package cn.hperfect.nbquerier.core.components.executor;

import cn.hperfect.nbquerier.config.NbQuerierCons;
import cn.hperfect.nbquerier.config.properties.NbQuerierProperties;
import cn.hperfect.nbquerier.core.components.builder.INbSqlBuilder;
import cn.hperfect.nbquerier.core.components.executor.options.DoQueryOptions;
import cn.hperfect.nbquerier.core.components.executor.options.DoUpdateOptions;
import cn.hperfect.nbquerier.core.components.result.IResultSetHandler;
import cn.hperfect.nbquerier.core.components.result.json.NbJsonSerializer;
import cn.hperfect.nbquerier.core.components.type.INbQueryType;
import cn.hperfect.nbquerier.core.components.type.NbQueryType;
import cn.hperfect.nbquerier.core.metedata.INbExecuteBatch;
import cn.hperfect.nbquerier.core.metedata.QueryValParam;
import cn.hperfect.nbquerier.core.querier.NbQuerier;
import cn.hperfect.nbquerier.core.transaction.INbTransaction;
import cn.hperfect.nbquerier.core.type.JsonNbType;
import cn.hperfect.nbquerier.enums.DbType;
import cn.hperfect.nbquerier.enums.ResultType;
import cn.hperfect.nbquerier.exceptions.NbSQLException;
import cn.hperfect.nbquerier.exceptions.NbSQLExecuteException;
import cn.hperfect.nbquerier.exceptions.NbSQLMessageException;
import cn.hperfect.nbquerier.exceptions.TypeConvertException;
import cn.hperfect.nbquerier.toolkit.SqlLoggerUtils;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.collection.ListUtil;
import cn.hutool.core.convert.Convert;
import cn.hutool.core.date.DateUtil;
import cn.hutool.core.date.TimeInterval;
import cn.hutool.core.io.IoUtil;
import cn.hutool.core.lang.Assert;
import cn.hutool.core.util.ReUtil;
import cn.hutool.core.util.StrUtil;
import com.impossibl.postgres.api.jdbc.PGType;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.postgresql.util.PGobject;

import java.sql.*;
import java.util.Collection;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;

/**
 * 执行器-> mybatis -> 执行?
 * 拿到base mapper执行
 * 重写result handler
 * <p>
 * 直接获取DataSource,获取连接，准备，执行 管理事务
 * todo bug:只有数据源名称为空 才支持事务
 *
 * @author huanxi
 * @version 1.0
 * @date 2021/11/24 10:50 上午
 */
@Slf4j
public class DefaultNbExecutor implements INbExecutor {

    private final NbQuerierProperties config;
    private final INbTransaction transaction;
    private final IResultSetHandler resultSetHandler;

    private final INbSqlBuilder sqlBuilder;

    private final NbJsonSerializer jsonSerializer;


    public DefaultNbExecutor(INbTransaction tx, NbQuerierProperties config, INbSqlBuilder nbSqlBuilder, IResultSetHandler resultSetHandler, NbJsonSerializer jsonSerializer) {
        this.transaction = tx;
        this.config = config;
        this.sqlBuilder = nbSqlBuilder;
        this.resultSetHandler = resultSetHandler;
        this.jsonSerializer = jsonSerializer;
    }

    @Data
    @AllArgsConstructor
    static class PrepareSql {
        private String sql;
        private List<QueryValParam> params;
    }

    /**
     * @param resultType todo 待优化,取返回值泛型类型
     * @param <T>
     * @return
     */
    @Override
    @SuppressWarnings("unchecked")
    public <T> T doQuery(DoQueryOptions options) {
        String dsName = options.getDsName();
        ResultType resultType = options.getResultType();
        List<QueryValParam> params = options.getParams();
        String sql = options.getSql();
        TimeInterval timer = DateUtil.timer();
        T result = null;
        Connection connection = null;
        ResultSet resultSet = null;
        PreparedStatement preparedStatement = null;
        try {
            connection = this.transaction.getConnection(dsName);
            PrepareSql prepareSql = getPrepareSql(sql, params);
//            ResultSet.TYPE_FORWARD_ONLY
            preparedStatement = connection.prepareStatement(prepareSql.getSql());
            //设置参数
            setParams(prepareSql.getParams(), connection, preparedStatement);
            // 结果集处理
            resultSet = preparedStatement.executeQuery();
            switch (resultType) {
                case INT:
                    if (resultSet.next()) {
                        result = (T) Convert.toInt(resultSet.getInt(1));
                    } else {
                        result = (T) Convert.toInt(0);
                    }
                    break;
                case LONG:
                    if (resultSet.next()) {
                        result = (T) Convert.toLong(resultSet.getLong(1));
                    } else {
                        result = (T) Convert.toLong(0);
                    }
                    break;
                case LIST:
                    result = (T) resultSetHandler.toMap(resultSet);
                    break;
                case STRING:
                    if (resultSet.next()) {
                        result = (T) Convert.toStr(resultSet.getString(1));
                    }
                    break;
                case DOUBLE:
                    if (resultSet.next()) {
                        result = (T) Convert.toDouble(resultSet.getDouble(1), 0d);
                    } else {
                        result = (T) Convert.toDouble(0);
                    }
                    break;
                default:
                    throw new NbSQLException("为实现该类型查询:{}", resultType);
            }
        } catch (Throwable throwable) {
            throw new NbSQLExecuteException(SqlLoggerUtils.logSql(sql, params), throwable);
        } finally {
            IoUtil.close(resultSet);
            IoUtil.close(preparedStatement);
            IoUtil.close(transaction);
        }

        if (config.isLogSql()) {
            log.info(SqlLoggerUtils.logSql(timer, sql, params));
        }
        return result;
    }

    @Override
    @SuppressWarnings("unchecked")
    @Deprecated
    public <T> T doQuery(String sql) {
        Connection connection = null;
        List<QueryValParam> params = ListUtil.empty();
        PreparedStatement preparedStatement = null;
        try {
            connection = this.transaction.getConnection(null);
            PrepareSql prepareSql = getPrepareSql(sql, params);
//            ResultSet.TYPE_FORWARD_ONLY
            preparedStatement = connection.prepareStatement(prepareSql.getSql(), ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY);
            //设置参数
            setParams(prepareSql.getParams(), connection, preparedStatement);
            // 结果集处理
            ResultSet resultSet = preparedStatement.executeQuery();
            while (resultSet.next()) {
                T result = (T) resultSetHandler.toMap(resultSet);
                log.info("查出结果:{}", result);
            }
        } catch (Throwable throwable) {
            throw new NbSQLExecuteException(SqlLoggerUtils.logSql(sql, params), throwable);
        } finally {
            IoUtil.close(preparedStatement);
            IoUtil.close(transaction);
        }
        return null;
    }

    /**
     * 替换占位符,并对参数重排序
     *
     * @param sql
     * @param params
     * @return
     */
    public PrepareSql getPrepareSql(String sql, List<QueryValParam> params) {
        List<QueryValParam> paramsNew = new LinkedList<>();
        if (CollUtil.isNotEmpty(params)) {
            List<String> all = ReUtil.findAll(NbQuerierCons.SQL_PARAM_PATTERN, sql, 0);
            for (String paramStr : all) {
                String indexStr = ReUtil.get(NbQuerierCons.SQL_PARAM_PATTERN, paramStr, 1);
                Assert.notBlank(indexStr, "参数替换失败，占位符不合法:{}", indexStr);
                sql = StrUtil.replace(sql, paramStr, "?");
                int index = Convert.toInt(indexStr);
                paramsNew.add(params.get(index));
            }
        }
        return new PrepareSql(sql, paramsNew);
    }

    /**
     * 设置参数
     *
     * @param params
     * @param connection
     * @param preparedStatement
     * @throws TypeConvertException
     * @throws SQLException
     */
    private void setParams(List<QueryValParam> params, Connection connection, PreparedStatement preparedStatement) throws TypeConvertException, SQLException {
        if (CollUtil.isNotEmpty(params)) {
            for (int i = 0; i < params.size(); i++) {
                QueryValParam param = params.get(i);
                INbQueryType type = param.getType();
                Assert.notNull(type, "设置参数的type不能为空");
                Assert.notNull(type.getDbDataType(), "type:{},dbType不能为空", type);
                //三种类型特殊处理
                switch (type.getDbDataType()) {
                    case ARRAY:
                        NbQueryType arraySubType = type.getArraySubType();
                        Object[] convert = (Object[]) type.convert(param.getValue());
                        Array arrayObj = connection.createArrayOf(arraySubType.getDbTypeSql(), convert);
                        preparedStatement.setObject(i + 1, arrayObj);
                        break;
                    case JSON:
                        JsonNbType jsonNbType = (JsonNbType) type;
                        Object value = param.getValue();
                        if (value != null && jsonNbType.getJsonArray() != null) {
                            boolean isArrayValue = value instanceof Collection || value.getClass().isArray();
                            boolean jsonArray = jsonNbType.isJsonArray();
                            if (jsonArray && !isArrayValue) {
                                throw new NbSQLMessageException("json值设置错误,值必须为数组");
                            }
                            if (!jsonArray && isArrayValue) {
                                throw new NbSQLMessageException("json值设置错误,值不能为数组");
                            }
                        }


                        if (config.getDbType() == DbType.PG_NG) {
                            // 字符串
                            Object obj = type.convert(param.getValue());
                            preparedStatement.setObject(i + 1, jsonSerializer.serialize(obj), PGType.JSON);
                        } else if (config.getDbType() == DbType.POSTGRE_SQL) {
                            Object obj = type.convert(param.getValue());
                            PGobject pGobject = new PGobject();
                            //数组,json 字符串
                            pGobject.setValue(jsonSerializer.serialize(obj));
                            pGobject.setType("jsonb");
                            preparedStatement.setObject(i + 1, pGobject);
                        } else if (config.getDbType() == DbType.MYSQL) {
                            Object obj = type.convert(param.getValue());
                            preparedStatement.setObject(i + 1, obj);
                        } else {
                            throw new NbSQLMessageException("不支持数据库类型:{},保存json字段", type);
                        }
                        break;
                    case GENERAL:
                        Object obj = type.convert(param.getValue());
                        preparedStatement.setObject(i + 1, obj);
                        break;
                    default:
                        throw new NbSQLMessageException("未处理的数据类型");
                }
            }
        }
    }

    @Override
    public int insertBatch(NbQuerier<?> querier, List<Map<String, Object>> maps, List<QueryValParam> params) {
        if (CollUtil.isEmpty(maps)) {
            return 0;
        }
        int sum = 0;
        List<List<Map<String, Object>>> groups = CollUtil.split(maps, 1000);
        if (groups.size() == 1) {
            String sql = sqlBuilder.buildSaveSql(querier, groups.get(0));
            return this.doUpdate(querier.getDsName(), sql, querier.getParams(), true);
        }

        try {
            transaction.setAutoCommit(false);
            for (List<Map<String, Object>> group : groups) {
                String sql = sqlBuilder.buildSaveSql(querier, group);
                sum += this.doUpdate(querier.getDsName(), sql, querier.getParams(), false);
            }
            transaction.commit();
        } catch (Exception exception) {
            try {
                transaction.rollback();
            } catch (SQLException e) {
                throw new NbSQLExecuteException("批量保存回滚错误错误", e);
            }
            throw new NbSQLExecuteException("批量保存错误", exception);
        } finally {
            IoUtil.close(transaction);
        }
        return sum;
    }

    @Override
    public int doUpdate(DoUpdateOptions options) {
        String dsName = options.getDsName();
        boolean autoClose = options.isAutoClose();
        TimeInterval timer = DateUtil.timer();
        PrepareSql prepareSql = getPrepareSql(options.getSql(), options.getParams());
        INbExecuteBatch batch = options.getBatch();
        int result = 0;
        Connection connection = null;
        PreparedStatement preparedStatement = null;
        boolean isBatch = batch != null;

        try {
            //获取连接
            connection = transaction.getConnection(dsName);
            if (isBatch) {
                transaction.setAutoCommit(false);
                preparedStatement = batch.getPreparedStatement();
                if (preparedStatement == null) {
                    preparedStatement = connection.prepareStatement(prepareSql.getSql());
                    batch.setPreparedStatement(preparedStatement);
                } else {
                    //检测sql

                }
                //使用已有的 statement
                setParams(prepareSql.getParams(), connection, preparedStatement);
                batch.addBatch();
            } else {
                //开启 statement
                preparedStatement = connection.prepareStatement(prepareSql.getSql());
                setParams(prepareSql.getParams(), connection, preparedStatement);
                result = preparedStatement.executeUpdate();
            }
        } catch (Throwable throwable) {
            if (throwable instanceof NbSQLMessageException) {
                throw (NbSQLMessageException) throwable;
            }
            if (StrUtil.isNotBlank(throwable.getMessage())) {
                if (config.getDbType() == DbType.MYSQL) {
                    String message = ReUtil.get("Field '(\\S+)' doesn't have a default value", throwable.getMessage(), 1);
                    if (StrUtil.isNotBlank(message)) {
                        throw new NbSQLMessageException("字段:{}为必填项", message);
                    }
                }
            }
            throw new NbSQLExecuteException(SqlLoggerUtils.logSql(options.getSql(), options.getParams()), throwable);
        } finally {
            if (autoClose) {
                IoUtil.close(transaction);
            }
            if (!isBatch) {
                IoUtil.close(preparedStatement);
            }
        }
        if (config.isLogSql()) {
            log.info(SqlLoggerUtils.logSql(timer, options.getSql(), options.getParams()));
        }
        return result;
    }
}
