/*
 * Decompiled with CFR 0.152.
 */
package opennlp.tools.ml.maxent.quasinewton;

import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.HashMap;
import opennlp.tools.ml.maxent.quasinewton.NegLogLikelihood;
import opennlp.tools.ml.model.DataIndexer;
import opennlp.tools.ml.model.OnePassRealValueDataIndexer;
import opennlp.tools.ml.model.RealValueFileEventStream;
import opennlp.tools.util.ObjectStream;
import opennlp.tools.util.TrainingParameters;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

public class NegLogLikelihoodTest {
    private static final double TOLERANCE01 = 1.0E-6;
    private static final double TOLERANCE02 = 1.0E-10;
    private DataIndexer testDataIndexer;

    @BeforeEach
    void initIndexer() {
        TrainingParameters trainingParameters = new TrainingParameters();
        trainingParameters.put("Cutoff", 1);
        this.testDataIndexer = new OnePassRealValueDataIndexer();
        this.testDataIndexer.init(trainingParameters, new HashMap());
    }

    @Test
    void testDomainDimensionSanity() throws IOException {
        RealValueFileEventStream rvfes1 = new RealValueFileEventStream("src/test/resources/data/opennlp/maxent/real-valued-weights-training-data.txt", StandardCharsets.UTF_8.name());
        this.testDataIndexer.index((ObjectStream)rvfes1);
        NegLogLikelihood objectFunction = new NegLogLikelihood(this.testDataIndexer);
        int correctDomainDimension = this.testDataIndexer.getPredLabels().length * this.testDataIndexer.getOutcomeLabels().length;
        Assertions.assertEquals((int)correctDomainDimension, (int)objectFunction.getDimension());
    }

    @Test
    void testInitialSanity() throws IOException {
        double[] initial;
        RealValueFileEventStream rvfes1 = new RealValueFileEventStream("src/test/resources/data/opennlp/maxent/real-valued-weights-training-data.txt", StandardCharsets.UTF_8.name());
        this.testDataIndexer.index((ObjectStream)rvfes1);
        NegLogLikelihood objectFunction = new NegLogLikelihood(this.testDataIndexer);
        for (double anInitial : initial = objectFunction.getInitialPoint()) {
            Assertions.assertEquals((double)0.0, (double)anInitial, (double)1.0E-6);
        }
    }

    @Test
    void testGradientSanity() throws IOException {
        RealValueFileEventStream rvfes1 = new RealValueFileEventStream("src/test/resources/data/opennlp/maxent/real-valued-weights-training-data.txt", StandardCharsets.UTF_8.name());
        this.testDataIndexer.index((ObjectStream)rvfes1);
        NegLogLikelihood objectFunction = new NegLogLikelihood(this.testDataIndexer);
        double[] initial = objectFunction.getInitialPoint();
        double[] gradientAtInitial = objectFunction.gradientAt(initial);
        Assertions.assertNotNull((Object)gradientAtInitial);
    }

    @Test
    void testValueAtInitialPoint() throws IOException {
        RealValueFileEventStream rvfes1 = new RealValueFileEventStream("src/test/resources/data/opennlp/maxent/real-valued-weights-training-data.txt", "UTF-8");
        this.testDataIndexer.index((ObjectStream)rvfes1);
        NegLogLikelihood objectFunction = new NegLogLikelihood(this.testDataIndexer);
        double value = objectFunction.valueAt(objectFunction.getInitialPoint());
        double expectedValue = 13.86294361;
        Assertions.assertEquals((double)expectedValue, (double)value, (double)1.0E-6);
    }

    @Test
    void testValueAtNonInitialPoint01() throws IOException {
        RealValueFileEventStream rvfes1 = new RealValueFileEventStream("src/test/resources/data/opennlp/maxent/real-valued-weights-training-data.txt", StandardCharsets.UTF_8.name());
        this.testDataIndexer.index((ObjectStream)rvfes1);
        NegLogLikelihood objectFunction = new NegLogLikelihood(this.testDataIndexer);
        double[] nonInitialPoint = new double[]{1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0};
        double value = objectFunction.valueAt(nonInitialPoint);
        double expectedValue = 13.862943611198894;
        Assertions.assertEquals((double)expectedValue, (double)value, (double)1.0E-6);
    }

    @Test
    void testValueAtNonInitialPoint02() throws IOException {
        RealValueFileEventStream rvfes1 = new RealValueFileEventStream("src/test/resources/data/opennlp/maxent/real-valued-weights-training-data.txt", StandardCharsets.UTF_8.name());
        this.testDataIndexer.index((ObjectStream)rvfes1);
        NegLogLikelihood objectFunction = new NegLogLikelihood(this.testDataIndexer);
        double[] nonInitialPoint = new double[]{3.0, 2.0, 3.0, 2.0, 3.0, 2.0, 3.0, 2.0, 3.0, 2.0};
        double value = objectFunction.valueAt(this.dealignDoubleArrayForTestData(nonInitialPoint, this.testDataIndexer.getPredLabels(), this.testDataIndexer.getOutcomeLabels()));
        double expectedValue = 53.163219721099026;
        Assertions.assertEquals((double)expectedValue, (double)value, (double)1.0E-10);
    }

    @Test
    void testGradientAtInitialPoint() throws IOException {
        RealValueFileEventStream rvfes1 = new RealValueFileEventStream("src/test/resources/data/opennlp/maxent/real-valued-weights-training-data.txt", StandardCharsets.UTF_8.name());
        this.testDataIndexer.index((ObjectStream)rvfes1);
        NegLogLikelihood objectFunction = new NegLogLikelihood(this.testDataIndexer);
        double[] gradientAtInitialPoint = objectFunction.gradientAt(objectFunction.getInitialPoint());
        double[] expectedGradient = new double[]{-9.0, -14.0, -17.0, 20.0, 8.5, 9.0, 14.0, 17.0, -20.0, -8.5};
        Assertions.assertTrue((boolean)this.compareDoubleArray(expectedGradient, gradientAtInitialPoint, this.testDataIndexer, 1.0E-6));
    }

    @Test
    void testGradientAtNonInitialPoint() throws IOException {
        RealValueFileEventStream rvfes1 = new RealValueFileEventStream("src/test/resources/data/opennlp/maxent/real-valued-weights-training-data.txt", StandardCharsets.UTF_8.name());
        this.testDataIndexer.index((ObjectStream)rvfes1);
        NegLogLikelihood objectFunction = new NegLogLikelihood(this.testDataIndexer);
        double[] nonInitialPoint = new double[]{0.2, 0.5, 0.2, 0.5, 0.2, 0.5, 0.2, 0.5, 0.2, 0.5};
        double[] gradientAtNonInitialPoint = objectFunction.gradientAt(this.dealignDoubleArrayForTestData(nonInitialPoint, this.testDataIndexer.getPredLabels(), this.testDataIndexer.getOutcomeLabels()));
        double[] expectedGradient = new double[]{-12.755042847945553, -21.227127506102434, -72.57790706276435, 38.03525795198456, 15.348650889354925, 12.755042847945557, 21.22712750610244, 72.57790706276438, -38.03525795198456, -15.348650889354925};
        Assertions.assertTrue((boolean)this.compareDoubleArray(expectedGradient, gradientAtNonInitialPoint, this.testDataIndexer, 1.0E-6));
    }

    private double[] alignDoubleArrayForTestData(double[] expected, String[] predLabels, String[] outcomeLabels) {
        int i;
        double[] aligned = new double[predLabels.length * outcomeLabels.length];
        Object[] sortedPredLabels = (String[])predLabels.clone();
        Object[] sortedOutcomeLabels = (String[])outcomeLabels.clone();
        Arrays.sort(sortedPredLabels);
        Arrays.sort(sortedOutcomeLabels);
        HashMap<String, Integer> invertedPredIndex = new HashMap<String, Integer>();
        HashMap<String, Integer> invertedOutcomeIndex = new HashMap<String, Integer>();
        for (i = 0; i < predLabels.length; ++i) {
            invertedPredIndex.put(predLabels[i], i);
        }
        for (i = 0; i < outcomeLabels.length; ++i) {
            invertedOutcomeIndex.put(outcomeLabels[i], i);
        }
        for (i = 0; i < sortedOutcomeLabels.length; ++i) {
            for (int j = 0; j < sortedPredLabels.length; ++j) {
                aligned[i * sortedPredLabels.length + j] = expected[(Integer)invertedOutcomeIndex.get(sortedOutcomeLabels[i]) * sortedPredLabels.length + (Integer)invertedPredIndex.get(sortedPredLabels[j])];
            }
        }
        return aligned;
    }

    private double[] dealignDoubleArrayForTestData(double[] expected, String[] predLabels, String[] outcomeLabels) {
        int i;
        double[] dealigned = new double[predLabels.length * outcomeLabels.length];
        Object[] sortedPredLabels = (String[])predLabels.clone();
        Object[] sortedOutcomeLabels = (String[])outcomeLabels.clone();
        Arrays.sort(sortedPredLabels);
        Arrays.sort(sortedOutcomeLabels);
        HashMap<String, Integer> invertedPredIndex = new HashMap<String, Integer>();
        HashMap<String, Integer> invertedOutcomeIndex = new HashMap<String, Integer>();
        for (i = 0; i < predLabels.length; ++i) {
            invertedPredIndex.put(predLabels[i], i);
        }
        for (i = 0; i < outcomeLabels.length; ++i) {
            invertedOutcomeIndex.put(outcomeLabels[i], i);
        }
        for (i = 0; i < sortedOutcomeLabels.length; ++i) {
            for (int j = 0; j < sortedPredLabels.length; ++j) {
                dealigned[((Integer)invertedOutcomeIndex.get((Object)sortedOutcomeLabels[i])).intValue() * sortedPredLabels.length + ((Integer)invertedPredIndex.get((Object)sortedPredLabels[j])).intValue()] = expected[i * sortedPredLabels.length + j];
            }
        }
        return dealigned;
    }

    private boolean compareDoubleArray(double[] expected, double[] actual, DataIndexer indexer, double tolerance) {
        double[] alignedActual = this.alignDoubleArrayForTestData(actual, indexer.getPredLabels(), indexer.getOutcomeLabels());
        if (expected.length != alignedActual.length) {
            return false;
        }
        for (int i = 0; i < alignedActual.length; ++i) {
            if (!(StrictMath.abs(alignedActual[i] - expected[i]) > tolerance)) continue;
            return false;
        }
        return true;
    }
}

