package org.apache.flink.ml.common.broadcast;

import java.util.Iterator;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.runtime.state.FunctionInitializationContext;
import org.apache.flink.runtime.state.FunctionSnapshotContext;
import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
import org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction;
import org.apache.flink.streaming.api.functions.source.SourceFunction;

/* loaded from: input_file:org/apache/flink/ml/common/broadcast/TestSource.class */
public class TestSource extends RichParallelSourceFunction<Integer> implements CheckpointedFunction {
    private static volatile boolean hasThrown = false;
    private ListState<Integer> currentIdxState;
    private Integer currentIdx;
    private Integer mod;
    private Integer numPartitions;
    private Integer numElementsPerPartition;
    private volatile transient boolean running = true;

    public TestSource(int i) {
        this.numElementsPerPartition = Integer.valueOf(i);
    }

    public void open(Configuration configuration) {
        this.mod = Integer.valueOf(getRuntimeContext().getIndexOfThisSubtask());
        this.numPartitions = Integer.valueOf(getRuntimeContext().getNumberOfParallelSubtasks());
        this.running = true;
    }

    public void snapshotState(FunctionSnapshotContext functionSnapshotContext) throws Exception {
        this.currentIdxState.clear();
        this.currentIdxState.add(this.currentIdx);
    }

    public void initializeState(FunctionInitializationContext functionInitializationContext) throws Exception {
        this.currentIdxState = functionInitializationContext.getOperatorStateStore().getListState(new ListStateDescriptor("currentIdx", BasicTypeInfo.INT_TYPE_INFO));
        Iterator it = ((Iterable) this.currentIdxState.get()).iterator();
        this.currentIdx = 0;
        if (it.hasNext()) {
            this.currentIdx = (Integer) it.next();
        }
    }

    public void run(SourceFunction.SourceContext<Integer> sourceContext) throws Exception {
        while (this.running && this.currentIdx.intValue() < this.numElementsPerPartition.intValue()) {
            synchronized (sourceContext.getCheckpointLock()) {
                sourceContext.collect(Integer.valueOf((this.currentIdx.intValue() * this.numPartitions.intValue()) + this.mod.intValue()));
                Integer num = this.currentIdx;
                this.currentIdx = Integer.valueOf(this.currentIdx.intValue() + 1);
            }
            Thread.sleep(1L);
            if (this.currentIdx.intValue() == this.numElementsPerPartition.intValue() / 2 && !hasThrown) {
                hasThrown = true;
                throw new RuntimeException("Failing source");
            }
        }
    }

    public void cancel() {
        this.running = false;
    }
}
