package hex;

import au.com.bytecode.opencsv.CSVWriter;
import hex.Model;
import hex.Model.Output;
import hex.Model.Parameters;
import hex.ModelMetrics;
import hex.ModelTrainingEventsPublisher;
import hex.ScoreKeeper;
import hex.genmodel.MojoModel;
import hex.genmodel.utils.DistributionFamily;
import java.io.IOException;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Optional;
import java.util.TreeSet;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingQueue;
import jsr166y.CountedCompleter;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;
import water.AutoBuffer;
import water.DKV;
import water.ExtensionManager;
import water.Futures;
import water.H2O;
import water.Iced;
import water.Job;
import water.Key;
import water.Keyed;
import water.ListenerService;
import water.MRTask;
import water.Scope;
import water.Value;
import water.api.FSIOException;
import water.api.HDFSIOException;
import water.exceptions.H2OIllegalArgumentException;
import water.exceptions.H2OModelBuilderIllegalArgumentException;
import water.fvec.Chunk;
import water.fvec.ChunkVisitor;
import water.fvec.FileVec;
import water.fvec.Frame;
import water.fvec.NewChunk;
import water.fvec.RebalanceDataSet;
import water.fvec.Vec;
import water.rapids.ast.prims.advmath.AstKFold;
import water.udf.CFuncRef;
import water.util.ArrayUtils;
import water.util.Countdown;
import water.util.FrameUtils;
import water.util.IcedHashMap;
import water.util.Log;
import water.util.MRUtils;
import water.util.MathUtils;
import water.util.PrettyPrint;
import water.util.StringUtils;
import water.util.TwoDimTable;
import water.util.VecUtils;

/* loaded from: input_file:hex/ModelBuilder.class */
public abstract class ModelBuilder<M extends Model<M, P, O>, P extends Model.Parameters, O extends Model.Output> extends Iced {
    private transient Workspace _workspace;
    public Job<M> _job;
    protected Key<M> _result;
    public String _desc;
    private Countdown _build_model_countdown;
    private Countdown _build_step_countdown;
    private static String[] ALGOBASES;
    private static String[] SCHEMAS;
    private static ModelBuilder[] BUILDERS;
    protected boolean _startUpOnceModelBuilder;
    public P _parms;
    public P _input_parms;
    protected transient Frame _train;
    protected transient Frame _origTrain;
    protected transient Frame _valid;
    private static final Map<String, Class<? extends ModelBuilder>> _builders;
    private static final Map<Class<? extends Model>, String> _model_class_to_algo;
    private static final Map<String, String> _algo_to_algo_full_name;
    private static final Map<String, Class<? extends Model>> _algo_to_model_class;
    protected transient ModelTrainingEventsPublisher _eventPublisher;
    protected transient ModelBuilder<M, P, O>.ModelTrainingCoordinator _coordinator;
    protected transient Vec _response;
    protected transient Vec _vresponse;
    protected transient Vec _offset;
    protected transient Vec _weights;
    protected transient Vec _fold;
    protected transient Vec _treatment;
    protected transient String[] _origNames;
    protected transient String[][] _origDomains;
    protected transient double[] _orig_projection_array;
    protected int _nclass;
    transient double[] _distribution;
    protected transient double[] _priorClassDist;
    public ValidationMessage[] _messages;
    private int _error_count;
    public transient HashSet<String> _removedCols;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:hex/ModelBuilder$ApproximatingHoldoutPredictionCombiner.class */
    public static class ApproximatingHoldoutPredictionCombiner extends HoldoutPredictionCombiner {
        private final int _precision;

        public ApproximatingHoldoutPredictionCombiner(int i, int i2, int i3) {
            super(i, i2);
            this._precision = i3;
        }

        @Override // hex.ModelBuilder.HoldoutPredictionCombiner
        protected void populateChunk(NewChunk newChunk, double[] dArr) {
            long pow10i = PrettyPrint.pow10i(this._precision);
            for (double d : dArr) {
                if (Double.isNaN(d)) {
                    newChunk.addNA();
                } else {
                    newChunk.addNum(Math.round(d * pow10i), -this._precision);
                }
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:hex/ModelBuilder$Barrier.class */
    public static class Barrier extends CountedCompleter {
        private Barrier() {
        }

        @Override // jsr166y.CountedCompleter
        public void compute() {
        }
    }

    /* loaded from: input_file:hex/ModelBuilder$BuilderVisibility.class */
    public enum BuilderVisibility {
        Experimental,
        Beta,
        Stable;

        public static BuilderVisibility valueOfIgnoreCase(String str) throws IllegalArgumentException {
            BuilderVisibility[] values = values();
            for (int i = 0; i < values.length; i++) {
                if (values[i].name().equalsIgnoreCase(str)) {
                    return values[i];
                }
            }
            throw new IllegalArgumentException(String.format("Algorithm availability level of '%s' is not known. Available levels: %s", str, Arrays.toString(values)));
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:hex/ModelBuilder$Driver.class */
    public abstract class Driver extends H2O.H2OCountedCompleter<ModelBuilder<M, P, O>.Driver> {
        private ModelBuilderListener _callback;

        /* JADX INFO: Access modifiers changed from: protected */
        public Driver() {
        }

        public void setCallback(ModelBuilderListener modelBuilderListener) {
            this._callback = modelBuilderListener;
        }

        @Override // water.H2O.H2OCountedCompleter
        public void compute2() {
            try {
                Scope.enter();
                ModelBuilder.this._parms.read_lock_frames(ModelBuilder.this._job);
                computeImpl();
                computeParameters();
                ModelBuilder.this.saveModelCheckpointIfConfigured();
                ModelBuilder.this.notifyModelListeners();
                ModelBuilder.this._parms.read_unlock_frames(ModelBuilder.this._job);
                if (ModelBuilder.this._parms._is_cv_model) {
                    Scope.exit(ModelBuilder.this._workspace == null ? new Key[0] : (Key[]) ModelBuilder.this._workspace.getToDelete(true).keySet().toArray(new Key[0]));
                } else {
                    ModelBuilder.this.cleanUp();
                    Scope.exit(new Key[0]);
                }
                tryComplete();
            } catch (Throwable th) {
                ModelBuilder.this._parms.read_unlock_frames(ModelBuilder.this._job);
                if (ModelBuilder.this._parms._is_cv_model) {
                    Scope.exit(ModelBuilder.this._workspace == null ? new Key[0] : (Key[]) ModelBuilder.this._workspace.getToDelete(true).keySet().toArray(new Key[0]));
                } else {
                    ModelBuilder.this.cleanUp();
                    Scope.exit(new Key[0]);
                }
                throw th;
            }
        }

        @Override // jsr166y.CountedCompleter
        public void onCompletion(CountedCompleter countedCompleter) {
            ModelBuilder.this.setFinalState();
            if (this._callback != null) {
                this._callback.onModelSuccess(ModelBuilder.this._result.get());
            }
        }

        @Override // jsr166y.CountedCompleter
        public boolean onExceptionalCompletion(Throwable th, CountedCompleter countedCompleter) {
            ModelBuilder.this.setFinalState();
            if (this._callback == null) {
                return true;
            }
            this._callback.onModelFailure(th, ModelBuilder.this._parms);
            return true;
        }

        public abstract void computeImpl();

        public final void computeParameters() {
            M m = ModelBuilder.this._result.get();
            if (m != null) {
                m.write_lock(ModelBuilder.this._job);
                m.setInputParms(ModelBuilder.this._input_parms);
                m.update(ModelBuilder.this._job);
                m.unlock(ModelBuilder.this._job);
            }
        }
    }

    /* loaded from: input_file:hex/ModelBuilder$FilterCols.class */
    public abstract class FilterCols {
        final int _specialVecs;

        public FilterCols(int i) {
            this._specialVecs = i;
        }

        protected abstract boolean filter(Vec vec, String str);

        public void doIt(Frame frame, String str, boolean z) {
            ArrayList arrayList = new ArrayList();
            for (int i = 0; i < frame.vecs().length - this._specialVecs; i++) {
                if (filter(frame.vec(i), frame._names[i])) {
                    arrayList.add(Integer.valueOf(i));
                }
            }
            if (arrayList.isEmpty()) {
                return;
            }
            ModelBuilder.this._removedCols = new HashSet<>(arrayList.size());
            int[] iArr = new int[arrayList.size()];
            for (int i2 = 0; i2 < iArr.length; i2++) {
                iArr[i2] = ((Integer) arrayList.get(i2)).intValue();
                ModelBuilder.this._removedCols.add(frame._names[iArr[i2]]);
            }
            frame.remove(iArr);
            String str2 = str + ModelBuilder.this._removedCols.toString();
            ModelBuilder.this.warn("_train", str2);
            if (z) {
                Log.info(str2);
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:hex/ModelBuilder$HoldoutPredictionCombiner.class */
    public static class HoldoutPredictionCombiner extends MRTask<HoldoutPredictionCombiner> {
        int _folds;
        int _cols;

        public HoldoutPredictionCombiner(int i, int i2) {
            this._folds = i;
            this._cols = i2;
        }

        @Override // water.MRTask
        public final void map(Chunk[] chunkArr, NewChunk[] newChunkArr) {
            for (int i = 0; i < this._cols; i++) {
                double[] dArr = new double[chunkArr[0].len()];
                ChunkVisitor.CombiningDoubleAryVisitor combiningDoubleAryVisitor = new ChunkVisitor.CombiningDoubleAryVisitor(dArr);
                for (int i2 = 0; i2 < this._folds; i2++) {
                    chunkArr[(i2 * this._cols) + i].processRows(combiningDoubleAryVisitor, 0, dArr.length);
                    combiningDoubleAryVisitor.reset();
                }
                populateChunk(newChunkArr[i], dArr);
            }
        }

        protected void populateChunk(NewChunk newChunk, double[] dArr) {
            newChunk.setDoubles(dArr);
        }
    }

    /* loaded from: input_file:hex/ModelBuilder$ModelTrainingCoordinator.class */
    public class ModelTrainingCoordinator {
        private final BlockingQueue<ModelTrainingEventsPublisher.Event> _events;
        private final ModelBuilder<M, P, O>[] _cvModelBuilders;
        private int _inProgress;

        public ModelTrainingCoordinator(BlockingQueue<ModelTrainingEventsPublisher.Event> blockingQueue, ModelBuilder<M, P, O>[] modelBuilderArr) {
            this._events = blockingQueue;
            this._cvModelBuilders = modelBuilderArr;
            this._inProgress = this._cvModelBuilders.length;
        }

        public void initStoppingParameters() {
            ModelBuilder.this.cv_initStoppingParameters();
        }

        public void updateParameters() {
            while (this._inProgress > 0) {
                try {
                    switch (this._events.take()) {
                        case ALL_DONE:
                            this._inProgress--;
                            break;
                        case ONE_DONE:
                            if (!ModelBuilder.this.cv_updateOptimalParameters(this._cvModelBuilders)) {
                                break;
                            } else {
                                return;
                            }
                    }
                } catch (InterruptedException e) {
                    Thread.currentThread().interrupt();
                    throw new RuntimeException("Failed to update model parameters based on result of CV model training", e);
                }
            }
            ModelBuilder.this.cv_updateOptimalParameters(this._cvModelBuilders);
        }
    }

    /* loaded from: input_file:hex/ModelBuilder$TrainModelNestedRunnable.class */
    private static class TrainModelNestedRunnable extends H2O.RemoteRunnable<TrainModelNestedRunnable> {
        private Job<?> _job;
        private Key<Model> _key;
        private Model.Parameters _parms;
        private Frame _fr;

        private TrainModelNestedRunnable(Job<?> job, Key<Model> key, Model.Parameters parameters, Frame frame) {
            this._job = job;
            this._key = key;
            this._parms = parameters;
            this._fr = frame;
        }

        @Override // water.H2O.RemoteRunnable
        public void run() {
            ModelBuilder make = ModelBuilder.make(this._parms.algoName(), this._job, this._key);
            make._parms = (P) this._parms;
            make._input_parms = (P) this._parms.m829clone();
            make.trainModelNested(this._fr);
        }
    }

    /* loaded from: input_file:hex/ModelBuilder$TrainModelRunnable.class */
    private static class TrainModelRunnable extends H2O.RemoteRunnable<TrainModelRunnable> {
        private transient ModelBuilder _mb;
        private Job<Model> _job;
        private Key<Model> _key;
        private Model.Parameters _parms;
        private Model.Parameters _input_parms;

        private TrainModelRunnable(ModelBuilder modelBuilder) {
            this._mb = modelBuilder;
            this._job = this._mb._job;
            this._key = this._job._result;
            this._parms = this._mb._parms;
            this._input_parms = this._mb._input_parms;
        }

        @Override // water.H2O.RemoteRunnable
        public void setupOnRemote() {
            this._mb = ModelBuilder.make(this._parms.algoName(), this._job, this._key);
            this._mb._parms = (P) this._parms;
            this._mb._input_parms = (P) this._input_parms;
            this._mb.init(false);
        }

        @Override // water.H2O.RemoteRunnable
        public void run() {
            this._mb.trainModel();
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:hex/ModelBuilder$TrainModelTaskController.class */
    public class TrainModelTaskController {
        private final ModelBuilder<M, P, O>.Driver _driver;
        private final Barrier _barrier;

        TrainModelTaskController(ModelBuilder<M, P, O>.Driver driver, Barrier barrier) {
            this._driver = driver;
            this._barrier = barrier;
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public void join() {
            this._barrier.join();
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public void cancel(boolean z) {
            this._driver.cancel(z);
        }
    }

    /* loaded from: input_file:hex/ModelBuilder$ValidationMessage.class */
    public static final class ValidationMessage extends Iced {
        final byte _log_level;
        final String _field_name;
        final String _message;

        public ValidationMessage(byte b, String str, String str2) {
            this._log_level = b;
            this._field_name = str;
            this._message = str2;
            Log.log(b, str + ": " + str2);
        }

        public int log_level() {
            return this._log_level;
        }

        public String field() {
            return this._field_name;
        }

        public String message() {
            return this._message;
        }

        public String toString() {
            return Log.LVLS[this._log_level] + " on field: " + this._field_name + ": " + this._message;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:hex/ModelBuilder$Workspace.class */
    public static class Workspace {
        private final IcedHashMap<Key, String> _toDelete;

        private Workspace(boolean z) {
            this._toDelete = z ? new IcedHashMap<>() : null;
        }

        IcedHashMap<Key, String> getToDelete(boolean z) {
            if (!z) {
                return null;
            }
            if (this._toDelete == null) {
                throw new IllegalStateException("ModelBuilder was not correctly initialized. Expensive phase requires field `_toDelete` to be non-null. Does your implementation of init method call super.init(true) or alternatively initWorkspace(true)?");
            }
            return this._toDelete;
        }

        void cleanUp() {
            if (this._toDelete == null) {
                return;
            }
            for (Key key : (Key[]) this._toDelete.keySet().toArray(new Key[0])) {
                Value value = DKV.get(key);
                if (value != null) {
                    if (value.isFrame()) {
                        Scope.track((Frame) value.get(Frame.class));
                    } else if (value.isVec()) {
                        Scope.track((Vec) value.get(Vec.class));
                    } else {
                        Scope.track_generic((Keyed) value.get(Keyed.class));
                    }
                }
            }
        }
    }

    public ToEigenVec getToEigenVec() {
        return null;
    }

    public boolean shouldReorder(Vec vec) {
        return this._parms._categorical_encoding.needsResponse() && isSupervised();
    }

    public final M get() {
        if ($assertionsDisabled || this._job._result == this._result) {
            return this._job.get();
        }
        throw new AssertionError();
    }

    public final boolean isStopped() {
        return this._job.isStopped();
    }

    public final Key<M> dest() {
        return this._result;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public final void startClock() {
        this._build_model_countdown = Countdown.fromSeconds(this._parms._max_runtime_secs);
        this._build_model_countdown.start();
    }

    protected boolean timeout() {
        return this._build_step_countdown != null ? this._build_step_countdown.timedOut() : this._build_model_countdown.timedOut();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public boolean stop_requested() {
        return this._job.stop_requested() || timeout();
    }

    protected long remainingTimeSecs() {
        return (long) Math.ceil(this._build_model_countdown.remainingTime() / 1000.0d);
    }

    public static <S extends Model> Key<S> defaultKey(String str) {
        return Key.make(H2O.calcNextUniqueModelId(str));
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public ModelBuilder(P p) {
        this(p, defaultKey(p.algoName()));
    }

    protected ModelBuilder(P p, Key<M> key) {
        this._workspace = new Workspace(false);
        this._desc = "Main model";
        this._startUpOnceModelBuilder = false;
        this._messages = new ValidationMessage[0];
        this._error_count = -1;
        this._removedCols = new HashSet<>();
        this._result = key;
        this._job = new Job<>(key, p.javaName(), p.algoName());
        this._parms = p;
        this._input_parms = (P) p.m829clone();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public ModelBuilder(P p, Job<M> job) {
        this._workspace = new Workspace(false);
        this._desc = "Main model";
        this._startUpOnceModelBuilder = false;
        this._messages = new ValidationMessage[0];
        this._error_count = -1;
        this._removedCols = new HashSet<>();
        this._job = job;
        this._result = defaultKey(p.algoName());
        this._parms = p;
        this._input_parms = (P) p.m829clone();
    }

    public static String[] algos() {
        return ALGOBASES;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public ModelBuilder(P p, boolean z) {
        this(p, z, "hex.schemas.");
    }

    protected ModelBuilder(P p, boolean z, String str) {
        this._workspace = new Workspace(false);
        this._desc = "Main model";
        this._startUpOnceModelBuilder = false;
        this._messages = new ValidationMessage[0];
        this._error_count = -1;
        this._removedCols = new HashSet<>();
        String name = getName();
        if (!z) {
            throw H2O.fail("Algorithm " + name + " registration issue. It can only be called at startup.");
        }
        this._startUpOnceModelBuilder = true;
        this._job = null;
        this._result = null;
        this._parms = p;
        init(false);
        if (ArrayUtils.find(ALGOBASES, name) != -1) {
            throw H2O.fail("Only called once at startup per ModelBuilder, and " + name + " has already been called");
        }
        ALGOBASES = (String[]) Arrays.copyOf(ALGOBASES, ALGOBASES.length + 1);
        BUILDERS = (ModelBuilder[]) Arrays.copyOf(BUILDERS, BUILDERS.length + 1);
        SCHEMAS = (String[]) Arrays.copyOf(SCHEMAS, SCHEMAS.length + 1);
        ALGOBASES[ALGOBASES.length - 1] = name;
        BUILDERS[BUILDERS.length - 1] = this;
        SCHEMAS[SCHEMAS.length - 1] = str;
    }

    public static String algoName(String str) {
        return BUILDERS[ArrayUtils.find(ALGOBASES, str)]._parms.algoName();
    }

    public static String javaName(String str) {
        return BUILDERS[ArrayUtils.find(ALGOBASES, str)]._parms.javaName();
    }

    public static String paramName(String str) {
        return algoName(str) + "Parameters";
    }

    public static String schemaDirectory(String str) {
        return SCHEMAS[ArrayUtils.find(ALGOBASES, str)];
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static <B extends ModelBuilder> Optional<B> getRegisteredBuilder(String str) {
        int find = ArrayUtils.find(ALGOBASES, str.toLowerCase());
        return find < 0 ? Optional.empty() : Optional.of(BUILDERS[find]);
    }

    public static <P extends Model.Parameters> P makeParameters(String str) {
        return make(str, null, null)._parms;
    }

    public static <B extends ModelBuilder> B make(String str, Job job, Key<Model> key) {
        return (B) getRegisteredBuilder(str).map(modelBuilder -> {
            ModelBuilder modelBuilder = (ModelBuilder) modelBuilder.m829clone();
            modelBuilder._job = job;
            modelBuilder._result = key;
            modelBuilder._parms = (P) modelBuilder._parms.m829clone();
            modelBuilder._input_parms = (P) modelBuilder._parms.m829clone();
            return modelBuilder;
        }).orElseThrow(() -> {
            StringBuilder sb = new StringBuilder();
            sb.append("Unknown algo: '").append(str).append("'; Extension report: ");
            Log.err(ExtensionManager.getInstance().makeExtensionReport(sb));
            return new IllegalStateException("Algorithm '" + str + "' is not registered. Available algos: [" + StringUtils.join(",", ALGOBASES) + "]");
        });
    }

    public static <B extends ModelBuilder, MP extends Model.Parameters> B make(MP mp) {
        return (B) make(mp, defaultKey(mp.algoName()));
    }

    public static <B extends ModelBuilder, MP extends Model.Parameters> B make(MP mp, Key<Model> key) {
        B b = (B) make(mp.algoName(), new Job(key, mp.javaName(), mp.algoName()), key);
        b._parms = (P) mp.m829clone();
        b._input_parms = (P) mp.m829clone();
        return b;
    }

    public final Frame train() {
        return this._train;
    }

    public void setTrain(Frame frame) {
        this._train = frame;
    }

    public void setValid(Frame frame) {
        this._valid = frame;
    }

    public final Frame valid() {
        return this._valid;
    }

    public Vec response() {
        return this._response;
    }

    public Vec vresponse() {
        return this._vresponse == null ? this._response : this._vresponse;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public void setFinalState() {
        Key<M> dest = dest();
        if (dest == null) {
            return;
        }
        M m = dest.get();
        if (m != null && m._output != null) {
            m._output._job = this._job;
            m._output.stopClock();
            m.write_lock(this._job);
            m.update(this._job);
            m.unlock(this._job);
        }
        Log.info("Completing model " + dest);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public void saveModelCheckpointIfConfigured() {
        M m = this._result.get();
        if (m == null || StringUtils.isNullOrEmpty(m._parms._export_checkpoints_dir)) {
            return;
        }
        try {
            m.exportBinaryModel(m._parms._export_checkpoints_dir + "/" + m._key.toString(), true, new ModelExportOption[0]);
        } catch (IOException | FSIOException | HDFSIOException e) {
            throw new H2OIllegalArgumentException("export_checkpoints_dir", "saveModelIfConfigured", e);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    public void notifyModelListeners() {
        ListenerService.getInstance().report("model_completed", this._result.get(), this._parms);
    }

    public Job<M> trainModelOnH2ONode() {
        if (error_count() > 0) {
            throw H2OModelBuilderIllegalArgumentException.makeFromBuilder(this);
        }
        this._input_parms = (P) this._parms.m829clone();
        H2O.runOnH2ONode(new TrainModelRunnable());
        return this._job;
    }

    public final Job<M> trainModel() {
        return trainModel(null);
    }

    public final Job<M> trainModel(final ModelBuilderListener modelBuilderListener) {
        if (error_count() > 0) {
            throw H2OModelBuilderIllegalArgumentException.makeFromBuilder(this);
        }
        startClock();
        if (nFoldCV()) {
            return this._job.start(new H2O.H2OCountedCompleter() { // from class: hex.ModelBuilder.1
                @Override // water.H2O.H2OCountedCompleter
                public void compute2() {
                    ModelBuilder.this.computeCrossValidation();
                    tryComplete();
                }

                @Override // jsr166y.CountedCompleter
                public void onCompletion(CountedCompleter countedCompleter) {
                    if (modelBuilderListener != null) {
                        modelBuilderListener.onModelSuccess(ModelBuilder.this._result.get());
                    }
                }

                @Override // jsr166y.CountedCompleter
                public boolean onExceptionalCompletion(Throwable th, CountedCompleter countedCompleter) {
                    Log.warn("Model training job " + ModelBuilder.this._job._description + " completed with exception: " + th);
                    if (modelBuilderListener != null) {
                        modelBuilderListener.onModelFailure(th, ModelBuilder.this._parms);
                    }
                    try {
                        Keyed.remove(ModelBuilder.this._job._result);
                        return true;
                    } catch (Exception e) {
                        Log.warn("Exception thrown when removing result from job " + ModelBuilder.this._job._description, e);
                        return true;
                    }
                }
            }, (nFoldWork() + 1) * this._parms.progressUnits(), this._parms._max_runtime_secs);
        }
        ModelBuilder<M, P, O>.Driver trainModelImpl = trainModelImpl();
        trainModelImpl.setCallback(modelBuilderListener);
        return this._job.start(trainModelImpl, this._parms.progressUnits(), this._parms._max_runtime_secs);
    }

    public final M trainModelNested(Frame frame) {
        if (frame != null) {
            setTrain(frame);
        }
        if (error_count() > 0) {
            throw H2OModelBuilderIllegalArgumentException.makeFromBuilder(this);
        }
        startClock();
        if (nFoldCV()) {
            computeCrossValidation();
        } else {
            submitTrainModelTask().join();
        }
        return this._result.get();
    }

    public static <MP extends Model.Parameters> Model trainModelNested(Job<?> job, Key<Model> key, MP mp, Frame frame) {
        H2O.runOnH2ONode(new TrainModelNestedRunnable(job, key, mp, frame));
        return key.get();
    }

    protected abstract ModelBuilder<M, P, O>.Driver trainModelImpl();

    /* JADX INFO: Access modifiers changed from: package-private */
    public ModelBuilder<M, P, O>.TrainModelTaskController submitTrainModelTask() {
        ModelBuilder<M, P, O>.Driver trainModelImpl = trainModelImpl();
        Barrier barrier = new Barrier();
        trainModelImpl.setCompleter(barrier);
        H2O.submitTask(trainModelImpl);
        return new TrainModelTaskController(trainModelImpl, barrier);
    }

    @Deprecated
    protected int nModelsInParallel() {
        return 0;
    }

    protected int nModelsInParallel(int i) {
        int nModelsInParallel = nModelsInParallel();
        return nModelsInParallel > 0 ? nModelsInParallel : nModelsInParallel(i, 1);
    }

    protected int nModelsInParallel(int i, int i2) {
        if (!this._parms._parallelize_cross_validation) {
            return 1;
        }
        int i3 = i2;
        if (this._train.byteSize() < smallDataSize()) {
            i3 = i;
        }
        return Math.min(i3, (int) H2O.ARGS.nthreads);
    }

    protected long smallDataSize() {
        return 1000000L;
    }

    private double maxRuntimeSecsPerModel(int i, int i2) {
        return i > 0 ? this._parms._max_runtime_secs / Math.ceil((i / i2) + 1.0d) : this._parms._max_runtime_secs;
    }

    protected int nFoldWork() {
        return this._parms._fold_column == null ? this._parms._nfolds : FoldAssignment.nFoldWork(this._parms._train.get().vec(this._parms._fold_column));
    }

    public void computeCrossValidation() {
        boolean z;
        if (!$assertionsDisabled && !this._job.isRunning()) {
            throw new AssertionError();
        }
        this._job.setReadyForView(false);
        int nFoldWork = nFoldWork();
        ModelBuilder<M, P, O>[] modelBuilderArr = null;
        try {
            try {
                Scope.enter();
                init(false);
                Vec[] cv_makeWeights = cv_makeWeights(nFoldWork, cv_AssignFold(nFoldWork));
                modelBuilderArr = cv_makeFramesAndBuilders(nFoldWork, cv_makeWeights);
                if (useParallelMainModelBuilding(nFoldWork)) {
                    int nModelsInParallel = nModelsInParallel(nFoldWork);
                    Log.info(this._desc + " will be trained in parallel to the Cross-Validation models (up to " + nModelsInParallel + " models running at the same time).");
                    LinkedBlockingQueue linkedBlockingQueue = new LinkedBlockingQueue();
                    for (ModelBuilder<M, P, O> modelBuilder : modelBuilderArr) {
                        modelBuilder._eventPublisher = new ModelTrainingEventsPublisher(linkedBlockingQueue);
                    }
                    this._coordinator = new ModelTrainingCoordinator(linkedBlockingQueue, modelBuilderArr);
                    ModelBuilder[] modelBuilderArr2 = (ModelBuilder[]) Arrays.copyOf(modelBuilderArr, modelBuilderArr.length + 1);
                    modelBuilderArr2[modelBuilderArr2.length - 1] = this;
                    new SubModelBuilder(this._job, modelBuilderArr2, nModelsInParallel).bulkBuildModels();
                    z = false;
                } else {
                    cv_buildModels(nFoldWork, modelBuilderArr);
                    z = true;
                }
                ModelMetrics.MetricBuilder[] cv_scoreCVModels = cv_scoreCVModels(nFoldWork, cv_makeWeights, modelBuilderArr);
                if (z) {
                    buildMainModel((long) (maxRuntimeSecsPerModel(nFoldWork, nModelsInParallel(nFoldWork)) * 1000.0d));
                }
                if (!modelBuilderArr[0].getName().equals("infogram")) {
                    cv_mainModelScores(nFoldWork, cv_scoreCVModels, modelBuilderArr);
                }
                this._job.setReadyForView(true);
                DKV.put(this._job);
                if (modelBuilderArr != null) {
                    for (ModelBuilder<M, P, O> modelBuilder2 : modelBuilderArr) {
                        modelBuilder2.cleanUp();
                    }
                }
                cleanUp();
                Scope.exit(new Key[0]);
            } catch (Exception e) {
                if (modelBuilderArr != null) {
                    Futures futures = new Futures();
                    for (ModelBuilder<M, P, O> modelBuilder3 : modelBuilderArr) {
                        DKV.remove(modelBuilder3._parms._train, futures);
                        DKV.remove(modelBuilder3._parms._valid, futures);
                        DKV.remove(Key.make(modelBuilder3.getPredictionKey()), futures);
                        Keyed.remove(modelBuilder3._result, futures, true);
                    }
                    futures.blockForPending();
                }
                throw e;
            }
        } catch (Throwable th) {
            if (modelBuilderArr != null) {
                for (ModelBuilder<M, P, O> modelBuilder4 : modelBuilderArr) {
                    modelBuilder4.cleanUp();
                }
            }
            cleanUp();
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    FoldAssignment cv_AssignFold(int i) {
        Vec stratifiedKFoldColumn;
        if (!$assertionsDisabled && i < 2) {
            throw new AssertionError();
        }
        Vec vec = train().vec(this._parms._fold_column);
        if (vec != null) {
            return FoldAssignment.fromUserFoldSpecification(i, vec);
        }
        long orMakeRealSeed = this._parms.getOrMakeRealSeed();
        Log.info("Creating " + i + " cross-validation splits with random number seed: " + orMakeRealSeed);
        switch (this._parms._fold_assignment) {
            case AUTO:
            case Random:
                stratifiedKFoldColumn = AstKFold.kfoldColumn(train().anyVec().makeZero(), i, orMakeRealSeed);
                break;
            case Modulo:
                stratifiedKFoldColumn = AstKFold.moduloKfoldColumn(train().anyVec().makeZero(), i);
                break;
            case Stratified:
                stratifiedKFoldColumn = AstKFold.stratifiedKFoldColumn(response(), i, orMakeRealSeed);
                break;
            default:
                throw H2O.unimpl();
        }
        return FoldAssignment.fromInternalFold(i, stratifiedKFoldColumn);
    }

    Vec[] cv_makeWeights(final int i, FoldAssignment foldAssignment) {
        String str = this._parms._weights_column;
        Vec vec = str != null ? train().vec(str) : train().anyVec().makeCon(1.0d);
        Vec[] vecs = new MRTask() { // from class: hex.ModelBuilder.2
            @Override // water.MRTask
            public void map(Chunk[] chunkArr, NewChunk[] newChunkArr) {
                Chunk chunk = chunkArr[0];
                Chunk chunk2 = chunkArr[1];
                for (int i2 = 0; i2 < chunk2._len; i2++) {
                    int atd = (int) chunk.atd(i2);
                    double atd2 = chunk2.atd(i2);
                    int i3 = 0;
                    while (i3 < i) {
                        boolean z = atd == i3;
                        newChunkArr[2 * i3].addNum(z ? CMAESOptimizer.DEFAULT_STOPFITNESS : atd2);
                        newChunkArr[(2 * i3) + 1].addNum(z ? atd2 : CMAESOptimizer.DEFAULT_STOPFITNESS);
                        i3++;
                    }
                }
            }
        }.doAll(2 * i, (byte) 3, new Frame(foldAssignment.getAdaptedFold(), vec)).outputFrame().vecs();
        if (str == null) {
            vec.remove();
        }
        if (this._parms._keep_cross_validation_fold_assignment) {
            DKV.put(foldAssignment.toFrame(Key.make("cv_fold_assignment_" + this._result.toString())));
        }
        foldAssignment.remove(this._parms._keep_cross_validation_fold_assignment);
        for (Vec vec2 : vecs) {
            if (vec2.isConst()) {
                throw new H2OIllegalArgumentException("Not enough data to create " + i + " random cross-validation splits. Either reduce nfolds, specify a larger dataset (or specify another random number seed, if applicable).");
            }
        }
        return vecs;
    }

    /* JADX WARN: Multi-variable type inference failed */
    private ModelBuilder<M, P, O>[] cv_makeFramesAndBuilders(int i, Vec[] vecArr) {
        long checksum = this._parms.checksum();
        String key = this._result.toString();
        if (train().find("__internal_cv_weights__") != -1) {
            throw new H2OIllegalArgumentException("Frame cannot contain a Vec called '__internal_cv_weights__'.");
        }
        Frame frame = new Frame(train().names(), train().vecs());
        if (this._parms._weights_column != null) {
            frame.remove(this._parms._weights_column);
        }
        ModelBuilder<M, P, O>[] modelBuilderArr = new ModelBuilder[i];
        ArrayList<Frame> arrayList = new ArrayList();
        double maxRuntimeSecsPerModel = maxRuntimeSecsPerModel(i, nModelsInParallel(i));
        for (int i2 = 0; i2 < i; i2++) {
            String str = key + "_cv_" + (i2 + 1);
            Frame frame2 = new Frame(Key.make(str + "_train"), frame.names(), frame.vecs());
            frame2.write_lock(this._job);
            frame2.add("__internal_cv_weights__", vecArr[2 * i2]);
            frame2.update(this._job);
            Frame frame3 = new Frame(Key.make(str + "_valid"), frame.names(), frame.vecs());
            frame3.write_lock(this._job);
            frame3.add("__internal_cv_weights__", vecArr[(2 * i2) + 1]);
            frame3.update(this._job);
            ModelBuilder<M, P, O> modelBuilder = (ModelBuilder) m829clone();
            modelBuilder.setTrain(frame2);
            modelBuilder._result = Key.make(str);
            modelBuilder._parms = (P) this._parms.m829clone();
            modelBuilder._parms._is_cv_model = true;
            modelBuilder._parms._cv_fold = i2;
            modelBuilder._parms._weights_column = "__internal_cv_weights__";
            modelBuilder._parms.setTrain(frame2._key);
            modelBuilder._parms._valid = frame3._key;
            modelBuilder._parms._fold_assignment = Model.Parameters.FoldAssignmentScheme.AUTO;
            modelBuilder._parms._nfolds = 0;
            modelBuilder._parms._max_runtime_secs = maxRuntimeSecsPerModel;
            modelBuilder.clearValidationErrors();
            modelBuilder._input_parms = (P) this._parms.m829clone();
            modelBuilder._desc = "Cross-Validation model " + (i2 + 1) + " / " + i;
            modelBuilder.init(false);
            if (modelBuilder.error_count() > 0) {
                Log.info("Marking frame for failed cv model for removal: " + frame2._key);
                arrayList.add(frame2);
                Log.info("Marking frame for failed cv model for removal: " + frame3._key);
                arrayList.add(frame3);
                for (ValidationMessage validationMessage : modelBuilder._messages) {
                    message(validationMessage._log_level, validationMessage._field_name, validationMessage._message);
                }
            }
            modelBuilderArr[i2] = modelBuilder;
        }
        if (error_count() <= 0) {
            if ($assertionsDisabled || checksum == this._parms.checksum()) {
                return modelBuilderArr;
            }
            throw new AssertionError();
        }
        Futures futures = new Futures();
        for (Frame frame4 : arrayList) {
            frame4.vec("__internal_cv_weights__").remove(futures);
            DKV.remove(frame4._key, futures);
            Log.info("Removing frame for failed cv model: " + frame4._key);
        }
        futures.blockForPending();
        throw H2OModelBuilderIllegalArgumentException.makeFromBuilder(this);
    }

    public void cv_buildModels(int i, ModelBuilder<M, P, O>[] modelBuilderArr) {
        makeCVModelBuilder(modelBuilderArr, nModelsInParallel(i)).bulkBuildModels();
        cv_computeAndSetOptimalParameters(modelBuilderArr);
    }

    protected CVModelBuilder makeCVModelBuilder(ModelBuilder<?, ?, ?>[] modelBuilderArr, int i) {
        return new CVModelBuilder(this._job, modelBuilderArr, i);
    }

    /* JADX WARN: Finally extract failed */
    public ModelMetrics.MetricBuilder[] cv_scoreCVModels(int i, Vec[] vecArr, ModelBuilder<M, P, O>[] modelBuilderArr) {
        if (this._job.stop_requested()) {
            Log.info("Skipping scoring of CV models");
            throw new Job.JobCancelledException(this._job);
        }
        if (!$assertionsDisabled && vecArr.length != 2 * i) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && modelBuilderArr.length != i) {
            throw new AssertionError();
        }
        Log.info("Scoring the " + i + " CV models");
        ModelMetrics.MetricBuilder[] metricBuilderArr = new ModelMetrics.MetricBuilder[i];
        Futures futures = new Futures();
        for (int i2 = 0; i2 < i; i2++) {
            if (this._job.stop_requested()) {
                Log.info("Skipping scoring for last " + (i - i2) + " out of " + i + " CV models");
                throw new Job.JobCancelledException(this._job);
            }
            Frame valid = modelBuilderArr[i2].valid();
            Frame frame = null;
            try {
                Scope.Safe safe = Scope.safe(valid);
                Throwable th = null;
                try {
                    try {
                        Frame frame2 = new Frame(valid);
                        if (makeCVMetrics(modelBuilderArr[i2])) {
                            M m = modelBuilderArr[i2].dest().get();
                            m.adaptTestForTrain(frame2, true, !isSupervised());
                            if (nclasses() == 2 || this._parms._keep_cross_validation_predictions || m.isDistributionHuber()) {
                                Model<M, P, O>.PredictScoreResult predictScoreImpl = m.predictScoreImpl(valid, frame2, modelBuilderArr[i2].getPredictionKey(), this._job, true, CFuncRef.from(this._parms._custom_metric_func));
                                frame = predictScoreImpl.getPredictions();
                                Scope.untrack(frame);
                                predictScoreImpl.makeModelMetrics(valid, frame2);
                                metricBuilderArr[i2] = predictScoreImpl.getMetricBuilder();
                                DKV.put(m);
                            } else {
                                metricBuilderArr[i2] = m.scoreMetrics(frame2);
                            }
                        }
                        if (safe != null) {
                            if (0 != 0) {
                                try {
                                    safe.close();
                                } catch (Throwable th2) {
                                    th.addSuppressed(th2);
                                }
                            } else {
                                safe.close();
                            }
                        }
                        Scope.track(frame);
                        DKV.remove(modelBuilderArr[i2]._parms._train, futures);
                        DKV.remove(modelBuilderArr[i2]._parms._valid, futures);
                        vecArr[2 * i2].remove(futures);
                        vecArr[(2 * i2) + 1].remove(futures);
                    } finally {
                    }
                } finally {
                }
            } catch (Throwable th3) {
                Scope.track(frame);
                throw th3;
            }
        }
        futures.blockForPending();
        return metricBuilderArr;
    }

    protected boolean makeCVMetrics(ModelBuilder<?, ?, ?> modelBuilder) {
        return !modelBuilder.getName().equals("infogram");
    }

    private boolean useParallelMainModelBuilding(int i) {
        return nModelsInParallel(i) > 1 && this._parms._parallelize_cross_validation && cv_canBuildMainModelInParallel();
    }

    protected boolean cv_canBuildMainModelInParallel() {
        return false;
    }

    protected boolean cv_updateOptimalParameters(ModelBuilder<M, P, O>[] modelBuilderArr) {
        throw new UnsupportedOperationException();
    }

    protected boolean cv_initStoppingParameters() {
        throw new UnsupportedOperationException();
    }

    private void buildMainModel(long j) {
        if (this._job.stop_requested()) {
            Log.info("Skipping main model");
            throw new Job.JobCancelledException(this._job);
        }
        if (!$assertionsDisabled && !this._job.isRunning()) {
            throw new AssertionError();
        }
        Log.info("Building main model.");
        Log.info("Remaining time for main model (ms): " + j);
        this._build_step_countdown = new Countdown(j, true);
        submitTrainModelTask().join();
        this._build_step_countdown = null;
    }

    public void cv_mainModelScores(int i, ModelMetrics.MetricBuilder[] metricBuilderArr, ModelBuilder<M, P, O>[] modelBuilderArr) {
        M m = this._result.get();
        Log.info("Computing " + i + "-fold cross-validation metrics.");
        Key[] keyArr = new Key[i];
        m._output._cross_validation_models = this._parms._keep_cross_validation_models ? keyArr : null;
        Key<Frame>[] keyArr2 = new Key[i];
        m._output._cross_validation_predictions = this._parms._keep_cross_validation_predictions ? keyArr2 : null;
        for (int i2 = 0; i2 < i; i2++) {
            keyArr[i2] = modelBuilderArr[i2]._result;
            keyArr2[i2] = Key.make(modelBuilderArr[i2].getPredictionKey());
        }
        cv_makeAggregateModelMetrics(metricBuilderArr);
        Frame frame = null;
        if (this._parms._keep_cross_validation_predictions || nclasses() == 2 || m.isDistributionHuber()) {
            Key<Frame> make = Key.make("cv_holdout_prediction_" + m._key.toString());
            if (this._parms._keep_cross_validation_predictions) {
                m._output._cross_validation_holdout_predictions_frame_id = make;
            }
            frame = combineHoldoutPredictions(keyArr2, make);
        }
        if (this._parms._keep_cross_validation_fold_assignment) {
            m._output._cross_validation_fold_assignment_frame_id = Key.make("cv_fold_assignment_" + this._result.toString());
            Frame frame2 = m._output._cross_validation_fold_assignment_frame_id.get();
            if (frame2 != null) {
                Scope.untrack(frame2.keysList());
            }
        }
        if (this._parms._keep_cross_validation_predictions) {
            for (Key<Frame> key : keyArr2) {
                Frame frame3 = (Frame) DKV.getGet(key);
                if (frame3 != null) {
                    Scope.untrack(frame3);
                }
            }
        } else {
            Log.info(Model.deleteAll(keyArr2) + " CV predictions were removed");
        }
        m._output._cross_validation_metrics = metricBuilderArr[0].makeModelMetrics(m, this._parms.train(), null, frame);
        if (frame != null) {
            if (this._parms._keep_cross_validation_predictions) {
                Scope.untrack(frame);
            } else {
                frame.remove();
            }
        }
        m._output._cross_validation_metrics._description = i + "-fold cross-validation on training data (Metrics computed for combined holdout predictions)";
        Log.info(m._output._cross_validation_metrics.toString());
        m._output._cross_validation_metrics_summary = makeCrossValidationSummaryTable(keyArr);
        if (m._output._scoring_history != null) {
            m._output._cv_scoring_history = new TwoDimTable[keyArr.length];
            for (int i3 = 0; i3 < keyArr.length; i3++) {
                TwoDimTable twoDimTable = ((Model) keyArr[i3].get())._output._scoring_history;
                String[] rowHeaders = twoDimTable.getRowHeaders();
                String[] colTypes = twoDimTable.getColTypes();
                int length = rowHeaders.length;
                int length2 = colTypes.length;
                TwoDimTable twoDimTable2 = new TwoDimTable(twoDimTable.getTableHeader(), twoDimTable.getTableDescription(), twoDimTable.getRowHeaders(), twoDimTable.getColHeaders(), twoDimTable.getColTypes(), twoDimTable.getColFormats(), twoDimTable.getColHeaderForRowHeaders());
                for (int i4 = 0; i4 < length; i4++) {
                    for (int i5 = 0; i5 < length2; i5++) {
                        twoDimTable2.set(i4, i5, twoDimTable.get(i4, i5));
                    }
                }
                m._output._cv_scoring_history[i3] = twoDimTable2;
            }
        }
        if (!this._parms._keep_cross_validation_models) {
            Log.info(Model.deleteAll(keyArr) + " CV models were removed");
        }
        m._output._total_run_time = this._build_model_countdown.elapsedTime();
        DKV.put(m);
    }

    public void cv_makeAggregateModelMetrics(ModelMetrics.MetricBuilder[] metricBuilderArr) {
        for (int i = 1; i < metricBuilderArr.length; i++) {
            metricBuilderArr[0].reduceForCV(metricBuilderArr[i]);
        }
    }

    private String getPredictionKey() {
        return "prediction_" + this._result.toString();
    }

    protected void setMaxRuntimeSecsForMainModel() {
        if (this._parms._max_runtime_secs == CMAESOptimizer.DEFAULT_STOPFITNESS) {
            return;
        }
        if (this._parms._main_model_time_budget_factor < CMAESOptimizer.DEFAULT_STOPFITNESS) {
            this._parms._max_runtime_secs = Math.max(1.0d, (-this._parms._main_model_time_budget_factor) * remainingTimeSecs());
        } else {
            int nFoldWork = nFoldWork();
            this._parms._max_runtime_secs = Math.max(remainingTimeSecs(), ((this._parms._main_model_time_budget_factor * maxRuntimeSecsPerModel(nFoldWork, nModelsInParallel(nFoldWork))) * nFoldWork) / (nFoldWork - 1.0d));
        }
    }

    public void cv_computeAndSetOptimalParameters(ModelBuilder<M, P, O>[] modelBuilderArr) {
    }

    public boolean nFoldCV() {
        return (this._parms._fold_column == null && this._parms._nfolds == 0) ? false : true;
    }

    public abstract ModelCategory[] can_build();

    public BuilderVisibility builderVisibility() {
        return BuilderVisibility.Stable;
    }

    public void clearInitState() {
        clearValidationErrors();
    }

    protected boolean logMe() {
        return true;
    }

    public abstract boolean isSupervised();

    public boolean isResponseOptional() {
        return false;
    }

    public boolean hasOffsetCol() {
        return this._parms._offset_column != null;
    }

    public boolean hasWeightCol() {
        return this._parms._weights_column != null;
    }

    public boolean hasFoldCol() {
        return this._parms._fold_column != null;
    }

    public boolean hasTreatmentCol() {
        return this._parms._treatment_column != null;
    }

    public int numSpecialCols() {
        return (hasOffsetCol() ? 1 : 0) + (hasWeightCol() ? 1 : 0) + (hasFoldCol() ? 1 : 0) + (hasTreatmentCol() ? 1 : 0);
    }

    public boolean havePojo() {
        return false;
    }

    public boolean haveMojo() {
        return false;
    }

    public int nclasses() {
        return this._nclass;
    }

    public final boolean isClassifier() {
        return nclasses() > 1;
    }

    protected boolean validateStoppingMetric() {
        return true;
    }

    protected boolean validateBinaryResponse() {
        return true;
    }

    protected void checkEarlyStoppingReproducibility() {
    }

    public int separateFeatureVecs() {
        int i = 0;
        if (this._parms._weights_column != null) {
            Vec remove = this._train.remove(this._parms._weights_column);
            if (remove == null) {
                error("_weights_column", "Weights column '" + this._parms._weights_column + "' not found in the training frame");
            } else {
                if (!remove.isNumeric()) {
                    error("_weights_column", "Invalid weights column '" + this._parms._weights_column + "', weights must be numeric");
                }
                this._weights = remove;
                if (remove.naCnt() > 0) {
                    error("_weights_columns", "Weights cannot have missing values.");
                }
                if (remove.min() < CMAESOptimizer.DEFAULT_STOPFITNESS) {
                    error("_weights_columns", "Weights must be >= 0");
                }
                if (remove.max() == CMAESOptimizer.DEFAULT_STOPFITNESS) {
                    error("_weights_columns", "Max. weight must be > 0");
                }
                this._train.add(this._parms._weights_column, remove);
                i = 0 + 1;
            }
        } else {
            this._weights = null;
            if (!$assertionsDisabled && hasWeightCol()) {
                throw new AssertionError();
            }
        }
        if (this._parms._offset_column != null) {
            Vec remove2 = this._train.remove(this._parms._offset_column);
            if (remove2 == null) {
                error("_offset_column", "Offset column '" + this._parms._offset_column + "' not found in the training frame");
            } else {
                if (!remove2.isNumeric()) {
                    error("_offset_column", "Invalid offset column '" + this._parms._offset_column + "', offset must be numeric");
                }
                this._offset = remove2;
                if (remove2.naCnt() > 0) {
                    error("_offset_column", "Offset cannot have missing values.");
                }
                if (this._weights == this._offset) {
                    error("_offset_column", "Offset must be different from weights");
                }
                this._train.add(this._parms._offset_column, remove2);
                i++;
            }
        } else {
            this._offset = null;
            if (!$assertionsDisabled && hasOffsetCol()) {
                throw new AssertionError();
            }
        }
        if (this._parms._fold_column != null) {
            Vec remove3 = this._train.remove(this._parms._fold_column);
            if (remove3 == null) {
                error("_fold_column", "Fold column '" + this._parms._fold_column + "' not found in the training frame");
            } else {
                if (!remove3.isInt() && !remove3.isCategorical()) {
                    error("_fold_column", "Invalid fold column '" + this._parms._fold_column + "', fold must be integer or categorical");
                }
                if (remove3.min() < CMAESOptimizer.DEFAULT_STOPFITNESS) {
                    error("_fold_column", "Invalid fold column '" + this._parms._fold_column + "', fold must be non-negative");
                }
                if (remove3.isConst()) {
                    error("_fold_column", "Invalid fold column '" + this._parms._fold_column + "', fold cannot be constant");
                }
                this._fold = remove3;
                if (remove3.naCnt() > 0) {
                    error("_fold_column", "Fold cannot have missing values.");
                }
                if (this._fold == this._weights) {
                    error("_fold_column", "Fold must be different from weights");
                }
                if (this._fold == this._offset) {
                    error("_fold_column", "Fold must be different from offset");
                }
                this._train.add(this._parms._fold_column, remove3);
                i++;
            }
        } else {
            this._fold = null;
            if (!$assertionsDisabled && hasFoldCol()) {
                throw new AssertionError();
            }
        }
        if (this._parms._treatment_column != null) {
            Vec remove4 = this._train.remove(this._parms._treatment_column);
            if (remove4 == null) {
                error("_treatment_column", "Treatment column '" + this._parms._treatment_column + "' not found in the training frame");
            } else {
                this._treatment = remove4;
                if (!remove4.isCategorical()) {
                    error("_treatment_column", "Invalid treatment column '" + this._parms._treatment_column + "', treatment column must be categorical");
                }
                this._weights = remove4;
                if (remove4.naCnt() > 0) {
                    error("_treatment_column", "Treatment column cannot have missing values.");
                }
                if (remove4.isCategorical() && remove4.domain().length != 2) {
                    error("_treatment_column", "Treatment column must contains only 0 or 1");
                }
                if (remove4.min() != CMAESOptimizer.DEFAULT_STOPFITNESS) {
                    error("_treatment_column", "Min. treatment column value must be 0");
                }
                if (remove4.max() != 1.0d) {
                    error("_treatment_column", "Max. treatment column value must be 1");
                }
                this._train.add(this._parms._treatment_column, remove4);
                i++;
            }
        } else {
            this._treatment = null;
            if (!$assertionsDisabled && hasTreatmentCol()) {
                throw new AssertionError();
            }
        }
        if (!isSupervised() || this._parms._response_column == null) {
            this._response = null;
        } else {
            this._response = this._train.remove(this._parms._response_column);
            if (this._response != null) {
                if (this._response == this._offset) {
                    error("_response_column", "Response column must be different from offset_column");
                }
                if (this._response == this._weights) {
                    error("_response_column", "Response column must be different from weights_column");
                }
                if (this._response == this._fold) {
                    error("_response_column", "Response column must be different from fold_column");
                }
                if (this._response == this._treatment) {
                    error("_response_column", "Response column must be different from treatment_column");
                }
                this._train.add(this._parms._response_column, this._response);
                i++;
            } else if (isSupervised()) {
                error("_response_column", "Response column '" + this._parms._response_column + "' not found in the training frame");
            }
        }
        return i;
    }

    protected boolean ignoreStringColumns() {
        return true;
    }

    protected boolean ignoreConstColumns() {
        return this._parms._ignore_const_cols;
    }

    protected boolean ignoreUuidColumns() {
        return true;
    }

    protected void ignoreBadColumns(int i, boolean z) {
        if (this._parms._ignore_const_cols) {
            new ModelBuilder<M, P, O>.FilterCols(i) { // from class: hex.ModelBuilder.3
                @Override // hex.ModelBuilder.FilterCols
                protected boolean filter(Vec vec, String str) {
                    return vec.isBad() || (ModelBuilder.this.ignoreConstColumns() && vec.isConst(ModelBuilder.this.canLearnFromNAs())) || (ModelBuilder.this.ignoreStringColumns() && vec.isString()) || (ModelBuilder.this.ignoreUuidColumns() && vec.isUUID());
                }
            }.doIt(this._train, "Dropping bad and constant columns: ", z);
        }
    }

    protected boolean canLearnFromNAs() {
        return false;
    }

    protected void checkResponseVariable() {
        if (this._response == null || this._response.isNumeric() || this._response.isCategorical() || this._response.isTime()) {
            return;
        }
        error("_response_column", "Use numerical, categorical or time variable. Currently used " + this._response.get_type_str());
    }

    protected void ignoreInvalidColumns(int i, boolean z) {
    }

    protected void checkMemoryFootPrint() {
        if (Boolean.getBoolean("sys.ai.h2o.debug.noMemoryCheck")) {
            return;
        }
        checkMemoryFootPrint_impl();
    }

    protected void checkMemoryFootPrint_impl() {
    }

    protected boolean computePriorClassDistribution() {
        return isClassifier();
    }

    public int error_count() {
        if ($assertionsDisabled || this._error_count >= 0) {
            return this._error_count;
        }
        throw new AssertionError("init() not run yet");
    }

    public void hide(String str, String str2) {
        message((byte) 5, str, str2);
    }

    public void info(String str, String str2) {
        message((byte) 3, str, str2);
    }

    public void warn(String str, String str2) {
        message((byte) 2, str, str2);
    }

    public void error(String str, String str2) {
        message((byte) 1, str, str2);
        this._error_count++;
    }

    public void clearValidationErrors() {
        this._messages = new ValidationMessage[0];
        this._error_count = 0;
    }

    public void message(byte b, String str, String str2) {
        this._messages = (ValidationMessage[]) Arrays.copyOf(this._messages, this._messages.length + 1);
        this._messages[this._messages.length - 1] = new ValidationMessage(b, str, str2);
        if (b == 1) {
            this._error_count++;
        }
    }

    public ValidationMessage[] getMessagesByFieldAndSeverity(String str, byte b) {
        return (ValidationMessage[]) Arrays.stream(this._messages).filter(validationMessage -> {
            return validationMessage._field_name.equals(str) && validationMessage._log_level == b;
        }).toArray(i -> {
            return new ValidationMessage[i];
        });
    }

    public String validationErrors() {
        return validationMessage(1);
    }

    public String validationWarnings() {
        return validationMessage(2);
    }

    private String validationMessage(int i) {
        StringBuilder sb = new StringBuilder();
        for (ValidationMessage validationMessage : this._messages) {
            if (validationMessage._log_level == i) {
                sb.append(validationMessage.toString()).append(CSVWriter.DEFAULT_LINE_END);
            }
        }
        return sb.toString();
    }

    public void init(boolean z) {
        if (z && logMe()) {
            Log.info("Building H2O " + getClass().getSimpleName() + " model with these parameters:");
            Log.info(new String(this._parms.writeJSON(new AutoBuffer()).buf()));
        }
        clearInitState();
        initWorkspace(z);
        if (!$assertionsDisabled && this._parms == null) {
            throw new AssertionError();
        }
        if (this._parms._train == null) {
            if (z) {
                error("_train", "Missing training frame");
                return;
            }
            return;
        }
        new ObjectConsistencyChecker(this._parms._train).doAllNodes();
        Frame train = this._train != null ? this._train : this._parms.train();
        if (train == null) {
            error("_train", "Missing training frame: " + this._parms._train);
            return;
        }
        if (z) {
            Scope.protect(this._parms.train(), this._parms.valid());
        }
        setTrain(new Frame(null, (String[]) train._names.clone(), (Vec[]) train.vecs().clone()));
        if (z) {
            this._parms.getOrMakeRealSeed();
        }
        if (this._parms._categorical_encoding.needsResponse() && !isSupervised()) {
            error("_categorical_encoding", "Categorical encoding scheme cannot be " + this._parms._categorical_encoding.toString() + " - no response column available.");
        }
        if (this._parms._nfolds < 0 || this._parms._nfolds == 1) {
            error("_nfolds", "nfolds must be either 0 or >1.");
        }
        if (this._parms._nfolds > 1 && this._parms._nfolds > train().numRows()) {
            error("_nfolds", "nfolds cannot be larger than the number of rows (" + train().numRows() + ").");
        }
        if (this._parms._fold_column != null) {
            hide("_fold_assignment", "Fold assignment is ignored when a fold column is specified.");
            if (this._parms._nfolds > 1) {
                error("_nfolds", "nfolds cannot be specified at the same time as a fold column.");
            } else {
                hide("_nfolds", "nfolds is ignored when a fold column is specified.");
            }
            if (this._parms._fold_assignment != Model.Parameters.FoldAssignmentScheme.AUTO && this._parms._fold_assignment != null && this._parms != null) {
                error("_fold_assignment", "Fold assignment is not allowed in conjunction with a fold column.");
            }
        }
        if (this._parms._nfolds > 1) {
            hide("_fold_column", "Fold column is ignored when nfolds > 1.");
        }
        if (!nFoldCV()) {
            hide("_keep_cross_validation_models", "Only for cross-validation.");
            hide("_keep_cross_validation_predictions", "Only for cross-validation.");
            hide("_keep_cross_validation_fold_assignment", "Only for cross-validation.");
            hide("_fold_assignment", "Only for cross-validation.");
            if (this._parms._fold_assignment != Model.Parameters.FoldAssignmentScheme.AUTO && this._parms._fold_assignment != null) {
                error("_fold_assignment", "Fold assignment is only allowed for cross-validation.");
            }
        }
        if (this._parms._distribution == DistributionFamily.modified_huber) {
            error("_distribution", "Modified Huber distribution is not supported yet.");
        }
        if (this._parms._distribution != DistributionFamily.tweedie) {
            hide("_tweedie_power", "Only for Tweedie Distribution.");
        }
        if (this._parms._tweedie_power <= 1.0d || this._parms._tweedie_power >= 2.0d) {
            error("_tweedie_power", "Tweedie power must be between 1 and 2 (exclusive). For tweedie power = 1, use Poisson distribution. For tweedie power = 2, use Gamma distribution.");
        }
        if (this._parms._ignored_columns != null) {
            HashSet hashSet = new HashSet(Arrays.asList(this._parms._ignored_columns));
            hashSet.removeAll(this._parms.getUsedColumns(train._names));
            String[] strArr = (String[]) hashSet.toArray(new String[0]);
            this._train.remove(strArr);
            if (z) {
                Log.info("Dropping ignored columns: " + Arrays.toString(strArr));
            }
        }
        if (this._parms._checkpoint != null) {
            if (DKV.get(this._parms._checkpoint) == null) {
                error("_checkpoint", "Checkpoint has to point to existing model!");
            }
            for (String str : this._parms._checkpoint.get().adaptTestForTrain(this._train, z, false)) {
                warn("_checkpoint", str);
            }
            separateFeatureVecs();
        } else {
            ignoreBadColumns(separateFeatureVecs(), z);
            ignoreInvalidColumns(separateFeatureVecs(), z);
            checkResponseVariable();
        }
        if (z && error_count() == 0 && this._parms._auto_rebalance) {
            setTrain(rebalance(this._train, false, this._result + ".temporary.train"));
            separateFeatureVecs();
            this._valid = rebalance(this._valid, false, this._result + ".temporary.valid");
        }
        if (this._train.numCols() == 0) {
            error("_train", "There are no usable columns to generate model");
        }
        if (isSupervised()) {
            if (this._response != null) {
                if (this._parms._distribution != DistributionFamily.tweedie) {
                    hide("_tweedie_power", "Tweedie power is only used for Tweedie distribution.");
                }
                if (this._parms._distribution != DistributionFamily.quantile) {
                    hide("_quantile_alpha", "Quantile (alpha) is only used for Quantile regression.");
                }
                if (z) {
                    checkDistributions();
                }
                this._nclass = init_getNClass();
                if (this._parms._check_constant_response && this._response.isConst()) {
                    error("_response", "Response cannot be constant.");
                }
                if (validateBinaryResponse() && this._nclass == 1 && this._response.isBinary(true)) {
                    warn("_response", "We have detected that your response column has only 2 unique values (0/1). If you wish to train a binary model instead of a regression model, convert your target column to categorical before training.");
                }
            }
            if (!this._parms._balance_classes) {
                hide("_max_after_balance_size", "Balance classes is false, hide max_after_balance_size");
            } else if (this._parms._weights_column != null && this._weights != null && !this._weights.isBinary()) {
                error("_balance_classes", "Balance classes and observation weights are not currently supported together.");
            }
            if (this._parms._max_after_balance_size <= CMAESOptimizer.DEFAULT_STOPFITNESS) {
                error("_max_after_balance_size", "Max size after balancing needs to be positive, suggest 1.0f");
            }
            if (this._train != null) {
                if (this._train.numCols() <= 1 && !getClass().toString().equals("class hex.gam.GAM")) {
                    error("_train", "Training data must have at least 2 features (incl. response).");
                }
                if (null == this._parms._response_column) {
                    error("_response_column", "Response column parameter not set.");
                    return;
                }
                if (this._response != null && computePriorClassDistribution()) {
                    if (!isClassifier() || !isSupervised()) {
                        double[] dArr = new double[1];
                        dArr[0] = (this._weights != null ? this._weights.mean() : 1.0d) * train().numRows();
                        this._distribution = dArr;
                        this._priorClassDist = new double[]{1.0d};
                    } else if (this._parms.getDistributionFamily() == DistributionFamily.quasibinomial) {
                        String[] stringDomain = new VecUtils.CollectDoubleDomain(null, 2).doAll(this._response).stringDomain(this._response.isInt());
                        MRUtils.ClassDistQuasibinomial doAll = this._weights != null ? new MRUtils.ClassDistQuasibinomial(stringDomain).doAll(this._response, this._weights) : new MRUtils.ClassDistQuasibinomial(stringDomain).doAll(this._response);
                        this._distribution = doAll.dist();
                        this._priorClassDist = doAll.relDist();
                    } else {
                        MRUtils.ClassDist doAll2 = this._weights != null ? new MRUtils.ClassDist(nclasses()).doAll(this._response, this._weights) : new MRUtils.ClassDist(nclasses()).doAll(this._response);
                        this._distribution = doAll2.dist();
                        this._priorClassDist = doAll2.relDist();
                    }
                }
            }
            if (!isClassifier()) {
                hide("_balance_classes", "Balance classes is only applicable to classification problems.");
                hide("_class_sampling_factors", "Class sampling factors is only applicable to classification problems.");
                hide("_max_after_balance_size", "Max after balance size is only applicable to classification problems.");
                hide("_max_confusion_matrix_size", "Max confusion matrix size is only applicable to classification problems.");
            }
            if (this._nclass <= 2) {
                hide("_max_confusion_matrix_size", "Only for multi-class classification problems.");
            }
            if (!this._parms._balance_classes) {
                hide("_max_after_balance_size", "Only used with balanced classes");
                hide("_class_sampling_factors", "Class sampling factors is only applicable if balancing classes.");
            }
        } else {
            if (!isResponseOptional()) {
                hide("_response_column", "Ignored for unsupervised methods.");
                this._vresponse = null;
            }
            hide("_balance_classes", "Ignored for unsupervised methods.");
            hide("_class_sampling_factors", "Ignored for unsupervised methods.");
            hide("_max_after_balance_size", "Ignored for unsupervised methods.");
            hide("_max_confusion_matrix_size", "Ignored for unsupervised methods.");
            this._response = null;
            this._nclass = 1;
        }
        if (this._nclass > 1048576) {
            error("_nclass", "Too many levels in response column: " + this._nclass + ", maximum supported number of classes is 1048576.");
        }
        Frame valid = this._parms.valid();
        if (valid != null) {
            if (isResponseOptional() && this._parms._response_column != null && this._response == null) {
                this._vresponse = valid.vec(this._parms._response_column);
            }
            this._valid = adaptFrameToTrain(valid, "Validation Frame", "_validation_frame", z, false);
            if (!isResponseOptional() || (this._parms._response_column != null && this._valid.find(this._parms._response_column) >= 0)) {
                this._vresponse = this._valid.vec(this._parms._response_column);
            }
        } else {
            this._valid = null;
            this._vresponse = null;
        }
        if (z) {
            boolean z2 = !this._parms._is_cv_model;
            Frame encodeFrameCategoricals = encodeFrameCategoricals(applyPreprocessors(this._train, true, z2), z2);
            if (encodeFrameCategoricals != this._train) {
                this._origTrain = this._train;
                this._origNames = this._train.names();
                this._origDomains = this._train.domains();
                setTrain(encodeFrameCategoricals);
                separateFeatureVecs();
            } else {
                this._origTrain = null;
            }
            if (this._valid != null) {
                setValid(encodeFrameCategoricals(applyPreprocessors(this._valid, false, z2), z2));
            }
            boolean z3 = false;
            Vec[] vecs = this._train.vecs();
            for (int i = 0; i < vecs.length; i++) {
                Vec vec = vecs[i];
                if (vec != this._response && vec != this._fold && vec.isCategorical() && shouldReorder(vec)) {
                    int length = vec.domain().length;
                    Log.info("Reordering categorical column " + this._train.name(i) + " (" + length + " levels) based on the mean (weighted) response per level.");
                    VecUtils.MeanResponsePerLevelTask meanResponsePerLevelTask = new VecUtils.MeanResponsePerLevelTask(length);
                    Vec[] vecArr = new Vec[3];
                    vecArr[0] = vec;
                    vecArr[1] = this._parms._weights_column != null ? this._train.vec(this._parms._weights_column) : vec.makeCon(1.0d);
                    vecArr[2] = this._train.vec(this._parms._response_column);
                    double[] dArr2 = meanResponsePerLevelTask.doAll(vecArr).meanWeightedResponse;
                    int[] iArr = new int[length];
                    for (int i2 = 0; i2 < length; i2++) {
                        iArr[i2] = i2;
                    }
                    ArrayUtils.sort(iArr, dArr2);
                    int[] iArr2 = new int[length];
                    for (int i3 = 0; i3 < length; i3++) {
                        iArr2[iArr[i3]] = i3;
                    }
                    Vec anyVec = new VecUtils.ReorderTask(iArr2).doAll(1, (byte) 3, new Frame(vec)).outputFrame().anyVec();
                    String[] strArr2 = new String[length];
                    for (int i4 = 0; i4 < length; i4++) {
                        strArr2[i4] = vec.domain()[iArr[i4]];
                    }
                    anyVec.setDomain(strArr2);
                    vecs[i] = anyVec;
                    z3 = true;
                }
            }
            if (z3) {
                this._train.restructure(this._train.names(), vecs);
            }
        }
        boolean z4 = this._parms._categorical_encoding == Model.Parameters.CategoricalEncodingScheme.Binary;
        boolean z5 = (this._valid == null || ArrayUtils.difference(this._train._names, this._valid._names).length == 0) ? false : true;
        if (!$assertionsDisabled && z && !z4 && z5) {
            throw new AssertionError();
        }
        if (z5 && z4) {
            for (String str2 : this._train._names) {
                if (!$assertionsDisabled && !ArrayUtils.contains(this._valid._names, str2)) {
                    throw new AssertionError("Internal error during categorical encoding: training column " + str2 + " not in validation frame with columns " + Arrays.toString(this._valid._names));
                }
            }
        }
        if (this._parms._stopping_tolerance < CMAESOptimizer.DEFAULT_STOPFITNESS) {
            error("_stopping_tolerance", "Stopping tolerance must be >= 0.");
        }
        if (this._parms._stopping_tolerance >= 1.0d) {
            error("_stopping_tolerance", "Stopping tolerance must be < 1.");
        }
        if (this._parms._stopping_rounds == 0) {
            if (this._parms._stopping_metric != ScoreKeeper.StoppingMetric.AUTO) {
                warn("_stopping_metric", "Stopping metric is ignored for _stopping_rounds=0.");
            }
            if (this._parms._stopping_tolerance != this._parms.defaultStoppingTolerance()) {
                warn("_stopping_tolerance", "Stopping tolerance is ignored for _stopping_rounds=0.");
            }
        } else if (this._parms._stopping_rounds < 0) {
            error("_stopping_rounds", "Stopping rounds must be >= 0.");
        } else {
            checkEarlyStoppingReproducibility();
            if (validateStoppingMetric()) {
                if (isClassifier()) {
                    if (this._parms._stopping_metric == ScoreKeeper.StoppingMetric.deviance && !getClass().getSimpleName().contains("GLM")) {
                        error("_stopping_metric", "Stopping metric cannot be deviance for classification.");
                    }
                } else if (this._parms._stopping_metric.isClassificationOnly()) {
                    error("_stopping_metric", "Stopping metric cannot be " + this._parms._stopping_metric + " for regression.");
                }
            }
        }
        if (this._parms._stopping_metric == ScoreKeeper.StoppingMetric.custom || this._parms._stopping_metric == ScoreKeeper.StoppingMetric.custom_increasing) {
            checkCustomMetricForEarlyStopping();
        }
        if (this._parms._max_runtime_secs < CMAESOptimizer.DEFAULT_STOPFITNESS) {
            error("_max_runtime_secs", "Max runtime (in seconds) must be greater than 0 (or 0 for unlimited).");
        }
        if (StringUtils.isNullOrEmpty(this._parms._export_checkpoints_dir) || this._parms._is_cv_model || H2O.getPM().isWritableDirectory(this._parms._export_checkpoints_dir)) {
            return;
        }
        error("_export_checkpoints_dir", "Checkpoints directory path must point to a writable path.");
    }

    protected void checkCustomMetricForEarlyStopping() {
        if (this._parms._custom_metric_func == null) {
            error("_custom_metric_func", "Custom metric function needs to be defined in order to use it for early stopping.");
        }
    }

    public Frame init_adaptFrameToTrain(Frame frame, String str, String str2, boolean z) {
        Frame adaptFrameToTrain = adaptFrameToTrain(frame, str, str2, z, false);
        if (z) {
            adaptFrameToTrain = encodeFrameCategoricals(adaptFrameToTrain, true);
        }
        return adaptFrameToTrain;
    }

    private Frame adaptFrameToTrain(Frame frame, String str, String str2, boolean z, boolean z2) {
        if (frame.numRows() == 0) {
            error(str2, str + " must have > 0 rows.");
        }
        Frame frame2 = new Frame(null, (String[]) frame._names.clone(), (Vec[]) frame.vecs().clone());
        try {
            String[] adaptTestForTrain = Model.adaptTestForTrain(frame2, null, (String[][]) null, this._train._names, this._train.domains(), this._parms, z, true, null, getToEigenVec(), this._workspace.getToDelete(z), z2);
            if (frame2.vec(this._parms._response_column) == null && this._parms._response_column != null && !isResponseOptional()) {
                error(str2, str + " must have a response column '" + this._parms._response_column + "'.");
            }
            if (z) {
                for (String str3 : adaptTestForTrain) {
                    Log.info(str3);
                    warn(str2, str3);
                }
            }
        } catch (IllegalArgumentException e) {
            error(str2, e.getMessage());
        }
        return frame2;
    }

    private Frame applyPreprocessors(Frame frame, boolean z, boolean z2) {
        if (this._parms._preprocessors == null) {
            return frame;
        }
        for (Key<ModelPreprocessor> key : this._parms._preprocessors) {
            DKV.prefetch(key);
        }
        Frame frame2 = frame;
        for (Key<ModelPreprocessor> key2 : this._parms._preprocessors) {
            ModelPreprocessor modelPreprocessor = key2.get();
            Frame processTrain = z ? modelPreprocessor.processTrain(frame2, this._parms) : modelPreprocessor.processValid(frame2, this._parms);
            if (processTrain != frame2) {
                trackEncoded(processTrain, z2);
            }
            frame2 = processTrain;
        }
        if (!z2) {
            Scope.untrack(frame2);
        }
        return frame2;
    }

    private Frame encodeFrameCategoricals(Frame frame, boolean z) {
        Frame categoricalEncoder = FrameUtils.categoricalEncoder(frame, this._parms.getNonPredictors(), this._parms._categorical_encoding, getToEigenVec(), this._parms._max_categorical_levels);
        if (categoricalEncoder != frame) {
            trackEncoded(categoricalEncoder, z);
        }
        return categoricalEncoder;
    }

    private void trackEncoded(Frame frame, boolean z) {
        if (!$assertionsDisabled && frame._key == null) {
            throw new AssertionError();
        }
        if (z) {
            Scope.track(frame);
        } else {
            this._workspace.getToDelete(true).put(frame._key, Arrays.toString(Thread.currentThread().getStackTrace()));
        }
    }

    protected Frame rebalance(Frame frame, boolean z, String str) {
        if (frame == null) {
            return null;
        }
        int desiredChunks = desiredChunks(frame, z);
        String substring = str.substring(str.length() - 5);
        double rebalanceRatio = rebalanceRatio();
        int nonEmptyChunks = frame.anyVec().nonEmptyChunks();
        if (nonEmptyChunks >= desiredChunks * rebalanceRatio) {
            if (desiredChunks > 1) {
                Log.info(substring + " dataset already contains " + nonEmptyChunks + " (non-empty)  chunks. No need to rebalance. [desiredChunks=" + desiredChunks, ", rebalanceRatio=" + rebalanceRatio + "]");
            }
            return frame;
        }
        raiseReproducibilityWarning(substring, desiredChunks);
        Log.info("Rebalancing " + substring + " dataset into " + desiredChunks + " chunks.");
        Key makeUserHidden = Key.makeUserHidden(str + ".chunks" + desiredChunks);
        ((RebalanceDataSet) H2O.submitTask(new RebalanceDataSet(frame, makeUserHidden, desiredChunks))).join();
        Frame frame2 = (Frame) DKV.get(makeUserHidden).get();
        Scope.track(frame2);
        return frame2;
    }

    protected void raiseReproducibilityWarning(String str, int i) {
    }

    private double rebalanceRatio() {
        return Double.parseDouble(getSysProperty("rebalance.ratio." + (H2O.getCloudSize() == 1 ? "single" : "multi"), "1.0"));
    }

    protected int desiredChunks(Frame frame, boolean z) {
        return (H2O.getCloudSize() <= 1 || !Boolean.parseBoolean(getSysProperty("rebalance.enableMulti", "false"))) ? desiredChunkSingle(frame) : desiredChunkMulti(frame);
    }

    private int desiredChunkSingle(Frame frame) {
        return Math.min((int) Math.ceil(frame.numRows() / 1000.0d), H2O.NUMCPUS);
    }

    private int desiredChunkMulti(Frame frame) {
        for (byte b : frame.types()) {
            if (b != 3 && b != 4) {
                Log.warn("Training frame contains columns non-numeric/categorical columns. Using old rebalance logic.");
                return desiredChunkSingle(frame);
            }
        }
        long j = 0;
        for (Vec vec : frame.vecs()) {
            j += vec.length() - vec.naCnt();
        }
        long max = Math.max(j * 4, frame.byteSize());
        int calcOptimalChunkSize = FileVec.calcOptimalChunkSize(max, frame.numCols(), frame.numCols() * 4, H2O.NUMCPUS, H2O.getCloudSize(), false, true);
        int i = (int) ((max / calcOptimalChunkSize) + (max % ((long) calcOptimalChunkSize) > 0 ? 1 : 0));
        Log.info("Calculated optimal number of chunks = " + i);
        return i;
    }

    protected String getSysProperty(String str, String str2) {
        return System.getProperty("sys.ai.h2o." + str, str2);
    }

    protected int init_getNClass() {
        int cardinality = this._response.isCategorical() ? this._response.cardinality() : 1;
        if (this._parms._distribution == DistributionFamily.quasibinomial) {
            cardinality = 2;
        }
        return cardinality;
    }

    public void checkDistributions() {
        if (this._parms._distribution == DistributionFamily.poisson) {
            if (this._response.min() < CMAESOptimizer.DEFAULT_STOPFITNESS) {
                error("_response", "Response must be non-negative for Poisson distribution.");
                return;
            }
            return;
        }
        if (this._parms._distribution == DistributionFamily.gamma) {
            if (this._response.min() < CMAESOptimizer.DEFAULT_STOPFITNESS) {
                error("_response", "Response must be non-negative for Gamma distribution.");
                return;
            }
            return;
        }
        if (this._parms._distribution == DistributionFamily.tweedie) {
            if (this._parms._tweedie_power >= 2.0d || this._parms._tweedie_power <= 1.0d) {
                error("_tweedie_power", "Tweedie power must be between 1 and 2.");
            }
            if (this._response.min() < CMAESOptimizer.DEFAULT_STOPFITNESS) {
                error("_response", "Response must be non-negative for Tweedie distribution.");
                return;
            }
            return;
        }
        if (this._parms._distribution == DistributionFamily.quantile) {
            if (this._parms._quantile_alpha > 1.0d || this._parms._quantile_alpha < CMAESOptimizer.DEFAULT_STOPFITNESS) {
                error("_quantile_alpha", "Quantile alpha must be between 0 and 1.");
                return;
            }
            return;
        }
        if (this._parms._distribution == DistributionFamily.huber) {
            if (this._parms._huber_alpha < CMAESOptimizer.DEFAULT_STOPFITNESS || this._parms._huber_alpha > 1.0d) {
                error("_huber_alpha", "Huber alpha must be between 0 and 1.");
            }
        }
    }

    Frame combineHoldoutPredictions(Key<Frame>[] keyArr, Key<Frame> key) {
        int i = this._parms._keep_cross_validation_predictions_precision;
        if (i < 0) {
            i = isClassifier() ? 8 : 0;
        }
        return combineHoldoutPredictions(keyArr, key, i);
    }

    static Frame combineHoldoutPredictions(Key<Frame>[] keyArr, Key<Frame> key, int i) {
        int length = keyArr.length;
        Frame frame = keyArr[0].get();
        Vec[] vecArr = new Vec[length * frame.numCols()];
        int i2 = 0;
        for (Key<Frame> key2 : keyArr) {
            for (int i3 = 0; i3 < key2.get().numCols(); i3++) {
                int i4 = i2;
                i2++;
                vecArr[i4] = key2.get().vec(i3);
            }
        }
        return makeHoldoutPredictionCombiner(length, frame.numCols(), i).doAll(frame.types(), new Frame(vecArr)).outputFrame(key, frame.names(), frame.domains());
    }

    static HoldoutPredictionCombiner makeHoldoutPredictionCombiner(int i, int i2, int i3) {
        if (i3 < 0) {
            throw new IllegalArgumentException("Precision cannot be negative, got precision = " + i3);
        }
        return i3 == 0 ? new HoldoutPredictionCombiner(i, i2) : new ApproximatingHoldoutPredictionCombiner(i, i2, i3);
    }

    private TwoDimTable makeCrossValidationSummaryTable(Key[] keyArr) {
        if (keyArr == null || keyArr.length == 0) {
            return null;
        }
        int length = keyArr.length;
        String[] strArr = new String[length + 2];
        Arrays.fill(strArr, "float");
        String[] strArr2 = new String[length + 2];
        Arrays.fill(strArr2, "%f");
        String[] strArr3 = new String[length + 2];
        strArr3[0] = "mean";
        strArr3[1] = "sd";
        for (int i = 0; i < length; i++) {
            strArr3[i + 2] = "cv_" + (i + 1) + "_valid";
        }
        HashSet hashSet = new HashSet();
        hashSet.add("total_rows");
        hashSet.add("makeSchema");
        hashSet.add("hr");
        hashSet.add("frame");
        hashSet.add("model");
        hashSet.add("remove");
        hashSet.add("cm");
        hashSet.add("auc_obj");
        hashSet.add("aucpr");
        if (null == this._parms._custom_metric_func) {
            hashSet.add("custom");
            hashSet.add("custom_increasing");
        }
        ArrayList arrayList = new ArrayList();
        ModelMetrics modelMetrics = ((Model) DKV.getGet(keyArr[0]))._output._validation_metrics;
        if (modelMetrics != null) {
            for (Method method : modelMetrics.getClass().getMethods()) {
                if (!hashSet.contains(method.getName())) {
                    try {
                        ((Double) method.invoke(modelMetrics, new Object[0])).doubleValue();
                        arrayList.add(method);
                    } catch (Exception e) {
                    }
                }
            }
            ConfusionMatrix cm = modelMetrics.cm();
            if (cm != null) {
                for (Method method2 : cm.getClass().getMethods()) {
                    if (!hashSet.contains(method2.getName())) {
                        try {
                            ((Double) method2.invoke(cm, new Object[0])).doubleValue();
                            arrayList.add(method2);
                        } catch (Exception e2) {
                        }
                    }
                }
            }
        }
        TreeSet<String> treeSet = new TreeSet();
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            treeSet.add(((Method) it.next()).getName());
        }
        ArrayList<Method> arrayList2 = new ArrayList();
        for (String str : treeSet) {
            Iterator it2 = arrayList.iterator();
            while (true) {
                if (it2.hasNext()) {
                    Method method3 = (Method) it2.next();
                    if (method3.getName().equals(str)) {
                        arrayList2.add(method3);
                        break;
                    }
                }
            }
        }
        int size = treeSet.size();
        TwoDimTable twoDimTable = new TwoDimTable("Cross-Validation Metrics Summary", null, (String[]) treeSet.toArray(new String[0]), strArr3, strArr, strArr2, "");
        new MathUtils.BasicStats(size);
        double[][] dArr = new double[length][size];
        int i2 = 0;
        for (Key key : keyArr) {
            Model model = (Model) DKV.getGet(key);
            if (model != null) {
                ModelMetrics modelMetrics2 = model._output._validation_metrics;
                int i3 = 0;
                for (Method method4 : arrayList2) {
                    if (!hashSet.contains(method4.getName())) {
                        try {
                            double doubleValue = ((Double) method4.invoke(modelMetrics2, new Object[0])).doubleValue();
                            dArr[i2][i3] = doubleValue;
                            int i4 = i3;
                            i3++;
                            twoDimTable.set(i4, i2 + 2, Float.valueOf((float) doubleValue));
                        } catch (Throwable th) {
                        }
                        if (modelMetrics2.cm() != null) {
                            try {
                                double doubleValue2 = ((Double) method4.invoke(modelMetrics2.cm(), new Object[0])).doubleValue();
                                dArr[i2][i3] = doubleValue2;
                                int i5 = i3;
                                i3++;
                                twoDimTable.set(i5, i2 + 2, Float.valueOf((float) doubleValue2));
                            } catch (Throwable th2) {
                            }
                        }
                    }
                }
                i2++;
            }
        }
        MathUtils.SimpleStats simpleStats = new MathUtils.SimpleStats(size);
        for (int i6 = 0; i6 < length; i6++) {
            simpleStats.add(dArr[i6], 1.0d);
        }
        for (int i7 = 0; i7 < size; i7++) {
            twoDimTable.set(i7, 0, Float.valueOf((float) simpleStats.mean()[i7]));
            twoDimTable.set(i7, 1, Float.valueOf((float) simpleStats.sigma()[i7]));
        }
        Log.info(twoDimTable);
        return twoDimTable;
    }

    public String getName() {
        return getClass().getSimpleName().toLowerCase();
    }

    /* JADX INFO: Access modifiers changed from: private */
    public void cleanUp() {
        this._workspace.cleanUp();
    }

    protected final void initWorkspace(boolean z) {
        if (z) {
            this._workspace = new Workspace(true);
        }
    }

    public PojoWriter makePojoWriter(Model<?, ?, ?> model, MojoModel mojoModel) {
        throw new UnsupportedOperationException("MOJO Model for algorithm '" + mojoModel._algoName + "' doesn't support conversion to POJO.");
    }

    static {
        $assertionsDisabled = !ModelBuilder.class.desiredAssertionStatus();
        ALGOBASES = new String[0];
        SCHEMAS = new String[0];
        BUILDERS = new ModelBuilder[0];
        _builders = new HashMap();
        _model_class_to_algo = new HashMap();
        _algo_to_algo_full_name = new HashMap();
        _algo_to_model_class = new HashMap();
    }
}
