package tech.ydb.yoj.repository.ydb.merge;

import lombok.NonNull;
import lombok.Value;
import lombok.With;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import tech.ydb.yoj.repository.db.Entity;
import tech.ydb.yoj.repository.db.cache.RepositoryCache;
import tech.ydb.yoj.repository.db.exception.EntityAlreadyExistsException;
import tech.ydb.yoj.repository.ydb.YdbRepository.Query;
import tech.ydb.yoj.repository.ydb.exception.YdbRepositoryException;
import tech.ydb.yoj.repository.ydb.statement.Statement;
import tech.ydb.yoj.repository.ydb.statement.YqlStatement;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

public class ByEntityYqlQueriesMerger implements YqlQueriesMerger {
    private static final Logger log = LoggerFactory.getLogger(ByEntityYqlQueriesMerger.class);

    private static final Set<Statement.QueryType> SUPPORTED_QUERY_TYPES = new HashSet<>(Arrays.asList(
            Statement.QueryType.INSERT,
            Statement.QueryType.DELETE,
            Statement.QueryType.UPSERT,
            Statement.QueryType.UPDATE,
            Statement.QueryType.DELETE_ALL));
    private static final Map<TransitionKey, MergingState> transitionMap = createTransitionMap();

    private final Map<TableMetadata, TableState> states = new HashMap<>();
    private final RepositoryCache cache;

    ByEntityYqlQueriesMerger(RepositoryCache cache) {
        this.cache = cache;
    }

    @Override
    public void onNext(Query<?> query) {
        Statement.QueryType queryType = query.getStatement().getQueryType();
        check(SUPPORTED_QUERY_TYPES.contains(queryType), "Unsupported query type: " + queryType);

        TableState tableState = states.computeIfAbsent(new TableMetadata(getEntityClass(query), getTableName(query)), __ -> new TableState());
        if (queryType == Statement.QueryType.DELETE_ALL) {
            tableState.entityStates.clear();
            tableState.deleteAll = query;
            return;
        } else if (queryType == Statement.QueryType.UPDATE) {
            check(tableState.isEmpty(), "Update operation couldn't be after other modifications");
            tableState.update = query;
            return;
        }

        check(tableState.deleteAll == null && tableState.update == null,
                "Modifications after delete_all or update aren't allowed");
        EntityState state;
        Entity.Id id = getEntityId(query);
        if (tableState.entityStates.containsKey(id)) {
            state = tableState.entityStates.get(id);
            MergingState oldMergingState = state.getState();
            state = state.withState(doTransition(oldMergingState, queryType, query));
            if (state.getState() != MergingState.INS_DEL) {
                Query<?> replaceWith = query;
                if (oldMergingState == MergingState.DELETE && queryType == Statement.QueryType.INSERT) {
                    // DELETE, INSERT -> UPSERT
                    replaceWith = convertInsertToUpsert(query);
                }
                state = state.withQuery(replaceWith);
            }
        } else {
            state = new EntityState(query, doTransition(MergingState.INITIAL, queryType, query));
        }
        tableState.entityStates.put(id, state);
    }

    @Override
    public List<Query<?>> getQueries() {
        Map<MergingState, List<Query<?>>> queries = new HashMap<>();
        List<Query<?>> specificQueries = new ArrayList<>();

        for (TableState tableState : states.values()) {
            if (tableState.deleteAll != null) {
                specificQueries.add(tableState.deleteAll);
            } else if (tableState.update != null) {
                specificQueries.add(tableState.update);
            } else {
                Map<MergingState, Query<?>> curQueries = new HashMap<>();
                for (EntityState entityState : tableState.entityStates.values()) {
                    MergingState curState = entityState.state;
                    if (curState == MergingState.INS_DEL) {
                        updateCurQueries(curQueries, convertInsertToDelete(entityState.query), MergingState.DELETE);
                        curState = MergingState.INSERT;
                    } else if (needIgnoreQuery(entityState)) {
                        log.trace("Ignoring query: [{}]", entityState.query.getStatement());
                        continue;
                    }
                    updateCurQueries(curQueries, entityState.query, curState);
                }

                for (Map.Entry<MergingState, Query<?>> entry : curQueries.entrySet()) {
                    queries.computeIfAbsent(entry.getKey(), __ -> new ArrayList<>()).add(entry.getValue());
                }
            }
        }

        List<Query<?>> result = new ArrayList<>();
        addAllIfNonNull(result, queries.get(MergingState.INSERT));
        addAllIfNonNull(result, queries.get(MergingState.UPSERT));
        addAllIfNonNull(result, queries.get(MergingState.DELETE));
        result.addAll(specificQueries);
        return result;
    }

    private boolean needIgnoreQuery(EntityState entityState) {
        if (entityState.state == MergingState.UPSERT || entityState.state == MergingState.INSERT) {
            Class<?> clazz = getEntityClass(entityState.query);
            Entity.Id entityId = getEntityId(entityState.query);
            RepositoryCache.Key key = new RepositoryCache.Key(clazz, entityId);

            if (entityState.state == MergingState.UPSERT) {
                boolean newValueEqualsCached = cache.get(key)
                        .map(entity -> entity.equals(entityState.query.getValues().get(0)))
                        .orElse(false);
                if (newValueEqualsCached) {
                    log.trace("New value {} is equal to cached value", entityState.query.getValues().get(0));
                }
                return newValueEqualsCached;
            } else if (cache.contains(key) && cache.get(key).isPresent()) { // INSERT case
                throw new EntityAlreadyExistsException("Entity " + entityId + " already exists");
            }
        }
        return false;
    }

    private void addAllIfNonNull(List<Query<?>> result, List<Query<?>> additional) {
        if (additional != null) {
            result.addAll(additional);
        }
    }

    private void updateCurQueries(Map<MergingState, Query<?>> curQueries, Query<?> newQuery, MergingState curState) {
        curQueries.computeIfPresent(curState, (__, q) -> q.merge(newQuery));
        curQueries.putIfAbsent(curState, newQuery);
    }

    private MergingState doTransition(MergingState state, Statement.QueryType nextQueryType, Query<?> query) {
        if (state == MergingState.INSERT && nextQueryType == Statement.QueryType.INSERT) {
            throw new EntityAlreadyExistsException("Entity " + getEntityId(query) + " already exists");
        }
        MergingState nextState = transitionMap.get(new TransitionKey(state, nextQueryType));
        check(nextState != null, "Incorrect transition, from " + state + " by " + nextQueryType);
        return nextState;
    }

    @SuppressWarnings("unchecked")
    private static Query convertInsertToUpsert(Query<?> query) {
        return new Query<>(YqlStatement.save(getEntityClass(query)), query.getValues().get(0));
    }

    @SuppressWarnings("unchecked")
    private static Query convertInsertToDelete(Query<?> query) {
        return new Query<>(YqlStatement.delete(getEntityClass(query)), getEntityId(query));
    }

    private static Entity.Id getEntityId(Query<?> query) {
        check(query.getValues().size() == 1, "Unsupported query");

        Object value = query.getValues().get(0);
        if (query.getStatement().getQueryType() == Statement.QueryType.DELETE) {
            return (Entity.Id) value;
        } else {
            return ((Entity) value).getId();
        }
    }

    private static Class getEntityClass(Query query) {
        return convertQueryToYqlStatement(query).getInSchemaType();
    }

    private static String getTableName(Query query) {
        return convertQueryToYqlStatement(query).getTableName();
    }

    private static YqlStatement convertQueryToYqlStatement(Query query) {
        return (YqlStatement) query.getStatement();
    }

    private static void check(boolean condition, String message) {
        if (!condition) {
            throw new YdbRepositoryException(message);
        }
    }

    private static Map<TransitionKey, MergingState> createTransitionMap() {
        Map<TransitionKey, MergingState> table = new HashMap<>();
        table.put(new TransitionKey(MergingState.INITIAL, Statement.QueryType.INSERT), MergingState.INSERT);
        table.put(new TransitionKey(MergingState.INITIAL, Statement.QueryType.UPSERT), MergingState.UPSERT);
        table.put(new TransitionKey(MergingState.INITIAL, Statement.QueryType.DELETE), MergingState.DELETE);

        table.put(new TransitionKey(MergingState.INSERT, Statement.QueryType.INSERT), MergingState.INSERT);
        table.put(new TransitionKey(MergingState.INSERT, Statement.QueryType.UPSERT), MergingState.INSERT);
        table.put(new TransitionKey(MergingState.INSERT, Statement.QueryType.DELETE), MergingState.INS_DEL);

        table.put(new TransitionKey(MergingState.INS_DEL, Statement.QueryType.INSERT), MergingState.INSERT);
        table.put(new TransitionKey(MergingState.INS_DEL, Statement.QueryType.UPSERT), MergingState.INSERT);
        table.put(new TransitionKey(MergingState.INS_DEL, Statement.QueryType.DELETE), MergingState.INS_DEL);

        table.put(new TransitionKey(MergingState.UPSERT, Statement.QueryType.INSERT), MergingState.INSERT);
        table.put(new TransitionKey(MergingState.UPSERT, Statement.QueryType.UPSERT), MergingState.UPSERT);
        table.put(new TransitionKey(MergingState.UPSERT, Statement.QueryType.DELETE), MergingState.DELETE);

        table.put(new TransitionKey(MergingState.DELETE, Statement.QueryType.INSERT), MergingState.UPSERT);
        table.put(new TransitionKey(MergingState.DELETE, Statement.QueryType.UPSERT), MergingState.UPSERT);
        table.put(new TransitionKey(MergingState.DELETE, Statement.QueryType.DELETE), MergingState.DELETE);

        return table;
    }

    @With
    @Value
    private class EntityState {
        private Query<?> query;
        private MergingState state;
    }

    private class TableState {
        private Map<Entity.Id, EntityState> entityStates = new HashMap<>();
        private Query<?> deleteAll;
        private Query<?> update;

        public boolean isEmpty() {
            return entityStates.isEmpty() && update == null && deleteAll == null;
        }
    }

    @Value
    private static class TransitionKey {
        private MergingState state;
        private Statement.QueryType nextQueryType;
    }

    @Value
    private static class TableMetadata {
        private @NonNull Class<?> entityClass;
        private @NonNull String tableName;
    }

    private enum MergingState {
        INITIAL,
        INSERT,
        INS_DEL,
        UPSERT,
        DELETE,
    }
}
