/*
 * Decompiled with CFR 0.152.
 */
package dragon.ml.seqmodel.crf;

import dragon.ml.seqmodel.crf.AbstractTrainer;
import dragon.ml.seqmodel.crf.Labeler;
import dragon.ml.seqmodel.crf.ViterbiBasicLabeler;
import dragon.ml.seqmodel.data.DataSequence;
import dragon.ml.seqmodel.data.Dataset;
import dragon.ml.seqmodel.feature.Feature;
import dragon.ml.seqmodel.feature.FeatureGenerator;
import dragon.ml.seqmodel.model.ModelGraph;
import dragon.util.MathUtil;

public class CollinsBasicTrainer
extends AbstractTrainer {
    protected int topSolutions;
    protected double beta;
    protected boolean useUpdated;

    public CollinsBasicTrainer(ModelGraph model, FeatureGenerator featureGenerator) {
        super(model, featureGenerator);
        this.topSolutions = Math.min(3, model.getStateNum());
        this.beta = 0.05;
        this.useUpdated = false;
    }

    @Override
    public boolean train(Dataset dataset) {
        dataset.startScan();
        while (dataset.hasNext()) {
            this.model.mapLabelToState(dataset.next());
        }
        if (!this.featureGenerator.train(dataset)) {
            return false;
        }
        int featureNum = this.featureGenerator.getFeatureNum();
        this.lambda = new double[featureNum];
        double[] lambdaAvg = new double[featureNum];
        double[] lambdaSum = new double[featureNum];
        MathUtil.initArray(this.lambda, 0.0);
        MathUtil.initArray(lambdaAvg, 0.0);
        MathUtil.initArray(this.lambda, 0.0);
        Labeler labeler = this.getLabeler();
        DataSequence[] solutions = new DataSequence[this.topSolutions];
        int[] autoStartPos = new int[this.topSolutions];
        int trainingCount = 0;
        for (int t = 0; t < this.maxIteration; ++t) {
            int numErrs = 0;
            dataset.startScan();
            while (dataset.hasNext()) {
                double curScore;
                if (trainingCount > 0) {
                    MathUtil.copyArray(lambdaSum, lambdaAvg);
                    MathUtil.multiArray(lambdaAvg, 1.0 / (double)trainingCount);
                }
                MathUtil.initArray(autoStartPos, 0);
                DataSequence manualSeq = dataset.next();
                DataSequence autoSeq = manualSeq.copy();
                labeler.label(autoSeq, this.useUpdated ? lambdaAvg : this.lambda);
                double correctScore = this.getSequenceScore(manualSeq, this.useUpdated ? lambdaAvg : this.lambda);
                int solutionNum = 0;
                for (int k = 0; k < this.topSolutions && !((curScore = labeler.getBestSolution(autoSeq = manualSeq.copy(), k)) < correctScore * (1.0 - this.beta)); ++k) {
                    this.model.mapLabelToState(autoSeq);
                    if (this.isCorrect(manualSeq, autoSeq)) continue;
                    solutions[solutionNum] = autoSeq;
                    ++solutionNum;
                }
                if (solutionNum > 0) {
                    int startPos = this.model.getMarkovOrder() - 1;
                    while (startPos < manualSeq.length()) {
                        int autoEndPos;
                        int s;
                        int endPos = this.getSegmentEnd(manualSeq, startPos);
                        boolean different = false;
                        for (s = 0; s < solutionNum; ++s) {
                            if (autoStartPos[s] == startPos && this.getSegmentEnd(solutions[s], autoStartPos[s]) == endPos && manualSeq.getLabel(endPos) == solutions[s].getLabel(endPos)) continue;
                            different = true;
                            break;
                        }
                        if (different) {
                            ++numErrs;
                            this.updateWeights(manualSeq, startPos, endPos, 1.0, this.lambda);
                            for (s = 0; s < solutionNum; ++s) {
                                while (autoStartPos[s] <= endPos) {
                                    autoEndPos = this.getSegmentEnd(solutions[s], autoStartPos[s]);
                                    this.updateWeights(solutions[s], autoStartPos[s], autoEndPos, -1.0 / (double)solutionNum, this.lambda);
                                    autoStartPos[s] = autoEndPos + 1;
                                }
                            }
                        }
                        for (s = 0; s < solutionNum; ++s) {
                            while (autoStartPos[s] <= endPos) {
                                autoEndPos = this.getSegmentEnd(solutions[s], autoStartPos[s]);
                                autoStartPos[s] = autoEndPos + 1;
                            }
                        }
                        startPos = endPos + 1;
                    }
                }
                MathUtil.sumArray(lambdaSum, this.lambda);
                ++trainingCount;
            }
            System.out.println("Iteration " + t + " numErrs " + numErrs);
            if (numErrs == 0) break;
        }
        MathUtil.multiArray(lambdaSum, 1.0 / (double)trainingCount);
        MathUtil.copyArray(lambdaSum, this.lambda);
        return true;
    }

    protected boolean isCorrect(DataSequence manual, DataSequence auto) {
        for (int i = 0; i < manual.length(); ++i) {
            if (manual.getLabel(i) == auto.getLabel(i)) continue;
            return false;
        }
        return true;
    }

    protected void updateWeights(DataSequence dataSeq, int startPos, int endPos, double wt, double[] grad) {
        this.featureGenerator.startScanFeaturesAt(dataSeq, startPos, endPos);
        while (this.featureGenerator.hasNext()) {
            Feature feature = this.featureGenerator.next();
            int f = feature.getIndex();
            int yp = feature.getLabel();
            int yprev = feature.getPrevLabel();
            if (dataSeq.getLabel(endPos) != yp || yprev >= 0 && yprev != dataSeq.getLabel(startPos - 1)) continue;
            int n = f;
            grad[n] = grad[n] + wt * feature.getValue();
        }
    }

    protected double getSequenceScore(DataSequence dataSeq, double[] grad) {
        int startPos = this.model.getMarkovOrder() - 1;
        double score = 0.0;
        while (startPos < dataSeq.length()) {
            int endPos = this.getSegmentEnd(dataSeq, startPos);
            this.featureGenerator.startScanFeaturesAt(dataSeq, startPos, endPos);
            while (this.featureGenerator.hasNext()) {
                Feature feature = this.featureGenerator.next();
                int f = feature.getIndex();
                int yp = feature.getLabel();
                int yprev = feature.getPrevLabel();
                if (dataSeq.getLabel(endPos) != yp || yprev >= 0 && yprev != dataSeq.getLabel(startPos - 1)) continue;
                score += grad[f] * feature.getValue();
            }
            startPos = endPos + 1;
        }
        return score;
    }

    protected Labeler getLabeler() {
        return new ViterbiBasicLabeler(this.model, this.featureGenerator);
    }

    protected int getSegmentEnd(DataSequence dataSeq, int start) {
        return start;
    }
}

