package org.apache.flink.ml.feature;

import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Locale;
import org.apache.commons.collections.IteratorUtils;
import org.apache.flink.ml.api.Stage;
import org.apache.flink.ml.builder.PipelineModel;
import org.apache.flink.ml.feature.stopwordsremover.StopWordsRemover;
import org.apache.flink.ml.util.TestUtils;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.ValidationException;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.test.util.AbstractTestBase;
import org.apache.flink.types.Row;
import org.apache.flink.util.CloseableIterator;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;

/* loaded from: input_file:org/apache/flink/ml/feature/StopWordsRemoverTest.class */
public class StopWordsRemoverTest extends AbstractTestBase {
    private StreamExecutionEnvironment env;
    private StreamTableEnvironment tEnv;

    @Before
    public void before() {
        this.env = TestUtils.getExecutionEnvironment();
        this.tEnv = StreamTableEnvironment.create(this.env);
    }

    private static void verifyOutputResult(StopWordsRemover stopWordsRemover, Table table) {
        Table table2 = stopWordsRemover.transform(new Table[]{table})[0];
        int size = IteratorUtils.toList(table.execute().collect()).size();
        int i = 0;
        CloseableIterator collect = table2.execute().collect();
        while (collect.hasNext()) {
            Row row = (Row) collect.next();
            Assert.assertArrayEquals((String[]) row.getFieldAs("expected"), (String[]) row.getFieldAs("filtered"));
            i++;
        }
        Assert.assertEquals(size, i);
    }

    @Test
    public void testParams() {
        StopWordsRemover stopWordsRemover = new StopWordsRemover();
        Assert.assertTrue(Arrays.asList(stopWordsRemover.getStopWords()).containsAll(Arrays.asList("i", "would")));
        Assert.assertTrue(Arrays.asList(Locale.US.toString(), Locale.getDefault().toString()).contains(stopWordsRemover.getLocale()));
        Assert.assertFalse(stopWordsRemover.getCaseSensitive());
        ((StopWordsRemover) ((StopWordsRemover) ((StopWordsRemover) ((StopWordsRemover) stopWordsRemover.setInputCols(new String[]{"f1", "f2"})).setOutputCols(new String[]{"o1", "o2"})).setStopWords(StopWordsRemover.loadDefaultStopWords("turkish"))).setLocale(Locale.US.toString())).setCaseSensitive(true);
        Assert.assertArrayEquals(new String[]{"f1", "f2"}, stopWordsRemover.getInputCols());
        Assert.assertArrayEquals(new String[]{"o1", "o2"}, stopWordsRemover.getOutputCols());
        Assert.assertTrue(Arrays.asList(stopWordsRemover.getStopWords()).containsAll(Arrays.asList("acaba", "yani")));
        Assert.assertEquals(Locale.US.toString(), stopWordsRemover.getLocale());
        Assert.assertTrue(stopWordsRemover.getCaseSensitive());
    }

    @Test
    public void testOutputSchema() {
        Assert.assertEquals(Arrays.asList("raw", "expected", "filtered"), ((StopWordsRemover) ((StopWordsRemover) new StopWordsRemover().setInputCols(new String[]{"raw"})).setOutputCols(new String[]{"filtered"})).transform(new Table[]{this.tEnv.fromDataStream(this.env.fromElements(new Row[]{Row.of(new Object[]{new String[]{"test", "test"}, new String[]{"test", "test"}})})).as("raw", new String[]{"expected"})})[0].getResolvedSchema().getColumnNames());
    }

    @Test
    public void testTransform() {
        verifyOutputResult((StopWordsRemover) ((StopWordsRemover) new StopWordsRemover().setInputCols(new String[]{"raw"})).setOutputCols(new String[]{"filtered"}), this.tEnv.fromDataStream(this.env.fromElements(new Row[]{Row.of(new Object[]{new String[]{"test", "test"}, new String[]{"test", "test"}}), Row.of(new Object[]{new String[]{"a", "b", "c", "d"}, new String[]{"b", "c", "d"}}), Row.of(new Object[]{new String[]{"a", "the", "an"}, new String[0]}), Row.of(new Object[]{new String[]{"A", "The", "AN"}, new String[0]}), Row.of(new Object[]{new String[]{null}, new String[]{null}}), Row.of(new Object[]{new String[0], new String[0]})})).as("raw", new String[]{"expected"}));
    }

    @Test
    public void testTransformWithStopWordsList() {
        verifyOutputResult((StopWordsRemover) ((StopWordsRemover) ((StopWordsRemover) new StopWordsRemover().setInputCols(new String[]{"raw"})).setOutputCols(new String[]{"filtered"})).setStopWords(new String[]{"test", "a", "an", "the", null}), this.tEnv.fromDataStream(this.env.fromElements(new Row[]{Row.of(new Object[]{new String[]{"test", "test"}, new String[0]}), Row.of(new Object[]{new String[]{"a", "b", "c", "d"}, new String[]{"b", "c", "d"}}), Row.of(new Object[]{new String[]{"a", "the", "an"}, new String[0]}), Row.of(new Object[]{new String[]{"A", "The", "AN"}, new String[0]}), Row.of(new Object[]{new String[]{null}, new String[0]}), Row.of(new Object[]{new String[0], new String[0]})})).as("raw", new String[]{"expected"}));
    }

    @Test
    public void testTransformWithLocaledInputCaseInsensitive() {
        verifyOutputResult((StopWordsRemover) ((StopWordsRemover) ((StopWordsRemover) ((StopWordsRemover) ((StopWordsRemover) new StopWordsRemover().setInputCols(new String[]{"raw"})).setOutputCols(new String[]{"filtered"})).setStopWords(new String[]{"milk", "cookie"})).setCaseSensitive(false)).setLocale("tr"), this.tEnv.fromDataStream(this.env.fromElements(new Row[]{Row.of(new Object[]{new String[]{"mİlk", "and", "nuts"}, new String[]{"and", "nuts"}}), Row.of(new Object[]{new String[]{"cookIe", "and", "nuts"}, new String[]{"cookIe", "and", "nuts"}}), Row.of(new Object[]{new String[]{null}, new String[]{null}}), Row.of(new Object[]{new String[0], new String[0]})})).as("raw", new String[]{"expected"}));
    }

    @Test
    public void testTransformWithLocaledInputCaseSensitive() {
        verifyOutputResult((StopWordsRemover) ((StopWordsRemover) ((StopWordsRemover) ((StopWordsRemover) ((StopWordsRemover) new StopWordsRemover().setInputCols(new String[]{"raw"})).setOutputCols(new String[]{"filtered"})).setStopWords(new String[]{"milk", "cookie"})).setCaseSensitive(true)).setLocale("tr"), this.tEnv.fromDataStream(this.env.fromElements(new Row[]{Row.of(new Object[]{new String[]{"mİlk", "and", "nuts"}, new String[]{"mİlk", "and", "nuts"}}), Row.of(new Object[]{new String[]{"cookIe", "and", "nuts"}, new String[]{"cookIe", "and", "nuts"}}), Row.of(new Object[]{new String[]{null}, new String[]{null}}), Row.of(new Object[]{new String[0], new String[0]})})).as("raw", new String[]{"expected"}));
    }

    @Test
    public void testInvalidLocale() {
        try {
            new StopWordsRemover().setLocale("rt");
            Assert.fail();
        } catch (Exception e) {
            Assert.assertEquals(IllegalArgumentException.class, e.getClass());
            Assert.assertEquals("Parameter locale is given an invalid value rt", e.getMessage());
        }
    }

    @Test
    public void testAvailableLocales() {
        Assert.assertTrue(StopWordsRemover.getAvailableLocales().contains(Locale.US.toString()));
        StopWordsRemover stopWordsRemover = new StopWordsRemover();
        Iterator it = StopWordsRemover.getAvailableLocales().iterator();
        while (it.hasNext()) {
            stopWordsRemover.setLocale((String) it.next());
        }
    }

    @Test
    public void testTransformCaseSensitive() {
        verifyOutputResult((StopWordsRemover) ((StopWordsRemover) ((StopWordsRemover) new StopWordsRemover().setInputCols(new String[]{"raw"})).setOutputCols(new String[]{"filtered"})).setCaseSensitive(true), this.tEnv.fromDataStream(this.env.fromElements(new Row[]{Row.of(new Object[]{new String[]{"A"}, new String[]{"A"}}), Row.of(new Object[]{new String[]{"The", "the"}, new String[]{"The"}})})).as("raw", new String[]{"expected"}));
    }

    @Test
    public void testDefaultStopWordsOfSupportedLanguagesNotEmtpy() {
        Iterator it = Arrays.asList("danish", "dutch", "english", "finnish", "french", "german", "hungarian", "italian", "norwegian", "portuguese", "russian", "spanish", "swedish", "turkish").iterator();
        while (it.hasNext()) {
            Assert.assertTrue(StopWordsRemover.loadDefaultStopWords((String) it.next()).length > 0);
        }
    }

    @Test
    public void testTransformWithLanguageSelection() {
        verifyOutputResult((StopWordsRemover) ((StopWordsRemover) ((StopWordsRemover) new StopWordsRemover().setInputCols(new String[]{"raw"})).setOutputCols(new String[]{"filtered"})).setStopWords(StopWordsRemover.loadDefaultStopWords("turkish")), this.tEnv.fromDataStream(this.env.fromElements(new Row[]{Row.of(new Object[]{new String[]{"acaba", "ama", "biri"}, new String[0]}), Row.of(new Object[]{new String[]{"hep", "her", "scala"}, new String[]{"scala"}})})).as("raw", new String[]{"expected"}));
    }

    @Test
    public void testTransformWithIgnoredWords() {
        Table as = this.tEnv.fromDataStream(this.env.fromElements(new Row[]{Row.of(new Object[]{new String[]{"python", "scala", "a"}, new String[]{"python", "scala", "a"}}), Row.of(new Object[]{new String[]{"Python", "Scala", "swift"}, new String[]{"Python", "Scala", "swift"}})})).as("raw", new String[]{"expected"});
        HashSet hashSet = new HashSet(Arrays.asList(StopWordsRemover.loadDefaultStopWords("english")));
        hashSet.remove("a");
        verifyOutputResult((StopWordsRemover) ((StopWordsRemover) ((StopWordsRemover) new StopWordsRemover().setInputCols(new String[]{"raw"})).setOutputCols(new String[]{"filtered"})).setStopWords((String[]) hashSet.toArray(new String[0])), as);
    }

    @Test
    public void testTransformWithAdditionalWords() {
        Table as = this.tEnv.fromDataStream(this.env.fromElements(new Row[]{Row.of(new Object[]{new String[]{"python", "scala", "a"}, new String[0]}), Row.of(new Object[]{new String[]{"Python", "Scala", "swift"}, new String[]{"swift"}})})).as("raw", new String[]{"expected"});
        HashSet hashSet = new HashSet(Arrays.asList(StopWordsRemover.loadDefaultStopWords("english")));
        hashSet.addAll(Arrays.asList("python", "scala"));
        verifyOutputResult((StopWordsRemover) ((StopWordsRemover) ((StopWordsRemover) new StopWordsRemover().setInputCols(new String[]{"raw"})).setOutputCols(new String[]{"filtered"})).setStopWords((String[]) hashSet.toArray(new String[0])), as);
    }

    @Test
    public void testSaveLoadAndTransform() throws Exception {
        verifyOutputResult(TestUtils.saveAndReload(this.tEnv, (StopWordsRemover) ((StopWordsRemover) new StopWordsRemover().setInputCols(new String[]{"raw"})).setOutputCols(new String[]{"filtered"}), TEMPORARY_FOLDER.newFolder().getAbsolutePath(), StopWordsRemover::load), this.tEnv.fromDataStream(this.env.fromElements(new Row[]{Row.of(new Object[]{new String[]{"test", "test"}, new String[]{"test", "test"}}), Row.of(new Object[]{new String[]{"a", "b", "c", "d"}, new String[]{"b", "c", "d"}}), Row.of(new Object[]{new String[]{"a", "the", "an"}, new String[0]}), Row.of(new Object[]{new String[]{"A", "The", "AN"}, new String[0]}), Row.of(new Object[]{new String[]{null}, new String[]{null}}), Row.of(new Object[]{new String[0], new String[0]})})).as("raw", new String[]{"expected"}));
    }

    @Test
    public void testOutputColumnAlreadyExists() {
        try {
            ((StopWordsRemover) ((StopWordsRemover) new StopWordsRemover().setInputCols(new String[]{"raw"})).setOutputCols(new String[]{"expected"})).transform(new Table[]{this.tEnv.fromDataStream(this.env.fromElements(new Row[]{Row.of(new Object[]{new String[]{"The", "the", "swift"}, new String[]{"swift"}})})).as("raw", new String[]{"expected"})});
            Assert.fail();
        } catch (Exception e) {
            Assert.assertEquals(ValidationException.class, e.getClass());
            Assert.assertEquals("Ambiguous column name: expected", e.getMessage());
        }
    }

    @Test
    public void testTransformMultipleColumns() {
        Table as = this.tEnv.fromDataStream(this.env.fromElements(new Row[]{Row.of(new Object[]{new String[]{"test", "test"}, new String[]{"test1", "test2"}, new String[0], new String[]{"test1", "test2"}}), Row.of(new Object[]{new String[]{"a", "b", "c", "d"}, new String[]{"a", "b"}, new String[]{"b", "c", "d"}, new String[]{"b"}}), Row.of(new Object[]{new String[]{"a", "the", "an"}, new String[]{"a", "the", "test1"}, new String[0], new String[]{"test1"}}), Row.of(new Object[]{new String[]{"A", "The", "AN"}, new String[]{"A", "The", "AN"}, new String[0], new String[0]}), Row.of(new Object[]{new String[]{null}, new String[]{null}, new String[]{null}, new String[]{null}}), Row.of(new Object[]{new String[0], new String[0], new String[0], new String[0]})})).as("raw1", new String[]{"raw2", "expected1", "expected2"});
        Table table = ((StopWordsRemover) ((StopWordsRemover) ((StopWordsRemover) new StopWordsRemover().setInputCols(new String[]{"raw1", "raw2"})).setOutputCols(new String[]{"filtered1", "filtered2"})).setStopWords(new String[]{"test", "a", "an", "the"})).transform(new Table[]{as})[0];
        int size = IteratorUtils.toList(as.execute().collect()).size();
        int i = 0;
        CloseableIterator collect = table.execute().collect();
        while (collect.hasNext()) {
            Row row = (Row) collect.next();
            Assert.assertArrayEquals((String[]) row.getFieldAs("expected1"), (String[]) row.getFieldAs("filtered1"));
            Assert.assertArrayEquals((String[]) row.getFieldAs("expected2"), (String[]) row.getFieldAs("filtered2"));
            i++;
        }
        Assert.assertEquals(size, i);
    }

    @Test
    public void testCompareSingleMultipleRemoverInPipeline() {
        Table as = this.tEnv.fromDataStream(this.env.fromElements(new Row[]{Row.of(new Object[]{new String[]{"test", "test"}, new String[]{"test1", "test2"}}), Row.of(new Object[]{new String[]{"a", "b", "c", "d"}, new String[]{"a", "b"}}), Row.of(new Object[]{new String[]{"a", "the", "an"}, new String[]{"a", "the", "test1"}}), Row.of(new Object[]{new String[]{"A", "The", "AN"}, new String[]{"A", "The", "AN"}}), Row.of(new Object[]{new String[]{null}, new String[]{null}}), Row.of(new Object[]{new String[0], new String[0]})})).as("input1", new String[]{"input2"});
        Assert.assertEquals(new HashSet(IteratorUtils.toList(new PipelineModel(Collections.singletonList((Stage) ((StopWordsRemover) new StopWordsRemover().setInputCols(new String[]{"input1", "input2"})).setOutputCols(new String[]{"output1", "output2"}))).transform(new Table[]{as})[0].execute().collect())), new HashSet(IteratorUtils.toList(new PipelineModel(Arrays.asList((Stage) ((StopWordsRemover) new StopWordsRemover().setInputCols(new String[]{"input1"})).setOutputCols(new String[]{"output1"}), (Stage) ((StopWordsRemover) new StopWordsRemover().setInputCols(new String[]{"input2"})).setOutputCols(new String[]{"output2"}))).transform(new Table[]{as})[0].execute().collect())));
    }

    @Test
    public void testMismatchInputOutputCols() {
        try {
            ((StopWordsRemover) ((StopWordsRemover) new StopWordsRemover().setInputCols(new String[]{"raw"})).setOutputCols(new String[]{"expected1", "expected2"})).transform(new Table[]{this.tEnv.fromDataStream(this.env.fromElements(new Row[]{Row.of(new Object[]{new String[]{"The", "the", "swift"}, new String[]{"swift"}})})).as("raw", new String[]{"expected"})});
            Assert.fail();
        } catch (Exception e) {
            Assert.assertEquals(IllegalArgumentException.class, e.getClass());
            Assert.assertNull(e.getMessage());
        }
    }
}
