package org.apache.lens.ml;

import java.io.File;
import java.net.URI;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import javax.ws.rs.client.WebTarget;
import javax.ws.rs.core.Application;
import javax.ws.rs.core.UriBuilder;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.conf.HiveConf;
import org.apache.hadoop.hive.metastore.api.Database;
import org.apache.hadoop.hive.ql.metadata.Hive;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.metadata.Partition;
import org.apache.lens.client.LensClient;
import org.apache.lens.client.LensClientConfig;
import org.apache.lens.client.LensMLClient;
import org.apache.lens.ml.algo.spark.dt.DecisionTreeAlgo;
import org.apache.lens.ml.algo.spark.lr.LogisticRegressionAlgo;
import org.apache.lens.ml.algo.spark.nb.NaiveBayesAlgo;
import org.apache.lens.ml.algo.spark.svm.SVMAlgo;
import org.apache.lens.ml.impl.MLTask;
import org.apache.lens.ml.impl.MLUtils;
import org.apache.lens.ml.server.MLApp;
import org.apache.lens.server.LensJerseyTest;
import org.apache.lens.server.query.QueryServiceResource;
import org.apache.lens.server.session.SessionResource;
import org.glassfish.jersey.client.ClientConfig;
import org.glassfish.jersey.media.multipart.MultiPartFeature;
import org.testng.Assert;
import org.testng.annotations.AfterTest;
import org.testng.annotations.BeforeMethod;
import org.testng.annotations.BeforeTest;
import org.testng.annotations.Test;

@Test
/* loaded from: input_file:org/apache/lens/ml/TestMLResource.class */
public class TestMLResource extends LensJerseyTest {
    private static final Log LOG = LogFactory.getLog(TestMLResource.class);
    private static final String TEST_DB = "default";
    private WebTarget mlTarget;
    private LensMLClient mlClient;

    protected int getTestPort() {
        return 10002;
    }

    protected Application configure() {
        return new MLApp(new Class[]{SessionResource.class, QueryServiceResource.class});
    }

    protected void configureClient(ClientConfig clientConfig) {
        clientConfig.register(MultiPartFeature.class);
    }

    protected URI getBaseUri() {
        return UriBuilder.fromUri("http://localhost/").port(getTestPort()).path("/lensapi").build(new Object[0]);
    }

    @BeforeTest
    public void setUp() throws Exception {
        super.setUp();
        Hive hive = Hive.get(new HiveConf());
        Database database = new Database();
        database.setName(TEST_DB);
        hive.createDatabase(database, true);
        LensClientConfig lensClientConfig = new LensClientConfig();
        lensClientConfig.setLensDatabase(TEST_DB);
        lensClientConfig.set("lens.server.base.url", "http://localhost:" + getTestPort() + "/lensapi");
        this.mlClient = new LensMLClient(new LensClient(lensClientConfig));
    }

    @AfterTest
    public void tearDown() throws Exception {
        super.tearDown();
        try {
            Hive.get(new HiveConf()).dropDatabase(TEST_DB);
        } catch (Exception e) {
            e.printStackTrace();
        }
        this.mlClient.close();
    }

    @BeforeMethod
    public void setMLTarget() {
        this.mlTarget = target().path("ml");
    }

    @Test
    public void testMLResourceUp() throws Exception {
        Assert.assertEquals((String) this.mlTarget.request().get(String.class), "ML service is up");
    }

    @Test
    public void testGetAlgos() throws Exception {
        List algorithms = this.mlClient.getAlgorithms();
        Assert.assertNotNull(algorithms);
        Assert.assertTrue(algorithms.contains(MLUtils.getAlgoName(NaiveBayesAlgo.class)), MLUtils.getAlgoName(NaiveBayesAlgo.class));
        Assert.assertTrue(algorithms.contains(MLUtils.getAlgoName(SVMAlgo.class)), MLUtils.getAlgoName(SVMAlgo.class));
        Assert.assertTrue(algorithms.contains(MLUtils.getAlgoName(LogisticRegressionAlgo.class)), MLUtils.getAlgoName(LogisticRegressionAlgo.class));
        Assert.assertTrue(algorithms.contains(MLUtils.getAlgoName(DecisionTreeAlgo.class)), MLUtils.getAlgoName(DecisionTreeAlgo.class));
    }

    @Test
    public void testGetAlgoParams() throws Exception {
        Map algoParamDescription = this.mlClient.getAlgoParamDescription(MLUtils.getAlgoName(DecisionTreeAlgo.class));
        Assert.assertNotNull(algoParamDescription);
        Assert.assertFalse(algoParamDescription.isEmpty());
        for (String str : algoParamDescription.keySet()) {
            LOG.info("## Param " + str + " help = " + ((String) algoParamDescription.get(str)));
        }
    }

    @Test
    public void trainAndEval() throws Exception {
        LOG.info("Starting train & eval");
        String algoName = MLUtils.getAlgoName(NaiveBayesAlgo.class);
        HiveConf hiveConf = new HiveConf();
        URI uri = new File("data/naive_bayes/naive_bayes_train.data").toURI();
        String[] strArr = {"feature_1", "feature_2", "feature_3"};
        LOG.info("Creating training table from file " + uri.toString());
        try {
            ExampleUtils.createTable(hiveConf, TEST_DB, "naivebayes_training_table", uri.toString(), "label", new HashMap(), strArr);
        } catch (HiveException e) {
            e.printStackTrace();
        }
        MLTask.Builder builder = new MLTask.Builder();
        builder.algorithm(algoName).hiveConf(hiveConf).labelColumn("label").outputTable("naivebayes_eval_table").client(this.mlClient).trainingTable("naivebayes_training_table");
        builder.addFeatureColumn("feature_1").addFeatureColumn("feature_2").addFeatureColumn("feature_3");
        MLTask build = builder.build();
        LOG.info("Created task " + build.toString());
        build.run();
        Assert.assertEquals(build.getTaskState(), MLTask.State.SUCCESSFUL);
        String modelID = build.getModelID();
        String reportID = build.getReportID();
        Assert.assertNotNull(reportID);
        Assert.assertNotNull(modelID);
        MLTask.Builder builder2 = new MLTask.Builder();
        builder2.algorithm(algoName).hiveConf(hiveConf).labelColumn("label").outputTable("naivebayes_eval_table").client(this.mlClient).trainingTable("naivebayes_training_table");
        builder2.addFeatureColumn("feature_1").addFeatureColumn("feature_2").addFeatureColumn("feature_3");
        MLTask build2 = builder2.build();
        LOG.info("Created second task " + build2.toString());
        build2.run();
        String modelID2 = build2.getModelID();
        String reportID2 = build2.getReportID();
        Assert.assertNotNull(modelID2);
        Assert.assertNotNull(reportID2);
        Hive hive = Hive.get(hiveConf);
        List<Partition> partitions = hive.getPartitions(hive.getTable("naivebayes_eval_table"));
        Assert.assertNotNull(partitions);
        HashSet hashSet = new HashSet();
        for (Partition partition : partitions) {
            LOG.info("@@PART#0 " + partition.getSpec().toString());
            hashSet.add(partition.getSpec().get("part_testid"));
        }
        Assert.assertTrue(hashSet.contains(reportID), reportID + "  first partition not there");
        Assert.assertTrue(hashSet.contains(reportID2), reportID2 + " second partition not there");
        LOG.info("Completed task run");
    }
}
