/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.grmm.learning;

import cc.mallet.grmm.learning.ACRF;
import cc.mallet.grmm.learning.ACRFEvaluator;
import cc.mallet.grmm.learning.ACRFTrainer;
import cc.mallet.grmm.util.LabelsAssignment;
import cc.mallet.optimize.ConjugateGradient;
import cc.mallet.optimize.LimitedMemoryBFGS;
import cc.mallet.optimize.Optimizable;
import cc.mallet.optimize.Optimizer;
import cc.mallet.types.Alphabet;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.Label;
import cc.mallet.types.LabelAlphabet;
import cc.mallet.types.Labels;
import cc.mallet.types.LabelsSequence;
import cc.mallet.util.MalletLogger;
import cc.mallet.util.Timing;
import gnu.trove.TIntArrayList;
import java.io.File;
import java.io.FileWriter;
import java.io.PrintWriter;
import java.util.Date;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.logging.Logger;

public class DefaultAcrfTrainer
implements ACRFTrainer {
    private static Logger logger = MalletLogger.getLogger(DefaultAcrfTrainer.class.getName());
    private Optimizer maxer;
    private static boolean rethrowExceptions = false;
    private File outputPrefix = new File("");
    private static final double[] SIZE = new double[]{0.1, 0.5};
    private static final int SUBSET_ITER = 10;
    private static final Random r = new Random(1729L);

    public void setOutputPrefix(File f) {
        this.outputPrefix = f;
    }

    public Optimizer getMaxer() {
        return this.maxer;
    }

    public void setMaxer(Optimizer maxer) {
        this.maxer = maxer;
    }

    public static boolean isRethrowExceptions() {
        return rethrowExceptions;
    }

    public static void setRethrowExceptions(boolean rethrowExceptions) {
        DefaultAcrfTrainer.rethrowExceptions = rethrowExceptions;
    }

    @Override
    public boolean train(ACRF acrf, InstanceList training) {
        return this.train(acrf, training, null, null, new LogEvaluator(), 1);
    }

    @Override
    public boolean train(ACRF acrf, InstanceList training, int numIter) {
        return this.train(acrf, training, null, null, new LogEvaluator(), numIter);
    }

    @Override
    public boolean train(ACRF acrf, InstanceList training, ACRFEvaluator eval, int numIter) {
        return this.train(acrf, training, null, null, eval, numIter);
    }

    @Override
    public boolean train(ACRF acrf, InstanceList training, InstanceList validation, InstanceList testing, int numIter) {
        return this.train(acrf, training, validation, testing, new LogEvaluator(), numIter);
    }

    @Override
    public boolean train(ACRF acrf, InstanceList trainingList, InstanceList validationList, InstanceList testSet, ACRFEvaluator eval, int numIter) {
        Optimizable.ByGradientValue macrf = this.createOptimizable(acrf, trainingList);
        return this.train(acrf, trainingList, validationList, testSet, eval, numIter, macrf);
    }

    protected Optimizable.ByGradientValue createOptimizable(ACRF acrf, InstanceList trainingList) {
        return acrf.getMaximizable(trainingList);
    }

    public boolean incrementalTrain(ACRF acrf, InstanceList training, InstanceList validation, InstanceList testing, int numIter) {
        return this.incrementalTrain(acrf, training, validation, testing, new LogEvaluator(), numIter);
    }

    public boolean incrementalTrain(ACRF acrf, InstanceList training, InstanceList validation, InstanceList testing, ACRFEvaluator eval, int numIter) {
        long stime = new Date().getTime();
        int i = 0;
        while (i < SIZE.length) {
            InstanceList subset = training.split(new double[]{SIZE[i], 1.0 - SIZE[i]})[0];
            logger.info("Training on subset of size " + subset.size());
            Optimizable.ByGradientValue subset_macrf = this.createOptimizable(acrf, subset);
            this.train(acrf, training, validation, null, eval, 10, subset_macrf);
            logger.info("Subset training " + i + " finished...");
            ++i;
        }
        long etime = new Date().getTime();
        logger.info("All subset training finished.  Time = " + (etime - stime) + " ms.");
        return this.train(acrf, training, validation, testing, eval, numIter);
    }

    @Override
    public boolean train(ACRF acrf, InstanceList trainingList, InstanceList validationList, InstanceList testSet, ACRFEvaluator eval, int numIter, Optimizable.ByGradientValue macrf) {
        long etime;
        Optimizer maximizer = this.createMaxer(macrf);
        boolean converged = false;
        boolean resetOnError = true;
        long stime = System.currentTimeMillis();
        int numNodes = macrf instanceof ACRF.MaximizableACRF ? ((ACRF.MaximizableACRF)macrf).getTotalNodes() : 0;
        double thresh = 1.0E-5 * (double)numNodes;
        if (testSet == null) {
            logger.warning("ACRF trainer: No test set provided.");
        }
        double prevValue = Double.NEGATIVE_INFINITY;
        int iter = 0;
        while (iter < numIter) {
            etime = new Date().getTime();
            logger.info("ACRF trainer iteration " + iter + " at time " + (etime - stime));
            try {
                converged = maximizer.optimize(1);
                if (converged |= this.callEvaluator(acrf, trainingList, validationList, testSet, iter, eval)) break;
                resetOnError = true;
            }
            catch (RuntimeException e) {
                e.printStackTrace();
                if (resetOnError) {
                    logger.warning("Exception in iteration " + iter + ":" + e + "\n  Resetting LBFGs and trying again...");
                    if (maximizer instanceof LimitedMemoryBFGS) {
                        ((LimitedMemoryBFGS)maximizer).reset();
                    }
                    if (maximizer instanceof ConjugateGradient) {
                        ((ConjugateGradient)maximizer).reset();
                    }
                    resetOnError = false;
                }
                logger.warning("Exception in iteration " + iter + ":" + e + "\n   Quitting and saying converged...");
                converged = true;
                if (!rethrowExceptions) break;
                throw e;
            }
            if (converged) break;
            double currentValue = macrf.getValue();
            if (Math.abs(currentValue - prevValue) < thresh) {
                if (resetOnError) {
                    logger.info("ACRFTrainer saying converged:  Current value " + currentValue + ", previous " + prevValue + "\n...threshold was " + thresh + " = 1e-5 * " + numNodes);
                    converged = true;
                    break;
                }
            } else {
                prevValue = currentValue;
            }
            ++iter;
        }
        if (iter >= numIter) {
            logger.info("ACRFTrainer: Too many iterations, stopping training.  maxIter = " + numIter);
        }
        etime = System.currentTimeMillis();
        logger.info("ACRF training time (ms) = " + (etime - stime));
        if (macrf instanceof ACRF.MaximizableACRF) {
            ((ACRF.MaximizableACRF)macrf).report();
        }
        if (testSet != null && eval != null) {
            boolean oldCache = acrf.isCacheUnrolledGraphs();
            acrf.setCacheUnrolledGraphs(false);
            eval.test(acrf, testSet, "Testing");
            acrf.setCacheUnrolledGraphs(oldCache);
        }
        return converged;
    }

    private Optimizer createMaxer(Optimizable.ByGradientValue macrf) {
        if (this.maxer == null) {
            return new LimitedMemoryBFGS(macrf);
        }
        return this.maxer;
    }

    protected boolean callEvaluator(ACRF acrf, InstanceList trainingList, InstanceList validationList, InstanceList testSet, int iter, ACRFEvaluator eval) {
        if (eval == null) {
            return false;
        }
        eval.setOutputPrefix(this.outputPrefix);
        boolean wasCached = acrf.isCacheUnrolledGraphs();
        acrf.setCacheUnrolledGraphs(false);
        Timing timing = new Timing();
        if (!eval.evaluate(acrf, iter + 1, trainingList, validationList, testSet)) {
            logger.info("ACRF trainer: evaluator returned false. Quitting.");
            timing.tick("Evaluation time (iteration " + iter + ")");
            return true;
        }
        timing.tick("Evaluation time (iteration " + iter + ")");
        acrf.setCacheUnrolledGraphs(wasCached);
        return false;
    }

    public boolean someUnsupportedTrain(ACRF acrf, InstanceList trainingList, InstanceList validationList, InstanceList testSet, ACRFEvaluator eval, int numIter) {
        Optimizable.ByGradientValue macrf = this.createOptimizable(acrf, trainingList);
        this.train(acrf, trainingList, validationList, testSet, eval, 5, macrf);
        ACRF.Template[] tmpls = acrf.getTemplates();
        int ti = 0;
        while (ti < tmpls.length) {
            tmpls[ti].addSomeUnsupportedWeights(trainingList);
            ++ti;
        }
        logger.info("Some unsupporetd weights initialized.  Training...");
        return this.train(acrf, trainingList, validationList, testSet, eval, numIter, macrf);
    }

    public void test(ACRF acrf, InstanceList testing, ACRFEvaluator eval) {
        this.test(acrf, testing, new ACRFEvaluator[]{eval});
    }

    public void test(ACRF acrf, InstanceList testing, ACRFEvaluator[] evals) {
        List pred = acrf.getBestLabels(testing);
        int i = 0;
        while (i < evals.length) {
            evals[i].setOutputPrefix(this.outputPrefix);
            evals[i].test(testing, pred, "Testing");
            ++i;
        }
    }

    public static Random getRandom() {
        return r;
    }

    public void train(ACRF acrf, InstanceList training, InstanceList validation, InstanceList testing, ACRFEvaluator eval, double[] proportions, int iterPerProportion) {
        int i = 0;
        while (i < proportions.length) {
            double proportion = proportions[i];
            InstanceList[] lists = training.split(r, new double[]{proportion, 1.0});
            logger.info("ACRF trainer: Round " + i + ", training proportion = " + proportion);
            this.train(acrf, lists[0], validation, testing, eval, iterPerProportion);
            ++i;
        }
        logger.info("ACRF trainer: Training on full data");
        this.train(acrf, training, validation, testing, eval, 99999);
    }

    public static class FileEvaluator
    extends ACRFEvaluator {
        private File file;

        public FileEvaluator(File file) {
            this.file = file;
        }

        @Override
        public boolean evaluate(ACRF acrf, int iter, InstanceList training, InstanceList validation, InstanceList testing) {
            if (this.shouldDoEvaluate(iter)) {
                this.test(acrf, testing, "Testing ");
            }
            return true;
        }

        @Override
        public void test(InstanceList testList, List returnedList, String description) {
            logger.info("Number of testing instances = " + testList.size());
            TestResults results = LogEvaluator.computeTestResults(testList, returnedList);
            try {
                PrintWriter writer = new PrintWriter(new FileWriter(this.file, true));
                results.print(description, writer);
                writer.close();
            }
            catch (Exception e) {
                e.printStackTrace();
            }
        }
    }

    public static class LogEvaluator
    extends ACRFEvaluator {
        private TestResults lastResults;

        @Override
        public boolean evaluate(ACRF acrf, int iter, InstanceList training, InstanceList validation, InstanceList testing) {
            if (this.shouldDoEvaluate(iter)) {
                if (training != null) {
                    this.test(acrf, training, "Training");
                }
                if (testing != null) {
                    this.test(acrf, testing, "Testing");
                }
            }
            return true;
        }

        @Override
        public void test(InstanceList testList, List returnedList, String description) {
            logger.info(String.valueOf(description) + ": Number of instances = " + testList.size());
            TestResults results = LogEvaluator.computeTestResults(testList, returnedList);
            results.log(description);
            this.lastResults = results;
        }

        public static TestResults computeTestResults(InstanceList testList, List returnedList) {
            TestResults results = new TestResults(testList);
            Iterator it1 = testList.iterator();
            Iterator it2 = returnedList.iterator();
            while (it1.hasNext()) {
                Instance inst = (Instance)it1.next();
                LabelsAssignment lblseq = (LabelsAssignment)inst.getTarget();
                LabelsSequence target = lblseq.getLabelsSequence();
                LabelsSequence returned = (LabelsSequence)it2.next();
                LogEvaluator.compareLabelings(results, returned, target);
            }
            results.computeStatistics();
            return results;
        }

        static void compareLabelings(TestResults results, LabelsSequence returned, LabelsSequence target) {
            assert (returned.size() == target.size());
            int i = 0;
            while (i < returned.size()) {
                Labels lblsReturned = returned.getLabels(i);
                Labels lblsTarget = target.getLabels(i);
                results.incrementCount(lblsReturned, lblsTarget);
                ++i;
            }
        }

        public double getJointAccuracy() {
            return this.lastResults.getJointAccuracy();
        }
    }

    public static class TestResults {
        public int[][] confusion;
        public int numClasses;
        public int[] trueCounts;
        public int[] returnedCounts;
        public double[] precision;
        public double[] recall;
        public double[] f1;
        public TIntArrayList[] factors;
        public int maxT = 0;
        public int correctT = 0;
        public Alphabet alphabet = new Alphabet();

        TestResults(InstanceList ilist) {
            this((Instance)ilist.get(0));
        }

        TestResults(Instance inst) {
            this.setupAlphabet(inst);
            this.numClasses = this.alphabet.size();
            this.confusion = new int[this.numClasses][this.numClasses];
            this.precision = new double[this.numClasses];
            this.recall = new double[this.numClasses];
            this.f1 = new double[this.numClasses];
        }

        private void setupAlphabet(Instance inst) {
            LabelsAssignment lblseq = (LabelsAssignment)inst.getTarget();
            this.factors = new TIntArrayList[lblseq.numSlices()];
            int i = 0;
            while (i < lblseq.numSlices()) {
                LabelAlphabet dict = lblseq.getOutputAlphabet(i);
                this.factors[i] = new TIntArrayList(dict.size());
                int j = 0;
                while (j < dict.size()) {
                    int idx = this.alphabet.lookupIndex(dict.lookupObject(j));
                    this.factors[i].add(idx);
                    ++j;
                }
                ++i;
            }
        }

        void incrementCount(Labels lblsReturned, Labels lblsTarget) {
            boolean allSame = true;
            int j = 0;
            while (j < lblsReturned.size()) {
                int idxRet;
                Label lret = lblsReturned.get(j);
                Label ltarget = lblsTarget.get(j);
                int idxTrue = this.alphabet.lookupIndex(ltarget.getEntry());
                if (idxTrue != (idxRet = this.alphabet.lookupIndex(lret.getEntry()))) {
                    allSame = false;
                }
                int[] nArray = this.confusion[idxTrue];
                int n = idxRet;
                nArray[n] = nArray[n] + 1;
                ++j;
            }
            ++this.maxT;
            if (allSame) {
                ++this.correctT;
            }
        }

        void computeStatistics() {
            this.trueCounts = new int[this.numClasses];
            this.returnedCounts = new int[this.numClasses];
            int i = 0;
            while (i < this.numClasses) {
                int j = 0;
                while (j < this.numClasses) {
                    int n = i;
                    this.trueCounts[n] = this.trueCounts[n] + this.confusion[i][j];
                    int n2 = j;
                    this.returnedCounts[n2] = this.returnedCounts[n2] + this.confusion[i][j];
                    ++j;
                }
                ++i;
            }
            i = 0;
            while (i < this.numClasses) {
                double correct = this.confusion[i][i];
                this.precision[i] = this.returnedCounts[i] == 0 ? (correct == 0.0 ? 1.0 : 0.0) : correct / (double)this.returnedCounts[i];
                this.recall[i] = this.trueCounts[i] == 0 ? 1.0 : correct / (double)this.trueCounts[i];
                this.f1[i] = 2.0 * this.precision[i] * this.recall[i] / (this.precision[i] + this.recall[i]);
                ++i;
            }
        }

        public void log() {
            this.log("");
        }

        public void log(String desc) {
            logger.info(String.valueOf(desc) + ":  i\tLabel\tN\tCorrect\tReturned\tP\tR\tF1");
            int i = 0;
            while (i < this.numClasses) {
                logger.info(String.valueOf(desc) + ":  " + i + "\t" + this.alphabet.lookupObject(i) + "\t" + this.trueCounts[i] + "\t" + this.confusion[i][i] + "\t" + this.returnedCounts[i] + "\t" + this.precision[i] + "\t" + this.recall[i] + "\t" + this.f1[i] + "\t");
                ++i;
            }
            int fnum = 0;
            while (fnum < this.factors.length) {
                int correct = 0;
                int returned = 0;
                int i2 = 0;
                while (i2 < this.factors[fnum].size()) {
                    int lbl = this.factors[fnum].get(i2);
                    correct += this.confusion[lbl][lbl];
                    returned += this.returnedCounts[lbl];
                    ++i2;
                }
                logger.info(String.valueOf(desc) + ":  Factor " + fnum + " accuracy: (" + correct + " " + returned + ") " + (double)correct / (double)returned);
                ++fnum;
            }
            logger.info(String.valueOf(desc) + " CorrectT " + this.correctT + "  maxt " + this.maxT);
            logger.info(String.valueOf(desc) + " Joint accuracy: " + (double)this.correctT / (double)this.maxT);
        }

        public void print(String desc, PrintWriter out) {
            out.println("i\tLabel\tN\tCorrect\tReturned\tP\tR\tF1");
            int i = 0;
            while (i < this.numClasses) {
                out.println(String.valueOf(i) + "\t" + this.alphabet.lookupObject(i) + "\t" + this.trueCounts[i] + "\t" + this.confusion[i][i] + "\t" + this.returnedCounts[i] + "\t" + this.precision[i] + "\t" + this.recall[i] + "\t" + this.f1[i] + "\t");
                ++i;
            }
            int fnum = 0;
            while (fnum < this.factors.length) {
                int correct = 0;
                int returned = 0;
                int i2 = 0;
                while (i2 < this.factors[fnum].size()) {
                    int lbl = this.factors[fnum].get(i2);
                    correct += this.confusion[lbl][lbl];
                    returned += this.returnedCounts[lbl];
                    ++i2;
                }
                out.println(String.valueOf(desc) + " Factor " + fnum + " accuracy: (" + correct + " " + returned + ") " + (double)correct / (double)returned);
                ++fnum;
            }
            out.println(String.valueOf(desc) + " CorrectT " + this.correctT + "  maxt " + this.maxT);
            out.println(String.valueOf(desc) + " Joint accuracy: " + (double)this.correctT / (double)this.maxT);
        }

        void printConfusion() {
            System.out.println("True\t\tReturned\tCount");
            int i = 0;
            while (i < this.numClasses) {
                int j = 0;
                while (j < this.numClasses) {
                    System.out.println(String.valueOf(i) + "\t\t" + j + "\t" + this.confusion[i][j]);
                    ++j;
                }
                ++i;
            }
        }

        public double getJointAccuracy() {
            return (double)this.correctT / (double)this.maxT;
        }
    }
}

