package org.apache.wayang.spark.operators;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.function.IntUnaryOperator;
import java.util.function.LongUnaryOperator;
import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaRDD;
import org.apache.wayang.basic.operators.SampleOperator;
import org.apache.wayang.core.api.exception.WayangException;
import org.apache.wayang.core.optimizer.OptimizationContext;
import org.apache.wayang.core.plan.wayangplan.ExecutionOperator;
import org.apache.wayang.core.platform.ChannelDescriptor;
import org.apache.wayang.core.platform.ChannelInstance;
import org.apache.wayang.core.platform.lineage.ExecutionLineageNode;
import org.apache.wayang.core.types.DataSetType;
import org.apache.wayang.core.util.Tuple;
import org.apache.wayang.java.channels.CollectionChannel;
import org.apache.wayang.spark.channels.RddChannel;
import org.apache.wayang.spark.execution.SparkExecutor;
import scala.collection.JavaConversions;
import scala.reflect.ClassTag$;

/* loaded from: input_file:org/apache/wayang/spark/operators/SparkRandomPartitionSampleOperator.class */
public class SparkRandomPartitionSampleOperator<Type> extends SampleOperator<Type> implements SparkExecutionOperator {
    private Random rand;
    private int nb_partitions;
    private int partitionSize;
    private boolean first;
    static final /* synthetic */ boolean $assertionsDisabled;

    public SparkRandomPartitionSampleOperator(IntUnaryOperator intUnaryOperator, DataSetType<Type> dataSetType, LongUnaryOperator longUnaryOperator) {
        super(intUnaryOperator, dataSetType, SampleOperator.Methods.RANDOM, longUnaryOperator);
        this.nb_partitions = 0;
        this.partitionSize = 0;
        this.first = true;
    }

    public SparkRandomPartitionSampleOperator(SampleOperator<Type> sampleOperator) {
        super(sampleOperator);
        this.nb_partitions = 0;
        this.partitionSize = 0;
        this.first = true;
        if (!$assertionsDisabled && sampleOperator.getSampleMethod() != SampleOperator.Methods.RANDOM && sampleOperator.getSampleMethod() != SampleOperator.Methods.ANY) {
            throw new AssertionError();
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v95, types: [java.util.List[]] */
    /* JADX WARN: Type inference failed for: r0v96 */
    @Override // org.apache.wayang.spark.operators.SparkExecutionOperator
    public Tuple<Collection<ExecutionLineageNode>, Collection<ChannelInstance>> evaluate(ChannelInstance[] channelInstanceArr, ChannelInstance[] channelInstanceArr2, SparkExecutor sparkExecutor, OptimizationContext.OperatorContext operatorContext) {
        ArrayList arrayList;
        if (!$assertionsDisabled && channelInstanceArr.length != getNumInputs()) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && channelInstanceArr2.length != getNumOutputs()) {
            throw new AssertionError();
        }
        JavaRDD provideRdd = ((RddChannel.Instance) channelInstanceArr[0]).provideRdd();
        long datasetSize = isDataSetSizeKnown() ? getDatasetSize() : provideRdd.cache().count();
        int sampleSize = getSampleSize(operatorContext);
        if (sampleSize >= datasetSize) {
            ((CollectionChannel.Instance) channelInstanceArr2[0]).accept(provideRdd.collect());
            return ExecutionOperator.modelEagerExecution(channelInstanceArr, channelInstanceArr2, operatorContext);
        }
        this.rand = new Random(getSeed(operatorContext));
        SparkContext context = provideRdd.context();
        if (this.first) {
            this.nb_partitions = provideRdd.partitions().size();
            this.partitionSize = (int) Math.ceil(datasetSize / this.nb_partitions);
            this.first = false;
        }
        if (sampleSize == 1) {
            int nextInt = this.rand.nextInt(this.nb_partitions);
            int nextInt2 = this.rand.nextInt(this.partitionSize);
            arrayList = ((List[]) context.runJob(provideRdd.rdd(), new PartitionSampleFunction(nextInt2, nextInt2 + sampleSize), JavaConversions.asScalaBuffer(Collections.singletonList(Integer.valueOf(nextInt))), ClassTag$.MODULE$.apply(List.class)))[0];
        } else {
            HashMap hashMap = new HashMap();
            for (int i = 0; i < sampleSize; i++) {
                int nextInt3 = this.rand.nextInt(this.nb_partitions);
                int nextInt4 = this.rand.nextInt(this.partitionSize);
                ArrayList arrayList2 = (ArrayList) hashMap.get(Integer.valueOf(nextInt3));
                if (arrayList2 == null) {
                    ArrayList arrayList3 = new ArrayList();
                    arrayList3.add(Integer.valueOf(nextInt4));
                    hashMap.put(Integer.valueOf(nextInt3), arrayList3);
                } else {
                    arrayList2.add(Integer.valueOf(nextInt4));
                }
            }
            ArrayList arrayList4 = new ArrayList();
            ExecutorService newFixedThreadPool = Executors.newFixedThreadPool(hashMap.size());
            Iterator it = hashMap.keySet().iterator();
            ArrayList arrayList5 = new ArrayList(hashMap.size());
            while (it.hasNext()) {
                int intValue = ((Integer) it.next()).intValue();
                List singletonList = Collections.singletonList(Integer.valueOf(intValue));
                ArrayList arrayList6 = (ArrayList) hashMap.get(Integer.valueOf(intValue));
                Collections.sort(arrayList6);
                arrayList5.add(newFixedThreadPool.submit(() -> {
                    return context.runJob(provideRdd.rdd(), new PartitionSampleListFunction(arrayList6), JavaConversions.asScalaBuffer(singletonList), ClassTag$.MODULE$.apply(List.class));
                }));
            }
            for (int i2 = 0; i2 < hashMap.size(); i2++) {
                try {
                    arrayList4.addAll(((List[]) ((Future) arrayList5.get(i2)).get())[0]);
                } catch (InterruptedException e) {
                    this.logger.error("Random partition sampling failed due to threads.", e);
                } catch (ExecutionException e2) {
                    throw new WayangException("Random partition sampling failed.", e2);
                }
            }
            newFixedThreadPool.shutdown();
            arrayList = arrayList4;
        }
        ((CollectionChannel.Instance) channelInstanceArr2[0]).accept(arrayList);
        return ExecutionOperator.modelEagerExecution(channelInstanceArr, channelInstanceArr2, operatorContext);
    }

    protected ExecutionOperator createCopy() {
        return new SparkRandomPartitionSampleOperator(this);
    }

    public List<ChannelDescriptor> getSupportedInputChannels(int i) {
        if ($assertionsDisabled || i <= getNumInputs() || (i == 0 && getNumInputs() == 0)) {
            return isDataSetSizeKnown() ? Arrays.asList(RddChannel.UNCACHED_DESCRIPTOR, RddChannel.CACHED_DESCRIPTOR) : Collections.singletonList(RddChannel.CACHED_DESCRIPTOR);
        }
        throw new AssertionError();
    }

    public List<ChannelDescriptor> getSupportedOutputChannels(int i) {
        if ($assertionsDisabled || i <= getNumOutputs() || (i == 0 && getNumOutputs() == 0)) {
            return Collections.singletonList(CollectionChannel.DESCRIPTOR);
        }
        throw new AssertionError();
    }

    public String getLoadProfileEstimatorConfigurationKey() {
        return "wayang.spark.random-partition-sample.load";
    }

    @Override // org.apache.wayang.spark.operators.SparkExecutionOperator
    public boolean containsAction() {
        return true;
    }

    static {
        $assertionsDisabled = !SparkRandomPartitionSampleOperator.class.desiredAssertionStatus();
    }
}
