package hex.hglm;

import Jama.Matrix;
import hex.DataInfo;
import hex.hglm.HGLMModel;
import hex.hglm.HGLMTask;
import java.util.Random;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;
import water.Job;
import water.util.ArrayUtils;
import water.util.Log;

/* loaded from: input_file:hex/hglm/ComputationStateHGLM.class */
public class ComputationStateHGLM {
    final int _numFixedCoeffs;
    final int _numRandomCoeffs;
    public final HGLMModel.HGLMParameters _parms;
    int _iter;
    private double[] _beta;
    private double[][] _ubeta;
    private double[][] _T;
    final DataInfo _dinfo;
    private final Job _job;
    double _tauEVarE10 = CMAESOptimizer.DEFAULT_STOPFITNESS;
    double _tauEVarE17 = CMAESOptimizer.DEFAULT_STOPFITNESS;
    String[] _fixedCofficientNames;
    String[] _randomCoefficientNames;
    String[] _level2UnitNames;
    final int _numLevel2Unit;
    final int _level2UnitIndex;
    final int _nobs;

    /* loaded from: input_file:hex/hglm/ComputationStateHGLM$ComputationStateSimple.class */
    public static class ComputationStateSimple {
        public final double[] _beta;
        public final double[][] _ubeta;
        public final double[][] _tmat;
        public final double _tauEVar;

        public ComputationStateSimple(double[] dArr, double[][] dArr2, double[][] dArr3, double d) {
            this._beta = dArr;
            this._ubeta = dArr2;
            this._tmat = dArr3;
            this._tauEVar = d;
        }
    }

    public ComputationStateHGLM(Job job, HGLMModel.HGLMParameters hGLMParameters, DataInfo dataInfo, HGLMTask.ComputationEngineTask computationEngineTask, int i) {
        this._job = job;
        this._parms = hGLMParameters;
        this._dinfo = dataInfo;
        this._iter = i;
        this._fixedCofficientNames = computationEngineTask._fixedCoeffNames;
        this._level2UnitNames = computationEngineTask._level2UnitNames;
        this._randomCoefficientNames = computationEngineTask._randomCoeffNames;
        this._level2UnitIndex = computationEngineTask._level2UnitIndex;
        initComputationStateHGLM(computationEngineTask);
        this._numFixedCoeffs = this._beta.length;
        this._numRandomCoeffs = this._ubeta[0].length;
        this._numLevel2Unit = this._ubeta.length;
        this._nobs = computationEngineTask._nobs;
    }

    void initComputationStateHGLM(HGLMTask.ComputationEngineTask computationEngineTask) {
        int length = this._randomCoefficientNames.length;
        int length2 = this._fixedCofficientNames.length;
        if (this._parms._seed == -1) {
            this._parms._seed = new Random().nextLong();
        }
        Log.info("Random seed: " + this._parms._seed);
        Random random = new Random(this._parms._seed);
        if (this._parms._tau_e_var_init > CMAESOptimizer.DEFAULT_STOPFITNESS) {
            this._tauEVarE10 = this._parms._tau_e_var_init;
        } else {
            this._tauEVarE10 = Math.abs(random.nextGaussian());
        }
        this._T = new double[length][length];
        if (this._parms._initial_t_matrix != null) {
            HGLMUtils.grabInitValuesFromFrame(this._parms._initial_t_matrix, this._T);
            if (!HGLMUtils.equal2DArrays(this._T, ArrayUtils.transpose(this._T), 1.0E-6d)) {
                throw new IllegalArgumentException("initial_t_matrix must be symmetric but is not!");
            }
            Matrix matrix = new Matrix(this._T);
            if (this._parms._max_iterations > 0 && !matrix.chol().isSPD()) {
                throw new IllegalArgumentException("initial_t_matrix must be positive semi definite but is not!");
            }
        } else {
            if (this._parms._tau_u_var_init > CMAESOptimizer.DEFAULT_STOPFITNESS) {
                this._tauEVarE10 = this._parms._tau_u_var_init;
            } else {
                this._tauEVarE10 = Math.abs(random.nextGaussian());
            }
            HGLMUtils.setDiagValues(this._T, this._tauEVarE10);
        }
        this._ubeta = new double[computationEngineTask._numLevel2Units][computationEngineTask._numRandomCoeffs];
        if (null != this._parms._initial_random_effects) {
            HGLMUtils.grabInitValuesFromFrame(this._parms._initial_random_effects, this._ubeta);
        } else {
            ArrayUtils.gaussianVector(random, this._ubeta, this._level2UnitNames.length, length);
            ArrayUtils.mult(this._ubeta, Math.sqrt(this._T[0][0]));
        }
        if (null == this._parms._initial_fixed_effects) {
            this._beta = new double[length2];
            this._beta[this._beta.length - 1] = this._parms.train().vec(this._parms._response_column).mean();
        } else {
            if (this._parms._initial_fixed_effects.length != length2) {
                throw new IllegalArgumentException("initial_fixed_effects must be an double[] array of size " + length2);
            }
            this._beta = this._parms._initial_fixed_effects;
        }
    }

    public double[] getBeta() {
        return this._beta;
    }

    public double[][] getUbeta() {
        return this._ubeta;
    }

    public double getTauUVar() {
        return this._tauEVarE10;
    }

    public double getTauEVarE10() {
        return this._tauEVarE10;
    }

    public String[] getFixedCofficientNames() {
        return this._fixedCofficientNames;
    }

    public String[] getRandomCoefficientNames() {
        return this._randomCoefficientNames;
    }

    public String[] getGroupColumnNames() {
        return this._level2UnitNames;
    }

    public double[][] getT() {
        return this._T;
    }

    public int getNumFixedCoeffs() {
        return this._numFixedCoeffs;
    }

    public int getNumRandomCoeffs() {
        return this._numRandomCoeffs;
    }

    public int getNumLevel2Units() {
        return this._numLevel2Unit;
    }

    public int getLevel2UnitIndex() {
        return this._level2UnitIndex;
    }

    public void setBeta(double[] dArr) {
        System.arraycopy(dArr, 0, this._beta, 0, dArr.length);
    }

    public void setUbeta(double[][] dArr) {
        ArrayUtils.copy2DArray(dArr, this._ubeta);
    }

    public void setT(double[][] dArr) {
        ArrayUtils.copy2DArray(dArr, this._T);
    }

    public void setTauEVarE10(double d) {
        this._tauEVarE10 = d;
    }
}
