package hex.tree.xgboost.predict;

import hex.DataInfo;
import hex.Model;
import hex.genmodel.algos.xgboost.XGBoostMojoModel;
import hex.tree.xgboost.BoosterParms;
import hex.tree.xgboost.XGBoostModel;
import hex.tree.xgboost.XGBoostOutput;
import hex.tree.xgboost.XGBoostUtils;
import hex.tree.xgboost.predict.XGBoostPredict;
import hex.tree.xgboost.util.BoosterHelper;
import java.util.HashMap;
import ml.dmlc.xgboost4j.java.Booster;
import ml.dmlc.xgboost4j.java.DMatrix;
import ml.dmlc.xgboost4j.java.Rabit;
import ml.dmlc.xgboost4j.java.XGBoostError;
import ml.dmlc.xgboost4j.java.XGBoostModelInfo;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.util.Log;

/* loaded from: input_file:hex/tree/xgboost/predict/XGBoostNativeBigScoreChunkPredict.class */
public class XGBoostNativeBigScoreChunkPredict implements XGBoostPredictContrib, Model.BigScoreChunkPredict {
    private final double _threshold;
    private final int _responseIndex;
    private final int _offsetIndex;
    private final XGBoostModelInfo _modelInfo;
    private final XGBoostModel.XGBoostParameters _parms;
    private final DataInfo _dataInfo;
    private final BoosterParms _boosterParms;
    private final XGBoostOutput _output;
    private final float[][] _preds;

    public XGBoostNativeBigScoreChunkPredict(XGBoostModelInfo xGBoostModelInfo, XGBoostModel.XGBoostParameters xGBoostParameters, DataInfo dataInfo, BoosterParms boosterParms, double d, XGBoostOutput xGBoostOutput, Frame frame, Chunk[] chunkArr) {
        this._modelInfo = xGBoostModelInfo;
        this._parms = xGBoostParameters;
        this._dataInfo = dataInfo;
        this._boosterParms = boosterParms;
        this._threshold = d;
        this._output = xGBoostOutput;
        this._responseIndex = frame.find(this._parms._response_column);
        this._offsetIndex = frame.find(this._parms._offset_column);
        this._preds = scoreChunk(chunkArr, XGBoostPredict.OutputType.PREDICT);
    }

    public double[] score0(Chunk[] chunkArr, double d, int i, double[] dArr, double[] dArr2) {
        for (int i2 = 0; i2 < dArr.length; i2++) {
            dArr[i2] = chunkArr[i2].atd(i);
        }
        return XGBoostMojoModel.toPreds(dArr, this._preds[i], dArr2, this._output.nclasses(), (double[]) null, this._threshold);
    }

    @Override // hex.tree.xgboost.predict.XGBoostPredictContrib
    public float[][] predictContrib(Chunk[] chunkArr) {
        return scoreChunk(chunkArr, XGBoostPredict.OutputType.PREDICT_CONTRIB_APPROX);
    }

    @Override // hex.tree.xgboost.predict.XGBoostPredict
    public float[][] predict(Chunk[] chunkArr) {
        return scoreChunk(chunkArr, XGBoostPredict.OutputType.PREDICT);
    }

    /* JADX WARN: Multi-variable type inference failed */
    private float[][] scoreChunk(Chunk[] chunkArr, XGBoostPredict.OutputType outputType) {
        float[][] predictContrib;
        try {
            try {
                Rabit.init(new HashMap());
                DMatrix convertChunksToDMatrix = XGBoostUtils.convertChunksToDMatrix(this._dataInfo, chunkArr, this._responseIndex, this._output._sparse, this._offsetIndex);
                if (convertChunksToDMatrix.rowNum() == 0) {
                    float[][] fArr = (float[][]) null;
                    BoosterHelper.dispose(null, convertChunksToDMatrix);
                    try {
                        Rabit.shutdown();
                    } catch (XGBoostError e) {
                        Log.err(new Object[]{"Failed Rabit shutdown. A hanging RabitTracker task might be present on the driver node.", e});
                    }
                    return fArr;
                }
                Booster deserializeBooster = this._modelInfo.deserializeBooster();
                deserializeBooster.setParams(this._boosterParms.get());
                int i = 0;
                if (this._parms._booster == XGBoostModel.XGBoostParameters.Booster.dart) {
                    i = this._parms._ntrees;
                }
                switch (outputType) {
                    case PREDICT:
                        predictContrib = deserializeBooster.predict(convertChunksToDMatrix, false, i);
                        break;
                    case PREDICT_CONTRIB_APPROX:
                        predictContrib = deserializeBooster.predictContrib(convertChunksToDMatrix, i);
                        break;
                    default:
                        throw new UnsupportedOperationException("Unsupported output type: " + outputType);
                }
                float[][] fArr2 = predictContrib == null ? new float[0] : predictContrib;
                BoosterHelper.dispose(deserializeBooster, convertChunksToDMatrix);
                try {
                    Rabit.shutdown();
                } catch (XGBoostError e2) {
                    Log.err(new Object[]{"Failed Rabit shutdown. A hanging RabitTracker task might be present on the driver node.", e2});
                }
                return fArr2;
            } catch (XGBoostError e3) {
                throw new IllegalStateException("Failed to score with XGBoost.", e3);
            }
        } catch (Throwable th) {
            BoosterHelper.dispose(null, null);
            try {
                Rabit.shutdown();
            } catch (XGBoostError e4) {
                Log.err(new Object[]{"Failed Rabit shutdown. A hanging RabitTracker task might be present on the driver node.", e4});
            }
            throw th;
        }
    }

    public void close() {
    }
}
