package org.apache.flink.cdc.connectors.tidb;

import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.TreeMap;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import org.apache.flink.api.common.state.CheckpointListener;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeutils.base.LongSerializer;
import org.apache.flink.api.java.typeutils.ResultTypeQueryable;
import org.apache.flink.cdc.connectors.tidb.table.StartupMode;
import org.apache.flink.cdc.connectors.tidb.table.utils.TableKeyRangeUtils;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.runtime.state.FunctionInitializationContext;
import org.apache.flink.runtime.state.FunctionSnapshotContext;
import org.apache.flink.shaded.guava31.com.google.common.util.concurrent.ThreadFactoryBuilder;
import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
import org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction;
import org.apache.flink.streaming.api.functions.source.SourceFunction;
import org.apache.flink.util.Collector;
import org.apache.flink.util.Preconditions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.tikv.cdc.CDCClient;
import org.tikv.common.TiConfiguration;
import org.tikv.common.TiSession;
import org.tikv.common.key.RowKey;
import org.tikv.common.meta.TiTableInfo;
import org.tikv.kvproto.Cdcpb;
import org.tikv.kvproto.Coprocessor;
import org.tikv.kvproto.Kvrpcpb;
import org.tikv.shade.com.google.protobuf.ByteString;
import org.tikv.txn.KVClient;

/* loaded from: input_file:org/apache/flink/cdc/connectors/tidb/TiKVRichParallelSourceFunction.class */
public class TiKVRichParallelSourceFunction<T> extends RichParallelSourceFunction<T> implements CheckpointListener, CheckpointedFunction, ResultTypeQueryable<T> {
    private static final long serialVersionUID = 1;
    private static final Logger LOG = LoggerFactory.getLogger(TiKVRichParallelSourceFunction.class);
    private static final long SNAPSHOT_VERSION_EPOCH = -1;
    private static final long STREAMING_VERSION_START_EPOCH = 0;
    private final TiKVSnapshotEventDeserializationSchema<T> snapshotEventDeserializationSchema;
    private final TiKVChangeEventDeserializationSchema<T> changeEventDeserializationSchema;
    private final TiConfiguration tiConf;
    private final StartupMode startupMode;
    private final String database;
    private final String tableName;
    private transient OutputCollector<T> outputCollector;
    private transient ExecutorService executorService;
    private transient ListState<Long> offsetState;
    private static final long CLOSE_TIMEOUT = 30;
    private transient TiSession session = null;
    private transient Coprocessor.KeyRange keyRange = null;
    private transient CDCClient cdcClient = null;
    private transient SourceFunction.SourceContext<T> sourceContext = null;
    private volatile transient long resolvedTs = -1;
    private transient TreeMap<RowKeyWithTs, Cdcpb.Event.Row> prewrites = null;
    private transient TreeMap<RowKeyWithTs, Cdcpb.Event.Row> commits = null;
    private transient BlockingQueue<Cdcpb.Event.Row> committedEvents = null;
    private transient boolean running = true;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/flink/cdc/connectors/tidb/TiKVRichParallelSourceFunction$OutputCollector.class */
    public static class OutputCollector<T> implements Collector<T> {
        private SourceFunction.SourceContext<T> context;

        private OutputCollector() {
        }

        public void collect(T t) {
            this.context.collect(t);
        }

        public void close() {
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/flink/cdc/connectors/tidb/TiKVRichParallelSourceFunction$RowKeyWithTs.class */
    public static class RowKeyWithTs implements Comparable<RowKeyWithTs> {
        private final long timestamp;
        private final RowKey rowKey;

        private RowKeyWithTs(long j, RowKey rowKey) {
            this.timestamp = j;
            this.rowKey = rowKey;
        }

        private RowKeyWithTs(long j, byte[] bArr) {
            this(j, RowKey.decode(bArr));
        }

        @Override // java.lang.Comparable
        public int compareTo(RowKeyWithTs rowKeyWithTs) {
            int compare = Long.compare(this.timestamp, rowKeyWithTs.timestamp);
            if (compare == 0) {
                compare = Long.compare(this.rowKey.getTableId(), rowKeyWithTs.rowKey.getTableId());
            }
            if (compare == 0) {
                compare = Long.compare(this.rowKey.getHandle(), rowKeyWithTs.rowKey.getHandle());
            }
            return compare;
        }

        public int hashCode() {
            return Objects.hash(Long.valueOf(this.timestamp), Long.valueOf(this.rowKey.getTableId()), Long.valueOf(this.rowKey.getHandle()));
        }

        public boolean equals(Object obj) {
            if (!(obj instanceof RowKeyWithTs)) {
                return false;
            }
            RowKeyWithTs rowKeyWithTs = (RowKeyWithTs) obj;
            return this.timestamp == rowKeyWithTs.timestamp && this.rowKey.equals(rowKeyWithTs.rowKey);
        }

        static RowKeyWithTs ofStart(Cdcpb.Event.Row row) {
            return new RowKeyWithTs(row.getStartTs(), row.getKey().toByteArray());
        }

        static RowKeyWithTs ofCommit(Cdcpb.Event.Row row) {
            return new RowKeyWithTs(row.getCommitTs(), row.getKey().toByteArray());
        }
    }

    public TiKVRichParallelSourceFunction(TiKVSnapshotEventDeserializationSchema<T> tiKVSnapshotEventDeserializationSchema, TiKVChangeEventDeserializationSchema<T> tiKVChangeEventDeserializationSchema, TiConfiguration tiConfiguration, StartupMode startupMode, String str, String str2) {
        this.snapshotEventDeserializationSchema = tiKVSnapshotEventDeserializationSchema;
        this.changeEventDeserializationSchema = tiKVChangeEventDeserializationSchema;
        this.tiConf = tiConfiguration;
        this.startupMode = startupMode;
        this.database = str;
        this.tableName = str2;
    }

    public void open(Configuration configuration) throws Exception {
        super.open(configuration);
        this.session = TiSession.create(this.tiConf);
        TiTableInfo table = this.session.getCatalog().getTable(this.database, this.tableName);
        if (table == null) {
            throw new RuntimeException(String.format("Table %s.%s does not exist.", this.database, this.tableName));
        }
        this.keyRange = TableKeyRangeUtils.getTableKeyRange(table.getId(), getRuntimeContext().getNumberOfParallelSubtasks(), getRuntimeContext().getIndexOfThisSubtask());
        this.cdcClient = new CDCClient(this.session, this.keyRange);
        this.prewrites = new TreeMap<>();
        this.commits = new TreeMap<>();
        this.committedEvents = new LinkedBlockingQueue();
        this.outputCollector = new OutputCollector<>();
        this.resolvedTs = this.startupMode == StartupMode.INITIAL ? -1L : 0L;
        this.executorService = Executors.newSingleThreadExecutor(new ThreadFactoryBuilder().setNameFormat("tidb-source-function-" + getRuntimeContext().getIndexOfThisSubtask()).build());
    }

    public void run(SourceFunction.SourceContext<T> sourceContext) throws Exception {
        this.sourceContext = sourceContext;
        ((OutputCollector) this.outputCollector).context = this.sourceContext;
        if (this.startupMode == StartupMode.INITIAL) {
            synchronized (this.sourceContext.getCheckpointLock()) {
                readSnapshotEvents();
            }
        } else {
            LOG.info("Skip snapshot read");
            this.resolvedTs = this.session.getTimestamp().getVersion();
        }
        LOG.info("start read change events");
        this.cdcClient.start(this.resolvedTs);
        this.running = true;
        readChangeEvents();
    }

    private void handleRow(Cdcpb.Event.Row row) {
        if (TableKeyRangeUtils.isRecordKey(row.getKey().toByteArray())) {
            LOG.debug("binlog record, type: {}, data: {}", row.getType(), row);
            switch (row.getType()) {
                case COMMITTED:
                    this.prewrites.put(RowKeyWithTs.ofStart(row), row);
                    this.commits.put(RowKeyWithTs.ofCommit(row), row);
                    return;
                case COMMIT:
                    this.commits.put(RowKeyWithTs.ofCommit(row), row);
                    return;
                case PREWRITE:
                    this.prewrites.put(RowKeyWithTs.ofStart(row), row);
                    return;
                case ROLLBACK:
                    this.prewrites.remove(RowKeyWithTs.ofStart(row));
                    return;
                default:
                    LOG.warn("Unsupported row type:" + row.getType());
                    return;
            }
        }
    }

    protected void readSnapshotEvents() throws Exception {
        LOG.info("read snapshot events");
        KVClient createKVClient = this.session.createKVClient();
        Throwable th = null;
        try {
            long version = this.session.getTimestamp().getVersion();
            ByteString start = this.keyRange.getStart();
            while (true) {
                List<Kvrpcpb.KvPair> scan = createKVClient.scan(start, this.keyRange.getEnd(), version);
                if (scan.isEmpty()) {
                    break;
                }
                for (Kvrpcpb.KvPair kvPair : scan) {
                    if (TableKeyRangeUtils.isRecordKey(kvPair.getKey().toByteArray())) {
                        this.snapshotEventDeserializationSchema.deserialize(kvPair, this.outputCollector);
                    }
                }
                start = RowKey.toRawKey(scan.get(scan.size() - 1).getKey()).next().toByteString();
            }
            this.resolvedTs = version;
            if (createKVClient != null) {
                if (0 == 0) {
                    createKVClient.close();
                    return;
                }
                try {
                    createKVClient.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
        } catch (Throwable th3) {
            if (createKVClient != null) {
                if (0 != 0) {
                    try {
                        createKVClient.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    createKVClient.close();
                }
            }
            throw th3;
        }
    }

    protected void readChangeEvents() throws Exception {
        Cdcpb.Event.Row row;
        LOG.info("read change event from resolvedTs:{}", Long.valueOf(this.resolvedTs));
        this.executorService.execute(() -> {
            while (this.running) {
                try {
                    this.changeEventDeserializationSchema.deserialize(this.committedEvents.take(), this.outputCollector);
                } catch (Exception e) {
                    e.printStackTrace();
                }
            }
        });
        while (this.resolvedTs >= 0) {
            for (int i = 0; i < 1000 && (row = this.cdcClient.get()) != null; i++) {
                handleRow(row);
            }
            this.resolvedTs = this.cdcClient.getMaxResolvedTs();
            if (this.commits.size() > 0) {
                flushRows(this.resolvedTs);
            }
        }
    }

    protected void flushRows(long j) throws Exception {
        Preconditions.checkState(this.sourceContext != null, "sourceContext shouldn't be null");
        synchronized (this.sourceContext) {
            while (!this.commits.isEmpty() && this.commits.firstKey().timestamp <= j) {
                this.committedEvents.offer(this.prewrites.remove(RowKeyWithTs.ofStart(this.commits.pollFirstEntry().getValue())));
            }
        }
    }

    public void cancel() {
        try {
            this.running = false;
            if (this.cdcClient != null) {
                this.cdcClient.close();
            }
            if (this.executorService != null) {
                this.executorService.shutdown();
                if (!this.executorService.awaitTermination(CLOSE_TIMEOUT, TimeUnit.SECONDS)) {
                    LOG.warn("Failed to close the tidb source function in {} seconds.", Long.valueOf(CLOSE_TIMEOUT));
                }
            }
        } catch (Exception e) {
            LOG.error("Unable to close cdcClient", e);
        }
    }

    public void snapshotState(FunctionSnapshotContext functionSnapshotContext) throws Exception {
        LOG.info("snapshotState checkpoint: {} at resolvedTs: {}", Long.valueOf(functionSnapshotContext.getCheckpointId()), Long.valueOf(this.resolvedTs));
        flushRows(this.resolvedTs);
        this.offsetState.clear();
        this.offsetState.add(Long.valueOf(this.resolvedTs));
    }

    public void initializeState(FunctionInitializationContext functionInitializationContext) throws Exception {
        LOG.info("initialize checkpoint");
        this.offsetState = functionInitializationContext.getOperatorStateStore().getListState(new ListStateDescriptor("resolvedTsState", LongSerializer.INSTANCE));
        if (!functionInitializationContext.isRestored()) {
            this.resolvedTs = 0L;
            LOG.info("Initialize State from resolvedTs: {}", Long.valueOf(this.resolvedTs));
            return;
        }
        Iterator<T> it = ((Iterable) this.offsetState.get()).iterator();
        if (it.hasNext()) {
            this.resolvedTs = ((Long) it.next()).longValue();
            LOG.info("Restore State from resolvedTs: {}", Long.valueOf(this.resolvedTs));
        }
    }

    public void notifyCheckpointComplete(long j) throws Exception {
    }

    public TypeInformation<T> getProducedType() {
        return this.snapshotEventDeserializationSchema.getProducedType();
    }
}
