/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.types;

import cc.mallet.types.Alphabet;
import cc.mallet.types.FeatureSequence;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.Multinomial;
import cc.mallet.types.TokenSequence;
import cc.mallet.util.Randoms;
import gnu.trove.TIntHashSet;
import gnu.trove.TIntIntHashMap;
import java.io.BufferedWriter;
import java.io.FileWriter;
import java.io.IOException;
import java.io.PrintWriter;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;

public class Dirichlet {
    Alphabet dict;
    double magnitude = 1.0;
    double[] partition;
    Randoms random = null;
    public static final double EULER_MASCHERONI = -0.5772156649015329;
    public static final double PI_SQUARED_OVER_SIX = 1.6449340668482264;
    public static final double HALF_LOG_TWO_PI = Math.log(Math.PI * 2) / 2.0;
    public static final double DIGAMMA_COEF_1 = 0.0;
    public static final double DIGAMMA_COEF_2 = 0.0;
    public static final double DIGAMMA_COEF_3 = 0.0;
    public static final double DIGAMMA_COEF_4 = 0.0;
    public static final double DIGAMMA_COEF_5 = 0.0;
    public static final double DIGAMMA_COEF_6 = 0.0;
    public static final double DIGAMMA_COEF_7 = 0.0;
    public static final double DIGAMMA_COEF_8 = 0.0;
    public static final double DIGAMMA_COEF_9 = 3.0;
    public static final double DIGAMMA_COEF_10 = 26.0;
    public static final double DIGAMMA_LARGE = 9.5;
    public static final double DIGAMMA_SMALL = 1.0E-6;

    public Dirichlet(double m, double[] p) {
        this.magnitude = m;
        this.partition = p;
    }

    public Dirichlet(double[] p) {
        this.magnitude = 0.0;
        this.partition = new double[p.length];
        int i = 0;
        while (i < p.length) {
            this.magnitude += p[i];
            ++i;
        }
        i = 0;
        while (i < p.length) {
            this.partition[i] = p[i] / this.magnitude;
            ++i;
        }
    }

    public Dirichlet(double[] alphas, Alphabet dict) {
        this(alphas);
        if (dict != null && alphas.length != dict.size()) {
            throw new IllegalArgumentException("alphas and dict sizes do not match.");
        }
        this.dict = dict;
        if (dict != null) {
            dict.stopGrowth();
        }
    }

    public Dirichlet(Alphabet dict) {
        this(dict, 1.0);
    }

    public Dirichlet(Alphabet dict, double alpha) {
        this(dict.size(), alpha);
        this.dict = dict;
        dict.stopGrowth();
    }

    public Dirichlet(int size) {
        this(size, 1.0);
    }

    public Dirichlet(int size, double alpha) {
        this.magnitude = (double)size * alpha;
        this.partition = new double[size];
        this.partition[0] = 1.0 / (double)size;
        int i = 1;
        while (i < size) {
            this.partition[i] = this.partition[0];
            ++i;
        }
    }

    private void initRandom() {
        if (this.random == null) {
            this.random = new Randoms();
        }
    }

    public double[] nextDistribution() {
        double[] distribution = new double[this.partition.length];
        this.initRandom();
        double sum = 0.0;
        int i = 0;
        while (i < distribution.length) {
            distribution[i] = this.random.nextGamma(this.partition[i] * this.magnitude, 1.0);
            if (distribution[i] <= 0.0) {
                distribution[i] = 1.0E-4;
            }
            sum += distribution[i];
            ++i;
        }
        i = 0;
        while (i < distribution.length) {
            int n = i++;
            distribution[n] = distribution[n] / sum;
        }
        return distribution;
    }

    public static String distributionToString(double magnitude, double[] distribution) {
        StringBuffer output = new StringBuffer();
        NumberFormat formatter = NumberFormat.getInstance();
        formatter.setMaximumFractionDigits(5);
        output.append(String.valueOf(formatter.format(magnitude)) + ":\t");
        int i = 0;
        while (i < distribution.length) {
            output.append(String.valueOf(formatter.format(distribution[i])) + "\t");
            ++i;
        }
        return output.toString();
    }

    public void toFile(String filename) throws IOException {
        PrintWriter out = new PrintWriter(new BufferedWriter(new FileWriter(filename)));
        int i = 0;
        while (i < this.partition.length) {
            out.println(this.magnitude * this.partition[i]);
            ++i;
        }
        out.flush();
        out.close();
    }

    public int[] drawObservation(int n) {
        this.initRandom();
        double[] distribution = this.nextDistribution();
        return this.drawObservation(n, distribution);
    }

    public int[] drawObservation(int n, double[] distribution) {
        this.initRandom();
        int[] histogram = new int[this.partition.length];
        Arrays.fill(histogram, 0);
        int count = n < 100 ? this.random.nextPoisson() : (int)Math.round(this.random.nextGaussian(n, n));
        int i = 0;
        while (i < count) {
            int n2 = this.random.nextDiscrete(distribution);
            histogram[n2] = histogram[n2] + 1;
            ++i;
        }
        return histogram;
    }

    public Object[] drawObservations(int d, int n) {
        Object[] observations = new Object[d];
        int i = 0;
        while (i < d) {
            observations[i] = this.drawObservation(n);
            ++i;
        }
        return observations;
    }

    public static double logGammaDefinition(double z) {
        double result = -0.5772156649015329 * z - Math.log(z);
        int k = 1;
        while (k < 10000000) {
            result += z / (double)k - Math.log(1.0 + z / (double)k);
            ++k;
        }
        return result;
    }

    public static double logGammaDifference(double z, int n) {
        double result = 0.0;
        int i = 0;
        while (i < n) {
            result += Math.log(z + (double)i);
            ++i;
        }
        return result;
    }

    public static double logGamma(double z) {
        return Dirichlet.logGammaStirling(z);
    }

    public static double logGammaStirling(double z) {
        int shift = 0;
        while (z < 2.0) {
            z += 1.0;
            ++shift;
        }
        double result = HALF_LOG_TWO_PI + (z - 0.5) * Math.log(z) - z + 1.0 / (12.0 * z) - 1.0 / (360.0 * z * z * z) + 1.0 / (1260.0 * z * z * z * z * z);
        while (shift > 0) {
            --shift;
            result -= Math.log(z -= 1.0);
        }
        return result;
    }

    public static double logGammaNemes(double z) {
        double result = HALF_LOG_TWO_PI - Math.log(z) / 2.0 + z * (Math.log(z + 1.0 / (12.0 * z - 1.0 / (10.0 * z))) - 1.0);
        return result;
    }

    /*
     * Unable to fully structure code
     */
    public static double digamma(double z) {
        psi = 0.0;
        if (!(z < 1.0E-6)) ** GOTO lbl7
        psi = -0.5772156649015329 - 1.0 / z;
        return psi;
lbl-1000:
        // 1 sources

        {
            psi -= 1.0 / z;
            z += 1.0;
lbl7:
            // 2 sources

            ** while (z < 9.5)
        }
lbl8:
        // 1 sources

        invZ = 1.0 / z;
        invZSquared = invZ * invZ;
        return psi += Math.log(z) - 0.5 * invZ - invZSquared * (0.0 - invZSquared * (0.0 - invZSquared * (0.0 - invZSquared * (0.0 - invZSquared * (0.0 - invZSquared * (0.0 - invZSquared * 0.0))))));
    }

    public static double digammaDifference(double x, int n) {
        double sum = 0.0;
        int i = 0;
        while (i < n) {
            sum += 1.0 / (x + (double)i);
            ++i;
        }
        return sum;
    }

    public static double trigamma(double z) {
        int shift = 0;
        while (z < 2.0) {
            z += 1.0;
            ++shift;
        }
        double oneOverZ = 1.0 / z;
        double oneOverZSquared = oneOverZ * oneOverZ;
        double result = oneOverZ + 0.5 * oneOverZSquared + 0.1666667 * oneOverZSquared * oneOverZ - 0.03333333 * oneOverZSquared * oneOverZSquared * oneOverZ + 0.02380952 * oneOverZSquared * oneOverZSquared * oneOverZSquared * oneOverZ - 0.03333333 * oneOverZSquared * oneOverZSquared * oneOverZSquared * oneOverZSquared * oneOverZ;
        while (shift > 0) {
            --shift;
            result += 1.0 / ((z -= 1.0) * z);
        }
        return result;
    }

    public static double learnSymmetricConcentration(int[] countHistogram, int[] observationLengths, int numDimensions, double currentValue) {
        int largestNonZeroCount = 0;
        int[] nonZeroLengthIndex = new int[observationLengths.length];
        int index = 0;
        while (index < countHistogram.length) {
            if (countHistogram[index] > 0) {
                largestNonZeroCount = index;
            }
            ++index;
        }
        int denseIndex = 0;
        int index2 = 0;
        while (index2 < observationLengths.length) {
            if (observationLengths[index2] > 0) {
                nonZeroLengthIndex[denseIndex] = index2;
                ++denseIndex;
            }
            ++index2;
        }
        int denseIndexSize = denseIndex;
        int iteration = 1;
        while (iteration <= 200) {
            double currentParameter = currentValue / (double)numDimensions;
            double currentDigamma = 0.0;
            double numerator = 0.0;
            int index3 = 1;
            while (index3 <= largestNonZeroCount) {
                numerator += (double)countHistogram[index3] * (currentDigamma += 1.0 / (currentParameter + (double)index3 - 1.0));
                ++index3;
            }
            currentDigamma = 0.0;
            double denominator = 0.0;
            int previousLength = 0;
            double cachedDigamma = Dirichlet.digamma(currentValue);
            denseIndex = 0;
            while (denseIndex < denseIndexSize) {
                int length = nonZeroLengthIndex[denseIndex];
                if (length - previousLength > 20) {
                    currentDigamma = Dirichlet.digamma(currentValue + (double)length) - cachedDigamma;
                } else {
                    int index4 = previousLength;
                    while (index4 < length) {
                        currentDigamma += 1.0 / (currentValue + (double)index4);
                        ++index4;
                    }
                }
                denominator += currentDigamma * (double)observationLengths[length];
                ++denseIndex;
            }
            currentValue = currentParameter * numerator / denominator;
            ++iteration;
        }
        return currentValue;
    }

    public static void testSymmetricConcentration(int numDimensions, int numObservations, int observationMeanLength) {
        double logD = Math.log(numDimensions);
        int exponent = -5;
        while (exponent < 4) {
            double alpha = (double)numDimensions * 1.0;
            Dirichlet prior = new Dirichlet(numDimensions, alpha / (double)numDimensions);
            int[] countHistogram = new int[1000000];
            int[] observationLengths = new int[1000000];
            Object[] observations = prior.drawObservations(numObservations, observationMeanLength);
            Dirichlet optimizedDirichlet = new Dirichlet(numDimensions, 1.0);
            optimizedDirichlet.learnParametersWithHistogram(observations);
            System.out.println(optimizedDirichlet.magnitude);
            int i = 0;
            while (i < numObservations) {
                int[] observation = (int[])observations[i];
                int total = 0;
                int k = 0;
                while (k < numDimensions) {
                    if (observation[k] > 0) {
                        total += observation[k];
                        int n = observation[k];
                        countHistogram[n] = countHistogram[n] + 1;
                    }
                    ++k;
                }
                int n = total;
                observationLengths[n] = observationLengths[n] + 1;
                ++i;
            }
            double estimatedAlpha = Dirichlet.learnSymmetricConcentration(countHistogram, observationLengths, numDimensions, 1.0);
            System.out.println(String.valueOf(alpha) + "\t" + estimatedAlpha + "\t" + Math.abs(alpha - estimatedAlpha));
            ++exponent;
        }
    }

    public static double learnParameters(double[] parameters, int[][] observations, int[] observationLengths) {
        return Dirichlet.learnParameters(parameters, observations, observationLengths, 1.00001, 1.0, 200);
    }

    public static double learnParameters(double[] parameters, int[][] observations, int[] observationLengths, double shape, double scale, int numIterations) {
        int[] histogram;
        double parametersSum = 0.0;
        int k = 0;
        while (k < parameters.length) {
            parametersSum += parameters[k];
            ++k;
        }
        int[] nonZeroLimits = new int[observations.length];
        Arrays.fill(nonZeroLimits, -1);
        int i = 0;
        while (i < observations.length) {
            histogram = observations[i];
            k = 0;
            while (k < histogram.length) {
                if (histogram[k] > 0) {
                    nonZeroLimits[i] = k;
                }
                ++k;
            }
            ++i;
        }
        int iteration = 0;
        while (iteration < numIterations) {
            double denominator = 0.0;
            double currentDigamma = 0.0;
            i = 1;
            while (i < observationLengths.length) {
                denominator += (double)observationLengths[i] * (currentDigamma += 1.0 / (parametersSum + (double)i - 1.0));
                ++i;
            }
            denominator -= 1.0 / scale;
            parametersSum = 0.0;
            k = 0;
            while (k < parameters.length) {
                int nonZeroLimit = nonZeroLimits[k];
                double oldParametersK = parameters[k];
                parameters[k] = 0.0;
                currentDigamma = 0.0;
                histogram = observations[k];
                i = 1;
                while (i <= nonZeroLimit) {
                    int n = k;
                    parameters[n] = parameters[n] + (double)histogram[i] * (currentDigamma += 1.0 / (oldParametersK + (double)i - 1.0));
                    ++i;
                }
                parameters[k] = oldParametersK * (parameters[k] + shape) / denominator;
                parametersSum += parameters[k];
                ++k;
            }
            ++iteration;
        }
        if (parametersSum < 0.0) {
            throw new RuntimeException("sum: " + parametersSum);
        }
        return parametersSum;
    }

    public long learnParametersWithHistogram(Object[] observations) {
        int maxLength = 0;
        int[] maxBinCounts = new int[this.partition.length];
        Arrays.fill(maxBinCounts, 0);
        int i = 0;
        while (i < observations.length) {
            int length = 0;
            int[] observation = (int[])observations[i];
            int bin = 0;
            while (bin < observation.length) {
                if (observation[bin] > maxBinCounts[bin]) {
                    maxBinCounts[bin] = observation[bin];
                }
                length += observation[bin];
                ++bin;
            }
            if (length > maxLength) {
                maxLength = length;
            }
            ++i;
        }
        int[][] binCountHistograms = new int[this.partition.length][];
        int bin = 0;
        while (bin < this.partition.length) {
            binCountHistograms[bin] = new int[maxBinCounts[bin] + 1];
            Arrays.fill(binCountHistograms[bin], 0);
            ++bin;
        }
        int[] lengthHistogram = new int[maxLength + 1];
        Arrays.fill(lengthHistogram, 0);
        int i2 = 0;
        while (i2 < observations.length) {
            int length = 0;
            int[] observation = (int[])observations[i2];
            int bin2 = 0;
            while (bin2 < observation.length) {
                int[] nArray = binCountHistograms[bin2];
                int n = observation[bin2];
                nArray[n] = nArray[n] + 1;
                length += observation[bin2];
                ++bin2;
            }
            int n = length;
            lengthHistogram[n] = lengthHistogram[n] + 1;
            ++i2;
        }
        return this.learnParametersWithHistogram(binCountHistograms, lengthHistogram);
    }

    public long learnParametersWithHistogram(int[][] binCountHistograms, int[] lengthHistogram) {
        long start = System.currentTimeMillis();
        double[] newParameters = new double[this.partition.length];
        double parametersSum = 0.0;
        int k = 0;
        while (k < this.partition.length) {
            newParameters[k] = this.magnitude * this.partition[k];
            parametersSum += newParameters[k];
            ++k;
        }
        int iteration = 0;
        while (iteration < 1000) {
            double denominator = 0.0;
            double currentDigamma = 0.0;
            int i = 1;
            while (i < lengthHistogram.length) {
                denominator += (double)lengthHistogram[i] * (currentDigamma += 1.0 / (parametersSum + (double)i - 1.0));
                ++i;
            }
            assert (denominator > 0.0);
            assert (!Double.isNaN(denominator));
            parametersSum = 0.0;
            k = 0;
            while (k < this.partition.length) {
                double alphaK = newParameters[k];
                newParameters[k] = 0.0;
                currentDigamma = 0.0;
                int[] histogram = binCountHistograms[k];
                if (histogram.length <= 1) {
                    newParameters[k] = 1.0E-6;
                } else {
                    i = 1;
                    while (i < histogram.length) {
                        int n = k;
                        newParameters[n] = newParameters[n] + (double)histogram[i] * (currentDigamma += 1.0 / (alphaK + (double)i - 1.0));
                        ++i;
                    }
                }
                if (!(newParameters[k] > 0.0)) {
                    System.out.println("length of empty array: " + new int[0].length);
                    i = 0;
                    while (i < histogram.length) {
                        System.out.print(String.valueOf(histogram[i]) + " ");
                        ++i;
                    }
                    System.out.println();
                }
                assert (newParameters[k] > 0.0);
                assert (!Double.isNaN(newParameters[k]));
                int n = k;
                newParameters[n] = newParameters[n] * (alphaK / denominator);
                parametersSum += newParameters[k];
                ++k;
            }
            ++iteration;
        }
        k = 0;
        while (k < this.partition.length) {
            this.partition[k] = newParameters[k] / parametersSum;
            this.magnitude = parametersSum;
            ++k;
        }
        return System.currentTimeMillis() - start;
    }

    public long learnParametersWithDigamma(Object[] observations) {
        int[][] binCounts = new int[this.partition.length][observations.length];
        int[] observationLengths = new int[observations.length];
        int i = 0;
        while (i < observations.length) {
            int[] observation = (int[])observations[i];
            int bin = 0;
            while (bin < this.partition.length) {
                binCounts[bin][i] = observation[bin];
                int n = i;
                observationLengths[n] = observationLengths[n] + observation[bin];
                ++bin;
            }
            ++i;
        }
        return this.learnParametersWithDigamma(binCounts, observationLengths);
    }

    public long learnParametersWithDigamma(int[][] binCounts, int[] observationLengths) {
        long start = System.currentTimeMillis();
        double[] newParameters = new double[this.partition.length];
        int iteration = 0;
        while (iteration < 1000) {
            double newMagnitude = 0.0;
            double denominator = 0.0;
            int i = 0;
            while (i < observationLengths.length) {
                denominator += Dirichlet.digamma(this.magnitude + (double)observationLengths[i]);
                ++i;
            }
            denominator -= (double)observationLengths.length * Dirichlet.digamma(this.magnitude);
            int k = 0;
            while (k < this.partition.length) {
                newParameters[k] = 0.0;
                int[] counts = binCounts[k];
                double alphaK = this.magnitude * this.partition[k];
                double digammaAlphaK = Dirichlet.digamma(alphaK);
                i = 0;
                while (i < counts.length) {
                    if (counts[i] == 0) {
                        int n = k;
                        newParameters[n] = newParameters[n] + digammaAlphaK;
                    } else {
                        int n = k;
                        newParameters[n] = newParameters[n] + Dirichlet.digamma(alphaK + (double)counts[i]);
                    }
                    ++i;
                }
                int n = k;
                newParameters[n] = newParameters[n] - (double)counts.length * digammaAlphaK;
                if (newParameters[k] <= 0.0) {
                    newParameters[k] = 1.0E-6;
                } else {
                    int n2 = k;
                    newParameters[n2] = newParameters[n2] * (alphaK / denominator);
                }
                if (newParameters[k] <= 0.0) {
                    System.out.println(String.valueOf(newParameters[k]) + "\t" + alphaK + "\t" + denominator);
                }
                assert (newParameters[k] > 0.0);
                assert (!Double.isNaN(newParameters[k]));
                newMagnitude += newParameters[k];
                ++k;
            }
            this.magnitude = newMagnitude;
            k = 0;
            while (k < this.partition.length) {
                this.partition[k] = newParameters[k] / this.magnitude;
                ++k;
            }
            ++iteration;
        }
        return System.currentTimeMillis() - start;
    }

    public long learnParametersWithMoments(Object[] observations) {
        int bin;
        long start = System.currentTimeMillis();
        int[] observationLengths = new int[observations.length];
        double[] variances = new double[this.partition.length];
        Arrays.fill(this.partition, 0.0);
        Arrays.fill(observationLengths, 0);
        Arrays.fill(variances, 0.0);
        int i = 0;
        while (i < observations.length) {
            int[] observation = (int[])observations[i];
            bin = 0;
            while (bin < this.partition.length) {
                int n = i;
                observationLengths[n] = observationLengths[n] + observation[bin];
                ++bin;
            }
            bin = 0;
            while (bin < this.partition.length) {
                int n = bin;
                this.partition[n] = this.partition[n] + (double)observation[bin] / (double)observationLengths[i];
                ++bin;
            }
            ++i;
        }
        bin = 0;
        while (bin < this.partition.length) {
            int n = bin++;
            this.partition[n] = this.partition[n] / (double)observations.length;
        }
        i = 0;
        while (i < observations.length) {
            int[] observation = (int[])observations[i];
            bin = 0;
            while (bin < this.partition.length) {
                double difference = (double)observation[bin] / (double)observationLengths[i] - this.partition[bin];
                int n = bin++;
                variances[n] = variances[n] + difference * difference;
            }
            ++i;
        }
        bin = 0;
        while (bin < this.partition.length) {
            int n = bin++;
            variances[n] = variances[n] / (double)(observations.length - 1);
        }
        double sum = 0.0;
        bin = 0;
        while (bin < this.partition.length) {
            if (this.partition[bin] != 0.0) {
                sum += Math.log(this.partition[bin] * (1.0 - this.partition[bin]) / variances[bin] - 1.0);
            }
            ++bin;
        }
        this.magnitude = Math.exp(sum / (double)(this.partition.length - 1));
        return System.currentTimeMillis() - start;
    }

    public long learnParametersWithLeaveOneOut(Object[] observations) {
        int[][] binCounts = new int[this.partition.length][observations.length];
        int[] observationLengths = new int[observations.length];
        int i = 0;
        while (i < observations.length) {
            int[] observation = (int[])observations[i];
            int bin = 0;
            while (bin < this.partition.length) {
                binCounts[bin][i] = observation[bin];
                int n = i;
                observationLengths[n] = observationLengths[n] + observation[bin];
                ++bin;
            }
            ++i;
        }
        return this.learnParametersWithLeaveOneOut(binCounts, observationLengths);
    }

    public long learnParametersWithLeaveOneOut(int[][] binCounts, int[] observationLengths) {
        long start = System.currentTimeMillis();
        double[] newParameters = new double[this.partition.length];
        double[] binSums = new double[this.partition.length];
        double observationSum = 0.0;
        double parameterSum = 0.0;
        int iteration = 0;
        while (iteration < 1000) {
            observationSum = 0.0;
            Arrays.fill(binSums, 0.0);
            int i = 0;
            while (i < observationLengths.length) {
                observationSum += (double)observationLengths[i] / ((double)(observationLengths[i] - 1) + this.magnitude);
                ++i;
            }
            int bin = 0;
            while (bin < this.partition.length) {
                int[] counts = binCounts[bin];
                i = 0;
                while (i < counts.length) {
                    if (counts[i] >= 2) {
                        int n = bin;
                        binSums[n] = binSums[n] + (double)counts[i] / ((double)(counts[i] - 1) + this.magnitude * this.partition[bin]);
                    }
                    ++i;
                }
                ++bin;
            }
            parameterSum = 0.0;
            bin = 0;
            while (bin < this.partition.length) {
                newParameters[bin] = binSums[bin] == 0.0 ? 1.0E-6 : this.partition[bin] * this.magnitude * binSums[bin] / observationSum;
                parameterSum += newParameters[bin];
                ++bin;
            }
            bin = 0;
            while (bin < this.partition.length) {
                this.partition[bin] = newParameters[bin] / parameterSum;
                ++bin;
            }
            this.magnitude = parameterSum;
            ++iteration;
        }
        return System.currentTimeMillis() - start;
    }

    public double absoluteDifference(Dirichlet other) {
        if (this.partition.length != other.partition.length) {
            throw new IllegalArgumentException("dirichlets must have the same dimension to be compared");
        }
        double residual = 0.0;
        int k = 0;
        while (k < this.partition.length) {
            residual += Math.abs(this.partition[k] * this.magnitude - other.partition[k] * other.magnitude);
            ++k;
        }
        return residual;
    }

    public double squaredDifference(Dirichlet other) {
        if (this.partition.length != other.partition.length) {
            throw new IllegalArgumentException("dirichlets must have the same dimension to be compared");
        }
        double residual = 0.0;
        int k = 0;
        while (k < this.partition.length) {
            residual += Math.pow(this.partition[k] * this.magnitude - other.partition[k] * other.magnitude, 2.0);
            ++k;
        }
        return residual;
    }

    public void checkBreakeven(double x) {
        double digammaX = Dirichlet.digamma(x);
        int n = 1;
        while (n < 100) {
            long start = System.currentTimeMillis();
            int i = 0;
            while (i < 1000000) {
                Dirichlet.digamma(x + (double)n);
                ++i;
            }
            long clock1 = System.currentTimeMillis() - start;
            start = System.currentTimeMillis();
            i = 0;
            while (i < 1000000) {
                Dirichlet.digammaDifference(x, n);
                ++i;
            }
            long clock2 = System.currentTimeMillis() - start;
            System.out.println(String.valueOf(n) + "\tdirect: " + clock1 + "\tindirect: " + clock2 + " (" + (clock1 - clock2) + ")");
            System.out.println("  " + (Dirichlet.digamma(x + (double)n) - digammaX) + " " + Dirichlet.digammaDifference(x, n));
            ++n;
        }
    }

    public static String compare(double sum, int k, int n, int w) {
        StringBuffer output = new StringBuffer();
        output.append(String.valueOf(sum) + "\t" + k + "\t" + n + "\t" + w + "\t");
        Dirichlet uniformDirichlet = new Dirichlet(k, sum / (double)k);
        Dirichlet dirichlet = new Dirichlet(sum, uniformDirichlet.nextDistribution());
        Object[] observations = dirichlet.drawObservations(n, w);
        Dirichlet estimatedDirichlet = new Dirichlet(k, sum / (double)k);
        long time = estimatedDirichlet.learnParametersWithDigamma(observations);
        output.append(String.valueOf(time) + "\t" + dirichlet.absoluteDifference(estimatedDirichlet) + "\t");
        estimatedDirichlet = new Dirichlet(k, sum / (double)k);
        time = estimatedDirichlet.learnParametersWithHistogram(observations);
        output.append(String.valueOf(time) + "\t" + dirichlet.absoluteDifference(estimatedDirichlet) + "\t");
        estimatedDirichlet = new Dirichlet(k, sum / (double)k);
        time = estimatedDirichlet.learnParametersWithMoments(observations);
        output.append(String.valueOf(time) + "\t" + dirichlet.absoluteDifference(estimatedDirichlet) + "\t");
        estimatedDirichlet = new Dirichlet(k, sum / (double)k);
        time = estimatedDirichlet.learnParametersWithLeaveOneOut(observations);
        output.append(String.valueOf(time) + "\t" + dirichlet.absoluteDifference(estimatedDirichlet) + "\t");
        return output.toString();
    }

    public static double dirichletMultinomialLikelihoodRatio(TIntIntHashMap countsX, TIntIntHashMap countsY, double alpha, double alphaSum) {
        double logLikelihood = 0.0;
        double logGammaAlpha = Dirichlet.logGamma(alpha);
        int totalX = 0;
        int totalY = 0;
        TIntHashSet distinctKeys = new TIntHashSet();
        distinctKeys.addAll(countsX.keys());
        distinctKeys.addAll(countsY.keys());
        for (int key : distinctKeys) {
            int x = 0;
            if (countsX.containsKey(key)) {
                x = countsX.get(key);
            }
            int y = 0;
            if (countsY.containsKey(key)) {
                y = countsY.get(key);
            }
            totalX += x;
            totalY += y;
            logLikelihood += Dirichlet.logGamma(alpha) + Dirichlet.logGamma(alpha + (double)x + (double)y) - Dirichlet.logGamma(alpha + (double)x) - Dirichlet.logGamma(alpha + (double)y);
        }
        return logLikelihood += Dirichlet.logGamma(alphaSum + (double)totalX) + Dirichlet.logGamma(alphaSum + (double)totalY) - Dirichlet.logGamma(alphaSum) - Dirichlet.logGamma(alphaSum + (double)totalX + (double)totalY);
    }

    public static double dirichletMultinomialLikelihoodRatio(int[] countsX, int[] countsY, double alpha, double alphaSum) {
        if (countsX.length != countsY.length) {
            throw new IllegalArgumentException("both arrays must contain the same number of dimensions");
        }
        double logLikelihood = 0.0;
        double logGammaAlpha = Dirichlet.logGamma(alpha);
        int totalX = 0;
        int totalY = 0;
        int key = 0;
        while (key < countsX.length) {
            int x = countsX[key];
            int y = countsY[key];
            totalX += x;
            totalY += y;
            logLikelihood += logGammaAlpha + Dirichlet.logGamma(alpha + (double)x + (double)y) - Dirichlet.logGamma(alpha + (double)x) - Dirichlet.logGamma(alpha + (double)y);
            ++key;
        }
        return logLikelihood += Dirichlet.logGamma(alphaSum + (double)totalX) + Dirichlet.logGamma(alphaSum + (double)totalY) - Dirichlet.logGamma(alphaSum) - Dirichlet.logGamma(alphaSum + (double)totalX + (double)totalY);
    }

    public double dirichletMultinomialLikelihoodRatio(int[] countsX, int[] countsY) {
        if (countsX.length != countsY.length || countsX.length != this.partition.length) {
            throw new IllegalArgumentException("both arrays and the Dirichlet prior must contain the same number of dimensions");
        }
        double logLikelihood = 0.0;
        int totalX = 0;
        int totalY = 0;
        int key = 0;
        while (key < countsX.length) {
            int x = countsX[key];
            int y = countsY[key];
            totalX += x;
            totalY += y;
            double alpha = this.partition[key] * this.magnitude;
            logLikelihood += Dirichlet.logGamma(alpha) + Dirichlet.logGamma(alpha + (double)x + (double)y) - Dirichlet.logGamma(alpha + (double)x) - Dirichlet.logGamma(alpha + (double)y);
            ++key;
        }
        return logLikelihood += Dirichlet.logGamma(this.magnitude + (double)totalX) + Dirichlet.logGamma(this.magnitude + (double)totalY) - Dirichlet.logGamma(this.magnitude) - Dirichlet.logGamma(this.magnitude + (double)totalX + (double)totalY);
    }

    public static double ewensLikelihoodRatio(int[] countsX, int[] countsY, double lambda) {
        int y;
        int x;
        if (countsX.length != countsY.length) {
            throw new IllegalArgumentException("both arrays must contain the same number of dimensions");
        }
        double logLikelihood = 0.0;
        int totalX = 0;
        int totalY = 0;
        int total = 0;
        int key = 0;
        while (key < countsX.length) {
            x = countsX[key];
            y = countsY[key];
            totalX += x;
            totalY += y;
            total += x + y;
            ++key;
        }
        int[] countHistogramX = new int[total + 1];
        int[] countHistogramY = new int[total + 1];
        int[] countHistogramBoth = new int[total + 1];
        int key2 = 0;
        while (key2 < countsX.length) {
            x = countsX[key2];
            y = countsY[key2];
            int n = x;
            countHistogramX[n] = countHistogramX[n] + 1;
            int n2 = y;
            countHistogramX[n2] = countHistogramX[n2] + 1;
            int n3 = x + y;
            countHistogramBoth[n3] = countHistogramBoth[n3] + 1;
            ++key2;
        }
        int j = 1;
        while (j <= total) {
            if (countHistogramX[j] != 0 || countHistogramY[j] != 0 || countHistogramBoth[j] != 0) {
                logLikelihood += (double)(countHistogramBoth[j] - countHistogramX[j] - countHistogramY[j]) * Math.log(lambda / (double)j);
                logLikelihood += Dirichlet.logGamma(countHistogramX[j] + 1) + Dirichlet.logGamma(countHistogramY[j] + 1) - Dirichlet.logGamma(countHistogramBoth[j] + 1);
            }
            ++j;
        }
        logLikelihood += Dirichlet.logGamma(total + 1) - Dirichlet.logGamma(totalX + 1) - Dirichlet.logGamma(totalY + 1);
        return logLikelihood += Dirichlet.logGamma(lambda + (double)totalX) + Dirichlet.logGamma(lambda + (double)totalY) - Dirichlet.logGamma(lambda) - Dirichlet.logGamma(lambda + (double)totalX + (double)totalY);
    }

    public static void runComparison() {
        try {
            PrintWriter out = new PrintWriter(new BufferedWriter(new FileWriter("comparison")));
            int dimensions = 10;
            int j = 0;
            while (j < 5) {
                int documents = 100;
                int k = 0;
                while (k < 5) {
                    int meanSize = 100;
                    int l = 0;
                    while (l < 5) {
                        System.out.println(String.valueOf(dimensions) + "\t" + dimensions + "\t" + documents + "\t" + meanSize);
                        int m = 0;
                        while (m < 10) {
                            out.println(Dirichlet.compare(dimensions, dimensions, documents, meanSize));
                            ++m;
                        }
                        out.flush();
                        meanSize *= 2;
                        ++l;
                    }
                    documents *= 2;
                    ++k;
                }
                dimensions *= 2;
                ++j;
            }
            out.flush();
            out.close();
        }
        catch (Exception e) {
            e.printStackTrace(System.out);
        }
    }

    public static void main(String[] args) {
        Dirichlet.testSymmetricConcentration(1000, 100, 1000);
    }

    public Alphabet getAlphabet() {
        return this.dict;
    }

    public int size() {
        return this.partition.length;
    }

    public double alpha(int featureIndex) {
        return this.magnitude * this.partition[featureIndex];
    }

    public void print() {
        System.out.println("Dirichlet:");
        int j = 0;
        while (j < this.partition.length) {
            System.out.println(this.dict != null ? this.dict.lookupObject(j).toString() : String.valueOf(j) + "=" + this.magnitude * this.partition[j]);
            ++j;
        }
    }

    protected double[] randomRawMultinomial(Randoms r) {
        double sum = 0.0;
        double[] pr = new double[this.partition.length];
        int i = 0;
        while (i < this.partition.length) {
            pr[i] = r.nextGamma(this.magnitude * this.partition[i]);
            sum += pr[i];
            ++i;
        }
        i = 0;
        while (i < this.partition.length) {
            int n = i++;
            pr[n] = pr[n] / sum;
        }
        return pr;
    }

    public Multinomial randomMultinomial(Randoms r) {
        return new Multinomial(this.randomRawMultinomial(r), this.dict, this.partition.length, false, false);
    }

    public Dirichlet randomDirichlet(Randoms r, double averageAlpha) {
        double[] pr = this.randomRawMultinomial(r);
        double alphaSum = (double)pr.length * averageAlpha;
        int i = 0;
        while (i < pr.length) {
            int n = i++;
            pr[n] = pr[n] * alphaSum;
        }
        return new Dirichlet(pr, this.dict);
    }

    public FeatureSequence randomFeatureSequence(Randoms r, int length) {
        Multinomial m = this.randomMultinomial(r);
        return m.randomFeatureSequence(r, length);
    }

    public FeatureVector randomFeatureVector(Randoms r, int size) {
        return new FeatureVector(this.randomFeatureSequence(r, size));
    }

    public TokenSequence randomTokenSequence(Randoms r, int length) {
        FeatureSequence fs = this.randomFeatureSequence(r, length);
        TokenSequence ts = new TokenSequence(length);
        int i = 0;
        while (i < length) {
            ts.add(fs.getObjectAtPosition(i).toString());
            ++i;
        }
        return ts;
    }

    public double[] randomVector(Randoms r) {
        return this.randomRawMultinomial(r);
    }

    public static abstract class Estimator {
        ArrayList<Multinomial> multinomials;

        public Estimator() {
            this.multinomials = new ArrayList();
        }

        public Estimator(Collection<Multinomial> multinomialsTraining) {
            this.multinomials = new ArrayList<Multinomial>(multinomialsTraining);
            int i = 1;
            while (i < this.multinomials.size()) {
                if (this.multinomials.get(i - 1).size() != this.multinomials.get(i).size() || this.multinomials.get(i - 1).getAlphabet() != this.multinomials.get(i).getAlphabet()) {
                    throw new IllegalArgumentException("All multinomials must have same size and Alphabet.");
                }
                ++i;
            }
        }

        public void addMultinomial(Multinomial m) {
            this.multinomials.add(m);
        }

        public abstract Dirichlet estimate();
    }

    public static class MethodOfMomentsEstimator
    extends Estimator {
        @Override
        public Dirichlet estimate() {
            int dims = ((Multinomial)this.multinomials.get(0)).size();
            double[] alphas = new double[dims];
            int i = 1;
            while (i < this.multinomials.size()) {
                ((Multinomial)this.multinomials.get(i)).addProbabilitiesTo(alphas);
                ++i;
            }
            double alphaSum = 0.0;
            int i2 = 0;
            while (i2 < alphas.length) {
                alphaSum += alphas[i2];
                ++i2;
            }
            i2 = 0;
            while (i2 < alphas.length) {
                int n = i2++;
                alphas[n] = alphas[n] / alphaSum;
            }
            throw new UnsupportedOperationException("Not yet implemented.");
        }
    }
}

