package org.apache.mahout.math.neighborhood;

import java.util.BitSet;
import java.util.Iterator;
import java.util.List;
import org.apache.mahout.common.distance.EuclideanDistanceMeasure;
import org.apache.mahout.math.DenseMatrix;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.WeightedVector;
import org.apache.mahout.math.random.Normal;
import org.apache.mahout.math.random.WeightedThing;
import org.apache.mahout.math.stats.OnlineSummarizer;
import org.junit.Assert;
import org.junit.Test;

/* loaded from: input_file:org/apache/mahout/math/neighborhood/LocalitySensitiveHashSearchTest.class */
public class LocalitySensitiveHashSearchTest {
    @Test
    public void testNormal() {
        DenseMatrix denseMatrix = new DenseMatrix(100000, 10);
        denseMatrix.assign(new Normal());
        EuclideanDistanceMeasure euclideanDistanceMeasure = new EuclideanDistanceMeasure();
        BruteSearch bruteSearch = new BruteSearch(euclideanDistanceMeasure);
        bruteSearch.addAllMatrixSlicesAsWeightedVectors(denseMatrix);
        LocalitySensitiveHashSearch localitySensitiveHashSearch = new LocalitySensitiveHashSearch(euclideanDistanceMeasure, 10);
        localitySensitiveHashSearch.addAllMatrixSlicesAsWeightedVectors(denseMatrix);
        localitySensitiveHashSearch.setSearchSize(200);
        localitySensitiveHashSearch.resetEvaluationCount();
        System.out.printf("speedup,q1,q2,q3\n", new Object[0]);
        for (int i = 0; i < 12; i++) {
            localitySensitiveHashSearch.setRaiseHashLimitStrategy((i - 1.0d) / 10.0d);
            OnlineSummarizer evaluateStrategy = evaluateStrategy(denseMatrix, bruteSearch, localitySensitiveHashSearch);
            double resetEvaluationCount = 1.0E7d / localitySensitiveHashSearch.resetEvaluationCount();
            System.out.printf("%.1f,%.2f,%.2f,%.2f\n", Double.valueOf(resetEvaluationCount), Double.valueOf(evaluateStrategy.getQuartile(1)), Double.valueOf(evaluateStrategy.getQuartile(2)), Double.valueOf(evaluateStrategy.getQuartile(3)));
            Assert.assertTrue(evaluateStrategy.getQuartile(2) > 0.45d);
            Assert.assertTrue(resetEvaluationCount > 4.0d || evaluateStrategy.getQuartile(2) > 0.9d);
            Assert.assertTrue(resetEvaluationCount > 15.0d || evaluateStrategy.getQuartile(2) > 0.8d);
        }
    }

    private static OnlineSummarizer evaluateStrategy(Matrix matrix, BruteSearch bruteSearch, LocalitySensitiveHashSearch localitySensitiveHashSearch) {
        OnlineSummarizer onlineSummarizer = new OnlineSummarizer();
        for (int i = 0; i < 100; i++) {
            Vector viewRow = matrix.viewRow(i);
            List search = localitySensitiveHashSearch.search(viewRow, 150);
            BitSet bitSet = new BitSet();
            Iterator it = search.iterator();
            while (it.hasNext()) {
                bitSet.set(((WeightedVector) ((WeightedThing) it.next()).getValue()).getIndex());
            }
            List search2 = bruteSearch.search(viewRow, 100);
            BitSet bitSet2 = new BitSet();
            Iterator it2 = search2.iterator();
            while (it2.hasNext()) {
                bitSet2.set(((WeightedVector) ((WeightedThing) it2.next()).getValue()).getIndex());
            }
            bitSet.and(bitSet2);
            onlineSummarizer.add(bitSet.cardinality());
        }
        return onlineSummarizer;
    }

    @Test
    public void testDotCorrelation() {
        Normal normal = new Normal();
        DenseMatrix denseMatrix = new DenseMatrix(64, 10);
        denseMatrix.assign(normal);
        DenseVector denseVector = new DenseVector(10);
        denseVector.assign(normal);
        long computeHash64 = HashedVector.computeHash64(denseVector, denseMatrix);
        int[] iArr = new int[65];
        DenseVector denseVector2 = new DenseVector(10);
        for (int i = 0; i < 500000; i++) {
            denseVector2.assign(normal);
            int bitCount = Long.bitCount(computeHash64 ^ HashedVector.computeHash64(denseVector2, denseMatrix));
            iArr[bitCount] = iArr[bitCount] + 1;
            if (iArr[bitCount] < 200) {
                System.out.printf("%d, %.3f\n", Integer.valueOf(bitCount), Double.valueOf(denseVector2.dot(denseVector) / Math.sqrt(denseVector2.getLengthSquared() * denseVector.getLengthSquared())));
            }
        }
        for (int i2 = 0; i2 < 65; i2++) {
            System.out.printf("%d, %d\n", Integer.valueOf(i2), Integer.valueOf(iArr[i2]));
        }
    }
}
