/*
 * Decompiled with CFR 0.152.
 */
package de.jungblut.ner;

import com.google.common.base.Preconditions;
import de.jungblut.classification.AbstractClassifier;
import de.jungblut.math.DoubleMatrix;
import de.jungblut.math.DoubleVector;
import de.jungblut.math.ViterbiUtils;
import de.jungblut.math.dense.DenseDoubleMatrix;
import de.jungblut.math.dense.DenseDoubleVector;
import de.jungblut.math.minimize.DenseMatrixFolder;
import de.jungblut.math.minimize.Minimizer;
import de.jungblut.math.sparse.SparseDoubleRowMatrix;
import de.jungblut.ner.ConditionalLikelihoodCostFunction;
import de.jungblut.ner.UnrollableDoubleVector;
import java.util.Collections;

public final class MaxEntMarkovModel
extends AbstractClassifier {
    private final Minimizer minimizer;
    private final boolean verbose;
    private final int numIterations;
    private DoubleMatrix theta;
    private int classes;

    public MaxEntMarkovModel(Minimizer minimizer, int numIterations, boolean verbose) {
        this.minimizer = minimizer;
        this.numIterations = numIterations;
        this.verbose = verbose;
    }

    public MaxEntMarkovModel(DenseDoubleMatrix theta, int classes) {
        this(null, -1, false);
        this.theta = theta;
        this.classes = classes;
    }

    @Override
    public void train(DoubleVector[] features, DoubleVector[] outcome) {
        Preconditions.checkArgument((features.length == outcome.length && features.length > 0 ? 1 : 0) != 0, (Object)"There wasn't at least a single featurevector, or the two array didn't match in size.");
        this.classes = outcome[0].getDimension() == 1 ? 2 : outcome[0].getDimension();
        Object mat = null;
        mat = features[0].isSparse() ? new SparseDoubleRowMatrix(features) : new DenseDoubleMatrix(features);
        ConditionalLikelihoodCostFunction func = new ConditionalLikelihoodCostFunction((DoubleMatrix)mat, (DoubleMatrix)new DenseDoubleMatrix(outcome));
        DenseDoubleVector vx = new DenseDoubleVector(mat.getColumnCount() * this.classes);
        DoubleVector input = this.minimizer.minimize(func, (DoubleVector)vx, this.numIterations, this.verbose);
        this.theta = DenseMatrixFolder.unfoldMatrix(input, this.classes, (int)((double)input.getLength() / (double)this.classes));
    }

    @Override
    public DoubleVector predict(DoubleVector features) {
        Preconditions.checkArgument((boolean)features.getClass().equals(UnrollableDoubleVector.class), (Object)"Features must be an instance of the class UnrollableDoubleVector.");
        UnrollableDoubleVector unrollable = (UnrollableDoubleVector)features;
        return this.predict(unrollable.getMainVector(), unrollable.getSideVectors());
    }

    public DoubleMatrix getTheta() {
        return this.theta;
    }

    public DoubleVector predict(DoubleVector feature, DoubleVector[] featuresPerState) {
        return ViterbiUtils.decode(this.theta, (DoubleMatrix)new SparseDoubleRowMatrix(Collections.singletonList(feature)), (DoubleMatrix)new SparseDoubleRowMatrix(featuresPerState), this.classes).getRowVector(0);
    }

    public DoubleMatrix predict(DoubleMatrix features, DoubleMatrix featuresPerState) {
        return ViterbiUtils.decode(this.theta, features, featuresPerState, this.classes);
    }
}

