package cn.tenfell.plugins.mybatisplus.interce;

import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.lang.Console;
import cn.hutool.core.lang.Dict;
import cn.hutool.core.util.ClassUtil;
import cn.hutool.core.util.ReflectUtil;
import cn.hutool.core.util.StrUtil;
import cn.hutool.db.Db;
import cn.hutool.db.Entity;
import com.baomidou.mybatisplus.core.metadata.TableFieldInfo;
import com.baomidou.mybatisplus.core.metadata.TableInfo;
import com.baomidou.mybatisplus.extension.toolkit.SqlHelper;
import org.apache.ibatis.executor.resultset.ResultSetHandler;
import org.apache.ibatis.mapping.Environment;
import org.apache.ibatis.plugin.*;
import org.apache.ibatis.session.Configuration;

import javax.sql.DataSource;
import java.lang.reflect.Field;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.*;

/**
 * @author fs
 */
@Intercepts({@Signature(type = ResultSetHandler.class, method = "handleResultSets", args = {Statement.class})})
public class TableChildrenSelectInterceptor implements Interceptor {
    private final static ThreadLocal<String> SQL_SELECT = new ThreadLocal<>();
    private final static ThreadLocal<String> PARENT_FIELD = new ThreadLocal<>();
    public static void set(String sql,String parentField){
        SQL_SELECT.set(sql);
        PARENT_FIELD.set(parentField);
    }
    private static String getSql(){
        return SQL_SELECT.get();
    }
    private static String getParentField(){
        return PARENT_FIELD.get();
    }
    private static Object handler(Invocation invocation) throws InvocationTargetException,IllegalAccessException {
        if(StrUtil.isNotBlank(getSql())){
            SQL_SELECT.remove();
        }
        if(StrUtil.isNotBlank(getParentField())){
            PARENT_FIELD.remove();
        }
        return invocation.proceed();
    }
    private static Object handler(List list) {
        if(StrUtil.isNotBlank(getSql())){
            SQL_SELECT.remove();
        }
        if(StrUtil.isNotBlank(getParentField())){
            PARENT_FIELD.remove();
        }
        return list;
    }
    @Override
    public Object intercept(Invocation invocation) throws Exception{
        String sql = getSql();
        if(StrUtil.isBlank(sql)){
            return handler(invocation);
        }
        List list= (List)invocation.proceed();
        if(list == null || list.size() == 0){
            return handler(list);
        }
        String parentField =getParentField();
        Field field = ReflectUtil.getField(list.get(0).getClass(),parentField);
        Class childType = field.getType();
        Class fieldType = field.getType();
        if(ClassUtil.isAssignable(Iterable.class, fieldType)){
            Type[] types = ((ParameterizedType)field.getGenericType()).getActualTypeArguments();
            childType = (Class)types[0];
        }
        Configuration configuration = SqlHelper.sqlSessionFactory(childType).getConfiguration();
        Environment environment = (Environment)ReflectUtil.getFieldValue(configuration,"environment");
        DataSource dataSource = (DataSource)ReflectUtil.getFieldValue(environment,"dataSource");
        for(Object obj:list){
            if(obj == null){
                continue;
            }
            List data = getChildren(obj,sql,childType,dataSource);
            if(ClassUtil.isAssignable(Iterable.class, fieldType)){
                ReflectUtil.setFieldValue(obj,parentField,data);
            }else if(fieldType == childType && data.size() <= 1){
                if(data.size() == 1){
                    ReflectUtil.setFieldValue(obj,parentField,data.get(0));
                }else{
                    ReflectUtil.setFieldValue(obj,parentField,null);
                }
            }else{
                Console.error("不支持此类型注入:{}",fieldType.toString());
            }

        }
        return handler(list);
    }
    private List<Dict> getColumns(Class clazz){
        Map<String,String> columnMaps = new HashMap<>();
        try{
            TableInfo tableInfo = SqlHelper.table(clazz);
            columnMaps.put(tableInfo.getKeyColumn(),tableInfo.getKeyProperty());
            List<TableFieldInfo> tableFieldInfos = tableInfo.getFieldList();
            for(TableFieldInfo tableFieldInfo:tableFieldInfos){
                columnMaps.put(tableFieldInfo.getColumn(),tableFieldInfo.getProperty());
            }
        }catch (Exception e){

        }
        Field[] fields = ReflectUtil.getFields(clazz);
        for(Field field:fields){
            String property = field.getName();
            String column = StrUtil.toUnderlineCase(property);
            if(columnMaps.get(column) != null){
                continue;
            }
            columnMaps.put(column,property);
        }
        final List<Dict> list = new ArrayList<>();
        CollUtil.forEach(columnMaps, new CollUtil.KVConsumer<String, String>() {
            @Override
            public void accept(String key, String val, int i) {
                list.add(Dict.create().set("column",key).set("property",val));
            }
        });
        return list;
    }
    private List getChildren(Object parent,String sql,Class child,DataSource dataSource) throws SQLException {
        TableInfo parentInfo = SqlHelper.table(parent.getClass());
        List<Dict> parentColList = getColumns(parent.getClass());
        String parentTable = parentInfo.getTableName();
        String tempSql = " " +sql + " ";
        int s ;
        while((s = tempSql.indexOf(parentTable+".")) != -1){
            int e = tempSql.indexOf(" ",s+1);
            String key = tempSql.substring(s,e);
            for(Dict parentCol:parentColList){
                String column = parentCol.getStr("column");
                String property = parentCol.getStr("property");
                if(StrUtil.equals(parentTable+"."+column,key)){
                    String value = ReflectUtil.getFieldValue(parent,property).toString();
                    tempSql = tempSql.replace(key,"'"+value+"'");
                    break;
                }
            }
        }
        String runSql = tempSql.trim();
        List<Entity> list = Db.use(dataSource).query(runSql);
        List res = new ArrayList();
        for(Entity entity:list){
            res.add(entity.toBean(child));
        }
        return res;
    }
    /**
     * 生成拦截对象的代理
     *
     * @param target 目标对象
     * @return 代理对象
     */
    @Override
    public Object plugin(Object target) {
        if (target instanceof ResultSetHandler) {
            return Plugin.wrap(target, this);
        }
        return target;
    }

    /**
     * mybatis配置的属性
     *
     * @param properties mybatis配置的属性
     */
    @Override
    public void setProperties(Properties properties) {

    }
}
