package opennlp.tools.ml.maxent;

import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
import opennlp.tools.ml.AbstractEventTrainer;
import opennlp.tools.ml.TrainerFactory;
import opennlp.tools.ml.model.DataIndexer;
import opennlp.tools.ml.model.DataIndexerFactory;
import opennlp.tools.ml.model.Event;
import opennlp.tools.util.ObjectStream;
import opennlp.tools.util.ObjectStreamUtils;
import opennlp.tools.util.TrainingParameters;
import opennlp.tools.util.model.ModelUtil;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;

/* loaded from: input_file:opennlp/tools/ml/maxent/GISIndexingTest.class */
public class GISIndexingTest {
    private static String[][] cntx = {new String[]{"dog", "cat", "mouse"}, new String[]{"text", "print", "mouse"}, new String[]{"dog", "pig", "cat", "mouse"}};
    private static String[] outputs = {"A", "B", "A"};

    private ObjectStream<Event> createEventStream() {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < cntx.length; i++) {
            arrayList.add(new Event(outputs[i], cntx[i]));
        }
        return ObjectStreamUtils.createObjectStream(arrayList);
    }

    @Test
    void testGISTrainSignature1() throws IOException {
        ObjectStream<Event> createEventStream = createEventStream();
        try {
            TrainingParameters createDefaultTrainingParameters = ModelUtil.createDefaultTrainingParameters();
            createDefaultTrainingParameters.put("Cutoff", 1);
            Assertions.assertNotNull(TrainerFactory.getEventTrainer(createDefaultTrainingParameters, (Map) null).train(createEventStream));
            if (createEventStream != null) {
                createEventStream.close();
            }
        } catch (Throwable th) {
            if (createEventStream != null) {
                try {
                    createEventStream.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Test
    void testGISTrainSignature2() throws IOException {
        ObjectStream<Event> createEventStream = createEventStream();
        try {
            TrainingParameters createDefaultTrainingParameters = ModelUtil.createDefaultTrainingParameters();
            createDefaultTrainingParameters.put("Cutoff", 1);
            createDefaultTrainingParameters.put("smoothing", true);
            Assertions.assertNotNull(TrainerFactory.getEventTrainer(createDefaultTrainingParameters, (Map) null).train(createEventStream));
            if (createEventStream != null) {
                createEventStream.close();
            }
        } catch (Throwable th) {
            if (createEventStream != null) {
                try {
                    createEventStream.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Test
    void testGISTrainSignature3() throws IOException {
        ObjectStream<Event> createEventStream = createEventStream();
        try {
            TrainingParameters createDefaultTrainingParameters = ModelUtil.createDefaultTrainingParameters();
            createDefaultTrainingParameters.put("Iterations", 10);
            createDefaultTrainingParameters.put("Cutoff", 1);
            Assertions.assertNotNull(TrainerFactory.getEventTrainer(createDefaultTrainingParameters, (Map) null).train(createEventStream));
            if (createEventStream != null) {
                createEventStream.close();
            }
        } catch (Throwable th) {
            if (createEventStream != null) {
                try {
                    createEventStream.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Test
    void testGISTrainSignature4() throws IOException {
        ObjectStream<Event> createEventStream = createEventStream();
        try {
            TrainingParameters createDefaultTrainingParameters = ModelUtil.createDefaultTrainingParameters();
            createDefaultTrainingParameters.put("Iterations", 10);
            createDefaultTrainingParameters.put("Cutoff", 1);
            GISTrainer eventTrainer = TrainerFactory.getEventTrainer(createDefaultTrainingParameters, (Map) null);
            eventTrainer.setGaussianSigma(0.01d);
            Assertions.assertNotNull(eventTrainer.trainModel(createEventStream));
            if (createEventStream != null) {
                createEventStream.close();
            }
        } catch (Throwable th) {
            if (createEventStream != null) {
                try {
                    createEventStream.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Test
    void testGISTrainSignature5() throws IOException {
        ObjectStream<Event> createEventStream = createEventStream();
        try {
            TrainingParameters createDefaultTrainingParameters = ModelUtil.createDefaultTrainingParameters();
            createDefaultTrainingParameters.put("Iterations", 10);
            createDefaultTrainingParameters.put("Cutoff", 1);
            createDefaultTrainingParameters.put("smoothing", false);
            createDefaultTrainingParameters.put("PrintMessages", false);
            Assertions.assertNotNull(TrainerFactory.getEventTrainer(createDefaultTrainingParameters, (Map) null).train(createEventStream));
            if (createEventStream != null) {
                createEventStream.close();
            }
        } catch (Throwable th) {
            if (createEventStream != null) {
                try {
                    createEventStream.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Test
    void testIndexingWithTrainingParameters() throws IOException {
        ObjectStream<Event> createEventStream = createEventStream();
        TrainingParameters defaultParams = TrainingParameters.defaultParams();
        defaultParams.put("Iterations", 10);
        defaultParams.put("DataIndexer", "OnePass");
        defaultParams.put("Cutoff", 1);
        defaultParams.put("sort", true);
        AbstractEventTrainer eventTrainer = TrainerFactory.getEventTrainer(defaultParams, new HashMap());
        Assertions.assertEquals("opennlp.tools.ml.maxent.GISTrainer", eventTrainer.getClass().getName());
        DataIndexer dataIndexer = eventTrainer.getDataIndexer(createEventStream);
        Assertions.assertEquals("opennlp.tools.ml.model.OnePassDataIndexer", dataIndexer.getClass().getName());
        Assertions.assertEquals(3, dataIndexer.getNumEvents());
        Assertions.assertEquals(2, dataIndexer.getOutcomeLabels().length);
        Assertions.assertEquals(6, dataIndexer.getPredLabels().length);
        createEventStream.reset();
        defaultParams.put("Algorithm", "MAXENT_QN");
        defaultParams.put("DataIndexer", "TwoPass");
        defaultParams.put("Cutoff", 2);
        AbstractEventTrainer eventTrainer2 = TrainerFactory.getEventTrainer(defaultParams, new HashMap());
        Assertions.assertEquals("opennlp.tools.ml.maxent.quasinewton.QNTrainer", eventTrainer2.getClass().getName());
        Assertions.assertEquals("opennlp.tools.ml.model.TwoPassDataIndexer", eventTrainer2.getDataIndexer(createEventStream).getClass().getName());
        createEventStream.close();
    }

    @Test
    void testIndexingFactory() throws IOException {
        HashMap hashMap = new HashMap();
        ObjectStream<Event> createEventStream = createEventStream();
        TrainingParameters trainingParameters = new TrainingParameters();
        trainingParameters.put("Cutoff", 1);
        trainingParameters.put("DataIndexer", "OnePass");
        DataIndexer dataIndexer = DataIndexerFactory.getDataIndexer(trainingParameters, hashMap);
        Assertions.assertEquals("opennlp.tools.ml.model.OnePassDataIndexer", dataIndexer.getClass().getName());
        dataIndexer.index(createEventStream);
        Assertions.assertEquals(3, dataIndexer.getNumEvents());
        Assertions.assertEquals(2, dataIndexer.getOutcomeLabels().length);
        Assertions.assertEquals(6, dataIndexer.getPredLabels().length);
        createEventStream.reset();
        trainingParameters.put("DataIndexer", "TwoPass");
        DataIndexer dataIndexer2 = DataIndexerFactory.getDataIndexer(trainingParameters, hashMap);
        Assertions.assertEquals("opennlp.tools.ml.model.TwoPassDataIndexer", dataIndexer2.getClass().getName());
        dataIndexer2.index(createEventStream);
        Assertions.assertEquals(3, dataIndexer2.getNumEvents());
        Assertions.assertEquals(2, dataIndexer2.getOutcomeLabels().length);
        Assertions.assertEquals(6, dataIndexer2.getPredLabels().length);
        createEventStream.close();
        trainingParameters.put("DataIndexer", "OnePassRealValue");
        Assertions.assertEquals("opennlp.tools.ml.model.OnePassRealValueDataIndexer", DataIndexerFactory.getDataIndexer(trainingParameters, hashMap).getClass().getName());
        trainingParameters.put("DataIndexer", "opennlp.tools.ml.maxent.MockDataIndexer");
        Assertions.assertEquals("opennlp.tools.ml.maxent.MockDataIndexer", DataIndexerFactory.getDataIndexer(trainingParameters, hashMap).getClass().getName());
    }
}
