package tech.harmonysoft.oss.traute.javac.common;

import com.sun.source.tree.AnnotationTree;
import com.sun.source.tree.BlockTree;
import com.sun.source.tree.CaseTree;
import com.sun.source.tree.ClassTree;
import com.sun.source.tree.CompilationUnitTree;
import com.sun.source.tree.DoWhileLoopTree;
import com.sun.source.tree.EnhancedForLoopTree;
import com.sun.source.tree.ForLoopTree;
import com.sun.source.tree.IfTree;
import com.sun.source.tree.ImportTree;
import com.sun.source.tree.MethodTree;
import com.sun.source.tree.ModifiersTree;
import com.sun.source.tree.ReturnTree;
import com.sun.source.tree.Tree;
import com.sun.source.tree.VariableTree;
import com.sun.source.tree.WhileLoopTree;
import com.sun.source.util.TreeScanner;
import com.sun.tools.javac.tree.JCTree;
import com.sun.tools.javac.tree.TreeMaker;
import java.util.Collection;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.Stack;
import java.util.TreeSet;
import javax.lang.model.element.Modifier;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import tech.harmonysoft.oss.traute.common.instrumentation.InstrumentationType;
import tech.harmonysoft.oss.traute.common.util.TrauteConstants;
import tech.harmonysoft.oss.traute.javac.instrumentation.Instrumentator;
import tech.harmonysoft.oss.traute.javac.instrumentation.method.ReturnToInstrumentInfo;
import tech.harmonysoft.oss.traute.javac.instrumentation.parameter.ParameterToInstrumentInfo;

/* loaded from: input_file:tech/harmonysoft/oss/traute/javac/common/InstrumentationApplianceFinder.class */
public class InstrumentationApplianceFinder extends TreeScanner<Void, Void> {
    private Stack<Tree> parents = new Stack<>();

    @NotNull
    private final CompilationUnitProcessingContext context;

    @NotNull
    private final Instrumentator<ParameterToInstrumentInfo> parameterInstrumenter;

    @NotNull
    private final Instrumentator<ReturnToInstrumentInfo> returnInstrumenter;
    private String packageName;
    private String className;
    private String methodName;
    private JCTree.JCExpression methodReturnType;
    private String methodNotNullAnnotation;
    private int tmpVariableCounter;
    private boolean processingInterface;
    private boolean instrumentReturnExpression;

    public InstrumentationApplianceFinder(@NotNull CompilationUnitProcessingContext compilationUnitProcessingContext, @NotNull Instrumentator<ParameterToInstrumentInfo> instrumentator, @NotNull Instrumentator<ReturnToInstrumentInfo> instrumentator2) {
        this.context = compilationUnitProcessingContext;
        this.parameterInstrumenter = instrumentator;
        this.returnInstrumenter = instrumentator2;
    }

    public Void visitCompilationUnit(CompilationUnitTree compilationUnitTree, Void r6) {
        this.packageName = compilationUnitTree.getPackageName().toString();
        return (Void) super.visitCompilationUnit(compilationUnitTree, r6);
    }

    public Void visitClass(ClassTree classTree, Void r8) {
        this.className = classTree.getSimpleName().toString();
        JCTree.JCModifiers modifiers = classTree.getModifiers();
        if (modifiers instanceof JCTree.JCModifiers) {
            this.processingInterface = (modifiers.flags & 512) != 0;
        } else {
            this.processingInterface = modifiers.toString().contains("interface");
        }
        try {
            Void r0 = (Void) super.visitClass(classTree, r8);
            this.processingInterface = false;
            return r0;
        } catch (Throwable th) {
            this.processingInterface = false;
            throw th;
        }
    }

    public Void visitImport(ImportTree importTree, Void r5) {
        if (!importTree.isStatic()) {
            this.context.addImport(importTree.getQualifiedIdentifier().toString());
        }
        return r5;
    }

    public Void visitMethod(MethodTree methodTree, Void r6) {
        JCTree.JCBlock methodBody;
        this.methodName = methodTree.getName().toString();
        this.instrumentReturnExpression = shouldInstrumentReturnExpression(methodTree);
        if (shouldInstrumentMethodParameters(methodTree) && (methodBody = getMethodBody(methodTree)) != null) {
            instrumentMethodParameters(methodTree, methodBody);
        }
        try {
            Void r0 = (Void) super.visitMethod(methodTree, r6);
            this.methodReturnType = null;
            this.methodNotNullAnnotation = null;
            this.methodName = null;
            this.instrumentReturnExpression = false;
            this.tmpVariableCounter = 1;
            return r0;
        } catch (Throwable th) {
            this.methodReturnType = null;
            this.methodNotNullAnnotation = null;
            this.methodName = null;
            this.instrumentReturnExpression = false;
            this.tmpVariableCounter = 1;
            throw th;
        }
    }

    private boolean shouldInstrumentReturnExpression(@NotNull MethodTree methodTree) {
        return (!this.processingInterface || hasFlag(methodTree.getModifiers(), Modifier.DEFAULT)) && this.context.getPluginSettings().isEnabled(InstrumentationType.METHOD_RETURN) && mayBeInstrumentReturnType(methodTree);
    }

    private boolean shouldInstrumentMethodParameters(@NotNull MethodTree methodTree) {
        if (!this.processingInterface || hasFlag(methodTree.getModifiers(), Modifier.DEFAULT)) {
            return this.context.getPluginSettings().isEnabled(InstrumentationType.METHOD_PARAMETER);
        }
        return false;
    }

    private static boolean hasFlag(@Nullable ModifiersTree modifiersTree, @NotNull Modifier modifier) {
        Set flags;
        if (modifiersTree == null || (flags = modifiersTree.getFlags()) == null) {
            return false;
        }
        return flags.contains(modifier);
    }

    @Nullable
    private JCTree.JCBlock getMethodBody(@NotNull MethodTree methodTree) {
        JCTree.JCBlock body;
        if (hasFlag(methodTree.getModifiers(), Modifier.ABSTRACT) || (body = methodTree.getBody()) == null) {
            return null;
        }
        if (body instanceof JCTree.JCBlock) {
            return body;
        }
        this.context.getLogger().reportDetails(String.format("get a %s instance in the method AST but got %s", JCTree.JCBlock.class.getName(), body.getClass().getName()));
        return null;
    }

    private void instrumentMethodParameters(@NotNull MethodTree methodTree, @NotNull JCTree.JCBlock jCBlock) {
        Tree type;
        TreeSet<ParameterToInstrumentInfo> treeSet = new TreeSet((parameterToInstrumentInfo, parameterToInstrumentInfo2) -> {
            return parameterToInstrumentInfo2.getMethodParameterIndex() - parameterToInstrumentInfo.getMethodParameterIndex();
        });
        int i = 0;
        int size = methodTree.getParameters().size();
        for (VariableTree variableTree : methodTree.getParameters()) {
            if (variableTree != null && ((type = variableTree.getType()) == null || !TrauteConstants.PRIMITIVE_TYPES.contains(type.toString()))) {
                Optional<String> findNotNullAnnotation = findNotNullAnnotation(variableTree.getModifiers());
                if (findNotNullAnnotation.isPresent()) {
                    treeSet.add(new ParameterToInstrumentInfo(this.context, findNotNullAnnotation.get(), variableTree, jCBlock, getQualifiedMethodName(), i, size));
                }
                i++;
            }
        }
        for (ParameterToInstrumentInfo parameterToInstrumentInfo3 : treeSet) {
            mayBeSetPosition(parameterToInstrumentInfo3.getMethodParameter(), this.context.getAstFactory());
            this.parameterInstrumenter.instrument(parameterToInstrumentInfo3);
        }
    }

    private boolean mayBeInstrumentReturnType(@NotNull MethodTree methodTree) {
        JCTree.JCExpression returnType = methodTree.getReturnType();
        if (returnType == null || TrauteConstants.METHOD_RETURN_TYPES_TO_SKIP.contains(returnType.toString()) || !(returnType instanceof JCTree.JCExpression)) {
            return false;
        }
        Optional<String> findNotNullAnnotation = findNotNullAnnotation(methodTree.getModifiers());
        if (!findNotNullAnnotation.isPresent()) {
            return false;
        }
        this.methodNotNullAnnotation = findNotNullAnnotation.get();
        this.methodReturnType = returnType;
        return true;
    }

    @NotNull
    private String getTmpVariableName() {
        StringBuilder append = new StringBuilder().append("tmpTrauteVar");
        int i = this.tmpVariableCounter + 1;
        this.tmpVariableCounter = i;
        return append.append(i).toString();
    }

    private void mayBeSetPosition(@NotNull Tree tree, @NotNull TreeMaker treeMaker) {
        if (tree instanceof JCTree) {
            treeMaker.at(((JCTree) tree).pos);
        }
    }

    @NotNull
    private Optional<String> findNotNullAnnotation(@Nullable ModifiersTree modifiersTree) {
        List annotations;
        if (modifiersTree != null && (annotations = modifiersTree.getAnnotations()) != null) {
            HashSet hashSet = new HashSet();
            Iterator it = annotations.iterator();
            while (it.hasNext()) {
                Tree annotationType = ((AnnotationTree) it.next()).getAnnotationType();
                if (annotationType != null) {
                    hashSet.add(annotationType.toString());
                }
            }
            return findMatch(hashSet);
        }
        return Optional.empty();
    }

    @NotNull
    private Optional<String> findMatch(@NotNull Collection<String> collection) {
        for (String str : collection) {
            Set<String> notNullAnnotations = this.context.getPluginSettings().getNotNullAnnotations();
            if (notNullAnnotations.contains(str)) {
                return Optional.of(str);
            }
            if (this.packageName != null) {
                String format = String.format("%s.%s", this.packageName, str);
                if (notNullAnnotations.contains(format)) {
                    return Optional.of(format);
                }
            }
            for (String str2 : this.context.getImports()) {
                if (str2.endsWith(".*")) {
                    String str3 = str2.substring(0, str2.length() - 1) + str;
                    if (notNullAnnotations.contains(str3)) {
                        return Optional.of(str3);
                    }
                } else if (notNullAnnotations.contains(str2) && str2.endsWith(str)) {
                    return Optional.of(str2);
                }
            }
        }
        return Optional.empty();
    }

    @Nullable
    private String getQualifiedMethodName() {
        StringBuilder sb = new StringBuilder();
        if (this.packageName != null) {
            sb.append(this.packageName).append(".");
        }
        if (this.className != null) {
            sb.append(this.className).append(".");
        } else if (this.packageName != null) {
            return null;
        }
        if (this.methodName == null || sb.length() == 0) {
            return null;
        }
        sb.append(this.methodName);
        return sb.toString();
    }

    public Void visitBlock(BlockTree blockTree, Void r6) {
        this.parents.push(blockTree);
        try {
            Void r0 = (Void) super.visitBlock(blockTree, r6);
            this.parents.pop();
            return r0;
        } catch (Throwable th) {
            this.parents.pop();
            throw th;
        }
    }

    public Void visitIf(IfTree ifTree, Void r6) {
        this.parents.push(ifTree);
        try {
            Void r0 = (Void) super.visitIf(ifTree, r6);
            this.parents.pop();
            return r0;
        } catch (Throwable th) {
            this.parents.pop();
            throw th;
        }
    }

    public Void visitForLoop(ForLoopTree forLoopTree, Void r6) {
        this.parents.push(forLoopTree);
        try {
            Void r0 = (Void) super.visitForLoop(forLoopTree, r6);
            this.parents.pop();
            return r0;
        } catch (Throwable th) {
            this.parents.pop();
            throw th;
        }
    }

    public Void visitEnhancedForLoop(EnhancedForLoopTree enhancedForLoopTree, Void r6) {
        this.parents.push(enhancedForLoopTree);
        try {
            Void r0 = (Void) super.visitEnhancedForLoop(enhancedForLoopTree, r6);
            this.parents.pop();
            return r0;
        } catch (Throwable th) {
            this.parents.pop();
            throw th;
        }
    }

    public Void visitWhileLoop(WhileLoopTree whileLoopTree, Void r6) {
        this.parents.push(whileLoopTree);
        try {
            Void r0 = (Void) super.visitWhileLoop(whileLoopTree, r6);
            this.parents.pop();
            return r0;
        } catch (Throwable th) {
            this.parents.pop();
            throw th;
        }
    }

    public Void visitDoWhileLoop(DoWhileLoopTree doWhileLoopTree, Void r6) {
        this.parents.push(doWhileLoopTree);
        try {
            Void r0 = (Void) super.visitDoWhileLoop(doWhileLoopTree, r6);
            this.parents.pop();
            return r0;
        } catch (Throwable th) {
            this.parents.pop();
            throw th;
        }
    }

    public Void visitCase(CaseTree caseTree, Void r6) {
        this.parents.push(caseTree);
        try {
            Void r0 = (Void) super.visitCase(caseTree, r6);
            this.parents.pop();
            return r0;
        } catch (Throwable th) {
            this.parents.pop();
            throw th;
        }
    }

    public Void visitReturn(ReturnTree returnTree, Void r13) {
        if (this.instrumentReturnExpression && this.methodNotNullAnnotation != null && this.methodReturnType != null && !this.parents.isEmpty()) {
            mayBeSetPosition(returnTree, this.context.getAstFactory());
            this.returnInstrumenter.instrument(new ReturnToInstrumentInfo(this.context, this.methodNotNullAnnotation, returnTree, this.methodReturnType, getTmpVariableName(), this.parents.peek(), getQualifiedMethodName()));
        }
        return (Void) super.visitReturn(returnTree, r13);
    }
}
