package org.apache.beam.runners.direct;

import java.util.Iterator;
import org.apache.beam.runners.direct.DirectGroupByKey;
import org.apache.beam.runners.direct.repackaged.com.google.common.collect.HashMultiset;
import org.apache.beam.runners.direct.repackaged.com.google.common.collect.ImmutableSet;
import org.apache.beam.runners.direct.repackaged.runners.core.KeyedWorkItem;
import org.apache.beam.runners.direct.repackaged.runners.core.KeyedWorkItems;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.testing.TestPipeline;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
import org.hamcrest.BaseMatcher;
import org.hamcrest.Description;
import org.hamcrest.Matchers;
import org.joda.time.Instant;
import org.junit.Assert;
import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.Mockito;

@RunWith(JUnit4.class)
/* loaded from: input_file:org/apache/beam/runners/direct/GroupByKeyOnlyEvaluatorFactoryTest.class */
public class GroupByKeyOnlyEvaluatorFactoryTest {
    private BundleFactory bundleFactory = ImmutableListBundleFactory.create();

    @Rule
    public TestPipeline p = TestPipeline.create().enableAbandonedNodeEnforcement(false);

    /* loaded from: input_file:org/apache/beam/runners/direct/GroupByKeyOnlyEvaluatorFactoryTest$KeyedWorkItemMatcher.class */
    private static class KeyedWorkItemMatcher<K, V> extends BaseMatcher<WindowedValue<KeyedWorkItem<K, V>>> {
        private final KeyedWorkItem<K, V> myWorkItem;
        private final Coder<K> keyCoder;

        public KeyedWorkItemMatcher(KeyedWorkItem<K, V> keyedWorkItem, Coder<K> coder) {
            this.myWorkItem = keyedWorkItem;
            this.keyCoder = coder;
        }

        public boolean matches(Object obj) {
            if (obj == null || !(obj instanceof WindowedValue)) {
                return false;
            }
            WindowedValue windowedValue = (WindowedValue) obj;
            HashMultiset create = HashMultiset.create();
            HashMultiset create2 = HashMultiset.create();
            Iterator it = this.myWorkItem.elementsIterable().iterator();
            while (it.hasNext()) {
                create.add((WindowedValue) it.next());
            }
            Iterator it2 = ((KeyedWorkItem) windowedValue.getValue()).elementsIterable().iterator();
            while (it2.hasNext()) {
                create2.add((WindowedValue) it2.next());
            }
            try {
                if (create.equals(create2)) {
                    if (this.keyCoder.structuralValue(this.myWorkItem.key()).equals(this.keyCoder.structuralValue(((KeyedWorkItem) windowedValue.getValue()).key()))) {
                        return true;
                    }
                }
                return false;
            } catch (Exception e) {
                return false;
            }
        }

        public void describeTo(Description description) {
            description.appendText("KeyedWorkItem<K, V> containing key ").appendValue(this.myWorkItem.key()).appendText(" and values ").appendValueList("[", ", ", "]", this.myWorkItem.elementsIterable());
        }
    }

    @Test
    public void testInMemoryEvaluator() throws Exception {
        KV of = KV.of("foo", -1);
        KV of2 = KV.of("foo", 1);
        KV of3 = KV.of("foo", 3);
        KV of4 = KV.of("bar", 22);
        KV of5 = KV.of("bar", 12);
        KV of6 = KV.of("baz", Integer.MAX_VALUE);
        PCollection apply = this.p.apply(Create.of(of, new KV[]{of4, of2, of6, of5, of3}));
        PCollection apply2 = apply.apply(new DirectGroupByKey.DirectGroupByKeyOnly());
        CommittedBundle commit = this.bundleFactory.createBundle(apply).commit(Instant.now());
        EvaluationContext evaluationContext = (EvaluationContext) Mockito.mock(EvaluationContext.class);
        StructuralKey of7 = StructuralKey.of("foo", StringUtf8Coder.of());
        UncommittedBundle createKeyedBundle = this.bundleFactory.createKeyedBundle(of7, apply2);
        StructuralKey of8 = StructuralKey.of("bar", StringUtf8Coder.of());
        UncommittedBundle createKeyedBundle2 = this.bundleFactory.createKeyedBundle(of8, apply2);
        StructuralKey of9 = StructuralKey.of("baz", StringUtf8Coder.of());
        UncommittedBundle createKeyedBundle3 = this.bundleFactory.createKeyedBundle(of9, apply2);
        Mockito.when(evaluationContext.createKeyedBundle(of7, apply2)).thenReturn(createKeyedBundle);
        Mockito.when(evaluationContext.createKeyedBundle(of8, apply2)).thenReturn(createKeyedBundle2);
        Mockito.when(evaluationContext.createKeyedBundle(of9, apply2)).thenReturn(createKeyedBundle3);
        Coder keyCoder = apply.getCoder().getKeyCoder();
        TransformEvaluator forApplication = new GroupByKeyOnlyEvaluatorFactory(evaluationContext).forApplication(DirectGraphs.getProducer(apply2), commit);
        forApplication.processElement(WindowedValue.valueInGlobalWindow(of));
        forApplication.processElement(WindowedValue.valueInGlobalWindow(of2));
        forApplication.processElement(WindowedValue.valueInGlobalWindow(of3));
        forApplication.processElement(WindowedValue.valueInGlobalWindow(of4));
        forApplication.processElement(WindowedValue.valueInGlobalWindow(of5));
        forApplication.processElement(WindowedValue.valueInGlobalWindow(of6));
        forApplication.finishBundle();
        Assert.assertThat(createKeyedBundle.commit(Instant.now()).getElements(), Matchers.contains(new KeyedWorkItemMatcher(KeyedWorkItems.elementsWorkItem("foo", ImmutableSet.of(WindowedValue.valueInGlobalWindow(-1), WindowedValue.valueInGlobalWindow(1), WindowedValue.valueInGlobalWindow(3))), keyCoder)));
        Assert.assertThat(createKeyedBundle2.commit(Instant.now()).getElements(), Matchers.contains(new KeyedWorkItemMatcher(KeyedWorkItems.elementsWorkItem("bar", ImmutableSet.of(WindowedValue.valueInGlobalWindow(12), WindowedValue.valueInGlobalWindow(22))), keyCoder)));
        Assert.assertThat(createKeyedBundle3.commit(Instant.now()).getElements(), Matchers.contains(new KeyedWorkItemMatcher(KeyedWorkItems.elementsWorkItem("baz", ImmutableSet.of(WindowedValue.valueInGlobalWindow(Integer.MAX_VALUE))), keyCoder)));
    }

    private <K, V> KV<K, WindowedValue<V>> gwValue(KV<K, V> kv) {
        return KV.of(kv.getKey(), WindowedValue.valueInGlobalWindow(kv.getValue()));
    }
}
