package jptools.ml.classifier.impl;

import java.io.Serializable;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Queue;
import java.util.Set;
import jptools.ml.classifier.IClassification;
import jptools.ml.classifier.IClassifier;

/* loaded from: input_file:jptools/ml/classifier/impl/AbstractClassifier.class */
public abstract class AbstractClassifier<T extends Serializable, K extends Serializable> implements IClassifier<T, K>, Serializable {
    private static final long serialVersionUID = 5504911666956811966L;
    private static final int INITIAL_CATEGORY_DICTIONARY_CAPACITY = 16;
    private static final int INITIAL_FEATURE_DICTIONARY_CAPACITY = 32;
    private Map<K, Map<T, Integer>> featureCountPerCategory;
    private Map<T, Integer> totalFeatureCount;
    private Map<K, Integer> totalCategoryCount;
    private Queue<IClassification<T, K>> memoryQueue;
    private int maxClassificationSize;

    public AbstractClassifier(int i) {
        this.maxClassificationSize = i;
        reset();
    }

    @Override // jptools.ml.classifier.IClassifier
    public int getMaxClassicicationSize() {
        return this.maxClassificationSize;
    }

    protected void setMaxClassicicationSize(int i) {
        for (int i2 = i; i2 > i; i2--) {
            this.memoryQueue.poll();
        }
        this.maxClassificationSize = i;
    }

    @Override // jptools.ml.classifier.IClassifier
    public void train(K k, List<T> list) {
        train(new ClassificationImpl(list, k));
    }

    public void train(IClassification<T, K> iClassification) {
        Iterator<T> it = iClassification.getFeatures().iterator();
        while (it.hasNext()) {
            incrementFeature(it.next(), iClassification.getCategory());
        }
        incrementCategory(iClassification.getCategory());
        this.memoryQueue.offer(iClassification);
        if (this.memoryQueue.size() > this.maxClassificationSize) {
            IClassification<T, K> remove = this.memoryQueue.remove();
            Iterator<T> it2 = remove.getFeatures().iterator();
            while (it2.hasNext()) {
                decrementFeature(it2.next(), remove.getCategory());
            }
            decrementCategory(remove.getCategory());
        }
    }

    @Override // jptools.ml.classifier.IClassifier
    public void reset() {
        this.featureCountPerCategory = new HashMap(INITIAL_CATEGORY_DICTIONARY_CAPACITY);
        this.totalFeatureCount = new HashMap(32);
        this.totalCategoryCount = new HashMap(INITIAL_CATEGORY_DICTIONARY_CAPACITY);
        this.memoryQueue = new LinkedList();
    }

    protected Set<T> getFeatures() {
        return this.totalFeatureCount.keySet();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public Set<K> getCategories() {
        return this.totalCategoryCount.keySet();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public int getCategoriesTotal() {
        int i = 0;
        Iterator<Integer> it = this.totalCategoryCount.values().iterator();
        while (it.hasNext()) {
            i += it.next().intValue();
        }
        return i;
    }

    protected void incrementFeature(T t, K k) {
        Map<T, Integer> map = this.featureCountPerCategory.get(k);
        if (map == null) {
            this.featureCountPerCategory.put(k, new HashMap(32));
            map = this.featureCountPerCategory.get(k);
        }
        Integer num = map.get(t);
        if (num == null) {
            map.put(t, 0);
            num = map.get(t);
        }
        map.put(t, Integer.valueOf(num.intValue() + 1));
        Integer num2 = this.totalFeatureCount.get(t);
        if (num2 == null) {
            this.totalFeatureCount.put(t, 0);
            num2 = this.totalFeatureCount.get(t);
        }
        this.totalFeatureCount.put(t, Integer.valueOf(num2.intValue() + 1));
    }

    protected void incrementCategory(K k) {
        Integer num = this.totalCategoryCount.get(k);
        if (num == null) {
            this.totalCategoryCount.put(k, 0);
            num = this.totalCategoryCount.get(k);
        }
        this.totalCategoryCount.put(k, Integer.valueOf(num.intValue() + 1));
    }

    protected void decrementFeature(T t, K k) {
        Integer num;
        Map<T, Integer> map = this.featureCountPerCategory.get(k);
        if (map == null || (num = map.get(t)) == null) {
            return;
        }
        if (num.intValue() == 1) {
            map.remove(t);
            if (map.size() == 0) {
                this.featureCountPerCategory.remove(k);
            }
        } else {
            map.put(t, Integer.valueOf(num.intValue() - 1));
        }
        Integer num2 = this.totalFeatureCount.get(t);
        if (num2 == null) {
            return;
        }
        if (num2.intValue() == 1) {
            this.totalFeatureCount.remove(t);
        } else {
            this.totalFeatureCount.put(t, Integer.valueOf(num2.intValue() - 1));
        }
    }

    protected void decrementCategory(K k) {
        Integer num = this.totalCategoryCount.get(k);
        if (num == null) {
            return;
        }
        if (num.intValue() == 1) {
            this.totalCategoryCount.remove(k);
        } else {
            this.totalCategoryCount.put(k, Integer.valueOf(num.intValue() - 1));
        }
    }

    protected int getFeatureCount(T t, K k) {
        Integer num;
        Map<T, Integer> map = this.featureCountPerCategory.get(k);
        if (map == null || (num = map.get(t)) == null) {
            return 0;
        }
        return num.intValue();
    }

    protected int getFeatureCount(T t) {
        Integer num = this.totalFeatureCount.get(t);
        if (num == null) {
            return 0;
        }
        return num.intValue();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public int getCategoryCount(K k) {
        Integer num = this.totalCategoryCount.get(k);
        if (num == null) {
            return 0;
        }
        return num.intValue();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public float calculateFeatureWeighedAverage(T t, K k) {
        return calculateFeatureWeighedAverage(t, k, 1.0f);
    }

    protected float calculateFeatureWeighedAverage(T t, K k, float f) {
        return calculateFeatureWeighedAverage(t, k, f, 0.5f);
    }

    protected float calculateFeatureWeighedAverage(T t, K k, float f, float f2) {
        float calculateFeatureProbability = calculateFeatureProbability(t, k);
        Integer num = this.totalFeatureCount.get(t);
        if (num == null) {
            num = 0;
        }
        return ((f * f2) + (num.intValue() * calculateFeatureProbability)) / (f + num.intValue());
    }

    protected float calculateFeatureProbability(T t, K k) {
        if (getFeatureCount(t) == 0.0f) {
            return 0.0f;
        }
        return getFeatureCount(t, k) / getFeatureCount(t);
    }
}
