package hex.tree.xgboost.task;

import hex.CMetricScoringTask;
import hex.ModelMetrics;
import hex.ModelMetricsBinomial;
import hex.ModelMetricsMultinomial;
import hex.ModelMetricsRegression;
import hex.genmodel.GenModel;
import hex.tree.xgboost.XGBoostModel;
import hex.tree.xgboost.XGBoostOutput;
import hex.tree.xgboost.predict.XGBoostBigScorePredict;
import water.MemoryManager;
import water.fvec.Chunk;
import water.fvec.NewChunk;
import water.udf.CFuncRef;

/* loaded from: input_file:hex/tree/xgboost/task/XGBoostScoreTask.class */
public class XGBoostScoreTask extends CMetricScoringTask<XGBoostScoreTask> {
    private final XGBoostOutput _output;
    private final int _weightsChunkId;
    private final XGBoostModel _model;
    private final boolean _isTrain;
    private final double _threshold;
    public ModelMetrics.MetricBuilder _metricBuilder;
    private transient XGBoostBigScorePredict _predict;
    static final /* synthetic */ boolean $assertionsDisabled;

    public XGBoostScoreTask(XGBoostOutput xGBoostOutput, int i, boolean z, XGBoostModel xGBoostModel, CFuncRef cFuncRef) {
        super(cFuncRef);
        this._output = xGBoostOutput;
        this._weightsChunkId = i;
        this._model = xGBoostModel;
        this._isTrain = z;
        this._threshold = xGBoostModel.defaultThreshold();
    }

    private ModelMetrics.MetricBuilder createMetricsBuilder(int i, String[] strArr) {
        switch (i) {
            case 1:
                return new ModelMetricsRegression.MetricBuilderRegression();
            case 2:
                return new ModelMetricsBinomial.MetricBuilderBinomial(strArr);
            default:
                return new ModelMetricsMultinomial.MetricBuilderMultinomial(i, strArr, ((XGBoostModel.XGBoostParameters) this._model._parms)._auc_type);
        }
    }

    protected void setupLocal() {
        super.setupLocal();
        this._predict = this._model.setupBigScorePredict(this._isTrain);
    }

    public void map(Chunk[] chunkArr, NewChunk[] newChunkArr) {
        this._metricBuilder = createMetricsBuilder(this._output.nclasses(), this._output.classNames());
        float[][] predict = this._predict.mo42initMap(this._fr, chunkArr).predict(chunkArr);
        if (predict.length == 0) {
            return;
        }
        if (!$assertionsDisabled && predict.length != chunkArr[0]._len) {
            throw new AssertionError();
        }
        Chunk chunk = chunkArr[this._output.responseIdx()];
        if (this._output.nclasses() == 1) {
            double[] dArr = new double[1];
            float[] fArr = new float[1];
            for (int i = 0; i < predict.length; i++) {
                dArr[0] = predict[i][0];
                fArr[0] = (float) chunk.atd(i);
                double atd = this._weightsChunkId != -1 ? chunkArr[this._weightsChunkId].atd(i) : 1.0d;
                this._metricBuilder.perRow(dArr, fArr, atd, 0.0d, this._model);
                customMetricPerRow(dArr, fArr, atd, 0.0d, this._model);
            }
            for (int i2 = 0; i2 < chunkArr[0]._len; i2++) {
                newChunkArr[0].addNum(predict[i2][0]);
            }
            return;
        }
        if (this._output.nclasses() == 2) {
            double[] dArr2 = new double[3];
            float[] fArr2 = new float[1];
            for (int i3 = 0; i3 < chunkArr[0]._len; i3++) {
                double d = predict[i3][0];
                dArr2[1] = 1.0d - d;
                dArr2[2] = d;
                dArr2[0] = GenModel.getPrediction(dArr2, this._output._priorClassDist, (double[]) null, this._threshold);
                newChunkArr[0].addNum(dArr2[0]);
                newChunkArr[1].addNum(dArr2[1]);
                newChunkArr[2].addNum(dArr2[2]);
                double atd2 = this._weightsChunkId != -1 ? chunkArr[this._weightsChunkId].atd(i3) : 1.0d;
                fArr2[0] = (float) chunk.atd(i3);
                this._metricBuilder.perRow(dArr2, fArr2, atd2, 0.0d, this._model);
                customMetricPerRow(dArr2, fArr2, atd2, 0.0d, this._model);
            }
            return;
        }
        float[] fArr3 = new float[1];
        double[] malloc8d = MemoryManager.malloc8d(newChunkArr.length);
        for (int i4 = 0; i4 < chunkArr[0]._len; i4++) {
            for (int i5 = 1; i5 < malloc8d.length; i5++) {
                double d2 = predict[i4][i5 - 1];
                newChunkArr[i5].addNum(d2);
                malloc8d[i5] = d2;
            }
            malloc8d[0] = GenModel.getPrediction(malloc8d, this._output._priorClassDist, (double[]) null, this._threshold);
            newChunkArr[0].addNum(malloc8d[0]);
            fArr3[0] = (float) chunk.atd(i4);
            double atd3 = this._weightsChunkId != -1 ? chunkArr[this._weightsChunkId].atd(i4) : 1.0d;
            this._metricBuilder.perRow(malloc8d, fArr3, atd3, 0.0d, this._model);
            customMetricPerRow(malloc8d, fArr3, atd3, 0.0d, this._model);
        }
    }

    public void reduce(XGBoostScoreTask xGBoostScoreTask) {
        super.reduce(xGBoostScoreTask);
        this._metricBuilder.reduce(xGBoostScoreTask._metricBuilder);
    }

    protected void postGlobal() {
        super.postGlobal();
        if (this._metricBuilder != null) {
            this._metricBuilder.postGlobal(getComputedCustomMetric());
            if (null != this.cFuncRef) {
                this._metricBuilder._CMetricScoringTask = this;
            }
        }
    }

    static {
        $assertionsDisabled = !XGBoostScoreTask.class.desiredAssertionStatus();
    }
}
