package org.apache.mahout.clustering.dirichlet;

import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.io.DataInputBuffer;
import org.apache.hadoop.io.DataOutputBuffer;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.io.WritableComparable;
import org.apache.hadoop.mapred.JobConf;
import org.apache.hadoop.mapred.Reporter;
import org.apache.mahout.clustering.ClusteringTestUtils;
import org.apache.mahout.clustering.dirichlet.models.AsymmetricSampledNormalModel;
import org.apache.mahout.clustering.dirichlet.models.Model;
import org.apache.mahout.clustering.dirichlet.models.NormalModel;
import org.apache.mahout.clustering.dirichlet.models.NormalModelDistribution;
import org.apache.mahout.clustering.dirichlet.models.SampledNormalDistribution;
import org.apache.mahout.clustering.dirichlet.models.SampledNormalModel;
import org.apache.mahout.clustering.kmeans.KMeansDriver;
import org.apache.mahout.common.DummyOutputCollector;
import org.apache.mahout.common.MahoutTestCase;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.VectorWritable;

/* loaded from: input_file:org/apache/mahout/clustering/dirichlet/TestMapReduce.class */
public class TestMapReduce extends MahoutTestCase {
    private List<VectorWritable> sampleData = new ArrayList();
    private FileSystem fs;
    private Configuration conf;

    private void generateSamples(int i, double d, double d2, double d3, double d4) {
        System.out.println("Generating " + i + " samples m=[" + d + ", " + d2 + "] sd=[" + d3 + ", " + d4 + ']');
        for (int i2 = 0; i2 < i; i2++) {
            addSample(new double[]{UncommonDistributions.rNorm(d, d3), UncommonDistributions.rNorm(d2, d4)});
        }
    }

    private void addSample(double[] dArr) {
        DenseVector denseVector = new DenseVector(2);
        for (int i = 0; i < dArr.length; i++) {
            denseVector.setQuick(i, dArr[i]);
        }
        this.sampleData.add(new VectorWritable(denseVector));
    }

    private void generateSamples(int i, double d, double d2, double d3) {
        System.out.println("Generating " + i + " samples m=[" + d + ", " + d2 + "] sd=" + d3);
        for (int i2 = 0; i2 < i; i2++) {
            addSample(new double[]{UncommonDistributions.rNorm(d, d3), UncommonDistributions.rNorm(d2, d3)});
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.apache.mahout.common.MahoutTestCase
    public void setUp() throws Exception {
        super.setUp();
        RandomUtils.useTestSeed();
        ClusteringTestUtils.rmr("output");
        ClusteringTestUtils.rmr("input");
        this.conf = new Configuration();
        this.fs = FileSystem.get(this.conf);
        new File("input").mkdir();
    }

    public void testMapper() throws Exception {
        generateSamples(10, 0.0d, 0.0d, 1.0d);
        DirichletState dirichletState = new DirichletState(new NormalModelDistribution(new VectorWritable(new DenseVector(2))), 5, 1.0d);
        DirichletMapper dirichletMapper = new DirichletMapper();
        dirichletMapper.configure(dirichletState);
        DummyOutputCollector dummyOutputCollector = new DummyOutputCollector();
        Iterator<VectorWritable> it = this.sampleData.iterator();
        while (it.hasNext()) {
            dirichletMapper.map((WritableComparable) null, it.next(), dummyOutputCollector, (Reporter) null);
        }
    }

    public void testReducer() throws Exception {
        generateSamples(100, 0.0d, 0.0d, 1.0d);
        generateSamples(100, 2.0d, 0.0d, 1.0d);
        generateSamples(100, 0.0d, 2.0d, 1.0d);
        generateSamples(100, 2.0d, 2.0d, 1.0d);
        DirichletState dirichletState = new DirichletState(new SampledNormalDistribution(new VectorWritable(new DenseVector(2))), 20, 1.0d);
        DirichletMapper dirichletMapper = new DirichletMapper();
        dirichletMapper.configure(dirichletState);
        DummyOutputCollector dummyOutputCollector = new DummyOutputCollector();
        Iterator<VectorWritable> it = this.sampleData.iterator();
        while (it.hasNext()) {
            dirichletMapper.map((WritableComparable) null, it.next(), dummyOutputCollector, (Reporter) null);
        }
        DirichletReducer dirichletReducer = new DirichletReducer();
        dirichletReducer.configure(dirichletState);
        DummyOutputCollector dummyOutputCollector2 = new DummyOutputCollector();
        for (String str : dummyOutputCollector.getKeys()) {
            dirichletReducer.reduce(new Text(str), dummyOutputCollector.getValue(str).iterator(), dummyOutputCollector2, (Reporter) null);
        }
        dirichletState.update(dirichletReducer.getNewModels());
    }

    private static void printModels(Iterable<Model<VectorWritable>[]> iterable, int i) {
        int i2 = 0;
        for (Model<VectorWritable>[] modelArr : iterable) {
            int i3 = i2;
            i2++;
            System.out.print("sample[" + i3 + "]= ");
            for (int i4 = 0; i4 < modelArr.length; i4++) {
                Model<VectorWritable> model = modelArr[i4];
                if (model.count() > i) {
                    System.out.print("m" + i4 + model.toString() + ", ");
                }
            }
            System.out.println();
        }
        System.out.println();
    }

    public void testMRIterations() throws Exception {
        generateSamples(100, 0.0d, 0.0d, 1.0d);
        generateSamples(100, 2.0d, 0.0d, 1.0d);
        generateSamples(100, 0.0d, 2.0d, 1.0d);
        generateSamples(100, 2.0d, 2.0d, 1.0d);
        DirichletState dirichletState = new DirichletState(new SampledNormalDistribution(new VectorWritable(new DenseVector(2))), 20, 1.0d);
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < 10; i++) {
            DirichletMapper dirichletMapper = new DirichletMapper();
            dirichletMapper.configure(dirichletState);
            DummyOutputCollector dummyOutputCollector = new DummyOutputCollector();
            Iterator<VectorWritable> it = this.sampleData.iterator();
            while (it.hasNext()) {
                dirichletMapper.map((WritableComparable) null, it.next(), dummyOutputCollector, (Reporter) null);
            }
            DirichletReducer dirichletReducer = new DirichletReducer();
            dirichletReducer.configure(dirichletState);
            DummyOutputCollector dummyOutputCollector2 = new DummyOutputCollector();
            for (String str : dummyOutputCollector.getKeys()) {
                dirichletReducer.reduce(new Text(str), dummyOutputCollector.getValue(str).iterator(), dummyOutputCollector2, (Reporter) null);
            }
            Model[] newModels = dirichletReducer.getNewModels();
            dirichletState.update(newModels);
            arrayList.add(newModels);
        }
        printModels(arrayList, 0);
    }

    public void testDriverMRIterations() throws Exception {
        for (File file : new File("input").listFiles()) {
            file.delete();
        }
        generateSamples(100, 0.0d, 0.0d, 0.5d);
        generateSamples(100, 2.0d, 0.0d, 0.2d);
        generateSamples(100, 0.0d, 2.0d, 0.3d);
        generateSamples(100, 2.0d, 2.0d, 1.0d);
        ClusteringTestUtils.writePointsToFile(this.sampleData, "input/data.txt", this.fs, this.conf);
        DirichletDriver.runJob("input", "output", "org.apache.mahout.clustering.dirichlet.models.SampledNormalDistribution", 20, 10, 1.0d, 1);
        ArrayList arrayList = new ArrayList();
        JobConf jobConf = new JobConf(KMeansDriver.class);
        jobConf.set("org.apache.mahout.clustering.dirichlet.modelFactory", "org.apache.mahout.clustering.dirichlet.models.SampledNormalDistribution");
        jobConf.set("org.apache.mahout.clustering.dirichlet.modelPrototype", "org.apache.mahout.math.DenseVector");
        jobConf.set("org.apache.mahout.clustering.dirichlet.prototypeSize", "2");
        jobConf.set("org.apache.mahout.clustering.dirichlet.numClusters", "20");
        jobConf.set("org.apache.mahout.clustering.dirichlet.alpha_0", "1.0");
        for (int i = 0; i < 11; i++) {
            jobConf.set("org.apache.mahout.clustering.dirichlet.stateIn", "output/state-" + i);
            arrayList.add(DirichletMapper.getDirichletState(jobConf).getClusters());
        }
        printResults(arrayList, 0);
    }

    private static void printResults(List<List<DirichletCluster<VectorWritable>>> list, int i) {
        int i2 = 0;
        for (List<DirichletCluster<VectorWritable>> list2 : list) {
            int i3 = i2;
            i2++;
            System.out.print("sample[" + i3 + "]= ");
            for (int i4 = 0; i4 < list2.size(); i4++) {
                Model model = list2.get(i4).getModel();
                if (model.count() > i) {
                    System.out.print("m" + i4 + '(' + ((int) list2.get(i4).getTotalCount()) + ')' + model.toString() + ", ");
                }
            }
            System.out.println();
        }
        System.out.println();
    }

    public void testDriverMnRIterations() throws Exception {
        for (File file : new File("input").listFiles()) {
            file.delete();
        }
        generate4Datasets();
        DirichletDriver.runJob("input", "output", "org.apache.mahout.clustering.dirichlet.models.SampledNormalDistribution", 20, 15, 1.0d, 1);
        ArrayList arrayList = new ArrayList();
        JobConf jobConf = new JobConf(KMeansDriver.class);
        jobConf.set("org.apache.mahout.clustering.dirichlet.modelFactory", "org.apache.mahout.clustering.dirichlet.models.SampledNormalDistribution");
        jobConf.set("org.apache.mahout.clustering.dirichlet.modelPrototype", "org.apache.mahout.math.DenseVector");
        jobConf.set("org.apache.mahout.clustering.dirichlet.prototypeSize", "2");
        jobConf.set("org.apache.mahout.clustering.dirichlet.numClusters", "20");
        jobConf.set("org.apache.mahout.clustering.dirichlet.alpha_0", "1.0");
        for (int i = 0; i < 11; i++) {
            jobConf.set("org.apache.mahout.clustering.dirichlet.stateIn", "output/state-" + i);
            arrayList.add(DirichletMapper.getDirichletState(jobConf).getClusters());
        }
        printResults(arrayList, 0);
    }

    private void generate4Datasets() throws IOException {
        generateSamples(500, 0.0d, 0.0d, 0.5d);
        ClusteringTestUtils.writePointsToFile(this.sampleData, "input/data1.txt", this.fs, this.conf);
        this.sampleData = new ArrayList();
        generateSamples(500, 2.0d, 0.0d, 0.2d);
        ClusteringTestUtils.writePointsToFile(this.sampleData, "input/data2.txt", this.fs, this.conf);
        this.sampleData = new ArrayList();
        generateSamples(500, 0.0d, 2.0d, 0.3d);
        ClusteringTestUtils.writePointsToFile(this.sampleData, "input/data3.txt", this.fs, this.conf);
        this.sampleData = new ArrayList();
        generateSamples(500, 2.0d, 2.0d, 1.0d);
        ClusteringTestUtils.writePointsToFile(this.sampleData, "input/data4.txt", this.fs, this.conf);
    }

    public void testDriverMnRnIterations() throws Exception {
        for (File file : new File("input").listFiles()) {
            file.delete();
        }
        generate4Datasets();
        DirichletDriver.runJob("input", "output", "org.apache.mahout.clustering.dirichlet.models.SampledNormalDistribution", 20, 15, 1.0d, 2);
        ArrayList arrayList = new ArrayList();
        JobConf jobConf = new JobConf(KMeansDriver.class);
        jobConf.set("org.apache.mahout.clustering.dirichlet.modelFactory", "org.apache.mahout.clustering.dirichlet.models.SampledNormalDistribution");
        jobConf.set("org.apache.mahout.clustering.dirichlet.modelPrototype", "org.apache.mahout.math.DenseVector");
        jobConf.set("org.apache.mahout.clustering.dirichlet.prototypeSize", "2");
        jobConf.set("org.apache.mahout.clustering.dirichlet.numClusters", "20");
        jobConf.set("org.apache.mahout.clustering.dirichlet.alpha_0", "1.0");
        for (int i = 0; i < 11; i++) {
            jobConf.set("org.apache.mahout.clustering.dirichlet.stateIn", "output/state-" + i);
            arrayList.add(DirichletMapper.getDirichletState(jobConf).getClusters());
        }
        printResults(arrayList, 0);
    }

    public void testDriverMnRnIterationsAsymmetric() throws Exception {
        for (File file : new File("input").listFiles()) {
            file.delete();
        }
        generateSamples(500, 0.0d, 0.0d, 0.5d, 1.0d);
        ClusteringTestUtils.writePointsToFile(this.sampleData, "input/data1.txt", this.fs, this.conf);
        this.sampleData = new ArrayList();
        generateSamples(500, 2.0d, 0.0d, 0.2d);
        ClusteringTestUtils.writePointsToFile(this.sampleData, "input/data2.txt", this.fs, this.conf);
        this.sampleData = new ArrayList();
        generateSamples(500, 0.0d, 2.0d, 0.3d);
        ClusteringTestUtils.writePointsToFile(this.sampleData, "input/data3.txt", this.fs, this.conf);
        this.sampleData = new ArrayList();
        generateSamples(500, 2.0d, 2.0d, 1.0d);
        ClusteringTestUtils.writePointsToFile(this.sampleData, "input/data4.txt", this.fs, this.conf);
        DirichletDriver.runJob("input", "output", "org.apache.mahout.clustering.dirichlet.models.AsymmetricSampledNormalDistribution", 20, 15, 1.0d, 2);
        ArrayList arrayList = new ArrayList();
        JobConf jobConf = new JobConf(KMeansDriver.class);
        jobConf.set("org.apache.mahout.clustering.dirichlet.modelFactory", "org.apache.mahout.clustering.dirichlet.models.AsymmetricSampledNormalDistribution");
        jobConf.set("org.apache.mahout.clustering.dirichlet.modelPrototype", "org.apache.mahout.math.DenseVector");
        jobConf.set("org.apache.mahout.clustering.dirichlet.prototypeSize", "2");
        jobConf.set("org.apache.mahout.clustering.dirichlet.numClusters", "20");
        jobConf.set("org.apache.mahout.clustering.dirichlet.alpha_0", "1.0");
        for (int i = 0; i < 11; i++) {
            jobConf.set("org.apache.mahout.clustering.dirichlet.stateIn", "output/state-" + i);
            arrayList.add(DirichletMapper.getDirichletState(jobConf).getClusters());
        }
        printResults(arrayList, 0);
    }

    public void testNormalModelWritableSerialization() throws Exception {
        NormalModel normalModel = new NormalModel(new DenseVector(new double[]{1.1d, 2.2d, 3.3d}), 3.3d);
        DataOutputBuffer dataOutputBuffer = new DataOutputBuffer();
        normalModel.write(dataOutputBuffer);
        NormalModel normalModel2 = new NormalModel();
        DataInputBuffer dataInputBuffer = new DataInputBuffer();
        dataInputBuffer.reset(dataOutputBuffer.getData(), dataOutputBuffer.getLength());
        normalModel2.readFields(dataInputBuffer);
        assertEquals("models", normalModel.toString(), normalModel2.toString());
    }

    public void testSampledNormalModelWritableSerialization() throws Exception {
        SampledNormalModel sampledNormalModel = new SampledNormalModel(new DenseVector(new double[]{1.1d, 2.2d, 3.3d}), 3.3d);
        DataOutputBuffer dataOutputBuffer = new DataOutputBuffer();
        sampledNormalModel.write(dataOutputBuffer);
        SampledNormalModel sampledNormalModel2 = new SampledNormalModel();
        DataInputBuffer dataInputBuffer = new DataInputBuffer();
        dataInputBuffer.reset(dataOutputBuffer.getData(), dataOutputBuffer.getLength());
        sampledNormalModel2.readFields(dataInputBuffer);
        assertEquals("models", sampledNormalModel.toString(), sampledNormalModel2.toString());
    }

    public void testAsymmetricSampledNormalModelWritableSerialization() throws Exception {
        AsymmetricSampledNormalModel asymmetricSampledNormalModel = new AsymmetricSampledNormalModel(new DenseVector(new double[]{1.1d, 2.2d, 3.3d}), new DenseVector(new double[]{3.3d, 4.4d, 5.5d}));
        DataOutputBuffer dataOutputBuffer = new DataOutputBuffer();
        asymmetricSampledNormalModel.write(dataOutputBuffer);
        AsymmetricSampledNormalModel asymmetricSampledNormalModel2 = new AsymmetricSampledNormalModel();
        DataInputBuffer dataInputBuffer = new DataInputBuffer();
        dataInputBuffer.reset(dataOutputBuffer.getData(), dataOutputBuffer.getLength());
        asymmetricSampledNormalModel2.readFields(dataInputBuffer);
        assertEquals("models", asymmetricSampledNormalModel.toString(), asymmetricSampledNormalModel2.toString());
    }

    public void testClusterWritableSerialization() throws Exception {
        DirichletCluster dirichletCluster = new DirichletCluster(new NormalModel(new DenseVector(new double[]{1.1d, 2.2d, 3.3d}), 4.0d), 10.0d);
        DataOutputBuffer dataOutputBuffer = new DataOutputBuffer();
        dirichletCluster.write(dataOutputBuffer);
        DirichletCluster dirichletCluster2 = new DirichletCluster();
        DataInputBuffer dataInputBuffer = new DataInputBuffer();
        dataInputBuffer.reset(dataOutputBuffer.getData(), dataOutputBuffer.getLength());
        dirichletCluster2.readFields(dataInputBuffer);
        assertEquals("count", Double.valueOf(dirichletCluster.getTotalCount()), Double.valueOf(dirichletCluster2.getTotalCount()));
        assertNotNull("model null", dirichletCluster2.getModel());
        assertEquals("model", dirichletCluster.getModel().toString(), dirichletCluster2.getModel().toString());
    }
}
