package org.apache.flink.ml.stats;

import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import org.apache.commons.collections.IteratorUtils;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.ml.stats.anovatest.ANOVATest;
import org.apache.flink.ml.util.TestUtils;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.api.internal.TableImpl;
import org.apache.flink.test.util.AbstractTestBase;
import org.apache.flink.types.Row;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;

/* loaded from: input_file:org/apache/flink/ml/stats/ANOVATestTest.class */
public class ANOVATestTest extends AbstractTestBase {

    @Rule
    public final TemporaryFolder tempFolder = new TemporaryFolder();
    private StreamTableEnvironment tEnv;
    private Table denseInputTable;
    private Table sparseInputTable;
    private static final double EPS = 1.0E-5d;
    private static final List<Row> DENSE_INPUT_DATA = Arrays.asList(Row.of(new Object[]{3, Vectors.dense(new double[]{0.85956061d, 0.1645695d, 0.48347596d, 0.92102727d, 0.42855644d, 0.05746009d})}), Row.of(new Object[]{2, Vectors.dense(new double[]{0.92500743d, 0.65760154d, 0.13295284d, 0.53344893d, 0.8994776d, 0.24836496d})}), Row.of(new Object[]{1, Vectors.dense(new double[]{0.03017182d, 0.07244715d, 0.87416449d, 0.55843035d, 0.91604736d, 0.63346045d})}), Row.of(new Object[]{5, Vectors.dense(new double[]{0.28325261d, 0.36536881d, 0.09223386d, 0.37251258d, 0.34742278d, 0.70517077d})}), Row.of(new Object[]{4, Vectors.dense(new double[]{0.64850904d, 0.04090877d, 0.21173176d, 0.00148992d, 0.13897166d, 0.21182539d})}), Row.of(new Object[]{4, Vectors.dense(new double[]{0.02609493d, 0.44608735d, 0.23910531d, 0.95449222d, 0.90763182d, 0.8624905d})}), Row.of(new Object[]{5, Vectors.dense(new double[]{0.09158744d, 0.97745235d, 0.41150139d, 0.45830467d, 0.52590925d, 0.29441554d})}), Row.of(new Object[]{4, Vectors.dense(new double[]{0.97211594d, 0.1814442d, 0.30340642d, 0.17445413d, 0.52756958d, 0.02069296d})}), Row.of(new Object[]{2, Vectors.dense(new double[]{0.06354593d, 0.63527231d, 0.49620335d, 0.0141264d, 0.62722219d, 0.63497507d})}), Row.of(new Object[]{1, Vectors.dense(new double[]{0.10814149d, 0.8296426d, 0.51775217d, 0.57068344d, 0.54633305d, 0.12714921d})}), Row.of(new Object[]{1, Vectors.dense(new double[]{0.72731796d, 0.94010124d, 0.45007811d, 0.87650674d, 0.53735565d, 0.49568415d})}), Row.of(new Object[]{2, Vectors.dense(new double[]{0.41827208d, 0.85100628d, 0.38685271d, 0.60689503d, 0.21784097d, 0.91294433d})}), Row.of(new Object[]{3, Vectors.dense(new double[]{0.65843656d, 0.5880859d, 0.18862706d, 0.856398d, 0.18029327d, 0.94851926d})}), Row.of(new Object[]{4, Vectors.dense(new double[]{0.3841634d, 0.25138793d, 0.96746644d, 0.77048045d, 0.44685196d, 0.19813854d})}), Row.of(new Object[]{5, Vectors.dense(new double[]{0.65982267d, 0.23024125d, 0.13598434d, 0.60144265d, 0.57848927d, 0.85623564d})}), Row.of(new Object[]{1, Vectors.dense(new double[]{0.35764189d, 0.47623815d, 0.5459232d, 0.79508298d, 0.14462443d, 0.01802919d})}), Row.of(new Object[]{5, Vectors.dense(new double[]{0.38532153d, 0.90614554d, 0.86629571d, 0.13988735d, 0.32062385d, 0.00179492d})}), Row.of(new Object[]{3, Vectors.dense(new double[]{0.2142368d, 0.28306022d, 0.59481646d, 0.42567028d, 0.52207663d, 0.78082401d})}), Row.of(new Object[]{1, Vectors.dense(new double[]{0.20788283d, 0.76861782d, 0.59595468d, 0.62103642d, 0.17781246d, 0.77655345d})}), Row.of(new Object[]{1, Vectors.dense(new double[]{0.1751708d, 0.4547537d, 0.46187865d, 0.79781199d, 0.05104487d, 0.42406092d})}));
    private static final List<Row> SPARSE_INPUT_DATA = Arrays.asList(Row.of(new Object[]{3, Vectors.dense(new double[]{6.0d, 7.0d, 0.0d, 7.0d, 6.0d, 0.0d, 0.0d}).toSparse()}), Row.of(new Object[]{1, Vectors.dense(new double[]{0.0d, 9.0d, 6.0d, 0.0d, 5.0d, 9.0d, 0.0d}).toSparse()}), Row.of(new Object[]{3, Vectors.dense(new double[]{0.0d, 9.0d, 3.0d, 0.0d, 5.0d, 5.0d, 0.0d}).toSparse()}), Row.of(new Object[]{2, Vectors.dense(new double[]{0.0d, 9.0d, 8.0d, 5.0d, 6.0d, 4.0d, 0.0d}).toSparse()}), Row.of(new Object[]{2, Vectors.dense(new double[]{8.0d, 9.0d, 6.0d, 5.0d, 4.0d, 4.0d, 0.0d}).toSparse()}), Row.of(new Object[]{3, Vectors.dense(new double[]{Double.NaN, 9.0d, 6.0d, 4.0d, 0.0d, 0.0d, 0.0d}).toSparse()}));
    private static final Row EXPECTED_OUTPUT_DENSE = Row.of(new Object[]{Vectors.dense(new double[]{0.64137831d, 0.14830724d, 0.69858474d, 0.28038169d, 0.86759161d, 0.81608606d}), new long[]{19, 19, 19, 19, 19, 19}, Vectors.dense(new double[]{0.64110932d, 1.98689258d, 0.55499714d, 1.40340562d, 0.30881722d, 0.3848595d})});
    private static final List<Row> EXPECTED_FLATTENED_OUTPUT_DENSE = Arrays.asList(Row.of(new Object[]{0, Double.valueOf(0.64137831d), 19, Double.valueOf(0.64110932d)}), Row.of(new Object[]{1, Double.valueOf(0.14830724d), 19, Double.valueOf(1.98689258d)}), Row.of(new Object[]{2, Double.valueOf(0.69858474d), 19, Double.valueOf(0.55499714d)}), Row.of(new Object[]{3, Double.valueOf(0.28038169d), 19, Double.valueOf(1.40340562d)}), Row.of(new Object[]{4, Double.valueOf(0.86759161d), 19, Double.valueOf(0.30881722d)}), Row.of(new Object[]{5, Double.valueOf(0.81608606d), 19, Double.valueOf(0.3848595d)}));
    private static final Row EXPECTED_OUTPUT_SPARSE = Row.of(new Object[]{Vectors.dense(new double[]{Double.NaN, 0.71554175d, 0.34278574d, 0.45824059d, 0.84633632d, 0.15673368d, Double.NaN}), new long[]{5, 5, 5, 5, 5, 5, 5}, Vectors.dense(new double[]{Double.NaN, 0.375d, 1.5625d, 1.02364865d, 0.17647059d, 3.66d, Double.NaN})});
    private static final List<Row> EXPECTED_FLATTENED_OUTPUT_SPARSE = Arrays.asList(Row.of(new Object[]{0, Double.valueOf(Double.NaN), 5, Double.valueOf(Double.NaN)}), Row.of(new Object[]{1, Double.valueOf(0.71554175d), 5, Double.valueOf(0.375d)}), Row.of(new Object[]{2, Double.valueOf(0.34278574d), 5, Double.valueOf(1.5625d)}), Row.of(new Object[]{3, Double.valueOf(0.45824059d), 5, Double.valueOf(1.02364865d)}), Row.of(new Object[]{4, Double.valueOf(0.84633632d), 5, Double.valueOf(0.17647059d)}), Row.of(new Object[]{5, Double.valueOf(0.15673368d), 5, Double.valueOf(3.66d)}), Row.of(new Object[]{6, Double.valueOf(Double.NaN), 5, Double.valueOf(Double.NaN)}));

    @Before
    public void before() {
        StreamExecutionEnvironment executionEnvironment = TestUtils.getExecutionEnvironment();
        this.tEnv = StreamTableEnvironment.create(executionEnvironment);
        this.denseInputTable = this.tEnv.fromDataStream(executionEnvironment.fromCollection(DENSE_INPUT_DATA)).as("label", new String[]{"features"});
        this.sparseInputTable = this.tEnv.fromDataStream(executionEnvironment.fromCollection(SPARSE_INPUT_DATA)).as("label", new String[]{"features"});
    }

    private static void verifyFlattenTransformationResult(Table table, List<Row> list) throws Exception {
        List list2 = IteratorUtils.toList(((TableImpl) table).getTableEnvironment().toDataStream(table).executeAndCollect());
        Assert.assertEquals(list.size(), list2.size());
        list2.sort(Comparator.comparing(row -> {
            return String.valueOf(row.getField(0));
        }));
        list.sort(Comparator.comparing(row2 -> {
            return String.valueOf(row2.getField(0));
        }));
        for (int i = 0; i < list.size(); i++) {
            Assert.assertEquals(list.get(i).getArity(), ((Row) list2.get(i)).getArity());
            for (int i2 = 0; i2 < list.get(i).getArity(); i2++) {
                Assert.assertEquals(Double.valueOf(list.get(i).getField(i2).toString()).doubleValue(), Double.valueOf(((Row) list2.get(i)).getField(i2).toString()).doubleValue(), EPS);
            }
        }
    }

    private static void verifyTransformationResult(Table table, Row row) throws Exception {
        List list = IteratorUtils.toList(((TableImpl) table).getTableEnvironment().toDataStream(table).executeAndCollect());
        Assert.assertEquals(1L, list.size());
        Row row2 = (Row) list.get(0);
        Assert.assertEquals(3L, row2.getArity());
        Assert.assertArrayEquals(((Vector) row.getField(0)).toArray(), ((Vector) row2.getField(0)).toArray(), EPS);
        Assert.assertArrayEquals((long[]) row.getField(1), (long[]) row2.getField(1));
        Assert.assertArrayEquals(((Vector) row.getField(2)).toArray(), ((Vector) row2.getField(2)).toArray(), EPS);
    }

    @Test
    public void testParam() {
        ANOVATest aNOVATest = new ANOVATest();
        Assert.assertEquals("label", aNOVATest.getLabelCol());
        Assert.assertEquals("features", aNOVATest.getFeaturesCol());
        Assert.assertFalse(aNOVATest.getFlatten());
        ((ANOVATest) ((ANOVATest) aNOVATest.setLabelCol("test_label")).setFeaturesCol("test_features")).setFlatten(true);
        Assert.assertEquals("test_features", aNOVATest.getFeaturesCol());
        Assert.assertEquals("test_label", aNOVATest.getLabelCol());
        Assert.assertTrue(aNOVATest.getFlatten());
    }

    @Test
    public void testOutputSchema() {
        ANOVATest aNOVATest = (ANOVATest) ((ANOVATest) new ANOVATest().setFeaturesCol("test_features")).setLabelCol("test_label");
        Assert.assertEquals(Arrays.asList("pValues", "degreesOfFreedom", "fValues"), aNOVATest.transform(new Table[]{this.denseInputTable})[0].getResolvedSchema().getColumnNames());
        aNOVATest.setFlatten(true);
        Assert.assertEquals(Arrays.asList("featureIndex", "pValue", "degreeOfFreedom", "fValue"), aNOVATest.transform(new Table[]{this.denseInputTable})[0].getResolvedSchema().getColumnNames());
    }

    @Test
    public void testTransform() throws Exception {
        ANOVATest aNOVATest = new ANOVATest();
        verifyTransformationResult(aNOVATest.transform(new Table[]{this.denseInputTable})[0], EXPECTED_OUTPUT_DENSE);
        verifyTransformationResult(aNOVATest.transform(new Table[]{this.sparseInputTable})[0], EXPECTED_OUTPUT_SPARSE);
    }

    @Test
    public void testTransformWithFlatten() throws Exception {
        ANOVATest aNOVATest = (ANOVATest) new ANOVATest().setFlatten(true);
        verifyFlattenTransformationResult(aNOVATest.transform(new Table[]{this.denseInputTable})[0], EXPECTED_FLATTENED_OUTPUT_DENSE);
        verifyFlattenTransformationResult(aNOVATest.transform(new Table[]{this.sparseInputTable})[0], EXPECTED_FLATTENED_OUTPUT_SPARSE);
    }

    @Test
    public void testSaveLoadAndTransform() throws Exception {
        verifyTransformationResult(TestUtils.saveAndReload(this.tEnv, new ANOVATest(), this.tempFolder.newFolder().getAbsolutePath(), ANOVATest::load).transform(new Table[]{this.denseInputTable})[0], EXPECTED_OUTPUT_DENSE);
    }
}
