/*
 * Decompiled with CFR 0.152.
 */
package org.apache.pinot.plugin.stream.kinesis;

import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeoutException;
import java.util.stream.Collectors;
import org.apache.commons.lang3.StringUtils;
import org.apache.pinot.plugin.stream.kinesis.KinesisConfig;
import org.apache.pinot.plugin.stream.kinesis.KinesisConnectionHandler;
import org.apache.pinot.plugin.stream.kinesis.KinesisPartitionGroupOffset;
import org.apache.pinot.spi.stream.MessageBatch;
import org.apache.pinot.spi.stream.OffsetCriteria;
import org.apache.pinot.spi.stream.PartitionGroupConsumer;
import org.apache.pinot.spi.stream.PartitionGroupConsumptionStatus;
import org.apache.pinot.spi.stream.PartitionGroupMetadata;
import org.apache.pinot.spi.stream.StreamConfig;
import org.apache.pinot.spi.stream.StreamConsumerFactory;
import org.apache.pinot.spi.stream.StreamConsumerFactoryProvider;
import org.apache.pinot.spi.stream.StreamMetadataProvider;
import org.apache.pinot.spi.stream.StreamPartitionMsgOffset;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import software.amazon.awssdk.services.kinesis.model.Shard;

public class KinesisStreamMetadataProvider
implements StreamMetadataProvider {
    private static final String SHARD_ID_PREFIX = "shardId-";
    private final KinesisConnectionHandler _kinesisConnectionHandler;
    private final StreamConsumerFactory _kinesisStreamConsumerFactory;
    private final String _clientId;
    private final int _fetchTimeoutMs;
    private static final Logger LOGGER = LoggerFactory.getLogger(KinesisStreamMetadataProvider.class);

    public KinesisStreamMetadataProvider(String clientId, StreamConfig streamConfig) {
        KinesisConfig kinesisConfig = new KinesisConfig(streamConfig);
        this._kinesisConnectionHandler = new KinesisConnectionHandler(kinesisConfig);
        this._kinesisStreamConsumerFactory = StreamConsumerFactoryProvider.create((StreamConfig)streamConfig);
        this._clientId = clientId;
        this._fetchTimeoutMs = streamConfig.getFetchTimeoutMillis();
    }

    public KinesisStreamMetadataProvider(String clientId, StreamConfig streamConfig, KinesisConnectionHandler kinesisConnectionHandler, StreamConsumerFactory streamConsumerFactory) {
        this._kinesisConnectionHandler = kinesisConnectionHandler;
        this._kinesisStreamConsumerFactory = streamConsumerFactory;
        this._clientId = clientId;
        this._fetchTimeoutMs = streamConfig.getFetchTimeoutMillis();
    }

    public int fetchPartitionCount(long timeoutMillis) {
        throw new UnsupportedOperationException();
    }

    public StreamPartitionMsgOffset fetchStreamPartitionOffset(OffsetCriteria offsetCriteria, long timeoutMillis) {
        throw new UnsupportedOperationException();
    }

    public List<PartitionGroupMetadata> computePartitionGroupMetadata(String clientId, StreamConfig streamConfig, List<PartitionGroupConsumptionStatus> partitionGroupConsumptionStatuses, int timeoutMillis) throws IOException, TimeoutException {
        ArrayList<PartitionGroupMetadata> newPartitionGroupMetadataList = new ArrayList<PartitionGroupMetadata>();
        Map<String, Shard> shardIdToShardMap = this._kinesisConnectionHandler.getShards().stream().collect(Collectors.toMap(Shard::shardId, s -> s, (s1, s2) -> s1));
        HashSet<String> shardsInCurrent = new HashSet<String>();
        HashSet<String> shardsEnded = new HashSet<String>();
        for (PartitionGroupConsumptionStatus partitionGroupConsumptionStatus : partitionGroupConsumptionStatuses) {
            StreamPartitionMsgOffset newStartOffset;
            KinesisPartitionGroupOffset kinesisStartCheckpoint = (KinesisPartitionGroupOffset)partitionGroupConsumptionStatus.getStartOffset();
            String shardId = kinesisStartCheckpoint.getShardToStartSequenceMap().keySet().iterator().next();
            shardsInCurrent.add(shardId);
            Shard shard = shardIdToShardMap.get(shardId);
            if (shard == null) {
                shardsEnded.add(shardId);
                String lastConsumedSequenceID = kinesisStartCheckpoint.getShardToStartSequenceMap().get(shardId);
                LOGGER.warn("Kinesis shard with id: " + shardId + " has expired. Data has been consumed from the shard till sequence number: " + lastConsumedSequenceID + ". There can be potential data loss.");
                continue;
            }
            StreamPartitionMsgOffset currentEndOffset = partitionGroupConsumptionStatus.getEndOffset();
            if (currentEndOffset != null) {
                String endingSequenceNumber = shard.sequenceNumberRange().endingSequenceNumber();
                if (endingSequenceNumber != null && this.consumedEndOfShard(currentEndOffset, partitionGroupConsumptionStatus)) {
                    shardsEnded.add(shardId);
                    continue;
                }
                newStartOffset = currentEndOffset;
            } else {
                newStartOffset = partitionGroupConsumptionStatus.getStartOffset();
            }
            newPartitionGroupMetadataList.add(new PartitionGroupMetadata(partitionGroupConsumptionStatus.getPartitionGroupId(), newStartOffset));
        }
        for (Map.Entry entry : shardIdToShardMap.entrySet()) {
            Shard newShard;
            String parentShardId;
            String newShardId = (String)entry.getKey();
            if (shardsInCurrent.contains(newShardId) || (parentShardId = (newShard = (Shard)entry.getValue()).parentShardId()) != null && shardIdToShardMap.containsKey(parentShardId) && !shardsEnded.contains(parentShardId)) continue;
            HashMap<String, String> shardToSequenceNumberMap = new HashMap<String, String>();
            shardToSequenceNumberMap.put(newShardId, newShard.sequenceNumberRange().startingSequenceNumber());
            KinesisPartitionGroupOffset newStartOffset = new KinesisPartitionGroupOffset(shardToSequenceNumberMap);
            int partitionGroupId = this.getPartitionGroupIdFromShardId(newShardId);
            newPartitionGroupMetadataList.add(new PartitionGroupMetadata(partitionGroupId, (StreamPartitionMsgOffset)newStartOffset));
        }
        return newPartitionGroupMetadataList;
    }

    private int getPartitionGroupIdFromShardId(String shardId) {
        String shardIdNum = StringUtils.stripStart((String)StringUtils.removeStart((String)shardId, (String)SHARD_ID_PREFIX), (String)"0");
        return shardIdNum.isEmpty() ? 0 : Integer.parseInt(shardIdNum);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private boolean consumedEndOfShard(StreamPartitionMsgOffset startCheckpoint, PartitionGroupConsumptionStatus partitionGroupConsumptionStatus) throws IOException, TimeoutException {
        MessageBatch messageBatch;
        try (PartitionGroupConsumer partitionGroupConsumer = this._kinesisStreamConsumerFactory.createPartitionGroupConsumer(this._clientId, partitionGroupConsumptionStatus);){
            messageBatch = partitionGroupConsumer.fetchMessages(startCheckpoint, null, this._fetchTimeoutMs);
        }
        return messageBatch.isEndOfPartitionGroup();
    }

    public void close() {
    }
}

