package opennlp.dl.doccat;

import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.io.File;
import java.io.IOException;
import java.nio.LongBuffer;
import java.nio.file.Files;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.SortedMap;
import java.util.TreeMap;
import java.util.stream.IntStream;
import opennlp.dl.AbstractDL;
import opennlp.dl.InferenceOptions;
import opennlp.dl.Tokens;
import opennlp.dl.doccat.scoring.ClassificationScoringStrategy;
import opennlp.tools.doccat.DocumentCategorizer;
import opennlp.tools.tokenize.WordpieceTokenizer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:opennlp/dl/doccat/DocumentCategorizerDL.class */
public class DocumentCategorizerDL extends AbstractDL implements DocumentCategorizer {
    private static final Logger logger = LoggerFactory.getLogger(DocumentCategorizerDL.class);
    private final Map<Integer, String> categories;
    private final ClassificationScoringStrategy classificationScoringStrategy;
    private final InferenceOptions inferenceOptions;

    public DocumentCategorizerDL(File file, File file2, Map<Integer, String> map, ClassificationScoringStrategy classificationScoringStrategy, InferenceOptions inferenceOptions) throws IOException, OrtException {
        this.env = OrtEnvironment.getEnvironment();
        OrtSession.SessionOptions sessionOptions = new OrtSession.SessionOptions();
        if (inferenceOptions.isGpu()) {
            sessionOptions.addCUDA(inferenceOptions.getGpuDeviceId());
        }
        this.session = this.env.createSession(file.getPath(), sessionOptions);
        this.vocab = loadVocab(file2);
        this.tokenizer = new WordpieceTokenizer(this.vocab.keySet());
        this.categories = map;
        this.classificationScoringStrategy = classificationScoringStrategy;
        this.inferenceOptions = inferenceOptions;
    }

    public DocumentCategorizerDL(File file, File file2, File file3, ClassificationScoringStrategy classificationScoringStrategy, InferenceOptions inferenceOptions) throws IOException, OrtException {
        this.env = OrtEnvironment.getEnvironment();
        OrtSession.SessionOptions sessionOptions = new OrtSession.SessionOptions();
        if (inferenceOptions.isGpu()) {
            sessionOptions.addCUDA(inferenceOptions.getGpuDeviceId());
        }
        this.session = this.env.createSession(file.getPath(), sessionOptions);
        this.vocab = loadVocab(file2);
        this.tokenizer = new WordpieceTokenizer(this.vocab.keySet());
        this.categories = readCategoriesFromFile(file3);
        this.classificationScoringStrategy = classificationScoringStrategy;
        this.inferenceOptions = inferenceOptions;
    }

    public double[] categorize(String[] strArr) {
        try {
            List<Tokens> list = tokenize(strArr[0]);
            LinkedList linkedList = new LinkedList();
            for (Tokens tokens : list) {
                HashMap hashMap = new HashMap();
                hashMap.put(AbstractDL.INPUT_IDS, OnnxTensor.createTensor(this.env, LongBuffer.wrap(tokens.ids()), new long[]{1, tokens.ids().length}));
                if (this.inferenceOptions.isIncludeAttentionMask()) {
                    hashMap.put(AbstractDL.ATTENTION_MASK, OnnxTensor.createTensor(this.env, LongBuffer.wrap(tokens.mask()), new long[]{1, tokens.mask().length}));
                }
                if (this.inferenceOptions.isIncludeTokenTypeIds()) {
                    hashMap.put(AbstractDL.TOKEN_TYPE_IDS, OnnxTensor.createTensor(this.env, LongBuffer.wrap(tokens.types()), new long[]{1, tokens.types().length}));
                }
                linkedList.add(softmax(((float[][]) this.session.run(hashMap).get(0).getValue())[0]));
            }
            return this.classificationScoringStrategy.score(linkedList);
        } catch (Exception e) {
            logger.error("Unload to perform document classification inference", e);
            return new double[0];
        }
    }

    public double[] categorize(String[] strArr, Map<String, Object> map) {
        return categorize(strArr);
    }

    public String getBestCategory(double[] dArr) {
        return this.categories.get(Integer.valueOf(maxIndex(dArr)));
    }

    public int getIndex(String str) {
        return getKey(str);
    }

    public String getCategory(int i) {
        return this.categories.get(Integer.valueOf(i));
    }

    public int getNumberOfCategories() {
        return this.categories.size();
    }

    public String getAllResults(double[] dArr) {
        return null;
    }

    public Map<String, Double> scoreMap(String[] strArr) {
        double[] categorize = categorize(strArr);
        HashMap hashMap = new HashMap();
        Iterator<Integer> it = this.categories.keySet().iterator();
        while (it.hasNext()) {
            int intValue = it.next().intValue();
            hashMap.put(this.categories.get(Integer.valueOf(intValue)), Double.valueOf(categorize[intValue]));
        }
        return hashMap;
    }

    public SortedMap<Double, Set<String>> sortedScoreMap(String[] strArr) {
        double[] categorize = categorize(strArr);
        TreeMap treeMap = new TreeMap();
        Iterator<Integer> it = this.categories.keySet().iterator();
        while (it.hasNext()) {
            int intValue = it.next().intValue();
            if (treeMap.get(Double.valueOf(categorize[intValue])) == null) {
                treeMap.put(Double.valueOf(categorize[intValue]), new HashSet());
            }
            ((Set) treeMap.get(Double.valueOf(categorize[intValue]))).add(this.categories.get(Integer.valueOf(intValue)));
        }
        return treeMap;
    }

    private int getKey(String str) {
        for (Map.Entry<Integer, String> entry : this.categories.entrySet()) {
            if (entry.getValue().equals(str)) {
                return entry.getKey().intValue();
            }
        }
        return -1;
    }

    private List<Tokens> tokenize(String str) {
        LinkedList linkedList = new LinkedList();
        String[] split = str.split("\\s+");
        int i = 0;
        while (true) {
            int i2 = i;
            if (i2 >= split.length) {
                return linkedList;
            }
            int documentSplitSize = i2 + this.inferenceOptions.getDocumentSplitSize();
            if (documentSplitSize > split.length) {
                documentSplitSize = split.length;
            }
            String join = String.join(" ", (CharSequence[]) Arrays.copyOfRange(split, i2, documentSplitSize));
            int splitOverlapSize = i2 - this.inferenceOptions.getSplitOverlapSize();
            String[] strArr = this.tokenizer.tokenize(join);
            int[] iArr = new int[strArr.length];
            for (int i3 = 0; i3 < strArr.length; i3++) {
                iArr[i3] = this.vocab.get(strArr[i3]).intValue();
            }
            long[] array = Arrays.stream(iArr).mapToLong(i4 -> {
                return i4;
            }).toArray();
            long[] jArr = new long[iArr.length];
            Arrays.fill(jArr, 1L);
            long[] jArr2 = new long[iArr.length];
            Arrays.fill(jArr2, 0L);
            linkedList.add(new Tokens(strArr, array, jArr, jArr2));
            i = splitOverlapSize + this.inferenceOptions.getDocumentSplitSize();
        }
    }

    private double[] softmax(float[] fArr) {
        double[] dArr = new double[fArr.length];
        double d = 0.0d;
        for (int i = 0; i < fArr.length; i++) {
            double exp = Math.exp(fArr[i]);
            d += exp;
            dArr[i] = exp;
        }
        double[] dArr2 = new double[fArr.length];
        for (int i2 = 0; i2 < dArr2.length; i2++) {
            dArr2[i2] = (float) (dArr[i2] / d);
        }
        return dArr2;
    }

    private int maxIndex(double[] dArr) {
        return IntStream.range(0, dArr.length).reduce((i, i2) -> {
            return dArr[i] > dArr[i2] ? i : i2;
        }).orElse(-1);
    }

    private Map<Integer, String> readCategoriesFromFile(File file) throws IOException {
        String str = new String(Files.readAllBytes(file.toPath()));
        ObjectMapper objectMapper = new ObjectMapper();
        objectMapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
        DocumentCategorizerConfig documentCategorizerConfig = (DocumentCategorizerConfig) objectMapper.readValue(str, DocumentCategorizerConfig.class);
        HashMap hashMap = new HashMap();
        for (String str2 : documentCategorizerConfig.getId2label().keySet()) {
            hashMap.put(Integer.valueOf(str2), documentCategorizerConfig.getId2label().get(str2));
        }
        return hashMap;
    }
}
