package cn.fscode.commons.mybatis.plus;

import cn.fscode.commons.mybatis.plus.constants.MybatisPlusConstants;
import cn.fscode.commons.mybatis.plus.utils.DataBaseInit;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
import org.springframework.beans.factory.support.BeanDefinitionRegistry;
import org.springframework.beans.factory.support.BeanDefinitionRegistryPostProcessor;
import org.springframework.context.EnvironmentAware;
import org.springframework.core.env.Environment;

/**
 * @author shenguangyang
 */
public class MybatisInitPostProcessor implements BeanDefinitionRegistryPostProcessor, EnvironmentAware {
    private static final Logger log = LoggerFactory.getLogger(MybatisInitPostProcessor.class);
    private Environment environment;
    @Override
    public void postProcessBeanDefinitionRegistry(BeanDefinitionRegistry beanDefinitionRegistry) throws BeansException {
        if (!isInitDb()) {
            return;
        }
        String driverClassName = getEnvironmentProperty("spring.datasource.driver-class-name");
        String jdbcUrl = getEnvironmentProperty("spring.datasource.url");
        String username = getEnvironmentProperty("spring.datasource.username");
        String password = getEnvironmentProperty("spring.datasource.password");
        String initDb = getDbNameFromUrl(jdbcUrl);

        // 初始化数据库
        log.info("initDb [{}] driverClassName [{}] jdbcUrl [{}] username [{}] password [{}]",
                initDb,driverClassName,jdbcUrl,username,password);
        try {
            DataBaseInit.initDb(jdbcUrl, driverClassName, username, password, initDb);
        } catch (ClassNotFoundException e) {
            log.error("database create failed, errorMessage [{}]", e.getMessage());
        }
    }

    @Override
    public void postProcessBeanFactory(ConfigurableListableBeanFactory configurableListableBeanFactory) throws BeansException {

    }

    @Override
    public void setEnvironment(Environment environment) {
        this.environment = environment;
    }

    private Boolean isInitDb() {
        String property = MybatisPlusConstants.MYBATIS_PLUS_EXTEND_PROPERTIES_PREFIX + ".init-db";
        String value = environment.getProperty(property, "false");
        log.info("mybatis-plus.init-db: " + value);
        return Boolean.parseBoolean(value);
    }

    private String getEnvironmentProperty(String property) {
        String value = environment.getProperty(property);
        if (value == null) {
            throw new RuntimeException("property [ " + property + " ] is null");
        }
        return value;
    }

    /**
     * 从url中获取数据库名称
     * @param url jdbc:mysql://192.168.116.131:53306/ums1?useUnicode=true&characterEncoding=utf8&characterSetResults=utf8&serverTimezone=Asia/Shanghai
     * @return ums1
     */
    private String getDbNameFromUrl(String url) {
        url = url.substring(url.indexOf("//") + 2);
        url= url.substring(url.indexOf("/")  +1);
        int i = url.indexOf("?");
        String dbName = "";
        dbName = i != -1 ? url.substring(0, i) : url;

        //意思是:  匹配不含这些特殊字符的其他任意一个或多个字符
        String regex = "^.*(=|&|/|\\\\).*$";
        if (dbName.matches(regex)) {
            throw new RuntimeException("url [ " + dbName + " ] is illegal");
        }
        return dbName;
    }
}
