/*
 * Decompiled with CFR 0.152.
 */
package dragon.ir.kngbase;

import dragon.ir.index.IRSignatureIndexList;
import dragon.matrix.DoubleSuperSparseMatrix;
import dragon.matrix.IntSparseMatrix;
import dragon.nlp.Counter;
import dragon.nlp.Token;
import dragon.util.MathUtil;
import java.io.File;
import java.util.ArrayList;
import java.util.Date;
import java.util.HashMap;
import java.util.Random;

public class TopicSignatureModel {
    private IRSignatureIndexList srcIndexList;
    private IRSignatureIndexList destIndexList;
    private IntSparseMatrix srcSignatureDocMatrix;
    private IntSparseMatrix destDocSignatureMatrix;
    private IntSparseMatrix cooccurMatrix;
    private boolean useDocFrequency;
    private boolean useMeanTrim;
    private boolean useEM;
    private double probThreshold;
    private double bkgCoeffi;
    private int[] buf;
    private int iterationNum;
    private int totalDestSignatureNum;
    private int DOC_THRESH;

    public TopicSignatureModel(IRSignatureIndexList srcIndexList, IntSparseMatrix srcSignatureDocMatrix, IntSparseMatrix destDocSignatureMatrix) {
        this.srcIndexList = srcIndexList;
        this.srcSignatureDocMatrix = srcSignatureDocMatrix;
        this.destDocSignatureMatrix = destDocSignatureMatrix;
        this.useDocFrequency = true;
        this.useMeanTrim = true;
        this.probThreshold = 0.001;
        this.useEM = false;
        this.iterationNum = 15;
        this.bkgCoeffi = 0.5;
        this.totalDestSignatureNum = destDocSignatureMatrix.columns();
    }

    public TopicSignatureModel(IRSignatureIndexList srcIndexList, IntSparseMatrix cooccurMatrix) {
        this.srcIndexList = srcIndexList;
        this.cooccurMatrix = cooccurMatrix;
        this.useMeanTrim = true;
        this.probThreshold = 0.001;
        this.useEM = false;
        this.iterationNum = 15;
        this.bkgCoeffi = 0.5;
        this.totalDestSignatureNum = cooccurMatrix.columns();
    }

    public TopicSignatureModel(IRSignatureIndexList srcIndexList, IRSignatureIndexList destIndexList, IntSparseMatrix cooccurMatrix) {
        this.srcIndexList = srcIndexList;
        this.destIndexList = destIndexList;
        this.cooccurMatrix = cooccurMatrix;
        this.useMeanTrim = true;
        this.probThreshold = 0.001;
        this.useEM = true;
        this.iterationNum = 15;
        this.bkgCoeffi = 0.5;
        this.totalDestSignatureNum = cooccurMatrix.columns();
    }

    public TopicSignatureModel(IRSignatureIndexList srcIndexList, IntSparseMatrix srcSignatureDocMatrix, IRSignatureIndexList destIndexList, IntSparseMatrix destDocSignatureMatrix) {
        this.srcIndexList = srcIndexList;
        this.srcSignatureDocMatrix = srcSignatureDocMatrix;
        this.destIndexList = destIndexList;
        this.destDocSignatureMatrix = destDocSignatureMatrix;
        this.useDocFrequency = true;
        this.useMeanTrim = true;
        this.probThreshold = 0.001;
        this.useEM = true;
        this.iterationNum = 15;
        this.bkgCoeffi = 0.5;
        this.totalDestSignatureNum = destDocSignatureMatrix.columns();
    }

    public void setUseEM(boolean option) {
        this.useEM = option;
    }

    public boolean getUseEM() {
        return this.useEM;
    }

    public void setEMBackgroundCoefficient(double coeffi) {
        this.bkgCoeffi = coeffi;
    }

    public double getEMBackgroundCoefficient() {
        return this.bkgCoeffi;
    }

    public void setEMIterationNum(int iterationNum) {
        this.iterationNum = iterationNum;
    }

    public int getEMIterationNum() {
        return this.iterationNum;
    }

    public void setUseDocFrequency(boolean option) {
        this.useDocFrequency = option;
    }

    public boolean getUseDocFrequency() {
        return this.useDocFrequency;
    }

    public void setUseMeanTrim(boolean option) {
        this.useMeanTrim = option;
    }

    public boolean getUseMeanTrim() {
        return this.useMeanTrim;
    }

    public void setProbThreshold(double threshold) {
        this.probThreshold = threshold;
    }

    public double getProbThreshold() {
        return this.probThreshold;
    }

    public boolean genTransMatrix(int minDocFrequency, String matrixPath, String matrixKey) {
        String transIndexFile = matrixPath + "/" + matrixKey + ".index";
        String transMatrixFile = matrixPath + "/" + matrixKey + ".matrix";
        String transTIndexFile = matrixPath + "/" + matrixKey + "t.index";
        String transTMatrixFile = matrixPath + "/" + matrixKey + "t.matrix";
        File file = new File(transMatrixFile);
        if (file.exists()) {
            file.delete();
        }
        if ((file = new File(transIndexFile)).exists()) {
            file.delete();
        }
        if ((file = new File(transTMatrixFile)).exists()) {
            file.delete();
        }
        if ((file = new File(transTIndexFile)).exists()) {
            file.delete();
        }
        DoubleSuperSparseMatrix outputTransMatrix = new DoubleSuperSparseMatrix(transIndexFile, transMatrixFile, false, false);
        outputTransMatrix.setFlushInterval(Integer.MAX_VALUE);
        DoubleSuperSparseMatrix outputTransTMatrix = new DoubleSuperSparseMatrix(transTIndexFile, transTMatrixFile, false, false);
        outputTransTMatrix.setFlushInterval(Integer.MAX_VALUE);
        int cellNum = 0;
        int rowNum = this.srcIndexList.size();
        this.buf = new int[this.totalDestSignatureNum];
        if (this.destDocSignatureMatrix != null) {
            this.DOC_THRESH = this.computeDocThreshold(this.destDocSignatureMatrix);
        }
        for (int i = 0; i < rowNum; ++i) {
            if (i % 1000 == 0) {
                System.out.println(new Date().toString() + " Processing Row#" + i);
            }
            if (this.srcIndexList.getIRSignature(i).getDocFrequency() < minDocFrequency || this.cooccurMatrix != null && this.cooccurMatrix.getNonZeroNumInRow(i) < 5) continue;
            ArrayList tokenList = this.genSignatureTranslation(i);
            for (int j = 0; j < tokenList.size(); ++j) {
                Token curToken = (Token)tokenList.get(j);
                outputTransMatrix.add(i, curToken.getIndex(), curToken.getWeight());
                outputTransTMatrix.add(curToken.getIndex(), i, curToken.getWeight());
            }
            tokenList.clear();
            if ((cellNum += tokenList.size()) < 5000000) continue;
            outputTransTMatrix.flush();
            outputTransMatrix.flush();
            cellNum = 0;
        }
        outputTransTMatrix.finalizeData();
        outputTransTMatrix.close();
        outputTransMatrix.finalizeData();
        outputTransMatrix.close();
        return true;
    }

    public ArrayList genSignatureTranslation(int srcSignatureIndex) {
        int[] arrDoc;
        ArrayList tokenList = this.srcSignatureDocMatrix != null ? ((arrDoc = this.srcSignatureDocMatrix.getNonZeroColumnsInRow(srcSignatureIndex)).length > this.DOC_THRESH ? this.computeDistributionByArray(arrDoc) : this.computeDistributionByHash(arrDoc)) : this.computeDistributionByCooccurMatrix(srcSignatureIndex);
        if (this.useEM) {
            tokenList = this.emTopicSignatureModel(tokenList);
        }
        return tokenList;
    }

    private int computeDocThreshold(IntSparseMatrix doctermMatrix) {
        return (int)((double)doctermMatrix.columns() / this.computeAvgTermNum(doctermMatrix) / 8.0);
    }

    private double computeAvgTermNum(IntSparseMatrix doctermMatrix) {
        Random random = new Random();
        int num = Math.min(50, doctermMatrix.rows());
        double sum = 0.0;
        for (int i = 0; i < num; ++i) {
            int index = random.nextInt(doctermMatrix.rows());
            sum += (double)doctermMatrix.getNonZeroNumInRow(index);
        }
        return sum / (double)num;
    }

    private ArrayList computeDistributionByCooccurMatrix(int signatureIndex) {
        int i;
        double rowTotal = 0.0;
        int[] arrIndex = this.cooccurMatrix.getNonZeroColumnsInRow(signatureIndex);
        int[] arrFreq = this.cooccurMatrix.getNonZeroIntScoresInRow(signatureIndex);
        for (i = 0; i < arrFreq.length; ++i) {
            rowTotal += (double)arrFreq[i];
        }
        double mean = this.useMeanTrim ? rowTotal / (double)arrFreq.length : 0.5;
        if (mean < rowTotal * this.getMinInitProb()) {
            mean = rowTotal * this.getMinInitProb();
        }
        rowTotal = 0.0;
        ArrayList<Token> list = new ArrayList<Token>();
        for (i = 0; i < arrFreq.length; ++i) {
            if (!((double)arrFreq[i] >= mean)) continue;
            list.add(new Token(arrIndex[i], arrFreq[i]));
            rowTotal += (double)arrFreq[i];
        }
        for (i = 0; i < list.size(); ++i) {
            Token curToken = (Token)list.get(i);
            curToken.setWeight((double)curToken.getFrequency() / rowTotal);
        }
        return list;
    }

    private ArrayList computeDistributionByArray(int[] arrDoc) {
        int i;
        double rowTotal = 0.0;
        if (this.buf == null) {
            this.buf = new int[this.totalDestSignatureNum];
        }
        MathUtil.initArray(this.buf, 0);
        for (int j = 0; j < arrDoc.length; ++j) {
            int[] arrIndex = this.destDocSignatureMatrix.getNonZeroColumnsInRow(arrDoc[j]);
            int[] arrFreq = this.useDocFrequency ? null : this.destDocSignatureMatrix.getNonZeroIntScoresInRow(arrDoc[j]);
            for (int k = 0; k < arrIndex.length; ++k) {
                if (this.useDocFrequency) {
                    int n = arrIndex[k];
                    this.buf[n] = this.buf[n] + 1;
                    continue;
                }
                int n = arrIndex[k];
                this.buf[n] = this.buf[n] + arrFreq[k];
            }
        }
        int nonZeroNum = 0;
        for (i = 0; i < this.buf.length; ++i) {
            if (this.buf[i] <= 0) continue;
            ++nonZeroNum;
            rowTotal += (double)this.buf[i];
        }
        double mean = this.useMeanTrim ? rowTotal / (double)nonZeroNum : 0.5;
        if (mean < rowTotal * this.getMinInitProb()) {
            mean = rowTotal * this.getMinInitProb();
        }
        rowTotal = 0.0;
        ArrayList<Token> list = new ArrayList<Token>();
        for (i = 0; i < this.buf.length; ++i) {
            if (!((double)this.buf[i] >= mean)) continue;
            list.add(new Token(i, this.buf[i]));
            rowTotal += (double)this.buf[i];
        }
        for (i = 0; i < list.size(); ++i) {
            Token curToken = (Token)list.get(i);
            curToken.setWeight((double)curToken.getFrequency() / rowTotal);
        }
        return list;
    }

    private ArrayList computeDistributionByHash(int[] arrDoc) {
        Token curToken;
        ArrayList<Token> list;
        int i;
        ArrayList<Token> tokenList = this.countTokensByHashMap(arrDoc);
        double rowTotal = 0.0;
        for (i = 0; i < tokenList.size(); ++i) {
            rowTotal += (double)((Token)tokenList.get(i)).getFrequency();
        }
        if (this.useMeanTrim || rowTotal * this.getMinInitProb() > 1.0) {
            double mean = this.useMeanTrim ? rowTotal / (double)tokenList.size() : 0.5;
            if (mean < rowTotal * this.getMinInitProb()) {
                mean = rowTotal * this.getMinInitProb();
            }
            list = new ArrayList<Token>();
            rowTotal = 0.0;
            for (i = 0; i < tokenList.size(); ++i) {
                curToken = (Token)tokenList.get(i);
                if (!((double)curToken.getFrequency() >= mean)) continue;
                list.add(curToken);
                rowTotal += (double)curToken.getFrequency();
            }
            tokenList.clear();
        } else {
            list = tokenList;
        }
        for (i = 0; i < list.size(); ++i) {
            curToken = (Token)list.get(i);
            curToken.setWeight((double)curToken.getFrequency() / rowTotal);
        }
        return list;
    }

    private ArrayList countTokensByHashMap(int[] arrDoc) {
        Counter counter;
        HashMap<Token, Counter> hash = new HashMap<Token, Counter>();
        for (int j = 0; j < arrDoc.length; ++j) {
            int termNum = this.destDocSignatureMatrix.getNonZeroNumInRow(arrDoc[j]);
            if (termNum == 0) continue;
            int[] arrTerm = this.destDocSignatureMatrix.getNonZeroColumnsInRow(arrDoc[j]);
            int[] arrFreq = this.useDocFrequency ? null : this.destDocSignatureMatrix.getNonZeroIntScoresInRow(arrDoc[j]);
            for (int i = 0; i < termNum; ++i) {
                Token curToken = this.useDocFrequency ? new Token(arrTerm[i], 1) : new Token(arrTerm[i], arrFreq[i]);
                counter = (Counter)hash.get(curToken);
                if (counter == null) {
                    counter = new Counter(curToken.getFrequency());
                    hash.put(curToken, counter);
                    continue;
                }
                counter.addCount(curToken.getFrequency());
            }
        }
        ArrayList<Token> list = new ArrayList<Token>(hash.size());
        for (Token curToken : hash.keySet()) {
            counter = (Counter)hash.get(curToken);
            curToken.setFrequency(counter.getCount());
            list.add(curToken);
        }
        hash.clear();
        return list;
    }

    private double getMinInitProb() {
        return this.probThreshold;
    }

    private ArrayList emTopicSignatureModel(ArrayList list) {
        Token curToken;
        int i;
        int termNum = list.size();
        double[] arrProb = new double[termNum];
        double[] arrCollectionProb = new double[termNum];
        double weightSum = 0.0;
        for (i = 0; i < termNum; ++i) {
            curToken = (Token)list.get(i);
            arrCollectionProb[i] = this.useDocFrequency ? (double)this.destIndexList.getIRSignature(curToken.getIndex()).getDocFrequency() : (double)this.destIndexList.getIRSignature(curToken.getIndex()).getFrequency();
            weightSum += arrCollectionProb[i];
        }
        for (i = 0; i < termNum; ++i) {
            arrCollectionProb[i] = arrCollectionProb[i] / weightSum;
        }
        for (i = 0; i < this.iterationNum; ++i) {
            int j;
            weightSum = 0.0;
            for (j = 0; j < termNum; ++j) {
                curToken = (Token)list.get(j);
                arrProb[j] = (1.0 - this.bkgCoeffi) * curToken.getWeight() / ((1.0 - this.bkgCoeffi) * curToken.getWeight() + this.bkgCoeffi * arrCollectionProb[j]) * (double)curToken.getFrequency();
                weightSum += arrProb[j];
            }
            for (j = 0; j < termNum; ++j) {
                curToken = (Token)list.get(j);
                curToken.setWeight(arrProb[j] / weightSum);
            }
        }
        return list;
    }
}

