/*
 * Decompiled with CFR 0.152.
 */
package dragon.ml.seqmodel.crf;

import dragon.matrix.DoubleFlatDenseMatrix;
import dragon.ml.seqmodel.crf.AbstractTrainer;
import dragon.ml.seqmodel.crf.LBFGS;
import dragon.ml.seqmodel.data.DataSequence;
import dragon.ml.seqmodel.data.Dataset;
import dragon.ml.seqmodel.feature.Feature;
import dragon.ml.seqmodel.feature.FeatureGenerator;
import dragon.ml.seqmodel.model.ModelGraph;
import dragon.util.MathUtil;
import java.util.Date;

public class LBFGSBasicTrainer
extends AbstractTrainer {
    protected int mForHessian = 7;
    protected double epsForConvergence = 0.001;
    protected double invSigmaSquare = 0.01;

    public LBFGSBasicTrainer(ModelGraph model, FeatureGenerator featureGenerator) {
        super(model, featureGenerator);
    }

    public void setGradientHistory(int history) {
        this.mForHessian = history;
    }

    public void setAccuracy(int eps) {
        this.epsForConvergence = eps;
    }

    public void setInvSigmaSquare(int invSigmaSquare) {
        this.invSigmaSquare = invSigmaSquare;
    }

    @Override
    public boolean train(Dataset dataset) {
        int j;
        dataset.startScan();
        while (dataset.hasNext()) {
            this.model.mapLabelToState(dataset.next());
        }
        if (!this.featureGenerator.train(dataset)) {
            return false;
        }
        int featureNum = this.featureGenerator.getFeatureNum();
        this.lambda = new double[featureNum];
        double[] gradLogli = new double[featureNum];
        double[] diag = new double[featureNum];
        int[] iprint = new int[2];
        int[] iflag = new int[1];
        int icall = 0;
        iprint[0] = -1;
        iprint[1] = 0;
        iflag[0] = 0;
        for (j = 0; j < this.lambda.length; ++j) {
            this.lambda[j] = 0.0;
        }
        do {
            double f = this.computeFunctionGradient(dataset, this.lambda, gradLogli);
            System.out.println(new Date().toString() + " Iteration: " + icall + " log likelihood " + f + " norm(grad logli) " + this.norm(gradLogli) + " norm(x) " + this.norm(this.lambda));
            f = -1.0 * f;
            j = 0;
            while (j < this.lambda.length) {
                int n = j++;
                gradLogli[n] = gradLogli[n] * -1.0;
            }
            try {
                LBFGS.lbfgs(featureNum, this.mForHessian, this.lambda, f, gradLogli, false, diag, iprint, this.epsForConvergence, xtol, iflag);
            }
            catch (LBFGS.ExceptionWithIflag e) {
                System.err.println("CRF: lbfgs failed.\n" + e);
                if (e.iflag == -1) {
                    System.err.println("Possible reasons could be: \n \t 1. Bug in the feature generation or data handling code\n\t 2. Not enough features to make observed feature value==expected value\n");
                }
                return false;
            }
        } while (iflag[0] != 0 && ++icall <= this.maxIteration);
        return true;
    }

    protected double norm(double[] ar) {
        double v = 0.0;
        for (int f = 0; f < ar.length; ++f) {
            v += ar[f] * ar[f];
        }
        return Math.sqrt(v);
    }

    protected double computeFunctionGradient(Dataset diter, double[] lambda, double[] grad) {
        double logli = 0.0;
        int markovOrder = this.model.getMarkovOrder();
        int stateNum = this.model.getStateNum();
        double[] alpha_Y = new double[stateNum];
        double[] newAlpha_Y = new double[stateNum];
        Object beta_Y = null;
        double[] scale = null;
        double[] expF = new double[this.featureGenerator.getFeatureNum()];
        DoubleFlatDenseMatrix Mi_YY = new DoubleFlatDenseMatrix(stateNum, stateNum);
        try {
            int f;
            for (f = 0; f < lambda.length; ++f) {
                grad[f] = -1.0 * lambda[f] * this.invSigmaSquare;
                logli -= lambda[f] * lambda[f] * this.invSigmaSquare / 2.0;
            }
            diter.startScan();
            while (diter.hasNext()) {
                int i;
                DataSequence dataSeq = diter.next();
                MathUtil.initArray(alpha_Y, 1.0);
                for (f = 0; f < lambda.length; ++f) {
                    expF[f] = 0.0;
                }
                if (beta_Y == null || ((double[][])beta_Y).length < dataSeq.length()) {
                    beta_Y = new double[2 * dataSeq.length()][];
                    for (i = 0; i < ((double[][])beta_Y).length; ++i) {
                        beta_Y[i] = new double[stateNum];
                    }
                    scale = new double[2 * dataSeq.length()];
                }
                scale[dataSeq.length() - 1] = this.doScaling ? (double)stateNum : 1.0;
                MathUtil.initArray(beta_Y[dataSeq.length() - 1], 1.0 / scale[dataSeq.length() - 1]);
                for (i = dataSeq.length() - 1; i > markovOrder - 1; --i) {
                    this.computeTransMatrix(lambda, dataSeq, i, i, Mi_YY, true);
                    MathUtil.initArray(beta_Y[i - 1], 0.0);
                    this.genStateVector(Mi_YY, beta_Y[i], beta_Y[i - 1], false);
                    double d = scale[i - 1] = this.doScaling ? MathUtil.sumArray(beta_Y[i - 1]) : 1.0;
                    if (scale[i - 1] < 1.0 && scale[i - 1] > -1.0) {
                        scale[i - 1] = 1.0;
                    }
                    MathUtil.multiArray(beta_Y[i - 1], 1.0 / scale[i - 1]);
                }
                double thisSeqLogli = 0.0;
                for (i = markovOrder - 1; i < dataSeq.length(); ++i) {
                    this.computeTransMatrix(lambda, dataSeq, i, i, Mi_YY, true);
                    MathUtil.initArray(newAlpha_Y, 0.0);
                    this.genStateVector(Mi_YY, alpha_Y, newAlpha_Y, true);
                    this.featureGenerator.startScanFeaturesAt(dataSeq, i, i);
                    while (this.featureGenerator.hasNext()) {
                        Feature feature = this.featureGenerator.next();
                        f = feature.getIndex();
                        int yp = feature.getLabel();
                        int yprev = feature.getPrevLabel();
                        double val = feature.getValue();
                        if (dataSeq.getLabel(i) == yp && (i - 1 >= 0 && yprev == dataSeq.getLabel(i - 1) || yprev < 0)) {
                            int n = f;
                            grad[n] = grad[n] + val;
                            thisSeqLogli += val * lambda[f];
                        }
                        if (yprev < 0) {
                            int n = f;
                            expF[n] = expF[n] + val * newAlpha_Y[yp] * beta_Y[i][yp];
                            continue;
                        }
                        int n = f;
                        expF[n] = expF[n] + val * alpha_Y[yprev] * Mi_YY.getDouble(yprev, yp) * beta_Y[i][yp];
                    }
                    MathUtil.copyArray(newAlpha_Y, alpha_Y);
                    MathUtil.multiArray(alpha_Y, 1.0 / scale[i]);
                }
                double Zx = MathUtil.sumArray(alpha_Y);
                thisSeqLogli -= Math.log(Zx);
                for (i = markovOrder - 1; i < dataSeq.length(); ++i) {
                    thisSeqLogli -= Math.log(scale[i]);
                }
                logli += thisSeqLogli;
                for (f = 0; f < grad.length; ++f) {
                    int n = f;
                    grad[n] = grad[n] - expF[f] / Zx;
                }
            }
        }
        catch (Exception e) {
            e.printStackTrace();
            System.exit(0);
        }
        return logli;
    }
}

