/*
 * Decompiled with CFR 0.152.
 */
package dragon.ir.classification.featureselection;

import dragon.ir.classification.DocClass;
import dragon.ir.classification.DocClassSet;
import dragon.ir.classification.featureselection.AbstractFeatureSelector;
import dragon.ir.index.IndexReader;
import dragon.matrix.SparseMatrix;
import dragon.matrix.vector.DoubleVector;
import dragon.nlp.Token;
import dragon.nlp.compare.IndexComparator;
import dragon.nlp.compare.WeightComparator;
import dragon.util.MathUtil;
import dragon.util.SortedArray;
import java.io.Serializable;

public class InfoGainFeatureSelector
extends AbstractFeatureSelector
implements Serializable {
    private static final long serialVersionUID = 1L;
    private double topPercentage;

    public InfoGainFeatureSelector(double topPercentage) {
        this.topPercentage = topPercentage;
    }

    @Override
    protected int[] getSelectedFeatures(SparseMatrix doctermMatrix, DocClassSet trainingSet) {
        System.out.println("InfoGainSelector does not accept SparseMatrix as input. Please use IndexReader as input instead.");
        return null;
    }

    @Override
    protected int[] getSelectedFeatures(IndexReader indexReader, DocClassSet trainingSet) {
        int i;
        SortedArray list = this.computeTermIG(indexReader, trainingSet);
        int termNum = (int)(this.topPercentage * (double)indexReader.getCollection().getTermNum());
        termNum = Math.min(list.size(), termNum);
        SortedArray selectedList = new SortedArray(termNum, new IndexComparator());
        for (i = 0; i < termNum; ++i) {
            selectedList.add(list.get(i));
        }
        int[] featureMap = new int[selectedList.size()];
        for (i = 0; i < featureMap.length; ++i) {
            featureMap[i] = ((Token)selectedList.get(i)).getIndex();
        }
        return featureMap;
    }

    private SortedArray computeTermIG(IndexReader indexReader, DocClassSet trainingSet) {
        int j;
        int i;
        int trainingDocNum = 0;
        for (i = 0; i < trainingSet.getClassNum(); ++i) {
            trainingDocNum += trainingSet.getDocClass(i).getDocNum();
        }
        DoubleVector classPrior = this.getClassPrior(trainingSet);
        double classEntropy = this.calEntropy(classPrior);
        DoubleVector classVector = classPrior.copy();
        classVector.multiply(trainingDocNum);
        int[] arrDoc = new int[indexReader.getCollection().getDocNum()];
        MathUtil.initArray(arrDoc, -1);
        for (i = 0; i < trainingSet.getClassNum(); ++i) {
            DocClass docClass = trainingSet.getDocClass(i);
            for (j = 0; j < docClass.getDocNum(); ++j) {
                arrDoc[docClass.getDoc((int)j).getIndex()] = i;
            }
        }
        int termNum = indexReader.getCollection().getTermNum();
        SortedArray list = new SortedArray(termNum, new IndexComparator());
        DoubleVector termVector = new DoubleVector(termNum);
        DoubleVector classDistrWiTerm = new DoubleVector(classPrior.size());
        DoubleVector classDistrWoTerm = new DoubleVector(classPrior.size());
        for (i = 0; i < termNum; ++i) {
            int[] arrDocIndex = indexReader.getTermDocIndexList(i);
            if (arrDocIndex == null || arrDocIndex.length == 0) continue;
            classDistrWiTerm.assign(0.0);
            classDistrWoTerm.assign(classVector);
            int docCount = 0;
            for (j = 0; j < arrDocIndex.length; ++j) {
                int docLabel = arrDoc[arrDocIndex[j]];
                if (docLabel < 0) continue;
                classDistrWiTerm.add(docLabel, 1.0);
                classDistrWoTerm.add(docLabel, -1.0);
                ++docCount;
            }
            if (docCount == 0) continue;
            classDistrWiTerm.multiply(1.0 / (double)docCount);
            classDistrWoTerm.multiply(1.0 / (double)(trainingDocNum - docCount));
            termVector.set(i, classEntropy - this.calEntropy(classDistrWiTerm) - this.calEntropy(classDistrWoTerm));
        }
        for (i = 0; i < termVector.size(); ++i) {
            Token curTerm = new Token(i, 0);
            if (termVector.get(i) <= 0.0) continue;
            curTerm.setWeight(termVector.get(i));
            list.add(curTerm);
        }
        list.setComparator(new WeightComparator(true));
        return list;
    }

    private double calEntropy(DoubleVector probVector) {
        double sum = 0.0;
        for (int i = 0; i < probVector.size(); ++i) {
            if (probVector.get(i) == 0.0) {
                sum -= Double.MIN_VALUE * Math.log(Double.MIN_VALUE);
                continue;
            }
            sum -= probVector.get(i) * Math.log(probVector.get(i));
        }
        return sum;
    }
}

