package org.apache.lens.ml.algo.spark.lr;

import java.util.Map;
import org.apache.lens.ml.algo.api.AlgoParam;
import org.apache.lens.ml.algo.api.Algorithm;
import org.apache.lens.ml.algo.spark.BaseSparkAlgo;
import org.apache.lens.ml.algo.spark.BaseSparkClassificationModel;
import org.apache.lens.server.api.error.LensException;
import org.apache.spark.mllib.classification.LogisticRegressionWithSGD;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.rdd.RDD;

@Algorithm(name = "spark_logistic_regression", description = "Spark logistic regression algo")
/* loaded from: input_file:org/apache/lens/ml/algo/spark/lr/LogisticRegressionAlgo.class */
public class LogisticRegressionAlgo extends BaseSparkAlgo {

    @AlgoParam(name = "iterations", help = "Max number of iterations", defaultValue = "100")
    private int iterations;

    @AlgoParam(name = "stepSize", help = "Step size", defaultValue = "1.0d")
    private double stepSize;

    @AlgoParam(name = "minBatchFraction", help = "Fraction for batched learning", defaultValue = "1.0d")
    private double minBatchFraction;

    public LogisticRegressionAlgo(String str, String str2) {
        super(str, str2);
    }

    @Override // org.apache.lens.ml.algo.spark.BaseSparkAlgo
    public void parseAlgoParams(Map<String, String> map) {
        this.iterations = getParamValue("iterations", 100);
        this.stepSize = getParamValue("stepSize", 1.0d);
        this.minBatchFraction = getParamValue("minBatchFraction", 1.0d);
    }

    @Override // org.apache.lens.ml.algo.spark.BaseSparkAlgo
    protected BaseSparkClassificationModel trainInternal(String str, RDD<LabeledPoint> rdd) throws LensException {
        return new LogitRegressionClassificationModel(str, LogisticRegressionWithSGD.train(rdd, this.iterations, this.stepSize, this.minBatchFraction));
    }
}
