package org.apache.lens.ml.algo.spark;

import java.io.File;
import java.io.FilenameFilter;
import java.util.ArrayList;
import java.util.List;
import org.apache.commons.lang.StringUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.lens.api.LensConf;
import org.apache.lens.api.LensException;
import org.apache.lens.ml.algo.api.MLAlgo;
import org.apache.lens.ml.algo.api.MLDriver;
import org.apache.lens.ml.algo.lib.Algorithms;
import org.apache.lens.ml.algo.spark.dt.DecisionTreeAlgo;
import org.apache.lens.ml.algo.spark.lr.LogisticRegressionAlgo;
import org.apache.lens.ml.algo.spark.nb.NaiveBayesAlgo;
import org.apache.lens.ml.algo.spark.svm.SVMAlgo;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaSparkContext;

/* loaded from: input_file:org/apache/lens/ml/algo/spark/SparkMLDriver.class */
public class SparkMLDriver implements MLDriver {
    public static final Log LOG = LogFactory.getLog(SparkMLDriver.class);
    private boolean ownsSparkContext = true;
    private final Algorithms algorithms = new Algorithms();
    private SparkMasterMode clientMode = SparkMasterMode.EMBEDDED;
    private boolean isStarted;
    private SparkConf sparkConf;
    private JavaSparkContext sparkContext;

    /* loaded from: input_file:org/apache/lens/ml/algo/spark/SparkMLDriver$SparkMasterMode.class */
    private enum SparkMasterMode {
        EMBEDDED,
        YARN_CLIENT,
        YARN_CLUSTER
    }

    public void useSparkContext(JavaSparkContext javaSparkContext) {
        this.ownsSparkContext = false;
        this.sparkContext = javaSparkContext;
    }

    @Override // org.apache.lens.ml.algo.api.MLDriver
    public boolean isAlgoSupported(String str) {
        return this.algorithms.isAlgoSupported(str);
    }

    @Override // org.apache.lens.ml.algo.api.MLDriver
    public MLAlgo getAlgoInstance(String str) throws LensException {
        checkStarted();
        if (!isAlgoSupported(str)) {
            return null;
        }
        MLAlgo mLAlgo = null;
        try {
            mLAlgo = this.algorithms.getAlgoForName(str);
            if (mLAlgo instanceof BaseSparkAlgo) {
                ((BaseSparkAlgo) mLAlgo).setSparkContext(this.sparkContext);
            }
        } catch (LensException e) {
            LOG.error("Error creating algo object", e);
        }
        return mLAlgo;
    }

    private void registerAlgos() {
        this.algorithms.register(NaiveBayesAlgo.class);
        this.algorithms.register(SVMAlgo.class);
        this.algorithms.register(LogisticRegressionAlgo.class);
        this.algorithms.register(DecisionTreeAlgo.class);
    }

    @Override // org.apache.lens.ml.algo.api.MLDriver
    public void init(LensConf lensConf) throws LensException {
        this.sparkConf = new SparkConf();
        registerAlgos();
        for (String str : lensConf.getProperties().keySet()) {
            if (str.startsWith("lens.ml.sparkdriver.")) {
                this.sparkConf.set(str.substring("lens.ml.sparkdriver.".length()), (String) lensConf.getProperties().get(str));
            }
        }
        String str2 = this.sparkConf.get("spark.master");
        if ("yarn-client".equalsIgnoreCase(str2)) {
            this.clientMode = SparkMasterMode.YARN_CLIENT;
        } else if ("yarn-cluster".equalsIgnoreCase(str2)) {
            this.clientMode = SparkMasterMode.YARN_CLUSTER;
        } else {
            if (!"local".equalsIgnoreCase(str2) && !StringUtils.isBlank(str2)) {
                throw new IllegalArgumentException("Invalid master mode " + str2);
            }
            this.clientMode = SparkMasterMode.EMBEDDED;
        }
        if (this.clientMode == SparkMasterMode.YARN_CLIENT || this.clientMode == SparkMasterMode.YARN_CLUSTER) {
            String str3 = System.getenv("SPARK_HOME");
            if (StringUtils.isNotBlank(str3)) {
                this.sparkConf.setSparkHome(str3);
            }
            if (StringUtils.isBlank(this.sparkConf.get("spark.home"))) {
                throw new IllegalArgumentException("Spark home is not set");
            }
            LOG.info("Spark home is set to " + this.sparkConf.get("spark.home"));
        }
        this.sparkConf.setAppName("lens-ml");
    }

    @Override // org.apache.lens.ml.algo.api.MLDriver
    public void start() throws LensException {
        if (this.sparkContext == null) {
            this.sparkContext = new JavaSparkContext(this.sparkConf);
        }
        if (this.clientMode != SparkMasterMode.EMBEDDED) {
            String str = System.getenv("HIVE_HOME");
            if (StringUtils.isBlank(str)) {
                throw new LensException("HIVE_HOME is not set");
            }
            LOG.info("HIVE_HOME at " + str);
            File file = new File(str, "lib");
            FilenameFilter filenameFilter = new FilenameFilter() { // from class: org.apache.lens.ml.algo.spark.SparkMLDriver.1
                @Override // java.io.FilenameFilter
                public boolean accept(File file2, String str2) {
                    return str2.endsWith(".jar");
                }
            };
            ArrayList arrayList = new ArrayList();
            for (File file2 : file.listFiles(filenameFilter)) {
                arrayList.add(file2.getAbsolutePath());
                LOG.info("Adding HIVE jar " + file2.getAbsolutePath());
                this.sparkContext.addJar(file2.getAbsolutePath());
            }
            for (File file3 : new File(str + "/hcatalog/share/hcatalog").listFiles(filenameFilter)) {
                arrayList.add(file3.getAbsolutePath());
                LOG.info("Adding HCATALOG jar " + file3.getAbsolutePath());
                this.sparkContext.addJar(file3.getAbsolutePath());
            }
            for (String str2 : JavaSparkContext.jarOfClass(SparkMLDriver.class)) {
                LOG.info("Adding Lens JAR " + str2);
                this.sparkContext.addJar(str2);
            }
        }
        this.isStarted = true;
        LOG.info("Created Spark context for app: '" + this.sparkContext.appName() + "', Spark master: " + this.sparkContext.master());
    }

    @Override // org.apache.lens.ml.algo.api.MLDriver
    public void stop() throws LensException {
        if (!this.isStarted) {
            LOG.warn("Spark driver was not started");
            return;
        }
        this.isStarted = false;
        if (this.ownsSparkContext) {
            this.sparkContext.stop();
        }
        LOG.info("Stopped spark context " + this);
    }

    @Override // org.apache.lens.ml.algo.api.MLDriver
    public List<String> getAlgoNames() {
        return this.algorithms.getAlgorithmNames();
    }

    public void checkStarted() throws LensException {
        if (!this.isStarted) {
            throw new LensException("Spark driver is not started yet");
        }
    }

    public JavaSparkContext getSparkContext() {
        return this.sparkContext;
    }
}
