package hex.hglm;

import hex.DataInfo;
import hex.Model;
import hex.ModelBuilder;
import hex.ModelCategory;
import hex.ModelMetrics;
import hex.ModelMetricsRegressionHGLM;
import hex.glm.GLMModel;
import hex.hglm.HGLMModel;
import hex.hglm.HGLMTask;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;
import org.joda.time.format.DateTimeFormat;
import org.joda.time.format.DateTimeFormatter;
import water.H2O;
import water.Job;
import water.Key;
import water.exceptions.H2OModelBuilderIllegalArgumentException;
import water.fvec.Frame;
import water.udf.CFuncRef;
import water.util.ArrayUtils;
import water.util.Log;
import water.util.TwoDimTable;

/* loaded from: input_file:hex/hglm/HGLM.class */
public class HGLM extends ModelBuilder<HGLMModel, HGLMModel.HGLMParameters, HGLMModel.HGLMModelOutput> {
    long _startTime;
    private transient ComputationStateHGLM _state;
    private static final DateTimeFormatter DATE_TIME_FORMATTER = DateTimeFormat.forPattern("yyyy-MM-dd HH:mm:ss");

    /* loaded from: input_file:hex/hglm/HGLM$HGLMDriver.class */
    private class HGLMDriver extends ModelBuilder<HGLMModel, HGLMModel.HGLMParameters, HGLMModel.HGLMModelOutput>.Driver {
        DataInfo _dinfo;

        private HGLMDriver() {
            super();
            this._dinfo = null;
        }

        /* JADX WARN: Multi-variable type inference failed */
        @Override // hex.ModelBuilder.Driver
        public void computeImpl() {
            HGLM.this._startTime = System.currentTimeMillis();
            HGLM.this.init(true);
            if (HGLM.this.error_count() > 0) {
                throw H2OModelBuilderIllegalArgumentException.makeFromBuilder(HGLM.this);
            }
            HGLM.this._job.update(0L, "Initializing HGLM model training");
            HGLMModel hGLMModel = null;
            ScoringHistory scoringHistory = new ScoringHistory();
            ScoringHistory scoringHistory2 = ((HGLMModel.HGLMParameters) HGLM.this._parms)._valid == null ? null : new ScoringHistory();
            try {
                this._dinfo = new DataInfo((Frame) HGLM.this._train.m1510clone(), (Frame) null, 1, ((HGLMModel.HGLMParameters) HGLM.this._parms)._use_all_factor_levels, DataInfo.TransformType.NONE, DataInfo.TransformType.NONE, ((HGLMModel.HGLMParameters) HGLM.this._parms).missingValuesHandling() == GLMModel.GLMParameters.MissingValuesHandling.Skip, ((HGLMModel.HGLMParameters) HGLM.this._parms).missingValuesHandling() == GLMModel.GLMParameters.MissingValuesHandling.MeanImputation || ((HGLMModel.HGLMParameters) HGLM.this._parms).missingValuesHandling() == GLMModel.GLMParameters.MissingValuesHandling.PlugValues, ((HGLMModel.HGLMParameters) HGLM.this._parms).makeImputer(), false, HGLM.this.hasWeightCol(), HGLM.this.hasOffsetCol(), HGLM.this.hasFoldCol(), (Model.InteractionSpec) null);
                hGLMModel = new HGLMModel(HGLM.this.dest(), (HGLMModel.HGLMParameters) HGLM.this._parms, new HGLMModel.HGLMModelOutput(HGLM.this, this._dinfo));
                hGLMModel.write_lock(HGLM.this._job);
                HGLM.this._job.update(1L, "Starting to build HGLM model...");
                if (HGLMModel.HGLMParameters.Method.EM == ((HGLMModel.HGLMParameters) HGLM.this._parms)._method) {
                    fitEM(hGLMModel, HGLM.this._job, scoringHistory, scoringHistory2);
                }
                ((HGLMModel.HGLMModelOutput) hGLMModel._output).setModelOutputFields(HGLM.this._state);
                scoreAndUpdateModel(hGLMModel, true, scoringHistory);
                ((HGLMModel.HGLMModelOutput) hGLMModel._output)._model_summary = generateSummary((HGLMModel.HGLMModelOutput) hGLMModel._output);
                ((HGLMModel.HGLMModelOutput) hGLMModel._output)._start_time = HGLM.this._startTime;
                ((HGLMModel.HGLMModelOutput) hGLMModel._output)._training_time_ms = System.currentTimeMillis() - HGLM.this._startTime;
                ((HGLMModel.HGLMModelOutput) hGLMModel._output)._scoring_history = scoringHistory.to2dTable();
                if (HGLM.this.valid() != null) {
                    scoreAndUpdateModel(hGLMModel, false, scoringHistory2);
                    if (scoringHistory2._scoringIters.size() > 0) {
                        ((HGLMModel.HGLMModelOutput) hGLMModel._output)._scoring_history_valid = scoringHistory2.to2dTable();
                    }
                }
                hGLMModel.update(HGLM.this._job);
                hGLMModel.unlock(HGLM.this._job);
            } catch (Throwable th) {
                hGLMModel.update(HGLM.this._job);
                hGLMModel.unlock(HGLM.this._job);
                throw th;
            }
        }

        private TwoDimTable generateSummary(HGLMModel.HGLMModelOutput hGLMModelOutput) {
            TwoDimTable twoDimTable = new TwoDimTable("HGLM Model", "summary", new String[]{""}, new String[]{"number_of_iterations", "loglikelihood", "noise_variance"}, new String[]{"int", "double", "double"}, new String[]{"%d", "%.5f", "%.5f"}, "");
            twoDimTable.set(0, 0, Integer.valueOf(hGLMModelOutput._iterations));
            twoDimTable.set(0, 1, Double.valueOf(hGLMModelOutput._log_likelihood));
            twoDimTable.set(0, 2, Double.valueOf(hGLMModelOutput._tau_e_var));
            return twoDimTable;
        }

        private long timeSinceLastScoring(long j) {
            return System.currentTimeMillis() - j;
        }

        private void scoreAndUpdateModel(HGLMModel hGLMModel, boolean z, ScoringHistory scoringHistory) {
            Log.info("Scoring after " + timeSinceLastScoring(HGLM.this._startTime) + "ms at iteration " + ((HGLMModel.HGLMModelOutput) hGLMModel._output)._iterations);
            long currentTimeMillis = System.currentTimeMillis();
            if (!z) {
                Log.info("Scoring on validation dataset.");
                hGLMModel.score(((HGLMModel.HGLMParameters) HGLM.this._parms).valid(), (String) null, CFuncRef.from(((HGLMModel.HGLMParameters) HGLM.this._parms)._custom_metric_func)).delete();
                ModelMetricsRegressionHGLM modelMetricsRegressionHGLM = (ModelMetricsRegressionHGLM) ModelMetrics.getFromDKV(hGLMModel, ((HGLMModel.HGLMParameters) HGLM.this._parms).valid());
                if (null != modelMetricsRegressionHGLM) {
                    ((HGLMModel.HGLMModelOutput) hGLMModel._output)._validation_metrics = modelMetricsRegressionHGLM;
                    ((HGLMModel.HGLMModelOutput) hGLMModel._output)._log_likelihood_valid = ((ModelMetricsRegressionHGLM) ((HGLMModel.HGLMModelOutput) hGLMModel._output)._validation_metrics).llg();
                    scoringHistory.addIterationScore(HGLM.this._state._iter, ((HGLMModel.HGLMModelOutput) hGLMModel._output)._log_likelihood_valid, ((HGLMModel.HGLMModelOutput) hGLMModel._output)._tau_e_var);
                    return;
                }
                return;
            }
            hGLMModel.score(((HGLMModel.HGLMParameters) HGLM.this._parms).train(), (String) null, CFuncRef.from(((HGLMModel.HGLMParameters) HGLM.this._parms)._custom_metric_func)).delete();
            ModelMetricsRegressionHGLM modelMetricsRegressionHGLM2 = (ModelMetricsRegressionHGLM) ModelMetrics.getFromDKV(hGLMModel, ((HGLMModel.HGLMParameters) HGLM.this._parms).train());
            ((HGLMModel.HGLMModelOutput) hGLMModel._output)._training_metrics = modelMetricsRegressionHGLM2;
            ((HGLMModel.HGLMModelOutput) hGLMModel._output)._training_time_ms = currentTimeMillis - ((HGLMModel.HGLMModelOutput) hGLMModel._output)._start_time;
            if (null != modelMetricsRegressionHGLM2) {
                ((HGLMModel.HGLMModelOutput) hGLMModel._output)._log_likelihood = modelMetricsRegressionHGLM2._log_likelihood;
                ((HGLMModel.HGLMModelOutput) hGLMModel._output)._icc = (double[]) modelMetricsRegressionHGLM2._icc.clone();
                scoringHistory.addIterationScore(HGLM.this._state._iter, ((HGLMModel.HGLMModelOutput) hGLMModel._output)._log_likelihood, modelMetricsRegressionHGLM2._var_residual);
            }
        }

        void fitEM(HGLMModel hGLMModel, Job job, ScoringHistory scoringHistory, ScoringHistory scoringHistory2) {
            double[][] estimateNewRandomEffects;
            HGLMTask.ResidualLLHTask residualLLHTask;
            int i = 0;
            HGLMTask.ComputationEngineTask computationEngineTask = new HGLMTask.ComputationEngineTask(job, (HGLMModel.HGLMParameters) HGLM.this._parms, this._dinfo);
            computationEngineTask.doAll(this._dinfo._adaptedFrame);
            ((HGLMModel.HGLMModelOutput) hGLMModel._output).setModelOutput(computationEngineTask);
            if (((HGLMModel.HGLMParameters) HGLM.this._parms)._showFixedMatVecs) {
                ((HGLMModel.HGLMModelOutput) hGLMModel._output).setModelOutputFixMatVec(computationEngineTask);
            }
            HGLM.this._state = new ComputationStateHGLM(HGLM.this._job, (HGLMModel.HGLMParameters) HGLM.this._parms, this._dinfo, computationEngineTask, 0);
            try {
                if (((HGLMModel.HGLMParameters) HGLM.this._parms)._max_iterations > 0) {
                    double[] dArr = (double[]) HGLM.this._state.getBeta().clone();
                    double tauEVarE10 = HGLM.this._state.getTauEVarE10();
                    double[][] copy2DArray = ArrayUtils.copy2DArray(HGLM.this._state.getT());
                    do {
                        i++;
                        double[][][] generateCJInverse = HGLMUtils.generateCJInverse(computationEngineTask._ArjTArj, tauEVarE10, HGLMUtils.generateTInverse(copy2DArray));
                        estimateNewRandomEffects = HGLMUtils.estimateNewRandomEffects(generateCJInverse, computationEngineTask._ArjTYj, computationEngineTask._ArjTAfj, dArr);
                        dArr = HGLMUtils.estimateFixedCoeff(computationEngineTask._AfTAftInv, computationEngineTask._AfjTYjSum, computationEngineTask._AfjTArj, estimateNewRandomEffects);
                        copy2DArray = HGLMUtils.estimateNewtMat(estimateNewRandomEffects, tauEVarE10, generateCJInverse, computationEngineTask._oneOverJ);
                        residualLLHTask = new HGLMTask.ResidualLLHTask(HGLM.this._job, (HGLMModel.HGLMParameters) HGLM.this._parms, this._dinfo, estimateNewRandomEffects, dArr, computationEngineTask);
                        residualLLHTask.doAll(this._dinfo._adaptedFrame);
                        tauEVarE10 = residualLLHTask._residualSquare * computationEngineTask._oneOverN;
                        if (!HGLMUtils.checkPositiveG(computationEngineTask._numLevel2Units, copy2DArray)) {
                            Log.info("HGLM model building is stopped due to matrix G in section II.V of the doc is no longer PSD");
                        }
                    } while (progress(dArr, estimateNewRandomEffects, copy2DArray, tauEVarE10, scoringHistory, scoringHistory2, hGLMModel, residualLLHTask));
                }
            } catch (Exception e) {
                if (i <= 1) {
                    throw new RuntimeException(e);
                }
            }
        }

        public boolean progress(double[] dArr, double[][] dArr2, double[][] dArr3, double d, ScoringHistory scoringHistory, ScoringHistory scoringHistory2, HGLMModel hGLMModel, HGLMTask.ResidualLLHTask residualLLHTask) {
            HGLM.this._state._iter++;
            if (HGLM.this._state._iter >= ((HGLMModel.HGLMParameters) HGLM.this._parms)._max_iterations || HGLM.this.stop_requested()) {
                return false;
            }
            double[] dArr4 = new double[dArr.length];
            ArrayUtils.minus(dArr4, dArr, HGLM.this._state.getBeta());
            double maxMag = ArrayUtils.maxMag(dArr4) / ArrayUtils.maxMag(dArr);
            double[][] dArr5 = new double[dArr3.length][dArr3[0].length];
            ArrayUtils.minus(dArr5, dArr3, HGLM.this._state.getT());
            double maxMag2 = ArrayUtils.maxMag(dArr5) / ArrayUtils.maxMag(dArr3);
            double[][] dArr6 = new double[dArr2.length][dArr2[0].length];
            ArrayUtils.minus(dArr6, dArr2, HGLM.this._state.getUbeta());
            boolean z = maxMag <= ((HGLMModel.HGLMParameters) HGLM.this._parms)._em_epsilon && maxMag2 <= ((HGLMModel.HGLMParameters) HGLM.this._parms)._em_epsilon && ArrayUtils.maxMag(dArr6) / ArrayUtils.maxMag(dArr2) <= ((HGLMModel.HGLMParameters) HGLM.this._parms)._em_epsilon && Math.abs(d - HGLM.this._state.getTauEVarE10()) / d <= ((HGLMModel.HGLMParameters) HGLM.this._parms)._em_epsilon;
            if (!z) {
                HGLM.this._state.setBeta(dArr);
                HGLM.this._state.setUbeta(dArr2);
                HGLM.this._state.setT(dArr3);
                HGLM.this._state.setTauEVarE10(d);
                if (((HGLMModel.HGLMParameters) HGLM.this._parms)._score_each_iteration || ((HGLMModel.HGLMParameters) HGLM.this._parms)._score_iteration_interval % HGLM.this._state._iter == 0) {
                    ((HGLMModel.HGLMModelOutput) hGLMModel._output).setModelOutputFields(HGLM.this._state);
                    scoreAndUpdateModel(hGLMModel, true, scoringHistory);
                    if (((HGLMModel.HGLMParameters) HGLM.this._parms).valid() != null) {
                        scoreAndUpdateModel(hGLMModel, false, scoringHistory2);
                    }
                } else {
                    scoringHistory.addIterationScore(HGLM.this._state._iter, MetricBuilderHGLM.calHGLMLlg(HGLM.this._state._nobs, dArr3, d, ((HGLMModel.HGLMModelOutput) hGLMModel._output)._arjtarj, residualLLHTask._sse_fixed, residualLLHTask._yMinusXTimesZ), d);
                }
            }
            return !z;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:hex/hglm/HGLM$ScoringHistory.class */
    public static class ScoringHistory {
        private ArrayList<Integer> _scoringIters = new ArrayList<>();
        private ArrayList<Long> _scoringTimes = new ArrayList<>();
        private ArrayList<Double> _logLikelihood = new ArrayList<>();
        private ArrayList<Double> _tauEVar = new ArrayList<>();

        ScoringHistory() {
        }

        public ArrayList<Integer> getScoringIters() {
            return this._scoringIters;
        }

        public void addIterationScore(int i, double d, double d2) {
            this._scoringIters.add(Integer.valueOf(i));
            this._scoringTimes.add(Long.valueOf(System.currentTimeMillis()));
            this._logLikelihood.add(Double.valueOf(d));
            this._tauEVar.add(Double.valueOf(d2));
        }

        public TwoDimTable to2dTable() {
            int size = this._scoringIters.size();
            TwoDimTable twoDimTable = new TwoDimTable("Scoring History", "", new String[size], new String[]{"timestamp", "number_of_iterations", "loglikelihood", "noise_variance"}, new String[]{"string", "int", "double", "double"}, new String[]{"%s", "%d", "%.5f", "%.5f"}, "");
            int i = 0;
            for (int i2 = 0; i2 < size; i2++) {
                int i3 = i;
                int i4 = i + 1;
                twoDimTable.set(i2, i3, HGLM.DATE_TIME_FORMATTER.print(this._scoringTimes.get(i2).longValue()));
                int i5 = i4 + 1;
                twoDimTable.set(i2, i4, this._scoringIters.get(i2));
                twoDimTable.set(i2, i5, this._logLikelihood.get(i2));
                twoDimTable.set(i2, i5 + 1, this._tauEVar.get(i2));
                i = 0;
            }
            return twoDimTable;
        }
    }

    @Override // hex.ModelBuilder
    public ModelCategory[] can_build() {
        return new ModelCategory[]{ModelCategory.Regression};
    }

    @Override // hex.ModelBuilder
    public boolean isSupervised() {
        return true;
    }

    @Override // hex.ModelBuilder
    public ModelBuilder.BuilderVisibility builderVisibility() {
        return ModelBuilder.BuilderVisibility.Experimental;
    }

    @Override // hex.ModelBuilder
    public boolean havePojo() {
        return false;
    }

    @Override // hex.ModelBuilder
    public boolean haveMojo() {
        return false;
    }

    public HGLM(boolean z) {
        super(new HGLMModel.HGLMParameters(), z);
    }

    protected HGLM(HGLMModel.HGLMParameters hGLMParameters) {
        super(hGLMParameters);
        init(false);
    }

    public HGLM(HGLMModel.HGLMParameters hGLMParameters, Key<HGLMModel> key) {
        super(hGLMParameters, key);
        init(false);
    }

    @Override // hex.ModelBuilder
    protected ModelBuilder<HGLMModel, HGLMModel.HGLMParameters, HGLMModel.HGLMModelOutput>.Driver trainModelImpl() {
        return new HGLMDriver();
    }

    @Override // hex.ModelBuilder
    public void init(boolean z) {
        if (((HGLMModel.HGLMParameters) this._parms)._nfolds > 0 || ((HGLMModel.HGLMParameters) this._parms)._fold_column != null) {
            error("nfolds or _fold_coumn", " cross validation is not supported in HGLM right now.");
        }
        if (null != ((HGLMModel.HGLMParameters) this._parms)._family && !GLMModel.GLMParameters.Family.gaussian.equals(((HGLMModel.HGLMParameters) this._parms)._family)) {
            error("family", " only Gaussian families are supported now");
        }
        if (null != ((HGLMModel.HGLMParameters) this._parms)._method && !HGLMModel.HGLMParameters.Method.EM.equals(((HGLMModel.HGLMParameters) this._parms)._method)) {
            error("method", " only EM (expectation maximization) is supported for now.");
        }
        if (null != ((HGLMModel.HGLMParameters) this._parms)._missing_values_handling && GLMModel.GLMParameters.MissingValuesHandling.PlugValues == ((HGLMModel.HGLMParameters) this._parms)._missing_values_handling && ((HGLMModel.HGLMParameters) this._parms)._plug_values == null) {
            error("PlugValues", " if specified, must provide a frame with plug values in plug_values.");
        }
        if (((HGLMModel.HGLMParameters) this._parms)._tau_u_var_init < CMAESOptimizer.DEFAULT_STOPFITNESS) {
            error("tau_u_var_init", "if set, must > 0.0.");
        }
        if (((HGLMModel.HGLMParameters) this._parms)._tau_e_var_init < CMAESOptimizer.DEFAULT_STOPFITNESS) {
            error("tau_e_var_init", "if set, must > 0.0.");
        }
        if (((HGLMModel.HGLMParameters) this._parms)._seed == 0) {
            error("seed", "cannot be set to any number except zero.");
        }
        if (((HGLMModel.HGLMParameters) this._parms)._em_epsilon < CMAESOptimizer.DEFAULT_STOPFITNESS) {
            error("em_epsilon", "if specified, must >= 0.0.");
        }
        if (((HGLMModel.HGLMParameters) this._parms)._score_iteration_interval <= 0) {
            error("score_iteration_interval", "if specified must be >= 1.");
        }
        super.init(z);
        if (error_count() > 0) {
            throw H2OModelBuilderIllegalArgumentException.makeFromBuilder(this);
        }
        if (z) {
            if (((HGLMModel.HGLMParameters) this._parms)._max_iterations == 0) {
                warn("max_iterations", "for HGLM, must be >= 1 (or -1 for unlimited or default setting) to obtain proper model.  Setting it to be 0 will only return the correct coefficient names and an empty model.");
                warn("_max_iterations", H2O.technote(2, "for HGLM, if specified, must be >= 1 or == -1."));
            }
            if (((HGLMModel.HGLMParameters) this._parms)._max_iterations == -1) {
                ((HGLMModel.HGLMParameters) this._parms)._max_iterations = 1000;
            }
            Frame train = train();
            List list = (List) Arrays.stream(train.names()).collect(Collectors.toList());
            if (((HGLMModel.HGLMParameters) this._parms)._group_column == null) {
                error("group_column", " column used to generate level 2 units is missing");
            } else if (!list.contains(((HGLMModel.HGLMParameters) this._parms)._group_column)) {
                error("group_column", " is not found in the training frame.");
            } else if (!train.vec(((HGLMModel.HGLMParameters) this._parms)._group_column).isCategorical()) {
                error("group_column", " should be a categorical column.");
            }
            if (((HGLMModel.HGLMParameters) this._parms)._random_columns == null && !((HGLMModel.HGLMParameters) this._parms)._random_intercept) {
                error("random_columns", " should not be null if random_intercept is false.  You must specify predictors in random_columns or set random_intercept to true.");
            }
            if (((HGLMModel.HGLMParameters) this._parms)._random_columns != null) {
                if (!(Arrays.stream(((HGLMModel.HGLMParameters) this._parms)._random_columns).filter(str -> {
                    return list.contains(str);
                }).count() == ((long) ((HGLMModel.HGLMParameters) this._parms)._random_columns.length))) {
                    error("random_columns", " can only contain columns in the training frame.");
                }
            }
            if (((HGLMModel.HGLMParameters) this._parms)._gen_syn_data) {
                ((HGLMModel.HGLMParameters) this._parms)._max_iterations = 0;
                if (((HGLMModel.HGLMParameters) this._parms)._tau_e_var_init <= CMAESOptimizer.DEFAULT_STOPFITNESS) {
                    error("tau_e_var_init", "If gen_syn_data is true, tau_e_var_init must be > 0.");
                }
            }
        }
    }
}
