package hex.kmeans;

import hex.ClusteringModel;
import hex.Model;
import hex.ModelMetrics;
import hex.ModelMetricsClustering;
import hex.ToEigenVec;
import hex.genmodel.GenModel;
import hex.genmodel.IClusteringModel;
import hex.kmeans.KMeans;
import hex.util.EffectiveParametersUtils;
import hex.util.LinearAlgebraUtils;
import java.util.Arrays;
import water.DKV;
import water.Job;
import water.Key;
import water.MRTask;
import water.codegen.CodeGenerator;
import water.codegen.CodeGeneratorPipeline;
import water.exceptions.JCodeSB;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.udf.CFuncRef;
import water.util.ArrayUtils;
import water.util.JCodeGen;
import water.util.SBPrintStream;

/* loaded from: input_file:hex/kmeans/KMeansModel.class */
public class KMeansModel extends ClusteringModel<KMeansModel, KMeansParameters, KMeansOutput> {
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:hex/kmeans/KMeansModel$KMeansOutput.class */
    public static class KMeansOutput extends ClusteringModel.ClusteringOutput {
        public int _iterations;
        public double[] _withinss;
        public double _tot_withinss;
        public double[] _history_withinss;
        public double _totss;
        public double _betweenss;
        public int _categorical_column_count;
        public long[] _training_time_ms;
        public double[] _reassigned_count;
        public int[] _k;

        public KMeansOutput(KMeans kMeans) {
            super(kMeans);
            this._history_withinss = new double[]{Double.NaN};
            this._training_time_ms = new long[]{System.currentTimeMillis()};
            this._reassigned_count = new double[]{Double.NaN};
            this._k = new int[]{0};
        }
    }

    /* loaded from: input_file:hex/kmeans/KMeansModel$KMeansParameters.class */
    public static class KMeansParameters extends ClusteringModel.ClusteringParameters {
        public Key<Frame> _user_points;
        public int _max_iterations = 10;
        public boolean _standardize = true;
        public KMeans.Initialization _init = KMeans.Initialization.Furthest;
        public boolean _pred_indicator = false;
        public boolean _estimate_k = false;
        public int[] _cluster_size_constraints = null;

        public String algoName() {
            return "KMeans";
        }

        public String fullName() {
            return "K-means";
        }

        public String javaName() {
            return KMeansModel.class.getName();
        }

        public long progressUnits() {
            return this._estimate_k ? this._k : this._max_iterations;
        }
    }

    public ToEigenVec getToEigenVec() {
        return LinearAlgebraUtils.toEigen;
    }

    public KMeansModel(Key key, KMeansParameters kMeansParameters, KMeansOutput kMeansOutput) {
        super(key, kMeansParameters, kMeansOutput);
    }

    public void initActualParamValues() {
        super.initActualParamValues();
        EffectiveParametersUtils.initFoldAssignment(this._parms);
        EffectiveParametersUtils.initCategoricalEncoding(this._parms, Model.Parameters.CategoricalEncodingScheme.Enum);
    }

    public ModelMetrics.MetricBuilder makeMetricBuilder(String[] strArr) {
        if ($assertionsDisabled || strArr == null) {
            return new ModelMetricsClustering.MetricBuilderClustering(this._output.nfeatures(), this._output._k[this._output._k.length - 1]);
        }
        throw new AssertionError();
    }

    /* JADX WARN: Type inference failed for: r0v13, types: [hex.kmeans.KMeansModel$1] */
    protected Frame predictScoreImpl(Frame frame, Frame frame2, String str, final Job job, boolean z, CFuncRef cFuncRef) {
        if (!this._parms._pred_indicator) {
            return super.predictScoreImpl(frame, frame2, str, job, z, cFuncRef);
        }
        final int i = this._output._k[this._output._k.length - 1];
        Frame frame3 = new Frame(frame2);
        for (int i2 = 0; i2 < i; i2++) {
            frame3.add("cluster_" + Double.toString(i2 + 1), frame3.anyVec().makeZero());
        }
        new MRTask() { // from class: hex.kmeans.KMeansModel.1
            public void map(Chunk[] chunkArr) {
                if (isCancelled()) {
                    return;
                }
                if (job == null || !job.stop_requested()) {
                    double[] dArr = new double[KMeansModel.this._output._names.length];
                    double[] dArr2 = new double[i];
                    for (int i3 = 0; i3 < chunkArr[0]._len; i3++) {
                        Arrays.fill(dArr2, 0.0d);
                        double[] score_indicator = KMeansModel.this.score_indicator(chunkArr, i3, dArr, dArr2);
                        for (int i4 = 0; i4 < dArr2.length; i4++) {
                            chunkArr[KMeansModel.this._output._names.length + i4].set(i3, score_indicator[i4]);
                        }
                    }
                    if (job != null) {
                        job.update(1L);
                    }
                }
            }
        }.doAll(frame3);
        Frame extractFrame = frame3.extractFrame(this._output._names.length, frame3.numCols());
        Frame frame4 = new Frame(Key.make(str), extractFrame.names(), extractFrame.vecs());
        DKV.put(frame4);
        makeMetricBuilder(null).makeModelMetrics(this, frame, (Frame) null, (Frame) null);
        return frame4;
    }

    public double[] score_indicator(Chunk[] chunkArr, int i, double[] dArr, double[] dArr2) {
        if (!$assertionsDisabled && !this._parms._pred_indicator) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && (dArr.length != this._output._names.length || dArr2.length != this._output._centers_raw.length)) {
            throw new AssertionError();
        }
        for (int i2 = 0; i2 < dArr.length; i2++) {
            dArr[i2] = chunkArr[i2].atd(i);
        }
        double[] dArr3 = new double[1];
        score0(dArr, dArr3);
        if (!$assertionsDisabled && (dArr2 == null || ArrayUtils.l2norm2(dArr2) != 0.0d)) {
            throw new AssertionError("preds must be a vector of all zeros, got " + Arrays.toString(dArr2));
        }
        if (!$assertionsDisabled && (dArr3[0] < 0.0d || dArr3[0] >= dArr2.length)) {
            throw new AssertionError("Cluster number must be an integer in [0," + String.valueOf(dArr2.length) + ")");
        }
        dArr2[(int) dArr3[0]] = 1.0d;
        return dArr2;
    }

    public double[] score_ratio(Chunk[] chunkArr, int i, double[] dArr) {
        if (!$assertionsDisabled && !this._parms._pred_indicator) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && dArr.length != this._output._names.length) {
            throw new AssertionError();
        }
        for (int i2 = 0; i2 < dArr.length; i2++) {
            dArr[i2] = chunkArr[i2].atd(i);
        }
        double[] KMeans_simplex = GenModel.KMeans_simplex(this._parms._standardize ? this._output._centers_std_raw : this._output._centers_raw, dArr, this._output._domains);
        if (!$assertionsDisabled && KMeans_simplex.length != this._output._k[this._output._k.length - 1]) {
            throw new AssertionError();
        }
        if ($assertionsDisabled || Math.abs(ArrayUtils.sum(KMeans_simplex) - 1.0d) < 1.0E-6d) {
            return KMeans_simplex;
        }
        throw new AssertionError("Sum of k-means distance ratios should equal 1");
    }

    protected double[] score0(double[] dArr, double[] dArr2, double d) {
        return score0(dArr, dArr2);
    }

    protected double[] score0(double[] dArr, double[] dArr2) {
        double[][] dArr3 = this._parms._standardize ? this._output._centers_std_raw : this._output._centers_raw;
        GenModel.Kmeans_preprocessData(dArr, this._output._normSub, this._output._normMul, this._output._mode);
        dArr2[0] = GenModel.KMeans_closest(dArr3, dArr, this._output._domains);
        return dArr2;
    }

    protected double data(Chunk[] chunkArr, int i, int i2) {
        return GenModel.Kmeans_preprocessData(chunkArr[i2].atd(i), i2, this._output._normSub, this._output._normMul, this._output._mode);
    }

    protected Class<?>[] getPojoInterfaces() {
        return new Class[]{IClusteringModel.class};
    }

    protected void toJavaPredictBody(SBPrintStream sBPrintStream, CodeGeneratorPipeline codeGeneratorPipeline, CodeGeneratorPipeline codeGeneratorPipeline2, boolean z) {
        final String javaId = JCodeGen.toJavaId(this._key.toString());
        if (!this._parms._standardize) {
            codeGeneratorPipeline2.add(new CodeGenerator() { // from class: hex.kmeans.KMeansModel.3
                public void generate(JCodeSB jCodeSB) {
                    JCodeGen.toClassWithArray(jCodeSB, (String) null, javaId + "_CENTERS", KMeansModel.this._output._centers_raw, "Denormalized cluster centers[K][features]");
                }
            });
            sBPrintStream.ip("preds[0] = KMeans_closest(").pj(javaId + "_CENTERS", "VALUES").p(",data, DOMAINS);").nl();
        } else {
            codeGeneratorPipeline2.add(new CodeGenerator() { // from class: hex.kmeans.KMeansModel.2
                public void generate(JCodeSB jCodeSB) {
                    JCodeGen.toClassWithArray(jCodeSB, (String) null, javaId + "_MEANS", KMeansModel.this._output._normSub, "Column means of training data");
                    JCodeGen.toClassWithArray(jCodeSB, (String) null, javaId + "_MULTS", KMeansModel.this._output._normMul, "Reciprocal of column standard deviations of training data");
                    JCodeGen.toClassWithArray(jCodeSB, (String) null, javaId + "_MODES", KMeansModel.this._output._mode, "Mode for categorical columns");
                    JCodeGen.toClassWithArray(jCodeSB, (String) null, javaId + "_CENTERS", KMeansModel.this._output._centers_std_raw, "Normalized cluster centers[K][features]");
                }
            });
            sBPrintStream.ip("Kmeans_preprocessData(data,").pj(javaId + "_MEANS", "VALUES,").pj(javaId + "_MULTS", "VALUES,").pj(javaId + "_MODES", "VALUES").p(");").nl();
            sBPrintStream.ip("preds[0] = KMeans_closest(").pj(javaId + "_CENTERS", "VALUES").p(", data, DOMAINS); ").nl();
        }
    }

    protected SBPrintStream toJavaTransform(SBPrintStream sBPrintStream, CodeGeneratorPipeline codeGeneratorPipeline, boolean z) {
        sBPrintStream.nl();
        sBPrintStream.ip("// Pass in data in a double[], in a same way as to the score0 function.").nl();
        sBPrintStream.ip("// Cluster distances will be stored into the distances[] array. Function").nl();
        sBPrintStream.ip("// will return the closest cluster. This way the caller can avoid to call").nl();
        sBPrintStream.ip("// score0(..) to retrieve the cluster where the data point belongs.").nl();
        sBPrintStream.ip("public final int distances( double[] data, double[] distances ) {").nl();
        toJavaDistancesBody(sBPrintStream.ii(1));
        sBPrintStream.ip("return cluster;").nl();
        sBPrintStream.di(1).ip("}").nl();
        sBPrintStream.nl();
        sBPrintStream.ip("// Returns number of cluster used by this model.").nl();
        sBPrintStream.ip("public final int getNumClusters() {").nl();
        toJavaGetNumClustersBody(sBPrintStream.ii(1));
        sBPrintStream.ip("return nclusters;").nl();
        sBPrintStream.di(1).ip("}").nl();
        new CodeGeneratorPipeline().generate(sBPrintStream.ii(1));
        sBPrintStream.di(1);
        return sBPrintStream;
    }

    private void toJavaDistancesBody(SBPrintStream sBPrintStream) {
        String javaId = JCodeGen.toJavaId(this._key.toString());
        if (!this._parms._standardize) {
            sBPrintStream.ip("int cluster = KMeans_distances(").pj(javaId + "_CENTERS", "VALUES").p(",data, DOMAINS, distances);").nl();
        } else {
            sBPrintStream.ip("Kmeans_preprocessData(data,").pj(javaId + "_MEANS", "VALUES,").pj(javaId + "_MULTS", "VALUES,").pj(javaId + "_MODES", "VALUES").p(");").nl();
            sBPrintStream.ip("int cluster = KMeans_distances(").pj(javaId + "_CENTERS", "VALUES").p(", data, DOMAINS, distances); ").nl();
        }
    }

    private void toJavaGetNumClustersBody(SBPrintStream sBPrintStream) {
        sBPrintStream.ip("int nclusters = ").pj(JCodeGen.toJavaId(this._key.toString()) + "_CENTERS", "VALUES").p(".length;").nl();
    }

    protected boolean toJavaCheckTooBig() {
        return this._parms._standardize ? ((double) (this._output._centers_std_raw.length * this._output._centers_std_raw[0].length)) > 1000000.0d : ((double) (this._output._centers_raw.length * this._output._centers_raw[0].length)) > 1000000.0d;
    }

    /* renamed from: getMojo, reason: merged with bridge method [inline-methods] */
    public KMeansMojoWriter m130getMojo() {
        return new KMeansMojoWriter(this);
    }

    static {
        $assertionsDisabled = !KMeansModel.class.desiredAssertionStatus();
    }
}
