/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.classify;

import cc.mallet.classify.Classification;
import cc.mallet.classify.Classifier;
import cc.mallet.pipe.Pipe;
import cc.mallet.types.Alphabet;
import cc.mallet.types.FeatureSelection;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.Instance;
import cc.mallet.types.LabelAlphabet;
import cc.mallet.types.LabelVector;
import cc.mallet.types.MatrixOps;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;

public class MCMaxEnt
extends Classifier
implements Serializable {
    double[] parameters;
    int defaultFeatureIndex;
    FeatureSelection featureSelection;
    FeatureSelection[] perClassFeatureSelection;
    private static final long serialVersionUID = 1L;
    private static final int CURRENT_SERIAL_VERSION = 1;
    static final int NULL_INTEGER = -1;

    public MCMaxEnt(Pipe dataPipe, double[] parameters, FeatureSelection featureSelection, FeatureSelection[] perClassFeatureSelection) {
        super(dataPipe);
        assert (featureSelection == null || perClassFeatureSelection == null);
        this.parameters = parameters;
        this.featureSelection = featureSelection;
        this.perClassFeatureSelection = perClassFeatureSelection;
        this.defaultFeatureIndex = dataPipe.getDataAlphabet().size();
    }

    public MCMaxEnt(Pipe dataPipe, double[] parameters, FeatureSelection featureSelection) {
        this(dataPipe, parameters, featureSelection, null);
    }

    public MCMaxEnt(Pipe dataPipe, double[] parameters, FeatureSelection[] perClassFeatureSelection) {
        this(dataPipe, parameters, null, perClassFeatureSelection);
    }

    public MCMaxEnt(Pipe dataPipe, double[] parameters) {
        this(dataPipe, parameters, null, null);
    }

    public double[] getParameters() {
        return this.parameters;
    }

    public void setParameter(int classIndex, int featureIndex, double value) {
        this.parameters[classIndex * (this.getAlphabet().size() + 1) + featureIndex] = value;
    }

    public void getUnnormalizedClassificationScores(Instance instance, double[] scores) {
        int numFeatures = this.defaultFeatureIndex + 1;
        int numLabels = this.getLabelAlphabet().size();
        assert (scores.length == numLabels);
        FeatureVector fv = (FeatureVector)instance.getData();
        assert (fv.getAlphabet() == this.instancePipe.getDataAlphabet());
        int li = 0;
        while (li < numLabels) {
            scores[li] = this.parameters[li * numFeatures + this.defaultFeatureIndex] + MatrixOps.rowDotProduct(this.parameters, numFeatures, li, fv, this.defaultFeatureIndex, this.perClassFeatureSelection == null ? this.featureSelection : this.perClassFeatureSelection[li]);
            ++li;
        }
    }

    public void getClassificationScores(Instance instance, double[] scores) {
        int numLabels = this.getLabelAlphabet().size();
        assert (scores.length == numLabels);
        FeatureVector fv = (FeatureVector)instance.getData();
        assert (this.instancePipe == null || fv.getAlphabet() == this.instancePipe.getDataAlphabet());
        int numFeatures = this.defaultFeatureIndex + 1;
        int li = 0;
        while (li < numLabels) {
            scores[li] = this.parameters[li * numFeatures + this.defaultFeatureIndex] + MatrixOps.rowDotProduct(this.parameters, numFeatures, li, fv, this.defaultFeatureIndex, this.perClassFeatureSelection == null ? this.featureSelection : this.perClassFeatureSelection[li]);
            ++li;
        }
        double max = MatrixOps.max(scores);
        double sum = 0.0;
        int li2 = 0;
        while (li2 < numLabels) {
            scores[li2] = Math.exp(scores[li2] - max);
            sum += scores[li2];
            ++li2;
        }
        li2 = 0;
        while (li2 < numLabels) {
            int n = li2++;
            scores[n] = scores[n] / sum;
        }
    }

    @Override
    public Classification classify(Instance instance) {
        int numClasses = this.getLabelAlphabet().size();
        double[] scores = new double[numClasses];
        this.getClassificationScores(instance, scores);
        return new Classification(instance, this, new LabelVector(this.getLabelAlphabet(), scores));
    }

    @Override
    public void print() {
        Alphabet dict = this.getAlphabet();
        LabelAlphabet labelDict = this.getLabelAlphabet();
        int numFeatures = dict.size() + 1;
        int numLabels = labelDict.size();
        int li = 0;
        while (li < numLabels) {
            System.out.println("FEATURES FOR CLASS " + labelDict.lookupObject(li));
            System.out.println(" <default> " + this.parameters[li * numFeatures + this.defaultFeatureIndex]);
            int i = 0;
            while (i < this.defaultFeatureIndex) {
                Object name = dict.lookupObject(i);
                double weight = this.parameters[li * numFeatures + i];
                System.out.println(" " + name + " " + weight);
                ++i;
            }
            ++li;
        }
    }

    private void writeObject(ObjectOutputStream out) throws IOException {
        out.writeInt(1);
        out.writeObject(this.getInstancePipe());
        int np = this.parameters.length;
        out.writeInt(np);
        int p = 0;
        while (p < np) {
            out.writeDouble(this.parameters[p]);
            ++p;
        }
        out.writeInt(this.defaultFeatureIndex);
        if (this.featureSelection == null) {
            out.writeInt(-1);
        } else {
            out.writeInt(1);
            out.writeObject(this.featureSelection);
        }
        if (this.perClassFeatureSelection == null) {
            out.writeInt(-1);
        } else {
            out.writeInt(this.perClassFeatureSelection.length);
            int i = 0;
            while (i < this.perClassFeatureSelection.length) {
                if (this.perClassFeatureSelection[i] == null) {
                    out.writeInt(-1);
                } else {
                    out.writeInt(1);
                    out.writeObject(this.perClassFeatureSelection[i]);
                }
                ++i;
            }
        }
    }

    private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
        int nfs;
        int version2 = in.readInt();
        if (version2 != 1) {
            throw new ClassNotFoundException("Mismatched MCMaxEnt versions: wanted 1, got " + version2);
        }
        this.instancePipe = (Pipe)in.readObject();
        int np = in.readInt();
        this.parameters = new double[np];
        int p = 0;
        while (p < np) {
            this.parameters[p] = in.readDouble();
            ++p;
        }
        this.defaultFeatureIndex = in.readInt();
        int opt = in.readInt();
        if (opt == 1) {
            this.featureSelection = (FeatureSelection)in.readObject();
        }
        if ((nfs = in.readInt()) >= 0) {
            this.perClassFeatureSelection = new FeatureSelection[nfs];
            int i = 0;
            while (i < nfs) {
                opt = in.readInt();
                if (opt == 1) {
                    this.perClassFeatureSelection[i] = (FeatureSelection)in.readObject();
                }
                ++i;
            }
        }
    }
}

