package org.apache.flink.runtime.checkpoint;

import java.io.Serializable;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import org.apache.flink.runtime.jobgraph.JobVertexID;
import org.apache.flink.runtime.state.StateHandle;
import org.apache.flink.util.SerializedValue;

/* loaded from: input_file:org/apache/flink/runtime/checkpoint/TaskState.class */
public class TaskState implements Serializable {
    private static final long serialVersionUID = -4845578005863201810L;
    private final JobVertexID jobVertexID;
    private final Map<Integer, SubtaskState> subtaskStates;
    private final Map<Integer, KeyGroupState> kvStates = new HashMap();
    private final int parallelism;

    public TaskState(JobVertexID jobVertexID, int i) {
        this.jobVertexID = jobVertexID;
        this.subtaskStates = new HashMap(i);
        this.parallelism = i;
    }

    public JobVertexID getJobVertexID() {
        return this.jobVertexID;
    }

    public void putState(int i, SubtaskState subtaskState) {
        if (i < 0 || i >= this.parallelism) {
            throw new IndexOutOfBoundsException("The given sub task index " + i + " exceeds the maximum number of sub tasks " + this.subtaskStates.size());
        }
        this.subtaskStates.put(Integer.valueOf(i), subtaskState);
    }

    public SubtaskState getState(int i) {
        if (i < 0 || i >= this.parallelism) {
            throw new IndexOutOfBoundsException("The given sub task index " + i + " exceeds the maximum number of sub tasks " + this.subtaskStates.size());
        }
        return this.subtaskStates.get(Integer.valueOf(i));
    }

    public Collection<SubtaskState> getStates() {
        return this.subtaskStates.values();
    }

    public long getStateSize() {
        long j = 0;
        Iterator<SubtaskState> it = this.subtaskStates.values().iterator();
        while (it.hasNext()) {
            j += it.next().getStateSize();
        }
        Iterator<KeyGroupState> it2 = this.kvStates.values().iterator();
        while (it2.hasNext()) {
            j += it2.next().getStateSize();
        }
        return j;
    }

    public int getNumberCollectedStates() {
        return this.subtaskStates.size();
    }

    public int getParallelism() {
        return this.parallelism;
    }

    public void putKvState(int i, KeyGroupState keyGroupState) {
        this.kvStates.put(Integer.valueOf(i), keyGroupState);
    }

    public KeyGroupState getKvState(int i) {
        return this.kvStates.get(Integer.valueOf(i));
    }

    public Map<Integer, SerializedValue<StateHandle<?>>> getUnwrappedKvStates(Set<Integer> set) {
        HashMap hashMap = new HashMap(set.size());
        for (Integer num : set) {
            if (this.kvStates.get(num) != null) {
                hashMap.put(num, this.kvStates.get(num).getKeyGroupState());
            }
        }
        return hashMap;
    }

    public int getNumberCollectedKvStates() {
        return this.kvStates.size();
    }

    public void discard(ClassLoader classLoader) throws Exception {
        Iterator<SubtaskState> it = this.subtaskStates.values().iterator();
        while (it.hasNext()) {
            it.next().discard(classLoader);
        }
        Iterator<KeyGroupState> it2 = this.kvStates.values().iterator();
        while (it2.hasNext()) {
            it2.next().discard(classLoader);
        }
    }

    public boolean equals(Object obj) {
        if (!(obj instanceof TaskState)) {
            return false;
        }
        TaskState taskState = (TaskState) obj;
        return this.jobVertexID.equals(taskState.jobVertexID) && this.parallelism == taskState.parallelism && this.subtaskStates.equals(taskState.subtaskStates) && this.kvStates.equals(taskState.kvStates);
    }

    public int hashCode() {
        return this.parallelism + (31 * Objects.hash(this.jobVertexID, this.subtaskStates, this.kvStates));
    }
}
