/*
 * Decompiled with CFR 0.152.
 */
package dev.langchain4j.model.gpullama3;

import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.SystemMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.model.chat.request.ChatRequest;
import java.io.IOException;
import java.lang.ref.Cleaner;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.List;
import java.util.Set;
import java.util.function.IntConsumer;
import org.beehive.gpullama3.auxiliary.LastRunMetrics;
import org.beehive.gpullama3.inference.sampler.Sampler;
import org.beehive.gpullama3.inference.state.State;
import org.beehive.gpullama3.model.Model;
import org.beehive.gpullama3.model.format.ChatFormat;
import org.beehive.gpullama3.model.loader.ModelLoader;
import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan;

abstract class GPULlama3BaseModel
implements AutoCloseable {
    private static final Cleaner CLEANER = Cleaner.create();
    private final Integer START_POSITION = 0;
    State state;
    List<Integer> promptTokens;
    ChatFormat chatFormat;
    TornadoVMMasterPlan tornadoVMPlan;
    private Integer maxTokens;
    private Boolean onGPU;
    private Model model;
    private Sampler sampler;
    private Cleaner.Cleanable cleanable;
    private boolean closed = false;

    GPULlama3BaseModel() {
    }

    public void init(Path modelPath, Double temperature, Double topP, Integer seed, Integer maxTokens, Boolean onGPU) {
        this.maxTokens = maxTokens;
        this.onGPU = onGPU;
        try {
            this.model = ModelLoader.loadModel((Path)modelPath, (int)maxTokens, (boolean)true, (boolean)onGPU);
            this.state = this.model.createNewState();
            this.sampler = Sampler.selectSampler((int)this.model.configuration().vocabularySize(), (float)temperature.floatValue(), (float)topP.floatValue(), (long)seed.intValue());
            this.chatFormat = this.model.chatFormat();
            if (onGPU.booleanValue()) {
                this.tornadoVMPlan = TornadoVMMasterPlan.initializeTornadoVMPlan((State)this.state, (Model)this.model);
                this.cleanable = CLEANER.register(this, new TornadoVMCleanupAction(this.tornadoVMPlan));
            } else {
                this.cleanable = null;
            }
        }
        catch (IOException e) {
            throw new RuntimeException("Failed to load model from " + String.valueOf(modelPath), e);
        }
    }

    public Model getModel() {
        return this.model;
    }

    public Sampler getSampler() {
        return this.sampler;
    }

    public String modelResponse(ChatRequest request, IntConsumer tokenConsumer) {
        this.promptTokens = new ArrayList<Integer>();
        if (this.model.shouldAddBeginOfText()) {
            this.promptTokens.add(this.chatFormat.getBeginOfText());
        }
        this.processPromptMessages(request.messages());
        Set stopTokens = this.chatFormat.getStopTokens();
        List responseTokens = this.onGPU != false ? this.model.generateTokensGPU(this.state, this.START_POSITION.intValue(), this.promptTokens.subList(this.START_POSITION, this.promptTokens.size()), stopTokens, this.maxTokens.intValue(), this.sampler, false, tokenConsumer, this.tornadoVMPlan) : this.model.generateTokens(this.state, this.START_POSITION.intValue(), this.promptTokens.subList(this.START_POSITION, this.promptTokens.size()), stopTokens, this.maxTokens.intValue(), this.sampler, false, tokenConsumer);
        Integer stopToken = null;
        if (!responseTokens.isEmpty() && stopTokens.contains(responseTokens.getLast())) {
            stopToken = (Integer)responseTokens.getLast();
            responseTokens.removeLast();
        }
        String responseText = this.model.tokenizer().decode(responseTokens);
        this.promptTokens.addAll(responseTokens);
        if (stopToken != null) {
            this.promptTokens.add(stopToken);
        }
        if (stopToken == null) {
            return "Ran out of context length...\n Increase context length with by passing to llama-tornado --max-tokens XXX";
        }
        return responseText;
    }

    public void printLastMetrics() {
        LastRunMetrics.printMetrics();
    }

    private void processPromptMessages(List<ChatMessage> messageList) {
        for (ChatMessage msg : messageList) {
            if (msg instanceof UserMessage) {
                UserMessage userMessage = (UserMessage)msg;
                this.promptTokens.addAll(this.chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.USER, userMessage.singleText())));
                continue;
            }
            if (msg instanceof SystemMessage) {
                SystemMessage systemMessage = (SystemMessage)msg;
                if (this.model.shouldAddSystemPrompt()) {
                    this.promptTokens.addAll(this.chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.SYSTEM, systemMessage.text())));
                    continue;
                }
            }
            if (!(msg instanceof AiMessage)) continue;
            AiMessage aiMessage = (AiMessage)msg;
            this.promptTokens.addAll(this.chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.ASSISTANT, aiMessage.text())));
        }
        this.promptTokens.addAll(this.chatFormat.encodeHeader(new ChatFormat.Message(ChatFormat.Role.ASSISTANT, "")));
    }

    public void freeTornadoVMGPUResources() {
        if (!this.closed && this.cleanable != null) {
            this.cleanable.clean();
            this.closed = true;
        }
    }

    @Override
    public void close() {
        this.freeTornadoVMGPUResources();
    }

    private static class TornadoVMCleanupAction
    implements Runnable {
        private final TornadoVMMasterPlan plan;

        TornadoVMCleanupAction(TornadoVMMasterPlan plan) {
            this.plan = plan;
        }

        @Override
        public void run() {
            if (this.plan != null) {
                try {
                    this.plan.freeTornadoExecutionPlan();
                }
                catch (Exception e) {
                    System.err.println("Error while cleaning up TornadoVM resources: " + e.getMessage());
                }
            }
        }
    }
}

