package org.apache.mahout.math.random;

import com.google.common.collect.HashMultiset;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import org.apache.mahout.math.DenseMatrix;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.MahoutTestCase;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.QRDecomposition;
import org.junit.Test;

/* loaded from: input_file:org/apache/mahout/math/random/ChineseRestaurantTest.class */
public final class ChineseRestaurantTest extends MahoutTestCase {
    @Test
    public void testDepth() {
        ArrayList newArrayList = Lists.newArrayList();
        for (int i = 0; i < 1000; i++) {
            ChineseRestaurant chineseRestaurant = new ChineseRestaurant(10.0d);
            HashMultiset create = HashMultiset.create();
            for (int i2 = 0; i2 < 100; i2++) {
                create.add(chineseRestaurant.sample());
            }
            ArrayList newArrayList2 = Lists.newArrayList();
            Iterator it = create.elementSet().iterator();
            while (it.hasNext()) {
                newArrayList2.add(Integer.valueOf(create.count((Integer) it.next())));
            }
            Collections.sort(newArrayList2, Collections.reverseOrder());
            while (newArrayList.size() < newArrayList2.size()) {
                newArrayList.add(0);
            }
            int i3 = 0;
            Iterator it2 = newArrayList2.iterator();
            while (it2.hasNext()) {
                newArrayList.set(i3, Integer.valueOf(((Integer) newArrayList.get(i3)).intValue() + ((Integer) it2.next()).intValue()));
                i3++;
            }
        }
        assertEquals(25000.0d, ((Integer) newArrayList.get(0)).intValue(), 1000.0d);
        assertEquals(24000.0d, ((Integer) newArrayList.get(1)).intValue(), 1000.0d);
        assertEquals(8000.0d, ((Integer) newArrayList.get(2)).intValue(), 200.0d);
        assertEquals(1000.0d, ((Integer) newArrayList.get(15)).intValue(), 50.0d);
        assertEquals(1000.0d, ((Integer) newArrayList.get(20)).intValue(), 40.0d);
    }

    @Test
    public void testExtremeDiscount() {
        ChineseRestaurant chineseRestaurant = new ChineseRestaurant(100.0d, 1.0d);
        HashMultiset create = HashMultiset.create();
        for (int i = 0; i < 10000; i++) {
            create.add(chineseRestaurant.sample());
        }
        assertEquals(10000L, chineseRestaurant.size());
        for (int i2 = 0; i2 < 10000; i2++) {
            assertEquals(1L, chineseRestaurant.count(i2));
        }
    }

    @Test
    public void testGrowth() {
        ChineseRestaurant chineseRestaurant = new ChineseRestaurant(10.0d, 0.0d);
        ChineseRestaurant chineseRestaurant2 = new ChineseRestaurant(10.0d, 0.5d);
        ChineseRestaurant chineseRestaurant3 = new ChineseRestaurant(10.0d, 0.9d);
        ImmutableSet of = ImmutableSet.of(Double.valueOf(1.0d), Double.valueOf(1.5d), Double.valueOf(2.0d), Double.valueOf(3.0d), Double.valueOf(5.0d), Double.valueOf(8.0d), new Double[0]);
        double d = 0.0d;
        int i = 0;
        DenseMatrix denseMatrix = new DenseMatrix(20, 3);
        DenseMatrix denseMatrix2 = new DenseMatrix(20, 3);
        for (int i2 = 0; i2 <= 200000; i2++) {
            if (of.contains(Double.valueOf(i2 / Math.pow(10.0d, Math.floor(Math.log10(i2)))))) {
                if (i2 > 900) {
                    assertEquals(predictSize(denseMatrix.viewPart(0, i, 0, 3), i2, 0.5d), Math.log(chineseRestaurant2.size()), 1.0d);
                    assertEquals(predictSize(denseMatrix2.viewPart(0, i, 0, 3), i2, 0.9d), Math.log(chineseRestaurant3.size()), 1.0d);
                } else if (i2 > 50) {
                    double log = (10.5d * Math.log(i2)) - chineseRestaurant.size();
                    denseMatrix.viewRow(i).assign(new double[]{Math.log(chineseRestaurant2.size()), Math.log(i2), 1.0d});
                    denseMatrix2.viewRow(i).assign(new double[]{Math.log(chineseRestaurant3.size()), Math.log(i2), 1.0d});
                    i++;
                    d += (log - d) / i;
                }
                if (i2 > 10000) {
                    assertEquals(0.0d, hapaxCount(chineseRestaurant) / chineseRestaurant.size(), 0.25d);
                    assertEquals(0.5d, hapaxCount(chineseRestaurant2) / chineseRestaurant2.size(), 0.1d);
                    assertEquals(0.9d, hapaxCount(chineseRestaurant3) / chineseRestaurant3.size(), 0.05d);
                }
            }
            chineseRestaurant.sample();
            chineseRestaurant2.sample();
            chineseRestaurant3.sample();
        }
    }

    private static double predictSize(Matrix matrix, int i, double d) {
        int rowSize = matrix.rowSize();
        Matrix viewPart = matrix.viewPart(0, rowSize, 1, 2);
        Matrix transpose = new QRDecomposition(viewPart.transpose().times(viewPart)).solve(viewPart.transpose().times(matrix.viewPart(0, rowSize, 0, 1))).transpose();
        assertEquals(d, transpose.get(0, 0), 0.2d);
        return transpose.times(new DenseVector(new double[]{Math.log(i), 1.0d})).get(0);
    }

    private static int hapaxCount(ChineseRestaurant chineseRestaurant) {
        int i = 0;
        for (int i2 = 0; i2 < chineseRestaurant.size(); i2++) {
            if (chineseRestaurant.count(i2) == 1) {
                i++;
            }
        }
        return i;
    }
}
