package hex.tree.xgboost.predict;

import hex.ContributionsWithBackgroundFrameTask;
import hex.DataInfo;
import hex.Distribution;
import hex.DistributionFactory;
import hex.Model;
import hex.genmodel.algos.xgboost.XGBoostJavaMojoModel;
import hex.genmodel.utils.DistributionFamily;
import hex.tree.xgboost.XGBoostModelInfo;
import hex.tree.xgboost.XGBoostOutput;
import java.util.Arrays;
import water.MemoryManager;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.NewChunk;

/* loaded from: input_file:hex/tree/xgboost/predict/PredictTreeSHAPWithBackgroundTask.class */
public class PredictTreeSHAPWithBackgroundTask extends ContributionsWithBackgroundFrameTask<PredictTreeSHAPWithBackgroundTask> {
    protected final DataInfo _di;
    protected final XGBoostModelInfo _modelInfo;
    protected final XGBoostOutput _output;
    protected final boolean _outputAggregated;
    protected final boolean _outputSpace;
    protected final Distribution _distribution;
    protected transient XGBoostJavaMojoModel _mojo;

    public PredictTreeSHAPWithBackgroundTask(DataInfo dataInfo, XGBoostModelInfo xGBoostModelInfo, XGBoostOutput xGBoostOutput, Model.Contributions.ContributionsOptions contributionsOptions, Frame frame, Frame frame2, boolean z, boolean z2) {
        super(frame._key, frame2._key, z);
        this._di = dataInfo;
        this._modelInfo = xGBoostModelInfo;
        this._output = xGBoostOutput;
        this._outputAggregated = Model.Contributions.ContributionsOutputFormat.Compact.equals(contributionsOptions._outputFormat);
        this._outputSpace = z2;
        this._distribution = z2 ? (this._modelInfo._parameters.getDistributionFamily().equals(DistributionFamily.AUTO) && this._output.isBinomialClassifier()) ? DistributionFactory.getDistribution(DistributionFamily.bernoulli) : DistributionFactory.getDistribution(this._modelInfo._parameters) : null;
    }

    protected void setupLocal() {
        this._mojo = new XGBoostJavaMojoModel(this._modelInfo._boosterBytes, this._modelInfo.auxNodeWeightBytes(), this._output._names, this._output._domains, this._output.responseName(), true);
    }

    protected void fillInput(Chunk[] chunkArr, int i, double[] dArr) {
        for (int i2 = 0; i2 < chunkArr.length; i2++) {
            dArr[i2] = chunkArr[i2].atd(i);
        }
    }

    protected void addContribToNewChunk(double[] dArr, NewChunk[] newChunkArr) {
        double d = 1.0d;
        double d2 = dArr[dArr.length - 1];
        if (this._outputSpace) {
            double sum = Arrays.stream(dArr).sum();
            double linkInv = this._distribution.linkInv(sum);
            double linkInv2 = this._distribution.linkInv(d2);
            d = Math.abs(sum - d2) < 1.0E-6d ? 0.0d : (linkInv - linkInv2) / (sum - d2);
            d2 = linkInv2;
        }
        for (int i = 0; i < newChunkArr.length - 1; i++) {
            newChunkArr[i].addNum(dArr[i] * d);
        }
        newChunkArr[newChunkArr.length - 1].addNum(d2);
    }

    protected void map(Chunk[] chunkArr, Chunk[] chunkArr2, NewChunk[] newChunkArr) {
        MutableOneHotEncoderFVec mutableOneHotEncoderFVec = new MutableOneHotEncoderFVec(this._di, this._output._sparse);
        MutableOneHotEncoderFVec mutableOneHotEncoderFVec2 = new MutableOneHotEncoderFVec(this._di, this._output._sparse);
        double[] malloc8d = MemoryManager.malloc8d(chunkArr.length);
        double[] malloc8d2 = MemoryManager.malloc8d(chunkArr.length);
        double[] malloc8d3 = MemoryManager.malloc8d(this._outputAggregated ? newChunkArr.length : this._di.fullN() + 1);
        for (int i = 0; i < chunkArr[0]._len; i++) {
            fillInput(chunkArr, i, malloc8d);
            mutableOneHotEncoderFVec.setInput(malloc8d);
            for (int i2 = 0; i2 < chunkArr2[0]._len; i2++) {
                Arrays.fill(malloc8d3, 0.0d);
                fillInput(chunkArr2, i2, malloc8d2);
                mutableOneHotEncoderFVec2.setInput(malloc8d2);
                this._mojo.calculateInterventionalContributions(mutableOneHotEncoderFVec, mutableOneHotEncoderFVec2, malloc8d3, this._outputAggregated ? this._di._catOffsets : null, false);
                addContribToNewChunk(malloc8d3, newChunkArr);
            }
        }
    }
}
