package org.apache.flink.ml.feature;

import java.util.Arrays;
import java.util.List;
import org.apache.commons.collections.IteratorUtils;
import org.apache.flink.ml.feature.featurehasher.FeatureHasher;
import org.apache.flink.ml.linalg.SparseVector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.ml.util.TestUtils;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.test.util.AbstractTestBase;
import org.apache.flink.types.Row;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;

/* loaded from: input_file:org/apache/flink/ml/feature/FeatureHasherTest.class */
public class FeatureHasherTest extends AbstractTestBase {
    private StreamTableEnvironment tEnv;
    private Table inputDataTable;
    private static final List<Row> INPUT_DATA = Arrays.asList(Row.of(new Object[]{0, "a", Double.valueOf(1.0d), true}), Row.of(new Object[]{1, "c", Double.valueOf(1.0d), false}));
    private static final SparseVector EXPECTED_OUTPUT_DATA_1 = Vectors.sparse(1000, new int[]{607, 635, 913}, new double[]{1.0d, 1.0d, 1.0d});
    private static final SparseVector EXPECTED_OUTPUT_DATA_2 = Vectors.sparse(1000, new int[]{242, 869, 913}, new double[]{1.0d, 1.0d, 1.0d});

    @Before
    public void before() {
        StreamExecutionEnvironment executionEnvironment = TestUtils.getExecutionEnvironment();
        this.tEnv = StreamTableEnvironment.create(executionEnvironment);
        this.inputDataTable = this.tEnv.fromDataStream(executionEnvironment.fromCollection(INPUT_DATA)).as("id", new String[]{"f0", "f1", "f2"});
    }

    private void verifyOutputResult(Table table, String str) throws Exception {
        List<Row> list = IteratorUtils.toList(this.tEnv.toDataStream(table).executeAndCollect());
        Assert.assertEquals(2L, list.size());
        for (Row row : list) {
            if (row.getField(0) == 0) {
                Assert.assertEquals(EXPECTED_OUTPUT_DATA_1, row.getField(str));
            } else {
                if (row.getField(0) != 1) {
                    throw new RuntimeException("unknown output value.");
                }
                Assert.assertEquals(EXPECTED_OUTPUT_DATA_2, row.getField(str));
            }
        }
    }

    @Test
    public void testParam() {
        FeatureHasher featureHasher = new FeatureHasher();
        Assert.assertEquals("output", featureHasher.getOutputCol());
        Assert.assertArrayEquals(new String[0], featureHasher.getCategoricalCols());
        Assert.assertEquals(262144L, featureHasher.getNumFeatures());
        ((FeatureHasher) ((FeatureHasher) ((FeatureHasher) featureHasher.setInputCols(new String[]{"f0", "f1", "f2"})).setOutputCol("vec")).setCategoricalCols(new String[]{"f0", "f2"})).setNumFeatures(1000);
        Assert.assertArrayEquals(new String[]{"f0", "f1", "f2"}, featureHasher.getInputCols());
        Assert.assertEquals("vec", featureHasher.getOutputCol());
        Assert.assertArrayEquals(new String[]{"f0", "f2"}, featureHasher.getCategoricalCols());
        Assert.assertEquals(1000L, featureHasher.getNumFeatures());
    }

    @Test
    public void testSaveLoadAndTransform() throws Exception {
        FeatureHasher saveAndReload = TestUtils.saveAndReload(this.tEnv, (FeatureHasher) ((FeatureHasher) ((FeatureHasher) ((FeatureHasher) new FeatureHasher().setInputCols(new String[]{"f0", "f1", "f2"})).setOutputCol("vec")).setCategoricalCols(new String[]{"f0", "f2"})).setNumFeatures(1000), TEMPORARY_FOLDER.newFolder().getAbsolutePath(), FeatureHasher::load);
        verifyOutputResult(saveAndReload.transform(new Table[]{this.inputDataTable})[0], saveAndReload.getOutputCol());
    }

    @Test
    public void testCategoricalColsNotSet() throws Exception {
        FeatureHasher saveAndReload = TestUtils.saveAndReload(this.tEnv, (FeatureHasher) ((FeatureHasher) ((FeatureHasher) new FeatureHasher().setInputCols(new String[]{"f0", "f1", "f2"})).setOutputCol("vec")).setNumFeatures(1000), TEMPORARY_FOLDER.newFolder().getAbsolutePath(), FeatureHasher::load);
        verifyOutputResult(saveAndReload.transform(new Table[]{this.inputDataTable})[0], saveAndReload.getOutputCol());
    }

    @Test
    public void testInputTypeConversion() throws Exception {
        this.inputDataTable = TestUtils.convertDataTypesToSparseInt(this.tEnv, this.inputDataTable);
        Assert.assertArrayEquals(new Class[]{Integer.class, String.class, Integer.class, Boolean.class}, TestUtils.getColumnDataTypes(this.inputDataTable));
        FeatureHasher saveAndReload = TestUtils.saveAndReload(this.tEnv, (FeatureHasher) ((FeatureHasher) ((FeatureHasher) ((FeatureHasher) new FeatureHasher().setInputCols(new String[]{"f0", "f1", "f2"})).setOutputCol("vec")).setCategoricalCols(new String[]{"f0", "f2"})).setNumFeatures(1000), TEMPORARY_FOLDER.newFolder().getAbsolutePath(), FeatureHasher::load);
        verifyOutputResult(saveAndReload.transform(new Table[]{this.inputDataTable})[0], saveAndReload.getOutputCol());
    }
}
