package hex;

import hex.Distribution;
import hex.Model;
import hex.Model.Output;
import hex.Model.Parameters;
import hex.ModelMetrics;
import hex.ScoreKeeper;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.TreeSet;
import jsr166y.CountedCompleter;
import water.AutoBuffer;
import water.DKV;
import water.Futures;
import water.H2O;
import water.Iced;
import water.Job;
import water.Key;
import water.MRTask;
import water.Scope;
import water.exceptions.H2OIllegalArgumentException;
import water.exceptions.H2OModelBuilderIllegalArgumentException;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.NewChunk;
import water.fvec.RebalanceDataSet;
import water.fvec.Vec;
import water.rapids.ASTKFold;
import water.util.ArrayUtils;
import water.util.Log;
import water.util.MRUtils;
import water.util.MathUtils;
import water.util.TwoDimTable;
import water.util.VecUtils;

/* loaded from: input_file:hex/ModelBuilder.class */
public abstract class ModelBuilder<M extends Model<M, P, O>, P extends Model.Parameters, O extends Model.Output> extends Iced {
    public Job _job;
    protected Key<M> _result;
    private long _start_time;
    private static String[] ALGOBASES;
    private static String[] SCHEMAS;
    private static ModelBuilder[] BUILDERS;
    public P _parms;
    protected transient Frame _train;
    protected transient Frame _valid;
    private static final Map<String, Class<? extends ModelBuilder>> _builders;
    private static final Map<Class<? extends Model>, String> _model_class_to_algo;
    private static final Map<String, String> _algo_to_algo_full_name;
    private static final Map<String, Class<? extends Model>> _algo_to_model_class;
    protected transient Vec _response;
    protected transient Vec _vresponse;
    protected transient Vec _offset;
    protected transient Vec _weights;
    protected transient Vec _fold;
    protected int _nclass;
    transient double[] _distribution;
    protected transient double[] _priorClassDist;
    public ValidationMessage[] _messages;
    private int _error_count;
    public transient HashSet<String> _removedCols;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:hex/ModelBuilder$BuilderVisibility.class */
    public enum BuilderVisibility {
        Experimental,
        Beta,
        Stable
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:hex/ModelBuilder$Driver.class */
    public abstract class Driver extends H2O.H2OCountedCompleter<ModelBuilder<M, P, O>.Driver> {
        /* JADX INFO: Access modifiers changed from: protected */
        public Driver() {
        }

        protected Driver(H2O.H2OCountedCompleter h2OCountedCompleter) {
            super(h2OCountedCompleter);
        }

        @Override // jsr166y.CountedCompleter
        public void onCompletion(CountedCompleter countedCompleter) {
            try {
                ModelBuilder.this.dest().get()._output.stopClock();
            } catch (Throwable th) {
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:hex/ModelBuilder$FilterCols.class */
    public abstract class FilterCols {
        final int _specialVecs;

        public FilterCols(int i) {
            this._specialVecs = i;
        }

        protected abstract boolean filter(Vec vec);

        void doIt(Frame frame, String str, boolean z) {
            ArrayList arrayList = new ArrayList();
            for (int i = 0; i < frame.vecs().length - this._specialVecs; i++) {
                if (filter(frame.vecs()[i])) {
                    arrayList.add(Integer.valueOf(i));
                }
            }
            if (arrayList.isEmpty()) {
                return;
            }
            ModelBuilder.this._removedCols = new HashSet<>(arrayList.size());
            int[] iArr = new int[arrayList.size()];
            for (int i2 = 0; i2 < iArr.length; i2++) {
                iArr[i2] = ((Integer) arrayList.get(i2)).intValue();
                ModelBuilder.this._removedCols.add(frame._names[iArr[i2]]);
            }
            frame.remove(iArr);
            String str2 = str + ModelBuilder.this._removedCols.toString();
            ModelBuilder.this.warn("_train", str2);
            if (z) {
                Log.info(str2);
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:hex/ModelBuilder$HoldoutPredictionCombiner.class */
    public static class HoldoutPredictionCombiner extends MRTask<HoldoutPredictionCombiner> {
        int _folds;
        int _cols;

        public HoldoutPredictionCombiner(int i, int i2) {
            this._folds = i;
            this._cols = i2;
        }

        @Override // water.MRTask
        public void map(Chunk[] chunkArr, NewChunk[] newChunkArr) {
            for (int i = 0; i < this._cols; i++) {
                double[] dArr = new double[chunkArr[0].len()];
                for (int i2 = 0; i2 < this._folds; i2++) {
                    for (int i3 = 0; i3 < chunkArr[0].len(); i3++) {
                        int i4 = i3;
                        dArr[i4] = dArr[i4] + chunkArr[(i2 * this._cols) + i].atd(i3);
                    }
                }
                newChunkArr[i].setDoubles(dArr);
            }
        }
    }

    /* loaded from: input_file:hex/ModelBuilder$ValidationMessage.class */
    public static final class ValidationMessage extends Iced {
        final byte _log_level;
        final String _field_name;
        final String _message;

        public ValidationMessage(byte b, String str, String str2) {
            this._log_level = b;
            this._field_name = str;
            this._message = str2;
            Log.log(b, str + ": " + str2);
        }

        public int log_level() {
            return this._log_level;
        }

        public String toString() {
            return Log.LVLS[this._log_level] + " on field: " + this._field_name + ": " + this._message;
        }
    }

    public final M get() {
        if ($assertionsDisabled || this._job._result == this._result) {
            return (M) this._job.get();
        }
        throw new AssertionError();
    }

    public final boolean isStopped() {
        return this._job.isStopped();
    }

    public final Key<M> dest() {
        return this._result;
    }

    protected boolean timeout() {
        if ($assertionsDisabled || this._start_time > 0) {
            return this._parms._max_runtime_secs > 0.0d && System.currentTimeMillis() - this._start_time > ((long) (this._parms._max_runtime_secs * 1000.0d));
        }
        throw new AssertionError("Must set _start_time for each individual model.");
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public boolean stop_requested() {
        return this._job.stop_requested() || timeout();
    }

    public static Key<? extends Model> defaultKey(String str) {
        return Key.make(H2O.calcNextUniqueModelId(str));
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public ModelBuilder(P p) {
        this(p, defaultKey(p.algoName()));
    }

    protected ModelBuilder(P p, Key<M> key) {
        this._messages = new ValidationMessage[0];
        this._error_count = -1;
        this._removedCols = new HashSet<>();
        this._result = key;
        this._job = new Job(key, p.javaName(), p.algoName());
        this._parms = p;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public ModelBuilder(P p, Job job) {
        this._messages = new ValidationMessage[0];
        this._error_count = -1;
        this._removedCols = new HashSet<>();
        this._job = job;
        this._result = (Key<M>) defaultKey(p.algoName());
        this._parms = p;
    }

    public static String[] algos() {
        return ALGOBASES;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public ModelBuilder(P p, boolean z) {
        this(p, z, "hex.schemas.");
    }

    protected ModelBuilder(P p, boolean z, String str) {
        this._messages = new ValidationMessage[0];
        this._error_count = -1;
        this._removedCols = new HashSet<>();
        if (!$assertionsDisabled && !z) {
            throw new AssertionError();
        }
        this._job = null;
        this._result = null;
        this._parms = p;
        init(false);
        String lowerCase = getClass().getSimpleName().toLowerCase();
        if (ArrayUtils.find(ALGOBASES, lowerCase) != -1) {
            throw H2O.fail("Only called once at startup per ModelBuilder, and " + lowerCase + " has already been called");
        }
        ALGOBASES = (String[]) Arrays.copyOf(ALGOBASES, ALGOBASES.length + 1);
        BUILDERS = (ModelBuilder[]) Arrays.copyOf(BUILDERS, BUILDERS.length + 1);
        SCHEMAS = (String[]) Arrays.copyOf(SCHEMAS, SCHEMAS.length + 1);
        ALGOBASES[ALGOBASES.length - 1] = lowerCase;
        BUILDERS[BUILDERS.length - 1] = this;
        SCHEMAS[SCHEMAS.length - 1] = str;
    }

    public static String algoName(String str) {
        return BUILDERS[ArrayUtils.find(ALGOBASES, str)]._parms.algoName();
    }

    public static String javaName(String str) {
        return BUILDERS[ArrayUtils.find(ALGOBASES, str)]._parms.javaName();
    }

    public static String paramName(String str) {
        return algoName(str) + "Parameters";
    }

    public static String schemaDirectory(String str) {
        return SCHEMAS[ArrayUtils.find(ALGOBASES, str)];
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static <B extends ModelBuilder> B make(String str, Job job, Key<Model> key) {
        int find = ArrayUtils.find(ALGOBASES, str.toLowerCase());
        if (!$assertionsDisabled && find == -1) {
            throw new AssertionError("Unregistered algorithm " + str);
        }
        B b = (B) BUILDERS[find].m62clone();
        b._job = job;
        b._result = key;
        b._parms = (P) BUILDERS[find]._parms.m62clone();
        return b;
    }

    public final Frame train() {
        return this._train;
    }

    protected final Frame valid() {
        return this._valid;
    }

    public Vec response() {
        return this._response;
    }

    public Vec vresponse() {
        return this._vresponse == null ? this._response : this._vresponse;
    }

    public final Job<M> trainModel() {
        if (error_count() > 0) {
            throw H2OModelBuilderIllegalArgumentException.makeFromBuilder(this);
        }
        this._start_time = System.currentTimeMillis();
        return !nFoldCV() ? this._job.start(trainModelImpl(), this._parms.progressUnits()) : this._job.start(new H2O.H2OCountedCompleter() { // from class: hex.ModelBuilder.1
            @Override // water.H2O.H2OCountedCompleter
            public void compute2() {
                ModelBuilder.this.computeCrossValidation();
                tryComplete();
            }
        }, (1 + nFoldWork() + 1) * this._parms.progressUnits());
    }

    public final M trainModelNested() {
        if (error_count() > 0) {
            throw H2OModelBuilderIllegalArgumentException.makeFromBuilder(this);
        }
        this._start_time = System.currentTimeMillis();
        if (nFoldCV()) {
            computeCrossValidation();
        } else {
            trainModelImpl().compute2();
        }
        return this._result.get();
    }

    protected abstract ModelBuilder<M, P, O>.Driver trainModelImpl();

    protected int nModelsInParallel() {
        if (this._parms._parallelize_cross_validation && this._parms._max_runtime_secs == 0.0d && this._train.byteSize() < 1000000.0d) {
            return this._parms._nfolds;
        }
        return 1;
    }

    private int nFoldWork() {
        if (this._parms._fold_column == null) {
            return this._parms._nfolds;
        }
        Vec categoricalVec = VecUtils.toCategoricalVec(train().vec(this._parms._fold_column));
        int length = categoricalVec.domain().length;
        categoricalVec.remove();
        return length;
    }

    public void computeCrossValidation() {
        if (!$assertionsDisabled && !this._job.isRunning()) {
            throw new AssertionError();
        }
        this._job.setReadyForView(false);
        Integer valueOf = Integer.valueOf(nFoldWork());
        init(false);
        try {
            Scope.enter();
            Vec[] cv_makeWeights = cv_makeWeights(valueOf.intValue(), cv_AssignFold(valueOf.intValue()));
            this._job.update(1L);
            ModelBuilder<M, P, O>[] cv_makeFramesAndBuilders = cv_makeFramesAndBuilders(valueOf.intValue(), cv_makeWeights);
            H2O.H2OCountedCompleter cv_buildModels = cv_buildModels(valueOf.intValue(), cv_makeFramesAndBuilders);
            ModelMetrics.MetricBuilder[] cv_scoreCVModels = cv_scoreCVModels(valueOf.intValue(), cv_makeWeights, cv_makeFramesAndBuilders);
            if (cv_buildModels != null) {
                cv_buildModels.join();
            }
            cv_mainModelScores(valueOf.intValue(), cv_scoreCVModels, cv_makeFramesAndBuilders);
            this._job.setReadyForView(true);
            DKV.put(this._job);
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    public Vec cv_AssignFold(int i) {
        Vec vec = train().vec(this._parms._fold_column);
        if (vec != null) {
            if (vec.isInt() && ((vec.min() == 0.0d && vec.max() == i - 1) || (vec.min() == 1.0d && vec.max() == i))) {
                return vec;
            }
            throw new H2OIllegalArgumentException("Fold column must be either categorical or contiguous integers from 0..N-1 or 1..N");
        }
        long nFoldSeed = this._parms.nFoldSeed();
        Log.info("Creating " + i + " cross-validation splits with random number seed: " + nFoldSeed);
        switch (this._parms._fold_assignment) {
            case AUTO:
            case Random:
                return ASTKFold.kfoldColumn(train().anyVec().makeZero(), i, nFoldSeed);
            case Modulo:
                return ASTKFold.moduloKfoldColumn(train().anyVec().makeZero(), i);
            case Stratified:
                return ASTKFold.stratifiedKFoldColumn(response(), i, nFoldSeed);
            default:
                throw H2O.unimpl();
        }
    }

    public Vec[] cv_makeWeights(final int i, Vec vec) {
        String str = this._parms._weights_column;
        Vec vec2 = str != null ? train().vec(str) : train().anyVec().makeCon(1.0d);
        Vec[] vecs = new MRTask() { // from class: hex.ModelBuilder.2
            @Override // water.MRTask
            public void map(Chunk[] chunkArr, NewChunk[] newChunkArr) {
                Chunk chunk = chunkArr[0];
                Chunk chunk2 = chunkArr[1];
                for (int i2 = 0; i2 < chunk2._len; i2++) {
                    int at8 = ((int) chunk.at8(i2)) % i;
                    double atd = chunk2.atd(i2);
                    int i3 = 0;
                    while (i3 < i) {
                        boolean z = at8 == i3;
                        newChunkArr[(2 * i3) + 0].addNum(z ? 0.0d : atd);
                        newChunkArr[(2 * i3) + 1].addNum(z ? atd : 0.0d);
                        i3++;
                    }
                }
            }
        }.doAll(2 * i, (byte) 3, new Frame(vec, vec2)).outputFrame().vecs();
        if (this._parms._keep_cross_validation_fold_assignment) {
            DKV.put(new Frame(Key.make("cv_fold_assignment_" + this._result.toString()), new String[]{"fold_assignment"}, new Vec[]{vec.makeCopy()}));
        }
        if (this._parms._fold_column == null && !this._parms._keep_cross_validation_fold_assignment) {
            vec.remove();
        }
        if (str == null) {
            vec2.remove();
        }
        for (Vec vec3 : vecs) {
            if (vec3.isConst()) {
                throw new H2OIllegalArgumentException("Not enough data to create " + i + " random cross-validation splits. Either reduce nfolds, specify a larger dataset (or specify another random number seed, if applicable).");
            }
        }
        return vecs;
    }

    public ModelBuilder<M, P, O>[] cv_makeFramesAndBuilders(int i, Vec[] vecArr) {
        long checksum = this._parms.checksum();
        String key = this._result.toString();
        if (train().find("__internal_cv_weights__") != -1) {
            throw new H2OIllegalArgumentException("Frame cannot contain a Vec called '__internal_cv_weights__'.");
        }
        Frame frame = new Frame(train().names(), train().vecs());
        if (this._parms._weights_column != null) {
            frame.remove(this._parms._weights_column);
        }
        ModelBuilder<M, P, O>[] modelBuilderArr = new ModelBuilder[i];
        for (int i2 = 0; i2 < i; i2++) {
            String str = key + "_cv_" + (i2 + 1);
            Frame frame2 = new Frame(Key.make(str + "_train"), frame.names(), frame.vecs());
            frame2.add("__internal_cv_weights__", vecArr[2 * i2]);
            DKV.put(frame2);
            Frame frame3 = new Frame(Key.make(str + "_valid"), frame.names(), frame.vecs());
            frame3.add("__internal_cv_weights__", vecArr[(2 * i2) + 1]);
            DKV.put(frame3);
            ModelBuilder<M, P, O> modelBuilder = (ModelBuilder) m62clone();
            modelBuilder._result = Key.make(str);
            modelBuilder._parms = (P) this._parms.m62clone();
            modelBuilder._parms._weights_column = "__internal_cv_weights__";
            modelBuilder._parms._train = frame2._key;
            modelBuilder._parms._valid = frame3._key;
            modelBuilder._parms._fold_assignment = Model.Parameters.FoldAssignmentScheme.AUTO;
            modelBuilder._parms._nfolds = 0;
            modelBuilder.init(false);
            if (modelBuilder.error_count() > 0) {
                for (ValidationMessage validationMessage : modelBuilder._messages) {
                    message(validationMessage._log_level, validationMessage._field_name, validationMessage._message);
                }
            }
            modelBuilderArr[i2] = modelBuilder;
        }
        if (error_count() > 0) {
            throw H2OModelBuilderIllegalArgumentException.makeFromBuilder(this);
        }
        if ($assertionsDisabled || checksum == this._parms.checksum()) {
            return modelBuilderArr;
        }
        throw new AssertionError();
    }

    public H2O.H2OCountedCompleter cv_buildModels(int i, ModelBuilder<M, P, O>[] modelBuilderArr) {
        H2O.H2OCountedCompleter[] h2OCountedCompleterArr = new H2O.H2OCountedCompleter[i];
        int i2 = 0;
        for (int i3 = 0; i3 < i && !this._job.stop_requested(); i3++) {
            Log.info("Building cross-validation model " + (i3 + 1) + " / " + i + ".");
            modelBuilderArr[i3]._start_time = System.currentTimeMillis();
            h2OCountedCompleterArr[i3] = H2O.submitTask(modelBuilderArr[i3].trainModelImpl());
            i2++;
            if (i2 == nModelsInParallel()) {
                while (i2 > 0) {
                    int i4 = i2;
                    i2--;
                    h2OCountedCompleterArr[(i3 + 1) - i4].join();
                }
            }
        }
        for (int i5 = 0; i5 < i; i5++) {
            h2OCountedCompleterArr[i5].join();
        }
        if (this._job.stop_requested()) {
            return null;
        }
        if (!$assertionsDisabled && !this._job.isRunning()) {
            throw new AssertionError();
        }
        Log.info("Building main model.");
        this._start_time = System.currentTimeMillis();
        modifyParmsForCrossValidationMainModel(modelBuilderArr);
        return H2O.submitTask(trainModelImpl());
    }

    public ModelMetrics.MetricBuilder[] cv_scoreCVModels(int i, Vec[] vecArr, ModelBuilder<M, P, O>[] modelBuilderArr) {
        if (this._job.stop_requested()) {
            return null;
        }
        ModelMetrics.MetricBuilder[] metricBuilderArr = new ModelMetrics.MetricBuilder[i];
        Futures futures = new Futures();
        for (int i2 = 0; i2 < i; i2++) {
            if (this._job.stop_requested()) {
                return null;
            }
            Frame valid = modelBuilderArr[i2].valid();
            Frame frame = new Frame(valid);
            M m = modelBuilderArr[i2].dest().get();
            m.adaptTestForTrain(frame, true, !isSupervised());
            metricBuilderArr[i2] = m.scoreMetrics(frame);
            if (nclasses() == 2 || this._parms._keep_cross_validation_predictions) {
                m.predictScoreImpl(valid, frame, "prediction_" + modelBuilderArr[i2]._result.toString(), null);
            }
            if (frame != null) {
                Model.cleanup_adapt(frame, valid);
                DKV.remove(frame._key, futures);
            }
            DKV.remove(modelBuilderArr[i2]._parms._train, futures);
            DKV.remove(modelBuilderArr[i2]._parms._valid, futures);
            vecArr[2 * i2].remove(futures);
            vecArr[(2 * i2) + 1].remove(futures);
        }
        futures.blockForPending();
        return metricBuilderArr;
    }

    public void cv_mainModelScores(int i, ModelMetrics.MetricBuilder[] metricBuilderArr, ModelBuilder<M, P, O>[] modelBuilderArr) {
        if (this._job.stop_requested()) {
            return;
        }
        if (!$assertionsDisabled && !this._job.isRunning()) {
            throw new AssertionError();
        }
        M m = this._result.get();
        Log.info("Computing " + i + "-fold cross-validation metrics.");
        m._output._cross_validation_models = new Key[i];
        Key[] keyArr = new Key[i];
        m._output._cross_validation_predictions = this._parms._keep_cross_validation_predictions ? keyArr : null;
        for (int i2 = 0; i2 < i; i2++) {
            if (i2 > 0) {
                metricBuilderArr[0].reduce(metricBuilderArr[i2]);
            }
            Key<M> key = modelBuilderArr[i2]._result;
            m._output._cross_validation_models[i2] = key;
            keyArr[i2] = Key.make("prediction_" + key.toString());
        }
        Frame frame = null;
        if (this._parms._keep_cross_validation_predictions || nclasses() == 2) {
            Key make = Key.make("cv_holdout_prediction_" + m._key.toString());
            if (this._parms._keep_cross_validation_predictions) {
                m._output._cross_validation_holdout_predictions_frame_id = make;
            }
            frame = combineHoldoutPredictions(keyArr, make);
        }
        if (this._parms._keep_cross_validation_fold_assignment) {
            m._output._cross_validation_fold_assignment_frame_id = Key.make("cv_fold_assignment_" + this._result.toString());
            Scope.untrack(((Frame) m._output._cross_validation_fold_assignment_frame_id.get()).keys());
        }
        for (Key key2 : keyArr) {
            Frame frame2 = (Frame) DKV.getGet(key2);
            if (frame2 != null) {
                if (this._parms._keep_cross_validation_predictions) {
                    Scope.untrack(frame2.keys());
                } else {
                    frame2.remove();
                }
            }
        }
        m._output._cross_validation_metrics = metricBuilderArr[0].makeModelMetrics(m, this._parms.train(), null, frame);
        if (frame != null) {
            if (this._parms._keep_cross_validation_predictions) {
                Scope.untrack(frame.keys());
            } else {
                frame.remove();
            }
        }
        m._output._cross_validation_metrics._description = i + "-fold cross-validation on training data (Metrics computed for combined holdout predictions)";
        Log.info(m._output._cross_validation_metrics.toString());
        m._output._cross_validation_metrics_summary = makeCrossValidationSummaryTable(m._output._cross_validation_models);
        DKV.put(m);
    }

    public void modifyParmsForCrossValidationMainModel(ModelBuilder<M, P, O>[] modelBuilderArr) {
    }

    public boolean nFoldCV() {
        return (this._parms._fold_column == null && this._parms._nfolds == 0) ? false : true;
    }

    public abstract ModelCategory[] can_build();

    public BuilderVisibility builderVisibility() {
        return BuilderVisibility.Stable;
    }

    public void clearInitState() {
        clearValidationErrors();
    }

    protected boolean logMe() {
        return true;
    }

    public boolean isSupervised() {
        return false;
    }

    public boolean hasOffsetCol() {
        return this._parms._offset_column != null;
    }

    public boolean hasWeightCol() {
        return this._parms._weights_column != null;
    }

    public boolean hasFoldCol() {
        return this._parms._fold_column != null;
    }

    public int numSpecialCols() {
        return (hasOffsetCol() ? 1 : 0) + (hasWeightCol() ? 1 : 0) + (hasFoldCol() ? 1 : 0);
    }

    public int nclasses() {
        return this._nclass;
    }

    public final boolean isClassifier() {
        return nclasses() > 1;
    }

    protected int separateFeatureVecs() {
        int i = 0;
        if (this._parms._weights_column != null) {
            Vec remove = this._train.remove(this._parms._weights_column);
            if (remove == null) {
                error("_weights_column", "Weights column '" + this._parms._weights_column + "' not found in the training frame");
            } else {
                if (!remove.isNumeric()) {
                    error("_weights_column", "Invalid weights column '" + this._parms._weights_column + "', weights must be numeric");
                }
                this._weights = remove;
                if (remove.naCnt() > 0) {
                    error("_weights_columns", "Weights cannot have missing values.");
                }
                if (remove.min() < 0.0d) {
                    error("_weights_columns", "Weights must be >= 0");
                }
                if (remove.max() == 0.0d) {
                    error("_weights_columns", "Max. weight must be > 0");
                }
                this._train.add(this._parms._weights_column, remove);
                i = 0 + 1;
            }
        } else {
            this._weights = null;
            if (!$assertionsDisabled && hasWeightCol()) {
                throw new AssertionError();
            }
        }
        if (this._parms._offset_column != null) {
            Vec remove2 = this._train.remove(this._parms._offset_column);
            if (remove2 == null) {
                error("_offset_column", "Offset column '" + this._parms._offset_column + "' not found in the training frame");
            } else {
                if (!remove2.isNumeric()) {
                    error("_offset_column", "Invalid offset column '" + this._parms._offset_column + "', offset must be numeric");
                }
                this._offset = remove2;
                if (remove2.naCnt() > 0) {
                    error("_offset_column", "Offset cannot have missing values.");
                }
                if (this._weights == this._offset) {
                    error("_offset_column", "Offset must be different from weights");
                }
                this._train.add(this._parms._offset_column, remove2);
                i++;
            }
        } else {
            this._offset = null;
            if (!$assertionsDisabled && hasOffsetCol()) {
                throw new AssertionError();
            }
        }
        if (this._parms._fold_column != null) {
            Vec remove3 = this._train.remove(this._parms._fold_column);
            if (remove3 == null) {
                error("_fold_column", "Fold column '" + this._parms._fold_column + "' not found in the training frame");
            } else {
                if (!remove3.isInt() && !remove3.isCategorical()) {
                    error("_fold_column", "Invalid fold column '" + this._parms._fold_column + "', fold must be integer or categorical");
                }
                if (remove3.min() < 0.0d) {
                    error("_fold_column", "Invalid fold column '" + this._parms._fold_column + "', fold must be non-negative");
                }
                if (remove3.isConst()) {
                    error("_fold_column", "Invalid fold column '" + this._parms._fold_column + "', fold cannot be constant");
                }
                this._fold = remove3;
                if (remove3.naCnt() > 0) {
                    error("_fold_column", "Fold cannot have missing values.");
                }
                if (this._fold == this._weights) {
                    error("_fold_column", "Fold must be different from weights");
                }
                if (this._fold == this._offset) {
                    error("_fold_column", "Fold must be different from offset");
                }
                this._train.add(this._parms._fold_column, remove3);
                i++;
            }
        } else {
            this._fold = null;
            if (!$assertionsDisabled && hasFoldCol()) {
                throw new AssertionError();
            }
        }
        if (!isSupervised() || this._parms._response_column == null) {
            this._response = null;
        } else {
            this._response = this._train.remove(this._parms._response_column);
            if (this._response != null) {
                if (this._response == this._offset) {
                    error("_response_column", "Response column must be different from offset_column");
                }
                if (this._response == this._weights) {
                    error("_response_column", "Response column must be different from weights_column");
                }
                if (this._response == this._fold) {
                    error("_response_column", "Response column must be different from fold_column");
                }
                this._train.add(this._parms._response_column, this._response);
                i++;
            } else if (isSupervised()) {
                error("_response_column", "Response column '" + this._parms._response_column + "' not found in the training frame");
            }
        }
        return i;
    }

    protected boolean ignoreStringColumns() {
        return true;
    }

    protected boolean ignoreConstColumns() {
        return this._parms._ignore_const_cols;
    }

    protected void ignoreBadColumns(int i, boolean z) {
        if (this._parms._ignore_const_cols) {
            new ModelBuilder<M, P, O>.FilterCols(i) { // from class: hex.ModelBuilder.3
                @Override // hex.ModelBuilder.FilterCols
                protected boolean filter(Vec vec) {
                    return (ModelBuilder.this.ignoreConstColumns() && vec.isConst()) || vec.isBad() || (ModelBuilder.this.ignoreStringColumns() && vec.isString());
                }
            }.doIt(this._train, "Dropping constant columns: ", z);
        }
    }

    protected void checkMemoryFootPrint() {
    }

    protected boolean computePriorClassDistribution() {
        return isClassifier();
    }

    public int error_count() {
        if ($assertionsDisabled || this._error_count >= 0) {
            return this._error_count;
        }
        throw new AssertionError("init() not run yet");
    }

    public void hide(String str, String str2) {
        message((byte) 5, str, str2);
    }

    public void info(String str, String str2) {
        message((byte) 3, str, str2);
    }

    public void warn(String str, String str2) {
        message((byte) 2, str, str2);
    }

    public void error(String str, String str2) {
        message((byte) 1, str, str2);
        this._error_count++;
    }

    public void clearValidationErrors() {
        this._messages = new ValidationMessage[0];
        this._error_count = 0;
    }

    public void message(byte b, String str, String str2) {
        this._messages = (ValidationMessage[]) Arrays.copyOf(this._messages, this._messages.length + 1);
        this._messages[this._messages.length - 1] = new ValidationMessage(b, str, str2);
    }

    public String validationErrors() {
        StringBuilder sb = new StringBuilder();
        for (ValidationMessage validationMessage : this._messages) {
            if (validationMessage._log_level == 1) {
                sb.append(validationMessage.toString()).append("\n");
            }
        }
        return sb.toString();
    }

    public void init(boolean z) {
        if (z && logMe()) {
            Log.info("Building H2O " + getClass().getSimpleName().toString() + " model with these parameters:");
            Log.info(new String(this._parms.writeJSON(new AutoBuffer()).buf()));
        }
        clearInitState();
        if (!$assertionsDisabled && this._parms == null) {
            throw new AssertionError();
        }
        if (this._parms._train == null) {
            if (z) {
                error("_train", "Missing training frame");
                return;
            }
            return;
        }
        Frame train = this._parms.train();
        if (train == null) {
            error("_train", "Missing training frame: " + this._parms._train);
            return;
        }
        this._train = new Frame((Key) null, (String[]) train._names.clone(), (Vec[]) train.vecs().clone());
        if (this._parms._nfolds < 0 || this._parms._nfolds == 1) {
            error("_nfolds", "nfolds must be either 0 or >1.");
        }
        if (this._parms._nfolds > 1 && this._parms._nfolds > train().numRows()) {
            error("_nfolds", "nfolds cannot be larger than the number of rows (" + train().numRows() + ").");
        }
        if (this._parms._fold_column != null) {
            hide("_fold_assignment", "Fold assignment is ignored when a fold column is specified.");
            if (this._parms._nfolds > 1) {
                error("_nfolds", "nfolds cannot be specified at the same time as a fold column.");
            } else {
                hide("_nfolds", "nfolds is ignored when a fold column is specified.");
            }
            if (this._parms._fold_assignment != Model.Parameters.FoldAssignmentScheme.AUTO) {
                error("_fold_assignment", "Fold assignment is not allowed in conjunction with a fold column.");
            }
        }
        if (this._parms._nfolds > 1) {
            hide("_fold_column", "Fold column is ignored when nfolds > 1.");
        }
        if (!nFoldCV()) {
            hide("_keep_cross_validation_predictions", "Only for cross-validation.");
            hide("_keep_cross_validation_fold_assignment", "Only for cross-validation.");
            hide("_fold_assignment", "Only for cross-validation.");
            if (this._parms._fold_assignment != Model.Parameters.FoldAssignmentScheme.AUTO) {
                error("_fold_assignment", "Fold assignment is only allowed for cross-validation.");
            }
        }
        if (this._parms._distribution != Distribution.Family.tweedie) {
            hide("_tweedie_power", "Only for Tweedie Distribution.");
        }
        if (this._parms._tweedie_power <= 1.0d || this._parms._tweedie_power >= 2.0d) {
            error("_tweedie_power", "Tweedie power must be between 1 and 2 (exclusive).");
        }
        if (this._parms._ignored_columns != null) {
            this._train.remove(this._parms._ignored_columns);
            if (z) {
                Log.info("Dropping ignored columns: " + Arrays.toString(this._parms._ignored_columns));
            }
        }
        if (z && error_count() == 0) {
            this._train = rebalance(this._train, false, this._result + ".temporary.train");
            this._valid = rebalance(this._valid, false, this._result + ".temporary.valid");
        }
        ignoreBadColumns(separateFeatureVecs(), z);
        if (this._train.numCols() == 0) {
            error("_train", "There are no usable columns to generate model");
        }
        if (isSupervised()) {
            if (this._response != null) {
                if (this._parms._distribution != Distribution.Family.tweedie) {
                    hide("_tweedie_power", "Tweedie power is only used for Tweedie distribution.");
                }
                if (this._parms._distribution != Distribution.Family.quantile) {
                    hide("_quantile_alpha", "Quantile (alpha) is only used for Quantile regression.");
                }
                if (z) {
                    checkDistributions();
                }
                this._nclass = this._response.isCategorical() ? this._response.cardinality() : 1;
                if (this._response.isConst()) {
                    error("_response", "Response cannot be constant.");
                }
            }
            if (!this._parms._balance_classes) {
                hide("_max_after_balance_size", "Balance classes is false, hide max_after_balance_size");
            } else if (this._parms._weights_column != null && this._weights != null && !this._weights.isBinary()) {
                error("_balance_classes", "Balance classes and observation weights are not currently supported together.");
            }
            if (this._parms._max_after_balance_size <= 0.0d) {
                error("_max_after_balance_size", "Max size after balancing needs to be positive, suggest 1.0f");
            }
            if (this._train != null) {
                if (this._train.numCols() <= 1) {
                    error("_train", "Training data must have at least 2 features (incl. response).");
                }
                if (null == this._parms._response_column) {
                    error("_response_column", "Response column parameter not set.");
                    return;
                }
                if (this._response != null && computePriorClassDistribution()) {
                    if (isClassifier() && isSupervised()) {
                        MRUtils.ClassDist doAll = this._weights != null ? new MRUtils.ClassDist(nclasses()).doAll(this._response, this._weights) : new MRUtils.ClassDist(nclasses()).doAll(this._response);
                        this._distribution = doAll.dist();
                        this._priorClassDist = doAll.rel_dist();
                    } else {
                        double[] dArr = new double[1];
                        dArr[0] = (this._weights != null ? this._weights.mean() : 1.0d) * train().numRows();
                        this._distribution = dArr;
                        this._priorClassDist = new double[]{1.0d};
                    }
                }
            }
            if (!isClassifier()) {
                hide("_balance_classes", "Balance classes is only applicable to classification problems.");
                hide("_class_sampling_factors", "Class sampling factors is only applicable to classification problems.");
                hide("_max_after_balance_size", "Max after balance size is only applicable to classification problems.");
                hide("_max_confusion_matrix_size", "Max confusion matrix size is only applicable to classification problems.");
            }
            if (this._nclass <= 2) {
                hide("_max_hit_ratio_k", "Max K-value for hit ratio is only applicable to multi-class classification problems.");
                hide("_max_confusion_matrix_size", "Only for multi-class classification problems.");
            }
            if (!this._parms._balance_classes) {
                hide("_max_after_balance_size", "Only used with balanced classes");
                hide("_class_sampling_factors", "Class sampling factors is only applicable if balancing classes.");
            }
        } else {
            hide("_response_column", "Ignored for unsupervised methods.");
            hide("_balance_classes", "Ignored for unsupervised methods.");
            hide("_class_sampling_factors", "Ignored for unsupervised methods.");
            hide("_max_after_balance_size", "Ignored for unsupervised methods.");
            hide("_max_confusion_matrix_size", "Ignored for unsupervised methods.");
            this._response = null;
            this._vresponse = null;
            this._nclass = 1;
        }
        if (this._nclass > 1000) {
            error("_nclass", "Too many levels in response column: " + this._nclass + ", maximum supported number of classes is 1000.");
        }
        Frame valid = this._parms.valid();
        if (valid != null) {
            if (valid.numRows() == 0) {
                error("_validation_frame", "Validation frame must have > 0 rows.");
            }
            this._valid = new Frame((Key) null, (String[]) valid._names.clone(), (Vec[]) valid.vecs().clone());
            try {
                String[] adaptTestForTrain = Model.adaptTestForTrain(this._train._names, this._parms._weights_column, this._parms._offset_column, this._parms._fold_column, null, this._train.domains(), this._valid, this._parms.missingColumnsType(), z, true, null);
                this._vresponse = this._valid.vec(this._parms._response_column);
                if (this._vresponse == null && this._parms._response_column != null) {
                    error("_validation_frame", "Validation frame must have a response column '" + this._parms._response_column + "'.");
                }
                if (z) {
                    for (String str : adaptTestForTrain) {
                        Log.info(str);
                        warn("_valid", str);
                    }
                }
                if (!$assertionsDisabled && z && this._valid != null && !Arrays.equals(this._train._names, this._valid._names)) {
                    throw new AssertionError();
                }
            } catch (IllegalArgumentException e) {
                error("_valid", e.getMessage());
            }
        } else {
            this._valid = null;
            this._vresponse = null;
        }
        if (this._parms._checkpoint != null && DKV.get(this._parms._checkpoint) == null) {
            error("_checkpoint", "Checkpoint has to point to existing model!");
        }
        if (this._parms._stopping_tolerance < 0.0d) {
            error("_stopping_tolerance", "Stopping tolerance must be >= 0.");
        }
        if (this._parms._stopping_tolerance >= 1.0d) {
            error("_stopping_tolerance", "Stopping tolerance must be < 1.");
        }
        if (this._parms._stopping_rounds == 0) {
            if (this._parms._stopping_metric != ScoreKeeper.StoppingMetric.AUTO) {
                warn("_stopping_metric", "Stopping metric is ignored for _stopping_rounds=0.");
            }
            if (this._parms._stopping_tolerance != this._parms.defaultStoppingTolerance()) {
                warn("_stopping_tolerance", "Stopping tolerance is ignored for _stopping_rounds=0.");
            }
        } else if (this._parms._stopping_rounds < 0) {
            error("_stopping_rounds", "Stopping rounds must be >= 0.");
        } else if (isClassifier()) {
            if (this._parms._stopping_metric == ScoreKeeper.StoppingMetric.deviance) {
                error("_stopping_metric", "Stopping metric cannot be deviance for classification.");
            }
            if (nclasses() != 2 && this._parms._stopping_metric == ScoreKeeper.StoppingMetric.AUC) {
                error("_stopping_metric", "Stopping metric cannot be AUC for multinomial classification.");
            }
        } else if (this._parms._stopping_metric == ScoreKeeper.StoppingMetric.misclassification || this._parms._stopping_metric == ScoreKeeper.StoppingMetric.AUC || this._parms._stopping_metric == ScoreKeeper.StoppingMetric.logloss) {
            error("_stopping_metric", "Stopping metric cannot be " + this._parms._stopping_metric.toString() + " for regression.");
        }
        if (this._parms._max_runtime_secs < 0.0d) {
            error("_max_runtime_secs", "Max runtime (in seconds) must be greater than 0 (or 0 for unlimited).");
        }
    }

    protected Frame rebalance(Frame frame, boolean z, String str) {
        if (frame == null) {
            return null;
        }
        int desiredChunks = desiredChunks(frame, z);
        if (frame.anyVec().nChunks() >= desiredChunks) {
            if (desiredChunks > 1) {
                Log.info(str.substring(str.length() - 5) + " dataset already contains " + frame.anyVec().nChunks() + " chunks. No need to rebalance.");
            }
            return frame;
        }
        Log.info("Rebalancing " + str.substring(str.length() - 5) + " dataset into " + desiredChunks + " chunks.");
        Key makeUserHidden = Key.makeUserHidden(str + ".chunks" + desiredChunks);
        ((RebalanceDataSet) H2O.submitTask(new RebalanceDataSet(frame, makeUserHidden, desiredChunks))).join();
        Frame frame2 = (Frame) DKV.get(makeUserHidden).get();
        Scope.track(frame2);
        return frame2;
    }

    protected int desiredChunks(Frame frame, boolean z) {
        return Math.min((int) Math.ceil(frame.numRows() / 1000.0d), H2O.NUMCPUS);
    }

    public void checkDistributions() {
        if (this._parms._distribution == Distribution.Family.poisson) {
            if (this._response.min() < 0.0d) {
                error("_response", "Response must be non-negative for Poisson distribution.");
                return;
            }
            return;
        }
        if (this._parms._distribution == Distribution.Family.gamma) {
            if (this._response.min() < 0.0d) {
                error("_response", "Response must be non-negative for Gamma distribution.");
            }
        } else {
            if (this._parms._distribution == Distribution.Family.tweedie) {
                if (this._parms._tweedie_power >= 2.0d || this._parms._tweedie_power <= 1.0d) {
                    error("_tweedie_power", "Tweedie power must be between 1 and 2.");
                }
                if (this._response.min() < 0.0d) {
                    error("_response", "Response must be non-negative for Tweedie distribution.");
                    return;
                }
                return;
            }
            if (this._parms._distribution == Distribution.Family.quantile) {
                if (this._parms._quantile_alpha > 1.0d || this._parms._quantile_alpha < 0.0d) {
                    error("_quantile_alpha", "Quantile (alpha) must be between 0 and 1.");
                }
            }
        }
    }

    private static Frame combineHoldoutPredictions(Key<Frame>[] keyArr, Key key) {
        int length = keyArr.length;
        Frame frame = keyArr[0].get();
        Vec[] vecArr = new Vec[length * frame.numCols()];
        int i = 0;
        for (int i2 = 0; i2 < length; i2++) {
            for (int i3 = 0; i3 < keyArr[i2].get().numCols(); i3++) {
                int i4 = i;
                i++;
                vecArr[i4] = keyArr[i2].get().vec(i3);
            }
        }
        return new HoldoutPredictionCombiner(length, frame.numCols()).doAll(frame.types(), new Frame(vecArr)).outputFrame(key, frame.names(), frame.domains());
    }

    private TwoDimTable makeCrossValidationSummaryTable(Key[] keyArr) {
        if (keyArr == null || keyArr.length == 0) {
            return null;
        }
        int length = keyArr.length;
        String[] strArr = new String[length + 2];
        Arrays.fill(strArr, "string");
        String[] strArr2 = new String[length + 2];
        Arrays.fill(strArr2, "%s");
        String[] strArr3 = new String[length + 2];
        strArr3[0] = "mean";
        strArr3[1] = "sd";
        for (int i = 0; i < length; i++) {
            strArr3[i + 2] = "cv_" + (i + 1) + "_valid";
        }
        HashSet hashSet = new HashSet();
        hashSet.add("total_rows");
        hashSet.add("makeSchema");
        hashSet.add("hr");
        hashSet.add("frame");
        hashSet.add("remove");
        hashSet.add("cm");
        hashSet.add("auc_obj");
        ArrayList arrayList = new ArrayList();
        ModelMetrics modelMetrics = ((Model) DKV.getGet(keyArr[0]))._output._validation_metrics;
        ConfusionMatrix cm = modelMetrics.cm();
        if (modelMetrics != null) {
            for (Method method : modelMetrics.getClass().getMethods()) {
                if (!hashSet.contains(method.getName())) {
                    try {
                        ((Double) method.invoke(modelMetrics, new Object[0])).doubleValue();
                        arrayList.add(method);
                    } catch (Exception e) {
                    }
                }
            }
        }
        if (cm != null) {
            for (Method method2 : cm.getClass().getMethods()) {
                if (!hashSet.contains(method2.getName())) {
                    try {
                        ((Double) method2.invoke(cm, new Object[0])).doubleValue();
                        arrayList.add(method2);
                    } catch (Exception e2) {
                    }
                }
            }
        }
        TreeSet<String> treeSet = new TreeSet();
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            treeSet.add(((Method) it.next()).getName());
        }
        ArrayList<Method> arrayList2 = new ArrayList();
        for (String str : treeSet) {
            Iterator it2 = arrayList.iterator();
            while (true) {
                if (it2.hasNext()) {
                    Method method3 = (Method) it2.next();
                    if (method3.getName().equals(str)) {
                        arrayList2.add(method3);
                        break;
                    }
                }
            }
        }
        int size = treeSet.size();
        TwoDimTable twoDimTable = new TwoDimTable("Cross-Validation Metrics Summary", null, (String[]) treeSet.toArray(new String[0]), strArr3, strArr, strArr2, "");
        MathUtils.BasicStats basicStats = new MathUtils.BasicStats(size);
        double[][] dArr = new double[length][size];
        int i2 = 0;
        for (Key key : keyArr) {
            Model model = (Model) DKV.getGet(key);
            if (model != null) {
                ModelMetrics modelMetrics2 = model._output._validation_metrics;
                int i3 = 0;
                for (Method method4 : arrayList2) {
                    if (!hashSet.contains(method4.getName())) {
                        try {
                            double doubleValue = ((Double) method4.invoke(modelMetrics2, new Object[0])).doubleValue();
                            dArr[i2][i3] = doubleValue;
                            int i4 = i3;
                            i3++;
                            twoDimTable.set(i4, i2 + 2, Float.valueOf((float) doubleValue));
                        } catch (Throwable th) {
                        }
                        if (modelMetrics2.cm() != null) {
                            try {
                                double doubleValue2 = ((Double) method4.invoke(modelMetrics2.cm(), new Object[0])).doubleValue();
                                dArr[i2][i3] = doubleValue2;
                                int i5 = i3;
                                i3++;
                                twoDimTable.set(i5, i2 + 2, Float.valueOf((float) doubleValue2));
                            } catch (Throwable th2) {
                            }
                        }
                    }
                }
                i2++;
            }
        }
        for (int i6 = 0; i6 < length; i6++) {
            basicStats.add(dArr[i6], 1.0d);
        }
        for (int i7 = 0; i7 < size; i7++) {
            twoDimTable.set(i7, 0, Float.valueOf((float) basicStats.mean()[i7]));
            twoDimTable.set(i7, 1, Float.valueOf((float) basicStats.sigma()[i7]));
        }
        Log.info(twoDimTable);
        return twoDimTable;
    }

    static {
        $assertionsDisabled = !ModelBuilder.class.desiredAssertionStatus();
        ALGOBASES = new String[0];
        SCHEMAS = new String[0];
        BUILDERS = new ModelBuilder[0];
        _builders = new HashMap();
        _model_class_to_algo = new HashMap();
        _algo_to_algo_full_name = new HashMap();
        _algo_to_model_class = new HashMap();
    }
}
