/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.classify.constraints.ge;

import cc.mallet.classify.constraints.ge.MaxEntGEConstraint;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import gnu.trove.TDoubleArrayList;
import gnu.trove.TIntArrayList;
import gnu.trove.TIntObjectHashMap;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.HashMap;

public class MaxEntRangeL2FLGEConstraints
implements MaxEntGEConstraint {
    private boolean useValues;
    private boolean normalize;
    private int numFeatures;
    private int numLabels;
    protected TIntObjectHashMap<MaxEntL2IndGEConstraint> constraints;
    protected TIntArrayList indexCache;
    protected TDoubleArrayList valueCache;

    public MaxEntRangeL2FLGEConstraints(int numFeatures, int numLabels, boolean useValues, boolean normalize) {
        this.numFeatures = numFeatures;
        this.numLabels = numLabels;
        this.useValues = useValues;
        this.normalize = normalize;
        this.constraints = new TIntObjectHashMap();
        this.indexCache = new TIntArrayList();
        this.valueCache = new TDoubleArrayList();
    }

    public void addConstraint(int fi, int li, double lower, double upper, double weight) {
        if (!this.constraints.containsKey(fi)) {
            this.constraints.put(fi, new MaxEntL2IndGEConstraint());
        }
        this.constraints.get(fi).add(li, lower, upper, weight);
    }

    @Override
    public BitSet preProcess(InstanceList data) {
        int ii = 0;
        BitSet bitSet = new BitSet(data.size());
        for (Instance instance : data) {
            double weight = data.getInstanceWeight(instance);
            FeatureVector fv = (FeatureVector)instance.getData();
            for (int loc = 0; loc < fv.numLocations(); ++loc) {
                int fi = fv.indexAtLocation(loc);
                if (!this.constraints.containsKey(fi)) continue;
                this.constraints.get((int)fi).count = this.useValues ? (this.constraints.get((int)fi).count += weight * fv.valueAtLocation(loc)) : (this.constraints.get((int)fi).count += weight);
                bitSet.set(ii);
            }
            ++ii;
            if (!this.constraints.containsKey(this.numFeatures)) continue;
            bitSet.set(ii);
            this.constraints.get((int)this.numFeatures).count += weight;
        }
        return bitSet;
    }

    @Override
    public void preProcess(FeatureVector input) {
        this.indexCache.resetQuick();
        if (this.useValues) {
            this.valueCache.resetQuick();
        }
        for (int loc = 0; loc < input.numLocations(); ++loc) {
            int fi = input.indexAtLocation(loc);
            if (!this.constraints.containsKey(fi)) continue;
            this.indexCache.add(fi);
            if (!this.useValues) continue;
            this.valueCache.add(input.valueAtLocation(loc));
        }
        if (this.constraints.containsKey(this.numFeatures)) {
            this.indexCache.add(this.numFeatures);
            if (this.useValues) {
                this.valueCache.add(1.0);
            }
        }
    }

    @Override
    public double getCompositeConstraintFeatureValue(FeatureVector input, int label) {
        double value = 0.0;
        for (int i = 0; i < this.indexCache.size(); ++i) {
            if (this.useValues) {
                value += this.constraints.get(this.indexCache.getQuick(i)).getGradientContribution(label) * this.valueCache.getQuick(i);
                continue;
            }
            value += this.constraints.get(this.indexCache.getQuick(i)).getGradientContribution(label);
        }
        return value;
    }

    @Override
    public void computeExpectations(FeatureVector input, double[] dist, double weight) {
        this.preProcess(input);
        for (int li = 0; li < this.numLabels; ++li) {
            double p = weight * dist[li];
            for (int i = 0; i < this.indexCache.size(); ++i) {
                if (this.useValues) {
                    int n = li;
                    this.constraints.get((int)this.indexCache.getQuick((int)i)).expectation[n] = this.constraints.get((int)this.indexCache.getQuick((int)i)).expectation[n] + p * this.valueCache.getQuick(i);
                    continue;
                }
                int n = li;
                this.constraints.get((int)this.indexCache.getQuick((int)i)).expectation[n] = this.constraints.get((int)this.indexCache.getQuick((int)i)).expectation[n] + p;
            }
        }
    }

    @Override
    public double getValue() {
        double value = 0.0;
        for (int fi : this.constraints.keys()) {
            MaxEntL2IndGEConstraint constraint = this.constraints.get(fi);
            if (!(constraint.count > 0.0)) continue;
            for (int labelIndex = 0; labelIndex < this.numLabels; ++labelIndex) {
                value -= constraint.getValue(labelIndex);
            }
        }
        assert (!Double.isNaN(value) && !Double.isInfinite(value));
        return value;
    }

    @Override
    public void zeroExpectations() {
        for (int fi : this.constraints.keys()) {
            this.constraints.get((int)fi).expectation = new double[this.constraints.get(fi).getNumConstrainedLabels()];
        }
    }

    protected class MaxEntL2IndGEConstraint {
        protected int index = 0;
        protected double count = 0.0;
        protected ArrayList<Double> lower = new ArrayList();
        protected ArrayList<Double> upper = new ArrayList();
        protected ArrayList<Double> weights = new ArrayList();
        protected HashMap<Integer, Integer> labelMap = new HashMap();
        protected double[] expectation;

        public void add(int label, double lower, double upper, double weight) {
            this.lower.add(lower);
            this.upper.add(upper);
            this.weights.add(weight);
            this.labelMap.put(label, this.index);
            ++this.index;
        }

        public void incrementExpectation(int li, double value) {
            if (this.labelMap.containsKey(li)) {
                int i;
                int n = i = this.labelMap.get(li).intValue();
                this.expectation[n] = this.expectation[n] + value;
            }
        }

        public double getValue(int li) {
            if (this.labelMap.containsKey(li)) {
                int i = this.labelMap.get(li);
                assert (this.count != 0.0);
                double ex = MaxEntRangeL2FLGEConstraints.this.normalize ? this.expectation[i] / this.count : this.expectation[i];
                if (ex < this.lower.get(i)) {
                    return this.weights.get(i) * Math.pow(this.lower.get(i) - ex, 2.0);
                }
                if (ex > this.upper.get(i)) {
                    return this.weights.get(i) * Math.pow(this.upper.get(i) - ex, 2.0);
                }
            }
            return 0.0;
        }

        public int getNumConstrainedLabels() {
            return this.index;
        }

        public double getGradientContribution(int li) {
            if (this.labelMap.containsKey(li)) {
                int i = this.labelMap.get(li);
                assert (this.count != 0.0);
                if (MaxEntRangeL2FLGEConstraints.this.normalize) {
                    double ex = this.expectation[i] / this.count;
                    if (ex < this.lower.get(i)) {
                        return 2.0 * this.weights.get(i) * (this.lower.get(i) / this.count - this.expectation[i] / (this.count * this.count));
                    }
                    if (ex > this.upper.get(i)) {
                        return 2.0 * this.weights.get(i) * (this.upper.get(i) / this.count - this.expectation[i] / (this.count * this.count));
                    }
                } else {
                    double ex = this.expectation[i];
                    if (ex < this.lower.get(i)) {
                        return 2.0 * this.weights.get(i) * (this.lower.get(i) - this.expectation[i]);
                    }
                    if (ex > this.upper.get(i)) {
                        return 2.0 * this.weights.get(i) * (this.upper.get(i) - this.expectation[i]);
                    }
                }
            }
            return 0.0;
        }
    }
}

