package cn.schoolwow.quickdao.flow.initial.datasource;

import cn.schoolwow.quickdao.domain.external.QuickDAOConfig;
import cn.schoolwow.quickdao.domain.internal.DatabaseType;
import cn.schoolwow.quickdao.provider.DatabaseProvider;
import cn.schoolwow.quickflow.domain.FlowContext;
import cn.schoolwow.quickflow.flow.BusinessFlow;
import com.zaxxer.hikari.HikariDataSource;

import javax.sql.DataSource;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.List;

public class AutomaticCreateDatabase implements BusinessFlow {
    @Override
    public void executeBusinessFlow(FlowContext flowContext) throws Exception {
        setDatabaseType(flowContext);
        setQueryDatabaseSQL(flowContext);
        setCreateDatabaseSQL(flowContext);
        executeCreateDatabase(flowContext);
    }

    @Override
    public String name() {
        return "自动创建数据库";
    }

    private void setDatabaseType(FlowContext flowContext){
        DataSource dataSource = (DataSource) flowContext.checkData("dataSource");

        if(!(dataSource instanceof HikariDataSource)){
            flowContext.broken("不支持的数据库连接池类型!类名:"+dataSource.getClass().getName());
        }

        HikariDataSource hikariDataSource = (HikariDataSource) dataSource;

        String jdbcUrl = hikariDataSource.getJdbcUrl();
        flowContext.putCurrentFlowData("jdbcUrl", jdbcUrl);

        List<DatabaseProvider> databaseProviderList = (List<DatabaseProvider>) flowContext.checkData("databaseProviderList");
        QuickDAOConfig quickDAOConfig = (QuickDAOConfig) flowContext.checkData("quickDAOConfig");

        for (DatabaseProvider databaseProvider : databaseProviderList) {
            if (jdbcUrl.contains("jdbc:" + databaseProvider.name())) {
                quickDAOConfig.databaseContext.databaseProvider = databaseProvider;
                break;
            }
        }
        if (null == quickDAOConfig.databaseContext.databaseProvider) {
            throw new IllegalArgumentException("不支持的数据库类型!jdbcUrl:" + jdbcUrl);
        }
        flowContext.putCurrentFlowData("databaseType", quickDAOConfig.databaseContext.databaseProvider.getDatabaseType());

        flowContext.putCurrentFlowData("driverClassName", hikariDataSource.getDriverClassName());
        if(null!=hikariDataSource.getUsername()){
            flowContext.putCurrentFlowData("username", hikariDataSource.getUsername());
        }
        if(null!=hikariDataSource.getPassword()){
            flowContext.putCurrentFlowData("password", hikariDataSource.getPassword());
        }

        String databaseName = jdbcUrl.substring(jdbcUrl.lastIndexOf("/")+1, jdbcUrl.contains("?")?jdbcUrl.indexOf("?"):jdbcUrl.length());
        flowContext.putCurrentFlowData("databaseName", databaseName);
    }

    private void setQueryDatabaseSQL(FlowContext flowContext){
        DatabaseType databaseType = (DatabaseType) flowContext.checkData("databaseType");

        String queryDatabaseSQL = null;
        switch (databaseType){
            case H2:{queryDatabaseSQL = "SELECT SCHEMA_NAME FROM INFORMATION_SCHEMA.SCHEMATA;";}break;
            case MariaDB:
            case Mysql:{queryDatabaseSQL = "SHOW DATABASES;";}break;
            case Postgresql:{queryDatabaseSQL = "SELECT datname FROM pg_database;";}break;
            case SQLServer:{queryDatabaseSQL = "SELECT name FROM sys.databases;";}break;
            case Oracle:{queryDatabaseSQL = "SELECT table_name FROM user_tables;";}break;
            default:{
                flowContext.broken("不支持的数据库类型!"+databaseType.name());
            }break;
        }
        flowContext.putCurrentFlowData("queryDatabaseSQL", queryDatabaseSQL);
    }

    private void setCreateDatabaseSQL(FlowContext flowContext){
        DatabaseType databaseType = (DatabaseType) flowContext.checkData("databaseType");
        String databaseName = (String) flowContext.checkData("databaseName");

        String createDatabaseSQL = null;
        switch (databaseType){
            case H2:{createDatabaseSQL = "CREATE SCHEMA "+databaseName+";";}break;
            case MariaDB:
            case Mysql:
            case Postgresql:
            case SQLServer:{createDatabaseSQL = "CREATE DATABASE "+databaseName+";";}break;
            case Oracle:
            default:{
                flowContext.broken("不支持的数据库类型!"+databaseType.name());
            }break;
        }
        flowContext.putCurrentFlowData("createDatabaseSQL", createDatabaseSQL);
    }

    private void executeCreateDatabase(FlowContext flowContext) throws ClassNotFoundException, SQLException {
        String driverClassName = (String) flowContext.checkData("driverClassName");
        String jdbcUrl = (String) flowContext.checkData("jdbcUrl");
        String username = (String) flowContext.getData("username");
        String password = (String) flowContext.getData("password");

        String queryDatabaseSQL = (String) flowContext.checkData("queryDatabaseSQL");
        String createDatabaseSQL = (String) flowContext.checkData("createDatabaseSQL");
        String databaseName = (String) flowContext.checkData("databaseName");

        String rawJdbcUrl = jdbcUrl.substring(0, jdbcUrl.lastIndexOf("/")+1);

        Class.forName(driverClassName);
        try(Connection connection = DriverManager.getConnection(rawJdbcUrl, username, password);){
            ResultSet resultSet = connection.prepareStatement(queryDatabaseSQL).executeQuery();
            boolean databaseExist = false;
            while(resultSet.next()){
                String rowDatabaseName = resultSet.getString(1);
                if(databaseName.equalsIgnoreCase(rowDatabaseName)){
                    databaseExist = true;
                    break;
                }
            }
            if(!databaseExist){
                connection.prepareStatement(createDatabaseSQL).executeUpdate();
            }
        }
    }

}
