/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.extract;

import cc.mallet.extract.ExactMatchComparator;
import cc.mallet.extract.Extraction;
import cc.mallet.extract.ExtractionEvaluator;
import cc.mallet.extract.Field;
import cc.mallet.extract.FieldComparator;
import cc.mallet.extract.Record;
import cc.mallet.types.Label;
import cc.mallet.types.LabelAlphabet;
import cc.mallet.types.MatrixOps;
import java.io.OutputStream;
import java.io.PrintStream;
import java.text.DecimalFormat;
import java.util.Iterator;

public class PerFieldF1Evaluator
implements ExtractionEvaluator {
    private FieldComparator comparator = new ExactMatchComparator();
    private PrintStream errorOutputStream = null;

    public FieldComparator getComparator() {
        return this.comparator;
    }

    public void setComparator(FieldComparator comparator) {
        this.comparator = comparator;
    }

    public PrintStream getErrorOutputStream() {
        return this.errorOutputStream;
    }

    public void setErrorOutputStream(OutputStream errorOutputStream) {
        this.errorOutputStream = new PrintStream(errorOutputStream);
    }

    @Override
    public void evaluate(Extraction extraction) {
        this.evaluate("", extraction, System.out);
    }

    public void evaluate(String description, Extraction extraction, PrintStream out) {
        int numDocs = extraction.getNumDocuments();
        assert (numDocs == extraction.getNumRecords());
        LabelAlphabet dict = extraction.getLabelAlphabet();
        int numLabels = dict.size();
        int[] numCorr = new int[numLabels];
        int[] numPred = new int[numLabels];
        int[] numTrue = new int[numLabels];
        int docnum = 0;
        while (docnum < numDocs) {
            Label name;
            Record extracted = extraction.getRecord(docnum);
            Record target = extraction.getTargetRecord(docnum);
            Iterator it = extracted.fieldsIterator();
            while (it.hasNext()) {
                Field predField = (Field)it.next();
                name = predField.getName();
                Field trueField = target.getField(name);
                int idx = name.getIndex();
                int j = 0;
                while (j < predField.numValues()) {
                    int n = idx;
                    numPred[n] = numPred[n] + 1;
                    if (trueField != null && trueField.isValue(predField.value(j), this.comparator)) {
                        int n2 = idx;
                        numCorr[n2] = numCorr[n2] + 1;
                    } else if (this.errorOutputStream != null) {
                        this.errorOutputStream.println("Error in extraction!");
                        this.errorOutputStream.println("Predicted " + predField);
                        this.errorOutputStream.println("True " + trueField);
                        this.errorOutputStream.println();
                    }
                    ++j;
                }
            }
            it = target.fieldsIterator();
            while (it.hasNext()) {
                Field trueField = (Field)it.next();
                name = trueField.getName();
                int n = name.getIndex();
                numTrue[n] = numTrue[n] + trueField.numValues();
            }
            ++docnum;
        }
        out.println(String.valueOf(description) + " SEGMENT counts");
        out.println("Name\tCorrect\tPred\tTarget");
        int i = 0;
        while (i < numLabels) {
            Label name = dict.lookupLabel(i);
            out.println(name + "\t" + numCorr[i] + "\t" + numPred[i] + "\t" + numTrue[i]);
            ++i;
        }
        out.println();
        DecimalFormat f = new DecimalFormat("0.####");
        double totalF1 = 0.0;
        int totalFields = 0;
        out.println(String.valueOf(description) + " per-field F1");
        out.println("Name\tP\tR\tF1");
        int i2 = 0;
        while (i2 < numLabels) {
            double F1;
            double P2 = numPred[i2] == 0 ? 0.0 : (double)numCorr[i2] / (double)numPred[i2];
            double R = numTrue[i2] == 0 ? 1.0 : (double)numCorr[i2] / (double)numTrue[i2];
            double d = F1 = P2 + R == 0.0 ? 0.0 : 2.0 * P2 * R / (P2 + R);
            if (numPred[i2] > 0 || numTrue[i2] > 0) {
                totalF1 += F1;
                ++totalFields;
            }
            Label name = dict.lookupLabel(i2);
            out.println(name + "\t" + f.format(P2) + "\t" + f.format(R) + "\t" + f.format(F1));
            ++i2;
        }
        int totalCorr = MatrixOps.sum(numCorr);
        int totalPred = MatrixOps.sum(numPred);
        int totalTrue = MatrixOps.sum(numTrue);
        double P3 = (double)totalCorr / (double)totalPred;
        double R = (double)totalCorr / (double)totalTrue;
        double F1 = 2.0 * P3 * R / (P3 + R);
        out.println("OVERALL (micro-averaged) P=" + f.format(P3) + " R=" + f.format(R) + " F1=" + f.format(F1));
        out.println("OVERALL (macro-averaged) F1=" + f.format(totalF1 / (double)totalFields));
        out.println();
    }
}

