/*
 * Decompiled with CFR 0.152.
 */
package de.jungblut.math;

import de.jungblut.math.DoubleMatrix;
import de.jungblut.math.DoubleVector;
import de.jungblut.math.dense.DenseDoubleMatrix;
import de.jungblut.math.dense.DenseDoubleVector;
import java.util.Iterator;

public final class ViterbiUtils {
    private ViterbiUtils() {
        throw new IllegalAccessError();
    }

    public static DoubleMatrix decode(DoubleMatrix weights, DoubleMatrix features, DoubleMatrix featuresPerState, int classes) {
        int m = features.getRowCount();
        int[][] backpointers = new int[m][classes];
        double[][] scores = new double[m][classes];
        int prevLabel = 0;
        double[] localScores = ViterbiUtils.computeScores(classes, features.getRowVector(0), weights);
        int position = 0;
        for (int currLabel = 0; currLabel < localScores.length; ++currLabel) {
            backpointers[position][currLabel] = prevLabel;
            scores[position][currLabel] = localScores[currLabel];
        }
        for (position = 1; position < m; ++position) {
            int i = position * classes - 1;
            for (int j = 0; j < classes; ++j) {
                prevLabel = j;
                localScores = ViterbiUtils.computeScores(classes, featuresPerState.getRowVector(i + j), weights);
                for (int currLabel = 0; currLabel < localScores.length; ++currLabel) {
                    double score = localScores[currLabel] + scores[position - 1][prevLabel];
                    if (prevLabel != 0 && !(score > scores[position][currLabel])) continue;
                    backpointers[position][currLabel] = prevLabel;
                    scores[position][currLabel] = score;
                }
            }
        }
        int bestLabel = 0;
        double bestScore = scores[m - 1][bestLabel];
        for (int label = 1; label < scores[m - 1].length; ++label) {
            if (!(scores[m - 1][label] > bestScore)) continue;
            bestLabel = label;
            bestScore = scores[m - 1][label];
        }
        DenseDoubleMatrix outcome = new DenseDoubleMatrix(features.getRowCount(), classes == 2 ? 1 : classes);
        for (position = m - 1; position >= 0; --position) {
            DenseDoubleVector vec = null;
            if (classes != 2) {
                vec = new DenseDoubleVector(classes);
                vec.set(bestLabel, 1.0);
            } else {
                vec = new DenseDoubleVector(1);
                vec.set(0, (double)bestLabel);
            }
            outcome.setRowVector(position, (DoubleVector)vec);
            bestLabel = backpointers[position][bestLabel];
        }
        return outcome;
    }

    static double[] computeScores(int classes, DoubleVector features, DoubleMatrix weights) {
        double[] scores = new double[classes];
        Iterator iterateNonZero = features.iterateNonZero();
        while (iterateNonZero.hasNext()) {
            DoubleVector.DoubleVectorElement next = (DoubleVector.DoubleVectorElement)iterateNonZero.next();
            for (int i = 0; i < scores.length; ++i) {
                int n = i;
                scores[n] = scores[n] + weights.get(i, next.getIndex());
            }
        }
        return scores;
    }
}

