package org.apache.lens.ml.spark.algos;

import java.util.Map;
import org.apache.lens.api.LensException;
import org.apache.lens.ml.AlgoParam;
import org.apache.lens.ml.Algorithm;
import org.apache.lens.ml.spark.models.BaseSparkClassificationModel;
import org.apache.lens.ml.spark.models.DecisionTreeClassificationModel;
import org.apache.lens.ml.spark.models.SparkDecisionTreeModel;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.tree.DecisionTree$;
import org.apache.spark.mllib.tree.configuration.Algo$;
import org.apache.spark.mllib.tree.impurity.Entropy$;
import org.apache.spark.mllib.tree.impurity.Gini$;
import org.apache.spark.mllib.tree.impurity.Impurity;
import org.apache.spark.mllib.tree.impurity.Variance$;
import org.apache.spark.rdd.RDD;
import scala.Enumeration;

@Algorithm(name = "spark_decision_tree", description = "Spark Decision Tree classifier algo")
/* loaded from: input_file:org/apache/lens/ml/spark/algos/DecisionTreeAlgo.class */
public class DecisionTreeAlgo extends BaseSparkAlgo {

    @AlgoParam(name = "algo", help = "Decision tree algorithm. Allowed values are 'classification' and 'regression'")
    private Enumeration.Value algo;

    @AlgoParam(name = "impurity", help = "Impurity measure used by the decision tree. Allowed values are 'gini', 'entropy' and 'variance'")
    private Impurity decisionTreeImpurity;

    @AlgoParam(name = "maxDepth", help = "Max depth of the decision tree. Integer values expected.", defaultValue = "100")
    private int maxDepth;

    public DecisionTreeAlgo(String str, String str2) {
        super(str, str2);
    }

    @Override // org.apache.lens.ml.spark.algos.BaseSparkAlgo
    public void parseAlgoParams(Map<String, String> map) {
        String str = map.get("algo");
        if ("classification".equalsIgnoreCase(str)) {
            this.algo = Algo$.MODULE$.Classification();
        } else if ("regression".equalsIgnoreCase(str)) {
            this.algo = Algo$.MODULE$.Regression();
        }
        String str2 = map.get("impurity");
        if ("gini".equals(str2)) {
            this.decisionTreeImpurity = Gini$.MODULE$;
        } else if ("entropy".equals(str2)) {
            this.decisionTreeImpurity = Entropy$.MODULE$;
        } else if ("variance".equals(str2)) {
            this.decisionTreeImpurity = Variance$.MODULE$;
        }
        this.maxDepth = getParamValue("maxDepth", 100);
    }

    @Override // org.apache.lens.ml.spark.algos.BaseSparkAlgo
    protected BaseSparkClassificationModel trainInternal(String str, RDD<LabeledPoint> rdd) throws LensException {
        return new DecisionTreeClassificationModel(str, new SparkDecisionTreeModel(DecisionTree$.MODULE$.train(rdd, this.algo, this.decisionTreeImpurity, this.maxDepth)));
    }
}
