package org.apache.mahout.clustering.dirichlet;

import com.google.common.collect.Lists;
import com.google.common.io.Closeables;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FSDataOutputStream;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.util.ToolRunner;
import org.apache.mahout.clustering.Cluster;
import org.apache.mahout.clustering.ClusteringTestUtils;
import org.apache.mahout.clustering.classify.ClusterClassifier;
import org.apache.mahout.clustering.dirichlet.models.DistanceMeasureClusterDistribution;
import org.apache.mahout.clustering.dirichlet.models.DistributionDescription;
import org.apache.mahout.clustering.dirichlet.models.GaussianClusterDistribution;
import org.apache.mahout.common.MahoutTestCase;
import org.apache.mahout.common.distance.MahalanobisDistanceMeasure;
import org.apache.mahout.math.DenseMatrix;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.MatrixWritable;
import org.apache.mahout.math.VectorWritable;
import org.junit.Before;
import org.junit.Test;

/* loaded from: input_file:org/apache/mahout/clustering/dirichlet/TestMapReduce.class */
public final class TestMapReduce extends MahoutTestCase {
    private Collection<VectorWritable> sampleData = Lists.newArrayList();
    private FileSystem fs;
    private Configuration conf;

    private void addSample(double[] dArr) {
        DenseVector denseVector = new DenseVector(2);
        for (int i = 0; i < dArr.length; i++) {
            denseVector.setQuick(i, dArr[i]);
        }
        this.sampleData.add(new VectorWritable(denseVector));
    }

    private void generateSamples(int i, double d, double d2, double d3) {
        System.out.println("Generating " + i + " samples m=[" + d + ", " + d2 + "] sd=" + d3);
        for (int i2 = 0; i2 < i; i2++) {
            addSample(new double[]{UncommonDistributions.rNorm(d, d3), UncommonDistributions.rNorm(d2, d3)});
        }
    }

    private void generateAsymmetricSamples(int i, double d, double d2, double d3, double d4) {
        System.out.println("Generating " + i + " samples m=[" + d + ", " + d2 + "] sd=[" + d3 + ", " + d4 + ']');
        for (int i2 = 0; i2 < i; i2++) {
            addSample(new double[]{UncommonDistributions.rNorm(d, d3), UncommonDistributions.rNorm(d2, d4)});
        }
    }

    @Override // org.apache.mahout.common.MahoutTestCase
    @Before
    public void setUp() throws Exception {
        super.setUp();
        this.conf = new Configuration();
        this.fs = FileSystem.get(this.conf);
    }

    @Test
    public void testDriverIterationsSeq() throws Exception {
        generateSamples(100, 0.0d, 0.0d, 0.5d);
        generateSamples(100, 2.0d, 0.0d, 0.2d);
        generateSamples(100, 0.0d, 2.0d, 0.3d);
        generateSamples(100, 2.0d, 2.0d, 1.0d);
        ClusteringTestUtils.writePointsToFile(this.sampleData, getTestTempFilePath("input/data.txt"), this.fs, this.conf);
        Integer num = 5;
        DistributionDescription distributionDescription = new DistributionDescription(GaussianClusterDistribution.class.getName(), DenseVector.class.getName(), (String) null, 2);
        Path testTempDirPath = getTestTempDirPath("output");
        ToolRunner.run(this.conf, new DirichletDriver(), new String[]{optKey("input"), getTestTempDirPath("input").toString(), optKey("output"), testTempDirPath.toString(), optKey("modelDist"), distributionDescription.getModelFactory(), optKey("modelPrototype"), distributionDescription.getModelPrototype(), optKey("numClusters"), "20", optKey("maxIter"), num.toString(), optKey("alpha"), "1.0", optKey("overwrite"), optKey("clustering"), optKey("method"), "sequential"});
        printModels(getClusters(testTempDirPath, num.intValue()));
    }

    @Test
    public void testDriverIterationsMR() throws Exception {
        generateSamples(100, 0.0d, 0.0d, 0.5d);
        generateSamples(100, 2.0d, 0.0d, 0.2d);
        generateSamples(100, 0.0d, 2.0d, 0.3d);
        generateSamples(100, 2.0d, 2.0d, 1.0d);
        ClusteringTestUtils.writePointsToFile(this.sampleData, true, getTestTempFilePath("input/data.txt"), this.fs, this.conf);
        Integer num = 5;
        DistributionDescription distributionDescription = new DistributionDescription(GaussianClusterDistribution.class.getName(), DenseVector.class.getName(), (String) null, 2);
        Path testTempDirPath = getTestTempDirPath("output");
        ToolRunner.run(this.conf, new DirichletDriver(), new String[]{optKey("input"), getTestTempDirPath("input").toString(), optKey("output"), testTempDirPath.toString(), optKey("modelDist"), distributionDescription.getModelFactory(), optKey("modelPrototype"), distributionDescription.getModelPrototype(), optKey("numClusters"), "20", optKey("maxIter"), num.toString(), optKey("alpha"), "1.0", optKey("overwrite"), optKey("clustering")});
        printModels(getClusters(testTempDirPath, num.intValue()));
    }

    /* JADX WARN: Type inference failed for: r2v10, types: [double[], double[][]] */
    @Test
    public void testDriverIterationsMahalanobisSeq() throws Exception {
        generateAsymmetricSamples(100, 0.0d, 0.0d, 0.5d, 3.0d);
        generateAsymmetricSamples(100, 0.0d, 3.0d, 0.3d, 4.0d);
        ClusteringTestUtils.writePointsToFile(this.sampleData, getTestTempFilePath("input/data.txt"), this.fs, this.conf);
        MahalanobisDistanceMeasure mahalanobisDistanceMeasure = new MahalanobisDistanceMeasure();
        DistributionDescription distributionDescription = new DistributionDescription(DistanceMeasureClusterDistribution.class.getName(), DenseVector.class.getName(), MahalanobisDistanceMeasure.class.getName(), 2);
        DenseVector denseVector = new DenseVector(new double[]{0.0d, 0.0d});
        mahalanobisDistanceMeasure.setMeanVector(denseVector);
        mahalanobisDistanceMeasure.setCovarianceMatrix(new DenseMatrix((double[][]) new double[]{new double[]{0.5d, 0.0d}, new double[]{0.0d, 4.0d}}));
        Path path = new Path(getTestTempDirPath("mahalanobis"), "MahalanobisDistanceMeasureInverseCovarianceFile");
        this.conf.set("MahalanobisDistanceMeasure.inverseCovarianceFile", path.toString());
        FileSystem fileSystem = FileSystem.get(path.toUri(), this.conf);
        MatrixWritable matrixWritable = new MatrixWritable(mahalanobisDistanceMeasure.getInverseCovarianceMatrix());
        FSDataOutputStream create = fileSystem.create(path);
        try {
            matrixWritable.write(create);
            Closeables.closeQuietly(create);
            Path path2 = new Path(getTestTempDirPath("mahalanobis"), "MahalanobisDistanceMeasureMeanVectorFile");
            this.conf.set("MahalanobisDistanceMeasure.meanVectorFile", path2.toString());
            FileSystem fileSystem2 = FileSystem.get(path2.toUri(), this.conf);
            VectorWritable vectorWritable = new VectorWritable(denseVector);
            create = fileSystem2.create(path2);
            try {
                vectorWritable.write(create);
                Closeables.closeQuietly(create);
                this.conf.set("MahalanobisDistanceMeasure.maxtrixClass", MatrixWritable.class.getName());
                this.conf.set("MahalanobisDistanceMeasure.vectorClass", VectorWritable.class.getName());
                Integer num = 5;
                Path testTempDirPath = getTestTempDirPath("output");
                String[] strArr = {optKey("input"), getTestTempDirPath("input").toString(), optKey("output"), testTempDirPath.toString(), optKey("modelDist"), distributionDescription.getModelFactory(), optKey("distanceMeasure"), distributionDescription.getDistanceMeasure(), optKey("modelPrototype"), distributionDescription.getModelPrototype(), optKey("numClusters"), "20", optKey("maxIter"), num.toString(), optKey("alpha"), "1.0", optKey("overwrite"), optKey("clustering"), optKey("method"), "sequential"};
                DirichletDriver dirichletDriver = new DirichletDriver();
                dirichletDriver.setConf(this.conf);
                dirichletDriver.run(strArr);
                printModels(getClusters(testTempDirPath, num.intValue()));
            } finally {
            }
        } finally {
        }
    }

    /* JADX WARN: Type inference failed for: r2v9, types: [double[], double[][]] */
    @Test
    public void testDriverIterationsMahalanobisMR() throws Exception {
        generateAsymmetricSamples(100, 0.0d, 0.0d, 0.5d, 3.0d);
        generateAsymmetricSamples(100, 0.0d, 3.0d, 0.3d, 4.0d);
        ClusteringTestUtils.writePointsToFile(this.sampleData, true, getTestTempFilePath("input/data.txt"), this.fs, this.conf);
        MahalanobisDistanceMeasure mahalanobisDistanceMeasure = new MahalanobisDistanceMeasure();
        DistributionDescription distributionDescription = new DistributionDescription(DistanceMeasureClusterDistribution.class.getName(), DenseVector.class.getName(), MahalanobisDistanceMeasure.class.getName(), 2);
        DenseVector denseVector = new DenseVector(new double[]{0.0d, 0.0d});
        mahalanobisDistanceMeasure.setMeanVector(denseVector);
        mahalanobisDistanceMeasure.setCovarianceMatrix(new DenseMatrix((double[][]) new double[]{new double[]{0.5d, 0.0d}, new double[]{0.0d, 4.0d}}));
        Path path = new Path(getTestTempDirPath("mahalanobis"), "MahalanobisDistanceMeasureInverseCovarianceFile");
        this.conf.set("MahalanobisDistanceMeasure.inverseCovarianceFile", path.toString());
        FileSystem fileSystem = FileSystem.get(path.toUri(), this.conf);
        MatrixWritable matrixWritable = new MatrixWritable(mahalanobisDistanceMeasure.getInverseCovarianceMatrix());
        FSDataOutputStream create = fileSystem.create(path);
        try {
            matrixWritable.write(create);
            Closeables.closeQuietly(create);
            Path path2 = new Path(getTestTempDirPath("mahalanobis"), "MahalanobisDistanceMeasureMeanVectorFile");
            this.conf.set("MahalanobisDistanceMeasure.meanVectorFile", path2.toString());
            FileSystem fileSystem2 = FileSystem.get(path2.toUri(), this.conf);
            VectorWritable vectorWritable = new VectorWritable(denseVector);
            create = fileSystem2.create(path2);
            try {
                vectorWritable.write(create);
                Closeables.closeQuietly(create);
                this.conf.set("MahalanobisDistanceMeasure.maxtrixClass", MatrixWritable.class.getName());
                this.conf.set("MahalanobisDistanceMeasure.vectorClass", VectorWritable.class.getName());
                Integer num = 5;
                Path testTempDirPath = getTestTempDirPath("output");
                String[] strArr = {optKey("input"), getTestTempDirPath("input").toString(), optKey("output"), testTempDirPath.toString(), optKey("modelDist"), distributionDescription.getModelFactory(), optKey("distanceMeasure"), distributionDescription.getDistanceMeasure(), optKey("modelPrototype"), distributionDescription.getModelPrototype(), optKey("numClusters"), "20", optKey("maxIter"), num.toString(), optKey("alpha"), "1.0", optKey("overwrite"), optKey("clustering")};
                DirichletDriver dirichletDriver = new DirichletDriver();
                dirichletDriver.setConf(this.conf);
                ToolRunner.run(this.conf, dirichletDriver, strArr);
                printModels(getClusters(testTempDirPath, num.intValue()));
            } finally {
            }
        } finally {
        }
    }

    private void printModels(List<List<Cluster>> list) {
        int i = 0;
        StringBuilder sb = new StringBuilder(100);
        for (List<Cluster> list2 : list) {
            int i2 = i;
            i++;
            sb.append("sample[").append(i2).append("]= ");
            for (int i3 = 0; i3 < list2.size(); i3++) {
                sb.append('m').append(i3).append(list2.get(i3).asFormatString((String[]) null)).append(", ");
            }
            sb.append('\n');
        }
        sb.append('\n');
        System.out.println(sb.toString());
    }

    private List<List<Cluster>> getClusters(Path path, int i) throws IOException {
        ArrayList arrayList = new ArrayList();
        int i2 = 1;
        while (i2 <= i) {
            ClusterClassifier clusterClassifier = new ClusterClassifier();
            clusterClassifier.readFromSeqFiles(this.conf, new Path(path, i2 == i ? "clusters-" + i2 + "-final" : "clusters-" + i2));
            ArrayList newArrayList = Lists.newArrayList();
            Iterator it = clusterClassifier.getModels().iterator();
            while (it.hasNext()) {
                newArrayList.add((Cluster) it.next());
            }
            arrayList.add(newArrayList);
            i2++;
        }
        return arrayList;
    }
}
