package hex.glrm;

import hex.DataInfo;
import hex.Model;
import hex.ModelCategory;
import hex.genmodel.algos.glrm.GlrmInitialization;
import hex.genmodel.algos.glrm.GlrmLoss;
import hex.genmodel.algos.glrm.GlrmRegularizer;
import hex.glrm.GLRM;
import hex.glrm.ModelMetricsGLRM;
import hex.svd.SVDModel;
import java.util.ArrayList;
import water.AutoBuffer;
import water.DKV;
import water.Futures;
import water.H2O;
import water.Job;
import water.Key;
import water.Keyed;
import water.MRTask;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.Vec;
import water.util.ArrayUtils;
import water.util.TwoDimTable;

/* loaded from: input_file:hex/glrm/GLRMModel.class */
public class GLRMModel extends Model<GLRMModel, GLRMParameters, GLRMOutput> implements Model.GLRMArchetypes {
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:hex/glrm/GLRMModel$GLRMOutput.class */
    public static class GLRMOutput extends Model.Output {
        public int _iterations;
        public int _updates;
        public double _objective;
        public double _step_size;
        public double _avg_change_obj;
        public ArrayList<Double> _history_objective;
        public TwoDimTable _archetypes;
        public GLRM.Archetypes _archetypes_raw;
        public ArrayList<Double> _history_step_size;
        public double[][] _eigenvectors_raw;
        public TwoDimTable _eigenvectors;
        public double[] _singular_vals;
        public String _representation_name;
        public Key<Frame> _representation_key;
        public Key<? extends Model> _init_key;
        public int _ncats;
        public int _nnums;
        public long _nobs;
        public int[] _catOffsets;
        public double[] _normSub;
        public double[] _normMul;
        public int[] _permutation;
        public String[] _names_expanded;
        public GlrmLoss[] _lossFunc;
        public ArrayList<Long> _training_time_ms;
        public double _total_variance;
        public double[] _std_deviation;
        public TwoDimTable _importance;

        public GLRMOutput(GLRM glrm) {
            super(glrm);
            this._history_objective = new ArrayList<>();
            this._history_step_size = new ArrayList<>();
            this._training_time_ms = new ArrayList<>();
        }

        public int nfeatures() {
            return this._names.length;
        }

        public ModelCategory getModelCategory() {
            return ModelCategory.DimReduction;
        }
    }

    /* loaded from: input_file:hex/glrm/GLRMModel$GLRMParameters.class */
    public static class GLRMParameters extends Model.Parameters {
        public Key<Frame> _user_y;
        public Key<Frame> _user_x;
        public GlrmLoss[] _loss_by_col;
        public int[] _loss_by_col_idx;
        public String _representation_name;
        public DataInfo.TransformType _transform = DataInfo.TransformType.NONE;
        public int _k = 1;
        public GlrmInitialization _init = GlrmInitialization.PlusPlus;
        public SVDModel.SVDParameters.Method _svd_method = SVDModel.SVDParameters.Method.Randomized;
        public boolean _expand_user_y = true;
        public GlrmLoss _loss = GlrmLoss.Quadratic;
        public GlrmLoss _multi_loss = GlrmLoss.Categorical;
        public int _period = 1;
        public GlrmRegularizer _regularization_x = GlrmRegularizer.None;
        public GlrmRegularizer _regularization_y = GlrmRegularizer.None;
        public double _gamma_x = 0.0d;
        public double _gamma_y = 0.0d;
        public int _max_iterations = 1000;
        public int _max_updates = 2 * this._max_iterations;
        public double _init_step_size = 1.0d;
        public double _min_step_size = 1.0E-4d;
        public boolean _recover_svd = false;
        public boolean _impute_original = false;
        public boolean _verbose = true;

        public String algoName() {
            return "GLRM";
        }

        public String fullName() {
            return "Generalized Low Rank Modeling";
        }

        public String javaName() {
            return GLRMModel.class.getName();
        }

        public long progressUnits() {
            return 2 + this._max_iterations;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:hex/glrm/GLRMModel$GLRMScore.class */
    public class GLRMScore extends MRTask<GLRMScore> {
        final int _ncolA;
        final int _ncolX;
        final boolean _save_imputed;
        final boolean _reverse_transform;
        ModelMetricsGLRM.GlrmModelMetricsBuilder _mb;
        static final /* synthetic */ boolean $assertionsDisabled;

        GLRMScore(GLRMModel gLRMModel, int i, int i2, boolean z) {
            this(i, i2, z, ((GLRMParameters) gLRMModel._parms)._impute_original);
        }

        GLRMScore(int i, int i2, boolean z, boolean z2) {
            this._ncolA = i;
            this._ncolX = i2;
            this._save_imputed = z;
            this._reverse_transform = z2;
        }

        public void map(Chunk[] chunkArr) {
            float[] fArr = new float[this._ncolA];
            double[] dArr = new double[this._ncolX];
            double[] dArr2 = new double[this._ncolA];
            this._mb = GLRMModel.this.m94makeMetricBuilder((String[]) null);
            if (!this._save_imputed) {
                for (int i = 0; i < chunkArr[0]._len; i++) {
                    compute_metrics(chunkArr, i, fArr, impute_data(chunkArr, i, dArr, dArr2));
                }
                return;
            }
            for (int i2 = 0; i2 < chunkArr[0]._len; i2++) {
                double[] impute_data = impute_data(chunkArr, i2, dArr, dArr2);
                compute_metrics(chunkArr, i2, fArr, impute_data);
                for (int i3 = 0; i3 < dArr2.length; i3++) {
                    chunkArr[this._ncolA + this._ncolX + i3].set(i2, impute_data[i3]);
                }
            }
        }

        public void reduce(GLRMScore gLRMScore) {
            if (this._mb != null) {
                this._mb.reduce(gLRMScore._mb);
            }
        }

        protected void postGlobal() {
            if (this._mb != null) {
                this._mb.postGlobal();
            }
        }

        private float[] compute_metrics(Chunk[] chunkArr, int i, float[] fArr, double[] dArr) {
            for (int i2 = 0; i2 < fArr.length; i2++) {
                fArr[i2] = (float) chunkArr[i2].atd(i);
            }
            this._mb.perRow(dArr, fArr, GLRMModel.this);
            return fArr;
        }

        private double[] impute_data(Chunk[] chunkArr, int i, double[] dArr, double[] dArr2) {
            for (int i2 = 0; i2 < dArr.length; i2++) {
                dArr[i2] = chunkArr[this._ncolA + i2].atd(i);
            }
            impute_data(dArr, dArr2);
            return dArr2;
        }

        private double[] impute_data(double[] dArr, double[] dArr2) {
            if (!$assertionsDisabled && dArr2.length != ((GLRMOutput) GLRMModel.this._output)._nnums + ((GLRMOutput) GLRMModel.this._output)._ncats) {
                throw new AssertionError();
            }
            for (int i = 0; i < ((GLRMOutput) GLRMModel.this._output)._ncats; i++) {
                dArr2[((GLRMOutput) GLRMModel.this._output)._permutation[i]] = ((GLRMOutput) GLRMModel.this._output)._lossFunc[i].mimpute(((GLRMOutput) GLRMModel.this._output)._archetypes_raw.lmulCatBlock(dArr, i));
            }
            for (int i2 = ((GLRMOutput) GLRMModel.this._output)._ncats; i2 < dArr2.length; i2++) {
                int i3 = i2 - ((GLRMOutput) GLRMModel.this._output)._ncats;
                dArr2[((GLRMOutput) GLRMModel.this._output)._permutation[i2]] = ((GLRMOutput) GLRMModel.this._output)._lossFunc[i2].impute(((GLRMOutput) GLRMModel.this._output)._archetypes_raw.lmulNumCol(dArr, i3));
                if (this._reverse_transform) {
                    dArr2[((GLRMOutput) GLRMModel.this._output)._permutation[i2]] = (dArr2[((GLRMOutput) GLRMModel.this._output)._permutation[i2]] / ((GLRMOutput) GLRMModel.this._output)._normMul[i3]) + ((GLRMOutput) GLRMModel.this._output)._normSub[i3];
                }
            }
            return dArr2;
        }

        static {
            $assertionsDisabled = !GLRMModel.class.desiredAssertionStatus();
        }
    }

    public GLRMModel(Key<GLRMModel> key, GLRMParameters gLRMParameters, GLRMOutput gLRMOutput) {
        super(key, gLRMParameters, gLRMOutput);
    }

    protected Futures remove_impl(Futures futures) {
        if (((GLRMOutput) this._output)._init_key != null) {
            ((GLRMOutput) this._output)._init_key.remove(futures);
        }
        if (((GLRMOutput) this._output)._representation_key != null) {
            ((GLRMOutput) this._output)._representation_key.remove(futures);
        }
        return super.remove_impl(futures);
    }

    protected AutoBuffer writeAll_impl(AutoBuffer autoBuffer) {
        autoBuffer.putKey(((GLRMOutput) this._output)._init_key);
        autoBuffer.putKey(((GLRMOutput) this._output)._representation_key);
        return super.writeAll_impl(autoBuffer);
    }

    protected Keyed readAll_impl(AutoBuffer autoBuffer, Futures futures) {
        autoBuffer.getKey(((GLRMOutput) this._output)._init_key, futures);
        autoBuffer.getKey(((GLRMOutput) this._output)._representation_key, futures);
        return super.readAll_impl(autoBuffer, futures);
    }

    /* renamed from: getMojo, reason: merged with bridge method [inline-methods] */
    public GlrmMojoWriter m93getMojo() {
        return new GlrmMojoWriter(this);
    }

    private Frame reconstruct(Frame frame, Frame frame2, Key<Frame> key, boolean z, boolean z2) {
        int length = ((GLRMOutput) this._output)._names.length;
        if (!$assertionsDisabled && length != frame2.numCols()) {
            throw new AssertionError();
        }
        Frame frame3 = new Frame(frame2);
        frame3.add(DKV.get(((GLRMOutput) this._output)._representation_key).get());
        String[][] domains = frame2.domains();
        Vec anyVec = frame3.anyVec();
        if (!$assertionsDisabled && anyVec == null) {
            throw new AssertionError();
        }
        for (int i = 0; i < length; i++) {
            Vec makeZero = anyVec.makeZero();
            makeZero.setDomain(domains[i]);
            frame3.add("reconstr_" + ((GLRMOutput) this._output)._names[i], makeZero);
        }
        GLRMScore gLRMScore = (GLRMScore) new GLRMScore(length, ((GLRMParameters) this._parms)._k, z, z2).doAll(frame3);
        Frame extractFrame = frame3.extractFrame(length + ((GLRMParameters) this._parms)._k, frame3.numCols());
        Frame frame4 = new Frame(key == null ? Key.make() : key, extractFrame.names(), extractFrame.vecs());
        DKV.put(frame4);
        gLRMScore._mb.makeModelMetrics(this, frame, null, null);
        return frame4;
    }

    protected Frame predictScoreImpl(Frame frame, Frame frame2, String str, Job job, boolean z) {
        return reconstruct(frame, frame2, Key.make(str), true, ((GLRMParameters) this._parms)._impute_original);
    }

    public Frame scoreReconstruction(Frame frame, Key<Frame> key, boolean z) {
        Frame frame2 = new Frame(frame);
        adaptTestForTrain(frame2, true, false);
        return reconstruct(frame, frame2, key, true, z);
    }

    public Frame scoreArchetypes(Frame frame, Key<Frame> key, boolean z) {
        int length = ((GLRMOutput) this._output)._names.length;
        Frame frame2 = new Frame(frame);
        adaptTestForTrain(frame2, true, false);
        if (!$assertionsDisabled && length != frame2.numCols()) {
            throw new AssertionError();
        }
        String[][] domains = frame2.domains();
        double[][] dArr = new double[((GLRMParameters) this._parms)._k][((GLRMOutput) this._output)._nnums + ((GLRMOutput) this._output)._ncats];
        for (int i = 0; i < ((GLRMOutput) this._output)._ncats; i++) {
            double[][] catBlock = ((GLRMOutput) this._output)._archetypes_raw.getCatBlock(i);
            for (int i2 = 0; i2 < ((GLRMParameters) this._parms)._k; i2++) {
                dArr[i2][((GLRMOutput) this._output)._permutation[i]] = ((GLRMOutput) this._output)._lossFunc[i].mimpute(catBlock[i2]);
            }
        }
        for (int i3 = ((GLRMOutput) this._output)._ncats; i3 < ((GLRMOutput) this._output)._ncats + ((GLRMOutput) this._output)._nnums; i3++) {
            int i4 = i3 - ((GLRMOutput) this._output)._ncats;
            for (int i5 = 0; i5 < ((GLRMParameters) this._parms)._k; i5++) {
                dArr[i5][((GLRMOutput) this._output)._permutation[i3]] = ((GLRMOutput) this._output)._lossFunc[i3].impute(((GLRMOutput) this._output)._archetypes_raw.getNum(i4, i5));
                if (z) {
                    dArr[i5][((GLRMOutput) this._output)._permutation[i3]] = (dArr[i5][((GLRMOutput) this._output)._permutation[i3]] / ((GLRMOutput) this._output)._normMul[i4]) + ((GLRMOutput) this._output)._normSub[i4];
                }
            }
        }
        Frame frame3 = ArrayUtils.frame(key == null ? Key.make() : key, frame2.names(), dArr);
        for (int i6 = 0; i6 < length; i6++) {
            frame3.vec(i6).setDomain(domains[i6]);
        }
        return frame3;
    }

    protected double[] score0(double[] dArr, double[] dArr2) {
        throw H2O.unimpl();
    }

    public ModelMetricsGLRM scoreMetricsOnly(Frame frame) {
        if (frame == null) {
            return null;
        }
        int length = ((GLRMOutput) this._output)._names.length;
        Frame frame2 = new Frame(frame);
        adaptTestForTrain(frame2, true, false);
        if (!$assertionsDisabled && length != frame2.numCols()) {
            throw new AssertionError();
        }
        Frame frame3 = new Frame(frame2);
        frame3.add(DKV.get(((GLRMOutput) this._output)._representation_key).get());
        return ((GLRMScore) new GLRMScore(this, length, ((GLRMParameters) this._parms)._k, false).doAll(frame3))._mb.makeModelMetrics(this, frame, null, null);
    }

    /* renamed from: makeMetricBuilder, reason: merged with bridge method [inline-methods] */
    public ModelMetricsGLRM.GlrmModelMetricsBuilder m94makeMetricBuilder(String[] strArr) {
        return new ModelMetricsGLRM.GlrmModelMetricsBuilder(((GLRMParameters) this._parms)._k, ((GLRMOutput) this._output)._permutation, ((GLRMParameters) this._parms)._impute_original);
    }

    static {
        $assertionsDisabled = !GLRMModel.class.desiredAssertionStatus();
    }
}
