package hex.glrm;

import hex.DataInfo;
import hex.Model;
import hex.ModelCategory;
import hex.ModelMetrics;
import hex.glrm.GLRM;
import hex.glrm.ModelMetricsGLRM;
import hex.svd.SVDModel;
import java.util.Random;
import water.AutoBuffer;
import water.DKV;
import water.Futures;
import water.H2O;
import water.Key;
import water.Keyed;
import water.MRTask;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.Vec;
import water.util.ArrayUtils;
import water.util.MathUtils;
import water.util.TwoDimTable;

/* loaded from: input_file:hex/glrm/GLRMModel.class */
public class GLRMModel extends Model<GLRMModel, GLRMParameters, GLRMOutput> implements Model.GLRMArchetypes {
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:hex/glrm/GLRMModel$GLRMOutput.class */
    public static class GLRMOutput extends Model.Output {
        public int _iterations;
        public int _updates;
        public double _objective;
        public double _avg_change_obj;
        public double[] _history_objective;
        public TwoDimTable _archetypes;
        public GLRM.Archetypes _archetypes_raw;
        public double[] _history_step_size;
        public double[][] _eigenvectors_raw;
        public TwoDimTable _eigenvectors;
        public double[] _singular_vals;
        public String _representation_name;
        public Key<Frame> _representation_key;
        public Key<? extends Model> _init_key;
        public int _ncats;
        public int _nnums;
        public long _nobs;
        public int[] _catOffsets;
        public double[] _normSub;
        public double[] _normMul;
        public int[] _permutation;
        public String[] _names_expanded;
        public GLRMParameters.Loss[] _lossFunc;
        public long[] _training_time_ms;

        public GLRMOutput(GLRM glrm) {
            super(glrm);
            this._history_objective = new double[0];
            this._history_step_size = new double[0];
            this._training_time_ms = new long[0];
        }

        public int nfeatures() {
            return this._names.length;
        }

        public ModelCategory getModelCategory() {
            return ModelCategory.DimReduction;
        }
    }

    /* loaded from: input_file:hex/glrm/GLRMModel$GLRMParameters.class */
    public static class GLRMParameters extends Model.Parameters {
        public Key<Frame> _user_y;
        public Key<Frame> _user_x;
        public Loss[] _loss_by_col;
        public int[] _loss_by_col_idx;
        public String _representation_name;
        static final /* synthetic */ boolean $assertionsDisabled;
        public DataInfo.TransformType _transform = DataInfo.TransformType.NONE;
        public int _k = 1;
        public GLRM.Initialization _init = GLRM.Initialization.PlusPlus;
        public SVDModel.SVDParameters.Method _svd_method = SVDModel.SVDParameters.Method.Randomized;
        public boolean _expand_user_y = true;
        public Loss _loss = Loss.Quadratic;
        public Loss _multi_loss = Loss.Categorical;
        public int _period = 1;
        public Regularizer _regularization_x = Regularizer.None;
        public Regularizer _regularization_y = Regularizer.None;
        public double _gamma_x = 0.0d;
        public double _gamma_y = 0.0d;
        public int _max_iterations = 1000;
        public int _max_updates = 2 * this._max_iterations;
        public double _init_step_size = 1.0d;
        public double _min_step_size = 1.0E-4d;
        public long _seed = System.nanoTime();
        public boolean _recover_svd = false;
        public boolean _impute_original = false;
        public boolean _verbose = true;

        /* loaded from: input_file:hex/glrm/GLRMModel$GLRMParameters$Loss.class */
        public enum Loss {
            Quadratic(true),
            Absolute(true),
            Huber(true),
            Poisson(true),
            Periodic(true),
            Logistic(true, true),
            Hinge(true, true),
            Categorical(false),
            Ordinal(false);

            private boolean forNumeric;
            private boolean forBinary;

            Loss(boolean z) {
                this(z, false);
            }

            Loss(boolean z, boolean z2) {
                this.forNumeric = z;
                this.forBinary = z2;
            }

            public boolean isForNumeric() {
                return this.forNumeric;
            }

            public boolean isForCategorical() {
                return !this.forNumeric;
            }

            public boolean isForBinary() {
                return this.forBinary;
            }
        }

        /* loaded from: input_file:hex/glrm/GLRMModel$GLRMParameters$Regularizer.class */
        public enum Regularizer {
            None,
            Quadratic,
            L2,
            L1,
            NonNegative,
            OneSparse,
            UnitOneSparse,
            Simplex
        }

        public String algoName() {
            return "GLRM";
        }

        public String fullName() {
            return "Generalized Low Rank Modeling";
        }

        public String javaName() {
            return GLRMModel.class.getName();
        }

        protected long nFoldSeed() {
            return this._seed;
        }

        private final boolean allLossEquals(Loss loss) {
            if (null == this._loss_by_col) {
                return false;
            }
            boolean z = true;
            int i = 0;
            while (true) {
                if (i >= this._loss_by_col.length) {
                    break;
                }
                if (this._loss_by_col[i] != loss) {
                    z = false;
                    break;
                }
                i++;
            }
            return z;
        }

        public final boolean hasClosedForm() {
            long j = 0;
            Frame frame = this._train.get();
            for (int i = 0; i < frame.numCols(); i++) {
                j += frame.vec(i).naCnt();
            }
            return hasClosedForm(j);
        }

        public final boolean hasClosedForm(long j) {
            return j == 0 && ((null == this._loss_by_col && this._loss == Loss.Quadratic) || (null != this._loss_by_col && allLossEquals(Loss.Quadratic) && (this._loss_by_col.length == this._train.get().numCols() || this._loss == Loss.Quadratic))) && (this._gamma_x == 0.0d || this._regularization_x == Regularizer.None || this._regularization_x == Regularizer.Quadratic) && (this._gamma_y == 0.0d || this._regularization_y == Regularizer.None || this._regularization_y == Regularizer.Quadratic);
        }

        public final double loss(double d, double d2) {
            return loss(d, d2, this._loss);
        }

        public final double loss(double d, double d2, Loss loss) {
            if (!$assertionsDisabled && !loss.isForNumeric()) {
                throw new AssertionError("Loss function " + loss + " not applicable to numerics");
            }
            switch (loss) {
                case Quadratic:
                    return (d - d2) * (d - d2);
                case Absolute:
                    return Math.abs(d - d2);
                case Huber:
                    return Math.abs(d - d2) <= 1.0d ? 0.5d * (d - d2) * (d - d2) : Math.abs(d - d2) - 0.5d;
                case Poisson:
                    if ($assertionsDisabled || d2 >= 0.0d) {
                        return Math.exp(d) + (d2 == 0.0d ? 0.0d : (((-d2) * d) + (d2 * Math.log(d2))) - d2);
                    }
                    throw new AssertionError("Poisson loss L(u,a) requires variable a >= 0");
                case Hinge:
                    return Math.max(1.0d - (d2 == 0.0d ? -d : d), 0.0d);
                case Logistic:
                    return Math.log(1.0d + Math.exp(d2 == 0.0d ? d : -d));
                case Periodic:
                    return 1.0d - Math.cos(((d2 - d) * 6.283185307179586d) / this._period);
                default:
                    throw new RuntimeException("Unknown loss function " + loss);
            }
        }

        public final double lgrad(double d, double d2) {
            return lgrad(d, d2, this._loss);
        }

        public final double lgrad(double d, double d2, Loss loss) {
            if (!$assertionsDisabled && !loss.isForNumeric()) {
                throw new AssertionError("Loss function " + loss + " not applicable to numerics");
            }
            switch (loss) {
                case Quadratic:
                    return 2.0d * (d - d2);
                case Absolute:
                    return Math.signum(d - d2);
                case Huber:
                    return Math.abs(d - d2) <= 1.0d ? d - d2 : Math.signum(d - d2);
                case Poisson:
                    if ($assertionsDisabled || d2 >= 0.0d) {
                        return Math.exp(d) - d2;
                    }
                    throw new AssertionError("Poisson loss L(u,a) requires variable a >= 0");
                case Hinge:
                    if (d2 == 0.0d) {
                        return (-d) <= 1.0d ? 1 : 0;
                    }
                    return d <= 1.0d ? -1 : 0;
                case Logistic:
                    return d2 == 0.0d ? 1.0d / (1.0d + Math.exp(-d)) : (-1.0d) / (1.0d + Math.exp(d));
                case Periodic:
                    return (6.283185307179586d / this._period) * Math.sin(((d2 - d) * 6.283185307179586d) / this._period);
                default:
                    throw new RuntimeException("Unknown loss function " + loss);
            }
        }

        public final double mloss(double[] dArr, int i) {
            return mloss(dArr, i, this._multi_loss);
        }

        public static double mloss(double[] dArr, int i, Loss loss) {
            if (!$assertionsDisabled && !loss.isForCategorical()) {
                throw new AssertionError("Loss function " + loss + " not applicable to categoricals");
            }
            if (i < 0 || i > dArr.length - 1) {
                throw new IllegalArgumentException("Index must be between 0 and " + String.valueOf(dArr.length - 1));
            }
            double d = 0.0d;
            switch (loss) {
                case Categorical:
                    for (double d2 : dArr) {
                        d += Math.max(1.0d + d2, 0.0d);
                    }
                    return d + (Math.max(1.0d - dArr[i], 0.0d) - Math.max(1.0d + dArr[i], 0.0d));
                case Ordinal:
                    int i2 = 0;
                    while (i2 < dArr.length - 1) {
                        d += Math.max(i > i2 ? 1.0d - dArr[i2] : 1.0d, 0.0d);
                        i2++;
                    }
                    return d;
                default:
                    throw new RuntimeException("Unknown multidimensional loss function " + loss);
            }
        }

        public final double[] mlgrad(double[] dArr, int i) {
            return mlgrad(dArr, i, this._multi_loss);
        }

        public static double[] mlgrad(double[] dArr, int i, Loss loss) {
            if (!$assertionsDisabled && !loss.isForCategorical()) {
                throw new AssertionError("Loss function " + loss + " not applicable to categoricals");
            }
            if (i < 0 || i > dArr.length - 1) {
                throw new IllegalArgumentException("Index must be between 0 and " + String.valueOf(dArr.length - 1));
            }
            double[] dArr2 = new double[dArr.length];
            switch (loss) {
                case Categorical:
                    for (int i2 = 0; i2 < dArr.length; i2++) {
                        dArr2[i2] = 1.0d + dArr[i2] > 0.0d ? 1.0d : 0.0d;
                    }
                    dArr2[i] = 1.0d - dArr[i] > 0.0d ? -1.0d : 0.0d;
                    return dArr2;
                case Ordinal:
                    int i3 = 0;
                    while (i3 < dArr.length - 1) {
                        dArr2[i3] = (i <= i3 || 1.0d - dArr[i3] <= 0.0d) ? 0.0d : -1.0d;
                        i3++;
                    }
                    return dArr2;
                default:
                    throw new RuntimeException("Unknown multidimensional loss function " + loss);
            }
        }

        public final double regularize_x(double[] dArr) {
            return regularize(dArr, this._regularization_x);
        }

        public final double regularize_y(double[] dArr) {
            return regularize(dArr, this._regularization_y);
        }

        public final double regularize(double[] dArr, Regularizer regularizer) {
            if (dArr == null) {
                return 0.0d;
            }
            double d = 0.0d;
            switch (regularizer) {
                case None:
                    return 0.0d;
                case Quadratic:
                    for (int i = 0; i < dArr.length; i++) {
                        d += dArr[i] * dArr[i];
                    }
                    return d;
                case L2:
                    for (int i2 = 0; i2 < dArr.length; i2++) {
                        d += dArr[i2] * dArr[i2];
                    }
                    return Math.sqrt(d);
                case L1:
                    for (double d2 : dArr) {
                        d += Math.abs(d2);
                    }
                    return d;
                case NonNegative:
                    for (double d3 : dArr) {
                        if (d3 < 0.0d) {
                            return Double.POSITIVE_INFINITY;
                        }
                    }
                    return 0.0d;
                case OneSparse:
                    int i3 = 0;
                    for (int i4 = 0; i4 < dArr.length; i4++) {
                        if (dArr[i4] < 0.0d) {
                            return Double.POSITIVE_INFINITY;
                        }
                        if (dArr[i4] > 0.0d) {
                            i3++;
                        }
                    }
                    return i3 == 1 ? 0.0d : Double.POSITIVE_INFINITY;
                case UnitOneSparse:
                    int i5 = 0;
                    int i6 = 0;
                    for (int i7 = 0; i7 < dArr.length; i7++) {
                        if (dArr[i7] == 1.0d) {
                            i5++;
                        } else {
                            if (dArr[i7] != 0.0d) {
                                return Double.POSITIVE_INFINITY;
                            }
                            i6++;
                        }
                    }
                    return (i5 == 1 && i6 == dArr.length - 1) ? 0.0d : Double.POSITIVE_INFINITY;
                case Simplex:
                    double d4 = 0.0d;
                    double d5 = 0.0d;
                    for (int i8 = 0; i8 < dArr.length; i8++) {
                        if (dArr[i8] < 0.0d) {
                            return Double.POSITIVE_INFINITY;
                        }
                        d4 += dArr[i8];
                        d5 += Math.abs(dArr[i8]);
                    }
                    return MathUtils.equalsWithinRecSumErr(d4, 1.0d, dArr.length, d5) ? 0.0d : Double.POSITIVE_INFINITY;
                default:
                    throw new RuntimeException("Unknown regularization function " + regularizer);
            }
        }

        public final double regularize_x(double[][] dArr) {
            return regularize(dArr, this._regularization_x);
        }

        public final double regularize_y(double[][] dArr) {
            return regularize(dArr, this._regularization_y);
        }

        public final double regularize(double[][] dArr, Regularizer regularizer) {
            if (dArr == null || regularizer == Regularizer.None) {
                return 0.0d;
            }
            double d = 0.0d;
            for (double[] dArr2 : dArr) {
                d += regularize(dArr2, regularizer);
                if (Double.isInfinite(d)) {
                    return d;
                }
            }
            return d;
        }

        public final double[] rproxgrad_x(double[] dArr, double d, Random random) {
            return rproxgrad(dArr, d, this._gamma_x, this._regularization_x, random);
        }

        public final double[] rproxgrad_y(double[] dArr, double d, Random random) {
            return rproxgrad(dArr, d, this._gamma_y, this._regularization_y, random);
        }

        static double[] rproxgrad(double[] dArr, double d, double d2, Regularizer regularizer, Random random) {
            if (dArr == null || d == 0.0d || d2 == 0.0d) {
                return dArr;
            }
            double[] dArr2 = new double[dArr.length];
            switch (regularizer) {
                case None:
                    return dArr;
                case Quadratic:
                    for (int i = 0; i < dArr.length; i++) {
                        dArr2[i] = dArr[i] / (1.0d + ((2.0d * d) * d2));
                    }
                    return dArr2;
                case L2:
                    double l2norm = 1.0d - ((d * d2) / ArrayUtils.l2norm(dArr));
                    if (l2norm < 0.0d) {
                        return dArr2;
                    }
                    for (int i2 = 0; i2 < dArr.length; i2++) {
                        dArr2[i2] = l2norm * dArr[i2];
                    }
                    return dArr2;
                case L1:
                    for (int i3 = 0; i3 < dArr.length; i3++) {
                        dArr2[i3] = Math.max(dArr[i3] - (d * d2), 0.0d) + Math.min(dArr[i3] + (d * d2), 0.0d);
                    }
                    return dArr2;
                case NonNegative:
                    for (int i4 = 0; i4 < dArr.length; i4++) {
                        dArr2[i4] = Math.max(dArr[i4], 0.0d);
                    }
                    return dArr2;
                case OneSparse:
                    int maxIndex = ArrayUtils.maxIndex(dArr, random);
                    dArr2[maxIndex] = dArr[maxIndex] > 0.0d ? dArr[maxIndex] : 1.0E-6d;
                    return dArr2;
                case UnitOneSparse:
                    dArr2[ArrayUtils.maxIndex(dArr, random)] = 1.0d;
                    return dArr2;
                case Simplex:
                    int length = dArr.length;
                    int[] iArr = new int[length];
                    for (int i5 = 0; i5 < length; i5++) {
                        iArr[i5] = i5;
                    }
                    ArrayUtils.sort(iArr, dArr);
                    double[] dArr3 = new double[length];
                    dArr3[length - 1] = dArr[iArr[length - 1]];
                    for (int i6 = length - 2; i6 >= 0; i6--) {
                        dArr3[i6] = dArr3[i6 + 1] + dArr[iArr[i6]];
                    }
                    double d3 = (dArr3[0] - 1.0d) / length;
                    int i7 = length - 1;
                    while (true) {
                        if (i7 >= 1) {
                            double d4 = (dArr3[i7] - 1.0d) / (length - i7);
                            if (d4 >= dArr[iArr[i7 - 1]]) {
                                d3 = d4;
                            } else {
                                i7--;
                            }
                        }
                    }
                    double[] dArr4 = new double[dArr.length];
                    for (int i8 = 0; i8 < dArr.length; i8++) {
                        dArr4[i8] = Math.max(dArr[i8] - d3, 0.0d);
                    }
                    return dArr4;
                default:
                    throw new RuntimeException("Unknown regularization function " + regularizer);
            }
        }

        public final double[] project_x(double[] dArr, Random random) {
            return project(dArr, this._regularization_x, random);
        }

        public final double[] project_y(double[] dArr, Random random) {
            return project(dArr, this._regularization_y, random);
        }

        public final double[] project(double[] dArr, Regularizer regularizer, Random random) {
            if (dArr == null) {
                return dArr;
            }
            switch (regularizer) {
                case None:
                case Quadratic:
                case L2:
                case L1:
                    return dArr;
                case NonNegative:
                case OneSparse:
                case UnitOneSparse:
                    return rproxgrad(dArr, 1.0d, 1.0d, regularizer, random);
                case Simplex:
                    return regularize(dArr, regularizer) == 0.0d ? dArr : rproxgrad(dArr, 1.0d, 1.0d, regularizer, random);
                default:
                    throw new RuntimeException("Unknown regularization function " + regularizer);
            }
        }

        public final double impute(double d) {
            return impute(d, this._loss);
        }

        public static double impute(double d, Loss loss) {
            if (!$assertionsDisabled && !loss.isForNumeric()) {
                throw new AssertionError("Loss function " + loss + " not applicable to numerics");
            }
            switch (loss) {
                case Quadratic:
                case Absolute:
                case Huber:
                case Periodic:
                    return d;
                case Poisson:
                    return Math.exp(d) - 1.0d;
                case Hinge:
                case Logistic:
                    return d > 0.0d ? 1.0d : 0.0d;
                default:
                    throw new RuntimeException("Unknown loss function " + loss);
            }
        }

        public final int mimpute(double[] dArr) {
            return mimpute(dArr, this._multi_loss);
        }

        public static int mimpute(double[] dArr, Loss loss) {
            if (!$assertionsDisabled && !loss.isForCategorical()) {
                throw new AssertionError("Loss function " + loss + " not applicable to categoricals");
            }
            switch (loss) {
                case Categorical:
                case Ordinal:
                    double[] dArr2 = new double[dArr.length];
                    for (int i = 0; i < dArr2.length; i++) {
                        dArr2[i] = mloss(dArr, i, loss);
                    }
                    return ArrayUtils.minIndex(dArr2);
                default:
                    throw new RuntimeException("Unknown multidimensional loss function " + loss);
            }
        }

        static {
            $assertionsDisabled = !GLRMModel.class.desiredAssertionStatus();
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:hex/glrm/GLRMModel$GLRMScore.class */
    public class GLRMScore extends MRTask<GLRMScore> {
        final int _ncolA;
        final int _ncolX;
        final boolean _save_imputed;
        final boolean _reverse_transform;
        ModelMetrics.MetricBuilder _mb;
        static final /* synthetic */ boolean $assertionsDisabled;

        GLRMScore(GLRMModel gLRMModel, int i, int i2, boolean z) {
            this(i, i2, z, ((GLRMParameters) gLRMModel._parms)._impute_original);
        }

        GLRMScore(int i, int i2, boolean z, boolean z2) {
            this._ncolA = i;
            this._ncolX = i2;
            this._save_imputed = z;
            this._reverse_transform = z2;
        }

        public void map(Chunk[] chunkArr) {
            float[] fArr = new float[this._ncolA];
            double[] dArr = new double[this._ncolX];
            double[] dArr2 = new double[this._ncolA];
            this._mb = GLRMModel.this.makeMetricBuilder(null);
            if (!this._save_imputed) {
                for (int i = 0; i < chunkArr[0]._len; i++) {
                    compute_metrics(chunkArr, i, fArr, impute_data(chunkArr, i, dArr, dArr2));
                }
                return;
            }
            for (int i2 = 0; i2 < chunkArr[0]._len; i2++) {
                double[] impute_data = impute_data(chunkArr, i2, dArr, dArr2);
                compute_metrics(chunkArr, i2, fArr, impute_data);
                for (int i3 = 0; i3 < dArr2.length; i3++) {
                    chunkArr[this._ncolA + this._ncolX + i3].set(i2, impute_data[i3]);
                }
            }
        }

        public void reduce(GLRMScore gLRMScore) {
            if (this._mb != null) {
                this._mb.reduce(gLRMScore._mb);
            }
        }

        protected void postGlobal() {
            if (this._mb != null) {
                this._mb.postGlobal();
            }
        }

        private float[] compute_metrics(Chunk[] chunkArr, int i, float[] fArr, double[] dArr) {
            for (int i2 = 0; i2 < fArr.length; i2++) {
                fArr[i2] = (float) chunkArr[i2].atd(i);
            }
            this._mb.perRow(dArr, fArr, GLRMModel.this);
            return fArr;
        }

        private double[] impute_data(Chunk[] chunkArr, int i, double[] dArr, double[] dArr2) {
            for (int i2 = 0; i2 < dArr.length; i2++) {
                dArr[i2] = chunkArr[this._ncolA + i2].atd(i);
            }
            impute_data(dArr, dArr2);
            return dArr2;
        }

        private double[] impute_data(double[] dArr, double[] dArr2) {
            if (!$assertionsDisabled && dArr2.length != ((GLRMOutput) GLRMModel.this._output)._nnums + ((GLRMOutput) GLRMModel.this._output)._ncats) {
                throw new AssertionError();
            }
            for (int i = 0; i < ((GLRMOutput) GLRMModel.this._output)._ncats; i++) {
                double[] lmulCatBlock = ((GLRMOutput) GLRMModel.this._output)._archetypes_raw.lmulCatBlock(dArr, i);
                int i2 = ((GLRMOutput) GLRMModel.this._output)._permutation[i];
                dArr2[i2] = GLRMParameters.mimpute(lmulCatBlock, ((GLRMOutput) GLRMModel.this._output)._lossFunc[i]);
            }
            for (int i3 = ((GLRMOutput) GLRMModel.this._output)._ncats; i3 < dArr2.length; i3++) {
                int i4 = i3 - ((GLRMOutput) GLRMModel.this._output)._ncats;
                double lmulNumCol = ((GLRMOutput) GLRMModel.this._output)._archetypes_raw.lmulNumCol(dArr, i4);
                int i5 = ((GLRMOutput) GLRMModel.this._output)._permutation[i3];
                dArr2[i5] = GLRMParameters.impute(lmulNumCol, ((GLRMOutput) GLRMModel.this._output)._lossFunc[i3]);
                if (this._reverse_transform) {
                    dArr2[((GLRMOutput) GLRMModel.this._output)._permutation[i3]] = (dArr2[((GLRMOutput) GLRMModel.this._output)._permutation[i3]] / ((GLRMOutput) GLRMModel.this._output)._normMul[i4]) + ((GLRMOutput) GLRMModel.this._output)._normSub[i4];
                }
            }
            return dArr2;
        }

        static {
            $assertionsDisabled = !GLRMModel.class.desiredAssertionStatus();
        }
    }

    public GLRMModel(Key key, GLRMParameters gLRMParameters, GLRMOutput gLRMOutput) {
        super(key, gLRMParameters, gLRMOutput);
    }

    protected Futures remove_impl(Futures futures) {
        if (null != ((GLRMOutput) this._output)._init_key) {
            ((GLRMOutput) this._output)._init_key.remove(futures);
        }
        if (null != ((GLRMOutput) this._output)._representation_key) {
            ((GLRMOutput) this._output)._representation_key.remove(futures);
        }
        return super.remove_impl(futures);
    }

    protected AutoBuffer writeAll_impl(AutoBuffer autoBuffer) {
        autoBuffer.putKey(((GLRMOutput) this._output)._init_key);
        autoBuffer.putKey(((GLRMOutput) this._output)._representation_key);
        return super.writeAll_impl(autoBuffer);
    }

    protected Keyed readAll_impl(AutoBuffer autoBuffer, Futures futures) {
        autoBuffer.getKey(((GLRMOutput) this._output)._init_key, futures);
        autoBuffer.getKey(((GLRMOutput) this._output)._representation_key, futures);
        return super.readAll_impl(autoBuffer, futures);
    }

    private Frame reconstruct(Frame frame, Frame frame2, Key key, boolean z, boolean z2) {
        int length = ((GLRMOutput) this._output)._names.length;
        if (!$assertionsDisabled && length != frame2.numCols()) {
            throw new AssertionError();
        }
        Frame frame3 = new Frame(frame2);
        frame3.add(DKV.get(((GLRMOutput) this._output)._representation_key).get());
        String[][] domains = frame2.domains();
        for (int i = 0; i < length; i++) {
            Vec makeZero = frame3.anyVec().makeZero();
            makeZero.setDomain(domains[i]);
            frame3.add("reconstr_" + ((GLRMOutput) this._output)._names[i], makeZero);
        }
        GLRMScore gLRMScore = (GLRMScore) new GLRMScore(length, ((GLRMParameters) this._parms)._k, z, z2).doAll(frame3);
        Frame extractFrame = frame3.extractFrame(length + ((GLRMParameters) this._parms)._k, frame3.numCols());
        Frame frame4 = new Frame(null == key ? Key.make() : key, extractFrame.names(), extractFrame.vecs());
        DKV.put(frame4);
        gLRMScore._mb.makeModelMetrics(this, frame, (Frame) null, (Frame) null);
        return frame4;
    }

    protected Frame predictScoreImpl(Frame frame, Frame frame2, String str) {
        return reconstruct(frame, frame2, null == str ? Key.make() : Key.make(str), true, ((GLRMParameters) this._parms)._impute_original);
    }

    public Frame scoreReconstruction(Frame frame, Key key, boolean z) {
        Frame frame2 = new Frame(frame);
        adaptTestForTrain(frame2, true, false);
        return reconstruct(frame, frame2, key, true, z);
    }

    public Frame scoreArchetypes(Frame frame, Key key, boolean z) {
        int length = ((GLRMOutput) this._output)._names.length;
        Frame frame2 = new Frame(frame);
        adaptTestForTrain(frame2, true, false);
        if (!$assertionsDisabled && length != frame2.numCols()) {
            throw new AssertionError();
        }
        String[][] domains = frame2.domains();
        double[][] dArr = new double[((GLRMParameters) this._parms)._k][((GLRMOutput) this._output)._nnums + ((GLRMOutput) this._output)._ncats];
        for (int i = 0; i < ((GLRMOutput) this._output)._ncats; i++) {
            double[][] catBlock = ((GLRMOutput) this._output)._archetypes_raw.getCatBlock(i);
            for (int i2 = 0; i2 < ((GLRMParameters) this._parms)._k; i2++) {
                double[] dArr2 = dArr[i2];
                int i3 = ((GLRMOutput) this._output)._permutation[i];
                dArr2[i3] = GLRMParameters.mimpute(catBlock[i2], ((GLRMOutput) this._output)._lossFunc[i]);
            }
        }
        for (int i4 = ((GLRMOutput) this._output)._ncats; i4 < ((GLRMOutput) this._output)._ncats + ((GLRMOutput) this._output)._nnums; i4++) {
            int i5 = i4 - ((GLRMOutput) this._output)._ncats;
            for (int i6 = 0; i6 < ((GLRMParameters) this._parms)._k; i6++) {
                double num = ((GLRMOutput) this._output)._archetypes_raw.getNum(i5, i6);
                double[] dArr3 = dArr[i6];
                int i7 = ((GLRMOutput) this._output)._permutation[i4];
                dArr3[i7] = GLRMParameters.impute(num, ((GLRMOutput) this._output)._lossFunc[i4]);
                if (z) {
                    dArr[i6][((GLRMOutput) this._output)._permutation[i4]] = (dArr[i6][((GLRMOutput) this._output)._permutation[i4]] / ((GLRMOutput) this._output)._normMul[i5]) + ((GLRMOutput) this._output)._normSub[i5];
                }
            }
        }
        Frame frame3 = ArrayUtils.frame(null == key ? Key.make() : key, frame2.names(), dArr);
        for (int i8 = 0; i8 < length; i8++) {
            frame3.vec(i8).setDomain(domains[i8]);
        }
        return frame3;
    }

    protected double[] score0(double[] dArr, double[] dArr2) {
        throw H2O.unimpl();
    }

    public ModelMetricsGLRM scoreMetricsOnly(Frame frame) {
        int length = ((GLRMOutput) this._output)._names.length;
        Frame frame2 = new Frame(frame);
        adaptTestForTrain(frame2, true, false);
        if (!$assertionsDisabled && length != frame2.numCols()) {
            throw new AssertionError();
        }
        Frame frame3 = new Frame(frame2);
        frame3.add(DKV.get(((GLRMOutput) this._output)._representation_key).get());
        return ((GLRMScore) new GLRMScore(this, length, ((GLRMParameters) this._parms)._k, false).doAll(frame3))._mb.makeModelMetrics(this, frame2, (Frame) null, (Frame) null);
    }

    public ModelMetrics.MetricBuilder makeMetricBuilder(String[] strArr) {
        return new ModelMetricsGLRM.GLRMModelMetrics(((GLRMParameters) this._parms)._k, ((GLRMOutput) this._output)._permutation, ((GLRMParameters) this._parms)._impute_original);
    }

    static {
        $assertionsDisabled = !GLRMModel.class.desiredAssertionStatus();
    }
}
