package opennlp.tools.ml.maxent.quasinewton;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.util.HashMap;
import opennlp.tools.ml.model.BinaryFileDataReader;
import opennlp.tools.ml.model.DataIndexer;
import opennlp.tools.ml.model.GenericModelReader;
import opennlp.tools.ml.model.GenericModelWriter;
import opennlp.tools.ml.model.OnePassRealValueDataIndexer;
import opennlp.tools.ml.model.RealValueFileEventStream;
import opennlp.tools.util.TrainingParameters;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

/* loaded from: input_file:opennlp/tools/ml/maxent/quasinewton/QNTrainerTest.class */
public class QNTrainerTest {
    private static final int ITERATIONS = 50;
    private DataIndexer testDataIndexer;

    @BeforeEach
    void initIndexer() {
        TrainingParameters trainingParameters = new TrainingParameters();
        trainingParameters.put("Cutoff", 1);
        this.testDataIndexer = new OnePassRealValueDataIndexer();
        this.testDataIndexer.init(trainingParameters, new HashMap());
    }

    @Test
    void testTrainModelReturnsAQNModel() throws Exception {
        this.testDataIndexer.index(new RealValueFileEventStream("src/test/resources/data/opennlp/maxent/real-valued-weights-training-data.txt"));
        Assertions.assertNotNull(new QNTrainer(false).trainModel(ITERATIONS, this.testDataIndexer));
    }

    @Test
    void testInTinyDevSet() throws Exception {
        this.testDataIndexer.index(new RealValueFileEventStream("src/test/resources/data/opennlp/maxent/real-valued-weights-training-data.txt"));
        Assertions.assertNotNull(new QNTrainer(15, true).trainModel(ITERATIONS, this.testDataIndexer).eval(new String[]{"feature2", "feature3", "feature3", "feature3", "feature3", "feature3", "feature3", "feature3", "feature3", "feature3", "feature3", "feature3"}));
    }

    @Test
    void testModel() throws IOException {
        this.testDataIndexer.index(new RealValueFileEventStream("src/test/resources/data/opennlp/maxent/real-valued-weights-training-data.txt"));
        Assertions.assertFalse(new QNTrainer(15, true).trainModel(ITERATIONS, this.testDataIndexer).equals((Object) null));
    }

    @Test
    void testSerdeModel() throws IOException {
        this.testDataIndexer.index(new RealValueFileEventStream("src/test/resources/data/opennlp/maxent/real-valued-weights-training-data.txt"));
        QNModel trainModel = new QNTrainer(5, 700, true).trainModel(ITERATIONS, this.testDataIndexer);
        ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
        GenericModelWriter genericModelWriter = new GenericModelWriter(trainModel, new DataOutputStream(byteArrayOutputStream));
        genericModelWriter.persist();
        genericModelWriter.close();
        QNModel model = new GenericModelReader(new BinaryFileDataReader(new ByteArrayInputStream(byteArrayOutputStream.toByteArray()))).getModel();
        Assertions.assertTrue(trainModel.equals(model));
        String[] strArr = {"feature2", "feature3", "feature3", "feature3", "feature3", "feature3", "feature3", "feature3", "feature3", "feature3", "feature3", "feature3"};
        double[] eval = trainModel.eval(strArr);
        double[] eval2 = model.eval(strArr);
        Assertions.assertEquals(eval.length, eval2.length);
        for (int i = 0; i < eval.length; i++) {
            Assertions.assertEquals(eval[i], eval2[i], 1.0E-8d);
        }
    }
}
