package org.apache.beam.runners.direct;

import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import org.apache.beam.runners.core.ExecutionContext;
import org.apache.beam.runners.direct.AggregatorContainer;
import org.apache.beam.sdk.transforms.Aggregator;
import org.apache.beam.sdk.transforms.Sum;
import org.hamcrest.Matchers;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.MockitoAnnotations;

@RunWith(JUnit4.class)
/* loaded from: input_file:org/apache/beam/runners/direct/AggregatorContainerTest.class */
public class AggregatorContainerTest {
    private static final String STEP_NAME = "step";

    @Mock
    private ExecutionContext.StepContext stepContext;

    @Rule
    public final ExpectedException thrown = ExpectedException.none();
    private final AggregatorContainer container = AggregatorContainer.create();
    private final Class<?> fn = getClass();

    @Before
    public void setUp() {
        MockitoAnnotations.initMocks(this);
        Mockito.when(this.stepContext.getStepName()).thenReturn(STEP_NAME);
    }

    @Test
    public void addsAggregatorsOnCommit() {
        AggregatorContainer.Mutator createMutator = this.container.createMutator();
        createMutator.createAggregatorForDoFn(this.fn, this.stepContext, "sum_int", Sum.ofIntegers()).addValue(5);
        createMutator.commit();
        Assert.assertThat((Integer) this.container.getAggregate(STEP_NAME, "sum_int"), Matchers.equalTo(5));
        AggregatorContainer.Mutator createMutator2 = this.container.createMutator();
        createMutator2.createAggregatorForDoFn(this.fn, this.stepContext, "sum_int", Sum.ofIntegers()).addValue(8);
        Assert.assertThat("Shouldn't update value until commit", (Integer) this.container.getAggregate(STEP_NAME, "sum_int"), Matchers.equalTo(5));
        createMutator2.commit();
        Assert.assertThat((Integer) this.container.getAggregate(STEP_NAME, "sum_int"), Matchers.equalTo(13));
    }

    @Test
    public void failToCreateAfterCommit() {
        AggregatorContainer.Mutator createMutator = this.container.createMutator();
        createMutator.commit();
        this.thrown.expect(IllegalStateException.class);
        createMutator.createAggregatorForDoFn(this.fn, this.stepContext, "sum_int", Sum.ofIntegers()).addValue(5);
    }

    @Test
    public void failToAddValueAfterCommit() {
        AggregatorContainer.Mutator createMutator = this.container.createMutator();
        Aggregator createAggregatorForDoFn = createMutator.createAggregatorForDoFn(this.fn, this.stepContext, "sum_int", Sum.ofIntegers());
        createMutator.commit();
        this.thrown.expect(IllegalStateException.class);
        createAggregatorForDoFn.addValue(5);
    }

    @Test
    public void failToAddValueAfterCommitWithPrevious() {
        AggregatorContainer.Mutator createMutator = this.container.createMutator();
        createMutator.createAggregatorForDoFn(this.fn, this.stepContext, "sum_int", Sum.ofIntegers()).addValue(5);
        createMutator.commit();
        AggregatorContainer.Mutator createMutator2 = this.container.createMutator();
        Aggregator createAggregatorForDoFn = createMutator2.createAggregatorForDoFn(this.fn, this.stepContext, "sum_int", Sum.ofIntegers());
        createMutator2.commit();
        this.thrown.expect(IllegalStateException.class);
        createAggregatorForDoFn.addValue(5);
    }

    @Test
    public void concurrentWrites() throws InterruptedException {
        ExecutorService newFixedThreadPool = Executors.newFixedThreadPool(20);
        int i = 0;
        for (int i2 = 0; i2 < 100; i2++) {
            i += i2;
            final int i3 = i2;
            final AggregatorContainer.Mutator createMutator = this.container.createMutator();
            newFixedThreadPool.submit(new Runnable() { // from class: org.apache.beam.runners.direct.AggregatorContainerTest.1
                @Override // java.lang.Runnable
                public void run() {
                    createMutator.createAggregatorForDoFn(AggregatorContainerTest.this.fn, AggregatorContainerTest.this.stepContext, "sum_int", Sum.ofIntegers()).addValue(Integer.valueOf(i3));
                    createMutator.commit();
                }
            });
        }
        newFixedThreadPool.shutdown();
        Assert.assertThat("Expected all threads to complete after 5 seconds", Boolean.valueOf(newFixedThreadPool.awaitTermination(5L, TimeUnit.SECONDS)), Matchers.equalTo(true));
        Assert.assertThat((Integer) this.container.getAggregate(STEP_NAME, "sum_int"), Matchers.equalTo(Integer.valueOf(i)));
    }
}
