/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.grmm.types;

import cc.mallet.grmm.types.AbstractFactor;
import cc.mallet.grmm.types.Assignment;
import cc.mallet.grmm.types.AssignmentIterator;
import cc.mallet.grmm.types.Factor;
import cc.mallet.grmm.types.HashVarSet;
import cc.mallet.grmm.types.LogTableFactor;
import cc.mallet.grmm.types.ParameterizedFactor;
import cc.mallet.grmm.types.VarSet;
import cc.mallet.grmm.types.Variable;
import cc.mallet.grmm.util.Matrices;
import cc.mallet.types.Matrix;
import cc.mallet.types.SparseMatrixn;
import cc.mallet.util.Randoms;

public class PottsTableFactor
extends AbstractFactor
implements ParameterizedFactor {
    private Variable alpha;
    private VarSet xs;

    public PottsTableFactor(VarSet xs, Variable alpha) {
        super(PottsTableFactor.combineVariables(alpha, xs));
        this.alpha = alpha;
        this.xs = xs;
        if (!alpha.isContinuous()) {
            throw new IllegalArgumentException("alpha must be continuous");
        }
    }

    public PottsTableFactor(Variable x1, Variable x2, Variable alpha) {
        super(new HashVarSet(new Variable[]{x1, x2, alpha}));
        this.alpha = alpha;
        this.xs = new HashVarSet(new Variable[]{x1, x2});
        if (!alpha.isContinuous()) {
            throw new IllegalArgumentException("alpha must be continuous");
        }
    }

    private static VarSet combineVariables(Variable alpha, VarSet xs) {
        HashVarSet ret = new HashVarSet(xs);
        ret.add(alpha);
        return ret;
    }

    @Override
    protected Factor extractMaxInternal(VarSet varSet) {
        throw new UnsupportedOperationException();
    }

    @Override
    protected double lookupValueInternal(int i) {
        throw new UnsupportedOperationException();
    }

    @Override
    protected Factor marginalizeInternal(VarSet varsToKeep) {
        throw new UnsupportedOperationException();
    }

    @Override
    public double value(AssignmentIterator it) {
        Assignment assn = it.assignment();
        Factor tbl = this.sliceForAlpha(assn);
        return tbl.value(assn);
    }

    private Factor sliceForAlpha(Assignment assn) {
        double alph = assn.getDouble(this.alpha);
        int[] sizes = this.sizesFromVarSet(this.xs);
        Matrix diag = Matrices.diag(sizes, alph);
        Matrix matrix = Matrices.constant(sizes, -alph);
        matrix.plusEquals(diag);
        return LogTableFactor.makeFromLogMatrix(this.xs.toVariableArray(), (SparseMatrixn)matrix);
    }

    private int[] sizesFromVarSet(VarSet xs) {
        int[] szs = new int[xs.size()];
        int i = 0;
        while (i < xs.size()) {
            szs[i] = xs.get(i).getNumOutcomes();
            ++i;
        }
        return szs;
    }

    @Override
    public Factor normalize() {
        throw new UnsupportedOperationException();
    }

    @Override
    public Assignment sample(Randoms r) {
        throw new UnsupportedOperationException();
    }

    @Override
    public double logValue(AssignmentIterator it) {
        return Math.log(this.value(it));
    }

    @Override
    public Factor slice(Assignment assn) {
        Factor alphSlice = this.sliceForAlpha(assn);
        return alphSlice.slice(assn);
    }

    @Override
    public String dumpToString() {
        StringBuffer buf = new StringBuffer();
        buf.append("[Potts: alpha:");
        buf.append(this.alpha);
        buf.append(" xs:");
        buf.append(this.xs);
        buf.append("]");
        return buf.toString();
    }

    @Override
    public double sumGradLog(Factor q, Variable param, Assignment theta) {
        if (param != this.alpha) {
            throw new IllegalArgumentException();
        }
        Factor q_xs = q.marginalize(this.xs);
        double qDiff = 0.0;
        AssignmentIterator it = this.xs.assignmentIterator();
        while (it.hasNext()) {
            Assignment assn = it.assignment();
            if (!this.isAllEqual(assn)) {
                qDiff += -q_xs.value(it);
            }
            it.advance();
        }
        return qDiff;
    }

    public double secondDerivative(Factor q, Variable param, Assignment theta) {
        double e_x = this.sumGradLog(q, param, theta);
        Factor q_xs = q.marginalize(this.xs);
        double e_x2 = 0.0;
        AssignmentIterator it = this.xs.assignmentIterator();
        while (it.hasNext()) {
            Assignment assn = it.assignment();
            if (!this.isAllEqual(assn)) {
                e_x2 += q_xs.value(it);
            }
            it.advance();
        }
        return e_x2 - e_x * e_x;
    }

    private boolean isAllEqual(Assignment assn) {
        Object val1 = assn.getObject(this.xs.get(0));
        int i = 1;
        while (i < this.xs.size()) {
            Object val2 = assn.getObject(this.xs.get(i));
            if (!val1.equals(val2)) {
                return false;
            }
            ++i;
        }
        return true;
    }

    @Override
    public Factor duplicate() {
        return new PottsTableFactor(this.xs, this.alpha);
    }

    @Override
    public boolean isNaN() {
        return false;
    }

    @Override
    public boolean almostEquals(Factor p, double epsilon) {
        return this.equals(p);
    }

    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (o == null || this.getClass() != o.getClass()) {
            return false;
        }
        PottsTableFactor that = (PottsTableFactor)o;
        if (this.alpha != null ? !this.alpha.equals(that.alpha) : that.alpha != null) {
            return false;
        }
        return !(this.xs != null ? !this.xs.equals(that.xs) : that.xs != null);
    }

    public int hashCode() {
        int result = this.alpha != null ? this.alpha.hashCode() : 0;
        result = 29 * result + (this.xs != null ? this.xs.hashCode() : 0);
        return result;
    }
}

