package streaming.dsl.mmlib.algs.python;

import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.ml.Model;
import org.apache.spark.ml.param.Params;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SaveMode;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.expressions.UserDefinedFunction;
import org.apache.spark.sql.mlsql.session.MLSQLException;
import org.slf4j.Logger;
import scala.Function0;
import scala.Function1;
import scala.Function2;
import scala.MatchError;
import scala.Option;
import scala.Option$;
import scala.Predef$;
import scala.Predef$ArrowAssoc$;
import scala.Tuple2;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.immutable.List;
import scala.collection.immutable.Map;
import scala.collection.immutable.Map$;
import scala.collection.immutable.Nil$;
import scala.collection.immutable.StringOps;
import scala.collection.mutable.ArrayBuffer;
import scala.collection.mutable.StringBuilder;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.IntRef;
import scala.runtime.ObjectRef;
import streaming.dsl.MLSQLExecuteContext;
import streaming.dsl.ScriptSQLExec$;
import streaming.dsl.mmlib.algs.Functions;
import streaming.dsl.mmlib.algs.MetricValue;
import streaming.dsl.mmlib.algs.SQLPythonFunc$;
import streaming.dsl.mmlib.algs.SQlBaseFunc;
import streaming.log.WowLog;
import tech.mlsql.common.utils.env.python.BasicCondaEnvManager$;
import tech.mlsql.common.utils.hdfs.HDFSOperator$;
import tech.mlsql.common.utils.log.Logging;

/* compiled from: PythonTrain.scala */
@ScalaSignature(bytes = "\u0006\u0001M3A!\u0001\u0002\u0001\u001b\tY\u0001+\u001f;i_:$&/Y5o\u0015\t\u0019A!\u0001\u0004qsRDwN\u001c\u0006\u0003\u000b\u0019\tA!\u00197hg*\u0011q\u0001C\u0001\u0006[6d\u0017N\u0019\u0006\u0003\u0013)\t1\u0001Z:m\u0015\u0005Y\u0011!C:ue\u0016\fW.\u001b8h\u0007\u0001\u0019B\u0001\u0001\b\u00151A\u0011qBE\u0007\u0002!)\t\u0011#A\u0003tG\u0006d\u0017-\u0003\u0002\u0014!\t1\u0011I\\=SK\u001a\u0004\"!\u0006\f\u000e\u0003\u0011I!a\u0006\u0003\u0003\u0013\u0019+hn\u0019;j_:\u001c\bCA\b\u001a\u0013\tQ\u0002C\u0001\u0007TKJL\u0017\r\\5{C\ndW\rC\u0003\u001d\u0001\u0011\u0005Q$\u0001\u0004=S:LGO\u0010\u000b\u0002=A\u0011q\u0004A\u0007\u0002\u0005!)\u0011\u0005\u0001C\u0001E\u0005\u0019BO]1j]~\u0003XM]0qCJ$\u0018\u000e^5p]R!1%P I!\t!#H\u0004\u0002&o9\u0011a\u0005\u000e\b\u0003OEr!\u0001\u000b\u0018\u000f\u0005%bS\"\u0001\u0016\u000b\u0005-b\u0011A\u0002\u001fs_>$h(C\u0001.\u0003\ry'oZ\u0005\u0003_A\na!\u00199bG\",'\"A\u0017\n\u0005I\u001a\u0014!B:qCJ\\'BA\u00181\u0013\t)d'A\u0002tc2T!AM\u001a\n\u0005aJ\u0014a\u00029bG.\fw-\u001a\u0006\u0003kYJ!a\u000f\u001f\u0003\u0013\u0011\u000bG/\u0019$sC6,'B\u0001\u001d:\u0011\u0015q\u0004\u00051\u0001$\u0003\t!g\rC\u0003AA\u0001\u0007\u0011)\u0001\u0003qCRD\u0007C\u0001\"F\u001d\ty1)\u0003\u0002E!\u00051\u0001K]3eK\u001aL!AR$\u0003\rM#(/\u001b8h\u0015\t!\u0005\u0003C\u0003JA\u0001\u0007!*\u0001\u0004qCJ\fWn\u001d\t\u0005\u0005.\u000b\u0015)\u0003\u0002M\u000f\n\u0019Q*\u00199\t\u000b9\u0003A\u0011A(\u0002\u000bQ\u0014\u0018-\u001b8\u0015\t\r\u0002\u0016K\u0015\u0005\u0006}5\u0003\ra\t\u0005\u0006\u00016\u0003\r!\u0011\u0005\u0006\u00136\u0003\rA\u0013")
/* loaded from: input_file:streaming/dsl/mmlib/algs/python/PythonTrain.class */
public class PythonTrain implements Functions {
    private transient Logger tech$mlsql$common$utils$log$Logging$$log_;

    @Override // streaming.dsl.mmlib.algs.Functions
    public void pythonCheckRequirements(Dataset<Row> dataset) {
        Functions.Cclass.pythonCheckRequirements(this, dataset);
    }

    @Override // streaming.dsl.mmlib.algs.Functions
    public Dataset<Row> emptyDataFrame(Dataset<Row> dataset) {
        return Functions.Cclass.emptyDataFrame(this, dataset);
    }

    @Override // streaming.dsl.mmlib.algs.Functions
    public Dataset<Row> emptyDataFrame(SparkSession sparkSession, String str) {
        return Functions.Cclass.emptyDataFrame(this, sparkSession, str);
    }

    @Override // streaming.dsl.mmlib.algs.Functions
    public void sampleUnbalanceWithMultiModel(Dataset<Row> dataset, String str, Map<String, String> map, Function2<Dataset<Row>, Object, BoxedUnit> function2) {
        Functions.Cclass.sampleUnbalanceWithMultiModel(this, dataset, str, map, function2);
    }

    @Override // streaming.dsl.mmlib.algs.Functions
    public Object[] configureModel(Params params, Map<String, String> map) {
        return Functions.Cclass.configureModel(this, params, map);
    }

    @Override // streaming.dsl.mmlib.algs.Functions
    public Map<String, String> mapParams(String str, Map<String, String> map) {
        return Functions.Cclass.mapParams(this, str, map);
    }

    @Override // streaming.dsl.mmlib.algs.Functions
    public Map<String, String>[] arrayParams(String str, Map<String, String> map) {
        return Functions.Cclass.arrayParams(this, str, map);
    }

    @Override // streaming.dsl.mmlib.algs.Functions
    public Tuple2<Object, Map<String, String>>[] arrayParamsWithIndex(String str, Map<String, String> map) {
        return Functions.Cclass.arrayParamsWithIndex(this, str, map);
    }

    @Override // streaming.dsl.mmlib.algs.Functions
    public Object getModelConstructField(Object obj, String str, String str2) {
        return Functions.Cclass.getModelConstructField(this, obj, str, str2);
    }

    @Override // streaming.dsl.mmlib.algs.Functions
    public Object getModelField(Object obj, String str) {
        return Functions.Cclass.getModelField(this, obj, str);
    }

    @Override // streaming.dsl.mmlib.algs.Functions
    public ArrayBuffer<Object> loadModels(String str, Function1<String, Object> function1) {
        return Functions.Cclass.loadModels(this, str, function1);
    }

    @Override // streaming.dsl.mmlib.algs.Functions
    public <T extends Model<T>> void trainModels(Dataset<Row> dataset, String str, Map<String, String> map, Function0<Params> function0) {
        Functions.Cclass.trainModels(this, dataset, str, map, function0);
    }

    @Override // streaming.dsl.mmlib.algs.Functions
    public void trainModelsWithMultiParamGroup2(Dataset<Row> dataset, String str, Map<String, String> map, Function0<Params> function0, Function2<Params, Map<String, String>, List<MetricValue>> function2) {
        Functions.Cclass.trainModelsWithMultiParamGroup2(this, dataset, str, map, function0, function2);
    }

    @Override // streaming.dsl.mmlib.algs.Functions
    public <T extends Model<T>> void trainModelsWithMultiParamGroup(Dataset<Row> dataset, String str, Map<String, String> map, Function0<Params> function0, Function2<Params, Map<String, String>, List<MetricValue>> function2) {
        Functions.Cclass.trainModelsWithMultiParamGroup(this, dataset, str, map, function0, function2);
    }

    @Override // streaming.dsl.mmlib.algs.Functions
    public UserDefinedFunction predict_classification(SparkSession sparkSession, Object obj, String str) {
        return Functions.Cclass.predict_classification(this, sparkSession, obj, str);
    }

    @Override // streaming.dsl.mmlib.algs.Functions
    public Tuple2<Map<String, String>, RDD<byte[]>> writeKafka(Dataset<Row> dataset, String str, Map<String, String> map) {
        return Functions.Cclass.writeKafka(this, dataset, str, map);
    }

    @Override // streaming.dsl.mmlib.algs.Functions
    public String createTempModelLocalPath(String str, boolean z) {
        return Functions.Cclass.createTempModelLocalPath(this, str, z);
    }

    @Override // streaming.dsl.mmlib.algs.Functions
    public boolean distributeResource(SparkSession sparkSession, String str, String str2) {
        return Functions.Cclass.distributeResource(this, sparkSession, str, str2);
    }

    @Override // streaming.dsl.mmlib.algs.Functions
    public boolean createTempModelLocalPath$default$2() {
        return Functions.Cclass.createTempModelLocalPath$default$2(this);
    }

    public String format(String str, boolean z) {
        return WowLog.class.format(this, str, z);
    }

    public String wow_format(String str) {
        return WowLog.class.wow_format(this, str);
    }

    public String format_exception(Exception exc) {
        return WowLog.class.format_exception(this, exc);
    }

    public String format_throwable(Throwable th, boolean z) {
        return WowLog.class.format_throwable(this, th, z);
    }

    public String format_cause(Exception exc) {
        return WowLog.class.format_cause(this, exc);
    }

    public void format_full_exception(ArrayBuffer<String> arrayBuffer, Exception exc, boolean z) {
        WowLog.class.format_full_exception(this, arrayBuffer, exc, z);
    }

    public boolean format$default$2() {
        return WowLog.class.format$default$2(this);
    }

    public boolean format_throwable$default$2() {
        return WowLog.class.format_throwable$default$2(this);
    }

    public boolean format_full_exception$default$3() {
        return WowLog.class.format_full_exception$default$3(this);
    }

    public Logger tech$mlsql$common$utils$log$Logging$$log_() {
        return this.tech$mlsql$common$utils$log$Logging$$log_;
    }

    public void tech$mlsql$common$utils$log$Logging$$log__$eq(Logger logger) {
        this.tech$mlsql$common$utils$log$Logging$$log_ = logger;
    }

    public String logName() {
        return Logging.class.logName(this);
    }

    public Logger log() {
        return Logging.class.log(this);
    }

    public void logInfo(Function0<String> function0) {
        Logging.class.logInfo(this, function0);
    }

    public void logDebug(Function0<String> function0) {
        Logging.class.logDebug(this, function0);
    }

    public void logTrace(Function0<String> function0) {
        Logging.class.logTrace(this, function0);
    }

    public void logWarning(Function0<String> function0) {
        Logging.class.logWarning(this, function0);
    }

    public void logError(Function0<String> function0) {
        Logging.class.logError(this, function0);
    }

    public void logInfo(Function0<String> function0, Throwable th) {
        Logging.class.logInfo(this, function0, th);
    }

    public void logDebug(Function0<String> function0, Throwable th) {
        Logging.class.logDebug(this, function0, th);
    }

    public void logTrace(Function0<String> function0, Throwable th) {
        Logging.class.logTrace(this, function0, th);
    }

    public void logWarning(Function0<String> function0, Throwable th) {
        Logging.class.logWarning(this, function0, th);
    }

    public void logError(Function0<String> function0, Throwable th) {
        Logging.class.logError(this, function0, th);
    }

    public boolean isTraceEnabled() {
        return Logging.class.isTraceEnabled(this);
    }

    public void initializeLogIfNecessary(boolean z) {
        Logging.class.initializeLogIfNecessary(this, z);
    }

    public void saveTraningParams(SparkSession sparkSession, Map<String, String> map, String str) {
        SQlBaseFunc.class.saveTraningParams(this, sparkSession, map, str);
    }

    public Tuple2<Map<String, String>, Dataset<Tuple2<String, String>>> getTranningParams(SparkSession sparkSession, String str) {
        return SQlBaseFunc.class.getTranningParams(this, sparkSession, str);
    }

    public <A, B> Option<B> cleanly(Function0<A> function0, Function1<A, BoxedUnit> function1, Function1<A, B> function12) {
        return SQlBaseFunc.class.cleanly(this, function0, function1, function12);
    }

    public Dataset<Row> train_per_partition(Dataset<Row> dataset, String str, Map<String, String> map) {
        boolean z = new StringOps(Predef$.MODULE$.augmentString((String) map.getOrElse("keepVersion", new PythonTrain$$anonfun$1(this)))).toBoolean();
        boolean z2 = new StringOps(Predef$.MODULE$.augmentString((String) map.getOrElse("keepLocalDirectory", new PythonTrain$$anonfun$2(this)))).toBoolean();
        Option option = map.get("partitionKey");
        ObjectRef create = ObjectRef.create(mapParams("kafkaParam", map));
        boolean z3 = new StringOps(Predef$.MODULE$.augmentString((String) map.getOrElse("enableDataLocal", new PythonTrain$$anonfun$3(this)))).toBoolean();
        IntRef create2 = IntRef.create(-1);
        if (!z3) {
            Predef$.MODULE$.require(((Map) create.elem).size() > 0, new PythonTrain$$anonfun$train_per_partition$1(this));
            Tuple2<Map<String, String>, RDD<byte[]>> writeKafka = writeKafka(dataset, str, map);
            if (writeKafka == null) {
                throw new MatchError(writeKafka);
            }
            Tuple2 tuple2 = new Tuple2((Map) writeKafka._1(), (RDD) writeKafka._2());
            Map map2 = (Map) tuple2._1();
            create2.elem = ((RDD) tuple2._2()).getNumPartitions();
            create.elem = map2;
        }
        Option<PythonScript> loadProject = PythonAlgProject$.MODULE$.loadProject(map, dataset.sparkSession());
        SQLPythonFunc$.MODULE$.incrementVersion(str, z);
        MLSQLExecuteContext contextGetOrForTest = ScriptSQLExec$.MODULE$.contextGetOrForTest();
        Option apply = Option$.MODULE$.apply(((PythonScript) loadProject.get()).filePath());
        String projectName = ((PythonScript) loadProject.get()).projectName();
        PythonScriptType scriptType = ((PythonScript) loadProject.get()).scriptType();
        ObjectRef create3 = ObjectRef.create(mapParams("fitParam", map));
        create3.elem = cleanNumber$1(create3);
        Map<String, String> mapParams = mapParams("systemParam", map);
        MLFlowConfig$.MODULE$.buildFromSystemParam(mapParams);
        PythonConfig buildFromSystemParam = PythonConfig$.MODULE$.buildFromSystemParam(mapParams);
        Map $plus$plus = EnvConfig$.MODULE$.buildFromSystemParam(mapParams).$plus$plus(Predef$.MODULE$.Map().apply(Predef$.MODULE$.wrapRefArray(new Tuple2[]{Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc(BasicCondaEnvManager$.MODULE$.MLSQL_INSTNANCE_NAME_KEY()), dataset.sparkSession().sparkContext().getConf().get("spark.app.name"))})));
        if (z) {
            BoxedUnit boxedUnit = BoxedUnit.UNIT;
        } else {
            if (str.contains("..") || (str != null ? str.equals("/") : "/" == 0) || str.split("/").length < 3) {
                throw new MLSQLException("path should at least three layer");
            }
            BoxesRunTime.boxToBoolean(HDFSOperator$.MODULE$.deleteDir(SQLPythonFunc$.MODULE$.getAlgModelPath(str, z)));
        }
        RDD rdd = dataset.toJSON().rdd();
        dataset.sparkSession().createDataFrame(rdd.mapPartitionsWithIndex(new PythonTrain$$anonfun$4(this, str, z, z2, option, create, z3, create2, loadProject, contextGetOrForTest, apply, projectName, scriptType, create3, $plus$plus), rdd.mapPartitionsWithIndex$default$2(), ClassTag$.MODULE$.apply(Row.class)), PythonTrainingResultSchema$.MODULE$.algSchema()).write().mode(SaveMode.Overwrite).parquet(new StringBuilder().append(SQLPythonFunc$.MODULE$.getAlgMetalPath(str, z)).append("/0").toString());
        dataset.sparkSession().createDataFrame(dataset.sparkSession().sparkContext().parallelize(Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new Seq[]{(Seq) Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new Map[]{(Map) Predef$.MODULE$.Map().apply(Predef$.MODULE$.wrapRefArray(new Tuple2[]{Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc("pythonPath"), buildFromSystemParam.pythonPath()), Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc("pythonVer"), buildFromSystemParam.pythonVer())})), map}))})), 1, ClassTag$.MODULE$.apply(Seq.class)).map(new PythonTrain$$anonfun$6(this), ClassTag$.MODULE$.apply(Row.class)), PythonTrainingResultSchema$.MODULE$.trainParamsSchema()).write().mode(SaveMode.Overwrite).parquet(new StringBuilder().append(SQLPythonFunc$.MODULE$.getAlgMetalPath(str, z)).append("/1").toString());
        return dataset.sparkSession().read().parquet(new StringBuilder().append(SQLPythonFunc$.MODULE$.getAlgMetalPath(str, z)).append("/0").toString());
    }

    public Dataset<Row> train(Dataset<Row> dataset, String str, Map<String, String> map) {
        boolean z = new StringOps(Predef$.MODULE$.augmentString((String) map.getOrElse("keepVersion", new PythonTrain$$anonfun$7(this)))).toBoolean();
        boolean z2 = new StringOps(Predef$.MODULE$.augmentString((String) map.getOrElse("keepLocalDirectory", new PythonTrain$$anonfun$8(this)))).toBoolean();
        ObjectRef create = ObjectRef.create(mapParams("kafkaParam", map));
        DataManager dataManager = new DataManager(dataset, str, map);
        boolean enableDataLocal = dataManager.enableDataLocal();
        String saveDataToHDFS = dataManager.saveDataToHDFS();
        Broadcast<byte[][]> broadCastValidateTable = dataManager.broadCastValidateTable();
        IntRef create2 = IntRef.create(-1);
        if (!enableDataLocal) {
            Predef$.MODULE$.require(((Map) create.elem).size() > 0, new PythonTrain$$anonfun$train$1(this));
            Tuple2<Map<String, String>, RDD<byte[]>> writeKafka = writeKafka(dataset, str, map);
            if (writeKafka == null) {
                throw new MatchError(writeKafka);
            }
            Tuple2 tuple2 = new Tuple2((Map) writeKafka._1(), (RDD) writeKafka._2());
            Map map2 = (Map) tuple2._1();
            create2.elem = ((RDD) tuple2._2()).getNumPartitions();
            create.elem = map2;
        }
        Map<String, String> mapParams = mapParams("systemParam", map);
        Tuple2<Object, Map<String, String>>[] arrayParamsWithIndex = arrayParamsWithIndex("fitParam", map);
        if (Predef$.MODULE$.refArrayOps(arrayParamsWithIndex).size() == 0) {
            logWarning(new PythonTrain$$anonfun$train$2(this));
            arrayParamsWithIndex = new Tuple2[]{Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc(BoxesRunTime.boxToInteger(0)), Predef$.MODULE$.Map().apply(Nil$.MODULE$))};
        }
        RDD parallelize = dataset.sparkSession().sparkContext().parallelize(Predef$.MODULE$.wrapRefArray(arrayParamsWithIndex), arrayParamsWithIndex.length, ClassTag$.MODULE$.apply(Tuple2.class));
        MLFlowConfig buildFromSystemParam = MLFlowConfig$.MODULE$.buildFromSystemParam(mapParams);
        PythonConfig buildFromSystemParam2 = PythonConfig$.MODULE$.buildFromSystemParam(mapParams);
        Map $plus$plus = EnvConfig$.MODULE$.buildFromSystemParam(mapParams).$plus$plus(Predef$.MODULE$.Map().apply(Predef$.MODULE$.wrapRefArray(new Tuple2[]{Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc(BasicCondaEnvManager$.MODULE$.MLSQL_INSTNANCE_NAME_KEY()), dataset.sparkSession().sparkContext().getConf().get("spark.app.name"))})));
        Option<PythonScript> loadProject = PythonAlgProject$.MODULE$.loadProject(map, dataset.sparkSession());
        SQLPythonFunc$.MODULE$.incrementVersion(str, z);
        dataset.sparkSession().createDataFrame(parallelize.map(new PythonTrain$$anonfun$9(this, str, z, z2, create, enableDataLocal, saveDataToHDFS, broadCastValidateTable, create2, mapParams, buildFromSystemParam, buildFromSystemParam2, $plus$plus, loadProject, ScriptSQLExec$.MODULE$.contextGetOrForTest(), Option$.MODULE$.apply(((PythonScript) loadProject.get()).filePath()), ((PythonScript) loadProject.get()).projectName(), ((PythonScript) loadProject.get()).scriptType()), ClassTag$.MODULE$.apply(Row.class)), PythonTrainingResultSchema$.MODULE$.algSchema()).write().mode(SaveMode.Overwrite).parquet(new StringBuilder().append(SQLPythonFunc$.MODULE$.getAlgMetalPath(str, z)).append("/0").toString());
        dataset.sparkSession().createDataFrame(dataset.sparkSession().sparkContext().parallelize(Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new Seq[]{(Seq) Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new Map[]{(Map) Predef$.MODULE$.Map().apply(Predef$.MODULE$.wrapRefArray(new Tuple2[]{Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc("pythonPath"), buildFromSystemParam2.pythonPath()), Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc("pythonVer"), buildFromSystemParam2.pythonVer())})), map}))})), 1, ClassTag$.MODULE$.apply(Seq.class)).map(new PythonTrain$$anonfun$11(this), ClassTag$.MODULE$.apply(Row.class)), PythonTrainingResultSchema$.MODULE$.trainParamsSchema()).write().mode(SaveMode.Overwrite).parquet(new StringBuilder().append(SQLPythonFunc$.MODULE$.getAlgMetalPath(str, z)).append("/1").toString());
        return dataset.sparkSession().read().parquet(new StringBuilder().append(SQLPythonFunc$.MODULE$.getAlgMetalPath(str, z)).append("/0").toString());
    }

    private final Map cleanNumber$1(ObjectRef objectRef) {
        return (Map) ((Map) objectRef.elem).map(new PythonTrain$$anonfun$cleanNumber$1$1(this), Map$.MODULE$.canBuildFrom());
    }

    public PythonTrain() {
        SQlBaseFunc.class.$init$(this);
        Logging.class.$init$(this);
        WowLog.class.$init$(this);
        Functions.Cclass.$init$(this);
    }
}
