/*
 * Decompiled with CFR 0.152.
 */
package net.sourceforge.pmd.lang.apex.rule.security;

import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import net.sourceforge.pmd.lang.apex.ast.ASTAssignmentExpression;
import net.sourceforge.pmd.lang.apex.ast.ASTBinaryExpression;
import net.sourceforge.pmd.lang.apex.ast.ASTFieldDeclaration;
import net.sourceforge.pmd.lang.apex.ast.ASTLiteralExpression;
import net.sourceforge.pmd.lang.apex.ast.ASTMethod;
import net.sourceforge.pmd.lang.apex.ast.ASTMethodCallExpression;
import net.sourceforge.pmd.lang.apex.ast.ASTParameter;
import net.sourceforge.pmd.lang.apex.ast.ASTStandardCondition;
import net.sourceforge.pmd.lang.apex.ast.ASTUserClass;
import net.sourceforge.pmd.lang.apex.ast.ASTVariableDeclaration;
import net.sourceforge.pmd.lang.apex.ast.ASTVariableExpression;
import net.sourceforge.pmd.lang.apex.ast.AbstractApexNode;
import net.sourceforge.pmd.lang.apex.ast.ApexNode;
import net.sourceforge.pmd.lang.apex.rule.AbstractApexRule;
import net.sourceforge.pmd.lang.apex.rule.internal.Helper;
import net.sourceforge.pmd.lang.ast.Node;
import net.sourceforge.pmd.lang.rule.RuleTargetSelector;
import org.checkerframework.checker.nullness.qual.NonNull;

public class ApexSOQLInjectionRule
extends AbstractApexRule {
    private static final Set<String> SAFE_VARIABLE_TYPES = Collections.unmodifiableSet(Stream.of("double", "long", "decimal", "boolean", "id", "integer", "sobjecttype", "schema.sobjecttype", "sobjectfield", "schema.sobjectfield").collect(Collectors.toSet()));
    private static final String JOIN = "join";
    private static final String ESCAPE_SINGLE_QUOTES = "escapeSingleQuotes";
    private static final String STRING = "String";
    private static final String DATABASE = "Database";
    private static final String QUERY = "query";
    private static final String COUNT_QUERY = "countQuery";
    private static final Pattern SELECT_PATTERN = Pattern.compile("^select[\\s]+?.*?$", 2);
    private final Set<String> safeVariables = new HashSet<String>();
    private final Map<String, Boolean> selectContainingVariables = new HashMap<String, Boolean>();

    protected @NonNull RuleTargetSelector buildTargetSelector() {
        return RuleTargetSelector.forTypes(ASTUserClass.class, (Class[])new Class[0]);
    }

    @Override
    public Object visit(ASTUserClass node, Object data) {
        if (Helper.isTestMethodOrClass(node) || Helper.isSystemLevelClass(node)) {
            return data;
        }
        for (AbstractApexNode m : node.descendants(ASTMethod.class)) {
            this.findSafeVariablesInSignature((ASTMethod)m);
        }
        for (AbstractApexNode.Single a : node.descendants(ASTFieldDeclaration.class)) {
            this.findSanitizedVariables(a);
            this.findSelectContainingVariables(a);
        }
        for (AbstractApexNode.Single a : node.descendants(ASTVariableDeclaration.class)) {
            this.findSanitizedVariables(a);
            this.findSelectContainingVariables(a);
        }
        for (AbstractApexNode.Single a : node.descendants(ASTAssignmentExpression.class)) {
            this.findSanitizedVariables(a);
            this.findSelectContainingVariables(a);
        }
        for (AbstractApexNode m : node.descendants(ASTMethodCallExpression.class)) {
            if (Helper.isTestMethodOrClass(m) || !this.isQueryMethodCall((ASTMethodCallExpression)m)) continue;
            this.reportStrings((ASTMethodCallExpression)m, data);
            this.reportVariables((ASTMethodCallExpression)m, data);
        }
        this.safeVariables.clear();
        this.selectContainingVariables.clear();
        return data;
    }

    private boolean isQueryMethodCall(ASTMethodCallExpression m) {
        return Helper.isMethodName(m, DATABASE, QUERY) || Helper.isMethodName(m, DATABASE, COUNT_QUERY);
    }

    private boolean isSafeVariableType(String typeName) {
        return SAFE_VARIABLE_TYPES.contains(typeName.toLowerCase(Locale.ROOT));
    }

    private void findSafeVariablesInSignature(ASTMethod m) {
        for (ASTParameter p : m.children(ASTParameter.class)) {
            if (!this.isSafeVariableType(p.getType())) continue;
            this.safeVariables.add(Helper.getFQVariableName(p));
        }
    }

    private void findSanitizedVariables(ApexNode<?> node) {
        ASTVariableExpression left = (ASTVariableExpression)node.firstChild(ASTVariableExpression.class);
        ASTLiteralExpression literal = (ASTLiteralExpression)node.firstChild(ASTLiteralExpression.class);
        ASTMethodCallExpression right = (ASTMethodCallExpression)node.firstChild(ASTMethodCallExpression.class);
        if (literal != null && left != null) {
            if (literal.isInteger() || literal.isBoolean() || literal.isDouble()) {
                this.safeVariables.add(Helper.getFQVariableName(left));
            }
            if (literal.isString()) {
                if (SELECT_PATTERN.matcher(literal.getImage()).matches()) {
                    this.selectContainingVariables.put(Helper.getFQVariableName(left), Boolean.TRUE);
                } else {
                    this.safeVariables.add(Helper.getFQVariableName(left));
                }
            }
        }
        if (right != null && Helper.isMethodName(right, STRING, ESCAPE_SINGLE_QUOTES) && left != null) {
            this.safeVariables.add(Helper.getFQVariableName(left));
        }
        if (node instanceof ASTVariableDeclaration && this.isSafeVariableType(((ASTVariableDeclaration)node).getType())) {
            this.safeVariables.add(Helper.getFQVariableName(left));
        }
    }

    private void findSelectContainingVariables(ApexNode<?> node) {
        ASTVariableExpression left = (ASTVariableExpression)node.firstChild(ASTVariableExpression.class);
        ASTBinaryExpression right = (ASTBinaryExpression)node.firstChild(ASTBinaryExpression.class);
        if (left != null && right != null) {
            this.recursivelyCheckForSelect(left, right);
        }
    }

    private void recursivelyCheckForSelect(ASTVariableExpression var, ASTBinaryExpression node) {
        ASTLiteralExpression literal;
        ASTMethodCallExpression methodCall;
        ASTBinaryExpression right = (ASTBinaryExpression)node.firstChild(ASTBinaryExpression.class);
        if (right != null) {
            this.recursivelyCheckForSelect(var, right);
        }
        ASTVariableExpression concatenatedVar = (ASTVariableExpression)node.firstChild(ASTVariableExpression.class);
        boolean isSafeVariable = false;
        if (concatenatedVar != null && this.safeVariables.contains(Helper.getFQVariableName(concatenatedVar))) {
            isSafeVariable = true;
        }
        if ((methodCall = (ASTMethodCallExpression)node.firstChild(ASTMethodCallExpression.class)) != null && Helper.isMethodName(methodCall, STRING, ESCAPE_SINGLE_QUOTES)) {
            isSafeVariable = true;
        }
        if ((literal = (ASTLiteralExpression)node.firstChild(ASTLiteralExpression.class)) != null) {
            if (literal.isString() && SELECT_PATTERN.matcher(literal.getImage()).matches()) {
                if (!isSafeVariable) {
                    this.selectContainingVariables.put(Helper.getFQVariableName(var), Boolean.FALSE);
                } else {
                    this.safeVariables.add(Helper.getFQVariableName(var));
                }
            }
        } else if (!isSafeVariable) {
            this.selectContainingVariables.put(Helper.getFQVariableName(var), Boolean.FALSE);
        }
    }

    private void reportStrings(ASTMethodCallExpression m, Object data) {
        HashSet setOfSafeVars = new HashSet();
        for (ASTStandardCondition c : m.descendants(ASTStandardCondition.class)) {
            List vars = c.descendants(ASTVariableExpression.class).toList();
            setOfSafeVars.addAll(vars);
        }
        for (ASTBinaryExpression b : m.children(ASTBinaryExpression.class)) {
            for (ASTVariableExpression v : b.descendants(ASTVariableExpression.class)) {
                ASTMethodCallExpression parentCall;
                boolean isSafeMethod;
                boolean isLiteral;
                String fqName = Helper.getFQVariableName(v);
                if (this.selectContainingVariables.containsKey(fqName) && (isLiteral = this.selectContainingVariables.get(fqName).booleanValue()) || setOfSafeVars.contains(v) || this.safeVariables.contains(fqName) || (isSafeMethod = Helper.isMethodName(parentCall = (ASTMethodCallExpression)v.ancestors(ASTMethodCallExpression.class).first(), STRING, ESCAPE_SINGLE_QUOTES) || Helper.isMethodName(parentCall, STRING, JOIN))) continue;
                this.asCtx(data).addViolation((Node)v);
            }
        }
    }

    private void reportVariables(ASTMethodCallExpression m, Object data) {
        boolean isLiteral;
        String nameFQ;
        ASTVariableExpression var = (ASTVariableExpression)m.firstChild(ASTVariableExpression.class);
        if (var != null && this.selectContainingVariables.containsKey(nameFQ = Helper.getFQVariableName(var)) && !(isLiteral = this.selectContainingVariables.get(nameFQ).booleanValue())) {
            this.asCtx(data).addViolation((Node)var);
        }
    }
}

