package org.apache.flink.ml.feature;

import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import org.apache.commons.collections.IteratorUtils;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.ml.feature.hashingtf.HashingTF;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.ml.util.TestUtils;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Expressions;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.expressions.Expression;
import org.apache.flink.test.util.AbstractTestBase;
import org.apache.flink.types.Row;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;

/* loaded from: input_file:org/apache/flink/ml/feature/HashingTFTest.class */
public class HashingTFTest extends AbstractTestBase {
    private StreamTableEnvironment tEnv;
    private StreamExecutionEnvironment env;
    private Table inputDataTable;
    private static final List<Row> INPUT = Arrays.asList(Row.of(new Object[]{Arrays.asList("HashingTFTest", "Hashing", "Term", "Frequency", "Test")}), Row.of(new Object[]{Arrays.asList("HashingTFTest", "Hashing", "Hashing", "Test", "Test")}));
    private static final List<Row> EXPECTED_OUTPUT = Arrays.asList(Row.of(new Object[]{Vectors.sparse(262144, new int[]{67564, 89917, 113827, 131486, 228971}, new double[]{1.0d, 1.0d, 1.0d, 1.0d, 1.0d})}), Row.of(new Object[]{Vectors.sparse(262144, new int[]{67564, 131486, 228971}, new double[]{1.0d, 2.0d, 2.0d})}));
    private static final List<Row> EXPECTED_BINARY_OUTPUT = Arrays.asList(Row.of(new Object[]{Vectors.sparse(262144, new int[]{67564, 89917, 113827, 131486, 228971}, new double[]{1.0d, 1.0d, 1.0d, 1.0d, 1.0d})}), Row.of(new Object[]{Vectors.sparse(262144, new int[]{67564, 131486, 228971}, new double[]{1.0d, 1.0d, 1.0d})}));

    @Before
    public void before() {
        this.env = TestUtils.getExecutionEnvironment();
        this.tEnv = StreamTableEnvironment.create(this.env);
        this.inputDataTable = this.tEnv.fromDataStream(this.env.fromCollection(INPUT, Types.ROW(new TypeInformation[]{Types.LIST(Types.STRING)}))).as("input", new String[0]);
    }

    @Test
    public void testParam() {
        HashingTF hashingTF = new HashingTF();
        Assert.assertEquals("input", hashingTF.getInputCol());
        Assert.assertFalse(hashingTF.getBinary());
        Assert.assertEquals(262144L, hashingTF.getNumFeatures());
        Assert.assertEquals("output", hashingTF.getOutputCol());
        ((HashingTF) ((HashingTF) ((HashingTF) hashingTF.setInputCol("testInputCol")).setBinary(true)).setNumFeatures(1024)).setOutputCol("testOutputCol");
        Assert.assertEquals("testInputCol", hashingTF.getInputCol());
        Assert.assertTrue(hashingTF.getBinary());
        Assert.assertEquals(1024L, hashingTF.getNumFeatures());
        Assert.assertEquals("testOutputCol", hashingTF.getOutputCol());
    }

    @Test
    public void testOutputSchema() {
        HashingTF hashingTF = new HashingTF();
        this.inputDataTable = this.tEnv.fromDataStream(this.env.fromElements(new Row[]{Row.of(new Object[]{Arrays.asList(""), Arrays.asList("")})})).as("input", new String[]{"dummyInput"});
        Assert.assertEquals(Arrays.asList(hashingTF.getInputCol(), "dummyInput", hashingTF.getOutputCol()), hashingTF.transform(new Table[]{this.inputDataTable})[0].getResolvedSchema().getColumnNames());
    }

    @Test
    public void testTransform() throws Exception {
        HashingTF hashingTF = new HashingTF();
        verifyOutputResult(hashingTF.transform(new Table[]{this.inputDataTable})[0], hashingTF.getOutputCol(), EXPECTED_OUTPUT);
        hashingTF.setBinary(true);
        verifyOutputResult(hashingTF.transform(new Table[]{this.inputDataTable})[0], hashingTF.getOutputCol(), EXPECTED_BINARY_OUTPUT);
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Test
    public void testTransformArrayData() throws Exception {
        HashingTF hashingTF = new HashingTF();
        this.inputDataTable = this.tEnv.fromDataStream(this.env.fromElements(new String[]{new String[]{"HashingTFTest", "Hashing", "Term", "Frequency", "Test"}, new String[]{"HashingTFTest", "Hashing", "Hashing", "Test", "Test"}})).as("input", new String[0]);
        verifyOutputResult(hashingTF.transform(new Table[]{this.inputDataTable})[0], hashingTF.getOutputCol(), EXPECTED_OUTPUT);
    }

    @Test
    public void testSaveLoadAndTransform() throws Exception {
        HashingTF saveAndReload = TestUtils.saveAndReload(this.tEnv, new HashingTF(), TEMPORARY_FOLDER.newFolder().getAbsolutePath(), HashingTF::load);
        verifyOutputResult(saveAndReload.transform(new Table[]{this.inputDataTable})[0], saveAndReload.getOutputCol(), EXPECTED_OUTPUT);
    }

    private void verifyOutputResult(Table table, String str, List<Row> list) throws Exception {
        List list2 = IteratorUtils.toList(this.tEnv.toDataStream(table.select(new Expression[]{Expressions.$(str)})).executeAndCollect());
        Assert.assertEquals(list.size(), list2.size());
        list2.sort(Comparator.comparingInt(row -> {
            return row.getField(0).hashCode();
        }));
        list.sort(Comparator.comparingInt(row2 -> {
            return row2.getField(0).hashCode();
        }));
        for (int i = 0; i < list.size(); i++) {
            Assert.assertEquals(list.get(i).getField(0), ((Row) list2.get(i)).getField(0));
        }
    }
}
