package net.sansa_stack.examples.spark.ml.kge;

import net.sansa_stack.examples.spark.ml.kge.CrossValidation;
import net.sansa_stack.ml.spark.kge.linkprediction.crossvalidation.Bootstrapping;
import net.sansa_stack.ml.spark.kge.linkprediction.crossvalidation.Holdout;
import net.sansa_stack.ml.spark.kge.linkprediction.crossvalidation.kFold;
import net.sansa_stack.rdf.spark.kge.convertor.ByIndex;
import net.sansa_stack.rdf.spark.kge.triples.Triples;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.SparkSession$;
import scala.MatchError;
import scala.None$;
import scala.Predef$;
import scala.Some;
import scala.StringContext;
import scala.Tuple2;
import scala.collection.mutable.StringBuilder;
import scala.runtime.BoxedUnit;
import scopt.OptionParser;
import scopt.Read$;

/* compiled from: CrossValidation.scala */
/* loaded from: input_file:net/sansa_stack/examples/spark/ml/kge/CrossValidation$.class */
public final class CrossValidation$ {
    public static final CrossValidation$ MODULE$ = null;
    private final OptionParser<CrossValidation.Config> parser;

    static {
        new CrossValidation$();
    }

    public void main(String[] strArr) {
        Some parse = parser().parse(Predef$.MODULE$.wrapRefArray(strArr), new CrossValidation.Config(CrossValidation$Config$.MODULE$.apply$default$1(), CrossValidation$Config$.MODULE$.apply$default$2(), CrossValidation$Config$.MODULE$.apply$default$3()));
        if (parse instanceof Some) {
            CrossValidation.Config config = (CrossValidation.Config) parse.x();
            run(config.in(), config.technique(), config.k());
            BoxedUnit boxedUnit = BoxedUnit.UNIT;
        } else {
            if (!None$.MODULE$.equals(parse)) {
                throw new MatchError(parse);
            }
            Predef$.MODULE$.println(parser().usage());
            BoxedUnit boxedUnit2 = BoxedUnit.UNIT;
        }
    }

    public void run(String str, String str2, int i) {
        Tuple2 crossValidation;
        SparkSession orCreate = SparkSession$.MODULE$.builder().appName(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"Cross validation techniques example  ", ""})).s(Predef$.MODULE$.genericWrapArray(new Object[]{str}))).master("local[*]").config("spark.serializer", "org.apache.spark.serializer.KryoSerializer").getOrCreate();
        Predef$.MODULE$.println("==============================================");
        Predef$.MODULE$.println("|Cross validation techniques example |");
        Predef$.MODULE$.println("==============================================");
        Triples triples = new Triples(str, "\t", false, false, orCreate);
        Predef$.MODULE$.refArrayOps((Object[]) Predef$.MODULE$.refArrayOps(triples.getEntities()).take(10)).foreach(new CrossValidation$$anonfun$run$1());
        Predef$.MODULE$.refArrayOps((Object[]) Predef$.MODULE$.refArrayOps(triples.getEntities()).take(10)).foreach(new CrossValidation$$anonfun$run$2());
        ByIndex byIndex = new ByIndex(triples.triples(), orCreate);
        Dataset numeric = byIndex.numeric();
        Predef$.MODULE$.refArrayOps((Object[]) byIndex.numeric().take(10)).foreach(new CrossValidation$$anonfun$run$3());
        if ("holdout".equals(str2)) {
            crossValidation = new Holdout(numeric, 0.6f).crossValidation();
        } else if ("bootstrapping".equals(str2)) {
            crossValidation = new Bootstrapping(numeric).crossValidation();
        } else {
            if (!"kFold".equals(str2)) {
                throw new RuntimeException(new StringBuilder().append("'").append(str2).append("' - Not supported, yet.").toString());
            }
            crossValidation = new kFold(numeric, i, orCreate).crossValidation();
        }
        Tuple2 tuple2 = crossValidation;
        if (tuple2 == null) {
            throw new MatchError(tuple2);
        }
        Tuple2 tuple22 = new Tuple2(tuple2._1(), tuple2._2());
        tuple22._1();
        tuple22._2();
        Predef$.MODULE$.println("<< DONE >>");
        orCreate.stop();
    }

    public OptionParser<CrossValidation.Config> parser() {
        return this.parser;
    }

    private CrossValidation$() {
        MODULE$ = this;
        this.parser = new OptionParser<CrossValidation.Config>() { // from class: net.sansa_stack.examples.spark.ml.kge.CrossValidation$$anon$1
            {
                head(Predef$.MODULE$.wrapRefArray(new String[]{"Cross validation techniques example"}));
                opt('i', "input", Read$.MODULE$.stringRead()).required().valueName("<path>").action(new CrossValidation$$anon$1$$anonfun$1(this)).text("path to file that contains the data");
                opt('t', "technique", Read$.MODULE$.stringRead()).required().valueName("{holdout | bootstrapping | kFold}").action(new CrossValidation$$anon$1$$anonfun$2(this)).text("cross validation techniques");
                opt("k", Read$.MODULE$.intRead()).optional().valueName("<value>").action(new CrossValidation$$anon$1$$anonfun$3(this)).text("The k value (used only for technique'kFold')");
                checkConfig(new CrossValidation$$anon$1$$anonfun$4(this));
                help("help").text("prints this usage text");
            }
        };
    }
}
