package org.zhuyb.graphbatis.interceptor;

import com.google.common.base.CaseFormat;
import graphql.language.Field;
import graphql.language.Selection;
import graphql.language.SelectionSet;
import graphql.schema.DataFetchingEnvironment;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.Set;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.expression.BinaryExpression;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import net.sf.jsqlparser.schema.Column;
import net.sf.jsqlparser.schema.Table;
import net.sf.jsqlparser.statement.select.FromItem;
import net.sf.jsqlparser.statement.select.Join;
import net.sf.jsqlparser.statement.select.PlainSelect;
import net.sf.jsqlparser.statement.select.SelectExpressionItem;
import net.sf.jsqlparser.statement.select.SelectItem;
import org.apache.commons.lang3.StringUtils;
import org.apache.ibatis.cache.CacheKey;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.SqlCommandType;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Plugin;
import org.apache.ibatis.plugin.Signature;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;
import org.jetbrains.annotations.NotNull;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.zhuyb.graphbatis.DataFetchingEnvHolder;
import org.zhuyb.graphbatis.entity.Tables;

@Intercepts({@Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class, CacheKey.class, BoundSql.class})})
/* loaded from: input_file:org/zhuyb/graphbatis/interceptor/CleanSqlInterceptor.class */
public class CleanSqlInterceptor implements Interceptor {
    public static final Logger logger = LoggerFactory.getLogger(CleanSqlInterceptor.class);
    public static final int BOUND_SQL_INDEX = 5;
    public static final int MAPPED_STATEMENT_INDEX = 0;
    public static final int DEFAULT_MAX_LOOP_DEEP = -1;
    private int maxLoopDeep = -1;

    public Object intercept(Invocation invocation) throws Throwable {
        Object proceed;
        Object[] args = invocation.getArgs();
        MappedStatement mappedStatement = (MappedStatement) args[0];
        if (SqlCommandType.SELECT.equals(mappedStatement.getSqlCommandType())) {
            long currentTimeMillis = System.currentTimeMillis();
            DataFetchingEnvironment dataFetchingEnvironment = DataFetchingEnvHolder.get();
            if (dataFetchingEnvironment != null) {
                Invocation invocation2 = new Invocation(invocation.getTarget(), invocation.getMethod(), args);
                BoundSql boundSql = (BoundSql) args[5];
                args[5] = new BoundSql(mappedStatement.getConfiguration(), getCleanSql(dataFetchingEnvironment, boundSql.getSql()), boundSql.getParameterMappings(), boundSql.getParameterObject());
                logger.debug("clean sql cost {}ms", Long.valueOf(System.currentTimeMillis() - currentTimeMillis));
                proceed = invocation2.proceed();
                DataFetchingEnvHolder.remove();
            } else {
                proceed = invocation.proceed();
            }
        } else {
            proceed = invocation.proceed();
        }
        return proceed;
    }

    private String getCleanSql(DataFetchingEnvironment dataFetchingEnvironment, String str) throws JSQLParserException {
        return getCleanSql(dataFetchingEnvironment, str, 0);
    }

    private String getCleanSql(DataFetchingEnvironment dataFetchingEnvironment, String str, int i) throws JSQLParserException {
        if (this.maxLoopDeep != -1 && i > this.maxLoopDeep) {
            return str;
        }
        PlainSelect plainSelect = (PlainSelect) CCJSqlParserUtil.parse(str).getSelectBody();
        Set<String> allGraphQLFieldNames = getAllGraphQLFieldNames(dataFetchingEnvironment);
        plainSelect.setSelectItems(getCleanSelectItems(plainSelect, allGraphQLFieldNames));
        HashSet hashSet = new HashSet();
        Set<String> cleanSelectTablesAlias = getCleanSelectTablesAlias(plainSelect, allGraphQLFieldNames);
        hashSet.addAll(cleanSelectTablesAlias);
        hashSet.addAll(getCleanWhereTables(plainSelect));
        Tables cleanTables = getCleanTables(plainSelect, hashSet, cleanSelectTablesAlias);
        plainSelect.setFromItem(cleanTables.getFromItem());
        plainSelect.setJoins(cleanTables.getJoins());
        String plainSelect2 = plainSelect.toString();
        if (str.equals(plainSelect2)) {
            return plainSelect2;
        }
        logger.debug("loop times {} clean sql ==> {}", Integer.valueOf(i + 1), plainSelect2);
        return getCleanSql(dataFetchingEnvironment, plainSelect2, i + 1);
    }

    private Tables getCleanTables(PlainSelect plainSelect, Set<String> set, Set<String> set2) {
        List<Join> list;
        List<Join> joins = plainSelect.getJoins();
        FromItem fromItem = (Table) plainSelect.getFromItem();
        Set<Join> explicitJoins = getExplicitJoins(set, joins);
        FromItem fromItem2 = fromItem;
        if (explicitJoins.isEmpty()) {
            list = Collections.emptyList();
        } else {
            HashSet hashSet = new HashSet();
            hashSet.addAll(explicitJoins);
            hashSet.addAll(getImplicitJoin(set, joins, explicitJoins, fromItem));
            List<Join> cleanOrderedJoins = getCleanOrderedJoins(joins, hashSet);
            if (isNeedFromItem(explicitJoins, fromItem, set) || cleanOrderedJoins.size() <= 0) {
                list = cleanOrderedJoins;
            } else {
                int size = cleanOrderedJoins.size();
                if (size == 2 && set2.size() == 1) {
                    FromItem fromItem3 = fromItem;
                    Iterator<Join> it = cleanOrderedJoins.iterator();
                    while (true) {
                        if (!it.hasNext()) {
                            break;
                        }
                        FromItem rightItem = it.next().getRightItem();
                        if (rightItem.getAlias().getName().equals(set2.stream().findFirst().get())) {
                            fromItem3 = rightItem;
                            break;
                        }
                    }
                    fromItem2 = fromItem3;
                    list = Collections.emptyList();
                } else {
                    fromItem2 = cleanOrderedJoins.get(0).getRightItem();
                    list = size <= 1 ? Collections.emptyList() : cleanOrderedJoins.subList(1, size);
                }
            }
        }
        return new Tables(fromItem2, list);
    }

    private Set<String> getCleanSelectTablesAlias(PlainSelect plainSelect, Set<String> set) {
        HashSet hashSet = new HashSet();
        selectItemsLoop(plainSelect.getSelectItems(), set, selectItem -> {
            hashSet.add(getTableAliasName(((SelectExpressionItem) selectItem).getExpression().getTable()));
        });
        return hashSet;
    }

    private List<SelectItem> getCleanSelectItems(PlainSelect plainSelect, Set<String> set) {
        ArrayList arrayList = new ArrayList();
        selectItemsLoop(plainSelect.getSelectItems(), set, selectItem -> {
            arrayList.add(selectItem);
        });
        return arrayList;
    }

    @NotNull
    private List<Join> getCleanOrderedJoins(List<Join> list, Set<Join> set) {
        ArrayList arrayList = new ArrayList(set.size());
        if (list != null) {
            for (Join join : list) {
                if (set.contains(join)) {
                    arrayList.add(join);
                }
            }
        }
        logger.debug("sorted joins {}", set);
        return arrayList;
    }

    private List<Join> getImplicitJoin(Set<String> set, List<Join> list, Set<Join> set2, Table table) {
        ArrayList arrayList = new ArrayList();
        for (Join join : set2) {
            EqualsTo onExpression = join.getOnExpression();
            Column leftExpression = onExpression.getLeftExpression();
            Column rightExpression = onExpression.getRightExpression();
            String name = leftExpression.getTable().getName();
            Join lostJoin = getTableAliasName((Table) join.getRightItem()).equals(name) ? getLostJoin(set, list, table, rightExpression.getTable().getName()) : getLostJoin(set, list, table, name);
            if (lostJoin != null) {
                arrayList.add(lostJoin);
            }
        }
        logger.debug("add implicit joins {}", arrayList);
        return arrayList;
    }

    private boolean isNeedFromItem(Set<Join> set, Table table, Set<String> set2) {
        boolean z = false;
        String name = table.getAlias().getName();
        if (set2.contains(name)) {
            z = true;
        } else {
            Iterator<Join> it = set.iterator();
            while (it.hasNext()) {
                EqualsTo onExpression = it.next().getOnExpression();
                Column leftExpression = onExpression.getLeftExpression();
                Column rightExpression = onExpression.getRightExpression();
                String name2 = leftExpression.getTable().getName();
                if (name.equals(rightExpression.getTable().getName()) || name.equals(name2)) {
                    z = true;
                    break;
                }
            }
        }
        return z;
    }

    private Set<Join> getExplicitJoins(Set<String> set, List<Join> list) {
        HashSet hashSet = new HashSet();
        if (list != null) {
            for (Join join : list) {
                if (set.contains(getTableAliasName((Table) join.getRightItem()))) {
                    hashSet.add(join);
                }
            }
        }
        logger.debug("add explicit joins {}", hashSet);
        return hashSet;
    }

    private String getTableAliasName(Table table) {
        return table.getAlias() != null ? table.getAlias().getName() : table.getName();
    }

    private Join getLostJoin(Set<String> set, List<Join> list, Table table, String str) {
        if (set.contains(str) || getTableAliasName(table).equals(str)) {
            return null;
        }
        for (Join join : list) {
            if (getTableAliasName((Table) join.getRightItem()).equals(str)) {
                return join;
            }
        }
        return null;
    }

    private Set<String> getCleanWhereTables(PlainSelect plainSelect) {
        HashSet hashSet = new HashSet();
        BinaryExpression binaryExpression = (BinaryExpression) plainSelect.getWhere();
        ArrayList arrayList = new ArrayList();
        nextExpression(binaryExpression, arrayList);
        for (Column column : arrayList) {
            logger.debug("where table {}", column.getTable().toString());
            hashSet.add(column.getTable().toString());
        }
        return hashSet;
    }

    private void selectItemsLoop(List<SelectItem> list, Set<String> set, Consumer<SelectItem> consumer) {
        Iterator<SelectItem> it = list.iterator();
        while (it.hasNext()) {
            SelectExpressionItem selectExpressionItem = (SelectItem) it.next();
            Column expression = selectExpressionItem.getExpression();
            if (set.contains(expression.getColumnName())) {
                consumer.accept(selectExpressionItem);
            } else {
                logger.debug("column {} removed", expression.getColumnName());
            }
        }
    }

    private void nextExpression(BinaryExpression binaryExpression, List<Column> list) {
        if (binaryExpression == null) {
            return;
        }
        Expression leftExpression = binaryExpression.getLeftExpression();
        if (leftExpression != null && (leftExpression instanceof BinaryExpression)) {
            nextExpression((BinaryExpression) leftExpression, list);
        }
        if (leftExpression != null && (leftExpression instanceof Column)) {
            list.add((Column) leftExpression);
        }
        Expression rightExpression = binaryExpression.getRightExpression();
        if (rightExpression != null && (rightExpression instanceof BinaryExpression)) {
            nextExpression((BinaryExpression) rightExpression, list);
        }
        if (rightExpression == null || !(rightExpression instanceof Column)) {
            return;
        }
        list.add((Column) rightExpression);
    }

    private Set<String> getAllGraphQLArgumentsNames(DataFetchingEnvironment dataFetchingEnvironment) {
        Set<String> set = null;
        Map arguments = dataFetchingEnvironment.getArguments();
        if (arguments != null) {
            set = (Set) arguments.keySet().stream().map(str -> {
                return (String) CaseFormat.LOWER_CAMEL.converterTo(CaseFormat.LOWER_UNDERSCORE).convert(str);
            }).collect(Collectors.toSet());
        }
        return set;
    }

    private Set<String> getAllGraphQLFieldNames(DataFetchingEnvironment dataFetchingEnvironment) {
        Set<String> set = null;
        List fields = dataFetchingEnvironment.getFields();
        if (fields != null) {
            set = new HashSet();
            Iterator it = fields.iterator();
            while (it.hasNext()) {
                getAllGraphQLFieldNames(set, (Field) it.next());
            }
            if (set != null) {
                set = (Set) set.stream().map(str -> {
                    return (String) CaseFormat.LOWER_CAMEL.converterTo(CaseFormat.LOWER_UNDERSCORE).convert(str);
                }).collect(Collectors.toSet());
                logger.debug("all graphQL field names {}", set);
            }
        }
        if (set == null || set.isEmpty()) {
            logger.debug("all graphQL field names is empty");
        }
        return set;
    }

    private void getAllGraphQLFieldNames(Set<String> set, Field field) {
        SelectionSet selectionSet;
        List selections;
        if (field == null || (selectionSet = field.getSelectionSet()) == null || (selections = selectionSet.getSelections()) == null) {
            return;
        }
        Iterator it = selections.iterator();
        while (it.hasNext()) {
            Field field2 = (Field) ((Selection) it.next());
            set.add(field2.getName());
            getAllGraphQLFieldNames(set, field2);
        }
    }

    public Object plugin(Object obj) {
        return Plugin.wrap(obj, this);
    }

    public void setProperties(Properties properties) {
        String property = properties.getProperty("maxLoopDeep");
        if (StringUtils.isNumeric(property)) {
            this.maxLoopDeep = Integer.valueOf(property).intValue();
        }
        logger.info("properties {}", properties);
    }
}
