/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.grmm.types;

import cc.mallet.grmm.types.AbstractTableFactor;
import cc.mallet.grmm.types.Assignment;
import cc.mallet.grmm.types.AssignmentIterator;
import cc.mallet.grmm.types.BidirectionalIntObjectMap;
import cc.mallet.grmm.types.DiscreteFactor;
import cc.mallet.grmm.types.Factor;
import cc.mallet.grmm.types.HashVarSet;
import cc.mallet.grmm.types.VarSet;
import cc.mallet.grmm.types.Variable;
import cc.mallet.grmm.util.Flops;
import cc.mallet.types.Matrix;
import cc.mallet.types.Matrixn;
import cc.mallet.util.Maths;
import java.util.Arrays;
import java.util.Collection;

public class TableFactor
extends AbstractTableFactor {
    public static DiscreteFactor multiplyAll(Factor[] phis) {
        return TableFactor.multiplyAll(Arrays.asList(phis));
    }

    public static AbstractTableFactor multiplyAll(Collection phis) {
        if (phis.size() == 1) {
            Factor first = (Factor)phis.iterator().next();
            return (AbstractTableFactor)first.duplicate();
        }
        HashVarSet vs = new HashVarSet();
        for (Factor phi : phis) {
            vs.addAll(phi.varSet());
        }
        TableFactor newCPF = new TableFactor(vs);
        for (Factor phi : phis) {
            newCPF.multiplyBy(phi);
        }
        return newCPF;
    }

    public TableFactor(Variable var) {
        super(var);
    }

    public TableFactor(Variable var, double[] values) {
        super(var, values);
    }

    public TableFactor() {
    }

    public TableFactor(BidirectionalIntObjectMap varMap) {
        super(varMap);
    }

    public TableFactor(Variable[] allVars) {
        super(allVars);
    }

    public TableFactor(Collection allVars) {
        super(allVars);
    }

    public TableFactor(Variable[] allVars, double[] probs) {
        super(allVars, probs);
    }

    public TableFactor(VarSet allVars, double[] probs) {
        super(allVars, probs);
    }

    public TableFactor(Variable[] allVars, Matrix probsIn) {
        super(allVars, probsIn);
    }

    public TableFactor(AbstractTableFactor in) {
        super(in);
        this.probs = (Matrix)in.getValueMatrix().cloneMatrix();
    }

    public TableFactor(VarSet allVars, Matrix probsIn) {
        super(allVars, probsIn);
    }

    public TableFactor(AbstractTableFactor ptl, double[] probs) {
        super(ptl, probs);
    }

    @Override
    void setAsIdentity() {
        this.setAll(1.0);
    }

    @Override
    public Factor duplicate() {
        return new TableFactor(this);
    }

    @Override
    protected AbstractTableFactor createBlankSubset(Variable[] vars) {
        return new TableFactor(vars);
    }

    @Override
    public Factor normalize() {
        Flops.increment(2 * this.probs.numLocations());
        this.probs.oneNormalize();
        return this;
    }

    @Override
    public double sum() {
        Flops.increment(this.probs.numLocations());
        return this.probs.oneNorm();
    }

    @Override
    public double logValue(AssignmentIterator it) {
        Flops.log();
        return Math.log(this.rawValue(it.indexOfCurrentAssn()));
    }

    @Override
    public double logValue(Assignment assn) {
        Flops.log();
        return Math.log(this.rawValue(assn));
    }

    @Override
    public double logValue(int loc) {
        Flops.log();
        return Math.log(this.rawValue(loc));
    }

    @Override
    public double value(Assignment assn) {
        return this.rawValue(assn);
    }

    @Override
    public double value(int loc) {
        return this.rawValue(loc);
    }

    @Override
    public double value(AssignmentIterator assn) {
        return this.rawValue(assn.indexOfCurrentAssn());
    }

    @Override
    protected Factor marginalizeInternal(AbstractTableFactor result) {
        result.setAll(0.0);
        int[] projection = this.largeIdxToSmall(result);
        int numLocs = this.probs.numLocations();
        int largeLoc = 0;
        while (largeLoc < numLocs) {
            int smallIdx = projection[largeLoc];
            double oldValue = this.probs.valueAtLocation(largeLoc);
            result.probs.incrementSingleValue(smallIdx, oldValue);
            ++largeLoc;
        }
        Flops.increment(numLocs);
        return result;
    }

    @Override
    protected void multiplyByInternal(DiscreteFactor ptl) {
        int[] projection = this.largeIdxToSmall(ptl);
        int numLocs = this.probs.numLocations();
        int singleLoc = 0;
        while (singleLoc < numLocs) {
            int smallIdx = projection[singleLoc];
            double prev = this.probs.valueAtLocation(singleLoc);
            double newVal = ptl.value(smallIdx);
            this.probs.setValueAtLocation(singleLoc, prev * newVal);
            ++singleLoc;
        }
        Flops.increment(numLocs);
    }

    @Override
    protected void divideByInternal(DiscreteFactor ptl) {
        int[] projection = this.largeIdxToSmall(ptl);
        int numLocs = this.probs.numLocations();
        int singleLoc = 0;
        while (singleLoc < numLocs) {
            int smallIdx = projection[singleLoc];
            double prev = this.probs.valueAtLocation(singleLoc);
            double newVal = ptl.value(smallIdx);
            double product = prev / newVal;
            if (Maths.almostEquals(newVal, 0.0)) {
                product = 0.0;
            }
            this.probs.setValueAtLocation(singleLoc, product);
            ++singleLoc;
        }
        Flops.increment(numLocs);
    }

    @Override
    protected void plusEqualsInternal(DiscreteFactor ptl) {
        int[] projection = this.largeIdxToSmall(ptl);
        int numLocs = this.probs.numLocations();
        int singleLoc = 0;
        while (singleLoc < numLocs) {
            int smallIdx = projection[singleLoc];
            double prev = this.probs.valueAtLocation(singleLoc);
            double newVal = ptl.value(smallIdx);
            this.probs.setValueAtLocation(singleLoc, prev + newVal);
            ++singleLoc;
        }
        Flops.increment(numLocs);
    }

    protected double rawValue(Assignment assn) {
        int numVars = this.getNumVars();
        int[] indices = new int[numVars];
        int i = 0;
        while (i < numVars) {
            Variable var = this.getVariable(i);
            indices[i] = assn.get(var);
            ++i;
        }
        double value = this.rawValue(indices);
        return value;
    }

    private double rawValue(int[] indices) {
        int singleIdx = this.probs.singleIndex(indices);
        return this.rawValue(singleIdx);
    }

    @Override
    protected double rawValue(int singleIdx) {
        int loc = this.probs.location(singleIdx);
        if (loc < 0) {
            return 0.0;
        }
        return this.probs.valueAtLocation(loc);
    }

    @Override
    public void exponentiate(double power) {
        int loc = 0;
        while (loc < this.probs.numLocations()) {
            double oldVal = this.probs.valueAtLocation(loc);
            double newVal = Math.pow(oldVal, power);
            this.probs.setValueAtLocation(loc, newVal);
            ++loc;
        }
        Flops.pow(this.probs.numLocations());
    }

    @Override
    public void setLogValue(Assignment assn, double logValue) {
        Flops.exp();
        this.setRawValue(assn, Math.exp(logValue));
    }

    @Override
    public void setLogValue(AssignmentIterator assnIt, double logValue) {
        Flops.exp();
        this.setRawValue(assnIt, Math.exp(logValue));
    }

    @Override
    public void setValue(AssignmentIterator assnIt, double value) {
        this.setRawValue(assnIt, value);
    }

    @Override
    public void setLogValues(double[] vals) {
        Flops.exp(vals.length);
        int i = 0;
        while (i < vals.length) {
            this.setRawValue(i, Math.exp(vals[i]));
            ++i;
        }
    }

    @Override
    public void setValues(double[] vals) {
        int i = 0;
        while (i < vals.length) {
            this.setRawValue(i, vals[i]);
            ++i;
        }
    }

    @Override
    public void timesEquals(double v) {
        Flops.increment(this.probs.numLocations());
        this.probs.timesEquals(v);
    }

    @Override
    protected void plusEqualsAtLocation(int loc, double v) {
        Flops.increment(1);
        double oldVal = this.valueAtLocation(loc);
        this.setRawValue(loc, oldVal + v);
    }

    @Override
    public Matrix getValueMatrix() {
        return this.probs;
    }

    @Override
    public Matrix getLogValueMatrix() {
        Flops.log(this.probs.numLocations());
        Matrix logProbs = (Matrix)this.probs.cloneMatrix();
        int loc = 0;
        while (loc < this.probs.numLocations()) {
            logProbs.setValueAtLocation(loc, Math.log(logProbs.valueAtLocation(loc)));
            ++loc;
        }
        return logProbs;
    }

    @Override
    public double valueAtLocation(int idx) {
        return this.probs.valueAtLocation(idx);
    }

    @Override
    protected Factor slice_onevar(Variable var, Assignment observed) {
        double[] vals = new double[var.getNumOutcomes()];
        int i = 0;
        while (i < var.getNumOutcomes()) {
            Assignment toAssn = new Assignment(var, i);
            Assignment union = Assignment.union(toAssn, observed);
            vals[i] = this.value(union);
            ++i;
        }
        return new TableFactor(var, vals);
    }

    @Override
    protected Factor slice_twovar(Variable v1, Variable v2, Assignment observed) {
        int N1 = v1.getNumOutcomes();
        int N2 = v2.getNumOutcomes();
        int[] szs = new int[]{N1, N2};
        Variable[] varr = new Variable[]{v1, v2};
        int[] outcomes = new int[2];
        double[] vals = new double[N1 * N2];
        int i = 0;
        while (i < N1) {
            outcomes[0] = i;
            int j = 0;
            while (j < N2) {
                outcomes[1] = j;
                Assignment toVars = new Assignment(varr, outcomes);
                Assignment assn = Assignment.union(toVars, observed);
                int idx = Matrixn.singleIndex(szs, new int[]{i, j++});
                vals[idx] = this.value(assn);
            }
            ++i;
        }
        return new TableFactor(new Variable[]{v1, v2}, vals);
    }

    @Override
    protected Factor slice_general(Variable[] vars, Assignment observed) {
        HashVarSet toKeep = new HashVarSet(vars);
        toKeep.removeAll(observed.varSet());
        double[] vals = new double[toKeep.weight()];
        AssignmentIterator it = toKeep.assignmentIterator();
        while (it.hasNext()) {
            Assignment union = Assignment.union(observed, it.assignment());
            vals[it.indexOfCurrentAssn()] = this.value(union);
            it.advance();
        }
        return new TableFactor((VarSet)toKeep, vals);
    }

    public static TableFactor makeFromLogValues(VarSet domain, double[] vals) {
        double[] vals2 = new double[vals.length];
        int i = 0;
        while (i < vals.length) {
            vals2[i] = Math.exp(vals[i]);
            ++i;
        }
        return new TableFactor(domain, vals2);
    }

    @Override
    public AbstractTableFactor recenter() {
        int loc = this.argmax();
        double val = this.valueAtLocation(loc);
        this.timesEquals(1.0 / val);
        return this;
    }
}

