package opennlp.tools.ml.perceptron;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import opennlp.tools.ml.EventTrainer;
import opennlp.tools.ml.PrepAttachDataUtil;
import opennlp.tools.ml.TrainerFactory;
import opennlp.tools.ml.model.AbstractModel;
import opennlp.tools.ml.model.TwoPassDataIndexer;
import opennlp.tools.util.TrainingParameters;
import org.junit.Assert;
import org.junit.Test;

/* loaded from: input_file:opennlp/tools/ml/perceptron/PerceptronPrepAttachTest.class */
public class PerceptronPrepAttachTest {
    @Test
    public void testPerceptronOnPrepAttachData() throws IOException {
        TwoPassDataIndexer twoPassDataIndexer = new TwoPassDataIndexer();
        TrainingParameters trainingParameters = new TrainingParameters();
        trainingParameters.put("Cutoff", 1);
        trainingParameters.put("sort", false);
        twoPassDataIndexer.init(trainingParameters, new HashMap());
        twoPassDataIndexer.index(PrepAttachDataUtil.createTrainingStream());
        PrepAttachDataUtil.testModel(new PerceptronTrainer().trainModel(400, twoPassDataIndexer, 1), 0.7650408516959644d);
    }

    @Test
    public void testPerceptronOnPrepAttachDataWithSkippedAveraging() throws IOException {
        TrainingParameters trainingParameters = new TrainingParameters();
        trainingParameters.put("Algorithm", "PERCEPTRON");
        trainingParameters.put("Cutoff", 1);
        trainingParameters.put("UseSkippedAveraging", true);
        PrepAttachDataUtil.testModel(TrainerFactory.getEventTrainer(trainingParameters, (Map) null).train(PrepAttachDataUtil.createTrainingStream()), 0.773706362961129d);
    }

    @Test
    public void testPerceptronOnPrepAttachDataWithTolerance() throws IOException {
        TrainingParameters trainingParameters = new TrainingParameters();
        trainingParameters.put("Algorithm", "PERCEPTRON");
        trainingParameters.put("Cutoff", 1);
        trainingParameters.put("Iterations", 500);
        trainingParameters.put("Tolerance", 1.0E-4d);
        PrepAttachDataUtil.testModel(TrainerFactory.getEventTrainer(trainingParameters, (Map) null).train(PrepAttachDataUtil.createTrainingStream()), 0.7677642980935875d);
    }

    @Test
    public void testPerceptronOnPrepAttachDataWithStepSizeDecrease() throws IOException {
        TrainingParameters trainingParameters = new TrainingParameters();
        trainingParameters.put("Algorithm", "PERCEPTRON");
        trainingParameters.put("Cutoff", 1);
        trainingParameters.put("Iterations", 500);
        trainingParameters.put("StepSizeDecrease", 0.06d);
        PrepAttachDataUtil.testModel(TrainerFactory.getEventTrainer(trainingParameters, (Map) null).train(PrepAttachDataUtil.createTrainingStream()), 0.7791532557563754d);
    }

    @Test
    public void testModelSerialization() throws IOException {
        TrainingParameters trainingParameters = new TrainingParameters();
        trainingParameters.put("Algorithm", "PERCEPTRON");
        trainingParameters.put("Cutoff", 1);
        trainingParameters.put("UseSkippedAveraging", true);
        AbstractModel train = TrainerFactory.getEventTrainer(trainingParameters, (Map) null).train(PrepAttachDataUtil.createTrainingStream());
        PrepAttachDataUtil.testModel(train, 0.773706362961129d);
        ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
        BinaryPerceptronModelWriter binaryPerceptronModelWriter = new BinaryPerceptronModelWriter(train, new DataOutputStream(byteArrayOutputStream));
        binaryPerceptronModelWriter.persist();
        binaryPerceptronModelWriter.close();
        PrepAttachDataUtil.testModel(new BinaryPerceptronModelReader(new DataInputStream(new ByteArrayInputStream(byteArrayOutputStream.toByteArray()))).getModel(), 0.773706362961129d);
    }

    @Test
    public void testModelEquals() throws IOException {
        TrainingParameters trainingParameters = new TrainingParameters();
        trainingParameters.put("Algorithm", "PERCEPTRON");
        trainingParameters.put("Cutoff", 1);
        trainingParameters.put("UseSkippedAveraging", true);
        EventTrainer eventTrainer = TrainerFactory.getEventTrainer(trainingParameters, (Map) null);
        Assert.assertEquals(eventTrainer.train(PrepAttachDataUtil.createTrainingStream()), eventTrainer.train(PrepAttachDataUtil.createTrainingStream()));
        Assert.assertEquals(r0.hashCode(), r0.hashCode());
    }

    @Test
    public void verifyReportMap() throws IOException {
        TrainingParameters trainingParameters = new TrainingParameters();
        trainingParameters.put("Algorithm", "PERCEPTRON");
        trainingParameters.put("Cutoff", 1);
        trainingParameters.put("Iterations", 1);
        trainingParameters.put("UseSkippedAveraging", true);
        HashMap hashMap = new HashMap();
        TrainerFactory.getEventTrainer(trainingParameters, hashMap).train(PrepAttachDataUtil.createTrainingStream());
        Assert.assertTrue("Report Map does not contain the training event hash", hashMap.containsKey("Training-Eventhash"));
    }
}
