package hex.tree.xgboost.predict;

import hex.DataInfo;
import hex.Model;
import hex.genmodel.algos.tree.TreeSHAPPredictor;
import hex.genmodel.algos.xgboost.XGBoostJavaMojoModel;
import hex.tree.xgboost.XGBoostModelInfo;
import hex.tree.xgboost.XGBoostOutput;
import java.util.Arrays;
import org.apache.lucene.util.packed.PackedInts;
import water.MRTask;
import water.MemoryManager;
import water.fvec.Chunk;
import water.fvec.NewChunk;

/* loaded from: input_file:hex/tree/xgboost/predict/PredictTreeSHAPTask.class */
public class PredictTreeSHAPTask extends MRTask<PredictTreeSHAPTask> {
    protected final DataInfo _di;
    protected final XGBoostModelInfo _modelInfo;
    protected final XGBoostOutput _output;
    protected final boolean _outputAggregated;
    protected transient XGBoostJavaMojoModel _mojo;

    public PredictTreeSHAPTask(DataInfo dataInfo, XGBoostModelInfo xGBoostModelInfo, XGBoostOutput xGBoostOutput, Model.Contributions.ContributionsOptions contributionsOptions) {
        this._di = dataInfo;
        this._modelInfo = xGBoostModelInfo;
        this._output = xGBoostOutput;
        this._outputAggregated = Model.Contributions.ContributionsOutputFormat.Compact.equals(contributionsOptions._outputFormat);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // water.MRTask
    public void setupLocal() {
        this._mojo = new XGBoostJavaMojoModel(this._modelInfo._boosterBytes, this._modelInfo.auxNodeWeightBytes(), this._output._names, this._output._domains, this._output.responseName(), true);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void fillInput(Chunk[] chunkArr, int i, double[] dArr, float[] fArr) {
        for (int i2 = 0; i2 < chunkArr.length; i2++) {
            dArr[i2] = chunkArr[i2].atd(i);
        }
        Arrays.fill(fArr, PackedInts.COMPACT);
    }

    @Override // water.MRTask
    public void map(Chunk[] chunkArr, NewChunk[] newChunkArr) {
        MutableOneHotEncoderFVec mutableOneHotEncoderFVec = new MutableOneHotEncoderFVec(this._di, this._output._sparse);
        double[] malloc8d = MemoryManager.malloc8d(chunkArr.length);
        float[] malloc4f = MemoryManager.malloc4f(this._di.fullN() + 1);
        float[] malloc4f2 = this._outputAggregated ? MemoryManager.malloc4f(newChunkArr.length) : malloc4f;
        TreeSHAPPredictor.Workspace makeContributionsWorkspace = this._mojo.makeContributionsWorkspace();
        for (int i = 0; i < chunkArr[0]._len; i++) {
            fillInput(chunkArr, i, malloc8d, malloc4f);
            mutableOneHotEncoderFVec.setInput(malloc8d);
            this._mojo.calculateContributions(mutableOneHotEncoderFVec, malloc4f, makeContributionsWorkspace);
            handleOutputFormat(mutableOneHotEncoderFVec, malloc4f, malloc4f2);
            addContribToNewChunk(malloc4f2, newChunkArr);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void handleOutputFormat(MutableOneHotEncoderFVec mutableOneHotEncoderFVec, float[] fArr, float[] fArr2) {
        if (this._outputAggregated) {
            mutableOneHotEncoderFVec.decodeAggregate(fArr, fArr2);
            fArr2[fArr2.length - 1] = fArr[fArr.length - 1];
        }
    }

    protected void addContribToNewChunk(float[] fArr, NewChunk[] newChunkArr) {
        for (int i = 0; i < newChunkArr.length; i++) {
            newChunkArr[i].addNum(fArr[i]);
        }
    }
}
