package hex.optimization;

import hex.optimization.OptimizationUtils;
import java.util.Arrays;
import java.util.Random;
import water.Iced;
import water.MemoryManager;
import water.util.ArrayUtils;
import water.util.MathUtils;

/* loaded from: input_file:hex/optimization/L_BFGS.class */
public final class L_BFGS extends Iced {
    int _maxIter = 500;
    double _gradEps = 1.0E-8d;
    double _objEps = 1.0E-10d;
    int _historySz = 20;
    History _hist;

    /* loaded from: input_file:hex/optimization/L_BFGS$History.class */
    public static final class History extends Iced {
        private final double[][] _s;
        private final double[][] _y;
        private final double[] _rho;
        private final double[] _alpha;
        final int _m;
        final int _n;
        int _k;

        /* JADX WARN: Type inference failed for: r1v6, types: [double[], double[][]] */
        /* JADX WARN: Type inference failed for: r1v8, types: [double[], double[][]] */
        public History(int i, int i2) {
            this._m = i;
            this._alpha = new double[this._m];
            this._n = i2;
            this._s = new double[i];
            this._y = new double[i];
            this._rho = MemoryManager.malloc8d(i);
            Arrays.fill(this._rho, Double.NaN);
            for (int i3 = 0; i3 < i; i3++) {
                this._s[i3] = MemoryManager.malloc8d(i2);
                Arrays.fill(this._s[i3], Double.NaN);
                this._y[i3] = MemoryManager.malloc8d(i2);
                Arrays.fill(this._y[i3], Double.NaN);
            }
        }

        int getId(int i) {
            return (this._k + i) % this._m;
        }

        /* JADX INFO: Access modifiers changed from: private */
        public final void update(double[] dArr, double[] dArr2, double[] dArr3) {
            int id = getId(0);
            double[] dArr4 = this._y[id];
            double[] dArr5 = this._s[id];
            for (int i = 0; i < dArr2.length; i++) {
                dArr4[i] = dArr2[i] - dArr3[i];
            }
            System.arraycopy(dArr, 0, dArr5, 0, dArr.length);
            this._rho[id] = 1.0d / ArrayUtils.innerProduct(dArr5, dArr4);
            this._k++;
        }

        protected final double[] getSearchDirection(double[] dArr, double[] dArr2) {
            System.arraycopy(dArr, 0, dArr2, 0, dArr2.length);
            if (this._k != 0) {
                int min = Math.min(this._k, this._m);
                for (int i = 1; i <= min; i++) {
                    int id = getId(-i);
                    this._alpha[id] = this._rho[id] * ArrayUtils.innerProduct(this._s[id], dArr2);
                    MathUtils.wadd(dArr2, this._y[id], -this._alpha[id]);
                }
                int id2 = getId(-1);
                double[] dArr3 = this._y[id2];
                ArrayUtils.mult(dArr2, (-1.0d) / (ArrayUtils.innerProduct(dArr3, dArr3) * this._rho[id2]));
                for (int i2 = min; i2 > 0; i2--) {
                    int id3 = getId(-i2);
                    MathUtils.wadd(dArr2, this._s[id3], (-this._alpha[id3]) - (this._rho[id3] * ArrayUtils.innerProduct(this._y[id3], dArr2)));
                }
            } else {
                ArrayUtils.mult(dArr2, -1.0d);
            }
            return dArr2;
        }
    }

    /* loaded from: input_file:hex/optimization/L_BFGS$ProgressMonitor.class */
    public interface ProgressMonitor {
        boolean progress(double[] dArr, OptimizationUtils.GradientInfo gradientInfo);
    }

    /* loaded from: input_file:hex/optimization/L_BFGS$Result.class */
    public static final class Result {
        public final int iter;
        public final double[] coefs;
        public final OptimizationUtils.GradientInfo ginfo;
        public final boolean converged;
        public final double rel_improvement;

        public Result(boolean z, int i, double[] dArr, OptimizationUtils.GradientInfo gradientInfo, double d) {
            this.iter = i;
            this.coefs = dArr;
            this.ginfo = gradientInfo;
            this.converged = z;
            this.rel_improvement = d;
        }

        public String toString() {
            return this.coefs.length < 10 ? "L-BFGS_res(converged? " + this.converged + ", iter = " + this.iter + ", obj = " + this.ginfo._objVal + ", rel_improvement = " + this.rel_improvement + ", coefs = " + Arrays.toString(this.coefs) + ", grad = " + Arrays.toString(this.ginfo._gradient) + ")" : "L-BFGS_res(converged? " + this.converged + ", iter = " + this.iter + ", obj = " + this.ginfo._objVal + ", rel_improvement = " + this.rel_improvement + "grad_linf_norm = " + ArrayUtils.linfnorm(this.ginfo._gradient, false) + ")";
        }
    }

    public L_BFGS setMaxIter(int i) {
        this._maxIter = i;
        return this;
    }

    public L_BFGS setGradEps(double d) {
        this._gradEps = d;
        return this;
    }

    public L_BFGS setObjEps(double d) {
        this._objEps = d;
        return this;
    }

    public L_BFGS setHistorySz(int i) {
        this._historySz = i;
        return this;
    }

    public int k() {
        return this._hist._k;
    }

    public int maxIter() {
        return this._maxIter;
    }

    public final Result solve(OptimizationUtils.GradientSolver gradientSolver, double[] dArr, OptimizationUtils.GradientInfo gradientInfo, ProgressMonitor progressMonitor) {
        if (this._hist == null) {
            this._hist = new History(this._historySz, dArr.length);
        }
        int i = 0;
        double d = 1.0d;
        double[] dArr2 = new double[dArr.length];
        OptimizationUtils.MoreThuente moreThuente = new OptimizationUtils.MoreThuente(gradientSolver, dArr, gradientInfo);
        while (!ArrayUtils.hasNaNsOrInfs(dArr) && ArrayUtils.linfnorm(gradientInfo._gradient, false) > this._gradEps && d > this._objEps && i != this._maxIter) {
            i++;
            this._hist.getSearchDirection(gradientInfo._gradient, dArr2);
            if (!moreThuente.evaluate(dArr2)) {
                break;
            }
            moreThuente.setInitialStep(Math.max(1.0E-16d, moreThuente.step()));
            OptimizationUtils.GradientInfo ginfo = moreThuente.ginfo();
            this._hist.update(dArr2, ginfo._gradient, gradientInfo._gradient);
            d = (gradientInfo._objVal - ginfo._objVal) / Math.abs(gradientInfo._objVal);
            gradientInfo = ginfo;
            if (!progressMonitor.progress(moreThuente.getX(), gradientInfo)) {
                break;
            }
        }
        return new Result(ArrayUtils.linfnorm(gradientInfo._gradient, false) <= this._gradEps || d <= this._objEps, i, moreThuente.getX(), moreThuente.ginfo(), d);
    }

    public final Result solve(OptimizationUtils.GradientSolver gradientSolver, double[] dArr) {
        return solve(gradientSolver, dArr, gradientSolver.getGradient(dArr), new ProgressMonitor() { // from class: hex.optimization.L_BFGS.1
            @Override // hex.optimization.L_BFGS.ProgressMonitor
            public boolean progress(double[] dArr2, OptimizationUtils.GradientInfo gradientInfo) {
                return true;
            }
        });
    }

    public static double[] startCoefs(int i, long j) {
        double[] malloc8d = MemoryManager.malloc8d(i);
        Random random = new Random(j);
        for (int i2 = 0; i2 < malloc8d.length; i2++) {
            malloc8d[i2] = random.nextGaussian();
        }
        return malloc8d;
    }
}
