package dragon.ml.seqmodel.crf;

import dragon.matrix.DoubleFlatDenseMatrix;
import dragon.ml.seqmodel.data.DataSequence;
import dragon.ml.seqmodel.data.Dataset;
import dragon.ml.seqmodel.feature.Feature;
import dragon.ml.seqmodel.feature.FeatureGenerator;
import dragon.ml.seqmodel.model.ModelGraph;
import dragon.util.MathUtil;

/* loaded from: input_file:dragon/ml/seqmodel/crf/LBFGSSegmentTrainer.class */
public class LBFGSSegmentTrainer extends LBFGSBasicTrainer {
    private int maxSegmentLength;

    public LBFGSSegmentTrainer(ModelGraph modelGraph, FeatureGenerator featureGenerator, int i) {
        super(modelGraph, featureGenerator);
        this.maxSegmentLength = i;
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v30, types: [double[]] */
    /* JADX WARN: Type inference failed for: r0v37, types: [double[]] */
    @Override // dragon.ml.seqmodel.crf.LBFGSBasicTrainer
    protected double computeFunctionGradient(Dataset dataset, double[] dArr, double[] dArr2) {
        try {
            if (this.doScaling) {
                return computeFunctionGradientLL(dataset, dArr, dArr2);
            }
            int stateNum = this.model.getStateNum();
            double d = 0.0d;
            double[][] dArr3 = null;
            double[][] dArr4 = null;
            double[] dArr5 = new double[this.featureGenerator.getFeatureNum()];
            DoubleFlatDenseMatrix doubleFlatDenseMatrix = new DoubleFlatDenseMatrix(stateNum, stateNum);
            for (int i = 0; i < dArr.length; i++) {
                dArr2[i] = (-1.0d) * dArr[i] * this.invSigmaSquare;
                d -= ((dArr[i] * dArr[i]) * this.invSigmaSquare) / 2.0d;
            }
            dataset.startScan();
            while (dataset.hasNext()) {
                DataSequence next = dataset.next();
                for (int i2 = 0; i2 < dArr.length; i2++) {
                    dArr5[i2] = 0.0d;
                }
                if (dArr3 == null || dArr3.length < next.length() - (-1)) {
                    dArr3 = new double[2 * next.length()];
                    for (int i3 = 0; i3 < dArr3.length; i3++) {
                        dArr3[i3] = new double[stateNum];
                    }
                }
                if (dArr4 == null || dArr4.length < next.length()) {
                    dArr4 = new double[2 * next.length()];
                    for (int i4 = 0; i4 < dArr4.length; i4++) {
                        dArr4[i4] = new double[stateNum];
                    }
                }
                MathUtil.initArray(dArr4[next.length() - 1], 1.0d);
                for (int length = next.length() - 2; length >= 0; length--) {
                    MathUtil.initArray(dArr4[length], 0.0d);
                    for (int i5 = 1; i5 <= this.maxSegmentLength && length + i5 < next.length(); i5++) {
                        computeTransMatrix(dArr, next, length + 1, length + i5, doubleFlatDenseMatrix, true);
                        genStateVector(doubleFlatDenseMatrix, dArr4[length + i5], dArr4[length], false);
                    }
                }
                double d2 = 0.0d;
                MathUtil.initArray(dArr3[0], 1.0d);
                int i6 = 0;
                int i7 = -1;
                boolean z = false;
                int i8 = 0;
                while (true) {
                    if (i8 >= next.length()) {
                        break;
                    }
                    if (i7 < i8) {
                        i6 = i8;
                        i7 = next.getSegmentEnd(i8);
                    }
                    if ((i7 - i6) + 1 > this.maxSegmentLength) {
                        z = true;
                        break;
                    }
                    MathUtil.initArray(dArr3[i8 - (-1)], 0.0d);
                    for (int i9 = 1; i9 <= this.maxSegmentLength && i8 - i9 >= -1; i9++) {
                        computeTransMatrix(dArr, next, (i8 - i9) + 1, i8, doubleFlatDenseMatrix, true);
                        this.featureGenerator.startScanFeaturesAt(next, (i8 - i9) + 1, i8);
                        boolean z2 = (i8 - i9) + 1 == i6 && i8 == i7;
                        while (this.featureGenerator.hasNext()) {
                            Feature next2 = this.featureGenerator.next();
                            int index = next2.getIndex();
                            int label = next2.getLabel();
                            int prevLabel = next2.getPrevLabel();
                            double value = next2.getValue();
                            if (z2 && next.getLabel(i8) == label && ((i8 - i9 >= 0 && prevLabel == next.getLabel(i8 - i9)) || prevLabel < 0)) {
                                dArr2[index] = dArr2[index] + value;
                                d2 += value * dArr[index];
                            }
                            if (prevLabel < 0) {
                                for (int i10 = 0; i10 < doubleFlatDenseMatrix.rows(); i10++) {
                                    dArr5[index] = dArr5[index] + (value * dArr3[(i8 - i9) - (-1)][i10] * doubleFlatDenseMatrix.getDouble(i10, label) * dArr4[i8][label]);
                                }
                            } else {
                                dArr5[index] = dArr5[index] + (value * dArr3[(i8 - i9) - (-1)][prevLabel] * doubleFlatDenseMatrix.getDouble(prevLabel, label) * dArr4[i8][label]);
                            }
                        }
                        genStateVector(doubleFlatDenseMatrix, dArr3[(i8 - i9) - (-1)], dArr3[i8 - (-1)], true);
                    }
                    i8++;
                }
                if (!z) {
                    double sumArray = MathUtil.sumArray(dArr3[(next.length() - 1) - (-1)]);
                    d += d2 - Math.log(sumArray);
                    for (int i11 = 0; i11 < dArr2.length; i11++) {
                        int i12 = i11;
                        dArr2[i12] = dArr2[i12] - (dArr5[i11] / sumArray);
                    }
                }
            }
            return d;
        } catch (Exception e) {
            e.printStackTrace();
            System.exit(0);
            return 0.0d;
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v28, types: [double[]] */
    /* JADX WARN: Type inference failed for: r0v35, types: [double[]] */
    protected double computeFunctionGradientLL(Dataset dataset, double[] dArr, double[] dArr2) {
        try {
            double d = 0.0d;
            int stateNum = this.model.getStateNum();
            double[][] dArr3 = null;
            double[][] dArr4 = null;
            double[] dArr5 = new double[this.featureGenerator.getFeatureNum()];
            DoubleFlatDenseMatrix doubleFlatDenseMatrix = new DoubleFlatDenseMatrix(stateNum, stateNum);
            for (int i = 0; i < dArr.length; i++) {
                dArr2[i] = (-1.0d) * dArr[i] * this.invSigmaSquare;
                d -= ((dArr[i] * dArr[i]) * this.invSigmaSquare) / 2.0d;
            }
            dataset.startScan();
            while (dataset.hasNext()) {
                DataSequence next = dataset.next();
                for (int i2 = 0; i2 < dArr.length; i2++) {
                    dArr5[i2] = MathUtil.LOG0;
                }
                if (dArr3 == null || dArr3.length < next.length() - (-1)) {
                    dArr3 = new double[2 * next.length()];
                    for (int i3 = 0; i3 < dArr3.length; i3++) {
                        dArr3[i3] = new double[stateNum];
                    }
                }
                if (dArr4 == null || dArr4.length < next.length()) {
                    dArr4 = new double[2 * next.length()];
                    for (int i4 = 0; i4 < dArr4.length; i4++) {
                        dArr4[i4] = new double[stateNum];
                    }
                }
                MathUtil.initArray(dArr4[next.length() - 1], 0.0d);
                for (int length = next.length() - 2; length >= 0; length--) {
                    MathUtil.initArray(dArr4[length], MathUtil.LOG0);
                    for (int i5 = 1; i5 <= this.maxSegmentLength && length + i5 < next.length(); i5++) {
                        computeTransMatrix(dArr, next, length + 1, length + i5, doubleFlatDenseMatrix, false);
                        genStateVectorLog(doubleFlatDenseMatrix, dArr4[length + i5], dArr4[length], false);
                    }
                }
                double d2 = 0.0d;
                MathUtil.initArray(dArr3[0], 0.0d);
                int i6 = 0;
                int i7 = -1;
                boolean z = false;
                int i8 = 0;
                while (true) {
                    if (i8 >= next.length()) {
                        break;
                    }
                    if (i7 < i8) {
                        i6 = i8;
                        i7 = next.getSegmentEnd(i8);
                    }
                    if ((i7 - i6) + 1 > this.maxSegmentLength) {
                        z = true;
                        break;
                    }
                    MathUtil.initArray(dArr3[i8 - (-1)], MathUtil.LOG0);
                    for (int i9 = 1; i9 <= this.maxSegmentLength && i8 - i9 >= -1; i9++) {
                        computeTransMatrix(dArr, next, (i8 - i9) + 1, i8, doubleFlatDenseMatrix, false);
                        this.featureGenerator.startScanFeaturesAt(next, i8 - i9, i8);
                        boolean z2 = (i8 - i9) + 1 == i6 && i8 == i7;
                        while (this.featureGenerator.hasNext()) {
                            Feature next2 = this.featureGenerator.next();
                            int index = next2.getIndex();
                            int label = next2.getLabel();
                            int prevLabel = next2.getPrevLabel();
                            double value = next2.getValue();
                            if (z2 && next.getLabel(i8) == label && ((i8 - i9 >= 0 && prevLabel == next.getLabel(i8 - i9)) || prevLabel < 0)) {
                                dArr2[index] = dArr2[index] + value;
                                d2 += value * dArr[index];
                            }
                            if (prevLabel < 0) {
                                for (int i10 = 0; i10 < doubleFlatDenseMatrix.rows(); i10++) {
                                    dArr5[index] = MathUtil.logSumExp(dArr5[index], dArr3[(i8 - i9) - (-1)][i10] + doubleFlatDenseMatrix.getDouble(i10, label) + MathUtil.log(value) + dArr4[i8][label]);
                                }
                            } else {
                                dArr5[index] = MathUtil.logSumExp(dArr5[index], dArr3[(i8 - i9) - (-1)][prevLabel] + doubleFlatDenseMatrix.getDouble(prevLabel, label) + MathUtil.log(value) + dArr4[i8][label]);
                            }
                        }
                        genStateVectorLog(doubleFlatDenseMatrix, dArr3[(i8 - i9) - (-1)], dArr3[i8 - (-1)], true);
                    }
                    i8++;
                }
                if (!z) {
                    double logSumExp = MathUtil.logSumExp(dArr3[(next.length() - 1) - (-1)]);
                    d += d2 - logSumExp;
                    for (int i11 = 0; i11 < dArr2.length; i11++) {
                        int i12 = i11;
                        dArr2[i12] = dArr2[i12] - MathUtil.exp(dArr5[i11] - logSumExp);
                    }
                }
            }
            return d;
        } catch (Exception e) {
            e.printStackTrace();
            System.exit(0);
            return 0.0d;
        }
    }
}
