package org.apache.spark.mllib.tree;

import org.apache.spark.Logging;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.tree.configuration.Algo$;
import org.apache.spark.mllib.tree.configuration.BoostingStrategy;
import org.apache.spark.mllib.tree.configuration.Strategy;
import org.apache.spark.mllib.tree.impl.TimeTracker;
import org.apache.spark.mllib.tree.impurity.Variance$;
import org.apache.spark.mllib.tree.loss.Loss;
import org.apache.spark.mllib.tree.model.DecisionTreeModel;
import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel;
import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel$;
import org.apache.spark.rdd.RDD;
import org.apache.spark.rdd.RDD$;
import org.apache.spark.storage.StorageLevel;
import org.apache.spark.storage.StorageLevel$;
import org.slf4j.Logger;
import scala.Function0;
import scala.Predef$;
import scala.Serializable;
import scala.StringContext;
import scala.Tuple2;
import scala.math.Ordering$Double$;
import scala.reflect.ClassTag$;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.IntRef;
import scala.runtime.ObjectRef;

/* compiled from: GradientBoostedTrees.scala */
/* loaded from: input_file:org/apache/spark/mllib/tree/GradientBoostedTrees$.class */
public final class GradientBoostedTrees$ implements Logging, Serializable {
    public static final GradientBoostedTrees$ MODULE$ = null;
    private transient Logger org$apache$spark$Logging$$log_;

    static {
        new GradientBoostedTrees$();
    }

    public Logger org$apache$spark$Logging$$log_() {
        return this.org$apache$spark$Logging$$log_;
    }

    public void org$apache$spark$Logging$$log__$eq(Logger logger) {
        this.org$apache$spark$Logging$$log_ = logger;
    }

    public String logName() {
        return Logging.class.logName(this);
    }

    public Logger log() {
        return Logging.class.log(this);
    }

    public void logInfo(Function0<String> function0) {
        Logging.class.logInfo(this, function0);
    }

    public void logDebug(Function0<String> function0) {
        Logging.class.logDebug(this, function0);
    }

    public void logTrace(Function0<String> function0) {
        Logging.class.logTrace(this, function0);
    }

    public void logWarning(Function0<String> function0) {
        Logging.class.logWarning(this, function0);
    }

    public void logError(Function0<String> function0) {
        Logging.class.logError(this, function0);
    }

    public void logInfo(Function0<String> function0, Throwable th) {
        Logging.class.logInfo(this, function0, th);
    }

    public void logDebug(Function0<String> function0, Throwable th) {
        Logging.class.logDebug(this, function0, th);
    }

    public void logTrace(Function0<String> function0, Throwable th) {
        Logging.class.logTrace(this, function0, th);
    }

    public void logWarning(Function0<String> function0, Throwable th) {
        Logging.class.logWarning(this, function0, th);
    }

    public void logError(Function0<String> function0, Throwable th) {
        Logging.class.logError(this, function0, th);
    }

    public boolean isTraceEnabled() {
        return Logging.class.isTraceEnabled(this);
    }

    public GradientBoostedTreesModel train(RDD<LabeledPoint> rdd, BoostingStrategy boostingStrategy) {
        return new GradientBoostedTrees(boostingStrategy).run(rdd);
    }

    public GradientBoostedTreesModel train(JavaRDD<LabeledPoint> javaRDD, BoostingStrategy boostingStrategy) {
        return train(javaRDD.rdd(), boostingStrategy);
    }

    public GradientBoostedTreesModel org$apache$spark$mllib$tree$GradientBoostedTrees$$boost(RDD<LabeledPoint> rdd, RDD<LabeledPoint> rdd2, BoostingStrategy boostingStrategy, boolean z) {
        boolean z2;
        TimeTracker timeTracker = new TimeTracker();
        timeTracker.start("total");
        timeTracker.start("init");
        boostingStrategy.assertValid();
        int numIterations = boostingStrategy.numIterations();
        DecisionTreeModel[] decisionTreeModelArr = new DecisionTreeModel[numIterations];
        double[] dArr = new double[numIterations];
        Loss loss = boostingStrategy.loss();
        double learningRate = boostingStrategy.learningRate();
        Strategy copy = boostingStrategy.treeStrategy().copy();
        double validationTol = boostingStrategy.validationTol();
        copy.algo_$eq(Algo$.MODULE$.Regression());
        copy.impurity_$eq(Variance$.MODULE$);
        copy.assertValid();
        StorageLevel storageLevel = rdd.getStorageLevel();
        StorageLevel NONE = StorageLevel$.MODULE$.NONE();
        if (storageLevel != null ? !storageLevel.equals(NONE) : NONE != null) {
            z2 = false;
        } else {
            rdd.persist(StorageLevel$.MODULE$.MEMORY_AND_DISK());
            z2 = true;
        }
        boolean z3 = z2;
        timeTracker.stop("init");
        logDebug(new GradientBoostedTrees$$anonfun$org$apache$spark$mllib$tree$GradientBoostedTrees$$boost$1());
        logDebug(new GradientBoostedTrees$$anonfun$org$apache$spark$mllib$tree$GradientBoostedTrees$$boost$2());
        logDebug(new GradientBoostedTrees$$anonfun$org$apache$spark$mllib$tree$GradientBoostedTrees$$boost$3());
        timeTracker.start("building tree 0");
        DecisionTreeModel run = new DecisionTree(copy).run(rdd);
        decisionTreeModelArr[0] = run;
        dArr[0] = 1.0d;
        ObjectRef create = ObjectRef.create(GradientBoostedTreesModel$.MODULE$.computeInitialPredictionAndError(rdd, 1.0d, run, loss));
        logDebug(new GradientBoostedTrees$$anonfun$org$apache$spark$mllib$tree$GradientBoostedTrees$$boost$4(create));
        timeTracker.stop("building tree 0");
        RDD<Tuple2<Object, Object>> computeInitialPredictionAndError = GradientBoostedTreesModel$.MODULE$.computeInitialPredictionAndError(rdd2, 1.0d, run, loss);
        double mean = z ? RDD$.MODULE$.doubleRDDToDoubleRDDFunctions(RDD$.MODULE$.rddToPairRDDFunctions(computeInitialPredictionAndError, ClassTag$.MODULE$.Double(), ClassTag$.MODULE$.Double(), Ordering$Double$.MODULE$).values()).mean() : 0.0d;
        int i = 1;
        RDD<LabeledPoint> map = ((RDD) create.elem).zip(rdd, ClassTag$.MODULE$.apply(LabeledPoint.class)).map(new GradientBoostedTrees$$anonfun$org$apache$spark$mllib$tree$GradientBoostedTrees$$boost$5(loss), ClassTag$.MODULE$.apply(LabeledPoint.class));
        IntRef create2 = IntRef.create(1);
        while (create2.elem < numIterations) {
            timeTracker.start(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"building tree ", ""})).s(Predef$.MODULE$.genericWrapArray(new Object[]{BoxesRunTime.boxToInteger(create2.elem)})));
            logDebug(new GradientBoostedTrees$$anonfun$org$apache$spark$mllib$tree$GradientBoostedTrees$$boost$6());
            logDebug(new GradientBoostedTrees$$anonfun$org$apache$spark$mllib$tree$GradientBoostedTrees$$boost$7(create2));
            logDebug(new GradientBoostedTrees$$anonfun$org$apache$spark$mllib$tree$GradientBoostedTrees$$boost$8());
            DecisionTreeModel run2 = new DecisionTree(copy).run(map);
            timeTracker.stop(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"building tree ", ""})).s(Predef$.MODULE$.genericWrapArray(new Object[]{BoxesRunTime.boxToInteger(create2.elem)})));
            decisionTreeModelArr[create2.elem] = run2;
            dArr[create2.elem] = learningRate;
            new GradientBoostedTreesModel(Algo$.MODULE$.Regression(), (DecisionTreeModel[]) Predef$.MODULE$.refArrayOps(decisionTreeModelArr).slice(0, create2.elem + 1), (double[]) Predef$.MODULE$.doubleArrayOps(dArr).slice(0, create2.elem + 1));
            create.elem = GradientBoostedTreesModel$.MODULE$.updatePredictionError(rdd, (RDD) create.elem, dArr[create2.elem], decisionTreeModelArr[create2.elem], loss);
            logDebug(new GradientBoostedTrees$$anonfun$org$apache$spark$mllib$tree$GradientBoostedTrees$$boost$9(create));
            if (z) {
                computeInitialPredictionAndError = GradientBoostedTreesModel$.MODULE$.updatePredictionError(rdd2, computeInitialPredictionAndError, dArr[create2.elem], decisionTreeModelArr[create2.elem], loss);
                double mean2 = RDD$.MODULE$.doubleRDDToDoubleRDDFunctions(RDD$.MODULE$.rddToPairRDDFunctions(computeInitialPredictionAndError, ClassTag$.MODULE$.Double(), ClassTag$.MODULE$.Double(), Ordering$Double$.MODULE$).values()).mean();
                if (mean - mean2 < validationTol) {
                    return new GradientBoostedTreesModel(boostingStrategy.treeStrategy().algo(), (DecisionTreeModel[]) Predef$.MODULE$.refArrayOps(decisionTreeModelArr).slice(0, i), (double[]) Predef$.MODULE$.doubleArrayOps(dArr).slice(0, i));
                }
                if (mean2 < mean) {
                    mean = mean2;
                    i = create2.elem + 1;
                }
            }
            map = ((RDD) create.elem).zip(rdd, ClassTag$.MODULE$.apply(LabeledPoint.class)).map(new GradientBoostedTrees$$anonfun$org$apache$spark$mllib$tree$GradientBoostedTrees$$boost$10(loss), ClassTag$.MODULE$.apply(LabeledPoint.class));
            create2.elem++;
        }
        timeTracker.stop("total");
        logInfo(new GradientBoostedTrees$$anonfun$org$apache$spark$mllib$tree$GradientBoostedTrees$$boost$11());
        logInfo(new GradientBoostedTrees$$anonfun$org$apache$spark$mllib$tree$GradientBoostedTrees$$boost$12(timeTracker));
        if (z3) {
            rdd.unpersist(rdd.unpersist$default$1());
        } else {
            BoxedUnit boxedUnit = BoxedUnit.UNIT;
        }
        return z ? new GradientBoostedTreesModel(boostingStrategy.treeStrategy().algo(), (DecisionTreeModel[]) Predef$.MODULE$.refArrayOps(decisionTreeModelArr).slice(0, i), (double[]) Predef$.MODULE$.doubleArrayOps(dArr).slice(0, i)) : new GradientBoostedTreesModel(boostingStrategy.treeStrategy().algo(), decisionTreeModelArr, dArr);
    }

    private Object readResolve() {
        return MODULE$;
    }

    private GradientBoostedTrees$() {
        MODULE$ = this;
        Logging.class.$init$(this);
    }
}
