/*
 * Decompiled with CFR 0.152.
 */
package net.codecrete.windowsapi.winmd;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.Comparator;
import java.util.Set;
import net.codecrete.windowsapi.winmd.Blob;
import net.codecrete.windowsapi.winmd.LittleEndianDataInputStream;
import net.codecrete.windowsapi.winmd.WinmdException;
import net.codecrete.windowsapi.winmd.tables.ClassLayout;
import net.codecrete.windowsapi.winmd.tables.CodedIndexes;
import net.codecrete.windowsapi.winmd.tables.Constant;
import net.codecrete.windowsapi.winmd.tables.CustomAttribute;
import net.codecrete.windowsapi.winmd.tables.Field;
import net.codecrete.windowsapi.winmd.tables.FieldLayout;
import net.codecrete.windowsapi.winmd.tables.ImplMap;
import net.codecrete.windowsapi.winmd.tables.InterfaceImpl;
import net.codecrete.windowsapi.winmd.tables.MemberRef;
import net.codecrete.windowsapi.winmd.tables.MethodDef;
import net.codecrete.windowsapi.winmd.tables.NestedClass;
import net.codecrete.windowsapi.winmd.tables.Param;
import net.codecrete.windowsapi.winmd.tables.RowKeyTableIterable;
import net.codecrete.windowsapi.winmd.tables.Table;
import net.codecrete.windowsapi.winmd.tables.TableRangeIterable;
import net.codecrete.windowsapi.winmd.tables.TypeDef;
import net.codecrete.windowsapi.winmd.tables.TypeRef;

public class MetadataFile {
    private final LittleEndianDataInputStream inputStream;
    private String version;
    private MetadataStream[] streams;
    private byte[] blobHeap;
    private byte[] stringHeap;
    private final Table[] tables = new Table[64];
    private Table classLayouts;
    private Table constants;
    private Table customAttributes;
    private Table fields;
    private Table fieldLayouts;
    private Table implMaps;
    private Table interfaceImpls;
    private Table memberRefs;
    private Table methodDefs;
    private Table moduleRefs;
    private Table nestedClasses;
    private Table params;
    private Table typeDefs;
    private Table typeRefs;
    private int hasCustomAttributeIndexWidth;
    private int hasConstantIndexWidth;
    private int memberForwardedIndexWidth;
    private static final Set<Integer> USED_TABLES = Set.of(15, 11, 12, 4, 16, 28, 9, 10, 6, 26, 41, 8, 2, 1);

    public MetadataFile(InputStream inputStream) {
        this.inputStream = new LittleEndianDataInputStream(inputStream);
        try {
            this.read();
        }
        catch (IOException e) {
            throw new WinmdException("Failed to read .winmd file", e);
        }
    }

    public String getVersion() {
        return this.version;
    }

    MetadataStream[] getStreams() {
        return this.streams;
    }

    public ClassLayout getClassLayout(int parent) {
        int index = this.classLayouts.indexByPrimaryKey(parent, this.simpleIndexWidth(2), 6);
        if (index == 0) {
            return null;
        }
        int[] values = new int[3];
        this.classLayouts.getRow(index, values);
        return new ClassLayout(values[0], values[1], values[2]);
    }

    public Constant getConstant(int parent) {
        int index = this.constants.indexByPrimaryKey(parent, this.hasConstantIndexWidth, 2);
        assert (index != 0);
        int[] values = new int[3];
        this.constants.getRow(index, values);
        return new Constant(values[0], values[1], values[2]);
    }

    public Iterable<CustomAttribute> getCustomAttributes(int parent) {
        return new RowKeyTableIterable<CustomAttribute>(this.customAttributes, parent, this.hasCustomAttributeIndexWidth, index -> {
            int[] values = new int[3];
            this.customAttributes.getRow((int)index, values);
            return new CustomAttribute(values[0], values[1], values[2]);
        });
    }

    public Iterable<Field> getFields(int typeDefIndex) {
        int firstField = this.typeDefs.getValue(typeDefIndex, 4);
        int lastField = typeDefIndex + 1 <= this.typeDefs.numRows() ? this.typeDefs.getValue(typeDefIndex + 1, 4) - 1 : this.fields.numRows();
        assert (firstField <= lastField + 1);
        return new TableRangeIterable<Field>(firstField, lastField, index -> {
            int[] values = new int[3];
            this.fields.getRow((int)index, values);
            return new Field((int)index, values[0], values[1], values[2]);
        });
    }

    public FieldLayout getFieldLayout(int field) {
        int index = this.fieldLayouts.indexByPrimaryKey(field, this.simpleIndexWidth(4), 4);
        if (index == 0) {
            return null;
        }
        int[] values = new int[2];
        this.fieldLayouts.getRow(index, values);
        return new FieldLayout(values[0], values[1]);
    }

    public ImplMap getImplMap(int memberForwarded) {
        int index = this.implMaps.indexByPrimaryKey(memberForwarded, this.memberForwardedIndexWidth, 2);
        if (index == 0) {
            return null;
        }
        int[] values = new int[4];
        this.implMaps.getRow(index, values);
        return new ImplMap(values[0], values[1], values[2], values[3]);
    }

    public Iterable<InterfaceImpl> getInterfaceImpl(int classIndex) {
        return new RowKeyTableIterable<InterfaceImpl>(this.interfaceImpls, classIndex, this.simpleIndexWidth(2), index -> {
            int[] values = new int[2];
            this.interfaceImpls.getRow((int)index, values);
            return new InterfaceImpl(values[0], values[1]);
        });
    }

    public MemberRef getMemberRef(int index) {
        int[] values = new int[3];
        this.memberRefs.getRow(index, values);
        return new MemberRef(values[0], values[1], values[2]);
    }

    public MethodDef getMethodDef(int index) {
        int[] values = new int[6];
        this.methodDefs.getRow(index, values);
        return new MethodDef(index, values[0], values[1], values[2], values[3], values[4], values[5]);
    }

    public Iterable<MethodDef> getMethodDefs(int typeDefIndex) {
        int firstMethod = this.typeDefs.getValue(typeDefIndex, 5);
        int lastMethod = typeDefIndex + 1 <= this.typeDefs.numRows() ? this.typeDefs.getValue(typeDefIndex + 1, 5) - 1 : this.methodDefs.numRows();
        assert (firstMethod <= lastMethod + 1);
        return new TableRangeIterable<MethodDef>(firstMethod, lastMethod, this::getMethodDef);
    }

    public int getModuleRefName(int moduleRef) {
        int[] values = new int[1];
        this.moduleRefs.getRow(moduleRef, values);
        return values[0];
    }

    public NestedClass getNestedClass(int nestedClass) {
        int index = this.nestedClasses.indexByPrimaryKey(nestedClass, this.simpleIndexWidth(2), 0);
        if (index == 0) {
            return null;
        }
        int[] values = new int[2];
        this.nestedClasses.getRow(index, values);
        return new NestedClass(values[0], values[1]);
    }

    public Iterable<Param> getParameters(int methodDefIndex) {
        int firstParam = this.methodDefs.getValue(methodDefIndex, 5);
        int lastParam = methodDefIndex + 1 <= this.methodDefs.numRows() ? this.methodDefs.getValue(methodDefIndex + 1, 5) - 1 : this.params.numRows();
        assert (firstParam <= lastParam + 1);
        return new TableRangeIterable<Param>(firstParam, lastParam, index -> {
            int[] values = new int[3];
            this.params.getRow((int)index, values);
            return new Param((int)index, values[0], values[1], values[2]);
        });
    }

    public TypeDef getTypeDef(int typeDefIndex) {
        int[] values = new int[6];
        this.typeDefs.getRow(typeDefIndex, values);
        return new TypeDef(values[0], values[1], values[2], values[3], values[4], values[5]);
    }

    public Iterable<TypeDef> getTypeDefs() {
        return new TableRangeIterable<TypeDef>(1, this.typeDefs.numRows(), this::getTypeDef);
    }

    public int getTypeDefinitionCount() {
        return this.typeDefs.numRows();
    }

    public TypeRef getTypeRef(int index) {
        int[] values = new int[3];
        this.typeRefs.getRow(index, values);
        return new TypeRef(values[0], values[1], values[2]);
    }

    public String getString(int index) {
        if (index == 0) {
            return null;
        }
        int end = index;
        while (this.stringHeap[end] != 0) {
            ++end;
        }
        return new String(this.stringHeap, index, end - index, StandardCharsets.UTF_8);
    }

    public Blob getBlob(int index) {
        int length;
        int b1 = this.blobHeap[index] & 0xFF;
        if ((b1 & 0x80) == 0) {
            length = b1;
            ++index;
        } else if ((b1 & 0xC0) == 128) {
            length = ((b1 & 0x3F) << 8) + (this.blobHeap[index + 1] & 0xFF);
            index += 2;
        } else if ((b1 & 0xE0) == 192) {
            length = ((b1 & 0x1F) << 24) + ((this.blobHeap[index + 1] & 0xFF) << 16) + ((this.blobHeap[index + 2] & 0xFF) << 8) + (this.blobHeap[index + 3] & 0xFF);
            index += 4;
        } else {
            throw new WinmdException("Invalid data in blob");
        }
        return new Blob(this.blobHeap, index, length);
    }

    private void read() throws IOException {
        this.readPEHeaders();
        this.readMetadataHeader();
        this.readStreams();
    }

    private void readPEHeaders() throws IOException {
        byte[] magicBytes = new byte[2];
        this.inputStream.readFully(magicBytes);
        if (magicBytes[0] != 77 || magicBytes[1] != 90) {
            throw new WinmdException("Invalid data (expected magic bytes \"MZ\")");
        }
        this.inputStream.skipTo(60);
        int signatureOffset = this.inputStream.readInt();
        this.inputStream.skipTo(signatureOffset);
        magicBytes = new byte[4];
        this.inputStream.readFully(magicBytes);
        if (magicBytes[0] != 80 || magicBytes[1] != 69 || magicBytes[2] != 0 || magicBytes[3] != 0) {
            throw new WinmdException("Invalid data (expected magic bytes \"PZ\\0\\0\")");
        }
        this.inputStream.readUnsignedShort();
        int numSections = this.inputStream.readUnsignedShort();
        this.inputStream.readInt();
        this.inputStream.readInt();
        this.inputStream.readInt();
        int optionalHeaderSize = this.inputStream.readUnsignedShort();
        this.inputStream.readUnsignedShort();
        if (optionalHeaderSize < 2) {
            throw new WinmdException("Invalid data (expected optional header)");
        }
        int magicNumber = this.inputStream.readUnsignedShort();
        if (magicNumber != 267) {
            throw new WinmdException("Invalid data (expected magic number 0x10b)");
        }
        this.inputStream.skipTo(signatureOffset + 24 + 96);
        this.inputStream.skipNBytes(112);
        int clrRuntimeHeaderAddress = this.inputStream.readInt();
        this.inputStream.skipTo(signatureOffset + 24 + 96 + 128);
        Section[] sections = new Section[numSections];
        for (int i = 0; i < numSections; ++i) {
            this.inputStream.skipNBytes(8);
            int virtualSize = this.inputStream.readInt();
            int virtualAddress = this.inputStream.readInt();
            this.inputStream.skipNBytes(4);
            int pointerToRawData = this.inputStream.readInt();
            this.inputStream.skipNBytes(16);
            sections[i] = new Section(virtualSize, virtualAddress, pointerToRawData);
        }
        int clrRuntimeHeaderOffset = MetadataFile.getOffset(sections, clrRuntimeHeaderAddress);
        this.inputStream.skipTo(clrRuntimeHeaderOffset);
        int size = this.inputStream.readInt();
        if (size != 72) {
            throw new WinmdException("Invalid data (unexpected size in CLR runtime header)");
        }
        this.inputStream.skipNBytes(4);
        int metaDataAddress = this.inputStream.readInt();
        int metaDataOffset = MetadataFile.getOffset(sections, metaDataAddress);
        this.inputStream.skipTo(metaDataOffset);
    }

    private void readMetadataHeader() throws IOException {
        long metadataRootOffset = this.inputStream.getOffset();
        int magicBytes = this.inputStream.readInt();
        if (magicBytes != 1112167234) {
            throw new WinmdException("Invalid data (invalid magic bytes in metadata header)");
        }
        this.inputStream.skipNBytes(8);
        int versionLength = this.inputStream.readInt();
        byte[] versionBytes = new byte[versionLength];
        this.inputStream.readFully(versionBytes);
        this.version = MetadataFile.createString(versionBytes);
        this.inputStream.skipNBytes(2);
        int numStreams = this.inputStream.readUnsignedShort();
        this.streams = new MetadataStream[numStreams];
        for (int i = 0; i < numStreams; ++i) {
            int offset = this.inputStream.readInt() + (int)metadataRootOffset;
            int size = this.inputStream.readInt();
            String name = this.readUtf8String();
            this.streams[i] = new MetadataStream(offset, size, name);
        }
        Arrays.sort(this.streams, Comparator.comparingInt(s -> s.offset));
    }

    private void readStreams() throws IOException {
        block10: for (MetadataStream stream : this.streams) {
            this.inputStream.skipTo(stream.offset);
            switch (stream.name) {
                case "#~": {
                    this.readTablesHeader();
                    this.readTables();
                    continue block10;
                }
                case "#Strings": {
                    this.stringHeap = new byte[stream.size];
                    this.inputStream.readFully(this.stringHeap);
                    continue block10;
                }
                case "#Blob": {
                    this.blobHeap = new byte[stream.size];
                    this.inputStream.readFully(this.blobHeap);
                    continue block10;
                }
            }
        }
    }

    private void readTablesHeader() throws IOException {
        int headerOffset = (int)this.inputStream.getOffset();
        this.inputStream.skipTo(headerOffset + 6);
        byte heapSizes = this.inputStream.readByte();
        int stringIndexWidth = (heapSizes & 1) != 0 ? 4 : 2;
        int guidIndexWidth = (heapSizes & 2) != 0 ? 4 : 2;
        int blobIndexWidth = (heapSizes & 4) != 0 ? 4 : 2;
        this.inputStream.skipTo(headerOffset + 8);
        long availableTables = this.inputStream.readLong();
        this.inputStream.skipTo(headerOffset + 24);
        for (int i = 0; i < 64; ++i) {
            if ((availableTables & 1L << i) == 0L) continue;
            this.tables[i] = new Table(this.inputStream.readInt());
        }
        int typeDefOrRefIndexWidth = this.codedIndexWidth(CodedIndexes.TYPE_DEF_OR_REF_TABLES);
        this.hasConstantIndexWidth = this.codedIndexWidth(CodedIndexes.HAS_CONSTANT_TABLES);
        this.hasCustomAttributeIndexWidth = this.codedIndexWidth(CodedIndexes.HAS_CUSTOM_ATTRIBUTE_TABLES);
        int hasFieldMarshalIndexWidth = this.codedIndexWidth(CodedIndexes.HAS_FIELD_MARSHAL_TABLES);
        int hasDeclSecurityIndexWidth = this.codedIndexWidth(CodedIndexes.HAS_DECL_SECURITY_TABLES);
        int memberRefParentIndexWidth = this.codedIndexWidth(CodedIndexes.MEMBER_REF_PARENT_TABLES);
        int hasSemanticsIndexWidth = this.codedIndexWidth(CodedIndexes.HAS_SEMANTICS_TABLES);
        int methodDefOrRefIndexWidth = this.codedIndexWidth(CodedIndexes.METHOD_DEF_OR_REF_TABLES);
        this.memberForwardedIndexWidth = this.codedIndexWidth(CodedIndexes.MEMBER_FORWARDED_TABLES);
        int implementationIndexWidth = this.codedIndexWidth(CodedIndexes.IMPLEMENTATION_TABLES);
        int customAttributeTypeIndexWidth = this.codedIndexWidth(CodedIndexes.CUSTOM_ATTRIBUTE_TYPE_TABLES);
        int resolutionScopeIndexWidth = this.codedIndexWidth(CodedIndexes.RESOLUTION_SCOPE_TABLES);
        int typeOrMethodDefIndexWidth = this.codedIndexWidth(CodedIndexes.TYPE_OR_METHOD_DEF_TABLES);
        this.setColumnWidths(32, 4, 8, 4, blobIndexWidth, stringIndexWidth, stringIndexWidth);
        this.setColumnWidths(34, 4, 4, 4);
        this.setColumnWidths(33, 4);
        this.setColumnWidths(35, 8, 4, blobIndexWidth, stringIndexWidth, stringIndexWidth, blobIndexWidth);
        this.setColumnWidths(37, 4, 4, 4, this.simpleIndexWidth(35));
        this.setColumnWidths(36, 4, this.simpleIndexWidth(35));
        this.setColumnWidths(15, 2, 4, this.simpleIndexWidth(2));
        this.setColumnWidths(11, 2, this.hasConstantIndexWidth, blobIndexWidth);
        this.setColumnWidths(12, this.hasCustomAttributeIndexWidth, customAttributeTypeIndexWidth, blobIndexWidth);
        this.setColumnWidths(14, 2, hasDeclSecurityIndexWidth, blobIndexWidth);
        this.setColumnWidths(20, 2, stringIndexWidth, typeDefOrRefIndexWidth);
        this.setColumnWidths(18, this.simpleIndexWidth(2), this.simpleIndexWidth(20));
        this.setColumnWidths(39, 4, 4, stringIndexWidth, stringIndexWidth, implementationIndexWidth);
        this.setColumnWidths(4, 2, stringIndexWidth, blobIndexWidth);
        this.setColumnWidths(16, 4, this.simpleIndexWidth(4));
        this.setColumnWidths(13, hasFieldMarshalIndexWidth, blobIndexWidth);
        this.setColumnWidths(29, 4, this.simpleIndexWidth(4));
        this.setColumnWidths(38, 4, stringIndexWidth, blobIndexWidth);
        this.setColumnWidths(42, 2, 2, typeOrMethodDefIndexWidth, stringIndexWidth);
        this.setColumnWidths(44, this.simpleIndexWidth(42), typeDefOrRefIndexWidth);
        this.setColumnWidths(28, 2, this.memberForwardedIndexWidth, stringIndexWidth, this.simpleIndexWidth(26));
        this.setColumnWidths(9, this.simpleIndexWidth(2), typeDefOrRefIndexWidth);
        this.setColumnWidths(40, 4, 4, stringIndexWidth, implementationIndexWidth);
        this.setColumnWidths(10, memberRefParentIndexWidth, stringIndexWidth, blobIndexWidth);
        this.setColumnWidths(6, 4, 2, 2, stringIndexWidth, blobIndexWidth, this.simpleIndexWidth(8));
        this.setColumnWidths(25, this.simpleIndexWidth(2), methodDefOrRefIndexWidth, methodDefOrRefIndexWidth);
        this.setColumnWidths(24, 2, this.simpleIndexWidth(6), hasSemanticsIndexWidth);
        this.setColumnWidths(43, methodDefOrRefIndexWidth, blobIndexWidth);
        this.setColumnWidths(0, 2, stringIndexWidth, guidIndexWidth, guidIndexWidth, guidIndexWidth);
        this.setColumnWidths(26, stringIndexWidth);
        this.setColumnWidths(41, this.simpleIndexWidth(2), this.simpleIndexWidth(2));
        this.setColumnWidths(8, 2, 2, stringIndexWidth);
        this.setColumnWidths(23, 2, stringIndexWidth, blobIndexWidth);
        this.setColumnWidths(21, this.simpleIndexWidth(2), this.simpleIndexWidth(23));
        this.setColumnWidths(17, blobIndexWidth);
        this.setColumnWidths(2, 4, stringIndexWidth, stringIndexWidth, typeDefOrRefIndexWidth, this.simpleIndexWidth(4), this.simpleIndexWidth(6));
        this.setColumnWidths(1, resolutionScopeIndexWidth, stringIndexWidth, stringIndexWidth);
        this.setColumnWidths(27, blobIndexWidth);
    }

    private void setColumnWidths(int tableIndex, int ... widths) {
        if (this.tables[tableIndex] != null) {
            this.tables[tableIndex].setColumnWidths(widths);
        }
    }

    private void readTables() throws IOException {
        for (int i = 0; i < 63; ++i) {
            if (this.tables[i] == null || this.tables[i].numRows() == 0) continue;
            int tableLength = this.tables[i].numRows() * this.tables[i].width();
            if (USED_TABLES.contains(i)) {
                byte[] data = new byte[tableLength];
                this.inputStream.readFully(data);
                this.tables[i].setData(data);
                continue;
            }
            this.inputStream.skipNBytes(tableLength);
        }
        this.classLayouts = this.tables[15];
        this.constants = this.tables[11];
        this.customAttributes = this.tables[12];
        this.fields = this.tables[4];
        this.fieldLayouts = this.tables[16];
        this.implMaps = this.tables[28];
        this.interfaceImpls = this.tables[9];
        this.memberRefs = this.tables[10];
        this.methodDefs = this.tables[6];
        this.moduleRefs = this.tables[26];
        this.nestedClasses = this.tables[41];
        this.params = this.tables[8];
        this.typeDefs = this.tables[2];
        this.typeRefs = this.tables[1];
    }

    private int simpleIndexWidth(int table) {
        return this.tables[table] != null ? this.tables[table].indexWidth() : 2;
    }

    private int codedIndexWidth(int ... tableIndexes) {
        int numBitsTable = 32 - Integer.numberOfLeadingZeros(tableIndexes.length - 1);
        int max16BitIndex = 1 << 16 - numBitsTable;
        for (int index : tableIndexes) {
            if (this.tables[index] == null || this.tables[index].numRows() < max16BitIndex) continue;
            return 4;
        }
        return 2;
    }

    private String readUtf8String() throws IOException {
        int length;
        ByteArrayOutputStream utf8Buffer = new ByteArrayOutputStream();
        byte[] fourBytes = new byte[4];
        do {
            this.inputStream.readFully(fourBytes);
            for (length = 0; length < 4 && fourBytes[length] != 0; ++length) {
            }
            utf8Buffer.write(fourBytes, 0, length);
        } while (length == 4);
        return utf8Buffer.toString(StandardCharsets.UTF_8);
    }

    private static int getOffset(Section[] sections, int virtualAddress) {
        for (Section section : sections) {
            if (section.virtualAddress > virtualAddress || virtualAddress >= section.virtualAddress + section.virtualSize) continue;
            return virtualAddress - section.virtualAddress + section.pointerToRawData;
        }
        throw new WinmdException("Invalid data (virtual address outside sections)");
    }

    private static String createString(byte[] codeUnits) {
        int length = 0;
        while (codeUnits[length] != 0) {
            ++length;
        }
        return new String(codeUnits, 0, length, StandardCharsets.UTF_8);
    }

    record MetadataStream(int offset, int size, String name) {
    }

    private record Section(int virtualSize, int virtualAddress, int pointerToRawData) {
    }
}

