package org.apache.lens.ml.spark.algos;

import java.util.Map;
import org.apache.lens.api.LensException;
import org.apache.lens.ml.AlgoParam;
import org.apache.lens.ml.Algorithm;
import org.apache.lens.ml.spark.models.BaseSparkClassificationModel;
import org.apache.lens.ml.spark.models.SVMClassificationModel;
import org.apache.spark.mllib.classification.SVMWithSGD;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.rdd.RDD;

@Algorithm(name = "spark_svm", description = "Spark SVML classifier algo")
/* loaded from: input_file:org/apache/lens/ml/spark/algos/SVMAlgo.class */
public class SVMAlgo extends BaseSparkAlgo {

    @AlgoParam(name = "minBatchFraction", help = "Fraction for batched learning", defaultValue = "1.0d")
    private double minBatchFraction;

    @AlgoParam(name = "regParam", help = "regularization parameter for gradient descent", defaultValue = "1.0d")
    private double regParam;

    @AlgoParam(name = "stepSize", help = "Iteration step size", defaultValue = "1.0d")
    private double stepSize;

    @AlgoParam(name = "iterations", help = "Number of iterations", defaultValue = "100")
    private int iterations;

    public SVMAlgo(String str, String str2) {
        super(str, str2);
    }

    @Override // org.apache.lens.ml.spark.algos.BaseSparkAlgo
    public void parseAlgoParams(Map<String, String> map) {
        this.minBatchFraction = getParamValue("minBatchFraction", 1.0d);
        this.regParam = getParamValue("regParam", 1.0d);
        this.stepSize = getParamValue("stepSize", 1.0d);
        this.iterations = getParamValue("iterations", 100);
    }

    @Override // org.apache.lens.ml.spark.algos.BaseSparkAlgo
    protected BaseSparkClassificationModel trainInternal(String str, RDD<LabeledPoint> rdd) throws LensException {
        return new SVMClassificationModel(str, SVMWithSGD.train(rdd, this.iterations, this.stepSize, this.regParam, this.minBatchFraction));
    }
}
