/*
 * Decompiled with CFR 0.152.
 */
package com.aliasi.classify;

import com.aliasi.classify.Classification;
import com.aliasi.classify.Classified;
import com.aliasi.classify.JointClassification;
import com.aliasi.classify.JointClassifier;
import com.aliasi.corpus.ObjectHandler;
import com.aliasi.stats.MultivariateEstimator;
import com.aliasi.util.AbstractExternalizable;
import com.aliasi.util.Counter;
import com.aliasi.util.FeatureExtractor;
import com.aliasi.util.Math;
import com.aliasi.util.ObjectToCounterMap;
import com.aliasi.util.ObjectToDoubleMap;
import com.aliasi.util.ScoredObject;
import com.aliasi.util.Strings;
import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.io.Serializable;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

public class BernoulliClassifier<E>
implements JointClassifier<E>,
ObjectHandler<Classified<E>>,
Serializable {
    static final long serialVersionUID = -7761909693358968780L;
    private final MultivariateEstimator mCategoryDistribution;
    private final FeatureExtractor<E> mFeatureExtractor;
    private final double mActivationThreshold;
    private final Set<String> mFeatureSet;
    private final Map<String, ObjectToCounterMap<String>> mFeatureDistributionMap;

    public BernoulliClassifier(FeatureExtractor<E> featureExtractor) {
        this(featureExtractor, 0.0);
    }

    public BernoulliClassifier(FeatureExtractor<E> featureExtractor, double featureActivationThreshold) {
        this(new MultivariateEstimator(), featureExtractor, featureActivationThreshold, new HashSet<String>(), new HashMap<String, ObjectToCounterMap<String>>());
    }

    BernoulliClassifier(MultivariateEstimator catDistro, FeatureExtractor<E> featureExtractor, double activationThreshold, Set<String> featureSet, Map<String, ObjectToCounterMap<String>> featureDistributionMap) {
        this.mCategoryDistribution = catDistro;
        this.mFeatureExtractor = featureExtractor;
        this.mActivationThreshold = activationThreshold;
        this.mFeatureSet = featureSet;
        this.mFeatureDistributionMap = featureDistributionMap;
    }

    public double featureActivationThreshold() {
        return this.mActivationThreshold;
    }

    public FeatureExtractor<E> featureExtractor() {
        return this.mFeatureExtractor;
    }

    public String[] categories() {
        String[] categories = new String[this.mCategoryDistribution.numDimensions()];
        int i = 0;
        while (i < this.mCategoryDistribution.numDimensions()) {
            categories[i] = this.mCategoryDistribution.label(i);
            ++i;
        }
        return categories;
    }

    @Override
    public void handle(Classified<E> classified) {
        this.handle(classified.getObject(), classified.getClassification());
    }

    void handle(E input, Classification classification) {
        String category = classification.bestCategory();
        this.mCategoryDistribution.train(category, 1L);
        ObjectToCounterMap<String> categoryCounter = this.mFeatureDistributionMap.get(category);
        if (categoryCounter == null) {
            categoryCounter = new ObjectToCounterMap();
            this.mFeatureDistributionMap.put(category, categoryCounter);
        }
        for (String feature : this.activeFeatureSet(input)) {
            categoryCounter.increment(feature);
            this.mFeatureSet.add(feature);
        }
    }

    @Override
    public JointClassification classify(E input) {
        Set<String> activeFeatureSet = this.activeFeatureSet(input);
        HashSet<String> inactiveFeatureSet = new HashSet<String>(this.mFeatureSet);
        inactiveFeatureSet.removeAll(activeFeatureSet);
        String[] activeFeatures = activeFeatureSet.toArray(Strings.EMPTY_STRING_ARRAY);
        String[] inactiveFeatures = inactiveFeatureSet.toArray(Strings.EMPTY_STRING_ARRAY);
        ObjectToDoubleMap<String> categoryToLog2P = new ObjectToDoubleMap<String>();
        int numCategories = this.mCategoryDistribution.numDimensions();
        long i = 0L;
        while (i < (long)numCategories) {
            String category = this.mCategoryDistribution.label(i);
            double log2P = Math.log2(this.mCategoryDistribution.probability(i));
            double categoryCount = this.mCategoryDistribution.getCount(i);
            ObjectToCounterMap<String> categoryFeatureCounts = this.mFeatureDistributionMap.get(category);
            String[] stringArray = activeFeatures;
            int n = activeFeatures.length;
            int n2 = 0;
            while (n2 < n) {
                String activeFeature = stringArray[n2];
                double featureCount = categoryFeatureCounts.getCount(activeFeature);
                if (featureCount != 0.0) {
                    log2P += Math.log2((featureCount + 1.0) / (categoryCount + 2.0));
                }
                ++n2;
            }
            stringArray = inactiveFeatures;
            n = inactiveFeatures.length;
            n2 = 0;
            while (n2 < n) {
                String inactiveFeature = stringArray[n2];
                double notFeatureCount = categoryCount - (double)categoryFeatureCounts.getCount(inactiveFeature);
                log2P += Math.log2((notFeatureCount + 1.0) / (categoryCount + 2.0));
                ++n2;
            }
            categoryToLog2P.set(category, log2P);
            ++i;
        }
        String[] categories = new String[numCategories];
        double[] log2Ps = new double[numCategories];
        List scoredObjectList = categoryToLog2P.scoredObjectsOrderedByValueList();
        int i2 = 0;
        while (i2 < numCategories) {
            ScoredObject so = scoredObjectList.get(i2);
            categories[i2] = (String)so.getObject();
            log2Ps[i2] = so.score();
            ++i2;
        }
        return new JointClassification(categories, log2Ps);
    }

    Object writeReplace() {
        return new Serializer(this);
    }

    private Set<String> activeFeatureSet(E input) {
        HashSet<String> activeFeatureSet = new HashSet<String>();
        Map<String, Number> featureMap = this.mFeatureExtractor.features(input);
        for (Map.Entry<String, Number> entry : featureMap.entrySet()) {
            String feature = entry.getKey();
            Number val = entry.getValue();
            if (!(val.doubleValue() > this.mActivationThreshold)) continue;
            activeFeatureSet.add(feature);
        }
        return activeFeatureSet;
    }

    static class Serializer<F>
    extends AbstractExternalizable {
        static final long serialVersionUID = 4803666611627400222L;
        final BernoulliClassifier<F> mClassifier;

        public Serializer(BernoulliClassifier<F> classifier) {
            this.mClassifier = classifier;
        }

        public Serializer() {
            this(null);
        }

        @Override
        public void writeExternal(ObjectOutput objOut) throws IOException {
            objOut.writeObject(((BernoulliClassifier)this.mClassifier).mCategoryDistribution);
            objOut.writeObject(((BernoulliClassifier)this.mClassifier).mFeatureExtractor);
            objOut.writeDouble(((BernoulliClassifier)this.mClassifier).mActivationThreshold);
            objOut.writeInt(((BernoulliClassifier)this.mClassifier).mFeatureSet.size());
            for (String string : ((BernoulliClassifier)this.mClassifier).mFeatureSet) {
                objOut.writeUTF(string);
            }
            objOut.writeInt(((BernoulliClassifier)this.mClassifier).mFeatureDistributionMap.size());
            for (Map.Entry entry : ((BernoulliClassifier)this.mClassifier).mFeatureDistributionMap.entrySet()) {
                objOut.writeUTF((String)entry.getKey());
                ObjectToCounterMap map = (ObjectToCounterMap)entry.getValue();
                objOut.writeInt(map.size());
                for (Map.Entry entry2 : map.entrySet()) {
                    objOut.writeUTF((String)entry2.getKey());
                    objOut.writeInt(((Counter)entry2.getValue()).intValue());
                }
            }
        }

        @Override
        public Object read(ObjectInput objIn) throws ClassNotFoundException, IOException {
            MultivariateEstimator estimator = (MultivariateEstimator)objIn.readObject();
            FeatureExtractor featureExtractor = (FeatureExtractor)objIn.readObject();
            double activationThreshold = objIn.readDouble();
            int featureSetSize = objIn.readInt();
            HashSet<String> featureSet = new HashSet<String>(2 * featureSetSize);
            int i = 0;
            while (i < featureSetSize) {
                featureSet.add(objIn.readUTF());
                ++i;
            }
            int featureDistributionMapSize = objIn.readInt();
            HashMap<String, ObjectToCounterMap<String>> featureDistributionMap = new HashMap<String, ObjectToCounterMap<String>>(2 * featureDistributionMapSize);
            int i2 = 0;
            while (i2 < featureDistributionMapSize) {
                String key = objIn.readUTF();
                int mapSize = objIn.readInt();
                ObjectToCounterMap<String> otc = new ObjectToCounterMap<String>();
                featureDistributionMap.put(key, otc);
                int j = 0;
                while (j < mapSize) {
                    String key2 = objIn.readUTF();
                    int count = objIn.readInt();
                    otc.set(key2, count);
                    ++j;
                }
                ++i2;
            }
            return new BernoulliClassifier(estimator, featureExtractor, activationThreshold, featureSet, featureDistributionMap);
        }
    }
}

