package hex.hglm;

import hex.DataInfo;
import hex.Model;
import hex.ModelCategory;
import hex.ModelMetrics;
import hex.deeplearning.DeepLearningModel;
import hex.glm.GLM;
import hex.glm.GLMModel;
import hex.hglm.HGLMTask;
import java.io.Serializable;
import java.util.Arrays;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;
import water.AutoBuffer;
import water.Futures;
import water.Job;
import water.Key;
import water.Keyed;
import water.fvec.Frame;
import water.udf.CFuncRef;
import water.util.ArrayUtils;
import water.util.TwoDimTable;

/* loaded from: input_file:hex/hglm/HGLMModel.class */
public class HGLMModel extends Model<HGLMModel, HGLMParameters, HGLMModelOutput> {

    /* loaded from: input_file:hex/hglm/HGLMModel$HGLMModelOutput.class */
    public static class HGLMModelOutput extends Model.Output {
        public DataInfo _dinfo;
        final GLMModel.GLMParameters.Family _family;
        final GLMModel.GLMParameters.Family _random_family;
        public String[] _fixed_coefficient_names;
        public String[] _random_coefficient_names;
        public String[] _group_column_names;
        public long _training_time_ms;
        public double[] _beta;
        public double[][] _ubeta;
        public double[][] _tmat;
        double _tauUVar;
        public double _tau_e_var;
        public double[][] _afjtyj;
        public double[][] _arjtyj;
        public double[][][] _afjtafj;
        public double[][][] _arjtarj;
        public double[][][] _afjtarj;
        public double[][] _yMinusXTimesZ;
        public double[][] _yMinusXTimesZValid;
        public int _num_fixed_coeffs;
        public int _num_random_coeffs;
        int[] _randomCatIndices;
        int[] _randomNumIndices;
        int[] _randomCatArrayStartIndices;
        int _predStartIndexRandom;
        boolean _randomSlopeToo;
        int[] _fixedCatIndices;
        int _numLevel2Units;
        int _level2UnitIndex;
        int _predStartIndexFixed;
        public double[] _icc;
        public double _log_likelihood;
        public double _log_likelihood_valid;
        public int _iterations;
        public int _nobs;
        public int _nobs_valid;
        public double _yMinusFixPredSquare;
        public double _yMinusFixPredSquareValid;
        public TwoDimTable _scoring_history_valid;

        public void setModelOutputFixMatVec(HGLMTask.ComputationEngineTask computationEngineTask) {
            this._afjtyj = ArrayUtils.copy2DArray(computationEngineTask._AfjTYj);
            this._arjtyj = ArrayUtils.copy2DArray(computationEngineTask._ArjTYj);
            this._afjtafj = HGLMUtils.copy3DArray(computationEngineTask._AfjTAfj);
            this._afjtarj = HGLMUtils.copy3DArray(computationEngineTask._AfjTArj);
            this._nobs = computationEngineTask._nobs;
        }

        public void setModelOutput(HGLMTask.ComputationEngineTask computationEngineTask) {
            this._randomCatIndices = computationEngineTask._randomCatIndices;
            this._randomNumIndices = computationEngineTask._randomNumIndices;
            this._randomCatArrayStartIndices = computationEngineTask._randomCatArrayStartIndices;
            this._predStartIndexRandom = computationEngineTask._predStartIndexRandom;
            this._randomSlopeToo = (computationEngineTask._numRandomCoeffs == 1 && computationEngineTask._parms._random_intercept) ? false : true;
            this._fixedCatIndices = computationEngineTask._fixedCatIndices;
            this._predStartIndexFixed = computationEngineTask._predStartIndexFixed;
            this._arjtarj = HGLMUtils.copy3DArray(computationEngineTask._ArjTArj);
            this._log_likelihood = Double.NEGATIVE_INFINITY;
        }

        public HGLMModelOutput(HGLM hglm, DataInfo dataInfo) {
            super(hglm, dataInfo._adaptedFrame);
            this._dinfo = dataInfo;
            this._domains = dataInfo._adaptedFrame.domains();
            this._family = ((HGLMParameters) hglm._parms)._family;
            this._random_family = ((HGLMParameters) hglm._parms)._random_family;
        }

        public void setModelOutputFields(ComputationStateHGLM computationStateHGLM) {
            this._fixed_coefficient_names = computationStateHGLM.getFixedCofficientNames();
            this._random_coefficient_names = computationStateHGLM.getRandomCoefficientNames();
            this._group_column_names = computationStateHGLM.getGroupColumnNames();
            this._tauUVar = computationStateHGLM.getTauUVar();
            this._tau_e_var = computationStateHGLM.getTauEVarE10();
            this._tmat = computationStateHGLM.getT();
            this._num_fixed_coeffs = computationStateHGLM.getNumFixedCoeffs();
            this._num_random_coeffs = computationStateHGLM.getNumRandomCoeffs();
            this._numLevel2Units = computationStateHGLM.getNumLevel2Units();
            this._level2UnitIndex = computationStateHGLM.getLevel2UnitIndex();
            this._nobs = computationStateHGLM._nobs;
            this._beta = computationStateHGLM.getBeta();
            this._ubeta = computationStateHGLM.getUbeta();
            this._num_random_coeffs = this._ubeta[0].length;
            this._iterations = computationStateHGLM._iter;
        }

        @Override // hex.Model.Output
        public int nclasses() {
            return 1;
        }

        @Override // hex.Model.Output
        public ModelCategory getModelCategory() {
            return ModelCategory.Regression;
        }
    }

    /* loaded from: input_file:hex/hglm/HGLMModel$HGLMParameters.class */
    public static class HGLMParameters extends Model.Parameters {
        public double[] _initial_fixed_effects;
        public Key _initial_random_effects;
        public Key _initial_t_matrix;
        public String[] _random_columns;
        public String _group_column;
        static final /* synthetic */ boolean $assertionsDisabled;
        public long _seed = -1;
        public int _max_iterations = -1;
        public double _tau_u_var_init = CMAESOptimizer.DEFAULT_STOPFITNESS;
        public double _tau_e_var_init = CMAESOptimizer.DEFAULT_STOPFITNESS;
        public GLMModel.GLMParameters.Family _random_family = GLMModel.GLMParameters.Family.gaussian;
        public double _em_epsilon = 0.001d;
        public boolean _random_intercept = true;
        public Serializable _missing_values_handling = GLMModel.GLMParameters.MissingValuesHandling.MeanImputation;
        public Key<Frame> _plug_values = null;
        public boolean _use_all_factor_levels = false;
        public boolean _showFixedMatVecs = false;
        public int _score_iteration_interval = 5;
        public boolean _score_each_iteration = false;
        public boolean _gen_syn_data = false;
        public GLMModel.GLMParameters.Family _family = GLMModel.GLMParameters.Family.gaussian;
        public Method _method = Method.EM;

        /* loaded from: input_file:hex/hglm/HGLMModel$HGLMParameters$Method.class */
        public enum Method {
            EM
        }

        @Override // hex.Model.Parameters
        public String algoName() {
            return "HGLM";
        }

        @Override // hex.Model.Parameters
        public String fullName() {
            return "Hierarchical Generalized Linear Model";
        }

        @Override // hex.Model.Parameters
        public String javaName() {
            return HGLMModel.class.getName();
        }

        @Override // hex.Model.Parameters
        public long progressUnits() {
            return 1L;
        }

        public GLMModel.GLMParameters.MissingValuesHandling missingValuesHandling() {
            if (this._missing_values_handling instanceof GLMModel.GLMParameters.MissingValuesHandling) {
                return (GLMModel.GLMParameters.MissingValuesHandling) this._missing_values_handling;
            }
            if (!$assertionsDisabled && !(this._missing_values_handling instanceof DeepLearningModel.DeepLearningParameters.MissingValuesHandling)) {
                throw new AssertionError();
            }
            switch ((DeepLearningModel.DeepLearningParameters.MissingValuesHandling) this._missing_values_handling) {
                case MeanImputation:
                    return GLMModel.GLMParameters.MissingValuesHandling.MeanImputation;
                case Skip:
                    return GLMModel.GLMParameters.MissingValuesHandling.Skip;
                default:
                    throw new IllegalStateException("Unsupported missing values handling value: " + this._missing_values_handling);
            }
        }

        public boolean imputeMissing() {
            return missingValuesHandling() == GLMModel.GLMParameters.MissingValuesHandling.MeanImputation || missingValuesHandling() == GLMModel.GLMParameters.MissingValuesHandling.PlugValues;
        }

        public DataInfo.Imputer makeImputer() {
            if (missingValuesHandling() != GLMModel.GLMParameters.MissingValuesHandling.PlugValues) {
                return new DataInfo.MeanImputer();
            }
            if (this._plug_values == null || this._plug_values.get() == null) {
                throw new IllegalStateException("Plug values frame needs to be specified when Missing Value Handling = PlugValues.");
            }
            return new GLM.PlugValuesImputer(this._plug_values.get());
        }

        static {
            $assertionsDisabled = !HGLMModel.class.desiredAssertionStatus();
        }
    }

    public HGLMModel(Key<HGLMModel> key, HGLMParameters hGLMParameters, HGLMModelOutput hGLMModelOutput) {
        super(key, hGLMParameters, hGLMModelOutput);
    }

    @Override // hex.Model
    public ModelMetrics.MetricBuilder makeMetricBuilder(String[] strArr) {
        return new MetricBuilderHGLM(strArr, true, true, ((HGLMParameters) this._parms)._random_intercept, (HGLMModelOutput) this._output);
    }

    @Override // hex.Model
    public String[] makeScoringNames() {
        return new String[]{"predict"};
    }

    @Override // hex.Model
    protected double[] score0(double[] dArr, double[] dArr2) {
        throw new UnsupportedOperationException("HGLMModel.score0 should never be called");
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v4, types: [java.lang.String[], java.lang.String[][]] */
    @Override // hex.Model
    public Model<HGLMModel, HGLMParameters, HGLMModelOutput>.PredictScoreResult predictScoreImpl(Frame frame, Frame frame2, String str, Job job, boolean z, CFuncRef cFuncRef) {
        String[] makeScoringNames = makeScoringNames();
        ?? r0 = new String[makeScoringNames.length];
        boolean equals = ((HGLMParameters) this._parms).train().getKey().equals(frame.getKey());
        HGLMScore makeScoringTask = makeScoringTask(frame2, true, job, z && !((HGLMParameters) this._parms)._gen_syn_data);
        makeScoringTask.doAll(makeScoringNames.length, (byte) 3, makeScoringTask._dinfo._adaptedFrame);
        MetricBuilderHGLM metricBuilderHGLM = null;
        Frame frame3 = null;
        if (makeScoringTask._computeMetrics) {
            metricBuilderHGLM = makeScoringTask._mb;
            if (equals) {
                ((HGLMModelOutput) this._output)._yMinusXTimesZ = makeScoringTask._yMinusXTimesZ;
                ((HGLMModelOutput) this._output)._yMinusFixPredSquare = metricBuilderHGLM._yMinusFixPredSquare;
            } else {
                ((HGLMModelOutput) this._output)._yMinusXTimesZValid = makeScoringTask._yMinusXTimesZ;
                ((HGLMModelOutput) this._output)._yMinusFixPredSquareValid = metricBuilderHGLM._yMinusFixPredSquare;
            }
            frame3 = makeScoringTask.outputFrame();
        }
        r0[0] = makeScoringTask._predDomains;
        return new Model.PredictScoreResult(metricBuilderHGLM, frame3, makeScoringTask.outputFrame(Key.make(str), makeScoringNames, r0));
    }

    private HGLMScore makeScoringTask(Frame frame, boolean z, Job job, boolean z2) {
        int find = frame.find(((HGLMModelOutput) this._output).responseName());
        if (find > -1 && frame.vec(find).isBad()) {
            frame = new Frame(frame.names(), frame.vecs());
            frame.remove(find);
        }
        return new HGLMScore(job, this, ((HGLMModelOutput) this._output)._dinfo.scoringInfo(((HGLMModelOutput) this._output)._names, frame), ((HGLMModelOutput) this._output).nclasses() <= 1 ? null : !(z2 && frame.vec(((HGLMModelOutput) this._output).responseName()) != null && !frame.vec(((HGLMModelOutput) this._output).responseName()).isBad()) ? ((HGLMModelOutput) this._output)._domains[((HGLMModelOutput) this._output)._domains.length - 1] : frame.lastVec().domain(), z2, z);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // hex.Model, water.Keyed
    public Futures remove_impl(Futures futures, boolean z) {
        super.remove_impl(futures, z);
        return futures;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // hex.Model, water.Keyed
    public AutoBuffer writeAll_impl(AutoBuffer autoBuffer) {
        return super.writeAll_impl(autoBuffer);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // hex.Model, water.Keyed
    public Keyed readAll_impl(AutoBuffer autoBuffer, Futures futures) {
        return super.readAll_impl(autoBuffer, futures);
    }

    @Override // hex.Model
    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append(super.toString());
        sb.append(" loglikelihood: " + ((HGLMModelOutput) this._output)._log_likelihood);
        sb.append(" fixed effect coefficients: " + Arrays.toString(((HGLMModelOutput) this._output)._beta));
        int length = ((HGLMModelOutput) this._output)._ubeta.length;
        for (int i = 0; i < length; i++) {
            sb.append(" standard error of random effects for level 2 index " + i + ": " + ((HGLMModelOutput) this._output)._tmat[i][i]);
        }
        sb.append(" standard error of residual error: " + ((HGLMModelOutput) this._output)._tau_e_var);
        sb.append(" ICC: " + Arrays.toString(((HGLMModelOutput) this._output)._icc));
        sb.append(" loglikelihood: " + ((HGLMModelOutput) this._output)._log_likelihood);
        sb.append(" iterations taken to build model: " + ((HGLMModelOutput) this._output)._iterations);
        sb.append(" coefficients for fixed effect: " + Arrays.toString(((HGLMModelOutput) this._output)._beta));
        for (int i2 = 0; i2 < length; i2++) {
            sb.append(" coefficients for random effect for level 2 index: " + i2 + ": " + Arrays.toString(((HGLMModelOutput) this._output)._ubeta[i2]));
        }
        return sb.toString();
    }
}
