package tech.mlsql.ets.tensorflow;

import java.net.ServerSocket;
import java.util.concurrent.atomic.AtomicReference;
import org.apache.spark.MLSQLSparkUtils$;
import org.apache.spark.TaskContext;
import org.apache.spark.TaskContext$;
import org.apache.spark.util.TaskCompletionListener;
import org.apache.spark.util.TaskContextUtil$;
import os.SubProcess;
import scala.Option$;
import scala.Predef$;
import scala.Predef$ArrowAssoc$;
import scala.Serializable;
import scala.Tuple2;
import scala.collection.Iterator;
import scala.collection.Seq;
import scala.collection.immutable.Map;
import scala.collection.immutable.Nil$;
import scala.package$;
import scala.runtime.AbstractFunction1;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.Nothing$;
import streaming.dsl.mmlib.algs.python.PythonScriptType;
import tech.mlsql.arrow.python.runner.PythonProjectRunner;
import tech.mlsql.common.utils.cluster.ml.MLWorkerProxy;
import tech.mlsql.common.utils.cluster.ml.ReportToMasterRequest;
import tech.mlsql.ets.ml.cluster.CurrentRole;
import tech.mlsql.ets.ml.cluster.DriverHost;
import tech.mlsql.ets.ml.cluster.LocalDirectoryManager$;
import tech.mlsql.ets.ml.cluster.PortManager$;
import tech.mlsql.ets.ml.cluster.TFContext;
import tech.mlsql.log.WriteLog$;

/* compiled from: DistributedTensorflow.scala */
/* loaded from: input_file:tech/mlsql/ets/tensorflow/DistributedTensorflow$$anonfun$tech$mlsql$ets$tensorflow$DistributedTensorflow$$startPs$1$1.class */
public final class DistributedTensorflow$$anonfun$tech$mlsql$ets$tensorflow$DistributedTensorflow$$startPs$1$1 extends AbstractFunction1<Object, Iterator<Nothing$>> implements Serializable {
    public static final long serialVersionUID = 0;
    private final Map fitParam$1;
    private final int workerTargetSize$1;
    private final int psTargetSize$1;
    public final Map logConf$1;
    private final String tempSocketServerHost$1;
    private final int tempSocketServerPort$1;
    public final String pythonProjectPath$1;
    private final String projectName$1;
    public final PythonScriptType projectType$1;
    public final Seq command$1;

    public final Iterator<Nothing$> apply(int i) {
        final TFContext tFContext = new TFContext(new DriverHost(this.tempSocketServerHost$1, this.tempSocketServerPort$1), new CurrentRole("ps", i));
        tFContext.assertCommand();
        ServerSocket preTaken = PortManager$.MODULE$.preTaken();
        int port = PortManager$.MODULE$.getPort(preTaken);
        MLWorkerProxy workerProxy = tFContext.workerProxy();
        workerProxy.reportToMaster(new ReportToMasterRequest(MLSQLSparkUtils$.MODULE$.rpcEnv().address().host(), port, tFContext.currentRole().jobName(), tFContext.currentRole().taskIndex(), tFContext.isPs()));
        workerProxy.waitOthers(this.psTargetSize$1 + this.workerTargetSize$1, workerProxy.waitOthers$default$2());
        PortManager$.MODULE$.releasePreTaken(preTaken);
        final Map $plus$plus = this.fitParam$1.$plus$plus(Predef$.MODULE$.Map().apply(Predef$.MODULE$.wrapRefArray(new Tuple2[]{Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc("clusterSpec"), tFContext.createClusterSpec()), Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc("roleSpec"), tFContext.createRoleSpec())})));
        final String upTaskDirectory = LocalDirectoryManager$.MODULE$.setUpTaskDirectory(this.projectName$1);
        final TaskContext taskContext = TaskContext$.MODULE$.get();
        final AtomicReference atomicReference = new AtomicReference();
        final AtomicReference atomicReference2 = new AtomicReference("normal");
        TaskContext$.MODULE$.get().addTaskCompletionListener(new TaskCompletionListener(this, tFContext, atomicReference) { // from class: tech.mlsql.ets.tensorflow.DistributedTensorflow$$anonfun$tech$mlsql$ets$tensorflow$DistributedTensorflow$$startPs$1$1$$anon$2
            private final TFContext tfContext$1;
            private final AtomicReference processRef$1;

            public void onTaskCompletion(TaskContext taskContext2) {
                if (!taskContext2.isInterrupted() || this.processRef$1.get() == null) {
                    return;
                }
                this.tfContext$1.killPython(((SubProcess) this.processRef$1.get()).wrapped());
            }

            {
                this.tfContext$1 = tFContext;
                this.processRef$1 = atomicReference;
            }
        });
        Thread thread = new Thread(new Runnable(this, tFContext, $plus$plus, upTaskDirectory, taskContext, atomicReference, atomicReference2) { // from class: tech.mlsql.ets.tensorflow.DistributedTensorflow$$anonfun$tech$mlsql$ets$tensorflow$DistributedTensorflow$$startPs$1$1$$anon$4
            private final /* synthetic */ DistributedTensorflow$$anonfun$tech$mlsql$ets$tensorflow$DistributedTensorflow$$startPs$1$1 $outer;
            private final TFContext tfContext$1;
            private final Map paramMap$1;
            private final String taskDirectory$1;
            private final TaskContext context$2;
            private final AtomicReference processRef$1;
            private final AtomicReference flag$1;

            @Override // java.lang.Runnable
            public void run() {
                TaskContextUtil$.MODULE$.setContext(this.context$2);
                try {
                    LocalDirectoryManager$.MODULE$.downloadProject(this.taskDirectory$1, Option$.MODULE$.apply(this.$outer.pythonProjectPath$1), this.$outer.projectType$1);
                    PythonProjectRunner pythonProjectRunner = new PythonProjectRunner(this.taskDirectory$1, Predef$.MODULE$.Map().apply(Nil$.MODULE$));
                    Iterator run = pythonProjectRunner.run(this.$outer.command$1, this.paramMap$1.$plus$plus(this.$outer.logConf$1));
                    this.processRef$1.set(pythonProjectRunner.getPythonProcess().get());
                    WriteLog$.MODULE$.write(run, this.paramMap$1.$plus$plus(this.$outer.logConf$1));
                } catch (Exception e) {
                    Object obj = this.flag$1.get();
                    if (obj == null) {
                        if ("kill-python" == 0) {
                            return;
                        }
                    } else if (obj.equals("kill-python")) {
                        return;
                    }
                    if (this.processRef$1.get() == null) {
                        return;
                    }
                    this.tfContext$1.reportFails();
                    this.tfContext$1.killPython(((SubProcess) this.processRef$1.get()).wrapped());
                    this.tfContext$1.close();
                    throw e;
                }
            }

            {
                if (this == null) {
                    throw null;
                }
                this.$outer = this;
                this.tfContext$1 = tFContext;
                this.paramMap$1 = $plus$plus;
                this.taskDirectory$1 = upTaskDirectory;
                this.context$2 = taskContext;
                this.processRef$1 = atomicReference;
                this.flag$1 = atomicReference2;
            }
        });
        thread.setDaemon(true);
        thread.start();
        tFContext.waitDoneOrFail();
        tFContext.reportSuccess();
        thread.interrupt();
        if (atomicReference.get() == null) {
            BoxedUnit boxedUnit = BoxedUnit.UNIT;
        } else {
            atomicReference2.set("kill-python");
            tFContext.killPython(((SubProcess) atomicReference.get()).wrapped());
        }
        tFContext.close();
        return package$.MODULE$.Iterator().apply(Nil$.MODULE$);
    }

    public final /* bridge */ /* synthetic */ Object apply(Object obj) {
        return apply(BoxesRunTime.unboxToInt(obj));
    }

    public DistributedTensorflow$$anonfun$tech$mlsql$ets$tensorflow$DistributedTensorflow$$startPs$1$1(DistributedTensorflow distributedTensorflow, Map map, int i, int i2, Map map2, String str, int i3, String str2, String str3, PythonScriptType pythonScriptType, Seq seq) {
        this.fitParam$1 = map;
        this.workerTargetSize$1 = i;
        this.psTargetSize$1 = i2;
        this.logConf$1 = map2;
        this.tempSocketServerHost$1 = str;
        this.tempSocketServerPort$1 = i3;
        this.pythonProjectPath$1 = str2;
        this.projectName$1 = str3;
        this.projectType$1 = pythonScriptType;
        this.command$1 = seq;
    }
}
