/*
 * Decompiled with CFR 0.152.
 */
package de.fraunhofer.aisec.cpg.passes;

import de.fraunhofer.aisec.cpg.TranslationResult;
import de.fraunhofer.aisec.cpg.frontends.java.JavaLanguageFrontend;
import de.fraunhofer.aisec.cpg.graph.CallExpression;
import de.fraunhofer.aisec.cpg.graph.ConstructExpression;
import de.fraunhofer.aisec.cpg.graph.ConstructorDeclaration;
import de.fraunhofer.aisec.cpg.graph.DeclaredReferenceExpression;
import de.fraunhofer.aisec.cpg.graph.ExplicitConstructorInvocation;
import de.fraunhofer.aisec.cpg.graph.Expression;
import de.fraunhofer.aisec.cpg.graph.FunctionDeclaration;
import de.fraunhofer.aisec.cpg.graph.HasType;
import de.fraunhofer.aisec.cpg.graph.MemberCallExpression;
import de.fraunhofer.aisec.cpg.graph.MethodDeclaration;
import de.fraunhofer.aisec.cpg.graph.NewExpression;
import de.fraunhofer.aisec.cpg.graph.Node;
import de.fraunhofer.aisec.cpg.graph.NodeBuilder;
import de.fraunhofer.aisec.cpg.graph.ParamVariableDeclaration;
import de.fraunhofer.aisec.cpg.graph.RecordDeclaration;
import de.fraunhofer.aisec.cpg.graph.StaticCallExpression;
import de.fraunhofer.aisec.cpg.graph.TranslationUnitDeclaration;
import de.fraunhofer.aisec.cpg.graph.ValueDeclaration;
import de.fraunhofer.aisec.cpg.graph.VariableDeclaration;
import de.fraunhofer.aisec.cpg.graph.type.FunctionPointerType;
import de.fraunhofer.aisec.cpg.graph.type.Type;
import de.fraunhofer.aisec.cpg.graph.type.TypeParser;
import de.fraunhofer.aisec.cpg.helpers.SubgraphWalker;
import de.fraunhofer.aisec.cpg.helpers.Util;
import de.fraunhofer.aisec.cpg.passes.Pass;
import de.fraunhofer.aisec.cpg.processing.strategy.Strategy;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.IdentityHashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import org.checkerframework.checker.nullness.qual.NonNull;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class CallResolver
extends Pass {
    private static final Logger LOGGER = LoggerFactory.getLogger(CallResolver.class);
    private Map<String, RecordDeclaration> recordMap = new HashMap<String, RecordDeclaration>();
    private Map<FunctionDeclaration, Type> containingType = new HashMap<FunctionDeclaration, Type>();
    private @Nullable TranslationUnitDeclaration currentTU;
    private SubgraphWalker.ScopedWalker walker;

    @Override
    public void cleanup() {
        this.containingType.clear();
        this.currentTU = null;
    }

    @Override
    public void accept(@NonNull TranslationResult translationResult) {
        this.walker = new SubgraphWalker.ScopedWalker();
        this.walker.registerHandler((currClass, parent, currNode) -> this.walker.collectDeclarations((Node)currNode));
        this.walker.registerHandler(this::findRecords);
        this.walker.registerHandler(this::registerMethods);
        for (TranslationUnitDeclaration tu : translationResult.getTranslationUnits()) {
            this.walker.iterate(tu);
        }
        this.walker.clearCallbacks();
        this.walker.registerHandler(this::fixInitializers);
        for (TranslationUnitDeclaration tu : translationResult.getTranslationUnits()) {
            this.walker.iterate(tu);
        }
        this.walker.clearCallbacks();
        this.walker.registerHandler(this::resolve);
        for (TranslationUnitDeclaration tu : translationResult.getTranslationUnits()) {
            this.walker.iterate(tu);
        }
    }

    private void findRecords(@NonNull Node node, RecordDeclaration curClass) {
        if (node instanceof RecordDeclaration) {
            this.recordMap.putIfAbsent(node.getName(), (RecordDeclaration)node);
        }
    }

    private void registerMethods(RecordDeclaration currentClass, Node parent, @NonNull Node currentNode) {
        if (currentNode instanceof MethodDeclaration && currentClass != null) {
            this.containingType.put((FunctionDeclaration)currentNode, TypeParser.createFrom(currentClass.getName(), true));
        }
    }

    private void fixInitializers(@NonNull Node node, RecordDeclaration curClass) {
        NewExpression newExpression;
        if (node instanceof VariableDeclaration) {
            VariableDeclaration declaration = (VariableDeclaration)node;
            String typeString = declaration.getType().getRoot().getName();
            boolean isRecord = this.recordMap.containsKey(typeString);
            if (isRecord) {
                Expression currInitializer = declaration.getInitializer();
                if (currInitializer == null && declaration.isImplicitInitializerAllowed()) {
                    ConstructExpression initializer = NodeBuilder.newConstructExpression("()");
                    initializer.setImplicit(true);
                    declaration.setInitializer(initializer);
                } else if (currInitializer instanceof CallExpression && currInitializer.getName().equals(typeString)) {
                    CallExpression call = (CallExpression)currInitializer;
                    List<Expression> arguments = call.getArguments();
                    String signature = arguments.stream().map(Node::getCode).collect(Collectors.joining(", "));
                    ConstructExpression initializer = NodeBuilder.newConstructExpression("(" + signature + ")");
                    initializer.setArguments(new ArrayList<Expression>(arguments));
                    initializer.setImplicit(true);
                    declaration.setInitializer(initializer);
                    currInitializer.disconnectFromGraph();
                }
            }
        } else if (node instanceof NewExpression && (newExpression = (NewExpression)node).getInitializer() == null) {
            ConstructExpression initializer = NodeBuilder.newConstructExpression("()");
            initializer.setImplicit(true);
            newExpression.setInitializer(initializer);
        }
    }

    private void handleSuperCall(RecordDeclaration curClass, CallExpression call) {
        RecordDeclaration target = null;
        if (call.getBase().getName().equals("super")) {
            if (!curClass.getSuperClasses().isEmpty()) {
                target = this.recordMap.get(curClass.getSuperClasses().get(0).getTypeName());
            } else {
                Util.warnWithFileLocation(call, LOGGER, "super call without direct superclass! Expected java.lang.Object to be present at least!", new Object[0]);
            }
        } else {
            target = this.handleSpecificSupertype(curClass, call);
        }
        if (target != null) {
            ((DeclaredReferenceExpression)call.getBase()).setRefersTo(target.getThis());
            this.handleMethodCall(target, call);
        }
    }

    private RecordDeclaration handleSpecificSupertype(RecordDeclaration curClass, CallExpression call) {
        String baseName = call.getBase().getName().substring(0, call.getBase().getName().lastIndexOf(".super"));
        if (curClass.getImplementedInterfaces().contains(TypeParser.createFrom(baseName, true))) {
            return this.recordMap.get(baseName);
        }
        RecordDeclaration base = this.recordMap.get(baseName);
        if (base != null) {
            if (!base.getSuperClasses().isEmpty()) {
                return this.recordMap.get(base.getSuperClasses().get(0).getTypeName());
            }
            Util.warnWithFileLocation(call, LOGGER, "super call without direct superclass! Expected java.lang.Object to be present at least!", new Object[0]);
        }
        return null;
    }

    private void resolve(@NonNull Node node, RecordDeclaration curClass) {
        if (node instanceof TranslationUnitDeclaration) {
            this.currentTU = (TranslationUnitDeclaration)node;
        } else if (node instanceof ExplicitConstructorInvocation) {
            this.resolveExplicitConstructorInvocation((ExplicitConstructorInvocation)node);
        } else if (node instanceof CallExpression) {
            CallExpression call = (CallExpression)node;
            this.resolveArguments(call, curClass);
            this.handleCallExpression(curClass, call);
        } else if (node instanceof ConstructExpression) {
            this.resolveConstructExpression((ConstructExpression)node);
        }
    }

    private void handleCallExpression(RecordDeclaration curClass, CallExpression call) {
        Node member;
        if (this.lang instanceof JavaLanguageFrontend && call.getBase() instanceof DeclaredReferenceExpression && call.getBase().getName().matches("(?<class>.+\\.)?super")) {
            this.handleSuperCall(curClass, call);
            return;
        }
        if (call instanceof MemberCallExpression && (member = ((MemberCallExpression)call).getMember()) instanceof HasType && ((HasType)((Object)member)).getType() instanceof FunctionPointerType) {
            this.handleFunctionPointerCall(call, member);
            return;
        }
        Optional<? extends ValueDeclaration> funcPointer = this.walker.getDeclarationForScope(call, v -> v.getType() instanceof FunctionPointerType && v.getName().equals(call.getName()));
        if (funcPointer.isPresent()) {
            this.handleFunctionPointerCall(call, funcPointer.get());
        } else {
            this.handleNormalCalls(curClass, call);
        }
    }

    private void resolveArguments(CallExpression call, RecordDeclaration curClass) {
        ArrayDeque<Node> worklist = new ArrayDeque<Node>();
        call.getArguments().forEach(worklist::push);
        while (!worklist.isEmpty()) {
            Node curr = (Node)worklist.pop();
            if (curr instanceof CallExpression) {
                this.resolve(curr, curClass);
                continue;
            }
            Iterator<Node> it = Strategy.AST_FORWARD(curr);
            while (it.hasNext()) {
                Node astChild = it.next();
                if (astChild instanceof RecordDeclaration) continue;
                worklist.push(astChild);
            }
        }
    }

    private void handleNormalCalls(RecordDeclaration curClass, CallExpression call) {
        if (curClass == null && this.currentTU != null) {
            List<FunctionDeclaration> invocationCandidates = this.currentTU.getDeclarations().stream().filter(FunctionDeclaration.class::isInstance).map(FunctionDeclaration.class::cast).filter(f -> f.getName().equals(call.getName()) && f.hasSignature(call.getSignature())).collect(Collectors.toList());
            call.setInvokes(invocationCandidates);
        } else if (!this.handlePossibleStaticImport(call, curClass)) {
            this.handleMethodCall(curClass, call);
        }
    }

    private void handleMethodCall(RecordDeclaration curClass, CallExpression call) {
        String[] nameParts;
        Set<Type> possibleContainingTypes = this.getPossibleContainingTypes(call, curClass);
        List<FunctionDeclaration> invocationCandidates = call.getInvokes().stream().map(f -> this.getOverridingCandidates(possibleContainingTypes, (FunctionDeclaration)f)).flatMap(Collection::stream).collect(Collectors.toList());
        if (invocationCandidates.isEmpty() && (nameParts = call.getName().split("\\.")).length > 0) {
            List<Type> signature = call.getSignature();
            Set<RecordDeclaration> records = possibleContainingTypes.stream().map(t -> this.recordMap.get(t.getTypeName())).filter(Objects::nonNull).collect(Collectors.toSet());
            invocationCandidates = this.getInvocationCandidatesFromParents(nameParts[nameParts.length - 1], signature, records);
        }
        if (curClass != null && !(call instanceof MemberCallExpression) && !(call instanceof StaticCallExpression)) {
            call.setBase(curClass.getThis());
        }
        call.setInvokes(invocationCandidates);
    }

    private void resolveConstructExpression(ConstructExpression constructExpression) {
        List<Type> signature = constructExpression.getSignature();
        String typeName = constructExpression.getType().getTypeName();
        RecordDeclaration record = this.recordMap.get(typeName);
        constructExpression.setInstantiates(record);
        if (record != null && record.getCode() != null && !record.getCode().isEmpty()) {
            ConstructorDeclaration constructor = this.getConstructorDeclaration(signature, record);
            if (constructor != null) {
                constructExpression.setConstructor(constructor);
            } else {
                LOGGER.warn("Unexpected: Could not find constructor for {} with signature {}", (Object)record.getName(), (Object)signature);
            }
        }
    }

    private void handleFunctionPointerCall(CallExpression call, Node pointer) {
        ArrayList<FunctionDeclaration> invocationCandidates = new ArrayList<FunctionDeclaration>();
        ArrayDeque<Node> worklist = new ArrayDeque<Node>();
        Set seen = Collections.newSetFromMap(new IdentityHashMap());
        worklist.push(pointer);
        DeclaredReferenceExpression finalReference = null;
        while (!worklist.isEmpty()) {
            Node curr = (Node)worklist.pop();
            if (!seen.add(curr)) continue;
            if (curr instanceof FunctionDeclaration) {
                if (((FunctionDeclaration)curr).hasSignature(call.getSignature())) {
                    invocationCandidates.add((FunctionDeclaration)curr);
                    continue;
                }
                if (!curr.isImplicit()) continue;
                if (((FunctionDeclaration)curr).hasSignature(call.getSignature())) {
                    invocationCandidates.add((FunctionDeclaration)curr);
                    if (finalReference == null || !finalReference.getRefersTo().contains(curr)) continue;
                    finalReference.setRefersTo((ValueDeclaration)curr);
                    continue;
                }
                FunctionDeclaration dummy = this.createDummyWithMatchingSignature((FunctionDeclaration)curr, call.getSignature());
                invocationCandidates.add(dummy);
                if (finalReference == null || !finalReference.getRefersTo().contains(curr)) continue;
                finalReference.setRefersTo(dummy);
                continue;
            }
            if (curr instanceof DeclaredReferenceExpression) {
                finalReference = (DeclaredReferenceExpression)curr;
            }
            curr.getPrevDFG().forEach(worklist::push);
        }
        call.setInvokes(invocationCandidates);
    }

    private void resolveExplicitConstructorInvocation(ExplicitConstructorInvocation eci) {
        if (eci.getContainingClass() != null) {
            RecordDeclaration record = this.recordMap.get(eci.getContainingClass());
            List<Type> signature = eci.getArguments().stream().map(Expression::getType).collect(Collectors.toList());
            if (record != null) {
                ConstructorDeclaration constructor = this.getConstructorDeclaration(signature, record);
                ArrayList<FunctionDeclaration> invokes = new ArrayList<FunctionDeclaration>();
                if (constructor != null) {
                    invokes.add(constructor);
                }
                eci.setInvokes(invokes);
            }
        }
    }

    private boolean handlePossibleStaticImport(@Nullable CallExpression call, RecordDeclaration curClass) {
        if (call == null || curClass == null) {
            return false;
        }
        String name = call.getName().substring(call.getName().lastIndexOf(46) + 1);
        List nameMatches = curClass.getStaticImports().stream().filter(FunctionDeclaration.class::isInstance).map(FunctionDeclaration.class::cast).filter(m3 -> m3.getName().equals(name) || m3.getName().endsWith("." + name)).collect(Collectors.toList());
        if (nameMatches.isEmpty()) {
            return false;
        }
        ArrayList<FunctionDeclaration> invokes = new ArrayList<FunctionDeclaration>();
        FunctionDeclaration target = nameMatches.stream().filter(m3 -> m3.hasSignature(call.getSignature())).findFirst().orElse(null);
        if (target == null) {
            this.generateStaticImportDummies(call, name, invokes, curClass);
        } else {
            invokes.add(target);
        }
        call.setInvokes(invokes);
        return true;
    }

    private void generateStaticImportDummies(@NonNull CallExpression call, @NonNull String name, @NonNull List<FunctionDeclaration> invokes, RecordDeclaration curClass) {
        if (curClass == null) {
            LOGGER.warn("Cannot generate dummies for imports of a null class: {}", (Object)call.toString());
            return;
        }
        List containingRecords = curClass.getStaticImportStatements().stream().filter(i -> i.endsWith("." + name)).map(i -> i.substring(0, i.lastIndexOf(46))).map(c -> this.recordMap.getOrDefault(c, null)).filter(Objects::nonNull).collect(Collectors.toList());
        for (RecordDeclaration record : containingRecords) {
            MethodDeclaration dummy = NodeBuilder.newMethodDeclaration(name, "", true, record);
            dummy.setImplicit(true);
            List<ParamVariableDeclaration> params = Util.createParameters(call.getSignature());
            dummy.setParameters(params);
            record.getMethods().add(dummy);
            curClass.getStaticImports().add(dummy);
            invokes.add(dummy);
        }
    }

    private Optional<FunctionDeclaration> checkExistingDummies(FunctionDeclaration template, List<Type> signature) {
        if (template instanceof MethodDeclaration && ((MethodDeclaration)template).getRecordDeclaration() != null) {
            return ((MethodDeclaration)template).getRecordDeclaration().getMethods().stream().filter(m3 -> m3.getName().equals(template.getName()) && m3.hasSignature(signature)).map(FunctionDeclaration.class::cast).findFirst();
        }
        if (this.currentTU == null) {
            LOGGER.error("No current translation unit when trying to find matching dummy for {}", (Object)template);
            return Optional.empty();
        }
        return this.currentTU.getDeclarations().stream().filter(FunctionDeclaration.class::isInstance).map(FunctionDeclaration.class::cast).filter(f -> f.getName().equals(template.getName()) && f.hasSignature(signature)).findFirst();
    }

    private FunctionDeclaration createDummyWithMatchingSignature(FunctionDeclaration template, List<Type> signature) {
        Optional<FunctionDeclaration> existing = this.checkExistingDummies(template, signature);
        if (existing.isPresent()) {
            return existing.get();
        }
        List<ParamVariableDeclaration> parameters = Util.createParameters(signature);
        if (template instanceof MethodDeclaration) {
            RecordDeclaration containingRecord = ((MethodDeclaration)template).getRecordDeclaration();
            MethodDeclaration dummy = NodeBuilder.newMethodDeclaration(template.getName(), template.getCode(), ((MethodDeclaration)template).isStatic(), containingRecord);
            dummy.setImplicit(true);
            dummy.setParameters(parameters);
            if (containingRecord == null) {
                if (this.currentTU == null) {
                    LOGGER.error("No current translation unit when trying to generate method dummy {}", (Object)dummy.getName());
                } else {
                    this.currentTU.getDeclarations().add(dummy);
                }
            } else {
                containingRecord.getMethods().add(dummy);
            }
            return dummy;
        }
        FunctionDeclaration dummy = NodeBuilder.newFunctionDeclaration(template.getName(), template.getCode());
        dummy.setParameters(parameters);
        dummy.setImplicit(true);
        if (this.currentTU == null) {
            LOGGER.error("No current translation unit when trying to generate function dummy {}", (Object)dummy.getName());
        } else {
            this.currentTU.getDeclarations().add(dummy);
        }
        return dummy;
    }

    private Set<Type> getPossibleContainingTypes(Node node, RecordDeclaration curClass) {
        HashSet<Type> possibleTypes = new HashSet<Type>();
        if (node instanceof MemberCallExpression) {
            MemberCallExpression memberCall = (MemberCallExpression)node;
            if (memberCall.getBase() instanceof HasType) {
                HasType base = (HasType)((Object)memberCall.getBase());
                possibleTypes.add(base.getType());
                possibleTypes.addAll(base.getPossibleSubTypes());
            }
        } else if (node instanceof StaticCallExpression) {
            StaticCallExpression staticCall = (StaticCallExpression)node;
            if (staticCall.getTargetRecord() != null) {
                possibleTypes.add(TypeParser.createFrom(staticCall.getTargetRecord(), true));
            }
        } else if (curClass != null) {
            possibleTypes.add(TypeParser.createFrom(curClass.getName(), true));
            possibleTypes.addAll(curClass.getSuperTypes());
        }
        return possibleTypes;
    }

    private List<FunctionDeclaration> getInvocationCandidatesFromRecord(RecordDeclaration record, String name, List<Type> signature) {
        Pattern namePattern = Pattern.compile("(" + Pattern.quote(record.getName()) + "\\.)?" + Pattern.quote(name));
        return record.getMethods().stream().filter(m3 -> namePattern.matcher(m3.getName()).matches() && m3.hasSignature(signature)).map(FunctionDeclaration.class::cast).collect(Collectors.toList());
    }

    private List<FunctionDeclaration> getInvocationCandidatesFromParents(String name, List<Type> signature, Set<RecordDeclaration> possibleTypes) {
        if (possibleTypes.isEmpty()) {
            return new ArrayList<FunctionDeclaration>();
        }
        List<FunctionDeclaration> firstLevelCandidates = possibleTypes.stream().map(r -> this.getInvocationCandidatesFromRecord((RecordDeclaration)r, name, signature)).flatMap(Collection::stream).collect(Collectors.toList());
        if (firstLevelCandidates.isEmpty()) {
            return possibleTypes.stream().map(RecordDeclaration::getSuperTypeDeclarations).map(superTypes -> this.getInvocationCandidatesFromParents(name, signature, (Set<RecordDeclaration>)superTypes)).flatMap(Collection::stream).collect(Collectors.toList());
        }
        return firstLevelCandidates;
    }

    private Set<FunctionDeclaration> getOverridingCandidates(Set<Type> possibleSubTypes, FunctionDeclaration declaration) {
        return declaration.getOverriddenBy().stream().filter(f -> possibleSubTypes.contains(this.containingType.get(f))).collect(Collectors.toSet());
    }

    private @Nullable ConstructorDeclaration getConstructorDeclaration(List<Type> signature, RecordDeclaration record) {
        return record.getConstructors().stream().filter(f -> f.hasSignature(signature)).findFirst().orElse(null);
    }
}

