/*
 * Decompiled with CFR 0.152.
 */
package de.datexis.sector.eval;

import de.datexis.annotator.AnnotatorEvaluation;
import de.datexis.encoder.LookupCacheEncoder;
import de.datexis.model.Annotation;
import de.datexis.model.Document;
import de.datexis.model.Span;
import de.datexis.model.tag.Tag;
import java.io.Serializable;
import java.util.Collection;
import java.util.IdentityHashMap;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
import org.nd4j.evaluation.EvaluationAveraging;
import org.nd4j.evaluation.EvaluationUtils;
import org.nd4j.evaluation.IEvaluation;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;
import org.slf4j.LoggerFactory;

public class ClassificationEvaluation
extends AnnotatorEvaluation
implements IEvaluation<ClassificationEvaluation> {
    protected LookupCacheEncoder encoder;
    protected int numClasses;
    protected int K;
    protected Evaluation eval;
    protected double mrrsum = 0.0;
    protected double mapsum = 0.0;
    protected double p1sum = 0.0;
    protected double r1sum = 0.0;
    protected double pksum = 0.0;
    protected double rksum = 0.0;

    public ClassificationEvaluation(String experimentName, LookupCacheEncoder encoder) {
        this(experimentName, Annotation.Source.GOLD, Annotation.Source.PRED, encoder, 3);
    }

    public ClassificationEvaluation(String experimentName, Annotation.Source expected, Annotation.Source predicted, LookupCacheEncoder encoder, int K) {
        super(experimentName, expected, predicted);
        this.K = K;
        this.encoder = encoder;
        this.numClasses = (int)encoder.getEmbeddingVectorSize();
        this.log = LoggerFactory.getLogger(ClassificationEvaluation.class);
        this.clear();
    }

    protected void clear() {
        this.eval = new Evaluation(this.encoder.getWords(), this.K);
        this.countDocs = 0;
        this.countExamples = 0;
        this.mrrsum = 0.0;
        this.mapsum = 0.0;
        this.p1sum = 0.0;
        this.r1sum = 0.0;
        this.pksum = 0.0;
        this.rksum = 0.0;
    }

    public double getScore() {
        return this.getMAP();
    }

    public void calculateScores(Collection<Document> docs) {
        throw new UnsupportedOperationException("Not implemented yet.");
    }

    public void calculateScoresFromAnnotations(Collection<Document> documents, Class<? extends Annotation> annotationClass, boolean matchAllPredicted) {
        IdentityHashMap matched = new IdentityHashMap();
        this.countDocs += documents.size();
        for (Document doc : documents) {
            INDArray p;
            INDArray r;
            for (Annotation expected : doc.getAnnotations(this.expectedSource, annotationClass)) {
                Optional predicted = doc.getAnnotationMaxOverlap(this.predictedSource, annotationClass, (Span)expected);
                if (predicted.isPresent()) {
                    matched.put(predicted.get(), true);
                    r = expected.getVector(this.encoder.getClass()).transpose();
                    p = ((Annotation)predicted.get()).getVector(this.encoder.getClass()).transpose();
                    this.evalExample(r, p);
                    continue;
                }
                this.log.warn("Could not match predicted Annotation for expected Annotation {}-{}", (Object)expected.getBegin(), (Object)expected.getEnd());
            }
            if (!matchAllPredicted) continue;
            for (Annotation predicted : doc.getAnnotations(this.predictedSource, annotationClass)) {
                Optional expected;
                if (matched.containsKey(predicted) || !(expected = doc.getAnnotationMaxOverlap(this.expectedSource, annotationClass, (Span)predicted)).isPresent()) continue;
                r = ((Annotation)expected.get()).getVector(this.encoder.getClass()).transpose();
                p = predicted.getVector(this.encoder.getClass()).transpose();
                this.evalExample(r, p);
            }
        }
    }

    public <T extends Tag> void calculateScoresFromTags(Collection<Document> documents, Class<? extends Span> spanClass, Class<T> tagClass) {
        this.countDocs += documents.size();
        for (Document doc : documents) {
            for (Span s : doc.getStream(spanClass).collect(Collectors.toList())) {
                INDArray r = s.getTag(this.expectedSource, tagClass).getVector().transpose();
                INDArray p = s.getTag(this.predictedSource, tagClass).getVector().transpose();
                this.evalExample(r, p);
            }
        }
    }

    public void evalExample(INDArray Y, INDArray Z) {
        INDArray[] z = Nd4j.sortWithIndices((INDArray)Nd4j.toFlattened((INDArray[])new INDArray[]{Z}).dup(), (int)1, (boolean)false);
        if (z[0].sumNumber().doubleValue() == 0.0) {
            this.log.warn("Sort on zero vector - please check vector dimensions!");
        }
        INDArray Zi = z[0];
        this.eval.eval(Y, Z);
        this.mapsum += this.AP(Y, Z, Zi);
        this.mrrsum += this.RR(Y, Z, Zi);
        this.p1sum += this.Prec(Y, Z, Zi, 1);
        this.r1sum += this.Rec(Y, Z, Zi, 1);
        this.pksum += this.Prec(Y, Z, Zi, this.K);
        this.rksum += this.Rec(Y, Z, Zi, this.K);
        ++this.countExamples;
    }

    protected double div(double n, double d) {
        if (d == 0.0) {
            return 0.0;
        }
        return n / d;
    }

    protected static int rank(int idx, INDArray l) {
        int i = 0;
        while ((long)i < l.length()) {
            if (l.getInt(new int[]{i++}) != idx) continue;
            return i + 1;
        }
        throw new IllegalArgumentException("index does not exist in labels");
    }

    private double RR(INDArray Y, INDArray Z, INDArray Zi) {
        int ri = ClassificationEvaluation.maxIndex(Y);
        if (ri >= 0) {
            double r = ClassificationEvaluation.rank(ri, Zi);
            return 1.0 / r;
        }
        return 0.0;
    }

    private double AP(INDArray Y, INDArray Z, INDArray Zi) {
        double sum = 0.0;
        int count = 0;
        int k = 0;
        while ((long)k < Y.length()) {
            int[] nArray = new int[]{k};
            int idx = Zi.getInt(nArray);
            if (Y.getDouble((long)idx) > 0.0) {
                sum += this.Prec(Y, Z, Zi, k + 1);
                ++count;
            }
            ++k;
        }
        assert (count == 1);
        if (count > 0) {
            return sum / (double)count;
        }
        return 0.0;
    }

    private double Prec(INDArray Y, INDArray Z, INDArray Zi, int k) {
        double sum = 0.0;
        for (int i = 0; i < k; ++i) {
            int[] nArray = new int[]{i};
            int idx = Zi.getInt(nArray);
            if (!(Y.getDouble((long)idx) > 0.0)) continue;
            sum += 1.0;
        }
        return sum / (double)k;
    }

    private double Rec(INDArray Y, INDArray Z, INDArray Zi, int k) {
        if (Y.sumNumber().doubleValue() == 0.0) {
            return 0.0;
        }
        double sum = 0.0;
        for (int i = 0; i < k; ++i) {
            int[] nArray = new int[]{i};
            int idx = Zi.getInt(nArray);
            if (!(Y.getDouble((long)idx) > 0.0)) continue;
            sum += 1.0;
        }
        return sum / Y.sumNumber().doubleValue();
    }

    protected static int maxIndex(INDArray Y) {
        int idx = -1;
        double max = Double.MIN_VALUE;
        int i = 0;
        while ((long)i < Y.length()) {
            if (Y.getDouble((long)i) > max) {
                max = Y.getDouble((long)i);
                idx = i;
            }
            ++i;
        }
        return idx;
    }

    public double getAccuracy() {
        return this.eval.accuracy();
    }

    public double getAccuracyK() {
        return this.eval.topNAccuracy();
    }

    protected double getAccuracy(int c) {
        return this.div(((Integer)this.eval.truePositives().get(c)).intValue(), ((Integer)this.eval.positive().get(c)).intValue());
    }

    public double getMicroPrecision() {
        return this.eval.precision(EvaluationAveraging.Micro);
    }

    public double getMacroPrecision() {
        double sum = 0.0;
        for (int c = 0; c < this.numClasses; ++c) {
            sum += this.getPrecision(c);
        }
        return sum / (double)this.numClasses;
    }

    protected double getPrecision(int c) {
        return this.eval.precision(Integer.valueOf(c));
    }

    public double getMicroRecall() {
        return this.eval.recall(EvaluationAveraging.Micro);
    }

    public double getMacroRecall() {
        double sum = 0.0;
        for (int c = 0; c < this.numClasses; ++c) {
            sum += this.getRecall(c);
        }
        return sum / (double)this.numClasses;
    }

    protected double getRecall(int c) {
        return this.eval.recall(c);
    }

    public double getMicroF1() {
        return this.eval.f1(EvaluationAveraging.Micro);
    }

    public double getMacroF1() {
        double sum = 0.0;
        for (int c = 0; c < this.numClasses; ++c) {
            sum += this.getF1(c);
        }
        return sum / (double)this.numClasses;
    }

    protected double getF1(int c) {
        return this.eval.f1(c);
    }

    protected double getMRR() {
        return this.mrrsum / (double)this.countExamples;
    }

    public double getMAP() {
        return this.mapsum / (double)this.countExamples;
    }

    public double getPrecisionK() {
        return this.pksum / (double)this.countExamples;
    }

    public double getRecallK() {
        return this.rksum / (double)this.countExamples;
    }

    public double getPrecision1() {
        return this.p1sum / (double)this.countExamples;
    }

    public double getRecall1() {
        return this.r1sum / (double)this.countExamples;
    }

    public void eval(INDArray labels, INDArray networkPredictions) {
        for (int i = 0; i < labels.rows(); ++i) {
            this.evalExample(labels.getRow((long)i), networkPredictions.getRow((long)i));
        }
    }

    public void eval(INDArray labels, INDArray networkPredictions, List<? extends Serializable> recordMetaData) {
        this.eval(labels, networkPredictions);
    }

    public void eval(INDArray labels, INDArray networkPredictions, INDArray maskArray) {
        if (maskArray == null) {
            if (labels.rank() == 3) {
                this.evalTimeSeries(labels, networkPredictions, maskArray);
            } else {
                this.eval(labels, networkPredictions);
            }
            return;
        }
        if (labels.rank() == 3 && maskArray.rank() == 2) {
            this.evalTimeSeries(labels, networkPredictions, maskArray);
            return;
        }
        throw new UnsupportedOperationException(((Object)((Object)this)).getClass().getSimpleName() + " does not support per-output masking");
    }

    public void evalTimeSeries(INDArray labels, INDArray predicted) {
        this.evalTimeSeries(labels, predicted, null);
    }

    public void evalTimeSeries(INDArray labels, INDArray predictions, INDArray labelsMask) {
        Pair pair = EvaluationUtils.extractNonMaskedTimeSteps((INDArray)labels, (INDArray)predictions, (INDArray)labelsMask);
        if (pair == null) {
            return;
        }
        INDArray labels2d = (INDArray)pair.getFirst();
        INDArray predicted2d = (INDArray)pair.getSecond();
        this.eval(labels2d, predicted2d);
    }

    public void merge(ClassificationEvaluation other) {
        throw new UnsupportedOperationException("Not implemented yet.");
    }

    public void reset() {
        throw new UnsupportedOperationException("Not implemented yet.");
    }

    public String stats() {
        throw new UnsupportedOperationException("Not implemented yet.");
    }

    public String toJson() {
        throw new UnsupportedOperationException("Not implemented yet.");
    }

    public String toYaml() {
        throw new UnsupportedOperationException("Not implemented yet.");
    }

    public String printClassificationAtKStats() {
        ClassificationEvaluation eval = this;
        StringBuilder line = new StringBuilder();
        line.append(" Acc@1\t Acc@").append(this.K).append("\t P@1\t P@").append(this.K).append("\t R@1\t R@").append(this.K).append("\t MAP\n");
        line.append(ClassificationEvaluation.fDbl((double)eval.getAccuracy())).append("\t");
        line.append(ClassificationEvaluation.fDbl((double)eval.getAccuracyK())).append("\t");
        line.append(ClassificationEvaluation.fDbl((double)eval.getPrecision1())).append("\t");
        line.append(ClassificationEvaluation.fDbl((double)eval.getPrecisionK())).append("\t");
        line.append(ClassificationEvaluation.fDbl((double)eval.getRecall1())).append("\t");
        line.append(ClassificationEvaluation.fDbl((double)eval.getRecallK())).append("\t");
        line.append(ClassificationEvaluation.fDbl((double)eval.getMAP())).append("\t");
        line.append("\n");
        return line.toString();
    }

    public String printClassificationStats() {
        ClassificationEvaluation eval = this;
        StringBuilder line = new StringBuilder();
        line.append(" count\t TP\t FP\t MRR\t P@1\t MAP\t mPrec\t mRec\t mF1\n");
        line.append(ClassificationEvaluation.fInt((double)eval.countExamples())).append("\t");
        line.append(ClassificationEvaluation.fInt((double)eval.eval.getTruePositives().totalCount())).append("\t");
        line.append(ClassificationEvaluation.fInt((double)eval.eval.getFalsePositives().totalCount())).append("\t");
        line.append(ClassificationEvaluation.fDbl((double)(eval.getMRR() / 100.0))).append("\t");
        line.append(ClassificationEvaluation.fDbl((double)eval.getAccuracy())).append("\t");
        line.append(ClassificationEvaluation.fDbl((double)eval.getMAP())).append("\t");
        line.append(ClassificationEvaluation.fDbl((double)eval.getMacroPrecision())).append("\t");
        line.append(ClassificationEvaluation.fDbl((double)eval.getMacroRecall())).append("\t");
        line.append(ClassificationEvaluation.fDbl((double)eval.getMacroF1())).append("\t");
        line.append("\n");
        return line.toString();
    }
}

