package cn.net.vidyo.framework.builder.query;

import cn.net.vidyo.framework.builder.config.DataSourceConfig;
import cn.net.vidyo.framework.builder.domain.CusumExecuter;
import cn.net.vidyo.framework.builder.domain.DbType;
import cn.net.vidyo.framework.builder.domain.IDbQuery;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.sql.DataSource;
import java.sql.*;
import java.util.Optional;
import java.util.function.Consumer;

public class ConnectManager {
    protected final Logger logger = LoggerFactory.getLogger(DataSourceConfig.class);
    DataSourceConfig dataSourceConfig;

    public ConnectManager(DataSourceConfig dataSourceConfig) {
        this.dataSourceConfig = dataSourceConfig;
    }

    /**
     * 数据源实例
     *
     * @since 3.5.0
     */
    private DataSource dataSource;
    Connection connection;
    /**
     * schemaName
     */
    private String schemaName;
    /**
     * 创建数据库连接对象
     * 这方法建议只调用一次，毕竟只是代码生成，用一个连接就行。
     *
     * @return Connection
     */
    public Connection getConn() {
        try {
            if (connection != null && !connection.isClosed()) {
                return connection;
            } else {
                synchronized (this) {
                    if (dataSource != null) {
                        connection = dataSource.getConnection();
                    } else {
                        this.connection = DriverManager.getConnection(dataSourceConfig.getUrl(), dataSourceConfig.getUsername(), dataSourceConfig.getPassword());
                    }
                }
            }
            String schema = StringUtils.isNotBlank(dataSourceConfig.getSchemaName()) ? dataSourceConfig.getSchemaName() : getDefaultSchema(dataSourceConfig.getDbType(),dataSourceConfig.getUsername());
            if (StringUtils.isNotBlank(schema)) {
                schemaName = schema;
                try {
                    connection.setSchema(schemaName);
                } catch (Throwable t) {
                    logger.error("There may be exceptions in the driver and version of the database, " + t.getMessage());
                }
            }
        } catch (SQLException e) {
            throw new RuntimeException(e);
        }
        return connection;
    }

    /**
     * 获取数据库默认schema
     *
     * @param dbType  d
     * @param username d  s
     * @return 默认schema
     */
    protected String getDefaultSchema(DbType dbType,String username){
        String schema = null;
        if (DbType.POSTGRE_SQL == dbType) {
            //pg 默认 schema=public
            schema = "public";
        } else if (DbType.KINGBASE_ES == dbType) {
            //kingbase 默认 schema=PUBLIC
            schema = "PUBLIC";
        } else if (DbType.DB2 == dbType) {
            //db2 默认 schema=current schema
            schema = "current schema";
        } else if (DbType.ORACLE == dbType) {
            //oracle 默认 schema=username
            schema = username.toUpperCase();
        }
        return schema;
    }


    /**
     * 执行 SQL 查询，回调返回结果
     *
     * @param sql      执行SQL
     * @param consumer 结果处理
     * @throws SQLException ss
     */
    public void execute2(String sql, CusumExecuter consumer) throws SQLException {
        logger.debug("执行SQL:{}", sql);
        int count = 0;
        long start = System.nanoTime();
        try (PreparedStatement preparedStatement = connection.prepareStatement(sql);
             ResultSet resultSet = preparedStatement.executeQuery()) {
            while (resultSet.next()) {
                consumer.accept(consumer.createInstance(resultSet,dataSourceConfig.getDbQuery(),  dataSourceConfig.getDbType()));
                count++;
            }
            long end = System.nanoTime();
            logger.debug("返回记录数:{},耗时(ms):{}", count, (end - start) / 1000000);
        }
    }

    public void execute(String sql, Consumer<ResultSetWrapper> consumer) throws SQLException {
        execute2(sql, new CusumExecuter<ResultSetWrapper>() {
            @Override
            public ResultSetWrapper createInstance(ResultSet resultSet,IDbQuery dbQuery, DbType dbType) {
                return new ResultSetWrapper(resultSet,dbQuery,dbType);
            }

            @Override
            public void accept(ResultSetWrapper o) {
                consumer.accept(o);
            }
        });
    }

    public Connection getConnection() {
        if(connection==null){
            connection= getConn();
        }
        return connection;
    }

    public void closeConnection() {
        Optional.ofNullable(connection).ifPresent((con) -> {
            try {
                con.close();
            } catch (SQLException sqlException) {
                sqlException.printStackTrace();
            }
        });
    }

    public static class ResultSetWrapper {

        private final IDbQuery dbQuery;

        private final ResultSet resultSet;

        private final DbType dbType;

        ResultSetWrapper(ResultSet resultSet, IDbQuery dbQuery, DbType dbType) {
            this.resultSet = resultSet;
            this.dbQuery = dbQuery;
            this.dbType = dbType;
        }

        public ResultSet getResultSet() {
            return resultSet;
        }

        public String getStringResult(String columnLabel) {
            try {
                return resultSet.getString(columnLabel);
            } catch (SQLException sqlException) {
                throw new RuntimeException(String.format("读取[%s]字段出错!", columnLabel), sqlException);
            }
        }

        /**
         * @return 获取字段注释
         * @deprecated 3.5.3
         */
        public String getFiledComment() {
            return getComment(dbQuery.fieldComment());
        }

        /**
         * 获取格式化注释
         *
         * @param columnLabel 字段列
         * @return 注释
         * @deprecated 3.5.3
         */
        @Deprecated
        private String getComment(String columnLabel) {
            return StringUtils.isNotBlank(columnLabel) ? formatComment(getStringResult(columnLabel)) : "";
        }

        /**
         * 获取表注释
         *
         * @return 表注释
         * @deprecated 3.5.3
         */
        @Deprecated
        public String getTableComment() {
            return getComment(dbQuery.tableComment());
        }

        /**
         * @param comment 注释
         * @return 格式化内容
         * @deprecated 3.5.3
         */
        @Deprecated
        public String formatComment(String comment) {
            return StringUtils.isBlank(comment) ? "" : comment.replaceAll("\r\n", "\t");
        }

        /**
         * @deprecated 3.5.3
         * @return 是否主键
         */
        @Deprecated
        public boolean isPrimaryKey() {
            String key = this.getStringResult(dbQuery.fieldKey());
            if (DbType.DB2 == dbType || DbType.SQLITE == dbType || DbType.CLICK_HOUSE == dbType) {
                return StringUtils.isNotBlank(key) && "1".equals(key);
            } else {
                return StringUtils.isNotBlank(key) && "PRI".equalsIgnoreCase(key);
            }
        }
    }
}
