/*
 * Decompiled with CFR 0.152.
 */
package dragon.ir.topicmodel;

import dragon.ir.topicmodel.AbstractModel;
import dragon.matrix.IntSparseMatrix;
import java.util.Random;

public class CrossMixtureModel
extends AbstractModel {
    protected IntSparseMatrix[] arrTopicReader;
    protected double[] bkgModel;
    protected double bkgCoefficient;
    protected double comCoefficient;
    protected int themeNum;
    protected int collectionNum;
    protected int maxTermNum;
    protected int maxDocNum;
    private double[][][] arrDocWeight;
    private double[][][] arrProb;
    private double[][] arrCommonProb;

    public CrossMixtureModel(IntSparseMatrix[] arrTopicMatrix, int themeNum, double[] bkgModel, double bkgCoefficient, double comCoefficient) {
        int i;
        this.arrTopicReader = arrTopicMatrix;
        this.themeNum = themeNum;
        this.collectionNum = this.arrTopicReader.length;
        this.bkgModel = new double[bkgModel.length];
        this.comCoefficient = comCoefficient;
        for (i = 0; i < bkgModel.length; ++i) {
            this.bkgModel[i] = bkgModel[i] * bkgCoefficient;
        }
        this.bkgCoefficient = bkgCoefficient;
        this.maxTermNum = this.arrTopicReader[0].columns();
        this.maxDocNum = this.arrTopicReader[0].rows();
        for (i = 1; i < this.arrTopicReader.length; ++i) {
            if (this.arrTopicReader[i].columns() > this.maxTermNum) {
                this.maxTermNum = this.arrTopicReader[i].columns();
            }
            if (this.arrTopicReader[i].rows() <= this.maxDocNum) continue;
            this.maxDocNum = this.arrTopicReader[i].rows();
        }
    }

    public double[][][] getModels() {
        return this.arrProb;
    }

    public double[][] getCommonModels() {
        return this.arrCommonProb;
    }

    public double[][][] getDocMemberships() {
        return this.arrDocWeight;
    }

    public boolean estimateModel() {
        this.arrProb = new double[this.collectionNum][this.themeNum][this.maxTermNum];
        this.arrCommonProb = new double[this.themeNum][this.maxTermNum];
        this.arrDocWeight = new double[this.collectionNum][this.themeNum][this.maxDocNum];
        double[][][] arrTempProb = new double[this.collectionNum][this.themeNum][this.maxTermNum];
        double[][] arrTempCommonProb = new double[this.themeNum][this.maxTermNum];
        double[] arrDocWeightSum = new double[this.themeNum];
        this.initialize(this.maxTermNum, this.collectionNum, this.themeNum, this.maxDocNum, this.arrCommonProb, this.arrProb, this.arrDocWeight);
        this.printStatus("Estimating the coefficients of simple mixture model...");
        for (int k = 0; k < this.iterations; ++k) {
            int n;
            int j;
            int i;
            this.printStatus("Iteration #" + (k + 1));
            for (i = 0; i < this.themeNum; ++i) {
                for (j = 0; j < this.maxTermNum; ++j) {
                    arrTempCommonProb[i][j] = 0.0;
                }
            }
            for (n = 0; n < this.collectionNum; ++n) {
                for (i = 0; i < this.themeNum; ++i) {
                    for (j = 0; j < this.maxTermNum; ++j) {
                        arrTempProb[n][i][j] = 0.0;
                    }
                }
            }
            for (n = 0; n < this.collectionNum; ++n) {
                int docNum = this.arrTopicReader[n].rows();
                for (i = 0; i < docNum; ++i) {
                    int m;
                    int[] arrIndex = this.arrTopicReader[n].getNonZeroColumnsInRow(i);
                    int[] arrFreq = this.arrTopicReader[n].getNonZeroIntScoresInRow(i);
                    for (m = 0; m < this.themeNum; ++m) {
                        arrDocWeightSum[m] = 0.0;
                    }
                    for (j = 0; j < arrIndex.length; ++j) {
                        int termIndex = arrIndex[j];
                        double themeProbSum = 0.0;
                        for (m = 0; m < this.themeNum; ++m) {
                            themeProbSum += (this.comCoefficient * this.arrCommonProb[m][j] + (1.0 - this.comCoefficient) * this.arrProb[n][m][j]) * this.arrDocWeight[n][m][i];
                        }
                        double bkgProb = this.bkgModel[termIndex] / (themeProbSum * (1.0 - this.bkgCoefficient) + this.bkgModel[termIndex]);
                        for (m = 0; m < this.themeNum; ++m) {
                            double themeProb = themeProbSum == 0.0 ? 0.0 : (this.comCoefficient * this.arrCommonProb[m][termIndex] + (1.0 - this.comCoefficient) * this.arrProb[n][m][termIndex]) * this.arrDocWeight[n][m][i] / themeProbSum;
                            double comThemeProb = this.comCoefficient * this.arrCommonProb[m][termIndex] + (1.0 - this.comCoefficient) * this.arrProb[n][m][termIndex];
                            comThemeProb = comThemeProb > 0.0 ? this.comCoefficient * this.arrCommonProb[m][termIndex] / comThemeProb : 0.0;
                            double termProb = (double)arrFreq[j] * themeProb;
                            int n2 = m;
                            arrDocWeightSum[n2] = arrDocWeightSum[n2] + termProb;
                            double[] dArray = arrTempProb[n][m];
                            int n3 = termIndex;
                            dArray[n3] = dArray[n3] + (termProb *= 1.0 - bkgProb) * (1.0 - comThemeProb);
                            double[] dArray2 = arrTempCommonProb[m];
                            int n4 = termIndex;
                            dArray2[n4] = dArray2[n4] + termProb * comThemeProb;
                        }
                    }
                    double docWeightSum = 0.0;
                    for (m = 0; m < this.themeNum; ++m) {
                        docWeightSum += arrDocWeightSum[m];
                    }
                    if (docWeightSum > 0.0) {
                        for (m = 0; m < this.themeNum; ++m) {
                            this.arrDocWeight[n][m][i] = arrDocWeightSum[m] / docWeightSum;
                        }
                        continue;
                    }
                    for (m = 0; m < this.themeNum; ++m) {
                        this.arrDocWeight[n][m][i] = 0.0;
                    }
                }
            }
            for (i = 0; i < this.themeNum; ++i) {
                double termProbSum = 0.0;
                for (j = 0; j < this.maxTermNum; ++j) {
                    termProbSum += arrTempCommonProb[i][j];
                }
                for (j = 0; j < this.maxTermNum; ++j) {
                    this.arrCommonProb[i][j] = arrTempCommonProb[i][j] / termProbSum;
                }
                for (n = 0; n < this.collectionNum; ++n) {
                    termProbSum = 0.0;
                    for (j = 0; j < this.maxTermNum; ++j) {
                        termProbSum += arrTempProb[n][i][j];
                    }
                    for (j = 0; j < this.maxTermNum; ++j) {
                        this.arrProb[n][i][j] = arrTempProb[n][i][j] / termProbSum;
                    }
                }
            }
        }
        this.printStatus("");
        return true;
    }

    protected void initialize(int maxTermNum, int collectionNum, int themeNum, int maxDocNum, double[][] arrCommonModel, double[][][] arrModel, double[][][] arrDocMembership) {
        int n;
        int j;
        int i;
        double termProb = 1.0 / (double)maxTermNum;
        for (i = 0; i < themeNum; ++i) {
            for (j = 0; j < maxTermNum; ++j) {
                arrCommonModel[i][j] = termProb;
            }
        }
        for (n = 0; n < collectionNum; ++n) {
            for (i = 0; i < themeNum; ++i) {
                for (j = 0; j < maxTermNum; ++j) {
                    arrModel[n][i][j] = termProb;
                }
            }
        }
        Random random = this.seed >= 0 ? new Random(this.seed) : new Random();
        for (n = 0; n < collectionNum; ++n) {
            for (j = 0; j < maxDocNum; ++j) {
                double docProb = 0.0;
                for (i = 0; i < themeNum; ++i) {
                    arrDocMembership[n][i][j] = random.nextDouble();
                    docProb += arrDocMembership[n][i][j];
                }
                for (i = 0; i < themeNum; ++i) {
                    arrDocMembership[n][i][j] = arrDocMembership[n][i][j] / docProb;
                }
            }
        }
    }
}

