/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.fst.semi_supervised.pr;

import cc.mallet.fst.CRF;
import cc.mallet.fst.Transducer;
import cc.mallet.fst.TransducerTrainer;
import cc.mallet.fst.semi_supervised.StateLabelMap;
import cc.mallet.fst.semi_supervised.pr.CRFOptimizableByKL;
import cc.mallet.fst.semi_supervised.pr.ConstraintsOptimizableByPR;
import cc.mallet.fst.semi_supervised.pr.PRAuxiliaryModel;
import cc.mallet.fst.semi_supervised.pr.constraints.PRConstraint;
import cc.mallet.optimize.LimitedMemoryBFGS;
import cc.mallet.optimize.Optimizer;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import java.util.ArrayList;
import java.util.BitSet;

public class CRFTrainerByPR
extends TransducerTrainer
implements TransducerTrainer.ByOptimization {
    private boolean converged;
    private int iter;
    private int numThreads;
    private double pGpv;
    private double tolerance;
    private double value;
    private double qValue;
    private ArrayList<PRConstraint> constraints;
    private LimitedMemoryBFGS bfgs;
    private CRF crf;
    private StateLabelMap stateLabelMap;

    public CRFTrainerByPR(CRF crf, ArrayList<PRConstraint> constraints) {
        this(crf, constraints, 1);
    }

    public CRFTrainerByPR(CRF crf, ArrayList<PRConstraint> constraints, int numThreads) {
        this.crf = crf;
        this.iter = 0;
        this.value = Double.NEGATIVE_INFINITY;
        this.constraints = constraints;
        this.pGpv = 10.0;
        this.tolerance = 0.001;
        this.numThreads = numThreads;
        this.stateLabelMap = new StateLabelMap(crf.getOutputAlphabet(), true);
    }

    @Override
    public int getIteration() {
        return this.iter;
    }

    @Override
    public Transducer getTransducer() {
        return this.crf;
    }

    @Override
    public boolean isFinishedTraining() {
        return this.converged;
    }

    public void setStateLabelMap(StateLabelMap map) {
        this.stateLabelMap = map;
    }

    public void setPGaussianPriorVariance(double pGpv) {
        this.pGpv = pGpv;
    }

    public void setTolerance(double tolerance) {
        this.tolerance = tolerance;
    }

    @Override
    public boolean train(InstanceList train, int numIterations) {
        return this.train(train, 0, numIterations);
    }

    public boolean train(InstanceList train, int minIter, int maxIter) {
        return this.train(train, minIter, maxIter, Integer.MAX_VALUE);
    }

    public boolean train(InstanceList train, int minIter, int maxIter, int maxIterPerStep) {
        double oldValue = 0.0;
        int max = this.iter + maxIter;
        BitSet constrainedInstances = new BitSet();
        for (PRConstraint constraint : this.constraints) {
            constrainedInstances.or(constraint.preProcess(train));
            constraint.setStateLabelMap(this.stateLabelMap);
        }
        int removed = 0;
        InstanceList tempTrain = train.cloneEmpty();
        int ii = 0;
        while (ii < train.size()) {
            if (constrainedInstances.get(ii)) {
                tempTrain.add((Instance)train.get(ii));
            } else {
                ++removed;
            }
            ++ii;
        }
        train = tempTrain;
        System.err.println("Removed " + removed + " instances that do not contain constraints.");
        PRAuxiliaryModel model = new PRAuxiliaryModel(this.crf, this.constraints);
        while (this.iter < max) {
            long startTime = System.currentTimeMillis();
            ConstraintsOptimizableByPR opt = new ConstraintsOptimizableByPR(this.crf, train, model, this.numThreads);
            this.bfgs = new LimitedMemoryBFGS(opt);
            try {
                this.bfgs.optimize(maxIterPerStep);
            }
            catch (Exception e) {
                e.printStackTrace();
            }
            opt.shutdown();
            this.qValue = opt.getCompleteValueContribution();
            assert (this.qValue > 0.0);
            CRFOptimizableByKL optP = new CRFOptimizableByKL(this.crf, train, model, opt.getCachedDots(), this.numThreads, 1.0);
            optP.setGaussianPriorVariance(this.pGpv);
            LimitedMemoryBFGS bfgsP = new LimitedMemoryBFGS(optP);
            try {
                bfgsP.optimize(maxIterPerStep);
            }
            catch (Exception e) {
                e.printStackTrace();
            }
            optP.shutdown();
            this.value = optP.getValue() - this.qValue;
            assert (this.value < 0.0);
            System.err.println("Total value = " + this.value + " (pValue = " + optP.getValue() + ") (qValue = " + -this.qValue + ")");
            System.err.println("Time for iteration " + String.format("%.2f", (double)(System.currentTimeMillis() - startTime) / 1000.0) + "s");
            if (this.iter >= minIter && 2.0 * Math.abs(this.value - oldValue) <= this.tolerance * (Math.abs(this.value) + Math.abs(oldValue) + 1.0E-5)) {
                System.err.println("AP value difference below tolerance (oldValue: " + oldValue + "newValue: " + this.value);
                break;
            }
            oldValue = this.value;
            this.runEvaluators();
            ++this.iter;
        }
        this.converged = true;
        return this.converged;
    }

    public double getTotalValue() {
        return this.value;
    }

    public double getQValue() {
        return this.qValue;
    }

    @Override
    public Optimizer getOptimizer() {
        return this.bfgs;
    }
}

