/*
 * Decompiled with CFR 0.152.
 */
package de.jungblut.classification.eval;

import com.google.common.base.Preconditions;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import de.jungblut.classification.Classifier;
import de.jungblut.classification.ClassifierFactory;
import de.jungblut.classification.Predictor;
import de.jungblut.classification.eval.EvaluationSplit;
import de.jungblut.datastructure.ArrayUtils;
import de.jungblut.math.DoubleVector;
import de.jungblut.math.MathUtils;
import de.jungblut.partition.BlockPartitioner;
import de.jungblut.partition.Boundaries;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorCompletionService;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import org.apache.commons.math3.util.FastMath;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

public final class Evaluator {
    private static final Logger LOG = LogManager.getLogger(Evaluator.class);

    private Evaluator() {
        throw new IllegalAccessError();
    }

    public static EvaluationResult evaluateClassifier(Classifier classifier, DoubleVector[] features, DoubleVector[] outcome, float splitFraction, boolean random) {
        return Evaluator.evaluateClassifier(classifier, features, outcome, splitFraction, random, null);
    }

    public static EvaluationResult evaluateClassifier(Classifier classifier, DoubleVector[] features, DoubleVector[] outcome, float splitFraction, boolean random, Double threshold) {
        EvaluationSplit split = EvaluationSplit.create(features, outcome, splitFraction, random);
        return Evaluator.evaluateSplit(classifier, split, threshold);
    }

    public static EvaluationResult evaluateSplit(Classifier classifier, EvaluationSplit split) {
        return Evaluator.evaluateSplit(classifier, split.getTrainFeatures(), split.getTrainOutcome(), split.getTestFeatures(), split.getTestOutcome(), null);
    }

    public static EvaluationResult evaluateSplit(Classifier classifier, EvaluationSplit split, Double threshold) {
        return Evaluator.evaluateSplit(classifier, split.getTrainFeatures(), split.getTrainOutcome(), split.getTestFeatures(), split.getTestOutcome(), threshold);
    }

    public static EvaluationResult evaluateSplit(Classifier classifier, DoubleVector[] trainFeatures, DoubleVector[] trainOutcome, DoubleVector[] testFeatures, DoubleVector[] testOutcome, Double threshold) {
        classifier.train(trainFeatures, trainOutcome);
        return Evaluator.testClassifier(classifier, testFeatures, testOutcome, threshold);
    }

    public static EvaluationResult testClassifier(Predictor classifier, DoubleVector[] testFeatures, DoubleVector[] testOutcome) {
        return Evaluator.testClassifier(classifier, testFeatures, testOutcome, null);
    }

    public static EvaluationResult testClassifier(Predictor classifier, DoubleVector[] testFeatures, DoubleVector[] testOutcome, Double threshold) {
        EvaluationResult result = new EvaluationResult();
        result.numLabels = Math.max(2, testOutcome[0].getDimension());
        result.testSize = testOutcome.length;
        if (result.isBinary()) {
            ArrayList<MathUtils.PredictionOutcomePair> outcomePredictedPairs = new ArrayList<MathUtils.PredictionOutcomePair>();
            for (int i = 0; i < testFeatures.length; ++i) {
                DoubleVector outcomeVector = testOutcome[i];
                DoubleVector predictedVector = classifier.predict(testFeatures[i]);
                int outcomeClass = Evaluator.observeBinaryClassificationElement(classifier, threshold, result, outcomeVector, predictedVector);
                outcomePredictedPairs.add(MathUtils.PredictionOutcomePair.from(outcomeClass, predictedVector.get(0)));
            }
            result.auc = MathUtils.computeAUC(outcomePredictedPairs);
        } else {
            int[][] confusionMatrix = new int[result.numLabels][result.numLabels];
            for (int i = 0; i < testFeatures.length; ++i) {
                DoubleVector predicted = classifier.predict(testFeatures[i]);
                DoubleVector outcomeVector = testOutcome[i];
                result.logLoss += outcomeVector.multiply(MathUtils.logVector(predicted)).sum();
                int outcomeClass = outcomeVector.maxIndex();
                int prediction = classifier.extractPredictedClass(predicted);
                int[] nArray = confusionMatrix[outcomeClass];
                int n = prediction;
                nArray[n] = nArray[n] + 1;
                if (outcomeClass != prediction) continue;
                ++result.correct;
            }
            result.confusionMatrix = confusionMatrix;
        }
        return result;
    }

    public static int observeBinaryClassificationElement(Predictor predictor, Double threshold, EvaluationResult result, DoubleVector outcomeVector, DoubleVector predictedVector) {
        int outcomeClass = (int)outcomeVector.get(0);
        result.logLoss += outcomeVector.multiply(MathUtils.logVector(predictedVector)).sum();
        int prediction = 0;
        prediction = threshold == null ? predictor.extractPredictedClass(predictedVector) : predictor.extractPredictedClass(predictedVector, threshold);
        if (outcomeClass == 1) {
            if (prediction == 1) {
                ++result.truePositive;
            } else {
                ++result.falseNegative;
            }
        } else if (outcomeClass == 0) {
            if (prediction == 0) {
                ++result.trueNegative;
            } else {
                ++result.falsePositive;
            }
        } else {
            throw new IllegalArgumentException("Outcome class was neither 0 or 1. Was: " + outcomeClass + "; the supplied outcome value was: " + outcomeVector.get(0));
        }
        return outcomeClass;
    }

    public static <A extends Classifier> EvaluationResult crossValidateClassifier(ClassifierFactory<A> classifierFactory, DoubleVector[] features, DoubleVector[] outcome, int numLabels, int folds, Double threshold, boolean verbose) {
        return Evaluator.crossValidateClassifier(classifierFactory, features, outcome, numLabels, folds, threshold, 1, verbose);
    }

    public static <A extends Classifier> EvaluationResult crossValidateClassifier(ClassifierFactory<A> classifierFactory, DoubleVector[] features, DoubleVector[] outcome, int numLabels, int folds, Double threshold, int numThreads, boolean verbose) {
        int fold;
        int numFolds = folds + 1;
        ArrayUtils.multiShuffle(features, new DoubleVector[][]{outcome});
        EvaluationResult averagedModel = new EvaluationResult();
        averagedModel.numLabels = numLabels;
        int m = features.length;
        ArrayList<Boundaries.Range> partition = new ArrayList<Boundaries.Range>(new BlockPartitioner().partition(numFolds, m).getBoundaries());
        int[] splitRanges = new int[numFolds];
        for (int i = 1; i < numFolds; ++i) {
            splitRanges[i] = ((Boundaries.Range)partition.get(i)).getEnd();
        }
        splitRanges[numFolds - 1] = splitRanges[numFolds - 1] - 1;
        if (verbose) {
            LOG.info("Computed split ranges: " + Arrays.toString(splitRanges) + "\n");
        }
        ExecutorService pool = Executors.newFixedThreadPool(numThreads, new ThreadFactoryBuilder().setDaemon(true).build());
        ExecutorCompletionService<EvaluationResult> completionService = new ExecutorCompletionService<EvaluationResult>(pool);
        for (fold = 0; fold < folds; ++fold) {
            completionService.submit(new CallableEvaluation<A>(fold, splitRanges, m, classifierFactory, features, outcome, folds, threshold));
        }
        for (fold = 0; fold < folds; ++fold) {
            try {
                Future take = completionService.take();
                EvaluationResult foldSplit = (EvaluationResult)take.get();
                if (verbose) {
                    LOG.info("Fold: " + (fold + 1));
                    foldSplit.print();
                    LOG.info("");
                }
                averagedModel.add(foldSplit);
                continue;
            }
            catch (InterruptedException e) {
                e.printStackTrace();
                continue;
            }
            catch (ExecutionException e) {
                e.printStackTrace();
            }
        }
        averagedModel.average(folds);
        return averagedModel;
    }

    public static <A extends Classifier> EvaluationResult tenFoldCrossValidation(ClassifierFactory<A> classifierFactory, DoubleVector[] features, DoubleVector[] outcome, int numLabels, Double threshold, boolean verbose) {
        return Evaluator.crossValidateClassifier(classifierFactory, features, outcome, numLabels, 10, threshold, verbose);
    }

    public static <A extends Classifier> EvaluationResult tenFoldCrossValidation(ClassifierFactory<A> classifierFactory, DoubleVector[] features, DoubleVector[] outcome, int numLabels, Double threshold, int numThreads, boolean verbose) {
        return Evaluator.crossValidateClassifier(classifierFactory, features, outcome, numLabels, 10, threshold, numThreads, verbose);
    }

    private static class CallableEvaluation<A extends Classifier>
    implements Callable<EvaluationResult> {
        private final int fold;
        private final int[] splitRanges;
        private final int m;
        private final DoubleVector[] features;
        private final DoubleVector[] outcome;
        private final ClassifierFactory<A> classifierFactory;
        private final Double threshold;

        public CallableEvaluation(int fold, int[] splitRanges, int m, ClassifierFactory<A> classifierFactory, DoubleVector[] features, DoubleVector[] outcome, int folds, Double threshold) {
            this.fold = fold;
            this.splitRanges = splitRanges;
            this.m = m;
            this.classifierFactory = classifierFactory;
            this.features = features;
            this.outcome = outcome;
            this.threshold = threshold;
        }

        @Override
        public EvaluationResult call() throws Exception {
            DoubleVector[] featureTest = ArrayUtils.subArray(this.features, this.splitRanges[this.fold], this.splitRanges[this.fold + 1]);
            DoubleVector[] outcomeTest = ArrayUtils.subArray(this.outcome, this.splitRanges[this.fold], this.splitRanges[this.fold + 1]);
            DoubleVector[] featureTrain = new DoubleVector[this.m - featureTest.length];
            DoubleVector[] outcomeTrain = new DoubleVector[this.m - featureTest.length];
            int index = 0;
            for (int i = 0; i < this.m; ++i) {
                if (i >= this.splitRanges[this.fold] && i <= this.splitRanges[this.fold + 1]) continue;
                featureTrain[index] = this.features[i];
                outcomeTrain[index] = this.outcome[i];
                ++index;
            }
            return Evaluator.evaluateSplit(this.classifierFactory.newInstance(), featureTrain, outcomeTrain, featureTest, outcomeTest, this.threshold);
        }
    }

    public static class EvaluationResult {
        int numLabels;
        int correct;
        int testSize;
        int truePositive;
        int falsePositive;
        int trueNegative;
        int falseNegative;
        int[][] confusionMatrix;
        double auc;
        double logLoss;

        public double getAUC() {
            return this.auc;
        }

        public double getLogLoss() {
            return -this.logLoss / (double)this.testSize;
        }

        public double getPrecision() {
            return (double)this.truePositive / (double)(this.truePositive + this.falsePositive);
        }

        public double getRecall() {
            return (double)this.truePositive / (double)(this.truePositive + this.falseNegative);
        }

        public double getFalsePositiveRate() {
            return (double)this.falsePositive / (double)(this.falsePositive + this.trueNegative);
        }

        public double getAccuracy() {
            if (this.isBinary()) {
                return ((double)this.truePositive + (double)this.trueNegative) / (double)(this.truePositive + this.trueNegative + this.falsePositive + this.falseNegative);
            }
            return (double)this.correct / (double)this.testSize;
        }

        public double getF1Score() {
            return 2.0 * (this.getPrecision() * this.getRecall()) / (this.getPrecision() + this.getRecall());
        }

        public double getMatthewsCorrelationCoefficient() {
            return (double)(this.truePositive * this.trueNegative - this.falsePositive * this.falseNegative) / FastMath.sqrt((double)((this.truePositive + this.falsePositive) * (this.truePositive + this.falseNegative) * (this.trueNegative + this.falsePositive) * (this.trueNegative + this.falseNegative)));
        }

        public int getCorrect() {
            if (!this.isBinary()) {
                return this.correct;
            }
            return this.truePositive + this.trueNegative;
        }

        public int getNumLabels() {
            return this.numLabels;
        }

        public int getTestSize() {
            return this.testSize;
        }

        public int[][] getConfusionMatrix() {
            return this.confusionMatrix;
        }

        public boolean isBinary() {
            return this.numLabels == 2;
        }

        public void add(EvaluationResult res) {
            this.correct += res.correct;
            this.testSize += res.testSize;
            this.truePositive += res.truePositive;
            this.falsePositive += res.falsePositive;
            this.trueNegative += res.trueNegative;
            this.falseNegative += res.falseNegative;
            this.auc += res.auc;
            this.logLoss += res.logLoss;
            if (this.confusionMatrix == null && res.confusionMatrix != null) {
                this.confusionMatrix = res.confusionMatrix;
            } else if (this.confusionMatrix != null && res.confusionMatrix != null) {
                for (int i = 0; i < this.numLabels; ++i) {
                    for (int j = 0; j < this.numLabels; ++j) {
                        int[] nArray = this.confusionMatrix[i];
                        int n = j;
                        nArray[n] = nArray[n] + res.confusionMatrix[i][j];
                    }
                }
            }
        }

        public void average(int n) {
            this.correct /= n;
            this.testSize /= n;
            this.truePositive /= n;
            this.falsePositive /= n;
            this.trueNegative /= n;
            this.falseNegative /= n;
            this.auc /= (double)n;
            this.logLoss /= (double)n;
            if (this.confusionMatrix != null) {
                for (int i = 0; i < this.numLabels; ++i) {
                    int j = 0;
                    while (j < this.numLabels) {
                        int[] nArray = this.confusionMatrix[i];
                        int n2 = j++;
                        nArray[n2] = nArray[n2] / n;
                    }
                }
            }
        }

        public int getTruePositive() {
            return this.truePositive;
        }

        public int getFalsePositive() {
            return this.falsePositive;
        }

        public int getTrueNegative() {
            return this.trueNegative;
        }

        public int getFalseNegative() {
            return this.falseNegative;
        }

        public void print() {
            this.print(LOG);
        }

        public void print(Logger log) {
            log.info("Number of labels: " + this.getNumLabels());
            log.info("Testset size: " + this.getTestSize());
            log.info("Correctly classified: " + this.getCorrect());
            log.info("Accuracy: " + this.getAccuracy());
            log.info("Log loss: " + this.getLogLoss());
            if (this.isBinary()) {
                log.info("TP: " + this.truePositive);
                log.info("FP: " + this.falsePositive);
                log.info("TN: " + this.trueNegative);
                log.info("FN: " + this.falseNegative);
                log.info("Precision: " + this.getPrecision());
                log.info("Recall: " + this.getRecall());
                log.info("F1 Score: " + this.getF1Score());
                log.info("AUC: " + this.getAUC());
                log.info("MMC: " + this.getMatthewsCorrelationCoefficient());
            } else {
                this.printConfusionMatrix();
            }
        }

        public void printConfusionMatrix() {
            this.printConfusionMatrix(null);
        }

        public void printConfusionMatrix(String[] classNames) {
            int i;
            Preconditions.checkNotNull((Object)this.confusionMatrix, (Object)"No confusion matrix found.");
            if (classNames != null) {
                Preconditions.checkArgument((classNames.length == this.getNumLabels() ? 1 : 0) != 0, (Object)("Passed class names doesn't match with number of labels! Expected " + this.getNumLabels() + " but was " + classNames.length));
            }
            System.out.println("\nConfusion matrix (real outcome on rows, prediction in columns)\n");
            for (i = 0; i < this.getNumLabels(); ++i) {
                System.out.format("%5d", i);
            }
            System.out.format(" <- %5s %5s\t%s\n", "sum", "perc", "class");
            for (i = 0; i < this.getNumLabels(); ++i) {
                int sum = 0;
                for (int j = 0; j < this.getNumLabels(); ++j) {
                    if (i != j) {
                        sum += this.confusionMatrix[i][j];
                    }
                    System.out.format("%5d", this.confusionMatrix[i][j]);
                }
                float falsePercentage = (float)sum / (float)(sum + this.confusionMatrix[i][i]);
                String clz = classNames != null ? " " + i + " (" + classNames[i] + ")" : " " + i;
                System.out.format(" <- %5s %5s\t%s\n", sum, NumberFormat.getPercentInstance().format(falsePercentage), clz);
            }
        }
    }
}

