package opennlp.tools.ml.maxent.quasinewton;

import java.io.IOException;
import java.util.Arrays;
import java.util.HashMap;
import opennlp.tools.ml.model.DataIndexer;
import opennlp.tools.ml.model.OnePassRealValueDataIndexer;
import opennlp.tools.ml.model.RealValueFileEventStream;
import opennlp.tools.util.TrainingParameters;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;

/* loaded from: input_file:opennlp/tools/ml/maxent/quasinewton/NegLogLikelihoodTest.class */
public class NegLogLikelihoodTest {
    private static final double TOLERANCE01 = 1.0E-6d;
    private static final double TOLERANCE02 = 1.0E-10d;
    private DataIndexer testDataIndexer;

    @Before
    public void initIndexer() {
        TrainingParameters trainingParameters = new TrainingParameters();
        trainingParameters.put("Cutoff", 1);
        this.testDataIndexer = new OnePassRealValueDataIndexer();
        this.testDataIndexer.init(trainingParameters, new HashMap());
    }

    @Test
    public void testDomainDimensionSanity() throws IOException {
        this.testDataIndexer.index(new RealValueFileEventStream("src/test/resources/data/opennlp/maxent/real-valued-weights-training-data.txt", "UTF-8"));
        Assert.assertEquals(this.testDataIndexer.getPredLabels().length * this.testDataIndexer.getOutcomeLabels().length, new NegLogLikelihood(this.testDataIndexer).getDimension());
    }

    @Test
    public void testInitialSanity() throws IOException {
        this.testDataIndexer.index(new RealValueFileEventStream("src/test/resources/data/opennlp/maxent/real-valued-weights-training-data.txt", "UTF-8"));
        for (double d : new NegLogLikelihood(this.testDataIndexer).getInitialPoint()) {
            Assert.assertEquals(0.0d, d, TOLERANCE01);
        }
    }

    @Test
    public void testGradientSanity() throws IOException {
        this.testDataIndexer.index(new RealValueFileEventStream("src/test/resources/data/opennlp/maxent/real-valued-weights-training-data.txt", "UTF-8"));
        NegLogLikelihood negLogLikelihood = new NegLogLikelihood(this.testDataIndexer);
        Assert.assertNotNull(negLogLikelihood.gradientAt(negLogLikelihood.getInitialPoint()));
    }

    @Test
    public void testValueAtInitialPoint() throws IOException {
        this.testDataIndexer.index(new RealValueFileEventStream("src/test/resources/data/opennlp/maxent/real-valued-weights-training-data.txt", "UTF-8"));
        NegLogLikelihood negLogLikelihood = new NegLogLikelihood(this.testDataIndexer);
        Assert.assertEquals(13.86294361d, negLogLikelihood.valueAt(negLogLikelihood.getInitialPoint()), TOLERANCE01);
    }

    @Test
    public void testValueAtNonInitialPoint01() throws IOException {
        this.testDataIndexer.index(new RealValueFileEventStream("src/test/resources/data/opennlp/maxent/real-valued-weights-training-data.txt", "UTF-8"));
        Assert.assertEquals(13.862943611198894d, new NegLogLikelihood(this.testDataIndexer).valueAt(new double[]{1.0d, 1.0d, 1.0d, 1.0d, 1.0d, 1.0d, 1.0d, 1.0d, 1.0d, 1.0d}), TOLERANCE01);
    }

    @Test
    public void testValueAtNonInitialPoint02() throws IOException {
        this.testDataIndexer.index(new RealValueFileEventStream("src/test/resources/data/opennlp/maxent/real-valued-weights-training-data.txt", "UTF-8"));
        Assert.assertEquals(53.163219721099026d, new NegLogLikelihood(this.testDataIndexer).valueAt(dealignDoubleArrayForTestData(new double[]{3.0d, 2.0d, 3.0d, 2.0d, 3.0d, 2.0d, 3.0d, 2.0d, 3.0d, 2.0d}, this.testDataIndexer.getPredLabels(), this.testDataIndexer.getOutcomeLabels())), TOLERANCE02);
    }

    @Test
    public void testGradientAtInitialPoint() throws IOException {
        this.testDataIndexer.index(new RealValueFileEventStream("src/test/resources/data/opennlp/maxent/real-valued-weights-training-data.txt", "UTF-8"));
        NegLogLikelihood negLogLikelihood = new NegLogLikelihood(this.testDataIndexer);
        Assert.assertTrue(compareDoubleArray(new double[]{-9.0d, -14.0d, -17.0d, 20.0d, 8.5d, 9.0d, 14.0d, 17.0d, -20.0d, -8.5d}, negLogLikelihood.gradientAt(negLogLikelihood.getInitialPoint()), this.testDataIndexer, TOLERANCE01));
    }

    @Test
    public void testGradientAtNonInitialPoint() throws IOException {
        this.testDataIndexer.index(new RealValueFileEventStream("src/test/resources/data/opennlp/maxent/real-valued-weights-training-data.txt", "UTF-8"));
        Assert.assertTrue(compareDoubleArray(new double[]{-12.755042847945553d, -21.227127506102434d, -72.57790706276435d, 38.03525795198456d, 15.348650889354925d, 12.755042847945557d, 21.22712750610244d, 72.57790706276438d, -38.03525795198456d, -15.348650889354925d}, new NegLogLikelihood(this.testDataIndexer).gradientAt(dealignDoubleArrayForTestData(new double[]{0.2d, 0.5d, 0.2d, 0.5d, 0.2d, 0.5d, 0.2d, 0.5d, 0.2d, 0.5d}, this.testDataIndexer.getPredLabels(), this.testDataIndexer.getOutcomeLabels())), this.testDataIndexer, TOLERANCE01));
    }

    private double[] alignDoubleArrayForTestData(double[] dArr, String[] strArr, String[] strArr2) {
        double[] dArr2 = new double[strArr.length * strArr2.length];
        String[] strArr3 = (String[]) strArr.clone();
        String[] strArr4 = (String[]) strArr2.clone();
        Arrays.sort(strArr3);
        Arrays.sort(strArr4);
        HashMap hashMap = new HashMap();
        HashMap hashMap2 = new HashMap();
        for (int i = 0; i < strArr.length; i++) {
            hashMap.put(strArr[i], Integer.valueOf(i));
        }
        for (int i2 = 0; i2 < strArr2.length; i2++) {
            hashMap2.put(strArr2[i2], Integer.valueOf(i2));
        }
        for (int i3 = 0; i3 < strArr4.length; i3++) {
            for (int i4 = 0; i4 < strArr3.length; i4++) {
                dArr2[(i3 * strArr3.length) + i4] = dArr[(((Integer) hashMap2.get(strArr4[i3])).intValue() * strArr3.length) + ((Integer) hashMap.get(strArr3[i4])).intValue()];
            }
        }
        return dArr2;
    }

    private double[] dealignDoubleArrayForTestData(double[] dArr, String[] strArr, String[] strArr2) {
        double[] dArr2 = new double[strArr.length * strArr2.length];
        String[] strArr3 = (String[]) strArr.clone();
        String[] strArr4 = (String[]) strArr2.clone();
        Arrays.sort(strArr3);
        Arrays.sort(strArr4);
        HashMap hashMap = new HashMap();
        HashMap hashMap2 = new HashMap();
        for (int i = 0; i < strArr.length; i++) {
            hashMap.put(strArr[i], Integer.valueOf(i));
        }
        for (int i2 = 0; i2 < strArr2.length; i2++) {
            hashMap2.put(strArr2[i2], Integer.valueOf(i2));
        }
        for (int i3 = 0; i3 < strArr4.length; i3++) {
            for (int i4 = 0; i4 < strArr3.length; i4++) {
                dArr2[(((Integer) hashMap2.get(strArr4[i3])).intValue() * strArr3.length) + ((Integer) hashMap.get(strArr3[i4])).intValue()] = dArr[(i3 * strArr3.length) + i4];
            }
        }
        return dArr2;
    }

    private boolean compareDoubleArray(double[] dArr, double[] dArr2, DataIndexer dataIndexer, double d) {
        double[] alignDoubleArrayForTestData = alignDoubleArrayForTestData(dArr2, dataIndexer.getPredLabels(), dataIndexer.getOutcomeLabels());
        if (dArr.length != alignDoubleArrayForTestData.length) {
            return false;
        }
        for (int i = 0; i < alignDoubleArrayForTestData.length; i++) {
            if (Math.abs(alignDoubleArrayForTestData[i] - dArr[i]) > d) {
                return false;
            }
        }
        return true;
    }
}
