package org.apache.flink.ml.clustering;

import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import org.apache.commons.collections.CollectionUtils;
import org.apache.flink.api.common.JobID;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.configuration.CoreOptions;
import org.apache.flink.configuration.RestOptions;
import org.apache.flink.metrics.Gauge;
import org.apache.flink.ml.clustering.kmeans.KMeans;
import org.apache.flink.ml.clustering.kmeans.KMeansModelData;
import org.apache.flink.ml.clustering.kmeans.OnlineKMeans;
import org.apache.flink.ml.clustering.kmeans.OnlineKMeansModel;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.SparseVector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
import org.apache.flink.ml.util.InMemorySinkFunction;
import org.apache.flink.ml.util.InMemorySourceFunction;
import org.apache.flink.ml.util.TestUtils;
import org.apache.flink.runtime.client.JobStatusMessage;
import org.apache.flink.runtime.jobgraph.JobGraph;
import org.apache.flink.runtime.minicluster.MiniCluster;
import org.apache.flink.runtime.minicluster.MiniClusterConfiguration;
import org.apache.flink.runtime.testutils.InMemoryReporter;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.types.Row;
import org.apache.flink.util.TestLogger;
import org.junit.After;
import org.junit.AfterClass;
import org.junit.Assert;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;

/* loaded from: input_file:org/apache/flink/ml/clustering/OnlineKMeansTest.class */
public class OnlineKMeansTest extends TestLogger {

    @Rule
    public final TemporaryFolder tempFolder = new TemporaryFolder();
    private static final DenseVector[] trainData1 = {Vectors.dense(new double[]{10.0d, 0.0d}), Vectors.dense(new double[]{10.0d, 0.3d}), Vectors.dense(new double[]{10.3d, 0.0d}), Vectors.dense(new double[]{-10.0d, 0.0d}), Vectors.dense(new double[]{-10.0d, 0.6d}), Vectors.dense(new double[]{-10.6d, 0.0d})};
    private static final DenseVector[] trainData2 = {Vectors.dense(new double[]{10.0d, 100.0d}), Vectors.dense(new double[]{10.0d, 100.3d}), Vectors.dense(new double[]{10.3d, 100.0d}), Vectors.dense(new double[]{-10.0d, -100.0d}), Vectors.dense(new double[]{-10.0d, -100.6d}), Vectors.dense(new double[]{-10.6d, -100.0d})};
    private static final DenseVector[] predictData = {Vectors.dense(new double[]{10.0d, 10.0d}), Vectors.dense(new double[]{10.3d, 10.0d}), Vectors.dense(new double[]{10.0d, 10.3d}), Vectors.dense(new double[]{-10.0d, 10.0d}), Vectors.dense(new double[]{-10.3d, 10.0d}), Vectors.dense(new double[]{-10.0d, 10.3d})};
    private static final List<Set<DenseVector>> expectedGroups1 = Arrays.asList(new HashSet(Arrays.asList(Vectors.dense(new double[]{10.0d, 10.0d}), Vectors.dense(new double[]{10.3d, 10.0d}), Vectors.dense(new double[]{10.0d, 10.3d}))), new HashSet(Arrays.asList(Vectors.dense(new double[]{-10.0d, 10.0d}), Vectors.dense(new double[]{-10.3d, 10.0d}), Vectors.dense(new double[]{-10.0d, 10.3d}))));
    private static final List<Set<DenseVector>> expectedGroups2 = Collections.singletonList(new HashSet(Arrays.asList(Vectors.dense(new double[]{10.0d, 10.0d}), Vectors.dense(new double[]{10.3d, 10.0d}), Vectors.dense(new double[]{10.0d, 10.3d}), Vectors.dense(new double[]{-10.0d, 10.0d}), Vectors.dense(new double[]{-10.3d, 10.0d}), Vectors.dense(new double[]{-10.0d, 10.3d}))));
    private static final int defaultParallelism = 4;
    private static final int numTaskManagers = 2;
    private static final int numSlotsPerTaskManager = 2;
    private int currentModelDataVersion;
    private InMemorySourceFunction<DenseVector> trainSource;
    private InMemorySourceFunction<DenseVector> predictSource;
    private InMemorySinkFunction<Row> outputSink;
    private InMemorySinkFunction<KMeansModelData> modelDataSink;
    private static InMemoryReporter reporter;
    private static MiniCluster miniCluster;
    private static StreamExecutionEnvironment env;
    private static StreamTableEnvironment tEnv;
    private Table offlineTrainTable;
    private Table onlineTrainTable;
    private Table onlinePredictTable;

    @BeforeClass
    public static void beforeClass() throws Exception {
        Configuration configuration = new Configuration();
        configuration.set(RestOptions.BIND_PORT, "18081-19091");
        configuration.set(CoreOptions.DEFAULT_PARALLELISM, Integer.valueOf(defaultParallelism));
        reporter = InMemoryReporter.create();
        reporter.addToConfiguration(configuration);
        miniCluster = new MiniCluster(new MiniClusterConfiguration.Builder().setConfiguration(configuration).setNumTaskManagers(2).setNumSlotsPerTaskManager(2).build());
        miniCluster.start();
        env = TestUtils.getExecutionEnvironment(configuration);
        tEnv = StreamTableEnvironment.create(env);
    }

    @Before
    public void before() throws Exception {
        this.currentModelDataVersion = 0;
        this.trainSource = new InMemorySourceFunction<>();
        this.predictSource = new InMemorySourceFunction<>();
        this.outputSink = new InMemorySinkFunction<>();
        this.modelDataSink = new InMemorySinkFunction<>();
        this.offlineTrainTable = tEnv.fromDataStream(env.fromElements(trainData1)).as("features", new String[0]);
        this.onlineTrainTable = tEnv.fromDataStream(env.addSource(this.trainSource, DenseVectorTypeInfo.INSTANCE)).as("features", new String[0]);
        this.onlinePredictTable = tEnv.fromDataStream(env.addSource(this.predictSource, DenseVectorTypeInfo.INSTANCE)).as("features", new String[0]);
    }

    @After
    public void after() throws Exception {
        Iterator it = ((Collection) miniCluster.listJobs().get()).iterator();
        while (it.hasNext()) {
            miniCluster.cancelJob(((JobStatusMessage) it.next()).getJobId());
        }
    }

    @AfterClass
    public static void afterClass() throws Exception {
        miniCluster.close();
    }

    private void transformAndOutputData(OnlineKMeansModel onlineKMeansModel) {
        tEnv.toDataStream(onlineKMeansModel.transform(new Table[]{this.onlinePredictTable})[0]).addSink(this.outputSink);
        KMeansModelData.getModelDataStream(onlineKMeansModel.getModelData()[0]).addSink(this.modelDataSink);
    }

    private void waitInitModelDataSetup(JobID jobID) throws InterruptedException {
        while (reporter.findMetrics(jobID, "modelDataVersion").size() < defaultParallelism) {
            Thread.sleep(100L);
        }
        waitModelDataUpdate(jobID);
    }

    private void waitModelDataUpdate(JobID jobID) throws InterruptedException {
        while (true) {
            int intValue = ((Integer) reporter.findMetrics(jobID, "modelDataVersion").values().stream().map(metric -> {
                return Integer.valueOf(Integer.parseInt((String) ((Gauge) metric).getValue()));
            }).min((v0, v1) -> {
                return v0.compareTo(v1);
            }).orElse(0)).intValue();
            if (intValue != this.currentModelDataVersion) {
                this.currentModelDataVersion = intValue;
                return;
            }
            Thread.sleep(100L);
        }
    }

    private void predictAndAssert(List<Set<DenseVector>> list, String str, String str2) throws Exception {
        this.predictSource.addAll(predictData);
        Assert.assertTrue(CollectionUtils.isEqualCollection(list, KMeansTest.groupFeaturesByPrediction(this.outputSink.poll(predictData.length), str, str2)));
    }

    private JobID submitJob(JobGraph jobGraph) throws ExecutionException, InterruptedException, TimeoutException {
        return (JobID) miniCluster.submitJob(jobGraph).thenApply((v0) -> {
            return v0.getJobID();
        }).get(1L, TimeUnit.SECONDS);
    }

    @Test
    public void testParam() {
        OnlineKMeans onlineKMeans = new OnlineKMeans();
        Assert.assertEquals("features", onlineKMeans.getFeaturesCol());
        Assert.assertEquals("prediction", onlineKMeans.getPredictionCol());
        Assert.assertEquals("count", onlineKMeans.getBatchStrategy());
        Assert.assertEquals("euclidean", onlineKMeans.getDistanceMeasure());
        Assert.assertEquals(32L, onlineKMeans.getGlobalBatchSize());
        Assert.assertEquals(0.0d, onlineKMeans.getDecayFactor(), 1.0E-5d);
        Assert.assertEquals(OnlineKMeans.class.getName().hashCode(), onlineKMeans.getSeed());
        ((OnlineKMeans) ((OnlineKMeans) ((OnlineKMeans) ((OnlineKMeans) onlineKMeans.setFeaturesCol("test_feature")).setPredictionCol("test_prediction")).setGlobalBatchSize(5)).setDecayFactor(Double.valueOf(0.25d))).setSeed(100L);
        Assert.assertEquals("test_feature", onlineKMeans.getFeaturesCol());
        Assert.assertEquals("test_prediction", onlineKMeans.getPredictionCol());
        Assert.assertEquals("count", onlineKMeans.getBatchStrategy());
        Assert.assertEquals("euclidean", onlineKMeans.getDistanceMeasure());
        Assert.assertEquals(5L, onlineKMeans.getGlobalBatchSize());
        Assert.assertEquals(0.25d, onlineKMeans.getDecayFactor(), 1.0E-5d);
        Assert.assertEquals(100L, onlineKMeans.getSeed());
    }

    @Test
    public void testFitAndPredict() throws Exception {
        OnlineKMeans initialModelData = ((OnlineKMeans) ((OnlineKMeans) ((OnlineKMeans) new OnlineKMeans().setFeaturesCol("features")).setPredictionCol("prediction")).setGlobalBatchSize(6)).setInitialModelData(KMeansModelData.generateRandomModelData(tEnv, 2, 2, 0.0d, 0L));
        transformAndOutputData(initialModelData.fit(new Table[]{this.onlineTrainTable}));
        JobID submitJob = submitJob(env.getStreamGraph().getJobGraph());
        waitInitModelDataSetup(submitJob);
        this.trainSource.addAll(trainData1);
        waitModelDataUpdate(submitJob);
        predictAndAssert(expectedGroups1, initialModelData.getFeaturesCol(), initialModelData.getPredictionCol());
        this.trainSource.addAll(trainData2);
        waitModelDataUpdate(submitJob);
        predictAndAssert(expectedGroups2, initialModelData.getFeaturesCol(), initialModelData.getPredictionCol());
    }

    @Test
    public void testInputTypeConversion() throws Exception {
        this.offlineTrainTable = TestUtils.convertDataTypesToSparseInt(tEnv, this.offlineTrainTable);
        this.onlineTrainTable = TestUtils.convertDataTypesToSparseInt(tEnv, this.onlineTrainTable);
        this.onlinePredictTable = TestUtils.convertDataTypesToSparseInt(tEnv, this.onlinePredictTable);
        Assert.assertArrayEquals(new Class[]{SparseVector.class}, TestUtils.getColumnDataTypes(this.offlineTrainTable));
        Assert.assertArrayEquals(new Class[]{SparseVector.class}, TestUtils.getColumnDataTypes(this.onlineTrainTable));
        Assert.assertArrayEquals(new Class[]{SparseVector.class}, TestUtils.getColumnDataTypes(this.onlinePredictTable));
        OnlineKMeans initialModelData = ((OnlineKMeans) ((OnlineKMeans) ((OnlineKMeans) new OnlineKMeans().setFeaturesCol("features")).setPredictionCol("prediction")).setGlobalBatchSize(6)).setInitialModelData(KMeansModelData.generateRandomModelData(tEnv, 2, 2, 0.0d, 0L));
        transformAndOutputData(initialModelData.fit(new Table[]{this.onlineTrainTable}));
        JobID submitJob = submitJob(env.getStreamGraph().getJobGraph());
        waitInitModelDataSetup(submitJob);
        this.trainSource.addAll(trainData1);
        waitModelDataUpdate(submitJob);
        predictAndAssert(expectedGroups1, initialModelData.getFeaturesCol(), initialModelData.getPredictionCol());
        this.trainSource.addAll(trainData2);
        waitModelDataUpdate(submitJob);
        predictAndAssert(expectedGroups2, initialModelData.getFeaturesCol(), initialModelData.getPredictionCol());
    }

    @Test
    public void testInitWithKMeans() throws Exception {
        OnlineKMeans initialModelData = ((OnlineKMeans) ((OnlineKMeans) ((OnlineKMeans) new OnlineKMeans().setFeaturesCol("features")).setPredictionCol("prediction")).setGlobalBatchSize(6)).setInitialModelData(((KMeans) ((KMeans) new KMeans().setFeaturesCol("features")).setPredictionCol("prediction")).fit(new Table[]{this.offlineTrainTable}).getModelData()[0]);
        transformAndOutputData(initialModelData.fit(new Table[]{this.onlineTrainTable}));
        JobID submitJob = submitJob(env.getStreamGraph().getJobGraph());
        waitInitModelDataSetup(submitJob);
        predictAndAssert(expectedGroups1, initialModelData.getFeaturesCol(), initialModelData.getPredictionCol());
        this.trainSource.addAll(trainData2);
        waitModelDataUpdate(submitJob);
        predictAndAssert(expectedGroups2, initialModelData.getFeaturesCol(), initialModelData.getPredictionCol());
    }

    @Test
    public void testDecayFactor() throws Exception {
        transformAndOutputData(((OnlineKMeans) ((OnlineKMeans) ((OnlineKMeans) ((OnlineKMeans) new OnlineKMeans().setFeaturesCol("features")).setPredictionCol("prediction")).setGlobalBatchSize(6)).setDecayFactor(Double.valueOf(0.5d))).setInitialModelData(((KMeans) ((KMeans) new KMeans().setFeaturesCol("features")).setPredictionCol("prediction")).fit(new Table[]{this.offlineTrainTable}).getModelData()[0]).fit(new Table[]{this.onlineTrainTable}));
        submitJob(env.getStreamGraph().getJobGraph());
        this.modelDataSink.poll();
        this.trainSource.addAll(trainData2);
        KMeansModelData poll = this.modelDataSink.poll();
        KMeansModelData kMeansModelData = new KMeansModelData(new DenseVector[]{Vectors.dense(new double[]{-10.2d, -66.73333333333333d}), Vectors.dense(new double[]{10.1d, 66.76666666666667d})}, Vectors.dense(new double[]{4.5d, 4.5d}));
        Assert.assertArrayEquals(kMeansModelData.weights.values, poll.weights.values, 1.0E-5d);
        Assert.assertEquals(kMeansModelData.centroids.length, poll.centroids.length);
        Arrays.sort(poll.centroids, Comparator.comparingDouble(denseVector -> {
            return denseVector.get(0);
        }));
        for (int i = 0; i < kMeansModelData.centroids.length; i++) {
            Assert.assertArrayEquals(kMeansModelData.centroids[i].values, poll.centroids[i].values, 1.0E-5d);
        }
    }

    @Test
    public void testBatchSizeLessThanParallelism() {
        try {
            ((OnlineKMeans) ((OnlineKMeans) ((OnlineKMeans) new OnlineKMeans().setFeaturesCol("features")).setPredictionCol("prediction")).setGlobalBatchSize(2)).setInitialModelData(KMeansModelData.generateRandomModelData(tEnv, 2, 2, 0.0d, 0L)).fit(new Table[]{this.onlineTrainTable});
            Assert.fail("Expected IllegalStateException");
        } catch (Throwable th) {
            Assert.assertEquals(IllegalStateException.class, th.getClass());
            Assert.assertEquals("There are more subtasks in the training process than the number of elements in each batch. Some subtasks might be idling forever.", th.getMessage());
        }
    }

    @Test
    public void testSaveAndReload() throws Exception {
        OnlineKMeans initialModelData = ((OnlineKMeans) ((OnlineKMeans) ((OnlineKMeans) new OnlineKMeans().setFeaturesCol("features")).setPredictionCol("prediction")).setGlobalBatchSize(6)).setInitialModelData(((KMeans) ((KMeans) new KMeans().setFeaturesCol("features")).setPredictionCol("prediction")).fit(new Table[]{this.offlineTrainTable}).getModelData()[0]);
        String absolutePath = this.tempFolder.newFolder().getAbsolutePath();
        initialModelData.save(absolutePath);
        miniCluster.executeJobBlocking(env.getStreamGraph().getJobGraph());
        OnlineKMeansModel fit = OnlineKMeans.load(tEnv, absolutePath).fit(new Table[]{this.onlineTrainTable});
        String absolutePath2 = this.tempFolder.newFolder().getAbsolutePath();
        fit.save(absolutePath2);
        OnlineKMeansModel load = OnlineKMeansModel.load(tEnv, absolutePath2);
        load.setModelData(fit.getModelData());
        transformAndOutputData(load);
        JobID submitJob = submitJob(env.getStreamGraph().getJobGraph());
        waitInitModelDataSetup(submitJob);
        predictAndAssert(expectedGroups1, initialModelData.getFeaturesCol(), initialModelData.getPredictionCol());
        this.trainSource.addAll(trainData2);
        waitModelDataUpdate(submitJob);
        predictAndAssert(expectedGroups2, initialModelData.getFeaturesCol(), initialModelData.getPredictionCol());
    }

    @Test
    public void testGetModelData() throws Exception {
        transformAndOutputData(((OnlineKMeans) ((OnlineKMeans) ((OnlineKMeans) new OnlineKMeans().setFeaturesCol("features")).setPredictionCol("prediction")).setGlobalBatchSize(6)).setInitialModelData(KMeansModelData.generateRandomModelData(tEnv, 2, 2, 0.0d, 0L)).fit(new Table[]{this.onlineTrainTable}));
        submitJob(env.getStreamGraph().getJobGraph());
        this.modelDataSink.poll();
        this.trainSource.addAll(trainData1);
        KMeansModelData poll = this.modelDataSink.poll();
        KMeansModelData kMeansModelData = new KMeansModelData(new DenseVector[]{Vectors.dense(new double[]{-10.2d, 0.2d}), Vectors.dense(new double[]{10.1d, 0.1d})}, Vectors.dense(new double[]{3.0d, 3.0d}));
        Assert.assertArrayEquals(kMeansModelData.weights.values, poll.weights.values, 1.0E-5d);
        Assert.assertEquals(kMeansModelData.centroids.length, poll.centroids.length);
        Arrays.sort(poll.centroids, Comparator.comparingDouble(denseVector -> {
            return denseVector.get(0);
        }));
        for (int i = 0; i < kMeansModelData.centroids.length; i++) {
            Assert.assertArrayEquals(kMeansModelData.centroids[i].values, poll.centroids[i].values, 1.0E-5d);
        }
    }

    @Test
    public void testSetModelData() throws Exception {
        KMeansModelData kMeansModelData = new KMeansModelData(new DenseVector[]{Vectors.dense(new double[]{10.1d, 0.1d}), Vectors.dense(new double[]{-10.2d, 0.2d})}, Vectors.dense(new double[]{0.0d, 0.0d}));
        KMeansModelData kMeansModelData2 = new KMeansModelData(new DenseVector[]{Vectors.dense(new double[]{10.1d, 100.1d}), Vectors.dense(new double[]{-10.2d, -100.2d})}, Vectors.dense(new double[]{0.0d, 0.0d}));
        InMemorySourceFunction inMemorySourceFunction = new InMemorySourceFunction();
        OnlineKMeansModel onlineKMeansModel = (OnlineKMeansModel) ((OnlineKMeansModel) new OnlineKMeansModel().setModelData(new Table[]{tEnv.fromDataStream(env.addSource(inMemorySourceFunction, TypeInformation.of(KMeansModelData.class)))}).setFeaturesCol("features")).setPredictionCol("prediction");
        transformAndOutputData(onlineKMeansModel);
        JobID submitJob = submitJob(env.getStreamGraph().getJobGraph());
        inMemorySourceFunction.addAll(kMeansModelData);
        waitInitModelDataSetup(submitJob);
        predictAndAssert(expectedGroups1, onlineKMeansModel.getFeaturesCol(), onlineKMeansModel.getPredictionCol());
        inMemorySourceFunction.addAll(kMeansModelData2);
        waitModelDataUpdate(submitJob);
        predictAndAssert(expectedGroups2, onlineKMeansModel.getFeaturesCol(), onlineKMeansModel.getPredictionCol());
    }
}
