/*
 * Decompiled with CFR 0.152.
 */
package dragon.ir.classification;

import dragon.ir.classification.DocClass;
import dragon.ir.classification.DocClassSet;
import dragon.ir.classification.NBClassifier;
import dragon.ir.index.IRDoc;
import dragon.ir.index.IRTerm;
import dragon.ir.index.IndexReader;
import dragon.matrix.DoubleFlatDenseMatrix;
import dragon.matrix.IntRow;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Random;

public class NigamActiveLearning
extends NBClassifier {
    private IntRow[] externalUnlabeled;
    private DocClass unlabeledSet;
    private DocClass unlabeledSetBackup;
    private int externalDocOffset;
    private double convergeThreshold;
    private double unlabeledRate;
    private int runNum;

    public NigamActiveLearning(String modelFile) {
        super(modelFile);
    }

    public NigamActiveLearning(IndexReader indexReader, double unlabeledRate) {
        super(indexReader);
        this.externalDocOffset = indexReader.getCollection().getDocNum();
        this.runNum = 15;
        this.convergeThreshold = 1.0E-4;
        this.unlabeledRate = unlabeledRate;
    }

    public void setUnlabeledData(IndexReader newIndexReader, DocClass docSet) {
        int[] termMap = this.getTermMap(newIndexReader, this.indexReader);
        this.externalUnlabeled = new IntRow[docSet.getDocNum()];
        this.unlabeledSet = new DocClass(0);
        int docNum = 0;
        for (int i = 0; i < this.externalUnlabeled.length; ++i) {
            int j;
            IRDoc curDoc = docSet.getDoc(i);
            int[] arrIndex = newIndexReader.getTermIndexList(curDoc.getIndex());
            int[] arrFreq = newIndexReader.getTermFrequencyList(curDoc.getIndex());
            if (arrIndex == null) continue;
            int termNum = 0;
            for (j = 0; j < arrIndex.length; ++j) {
                if (termMap[arrIndex[j]] < 0) continue;
                ++termNum;
            }
            if (termNum == 0) continue;
            int[] arrNewIndex = new int[termNum];
            int[] arrNewFreq = new int[termNum];
            termNum = 0;
            for (j = 0; j < arrIndex.length; ++j) {
                int newIndex = termMap[arrIndex[j]];
                if (newIndex < 0) continue;
                arrNewIndex[termNum] = newIndex;
                arrNewFreq[termNum] = arrFreq[j];
                ++termNum;
            }
            this.externalUnlabeled[docNum] = new IntRow(docNum, termNum, arrNewIndex, arrNewFreq);
            curDoc.setIndex(this.externalDocOffset + docNum);
            curDoc.setKey("external_unlabeled" + curDoc.getKey());
            this.unlabeledSet.addDoc(curDoc);
            ++docNum;
        }
    }

    public void setUnlabeledData(DocClass docSet) {
        this.unlabeledSet = docSet;
        this.externalUnlabeled = null;
    }

    @Override
    public DocClassSet classify(DocClassSet trainingDocSet, DocClass testingDocs) {
        if (this.indexReader == null && this.doctermMatrix == null) {
            return null;
        }
        if (this.unlabeledRate > 0.0) {
            int i;
            this.unlabeledSetBackup = this.unlabeledSet;
            this.unlabeledSet = new DocClass(0);
            if (this.unlabeledSetBackup != null) {
                for (i = 0; i < this.unlabeledSetBackup.getDocNum(); ++i) {
                    this.unlabeledSet.addDoc(this.unlabeledSetBackup.getDoc(i));
                }
            }
            ArrayList<IRDoc> list = new ArrayList<IRDoc>(testingDocs.getDocNum());
            for (i = 0; i < testingDocs.getDocNum(); ++i) {
                list.add(testingDocs.getDoc(i));
            }
            Collections.shuffle(list, new Random(10L));
            int num = (int)(this.unlabeledRate * (double)list.size());
            for (i = 0; i < num; ++i) {
                this.unlabeledSet.addDoc((IRDoc)list.get(i));
            }
            this.train(trainingDocSet);
            this.unlabeledSet.removeAll();
            this.unlabeledSet = this.unlabeledSetBackup;
        } else {
            this.train(trainingDocSet);
        }
        return this.classify(testingDocs);
    }

    @Override
    public void train(DocClassSet trainingDocSet) {
        int i;
        if (this.indexReader == null && this.doctermMatrix == null) {
            return;
        }
        this.classNum = trainingDocSet.getClassNum();
        this.arrLabel = new String[this.classNum];
        for (i = 0; i < this.classNum; ++i) {
            this.arrLabel[i] = trainingDocSet.getDocClass(i).getClassName();
        }
        this.eStep(trainingDocSet);
        double prevProb = 0.0;
        double prob = -1.7976931348623157E308;
        for (int curRun = 0; Math.abs(prob - prevProb) > this.convergeThreshold && curRun < this.runNum; ++curRun) {
            IRDoc curDoc;
            int j;
            DocClass cur;
            prevProb = prob;
            prob = 0.0;
            DocClassSet classifiedUnlabeledSet = this.classify(this.unlabeledSet);
            for (i = 0; i < trainingDocSet.getClassNum(); ++i) {
                cur = trainingDocSet.getDocClass(i);
                for (j = 0; j < cur.getDocNum(); ++j) {
                    curDoc = cur.getDoc(j);
                    prob += curDoc.getWeight();
                }
            }
            for (i = 0; i < trainingDocSet.getClassNum(); ++i) {
                cur = trainingDocSet.getDocClass(i);
                for (j = 0; j < cur.getDocNum(); ++j) {
                    curDoc = cur.getDoc(j);
                    double docProb = this.classPrior.get(i);
                    int[] arrIndex = this.indexReader.getTermIndexList(curDoc.getIndex());
                    int[] arrFreq = this.indexReader.getTermFrequencyList(curDoc.getIndex());
                    for (int k = 0; k < arrIndex.length; ++k) {
                        int newTermIndex = this.featureSelector.map(arrIndex[k]);
                        if (newTermIndex < 0) continue;
                        docProb += (double)arrFreq[k] * this.model.getDouble(i, newTermIndex);
                    }
                    prob += docProb;
                }
            }
            DocClassSet newTrainingSet = new DocClassSet(trainingDocSet.getClassNum());
            for (i = 0; i < trainingDocSet.getClassNum(); ++i) {
                cur = trainingDocSet.getDocClass(i);
                for (j = 0; j < cur.getDocNum(); ++j) {
                    newTrainingSet.addDoc(i, cur.getDoc(j));
                }
            }
            for (i = 0; i < classifiedUnlabeledSet.getClassNum(); ++i) {
                cur = classifiedUnlabeledSet.getDocClass(i);
                for (j = 0; j < cur.getDocNum(); ++j) {
                    newTrainingSet.addDoc(i, cur.getDoc(j));
                }
            }
            this.eStep(newTrainingSet);
        }
    }

    @Override
    public int classify(IRDoc curDoc) {
        int[] arrFreq;
        int[] arrIndex;
        if (curDoc.getKey().startsWith("external_unlabeled")) {
            arrIndex = this.externalUnlabeled[curDoc.getIndex() - this.externalDocOffset].getNonZeroColumns();
            arrFreq = this.externalUnlabeled[curDoc.getIndex() - this.externalDocOffset].getNonZeroIntScores();
        } else {
            arrIndex = this.indexReader.getTermIndexList(curDoc.getIndex());
            arrFreq = this.indexReader.getTermFrequencyList(curDoc.getIndex());
        }
        IntRow row = new IntRow(0, arrIndex.length, arrIndex, arrFreq);
        int label = this.classify(row);
        curDoc.setWeight(this.lastClassProb.get(label));
        return label;
    }

    private void eStep(DocClassSet trainingDocSet) {
        this.classPrior = this.getClassPrior(trainingDocSet);
        this.featureSelector.train(this.indexReader, trainingDocSet);
        this.model = new DoubleFlatDenseMatrix(trainingDocSet.getClassNum(), this.featureSelector.getSelectedFeatureNum());
        this.model.assign(1.0);
        for (int i = 0; i < trainingDocSet.getClassNum(); ++i) {
            int k;
            int classSum = this.featureSelector.getSelectedFeatureNum();
            DocClass cur = trainingDocSet.getDocClass(i);
            for (int j = 0; j < cur.getDocNum(); ++j) {
                int[] arrFreq;
                int[] arrIndex;
                IRDoc curDoc = cur.getDoc(j);
                if (curDoc.getKey().startsWith("external_unlabeled")) {
                    arrIndex = this.externalUnlabeled[curDoc.getIndex() - this.externalDocOffset].getNonZeroColumns();
                    arrFreq = this.externalUnlabeled[curDoc.getIndex() - this.externalDocOffset].getNonZeroIntScores();
                } else {
                    arrIndex = this.indexReader.getTermIndexList(curDoc.getIndex());
                    arrFreq = this.indexReader.getTermFrequencyList(curDoc.getIndex());
                }
                for (k = 0; k < arrIndex.length; ++k) {
                    int newTermIndex = this.featureSelector.map(arrIndex[k]);
                    if (newTermIndex < 0) continue;
                    classSum += arrFreq[k];
                    this.model.add(i, newTermIndex, arrFreq[k]);
                }
            }
            double rate = 1.0 / (double)classSum;
            for (k = 0; k < this.model.columns(); ++k) {
                this.model.setDouble(i, k, Math.log(this.model.getDouble(i, k) * rate));
            }
        }
    }

    private int[] getTermMap(IndexReader src, IndexReader dest) {
        int[] termMap = new int[src.getCollection().getTermNum()];
        for (int i = 0; i < termMap.length; ++i) {
            IRTerm irTerm = dest.getIRTerm(src.getTermKey(i));
            termMap[i] = irTerm != null ? irTerm.getIndex() : -1;
        }
        return termMap;
    }
}

