package ai.h2o.sparkling.examples;

import ai.h2o.sparkling.ml.algos.H2OAutoML;
import ai.h2o.sparkling.ml.algos.H2ODeepLearning;
import ai.h2o.sparkling.ml.algos.H2OGBM;
import ai.h2o.sparkling.ml.algos.H2OGridSearch;
import ai.h2o.sparkling.ml.algos.H2OXGBoost;
import ai.h2o.sparkling.ml.features.ColumnPruner;
import java.io.File;
import org.apache.spark.h2o.H2OContext$;
import org.apache.spark.ml.Estimator;
import org.apache.spark.ml.Pipeline;
import org.apache.spark.ml.Pipeline$;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.ml.PipelineModel$;
import org.apache.spark.ml.PipelineStage;
import org.apache.spark.ml.feature.HashingTF;
import org.apache.spark.ml.feature.IDF;
import org.apache.spark.ml.feature.RegexTokenizer;
import org.apache.spark.ml.feature.StopWordsRemover;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.Row$;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.SparkSession$;
import org.apache.spark.sql.types.StringType$;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructField$;
import org.apache.spark.sql.types.StructType;
import scala.Array$;
import scala.Predef$;
import scala.Predef$ArrowAssoc$;
import scala.Tuple2;
import scala.collection.Seq$;
import scala.collection.mutable.ArrayOps;
import scala.reflect.ClassTag$;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.ScalaRunTime$;

/* compiled from: HamOrSpamDemo.scala */
/* loaded from: input_file:ai/h2o/sparkling/examples/HamOrSpamDemo$.class */
public final class HamOrSpamDemo$ {
    public static HamOrSpamDemo$ MODULE$;

    static {
        new HamOrSpamDemo$();
    }

    public void main(String[] strArr) {
        SparkSession orCreate = SparkSession$.MODULE$.builder().appName("Ham or Spam Pipeline Demo").getOrCreate();
        Dataset<Row> load = load(orCreate, new StringBuilder(7).append("file://").append(new File("./examples/smalldata/smsData.txt").getAbsolutePath()).toString());
        H2OContext$.MODULE$.getOrCreate();
        RegexTokenizer createTokenizer = createTokenizer();
        StopWordsRemover createStopWordsRemover = createStopWordsRemover(createTokenizer);
        HashingTF createHashingTF = createHashingTF(createStopWordsRemover);
        IDF createIDF = createIDF(createHashingTF);
        ColumnPruner createColumnPruner = createColumnPruner(createIDF, createHashingTF, createStopWordsRemover, createTokenizer);
        new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(new Estimator[]{gbm(), deepLearning(), autoML(), gridSearch(), xgboost()})).foreach(estimator -> {
            $anonfun$main$1(createTokenizer, createStopWordsRemover, createHashingTF, createIDF, createColumnPruner, orCreate, load, estimator);
            return BoxedUnit.UNIT;
        });
    }

    public Pipeline createPipeline(PipelineStage[] pipelineStageArr) {
        new Pipeline().setStages(pipelineStageArr).write().overwrite().save("examples/build/pipeline");
        return Pipeline$.MODULE$.load("examples/build/pipeline");
    }

    public void assertPredictions(SparkSession sparkSession, PipelineModel pipelineModel) {
        Predef$.MODULE$.assert(!isSpam(sparkSession, "Michal, h2oworld party tonight in MV?", pipelineModel));
        Predef$.MODULE$.assert(isSpam(sparkSession, "We tried to contact you re your reply to our offer of a Video Handset? 750 anytime any networks mins? UNLIMITED TEXT?", pipelineModel));
    }

    public boolean isSpam(SparkSession sparkSession, String str, PipelineModel pipelineModel) {
        String string = ((Row) pipelineModel.transform(sparkSession.createDataFrame(sparkSession.sparkContext().parallelize(Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new String[]{str})), sparkSession.sparkContext().parallelize$default$2(), ClassTag$.MODULE$.apply(String.class)).map(str2 -> {
            return Row$.MODULE$.apply(Predef$.MODULE$.genericWrapArray(new Object[]{str2}));
        }, ClassTag$.MODULE$.apply(Row.class)), new StructType(new StructField[]{new StructField("text", StringType$.MODULE$, false, StructField$.MODULE$.apply$default$4())}))).select("prediction", Predef$.MODULE$.wrapRefArray(new String[0])).first()).getString(0);
        return string != null ? string.equals("spam") : "spam" == 0;
    }

    public Dataset<Row> load(SparkSession sparkSession, String str) {
        return sparkSession.createDataFrame(sparkSession.sparkContext().textFile(str, sparkSession.sparkContext().textFile$default$2()).map(str2 -> {
            return str2.split("\t", 2);
        }, ClassTag$.MODULE$.apply(ScalaRunTime$.MODULE$.arrayClass(String.class))).filter(strArr -> {
            return BoxesRunTime.boxToBoolean($anonfun$load$2(strArr));
        }).map(strArr2 -> {
            return Row$.MODULE$.apply(Predef$.MODULE$.genericWrapArray(new Object[]{strArr2[0], strArr2[1]}));
        }, ClassTag$.MODULE$.apply(Row.class)), new StructType(new StructField[]{new StructField("label", StringType$.MODULE$, false, StructField$.MODULE$.apply$default$4()), new StructField("text", StringType$.MODULE$, false, StructField$.MODULE$.apply$default$4())}));
    }

    public RegexTokenizer createTokenizer() {
        return new RegexTokenizer().setInputCol("text").setOutputCol("words").setMinTokenLength(3).setGaps(false).setPattern("[a-zA-Z]+");
    }

    public StopWordsRemover createStopWordsRemover(RegexTokenizer regexTokenizer) {
        return new StopWordsRemover().setInputCol(regexTokenizer.getOutputCol()).setOutputCol("filtered").setStopWords(new String[]{"the", "a", "", "in", "on", "at", "as", "not", "for"}).setCaseSensitive(false);
    }

    public HashingTF createHashingTF(StopWordsRemover stopWordsRemover) {
        return new HashingTF().setNumFeatures(1024).setInputCol(stopWordsRemover.getOutputCol()).setOutputCol("wordToIndex");
    }

    public IDF createIDF(HashingTF hashingTF) {
        return new IDF().setMinDocFreq(4).setInputCol(hashingTF.getOutputCol()).setOutputCol("tf_idf");
    }

    public ColumnPruner createColumnPruner(IDF idf, HashingTF hashingTF, StopWordsRemover stopWordsRemover, RegexTokenizer regexTokenizer) {
        return new ColumnPruner().setColumns(new String[]{idf.getOutputCol(), hashingTF.getOutputCol(), stopWordsRemover.getOutputCol(), regexTokenizer.getOutputCol()});
    }

    public PipelineModel trainPipeline(SparkSession sparkSession, Pipeline pipeline, Dataset<Row> dataset) {
        pipeline.fit(dataset).write().overwrite().save("build/examples/model");
        return PipelineModel$.MODULE$.load("build/examples/model");
    }

    public H2OGBM gbm() {
        return new H2OGBM().setSplitRatio(0.8d).setSeed(1L).setFeaturesCols("tf_idf", Predef$.MODULE$.wrapRefArray(new String[0])).setLabelCol("label");
    }

    public H2ODeepLearning deepLearning() {
        return new H2ODeepLearning().setEpochs(10.0d).setL1(0.001d).setL2(0.0d).setSeed(1L).setHidden((int[]) Array$.MODULE$.apply(Predef$.MODULE$.wrapIntArray(new int[]{200, 200}), ClassTag$.MODULE$.Int())).setFeaturesCols("tf_idf", Predef$.MODULE$.wrapRefArray(new String[0])).setLabelCol("label");
    }

    public H2OAutoML autoML() {
        return new H2OAutoML().setLabelCol("label").setSeed(1L).setMaxRuntimeSecs(6000.0d).setMaxModels(3).setConvertUnknownCategoricalLevelsToNa(true);
    }

    public H2OGridSearch gridSearch() {
        return new H2OGridSearch().setHyperParameters(Predef$.MODULE$.Map().apply(Predef$.MODULE$.wrapRefArray(new Tuple2[]{Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc("_ntrees"), new ArrayOps.ofInt(Predef$.MODULE$.intArrayOps(new int[]{1, 30})).map(obj -> {
            return BoxesRunTime.boxToInteger(BoxesRunTime.unboxToInt(obj));
        }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.AnyRef())))}))).setAlgo(new H2OGBM().setMaxDepth(6).setSeed(1L).setFeaturesCols("tf_idf", Predef$.MODULE$.wrapRefArray(new String[0])).setLabelCol("label").setConvertUnknownCategoricalLevelsToNa(true));
    }

    public H2OXGBoost xgboost() {
        return new H2OXGBoost().setFeaturesCols("tf_idf", Predef$.MODULE$.wrapRefArray(new String[0])).setLabelCol("label").setConvertUnknownCategoricalLevelsToNa(true);
    }

    public static final /* synthetic */ void $anonfun$main$1(RegexTokenizer regexTokenizer, StopWordsRemover stopWordsRemover, HashingTF hashingTF, IDF idf, ColumnPruner columnPruner, SparkSession sparkSession, Dataset dataset, Estimator estimator) {
        MODULE$.assertPredictions(sparkSession, MODULE$.trainPipeline(sparkSession, MODULE$.createPipeline((PipelineStage[]) new PipelineStage[]{regexTokenizer, stopWordsRemover, hashingTF, idf, estimator, columnPruner}), dataset));
    }

    public static final /* synthetic */ boolean $anonfun$load$2(String[] strArr) {
        return !strArr[0].isEmpty();
    }

    private HamOrSpamDemo$() {
        MODULE$ = this;
    }
}
