/*
 * Decompiled with CFR 0.152.
 */
package de.jungblut.math;

import com.google.common.base.Preconditions;
import de.jungblut.math.DoubleMatrix;
import de.jungblut.math.DoubleVector;
import de.jungblut.math.dense.DenseDoubleMatrix;
import de.jungblut.math.dense.DenseDoubleVector;
import de.jungblut.math.minimize.CostFunction;
import de.jungblut.math.tuple.Tuple;
import de.jungblut.math.tuple.Tuple3;
import de.jungblut.online.ml.FeatureOutcomePair;
import de.jungblut.reader.Dataset;
import java.util.Collections;
import java.util.List;
import java.util.function.Predicate;
import org.apache.commons.math3.util.FastMath;

public final class MathUtils {
    public static final double EPS = Math.sqrt(2.2E-16);

    private MathUtils() {
        throw new IllegalAccessError();
    }

    public static Tuple<DoubleMatrix, DoubleVector> meanNormalizeRows(DoubleMatrix pMatrix) {
        DenseDoubleMatrix matrix = new DenseDoubleMatrix(pMatrix.getRowCount(), pMatrix.getColumnCount());
        DenseDoubleVector meanVector = new DenseDoubleVector(matrix.getRowCount());
        for (int row = 0; row < matrix.getRowCount(); ++row) {
            double val;
            int column;
            double mean = 0.0;
            int nonZeroElements = 0;
            for (column = 0; column < matrix.getColumnCount(); ++column) {
                val = pMatrix.get(row, column);
                if (val == 0.0) continue;
                mean += val;
                ++nonZeroElements;
            }
            if ((double)nonZeroElements != 0.0) {
                mean /= (double)nonZeroElements;
            }
            meanVector.set(row, mean);
            for (column = 0; column < matrix.getColumnCount(); ++column) {
                val = pMatrix.get(row, column);
                if (val == 0.0) continue;
                matrix.set(row, column, val - mean);
            }
        }
        return new Tuple((Object)matrix, (Object)meanVector);
    }

    public static Tuple3<DoubleMatrix, DoubleVector, DoubleVector> meanNormalizeColumns(DoubleMatrix x) {
        DoubleVector column;
        int col;
        DenseDoubleMatrix toReturn = new DenseDoubleMatrix(x.getRowCount(), x.getColumnCount());
        int length = x.getColumnCount();
        DenseDoubleVector meanVector = new DenseDoubleVector(length);
        DenseDoubleVector stddevVector = new DenseDoubleVector(length);
        for (col = 0; col < length; ++col) {
            column = x.getColumnVector(col);
            double mean = column.sum() / (double)column.getLength();
            meanVector.set(col, mean);
            double var = column.subtract(mean).pow(2.0).sum() / (double)column.getLength();
            stddevVector.set(col, Math.sqrt(var));
        }
        for (col = 0; col < length; ++col) {
            column = x.getColumnVector(col).subtract(meanVector.get(col)).divide(stddevVector.get(col));
            toReturn.setColumn(col, column.toArray());
        }
        return new Tuple3((Object)toReturn, (Object)meanVector, (Object)stddevVector);
    }

    public static Tuple<DoubleVector, DoubleVector> meanNormalizeColumns(Dataset dataset) {
        return MathUtils.meanNormalizeColumns(dataset, x -> true);
    }

    public static Tuple<DoubleVector, DoubleVector> meanNormalizeColumns(Dataset dataset, Predicate<FeatureOutcomePair> filterPredicate) {
        int i2;
        int numSamples = dataset.getFeatures().length;
        DoubleVector sumVector = null;
        for (int i3 = 0; i3 < numSamples; ++i3) {
            if (!filterPredicate.test(new FeatureOutcomePair(dataset.getFeatures()[i3], dataset.getOutcomes()[i3]))) continue;
            sumVector = sumVector == null ? dataset.getFeatures()[i3] : sumVector.add(dataset.getFeatures()[i3]);
        }
        DoubleVector mean = sumVector.divide((double)numSamples);
        DoubleVector stdVector = null;
        for (i2 = 0; i2 < numSamples; ++i2) {
            if (!filterPredicate.test(new FeatureOutcomePair(dataset.getFeatures()[i2], dataset.getOutcomes()[i2]))) continue;
            stdVector = stdVector == null ? dataset.getFeatures()[i2].subtract(mean).pow(2.0) : stdVector.add(dataset.getFeatures()[i2].subtract(mean).pow(2.0));
        }
        stdVector = stdVector.divide((double)numSamples).sqrt().apply((i, val) -> Math.max(1.0, val));
        for (i2 = 0; i2 < numSamples; ++i2) {
            if (!filterPredicate.test(new FeatureOutcomePair(dataset.getFeatures()[i2], dataset.getOutcomes()[i2]))) continue;
            dataset.getFeatures()[i2] = dataset.getFeatures()[i2].subtract(mean).divide(stdVector);
        }
        return new Tuple((Object)mean, (Object)stdVector);
    }

    public static DenseDoubleMatrix createPolynomials(DenseDoubleMatrix seed, int num) {
        if (num == 1) {
            return seed;
        }
        DenseDoubleMatrix m = new DenseDoubleMatrix(seed.getRowCount(), seed.getColumnCount() * num);
        int index = 0;
        for (int c = 0; c < m.getColumnCount(); c += num) {
            double[] column = seed.getColumn(index++);
            m.setColumn(c, column);
            for (int i = 2; i < num + 1; ++i) {
                DoubleVector pow = new DenseDoubleVector(column).pow((double)i);
                m.setColumn(c + i - 1, pow.toArray());
            }
        }
        return m;
    }

    public static DoubleVector numericalGradient(DoubleVector vector, CostFunction f) {
        DenseDoubleVector gradient = new DenseDoubleVector(vector.getLength());
        DoubleVector tmp = vector.deepCopy();
        for (int i = 0; i < vector.getLength(); ++i) {
            double stepSize = EPS * (Math.abs(vector.get(i)) + 1.0);
            tmp.set(i, vector.get(i) + stepSize);
            double add = f.evaluateCost(tmp).getCost();
            tmp.set(i, vector.get(i) - stepSize);
            double diff = f.evaluateCost(tmp).getCost();
            gradient.set(i, (add - diff) / (2.0 * stepSize));
        }
        return gradient;
    }

    public static DoubleMatrix logMatrix(DoubleMatrix input) {
        DenseDoubleMatrix log = new DenseDoubleMatrix(input.getRowCount(), input.getColumnCount());
        for (int row = 0; row < log.getRowCount(); ++row) {
            for (int col = 0; col < log.getColumnCount(); ++col) {
                double d = input.get(row, col);
                log.set(row, col, MathUtils.guardedLogarithm(d));
            }
        }
        return log;
    }

    public static DoubleVector logVector(DoubleVector input) {
        DenseDoubleVector log = new DenseDoubleVector(input.getDimension());
        for (int col = 0; col < log.getDimension(); ++col) {
            log.set(col, MathUtils.guardedLogarithm(input.get(col)));
        }
        return log;
    }

    public static DoubleMatrix minMaxScale(DoubleMatrix input, double fromMin, double fromMax, double toMin, double toMax) {
        DenseDoubleMatrix newOne = new DenseDoubleMatrix(input.getRowCount(), input.getColumnCount());
        double[][] array = input.toArray();
        for (int row = 0; row < newOne.getRowCount(); ++row) {
            for (int col = 0; col < newOne.getColumnCount(); ++col) {
                newOne.set(row, col, MathUtils.minMaxScale(array[row][col], fromMin, fromMax, toMin, toMax));
            }
        }
        return newOne;
    }

    public static DoubleVector minMaxScale(DoubleVector input, double fromMin, double fromMax, double toMin, double toMax) {
        DenseDoubleVector newOne = new DenseDoubleVector(input.getDimension());
        double[] array = input.toArray();
        for (int i = 0; i < array.length; ++i) {
            newOne.set(i, MathUtils.minMaxScale(array[i], fromMin, fromMax, toMin, toMax));
        }
        return newOne;
    }

    public static double minMaxScale(double x, double fromMin, double fromMax, double toMin, double toMax) {
        return (x - fromMin) * (toMax - toMin) / (fromMax - fromMin) + toMin;
    }

    public static double guardedLogarithm(double input) {
        if (Double.isNaN(input) || Double.isInfinite(input)) {
            return 0.0;
        }
        if (input <= 0.0 || input <= -0.0) {
            return -10.0;
        }
        return FastMath.log((double)input);
    }

    public static double computeAUC(List<PredictionOutcomePair> outcomePredictedPairs) {
        long tp0;
        Collections.sort(outcomePredictedPairs);
        int n = outcomePredictedPairs.size();
        int numOnes = 0;
        for (PredictionOutcomePair tuple : outcomePredictedPairs) {
            if (tuple.getOutcomeClass() != 1) continue;
            ++numOnes;
        }
        if (numOnes == 0 || numOnes == n) {
            return 1.0;
        }
        long truePos = tp0 = (long)numOnes;
        long tn = 0L;
        long accum = 0L;
        double threshold = outcomePredictedPairs.get(0).getPrediction();
        for (int i = 0; i < n; ++i) {
            double actualValue = outcomePredictedPairs.get(i).getOutcomeClass();
            double predictedValue = outcomePredictedPairs.get(i).getPrediction();
            if (predictedValue != threshold) {
                threshold = predictedValue;
                accum += tn * (truePos + tp0);
                tp0 = truePos;
                tn = 0L;
            }
            tn = (long)((double)tn + (1.0 - actualValue));
            truePos = (long)((double)truePos - actualValue);
        }
        return (double)(accum += tn * (truePos + tp0)) / (double)(2 * numOnes * (n - numOnes));
    }

    public static class PredictionOutcomePair
    implements Comparable<PredictionOutcomePair> {
        private final int outcomeClass;
        private final double prediction;

        private PredictionOutcomePair(int outcomeClass, double prediction) {
            this.outcomeClass = outcomeClass;
            this.prediction = prediction;
        }

        public static PredictionOutcomePair from(int outcomeClass, double prediction) {
            Preconditions.checkArgument((outcomeClass == 0 || outcomeClass == 1 ? 1 : 0) != 0, (Object)("Outcome class must be 0 or 1! Supplied: " + outcomeClass));
            return new PredictionOutcomePair(outcomeClass, prediction);
        }

        @Override
        public int compareTo(PredictionOutcomePair o) {
            return Double.compare(this.prediction, o.prediction);
        }

        public int getOutcomeClass() {
            return this.outcomeClass;
        }

        public double getPrediction() {
            return this.prediction;
        }
    }
}

