/*
 * Decompiled with CFR 0.152.
 */
package tl.lin.data.cfd;

import it.unimi.dsi.fastutil.objects.Object2ObjectMap;
import it.unimi.dsi.fastutil.objects.Object2ObjectOpenHashMap;
import java.util.Iterator;
import tl.lin.data.cfd.Object2IntConditionalFrequencyDistribution;
import tl.lin.data.fd.Object2IntFrequencyDistribution;
import tl.lin.data.fd.Object2IntFrequencyDistributionFastutil;
import tl.lin.data.fd.Object2LongFrequencyDistribution;
import tl.lin.data.fd.Object2LongFrequencyDistributionFastutil;
import tl.lin.data.pair.PairOfObjectInt;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class Object2IntConditionalFrequencyDistributionFastutil<K extends Comparable<K>>
implements Object2IntConditionalFrequencyDistribution<K> {
    private final Object2ObjectMap<K, Object2IntFrequencyDistribution<K>> distributions = new Object2ObjectOpenHashMap();
    private final Object2LongFrequencyDistribution<K> marginals = new Object2LongFrequencyDistributionFastutil();
    private long sumOfAllCounts = 0L;

    public void set(K k, K cond, int v) {
        if (!this.distributions.containsKey(cond)) {
            Object2IntFrequencyDistributionFastutil<K> fd = new Object2IntFrequencyDistributionFastutil<K>();
            fd.set(k, v);
            this.distributions.put(cond, fd);
            this.marginals.increment(k, (long)v);
            this.sumOfAllCounts += (long)v;
        } else {
            Object2IntFrequencyDistribution fd = (Object2IntFrequencyDistribution)this.distributions.get(cond);
            int rv = fd.get(k);
            fd.set(k, v);
            this.distributions.put(cond, (Object)fd);
            this.marginals.increment(k, (long)(-rv + v));
            this.sumOfAllCounts = this.sumOfAllCounts - (long)rv + (long)v;
        }
    }

    public void increment(K k, K cond) {
        this.increment(k, cond, 1);
    }

    public void increment(K k, K cond, int v) {
        int cur = this.get(k, cond);
        if (cur == 0) {
            this.set(k, cond, v);
        } else {
            this.set(k, cond, cur + v);
        }
    }

    public int get(K k, K cond) {
        if (!this.distributions.containsKey(cond)) {
            return 0;
        }
        return ((Object2IntFrequencyDistribution)this.distributions.get(cond)).get(k);
    }

    public long getMarginalCount(K k) {
        return this.marginals.get(k);
    }

    public Object2IntFrequencyDistribution<K> getConditionalDistribution(K cond) {
        if (this.distributions.containsKey(cond)) {
            return (Object2IntFrequencyDistribution)this.distributions.get(cond);
        }
        return new Object2IntFrequencyDistributionFastutil();
    }

    public long getSumOfAllCounts() {
        return this.sumOfAllCounts;
    }

    public void check() {
        PairOfObjectInt e;
        Object2IntFrequencyDistributionFastutil<Comparable> m = new Object2IntFrequencyDistributionFastutil<Comparable>();
        long totalSum = 0L;
        for (Object2IntFrequencyDistribution fd : this.distributions.values()) {
            long conditionalSum = 0L;
            for (PairOfObjectInt pair : fd) {
                conditionalSum += (long)pair.getRightElement();
                m.increment(pair.getLeftElement(), pair.getRightElement());
            }
            if (conditionalSum != fd.getSumOfCounts()) {
                throw new RuntimeException("Internal Error!");
            }
            totalSum += fd.getSumOfCounts();
        }
        if (totalSum != this.getSumOfAllCounts()) {
            throw new RuntimeException("Internal Error! Got " + totalSum + ", Expected " + this.getSumOfAllCounts());
        }
        Iterator i$ = m.iterator();
        while (i$.hasNext()) {
            e = i$.next();
            if ((long)e.getRightElement() == this.marginals.get(e.getLeftElement())) continue;
            throw new RuntimeException("Internal Error!");
        }
        i$ = m.iterator();
        while (i$.hasNext()) {
            e = i$.next();
            if (e.getRightElement() == m.get(e.getLeftElement())) continue;
            throw new RuntimeException("Internal Error!");
        }
    }
}

