package cn.dolphin.core.dialect.util;


import cn.dolphin.core.consts.SqlTypeConstant;
import cn.dolphin.core.dialect.Dialect;
import cn.dolphin.core.dialect.common.MySQLDialect;
import cn.dolphin.core.dialect.common.OracleDialect;
import cn.dolphin.core.util.StrUtil;
import org.springframework.jdbc.core.JdbcTemplate;

import javax.sql.DataSource;
import java.sql.Connection;
import java.sql.SQLException;
import java.util.HashMap;
import java.util.Map;

/**
 * 自定义分类插件工具类
 */
public class DialectUtil {

    private final static Map<Object, Dialect> map = new HashMap<Object, Dialect>();


    /**
     * 根据数据类型获取分页插件
     * @param dbType
     * @return
     * @throws Exception
     */
    public static Dialect getDialect(String dbType) throws Exception {
        try {
            Dialect dialect = map.get(dbType);
            if (dialect != null) {
                return dialect;
            }
            if (dbType.equals(SqlTypeConstant.ORACLE)) {
                dialect = new OracleDialect();
                map.put(dbType, dialect);
            }else if (dbType.equals(SqlTypeConstant.MYSQL)) {
                dialect = new MySQLDialect();
                map.put(dbType, dialect);
            }else {
                throw new Exception("没有设置合适的数据库类型");
            }
            return dialect;
        }catch (Exception e){
            throw new Exception("没有设置合适的数据库类型");
        }
    }


    /**
     * 获取分页插件
     * @param jt
     * @return
     */
    public final static Dialect getDialect(JdbcTemplate jt) {
        Dialect dialect = map.get(jt);
        if (dialect != null) {
            return dialect;
        }
        try {
            dialect = getDialect(jt.getDataSource().getConnection());
            map.put(jt, dialect);
            return dialect;
        } catch (SQLException e) {
            throw new RuntimeException("Can't get the database product name .");
        }
    }

    /**
     * 获取分页插件
     * @param dataSource
     * @return
     */
    public static Dialect getDialect(DataSource dataSource) {
        Dialect dialect = map.get(dataSource);
        if (dialect != null) {
            return dialect;
        }
        try {
            dialect = getDialect(dataSource.getConnection());
            map.put(dataSource, dialect);
            return dialect;
        } catch (SQLException e) {
            throw new RuntimeException("Can't get the database product name .");
        }
    }


    /**
     * 获取分页插件
     * @param conn
     * @return
     * @throws Exception
     */
    public static Dialect getDialect(Connection conn){
        String name = null;
        try {
            name = conn.getMetaData().getDatabaseProductName();

            if(StrUtil.isEmpty(name)){
                throw new RuntimeException("Get dialect error.");
            }
            Dialect dialect;
            if(SqlTypeConstant.ORACLE.equals(name.toLowerCase())){
                dialect = new OracleDialect();
            }else if(SqlTypeConstant.MYSQL.equals(name.toLowerCase())){
                dialect = new MySQLDialect();
            }else {
                throw new Exception("没有设置合适的数据库类型");
            }
            return dialect;
        } catch (Exception e) {
            throw new RuntimeException("Can't get the database product name .");
        }
    }


}
