package org.apache.spark.ml.r;

import org.apache.spark.SparkException;
import org.apache.spark.ml.Pipeline;
import org.apache.spark.ml.PipelineStage;
import org.apache.spark.ml.attribute.Attribute;
import org.apache.spark.ml.attribute.AttributeGroup$;
import org.apache.spark.ml.feature.RFormula;
import org.apache.spark.ml.feature.RFormulaModel;
import org.apache.spark.ml.r.AFTSurvivalRegressionWrapper;
import org.apache.spark.ml.regression.AFTSurvivalRegression;
import org.apache.spark.ml.util.MLReadable;
import org.apache.spark.ml.util.MLReader;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import scala.MatchError;
import scala.Option;
import scala.Predef$;
import scala.Tuple2;
import scala.Tuple3;
import scala.collection.ArrayOps$;
import scala.collection.LinearSeqOps;
import scala.collection.StringOps$;
import scala.collection.immutable.List;
import scala.reflect.ClassTag$;
import scala.util.matching.Regex;

/* compiled from: AFTSurvivalRegressionWrapper.scala */
/* loaded from: input_file:org/apache/spark/ml/r/AFTSurvivalRegressionWrapper$.class */
public final class AFTSurvivalRegressionWrapper$ implements MLReadable<AFTSurvivalRegressionWrapper> {
    public static final AFTSurvivalRegressionWrapper$ MODULE$ = new AFTSurvivalRegressionWrapper$();
    private static final Regex FORMULA_REGEXP;

    static {
        MLReadable.$init$(MODULE$);
        FORMULA_REGEXP = StringOps$.MODULE$.r$extension(Predef$.MODULE$.augmentString("Surv\\(([^,]+), ([^,]+)\\) ~ (.+)"));
    }

    private Regex FORMULA_REGEXP() {
        return FORMULA_REGEXP;
    }

    private Tuple2<String, String> formulaRewrite(String str) {
        if (str != null) {
            try {
                Option unapplySeq = FORMULA_REGEXP().unapplySeq(str);
                if (!unapplySeq.isEmpty() && unapplySeq.get() != null && ((List) unapplySeq.get()).lengthCompare(3) == 0) {
                    Tuple3 tuple3 = new Tuple3((String) ((LinearSeqOps) unapplySeq.get()).apply(0), (String) ((LinearSeqOps) unapplySeq.get()).apply(1), (String) ((LinearSeqOps) unapplySeq.get()).apply(2));
                    String str2 = (String) tuple3._1();
                    String str3 = (String) tuple3._2();
                    String str4 = (String) tuple3._3();
                    if (str4.contains(".")) {
                        throw new UnsupportedOperationException("Terms of survreg formula can not support dot operator.");
                    }
                    return new Tuple2<>(str2.trim() + "~" + str4.trim(), str3.trim());
                }
            } catch (MatchError e) {
                throw new SparkException("Could not parse formula: " + str);
            }
        }
        throw new MatchError(str);
    }

    /* JADX WARN: Multi-variable type inference failed */
    public AFTSurvivalRegressionWrapper fit(String str, Dataset<Row> dataset, int i, String str2) {
        Tuple2<String, String> formulaRewrite = formulaRewrite(str);
        if (formulaRewrite == null) {
            throw new MatchError(formulaRewrite);
        }
        Tuple2 tuple2 = new Tuple2((String) formulaRewrite._1(), (String) formulaRewrite._2());
        String str3 = (String) tuple2._1();
        String str4 = (String) tuple2._2();
        RFormula stringIndexerOrderType = new RFormula().setFormula(str3).setStringIndexerOrderType(str2);
        RWrapperUtils$.MODULE$.checkDataColumns(stringIndexerOrderType, dataset);
        RFormulaModel fit = stringIndexerOrderType.fit((Dataset<?>) dataset);
        return new AFTSurvivalRegressionWrapper(new Pipeline().setStages(new PipelineStage[]{fit, ((AFTSurvivalRegression) new AFTSurvivalRegression().setCensorCol(str4).setFitIntercept(stringIndexerOrderType.hasIntercept()).setFeaturesCol(stringIndexerOrderType.getFeaturesCol())).setAggregationDepth(i)}).fit((Dataset<?>) dataset), (String[]) ArrayOps$.MODULE$.map$extension(Predef$.MODULE$.refArrayOps((Attribute[]) AttributeGroup$.MODULE$.fromStructField(fit.transform(dataset).schema().apply(stringIndexerOrderType.getFeaturesCol())).attributes().get()), attribute -> {
            return (String) attribute.name().get();
        }, ClassTag$.MODULE$.apply(String.class)));
    }

    @Override // org.apache.spark.ml.util.MLReadable
    public MLReader<AFTSurvivalRegressionWrapper> read() {
        return new AFTSurvivalRegressionWrapper.AFTSurvivalRegressionWrapperReader();
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.apache.spark.ml.util.MLReadable
    public AFTSurvivalRegressionWrapper load(String str) {
        Object load;
        load = load(str);
        return (AFTSurvivalRegressionWrapper) load;
    }

    private AFTSurvivalRegressionWrapper$() {
    }
}
