package hex.modelselection;

import hex.DataInfo;
import hex.Model;
import hex.ModelBuilder;
import hex.ModelBuilderHelper;
import hex.ModelBuilderListener;
import hex.ModelCategory;
import hex.genmodel.utils.MathUtils;
import hex.glm.GLM;
import hex.glm.GLMModel;
import hex.modelselection.ModelSelectionModel;
import hex.modelselection.ModelSelectionUtils;
import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.BitSet;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.DoubleStream;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import jsr166y.CountedCompleter;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;
import water.DKV;
import water.H2O;
import water.HeartBeat;
import water.Key;
import water.Lockable;
import water.Scope;
import water.exceptions.H2OModelBuilderIllegalArgumentException;
import water.fvec.Frame;
import water.util.ArrayUtils;
import water.util.PrettyPrint;

/* loaded from: input_file:hex/modelselection/ModelSelection.class */
public class ModelSelection extends ModelBuilder<ModelSelectionModel, ModelSelectionModel.ModelSelectionParameters, ModelSelectionModel.ModelSelectionModelOutput> {
    public String[][] _bestModelPredictors;
    public double[] _bestR2Values;
    public String[][] _predictorsAdd;
    public String[][] _predictorsRemoved;
    DataInfo _dinfo;
    String[] _coefNames;
    public int _numPredictors;
    public String[] _predictorNames;
    double[][] _currCPM;
    Frame _currCPMFrame;
    int[] _trackSweep;
    public int _glmNFolds;
    Model.Parameters.FoldAssignmentScheme _foldAssignment;
    String _foldColumn;

    /* loaded from: input_file:hex/modelselection/ModelSelection$ModelSelectionDriver.class */
    public class ModelSelectionDriver extends ModelBuilder<ModelSelectionModel, ModelSelectionModel.ModelSelectionParameters, ModelSelectionModel.ModelSelectionModelOutput>.Driver {
        public ModelSelectionDriver() {
            super();
        }

        /* JADX WARN: Type inference failed for: r1v122, types: [double[], double[][]] */
        /* JADX WARN: Type inference failed for: r1v126, types: [double[], double[][]] */
        /* JADX WARN: Type inference failed for: r1v130, types: [java.lang.String[], java.lang.String[][]] */
        /* JADX WARN: Type inference failed for: r1v134, types: [java.lang.String[], java.lang.String[][]] */
        /* JADX WARN: Type inference failed for: r1v138, types: [java.lang.String[], java.lang.String[][]] */
        /* JADX WARN: Type inference failed for: r1v142, types: [java.lang.String[], java.lang.String[][]] */
        /* JADX WARN: Type inference failed for: r1v31, types: [java.lang.String[], java.lang.String[][]] */
        /* JADX WARN: Type inference failed for: r1v37, types: [java.lang.String[], java.lang.String[][]] */
        /* JADX WARN: Type inference failed for: r1v43, types: [java.lang.String[], java.lang.String[][]] */
        /* JADX WARN: Type inference failed for: r1v49, types: [java.lang.String[], java.lang.String[][]] */
        /* JADX WARN: Type inference failed for: r1v66, types: [double[], double[][]] */
        /* JADX WARN: Type inference failed for: r1v72, types: [double[], double[][]] */
        public final void buildModel() {
            Lockable lockable = null;
            try {
                int i = 0;
                ModelSelectionModel modelSelectionModel = new ModelSelectionModel(ModelSelection.this.dest(), (ModelSelectionModel.ModelSelectionParameters) ModelSelection.this._parms, new ModelSelectionModel.ModelSelectionModelOutput(ModelSelection.this, ModelSelection.this._dinfo));
                modelSelectionModel.write_lock(ModelSelection.this._job);
                ((ModelSelectionModel.ModelSelectionModelOutput) modelSelectionModel._output)._mode = ((ModelSelectionModel.ModelSelectionParameters) ModelSelection.this._parms)._mode;
                if (ModelSelectionModel.ModelSelectionParameters.Mode.backward.equals(((ModelSelectionModel.ModelSelectionParameters) ModelSelection.this._parms)._mode)) {
                    ((ModelSelectionModel.ModelSelectionModelOutput) modelSelectionModel._output)._best_model_ids = new Key[ModelSelection.this._numPredictors];
                    ((ModelSelectionModel.ModelSelectionModelOutput) modelSelectionModel._output)._coef_p_values = new double[ModelSelection.this._numPredictors];
                    ((ModelSelectionModel.ModelSelectionModelOutput) modelSelectionModel._output)._z_values = new double[ModelSelection.this._numPredictors];
                    ((ModelSelectionModel.ModelSelectionModelOutput) modelSelectionModel._output)._best_predictors_subset = new String[ModelSelection.this._numPredictors];
                    ((ModelSelectionModel.ModelSelectionModelOutput) modelSelectionModel._output)._coefficient_names = new String[ModelSelection.this._numPredictors];
                    ((ModelSelectionModel.ModelSelectionModelOutput) modelSelectionModel._output)._predictors_removed_per_step = new String[ModelSelection.this._numPredictors];
                    ((ModelSelectionModel.ModelSelectionModelOutput) modelSelectionModel._output)._predictors_added_per_step = new String[ModelSelection.this._numPredictors];
                } else {
                    ((ModelSelectionModel.ModelSelectionModelOutput) modelSelectionModel._output)._best_r2_values = new double[((ModelSelectionModel.ModelSelectionParameters) ModelSelection.this._parms)._max_predictor_number];
                    ((ModelSelectionModel.ModelSelectionModelOutput) modelSelectionModel._output)._best_predictors_subset = new String[((ModelSelectionModel.ModelSelectionParameters) ModelSelection.this._parms)._max_predictor_number];
                    ((ModelSelectionModel.ModelSelectionModelOutput) modelSelectionModel._output)._coefficient_names = new String[((ModelSelectionModel.ModelSelectionParameters) ModelSelection.this._parms)._max_predictor_number];
                    ((ModelSelectionModel.ModelSelectionModelOutput) modelSelectionModel._output)._predictors_removed_per_step = new String[((ModelSelectionModel.ModelSelectionParameters) ModelSelection.this._parms)._max_predictor_number];
                    ((ModelSelectionModel.ModelSelectionModelOutput) modelSelectionModel._output)._predictors_added_per_step = new String[((ModelSelectionModel.ModelSelectionParameters) ModelSelection.this._parms)._max_predictor_number];
                    if (!ModelSelectionModel.ModelSelectionParameters.Mode.maxrsweep.equals(((ModelSelectionModel.ModelSelectionParameters) ModelSelection.this._parms)._mode) || ((ModelSelectionModel.ModelSelectionParameters) ModelSelection.this._parms)._build_glm_model) {
                        ((ModelSelectionModel.ModelSelectionModelOutput) modelSelectionModel._output)._best_model_ids = new Key[((ModelSelectionModel.ModelSelectionParameters) ModelSelection.this._parms)._max_predictor_number];
                    } else {
                        ((ModelSelectionModel.ModelSelectionModelOutput) modelSelectionModel._output)._coefficient_values = new double[((ModelSelectionModel.ModelSelectionParameters) ModelSelection.this._parms)._max_predictor_number];
                        ((ModelSelectionModel.ModelSelectionModelOutput) modelSelectionModel._output)._coefficient_values_normalized = new double[((ModelSelectionModel.ModelSelectionParameters) ModelSelection.this._parms)._max_predictor_number];
                        ((ModelSelectionModel.ModelSelectionModelOutput) modelSelectionModel._output)._best_model_ids = null;
                    }
                }
                if (ModelSelectionModel.ModelSelectionParameters.Mode.allsubsets.equals(((ModelSelectionModel.ModelSelectionParameters) ModelSelection.this._parms)._mode)) {
                    buildAllSubsetsModels(modelSelectionModel);
                } else if (ModelSelectionModel.ModelSelectionParameters.Mode.maxr.equals(((ModelSelectionModel.ModelSelectionParameters) ModelSelection.this._parms)._mode)) {
                    buildMaxRModels(modelSelectionModel);
                } else if (ModelSelectionModel.ModelSelectionParameters.Mode.backward.equals(((ModelSelectionModel.ModelSelectionParameters) ModelSelection.this._parms)._mode)) {
                    i = buildBackwardModels(modelSelectionModel);
                } else if (ModelSelectionModel.ModelSelectionParameters.Mode.maxrsweep.equals(((ModelSelectionModel.ModelSelectionParameters) ModelSelection.this._parms)._mode)) {
                    buildMaxRSweepModels(modelSelectionModel);
                }
                ModelSelection.this._job.update(0L, "Completed GLM model building.  Extracting results now.");
                modelSelectionModel.update(ModelSelection.this._job);
                if (ModelSelectionModel.ModelSelectionParameters.Mode.backward.equals(((ModelSelectionModel.ModelSelectionParameters) ModelSelection.this._parms)._mode)) {
                    ((ModelSelectionModel.ModelSelectionModelOutput) modelSelectionModel._output).shrinkArrays(i);
                    ((ModelSelectionModel.ModelSelectionModelOutput) modelSelectionModel._output).generateSummary(i);
                } else {
                    ((ModelSelectionModel.ModelSelectionModelOutput) modelSelectionModel._output).generateSummary();
                }
                modelSelectionModel.update(ModelSelection.this._job);
                modelSelectionModel.unlock(ModelSelection.this._job);
            } catch (Throwable th) {
                lockable.update(ModelSelection.this._job);
                lockable.unlock(ModelSelection.this._job);
                throw th;
            }
        }

        Frame extractCPM(double[][] dArr, String[] strArr) {
            List list = (List) Arrays.stream((strArr == null || strArr.length == 0) ? ModelSelection.this._dinfo.coefNames() : strArr).collect(Collectors.toList());
            if (((ModelSelectionModel.ModelSelectionParameters) ModelSelection.this._parms)._intercept) {
                list.add("intercept");
            }
            list.add("XTYnYTY");
            new ArrayUtils();
            Frame frame = ArrayUtils.frame(Key.make(), (String[]) list.stream().toArray(i -> {
                return new String[i];
            }), dArr);
            Scope.track(frame);
            return ModelSelection.this.rebalance(frame, false, Key.make().toString());
        }

        /* JADX WARN: Multi-variable type inference failed */
        void buildMaxRSweepModels(ModelSelectionModel modelSelectionModel) {
            ModelSelection.this._coefNames = ModelSelection.this._dinfo.coefNames();
            ModelSelectionUtils.CPMnPredNames genCPMPredNamesIndex = ModelSelectionUtils.genCPMPredNamesIndex(ModelSelection.this._job._key, ModelSelection.this._dinfo, ModelSelection.this._predictorNames, (ModelSelectionModel.ModelSelectionParameters) ModelSelection.this._parms);
            ModelSelection.this._predictorNames = genCPMPredNamesIndex._predNames;
            if (ModelSelection.this._predictorNames.length < ((ModelSelectionModel.ModelSelectionParameters) ModelSelection.this._parms)._max_predictor_number) {
                ModelSelection.this.error("max_predictor_number", "Your dataset contains duplicated predictors.  After removal, reduce your max_predictor_number to " + ModelSelection.this._predictorNames.length + " or less.");
            }
            if (ModelSelection.this.error_count() > 0) {
                throw H2OModelBuilderIllegalArgumentException.makeFromBuilder(ModelSelection.this);
            }
            int length = genCPMPredNamesIndex._cpm.length;
            if (((ModelSelectionModel.ModelSelectionParameters) ModelSelection.this._parms)._multinode_mode) {
                ModelSelection.this._currCPMFrame = extractCPM(genCPMPredNamesIndex._cpm, genCPMPredNamesIndex._coefNames);
                Scope.track(ModelSelection.this._currCPMFrame);
                DKV.put(ModelSelection.this._currCPMFrame);
                ModelSelection.this._trackSweep = IntStream.range(0, genCPMPredNamesIndex._cpm.length).map(i -> {
                    return 1;
                }).toArray();
                if (((ModelSelectionModel.ModelSelectionParameters) ModelSelection.this._parms)._intercept) {
                    ModelSelectionUtils.sweepCPMParallel(ModelSelection.this._currCPMFrame, new int[]{length - 2}, ModelSelection.this._trackSweep);
                }
                genCPMPredNamesIndex._cpm = (double[][]) null;
            } else {
                ModelSelection.this._currCPM = genCPMPredNamesIndex._cpm;
                if (((ModelSelectionModel.ModelSelectionParameters) ModelSelection.this._parms)._intercept) {
                    ModelSelectionUtils.sweepCPM(ModelSelection.this._currCPM, new int[]{length - 2}, false);
                }
            }
            int[][] iArr = genCPMPredNamesIndex._pred2CPMMapping;
            ModelSelection.this.checkMemoryFootPrint(length);
            double calR2Scale = 1.0d / ModelSelectionUtils.calR2Scale(ModelSelection.this.train(), ((ModelSelectionModel.ModelSelectionParameters) ModelSelection.this._parms)._response_column);
            ModelSelectionUtils.CoeffNormalization generateScale = ModelSelectionUtils.generateScale(ModelSelection.this._dinfo, ((ModelSelectionModel.ModelSelectionParameters) ModelSelection.this._parms)._standardize);
            List<Integer> arrayList = new ArrayList();
            ArrayList arrayList2 = new ArrayList(Arrays.asList(ModelSelection.this._predictorNames));
            List<Integer> list = (List) IntStream.rangeClosed(0, arrayList2.size() - 1).boxed().collect(Collectors.toList());
            SweepModel sweepModel = null;
            List<String> list2 = (List) Stream.of((Object[]) ModelSelection.this._coefNames).collect(Collectors.toList());
            BitSet bitSet = new BitSet(ModelSelection.this._predictorNames.length);
            for (int i2 = 1; i2 <= ((ModelSelectionModel.ModelSelectionParameters) ModelSelection.this._parms)._max_predictor_number; i2++) {
                HashSet hashSet = new HashSet();
                sweepModel = ModelSelection.this.forwardStep(arrayList, list, hashSet, bitSet, iArr, sweepModel, ((ModelSelectionModel.ModelSelectionParameters) ModelSelection.this._parms)._intercept);
                list.removeAll(arrayList);
                ModelSelection.this._job.update(i2, "Finished forward step with " + i2 + " predictors.");
                if (i2 <= ModelSelection.this._numPredictors && i2 > 1) {
                    sweepModel = ModelSelection.this.replacement(arrayList, list, hashSet, bitSet, sweepModel, iArr);
                    arrayList = (List) IntStream.of(sweepModel._predSubset).boxed().collect(Collectors.toList());
                    list = (List) IntStream.rangeClosed(0, arrayList2.size() - 1).boxed().collect(Collectors.toList());
                    list.removeAll(arrayList);
                }
                if (((ModelSelectionModel.ModelSelectionParameters) ModelSelection.this._parms)._build_glm_model) {
                    GLMModel buildGLMModel = buildGLMModel(arrayList);
                    DKV.put(buildGLMModel);
                    ((ModelSelectionModel.ModelSelectionModelOutput) modelSelectionModel._output).updateBestModels(buildGLMModel, i2 - 1);
                } else {
                    ((ModelSelectionModel.ModelSelectionModelOutput) modelSelectionModel._output).updateBestModels(ModelSelection.this._predictorNames, list2, i2 - 1, ((ModelSelectionModel.ModelSelectionParameters) ModelSelection.this._parms)._intercept, sweepModel._CPM.length, sweepModel._predSubset, sweepModel._CPM, calR2Scale, generateScale, iArr, ModelSelection.this._dinfo);
                }
            }
        }

        public GLMModel buildGLMModel(List<Integer> list) {
            Frame generateOneFrame = ModelSelectionUtils.generateOneFrame(list.stream().mapToInt((v0) -> {
                return v0.intValue();
            }).toArray(), ModelSelection.this._parms, ModelSelection.this._predictorNames, null);
            DKV.put(generateOneFrame);
            Field[] declaredFields = ModelSelectionModel.ModelSelectionParameters.class.getDeclaredFields();
            Field[] declaredFields2 = Model.Parameters.class.getDeclaredFields();
            GLMModel.GLMParameters gLMParameters = new GLMModel.GLMParameters();
            ModelSelectionUtils.setParamField(ModelSelection.this._parms, gLMParameters, false, declaredFields, Collections.emptyList());
            ModelSelectionUtils.setParamField(ModelSelection.this._parms, gLMParameters, true, declaredFields2, Collections.emptyList());
            gLMParameters._train = generateOneFrame._key;
            if (((ModelSelectionModel.ModelSelectionParameters) ModelSelection.this._parms)._valid != null) {
                gLMParameters._valid = ((ModelSelectionModel.ModelSelectionParameters) ModelSelection.this._parms)._valid;
            }
            GLMModel gLMModel = new GLM(gLMParameters).trainModel().get();
            DKV.remove(generateOneFrame._key);
            return gLMModel;
        }

        void buildMaxRModels(ModelSelectionModel modelSelectionModel) {
            ArrayList arrayList = new ArrayList();
            ArrayList arrayList2 = new ArrayList(Arrays.asList(ModelSelection.this._predictorNames));
            List list = (List) IntStream.rangeClosed(0, arrayList2.size() - 1).boxed().collect(Collectors.toList());
            for (int i = 1; i <= ((ModelSelectionModel.ModelSelectionParameters) ModelSelection.this._parms)._max_predictor_number; i++) {
                HashSet hashSet = new HashSet();
                GLMModel forwardStep = ModelSelection.forwardStep(arrayList, arrayList2, i - 1, list, (ModelSelectionModel.ModelSelectionParameters) ModelSelection.this._parms, ModelSelection.this._foldColumn, ModelSelection.this._glmNFolds, ModelSelection.this._foldAssignment, hashSet);
                list.removeAll(arrayList);
                ModelSelection.this._job.update(i, "Finished building all models with " + i + " predictors.");
                if (i < ModelSelection.this._numPredictors && i > 1) {
                    GLMModel replacement = ModelSelection.replacement(arrayList, arrayList2, forwardStep != null ? forwardStep.r2() : 0.0d, (ModelSelectionModel.ModelSelectionParameters) ModelSelection.this._parms, ModelSelection.this._glmNFolds, ModelSelection.this._foldColumn, list, ModelSelection.this._foldAssignment, hashSet);
                    if (replacement != null) {
                        forwardStep.delete();
                        forwardStep = replacement;
                    }
                    list.removeAll(arrayList);
                }
                DKV.put(forwardStep);
                ((ModelSelectionModel.ModelSelectionModelOutput) modelSelectionModel._output).updateBestModels(forwardStep, i - 1);
            }
        }

        private int buildBackwardModels(ModelSelectionModel modelSelectionModel) {
            ArrayList arrayList = new ArrayList(Arrays.asList(ModelSelection.this._predictorNames));
            Frame frame = (Frame) DKV.getGet(((ModelSelectionModel.ModelSelectionParameters) ModelSelection.this._parms)._train);
            List<String> list = (List) arrayList.stream().filter(str -> {
                return frame.vec(str).isNumeric();
            }).collect(Collectors.toList());
            List<String> list2 = (List) arrayList.stream().filter(str2 -> {
                return !list.contains(str2);
            }).collect(Collectors.toList());
            int i = 0;
            String[] strArr = (String[]) arrayList.toArray(new String[0]);
            for (int i2 = ModelSelection.this._numPredictors; i2 >= ((ModelSelectionModel.ModelSelectionParameters) ModelSelection.this._parms)._min_predictor_number; i2--) {
                int i3 = i2 - 1;
                Frame generateOneFrame = ModelSelectionUtils.generateOneFrame(null, ModelSelection.this._parms, strArr, ModelSelection.this._foldColumn);
                DKV.put(generateOneFrame);
                GLMModel gLMModel = new GLM(ModelSelectionUtils.generateGLMParameters(new Frame[]{generateOneFrame}, (ModelSelectionModel.ModelSelectionParameters) ModelSelection.this._parms, ModelSelection.this._glmNFolds, ModelSelection.this._foldColumn, ModelSelection.this._foldAssignment)[0]).trainModel().get();
                DKV.put(gLMModel);
                ((ModelSelectionModel.ModelSelectionModelOutput) modelSelectionModel._output).extractPredictors4NextModel(gLMModel, i3, arrayList, list, list2);
                i++;
                DKV.remove(generateOneFrame._key);
                strArr = (String[]) arrayList.toArray(new String[0]);
                ModelSelection.this._job.update(i2, "Finished building all models with " + i2 + " predictors.");
                if ((((ModelSelectionModel.ModelSelectionParameters) ModelSelection.this._parms)._p_values_threshold > CMAESOptimizer.DEFAULT_STOPFITNESS && DoubleStream.of(((ModelSelectionModel.ModelSelectionModelOutput) modelSelectionModel._output)._coef_p_values[i3]).limit(((ModelSelectionModel.ModelSelectionModelOutput) modelSelectionModel._output)._coef_p_values[i3].length - 1).allMatch(d -> {
                    return d <= ((ModelSelectionModel.ModelSelectionParameters) ModelSelection.this._parms)._p_values_threshold;
                })) || arrayList.size() == 0) {
                    break;
                }
            }
            return i;
        }

        void buildAllSubsetsModels(ModelSelectionModel modelSelectionModel) {
            for (int i = 1; i <= ((ModelSelectionModel.ModelSelectionParameters) ModelSelection.this._parms)._max_predictor_number; i++) {
                Frame[] generateTrainingFrames = ModelSelectionUtils.generateTrainingFrames((ModelSelectionModel.ModelSelectionParameters) ModelSelection.this._parms, i, ModelSelection.this._predictorNames, MathUtils.combinatorial(ModelSelection.this._numPredictors, i), ModelSelection.this._foldColumn);
                GLMModel buildExtractBestR2Model = ModelSelection.buildExtractBestR2Model(generateTrainingFrames, (ModelSelectionModel.ModelSelectionParameters) ModelSelection.this._parms, ModelSelection.this._glmNFolds, ModelSelection.this._foldColumn, ModelSelection.this._foldAssignment);
                DKV.put(buildExtractBestR2Model);
                ((ModelSelectionModel.ModelSelectionModelOutput) modelSelectionModel._output).updateBestModels(buildExtractBestR2Model, i - 1);
                ModelSelectionUtils.removeTrainingFrames(generateTrainingFrames);
                ModelSelection.this._job.update(i, "Finished building all models with " + i + " predictors.");
            }
        }

        /* JADX WARN: Multi-variable type inference failed */
        @Override // hex.ModelBuilder.Driver
        public void computeImpl() {
            if (((ModelSelectionModel.ModelSelectionParameters) ModelSelection.this._parms)._lambda_search || !((ModelSelectionModel.ModelSelectionParameters) ModelSelection.this._parms)._intercept || ((ModelSelectionModel.ModelSelectionParameters) ModelSelection.this._parms)._lambda == null || ((ModelSelectionModel.ModelSelectionParameters) ModelSelection.this._parms)._lambda[0] > CMAESOptimizer.DEFAULT_STOPFITNESS) {
                ((ModelSelectionModel.ModelSelectionParameters) ModelSelection.this._parms)._use_all_factor_levels = true;
            }
            ModelSelection.this._dinfo = new DataInfo((Frame) ModelSelection.this._train.m1494clone(), ModelSelection.this._valid, 1, ((ModelSelectionModel.ModelSelectionParameters) ModelSelection.this._parms)._use_all_factor_levels, ((ModelSelectionModel.ModelSelectionParameters) ModelSelection.this._parms)._standardize ? DataInfo.TransformType.STANDARDIZE : DataInfo.TransformType.NONE, DataInfo.TransformType.NONE, ((ModelSelectionModel.ModelSelectionParameters) ModelSelection.this._parms).missingValuesHandling() == GLMModel.GLMParameters.MissingValuesHandling.Skip, ((ModelSelectionModel.ModelSelectionParameters) ModelSelection.this._parms).imputeMissing(), ((ModelSelectionModel.ModelSelectionParameters) ModelSelection.this._parms).makeImputer(), false, ModelSelection.this.hasWeightCol(), ModelSelection.this.hasOffsetCol(), ModelSelection.this.hasFoldCol(), (Model.InteractionSpec) null);
            ModelSelection.this.init(true);
            if (ModelSelection.this.error_count() > 0) {
                throw H2OModelBuilderIllegalArgumentException.makeFromBuilder(ModelSelection.this);
            }
            ModelSelection.this._job.update(0L, "finished init and ready to build models");
            buildModel();
        }

        @Override // hex.ModelBuilder.Driver, jsr166y.CountedCompleter
        public /* bridge */ /* synthetic */ boolean onExceptionalCompletion(Throwable th, CountedCompleter countedCompleter) {
            return super.onExceptionalCompletion(th, countedCompleter);
        }

        @Override // hex.ModelBuilder.Driver, jsr166y.CountedCompleter
        public /* bridge */ /* synthetic */ void onCompletion(CountedCompleter countedCompleter) {
            super.onCompletion(countedCompleter);
        }

        @Override // hex.ModelBuilder.Driver, water.H2O.H2OCountedCompleter
        public /* bridge */ /* synthetic */ void compute2() {
            super.compute2();
        }

        @Override // hex.ModelBuilder.Driver
        public /* bridge */ /* synthetic */ void setCallback(ModelBuilderListener modelBuilderListener) {
            super.setCallback(modelBuilderListener);
        }
    }

    /* loaded from: input_file:hex/modelselection/ModelSelection$SweepModel.class */
    public static class SweepModel {
        int[] _predSubset;
        double[][] _CPM;
        double _errorVariance;

        public SweepModel(int[] iArr, double[][] dArr, double d) {
            this._predSubset = iArr;
            this._CPM = dArr;
            this._errorVariance = d;
        }
    }

    public ModelSelection(boolean z) {
        super(new ModelSelectionModel.ModelSelectionParameters(), z);
        this._glmNFolds = 0;
        this._foldAssignment = null;
        this._foldColumn = null;
    }

    public ModelSelection(ModelSelectionModel.ModelSelectionParameters modelSelectionParameters) {
        super(modelSelectionParameters);
        this._glmNFolds = 0;
        this._foldAssignment = null;
        this._foldColumn = null;
        init(false);
    }

    public ModelSelection(ModelSelectionModel.ModelSelectionParameters modelSelectionParameters, Key<ModelSelectionModel> key) {
        super(modelSelectionParameters, key);
        this._glmNFolds = 0;
        this._foldAssignment = null;
        this._foldColumn = null;
        init(false);
    }

    @Override // hex.ModelBuilder
    protected int nModelsInParallel(int i) {
        return nModelsInParallel(1, 2);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // hex.ModelBuilder
    public ModelSelectionDriver trainModelImpl() {
        return new ModelSelectionDriver();
    }

    @Override // hex.ModelBuilder
    public ModelCategory[] can_build() {
        return new ModelCategory[]{ModelCategory.Regression};
    }

    @Override // hex.ModelBuilder
    public boolean isSupervised() {
        return true;
    }

    @Override // hex.ModelBuilder
    public boolean haveMojo() {
        return false;
    }

    @Override // hex.ModelBuilder
    public boolean havePojo() {
        return false;
    }

    @Override // hex.ModelBuilder
    public void init(boolean z) {
        if (((ModelSelectionModel.ModelSelectionParameters) this._parms)._nfolds > 0 || ((ModelSelectionModel.ModelSelectionParameters) this._parms)._fold_column != null) {
            if (ModelSelectionModel.ModelSelectionParameters.Mode.backward.equals(((ModelSelectionModel.ModelSelectionParameters) this._parms)._mode)) {
                error("nfolds/fold_column", "cross-validation is not supported for backward selection.");
            } else if (ModelSelectionModel.ModelSelectionParameters.Mode.maxrsweep.equals(((ModelSelectionModel.ModelSelectionParameters) this._parms)._mode)) {
                error("nfolds/fold_column", "cross-validation is not supported for maxrsweep,  maxrsweepsmall, maxrsweep and maxrsweepfull.");
            } else {
                this._glmNFolds = ((ModelSelectionModel.ModelSelectionParameters) this._parms)._nfolds;
                if (((ModelSelectionModel.ModelSelectionParameters) this._parms)._fold_assignment != null) {
                    this._foldAssignment = ((ModelSelectionModel.ModelSelectionParameters) this._parms)._fold_assignment;
                    ((ModelSelectionModel.ModelSelectionParameters) this._parms)._fold_assignment = null;
                }
                if (((ModelSelectionModel.ModelSelectionParameters) this._parms)._fold_column != null) {
                    this._foldColumn = ((ModelSelectionModel.ModelSelectionParameters) this._parms)._fold_column;
                    ((ModelSelectionModel.ModelSelectionParameters) this._parms)._fold_column = null;
                }
                ((ModelSelectionModel.ModelSelectionParameters) this._parms)._nfolds = 0;
            }
        }
        super.init(z);
        if (error_count() <= 0 && z) {
            initModelSelectionParameters();
            if (error_count() > 0) {
                return;
            }
            initModelParameters();
        }
    }

    /* JADX WARN: Type inference failed for: r1v13, types: [java.lang.String[], java.lang.String[][]] */
    /* JADX WARN: Type inference failed for: r1v22, types: [java.lang.String[], java.lang.String[][]] */
    /* JADX WARN: Type inference failed for: r1v27, types: [java.lang.String[], java.lang.String[][]] */
    private void initModelParameters() {
        if (ModelSelectionModel.ModelSelectionParameters.Mode.backward.equals(((ModelSelectionModel.ModelSelectionParameters) this._parms)._mode)) {
            return;
        }
        this._bestR2Values = new double[((ModelSelectionModel.ModelSelectionParameters) this._parms)._max_predictor_number];
        this._bestModelPredictors = new String[((ModelSelectionModel.ModelSelectionParameters) this._parms)._max_predictor_number];
        if (!ModelSelectionModel.ModelSelectionParameters.Mode.backward.equals(((ModelSelectionModel.ModelSelectionParameters) this._parms)._mode)) {
            this._predictorsAdd = new String[((ModelSelectionModel.ModelSelectionParameters) this._parms)._max_predictor_number];
        }
        this._predictorsRemoved = new String[((ModelSelectionModel.ModelSelectionParameters) this._parms)._max_predictor_number];
    }

    private void initModelSelectionParameters() {
        this._predictorNames = ModelSelectionUtils.extractPredictorNames(this._parms, this._dinfo, this._foldColumn);
        this._numPredictors = this._predictorNames.length;
        if (ModelSelectionModel.ModelSelectionParameters.Mode.maxr.equals(((ModelSelectionModel.ModelSelectionParameters) this._parms)._mode) || ModelSelectionModel.ModelSelectionParameters.Mode.allsubsets.equals(((ModelSelectionModel.ModelSelectionParameters) this._parms)._mode) || ModelSelectionModel.ModelSelectionParameters.Mode.maxrsweep.equals(((ModelSelectionModel.ModelSelectionParameters) this._parms)._mode)) {
            if (((ModelSelectionModel.ModelSelectionParameters) this._parms)._lambda == null && !((ModelSelectionModel.ModelSelectionParameters) this._parms)._lambda_search && ((ModelSelectionModel.ModelSelectionParameters) this._parms)._alpha == null && !ModelSelectionModel.ModelSelectionParameters.Mode.maxrsweep.equals(((ModelSelectionModel.ModelSelectionParameters) this._parms)._mode)) {
                ((ModelSelectionModel.ModelSelectionParameters) this._parms)._lambda = new double[]{CMAESOptimizer.DEFAULT_STOPFITNESS};
            }
            if (nclasses() > 1) {
                error("response", "'allsubsets', 'maxr', 'maxrsweep', 'maxrsweep' only works with regression.");
            }
            if (!GLMModel.GLMParameters.Family.AUTO.equals(((ModelSelectionModel.ModelSelectionParameters) this._parms)._family) && !GLMModel.GLMParameters.Family.gaussian.equals(((ModelSelectionModel.ModelSelectionParameters) this._parms)._family)) {
                error("_family", "ModelSelection only supports Gaussian family for 'allsubset' and 'maxr' mode.");
            }
            if (GLMModel.GLMParameters.Family.AUTO.equals(((ModelSelectionModel.ModelSelectionParameters) this._parms)._family)) {
                ((ModelSelectionModel.ModelSelectionParameters) this._parms)._family = GLMModel.GLMParameters.Family.gaussian;
            }
            if (((ModelSelectionModel.ModelSelectionParameters) this._parms)._max_predictor_number < 1 || ((ModelSelectionModel.ModelSelectionParameters) this._parms)._max_predictor_number > this._numPredictors) {
                error("max_predictor_number", "max_predictor_number must exceed 0 and be no greater than the number of predictors of the training frame.");
            }
        } else {
            ((ModelSelectionModel.ModelSelectionParameters) this._parms)._compute_p_values = true;
            if (((ModelSelectionModel.ModelSelectionParameters) this._parms)._valid != null) {
                error("validation_frame", " is not supported for ModelSelection mode='backward'");
            }
            if (((ModelSelectionModel.ModelSelectionParameters) this._parms)._lambda_search) {
                error("lambda_search", "backward selection does not support lambda_search.");
            }
            if (((ModelSelectionModel.ModelSelectionParameters) this._parms)._lambda != null) {
                if (((ModelSelectionModel.ModelSelectionParameters) this._parms)._lambda.length > 1) {
                    error("lambda", "if set must be set to 0 and cannot be an array or more than length one for backward selection.");
                }
                if (((ModelSelectionModel.ModelSelectionParameters) this._parms)._lambda[0] != CMAESOptimizer.DEFAULT_STOPFITNESS) {
                    error("lambda", "must be set to 0 for backward selection");
                }
            } else {
                ((ModelSelectionModel.ModelSelectionParameters) this._parms)._lambda = new double[]{CMAESOptimizer.DEFAULT_STOPFITNESS};
            }
            if (GLMModel.GLMParameters.Family.multinomial.equals(((ModelSelectionModel.ModelSelectionParameters) this._parms)._family) || GLMModel.GLMParameters.Family.ordinal.equals(((ModelSelectionModel.ModelSelectionParameters) this._parms)._family)) {
                error("family", "backward selection does not support multinomial or ordinal");
            }
            if (((ModelSelectionModel.ModelSelectionParameters) this._parms)._min_predictor_number <= 0) {
                error("min_predictor_number", "must be >= 1.");
            }
            if (((ModelSelectionModel.ModelSelectionParameters) this._parms)._min_predictor_number > this._numPredictors) {
                error("min_predictor_number", "cannot exceed the total number of predictors (" + this._numPredictors + ")in the dataset.");
            }
        }
        if (((ModelSelectionModel.ModelSelectionParameters) this._parms)._nparallelism < 0) {
            error("nparallelism", "must be >= 0.");
        }
        if (((ModelSelectionModel.ModelSelectionParameters) this._parms)._nparallelism == 0) {
            ((ModelSelectionModel.ModelSelectionParameters) this._parms)._nparallelism = H2O.NUMCPUS;
        }
        if (ModelSelectionModel.ModelSelectionParameters.Mode.maxrsweep.equals(((ModelSelectionModel.ModelSelectionParameters) this._parms)._mode)) {
            warn("validation_frame", " is not used in choosing the best k subset for ModelSelection models with maxrsweep.");
        }
        if (!ModelSelectionModel.ModelSelectionParameters.Mode.maxrsweep.equals(((ModelSelectionModel.ModelSelectionParameters) this._parms)._mode) || ((ModelSelectionModel.ModelSelectionParameters) this._parms)._build_glm_model || ((ModelSelectionModel.ModelSelectionParameters) this._parms)._influence == null) {
            return;
        }
        error("influence", " can only be set if glm models are built.  With maxrsweep model without build_glm_model = true, no GLM models will be built and hence no regression influence diagnostics can be calculated.");
    }

    protected void checkMemoryFootPrint(int i) {
        if (ModelSelectionModel.ModelSelectionParameters.Mode.maxrsweep.equals(((ModelSelectionModel.ModelSelectionParameters) this._parms)._mode)) {
            i = (int) Math.ceil((i * (((ModelSelectionModel.ModelSelectionParameters) this._parms)._max_predictor_number + 2)) / this._dinfo.fullN());
        }
        HeartBeat heartBeat = H2O.SELF._heartbeat;
        long j = heartBeat._cpus_allowed * i * i * i;
        long j2 = heartBeat.get_free_mem();
        if (j > j2) {
            error("_train", "Gram matrices (one per thread) won't fit in the driver node's memory (" + PrettyPrint.bytes(j) + " > " + PrettyPrint.bytes(j2) + ") - try reducing the number of columns and/or the number of categorical factors (or switch to the L-BFGS solver).");
        }
    }

    public SweepModel forwardStep(List<Integer> list, List<Integer> list2, Set<BitSet> set, BitSet bitSet, int[][] iArr, SweepModel sweepModel, boolean z) {
        double[] generateAllErrVar = sweepModel == null ? ModelSelectionUtils.generateAllErrVar(this._currCPM, this._currCPMFrame, -1, list, list2, set, bitSet, iArr, z) : ModelSelectionUtils.generateAllErrVar(this._currCPM, this._currCPMFrame, sweepModel._CPM.length - 1, list, list2, set, bitSet, iArr, z);
        int i = -1;
        double d = Double.MAX_VALUE;
        int length = generateAllErrVar.length;
        for (int i2 = 0; i2 < length; i2++) {
            if (generateAllErrVar[i2] < d) {
                d = generateAllErrVar[i2];
                i = i2;
            }
        }
        if (i == -1) {
            return new SweepModel(null, (double[][]) null, d);
        }
        int intValue = list2.get(i).intValue();
        List<Integer> extractCPMIndexFromPredOnly = ModelSelectionUtils.extractCPMIndexFromPredOnly(iArr, new int[]{intValue});
        if (((ModelSelectionModel.ModelSelectionParameters) this._parms)._multinode_mode) {
            ModelSelectionUtils.sweepCPMParallel(this._currCPMFrame, extractCPMIndexFromPredOnly.stream().mapToInt(num -> {
                return num.intValue();
            }).toArray(), this._trackSweep);
        } else {
            ModelSelectionUtils.sweepCPM(this._currCPM, extractCPMIndexFromPredOnly.stream().mapToInt(num2 -> {
                return num2.intValue();
            }).toArray(), false);
        }
        list.add(Integer.valueOf(intValue));
        int[] array = list.stream().mapToInt(num3 -> {
            return num3.intValue();
        }).toArray();
        return new SweepModel(array, ((ModelSelectionModel.ModelSelectionParameters) this._parms)._multinode_mode ? ModelSelectionUtils.extractPredSubsetsCPMFrame(this._currCPMFrame, array, iArr, z) : ModelSelectionUtils.extractPredSubsetsCPM(this._currCPM, array, iArr, z), d);
    }

    public SweepModel replacement(List<Integer> list, List<Integer> list2, Set<BitSet> set, BitSet bitSet, SweepModel sweepModel, int[][] iArr) {
        double d = sweepModel._errorVariance;
        int size = list.size();
        int i = -1;
        SweepModel sweepModel2 = new SweepModel(sweepModel._predSubset, sweepModel._CPM, sweepModel._errorVariance);
        SweepModel sweepModel3 = new SweepModel(sweepModel._predSubset, sweepModel._CPM, sweepModel._errorVariance);
        while (true) {
            for (int i2 = 0; i2 < size; i2++) {
                ArrayList arrayList = new ArrayList(list);
                int intValue = ((Integer) arrayList.remove(i2)).intValue();
                list2.removeAll(arrayList);
                sweepModel2._predSubset = arrayList.stream().mapToInt(num -> {
                    return num.intValue();
                }).toArray();
                SweepModel forwardStepR = forwardStepR(list, list2, set, bitSet, iArr, sweepModel2, d, i2);
                if (forwardStepR._CPM != null && d > forwardStepR._errorVariance) {
                    sweepModel2 = forwardStepR;
                    d = sweepModel2._errorVariance;
                    i = i2;
                    sweepModel3 = new SweepModel(sweepModel2._predSubset, sweepModel2._CPM, sweepModel2._errorVariance);
                    list2.add(Integer.valueOf(intValue));
                }
            }
            if (i < 0) {
                return sweepModel3;
            }
            i = -1;
        }
    }

    public SweepModel forwardStepR(List<Integer> list, List<Integer> list2, Set<BitSet> set, BitSet bitSet, int[][] iArr, SweepModel sweepModel, double d, int i) {
        double[][] extractPredSubsetsCPM;
        double[][] deepClone = ArrayUtils.deepClone(sweepModel._CPM);
        int intValue = list.get(i).intValue();
        int[] extractSweepIndices = ModelSelectionUtils.extractSweepIndices(list, i, intValue, iArr, ((ModelSelectionModel.ModelSelectionParameters) this._parms)._intercept);
        double[] generateAllErrVarR = ModelSelectionUtils.generateAllErrVarR(this._currCPM, this._currCPMFrame, deepClone, i, list, list2, set, bitSet, iArr, ((ModelSelectionModel.ModelSelectionParameters) this._parms)._intercept, extractSweepIndices, ModelSelectionUtils.sweepCPM(deepClone, extractSweepIndices, true));
        int i2 = -1;
        double d2 = Double.MAX_VALUE;
        int length = generateAllErrVarR.length;
        for (int i3 = 0; i3 < length; i3++) {
            if (generateAllErrVarR[i3] < d2) {
                d2 = generateAllErrVarR[i3];
                i2 = i3;
            }
        }
        if (i2 == -1 || d2 > d) {
            return new SweepModel(null, (double[][]) null, d2);
        }
        int intValue2 = list2.get(i2).intValue();
        list.remove(i);
        list.add(i, Integer.valueOf(intValue2));
        int[] array = list.stream().mapToInt((v0) -> {
            return v0.intValue();
        }).toArray();
        sweepModel._predSubset = array;
        if (((ModelSelectionModel.ModelSelectionParameters) this._parms)._multinode_mode) {
            ModelSelectionUtils.sweepCPMParallel(this._currCPMFrame, iArr[intValue], this._trackSweep);
            ModelSelectionUtils.sweepCPMParallel(this._currCPMFrame, iArr[intValue2], this._trackSweep);
            extractPredSubsetsCPM = ModelSelectionUtils.extractPredSubsetsCPMFrame(this._currCPMFrame, array, iArr, ((ModelSelectionModel.ModelSelectionParameters) this._parms)._intercept);
        } else {
            ModelSelectionUtils.sweepCPM(this._currCPM, iArr[intValue], false);
            ModelSelectionUtils.sweepCPM(this._currCPM, iArr[intValue2], false);
            extractPredSubsetsCPM = ModelSelectionUtils.extractPredSubsetsCPM(this._currCPM, array, iArr, ((ModelSelectionModel.ModelSelectionParameters) this._parms)._intercept);
        }
        sweepModel._CPM = extractPredSubsetsCPM;
        sweepModel._errorVariance = d2;
        return sweepModel;
    }

    public static GLMModel buildExtractBestR2Model(Frame[] frameArr, ModelSelectionModel.ModelSelectionParameters modelSelectionParameters, int i, String str, Model.Parameters.FoldAssignmentScheme foldAssignmentScheme) {
        return ModelSelectionUtils.findBestModel((GLM[]) ModelBuilderHelper.trainModelsParallel(ModelSelectionUtils.buildGLMBuilders(ModelSelectionUtils.generateGLMParameters(frameArr, modelSelectionParameters, i, str, foldAssignmentScheme)), modelSelectionParameters._nparallelism));
    }

    public static GLMModel forwardStep(List<Integer> list, List<String> list2, int i, List<Integer> list3, ModelSelectionModel.ModelSelectionParameters modelSelectionParameters, String str, int i2, Model.Parameters.FoldAssignmentScheme foldAssignmentScheme, Set<BitSet> set) {
        Frame[] generateMaxRTrainingFrames = ModelSelectionUtils.generateMaxRTrainingFrames(modelSelectionParameters, (String[]) list2.stream().toArray(i3 -> {
            return new String[i3];
        }), str, list, i, list3, set);
        if (generateMaxRTrainingFrames.length <= 0) {
            return null;
        }
        GLMModel buildExtractBestR2Model = buildExtractBestR2Model(generateMaxRTrainingFrames, modelSelectionParameters, i2, str, foldAssignmentScheme);
        List<String> extraModelColumnNames = ModelSelectionUtils.extraModelColumnNames(list2, buildExtractBestR2Model);
        int size = extraModelColumnNames.size() - 1;
        while (true) {
            if (size < 0) {
                break;
            }
            int indexOf = list2.indexOf(extraModelColumnNames.get(size));
            if (!list.contains(Integer.valueOf(indexOf))) {
                list.add(i, Integer.valueOf(indexOf));
                break;
            }
            size--;
        }
        ModelSelectionUtils.removeTrainingFrames(generateMaxRTrainingFrames);
        return buildExtractBestR2Model;
    }

    public static GLMModel forwardStep(List<Integer> list, List<String> list2, int i, List<Integer> list3, ModelSelectionModel.ModelSelectionParameters modelSelectionParameters, String str, int i2, Model.Parameters.FoldAssignmentScheme foldAssignmentScheme) {
        return forwardStep(list, list2, i, list3, modelSelectionParameters, str, i2, foldAssignmentScheme, null);
    }

    public static GLMModel replacement(List<Integer> list, List<String> list2, double d, ModelSelectionModel.ModelSelectionParameters modelSelectionParameters, int i, String str, List<Integer> list3, Model.Parameters.FoldAssignmentScheme foldAssignmentScheme, Set<BitSet> set) {
        int size = list.size();
        int i2 = size - 1;
        int i3 = -1;
        GLMModel gLMModel = null;
        while (true) {
            for (int i4 = 0; i4 < size; i4++) {
                ArrayList arrayList = new ArrayList(list);
                int intValue = ((Integer) arrayList.remove(i4)).intValue();
                GLMModel forwardStep = forwardStep(arrayList, list2, i4, list3, modelSelectionParameters, str, i, foldAssignmentScheme, set);
                if (forwardStep != null) {
                    if (forwardStep.r2() > d) {
                        i3 = i4;
                        list3.remove(arrayList.get(i2));
                        if (gLMModel != null) {
                            gLMModel.delete();
                        }
                        gLMModel = forwardStep;
                        d = gLMModel.r2();
                        list.clear();
                        list.addAll(arrayList);
                        list3.add(Integer.valueOf(intValue));
                    } else {
                        forwardStep.delete();
                    }
                }
            }
            if (i3 < 0) {
                return gLMModel;
            }
            i3 = -1;
        }
    }
}
