package opennlp.dl.doccat;

import ai.onnxruntime.OrtException;
import java.io.File;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import java.util.SortedMap;
import opennlp.dl.AbstractDLTest;
import opennlp.dl.InferenceOptions;
import opennlp.dl.doccat.scoring.AverageClassificationScoringStrategy;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:opennlp/dl/doccat/DocumentCategorizerDLEval.class */
public class DocumentCategorizerDLEval extends AbstractDLTest {
    private static final Logger logger = LoggerFactory.getLogger(DocumentCategorizerDLEval.class);
    final String text = "We try hard to identify the sources and licenses of all media such as text, images or sounds used in our encyclopedia articles. Still, we cannot guarantee that all media are used or marked correctly: for example, if an image description page states that an image was in the public domain, you should still check yourself whether that claim appears correct and decide for yourself whether your use of the image would be fine under the laws applicable to you. Wikipedia is primarily subject to U.S. law; re-users outside the U.S. should be aware that they are subject to the laws of their country, which almost certainly are different. Images published under the GFDL or one of the Creative Commons Licenses are unlikely to pose problems, as these are specific licenses with precise terms worldwide. Public domain images may need to be re-evaluated by a re-user because it depends on each country's copyright laws what is in the public domain there. There is no guarantee that something in the public domain in the U.S. was also in the public domain in your country.";

    @Test
    public void categorize() throws IOException, OrtException {
        DocumentCategorizerDL documentCategorizerDL = new DocumentCategorizerDL(new File(getOpennlpDataDir(), "onnx/doccat/nlptown_bert-base-multilingual-uncased-sentiment.onnx"), new File(getOpennlpDataDir(), "onnx/doccat/nlptown_bert-base-multilingual-uncased-sentiment.vocab"), getCategories(), new AverageClassificationScoringStrategy(), new InferenceOptions());
        double[] categorize = documentCategorizerDL.categorize(new String[]{"We try hard to identify the sources and licenses of all media such as text, images or sounds used in our encyclopedia articles. Still, we cannot guarantee that all media are used or marked correctly: for example, if an image description page states that an image was in the public domain, you should still check yourself whether that claim appears correct and decide for yourself whether your use of the image would be fine under the laws applicable to you. Wikipedia is primarily subject to U.S. law; re-users outside the U.S. should be aware that they are subject to the laws of their country, which almost certainly are different. Images published under the GFDL or one of the Creative Commons Licenses are unlikely to pose problems, as these are specific licenses with precise terms worldwide. Public domain images may need to be re-evaluated by a re-user because it depends on each country's copyright laws what is in the public domain there. There is no guarantee that something in the public domain in the U.S. was also in the public domain in your country."});
        double[] array = Arrays.stream(categorize).boxed().sorted(Collections.reverseOrder()).mapToDouble((v0) -> {
            return v0.doubleValue();
        }).toArray();
        double[] dArr = {0.3391093313694d, 0.2611352801322937d, 0.24420668184757233d, 0.11939861625432968d, 0.03615010157227516d};
        logger.debug("Actual: {}", Arrays.toString(array));
        logger.debug("Expected: {}", Arrays.toString(dArr));
        Assertions.assertArrayEquals(dArr, array, 1.0E-6d);
        Assertions.assertEquals(5, categorize.length);
        Assertions.assertEquals("bad", documentCategorizerDL.getBestCategory(categorize));
    }

    @Test
    public void categorizeWithAutomaticLabels() throws IOException, OrtException {
        DocumentCategorizerDL documentCategorizerDL = new DocumentCategorizerDL(new File(getOpennlpDataDir(), "onnx/doccat/nlptown_bert-base-multilingual-uncased-sentiment.onnx"), new File(getOpennlpDataDir(), "onnx/doccat/nlptown_bert-base-multilingual-uncased-sentiment.vocab"), new File(getOpennlpDataDir(), "onnx/doccat/nlptown_bert-base-multilingual-uncased-sentiment.json"), new AverageClassificationScoringStrategy(), new InferenceOptions());
        double[] categorize = documentCategorizerDL.categorize(new String[]{"We try hard to identify the sources and licenses of all media such as text, images or sounds used in our encyclopedia articles. Still, we cannot guarantee that all media are used or marked correctly: for example, if an image description page states that an image was in the public domain, you should still check yourself whether that claim appears correct and decide for yourself whether your use of the image would be fine under the laws applicable to you. Wikipedia is primarily subject to U.S. law; re-users outside the U.S. should be aware that they are subject to the laws of their country, which almost certainly are different. Images published under the GFDL or one of the Creative Commons Licenses are unlikely to pose problems, as these are specific licenses with precise terms worldwide. Public domain images may need to be re-evaluated by a re-user because it depends on each country's copyright laws what is in the public domain there. There is no guarantee that something in the public domain in the U.S. was also in the public domain in your country."});
        double[] array = Arrays.stream(categorize).boxed().sorted(Collections.reverseOrder()).mapToDouble((v0) -> {
            return v0.doubleValue();
        }).toArray();
        double[] dArr = {0.3391093313694d, 0.2611352801322937d, 0.24420668184757233d, 0.11939861625432968d, 0.03615010157227516d};
        logger.debug("Actual: {}", Arrays.toString(array));
        logger.debug("Expected: {}", Arrays.toString(dArr));
        Assertions.assertArrayEquals(dArr, array, 1.0E-6d);
        Assertions.assertEquals(5, categorize.length);
        Assertions.assertEquals("2 stars", documentCategorizerDL.getBestCategory(categorize));
    }

    @Disabled("This test will should only be run if a GPU device is present.")
    @Test
    public void categorizeWithGpu() throws Exception {
        File file = new File(getOpennlpDataDir(), "onnx/doccat/nlptown_bert-base-multilingual-uncased-sentiment.onnx");
        File file2 = new File(getOpennlpDataDir(), "onnx/doccat/nlptown_bert-base-multilingual-uncased-sentiment.vocab");
        InferenceOptions inferenceOptions = new InferenceOptions();
        inferenceOptions.setGpu(true);
        inferenceOptions.setGpuDeviceId(0);
        DocumentCategorizerDL documentCategorizerDL = new DocumentCategorizerDL(file, file2, getCategories(), new AverageClassificationScoringStrategy(), new InferenceOptions());
        double[] categorize = documentCategorizerDL.categorize(new String[]{"I am happy"});
        logger.debug(Arrays.toString(categorize));
        Assertions.assertArrayEquals(new double[]{0.007819971069693565d, 0.006593209225684404d, 0.04995147883892059d, 0.3003573715686798d, 0.6352779865264893d}, categorize, 1.0E-6d);
        Assertions.assertEquals(5, categorize.length);
        Assertions.assertEquals("very good", documentCategorizerDL.getBestCategory(categorize));
    }

    @Test
    public void categorizeWithInferenceOptions() throws Exception {
        File file = new File(getOpennlpDataDir(), "onnx/doccat/lvwerra_distilbert-imdb.onnx");
        File file2 = new File(getOpennlpDataDir(), "onnx/doccat/lvwerra_distilbert-imdb.vocab");
        InferenceOptions inferenceOptions = new InferenceOptions();
        inferenceOptions.setIncludeTokenTypeIds(false);
        HashMap hashMap = new HashMap();
        hashMap.put(0, "negative");
        hashMap.put(1, "positive");
        DocumentCategorizerDL documentCategorizerDL = new DocumentCategorizerDL(file, file2, hashMap, new AverageClassificationScoringStrategy(), inferenceOptions);
        double[] categorize = documentCategorizerDL.categorize(new String[]{"I am angry"});
        Assertions.assertArrayEquals(new double[]{0.8851314783096313d, 0.11486853659152985d}, categorize, 1.0E-6d);
        Assertions.assertEquals(2, categorize.length);
        Assertions.assertEquals("negative", documentCategorizerDL.getBestCategory(categorize));
    }

    @Test
    public void scoreMap() throws Exception {
        Map scoreMap = new DocumentCategorizerDL(new File(getOpennlpDataDir(), "onnx/doccat/nlptown_bert-base-multilingual-uncased-sentiment.onnx"), new File(getOpennlpDataDir(), "onnx/doccat/nlptown_bert-base-multilingual-uncased-sentiment.vocab"), getCategories(), new AverageClassificationScoringStrategy(), new InferenceOptions()).scoreMap(new String[]{"I am happy"});
        Assertions.assertEquals(0.6352779865264893d, ((Double) scoreMap.get("very good")).doubleValue(), 1.0E-6d);
        Assertions.assertEquals(0.3003573715686798d, ((Double) scoreMap.get("good")).doubleValue(), 1.0E-6d);
        Assertions.assertEquals(0.04995147883892059d, ((Double) scoreMap.get("neutral")).doubleValue(), 1.0E-6d);
        Assertions.assertEquals(0.006593209225684404d, ((Double) scoreMap.get("bad")).doubleValue(), 1.0E-6d);
        Assertions.assertEquals(0.007819971069693565d, ((Double) scoreMap.get("very bad")).doubleValue(), 1.0E-6d);
    }

    @Test
    public void sortedScoreMap() throws IOException, OrtException {
        SortedMap sortedScoreMap = new DocumentCategorizerDL(new File(getOpennlpDataDir(), "onnx/doccat/nlptown_bert-base-multilingual-uncased-sentiment.onnx"), new File(getOpennlpDataDir(), "onnx/doccat/nlptown_bert-base-multilingual-uncased-sentiment.vocab"), getCategories(), new AverageClassificationScoringStrategy(), new InferenceOptions()).sortedScoreMap(new String[]{"I am happy"});
        Assertions.assertNotNull(sortedScoreMap, "Result must not be NULL.");
        Assertions.assertEquals(5, sortedScoreMap.size());
        Iterator it = sortedScoreMap.entrySet().iterator();
        Map.Entry entry = (Map.Entry) it.next();
        Assertions.assertEquals(0.006593209225684404d, ((Double) entry.getKey()).doubleValue(), 1.0E-6d);
        Assertions.assertEquals(((Set) entry.getValue()).size(), 1);
        Map.Entry entry2 = (Map.Entry) it.next();
        Assertions.assertEquals(0.007819971069693565d, ((Double) entry2.getKey()).doubleValue(), 1.0E-6d);
        Assertions.assertEquals(((Set) entry2.getValue()).size(), 1);
        Map.Entry entry3 = (Map.Entry) it.next();
        Assertions.assertEquals(0.04995147883892059d, ((Double) entry3.getKey()).doubleValue(), 1.0E-6d);
        Assertions.assertEquals(((Set) entry3.getValue()).size(), 1);
        Map.Entry entry4 = (Map.Entry) it.next();
        Assertions.assertEquals(0.3003573715686798d, ((Double) entry4.getKey()).doubleValue(), 1.0E-6d);
        Assertions.assertEquals(((Set) entry4.getValue()).size(), 1);
        Map.Entry entry5 = (Map.Entry) it.next();
        Assertions.assertEquals(0.6352779865264893d, ((Double) entry5.getKey()).doubleValue(), 1.0E-6d);
        Assertions.assertEquals(((Set) entry5.getValue()).size(), 1);
    }

    @Test
    public void doccat() throws IOException, OrtException {
        DocumentCategorizerDL documentCategorizerDL = new DocumentCategorizerDL(new File(getOpennlpDataDir(), "onnx/doccat/nlptown_bert-base-multilingual-uncased-sentiment.onnx"), new File(getOpennlpDataDir(), "onnx/doccat/nlptown_bert-base-multilingual-uncased-sentiment.vocab"), getCategories(), new AverageClassificationScoringStrategy(), new InferenceOptions());
        Assertions.assertEquals(1, documentCategorizerDL.getIndex("bad"));
        Assertions.assertEquals("good", documentCategorizerDL.getCategory(3));
        Assertions.assertEquals(5, documentCategorizerDL.getNumberOfCategories());
    }

    private Map<Integer, String> getCategories() {
        HashMap hashMap = new HashMap();
        hashMap.put(0, "very bad");
        hashMap.put(1, "bad");
        hashMap.put(2, "neutral");
        hashMap.put(3, "good");
        hashMap.put(4, "very good");
        return hashMap;
    }
}
