package com.github.dreamroute.pager.starter.interceptor;

import cn.hutool.core.annotation.AnnotationUtil;
import cn.hutool.core.bean.BeanUtil;
import cn.hutool.core.collection.CollUtil;
import com.alibaba.druid.sql.SQLUtils;
import com.alibaba.druid.sql.ast.statement.SQLSelectStatement;
import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlSelectQueryBlock;
import com.alibaba.druid.sql.dialect.mysql.visitor.MySqlSchemaStatVisitor;
import com.github.dreamroute.pager.starter.anno.Pager;
import com.github.dreamroute.pager.starter.anno.PagerContainer;
import com.github.dreamroute.pager.starter.anno.PagerContainerBaseInfo;
import com.github.dreamroute.pager.starter.api.PageRequest;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.commons.lang3.StringUtils;
import org.apache.ibatis.binding.MapperMethod;
import org.apache.ibatis.builder.StaticSqlSource;
import org.apache.ibatis.cache.CacheKey;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.ParameterMapping;
import org.apache.ibatis.mapping.SqlCommandType;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Signature;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.session.Configuration;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;
import org.apache.ibatis.session.SqlSessionFactory;
import org.apache.ibatis.transaction.Transaction;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.context.ApplicationListener;
import org.springframework.context.event.ContextRefreshedEvent;

@Intercepts({@Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class}), @Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class, CacheKey.class, BoundSql.class})})
/* loaded from: input_file:com/github/dreamroute/pager/starter/interceptor/PagerInterceptor.class */
public class PagerInterceptor implements Interceptor, ApplicationListener<ContextRefreshedEvent> {
    private static final Logger log = LoggerFactory.getLogger(PagerInterceptor.class);
    private final ConcurrentHashMap<String, PagerContainerBaseInfo> pagerContainer = new ConcurrentHashMap<>();
    private static final int SINGLE = 1;
    private static final String COUNT_NAME = "_$count$_";
    private static final String SELECT = "SELECT ";
    private static final String WHERE = " WHERE ";
    private static final String FROM = " FROM ";
    private static final String DISTINCT = " DISTINCT ";
    private Configuration config;

    public void onApplicationEvent(ContextRefreshedEvent contextRefreshedEvent) {
        this.config = ((SqlSessionFactory) contextRefreshedEvent.getApplicationContext().getBean(SqlSessionFactory.class)).getConfiguration();
        Collection<Class> mappers = this.config.getMapperRegistry().getMappers();
        if (mappers == null || mappers.isEmpty()) {
            return;
        }
        for (Class cls : mappers) {
            String name = cls.getName();
            Arrays.stream(cls.getDeclaredMethods()).filter(method -> {
                return AnnotationUtil.hasAnnotation(method, Pager.class);
            }).forEach(method2 -> {
                PagerContainerBaseInfo pagerContainerBaseInfo = new PagerContainerBaseInfo();
                String str = (String) AnnotationUtil.getAnnotationValue(method2, Pager.class, "distinctBy");
                if (StringUtils.isNotBlank(str)) {
                    pagerContainerBaseInfo.setDistinctBy(str);
                }
                this.pagerContainer.put(name + "." + method2.getName(), pagerContainerBaseInfo);
            });
        }
    }

    public Object intercept(Invocation invocation) throws Throwable {
        Object obj = invocation.getArgs()[SINGLE];
        Object obj2 = obj;
        String str = null;
        MappedStatement mappedStatement = (MappedStatement) invocation.getArgs()[0];
        if (obj instanceof MapperMethod.ParamMap) {
            IllegalArgumentException illegalArgumentException = new IllegalArgumentException("接口" + mappedStatement.getId() + "参数有误, 分页接口参数必有且仅能有一个，并且是继承了PageRequest的，需要把多个参数封装在一个对象中");
            MapperMethod.ParamMap paramMap = (MapperMethod.ParamMap) obj;
            Stream stream = paramMap.values().stream();
            Class<PageRequest> cls = PageRequest.class;
            PageRequest.class.getClass();
            if (!stream.anyMatch(cls::isInstance)) {
                return invocation.proceed();
            }
            if (paramMap.size() != 2) {
                throw illegalArgumentException;
            }
            obj2 = paramMap.values().stream().findAny().orElseThrow(() -> {
                return illegalArgumentException;
            });
            str = (String) paramMap.keySet().stream().filter(str2 -> {
                return !str2.toLowerCase().matches("param\\d+");
            }).findAny().orElseThrow(() -> {
                return illegalArgumentException;
            });
        }
        if (this.pagerContainer.get(mappedStatement.getId()) == null || !(obj2 instanceof PageRequest)) {
            return invocation.proceed();
        }
        BoundSql boundSql = mappedStatement.getBoundSql(obj);
        PagerContainer parseSql = parseSql(boundSql.getSql(), mappedStatement.getId());
        List<ParameterMapping> parameterMappings = boundSql.getParameterMappings();
        parseSql.setOriginPmList(parameterMappings);
        parseSql.setAfterPmList(parseParameterMappings(this.config, parameterMappings, str, parseSql.isSingleTable()));
        Executor executor = (Executor) ProxyUtil.getOriginObj(invocation.getTarget());
        Transaction transaction = executor.getTransaction();
        BoundSql boundSql2 = new BoundSql(this.config, parseSql.getCountSql(), parseSql.getOriginPmList(), obj);
        copyProps(boundSql, boundSql2, this.config);
        Statement prepareStatement = prepareStatement(transaction, this.config.newStatementHandler(executor, new MappedStatement.Builder(this.config, mappedStatement.getId() + "(分页统计)", new StaticSqlSource(this.config, parseSql.getCountSql()), SqlCommandType.SELECT).build(), obj, RowBounds.DEFAULT, (ResultHandler) null, boundSql2));
        ((PreparedStatement) prepareStatement).execute();
        ResultSet resultSet = prepareStatement.getResultSet();
        ResultWrapper resultWrapper = new ResultWrapper();
        while (resultSet.next()) {
            resultWrapper.setTotal(resultSet.getLong(COUNT_NAME));
        }
        prepareStatement.close();
        PageRequest pageRequest = (PageRequest) obj2;
        int pageNum = pageRequest.getPageNum();
        int pageSize = pageRequest.getPageSize();
        resultWrapper.setPageNum(pageNum);
        resultWrapper.setPageSize(pageSize);
        pageRequest.setPageNum((pageNum - SINGLE) * pageSize);
        if (resultWrapper.getTotal() != 0) {
            BoundSql boundSql3 = new BoundSql(this.config, parseSql.getAfterSql(), parseSql.getAfterPmList(), obj);
            copyProps(boundSql, boundSql3, this.config);
            StatementHandler newStatementHandler = this.config.newStatementHandler(executor, mappedStatement, obj, RowBounds.DEFAULT, (ResultHandler) null, boundSql3);
            Statement prepareStatement2 = prepareStatement(transaction, newStatementHandler);
            List query = newStatementHandler.query(prepareStatement2, (ResultHandler) null);
            prepareStatement2.close();
            resultWrapper.addAll(query);
        }
        pageRequest.setPageNum(pageNum);
        return resultWrapper;
    }

    private Statement prepareStatement(Transaction transaction, StatementHandler statementHandler) throws SQLException {
        Statement prepare = statementHandler.prepare(transaction.getConnection(), transaction.getTimeout());
        statementHandler.parameterize(prepare);
        return prepare;
    }

    private List<ParameterMapping> parseParameterMappings(Configuration configuration, List<ParameterMapping> list, String str, boolean z) {
        ArrayList arrayList = new ArrayList((Collection) Optional.ofNullable(list).orElseGet(ArrayList::new));
        String str2 = "pageNum";
        String str3 = "pageSize";
        if (StringUtils.isNotBlank(str)) {
            str2 = str + "." + str2;
            str3 = str + "." + str3;
        }
        arrayList.add(new ParameterMapping.Builder(configuration, str2, Integer.TYPE).build());
        arrayList.add(new ParameterMapping.Builder(configuration, str3, Integer.TYPE).build());
        if (!z) {
            arrayList.addAll((Collection) Optional.ofNullable(list).orElseGet(ArrayList::new));
        }
        return arrayList;
    }

    private PagerContainer parseSql(String str, String str2) {
        String str3;
        PagerContainer pagerContainer = (PagerContainer) BeanUtil.copyProperties(this.pagerContainer.get(str2), PagerContainer.class, new String[0]);
        SQLSelectStatement parseSingleMysqlStatement = SQLUtils.parseSingleMysqlStatement(str);
        MySqlSchemaStatVisitor mySqlSchemaStatVisitor = new MySqlSchemaStatVisitor();
        parseSingleMysqlStatement.accept(mySqlSchemaStatVisitor);
        List list = (List) mySqlSchemaStatVisitor.getTables().keySet().stream().map((v0) -> {
            return v0.getName();
        }).collect(Collectors.toList());
        MySqlSelectQueryBlock query = parseSingleMysqlStatement.getSelect().getQuery();
        String str4 = (String) query.getSelectList().stream().map((v0) -> {
            return v0.toString();
        }).collect(Collectors.joining(","));
        String obj = query.getFrom().toString();
        String obj2 = query.getWhere() != null ? query.getWhere().toString() : "";
        String str5 = StringUtils.isNotBlank(obj2) ? WHERE + obj2 : "";
        if (((List) Optional.ofNullable(query.getOrderBy()).map((v0) -> {
            return v0.getItems();
        }).orElseGet(ArrayList::new)).size() > SINGLE) {
            throw new IllegalArgumentException("分页插件不支持多个排序字段!");
        }
        String sQLOrderBy = query.getOrderBy() != null ? query.getOrderBy().toString() : "";
        if (CollUtil.isNotEmpty(list) && list.size() == SINGLE) {
            pagerContainer.setCountSql("SELECT COUNT(*) _$count$_  FROM " + obj + str5);
            str3 = str + " LIMIT ?, ?";
            pagerContainer.setSingleTable(true);
        } else {
            String distinctBy = pagerContainer.getDistinctBy();
            String str6 = "";
            String str7 = "";
            if (distinctBy.indexOf(46) != -1) {
                str6 = distinctBy.split("\\.")[0];
                str7 = distinctBy.split("\\.")[SINGLE];
            }
            String str8 = "SELECT  DISTINCT " + distinctBy + FROM + obj + str5 + " " + sQLOrderBy + " LIMIT ?, ?";
            String str9 = "_" + str6;
            String replaceAll = str8.replaceAll("\\bAS\\s+" + str6 + "\\b", "AS " + str9).replaceAll("\\b" + str6 + "\\b", str9);
            String str10 = "__" + StringUtils.repeat(str6, 2);
            str3 = SELECT + str4 + FROM + obj + WHERE + distinctBy + " IN (" + (SELECT + str10 + "." + str7 + FROM + "(" + replaceAll + ") " + str10) + ")" + (StringUtils.isNotBlank(str5) ? " AND " + obj2 : "") + " " + sQLOrderBy;
            pagerContainer.setCountSql("SELECT count(DISTINCT " + distinctBy + ") " + COUNT_NAME + FROM + obj + str5);
        }
        pagerContainer.setAfterSql(str3);
        return pagerContainer;
    }

    private static void copyProps(BoundSql boundSql, BoundSql boundSql2, Configuration configuration) {
        MetaObject newMetaObject = configuration.newMetaObject(boundSql);
        Object value = newMetaObject.getValue("additionalParameters");
        Object value2 = newMetaObject.getValue("metaParameters");
        MetaObject newMetaObject2 = configuration.newMetaObject(boundSql2);
        newMetaObject2.setValue("additionalParameters", value);
        newMetaObject2.setValue("metaParameters", value2);
    }
}
