/*
 * Decompiled with CFR 0.152.
 */
package com.aliasi.classify;

import com.aliasi.classify.JointClassification;
import com.aliasi.classify.JointClassifier;
import com.aliasi.lm.LanguageModel;
import com.aliasi.stats.MultivariateDistribution;
import com.aliasi.util.ScoredObject;
import com.aliasi.util.Strings;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;

public class LMClassifier<L extends LanguageModel, M extends MultivariateDistribution>
implements JointClassifier<CharSequence> {
    final L[] mLanguageModels;
    final M mCategoryDistribution;
    final HashMap<String, L> mCategoryToModel;
    final String[] mCategories;

    public LMClassifier(String[] categories, L[] languageModels, M categoryDistribution) {
        String msg;
        HashSet<String> categorySet = new HashSet<String>();
        String[] stringArray = categories;
        int n = categories.length;
        int n2 = 0;
        while (n2 < n) {
            String cat = stringArray[n2];
            if (!categorySet.add(cat)) {
                String msg2 = "Duplicate category=" + cat;
                throw new IllegalArgumentException(msg2);
            }
            ++n2;
        }
        if (categories.length < 2) {
            msg = "Require at least two categories. Found categories.length=" + categories.length;
            throw new IllegalArgumentException(msg);
        }
        if (categories.length != ((MultivariateDistribution)categoryDistribution).numDimensions()) {
            msg = "Require same number of categories as dimensions. Found categories.length=" + categories.length + " Found categoryDistribution.numDimensions()=" + ((MultivariateDistribution)categoryDistribution).numDimensions();
            throw new IllegalArgumentException(msg);
        }
        this.mCategories = categories;
        if (categories.length != languageModels.length) {
            msg = "Categories and language models must be same length. Found categories length=" + categories.length + " Found language models length=" + languageModels.length;
            throw new IllegalArgumentException(msg);
        }
        this.mLanguageModels = languageModels;
        this.mCategoryDistribution = categoryDistribution;
        this.mCategoryToModel = new HashMap();
        int i = 0;
        while (i < categories.length) {
            this.mCategoryToModel.put(categories[i], languageModels[i]);
            ++i;
        }
    }

    public String[] categories() {
        return (String[])this.mCategories.clone();
    }

    public L languageModel(String category) {
        int i = 0;
        while (i < this.mCategories.length) {
            if (category.equals(this.mCategories[i])) {
                return this.mLanguageModels[i];
            }
            ++i;
        }
        String msg = "Category not known.  Category=" + category;
        throw new IllegalArgumentException(msg);
    }

    public M categoryDistribution() {
        return this.mCategoryDistribution;
    }

    @Override
    public JointClassification classify(CharSequence cSeq) {
        if (!(cSeq instanceof CharSequence)) {
            String msg = "LM Classification requires CharSequence input. Found class=" + (cSeq == null ? null : cSeq.getClass());
            throw new IllegalArgumentException(msg);
        }
        return this.classifyJoint(Strings.toCharArray(cSeq), 0, cSeq.length());
    }

    public JointClassification classifyJoint(char[] cs, int start, int end) {
        Strings.checkArgsStartEnd(cs, start, end);
        ScoredObject[] estimates = new ScoredObject[this.categories().length];
        int i = 0;
        while (i < this.categories().length) {
            String category = this.categories()[i];
            L model = this.mLanguageModels[i];
            double charsGivenCatLogProb = model.log2Estimate(new String(cs, start, end - start));
            double catLogProb = ((MultivariateDistribution)this.mCategoryDistribution).log2Probability(category);
            double charsCatJointLogProb = charsGivenCatLogProb + catLogProb;
            estimates[i] = new ScoredObject<String>(category, charsCatJointLogProb);
            ++i;
        }
        return LMClassifier.toJointClassification(estimates, end - start + 2);
    }

    static JointClassification toJointClassification(ScoredObject<String>[] estimates, double length) {
        Arrays.sort(estimates, ScoredObject.reverseComparator());
        String[] categories = new String[estimates.length];
        double[] jointEstimates = new double[estimates.length];
        double[] scores = new double[estimates.length];
        int i = 0;
        while (i < estimates.length) {
            categories[i] = estimates[i].getObject();
            jointEstimates[i] = estimates[i].score();
            scores[i] = jointEstimates[i] / length;
            ++i;
        }
        return new JointClassification(categories, scores, jointEstimates);
    }
}

