package org.apache.mahout.clustering;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import org.apache.mahout.clustering.dirichlet.UncommonDistributions;
import org.apache.mahout.common.MahoutTestCase;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
import org.apache.mahout.math.function.SquareRootFunction;
import org.junit.Before;
import org.junit.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/mahout/clustering/TestGaussianAccumulators.class */
public final class TestGaussianAccumulators extends MahoutTestCase {
    private static final Logger log = LoggerFactory.getLogger(TestGaussianAccumulators.class);
    private Collection<VectorWritable> sampleData = new ArrayList();
    private int sampleN;
    private Vector sampleMean;
    private Vector sampleStd;

    @Override // org.apache.mahout.common.MahoutTestCase
    @Before
    public void setUp() throws Exception {
        super.setUp();
        this.sampleData = new ArrayList();
        generateSamples();
        this.sampleN = 0;
        DenseVector denseVector = new DenseVector(2);
        Iterator<VectorWritable> it = this.sampleData.iterator();
        while (it.hasNext()) {
            it.next().get().addTo(denseVector);
            this.sampleN++;
        }
        this.sampleMean = denseVector.divide(this.sampleN);
        DenseVector denseVector2 = new DenseVector(2);
        Iterator<VectorWritable> it2 = this.sampleData.iterator();
        while (it2.hasNext()) {
            Vector minus = it2.next().get().minus(this.sampleMean);
            minus.times(minus).addTo(denseVector2);
        }
        this.sampleStd = denseVector2.divide(this.sampleN - 1).clone();
        this.sampleStd.assign(new SquareRootFunction());
        log.info("Observing {} samples m=[{}, {}] sd=[{}, {}]", new Object[]{Integer.valueOf(this.sampleN), Double.valueOf(this.sampleMean.get(0)), Double.valueOf(this.sampleMean.get(1)), Double.valueOf(this.sampleStd.get(0)), Double.valueOf(this.sampleStd.get(1))});
    }

    private void generate2dSamples(int i, double d, double d2, double d3, double d4) {
        log.info("Generating {} samples m=[{}, {}] sd=[{}, {}]", new Object[]{Integer.valueOf(i), Double.valueOf(d), Double.valueOf(d2), Double.valueOf(d3), Double.valueOf(d4)});
        for (int i2 = 0; i2 < i; i2++) {
            this.sampleData.add(new VectorWritable(new DenseVector(new double[]{UncommonDistributions.rNorm(d, d3), UncommonDistributions.rNorm(d2, d4)})));
        }
    }

    private void generateSamples() {
        generate2dSamples(50000, 1.0d, 2.0d, 3.0d, 4.0d);
    }

    @Test
    public void testAccumulatorNoSamples() {
        RunningSumsGaussianAccumulator runningSumsGaussianAccumulator = new RunningSumsGaussianAccumulator();
        OnlineGaussianAccumulator onlineGaussianAccumulator = new OnlineGaussianAccumulator();
        runningSumsGaussianAccumulator.compute();
        onlineGaussianAccumulator.compute();
        assertEquals("N", runningSumsGaussianAccumulator.getN(), onlineGaussianAccumulator.getN(), 1.0E-6d);
        assertEquals("Means", runningSumsGaussianAccumulator.getMean(), onlineGaussianAccumulator.getMean());
        assertEquals("Avg Stds", runningSumsGaussianAccumulator.getAverageStd(), onlineGaussianAccumulator.getAverageStd(), 1.0E-6d);
    }

    @Test
    public void testAccumulatorOneSample() {
        RunningSumsGaussianAccumulator runningSumsGaussianAccumulator = new RunningSumsGaussianAccumulator();
        OnlineGaussianAccumulator onlineGaussianAccumulator = new OnlineGaussianAccumulator();
        DenseVector denseVector = new DenseVector(2);
        runningSumsGaussianAccumulator.observe(denseVector, 1.0d);
        onlineGaussianAccumulator.observe(denseVector, 1.0d);
        runningSumsGaussianAccumulator.compute();
        onlineGaussianAccumulator.compute();
        assertEquals("N", runningSumsGaussianAccumulator.getN(), onlineGaussianAccumulator.getN(), 1.0E-6d);
        assertEquals("Means", runningSumsGaussianAccumulator.getMean(), onlineGaussianAccumulator.getMean());
        assertEquals("Avg Stds", runningSumsGaussianAccumulator.getAverageStd(), onlineGaussianAccumulator.getAverageStd(), 1.0E-6d);
    }

    @Test
    public void testOLAccumulatorResults() {
        OnlineGaussianAccumulator onlineGaussianAccumulator = new OnlineGaussianAccumulator();
        Iterator<VectorWritable> it = this.sampleData.iterator();
        while (it.hasNext()) {
            onlineGaussianAccumulator.observe(it.next().get(), 1.0d);
        }
        onlineGaussianAccumulator.compute();
        log.info("OL Observed {} samples m=[{}, {}] sd=[{}, {}]", new Object[]{Double.valueOf(onlineGaussianAccumulator.getN()), Double.valueOf(onlineGaussianAccumulator.getMean().get(0)), Double.valueOf(onlineGaussianAccumulator.getMean().get(1)), Double.valueOf(onlineGaussianAccumulator.getStd().get(0)), Double.valueOf(onlineGaussianAccumulator.getStd().get(1))});
        assertEquals("OL N", this.sampleN, onlineGaussianAccumulator.getN(), 1.0E-6d);
        assertEquals("OL Mean", this.sampleMean.zSum(), onlineGaussianAccumulator.getMean().zSum(), 1.0E-6d);
        assertEquals("OL Std", this.sampleStd.zSum(), onlineGaussianAccumulator.getStd().zSum(), 1.0E-6d);
    }

    @Test
    public void testRSAccumulatorResults() {
        RunningSumsGaussianAccumulator runningSumsGaussianAccumulator = new RunningSumsGaussianAccumulator();
        Iterator<VectorWritable> it = this.sampleData.iterator();
        while (it.hasNext()) {
            runningSumsGaussianAccumulator.observe(it.next().get(), 1.0d);
        }
        runningSumsGaussianAccumulator.compute();
        log.info("RS Observed {} samples m=[{}, {}] sd=[{}, {}]", new Object[]{Integer.valueOf((int) runningSumsGaussianAccumulator.getN()), Double.valueOf(runningSumsGaussianAccumulator.getMean().get(0)), Double.valueOf(runningSumsGaussianAccumulator.getMean().get(1)), Double.valueOf(runningSumsGaussianAccumulator.getStd().get(0)), Double.valueOf(runningSumsGaussianAccumulator.getStd().get(1))});
        assertEquals("OL N", this.sampleN, runningSumsGaussianAccumulator.getN(), 1.0E-6d);
        assertEquals("OL Mean", this.sampleMean.zSum(), runningSumsGaussianAccumulator.getMean().zSum(), 1.0E-6d);
        assertEquals("OL Std", this.sampleStd.zSum(), runningSumsGaussianAccumulator.getStd().zSum(), 1.0E-4d);
    }

    @Test
    public void testAccumulatorWeightedResults() {
        RunningSumsGaussianAccumulator runningSumsGaussianAccumulator = new RunningSumsGaussianAccumulator();
        OnlineGaussianAccumulator onlineGaussianAccumulator = new OnlineGaussianAccumulator();
        for (VectorWritable vectorWritable : this.sampleData) {
            runningSumsGaussianAccumulator.observe(vectorWritable.get(), 0.5d);
            onlineGaussianAccumulator.observe(vectorWritable.get(), 0.5d);
        }
        runningSumsGaussianAccumulator.compute();
        onlineGaussianAccumulator.compute();
        assertEquals("N", runningSumsGaussianAccumulator.getN(), onlineGaussianAccumulator.getN(), 1.0E-6d);
        assertEquals("Means", runningSumsGaussianAccumulator.getMean().zSum(), onlineGaussianAccumulator.getMean().zSum(), 1.0E-6d);
        assertEquals("Stds", runningSumsGaussianAccumulator.getStd().zSum(), onlineGaussianAccumulator.getStd().zSum(), 0.001d);
        assertEquals("Variance", runningSumsGaussianAccumulator.getVariance().zSum(), onlineGaussianAccumulator.getVariance().zSum(), 0.01d);
    }

    @Test
    public void testAccumulatorWeightedResults2() {
        RunningSumsGaussianAccumulator runningSumsGaussianAccumulator = new RunningSumsGaussianAccumulator();
        OnlineGaussianAccumulator onlineGaussianAccumulator = new OnlineGaussianAccumulator();
        for (VectorWritable vectorWritable : this.sampleData) {
            runningSumsGaussianAccumulator.observe(vectorWritable.get(), 1.5d);
            onlineGaussianAccumulator.observe(vectorWritable.get(), 1.5d);
        }
        runningSumsGaussianAccumulator.compute();
        onlineGaussianAccumulator.compute();
        assertEquals("N", runningSumsGaussianAccumulator.getN(), onlineGaussianAccumulator.getN(), 1.0E-6d);
        assertEquals("Means", runningSumsGaussianAccumulator.getMean().zSum(), onlineGaussianAccumulator.getMean().zSum(), 1.0E-6d);
        assertEquals("Stds", runningSumsGaussianAccumulator.getStd().zSum(), onlineGaussianAccumulator.getStd().zSum(), 0.001d);
        assertEquals("Variance", runningSumsGaussianAccumulator.getVariance().zSum(), onlineGaussianAccumulator.getVariance().zSum(), 0.01d);
    }
}
