package org.apache.flink.ml.feature;

import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import org.apache.commons.collections.IteratorUtils;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.typeutils.RowTypeInfo;
import org.apache.flink.ml.feature.sqltransformer.SQLTransformer;
import org.apache.flink.ml.util.TestUtils;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.DataTypes;
import org.apache.flink.table.api.Schema;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.test.util.AbstractTestBase;
import org.apache.flink.types.Row;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;

/* loaded from: input_file:org/apache/flink/ml/feature/SQLTransformerTest.class */
public class SQLTransformerTest extends AbstractTestBase {
    private static final List<Row> INPUT_DATA = Arrays.asList(Row.of(new Object[]{0, Double.valueOf(1.0d), Double.valueOf(3.0d)}), Row.of(new Object[]{1, Double.valueOf(2.0d), Double.valueOf(3.0d)}), Row.of(new Object[]{2, Double.valueOf(2.0d), Double.valueOf(2.0d)}), Row.of(new Object[]{3, Double.valueOf(4.0d), Double.valueOf(2.0d)}));
    private static final List<Row> EXPECTED_NUMERIC_DATA_OUTPUT = Arrays.asList(Row.of(new Object[]{0, Double.valueOf(1.0d), Double.valueOf(3.0d), Double.valueOf(4.0d), Double.valueOf(3.0d)}), Row.of(new Object[]{1, Double.valueOf(2.0d), Double.valueOf(3.0d), Double.valueOf(5.0d), Double.valueOf(6.0d)}), Row.of(new Object[]{2, Double.valueOf(2.0d), Double.valueOf(2.0d), Double.valueOf(4.0d), Double.valueOf(4.0d)}), Row.of(new Object[]{3, Double.valueOf(4.0d), Double.valueOf(2.0d), Double.valueOf(6.0d), Double.valueOf(8.0d)}));
    private static final List<Row> EXPECTED_BUILT_IN_FUNCTION_OUTPUT = Arrays.asList(Row.of(new Object[]{0, Double.valueOf(1.0d), Double.valueOf(3.0d), Double.valueOf(1.0d)}), Row.of(new Object[]{1, Double.valueOf(2.0d), Double.valueOf(3.0d), Double.valueOf(Math.sqrt(2.0d))}), Row.of(new Object[]{2, Double.valueOf(2.0d), Double.valueOf(2.0d), Double.valueOf(Math.sqrt(2.0d))}), Row.of(new Object[]{3, Double.valueOf(4.0d), Double.valueOf(2.0d), Double.valueOf(2.0d)}));
    private static final List<Row> EXPECTED_GROUP_BY_AGGREGATION_OUTPUT = Arrays.asList(Row.of(new Object[]{Double.valueOf(3.0d), Double.valueOf(3.0d)}), Row.of(new Object[]{Double.valueOf(2.0d), Double.valueOf(6.0d)}));
    private static final List<Row> EXPECTED_WINDOW_AGGREGATION_OUTPUT = Collections.singletonList(Row.of(new Object[]{Double.valueOf(9.0d)}));
    private StreamTableEnvironment tEnv;
    private StreamExecutionEnvironment env;
    private Table inputTable;

    @Before
    public void before() {
        this.env = TestUtils.getExecutionEnvironment();
        this.tEnv = StreamTableEnvironment.create(this.env);
        this.inputTable = this.tEnv.fromDataStream(this.env.fromCollection(INPUT_DATA, new RowTypeInfo(new TypeInformation[]{Types.INT, Types.DOUBLE, Types.DOUBLE}))).as("id", new String[]{"v1", "v2"});
    }

    @Test
    public void testParam() {
        SQLTransformer sQLTransformer = new SQLTransformer();
        sQLTransformer.setStatement("SELECT * FROM __THIS__");
        Assert.assertEquals("SELECT * FROM __THIS__", sQLTransformer.getStatement());
    }

    @Test
    public void testInvalidSQLStatement() {
        try {
            new SQLTransformer().setStatement("SELECT * FROM __THAT__");
            Assert.fail();
        } catch (Exception e) {
            Assert.assertEquals("Parameter statement is given an invalid value SELECT * FROM __THAT__", e.getMessage());
        }
    }

    @Test
    public void testOutputSchema() {
        Assert.assertEquals(Arrays.asList("id", "v1", "v2", "v3", "v4"), ((SQLTransformer) new SQLTransformer().setStatement("SELECT *, (v1 + v2) AS v3, (v1 * v2) AS v4 FROM __THIS__")).transform(new Table[]{this.inputTable})[0].getResolvedSchema().getColumnNames());
    }

    @Test
    public void testTransformNumericData() {
        verifyOutputResult(((SQLTransformer) new SQLTransformer().setStatement("SELECT *, (v1 + v2) AS v3, (v1 * v2) AS v4 FROM __THIS__")).transform(new Table[]{this.inputTable})[0], EXPECTED_NUMERIC_DATA_OUTPUT);
    }

    @Test
    public void testBuiltInFunction() {
        verifyOutputResult(((SQLTransformer) new SQLTransformer().setStatement("SELECT *, SQRT(v1) AS v3 FROM __THIS__")).transform(new Table[]{this.inputTable})[0], EXPECTED_BUILT_IN_FUNCTION_OUTPUT);
    }

    @Test
    public void testGroupByAggregation() {
        verifyOutputResult(((SQLTransformer) new SQLTransformer().setStatement("SELECT v2, SUM(v1) AS v3 FROM __THIS__ GROUP BY v2")).transform(new Table[]{this.inputTable})[0], EXPECTED_GROUP_BY_AGGREGATION_OUTPUT);
    }

    @Test
    public void testWindowAggregation() {
        this.inputTable = this.tEnv.fromDataStream(this.env.fromCollection(INPUT_DATA, new RowTypeInfo(new TypeInformation[]{Types.INT, Types.DOUBLE, Types.DOUBLE}, new String[]{"id", "v1", "v2"})), Schema.newBuilder().column("id", DataTypes.INT()).column("v1", DataTypes.DOUBLE()).column("v2", DataTypes.DOUBLE()).columnByExpression("time_ltz", "TO_TIMESTAMP_LTZ(id * 1000, 3)").watermark("time_ltz", "time_ltz - INTERVAL '5' SECOND").build());
        verifyOutputResult(((SQLTransformer) new SQLTransformer().setStatement("SELECT SUM(v1) AS v3 FROM TABLE(TUMBLE(TABLE __THIS__, DESCRIPTOR(time_ltz), INTERVAL '10' MINUTES)) GROUP BY window_start, window_end")).transform(new Table[]{this.inputTable})[0], EXPECTED_WINDOW_AGGREGATION_OUTPUT);
    }

    @Test
    public void testSaveLoadAndTransform() throws Exception {
        verifyOutputResult(TestUtils.saveAndReload(this.tEnv, (SQLTransformer) new SQLTransformer().setStatement("SELECT *, (v1 + v2) AS v3, (v1 * v2) AS v4 FROM __THIS__"), TEMPORARY_FOLDER.newFolder().getAbsolutePath(), SQLTransformer::load).transform(new Table[]{this.inputTable})[0], EXPECTED_NUMERIC_DATA_OUTPUT);
    }

    private static void verifyOutputResult(Table table, List<Row> list) {
        Assert.assertEquals(new HashSet(list), new HashSet(IteratorUtils.toList(table.execute().collect())));
    }
}
