package org.apache.flink.ml.feature;

import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import java.util.Objects;
import org.apache.commons.collections.IteratorUtils;
import org.apache.flink.ml.feature.dct.DCT;
import org.apache.flink.ml.linalg.SparseVector;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.ml.util.TestUtils;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.test.util.AbstractTestBase;
import org.apache.flink.types.Row;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;

/* loaded from: input_file:org/apache/flink/ml/feature/DCTTest.class */
public class DCTTest extends AbstractTestBase {
    private StreamExecutionEnvironment env;
    private StreamTableEnvironment tEnv;
    private static final List<Vector> inputData = Arrays.asList(Vectors.dense(new double[]{1.0d, 1.0d, 1.0d, 1.0d}), Vectors.dense(new double[]{1.0d, 0.0d, -1.0d, 0.0d}));
    private static final List<Row> expectedForwardOutputData = Arrays.asList(Row.of(new Object[]{Vectors.dense(new double[]{1.0d, 1.0d, 1.0d, 1.0d}), Vectors.dense(new double[]{2.0d, 0.0d, 0.0d, 0.0d})}), Row.of(new Object[]{Vectors.dense(new double[]{1.0d, 0.0d, -1.0d, 0.0d}), Vectors.dense(new double[]{0.0d, 0.924d, 1.0d, -0.383d})}));
    private static final List<Row> expectedInverseOutputData = Arrays.asList(Row.of(new Object[]{Vectors.dense(new double[]{1.0d, 1.0d, 1.0d, 1.0d}), Vectors.dense(new double[]{1.924d, -0.383d, 0.383d, 0.076d})}), Row.of(new Object[]{Vectors.dense(new double[]{1.0d, 0.0d, -1.0d, 0.0d}), Vectors.dense(new double[]{0.0d, 1.0d, 1.0d, 0.0d})}));
    private Table inputTable;

    @Before
    public void before() {
        this.env = TestUtils.getExecutionEnvironment();
        this.tEnv = StreamTableEnvironment.create(this.env);
        this.inputTable = this.tEnv.fromDataStream(this.env.fromCollection(inputData)).as("input", new String[0]);
    }

    @Test
    public void testParam() {
        DCT dct = new DCT();
        Assert.assertEquals("input", dct.getInputCol());
        Assert.assertEquals("output", dct.getOutputCol());
        Assert.assertFalse(dct.getInverse());
        ((DCT) ((DCT) dct.setInputCol("test_input")).setOutputCol("test_output")).setInverse(true);
        Assert.assertEquals("test_input", dct.getInputCol());
        Assert.assertEquals("test_output", dct.getOutputCol());
        Assert.assertTrue(dct.getInverse());
    }

    @Test
    public void testOutputSchema() {
        Table as = this.tEnv.fromDataStream(this.env.fromElements(new Row[]{Row.of(new Object[]{Vectors.dense(new double[]{0.0d}), ""})})).as("test_input", new String[]{"dummy_input"});
        DCT dct = (DCT) ((DCT) new DCT().setInputCol("test_input")).setOutputCol("test_output");
        Assert.assertEquals(Arrays.asList(dct.getInputCol(), "dummy_input", dct.getOutputCol()), dct.transform(new Table[]{as})[0].getResolvedSchema().getColumnNames());
    }

    @Test
    public void testTransformForward() {
        DCT dct = new DCT();
        verifyTransformResult(dct.transform(new Table[]{this.inputTable})[0], expectedForwardOutputData, dct.getInputCol(), dct.getOutputCol());
    }

    @Test
    public void testTransformInverse() {
        DCT dct = (DCT) new DCT().setInverse(true);
        verifyTransformResult(dct.transform(new Table[]{this.inputTable})[0], expectedInverseOutputData, dct.getInputCol(), dct.getOutputCol());
    }

    @Test
    public void testInputTypeConversion() throws Exception {
        this.inputTable = TestUtils.convertDataTypesToSparseInt(this.tEnv, this.inputTable);
        Assert.assertArrayEquals(new Class[]{SparseVector.class}, TestUtils.getColumnDataTypes(this.inputTable));
        DCT dct = new DCT();
        verifyTransformResult(dct.transform(new Table[]{this.inputTable})[0], expectedForwardOutputData, dct.getInputCol(), dct.getOutputCol());
    }

    @Test
    public void testSaveLoadAndTransform() throws Exception {
        DCT saveAndReload = TestUtils.saveAndReload(this.tEnv, (DCT) new DCT().setInverse(true), TEMPORARY_FOLDER.newFolder().getAbsolutePath(), DCT::load);
        verifyTransformResult(saveAndReload.transform(new Table[]{this.inputTable})[0], expectedInverseOutputData, saveAndReload.getInputCol(), saveAndReload.getOutputCol());
    }

    private static void verifyTransformResult(Table table, List<Row> list, String str, String str2) {
        List list2 = IteratorUtils.toList(table.execute().collect());
        list2.sort(Comparator.comparingLong(row -> {
            return ((Vector) Objects.requireNonNull(row.getField(str))).toDense().hashCode();
        }));
        list.sort(Comparator.comparingLong(row2 -> {
            return ((Vector) Objects.requireNonNull(row2.getField(0))).toDense().hashCode();
        }));
        Assert.assertEquals(list2.size(), list.size());
        for (int i = 0; i < list2.size(); i++) {
            Assert.assertArrayEquals(((Vector) list.get(i).getFieldAs(1)).toArray(), ((Vector) ((Row) list2.get(i)).getFieldAs(str2)).toArray(), 0.001d);
        }
    }
}
