package org.apache.flink.ml.common.broadcast.operator;

import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.InputStream;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import org.apache.commons.collections.IteratorUtils;
import org.apache.flink.api.common.functions.RichFunction;
import org.apache.flink.api.common.operators.MailboxExecutor;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.core.fs.Path;
import org.apache.flink.core.memory.ManagedMemoryUseCase;
import org.apache.flink.iteration.datacache.nonkeyed.DataCacheReader;
import org.apache.flink.iteration.datacache.nonkeyed.DataCacheSnapshot;
import org.apache.flink.iteration.datacache.nonkeyed.DataCacheWriter;
import org.apache.flink.iteration.operator.OperatorUtils;
import org.apache.flink.iteration.proxy.state.ProxyStreamOperatorStateContext;
import org.apache.flink.metrics.groups.OperatorMetricGroup;
import org.apache.flink.ml.common.broadcast.BroadcastContext;
import org.apache.flink.ml.common.broadcast.BroadcastStreamingRuntimeContext;
import org.apache.flink.ml.common.broadcast.typeinfo.CacheElement;
import org.apache.flink.ml.common.broadcast.typeinfo.CacheElementSerializer;
import org.apache.flink.runtime.checkpoint.CheckpointOptions;
import org.apache.flink.runtime.execution.Environment;
import org.apache.flink.runtime.jobgraph.OperatorID;
import org.apache.flink.runtime.metrics.groups.InternalOperatorMetricGroup;
import org.apache.flink.runtime.metrics.groups.UnregisteredMetricGroups;
import org.apache.flink.runtime.state.CheckpointStreamFactory;
import org.apache.flink.runtime.state.OperatorStateCheckpointOutputStream;
import org.apache.flink.runtime.state.StateInitializationContext;
import org.apache.flink.runtime.state.StatePartitionStreamProvider;
import org.apache.flink.runtime.state.StateSnapshotContext;
import org.apache.flink.runtime.util.NonClosingInputStreamDecorator;
import org.apache.flink.runtime.util.NonClosingOutputStreamDecorator;
import org.apache.flink.streaming.api.graph.StreamConfig;
import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator;
import org.apache.flink.streaming.api.operators.InternalTimeServiceManager;
import org.apache.flink.streaming.api.operators.OperatorSnapshotFutures;
import org.apache.flink.streaming.api.operators.Output;
import org.apache.flink.streaming.api.operators.StreamOperator;
import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
import org.apache.flink.streaming.api.operators.StreamOperatorFactoryUtil;
import org.apache.flink.streaming.api.operators.StreamOperatorParameters;
import org.apache.flink.streaming.api.operators.StreamOperatorStateContext;
import org.apache.flink.streaming.api.operators.StreamOperatorStateHandler;
import org.apache.flink.streaming.api.operators.StreamTaskStateInitializer;
import org.apache.flink.streaming.api.watermark.Watermark;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.streaming.runtime.tasks.StreamTask;
import org.apache.flink.util.CloseableIterator;
import org.apache.flink.util.Preconditions;
import org.apache.flink.util.function.ThrowingConsumer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/flink/ml/common/broadcast/operator/AbstractBroadcastWrapperOperator.class */
public abstract class AbstractBroadcastWrapperOperator<T, S extends StreamOperator<T>> implements StreamOperator<T>, StreamOperatorStateHandler.CheckpointedStreamOperator {
    private static final Logger LOG = LoggerFactory.getLogger(AbstractBroadcastWrapperOperator.class);
    protected final StreamOperatorParameters<T> parameters;
    protected final StreamConfig streamConfig;
    protected final StreamTask<?, ?> containingTask;
    protected final Output<StreamRecord<T>> output;
    protected final StreamOperatorFactory<T> operatorFactory;
    protected final OperatorMetricGroup metrics;
    protected final S wrappedOperator;
    protected transient StreamOperatorStateHandler stateHandler;
    protected transient InternalTimeServiceManager<?> timeServiceManager;
    private MailboxExecutor mailboxExecutor;
    private String[] broadcastStreamNames;
    private boolean[] isBlocked;
    private TypeSerializer<?>[] inTypeSerializers;
    private boolean broadcastVariablesReady;
    protected transient int indexOfSubtask;
    protected int numInputs;
    private BroadcastStreamingRuntimeContext wrappedOperatorRuntimeContext;
    private Path basePath;
    private DataCacheWriter[] dataCacheWriters;
    private boolean[] hasPendingElements;
    private final boolean hasRichFunction;

    /* JADX INFO: Access modifiers changed from: package-private */
    public AbstractBroadcastWrapperOperator(StreamOperatorParameters<T> streamOperatorParameters, StreamOperatorFactory<T> streamOperatorFactory, String[] strArr) {
        this.parameters = (StreamOperatorParameters) Objects.requireNonNull(streamOperatorParameters);
        this.streamConfig = (StreamConfig) Objects.requireNonNull(streamOperatorParameters.getStreamConfig());
        this.containingTask = (StreamTask) Objects.requireNonNull(streamOperatorParameters.getContainingTask());
        this.output = (Output) Objects.requireNonNull(streamOperatorParameters.getOutput());
        this.operatorFactory = (StreamOperatorFactory) Objects.requireNonNull(streamOperatorFactory);
        this.metrics = createOperatorMetricGroup(this.containingTask.getEnvironment(), this.streamConfig);
        this.wrappedOperator = (S) StreamOperatorFactoryUtil.createOperator(streamOperatorFactory, this.containingTask, this.streamConfig, this.output, streamOperatorParameters.getOperatorEventDispatcher()).f0;
        this.hasRichFunction = (this.wrappedOperator instanceof AbstractUdfStreamOperator) && (((AbstractUdfStreamOperator) this.wrappedOperator).getUserFunction() instanceof RichFunction);
        if (this.hasRichFunction) {
            this.wrappedOperatorRuntimeContext = new BroadcastStreamingRuntimeContext(this.containingTask.getEnvironment(), this.containingTask.getEnvironment().getAccumulatorRegistry().getUserMap(), this.wrappedOperator.getMetricGroup(), this.wrappedOperator.getOperatorID(), ((AbstractUdfStreamOperator) this.wrappedOperator).getProcessingTimeService(), null, this.containingTask.getEnvironment().getExternalResourceInfoProvider());
            ((RichFunction) ((AbstractUdfStreamOperator) this.wrappedOperator).getUserFunction()).setRuntimeContext(this.wrappedOperatorRuntimeContext);
            this.mailboxExecutor = this.containingTask.getMailboxExecutorFactory().createExecutor(-1);
            this.indexOfSubtask = this.containingTask.getIndexInSubtaskGroup();
            for (String str : strArr) {
                BroadcastContext.putMailBoxExecutor(str + "-" + this.indexOfSubtask, this.mailboxExecutor);
            }
            this.broadcastStreamNames = strArr;
            StreamConfig.InputConfig[] inputs = this.streamConfig.getInputs(this.containingTask.getUserCodeClassLoader());
            int i = 0;
            while (i < inputs.length && (inputs[i] instanceof StreamConfig.NetworkInputConfig)) {
                i++;
            }
            this.numInputs = i;
            this.isBlocked = new boolean[this.numInputs];
            Arrays.fill(this.isBlocked, false);
            this.inTypeSerializers = new TypeSerializer[this.numInputs];
            for (int i2 = 0; i2 < this.numInputs; i2++) {
                this.inTypeSerializers[i2] = this.streamConfig.getTypeSerializerIn(i2, this.containingTask.getUserCodeClassLoader());
            }
            this.broadcastVariablesReady = false;
            this.basePath = OperatorUtils.getDataCachePath(this.containingTask.getEnvironment().getTaskManagerInfo().getConfiguration(), this.containingTask.getEnvironment().getIOManager().getSpillingDirectoriesPaths());
            this.dataCacheWriters = new DataCacheWriter[this.numInputs];
            this.hasPendingElements = new boolean[this.numInputs];
            Arrays.fill(this.hasPendingElements, true);
        }
    }

    protected boolean areBroadcastVariablesReady() {
        if (this.broadcastVariablesReady) {
            return true;
        }
        for (String str : this.broadcastStreamNames) {
            if (!BroadcastContext.isCacheFinished(str + "-" + this.indexOfSubtask)) {
                return false;
            }
            this.wrappedOperatorRuntimeContext.setBroadcastVariable(str.substring(str.indexOf(45) + 1), BroadcastContext.getBroadcastVariable(str + "-" + this.indexOfSubtask));
        }
        this.broadcastVariablesReady = true;
        return true;
    }

    private OperatorMetricGroup createOperatorMetricGroup(Environment environment, StreamConfig streamConfig) {
        try {
            InternalOperatorMetricGroup orAddOperator = environment.getMetricGroup().getOrAddOperator(streamConfig.getOperatorID(), streamConfig.getOperatorName());
            if (streamConfig.isChainEnd()) {
                orAddOperator.getIOMetricGroup().reuseOutputMetricsForTask();
            }
            return orAddOperator;
        } catch (Exception e) {
            LOG.warn("An error occurred while instantiating task metrics.", e);
            return UnregisteredMetricGroups.createUnregisteredOperatorMetricGroup();
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void processElementX(StreamRecord streamRecord, int i, ThrowingConsumer<StreamRecord, Exception> throwingConsumer, ThrowingConsumer<Watermark, Exception> throwingConsumer2, ThrowingConsumer<StreamRecord, Exception> throwingConsumer3) throws Exception {
        if (!this.hasRichFunction) {
            throwingConsumer.accept(streamRecord);
            return;
        }
        if (this.isBlocked[i]) {
            while (!areBroadcastVariablesReady()) {
                this.mailboxExecutor.yield();
            }
            throwingConsumer.accept(streamRecord);
        } else {
            if (!areBroadcastVariablesReady()) {
                this.dataCacheWriters[i].addRecord(CacheElement.newRecord(streamRecord.getValue()));
                return;
            }
            if (this.hasPendingElements[i]) {
                processPendingElementsAndWatermarks(i, throwingConsumer, throwingConsumer2, throwingConsumer3);
                this.hasPendingElements[i] = false;
            }
            throwingConsumer3.accept(streamRecord);
            throwingConsumer.accept(streamRecord);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void processWatermarkX(Watermark watermark, int i, ThrowingConsumer<StreamRecord, Exception> throwingConsumer, ThrowingConsumer<Watermark, Exception> throwingConsumer2, ThrowingConsumer<StreamRecord, Exception> throwingConsumer3) throws Exception {
        if (!this.hasRichFunction) {
            throwingConsumer2.accept(watermark);
            return;
        }
        if (this.isBlocked[i]) {
            while (!areBroadcastVariablesReady()) {
                this.mailboxExecutor.yield();
            }
            throwingConsumer2.accept(watermark);
        } else {
            if (!areBroadcastVariablesReady()) {
                this.dataCacheWriters[i].addRecord(CacheElement.newWatermark(watermark.getTimestamp()));
                return;
            }
            if (this.hasPendingElements[i]) {
                processPendingElementsAndWatermarks(i, throwingConsumer, throwingConsumer2, throwingConsumer3);
                this.hasPendingElements[i] = false;
            }
            throwingConsumer2.accept(watermark);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void endInputX(int i, ThrowingConsumer<StreamRecord, Exception> throwingConsumer, ThrowingConsumer<Watermark, Exception> throwingConsumer2, ThrowingConsumer<StreamRecord, Exception> throwingConsumer3) throws Exception {
        if (this.hasRichFunction) {
            while (!areBroadcastVariablesReady()) {
                this.mailboxExecutor.yield();
            }
            if (this.hasPendingElements[i]) {
                processPendingElementsAndWatermarks(i, throwingConsumer, throwingConsumer2, throwingConsumer3);
                this.hasPendingElements[i] = false;
            }
        }
    }

    private void processPendingElementsAndWatermarks(int i, ThrowingConsumer<StreamRecord, Exception> throwingConsumer, ThrowingConsumer<Watermark, Exception> throwingConsumer2, ThrowingConsumer<StreamRecord, Exception> throwingConsumer3) throws Exception {
        List segments = this.dataCacheWriters[i].getSegments();
        if (segments.size() != 0) {
            DataCacheReader dataCacheReader = new DataCacheReader(new CacheElementSerializer(this.inTypeSerializers[i]), segments);
            while (dataCacheReader.hasNext()) {
                CacheElement cacheElement = (CacheElement) dataCacheReader.next();
                switch (cacheElement.getType()) {
                    case RECORD:
                        StreamRecord streamRecord = new StreamRecord(cacheElement.getRecord());
                        throwingConsumer3.accept(streamRecord);
                        throwingConsumer.accept(streamRecord);
                        break;
                    case WATERMARK:
                        throwingConsumer2.accept(new Watermark(cacheElement.getWatermark()));
                        break;
                    default:
                        throw new RuntimeException("Unsupported CacheElement type: " + cacheElement.getType());
                }
            }
            this.dataCacheWriters[i].clear();
        }
    }

    public void open() throws Exception {
        this.wrappedOperator.open();
    }

    public void close() throws Exception {
        this.wrappedOperator.close();
        if (this.hasRichFunction) {
            for (String str : this.broadcastStreamNames) {
                BroadcastContext.remove(str + "-" + this.indexOfSubtask);
            }
        }
    }

    public void finish() throws Exception {
        this.wrappedOperator.finish();
    }

    public void prepareSnapshotPreBarrier(long j) throws Exception {
        this.wrappedOperator.prepareSnapshotPreBarrier(j);
    }

    public void initializeState(StreamTaskStateInitializer streamTaskStateInitializer) throws Exception {
        StreamOperatorStateContext streamOperatorStateContext = streamTaskStateInitializer.streamOperatorStateContext(getOperatorID(), getClass().getSimpleName(), this.parameters.getProcessingTimeService(), this, this.streamConfig.getStateKeySerializer(this.containingTask.getUserCodeClassLoader()), this.containingTask.getCancelables(), this.metrics, this.streamConfig.getManagedMemoryFractionOperatorUseCaseOfSlot(ManagedMemoryUseCase.STATE_BACKEND, this.containingTask.getEnvironment().getTaskManagerInfo().getConfiguration(), this.containingTask.getUserCodeClassLoader()), false);
        this.stateHandler = new StreamOperatorStateHandler(streamOperatorStateContext, this.containingTask.getExecutionConfig(), this.containingTask.getCancelables());
        this.stateHandler.initializeOperatorState(this);
        this.timeServiceManager = streamOperatorStateContext.internalTimerServiceManager();
        this.wrappedOperator.initializeState((operatorID, str, processingTimeService, keyContext, typeSerializer, closeableRegistry, metricGroup, d, z) -> {
            return new ProxyStreamOperatorStateContext(streamOperatorStateContext, "wrapped-", CloseableIterator.empty(), 0);
        });
    }

    public OperatorSnapshotFutures snapshotState(long j, long j2, CheckpointOptions checkpointOptions, CheckpointStreamFactory checkpointStreamFactory) throws Exception {
        return this.stateHandler.snapshotState(this, Optional.ofNullable(this.timeServiceManager), this.streamConfig.getOperatorName(), j, j2, checkpointOptions, checkpointStreamFactory, false);
    }

    public void initializeState(StateInitializationContext stateInitializationContext) throws Exception {
        List list = IteratorUtils.toList(stateInitializationContext.getRawOperatorStateInputs().iterator());
        Preconditions.checkState(list.size() < 2, "The input from raw operator state should be one or zero.");
        if (this.hasRichFunction) {
            if (list.size() == 0) {
                for (int i = 0; i < this.numInputs; i++) {
                    this.dataCacheWriters[i] = new DataCacheWriter(new CacheElementSerializer(this.inTypeSerializers[i]), this.basePath.getFileSystem(), OperatorUtils.createDataCacheFileGenerator(this.basePath, "cache", this.streamConfig.getOperatorID()));
                }
                return;
            }
            InputStream stream = ((StatePartitionStreamProvider) list.get(0)).getStream();
            Preconditions.checkState(new DataInputStream(new NonClosingInputStreamDecorator(stream)).readInt() == this.numInputs, "Number of input is wrong.");
            for (int i2 = 0; i2 < this.numInputs; i2++) {
                this.dataCacheWriters[i2] = new DataCacheWriter(new CacheElementSerializer(this.inTypeSerializers[i2]), this.basePath.getFileSystem(), OperatorUtils.createDataCacheFileGenerator(this.basePath, "cache", this.streamConfig.getOperatorID()), DataCacheSnapshot.recover(stream, this.basePath.getFileSystem(), OperatorUtils.createDataCacheFileGenerator(this.basePath, "cache", this.streamConfig.getOperatorID())).getSegments());
            }
        }
    }

    public void snapshotState(StateSnapshotContext stateSnapshotContext) throws Exception {
        if (this.wrappedOperator instanceof StreamOperatorStateHandler.CheckpointedStreamOperator) {
            this.wrappedOperator.snapshotState(stateSnapshotContext);
        }
        if (this.hasRichFunction) {
            OperatorStateCheckpointOutputStream rawOperatorStateOutput = stateSnapshotContext.getRawOperatorStateOutput();
            rawOperatorStateOutput.startNewPartition();
            DataOutputStream dataOutputStream = new DataOutputStream(new NonClosingOutputStreamDecorator(rawOperatorStateOutput));
            try {
                dataOutputStream.writeInt(this.numInputs);
                dataOutputStream.close();
                for (int i = 0; i < this.numInputs; i++) {
                    this.dataCacheWriters[i].writeSegmentsToFiles();
                    new DataCacheSnapshot(this.basePath.getFileSystem(), (Tuple2) null, this.dataCacheWriters[i].getSegments()).writeTo(rawOperatorStateOutput);
                }
            } catch (Throwable th) {
                try {
                    dataOutputStream.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
                throw th;
            }
        }
    }

    public void setKeyContextElement1(StreamRecord<?> streamRecord) throws Exception {
        this.wrappedOperator.setKeyContextElement1(streamRecord);
    }

    public void setKeyContextElement2(StreamRecord<?> streamRecord) throws Exception {
        this.wrappedOperator.setKeyContextElement2(streamRecord);
    }

    public OperatorMetricGroup getMetricGroup() {
        return this.wrappedOperator.getMetricGroup();
    }

    public OperatorID getOperatorID() {
        return this.wrappedOperator.getOperatorID();
    }

    public void notifyCheckpointComplete(long j) throws Exception {
        this.wrappedOperator.notifyCheckpointComplete(j);
    }

    public void notifyCheckpointAborted(long j) throws Exception {
        this.wrappedOperator.notifyCheckpointAborted(j);
    }

    public void setCurrentKey(Object obj) {
        this.wrappedOperator.setCurrentKey(obj);
    }

    public Object getCurrentKey() {
        return this.wrappedOperator.getCurrentKey();
    }
}
