/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.fst.confidence;

import cc.mallet.fst.MaxLatticeDefault;
import cc.mallet.fst.Segment;
import cc.mallet.fst.Transducer;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.Sequence;
import java.io.PrintStream;
import java.util.ArrayList;

public class ConfidenceCorrectorEvaluator {
    Object[] startTags;
    Object[] inTags;

    public ConfidenceCorrectorEvaluator(Object[] startTags, Object[] inTags) {
        this.startTags = startTags;
        this.inTags = inTags;
    }

    private boolean containsErrorInUncorrectedSegments(Sequence trueSequence, Sequence predSequence, Sequence correctedSequence, Segment correctedSegment) {
        int i = 0;
        while (i < trueSequence.size()) {
            if (correctedSegment.indexInSegment(i)) {
                if (!correctedSequence.get(i).equals(trueSequence.get(i))) {
                    System.err.println("\nTruth: ");
                    int j = 0;
                    while (j < trueSequence.size()) {
                        System.err.print(trueSequence.get(j) + " ");
                        ++j;
                    }
                    System.err.println("\nPredicted: ");
                    j = 0;
                    while (j < trueSequence.size()) {
                        System.err.print(predSequence.get(j) + " ");
                        ++j;
                    }
                    System.err.println("\nCorrected: ");
                    j = 0;
                    while (j < trueSequence.size()) {
                        System.err.print(correctedSequence.get(j) + " ");
                        ++j;
                    }
                    throw new IllegalStateException("Corrected sequence does not have correct labels for corrected segment: " + correctedSegment);
                }
            } else if (!predSequence.get(i).equals(trueSequence.get(i))) {
                return true;
            }
            ++i;
        }
        return false;
    }

    public void evaluate(Transducer model, ArrayList predictions, InstanceList ilist, ArrayList correctedSegments, String description, PrintStream outputStream, boolean errorsInUncorrected) {
        if (predictions.size() != ilist.size() || correctedSegments.size() != ilist.size()) {
            throw new IllegalArgumentException("number of predicted sequence (" + predictions.size() + ") and number of corrected segments (" + correctedSegments.size() + ") must be equal to length of instancelist (" + ilist.size() + ")");
        }
        int numIncorrect2Correct = 0;
        int numCorrect2Incorrect = 0;
        int numPropagatedIncorrect2Correct = 0;
        int numPredictedCorrect = 0;
        int numCorrectedCorrect = 0;
        int numUncorrectedCorrectBeforePropagation = 0;
        int numUncorrectedCorrectAfterPropagation = 0;
        int totalTokens = 0;
        int totalTokensInUncorrectedRegion = 0;
        int numCorrectedSequences = 0;
        int i = 0;
        while (i < ilist.size()) {
            Instance instance = (Instance)ilist.get(i);
            Sequence input = (Sequence)instance.getData();
            Sequence trueSequence = (Sequence)instance.getTarget();
            Sequence<Object> predSequence = new MaxLatticeDefault(model, input).bestOutputSequence();
            Sequence correctedSequence = (Sequence)predictions.get(i);
            Segment correctedSegment = (Segment)correctedSegments.get(i);
            if (correctedSegment != null && (!errorsInUncorrected || this.containsErrorInUncorrectedSegments(trueSequence, predSequence, correctedSequence, correctedSegment))) {
                ++numCorrectedSequences;
                totalTokens += trueSequence.size();
                boolean[] predictedMatches = this.getMatches(trueSequence, predSequence);
                boolean[] correctedMatches = this.getMatches(trueSequence, correctedSequence);
                int j = 0;
                while (j < predictedMatches.length) {
                    numPredictedCorrect += predictedMatches[j] ? 1 : 0;
                    numCorrectedCorrect += correctedMatches[j] ? 1 : 0;
                    if (predictedMatches[j] && !correctedMatches[j]) {
                        ++numCorrect2Incorrect;
                    } else if (!predictedMatches[j] && correctedMatches[j]) {
                        ++numIncorrect2Correct;
                    }
                    if (j < correctedSegment.getStart() || j > correctedSegment.getEnd()) {
                        ++totalTokensInUncorrectedRegion;
                        if (!predictedMatches[j] && correctedMatches[j]) {
                            ++numPropagatedIncorrect2Correct;
                        }
                        numUncorrectedCorrectBeforePropagation += predictedMatches[j] ? 1 : 0;
                        numUncorrectedCorrectAfterPropagation += correctedMatches[j] ? 1 : 0;
                    }
                    ++j;
                }
            }
            ++i;
        }
        double tokenAccuracyBeforeCorrection = (double)numPredictedCorrect / (double)totalTokens;
        double tokenAccuracyAfterCorrection = (double)numCorrectedCorrect / (double)totalTokens;
        double uncorrectedRegionAccuracyBeforeCorrection = (double)numUncorrectedCorrectBeforePropagation / (double)totalTokensInUncorrectedRegion;
        double uncorrectedRegionAccuracyAfterCorrection = (double)numUncorrectedCorrectAfterPropagation / (double)totalTokensInUncorrectedRegion;
        outputStream.println(String.valueOf(description) + "\nEvaluating effect of error-propagation in sequences containing at least one token error:" + "\ntotal number correctedsequences: " + numCorrectedSequences + "\ntotal number tokens: " + totalTokens + "\ntotal number tokens in \"uncorrected region\":" + totalTokensInUncorrectedRegion + "\ntotal number correct tokens before correction:" + numPredictedCorrect + "\ntotal number correct tokens after correction:" + numCorrectedCorrect + "\ntoken accuracy before correction: " + tokenAccuracyBeforeCorrection + "\ntoken accuracy after correction: " + tokenAccuracyAfterCorrection + "\nnumber tokens corrected by propagation: " + numPropagatedIncorrect2Correct + "\nnumber tokens made incorrect by propagation: " + numCorrect2Incorrect + "\ntoken accuracy of \"uncorrected region\" before propagation: " + uncorrectedRegionAccuracyBeforeCorrection + "\ntoken accuracy of \"uncorrected region\" after propagataion: " + uncorrectedRegionAccuracyAfterCorrection);
    }

    private boolean[] getMatches(Sequence s1, Sequence s2) {
        if (s1.size() != s2.size()) {
            throw new IllegalArgumentException("s1.size: " + s1.size() + " s2.size: " + s2.size());
        }
        boolean[] ret = new boolean[s1.size()];
        int i = 0;
        while (i < s1.size()) {
            ret[i] = s1.get(i).equals(s2.get(i));
            ++i;
        }
        return ret;
    }
}

