package org.apache.lens.ml.algo.spark;

import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.conf.HiveConf;
import org.apache.lens.api.LensConf;
import org.apache.lens.api.LensException;
import org.apache.lens.ml.algo.api.AlgoParam;
import org.apache.lens.ml.algo.api.Algorithm;
import org.apache.lens.ml.algo.api.MLAlgo;
import org.apache.lens.ml.algo.api.MLModel;
import org.apache.lens.ml.algo.spark.TableTrainingSpec;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.rdd.RDD;

/* loaded from: input_file:org/apache/lens/ml/algo/spark/BaseSparkAlgo.class */
public abstract class BaseSparkAlgo implements MLAlgo {
    public static final Log LOG = LogFactory.getLog(BaseSparkAlgo.class);
    private final String name;
    private final String description;
    protected JavaSparkContext sparkContext;
    protected Map<String, String> params;
    protected transient LensConf conf;

    @AlgoParam(name = "trainingFraction", help = "% of dataset to be used for training", defaultValue = "0")
    protected double trainingFraction;
    private boolean useTrainingFraction;

    @AlgoParam(name = "label", help = "Name of column which is used as a training label for supervised learning")
    protected String label;

    @AlgoParam(name = "partition", help = "Partition filter used to create create HCatInputFormats")
    protected String partitionFilter;

    @AlgoParam(name = "feature", help = "Column name(s) which are to be used as sample features")
    protected List<String> features;

    public BaseSparkAlgo(String str, String str2) {
        this.name = str;
        this.description = str2;
    }

    public void setSparkContext(JavaSparkContext javaSparkContext) {
        this.sparkContext = javaSparkContext;
    }

    @Override // org.apache.lens.ml.algo.api.MLAlgo
    public LensConf getConf() {
        return this.conf;
    }

    @Override // org.apache.lens.ml.algo.api.MLAlgo
    public void configure(LensConf lensConf) {
        this.conf = lensConf;
    }

    @Override // org.apache.lens.ml.algo.api.MLAlgo
    public MLModel<?> train(LensConf lensConf, String str, String str2, String str3, String... strArr) throws LensException {
        parseParams(strArr);
        TableTrainingSpec.TableTrainingSpecBuilder labelColumn = TableTrainingSpec.newBuilder().hiveConf(toHiveConf(lensConf)).database(str).table(str2).partitionFilter(this.partitionFilter).featureColumns(this.features).labelColumn(this.label);
        if (this.useTrainingFraction) {
            labelColumn.trainingFraction(this.trainingFraction);
        }
        TableTrainingSpec build = labelColumn.build();
        LOG.info("Training  with " + this.features.size() + " features");
        build.createRDDs(this.sparkContext);
        BaseSparkClassificationModel trainInternal = trainInternal(str3, build.getTrainingRDD());
        trainInternal.setTable(str2);
        trainInternal.setParams(Arrays.asList(strArr));
        trainInternal.setLabelColumn(this.label);
        trainInternal.setFeatureColumns(this.features);
        return trainInternal;
    }

    protected HiveConf toHiveConf(LensConf lensConf) {
        HiveConf hiveConf = new HiveConf();
        for (String str : lensConf.getProperties().keySet()) {
            hiveConf.set(str, (String) lensConf.getProperties().get(str));
        }
        return hiveConf;
    }

    public void parseParams(String[] strArr) {
        if (strArr.length % 2 != 0) {
            throw new IllegalArgumentException("Invalid number of params " + strArr.length);
        }
        this.params = new LinkedHashMap();
        for (int i = 0; i < strArr.length; i += 2) {
            if ("f".equalsIgnoreCase(strArr[i]) || "feature".equalsIgnoreCase(strArr[i])) {
                if (this.features == null) {
                    this.features = new ArrayList();
                }
                this.features.add(strArr[i + 1]);
            } else if ("l".equalsIgnoreCase(strArr[i]) || "label".equalsIgnoreCase(strArr[i])) {
                this.label = strArr[i + 1];
            } else {
                this.params.put(strArr[i].replaceAll("\\-+", ""), strArr[i + 1]);
            }
        }
        if (this.params.containsKey("trainingFraction")) {
            try {
                this.trainingFraction = Double.parseDouble(this.params.get("trainingFraction"));
                this.useTrainingFraction = true;
            } catch (NumberFormatException e) {
                throw new IllegalArgumentException("Invalid training fraction", e);
            }
        }
        if (this.params.containsKey("partition") || this.params.containsKey("p")) {
            this.partitionFilter = this.params.containsKey("partition") ? this.params.get("partition") : this.params.get("p");
        }
        parseAlgoParams(this.params);
    }

    public double getParamValue(String str, double d) {
        if (this.params.containsKey(str)) {
            try {
                return Double.parseDouble(this.params.get(str));
            } catch (NumberFormatException e) {
                LOG.warn("Couldn't parse param value: " + str + " as double.");
            }
        }
        return d;
    }

    public int getParamValue(String str, int i) {
        if (this.params.containsKey(str)) {
            try {
                return Integer.parseInt(this.params.get(str));
            } catch (NumberFormatException e) {
                LOG.warn("Couldn't parse param value: " + str + " as integer.");
            }
        }
        return i;
    }

    @Override // org.apache.lens.ml.algo.api.MLAlgo
    public String getName() {
        return this.name;
    }

    @Override // org.apache.lens.ml.algo.api.MLAlgo
    public String getDescription() {
        return this.description;
    }

    public Map<String, String> getArgUsage() {
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        Class<?> cls = getClass();
        Algorithm algorithm = (Algorithm) cls.getAnnotation(Algorithm.class);
        if (algorithm != null) {
            linkedHashMap.put("Algorithm Name", algorithm.name());
            linkedHashMap.put("Algorithm Description", algorithm.description());
        }
        while (cls != null) {
            for (Field field : cls.getDeclaredFields()) {
                AlgoParam algoParam = (AlgoParam) field.getAnnotation(AlgoParam.class);
                if (algoParam != null) {
                    linkedHashMap.put("[param] " + algoParam.name(), algoParam.help() + " Default Value = " + algoParam.defaultValue());
                }
            }
            if (cls.equals(BaseSparkAlgo.class)) {
                break;
            }
            cls = cls.getSuperclass();
        }
        return linkedHashMap;
    }

    public abstract void parseAlgoParams(Map<String, String> map);

    protected abstract BaseSparkClassificationModel trainInternal(String str, RDD<LabeledPoint> rdd) throws LensException;
}
