package org.apache.lens.ml.impl;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.conf.HiveConf;
import org.apache.lens.client.LensMLClient;
import org.apache.lens.ml.api.LensML;
import org.apache.lens.ml.api.MLTestReport;

/* loaded from: input_file:org/apache/lens/ml/impl/MLTask.class */
public class MLTask implements Runnable {
    private static final Log LOG = LogFactory.getLog(MLTask.class);
    private State taskState;
    private String algorithm;
    private String trainingTable;
    private String testTable;
    private String partitionSpec;
    private String labelColumn;
    private List<String> featureColumns;
    private HiveConf configuration;
    private LensML ml;
    private String taskID;
    private LensMLClient mlClient;
    private String outputTable;
    private Map<String, String> extraParams;
    private String modelID;
    private String reportID;

    /* loaded from: input_file:org/apache/lens/ml/impl/MLTask$Builder.class */
    public static class Builder {
        private MLTask task = new MLTask();

        public Builder trainingTable(String str) {
            this.task.trainingTable = str;
            return this;
        }

        public Builder testTable(String str) {
            this.task.testTable = str;
            return this;
        }

        public Builder algorithm(String str) {
            this.task.algorithm = str;
            return this;
        }

        public Builder labelColumn(String str) {
            this.task.labelColumn = str;
            return this;
        }

        public Builder client(LensMLClient lensMLClient) {
            this.task.mlClient = lensMLClient;
            return this;
        }

        public Builder addFeatureColumn(String str) {
            if (this.task.featureColumns == null) {
                this.task.featureColumns = new ArrayList();
            }
            this.task.featureColumns.add(str);
            return this;
        }

        public Builder hiveConf(HiveConf hiveConf) {
            this.task.configuration = hiveConf;
            return this;
        }

        public Builder extraParam(String str, String str2) {
            this.task.extraParams.put(str, str2);
            return this;
        }

        public Builder partitionSpec(String str) {
            this.task.partitionSpec = str;
            return this;
        }

        public Builder outputTable(String str) {
            this.task.outputTable = str;
            return this;
        }

        public MLTask build() {
            MLTask mLTask = this.task;
            this.task = null;
            return mLTask;
        }
    }

    /* loaded from: input_file:org/apache/lens/ml/impl/MLTask$State.class */
    public enum State {
        RUNNING,
        SUCCESSFUL,
        FAILED
    }

    private MLTask() {
        this.extraParams = new HashMap();
        this.taskID = UUID.randomUUID().toString();
    }

    @Override // java.lang.Runnable
    public void run() {
        this.taskState = State.RUNNING;
        LOG.info("Starting " + this.taskID);
        try {
            runTask();
            this.taskState = State.SUCCESSFUL;
            LOG.info("Complete " + this.taskID);
        } catch (Exception e) {
            this.taskState = State.FAILED;
            LOG.info("Error running task " + this.taskID, e);
        }
    }

    private void runTask() throws Exception {
        if (this.mlClient != null) {
            this.ml = this.mlClient;
            LOG.info("Working in client mode. Lens session handle " + this.mlClient.getSessionHandle().getPublicId());
        } else {
            this.ml = MLUtils.getMLService();
            LOG.info("Working in Lens server");
        }
        String[] buildTrainingArgs = buildTrainingArgs();
        LOG.info("Starting task " + this.taskID + " algo args: " + Arrays.toString(buildTrainingArgs));
        this.modelID = this.ml.train(this.trainingTable, this.algorithm, buildTrainingArgs);
        printModelMetadata(this.taskID, this.modelID);
        LOG.info("Starting test " + this.taskID);
        this.testTable = this.testTable != null ? this.testTable : this.trainingTable;
        MLTestReport testModel = this.ml.testModel(this.mlClient.getSessionHandle(), this.testTable, this.algorithm, this.modelID, this.outputTable);
        this.reportID = testModel.getReportID();
        printTestReport(this.taskID, testModel);
        saveTask();
    }

    private void saveTask() {
        LOG.info("Saving task details to DB");
    }

    private void printTestReport(String str, MLTestReport mLTestReport) {
        StringBuilder append = new StringBuilder("Example: ").append(str);
        append.append("\n\t");
        append.append("EvaluationReport: ").append(mLTestReport.toString());
        System.out.println(append.toString());
    }

    private String[] buildTrainingArgs() {
        ArrayList arrayList = new ArrayList();
        arrayList.add("label");
        arrayList.add(this.labelColumn);
        for (String str : this.featureColumns) {
            arrayList.add("feature");
            arrayList.add(str);
        }
        for (String str2 : this.extraParams.keySet()) {
            arrayList.add(str2);
            arrayList.add(this.extraParams.get(str2));
        }
        return (String[]) arrayList.toArray(new String[arrayList.size()]);
    }

    private void printModelMetadata(String str, String str2) throws Exception {
        StringBuilder append = new StringBuilder("Example: ").append(str);
        append.append("\n\t");
        append.append("Model: ");
        append.append(this.ml.getModel(this.algorithm, str2).toString());
        System.out.println(append.toString());
    }

    public String toString() {
        return "MLTask(taskState=" + getTaskState() + ", algorithm=" + getAlgorithm() + ", trainingTable=" + getTrainingTable() + ", testTable=" + getTestTable() + ", partitionSpec=" + getPartitionSpec() + ", labelColumn=" + getLabelColumn() + ", featureColumns=" + getFeatureColumns() + ", configuration=" + getConfiguration() + ", ml=" + this.ml + ", taskID=" + this.taskID + ", mlClient=" + getMlClient() + ", outputTable=" + getOutputTable() + ", extraParams=" + getExtraParams() + ", modelID=" + getModelID() + ", reportID=" + getReportID() + ")";
    }

    public State getTaskState() {
        return this.taskState;
    }

    public String getAlgorithm() {
        return this.algorithm;
    }

    public String getTrainingTable() {
        return this.trainingTable;
    }

    public String getTestTable() {
        return this.testTable;
    }

    public String getPartitionSpec() {
        return this.partitionSpec;
    }

    public String getLabelColumn() {
        return this.labelColumn;
    }

    public List<String> getFeatureColumns() {
        return this.featureColumns;
    }

    public HiveConf getConfiguration() {
        return this.configuration;
    }

    public LensMLClient getMlClient() {
        return this.mlClient;
    }

    public String getOutputTable() {
        return this.outputTable;
    }

    public Map<String, String> getExtraParams() {
        return this.extraParams;
    }

    public String getModelID() {
        return this.modelID;
    }

    public String getReportID() {
        return this.reportID;
    }
}
