package org.apache.flink.runtime.checkpoint;

import java.util.Collections;
import java.util.EnumMap;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.apache.flink.runtime.jobgraph.OperatorID;
import org.apache.flink.runtime.state.KeyedStateHandle;
import org.apache.flink.runtime.state.OperatorStateHandle;
import org.apache.flink.runtime.state.OperatorStreamStateHandle;
import org.apache.flink.runtime.state.memory.ByteStreamStateHandle;
import org.apache.flink.util.TestLogger;
import org.junit.Assert;
import org.junit.Test;

/* loaded from: input_file:org/apache/flink/runtime/checkpoint/StateAssignmentOperationTest.class */
public class StateAssignmentOperationTest extends TestLogger {
    @Test
    public void testRepartitionSplitDistributeStates() {
        OperatorID operatorID = new OperatorID();
        OperatorState operatorState = new OperatorState(operatorID, 2, 4);
        HashMap hashMap = new HashMap(1);
        hashMap.put("t-1", new OperatorStateHandle.StateMetaInfo(new long[]{0, 10}, OperatorStateHandle.Mode.SPLIT_DISTRIBUTE));
        operatorState.putState(0, new OperatorSubtaskState(new OperatorStreamStateHandle(hashMap, new ByteStreamStateHandle("test1", new byte[30])), (OperatorStateHandle) null, (KeyedStateHandle) null, (KeyedStateHandle) null));
        HashMap hashMap2 = new HashMap(1);
        hashMap2.put("t-2", new OperatorStateHandle.StateMetaInfo(new long[]{0, 15}, OperatorStateHandle.Mode.SPLIT_DISTRIBUTE));
        operatorState.putState(1, new OperatorSubtaskState(new OperatorStreamStateHandle(hashMap2, new ByteStreamStateHandle("test2", new byte[40])), (OperatorStateHandle) null, (KeyedStateHandle) null, (KeyedStateHandle) null));
        verifyOneKindPartitionableStateRescale(operatorState, operatorID);
    }

    @Test
    public void testRepartitionUnionState() {
        OperatorID operatorID = new OperatorID();
        OperatorState operatorState = new OperatorState(operatorID, 2, 4);
        HashMap hashMap = new HashMap(2);
        hashMap.put("t-3", new OperatorStateHandle.StateMetaInfo(new long[]{0}, OperatorStateHandle.Mode.UNION));
        hashMap.put("t-4", new OperatorStateHandle.StateMetaInfo(new long[]{22, 44}, OperatorStateHandle.Mode.UNION));
        operatorState.putState(0, new OperatorSubtaskState(new OperatorStreamStateHandle(hashMap, new ByteStreamStateHandle("test1", new byte[50])), (OperatorStateHandle) null, (KeyedStateHandle) null, (KeyedStateHandle) null));
        HashMap hashMap2 = new HashMap(1);
        hashMap2.put("t-3", new OperatorStateHandle.StateMetaInfo(new long[]{0}, OperatorStateHandle.Mode.UNION));
        operatorState.putState(1, new OperatorSubtaskState(new OperatorStreamStateHandle(hashMap2, new ByteStreamStateHandle("test2", new byte[20])), (OperatorStateHandle) null, (KeyedStateHandle) null, (KeyedStateHandle) null));
        verifyOneKindPartitionableStateRescale(operatorState, operatorID);
    }

    @Test
    public void testRepartitionBroadcastState() {
        OperatorID operatorID = new OperatorID();
        OperatorState operatorState = new OperatorState(operatorID, 2, 4);
        HashMap hashMap = new HashMap(2);
        hashMap.put("t-5", new OperatorStateHandle.StateMetaInfo(new long[]{0, 10, 20}, OperatorStateHandle.Mode.BROADCAST));
        hashMap.put("t-6", new OperatorStateHandle.StateMetaInfo(new long[]{30, 40, 50}, OperatorStateHandle.Mode.BROADCAST));
        operatorState.putState(0, new OperatorSubtaskState(new OperatorStreamStateHandle(hashMap, new ByteStreamStateHandle("test1", new byte[60])), (OperatorStateHandle) null, (KeyedStateHandle) null, (KeyedStateHandle) null));
        HashMap hashMap2 = new HashMap(2);
        hashMap2.put("t-5", new OperatorStateHandle.StateMetaInfo(new long[]{0, 10, 20}, OperatorStateHandle.Mode.BROADCAST));
        hashMap2.put("t-6", new OperatorStateHandle.StateMetaInfo(new long[]{30, 40, 50}, OperatorStateHandle.Mode.BROADCAST));
        operatorState.putState(1, new OperatorSubtaskState(new OperatorStreamStateHandle(hashMap2, new ByteStreamStateHandle("test2", new byte[60])), (OperatorStateHandle) null, (KeyedStateHandle) null, (KeyedStateHandle) null));
        verifyOneKindPartitionableStateRescale(operatorState, operatorID);
    }

    @Test
    public void testReDistributeCombinedPartitionableStates() {
        OperatorID operatorID = new OperatorID();
        OperatorState operatorState = new OperatorState(operatorID, 2, 4);
        HashMap hashMap = new HashMap(6);
        hashMap.put("t-1", new OperatorStateHandle.StateMetaInfo(new long[]{0}, OperatorStateHandle.Mode.UNION));
        hashMap.put("t-2", new OperatorStateHandle.StateMetaInfo(new long[]{22, 44}, OperatorStateHandle.Mode.UNION));
        hashMap.put("t-3", new OperatorStateHandle.StateMetaInfo(new long[]{52, 63}, OperatorStateHandle.Mode.SPLIT_DISTRIBUTE));
        hashMap.put("t-4", new OperatorStateHandle.StateMetaInfo(new long[]{67, 74, 75}, OperatorStateHandle.Mode.BROADCAST));
        hashMap.put("t-5", new OperatorStateHandle.StateMetaInfo(new long[]{77, 88, 92}, OperatorStateHandle.Mode.BROADCAST));
        hashMap.put("t-6", new OperatorStateHandle.StateMetaInfo(new long[]{101, 123, 127}, OperatorStateHandle.Mode.BROADCAST));
        operatorState.putState(0, new OperatorSubtaskState(new OperatorStreamStateHandle(hashMap, new ByteStreamStateHandle("test1", new byte[130])), (OperatorStateHandle) null, (KeyedStateHandle) null, (KeyedStateHandle) null));
        HashMap hashMap2 = new HashMap(3);
        hashMap2.put("t-1", new OperatorStateHandle.StateMetaInfo(new long[]{0}, OperatorStateHandle.Mode.UNION));
        hashMap2.put("t-4", new OperatorStateHandle.StateMetaInfo(new long[]{20, 27, 28}, OperatorStateHandle.Mode.BROADCAST));
        hashMap2.put("t-5", new OperatorStateHandle.StateMetaInfo(new long[]{30, 44, 48}, OperatorStateHandle.Mode.BROADCAST));
        hashMap2.put("t-6", new OperatorStateHandle.StateMetaInfo(new long[]{57, 79, 83}, OperatorStateHandle.Mode.BROADCAST));
        operatorState.putState(1, new OperatorSubtaskState(new OperatorStreamStateHandle(hashMap2, new ByteStreamStateHandle("test2", new byte[86])), (OperatorStateHandle) null, (KeyedStateHandle) null, (KeyedStateHandle) null));
        verifyCombinedPartitionableStateRescale(operatorState, operatorID, 2, 3);
        verifyCombinedPartitionableStateRescale(operatorState, operatorID, 2, 1);
        verifyCombinedPartitionableStateRescale(operatorState, operatorID, 2, 2);
    }

    private void verifyAndCollectStateInfo(OperatorState operatorState, OperatorID operatorID, int i, int i2, Map<String, Integer> map) {
        HashMap hashMap = new HashMap(i2);
        StateAssignmentOperation.reDistributePartitionableStates(Collections.singletonList(operatorState), i2, Collections.singletonList(operatorID), hashMap, new HashMap(i2));
        for (List list : hashMap.values()) {
            EnumMap enumMap = new EnumMap(OperatorStateHandle.Mode.class);
            for (OperatorStateHandle.Mode mode : OperatorStateHandle.Mode.values()) {
                enumMap.put((EnumMap) mode, (OperatorStateHandle.Mode) new HashMap());
            }
            Iterator it = list.iterator();
            while (it.hasNext()) {
                for (Map.Entry entry : ((OperatorStateHandle) it.next()).getStateNameToPartitionOffsets().entrySet()) {
                    String str = (String) entry.getKey();
                    map.merge(str, 1, (num, num2) -> {
                        return Integer.valueOf(num.intValue() + num2.intValue());
                    });
                    OperatorStateHandle.StateMetaInfo stateMetaInfo = (OperatorStateHandle.StateMetaInfo) entry.getValue();
                    ((Map) enumMap.get(stateMetaInfo.getDistributionMode())).merge(str, Integer.valueOf(stateMetaInfo.getOffsets().length), (num3, num4) -> {
                        return Integer.valueOf(num3.intValue() + num4.intValue());
                    });
                }
            }
            for (Map.Entry entry2 : enumMap.entrySet()) {
                OperatorStateHandle.Mode mode2 = (OperatorStateHandle.Mode) entry2.getKey();
                Map map2 = (Map) entry2.getValue();
                if (OperatorStateHandle.Mode.SPLIT_DISTRIBUTE.equals(mode2)) {
                    if (i < i2) {
                        map2.values().forEach(num5 -> {
                            Assert.assertEquals(1L, num5.intValue());
                        });
                    } else {
                        map2.values().forEach(num6 -> {
                            Assert.assertEquals(2L, num6.intValue());
                        });
                    }
                } else if (OperatorStateHandle.Mode.UNION.equals(mode2)) {
                    map2.values().forEach(num7 -> {
                        Assert.assertEquals(2L, num7.intValue());
                    });
                } else {
                    map2.values().forEach(num8 -> {
                        Assert.assertEquals(3L, num8.intValue());
                    });
                }
            }
        }
    }

    private void verifyOneKindPartitionableStateRescale(OperatorState operatorState, OperatorID operatorID) {
        verifyOneKindPartitionableStateRescale(operatorState, operatorID, 2, 3);
        verifyOneKindPartitionableStateRescale(operatorState, operatorID, 2, 1);
        verifyOneKindPartitionableStateRescale(operatorState, operatorID, 2, 2);
    }

    private void verifyOneKindPartitionableStateRescale(OperatorState operatorState, OperatorID operatorID, int i, int i2) {
        HashMap hashMap = new HashMap();
        verifyAndCollectStateInfo(operatorState, operatorID, i, i2, hashMap);
        Assert.assertEquals(2L, hashMap.size());
        if (hashMap.containsKey("t-1")) {
            if (i < i2) {
                Assert.assertEquals(2L, hashMap.get("t-1").intValue());
                Assert.assertEquals(2L, hashMap.get("t-2").intValue());
            } else {
                Assert.assertEquals(1L, hashMap.get("t-1").intValue());
                Assert.assertEquals(1L, hashMap.get("t-2").intValue());
            }
        }
        if (hashMap.containsKey("t-3")) {
            Assert.assertEquals(2 * i2, hashMap.get("t-3").intValue());
            Assert.assertEquals(i2, hashMap.get("t-4").intValue());
        }
        if (hashMap.containsKey("t-5")) {
            Assert.assertEquals(i2, hashMap.get("t-5").intValue());
            Assert.assertEquals(i2, hashMap.get("t-6").intValue());
        }
    }

    private void verifyCombinedPartitionableStateRescale(OperatorState operatorState, OperatorID operatorID, int i, int i2) {
        verifyAndCollectStateInfo(operatorState, operatorID, i, i2, new HashMap());
        Assert.assertEquals(6L, r0.size());
        Assert.assertEquals(2 * i2, r0.get("t-1").intValue());
        Assert.assertEquals(i2, r0.get("t-2").intValue());
        if (i < i2) {
            Assert.assertEquals(2L, r0.get("t-3").intValue());
        } else {
            Assert.assertEquals(1L, r0.get("t-3").intValue());
        }
        Assert.assertEquals(i2, r0.get("t-4").intValue());
        Assert.assertEquals(i2, r0.get("t-5").intValue());
        Assert.assertEquals(i2, r0.get("t-6").intValue());
    }
}
