package org.apache.lens.ml;

import java.io.IOException;
import java.io.ObjectOutputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Date;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import javax.ws.rs.client.ClientBuilder;
import javax.ws.rs.client.Entity;
import javax.ws.rs.client.WebTarget;
import javax.ws.rs.core.MediaType;
import org.apache.commons.io.IOUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.fs.FileStatus;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.hive.conf.HiveConf;
import org.apache.hadoop.hive.ql.session.SessionState;
import org.apache.lens.api.LensConf;
import org.apache.lens.api.LensException;
import org.apache.lens.api.LensSessionHandle;
import org.apache.lens.api.query.LensQuery;
import org.apache.lens.api.query.QueryHandle;
import org.apache.lens.api.query.QueryStatus;
import org.apache.lens.ml.spark.SparkMLDriver;
import org.apache.lens.ml.spark.algos.BaseSparkAlgo;
import org.apache.spark.api.java.JavaSparkContext;
import org.glassfish.jersey.media.multipart.FormDataBodyPart;
import org.glassfish.jersey.media.multipart.FormDataContentDisposition;
import org.glassfish.jersey.media.multipart.FormDataMultiPart;
import org.glassfish.jersey.media.multipart.MultiPartFeature;

/* loaded from: input_file:org/apache/lens/ml/LensMLImpl.class */
public class LensMLImpl implements LensML {
    public static final Log LOG = LogFactory.getLog(LensMLImpl.class);
    protected List<MLDriver> drivers;
    private HiveConf conf;
    private JavaSparkContext sparkContext;

    /* loaded from: input_file:org/apache/lens/ml/LensMLImpl$RemoteQueryRunner.class */
    class RemoteQueryRunner extends TestQueryRunner {
        final String queryApiUrl;

        public RemoteQueryRunner(LensSessionHandle lensSessionHandle, String str) {
            super(lensSessionHandle);
            this.queryApiUrl = str;
        }

        @Override // org.apache.lens.ml.TestQueryRunner
        public QueryHandle runQuery(String str) throws LensException {
            WebTarget target = ClientBuilder.newBuilder().register(MultiPartFeature.class).build().target(this.queryApiUrl);
            FormDataMultiPart formDataMultiPart = new FormDataMultiPart();
            formDataMultiPart.bodyPart(new FormDataBodyPart(FormDataContentDisposition.name("sessionid").build(), this.sessionHandle, MediaType.APPLICATION_XML_TYPE));
            formDataMultiPart.bodyPart(new FormDataBodyPart(FormDataContentDisposition.name("query").build(), str));
            formDataMultiPart.bodyPart(new FormDataBodyPart(FormDataContentDisposition.name("operation").build(), "execute"));
            LensConf lensConf = new LensConf();
            lensConf.addProperty("lens.query.enable.persistent.resultset", "false");
            lensConf.addProperty("lens.query.enable.persistent.resultset.indriver", "false");
            formDataMultiPart.bodyPart(new FormDataBodyPart(FormDataContentDisposition.name("conf").fileName("conf").build(), lensConf, MediaType.APPLICATION_XML_TYPE));
            QueryHandle queryHandle = (QueryHandle) target.request().post(Entity.entity(formDataMultiPart, MediaType.MULTIPART_FORM_DATA_TYPE), QueryHandle.class);
            LensQuery lensQuery = (LensQuery) target.path(queryHandle.toString()).queryParam("sessionid", new Object[]{this.sessionHandle}).request().get(LensQuery.class);
            QueryStatus status = lensQuery.getStatus();
            while (!status.isFinished()) {
                lensQuery = (LensQuery) target.path(queryHandle.toString()).queryParam("sessionid", new Object[]{this.sessionHandle}).request().get(LensQuery.class);
                status = lensQuery.getStatus();
                try {
                    Thread.sleep(500L);
                } catch (InterruptedException e) {
                    throw new LensException(e);
                }
            }
            if (status.getStatus() != QueryStatus.Status.SUCCESSFUL) {
                throw new LensException("Query failed " + lensQuery.getQueryHandle().getHandleId() + " reason:" + status.getErrorMessage());
            }
            return lensQuery.getQueryHandle();
        }
    }

    public LensMLImpl(HiveConf hiveConf) {
        this.conf = hiveConf;
    }

    public HiveConf getConf() {
        return this.conf;
    }

    public void setSparkContext(JavaSparkContext javaSparkContext) {
        this.sparkContext = javaSparkContext;
    }

    @Override // org.apache.lens.ml.LensML
    public List<String> getAlgorithms() {
        ArrayList arrayList = new ArrayList();
        Iterator<MLDriver> it = this.drivers.iterator();
        while (it.hasNext()) {
            arrayList.addAll(it.next().getAlgoNames());
        }
        return arrayList;
    }

    @Override // org.apache.lens.ml.LensML
    public MLAlgo getAlgoForName(String str) throws LensException {
        for (MLDriver mLDriver : this.drivers) {
            if (mLDriver.isAlgoSupported(str)) {
                return mLDriver.getAlgoInstance(str);
            }
        }
        throw new LensException("Algo not supported " + str);
    }

    @Override // org.apache.lens.ml.LensML
    public String train(String str, String str2, String[] strArr) throws LensException {
        MLAlgo algoForName = getAlgoForName(str2);
        String uuid = UUID.randomUUID().toString();
        LOG.info("Begin training model " + uuid + ", algo=" + str2 + ", table=" + str + ", params=" + Arrays.toString(strArr));
        MLModel train = algoForName.train(toLensConf(this.conf), SessionState.get() != null ? SessionState.get().getCurrentDatabase() : "default", str, uuid, strArr);
        LOG.info("Done training model: " + uuid);
        train.setCreatedAt(new Date());
        train.setAlgoName(str2);
        try {
            LOG.info("Model saved: " + uuid + ", algo: " + str2 + ", path: " + persistModel(train));
            return train.getId();
        } catch (IOException e) {
            throw new LensException("Error saving model " + uuid + " for algo " + str2, e);
        }
    }

    private Path getAlgoDir(String str) throws IOException {
        return new Path(new Path(this.conf.get(ModelLoader.MODEL_PATH_BASE_DIR, ModelLoader.MODEL_PATH_BASE_DIR_DEFAULT)), str);
    }

    private Path persistModel(MLModel mLModel) throws IOException {
        Path algoDir = getAlgoDir(mLModel.getAlgoName());
        FileSystem fileSystem = algoDir.getFileSystem(this.conf);
        if (!fileSystem.exists(algoDir)) {
            fileSystem.mkdirs(algoDir);
        }
        Path path = new Path(algoDir, mLModel.getId());
        ObjectOutputStream objectOutputStream = null;
        try {
            try {
                objectOutputStream = new ObjectOutputStream(fileSystem.create(path, false));
                objectOutputStream.writeObject(mLModel);
                objectOutputStream.flush();
                IOUtils.closeQuietly(objectOutputStream);
                return path;
            } catch (IOException e) {
                LOG.error("Error saving model " + mLModel.getId() + " reason: " + e.getMessage());
                throw e;
            }
        } catch (Throwable th) {
            IOUtils.closeQuietly(objectOutputStream);
            throw th;
        }
    }

    @Override // org.apache.lens.ml.LensML
    public List<String> getModels(String str) throws LensException {
        try {
            Path algoDir = getAlgoDir(str);
            FileSystem fileSystem = algoDir.getFileSystem(this.conf);
            if (!fileSystem.exists(algoDir)) {
                return null;
            }
            ArrayList arrayList = new ArrayList();
            for (FileStatus fileStatus : fileSystem.listStatus(algoDir)) {
                arrayList.add(fileStatus.getPath().getName());
            }
            if (arrayList.isEmpty()) {
                return null;
            }
            return arrayList;
        } catch (IOException e) {
            throw new LensException(e);
        }
    }

    @Override // org.apache.lens.ml.LensML
    public MLModel getModel(String str, String str2) throws LensException {
        try {
            return ModelLoader.loadModel(this.conf, str, str2);
        } catch (IOException e) {
            throw new LensException(e);
        }
    }

    public synchronized void init(HiveConf hiveConf) {
        this.conf = hiveConf;
        String[] strings = hiveConf.getStrings("lens.ml.drivers");
        if (strings == null || strings.length == 0) {
            throw new RuntimeException("No ML Drivers specified in conf");
        }
        LOG.info("Loading drivers " + Arrays.toString(strings));
        this.drivers = new ArrayList(strings.length);
        for (String str : strings) {
            try {
                Class<?> cls = Class.forName(str);
                if (MLDriver.class.isAssignableFrom(cls)) {
                    try {
                        MLDriver mLDriver = (MLDriver) cls.newInstance();
                        mLDriver.init(toLensConf(this.conf));
                        this.drivers.add(mLDriver);
                        LOG.info("Added driver " + str);
                    } catch (Exception e) {
                        LOG.error("Failed to create driver " + str + " reason: " + e.getMessage(), e);
                    }
                } else {
                    LOG.warn("Not a driver class " + str);
                }
            } catch (ClassNotFoundException e2) {
                LOG.error("Driver class not found " + str);
            }
        }
        if (this.drivers.isEmpty()) {
            throw new RuntimeException("No ML drivers loaded");
        }
        LOG.info("Inited ML service");
    }

    public synchronized void start() {
        for (MLDriver mLDriver : this.drivers) {
            try {
                if ((mLDriver instanceof SparkMLDriver) && this.sparkContext != null) {
                    ((SparkMLDriver) mLDriver).useSparkContext(this.sparkContext);
                }
                mLDriver.start();
            } catch (LensException e) {
                LOG.error("Failed to start driver " + mLDriver, e);
            }
        }
        LOG.info("Started ML service");
    }

    public synchronized void stop() {
        for (MLDriver mLDriver : this.drivers) {
            try {
                mLDriver.stop();
            } catch (LensException e) {
                LOG.error("Failed to stop driver " + mLDriver, e);
            }
        }
        this.drivers.clear();
        LOG.info("Stopped ML service");
    }

    public synchronized HiveConf getHiveConf() {
        return this.conf;
    }

    public void clearModels() {
        ModelLoader.clearCache();
    }

    @Override // org.apache.lens.ml.LensML
    public String getModelPath(String str, String str2) {
        return ModelLoader.getModelLocation(this.conf, str, str2).toString();
    }

    @Override // org.apache.lens.ml.LensML
    public MLTestReport testModel(LensSessionHandle lensSessionHandle, String str, String str2, String str3, String str4) throws LensException {
        return null;
    }

    public MLTestReport testModelRemote(LensSessionHandle lensSessionHandle, String str, String str2, String str3, String str4, String str5) throws LensException {
        return testModel(lensSessionHandle, str, str2, str3, new RemoteQueryRunner(lensSessionHandle, str4), str5);
    }

    public MLTestReport testModel(LensSessionHandle lensSessionHandle, String str, String str2, String str3, TestQueryRunner testQueryRunner, String str4) throws LensException {
        if (!getAlgorithms().contains(str2)) {
            throw new LensException("No such algorithm " + str2);
        }
        try {
            MLModel loadModel = ModelLoader.loadModel(this.conf, str2, str3);
            if (loadModel == null) {
                throw new LensException("Model not found: " + str3 + " algorithm=" + str2);
            }
            String str5 = null;
            if (SessionState.get() != null) {
                str5 = SessionState.get().getCurrentDatabase();
            }
            String replace = UUID.randomUUID().toString().replace("-", "_");
            TableTestingSpec build = TableTestingSpec.newBuilder().hiveConf(this.conf).database(str5 == null ? "default" : str5).inputTable(str).featureColumns(loadModel.getFeatureColumns()).outputColumn("prediction_result").lableColumn(loadModel.getLabelColumn()).algorithm(str2).modelID(str3).outputTable(str4).testID(replace).build();
            String testQuery = build.getTestQuery();
            if (testQuery == null) {
                throw new LensException("Invalid test spec. table=" + str + " algorithm=" + str2 + " modelID=" + str3);
            }
            if (!build.isOutputTableExists()) {
                LOG.info("Output table '" + str4 + "' does not exist for test algorithm = " + str2 + " modelid=" + str3 + ", Creating table using query: " + build.getCreateOutputTableQuery());
                testQueryRunner.runQuery(build.getCreateOutputTableQuery());
                LOG.info("Table created " + str4);
            }
            LOG.info("Running evaluation query " + testQuery);
            QueryHandle runQuery = testQueryRunner.runQuery(testQuery);
            MLTestReport mLTestReport = new MLTestReport();
            mLTestReport.setReportID(replace);
            mLTestReport.setAlgorithm(str2);
            mLTestReport.setFeatureColumns(loadModel.getFeatureColumns());
            mLTestReport.setLabelColumn(loadModel.getLabelColumn());
            mLTestReport.setModelID(loadModel.getId());
            mLTestReport.setOutputColumn("prediction_result");
            mLTestReport.setOutputTable(str4);
            mLTestReport.setTestTable(str);
            mLTestReport.setQueryID(runQuery.toString());
            persistTestReport(mLTestReport);
            LOG.info("Saved test report " + mLTestReport.getReportID());
            return mLTestReport;
        } catch (IOException e) {
            throw new LensException(e);
        }
    }

    private void persistTestReport(MLTestReport mLTestReport) throws LensException {
        LOG.info("saving test report " + mLTestReport.getReportID());
        try {
            ModelLoader.saveTestReport(this.conf, mLTestReport);
            LOG.info("Saved report " + mLTestReport.getReportID());
        } catch (IOException e) {
            LOG.error("Error saving report " + mLTestReport.getReportID() + " reason: " + e.getMessage());
        }
    }

    @Override // org.apache.lens.ml.LensML
    public List<String> getTestReports(String str) throws LensException {
        Path path = new Path(this.conf.get(ModelLoader.TEST_REPORT_BASE_DIR, ModelLoader.TEST_REPORT_BASE_DIR_DEFAULT));
        try {
            FileSystem fileSystem = path.getFileSystem(this.conf);
            if (!fileSystem.exists(path)) {
                return null;
            }
            Path path2 = new Path(path, str);
            if (!fileSystem.exists(path2)) {
                return null;
            }
            ArrayList arrayList = new ArrayList();
            for (FileStatus fileStatus : fileSystem.listStatus(path2)) {
                arrayList.add(fileStatus.getPath().getName());
            }
            return arrayList;
        } catch (IOException e) {
            LOG.error("Error reading report list for " + str, e);
            return null;
        }
    }

    @Override // org.apache.lens.ml.LensML
    public MLTestReport getTestReport(String str, String str2) throws LensException {
        try {
            return ModelLoader.loadReport(this.conf, str, str2);
        } catch (IOException e) {
            throw new LensException(e);
        }
    }

    @Override // org.apache.lens.ml.LensML
    public Object predict(String str, String str2, Object[] objArr) throws LensException {
        return getModel(str, str2).predict(objArr);
    }

    @Override // org.apache.lens.ml.LensML
    public void deleteModel(String str, String str2) throws LensException {
        try {
            ModelLoader.deleteModel(this.conf, str, str2);
            LOG.info("DELETED model " + str2 + " algorithm=" + str);
        } catch (IOException e) {
            LOG.error("Error deleting model file. algorithm=" + str + " model=" + str2 + " reason: " + e.getMessage(), e);
            throw new LensException("Unable to delete model " + str2 + " for algorithm " + str, e);
        }
    }

    @Override // org.apache.lens.ml.LensML
    public void deleteTestReport(String str, String str2) throws LensException {
        try {
            ModelLoader.deleteTestReport(this.conf, str, str2);
            LOG.info("DELETED report=" + str2 + " algorithm=" + str);
        } catch (IOException e) {
            LOG.error("Error deleting report " + str2 + " algorithm=" + str + " reason: " + e.getMessage(), e);
            throw new LensException("Unable to delete report " + str2 + " for algorithm " + str, e);
        }
    }

    @Override // org.apache.lens.ml.LensML
    public Map<String, String> getAlgoParamDescription(String str) {
        try {
            MLAlgo algoForName = getAlgoForName(str);
            if (algoForName instanceof BaseSparkAlgo) {
                return ((BaseSparkAlgo) algoForName).getArgUsage();
            }
            return null;
        } catch (LensException e) {
            LOG.error("Error getting algo description : " + str, e);
            return null;
        }
    }

    private LensConf toLensConf(HiveConf hiveConf) {
        LensConf lensConf = new LensConf();
        lensConf.getProperties().putAll(hiveConf.getValByRegex(".*"));
        return lensConf;
    }
}
