package com.microsoft.azure.synapse.ml.vw;

import com.microsoft.azure.synapse.ml.build.BuildInfo$;
import com.microsoft.azure.synapse.ml.core.env.FileUtilities$;
import org.apache.spark.ml.Pipeline;
import org.apache.spark.ml.Transformer;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.functions$;
import scala.Predef$;

/* compiled from: VWContextualBandidSpec.scala */
/* loaded from: input_file:com/microsoft/azure/synapse/ml/vw/CBDatasetHelper$.class */
public final class CBDatasetHelper$ {
    public static CBDatasetHelper$ MODULE$;

    static {
        new CBDatasetHelper$();
    }

    public Dataset<Row> readCSV(SparkSession sparkSession, String str, String str2) {
        return sparkSession.read().option("header", "true").option("inferSchema", "true").option("treatEmptyValuesAsNulls", "false").option("delimiter", str.endsWith(".csv") ? "," : "\t").csv(str2);
    }

    public Dataset<Row> getCBDataset(SparkSession sparkSession) {
        Dataset withColumnRenamed = readCSV(sparkSession, "cbdata.train.csv", FileUtilities$.MODULE$.join(BuildInfo$.MODULE$.datasetDir(), Predef$.MODULE$.wrapRefArray(new String[]{"VowpalWabbit", "Train", "cbdata.train.csv"})).toString()).repartition(1).withColumn("chosen_action", functions$.MODULE$.col("chosen_action").cast("Int")).withColumnRenamed("chosen_action", "chosenAction").withColumn("cost", functions$.MODULE$.col("cost").cast("Double")).withColumnRenamed("cost", "label").withColumn("prob", functions$.MODULE$.col("prob").cast("Double")).withColumnRenamed("prob", "probability");
        return new Pipeline().setStages(new Transformer[]{(VowpalWabbitFeaturizer) new VowpalWabbitFeaturizer().setInputCols(new String[]{"shared_id", "shared_major", "shared_hobby", "shared_fav_character"}).setOutputCol("shared"), (VowpalWabbitFeaturizer) new VowpalWabbitFeaturizer().setInputCols(new String[]{"action1_topic"}).setOutputCol("action1_features"), (VowpalWabbitFeaturizer) new VowpalWabbitFeaturizer().setInputCols(new String[]{"action2_topic"}).setOutputCol("action2_features"), (VowpalWabbitFeaturizer) new VowpalWabbitFeaturizer().setInputCols(new String[]{"action3_topic"}).setOutputCol("action3_features"), (VowpalWabbitFeaturizer) new VowpalWabbitFeaturizer().setInputCols(new String[]{"action4_topic"}).setOutputCol("action4_features"), (VowpalWabbitFeaturizer) new VowpalWabbitFeaturizer().setInputCols(new String[]{"action5_topic"}).setOutputCol("action5_features"), (VectorZipper) new VectorZipper().setInputCols(new String[]{"action1_features", "action2_features", "action3_features", "action4_features", "action5_features"}).setOutputCol("features")}).fit(withColumnRenamed).transform(withColumnRenamed);
    }

    private CBDatasetHelper$() {
        MODULE$ = this;
    }
}
