/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.fst.semi_supervised.constraints;

import cc.mallet.fst.SumLattice;
import cc.mallet.fst.semi_supervised.StateLabelMap;
import cc.mallet.fst.semi_supervised.constraints.GEConstraint;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.FeatureVectorSequence;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import gnu.trove.TIntArrayList;
import gnu.trove.TIntObjectHashMap;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.HashMap;

public class OneLabelL2RangeGEConstraints
implements GEConstraint {
    protected TIntObjectHashMap<OneLabelL2IndGEConstraint> constraints;
    protected StateLabelMap map;
    protected TIntArrayList cache;

    public OneLabelL2RangeGEConstraints() {
        this.constraints = new TIntObjectHashMap();
        this.cache = new TIntArrayList();
    }

    protected OneLabelL2RangeGEConstraints(TIntObjectHashMap<OneLabelL2IndGEConstraint> constraints, StateLabelMap map) {
        this.constraints = constraints;
        this.map = map;
        this.cache = new TIntArrayList();
    }

    public void addConstraint(int fi, int li, double lower, double upper, double weight) {
        if (!this.constraints.containsKey(fi)) {
            this.constraints.put(fi, new OneLabelL2IndGEConstraint());
        }
        this.constraints.get(fi).add(li, lower, upper, weight);
    }

    @Override
    public boolean isOneStateConstraint() {
        return true;
    }

    @Override
    public void setStateLabelMap(StateLabelMap map) {
        this.map = map;
    }

    @Override
    public void preProcess(FeatureVector fv) {
        this.cache.resetQuick();
        int loc = 0;
        while (loc < fv.numLocations()) {
            int fi = fv.indexAtLocation(loc);
            if (this.constraints.containsKey(fi)) {
                this.cache.add(fi);
            }
            ++loc;
        }
        if (this.constraints.containsKey(fv.getAlphabet().size())) {
            this.cache.add(fv.getAlphabet().size());
        }
    }

    @Override
    public BitSet preProcess(InstanceList data) {
        int ii = 0;
        BitSet bitSet = new BitSet(data.size());
        for (Instance instance : data) {
            FeatureVectorSequence fvs = (FeatureVectorSequence)instance.getData();
            int ip = 0;
            while (ip < fvs.size()) {
                FeatureVector fv = fvs.get(ip);
                int loc = 0;
                while (loc < fv.numLocations()) {
                    int fi = fv.indexAtLocation(loc);
                    if (this.constraints.containsKey(fi)) {
                        this.constraints.get((int)fi).count += 1.0;
                        bitSet.set(ii);
                    }
                    ++loc;
                }
                if (this.constraints.containsKey(fv.getAlphabet().size())) {
                    bitSet.set(ii);
                    this.constraints.get((int)fv.getAlphabet().size()).count += 1.0;
                }
                ++ip;
            }
            ++ii;
        }
        return bitSet;
    }

    @Override
    public double getCompositeConstraintFeatureValue(FeatureVector fv, int ip, int si1, int si2) {
        double value = 0.0;
        int li2 = this.map.getLabelIndex(si2);
        int i = 0;
        while (i < this.cache.size()) {
            value += this.constraints.get(this.cache.getQuick(i)).getGradientContribution(li2);
            ++i;
        }
        return value;
    }

    @Override
    public double getValue() {
        double value = 0.0;
        int[] nArray = this.constraints.keys();
        int n = nArray.length;
        int n2 = 0;
        while (n2 < n) {
            int fi = nArray[n2];
            OneLabelL2IndGEConstraint constraint = this.constraints.get(fi);
            if (constraint.count > 0.0) {
                int labelIndex = 0;
                while (labelIndex < this.map.getNumLabels()) {
                    value -= constraint.getValueContribution(labelIndex);
                    ++labelIndex;
                }
            }
            ++n2;
        }
        assert (!Double.isNaN(value) && !Double.isInfinite(value));
        return value;
    }

    @Override
    public void zeroExpectations() {
        int[] nArray = this.constraints.keys();
        int n = nArray.length;
        int n2 = 0;
        while (n2 < n) {
            int fi = nArray[n2];
            this.constraints.get((int)fi).expectation = new double[this.constraints.get(fi).getNumConstrainedLabels()];
            ++n2;
        }
    }

    @Override
    public void computeExpectations(ArrayList<SumLattice> lattices) {
        TIntArrayList cache = new TIntArrayList();
        int i = 0;
        while (i < lattices.size()) {
            if (lattices.get(i) != null) {
                SumLattice lattice = lattices.get(i);
                FeatureVectorSequence fvs = (FeatureVectorSequence)lattice.getInput();
                double[][] gammas = lattice.getGammas();
                int ip = 0;
                while (ip < fvs.size()) {
                    cache.resetQuick();
                    FeatureVector fv = fvs.getFeatureVector(ip);
                    int loc = 0;
                    while (loc < fv.numLocations()) {
                        int fi = fv.indexAtLocation(loc);
                        if (this.constraints.containsKey(fi)) {
                            cache.add(fi);
                        }
                        ++loc;
                    }
                    if (this.constraints.containsKey(fv.getAlphabet().size())) {
                        cache.add(fv.getAlphabet().size());
                    }
                    int s2 = 0;
                    while (s2 < this.map.getNumStates()) {
                        int li = this.map.getLabelIndex(s2);
                        if (li != -2) {
                            double gammaProb = Math.exp(gammas[ip + 1][s2]);
                            int j = 0;
                            while (j < cache.size()) {
                                this.constraints.get(cache.getQuick(j)).incrementExpectation(li, gammaProb);
                                ++j;
                            }
                        }
                        ++s2;
                    }
                    ++ip;
                }
            }
            ++i;
        }
    }

    @Override
    public GEConstraint copy() {
        return new OneLabelL2RangeGEConstraints(this.constraints, this.map);
    }

    protected class OneLabelL2IndGEConstraint {
        protected int index = 0;
        protected double count = 0.0;
        protected ArrayList<Double> lower = new ArrayList();
        protected ArrayList<Double> upper = new ArrayList();
        protected ArrayList<Double> weights = new ArrayList();
        protected HashMap<Integer, Integer> labelMap = new HashMap();
        protected double[] expectation;

        public void add(int label, double lower, double upper, double weight) {
            this.lower.add(lower);
            this.upper.add(upper);
            this.weights.add(weight);
            this.labelMap.put(label, this.index);
            ++this.index;
        }

        public void incrementExpectation(int li, double value) {
            if (this.labelMap.containsKey(li)) {
                int i;
                int n = i = this.labelMap.get(li).intValue();
                this.expectation[n] = this.expectation[n] + value;
            }
        }

        public double getValueContribution(int li) {
            if (this.labelMap.containsKey(li)) {
                int i = this.labelMap.get(li);
                assert (this.count != 0.0);
                double ex = this.expectation[i] / this.count;
                if (ex < this.lower.get(i)) {
                    return this.weights.get(i) * Math.pow(this.lower.get(i) - ex, 2.0);
                }
                if (ex > this.upper.get(i)) {
                    return this.weights.get(i) * Math.pow(this.upper.get(i) - ex, 2.0);
                }
            }
            return 0.0;
        }

        public int getNumConstrainedLabels() {
            return this.index;
        }

        public double getGradientContribution(int li) {
            if (this.labelMap.containsKey(li)) {
                int i = this.labelMap.get(li);
                assert (this.count != 0.0);
                double ex = this.expectation[i] / this.count;
                if (ex < this.lower.get(i)) {
                    return 2.0 * this.weights.get(i) * (this.lower.get(i) / this.count - this.expectation[i] / (this.count * this.count));
                }
                if (ex > this.upper.get(i)) {
                    return 2.0 * this.weights.get(i) * (this.upper.get(i) / this.count - this.expectation[i] / (this.count * this.count));
                }
            }
            return 0.0;
        }
    }
}

