package org.apache.flink.ml.classification;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.apache.commons.collections.IteratorUtils;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.restartstrategy.RestartStrategies;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.ml.classification.knn.Knn;
import org.apache.flink.ml.classification.knn.KnnModel;
import org.apache.flink.ml.classification.knn.KnnModelData;
import org.apache.flink.ml.linalg.DenseMatrix;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.SparseVector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.ml.util.ReadWriteUtils;
import org.apache.flink.ml.util.TestUtils;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.DataTypes;
import org.apache.flink.table.api.Schema;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.api.internal.TableImpl;
import org.apache.flink.types.Row;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;

/* loaded from: input_file:org/apache/flink/ml/classification/KnnTest.class */
public class KnnTest {

    @Rule
    public final TemporaryFolder tempFolder = new TemporaryFolder();
    private StreamExecutionEnvironment env;
    private StreamTableEnvironment tEnv;
    private Table trainData;
    private Table predictData;
    private static final List<Row> trainRows = new ArrayList(Arrays.asList(Row.of(new Object[]{Vectors.dense(new double[]{2.0d, 3.0d}), Double.valueOf(1.0d)}), Row.of(new Object[]{Vectors.dense(new double[]{2.1d, 3.1d}), Double.valueOf(1.0d)}), Row.of(new Object[]{Vectors.dense(new double[]{200.1d, 300.1d}), Double.valueOf(2.0d)}), Row.of(new Object[]{Vectors.dense(new double[]{200.2d, 300.2d}), Double.valueOf(2.0d)}), Row.of(new Object[]{Vectors.dense(new double[]{200.3d, 300.3d}), Double.valueOf(2.0d)}), Row.of(new Object[]{Vectors.dense(new double[]{200.4d, 300.4d}), Double.valueOf(2.0d)}), Row.of(new Object[]{Vectors.dense(new double[]{200.4d, 300.4d}), Double.valueOf(2.0d)}), Row.of(new Object[]{Vectors.dense(new double[]{200.6d, 300.6d}), Double.valueOf(2.0d)}), Row.of(new Object[]{Vectors.dense(new double[]{2.1d, 3.1d}), Double.valueOf(1.0d)}), Row.of(new Object[]{Vectors.dense(new double[]{2.1d, 3.1d}), Double.valueOf(1.0d)}), Row.of(new Object[]{Vectors.dense(new double[]{2.1d, 3.1d}), Double.valueOf(1.0d)}), Row.of(new Object[]{Vectors.dense(new double[]{2.1d, 3.1d}), Double.valueOf(1.0d)}), Row.of(new Object[]{Vectors.dense(new double[]{2.3d, 3.2d}), Double.valueOf(1.0d)}), Row.of(new Object[]{Vectors.dense(new double[]{2.3d, 3.2d}), Double.valueOf(1.0d)}), Row.of(new Object[]{Vectors.dense(new double[]{2.8d, 3.2d}), Double.valueOf(3.0d)}), Row.of(new Object[]{Vectors.dense(new double[]{300.0d, 3.2d}), Double.valueOf(4.0d)}), Row.of(new Object[]{Vectors.dense(new double[]{2.2d, 3.2d}), Double.valueOf(1.0d)}), Row.of(new Object[]{Vectors.dense(new double[]{2.4d, 3.2d}), Double.valueOf(5.0d)}), Row.of(new Object[]{Vectors.dense(new double[]{2.5d, 3.2d}), Double.valueOf(5.0d)}), Row.of(new Object[]{Vectors.dense(new double[]{2.5d, 3.2d}), Double.valueOf(5.0d)}), Row.of(new Object[]{Vectors.dense(new double[]{2.1d, 3.1d}), Double.valueOf(1.0d)})));
    private static final List<Row> predictRows = new ArrayList(Arrays.asList(Row.of(new Object[]{Vectors.dense(new double[]{4.0d, 4.1d}), Double.valueOf(5.0d)}), Row.of(new Object[]{Vectors.dense(new double[]{300.0d, 42.0d}), Double.valueOf(2.0d)})));

    @Before
    public void before() {
        Configuration configuration = new Configuration();
        configuration.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
        this.env = StreamExecutionEnvironment.getExecutionEnvironment(configuration);
        this.env.setParallelism(4);
        this.env.enableCheckpointing(100L);
        this.env.setRestartStrategy(RestartStrategies.noRestart());
        this.tEnv = StreamTableEnvironment.create(this.env);
        Schema build = Schema.newBuilder().column("f0", DataTypes.of(DenseVector.class)).column("f1", DataTypes.DOUBLE()).build();
        this.trainData = this.tEnv.fromDataStream(this.env.fromCollection(trainRows), build).as("features", new String[]{"label"});
        this.predictData = this.tEnv.fromDataStream(this.env.fromCollection(predictRows), build).as("features", new String[]{"label"});
    }

    private static void verifyPredictionResult(Table table, final String str, final String str2) throws Exception {
        for (Tuple2 tuple2 : IteratorUtils.toList(((TableImpl) table).getTableEnvironment().toDataStream(table).map(new MapFunction<Row, Tuple2<Double, Double>>() { // from class: org.apache.flink.ml.classification.KnnTest.1
            public Tuple2<Double, Double> map(Row row) {
                return Tuple2.of(Double.valueOf(((Number) row.getField(str)).doubleValue()), (Double) row.getField(str2));
            }
        }).executeAndCollect())) {
            Assert.assertEquals(tuple2.f0, tuple2.f1);
        }
    }

    @Test
    public void testParam() {
        Knn knn = new Knn();
        Assert.assertEquals("features", knn.getFeaturesCol());
        Assert.assertEquals("label", knn.getLabelCol());
        Assert.assertEquals(5L, knn.getK().intValue());
        Assert.assertEquals("prediction", knn.getPredictionCol());
        ((Knn) ((Knn) ((Knn) knn.setLabelCol("test_label")).setFeaturesCol("test_features")).setK(4)).setPredictionCol("test_prediction");
        Assert.assertEquals("test_features", knn.getFeaturesCol());
        Assert.assertEquals("test_label", knn.getLabelCol());
        Assert.assertEquals(4L, knn.getK().intValue());
        Assert.assertEquals("test_prediction", knn.getPredictionCol());
    }

    @Test
    public void testOutputSchema() throws Exception {
        Assert.assertEquals(Arrays.asList("test_features", "test_label", "test_prediction"), ((Knn) ((Knn) ((Knn) ((Knn) new Knn().setLabelCol("test_label")).setFeaturesCol("test_features")).setK(4)).setPredictionCol("test_prediction")).fit(new Table[]{this.trainData.as("test_features, test_label", new String[0])}).transform(new Table[]{this.predictData.as("test_features, test_label", new String[0])})[0].getResolvedSchema().getColumnNames());
    }

    @Test
    public void testFewerDistinctPointsThanCluster() throws Exception {
        Knn knn = new Knn();
        verifyPredictionResult(knn.fit(new Table[]{this.predictData}).transform(new Table[]{this.predictData})[0], knn.getLabelCol(), knn.getPredictionCol());
    }

    @Test
    public void testFitAndPredict() throws Exception {
        Knn knn = new Knn();
        verifyPredictionResult(knn.fit(new Table[]{this.trainData}).transform(new Table[]{this.predictData})[0], knn.getLabelCol(), knn.getPredictionCol());
    }

    @Test
    public void testInputTypeConversion() throws Exception {
        this.trainData = TestUtils.convertDataTypesToSparseInt(this.tEnv, this.trainData);
        this.predictData = TestUtils.convertDataTypesToSparseInt(this.tEnv, this.predictData);
        Assert.assertArrayEquals(new Class[]{SparseVector.class, Integer.class}, TestUtils.getColumnDataTypes(this.trainData));
        Assert.assertArrayEquals(new Class[]{SparseVector.class, Integer.class}, TestUtils.getColumnDataTypes(this.predictData));
        Knn knn = new Knn();
        verifyPredictionResult(knn.fit(new Table[]{this.trainData}).transform(new Table[]{this.predictData})[0], knn.getLabelCol(), knn.getPredictionCol());
    }

    @Test
    public void testSaveLoadAndPredict() throws Exception {
        Knn knn = new Knn();
        KnnModel saveAndReload = TestUtils.saveAndReload(this.tEnv, TestUtils.saveAndReload(this.tEnv, knn, this.tempFolder.newFolder().getAbsolutePath()).fit(new Table[]{this.trainData}), this.tempFolder.newFolder().getAbsolutePath());
        Assert.assertEquals(Arrays.asList("packedFeatures", "featureNormSquares", "labels"), saveAndReload.getModelData()[0].getResolvedSchema().getColumnNames());
        verifyPredictionResult(saveAndReload.transform(new Table[]{this.predictData})[0], knn.getLabelCol(), knn.getPredictionCol());
    }

    @Test
    public void testModelSaveLoadAndPredict() throws Exception {
        Knn knn = new Knn();
        verifyPredictionResult(TestUtils.saveAndReload(this.tEnv, knn.fit(new Table[]{this.trainData}), this.tempFolder.newFolder().getAbsolutePath()).transform(new Table[]{this.predictData})[0], knn.getLabelCol(), knn.getPredictionCol());
    }

    @Test
    public void testGetModelData() throws Exception {
        Table table = new Knn().fit(new Table[]{this.trainData}).getModelData()[0];
        DataStream dataStream = this.tEnv.toDataStream(table);
        Assert.assertEquals("packedFeatures", table.getResolvedSchema().getColumnNames().get(0));
        Assert.assertEquals("featureNormSquares", table.getResolvedSchema().getColumnNames().get(1));
        Assert.assertEquals("labels", table.getResolvedSchema().getColumnNames().get(2));
        List list = IteratorUtils.toList(dataStream.executeAndCollect());
        Assert.assertNotNull(new KnnModelData((DenseMatrix) ((Row) list.get(0)).getField(0), (DenseVector) ((Row) list.get(0)).getField(1), (DenseVector) ((Row) list.get(0)).getField(2)));
        Assert.assertEquals(2L, r0.packedFeatures.numRows());
        Assert.assertEquals(r0.packedFeatures.numCols(), r0.labels.size());
        Assert.assertEquals(r0.featureNormSquares.size(), r0.labels.size());
    }

    @Test
    public void testSetModelData() throws Exception {
        Knn knn = new Knn();
        KnnModel fit = knn.fit(new Table[]{this.trainData});
        KnnModel modelData = new KnnModel().setModelData(new Table[]{fit.getModelData()[0]});
        ReadWriteUtils.updateExistingParams(modelData, fit.getParamMap());
        verifyPredictionResult(modelData.transform(new Table[]{this.predictData})[0], knn.getLabelCol(), knn.getPredictionCol());
    }
}
