package org.apache.lens.ml.impl;

import java.io.File;
import java.io.FileInputStream;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Properties;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.hive.conf.HiveConf;
import org.apache.hadoop.hive.metastore.TableType;
import org.apache.hadoop.hive.metastore.api.FieldSchema;
import org.apache.hadoop.hive.ql.metadata.Hive;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.metadata.Table;
import org.apache.hadoop.hive.ql.plan.AddPartitionDesc;
import org.apache.hadoop.mapred.TextInputFormat;
import org.apache.lens.client.LensClient;
import org.apache.lens.client.LensClientConfig;
import org.apache.lens.client.LensMLClient;
import org.apache.lens.ml.impl.MLTask;

/* loaded from: input_file:org/apache/lens/ml/impl/MLRunner.class */
public class MLRunner {
    private static final Log LOG = LogFactory.getLog(MLRunner.class);
    private LensMLClient mlClient;
    private String algoName;
    private String database;
    private String trainTable;
    private String trainFile;
    private String testTable;
    private String testFile;
    private String outputTable;
    private String[] features;
    private String labelColumn;
    private HiveConf conf;

    public void init(LensMLClient lensMLClient, String str) throws Exception {
        File file = new File(new File(str), "ml.properties");
        Properties properties = new Properties();
        properties.load(new FileInputStream(file));
        init(lensMLClient, properties.getProperty("algo"), properties.getProperty("database"), properties.getProperty("traintable"), str + File.separator + "train.data", properties.getProperty("testtable"), str + File.separator + "test.data", properties.getProperty("outputtable"), properties.getProperty("features").split(","), properties.getProperty("labelcolumn"));
    }

    public void init(LensMLClient lensMLClient, String str, String str2, String str3, String str4, String str5, String str6, String str7, String[] strArr, String str8) {
        this.mlClient = lensMLClient;
        this.algoName = str;
        this.database = str2;
        this.trainTable = str3;
        this.trainFile = str4;
        this.testTable = str5;
        this.testFile = str6;
        this.outputTable = str7;
        this.features = strArr;
        this.labelColumn = str8;
        this.conf = new HiveConf(new LensClientConfig(), MLRunner.class);
    }

    public MLTask train() throws Exception {
        LOG.info("Starting train & eval");
        createTable(this.trainTable, this.trainFile);
        createTable(this.testTable, this.testFile);
        MLTask.Builder builder = new MLTask.Builder();
        builder.algorithm(this.algoName).hiveConf(this.conf).labelColumn(this.labelColumn).outputTable(this.outputTable).client(this.mlClient).trainingTable(this.trainTable).testTable(this.testTable);
        for (String str : this.features) {
            builder.addFeatureColumn(str);
        }
        MLTask build = builder.build();
        LOG.info("Created task " + build.toString());
        build.run();
        return build;
    }

    public void createTable(String str, String str2) throws HiveException {
        Path parent = new Path(new File(str2).toURI()).getParent();
        ArrayList arrayList = new ArrayList();
        if (this.labelColumn != null) {
            arrayList.add(new FieldSchema(this.labelColumn, "double", "Labelled Column"));
        }
        for (String str3 : this.features) {
            arrayList.add(new FieldSchema(str3, "double", "Feature " + str3));
        }
        Table newTable = Hive.get(this.conf).newTable(this.database + "." + str);
        newTable.setTableType(TableType.MANAGED_TABLE);
        newTable.getTTable().getSd().setCols(arrayList);
        newTable.setInputFormatClass(TextInputFormat.class);
        newTable.setSerdeParam("line.delim", "\n");
        newTable.setSerdeParam("field.delim", " ");
        ArrayList arrayList2 = new ArrayList(1);
        arrayList2.add(new FieldSchema("dummy_partition_col", "string", ""));
        newTable.setPartCols(arrayList2);
        Hive.get(this.conf).dropTable(this.database, str, false, true);
        Hive.get(this.conf).createTable(newTable, true);
        LOG.info("Created table " + str);
        AddPartitionDesc addPartitionDesc = new AddPartitionDesc(this.database, str, false);
        HashMap hashMap = new HashMap();
        hashMap.put("dummy_partition_col", "dummy_val");
        addPartitionDesc.addPartition(hashMap, parent.toUri().toString());
        Hive.get(this.conf).createPartitions(addPartitionDesc);
        LOG.info(str + ": Added partition " + parent.toUri().toString());
    }

    public static void main(String[] strArr) throws Exception {
        if (strArr.length < 1) {
            System.out.println("Usage: " + MLRunner.class.getName() + " <ml-conf-dir>");
            System.exit(-1);
        }
        String str = strArr[0];
        LensMLClient lensMLClient = new LensMLClient(new LensClient());
        MLRunner mLRunner = new MLRunner();
        mLRunner.init(lensMLClient, str);
        mLRunner.train();
        System.out.println("Created the Model successfully. Output Table: " + mLRunner.outputTable);
    }
}
