/*
 * Decompiled with CFR 0.152.
 */
package dragon.ml.seqmodel.crf;

import dragon.matrix.DoubleFlatDenseMatrix;
import dragon.ml.seqmodel.crf.LBFGSBasicTrainer;
import dragon.ml.seqmodel.data.DataSequence;
import dragon.ml.seqmodel.data.Dataset;
import dragon.ml.seqmodel.feature.Feature;
import dragon.ml.seqmodel.feature.FeatureGenerator;
import dragon.ml.seqmodel.model.ModelGraph;
import dragon.util.MathUtil;

public class LBFGSSegmentTrainer
extends LBFGSBasicTrainer {
    private int maxSegmentLength;

    public LBFGSSegmentTrainer(ModelGraph model, FeatureGenerator featureGenerator, int maxSegmentLength) {
        super(model, featureGenerator);
        this.maxSegmentLength = maxSegmentLength;
    }

    @Override
    protected double computeFunctionGradient(Dataset diter, double[] lambda, double[] grad) {
        try {
            int f;
            if (this.doScaling) {
                return this.computeFunctionGradientLL(diter, lambda, grad);
            }
            int stateNum = this.model.getStateNum();
            double logli = 0.0;
            Object alpha_Y = null;
            Object beta_Y = null;
            double[] expF = new double[this.featureGenerator.getFeatureNum()];
            DoubleFlatDenseMatrix Mi_YY = new DoubleFlatDenseMatrix(stateNum, stateNum);
            for (f = 0; f < lambda.length; ++f) {
                grad[f] = -1.0 * lambda[f] * this.invSigmaSquare;
                logli -= lambda[f] * lambda[f] * this.invSigmaSquare / 2.0;
            }
            diter.startScan();
            while (diter.hasNext()) {
                int ell;
                int i;
                DataSequence dataSeq = diter.next();
                for (f = 0; f < lambda.length; ++f) {
                    expF[f] = 0.0;
                }
                int base = -1;
                if (alpha_Y == null || ((double[][])alpha_Y).length < dataSeq.length() - base) {
                    alpha_Y = new double[2 * dataSeq.length()][];
                    for (i = 0; i < ((double[][])alpha_Y).length; ++i) {
                        alpha_Y[i] = new double[stateNum];
                    }
                }
                if (beta_Y == null || ((double[][])beta_Y).length < dataSeq.length()) {
                    beta_Y = new double[2 * dataSeq.length()][];
                    for (i = 0; i < ((double[][])beta_Y).length; ++i) {
                        beta_Y[i] = new double[stateNum];
                    }
                }
                MathUtil.initArray(beta_Y[dataSeq.length() - 1], 1.0);
                for (i = dataSeq.length() - 2; i >= 0; --i) {
                    MathUtil.initArray(beta_Y[i], 0.0);
                    for (ell = 1; ell <= this.maxSegmentLength && i + ell < dataSeq.length(); ++ell) {
                        this.computeTransMatrix(lambda, dataSeq, i + 1, i + ell, Mi_YY, true);
                        this.genStateVector(Mi_YY, beta_Y[i + ell], beta_Y[i], false);
                    }
                }
                double thisSeqLogli = 0.0;
                MathUtil.initArray(alpha_Y[0], 1.0);
                int segmentStart = 0;
                int segmentEnd = -1;
                boolean invalid = false;
                for (i = 0; i < dataSeq.length(); ++i) {
                    if (segmentEnd < i) {
                        segmentStart = i;
                        segmentEnd = dataSeq.getSegmentEnd(i);
                    }
                    if (segmentEnd - segmentStart + 1 > this.maxSegmentLength) {
                        invalid = true;
                        break;
                    }
                    MathUtil.initArray(alpha_Y[i - base], 0.0);
                    for (ell = 1; ell <= this.maxSegmentLength && i - ell >= base; ++ell) {
                        boolean isSegment;
                        this.computeTransMatrix(lambda, dataSeq, i - ell + 1, i, Mi_YY, true);
                        this.featureGenerator.startScanFeaturesAt(dataSeq, i - ell + 1, i);
                        boolean bl = isSegment = i - ell + 1 == segmentStart && i == segmentEnd;
                        while (this.featureGenerator.hasNext()) {
                            Feature feature = this.featureGenerator.next();
                            f = feature.getIndex();
                            int yp = feature.getLabel();
                            int yprev = feature.getPrevLabel();
                            double val = feature.getValue();
                            if (isSegment && dataSeq.getLabel(i) == yp && (i - ell >= 0 && yprev == dataSeq.getLabel(i - ell) || yprev < 0)) {
                                int n = f;
                                grad[n] = grad[n] + val;
                                thisSeqLogli += val * lambda[f];
                            }
                            if (yprev < 0) {
                                for (yprev = 0; yprev < Mi_YY.rows(); ++yprev) {
                                    int n = f;
                                    expF[n] = expF[n] + val * alpha_Y[i - ell - base][yprev] * Mi_YY.getDouble(yprev, yp) * beta_Y[i][yp];
                                }
                                continue;
                            }
                            int n = f;
                            expF[n] = expF[n] + val * alpha_Y[i - ell - base][yprev] * Mi_YY.getDouble(yprev, yp) * beta_Y[i][yp];
                        }
                        this.genStateVector(Mi_YY, alpha_Y[i - ell - base], alpha_Y[i - base], true);
                    }
                }
                if (invalid) continue;
                double Zx = MathUtil.sumArray(alpha_Y[dataSeq.length() - 1 - base]);
                logli += (thisSeqLogli -= Math.log(Zx));
                for (f = 0; f < grad.length; ++f) {
                    int n = f;
                    grad[n] = grad[n] - expF[f] / Zx;
                }
            }
            return logli;
        }
        catch (Exception e) {
            e.printStackTrace();
            System.exit(0);
            return 0.0;
        }
    }

    protected double computeFunctionGradientLL(Dataset diter, double[] lambda, double[] grad) {
        try {
            int f;
            double logli = 0.0;
            int stateNum = this.model.getStateNum();
            Object alpha_Y = null;
            Object beta_Y = null;
            double[] expF = new double[this.featureGenerator.getFeatureNum()];
            DoubleFlatDenseMatrix Mi_YY = new DoubleFlatDenseMatrix(stateNum, stateNum);
            for (f = 0; f < lambda.length; ++f) {
                grad[f] = -1.0 * lambda[f] * this.invSigmaSquare;
                logli -= lambda[f] * lambda[f] * this.invSigmaSquare / 2.0;
            }
            diter.startScan();
            while (diter.hasNext()) {
                int ell;
                int i;
                DataSequence dataSeq = diter.next();
                for (f = 0; f < lambda.length; ++f) {
                    expF[f] = MathUtil.LOG0;
                }
                int base = -1;
                if (alpha_Y == null || ((double[][])alpha_Y).length < dataSeq.length() - base) {
                    alpha_Y = new double[2 * dataSeq.length()][];
                    for (i = 0; i < ((double[][])alpha_Y).length; ++i) {
                        alpha_Y[i] = new double[stateNum];
                    }
                }
                if (beta_Y == null || ((double[][])beta_Y).length < dataSeq.length()) {
                    beta_Y = new double[2 * dataSeq.length()][];
                    for (i = 0; i < ((double[][])beta_Y).length; ++i) {
                        beta_Y[i] = new double[stateNum];
                    }
                }
                MathUtil.initArray(beta_Y[dataSeq.length() - 1], 0.0);
                for (i = dataSeq.length() - 2; i >= 0; --i) {
                    MathUtil.initArray(beta_Y[i], MathUtil.LOG0);
                    for (ell = 1; ell <= this.maxSegmentLength && i + ell < dataSeq.length(); ++ell) {
                        this.computeTransMatrix(lambda, dataSeq, i + 1, i + ell, Mi_YY, false);
                        this.genStateVectorLog(Mi_YY, beta_Y[i + ell], beta_Y[i], false);
                    }
                }
                double thisSeqLogli = 0.0;
                MathUtil.initArray(alpha_Y[0], 0.0);
                int segmentStart = 0;
                int segmentEnd = -1;
                boolean invalid = false;
                for (i = 0; i < dataSeq.length(); ++i) {
                    if (segmentEnd < i) {
                        segmentStart = i;
                        segmentEnd = dataSeq.getSegmentEnd(i);
                    }
                    if (segmentEnd - segmentStart + 1 > this.maxSegmentLength) {
                        invalid = true;
                        break;
                    }
                    MathUtil.initArray(alpha_Y[i - base], MathUtil.LOG0);
                    for (ell = 1; ell <= this.maxSegmentLength && i - ell >= base; ++ell) {
                        boolean isSegment;
                        this.computeTransMatrix(lambda, dataSeq, i - ell + 1, i, Mi_YY, false);
                        this.featureGenerator.startScanFeaturesAt(dataSeq, i - ell, i);
                        boolean bl = isSegment = i - ell + 1 == segmentStart && i == segmentEnd;
                        while (this.featureGenerator.hasNext()) {
                            Feature feature = this.featureGenerator.next();
                            f = feature.getIndex();
                            int yp = feature.getLabel();
                            int yprev = feature.getPrevLabel();
                            double val = feature.getValue();
                            if (isSegment && dataSeq.getLabel(i) == yp && (i - ell >= 0 && yprev == dataSeq.getLabel(i - ell) || yprev < 0)) {
                                int n = f;
                                grad[n] = grad[n] + val;
                                thisSeqLogli += val * lambda[f];
                            }
                            if (yprev < 0) {
                                for (yprev = 0; yprev < Mi_YY.rows(); ++yprev) {
                                    expF[f] = MathUtil.logSumExp(expF[f], alpha_Y[i - ell - base][yprev] + Mi_YY.getDouble(yprev, yp) + MathUtil.log(val) + beta_Y[i][yp]);
                                }
                                continue;
                            }
                            expF[f] = MathUtil.logSumExp(expF[f], alpha_Y[i - ell - base][yprev] + Mi_YY.getDouble(yprev, yp) + MathUtil.log(val) + beta_Y[i][yp]);
                        }
                        this.genStateVectorLog(Mi_YY, alpha_Y[i - ell - base], alpha_Y[i - base], true);
                    }
                }
                if (invalid) continue;
                double lZx = MathUtil.logSumExp(alpha_Y[dataSeq.length() - 1 - base]);
                logli += (thisSeqLogli -= lZx);
                for (f = 0; f < grad.length; ++f) {
                    int n = f;
                    grad[n] = grad[n] - MathUtil.exp(expF[f] - lZx);
                }
            }
            return logli;
        }
        catch (Exception e) {
            e.printStackTrace();
            System.exit(0);
            return 0.0;
        }
    }
}

