/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.fst.semi_supervised.pr.constraints;

import cc.mallet.fst.semi_supervised.StateLabelMap;
import cc.mallet.fst.semi_supervised.pr.constraints.PRConstraint;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.FeatureVectorSequence;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import gnu.trove.TIntArrayList;
import gnu.trove.TIntIntHashMap;
import gnu.trove.TIntObjectHashMap;
import java.util.BitSet;

public class OneLabelL2PRConstraints
implements PRConstraint {
    protected TIntObjectHashMap<OneLabelPRConstraint> constraints = new TIntObjectHashMap();
    protected TIntIntHashMap constraintIndices;
    protected StateLabelMap map;
    protected boolean normalized;
    protected TIntArrayList cache;

    public OneLabelL2PRConstraints(boolean normalized) {
        this.constraintIndices = new TIntIntHashMap();
        this.cache = new TIntArrayList();
        this.normalized = normalized;
    }

    protected OneLabelL2PRConstraints(TIntObjectHashMap<OneLabelPRConstraint> constraints, TIntIntHashMap constraintIndices, StateLabelMap map, boolean normalized) {
        int[] nArray = constraints.keys();
        int n = nArray.length;
        int n2 = 0;
        while (n2 < n) {
            int key = nArray[n2];
            this.constraints.put(key, (Object)((OneLabelPRConstraint)constraints.get(key)).copy());
            ++n2;
        }
        this.constraintIndices = constraintIndices;
        this.map = map;
        this.cache = new TIntArrayList();
        this.normalized = normalized;
    }

    @Override
    public PRConstraint copy() {
        return new OneLabelL2PRConstraints(this.constraints, this.constraintIndices, this.map, this.normalized);
    }

    public void addConstraint(int fi, double[] target, double weight) {
        this.constraints.put(fi, (Object)new OneLabelPRConstraint(target, weight));
        this.constraintIndices.put(fi, this.constraintIndices.size());
    }

    @Override
    public int numDimensions() {
        assert (this.map != null);
        return this.map.getNumLabels() * this.constraints.size();
    }

    public boolean isOneStateConstraint() {
        return true;
    }

    @Override
    public void setStateLabelMap(StateLabelMap map) {
        this.map = map;
    }

    @Override
    public void preProcess(FeatureVector fv) {
        this.cache.resetQuick();
        int loc = 0;
        while (loc < fv.numLocations()) {
            int fi = fv.indexAtLocation(loc);
            if (this.constraints.containsKey(fi)) {
                this.cache.add(fi);
            }
            ++loc;
        }
    }

    @Override
    public BitSet preProcess(InstanceList data) {
        int ii = 0;
        BitSet bitSet = new BitSet(data.size());
        for (Instance instance : data) {
            FeatureVectorSequence fvs = (FeatureVectorSequence)instance.getData();
            int ip = 0;
            while (ip < fvs.size()) {
                FeatureVector fv = fvs.get(ip);
                int loc = 0;
                while (loc < fv.numLocations()) {
                    int fi = fv.indexAtLocation(loc);
                    if (this.constraints.containsKey(fi)) {
                        ((OneLabelPRConstraint)this.constraints.get((int)fi)).count += 1.0;
                        bitSet.set(ii);
                    }
                    ++loc;
                }
                ++ip;
            }
            ++ii;
        }
        return bitSet;
    }

    @Override
    public double getScore(FeatureVector input, int inputPosition, int srcIndex, int destIndex, double[] parameters) {
        double dot = 0.0;
        int li2 = this.map.getLabelIndex(destIndex);
        int i = 0;
        while (i < this.cache.size()) {
            int j = this.constraintIndices.get(this.cache.getQuick(i));
            dot = this.normalized ? (dot += parameters[j + this.constraints.size() * li2] / ((OneLabelPRConstraint)this.constraints.get((int)this.cache.getQuick((int)i))).count) : (dot += parameters[j + this.constraints.size() * li2]);
            ++i;
        }
        return dot;
    }

    @Override
    public void incrementExpectations(FeatureVector input, int inputPosition, int srcIndex, int destIndex, double prob) {
        int li2 = this.map.getLabelIndex(destIndex);
        int i = 0;
        while (i < this.cache.size()) {
            int n = li2;
            ((OneLabelPRConstraint)this.constraints.get((int)this.cache.getQuick((int)i))).expectation[n] = ((OneLabelPRConstraint)this.constraints.get((int)this.cache.getQuick((int)i))).expectation[n] + prob;
            ++i;
        }
    }

    @Override
    public void getExpectations(double[] expectations) {
        assert (expectations.length == this.numDimensions());
        int[] nArray = this.constraintIndices.keys();
        int n = nArray.length;
        int n2 = 0;
        while (n2 < n) {
            int fi = nArray[n2];
            int ci = this.constraintIndices.get(fi);
            OneLabelPRConstraint constraint = (OneLabelPRConstraint)this.constraints.get(fi);
            int li = 0;
            while (li < constraint.expectation.length) {
                expectations[ci + li * this.constraints.size()] = constraint.expectation[li];
                ++li;
            }
            ++n2;
        }
    }

    @Override
    public void addExpectations(double[] expectations) {
        assert (expectations.length == this.numDimensions());
        int[] nArray = this.constraintIndices.keys();
        int n = nArray.length;
        int n2 = 0;
        while (n2 < n) {
            int fi = nArray[n2];
            int ci = this.constraintIndices.get(fi);
            OneLabelPRConstraint constraint = (OneLabelPRConstraint)this.constraints.get(fi);
            int li = 0;
            while (li < constraint.expectation.length) {
                int n3 = li;
                constraint.expectation[n3] = constraint.expectation[n3] + expectations[ci + li * this.constraints.size()];
                ++li;
            }
            ++n2;
        }
    }

    @Override
    public void zeroExpectations() {
        int[] nArray = this.constraints.keys();
        int n = nArray.length;
        int n2 = 0;
        while (n2 < n) {
            int fi = nArray[n2];
            ((OneLabelPRConstraint)this.constraints.get((int)fi)).expectation = new double[this.map.getNumLabels()];
            ++n2;
        }
    }

    @Override
    public double getAuxiliaryValueContribution(double[] parameters) {
        double value = 0.0;
        int[] nArray = this.constraints.keys();
        int n = nArray.length;
        int n2 = 0;
        while (n2 < n) {
            int fi = nArray[n2];
            int ci = this.constraintIndices.get(fi);
            int li = 0;
            while (li < this.map.getNumLabels()) {
                double param = parameters[ci + li * this.constraints.size()];
                value += ((OneLabelPRConstraint)this.constraints.get((int)fi)).target[li] * param - param * param / (2.0 * ((OneLabelPRConstraint)this.constraints.get((int)fi)).weight);
                ++li;
            }
            ++n2;
        }
        return value;
    }

    @Override
    public double getCompleteValueContribution(double[] parameters) {
        double value = 0.0;
        int[] nArray = this.constraints.keys();
        int n = nArray.length;
        int n2 = 0;
        while (n2 < n) {
            int fi = nArray[n2];
            OneLabelPRConstraint constraint = (OneLabelPRConstraint)this.constraints.get(fi);
            int li = 0;
            while (li < this.map.getNumLabels()) {
                value = this.normalized ? (value += constraint.weight * Math.pow(constraint.target[li] - constraint.expectation[li] / constraint.count, 2.0) / 2.0) : (value += constraint.weight * Math.pow(constraint.target[li] - constraint.expectation[li], 2.0) / 2.0);
                ++li;
            }
            ++n2;
        }
        return value;
    }

    @Override
    public void getGradient(double[] parameters, double[] gradient) {
        int[] nArray = this.constraints.keys();
        int n = nArray.length;
        int n2 = 0;
        while (n2 < n) {
            int fi = nArray[n2];
            int ci = this.constraintIndices.get(fi);
            OneLabelPRConstraint constraint = (OneLabelPRConstraint)this.constraints.get(fi);
            int li = 0;
            while (li < this.map.getNumLabels()) {
                gradient[ci + li * this.constraints.size()] = this.normalized ? constraint.target[li] - constraint.expectation[li] / constraint.count - parameters[ci + li * this.constraints.size()] / constraint.weight : constraint.target[li] - constraint.expectation[li] - parameters[ci + li * this.constraints.size()] / constraint.weight;
                ++li;
            }
            ++n2;
        }
    }

    protected class OneLabelPRConstraint {
        protected double[] target;
        protected double[] expectation;
        protected double count;
        protected double weight;

        public OneLabelPRConstraint(double[] target, double weight) {
            this.target = target;
            this.weight = weight;
            this.expectation = null;
            this.count = 0.0;
        }

        public OneLabelPRConstraint copy() {
            OneLabelPRConstraint copy = new OneLabelPRConstraint(this.target, this.weight);
            copy.count = this.count;
            copy.expectation = new double[this.target.length];
            return copy;
        }
    }
}

