/*
 * Decompiled with CFR 0.152.
 */
package com.aliasi.matrix;

import com.aliasi.io.Reporter;
import com.aliasi.io.Reporters;
import com.aliasi.matrix.AbstractMatrix;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Random;

public class SvdMatrix
extends AbstractMatrix {
    private final double[][] mRowVectors;
    private final double[][] mColumnVectors;
    private final int mOrder;
    static final double[][] EMPTY_DOUBLE_2D_ARRAY = new double[0][];

    public SvdMatrix(double[][] rowVectors, double[][] columnVectors, int order) {
        SvdMatrix.verifyDimensions("row", order, rowVectors);
        SvdMatrix.verifyDimensions("column", order, columnVectors);
        this.mRowVectors = rowVectors;
        this.mColumnVectors = columnVectors;
        this.mOrder = order;
    }

    public SvdMatrix(double[][] rowSingularVectors, double[][] columnSingularVectors, double[] singularValues) {
        this.mOrder = singularValues.length;
        SvdMatrix.verifyDimensions("row", this.mOrder, rowSingularVectors);
        SvdMatrix.verifyDimensions("column", this.mOrder, columnSingularVectors);
        this.mRowVectors = new double[rowSingularVectors.length][this.mOrder];
        this.mColumnVectors = new double[columnSingularVectors.length][this.mOrder];
        double[] sqrtSingularValues = new double[singularValues.length];
        int i = 0;
        while (i < sqrtSingularValues.length) {
            sqrtSingularValues[i] = Math.sqrt(singularValues[i]);
            ++i;
        }
        SvdMatrix.scale(this.mRowVectors, rowSingularVectors, sqrtSingularValues);
        SvdMatrix.scale(this.mColumnVectors, columnSingularVectors, sqrtSingularValues);
    }

    @Override
    public int numRows() {
        return this.mRowVectors.length;
    }

    @Override
    public int numColumns() {
        return this.mColumnVectors.length;
    }

    public int order() {
        return this.mRowVectors[0].length;
    }

    @Override
    public double value(int row, int column) {
        double[] rowVec = this.mRowVectors[row];
        double[] colVec = this.mColumnVectors[column];
        double result = 0.0;
        int i = 0;
        while (i < rowVec.length) {
            result += rowVec[i] * colVec[i];
            ++i;
        }
        return result;
    }

    public double[] singularValues() {
        double[] singularValues = new double[this.mOrder];
        int i = 0;
        while (i < singularValues.length) {
            singularValues[i] = this.singularValue(i);
            ++i;
        }
        return singularValues;
    }

    public double singularValue(int order) {
        if (order >= this.mOrder) {
            String msg = "Maximum order=" + (this.mOrder - 1) + " found order=" + order;
            throw new IllegalArgumentException(msg);
        }
        return SvdMatrix.columnLength(this.mRowVectors, order) * SvdMatrix.columnLength(this.mColumnVectors, order);
    }

    public double[][] leftSingularVectors() {
        return SvdMatrix.normalizeColumns(this.mRowVectors);
    }

    public double[][] rightSingularVectors() {
        return SvdMatrix.normalizeColumns(this.mColumnVectors);
    }

    public static SvdMatrix svd(double[][] values, int maxOrder, double featureInit, double initialLearningRate, double annealingRate, double regularization, Reporter reporter, double minImprovement, int minEpochs, int maxEpochs) {
        if (reporter == null) {
            reporter = Reporters.silent();
        }
        int m = values.length;
        int n = values[0].length;
        reporter.info("Calculating SVD");
        reporter.info("#Rows=" + m + " #Cols=" + n);
        int i = 1;
        while (i < m) {
            if (values[i].length != n) {
                String msg = "All rows must be of same length. Found row[0].length=" + n + " row[" + i + "]=" + values[i].length;
                reporter.fatal(msg);
                throw new IllegalArgumentException(msg);
            }
            ++i;
        }
        int[] sharedRow = new int[n];
        int j = 0;
        while (j < n) {
            sharedRow[j] = j;
            ++j;
        }
        int[][] columnIds = new int[m][];
        int j2 = 0;
        while (j2 < m) {
            columnIds[j2] = sharedRow;
            ++j2;
        }
        return SvdMatrix.partialSvd(columnIds, values, maxOrder, featureInit, initialLearningRate, annealingRate, regularization, reporter, minImprovement, minEpochs, maxEpochs);
    }

    public static SvdMatrix partialSvd(int[][] columnIds, double[][] values, int maxOrder, double featureInit, double initialLearningRate, double annealingRate, double regularization, Reporter reporter, double minImprovement, int minEpochs, int maxEpochs) {
        return SvdMatrix.partialSvd(columnIds, values, maxOrder, featureInit, initialLearningRate, annealingRate, regularization, new Random(), reporter, minImprovement, minEpochs, maxEpochs);
    }

    static SvdMatrix partialSvd(int[][] columnIds, double[][] values, int maxOrder, double featureInit, double initialLearningRate, double annealingRate, double regularization, Random random, Reporter reporter, double minImprovement, int minEpochs, int maxEpochs) {
        if (reporter == null) {
            reporter = Reporters.silent();
        }
        reporter.info("Start");
        if (maxOrder < 1) {
            String msg = "Max order must be >= 1. Found maxOrder=" + maxOrder;
            reporter.fatal(msg);
            throw new IllegalArgumentException(msg);
        }
        if (minImprovement < 0.0 || SvdMatrix.notFinite(minImprovement)) {
            String msg = "Min improvement must be finite and non-negative. Found minImprovement=" + minImprovement;
            reporter.fatal(msg);
            throw new IllegalArgumentException(msg);
        }
        if (minEpochs <= 0 || maxEpochs < minEpochs) {
            String msg = "Min epochs must be non-negative and less than or equal to max epochs. found minEpochs=" + minEpochs + " maxEpochs=" + maxEpochs;
            reporter.fatal(msg);
            throw new IllegalArgumentException(msg);
        }
        if (SvdMatrix.notFinite(featureInit) || featureInit == 0.0) {
            String msg = "Feature inits must be finite and non-zero. Found featureInit=" + featureInit;
            reporter.fatal(msg);
            throw new IllegalArgumentException(msg);
        }
        if (SvdMatrix.notFinite(initialLearningRate) || initialLearningRate < 0.0) {
            String msg = "Initial learning rate must be finite and non-negative. Found initialLearningRate=" + initialLearningRate;
            reporter.fatal(msg);
            throw new IllegalArgumentException(msg);
        }
        if (SvdMatrix.notFinite(regularization) || regularization < 0.0) {
            String msg = "Regularization must be finite and non-negative. Found regularization=" + regularization;
            reporter.fatal(msg);
            throw new IllegalArgumentException(msg);
        }
        int row = 0;
        while (row < columnIds.length) {
            if (columnIds == null) {
                String msg = "ColumnIds must not be null.";
                reporter.fatal(msg);
                throw new IllegalArgumentException(msg);
            }
            if (values == null) {
                String msg = "Values must not be null";
                reporter.fatal(msg);
                throw new IllegalArgumentException(msg);
            }
            if (columnIds[row] == null) {
                String msg = "All column Ids must be non-null. Found null in row=" + row;
                reporter.fatal(msg);
                throw new IllegalArgumentException(msg);
            }
            if (values[row] == null) {
                String msg = "All values must be non-null. Found null row=" + row;
                reporter.fatal(msg);
                throw new IllegalArgumentException(msg);
            }
            if (columnIds[row].length != values[row].length) {
                String msg = "column Ids and values must be same length. For row=" + row + " Found columnIds[row].length=" + columnIds[row].length + " Found values[row].length=" + values[row].length;
                reporter.fatal(msg);
                throw new IllegalArgumentException(msg);
            }
            int i = 0;
            while (i < columnIds[row].length) {
                if (columnIds[row][i] < 0) {
                    String msg = "Column ids must be non-negative. Found columnIds[" + row + "][" + i + "]=" + columnIds[row][i];
                    reporter.fatal(msg);
                    throw new IllegalArgumentException(msg);
                }
                if (i > 0 && columnIds[row][i - 1] >= columnIds[row][i]) {
                    String msg = "All column Ids must be same length. At row=" + row + " Mismatch at rows " + i + " and " + (i - 1);
                    reporter.fatal(msg);
                    throw new IllegalArgumentException(msg);
                }
                ++i;
            }
            ++row;
        }
        if (annealingRate < 0.0 || SvdMatrix.notFinite(annealingRate)) {
            String msg = "Annealing rate must be finite and non-negative. Found rate=" + annealingRate;
            reporter.fatal(msg);
            throw new IllegalArgumentException("14");
        }
        int numRows = columnIds.length;
        int numEntries = 0;
        double[][] dArray = values;
        int n = values.length;
        int n2 = 0;
        while (n2 < n) {
            double[] xs = dArray[n2];
            numEntries += xs.length;
            ++n2;
        }
        int maxColumnIndex = 0;
        int[][] nArray = columnIds;
        int n3 = columnIds.length;
        n = 0;
        while (n < n3) {
            int[] xs = nArray[n];
            int i = 0;
            while (i < xs.length) {
                if (xs[i] > maxColumnIndex) {
                    maxColumnIndex = xs[i];
                }
                ++i;
            }
            ++n;
        }
        int numColumns = maxColumnIndex + 1;
        maxOrder = Math.min(maxOrder, Math.min(numRows, numColumns));
        double[][] cache = new double[values.length][];
        int row2 = 0;
        while (row2 < numRows) {
            cache[row2] = new double[values[row2].length];
            Arrays.fill(cache[row2], 0.0);
            ++row2;
        }
        ArrayList<double[]> rowVectorList = new ArrayList<double[]>(maxOrder);
        ArrayList<double[]> columnVectorList = new ArrayList<double[]>(maxOrder);
        int order = 0;
        while (order < maxOrder) {
            reporter.info("  Factor=" + order);
            double[] rowVector = SvdMatrix.initArray(numRows, featureInit, random);
            double[] columnVector = SvdMatrix.initArray(numColumns, featureInit, random);
            double rmseLast = Double.POSITIVE_INFINITY;
            int epoch = 0;
            while (epoch < maxEpochs) {
                double learningRateForEpoch = initialLearningRate / (1.0 + (double)epoch / annealingRate);
                double sumOfSquareErrors = 0.0;
                int row3 = 0;
                while (row3 < numRows) {
                    int[] columnIdsForRow = columnIds[row3];
                    double[] valuesForRow = values[row3];
                    double[] cacheForRow = cache[row3];
                    int i = 0;
                    while (i < columnIdsForRow.length) {
                        int column = columnIdsForRow[i];
                        double prediction = SvdMatrix.predict(row3, column, rowVector, columnVector, cacheForRow[i]);
                        double error = valuesForRow[i] - prediction;
                        sumOfSquareErrors += error * error;
                        double rowCurrent = rowVector[row3];
                        double columnCurrent = columnVector[column];
                        int n4 = row3;
                        rowVector[n4] = rowVector[n4] + learningRateForEpoch * (error * columnCurrent - regularization * rowCurrent);
                        int n5 = column;
                        columnVector[n5] = columnVector[n5] + learningRateForEpoch * (error * rowCurrent - regularization * columnCurrent);
                        ++i;
                    }
                    ++row3;
                }
                double rmse = Math.sqrt(sumOfSquareErrors / (double)numEntries);
                reporter.info("    epoch=" + epoch + " rmse=" + rmse);
                if (epoch >= minEpochs && SvdMatrix.relativeDifference(rmse, rmseLast) < minImprovement) {
                    reporter.info("Converged in epoch=" + epoch + " rmse=" + rmse + " relDiff=" + SvdMatrix.relativeDifference(rmse, rmseLast));
                    break;
                }
                rmseLast = rmse;
                ++epoch;
            }
            reporter.info("Order=" + order + " RMSE=" + rmseLast);
            rowVectorList.add(rowVector);
            columnVectorList.add(columnVector);
            int row4 = 0;
            while (row4 < cache.length) {
                double[] cacheRow = cache[row4];
                int i = 0;
                while (i < cacheRow.length) {
                    cacheRow[i] = SvdMatrix.predict(row4, columnIds[row4][i], rowVector, columnVector, cacheRow[i]);
                    ++i;
                }
                ++row4;
            }
            ++order;
        }
        double[][] rowVectors = (double[][])rowVectorList.toArray((T[])EMPTY_DOUBLE_2D_ARRAY);
        double[][] columnVectors = (double[][])columnVectorList.toArray((T[])EMPTY_DOUBLE_2D_ARRAY);
        return new SvdMatrix(SvdMatrix.transpose(rowVectors), SvdMatrix.transpose(columnVectors), maxOrder);
    }

    static double relativeDifference(double x, double y) {
        return Math.abs(x - y) / (Math.abs(x) + Math.abs(y));
    }

    static double[][] transpose(double[][] xs) {
        double[][] ys = new double[xs[0].length][xs.length];
        int i = 0;
        while (i < xs.length) {
            int j = 0;
            while (j < xs[i].length) {
                ys[j][i] = xs[i][j];
                ++j;
            }
            ++i;
        }
        return ys;
    }

    static double predict(int row, int column, double[] rowVector, double[] columnVector, double cache) {
        return cache + rowVector[row] * columnVector[column];
    }

    static double[] initArray(int size, double val, Random random) {
        double[] xs = new double[size];
        int i = 0;
        while (i < xs.length) {
            xs[i] = random.nextGaussian() * val;
            ++i;
        }
        return xs;
    }

    static boolean notFinite(double x) {
        return Double.isNaN(x) || Double.isInfinite(x);
    }

    static double columnLength(double[][] xs, int col) {
        double sumOfSquares = 0.0;
        int i = 0;
        while (i < xs.length) {
            sumOfSquares += xs[i][col] * xs[i][col];
            ++i;
        }
        return Math.sqrt(sumOfSquares);
    }

    static void scale(double[][] vecs, double[][] singularVecs, double[] singularVals) {
        int i = 0;
        while (i < vecs.length) {
            int k = 0;
            while (k < vecs[i].length) {
                vecs[i][k] = singularVecs[i][k] * singularVals[k];
                ++k;
            }
            ++i;
        }
    }

    static void verifyDimensions(String prefix, int order, double[][] vectors) {
        int i = 0;
        while (i < vectors.length) {
            if (vectors[i].length != order) {
                String msg = "All vectors must have length equal to order. order=" + order + " " + prefix + "Vectors[" + i + "].length=" + vectors[i].length;
                throw new IllegalArgumentException(msg);
            }
            ++i;
        }
    }

    static double[][] normalizeColumns(double[][] xs) {
        int numDims = xs.length;
        int order = xs[0].length;
        double[][] result = new double[numDims][order];
        int j = 0;
        while (j < order) {
            double sumOfSquares = 0.0;
            int i = 0;
            while (i < numDims) {
                double valIJ;
                result[i][j] = valIJ = xs[i][j];
                sumOfSquares += valIJ * valIJ;
                ++i;
            }
            double length = Math.sqrt(sumOfSquares);
            int i2 = 0;
            while (i2 < numDims) {
                double[] dArray = result[i2];
                int n = j;
                dArray[n] = dArray[n] / length;
                ++i2;
            }
            ++j;
        }
        return result;
    }
}

