package hex.optimization;

import hex.optimization.L_BFGS;
import hex.optimization.OptimizationUtils;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;
import water.H2O;
import water.MemoryManager;
import water.util.ArrayUtils;
import water.util.Log;
import water.util.MathUtils;

/* loaded from: input_file:hex/optimization/ADMM.class */
public class ADMM {

    /* loaded from: input_file:hex/optimization/ADMM$L1Solver.class */
    public static class L1Solver {
        final double RELTOL;
        final double ABSTOL;
        double gerr;
        int iter;
        final double _eps;
        final int max_iter;
        MathUtils.Norm _gradientNorm;
        public double[] _u;
        public static double DEFAULT_RELTOL = 0.01d;
        public static double DEFAULT_ABSTOL = 1.0E-4d;
        public L_BFGS.ProgressMonitor _pm;

        public L1Solver setGradientNorm(MathUtils.Norm norm) {
            this._gradientNorm = norm;
            return this;
        }

        public L1Solver(double d, int i, double[] dArr) {
            this(d, i, DEFAULT_RELTOL, DEFAULT_ABSTOL, dArr);
        }

        public L1Solver(double d, int i, double d2, double d3, double[] dArr) {
            this._gradientNorm = MathUtils.Norm.L_Infinite;
            this._eps = d;
            this.max_iter = i;
            this._u = dArr;
            this.RELTOL = d2;
            this.ABSTOL = d3;
        }

        public boolean solve(ProximalSolver proximalSolver, double[] dArr, double d, boolean z) {
            return solve(proximalSolver, dArr, d, z, null, null);
        }

        private double computeErr(double[] dArr, double[] dArr2, double d, double[] dArr3, double[] dArr4) {
            double[] dArr5 = (double[]) dArr2.clone();
            this.gerr = CMAESOptimizer.DEFAULT_STOPFITNESS;
            if (dArr3 != null) {
                for (int i = 0; i < dArr.length; i++) {
                    if (dArr[i] == dArr3[i] && dArr5[i] > CMAESOptimizer.DEFAULT_STOPFITNESS) {
                        dArr5[i] = dArr[i] >= CMAESOptimizer.DEFAULT_STOPFITNESS ? -d : d;
                    }
                }
            }
            if (dArr4 != null) {
                for (int i2 = 0; i2 < dArr.length; i2++) {
                    if (dArr[i2] == dArr4[i2] && dArr5[i2] < CMAESOptimizer.DEFAULT_STOPFITNESS) {
                        dArr5[i2] = dArr[i2] >= CMAESOptimizer.DEFAULT_STOPFITNESS ? -d : d;
                    }
                }
            }
            ADMM.subgrad(d, dArr, dArr5);
            switch (this._gradientNorm) {
                case L_Infinite:
                    this.gerr = ArrayUtils.linfnorm(dArr5, false);
                    break;
                case L2_2:
                    this.gerr = ArrayUtils.l2norm2(dArr5, false);
                    break;
                case L2:
                    this.gerr = Math.sqrt(ArrayUtils.l2norm2(dArr5, false));
                    break;
                case L1:
                    this.gerr = ArrayUtils.l1norm(dArr5, false);
                    break;
                default:
                    throw H2O.unimpl();
            }
            return this.gerr;
        }

        public boolean solve(ProximalSolver proximalSolver, double[] dArr, double d, boolean z, double[] dArr2, double[] dArr3) {
            double[] dArr4;
            this.gerr = Double.POSITIVE_INFINITY;
            this.iter = 0;
            if (d == CMAESOptimizer.DEFAULT_STOPFITNESS && dArr2 == null && dArr3 == null) {
                proximalSolver.solve(null, dArr);
                return true;
            }
            int i = z ? 1 : 0;
            int length = dArr.length;
            double sqrt = this.ABSTOL * Math.sqrt(length);
            double[] rho = proximalSolver.rho();
            double[] dArr5 = (double[]) dArr.clone();
            double[] malloc8d = MemoryManager.malloc8d(length);
            if (this._u != null) {
                dArr4 = this._u;
                for (int i2 = 0; i2 < malloc8d.length - i; i2++) {
                    malloc8d[i2] = dArr[i2] - this._u[i2];
                }
            } else {
                double[] malloc8d2 = MemoryManager.malloc8d(dArr.length);
                this._u = malloc8d2;
                dArr4 = malloc8d2;
            }
            double[] malloc8d3 = MemoryManager.malloc8d(rho.length);
            if (d > CMAESOptimizer.DEFAULT_STOPFITNESS) {
                for (int i3 = 0; i3 < length - i; i3++) {
                    malloc8d3[i3] = rho[i3] != CMAESOptimizer.DEFAULT_STOPFITNESS ? d / rho[i3] : CMAESOptimizer.DEFAULT_STOPFITNESS;
                }
            }
            double d2 = this.RELTOL;
            int i4 = 0;
            while (i4 < this.max_iter && proximalSolver.solve(malloc8d, dArr5)) {
                if (this._pm != null && (i4 + 1) % 5 == 0) {
                    this._pm.progress(dArr, proximalSolver.gradient(dArr));
                }
                double d3 = 0.0d;
                double d4 = 0.0d;
                double d5 = 0.0d;
                double d6 = 0.0d;
                for (int i5 = 0; i5 < length - i; i5++) {
                    double d7 = dArr5[i5];
                    double d8 = dArr[i5];
                    double d9 = (d7 * 1.0d) + ((1.0d - 1.0d) * d8);
                    double shrinkage = ADMM.shrinkage(d9 + dArr4[i5], malloc8d3[i5]);
                    if (dArr2 != null && shrinkage < dArr2[i5]) {
                        shrinkage = dArr2[i5];
                    }
                    if (dArr3 != null && shrinkage > dArr3[i5]) {
                        shrinkage = dArr3[i5];
                    }
                    double[] dArr6 = dArr4;
                    int i6 = i5;
                    dArr6[i6] = dArr6[i6] + (d9 - shrinkage);
                    malloc8d[i5] = shrinkage - dArr4[i5];
                    double d10 = d7 - shrinkage;
                    double d11 = shrinkage - d8;
                    d3 += d10 * d10;
                    d4 += d11 * d11;
                    d6 += d7 * d7;
                    d5 += rho[i5] * rho[i5] * dArr4[i5] * dArr4[i5];
                    dArr[i5] = shrinkage;
                }
                if (z) {
                    int length2 = dArr5.length - 1;
                    double d12 = dArr5[length2];
                    if (dArr2 != null && d12 < dArr2[length2]) {
                        d12 = dArr2[length2];
                    }
                    if (dArr3 != null && d12 > dArr3[length2]) {
                        d12 = dArr3[length2];
                    }
                    double d13 = dArr5[length2] - d12;
                    double d14 = d12 - dArr[length2];
                    double[] dArr7 = dArr4;
                    dArr7[length2] = dArr7[length2] + d13;
                    malloc8d[length2] = d12 - dArr4[length2];
                    d3 += d13 * d13;
                    d4 += d14 * d14;
                    d6 += d12 * d12;
                    d5 += rho[length2] * rho[length2] * dArr4[length2] * dArr4[length2];
                    dArr[length2] = d12;
                }
                if (d3 < sqrt + (d2 * Math.sqrt(d6)) && d4 < sqrt + (d2 * Math.sqrt(d5))) {
                    double d15 = this.gerr;
                    computeErr(dArr, proximalSolver.gradient(dArr)._gradient, d, dArr2, dArr3);
                    if (this.gerr <= this._eps) {
                        if (this.gerr > this._eps) {
                            Log.warn("ADMM solver finished with gerr = " + this.gerr + " >  eps = " + this._eps);
                        }
                        this.iter = i4;
                        if (this._pm == null || (i4 + 1) % 5 != 0) {
                            return true;
                        }
                        this._pm.progress(dArr, proximalSolver.gradient(dArr));
                        return true;
                    }
                    Log.debug("ADMM.L1Solver: iter = " + i4 + " , gerr =  " + this.gerr + ", oldGerr = " + d15 + ", rnorm = " + d3 + ", snorm  " + d4);
                    if (sqrt > 1.0E-12d) {
                        sqrt *= 0.1d;
                    }
                    if (d2 > 1.0E-10d) {
                        d2 *= 0.1d;
                    }
                    d2 *= 0.1d;
                }
                i4++;
            }
            computeErr(dArr, proximalSolver.gradient(dArr)._gradient, d, dArr2, dArr3);
            if (this.iter == this.max_iter) {
                Log.warn("ADMM solver reached maximum number of iterations (" + this.max_iter + ")");
            } else {
                Log.warn("ADMM solver stopped after " + i4 + " iterations. (max_iter=" + this.max_iter + ")");
            }
            if (this.gerr > this._eps) {
                Log.warn("ADMM solver finished with gerr = " + this.gerr + " >  eps = " + this._eps);
            }
            this.iter = this.max_iter;
            if (this._pm == null || (i4 + 1) % 5 != 0) {
                return false;
            }
            this._pm.progress(dArr, proximalSolver.gradient(dArr));
            return false;
        }

        public String toString() {
            return "iter = " + this.iter + ", gerr = " + this.gerr;
        }

        public static double estimateRho(double d, double d2, double d3, double d4) {
            if (Double.isInfinite(d)) {
                return CMAESOptimizer.DEFAULT_STOPFITNESS;
            }
            double d5 = 0.0d;
            if (d2 != CMAESOptimizer.DEFAULT_STOPFITNESS && d != CMAESOptimizer.DEFAULT_STOPFITNESS) {
                if (d > CMAESOptimizer.DEFAULT_STOPFITNESS) {
                    double d6 = d2 * (d2 + (4.0d * d));
                    if (d6 >= CMAESOptimizer.DEFAULT_STOPFITNESS) {
                        double sqrt = (d2 + Math.sqrt(d6)) / (2.0d * d);
                        if (sqrt > CMAESOptimizer.DEFAULT_STOPFITNESS) {
                            d5 = sqrt;
                        } else {
                            Log.warn("negative rho estimate(1)! r = " + sqrt);
                        }
                    }
                } else if (d < CMAESOptimizer.DEFAULT_STOPFITNESS) {
                    double d7 = d2 * (d2 - (4.0d * d));
                    if (d7 >= CMAESOptimizer.DEFAULT_STOPFITNESS) {
                        double d8 = (-(d2 + Math.sqrt(d7))) / (2.0d * d);
                        if (d8 > CMAESOptimizer.DEFAULT_STOPFITNESS) {
                            d5 = d8;
                        } else {
                            Log.warn("negative rho estimate(2)!  r = " + d8);
                        }
                    }
                }
                d5 *= 0.25d;
            }
            if (!Double.isInfinite(d3) || !Double.isInfinite(d4)) {
                d5 = ((-Math.min(d - d3, d4 - d)) > (-1.0E-4d) ? 1 : ((-Math.min(d - d3, d4 - d)) == (-1.0E-4d) ? 0 : -1)) > 0 ? 10.0d : 0.1d;
            }
            return d5;
        }
    }

    /* loaded from: input_file:hex/optimization/ADMM$ProximalSolver.class */
    public interface ProximalSolver {
        double[] rho();

        boolean solve(double[] dArr, double[] dArr2);

        boolean hasGradient();

        OptimizationUtils.GradientInfo gradient(double[] dArr);

        int iter();
    }

    public static double shrinkage(double d, double d2) {
        double d3 = d < CMAESOptimizer.DEFAULT_STOPFITNESS ? -1.0d : 1.0d;
        double d4 = d * d3;
        return d4 <= d2 ? CMAESOptimizer.DEFAULT_STOPFITNESS : d3 * (d4 - d2);
    }

    public static void subgrad(double d, double[] dArr, double[] dArr2) {
        if (dArr == null) {
            return;
        }
        for (int i = 0; i < dArr2.length - 1; i++) {
            if (dArr[i] < CMAESOptimizer.DEFAULT_STOPFITNESS) {
                dArr2[i] = shrinkage(dArr2[i] - d, d * 1.0E-4d);
            } else if (dArr[i] > CMAESOptimizer.DEFAULT_STOPFITNESS) {
                dArr2[i] = shrinkage(dArr2[i] + d, d * 1.0E-4d);
            } else {
                dArr2[i] = shrinkage(dArr2[i], d);
            }
        }
    }
}
