package org.apache.mahout.clustering.classify;

import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileStatus;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.FileUtil;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.SequenceFile;
import org.apache.mahout.clustering.ClusteringTestUtils;
import org.apache.mahout.clustering.canopy.CanopyDriver;
import org.apache.mahout.clustering.iterator.CanopyClusteringPolicy;
import org.apache.mahout.common.HadoopUtil;
import org.apache.mahout.common.MahoutTestCase;
import org.apache.mahout.common.distance.ManhattanDistanceMeasure;
import org.apache.mahout.common.iterator.sequencefile.PathFilters;
import org.apache.mahout.math.RandomAccessSparseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;

/* loaded from: input_file:org/apache/mahout/clustering/classify/ClusterClassificationDriverTest.class */
public class ClusterClassificationDriverTest extends MahoutTestCase {
    private static final double[][] REFERENCE = {new double[]{1.0d, 1.0d}, new double[]{2.0d, 1.0d}, new double[]{1.0d, 2.0d}, new double[]{4.0d, 4.0d}, new double[]{5.0d, 4.0d}, new double[]{4.0d, 5.0d}, new double[]{5.0d, 5.0d}, new double[]{9.0d, 9.0d}, new double[]{8.0d, 8.0d}};
    private FileSystem fs;
    private Path clusteringOutputPath;
    private Configuration conf;
    private Path pointsPath;
    private Path classifiedOutputPath;
    private List<Vector> firstCluster;
    private List<Vector> secondCluster;
    private List<Vector> thirdCluster;

    @Override // org.apache.mahout.common.MahoutTestCase
    @Before
    public void setUp() throws Exception {
        super.setUp();
        this.fs = FileSystem.get(getConfiguration());
        this.firstCluster = Lists.newArrayList();
        this.secondCluster = Lists.newArrayList();
        this.thirdCluster = Lists.newArrayList();
    }

    private static List<VectorWritable> getPointsWritable(double[][] dArr) {
        ArrayList newArrayList = Lists.newArrayList();
        for (double[] dArr2 : dArr) {
            RandomAccessSparseVector randomAccessSparseVector = new RandomAccessSparseVector(dArr2.length);
            randomAccessSparseVector.assign(dArr2);
            newArrayList.add(new VectorWritable(randomAccessSparseVector));
        }
        return newArrayList;
    }

    @Test
    public void testVectorClassificationWithOutlierRemovalMR() throws Exception {
        List<VectorWritable> pointsWritable = getPointsWritable(REFERENCE);
        this.pointsPath = getTestTempDirPath("points");
        this.clusteringOutputPath = getTestTempDirPath("output");
        this.classifiedOutputPath = getTestTempDirPath("classifiedClusters");
        HadoopUtil.delete(this.conf, new Path[]{this.classifiedOutputPath});
        this.conf = getConfiguration();
        ClusteringTestUtils.writePointsToFile(pointsWritable, true, new Path(this.pointsPath, "file1"), this.fs, this.conf);
        runClustering(this.pointsPath, this.conf, false);
        runClassificationWithOutlierRemoval(false);
        collectVectorsForAssertion();
        assertVectorsWithOutlierRemoval();
    }

    @Test
    public void testVectorClassificationWithoutOutlierRemoval() throws Exception {
        List<VectorWritable> pointsWritable = getPointsWritable(REFERENCE);
        this.pointsPath = getTestTempDirPath("points");
        this.clusteringOutputPath = getTestTempDirPath("output");
        this.classifiedOutputPath = getTestTempDirPath("classify");
        this.conf = getConfiguration();
        ClusteringTestUtils.writePointsToFile(pointsWritable, new Path(this.pointsPath, "file1"), this.fs, this.conf);
        runClustering(this.pointsPath, this.conf, true);
        runClassificationWithoutOutlierRemoval();
        collectVectorsForAssertion();
        assertVectorsWithoutOutlierRemoval();
    }

    @Test
    public void testVectorClassificationWithOutlierRemoval() throws Exception {
        List<VectorWritable> pointsWritable = getPointsWritable(REFERENCE);
        this.pointsPath = getTestTempDirPath("points");
        this.clusteringOutputPath = getTestTempDirPath("output");
        this.classifiedOutputPath = getTestTempDirPath("classify");
        this.conf = new Configuration();
        ClusteringTestUtils.writePointsToFile(pointsWritable, new Path(this.pointsPath, "file1"), this.fs, this.conf);
        runClustering(this.pointsPath, this.conf, true);
        runClassificationWithOutlierRemoval(true);
        collectVectorsForAssertion();
        assertVectorsWithOutlierRemoval();
    }

    private void runClustering(Path path, Configuration configuration, Boolean bool) throws IOException, InterruptedException, ClassNotFoundException {
        CanopyDriver.run(configuration, path, this.clusteringOutputPath, new ManhattanDistanceMeasure(), 3.1d, 2.1d, false, 0.0d, bool.booleanValue());
        ClusterClassifier.writePolicy(new CanopyClusteringPolicy(), new Path(this.clusteringOutputPath, "clusters-0-final"));
    }

    private void runClassificationWithoutOutlierRemoval() throws IOException, InterruptedException, ClassNotFoundException {
        ClusterClassificationDriver.run(getConfiguration(), this.pointsPath, this.clusteringOutputPath, this.classifiedOutputPath, 0.0d, true, true);
    }

    private void runClassificationWithOutlierRemoval(boolean z) throws IOException, InterruptedException, ClassNotFoundException {
        ClusterClassificationDriver.run(getConfiguration(), this.pointsPath, this.clusteringOutputPath, this.classifiedOutputPath, 0.73d, true, z);
    }

    private void collectVectorsForAssertion() throws IOException {
        for (FileStatus fileStatus : this.fs.listStatus(FileUtil.stat2Paths(this.fs.globStatus(this.classifiedOutputPath)), PathFilters.partFilter())) {
            SequenceFile.Reader reader = new SequenceFile.Reader(this.fs, fileStatus.getPath(), this.conf);
            IntWritable intWritable = new IntWritable();
            WeightedVectorWritable weightedVectorWritable = new WeightedVectorWritable();
            while (reader.next(intWritable, weightedVectorWritable)) {
                collectVector(intWritable.toString(), weightedVectorWritable.getVector());
            }
        }
    }

    private void collectVector(String str, Vector vector) {
        if ("0".equals(str)) {
            this.firstCluster.add(vector);
        } else if ("1".equals(str)) {
            this.secondCluster.add(vector);
        } else if ("2".equals(str)) {
            this.thirdCluster.add(vector);
        }
    }

    private void assertVectorsWithOutlierRemoval() {
        checkClustersWithOutlierRemoval();
    }

    private void assertVectorsWithoutOutlierRemoval() {
        assertFirstClusterWithoutOutlierRemoval();
        assertSecondClusterWithoutOutlierRemoval();
        assertThirdClusterWithoutOutlierRemoval();
    }

    private void assertThirdClusterWithoutOutlierRemoval() {
        Assert.assertEquals(2L, this.thirdCluster.size());
        Iterator<Vector> it = this.thirdCluster.iterator();
        while (it.hasNext()) {
            Assert.assertTrue(ArrayUtils.contains(new String[]{"{0:9.0,1:9.0}", "{0:8.0,1:8.0}"}, it.next().asFormatString()));
        }
    }

    private void assertSecondClusterWithoutOutlierRemoval() {
        Assert.assertEquals(4L, this.secondCluster.size());
        Iterator<Vector> it = this.secondCluster.iterator();
        while (it.hasNext()) {
            Assert.assertTrue(ArrayUtils.contains(new String[]{"{0:4.0,1:4.0}", "{0:5.0,1:4.0}", "{0:4.0,1:5.0}", "{0:5.0,1:5.0}"}, it.next().asFormatString()));
        }
    }

    private void assertFirstClusterWithoutOutlierRemoval() {
        Assert.assertEquals(3L, this.firstCluster.size());
        Iterator<Vector> it = this.firstCluster.iterator();
        while (it.hasNext()) {
            Assert.assertTrue(ArrayUtils.contains(new String[]{"{0:1.0,1:1.0}", "{0:2.0,1:1.0}", "{0:1.0,1:2.0}"}, it.next().asFormatString()));
        }
    }

    private void checkClustersWithOutlierRemoval() {
        HashSet newHashSet = Sets.newHashSet(new String[]{"{0:9.0,1:9.0}", "{0:1.0,1:1.0}"});
        ArrayList<List> newArrayList = Lists.newArrayList();
        newArrayList.add(this.firstCluster);
        newArrayList.add(this.secondCluster);
        newArrayList.add(this.thirdCluster);
        int i = 0;
        int i2 = 0;
        for (List list : newArrayList) {
            if (list.isEmpty()) {
                i2++;
            } else {
                i++;
                assertEquals("expecting only singleton clusters; got size=" + list.size(), 1L, list.size());
                Assert.assertTrue("not expecting cluster:" + ((Vector) list.get(0)).asFormatString(), newHashSet.contains(((Vector) list.get(0)).asFormatString()));
                newHashSet.remove(((Vector) list.get(0)).asFormatString());
            }
        }
        Assert.assertEquals("Different number of empty clusters than expected!", 1L, i2);
        Assert.assertEquals("Different number of singletons than expected!", 2L, i);
        Assert.assertEquals("Didn't match all reference clusters!", 0L, newHashSet.size());
    }
}
