package org.apache.flink.runtime.checkpoint;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.function.Function;
import org.apache.flink.runtime.OperatorIDPair;
import org.apache.flink.runtime.checkpoint.channel.InputChannelInfo;
import org.apache.flink.runtime.checkpoint.channel.ResultSubpartitionInfo;
import org.apache.flink.runtime.jobgraph.OperatorID;
import org.apache.flink.runtime.state.AbstractChannelStateHandle;
import org.apache.flink.runtime.state.InputChannelStateHandle;
import org.apache.flink.runtime.state.ResultSubpartitionStateHandle;
import org.apache.flink.runtime.state.testutils.EmptyStreamStateHandle;
import org.junit.Assert;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;

@RunWith(Parameterized.class)
/* loaded from: input_file:org/apache/flink/runtime/checkpoint/ChannelStateNoRescalingPartitionerTest.class */
public class ChannelStateNoRescalingPartitionerTest {
    private static final OperatorID OPERATOR_ID = new OperatorID();
    private final int oldParallelism;
    private final int newParallelism;
    private final int offsetsSize;
    private final Function<OperatorSubtaskState, ? extends StateObjectCollection<?>> extractState;

    @Parameterized.Parameters(name = "oldParallelism: {0}, newParallelism: {1}, offsetSize: {2}")
    public static Collection<Object[]> parameters() {
        ArrayList arrayList = new ArrayList();
        int[] iArr = {1, 2};
        int[] iArr2 = {0, 1, 2};
        List asList = Arrays.asList((v0) -> {
            return v0.getInputChannelState();
        }, (v0) -> {
            return v0.getResultSubpartitionState();
        });
        for (int i : iArr) {
            for (int i2 : iArr) {
                for (int i3 : iArr2) {
                    Iterator it = asList.iterator();
                    while (it.hasNext()) {
                        arrayList.add(new Object[]{Integer.valueOf(i), Integer.valueOf(i2), Integer.valueOf(i3), (Function) it.next()});
                    }
                }
            }
        }
        return arrayList;
    }

    public ChannelStateNoRescalingPartitionerTest(int i, int i2, int i3, Function<OperatorSubtaskState, ? extends StateObjectCollection<?>> function) {
        this.oldParallelism = i;
        this.newParallelism = i2;
        this.offsetsSize = i3;
        this.extractState = function;
    }

    @Test
    public <T extends AbstractChannelStateHandle<?>> void testNoRescaling() {
        OperatorState operatorState = new OperatorState(OPERATOR_ID, this.oldParallelism, this.oldParallelism);
        operatorState.putState(0, new OperatorSubtaskState(StateObjectCollection.empty(), StateObjectCollection.empty(), StateObjectCollection.empty(), StateObjectCollection.empty(), StateObjectCollection.singleton(new InputChannelStateHandle(new InputChannelInfo(0, 0), new EmptyStreamStateHandle(), getOffset())), StateObjectCollection.singleton(new ResultSubpartitionStateHandle(new ResultSubpartitionInfo(0, 0), new EmptyStreamStateHandle(), getOffset()))));
        try {
            StateAssignmentOperation.reDistributePartitionableStates(Collections.singletonList(operatorState), this.newParallelism, Collections.singletonList(OperatorIDPair.generatedIDOnly(OPERATOR_ID)), this.extractState, StateAssignmentOperation.channelStateNonRescalingRepartitioner("test"));
            if (shouldFail()) {
                Assert.fail("expected to fail for: oldParallelism=" + this.oldParallelism + ", newParallelism=" + this.newParallelism + ", offsetsSize=" + this.offsetsSize + ", extractState=" + this.extractState);
            }
        } catch (IllegalArgumentException e) {
            if (!shouldFail()) {
                throw e;
            }
        }
    }

    private boolean shouldFail() {
        return this.oldParallelism != this.newParallelism && this.offsetsSize > 0;
    }

    private List<Long> getOffset() {
        ArrayList arrayList = new ArrayList(this.offsetsSize);
        for (int i = 0; i < this.offsetsSize; i++) {
            arrayList.add(0L);
        }
        return arrayList;
    }
}
