package org.apache.beam.runners.spark.translation.streaming;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.util.concurrent.Uninterruptibles;
import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.concurrent.TimeUnit;
import org.apache.beam.runners.spark.ReuseSparkContextRule;
import org.apache.beam.runners.spark.SparkPipelineOptions;
import org.apache.beam.runners.spark.SparkPipelineResult;
import org.apache.beam.runners.spark.SparkRunner;
import org.apache.beam.runners.spark.aggregators.AggregatorsAccumulator;
import org.apache.beam.runners.spark.coders.CoderHelpers;
import org.apache.beam.runners.spark.metrics.SparkMetricsContainer;
import org.apache.beam.runners.spark.translation.streaming.utils.EmbeddedKafkaCluster;
import org.apache.beam.runners.spark.util.GlobalWatermarkHolder;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.coders.InstantCoder;
import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.io.kafka.KafkaIO;
import org.apache.beam.sdk.metrics.Counter;
import org.apache.beam.sdk.metrics.MetricMatchers;
import org.apache.beam.sdk.metrics.MetricNameFilter;
import org.apache.beam.sdk.metrics.Metrics;
import org.apache.beam.sdk.metrics.MetricsFilter;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.sdk.transforms.Aggregator;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.GroupByKey;
import org.apache.beam.sdk.transforms.Keys;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.transforms.Sum;
import org.apache.beam.sdk.transforms.Values;
import org.apache.beam.sdk.transforms.View;
import org.apache.beam.sdk.transforms.WithKeys;
import org.apache.beam.sdk.transforms.windowing.AfterWatermark;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.transforms.windowing.FixedWindows;
import org.apache.beam.sdk.transforms.windowing.Window;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionView;
import org.apache.beam.sdk.values.PDone;
import org.apache.kafka.clients.producer.KafkaProducer;
import org.apache.kafka.clients.producer.ProducerRecord;
import org.apache.kafka.common.serialization.Serializer;
import org.apache.kafka.common.serialization.StringSerializer;
import org.hamcrest.Matchers;
import org.joda.time.Duration;
import org.joda.time.Instant;
import org.junit.AfterClass;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.rules.TestName;

/* loaded from: input_file:org/apache/beam/runners/spark/translation/streaming/ResumeFromCheckpointStreamingTest.class */
public class ResumeFromCheckpointStreamingTest {
    private static final EmbeddedKafkaCluster.EmbeddedZookeeper EMBEDDED_ZOOKEEPER = new EmbeddedKafkaCluster.EmbeddedZookeeper();
    private static final EmbeddedKafkaCluster EMBEDDED_KAFKA_CLUSTER = new EmbeddedKafkaCluster(EMBEDDED_ZOOKEEPER.getConnection(), new Properties());
    private static final String TOPIC = "kafka_beam_test_topic";

    @Rule
    public TemporaryFolder tmpFolder = new TemporaryFolder();

    @Rule
    public ReuseSparkContextRule noContextResue = ReuseSparkContextRule.no();

    @Rule
    public transient TestName testName = new TestName();

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/beam/runners/spark/translation/streaming/ResumeFromCheckpointStreamingTest$EOFShallNotPassFn.class */
    public static class EOFShallNotPassFn extends DoFn<String, String> {
        final PCollectionView<List<String>> view;
        private final Aggregator<Long, Long> aggregator;
        Counter counter;

        private EOFShallNotPassFn(PCollectionView<List<String>> pCollectionView) {
            this.aggregator = createAggregator("processedMessages", Sum.ofLongs());
            this.counter = Metrics.counter(ResumeFromCheckpointStreamingTest.class, "allMessages");
            this.view = pCollectionView;
        }

        @DoFn.ProcessElement
        public void process(DoFn<String, String>.ProcessContext processContext) {
            String str = (String) processContext.element();
            Assert.assertThat(processContext.sideInput(this.view), Matchers.containsInAnyOrder(new String[]{"side1", "side2"}));
            this.counter.inc();
            if (str.equals("EOF")) {
                return;
            }
            this.aggregator.addValue(1L);
            processContext.output(processContext.element());
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/beam/runners/spark/translation/streaming/ResumeFromCheckpointStreamingTest$PAssertWithoutFlatten.class */
    public static class PAssertWithoutFlatten<T> extends PTransform<PCollection<Iterable<T>>, PDone> {
        private final T[] expected;

        /* JADX INFO: Access modifiers changed from: private */
        /* loaded from: input_file:org/apache/beam/runners/spark/translation/streaming/ResumeFromCheckpointStreamingTest$PAssertWithoutFlatten$AssertDoFn.class */
        public static class AssertDoFn<T> extends DoFn<Iterable<T>, Void> {
            private final Aggregator<Integer, Integer> success = createAggregator("PAssertSuccess", Sum.ofIntegers());
            private final Aggregator<Integer, Integer> failure = createAggregator("PAssertFailure", Sum.ofIntegers());
            private final T[] expected;

            AssertDoFn(T[] tArr) {
                this.expected = tArr;
            }

            @DoFn.ProcessElement
            public void processElement(DoFn<Iterable<T>, Void>.ProcessContext processContext) throws Exception {
                try {
                    Assert.assertThat(processContext.element(), Matchers.containsInAnyOrder(this.expected));
                    this.success.addValue(1);
                } catch (Throwable th) {
                    this.failure.addValue(1);
                    throw th;
                }
            }
        }

        private PAssertWithoutFlatten(T... tArr) {
            this.expected = tArr;
        }

        public PDone expand(PCollection<Iterable<T>> pCollection) {
            pCollection.apply(ParDo.of(new AssertDoFn(this.expected)));
            return PDone.in(pCollection.getPipeline());
        }
    }

    @BeforeClass
    public static void init() throws IOException {
        EMBEDDED_ZOOKEEPER.startup();
        EMBEDDED_KAFKA_CLUSTER.startup();
    }

    private static void produce(Map<String, Instant> map) {
        Properties properties = new Properties();
        properties.putAll(EMBEDDED_KAFKA_CLUSTER.getProps());
        properties.put("request.required.acks", 1);
        properties.put("bootstrap.servers", EMBEDDED_KAFKA_CLUSTER.getBrokerList());
        KafkaProducer kafkaProducer = new KafkaProducer(properties, new StringSerializer(), new Serializer<Instant>() { // from class: org.apache.beam.runners.spark.translation.streaming.ResumeFromCheckpointStreamingTest.1
            public void configure(Map<String, ?> map2, boolean z) {
            }

            public byte[] serialize(String str, Instant instant) {
                return CoderHelpers.toByteArray(instant, InstantCoder.of());
            }

            public void close() {
            }
        });
        Throwable th = null;
        try {
            for (Map.Entry<String, Instant> entry : map.entrySet()) {
                kafkaProducer.send(new ProducerRecord(TOPIC, entry.getKey(), entry.getValue()));
            }
            kafkaProducer.close();
            if (kafkaProducer != null) {
                if (0 == 0) {
                    kafkaProducer.close();
                    return;
                }
                try {
                    kafkaProducer.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
        } catch (Throwable th3) {
            if (kafkaProducer != null) {
                if (0 != 0) {
                    try {
                        kafkaProducer.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    kafkaProducer.close();
                }
            }
            throw th3;
        }
    }

    @Test
    public void testWithResume() throws Exception {
        SparkPipelineOptions sparkPipelineOptions = (SparkPipelineOptions) PipelineOptionsFactory.create().as(SparkPipelineOptions.class);
        sparkPipelineOptions.setRunner(SparkRunner.class);
        sparkPipelineOptions.setCheckpointDir(this.tmpFolder.newFolder().toString());
        sparkPipelineOptions.setCheckpointDurationMillis(500L);
        sparkPipelineOptions.setJobName(this.testName.getMethodName());
        sparkPipelineOptions.setSparkMaster("local[*]");
        produce(ImmutableMap.of("k1", new Instant(100L), "k2", new Instant(200L), "k3", new Instant(300L), "k4", new Instant(400L)));
        MetricsFilter build = MetricsFilter.builder().addNameFilter(MetricNameFilter.inNamespace(ResumeFromCheckpointStreamingTest.class)).build();
        SparkPipelineResult run = run(sparkPipelineOptions);
        run.waitUntilFinish(Duration.standardSeconds(5L));
        long longValue = ((Long) run.getAggregatorValue("processedMessages", Long.class)).longValue();
        Assert.assertThat(String.format("Expected %d processed messages count but found %d", 4, Long.valueOf(longValue)), Long.valueOf(longValue), Matchers.equalTo(4L));
        Assert.assertThat(run.metrics().queryMetrics(build).counters(), Matchers.hasItem(MetricMatchers.attemptedMetricsResult(ResumeFromCheckpointStreamingTest.class.getName(), "allMessages", "EOFShallNotPassFn", 4L)));
        AggregatorsAccumulator.clear();
        SparkMetricsContainer.clear();
        GlobalWatermarkHolder.clear();
        produce(ImmutableMap.of("k5", new Instant(499L), "EOF", new Instant(500L)));
        SparkPipelineResult runAgain = runAgain(sparkPipelineOptions);
        runAgain.waitUntilFinish(Duration.standardSeconds(5L));
        long longValue2 = ((Long) runAgain.getAggregatorValue("processedMessages", Long.class)).longValue();
        Assert.assertThat(String.format("Expected %d processed messages count but found %d", 5, Long.valueOf(longValue2)), Long.valueOf(longValue2), Matchers.equalTo(5L));
        Assert.assertThat(runAgain.metrics().queryMetrics(build).counters(), Matchers.hasItem(MetricMatchers.attemptedMetricsResult(ResumeFromCheckpointStreamingTest.class.getName(), "allMessages", "EOFShallNotPassFn", 6L)));
        int intValue = ((Integer) runAgain.getAggregatorValue("PAssertSuccess", Integer.class)).intValue();
        runAgain.getAggregatorValue("PAssertSuccess", Integer.class);
        Assert.assertThat(String.format("Expected %d successful assertions, but found %d.", 1, Integer.valueOf(intValue)), Integer.valueOf(intValue), Matchers.is(1));
        int intValue2 = ((Integer) runAgain.getAggregatorValue("PAssertFailure", Integer.class)).intValue();
        Assert.assertThat(String.format("Found %d failed assertions.", Integer.valueOf(intValue2)), Integer.valueOf(intValue2), Matchers.is(0));
    }

    private SparkPipelineResult runAgain(SparkPipelineOptions sparkPipelineOptions) {
        Uninterruptibles.sleepUninterruptibly(10L, TimeUnit.MILLISECONDS);
        return run(sparkPipelineOptions);
    }

    private static SparkPipelineResult run(SparkPipelineOptions sparkPipelineOptions) {
        KafkaIO.Read withWatermarkFn = KafkaIO.read().withBootstrapServers(EMBEDDED_KAFKA_CLUSTER.getBrokerList()).withTopics(Collections.singletonList(TOPIC)).withKeyCoder(StringUtf8Coder.of()).withValueCoder(InstantCoder.of()).updateConsumerProperties(ImmutableMap.of("auto.offset.reset", "earliest")).withTimestampFn(new SerializableFunction<KV<String, Instant>, Instant>() { // from class: org.apache.beam.runners.spark.translation.streaming.ResumeFromCheckpointStreamingTest.3
            public Instant apply(KV<String, Instant> kv) {
                return (Instant) kv.getValue();
            }
        }).withWatermarkFn(new SerializableFunction<KV<String, Instant>, Instant>() { // from class: org.apache.beam.runners.spark.translation.streaming.ResumeFromCheckpointStreamingTest.2
            public Instant apply(KV<String, Instant> kv) {
                return ((String) kv.getKey()).equals("EOF") ? BoundedWindow.TIMESTAMP_MAX_VALUE : (Instant) kv.getValue();
            }
        });
        Pipeline create = Pipeline.create(sparkPipelineOptions);
        PCollectionView apply = create.apply(Create.of(ImmutableList.of("side1", "side2")).withCoder(StringUtf8Coder.of())).apply(View.asList());
        create.apply(withWatermarkFn.withoutMetadata()).apply(Keys.create()).apply("EOFShallNotPassFn", ParDo.of(new EOFShallNotPassFn(apply)).withSideInputs(new PCollectionView[]{apply})).apply(Window.into(FixedWindows.of(Duration.millis(500L))).triggering(AfterWatermark.pastEndOfWindow()).accumulatingFiredPanes().withAllowedLateness(Duration.ZERO)).apply(WithKeys.of(1)).apply(GroupByKey.create()).apply(Values.create()).apply(new PAssertWithoutFlatten(new String[]{"k1", "k2", "k3", "k4", "k5"}));
        return create.run();
    }

    @AfterClass
    public static void tearDown() {
        EMBEDDED_KAFKA_CLUSTER.shutdown();
        EMBEDDED_ZOOKEEPER.shutdown();
    }
}
