package org.apache.flink.iteration.operator.coordinator;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Executor;
import java.util.function.Supplier;
import org.apache.flink.annotation.VisibleForTesting;
import org.apache.flink.iteration.IterationID;
import org.apache.flink.iteration.operator.event.CoordinatorCheckpointEvent;
import org.apache.flink.iteration.operator.event.GloballyAlignedEvent;
import org.apache.flink.iteration.operator.event.SubtaskAlignedEvent;
import org.apache.flink.runtime.jobgraph.OperatorID;
import org.apache.flink.runtime.jobgraph.OperatorInstanceID;
import org.apache.flink.runtime.operators.coordination.OperatorCoordinator;
import org.apache.flink.util.ExceptionUtils;
import org.apache.flink.util.Preconditions;
import org.apache.flink.util.function.ThrowingRunnable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/flink/iteration/operator/coordinator/SharedProgressAligner.class */
public class SharedProgressAligner {
    private static final Logger LOG = LoggerFactory.getLogger(SharedProgressAligner.class);
    public static ConcurrentHashMap<IterationID, SharedProgressAligner> instances = new ConcurrentHashMap<>();
    private final IterationID iterationId;
    private final int totalHeadParallelism;
    private final OperatorCoordinator.Context context;
    private final Executor executor;
    private boolean globallyTerminating;
    private final Map<Integer, EpochStatus> statusByEpoch = new HashMap();
    private final Map<OperatorID, SharedProgressAlignerListener> listeners = new HashMap();
    private final Map<Long, CheckpointStatus> checkpointStatuses = new HashMap();

    /* loaded from: input_file:org/apache/flink/iteration/operator/coordinator/SharedProgressAligner$CheckpointStatus.class */
    private static class CheckpointStatus {
        private final long totalHeadParallelism;
        private final List<CompletableFuture<byte[]>> stateFutures;
        private int notifiedCoordinatorParallelism;

        private CheckpointStatus(long j) {
            this.stateFutures = new ArrayList();
            this.totalHeadParallelism = j;
        }

        public boolean notify(int i, CompletableFuture<byte[]> completableFuture) {
            this.stateFutures.add(completableFuture);
            this.notifiedCoordinatorParallelism += i;
            return ((long) this.notifiedCoordinatorParallelism) == this.totalHeadParallelism;
        }

        public List<CompletableFuture<byte[]>> getStateFutures() {
            return this.stateFutures;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/flink/iteration/operator/coordinator/SharedProgressAligner$EpochStatus.class */
    public static class EpochStatus {
        private final int epoch;
        private final long totalHeadParallelism;
        private final Map<OperatorInstanceID, SubtaskAlignedEvent> reportedSubtasks = new HashMap();

        public EpochStatus(int i, long j) {
            this.epoch = i;
            this.totalHeadParallelism = j;
        }

        public boolean report(OperatorID operatorID, int i, SubtaskAlignedEvent subtaskAlignedEvent) {
            this.reportedSubtasks.put(new OperatorInstanceID(i, operatorID), subtaskAlignedEvent);
            Preconditions.checkState(((long) this.reportedSubtasks.size()) <= this.totalHeadParallelism, "Received more subtasks" + this.reportedSubtasks + "than the expected total parallelism " + this.totalHeadParallelism);
            return ((long) this.reportedSubtasks.size()) == this.totalHeadParallelism;
        }

        public void remove(OperatorID operatorID) {
            this.reportedSubtasks.entrySet().removeIf(entry -> {
                return ((OperatorInstanceID) entry.getKey()).getOperatorId().equals(operatorID);
            });
        }

        public void remove(OperatorID operatorID, int i) {
            this.reportedSubtasks.remove(new OperatorInstanceID(i, operatorID));
        }

        public boolean isTerminated() {
            Preconditions.checkState(((long) this.reportedSubtasks.size()) == this.totalHeadParallelism, "The round is not globally aligned yet");
            if (this.epoch == 0) {
                return false;
            }
            long j = 0;
            boolean z = false;
            long j2 = 0;
            for (SubtaskAlignedEvent subtaskAlignedEvent : this.reportedSubtasks.values()) {
                j += subtaskAlignedEvent.getNumRecordsThisRound();
                if (subtaskAlignedEvent.isCriteriaStream()) {
                    z = true;
                    j2 += subtaskAlignedEvent.getNumRecordsThisRound();
                }
            }
            return j == 0 || (z && j2 == 0);
        }
    }

    public static SharedProgressAligner getOrCreate(IterationID iterationID, int i, OperatorCoordinator.Context context, Supplier<Executor> supplier) {
        return instances.computeIfAbsent(iterationID, iterationID2 -> {
            return new SharedProgressAligner(iterationID, i, context, (Executor) supplier.get());
        });
    }

    @VisibleForTesting
    static ConcurrentHashMap<IterationID, SharedProgressAligner> getInstances() {
        return instances;
    }

    private SharedProgressAligner(IterationID iterationID, int i, OperatorCoordinator.Context context, Executor executor) {
        this.iterationId = (IterationID) Objects.requireNonNull(iterationID);
        this.totalHeadParallelism = i;
        this.context = (OperatorCoordinator.Context) Objects.requireNonNull(context);
        this.executor = (Executor) Objects.requireNonNull(executor);
    }

    public void registerAlignedListener(OperatorID operatorID, SharedProgressAlignerListener sharedProgressAlignerListener) {
        runInEventLoop(() -> {
            this.listeners.put(operatorID, sharedProgressAlignerListener);
        }, "Register listeners %s", operatorID.toHexString());
    }

    public void unregisterListener(OperatorID operatorID) {
        runInEventLoop(() -> {
            this.listeners.remove(operatorID);
            if (this.listeners.isEmpty()) {
                instances.remove(this.iterationId);
            }
        }, "Unregister listeners %s", operatorID.toHexString());
    }

    public void reportSubtaskProgress(OperatorID operatorID, int i, SubtaskAlignedEvent subtaskAlignedEvent) {
        runInEventLoop(() -> {
            LOG.debug("Processing {} from {}-{}", new Object[]{subtaskAlignedEvent, operatorID, Integer.valueOf(i)});
            EpochStatus computeIfAbsent = this.statusByEpoch.computeIfAbsent(Integer.valueOf(subtaskAlignedEvent.getEpoch()), num -> {
                return new EpochStatus(num.intValue(), this.totalHeadParallelism);
            });
            if (computeIfAbsent.report(operatorID, i, subtaskAlignedEvent)) {
                GloballyAlignedEvent globallyAlignedEvent = new GloballyAlignedEvent(subtaskAlignedEvent.getEpoch(), computeIfAbsent.isTerminated());
                Iterator<SharedProgressAlignerListener> it = this.listeners.values().iterator();
                while (it.hasNext()) {
                    it.next().onAligned(globallyAlignedEvent);
                }
                if (computeIfAbsent.isTerminated()) {
                    this.globallyTerminating = true;
                }
            }
        }, "Report subtask %s-%d", operatorID.toHexString(), Integer.valueOf(i));
    }

    public void requestCheckpoint(long j, int i, CompletableFuture<byte[]> completableFuture) {
        runInEventLoop(() -> {
            CheckpointStatus computeIfAbsent = this.checkpointStatuses.computeIfAbsent(Long.valueOf(j), l -> {
                return new CheckpointStatus(this.totalHeadParallelism);
            });
            if (computeIfAbsent.notify(i, completableFuture)) {
                if (!this.globallyTerminating) {
                    CoordinatorCheckpointEvent coordinatorCheckpointEvent = new CoordinatorCheckpointEvent(j);
                    Iterator<SharedProgressAlignerListener> it = this.listeners.values().iterator();
                    while (it.hasNext()) {
                        it.next().onCheckpointAligned(coordinatorCheckpointEvent);
                    }
                }
                Iterator<CompletableFuture<byte[]>> it2 = computeIfAbsent.getStateFutures().iterator();
                while (it2.hasNext()) {
                    it2.next().complete(new byte[0]);
                }
                this.checkpointStatuses.remove(Long.valueOf(j));
            }
        }, "Coordinator report checkpoint %d", Long.valueOf(j));
    }

    public void notifyGloballyTerminating() {
        runInEventLoop(() -> {
            this.globallyTerminating = true;
        }, "Report globally terminating", new Object[0]);
    }

    public void removeProgressInfo(OperatorID operatorID) {
        runInEventLoop(() -> {
            this.statusByEpoch.values().forEach(epochStatus -> {
                epochStatus.remove(operatorID);
            });
        }, "remove the progress information for {}", operatorID);
    }

    public void removeProgressInfo(OperatorID operatorID, int i) {
        runInEventLoop(() -> {
            this.statusByEpoch.values().forEach(epochStatus -> {
                epochStatus.remove(operatorID, i);
            });
        }, "remove the progress information for {}-{}", operatorID, Integer.valueOf(i));
    }

    private void runInEventLoop(ThrowingRunnable<Throwable> throwingRunnable, String str, Object... objArr) {
        this.executor.execute(() -> {
            try {
                throwingRunnable.run();
            } catch (Throwable th) {
                ExceptionUtils.rethrowIfFatalErrorOrOOM(th);
                LOG.error("Uncaught exception in the SharedProgressAligner for iteration {} while {}. Triggering job failover.", new Object[]{this.iterationId, String.format(str, objArr), th});
                this.context.failJob(th);
            }
        });
    }

    @VisibleForTesting
    int getNumberListeners() {
        return this.listeners.size();
    }
}
