package org.apache.flink.ml.feature;

import java.util.Arrays;
import java.util.List;
import org.apache.commons.collections.IteratorUtils;
import org.apache.commons.lang3.exception.ExceptionUtils;
import org.apache.flink.ml.feature.vectorslicer.VectorSlicer;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.SparseVector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.ml.util.TestUtils;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.test.util.AbstractTestBase;
import org.apache.flink.types.Row;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;

/* loaded from: input_file:org/apache/flink/ml/feature/VectorSlicerTest.class */
public class VectorSlicerTest extends AbstractTestBase {
    private StreamTableEnvironment tEnv;
    private Table inputDataTable;
    private static final List<Row> INPUT_DATA = Arrays.asList(Row.of(new Object[]{0, Vectors.dense(new double[]{2.1d, 3.1d, 2.3d, 3.4d, 5.3d, 5.1d}), Vectors.sparse(5, new int[]{1, 3, 4}, new double[]{0.1d, 0.2d, 0.3d})}), Row.of(new Object[]{1, Vectors.dense(new double[]{2.3d, 4.1d, 1.3d, 2.4d, 5.1d, 4.1d}), Vectors.sparse(5, new int[]{1, 2, 4}, new double[]{0.1d, 0.2d, 0.3d})}));
    private static final DenseVector EXPECTED_OUTPUT_DATA_1 = Vectors.dense(new double[]{2.1d, 3.1d, 2.3d});
    private static final DenseVector EXPECTED_OUTPUT_DATA_2 = Vectors.dense(new double[]{2.3d, 4.1d, 1.3d});
    private static final SparseVector EXPECTED_OUTPUT_DATA_3 = Vectors.sparse(3, new int[]{1}, new double[]{0.1d});
    private static final SparseVector EXPECTED_OUTPUT_DATA_4 = Vectors.sparse(3, new int[]{1, 2}, new double[]{0.1d, 0.2d});

    @Before
    public void before() {
        StreamExecutionEnvironment executionEnvironment = TestUtils.getExecutionEnvironment();
        this.tEnv = StreamTableEnvironment.create(executionEnvironment);
        this.inputDataTable = this.tEnv.fromDataStream(executionEnvironment.fromCollection(INPUT_DATA)).as("id", new String[]{"vec", "sparseVec"});
    }

    private void verifyOutputResult(Table table, String str, boolean z) throws Exception {
        List<Row> list = IteratorUtils.toList(this.tEnv.toDataStream(table).executeAndCollect());
        Assert.assertEquals(2L, list.size());
        for (Row row : list) {
            if (row.getField(0) != 0) {
                if (row.getField(0) != 1) {
                    throw new RuntimeException("Result id value is error, it must be 0 or 1.");
                }
                if (z) {
                    Assert.assertEquals(EXPECTED_OUTPUT_DATA_4, row.getField(str));
                } else {
                    Assert.assertEquals(EXPECTED_OUTPUT_DATA_2, row.getField(str));
                }
            } else if (z) {
                Assert.assertEquals(EXPECTED_OUTPUT_DATA_3, row.getField(str));
            } else {
                Assert.assertEquals(EXPECTED_OUTPUT_DATA_1, row.getField(str));
            }
        }
    }

    @Test
    public void testParam() {
        VectorSlicer vectorSlicer = new VectorSlicer();
        Assert.assertEquals("input", vectorSlicer.getInputCol());
        Assert.assertEquals("output", vectorSlicer.getOutputCol());
        ((VectorSlicer) ((VectorSlicer) vectorSlicer.setInputCol("vec")).setOutputCol("sliceVec")).setIndices(new Integer[]{0, 1, 2});
        Assert.assertEquals("vec", vectorSlicer.getInputCol());
        Assert.assertEquals("sliceVec", vectorSlicer.getOutputCol());
        Assert.assertArrayEquals(new Integer[]{0, 1, 2}, vectorSlicer.getIndices());
    }

    @Test
    public void testSaveLoadAndTransform() throws Exception {
        VectorSlicer saveAndReload = TestUtils.saveAndReload(this.tEnv, (VectorSlicer) ((VectorSlicer) ((VectorSlicer) new VectorSlicer().setInputCol("vec")).setOutputCol("sliceVec")).setIndices(new Integer[]{0, 1, 2}), TEMPORARY_FOLDER.newFolder().getAbsolutePath(), VectorSlicer::load);
        verifyOutputResult(saveAndReload.transform(new Table[]{this.inputDataTable})[0], saveAndReload.getOutputCol(), false);
    }

    @Test
    public void testEmptyIndices() {
        try {
            ((VectorSlicer) ((VectorSlicer) ((VectorSlicer) new VectorSlicer().setInputCol("vec")).setOutputCol("sliceVec")).setIndices(new Integer[0])).transform(new Table[]{this.inputDataTable});
            Assert.fail();
        } catch (Exception e) {
            Assert.assertEquals("Parameter indices is given an invalid value {}", e.getMessage());
        }
    }

    @Test
    public void testIndicesLargerThanVectorSize() {
        try {
            IteratorUtils.toList(this.tEnv.toDataStream(((VectorSlicer) ((VectorSlicer) ((VectorSlicer) new VectorSlicer().setInputCol("vec")).setOutputCol("sliceVec")).setIndices(new Integer[]{1, 2, 10})).transform(new Table[]{this.inputDataTable})[0]).executeAndCollect());
            Assert.fail();
        } catch (Exception e) {
            Assert.assertEquals("Index value 10 is greater than vector size:6", ExceptionUtils.getRootCause(e).getMessage());
        }
    }

    @Test
    public void testIndicesSmallerThanZero() {
        try {
            ((VectorSlicer) ((VectorSlicer) new VectorSlicer().setInputCol("vec")).setOutputCol("sliceVec")).setIndices(new Integer[]{1, -2});
            Assert.fail();
        } catch (Exception e) {
            Assert.assertEquals("Parameter indices is given an invalid value {1,-2}", e.getMessage());
        }
    }

    @Test
    public void testDuplicateIndices() {
        try {
            ((VectorSlicer) ((VectorSlicer) new VectorSlicer().setInputCol("vec")).setOutputCol("sliceVec")).setIndices(new Integer[]{1, 1, 3});
            Assert.fail();
        } catch (Exception e) {
            Assert.assertEquals("Parameter indices is given an invalid value {1,1,3}", e.getMessage());
        }
    }

    @Test
    public void testDenseTransform() throws Exception {
        VectorSlicer vectorSlicer = (VectorSlicer) ((VectorSlicer) ((VectorSlicer) new VectorSlicer().setInputCol("vec")).setOutputCol("sliceVec")).setIndices(new Integer[]{0, 1, 2});
        verifyOutputResult(vectorSlicer.transform(new Table[]{this.inputDataTable})[0], vectorSlicer.getOutputCol(), false);
    }

    @Test
    public void testDenseTransformWithUnorderedIndices() throws Exception {
        VectorSlicer vectorSlicer = (VectorSlicer) ((VectorSlicer) ((VectorSlicer) new VectorSlicer().setInputCol("vec")).setOutputCol("sliceVec")).setIndices(new Integer[]{0, 2, 1});
        List<Row> list = IteratorUtils.toList(this.tEnv.toDataStream(vectorSlicer.transform(new Table[]{this.inputDataTable})[0]).executeAndCollect());
        Assert.assertEquals(2L, list.size());
        for (Row row : list) {
            if (row.getField(0) == 0) {
                Assert.assertEquals(Vectors.dense(new double[]{2.1d, 2.3d, 3.1d}), row.getField(vectorSlicer.getOutputCol()));
            } else {
                if (row.getField(0) != 1) {
                    throw new RuntimeException("Result id value is error, it must be 0 or 1.");
                }
                Assert.assertEquals(Vectors.dense(new double[]{2.3d, 1.3d, 4.1d}), row.getField(vectorSlicer.getOutputCol()));
            }
        }
    }

    @Test
    public void testSparseTransform() throws Exception {
        VectorSlicer vectorSlicer = (VectorSlicer) ((VectorSlicer) ((VectorSlicer) new VectorSlicer().setInputCol("sparseVec")).setOutputCol("sliceVec")).setIndices(new Integer[]{0, 1, 2});
        verifyOutputResult(vectorSlicer.transform(new Table[]{this.inputDataTable})[0], vectorSlicer.getOutputCol(), true);
    }
}
