package hex.tree.xgboost.predict;

import hex.DataInfo;
import hex.Model;
import hex.genmodel.algos.tree.ContributionComposer;
import hex.genmodel.algos.tree.TreeSHAPPredictor;
import hex.tree.xgboost.XGBoostModelInfo;
import hex.tree.xgboost.XGBoostOutput;
import water.MemoryManager;
import water.fvec.Chunk;
import water.fvec.NewChunk;

/* loaded from: input_file:hex/tree/xgboost/predict/PredictTreeSHAPSortingTask.class */
public class PredictTreeSHAPSortingTask extends PredictTreeSHAPTask {
    private final boolean _outputAggregated;
    private final int _topN;
    private final int _bottomN;
    private final boolean _compareAbs;

    public PredictTreeSHAPSortingTask(DataInfo dataInfo, XGBoostModelInfo xGBoostModelInfo, XGBoostOutput xGBoostOutput, Model.Contributions.ContributionsOptions contributionsOptions) {
        super(dataInfo, xGBoostModelInfo, xGBoostOutput, contributionsOptions);
        this._outputAggregated = Model.Contributions.ContributionsOutputFormat.Compact.equals(contributionsOptions._outputFormat);
        this._topN = contributionsOptions._topN;
        this._bottomN = contributionsOptions._bottomN;
        this._compareAbs = contributionsOptions._compareAbs;
    }

    protected void fillInput(Chunk[] chunkArr, int i, double[] dArr, float[] fArr, int[] iArr) {
        super.fillInput(chunkArr, i, dArr, fArr);
        for (int i2 = 0; i2 < iArr.length; i2++) {
            iArr[i2] = i2;
        }
    }

    @Override // hex.tree.xgboost.predict.PredictTreeSHAPTask, water.MRTask
    public void map(Chunk[] chunkArr, NewChunk[] newChunkArr) {
        MutableOneHotEncoderFVec mutableOneHotEncoderFVec = new MutableOneHotEncoderFVec(this._di, this._output._sparse);
        double[] malloc8d = MemoryManager.malloc8d(chunkArr.length);
        float[] malloc4f = MemoryManager.malloc4f(this._di.fullN() + 1);
        float[] malloc4f2 = this._outputAggregated ? MemoryManager.malloc4f(chunkArr.length) : malloc4f;
        int[] malloc4 = MemoryManager.malloc4(malloc4f2.length);
        TreeSHAPPredictor.Workspace makeContributionsWorkspace = this._mojo.makeContributionsWorkspace();
        for (int i = 0; i < chunkArr[0]._len; i++) {
            fillInput(chunkArr, i, malloc8d, malloc4f, malloc4);
            mutableOneHotEncoderFVec.setInput(malloc8d);
            this._mojo.calculateContributions(mutableOneHotEncoderFVec, malloc4f, makeContributionsWorkspace);
            handleOutputFormat(mutableOneHotEncoderFVec, malloc4f, malloc4f2);
            addContribToNewChunk(malloc4f, new ContributionComposer().composeContributions(malloc4, malloc4f2, this._topN, this._bottomN, this._compareAbs), newChunkArr);
        }
    }

    protected void addContribToNewChunk(float[] fArr, int[] iArr, NewChunk[] newChunkArr) {
        int i = 0;
        int i2 = 0;
        while (i < newChunkArr.length - 1) {
            newChunkArr[i].addNum(iArr[i2]);
            newChunkArr[i + 1].addNum(fArr[iArr[i2]]);
            i += 2;
            i2++;
        }
        newChunkArr[newChunkArr.length - 1].addNum(fArr[fArr.length - 1]);
    }
}
