package cn.ibizlab.util.filter;

import com.baomidou.mybatisplus.annotation.DbType;
import com.baomidou.mybatisplus.extension.toolkit.JdbcUtils;

import javax.sql.DataSource;
import java.sql.Connection;
import java.sql.SQLException;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

public class DbTypeContextHolder {

    private static final String MASTER_DS_NAME = "master";

    private static final Map<String, DbType> JDBC_DB_TYPE_CACHE = new ConcurrentHashMap<>();

    private static DbTypeContext context = null;

    public static DbType getDbType(DataSource dataSource) {
        Connection connection = null;
        try {
            connection = dataSource.getConnection();
            return JdbcUtils.getDbType(connection.getMetaData().getURL());
        } catch (SQLException e) {
            return DbType.H2;
        } finally {
            try {
                if (connection != null && !connection.isClosed()) {
                    connection.close();
                }
            } catch (SQLException e) {
            }
        }
    }

    public static synchronized DbTypeContext getContext() {
        if (context == null) {
            context = new DbTypeContext() {
                @Override
                public DbType get() {
                    return peek(MASTER_DS_NAME);
                }
            };
        }
        return context;
    }

    public static void register(DataSource dataSource) {
        getContext().init(dataSource);
    }

    public static void register(DbTypeContext dbTypeContext) {
        context = dbTypeContext;
    }

    public static void register(DbTypeContext dbTypeContext, DataSource dataSource) {
        context = dbTypeContext;
        context.init(dataSource);
    }

    public static DbType get() {
        return getContext().get();
    }

    void push(String dsName, DbType dbType) {
        getContext().push(dsName, dbType);
    }

    void push(String dsName, String dbType) {
        getContext().push(dsName, dbType);
    }

    public interface DbTypeContext {
        DbType get();

        default void init(DataSource dataSource) {
            push(getDbType(dataSource));
        }

        default void push(String dbType) {
            push(DbType.getDbType(dbType));
        }

        default void push(DbType dbType) {
            push(MASTER_DS_NAME, dbType);
        }

        default void push(String dsName, String dbType) {
            push(dsName, DbType.getDbType(dbType));
        }

        default void push(String dsName, DbType dbType) {
            JDBC_DB_TYPE_CACHE.put(dsName, dbType);
        }

        default DbType peek(String dsName) {
            if (dsName == null)
                dsName = MASTER_DS_NAME;
            return JDBC_DB_TYPE_CACHE.getOrDefault(dsName, DbType.H2);
        }
    }

}
