package org.apache.flink.ml.stats;

import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import org.apache.commons.collections.IteratorUtils;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.ml.stats.chisqtest.ChiSqTest;
import org.apache.flink.ml.util.TestUtils;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.api.internal.TableImpl;
import org.apache.flink.test.util.AbstractTestBase;
import org.apache.flink.test.util.TestBaseUtils;
import org.apache.flink.types.Row;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;

/* loaded from: input_file:org/apache/flink/ml/stats/ChiSqTestTest.class */
public class ChiSqTestTest extends AbstractTestBase {
    private StreamTableEnvironment tEnv;
    private Table inputTableWithDoubleLabel;
    private Table inputTableWithIntegerLabel;

    @Rule
    public final TemporaryFolder tempFolder = new TemporaryFolder();
    private final List<Row> samplesWithDoubleLabel = Arrays.asList(Row.of(new Object[]{Double.valueOf(0.0d), Vectors.dense(new double[]{5.0d, 1.0d})}), Row.of(new Object[]{Double.valueOf(2.0d), Vectors.dense(new double[]{6.0d, 2.0d})}), Row.of(new Object[]{Double.valueOf(1.0d), Vectors.dense(new double[]{7.0d, 2.0d})}), Row.of(new Object[]{Double.valueOf(1.0d), Vectors.dense(new double[]{5.0d, 4.0d})}), Row.of(new Object[]{Double.valueOf(0.0d), Vectors.dense(new double[]{5.0d, 1.0d})}), Row.of(new Object[]{Double.valueOf(2.0d), Vectors.dense(new double[]{6.0d, 2.0d})}), Row.of(new Object[]{Double.valueOf(1.0d), Vectors.dense(new double[]{7.0d, 2.0d})}), Row.of(new Object[]{Double.valueOf(1.0d), Vectors.dense(new double[]{5.0d, 4.0d})}), Row.of(new Object[]{Double.valueOf(2.0d), Vectors.dense(new double[]{5.0d, 1.0d})}), Row.of(new Object[]{Double.valueOf(0.0d), Vectors.dense(new double[]{5.0d, 2.0d})}), Row.of(new Object[]{Double.valueOf(0.0d), Vectors.dense(new double[]{5.0d, 2.0d})}), Row.of(new Object[]{Double.valueOf(1.0d), Vectors.dense(new double[]{9.0d, 4.0d})}), Row.of(new Object[]{Double.valueOf(1.0d), Vectors.dense(new double[]{9.0d, 3.0d})}));
    private final List<Row> expectedChiSqTestResultWithDoubleLabel = Collections.singletonList(Row.of(new Object[]{Vectors.dense(new double[]{0.03419350755d, 0.24220177737d}), new int[]{6, 6}, Vectors.dense(new double[]{13.61904761905d, 7.94444444444d})}));
    private final List<Row> expectedChiSqTestResultWithDoubleLabelWithFlatten = Arrays.asList(Row.of(new Object[]{0, Double.valueOf(0.03419350755d), 6, Double.valueOf(13.61904761905d)}), Row.of(new Object[]{1, Double.valueOf(0.24220177737d), 6, Double.valueOf(7.94444444444d)}));
    private final List<Row> samplesWithIntegerLabel = Arrays.asList(Row.of(new Object[]{33, Vectors.dense(new double[]{5.0d, 0.0d})}), Row.of(new Object[]{44, Vectors.dense(new double[]{6.0d, 1.0d})}), Row.of(new Object[]{55, Vectors.dense(new double[]{7.0d, 1.0d})}), Row.of(new Object[]{11, Vectors.dense(new double[]{5.0d, 1.0d})}), Row.of(new Object[]{11, Vectors.dense(new double[]{5.0d, 0.0d})}), Row.of(new Object[]{33, Vectors.dense(new double[]{6.0d, 2.0d})}), Row.of(new Object[]{22, Vectors.dense(new double[]{7.0d, 2.0d})}), Row.of(new Object[]{66, Vectors.dense(new double[]{5.0d, 3.0d})}), Row.of(new Object[]{77, Vectors.dense(new double[]{5.0d, 3.0d})}), Row.of(new Object[]{88, Vectors.dense(new double[]{5.0d, 4.0d})}), Row.of(new Object[]{77, Vectors.dense(new double[]{5.0d, 6.0d})}), Row.of(new Object[]{44, Vectors.dense(new double[]{9.0d, 6.0d})}), Row.of(new Object[]{11, Vectors.dense(new double[]{9.0d, 8.0d})}));
    private final List<Row> expectedChiSqTestResultWithIntegerLabel = Collections.singletonList(Row.of(new Object[]{Vectors.dense(new double[]{0.35745138256d, 0.39934987096d}), new int[]{21, 42}, Vectors.dense(new double[]{22.75d, 43.69444444444d})}));

    @Before
    public void before() {
        StreamExecutionEnvironment executionEnvironment = TestUtils.getExecutionEnvironment();
        this.tEnv = StreamTableEnvironment.create(executionEnvironment);
        this.inputTableWithDoubleLabel = this.tEnv.fromDataStream(executionEnvironment.fromCollection(this.samplesWithDoubleLabel)).as("label", new String[]{"features"});
        this.inputTableWithIntegerLabel = this.tEnv.fromDataStream(executionEnvironment.fromCollection(this.samplesWithIntegerLabel)).as("label", new String[]{"features"});
    }

    private static void verifyPredictionResult(Table table, List<Row> list) throws Exception {
        TestBaseUtils.compareResultCollections(list, IteratorUtils.toList(((TableImpl) table).getTableEnvironment().toDataStream(table).executeAndCollect()), (row, row2) -> {
            return !row.equals(row2) ? 1 : 0;
        });
    }

    @Test
    public void testParam() {
        ChiSqTest chiSqTest = new ChiSqTest();
        Assert.assertEquals("label", chiSqTest.getLabelCol());
        Assert.assertEquals("features", chiSqTest.getFeaturesCol());
        Assert.assertFalse(chiSqTest.getFlatten());
        ((ChiSqTest) ((ChiSqTest) chiSqTest.setLabelCol("test_label")).setFeaturesCol("test_features")).setFlatten(true);
        Assert.assertEquals("test_features", chiSqTest.getFeaturesCol());
        Assert.assertEquals("test_label", chiSqTest.getLabelCol());
        Assert.assertTrue(chiSqTest.getFlatten());
    }

    @Test
    public void testOutputSchema() {
        ChiSqTest chiSqTest = (ChiSqTest) ((ChiSqTest) new ChiSqTest().setFeaturesCol("features")).setLabelCol("label");
        Assert.assertEquals(Arrays.asList("pValues", "degreesOfFreedom", "statistics"), chiSqTest.transform(new Table[]{this.inputTableWithDoubleLabel})[0].getResolvedSchema().getColumnNames());
        chiSqTest.setFlatten(true);
        Assert.assertEquals(Arrays.asList("featureIndex", "pValue", "degreeOfFreedom", "statistic"), chiSqTest.transform(new Table[]{this.inputTableWithDoubleLabel})[0].getResolvedSchema().getColumnNames());
    }

    @Test
    public void testTransform() throws Exception {
        ChiSqTest chiSqTest = (ChiSqTest) ((ChiSqTest) new ChiSqTest().setFeaturesCol("features")).setLabelCol("label");
        verifyPredictionResult(chiSqTest.transform(new Table[]{this.inputTableWithDoubleLabel})[0], this.expectedChiSqTestResultWithDoubleLabel);
        verifyPredictionResult(chiSqTest.transform(new Table[]{this.inputTableWithIntegerLabel})[0], this.expectedChiSqTestResultWithIntegerLabel);
    }

    @Test
    public void testTransformWithFlatten() throws Exception {
        verifyPredictionResult(((ChiSqTest) ((ChiSqTest) ((ChiSqTest) new ChiSqTest().setFlatten(true)).setFeaturesCol("features")).setLabelCol("label")).transform(new Table[]{this.inputTableWithDoubleLabel})[0], this.expectedChiSqTestResultWithDoubleLabelWithFlatten);
    }

    @Test
    public void testSaveLoadAndTransform() throws Exception {
        verifyPredictionResult(TestUtils.saveAndReload(this.tEnv, (ChiSqTest) ((ChiSqTest) new ChiSqTest().setFeaturesCol("features")).setLabelCol("label"), this.tempFolder.newFolder().getAbsolutePath(), ChiSqTest::load).transform(new Table[]{this.inputTableWithDoubleLabel})[0], this.expectedChiSqTestResultWithDoubleLabel);
    }
}
