package hex.naivebayes;

import hex.Model;
import hex.ModelCategory;
import hex.ModelMetrics;
import hex.ModelMetricsBinomial;
import hex.ModelMetricsMultinomial;
import hex.genmodel.GenModel;
import hex.schemas.NaiveBayesModelV3;
import water.H2O;
import water.Key;
import water.api.ModelSchema;
import water.util.JCodeGen;
import water.util.SB;
import water.util.TwoDimTable;

/* loaded from: input_file:hex/naivebayes/NaiveBayesModel.class */
public class NaiveBayesModel extends Model<NaiveBayesModel, NaiveBayesParameters, NaiveBayesOutput> {
    static final /* synthetic */ boolean $assertionsDisabled;

    /* renamed from: hex.naivebayes.NaiveBayesModel$1, reason: invalid class name */
    /* loaded from: input_file:hex/naivebayes/NaiveBayesModel$1.class */
    static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$hex$ModelCategory = new int[ModelCategory.values().length];

        static {
            try {
                $SwitchMap$hex$ModelCategory[ModelCategory.Binomial.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$hex$ModelCategory[ModelCategory.Multinomial.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
        }
    }

    /* loaded from: input_file:hex/naivebayes/NaiveBayesModel$NaiveBayesOutput.class */
    public static class NaiveBayesOutput extends Model.Output {
        public TwoDimTable _apriori;
        public double[] _apriori_raw;
        public TwoDimTable[] _pcond;
        public double[][][] _pcond_raw;
        public int[] _rescnt;
        public String[] _levels;
        public int _ncats;

        public NaiveBayesOutput(NaiveBayes naiveBayes) {
            super(naiveBayes);
        }
    }

    /* loaded from: input_file:hex/naivebayes/NaiveBayesModel$NaiveBayesParameters.class */
    public static class NaiveBayesParameters extends Model.Parameters {
        public double _laplace = 0.0d;
        public double _eps_sdev = 0.0d;
        public double _min_sdev = 0.001d;
        public double _eps_prob = 0.0d;
        public double _min_prob = 0.001d;
        public boolean _compute_metrics = true;
    }

    public NaiveBayesModel(Key key, NaiveBayesParameters naiveBayesParameters, NaiveBayesOutput naiveBayesOutput) {
        super(key, naiveBayesParameters, naiveBayesOutput);
    }

    public ModelSchema schema() {
        return new NaiveBayesModelV3();
    }

    public ModelMetrics.MetricBuilder makeMetricBuilder(String[] strArr) {
        switch (AnonymousClass1.$SwitchMap$hex$ModelCategory[((NaiveBayesOutput) this._output).getModelCategory().ordinal()]) {
            case 1:
                return new ModelMetricsBinomial.MetricBuilderBinomial(strArr);
            case 2:
                return new ModelMetricsMultinomial.MetricBuilderMultinomial(strArr.length, strArr);
            default:
                throw H2O.unimpl();
        }
    }

    protected double[] score0(double[] dArr, double[] dArr2) {
        double[] dArr3 = new double[((NaiveBayesOutput) this._output)._levels.length];
        if (!$assertionsDisabled && dArr2.length != ((NaiveBayesOutput) this._output)._levels.length + 1) {
            throw new AssertionError();
        }
        for (int i = 0; i < ((NaiveBayesOutput) this._output)._levels.length; i++) {
            dArr3[i] = Math.log(((NaiveBayesOutput) this._output)._apriori_raw[i]);
            for (int i2 = 0; i2 < ((NaiveBayesOutput) this._output)._ncats; i2++) {
                if (!Double.isNaN(dArr[i2])) {
                    int i3 = (int) dArr[i2];
                    double length = i3 < ((NaiveBayesOutput) this._output)._pcond_raw.length ? ((NaiveBayesOutput) this._output)._pcond_raw[i2][i][i3] : ((NaiveBayesParameters) this._parms)._laplace / (((NaiveBayesOutput) this._output)._rescnt[i] + (((NaiveBayesParameters) this._parms)._laplace * ((NaiveBayesOutput) this._output)._domains[i2].length));
                    int i4 = i;
                    dArr3[i4] = dArr3[i4] + Math.log(length <= ((NaiveBayesParameters) this._parms)._eps_prob ? ((NaiveBayesParameters) this._parms)._min_prob : length);
                }
            }
            for (int i5 = ((NaiveBayesOutput) this._output)._ncats; i5 < dArr.length; i5++) {
                if (!Double.isNaN(dArr[i5])) {
                    double d = dArr[i5];
                    double d2 = Double.isNaN(((NaiveBayesOutput) this._output)._pcond_raw[i5][i][0]) ? 0.0d : ((NaiveBayesOutput) this._output)._pcond_raw[i5][i][0];
                    double d3 = Double.isNaN(((NaiveBayesOutput) this._output)._pcond_raw[i5][i][1]) ? 1.0d : ((NaiveBayesOutput) this._output)._pcond_raw[i5][i][1] <= ((NaiveBayesParameters) this._parms)._eps_sdev ? ((NaiveBayesParameters) this._parms)._min_sdev : ((NaiveBayesOutput) this._output)._pcond_raw[i5][i][1];
                    double exp = Math.exp((-((d - d2) * (d - d2))) / ((2.0d * d3) * d3)) / (d3 * Math.sqrt(6.283185307179586d));
                    int i6 = i;
                    dArr3[i6] = dArr3[i6] + Math.log(exp <= ((NaiveBayesParameters) this._parms)._eps_prob ? ((NaiveBayesParameters) this._parms)._min_prob : exp);
                }
            }
        }
        for (int i7 = 0; i7 < dArr3.length; i7++) {
            double d4 = 0.0d;
            for (double d5 : dArr3) {
                d4 += Math.exp(d5 - dArr3[i7]);
            }
            dArr2[i7 + 1] = 1.0d / d4;
        }
        dArr2[0] = GenModel.getPrediction(dArr2, ((NaiveBayesOutput) this._output)._priorClassDist, dArr, defaultThreshold());
        return dArr2;
    }

    protected SB toJavaInit(SB sb, SB sb2) {
        SB javaInit = super.toJavaInit(sb, sb2);
        javaInit.ip("public boolean isSupervised() { return " + isSupervised() + "; }").nl();
        javaInit.ip("public int nfeatures() { return " + ((NaiveBayesOutput) this._output).nfeatures() + "; }").nl();
        javaInit.ip("public int nclasses() { return " + ((NaiveBayesOutput) this._output).nclasses() + "; }").nl();
        JCodeGen.toStaticVar(javaInit, "RESCNT", ((NaiveBayesOutput) this._output)._rescnt, "Count of categorical levels in response.");
        JCodeGen.toStaticVar(javaInit, "APRIORI", ((NaiveBayesOutput) this._output)._apriori_raw, "Apriori class distribution of the response.");
        JCodeGen.toStaticVar(javaInit, "PCOND", ((NaiveBayesOutput) this._output)._pcond_raw, "Conditional probability of predictors.");
        double[] dArr = null;
        if (((NaiveBayesOutput) this._output)._ncats > 0) {
            dArr = new double[((NaiveBayesOutput) this._output)._ncats];
            for (int i = 0; i < ((NaiveBayesOutput) this._output)._ncats; i++) {
                dArr[i] = ((NaiveBayesOutput) this._output)._domains[i].length;
            }
        }
        JCodeGen.toStaticVar(javaInit, "DOMLEN", dArr, "Number of unique levels for each categorical predictor.");
        return javaInit;
    }

    protected void toJavaPredictBody(SB sb, SB sb2, SB sb3) {
        SB sb4 = new SB();
        sb.i().p("java.util.Arrays.fill(preds,0);").nl();
        sb.i().p("double mean, sdev, prob;").nl();
        sb.i().p("double[] nums = new double[" + ((NaiveBayesOutput) this._output)._levels.length + "];").nl();
        sb.i().p("for(int i = 0; i < " + ((NaiveBayesOutput) this._output)._levels.length + "; i++) {").nl();
        sb.i(1).p("nums[i] = Math.log(APRIORI[i]);").nl();
        sb.i(1).p("for(int j = 0; j < " + ((NaiveBayesOutput) this._output)._ncats + "; j++) {").nl();
        sb.i(2).p("if(Double.isNaN(data[j])) continue;").nl();
        sb.i(2).p("int level = (int)data[j];").nl();
        sb.i(2).p("prob = level < " + ((NaiveBayesOutput) this._output)._pcond_raw.length + " ? PCOND[j][i][level] : " + (((NaiveBayesParameters) this._parms)._laplace == 0.0d ? 0 : ((NaiveBayesParameters) this._parms)._laplace + "/(RESCNT[i] + " + ((NaiveBayesParameters) this._parms)._laplace + "*DOMLEN[j])")).p(";").nl();
        sb.i(2).p("nums[i] += Math.log(prob <= " + ((NaiveBayesParameters) this._parms)._eps_prob + " ? " + ((NaiveBayesParameters) this._parms)._min_prob + " : prob);").nl();
        sb.i(1).p("}").nl();
        sb.i(1).p("for(int j = " + ((NaiveBayesOutput) this._output)._ncats + "; j < data.length; j++) {").nl();
        sb.i(2).p("if(Double.isNaN(data[j])) continue;").nl();
        sb.i(2).p("mean = Double.isNaN(PCOND[j][i][0]) ? 0 : PCOND[j][i][0];").nl();
        sb.i(2).p("sdev = Double.isNaN(PCOND[j][i][1]) ? 1 : (PCOND[j][i][1] <= " + ((NaiveBayesParameters) this._parms)._eps_sdev + " ? " + ((NaiveBayesParameters) this._parms)._min_sdev + " : PCOND[j][i][1]);").nl();
        sb.i(2).p("prob = Math.exp(-((data[j]-mean)*(data[j]-mean))/(2.*sdev*sdev)) / (sdev*Math.sqrt(2.*Math.PI));").nl();
        sb.i(2).p("nums[i] += Math.log(prob <= " + ((NaiveBayesParameters) this._parms)._eps_prob + " ? " + ((NaiveBayesParameters) this._parms)._min_prob + " : prob);").nl();
        sb.i(1).p("}").nl();
        sb.i().p("}").nl();
        sb.i().p("double sum;").nl();
        sb.i().p("for(int i = 0; i < nums.length; i++) {").nl();
        sb.i(1).p("sum = 0;").nl();
        sb.i(1).p("for(int j = 0; j < nums.length; j++) {").nl();
        sb.i(2).p("sum += Math.exp(nums[j]-nums[i]);").nl();
        sb.i(1).p("}").nl();
        sb.i(1).p("preds[i+1] = 1/sum;").nl();
        sb.i().p("}").nl();
        sb3.p(sb4);
        sb.i().p("preds[0] = hex.genmodel.GenModel.getPrediction(preds, PRIOR_CLASS_DISTRIB, data, " + defaultThreshold() + ");").nl();
    }

    static {
        $assertionsDisabled = !NaiveBayesModel.class.desiredAssertionStatus();
    }
}
