package cn.mvapi.data.mysql.service.impl;

import cn.hutool.core.util.StrUtil;
import cn.mvapi.data.mysql.config.MySqlCreateConfig;
import cn.mvapi.data.mysql.dao.CreateTableMapper;
import cn.mvapi.data.mysql.pojo.enums.SqlOperateTypeEnum;
import cn.mvapi.data.mysql.pojo.vo.CreateSqlVO;
import cn.mvapi.data.mysql.service.TableInitService;
import com.alibaba.druid.DbType;
import com.alibaba.druid.sql.SQLUtils;
import com.alibaba.druid.sql.ast.SQLDataType;
import com.alibaba.druid.sql.ast.SQLStatement;
import com.alibaba.druid.sql.ast.statement.SQLAlterTableStatement;
import com.alibaba.druid.sql.ast.statement.SQLColumnDefinition;
import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlAlterTableChangeColumn;
import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlCreateTableStatement;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;

import javax.annotation.PostConstruct;
import java.util.*;
import java.util.stream.Collectors;

/**
 * @author shibaolin
 */
@Service
public class CreateTableServiceImpl {

    @Autowired
    private List<TableInitService> tableInitServices;
    @Autowired
    private CreateTableMapper createTableMapper;
    @Autowired
    private MySqlCreateConfig mySqlCreateConfig;

    @PostConstruct
    public void autoCreateTable() {
        if(mySqlCreateConfig.getIsAutoUpdateTable()||mySqlCreateConfig.getIsAutoCreateTable()||mySqlCreateConfig.getIsAutoDeleteTable()){
            List<CreateSqlVO> changeSql = getChangeSql();
            if(!mySqlCreateConfig.getIsAutoDeleteTable()){
                changeSql=changeSql.stream().filter(item -> !item.getOperateType() .equals(SqlOperateTypeEnum.DELETE_TABLE.getCode()) )
                        .collect(Collectors.toList());
            }
            if(!mySqlCreateConfig.getIsAutoCreateTable()){
                changeSql=changeSql.stream().filter(item -> !item.getOperateType() .equals(SqlOperateTypeEnum.CREATE_TABLE.getCode()) )
                        .collect(Collectors.toList());
            }
            if(!mySqlCreateConfig.getIsAutoUpdateTable()){
                changeSql=changeSql.stream().filter(item ->
                                item.getOperateType() .equals(SqlOperateTypeEnum.DELETE_TABLE.getCode())
                                        || item.getOperateType().equals(SqlOperateTypeEnum.CREATE_TABLE.getCode())  )
                        .collect(Collectors.toList());
            }
            if(changeSql.size()>0){
                List<String> newTable = changeSql.stream().map(CreateSqlVO::getTableSql).collect(Collectors.toList());
                for (String s : newTable) {
                    createTableMapper.execTableSql(s);
                }
            }

        }
        //初始化表
        for (TableInitService tableInitService : tableInitServices) {
            tableInitService.initTable();
        }
        //初始化数据
        for (TableInitService tableInitService : tableInitServices) {
            tableInitService.initData();
        }

    }



    public List<CreateSqlVO> getChangeSql() {
        //解析所有语句
        List<SQLStatement> sqlStatements = this.getNewSql();
        //获取所有建表语句
        Map<String, MySqlCreateTableStatement> allNewCreateSql = this.getNewCreateSql(sqlStatements);
        //获取数据库里面的所有sql
        Map<String, MySqlCreateTableStatement> dbSql = this.getDbSql();
        //获取需要新增的表
        List<String> newTable = allNewCreateSql.keySet().stream()
                .filter(x -> !dbSql.containsKey(x))
                .collect(Collectors.toList());
        //需要删除的表
        List<String> deleteTable = dbSql.keySet().stream().filter(x -> !allNewCreateSql.containsKey(x)).collect(Collectors.toList());
        //获取可能需要修改的表
        List<String> modifyTable = dbSql.keySet().stream().filter(x -> allNewCreateSql.containsKey(x)).collect(Collectors.toList());
        //获取表更新语句
        List<CreateSqlVO> tableUpdateSql = getTableUpdateSql(modifyTable, allNewCreateSql, dbSql);

        List<CreateSqlVO> allSql = new ArrayList<>();
        allSql.addAll(tableUpdateSql);
        allSql.addAll(getAddTableSql(newTable, allNewCreateSql));
        allSql.addAll(getDelSql(deleteTable));
        return allSql;
    }

    private Collection<? extends CreateSqlVO> getDelSql(List<String> deleteTable) {
        return deleteTable.stream().map(tableName -> CreateSqlVO.builder()
                .tableName(tableName)
                .operateType(SqlOperateTypeEnum.DELETE_TABLE.getCode())
                .operateContent(SqlOperateTypeEnum.DELETE_TABLE.getName())
                .riskLevel(SqlOperateTypeEnum.DELETE_TABLE.getRiskLevel())
                .tableSql(getDelSqlByTableName(tableName))
                .build())
                .collect(Collectors.toList());
    }

    private String getDelSqlByTableName(String tableName) {
        return StrUtil.format("drop table {}", tableName);
    }

    private Collection<? extends CreateSqlVO> getAddTableSql(List<String> newTable, Map<String, MySqlCreateTableStatement> allNewCreateSql1) {
        return newTable.stream().map(tableName -> allNewCreateSql1.get(tableName))
                .map(newCreateSql -> CreateSqlVO.builder()
                        .tableName(newCreateSql.getTableName())
                        .operateType(SqlOperateTypeEnum.CREATE_TABLE.getCode())
                        .operateContent(SqlOperateTypeEnum.CREATE_TABLE.getName())
                        .riskLevel(SqlOperateTypeEnum.CREATE_TABLE.getRiskLevel())
                        .tableSql(newCreateSql.toString())
                        .build())
                .collect(Collectors.toList());
    }

    /**
     * 获取表更新语句
     *
     * @param modifyTable
     * @param allNewCreateSql
     * @param dbSql
     * @return
     */
    private List<CreateSqlVO> getTableUpdateSql(List<String> modifyTable, Map<String, MySqlCreateTableStatement> allNewCreateSql, Map<String, MySqlCreateTableStatement> dbSql) {
        List<CreateSqlVO> allSql = new ArrayList<>();
        for (String tableName : modifyTable) {
            MySqlCreateTableStatement newCreateSql = allNewCreateSql.get(tableName);
            MySqlCreateTableStatement dbCreateSql = dbSql.get(tableName);
            List<CreateSqlVO> changeSql = getUpdateTableSql(tableName, newCreateSql, dbCreateSql);
            allSql.addAll(changeSql);
        }
        return allSql;
    }

    private List<CreateSqlVO> getUpdateTableSql(String tableName, MySqlCreateTableStatement newCreateSql, MySqlCreateTableStatement dbCreateSql) {
        List<CreateSqlVO> changeSql = new ArrayList<>();
        if (newCreateSql == null || dbCreateSql == null) {
            return changeSql;
        }
        Map<String, SQLColumnDefinition> newColumnMap = newCreateSql.getColumnDefinitions().stream()
                .collect(Collectors.toMap(e->e.getColumnName().toLowerCase(Locale.ROOT).replaceAll("`", ""), item -> item, (oldValue, newValue) -> newValue));
        Map<String, SQLColumnDefinition> dbColumnMap = dbCreateSql.getColumnDefinitions().stream()
                .collect(Collectors.toMap(e->e.getColumnName().toLowerCase(Locale.ROOT).replaceAll("`", ""), item -> item, (oldValue, newValue) -> newValue));
        //需要新增的列
        List<SQLColumnDefinition> addColumnList = newColumnMap.keySet().stream().filter(item -> !dbColumnMap.containsKey(item))
                .map(newColumnMap::get).collect(Collectors.toList());
        List<CreateSqlVO> addColumnSql = getAddColumnSql(tableName, addColumnList);
        //需要删除的列
        List<SQLColumnDefinition> deleteColumnList = dbColumnMap.keySet().stream().filter(item -> !newColumnMap.containsKey(item))
                .map(dbColumnMap::get).collect(Collectors.toList());
        List<CreateSqlVO> deleteColumnSql = getDeleteColumnSql(tableName,deleteColumnList);
        //需要更新的列
        List<String> updateColumnName = newColumnMap.keySet().stream().filter(dbColumnMap::containsKey)
                .collect(Collectors.toList());
        List<CreateSqlVO> updateColumnSql = getUpdateColumnSql(tableName,updateColumnName, newColumnMap,dbColumnMap);
        List<CreateSqlVO> allSql = new ArrayList<>();
        allSql.addAll(addColumnSql);
        allSql.addAll(updateColumnSql);
        allSql.addAll(deleteColumnSql);
        return allSql;

    }

    private List<CreateSqlVO> getUpdateColumnSql(String tableName, List<String> updateColumnName,
                                                 Map<String, SQLColumnDefinition> newColumnMap,
                                                 Map<String, SQLColumnDefinition> dbColumnMap) {
        List<CreateSqlVO> updateColumnSql = new ArrayList<>();
        for (String columnName : updateColumnName) {
            // 新字段
            SQLColumnDefinition newColumn = newColumnMap.get(columnName);
            // 库里面的字段
            SQLColumnDefinition dbColumn = dbColumnMap.get(columnName);

            // 如果字段完全相同，则跳过
            if (newColumn.equals(dbColumn)) {
                continue;
            }



            // 如果类型不同，则生成修改类型的语句
            SqlOperateTypeEnum sqlOperateType =null;
            String sql = "ALTER TABLE " + tableName + " MODIFY COLUMN " + columnName + " " + newColumn.getDataType();
            if(!sqlColumnTypeEquals(newColumn.getDataType(),dbColumn.getDataType())){
                sqlOperateType = SqlOperateTypeEnum.MODIFY_COLUMN_TYPE;
            }


            // 如果注释不同，则修改注释
            sql = sql+ " COMMENT " + newColumn.getComment();
            if (!sqlEquals(newColumn.getComment(), dbColumn.getComment())) {
               if(sqlOperateType == null){
                   sqlOperateType = SqlOperateTypeEnum.MODIFY_COLUMN_COMMENT;
               }

            }
            if (newColumn.getDefaultExpr() != null) {
                sql += " DEFAULT " + newColumn.getDefaultExpr();
            }
            // 如果默认值不同，则修改默认值
            if (!sqlEquals(newColumn.getDefaultExpr(), dbColumn.getDefaultExpr())) {
                if(sqlOperateType == null){
                    sqlOperateType = SqlOperateTypeEnum.MODIFY_COLUMN_DEFAULT;
                }
            }
            if(sqlOperateType!=null){
                CreateSqlVO createSqlVO = CreateSqlVO.builder()
                        .tableSql(sql)
                        .operateType(sqlOperateType.getCode())
                        .operateContent(sqlOperateType.getName())
                        .riskLevel(sqlOperateType.getRiskLevel())
                        .tableName(tableName)
                        .build();
                updateColumnSql.add(createSqlVO);
            }
        }
        return updateColumnSql;
    }

    /**
     * 判断字段类型是否相同
     * @param dataType
     * @param dataType1
     * @return
     */
    private boolean sqlColumnTypeEquals(SQLDataType dataType, SQLDataType dataType1) {
        if(dataType.equals(dataType1)){
            return true;
        }
        if(dataType.getName().equalsIgnoreCase(dataType1.getName())){
            if(dataType.getArguments().isEmpty()||dataType1.getArguments().isEmpty()){
                return true;
            }
            if(dataType.getArguments().get(0).toString().equalsIgnoreCase(dataType1.getArguments().get(0).toString())){
                return true;
            }

        }
        return false;
    }


    private List<CreateSqlVO> getDeleteColumnSql(String tableName, List<SQLColumnDefinition> deleteColumnList) {
        return deleteColumnList.stream().map(column -> {
            CreateSqlVO createSqlVO = new CreateSqlVO();
            createSqlVO.setTableName(tableName);
            createSqlVO.setOperateType(SqlOperateTypeEnum.DELETE_COLUMN.getCode());
            createSqlVO.setOperateContent(SqlOperateTypeEnum.DELETE_COLUMN.getName());
            createSqlVO.setTableSql(getDeleteColumnSqlStr(tableName, column));
            return createSqlVO;
        }).collect(Collectors.toList());
    }

    private String getDeleteColumnSqlStr(String tableName, SQLColumnDefinition column) {
        return "ALTER TABLE " + tableName + " DROP COLUMN " + column.getNameAsString();
    }

    private List<CreateSqlVO> getAddColumnSql(String tableName, List<SQLColumnDefinition> addColumnList) {
        return addColumnList.stream().map(column -> {
            CreateSqlVO createSqlVO = new CreateSqlVO();
            SqlOperateTypeEnum operateTypeEnum = SqlOperateTypeEnum.ADD_COLUMN;
            createSqlVO.setOperateType(operateTypeEnum.getCode());
            createSqlVO.setOperateContent(operateTypeEnum.getName());
            createSqlVO.setTableSql(getAddColumnSqlStr(tableName, column));
            createSqlVO.setRiskLevel(operateTypeEnum.getRiskLevel());
            return createSqlVO;
        }).collect(Collectors.toList());
    }

    /**
     * 获取增加字段的SQL
     *
     * @param tableName
     * @param column
     * @return
     */
    private String getAddColumnSqlStr(String tableName, SQLColumnDefinition column) {
        StringBuilder sb = new StringBuilder();
        sb.append("ALTER TABLE ").append(tableName).append(" ADD COLUMN ")
                .append(column.getNameAsString()); // 获取列名

        // 获取数据类型
        if (column.getDataType() != null) {
            sb.append(" ").append(column.getDataType().toString());
        }

        // 如果有默认值
        if (column.getDefaultExpr() != null) {
            sb.append(" DEFAULT ").append(column.getDefaultExpr().toString());
        }

        // 如果不允许为空
        if (!column.containsNotNullConstaint()) {
            sb.append(" NOT NULL");
        }

        // 如果有注释
        if (column.getComment() != null) {
            sb.append(" COMMENT ").append(column.getComment());
        }
        sb.append(";");
        return sb.toString();

    }


    private List<SQLAlterTableStatement> getNewRenameColumnSql(List<SQLStatement> sqlStatements) {
        return sqlStatements.stream()
                .filter(e -> e instanceof SQLAlterTableStatement)
                .map(e -> (SQLAlterTableStatement) e)
                .filter(e -> e.getItems().stream().anyMatch(item -> item instanceof MySqlAlterTableChangeColumn))
                .collect(Collectors.toList());
    }

    /**
     * 获取所有建表语句
     *
     * @return
     */
    private Map<String, MySqlCreateTableStatement> getNewCreateSql(List<SQLStatement> sqlStatements) {
        return sqlStatements.stream()
                .filter(e -> e instanceof MySqlCreateTableStatement)
                .map(e -> (MySqlCreateTableStatement) e)
                .collect(Collectors.toMap(e->e.getTableName().toLowerCase(Locale.ROOT).replaceAll("`", ""),
                        e -> e,
                        (e1, e2) -> {
                            throw new RuntimeException("重复的建表语句,表名：" + e1.getTableName());
                        })
                );
    }

    /**
     * 获取数据库里面的所有建表语句信息
     *
     * @return 数据库里面的所有建表语句
     */
    private Map<String, MySqlCreateTableStatement> getDbSql() {
        List<String> allTable = createTableMapper.getAllTable();
        List<String> allSql = allTable.stream().map(e -> createTableMapper.getTableSql(e))
                .filter(Objects::nonNull)
                .map(e -> e.get("Create Table"))
                .filter(StrUtil::isNotBlank)
                .collect(Collectors.toList());
        return allSql.stream().filter(StrUtil::isNotBlank)
                .map(e -> SQLUtils.parseStatements(e, DbType.mysql))
                .flatMap(Collection::stream)
                .filter(e -> e instanceof MySqlCreateTableStatement)
                .map(e -> (MySqlCreateTableStatement) e)
                .collect(Collectors.toMap(e->e.getTableName().toLowerCase(Locale.ROOT).replaceAll("`", ""),
                        e -> e,
                        (e1, e2) -> {
                            throw new RuntimeException("数据库重复的表,表名：" + e1.getTableName());
                        })
                );
    }

    /**
     * 获取所有的建表语句
     *
     * @return 获取建表语句
     */
    private List<SQLStatement> getNewSql() {
        List<String> sqlList = new ArrayList<>();
        for (TableInitService tableInitService : tableInitServices) {
            String sql = tableInitService.getCreateTableSql();
            if (StrUtil.isNotBlank(sql)) {
                sqlList.add(sql);
            }
        }
        return sqlList.stream().filter(StrUtil::isNotBlank)
                .map(e -> SQLUtils.parseStatements(e, DbType.mysql))
                .flatMap(Collection::stream)
                .collect(Collectors.toList());
    }

    public boolean sqlEquals(Object o1, Object o2) {
        String str1 =o1==null?null: o1.toString();
        String str2 =o2==null?null: o2.toString();
       if(str1==null){
           if(str2==null|| str2.equalsIgnoreCase("null")){
               return true;
           }else{
               return false;
           }
       }
       if(str2==null){
           if(str1==null|| str1.equalsIgnoreCase("null")){
               return true;
           }else{
               return false;
           }
       }
       if(o1.equals(o2)){
           return true;
       }
       str1 = str1.trim().replaceAll("`","");
       str2 = str2.trim().replaceAll("`","");
       return str1.equalsIgnoreCase(str2);
    }

}
