package org.apache.flink.ml.feature;

import java.lang.invoke.SerializedLambda;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.stream.DoubleStream;
import java.util.stream.IntStream;
import org.apache.commons.collections.IteratorUtils;
import org.apache.commons.lang3.exception.ExceptionUtils;
import org.apache.flink.ml.feature.countvectorizer.CountVectorizer;
import org.apache.flink.ml.feature.countvectorizer.CountVectorizerModel;
import org.apache.flink.ml.linalg.SparseVector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.ml.util.TestUtils;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.api.internal.TableImpl;
import org.apache.flink.test.util.AbstractTestBase;
import org.apache.flink.test.util.TestBaseUtils;
import org.apache.flink.types.Row;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;

/* loaded from: input_file:org/apache/flink/ml/feature/CountVectorizerTest.class */
public class CountVectorizerTest extends AbstractTestBase {

    @Rule
    public final TemporaryFolder tempFolder = new TemporaryFolder();
    private StreamExecutionEnvironment env;
    private StreamTableEnvironment tEnv;
    private Table inputTable;
    private static final double EPS = 1.0E-5d;
    private static final List<Row> INPUT_DATA = new ArrayList(Arrays.asList(Row.of(new Object[]{new String[]{"a", "c", "b", "c"}}), Row.of(new Object[]{new String[]{"c", "d", "e"}}), Row.of(new Object[]{new String[]{"a", "b", "c"}}), Row.of(new Object[]{new String[]{"e", "f"}}), Row.of(new Object[]{new String[]{"a", "c", "a"}})));
    private static final List<SparseVector> EXPECTED_OUTPUT = new ArrayList(Arrays.asList(Vectors.sparse(6, IntStream.of(0, 1, 2).toArray(), DoubleStream.of(2.0d, 1.0d, 1.0d).toArray()), Vectors.sparse(6, IntStream.of(0, 3, 4).toArray(), DoubleStream.of(1.0d, 1.0d, 1.0d).toArray()), Vectors.sparse(6, IntStream.of(0, 1, 2).toArray(), DoubleStream.of(1.0d, 1.0d, 1.0d).toArray()), Vectors.sparse(6, IntStream.of(3, 5).toArray(), DoubleStream.of(1.0d, 1.0d).toArray()), Vectors.sparse(6, IntStream.of(0, 1).toArray(), DoubleStream.of(1.0d, 2.0d).toArray())));

    @Before
    public void before() {
        this.env = TestUtils.getExecutionEnvironment();
        this.tEnv = StreamTableEnvironment.create(this.env);
        this.inputTable = this.tEnv.fromDataStream(this.env.fromCollection(INPUT_DATA)).as("input", new String[0]);
    }

    private static void verifyPredictionResult(Table table, String str, List<SparseVector> list) throws Exception {
        TestBaseUtils.compareResultCollections(list, IteratorUtils.toList(((TableImpl) table).getTableEnvironment().toDataStream(table).map(row -> {
            return (SparseVector) row.getField(str);
        }).executeAndCollect()), (v0, v1) -> {
            return TestUtils.compare(v0, v1);
        });
    }

    @Test
    public void testParam() {
        CountVectorizer countVectorizer = new CountVectorizer();
        Assert.assertEquals("input", countVectorizer.getInputCol());
        Assert.assertEquals("output", countVectorizer.getOutputCol());
        Assert.assertEquals(9.223372036854776E18d, countVectorizer.getMaxDF(), EPS);
        Assert.assertEquals(1.0d, countVectorizer.getMinDF(), EPS);
        Assert.assertEquals(1.0d, countVectorizer.getMinTF(), EPS);
        Assert.assertEquals(262144L, countVectorizer.getVocabularySize());
        Assert.assertFalse(countVectorizer.getBinary());
        ((CountVectorizer) ((CountVectorizer) ((CountVectorizer) ((CountVectorizer) ((CountVectorizer) ((CountVectorizer) countVectorizer.setInputCol("test_input")).setOutputCol("test_output")).setMinDF(0.1d)).setMaxDF(0.9d)).setMinTF(10.0d)).setVocabularySize(1000)).setBinary(true);
        Assert.assertEquals("test_input", countVectorizer.getInputCol());
        Assert.assertEquals("test_output", countVectorizer.getOutputCol());
        Assert.assertEquals(0.9d, countVectorizer.getMaxDF(), EPS);
        Assert.assertEquals(0.1d, countVectorizer.getMinDF(), EPS);
        Assert.assertEquals(10.0d, countVectorizer.getMinTF(), EPS);
        Assert.assertEquals(1000L, countVectorizer.getVocabularySize());
        Assert.assertTrue(countVectorizer.getBinary());
    }

    @Test
    public void testInvalidMinMaxDF() {
        CountVectorizer countVectorizer = new CountVectorizer();
        countVectorizer.setMaxDF(0.1d);
        countVectorizer.setMinDF(0.2d);
        try {
            countVectorizer.fit(new Table[]{this.inputTable});
            Assert.fail();
        } catch (Throwable th) {
            Assert.assertEquals("maxDF must be >= minDF.", th.getMessage());
        }
        countVectorizer.setMaxDF(1.0d);
        countVectorizer.setMinDF(2.0d);
        try {
            countVectorizer.fit(new Table[]{this.inputTable});
            Assert.fail();
        } catch (Throwable th2) {
            Assert.assertEquals("maxDF must be >= minDF.", th2.getMessage());
        }
        countVectorizer.setMaxDF(1.0d);
        countVectorizer.setMinDF(0.9d);
        try {
            countVectorizer.fit(new Table[]{this.inputTable}).transform(new Table[]{this.inputTable})[0].execute().print();
            Assert.fail();
        } catch (Throwable th3) {
            Assert.assertEquals("maxDF must be >= minDF.", ExceptionUtils.getRootCause(th3).getMessage());
        }
        countVectorizer.setMaxDF(0.1d);
        countVectorizer.setMinDF(10.0d);
        try {
            countVectorizer.fit(new Table[]{this.inputTable}).transform(new Table[]{this.inputTable})[0].execute().print();
            Assert.fail();
        } catch (Throwable th4) {
            Assert.assertEquals("maxDF must be >= minDF.", ExceptionUtils.getRootCause(th4).getMessage());
        }
    }

    @Test
    public void testOutputSchema() {
        Assert.assertEquals(Arrays.asList("test_input", "test_output"), ((CountVectorizer) ((CountVectorizer) new CountVectorizer().setInputCol("test_input")).setOutputCol("test_output")).fit(new Table[]{this.inputTable}).transform(new Table[]{this.inputTable.as("test_input", new String[0])})[0].getResolvedSchema().getColumnNames());
    }

    @Test
    public void testFitAndPredict() throws Exception {
        CountVectorizer countVectorizer = new CountVectorizer();
        verifyPredictionResult(countVectorizer.fit(new Table[]{this.inputTable}).transform(new Table[]{this.inputTable})[0], countVectorizer.getOutputCol(), EXPECTED_OUTPUT);
    }

    @Test
    public void testSaveLoadAndPredict() throws Exception {
        CountVectorizer countVectorizer = new CountVectorizer();
        CountVectorizerModel saveAndReload = TestUtils.saveAndReload(this.tEnv, TestUtils.saveAndReload(this.tEnv, countVectorizer, this.tempFolder.newFolder().getAbsolutePath(), CountVectorizer::load).fit(new Table[]{this.inputTable}), this.tempFolder.newFolder().getAbsolutePath(), CountVectorizerModel::load);
        Assert.assertEquals(Arrays.asList("vocabulary"), saveAndReload.getModelData()[0].getResolvedSchema().getColumnNames());
        verifyPredictionResult(saveAndReload.transform(new Table[]{this.inputTable})[0], countVectorizer.getOutputCol(), EXPECTED_OUTPUT);
    }

    @Test
    public void testFitOnEmptyData() {
        try {
            new CountVectorizer().fit(new Table[]{this.tEnv.fromDataStream(this.env.fromCollection(INPUT_DATA).filter(row -> {
                return row.getArity() == 0;
            })).as("input", new String[0])}).getModelData()[0].execute().print();
            Assert.fail();
        } catch (Throwable th) {
            Assert.assertEquals("The training set is empty.", ExceptionUtils.getRootCause(th).getMessage());
        }
    }

    @Test
    public void testMinMaxDF() throws Exception {
        ArrayList arrayList = new ArrayList(Arrays.asList(Vectors.sparse(4, IntStream.of(0, 1, 2).toArray(), DoubleStream.of(2.0d, 1.0d, 1.0d).toArray()), Vectors.sparse(4, IntStream.of(0, 3).toArray(), DoubleStream.of(1.0d, 1.0d).toArray()), Vectors.sparse(4, IntStream.of(0, 1, 2).toArray(), DoubleStream.of(1.0d, 1.0d, 1.0d).toArray()), Vectors.sparse(4, IntStream.of(3).toArray(), DoubleStream.of(1.0d).toArray()), Vectors.sparse(4, IntStream.of(0, 1).toArray(), DoubleStream.of(1.0d, 2.0d).toArray())));
        CountVectorizer countVectorizer = (CountVectorizer) ((CountVectorizer) new CountVectorizer().setMinDF(2.0d)).setMaxDF(4.0d);
        verifyPredictionResult(countVectorizer.fit(new Table[]{this.inputTable}).transform(new Table[]{this.inputTable})[0], countVectorizer.getOutputCol(), arrayList);
        ((CountVectorizer) countVectorizer.setMinDF(0.4d)).setMaxDF(0.8d);
        verifyPredictionResult(countVectorizer.fit(new Table[]{this.inputTable}).transform(new Table[]{this.inputTable})[0], countVectorizer.getOutputCol(), arrayList);
    }

    @Test
    public void testMinTF() throws Exception {
        ArrayList arrayList = new ArrayList(Arrays.asList(Vectors.sparse(6, IntStream.of(0).toArray(), DoubleStream.of(2.0d).toArray()), Vectors.sparse(6, new int[0], new double[0]), Vectors.sparse(6, new int[0], new double[0]), Vectors.sparse(6, IntStream.of(3, 5).toArray(), DoubleStream.of(1.0d, 1.0d).toArray()), Vectors.sparse(6, IntStream.of(1).toArray(), DoubleStream.of(2.0d).toArray())));
        CountVectorizer countVectorizer = (CountVectorizer) new CountVectorizer().setMinTF(0.5d);
        verifyPredictionResult(countVectorizer.fit(new Table[]{this.inputTable}).transform(new Table[]{this.inputTable})[0], countVectorizer.getOutputCol(), arrayList);
    }

    @Test
    public void testBinary() throws Exception {
        ArrayList arrayList = new ArrayList(Arrays.asList(Vectors.sparse(6, IntStream.of(0, 1, 2).toArray(), DoubleStream.of(1.0d, 1.0d, 1.0d).toArray()), Vectors.sparse(6, IntStream.of(0, 3, 4).toArray(), DoubleStream.of(1.0d, 1.0d, 1.0d).toArray()), Vectors.sparse(6, IntStream.of(0, 1, 2).toArray(), DoubleStream.of(1.0d, 1.0d, 1.0d).toArray()), Vectors.sparse(6, IntStream.of(3, 5).toArray(), DoubleStream.of(1.0d, 1.0d).toArray()), Vectors.sparse(6, IntStream.of(0, 1).toArray(), DoubleStream.of(1.0d, 1.0d).toArray())));
        CountVectorizer countVectorizer = (CountVectorizer) new CountVectorizer().setBinary(true);
        verifyPredictionResult(countVectorizer.fit(new Table[]{this.inputTable}).transform(new Table[]{this.inputTable})[0], countVectorizer.getOutputCol(), arrayList);
    }

    @Test
    public void testVocabularySize() throws Exception {
        ArrayList arrayList = new ArrayList(Arrays.asList(Vectors.sparse(2, IntStream.of(0, 1).toArray(), DoubleStream.of(2.0d, 1.0d).toArray()), Vectors.sparse(2, IntStream.of(0).toArray(), DoubleStream.of(1.0d).toArray()), Vectors.sparse(2, IntStream.of(0, 1).toArray(), DoubleStream.of(1.0d, 1.0d).toArray()), Vectors.sparse(2, new int[0], new double[0]), Vectors.sparse(2, IntStream.of(0, 1).toArray(), DoubleStream.of(1.0d, 2.0d).toArray())));
        CountVectorizer countVectorizer = (CountVectorizer) new CountVectorizer().setVocabularySize(2);
        verifyPredictionResult(countVectorizer.fit(new Table[]{this.inputTable}).transform(new Table[]{this.inputTable})[0], countVectorizer.getOutputCol(), arrayList);
    }

    @Test
    public void testGetModelData() throws Exception {
        Table table = new CountVectorizer().fit(new Table[]{this.inputTable}).getModelData()[0];
        Assert.assertEquals(Arrays.asList("vocabulary"), table.getResolvedSchema().getColumnNames());
        Assert.assertArrayEquals(new String[]{"c", "a", "b", "e", "d", "f"}, (String[]) ((Row) IteratorUtils.toList(this.tEnv.toDataStream(table).executeAndCollect()).get(0)).getField(0));
    }

    @Test
    public void testSetModelData() throws Exception {
        CountVectorizer countVectorizer = new CountVectorizer();
        verifyPredictionResult(new CountVectorizerModel().setModelData(new Table[]{countVectorizer.fit(new Table[]{this.inputTable}).getModelData()[0]}).transform(new Table[]{this.inputTable})[0], countVectorizer.getOutputCol(), EXPECTED_OUTPUT);
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case 80637623:
                if (implMethodName.equals("lambda$verifyPredictionResult$6304b8d5$1")) {
                    z = false;
                    break;
                }
                break;
            case 1077534466:
                if (implMethodName.equals("lambda$testFitOnEmptyData$3dedd8cf$1")) {
                    z = true;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/flink/api/common/functions/MapFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("map") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/flink/ml/feature/CountVectorizerTest") && serializedLambda.getImplMethodSignature().equals("(Ljava/lang/String;Lorg/apache/flink/types/Row;)Lorg/apache/flink/ml/linalg/SparseVector;")) {
                    String str = (String) serializedLambda.getCapturedArg(0);
                    return row -> {
                        return (SparseVector) row.getField(str);
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/flink/api/common/functions/FilterFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("filter") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Z") && serializedLambda.getImplClass().equals("org/apache/flink/ml/feature/CountVectorizerTest") && serializedLambda.getImplMethodSignature().equals("(Lorg/apache/flink/types/Row;)Z")) {
                    return row2 -> {
                        return row2.getArity() == 0;
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
