package org.apache.beam.sdk.transforms;

import java.io.Serializable;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.testing.NeedsRunner;
import org.apache.beam.sdk.testing.TestPipeline;
import org.apache.beam.sdk.transforms.Combine;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.DoFnWithContext;
import org.apache.beam.sdk.transforms.Max;
import org.apache.beam.sdk.transforms.Sum;
import org.apache.beam.sdk.transforms.display.DisplayData;
import org.hamcrest.Matchers;
import org.junit.Assert;
import org.junit.Rule;
import org.junit.Test;
import org.junit.experimental.categories.Category;
import org.junit.rules.ExpectedException;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.Mockito;

@RunWith(JUnit4.class)
/* loaded from: input_file:org/apache/beam/sdk/transforms/DoFnWithContextTest.class */
public class DoFnWithContextTest implements Serializable {

    @Rule
    public transient ExpectedException thrown = ExpectedException.none();

    /* loaded from: input_file:org/apache/beam/sdk/transforms/DoFnWithContextTest$NoOpDoFnWithContext.class */
    private class NoOpDoFnWithContext extends DoFnWithContext<Void, Void> {
        private NoOpDoFnWithContext() {
        }

        @DoFnWithContext.ProcessElement
        public void processElement(DoFnWithContext<Void, Void>.ProcessContext processContext) {
        }
    }

    @Test
    public void testCreateAggregatorWithCombinerSucceeds() {
        Sum.SumLongFn sumLongFn = new Sum.SumLongFn();
        Aggregator createAggregator = new NoOpDoFnWithContext().createAggregator("testAggregator", sumLongFn);
        Assert.assertEquals("testAggregator", createAggregator.getName());
        Assert.assertEquals(sumLongFn, createAggregator.getCombineFn());
    }

    @Test
    public void testCreateAggregatorWithNullNameThrowsException() {
        this.thrown.expect(NullPointerException.class);
        this.thrown.expectMessage("name cannot be null");
        new NoOpDoFnWithContext().createAggregator((String) null, new Sum.SumLongFn());
    }

    @Test
    public void testCreateAggregatorWithNullCombineFnThrowsException() {
        this.thrown.expect(NullPointerException.class);
        this.thrown.expectMessage("combiner cannot be null");
        new NoOpDoFnWithContext().createAggregator("testAggregator", (Combine.CombineFn) null);
    }

    @Test
    public void testCreateAggregatorWithNullSerializableFnThrowsException() {
        this.thrown.expect(NullPointerException.class);
        this.thrown.expectMessage("combiner cannot be null");
        new NoOpDoFnWithContext().createAggregator("testAggregator", (SerializableFunction) null);
    }

    @Test
    public void testCreateAggregatorWithSameNameThrowsException() {
        Max.MaxDoubleFn maxDoubleFn = new Max.MaxDoubleFn();
        NoOpDoFnWithContext noOpDoFnWithContext = new NoOpDoFnWithContext();
        noOpDoFnWithContext.createAggregator("testAggregator", maxDoubleFn);
        this.thrown.expect(IllegalArgumentException.class);
        this.thrown.expectMessage("Cannot create");
        this.thrown.expectMessage("testAggregator");
        this.thrown.expectMessage("already exists");
        noOpDoFnWithContext.createAggregator("testAggregator", maxDoubleFn);
    }

    @Test
    public void testCreateAggregatorsWithDifferentNamesSucceeds() {
        Max.MaxDoubleFn maxDoubleFn = new Max.MaxDoubleFn();
        NoOpDoFnWithContext noOpDoFnWithContext = new NoOpDoFnWithContext();
        Assert.assertNotEquals(noOpDoFnWithContext.createAggregator("testAggregator", maxDoubleFn), noOpDoFnWithContext.createAggregator("aggregatorPrime", maxDoubleFn));
    }

    @Test
    public void testDoFnWithContextUsingAggregators() {
        NoOpDoFn noOpDoFn = new NoOpDoFn();
        DoFn.Context context = noOpDoFn.context();
        DoFn doFn = (DoFn) Mockito.spy(noOpDoFn);
        DoFn.Context context2 = (DoFn.Context) Mockito.spy(context);
        Aggregator aggregator = (Aggregator) Mockito.mock(Aggregator.class);
        Sum.SumLongFn sumLongFn = new Sum.SumLongFn();
        Aggregator createAggregator = doFn.createAggregator("test", sumLongFn);
        Mockito.when(context2.createAggregatorInternal("test", sumLongFn)).thenReturn(aggregator);
        context2.setupDelegateAggregators();
        createAggregator.addValue(1L);
        ((Aggregator) Mockito.verify(aggregator)).addValue(1L);
    }

    @Test
    public void testDefaultPopulateDisplayDataImplementation() {
        Assert.assertThat(DisplayData.from(new DoFnWithContext<String, String>() { // from class: org.apache.beam.sdk.transforms.DoFnWithContextTest.1
        }).items(), Matchers.empty());
    }

    @Test
    @Category({NeedsRunner.class})
    public void testCreateAggregatorInStartBundleThrows() {
        TestPipeline createTestPipeline = createTestPipeline(new DoFnWithContext<String, String>() { // from class: org.apache.beam.sdk.transforms.DoFnWithContextTest.2
            @DoFnWithContext.StartBundle
            public void startBundle(DoFnWithContext<String, String>.Context context) {
                createAggregator("anyAggregate", new Max.MaxIntegerFn());
            }

            @DoFnWithContext.ProcessElement
            public void processElement(DoFnWithContext<String, String>.ProcessContext processContext) {
            }
        });
        this.thrown.expect(Pipeline.PipelineExecutionException.class);
        this.thrown.expectCause(Matchers.isA(IllegalStateException.class));
        createTestPipeline.run();
    }

    @Test
    @Category({NeedsRunner.class})
    public void testCreateAggregatorInProcessElementThrows() {
        TestPipeline createTestPipeline = createTestPipeline(new DoFnWithContext<String, String>() { // from class: org.apache.beam.sdk.transforms.DoFnWithContextTest.3
            @DoFnWithContext.ProcessElement
            public void processElement(DoFnWithContext<String, String>.ProcessContext processContext) {
                createAggregator("anyAggregate", new Max.MaxIntegerFn());
            }
        });
        this.thrown.expect(Pipeline.PipelineExecutionException.class);
        this.thrown.expectCause(Matchers.isA(IllegalStateException.class));
        createTestPipeline.run();
    }

    @Test
    @Category({NeedsRunner.class})
    public void testCreateAggregatorInFinishBundleThrows() {
        TestPipeline createTestPipeline = createTestPipeline(new DoFnWithContext<String, String>() { // from class: org.apache.beam.sdk.transforms.DoFnWithContextTest.4
            @DoFnWithContext.FinishBundle
            public void finishBundle(DoFnWithContext<String, String>.Context context) {
                createAggregator("anyAggregate", new Max.MaxIntegerFn());
            }

            @DoFnWithContext.ProcessElement
            public void processElement(DoFnWithContext<String, String>.ProcessContext processContext) {
            }
        });
        this.thrown.expect(Pipeline.PipelineExecutionException.class);
        this.thrown.expectCause(Matchers.isA(IllegalStateException.class));
        createTestPipeline.run();
    }

    private <InputT, OutputT> TestPipeline createTestPipeline(DoFnWithContext<InputT, OutputT> doFnWithContext) {
        TestPipeline create = TestPipeline.create();
        create.apply(Create.of(new Object[]{null})).apply(ParDo.of(doFnWithContext));
        return create;
    }
}
