package org.apache.spark.ml.boosting;

import org.apache.commons.math3.distribution.PoissonDistribution;
import org.apache.spark.ml.PredictionModel;
import org.apache.spark.ml.PredictorParams;
import org.apache.spark.ml.ensemble.HasBaseLearner;
import org.apache.spark.ml.ensemble.HasNumBaseLearners;
import org.apache.spark.ml.ensemble.HasNumRound;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.param.ParamPair;
import org.apache.spark.ml.param.shared.HasSeed;
import org.apache.spark.ml.param.shared.HasTol;
import org.apache.spark.ml.param.shared.HasValidationIndicatorCol;
import org.apache.spark.ml.param.shared.HasWeightCol;
import org.apache.spark.ml.regression.BoostingRegressionModel;
import org.apache.spark.ml.util.Instrumentation;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.expressions.UserDefinedFunction;
import org.apache.spark.sql.functions$;
import scala.Function1;
import scala.MatchError;
import scala.Predef$;
import scala.Tuple2;
import scala.Tuple3;
import scala.collection.Iterator;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.reflect.runtime.package$;
import scala.runtime.BoxesRunTime;

/* compiled from: BoostingParams.scala */
@ScalaSignature(bytes = "\u0006\u0001\u0005Eh\u0001\u0003\b\u0010!\u0003\r\t!E\r\t\u000b\u0005\u0003A\u0011A\"\t\u000b\u001d\u0003A\u0011\u0003%\t\r\u001d\u0003A\u0011CA\u0007\u0011\u001d\t9\u0003\u0001C\t\u0003SAq!a\u000f\u0001\t#\ti\u0004C\u0004\u0002R\u0001!\t\"a\u0015\t\u000f\u0005-\u0003\u0001\"\u0005\u0002^!I\u0011Q\r\u0001\u0012\u0002\u0013E\u0011q\r\u0005\b\u0003{\u0002A\u0011CA@\u0011\u001d\t\u0019\t\u0001C\t\u0003\u000bCq!a&\u0001\t#\tI\nC\u0004\u0002T\u0002!\t\"!6\t\u0013\u0005-\b!%A\u0005\u0012\u00055(A\u0004\"p_N$\u0018N\\4QCJ\fWn\u001d\u0006\u0003!E\t\u0001BY8pgRLgn\u001a\u0006\u0003%M\t!!\u001c7\u000b\u0005Q)\u0012!B:qCJ\\'B\u0001\f\u0018\u0003\u0019\t\u0007/Y2iK*\t\u0001$A\u0002pe\u001e\u001c\"\u0002\u0001\u000e!I)\u0012T\u0007O\u001e?!\tYb$D\u0001\u001d\u0015\u0005i\u0012!B:dC2\f\u0017BA\u0010\u001d\u0005\u0019\te.\u001f*fMB\u0011\u0011EI\u0007\u0002#%\u00111%\u0005\u0002\u0010!J,G-[2u_J\u0004\u0016M]1ngB\u0011Q\u0005K\u0007\u0002M)\u0011q%E\u0001\tK:\u001cX-\u001c2mK&\u0011\u0011F\n\u0002\u0013\u0011\u0006\u001ch*^7CCN,G*Z1s]\u0016\u00148\u000f\u0005\u0002,a5\tAF\u0003\u0002.]\u000511\u000f[1sK\u0012T!aL\t\u0002\u000bA\f'/Y7\n\u0005Eb#\u0001\u0004%bg^+\u0017n\u001a5u\u0007>d\u0007CA\u00164\u0013\t!DFA\u0004ICN\u001cV-\u001a3\u0011\u0005\u00152\u0014BA\u001c'\u00059A\u0015m\u001d\"bg\u0016dU-\u0019:oKJ\u0004\"aK\u001d\n\u0005ib#!\u0007%bgZ\u000bG.\u001b3bi&|g.\u00138eS\u000e\fGo\u001c:D_2\u0004\"a\u000b\u001f\n\u0005ub#A\u0002%bgR{G\u000e\u0005\u0002&\u007f%\u0011\u0001I\n\u0002\f\u0011\u0006\u001ch*^7S_VtG-\u0001\u0004%S:LG\u000fJ\u0002\u0001)\u0005!\u0005CA\u000eF\u0013\t1ED\u0001\u0003V]&$\u0018\u0001F3wC2,\u0018\r^3P]Z\u000bG.\u001b3bi&|g\u000eF\u0004JG\",x0a\u0001\u0015\u0005)k\u0005CA\u000eL\u0013\taED\u0001\u0004E_V\u0014G.\u001a\u0005\u0006\u001d\n\u0001\raT\u0001\u0003I\u001a\u0004\"\u0001\u00151\u000f\u0005EkfB\u0001*\\\u001d\t\u0019&L\u0004\u0002U3:\u0011Q\u000bW\u0007\u0002-*\u0011qKQ\u0001\u0007yI|w\u000e\u001e \n\u0003aI!AF\f\n\u0005Q)\u0012B\u0001/\u0014\u0003\r\u0019\u0018\u000f\\\u0005\u0003=~\u000bq\u0001]1dW\u0006<WM\u0003\u0002]'%\u0011\u0011M\u0019\u0002\n\t\u0006$\u0018M\u0012:b[\u0016T!AX0\t\u000b\u0011\u0014\u0001\u0019A3\u0002\u000f],\u0017n\u001a5ugB\u00191D\u001a&\n\u0005\u001dd\"!B!se\u0006L\b\"B5\u0003\u0001\u0004Q\u0017\u0001\u00032p_N$XM]:\u0011\u0007m17\u000e\u0005\u0002me:\u0011Q.\u001d\b\u0003]Bt!AU8\n\u0005I\u0019\u0012BA\u0014\u0012\u0013\tqf%\u0003\u0002ti\nYRI\\:f[\ndW\r\u0015:fI&\u001cG/[8o\u001b>$W\r\u001c+za\u0016T!A\u0018\u0014\t\u000bY\u0014\u0001\u0019A<\u0002\u00191\f'-\u001a7D_2t\u0015-\\3\u0011\u0005adhBA={!\t)F$\u0003\u0002|9\u00051\u0001K]3eK\u001aL!! @\u0003\rM#(/\u001b8h\u0015\tYH\u0004\u0003\u0004\u0002\u0002\t\u0001\ra^\u0001\u0010M\u0016\fG/\u001e:fg\u000e{GNT1nK\"9\u0011Q\u0001\u0002A\u0002\u0005\u001d\u0011\u0001\u00027pgN\u0004RaGA\u0005\u0015*K1!a\u0003\u001d\u0005%1UO\\2uS>t\u0017\u0007\u0006\b\u0002\u0010\u0005M\u0011QDA\u0010\u0003C\t\u0019#!\n\u0015\u0007)\u000b\t\u0002C\u0003O\u0007\u0001\u0007q\nC\u0004\u0002\u0016\r\u0001\r!a\u0006\u0002\u00159,Xn\u00117bgN,7\u000fE\u0002\u001c\u00033I1!a\u0007\u001d\u0005\rIe\u000e\u001e\u0005\u0006I\u000e\u0001\r!\u001a\u0005\u0006S\u000e\u0001\rA\u001b\u0005\u0006m\u000e\u0001\ra\u001e\u0005\u0007\u0003\u0003\u0019\u0001\u0019A<\t\u000f\u0005\u00151\u00011\u0001\u0002\b\u0005Y\u0001O]8cC\nLG.\u001b>f)!\tY#a\f\u00024\u0005]BcA(\u0002.!)a\n\u0002a\u0001\u001f\"1\u0011\u0011\u0007\u0003A\u0002]\f!CY8pgR<V-[4ii\u000e{GNT1nK\"1\u0011Q\u0007\u0003A\u0002]\f\u0011CY8pgR\u0004&o\u001c2b\u0007>dg*Y7f\u0011\u0019\tI\u0004\u0002a\u0001o\u0006\u0019\u0002o\\5tg>t\u0007K]8cC\u000e{GNT1nK\u0006iQ\u000f\u001d3bi\u0016<V-[4iiN$\"\"a\u0010\u0002D\u0005\u0015\u0013\u0011JA')\ry\u0015\u0011\t\u0005\u0006\u001d\u0016\u0001\ra\u0014\u0005\u0007\u0003c)\u0001\u0019A<\t\r\u0005\u001dS\u00011\u0001x\u0003-awn]:D_2t\u0015-\\3\t\r\u0005-S\u00011\u0001K\u0003\u0011\u0011W\r^1\t\r\u0005=S\u00011\u0001x\u0003e)\b\u000fZ1uK\u0012\u0014un\\:u/\u0016Lw\r\u001b;D_2t\u0015-\\3\u0002\u000f\u00054x\rT8tgR1\u0011QKA-\u00037\"2ASA,\u0011\u0015qe\u00011\u0001P\u0011\u0019\t9E\u0002a\u0001o\"1\u0011Q\u0007\u0004A\u0002]$RASA0\u0003GBa!!\u0019\b\u0001\u0004Q\u0015\u0001B1wO2D\u0011\"!\u0006\b!\u0003\u0005\r!a\u0006\u0002\u001d\t,G/\u0019\u0013eK\u001a\fW\u000f\u001c;%eU\u0011\u0011\u0011\u000e\u0016\u0005\u0003/\tYg\u000b\u0002\u0002nA!\u0011qNA=\u001b\t\t\tH\u0003\u0003\u0002t\u0005U\u0014!C;oG\",7m[3e\u0015\r\t9\bH\u0001\u000bC:tw\u000e^1uS>t\u0017\u0002BA>\u0003c\u0012\u0011#\u001e8dQ\u0016\u001c7.\u001a3WCJL\u0017M\\2f\u0003\u00199X-[4iiR\u0019!*!!\t\r\u0005-\u0013\u00021\u0001K\u0003E)\u0007\u0010\u001e:bGR\u0014un\\:uK\u0012\u0014\u0015m\u001a\u000b\u0007\u0003\u000f\u000bY)!$\u0015\u0007=\u000bI\tC\u0003O\u0015\u0001\u0007q\n\u0003\u0004\u0002:)\u0001\ra\u001e\u0005\b\u0003\u001fS\u0001\u0019AAI\u0003\u0011\u0019X-\u001a3\u0011\u0007m\t\u0019*C\u0002\u0002\u0016r\u0011A\u0001T8oO\u0006aA/\u001a:nS:\fG/\u001a,bYR\u0011\u00121TAQ\u0003W\u000by+a-\u00028\u0006m\u0016qXAb!!Y\u0012QTA\f\u0015\u0006]\u0011bAAP9\t1A+\u001e9mKNBq!a)\f\u0001\u0004\t)+\u0001\bxSRDg+\u00197jI\u0006$\u0018n\u001c8\u0011\u0007m\t9+C\u0002\u0002*r\u0011qAQ8pY\u0016\fg\u000e\u0003\u0004\u0002..\u0001\rAS\u0001\u0006KJ\u0014xN\u001d\u0005\u0007\u0003c[\u0001\u0019\u0001&\u0002\rY,'O]8s\u0011\u0019\t)l\u0003a\u0001\u0015\u0006\u0019Ao\u001c7\t\u000f\u0005e6\u00021\u0001\u0002\u0018\u0005Aa.^7S_VtG\rC\u0004\u0002>.\u0001\r!a\u0006\u0002\r9,X\u000e\u0016:z\u0011\u001d\t\tm\u0003a\u0001\u0003/\tA!\u001b;fe\"9\u0011QY\u0006A\u0002\u0005\u001d\u0017aD5ogR\u0014X/\\3oi\u0006$\u0018n\u001c8\u0011\t\u0005%\u0017qZ\u0007\u0003\u0003\u0017T1!!4\u0012\u0003\u0011)H/\u001b7\n\t\u0005E\u00171\u001a\u0002\u0010\u0013:\u001cHO];nK:$\u0018\r^5p]\u0006IA/\u001a:nS:\fG/\u001a\u000b\u0017\u00037\u000b9.!7\u0002\\\u0006u\u0017q\\Aq\u0003G\f)/a:\u0002j\"1\u0011\u0011\r\u0007A\u0002)Cq!a)\r\u0001\u0004\t)\u000b\u0003\u0004\u0002.2\u0001\rA\u0013\u0005\u0007\u0003cc\u0001\u0019\u0001&\t\r\u0005UF\u00021\u0001K\u0011\u001d\tI\f\u0004a\u0001\u0003/Aq!!0\r\u0001\u0004\t9\u0002C\u0004\u0002B2\u0001\r!a\u0006\t\u000f\u0005\u0015G\u00021\u0001\u0002H\"A\u0011Q\u0003\u0007\u0011\u0002\u0003\u0007!*\u0001\u000buKJl\u0017N\\1uK\u0012\"WMZ1vYR$\u0013\u0007M\u000b\u0003\u0003_T3ASA6\u0001")
/* loaded from: input_file:org/apache/spark/ml/boosting/BoostingParams.class */
public interface BoostingParams extends PredictorParams, HasNumBaseLearners, HasWeightCol, HasSeed, HasBaseLearner, HasValidationIndicatorCol, HasTol, HasNumRound {
    default double evaluateOnValidation(double[] dArr, PredictionModel<Vector, ? extends PredictionModel<Vector, PredictionModel>>[] predictionModelArr, String str, String str2, Function1<Object, Object> function1, Dataset<Row> dataset) {
        BoostingRegressionModel featuresCol = new BoostingRegressionModel(dArr, predictionModelArr).setFeaturesCol(str2);
        UserDefinedFunction udf = functions$.MODULE$.udf(function1, package$.MODULE$.universe().TypeTag().Double(), package$.MODULE$.universe().TypeTag().Double());
        if (dataset.isEmpty()) {
            return Double.MAX_VALUE;
        }
        return ((Row) featuresCol.transform(dataset).agg(functions$.MODULE$.sum(udf.apply(Predef$.MODULE$.wrapRefArray(new Column[]{functions$.MODULE$.abs(functions$.MODULE$.col(str).$minus(functions$.MODULE$.col(featuresCol.getPredictionCol())))}))), Predef$.MODULE$.wrapRefArray(new Column[0])).head()).getDouble(0);
    }

    default double evaluateOnValidation(int i, double[] dArr, PredictionModel<Vector, ? extends PredictionModel<Vector, PredictionModel>>[] predictionModelArr, String str, String str2, Function1<Object, Object> function1, Dataset<Row> dataset) {
        BoostingRegressionModel featuresCol = new BoostingRegressionModel(dArr, predictionModelArr).setFeaturesCol(str2);
        UserDefinedFunction udf = functions$.MODULE$.udf(function1, package$.MODULE$.universe().TypeTag().Double(), package$.MODULE$.universe().TypeTag().Double());
        if (dataset.isEmpty()) {
            return Double.MAX_VALUE;
        }
        return ((Row) featuresCol.transform(dataset).agg(functions$.MODULE$.sum(udf.apply(Predef$.MODULE$.wrapRefArray(new Column[]{functions$.MODULE$.when(functions$.MODULE$.col(str).$eq$eq$eq(functions$.MODULE$.col(featuresCol.getPredictionCol())), BoxesRunTime.boxToDouble(0.0d)).otherwise(BoxesRunTime.boxToDouble(1.0d))}))), Predef$.MODULE$.wrapRefArray(new Column[0])).head()).getDouble(0);
    }

    default Dataset<Row> probabilize(String str, String str2, String str3, Dataset<Row> dataset) {
        Tuple2.mcDD.sp spVar = new Tuple2.mcDD.sp(r0.getLong(0), ((Row) dataset.agg(functions$.MODULE$.count(functions$.MODULE$.lit(BoxesRunTime.boxToInteger(1))), Predef$.MODULE$.wrapRefArray(new Column[]{functions$.MODULE$.sum(str)})).first()).getDouble(1));
        if (spVar == null) {
            throw new MatchError(spVar);
        }
        Tuple2.mcDD.sp spVar2 = new Tuple2.mcDD.sp(spVar._1$mcD$sp(), spVar._2$mcD$sp());
        return dataset.withColumn(str2, functions$.MODULE$.col(str).$div(BoxesRunTime.boxToDouble(spVar2._2$mcD$sp()))).withColumn(str3, functions$.MODULE$.col(str2).$times(BoxesRunTime.boxToDouble(spVar2._1$mcD$sp())));
    }

    default Dataset<Row> updateWeights(String str, String str2, double d, String str3, Dataset<Row> dataset) {
        return dataset.withColumn(str3, functions$.MODULE$.col(str).$times(functions$.MODULE$.pow(functions$.MODULE$.lit(BoxesRunTime.boxToDouble(d)), functions$.MODULE$.lit(BoxesRunTime.boxToInteger(1)).$minus(functions$.MODULE$.col(str2)))));
    }

    default double avgLoss(String str, String str2, Dataset<Row> dataset) {
        return ((Row) dataset.agg(functions$.MODULE$.sum(functions$.MODULE$.col(str).$times(functions$.MODULE$.col(str2))), Predef$.MODULE$.wrapRefArray(new Column[0])).first()).getDouble(0);
    }

    default double beta(double d, int i) {
        return d / ((1 - d) * (i - 1));
    }

    default int beta$default$2() {
        return 2;
    }

    default double weight(double d) {
        if (d == 0.0d) {
            return 1.0d;
        }
        return scala.math.package$.MODULE$.log(1 / d);
    }

    default Dataset<Row> extractBoostedBag(String str, long j, Dataset<Row> dataset) {
        int fieldIndex = dataset.schema().fieldIndex(str);
        return dataset.sparkSession().createDataFrame(dataset.rdd().mapPartitionsWithIndex((obj, iterator) -> {
            return $anonfun$extractBoostedBag$1(fieldIndex, j, BoxesRunTime.unboxToInt(obj), iterator);
        }, dataset.rdd().mapPartitionsWithIndex$default$2(), ClassTag$.MODULE$.apply(Row.class)), dataset.schema());
    }

    default Tuple3<Object, Object, Object> terminateVal(boolean z, double d, double d2, double d3, int i, int i2, int i3, Instrumentation instrumentation) {
        if (!z) {
            return new Tuple3<>(BoxesRunTime.boxToInteger(i3 - 1), BoxesRunTime.boxToDouble(0.0d), BoxesRunTime.boxToInteger(0));
        }
        if (d2 < d * (((double) 1) - d3)) {
            return new Tuple3<>(BoxesRunTime.boxToInteger(i3 - 1), BoxesRunTime.boxToDouble(d2), BoxesRunTime.boxToInteger(0));
        }
        if (i2 != i - 1) {
            return new Tuple3<>(BoxesRunTime.boxToInteger(i3 - 1), BoxesRunTime.boxToDouble(d), BoxesRunTime.boxToInteger(i2 + 1));
        }
        instrumentation.logInfo(() -> {
            return new StringBuilder(89).append("Stopped because new boosters don't improved validation performance more than ").append(d3).append(" in ").append(i).append(" rounds.").toString();
        });
        return new Tuple3<>(BoxesRunTime.boxToInteger(0), BoxesRunTime.boxToDouble(0.0d), BoxesRunTime.boxToInteger(i2 + 1));
    }

    default Tuple3<Object, Object, Object> terminate(double d, boolean z, double d2, double d3, double d4, int i, int i2, int i3, Instrumentation instrumentation, double d5) {
        if (d > (d5 - 1.0d) / d5) {
            instrumentation.logInfo(() -> {
                return new StringBuilder(63).append("Stopped because the average loss of new booster is higher than ").append((d5 - 1.0d) / d5).toString();
            });
            return new Tuple3<>(BoxesRunTime.boxToInteger(0), BoxesRunTime.boxToDouble(0.0d), BoxesRunTime.boxToInteger(1));
        }
        if (d != 0.0d) {
            return terminateVal(z, d2, d3, d4, i, i2, i3, instrumentation);
        }
        instrumentation.logInfo(() -> {
            return new StringBuilder(37).append("Stopped because the average loss was ").append(d).toString();
        });
        return new Tuple3<>(BoxesRunTime.boxToInteger(0), BoxesRunTime.boxToDouble(0.0d), BoxesRunTime.boxToInteger(0));
    }

    default double terminate$default$10() {
        return 2.0d;
    }

    static /* synthetic */ Iterator $anonfun$extractBoostedBag$1(int i, long j, int i2, Iterator iterator) {
        Tuple2 tuple2 = new Tuple2(BoxesRunTime.boxToInteger(i2), iterator);
        if (tuple2 == null) {
            throw new MatchError(tuple2);
        }
        int _1$mcI$sp = tuple2._1$mcI$sp();
        return ((Iterator) tuple2._2()).zipWithIndex().flatMap(tuple22 -> {
            Iterator fill;
            if (tuple22 == null) {
                throw new MatchError(tuple22);
            }
            Row row = (Row) tuple22._1();
            int _2$mcI$sp = tuple22._2$mcI$sp();
            double d = row.getDouble(i);
            if (d == 0.0d) {
                fill = scala.package$.MODULE$.Iterator().empty();
            } else {
                PoissonDistribution poissonDistribution = new PoissonDistribution(d);
                poissonDistribution.reseedRandomGenerator(j + _1$mcI$sp + _2$mcI$sp);
                fill = scala.package$.MODULE$.Iterator().fill(poissonDistribution.sample(), () -> {
                    return row;
                });
            }
            return fill;
        });
    }

    static void $init$(BoostingParams boostingParams) {
        boostingParams.setDefault(Predef$.MODULE$.wrapRefArray(new ParamPair[]{boostingParams.numRound().$minus$greater(BoxesRunTime.boxToInteger(2))}));
        boostingParams.setDefault(Predef$.MODULE$.wrapRefArray(new ParamPair[]{boostingParams.numBaseLearners().$minus$greater(BoxesRunTime.boxToInteger(10))}));
        boostingParams.setDefault(Predef$.MODULE$.wrapRefArray(new ParamPair[]{boostingParams.tol().$minus$greater(BoxesRunTime.boxToDouble(1.0E-6d))}));
    }
}
