/*
 * Decompiled with CFR 0.152.
 */
package opennlp.tools.ml.maxent.quasinewton;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.HashMap;
import opennlp.tools.ml.maxent.quasinewton.QNModel;
import opennlp.tools.ml.maxent.quasinewton.QNTrainer;
import opennlp.tools.ml.model.AbstractModel;
import opennlp.tools.ml.model.BinaryFileDataReader;
import opennlp.tools.ml.model.DataIndexer;
import opennlp.tools.ml.model.DataReader;
import opennlp.tools.ml.model.GenericModelReader;
import opennlp.tools.ml.model.GenericModelWriter;
import opennlp.tools.ml.model.OnePassRealValueDataIndexer;
import opennlp.tools.ml.model.RealValueFileEventStream;
import opennlp.tools.util.ObjectStream;
import opennlp.tools.util.TrainingParameters;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

public class QNTrainerTest {
    private static final int ITERATIONS = 50;
    private DataIndexer testDataIndexer;

    @BeforeEach
    void initIndexer() {
        TrainingParameters trainingParameters = new TrainingParameters();
        trainingParameters.put("Cutoff", 1);
        this.testDataIndexer = new OnePassRealValueDataIndexer();
        this.testDataIndexer.init(trainingParameters, new HashMap());
    }

    @Test
    void testTrainModelReturnsAQNModel() throws Exception {
        RealValueFileEventStream rvfes1 = new RealValueFileEventStream("src/test/resources/data/opennlp/maxent/real-valued-weights-training-data.txt");
        this.testDataIndexer.index((ObjectStream)rvfes1);
        QNModel trainedModel = new QNTrainer().trainModel(50, this.testDataIndexer);
        Assertions.assertNotNull((Object)trainedModel);
    }

    @Test
    void testInTinyDevSet() throws Exception {
        RealValueFileEventStream rvfes1 = new RealValueFileEventStream("src/test/resources/data/opennlp/maxent/real-valued-weights-training-data.txt");
        this.testDataIndexer.index((ObjectStream)rvfes1);
        QNModel trainedModel = new QNTrainer(15).trainModel(50, this.testDataIndexer);
        String[] features2Classify = new String[]{"feature2", "feature3", "feature3", "feature3", "feature3", "feature3", "feature3", "feature3", "feature3", "feature3", "feature3", "feature3"};
        double[] eval = trainedModel.eval(features2Classify);
        Assertions.assertNotNull((Object)eval);
    }

    @Test
    void testModel() throws IOException {
        RealValueFileEventStream rvfes1 = new RealValueFileEventStream("src/test/resources/data/opennlp/maxent/real-valued-weights-training-data.txt");
        this.testDataIndexer.index((ObjectStream)rvfes1);
        QNModel trainedModel = new QNTrainer(15).trainModel(50, this.testDataIndexer);
        Assertions.assertNotEquals(null, (Object)trainedModel);
    }

    @Test
    void testSerdeModel() throws IOException {
        RealValueFileEventStream rvfes1 = new RealValueFileEventStream("src/test/resources/data/opennlp/maxent/real-valued-weights-training-data.txt");
        this.testDataIndexer.index((ObjectStream)rvfes1);
        QNModel trainedModel = new QNTrainer(5, 700).trainModel(50, this.testDataIndexer);
        ByteArrayOutputStream modelBytes = new ByteArrayOutputStream();
        GenericModelWriter modelWriter = new GenericModelWriter((AbstractModel)trainedModel, new DataOutputStream(modelBytes));
        modelWriter.persist();
        modelWriter.close();
        GenericModelReader modelReader = new GenericModelReader((DataReader)new BinaryFileDataReader((InputStream)new ByteArrayInputStream(modelBytes.toByteArray())));
        AbstractModel readModel = modelReader.getModel();
        QNModel deserModel = (QNModel)readModel;
        Assertions.assertEquals((Object)trainedModel, (Object)deserModel);
        String[] features2Classify = new String[]{"feature2", "feature3", "feature3", "feature3", "feature3", "feature3", "feature3", "feature3", "feature3", "feature3", "feature3", "feature3"};
        double[] eval01 = trainedModel.eval(features2Classify);
        double[] eval02 = deserModel.eval(features2Classify);
        Assertions.assertEquals((int)eval01.length, (int)eval02.length);
        for (int i = 0; i < eval01.length; ++i) {
            Assertions.assertEquals((double)eval01[i], (double)eval02[i], (double)1.0E-8);
        }
    }
}

