package hex.glm;

import hex.DataInfo;
import hex.glm.DispersionTask;
import hex.glm.GLMModel;
import hex.glm.GLMTask;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;
import org.apache.commons.math3.special.Gamma;
import water.Job;
import water.Key;
import water.MRTask;
import water.Scope;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.Vec;
import water.util.Log;

/* loaded from: input_file:hex/glm/DispersionUtils.class */
public class DispersionUtils {

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: hex.glm.DispersionUtils$1MomentMethodThetaEstimation, reason: invalid class name */
    /* loaded from: input_file:hex/glm/DispersionUtils$1MomentMethodThetaEstimation.class */
    public class C1MomentMethodThetaEstimation extends MRTask<C1MomentMethodThetaEstimation> {
        double _muSqSum;
        double _sSqSum;
        double _muSum;
        double _wSum;

        C1MomentMethodThetaEstimation() {
        }

        @Override // water.MRTask
        public void map(Chunk[] chunkArr) {
            for (int i = 0; i < chunkArr[0]._len; i++) {
                double atd = chunkArr[2].atd(i);
                this._muSqSum += atd * Math.pow(chunkArr[0].atd(i), 2.0d);
                this._sSqSum += atd * Math.pow(chunkArr[1].atd(i) - chunkArr[0].atd(i), 2.0d);
                this._muSum += atd * chunkArr[0].atd(i);
                this._wSum += atd;
            }
        }

        @Override // water.MRTask
        public void reduce(C1MomentMethodThetaEstimation c1MomentMethodThetaEstimation) {
            this._muSqSum += c1MomentMethodThetaEstimation._muSqSum;
            this._sSqSum += c1MomentMethodThetaEstimation._sSqSum;
            this._muSum += c1MomentMethodThetaEstimation._muSum;
            this._wSum += c1MomentMethodThetaEstimation._wSum;
        }
    }

    /* loaded from: input_file:hex/glm/DispersionUtils$CalculateInitialTheta.class */
    static class CalculateInitialTheta extends MRTask<CalculateInitialTheta> {
        double _theta0;

        CalculateInitialTheta() {
        }

        @Override // water.MRTask
        public void map(Chunk[] chunkArr) {
            for (int i = 0; i < chunkArr[0]._len; i++) {
                this._theta0 += chunkArr[2].atd(i) * Math.pow((chunkArr[1].atd(i) / chunkArr[0].atd(i)) - 1.0d, 2.0d);
            }
        }

        @Override // water.MRTask
        public void reduce(CalculateInitialTheta calculateInitialTheta) {
            this._theta0 += calculateInitialTheta._theta0;
        }
    }

    /* loaded from: input_file:hex/glm/DispersionUtils$CalculateNegativeBinomialScoreAndInfo.class */
    static class CalculateNegativeBinomialScoreAndInfo extends MRTask<CalculateNegativeBinomialScoreAndInfo> {
        double _score;
        double _info;
        double _theta;

        CalculateNegativeBinomialScoreAndInfo(double d) {
            this._theta = d;
        }

        @Override // water.MRTask
        public void map(Chunk[] chunkArr) {
            for (int i = 0; i < chunkArr[0]._len; i++) {
                double atd = chunkArr[2].atd(i);
                this._score += atd * (((((Gamma.digamma(this._theta + chunkArr[1].atd(i)) - Gamma.digamma(this._theta)) + Math.log(this._theta)) + 1.0d) - Math.log(this._theta + chunkArr[0].atd(i))) - ((chunkArr[1].atd(i) + this._theta) / (chunkArr[0].atd(i) + this._theta)));
                this._info += atd * (((((-Gamma.trigamma(this._theta + chunkArr[1].atd(i))) + Gamma.trigamma(this._theta)) - (1.0d / this._theta)) + (2.0d / (chunkArr[0].atd(i) + this._theta))) - ((chunkArr[1].atd(i) + this._theta) / Math.pow(chunkArr[0].atd(i) + this._theta, 2.0d)));
            }
        }

        @Override // water.MRTask
        public void reduce(CalculateNegativeBinomialScoreAndInfo calculateNegativeBinomialScoreAndInfo) {
            this._score += calculateNegativeBinomialScoreAndInfo._score;
            this._info += calculateNegativeBinomialScoreAndInfo._info;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:hex/glm/DispersionUtils$NegativeBinomialGradientAndHessian.class */
    public static class NegativeBinomialGradientAndHessian extends MRTask<NegativeBinomialGradientAndHessian> {
        double _grad;
        double _hess;
        double _theta;
        double _invTheta;
        double _invThetaSq;
        double _llh;
        static final /* synthetic */ boolean $assertionsDisabled;

        /* JADX INFO: Access modifiers changed from: package-private */
        public NegativeBinomialGradientAndHessian(double d) {
            if (!$assertionsDisabled && d <= CMAESOptimizer.DEFAULT_STOPFITNESS) {
                throw new AssertionError();
            }
            this._theta = d;
            this._invTheta = 1.0d / d;
            this._invThetaSq = this._invTheta * this._invTheta;
        }

        @Override // water.MRTask
        public void map(Chunk[] chunkArr) {
            for (int i = 0; i < chunkArr[0]._len; i++) {
                double atd = chunkArr[0].atd(i);
                double atd2 = chunkArr[1].atd(i);
                double atd3 = chunkArr[2].atd(i);
                this._grad += atd3 * ((((-atd) * (atd2 + this._invTheta)) / ((atd * this._theta) + 1.0d)) + ((atd2 + (((Math.log((atd * this._theta) + 1.0d) - Gamma.digamma(atd2 + this._invTheta)) + Gamma.digamma(this._invTheta)) * this._invTheta)) * this._invTheta));
                this._hess += atd3 * ((((atd * atd) * (atd2 + this._invTheta)) / Math.pow((atd * this._theta) + 1.0d, 2.0d)) + (((-atd2) + ((2.0d * atd) / ((atd * this._theta) + 1.0d)) + ((((((-2.0d) * Math.log((atd * this._theta) + 1.0d)) + (2.0d * Gamma.digamma(atd2 + this._invTheta))) - (2.0d * Gamma.digamma(this._invTheta))) + ((Gamma.trigamma(atd2 + this._invTheta) - Gamma.trigamma(this._invTheta)) * this._invTheta)) * this._invTheta)) * this._invThetaSq));
                this._llh += (((Gamma.logGamma(atd2 + this._invTheta) - Gamma.logGamma(this._invTheta)) - Gamma.logGamma(atd2 + 1.0d)) + (atd2 * Math.log(this._theta * atd))) - ((atd2 + this._invTheta) * Math.log(1.0d + (this._theta * atd)));
            }
        }

        @Override // water.MRTask
        public void reduce(NegativeBinomialGradientAndHessian negativeBinomialGradientAndHessian) {
            this._grad += negativeBinomialGradientAndHessian._grad;
            this._hess += negativeBinomialGradientAndHessian._hess;
            this._llh += negativeBinomialGradientAndHessian._llh;
        }

        static {
            $assertionsDisabled = !DispersionUtils.class.desiredAssertionStatus();
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static double estimateGammaMLSE(GLMTask.ComputeGammaMLSETsk computeGammaMLSETsk, double d, double[] dArr, GLMModel.GLMParameters gLMParameters, ComputationState computationState, Job job, GLMModel gLMModel) {
        double d2 = (computeGammaMLSETsk._wsum + computeGammaMLSETsk._sumlnyiOui) - computeGammaMLSETsk._sumyiOverui;
        DataInfo activeData = computationState.activeData();
        Frame frame = activeData._adaptedFrame;
        long currentTimeMillis = System.currentTimeMillis();
        long j = gLMParameters._max_runtime_secs > CMAESOptimizer.DEFAULT_STOPFITNESS ? (long) ((gLMParameters._max_runtime_secs * 1000.0d) - (currentTimeMillis - ((GLMModel.GLMOutput) gLMModel._output)._start_time)) : Long.MAX_VALUE;
        for (int i = 0; i < gLMParameters._max_iterations_dispersion; i++) {
            GLMTask.ComputeDiTriGammaTsk computeDiTriGammaTsk = (GLMTask.ComputeDiTriGammaTsk) new GLMTask.ComputeDiTriGammaTsk(null, activeData, job._key, dArr, gLMParameters, d).doAll(frame);
            double log = ((computeGammaMLSETsk._wsum * Math.log(d)) - computeDiTriGammaTsk._sumDigamma) + d2;
            double d3 = (computeGammaMLSETsk._wsum / d) - computeDiTriGammaTsk._sumTrigamma;
            double d4 = log / d3;
            if (d3 == CMAESOptimizer.DEFAULT_STOPFITNESS || !Double.isFinite(d4)) {
                return d;
            }
            if (Math.abs(d4) < gLMParameters._dispersion_epsilon) {
                return d - d4;
            }
            double d5 = d - d4;
            d = d5 < CMAESOptimizer.DEFAULT_STOPFITNESS ? d * 0.5d : d5;
            if (i % 100 == 0 && (job.stop_requested() || System.currentTimeMillis() - currentTimeMillis > j)) {
                Log.warn("gamma dispersion parameter estimation was interrupted by user or due to time out.  Estimation process has not converged. Increase your max_runtime_secs if you have set maximum runtime for your model building process.");
                return d;
            }
        }
        Log.warn("gamma dispersion parameter estimation fails to converge within " + gLMParameters._max_iterations_dispersion + " iterations.  Increase max_iterations_dispersion or decrease dispersion_epsilon.");
        return d;
    }

    private static double getTweedieLogLikelihood(GLMModel.GLMParameters gLMParameters, DataInfo dataInfo, double d, Vec vec) {
        double d2 = new TweedieEstimator(gLMParameters._tweedie_variance_power, d, false, false, false, false).compute(vec, dataInfo._adaptedFrame.vec(gLMParameters._response_column), gLMParameters._weights_column == null ? dataInfo._adaptedFrame.makeCompatible(new Frame(Vec.makeOne(dataInfo._adaptedFrame.numRows())))[0] : dataInfo._adaptedFrame.vec(gLMParameters._weights_column))._loglikelihood;
        Log.debug("Tweedie LogLikelihood(p=" + gLMParameters._tweedie_variance_power + ", phi=" + d + ") = " + d2);
        return d2;
    }

    private static double goldenRatioDispersionSearch(GLMModel.GLMParameters gLMParameters, DataInfo dataInfo, Vec vec, List<Double> list, List<Double> list2, Job job) {
        List list3 = (List) list2.stream().sorted().collect(Collectors.toList());
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < list3.size(); i++) {
            arrayList.add(list.get(list2.indexOf(Double.valueOf(((Double) list3.get(i)).doubleValue()))));
        }
        boolean z = true;
        double d = 1.0E-16d;
        double doubleValue = ((Double) list3.get(0)).doubleValue();
        int i2 = 1;
        while (true) {
            if (i2 >= list3.size()) {
                break;
            }
            doubleValue = ((Double) list3.get(i2)).doubleValue();
            if (((Double) arrayList.get(i2 - 1)).doubleValue() > ((Double) arrayList.get(i2)).doubleValue()) {
                z = false;
                if (i2 > 2) {
                    d = ((Double) list3.get(i2 - 2)).doubleValue();
                } else {
                    list3.add(0, Double.valueOf(1.0E-16d));
                    arrayList.add(0, Double.valueOf(getTweedieLogLikelihood(gLMParameters, dataInfo, 1.0E-16d, vec)));
                }
            } else {
                i2++;
            }
        }
        int size = list3.size();
        int i3 = gLMParameters._max_iterations_dispersion - (10 * size);
        while (true) {
            if (!z || i3 <= size || job.stop_requested()) {
                break;
            }
            size++;
            doubleValue *= 2.0d;
            list3.add(Double.valueOf(doubleValue));
            double tweedieLogLikelihood = getTweedieLogLikelihood(gLMParameters, dataInfo, doubleValue, vec);
            Log.debug("Tweedie looking for the region containing the max. likelihood; upper bound = " + doubleValue + "; llh = " + tweedieLogLikelihood);
            arrayList.add(Double.valueOf(tweedieLogLikelihood));
            if (((Double) arrayList.get(size - 2)).doubleValue() > ((Double) arrayList.get(size - 1)).doubleValue()) {
                if (size > 3) {
                    d = ((Double) list3.get(size - 3)).doubleValue();
                }
                Log.debug("Tweedie found the region containing the max. likelihood; phi lower bound = " + d + "; phi upper bound = " + doubleValue);
            }
        }
        double d2 = (doubleValue - d) * 0.618d;
        double d3 = d;
        double d4 = doubleValue;
        double doubleValue2 = ((Double) list3.get(size - 2)).doubleValue();
        double doubleValue3 = ((Double) arrayList.get(size - 2)).doubleValue();
        if (doubleValue2 > doubleValue) {
            doubleValue2 = d4 - d2;
            doubleValue3 = getTweedieLogLikelihood(gLMParameters, dataInfo, doubleValue2, vec);
        }
        double d5 = d3 + d2;
        double tweedieLogLikelihood2 = getTweedieLogLikelihood(gLMParameters, dataInfo, d5, vec);
        while (size < i3) {
            Log.info("Tweedie golden-section search[iter=" + size + ", phis=(" + d3 + ", " + doubleValue2 + ", " + d5 + ", " + d4 + "), likelihoods=(..., " + doubleValue3 + ", " + tweedieLogLikelihood2 + ", ...)]");
            if (job.stop_requested()) {
                return (d4 + d3) / 2.0d;
            }
            if (tweedieLogLikelihood2 > doubleValue3) {
                d3 = doubleValue2;
            } else {
                d4 = d5;
            }
            double d6 = (d4 - d3) * 0.618d;
            if (d4 - d3 < gLMParameters._dispersion_epsilon) {
                return (d4 + d3) / 2.0d;
            }
            doubleValue2 = d4 - d6;
            d5 = d3 + d6;
            doubleValue3 = getTweedieLogLikelihood(gLMParameters, dataInfo, doubleValue2, vec);
            tweedieLogLikelihood2 = getTweedieLogLikelihood(gLMParameters, dataInfo, d5, vec);
            size++;
        }
        return (d4 + d3) / 2.0d;
    }

    public static double estimateTweedieDispersionOnly(GLMModel.GLMParameters gLMParameters, GLMModel gLMModel, Job job, double[] dArr, DataInfo dataInfo) {
        double d;
        long currentTimeMillis = System.currentTimeMillis();
        long j = gLMParameters._max_runtime_secs > CMAESOptimizer.DEFAULT_STOPFITNESS ? (long) ((gLMParameters._max_runtime_secs * 1000.0d) - (currentTimeMillis - ((GLMModel.GLMOutput) gLMModel._output)._start_time)) : Long.MAX_VALUE;
        TweedieMLDispersionOnly tweedieMLDispersionOnly = new TweedieMLDispersionOnly(gLMParameters.train(), gLMParameters, gLMModel, dArr, dataInfo);
        Vec vec = Scope.track(new DispersionTask.GenPrediction(dArr, gLMModel, dataInfo).doAll(1, (byte) 3, dataInfo._adaptedFrame).outputFrame(Key.make(), new String[]{"prediction"}, (String[][]) null)).vec(0);
        double d2 = tweedieMLDispersionOnly._dispersionParameter;
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        ArrayList arrayList3 = new ArrayList();
        double tweedieLogLikelihood = getTweedieLogLikelihood(gLMParameters, dataInfo, d2, vec);
        ArrayList arrayList4 = new ArrayList();
        ArrayList arrayList5 = new ArrayList();
        arrayList4.add(Double.valueOf(tweedieLogLikelihood));
        arrayList5.add(Double.valueOf(d2));
        for (int i = 0; i < gLMParameters._max_iterations_dispersion; i++) {
            tweedieMLDispersionOnly.updateDispersionP(d2);
            DispersionTask.ComputeMaxSumSeriesTsk computeMaxSumSeriesTsk = new DispersionTask.ComputeMaxSumSeriesTsk(tweedieMLDispersionOnly, gLMParameters, true);
            computeMaxSumSeriesTsk.doAll(tweedieMLDispersionOnly._infoFrame);
            double d3 = computeMaxSumSeriesTsk._logLL / computeMaxSumSeriesTsk._nobsLL;
            arrayList.add(Double.valueOf(d3));
            arrayList3.add(Double.valueOf(d2));
            if (arrayList.size() > 1) {
                arrayList2.add(Double.valueOf(((Double) arrayList.get(i)).doubleValue() - ((Double) arrayList.get(i - 1)).doubleValue()));
                boolean z = Math.abs(((Double) arrayList2.get(arrayList2.size() - 1)).doubleValue()) < gLMParameters._dispersion_epsilon;
                if (i % 10 == 0 || z) {
                    double tweedieLogLikelihood2 = getTweedieLogLikelihood(gLMParameters, dataInfo, d2, vec);
                    arrayList4.add(Double.valueOf(tweedieLogLikelihood2));
                    arrayList5.add(Double.valueOf(d2));
                    if (tweedieLogLikelihood2 < tweedieLogLikelihood) {
                        Log.info("Tweedie sanity check FAIL. Trying Golden-section search instead of Newton's method.");
                        tweedieMLDispersionOnly.cleanUp();
                        double goldenRatioDispersionSearch = goldenRatioDispersionSearch(gLMParameters, dataInfo, vec, arrayList4, arrayList5, job);
                        Log.info("Tweedie dispersion estimate = " + goldenRatioDispersionSearch);
                        return goldenRatioDispersionSearch;
                    }
                    tweedieLogLikelihood = Math.max(tweedieLogLikelihood, tweedieLogLikelihood2);
                    Log.debug("Tweedie sanity check OK");
                }
                if (z) {
                    tweedieMLDispersionOnly.cleanUp();
                    Log.info("last dispersion " + d2);
                    return ((Double) arrayList3.get(arrayList.indexOf(Collections.max(arrayList)))).doubleValue();
                }
            }
            if (arrayList.size() > 10 && arrayList.stream().skip(arrayList.size() - 3).noneMatch(d4 -> {
                return d4 != null && Double.isFinite(d4.doubleValue());
            })) {
                Log.warn("tweedie dispersion parameter estimation got stuck in numerically unstable region.");
                tweedieMLDispersionOnly.cleanUp();
                return Double.NaN;
            }
            double d5 = computeMaxSumSeriesTsk._dLogLL / computeMaxSumSeriesTsk._d2LogLL;
            if (Math.abs(d5) < 0.001d) {
                double dispersionLS = dispersionLS(computeMaxSumSeriesTsk, tweedieMLDispersionOnly, gLMParameters);
                if (!Double.isFinite(dispersionLS)) {
                    Log.info("last dispersion " + d2);
                    return ((Double) arrayList3.get(arrayList.indexOf(Collections.max(arrayList)))).doubleValue();
                }
                d = d2 - dispersionLS;
            } else {
                d = d2 - d5;
                if (d < CMAESOptimizer.DEFAULT_STOPFITNESS) {
                    d = d2 * 0.5d;
                }
                tweedieMLDispersionOnly.updateDispersionP(d);
                DispersionTask.ComputeMaxSumSeriesTsk computeMaxSumSeriesTsk2 = new DispersionTask.ComputeMaxSumSeriesTsk(tweedieMLDispersionOnly, gLMParameters, false);
                computeMaxSumSeriesTsk2.doAll(tweedieMLDispersionOnly._infoFrame);
                if (computeMaxSumSeriesTsk2._logLL / computeMaxSumSeriesTsk2._nobsLL <= d3) {
                    d = d2 + (gLMParameters._dispersion_learning_rate * d5);
                }
            }
            d2 = d < CMAESOptimizer.DEFAULT_STOPFITNESS ? d2 * 0.5d : d;
            if (i % 100 == 0 && (job.stop_requested() || System.currentTimeMillis() - currentTimeMillis > j)) {
                Log.warn("tweedie dispersion parameter estimation was interrupted by user or due to time out.  Estimation process has not converged. Increase your max_runtime_secs if you have set maximum runtime for your model building process.");
                tweedieMLDispersionOnly.cleanUp();
                Log.info("last dispersion " + d2);
                return ((Double) arrayList3.get(arrayList.indexOf(Collections.max(arrayList)))).doubleValue();
            }
        }
        tweedieMLDispersionOnly.cleanUp();
        if (arrayList3.size() <= 0) {
            return d2;
        }
        Log.info("last dispersion " + d2);
        return ((Double) arrayList3.get(arrayList.indexOf(Collections.max(arrayList)))).doubleValue();
    }

    public static double estimateNegBinomialDispersionMomentMethod(GLMModel gLMModel, double[] dArr, DataInfo dataInfo, Vec vec, Vec vec2, Vec vec3) {
        C1MomentMethodThetaEstimation doAll = new C1MomentMethodThetaEstimation().doAll(vec3, vec2, vec);
        return doAll._muSqSum / (doAll._sSqSum - (doAll._muSum / doAll._wSum));
    }

    public static double estimateNegBinomialDispersionFisherScoring(GLMModel.GLMParameters gLMParameters, GLMModel gLMModel, double[] dArr, DataInfo dataInfo) {
        Vec weightsVec = dataInfo._weights ? dataInfo.getWeightsVec() : dataInfo._adaptedFrame.makeCompatible(new Frame(Vec.makeOne(dataInfo._adaptedFrame.numRows())))[0];
        double numRows = weightsVec == null ? dataInfo._adaptedFrame.numRows() : weightsVec.mean() * weightsVec.length();
        Vec vec = new DispersionTask.GenPrediction(dArr, gLMModel, dataInfo).doAll(1, (byte) 3, dataInfo._adaptedFrame).outputFrame(Key.make(), new String[]{"prediction"}, (String[][]) null).vec(0);
        Vec vec2 = dataInfo._adaptedFrame.vec(dataInfo.responseChunkId(0));
        double d = numRows / new CalculateInitialTheta().doAll(vec, vec2, weightsVec)._theta0;
        double d2 = 1.0d;
        int i = 0;
        while (i < gLMParameters._max_iterations_dispersion && Math.abs(d2) >= gLMParameters._dispersion_epsilon) {
            double abs = Math.abs(d);
            CalculateNegativeBinomialScoreAndInfo doAll = new CalculateNegativeBinomialScoreAndInfo(abs).doAll(vec, vec2, weightsVec);
            d2 = doAll._score / doAll._info;
            d = abs + d2;
            i++;
        }
        if (d < CMAESOptimizer.DEFAULT_STOPFITNESS) {
            Log.warn("Dispersion estimate truncated at zero.");
        }
        if (i == gLMParameters._max_iterations_dispersion) {
            Log.warn("Iteration limit reached.");
        }
        return 1.0d / d;
    }

    public static double dispersionLS(DispersionTask.ComputeMaxSumSeriesTsk computeMaxSumSeriesTsk, TweedieMLDispersionOnly tweedieMLDispersionOnly, GLMModel.GLMParameters gLMParameters) {
        double d = Double.NEGATIVE_INFINITY;
        double d2 = tweedieMLDispersionOnly._dispersionParameter;
        double d3 = computeMaxSumSeriesTsk._dLogLL / computeMaxSumSeriesTsk._d2LogLL;
        for (int i = 0; i < gLMParameters._max_iterations_dispersion; i++) {
            if (!Double.isFinite(d3)) {
                return Double.NaN;
            }
            tweedieMLDispersionOnly.updateDispersionP(d2 - d3);
            double d4 = new DispersionTask.ComputeMaxSumSeriesTsk(tweedieMLDispersionOnly, gLMParameters, false).doAll(tweedieMLDispersionOnly._infoFrame)._logLL / r0._nobsLL;
            if (d4 <= d) {
                return d3;
            }
            d = d4;
            d3 = 2.0d * d3;
        }
        return d3;
    }

    public static double[] makeZeros(double[] dArr, double[] dArr2) {
        int length = dArr2.length;
        for (int i = 0; i < length; i++) {
            dArr2[i] = dArr2[i] - dArr[i];
        }
        return dArr2;
    }
}
