/*
 * Decompiled with CFR 0.152.
 */
package net.raphimc.javadowngrader.transformer.j8;

import java.util.Arrays;
import net.raphimc.javadowngrader.util.ASMUtil;
import org.objectweb.asm.Type;
import org.objectweb.asm.tree.AbstractInsnNode;
import org.objectweb.asm.tree.ClassNode;
import org.objectweb.asm.tree.InsnList;
import org.objectweb.asm.tree.InsnNode;
import org.objectweb.asm.tree.InvokeDynamicInsnNode;
import org.objectweb.asm.tree.LdcInsnNode;
import org.objectweb.asm.tree.MethodInsnNode;
import org.objectweb.asm.tree.MethodNode;
import org.objectweb.asm.tree.TypeInsnNode;
import org.objectweb.asm.tree.VarInsnNode;

public class StringConcatFactoryReplacer {
    private static final char STACK_ARG_CONSTANT = '\u0001';
    private static final char BSM_ARG_CONSTANT = '\u0002';

    public static void replace(ClassNode classNode) {
        for (MethodNode methodNode : classNode.methods) {
            for (AbstractInsnNode instruction : methodNode.instructions.toArray()) {
                int i;
                if (instruction.getOpcode() != 186) continue;
                InvokeDynamicInsnNode insn = (InvokeDynamicInsnNode)instruction;
                if (!insn.bsm.getOwner().equals("java/lang/invoke/StringConcatFactory") || !insn.bsm.getName().equals("makeConcatWithConstants")) continue;
                String pattern = (String)insn.bsmArgs[0];
                Type[] stackArgs = Type.getArgumentTypes((String)insn.desc);
                Object[] bsmArgs = Arrays.copyOfRange(insn.bsmArgs, 1, insn.bsmArgs.length);
                int stackArgsCount = StringConcatFactoryReplacer.count(pattern, '\u0001');
                int bsmArgsCount = StringConcatFactoryReplacer.count(pattern, '\u0002');
                if (stackArgs.length != stackArgsCount) {
                    throw new IllegalStateException("Stack args count does not match");
                }
                if (bsmArgs.length != bsmArgsCount) {
                    throw new IllegalStateException("BSM args count does not match");
                }
                int freeVarIndex = ASMUtil.getFreeVarIndex(methodNode);
                int[] stackIndices = new int[stackArgsCount];
                for (i = 0; i < stackArgs.length; ++i) {
                    stackIndices[i] = freeVarIndex;
                    freeVarIndex += stackArgs[i].getSize();
                }
                for (i = stackIndices.length - 1; i >= 0; --i) {
                    methodNode.instructions.insertBefore((AbstractInsnNode)insn, (AbstractInsnNode)new VarInsnNode(stackArgs[i].getOpcode(54), stackIndices[i]));
                }
                InsnList converted = StringConcatFactoryReplacer.convertStringConcatFactory(pattern, stackArgs, stackIndices, bsmArgs);
                methodNode.instructions.insertBefore((AbstractInsnNode)insn, converted);
                methodNode.instructions.remove((AbstractInsnNode)insn);
            }
        }
    }

    private static InsnList convertStringConcatFactory(String pattern, Type[] stackArgs, int[] stackIndices, Object[] bsmArgs) {
        InsnList insns = new InsnList();
        char[] chars = pattern.toCharArray();
        int stackArgsIndex = 0;
        int bsmArgsIndex = 0;
        StringBuilder partBuilder = new StringBuilder();
        insns.add((AbstractInsnNode)new TypeInsnNode(187, "java/lang/StringBuilder"));
        insns.add((AbstractInsnNode)new InsnNode(89));
        insns.add((AbstractInsnNode)new MethodInsnNode(183, "java/lang/StringBuilder", "<init>", "()V"));
        for (char c : chars) {
            if (c == '\u0001') {
                if (partBuilder.length() != 0) {
                    insns.add((AbstractInsnNode)new LdcInsnNode((Object)partBuilder.toString()));
                    insns.add((AbstractInsnNode)new MethodInsnNode(182, "java/lang/StringBuilder", "append", "(Ljava/lang/String;)Ljava/lang/StringBuilder;"));
                    partBuilder = new StringBuilder();
                }
                Type stackArg = stackArgs[stackArgsIndex++];
                int stackIndex = stackIndices[stackArgsIndex - 1];
                if (stackArg.getSort() == 10) {
                    insns.add((AbstractInsnNode)new VarInsnNode(25, stackIndex));
                    insns.add((AbstractInsnNode)new MethodInsnNode(182, "java/lang/StringBuilder", "append", "(Ljava/lang/Object;)Ljava/lang/StringBuilder;"));
                    continue;
                }
                if (stackArg.getSort() == 9) {
                    insns.add((AbstractInsnNode)new VarInsnNode(25, stackIndex));
                    insns.add((AbstractInsnNode)new MethodInsnNode(184, "java/util/Arrays", "toString", "([Ljava/lang/Object;)Ljava/lang/String;"));
                    insns.add((AbstractInsnNode)new MethodInsnNode(182, "java/lang/StringBuilder", "append", "(Ljava/lang/String;)Ljava/lang/StringBuilder;"));
                    continue;
                }
                insns.add((AbstractInsnNode)new VarInsnNode(stackArg.getOpcode(21), stackIndex));
                insns.add((AbstractInsnNode)new MethodInsnNode(182, "java/lang/StringBuilder", "append", "(" + stackArg.getDescriptor() + ")Ljava/lang/StringBuilder;"));
                continue;
            }
            if (c == '\u0002') {
                insns.add((AbstractInsnNode)new LdcInsnNode(bsmArgs[bsmArgsIndex++]));
                insns.add((AbstractInsnNode)new MethodInsnNode(182, "java/lang/StringBuilder", "append", "(Ljava/lang/Object;)Ljava/lang/StringBuilder;"));
                continue;
            }
            partBuilder.append(c);
        }
        if (partBuilder.length() != 0) {
            insns.add((AbstractInsnNode)new LdcInsnNode((Object)partBuilder.toString()));
            insns.add((AbstractInsnNode)new MethodInsnNode(182, "java/lang/StringBuilder", "append", "(Ljava/lang/String;)Ljava/lang/StringBuilder;"));
        }
        insns.add((AbstractInsnNode)new MethodInsnNode(182, "java/lang/StringBuilder", "toString", "()Ljava/lang/String;"));
        return insns;
    }

    private static int count(String s, char search) {
        char[] chars = s.toCharArray();
        int count = 0;
        for (char c : chars) {
            if (c != search) continue;
            ++count;
        }
        return count;
    }
}

