package org.apache.beam.runners.spark.metrics;

import org.apache.beam.runners.core.metrics.TestMetricsSink;
import org.apache.beam.runners.spark.ReuseSparkContextRule;
import org.apache.beam.runners.spark.SparkPipelineOptions;
import org.apache.beam.runners.spark.StreamingTest;
import org.apache.beam.runners.spark.io.CreateStream;
import org.apache.beam.sdk.coders.VarIntCoder;
import org.apache.beam.sdk.metrics.Counter;
import org.apache.beam.sdk.metrics.Metrics;
import org.apache.beam.sdk.metrics.MetricsOptions;
import org.apache.beam.sdk.testing.TestPipeline;
import org.apache.beam.sdk.testing.UsesMetricsPusher;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.windowing.FixedWindows;
import org.apache.beam.sdk.transforms.windowing.Window;
import org.apache.beam.sdk.values.TimestampedValue;
import org.hamcrest.Matchers;
import org.joda.time.Duration;
import org.joda.time.Instant;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.experimental.categories.Category;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/beam/runners/spark/metrics/SparkMetricsPusherTest.class */
public class SparkMetricsPusherTest {
    private static final Logger LOG = LoggerFactory.getLogger(SparkMetricsPusherTest.class);
    private static final String COUNTER_NAME = "counter";

    @Rule
    public final transient ReuseSparkContextRule noContextResue = ReuseSparkContextRule.no();

    @Rule
    public final TestPipeline pipeline = TestPipeline.create();

    /* loaded from: input_file:org/apache/beam/runners/spark/metrics/SparkMetricsPusherTest$CountingDoFn.class */
    private static class CountingDoFn extends DoFn<Integer, Integer> {
        private final Counter counter;

        private CountingDoFn() {
            this.counter = Metrics.counter(SparkMetricsPusherTest.class, SparkMetricsPusherTest.COUNTER_NAME);
        }

        @DoFn.ProcessElement
        public void processElement(DoFn<Integer, Integer>.ProcessContext processContext) {
            try {
                this.counter.inc();
                processContext.output((Integer) processContext.element());
            } catch (Exception e) {
                SparkMetricsPusherTest.LOG.warn("Exception caught" + e);
            }
        }
    }

    private Duration batchDuration() {
        return Duration.millis(this.pipeline.getOptions().as(SparkPipelineOptions.class).getBatchIntervalMillis().longValue());
    }

    @Before
    public void init() {
        TestMetricsSink.clear();
        this.pipeline.getOptions().as(MetricsOptions.class).setMetricsSink(TestMetricsSink.class);
    }

    @Test
    @Category({StreamingTest.class})
    public void testInStreamingMode() throws Exception {
        Instant instant = new Instant(0L);
        this.pipeline.apply(CreateStream.of(VarIntCoder.of(), batchDuration()).emptyBatch().advanceWatermarkForNextBatch(instant).nextBatch(new TimestampedValue[]{TimestampedValue.of(1, instant), TimestampedValue.of(2, instant), TimestampedValue.of(3, instant)}).advanceWatermarkForNextBatch(instant.plus(Duration.standardSeconds(1L))).nextBatch(new TimestampedValue[]{TimestampedValue.of(4, instant.plus(Duration.standardSeconds(1L))), TimestampedValue.of(5, instant.plus(Duration.standardSeconds(1L))), TimestampedValue.of(6, instant.plus(Duration.standardSeconds(1L)))}).advanceNextBatchWatermarkToInfinity()).apply(Window.into(FixedWindows.of(Duration.standardSeconds(3L))).withAllowedLateness(Duration.ZERO)).apply(ParDo.of(new CountingDoFn()));
        this.pipeline.run();
        Thread.sleep((this.pipeline.getOptions().as(MetricsOptions.class).getMetricsPushPeriod().longValue() + 1) * 1000);
        Assert.assertThat(Long.valueOf(TestMetricsSink.getCounterValue(COUNTER_NAME)), Matchers.is(6L));
    }

    @Test
    @Category({UsesMetricsPusher.class})
    public void testInSBatchMode() throws Exception {
        this.pipeline.apply(Create.of(1, new Integer[]{2, 3, 4, 5, 6})).apply(ParDo.of(new CountingDoFn()));
        this.pipeline.run();
        Thread.sleep((this.pipeline.getOptions().as(MetricsOptions.class).getMetricsPushPeriod().longValue() + 1) * 1000);
        Assert.assertThat(Long.valueOf(TestMetricsSink.getCounterValue(COUNTER_NAME)), Matchers.is(6L));
    }
}
