package opennlp.dl.namefinder;

import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.nio.LongBuffer;
import java.util.Arrays;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import opennlp.dl.InferenceOptions;
import opennlp.dl.SpanEnd;
import opennlp.dl.Tokens;
import opennlp.tools.namefind.TokenNameFinder;
import opennlp.tools.sentdetect.SentenceDetector;
import opennlp.tools.tokenize.Tokenizer;
import opennlp.tools.tokenize.WordpieceTokenizer;
import opennlp.tools.util.Span;

/* loaded from: input_file:opennlp/dl/namefinder/NameFinderDL.class */
public class NameFinderDL implements TokenNameFinder {
    public static final String INPUT_IDS = "input_ids";
    public static final String ATTENTION_MASK = "attention_mask";
    public static final String TOKEN_TYPE_IDS = "token_type_ids";
    public static final String I_PER = "I-PER";
    public static final String B_PER = "B-PER";
    public static final String SEPARATOR = "[SEP]";
    private static final String CHARS_TO_REPLACE = "##";
    protected final OrtSession session;
    private final SentenceDetector sentenceDetector;
    private final Map<Integer, String> ids2Labels;
    private final Tokenizer tokenizer;
    private final Map<String, Integer> vocab;
    private final InferenceOptions inferenceOptions;
    protected final OrtEnvironment env;

    public NameFinderDL(File file, File file2, Map<Integer, String> map, SentenceDetector sentenceDetector) throws Exception {
        this(file, file2, map, new InferenceOptions(), sentenceDetector);
    }

    public NameFinderDL(File file, File file2, Map<Integer, String> map, InferenceOptions inferenceOptions, SentenceDetector sentenceDetector) throws Exception {
        this.env = OrtEnvironment.getEnvironment();
        OrtSession.SessionOptions sessionOptions = new OrtSession.SessionOptions();
        if (inferenceOptions.isGpu()) {
            sessionOptions.addCUDA(inferenceOptions.getGpuDeviceId());
        }
        this.session = this.env.createSession(file.getPath(), sessionOptions);
        this.ids2Labels = map;
        this.vocab = loadVocab(file2);
        this.tokenizer = new WordpieceTokenizer(this.vocab.keySet());
        this.inferenceOptions = inferenceOptions;
        this.sentenceDetector = sentenceDetector;
    }

    public Span[] find(String[] strArr) {
        String str;
        LinkedList linkedList = new LinkedList();
        String join = String.join(" ", strArr);
        for (String str2 : this.sentenceDetector.sentDetect(join)) {
            for (Tokens tokens : tokenize(str2)) {
                try {
                    HashMap hashMap = new HashMap();
                    hashMap.put("input_ids", OnnxTensor.createTensor(this.env, LongBuffer.wrap(tokens.getIds()), new long[]{1, tokens.getIds().length}));
                    if (this.inferenceOptions.isIncludeAttentionMask()) {
                        hashMap.put("attention_mask", OnnxTensor.createTensor(this.env, LongBuffer.wrap(tokens.getMask()), new long[]{1, tokens.getMask().length}));
                    }
                    if (this.inferenceOptions.isIncludeTokenTypeIds()) {
                        hashMap.put("token_type_ids", OnnxTensor.createTensor(this.env, LongBuffer.wrap(tokens.getTypes()), new long[]{1, tokens.getTypes().length}));
                    }
                    float[][][] fArr = (float[][][]) this.session.run(hashMap).get(0).getValue();
                    int i = 0;
                    String[] tokens2 = tokens.getTokens();
                    for (int i2 = 0; i2 < fArr[0].length; i2++) {
                        float[] fArr2 = fArr[0][i2];
                        int maxIndex = maxIndex(fArr2);
                        String str3 = this.ids2Labels.get(Integer.valueOf(maxIndex));
                        double d = fArr2[maxIndex];
                        if (B_PER.equals(str3)) {
                            SpanEnd findSpanEnd = findSpanEnd(fArr, i2, this.ids2Labels, tokens2);
                            if (findSpanEnd.getIndex() != -1) {
                                StringBuilder sb = new StringBuilder();
                                int index = findSpanEnd.getIndex();
                                int i3 = i2;
                                while (i3 <= index) {
                                    if (tokens2[i3 + 1].startsWith(CHARS_TO_REPLACE)) {
                                        sb.append(tokens2[i3]).append(tokens2[i3 + 1].replace(CHARS_TO_REPLACE, ""));
                                        if (!tokens2[i3 + 2].startsWith(CHARS_TO_REPLACE)) {
                                            sb.append(" ");
                                        }
                                        i3++;
                                    } else {
                                        sb.append(tokens2[i3].replace(CHARS_TO_REPLACE, ""));
                                        if (!".".equals(tokens2[i3 + 1])) {
                                            sb.append(" ");
                                        }
                                    }
                                    i3++;
                                }
                                str = findByRegex(join, sb.toString().trim()).trim();
                            } else {
                                str = tokens2[i2];
                            }
                            if (!SEPARATOR.equals(str)) {
                                String replace = str.replace(CHARS_TO_REPLACE, "");
                                i = join.indexOf(replace, i);
                                if (i != -1) {
                                    linkedList.add(new Span(i, i + replace.length(), replace, d));
                                    i++;
                                }
                            }
                        }
                    }
                } catch (OrtException e) {
                    throw new RuntimeException("Error performing namefinder inference: " + e.getMessage(), e);
                }
            }
        }
        return (Span[]) linkedList.toArray(new Span[0]);
    }

    public void clearAdaptiveData() {
    }

    private SpanEnd findSpanEnd(float[][][] fArr, int i, Map<Integer, String> map, String[] strArr) {
        int i2 = -1;
        int i3 = 0;
        int i4 = i + 1;
        while (true) {
            if (i4 >= fArr[0].length) {
                break;
            }
            if (!I_PER.equals(map.get(Integer.valueOf(maxIndex(fArr[0][i4]))))) {
                i2 = i4 - 1;
                break;
            }
            i4++;
        }
        for (int i5 = 1; i5 <= i2 && i5 < strArr.length; i5++) {
            i3 += strArr[i5].length();
        }
        return new SpanEnd(i2, i3 + (i2 - 1));
    }

    private int maxIndex(float[] fArr) {
        double d = Double.NEGATIVE_INFINITY;
        int i = -1;
        for (int i2 = 0; i2 < fArr.length; i2++) {
            if (fArr[i2] > d) {
                i = i2;
                d = fArr[i2];
            }
        }
        return i;
    }

    private static String findByRegex(String str, String str2) {
        Matcher matcher = Pattern.compile(str2.replaceAll(" ", "\\\\s+").replaceAll("\\)", "\\\\)").replaceAll("\\(", "\\\\("), 2).matcher(str);
        return matcher.find() ? matcher.group(0) : str2;
    }

    private List<Tokens> tokenize(String str) {
        LinkedList linkedList = new LinkedList();
        String[] split = str.split("\\s+");
        int i = 0;
        while (true) {
            int i2 = i;
            if (i2 >= split.length) {
                return linkedList;
            }
            int documentSplitSize = i2 + this.inferenceOptions.getDocumentSplitSize();
            if (documentSplitSize > split.length) {
                documentSplitSize = split.length;
            }
            String join = String.join(" ", (CharSequence[]) Arrays.copyOfRange(split, i2, documentSplitSize));
            int splitOverlapSize = i2 - this.inferenceOptions.getSplitOverlapSize();
            String[] strArr = this.tokenizer.tokenize(join);
            int[] iArr = new int[strArr.length];
            for (int i3 = 0; i3 < strArr.length; i3++) {
                iArr[i3] = this.vocab.get(strArr[i3]).intValue();
            }
            long[] array = Arrays.stream(iArr).mapToLong(i4 -> {
                return i4;
            }).toArray();
            long[] jArr = new long[iArr.length];
            Arrays.fill(jArr, 1L);
            long[] jArr2 = new long[iArr.length];
            Arrays.fill(jArr2, 0L);
            linkedList.add(new Tokens(strArr, array, jArr, jArr2));
            i = splitOverlapSize + this.inferenceOptions.getDocumentSplitSize();
        }
    }

    private Map<String, Integer> loadVocab(File file) throws IOException {
        HashMap hashMap = new HashMap();
        BufferedReader bufferedReader = new BufferedReader(new FileReader(file.getPath()));
        try {
            String readLine = bufferedReader.readLine();
            int i = 0;
            while (readLine != null) {
                readLine = bufferedReader.readLine();
                i++;
                hashMap.put(readLine, Integer.valueOf(i));
            }
            bufferedReader.close();
            return hashMap;
        } catch (Throwable th) {
            try {
                bufferedReader.close();
            } catch (Throwable th2) {
                th.addSuppressed(th2);
            }
            throw th;
        }
    }
}
