/*
 * Decompiled with CFR 0.152.
 */
package dragon.ir.classification;

import dragon.ir.classification.DocClass;
import dragon.ir.classification.DocClassSet;
import dragon.ir.classification.NBClassifier;
import dragon.ir.classification.featureselection.NullFeatureSelector;
import dragon.ir.index.IRDoc;
import dragon.ir.index.IRTerm;
import dragon.ir.index.IndexReader;
import dragon.ir.kngbase.KnowledgeBase;
import dragon.matrix.DoubleFlatDenseMatrix;
import dragon.matrix.DoubleSparseMatrix;
import dragon.util.MathUtil;

public class SemanticNBClassifier
extends NBClassifier {
    private IndexReader topicIndexReader;
    private DoubleSparseMatrix topicTransMatrix;
    private double transCoefficient;
    private double bkgCoefficient;
    private int[] topicMap;
    private int[] termMap;

    public SemanticNBClassifier(String modelFile) {
        super(modelFile);
    }

    public SemanticNBClassifier(IndexReader indexReader, double bkgCoefficient) {
        super(indexReader);
        this.topicIndexReader = null;
        this.topicTransMatrix = null;
        this.transCoefficient = 0.0;
        this.bkgCoefficient = bkgCoefficient;
        this.featureSelector = new NullFeatureSelector();
    }

    public SemanticNBClassifier(IndexReader indexReader, IndexReader topicIndexReader, DoubleSparseMatrix topicTransMatrix, double transCoefficient, double bkgCoefficient) {
        super(indexReader);
        int i;
        this.featureSelector = new NullFeatureSelector();
        this.topicIndexReader = topicIndexReader;
        this.topicTransMatrix = topicTransMatrix;
        this.transCoefficient = transCoefficient;
        this.bkgCoefficient = bkgCoefficient;
        this.topicMap = new int[topicIndexReader.getCollection().getTermNum()];
        for (i = 0; i < this.topicMap.length; ++i) {
            this.topicMap[i] = i;
        }
        this.termMap = new int[indexReader.getCollection().getTermNum()];
        for (i = 0; i < this.termMap.length; ++i) {
            this.termMap[i] = i;
        }
    }

    public SemanticNBClassifier(IndexReader indexReader, IndexReader topicIndexReader, KnowledgeBase kngBase, double transCoefficient, double bkgCoefficient) {
        super(indexReader);
        int i;
        this.featureSelector = new NullFeatureSelector();
        this.topicIndexReader = topicIndexReader;
        this.topicTransMatrix = kngBase.getKnowledgeMatrix();
        this.transCoefficient = transCoefficient;
        this.bkgCoefficient = bkgCoefficient;
        this.topicMap = new int[topicIndexReader.getCollection().getTermNum()];
        for (i = 0; i < this.topicMap.length; ++i) {
            this.topicMap[i] = kngBase.getRowKeyList().search(topicIndexReader.getTermKey(i));
        }
        this.termMap = new int[kngBase.getColumnKeyList().size()];
        for (i = 0; i < this.termMap.length; ++i) {
            IRTerm curTerm = indexReader.getIRTerm(kngBase.getColumnKeyList().search(i));
            this.termMap[i] = curTerm == null ? -1 : curTerm.getIndex();
        }
    }

    public double getTranslationCoefficient() {
        return this.transCoefficient;
    }

    public void setTranslationCoefficient(double transCoefficient) {
        this.transCoefficient = transCoefficient;
    }

    public double getBackgroundCoefficient() {
        return this.bkgCoefficient;
    }

    public void setBackgroundCoefficient(double bkgCoefficient) {
        this.bkgCoefficient = bkgCoefficient;
    }

    @Override
    public void train(DocClassSet trainingDocSet) {
        int i;
        if (this.indexReader == null && this.doctermMatrix == null) {
            return;
        }
        this.classPrior = this.getClassPrior(trainingDocSet);
        this.featureSelector.train(this.indexReader, trainingDocSet);
        this.arrLabel = new String[trainingDocSet.getClassNum()];
        for (i = 0; i < trainingDocSet.getClassNum(); ++i) {
            this.arrLabel[i] = trainingDocSet.getDocClass(i).getClassName();
        }
        this.model = new DoubleFlatDenseMatrix(trainingDocSet.getClassNum(), this.featureSelector.getSelectedFeatureNum());
        double[] bkgModel = this.getBackgroundModel(this.indexReader);
        for (i = 0; i < trainingDocSet.getClassNum(); ++i) {
            double a;
            int k;
            int classSum = 0;
            DocClass cur = trainingDocSet.getDocClass(i);
            for (int j = 0; j < cur.getDocNum(); ++j) {
                IRDoc curDoc = cur.getDoc(j);
                int[] arrIndex = this.indexReader.getTermIndexList(curDoc.getIndex());
                int[] arrFreq = this.indexReader.getTermFrequencyList(curDoc.getIndex());
                for (k = 0; k < arrIndex.length; ++k) {
                    int newTermIndex = this.featureSelector.map(arrIndex[k]);
                    if (newTermIndex < 0) continue;
                    classSum += arrFreq[k];
                    this.model.add(i, newTermIndex, arrFreq[k]);
                }
            }
            if (this.topicTransMatrix != null) {
                double[] transModel = this.computeTranslationModel(cur);
                a = (1.0 - this.bkgCoefficient) * (1.0 - this.transCoefficient) / (double)classSum;
                double b = (1.0 - this.transCoefficient) * this.bkgCoefficient;
                for (k = 0; k < this.model.columns(); ++k) {
                    this.model.setDouble(i, k, Math.log(transModel[k] * this.transCoefficient + this.model.getDouble(i, k) * a + bkgModel[k] * b));
                }
                continue;
            }
            a = (1.0 - this.bkgCoefficient) / (double)classSum;
            for (k = 0; k < this.model.columns(); ++k) {
                this.model.setDouble(i, k, Math.log(this.model.getDouble(i, k) * a + bkgModel[k] * this.bkgCoefficient));
            }
        }
    }

    private double[] computeTranslationModel(DocClass curClass) {
        int termIndex;
        int topicIndex;
        int j;
        int[] arrIndex;
        int i;
        int topicNum = this.topicIndexReader.getCollection().getTermNum();
        int[] arrCount = new int[topicNum];
        int termNum = this.indexReader.getCollection().getTermNum();
        int docNum = this.topicIndexReader.getCollection().getDocNum();
        for (i = 0; i < curClass.getDocNum(); ++i) {
            IRDoc curDoc = curClass.getDoc(i);
            if (curDoc.getIndex() >= docNum) continue;
            arrIndex = this.topicIndexReader.getTermIndexList(curDoc.getIndex());
            int[] arrFreq = this.topicIndexReader.getTermFrequencyList(curDoc.getIndex());
            if (arrIndex == null) continue;
            for (j = 0; j < arrIndex.length; ++j) {
                int n = arrIndex[j];
                arrCount[n] = arrCount[n] + arrFreq[j];
            }
        }
        for (i = 0; i < this.topicMap.length; ++i) {
            topicIndex = this.topicMap[i];
            if (topicIndex < 0) {
                arrCount[i] = 0;
                continue;
            }
            if (topicIndex >= this.topicTransMatrix.rows()) {
                arrCount[i] = 0;
                continue;
            }
            if (this.topicTransMatrix.getNonZeroNumInRow(topicIndex) > 0) continue;
            arrCount[i] = 0;
        }
        double sum = MathUtil.sumArray(arrCount);
        double[] arrModel = new double[termNum];
        for (i = 0; i < topicNum; ++i) {
            if (arrCount[i] <= 0) continue;
            topicIndex = this.topicMap[i];
            double rate = (double)arrCount[i] / sum;
            arrIndex = this.topicTransMatrix.getNonZeroColumnsInRow(topicIndex);
            double[] arrScore = this.topicTransMatrix.getNonZeroDoubleScoresInRow(topicIndex);
            for (j = 0; j < arrIndex.length; ++j) {
                termIndex = this.termMap[arrIndex[j]];
                if (termIndex < 0) continue;
                int n = termIndex;
                arrModel[n] = arrModel[n] + rate * arrScore[j];
            }
        }
        if (arrModel.length == this.featureSelector.getSelectedFeatureNum()) {
            return arrModel;
        }
        double[] arrSelectedModel = new double[this.featureSelector.getSelectedFeatureNum()];
        sum = 0.0;
        for (i = 0; i < arrModel.length; ++i) {
            termIndex = this.featureSelector.map(i);
            if (termIndex < 0) continue;
            sum += arrModel[i];
            arrSelectedModel[termIndex] = arrModel[i];
        }
        for (i = 0; i < arrSelectedModel.length; ++i) {
            arrSelectedModel[i] = arrSelectedModel[i] / sum;
        }
        return arrSelectedModel;
    }

    private double[] getBackgroundModel(IndexReader reader) {
        int i;
        int termNum = reader.getCollection().getTermNum();
        int featureNum = this.featureSelector.getSelectedFeatureNum();
        double sum = 0.0;
        double[] arrModel = new double[featureNum];
        for (i = 0; i < termNum; ++i) {
            int newIndex = this.featureSelector.map(i);
            if (newIndex < 0) continue;
            arrModel[newIndex] = reader.getIRTerm(i).getFrequency();
            sum += arrModel[newIndex];
        }
        for (i = 0; i < featureNum; ++i) {
            arrModel[i] = arrModel[i] / sum;
        }
        return arrModel;
    }
}

