package org.apache.flink.ml.feature;

import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import org.apache.commons.collections.IteratorUtils;
import org.apache.flink.ml.feature.regextokenizer.RegexTokenizer;
import org.apache.flink.ml.feature.regextokenizer.RegexTokenizerParams;
import org.apache.flink.ml.util.TestUtils;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Expressions;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.expressions.Expression;
import org.apache.flink.test.util.AbstractTestBase;
import org.apache.flink.types.Row;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;

/* loaded from: input_file:org/apache/flink/ml/feature/RegexTokenizerTest.class */
public class RegexTokenizerTest extends AbstractTestBase {
    private StreamTableEnvironment tEnv;
    private StreamExecutionEnvironment env;
    private Table inputDataTable;
    private static final List<Row> INPUT = Arrays.asList(Row.of(new Object[]{"Test for tokenization."}), Row.of(new Object[]{"Te,st. punct"}));

    @Before
    public void before() {
        this.env = TestUtils.getExecutionEnvironment();
        this.tEnv = StreamTableEnvironment.create(this.env);
        this.inputDataTable = this.tEnv.fromDataStream(this.env.fromCollection(INPUT)).as("input", new String[0]);
    }

    @Test
    public void testParam() {
        RegexTokenizer regexTokenizer = new RegexTokenizer();
        Assert.assertEquals("input", regexTokenizer.getInputCol());
        Assert.assertEquals("output", regexTokenizer.getOutputCol());
        Assert.assertEquals(1L, regexTokenizer.getMinTokenLength());
        Assert.assertEquals(true, regexTokenizer.getGaps());
        Assert.assertEquals("\\s+", regexTokenizer.getPattern());
        Assert.assertEquals(true, regexTokenizer.getToLowercase());
        ((RegexTokenizer) ((RegexTokenizer) ((RegexTokenizer) ((RegexTokenizer) ((RegexTokenizer) regexTokenizer.setInputCol("testInputCol")).setOutputCol("testOutputCol")).setMinTokenLength(3)).setGaps(false)).setPattern("\\s")).setToLowercase(false);
        Assert.assertEquals("testInputCol", regexTokenizer.getInputCol());
        Assert.assertEquals("testOutputCol", regexTokenizer.getOutputCol());
        Assert.assertEquals(3L, regexTokenizer.getMinTokenLength());
        Assert.assertEquals(false, regexTokenizer.getGaps());
        Assert.assertEquals("\\s", regexTokenizer.getPattern());
        Assert.assertEquals(false, regexTokenizer.getToLowercase());
    }

    @Test
    public void testOutputSchema() {
        RegexTokenizer regexTokenizer = new RegexTokenizer();
        this.inputDataTable = this.tEnv.fromDataStream(this.env.fromElements(new Row[]{Row.of(new Object[]{"", ""})})).as("input", new String[]{"dummyInput"});
        Assert.assertEquals(Arrays.asList(regexTokenizer.getInputCol(), "dummyInput", regexTokenizer.getOutputCol()), regexTokenizer.transform(new Table[]{this.inputDataTable})[0].getResolvedSchema().getColumnNames());
    }

    @Test
    public void testTransform() throws Exception {
        int intValue = ((Integer) RegexTokenizerParams.MIN_TOKEN_LENGTH.defaultValue).intValue();
        boolean booleanValue = ((Boolean) RegexTokenizerParams.GAPS.defaultValue).booleanValue();
        String str = (String) RegexTokenizerParams.PATTERN.defaultValue;
        checkTransform(intValue, booleanValue, str, ((Boolean) RegexTokenizerParams.TO_LOWERCASE.defaultValue).booleanValue(), Arrays.asList(Row.of(new Object[]{new String[]{"test", "for", "tokenization."}}), Row.of(new Object[]{new String[]{"te,st.", "punct"}})));
        checkTransform(intValue, booleanValue, str, false, Arrays.asList(Row.of(new Object[]{new String[]{"Test", "for", "tokenization."}}), Row.of(new Object[]{new String[]{"Te,st.", "punct"}})));
        checkTransform(intValue, false, "\\w+|\\p{Punct}", true, Arrays.asList(Row.of(new Object[]{new String[]{"test", "for", "tokenization", "."}}), Row.of(new Object[]{new String[]{"te", ",", "st", ".", "punct"}})));
        checkTransform(3, false, "\\w+|\\p{Punct}", true, Arrays.asList(Row.of(new Object[]{new String[]{"test", "for", "tokenization"}}), Row.of(new Object[]{new String[]{"punct"}})));
    }

    @Test
    public void testSaveLoadAndTransform() throws Exception {
        RegexTokenizer regexTokenizer = new RegexTokenizer();
        List<Row> asList = Arrays.asList(Row.of(new Object[]{new String[]{"test", "for", "tokenization."}}), Row.of(new Object[]{new String[]{"te,st.", "punct"}}));
        RegexTokenizer saveAndReload = TestUtils.saveAndReload(this.tEnv, regexTokenizer, TEMPORARY_FOLDER.newFolder().getAbsolutePath(), RegexTokenizer::load);
        verifyOutputResult(saveAndReload.transform(new Table[]{this.inputDataTable})[0], saveAndReload.getOutputCol(), asList);
    }

    private void checkTransform(int i, boolean z, String str, boolean z2, List<Row> list) throws Exception {
        RegexTokenizer regexTokenizer = (RegexTokenizer) ((RegexTokenizer) ((RegexTokenizer) ((RegexTokenizer) new RegexTokenizer().setMinTokenLength(i)).setGaps(z)).setPattern(str)).setToLowercase(z2);
        verifyOutputResult(regexTokenizer.transform(new Table[]{this.inputDataTable})[0], regexTokenizer.getOutputCol(), list);
    }

    private void verifyOutputResult(Table table, String str, List<Row> list) throws Exception {
        List list2 = IteratorUtils.toList(this.tEnv.toDataStream(table.select(new Expression[]{Expressions.$(str)})).executeAndCollect());
        Assert.assertEquals(list.size(), list2.size());
        list2.sort(Comparator.comparingInt(row -> {
            return ((String[]) row.getField(0))[0].hashCode();
        }));
        list.sort(Comparator.comparingInt(row2 -> {
            return ((String[]) row2.getField(0))[0].hashCode();
        }));
        for (int i = 0; i < list.size(); i++) {
            Assert.assertArrayEquals((String[]) ((Row) list2.get(i)).getField(0), (String[]) list.get(i).getField(0));
        }
    }
}
