package org.apache.lens.ml.spark.algos;

import java.util.Map;
import org.apache.lens.api.LensException;
import org.apache.lens.ml.AlgoParam;
import org.apache.lens.ml.Algorithm;
import org.apache.lens.ml.spark.models.BaseSparkClassificationModel;
import org.apache.lens.ml.spark.models.NaiveBayesClassificationModel;
import org.apache.spark.mllib.classification.NaiveBayes;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.rdd.RDD;

@Algorithm(name = "spark_naive_bayes", description = "Spark Naive Bayes classifier algo")
/* loaded from: input_file:org/apache/lens/ml/spark/algos/NaiveBayesAlgo.class */
public class NaiveBayesAlgo extends BaseSparkAlgo {

    @AlgoParam(name = "lambda", help = "Lambda parameter for naive bayes learner", defaultValue = "1.0d")
    private double lambda;

    public NaiveBayesAlgo(String str, String str2) {
        super(str, str2);
        this.lambda = 1.0d;
    }

    @Override // org.apache.lens.ml.spark.algos.BaseSparkAlgo
    public void parseAlgoParams(Map<String, String> map) {
        this.lambda = getParamValue("lambda", 1.0d);
    }

    @Override // org.apache.lens.ml.spark.algos.BaseSparkAlgo
    protected BaseSparkClassificationModel trainInternal(String str, RDD<LabeledPoint> rdd) throws LensException {
        return new NaiveBayesClassificationModel(str, NaiveBayes.train(rdd, this.lambda));
    }
}
