package hex.glm;

import hex.DataInfo;
import hex.ModelMetrics;
import hex.glm.GLMModel;
import java.util.Arrays;
import water.Job;
import water.MRTask;
import water.MemoryManager;
import water.fvec.Chunk;
import water.fvec.NewChunk;
import water.util.ArrayUtils;
import water.util.FrameUtils;

/* loaded from: input_file:hex/glm/GLMScore.class */
public class GLMScore extends MRTask<GLMScore> {
    final GLMModel _m;
    final Job _j;
    ModelMetrics.MetricBuilder _mb;
    final DataInfo _dinfo;
    final boolean _sparse;
    final String[] _domain;
    final boolean _computeMetrics;
    final boolean _generatePredictions;
    transient double[][] _vcov;
    transient double[] _tmp;
    transient double[] _eta;
    final int _nclasses;
    private final double[] _beta;
    private final double[][] _beta_multinomial;
    private final double _defaultThreshold;

    public GLMScore(Job job, GLMModel gLMModel, DataInfo dataInfo, String[] strArr, boolean z, boolean z2) {
        this._j = job;
        this._m = gLMModel;
        this._computeMetrics = z;
        this._sparse = FrameUtils.sparseRatio(dataInfo._adaptedFrame) < 0.5d;
        this._domain = strArr;
        this._generatePredictions = z2;
        this._m._parms = gLMModel._parms;
        this._nclasses = ((GLMModel.GLMOutput) gLMModel._output).nclasses();
        if (((GLMModel.GLMParameters) this._m._parms)._family == GLMModel.GLMParameters.Family.multinomial || ((GLMModel.GLMParameters) this._m._parms)._family == GLMModel.GLMParameters.Family.ordinal) {
            this._beta = null;
            this._beta_multinomial = ((GLMModel.GLMOutput) gLMModel._output)._global_beta_multinomial;
        } else {
            double[] beta = gLMModel.beta();
            int[] iArr = new int[beta.length - 1];
            int i = 0;
            for (int i2 = 0; i2 < beta.length - 1; i2++) {
                if (beta[i2] != 0.0d) {
                    int i3 = i;
                    i++;
                    iArr[i3] = i2;
                }
            }
            if (i < beta.length - 1) {
                int[] copyOf = Arrays.copyOf(iArr, i);
                dataInfo = dataInfo.filterExpandedColumns(copyOf);
                double[] malloc8d = MemoryManager.malloc8d(copyOf.length + 1);
                int i4 = 0;
                for (int i5 : copyOf) {
                    int i6 = i4;
                    i4++;
                    malloc8d[i6] = beta[i5];
                }
                malloc8d[i4] = beta[beta.length - 1];
                beta = malloc8d;
            }
            this._beta_multinomial = (double[][]) null;
            this._beta = beta;
        }
        this._dinfo = dataInfo;
        this._dinfo._valid = true;
        this._defaultThreshold = gLMModel.defaultThreshold();
    }

    public double[] scoreRow(DataInfo.Row row, double d, double[] dArr) {
        int i = this._nclasses - 1;
        if (((GLMModel.GLMParameters) this._m._parms)._family == GLMModel.GLMParameters.Family.ordinal) {
            double[][] dArr2 = this._beta_multinomial;
            Arrays.fill(dArr, 1.0E-10d);
            dArr[0] = i;
            double d2 = 0.0d;
            int i2 = 0;
            while (true) {
                if (i2 >= i) {
                    break;
                }
                double innerProduct = row.innerProduct(dArr2[i2]) + d;
                double exp = Math.exp(innerProduct);
                double d3 = exp / (1.0d + exp);
                dArr[i2 + 1] = d3 - d2;
                d2 = d3;
                if (innerProduct >= 0.0d) {
                    dArr[0] = i2;
                    break;
                }
                i2++;
            }
            int i3 = ((int) dArr[0]) + 1;
            while (true) {
                if (i3 >= i) {
                    break;
                }
                double exp2 = Math.exp(row.innerProduct(dArr2[i3]) + d);
                double d4 = exp2 / (1.0d + exp2);
                if (d4 <= d2) {
                    d2 = 0.9999999999d;
                    break;
                }
                dArr[i3 + 1] = d4 - d2;
                d2 = d4;
                i3++;
            }
            dArr[this._nclasses] = 1.0d - d2;
        } else if (((GLMModel.GLMParameters) this._m._parms)._family == GLMModel.GLMParameters.Family.multinomial) {
            double[] dArr3 = this._eta;
            double[][] dArr4 = this._beta_multinomial;
            double d5 = 0.0d;
            double d6 = 0.0d;
            for (int i4 = 0; i4 < dArr4.length; i4++) {
                dArr3[i4] = row.innerProduct(dArr4[i4]) + d;
                if (dArr3[i4] > d6) {
                    d6 = dArr3[i4];
                }
            }
            for (int i5 = 0; i5 < dArr4.length; i5++) {
                double exp3 = Math.exp(dArr3[i5] - d6);
                dArr3[i5] = exp3;
                d5 += exp3;
            }
            double d7 = 1.0d / d5;
            for (int i6 = 0; i6 < dArr4.length; i6++) {
                dArr[i6 + 1] = dArr3[i6] * d7;
            }
            dArr[0] = ArrayUtils.maxIndex(dArr3);
        } else {
            double linkInv = ((GLMModel.GLMParameters) this._m._parms).linkInv(row.innerProduct(this._beta) + d);
            if (((GLMModel.GLMParameters) this._m._parms)._family == GLMModel.GLMParameters.Family.binomial || ((GLMModel.GLMParameters) this._m._parms)._family == GLMModel.GLMParameters.Family.quasibinomial) {
                dArr[0] = linkInv >= this._defaultThreshold ? 1.0d : 0.0d;
                dArr[1] = 1.0d - linkInv;
                dArr[2] = linkInv;
            } else {
                dArr[0] = linkInv;
            }
        }
        return dArr;
    }

    private void processRow(DataInfo.Row row, float[] fArr, double[] dArr, NewChunk[] newChunkArr, int i) {
        if (this._dinfo._responses != 0) {
            fArr[0] = (float) row.response[0];
        }
        if (row.predictors_bad) {
            Arrays.fill(dArr, Double.NaN);
        } else if (row.weight == 0.0d) {
            Arrays.fill(dArr, 0.0d);
        } else {
            scoreRow(row, row.offset, dArr);
            if (this._computeMetrics && !row.response_bad) {
                this._mb.perRow(dArr, fArr, row.weight, row.offset, this._m);
            }
        }
        if (this._generatePredictions) {
            for (int i2 = 0; i2 < i; i2++) {
                newChunkArr[i2].addNum(dArr[i2]);
            }
            if (this._vcov != null) {
                newChunkArr[i].addNum(Math.sqrt(row.innerProduct(row.mtrxMul(this._vcov, this._tmp))));
            }
        }
    }

    public void map(Chunk[] chunkArr, NewChunk[] newChunkArr) {
        double[] dArr;
        if (isCancelled()) {
            return;
        }
        if (this._j == null || !this._j.stop_requested()) {
            if (((GLMModel.GLMParameters) this._m._parms)._family == GLMModel.GLMParameters.Family.multinomial || ((GLMModel.GLMParameters) this._m._parms)._family == GLMModel.GLMParameters.Family.ordinal) {
                this._eta = MemoryManager.malloc8d(this._nclasses);
            }
            this._vcov = ((GLMModel.GLMOutput) this._m._output)._vcov;
            if (this._generatePredictions && this._vcov != null) {
                this._tmp = MemoryManager.malloc8d(this._vcov.length);
            }
            if (this._computeMetrics) {
                this._mb = this._m.makeMetricBuilder(this._domain);
                dArr = this._mb._work;
            } else {
                dArr = new double[((GLMModel.GLMOutput) this._m._output)._nclasses + 1];
            }
            float[] fArr = new float[1];
            int nclasses = ((GLMModel.GLMOutput) this._m._output).nclasses();
            int i = nclasses == 1 ? 1 : nclasses + 1;
            if (this._sparse) {
                for (DataInfo.Row row : this._dinfo.extractSparseRows(chunkArr)) {
                    processRow(row, fArr, dArr, newChunkArr, i);
                }
            } else {
                DataInfo.Row newDenseRow = this._dinfo.newDenseRow();
                for (int i2 = 0; i2 < chunkArr[0]._len; i2++) {
                    this._dinfo.extractDenseRow(chunkArr, i2, newDenseRow);
                    processRow(newDenseRow, fArr, dArr, newChunkArr, i);
                }
            }
            if (this._j != null) {
                this._j.update(1L);
            }
        }
    }

    public void reduce(GLMScore gLMScore) {
        if (this._mb != null) {
            this._mb.reduce(gLMScore._mb);
        }
    }

    protected void postGlobal() {
        if (this._mb != null) {
            this._mb.postGlobal();
        }
    }
}
