package opennlp.tools.ml.naivebayes;

import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import opennlp.tools.ml.model.DataIndexer;
import opennlp.tools.ml.model.Event;
import opennlp.tools.ml.model.MaxentModel;
import opennlp.tools.ml.model.TwoPassDataIndexer;
import opennlp.tools.util.ObjectStream;
import opennlp.tools.util.ObjectStreamUtils;
import opennlp.tools.util.TrainingParameters;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;

/* loaded from: input_file:opennlp/tools/ml/naivebayes/NaiveBayesCorrectnessTest.class */
public class NaiveBayesCorrectnessTest {
    private DataIndexer testDataIndexer;

    @Before
    public void initIndexer() {
        TrainingParameters trainingParameters = new TrainingParameters();
        trainingParameters.put("Cutoff", 1);
        trainingParameters.put("sort", false);
        this.testDataIndexer = new TwoPassDataIndexer();
        this.testDataIndexer.init(trainingParameters, new HashMap());
    }

    @Test
    public void testNaiveBayes1() throws IOException {
        this.testDataIndexer.index(createTrainingStream());
        testModel(new NaiveBayesTrainer().trainModel(this.testDataIndexer), new Event("politics", new String[]{"bow=united", "bow=nations"}), 0.9681650180264167d);
    }

    @Test
    public void testNaiveBayes2() throws IOException {
        this.testDataIndexer.index(createTrainingStream());
        testModel(new NaiveBayesTrainer().trainModel(this.testDataIndexer), new Event("sports", new String[]{"bow=manchester", "bow=united"}), 0.9658833555831029d);
    }

    @Test
    public void testNaiveBayes3() throws IOException {
        this.testDataIndexer.index(createTrainingStream());
        testModel(new NaiveBayesTrainer().trainModel(this.testDataIndexer), new Event("politics", new String[]{"bow=united"}), 0.6655036407766989d);
    }

    @Test
    public void testNaiveBayes4() throws IOException {
        this.testDataIndexer.index(createTrainingStream());
        testModel(new NaiveBayesTrainer().trainModel(this.testDataIndexer), new Event("politics", new String[0]), 0.5833333333333334d);
    }

    private void testModel(MaxentModel maxentModel, Event event, double d) {
        double[] eval = maxentModel.eval(event.getContext());
        String bestOutcome = maxentModel.getBestOutcome(eval);
        Assert.assertEquals(2L, eval.length);
        Assert.assertEquals(event.getOutcome(), bestOutcome);
        if (event.getOutcome().equals(maxentModel.getOutcome(0))) {
            Assert.assertEquals(d, eval[0], 1.0E-4d);
        }
        if (!event.getOutcome().equals(maxentModel.getOutcome(0))) {
            Assert.assertEquals(1.0d - d, eval[0], 1.0E-4d);
        }
        if (event.getOutcome().equals(maxentModel.getOutcome(1))) {
            Assert.assertEquals(d, eval[1], 1.0E-4d);
        }
        if (event.getOutcome().equals(maxentModel.getOutcome(1))) {
            return;
        }
        Assert.assertEquals(1.0d - d, eval[1], 1.0E-4d);
    }

    public static ObjectStream<Event> createTrainingStream() throws IOException {
        ArrayList arrayList = new ArrayList();
        arrayList.add(new Event("politics", new String[]{"bow=the", "bow=united", "bow=nations"}));
        arrayList.add(new Event("politics", new String[]{"bow=the", "bow=united", "bow=states", "bow=and"}));
        arrayList.add(new Event("sports", new String[]{"bow=manchester", "bow=united"}));
        arrayList.add(new Event("sports", new String[]{"bow=manchester", "bow=and", "bow=barca"}));
        return ObjectStreamUtils.createObjectStream(arrayList);
    }
}
