/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.fst.confidence;

import cc.mallet.fst.MaxLatticeDefault;
import cc.mallet.fst.Segment;
import cc.mallet.fst.Transducer;
import cc.mallet.fst.confidence.ConstrainedForwardBackwardConfidenceEstimator;
import cc.mallet.fst.confidence.TransducerConfidenceEstimator;
import cc.mallet.fst.confidence.TransducerCorrector;
import cc.mallet.types.ArraySequence;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.Sequence;
import cc.mallet.util.MalletLogger;
import java.util.ArrayList;
import java.util.Vector;
import java.util.logging.Logger;

public class ConstrainedViterbiTransducerCorrector
implements TransducerCorrector {
    private static Logger logger = MalletLogger.getLogger(ConstrainedViterbiTransducerCorrector.class.getName());
    TransducerConfidenceEstimator confidenceEstimator;
    Transducer model;
    ArrayList leastConfidentSegments;

    public ConstrainedViterbiTransducerCorrector(TransducerConfidenceEstimator confidenceEstimator, Transducer model) {
        this.confidenceEstimator = confidenceEstimator;
        this.model = model;
    }

    public ConstrainedViterbiTransducerCorrector(Transducer model) {
        this(new ConstrainedForwardBackwardConfidenceEstimator(model), model);
    }

    public Vector getSegmentConfidences() {
        return this.confidenceEstimator.getSegmentConfidences();
    }

    public ArrayList getLeastConfidentSegments() {
        return this.leastConfidentSegments;
    }

    public ArrayList getLeastConfidentSegments(InstanceList ilist, Object[] startTags, Object[] continueTags) {
        ArrayList<Segment> ret = new ArrayList<Segment>();
        int i = 0;
        while (i < ilist.size()) {
            Segment[] orderedSegments = this.confidenceEstimator.rankSegmentsByConfidence((Instance)ilist.get(i), startTags, continueTags);
            ret.add(orderedSegments[0]);
            ++i;
        }
        return ret;
    }

    @Override
    public ArrayList correctLeastConfidentSegments(InstanceList ilist, Object[] startTags, Object[] continueTags) {
        return this.correctLeastConfidentSegments(ilist, startTags, continueTags, false);
    }

    public ArrayList correctLeastConfidentSegments(InstanceList ilist, Object[] startTags, Object[] continueTags, boolean findIncorrect) {
        ArrayList<Sequence<Object>> correctedPredictionList = new ArrayList<Sequence<Object>>();
        this.leastConfidentSegments = new ArrayList();
        logger.info(String.valueOf(this.getClass().getName()) + " ranking confidence using " + this.confidenceEstimator.getClass().getName());
        int i = 0;
        while (i < ilist.size()) {
            logger.fine("correcting instance# " + i + " / " + ilist.size());
            Instance instance = (Instance)ilist.get(i);
            Segment[] orderedSegments = new Segment[1];
            Sequence input = (Sequence)instance.getData();
            Sequence truth = (Sequence)instance.getTarget();
            Sequence<Object> predicted = new MaxLatticeDefault(this.model, input).bestOutputSequence();
            int numIncorrect = 0;
            int j = 0;
            while (j < predicted.size()) {
                numIncorrect += !predicted.get(j).equals(truth.get(j)) ? 1 : 0;
                ++j;
            }
            if (numIncorrect == 0) {
                this.leastConfidentSegments.add(null);
                correctedPredictionList.add(predicted);
            } else {
                orderedSegments = this.confidenceEstimator.rankSegmentsByConfidence(instance, startTags, continueTags);
                logger.fine("Ordered Segments:\n");
                j = 0;
                while (j < orderedSegments.length) {
                    logger.fine(orderedSegments[j].toString());
                    ++j;
                }
                logger.fine("Correcting Segment: True Sequence:");
                j = 0;
                while (j < truth.size()) {
                    logger.fine(String.valueOf((String)truth.get(j)) + "\t");
                    ++j;
                }
                logger.fine("");
                logger.fine("Ordered Segments:\n");
                j = 0;
                while (j < orderedSegments.length) {
                    logger.fine(orderedSegments[j].toString());
                    ++j;
                }
                Segment leastConfidentSegment = orderedSegments[0];
                if (findIncorrect) {
                    int j2 = 0;
                    while (j2 < orderedSegments.length) {
                        if (!orderedSegments[j2].correct()) {
                            leastConfidentSegment = orderedSegments[j2];
                            break;
                        }
                        ++j2;
                    }
                }
                if (findIncorrect && leastConfidentSegment.correct()) {
                    logger.warning("cannot find incorrect segment, probably because error is in background state\n");
                    this.leastConfidentSegments.add(null);
                    correctedPredictionList.add(predicted);
                } else {
                    this.leastConfidentSegments.add(leastConfidentSegment);
                    if (leastConfidentSegment == null) {
                        correctedPredictionList.add(predicted);
                    } else {
                        String[] sequence = new String[truth.size()];
                        int numCorrectedTokens = 0;
                        int j3 = 0;
                        while (j3 < sequence.length) {
                            sequence[j3] = null;
                            ++j3;
                        }
                        j3 = 0;
                        while (j3 < truth.size()) {
                            if (leastConfidentSegment.indexInSegment(j3)) {
                                sequence[j3] = (String)truth.get(j3);
                                ++numCorrectedTokens;
                            }
                            ++j3;
                        }
                        if (leastConfidentSegment.endsPrematurely()) {
                            sequence[leastConfidentSegment.getEnd() + 1] = (String)truth.get(leastConfidentSegment.getEnd() + 1);
                            ++numCorrectedTokens;
                        }
                        logger.fine("Constrained Segment Sequence\n");
                        j3 = 0;
                        while (j3 < sequence.length) {
                            logger.fine(sequence[j3]);
                            ++j3;
                        }
                        ArraySequence<String> segmentCorrectedOutput = new ArraySequence<String>(sequence);
                        Sequence<Object> correctedPrediction = new MaxLatticeDefault(this.model, orderedSegments[0].getInput(), segmentCorrectedOutput).bestOutputSequence();
                        int numIncorrectAfterCorrection = 0;
                        int j4 = 0;
                        while (j4 < truth.size()) {
                            numIncorrectAfterCorrection += !correctedPrediction.get(j4).equals(truth.get(j4)) ? 1 : 0;
                            ++j4;
                        }
                        logger.fine("Num incorrect tokens in original prediction: " + numIncorrect);
                        logger.fine("Num corrected tokens: " + numCorrectedTokens);
                        logger.fine("Num incorrect tokens after correction-propagation: " + numIncorrectAfterCorrection);
                        logger.fine("Correcting Segment: True Sequence:");
                        j4 = 0;
                        while (j4 < truth.size()) {
                            logger.fine(String.valueOf((String)truth.get(j4)) + "\t");
                            ++j4;
                        }
                        logger.fine("\nOriginal prediction: ");
                        j4 = 0;
                        while (j4 < predicted.size()) {
                            logger.fine(String.valueOf((String)predicted.get(j4)) + "\t");
                            ++j4;
                        }
                        logger.fine("\nCorrected prediction: ");
                        j4 = 0;
                        while (j4 < correctedPrediction.size()) {
                            logger.fine(String.valueOf((String)correctedPrediction.get(j4)) + "\t");
                            ++j4;
                        }
                        logger.fine("");
                        correctedPredictionList.add(correctedPrediction);
                    }
                }
            }
            ++i;
        }
        return correctedPredictionList;
    }
}

