/*
 * Decompiled with CFR 0.152.
 */
package net.truej.sql.compiler;

import com.sun.source.util.JavacTask;
import com.sun.source.util.Plugin;
import com.sun.source.util.TaskEvent;
import com.sun.source.util.TaskListener;
import com.sun.tools.javac.api.BasicJavacTask;
import com.sun.tools.javac.code.Symbol;
import com.sun.tools.javac.code.Symtab;
import com.sun.tools.javac.code.Type;
import com.sun.tools.javac.code.TypeTag;
import com.sun.tools.javac.tree.JCTree;
import com.sun.tools.javac.tree.TreeMaker;
import com.sun.tools.javac.tree.TreeScanner;
import com.sun.tools.javac.util.Context;
import com.sun.tools.javac.util.JCDiagnostic;
import com.sun.tools.javac.util.List;
import com.sun.tools.javac.util.Name;
import com.sun.tools.javac.util.Names;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.RecordComponent;
import java.lang.runtime.SwitchBootstraps;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Objects;
import java.util.function.Consumer;
import java.util.function.Function;
import net.truej.sql.TrueSql;
import net.truej.sql.bindings.NullParameter;
import net.truej.sql.bindings.Standard;
import net.truej.sql.compiler.CompilerMessages;
import net.truej.sql.compiler.InvocationsFinder;
import net.truej.sql.compiler.JdbcMetadataFetcher;
import net.truej.sql.compiler.StatementGenerator;
import net.truej.sql.compiler.TypeChecker;
import net.truej.sql.fetch.As;
import net.truej.sql.fetch.NewConstraint;
import net.truej.sql.fetch.NoUpdateCount;
import net.truej.sql.fetch.Parameters;
import net.truej.sql.fetch.Q;
import net.truej.sql.fetch.UpdateCount;
import org.jetbrains.annotations.Nullable;

public class TrueSqlPlugin
implements Plugin {
    public static final String NAME = "TrueSql";
    HashSet<Symbol.MethodSymbol> trueSqlDslMethods = null;

    static java.util.List<Object> doofyEncode(final Object value) {
        if (value == null) {
            return null;
        }
        Class<?> cl = value.getClass();
        if (java.util.List.class.isAssignableFrom(cl)) {
            ArrayList<Object> result = new ArrayList<Object>();
            result.add(cl.getName());
            for (Object obj : (java.util.List)value) {
                result.add(TrueSqlPlugin.doofyEncode(obj));
            }
            return result;
        }
        if (cl.isRecord()) {
            ArrayList<Object> result = new ArrayList<Object>();
            result.add(cl.getName());
            for (RecordComponent rc : cl.getRecordComponents()) {
                try {
                    result.add(TrueSqlPlugin.doofyEncode(rc.getAccessor().invoke(value, new Object[0])));
                }
                catch (IllegalAccessException | InvocationTargetException e) {
                    throw new RuntimeException(e);
                }
            }
            return result;
        }
        if (cl.isEnum()) {
            return java.util.List.of(cl.getName(), ((Enum)value).name());
        }
        return new ArrayList<Object>(){
            {
                this.add(null);
                this.add(value);
            }
        };
    }

    static Object doofyDecode(Object in) {
        if (in == null) {
            return null;
        }
        java.util.List value = (java.util.List)in;
        try {
            Object className = value.get(0);
            if (className == null) {
                return value.get(1);
            }
            Class<?> cl = Class.forName((String)className);
            if (java.util.List.class.isAssignableFrom(cl)) {
                return value.stream().skip(1L).map(TrueSqlPlugin::doofyDecode).toList();
            }
            if (cl.isEnum()) {
                return Enum.valueOf(cl, (String)value.get(1));
            }
            return cl.getConstructors()[0].newInstance(value.stream().skip(1L).map(TrueSqlPlugin::doofyDecode).toArray());
        }
        catch (ClassNotFoundException | IllegalAccessException | InstantiationException | InvocationTargetException e) {
            throw new RuntimeException(e);
        }
    }

    @Override
    public String getName() {
        return NAME;
    }

    static String arrayClassNameToSourceCodeType(String className) {
        return switch (className.charAt(0)) {
            case '[' -> TrueSqlPlugin.arrayClassNameToSourceCodeType(className.substring(1)) + "[]";
            case 'B' -> "byte";
            case 'C' -> "char";
            case 'S' -> "short";
            case 'I' -> "int";
            case 'J' -> "long";
            case 'F' -> "float";
            case 'D' -> "double";
            case 'Z' -> "boolean";
            default -> className.substring(1, className.length() - 1);
        };
    }

    static String classNameToSourceCodeType(String className) {
        return className.startsWith("[") ? TrueSqlPlugin.arrayClassNameToSourceCodeType(className) : className;
    }

    static String arrayTypeToClassName(Type type) {
        if (type instanceof Type.ArrayType) {
            Type.ArrayType at = (Type.ArrayType)type;
            return "[" + TrueSqlPlugin.arrayTypeToClassName(at.elemtype);
        }
        return switch (type.getTag()) {
            case TypeTag.BYTE -> "B";
            case TypeTag.CHAR -> "C";
            case TypeTag.SHORT -> "S";
            case TypeTag.INT -> "I";
            case TypeTag.LONG -> "J";
            case TypeTag.FLOAT -> "F";
            case TypeTag.DOUBLE -> "D";
            case TypeTag.BOOLEAN -> "Z";
            default -> "L" + String.valueOf(type.tsym.flatName()) + ";";
        };
    }

    static String typeToClassName(Type type) {
        String string;
        if (type instanceof Type.ArrayType) {
            Type.ArrayType at = (Type.ArrayType)type;
            string = TrueSqlPlugin.arrayTypeToClassName(at);
        } else {
            string = type.tsym.flatName().toString();
        }
        return string;
    }

    void checkParameters(Symtab symtab, JCTree.JCMethodInvocation tree, FetchInvocation invocation) {
        if (invocation.parametersMetadata == null) {
            return;
        }
        static interface ParameterChecker {
            public void check(int var1, Type var2, ParameterMode var3);
        }
        ParameterChecker checkParameter = (pIndex, javaParameterType, javaParameterMode) -> {
            JdbcMetadataFetcher.SqlParameterMetadata pMetadata = invocation.parametersMetadata.get(pIndex);
            if (!invocation.onDatabase.equals("PostgreSQL") && pMetadata.mode() != javaParameterMode) {
                throw new ValidationException(tree, "for parameter " + (pIndex + 1) + " expected mode " + String.valueOf((Object)pMetadata.mode()) + " but has " + String.valueOf((Object)javaParameterMode));
            }
            if (javaParameterType != symtab.botType) {
                Standard.Binding binding = TypeChecker.getBindingForClass(tree, invocation.bindings, true, TrueSqlPlugin.typeToClassName(javaParameterType));
                TypeChecker.assertTypesCompatible(invocation.onDatabase, pMetadata.sqlType(), pMetadata.sqlTypeName(), pMetadata.javaClassName(), pMetadata.scale(), binding, (typeKind, expected, has) -> new ValidationException(tree, typeKind + " mismatch for parameter " + (pIndex + 1) + ". Expected " + has + " but has " + expected));
            }
        };
        int pIndex2 = 0;
        block10: for (InvocationsFinder.QueryPart part : invocation.query.parts()) {
            InvocationsFinder.QueryPart queryPart;
            Objects.requireNonNull(part);
            int n = 0;
            switch (SwitchBootstraps.typeSwitch("typeSwitch", new Object[]{InvocationsFinder.TextPart.class, InvocationsFinder.InOrInoutParameter.class, InvocationsFinder.OutParameter.class, InvocationsFinder.UnfoldParameter.class}, (Object)queryPart, n)) {
                default: {
                    throw new MatchException(null, null);
                }
                case 0: {
                    InvocationsFinder.TextPart __ = (InvocationsFinder.TextPart)queryPart;
                    continue block10;
                }
                case 1: {
                    InvocationsFinder.InOrInoutParameter inOrInoutParameter;
                    InvocationsFinder.InOrInoutParameter p = (InvocationsFinder.InOrInoutParameter)queryPart;
                    Type type = p.expression().type;
                    Objects.requireNonNull(p);
                    int n2 = 0;
                    checkParameter.check(pIndex2, type, switch (SwitchBootstraps.typeSwitch("typeSwitch", new Object[]{InvocationsFinder.InParameter.class, InvocationsFinder.InoutParameter.class}, (Object)inOrInoutParameter, n2)) {
                        default -> throw new MatchException(null, null);
                        case 0 -> {
                            InvocationsFinder.InParameter __ = (InvocationsFinder.InParameter)inOrInoutParameter;
                            yield ParameterMode.IN;
                        }
                        case 1 -> {
                            InvocationsFinder.InoutParameter __ = (InvocationsFinder.InoutParameter)inOrInoutParameter;
                            yield ParameterMode.INOUT;
                        }
                    });
                    ++pIndex2;
                    continue block10;
                }
                case 2: {
                    InvocationsFinder.OutParameter p = (InvocationsFinder.OutParameter)queryPart;
                    checkParameter.check(pIndex2, p.toType(), ParameterMode.OUT);
                    ++pIndex2;
                    continue block10;
                }
                case 3: 
            }
            InvocationsFinder.UnfoldParameter p = (InvocationsFinder.UnfoldParameter)queryPart;
            int argumentCount = StatementGenerator.unfoldArgumentsCount(p.extractor());
            if (p.extractor() == null) {
                checkParameter.check(pIndex2, (Type)p.expression().type.getTypeArguments().getFirst(), ParameterMode.IN);
            } else {
                for (int i = 0; i < argumentCount; ++i) {
                    checkParameter.check(pIndex2 + i, ((JCTree.JCNewArray)p.extractor().body).elems.get((int)i).type, ParameterMode.IN);
                }
            }
            pIndex2 += argumentCount;
        }
    }

    static Class<?> primitiveTypeToBoxedClass(Type.JCPrimitiveType pt) {
        return switch (pt.getTag()) {
            case TypeTag.BYTE -> Byte.class;
            case TypeTag.CHAR -> Character.class;
            case TypeTag.SHORT -> Short.class;
            case TypeTag.LONG -> Long.class;
            case TypeTag.FLOAT -> Float.class;
            case TypeTag.INT -> Integer.class;
            case TypeTag.DOUBLE -> Double.class;
            default -> Boolean.class;
        };
    }

    static Type boxType(Names names, Symtab symtab, Type t) {
        Function<Class, Type> boxed = cl -> symtab.getClass((Symbol.ModuleSymbol)symtab.java_base, (Name)names.fromString((String)cl.getName())).type;
        if (t instanceof Type.JCPrimitiveType) {
            Type.JCPrimitiveType pt = (Type.JCPrimitiveType)t;
            return boxed.apply(TrueSqlPlugin.primitiveTypeToBoxedClass(pt));
        }
        return t;
    }

    void handle(Symtab symtab, Names names, TreeMaker maker, JCTree.JCMethodInvocation tree, FetchInvocation invocation) {
        this.checkParameters(symtab, tree, invocation);
        Symbol.ClassSymbol clParameterExtractor = symtab.getClass(symtab.java_base, names.fromString(Function.class.getName()));
        Symbol.ClassSymbol clGenerated = symtab.getClass(symtab.unnamedModule, names.fromString(invocation.generatedClassName));
        Symbol.MethodSymbol mtGenerated = (Symbol.MethodSymbol)clGenerated.members().getSymbolsByName(names.fromString(invocation.fetchMethodName + "__line" + invocation.lineNumber + "__")).iterator().next();
        tree.meth = maker.Select((JCTree.JCExpression)maker.Ident(clGenerated), mtGenerated);
        tree.args = List.nil();
        Function<Type, JCTree.JCExpression> createRwFor = type -> {
            Symbol.ClassSymbol rwClassSymbol;
            if (type == symtab.botType) {
                rwClassSymbol = symtab.getClass(symtab.unnamedModule, names.fromString(NullParameter.class.getName()));
            } else {
                String forClassName = TrueSqlPlugin.typeToClassName(type);
                Standard.Binding binding = TypeChecker.getBindingForClass(tree, invocation.bindings, true, forClassName);
                rwClassSymbol = symtab.getClass(symtab.unnamedModule, names.fromString(binding.rwClassName()));
            }
            Symbol.MethodSymbol rwClassConstructor = (Symbol.MethodSymbol)rwClassSymbol.members().getSymbols(sym -> {
                if (!(sym instanceof Symbol.MethodSymbol)) return false;
                Symbol.MethodSymbol m = (Symbol.MethodSymbol)sym;
                if (!m.name.equals(names.fromString("<init>"))) return false;
                if (!m.params.isEmpty()) return false;
                return true;
            }).iterator().next();
            JCTree.JCNewClass newClass = maker.NewClass(null, List.nil(), maker.Ident(rwClassSymbol), List.nil(), null);
            newClass.type = new Type.ClassType(Type.noType, List.nil(), rwClassSymbol);
            newClass.constructor = rwClassConstructor;
            newClass.constructor.type = rwClassConstructor.type;
            return newClass;
        };
        StatementGenerator.Query query = invocation.query;
        Objects.requireNonNull(query);
        StatementGenerator.Query query2 = query;
        int n = 0;
        switch (SwitchBootstraps.typeSwitch("typeSwitch", new Object[]{StatementGenerator.BatchedQuery.class, StatementGenerator.SingleQuery.class}, (Object)query2, n)) {
            default: {
                throw new MatchException(null, null);
            }
            case 0: {
                StatementGenerator.BatchedQuery bq = (StatementGenerator.BatchedQuery)query2;
                tree.args = tree.args.append(bq.listDataExpression());
                for (InvocationsFinder.QueryPart part : bq.parts()) {
                    InvocationsFinder.QueryPart queryPart;
                    Objects.requireNonNull(part);
                    int n2 = 0;
                    switch (SwitchBootstraps.typeSwitch("typeSwitch", new Object[]{InvocationsFinder.InParameter.class}, (Object)queryPart, n2)) {
                        case 0: {
                            InvocationsFinder.InParameter p = (InvocationsFinder.InParameter)queryPart;
                            JCTree.JCLambda extractor = new JCTree.JCLambda(List.of((JCTree.JCVariableDecl)bq.extractor().params.head), p.expression());
                            Type.ClassType extractorType = new Type.ClassType(Type.noType, List.of(((JCTree.JCVariableDecl)bq.extractor().params.head).type, TrueSqlPlugin.boxType(names, symtab, p.expression().type)), clParameterExtractor);
                            extractor.type = extractorType;
                            extractor.target = extractorType;
                            tree.args = tree.args.append(extractor);
                            tree.args = tree.args.append(createRwFor.apply(p.expression().type));
                            break;
                        }
                    }
                }
                break;
            }
            case 1: {
                StatementGenerator.SingleQuery sq = (StatementGenerator.SingleQuery)query2;
                block14: for (InvocationsFinder.QueryPart part : sq.parts()) {
                    InvocationsFinder.QueryPart queryPart;
                    Objects.requireNonNull(part);
                    int n3 = 0;
                    switch (SwitchBootstraps.typeSwitch("typeSwitch", new Object[]{InvocationsFinder.InOrInoutParameter.class, InvocationsFinder.OutParameter.class, InvocationsFinder.UnfoldParameter.class, InvocationsFinder.TextPart.class}, (Object)queryPart, n3)) {
                        default: {
                            throw new MatchException(null, null);
                        }
                        case 0: {
                            InvocationsFinder.InOrInoutParameter p = (InvocationsFinder.InOrInoutParameter)queryPart;
                            tree.args = tree.args.append(p.expression());
                            tree.args = tree.args.append(createRwFor.apply(p.expression().type));
                            break;
                        }
                        case 1: {
                            InvocationsFinder.OutParameter p = (InvocationsFinder.OutParameter)queryPart;
                            tree.args = tree.args.append(createRwFor.apply(p.toType()));
                            break;
                        }
                        case 2: {
                            InvocationsFinder.UnfoldParameter p = (InvocationsFinder.UnfoldParameter)queryPart;
                            int n4 = StatementGenerator.unfoldArgumentsCount(p.extractor());
                            Type unfoldType = (Type)p.expression().type.allparams().head;
                            tree.args = tree.args.append(p.expression());
                            if (p.extractor() == null) {
                                tree.args = tree.args.append(createRwFor.apply(unfoldType));
                                break;
                            }
                            for (int i = 0; i < n4; ++i) {
                                JCTree.JCExpression partExpression = ((JCTree.JCNewArray)p.extractor().body).elems.get(i);
                                JCTree.JCLambda extractor = new JCTree.JCLambda(List.of((JCTree.JCVariableDecl)p.extractor().params.head), partExpression);
                                Type.ClassType extractorType = new Type.ClassType(Type.noType, List.of(((JCTree.JCVariableDecl)p.extractor().params.head).type, TrueSqlPlugin.boxType(names, symtab, partExpression.type)), clParameterExtractor);
                                extractor.type = extractorType;
                                extractor.target = extractorType;
                                tree.args = tree.args.append(extractor);
                                tree.args = tree.args.append(createRwFor.apply(partExpression.type));
                            }
                            continue block14;
                        }
                        case 3: {
                            InvocationsFinder.TextPart textPart = (InvocationsFinder.TextPart)queryPart;
                        }
                    }
                }
            }
        }
        tree.args = tree.args.append(invocation.sourceExpression);
    }

    void addSymbol(HashSet<Symbol.MethodSymbol> to, Symbol.ClassSymbol clSym) {
        clSym.members().getSymbols().forEach(sym -> {
            if (sym instanceof Symbol.MethodSymbol) {
                Symbol.MethodSymbol mt = (Symbol.MethodSymbol)sym;
                to.add(mt);
            }
            if (sym instanceof Symbol.ClassSymbol) {
                Symbol.ClassSymbol cl = (Symbol.ClassSymbol)sym;
                this.addSymbol(to, cl);
            }
        });
    }

    HashSet<Symbol.MethodSymbol> parseDslMethods(Symtab symtab, Names names, Class<?> ... classes) {
        HashSet<Symbol.MethodSymbol> to = new HashSet<Symbol.MethodSymbol>();
        for (Class<?> aClass : classes) {
            this.addSymbol(to, symtab.enterClass(symtab.unnamedModule, names.fromString(aClass.getName())));
        }
        return to;
    }

    boolean isTrueSqlDslInvocation(Symtab symtab, Names names, JCTree.JCMethodInvocation tree) {
        Symbol.MethodSymbol mt;
        Symbol symbol;
        JCTree.JCExpression jCExpression;
        if (this.trueSqlDslMethods == null) {
            this.trueSqlDslMethods = this.parseDslMethods(symtab, names, As.class, Q.class, NewConstraint.class, Parameters.class, UpdateCount.class, NoUpdateCount.class);
        }
        if ((jCExpression = tree.meth) instanceof JCTree.JCIdent) {
            JCTree.JCIdent id = (JCTree.JCIdent)jCExpression;
            symbol = id.sym;
        } else {
            symbol = ((JCTree.JCFieldAccess)tree.meth).sym;
        }
        return symbol instanceof Symbol.MethodSymbol && this.trueSqlDslMethods.contains(mt = (Symbol.MethodSymbol)symbol);
    }

    void assertHasNoDanglingTrueSqlCalls(final Symtab symtab, final Names names, JCTree.JCCompilationUnit cu) {
        final boolean[] hasTrueSqlAnnotation = new boolean[]{false};
        cu.accept(new TreeScanner(this){

            @Override
            public void visitClassDef(JCTree.JCClassDecl tree) {
                if (tree.type.tsym.getAnnotation(TrueSql.class) != null) {
                    hasTrueSqlAnnotation[0] = true;
                }
                super.visitClassDef(tree);
            }
        });
        cu.accept(new TreeScanner(){

            @Override
            public void visitApply(JCTree.JCMethodInvocation tree) {
                block4: {
                    block5: {
                        super.visitApply(tree);
                        if (!TrueSqlPlugin.this.isTrueSqlDslInvocation(symtab, names, tree)) break block4;
                        JCTree.JCExpression jCExpression = tree.meth;
                        if (!(jCExpression instanceof JCTree.JCFieldAccess)) break block5;
                        JCTree.JCFieldAccess fa = (JCTree.JCFieldAccess)jCExpression;
                        if (fa.name.contentEquals("constraint")) break block4;
                    }
                    if (!hasTrueSqlAnnotation[0]) {
                        throw new ValidationException(tree, "TrueSql DSL used but class not annotated with @TrueSql");
                    }
                    throw new ValidationException(tree, "Incorrect TrueSql DSL usage - dangling call");
                }
            }
        });
    }

    @Override
    public void init(final JavacTask task, String ... args) {
        task.addTaskListener(new TaskListener(){

            @Override
            public void finished(TaskEvent e) {
                if (e.getKind() == TaskEvent.Kind.ANALYZE) {
                    HashMap cuTrees;
                    Context context = ((BasicJavacTask)task).getContext();
                    Symtab symtab = Symtab.instance(context);
                    Names names = Names.instance(context);
                    TreeMaker maker = TreeMaker.instance(context);
                    CompilerMessages messages = new CompilerMessages(context);
                    JCTree.JCCompilationUnit cu = (JCTree.JCCompilationUnit)e.getCompilationUnit();
                    HashMap dataFromAnnotationProcessor = context.get(HashMap.class);
                    boolean hasNoErrors = true;
                    Consumer<ValidationError> sendCompilationError = error -> messages.write((JCTree.JCCompilationUnit)e.getCompilationUnit(), error.tree, JCDiagnostic.DiagnosticType.ERROR, error.message);
                    if (dataFromAnnotationProcessor != null && (cuTrees = (HashMap)dataFromAnnotationProcessor.get(cu)) != null) {
                        for (Map.Entry kv : cuTrees.entrySet()) {
                            JCTree.JCMethodInvocation tree = (JCTree.JCMethodInvocation)kv.getKey();
                            MethodInvocationResult invocation = (MethodInvocationResult)TrueSqlPlugin.doofyDecode(kv.getValue());
                            if (!TrueSqlPlugin.this.isTrueSqlDslInvocation(symtab, names, tree)) continue;
                            try {
                                MethodInvocationResult methodInvocationResult;
                                Objects.requireNonNull(invocation);
                                int n = 0;
                                switch (SwitchBootstraps.typeSwitch("typeSwitch", new Object[]{FetchInvocation.class, ValidationError.class}, (Object)methodInvocationResult, n)) {
                                    default: {
                                        throw new MatchException(null, null);
                                    }
                                    case 0: {
                                        FetchInvocation fi = (FetchInvocation)methodInvocationResult;
                                        for (CompilationWarning warning : fi.warnings) {
                                            messages.write(cu, warning.tree, JCDiagnostic.DiagnosticType.WARNING, warning.message);
                                        }
                                        TrueSqlPlugin.this.handle(symtab, names, maker, tree, fi);
                                        break;
                                    }
                                    case 1: {
                                        ValidationError error2 = (ValidationError)methodInvocationResult;
                                        hasNoErrors = false;
                                        sendCompilationError.accept(error2);
                                        break;
                                    }
                                }
                            }
                            catch (ValidationException ex) {
                                hasNoErrors = false;
                                sendCompilationError.accept(new ValidationError(ex.tree, ex.getMessage()));
                            }
                        }
                    }
                    try {
                        if (hasNoErrors) {
                            TrueSqlPlugin.this.assertHasNoDanglingTrueSqlCalls(symtab, names, cu);
                        }
                    }
                    catch (ValidationException ex) {
                        sendCompilationError.accept(new ValidationError(ex.tree, ex.getMessage()));
                    }
                }
            }
        });
    }

    public record FetchInvocation(String onDatabase, java.util.List<Standard.Binding> bindings, String generatedClassName, String fetchMethodName, int lineNumber, java.util.List<CompilationWarning> warnings, JCTree.JCIdent sourceExpression, StatementGenerator.Query query, @Nullable java.util.List<JdbcMetadataFetcher.SqlParameterMetadata> parametersMetadata, String generatedCode) implements MethodInvocationResult
    {
    }

    static enum ParameterMode {
        IN,
        INOUT,
        OUT,
        UNKNOWN;

    }

    public static class ValidationException
    extends RuntimeException {
        final JCTree tree;

        public ValidationException(JCTree tree, String message) {
            super(message);
            this.tree = tree;
        }
    }

    public record CompilationWarning(JCTree tree, String message) {
    }

    public record ValidationError(JCTree tree, String message) implements MethodInvocationResult
    {
    }

    public static sealed interface MethodInvocationResult
    permits ValidationError, FetchInvocation {
    }
}

