package org.apache.spark.ml.spark.models.svm;

import hex.ModelBuilder;
import hex.ModelCategory;
import hex.ModelMetrics;
import java.util.Arrays;
import java.util.HashSet;
import org.apache.spark.SparkContext;
import org.apache.spark.h2o.H2OContext;
import org.apache.spark.ml.spark.ProgressListener;
import org.apache.spark.ml.spark.models.svm.SVMModel;
import org.apache.spark.mllib.classification.SVMWithSGD;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.storage.RDDInfo;
import scala.collection.JavaConversions;
import water.DKV;
import water.fvec.Frame;
import water.fvec.H2OFrame;
import water.fvec.Vec;
import water.util.Log;

/* loaded from: input_file:org/apache/spark/ml/spark/models/svm/SVM.class */
public class SVM extends ModelBuilder<SVMModel, SVMParameters, SVMModel.SVMOutput> {
    private final transient H2OContext hc;

    /* loaded from: input_file:org/apache/spark/ml/spark/models/svm/SVM$SVMDriver.class */
    private final class SVMDriver extends ModelBuilder<SVMModel, SVMParameters, SVMModel.SVMOutput>.Driver {
        private transient SparkContext sc;
        private transient H2OContext h2oContext;
        private transient SQLContext sqlContext;

        private SVMDriver() {
            super(SVM.this);
            this.sc = SVM.this.hc.sparkContext();
            this.h2oContext = SVM.this.hc;
            this.sqlContext = SQLContext.getOrCreate(this.sc);
        }

        public void computeImpl() {
            SVM.this.init(true);
            SVMModel sVMModel = new SVMModel(SVM.this.dest(), (SVMParameters) SVM.this._parms, new SVMModel.SVMOutput(SVM.this));
            sVMModel.delete_and_lock(SVM.this._job);
            RDD<LabeledPoint> trainingData = getTrainingData(SVM.this._train, ((SVMParameters) SVM.this._parms)._response_column, ((SVMModel.SVMOutput) sVMModel._output).nfeatures());
            trainingData.cache();
            SVMWithSGD sVMWithSGD = new SVMWithSGD();
            sVMWithSGD.setIntercept(((SVMParameters) SVM.this._parms)._add_intercept);
            sVMWithSGD.optimizer().setNumIterations(((SVMParameters) SVM.this._parms)._max_iterations);
            sVMWithSGD.optimizer().setStepSize(((SVMParameters) SVM.this._parms)._step_size);
            sVMWithSGD.optimizer().setRegParam(((SVMParameters) SVM.this._parms)._reg_param);
            sVMWithSGD.optimizer().setMiniBatchFraction(((SVMParameters) SVM.this._parms)._mini_batch_fraction);
            sVMWithSGD.optimizer().setConvergenceTol(((SVMParameters) SVM.this._parms)._convergence_tol);
            sVMWithSGD.optimizer().setGradient(((SVMParameters) SVM.this._parms)._gradient.get());
            sVMWithSGD.optimizer().setUpdater(((SVMParameters) SVM.this._parms)._updater.get());
            ProgressListener progressListener = new ProgressListener(this.sc, SVM.this._job, RDDInfo.fromRdd(trainingData), JavaConversions.iterableAsScalaIterable(Arrays.asList("treeAggregate")));
            this.sc.addSparkListener(progressListener);
            org.apache.spark.mllib.classification.SVMModel run = null == ((SVMParameters) SVM.this._parms)._initial_weights ? (org.apache.spark.mllib.classification.SVMModel) sVMWithSGD.run(trainingData) : sVMWithSGD.run(trainingData, vec2vec(((SVMParameters) SVM.this._parms).initialWeights().vecs()));
            trainingData.unpersist(false);
            this.sc.listenerBus().listeners().remove(progressListener);
            ((SVMModel.SVMOutput) sVMModel._output).weights_$eq(run.weights().toArray());
            ((SVMModel.SVMOutput) sVMModel._output).iterations_$eq(((SVMParameters) SVM.this._parms)._max_iterations);
            ((SVMModel.SVMOutput) sVMModel._output).interceptor_$eq(run.intercept());
            Frame frame = (Frame) DKV.getGet(((SVMParameters) SVM.this._parms)._train);
            sVMModel.score(frame).delete();
            ((SVMModel.SVMOutput) sVMModel._output)._training_metrics = ModelMetrics.getFromDKV(sVMModel, frame);
            sVMModel.update(SVM.this._job);
            if (SVM.this._valid != null) {
                sVMModel.score(((SVMParameters) SVM.this._parms).valid()).delete();
                ((SVMModel.SVMOutput) sVMModel._output)._validation_metrics = ModelMetrics.getFromDKV(sVMModel, ((SVMParameters) SVM.this._parms).valid());
                sVMModel.update(SVM.this._job);
            }
            ((SVMModel.SVMOutput) sVMModel._output).interceptor_$eq(run.intercept());
            Log.info(new Object[]{((SVMModel.SVMOutput) sVMModel._output)._model_summary});
        }

        private Vector vec2vec(Vec[] vecArr) {
            double[] dArr = new double[vecArr.length];
            for (int i = 0; i < vecArr.length; i++) {
                dArr[i] = vecArr[i].at(0L);
            }
            return Vectors.dense(dArr);
        }

        private RDD<LabeledPoint> getTrainingData(Frame frame, String str, int i) {
            return this.h2oContext.asDataFrame(new H2OFrame(frame), true, this.sqlContext).javaRDD().map(new RowToLabeledPoint(i, str, frame.domains())).rdd();
        }
    }

    public SVM(boolean z, H2OContext h2OContext) {
        super(new SVMParameters(), z);
        this.hc = h2OContext;
    }

    public SVM(SVMParameters sVMParameters, H2OContext h2OContext) {
        super(sVMParameters);
        init(false);
        this.hc = h2OContext;
    }

    protected ModelBuilder<SVMModel, SVMParameters, SVMModel.SVMOutput>.Driver trainModelImpl() {
        return new SVMDriver();
    }

    public ModelCategory[] can_build() {
        return new ModelCategory[]{ModelCategory.Binomial, ModelCategory.Regression};
    }

    public boolean isSupervised() {
        return true;
    }

    public void init(boolean z) {
        super.init(z);
        ((SVMParameters) this._parms).validate(this);
        if (this._train == null) {
            return;
        }
        if (null != ((SVMParameters) this._parms)._initial_weights) {
            Frame frame = ((SVMParameters) this._parms)._initial_weights.get();
            if (frame.numCols() != this._train.numCols() - numSpecialCols()) {
                error("_initial_weights", "The user-specified initial weights must have the same number of columns (" + (this._train.numCols() - numSpecialCols()) + ") as the training observations");
            }
            if (frame.hasNAs()) {
                error("_initial_weights", "Initial weights cannot contain missing values.");
            }
        }
        for (int i = 0; i < this._train.numCols(); i++) {
            Vec vec = this._train.vec(i);
            String name = this._train.name(i);
            if (vec.naCnt() > 0 && (null == ((SVMParameters) this._parms)._ignored_columns || Arrays.binarySearch(((SVMParameters) this._parms)._ignored_columns, name) < 0)) {
                error("_train", "Training frame cannot contain any missing values [" + name + "].");
            }
        }
        HashSet hashSet = null != ((SVMParameters) this._parms)._ignored_columns ? new HashSet(Arrays.asList(((SVMParameters) this._parms)._ignored_columns)) : new HashSet();
        for (int i2 = 0; i2 < this._train.vecs().length; i2++) {
            Vec vec2 = this._train.vec(i2);
            if (!hashSet.contains(this._train.name(i2)) && !vec2.isNumeric() && !vec2.isCategorical()) {
                error("_train", "SVM supports only frames with numeric values (except for result column). But a " + vec2.get_type_str() + " was found.");
            }
        }
        if (null != ((SVMParameters) this._parms)._response_column && null == this._train.vec(((SVMParameters) this._parms)._response_column)) {
            error("_train", "Training frame has to contain the response column.");
        }
        if (this._train == null || ((SVMParameters) this._parms)._response_column == null) {
            return;
        }
        String[] responseDomains = responseDomains();
        if (null == responseDomains) {
            if (!Double.isNaN(((SVMParameters) this._parms)._threshold)) {
                error("_threshold", "Threshold cannot be set for regression SVM. Set the threshold to NaN or modify the response column to an enum.");
            }
            if (this._train.vec(((SVMParameters) this._parms)._response_column).isNumeric()) {
                return;
            }
            error("_response_column", "Regression SVM requires the response column type to be numeric.");
            return;
        }
        if (Double.isNaN(((SVMParameters) this._parms)._threshold)) {
            error("_threshold", "Threshold has to be set for binomial SVM. Set the threshold to a numeric value or change the response column type.");
        }
        if (responseDomains.length != 2) {
            error("_response_column", "SVM requires the response column's domain to be of size 2.");
        }
    }

    private String[] responseDomains() {
        int find = ((SVMParameters) this._parms).train().find(((SVMParameters) this._parms)._response_column);
        if (find == -1) {
            return null;
        }
        return ((SVMParameters) this._parms).train().domains()[find];
    }

    public int numSpecialCols() {
        return (hasOffsetCol() ? 1 : 0) + (hasWeightCol() ? 1 : 0) + (hasFoldCol() ? 1 : 0) + 1;
    }
}
