package org.apache.flink.ml.clustering;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.collections.IteratorUtils;
import org.apache.flink.api.common.restartstrategy.RestartStrategies;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.ml.clustering.kmeans.KMeans;
import org.apache.flink.ml.clustering.kmeans.KMeansModel;
import org.apache.flink.ml.clustering.kmeans.KMeansModelData;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.SparseVector;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.ml.util.ReadWriteUtils;
import org.apache.flink.ml.util.TestUtils;
import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.DataTypes;
import org.apache.flink.table.api.Schema;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.test.util.AbstractTestBase;
import org.apache.flink.types.Row;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;

/* loaded from: input_file:org/apache/flink/ml/clustering/KMeansTest.class */
public class KMeansTest extends AbstractTestBase {

    @Rule
    public final TemporaryFolder tempFolder = new TemporaryFolder();
    private StreamExecutionEnvironment env;
    private StreamTableEnvironment tEnv;
    private Table dataTable;
    private static final List<DenseVector> DATA = Arrays.asList(Vectors.dense(new double[]{0.0d, 0.0d}), Vectors.dense(new double[]{0.0d, 0.3d}), Vectors.dense(new double[]{0.3d, 0.0d}), Vectors.dense(new double[]{9.0d, 0.0d}), Vectors.dense(new double[]{9.0d, 0.6d}), Vectors.dense(new double[]{9.6d, 0.0d}));
    private static final List<Set<DenseVector>> expectedGroups = Arrays.asList(new HashSet(Arrays.asList(Vectors.dense(new double[]{0.0d, 0.0d}), Vectors.dense(new double[]{0.0d, 0.3d}), Vectors.dense(new double[]{0.3d, 0.0d}))), new HashSet(Arrays.asList(Vectors.dense(new double[]{9.0d, 0.0d}), Vectors.dense(new double[]{9.0d, 0.6d}), Vectors.dense(new double[]{9.6d, 0.0d}))));

    @Before
    public void before() {
        Configuration configuration = new Configuration();
        configuration.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
        this.env = StreamExecutionEnvironment.getExecutionEnvironment(configuration);
        this.env.setParallelism(4);
        this.env.enableCheckpointing(100L);
        this.env.setRestartStrategy(RestartStrategies.noRestart());
        this.tEnv = StreamTableEnvironment.create(this.env);
        this.dataTable = this.tEnv.fromDataStream(this.env.fromCollection(DATA)).as("features", new String[0]);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static List<Set<DenseVector>> groupFeaturesByPrediction(List<Row> list, String str, String str2) {
        HashMap hashMap = new HashMap();
        for (Row row : list) {
            DenseVector dense = ((Vector) row.getField(str)).toDense();
            int intValue = ((Integer) row.getField(str2)).intValue();
            hashMap.putIfAbsent(Integer.valueOf(intValue), new HashSet());
            ((Set) hashMap.get(Integer.valueOf(intValue))).add(dense);
        }
        return new ArrayList(hashMap.values());
    }

    @Test
    public void testParam() {
        KMeans kMeans = new KMeans();
        Assert.assertEquals("features", kMeans.getFeaturesCol());
        Assert.assertEquals("prediction", kMeans.getPredictionCol());
        Assert.assertEquals("euclidean", kMeans.getDistanceMeasure());
        Assert.assertEquals("random", kMeans.getInitMode());
        Assert.assertEquals(2L, kMeans.getK());
        Assert.assertEquals(20L, kMeans.getMaxIter());
        Assert.assertEquals(KMeans.class.getName().hashCode(), kMeans.getSeed());
        ((KMeans) ((KMeans) ((KMeans) ((KMeans) ((KMeans) kMeans.setK(9)).setFeaturesCol("test_feature")).setPredictionCol("test_prediction")).setK(3)).setMaxIter(30)).setSeed(100L);
        Assert.assertEquals("test_feature", kMeans.getFeaturesCol());
        Assert.assertEquals("test_prediction", kMeans.getPredictionCol());
        Assert.assertEquals(3L, kMeans.getK());
        Assert.assertEquals(30L, kMeans.getMaxIter());
        Assert.assertEquals(100L, kMeans.getSeed());
    }

    @Test
    public void testOutputSchema() {
        Table as = this.dataTable.as("test_feature", new String[0]);
        KMeans kMeans = (KMeans) ((KMeans) new KMeans().setFeaturesCol("test_feature")).setPredictionCol("test_prediction");
        Table table = kMeans.fit(new Table[]{as}).transform(new Table[]{as})[0];
        Assert.assertEquals(Arrays.asList("test_feature", "test_prediction"), table.getResolvedSchema().getColumnNames());
        Assert.assertTrue(CollectionUtils.isEqualCollection(expectedGroups, groupFeaturesByPrediction(IteratorUtils.toList(table.execute().collect()), kMeans.getFeaturesCol(), kMeans.getPredictionCol())));
    }

    @Test
    public void testFewerDistinctPointsThanCluster() {
        Table as = this.tEnv.fromDataStream(this.env.fromCollection(Arrays.asList(Vectors.dense(new double[]{0.0d, 0.1d}), Vectors.dense(new double[]{0.0d, 0.1d}), Vectors.dense(new double[]{0.0d, 0.1d}))), Schema.newBuilder().column("f0", DataTypes.of(DenseVector.class)).build()).as("features", new String[0]);
        KMeans kMeans = (KMeans) new KMeans().setK(2);
        Assert.assertTrue(CollectionUtils.isEqualCollection(Collections.singletonList(Collections.singleton(Vectors.dense(new double[]{0.0d, 0.1d}))), groupFeaturesByPrediction(IteratorUtils.toList(kMeans.fit(new Table[]{as}).transform(new Table[]{as})[0].execute().collect()), kMeans.getFeaturesCol(), kMeans.getPredictionCol())));
    }

    @Test
    public void testFitAndPredict() {
        KMeans kMeans = (KMeans) ((KMeans) new KMeans().setMaxIter(2)).setK(2);
        Table table = kMeans.fit(new Table[]{this.dataTable}).transform(new Table[]{this.dataTable})[0];
        Assert.assertEquals(Arrays.asList("features", "prediction"), table.getResolvedSchema().getColumnNames());
        Assert.assertTrue(CollectionUtils.isEqualCollection(expectedGroups, groupFeaturesByPrediction(IteratorUtils.toList(table.execute().collect()), kMeans.getFeaturesCol(), kMeans.getPredictionCol())));
    }

    @Test
    public void testInputTypeConversion() {
        this.dataTable = TestUtils.convertDataTypesToSparseInt(this.tEnv, this.dataTable);
        Assert.assertArrayEquals(new Class[]{SparseVector.class}, TestUtils.getColumnDataTypes(this.dataTable));
        KMeans kMeans = (KMeans) ((KMeans) new KMeans().setMaxIter(2)).setK(2);
        Table table = kMeans.fit(new Table[]{this.dataTable}).transform(new Table[]{this.dataTable})[0];
        Assert.assertEquals(Arrays.asList("features", "prediction"), table.getResolvedSchema().getColumnNames());
        Assert.assertTrue(CollectionUtils.isEqualCollection(expectedGroups, groupFeaturesByPrediction(IteratorUtils.toList(table.execute().collect()), kMeans.getFeaturesCol(), kMeans.getPredictionCol())));
    }

    @Test
    public void testSaveLoadAndPredict() throws Exception {
        KMeans kMeans = (KMeans) ((KMeans) new KMeans().setMaxIter(2)).setK(2);
        KMeansModel saveAndReload = TestUtils.saveAndReload(this.tEnv, TestUtils.saveAndReload(this.tEnv, kMeans, this.tempFolder.newFolder().getAbsolutePath()).fit(new Table[]{this.dataTable}), this.tempFolder.newFolder().getAbsolutePath());
        Table table = saveAndReload.transform(new Table[]{this.dataTable})[0];
        Assert.assertEquals(Arrays.asList("centroids", "weights"), saveAndReload.getModelData()[0].getResolvedSchema().getColumnNames());
        Assert.assertEquals(Arrays.asList("features", "prediction"), table.getResolvedSchema().getColumnNames());
        Assert.assertTrue(CollectionUtils.isEqualCollection(expectedGroups, groupFeaturesByPrediction(IteratorUtils.toList(table.execute().collect()), kMeans.getFeaturesCol(), kMeans.getPredictionCol())));
    }

    @Test
    public void testGetModelData() throws Exception {
        KMeansModel fit = ((KMeans) ((KMeans) new KMeans().setMaxIter(2)).setK(2)).fit(new Table[]{this.dataTable});
        Assert.assertEquals(Arrays.asList("centroids", "weights"), fit.getModelData()[0].getResolvedSchema().getColumnNames());
        List list = IteratorUtils.toList(KMeansModelData.getModelDataStream(fit.getModelData()[0]).executeAndCollect());
        Assert.assertEquals(1L, list.size());
        DenseVector[] denseVectorArr = ((KMeansModelData) list.get(0)).centroids;
        Assert.assertEquals(2L, denseVectorArr.length);
        Arrays.sort(denseVectorArr, Comparator.comparingDouble(denseVector -> {
            return denseVector.get(0);
        }));
        Assert.assertArrayEquals(denseVectorArr[0].values, new double[]{0.1d, 0.1d}, 1.0E-5d);
        Assert.assertArrayEquals(denseVectorArr[1].values, new double[]{9.2d, 0.2d}, 1.0E-5d);
    }

    @Test
    public void testSetModelData() {
        KMeans kMeans = (KMeans) ((KMeans) new KMeans().setMaxIter(2)).setK(2);
        KMeansModel fit = kMeans.fit(new Table[]{this.dataTable});
        KMeansModel modelData = new KMeansModel().setModelData(fit.getModelData());
        ReadWriteUtils.updateExistingParams(modelData, fit.getParamMap());
        Assert.assertTrue(CollectionUtils.isEqualCollection(expectedGroups, groupFeaturesByPrediction(IteratorUtils.toList(modelData.transform(new Table[]{this.dataTable})[0].execute().collect()), kMeans.getFeaturesCol(), kMeans.getPredictionCol())));
    }
}
