package hex.optimization;

import hex.glm.ComputationState;
import hex.glm.ConstrainedGLMUtils;
import hex.glm.GLM;
import java.util.Arrays;
import java.util.List;
import water.Iced;
import water.util.ArrayUtils;
import water.util.Log;

/* loaded from: input_file:hex/optimization/OptimizationUtils.class */
public class OptimizationUtils {

    /* loaded from: input_file:hex/optimization/OptimizationUtils$ExactLineSearch.class */
    public static final class ExactLineSearch {
        public double _alphal;
        public double _alphar;
        public double _alphai;
        public double[] _direction;
        public double[] _originalBeta;
        public double[] _newBeta;
        public GLM.GLMGradientInfo _ginfoOriginal;
        public double _currGradDirIP;
        public String[] _coeffNames;
        public final double _betaLS1 = 1.0E-4d;
        public final double _betaLS2 = 0.99d;
        public final double _lambdaLS = 2.0d;
        public int _maxIteration = 50;

        public ExactLineSearch(double[] dArr, ComputationState computationState, List<String> list) {
            reset(dArr, computationState, list);
        }

        public void reset(double[] dArr, ComputationState computationState, List<String> list) {
            this._direction = new double[dArr.length];
            ArrayUtils.subtract(dArr, computationState.beta(), this._direction);
            this._ginfoOriginal = computationState.ginfo();
            this._originalBeta = computationState.beta();
            this._alphai = 1.0d;
            this._alphal = 0.0d;
            this._alphar = Double.POSITIVE_INFINITY;
            this._coeffNames = (String[]) list.toArray(new String[0]);
            this._currGradDirIP = ArrayUtils.innerProduct(this._ginfoOriginal._gradient, this._direction);
        }

        public boolean evaluateFirstWolfe(GLM.GLMGradientInfo gLMGradientInfo) {
            return gLMGradientInfo._objVal <= this._ginfoOriginal._objVal + ((this._alphai * 1.0E-4d) * this._currGradDirIP);
        }

        public boolean evaluateSecondWolfe(GLM.GLMGradientInfo gLMGradientInfo) {
            return ArrayUtils.innerProduct(gLMGradientInfo._gradient, this._direction) >= 0.99d * this._currGradDirIP;
        }

        public boolean setAlphai(boolean z, boolean z2) {
            if (!z && z2) {
                this._alphar = this._alphai;
                this._alphai = 0.5d * (this._alphal + this._alphar);
                return true;
            }
            if (!z || z2) {
                return false;
            }
            this._alphal = this._alphai;
            if (this._alphar < Double.POSITIVE_INFINITY) {
                this._alphai = 0.5d * (this._alphal + this._alphar);
                return true;
            }
            this._alphai = 2.0d * this._alphai;
            return true;
        }

        public void setBetaConstraintsDeriv(double[] dArr, double[] dArr2, ComputationState computationState, ConstrainedGLMUtils.LinearConstraints[] linearConstraintsArr, ConstrainedGLMUtils.LinearConstraints[] linearConstraintsArr2, GLM.GLMGradientSolver gLMGradientSolver, double[] dArr3) {
            this._newBeta = dArr3;
            ConstrainedGLMUtils.updateConstraintValues(dArr3, Arrays.asList(this._coeffNames), linearConstraintsArr, linearConstraintsArr2);
            ConstrainedGLMUtils.calculateConstraintSquare(computationState, linearConstraintsArr, linearConstraintsArr2);
            computationState.updateConstraintInfo(linearConstraintsArr, linearConstraintsArr2);
            this._ginfoOriginal = ConstrainedGLMUtils.calGradient(dArr3, computationState, gLMGradientSolver, dArr, dArr2, linearConstraintsArr, linearConstraintsArr2);
        }

        public boolean findAlpha(double[] dArr, double[] dArr2, ComputationState computationState, ConstrainedGLMUtils.LinearConstraints[] linearConstraintsArr, ConstrainedGLMUtils.LinearConstraints[] linearConstraintsArr2, GLM.GLMGradientSolver gLMGradientSolver) {
            if (this._currGradDirIP > 0.0d) {
                return false;
            }
            double[] dArr3 = new double[this._originalBeta.length];
            for (int i = 0; i < this._maxIteration; i++) {
                ArrayUtils.mult(this._direction, dArr3, this._alphai);
                double[] add = ArrayUtils.add(dArr3, this._originalBeta);
                ConstrainedGLMUtils.updateConstraintValues(add, Arrays.asList(this._coeffNames), linearConstraintsArr, linearConstraintsArr2);
                ConstrainedGLMUtils.calculateConstraintSquare(computationState, linearConstraintsArr, linearConstraintsArr2);
                computationState.updateConstraintInfo(linearConstraintsArr, linearConstraintsArr2);
                GLM.GLMGradientInfo calGradient = ConstrainedGLMUtils.calGradient(add, computationState, gLMGradientSolver, dArr, dArr2, linearConstraintsArr, linearConstraintsArr2);
                boolean evaluateFirstWolfe = evaluateFirstWolfe(calGradient);
                boolean evaluateSecondWolfe = evaluateSecondWolfe(calGradient);
                if (evaluateFirstWolfe && evaluateSecondWolfe) {
                    this._newBeta = add;
                    this._ginfoOriginal = calGradient;
                    return true;
                }
                if (!setAlphai(evaluateFirstWolfe, evaluateSecondWolfe) || this._alphar < 1.0E-12d) {
                    return false;
                }
            }
            return false;
        }
    }

    /* loaded from: input_file:hex/optimization/OptimizationUtils$GradientInfo.class */
    public static class GradientInfo extends Iced {
        public double _objVal;
        public double[] _gradient;

        public GradientInfo(double d, double[] dArr) {
            this._objVal = d;
            this._gradient = dArr;
        }

        public boolean isValid() {
            return (Double.isNaN(this._objVal) || ArrayUtils.hasNaNsOrInfs(this._gradient)) ? false : true;
        }

        public String toString() {
            return " objVal = " + this._objVal + ", " + Arrays.toString(this._gradient);
        }
    }

    /* loaded from: input_file:hex/optimization/OptimizationUtils$GradientSolver.class */
    public interface GradientSolver {
        GradientInfo getGradient(double[] dArr);

        GradientInfo getObjective(double[] dArr);
    }

    /* loaded from: input_file:hex/optimization/OptimizationUtils$LineSearchSolver.class */
    public interface LineSearchSolver {
        boolean evaluate(double[] dArr);

        double step();

        GradientInfo ginfo();

        LineSearchSolver setInitialStep(double d);

        int nfeval();

        double getObj();

        double[] getX();
    }

    /* loaded from: input_file:hex/optimization/OptimizationUtils$MoreThuente.class */
    public static final class MoreThuente implements LineSearchSolver {
        double _stMin;
        double _stMax;
        double _initialStep;
        double _minRelativeImprovement;
        private final GradientSolver _gslvr;
        private double[] _beta;
        double _xtol;
        double _ftol;
        double _gtol;
        double _xtrapf;
        double _fvx;
        double _dgx;
        double _stx;
        double _bestStep;
        GradientInfo _betGradient;
        double _bestPsiVal;
        GradientInfo _ginfox;
        double _fvy;
        double _dgy;
        double _sty;
        boolean _brackt;
        boolean _bound;
        int _returnStatus;
        public final String[] messages;
        private int _iter;
        int _maxfev;
        double _maxStep;
        double _minStep;

        public MoreThuente(GradientSolver gradientSolver, double[] dArr) {
            this(gradientSolver, dArr, gradientSolver.getGradient(dArr), 0.1d, 0.1d, 0.01d);
        }

        public MoreThuente(GradientSolver gradientSolver, double[] dArr, GradientInfo gradientInfo) {
            this(gradientSolver, dArr, gradientInfo, 0.1d, 0.1d, 1.0E-8d);
        }

        public MoreThuente(GradientSolver gradientSolver, double[] dArr, GradientInfo gradientInfo, double d, double d2, double d3) {
            this._initialStep = 1.0d;
            this._minRelativeImprovement = 1.0E-8d;
            this._xtol = 1.0E-8d;
            this._ftol = 0.1d;
            this._gtol = 0.1d;
            this._xtrapf = 4.0d;
            this.messages = new String[]{"In progress or not evaluated", "The sufficient decrease condition and the directional derivative condition hold.", "Relative width of the interval of uncertainty is at most xtol.", "Number of calls to gradient solver has reached the limit.", "The step is at the lower bound stpmin.", "The step is at the upper bound stpmax.", "Rounding errors prevent further progress, ftol/gtol tolerances may be too small.", "Non-negative differential."};
            this._maxfev = 20;
            this._maxStep = 1.0E10d;
            this._minStep = 1.0E-10d;
            this._gslvr = gradientSolver;
            this._beta = dArr;
            this._ginfox = gradientInfo;
            if (gradientInfo._gradient == null) {
                throw new IllegalArgumentException("GradientInfo for MoreThuente line search solver must include gradient");
            }
            this._ftol = d;
            this._gtol = d2;
            this._xtol = d3;
        }

        @Override // hex.optimization.OptimizationUtils.LineSearchSolver
        public MoreThuente setInitialStep(double d) {
            this._initialStep = d;
            return this;
        }

        @Override // hex.optimization.OptimizationUtils.LineSearchSolver
        public int nfeval() {
            return this._iter;
        }

        @Override // hex.optimization.OptimizationUtils.LineSearchSolver
        public double getObj() {
            return ginfo()._objVal;
        }

        @Override // hex.optimization.OptimizationUtils.LineSearchSolver
        public double[] getX() {
            return this._beta;
        }

        private double nextStep(GradientInfo gradientInfo, double d, double d2, double d3) {
            double d4;
            double d5 = gradientInfo._objVal - (d2 * d3);
            double d6 = d - d3;
            double d7 = this._fvx - (this._stx * d3);
            double d8 = this._fvy - (this._sty * d3);
            double d9 = this._stx;
            double d10 = this._sty;
            double d11 = this._dgx - d3;
            double d12 = this._dgy - d3;
            if ((this._brackt && (d2 <= Math.min(d9, d10) || d2 >= Math.max(d9, d10))) || d11 * (d2 - d9) >= 0.0d) {
                return Double.NaN;
            }
            double d13 = ((3.0d * (d7 - d5)) / (d2 - d9)) + d11 + d6;
            double max = Math.max(Math.max(Math.abs(d13), Math.abs(d11)), Math.abs(d6));
            double d14 = 1.0d / max;
            double d15 = d13 * d14;
            double sqrt = max * Math.sqrt(Math.max(0.0d, (d15 * d15) - ((d11 * d14) * (d6 * d14))));
            if (d5 > d7) {
                if (d2 < d9) {
                    sqrt = -sqrt;
                }
                this._bound = true;
                this._brackt = true;
                double d16 = d9 + ((((sqrt - d11) + d13) / (((sqrt - d11) + sqrt) + d6)) * (d2 - d9));
                double d17 = d9 + (((d11 / (((d7 - d5) / (d2 - d9)) + d11)) / 2.0d) * (d2 - d9));
                d4 = Math.abs(d16 - d9) < Math.abs(d17 - d9) ? d16 : d16 + ((d17 - d16) / 2.0d);
            } else if (d6 * d11 < 0.0d) {
                if (d2 > d9) {
                    sqrt = -sqrt;
                }
                this._bound = false;
                this._brackt = true;
                double d18 = d2 + ((((sqrt - d6) + d13) / (((sqrt - d6) + sqrt) + d11)) * (d9 - d2));
                double d19 = d2 + ((d6 / (d6 - d11)) * (d9 - d2));
                d4 = Math.abs(d18 - d2) > Math.abs(d19 - d2) ? d18 : d19;
            } else if (Math.abs(d6) < Math.abs(d11)) {
                if (d2 > d9) {
                    sqrt = -sqrt;
                }
                this._bound = true;
                double d20 = ((sqrt - d6) + d13) / (((sqrt + d11) - d6) + sqrt);
                double d21 = (d20 >= 0.0d || sqrt == 0.0d) ? d2 > d9 ? this._stMax : this._stMin : d2 + (d20 * (d9 - d2));
                double d22 = d2 + ((d6 / (d6 - d11)) * (d9 - d2));
                if (this._brackt) {
                    d4 = Math.abs(d2 - d21) < Math.abs(d2 - d22) ? d21 : d22;
                } else {
                    d4 = Math.abs(d2 - d21) > Math.abs(d2 - d22) ? d21 : d22;
                }
            } else {
                this._bound = false;
                if (this._brackt) {
                    double d23 = ((3.0d * (d5 - d8)) / (d10 - d2)) + d12 + d6;
                    double sqrt2 = Math.sqrt((d23 * d23) - (d12 * d6));
                    if (d2 > d10) {
                        sqrt2 = -sqrt2;
                    }
                    d4 = d2 + ((((sqrt2 - d6) + d23) / (((sqrt2 - d6) + sqrt2) + d12)) * (d10 - d2));
                } else {
                    d4 = d2 > d9 ? this._stMax : this._stMin;
                }
            }
            if (d5 > d7) {
                this._sty = d2;
                this._fvy = gradientInfo._objVal;
                this._dgy = d;
            } else {
                if (d6 * d11 < 0.0d) {
                    this._sty = this._stx;
                    this._fvy = this._fvx;
                    this._dgy = this._dgx;
                }
                this._stx = d2;
                this._fvx = gradientInfo._objVal;
                this._dgx = d;
                this._ginfox = gradientInfo;
            }
            if (d4 > this._stMax) {
                d4 = this._stMax;
            }
            if (d4 < this._stMin) {
                d4 = this._stMin;
            }
            if (this._brackt & this._bound) {
                d4 = this._sty > this._stx ? Math.min(this._stx + (0.66d * (this._sty - this._stx)), d4) : Math.max(this._stx + (0.66d * (this._sty - this._stx)), d4);
            }
            return d4;
        }

        public String toString() {
            return "MoreThuente line search, iter = " + this._iter + ", status = " + this.messages[this._returnStatus] + ", step = " + this._stx + ", I = [" + this._stMin + ", " + this._stMax + "], grad = " + this._dgx + ", bestObj = " + this._fvx;
        }

        @Override // hex.optimization.OptimizationUtils.LineSearchSolver
        public boolean evaluate(double[] dArr) {
            double d = this._ginfox._objVal;
            double d2 = this._initialStep;
            this._bound = false;
            this._brackt = false;
            this._sty = 0.0d;
            this._stx = 0.0d;
            this._stMax = 0.0d;
            this._stMin = 0.0d;
            this._betGradient = null;
            this._bestPsiVal = Double.POSITIVE_INFINITY;
            this._bestStep = 0.0d;
            double d3 = this._ginfox._objVal - (this._minRelativeImprovement * this._ginfox._objVal);
            double innerProduct = ArrayUtils.innerProduct(this._ginfox._gradient, dArr);
            double d4 = innerProduct * this._ftol;
            if (d4 > 1.0E-4d) {
                Log.warn(new Object[]{"MoreThuente LS: got possitive differential " + d4});
            }
            if (d4 >= 0.0d) {
                this._returnStatus = 7;
                return false;
            }
            double[] dArr2 = new double[this._beta.length];
            double d5 = this._maxStep - this._minStep;
            double d6 = 2.0d * d5;
            boolean z = true;
            double d7 = this._ginfox._objVal;
            this._fvy = d7;
            this._fvx = d7;
            this._dgy = innerProduct;
            this._dgx = innerProduct;
            this._iter = 0;
            while (true) {
                if (this._brackt) {
                    this._stMin = Math.min(this._stx, this._sty);
                    this._stMax = Math.max(this._stx, this._sty);
                } else {
                    this._stMin = this._stx;
                    this._stMax = d2 + (this._xtrapf * (d2 - this._stx));
                }
                double max = Math.max(Math.min(d2, this._maxStep), this._minStep);
                double d8 = d + (max * d4);
                for (int i = 0; i < dArr2.length; i++) {
                    dArr2[i] = this._beta[i] + (max * dArr[i]);
                }
                GradientInfo gradient = this._gslvr.getGradient(dArr2);
                if (gradient._objVal < d3 && (this._betGradient == null || gradient._objVal - d8 < this._bestPsiVal)) {
                    this._bestPsiVal = gradient._objVal - d8;
                    this._betGradient = gradient;
                    this._bestStep = max;
                }
                this._iter++;
                if (this._iter >= this._maxfev || Double.isNaN(max) || !(Double.isNaN(gradient._objVal) || Double.isInfinite(gradient._objVal) || ArrayUtils.hasNaNsOrInfs(gradient._gradient))) {
                    double innerProduct2 = ArrayUtils.innerProduct(gradient._gradient, dArr);
                    if (Double.isNaN(max) || (this._brackt && (max <= this._stMin || max >= this._stMax))) {
                        break;
                    }
                    if (max == this._maxStep) {
                        if ((gradient._objVal <= d8) & (innerProduct2 <= d4)) {
                            this._returnStatus = 5;
                            this._stx = max;
                            this._ginfox = gradient;
                            break;
                        }
                    }
                    if (max == this._minStep) {
                        if ((gradient._objVal > d8) | (innerProduct2 >= d4)) {
                            this._returnStatus = 4;
                            if (this._betGradient != null) {
                                this._stx = this._bestStep;
                                this._ginfox = this._betGradient;
                            } else {
                                this._stx = max;
                                this._ginfox = gradient;
                            }
                        }
                    }
                    if (this._iter < this._maxfev) {
                        if (this._brackt && this._stMax - this._stMin <= this._xtol * this._stMax) {
                            this._ginfox = gradient;
                            this._returnStatus = 2;
                            break;
                        }
                        if (gradient._objVal < d8 && Math.abs(innerProduct2) <= (-this._gtol) * innerProduct) {
                            this._stx = max;
                            this._dgx = innerProduct2;
                            this._fvx = gradient._objVal;
                            this._ginfox = gradient;
                            this._returnStatus = 1;
                            break;
                        }
                        z = z && (gradient._objVal > d8 || innerProduct2 < d4);
                        double nextStep = nextStep(gradient, innerProduct2, max, z && (gradient._objVal > this._fvx ? 1 : (gradient._objVal == this._fvx ? 0 : -1)) <= 0 && (gradient._objVal > d8 ? 1 : (gradient._objVal == d8 ? 0 : -1)) > 0 ? d4 : 0.0d);
                        if (this._brackt) {
                            if (Math.abs(this._sty - this._stx) >= 0.66d * d6) {
                                nextStep = this._stx + (0.5d * (this._sty - this._stx));
                            }
                            d6 = d5;
                            d5 = Math.abs(this._sty - this._stx);
                        }
                        d2 = nextStep;
                    } else {
                        this._returnStatus = 3;
                        if (this._betGradient != null) {
                            this._stx = this._bestStep;
                            this._ginfox = this._betGradient;
                        } else {
                            this._stx = max;
                            this._ginfox = gradient;
                        }
                    }
                } else {
                    this._brackt = true;
                    this._sty = max;
                    this._maxStep = max;
                    this._fvy = Double.POSITIVE_INFINITY;
                    this._dgy = Double.MAX_VALUE;
                    d2 = max * 0.5d;
                }
            }
            boolean z2 = this._ginfox._objVal < d;
            if (z2) {
                for (int i2 = 0; i2 < dArr2.length; i2++) {
                    dArr2[i2] = this._beta[i2] + (this._stx * dArr[i2]);
                }
                this._beta = dArr2;
            }
            return z2;
        }

        @Override // hex.optimization.OptimizationUtils.LineSearchSolver
        public double step() {
            return this._stx;
        }

        @Override // hex.optimization.OptimizationUtils.LineSearchSolver
        public GradientInfo ginfo() {
            return this._ginfox;
        }
    }

    /* loaded from: input_file:hex/optimization/OptimizationUtils$SimpleBacktrackingLS.class */
    public static final class SimpleBacktrackingLS implements LineSearchSolver {
        private double[] _beta;
        final double _stepDec = 0.33d;
        private double _step;
        private final GradientSolver _gslvr;
        private GradientInfo _ginfo;
        private double _objVal;
        final double _l1pen;
        int _maxfev;
        double _minStep;

        public SimpleBacktrackingLS(GradientSolver gradientSolver, double[] dArr, double d) {
            this(gradientSolver, dArr, d, gradientSolver.getObjective(dArr));
        }

        public SimpleBacktrackingLS(GradientSolver gradientSolver, double[] dArr, double d, GradientInfo gradientInfo) {
            this._stepDec = 0.33d;
            this._maxfev = 20;
            this._minStep = 1.0E-4d;
            this._gslvr = gradientSolver;
            this._beta = dArr;
            this._ginfo = gradientInfo;
            this._l1pen = d;
            this._objVal = this._ginfo._objVal + (this._l1pen * ArrayUtils.l1norm(this._beta, true));
        }

        @Override // hex.optimization.OptimizationUtils.LineSearchSolver
        public int nfeval() {
            return -1;
        }

        @Override // hex.optimization.OptimizationUtils.LineSearchSolver
        public double getObj() {
            return this._objVal;
        }

        @Override // hex.optimization.OptimizationUtils.LineSearchSolver
        public double[] getX() {
            return this._beta;
        }

        @Override // hex.optimization.OptimizationUtils.LineSearchSolver
        public LineSearchSolver setInitialStep(double d) {
            return this;
        }

        @Override // hex.optimization.OptimizationUtils.LineSearchSolver
        public boolean evaluate(double[] dArr) {
            double d = 1.0d;
            for (double d2 : dArr) {
                double abs = Math.abs(1.0E-4d / d2);
                if (abs < d) {
                    d = abs;
                }
            }
            double[] dArr2 = (double[]) dArr.clone();
            int i = 0;
            for (double d3 = 1.0d; i < this._maxfev && d3 >= d; d3 *= 0.33d) {
                GradientInfo objective = this._gslvr.getObjective(ArrayUtils.wadd(this._beta, dArr, dArr2, d3));
                double l1norm = objective._objVal + (this._l1pen * ArrayUtils.l1norm(dArr2, true));
                if (l1norm < this._objVal) {
                    this._ginfo = objective;
                    this._objVal = l1norm;
                    this._beta = dArr2;
                    this._step = d3;
                    return true;
                }
                i++;
            }
            return false;
        }

        @Override // hex.optimization.OptimizationUtils.LineSearchSolver
        public double step() {
            return this._step;
        }

        @Override // hex.optimization.OptimizationUtils.LineSearchSolver
        public GradientInfo ginfo() {
            return this._ginfo;
        }

        public String toString() {
            return "";
        }
    }
}
