package org.apache.beam.sdk.util.construction.graph;

import java.util.Collections;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.beam.model.pipeline.v1.RunnerApi;
import org.apache.beam.sdk.providers.GenerateSequenceSchemaTransformProvider;
import org.apache.beam.sdk.transforms.reflect.ByteBuddyDoFnInvokerFactory;
import org.apache.beam.sdk.util.construction.Environments;
import org.apache.beam.sdk.util.construction.PTransformTranslation;
import org.apache.beam.sdk.util.construction.graph.PipelineNode;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables;
import org.hamcrest.Description;
import org.hamcrest.MatcherAssert;
import org.hamcrest.Matchers;
import org.hamcrest.TypeSafeMatcher;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;

@RunWith(JUnit4.class)
/* loaded from: input_file:org/apache/beam/sdk/util/construction/graph/GreedyStageFuserTest.class */
public class GreedyStageFuserTest {

    @Rule
    public ExpectedException thrown = ExpectedException.none();
    private final RunnerApi.PCollection impulseDotOut = RunnerApi.PCollection.newBuilder().setUniqueName("impulse.out").build();
    private final PipelineNode.PCollectionNode impulseOutputNode = PipelineNode.pCollection("impulse.out", this.impulseDotOut);
    private RunnerApi.Components partialComponents;

    @Before
    public void setup() {
        this.partialComponents = RunnerApi.Components.newBuilder().putTransforms("impulse", RunnerApi.PTransform.newBuilder().putOutputs(GenerateSequenceSchemaTransformProvider.OUTPUT_ROWS_TAG, "impulse.out").setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn(PTransformTranslation.IMPULSE_TRANSFORM_URN)).build()).putPcollections("impulse.out", this.impulseDotOut).build();
    }

    @Test
    public void noInitialConsumersThrows() {
        QueryablePipeline forPrimitivesIn = QueryablePipeline.forPrimitivesIn(this.partialComponents);
        this.thrown.expect(IllegalArgumentException.class);
        this.thrown.expectMessage("at least one PTransform");
        GreedyStageFuser.forGrpcPortRead(forPrimitivesIn, this.impulseOutputNode, Collections.emptySet());
    }

    @Test
    public void differentEnvironmentsThrows() {
        QueryablePipeline forPrimitivesIn = QueryablePipeline.forPrimitivesIn(this.partialComponents.toBuilder().putTransforms("read", RunnerApi.PTransform.newBuilder().setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn(PTransformTranslation.PAR_DO_TRANSFORM_URN).build()).putInputs("input", "impulse.out").putOutputs(GenerateSequenceSchemaTransformProvider.OUTPUT_ROWS_TAG, "read.out").build()).putPcollections("read.out", RunnerApi.PCollection.newBuilder().setUniqueName("read.out").build()).putTransforms("goTransform", RunnerApi.PTransform.newBuilder().putInputs("input", "read.out").putOutputs(GenerateSequenceSchemaTransformProvider.OUTPUT_ROWS_TAG, "go.out").setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn(PTransformTranslation.PAR_DO_TRANSFORM_URN).setPayload(RunnerApi.ParDoPayload.newBuilder().setDoFn(RunnerApi.FunctionSpec.newBuilder()).build().toByteString())).setEnvironmentId("go").build()).putPcollections("go.out", RunnerApi.PCollection.newBuilder().setUniqueName("go.out").build()).putTransforms("pyTransform", RunnerApi.PTransform.newBuilder().putInputs("input", "read.out").putOutputs(GenerateSequenceSchemaTransformProvider.OUTPUT_ROWS_TAG, "py.out").setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn(PTransformTranslation.ASSIGN_WINDOWS_TRANSFORM_URN).setPayload(RunnerApi.WindowIntoPayload.newBuilder().setWindowFn(RunnerApi.FunctionSpec.newBuilder()).build().toByteString())).setEnvironmentId("py").build()).putPcollections("py.out", RunnerApi.PCollection.newBuilder().setUniqueName("py.out").build()).putEnvironments("go", Environments.createDockerEnvironment("go")).putEnvironments("py", Environments.createDockerEnvironment("py")).build());
        Set<PipelineNode.PTransformNode> perElementConsumers = forPrimitivesIn.getPerElementConsumers(PipelineNode.pCollection("read.out", RunnerApi.PCollection.newBuilder().setUniqueName("read.out").build()));
        this.thrown.expect(IllegalArgumentException.class);
        this.thrown.expectMessage("go");
        this.thrown.expectMessage("py");
        this.thrown.expectMessage("same");
        GreedyStageFuser.forGrpcPortRead(forPrimitivesIn, PipelineNode.pCollection("read.out", RunnerApi.PCollection.newBuilder().setUniqueName("read.out").build()), perElementConsumers);
    }

    @Test
    public void noEnvironmentThrows() {
        RunnerApi.PTransform build = RunnerApi.PTransform.newBuilder().putInputs("input", "impulse.out").setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn(PTransformTranslation.GROUP_BY_KEY_TRANSFORM_URN)).putOutputs(GenerateSequenceSchemaTransformProvider.OUTPUT_ROWS_TAG, "gbk.out").build();
        QueryablePipeline forPrimitivesIn = QueryablePipeline.forPrimitivesIn(this.partialComponents.toBuilder().putTransforms("runnerTransform", build).putPcollections("gbk.out", RunnerApi.PCollection.newBuilder().setUniqueName("gbk.out").build()).build());
        this.thrown.expect(IllegalArgumentException.class);
        this.thrown.expectMessage("Environment must be populated");
        GreedyStageFuser.forGrpcPortRead(forPrimitivesIn, this.impulseOutputNode, ImmutableSet.of(PipelineNode.pTransform("runnerTransform", build)));
    }

    @Test
    public void fusesCompatibleEnvironments() {
        RunnerApi.PTransform build = RunnerApi.PTransform.newBuilder().putInputs("input", "impulse.out").putOutputs(GenerateSequenceSchemaTransformProvider.OUTPUT_ROWS_TAG, "parDo.out").setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn(PTransformTranslation.PAR_DO_TRANSFORM_URN).setPayload(RunnerApi.ParDoPayload.newBuilder().setDoFn(RunnerApi.FunctionSpec.newBuilder()).build().toByteString())).setEnvironmentId("common").build();
        RunnerApi.PTransform build2 = RunnerApi.PTransform.newBuilder().putInputs("input", "impulse.out").putOutputs(GenerateSequenceSchemaTransformProvider.OUTPUT_ROWS_TAG, "window.out").setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn(PTransformTranslation.ASSIGN_WINDOWS_TRANSFORM_URN).setPayload(RunnerApi.WindowIntoPayload.newBuilder().setWindowFn(RunnerApi.FunctionSpec.newBuilder()).build().toByteString())).setEnvironmentId("common").build();
        ExecutableStage forGrpcPortRead = GreedyStageFuser.forGrpcPortRead(QueryablePipeline.forPrimitivesIn(this.partialComponents.toBuilder().putTransforms("parDo", build).putPcollections("parDo.out", RunnerApi.PCollection.newBuilder().setUniqueName("parDo.out").build()).putTransforms(ByteBuddyDoFnInvokerFactory.WINDOW_PARAMETER_METHOD, build2).putPcollections("window.out", RunnerApi.PCollection.newBuilder().setUniqueName("window.out").build()).putEnvironments("common", Environments.createDockerEnvironment("common")).build()), this.impulseOutputNode, ImmutableSet.of(PipelineNode.pTransform("parDo", build), PipelineNode.pTransform(ByteBuddyDoFnInvokerFactory.WINDOW_PARAMETER_METHOD, build2)));
        MatcherAssert.assertThat(forGrpcPortRead.getOutputPCollections(), Matchers.emptyIterable());
        MatcherAssert.assertThat(forGrpcPortRead, hasSubtransforms("parDo", ByteBuddyDoFnInvokerFactory.WINDOW_PARAMETER_METHOD));
    }

    @Test
    public void materializesWithStatefulConsumer() {
        RunnerApi.PTransform build = RunnerApi.PTransform.newBuilder().putInputs("input", "impulse.out").putOutputs(GenerateSequenceSchemaTransformProvider.OUTPUT_ROWS_TAG, "parDo.out").setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn(PTransformTranslation.PAR_DO_TRANSFORM_URN).setPayload(RunnerApi.ParDoPayload.newBuilder().setDoFn(RunnerApi.FunctionSpec.newBuilder()).build().toByteString())).setEnvironmentId("common").build();
        ExecutableStage forGrpcPortRead = GreedyStageFuser.forGrpcPortRead(QueryablePipeline.forPrimitivesIn(this.partialComponents.toBuilder().putTransforms("parDo", build).putPcollections("parDo.out", RunnerApi.PCollection.newBuilder().setUniqueName("parDo.out").build()).putTransforms("stateful", RunnerApi.PTransform.newBuilder().putInputs("input", "parDo.out").putOutputs(GenerateSequenceSchemaTransformProvider.OUTPUT_ROWS_TAG, "stateful.out").setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn(PTransformTranslation.PAR_DO_TRANSFORM_URN).setPayload(RunnerApi.ParDoPayload.newBuilder().setDoFn(RunnerApi.FunctionSpec.newBuilder()).putStateSpecs(ByteBuddyDoFnInvokerFactory.STATE_PARAMETER_METHOD, RunnerApi.StateSpec.getDefaultInstance()).build().toByteString())).setEnvironmentId("common").build()).putPcollections("stateful.out", RunnerApi.PCollection.newBuilder().setUniqueName("stateful.out").build()).putEnvironments("common", Environments.createDockerEnvironment("common")).build()), this.impulseOutputNode, ImmutableSet.of(PipelineNode.pTransform("parDo", build)));
        MatcherAssert.assertThat(forGrpcPortRead.getOutputPCollections(), Matchers.contains(new PipelineNode.PCollectionNode[]{PipelineNode.pCollection("parDo.out", RunnerApi.PCollection.newBuilder().setUniqueName("parDo.out").build())}));
        MatcherAssert.assertThat(forGrpcPortRead, hasSubtransforms("parDo", new String[0]));
    }

    @Test
    public void materializesWithConsumerWithTimer() {
        RunnerApi.PTransform build = RunnerApi.PTransform.newBuilder().putInputs("input", "impulse.out").putOutputs(GenerateSequenceSchemaTransformProvider.OUTPUT_ROWS_TAG, "parDo.out").setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn(PTransformTranslation.PAR_DO_TRANSFORM_URN).setPayload(RunnerApi.ParDoPayload.newBuilder().setDoFn(RunnerApi.FunctionSpec.newBuilder()).build().toByteString())).setEnvironmentId("common").build();
        ExecutableStage forGrpcPortRead = GreedyStageFuser.forGrpcPortRead(QueryablePipeline.forPrimitivesIn(this.partialComponents.toBuilder().putTransforms("parDo", build).putPcollections("parDo.out", RunnerApi.PCollection.newBuilder().setUniqueName("parDo.out").build()).putTransforms(ByteBuddyDoFnInvokerFactory.TIMER_PARAMETER_METHOD, RunnerApi.PTransform.newBuilder().putInputs("input", "parDo.out").putOutputs(GenerateSequenceSchemaTransformProvider.OUTPUT_ROWS_TAG, "timer.out").setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn(PTransformTranslation.PAR_DO_TRANSFORM_URN).setPayload(RunnerApi.ParDoPayload.newBuilder().setDoFn(RunnerApi.FunctionSpec.newBuilder()).putTimerFamilySpecs(ByteBuddyDoFnInvokerFactory.TIMER_PARAMETER_METHOD, RunnerApi.TimerFamilySpec.getDefaultInstance()).build().toByteString())).setEnvironmentId("common").build()).putPcollections("timer.out", RunnerApi.PCollection.newBuilder().setUniqueName("timer.out").build()).putEnvironments("common", Environments.createDockerEnvironment("common")).build()), this.impulseOutputNode, ImmutableSet.of(PipelineNode.pTransform("parDo", build)));
        MatcherAssert.assertThat(forGrpcPortRead.getOutputPCollections(), Matchers.contains(new PipelineNode.PCollectionNode[]{PipelineNode.pCollection("parDo.out", RunnerApi.PCollection.newBuilder().setUniqueName("parDo.out").build())}));
        MatcherAssert.assertThat(forGrpcPortRead, hasSubtransforms("parDo", new String[0]));
    }

    @Test
    public void fusesFlatten() {
        RunnerApi.PTransform build = RunnerApi.PTransform.newBuilder().putInputs("input", "impulse.out").putOutputs(GenerateSequenceSchemaTransformProvider.OUTPUT_ROWS_TAG, "read.out").setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn(PTransformTranslation.PAR_DO_TRANSFORM_URN).setPayload(RunnerApi.ParDoPayload.newBuilder().setDoFn(RunnerApi.FunctionSpec.newBuilder()).build().toByteString())).setEnvironmentId("common").build();
        RunnerApi.PTransform build2 = RunnerApi.PTransform.newBuilder().putInputs("input", "impulse.out").putOutputs(GenerateSequenceSchemaTransformProvider.OUTPUT_ROWS_TAG, "parDo.out").setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn(PTransformTranslation.PAR_DO_TRANSFORM_URN).setPayload(RunnerApi.ParDoPayload.newBuilder().setDoFn(RunnerApi.FunctionSpec.newBuilder()).build().toByteString())).setEnvironmentId("common").build();
        QueryablePipeline forPrimitivesIn = QueryablePipeline.forPrimitivesIn(this.partialComponents.toBuilder().putTransforms("read", build).putPcollections("read.out", RunnerApi.PCollection.newBuilder().setUniqueName("read.out").build()).putTransforms("parDo", build2).putPcollections("parDo.out", RunnerApi.PCollection.newBuilder().setUniqueName("parDo.out").build()).putTransforms("flatten", RunnerApi.PTransform.newBuilder().putInputs("readInput", "read.out").putInputs("parDoInput", "parDo.out").putOutputs(GenerateSequenceSchemaTransformProvider.OUTPUT_ROWS_TAG, "flatten.out").setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn(PTransformTranslation.FLATTEN_TRANSFORM_URN)).build()).putPcollections("flatten.out", RunnerApi.PCollection.newBuilder().setUniqueName("flatten.out").build()).putTransforms(ByteBuddyDoFnInvokerFactory.WINDOW_PARAMETER_METHOD, RunnerApi.PTransform.newBuilder().putInputs("input", "flatten.out").putOutputs(GenerateSequenceSchemaTransformProvider.OUTPUT_ROWS_TAG, "window.out").setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn(PTransformTranslation.ASSIGN_WINDOWS_TRANSFORM_URN).setPayload(RunnerApi.WindowIntoPayload.newBuilder().setWindowFn(RunnerApi.FunctionSpec.newBuilder()).build().toByteString())).setEnvironmentId("common").build()).putPcollections("window.out", RunnerApi.PCollection.newBuilder().setUniqueName("window.out").build()).putEnvironments("common", Environments.createDockerEnvironment("common")).build());
        ExecutableStage forGrpcPortRead = GreedyStageFuser.forGrpcPortRead(forPrimitivesIn, this.impulseOutputNode, forPrimitivesIn.getPerElementConsumers(this.impulseOutputNode));
        MatcherAssert.assertThat(forGrpcPortRead.getOutputPCollections(), Matchers.emptyIterable());
        MatcherAssert.assertThat(forGrpcPortRead, hasSubtransforms("read", "parDo", "flatten", ByteBuddyDoFnInvokerFactory.WINDOW_PARAMETER_METHOD));
    }

    @Test
    public void fusesFlattenWithDifferentEnvironmentInputs() {
        RunnerApi.PTransform build = RunnerApi.PTransform.newBuilder().putInputs("input", "impulse.out").putOutputs(GenerateSequenceSchemaTransformProvider.OUTPUT_ROWS_TAG, "read.out").setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn(PTransformTranslation.PAR_DO_TRANSFORM_URN).setPayload(RunnerApi.ParDoPayload.newBuilder().setDoFn(RunnerApi.FunctionSpec.newBuilder()).build().toByteString())).setEnvironmentId("common").build();
        RunnerApi.PTransform build2 = RunnerApi.PTransform.newBuilder().putInputs("impulse", "impulse.out").putOutputs(GenerateSequenceSchemaTransformProvider.OUTPUT_ROWS_TAG, "envRead.out").setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn(PTransformTranslation.PAR_DO_TRANSFORM_URN).setPayload(RunnerApi.ParDoPayload.newBuilder().setDoFn(RunnerApi.FunctionSpec.newBuilder()).build().toByteString())).setEnvironmentId("rare").build();
        RunnerApi.Components build3 = this.partialComponents.toBuilder().putTransforms("read", build).putPcollections("read.out", RunnerApi.PCollection.newBuilder().setUniqueName("read.out").build()).putTransforms("envRead", build2).putPcollections("envRead.out", RunnerApi.PCollection.newBuilder().setUniqueName("envRead.out").build()).putTransforms("flatten", RunnerApi.PTransform.newBuilder().putInputs("readInput", "read.out").putInputs("otherEnvInput", "envRead.out").putOutputs(GenerateSequenceSchemaTransformProvider.OUTPUT_ROWS_TAG, "flatten.out").setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn(PTransformTranslation.FLATTEN_TRANSFORM_URN)).build()).putPcollections("flatten.out", RunnerApi.PCollection.newBuilder().setUniqueName("flatten.out").build()).putTransforms(ByteBuddyDoFnInvokerFactory.WINDOW_PARAMETER_METHOD, RunnerApi.PTransform.newBuilder().putInputs("input", "flatten.out").putOutputs(GenerateSequenceSchemaTransformProvider.OUTPUT_ROWS_TAG, "window.out").setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn(PTransformTranslation.ASSIGN_WINDOWS_TRANSFORM_URN).setPayload(RunnerApi.WindowIntoPayload.newBuilder().setWindowFn(RunnerApi.FunctionSpec.newBuilder()).build().toByteString())).setEnvironmentId("common").build()).putPcollections("window.out", RunnerApi.PCollection.newBuilder().setUniqueName("window.out").build()).putEnvironments("common", Environments.createDockerEnvironment("common")).putEnvironments("rare", Environments.createDockerEnvironment("rare")).build();
        QueryablePipeline forPrimitivesIn = QueryablePipeline.forPrimitivesIn(build3);
        ExecutableStage forGrpcPortRead = GreedyStageFuser.forGrpcPortRead(forPrimitivesIn, this.impulseOutputNode, ImmutableSet.of(PipelineNode.pTransform("read", build)));
        MatcherAssert.assertThat(forGrpcPortRead.getOutputPCollections(), Matchers.emptyIterable());
        MatcherAssert.assertThat(forGrpcPortRead, hasSubtransforms("read", "flatten", ByteBuddyDoFnInvokerFactory.WINDOW_PARAMETER_METHOD));
        ExecutableStage forGrpcPortRead2 = GreedyStageFuser.forGrpcPortRead(forPrimitivesIn, this.impulseOutputNode, ImmutableSet.of(PipelineNode.pTransform("envRead", build2)));
        MatcherAssert.assertThat(forGrpcPortRead2.getOutputPCollections(), Matchers.contains(new PipelineNode.PCollectionNode[]{PipelineNode.pCollection("flatten.out", build3.getPcollectionsOrThrow("flatten.out"))}));
        MatcherAssert.assertThat(forGrpcPortRead2, hasSubtransforms("envRead", "flatten"));
    }

    @Test
    public void flattenWithHeterogeneousInputsAndOutputs() {
        RunnerApi.PTransform build = RunnerApi.PTransform.newBuilder().putInputs("input", "impulse.out").putOutputs(GenerateSequenceSchemaTransformProvider.OUTPUT_ROWS_TAG, "pyRead.out").setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn(PTransformTranslation.PAR_DO_TRANSFORM_URN).setPayload(RunnerApi.ParDoPayload.newBuilder().setDoFn(RunnerApi.FunctionSpec.newBuilder()).build().toByteString()).build()).setEnvironmentId("py").build();
        RunnerApi.PTransform build2 = RunnerApi.PTransform.newBuilder().putInputs("input", "impulse.out").putOutputs(GenerateSequenceSchemaTransformProvider.OUTPUT_ROWS_TAG, "goRead.out").setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn(PTransformTranslation.PAR_DO_TRANSFORM_URN).setPayload(RunnerApi.ParDoPayload.newBuilder().setDoFn(RunnerApi.FunctionSpec.newBuilder()).build().toByteString()).build()).setEnvironmentId("go").build();
        RunnerApi.PTransform build3 = RunnerApi.PTransform.newBuilder().putInputs("input", "flatten.out").putOutputs(GenerateSequenceSchemaTransformProvider.OUTPUT_ROWS_TAG, "pyParDo.out").setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn(PTransformTranslation.PAR_DO_TRANSFORM_URN).setPayload(RunnerApi.ParDoPayload.newBuilder().setDoFn(RunnerApi.FunctionSpec.newBuilder()).build().toByteString()).build()).setEnvironmentId("py").build();
        RunnerApi.PTransform build4 = RunnerApi.PTransform.newBuilder().putInputs("input", "flatten.out").putOutputs(GenerateSequenceSchemaTransformProvider.OUTPUT_ROWS_TAG, "goWindow.out").setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn(PTransformTranslation.ASSIGN_WINDOWS_TRANSFORM_URN).setPayload(RunnerApi.WindowIntoPayload.newBuilder().setWindowFn(RunnerApi.FunctionSpec.newBuilder()).build().toByteString()).build()).setEnvironmentId("go").build();
        RunnerApi.PCollection build5 = RunnerApi.PCollection.newBuilder().setUniqueName("flatten.out").build();
        QueryablePipeline forPrimitivesIn = QueryablePipeline.forPrimitivesIn(this.partialComponents.toBuilder().putTransforms("pyRead", build).putPcollections("pyRead.out", RunnerApi.PCollection.newBuilder().setUniqueName("pyRead.out").build()).putTransforms("goRead", build2).putPcollections("goRead.out", RunnerApi.PCollection.newBuilder().setUniqueName("goRead.out").build()).putTransforms("flatten", RunnerApi.PTransform.newBuilder().putInputs("py_input", "pyRead.out").putInputs("go_input", "goRead.out").putOutputs(GenerateSequenceSchemaTransformProvider.OUTPUT_ROWS_TAG, "flatten.out").setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn(PTransformTranslation.FLATTEN_TRANSFORM_URN).build()).build()).putPcollections("flatten.out", build5).putTransforms("pyParDo", build3).putPcollections("pyParDo.out", RunnerApi.PCollection.newBuilder().setUniqueName("pyParDo.out").build()).putTransforms("goWindow", build4).putPcollections("goWindow.out", RunnerApi.PCollection.newBuilder().setUniqueName("goWindow.out").build()).putEnvironments("go", Environments.createDockerEnvironment("go")).putEnvironments("py", Environments.createDockerEnvironment("py")).build());
        ExecutableStage forGrpcPortRead = GreedyStageFuser.forGrpcPortRead(forPrimitivesIn, this.impulseOutputNode, ImmutableSet.of(PipelineNode.pTransform("pyRead", build)));
        ExecutableStage forGrpcPortRead2 = GreedyStageFuser.forGrpcPortRead(forPrimitivesIn, this.impulseOutputNode, ImmutableSet.of(PipelineNode.pTransform("goRead", build2)));
        MatcherAssert.assertThat(forGrpcPortRead.getOutputPCollections(), Matchers.contains(new PipelineNode.PCollectionNode[]{PipelineNode.pCollection("flatten.out", build5)}));
        MatcherAssert.assertThat(forGrpcPortRead.getTransforms(), Matchers.not(Matchers.hasItem(PipelineNode.pTransform("pyParDo", build3))));
        MatcherAssert.assertThat(forGrpcPortRead2.getOutputPCollections(), Matchers.contains(new PipelineNode.PCollectionNode[]{PipelineNode.pCollection("flatten.out", build5)}));
        MatcherAssert.assertThat(forGrpcPortRead2.getTransforms(), Matchers.not(Matchers.hasItem(PipelineNode.pTransform("goWindow", build4))));
    }

    @Test
    public void materializesWithDifferentEnvConsumer() {
        RunnerApi.Environment createDockerEnvironment = Environments.createDockerEnvironment("common");
        RunnerApi.PTransform build = RunnerApi.PTransform.newBuilder().putInputs("input", "impulse.out").putOutputs("out", "parDo.out").setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn(PTransformTranslation.PAR_DO_TRANSFORM_URN).setPayload(RunnerApi.ParDoPayload.newBuilder().setDoFn(RunnerApi.FunctionSpec.newBuilder()).build().toByteString())).setEnvironmentId("common").build();
        RunnerApi.PCollection build2 = RunnerApi.PCollection.newBuilder().setUniqueName("parDo.out").build();
        QueryablePipeline forPrimitivesIn = QueryablePipeline.forPrimitivesIn(this.partialComponents.toBuilder().putTransforms("parDo", build).putPcollections("parDo.out", build2).putTransforms(ByteBuddyDoFnInvokerFactory.WINDOW_PARAMETER_METHOD, RunnerApi.PTransform.newBuilder().putInputs("input", "parDo.out").putOutputs(GenerateSequenceSchemaTransformProvider.OUTPUT_ROWS_TAG, "window.out").setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn(PTransformTranslation.ASSIGN_WINDOWS_TRANSFORM_URN).setPayload(RunnerApi.WindowIntoPayload.newBuilder().setWindowFn(RunnerApi.FunctionSpec.newBuilder()).build().toByteString())).setEnvironmentId("rare").build()).putPcollections("window.out", RunnerApi.PCollection.newBuilder().setUniqueName("window.out").build()).putEnvironments("rare", Environments.createDockerEnvironment("rare")).putEnvironments("common", createDockerEnvironment).build());
        ExecutableStage forGrpcPortRead = GreedyStageFuser.forGrpcPortRead(forPrimitivesIn, this.impulseOutputNode, forPrimitivesIn.getPerElementConsumers(this.impulseOutputNode));
        MatcherAssert.assertThat(forGrpcPortRead.getOutputPCollections(), Matchers.contains(new PipelineNode.PCollectionNode[]{PipelineNode.pCollection("parDo.out", build2)}));
        MatcherAssert.assertThat(forGrpcPortRead.getInputPCollection(), Matchers.equalTo(this.impulseOutputNode));
        MatcherAssert.assertThat(forGrpcPortRead.getEnvironment(), Matchers.equalTo(createDockerEnvironment));
        MatcherAssert.assertThat(forGrpcPortRead.getTransforms(), Matchers.contains(new PipelineNode.PTransformNode[]{PipelineNode.pTransform("parDo", build)}));
    }

    @Test
    public void materializesWithDifferentEnvSibling() {
        RunnerApi.Environment createDockerEnvironment = Environments.createDockerEnvironment("common");
        RunnerApi.PTransform build = RunnerApi.PTransform.newBuilder().putInputs("input", "impulse.out").putOutputs(GenerateSequenceSchemaTransformProvider.OUTPUT_ROWS_TAG, "read.out").setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn(PTransformTranslation.PAR_DO_TRANSFORM_URN).setPayload(RunnerApi.ParDoPayload.newBuilder().setDoFn(RunnerApi.FunctionSpec.newBuilder()).build().toByteString())).setEnvironmentId("common").build();
        QueryablePipeline forPrimitivesIn = QueryablePipeline.forPrimitivesIn(this.partialComponents.toBuilder().putTransforms("read", build).putPcollections("read.out", RunnerApi.PCollection.newBuilder().setUniqueName("read.out").build()).putTransforms("parDo", RunnerApi.PTransform.newBuilder().putInputs("input", "read.out").putOutputs(GenerateSequenceSchemaTransformProvider.OUTPUT_ROWS_TAG, "parDo.out").setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn(PTransformTranslation.PAR_DO_TRANSFORM_URN).setPayload(RunnerApi.ParDoPayload.newBuilder().setDoFn(RunnerApi.FunctionSpec.newBuilder()).build().toByteString())).setEnvironmentId("common").build()).putPcollections("parDo.out", RunnerApi.PCollection.newBuilder().setUniqueName("parDo.out").build()).putTransforms(ByteBuddyDoFnInvokerFactory.WINDOW_PARAMETER_METHOD, RunnerApi.PTransform.newBuilder().putInputs("input", "read.out").putOutputs(GenerateSequenceSchemaTransformProvider.OUTPUT_ROWS_TAG, "window.out").setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn(PTransformTranslation.ASSIGN_WINDOWS_TRANSFORM_URN).setPayload(RunnerApi.WindowIntoPayload.newBuilder().setWindowFn(RunnerApi.FunctionSpec.newBuilder()).build().toByteString())).setEnvironmentId("rare").build()).putPcollections("window.out", RunnerApi.PCollection.newBuilder().setUniqueName("window.out").build()).putEnvironments("rare", Environments.createDockerEnvironment("rare")).putEnvironments("common", createDockerEnvironment).build());
        PipelineNode.PTransformNode pTransform = PipelineNode.pTransform("read", build);
        PipelineNode.PCollectionNode pCollectionNode = (PipelineNode.PCollectionNode) Iterables.getOnlyElement(forPrimitivesIn.getOutputPCollections(pTransform));
        ExecutableStage forGrpcPortRead = GreedyStageFuser.forGrpcPortRead(forPrimitivesIn, this.impulseOutputNode, ImmutableSet.of(PipelineNode.pTransform("read", build)));
        MatcherAssert.assertThat(forGrpcPortRead.getOutputPCollections(), Matchers.contains(new PipelineNode.PCollectionNode[]{pCollectionNode}));
        MatcherAssert.assertThat(forGrpcPortRead.getTransforms(), Matchers.contains(new PipelineNode.PTransformNode[]{pTransform}));
    }

    @Test
    public void materializesWithSideInputConsumer() {
        RunnerApi.Environment createDockerEnvironment = Environments.createDockerEnvironment("common");
        RunnerApi.PTransform build = RunnerApi.PTransform.newBuilder().putInputs("input", "impulse.out").putOutputs(GenerateSequenceSchemaTransformProvider.OUTPUT_ROWS_TAG, "read.out").setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn(PTransformTranslation.PAR_DO_TRANSFORM_URN).setPayload(RunnerApi.ParDoPayload.newBuilder().setDoFn(RunnerApi.FunctionSpec.newBuilder()).build().toByteString())).setEnvironmentId("common").build();
        QueryablePipeline forPrimitivesIn = QueryablePipeline.forPrimitivesIn(this.partialComponents.toBuilder().putTransforms("read", build).putPcollections("read.out", RunnerApi.PCollection.newBuilder().setUniqueName("read.out").build()).putTransforms("side_read", RunnerApi.PTransform.newBuilder().setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn(PTransformTranslation.PAR_DO_TRANSFORM_URN)).putInputs("input", "impulse.out").putOutputs(GenerateSequenceSchemaTransformProvider.OUTPUT_ROWS_TAG, "side_read.out").build()).putPcollections("side_read.out", RunnerApi.PCollection.newBuilder().setUniqueName("side_read.out").build()).putTransforms("parDo", RunnerApi.PTransform.newBuilder().putInputs("input", "read.out").putInputs("side_input", "side_read.out").putOutputs(GenerateSequenceSchemaTransformProvider.OUTPUT_ROWS_TAG, "parDo.out").setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn(PTransformTranslation.PAR_DO_TRANSFORM_URN).setPayload(RunnerApi.ParDoPayload.newBuilder().setDoFn(RunnerApi.FunctionSpec.newBuilder()).putSideInputs("side_input", RunnerApi.SideInput.getDefaultInstance()).build().toByteString())).setEnvironmentId("common").build()).putPcollections("parDo.out", RunnerApi.PCollection.newBuilder().setUniqueName("parDo.out").build()).putTransforms(ByteBuddyDoFnInvokerFactory.WINDOW_PARAMETER_METHOD, RunnerApi.PTransform.newBuilder().putInputs("input", "read.out").putOutputs(GenerateSequenceSchemaTransformProvider.OUTPUT_ROWS_TAG, "window.out").setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn(PTransformTranslation.ASSIGN_WINDOWS_TRANSFORM_URN).setPayload(RunnerApi.WindowIntoPayload.newBuilder().setWindowFn(RunnerApi.FunctionSpec.newBuilder()).build().toByteString())).setEnvironmentId("common").build()).putPcollections("window.out", RunnerApi.PCollection.newBuilder().setUniqueName("window.out").build()).putEnvironments("common", createDockerEnvironment).build());
        PipelineNode.PTransformNode pTransform = PipelineNode.pTransform("read", build);
        PipelineNode.PCollectionNode pCollectionNode = (PipelineNode.PCollectionNode) Iterables.getOnlyElement(forPrimitivesIn.getOutputPCollections(pTransform));
        ExecutableStage forGrpcPortRead = GreedyStageFuser.forGrpcPortRead(forPrimitivesIn, this.impulseOutputNode, ImmutableSet.of(pTransform));
        MatcherAssert.assertThat(forGrpcPortRead.getOutputPCollections(), Matchers.contains(new PipelineNode.PCollectionNode[]{pCollectionNode}));
        MatcherAssert.assertThat(forGrpcPortRead, hasSubtransforms(pTransform.getId(), new String[0]));
    }

    @Test
    public void sideInputIncludedInStage() {
        RunnerApi.Environment createDockerEnvironment = Environments.createDockerEnvironment("common");
        RunnerApi.PTransform build = RunnerApi.PTransform.newBuilder().setUniqueName("read").putInputs("input", "impulse.out").putOutputs(GenerateSequenceSchemaTransformProvider.OUTPUT_ROWS_TAG, "read.out").setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn(PTransformTranslation.PAR_DO_TRANSFORM_URN).setPayload(RunnerApi.ParDoPayload.newBuilder().setDoFn(RunnerApi.FunctionSpec.newBuilder()).build().toByteString())).setEnvironmentId("common").build();
        RunnerApi.PTransform build2 = RunnerApi.PTransform.newBuilder().setUniqueName("parDo").putInputs("input", "read.out").putInputs("side_input", "side_read.out").putOutputs(GenerateSequenceSchemaTransformProvider.OUTPUT_ROWS_TAG, "parDo.out").setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn(PTransformTranslation.PAR_DO_TRANSFORM_URN).setPayload(RunnerApi.ParDoPayload.newBuilder().setDoFn(RunnerApi.FunctionSpec.newBuilder()).putSideInputs("side_input", RunnerApi.SideInput.getDefaultInstance()).build().toByteString())).setEnvironmentId("common").build();
        RunnerApi.PCollection build3 = RunnerApi.PCollection.newBuilder().setUniqueName("side_read.out").build();
        QueryablePipeline forPrimitivesIn = QueryablePipeline.forPrimitivesIn(this.partialComponents.toBuilder().putTransforms("read", build).putPcollections("read.out", RunnerApi.PCollection.newBuilder().setUniqueName("read.out").build()).putTransforms("side_read", RunnerApi.PTransform.newBuilder().setUniqueName("side_read").setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn(PTransformTranslation.PAR_DO_TRANSFORM_URN)).putInputs("input", "impulse.out").putOutputs(GenerateSequenceSchemaTransformProvider.OUTPUT_ROWS_TAG, "side_read.out").build()).putPcollections("side_read.out", build3).putTransforms("parDo", build2).putPcollections("parDo.out", RunnerApi.PCollection.newBuilder().setUniqueName("parDo.out").build()).putEnvironments("common", createDockerEnvironment).build());
        ExecutableStage forGrpcPortRead = GreedyStageFuser.forGrpcPortRead(forPrimitivesIn, (PipelineNode.PCollectionNode) Iterables.getOnlyElement(forPrimitivesIn.getOutputPCollections(PipelineNode.pTransform("read", build))), ImmutableSet.of(PipelineNode.pTransform("parDo", build2)));
        MatcherAssert.assertThat(forGrpcPortRead.getSideInputs(), Matchers.contains(new SideInputReference[]{SideInputReference.of(PipelineNode.pTransform("parDo", build2), "side_input", PipelineNode.pCollection("side_read.out", build3))}));
        MatcherAssert.assertThat(forGrpcPortRead.getOutputPCollections(), Matchers.emptyIterable());
    }

    @Test
    public void executableStageProducingSideInputMaterializesIt() {
        RunnerApi.Environment createDockerEnvironment = Environments.createDockerEnvironment("common");
        RunnerApi.PTransform build = RunnerApi.PTransform.newBuilder().setUniqueName("impulse").putOutputs(GenerateSequenceSchemaTransformProvider.OUTPUT_ROWS_TAG, "impulsePC").setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn(PTransformTranslation.IMPULSE_TRANSFORM_URN)).build();
        RunnerApi.PTransform build2 = RunnerApi.PTransform.newBuilder().setUniqueName("createSide").putInputs("input", "impulsePC").putOutputs(GenerateSequenceSchemaTransformProvider.OUTPUT_ROWS_TAG, "sidePC").setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn(PTransformTranslation.PAR_DO_TRANSFORM_URN).setPayload(RunnerApi.ParDoPayload.newBuilder().setDoFn(RunnerApi.FunctionSpec.newBuilder()).build().toByteString())).setEnvironmentId("common").build();
        RunnerApi.PTransform build3 = RunnerApi.PTransform.newBuilder().setUniqueName("processMain").putInputs("main", "impulsePC").putInputs("side", "sidePC").setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn(PTransformTranslation.PAR_DO_TRANSFORM_URN).setPayload(RunnerApi.ParDoPayload.newBuilder().setDoFn(RunnerApi.FunctionSpec.newBuilder()).putSideInputs("side", RunnerApi.SideInput.getDefaultInstance()).build().toByteString())).setEnvironmentId("common").build();
        RunnerApi.PCollection build4 = RunnerApi.PCollection.newBuilder().setUniqueName("sidePC").build();
        QueryablePipeline forPrimitivesIn = QueryablePipeline.forPrimitivesIn(this.partialComponents.toBuilder().putTransforms("impulse", build).putTransforms("createSide", build2).putTransforms("processMain", build3).putPcollections("impulsePC", RunnerApi.PCollection.newBuilder().setUniqueName("impulsePC").build()).putPcollections("sidePC", build4).putEnvironments("common", createDockerEnvironment).build());
        MatcherAssert.assertThat(GreedyStageFuser.forGrpcPortRead(forPrimitivesIn, (PipelineNode.PCollectionNode) Iterables.getOnlyElement(forPrimitivesIn.getOutputPCollections(PipelineNode.pTransform("impulse", build))), ImmutableSet.of(PipelineNode.pTransform("createSide", build2))).getOutputPCollections(), Matchers.contains(new PipelineNode.PCollectionNode[]{PipelineNode.pCollection("sidePC", build4)}));
    }

    @Test
    public void userStateIncludedInStage() {
        RunnerApi.Environment createDockerEnvironment = Environments.createDockerEnvironment("common");
        RunnerApi.PTransform build = RunnerApi.PTransform.newBuilder().putInputs("input", "impulse.out").putOutputs(GenerateSequenceSchemaTransformProvider.OUTPUT_ROWS_TAG, "read.out").setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn(PTransformTranslation.PAR_DO_TRANSFORM_URN).setPayload(RunnerApi.ParDoPayload.newBuilder().setDoFn(RunnerApi.FunctionSpec.newBuilder()).build().toByteString())).setEnvironmentId("common").build();
        RunnerApi.PTransform build2 = RunnerApi.PTransform.newBuilder().putInputs("input", "read.out").putOutputs(GenerateSequenceSchemaTransformProvider.OUTPUT_ROWS_TAG, "parDo.out").setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn(PTransformTranslation.PAR_DO_TRANSFORM_URN).setPayload(RunnerApi.ParDoPayload.newBuilder().setDoFn(RunnerApi.FunctionSpec.newBuilder()).putStateSpecs("state_spec", RunnerApi.StateSpec.getDefaultInstance()).build().toByteString())).setEnvironmentId("common").build();
        RunnerApi.PCollection build3 = RunnerApi.PCollection.newBuilder().setUniqueName("read.out").build();
        QueryablePipeline forPrimitivesIn = QueryablePipeline.forPrimitivesIn(this.partialComponents.toBuilder().putTransforms("read", build).putPcollections("read.out", build3).putTransforms("user_state", RunnerApi.PTransform.newBuilder().putInputs("input", "impulse.out").putOutputs(GenerateSequenceSchemaTransformProvider.OUTPUT_ROWS_TAG, "user_state.out").build()).putPcollections("user_state.out", RunnerApi.PCollection.newBuilder().setUniqueName("user_state.out").build()).putTransforms("parDo", build2).putPcollections("parDo.out", RunnerApi.PCollection.newBuilder().setUniqueName("parDo.out").build()).putEnvironments("common", createDockerEnvironment).build());
        ExecutableStage forGrpcPortRead = GreedyStageFuser.forGrpcPortRead(forPrimitivesIn, (PipelineNode.PCollectionNode) Iterables.getOnlyElement(forPrimitivesIn.getOutputPCollections(PipelineNode.pTransform("read", build))), ImmutableSet.of(PipelineNode.pTransform("parDo", build2)));
        MatcherAssert.assertThat(forGrpcPortRead.getUserStates(), Matchers.contains(new UserStateReference[]{UserStateReference.of(PipelineNode.pTransform("parDo", build2), "state_spec", PipelineNode.pCollection("read.out", build3))}));
        MatcherAssert.assertThat(forGrpcPortRead.getOutputPCollections(), Matchers.emptyIterable());
    }

    @Test
    public void materializesWithGroupByKeyConsumer() {
        RunnerApi.Environment createDockerEnvironment = Environments.createDockerEnvironment("common");
        RunnerApi.PTransform build = RunnerApi.PTransform.newBuilder().putInputs("input", "impulse.out").putOutputs(GenerateSequenceSchemaTransformProvider.OUTPUT_ROWS_TAG, "read.out").setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn(PTransformTranslation.PAR_DO_TRANSFORM_URN).setPayload(RunnerApi.ParDoPayload.newBuilder().setDoFn(RunnerApi.FunctionSpec.newBuilder()).build().toByteString())).setEnvironmentId("common").build();
        QueryablePipeline forPrimitivesIn = QueryablePipeline.forPrimitivesIn(this.partialComponents.toBuilder().putTransforms("read", build).putPcollections("read.out", RunnerApi.PCollection.newBuilder().setUniqueName("read.out").build()).putTransforms("gbk", RunnerApi.PTransform.newBuilder().putInputs("input", "read.out").putOutputs(GenerateSequenceSchemaTransformProvider.OUTPUT_ROWS_TAG, "gbk.out").setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn(PTransformTranslation.GROUP_BY_KEY_TRANSFORM_URN)).build()).putPcollections("gbk.out", RunnerApi.PCollection.newBuilder().setUniqueName("parDo.out").build()).putEnvironments("common", createDockerEnvironment).build());
        PipelineNode.PTransformNode pTransform = PipelineNode.pTransform("read", build);
        PipelineNode.PCollectionNode pCollectionNode = (PipelineNode.PCollectionNode) Iterables.getOnlyElement(forPrimitivesIn.getOutputPCollections(pTransform));
        ExecutableStage forGrpcPortRead = GreedyStageFuser.forGrpcPortRead(forPrimitivesIn, this.impulseOutputNode, ImmutableSet.of(pTransform));
        MatcherAssert.assertThat(forGrpcPortRead.getOutputPCollections(), Matchers.contains(new PipelineNode.PCollectionNode[]{pCollectionNode}));
        MatcherAssert.assertThat(forGrpcPortRead, hasSubtransforms(pTransform.getId(), new String[0]));
    }

    private static TypeSafeMatcher<ExecutableStage> hasSubtransforms(String str, String... strArr) {
        final ImmutableSet build = ImmutableSet.builder().add(str).add(strArr).build();
        return new TypeSafeMatcher<ExecutableStage>() { // from class: org.apache.beam.sdk.util.construction.graph.GreedyStageFuserTest.1
            /* JADX INFO: Access modifiers changed from: protected */
            public boolean matchesSafely(ExecutableStage executableStage) {
                Set set = (Set) executableStage.getTransforms().stream().map((v0) -> {
                    return v0.getId();
                }).collect(Collectors.toSet());
                return set.containsAll(build) && build.containsAll(set);
            }

            public void describeTo(Description description) {
                description.appendText("ExecutableStage with subtransform ids: " + build);
            }
        };
    }
}
