/*
 * Decompiled with CFR 0.152.
 */
package dragon.ir.classification.featureselection;

import dragon.ir.classification.DocClassSet;
import dragon.ir.classification.featureselection.AbstractFeatureSelector;
import dragon.ir.index.IndexReader;
import dragon.matrix.IntDenseMatrix;
import dragon.matrix.SparseMatrix;
import dragon.matrix.vector.DoubleVector;
import dragon.nlp.Token;
import dragon.nlp.compare.IndexComparator;
import dragon.nlp.compare.WeightComparator;
import dragon.util.SortedArray;
import java.io.Serializable;

public class MutualInfoFeatureSelector
extends AbstractFeatureSelector
implements Serializable {
    private static final long serialVersionUID = 1L;
    private double topPercentage;
    private boolean avgMode;

    public MutualInfoFeatureSelector(double topPercentage, boolean avgMode) {
        this.topPercentage = topPercentage;
        this.avgMode = avgMode;
    }

    @Override
    protected int[] getSelectedFeatures(IndexReader indexReader, DocClassSet trainingSet) {
        int i;
        DoubleVector classPrior = this.getClassPrior(trainingSet);
        int docNum = 0;
        for (i = 0; i < trainingSet.getClassNum(); ++i) {
            docNum += trainingSet.getDocClass(i).getDocNum();
        }
        SortedArray list = this.computeTermMI(this.getTermDistribution(indexReader, trainingSet), classPrior, docNum);
        int termNum = (int)(this.topPercentage * (double)indexReader.getCollection().getTermNum());
        termNum = Math.min(list.size(), termNum);
        SortedArray selectedList = new SortedArray(termNum, new IndexComparator());
        for (i = 0; i < termNum; ++i) {
            selectedList.add(list.get(i));
        }
        int[] featureMap = new int[selectedList.size()];
        for (i = 0; i < featureMap.length; ++i) {
            featureMap[i] = ((Token)selectedList.get(i)).getIndex();
        }
        return featureMap;
    }

    @Override
    protected int[] getSelectedFeatures(SparseMatrix doctermMatrix, DocClassSet trainingSet) {
        int i;
        DoubleVector classPrior = this.getClassPrior(trainingSet);
        int docNum = 0;
        for (i = 0; i < trainingSet.getClassNum(); ++i) {
            docNum += trainingSet.getDocClass(i).getDocNum();
        }
        SortedArray list = this.computeTermMI(this.getTermDistribution(doctermMatrix, trainingSet), classPrior, docNum);
        int termNum = (int)(this.topPercentage * (double)doctermMatrix.columns());
        termNum = Math.min(list.size(), termNum);
        SortedArray selectedList = new SortedArray(termNum, new IndexComparator());
        for (i = 0; i < termNum; ++i) {
            selectedList.add(list.get(i));
        }
        int[] featureMap = new int[selectedList.size()];
        for (i = 0; i < featureMap.length; ++i) {
            featureMap[i] = ((Token)selectedList.get(i)).getIndex();
        }
        return featureMap;
    }

    private SortedArray computeTermMI(IntDenseMatrix termDistri, DoubleVector classPrior, int docNum) {
        int i;
        DoubleVector classVector = classPrior.copy();
        classVector.multiply(docNum);
        DoubleVector termVector = new DoubleVector(termDistri.columns());
        for (i = 0; i < termDistri.columns(); ++i) {
            termVector.set(i, termDistri.getColumnSum(i));
        }
        double total = docNum;
        DoubleVector chiVector = new DoubleVector(classVector.size());
        SortedArray list = new SortedArray(termVector.size(), new IndexComparator());
        for (i = 0; i < termVector.size(); ++i) {
            if (termVector.get(i) <= 0.0) continue;
            for (int j = 0; j < classVector.size(); ++j) {
                chiVector.set(j, this.calMutualInformation(termDistri.getInt(j, i), classVector.get(j), termVector.get(i), total));
            }
            Token curTerm = new Token(i, 0);
            if (this.avgMode) {
                curTerm.setWeight(chiVector.dotProduct(classPrior));
            } else {
                curTerm.setWeight(chiVector.getMaxValue());
            }
            list.add(curTerm);
        }
        list.setComparator(new WeightComparator(true));
        return list;
    }

    private double calMutualInformation(double t1t2occur, double t1sum, double t2sum, double total) {
        if (t1t2occur == 0.0 || t1sum == 0.0 || t2sum == 0.0) {
            return 0.0;
        }
        return Math.log(t1t2occur * total / (t1sum * t2sum));
    }
}

