package co.cask.cdap.examples.dtree;

import co.cask.cdap.api.TxRunnable;
import co.cask.cdap.api.data.DatasetContext;
import co.cask.cdap.api.dataset.lib.FileSet;
import co.cask.cdap.api.flow.flowlet.StreamEvent;
import co.cask.cdap.api.spark.SparkExecutionContext;
import co.cask.cdap.api.spark.SparkMain;
import co.cask.cdap.api.spark.SparkMain$Transaction$;
import java.util.UUID;
import java.util.concurrent.atomic.AtomicReference;
import org.apache.spark.SparkContext;
import org.apache.spark.ml.Pipeline;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.ml.PipelineStage;
import org.apache.spark.ml.evaluation.RegressionEvaluator;
import org.apache.spark.ml.feature.VectorIndexer;
import org.apache.spark.ml.feature.VectorIndexerModel;
import org.apache.spark.ml.regression.DecisionTreeRegressionModel;
import org.apache.spark.ml.regression.DecisionTreeRegressor;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.SparkSession$;
import org.spark_project.jetty.util.URIUtil;
import scala.Array$;
import scala.Function1;
import scala.MatchError;
import scala.Option;
import scala.Tuple2;
import scala.collection.SeqLike;
import scala.collection.mutable.StringBuilder;
import scala.reflect.ClassTag;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;

/* compiled from: DecisionTreeRegressionTrainer.scala */
@ScalaSignature(bytes = "\u0006\u000193A!\u0001\u0002\u0001\u001b\tiB)Z2jg&|g\u000e\u0016:fKJ+wM]3tg&|g\u000e\u0016:bS:,'O\u0003\u0002\u0004\t\u0005)A\r\u001e:fK*\u0011QAB\u0001\tKb\fW\u000e\u001d7fg*\u0011q\u0001C\u0001\u0005G\u0012\f\u0007O\u0003\u0002\n\u0015\u0005!1-Y:l\u0015\u0005Y\u0011AA2p\u0007\u0001\u00192\u0001\u0001\b\u0015!\ty!#D\u0001\u0011\u0015\u0005\t\u0012!B:dC2\f\u0017BA\n\u0011\u0005\u0019\te.\u001f*fMB\u0011QCG\u0007\u0002-)\u0011q\u0003G\u0001\u0006gB\f'o\u001b\u0006\u00033\u0019\t1!\u00199j\u0013\tYbCA\u0005Ta\u0006\u00148.T1j]\")Q\u0004\u0001C\u0001=\u00051A(\u001b8jiz\"\u0012a\b\t\u0003A\u0001i\u0011A\u0001\u0005\u0006E\u0001!\teI\u0001\u0004eVtGC\u0001\u0013(!\tyQ%\u0003\u0002'!\t!QK\\5u\u0011\u0015A\u0013\u0005q\u0001*\u0003\r\u0019Xm\u0019\t\u0003+)J!a\u000b\f\u0003+M\u0003\u0018M]6Fq\u0016\u001cW\u000f^5p]\u000e{g\u000e^3yi\u001e)QF\u0001E\u0001]\u0005iB)Z2jg&|g\u000e\u0016:fKJ+wM]3tg&|g\u000e\u0016:bS:,'\u000f\u0005\u0002!_\u0019)\u0011A\u0001E\u0001aM\u0019qFD\u0019\u0011\u0005=\u0011\u0014BA\u001a\u0011\u00051\u0019VM]5bY&T\u0018M\u00197f\u0011\u0015ir\u0006\"\u00016)\u0005q\u0003bB\u001c0\u0005\u0004%I\u0001O\u0001\u0004\u0019>;U#A\u001d\u0011\u0005izT\"A\u001e\u000b\u0005qj\u0014!B:mMRR'\"\u0001 \u0002\u0007=\u0014x-\u0003\u0002Aw\t1Aj\\4hKJDaAQ\u0018!\u0002\u0013I\u0014\u0001\u0002'P\u000f\u0002Bq\u0001R\u0018\u0002\u0002\u0013%Q)A\u0006sK\u0006$'+Z:pYZ,G#\u0001$\u0011\u0005\u001dcU\"\u0001%\u000b\u0005%S\u0015\u0001\u00027b]\u001eT\u0011aS\u0001\u0005U\u00064\u0018-\u0003\u0002N\u0011\n1qJ\u00196fGR\u0004")
/* loaded from: input_file:co/cask/cdap/examples/dtree/DecisionTreeRegressionTrainer.class */
public class DecisionTreeRegressionTrainer implements SparkMain {
    private final Function1<StreamEvent, Tuple2<Object, String>> timestampStringStreamDecoder;
    private final Function1<StreamEvent, String> stringStreamDecoder;
    private volatile SparkMain$Transaction$ Transaction$module;

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v0 */
    /* JADX WARN: Type inference failed for: r0v1, types: [java.lang.Throwable] */
    /* JADX WARN: Type inference failed for: r0v5 */
    private SparkMain$Transaction$ Transaction$lzycompute() {
        ?? r0 = this;
        synchronized (r0) {
            if (this.Transaction$module == null) {
                this.Transaction$module = new SparkMain$Transaction$(this);
            }
            BoxedUnit boxedUnit = BoxedUnit.UNIT;
            r0 = r0;
            return this.Transaction$module;
        }
    }

    public SparkMain$Transaction$ Transaction() {
        return this.Transaction$module == null ? Transaction$lzycompute() : this.Transaction$module;
    }

    public Function1<StreamEvent, Tuple2<Object, String>> timestampStringStreamDecoder() {
        return this.timestampStringStreamDecoder;
    }

    public Function1<StreamEvent, String> stringStreamDecoder() {
        return this.stringStreamDecoder;
    }

    public void co$cask$cdap$api$spark$SparkMain$_setter_$timestampStringStreamDecoder_$eq(Function1 function1) {
        this.timestampStringStreamDecoder = function1;
    }

    public void co$cask$cdap$api$spark$SparkMain$_setter_$stringStreamDecoder_$eq(Function1 function1) {
        this.stringStreamDecoder = function1;
    }

    public <K, V> SparkMain.SparkProgramRDDFunctions<K, V> SparkProgramRDDFunctions(RDD<Tuple2<K, V>> rdd, ClassTag<K> classTag, ClassTag<V> classTag2) {
        return SparkMain.class.SparkProgramRDDFunctions(this, rdd, classTag, classTag2);
    }

    public SparkMain.SparkProgramContextFunctions SparkProgramContextFunctions(SparkContext sparkContext) {
        return SparkMain.class.SparkProgramContextFunctions(this, sparkContext);
    }

    public void run(SparkExecutionContext sparkExecutionContext) {
        final AtomicReference atomicReference = new AtomicReference();
        final AtomicReference atomicReference2 = new AtomicReference();
        sparkExecutionContext.execute(new TxRunnable(this, atomicReference, atomicReference2) { // from class: co.cask.cdap.examples.dtree.DecisionTreeRegressionTrainer$$anon$1
            private final AtomicReference inputPath$1;
            private final AtomicReference outputPath$1;

            public void run(DatasetContext datasetContext) {
                FileSet dataset = datasetContext.getDataset(DecisionTreeRegressionApp.TRAINING_DATASET);
                FileSet dataset2 = datasetContext.getDataset(DecisionTreeRegressionApp.MODEL_DATASET);
                this.inputPath$1.set(dataset.getBaseLocation().append("labels").toURI().getPath());
                this.outputPath$1.set(dataset2.getBaseLocation().toURI().getPath());
            }

            {
                this.inputPath$1 = atomicReference;
                this.outputPath$1 = atomicReference2;
            }
        });
        SparkSession orCreate = SparkSession$.MODULE$.builder().appName("DecisionTreeRegressionExample").getOrCreate();
        Dataset<Row> load = orCreate.read().format("libsvm").load((String) atomicReference.get());
        VectorIndexerModel fit = new VectorIndexer().setInputCol("features").setOutputCol("indexedFeatures").setMaxCategories(4).fit((Dataset<?>) load);
        Dataset<Row>[] randomSplit = load.randomSplit(new double[]{0.7d, 0.3d});
        Option unapplySeq = Array$.MODULE$.unapplySeq(randomSplit);
        if (unapplySeq.isEmpty() || unapplySeq.get() == null || ((SeqLike) unapplySeq.get()).lengthCompare(2) != 0) {
            throw new MatchError(randomSplit);
        }
        Tuple2 tuple2 = new Tuple2((Dataset) ((SeqLike) unapplySeq.get()).apply(0), (Dataset) ((SeqLike) unapplySeq.get()).apply(1));
        Dataset<?> dataset = (Dataset) tuple2._1();
        Dataset<?> dataset2 = (Dataset) tuple2._2();
        PipelineModel fit2 = new Pipeline().setStages(new PipelineStage[]{fit, new DecisionTreeRegressor().setLabelCol("label").setFeaturesCol("indexedFeatures")}).fit(dataset);
        Dataset<Row> transform = fit2.transform(dataset2);
        long count = transform.where("prediction=label").count();
        long count2 = transform.count();
        double evaluate = new RegressionEvaluator().setLabelCol("label").setPredictionCol("prediction").setMetricName("rmse").evaluate(transform);
        DecisionTreeRegressionTrainer$.MODULE$.co$cask$cdap$examples$dtree$DecisionTreeRegressionTrainer$$LOG().info(new StringBuilder().append("Root Mean Squared Error (RMSE) on test data = ").append(BoxesRunTime.boxToDouble(evaluate)).toString());
        DecisionTreeRegressionModel decisionTreeRegressionModel = (DecisionTreeRegressionModel) fit2.stages()[1];
        DecisionTreeRegressionTrainer$.MODULE$.co$cask$cdap$examples$dtree$DecisionTreeRegressionTrainer$$LOG().info(new StringBuilder().append("Learned regression tree model:\n").append(decisionTreeRegressionModel.toDebugString()).toString());
        final String uuid = UUID.randomUUID().toString();
        decisionTreeRegressionModel.save(new StringBuilder().append((String) atomicReference2.get()).append(URIUtil.SLASH).append(uuid).toString());
        final ModelMeta modelMeta = new ModelMeta(fit.numFeatures(), count2, count, evaluate, 0.7d);
        sparkExecutionContext.execute(new TxRunnable(this, uuid, modelMeta) { // from class: co.cask.cdap.examples.dtree.DecisionTreeRegressionTrainer$$anon$2
            private final String id$1;
            private final ModelMeta meta$1;

            public void run(DatasetContext datasetContext) {
                datasetContext.getDataset(DecisionTreeRegressionApp.MODEL_META).write(this.id$1, this.meta$1);
            }

            {
                this.id$1 = uuid;
                this.meta$1 = modelMeta;
            }
        });
        orCreate.stop();
    }

    public DecisionTreeRegressionTrainer() {
        SparkMain.class.$init$(this);
    }
}
