/*
 * Decompiled with CFR 0.152.
 */
package dev.langchain4j.community.data.document.transformer.graph;

import dev.langchain4j.Experimental;
import dev.langchain4j.community.data.document.graph.GraphDocument;
import dev.langchain4j.community.data.document.graph.GraphEdge;
import dev.langchain4j.community.data.document.graph.GraphNode;
import dev.langchain4j.community.data.document.transformer.graph.GraphTransformer;
import dev.langchain4j.community.data.document.transformer.graph.LLMGraphTransformerUtils;
import dev.langchain4j.data.document.Document;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.SystemMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.internal.RetryUtils;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.internal.ValidationUtils;
import dev.langchain4j.model.chat.ChatModel;
import dev.langchain4j.model.chat.response.ChatResponse;
import dev.langchain4j.model.input.PromptTemplate;
import java.util.HashSet;
import java.util.List;
import java.util.Map;

@Experimental
public class LLMGraphTransformer
implements GraphTransformer {
    private static final String DEFAULT_NODE_TYPE = "Node";
    private static final PromptTemplate SYSTEM_TEMPLATE = PromptTemplate.from((String)"You are a top-tier algorithm designed for extracting information in structured formats to build a knowledge graph.\nYour task is to identify entities and relations from a given text and generate output in JSON format.\nEach object should have keys: 'head', 'head_type', 'relation', 'tail', and 'tail_type'.\n{{nodes}}\n{{rels}}\nIMPORTANT NOTES:\n- Don't add any explanation or extra text.\n{{additional}}\n");
    private static final PromptTemplate USER_TEMPLATE = PromptTemplate.from((String)"Based on the following example, extract entities and relations from the provided text.\n{{nodes}}\n{{rels}}\nBelow are a number of examples of text and their extracted entities and relationships.\n{{examples}}\n{{additional}}\nFor the following text, extract entities and relations as in the provided example.\nText: {{input}}\n");
    private final List<String> allowedNodes;
    private final List<String> allowedRelationships;
    private final List<ChatMessage> prompt;
    private final String examples;
    private final String additionalInstructions;
    private final ChatModel chatModel;
    private final Integer maxAttempts;

    public LLMGraphTransformer(ChatModel chatModel, List<String> allowedNodes, List<String> allowedRelationships, List<ChatMessage> prompt, String additionalInstructions, String examples, Integer maxAttempts) {
        this.chatModel = (ChatModel)ValidationUtils.ensureNotNull((Object)chatModel, (String)"chatModel");
        this.examples = (String)ValidationUtils.ensureNotNull((Object)examples, (String)"examples");
        this.allowedNodes = Utils.getOrDefault(allowedNodes, List.of());
        this.allowedRelationships = Utils.getOrDefault(allowedRelationships, List.of());
        this.prompt = prompt;
        this.maxAttempts = (Integer)Utils.getOrDefault((Object)maxAttempts, (Object)1);
        this.additionalInstructions = (String)Utils.getOrDefault((Object)additionalInstructions, (Object)"");
    }

    public static Builder builder() {
        return new Builder();
    }

    public List<ChatMessage> createUnstructuredPrompt(String text) {
        if (this.prompt != null && !this.prompt.isEmpty()) {
            return this.prompt;
        }
        boolean withAllowedNodes = this.allowedNodes != null && !this.allowedNodes.isEmpty();
        boolean withAllowedRels = this.allowedRelationships != null && !this.allowedRelationships.isEmpty();
        SystemMessage systemMessage = SYSTEM_TEMPLATE.apply(Map.of("nodes", withAllowedNodes ? "The 'head_type' and 'tail_type' must be one of: " + String.valueOf(this.allowedNodes) : "", "rels", withAllowedRels ? "The 'relation' must be one of: " + String.valueOf(this.allowedRelationships) : "", "additional", this.additionalInstructions)).toSystemMessage();
        UserMessage userMessage = USER_TEMPLATE.apply(Map.of("nodes", withAllowedNodes ? "# ENTITY TYPES:\n" + String.valueOf(this.allowedNodes) : "", "rels", withAllowedRels ? "# RELATION TYPES:\n" + String.valueOf(this.allowedRelationships) : "", "examples", this.examples, "additional", this.additionalInstructions, "input", text)).toUserMessage();
        return List.of(systemMessage, userMessage);
    }

    public GraphDocument transform(Document document) {
        String text = document.text();
        List<ChatMessage> messages = this.createUnstructuredPrompt(text);
        HashSet<GraphNode> nodesSet = new HashSet<GraphNode>();
        HashSet<GraphEdge> relationships = new HashSet<GraphEdge>();
        List<Map<String, String>> parsedJson = this.getJsonResult(messages);
        if (parsedJson == null || parsedJson.isEmpty()) {
            return null;
        }
        for (Map<String, String> rel : parsedJson) {
            if (!rel.containsKey("head") || !rel.containsKey("tail") || !rel.containsKey("relation")) continue;
            GraphNode sourceNode = GraphNode.from((String)rel.get("head"), (String)rel.getOrDefault("head_type", DEFAULT_NODE_TYPE));
            GraphNode targetNode = GraphNode.from((String)rel.get("tail"), (String)rel.getOrDefault("tail_type", DEFAULT_NODE_TYPE));
            nodesSet.add(sourceNode);
            nodesSet.add(targetNode);
            String relation = rel.get("relation");
            GraphEdge edge = GraphEdge.from((GraphNode)sourceNode, (GraphNode)targetNode, (String)relation);
            relationships.add(edge);
        }
        if (nodesSet.isEmpty()) {
            return null;
        }
        return new GraphDocument(nodesSet, relationships, document);
    }

    private List<Map<String, String>> getJsonResult(List<ChatMessage> messages) {
        return (List)RetryUtils.withRetry(() -> {
            ChatResponse chat = this.chatModel.chat(messages);
            return (List)LLMGraphTransformerUtils.parseJson(LLMGraphTransformerUtils.getBacktickText(chat.aiMessage().text()));
        }, (int)this.maxAttempts);
    }

    public static class Builder {
        private ChatModel model;
        private List<String> allowedNodes;
        private List<String> allowedRelationships;
        private List<ChatMessage> prompt;
        private String additionalInstructions = "";
        private String examples;
        private Integer maxAttempts = 1;

        public Builder model(ChatModel model) {
            this.model = model;
            return this;
        }

        public Builder examples(String examples) {
            this.examples = examples;
            return this;
        }

        public Builder allowedNodes(List<String> allowedNodes) {
            this.allowedNodes = allowedNodes;
            return this;
        }

        public Builder allowedRelationships(List<String> allowedRelationships) {
            this.allowedRelationships = allowedRelationships;
            return this;
        }

        public Builder prompt(List<ChatMessage> prompt) {
            this.prompt = prompt;
            return this;
        }

        public Builder additionalInstructions(String additionalInstructions) {
            this.additionalInstructions = additionalInstructions;
            return this;
        }

        public Builder maxAttempts(Integer maxAttempts) {
            this.maxAttempts = maxAttempts;
            return this;
        }

        public LLMGraphTransformer build() {
            return new LLMGraphTransformer(this.model, this.allowedNodes, this.allowedRelationships, this.prompt, this.additionalInstructions, this.examples, this.maxAttempts);
        }
    }
}

