package opennlp.tools.ml.naivebayes;

import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.IOException;
import java.io.StringReader;
import java.io.StringWriter;
import java.nio.file.Files;
import java.nio.file.attribute.FileAttribute;
import java.util.HashMap;
import java.util.stream.Stream;
import opennlp.tools.ml.model.DataIndexer;
import opennlp.tools.ml.model.Event;
import opennlp.tools.ml.model.TwoPassDataIndexer;
import opennlp.tools.util.TrainingParameters;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;

/* loaded from: input_file:opennlp/tools/ml/naivebayes/NaiveBayesSerializedCorrectnessTest.class */
public class NaiveBayesSerializedCorrectnessTest extends AbstractNaiveBayesTest {
    private DataIndexer testDataIndexer;

    @BeforeEach
    void initIndexer() throws IOException {
        TrainingParameters trainingParameters = new TrainingParameters();
        trainingParameters.put("Cutoff", 1);
        trainingParameters.put("sort", false);
        this.testDataIndexer = new TwoPassDataIndexer();
        this.testDataIndexer.init(trainingParameters, new HashMap());
        this.testDataIndexer.index(createTrainingStream());
    }

    @MethodSource({"provideLabelsWithContext"})
    @ParameterizedTest
    void testNaiveBayes(String str, String[] strArr) throws IOException {
        NaiveBayesModel trainModel = new NaiveBayesTrainer().trainModel(this.testDataIndexer);
        testModelOutcome(trainModel, persistedModel(trainModel), new Event(str, strArr));
    }

    private static Stream<Arguments> provideLabelsWithContext() {
        return Stream.of((Object[]) new Arguments[]{Arguments.of(new Object[]{"politics", new String[]{"bow=united", "bow=nations"}}), Arguments.of(new Object[]{"sports", new String[]{"bow=manchester", "bow=united"}}), Arguments.of(new Object[]{"politics", new String[]{"bow=united"}}), Arguments.of(new Object[]{"politics", new String[0]})});
    }

    @Test
    void testPlainTextModel() throws IOException {
        NaiveBayesModel trainModel = new NaiveBayesTrainer().trainModel(this.testDataIndexer);
        StringWriter stringWriter = new StringWriter();
        new PlainTextNaiveBayesModelWriter(trainModel, new BufferedWriter(stringWriter)).persist();
        PlainTextNaiveBayesModelReader plainTextNaiveBayesModelReader = new PlainTextNaiveBayesModelReader(new BufferedReader(new StringReader(stringWriter.toString())));
        plainTextNaiveBayesModelReader.checkModelType();
        NaiveBayesModel constructModel = plainTextNaiveBayesModelReader.constructModel();
        StringWriter stringWriter2 = new StringWriter();
        new PlainTextNaiveBayesModelWriter(constructModel, new BufferedWriter(stringWriter2)).persist();
        Assertions.assertEquals(stringWriter.toString(), stringWriter2.toString());
    }

    private static NaiveBayesModel persistedModel(NaiveBayesModel naiveBayesModel) throws IOException {
        File file = Files.createTempFile("ptnb-", ".bin", new FileAttribute[0]).toFile();
        try {
            new BinaryNaiveBayesModelWriter(naiveBayesModel, file).persist();
            BinaryNaiveBayesModelReader binaryNaiveBayesModelReader = new BinaryNaiveBayesModelReader(file);
            binaryNaiveBayesModelReader.checkModelType();
            NaiveBayesModel constructModel = binaryNaiveBayesModelReader.constructModel();
            file.delete();
            return constructModel;
        } catch (Throwable th) {
            file.delete();
            throw th;
        }
    }

    private static void testModelOutcome(NaiveBayesModel naiveBayesModel, NaiveBayesModel naiveBayesModel2, Event event) {
        Assertions.assertArrayEquals(extractLabels(naiveBayesModel), extractLabels(naiveBayesModel2));
        Assertions.assertArrayEquals(naiveBayesModel.eval(event.getContext()), naiveBayesModel2.eval(event.getContext()), 1.0E-12d);
    }

    private static String[] extractLabels(NaiveBayesModel naiveBayesModel) {
        String[] strArr = new String[naiveBayesModel.getNumOutcomes()];
        for (int i = 0; i < naiveBayesModel.getNumOutcomes(); i++) {
            strArr[i] = naiveBayesModel.getOutcome(i);
        }
        return strArr;
    }
}
