package hex.optimization;

import java.util.Arrays;
import water.H2O;
import water.MemoryManager;
import water.util.ArrayUtils;
import water.util.Log;
import water.util.MathUtils;

/* loaded from: input_file:hex/optimization/ADMM.class */
public class ADMM {

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: hex.optimization.ADMM$1, reason: invalid class name */
    /* loaded from: input_file:hex/optimization/ADMM$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$water$util$MathUtils$Norm = new int[MathUtils.Norm.values().length];

        static {
            try {
                $SwitchMap$water$util$MathUtils$Norm[MathUtils.Norm.L_Infinite.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$water$util$MathUtils$Norm[MathUtils.Norm.L2_2.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$water$util$MathUtils$Norm[MathUtils.Norm.L2.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$water$util$MathUtils$Norm[MathUtils.Norm.L1.ordinal()] = 4;
            } catch (NoSuchFieldError e4) {
            }
        }
    }

    /* loaded from: input_file:hex/optimization/ADMM$L1Solver.class */
    public static class L1Solver {
        final double RELTOL;
        final double ABSTOL;
        double gerr;
        int iter;
        final double _eps;
        final int max_iter;
        MathUtils.Norm _gradientNorm;
        public static double DEFAULT_RELTOL;
        public static double DEFAULT_ABSTOL;
        static final /* synthetic */ boolean $assertionsDisabled;

        public L1Solver setGradientNorm(MathUtils.Norm norm) {
            this._gradientNorm = norm;
            return this;
        }

        public L1Solver(double d, int i) {
            this(d, i, DEFAULT_RELTOL, DEFAULT_ABSTOL);
        }

        public L1Solver(double d, int i, double d2, double d3) {
            this._gradientNorm = MathUtils.Norm.L_Infinite;
            this._eps = d;
            this.max_iter = i;
            this.RELTOL = d2;
            this.ABSTOL = d3;
        }

        public boolean solve(ProximalSolver proximalSolver, double[] dArr, double d) {
            return solve(proximalSolver, dArr, d, true, null, null);
        }

        private double computeErr(double[] dArr, double[] dArr2, double d, double[] dArr3, double[] dArr4) {
            double[] dArr5 = (double[]) dArr2.clone();
            this.gerr = 0.0d;
            if (dArr3 != null) {
                for (int i = 0; i < dArr.length; i++) {
                    if (dArr[i] == dArr3[i] && dArr5[i] > 0.0d) {
                        dArr5[i] = dArr[i] >= 0.0d ? -d : d;
                    }
                }
            }
            if (dArr4 != null) {
                for (int i2 = 0; i2 < dArr.length; i2++) {
                    if (dArr[i2] == dArr4[i2] && dArr5[i2] < 0.0d) {
                        dArr5[i2] = dArr[i2] >= 0.0d ? -d : d;
                    }
                }
            }
            ADMM.subgrad(d, dArr, dArr5);
            switch (AnonymousClass1.$SwitchMap$water$util$MathUtils$Norm[this._gradientNorm.ordinal()]) {
                case 1:
                    this.gerr = ArrayUtils.linfnorm(dArr5, false);
                    break;
                case 2:
                    this.gerr = ArrayUtils.l2norm2(dArr5, false);
                    break;
                case 3:
                    this.gerr = Math.sqrt(ArrayUtils.l2norm2(dArr5, false));
                    break;
                case 4:
                    this.gerr = ArrayUtils.l1norm(dArr5, false);
                    break;
                default:
                    throw H2O.unimpl();
            }
            return this.gerr;
        }

        public boolean solve(ProximalSolver proximalSolver, double[] dArr, double d, boolean z, double[] dArr2, double[] dArr3) {
            this.gerr = Double.POSITIVE_INFINITY;
            if (d == 0.0d && dArr2 == null && dArr3 == null) {
                proximalSolver.solve(null, dArr);
                return true;
            }
            boolean z2 = z;
            Object[] objArr = null;
            int length = dArr.length;
            double sqrt = this.ABSTOL * Math.sqrt(length);
            double[] rho = proximalSolver.rho();
            double[] malloc8d = MemoryManager.malloc8d(length);
            double[] dArr4 = (double[]) dArr.clone();
            double[] malloc8d2 = MemoryManager.malloc8d(length);
            double[] malloc8d3 = MemoryManager.malloc8d(rho.length);
            if (d > 0.0d) {
                for (int i = 0; i < length - 1; i++) {
                    malloc8d3[i] = d / rho[i];
                }
            }
            double d2 = this.RELTOL;
            for (int i2 = 0; i2 < this.max_iter; i2++) {
                proximalSolver.solve(malloc8d2, dArr4);
                double d3 = 0.0d;
                double d4 = 0.0d;
                double d5 = 0.0d;
                double d6 = 0.0d;
                boolean z3 = true;
                for (int i3 = 0; i3 < length - 1; i3++) {
                    double d7 = dArr4[i3];
                    double d8 = dArr[i3];
                    double d9 = (d7 * 1.0d) + ((1.0d - 1.0d) * d8);
                    double shrinkage = ADMM.shrinkage(d9 + malloc8d[i3], malloc8d3[i3]);
                    if (dArr2 != null && shrinkage < dArr2[i3]) {
                        shrinkage = dArr2[i3];
                    }
                    if (dArr3 != null && shrinkage > dArr3[i3]) {
                        shrinkage = dArr3[i3];
                    }
                    int i4 = i3;
                    malloc8d[i4] = malloc8d[i4] + (d9 - shrinkage);
                    malloc8d2[i3] = shrinkage - malloc8d[i3];
                    double d10 = d7 - shrinkage;
                    double d11 = shrinkage - d8;
                    d3 += d10 * d10;
                    d4 += d11 * d11;
                    d6 += d7 * d7;
                    d5 += rho[i3] * rho[i3] * malloc8d[i3] * malloc8d[i3];
                    dArr[i3] = shrinkage;
                    z3 &= shrinkage == 0.0d;
                }
                if (z) {
                    int length2 = dArr4.length - 1;
                    double d12 = dArr4[length2];
                    if (dArr2 != null && d12 < dArr2[length2]) {
                        d12 = dArr2[length2];
                    }
                    if (dArr3 != null && d12 > dArr3[length2]) {
                        d12 = dArr3[length2];
                    }
                    double d13 = dArr4[length2] - d12;
                    double d14 = d12 - dArr[length2];
                    malloc8d[length2] = malloc8d[length2] + d13;
                    malloc8d2[length2] = d12 - malloc8d[length2];
                    d3 += d13 * d13;
                    d4 += d14 * d14;
                    d6 += d12 * d12;
                    d5 += rho[length2] * rho[length2] * malloc8d[length2] * malloc8d[length2];
                    dArr[length2] = d12;
                }
                if (d3 < sqrt + (d2 * Math.sqrt(d6)) && d4 < sqrt + (d2 * Math.sqrt(d5))) {
                    double d15 = this.gerr;
                    computeErr(dArr, proximalSolver.gradient(dArr), d, dArr2, dArr3);
                    if (this.gerr <= this._eps) {
                        if (this.gerr > this._eps) {
                            Log.warn(new Object[]{"ADMM solver finished with gerr = " + this.gerr + " >  eps = " + this._eps});
                        }
                        this.iter = i2;
                        Log.info(new Object[]{"ADMM.L1Solver: converged at iteration = " + i2 + ", gerr = " + this.gerr + ", inner solver took " + proximalSolver.iter() + " iterations"});
                        return true;
                    }
                    Log.debug(new Object[]{"ADMM.L1Solver: iter = " + i2 + " , gerr =  " + this.gerr + ", oldGerr = " + d15 + ", rnorm = " + d3 + ", snorm  " + d4});
                    if (sqrt > 1.0E-12d) {
                        sqrt *= 0.1d;
                    }
                    if (d2 > 1.0E-10d) {
                        d2 *= 0.1d;
                    }
                    d2 *= 0.1d;
                }
            }
            computeErr(dArr, proximalSolver.gradient(dArr), d, dArr2, dArr3);
            if (0 != 0 && Double.POSITIVE_INFINITY < this.gerr) {
                System.arraycopy(null, 0, dArr, 0, objArr.length);
                computeErr(dArr, proximalSolver.gradient(dArr), d, dArr2, dArr3);
                if (!$assertionsDisabled && Math.abs(Double.POSITIVE_INFINITY - this.gerr) >= 1.0E-8d) {
                    throw new AssertionError(" gerr = " + this.gerr + ", best_err = Infinity zbest = " + Arrays.toString((double[]) null) + ", z = " + Arrays.toString(dArr));
                }
            }
            Log.warn(new Object[]{"ADMM solver reached maximum number of iterations (" + this.max_iter + ")"});
            if (this.gerr > this._eps) {
                Log.warn(new Object[]{"ADMM solver finished with gerr = " + this.gerr + " >  eps = " + this._eps});
            }
            this.iter = this.max_iter;
            return false;
        }

        public static double estimateRho(double d, double d2, double d3, double d4) {
            if (Double.isInfinite(d)) {
                return 0.0d;
            }
            double d5 = 0.0d;
            if (d2 != 0.0d && d != 0.0d) {
                if (d > 0.0d) {
                    double d6 = d2 * (d2 + (4.0d * d));
                    if (d6 >= 0.0d) {
                        double sqrt = (d2 + Math.sqrt(d6)) / (2.0d * d);
                        if (sqrt > 0.0d) {
                            d5 = sqrt;
                        } else {
                            Log.warn(new Object[]{"negative rho estimate(1)! r = " + sqrt});
                        }
                    }
                } else if (d < 0.0d) {
                    double d7 = d2 * (d2 - (4.0d * d));
                    if (d7 >= 0.0d) {
                        double d8 = (-(d2 + Math.sqrt(d7))) / (2.0d * d);
                        if (d8 > 0.0d) {
                            d5 = d8;
                        } else {
                            Log.warn(new Object[]{"negative rho estimate(2)!  r = " + d8});
                        }
                    }
                }
                d5 *= 0.25d;
            }
            if (!Double.isInfinite(d3) || !Double.isInfinite(d4)) {
                d5 = ((-Math.min(d - d3, d4 - d)) > (-1.0E-4d) ? 1 : ((-Math.min(d - d3, d4 - d)) == (-1.0E-4d) ? 0 : -1)) > 0 ? 10.0d : 0.1d;
            }
            return d5;
        }

        static {
            $assertionsDisabled = !ADMM.class.desiredAssertionStatus();
            DEFAULT_RELTOL = 0.01d;
            DEFAULT_ABSTOL = 1.0E-4d;
        }
    }

    /* loaded from: input_file:hex/optimization/ADMM$ProximalSolver.class */
    public interface ProximalSolver {
        double[] rho();

        boolean solve(double[] dArr, double[] dArr2);

        boolean hasGradient();

        double[] gradient(double[] dArr);

        int iter();
    }

    public static double shrinkage(double d, double d2) {
        double d3 = d < 0.0d ? -1.0d : 1.0d;
        double d4 = d * d3;
        if (d4 <= d2) {
            return 0.0d;
        }
        return d3 * (d4 - d2);
    }

    public static void subgrad(double d, double[] dArr, double[] dArr2) {
        if (dArr == null) {
            return;
        }
        for (int i = 0; i < dArr2.length - 1; i++) {
            if (dArr[i] < 0.0d) {
                dArr2[i] = shrinkage(dArr2[i] - d, d * 0.001d);
            } else if (dArr[i] > 0.0d) {
                dArr2[i] = shrinkage(dArr2[i] + d, d * 0.001d);
            } else {
                dArr2[i] = shrinkage(dArr2[i], 1.001d * d);
            }
        }
    }
}
