/*
 * Decompiled with CFR 0.152.
 */
package org.apache.dolphinscheduler.plugin.task.pytorch;

import java.util.ArrayList;
import java.util.Map;
import org.apache.dolphinscheduler.common.utils.JSONUtils;
import org.apache.dolphinscheduler.plugin.task.api.AbstractTask;
import org.apache.dolphinscheduler.plugin.task.api.ShellCommandExecutor;
import org.apache.dolphinscheduler.plugin.task.api.TaskCallBack;
import org.apache.dolphinscheduler.plugin.task.api.TaskException;
import org.apache.dolphinscheduler.plugin.task.api.TaskExecutionContext;
import org.apache.dolphinscheduler.plugin.task.api.model.TaskResponse;
import org.apache.dolphinscheduler.plugin.task.api.parameters.AbstractParameters;
import org.apache.dolphinscheduler.plugin.task.api.parser.ParamUtils;
import org.apache.dolphinscheduler.plugin.task.api.parser.ParameterUtils;
import org.apache.dolphinscheduler.plugin.task.pytorch.GitProjectManager;
import org.apache.dolphinscheduler.plugin.task.pytorch.PythonEnvManager;
import org.apache.dolphinscheduler.plugin.task.pytorch.PytorchParameters;

public class PytorchTask
extends AbstractTask {
    private final ShellCommandExecutor shellCommandExecutor;
    protected PytorchParameters pytorchParameters;
    protected TaskExecutionContext taskExecutionContext;
    private PythonEnvManager pythonEnvManager;

    public PytorchTask(TaskExecutionContext taskExecutionContext) {
        super(taskExecutionContext);
        this.taskExecutionContext = taskExecutionContext;
        this.shellCommandExecutor = new ShellCommandExecutor(arg_0 -> ((PytorchTask)this).logHandle(arg_0), taskExecutionContext, this.logger);
    }

    public void init() {
        this.logger.info("python task params {}", (Object)this.taskExecutionContext.getTaskParams());
        this.pytorchParameters = (PytorchParameters)((Object)JSONUtils.parseObject((String)this.taskExecutionContext.getTaskParams(), PytorchParameters.class));
        if (!this.pytorchParameters.checkParameters()) {
            throw new TaskException("python task params is not valid");
        }
        this.pythonEnvManager = new PythonEnvManager();
        this.pythonEnvManager.setPythonEnvTool(this.pytorchParameters.getPythonEnvTool());
        this.pythonEnvManager.setCondaPythonVersion(this.pytorchParameters.getCondaPythonVersion());
    }

    public void handle(TaskCallBack taskCallBack) throws TaskException {
        try {
            String command = this.buildPythonExecuteCommand();
            TaskResponse taskResponse = this.shellCommandExecutor.run(command);
            this.setExitStatusCode(taskResponse.getExitStatusCode());
            this.setProcessId(taskResponse.getProcessId());
            this.setVarPool(this.shellCommandExecutor.getVarPool());
        }
        catch (InterruptedException e) {
            Thread.currentThread().interrupt();
            this.logger.error("The current Pytorch task has been interrupted", (Throwable)e);
            this.setExitStatusCode(-1);
            throw new TaskException("The current Pytorch task has been interrupted", (Throwable)e);
        }
        catch (Exception e) {
            this.setExitStatusCode(-1);
            throw new TaskException("Pytorch task execute failed", (Throwable)e);
        }
    }

    public void cancel() throws TaskException {
    }

    public String buildPythonExecuteCommand() throws Exception {
        String scriptParams;
        ArrayList<String> args = new ArrayList<String>();
        String pythonPath = this.pytorchParameters.getPythonPath();
        if (GitProjectManager.isGitPath(pythonPath)) {
            GitProjectManager gpm = new GitProjectManager();
            gpm.setPath(pythonPath);
            gpm.setBaseDir(this.taskExecutionContext.getExecutePath());
            gpm.prepareProject();
            this.pytorchParameters.setPythonPath(gpm.getGitLocalPath());
        }
        args.add(String.format("export PYTHONPATH=%s", this.pytorchParameters.getPythonPath()));
        if (this.pytorchParameters.getIsCreateEnvironment().booleanValue()) {
            String buildEnvCommand = this.pythonEnvManager.getBuildEnvCommand(this.pytorchParameters.getRequirementPath());
            args.add(buildEnvCommand);
        }
        if ((scriptParams = this.pytorchParameters.getScriptParams()) != null && !scriptParams.isEmpty()) {
            args.add(String.format("%s %s %s", this.getPythonCommand(), this.pytorchParameters.getScriptPath(), this.pytorchParameters.getScriptParams()));
        } else {
            args.add(String.format("%s %s", this.getPythonCommand(), this.pytorchParameters.getScriptPath()));
        }
        Map paramsMap = this.taskExecutionContext.getPrepareParamsMap();
        return ParameterUtils.convertParameterPlaceholders((String)String.join((CharSequence)"\n", args), (Map)ParamUtils.convert((Map)paramsMap));
    }

    private String getPythonCommand() {
        String pythonCommand = this.pytorchParameters.getIsCreateEnvironment() != false ? this.pythonEnvManager.getPythonCommand() : this.pytorchParameters.getPythonCommand();
        return pythonCommand;
    }

    public AbstractParameters getParameters() {
        return this.pytorchParameters;
    }
}

