package hex.psvm.psvm;

import javassist.bytecode.Opcode;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;
import water.Iced;
import water.MRTask;
import water.fvec.C8DVolatileChunk;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.TransformWrappedVec;
import water.fvec.Vec;
import water.util.ArrayUtils;
import water.util.Log;

/* loaded from: input_file:hex/psvm/psvm/PrimalDualIPM.class */
public class PrimalDualIPM {

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:hex/psvm/psvm/PrimalDualIPM$CheckConvergenceTask.class */
    public static class CheckConvergenceTask extends PDIPMTask<CheckConvergenceTask> {
        private final double _nu;
        double _resd;
        double _resp;

        CheckConvergenceTask(Parms parms, double d) {
            super(parms);
            this._nu = d;
        }

        @Override // hex.psvm.psvm.PrimalDualIPM.PDIPMTask
        void map() {
            for (int i = 0; i < this._z._len; i++) {
                double atd = this._z.atd(i) + ((this._nu * (this._label.atd(i) > CMAESOptimizer.DEFAULT_STOPFITNESS ? 1 : -1)) - 1.0d);
                double atd2 = (this._la.atd(i) - this._xi.atd(i)) + atd;
                this._z.set(i, atd);
                this._resd += atd2 * atd2;
                this._resp += this._label.atd(i) * this._x.atd(i);
            }
        }

        @Override // water.MRTask
        public void reduce(CheckConvergenceTask checkConvergenceTask) {
            this._resd += checkConvergenceTask._resd;
            this._resp += checkConvergenceTask._resp;
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // water.MRTask
        public void postGlobal() {
            this._resp = Math.abs(this._resp);
            this._resd = Math.sqrt(this._resd);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:hex/psvm/psvm/PrimalDualIPM$DeltaNuTask.class */
    public static class DeltaNuTask extends MRTask<DeltaNuTask> {
        private final double[] _vz;
        private final double[] _vl;
        double _sum1;
        double _sum2;

        DeltaNuTask(double[] dArr, double[] dArr2) {
            this._vz = dArr;
            this._vl = dArr2;
        }

        @Override // water.MRTask
        public void map(Chunk[] chunkArr) {
            int length = chunkArr.length - 4;
            Chunk chunk = chunkArr[length];
            Chunk chunk2 = chunkArr[length + 1];
            Chunk chunk3 = chunkArr[length + 2];
            Chunk chunk4 = chunkArr[length + 3];
            for (int i = 0; i < chunk3._len; i++) {
                double atd = chunk2.atd(i);
                double atd2 = chunk3.atd(i);
                for (int i2 = 0; i2 < length; i2++) {
                    atd -= chunkArr[i2].atd(i) * this._vz[i2];
                    atd2 -= chunkArr[i2].atd(i) * this._vl[i2];
                }
                this._sum1 += chunk3.atd(i) * ((atd * chunk.atd(i)) + chunk4.atd(i));
                this._sum2 += chunk3.atd(i) * atd2 * chunk.atd(i);
            }
        }

        @Override // water.MRTask
        public void reduce(DeltaNuTask deltaNuTask) {
            this._sum1 += deltaNuTask._sum1;
            this._sum2 += deltaNuTask._sum2;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:hex/psvm/psvm/PrimalDualIPM$InitTask.class */
    public static class InitTask extends PDIPMTask<InitTask> {
        InitTask(Parms parms) {
            super(parms);
        }

        @Override // hex.psvm.psvm.PrimalDualIPM.PDIPMTask
        public void map() {
            for (int i = 0; i < this._label._len; i++) {
                double d = (this._label.atd(i) > CMAESOptimizer.DEFAULT_STOPFITNESS ? this._c_pos : this._c_neg) / 10.0d;
                this._la.set(i, d);
                this._xi.set(i, d);
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:hex/psvm/psvm/PrimalDualIPM$LSHelper1.class */
    public static class LSHelper1 extends MRTask<LSHelper1> {
        private final boolean _output_z;
        double[] _row;

        LSHelper1(boolean z) {
            this._output_z = z;
        }

        @Override // water.MRTask
        public void map(Chunk[] chunkArr) {
            int length = chunkArr.length - (this._output_z ? 3 : 2);
            this._row = new double[length];
            Chunk chunk = chunkArr[length];
            Chunk chunk2 = chunkArr[length + 1];
            double[] values = this._output_z ? ((C8DVolatileChunk) chunkArr[length + 2]).getValues() : new double[chunk._len];
            for (int i = 0; i < values.length; i++) {
                values[i] = chunk2.atd(i) * chunk.atd(i);
            }
            for (int i2 = 0; i2 < length; i2++) {
                double d = 0.0d;
                for (int i3 = 0; i3 < values.length; i3++) {
                    d += chunkArr[i2].atd(i3) * values[i3];
                }
                this._row[i2] = d;
            }
        }

        @Override // water.MRTask
        public void reduce(LSHelper1 lSHelper1) {
            ArrayUtils.add(this._row, lSHelper1._row);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:hex/psvm/psvm/PrimalDualIPM$LineSearchTask.class */
    public static class LineSearchTask extends PDIPMTask<LineSearchTask> {
        private double _ap;
        private double _ad;

        LineSearchTask(Parms parms) {
            super(parms);
        }

        @Override // hex.psvm.psvm.PrimalDualIPM.PDIPMTask
        public void map() {
            map(this._label, this._tlx, this._tux, this._xilx, this._laux, this._xi, this._la, this._dx, this._x, ((C8DVolatileChunk) this._dxi).getValues(), ((C8DVolatileChunk) this._dla).getValues());
        }

        private void map(Chunk chunk, Chunk chunk2, Chunk chunk3, Chunk chunk4, Chunk chunk5, Chunk chunk6, Chunk chunk7, Chunk chunk8, Chunk chunk9, double[] dArr, double[] dArr2) {
            for (int i = 0; i < dArr.length; i++) {
                dArr[i] = (chunk2.atd(i) - (chunk4.atd(i) * chunk8.atd(i))) - chunk6.atd(i);
                dArr2[i] = (chunk3.atd(i) + (chunk5.atd(i) * chunk8.atd(i))) - chunk7.atd(i);
            }
            double d = Double.MAX_VALUE;
            double d2 = Double.MAX_VALUE;
            for (int i2 = 0; i2 < dArr.length; i2++) {
                double d3 = chunk.atd(i2) > CMAESOptimizer.DEFAULT_STOPFITNESS ? this._c_pos : this._c_neg;
                if (chunk8.atd(i2) > CMAESOptimizer.DEFAULT_STOPFITNESS) {
                    d = Math.min(d, (d3 - chunk9.atd(i2)) / chunk8.atd(i2));
                }
                if (chunk8.atd(i2) < CMAESOptimizer.DEFAULT_STOPFITNESS) {
                    d = Math.min(d, (-chunk9.atd(i2)) / chunk8.atd(i2));
                }
                if (dArr[i2] < CMAESOptimizer.DEFAULT_STOPFITNESS) {
                    d2 = Math.min(d2, (-chunk6.atd(i2)) / dArr[i2]);
                }
                if (dArr2[i2] < CMAESOptimizer.DEFAULT_STOPFITNESS) {
                    d2 = Math.min(d2, (-chunk7.atd(i2)) / dArr2[i2]);
                }
            }
            this._ap = d;
            this._ad = d2;
        }

        @Override // water.MRTask
        public void reduce(LineSearchTask lineSearchTask) {
            this._ap = Math.min(this._ap, lineSearchTask._ap);
            this._ad = Math.min(this._ad, lineSearchTask._ad);
        }

        @Override // water.MRTask
        public void postGlobal() {
            this._ap = Math.min(this._ap, 1.0d) * 0.99d;
            this._ad = Math.min(this._ad, 1.0d) * 0.99d;
        }
    }

    /* loaded from: input_file:hex/psvm/psvm/PrimalDualIPM$LinearCombTransform.class */
    private static class LinearCombTransform implements TransformWrappedVec.Transform {
        private final double[] _coefs;
        double _sum;

        LinearCombTransform(double[] dArr) {
            this._coefs = dArr;
        }

        @Override // water.fvec.TransformWrappedVec.Transform
        public void reset() {
            this._sum = CMAESOptimizer.DEFAULT_STOPFITNESS;
        }

        @Override // water.fvec.TransformWrappedVec.Transform
        public void setInput(int i, double d) {
            this._sum += d * this._coefs[i];
        }

        @Override // water.fvec.TransformWrappedVec.Transform
        public double apply() {
            return this._sum;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:hex/psvm/psvm/PrimalDualIPM$LinearCombTransformFactory.class */
    public static class LinearCombTransformFactory extends Iced<LinearCombTransformFactory> implements TransformWrappedVec.TransformFactory<LinearCombTransformFactory> {
        private final double[] _coefs;

        public LinearCombTransformFactory() {
            this._coefs = new double[0];
        }

        LinearCombTransformFactory(double... dArr) {
            this._coefs = dArr;
        }

        @Override // water.fvec.TransformWrappedVec.TransformFactory
        public TransformWrappedVec.Transform create(int i) {
            if (i != this._coefs.length) {
                throw new IllegalArgumentException("Expected " + this._coefs.length + " inputs, got: " + i);
            }
            return new LinearCombTransform(this._coefs);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:hex/psvm/psvm/PrimalDualIPM$MakeStepTask.class */
    public static class MakeStepTask extends MRTask<MakeStepTask> {
        double _ap;
        double _ad;

        MakeStepTask(double d, double d2) {
            this._ap = d;
            this._ad = d2;
        }

        @Override // water.MRTask
        public void map(Chunk[] chunkArr) {
            map(chunkArr[0], chunkArr[1], chunkArr[2], chunkArr[3], chunkArr[4], chunkArr[5]);
        }

        public void map(Chunk chunk, Chunk chunk2, Chunk chunk3, Chunk chunk4, Chunk chunk5, Chunk chunk6) {
            for (int i = 0; i < chunk._len; i++) {
                chunk.set(i, chunk.atd(i) + (this._ap * chunk2.atd(i)));
                chunk3.set(i, chunk3.atd(i) + (this._ad * chunk4.atd(i)));
                chunk5.set(i, chunk5.atd(i) + (this._ad * chunk6.atd(i)));
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:hex/psvm/psvm/PrimalDualIPM$PDIPMTask.class */
    public static abstract class PDIPMTask<E extends PDIPMTask<E>> extends MRTask<E> {
        transient Chunk _label;
        transient Chunk _x;
        transient Chunk _z;
        transient Chunk _xi;
        transient Chunk _dxi;
        transient Chunk _la;
        transient Chunk _dla;
        transient Chunk _tlx;
        transient Chunk _tux;
        transient Chunk _xilx;
        transient Chunk _laux;
        transient Chunk _d;
        transient Chunk _dx;
        final double _c_pos;
        final double _c_neg;

        PDIPMTask(Parms parms) {
            this._c_pos = parms._c_pos;
            this._c_neg = parms._c_neg;
        }

        @Override // water.MRTask
        public void map(Chunk[] chunkArr) {
            this._label = chunkArr[0];
            this._x = chunkArr[1];
            this._z = chunkArr[2];
            this._xi = chunkArr[3];
            this._dxi = chunkArr[4];
            this._la = chunkArr[5];
            this._dla = chunkArr[6];
            this._tlx = chunkArr[7];
            this._tux = chunkArr[8];
            this._xilx = chunkArr[9];
            this._laux = chunkArr[10];
            this._d = chunkArr[11];
            this._dx = chunkArr[12];
            map();
        }

        abstract void map();
    }

    /* loaded from: input_file:hex/psvm/psvm/PrimalDualIPM$Parms.class */
    public static class Parms {
        public int _max_iter;
        public double _mu_factor;
        public double _tradeoff;
        public double _feasible_threshold;
        public double _sgap_threshold;
        public double _x_epsilon;
        public double _c_neg;
        public double _c_pos;

        public Parms() {
            this._max_iter = Opcode.GOTO_W;
            this._mu_factor = 10.0d;
            this._tradeoff = CMAESOptimizer.DEFAULT_STOPFITNESS;
            this._feasible_threshold = 0.001d;
            this._sgap_threshold = 0.001d;
            this._x_epsilon = 1.0E-9d;
            this._c_neg = Double.NaN;
            this._c_pos = Double.NaN;
        }

        public Parms(double d, double d2) {
            this._max_iter = Opcode.GOTO_W;
            this._mu_factor = 10.0d;
            this._tradeoff = CMAESOptimizer.DEFAULT_STOPFITNESS;
            this._feasible_threshold = 0.001d;
            this._sgap_threshold = 0.001d;
            this._x_epsilon = 1.0E-9d;
            this._c_neg = Double.NaN;
            this._c_pos = Double.NaN;
            this._c_pos = d;
            this._c_neg = d2;
        }
    }

    /* loaded from: input_file:hex/psvm/psvm/PrimalDualIPM$ProgressObserver.class */
    public interface ProgressObserver {
        void reportProgress(int i, double d, double d2, double d3, boolean z);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:hex/psvm/psvm/PrimalDualIPM$SurrogateGapTask.class */
    public static class SurrogateGapTask extends PDIPMTask<SurrogateGapTask> {
        private double _sum;

        SurrogateGapTask(Parms parms) {
            super(parms);
        }

        @Override // hex.psvm.psvm.PrimalDualIPM.PDIPMTask
        void map() {
            double d = 0.0d;
            for (int i = 0; i < this._x._len; i++) {
                d += this._la.atd(i) * (this._label.atd(i) > CMAESOptimizer.DEFAULT_STOPFITNESS ? this._c_pos : this._c_neg);
            }
            for (int i2 = 0; i2 < this._x._len; i2++) {
                d += this._x.atd(i2) * (this._xi.atd(i2) - this._la.atd(i2));
            }
            this._sum = d;
        }

        @Override // water.MRTask
        public void reduce(SurrogateGapTask surrogateGapTask) {
            this._sum += surrogateGapTask._sum;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:hex/psvm/psvm/PrimalDualIPM$UpdateVarsTask.class */
    public static class UpdateVarsTask extends PDIPMTask<UpdateVarsTask> {
        private final double _epsilon_x;
        private final double _t;

        UpdateVarsTask(Parms parms, double d) {
            super(parms);
            this._epsilon_x = parms._x_epsilon;
            this._t = d;
        }

        @Override // hex.psvm.psvm.PrimalDualIPM.PDIPMTask
        void map() {
            for (int i = 0; i < this._z._len; i++) {
                double d = this._label.atd(i) > CMAESOptimizer.DEFAULT_STOPFITNESS ? this._c_pos : this._c_neg;
                double max = Math.max(this._x.atd(i), this._epsilon_x);
                double max2 = Math.max(d - this._x.atd(i), this._epsilon_x);
                double d2 = 1.0d / (this._t * max);
                double d3 = 1.0d / (this._t * max2);
                this._tlx.set(i, d2);
                this._tux.set(i, d3);
                double max3 = Math.max(this._xi.atd(i) / max, this._epsilon_x);
                double max4 = Math.max(this._la.atd(i) / max2, this._epsilon_x);
                this._d.set(i, 1.0d / (max3 + max4));
                this._xilx.set(i, max3);
                this._laux.set(i, max4);
                this._z.set(i, (d2 - d3) - this._z.atd(i));
            }
        }
    }

    public static Vec solve(Frame frame, Vec vec, Parms parms, ProgressObserver progressObserver) {
        checkLabel(vec);
        Frame makeVolatileWorkspace = makeVolatileWorkspace(vec, "z", "xi", "dxi", "la", "dla", "tlx", "tux", "xilx", "laux", "d", "dx");
        try {
            Vec solve = solve(frame, vec, parms, makeVolatileWorkspace, progressObserver);
            makeVolatileWorkspace.remove();
            return solve;
        } catch (Throwable th) {
            makeVolatileWorkspace.remove();
            throw th;
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    private static Vec solve(Frame frame, Vec vec, Parms parms, Frame frame2, ProgressObserver progressObserver) {
        Frame frame3 = new Frame(new String[]{"label"}, new Vec[]{vec});
        frame3.add("x", vec.makeZero());
        frame3.add(frame2);
        new InitTask(parms).doAll(frame3);
        Vec vec2 = frame3.vec("z");
        Vec vec3 = frame3.vec("la");
        Vec vec4 = frame3.vec("xi");
        Vec vec5 = frame3.vec("x");
        Vec vec6 = frame3.vec("dxi");
        Vec vec7 = frame3.vec("dla");
        Vec vec8 = frame3.vec("d");
        Vec vec9 = frame3.vec("dx");
        double d = 0.0d;
        boolean z = false;
        long numRows = frame.numRows() * 2;
        for (int i = 0; i < parms._max_iter; i++) {
            double d2 = ((SurrogateGapTask) new SurrogateGapTask(parms).doAll(frame3))._sum;
            double d3 = (parms._mu_factor * numRows) / d2;
            Log.info("Surrogate gap before iteration " + i + ": " + d2 + "; t: " + d3);
            computePartialZ(frame, vec5, parms._tradeoff, vec2);
            CheckConvergenceTask checkConvergenceTask = (CheckConvergenceTask) new CheckConvergenceTask(parms, d).doAll(frame3);
            Log.info("Residual (primal): " + checkConvergenceTask._resp + "; residual (dual): " + checkConvergenceTask._resd + ". Feasible threshold: " + parms._feasible_threshold);
            z = checkConvergenceTask._resp <= parms._feasible_threshold && checkConvergenceTask._resd <= parms._feasible_threshold && d2 <= parms._sgap_threshold;
            if (progressObserver != null) {
                progressObserver.reportProgress(i, d2, checkConvergenceTask._resp, checkConvergenceTask._resd, z);
            }
            if (z) {
                break;
            }
            new UpdateVarsTask(parms, d3).doAll(frame3);
            LLMatrix productMtDM = MatrixUtils.productMtDM(frame, vec8);
            productMtDM.addUnitMat();
            LLMatrix cf = productMtDM.cf();
            double computeDeltaNu = computeDeltaNu(frame, vec8, vec, vec2, vec5, cf);
            computeDeltaX(frame, vec8, vec, computeDeltaNu, cf, vec2, vec9);
            LineSearchTask lineSearchTask = (LineSearchTask) new LineSearchTask(parms).doAll(frame3);
            new MakeStepTask(lineSearchTask._ap, lineSearchTask._ad).doAll(vec5, vec9, vec4, vec6, vec3, vec7);
            d += lineSearchTask._ad * computeDeltaNu;
        }
        if (!z) {
            Log.warn("The algorithm didn't converge in the maximum number of iterations. Please consider changing the convergence parameters or increase the maximum number of iterations (" + parms._max_iter + ").");
        }
        frame2.remove();
        return vec5;
    }

    private static void checkLabel(Vec vec) {
        if (vec.min() != -1.0d || vec.max() != 1.0d) {
            throw new IllegalArgumentException("Expected a binary response encoded as +1/-1");
        }
    }

    private static void computePartialZ(Frame frame, Vec vec, final double d, Vec vec2) {
        final double[] productMtv = MatrixUtils.productMtv(frame, vec);
        new MRTask() { // from class: hex.psvm.psvm.PrimalDualIPM.1
            @Override // water.MRTask
            public void map(Chunk[] chunkArr) {
                int length = chunkArr.length - 2;
                Chunk chunk = chunkArr[length];
                Chunk chunk2 = chunkArr[length + 1];
                for (int i = 0; i < chunkArr[0]._len; i++) {
                    double d2 = 0.0d;
                    for (int i2 = 0; i2 < length; i2++) {
                        d2 += chunkArr[i2].atd(i) * productMtv[i2];
                    }
                    chunk2.set(i, d2 - (d * chunk.atd(i)));
                }
            }
        }.doAll((Vec[]) ArrayUtils.append(frame.vecs(), vec, vec2));
    }

    private static void computeDeltaX(Frame frame, Vec vec, Vec vec2, double d, LLMatrix lLMatrix, Vec vec3, Vec vec4) {
        TransformWrappedVec transformWrappedVec = new TransformWrappedVec(new Vec[]{vec3, vec2}, new LinearCombTransformFactory(1.0d, -d));
        try {
            linearSolveViaICFCol(frame, vec, transformWrappedVec, lLMatrix, vec4);
            transformWrappedVec.remove();
        } catch (Throwable th) {
            transformWrappedVec.remove();
            throw th;
        }
    }

    private static double computeDeltaNu(Frame frame, Vec vec, Vec vec2, Vec vec3, Vec vec4, LLMatrix lLMatrix) {
        DeltaNuTask doAll = new DeltaNuTask(partialLinearSolveViaICFCol(frame, vec, vec3, lLMatrix), partialLinearSolveViaICFCol(frame, vec, vec2, lLMatrix)).doAll((Vec[]) ArrayUtils.append(frame.vecs(), vec, vec3, vec2, vec4));
        return doAll._sum1 / doAll._sum2;
    }

    private static double[] partialLinearSolveViaICFCol(Frame frame, Vec vec, Vec vec2, LLMatrix lLMatrix) {
        return lLMatrix.cholSolve(new LSHelper1(false).doAll((Vec[]) ArrayUtils.append(frame.vecs(), vec, vec2))._row);
    }

    private static void linearSolveViaICFCol(Frame frame, Vec vec, Vec vec2, LLMatrix lLMatrix, Vec vec3) {
        final double[] cholSolve = lLMatrix.cholSolve(new LSHelper1(true).doAll((Vec[]) ArrayUtils.append(frame.vecs(), vec, vec2, vec3))._row);
        new MRTask() { // from class: hex.psvm.psvm.PrimalDualIPM.2
            @Override // water.MRTask
            public void map(Chunk[] chunkArr) {
                int length = chunkArr.length - 2;
                Chunk chunk = chunkArr[length];
                Chunk chunk2 = chunkArr[length + 1];
                for (int i = 0; i < chunkArr[0]._len; i++) {
                    double d = 0.0d;
                    for (int i2 = 0; i2 < length; i2++) {
                        d += chunkArr[i2].atd(i) * cholSolve[i2] * chunk.atd(i);
                    }
                    chunk2.set(i, chunk2.atd(i) - d);
                }
            }
        }.doAll((Vec[]) ArrayUtils.append(frame.vecs(), vec, vec3));
    }

    private static Frame makeVolatileWorkspace(Vec vec, String... strArr) {
        return new Frame(strArr, vec.makeVolatileDoubles(strArr.length));
    }
}
