package org.apache.flink.ml.feature;

import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import org.apache.commons.collections.IteratorUtils;
import org.apache.flink.ml.feature.ngram.NGram;
import org.apache.flink.ml.util.TestUtils;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Expressions;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.expressions.Expression;
import org.apache.flink.test.util.AbstractTestBase;
import org.apache.flink.types.Row;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;

/* loaded from: input_file:org/apache/flink/ml/feature/NGramTest.class */
public class NGramTest extends AbstractTestBase {
    private StreamTableEnvironment tEnv;
    private StreamExecutionEnvironment env;
    private Table inputDataTable;
    private static final List<Row> EXPECTED_OUTPUT = Arrays.asList(Row.of(new Object[]{new String[0]}), Row.of(new Object[]{new String[]{"a b", "b c"}}), Row.of(new Object[]{new String[]{"a b", "b c", "c d"}}));

    @Before
    public void before() {
        this.env = TestUtils.getExecutionEnvironment();
        this.tEnv = StreamTableEnvironment.create(this.env);
        this.inputDataTable = this.tEnv.fromDataStream(this.env.fromCollection(Arrays.asList(Row.of(new Object[]{new String[0]}), Row.of(new Object[]{new String[]{"a", "b", "c"}}), Row.of(new Object[]{new String[]{"a", "b", "c", "d"}})))).as("input", new String[0]);
    }

    @Test
    public void testParam() {
        NGram nGram = new NGram();
        Assert.assertEquals("input", nGram.getInputCol());
        Assert.assertEquals("output", nGram.getOutputCol());
        Assert.assertEquals(2L, nGram.getN());
        ((NGram) ((NGram) nGram.setInputCol("testInputCol")).setOutputCol("testOutputCol")).setN(5);
        Assert.assertEquals("testInputCol", nGram.getInputCol());
        Assert.assertEquals("testOutputCol", nGram.getOutputCol());
        Assert.assertEquals(5L, nGram.getN());
    }

    @Test
    public void testOutputSchema() {
        NGram nGram = new NGram();
        this.inputDataTable = this.tEnv.fromDataStream(this.env.fromElements(new Row[]{Row.of(new Object[]{new String[]{""}, ""})})).as("input", new String[]{"dummyInput"});
        Assert.assertEquals(Arrays.asList(nGram.getInputCol(), "dummyInput", nGram.getOutputCol()), nGram.transform(new Table[]{this.inputDataTable})[0].getResolvedSchema().getColumnNames());
    }

    @Test
    public void testTransform() throws Exception {
        NGram nGram = new NGram();
        verifyOutputResult(nGram.transform(new Table[]{this.inputDataTable})[0], nGram.getOutputCol());
    }

    @Test
    public void testSaveLoadAndTransform() throws Exception {
        NGram saveAndReload = TestUtils.saveAndReload(this.tEnv, new NGram(), TEMPORARY_FOLDER.newFolder().getAbsolutePath(), NGram::load);
        verifyOutputResult(saveAndReload.transform(new Table[]{this.inputDataTable})[0], saveAndReload.getOutputCol());
    }

    private void verifyOutputResult(Table table, String str) throws Exception {
        List list = IteratorUtils.toList(this.tEnv.toDataStream(table.select(new Expression[]{Expressions.$(str)})).executeAndCollect());
        Assert.assertEquals(EXPECTED_OUTPUT.size(), list.size());
        list.sort(Comparator.comparingInt(row -> {
            return ((String[]) row.getField(0)).length;
        }));
        for (int i = 0; i < EXPECTED_OUTPUT.size(); i++) {
            Assert.assertArrayEquals((String[]) EXPECTED_OUTPUT.get(i).getField(0), (String[]) ((Row) list.get(i)).getField(0));
        }
    }
}
