package org.apache.flink.runtime.iterative.task;

import java.io.IOException;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.concurrent.atomic.AtomicBoolean;
import org.apache.commons.configuration.tree.DefaultExpressionEngine;
import org.apache.flink.api.common.aggregators.Aggregator;
import org.apache.flink.api.common.aggregators.AggregatorWithName;
import org.apache.flink.api.common.aggregators.ConvergenceCriterion;
import org.apache.flink.runtime.event.task.TaskEvent;
import org.apache.flink.runtime.io.network.api.reader.MutableRecordReader;
import org.apache.flink.runtime.iterative.event.AllWorkersDoneEvent;
import org.apache.flink.runtime.iterative.event.TerminationEvent;
import org.apache.flink.runtime.iterative.event.WorkerDoneEvent;
import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable;
import org.apache.flink.runtime.operators.RegularPactTask;
import org.apache.flink.runtime.operators.util.TaskConfig;
import org.apache.flink.shaded.com.google.common.base.Preconditions;
import org.apache.flink.types.IntValue;
import org.apache.flink.types.Value;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/flink/runtime/iterative/task/IterationSynchronizationSinkTask.class */
public class IterationSynchronizationSinkTask extends AbstractInvokable implements Terminable {
    private static final Logger log = LoggerFactory.getLogger(IterationSynchronizationSinkTask.class);
    private MutableRecordReader<IntValue> headEventReader;
    private SyncEventHandler eventHandler;
    private ConvergenceCriterion<Value> convergenceCriterion;
    private Map<String, Aggregator<?>> aggregators;
    private String convergenceAggregatorName;
    private int maxNumberOfIterations;
    private int currentIteration = 1;
    private final AtomicBoolean terminated = new AtomicBoolean(false);

    @Override // org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable
    public void registerInputOutput() {
        this.headEventReader = new MutableRecordReader<>(getEnvironment().getInputGate(0));
    }

    @Override // org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable
    public void invoke() throws Exception {
        TaskConfig taskConfig = new TaskConfig(getTaskConfiguration());
        this.aggregators = new HashMap();
        for (AggregatorWithName<?> aggregatorWithName : taskConfig.getIterationAggregators(getUserCodeClassLoader())) {
            this.aggregators.put(aggregatorWithName.getName(), aggregatorWithName.getAggregator());
        }
        if (taskConfig.usesConvergenceCriterion()) {
            this.convergenceCriterion = taskConfig.getConvergenceCriterion(getUserCodeClassLoader());
            this.convergenceAggregatorName = taskConfig.getConvergenceCriterionAggregatorName();
            Preconditions.checkNotNull(this.convergenceAggregatorName);
        }
        this.maxNumberOfIterations = taskConfig.getNumberOfIterations();
        this.eventHandler = new SyncEventHandler(taskConfig.getNumberOfEventsUntilInterruptInIterativeGate(0), this.aggregators, getEnvironment().getUserClassLoader());
        this.headEventReader.registerTaskEventListener(this.eventHandler, WorkerDoneEvent.class);
        IntValue intValue = new IntValue();
        while (!terminationRequested()) {
            if (log.isInfoEnabled()) {
                log.info(formatLogString("starting iteration [" + this.currentIteration + DefaultExpressionEngine.DEFAULT_ATTRIBUTE_END));
            }
            readHeadEventChannel(intValue);
            if (log.isInfoEnabled()) {
                log.info(formatLogString("finishing iteration [" + this.currentIteration + DefaultExpressionEngine.DEFAULT_ATTRIBUTE_END));
            }
            if (checkForConvergence()) {
                if (log.isInfoEnabled()) {
                    log.info(formatLogString("signaling that all workers are to terminate in iteration [" + this.currentIteration + DefaultExpressionEngine.DEFAULT_ATTRIBUTE_END));
                }
                requestTermination();
                sendToAllWorkers(new TerminationEvent());
            } else {
                if (log.isInfoEnabled()) {
                    log.info(formatLogString("signaling that all workers are done in iteration [" + this.currentIteration + DefaultExpressionEngine.DEFAULT_ATTRIBUTE_END));
                }
                sendToAllWorkers(new AllWorkersDoneEvent(this.aggregators));
                Iterator<Aggregator<?>> it = this.aggregators.values().iterator();
                while (it.hasNext()) {
                    it.next().reset();
                }
                this.currentIteration++;
            }
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    private boolean checkForConvergence() {
        if (this.maxNumberOfIterations == this.currentIteration) {
            if (!log.isInfoEnabled()) {
                return true;
            }
            log.info(formatLogString("maximum number of iterations [" + this.currentIteration + "] reached, terminating..."));
            return true;
        }
        if (this.convergenceAggregatorName == null) {
            return false;
        }
        Aggregator<?> aggregator = this.aggregators.get(this.convergenceAggregatorName);
        if (aggregator == null) {
            throw new RuntimeException("Error: Aggregator for convergence criterion was null.");
        }
        if (!this.convergenceCriterion.isConverged(this.currentIteration, aggregator.getAggregate())) {
            return false;
        }
        if (!log.isInfoEnabled()) {
            return true;
        }
        log.info(formatLogString("convergence reached after [" + this.currentIteration + "] iterations, terminating..."));
        return true;
    }

    private void readHeadEventChannel(IntValue intValue) throws IOException {
        this.eventHandler.resetEndOfSuperstep();
        try {
            if (this.headEventReader.next(intValue)) {
                throw new RuntimeException("Synchronization task must not see any records!");
            }
        } catch (InterruptedException e) {
            if (!this.eventHandler.isEndOfSuperstep()) {
                throw new RuntimeException("Event handler interrupted without reaching end-of-superstep.");
            }
        }
    }

    private void sendToAllWorkers(TaskEvent taskEvent) throws IOException, InterruptedException {
        this.headEventReader.sendTaskEvent(taskEvent);
    }

    private String formatLogString(String str) {
        return RegularPactTask.constructLogString(str, getEnvironment().getTaskName(), this);
    }

    @Override // org.apache.flink.runtime.iterative.task.Terminable
    public boolean terminationRequested() {
        return this.terminated.get();
    }

    @Override // org.apache.flink.runtime.iterative.task.Terminable
    public void requestTermination() {
        this.terminated.set(true);
    }
}
