package org.apache.mahout.clustering.streaming.cluster;

import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.apache.mahout.clustering.ClusteringUtils;
import org.apache.mahout.common.Pair;
import org.apache.mahout.common.distance.EuclideanDistanceMeasure;
import org.apache.mahout.common.distance.SquaredEuclideanDistanceMeasure;
import org.apache.mahout.math.Centroid;
import org.apache.mahout.math.ConstantVector;
import org.apache.mahout.math.DenseMatrix;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.SingularValueDecomposition;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.WeightedVector;
import org.apache.mahout.math.function.Functions;
import org.apache.mahout.math.function.VectorFunction;
import org.apache.mahout.math.neighborhood.BruteSearch;
import org.apache.mahout.math.random.MultiNormal;
import org.apache.mahout.math.random.WeightedThing;
import org.apache.mahout.math.stats.OnlineSummarizer;
import org.junit.Assert;
import org.junit.Test;

/* loaded from: input_file:org/apache/mahout/clustering/streaming/cluster/BallKMeansTest.class */
public class BallKMeansTest {
    private static final int NUM_ITERATIONS = 20;
    private static final int K1 = 100;
    private static final int NUM_DIMENSIONS = 4;
    private static final int NUM_DATA_POINTS = 10000;
    private static final double DISTRIBUTION_RADIUS = 0.01d;
    private static Pair<List<Centroid>, List<Centroid>> syntheticData = DataUtils.sampleMultiNormalHypercube(NUM_DIMENSIONS, NUM_DATA_POINTS, DISTRIBUTION_RADIUS);

    @Test
    public void testClusteringMultipleRuns() {
        for (int i = 1; i <= 10; i++) {
            BallKMeans ballKMeans = new BallKMeans(new BruteSearch(new SquaredEuclideanDistanceMeasure()), 16, NUM_ITERATIONS, true, i);
            ballKMeans.cluster((List) syntheticData.getFirst());
            double d = ClusteringUtils.totalClusterCost((Iterable) syntheticData.getFirst(), ballKMeans);
            BallKMeans ballKMeans2 = new BallKMeans(new BruteSearch(new SquaredEuclideanDistanceMeasure()), 16, NUM_ITERATIONS, false, i);
            ballKMeans2.cluster((List) syntheticData.getFirst());
            double d2 = ClusteringUtils.totalClusterCost((Iterable) syntheticData.getFirst(), ballKMeans2);
            System.out.printf("%d runs; kmeans++: %f; random: %f\n", Integer.valueOf(i), Double.valueOf(d), Double.valueOf(d2));
            Assert.assertTrue("kmeans++ cost should be less than random cost", d < d2);
        }
    }

    @Test
    public void testClustering() {
        BruteSearch bruteSearch = new BruteSearch(new SquaredEuclideanDistanceMeasure());
        BallKMeans ballKMeans = new BallKMeans(bruteSearch, 16, NUM_ITERATIONS);
        long currentTimeMillis = System.currentTimeMillis();
        ballKMeans.cluster((List) syntheticData.getFirst());
        long currentTimeMillis2 = System.currentTimeMillis();
        Assert.assertEquals("Total weight not preserved", ClusteringUtils.totalWeight((Iterable) syntheticData.getFirst()), ClusteringUtils.totalWeight(ballKMeans), 1.0E-9d);
        OnlineSummarizer onlineSummarizer = new OnlineSummarizer();
        Iterator it = ((List) syntheticData.getSecond()).iterator();
        while (it.hasNext()) {
            onlineSummarizer.add(((WeightedThing) bruteSearch.search((Vector) it.next(), 1).get(0)).getWeight());
        }
        Assert.assertTrue(String.format("Median weight [%f] too large [>%f]", Double.valueOf(onlineSummarizer.getMedian()), Double.valueOf(DISTRIBUTION_RADIUS)), onlineSummarizer.getMedian() < DISTRIBUTION_RADIUS);
        double d = (currentTimeMillis2 - currentTimeMillis) / 1000.0d;
        System.out.printf("%s\n%.2f for clustering\n%.1f us per row\n\n", bruteSearch.getClass().getName(), Double.valueOf(d), Double.valueOf((d / ((List) syntheticData.getFirst()).size()) * 1000000.0d));
        double[] dArr = new double[16];
        BruteSearch bruteSearch2 = new BruteSearch(new EuclideanDistanceMeasure());
        Iterator it2 = ((List) syntheticData.getSecond()).iterator();
        while (it2.hasNext()) {
            bruteSearch2.add((Vector) it2.next());
        }
        Iterator it3 = ballKMeans.iterator();
        while (it3.hasNext()) {
            Centroid centroid = (Centroid) it3.next();
            int index = ((Centroid) ((WeightedThing) bruteSearch2.search(centroid, 1).get(0)).getValue()).getIndex();
            dArr[index] = dArr[index] + centroid.getWeight();
        }
        for (double d2 : dArr) {
            System.out.printf("%f ", Double.valueOf(d2));
        }
        System.out.println();
        for (double d3 : dArr) {
            Assert.assertEquals(625, d3, 0.0d);
        }
    }

    @Test
    public void testInitialization() {
        List<? extends WeightedVector> cubishTestData = cubishTestData(DISTRIBUTION_RADIUS);
        BallKMeans ballKMeans = new BallKMeans(new BruteSearch(new SquaredEuclideanDistanceMeasure()), 6, NUM_ITERATIONS);
        ballKMeans.cluster(cubishTestData);
        DenseMatrix denseMatrix = new DenseMatrix(6, 5);
        int i = 0;
        Iterator it = ballKMeans.iterator();
        while (it.hasNext()) {
            denseMatrix.viewRow(i).assign(((Centroid) it.next()).viewPart(0, 5));
            i++;
        }
        Assert.assertEquals(0.0d, denseMatrix.aggregateColumns(new VectorFunction() { // from class: org.apache.mahout.clustering.streaming.cluster.BallKMeansTest.1
            public double apply(Vector vector) {
                return Math.abs(vector.minValue()) + Math.abs(vector.maxValue() - 6.0d) + Math.abs(vector.norm(1.0d) - 6.0d);
            }
        }).norm(1.0d) / r0.size(), 0.1d);
        Vector assign = new SingularValueDecomposition(denseMatrix).getS().viewDiagonal().assign(Functions.div(6.0d));
        Assert.assertEquals(5.0d, assign.getLengthSquared(), 0.05d);
        Assert.assertEquals(5.0d, assign.norm(1.0d), 0.05d);
    }

    private static List<? extends WeightedVector> cubishTestData(double d) {
        ArrayList newArrayListWithCapacity = Lists.newArrayListWithCapacity(5100);
        int i = 0;
        MultiNormal multiNormal = new MultiNormal(d, new ConstantVector(0.0d, 10));
        for (int i2 = 0; i2 < K1; i2++) {
            int i3 = i;
            i++;
            newArrayListWithCapacity.add(new WeightedVector(multiNormal.sample(), 1.0d, i3));
        }
        for (int i4 = 0; i4 < 5; i4++) {
            DenseVector denseVector = new DenseVector(10);
            denseVector.set(i4, 6.0d);
            MultiNormal multiNormal2 = new MultiNormal(d, denseVector);
            for (int i5 = 0; i5 < 1000; i5++) {
                int i6 = i;
                i++;
                newArrayListWithCapacity.add(new WeightedVector(multiNormal2.sample(), 1.0d, i6));
            }
        }
        return newArrayListWithCapacity;
    }
}
