package org.apache.flink.ml.feature;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.apache.commons.collections.IteratorUtils;
import org.apache.flink.ml.feature.interaction.Interaction;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.SparseVector;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.ml.util.TestUtils;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.api.internal.TableImpl;
import org.apache.flink.test.util.AbstractTestBase;
import org.apache.flink.test.util.TestBaseUtils;
import org.apache.flink.types.Row;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;

/* loaded from: input_file:org/apache/flink/ml/feature/InteractionTest.class */
public class InteractionTest extends AbstractTestBase {
    private StreamTableEnvironment tEnv;
    private Table inputDataTable;
    private static final List<Row> INPUT_DATA = Arrays.asList(Row.of(new Object[]{1, Vectors.dense(new double[]{1.0d, 2.0d}), Vectors.dense(new double[]{3.0d, 4.0d}), Vectors.sparse(17, new int[]{0, 3, 9}, new double[]{1.0d, 2.0d, 7.0d})}), Row.of(new Object[]{2, Vectors.dense(new double[]{2.0d, 8.0d}), Vectors.dense(new double[]{3.0d, 4.0d, 5.0d}), Vectors.sparse(17, new int[]{0, 2, 14}, new double[]{5.0d, 4.0d, 1.0d})}), Row.of(new Object[]{3, null, null, null}));
    private static final List<Vector> EXPECTED_DENSE_OUTPUT = Arrays.asList(new DenseVector(new double[]{3.0d, 4.0d, 6.0d, 8.0d}), new DenseVector(new double[]{12.0d, 16.0d, 20.0d, 48.0d, 64.0d, 80.0d}));
    private static final List<Vector> EXPECTED_SPARSE_OUTPUT = Arrays.asList(new SparseVector(68, new int[]{0, 3, 9, 17, 20, 26, 34, 37, 43, 51, 54, 60}, new double[]{3.0d, 6.0d, 21.0d, 4.0d, 8.0d, 28.0d, 6.0d, 12.0d, 42.0d, 8.0d, 16.0d, 56.0d}), new SparseVector(102, new int[]{0, 2, 14, 17, 19, 31, 34, 36, 48, 51, 53, 65, 68, 70, 82, 85, 87, 99}, new double[]{60.0d, 48.0d, 12.0d, 80.0d, 64.0d, 16.0d, 100.0d, 80.0d, 20.0d, 240.0d, 192.0d, 48.0d, 320.0d, 256.0d, 64.0d, 400.0d, 320.0d, 80.0d}));

    @Before
    public void before() {
        StreamExecutionEnvironment executionEnvironment = TestUtils.getExecutionEnvironment();
        this.tEnv = StreamTableEnvironment.create(executionEnvironment);
        this.inputDataTable = this.tEnv.fromDataStream(executionEnvironment.fromCollection(INPUT_DATA)).as("f0", new String[]{"f1", "f2", "f3"});
    }

    private void verifyOutputResult(Table table, String str, List<Vector> list) throws Exception {
        List<Row> list2 = IteratorUtils.toList(((TableImpl) table).getTableEnvironment().toDataStream(table).executeAndCollect());
        ArrayList arrayList = new ArrayList(list2.size());
        for (Row row : list2) {
            if (row.getField(str) != null) {
                arrayList.add((Vector) row.getFieldAs(str));
            }
        }
        TestBaseUtils.compareResultCollections(list, arrayList, TestUtils::compare);
    }

    @Test
    public void testParam() {
        Interaction interaction = new Interaction();
        Assert.assertEquals("output", interaction.getOutputCol());
        ((Interaction) interaction.setInputCols(new String[]{"f0", "f1", "f2"})).setOutputCol("interactionVecVec");
        Assert.assertArrayEquals(new String[]{"f0", "f1", "f2"}, interaction.getInputCols());
        Assert.assertEquals("interactionVecVec", interaction.getOutputCol());
    }

    @Test
    public void testOutputSchema() {
        Assert.assertEquals(Arrays.asList("f0", "f1", "f2", "f3", "outputVec"), ((Interaction) ((Interaction) new Interaction().setInputCols(new String[]{"f0", "f1", "f2", "f3"})).setOutputCol("outputVec")).transform(new Table[]{this.inputDataTable})[0].getResolvedSchema().getColumnNames());
    }

    @Test
    public void testSaveLoadAndTransformSparse() throws Exception {
        Interaction saveAndReload = TestUtils.saveAndReload(this.tEnv, (Interaction) ((Interaction) new Interaction().setInputCols(new String[]{"f0", "f1", "f2", "f3"})).setOutputCol("interactionVecVec"), TEMPORARY_FOLDER.newFolder().getAbsolutePath(), Interaction::load);
        verifyOutputResult(saveAndReload.transform(new Table[]{this.inputDataTable})[0], saveAndReload.getOutputCol(), EXPECTED_SPARSE_OUTPUT);
    }

    @Test
    public void testSaveLoadAndTransformDense() throws Exception {
        Interaction saveAndReload = TestUtils.saveAndReload(this.tEnv, (Interaction) ((Interaction) new Interaction().setInputCols(new String[]{"f0", "f1", "f2"})).setOutputCol("interactionVecVec"), TEMPORARY_FOLDER.newFolder().getAbsolutePath(), Interaction::load);
        verifyOutputResult(saveAndReload.transform(new Table[]{this.inputDataTable})[0], saveAndReload.getOutputCol(), EXPECTED_DENSE_OUTPUT);
    }
}
