/*
 * Decompiled with CFR 0.152.
 */
package com.aliasi.test.unit.matrix;

import com.aliasi.matrix.SvdMatrix;
import java.util.Random;
import junit.framework.Assert;
import org.junit.Test;

public class SvdMatrixTest {
    static Random RANDOM = new Random();
    static int M = 5;
    static int N = 10;
    static int M2 = 1000;
    static int N2 = 500;
    static int MAX_INCR2 = 100;

    @Test
    public void testFixed() {
        double[][] values = new double[][]{{5.0, 9.0, 2.0}, {3.0, -4.0, 5.0}, {2.0, 5.0, 1.0}, {-8.0, 3.0, 3.0}};
        int m = values.length;
        int n = values[0].length;
        int[][] columnIds = new int[m][n];
        int i = 0;
        while (i < m) {
            int j = 0;
            while (j < n) {
                columnIds[i][j] = j;
                ++j;
            }
            ++i;
        }
        this.assertConverge(m, n, columnIds, values, 3, 0.001);
    }

    @Test
    public void testFull() {
        int[][] columnIds = new int[M][N];
        int i = 0;
        while (i < M) {
            int j = 0;
            while (j < N) {
                columnIds[i][j] = j;
                ++j;
            }
            ++i;
        }
        double[][] values = new double[M][N];
        int i2 = 0;
        while (i2 < M) {
            int j = 0;
            while (j < N) {
                values[i2][j] = this.random(1.0, 5.0);
                ++j;
            }
            ++i2;
        }
    }

    @Test
    public void testPartial() {
        int[][] columnIds = new int[M2][];
        double[][] values = new double[M2][];
        int i = 0;
        while (i < columnIds.length) {
            int[] columnIdsForRowBuf = new int[N2];
            int pos = 0;
            int j = 0;
            while (true) {
                int incr;
                if ((incr = RANDOM.nextInt(MAX_INCR2)) == 0 && j != 0) {
                    ++incr;
                }
                if ((j += incr) >= N2) break;
                columnIdsForRowBuf[pos++] = j;
            }
            columnIds[i] = new int[pos];
            int k = 0;
            while (k < pos) {
                columnIds[i][k] = columnIdsForRowBuf[k];
                ++k;
            }
            values[i] = new double[pos];
            k = 0;
            while (k < pos) {
                values[i][k] = this.random(1.0, 5.0);
                ++k;
            }
            ++i;
        }
    }

    void assertConverge(int numRows, int numCols, int[][] columnIds, double[][] values, int maxOrder, double tolerance) {
        double featureInit = 0.1;
        double initialLearningRate = 0.001;
        double annealingRate = 100000.0;
        double regularization = 0.0;
        double minImprovement = 0.0;
        int minEpochs = 1000;
        int maxEpochs = 1000000;
        SvdMatrix matrix = SvdMatrix.partialSvd(columnIds, values, maxOrder, featureInit, initialLearningRate, annealingRate, regularization, null, minImprovement, minEpochs, maxEpochs);
        int i = 0;
        while (i < numRows) {
            int j = 0;
            while (j < numCols) {
                ++j;
            }
            ++i;
        }
        double[] singularValues = matrix.singularValues();
        Assert.assertTrue((singularValues[0] >= 0.0 ? 1 : 0) != 0);
        int i2 = 1;
        while (i2 < singularValues.length) {
            Assert.assertTrue((singularValues[i2] <= singularValues[i2 - 1] ? 1 : 0) != 0);
            ++i2;
        }
        double[][] leftSingularVectors = matrix.leftSingularVectors();
        this.assertOrthonormal(leftSingularVectors);
        double[][] rightSingularVectors = matrix.rightSingularVectors();
        this.assertOrthonormal(rightSingularVectors);
        int i3 = 0;
        while (i3 < columnIds.length) {
            int j = 0;
            while (j < columnIds[i3].length) {
                int row = i3;
                int column = columnIds[i3][j];
                double val = values[i3][j];
                double estimatedVal = matrix.value(row, column);
                Assert.assertEquals((double)val, (double)estimatedVal, (double)tolerance);
                ++j;
            }
            ++i3;
        }
    }

    void assertOrthonormal(double[][] xs) {
        int numCols = xs[0].length;
        int j = 0;
        while (j < numCols) {
            this.assertUnitLengthColumn(xs, j);
            int k = j + 1;
            while (k < numCols) {
                this.assertOrthogonalColumns("col=" + j + " col2=" + k, xs, j, k);
                ++k;
            }
            ++j;
        }
    }

    void assertUnitLengthColumn(double[][] xs, int j) {
        double sum = 0.0;
        int i = 0;
        while (i < xs.length) {
            sum += xs[i][j] * xs[i][j];
            ++i;
        }
        Assert.assertEquals((String)"unit columns", (double)1.0, (double)sum, (double)0.01);
    }

    void assertOrthogonalColumns(String msg, double[][] xs, int i, int j) {
        double sum = 0.0;
        int k = 0;
        while (k < xs.length) {
            sum += xs[k][i] * xs[k][j];
            ++k;
        }
        Assert.assertEquals((String)("ortho columns " + msg), (double)0.0, (double)sum, (double)0.01);
    }

    void assertUnitLength(double[] xs) {
        this.assertProduct(xs, xs, 1.0);
    }

    void assertOrthogonal(String msg, double[] xs, double[] ys) {
        this.assertProduct(msg, xs, ys, 0.0);
    }

    void assertProduct(double[] xs, double[] ys, double expected) {
        this.assertProduct("", xs, ys, expected);
    }

    void assertProduct(String msg, double[] xs, double[] ys, double expected) {
        double sum = 0.0;
        int i = 0;
        while (i < xs.length) {
            sum += xs[i] * ys[i];
            ++i;
        }
        Assert.assertEquals((String)msg, (double)expected, (double)sum, (double)0.01);
    }

    double random(double min, double max) {
        return min + (max - min) * RANDOM.nextDouble();
    }

    void printMatrix(double[][] xs) {
        int i = 0;
        while (i < xs.length) {
            int j = 0;
            while (j < xs[i].length) {
                if (j > 0) {
                    System.out.print(", ");
                }
                this.printNumber(xs[i][j]);
                ++j;
            }
            System.out.println();
            ++i;
        }
    }

    void printNumber(double x) {
        System.out.printf("% 7.3f", x);
    }
}

