package opennlp.dl;

import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.stream.IntStream;
import opennlp.tools.tokenize.Tokenizer;
import opennlp.tools.tokenize.WordpieceTokenizer;

/* loaded from: input_file:opennlp/dl/Inference.class */
public abstract class Inference {
    public static final String INPUT_IDS = "input_ids";
    public static final String ATTENTION_MASK = "attention_mask";
    public static final String TOKEN_TYPE_IDS = "token_type_ids";
    protected final OrtEnvironment env = OrtEnvironment.getEnvironment();
    protected final OrtSession session;
    private final Tokenizer tokenizer;
    private final Map<String, Integer> vocabulary;

    public abstract double[][] infer(String str) throws Exception;

    public Inference(File file, File file2) throws OrtException, IOException {
        this.session = this.env.createSession(file.getPath(), new OrtSession.SessionOptions());
        this.vocabulary = loadVocab(file2);
        this.tokenizer = new WordpieceTokenizer(this.vocabulary.keySet());
    }

    public Tokens tokenize(String str) {
        String[] strArr = this.tokenizer.tokenize(str);
        int[] iArr = new int[strArr.length];
        for (int i = 0; i < strArr.length; i++) {
            iArr[i] = this.vocabulary.get(strArr[i]).intValue();
        }
        long[] array = Arrays.stream(iArr).mapToLong(i2 -> {
            return i2;
        }).toArray();
        long[] jArr = new long[iArr.length];
        Arrays.fill(jArr, 1L);
        long[] jArr2 = new long[iArr.length];
        Arrays.fill(jArr2, 0L);
        return new Tokens(strArr, array, jArr, jArr2);
    }

    public Map<String, Integer> loadVocab(File file) throws IOException {
        HashMap hashMap = new HashMap();
        BufferedReader bufferedReader = new BufferedReader(new FileReader(file.getPath()));
        String readLine = bufferedReader.readLine();
        int i = 0;
        while (readLine != null) {
            readLine = bufferedReader.readLine();
            i++;
            hashMap.put(readLine, Integer.valueOf(i));
        }
        return hashMap;
    }

    public static int maxIndex(double[] dArr) {
        return IntStream.range(0, dArr.length).reduce((i, i2) -> {
            return dArr[i] > dArr[i2] ? i : i2;
        }).orElse(-1);
    }

    public double[] softmax(double[] dArr) {
        double[] dArr2 = new double[dArr.length];
        double d = 0.0d;
        for (int i = 0; i < dArr.length; i++) {
            double exp = Math.exp(dArr[i]);
            d += exp;
            dArr2[i] = exp;
        }
        double[] dArr3 = new double[dArr.length];
        for (int i2 = 0; i2 < dArr3.length; i2++) {
            dArr3[i2] = (float) (dArr2[i2] / d);
        }
        return dArr3;
    }

    public double[][] convertFloatsToDoubles(float[][] fArr) {
        double[][] dArr = new double[fArr.length][fArr[0].length];
        for (int i = 0; i < fArr.length; i++) {
            for (int i2 = 0; i2 < fArr[0].length; i2++) {
                dArr[i][i2] = fArr[i][i2];
            }
        }
        return dArr;
    }

    public double[] convertFloatsToDoubles(float[] fArr) {
        double[] dArr = new double[fArr.length];
        for (int i = 0; i < fArr.length; i++) {
            dArr[i] = fArr[i];
        }
        return dArr;
    }
}
