/*
 * Decompiled with CFR 0.152.
 */
package de.datexis.ner.eval;

import de.datexis.evaluation.ModelEvaluation;
import de.datexis.model.Annotation;
import de.datexis.model.Dataset;
import de.datexis.model.Token;
import de.datexis.model.tag.BIO2Tag;
import de.datexis.model.tag.Tag;
import java.util.ArrayList;
import java.util.stream.Collectors;
import org.nd4j.evaluation.classification.ConfusionMatrix;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.primitives.Counter;

public class MentionTaggerEval
extends ModelEvaluation {
    protected int classes;
    protected Tag tagset;
    protected Evaluation eval;
    protected double accuracy;
    protected double precision;
    protected double recall;
    protected double f1;
    private ArrayList<Integer> examplesCurve;
    private ArrayList<Double> precisionCurve;
    private ArrayList<Double> recallCurve;
    private ArrayList<Double> f1Curve;
    private ArrayList<Double> errorCurve;
    Annotation.Source expectedSource;
    Annotation.Source predictedSource;

    public MentionTaggerEval(String experimentName) {
        this(experimentName, BIO2Tag.class);
    }

    public MentionTaggerEval(String experimentName, Class tagset) {
        this(experimentName, tagset, Annotation.Source.GOLD, Annotation.Source.PRED);
    }

    public MentionTaggerEval(String experimentName, Class tagset, Annotation.Source expected, Annotation.Source predicted) {
        super(experimentName);
        try {
            this.tagset = (Tag)tagset.newInstance();
        }
        catch (IllegalAccessException | InstantiationException reflectiveOperationException) {
            // empty catch block
        }
        this.classes = this.tagset.getVectorSize();
        this.expectedSource = expected;
        this.predictedSource = predicted;
    }

    public void clear() {
        super.clear();
        this.eval = new Evaluation(this.classes);
        this.examplesCurve = new ArrayList();
        this.precisionCurve = new ArrayList();
        this.recallCurve = new ArrayList();
        this.f1Curve = new ArrayList();
        this.errorCurve = new ArrayList();
    }

    public void eval(Token t, INDArray expected, INDArray predicted, boolean print) {
        this.eval.eval(expected, predicted);
        if (print) {
            System.out.println(t.getText() + "\t" + expected + "\t" + predicted);
        }
    }

    public void evalTimeSeries(INDArray expected, INDArray predicted) {
        this.eval.evalTimeSeries(expected, predicted);
    }

    public void evalTimeSeries(INDArray expected, INDArray predicted, INDArray labelsMask) {
        if (expected.shape()[2] == 1L) {
            this.eval.evalTimeSeries(expected.transpose(), predicted.transpose());
        } else {
            this.eval.evalTimeSeries(expected, predicted, labelsMask);
        }
    }

    public void evalTimeSeries(Evaluation ev) {
        this.eval = ev;
    }

    public void appendTrainingCurve(double precision, double recall, double f1) {
        this.examplesCurve.add(0);
        this.precisionCurve.add(precision);
        this.recallCurve.add(recall);
        this.f1Curve.add(f1);
        this.errorCurve.add(0.0);
    }

    public void appendTrainingCurve(int examples, double precision, double recall, double f1, double error) {
        this.examplesCurve.add(examples);
        this.precisionCurve.add(precision);
        this.recallCurve.add(recall);
        this.f1Curve.add(f1);
        this.errorCurve.add(error);
    }

    public void calculateMeasures(Dataset test) {
        for (int c = 0; c < this.classes; ++c) {
            double tp = 0.0;
            double fp = 0.0;
            double tn = 0.0;
            double fn = 0.0;
            for (Token t : test.streamTokens().collect(Collectors.toList())) {
                String g = t.getTag(this.expectedSource, this.tagset.getClass()).getTag();
                String p = t.getTag(this.predictedSource, this.tagset.getClass()).getTag();
                String cl = this.tagset.getTag(c);
                if (g.equals(cl) && p.equals(cl)) {
                    tp += 1.0;
                }
                if (!g.equals(cl) && p.equals(cl)) {
                    fp += 1.0;
                }
                if (!g.equals(cl) && !p.equals(cl)) {
                    tn += 1.0;
                }
                if (!g.equals(cl) || p.equals(cl)) continue;
                fn += 1.0;
            }
            ((Counter)this.counts.get(ModelEvaluation.Measure.TP)).setCount((Object)c, tp);
            ((Counter)this.counts.get(ModelEvaluation.Measure.FP)).setCount((Object)c, fp);
            ((Counter)this.counts.get(ModelEvaluation.Measure.TN)).setCount((Object)c, tn);
            ((Counter)this.counts.get(ModelEvaluation.Measure.FN)).setCount((Object)c, fn);
        }
    }

    private void calculateMeasures(Evaluation eval) {
        ConfusionMatrix m = eval.getConfusionMatrix();
        for (int c = 0; c < this.classes; ++c) {
            double tp = 0.0;
            double fp = 0.0;
            double tn = 0.0;
            double fn = 0.0;
            for (int p = 0; p < this.classes; ++p) {
                for (int g = 0; g < this.classes; ++g) {
                    int x = m.getCount((Comparable)Integer.valueOf(g), (Comparable)Integer.valueOf(p));
                    if (p == c && g == c) {
                        tp += (double)x;
                    }
                    if (p == c && g != c) {
                        fp += (double)x;
                    }
                    if (p != c && g != c) {
                        tn += (double)x;
                    }
                    if (p == c || g != c) continue;
                    fn += (double)x;
                }
            }
            ((Counter)this.counts.get(ModelEvaluation.Measure.TP)).setCount((Object)c, tp);
            ((Counter)this.counts.get(ModelEvaluation.Measure.FP)).setCount((Object)c, fp);
            ((Counter)this.counts.get(ModelEvaluation.Measure.TN)).setCount((Object)c, tn);
            ((Counter)this.counts.get(ModelEvaluation.Measure.FN)).setCount((Object)c, fn);
        }
    }

    public String printSequenceStats() {
        StringBuilder line = new StringBuilder();
        line.append("SEQUENCE Training per Config [macro-avg]\t\t\t\tTrain Time [ms]\t\t\tTest Time [ms]\n").append("Conf\t\t#EncMiss\t#TP\t#FP\t#TN\t#FN\tAcc\tPrec\tRec\tF1\t#Docs\t#Sents\t#Tokens\tTotal\tDoc\tSent\t#Docs\t#Sents\t#Tokens\tTotal\tDoc\tSent\n");
        System.out.println(line.toString());
        return line.toString();
    }

    public String printSequenceClassStats() {
        return this.printSequenceClassStats(true);
    }

    public String printSequenceClassStats(boolean calculate) {
        if (calculate) {
            this.calculateMeasures(this.eval);
        }
        StringBuilder line = new StringBuilder();
        line.append("SEQUENCE Labeling per Class [macro-avg]\n").append("Class\t#Tokns\t#Enc\t    TP\t    FP\t    TN\t    FN\tAcc\tPrec\tRec\tF1\n");
        double acc = 0.0;
        double pre = 0.0;
        double rec = 0.0;
        for (int c = 0; c < this.classes; ++c) {
            line.append(this.tagset.getTag(c)).append("\t");
            line.append(this.fInt(((Counter)this.counts.get(ModelEvaluation.Measure.TP)).getCount((Object)c) + ((Counter)this.counts.get(ModelEvaluation.Measure.FN)).getCount((Object)c))).append("\t");
            line.append("\t");
            line.append(this.fInt(((Counter)this.counts.get(ModelEvaluation.Measure.TP)).getCount((Object)c))).append("\t");
            line.append(this.fInt(((Counter)this.counts.get(ModelEvaluation.Measure.FP)).getCount((Object)c))).append("\t");
            line.append(this.fInt(((Counter)this.counts.get(ModelEvaluation.Measure.TN)).getCount((Object)c))).append("\t");
            line.append(this.fInt(((Counter)this.counts.get(ModelEvaluation.Measure.FN)).getCount((Object)c))).append("\t");
            line.append(this.fDbl(this.getAccuracy(c))).append("\t");
            line.append(this.fDbl(this.getPrecision(c))).append("\t");
            line.append(this.fDbl(this.getRecall(c))).append("\t");
            line.append(this.fDbl(this.getF1(c))).append("\t");
            line.append("\n");
            acc += this.getAccuracy(c);
            pre += this.getPrecision(c);
            rec += this.getRecall(c);
        }
        acc /= (double)this.classes;
        pre /= (double)this.classes;
        rec /= (double)this.classes;
        line.append("Total\t");
        line.append(this.fInt(((Counter)this.counts.get(ModelEvaluation.Measure.TP)).totalCount() + ((Counter)this.counts.get(ModelEvaluation.Measure.FN)).totalCount())).append("\t");
        line.append("\t");
        line.append(this.fInt(((Counter)this.counts.get(ModelEvaluation.Measure.TP)).totalCount())).append("\t");
        line.append(this.fInt(((Counter)this.counts.get(ModelEvaluation.Measure.FP)).totalCount())).append("\t");
        line.append(this.fInt(((Counter)this.counts.get(ModelEvaluation.Measure.TN)).totalCount())).append("\t");
        line.append(this.fInt(((Counter)this.counts.get(ModelEvaluation.Measure.FN)).totalCount())).append("\t");
        line.append(this.fDbl(acc)).append("\t");
        line.append(this.fDbl(pre)).append("\t");
        line.append(this.fDbl(rec)).append("\t");
        line.append(this.fDbl(this.getF1(pre, rec))).append("\t");
        line.append("\n");
        System.out.println(line.toString());
        return line.toString();
    }

    public String printTrainingCurve() {
        StringBuilder line = new StringBuilder();
        line.append("#\tCount\tPrec\tRec\tF1\tError\n");
        for (int i = 0; i < this.f1Curve.size(); ++i) {
            line.append(i).append("\t");
            line.append(this.fInt(this.examplesCurve.get(i).intValue())).append("\t");
            line.append(this.fDbl(this.precisionCurve.get(i))).append("\t");
            line.append(this.fDbl(this.recallCurve.get(i))).append("\t");
            line.append(this.fDbl(this.f1Curve.get(i))).append("\t");
            line.append(this.fDbl(this.errorCurve.get(i) / 100.0));
            line.append("\n");
        }
        return line.toString();
    }

    private double getAccuracy(int c) {
        return this.div(this.seqL(ModelEvaluation.Measure.TP, c) + this.seqL(ModelEvaluation.Measure.TN, c), this.seqL(ModelEvaluation.Measure.TP, c) + this.seqL(ModelEvaluation.Measure.TN, c) + this.seqL(ModelEvaluation.Measure.FP, c) + this.seqL(ModelEvaluation.Measure.FN, c));
    }

    private double getPrecision(int c) {
        return this.div(this.seqL(ModelEvaluation.Measure.TP, c), this.seqL(ModelEvaluation.Measure.TP, c) + this.seqL(ModelEvaluation.Measure.FP, c));
    }

    private double getRecall(int c) {
        return this.div(this.seqL(ModelEvaluation.Measure.TP, c), this.seqL(ModelEvaluation.Measure.TP, c) + this.seqL(ModelEvaluation.Measure.FN, c));
    }

    private double getF1(int c) {
        return this.getF1(this.getPrecision(c), this.getRecall(c));
    }

    private double getF1(double precision, double recall) {
        if (precision + recall == 0.0) {
            return 0.0;
        }
        return 2.0 * precision * recall / (precision + recall);
    }
}

