package hex.tree;

import hex.ContributionsWithBackgroundFrameTask;
import hex.DistributionFactory;
import hex.Model;
import hex.genmodel.algos.tree.ContributionComposer;
import hex.genmodel.algos.tree.TreeSHAP;
import hex.genmodel.algos.tree.TreeSHAPEnsemble;
import hex.genmodel.algos.tree.TreeSHAPPredictor;
import hex.tree.SharedTreeModel;
import hex.tree.SharedTreeModel.SharedTreeOutput;
import hex.tree.SharedTreeModel.SharedTreeParameters;
import java.util.ArrayList;
import java.util.Arrays;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;
import water.DKV;
import water.Job;
import water.JobUpdatePostMap;
import water.Key;
import water.MRTask;
import water.MemoryManager;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.NewChunk;
import water.util.ArrayUtils;
import water.util.Log;

/* loaded from: input_file:hex/tree/SharedTreeModelWithContributions.class */
public abstract class SharedTreeModelWithContributions<M extends SharedTreeModel<M, P, O>, P extends SharedTreeModel.SharedTreeParameters, O extends SharedTreeModel.SharedTreeOutput> extends SharedTreeModel<M, P, O> implements Model.Contributions {
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:hex/tree/SharedTreeModelWithContributions$ScoreContributionsSortingTask.class */
    public class ScoreContributionsSortingTask extends SharedTreeModelWithContributions<M, P, O>.ScoreContributionsTask {
        private final int _topN;
        private final int _bottomN;
        private final boolean _compareAbs;

        public ScoreContributionsSortingTask(SharedTreeModel sharedTreeModel, Model.Contributions.ContributionsOptions contributionsOptions) {
            super(sharedTreeModel);
            this._topN = contributionsOptions._topN;
            this._bottomN = contributionsOptions._bottomN;
            this._compareAbs = contributionsOptions._compareAbs;
        }

        protected void fillInput(Chunk[] chunkArr, int i, double[] dArr, float[] fArr, int[] iArr) {
            super.fillInput(chunkArr, i, dArr, fArr);
            for (int i2 = 0; i2 < iArr.length; i2++) {
                iArr[i2] = i2;
            }
        }

        @Override // hex.tree.SharedTreeModelWithContributions.ScoreContributionsTask, water.MRTask
        public void map(Chunk[] chunkArr, NewChunk[] newChunkArr) {
            double[] malloc8d = MemoryManager.malloc8d(chunkArr.length);
            float[] malloc4f = MemoryManager.malloc4f(chunkArr.length + 1);
            int[] malloc4 = MemoryManager.malloc4(chunkArr.length + 1);
            TreeSHAPPredictor.Workspace makeWorkspace = this._treeSHAP.makeWorkspace();
            for (int i = 0; i < chunkArr[0]._len; i++) {
                fillInput(chunkArr, i, malloc8d, malloc4f, malloc4);
                this._treeSHAP.calculateContributions(malloc8d, malloc4f, 0, -1, makeWorkspace);
                doModelSpecificComputation(malloc4f);
                addContribToNewChunk(malloc4f, new ContributionComposer().composeContributions(malloc4, malloc4f, this._topN, this._bottomN, this._compareAbs), newChunkArr);
            }
        }

        protected void addContribToNewChunk(float[] fArr, int[] iArr, NewChunk[] newChunkArr) {
            int i = 0;
            int i2 = 0;
            while (i < newChunkArr.length - 1) {
                newChunkArr[i].addNum(iArr[i2]);
                newChunkArr[i + 1].addNum(fArr[iArr[i2]]);
                i += 2;
                i2++;
            }
            newChunkArr[newChunkArr.length - 1].addNum(fArr[fArr.length - 1]);
        }
    }

    /* loaded from: input_file:hex/tree/SharedTreeModelWithContributions$ScoreContributionsTask.class */
    public class ScoreContributionsTask extends MRTask<SharedTreeModelWithContributions<M, P, O>.ScoreContributionsTask> {
        protected final Key<SharedTreeModel> _modelKey;
        protected transient SharedTreeModel _model;
        protected transient SharedTreeModel.SharedTreeOutput _output;
        protected transient TreeSHAPPredictor<double[]> _treeSHAP;
        static final /* synthetic */ boolean $assertionsDisabled;

        public ScoreContributionsTask(SharedTreeModel sharedTreeModel) {
            this._modelKey = sharedTreeModel._key;
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // water.MRTask
        public void setupLocal() {
            this._model = this._modelKey.get();
            if (!$assertionsDisabled && this._model == null) {
                throw new AssertionError();
            }
            this._output = (SharedTreeModel.SharedTreeOutput) this._model._output;
            if (!$assertionsDisabled && this._output == null) {
                throw new AssertionError();
            }
            ArrayList arrayList = new ArrayList(this._output._ntrees);
            for (int i = 0; i < this._output._ntrees; i++) {
                for (int i2 = 0; i2 < this._output._treeKeys[i].length; i2++) {
                    if (this._output._treeKeys[i][i2] != null) {
                        arrayList.add(new TreeSHAP(this._model.getSharedTreeSubgraph(i, i2).getNodes()));
                    }
                }
            }
            if (!$assertionsDisabled && arrayList.size() != this._output._ntrees) {
                throw new AssertionError();
            }
            this._treeSHAP = new TreeSHAPEnsemble(arrayList, (float) this._output._init_f);
        }

        protected void fillInput(Chunk[] chunkArr, int i, double[] dArr, float[] fArr) {
            for (int i2 = 0; i2 < chunkArr.length; i2++) {
                dArr[i2] = chunkArr[i2].atd(i);
            }
            Arrays.fill(fArr, 0.0f);
        }

        @Override // water.MRTask
        public void map(Chunk[] chunkArr, NewChunk[] newChunkArr) {
            if (!$assertionsDisabled && chunkArr.length != newChunkArr.length - 1) {
                throw new AssertionError();
            }
            double[] malloc8d = MemoryManager.malloc8d(chunkArr.length);
            float[] malloc4f = MemoryManager.malloc4f(newChunkArr.length);
            TreeSHAPPredictor.Workspace makeWorkspace = this._treeSHAP.makeWorkspace();
            for (int i = 0; i < chunkArr[0]._len; i++) {
                fillInput(chunkArr, i, malloc8d, malloc4f);
                this._treeSHAP.calculateContributions(malloc8d, malloc4f, 0, -1, makeWorkspace);
                doModelSpecificComputation(malloc4f);
                addContribToNewChunk(malloc4f, newChunkArr);
            }
        }

        protected void doModelSpecificComputation(float[] fArr) {
        }

        protected void addContribToNewChunk(float[] fArr, NewChunk[] newChunkArr) {
            for (int i = 0; i < newChunkArr.length; i++) {
                newChunkArr[i].addNum(fArr[i]);
            }
        }

        static {
            $assertionsDisabled = !SharedTreeModelWithContributions.class.desiredAssertionStatus();
        }
    }

    /* loaded from: input_file:hex/tree/SharedTreeModelWithContributions$ScoreContributionsWithBackgroundTask.class */
    public class ScoreContributionsWithBackgroundTask extends ContributionsWithBackgroundFrameTask<SharedTreeModelWithContributions<M, P, O>.ScoreContributionsWithBackgroundTask> {
        protected final Key<SharedTreeModel> _modelKey;
        protected transient SharedTreeModel _model;
        protected transient SharedTreeModel.SharedTreeOutput _output;
        protected transient TreeSHAPPredictor<double[]> _treeSHAP;
        protected boolean _expand;
        protected boolean _outputSpace;
        protected int[] _catOffsets;
        static final /* synthetic */ boolean $assertionsDisabled;

        public ScoreContributionsWithBackgroundTask(Key<Frame> key, Key<Frame> key2, boolean z, SharedTreeModel sharedTreeModel, boolean z2, int[] iArr, boolean z3) {
            super(key, key2, z);
            this._modelKey = sharedTreeModel._key;
            this._expand = z2;
            this._catOffsets = iArr;
            this._outputSpace = z3;
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // water.MRTask
        public void setupLocal() {
            this._model = this._modelKey.get();
            if (!$assertionsDisabled && this._model == null) {
                throw new AssertionError();
            }
            this._output = (SharedTreeModel.SharedTreeOutput) this._model._output;
            if (!$assertionsDisabled && this._output == null) {
                throw new AssertionError();
            }
            ArrayList arrayList = new ArrayList(this._output._ntrees);
            for (int i = 0; i < this._output._ntrees; i++) {
                for (int i2 = 0; i2 < this._output._treeKeys[i].length; i2++) {
                    if (this._output._treeKeys[i][i2] != null) {
                        arrayList.add(new TreeSHAP(this._model.getSharedTreeSubgraph(i, i2).getNodes()));
                    }
                }
            }
            if (!$assertionsDisabled && arrayList.size() != this._output._ntrees) {
                throw new AssertionError();
            }
            this._treeSHAP = new TreeSHAPEnsemble(arrayList, (float) this._output._init_f);
        }

        protected void fillInput(Chunk[] chunkArr, int i, double[] dArr) {
            for (int i2 = 0; i2 < chunkArr.length; i2++) {
                dArr[i2] = chunkArr[i2].atd(i);
            }
        }

        @Override // hex.ContributionsWithBackgroundFrameTask
        public void map(Chunk[] chunkArr, Chunk[] chunkArr2, NewChunk[] newChunkArr) {
            if (!$assertionsDisabled && chunkArr.length > newChunkArr.length - 1) {
                throw new AssertionError();
            }
            double[] malloc8d = MemoryManager.malloc8d(chunkArr.length);
            double[] malloc8d2 = MemoryManager.malloc8d(chunkArr2.length);
            double[] malloc8d3 = MemoryManager.malloc8d(newChunkArr.length);
            for (int i = 0; i < chunkArr[0]._len; i++) {
                fillInput(chunkArr, i, malloc8d);
                for (int i2 = 0; i2 < chunkArr2[0]._len; i2++) {
                    Arrays.fill(malloc8d3, CMAESOptimizer.DEFAULT_STOPFITNESS);
                    fillInput(chunkArr2, i2, malloc8d2);
                    this._treeSHAP.calculateInterventionalContributions(malloc8d, malloc8d2, malloc8d3, this._catOffsets, this._expand);
                    doModelSpecificComputation(malloc8d3);
                    addContribToNewChunk(malloc8d3, newChunkArr);
                }
            }
        }

        protected void doModelSpecificComputation(double[] dArr) {
        }

        protected void addContribToNewChunk(double[] dArr, NewChunk[] newChunkArr) {
            double d = 1.0d;
            double d2 = dArr[dArr.length - 1];
            if (this._outputSpace) {
                double sum = Arrays.stream(dArr).sum();
                double linkInv = DistributionFactory.getDistribution(SharedTreeModelWithContributions.this._parms).linkInv(sum);
                double linkInv2 = DistributionFactory.getDistribution(SharedTreeModelWithContributions.this._parms).linkInv(d2);
                d = Math.abs(sum - d2) < 1.0E-6d ? CMAESOptimizer.DEFAULT_STOPFITNESS : (linkInv - linkInv2) / (sum - d2);
                d2 = linkInv2;
            }
            for (int i = 0; i < newChunkArr.length - 1; i++) {
                newChunkArr[i].addNum(dArr[i] * d);
            }
            newChunkArr[newChunkArr.length - 1].addNum(d2);
        }

        static {
            $assertionsDisabled = !SharedTreeModelWithContributions.class.desiredAssertionStatus();
        }
    }

    public SharedTreeModelWithContributions(Key<M> key, P p, O o) {
        super(key, p, o);
    }

    @Override // hex.Model.Contributions
    public Frame scoreContributions(Frame frame, Key<Frame> key) {
        return scoreContributions(frame, key, null);
    }

    protected Frame removeSpecialColumns(Frame frame) {
        Frame frame2 = new Frame(frame);
        adaptTestForTrain(frame2, true, false);
        frame2.remove(((SharedTreeModel.SharedTreeParameters) this._parms)._response_column);
        frame2.remove(((SharedTreeModel.SharedTreeParameters) this._parms)._fold_column);
        frame2.remove(((SharedTreeModel.SharedTreeParameters) this._parms)._weights_column);
        frame2.remove(((SharedTreeModel.SharedTreeParameters) this._parms)._offset_column);
        return frame2;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public Frame removeSpecialNNonNumericColumns(Frame frame) {
        Frame removeSpecialColumns = removeSpecialColumns(frame);
        for (int numCols = removeSpecialColumns.numCols() - 1; numCols >= 0; numCols--) {
            if (!removeSpecialColumns.vec(numCols).isNumeric()) {
                removeSpecialColumns.remove(numCols);
            }
        }
        return removeSpecialColumns;
    }

    @Override // hex.Model.Contributions
    public Frame scoreContributions(Frame frame, Key<Frame> key, Job<Frame> job) {
        if (((SharedTreeModel.SharedTreeOutput) this._output).nclasses() > 2) {
            throw new UnsupportedOperationException("Calculating contributions is currently not supported for multinomial models.");
        }
        Frame removeSpecialColumns = removeSpecialColumns(frame);
        String[] strArr = (String[]) ArrayUtils.append((Object[]) removeSpecialColumns.names(), (Object[]) new String[]{"BiasTerm"});
        return getScoreContributionsTask(this).withPostMapAction(JobUpdatePostMap.forJob(job)).doAll(strArr.length, (byte) 3, removeSpecialColumns).outputFrame(key, strArr, (String[][]) null);
    }

    @Override // hex.Model.Contributions
    public Frame scoreContributions(Frame frame, Key<Frame> key, Job<Frame> job, Model.Contributions.ContributionsOptions contributionsOptions) {
        if (((SharedTreeModel.SharedTreeOutput) this._output).nclasses() > 2) {
            throw new UnsupportedOperationException("Calculating contributions is currently not supported for multinomial models.");
        }
        if (!contributionsOptions.isSortingRequired()) {
            return scoreContributions(frame, key, job);
        }
        Frame removeSpecialColumns = removeSpecialColumns(frame);
        String[] strArr = (String[]) ArrayUtils.append((Object[]) removeSpecialColumns.names(), (Object[]) new String[]{"BiasTerm"});
        ContributionComposer contributionComposer = new ContributionComposer();
        int min = Math.min((contributionComposer.checkAndAdjustInput(contributionsOptions._topN, removeSpecialColumns.names().length) + contributionComposer.checkAndAdjustInput(contributionsOptions._bottomN, removeSpecialColumns.names().length)) * 2, removeSpecialColumns.names().length * 2);
        String[] strArr2 = new String[min + 1];
        byte[] bArr = new byte[min + 1];
        String[][] strArr3 = new String[min + 1][strArr.length];
        composeScoreContributionTaskMetadata(strArr2, bArr, strArr3, removeSpecialColumns.names(), contributionsOptions);
        return getScoreContributionsSoringTask(this, contributionsOptions).withPostMapAction(JobUpdatePostMap.forJob(job)).doAll(bArr, removeSpecialColumns).outputFrame(key, strArr2, strArr3);
    }

    protected abstract SharedTreeModelWithContributions<M, P, O>.ScoreContributionsWithBackgroundTask getScoreContributionsWithBackgroundTask(SharedTreeModel sharedTreeModel, Frame frame, Frame frame2, boolean z, int[] iArr, Model.Contributions.ContributionsOptions contributionsOptions);

    @Override // hex.Model.Contributions
    public Frame scoreContributions(Frame frame, Key<Frame> key, Job<Frame> job, Model.Contributions.ContributionsOptions contributionsOptions, Frame frame2) {
        if (frame2 == null) {
            return scoreContributions(frame, key, job, contributionsOptions);
        }
        if (!$assertionsDisabled && contributionsOptions.isSortingRequired()) {
            throw new AssertionError();
        }
        if (((SharedTreeModel.SharedTreeOutput) this._output).nclasses() > 2) {
            throw new UnsupportedOperationException("Calculating contributions is currently not supported for multinomial models.");
        }
        Log.info("Starting contributions calculation for " + this._key + "...");
        try {
            if (contributionsOptions._outputFormat == Model.Contributions.ContributionsOutputFormat.Compact || ((SharedTreeModel.SharedTreeOutput) this._output)._domains == null) {
                Frame removeSpecialColumns = removeSpecialColumns(frame);
                Frame removeSpecialColumns2 = removeSpecialColumns(frame2);
                DKV.put(removeSpecialColumns);
                DKV.put(removeSpecialColumns2);
                Frame runAndGetOutput = getScoreContributionsWithBackgroundTask(this, removeSpecialColumns, removeSpecialColumns2, false, null, contributionsOptions).runAndGetOutput(job, key, (String[]) ArrayUtils.append((Object[]) removeSpecialColumns.names(), (Object[]) new String[]{"BiasTerm"}));
                if (null != removeSpecialColumns) {
                    Frame.deleteTempFrameAndItsNonSharedVecs(removeSpecialColumns, frame);
                }
                if (null != removeSpecialColumns2) {
                    Frame.deleteTempFrameAndItsNonSharedVecs(removeSpecialColumns2, frame2);
                }
                Log.info("Finished contributions calculation for " + this._key + "...");
                return runAndGetOutput;
            }
            Frame removeSpecialColumns3 = removeSpecialColumns(frame);
            Frame removeSpecialColumns4 = removeSpecialColumns(frame2);
            DKV.put(removeSpecialColumns3);
            DKV.put(removeSpecialColumns4);
            if (!$assertionsDisabled && !Model.Parameters.CategoricalEncodingScheme.Enum.equals(((SharedTreeModel.SharedTreeParameters) this._parms)._categorical_encoding)) {
                throw new AssertionError("Unsupported categorical encoding. Only enum is supported.");
            }
            int[] iArr = new int[((SharedTreeModel.SharedTreeOutput) this._output)._domains.length + 1];
            int i = 1;
            for (int i2 = 0; i2 < ((SharedTreeModel.SharedTreeOutput) this._output)._domains.length; i2++) {
                if (!((SharedTreeModel.SharedTreeOutput) this._output)._names[i2].equals(((SharedTreeModel.SharedTreeParameters) this._parms)._response_column) && !((SharedTreeModel.SharedTreeOutput) this._output)._names[i2].equals(((SharedTreeModel.SharedTreeParameters) this._parms)._fold_column) && !((SharedTreeModel.SharedTreeOutput) this._output)._names[i2].equals(((SharedTreeModel.SharedTreeParameters) this._parms)._weights_column) && !((SharedTreeModel.SharedTreeOutput) this._output)._names[i2].equals(((SharedTreeModel.SharedTreeParameters) this._parms)._offset_column)) {
                    if (null == ((SharedTreeModel.SharedTreeOutput) this._output)._domains[i2]) {
                        iArr[i2 + 1] = iArr[i2] + 1;
                    } else {
                        iArr[i2 + 1] = iArr[i2] + ((SharedTreeModel.SharedTreeOutput) this._output)._domains[i2].length + 1;
                    }
                    i++;
                }
            }
            int[] copyOf = Arrays.copyOf(iArr, i);
            String[] strArr = new String[copyOf[copyOf.length - 1] + 1];
            strArr[copyOf[copyOf.length - 1]] = "BiasTerm";
            int i3 = 0;
            for (int i4 = 0; i4 < ((SharedTreeModel.SharedTreeOutput) this._output)._names.length; i4++) {
                if (!((SharedTreeModel.SharedTreeOutput) this._output)._names[i4].equals(((SharedTreeModel.SharedTreeParameters) this._parms)._response_column) && !((SharedTreeModel.SharedTreeOutput) this._output)._names[i4].equals(((SharedTreeModel.SharedTreeParameters) this._parms)._fold_column) && !((SharedTreeModel.SharedTreeOutput) this._output)._names[i4].equals(((SharedTreeModel.SharedTreeParameters) this._parms)._weights_column) && !((SharedTreeModel.SharedTreeOutput) this._output)._names[i4].equals(((SharedTreeModel.SharedTreeParameters) this._parms)._offset_column)) {
                    if (null == ((SharedTreeModel.SharedTreeOutput) this._output)._domains[i4]) {
                        int i5 = i3;
                        i3++;
                        strArr[i5] = ((SharedTreeModel.SharedTreeOutput) this._output)._names[i4];
                    } else {
                        for (int i6 = 0; i6 < ((SharedTreeModel.SharedTreeOutput) this._output)._domains[i4].length; i6++) {
                            int i7 = i3;
                            i3++;
                            strArr[i7] = ((SharedTreeModel.SharedTreeOutput) this._output)._names[i4] + "." + ((SharedTreeModel.SharedTreeOutput) this._output)._domains[i4][i6];
                        }
                        int i8 = i3;
                        i3++;
                        strArr[i8] = ((SharedTreeModel.SharedTreeOutput) this._output)._names[i4] + ".missing(NA)";
                    }
                }
            }
            Frame runAndGetOutput2 = getScoreContributionsWithBackgroundTask(this, removeSpecialColumns3, removeSpecialColumns4, true, copyOf, contributionsOptions).runAndGetOutput(job, key, strArr);
            if (null != removeSpecialColumns3) {
                Frame.deleteTempFrameAndItsNonSharedVecs(removeSpecialColumns3, frame);
            }
            if (null != removeSpecialColumns4) {
                Frame.deleteTempFrameAndItsNonSharedVecs(removeSpecialColumns4, frame2);
            }
            Log.info("Finished contributions calculation for " + this._key + "...");
            return runAndGetOutput2;
        } catch (Throwable th) {
            if (0 != 0) {
                Frame.deleteTempFrameAndItsNonSharedVecs(null, frame);
            }
            if (0 != 0) {
                Frame.deleteTempFrameAndItsNonSharedVecs(null, frame2);
            }
            Log.info("Finished contributions calculation for " + this._key + "...");
            throw th;
        }
    }

    protected abstract SharedTreeModelWithContributions<M, P, O>.ScoreContributionsTask getScoreContributionsTask(SharedTreeModel sharedTreeModel);

    protected abstract SharedTreeModelWithContributions<M, P, O>.ScoreContributionsTask getScoreContributionsSoringTask(SharedTreeModel sharedTreeModel, Model.Contributions.ContributionsOptions contributionsOptions);

    static {
        $assertionsDisabled = !SharedTreeModelWithContributions.class.desiredAssertionStatus();
    }
}
