package org.apache.flink.runtime.checkpoint;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.apache.flink.annotation.VisibleForTesting;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.runtime.executiongraph.Execution;
import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
import org.apache.flink.runtime.jobgraph.JobVertexID;
import org.apache.flink.runtime.jobgraph.OperatorID;
import org.apache.flink.runtime.jobgraph.OperatorInstanceID;
import org.apache.flink.runtime.state.KeyGroupRange;
import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
import org.apache.flink.runtime.state.KeyedStateHandle;
import org.apache.flink.runtime.state.OperatorStateHandle;
import org.apache.flink.util.Preconditions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/flink/runtime/checkpoint/StateAssignmentOperation.class */
public class StateAssignmentOperation {
    private static final Logger LOG = LoggerFactory.getLogger((Class<?>) StateAssignmentOperation.class);
    private final Map<JobVertexID, ExecutionJobVertex> tasks;
    private final Map<OperatorID, OperatorState> operatorStates;
    private final long restoreCheckpointId;
    private final boolean allowNonRestoredState;

    public StateAssignmentOperation(long j, Map<JobVertexID, ExecutionJobVertex> map, Map<OperatorID, OperatorState> map2, boolean z) {
        this.restoreCheckpointId = j;
        this.tasks = (Map) Preconditions.checkNotNull(map);
        this.operatorStates = (Map) Preconditions.checkNotNull(map2);
        this.allowNonRestoredState = z;
    }

    public void assignStates() {
        HashMap hashMap = new HashMap(this.operatorStates);
        checkStateMappingCompleteness(this.allowNonRestoredState, this.operatorStates, this.tasks);
        for (Map.Entry<JobVertexID, ExecutionJobVertex> entry : this.tasks.entrySet()) {
            ExecutionJobVertex value = entry.getValue();
            List<OperatorID> operatorIDs = value.getOperatorIDs();
            List<OperatorID> userDefinedOperatorIDs = value.getUserDefinedOperatorIDs();
            ArrayList arrayList = new ArrayList(operatorIDs.size());
            boolean z = true;
            for (int i = 0; i < operatorIDs.size(); i++) {
                OperatorID operatorID = userDefinedOperatorIDs.get(i) == null ? operatorIDs.get(i) : userDefinedOperatorIDs.get(i);
                OperatorState operatorState = (OperatorState) hashMap.remove(operatorID);
                if (operatorState == null) {
                    operatorState = new OperatorState(operatorID, value.getParallelism(), value.getMaxParallelism());
                } else {
                    z = false;
                }
                arrayList.add(operatorState);
            }
            if (!z) {
                assignAttemptState(entry.getValue(), arrayList);
            }
        }
    }

    private void assignAttemptState(ExecutionJobVertex executionJobVertex, List<OperatorState> list) {
        List<OperatorID> operatorIDs = executionJobVertex.getOperatorIDs();
        checkParallelismPreconditions(list, executionJobVertex);
        int parallelism = executionJobVertex.getParallelism();
        List<KeyGroupRange> createKeyGroupPartitions = createKeyGroupPartitions(executionJobVertex.getMaxParallelism(), parallelism);
        int size = parallelism * operatorIDs.size();
        HashMap hashMap = new HashMap(size);
        HashMap hashMap2 = new HashMap(size);
        reDistributePartitionableStates(list, parallelism, operatorIDs, hashMap, hashMap2);
        HashMap hashMap3 = new HashMap(size);
        HashMap hashMap4 = new HashMap(size);
        reDistributeKeyedStates(list, parallelism, operatorIDs, createKeyGroupPartitions, hashMap3, hashMap4);
        assignTaskStateToExecutionJobVertices(executionJobVertex, hashMap, hashMap2, hashMap3, hashMap4, parallelism);
    }

    private void assignTaskStateToExecutionJobVertices(ExecutionJobVertex executionJobVertex, Map<OperatorInstanceID, List<OperatorStateHandle>> map, Map<OperatorInstanceID, List<OperatorStateHandle>> map2, Map<OperatorInstanceID, List<KeyedStateHandle>> map3, Map<OperatorInstanceID, List<KeyedStateHandle>> map4, int i) {
        List<OperatorID> operatorIDs = executionJobVertex.getOperatorIDs();
        for (int i2 = 0; i2 < i; i2++) {
            Execution currentExecutionAttempt = executionJobVertex.getTaskVertices()[i2].getCurrentExecutionAttempt();
            TaskStateSnapshot taskStateSnapshot = new TaskStateSnapshot(operatorIDs.size());
            boolean z = true;
            for (OperatorID operatorID : operatorIDs) {
                OperatorSubtaskState operatorSubtaskStateFrom = operatorSubtaskStateFrom(OperatorInstanceID.of(i2, operatorID), map, map2, map3, map4);
                if (operatorSubtaskStateFrom.hasState()) {
                    z = false;
                }
                taskStateSnapshot.putSubtaskStateByOperatorID(operatorID, operatorSubtaskStateFrom);
            }
            if (!z) {
                currentExecutionAttempt.setInitialState(new JobManagerTaskRestore(this.restoreCheckpointId, taskStateSnapshot));
            }
        }
    }

    public static OperatorSubtaskState operatorSubtaskStateFrom(OperatorInstanceID operatorInstanceID, Map<OperatorInstanceID, List<OperatorStateHandle>> map, Map<OperatorInstanceID, List<OperatorStateHandle>> map2, Map<OperatorInstanceID, List<KeyedStateHandle>> map3, Map<OperatorInstanceID, List<KeyedStateHandle>> map4) {
        if (!map.containsKey(operatorInstanceID) && !map2.containsKey(operatorInstanceID) && !map3.containsKey(operatorInstanceID) && !map4.containsKey(operatorInstanceID)) {
            return new OperatorSubtaskState();
        }
        if (!map3.containsKey(operatorInstanceID)) {
            Preconditions.checkState(!map4.containsKey(operatorInstanceID));
        }
        return new OperatorSubtaskState((StateObjectCollection<OperatorStateHandle>) new StateObjectCollection(map.getOrDefault(operatorInstanceID, Collections.emptyList())), (StateObjectCollection<OperatorStateHandle>) new StateObjectCollection(map2.getOrDefault(operatorInstanceID, Collections.emptyList())), (StateObjectCollection<KeyedStateHandle>) new StateObjectCollection(map3.getOrDefault(operatorInstanceID, Collections.emptyList())), (StateObjectCollection<KeyedStateHandle>) new StateObjectCollection(map4.getOrDefault(operatorInstanceID, Collections.emptyList())));
    }

    public void checkParallelismPreconditions(List<OperatorState> list, ExecutionJobVertex executionJobVertex) {
        Iterator<OperatorState> it = list.iterator();
        while (it.hasNext()) {
            checkParallelismPreconditions(it.next(), executionJobVertex);
        }
    }

    private void reDistributeKeyedStates(List<OperatorState> list, int i, List<OperatorID> list2, List<KeyGroupRange> list3, Map<OperatorInstanceID, List<KeyedStateHandle>> map, Map<OperatorInstanceID, List<KeyedStateHandle>> map2) {
        Preconditions.checkState(list2.size() == list.size(), "This method still depends on the order of the new and old operators");
        for (int i2 = 0; i2 < list2.size(); i2++) {
            OperatorState operatorState = list.get(i2);
            int parallelism = operatorState.getParallelism();
            for (int i3 = 0; i3 < i; i3++) {
                OperatorInstanceID of = OperatorInstanceID.of(i3, list2.get(i2));
                Tuple2<List<KeyedStateHandle>, List<KeyedStateHandle>> reAssignSubKeyedStates = reAssignSubKeyedStates(operatorState, list3, i3, i, parallelism);
                map.put(of, reAssignSubKeyedStates.f0);
                map2.put(of, reAssignSubKeyedStates.f1);
            }
        }
    }

    private Tuple2<List<KeyedStateHandle>, List<KeyedStateHandle>> reAssignSubKeyedStates(OperatorState operatorState, List<KeyGroupRange> list, int i, int i2, int i3) {
        List<KeyedStateHandle> managedKeyedStateHandles;
        List<KeyedStateHandle> rawKeyedStateHandles;
        if (i2 != i3) {
            managedKeyedStateHandles = getManagedKeyedStateHandles(operatorState, list.get(i));
            rawKeyedStateHandles = getRawKeyedStateHandles(operatorState, list.get(i));
        } else if (operatorState.getState(i) != null) {
            managedKeyedStateHandles = operatorState.getState(i).getManagedKeyedState().asList();
            rawKeyedStateHandles = operatorState.getState(i).getRawKeyedState().asList();
        } else {
            managedKeyedStateHandles = Collections.emptyList();
            rawKeyedStateHandles = Collections.emptyList();
        }
        return (managedKeyedStateHandles.isEmpty() && rawKeyedStateHandles.isEmpty()) ? new Tuple2<>(Collections.emptyList(), Collections.emptyList()) : new Tuple2<>(managedKeyedStateHandles, rawKeyedStateHandles);
    }

    @VisibleForTesting
    static void reDistributePartitionableStates(List<OperatorState> list, int i, List<OperatorID> list2, Map<OperatorInstanceID, List<OperatorStateHandle>> map, Map<OperatorInstanceID, List<OperatorStateHandle>> map2) {
        Preconditions.checkState(list2.size() == list.size(), "This method still depends on the order of the new and old operators");
        ArrayList arrayList = new ArrayList(list.size());
        ArrayList arrayList2 = new ArrayList(list.size());
        splitManagedAndRawOperatorStates(list, arrayList, arrayList2);
        OperatorStateRepartitioner operatorStateRepartitioner = RoundRobinOperatorStateRepartitioner.INSTANCE;
        for (int i2 = 0; i2 < list2.size(); i2++) {
            int parallelism = list.get(i2).getParallelism();
            OperatorID operatorID = list2.get(i2);
            map.putAll(applyRepartitioner(operatorID, operatorStateRepartitioner, (List) arrayList.get(i2), parallelism, i));
            map2.putAll(applyRepartitioner(operatorID, operatorStateRepartitioner, (List) arrayList2.get(i2), parallelism, i));
        }
    }

    private static void splitManagedAndRawOperatorStates(List<OperatorState> list, List<List<List<OperatorStateHandle>>> list2, List<List<List<OperatorStateHandle>>> list3) {
        for (OperatorState operatorState : list) {
            int parallelism = operatorState.getParallelism();
            ArrayList arrayList = new ArrayList(parallelism);
            ArrayList arrayList2 = new ArrayList(parallelism);
            for (int i = 0; i < parallelism; i++) {
                OperatorSubtaskState state = operatorState.getState(i);
                if (state == null) {
                    arrayList.add(Collections.emptyList());
                    arrayList2.add(Collections.emptyList());
                } else {
                    StateObjectCollection<OperatorStateHandle> managedOperatorState = state.getManagedOperatorState();
                    StateObjectCollection<OperatorStateHandle> rawOperatorState = state.getRawOperatorState();
                    arrayList.add(managedOperatorState.asList());
                    arrayList2.add(rawOperatorState.asList());
                }
            }
            list2.add(arrayList);
            list3.add(arrayList2);
        }
    }

    public static List<KeyedStateHandle> getManagedKeyedStateHandles(OperatorState operatorState, KeyGroupRange keyGroupRange) {
        int parallelism = operatorState.getParallelism();
        ArrayList arrayList = null;
        for (int i = 0; i < parallelism; i++) {
            if (operatorState.getState(i) != null) {
                StateObjectCollection<KeyedStateHandle> managedKeyedState = operatorState.getState(i).getManagedKeyedState();
                if (arrayList == null) {
                    arrayList = new ArrayList(parallelism * managedKeyedState.size());
                }
                extractIntersectingState(managedKeyedState, keyGroupRange, arrayList);
            }
        }
        return arrayList;
    }

    public static List<KeyedStateHandle> getRawKeyedStateHandles(OperatorState operatorState, KeyGroupRange keyGroupRange) {
        int parallelism = operatorState.getParallelism();
        ArrayList arrayList = null;
        for (int i = 0; i < parallelism; i++) {
            if (operatorState.getState(i) != null) {
                StateObjectCollection<KeyedStateHandle> rawKeyedState = operatorState.getState(i).getRawKeyedState();
                if (arrayList == null) {
                    arrayList = new ArrayList(parallelism * rawKeyedState.size());
                }
                extractIntersectingState(rawKeyedState, keyGroupRange, arrayList);
            }
        }
        return arrayList;
    }

    private static void extractIntersectingState(Collection<KeyedStateHandle> collection, KeyGroupRange keyGroupRange, List<KeyedStateHandle> list) {
        KeyedStateHandle intersection;
        for (KeyedStateHandle keyedStateHandle : collection) {
            if (keyedStateHandle != null && (intersection = keyedStateHandle.getIntersection(keyGroupRange)) != null) {
                list.add(intersection);
            }
        }
    }

    public static List<KeyGroupRange> createKeyGroupPartitions(int i, int i2) {
        Preconditions.checkArgument(i >= i2);
        ArrayList arrayList = new ArrayList(i2);
        for (int i3 = 0; i3 < i2; i3++) {
            arrayList.add(KeyGroupRangeAssignment.computeKeyGroupRangeForOperatorIndex(i, i2, i3));
        }
        return arrayList;
    }

    private static void checkParallelismPreconditions(OperatorState operatorState, ExecutionJobVertex executionJobVertex) {
        if (operatorState.getMaxParallelism() < executionJobVertex.getParallelism()) {
            throw new IllegalStateException("The state for task " + executionJobVertex.getJobVertexId() + " can not be restored. The maximum parallelism (" + operatorState.getMaxParallelism() + ") of the restored state is lower than the configured parallelism (" + executionJobVertex.getParallelism() + "). Please reduce the parallelism of the task to be lower or equal to the maximum parallelism.");
        }
        if (operatorState.getMaxParallelism() != executionJobVertex.getMaxParallelism()) {
            if (executionJobVertex.isMaxParallelismConfigured()) {
                throw new IllegalStateException("The maximum parallelism (" + operatorState.getMaxParallelism() + ") with which the latest checkpoint of the execution job vertex " + executionJobVertex + " has been taken and the current maximum parallelism (" + executionJobVertex.getMaxParallelism() + ") changed. This is currently not supported.");
            }
            LOG.debug("Overriding maximum parallelism for JobVertex {} from {} to {}", executionJobVertex.getJobVertexId(), Integer.valueOf(executionJobVertex.getMaxParallelism()), Integer.valueOf(operatorState.getMaxParallelism()));
            executionJobVertex.setMaxParallelism(operatorState.getMaxParallelism());
        }
    }

    private static void checkStateMappingCompleteness(boolean z, Map<OperatorID, OperatorState> map, Map<JobVertexID, ExecutionJobVertex> map2) {
        HashSet hashSet = new HashSet();
        Iterator<ExecutionJobVertex> it = map2.values().iterator();
        while (it.hasNext()) {
            hashSet.addAll(it.next().getOperatorIDs());
        }
        for (Map.Entry<OperatorID, OperatorState> entry : map.entrySet()) {
            OperatorState value = entry.getValue();
            if (!hashSet.contains(entry.getKey())) {
                if (!z) {
                    throw new IllegalStateException("There is no operator for the state " + value.getOperatorID());
                }
                LOG.info("Skipped checkpoint state for operator {}.", value.getOperatorID());
            }
        }
    }

    public static Map<OperatorInstanceID, List<OperatorStateHandle>> applyRepartitioner(OperatorID operatorID, OperatorStateRepartitioner operatorStateRepartitioner, List<List<OperatorStateHandle>> list, int i, int i2) {
        List<List<OperatorStateHandle>> applyRepartitioner = applyRepartitioner(operatorStateRepartitioner, list, i, i2);
        HashMap hashMap = new HashMap(applyRepartitioner.size());
        for (int i3 = 0; i3 < applyRepartitioner.size(); i3++) {
            Preconditions.checkNotNull(Boolean.valueOf(applyRepartitioner.get(i3) != null), "states.get(subtaskIndex) is null");
            hashMap.put(OperatorInstanceID.of(i3, operatorID), applyRepartitioner.get(i3));
        }
        return hashMap;
    }

    public static List<List<OperatorStateHandle>> applyRepartitioner(OperatorStateRepartitioner operatorStateRepartitioner, List<List<OperatorStateHandle>> list, int i, int i2) {
        return list == null ? Collections.emptyList() : operatorStateRepartitioner.repartitionState(list, i, i2);
    }

    public static List<KeyedStateHandle> getKeyedStateHandles(Collection<? extends KeyedStateHandle> collection, KeyGroupRange keyGroupRange) {
        ArrayList arrayList = new ArrayList(collection.size());
        Iterator<? extends KeyedStateHandle> it = collection.iterator();
        while (it.hasNext()) {
            KeyedStateHandle intersection = it.next().getIntersection(keyGroupRange);
            if (intersection != null) {
                arrayList.add(intersection);
            }
        }
        return arrayList;
    }
}
