/*
 * Decompiled with CFR 0.152.
 */
package net.codecrete.windowsapi.writer;

import java.lang.runtime.SwitchBootstraps;
import java.util.Iterator;
import java.util.Objects;
import net.codecrete.windowsapi.metadata.ComInterface;
import net.codecrete.windowsapi.metadata.EnumType;
import net.codecrete.windowsapi.metadata.Method;
import net.codecrete.windowsapi.metadata.Parameter;
import net.codecrete.windowsapi.metadata.Primitive;
import net.codecrete.windowsapi.metadata.PrimitiveKind;
import net.codecrete.windowsapi.metadata.Type;
import net.codecrete.windowsapi.metadata.TypeAlias;
import net.codecrete.windowsapi.writer.AddressLayout;
import net.codecrete.windowsapi.writer.CommentWriter;
import net.codecrete.windowsapi.writer.FunctionCodeWriterBase;
import net.codecrete.windowsapi.writer.GenerationContext;

class ComInterfaceWriter
extends FunctionCodeWriterBase<ComInterface> {
    private final CommentWriter commentWriter = new CommentWriter();

    ComInterfaceWriter(GenerationContext generationContext) {
        super(generationContext);
    }

    void writeComInterface(ComInterface comInterface) {
        String className = ComInterfaceWriter.toJavaClassName(comInterface.name());
        this.withFile(comInterface.namespace(), comInterface, className, this::writeComInterfaceContent);
    }

    private void writeComInterfaceContent() {
        this.writer.printf("package %s;\n\nimport java.lang.foreign.*;\nimport java.lang.invoke.*;\nimport static java.lang.foreign.ValueLayout.*;\n\n", this.packageName);
        this.writeComInterfaceComment();
        Object extendsInterface = "";
        ComInterface implementedInterface = ((ComInterface)this.type).implementedInterface();
        if (implementedInterface != null) {
            extendsInterface = ComInterfaceWriter.toJavaClassName(implementedInterface.name());
            if (((ComInterface)this.type).namespace() != implementedInterface.namespace()) {
                extendsInterface = this.toJavaPackageName(implementedInterface.namespace().name()) + "." + (String)extendsInterface;
            }
            extendsInterface = " extends " + (String)extendsInterface;
        }
        this.writer.printf("public interface %s%s {\n", this.className, extendsInterface);
        String[] methodNames = ComInterfaceWriter.getAllMethodNames((ComInterface)this.type);
        this.writeComInterfaceMethods(methodNames);
        this.writeCommonMethods(methodNames);
        this.writeAddressLayouts();
        this.writeIidInnerClass();
        this.writeDowncallWrapper(methodNames, (String)extendsInterface);
        int methodOffset = ComInterfaceWriter.getNumSuperMethods((ComInterface)this.type);
        for (int i = 0; i < ((ComInterface)this.type).methods().size(); ++i) {
            Method method = ((ComInterface)this.type).methods().get(i);
            int methodIndex = methodOffset + i;
            String innerClassName = "VFUNC" + methodIndex;
            this.writeFunctionInnerClass(method, innerClassName);
        }
        this.writeUpcallWrapper(methodNames);
        this.writeUpcallImplementation(methodNames);
        this.writer.println("}");
    }

    private void writeComInterfaceMethods(String[] methodNames) {
        int methodOffset = ComInterfaceWriter.getNumSuperMethods((ComInterface)this.type);
        for (int i = 0; i < ((ComInterface)this.type).methods().size(); ++i) {
            Method method = ((ComInterface)this.type).methods().get(i);
            String methodName = methodNames[methodOffset + i];
            this.commentWriter.writeFunctionComment(this.writer, method, "COM interface method");
            this.writer.print("    ");
            this.writeFunctionSignatureIntro(method, methodName);
            this.writeFunctionSignatureParameters(method);
            this.writer.println(";");
            this.writer.println();
        }
    }

    private void writeAddressLayouts() {
        int methodOffset = ComInterfaceWriter.getNumSuperMethods((ComInterface)this.type);
        this.writer.printf("    StructLayout %1$s$COM_OBJECT_LAYOUT = MemoryLayout.structLayout(\n        ADDRESS.withTargetLayout(MemoryLayout.sequenceLayout(%2$d, ADDRESS)).withName(\"vtable\")\n    );\n\n    AddressLayout %1$s$ADDRESS_LAYOUT = ADDRESS.withTargetLayout(%1$s$COM_OBJECT_LAYOUT);\n\n", this.className, methodOffset + ((ComInterface)this.type).methods().size());
        AddressLayout.requiredLayouts((ComInterface)this.type).forEach(layoutType -> this.writeAddressLayoutInitialization((AddressLayout)layoutType, ""));
        this.writer.println();
    }

    private void writeIidInnerClass() {
        if (((ComInterface)this.type).getIid() == null) {
            return;
        }
        this.writer.printf("    class $IID$%1$s {\n", this.className);
        this.writer.print("        private static final Arena ARENA = Arena.ofAuto();\n\n");
        this.writeCreateGuidMethod(8);
        this.writeGuidConstantMemorySegment("IID", ((ComInterface)this.type).getIid(), 8);
        this.writer.print("    }\n\n");
    }

    private void writeCommonMethods(String[] methodNames) {
        this.writeComment("Wraps the given COM object in a Java object with methods to call the COM interface functions.", new Object[0]);
        this.writer.printf("    static %s wrap(MemorySegment comObject) {\n        return new $DOWNCALL(comObject.reinterpret(8L));\n    }\n\n", this.className);
        this.writeComment("Gets the address layout for pointers to {@code %s} COM interfaces.", ((ComInterface)this.type).name());
        this.writer.printf("    static AddressLayout addressLayout() {\n        return %1$s$ADDRESS_LAYOUT;\n    }\n\n", this.className);
        if (((ComInterface)this.type).getIid() != null) {
            this.writeComment("Interface identifier (IID) for {@code %s} ({@code {%s}}).", ((ComInterface)this.type).name(), ((ComInterface)this.type).getIid());
            this.writer.printf("    static MemorySegment iid() {\n        return $IID$%1$s.IID$SEG;\n    }\n\n", this.className);
        }
        this.writeComment("Creates a COM object instance for the given Java object implementing {@code %s}.", ComInterfaceWriter.toJavaClassName(((ComInterface)this.type).name()));
        this.writer.printf("    static MemorySegment create(%1$s obj, Arena arena) {\n        var unwrapper = new $UPCALL(obj);\n        var vtable = arena.allocate(ADDRESS, %2$d);\n        var linker = Linker.nativeLinker();\n        for (int i = 0; i < %2$d; i++)\n            vtable.set(ADDRESS, 8 * i, linker.upcallStub($UPCALL_IMPL.HANDLES[i].bindTo(unwrapper), $UPCALL_IMPL.DESCRIPTORS[i], arena));\n        var objSegment = arena.allocate(ADDRESS);\n        objSegment.set(ADDRESS, 0, vtable);\n        return objSegment;\n    }\n\n", ComInterfaceWriter.toJavaClassName(((ComInterface)this.type).name()), methodNames.length);
    }

    private void writeDowncallWrapper(String[] methodNames, String extendsInterface) {
        int methodOffset = ComInterfaceWriter.getNumSuperMethods((ComInterface)this.type);
        ComInterface implementedInterface = ((ComInterface)this.type).implementedInterface();
        Object extendsSuperClass = "";
        if (!extendsInterface.isEmpty()) {
            extendsSuperClass = extendsInterface + ".$DOWNCALL";
        }
        this.writer.printf("    class $DOWNCALL%2$s implements %1$s {\n        private static final VarHandle VTABLE_FUNC_VARHANDLE = %1$s$COM_OBJECT_LAYOUT.varHandle(\n                MemoryLayout.PathElement.groupElement(\"vtable\"),\n                MemoryLayout.PathElement.dereferenceElement(),\n                MemoryLayout.PathElement.sequenceElement()\n        );\n\n        private static MemorySegment vtableFunc(MemorySegment comObject, long index) {\n            return (MemorySegment) VTABLE_FUNC_VARHANDLE.get(comObject, 0L, index);\n        }\n\n", this.className, extendsSuperClass);
        if (implementedInterface != null) {
            this.writer.printf("        protected $DOWNCALL(MemorySegment comObject) {\n            super(comObject);\n        }\n\n", new Object[0]);
        } else {
            this.writer.printf("        protected final MemorySegment comObject;\n\n        protected $DOWNCALL(MemorySegment comObject) {\n            this.comObject = comObject;\n        }\n\n", new Object[0]);
        }
        for (int i = 0; i < ((ComInterface)this.type).methods().size(); ++i) {
            Method method = ((ComInterface)this.type).methods().get(i);
            int methodIndex = methodOffset + i;
            String methodName = methodNames[methodIndex];
            String innerClassName = "VFUNC" + methodIndex;
            this.writer.print("        public ");
            this.writeFunctionSignatureIntro(method, methodName);
            this.writeFunctionSignatureParameters(method);
            this.writer.println(" {");
            String invokeString = innerClassName + "$IMPL.HANDLE.invokeExact(vtableFunc(comObject, " + methodIndex + "), comObject";
            if (method.parameters().length > 0) {
                invokeString = invokeString + ", ";
            }
            this.writeInvoke(method, invokeString, 12);
            this.writer.println("        }");
            this.writer.println();
        }
        this.writer.println("    }");
        this.writer.println();
    }

    private void writeFunctionInnerClass(Method method, String methodName) {
        this.writer.printf("    class %s$IMPL {\n", methodName);
        this.writer.print("        private static final FunctionDescriptor DESC = ");
        this.writeFunctionDescriptor(method, this.className + "$ADDRESS_LAYOUT");
        this.writer.println(";");
        this.writer.print("        private static final MethodHandle HANDLE = Linker.nativeLinker().downcallHandle(DESC);\n    }\n\n");
    }

    private void writeUpcallWrapper(String[] methodNames) {
        int numMethods = methodNames.length;
        Method[] methods = ComInterfaceWriter.getAllMethods((ComInterface)this.type);
        this.writer.printf("    class $UPCALL {\n        private final %1$s javaObject;\n\n        private $UPCALL(%1$s javaObject) {\n            this.javaObject = javaObject;\n        }\n\n", this.className);
        for (int i = 0; i < numMethods; ++i) {
            Method method = methods[i];
            String methodName = methodNames[i];
            this.writer.print("        ");
            this.writeFunctionSignatureIntro(method, methodName);
            this.writer.print("MemorySegment thisObject");
            if (method.parameters().length > 0) {
                this.writer.print(", ");
            }
            this.writeFunctionSignatureParameters(method);
            this.writer.println(" {");
            this.writer.print("            ");
            if (method.hasReturnType()) {
                this.writer.print("return ");
            }
            this.writer.print("javaObject.");
            this.writer.print(methodNames[i]);
            this.writer.print("(");
            for (int j = 0; j < method.parameters().length; ++j) {
                if (j > 0) {
                    this.writer.print(", ");
                }
                this.writer.print(ComInterfaceWriter.getJavaSafeName(method.parameters()[j].name()));
            }
            this.writer.println(");");
            this.writer.println("        }");
        }
        this.writer.print("    }\n\n");
    }

    private void writeUpcallImplementation(String[] methodNames) {
        int i;
        int numMethods = methodNames.length;
        Method[] methods = ComInterfaceWriter.getAllMethods((ComInterface)this.type);
        this.writer.printf("    class $UPCALL_IMPL {\n        private static final FunctionDescriptor[] DESCRIPTORS = createDescriptors();\n        private static final MethodHandle[] HANDLES = createHandles();\n\n        private static FunctionDescriptor[] createDescriptors() {\n            var descriptors = new FunctionDescriptor[%d];\n", numMethods);
        for (i = 0; i < numMethods; ++i) {
            this.writer.printf("            descriptors[%d] = ", i);
            this.writeFunctionDescriptor(methods[i], "ADDRESS");
            this.writer.println(";");
        }
        this.writer.printf("            return descriptors;\n        }\n\n        private static MethodHandle[] createHandles() {\n            try {\n                var lookup = MethodHandles.lookup();\n                var handles = new MethodHandle[%d];\n", numMethods);
        for (i = 0; i < numMethods; ++i) {
            this.writer.printf("                handles[%1$d] = lookup.findVirtual($UPCALL.class, \"%2$s\", DESCRIPTORS[%1$d].toMethodType());\n", i, methodNames[i]);
        }
        this.writer.print("               return handles;\n           } catch (ReflectiveOperationException e) {\n               throw new RuntimeException(e);\n           }\n       }\n   }\n");
    }

    private static int getNumSuperMethods(ComInterface comInterface) {
        int methodCount = 0;
        for (ComInterface superInterface = comInterface.implementedInterface(); superInterface != null; superInterface = superInterface.implementedInterface()) {
            methodCount += superInterface.methods().size();
        }
        return methodCount;
    }

    private static int collectMethods(ComInterface comInterface, Method[] methods) {
        int index = 0;
        if (comInterface.implementedInterface() != null) {
            index = ComInterfaceWriter.collectMethods(comInterface.implementedInterface(), methods);
        }
        Iterator<Method> iterator = comInterface.methods().iterator();
        while (iterator.hasNext()) {
            Method method;
            methods[index] = method = iterator.next();
            ++index;
        }
        return index;
    }

    private static Method[] getAllMethods(ComInterface comInterface) {
        int methodCount = ComInterfaceWriter.getNumSuperMethods(comInterface) + comInterface.methods().size();
        Method[] methods = new Method[methodCount];
        ComInterfaceWriter.collectMethods(comInterface, methods);
        return methods;
    }

    private static String[] getAllMethodNames(ComInterface comInterface) {
        int methodCount = ComInterfaceWriter.getNumSuperMethods(comInterface) + comInterface.methods().size();
        String[] methodNames = new String[methodCount];
        ComInterfaceWriter.collectMethodNames(comInterface, methodNames);
        return methodNames;
    }

    private static int collectMethodNames(ComInterface comInterface, String[] methodNames) {
        Method[] methods = ComInterfaceWriter.getAllMethods(comInterface);
        int numSuperMethods = 0;
        if (comInterface.implementedInterface() != null) {
            numSuperMethods = ComInterfaceWriter.collectMethodNames(comInterface.implementedInterface(), methodNames);
        }
        long[] signatureKeys = new long[methods.length];
        for (int i = 0; i < methods.length; ++i) {
            signatureKeys[i] = ComInterfaceWriter.getSignatureKey(methods[i]);
        }
        int numOwnMethods = comInterface.methods().size();
        for (int i = numSuperMethods; i < numSuperMethods + numOwnMethods; ++i) {
            if (methodNames[i] != null) continue;
            String methodName = methods[i].name();
            long signatureKey = ComInterfaceWriter.getSignatureKey(methods[i]);
            int numSameSignature = 0;
            int currentMethodIndex = -1;
            for (int j = 0; j < signatureKeys.length; ++j) {
                if (signatureKeys[j] != signatureKey || !methodName.equals(methods[j].name())) continue;
                ++numSameSignature;
                if (j > i) {
                    methodNames[j] = methodName + numSameSignature;
                    continue;
                }
                if (j != i) continue;
                currentMethodIndex = numSameSignature;
            }
            methodNames[i] = numSameSignature == 1 ? methodName : methodName + currentMethodIndex;
        }
        return numSuperMethods + numOwnMethods;
    }

    private static long getSignatureKey(Method method) {
        long index = method.parameters().length;
        for (Parameter param : method.parameters()) {
            index = index << 4 | ComInterfaceWriter.getJavaTypeKey(param.type());
        }
        return index;
    }

    private static long getJavaTypeKey(Type type) {
        Type type2 = type;
        Objects.requireNonNull(type2);
        Type type3 = type2;
        int n = 0;
        return switch (SwitchBootstraps.typeSwitch("typeSwitch", new Object[]{Primitive.class, EnumType.class, TypeAlias.class}, (Object)type3, n)) {
            case 0 -> {
                Primitive primitive = (Primitive)type3;
                yield ComInterfaceWriter.getPrimitiveJavaTypeIndex(primitive);
            }
            case 1 -> {
                EnumType enumType = (EnumType)type3;
                yield ComInterfaceWriter.getPrimitiveJavaTypeIndex(enumType.baseType());
            }
            case 2 -> {
                TypeAlias typeAlias = (TypeAlias)type3;
                yield ComInterfaceWriter.getJavaTypeKey(typeAlias.aliasedType());
            }
            default -> 0L;
        };
    }

    private static long getPrimitiveJavaTypeIndex(Primitive type) {
        return switch (type.kind()) {
            case PrimitiveKind.INT64, PrimitiveKind.UINT64, PrimitiveKind.INT_PTR, PrimitiveKind.UINT_PTR -> 1L;
            case PrimitiveKind.INT32, PrimitiveKind.UINT32 -> 2L;
            case PrimitiveKind.UINT16, PrimitiveKind.INT16, PrimitiveKind.CHAR -> 3L;
            case PrimitiveKind.BYTE, PrimitiveKind.SBYTE -> 4L;
            case PrimitiveKind.SINGLE -> 5L;
            case PrimitiveKind.DOUBLE -> 6L;
            case PrimitiveKind.BOOL -> 7L;
            default -> throw new AssertionError((Object)("Unexpected primitive type: " + type.name()));
        };
    }

    private void writeComInterfaceComment() {
        this.writer.printf("/**\n * {@code %s} COM interface\n", ((ComInterface)this.type).name());
        this.writeDocumentationUrl(this.type);
        this.writer.println(" */");
    }
}

