package org.apache.lens.ml.spark.algos;

import java.util.List;
import org.apache.hadoop.hive.conf.HiveConf;
import org.apache.hadoop.hive.metastore.api.FieldSchema;
import org.apache.hadoop.hive.ql.metadata.Hive;
import org.apache.hadoop.hive.ql.metadata.Table;
import org.apache.hadoop.io.WritableComparable;
import org.apache.hive.hcatalog.data.HCatRecord;
import org.apache.lens.api.LensConf;
import org.apache.lens.api.LensException;
import org.apache.lens.ml.AlgoArgParser;
import org.apache.lens.ml.AlgoParam;
import org.apache.lens.ml.Algorithm;
import org.apache.lens.ml.MLAlgo;
import org.apache.lens.ml.MLModel;
import org.apache.lens.ml.spark.HiveTableRDD;
import org.apache.lens.ml.spark.models.KMeansClusteringModel;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.mllib.clustering.KMeans;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import scala.Tuple2;

@Algorithm(name = "spark_kmeans_algo", description = "Spark MLLib KMeans algo")
/* loaded from: input_file:org/apache/lens/ml/spark/algos/KMeansAlgo.class */
public class KMeansAlgo implements MLAlgo {
    private transient LensConf conf;
    private JavaSparkContext sparkContext;

    @AlgoParam(name = "k", help = "Number of cluster")
    private int k;

    @AlgoParam(name = "partition", help = "Partition filter to be used while constructing table RDD")
    private String partFilter = null;

    @AlgoParam(name = "maxIterations", help = "Maximum number of iterations", defaultValue = "100")
    private int maxIterations = 100;

    @AlgoParam(name = "runs", help = "Number of parallel run", defaultValue = "1")
    private int runs = 1;

    @AlgoParam(name = "initializationMode", help = "initialization model, either \"random\" or \"k-means||\" (default).", defaultValue = "k-means||")
    private String initializationMode = "k-means||";

    @Override // org.apache.lens.ml.MLAlgo
    public String getName() {
        return ((Algorithm) getClass().getAnnotation(Algorithm.class)).name();
    }

    @Override // org.apache.lens.ml.MLAlgo
    public String getDescription() {
        return ((Algorithm) getClass().getAnnotation(Algorithm.class)).description();
    }

    @Override // org.apache.lens.ml.MLAlgo
    public void configure(LensConf lensConf) {
        this.conf = lensConf;
    }

    @Override // org.apache.lens.ml.MLAlgo
    public LensConf getConf() {
        return this.conf;
    }

    @Override // org.apache.lens.ml.MLAlgo
    public MLModel train(LensConf lensConf, String str, String str2, String str3, String... strArr) throws LensException {
        List<String> parseArgs = AlgoArgParser.parseArgs(this, strArr);
        final int[] iArr = new int[parseArgs.size()];
        final int size = parseArgs.size();
        try {
            Table table = Hive.get(toHiveConf(lensConf)).getTable(str, str2);
            List allCols = table.getAllCols();
            int i = 0;
            for (int i2 = 0; i2 < table.getAllCols().size(); i2++) {
                if (parseArgs.contains(((FieldSchema) allCols.get(i2)).getName())) {
                    int i3 = i;
                    i++;
                    iArr[i3] = i2;
                }
            }
            return new KMeansClusteringModel(str3, KMeans.train(HiveTableRDD.createHiveTableRDD(this.sparkContext, toHiveConf(lensConf), str, str2, this.partFilter).map(new Function<Tuple2<WritableComparable, HCatRecord>, Vector>() { // from class: org.apache.lens.ml.spark.algos.KMeansAlgo.1
                public Vector call(Tuple2<WritableComparable, HCatRecord> tuple2) throws Exception {
                    HCatRecord hCatRecord = (HCatRecord) tuple2._2();
                    double[] dArr = new double[size];
                    for (int i4 = 0; i4 < size; i4++) {
                        Object obj = hCatRecord.get(iArr[i4]);
                        dArr[i4] = obj == null ? 0.0d : ((Double) obj).doubleValue();
                    }
                    return Vectors.dense(dArr);
                }
            }).rdd(), this.k, this.maxIterations, this.runs, this.initializationMode));
        } catch (Exception e) {
            throw new LensException("KMeans algo failed for " + str + "." + str2, e);
        }
    }

    private HiveConf toHiveConf(LensConf lensConf) {
        HiveConf hiveConf = new HiveConf();
        for (String str : lensConf.getProperties().keySet()) {
            hiveConf.set(str, (String) lensConf.getProperties().get(str));
        }
        return hiveConf;
    }
}
