package hex.tree.xgboost;

import biz.k11i.xgboost.gbm.GBTree;
import biz.k11i.xgboost.tree.RegTree;
import biz.k11i.xgboost.tree.RegTreeNode;
import hex.DataInfo;
import hex.Model;
import hex.ModelBuilder;
import hex.ModelCategory;
import hex.ModelMetrics;
import hex.genmodel.algos.xgboost.XGBoostJavaMojoModel;
import hex.genmodel.utils.DistributionFamily;
import hex.glm.GLMTask;
import hex.tree.PlattScalingHelper;
import hex.tree.SharedTree;
import hex.tree.TreeStats;
import hex.tree.TreeUtils;
import hex.tree.xgboost.XGBoostModel;
import hex.tree.xgboost.XGBoostUtils;
import hex.tree.xgboost.exec.LocalXGBoostExecutor;
import hex.tree.xgboost.exec.RemoteXGBoostExecutor;
import hex.tree.xgboost.exec.XGBoostExecutor;
import hex.tree.xgboost.predict.XGBoostVariableImportance;
import hex.tree.xgboost.remote.SteamExecutorStarter;
import hex.tree.xgboost.util.FeatureScore;
import hex.util.CheckpointUtils;
import java.io.IOException;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import ml.dmlc.xgboost4j.java.DMatrix;
import ml.dmlc.xgboost4j.java.IEvaluation;
import ml.dmlc.xgboost4j.java.IObjective;
import ml.dmlc.xgboost4j.java.Rabit;
import ml.dmlc.xgboost4j.java.XGBoostError;
import org.apache.log4j.Logger;
import water.DKV;
import water.DTask;
import water.ExtensionManager;
import water.H2O;
import water.H2ONode;
import water.Key;
import water.Paxos;
import water.RPC;
import water.Scope;
import water.Value;
import water.exceptions.H2OIllegalArgumentException;
import water.exceptions.H2OModelBuilderIllegalArgumentException;
import water.fvec.Frame;
import water.fvec.RebalanceDataSet;
import water.fvec.Vec;
import water.util.ArrayUtils;
import water.util.TwoDimTable;

/* loaded from: input_file:hex/tree/xgboost/XGBoost.class */
public class XGBoost extends ModelBuilder<XGBoostModel, XGBoostModel.XGBoostParameters, XGBoostOutput> implements PlattScalingHelper.ModelBuilderWithCalibration<XGBoostModel, XGBoostModel.XGBoostParameters, XGBoostOutput> {
    private static final Logger LOG;
    private static final double FILL_RATIO_THRESHOLD = 0.25d;
    private int _ntrees;
    private XGBoostModel.XGBoostParameters.Backend _backend;
    private transient Frame _calib;
    private static volatile boolean DEFAULT_GPU_BLACKLISTED;
    private static Set<Integer> GPUS;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: hex.tree.xgboost.XGBoost$1, reason: invalid class name */
    /* loaded from: input_file:hex/tree/xgboost/XGBoost$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$hex$genmodel$utils$DistributionFamily = new int[DistributionFamily.values().length];

        static {
            try {
                $SwitchMap$hex$genmodel$utils$DistributionFamily[DistributionFamily.bernoulli.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$hex$genmodel$utils$DistributionFamily[DistributionFamily.modified_huber.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$hex$genmodel$utils$DistributionFamily[DistributionFamily.multinomial.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$hex$genmodel$utils$DistributionFamily[DistributionFamily.huber.ordinal()] = 4;
            } catch (NoSuchFieldError e4) {
            }
            try {
                $SwitchMap$hex$genmodel$utils$DistributionFamily[DistributionFamily.poisson.ordinal()] = 5;
            } catch (NoSuchFieldError e5) {
            }
            try {
                $SwitchMap$hex$genmodel$utils$DistributionFamily[DistributionFamily.gamma.ordinal()] = 6;
            } catch (NoSuchFieldError e6) {
            }
            try {
                $SwitchMap$hex$genmodel$utils$DistributionFamily[DistributionFamily.tweedie.ordinal()] = 7;
            } catch (NoSuchFieldError e7) {
            }
            try {
                $SwitchMap$hex$genmodel$utils$DistributionFamily[DistributionFamily.gaussian.ordinal()] = 8;
            } catch (NoSuchFieldError e8) {
            }
            try {
                $SwitchMap$hex$genmodel$utils$DistributionFamily[DistributionFamily.laplace.ordinal()] = 9;
            } catch (NoSuchFieldError e9) {
            }
            try {
                $SwitchMap$hex$genmodel$utils$DistributionFamily[DistributionFamily.quantile.ordinal()] = 10;
            } catch (NoSuchFieldError e10) {
            }
            try {
                $SwitchMap$hex$genmodel$utils$DistributionFamily[DistributionFamily.AUTO.ordinal()] = 11;
            } catch (NoSuchFieldError e11) {
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:hex/tree/xgboost/XGBoost$HasGPUTask.class */
    public static class HasGPUTask extends DTask<HasGPUTask> {
        private final int _gpu_id;
        private boolean _hasGPU;

        private HasGPUTask(int i) {
            this._gpu_id = i;
        }

        public void compute2() {
            this._hasGPU = XGBoost.hasGPU(this._gpu_id);
            tryComplete();
        }

        /* synthetic */ HasGPUTask(int i, AnonymousClass1 anonymousClass1) {
            this(i);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:hex/tree/xgboost/XGBoost$XGBoostDriver.class */
    public class XGBoostDriver extends ModelBuilder<XGBoostModel, XGBoostModel.XGBoostParameters, XGBoostOutput>.Driver {
        long _firstScore;
        long _timeLastScoreStart;
        long _timeLastScoreEnd;

        XGBoostDriver() {
            super(XGBoost.this);
            this._firstScore = 0L;
            this._timeLastScoreStart = 0L;
            this._timeLastScoreEnd = 0L;
        }

        public void computeImpl() {
            XGBoost.this.init(true);
            if (XGBoost.this.error_count() > 0) {
                throw H2OModelBuilderIllegalArgumentException.makeFromBuilder(XGBoost.this);
            }
            buildModel();
        }

        final void buildModel() {
            if ((!XGBoostModel.XGBoostParameters.Backend.auto.equals(((XGBoostModel.XGBoostParameters) XGBoost.this._parms)._backend) && !XGBoostModel.XGBoostParameters.Backend.gpu.equals(((XGBoostModel.XGBoostParameters) XGBoost.this._parms)._backend)) || !XGBoost.hasGPU(((XGBoostModel.XGBoostParameters) XGBoost.this._parms)._gpu_id) || H2O.getCloudSize() != 1 || !((XGBoostModel.XGBoostParameters) XGBoost.this._parms).gpuIncompatibleParams().isEmpty()) {
                buildModelImpl();
                return;
            }
            synchronized (XGBoostGPULock.lock(((XGBoostModel.XGBoostParameters) XGBoost.this._parms)._gpu_id)) {
                buildModelImpl();
            }
        }

        private XGBoostExecutor makeExecutor(XGBoostModel xGBoostModel) throws IOException {
            if (H2O.ARGS.use_external_xgboost) {
                return SteamExecutorStarter.getInstance().getRemoteExecutor(xGBoostModel, XGBoost.this._train, XGBoost.this._job);
            }
            String sysProperty = H2O.getSysProperty("xgboost.external.address", (String) null);
            if (sysProperty == null) {
                return new LocalXGBoostExecutor(xGBoostModel, XGBoost.this._train);
            }
            return new RemoteXGBoostExecutor(xGBoostModel, XGBoost.this._train, sysProperty, H2O.getSysProperty("xgboost.external.user", (String) null), H2O.getSysProperty("xgboost.external.password", (String) null));
        }

        final void buildModelImpl() {
            XGBoostModel xGBoostModel;
            if (((XGBoostModel.XGBoostParameters) XGBoost.this._parms).hasCheckpoint()) {
                XGBoostModel deepClone = DKV.get(((XGBoostModel.XGBoostParameters) XGBoost.this._parms)._checkpoint).get().deepClone(XGBoost.this._result);
                deepClone._parms = XGBoost.this._parms;
                xGBoostModel = (XGBoostModel) deepClone.delete_and_lock(XGBoost.this._job);
            } else {
                xGBoostModel = new XGBoostModel(XGBoost.this._result, (XGBoostModel.XGBoostParameters) XGBoost.this._parms, new XGBoostOutput(XGBoost.this), XGBoost.this._train, XGBoost.this._valid);
                xGBoostModel.write_lock(XGBoost.this._job);
            }
            if (((XGBoostModel.XGBoostParameters) XGBoost.this._parms)._dmatrix_type == XGBoostModel.XGBoostParameters.DMatrixType.sparse) {
                ((XGBoostOutput) xGBoostModel._output)._sparse = true;
            } else if (((XGBoostModel.XGBoostParameters) XGBoost.this._parms)._dmatrix_type == XGBoostModel.XGBoostParameters.DMatrixType.dense) {
                ((XGBoostOutput) xGBoostModel._output)._sparse = false;
            } else {
                ((XGBoostOutput) xGBoostModel._output)._sparse = isTrainDatasetSparse();
            }
            if (xGBoostModel.evalAutoParamsEnabled) {
                xGBoostModel.initActualParamValuesAfterOutputSetup(XGBoost.this.isClassifier(), XGBoost.this._nclass);
            }
            XGBoostUtils.createFeatureMap(xGBoostModel, XGBoost.this._train);
            XGBoostVariableImportance xGBoostVariableImportance = xGBoostModel.setupVarImp();
            try {
                try {
                    XGBoostExecutor makeExecutor = makeExecutor(xGBoostModel);
                    Throwable th = null;
                    try {
                        xGBoostModel.model_info().updateBoosterBytes(makeExecutor.setup());
                        scoreAndBuildTrees(xGBoostModel, makeExecutor, xGBoostVariableImportance);
                        if (makeExecutor != null) {
                            if (0 != 0) {
                                try {
                                    makeExecutor.close();
                                } catch (Throwable th2) {
                                    th.addSuppressed(th2);
                                }
                            } else {
                                makeExecutor.close();
                            }
                        }
                    } catch (Throwable th3) {
                        if (makeExecutor != null) {
                            if (0 != 0) {
                                try {
                                    makeExecutor.close();
                                } catch (Throwable th4) {
                                    th.addSuppressed(th4);
                                }
                            } else {
                                makeExecutor.close();
                            }
                        }
                        throw th3;
                    }
                } catch (Exception e) {
                    throw new RuntimeException("Error while training XGBoost model", e);
                }
            } finally {
                xGBoostVariableImportance.cleanup();
                xGBoostModel.unlock(XGBoost.this._job);
            }
        }

        private boolean isTrainDatasetSparse() {
            long j = 0;
            int i = 0;
            long j2 = 0;
            for (int i2 = 0; i2 < XGBoost.this._train.numCols(); i2++) {
                if (!XGBoost.this._train.name(i2).equals(((XGBoostModel.XGBoostParameters) XGBoost.this._parms)._response_column) && !XGBoost.this._train.name(i2).equals(((XGBoostModel.XGBoostParameters) XGBoost.this._parms)._weights_column) && !XGBoost.this._train.name(i2).equals(((XGBoostModel.XGBoostParameters) XGBoost.this._parms)._fold_column) && !XGBoost.this._train.name(i2).equals(((XGBoostModel.XGBoostParameters) XGBoost.this._parms)._offset_column)) {
                    Vec vec = XGBoost.this._train.vec(i2);
                    j = vec.isCategorical() ? j + XGBoost.this._train.numRows() : j + vec.nzCnt();
                    if (vec.isCategorical()) {
                        j2 += vec.cardinality();
                    } else {
                        i++;
                    }
                }
            }
            long j3 = j2 + i;
            double numRows = j / (j3 * XGBoost.this._train.numRows());
            XGBoost.LOG.info("fill ratio: " + numRows);
            return numRows < XGBoost.FILL_RATIO_THRESHOLD || XGBoost.this._train.numRows() * j3 > 2147483647L;
        }

        /* JADX WARN: Code restructure failed: missing block: B:10:0x01b9, code lost:
        
            if (r0.isEmpty() != false) goto L40;
         */
        /* JADX WARN: Code restructure failed: missing block: B:12:0x01cc, code lost:
        
            if (((hex.tree.xgboost.XGBoostModel.XGBoostParameters) r9.this$0._parms)._booster == hex.tree.xgboost.XGBoostModel.XGBoostParameters.Booster.gblinear) goto L40;
         */
        /* JADX WARN: Code restructure failed: missing block: B:14:0x01d3, code lost:
        
            if (constraintCheckEnabled() == false) goto L40;
         */
        /* JADX WARN: Code restructure failed: missing block: B:15:0x01d6, code lost:
        
            r9.this$0._job.update(0, "Checking monotonicity constraints on the final model");
            r10.model_info().updateBoosterBytes(r11.updateBooster());
            checkConstraints(r10.model_info(), r0);
         */
        /* JADX WARN: Code restructure failed: missing block: B:16:0x01fa, code lost:
        
            r9.this$0._job.update(0, "Scoring the final model");
            doScoring(r10, r11, r12, true);
            r9.this$0._job.update(((hex.tree.xgboost.XGBoostModel.XGBoostParameters) r9.this$0._parms)._ntrees - ((hex.tree.xgboost.XGBoostOutput) r10._output)._ntrees);
         */
        /* JADX WARN: Code restructure failed: missing block: B:17:0x0233, code lost:
        
            return;
         */
        /* JADX WARN: Code restructure failed: missing block: B:9:0x01a3, code lost:
        
            r0 = ((hex.tree.xgboost.XGBoostModel.XGBoostParameters) r9.this$0._parms).monotoneConstraints();
         */
        /*
            Code decompiled incorrectly, please refer to instructions dump.
            To view partially-correct add '--show-bad-code' argument
        */
        private void scoreAndBuildTrees(hex.tree.xgboost.XGBoostModel r10, hex.tree.xgboost.exec.XGBoostExecutor r11, hex.tree.xgboost.predict.XGBoostVariableImportance r12) {
            /*
                Method dump skipped, instructions count: 564
                To view this dump add '--comments-level debug' option
            */
            throw new UnsupportedOperationException("Method not decompiled: hex.tree.xgboost.XGBoost.XGBoostDriver.scoreAndBuildTrees(hex.tree.xgboost.XGBoostModel, hex.tree.xgboost.exec.XGBoostExecutor, hex.tree.xgboost.predict.XGBoostVariableImportance):void");
        }

        private boolean constraintCheckEnabled() {
            return Boolean.parseBoolean(XGBoost.this.getSysProperty("xgboost.monotonicity.checkEnabled", "true"));
        }

        private void checkConstraints(XGBoostModelInfo xGBoostModelInfo, Map<String, Integer> map) {
            GBTree booster = XGBoostJavaMojoModel.makePredictor(xGBoostModelInfo._boosterBytes).getBooster();
            if (!(booster instanceof GBTree)) {
                throw new IllegalStateException("Expected booster object to be GBTree instead it is " + booster.getClass().getName());
            }
            RegTree[][] groupedTrees = booster.getGroupedTrees();
            XGBoostUtils.FeatureProperties assembleFeatureNames = XGBoostUtils.assembleFeatureNames(xGBoostModelInfo.dataInfo());
            for (RegTree[] regTreeArr : groupedTrees) {
                for (RegTree regTree : regTreeArr) {
                    if (regTree != null) {
                        checkConstraints(regTree.getNodes(), map, assembleFeatureNames);
                    }
                }
            }
        }

        private void checkConstraints(RegTreeNode[] regTreeNodeArr, Map<String, Integer> map, XGBoostUtils.FeatureProperties featureProperties) {
            float[] fArr = new float[regTreeNodeArr.length];
            int[] iArr = new int[regTreeNodeArr.length];
            float[] fArr2 = new float[regTreeNodeArr.length];
            int[] iArr2 = new int[regTreeNodeArr.length];
            rollupMinMaxPreds(regTreeNodeArr, 0, fArr, iArr, fArr2, iArr2);
            for (RegTreeNode regTreeNode : regTreeNodeArr) {
                if (!regTreeNode.isLeaf()) {
                    String str = featureProperties._names[regTreeNode.getSplitIndex()];
                    if (map.containsKey(str)) {
                        int intValue = map.get(str).intValue();
                        int leftChildIndex = regTreeNode.getLeftChildIndex();
                        int rightChildIndex = regTreeNode.getRightChildIndex();
                        if (intValue > 0) {
                            if (fArr2[leftChildIndex] > fArr[rightChildIndex]) {
                                throw new IllegalStateException("Monotonicity constraint " + intValue + " violated on column '" + str + "' (max(left) > min(right)): " + fArr2[leftChildIndex] + " > " + fArr[rightChildIndex] + "\nNode: " + regTreeNode + "\nLeft Node (max): " + regTreeNodeArr[iArr2[leftChildIndex]] + "\nRight Node (min): " + regTreeNodeArr[iArr[rightChildIndex]]);
                            }
                        } else if (intValue < 0 && fArr[leftChildIndex] < fArr2[rightChildIndex]) {
                            throw new IllegalStateException("Monotonicity constraint " + intValue + " violated on column '" + str + "' (min(left) < max(right)): " + fArr[leftChildIndex] + " < " + fArr2[rightChildIndex] + "\nNode: " + regTreeNode + "\nLeft Node (min): " + regTreeNodeArr[iArr[leftChildIndex]] + "\nRight Node (max): " + regTreeNodeArr[iArr2[rightChildIndex]]);
                        }
                    } else {
                        continue;
                    }
                }
            }
        }

        private void rollupMinMaxPreds(RegTreeNode[] regTreeNodeArr, int i, float[] fArr, int[] iArr, float[] fArr2, int[] iArr2) {
            RegTreeNode regTreeNode = regTreeNodeArr[i];
            if (regTreeNode.isLeaf()) {
                fArr[i] = regTreeNode.getLeafValue();
                iArr[i] = i;
                fArr2[i] = regTreeNode.getLeafValue();
                iArr2[i] = i;
                return;
            }
            int leftChildIndex = regTreeNode.getLeftChildIndex();
            int rightChildIndex = regTreeNode.getRightChildIndex();
            rollupMinMaxPreds(regTreeNodeArr, leftChildIndex, fArr, iArr, fArr2, iArr2);
            rollupMinMaxPreds(regTreeNodeArr, rightChildIndex, fArr, iArr, fArr2, iArr2);
            int i2 = fArr[leftChildIndex] < fArr[rightChildIndex] ? leftChildIndex : rightChildIndex;
            fArr[i] = fArr[i2];
            iArr[i] = iArr[i2];
            int i3 = fArr2[leftChildIndex] > fArr2[rightChildIndex] ? leftChildIndex : rightChildIndex;
            fArr2[i] = fArr2[i3];
            iArr2[i] = iArr2[i3];
        }

        private boolean doScoring(XGBoostModel xGBoostModel, XGBoostExecutor xGBoostExecutor, XGBoostVariableImportance xGBoostVariableImportance, boolean z) {
            boolean z2 = false;
            long currentTimeMillis = System.currentTimeMillis();
            if (this._firstScore == 0) {
                this._firstScore = currentTimeMillis;
            }
            long j = currentTimeMillis - this._timeLastScoreStart;
            XGBoost.this._job.update(0L, "Built " + ((XGBoostOutput) xGBoostModel._output)._ntrees + " trees so far (out of " + ((XGBoostModel.XGBoostParameters) XGBoost.this._parms)._ntrees + ").");
            boolean z3 = currentTimeMillis - this._firstScore < ((long) ((XGBoostModel.XGBoostParameters) XGBoost.this._parms)._initial_score_interval) || (j > ((long) ((XGBoostModel.XGBoostParameters) XGBoost.this._parms)._score_interval) && ((double) (this._timeLastScoreEnd - this._timeLastScoreStart)) / ((double) j) < 0.1d);
            boolean z4 = ((XGBoostModel.XGBoostParameters) XGBoost.this._parms)._score_tree_interval > 0 && ((XGBoostOutput) xGBoostModel._output)._ntrees % ((XGBoostModel.XGBoostParameters) XGBoost.this._parms)._score_tree_interval == 0;
            if (((XGBoostModel.XGBoostParameters) XGBoost.this._parms)._score_each_iteration || z || ((z3 && ((XGBoostModel.XGBoostParameters) XGBoost.this._parms)._score_tree_interval == 0) || z4)) {
                this._timeLastScoreStart = currentTimeMillis;
                xGBoostModel.model_info().updateBoosterBytes(xGBoostExecutor.updateBooster());
                xGBoostModel.doScoring(XGBoost.this._train, ((XGBoostModel.XGBoostParameters) XGBoost.this._parms).train(), XGBoost.this._valid, ((XGBoostModel.XGBoostParameters) XGBoost.this._parms).valid());
                this._timeLastScoreEnd = System.currentTimeMillis();
                XGBoostOutput xGBoostOutput = (XGBoostOutput) xGBoostModel._output;
                xGBoostOutput._varimp = XGBoost.computeVarImp(xGBoostVariableImportance.getFeatureScores(xGBoostModel.model_info()._boosterBytes));
                xGBoostOutput._model_summary = SharedTree.createModelSummaryTable(xGBoostOutput._ntrees, (TreeStats) null);
                xGBoostOutput._scoring_history = SharedTree.createScoringHistoryTable(xGBoostOutput, ((XGBoostOutput) xGBoostModel._output)._scored_train, xGBoostOutput._scored_valid, XGBoost.this._job, xGBoostOutput._training_time_ms, ((XGBoostModel.XGBoostParameters) XGBoost.this._parms)._custom_metric_func != null, false);
                if (xGBoostOutput._varimp != null) {
                    xGBoostOutput._variable_importances = XGBoost.createVarImpTable(null, ArrayUtils.toDouble(xGBoostOutput._varimp._varimp), xGBoostOutput._varimp._names);
                    xGBoostOutput._variable_importances_cover = XGBoost.createVarImpTable("Cover", ArrayUtils.toDouble(xGBoostOutput._varimp._covers), xGBoostOutput._varimp._names);
                    xGBoostOutput._variable_importances_frequency = XGBoost.createVarImpTable("Frequency", ArrayUtils.toDouble(xGBoostOutput._varimp._freqs), xGBoostOutput._varimp._names);
                }
                xGBoostModel.update(XGBoost.this._job);
                XGBoost.LOG.info(xGBoostModel);
                z2 = true;
            }
            if (z && ((XGBoostModel.XGBoostParameters) XGBoost.this._parms).calibrateModel() && !((XGBoostModel.XGBoostParameters) XGBoost.this._parms)._is_cv_model) {
                ((XGBoostOutput) xGBoostModel._output)._calib_model = PlattScalingHelper.buildCalibrationModel(XGBoost.this, XGBoost.this._parms, XGBoost.this._job, xGBoostModel);
                xGBoostModel.update(XGBoost.this._job);
            }
            return z2;
        }
    }

    public boolean haveMojo() {
        return true;
    }

    public boolean havePojo() {
        return true;
    }

    public ModelBuilder.BuilderVisibility builderVisibility() {
        return ExtensionManager.getInstance().isCoreExtensionsEnabled(XGBoostExtension.NAME) ? ModelBuilder.BuilderVisibility.Stable : ModelBuilder.BuilderVisibility.Experimental;
    }

    public ModelCategory[] can_build() {
        return new ModelCategory[]{ModelCategory.Regression, ModelCategory.Binomial, ModelCategory.Multinomial};
    }

    public XGBoost(XGBoostModel.XGBoostParameters xGBoostParameters) {
        super(xGBoostParameters);
        init(false);
    }

    public XGBoost(XGBoostModel.XGBoostParameters xGBoostParameters, Key<XGBoostModel> key) {
        super(xGBoostParameters, key);
        init(false);
    }

    public XGBoost(boolean z) {
        super(new XGBoostModel.XGBoostParameters(), z);
    }

    public boolean isSupervised() {
        return true;
    }

    protected int nModelsInParallel(int i) {
        if (this._backend == XGBoostModel.XGBoostParameters.Backend.gpu) {
            return 1;
        }
        return nModelsInParallel(i, 2);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* renamed from: trainModelImpl, reason: merged with bridge method [inline-methods] */
    public XGBoostDriver m7trainModelImpl() {
        return new XGBoostDriver();
    }

    public void init(boolean z) {
        super.init(z);
        if (H2O.CLOUD.size() > 1) {
            if (H2O.SELF.getSecurityManager().securityEnabled && !H2O.ARGS.allow_insecure_xgboost) {
                throw new H2OIllegalArgumentException("Cannot run XGBoost on an SSL enabled cluster larger than 1 node. XGBoost does not support SSL encryption.");
            }
            LOG.info("Executing XGBoost on an secured cluster might compromise security.");
        }
        if (H2O.ARGS.client && ((XGBoostModel.XGBoostParameters) this._parms)._build_tree_one_node) {
            error("_build_tree_one_node", "Cannot run on a single node in client mode.");
        }
        if (z) {
            if (this._response.naCnt() > 0) {
                error("_response_column", "Response contains missing values (NAs) - not supported by XGBoost.");
            }
            if (!((XGBoostExtensionCheck) new XGBoostExtensionCheck().doAllNodes()).enabled) {
                error("XGBoost", "XGBoost is not available on all nodes!");
            }
        }
        if (Paxos._cloudLocked) {
            this._backend = XGBoostModel.getActualBackend((XGBoostModel.XGBoostParameters) this._parms);
        } else {
            this._backend = XGBoostModel.XGBoostParameters.Backend.cpu;
        }
        if (((XGBoostModel.XGBoostParameters) this._parms).hasCheckpoint()) {
            Value value = DKV.get(((XGBoostModel.XGBoostParameters) this._parms)._checkpoint);
            if (value != null) {
                this._ntrees = ((XGBoostModel.XGBoostParameters) this._parms)._ntrees - ((XGBoostOutput) ((XGBoostModel) CheckpointUtils.getAndValidateCheckpointModel(this, XGBoostModel.XGBoostParameters.CHECKPOINT_NON_MODIFIABLE_FIELDS, value))._output)._ntrees;
            }
        } else {
            this._ntrees = ((XGBoostModel.XGBoostParameters) this._parms)._ntrees;
        }
        if (((XGBoostModel.XGBoostParameters) this._parms)._max_depth < 0) {
            error("_max_depth", "_max_depth must be >= 0.");
        }
        if (((XGBoostModel.XGBoostParameters) this._parms)._max_depth == 0) {
            ((XGBoostModel.XGBoostParameters) this._parms)._max_depth = Integer.MAX_VALUE;
        }
        if (z && error_count() > 0) {
            throw H2OModelBuilderIllegalArgumentException.makeFromBuilder(this);
        }
        if (((XGBoostModel.XGBoostParameters) this._parms)._backend == XGBoostModel.XGBoostParameters.Backend.gpu) {
            if (!hasGPU(((XGBoostModel.XGBoostParameters) this._parms)._gpu_id)) {
                error("_backend", "GPU backend (gpu_id: " + ((XGBoostModel.XGBoostParameters) this._parms)._gpu_id + ") is not functional. Check CUDA_PATH and/or GPU installation.");
            }
            if (H2O.getCloudSize() > 1) {
                error("_backend", "GPU backend is not supported in distributed mode.");
            }
            Map<String, Object> gpuIncompatibleParams = ((XGBoostModel.XGBoostParameters) this._parms).gpuIncompatibleParams();
            if (!gpuIncompatibleParams.isEmpty()) {
                for (Map.Entry<String, Object> entry : gpuIncompatibleParams.entrySet()) {
                    error("_backend", "GPU backend is not available for parameter setting '" + entry.getKey() + " = " + entry.getValue() + "'. Use CPU backend instead.");
                }
            }
        }
        if (((XGBoostModel.XGBoostParameters) this._parms)._distribution == DistributionFamily.quasibinomial) {
            error("_distribution", "Quasibinomial is not supported for XGBoost in current H2O.");
        }
        if (((XGBoostModel.XGBoostParameters) this._parms)._categorical_encoding == Model.Parameters.CategoricalEncodingScheme.Enum) {
            error("_categorical_encoding", "Enum encoding is not supported for XGBoost in current H2O.");
        }
        switch (AnonymousClass1.$SwitchMap$hex$genmodel$utils$DistributionFamily[((XGBoostModel.XGBoostParameters) this._parms)._distribution.ordinal()]) {
            case 1:
                if (this._nclass != 2) {
                    error("_distribution", H2O.technote(2, "Binomial requires the response to be a 2-class categorical"));
                    break;
                }
                break;
            case 2:
                if (this._nclass != 2) {
                    error("_distribution", H2O.technote(2, "Modified Huber requires the response to be a 2-class categorical."));
                    break;
                }
                break;
            case 3:
                if (!isClassifier()) {
                    error("_distribution", H2O.technote(2, "Multinomial requires an categorical response."));
                    break;
                }
                break;
            case 4:
                if (isClassifier()) {
                    error("_distribution", H2O.technote(2, "Huber requires the response to be numeric."));
                    break;
                }
                break;
            case 5:
                if (isClassifier()) {
                    error("_distribution", H2O.technote(2, "Poisson requires the response to be numeric."));
                    break;
                }
                break;
            case 6:
                if (isClassifier()) {
                    error("_distribution", H2O.technote(2, "Gamma requires the response to be numeric."));
                    break;
                }
                break;
            case 7:
                if (isClassifier()) {
                    error("_distribution", H2O.technote(2, "Tweedie requires the response to be numeric."));
                    break;
                }
                break;
            case 8:
                if (isClassifier()) {
                    error("_distribution", H2O.technote(2, "Gaussian requires the response to be numeric."));
                    break;
                }
                break;
            case 9:
                if (isClassifier()) {
                    error("_distribution", H2O.technote(2, "Laplace requires the response to be numeric."));
                    break;
                }
                break;
            case 10:
                if (isClassifier()) {
                    error("_distribution", H2O.technote(2, "Quantile requires the response to be numeric."));
                    break;
                }
                break;
            case 11:
                break;
            default:
                error("_distribution", "Invalid distribution: " + ((XGBoostModel.XGBoostParameters) this._parms)._distribution);
                break;
        }
        if (0.0d >= ((XGBoostModel.XGBoostParameters) this._parms)._learn_rate || ((XGBoostModel.XGBoostParameters) this._parms)._learn_rate > 1.0d) {
            error("_learn_rate", "learn_rate must be between 0 and 1");
        }
        if (0.0d >= ((XGBoostModel.XGBoostParameters) this._parms)._col_sample_rate || ((XGBoostModel.XGBoostParameters) this._parms)._col_sample_rate > 1.0d) {
            error("_col_sample_rate", "col_sample_rate must be between 0 and 1");
        }
        if (((XGBoostModel.XGBoostParameters) this._parms)._grow_policy == XGBoostModel.XGBoostParameters.GrowPolicy.lossguide && ((XGBoostModel.XGBoostParameters) this._parms)._tree_method != XGBoostModel.XGBoostParameters.TreeMethod.hist) {
            error("_grow_policy", "must use tree_method=hist for grow_policy=lossguide");
        }
        if (this._train != null && !((XGBoostModel.XGBoostParameters) this._parms).monotoneConstraints().isEmpty()) {
            if (((XGBoostModel.XGBoostParameters) this._parms)._tree_method == XGBoostModel.XGBoostParameters.TreeMethod.approx) {
                error("_tree_method", "approx is not supported with _monotone_constraints, use auto/exact/hist instead");
            } else if (!$assertionsDisabled && ((XGBoostModel.XGBoostParameters) this._parms)._tree_method != XGBoostModel.XGBoostParameters.TreeMethod.auto && ((XGBoostModel.XGBoostParameters) this._parms)._tree_method != XGBoostModel.XGBoostParameters.TreeMethod.exact && ((XGBoostModel.XGBoostParameters) this._parms)._tree_method != XGBoostModel.XGBoostParameters.TreeMethod.hist) {
                throw new AssertionError("Unexpected tree method used " + ((XGBoostModel.XGBoostParameters) this._parms)._tree_method);
            }
            TreeUtils.checkMonotoneConstraints(this, this._train, ((XGBoostModel.XGBoostParameters) this._parms)._monotone_constraints);
        }
        if (this._train != null && H2O.CLOUD.size() > 1 && ((XGBoostModel.XGBoostParameters) this._parms)._tree_method == XGBoostModel.XGBoostParameters.TreeMethod.exact && !((XGBoostModel.XGBoostParameters) this._parms)._build_tree_one_node) {
            error("_tree_method", "exact is not supported in distributed environment, set build_tree_one_node to true to use exact");
        }
        PlattScalingHelper.initCalibration(this, this._parms, z);
    }

    /* renamed from: getModelBuilder, reason: merged with bridge method [inline-methods] */
    public XGBoost m8getModelBuilder() {
        return this;
    }

    public Frame getCalibrationFrame() {
        return this._calib;
    }

    public void setCalibrationFrame(Frame frame) {
        this._calib = frame;
    }

    protected boolean canLearnFromNAs() {
        return true;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static DataInfo makeDataInfo(Frame frame, Frame frame2, XGBoostModel.XGBoostParameters xGBoostParameters, int i) {
        DataInfo dataInfo = new DataInfo(frame, frame2, 1, true, DataInfo.TransformType.NONE, DataInfo.TransformType.NONE, false, false, true, xGBoostParameters._weights_column != null, xGBoostParameters._offset_column != null, xGBoostParameters._fold_column != null);
        GLMTask.YMUTask doAll = new GLMTask.YMUTask(dataInfo, i, i == 1, false, true, true).doAll(dataInfo._adaptedFrame);
        if (xGBoostParameters._weights_column != null && xGBoostParameters._offset_column != null) {
            LOG.warn("Combination of offset and weights can lead to slight differences because Rollupstats aren't weighted - need to re-calculate weighted mean/sigma of the response including offset terms.");
        }
        if (xGBoostParameters._weights_column != null && xGBoostParameters._offset_column == null) {
            dataInfo.updateWeightedSigmaAndMean(doAll.predictorSDs(), doAll.predictorMeans());
            if (i == 1) {
                dataInfo.updateWeightedSigmaAndMeanForResponse(doAll.responseSDs(), doAll.responseMeans());
            }
        }
        dataInfo.coefNames();
        if ($assertionsDisabled || dataInfo._coefNames != null) {
            return dataInfo;
        }
        throw new AssertionError();
    }

    protected Frame rebalance(Frame frame, boolean z, String str) {
        if (!((XGBoostModel.XGBoostParameters) this._parms)._build_tree_one_node) {
            return super.rebalance(frame, z, str);
        }
        if (frame.anyVec().nChunks() == 1) {
            return frame;
        }
        LOG.info("Rebalancing " + str.substring(str.length() - 5) + " dataset onto a single node.");
        Key make = Key.make(str + ".1chk");
        H2O.submitTask(new RebalanceDataSet(frame, make, 1)).join();
        Frame frame2 = DKV.get(make).get();
        Scope.track(new Frame[]{frame2});
        return frame2;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static TwoDimTable createVarImpTable(String str, double[] dArr, String[] strArr) {
        return ModelMetrics.calcVarImp(dArr, strArr, "Variable Importances" + (str != null ? " - " + str : ""), new String[]{"Relative Importance", "Scaled Importance", "Percentage"});
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static XgbVarImp computeVarImp(Map<String, FeatureScore> map) {
        if (map.isEmpty()) {
            return null;
        }
        float[] fArr = new float[map.size()];
        float[] fArr2 = new float[map.size()];
        int[] iArr = new int[map.size()];
        String[] strArr = new String[map.size()];
        int i = 0;
        for (Map.Entry<String, FeatureScore> entry : map.entrySet()) {
            fArr[i] = entry.getValue()._gain;
            fArr2[i] = entry.getValue()._cover;
            iArr[i] = entry.getValue()._frequency;
            strArr[i] = entry.getKey();
            i++;
        }
        return new XgbVarImp(strArr, fArr, fArr2, iArr);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static boolean hasGPU(H2ONode h2ONode, int i) {
        boolean z;
        if (H2O.SELF.equals(h2ONode)) {
            z = hasGPU(i);
        } else {
            HasGPUTask hasGPUTask = new HasGPUTask(i, null);
            new RPC(h2ONode, hasGPUTask).call().get();
            z = hasGPUTask._hasGPU;
        }
        LOG.debug("Availability of GPU (id=" + i + ") on node " + h2ONode + ": " + z);
        return z;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static boolean hasGPU(int i) {
        if (i == 0 && DEFAULT_GPU_BLACKLISTED) {
            return false;
        }
        boolean hasGPU_impl = hasGPU_impl(i);
        if (i == 0 && !hasGPU_impl) {
            DEFAULT_GPU_BLACKLISTED = true;
        }
        return hasGPU_impl;
    }

    private static synchronized boolean hasGPU_impl(int i) {
        if (!XGBoostExtension.isGpuSupportEnabled()) {
            return false;
        }
        if (GPUS.contains(Integer.valueOf(i))) {
            return true;
        }
        try {
            DMatrix dMatrix = new DMatrix(new float[]{1.0f, 2.0f, 1.0f, 2.0f}, 2, 2);
            dMatrix.setLabel(new float[]{1.0f, 0.0f});
            HashMap hashMap = new HashMap();
            hashMap.put("updater", "grow_gpu_hist");
            hashMap.put("silent", 1);
            hashMap.put("gpu_id", Integer.valueOf(i));
            HashMap hashMap2 = new HashMap();
            hashMap2.put("train", dMatrix);
            try {
                try {
                    Rabit.init(new HashMap());
                    ml.dmlc.xgboost4j.java.XGBoost.train(dMatrix, hashMap, 1, hashMap2, (IObjective) null, (IEvaluation) null);
                    GPUS.add(Integer.valueOf(i));
                    return true;
                } catch (XGBoostError e) {
                    try {
                        Rabit.shutdown();
                    } catch (XGBoostError e2) {
                        LOG.warn("Cannot shutdown XGBoost Rabit for current thread.");
                    }
                    return false;
                }
            } finally {
                try {
                    Rabit.shutdown();
                } catch (XGBoostError e3) {
                    LOG.warn("Cannot shutdown XGBoost Rabit for current thread.");
                }
            }
        } catch (XGBoostError e4) {
            throw new IllegalStateException("Couldn't prepare training matrix for XGBoost.", e4);
        }
    }

    public void cv_computeAndSetOptimalParameters(ModelBuilder<XGBoostModel, XGBoostModel.XGBoostParameters, XGBoostOutput>[] modelBuilderArr) {
        if (((XGBoostModel.XGBoostParameters) this._parms)._stopping_rounds == 0 && ((XGBoostModel.XGBoostParameters) this._parms)._max_runtime_secs == 0.0d) {
            return;
        }
        ((XGBoostModel.XGBoostParameters) this._parms)._stopping_rounds = 0;
        ((XGBoostModel.XGBoostParameters) this._parms)._max_runtime_secs = 0.0d;
        int i = 0;
        for (ModelBuilder<XGBoostModel, XGBoostModel.XGBoostParameters, XGBoostOutput> modelBuilder : modelBuilderArr) {
            i += ((XGBoostOutput) DKV.getGet(modelBuilder.dest())._output)._ntrees;
        }
        ((XGBoostModel.XGBoostParameters) this._parms)._ntrees = (int) (i / modelBuilderArr.length);
        warn("_ntrees", "Setting optimal _ntrees to " + ((XGBoostModel.XGBoostParameters) this._parms)._ntrees + " for cross-validation main model based on early stopping of cross-validation models.");
        warn("_stopping_rounds", "Disabling convergence-based early stopping for cross-validation main model.");
        warn("_max_runtime_secs", "Disabling maximum allowed runtime for cross-validation main model.");
    }

    static {
        $assertionsDisabled = !XGBoost.class.desiredAssertionStatus();
        LOG = Logger.getLogger(XGBoost.class);
        DEFAULT_GPU_BLACKLISTED = false;
        GPUS = new HashSet();
    }
}
