package org.apache.flink.ml.feature;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.apache.commons.collections.IteratorUtils;
import org.apache.flink.ml.feature.normalizer.Normalizer;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.ml.util.TestUtils;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.api.internal.TableImpl;
import org.apache.flink.test.util.AbstractTestBase;
import org.apache.flink.test.util.TestBaseUtils;
import org.apache.flink.types.Row;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;

/* loaded from: input_file:org/apache/flink/ml/feature/NormalizerTest.class */
public class NormalizerTest extends AbstractTestBase {
    private StreamTableEnvironment tEnv;
    private Table inputDataTable;
    private static final List<Row> INPUT_DATA = Arrays.asList(Row.of(new Object[]{Vectors.dense(new double[]{2.1d, 3.1d, 2.3d, 3.4d, 5.3d, 5.1d}), Vectors.sparse(5, new int[]{1, 3, 4}, new double[]{0.1d, 0.2d, 0.3d})}), Row.of(new Object[]{Vectors.dense(new double[]{2.3d, 4.1d, 1.3d, 2.4d, 5.1d, 4.1d}), Vectors.sparse(5, new int[]{1, 2, 4}, new double[]{0.1d, 0.2d, 0.3d})}));
    private static final List<Vector> EXPECTED_DENSE_OUTPUT = Arrays.asList(Vectors.dense(new double[]{0.17386300895299714d, 0.25665491797823387d, 0.19042139075804446d, 0.28149249068580484d, 0.43879711783375464d, 0.42223873602870726d}), Vectors.dense(new double[]{0.20785190042726007d, 0.3705186051094636d, 0.11748150893714701d, 0.2168889395762714d, 0.4608889965995767d, 0.3705186051094636d}));
    private static final List<Vector> EXPECTED_SPARSE_OUTPUT = Arrays.asList(Vectors.sparse(5, new int[]{1, 3, 4}, new double[]{0.23070057753660791d, 0.46140115507321583d, 0.6921017326098237d}), Vectors.sparse(5, new int[]{1, 2, 4}, new double[]{0.23070057753660791d, 0.46140115507321583d, 0.6921017326098237d}));

    @Before
    public void before() {
        StreamExecutionEnvironment executionEnvironment = TestUtils.getExecutionEnvironment();
        this.tEnv = StreamTableEnvironment.create(executionEnvironment);
        this.inputDataTable = this.tEnv.fromDataStream(executionEnvironment.fromCollection(INPUT_DATA)).as("denseVec", new String[]{"sparseVec"});
    }

    private void verifyOutputResult(Table table, String str, List<Vector> list) throws Exception {
        List<Row> list2 = IteratorUtils.toList(((TableImpl) table).getTableEnvironment().toDataStream(table).executeAndCollect());
        ArrayList arrayList = new ArrayList(list2.size());
        for (Row row : list2) {
            if (row.getField(str) != null) {
                arrayList.add((Vector) row.getFieldAs(str));
            }
        }
        TestBaseUtils.compareResultCollections(list, arrayList, TestUtils::compare);
    }

    @Test
    public void testParam() {
        Normalizer normalizer = new Normalizer();
        Assert.assertEquals("input", normalizer.getInputCol());
        Assert.assertEquals("output", normalizer.getOutputCol());
        Assert.assertEquals(2.0d, normalizer.getP().doubleValue(), 1.0E-5d);
        ((Normalizer) ((Normalizer) normalizer.setInputCol("denseVec")).setOutputCol("outputVec")).setP(Double.valueOf(1.5d));
        Assert.assertEquals("denseVec", normalizer.getInputCol());
        Assert.assertEquals("outputVec", normalizer.getOutputCol());
        Assert.assertEquals(1.5d, normalizer.getP().doubleValue(), 1.0E-5d);
    }

    @Test
    public void testOutputSchema() {
        Assert.assertEquals(Arrays.asList("denseVec", "sparseVec", "outputVec"), ((Normalizer) ((Normalizer) ((Normalizer) new Normalizer().setInputCol("denseVec")).setOutputCol("outputVec")).setP(Double.valueOf(1.5d))).transform(new Table[]{this.inputDataTable})[0].getResolvedSchema().getColumnNames());
    }

    @Test
    public void testSaveLoadAndTransform() throws Exception {
        Normalizer saveAndReload = TestUtils.saveAndReload(this.tEnv, (Normalizer) ((Normalizer) ((Normalizer) new Normalizer().setInputCol("denseVec")).setOutputCol("outputVec")).setP(Double.valueOf(1.5d)), TEMPORARY_FOLDER.newFolder().getAbsolutePath(), Normalizer::load);
        verifyOutputResult(saveAndReload.transform(new Table[]{this.inputDataTable})[0], saveAndReload.getOutputCol(), EXPECTED_DENSE_OUTPUT);
    }

    @Test
    public void testInvalidP() {
        try {
            ((Normalizer) ((Normalizer) ((Normalizer) new Normalizer().setInputCol("denseVec")).setOutputCol("outputVec")).setP(Double.valueOf(0.5d))).transform(new Table[]{this.inputDataTable});
            Assert.fail();
        } catch (Exception e) {
            Assert.assertEquals("Parameter p is given an invalid value 0.5", e.getMessage());
        }
    }

    @Test
    public void testDenseTransform() throws Exception {
        Normalizer normalizer = (Normalizer) ((Normalizer) ((Normalizer) new Normalizer().setInputCol("denseVec")).setOutputCol("outputVec")).setP(Double.valueOf(1.5d));
        verifyOutputResult(normalizer.transform(new Table[]{this.inputDataTable})[0], normalizer.getOutputCol(), EXPECTED_DENSE_OUTPUT);
    }

    @Test
    public void testSparseTransform() throws Exception {
        Normalizer normalizer = (Normalizer) ((Normalizer) ((Normalizer) new Normalizer().setInputCol("sparseVec")).setOutputCol("outputVec")).setP(Double.valueOf(1.5d));
        verifyOutputResult(normalizer.transform(new Table[]{this.inputDataTable})[0], normalizer.getOutputCol(), EXPECTED_SPARSE_OUTPUT);
    }
}
