package org.apache.flink.ml.feature;

import java.util.Arrays;
import java.util.List;
import org.apache.commons.collections.IteratorUtils;
import org.apache.commons.lang3.exception.ExceptionUtils;
import org.apache.flink.ml.feature.elementwiseproduct.ElementwiseProduct;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.SparseVector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.ml.util.TestUtils;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.test.util.AbstractTestBase;
import org.apache.flink.types.Row;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;

/* loaded from: input_file:org/apache/flink/ml/feature/ElementwiseProductTest.class */
public class ElementwiseProductTest extends AbstractTestBase {
    private StreamTableEnvironment tEnv;
    private Table inputDataTable;
    private static final int EXPECTED_OUTPUT_SPARSE_VEC_SIZE_1 = 5;
    private static final int EXPECTED_OUTPUT_SPARSE_VEC_SIZE_2 = 5;
    private static final List<Row> INPUT_DATA = Arrays.asList(Row.of(new Object[]{0, Vectors.dense(new double[]{2.1d, 3.1d}), Vectors.sparse(5, new int[]{3}, new double[]{1.0d})}), Row.of(new Object[]{1, Vectors.dense(new double[]{1.1d, 3.3d}), Vectors.sparse(5, new int[]{4, 2, 3, 1}, new double[]{4.0d, 2.0d, 3.0d, 1.0d})}), Row.of(new Object[]{2, null, null}));
    private static final double[] EXPECTED_OUTPUT_DENSE_VEC_ARRAY_1 = {2.31d, 3.41d};
    private static final double[] EXPECTED_OUTPUT_DENSE_VEC_ARRAY_2 = {1.21d, 3.63d};
    private static final int[] EXPECTED_OUTPUT_SPARSE_VEC_INDICES_1 = {3};
    private static final double[] EXPECTED_OUTPUT_SPARSE_VEC_VALUES_1 = {0.0d};
    private static final int[] EXPECTED_OUTPUT_SPARSE_VEC_INDICES_2 = {1, 2, 3, 4};
    private static final double[] EXPECTED_OUTPUT_SPARSE_VEC_VALUES_2 = {1.1d, 0.0d, 0.0d, 0.0d};

    @Before
    public void before() {
        StreamExecutionEnvironment executionEnvironment = TestUtils.getExecutionEnvironment();
        this.tEnv = StreamTableEnvironment.create(executionEnvironment);
        this.inputDataTable = this.tEnv.fromDataStream(executionEnvironment.fromCollection(INPUT_DATA)).as("id", new String[]{"vec", "sparseVec"});
    }

    private void verifyOutputResult(Table table, String str, boolean z) throws Exception {
        List<Row> list = IteratorUtils.toList(this.tEnv.toDataStream(table).executeAndCollect());
        Assert.assertEquals(3L, list.size());
        for (Row row : list) {
            if (row.getField(0) == 0) {
                if (z) {
                    SparseVector sparseVector = (SparseVector) row.getField(str);
                    Assert.assertEquals(5L, sparseVector.size());
                    Assert.assertArrayEquals(EXPECTED_OUTPUT_SPARSE_VEC_INDICES_1, sparseVector.indices);
                    Assert.assertArrayEquals(EXPECTED_OUTPUT_SPARSE_VEC_VALUES_1, sparseVector.values, 1.0E-5d);
                } else {
                    Assert.assertArrayEquals(EXPECTED_OUTPUT_DENSE_VEC_ARRAY_1, ((DenseVector) row.getField(str)).values, 1.0E-5d);
                }
            } else if (row.getField(0) != 1) {
                if (row.getField(0) != 2) {
                    throw new UnsupportedOperationException("Input data id not exists.");
                }
                Assert.assertNull(row.getField(str));
            } else if (z) {
                SparseVector sparseVector2 = (SparseVector) row.getField(str);
                Assert.assertEquals(5L, sparseVector2.size());
                Assert.assertArrayEquals(EXPECTED_OUTPUT_SPARSE_VEC_INDICES_2, sparseVector2.indices);
                Assert.assertArrayEquals(EXPECTED_OUTPUT_SPARSE_VEC_VALUES_2, sparseVector2.values, 1.0E-5d);
            } else {
                Assert.assertArrayEquals(EXPECTED_OUTPUT_DENSE_VEC_ARRAY_2, ((DenseVector) row.getField(str)).values, 1.0E-5d);
            }
        }
    }

    @Test
    public void testParam() {
        ElementwiseProduct elementwiseProduct = new ElementwiseProduct();
        Assert.assertEquals("output", elementwiseProduct.getOutputCol());
        Assert.assertEquals("input", elementwiseProduct.getInputCol());
        ((ElementwiseProduct) ((ElementwiseProduct) elementwiseProduct.setInputCol("vec")).setOutputCol("outputVec")).setScalingVec(Vectors.dense(new double[]{1.0d, 2.0d, 3.0d}));
        Assert.assertEquals("vec", elementwiseProduct.getInputCol());
        Assert.assertEquals(Vectors.dense(new double[]{1.0d, 2.0d, 3.0d}), elementwiseProduct.getScalingVec());
        Assert.assertEquals("outputVec", elementwiseProduct.getOutputCol());
    }

    @Test
    public void testOutputSchema() {
        Assert.assertEquals(Arrays.asList("id", "vec", "sparseVec", "outputVec"), ((ElementwiseProduct) ((ElementwiseProduct) ((ElementwiseProduct) new ElementwiseProduct().setInputCol("vec")).setOutputCol("outputVec")).setScalingVec(Vectors.dense(new double[]{1.0d, 2.0d, 3.0d}))).transform(new Table[]{this.inputDataTable})[0].getResolvedSchema().getColumnNames());
    }

    @Test
    public void testSaveLoadAndTransformDense() throws Exception {
        ElementwiseProduct saveAndReload = TestUtils.saveAndReload(this.tEnv, (ElementwiseProduct) ((ElementwiseProduct) ((ElementwiseProduct) new ElementwiseProduct().setInputCol("vec")).setOutputCol("outputVec")).setScalingVec(Vectors.dense(new double[]{1.1d, 1.1d})), TEMPORARY_FOLDER.newFolder().getAbsolutePath(), ElementwiseProduct::load);
        verifyOutputResult(saveAndReload.transform(new Table[]{this.inputDataTable})[0], saveAndReload.getOutputCol(), false);
    }

    @Test
    public void testVectorSizeNotEquals() {
        try {
            IteratorUtils.toList(this.tEnv.toDataStream(((ElementwiseProduct) ((ElementwiseProduct) ((ElementwiseProduct) new ElementwiseProduct().setInputCol("vec")).setOutputCol("outputVec")).setScalingVec(Vectors.dense(new double[]{1.1d, 1.1d, 2.0d}))).transform(new Table[]{this.inputDataTable})[0]).executeAndCollect());
            Assert.fail();
        } catch (Exception e) {
            Assert.assertEquals("The scaling vector size is 3, which is not equal input vector size(2).", ExceptionUtils.getRootCause(e).getMessage());
        }
    }

    @Test
    public void testSaveLoadAndTransformSparse() throws Exception {
        ElementwiseProduct saveAndReload = TestUtils.saveAndReload(this.tEnv, (ElementwiseProduct) ((ElementwiseProduct) ((ElementwiseProduct) new ElementwiseProduct().setInputCol("sparseVec")).setOutputCol("outputVec")).setScalingVec(Vectors.sparse(5, new int[]{0, 1}, new double[]{1.1d, 1.1d})), TEMPORARY_FOLDER.newFolder().getAbsolutePath(), ElementwiseProduct::load);
        verifyOutputResult(saveAndReload.transform(new Table[]{this.inputDataTable})[0], saveAndReload.getOutputCol(), true);
    }
}
