package co.cask.mmds.modeler.param;

import co.cask.mmds.spec.DoubleParam;
import co.cask.mmds.spec.ParamSpec;
import co.cask.mmds.spec.Params;
import co.cask.mmds.spec.Range;
import co.cask.mmds.spec.StringParam;
import com.google.common.collect.ImmutableSet;
import java.util.List;
import java.util.Map;
import org.apache.spark.ml.classification.LogisticRegression;

/* loaded from: input_file:co/cask/mmds/modeler/param/LogisticRegressionParams.class */
public class LogisticRegressionParams extends RegressionParams {
    private final DoubleParam threshold;
    private final StringParam family;

    public LogisticRegressionParams(Map<String, String> map) {
        super(map);
        this.threshold = new DoubleParam("threshold", "Threshold", "Threshold in binary classification. If the estimated probability of class label 1 is greater than threshold, then predict 1, else 0. A high threshold encourages the model to predict 0 more often. A low threshold encourages the model to predict 1 more often.", 0.5d, new Range(Double.valueOf(0.0d), Double.valueOf(1.0d), true, true), map);
        this.family = new StringParam("family", "Family", "Label distribution to be used in the model. 'auto' will automatically select the family based on the number of classes. If numClasses == 1 or numClasses == 2, sets to 'binomial'. Else, sets to 'multinomial'. 'binomial' uses binary logistic regression with pivoting. 'multinomial' uses multinomial logistic (softmax) regression without pivoting.", "auto", ImmutableSet.of("auto", "binomial", "multinomial"), map);
    }

    @Override // co.cask.mmds.modeler.param.RegressionParams, co.cask.mmds.spec.Parameters
    public List<ParamSpec> getSpec() {
        return Params.addParams(super.getSpec(), this.threshold, this.family);
    }

    @Override // co.cask.mmds.modeler.param.RegressionParams, co.cask.mmds.spec.Parameters
    public Map<String, String> toMap() {
        return Params.putParams(super.toMap(), this.threshold, this.family);
    }

    public void setParams(LogisticRegression logisticRegression) {
        logisticRegression.setMaxIter(this.maxIterations.getVal().intValue());
        logisticRegression.setStandardization(this.standardization.getVal().booleanValue());
        logisticRegression.setRegParam(this.regularizationParam.getVal().doubleValue());
        logisticRegression.setElasticNetParam(this.elasticNetParam.getVal().doubleValue());
        logisticRegression.setTol(this.tolerance.getVal().doubleValue());
        logisticRegression.setThreshold(this.threshold.getVal().doubleValue());
        logisticRegression.setFamily(this.family.getVal());
    }
}
