/*
 * Decompiled with CFR 0.152.
 */
package dragon.ir.search.smooth;

import dragon.ir.index.IRDoc;
import dragon.ir.index.IRTerm;
import dragon.ir.index.IndexReader;
import dragon.ir.query.Predicate;
import dragon.ir.query.RelSimpleQuery;
import dragon.ir.query.SimpleTermPredicate;
import java.io.PrintWriter;
import java.util.ArrayList;

public abstract class AbstractMixtureWeightEM {
    protected IndexReader indexReader;
    protected int iterationNum;
    protected int componentNum;
    private PrintWriter statusOut;
    private boolean docFirst;

    public AbstractMixtureWeightEM(IndexReader indexReader, int componentNum, int iterationNum, boolean docFirst) {
        this.indexReader = indexReader;
        this.iterationNum = iterationNum;
        this.componentNum = componentNum;
        this.docFirst = docFirst;
    }

    protected abstract void setInitialParameters(double[] var1, IRDoc[] var2);

    protected abstract void init(RelSimpleQuery var1);

    protected abstract void setDoc(IRDoc var1);

    protected abstract void setQueryTerm(SimpleTermPredicate var1);

    protected abstract void getComponentValue(SimpleTermPredicate var1, int var2, double[] var3);

    protected abstract void getComponentValue(IRDoc var1, int var2, double[] var3);

    public void setStatusOut(PrintWriter out) {
        this.statusOut = out;
    }

    public double[] estimateModelCoefficient(RelSimpleQuery query) {
        if (this.docFirst) {
            return this.breadthFirstEstimate(query);
        }
        return this.depthFirstEstimate(query);
    }

    private double[] breadthFirstEstimate(RelSimpleQuery query) {
        SimpleTermPredicate[] arrPredicate = this.checkSimpleTermQuery(query);
        this.init(query);
        double[] arrPreParam = new double[this.componentNum];
        double[] arrParam = new double[this.componentNum];
        double[] arrParamDocSum = new double[this.componentNum];
        double[] arrComp = new double[this.componentNum];
        int termNum = arrPredicate.length;
        int docNum = this.getDocNum();
        double[] arrDocWeight = new double[docNum];
        IRDoc[] arrDoc = new IRDoc[docNum];
        this.setInitialParameters(arrPreParam, arrDoc);
        this.printStatus("Estimating the coefficients of the mixed model...");
        for (int k = 0; k < this.iterationNum; ++k) {
            int m;
            this.printStatus("Iteration #" + (k + 1));
            double allDocSum = 0.0;
            for (m = 0; m < this.componentNum; ++m) {
                arrParam[m] = 0.0;
            }
            for (int i = 0; i < docNum; ++i) {
                double docSum = arrDoc[i].getWeight();
                for (m = 0; m < this.componentNum; ++m) {
                    arrParamDocSum[m] = 0.0;
                }
                this.setDoc(arrDoc[i]);
                for (int j = 0; j < termNum; ++j) {
                    IRTerm docTerm = this.indexReader.getIRTerm(arrPredicate[j].getIndex(), i);
                    this.getComponentValue(arrPredicate[j], docTerm.getFrequency(), arrComp);
                    double termProb = 0.0;
                    for (m = 0; m < this.componentNum; ++m) {
                        arrComp[m] = arrPreParam[m] * arrComp[m];
                        termProb += arrComp[m];
                    }
                    docSum *= termProb;
                    for (m = 0; m < this.componentNum; ++m) {
                        int n = m;
                        arrParamDocSum[n] = arrParamDocSum[n] + arrComp[m] / termProb;
                    }
                }
                for (m = 0; m < this.componentNum; ++m) {
                    int n = m;
                    arrParam[n] = arrParam[n] + arrDoc[i].getWeight() * arrParamDocSum[m];
                }
                arrDocWeight[i] = docSum;
                allDocSum += arrDocWeight[i];
            }
            for (m = 0; m < this.componentNum; ++m) {
                arrPreParam[m] = arrParam[m] / (double)termNum;
                this.printStatus("Coefficient #" + (m + 1) + " " + arrPreParam[m]);
            }
            for (m = 0; m < docNum; ++m) {
                arrDoc[m].setWeight(arrDocWeight[m] / allDocSum);
            }
        }
        this.printStatus("");
        return arrPreParam;
    }

    private double[] depthFirstEstimate(RelSimpleQuery query) {
        SimpleTermPredicate[] arrPredicate = this.checkSimpleTermQuery(query);
        this.init(query);
        double[] arrPreParam = new double[this.componentNum];
        double[] arrParam = new double[this.componentNum];
        double[] arrComp = new double[this.componentNum];
        int termNum = arrPredicate.length;
        int docNum = this.getDocNum();
        double[] arrDocWeight = new double[docNum];
        IRDoc[] arrDoc = new IRDoc[docNum];
        this.setInitialParameters(arrPreParam, arrDoc);
        this.printStatus("Estimating the coefficients of the mixed model...");
        for (int count = 0; count < this.iterationNum; ++count) {
            int m;
            int i;
            this.printStatus("Iteration #" + (count + 1));
            for (i = 0; i < docNum; ++i) {
                arrDocWeight[i] = arrDoc[i].getWeight();
            }
            for (m = 0; m < this.componentNum; ++m) {
                arrParam[m] = 0.0;
            }
            for (i = 0; i < arrPredicate.length; ++i) {
                double termProb;
                this.setQueryTerm(arrPredicate[i]);
                int[] arrIndex = this.indexReader.getTermDocIndexList(arrPredicate[i].getIndex());
                int[] arrFreq = this.indexReader.getTermDocFrequencyList(arrPredicate[i].getIndex());
                int k = 0;
                for (int j = 0; j < arrIndex.length; ++j) {
                    while (k < arrIndex[j]) {
                        this.getComponentValue(arrDoc[k], 0, arrComp);
                        termProb = 0.0;
                        for (m = 0; m < this.componentNum; ++m) {
                            arrComp[m] = arrPreParam[m] * arrComp[m];
                            termProb += arrComp[m];
                        }
                        arrDocWeight[k] = arrDocWeight[k] * termProb;
                        for (m = 0; m < this.componentNum; ++m) {
                            int n = m;
                            arrParam[n] = arrParam[n] + arrDoc[k].getWeight() * arrComp[m] / termProb;
                        }
                        ++k;
                    }
                    this.getComponentValue(arrDoc[k], arrFreq[j], arrComp);
                    termProb = 0.0;
                    for (m = 0; m < this.componentNum; ++m) {
                        arrComp[m] = arrPreParam[m] * arrComp[m];
                        termProb += arrComp[m];
                    }
                    arrDocWeight[k] = arrDocWeight[k] * termProb;
                    for (m = 0; m < this.componentNum; ++m) {
                        int n = m;
                        arrParam[n] = arrParam[n] + arrDoc[k].getWeight() * arrComp[m] / termProb;
                    }
                    ++k;
                }
                while (k < docNum) {
                    this.getComponentValue(arrDoc[k], 0, arrComp);
                    termProb = 0.0;
                    for (m = 0; m < this.componentNum; ++m) {
                        arrComp[m] = arrPreParam[m] * arrComp[m];
                        termProb += arrComp[m];
                    }
                    arrDocWeight[k] = arrDocWeight[k] * termProb;
                    for (m = 0; m < this.componentNum; ++m) {
                        int n = m;
                        arrParam[n] = arrParam[n] + arrDoc[k].getWeight() * arrComp[m] / termProb;
                    }
                    ++k;
                }
            }
            for (m = 0; m < this.componentNum; ++m) {
                arrPreParam[m] = arrParam[m] / (double)termNum;
                this.printStatus("Coefficient #" + (m + 1) + " " + arrPreParam[m]);
            }
            double allSum = 0.0;
            for (i = 0; i < docNum; ++i) {
                allSum += arrDocWeight[i];
            }
            for (i = 0; i < docNum; ++i) {
                arrDoc[i].setWeight(arrDocWeight[i] / allSum);
            }
        }
        this.printStatus("");
        return arrPreParam;
    }

    protected int getDocNum() {
        return this.indexReader.getCollection().getDocNum();
    }

    protected IRDoc getDoc(int seq) {
        return this.indexReader.getDoc(seq);
    }

    private void printStatus(String line) {
        try {
            System.out.println(line);
            if (this.statusOut != null) {
                this.statusOut.write(line + "\n");
                this.statusOut.flush();
            }
        }
        catch (Exception e) {
            e.printStackTrace();
        }
    }

    private SimpleTermPredicate[] checkSimpleTermQuery(RelSimpleQuery query) {
        int i;
        ArrayList<SimpleTermPredicate> list = new ArrayList<SimpleTermPredicate>();
        for (i = 0; i < query.getChildNum(); ++i) {
            IRTerm curIRTerm;
            if (!((Predicate)query.getChild(i)).isTermPredicate()) continue;
            SimpleTermPredicate predicate = (SimpleTermPredicate)query.getChild(i);
            if (predicate.getDocFrequency() == 0 && (curIRTerm = this.indexReader.getIRTerm(predicate.getKey())) != null) {
                predicate.setDocFrequency(curIRTerm.getDocFrequency());
                predicate.setFrequency(curIRTerm.getFrequency());
                predicate.setIndex(curIRTerm.getIndex());
            }
            if (predicate.getDocFrequency() <= 0) continue;
            list.add(predicate);
        }
        SimpleTermPredicate[] arrPredicate = new SimpleTermPredicate[list.size()];
        for (i = 0; i < list.size(); ++i) {
            arrPredicate[i] = ((SimpleTermPredicate)list.get(i)).copy();
        }
        return arrPredicate;
    }
}

