package hex.glm;

import hex.DataInfo;
import hex.glm.DispersionTask;
import hex.glm.GLMModel;
import hex.glm.GLMTask;
import java.util.ArrayList;
import java.util.Collections;
import water.Job;
import water.fvec.Frame;
import water.util.Log;

/* loaded from: input_file:hex/glm/DispersionUtils.class */
public class DispersionUtils {
    public static double estimateGammaMLSE(GLMTask.ComputeGammaMLSETsk computeGammaMLSETsk, double d, double[] dArr, GLMModel.GLMParameters gLMParameters, ComputationState computationState, Job job, GLMModel gLMModel) {
        double d2 = (computeGammaMLSETsk._wsum + computeGammaMLSETsk._sumlnyiOui) - computeGammaMLSETsk._sumyiOverui;
        DataInfo activeData = computationState.activeData();
        Frame frame = activeData._adaptedFrame;
        long currentTimeMillis = System.currentTimeMillis();
        long j = gLMParameters._max_runtime_secs > 0.0d ? (long) ((gLMParameters._max_runtime_secs * 1000.0d) - (currentTimeMillis - ((GLMModel.GLMOutput) gLMModel._output)._start_time)) : Long.MAX_VALUE;
        for (int i = 0; i < gLMParameters._max_iterations_dispersion; i++) {
            GLMTask.ComputeDiTriGammaTsk computeDiTriGammaTsk = (GLMTask.ComputeDiTriGammaTsk) new GLMTask.ComputeDiTriGammaTsk(null, activeData, job._key, dArr, gLMParameters, d).doAll(frame);
            double log = ((computeGammaMLSETsk._wsum * Math.log(d)) - computeDiTriGammaTsk._sumDigamma) + d2;
            double d3 = (computeGammaMLSETsk._wsum / d) - computeDiTriGammaTsk._sumTrigamma;
            double d4 = log / d3;
            if (d3 == 0.0d || !Double.isFinite(d4)) {
                return d;
            }
            if (Math.abs(d4) < gLMParameters._dispersion_epsilon) {
                return d - d4;
            }
            double d5 = d - d4;
            d = d5 < 0.0d ? d * 0.5d : d5;
            if (i % 100 == 0 && (job.stop_requested() || System.currentTimeMillis() - currentTimeMillis > j)) {
                Log.warn(new Object[]{"gamma dispersion parameter estimation was interrupted by user or due to time out.  Estimation process has not converged. Increase your max_runtime_secs if you have set maximum runtime for your model building process."});
                return d;
            }
        }
        Log.warn(new Object[]{"gamma dispersion parameter estimation fails to converge within " + gLMParameters._max_iterations_dispersion + " iterations.  Increase max_iterations_dispersion or decrease dispersion_epsilon."});
        return d;
    }

    public static double estimateTweedieDispersionOnly(GLMModel.GLMParameters gLMParameters, GLMModel gLMModel, Job job, double[] dArr, DataInfo dataInfo) {
        double d;
        long currentTimeMillis = System.currentTimeMillis();
        long j = gLMParameters._max_runtime_secs > 0.0d ? (long) ((gLMParameters._max_runtime_secs * 1000.0d) - (currentTimeMillis - ((GLMModel.GLMOutput) gLMModel._output)._start_time)) : Long.MAX_VALUE;
        TweedieMLDispersionOnly tweedieMLDispersionOnly = new TweedieMLDispersionOnly(gLMParameters.train(), gLMParameters, gLMModel, dArr, dataInfo);
        double d2 = tweedieMLDispersionOnly._dispersionParameter;
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        ArrayList arrayList3 = new ArrayList();
        for (int i = 0; i < gLMParameters._max_iterations_dispersion; i++) {
            tweedieMLDispersionOnly.updateDispersionP(d2);
            DispersionTask.ComputeMaxSumSeriesTsk computeMaxSumSeriesTsk = new DispersionTask.ComputeMaxSumSeriesTsk(tweedieMLDispersionOnly, gLMParameters, true);
            computeMaxSumSeriesTsk.doAll(tweedieMLDispersionOnly._infoFrame);
            double d3 = computeMaxSumSeriesTsk._logLL / computeMaxSumSeriesTsk._nobsLL;
            arrayList.add(Double.valueOf(d3));
            arrayList3.add(Double.valueOf(d2));
            if (arrayList.size() > 1) {
                arrayList2.add(Double.valueOf(((Double) arrayList.get(i)).doubleValue() - ((Double) arrayList.get(i - 1)).doubleValue()));
                if (Math.abs(((Double) arrayList2.get(arrayList2.size() - 1)).doubleValue()) < gLMParameters._dispersion_epsilon) {
                    tweedieMLDispersionOnly.cleanUp();
                    Log.info(new Object[]{"last dispersion " + d2});
                    return ((Double) arrayList3.get(arrayList.indexOf(Collections.max(arrayList)))).doubleValue();
                }
            }
            double d4 = computeMaxSumSeriesTsk._dLogLL / computeMaxSumSeriesTsk._d2LogLL;
            if (Math.abs(d4) < 0.001d) {
                double dispersionLS = dispersionLS(computeMaxSumSeriesTsk, tweedieMLDispersionOnly, gLMParameters);
                if (!Double.isFinite(dispersionLS)) {
                    Log.info(new Object[]{"last dispersion " + d2});
                    return ((Double) arrayList3.get(arrayList.indexOf(Collections.max(arrayList)))).doubleValue();
                }
                d = d2 - dispersionLS;
            } else {
                d = d2 - d4;
                if (d < 0.0d) {
                    d = d2 * 0.5d;
                }
                tweedieMLDispersionOnly.updateDispersionP(d);
                DispersionTask.ComputeMaxSumSeriesTsk computeMaxSumSeriesTsk2 = new DispersionTask.ComputeMaxSumSeriesTsk(tweedieMLDispersionOnly, gLMParameters, false);
                computeMaxSumSeriesTsk2.doAll(tweedieMLDispersionOnly._infoFrame);
                if (computeMaxSumSeriesTsk2._logLL / computeMaxSumSeriesTsk2._nobsLL <= d3) {
                    d = d2 + (gLMParameters._dispersion_learning_rate * d4);
                }
            }
            d2 = d < 0.0d ? d2 * 0.5d : d;
            if (i % 100 == 0 && (job.stop_requested() || System.currentTimeMillis() - currentTimeMillis > j)) {
                Log.warn(new Object[]{"tweedie dispersion parameter estimation was interrupted by user or due to time out.  Estimation process has not converged. Increase your max_runtime_secs if you have set maximum runtime for your model building process."});
                tweedieMLDispersionOnly.cleanUp();
                Log.info(new Object[]{"last dispersion " + d2});
                return ((Double) arrayList3.get(arrayList.indexOf(Collections.max(arrayList)))).doubleValue();
            }
        }
        tweedieMLDispersionOnly.cleanUp();
        if (arrayList3.size() <= 0) {
            return d2;
        }
        Log.info(new Object[]{"last dispersion " + d2});
        return ((Double) arrayList3.get(arrayList.indexOf(Collections.max(arrayList)))).doubleValue();
    }

    public static double dispersionLS(DispersionTask.ComputeMaxSumSeriesTsk computeMaxSumSeriesTsk, TweedieMLDispersionOnly tweedieMLDispersionOnly, GLMModel.GLMParameters gLMParameters) {
        double d = Double.NEGATIVE_INFINITY;
        double d2 = tweedieMLDispersionOnly._dispersionParameter;
        double d3 = computeMaxSumSeriesTsk._dLogLL / computeMaxSumSeriesTsk._d2LogLL;
        for (int i = 0; i < gLMParameters._max_iterations_dispersion; i++) {
            if (!Double.isFinite(d3)) {
                return Double.NaN;
            }
            tweedieMLDispersionOnly.updateDispersionP(d2 - d3);
            double d4 = ((DispersionTask.ComputeMaxSumSeriesTsk) new DispersionTask.ComputeMaxSumSeriesTsk(tweedieMLDispersionOnly, gLMParameters, false).doAll(tweedieMLDispersionOnly._infoFrame))._logLL / r0._nobsLL;
            if (d4 <= d) {
                return d3;
            }
            d = d4;
            d3 = 2.0d * d3;
        }
        return d3;
    }

    public static double[] makeZeros(double[] dArr, double[] dArr2) {
        int length = dArr2.length;
        for (int i = 0; i < length; i++) {
            dArr2[i] = dArr2[i] - dArr[i];
        }
        return dArr2;
    }
}
