package org.apache.flink.ml.stats;

import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import org.apache.commons.collections.IteratorUtils;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.ml.stats.fvaluetest.FValueTest;
import org.apache.flink.ml.util.TestUtils;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.api.internal.TableImpl;
import org.apache.flink.test.util.AbstractTestBase;
import org.apache.flink.types.Row;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;

/* loaded from: input_file:org/apache/flink/ml/stats/FValueTestTest.class */
public class FValueTestTest extends AbstractTestBase {

    @Rule
    public final TemporaryFolder tempFolder = new TemporaryFolder();
    private StreamTableEnvironment tEnv;
    private Table denseInputTable;
    private Table sparseInputTable;
    private static final double EPS = 1.0E-5d;
    private static final List<Row> DENSE_INPUT_DATA = Arrays.asList(Row.of(new Object[]{Double.valueOf(0.19775997d), Vectors.dense(new double[]{0.15266373d, 0.30235661d, 0.06203641d, 0.45986034d, 0.83525338d, 0.92699705d})}), Row.of(new Object[]{Double.valueOf(0.66009772d), Vectors.dense(new double[]{0.72698898d, 0.76849622d, 0.26920507d, 0.64402929d, 0.09337326d, 0.07968589d})}), Row.of(new Object[]{Double.valueOf(0.80865842d), Vectors.dense(new double[]{0.58961375d, 0.34334054d, 0.98887615d, 0.62647321d, 0.68177928d, 0.55225681d})}), Row.of(new Object[]{Double.valueOf(0.34142582d), Vectors.dense(new double[]{0.26886006d, 0.37325939d, 0.2229281d, 0.1864426d, 0.39064809d, 0.19316241d})}), Row.of(new Object[]{Double.valueOf(0.84756607d), Vectors.dense(new double[]{0.61091093d, 0.88280845d, 0.62233882d, 0.25311894d, 0.17993031d, 0.81640447d})}), Row.of(new Object[]{Double.valueOf(0.53360225d), Vectors.dense(new double[]{0.22537162d, 0.51685714d, 0.51849582d, 0.60037494d, 0.53262048d, 0.01331005d})}), Row.of(new Object[]{Double.valueOf(0.90053371d), Vectors.dense(new double[]{0.52409726d, 0.89588471d, 0.76990129d, 0.1228517d, 0.29587269d, 0.61202358d})}), Row.of(new Object[]{Double.valueOf(0.78779561d), Vectors.dense(new double[]{0.72613812d, 0.46349747d, 0.76911037d, 0.19163103d, 0.55786672d, 0.55077816d})}), Row.of(new Object[]{Double.valueOf(0.51604647d), Vectors.dense(new double[]{0.47222549d, 0.79188496d, 0.11524968d, 0.6813039d, 0.36233361d, 0.34420889d})}), Row.of(new Object[]{Double.valueOf(0.35325637d), Vectors.dense(new double[]{0.44951875d, 0.02694226d, 0.41524769d, 0.9222317d, 0.09120557d, 0.31512178d})}), Row.of(new Object[]{Double.valueOf(0.51408926d), Vectors.dense(new double[]{0.52802224d, 0.32806203d, 0.44891554d, 0.01633442d, 0.0970269d, 0.69258857d})}), Row.of(new Object[]{Double.valueOf(0.84489897d), Vectors.dense(new double[]{0.83594341d, 0.42432199d, 0.8487743d, 0.54679121d, 0.35410346d, 0.72724968d})}), Row.of(new Object[]{Double.valueOf(0.55342816d), Vectors.dense(new double[]{0.09385168d, 0.8928588d, 0.33625828d, 0.89183268d, 0.296849d, 0.30164829d})}), Row.of(new Object[]{Double.valueOf(0.89405683d), Vectors.dense(new double[]{0.80624061d, 0.83760997d, 0.63428133d, 0.3113273d, 0.02944858d, 0.39977732d})}), Row.of(new Object[]{Double.valueOf(0.54588131d), Vectors.dense(new double[]{0.51817346d, 0.00738845d, 0.77494778d, 0.8544712d, 0.13153282d, 0.28767364d})}), Row.of(new Object[]{Double.valueOf(0.96038024d), Vectors.dense(new double[]{0.32658881d, 0.90655956d, 0.99955954d, 0.77088429d, 0.04284752d, 0.96525111d})}), Row.of(new Object[]{Double.valueOf(0.71349698d), Vectors.dense(new double[]{0.97521246d, 0.2025168d, 0.67985305d, 0.46534506d, 0.92001748d, 0.72820735d})}), Row.of(new Object[]{Double.valueOf(0.43456735d), Vectors.dense(new double[]{0.24585653d, 0.01953996d, 0.70598881d, 0.77448287d, 0.4729746d, 0.80146736d})}), Row.of(new Object[]{Double.valueOf(0.52462506d), Vectors.dense(new double[]{0.17539792d, 0.72016934d, 0.3678759d, 0.53209295d, 0.29719397d, 0.37429151d})}), Row.of(new Object[]{Double.valueOf(0.43074793d), Vectors.dense(new double[]{0.72810013d, 0.39850784d, 0.1058295d, 0.39858265d, 0.52196395d, 0.1060125d})}));
    private static final List<Row> SPARSE_INPUT_DATA = Arrays.asList(Row.of(new Object[]{Double.valueOf(4.6d), Vectors.dense(new double[]{6.0d, 7.0d, 0.0d, 7.0d, 6.0d, 0.0d, 0.0d}).toSparse()}), Row.of(new Object[]{Double.valueOf(6.6d), Vectors.dense(new double[]{0.0d, 9.0d, 6.0d, 0.0d, 5.0d, 9.0d, 0.0d}).toSparse()}), Row.of(new Object[]{Double.valueOf(5.1d), Vectors.dense(new double[]{0.0d, 9.0d, 3.0d, 0.0d, 5.0d, 5.0d, 0.0d}).toSparse()}), Row.of(new Object[]{Double.valueOf(7.6d), Vectors.dense(new double[]{0.0d, 9.0d, 8.0d, 5.0d, 6.0d, 4.0d, 0.0d}).toSparse()}), Row.of(new Object[]{Double.valueOf(9.0d), Vectors.dense(new double[]{8.0d, 9.0d, 6.0d, 5.0d, 4.0d, 4.0d, 0.0d}).toSparse()}), Row.of(new Object[]{Double.valueOf(9.0d), Vectors.dense(new double[]{Double.NaN, 9.0d, 6.0d, 4.0d, 0.0d, 0.0d, 0.0d}).toSparse()}));
    private static final Row EXPECTED_OUTPUT_DENSE = Row.of(new Object[]{Vectors.dense(new double[]{0.01736587d, 0.0149916659d, 1.12697153E-4d, 0.426990301d, 0.275911201d, 0.193549275d}), new long[]{18, 18, 18, 18, 18, 18}, Vectors.dense(new double[]{6.86260598d, 7.23175589d, 24.11424725d, 0.6605354d, 1.26266286d, 1.82421406d})});
    private static final List<Row> EXPECTED_FLATTENED_OUTPUT_DENSE = Arrays.asList(Row.of(new Object[]{0, Double.valueOf(0.01736587d), 18, Double.valueOf(6.86260598d)}), Row.of(new Object[]{1, Double.valueOf(0.0149916659d), 18, Double.valueOf(7.23175589d)}), Row.of(new Object[]{2, Double.valueOf(1.12697153E-4d), 18, Double.valueOf(24.11424725d)}), Row.of(new Object[]{3, Double.valueOf(0.426990301d), 18, Double.valueOf(0.6605354d)}), Row.of(new Object[]{4, Double.valueOf(0.275911201d), 18, Double.valueOf(1.26266286d)}), Row.of(new Object[]{5, Double.valueOf(0.193549275d), 18, Double.valueOf(1.82421406d)}));
    private static final Row EXPECTED_OUTPUT_SPARSE = Row.of(new Object[]{Vectors.dense(new double[]{Double.NaN, 0.19167161d, 0.06506426d, 0.75183662d, 0.16111045d, 0.89090362d, Double.NaN}), new long[]{4, 4, 4, 4, 4, 4, 4}, Vectors.dense(new double[]{Double.NaN, 2.46254817d, 6.37164347d, 0.1147488d, 2.94816821d, 0.02134755d, Double.NaN})});
    private static final List<Row> EXPECTED_FLATTENED_OUTPUT_SPARSE = Arrays.asList(Row.of(new Object[]{0, Double.valueOf(Double.NaN), 4, Double.valueOf(Double.NaN)}), Row.of(new Object[]{1, Double.valueOf(0.19167161d), 4, Double.valueOf(2.46254817d)}), Row.of(new Object[]{2, Double.valueOf(0.06506426d), 4, Double.valueOf(6.37164347d)}), Row.of(new Object[]{3, Double.valueOf(0.75183662d), 4, Double.valueOf(0.1147488d)}), Row.of(new Object[]{4, Double.valueOf(0.16111045d), 4, Double.valueOf(2.94816821d)}), Row.of(new Object[]{5, Double.valueOf(0.89090362d), 4, Double.valueOf(0.02134755d)}), Row.of(new Object[]{6, Double.valueOf(Double.NaN), 4, Double.valueOf(Double.NaN)}));

    @Before
    public void before() {
        StreamExecutionEnvironment executionEnvironment = TestUtils.getExecutionEnvironment();
        this.tEnv = StreamTableEnvironment.create(executionEnvironment);
        this.denseInputTable = this.tEnv.fromDataStream(executionEnvironment.fromCollection(DENSE_INPUT_DATA)).as("label", new String[]{"features"});
        this.sparseInputTable = this.tEnv.fromDataStream(executionEnvironment.fromCollection(SPARSE_INPUT_DATA)).as("label", new String[]{"features"});
    }

    private static void verifyFlattenTransformationResult(Table table, List<Row> list) throws Exception {
        List list2 = IteratorUtils.toList(((TableImpl) table).getTableEnvironment().toDataStream(table).executeAndCollect());
        Assert.assertEquals(list.size(), list2.size());
        list2.sort(Comparator.comparing(row -> {
            return String.valueOf(row.getField(0));
        }));
        list.sort(Comparator.comparing(row2 -> {
            return String.valueOf(row2.getField(0));
        }));
        for (int i = 0; i < list.size(); i++) {
            Assert.assertEquals(list.get(i).getArity(), ((Row) list2.get(i)).getArity());
            for (int i2 = 0; i2 < list.get(i).getArity(); i2++) {
                Assert.assertEquals(Double.valueOf(list.get(i).getField(i2).toString()).doubleValue(), Double.valueOf(((Row) list2.get(i)).getField(i2).toString()).doubleValue(), EPS);
            }
        }
    }

    private static void verifyTransformationResult(Table table, Row row) throws Exception {
        List list = IteratorUtils.toList(((TableImpl) table).getTableEnvironment().toDataStream(table).executeAndCollect());
        Assert.assertEquals(1L, list.size());
        Row row2 = (Row) list.get(0);
        Assert.assertEquals(3L, row2.getArity());
        Assert.assertArrayEquals(((Vector) row.getField(0)).toArray(), ((Vector) row2.getField(0)).toArray(), EPS);
        Assert.assertArrayEquals((long[]) row.getField(1), (long[]) row2.getField(1));
        Assert.assertArrayEquals(((Vector) row.getField(2)).toArray(), ((Vector) row2.getField(2)).toArray(), EPS);
    }

    @Test
    public void testParam() {
        FValueTest fValueTest = new FValueTest();
        Assert.assertEquals("label", fValueTest.getLabelCol());
        Assert.assertEquals("features", fValueTest.getFeaturesCol());
        Assert.assertFalse(fValueTest.getFlatten());
        ((FValueTest) ((FValueTest) fValueTest.setLabelCol("test_label")).setFeaturesCol("test_features")).setFlatten(true);
        Assert.assertEquals("test_features", fValueTest.getFeaturesCol());
        Assert.assertEquals("test_label", fValueTest.getLabelCol());
        Assert.assertTrue(fValueTest.getFlatten());
    }

    @Test
    public void testOutputSchema() {
        FValueTest fValueTest = (FValueTest) ((FValueTest) new FValueTest().setFeaturesCol("test_features")).setLabelCol("test_label");
        Assert.assertEquals(Arrays.asList("pValues", "degreesOfFreedom", "fValues"), fValueTest.transform(new Table[]{this.denseInputTable})[0].getResolvedSchema().getColumnNames());
        fValueTest.setFlatten(true);
        Assert.assertEquals(Arrays.asList("featureIndex", "pValue", "degreeOfFreedom", "fValue"), fValueTest.transform(new Table[]{this.denseInputTable})[0].getResolvedSchema().getColumnNames());
    }

    @Test
    public void testTransform() throws Exception {
        FValueTest fValueTest = new FValueTest();
        verifyTransformationResult(fValueTest.transform(new Table[]{this.denseInputTable})[0], EXPECTED_OUTPUT_DENSE);
        verifyTransformationResult(fValueTest.transform(new Table[]{this.sparseInputTable})[0], EXPECTED_OUTPUT_SPARSE);
    }

    @Test
    public void testTransformWithFlatten() throws Exception {
        FValueTest fValueTest = (FValueTest) new FValueTest().setFlatten(true);
        verifyFlattenTransformationResult(fValueTest.transform(new Table[]{this.denseInputTable})[0], EXPECTED_FLATTENED_OUTPUT_DENSE);
        verifyFlattenTransformationResult(fValueTest.transform(new Table[]{this.sparseInputTable})[0], EXPECTED_FLATTENED_OUTPUT_SPARSE);
    }

    @Test
    public void testSaveLoadAndTransform() throws Exception {
        verifyTransformationResult(TestUtils.saveAndReload(this.tEnv, new FValueTest(), this.tempFolder.newFolder().getAbsolutePath(), FValueTest::load).transform(new Table[]{this.denseInputTable})[0], EXPECTED_OUTPUT_DENSE);
    }
}
