/*
 * Decompiled with CFR 0.152.
 */
package com.aliasi.test.unit.stats;

import com.aliasi.corpus.ObjectHandler;
import com.aliasi.io.Reporter;
import com.aliasi.matrix.DenseVector;
import com.aliasi.matrix.SparseFloatVector;
import com.aliasi.matrix.Vector;
import com.aliasi.stats.AnnealingSchedule;
import com.aliasi.stats.LogisticRegression;
import com.aliasi.stats.RegressionPrior;
import com.aliasi.util.AbstractExternalizable;
import java.io.IOException;
import junit.framework.Assert;
import org.junit.Test;

public class LogisticRegressionTest {
    static final int[] WALLET_OUTCOME_VECTOR;
    static final double[][] WALLET_DATA_MATRIX;
    static final double[][] WALLET_EXPECTED_FEATURES;

    static {
        int[] nArray = new int[195];
        nArray[0] = 1;
        nArray[1] = 1;
        nArray[2] = 2;
        nArray[3] = 2;
        nArray[5] = 2;
        nArray[6] = 2;
        nArray[7] = 2;
        nArray[8] = 2;
        nArray[9] = 2;
        nArray[10] = 1;
        nArray[11] = 2;
        nArray[12] = 2;
        nArray[13] = 2;
        nArray[14] = 2;
        nArray[15] = 2;
        nArray[16] = 2;
        nArray[17] = 1;
        nArray[19] = 1;
        nArray[20] = 1;
        nArray[21] = 2;
        nArray[22] = 2;
        nArray[23] = 2;
        nArray[24] = 2;
        nArray[25] = 1;
        nArray[26] = 1;
        nArray[28] = 2;
        nArray[29] = 2;
        nArray[30] = 2;
        nArray[31] = 2;
        nArray[33] = 2;
        nArray[34] = 2;
        nArray[35] = 2;
        nArray[36] = 2;
        nArray[37] = 2;
        nArray[38] = 2;
        nArray[39] = 2;
        nArray[40] = 2;
        nArray[41] = 2;
        nArray[42] = 2;
        nArray[43] = 2;
        nArray[44] = 2;
        nArray[45] = 2;
        nArray[46] = 2;
        nArray[47] = 2;
        nArray[48] = 2;
        nArray[49] = 2;
        nArray[50] = 1;
        nArray[51] = 2;
        nArray[52] = 2;
        nArray[53] = 2;
        nArray[54] = 2;
        nArray[55] = 2;
        nArray[56] = 2;
        nArray[57] = 1;
        nArray[58] = 2;
        nArray[59] = 2;
        nArray[60] = 2;
        nArray[61] = 2;
        nArray[62] = 2;
        nArray[64] = 2;
        nArray[65] = 2;
        nArray[67] = 2;
        nArray[68] = 1;
        nArray[71] = 2;
        nArray[72] = 2;
        nArray[73] = 1;
        nArray[74] = 1;
        nArray[75] = 1;
        nArray[76] = 2;
        nArray[77] = 2;
        nArray[78] = 2;
        nArray[79] = 2;
        nArray[80] = 2;
        nArray[81] = 2;
        nArray[82] = 2;
        nArray[83] = 2;
        nArray[84] = 1;
        nArray[85] = 2;
        nArray[86] = 2;
        nArray[87] = 1;
        nArray[88] = 2;
        nArray[89] = 2;
        nArray[90] = 2;
        nArray[91] = 2;
        nArray[92] = 2;
        nArray[93] = 2;
        nArray[94] = 2;
        nArray[97] = 1;
        nArray[99] = 1;
        nArray[101] = 1;
        nArray[103] = 2;
        nArray[104] = 2;
        nArray[105] = 1;
        nArray[106] = 2;
        nArray[108] = 2;
        nArray[109] = 1;
        nArray[110] = 2;
        nArray[111] = 2;
        nArray[112] = 1;
        nArray[113] = 2;
        nArray[114] = 2;
        nArray[116] = 1;
        nArray[117] = 1;
        nArray[120] = 2;
        nArray[121] = 2;
        nArray[122] = 2;
        nArray[123] = 2;
        nArray[124] = 2;
        nArray[125] = 2;
        nArray[126] = 2;
        nArray[127] = 2;
        nArray[128] = 2;
        nArray[129] = 1;
        nArray[130] = 1;
        nArray[131] = 2;
        nArray[132] = 1;
        nArray[133] = 2;
        nArray[134] = 1;
        nArray[135] = 2;
        nArray[136] = 2;
        nArray[138] = 2;
        nArray[139] = 2;
        nArray[140] = 2;
        nArray[141] = 2;
        nArray[142] = 1;
        nArray[143] = 2;
        nArray[144] = 1;
        nArray[145] = 2;
        nArray[146] = 1;
        nArray[147] = 2;
        nArray[148] = 2;
        nArray[149] = 2;
        nArray[150] = 2;
        nArray[151] = 1;
        nArray[152] = 2;
        nArray[153] = 2;
        nArray[154] = 1;
        nArray[155] = 2;
        nArray[156] = 2;
        nArray[157] = 1;
        nArray[158] = 2;
        nArray[159] = 1;
        nArray[160] = 2;
        nArray[162] = 2;
        nArray[163] = 1;
        nArray[165] = 1;
        nArray[166] = 2;
        nArray[167] = 1;
        nArray[168] = 2;
        nArray[169] = 1;
        nArray[170] = 1;
        nArray[172] = 1;
        nArray[173] = 1;
        nArray[175] = 1;
        nArray[176] = 1;
        nArray[177] = 2;
        nArray[178] = 2;
        nArray[179] = 1;
        nArray[181] = 1;
        nArray[182] = 2;
        nArray[183] = 1;
        nArray[184] = 2;
        nArray[186] = 1;
        nArray[187] = 2;
        nArray[188] = 1;
        nArray[189] = 2;
        nArray[190] = 2;
        nArray[191] = 2;
        nArray[192] = 2;
        nArray[193] = 2;
        nArray[194] = 1;
        WALLET_OUTCOME_VECTOR = nArray;
        WALLET_DATA_MATRIX = new double[][]{{1.0, 0.0, 0.0, 2.0, 0.0}, {1.0, 0.0, 0.0, 2.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 2.0, 0.0}, {1.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 2.0, 1.0}, {1.0, 0.0, 1.0, 1.0, 1.0}, {1.0, 1.0, 1.0, 1.0, 1.0}, {1.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 0.0}, {1.0, 0.0, 0.0, 2.0, 1.0}, {1.0, 0.0, 0.0, 3.0, 0.0}, {1.0, 1.0, 1.0, 3.0, 0.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 0.0, 2.0, 0.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 1.0, 1.0, 0.0}, {1.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 0.0}, {1.0, 1.0, 0.0, 3.0, 0.0}, {1.0, 1.0, 0.0, 2.0, 0.0}, {1.0, 1.0, 0.0, 2.0, 0.0}, {1.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 0.0, 2.0, 1.0}, {1.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 2.0, 0.0}, {1.0, 1.0, 0.0, 1.0, 0.0}, {1.0, 1.0, 1.0, 2.0, 1.0}, {1.0, 0.0, 0.0, 2.0, 0.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 2.0, 0.0}, {1.0, 1.0, 1.0, 1.0, 1.0}, {1.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 0.0}, {1.0, 0.0, 1.0, 3.0, 0.0}, {1.0, 1.0, 0.0, 2.0, 0.0}, {1.0, 0.0, 0.0, 2.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 0.0, 1.0, 0.0}, {1.0, 1.0, 1.0, 1.0, 0.0}, {1.0, 1.0, 0.0, 1.0, 0.0}, {1.0, 1.0, 1.0, 3.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 3.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 0.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 1.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 0.0}, {1.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 0.0, 1.0, 0.0}, {1.0, 1.0, 0.0, 3.0, 1.0}, {1.0, 1.0, 0.0, 3.0, 1.0}, {1.0, 1.0, 1.0, 2.0, 1.0}, {1.0, 1.0, 0.0, 2.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 3.0, 0.0}, {1.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 0.0}, {1.0, 1.0, 1.0, 1.0, 1.0}, {1.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 1.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 2.0, 1.0}, {1.0, 1.0, 1.0, 1.0, 0.0}, {1.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 1.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 1.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 1.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 2.0, 1.0}, {1.0, 1.0, 1.0, 1.0, 1.0}, {1.0, 1.0, 0.0, 2.0, 0.0}, {1.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 1.0, 1.0, 1.0}, {1.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 1.0, 3.0, 0.0}, {1.0, 1.0, 1.0, 1.0, 1.0}, {1.0, 1.0, 1.0, 3.0, 1.0}, {1.0, 0.0, 0.0, 3.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 1.0, 3.0, 0.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 0.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 0.0, 3.0, 0.0}, {1.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 1.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 3.0, 0.0}, {1.0, 0.0, 1.0, 2.0, 0.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 1.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 0.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 1.0, 1.0, 1.0}, {1.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 1.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 0.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 1.0, 2.0, 0.0}, {1.0, 1.0, 0.0, 1.0, 0.0}, {1.0, 1.0, 0.0, 1.0, 0.0}, {1.0, 0.0, 0.0, 2.0, 1.0}, {1.0, 1.0, 1.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 3.0, 0.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 1.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 0.0}, {1.0, 0.0, 1.0, 1.0, 0.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 0.0, 1.0, 0.0}, {1.0, 0.0, 0.0, 1.0, 0.0}, {1.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 0.0, 2.0, 0.0}, {1.0, 1.0, 1.0, 2.0, 1.0}, {1.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 0.0, 1.0, 0.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 1.0, 1.0, 0.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 0.0, 1.0, 0.0}, {1.0, 0.0, 1.0, 2.0, 1.0}, {1.0, 1.0, 1.0, 2.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 1.0, 3.0, 1.0}, {1.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 0.0}, {1.0, 1.0, 0.0, 3.0, 0.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 1.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 2.0, 0.0}, {1.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 1.0, 2.0, 0.0}, {1.0, 1.0, 0.0, 2.0, 0.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 1.0, 3.0, 0.0}, {1.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 0.0, 2.0, 0.0}, {1.0, 0.0, 1.0, 2.0, 1.0}, {1.0, 0.0, 0.0, 2.0, 0.0}, {1.0, 1.0, 1.0, 1.0, 1.0}, {1.0, 0.0, 1.0, 2.0, 1.0}, {1.0, 0.0, 0.0, 3.0, 0.0}, {1.0, 1.0, 1.0, 1.0, 0.0}, {1.0, 0.0, 0.0, 3.0, 1.0}, {1.0, 1.0, 0.0, 2.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 0.0, 3.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 1.0, 1.0, 1.0}, {1.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 0.0, 1.0, 1.0}};
        WALLET_EXPECTED_FEATURES = new double[][]{{-3.4712, 1.2673, 1.1804, 1.0817, -1.6006}, {-1.2917, 1.1699, 0.4179, 0.1957, -0.804}, {0.0, 0.0, 0.0, 0.0, 0.0}};
    }

    @Test
    public void testClass() {
        Vector[] weightVectors = new Vector[]{new DenseVector(new double[]{1.0, 2.0, 3.0}), new DenseVector(new double[]{-2.0, 1.0, -1.0})};
        LogisticRegression regression = new LogisticRegression(weightVectors);
        DenseVector testCase = new DenseVector(new double[]{1.0, -1.0, 2.0});
        double prod1 = 5.0;
        double prod2 = -5.0;
        double prod3 = 0.0;
        double prop1 = Math.exp(prod1);
        double prop2 = Math.exp(prod2);
        double prop3 = Math.exp(prod3);
        Assert.assertEquals((double)1.0, (double)prop3, (double)1.0E-4);
        double p1 = prop1 / (prop1 + prop2 + prop3);
        double p2 = prop2 / (prop1 + prop2 + prop3);
        double p3 = prop3 / (prop1 + prop2 + prop3);
        double[] expected = new double[]{p1, p2, p3};
        double[] estimated = regression.classify(testCase);
        Assert.assertEquals((int)expected.length, (int)estimated.length);
        int i = 0;
        while (i < expected.length) {
            Assert.assertEquals((double)expected[i], (double)estimated[i], (double)1.0E-7);
            ++i;
        }
    }

    static Vector[] sparseCopy(Vector[] matrix) {
        Vector[] result = new Vector[matrix.length];
        int i = 0;
        while (i < matrix.length) {
            result[i] = LogisticRegressionTest.sparseCopy(matrix[i]);
            ++i;
        }
        return result;
    }

    static Vector sparseCopy(Vector v) {
        int[] dims = new int[v.numDimensions()];
        float[] vals = new float[v.numDimensions()];
        int i = 0;
        while (i < dims.length) {
            dims[i] = i;
            vals[i] = (float)v.value(i);
            ++i;
        }
        return new SparseFloatVector(dims, vals, v.numDimensions());
    }

    @Test
    public void testEstimation() throws IOException, ClassNotFoundException {
        Vector[] data_matrix = new Vector[WALLET_DATA_MATRIX.length];
        int i = 0;
        while (i < data_matrix.length) {
            data_matrix[i] = new DenseVector(WALLET_DATA_MATRIX[i]);
            ++i;
        }
        Vector[] sparse_data_matrix = LogisticRegressionTest.sparseCopy(data_matrix);
        this.assertCorrectRegression(data_matrix);
        this.assertCorrectRegression(sparse_data_matrix);
    }

    void assertCorrectRegression(Vector[] data_matrix) throws IOException, ClassNotFoundException {
        Reporter reporter = null;
        LogisticRegression hotStart = null;
        ObjectHandler<LogisticRegression> handler = null;
        int priorBlockSize = 3;
        LogisticRegression regression = LogisticRegression.estimate(data_matrix, WALLET_OUTCOME_VECTOR, RegressionPrior.noninformative(), priorBlockSize, hotStart, AnnealingSchedule.inverse(0.05, 100.0), 1.0E-5, 5, 10, 500000, handler, reporter);
        Vector[] vs = regression.weightVectors();
        int i = 0;
        while (i < vs.length) {
            int j = 0;
            while (j < vs[i].numDimensions()) {
                Assert.assertEquals((double)WALLET_EXPECTED_FEATURES[i][j], (double)vs[i].value(j), (double)0.1);
                ++j;
            }
            ++i;
        }
        LogisticRegression regression2 = (LogisticRegression)AbstractExternalizable.compile(regression);
        Assert.assertEquals((int)regression.numOutcomes(), (int)regression2.numOutcomes());
        Assert.assertEquals((int)regression.numInputDimensions(), (int)regression.numInputDimensions());
        Vector[] vs1 = regression.weightVectors();
        Vector[] vs2 = regression2.weightVectors();
        Assert.assertEquals((int)vs1.length, (int)vs2.length);
        Assert.assertEquals((int)vs1.length, (int)vs2.length);
        int i2 = 0;
        while (i2 < vs1.length) {
            Assert.assertEquals((Object)vs1[i2], (Object)vs2[i2]);
            ++i2;
        }
        hotStart = regression;
        priorBlockSize = 2;
        LogisticRegression regression3 = LogisticRegression.estimate(data_matrix, WALLET_OUTCOME_VECTOR, RegressionPrior.noninformative(), priorBlockSize, hotStart, AnnealingSchedule.inverse(0.05, 100.0), 1.0E-7, 5, 10, 500000, handler, reporter);
        vs = regression3.weightVectors();
        int i3 = 0;
        while (i3 < vs.length) {
            int j = 0;
            while (j < vs[i3].numDimensions()) {
                Assert.assertEquals((double)WALLET_EXPECTED_FEATURES[i3][j], (double)vs[i3].value(j), (double)0.1);
                ++j;
            }
            ++i3;
        }
    }
}

