package org.apache.hudi.integ.testsuite.generator;

import java.io.IOException;
import java.io.Serializable;
import java.io.UncheckedIOException;
import java.lang.invoke.SerializedLambda;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.StreamSupport;
import org.apache.avro.generic.GenericRecord;
import org.apache.hudi.common.util.Option;
import org.apache.hudi.integ.testsuite.configuration.DFSDeltaConfig;
import org.apache.hudi.integ.testsuite.configuration.DeltaConfig;
import org.apache.hudi.integ.testsuite.converter.UpdateConverter;
import org.apache.hudi.integ.testsuite.reader.DFSAvroDeltaInputReader;
import org.apache.hudi.integ.testsuite.reader.DFSHoodieDatasetInputReader;
import org.apache.hudi.integ.testsuite.writer.DeltaOutputMode;
import org.apache.hudi.integ.testsuite.writer.DeltaWriteStats;
import org.apache.hudi.integ.testsuite.writer.DeltaWriterFactory;
import org.apache.hudi.keygen.BuiltinKeyGenerator;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.storage.StorageLevel;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.Tuple2;

/* loaded from: input_file:org/apache/hudi/integ/testsuite/generator/DeltaGenerator.class */
public class DeltaGenerator implements Serializable {
    private static Logger log = LoggerFactory.getLogger(DFSHoodieDatasetInputReader.class);
    private DeltaConfig deltaOutputConfig;
    private transient JavaSparkContext jsc;
    private transient SparkSession sparkSession;
    private String schemaStr;
    private List<String> recordRowKeyFieldNames;
    private List<String> partitionPathFieldNames;
    private int batchId;

    public DeltaGenerator(DeltaConfig deltaConfig, JavaSparkContext javaSparkContext, SparkSession sparkSession, String str, BuiltinKeyGenerator builtinKeyGenerator) {
        this.deltaOutputConfig = deltaConfig;
        this.jsc = javaSparkContext;
        this.sparkSession = sparkSession;
        this.schemaStr = str;
        this.recordRowKeyFieldNames = builtinKeyGenerator.getRecordKeyFields();
        this.partitionPathFieldNames = builtinKeyGenerator.getPartitionPathFields();
    }

    public JavaRDD<DeltaWriteStats> writeRecords(JavaRDD<GenericRecord> javaRDD) {
        JavaRDD<DeltaWriteStats> flatMap = javaRDD.mapPartitions(it2 -> {
            try {
                return Collections.singletonList(DeltaWriterFactory.getDeltaWriterAdapter(this.deltaOutputConfig, Integer.valueOf(this.batchId)).write(it2)).iterator();
            } catch (IOException e) {
                throw new UncheckedIOException(e);
            }
        }).flatMap((v0) -> {
            return v0.iterator();
        });
        this.batchId++;
        return flatMap;
    }

    public JavaRDD<GenericRecord> generateInserts(DeltaConfig.Config config) {
        long numRecordsInsert = config.getNumRecordsInsert();
        int recordSize = config.getRecordSize();
        return this.jsc.parallelize(Collections.EMPTY_LIST).repartition(config.getNumInsertPartitions()).mapPartitions(obj -> {
            return new LazyRecordGeneratorIterator(new FlexibleSchemaRecordGenerationIterator(numRecordsInsert, recordSize, this.schemaStr, this.partitionPathFieldNames));
        });
    }

    public JavaRDD<GenericRecord> generateUpdates(DeltaConfig.Config config) throws IOException {
        JavaRDD<GenericRecord> read;
        if (this.deltaOutputConfig.getDeltaOutputMode() != DeltaOutputMode.DFS) {
            throw new IllegalArgumentException("Other formats are not supported at the moment");
        }
        JavaRDD<GenericRecord> javaRDD = null;
        if (config.getNumRecordsInsert() > 0) {
            javaRDD = generateInserts(config);
        }
        if (config.getNumUpsertPartitions() < 1) {
            read = adjustRDDToGenerateExactNumUpdates(new DFSAvroDeltaInputReader(this.sparkSession, this.schemaStr, ((DFSDeltaConfig) this.deltaOutputConfig).getDeltaBasePath(), Option.empty(), Option.empty()).read(config.getNumRecordsUpsert()), this.jsc, config.getNumRecordsUpsert());
        } else {
            DFSHoodieDatasetInputReader dFSHoodieDatasetInputReader = new DFSHoodieDatasetInputReader(this.jsc, ((DFSDeltaConfig) this.deltaOutputConfig).getDatasetOutputPath(), this.schemaStr);
            read = config.getFractionUpsertPerFile() > 0.0d ? dFSHoodieDatasetInputReader.read(config.getNumUpsertPartitions(), config.getNumUpsertFiles(), config.getFractionUpsertPerFile()) : dFSHoodieDatasetInputReader.read(config.getNumUpsertPartitions(), config.getNumUpsertFiles(), config.getNumRecordsUpsert());
        }
        log.info("Repartitioning records");
        JavaRDD repartition = read.repartition(this.jsc.defaultParallelism().intValue());
        log.info("Repartitioning records done");
        JavaRDD<GenericRecord> convert = new UpdateConverter(this.schemaStr, config.getRecordSize(), this.partitionPathFieldNames, this.recordRowKeyFieldNames).convert(repartition);
        log.info("Records converted");
        convert.persist(StorageLevel.DISK_ONLY());
        return javaRDD != null ? javaRDD.union(convert) : convert;
    }

    public Map<Integer, Long> getPartitionToCountMap(JavaRDD<GenericRecord> javaRDD) {
        return javaRDD.mapPartitionsWithIndex((num, it2) -> {
            Iterable iterable = () -> {
                return it2;
            };
            return Arrays.asList(new Tuple2(num, Long.valueOf(StreamSupport.stream(iterable.spliterator(), true).count()))).iterator();
        }, true).mapToPair(tuple2 -> {
            return tuple2;
        }).collectAsMap();
    }

    public Map<Integer, Long> getAdjustedPartitionsCount(Map<Integer, Long> map, long j) {
        long j2 = j;
        HashMap hashMap = new HashMap();
        for (Map.Entry<Integer, Long> entry : map.entrySet()) {
            if (entry.getValue().longValue() < j2) {
                j2 -= entry.getValue().longValue();
                hashMap.put(entry.getKey(), 0L);
            } else {
                long longValue = entry.getValue().longValue() - j2;
                j2 = 0;
                hashMap.put(entry.getKey(), Long.valueOf(longValue));
            }
            if (j2 == 0) {
                break;
            }
        }
        return hashMap;
    }

    public JavaRDD<GenericRecord> adjustRDDToGenerateExactNumUpdates(JavaRDD<GenericRecord> javaRDD, JavaSparkContext javaSparkContext, long j) {
        Map<Integer, Long> partitionToCountMap = getPartitionToCountMap(javaRDD);
        long sum = partitionToCountMap.values().stream().mapToLong((v0) -> {
            return v0.longValue();
        }).sum();
        if (!isSafeToTake(j, sum)) {
            if (j >= sum) {
                return javaRDD;
            }
            Map<Integer, Long> adjustedPartitionsCount = getAdjustedPartitionsCount(partitionToCountMap, sum - j);
            return javaRDD.mapPartitionsWithIndex((num, it2) -> {
                ArrayList arrayList = new ArrayList();
                if (!adjustedPartitionsCount.containsKey(num)) {
                    return it2;
                }
                long longValue = ((Long) adjustedPartitionsCount.get(num)).longValue();
                for (int i = 1; i <= longValue && it2.hasNext(); i++) {
                    arrayList.add(it2.next());
                }
                return arrayList.iterator();
            }, true);
        }
        long j2 = sum;
        while (true) {
            long j3 = j2;
            if (j == j3) {
                return javaRDD;
            }
            long j4 = j - j3 > j3 ? j3 : j - j3;
            if (j - j3 <= j4 || j4 > j3) {
                javaRDD = javaRDD.union(javaSparkContext.parallelize(javaRDD.take((int) j4)));
                j2 = j3 + j4;
            } else {
                javaRDD = javaRDD.union(javaRDD);
                j2 = j3 * 2;
            }
        }
    }

    private boolean isSafeToTake(long j, long j2) {
        return j > j2;
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case -803819761:
                if (implMethodName.equals("lambda$adjustRDDToGenerateExactNumUpdates$58833efa$1")) {
                    z = 3;
                    break;
                }
                break;
            case -751161183:
                if (implMethodName.equals("lambda$getPartitionToCountMap$d8a43be8$1")) {
                    z = 4;
                    break;
                }
                break;
            case -153388:
                if (implMethodName.equals("lambda$getPartitionToCountMap$9de3afec$1")) {
                    z = 2;
                    break;
                }
                break;
            case 131889105:
                if (implMethodName.equals("lambda$generateInserts$3c54fa20$1")) {
                    z = 5;
                    break;
                }
                break;
            case 1182533742:
                if (implMethodName.equals("iterator")) {
                    z = false;
                    break;
                }
                break;
            case 2056634893:
                if (implMethodName.equals("lambda$writeRecords$5e8e5895$1")) {
                    z = true;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 9 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/spark/api/java/function/FlatMapFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("call") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/util/Iterator;") && serializedLambda.getImplClass().equals("java/util/List") && serializedLambda.getImplMethodSignature().equals("()Ljava/util/Iterator;")) {
                    return (v0) -> {
                        return v0.iterator();
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 7 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/spark/api/java/function/FlatMapFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("call") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/util/Iterator;") && serializedLambda.getImplClass().equals("org/apache/hudi/integ/testsuite/generator/DeltaGenerator") && serializedLambda.getImplMethodSignature().equals("(Ljava/util/Iterator;)Ljava/util/Iterator;")) {
                    DeltaGenerator deltaGenerator = (DeltaGenerator) serializedLambda.getCapturedArg(0);
                    return it2 -> {
                        try {
                            return Collections.singletonList(DeltaWriterFactory.getDeltaWriterAdapter(this.deltaOutputConfig, Integer.valueOf(this.batchId)).write(it2)).iterator();
                        } catch (IOException e) {
                            throw new UncheckedIOException(e);
                        }
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/spark/api/java/function/Function2") && serializedLambda.getFunctionalInterfaceMethodName().equals("call") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/hudi/integ/testsuite/generator/DeltaGenerator") && serializedLambda.getImplMethodSignature().equals("(Ljava/lang/Integer;Ljava/util/Iterator;)Ljava/util/Iterator;")) {
                    return (num, it22) -> {
                        Iterable iterable = () -> {
                            return it22;
                        };
                        return Arrays.asList(new Tuple2(num, Long.valueOf(StreamSupport.stream(iterable.spliterator(), true).count()))).iterator();
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/spark/api/java/function/Function2") && serializedLambda.getFunctionalInterfaceMethodName().equals("call") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/hudi/integ/testsuite/generator/DeltaGenerator") && serializedLambda.getImplMethodSignature().equals("(Ljava/util/Map;Ljava/lang/Integer;Ljava/util/Iterator;)Ljava/util/Iterator;")) {
                    Map map = (Map) serializedLambda.getCapturedArg(0);
                    return (num2, it23) -> {
                        ArrayList arrayList = new ArrayList();
                        if (!map.containsKey(num2)) {
                            return it23;
                        }
                        long longValue = ((Long) map.get(num2)).longValue();
                        for (int i = 1; i <= longValue && it23.hasNext(); i++) {
                            arrayList.add(it23.next());
                        }
                        return arrayList.iterator();
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/spark/api/java/function/PairFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("call") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Lscala/Tuple2;") && serializedLambda.getImplClass().equals("org/apache/hudi/integ/testsuite/generator/DeltaGenerator") && serializedLambda.getImplMethodSignature().equals("(Lscala/Tuple2;)Lscala/Tuple2;")) {
                    return tuple2 -> {
                        return tuple2;
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 7 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/spark/api/java/function/FlatMapFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("call") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/util/Iterator;") && serializedLambda.getImplClass().equals("org/apache/hudi/integ/testsuite/generator/DeltaGenerator") && serializedLambda.getImplMethodSignature().equals("(JILjava/lang/Object;)Ljava/util/Iterator;")) {
                    DeltaGenerator deltaGenerator2 = (DeltaGenerator) serializedLambda.getCapturedArg(0);
                    long longValue = ((Long) serializedLambda.getCapturedArg(1)).longValue();
                    int intValue = ((Integer) serializedLambda.getCapturedArg(2)).intValue();
                    return obj -> {
                        return new LazyRecordGeneratorIterator(new FlexibleSchemaRecordGenerationIterator(longValue, intValue, this.schemaStr, this.partitionPathFieldNames));
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
