/*
 * Decompiled with CFR 0.152.
 */
package de.jungblut.nlp;

import com.google.common.base.Preconditions;
import de.jungblut.classification.AbstractClassifier;
import de.jungblut.math.DoubleMatrix;
import de.jungblut.math.DoubleVector;
import de.jungblut.math.ViterbiUtils;
import de.jungblut.math.dense.DenseDoubleMatrix;
import de.jungblut.math.dense.DenseDoubleVector;
import de.jungblut.math.sparse.SparseDoubleRowMatrix;
import de.jungblut.writable.MatrixWritable;
import de.jungblut.writable.VectorWritable;
import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.util.Iterator;
import java.util.Random;
import org.apache.commons.math3.util.FastMath;
import org.apache.hadoop.io.Writable;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

public final class HMM
extends AbstractClassifier
implements Writable {
    private static final Logger LOG = LogManager.getLogger(HMM.class);
    private int numVisibleStates;
    private int numHiddenStates;
    private DoubleMatrix transitionProbabilityMatrix;
    private DoubleMatrix emissionProbabilityMatrix;
    private DoubleVector hiddenPriorProbability;
    private long seed;

    public HMM() {
        this.seed = System.currentTimeMillis();
    }

    public HMM(int numVisibleStates, int numHiddenStates) {
        this(numVisibleStates, numHiddenStates, System.currentTimeMillis());
    }

    HMM(int numVisibleStates, int numHiddenStates, long seed) {
        this.seed = seed;
        this.numVisibleStates = numVisibleStates;
        this.numHiddenStates = numHiddenStates;
        this.transitionProbabilityMatrix = new DenseDoubleMatrix(numHiddenStates, numHiddenStates);
        this.emissionProbabilityMatrix = new DenseDoubleMatrix(numHiddenStates, numVisibleStates);
        this.hiddenPriorProbability = new DenseDoubleVector(numHiddenStates);
    }

    private void normalizeProbabilities() {
        HMM.normalize(this.hiddenPriorProbability, this.transitionProbabilityMatrix, this.emissionProbabilityMatrix, false);
    }

    private void logNormalizeProbabilities() {
        HMM.normalize(this.hiddenPriorProbability, this.transitionProbabilityMatrix, this.emissionProbabilityMatrix, true);
    }

    public double estimateLikelihood(DoubleVector[] observationSequence) {
        return HMM.estimateLikelihood(HMM.forward((DoubleMatrix)new DenseDoubleMatrix(observationSequence.length, this.numHiddenStates), this.transitionProbabilityMatrix, this.emissionProbabilityMatrix, this.hiddenPriorProbability, observationSequence));
    }

    private static double estimateLikelihood(DoubleMatrix alpha) {
        return alpha.getRowVector(alpha.getRowCount() - 1).sum();
    }

    public DoubleMatrix decode(DoubleVector[] observationSequence, DoubleVector[] featuresPerHiddenState) {
        return ViterbiUtils.decode(this.emissionProbabilityMatrix, (DoubleMatrix)new SparseDoubleRowMatrix(observationSequence), (DoubleMatrix)new SparseDoubleRowMatrix(featuresPerHiddenState), this.numHiddenStates);
    }

    public void trainUnsupervised(DoubleVector[] features, double epsilon, int maxIterations, boolean verbose) {
        Random random = new Random(this.seed);
        this.transitionProbabilityMatrix = new DenseDoubleMatrix(this.numHiddenStates, this.numHiddenStates, random);
        this.emissionProbabilityMatrix = new DenseDoubleMatrix(this.numHiddenStates, this.numVisibleStates, random);
        this.hiddenPriorProbability = new DenseDoubleVector(this.numHiddenStates);
        for (int i = 0; i < this.numHiddenStates; ++i) {
            this.hiddenPriorProbability.set(i, random.nextDouble());
        }
        this.normalizeProbabilities();
        DenseDoubleMatrix alpha = new DenseDoubleMatrix(features.length, this.numHiddenStates);
        DenseDoubleMatrix beta = new DenseDoubleMatrix(features.length, this.numHiddenStates);
        for (int iteration = 0; iteration < maxIterations; ++iteration) {
            Iterator iterateNonZero;
            int t;
            double temp;
            int j;
            int i;
            DoubleMatrix transitionProbabilityMatrix = this.transitionProbabilityMatrix.deepCopy();
            DoubleMatrix emissionProbabilityMatrix = this.emissionProbabilityMatrix.deepCopy();
            DoubleVector hiddenPriorProbability = this.hiddenPriorProbability.deepCopy();
            alpha = HMM.forward((DoubleMatrix)alpha, transitionProbabilityMatrix, emissionProbabilityMatrix, hiddenPriorProbability, features);
            beta = HMM.backward((DoubleMatrix)beta, transitionProbabilityMatrix, emissionProbabilityMatrix, hiddenPriorProbability, features);
            hiddenPriorProbability = alpha.getRowVector(0).multiply(beta.getRowVector(0));
            double modelLikelihood = HMM.estimateLikelihood((DoubleMatrix)alpha);
            for (i = 0; i < this.numHiddenStates; ++i) {
                for (j = 0; j < this.numHiddenStates; ++j) {
                    temp = 0.0;
                    for (t = 0; t < features.length - 1; ++t) {
                        iterateNonZero = features[t + 1].iterateNonZero();
                        while (iterateNonZero.hasNext()) {
                            temp += alpha.get(t, i) * emissionProbabilityMatrix.get(j, ((DoubleVector.DoubleVectorElement)iterateNonZero.next()).getIndex()) * beta.get(t + 1, j);
                        }
                    }
                    transitionProbabilityMatrix.set(i, j, transitionProbabilityMatrix.get(i, j) * temp / modelLikelihood);
                }
            }
            for (i = 0; i < this.numHiddenStates; ++i) {
                for (j = 0; j < this.numVisibleStates; ++j) {
                    temp = 0.0;
                    for (t = 0; t < features.length; ++t) {
                        iterateNonZero = features[t].iterateNonZero();
                        while (iterateNonZero.hasNext()) {
                            DoubleVector.DoubleVectorElement next = (DoubleVector.DoubleVectorElement)iterateNonZero.next();
                            if (next.getIndex() != j) continue;
                            temp += alpha.get(t, i) * beta.get(t, i);
                        }
                    }
                    emissionProbabilityMatrix.set(i, j, temp / modelLikelihood);
                }
            }
            HMM.normalize(hiddenPriorProbability, transitionProbabilityMatrix, emissionProbabilityMatrix, false);
            double difference = this.transitionProbabilityMatrix.subtract(transitionProbabilityMatrix).pow(2.0).sum() + this.emissionProbabilityMatrix.subtract(emissionProbabilityMatrix).pow(2.0).sum() + this.getHiddenPriorProbability().subtract(hiddenPriorProbability).pow(2.0).sum();
            if (verbose) {
                LOG.info("Iteration " + iteration + " | Model difference: " + difference + "\r");
            }
            this.transitionProbabilityMatrix = transitionProbabilityMatrix;
            this.emissionProbabilityMatrix = emissionProbabilityMatrix;
            this.hiddenPriorProbability = hiddenPriorProbability;
            if (difference < epsilon) break;
        }
        HMM.normalize(this.hiddenPriorProbability, this.transitionProbabilityMatrix, this.emissionProbabilityMatrix, true);
    }

    private static DoubleMatrix backward(DoubleMatrix beta, DoubleMatrix transitionProbabilityMatrix, DoubleMatrix emissionProbabilityMatrix, DoubleVector hiddenPriorProbability, DoubleVector[] features) {
        int numHiddenStates = beta.getColumnCount();
        beta.setRowVector(features.length - 1, (DoubleVector)DenseDoubleVector.ones((int)numHiddenStates));
        for (int t = features.length - 2; t >= 0; --t) {
            for (int i = 0; i < numHiddenStates; ++i) {
                double sum = 0.0;
                for (int j = 0; j < numHiddenStates; ++j) {
                    Iterator iterateNonZero = features[t + 1].iterateNonZero();
                    while (iterateNonZero.hasNext()) {
                        DoubleVector.DoubleVectorElement next = (DoubleVector.DoubleVectorElement)iterateNonZero.next();
                        sum += beta.get(t + 1, j) * transitionProbabilityMatrix.get(i, j) * emissionProbabilityMatrix.get(j, next.getIndex());
                    }
                }
                beta.set(t, i, sum);
            }
        }
        return beta;
    }

    private static DoubleMatrix forward(DoubleMatrix alpha, DoubleMatrix transitionProbabilityMatrix, DoubleMatrix emissionProbabilityMatrix, DoubleVector hiddenPriorProbability, DoubleVector[] features) {
        int numHiddenStates = alpha.getColumnCount();
        for (int i = 0; i < numHiddenStates; ++i) {
            Iterator firstFeatures = features[0].iterateNonZero();
            double emissionSum = 0.0;
            while (firstFeatures.hasNext()) {
                emissionSum += emissionProbabilityMatrix.get(i, ((DoubleVector.DoubleVectorElement)firstFeatures.next()).getIndex());
            }
            alpha.set(0, i, hiddenPriorProbability.get(i) * emissionSum);
        }
        for (int t = 1; t < features.length; ++t) {
            for (int i = 0; i < numHiddenStates; ++i) {
                double sum = 0.0;
                for (int j = 0; j < numHiddenStates; ++j) {
                    sum += alpha.get(t - 1, j) * transitionProbabilityMatrix.get(j, i);
                }
                Iterator featureIterator = features[t].iterateNonZero();
                double emissionSum = 0.0;
                while (featureIterator.hasNext()) {
                    emissionSum += emissionProbabilityMatrix.get(i, ((DoubleVector.DoubleVectorElement)featureIterator.next()).getIndex());
                }
                alpha.set(t, i, sum * emissionSum);
            }
        }
        return alpha;
    }

    private static void normalize(DoubleVector hiddenPriorProbability, DoubleMatrix transitionProbabilityMatrix, DoubleMatrix emissionProbabilitiyMatrix, boolean log) {
        double sum = hiddenPriorProbability.sum();
        if (sum != 0.0) {
            for (int i = 0; i < hiddenPriorProbability.getDimension(); ++i) {
                hiddenPriorProbability.set(i, hiddenPriorProbability.get(i) / sum);
            }
        }
        for (int row = 0; row < transitionProbabilityMatrix.getRowCount(); ++row) {
            DoubleVector rowVector = transitionProbabilityMatrix.getRowVector(row);
            rowVector = rowVector.divide(rowVector.sum());
            if (log) {
                rowVector = rowVector.log();
            }
            transitionProbabilityMatrix.setRowVector(row, rowVector);
            rowVector = emissionProbabilitiyMatrix.getRowVector(row);
            rowVector = rowVector.divide(rowVector.sum());
            if (log) {
                rowVector = rowVector.log();
            }
            emissionProbabilitiyMatrix.setRowVector(row, rowVector);
        }
    }

    public void trainSupervised(DoubleVector[] features, DoubleVector[] outcome) {
        Preconditions.checkArgument((features.length == outcome.length ? 1 : 0) != 0, (Object)("Feature array length must match outcome array length: " + features.length + " != " + outcome.length));
        Preconditions.checkArgument((features.length > 0 ? 1 : 0) != 0, (Object)("Feature array length be at least 1! Given: " + features.length));
        Preconditions.checkArgument((features[0].getDimension() == this.numVisibleStates ? 1 : 0) != 0, (Object)("Feature vector's dimension must match the number of visible states! Given: " + features[0].getDimension() + ", but expected " + this.numVisibleStates));
        int outcomeDimension = outcome[0].getDimension();
        int expectedDimension = outcomeDimension == 1 ? 2 : this.numHiddenStates;
        Preconditions.checkArgument((outcomeDimension == expectedDimension ? 1 : 0) != 0, (Object)("Outcome dimension didn't match the given number of hidden states: " + outcomeDimension + " != " + expectedDimension));
        this.hiddenPriorProbability = this.hiddenPriorProbability.add(1.0);
        for (int rowIndex = 0; rowIndex < this.numHiddenStates; ++rowIndex) {
            this.transitionProbabilityMatrix.setRowVector(rowIndex, (DoubleVector)DenseDoubleVector.ones((int)this.numHiddenStates));
            this.emissionProbabilityMatrix.setRowVector(rowIndex, (DoubleVector)DenseDoubleVector.ones((int)this.numVisibleStates));
        }
        for (int i = 0; i < features.length; ++i) {
            DoubleVector feat = features[i];
            DoubleVector out = outcome[i];
            int index = this.getOutcomeState(out);
            this.hiddenPriorProbability.set(index, this.hiddenPriorProbability.get(index) + 1.0);
            Iterator iterateNonZero = feat.iterateNonZero();
            while (iterateNonZero.hasNext()) {
                DoubleVector.DoubleVectorElement next = (DoubleVector.DoubleVectorElement)iterateNonZero.next();
                this.emissionProbabilityMatrix.set(index, next.getIndex(), this.emissionProbabilityMatrix.get(index, next.getIndex()) + 1.0);
            }
            if (i + 1 >= features.length) continue;
            DoubleVector nextOut = outcome[i + 1];
            int nextIndex = this.getOutcomeState(nextOut);
            this.transitionProbabilityMatrix.set(index, nextIndex, this.transitionProbabilityMatrix.get(index, nextIndex) + 1.0);
        }
        this.logNormalizeProbabilities();
    }

    @Override
    public void train(DoubleVector[] features, DoubleVector[] outcome) {
        this.trainSupervised(features, outcome);
    }

    @Override
    public DoubleVector predict(DoubleVector features) {
        DoubleVector probabilities = this.emissionProbabilityMatrix.multiplyVectorRow(features);
        double max = probabilities.max();
        for (int state = 0; state < probabilities.getDimension(); ++state) {
            probabilities.set(state, FastMath.exp((double)(probabilities.get(state) - max)) * this.hiddenPriorProbability.get(state));
        }
        return probabilities.divide(probabilities.sum());
    }

    public DoubleVector predict(DoubleVector features, DoubleVector previousOutcome) {
        DoubleVector probabilities = this.emissionProbabilityMatrix.multiplyVectorRow(features);
        probabilities.add(this.transitionProbabilityMatrix.multiplyVectorRow(previousOutcome));
        double max = probabilities.max();
        for (int state = 0; state < probabilities.getDimension(); ++state) {
            probabilities.set(state, FastMath.exp((double)(probabilities.get(state) - max)) * this.hiddenPriorProbability.get(state));
        }
        return probabilities.divide(probabilities.sum());
    }

    public int getNumHiddenStates() {
        return this.numHiddenStates;
    }

    public int getNumVisibleStates() {
        return this.numVisibleStates;
    }

    public DoubleMatrix getEmissionProbabilitiyMatrix() {
        return this.emissionProbabilityMatrix;
    }

    public DoubleVector getHiddenPriorProbability() {
        return this.hiddenPriorProbability;
    }

    public DoubleMatrix getTransitionProbabilityMatrix() {
        return this.transitionProbabilityMatrix;
    }

    private int getOutcomeState(DoubleVector out) {
        int index = out.getDimension() == 2 ? (int)out.get(0) : out.maxIndex();
        return index;
    }

    public void write(DataOutput out) throws IOException {
        out.writeInt(this.numVisibleStates);
        out.writeInt(this.numHiddenStates);
        VectorWritable.writeVector(this.hiddenPriorProbability, out);
        MatrixWritable.writeDenseMatrix((DenseDoubleMatrix)this.transitionProbabilityMatrix, out);
        MatrixWritable.writeDenseMatrix((DenseDoubleMatrix)this.emissionProbabilityMatrix, out);
    }

    public void readFields(DataInput in) throws IOException {
        this.numVisibleStates = in.readInt();
        this.numHiddenStates = in.readInt();
        this.hiddenPriorProbability = VectorWritable.readVector(in);
        this.transitionProbabilityMatrix = MatrixWritable.readDenseMatrix(in);
        this.emissionProbabilityMatrix = MatrixWritable.readDenseMatrix(in);
    }
}

