package org.apache.flink.ml.feature;

import java.util.Arrays;
import java.util.List;
import org.apache.commons.collections.IteratorUtils;
import org.apache.commons.lang3.exception.ExceptionUtils;
import org.apache.flink.api.common.restartstrategy.RestartStrategies;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.ml.feature.vectorassembler.VectorAssembler;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.SparseVector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.ml.util.TestUtils;
import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.test.util.AbstractTestBase;
import org.apache.flink.types.Row;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;

/* loaded from: input_file:org/apache/flink/ml/feature/VectorAssemblerTest.class */
public class VectorAssemblerTest extends AbstractTestBase {
    private StreamTableEnvironment tEnv;
    private Table inputDataTable;
    private static final List<Row> INPUT_DATA = Arrays.asList(Row.of(new Object[]{0, Vectors.dense(new double[]{2.1d, 3.1d}), Double.valueOf(1.0d), Vectors.sparse(5, new int[]{3}, new double[]{1.0d})}), Row.of(new Object[]{1, Vectors.dense(new double[]{2.1d, 3.1d}), Double.valueOf(1.0d), Vectors.sparse(5, new int[]{4, 2, 3, 1}, new double[]{4.0d, 2.0d, 3.0d, 1.0d})}), Row.of(new Object[]{2, null, null, null}));
    private static final SparseVector EXPECTED_OUTPUT_DATA_1 = Vectors.sparse(8, new int[]{0, 1, 2, 6}, new double[]{2.1d, 3.1d, 1.0d, 1.0d});
    private static final DenseVector EXPECTED_OUTPUT_DATA_2 = Vectors.dense(new double[]{2.1d, 3.1d, 1.0d, 0.0d, 1.0d, 2.0d, 3.0d, 4.0d});

    @Before
    public void before() {
        Configuration configuration = new Configuration();
        configuration.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
        StreamExecutionEnvironment executionEnvironment = StreamExecutionEnvironment.getExecutionEnvironment(configuration);
        executionEnvironment.setParallelism(4);
        executionEnvironment.enableCheckpointing(100L);
        executionEnvironment.setRestartStrategy(RestartStrategies.noRestart());
        this.tEnv = StreamTableEnvironment.create(executionEnvironment);
        this.inputDataTable = this.tEnv.fromDataStream(executionEnvironment.fromCollection(INPUT_DATA)).as("id", new String[]{"vec", "num", "sparseVec"});
    }

    private void verifyOutputResult(Table table, String str, int i) throws Exception {
        List<Row> list = IteratorUtils.toList(this.tEnv.toDataStream(table).executeAndCollect());
        Assert.assertEquals(i, list.size());
        for (Row row : list) {
            if (row.getField(0) == 0) {
                Assert.assertEquals(EXPECTED_OUTPUT_DATA_1, row.getField(str));
            } else if (row.getField(0) == 1) {
                Assert.assertEquals(EXPECTED_OUTPUT_DATA_2, row.getField(str));
            } else {
                Assert.assertNull(row.getField(str));
            }
        }
    }

    @Test
    public void testParam() {
        VectorAssembler vectorAssembler = new VectorAssembler();
        Assert.assertEquals("error", vectorAssembler.getHandleInvalid());
        Assert.assertEquals("output", vectorAssembler.getOutputCol());
        ((VectorAssembler) ((VectorAssembler) vectorAssembler.setInputCols(new String[]{"vec", "num", "sparseVec"})).setOutputCol("assembledVec")).setHandleInvalid("skip");
        Assert.assertArrayEquals(new String[]{"vec", "num", "sparseVec"}, vectorAssembler.getInputCols());
        Assert.assertEquals("skip", vectorAssembler.getHandleInvalid());
        Assert.assertEquals("assembledVec", vectorAssembler.getOutputCol());
    }

    @Test
    public void testKeepInvalid() throws Exception {
        VectorAssembler vectorAssembler = (VectorAssembler) ((VectorAssembler) ((VectorAssembler) new VectorAssembler().setInputCols(new String[]{"vec", "num", "sparseVec"})).setOutputCol("assembledVec")).setHandleInvalid("keep");
        Table table = vectorAssembler.transform(new Table[]{this.inputDataTable})[0];
        Assert.assertEquals(Arrays.asList("id", "vec", "num", "sparseVec", "assembledVec"), table.getResolvedSchema().getColumnNames());
        verifyOutputResult(table, vectorAssembler.getOutputCol(), 3);
    }

    @Test
    public void testErrorInvalid() {
        try {
            ((VectorAssembler) ((VectorAssembler) ((VectorAssembler) new VectorAssembler().setInputCols(new String[]{"vec", "num", "sparseVec"})).setOutputCol("assembledVec")).setHandleInvalid("error")).transform(new Table[]{this.inputDataTable})[0].execute().collect().next();
            Assert.fail("Expected IllegalArgumentException");
        } catch (Throwable th) {
            Assert.assertEquals("Input column value should not be null.", ExceptionUtils.getRootCause(th).getMessage());
        }
    }

    @Test
    public void testSkipInvalid() throws Exception {
        VectorAssembler vectorAssembler = (VectorAssembler) ((VectorAssembler) ((VectorAssembler) new VectorAssembler().setInputCols(new String[]{"vec", "num", "sparseVec"})).setOutputCol("assembledVec")).setHandleInvalid("skip");
        Table table = vectorAssembler.transform(new Table[]{this.inputDataTable})[0];
        Assert.assertEquals(Arrays.asList("id", "vec", "num", "sparseVec", "assembledVec"), table.getResolvedSchema().getColumnNames());
        verifyOutputResult(table, vectorAssembler.getOutputCol(), 2);
    }

    @Test
    public void testSaveLoadAndTransform() throws Exception {
        VectorAssembler saveAndReload = TestUtils.saveAndReload(this.tEnv, (VectorAssembler) ((VectorAssembler) ((VectorAssembler) new VectorAssembler().setInputCols(new String[]{"vec", "num", "sparseVec"})).setOutputCol("assembledVec")).setHandleInvalid("skip"), TEMPORARY_FOLDER.newFolder().getAbsolutePath());
        verifyOutputResult(saveAndReload.transform(new Table[]{this.inputDataTable})[0], saveAndReload.getOutputCol(), 2);
    }

    @Test
    public void testInputTypeConversion() throws Exception {
        this.inputDataTable = TestUtils.convertDataTypesToSparseInt(this.tEnv, this.inputDataTable);
        Assert.assertArrayEquals(new Class[]{Integer.class, SparseVector.class, Integer.class, SparseVector.class}, TestUtils.getColumnDataTypes(this.inputDataTable));
        VectorAssembler saveAndReload = TestUtils.saveAndReload(this.tEnv, (VectorAssembler) ((VectorAssembler) ((VectorAssembler) new VectorAssembler().setInputCols(new String[]{"vec", "num", "sparseVec"})).setOutputCol("assembledVec")).setHandleInvalid("skip"), TEMPORARY_FOLDER.newFolder().getAbsolutePath());
        verifyOutputResult(saveAndReload.transform(new Table[]{this.inputDataTable})[0], saveAndReload.getOutputCol(), 2);
    }
}
