package org.apache.spark.ml.api.r;

import org.apache.spark.ml.Pipeline;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.ml.PipelineStage;
import org.apache.spark.ml.Transformer;
import org.apache.spark.ml.attribute.AttributeGroup$;
import org.apache.spark.ml.classification.LogisticRegression;
import org.apache.spark.ml.classification.LogisticRegressionModel;
import org.apache.spark.ml.feature.RFormula;
import org.apache.spark.ml.param.Params;
import org.apache.spark.ml.regression.LinearRegression;
import org.apache.spark.ml.regression.LinearRegressionModel;
import org.apache.spark.sql.DataFrame;
import scala.Array$;
import scala.MatchError;
import scala.Predef$;
import scala.reflect.ClassTag$;

/* compiled from: SparkRWrappers.scala */
/* loaded from: input_file:org/apache/spark/ml/api/r/SparkRWrappers$.class */
public final class SparkRWrappers$ {
    public static final SparkRWrappers$ MODULE$ = null;

    static {
        new SparkRWrappers$();
    }

    public PipelineModel fitRModelFormula(String str, DataFrame dataFrame, String str2, double d, double d2) {
        Params fitIntercept;
        RFormula formula = new RFormula().setFormula(str);
        if ("gaussian".equals(str2)) {
            fitIntercept = new LinearRegression().setRegParam(d).setElasticNetParam(d2).setFitIntercept(formula.hasIntercept());
        } else {
            if (!"binomial".equals(str2)) {
                throw new MatchError(str2);
            }
            fitIntercept = new LogisticRegression().setRegParam(d).setElasticNetParam(d2).setFitIntercept(formula.hasIntercept());
        }
        return new Pipeline().setStages(new PipelineStage[]{formula, fitIntercept}).fit(dataFrame);
    }

    public double[] getModelWeights(PipelineModel pipelineModel) {
        Transformer transformer = (Transformer) Predef$.MODULE$.refArrayOps(pipelineModel.stages()).last();
        if (transformer instanceof LinearRegressionModel) {
            LinearRegressionModel linearRegressionModel = (LinearRegressionModel) transformer;
            return (double[]) Predef$.MODULE$.doubleArrayOps(new double[]{linearRegressionModel.intercept()}).$plus$plus(Predef$.MODULE$.doubleArrayOps(linearRegressionModel.weights().toArray()), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.Double()));
        }
        if (transformer instanceof LogisticRegressionModel) {
            throw new UnsupportedOperationException("No weights available for LogisticRegressionModel");
        }
        throw new MatchError(transformer);
    }

    public String[] getModelFeatures(PipelineModel pipelineModel) {
        Transformer transformer = (Transformer) Predef$.MODULE$.refArrayOps(pipelineModel.stages()).last();
        if (transformer instanceof LinearRegressionModel) {
            LinearRegressionModel linearRegressionModel = (LinearRegressionModel) transformer;
            return (String[]) Predef$.MODULE$.refArrayOps(new String[]{"(Intercept)"}).$plus$plus(Predef$.MODULE$.refArrayOps((Object[]) Predef$.MODULE$.refArrayOps((Object[]) AttributeGroup$.MODULE$.fromStructField(linearRegressionModel.summary().predictions().schema().apply(linearRegressionModel.summary().featuresCol())).attributes().get()).map(new SparkRWrappers$$anonfun$getModelFeatures$1(), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(String.class)))), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(String.class)));
        }
        if (transformer instanceof LogisticRegressionModel) {
            throw new UnsupportedOperationException("No features names available for LogisticRegressionModel");
        }
        throw new MatchError(transformer);
    }

    private SparkRWrappers$() {
        MODULE$ = this;
    }
}
