package hex;

import hex.Model;
import hex.Model.Output;
import hex.Model.Parameters;
import hex.ModelMetrics;
import hex.ModelMetricsSupervised;
import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import water.DKV;
import water.Futures;
import water.H2O;
import water.Iced;
import water.Job;
import water.Key;
import water.Lockable;
import water.MRTask;
import water.Weaver;
import water.api.ModelSchema;
import water.fvec.Chunk;
import water.fvec.EnumWrappedVec;
import water.fvec.Frame;
import water.fvec.NewChunk;
import water.fvec.Vec;
import water.util.ArrayUtils;
import water.util.Log;
import water.util.MathUtils;

/* loaded from: input_file:hex/Model.class */
public abstract class Model<M extends Model<M, P, O>, P extends Parameters, O extends Output> extends Lockable<M> {
    public P _parms;
    public String[] _warnings;
    public O _output;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:hex/Model$BigScore.class */
    public class BigScore extends MRTask<Model<M, P, O>.BigScore> {
        final String[] _domain;
        final int _npredcols;
        ModelMetrics.MetricBuilder _mb;

        BigScore(String[] strArr, int i) {
            this._domain = strArr;
            this._npredcols = i;
        }

        @Override // water.MRTask
        public void map(Chunk[] chunkArr, NewChunk[] newChunkArr) {
            double[] dArr = new double[Model.this._output.nfeatures()];
            this._mb = Model.this.makeMetricBuilder(this._domain);
            int length = this._mb instanceof ModelMetricsSupervised.MetricBuilderSupervised ? chunkArr.length - 1 : 0;
            float[] fArr = this._mb._work;
            int i = chunkArr[0]._len;
            for (int i2 = 0; i2 < i; i2++) {
                float[] score0 = Model.this.score0(chunkArr, i2, dArr, fArr);
                float[] fArr2 = new float[chunkArr.length - length];
                for (int i3 = length; i3 < chunkArr.length; i3++) {
                    fArr2[i3 - length] = (float) chunkArr[i3].atd(i2);
                }
                this._mb.perRow(fArr, fArr2, Model.this);
                for (int i4 = 0; i4 < this._npredcols; i4++) {
                    newChunkArr[i4].addNum(score0[i4]);
                }
            }
        }

        @Override // water.MRTask
        public void reduce(Model<M, P, O>.BigScore bigScore) {
            this._mb.reduce(bigScore._mb);
        }

        @Override // water.MRTask
        protected void postGlobal() {
            this._mb.postGlobal();
        }
    }

    /* loaded from: input_file:hex/Model$ModelCategory.class */
    public enum ModelCategory {
        Unknown,
        Binomial,
        Multinomial,
        Regression,
        Clustering,
        AutoEncoder,
        DimReduction
    }

    /* loaded from: input_file:hex/Model$Output.class */
    public static abstract class Output extends Iced {
        public String[] _names;
        public String[][] _domains;
        public Job.JobState _state;
        public Key[] _model_metrics = new Key[0];
        public long _training_start_time = 0;
        public long _training_duration_in_ms = 0;

        public int nfeatures() {
            return this._names.length;
        }

        public Output(ModelBuilder modelBuilder) {
            if (modelBuilder.error_count() > 0) {
                throw new IllegalArgumentException(modelBuilder.validationErrors());
            }
            this._names = modelBuilder._train.names();
            this._domains = modelBuilder._train.domains();
        }

        public String[] allNames() {
            return this._names;
        }

        public String responseName() {
            return this._names[this._names.length - 1];
        }

        public String[] classNames() {
            return this._domains[this._domains.length - 1];
        }

        public boolean isClassifier() {
            return classNames() != null;
        }

        public int nclasses() {
            String[] classNames = classNames();
            if (classNames == null) {
                return 1;
            }
            return classNames.length;
        }

        public ModelCategory getModelCategory() {
            return isClassifier() ? nclasses() > 2 ? ModelCategory.Multinomial : ModelCategory.Binomial : ModelCategory.Regression;
        }

        /* JADX INFO: Access modifiers changed from: protected */
        public synchronized ModelMetrics addModelMetrics(ModelMetrics modelMetrics) {
            for (int i = 0; i < this._model_metrics.length; i++) {
                if (this._model_metrics[i] == modelMetrics._key) {
                    return modelMetrics;
                }
            }
            this._model_metrics = (Key[]) Arrays.copyOf(this._model_metrics, this._model_metrics.length + 1);
            this._model_metrics[this._model_metrics.length - 1] = modelMetrics._key;
            return modelMetrics;
        }

        long checksum_impl() {
            return (null == this._names ? 13 : Arrays.hashCode(this._names)) * (null == this._domains ? 17 : Arrays.deepHashCode(this._domains)) * getModelCategory().ordinal();
        }
    }

    /* loaded from: input_file:hex/Model$Parameters.class */
    public static abstract class Parameters extends Iced {
        public Key<Frame> _destination_key;
        public Key<Frame> _train;
        public Key<Frame> _valid;
        public String[] _ignored_columns;
        public boolean _score_each_iteration;
        public int _max_confusion_matrix_size = 20;
        public boolean _dropNA20Cols = defaultDropNA20Cols();
        public boolean _dropConsCols = defaultDropConsCols();

        public final Frame train() {
            return this._train.get();
        }

        public final Frame valid() {
            if (this._valid == null) {
                return null;
            }
            return this._valid.get();
        }

        public void read_lock_frames(Job job) {
            train().read_lock(job._key);
            if (this._valid == null || this._train.equals(this._valid)) {
                return;
            }
            valid().read_lock(job._key);
        }

        public void read_unlock_frames(Job job) {
            Frame train = train();
            if (train != null) {
                train.unlock(job._key);
            }
            if (this._valid == null || this._train.equals(this._valid)) {
                return;
            }
            valid().unlock(job._key);
        }

        protected boolean defaultDropNA20Cols() {
            return false;
        }

        protected boolean defaultDropConsCols() {
            return true;
        }

        /* JADX INFO: Access modifiers changed from: protected */
        public double missingColumnsType() {
            return Double.NaN;
        }

        protected long checksum_impl() {
            long j = 24589;
            int i = 0;
            Field[] wovenFields = Weaver.getWovenFields(getClass());
            Arrays.sort(wovenFields, new Comparator<Field>() { // from class: hex.Model.Parameters.1
                @Override // java.util.Comparator
                public int compare(Field field, Field field2) {
                    return field.getName().compareTo(field2.getName());
                }
            });
            for (Field field : wovenFields) {
                long j2 = MathUtils.PRIMES[i % MathUtils.PRIMES.length];
                Class<?> type = field.getType();
                if (type.isArray()) {
                    try {
                        field.setAccessible(true);
                        j = field.get(this) != null ? type.getComponentType() == Integer.TYPE ? j ^ (912559 + (j2 * Arrays.hashCode((int[]) field.get(this)))) : type.getComponentType() == Float.TYPE ? j ^ (912559 + (j2 * Arrays.hashCode((float[]) field.get(this)))) : type.getComponentType() == Double.TYPE ? j ^ (912559 + (j2 * Arrays.hashCode((double[]) field.get(this)))) : type.getComponentType() == Long.TYPE ? j ^ (912559 + (j2 * Arrays.hashCode((long[]) field.get(this)))) : j ^ (912559 + (j2 * Arrays.deepHashCode((Object[]) field.get(this)))) : j ^ (912559 + j2);
                    } catch (ClassCastException e) {
                        throw H2O.unimpl();
                    } catch (IllegalAccessException e2) {
                        throw new RuntimeException(e2);
                    }
                } else {
                    try {
                        field.setAccessible(true);
                        j = field.get(this) != null ? j ^ (4919 + (j2 * field.get(this).hashCode())) : j ^ (4919 + j2);
                    } catch (IllegalAccessException e3) {
                        throw new RuntimeException(e3);
                    }
                }
                i++;
            }
            return j ^ (train().checksum() * (this._valid == null ? 17L : valid().checksum()));
        }
    }

    public boolean isSupervised() {
        return false;
    }

    public void addWarning(String str) {
        this._warnings = (String[]) Arrays.copyOf(this._warnings, this._warnings.length + 1);
        this._warnings[this._warnings.length - 1] = str;
    }

    public ModelMetrics addMetrics(ModelMetrics modelMetrics) {
        return this._output.addModelMetrics(modelMetrics);
    }

    public abstract ModelMetrics.MetricBuilder makeMetricBuilder(String[] strArr);

    public abstract ModelSchema schema();

    public Model(Key key, P p, O o) {
        super(key);
        this._warnings = new String[0];
        this._parms = p;
        if (!$assertionsDisabled && p == null) {
            throw new AssertionError();
        }
        this._output = o;
    }

    public String[] adaptTestForTrain(Frame frame, boolean z) {
        return adaptTestForTrain(this._output._names, this._output._domains, frame, this._parms.missingColumnsType(), z);
    }

    public static String[] adaptTestForTrain(String[] strArr, String[][] strArr2, Frame frame, double d, boolean z) throws IllegalArgumentException {
        if (frame == null) {
            return new String[0];
        }
        String[][] domains = frame.domains();
        if (strArr == frame._names && strArr2 == domains) {
            return new String[0];
        }
        if (Arrays.equals(strArr, frame._names) && Arrays.deepEquals(strArr2, domains)) {
            return new String[0];
        }
        ArrayList arrayList = new ArrayList();
        Vec[] vecArr = new Vec[strArr.length];
        int i = 0;
        for (int i2 = 0; i2 < strArr.length; i2++) {
            Vec vec = frame.vec(strArr[i2]);
            if (vec == null) {
                arrayList.add("Validation set is missing training column " + strArr[i2]);
                if (z) {
                    vec = frame.anyVec().makeCon(d);
                    vec.setDomain(strArr2[i2]);
                }
            }
            if (vec != null) {
                if (strArr2[i2] != null) {
                    EnumWrappedVec adaptTo = vec.adaptTo(strArr2[i2]);
                    String[] domain = adaptTo.domain();
                    if (!$assertionsDisabled && (domain == null || domain.length < strArr2[i2].length)) {
                        throw new AssertionError();
                    }
                    if (domain.length > strArr2[i2].length) {
                        arrayList.add("Validation column " + strArr[i2] + " has levels not trained on: " + Arrays.toString(Arrays.copyOfRange(domain, strArr2[i2].length, domain.length)));
                    }
                    if (z) {
                        vec = adaptTo;
                        i++;
                    } else {
                        adaptTo.remove();
                        vec = null;
                    }
                } else {
                    if (vec.isEnum()) {
                        throw new IllegalArgumentException("Validation set has categorical column " + strArr[i2] + " which is real-valued in the training data");
                    }
                    i++;
                }
            }
            vecArr[i2] = vec;
        }
        if (i == 0) {
            throw new IllegalArgumentException("Validation set has no columns in common with the training set");
        }
        if (i == strArr.length) {
            frame.restructure(strArr, vecArr);
        }
        return (String[]) arrayList.toArray(new String[arrayList.size()]);
    }

    public Frame score(Frame frame) throws IllegalArgumentException {
        return score(frame, null);
    }

    public Frame score(Frame frame, String str) throws IllegalArgumentException {
        Frame frame2 = new Frame(frame);
        Vec vec = this._output.isClassifier() ? frame.vec(this._output.responseName()) : null;
        adaptTestForTrain(frame2, true);
        Frame scoreImpl = scoreImpl(frame, frame2, str);
        Vec vec2 = scoreImpl.vecs()[0];
        String[] domain = vec2.domain();
        if (this._output.isClassifier()) {
            if (!$assertionsDisabled && domain == null) {
                throw new AssertionError();
            }
            ModelMetrics fromDKV = ModelMetrics.getFromDKV(this, frame);
            ModelCategory modelCategory = this._output.getModelCategory();
            ConfusionMatrix cm = fromDKV.cm();
            if (modelCategory == ModelCategory.Binomial) {
                cm = ((ModelMetricsBinomial) fromDKV)._cm;
            } else if (modelCategory == ModelCategory.Multinomial) {
                cm = ((ModelMetricsMultinomial) fromDKV)._cm;
            }
            if (cm.domain != null) {
                if (!$assertionsDisabled && !Arrays.deepEquals(cm.domain, domain)) {
                    throw new AssertionError();
                }
                cm.table = cm.toTable();
                if (cm.confusion_matrix.length < this._parms._max_confusion_matrix_size) {
                    Log.info(cm.table.toString(1));
                }
            }
            String[] domain2 = vec.domain();
            if (domain2 != null && domain != domain2 && !Arrays.equals(domain, domain2)) {
                scoreImpl.replace(0, new EnumWrappedVec(vec.group().addVec(), vec.get_espc(), domain2, vec2._key));
            }
        }
        Vec[] vecs = frame2.vecs();
        for (int i = 0; i < vecs.length; i++) {
            if (frame.find(vecs[i]) != -1) {
                vecs[i] = null;
            }
        }
        frame2.delete();
        return scoreImpl;
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v11, types: [java.lang.String[], java.lang.String[][]] */
    protected Frame scoreImpl(Frame frame, Frame frame2, String str) {
        if (!$assertionsDisabled && !Arrays.equals(this._output._names, frame2._names)) {
            throw new AssertionError();
        }
        int nclasses = this._output.nclasses();
        int i = nclasses == 1 ? 1 : nclasses + 1;
        String[] strArr = new String[i];
        ?? r0 = new String[i];
        strArr[0] = "predict";
        for (int i2 = 1; i2 < strArr.length; i2++) {
            strArr[i2] = this._output.classNames()[i2 - 1];
        }
        r0[0] = nclasses == 1 ? null : frame2.lastVec().domain();
        Model<M, P, O>.BigScore doAll = new BigScore(r0[0], i).doAll(i, frame2);
        doAll._mb.makeModelMetrics(this, frame, this instanceof SupervisedModel ? frame2.lastVec().sigma() : Double.NaN);
        Frame outputFrame = doAll.outputFrame(null == str ? Key.make() : Key.make(str), strArr, r0);
        DKV.put(outputFrame);
        return outputFrame;
    }

    public float[] score0(Chunk[] chunkArr, int i, double[] dArr, float[] fArr) {
        if (!$assertionsDisabled && chunkArr.length < this._output._names.length) {
            throw new AssertionError();
        }
        for (int i2 = 0; i2 < this._output._names.length; i2++) {
            dArr[i2] = chunkArr[i2].atd(i);
        }
        return score0(dArr, fArr);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public abstract float[] score0(double[] dArr, float[] fArr);

    public double score(double[] dArr) {
        return ArrayUtils.maxIndex(score0(dArr, new float[this._output.nclasses()]));
    }

    @Override // water.Keyed
    protected Futures remove_impl(Futures futures) {
        if (this._output._model_metrics != null) {
            for (Key key : this._output._model_metrics) {
                key.remove(futures);
            }
        }
        return futures;
    }

    @Override // water.Keyed
    protected long checksum_impl() {
        return this._parms.checksum_impl() * this._output.checksum_impl();
    }

    static {
        $assertionsDisabled = !Model.class.desiredAssertionStatus();
    }
}
