package hex.ensemble;

import hex.ContributionsMeanAggregator;
import hex.ContributionsWithBackgroundFrameTask;
import hex.LinkFunction;
import hex.LinkFunctionFactory;
import hex.Model;
import hex.ModelCategory;
import hex.ModelMetrics;
import hex.ensemble.Metalearner;
import hex.genmodel.utils.DistributionFamily;
import hex.genmodel.utils.LinkFunctionType;
import java.lang.invoke.SerializedLambda;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.stream.Stream;
import water.AutoBuffer;
import water.DKV;
import water.Futures;
import water.H2O;
import water.Job;
import water.JobUpdatePostMap;
import water.Key;
import water.Keyed;
import water.LocalMR;
import water.MRTask;
import water.MemoryManager;
import water.MrFun;
import water.SplitToChunksApplyCombine;
import water.exceptions.H2OIllegalArgumentException;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.NewChunk;
import water.fvec.Vec;
import water.udf.CFuncRef;
import water.util.Log;
import water.util.MRUtils;
import water.util.fp.Function2;

/* loaded from: input_file:hex/ensemble/StackedEnsembleModel.class */
public class StackedEnsembleModel extends Model<StackedEnsembleModel, StackedEnsembleParameters, StackedEnsembleOutput> implements Model.Contributions {
    public ModelCategory modelCategory;
    public long trainingFrameRows;
    public String responseColumn;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:hex/ensemble/StackedEnsembleModel$GDeepSHAP.class */
    public class GDeepSHAP extends MRTask<GDeepSHAP> {
        final String[] _columns;
        final int[][] _baseIdx;
        final int[] _metaIdx;
        final int[] _levelOneIdx;
        final int _biasTermIdx;
        final int _biasTermSrc;
        final Integer[] _baseModelIdx;
        final int[] _biasTermIndices;
        final int[] _rowIndices;
        final int[] _rowBgIndices;
        final StackedEnsembleParameters.MetalearnerTransform _metaLearnerTransform;
        static final /* synthetic */ boolean $assertionsDisabled;

        GDeepSHAP(String[] strArr, String[] strArr2, String[] strArr3, Integer[] numArr, StackedEnsembleParameters.MetalearnerTransform metalearnerTransform) {
            this._columns = strArr;
            this._baseIdx = new int[strArr.length][strArr2.length];
            this._metaIdx = new int[strArr2.length];
            this._levelOneIdx = new int[strArr2.length];
            this._biasTermIdx = strArr.length;
            List asList = Arrays.asList(strArr3);
            this._biasTermSrc = asList.indexOf("metalearner_BiasTerm");
            this._baseModelIdx = numArr;
            this._metaLearnerTransform = metalearnerTransform;
            this._biasTermIndices = new int[strArr2.length];
            this._rowIndices = new int[strArr2.length + 1];
            this._rowBgIndices = new int[strArr2.length + 1];
            for (int i = 0; i < strArr.length; i++) {
                for (int i2 = 0; i2 < strArr2.length; i2++) {
                    this._baseIdx[i][i2] = asList.indexOf(strArr2[i2] + "_" + strArr[i]);
                }
            }
            for (int i3 = 0; i3 < strArr2.length; i3++) {
                this._metaIdx[i3] = asList.indexOf("metalearner_" + strArr2[i3]);
                this._levelOneIdx[i3] = asList.indexOf(strArr2[i3]);
                this._biasTermIndices[i3] = asList.indexOf(strArr2[i3] + "_RowIdx");
                this._rowIndices[i3] = asList.indexOf(strArr2[i3] + "_RowIdx");
                this._rowBgIndices[i3] = asList.indexOf(strArr2[i3] + "_BackgroundRowIdx");
            }
            this._rowIndices[strArr2.length] = asList.indexOf("metalearner_RowIdx");
            this._rowBgIndices[strArr2.length] = asList.indexOf("metalearner_BackgroundRowIdx");
        }

        private double baseModelContribution(Chunk[] chunkArr, int i, int i2, int i3) {
            return chunkArr[this._baseIdx[i3][i2]].atd(i);
        }

        private double metalearnerContribution(Chunk[] chunkArr, int i, int i2) {
            return chunkArr[this._metaIdx[i2]].atd(i);
        }

        private double baseModelBiasTerm(Chunk[] chunkArr, int i, int i2) {
            return chunkArr[this._biasTermIndices[i2]].atd(i);
        }

        private double div(double d, double d2) {
            if (Math.abs(d2) < 1.0E-6d) {
                return 0.0d;
            }
            return d / d2;
        }

        public void map(Chunk[] chunkArr, NewChunk[] newChunkArr) {
            double[] malloc8d = MemoryManager.malloc8d(this._metaIdx.length);
            for (int i = 0; i < chunkArr[0]._len; i++) {
                long at8 = chunkArr[this._rowIndices[0]].at8(i);
                long at82 = chunkArr[this._rowBgIndices[0]].at8(i);
                for (int i2 = 0; i2 < this._rowIndices.length; i2++) {
                    if (!$assertionsDisabled && at8 != chunkArr[this._rowIndices[i2]].at8(i)) {
                        throw new AssertionError();
                    }
                    if (!$assertionsDisabled && at82 != chunkArr[this._rowBgIndices[i2]].at8(i)) {
                        throw new AssertionError();
                    }
                }
                Arrays.fill(malloc8d, 0.0d);
                for (int i3 = 0; i3 < this._baseModelIdx.length - 1; i3++) {
                    for (int i4 = 0; i4 < this._columns.length; i4++) {
                        int i5 = i3;
                        malloc8d[i5] = malloc8d[i5] + baseModelContribution(chunkArr, i, i3, i4);
                    }
                    malloc8d[i3] = div(metalearnerContribution(chunkArr, i, i3), malloc8d[i3]);
                }
                for (int i6 = 0; i6 < newChunkArr.length - 3; i6++) {
                    double d = 0.0d;
                    for (int i7 = 0; i7 < malloc8d.length; i7++) {
                        d += malloc8d[i7] * baseModelContribution(chunkArr, i, i7, i6);
                    }
                    newChunkArr[i6].addNum(d);
                }
                newChunkArr[newChunkArr.length - 3].addNum(chunkArr[this._biasTermSrc].atd(i));
                newChunkArr[newChunkArr.length - 2].addNum(chunkArr[this._rowIndices[0]].at8(i));
                newChunkArr[newChunkArr.length - 1].addNum(chunkArr[this._rowBgIndices[0]].at8(i));
            }
        }

        static {
            $assertionsDisabled = !StackedEnsembleModel.class.desiredAssertionStatus();
        }
    }

    /* loaded from: input_file:hex/ensemble/StackedEnsembleModel$StackedEnsembleOutput.class */
    public static class StackedEnsembleOutput extends Model.Output {
        public Model _metalearner;
        public Frame _levelone_frame_id;
        public StackingStrategy _stacking_strategy;
        public Key<Frame>[] _base_model_predictions_keys;

        public StackedEnsembleOutput() {
        }

        public StackedEnsembleOutput(StackedEnsemble stackedEnsemble) {
            super(stackedEnsemble);
        }

        public StackedEnsembleOutput(Job job) {
            this._job = job;
        }

        public int nfeatures() {
            return super.nfeatures() - (this._metalearner._parms._fold_column == null ? 0 : 1);
        }
    }

    /* loaded from: input_file:hex/ensemble/StackedEnsembleModel$StackedEnsembleParameters.class */
    public static class StackedEnsembleParameters extends Model.Parameters {
        public int _metalearner_nfolds;
        public Model.Parameters.FoldAssignmentScheme _metalearner_fold_assignment;
        public String _metalearner_fold_column;
        public Key<Frame> _blending;
        public Model.Parameters _metalearner_parameters;
        static final /* synthetic */ boolean $assertionsDisabled;
        public Key<Model>[] _base_models = new Key[0];
        public boolean _keep_levelone_frame = false;
        public boolean _keep_base_model_predictions = false;
        public MetalearnerTransform _metalearner_transform = MetalearnerTransform.NONE;
        public Metalearner.Algorithm _metalearner_algorithm = Metalearner.Algorithm.AUTO;
        public String _metalearner_params = new String();
        public long _score_training_samples = 10000;

        /* loaded from: input_file:hex/ensemble/StackedEnsembleModel$StackedEnsembleParameters$MetalearnerTransform.class */
        public enum MetalearnerTransform {
            NONE,
            Logit;

            private LinkFunction logitLink = LinkFunctionFactory.getLinkFunction(LinkFunctionType.logit);

            MetalearnerTransform() {
            }

            /* JADX WARN: Type inference failed for: r0v7, types: [hex.ensemble.StackedEnsembleModel$StackedEnsembleParameters$MetalearnerTransform$1] */
            public Frame transform(StackedEnsembleModel stackedEnsembleModel, Frame frame, Key<Frame> key) {
                if (this == Logit) {
                    return new MRTask() { // from class: hex.ensemble.StackedEnsembleModel.StackedEnsembleParameters.MetalearnerTransform.1
                        public void map(Chunk[] chunkArr, NewChunk[] newChunkArr) {
                            for (int i = 0; i < chunkArr.length; i++) {
                                for (int i2 = 0; i2 < chunkArr[i]._len; i2++) {
                                    newChunkArr[i].addNum(MetalearnerTransform.this.logitLink.link(Math.min(0.999999999d, Math.max(chunkArr[i].atd(i2), 1.0E-9d))));
                                }
                            }
                        }
                    }.doAll(frame.numCols(), (byte) 3, frame).outputFrame(key, frame._names, (String[][]) null);
                }
                throw H2O.unimpl("Transformation " + name() + " is not supported.");
            }
        }

        public String algoName() {
            return "StackedEnsemble";
        }

        public String fullName() {
            return "Stacked Ensemble";
        }

        public String javaName() {
            return StackedEnsembleModel.class.getName();
        }

        public long progressUnits() {
            return 1L;
        }

        public void initMetalearnerParams() {
            initMetalearnerParams(this._metalearner_algorithm);
        }

        public void initMetalearnerParams(Metalearner.Algorithm algorithm) {
            this._metalearner_algorithm = algorithm;
            this._metalearner_parameters = Metalearners.createParameters(algorithm.name());
        }

        public final Frame blending() {
            if (this._blending == null) {
                return null;
            }
            return this._blending.get();
        }

        public String[] getNonPredictors() {
            HashSet hashSet = new HashSet();
            hashSet.addAll(Arrays.asList(super.getNonPredictors()));
            if (null != this._metalearner_fold_column) {
                hashSet.add(this._metalearner_fold_column);
            }
            return (String[]) hashSet.toArray(new String[0]);
        }

        public DistributionFamily getDistributionFamily() {
            return this._metalearner_parameters != null ? this._metalearner_parameters.getDistributionFamily() : super.getDistributionFamily();
        }

        public void setDistributionFamily(DistributionFamily distributionFamily) {
            if (!$assertionsDisabled && this._metalearner_parameters == null) {
                throw new AssertionError();
            }
            this._metalearner_parameters.setDistributionFamily(distributionFamily);
        }

        static {
            $assertionsDisabled = !StackedEnsembleModel.class.desiredAssertionStatus();
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:hex/ensemble/StackedEnsembleModel$StackedEnsemblePredictScoreResult.class */
    public class StackedEnsemblePredictScoreResult extends Model<StackedEnsembleModel, StackedEnsembleParameters, StackedEnsembleOutput>.PredictScoreResult {
        private final ModelMetrics _modelMetrics;

        public StackedEnsemblePredictScoreResult(Frame frame, ModelMetrics modelMetrics) {
            super(StackedEnsembleModel.this, (ModelMetrics.MetricBuilder) null, frame, frame);
            this._modelMetrics = modelMetrics;
        }

        public ModelMetrics makeModelMetrics(Frame frame, Frame frame2) {
            return this._modelMetrics;
        }

        public ModelMetrics.MetricBuilder<?> getMetricBuilder() {
            throw new UnsupportedOperationException("Stacked Ensemble model doesn't implement MetricBuilder infrastructure code, retrieve your metrics by calling getOrMakeMetrics method.");
        }
    }

    /* loaded from: input_file:hex/ensemble/StackedEnsembleModel$StackingStrategy.class */
    public enum StackingStrategy {
        cross_validation,
        blending
    }

    int numOfUsefulBaseModels() {
        int i = 0;
        for (Key<Model> key : ((StackedEnsembleParameters) this._parms)._base_models) {
            if (isUsefulBaseModel(key)) {
                i++;
            }
        }
        return i;
    }

    private Frame baseLineContributions(Frame frame, Key<Frame> key, Job<Frame> job, Model.Contributions.ContributionsOptions contributionsOptions, Frame frame2) {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        String[] strArr = null;
        arrayList2.add(0);
        Frame frame3 = new Frame(new Vec[0]);
        try {
            for (Key<Model> key2 : ((StackedEnsembleParameters) this._parms)._base_models) {
                if (isUsefulBaseModel(key2)) {
                    arrayList.add(key2.toString());
                    Frame scoreContributions = key2.get().scoreContributions(frame, Key.make(key.toString() + "_" + key2), job, new Model.Contributions.ContributionsOptions().setOutputFormat(contributionsOptions._outputFormat).setOutputSpace(true).setOutputPerReference(true), frame2);
                    if (null == strArr) {
                        strArr = scoreContributions._names;
                    }
                    if (!Arrays.equals(strArr, scoreContributions._names)) {
                        if (strArr.length == scoreContributions._names.length) {
                            HashSet hashSet = new HashSet();
                            List asList = Arrays.asList(strArr);
                            List asList2 = Arrays.asList(scoreContributions._names);
                            hashSet.addAll(asList);
                            if (hashSet.containsAll(asList2)) {
                                int[] iArr = new int[strArr.length];
                                for (int i = 0; i < strArr.length; i++) {
                                    iArr[i] = asList2.indexOf(strArr[i]);
                                }
                                scoreContributions.reOrder(iArr);
                            }
                        }
                        if (!Arrays.equals(strArr, scoreContributions._names)) {
                            Frame.deleteTempFrameAndItsNonSharedVecs(scoreContributions, frame3);
                            if (Model.Contributions.ContributionsOutputFormat.Original.equals(contributionsOptions._outputFormat)) {
                                throw new IllegalArgumentException("Base model contributions have different columns likely due to models using different categorical encoding. Please use output_format=\"compact\".");
                            }
                            throw new RuntimeException("Base model contributions have different columns. This is not expected. Please fill in a bug report.");
                        }
                    }
                    scoreContributions.setNames((String[]) Arrays.stream(scoreContributions._names).map(str -> {
                        return key2 + "_" + str;
                    }).toArray(i2 -> {
                        return new String[i2];
                    }));
                    frame3.add(scoreContributions);
                    Frame.deleteTempFrameAndItsNonSharedVecs(scoreContributions, frame3);
                    arrayList2.add(Integer.valueOf(frame3.numCols()));
                }
            }
            if (arrayList.isEmpty()) {
                throw new RuntimeException("Stacked Ensemble \"" + this._key + "\" doesn't use any base models. Stopping contribution calculation as no feature contributes.");
            }
            if (!$assertionsDisabled && (!strArr[strArr.length - 3].equals("BiasTerm") || !strArr[strArr.length - 2].equals("RowIdx") || !strArr[strArr.length - 1].equals("BackgroundRowIdx"))) {
                throw new AssertionError();
            }
            String[] strArr2 = strArr;
            String[] strArr3 = (String[]) Arrays.copyOfRange(strArr, 0, strArr.length - 3);
            Frame adaptFrameForScore = adaptFrameForScore(frame, false, new ArrayList());
            Frame levelOnePredictFrame = getLevelOnePredictFrame(frame, adaptFrameForScore, job);
            Frame adaptFrameForScore2 = adaptFrameForScore(frame2, false, new ArrayList());
            Frame levelOnePredictFrame2 = getLevelOnePredictFrame(frame2, adaptFrameForScore2, job);
            Frame scoreContributions2 = ((StackedEnsembleOutput) this._output)._metalearner.scoreContributions(levelOnePredictFrame, Key.make(key + "_" + ((StackedEnsembleOutput) this._output)._metalearner._key), job, new Model.Contributions.ContributionsOptions().setOutputFormat(contributionsOptions._outputFormat).setOutputSpace(contributionsOptions._outputSpace).setOutputPerReference(true), levelOnePredictFrame2);
            scoreContributions2.setNames((String[]) Arrays.stream(scoreContributions2._names).map(str2 -> {
                return "metalearner_" + str2;
            }).toArray(i3 -> {
                return new String[i3];
            }));
            frame3.add(scoreContributions2);
            Frame.deleteTempFrameAndItsNonSharedVecs(scoreContributions2, frame3);
            Frame outputFrame = ((GDeepSHAP) new GDeepSHAP(strArr3, (String[]) arrayList.toArray(new String[0]), frame3._names, (Integer[]) arrayList2.toArray(new Integer[0]), ((StackedEnsembleParameters) this._parms)._metalearner_transform).withPostMapAction(JobUpdatePostMap.forJob(job)).doAll(strArr2.length, (byte) 3, frame3)).outputFrame(key, strArr2, (String[][]) null);
            if (null != levelOnePredictFrame) {
                Frame.deleteTempFrameAndItsNonSharedVecs(levelOnePredictFrame, frame);
            }
            if (null != levelOnePredictFrame2) {
                Frame.deleteTempFrameAndItsNonSharedVecs(levelOnePredictFrame2, frame2);
            }
            Frame.deleteTempFrameAndItsNonSharedVecs(frame3, frame);
            if (null != adaptFrameForScore) {
                Frame.deleteTempFrameAndItsNonSharedVecs(adaptFrameForScore, frame);
            }
            if (null != adaptFrameForScore2) {
                Frame.deleteTempFrameAndItsNonSharedVecs(adaptFrameForScore2, frame2);
            }
            return outputFrame;
        } catch (Throwable th) {
            if (0 != 0) {
                Frame.deleteTempFrameAndItsNonSharedVecs((Frame) null, frame);
            }
            if (0 != 0) {
                Frame.deleteTempFrameAndItsNonSharedVecs((Frame) null, frame2);
            }
            Frame.deleteTempFrameAndItsNonSharedVecs(frame3, frame);
            if (0 != 0) {
                Frame.deleteTempFrameAndItsNonSharedVecs((Frame) null, frame);
            }
            if (0 != 0) {
                Frame.deleteTempFrameAndItsNonSharedVecs((Frame) null, frame2);
            }
            throw th;
        }
    }

    public long scoreContributionsWorkEstimate(Frame frame, Frame frame2, boolean z) {
        long max = (Math.max(frame.numRows(), frame2.numRows()) * (numOfUsefulBaseModels() + 1)) + (frame.numRows() * frame2.numRows());
        if (!z) {
            max += frame.numRows() * frame2.numRows();
        }
        return max;
    }

    public Frame scoreContributions(Frame frame, Key<Frame> key, Job<Frame> job, Model.Contributions.ContributionsOptions contributionsOptions, Frame frame2) {
        if (null == frame2) {
            throw H2O.unimpl("StackedEnsemble supports contribution calculation only with a background frame.");
        }
        Log.info(new Object[]{"Starting contributions calculation for " + this._key + "..."});
        try {
            if (contributionsOptions._outputPerReference) {
                Frame baseLineContributions = baseLineContributions(frame, key, job, contributionsOptions, frame2);
                Log.info(new Object[]{"Finished contributions calculation for " + this._key + "..."});
                return baseLineContributions;
            }
            Function2 function2 = (frame3, bool) -> {
                Frame baseLineContributions2 = baseLineContributions(frame3, Key.make(key + "_individual_contribs_for_subframe_" + frame3._key), job, contributionsOptions, frame2);
                String[] strArr = (String[]) Arrays.copyOf(baseLineContributions2.names(), baseLineContributions2.names().length - 3);
                String[] strArr2 = (String[]) Arrays.copyOf(baseLineContributions2.names(), baseLineContributions2.names().length - 2);
                if (!$assertionsDisabled && !strArr2[strArr2.length - 1].equals("BiasTerm")) {
                    throw new AssertionError();
                }
                try {
                    Frame outputFrame = ((ContributionsMeanAggregator) new ContributionsMeanAggregator(job, (int) frame3.numRows(), strArr.length + 1, (int) frame2.numRows()).withPostMapAction(JobUpdatePostMap.forJob(job)).doAll(strArr.length + 1, (byte) 3, baseLineContributions2)).outputFrame(bool.booleanValue() ? key : Key.make(key + "_for_subframe_" + frame3._key), strArr2, (String[][]) null);
                    baseLineContributions2.delete(true);
                    return outputFrame;
                } catch (Throwable th) {
                    baseLineContributions2.delete(true);
                    throw th;
                }
            };
            if (frame2.anyVec().nChunks() > H2O.CLOUD._memary.length || !ContributionsWithBackgroundFrameTask.enoughMinMemory(numOfUsefulBaseModels() * ContributionsWithBackgroundFrameTask.estimatePerNodeMinimalMemory(frame.numCols(), frame, frame2))) {
                Frame splitApplyCombine = SplitToChunksApplyCombine.splitApplyCombine(frame, frame4 -> {
                    return (Frame) function2.apply(frame4, false);
                }, key);
                Log.info(new Object[]{"Finished contributions calculation for " + this._key + "..."});
                return splitApplyCombine;
            }
            Frame frame5 = (Frame) function2.apply(frame, true);
            DKV.put(frame5);
            Log.info(new Object[]{"Finished contributions calculation for " + this._key + "..."});
            return frame5;
        } catch (Throwable th) {
            Log.info(new Object[]{"Finished contributions calculation for " + this._key + "..."});
            throw th;
        }
    }

    public StackedEnsembleModel(Key key, StackedEnsembleParameters stackedEnsembleParameters, StackedEnsembleOutput stackedEnsembleOutput) {
        super(key, stackedEnsembleParameters, stackedEnsembleOutput);
        this.trainingFrameRows = -1L;
        this.responseColumn = null;
    }

    public void initActualParamValues() {
        super.initActualParamValues();
        if (((StackedEnsembleParameters) this._parms)._metalearner_fold_assignment == Model.Parameters.FoldAssignmentScheme.AUTO) {
            ((StackedEnsembleParameters) this._parms)._metalearner_fold_assignment = Model.Parameters.FoldAssignmentScheme.Random;
        }
    }

    public boolean haveMojo() {
        return super.haveMojo() && Stream.of((Object[]) ((StackedEnsembleParameters) this._parms)._base_models).filter(this::isUsefulBaseModel).map(DKV::getGet).allMatch((v0) -> {
            return v0.haveMojo();
        });
    }

    protected Model<StackedEnsembleModel, StackedEnsembleParameters, StackedEnsembleOutput>.PredictScoreResult predictScoreImpl(Frame frame, Frame frame2, String str, Job job, boolean z, CFuncRef cFuncRef) {
        Frame levelOnePredictFrame = getLevelOnePredictFrame(frame, frame2, job);
        Log.info(new Object[]{"Finished creating \"level one\" frame for scoring: " + levelOnePredictFrame.toString()});
        Model model = ((StackedEnsembleOutput) this._output)._metalearner;
        Frame score = model.score(levelOnePredictFrame, str, job, z, CFuncRef.from(((StackedEnsembleParameters) this._parms)._custom_metric_func));
        ModelMetrics modelMetrics = null;
        if (z) {
            Key[] modelMetrics2 = model._output.getModelMetrics();
            modelMetrics = modelMetrics2[modelMetrics2.length - 1].get().deepCloneWithDifferentModelAndFrame(this, frame);
            addModelMetrics(modelMetrics);
            for (Key key : model._output.clearModelMetrics(true)) {
                DKV.remove(key);
            }
        }
        Frame.deleteTempFrameAndItsNonSharedVecs(levelOnePredictFrame, frame2);
        return new StackedEnsemblePredictScoreResult(score, modelMetrics);
    }

    private Frame getLevelOnePredictFrame(final Frame frame, Frame frame2, final Job job) {
        StackedEnsembleParameters.MetalearnerTransform metalearnerTransform;
        if (((StackedEnsembleParameters) this._parms)._metalearner_transform == null || ((StackedEnsembleParameters) this._parms)._metalearner_transform == StackedEnsembleParameters.MetalearnerTransform.NONE) {
            metalearnerTransform = null;
        } else {
            if (!((StackedEnsembleOutput) this._output).isBinomialClassifier() && !((StackedEnsembleOutput) this._output).isMultinomialClassifier()) {
                throw new H2OIllegalArgumentException("Metalearner transform is supported only for classification!");
            }
            metalearnerTransform = ((StackedEnsembleParameters) this._parms)._metalearner_transform;
        }
        final String key = this._key.toString();
        Key<Frame> make = Key.make("preds_levelone_" + key + frame._key);
        Frame frame3 = metalearnerTransform == null ? new Frame(make) : new Frame(new Vec[0]);
        final Model[] modelArr = (Model[]) Stream.of((Object[]) ((StackedEnsembleParameters) this._parms)._base_models).filter(this::isUsefulBaseModel).map((v0) -> {
            return v0.get();
        }).toArray(i -> {
            return new Model[i];
        });
        if (modelArr.length > 0) {
            final Frame[] frameArr = new Frame[modelArr.length];
            H2O.submitTask(new LocalMR(new MrFun() { // from class: hex.ensemble.StackedEnsembleModel.1
                protected void map(int i2) {
                    frameArr[i2] = modelArr[i2].score(frame, "preds_base_" + key + modelArr[i2]._key + frame._key, job, false);
                }
            }, modelArr.length)).join();
            for (int i2 = 0; i2 < modelArr.length; i2++) {
                StackedEnsemble.addModelPredictionsToLevelOneFrame(modelArr[i2], frameArr[i2], frame3);
                DKV.remove(frameArr[i2]._key);
                Frame.deleteTempFrameAndItsNonSharedVecs(frameArr[i2], frame3);
            }
        }
        if (metalearnerTransform != null) {
            frame3 = metalearnerTransform.transform(this, frame3, make);
            frame3.remove();
        }
        StackedEnsemble.addNonPredictorsToLevelOneFrame((StackedEnsembleParameters) this._parms, frame2, frame3, false);
        return frame3;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public boolean isUsefulBaseModel(Key<Model> key) {
        Model model = ((StackedEnsembleOutput) this._output)._metalearner;
        if (!$assertionsDisabled && model == null) {
            throw new AssertionError("can't use isUsefulBaseModel during training");
        }
        if (this.modelCategory != ModelCategory.Multinomial) {
            return model.isFeatureUsedInPredict(key.toString());
        }
        for (String str : model._output._names) {
            if (str.startsWith(key.toString().concat("/")) && model.isFeatureUsedInPredict(str)) {
                return true;
            }
        }
        return false;
    }

    protected double[] score0(double[] dArr, double[] dArr2) {
        throw new UnsupportedOperationException("StackedEnsembleModel.score0() should never be called: the code paths that normally go here should call predictScoreImpl().");
    }

    public ModelMetrics.MetricBuilder makeMetricBuilder(String[] strArr) {
        throw new UnsupportedOperationException("StackedEnsembleModel.makeMetricBuilder should never be called!");
    }

    private ModelMetrics doScoreTrainingMetrics(Frame frame, Job job) {
        Frame sampleFrame = (((StackedEnsembleParameters) this._parms)._score_training_samples <= 0 || ((StackedEnsembleParameters) this._parms)._score_training_samples >= frame.numRows()) ? frame : MRUtils.sampleFrame(frame, ((StackedEnsembleParameters) this._parms)._score_training_samples, ((StackedEnsembleParameters) this._parms)._seed);
        try {
            Frame frame2 = new Frame(sampleFrame);
            Model<StackedEnsembleModel, StackedEnsembleParameters, StackedEnsembleOutput>.PredictScoreResult predictScoreImpl = predictScoreImpl(sampleFrame, frame2, null, job, true, CFuncRef.from(((StackedEnsembleParameters) this._parms)._custom_metric_func));
            predictScoreImpl.getPredictions().delete();
            ModelMetrics makeModelMetrics = predictScoreImpl.makeModelMetrics(sampleFrame, frame2);
            if (sampleFrame != frame) {
                sampleFrame.delete();
            }
            return makeModelMetrics;
        } catch (Throwable th) {
            if (sampleFrame != frame) {
                sampleFrame.delete();
            }
            throw th;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void doScoreOrCopyMetrics(Job job) {
        ((StackedEnsembleOutput) this._output)._training_metrics = doScoreTrainingMetrics(((StackedEnsembleParameters) this._parms).train(), null);
        ((StackedEnsembleOutput) this._output)._validation_metrics = ((StackedEnsembleOutput) this._output)._metalearner._output._validation_metrics;
        if (null != ((StackedEnsembleOutput) this._output)._metalearner._output._cross_validation_metrics) {
            ((StackedEnsembleOutput) this._output)._cross_validation_metrics = ((StackedEnsembleOutput) this._output)._metalearner._output._cross_validation_metrics.deepCloneWithDifferentModelAndFrame(this, ((StackedEnsembleOutput) this._output)._metalearner._parms.train());
            ((StackedEnsembleOutput) this._output)._cross_validation_metrics_summary = ((StackedEnsembleOutput) this._output)._metalearner._output._cross_validation_metrics_summary.clone();
        }
    }

    public void deleteBaseModelPredictions() {
        if (((StackedEnsembleOutput) this._output)._base_model_predictions_keys != null) {
            for (Key<Frame> key : ((StackedEnsembleOutput) this._output)._base_model_predictions_keys) {
                if (((StackedEnsembleOutput) this._output)._levelone_frame_id == null || key.get() == null) {
                    Keyed.remove(key);
                } else {
                    Frame.deleteTempFrameAndItsNonSharedVecs(key.get(), ((StackedEnsembleOutput) this._output)._levelone_frame_id);
                }
            }
            ((StackedEnsembleOutput) this._output)._base_model_predictions_keys = null;
        }
    }

    protected Futures remove_impl(Futures futures, boolean z) {
        deleteBaseModelPredictions();
        if (((StackedEnsembleOutput) this._output)._metalearner != null) {
            ((StackedEnsembleOutput) this._output)._metalearner.remove(futures);
        }
        if (((StackedEnsembleOutput) this._output)._levelone_frame_id != null) {
            ((StackedEnsembleOutput) this._output)._levelone_frame_id.remove(futures);
        }
        return super.remove_impl(futures, z);
    }

    protected AutoBuffer writeAll_impl(AutoBuffer autoBuffer) {
        autoBuffer.putKey(((StackedEnsembleOutput) this._output)._metalearner._key);
        for (Key<Model> key : ((StackedEnsembleParameters) this._parms)._base_models) {
            autoBuffer.putKey(key);
        }
        return super.writeAll_impl(autoBuffer);
    }

    protected Keyed readAll_impl(AutoBuffer autoBuffer, Futures futures) {
        autoBuffer.getKey(((StackedEnsembleOutput) this._output)._metalearner._key, futures);
        for (Key<Model> key : ((StackedEnsembleParameters) this._parms)._base_models) {
            autoBuffer.getKey(key, futures);
        }
        return super.readAll_impl(autoBuffer, futures);
    }

    /* renamed from: getMojo, reason: merged with bridge method [inline-methods] */
    public StackedEnsembleMojoWriter m66getMojo() {
        return new StackedEnsembleMojoWriter(this);
    }

    public void deleteCrossValidationModels() {
        if (((StackedEnsembleOutput) this._output)._metalearner != null) {
            ((StackedEnsembleOutput) this._output)._metalearner.deleteCrossValidationModels();
        }
    }

    public void deleteCrossValidationPreds() {
        if (((StackedEnsembleOutput) this._output)._metalearner != null) {
            ((StackedEnsembleOutput) this._output)._metalearner.deleteCrossValidationPreds();
        }
    }

    public void deleteCrossValidationFoldAssignment() {
        if (((StackedEnsembleOutput) this._output)._metalearner != null) {
            ((StackedEnsembleOutput) this._output)._metalearner.deleteCrossValidationFoldAssignment();
        }
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case -1202319373:
                if (implMethodName.equals("lambda$scoreContributions$8d51a081$1")) {
                    z = true;
                    break;
                }
                break;
            case 255697212:
                if (implMethodName.equals("lambda$scoreContributions$1617fc2c$1")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("water/util/fp/Function") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("hex/ensemble/StackedEnsembleModel") && serializedLambda.getImplMethodSignature().equals("(Lwater/util/fp/Function2;Lwater/fvec/Frame;)Lwater/fvec/Frame;")) {
                    Function2 function2 = (Function2) serializedLambda.getCapturedArg(0);
                    return frame4 -> {
                        return (Frame) function2.apply(frame4, false);
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 7 && serializedLambda.getFunctionalInterfaceClass().equals("water/util/fp/Function2") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("hex/ensemble/StackedEnsembleModel") && serializedLambda.getImplMethodSignature().equals("(Lwater/Key;Lwater/Job;Lhex/Model$Contributions$ContributionsOptions;Lwater/fvec/Frame;Lwater/fvec/Frame;Ljava/lang/Boolean;)Lwater/fvec/Frame;")) {
                    StackedEnsembleModel stackedEnsembleModel = (StackedEnsembleModel) serializedLambda.getCapturedArg(0);
                    Key key = (Key) serializedLambda.getCapturedArg(1);
                    Job job = (Job) serializedLambda.getCapturedArg(2);
                    Model.Contributions.ContributionsOptions contributionsOptions = (Model.Contributions.ContributionsOptions) serializedLambda.getCapturedArg(3);
                    Frame frame = (Frame) serializedLambda.getCapturedArg(4);
                    return (frame3, bool) -> {
                        Frame baseLineContributions2 = baseLineContributions(frame3, Key.make(key + "_individual_contribs_for_subframe_" + frame3._key), job, contributionsOptions, frame);
                        String[] strArr = (String[]) Arrays.copyOf(baseLineContributions2.names(), baseLineContributions2.names().length - 3);
                        String[] strArr2 = (String[]) Arrays.copyOf(baseLineContributions2.names(), baseLineContributions2.names().length - 2);
                        if (!$assertionsDisabled && !strArr2[strArr2.length - 1].equals("BiasTerm")) {
                            throw new AssertionError();
                        }
                        try {
                            Frame outputFrame = ((ContributionsMeanAggregator) new ContributionsMeanAggregator(job, (int) frame3.numRows(), strArr.length + 1, (int) frame.numRows()).withPostMapAction(JobUpdatePostMap.forJob(job)).doAll(strArr.length + 1, (byte) 3, baseLineContributions2)).outputFrame(bool.booleanValue() ? key : Key.make(key + "_for_subframe_" + frame3._key), strArr2, (String[][]) null);
                            baseLineContributions2.delete(true);
                            return outputFrame;
                        } catch (Throwable th) {
                            baseLineContributions2.delete(true);
                            throw th;
                        }
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }

    static {
        $assertionsDisabled = !StackedEnsembleModel.class.desiredAssertionStatus();
    }
}
