package hex.hglm;

import Jama.Matrix;
import hex.DataInfo;
import hex.Model;
import hex.ModelMetrics;
import hex.ModelMetricsRegression;
import hex.ModelMetricsRegressionHGLM;
import hex.ModelMetricsSupervised;
import hex.glm.GLMModel;
import hex.hglm.HGLMModel;
import hex.hglm.HGLMTask;
import java.util.Arrays;
import java.util.List;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;
import water.fvec.Frame;
import water.util.ArrayUtils;

/* loaded from: input_file:hex/hglm/MetricBuilderHGLM.class */
public class MetricBuilderHGLM extends ModelMetricsSupervised.MetricBuilderSupervised<MetricBuilderHGLM> {
    public static final double LOG_2PI = Math.log(6.283185307179586d);
    ModelMetrics.MetricBuilder _metricBuilder;
    final boolean _intercept;
    final boolean _random_intercept;
    final boolean _computeMetrics;
    public double[] _beta;
    public double[][] _ubeta;
    public double[][] _tmat;
    public double _yMinusFixPredSquare;
    public double _sse;
    public int _nobs;

    public MetricBuilderHGLM(String[] strArr, boolean z, boolean z2, boolean z3, HGLMModel.HGLMModelOutput hGLMModelOutput) {
        super(strArr == null ? 0 : strArr.length, strArr);
        this._intercept = z2;
        this._computeMetrics = z;
        this._random_intercept = z3;
        this._metricBuilder = new ModelMetricsRegression.MetricBuilderRegression();
        this._beta = hGLMModelOutput._beta;
        this._ubeta = hGLMModelOutput._ubeta;
        this._tmat = hGLMModelOutput._tmat;
    }

    public double[] perRow(double[] dArr, float[] fArr, double d, double d2, double[] dArr2, double[] dArr3, double[][] dArr4, int i, Model model) {
        if (d == CMAESOptimizer.DEFAULT_STOPFITNESS) {
            return dArr;
        }
        this._metricBuilder.perRow(dArr, fArr, d, d2, model);
        add2(fArr[0], dArr[0], d, dArr2, dArr3, dArr4, i, d2);
        return dArr;
    }

    private void add2(double d, double d2, double d3, double[] dArr, double[] dArr2, double[][] dArr3, int i, double d4) {
        double innerProduct = (d - ArrayUtils.innerProduct(this._beta, dArr)) - d4;
        this._yMinusFixPredSquare += innerProduct * innerProduct;
        ArrayUtils.add(dArr3[i], ArrayUtils.mult(dArr2, innerProduct));
        this._nobs++;
        double d5 = d - d2;
        this._sse += d5 * d5;
    }

    @Override // hex.ModelMetrics.MetricBuilder
    public void reduce(MetricBuilderHGLM metricBuilderHGLM) {
        this._metricBuilder.reduce(metricBuilderHGLM._metricBuilder);
        this._yMinusFixPredSquare += metricBuilderHGLM._yMinusFixPredSquare;
        this._sse += metricBuilderHGLM._sse;
        this._nobs += metricBuilderHGLM._nobs;
    }

    @Override // hex.ModelMetrics.MetricBuilder
    public double[] perRow(double[] dArr, float[] fArr, Model model) {
        return dArr;
    }

    @Override // hex.ModelMetrics.MetricBuilder
    public ModelMetrics makeModelMetrics(Model model, Frame frame, Frame frame2, Frame frame3) {
        ModelMetricsRegressionHGLM modelMetricsRegressionHGLM;
        HGLMModel hGLMModel = (HGLMModel) model;
        ModelMetricsRegression modelMetricsRegression = (ModelMetricsRegression) this._metricBuilder.makeModelMetrics(hGLMModel, frame, null, null);
        boolean equals = model._parms.train().getKey().equals(frame.getKey());
        double[][] dArr = ((HGLMModel.HGLMModelOutput) hGLMModel._output)._tmat;
        if (equals) {
            modelMetricsRegressionHGLM = new ModelMetricsRegressionHGLM(model, frame, modelMetricsRegression._nobs, weightedSigma(), calHGLMLlg(modelMetricsRegression._nobs, dArr, ((HGLMModel.HGLMModelOutput) hGLMModel._output)._tau_e_var, ((HGLMModel.HGLMModelOutput) hGLMModel._output)._arjtarj, this._yMinusFixPredSquare, ((HGLMModel.HGLMModelOutput) hGLMModel._output)._yMinusXTimesZ), this._customMetric, ((HGLMModel.HGLMModelOutput) hGLMModel._output)._iterations, ((HGLMModel.HGLMModelOutput) hGLMModel._output)._beta, ((HGLMModel.HGLMModelOutput) hGLMModel._output)._ubeta, dArr, ((HGLMModel.HGLMModelOutput) hGLMModel._output)._tau_e_var, modelMetricsRegression._MSE, this._yMinusFixPredSquare / modelMetricsRegression._nobs, modelMetricsRegression.mae(), modelMetricsRegression._root_mean_squared_log_error, modelMetricsRegression._mean_residual_deviance, modelMetricsRegression.aic());
        } else {
            List asList = Arrays.asList(frame.names());
            DataInfo dataInfo = new DataInfo(frame2, (Frame) null, 1, ((HGLMModel.HGLMParameters) hGLMModel._parms)._use_all_factor_levels, DataInfo.TransformType.NONE, DataInfo.TransformType.NONE, ((HGLMModel.HGLMParameters) hGLMModel._parms).missingValuesHandling() == GLMModel.GLMParameters.MissingValuesHandling.Skip, ((HGLMModel.HGLMParameters) hGLMModel._parms).missingValuesHandling() == GLMModel.GLMParameters.MissingValuesHandling.MeanImputation || ((HGLMModel.HGLMParameters) hGLMModel._parms).missingValuesHandling() == GLMModel.GLMParameters.MissingValuesHandling.PlugValues, ((HGLMModel.HGLMParameters) hGLMModel._parms).makeImputer(), false, ((HGLMModel.HGLMParameters) hGLMModel._parms)._weights_column != null && asList.contains(((HGLMModel.HGLMParameters) hGLMModel._parms)._weights_column), ((HGLMModel.HGLMParameters) hGLMModel._parms)._offset_column != null && asList.contains(((HGLMModel.HGLMParameters) hGLMModel._parms)._offset_column), false, (Model.InteractionSpec) null);
            HGLMTask.ComputationEngineTask computationEngineTask = new HGLMTask.ComputationEngineTask(null, (HGLMModel.HGLMParameters) hGLMModel._parms, dataInfo);
            computationEngineTask.doAll(dataInfo._adaptedFrame);
            modelMetricsRegressionHGLM = new ModelMetricsRegressionHGLM(model, frame, modelMetricsRegression._nobs, weightedSigma(), calHGLMLlg(computationEngineTask._nobs, dArr, ((HGLMModel.HGLMModelOutput) hGLMModel._output)._tau_e_var, computationEngineTask._ArjTArj, this._yMinusFixPredSquare, ((HGLMModel.HGLMModelOutput) hGLMModel._output)._yMinusXTimesZValid), this._customMetric, ((HGLMModel.HGLMModelOutput) hGLMModel._output)._iterations, ((HGLMModel.HGLMModelOutput) hGLMModel._output)._beta, ((HGLMModel.HGLMModelOutput) hGLMModel._output)._ubeta, dArr, ((HGLMModel.HGLMModelOutput) hGLMModel._output)._tau_e_var, modelMetricsRegression._MSE, this._yMinusFixPredSquare / modelMetricsRegression._nobs, modelMetricsRegression.mae(), modelMetricsRegression._root_mean_squared_log_error, modelMetricsRegression._mean_residual_deviance, modelMetricsRegression.aic());
            ((HGLMModel.HGLMModelOutput) hGLMModel._output)._nobs_valid = computationEngineTask._nobs;
        }
        if (model != null) {
            model.addModelMetrics(modelMetricsRegressionHGLM);
        }
        return modelMetricsRegressionHGLM;
    }

    /* JADX WARN: Type inference failed for: r2v9, types: [double[], double[][]] */
    public static double calHGLMLlg(long j, double[][] dArr, double d, double[][][] dArr2, double d2, double[][] dArr3) {
        int length = dArr2.length;
        double[][] array = new Matrix(dArr).inverse().getArray();
        double det = new Matrix(dArr).det();
        double d3 = 1.0d / d;
        double d4 = d3 * d3;
        double d5 = (j * LOG_2PI) + (d3 * d2);
        for (int i = 0; i < length; i++) {
            double[][] calInvTPZjTZ = calInvTPZjTZ(array, dArr2[i], d3);
            double log = d5 + Math.log(d * new Matrix(calInvTPZjTZ).det() * det);
            Matrix matrix = new Matrix(new double[]{dArr3[i]});
            d5 = log - (d4 * matrix.times(new Matrix(calInvTPZjTZ).inverse().times(matrix.transpose())).getArray()[0][0]);
        }
        return (-0.5d) * d5;
    }

    public static double[][] calInvTPZjTZ(double[][] dArr, double[][] dArr2, double d) {
        return new Matrix(dArr).plus(new Matrix(dArr2).times(d)).getArray();
    }
}
