package org.apache.flink.ml.clustering;

import java.lang.invoke.SerializedLambda;
import java.time.Instant;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.collections.IteratorUtils;
import org.apache.flink.api.common.time.Time;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.typeutils.RowTypeInfo;
import org.apache.flink.ml.clustering.agglomerativeclustering.AgglomerativeClustering;
import org.apache.flink.ml.common.window.CountTumblingWindows;
import org.apache.flink.ml.common.window.EventTimeTumblingWindows;
import org.apache.flink.ml.common.window.GlobalWindows;
import org.apache.flink.ml.common.window.ProcessingTimeTumblingWindows;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
import org.apache.flink.ml.util.TestUtils;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.DataTypes;
import org.apache.flink.table.api.Schema;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.test.util.AbstractTestBase;
import org.apache.flink.types.Row;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;

/* loaded from: input_file:org/apache/flink/ml/clustering/AgglomerativeClusteringTest.class */
public class AgglomerativeClusteringTest extends AbstractTestBase {

    @Rule
    public final TemporaryFolder tempFolder = new TemporaryFolder();
    private StreamTableEnvironment tEnv;
    private StreamExecutionEnvironment env;
    private Table inputDataTable;
    private static final List<DenseVector> INPUT_DATA = Arrays.asList(Vectors.dense(new double[]{1.0d, 1.0d}), Vectors.dense(new double[]{1.0d, 4.0d}), Vectors.dense(new double[]{1.0d, 0.0d}), Vectors.dense(new double[]{4.0d, 4.0d}), Vectors.dense(new double[]{4.0d, 1.5d}), Vectors.dense(new double[]{4.0d, 0.0d}));
    private static final double[] EUCLIDEAN_AVERAGE_MERGE_DISTANCES = {1.0d, 1.5d, 3.0d, 3.1394402d, 3.9559706d};
    private static final double[] COSINE_AVERAGE_MERGE_DISTANCES = {0.0d, 1.110223E-16d, 0.0636708d, 0.142507d, 0.3664484d};
    private static final double[] MANHATTAN_AVERAGE_MERGE_DISTANCES = {1.0d, 1.5d, 3.0d, 3.75d, 4.875d};
    private static final double[] EUCLIDEAN_SINGLE_MERGE_DISTANCES = {1.0d, 1.5d, 2.5d, 3.0d, 3.0d};
    private static final double[] EUCLIDEAN_WARD_MERGE_DISTANCES = {1.0d, 1.5d, 3.0d, 4.2573465d, 5.5113519d};
    private static final double[] EUCLIDEAN_COMPLETE_MERGE_DISTANCES = {1.0d, 1.5d, 3.0d, 3.3541019d, 5.0d};
    private static final List<Set<DenseVector>> EUCLIDEAN_WARD_NUM_CLUSTERS_AS_TWO_RESULT = Arrays.asList(new HashSet(Arrays.asList(Vectors.dense(new double[]{1.0d, 1.0d}), Vectors.dense(new double[]{1.0d, 0.0d}), Vectors.dense(new double[]{4.0d, 1.5d}), Vectors.dense(new double[]{4.0d, 0.0d}))), new HashSet(Arrays.asList(Vectors.dense(new double[]{1.0d, 4.0d}), Vectors.dense(new double[]{4.0d, 4.0d}))));
    private static final List<Set<DenseVector>> EUCLIDEAN_WARD_THRESHOLD_AS_TWO_RESULT = Arrays.asList(new HashSet(Arrays.asList(Vectors.dense(new double[]{1.0d, 1.0d}), Vectors.dense(new double[]{1.0d, 0.0d}))), new HashSet(Collections.singletonList(Vectors.dense(new double[]{1.0d, 4.0d}))), new HashSet(Collections.singletonList(Vectors.dense(new double[]{4.0d, 4.0d}))), new HashSet(Arrays.asList(Vectors.dense(new double[]{4.0d, 1.5d}), Vectors.dense(new double[]{4.0d, 0.0d}))));
    private static final List<Set<DenseVector>> EUCLIDEAN_WARD_COUNT_FIVE_WINDOW_AS_TWO_RESULT = Arrays.asList(new HashSet(Arrays.asList(Vectors.dense(new double[]{1.0d, 1.0d}), Vectors.dense(new double[]{1.0d, 0.0d}))), new HashSet(Arrays.asList(Vectors.dense(new double[]{1.0d, 4.0d}), Vectors.dense(new double[]{4.0d, 4.0d}), Vectors.dense(new double[]{4.0d, 1.5d}))));
    private static final List<Set<DenseVector>> EUCLIDEAN_WARD_EVENT_TIME_WINDOW_AS_TWO_RESULT = Arrays.asList(new HashSet(Arrays.asList(Vectors.dense(new double[]{1.0d, 1.0d}), Vectors.dense(new double[]{1.0d, 0.0d}))), new HashSet(Collections.singletonList(Vectors.dense(new double[]{1.0d, 4.0d}))), new HashSet(Arrays.asList(Vectors.dense(new double[]{4.0d, 0.0d}), Vectors.dense(new double[]{4.0d, 1.5d}))), new HashSet(Collections.singletonList(Vectors.dense(new double[]{4.0d, 4.0d}))));
    private static final List<Set<DenseVector>> EUCLIDEAN_AVERAGE_NUM_CLUSTERS_AS_TWO_RESULT = Arrays.asList(new HashSet(Arrays.asList(Vectors.dense(new double[]{1.0d, 1.0d}), Vectors.dense(new double[]{1.0d, 0.0d}), Vectors.dense(new double[]{4.0d, 1.5d}), Vectors.dense(new double[]{4.0d, 0.0d}))), new HashSet(Arrays.asList(Vectors.dense(new double[]{1.0d, 4.0d}), Vectors.dense(new double[]{4.0d, 4.0d}))));
    private static final double TOLERANCE = 1.0E-7d;

    @Before
    public void before() {
        this.env = TestUtils.getExecutionEnvironment();
        this.tEnv = StreamTableEnvironment.create(this.env);
        this.inputDataTable = this.tEnv.fromDataStream(this.env.fromCollection(INPUT_DATA).map(denseVector -> {
            return denseVector;
        })).as("features", new String[0]);
    }

    @Test
    public void testParam() {
        AgglomerativeClustering agglomerativeClustering = new AgglomerativeClustering();
        Assert.assertEquals("features", agglomerativeClustering.getFeaturesCol());
        Assert.assertEquals(2L, agglomerativeClustering.getNumClusters().intValue());
        Assert.assertNull(agglomerativeClustering.getDistanceThreshold());
        Assert.assertEquals("ward", agglomerativeClustering.getLinkage());
        Assert.assertEquals("euclidean", agglomerativeClustering.getDistanceMeasure());
        Assert.assertFalse(agglomerativeClustering.getComputeFullTree().booleanValue());
        Assert.assertEquals("prediction", agglomerativeClustering.getPredictionCol());
        Assert.assertEquals(GlobalWindows.getInstance(), agglomerativeClustering.getWindows());
        ((AgglomerativeClustering) ((AgglomerativeClustering) ((AgglomerativeClustering) ((AgglomerativeClustering) ((AgglomerativeClustering) ((AgglomerativeClustering) ((AgglomerativeClustering) agglomerativeClustering.setFeaturesCol("test_features")).setNumClusters((Integer) null)).setDistanceThreshold(Double.valueOf(0.01d))).setLinkage("average")).setDistanceMeasure("cosine")).setComputeFullTree(true)).setPredictionCol("cluster_id")).setWindows(ProcessingTimeTumblingWindows.of(Time.milliseconds(100L)));
        Assert.assertEquals("test_features", agglomerativeClustering.getFeaturesCol());
        Assert.assertNull(agglomerativeClustering.getNumClusters());
        Assert.assertEquals(0.01d, agglomerativeClustering.getDistanceThreshold().doubleValue(), TOLERANCE);
        Assert.assertEquals("average", agglomerativeClustering.getLinkage());
        Assert.assertEquals("cosine", agglomerativeClustering.getDistanceMeasure());
        Assert.assertTrue(agglomerativeClustering.getComputeFullTree().booleanValue());
        Assert.assertEquals("cluster_id", agglomerativeClustering.getPredictionCol());
        Assert.assertEquals(ProcessingTimeTumblingWindows.of(Time.milliseconds(100L)), agglomerativeClustering.getWindows());
    }

    @Test
    public void testOutputSchema() {
        Table[] transform = ((AgglomerativeClustering) ((AgglomerativeClustering) new AgglomerativeClustering().setFeaturesCol("test_features")).setPredictionCol("test_prediction")).transform(new Table[]{this.tEnv.fromDataStream(this.env.fromElements(new Row[]{Row.of(new Object[]{"", ""})})).as("test_features", new String[]{"dummy_input"})});
        Assert.assertEquals(2L, transform.length);
        Assert.assertEquals(Arrays.asList("test_features", "dummy_input", "test_prediction"), transform[0].getResolvedSchema().getColumnNames());
        Assert.assertEquals(Arrays.asList("clusterId1", "clusterId2", "distance", "sizeOfMergedCluster"), transform[1].getResolvedSchema().getColumnNames());
    }

    @Test
    public void testTransform() throws Exception {
        AgglomerativeClustering agglomerativeClustering = (AgglomerativeClustering) ((AgglomerativeClustering) ((AgglomerativeClustering) new AgglomerativeClustering().setLinkage("ward")).setDistanceMeasure("euclidean")).setPredictionCol("pred");
        verifyClusteringResult(EUCLIDEAN_WARD_NUM_CLUSTERS_AS_TWO_RESULT, agglomerativeClustering.transform(new Table[]{this.inputDataTable})[0], agglomerativeClustering.getFeaturesCol(), agglomerativeClustering.getPredictionCol());
        verifyClusteringResult(EUCLIDEAN_WARD_NUM_CLUSTERS_AS_TWO_RESULT, ((AgglomerativeClustering) agglomerativeClustering.setComputeFullTree(true)).transform(new Table[]{this.inputDataTable})[0], agglomerativeClustering.getFeaturesCol(), agglomerativeClustering.getPredictionCol());
        verifyClusteringResult(EUCLIDEAN_WARD_THRESHOLD_AS_TWO_RESULT, ((AgglomerativeClustering) ((AgglomerativeClustering) agglomerativeClustering.setNumClusters((Integer) null)).setDistanceThreshold(Double.valueOf(2.0d))).transform(new Table[]{this.inputDataTable})[0], agglomerativeClustering.getFeaturesCol(), agglomerativeClustering.getPredictionCol());
    }

    @Test
    public void testTransformWithAverageLinkage() throws Exception {
        AgglomerativeClustering agglomerativeClustering = (AgglomerativeClustering) ((AgglomerativeClustering) ((AgglomerativeClustering) ((AgglomerativeClustering) new AgglomerativeClustering().setLinkage("average")).setDistanceMeasure("euclidean")).setNumClusters(2)).setPredictionCol("pred");
        verifyClusteringResult(EUCLIDEAN_AVERAGE_NUM_CLUSTERS_AS_TWO_RESULT, agglomerativeClustering.transform(new Table[]{this.inputDataTable})[0], agglomerativeClustering.getFeaturesCol(), agglomerativeClustering.getPredictionCol());
    }

    @Test
    public void testLargeDistanceThreshold() throws Exception {
        AgglomerativeClustering agglomerativeClustering = (AgglomerativeClustering) ((AgglomerativeClustering) new AgglomerativeClustering().setNumClusters((Integer) null)).setDistanceThreshold(Double.valueOf(Double.MAX_VALUE));
        Table table = agglomerativeClustering.transform(new Table[]{this.inputDataTable})[0];
        HashSet hashSet = new HashSet();
        this.tEnv.toDataStream(table).executeAndCollect().forEachRemaining(row -> {
            hashSet.add((Integer) row.getFieldAs(agglomerativeClustering.getPredictionCol()));
        });
        Assert.assertEquals(1L, hashSet.size());
    }

    @Test
    public void testTransformWithCountTumblingWindows() throws Exception {
        this.env.setParallelism(1);
        this.inputDataTable = this.tEnv.fromDataStream(this.env.fromCollection(INPUT_DATA).map(denseVector -> {
            return denseVector;
        })).as("features", new String[0]);
        AgglomerativeClustering agglomerativeClustering = (AgglomerativeClustering) ((AgglomerativeClustering) ((AgglomerativeClustering) ((AgglomerativeClustering) new AgglomerativeClustering().setLinkage("ward")).setDistanceMeasure("euclidean")).setPredictionCol("pred")).setWindows(CountTumblingWindows.of(5L));
        verifyClusteringResult(EUCLIDEAN_WARD_COUNT_FIVE_WINDOW_AS_TWO_RESULT, agglomerativeClustering.transform(new Table[]{this.inputDataTable})[0], agglomerativeClustering.getFeaturesCol(), agglomerativeClustering.getPredictionCol());
    }

    @Test
    public void testTransformWithEventTimeTumblingWindows() throws Exception {
        RowTypeInfo rowTypeInfo = new RowTypeInfo(new TypeInformation[]{DenseVectorTypeInfo.INSTANCE, Types.INSTANT}, new String[]{"features", "ts"});
        Instant now = Instant.now();
        Table fromDataStream = this.tEnv.fromDataStream(this.env.fromCollection(INPUT_DATA).setParallelism(1).map(denseVector -> {
            return Row.of(new Object[]{denseVector, now.plusSeconds((long) denseVector.get(0))});
        }, rowTypeInfo), Schema.newBuilder().column("features", DataTypes.of(DenseVectorTypeInfo.INSTANCE)).column("ts", DataTypes.TIMESTAMP_LTZ(3)).watermark("ts", "ts - INTERVAL '5' SECOND").build());
        AgglomerativeClustering agglomerativeClustering = (AgglomerativeClustering) ((AgglomerativeClustering) ((AgglomerativeClustering) ((AgglomerativeClustering) new AgglomerativeClustering().setLinkage("ward")).setDistanceMeasure("euclidean")).setPredictionCol("pred")).setWindows(EventTimeTumblingWindows.of(Time.seconds(1L)));
        List<Set<DenseVector>> groupFeaturesByPrediction = KMeansTest.groupFeaturesByPrediction(IteratorUtils.toList(this.tEnv.toDataStream(agglomerativeClustering.transform(new Table[]{fromDataStream})[0]).executeAndCollect()), agglomerativeClustering.getFeaturesCol(), agglomerativeClustering.getPredictionCol());
        boolean z = true;
        for (Set<DenseVector> set : EUCLIDEAN_WARD_EVENT_TIME_WINDOW_AS_TWO_RESULT) {
            boolean z2 = false;
            Iterator<Set<DenseVector>> it = groupFeaturesByPrediction.iterator();
            while (true) {
                if (it.hasNext()) {
                    if (it.next().containsAll(set)) {
                        z2 = true;
                        break;
                    }
                } else {
                    break;
                }
            }
            z &= z2;
        }
        Assert.assertTrue(z);
    }

    @Test
    public void testMergeInfo() throws Exception {
        AgglomerativeClustering agglomerativeClustering = (AgglomerativeClustering) ((AgglomerativeClustering) ((AgglomerativeClustering) ((AgglomerativeClustering) new AgglomerativeClustering().setLinkage("average")).setDistanceMeasure("euclidean")).setPredictionCol("pred")).setComputeFullTree(true);
        verifyMergeInfo(EUCLIDEAN_AVERAGE_MERGE_DISTANCES, agglomerativeClustering.transform(new Table[]{this.inputDataTable})[1]);
        agglomerativeClustering.setDistanceMeasure("cosine");
        verifyMergeInfo(COSINE_AVERAGE_MERGE_DISTANCES, agglomerativeClustering.transform(new Table[]{this.inputDataTable})[1]);
        agglomerativeClustering.setDistanceMeasure("manhattan");
        verifyMergeInfo(MANHATTAN_AVERAGE_MERGE_DISTANCES, agglomerativeClustering.transform(new Table[]{this.inputDataTable})[1]);
        ((AgglomerativeClustering) agglomerativeClustering.setDistanceMeasure("euclidean")).setLinkage("complete");
        verifyMergeInfo(EUCLIDEAN_COMPLETE_MERGE_DISTANCES, agglomerativeClustering.transform(new Table[]{this.inputDataTable})[1]);
        agglomerativeClustering.setLinkage("single");
        verifyMergeInfo(EUCLIDEAN_SINGLE_MERGE_DISTANCES, agglomerativeClustering.transform(new Table[]{this.inputDataTable})[1]);
        agglomerativeClustering.setLinkage("ward");
        verifyMergeInfo(EUCLIDEAN_WARD_MERGE_DISTANCES, agglomerativeClustering.transform(new Table[]{this.inputDataTable})[1]);
        agglomerativeClustering.setComputeFullTree(false);
        verifyMergeInfo(Arrays.copyOfRange(EUCLIDEAN_WARD_MERGE_DISTANCES, 0, EUCLIDEAN_WARD_MERGE_DISTANCES.length - 1), agglomerativeClustering.transform(new Table[]{this.inputDataTable})[1]);
    }

    @Test
    public void testSaveLoadTransform() throws Exception {
        AgglomerativeClustering saveAndReload = TestUtils.saveAndReload(this.tEnv, (AgglomerativeClustering) ((AgglomerativeClustering) ((AgglomerativeClustering) new AgglomerativeClustering().setLinkage("ward")).setDistanceMeasure("euclidean")).setPredictionCol("pred"), this.tempFolder.newFolder().getAbsolutePath(), AgglomerativeClustering::load);
        verifyClusteringResult(EUCLIDEAN_WARD_NUM_CLUSTERS_AS_TWO_RESULT, saveAndReload.transform(new Table[]{this.inputDataTable})[0], saveAndReload.getFeaturesCol(), saveAndReload.getPredictionCol());
    }

    private void verifyMergeInfo(double[] dArr, Table table) throws Exception {
        List list = IteratorUtils.toList(this.tEnv.toDataStream(table).executeAndCollect());
        Assert.assertEquals(dArr.length, list.size());
        for (int i = 0; i < list.size(); i++) {
            Assert.assertEquals(dArr[i], ((Number) ((Row) list.get(i)).getFieldAs(2)).doubleValue(), TOLERANCE);
        }
    }

    public void verifyClusteringResult(List<Set<DenseVector>> list, Table table, String str, String str2) throws Exception {
        Assert.assertTrue(CollectionUtils.isEqualCollection(list, KMeansTest.groupFeaturesByPrediction(IteratorUtils.toList(this.tEnv.toDataStream(table).executeAndCollect()), str, str2)));
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case -1785635391:
                if (implMethodName.equals("lambda$testTransformWithCountTumblingWindows$e0defa2f$1")) {
                    z = false;
                    break;
                }
                break;
            case -1753818124:
                if (implMethodName.equals("lambda$testTransformWithEventTimeTumblingWindows$efa92bd1$1")) {
                    z = true;
                    break;
                }
                break;
            case -1273644784:
                if (implMethodName.equals("lambda$before$e0defa2f$1")) {
                    z = 2;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/flink/api/common/functions/MapFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("map") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/flink/ml/clustering/AgglomerativeClusteringTest") && serializedLambda.getImplMethodSignature().equals("(Lorg/apache/flink/ml/linalg/DenseVector;)Lorg/apache/flink/ml/linalg/DenseVector;")) {
                    return denseVector -> {
                        return denseVector;
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/flink/api/common/functions/MapFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("map") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/flink/ml/clustering/AgglomerativeClusteringTest") && serializedLambda.getImplMethodSignature().equals("(Ljava/time/Instant;Lorg/apache/flink/ml/linalg/DenseVector;)Lorg/apache/flink/types/Row;")) {
                    Instant instant = (Instant) serializedLambda.getCapturedArg(0);
                    return denseVector2 -> {
                        return Row.of(new Object[]{denseVector2, instant.plusSeconds((long) denseVector2.get(0))});
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/flink/api/common/functions/MapFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("map") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/flink/ml/clustering/AgglomerativeClusteringTest") && serializedLambda.getImplMethodSignature().equals("(Lorg/apache/flink/ml/linalg/DenseVector;)Lorg/apache/flink/ml/linalg/DenseVector;")) {
                    return denseVector3 -> {
                        return denseVector3;
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
