package org.apache.ctakes.assertion.attributes.features.selection;

import com.google.common.base.Function;
import com.google.common.base.Joiner;
import com.google.common.collect.HashBasedTable;
import com.google.common.collect.HashMultiset;
import com.google.common.collect.Maps;
import com.google.common.collect.Multiset;
import com.google.common.collect.Ordering;
import com.google.common.collect.Sets;
import com.google.common.collect.Table;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.net.URI;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import org.cleartk.classifier.Feature;
import org.cleartk.classifier.Instance;
import org.cleartk.classifier.feature.transform.TransformableFeature;

/* loaded from: input_file:org/apache/ctakes/assertion/attributes/features/selection/MutualInformationFeatureSelection.class */
public class MutualInformationFeatureSelection<OUTCOME_T> extends FeatureSelection<OUTCOME_T> {
    private MutualInformationStats<OUTCOME_T> mutualInfoStats;
    private int numFeatures;
    private CombineScoreMethod combineScoreMethod;
    private double smoothingCount;

    /* loaded from: input_file:org/apache/ctakes/assertion/attributes/features/selection/MutualInformationFeatureSelection$CombineScoreMethod.class */
    public enum CombineScoreMethod implements Function<Map<?, Double>, Double> {
        AVERAGE { // from class: org.apache.ctakes.assertion.attributes.features.selection.MutualInformationFeatureSelection.CombineScoreMethod.1
            public Double apply(Map<?, Double> map) {
                Collection<Double> values = map.values();
                int size = values.size();
                double d = 0.0d;
                Iterator<Double> it = values.iterator();
                while (it.hasNext()) {
                    d += it.next().doubleValue();
                }
                return Double.valueOf(d / size);
            }
        },
        MAX { // from class: org.apache.ctakes.assertion.attributes.features.selection.MutualInformationFeatureSelection.CombineScoreMethod.2
            public Double apply(Map<?, Double> map) {
                return (Double) Ordering.natural().max(map.values());
            }
        }
    }

    /* loaded from: input_file:org/apache/ctakes/assertion/attributes/features/selection/MutualInformationFeatureSelection$MutualInformationStats.class */
    public static class MutualInformationStats<OUTCOME_T> {
        protected Multiset<OUTCOME_T> classCounts = HashMultiset.create();
        protected Table<String, OUTCOME_T, Integer> classConditionalCounts = HashBasedTable.create();
        protected double smoothingCount;

        public MutualInformationStats(double d) {
            this.smoothingCount += d;
        }

        public void update(String str, OUTCOME_T outcome_t, int i) {
            Integer num = (Integer) this.classConditionalCounts.get(str, outcome_t);
            if (num == null) {
                num = 0;
            }
            this.classConditionalCounts.put(str, outcome_t, Integer.valueOf(num.intValue() + i));
            this.classCounts.add(outcome_t, i);
        }

        public double mutualInformation(String str, OUTCOME_T outcome_t) {
            int[][] iArr = new int[2][2];
            int size = this.classCounts.size();
            int[] iArr2 = {size - iArr2[1], sum(this.classConditionalCounts.row(str).values())};
            int[] iArr3 = {size - iArr3[1], this.classCounts.count(outcome_t)};
            iArr[1][1] = this.classConditionalCounts.contains(str, outcome_t) ? ((Integer) this.classConditionalCounts.get(str, outcome_t)).intValue() : 0;
            iArr[1][0] = iArr2[1] - iArr[1][1];
            iArr[0][1] = iArr3[1] - iArr[1][1];
            iArr[0][0] = ((size - iArr2[1]) - iArr3[1]) + iArr[1][1];
            double d = 0.0d;
            for (int i = 0; i <= 1; i++) {
                for (int i2 = 0; i2 <= 1; i2++) {
                    iArr[i][i2] = (int) (r0[r1] + this.smoothingCount);
                    d += (iArr[i][i2] / size) * Math.log((size * iArr[i][i2]) / (iArr2[i] * iArr3[i2]));
                }
            }
            return d;
        }

        private static int sum(Collection<Integer> collection) {
            int i = 0;
            Iterator<Integer> it = collection.iterator();
            while (it.hasNext()) {
                i += it.next().intValue();
            }
            return i;
        }

        /* JADX WARN: Multi-variable type inference failed */
        public void save(URI uri) throws IOException {
            BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(new File(uri)));
            bufferedWriter.append((CharSequence) "Mutual Information Data\n");
            bufferedWriter.append((CharSequence) "Feature\t");
            bufferedWriter.append((CharSequence) Joiner.on("\t").join(this.classConditionalCounts.columnKeySet()));
            bufferedWriter.append((CharSequence) "\n");
            for (String str : this.classConditionalCounts.rowKeySet()) {
                bufferedWriter.append((CharSequence) str);
                for (Object obj : this.classConditionalCounts.columnKeySet()) {
                    bufferedWriter.append((CharSequence) "\t");
                    bufferedWriter.append((CharSequence) String.format("%f", Double.valueOf(mutualInformation(str, obj))));
                }
                bufferedWriter.append((CharSequence) "\n");
            }
            bufferedWriter.append((CharSequence) "\n");
            bufferedWriter.append((CharSequence) this.classConditionalCounts.toString());
            bufferedWriter.close();
        }

        public Function<String, Double> getScoreFunction(final CombineScoreMethod combineScoreMethod) {
            return new Function<String, Double>() { // from class: org.apache.ctakes.assertion.attributes.features.selection.MutualInformationFeatureSelection.MutualInformationStats.1
                /* JADX WARN: Multi-variable type inference failed */
                public Double apply(String str) {
                    Set columnKeySet = MutualInformationStats.this.classConditionalCounts.columnKeySet();
                    HashMap newHashMap = Maps.newHashMap();
                    for (Object obj : columnKeySet) {
                        newHashMap.put(obj, Double.valueOf(MutualInformationStats.this.mutualInformation(str, obj)));
                    }
                    return (Double) combineScoreMethod.apply(newHashMap);
                }
            };
        }
    }

    public MutualInformationFeatureSelection(String str) {
        this(str, CombineScoreMethod.MAX, 1.0d, 10);
    }

    public MutualInformationFeatureSelection(String str, int i) {
        this(str, CombineScoreMethod.MAX, 1.0d, i);
    }

    public MutualInformationFeatureSelection(String str, CombineScoreMethod combineScoreMethod, double d, int i) {
        super(str);
        this.combineScoreMethod = combineScoreMethod;
        this.smoothingCount = d;
        this.numFeatures = i;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public void train(Iterable<Instance<OUTCOME_T>> iterable) {
        this.mutualInfoStats = new MutualInformationStats<>(this.smoothingCount);
        for (Instance<OUTCOME_T> instance : iterable) {
            Object outcome = instance.getOutcome();
            for (TransformableFeature transformableFeature : instance.getFeatures()) {
                if (isTransformable(transformableFeature)) {
                    Iterator it = transformableFeature.getFeatures().iterator();
                    while (it.hasNext()) {
                        this.mutualInfoStats.update(getFeatureName((Feature) it.next()), outcome, 1);
                    }
                }
            }
        }
        this.selectedFeatureNames = Sets.newLinkedHashSet(Ordering.natural().onResultOf(this.mutualInfoStats.getScoreFunction(this.combineScoreMethod)).reverse().immutableSortedCopy(this.mutualInfoStats.classConditionalCounts.rowKeySet()).subList(0, this.numFeatures));
        this.isTrained = true;
    }

    public void save(URI uri) throws IOException {
        if (!this.isTrained) {
            throw new IOException("MutualInformationFeatureExtractor: Cannot save before training.");
        }
        BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(new File(uri)));
        bufferedWriter.append((CharSequence) "CombineScoreType\t");
        bufferedWriter.append((CharSequence) this.combineScoreMethod.toString());
        bufferedWriter.append('\n');
        Iterator<String> it = this.selectedFeatureNames.iterator();
        while (it.hasNext()) {
            bufferedWriter.append((CharSequence) it.next());
            bufferedWriter.append('\n');
        }
        bufferedWriter.close();
    }

    public void load(URI uri) throws IOException {
        this.selectedFeatureNames = Sets.newLinkedHashSet();
        BufferedReader bufferedReader = new BufferedReader(new FileReader(new File(uri)));
        this.combineScoreMethod = CombineScoreMethod.valueOf(bufferedReader.readLine().split("\t")[1]);
        int i = 0;
        while (true) {
            String readLine = bufferedReader.readLine();
            if (readLine == null || i >= this.numFeatures) {
                break;
            }
            this.selectedFeatureNames.add(readLine.trim());
            i++;
        }
        bufferedReader.close();
        this.isTrained = true;
    }
}
