package org.apache.spark.ml.bagging;

import org.apache.commons.math3.distribution.PoissonDistribution;
import org.apache.spark.SparkException;
import org.apache.spark.ml.linalg.DenseVector;
import org.apache.spark.ml.linalg.SparseVector;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.linalg.Vectors$;
import org.apache.spark.ml.util.BaggingMetadataUtils$;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.bfunctions$;
import org.apache.spark.sql.functions$;
import org.apache.spark.sql.types.IntegerType$;
import org.apache.spark.util.random.XORShiftRandom;
import scala.Array$;
import scala.MatchError;
import scala.None$;
import scala.Predef$;
import scala.Some;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.immutable.IndexedSeq$;
import scala.collection.mutable.ArrayOps;
import scala.math.Ordering$Int$;
import scala.reflect.ClassTag$;
import scala.reflect.api.Mirror;
import scala.reflect.api.TypeCreator;
import scala.reflect.api.Types;
import scala.reflect.api.Universe;
import scala.reflect.runtime.package$;
import scala.runtime.BoxesRunTime;
import scala.runtime.RichInt$;
import scala.util.Random;
import scala.util.Random$;

/* compiled from: Bagging.scala */
/* loaded from: input_file:org/apache/spark/ml/bagging/Bagging$.class */
public final class Bagging$ {
    public static Bagging$ MODULE$;

    static {
        new Bagging$();
    }

    public Column weightBag(boolean z, double d, int i, long j) {
        Predef$.MODULE$.require(d > ((double) 0), () -> {
            return "sampleRatio must be strictly positive";
        });
        if (z) {
            return functions$.MODULE$.array((Seq) RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), i).map(obj -> {
                return $anonfun$weightBag$2(d, j, BoxesRunTime.unboxToInt(obj));
            }, IndexedSeq$.MODULE$.canBuildFrom()));
        }
        if (d == 1) {
            return functions$.MODULE$.array_repeat(functions$.MODULE$.lit(BoxesRunTime.boxToInteger(1)), i);
        }
        Predef$.MODULE$.require(d <= ((double) 1), () -> {
            return "Without replacement, the sampleRatio cannot be greater to one";
        });
        return functions$.MODULE$.array((Seq) RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), i).map(obj2 -> {
            return $anonfun$weightBag$4(j, d, BoxesRunTime.unboxToInt(obj2));
        }, IndexedSeq$.MODULE$.canBuildFrom()));
    }

    public Column duplicateRow(Column column) {
        return functions$.MODULE$.explode(functions$.MODULE$.array_repeat(functions$.MODULE$.lit(BoxesRunTime.boxToInteger(1)), column.cast(IntegerType$.MODULE$)));
    }

    public Seq<Object> arraySample(boolean z, double d, long j, Seq<Object> seq) {
        if (z) {
            PoissonDistribution poissonDistribution = new PoissonDistribution(d);
            poissonDistribution.reseedRandomGenerator(j);
            return (Seq) seq.flatMap(obj -> {
                return $anonfun$arraySample$1(poissonDistribution, BoxesRunTime.unboxToDouble(obj));
            }, Seq$.MODULE$.canBuildFrom());
        }
        if (d == 1) {
            return seq;
        }
        XORShiftRandom xORShiftRandom = new XORShiftRandom(j);
        return (Seq) seq.flatMap(obj2 -> {
            return $anonfun$arraySample$2(xORShiftRandom, d, BoxesRunTime.unboxToDouble(obj2));
        }, Seq$.MODULE$.canBuildFrom());
    }

    public int[] arrayIndicesSample(boolean z, int i, long j, int[] iArr) {
        int min$extension = RichInt$.MODULE$.min$extension(Predef$.MODULE$.intWrapper(i), iArr.length);
        if (!z) {
            return (int[]) new ArrayOps.ofInt(Predef$.MODULE$.intArrayOps(min$extension == iArr.length ? iArr : (int[]) new ArrayOps.ofInt(Predef$.MODULE$.intArrayOps((int[]) Random$.MODULE$.shuffle(new ArrayOps.ofInt(Predef$.MODULE$.intArrayOps(iArr)).indices().toIndexedSeq(), IndexedSeq$.MODULE$.canBuildFrom()).toArray(ClassTag$.MODULE$.Int()))).take(min$extension))).sorted(Ordering$Int$.MODULE$);
        }
        Random random = new Random(j);
        return (int[]) new ArrayOps.ofInt(Predef$.MODULE$.intArrayOps((int[]) Array$.MODULE$.fill(min$extension, () -> {
            return random.nextInt(iArr.length);
        }, ClassTag$.MODULE$.Int()))).distinct();
    }

    public Dataset<Row> withWeightedBag(boolean z, double d, int i, long j, String str, Dataset<Row> dataset) {
        return dataset.withColumn(str, weightBag(z, d, i, j));
    }

    public Dataset<Row> withSampledRows(String str, int i, Dataset<Row> dataset) {
        return dataset.withColumn("dummy", duplicateRow(functions$.MODULE$.col(str).apply(BoxesRunTime.boxToInteger(i)))).drop(functions$.MODULE$.col("dummy"));
    }

    public Dataset<Row> withSampledRows(String str, Dataset<Row> dataset) {
        return dataset.withColumn("dummy", duplicateRow(functions$.MODULE$.col(str))).drop(functions$.MODULE$.col("dummy"));
    }

    public Dataset<Row> withSampledFeatures(String str, int[] iArr, Dataset<Row> dataset) {
        return dataset.withColumn(str, functions$.MODULE$.udf(vector -> {
            Vector slice;
            if (vector instanceof DenseVector) {
                DenseVector denseVector = (DenseVector) vector;
                slice = Vectors$.MODULE$.dense((double[]) new ArrayOps.ofInt(Predef$.MODULE$.intArrayOps(iArr)).map(i -> {
                    return denseVector.apply(i);
                }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.Double())));
            } else {
                if (!(vector instanceof SparseVector)) {
                    throw new MatchError(vector);
                }
                slice = ((SparseVector) vector).slice(iArr);
            }
            return slice;
        }, package$.MODULE$.universe().TypeTag().apply(package$.MODULE$.universe().runtimeMirror(getClass().getClassLoader()), new TypeCreator() { // from class: org.apache.spark.ml.bagging.Bagging$$typecreator1$1
            public <U extends Universe> Types.TypeApi apply(Mirror<U> mirror) {
                mirror.universe();
                return mirror.staticClass("org.apache.spark.ml.linalg.Vector").asType().toTypeConstructor();
            }
        }), package$.MODULE$.universe().TypeTag().apply(package$.MODULE$.universe().runtimeMirror(getClass().getClassLoader()), new TypeCreator() { // from class: org.apache.spark.ml.bagging.Bagging$$typecreator2$1
            public <U extends Universe> Types.TypeApi apply(Mirror<U> mirror) {
                mirror.universe();
                return mirror.staticClass("org.apache.spark.ml.linalg.Vector").asType().toTypeConstructor();
            }
        })).apply(Predef$.MODULE$.wrapRefArray(new Column[]{functions$.MODULE$.col(str)})));
    }

    public int getNumFeatures(Dataset<?> dataset, String str) {
        int i;
        Some numFeatures = BaggingMetadataUtils$.MODULE$.getNumFeatures(dataset.schema().apply(str));
        if (numFeatures instanceof Some) {
            i = BoxesRunTime.unboxToInt(numFeatures.value());
        } else {
            if (!None$.MODULE$.equals(numFeatures)) {
                throw new MatchError(numFeatures);
            }
            Row[] rowArr = (Row[]) dataset.select(Predef$.MODULE$.wrapRefArray(new Column[]{functions$.MODULE$.size(functions$.MODULE$.col(str))})).take(1);
            if (new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(rowArr)).isEmpty() || rowArr[0].get(0) == null) {
                throw new SparkException("ML algorithm was given empty dataset.");
            }
            i = ((Row) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(rowArr)).head()).getInt(0);
        }
        return i;
    }

    public static final /* synthetic */ Column $anonfun$weightBag$2(double d, long j, int i) {
        return bfunctions$.MODULE$.poisson(d, j + i);
    }

    public static final /* synthetic */ Column $anonfun$weightBag$4(long j, double d, int i) {
        return functions$.MODULE$.expr(new StringBuilder(16).append("if(rand(").append(j).append("+").append(i).append(")<").append(d).append(",1,0)").toString());
    }

    public static final /* synthetic */ Seq $anonfun$arraySample$1(PoissonDistribution poissonDistribution, double d) {
        return poissonDistribution.sample() > 1 ? Seq$.MODULE$.apply(Predef$.MODULE$.wrapDoubleArray(new double[]{d})) : Seq$.MODULE$.empty();
    }

    public static final /* synthetic */ Seq $anonfun$arraySample$2(XORShiftRandom xORShiftRandom, double d, double d2) {
        return xORShiftRandom.nextDouble() < d ? Seq$.MODULE$.apply(Predef$.MODULE$.wrapDoubleArray(new double[]{d2})) : Seq$.MODULE$.empty();
    }

    private Bagging$() {
        MODULE$ = this;
    }
}
