package org.apache.spark.ml.classification;

import java.io.IOException;
import org.apache.spark.SparkContext;
import org.apache.spark.SparkException;
import org.apache.spark.ml.PredictionModel;
import org.apache.spark.ml.Predictor;
import org.apache.spark.ml.boosting.BoostingParams;
import org.apache.spark.ml.ensemble.Utils$;
import org.apache.spark.ml.feature.Instance;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.param.IntParam;
import org.apache.spark.ml.param.Param;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.ml.param.ParamValidators$;
import org.apache.spark.ml.param.shared.HasAggregationDepth;
import org.apache.spark.ml.param.shared.HasCheckpointInterval;
import org.apache.spark.ml.param.shared.HasWeightCol;
import org.apache.spark.ml.util.DefaultParamsReader;
import org.apache.spark.ml.util.Identifiable$;
import org.apache.spark.ml.util.Instrumentation$;
import org.apache.spark.ml.util.MLReader;
import org.apache.spark.ml.util.MLWritable;
import org.apache.spark.ml.util.MLWriter;
import org.apache.spark.rdd.RDD;
import org.apache.spark.rdd.util.PeriodicRDDCheckpointer;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.functions$;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.storage.StorageLevel$;
import scala.Array$;
import scala.MatchError;
import scala.Option;
import scala.Predef$;
import scala.Some;
import scala.Tuple2;
import scala.collection.mutable.ArrayOps;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.reflect.api.Mirror;
import scala.reflect.api.TypeCreator;
import scala.reflect.api.Types;
import scala.reflect.api.Universe;
import scala.reflect.runtime.package$;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.DoubleRef;

/* compiled from: BoostingClassifier.scala */
@ScalaSignature(bytes = "\u0006\u0001\u0005\u001dg\u0001B\u000f\u001f\u0001%B\u0001\"\u0011\u0001\u0003\u0006\u0004%\tE\u0011\u0005\t!\u0002\u0011\t\u0011)A\u0005\u0007\")\u0011\u000b\u0001C\u0001%\")A\u000b\u0001C\u0001+\")A\u000e\u0001C\u0001[\")1\u000f\u0001C\u0001i\")a\u000f\u0001C\u0001o\")\u0011\u0010\u0001C\u0001u\")\u0011\u000b\u0001C\u0001y\")Q\u0010\u0001C!}\"9\u0011q\u0002\u0001\u0005\u0002\u0005E\u0001bBA\u0011\u0001\u0011E\u00131\u0005\u0005\b\u0003\u001b\u0002A\u0011IA(\u000f\u001d\t9F\bE\u0001\u000332a!\b\u0010\t\u0002\u0005m\u0003BB)\u0010\t\u0003\ty\u0007C\u0004\u0002r=!\t%a\u001d\t\u000f\u0005mt\u0002\"\u0011\u0002~\u00199\u00111Q\b\u0001\u001f\u0005\u0015\u0005\"CAD'\t\u0005\t\u0015!\u00035\u0011\u0019\t6\u0003\"\u0001\u0002\n\"9\u0011\u0011S\n\u0005R\u0005MeABAO\u001f\u0011\ty\n\u0003\u0004R/\u0011\u0005\u0011\u0011\u0015\u0005\n\u0003K;\"\u0019!C\u0005\u0003OC\u0001\"a.\u0018A\u0003%\u0011\u0011\u0016\u0005\b\u0003w:B\u0011IA]\u0011%\tilDA\u0001\n\u0013\tyL\u0001\nC_>\u001cH/\u001b8h\u00072\f7o]5gS\u0016\u0014(BA\u0010!\u00039\u0019G.Y:tS\u001aL7-\u0019;j_:T!!\t\u0012\u0002\u00055d'BA\u0012%\u0003\u0015\u0019\b/\u0019:l\u0015\t)c%\u0001\u0004ba\u0006\u001c\u0007.\u001a\u0006\u0002O\u0005\u0019qN]4\u0004\u0001M!\u0001A\u000b\u001d<!\u0015YCF\f\u001b6\u001b\u0005q\u0012BA\u0017\u001f\u0005)\u0019E.Y:tS\u001aLWM\u001d\t\u0003_Ij\u0011\u0001\r\u0006\u0003c\u0001\na\u0001\\5oC2<\u0017BA\u001a1\u0005\u00191Vm\u0019;peB\u00111\u0006\u0001\t\u0003WYJ!a\u000e\u0010\u00037\t{wn\u001d;j]\u001e\u001cE.Y:tS\u001aL7-\u0019;j_:lu\u000eZ3m!\tY\u0013(\u0003\u0002;=\tA\"i\\8ti&twm\u00117bgNLg-[3s!\u0006\u0014\u0018-\\:\u0011\u0005qzT\"A\u001f\u000b\u0005y\u0002\u0013\u0001B;uS2L!\u0001Q\u001f\u0003\u00155cuK]5uC\ndW-A\u0002vS\u0012,\u0012a\u0011\t\u0003\t6s!!R&\u0011\u0005\u0019KU\"A$\u000b\u0005!C\u0013A\u0002\u001fs_>$hHC\u0001K\u0003\u0015\u00198-\u00197b\u0013\ta\u0015*\u0001\u0004Qe\u0016$WMZ\u0005\u0003\u001d>\u0013aa\u0015;sS:<'B\u0001'J\u0003\u0011)\u0018\u000e\u001a\u0011\u0002\rqJg.\u001b;?)\t!4\u000bC\u0003B\u0007\u0001\u00071)\u0001\btKR\u0014\u0015m]3MK\u0006\u0014h.\u001a:\u0015\u0005Y;V\"\u0001\u0001\t\u000ba#\u0001\u0019A-\u0002\u000bY\fG.^3\u0011\u0005iKgBA.g\u001d\taFM\u0004\u0002^G:\u0011aL\u0019\b\u0003?\u0006t!A\u00121\n\u0003\u001dJ!!\n\u0014\n\u0005\r\"\u0013BA\u0011#\u0013\t)\u0007%\u0001\u0005f]N,WN\u00197f\u0013\t9\u0007.A\u0004qC\u000e\\\u0017mZ3\u000b\u0005\u0015\u0004\u0013B\u00016l\u0005Y)en]3nE2,7\t\\1tg&4\u0017.\u001a:UsB,'BA4i\u0003I\u0019X\r\u001e(v[\n\u000b7/\u001a'fCJtWM]:\u0015\u0005Ys\u0007\"\u0002-\u0006\u0001\u0004y\u0007C\u00019r\u001b\u0005I\u0015B\u0001:J\u0005\rIe\u000e^\u0001\u0016g\u0016$8\t[3dWB|\u0017N\u001c;J]R,'O^1m)\t1V\u000fC\u0003Y\r\u0001\u0007q.\u0001\u0007tKR<V-[4ii\u000e{G\u000e\u0006\u0002Wq\")\u0001l\u0002a\u0001\u0007\u0006a1/\u001a;BY\u001e|'/\u001b;i[R\u0011ak\u001f\u0005\u00061\"\u0001\ra\u0011\u000b\u0002i\u0005!1m\u001c9z)\t!t\u0010C\u0004\u0002\u0002)\u0001\r!a\u0001\u0002\u000b\u0015DHO]1\u0011\t\u0005\u0015\u00111B\u0007\u0003\u0003\u000fQ1!!\u0003!\u0003\u0015\u0001\u0018M]1n\u0013\u0011\ti!a\u0002\u0003\u0011A\u000b'/Y7NCB\fQ!\u001a:s_J$b!a\u0005\u0002\u001a\u0005u\u0001c\u00019\u0002\u0016%\u0019\u0011qC%\u0003\r\u0011{WO\u00197f\u0011\u001d\tYb\u0003a\u0001\u0003'\tQ\u0001\\1cK2Dq!a\b\f\u0001\u0004\t\u0019\"\u0001\u0006qe\u0016$\u0017n\u0019;j_:\fQ\u0001\u001e:bS:$2!NA\u0013\u0011\u001d\t9\u0003\u0004a\u0001\u0003S\tq\u0001Z1uCN,G\u000f\r\u0003\u0002,\u0005m\u0002CBA\u0017\u0003g\t9$\u0004\u0002\u00020)\u0019\u0011\u0011\u0007\u0012\u0002\u0007M\fH.\u0003\u0003\u00026\u0005=\"a\u0002#bi\u0006\u001cX\r\u001e\t\u0005\u0003s\tY\u0004\u0004\u0001\u0005\u0019\u0005u\u0012QEA\u0001\u0002\u0003\u0015\t!a\u0010\u0003\u0007}#\u0013'\u0005\u0003\u0002B\u0005\u001d\u0003c\u00019\u0002D%\u0019\u0011QI%\u0003\u000f9{G\u000f[5oOB\u0019\u0001/!\u0013\n\u0007\u0005-\u0013JA\u0002B]f\fQa\u001e:ji\u0016,\"!!\u0015\u0011\u0007q\n\u0019&C\u0002\u0002Vu\u0012\u0001\"\u0014'Xe&$XM]\u0001\u0013\u0005>|7\u000f^5oO\u000ec\u0017m]:jM&,'\u000f\u0005\u0002,\u001fM9q\"!\u0018\u0002d\u0005%\u0004c\u00019\u0002`%\u0019\u0011\u0011M%\u0003\r\u0005s\u0017PU3g!\u0011a\u0014Q\r\u001b\n\u0007\u0005\u001dTH\u0001\u0006N\u0019J+\u0017\rZ1cY\u0016\u00042\u0001]A6\u0013\r\ti'\u0013\u0002\r'\u0016\u0014\u0018.\u00197ju\u0006\u0014G.\u001a\u000b\u0003\u00033\nAA]3bIV\u0011\u0011Q\u000f\t\u0005y\u0005]D'C\u0002\u0002zu\u0012\u0001\"\u0014'SK\u0006$WM]\u0001\u0005Y>\fG\rF\u00025\u0003\u007fBa!!!\u0013\u0001\u0004\u0019\u0015\u0001\u00029bi\"\u0014\u0001DQ8pgRLgnZ\"mCN\u001c\u0018NZ5fe^\u0013\u0018\u000e^3s'\r\u0019\u0012\u0011K\u0001\tS:\u001cH/\u00198dKR!\u00111RAH!\r\tiiE\u0007\u0002\u001f!1\u0011qQ\u000bA\u0002Q\n\u0001b]1wK&k\u0007\u000f\u001c\u000b\u0005\u0003+\u000bY\nE\u0002q\u0003/K1!!'J\u0005\u0011)f.\u001b;\t\r\u0005\u0005e\u00031\u0001D\u0005a\u0011un\\:uS:<7\t\\1tg&4\u0017.\u001a:SK\u0006$WM]\n\u0004/\u0005UDCAAR!\r\tiiF\u0001\nG2\f7o\u001d(b[\u0016,\"!!+\u0011\t\u0005-\u0016QW\u0007\u0003\u0003[SA!a,\u00022\u0006!A.\u00198h\u0015\t\t\u0019,\u0001\u0003kCZ\f\u0017b\u0001(\u0002.\u0006Q1\r\\1tg:\u000bW.\u001a\u0011\u0015\u0007Q\nY\f\u0003\u0004\u0002\u0002n\u0001\raQ\u0001\fe\u0016\fGMU3t_24X\r\u0006\u0002\u0002BB!\u00111VAb\u0013\u0011\t)-!,\u0003\r=\u0013'.Z2u\u0001")
/* loaded from: input_file:org/apache/spark/ml/classification/BoostingClassifier.class */
public class BoostingClassifier extends Classifier<Vector, BoostingClassifier, BoostingClassificationModel> implements BoostingClassifierParams, MLWritable {
    private final String uid;
    private final Param<String> algorithm;
    private final IntParam aggregationDepth;
    private final IntParam checkpointInterval;
    private final Param<Classifier<Vector, ? extends Classifier<Vector, Classifier, ClassificationModel>, ? extends ClassificationModel<Vector, ClassificationModel>>> baseLearner;
    private final Param<String> weightCol;
    private final Param<Object> numBaseLearners;

    /* JADX INFO: Access modifiers changed from: private */
    /* compiled from: BoostingClassifier.scala */
    /* loaded from: input_file:org/apache/spark/ml/classification/BoostingClassifier$BoostingClassifierReader.class */
    public static class BoostingClassifierReader extends MLReader<BoostingClassifier> {
        private final String className = BoostingClassifier.class.getName();

        private String className() {
            return this.className;
        }

        /* renamed from: load, reason: merged with bridge method [inline-methods] */
        public BoostingClassifier m41load(String str) {
            Tuple2<DefaultParamsReader.Metadata, Classifier<Vector, ? extends Classifier<Vector, Classifier, ClassificationModel>, ? extends ClassificationModel<Vector, ClassificationModel>>> loadImpl = BoostingClassifierParams$.MODULE$.loadImpl(str, sc(), className());
            if (loadImpl == null) {
                throw new MatchError(loadImpl);
            }
            Tuple2 tuple2 = new Tuple2((DefaultParamsReader.Metadata) loadImpl._1(), (Classifier) loadImpl._2());
            DefaultParamsReader.Metadata metadata = (DefaultParamsReader.Metadata) tuple2._1();
            Classifier<Vector, ? extends Classifier<Vector, Classifier, ClassificationModel>, ? extends ClassificationModel<Vector, ClassificationModel>> classifier = (Classifier) tuple2._2();
            BoostingClassifier boostingClassifier = new BoostingClassifier(metadata.uid());
            metadata.getAndSetParams(boostingClassifier, metadata.getAndSetParams$default$2());
            return boostingClassifier.setBaseLearner(classifier);
        }
    }

    /* compiled from: BoostingClassifier.scala */
    /* loaded from: input_file:org/apache/spark/ml/classification/BoostingClassifier$BoostingClassifierWriter.class */
    public static class BoostingClassifierWriter extends MLWriter {
        private final BoostingClassifier instance;

        public void saveImpl(String str) {
            BoostingClassifierParams$.MODULE$.saveImpl(this.instance, str, sc(), BoostingClassifierParams$.MODULE$.saveImpl$default$4());
        }

        public BoostingClassifierWriter(BoostingClassifier boostingClassifier) {
            this.instance = boostingClassifier;
        }
    }

    public static BoostingClassifier load(String str) {
        return BoostingClassifier$.MODULE$.m40load(str);
    }

    public static MLReader<BoostingClassifier> read() {
        return BoostingClassifier$.MODULE$.read();
    }

    public void save(String str) throws IOException {
        MLWritable.save$(this, str);
    }

    @Override // org.apache.spark.ml.classification.BoostingClassifierParams
    public String getAlgorithm() {
        String algorithm;
        algorithm = getAlgorithm();
        return algorithm;
    }

    public final int getAggregationDepth() {
        return HasAggregationDepth.getAggregationDepth$(this);
    }

    public final int getCheckpointInterval() {
        return HasCheckpointInterval.getCheckpointInterval$(this);
    }

    @Override // org.apache.spark.ml.ensemble.HasBaseLearner
    public Predictor getBaseLearner() {
        Predictor baseLearner;
        baseLearner = getBaseLearner();
        return baseLearner;
    }

    @Override // org.apache.spark.ml.ensemble.HasBaseLearner
    public PredictionModel fitBaseLearner(Predictor predictor, String str, String str2, String str3, Option option, Dataset dataset) {
        PredictionModel fitBaseLearner;
        fitBaseLearner = fitBaseLearner(predictor, str, str2, str3, option, dataset);
        return fitBaseLearner;
    }

    public final String getWeightCol() {
        return HasWeightCol.getWeightCol$(this);
    }

    @Override // org.apache.spark.ml.ensemble.HasNumBaseLearners
    public int getNumBaseLearners() {
        int numBaseLearners;
        numBaseLearners = getNumBaseLearners();
        return numBaseLearners;
    }

    @Override // org.apache.spark.ml.classification.BoostingClassifierParams
    public Param<String> algorithm() {
        return this.algorithm;
    }

    @Override // org.apache.spark.ml.classification.BoostingClassifierParams
    public void org$apache$spark$ml$classification$BoostingClassifierParams$_setter_$algorithm_$eq(Param<String> param) {
        this.algorithm = param;
    }

    public final IntParam aggregationDepth() {
        return this.aggregationDepth;
    }

    public final void org$apache$spark$ml$param$shared$HasAggregationDepth$_setter_$aggregationDepth_$eq(IntParam intParam) {
        this.aggregationDepth = intParam;
    }

    public final IntParam checkpointInterval() {
        return this.checkpointInterval;
    }

    public final void org$apache$spark$ml$param$shared$HasCheckpointInterval$_setter_$checkpointInterval_$eq(IntParam intParam) {
        this.checkpointInterval = intParam;
    }

    @Override // org.apache.spark.ml.ensemble.HasBaseLearner
    public Param<Classifier<Vector, ? extends Classifier<Vector, Classifier, ClassificationModel>, ? extends ClassificationModel<Vector, ClassificationModel>>> baseLearner() {
        return this.baseLearner;
    }

    @Override // org.apache.spark.ml.ensemble.HasBaseLearner
    public void org$apache$spark$ml$ensemble$HasBaseLearner$_setter_$baseLearner_$eq(Param<Classifier<Vector, ? extends Classifier<Vector, Classifier, ClassificationModel>, ? extends ClassificationModel<Vector, ClassificationModel>>> param) {
        this.baseLearner = param;
    }

    public final Param<String> weightCol() {
        return this.weightCol;
    }

    public final void org$apache$spark$ml$param$shared$HasWeightCol$_setter_$weightCol_$eq(Param<String> param) {
        this.weightCol = param;
    }

    @Override // org.apache.spark.ml.ensemble.HasNumBaseLearners
    public Param<Object> numBaseLearners() {
        return this.numBaseLearners;
    }

    @Override // org.apache.spark.ml.ensemble.HasNumBaseLearners
    public void org$apache$spark$ml$ensemble$HasNumBaseLearners$_setter_$numBaseLearners_$eq(Param<Object> param) {
        this.numBaseLearners = param;
    }

    public String uid() {
        return this.uid;
    }

    public BoostingClassifier setBaseLearner(Classifier<Vector, ? extends Classifier<Vector, Classifier, ClassificationModel>, ? extends ClassificationModel<Vector, ClassificationModel>> classifier) {
        return (BoostingClassifier) set(baseLearner(), classifier);
    }

    public BoostingClassifier setNumBaseLearners(int i) {
        return (BoostingClassifier) set(numBaseLearners(), BoxesRunTime.boxToInteger(i));
    }

    public BoostingClassifier setCheckpointInterval(int i) {
        return (BoostingClassifier) set(checkpointInterval(), BoxesRunTime.boxToInteger(i));
    }

    public BoostingClassifier setWeightCol(String str) {
        return (BoostingClassifier) set(weightCol(), str);
    }

    public BoostingClassifier setAlgorithm(String str) {
        return (BoostingClassifier) set(algorithm(), str);
    }

    /* renamed from: copy, reason: merged with bridge method [inline-methods] and merged with bridge method [inline-methods] and merged with bridge method [inline-methods] and merged with bridge method [inline-methods] */
    public BoostingClassifier m38copy(ParamMap paramMap) {
        BoostingClassifier boostingClassifier = new BoostingClassifier(uid());
        copyValues(boostingClassifier, paramMap);
        return boostingClassifier.setBaseLearner((Classifier) boostingClassifier.getBaseLearner().copy(paramMap));
    }

    public double error(double d, double d2) {
        return d != d2 ? 1.0d : 0.0d;
    }

    public BoostingClassificationModel train(Dataset<?> dataset) {
        return (BoostingClassificationModel) Instrumentation$.MODULE$.instrumented(instrumentation -> {
            double unboxToDouble;
            BoxedUnit boxedUnit;
            instrumentation.logPipelineStage(this);
            instrumentation.logDataset(dataset);
            instrumentation.logParams(this, Predef$.MODULE$.wrapRefArray(new Param[]{this.labelCol(), this.featuresCol(), this.predictionCol(), this.weightCol(), this.algorithm(), this.numBaseLearners(), this.checkpointInterval()}));
            SparkSession sparkSession = dataset.sparkSession();
            SparkContext sparkContext = sparkSession.sparkContext();
            int numClasses = this.getNumClasses(dataset, this.getNumClasses$default$2());
            instrumentation.logNumClasses(numClasses);
            this.validateNumClasses(numClasses);
            RDD extractInstances = this.extractInstances(dataset, instance -> {
                $anonfun$train$2(this, numClasses, instance);
                return BoxedUnit.UNIT;
            });
            extractInstances.persist(StorageLevel$.MODULE$.MEMORY_AND_DISK());
            extractInstances.count();
            Metadata featuresMetadata = Utils$.MODULE$.getFeaturesMetadata(dataset, (String) this.$(this.featuresCol()), Utils$.MODULE$.getFeaturesMetadata$default$3());
            PredictionModel[] predictionModelArr = (PredictionModel[]) Array$.MODULE$.ofDim(BoxesRunTime.unboxToInt(this.$(this.numBaseLearners())), ClassTag$.MODULE$.apply(PredictionModel.class));
            double[] dArr = (double[]) Array$.MODULE$.ofDim(BoxesRunTime.unboxToInt(this.$(this.numBaseLearners())), ClassTag$.MODULE$.Double());
            RDD map = extractInstances.map(instance2 -> {
                return BoxesRunTime.boxToDouble(instance2.weight());
            }, ClassTag$.MODULE$.Double());
            PeriodicRDDCheckpointer periodicRDDCheckpointer = new PeriodicRDDCheckpointer(BoxesRunTime.unboxToInt(this.$(this.checkpointInterval())), sparkContext, StorageLevel$.MODULE$.MEMORY_AND_DISK());
            periodicRDDCheckpointer.update(map);
            DoubleRef create = DoubleRef.create(BoxesRunTime.unboxToDouble(map.treeReduce((d, d2) -> {
                return d + d2;
            }, BoxesRunTime.unboxToInt(this.$(this.aggregationDepth())))));
            int i = 0;
            boolean z = false;
            while (i < BoxesRunTime.unboxToInt(this.$(this.numBaseLearners())) && !z && create.elem > 0) {
                instrumentation.logNamedValue("iteration", i);
                RDD map2 = extractInstances.zip(map, ClassTag$.MODULE$.Double()).map(tuple2 -> {
                    if (tuple2 == null) {
                        throw new MatchError(tuple2);
                    }
                    Instance instance3 = (Instance) tuple2._1();
                    return instance3.copy(instance3.copy$default$1(), tuple2._2$mcD$sp() / create.elem, instance3.copy$default$3());
                }, ClassTag$.MODULE$.apply(Instance.class));
                final BoostingClassifier boostingClassifier = null;
                ProbabilisticClassificationModel fitBaseLearner = this.fitBaseLearner((Predictor) this.$(this.baseLearner()), "label", "features", (String) this.$(this.predictionCol()), new Some("weight"), sparkSession.createDataFrame(map2, package$.MODULE$.universe().TypeTag().apply(package$.MODULE$.universe().runtimeMirror(BoostingClassifier.class.getClassLoader()), new TypeCreator(boostingClassifier) { // from class: org.apache.spark.ml.classification.BoostingClassifier$$typecreator1$1
                    public <U extends Universe> Types.TypeApi apply(Mirror<U> mirror) {
                        mirror.universe();
                        return mirror.staticClass("org.apache.spark.ml.feature.Instance").asType().toTypeConstructor();
                    }
                })).withColumn("features", functions$.MODULE$.col("features"), featuresMetadata));
                if (fitBaseLearner instanceof ProbabilisticClassificationModel) {
                    ProbabilisticClassificationModel probabilisticClassificationModel = fitBaseLearner;
                    Object $ = this.$(this.algorithm());
                    if ($ != null ? $.equals("real") : "real" == 0) {
                        RDD map3 = map2.map(instance3 -> {
                            return probabilisticClassificationModel.predictProbability(instance3.features());
                        }, ClassTag$.MODULE$.apply(Vector.class));
                        if (BoxesRunTime.unboxToDouble(map2.zip(map3, ClassTag$.MODULE$.apply(Vector.class)).treeAggregate(BoxesRunTime.boxToDouble(0.0d), (obj, tuple22) -> {
                            return BoxesRunTime.boxToDouble($anonfun$train$7(this, BoxesRunTime.unboxToDouble(obj), tuple22));
                        }, (d3, d4) -> {
                            return d3 + d4;
                        }, BoxesRunTime.unboxToInt(this.$(this.aggregationDepth())), ClassTag$.MODULE$.Double())) <= 0) {
                            z = true;
                        }
                        dArr[i] = 1.0d;
                        predictionModelArr[i] = probabilisticClassificationModel;
                        map = map2.zip(map3, ClassTag$.MODULE$.apply(Vector.class)).map(tuple23 -> {
                            return BoxesRunTime.boxToDouble($anonfun$train$9(numClasses, tuple23));
                        }, ClassTag$.MODULE$.Double());
                        boxedUnit = BoxedUnit.UNIT;
                        periodicRDDCheckpointer.update(map);
                        create.elem = BoxesRunTime.unboxToDouble(map.treeReduce((d5, d6) -> {
                            return d5 + d6;
                        }, BoxesRunTime.unboxToInt(this.$(this.aggregationDepth()))));
                        i++;
                    }
                }
                if (fitBaseLearner instanceof ClassificationModel) {
                    ClassificationModel classificationModel = (ClassificationModel) fitBaseLearner;
                    Object $2 = this.$(this.algorithm());
                    if ($2 == null) {
                        if ("discrete" != 0) {
                        }
                        RDD map4 = map2.map(instance4 -> {
                            return BoxesRunTime.boxToDouble($anonfun$train$10(this, classificationModel, instance4));
                        }, ClassTag$.MODULE$.Double());
                        unboxToDouble = BoxesRunTime.unboxToDouble(map2.zip(map4, ClassTag$.MODULE$.Double()).treeAggregate(BoxesRunTime.boxToDouble(0.0d), (obj2, tuple24) -> {
                            return BoxesRunTime.boxToDouble($anonfun$train$11(BoxesRunTime.unboxToDouble(obj2), tuple24));
                        }, (d7, d8) -> {
                            return d7 + d8;
                        }, BoxesRunTime.unboxToInt(this.$(this.aggregationDepth())), ClassTag$.MODULE$.Double()));
                        if (unboxToDouble <= 0) {
                            z = true;
                        }
                        double d9 = unboxToDouble / ((1 - unboxToDouble) * (numClasses - 1));
                        dArr[i] = d9 != 0.0d ? 1.0d : scala.math.package$.MODULE$.log(1.0d / d9);
                        predictionModelArr[i] = classificationModel;
                        if (unboxToDouble >= 1.0d - (1.0d / numClasses)) {
                            i--;
                            z = true;
                        }
                        map = map2.zip(map4, ClassTag$.MODULE$.Double()).map(tuple25 -> {
                            return BoxesRunTime.boxToDouble($anonfun$train$13(d9, tuple25));
                        }, ClassTag$.MODULE$.Double());
                        boxedUnit = BoxedUnit.UNIT;
                    } else {
                        if (!$2.equals("discrete")) {
                        }
                        RDD map42 = map2.map(instance42 -> {
                            return BoxesRunTime.boxToDouble($anonfun$train$10(this, classificationModel, instance42));
                        }, ClassTag$.MODULE$.Double());
                        unboxToDouble = BoxesRunTime.unboxToDouble(map2.zip(map42, ClassTag$.MODULE$.Double()).treeAggregate(BoxesRunTime.boxToDouble(0.0d), (obj22, tuple242) -> {
                            return BoxesRunTime.boxToDouble($anonfun$train$11(BoxesRunTime.unboxToDouble(obj22), tuple242));
                        }, (d72, d82) -> {
                            return d72 + d82;
                        }, BoxesRunTime.unboxToInt(this.$(this.aggregationDepth())), ClassTag$.MODULE$.Double()));
                        if (unboxToDouble <= 0) {
                        }
                        double d92 = unboxToDouble / ((1 - unboxToDouble) * (numClasses - 1));
                        dArr[i] = d92 != 0.0d ? 1.0d : scala.math.package$.MODULE$.log(1.0d / d92);
                        predictionModelArr[i] = classificationModel;
                        if (unboxToDouble >= 1.0d - (1.0d / numClasses)) {
                        }
                        map = map2.zip(map42, ClassTag$.MODULE$.Double()).map(tuple252 -> {
                            return BoxesRunTime.boxToDouble($anonfun$train$13(d92, tuple252));
                        }, ClassTag$.MODULE$.Double());
                        boxedUnit = BoxedUnit.UNIT;
                    }
                    periodicRDDCheckpointer.update(map);
                    create.elem = BoxesRunTime.unboxToDouble(map.treeReduce((d52, d62) -> {
                        return d52 + d62;
                    }, BoxesRunTime.unboxToInt(this.$(this.aggregationDepth()))));
                    i++;
                }
                throw new SparkException(new StringBuilder(52).append("algorithm \"").append(this.$(this.algorithm())).append("\" is not compatible with base learner \"").append(this.$(this.baseLearner())).append("\".").toString());
            }
            periodicRDDCheckpointer.unpersistDataSet();
            periodicRDDCheckpointer.deleteAllCheckpoints();
            extractInstances.unpersist(extractInstances.unpersist$default$1());
            return new BoostingClassificationModel(numClasses, (double[]) new ArrayOps.ofDouble(Predef$.MODULE$.doubleArrayOps(dArr)).take(i), (PredictionModel[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(predictionModelArr)).take(i));
        });
    }

    public MLWriter write() {
        return new BoostingClassifierWriter(this);
    }

    /* renamed from: train, reason: collision with other method in class */
    public /* bridge */ /* synthetic */ PredictionModel m34train(Dataset dataset) {
        return train((Dataset<?>) dataset);
    }

    public static final /* synthetic */ void $anonfun$train$2(BoostingClassifier boostingClassifier, int i, Instance instance) {
        boostingClassifier.validateLabel(instance.label(), i);
    }

    public static final /* synthetic */ double $anonfun$train$7(BoostingClassifier boostingClassifier, double d, Tuple2 tuple2) {
        Tuple2 tuple22 = new Tuple2(BoxesRunTime.boxToDouble(d), tuple2);
        if (tuple22 != null) {
            double _1$mcD$sp = tuple22._1$mcD$sp();
            Tuple2 tuple23 = (Tuple2) tuple22._2();
            if (tuple23 != null) {
                Instance instance = (Instance) tuple23._1();
                return _1$mcD$sp + (instance.weight() * boostingClassifier.error(instance.label(), ((Vector) tuple23._2()).argmax()));
            }
        }
        throw new MatchError(tuple22);
    }

    public static final /* synthetic */ double $anonfun$train$9(int i, Tuple2 tuple2) {
        if (tuple2 == null) {
            throw new MatchError(tuple2);
        }
        Instance instance = (Instance) tuple2._1();
        Vector vector = (Vector) tuple2._2();
        double d = 0.0d;
        int i2 = 0;
        while (true) {
            int i3 = i2;
            if (i3 >= i) {
                return instance.weight() * scala.math.package$.MODULE$.exp((-((i - 1.0d) / i)) * d);
            }
            d += (instance.label() == ((double) i3) ? 1.0d : (-1) / (i - 1.0d)) * scala.math.package$.MODULE$.log(scala.math.package$.MODULE$.max(vector.apply(i3), org.apache.spark.ml.impl.Utils$.MODULE$.EPSILON()));
            i2 = i3 + 1;
        }
    }

    public static final /* synthetic */ double $anonfun$train$10(BoostingClassifier boostingClassifier, ClassificationModel classificationModel, Instance instance) {
        return boostingClassifier.error(instance.label(), classificationModel.predict(instance.features()));
    }

    public static final /* synthetic */ double $anonfun$train$11(double d, Tuple2 tuple2) {
        Tuple2 tuple22 = new Tuple2(BoxesRunTime.boxToDouble(d), tuple2);
        if (tuple22 != null) {
            double _1$mcD$sp = tuple22._1$mcD$sp();
            Tuple2 tuple23 = (Tuple2) tuple22._2();
            if (tuple23 != null) {
                Instance instance = (Instance) tuple23._1();
                return _1$mcD$sp + (instance.weight() * tuple23._2$mcD$sp());
            }
        }
        throw new MatchError(tuple22);
    }

    public static final /* synthetic */ double $anonfun$train$13(double d, Tuple2 tuple2) {
        if (tuple2 == null) {
            throw new MatchError(tuple2);
        }
        return ((Instance) tuple2._1()).weight() * scala.math.package$.MODULE$.pow(1 / d, tuple2._2$mcD$sp());
    }

    public BoostingClassifier(String str) {
        this.uid = str;
        org$apache$spark$ml$ensemble$HasNumBaseLearners$_setter_$numBaseLearners_$eq(new IntParam(this, "numBaseLearners", "number of base learners that will be used by the ensemble learner", ParamValidators$.MODULE$.gtEq(1.0d)));
        HasWeightCol.$init$(this);
        org$apache$spark$ml$ensemble$HasBaseLearner$_setter_$baseLearner_$eq(new Param<>(this, "baseLearner", "base learner that will be used by the ensemble learner"));
        HasCheckpointInterval.$init$(this);
        HasAggregationDepth.$init$(this);
        BoostingParams.$init$((BoostingParams) this);
        BoostingClassifierParams.$init$((BoostingClassifierParams) this);
        MLWritable.$init$(this);
    }

    public BoostingClassifier() {
        this(Identifiable$.MODULE$.randomUID("BoostingClassifier"));
    }
}
