//
// Windows API Generator for Java
// Copyright (c) 2025 Manuel Bleichenbacher
// Licensed under MIT License
// https://opensource.org/licenses/MIT
//
package net.codecrete.windowsapi.writer;

import net.codecrete.windowsapi.events.Event;
import net.codecrete.windowsapi.metadata.ComInterface;
import net.codecrete.windowsapi.metadata.Delegate;
import net.codecrete.windowsapi.metadata.EnumType;
import net.codecrete.windowsapi.metadata.Namespace;
import net.codecrete.windowsapi.metadata.Pointer;
import net.codecrete.windowsapi.metadata.Primitive;
import net.codecrete.windowsapi.metadata.PrimitiveKind;
import net.codecrete.windowsapi.metadata.Struct;
import net.codecrete.windowsapi.metadata.Type;
import net.codecrete.windowsapi.metadata.TypeAlias;
import net.codecrete.windowsapi.winmd.LayoutRequirement;

import java.io.PrintWriter;
import java.nio.file.Path;
import java.util.Locale;
import java.util.Set;
import java.util.UUID;

/**
 * Base class for generating Java code.
 */
@SuppressWarnings("java:S1192")
class JavaCodeWriter<T extends Type> {
    /**
     * Shared generation context.
     */
    protected final GenerationContext generationContext;

    /**
     * Void type.
     */
    protected final Primitive voidType;

    /**
     * Print writer for the current Java file
     */
    protected PrintWriter writer;
    /**
     * Namespace of the current Java file
     */
    protected Namespace namespace;
    /**
     * The metadata type for the current Java file (if any)
     */
    protected T type;
    /**
     * The Java package name of the current Java file
     */
    protected String packageName;
    /**
     * The Java class name of the current Java file
     */
    protected String className;

    private int bitFieldNumber = 0;

    /**
     * Generates a new instance.
     *
     * @param generationContext the code generation context
     */
    JavaCodeWriter(GenerationContext generationContext) {
        this.generationContext = generationContext;
        voidType = generationContext.metadata().getPrimitive(PrimitiveKind.VOID);
    }

    /**
     * Gets the code generation context.
     *
     * @return the generation context
     */
    protected GenerationContext generationContext() {
        return generationContext;
    }

    /**
     * Creates a new Java source file for the specified package and class name and executes the action.
     * <p>
     * While the action is executed, the namespace, type, package name, class name and the print writer
     * are available through instance variables.
     * </p>
     * <p>
     * If the action is successful, the event listener is notified that a new file has been created.
     * </p>
     *
     * @param namespace the namespace
     * @param type      the metadata type that is the basis for the Java file (or {@code null}).
     * @param className the Java class name
     * @param action    the action to execute
     */
    protected void withFile(Namespace namespace, T type, String className, Runnable action) {
        packageName = toJavaPackageName(namespace.name());
        var path = createJavaClassPath(packageName, className);
        try (var w = generationContext.createWriter(path)) {
            writeHeader(w);
            writer = w;
            this.namespace = namespace;
            this.type = type;
            this.className = className;
            action.run();
            generationContext.notify(new Event.JavaSourceGenerated(path));

        } finally {
            writer = null;
            this.type = null;
            this.packageName = null;
            this.className = null;
            this.bitFieldNumber = 0;
        }
    }

    private void writeHeader(PrintWriter writer) {
        writer.print("""
                // Code generated by Windows API Generator.
                // Do not modify manually.
                
                """);
    }

    /**
     * Returns the next bit field number and increases it.
     *
     * @return the bit field number
     */
    protected int consumeBitFieldNumber() {
        bitFieldNumber += 1;
        return bitFieldNumber;
    }

    /**
     * Creates a valid Java class name for the given type name.
     * <p>
     * Both names are without package / namespace name.
     * </p>
     *
     * @param typeName the type name
     * @return the Java class name
     */
    static String toJavaClassName(String typeName) {
        // "AVIStreamHeader" and "AVISTREAMHEADER" conflict with each other
        // when the associated Java file is created as they only differ in case.
        if (typeName.equals("AVISTREAMHEADER"))
            return "AVISTREAMHEADER_";
        return typeName;
    }

    /**
     * Converts the namespace name to a Java package name.
     *
     * @param namespace the namespace
     * @return the Java package name
     */
    protected String toJavaPackageName(String namespace) {
        var lowercaseNamespace = namespace.toLowerCase(Locale.ROOT);
        if (generationContext.basePackage().isEmpty())
            return lowercaseNamespace;
        return generationContext.basePackage() + "." + lowercaseNamespace;
    }

    /**
     * Creates a path name for the given Java package and class name.
     * <p>
     * The resulting path name is relative to the output directory and ends with ".java".
     * </p>
     *
     * @param packageName the package name
     * @param className   the class name
     * @return the path name
     */
    private Path createJavaClassPath(String packageName, String className) {
        var pathComponents = packageName.split("\\.");
        var firstComponent = pathComponents[0];
        System.arraycopy(pathComponents, 1, pathComponents, 0, pathComponents.length - 1);
        pathComponents[pathComponents.length - 1] = className + ".java";
        return Path.of(firstComponent, pathComponents);
    }

    /**
     * Gets the Java type for the given primitive type.
     *
     * @param type the primitive type
     * @return the name of the Java type
     */
    static String getPrimitiveJavaType(Primitive type) {
        return switch (type.kind()) {
            case PrimitiveKind.INT64, PrimitiveKind.UINT64, PrimitiveKind.INT_PTR, PrimitiveKind.UINT_PTR -> "long";
            case PrimitiveKind.INT32, PrimitiveKind.UINT32 -> "int";
            case PrimitiveKind.UINT16, PrimitiveKind.INT16, PrimitiveKind.CHAR -> "short";
            case PrimitiveKind.BYTE, PrimitiveKind.SBYTE -> "byte";
            case PrimitiveKind.SINGLE -> "float";
            case PrimitiveKind.DOUBLE -> "double";
            case PrimitiveKind.BOOL -> "boolean";
            default -> throw new AssertionError("Unexpected primitive type: " + type.name());
        };
    }

    /**
     * Gets a Java expression for the given integer constant value.
     *
     * @param javaType the resulting Java type
     * @param value    the constant value
     * @return the Java expression
     */
    static String getJavaIntegerConstant(String javaType, Object value) {
        assert value instanceof Number;
        var number = (Number) value;
        return switch (javaType) {
            case "long" -> number.longValue() + "L";
            case "int" -> Integer.toString(number.intValue());
            case "short" -> "(short) " + number.shortValue();
            case "byte" -> "(byte) " + number.byteValue();
            default -> throw new AssertionError("Unexpected value: " + javaType);
        };
    }

    /**
     * Gets the Java type for the given metadata type.
     * <p>
     * Type aliases are resolved. For enumerations, the base integer type
     * is used. For non-primitive types, the result will be "MemorySegment".
     * </p>
     *
     * @param type the metadata type
     * @return the Java type
     */
    static String getJavaType(Type type) {
        return switch (type) {
            case Primitive primitive -> getPrimitiveJavaType(primitive);
            case EnumType enumType -> getPrimitiveJavaType(enumType.baseType());
            case TypeAlias typeAlias -> getJavaType(typeAlias.aliasedType());
            default -> "MemorySegment";
        };
    }

    private static final Set<String> JAVA_KEYWORDS = Set.of(
            "abstract",
            "final",
            "import",
            "package",
            "var"
    );

    /**
     * Gets a Java-safe identifier name.
     * <p>
     * Prevents clashes with keywords.
     * </p>
     *
     * @param name the original identifier/name
     * @return the Java-safe identifier
     */
    static String getJavaSafeName(String name) {
        if (JAVA_KEYWORDS.contains(name))
            return name + "_";
        return name;
    }

    /**
     * Writes Java code to initialize a variable with a given address layout.
     *
     * @param addressLayout the address layout
     * @param modifiers     modifiers for the variable (such as "static")
     */
    void writeAddressLayoutInitialization(AddressLayout addressLayout, String modifiers) {
        writer.printf("    %sAddressLayout %s = %s.",
                modifiers,
                addressLayout.name(),
                addressLayout.aligned() ? "ADDRESS" : "ADDRESS_UNALIGNED"
        );

        if (addressLayout.isForStruct()) {
            writer.printf("withTargetLayout(MemoryLayout.sequenceLayout(%d, JAVA_BYTE)",
                    addressLayout.structSize());
            if (addressLayout.packageSize() != 1) {
                writer.printf(".withByteAlignment(%d)", addressLayout.packageSize());
            }
            writer.print(")");
        } else {
            if (addressLayout == AddressLayout.pointerToAddress(true)
                    || addressLayout == AddressLayout.pointerToAddress(false)) {
                writer.print("withTargetLayout(ADDRESS)");
            } else if (addressLayout == AddressLayout.pointerToUnknown(true)
                    || addressLayout == AddressLayout.pointerToUnknown(false)) {
                writer.print("withTargetLayout(MemoryLayout.sequenceLayout(Long.MAX_VALUE, JAVA_BYTE))");
            } else {
                assert false : "Unexpected address layout type: " + addressLayout.name();
            }
        }

        writer.println(";");
    }

    /**
     * Gets the layout name for the specified metadata type and alignment.
     * <p>
     * If the layout name refers to a generated layout in another package,
     * the package name is included. If it is a built-in Java FFM layout or a
     * generated layout in the same package, the Java package name is omitted.
     * To determine the package, the type's namespace is compared against the
     * namespace.
     * </p>
     *
     * @param type             the type
     * @param packageSize      the package size / alignment
     * @param currentNamespace the current namespace
     * @return the layout name
     */
    String getLayoutName(Type type, int packageSize, Namespace currentNamespace) {
        return switch (type) {
            case Primitive primitive -> getPrimitiveLayoutName(primitive, packageSize);
            case TypeAlias typeAlias -> getLayoutName(typeAlias.aliasedType(), packageSize, currentNamespace);
            case Pointer pointer -> AddressLayout.getAddressLayout(pointer.referencedType(), packageSize >= 8).name();
            case Delegate ignored -> AddressLayout.pointerToAddress(packageSize >= 8).name();
            case ComInterface ignored -> AddressLayout.pointerToAddress(packageSize >= 8).name();
            case EnumType enumType -> getPrimitiveLayoutName(enumType.baseType(), packageSize);
            case Struct struct -> getStructLayoutName(struct, currentNamespace);
            default -> throw new AssertionError("Unexpected type: " + type.name());
        };
    }

    /**
     * Gets the layout name for the specified metadata type.
     * <p>
     * If the layout name refers to a generated layout in another package,
     * the package name is included. If it is a built-in Java FFM layout or a
     * generated layout in the same package, the Java package name is omitted.
     * To determine the package, the type's namespace is compared against the
     * namespace.
     * </p>
     *
     * @param type             the type
     * @param currentNamespace the current namespace
     * @return the layout name
     */
    String getLayoutName(Type type, Namespace currentNamespace) {
        return getLayoutName(type, 8, currentNamespace);
    }

    /**
     * Gets the layout name for the given struct type.
     * <p>
     * The Java package name is included if the struct is in a namespace different from the current namespace.
     * </p>
     *
     * @param struct           the struct type
     * @param currentNamespace the current namespace
     */
    private String getStructLayoutName(Struct struct, Namespace currentNamespace) {
        if (currentNamespace != struct.namespace()) {
            return toJavaPackageName(struct.namespace().name()) + "." + toJavaClassName(struct.name()) + ".layout()";
        } else {
            return toJavaClassName(struct.name()) + ".layout()";
        }
    }

    /**
     * Writes the layout name for the given struct type.
     * <p>
     * If the alignment is lower than the struct's natural alignment, it is assumed that a local
     * layout definition with the name {@code STRUCT$LAYOUT_UNALIGNED} exists.
     * </p>
     * <p>
     * The Java package name is included if the struct is in a namespace different from the current namespace.
     * </p>
     *
     * @param struct           the struct type
     * @param alignment        the alignment (in bytes)
     * @param currentNamespace the current namespace
     */
    void writeStructLayoutName(Struct struct, int alignment, Namespace currentNamespace) {
        if (struct.packageSize() > alignment) {
            writer.printf("%s$LAYOUT_UNALIGNED", struct.name());
        } else if (currentNamespace != struct.namespace()) {
            writer.printf("%s.%s.layout()",
                    toJavaPackageName(struct.namespace().name()),
                    toJavaClassName(struct.name()));
        } else {
            writer.printf("%s.layout()", toJavaClassName(struct.name()));
        }
    }

    /**
     * Gets the layout name for the given primitive type and alignment.
     *
     * @param type      the primitive type
     * @param alignment the alignment (in bytes)
     * @return the layout name
     */
    static String getPrimitiveLayoutName(Primitive type, int alignment) {
        var isUnaligned = LayoutRequirement.primitiveSize(type) > alignment;
        return switch (type.kind()) {
            case PrimitiveKind.INT64, PrimitiveKind.UINT64, PrimitiveKind.INT_PTR, PrimitiveKind.UINT_PTR ->
                    isUnaligned ? "JAVA_LONG_UNALIGNED" : "JAVA_LONG";
            case PrimitiveKind.INT32, PrimitiveKind.UINT32 -> isUnaligned ? "JAVA_INT_UNALIGNED" : "JAVA_INT";
            case PrimitiveKind.INT16, PrimitiveKind.UINT16, PrimitiveKind.CHAR ->
                    isUnaligned ? "JAVA_SHORT_UNALIGNED" : "JAVA_SHORT";
            case PrimitiveKind.BYTE, PrimitiveKind.SBYTE -> "JAVA_BYTE";
            case PrimitiveKind.SINGLE -> isUnaligned ? "JAVA_FLOAT_UNALIGNED" : "JAVA_FLOAT";
            case PrimitiveKind.DOUBLE -> isUnaligned ? "JAVA_DOUBLE_UNALIGNED" : "JAVA_DOUBLE";
            case PrimitiveKind.BOOL -> "JAVA_BOOLEAN";
            default -> throw new AssertionError("Unexpected primitive type: " + type.name());
        };
    }

    /**
     * Writes a Java expression for a given value, suitable for the given type.
     *
     * @param type  the type
     * @param value the value
     */
    void writeValue(Type type, Object value) {
        switch (type) {
            case Primitive primitive -> writePrimitiveValue(primitive, value);
            case TypeAlias typeAlias -> writeValue(typeAlias.aliasedType(), value);
            case Pointer ignored -> writePointerValue(value);
            default -> throw new AssertionError("Unexpected type: " + type.name());
        }
    }

    /**
     * Writes a Java expression to initialize a memory segment with a given address.
     *
     * @param address the address
     */
    void writePointerValue(Object address) {
        writer.printf("MemorySegment.ofAddress(%sL)", address);
    }

    /**
     * Writes a constant value for a primitive type.
     *
     * @param primitive the primitive type
     * @param value     the constant value
     */
    void writePrimitiveValue(Primitive primitive, Object value) {
        switch (primitive.kind()) {
            case PrimitiveKind.INT64, PrimitiveKind.UINT64 -> writer.printf("%dL", (Long) value);
            case PrimitiveKind.INT_PTR, PrimitiveKind.UINT_PTR -> writer.printf("%dL", ((Number) value).longValue());
            case PrimitiveKind.INT32, PrimitiveKind.UINT32 -> writer.print(((Integer) value).intValue());
            case PrimitiveKind.INT16, PrimitiveKind.UINT16 -> writer.print(((Short) value).intValue());
            case PrimitiveKind.BYTE -> writer.print(((Byte) value).intValue());
            case PrimitiveKind.SINGLE -> writer.printf("%ff", (Float) value);
            case PrimitiveKind.DOUBLE -> writer.print(((Double) value).doubleValue());
            default -> throw new AssertionError("Unexpected type: " + primitive.name());
        }
    }

    private static final String SPACES = "        ".repeat(10);

    /**
     * Gets the indent (multiple spaces) for the specified indenting.
     *
     * @param indenting the indenting (number of spaces)
     * @return the indent
     */
    static String getIndent(int indenting) {
        return SPACES.substring(0, indenting);
    }

    void writeIndent(int indenting) {
        writer.write(SPACES, 0, indenting);
    }

    /**
     * Writes a Java method called "createGuid" for creating a memory segment with a GUID.
     *
     * @param indenting the indenting (number of spaces)
     */
    void writeCreateGuidMethod(int indenting) {
        writer.printf("""
                %1$sprivate static MemorySegment createGuid(long v1, long v2) {
                %1$s    var seg = ARENA.allocate(16, 8);
                %1$s    seg.set(ValueLayout.JAVA_LONG, 0, v1);
                %1$s    seg.set(ValueLayout.JAVA_LONG, 8, v2);
                %1$s    return seg;
                %1$s}
                
                """, getIndent(indenting));
    }

    /**
     * Writes a Java variable initialization for a memory segment with a GUID.
     *
     * @param name      the variable name
     * @param uuid      the GUID value
     * @param indenting the indenting (number of spaces)
     */
    void writeGuidConstantMemorySegment(String name, UUID uuid, int indenting) {
        writer.printf("""
                        %2$sprivate static final MemorySegment %1$s$SEG = createGuid(%3$dL, %4$dL);
                        
                        """, name, getIndent(indenting),
                reorderMostSignificantBits(uuid.getMostSignificantBits()),
                Long.reverseBytes(uuid.getLeastSignificantBits()));
    }

    private static long reorderMostSignificantBits(long bits) {
        var data1 = (bits >> 32) & 0x00000000ffffffffL;
        var data2 = (bits << 16) & 0x0000ffff00000000L;
        var data3 = (bits << 48) & 0xffff000000000000L;
        return data1 | data2 | data3;
    }

    /**
     * Writes a simple formatted comment.
     *
     * @param format the format (printf-style)
     * @param args   the arguments for the format
     */
    void writeComment(String format, Object... args) {
        writer.println("    /**");
        writer.print("     * ");
        writer.printf(format, args);
        writer.println();
        writer.println("     */");
    }

    /**
     * Writes a comment with notes.
     * <p>
     * Each note is added in a separate paragraph. If a note is {@code null}, it is ignored.
     * </p>
     *
     * @param comment the comment
     * @param notes   the notes
     */
    protected void writeCommentWithNotes(String comment, String... notes) {
        writer.println("    /**");
        writer.print("     * ");
        writer.println(comment);
        for (String note : notes) {
            if (note != null) {
                writer.println("     * <p>");
                writer.print("     * ");
                writer.println(note);
                writer.println("     * </p>");
            }
        }
        writer.println("     */");
    }

    /**
     * Writes the type's documentation URL to the top level comment (if available).
     *
     * @param type the type
     */
    protected void writeDocumentationUrl(Type type) {
        var documentationUrl = type.documentationUrl();
        if (documentationUrl != null) {
            writer.printf("""
                             *
                             * @see <a href="%1$s">%2$s (Microsoft)</a>
                            """,
                    documentationUrl,
                    type.nativeName()
            );
        }
    }
}
