package org.apache.flink.ml.feature;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import org.apache.commons.collections.IteratorUtils;
import org.apache.flink.ml.feature.randomsplitter.RandomSplitter;
import org.apache.flink.ml.util.TestUtils;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.test.util.AbstractTestBase;
import org.apache.flink.test.util.TestBaseUtils;
import org.apache.flink.types.Row;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;

/* loaded from: input_file:org/apache/flink/ml/feature/RandomSplitterTest.class */
public class RandomSplitterTest extends AbstractTestBase {

    @Rule
    public final TemporaryFolder tempFolder = new TemporaryFolder();
    private StreamExecutionEnvironment env;
    private StreamTableEnvironment tEnv;

    @Before
    public void before() {
        this.env = TestUtils.getExecutionEnvironment();
        this.env.setParallelism(1);
        this.tEnv = StreamTableEnvironment.create(this.env);
    }

    private Table getTable(int i) {
        return this.tEnv.fromDataStream(this.env.fromSequence(0L, i));
    }

    @Test
    public void testParam() {
        RandomSplitter randomSplitter = new RandomSplitter();
        ((RandomSplitter) randomSplitter.setWeights(new Double[]{Double.valueOf(0.3d), Double.valueOf(0.4d)})).setSeed(5L);
        Assert.assertArrayEquals(new Double[]{Double.valueOf(0.3d), Double.valueOf(0.4d)}, randomSplitter.getWeights());
        Assert.assertEquals(5L, randomSplitter.getSeed());
    }

    @Test
    public void testOutputSchema() {
        Table[] transform = ((RandomSplitter) new RandomSplitter().setWeights(new Double[]{Double.valueOf(0.5d), Double.valueOf(0.1d)})).transform(new Table[]{this.tEnv.fromDataStream(this.env.fromElements(new Row[]{Row.of(new Object[]{"", ""})})).as("test_input", new String[]{"dummy_input"})});
        Assert.assertEquals(2L, transform.length);
        for (Table table : transform) {
            Assert.assertEquals(Arrays.asList("test_input", "dummy_input"), table.getResolvedSchema().getColumnNames());
        }
    }

    @Test
    public void testWeights() throws Exception {
        Table table = getTable(1000);
        Table[] transform = ((RandomSplitter) new RandomSplitter().setWeights(new Double[]{Double.valueOf(2.0d), Double.valueOf(1.0d), Double.valueOf(2.0d)})).transform(new Table[]{table});
        List list = IteratorUtils.toList(this.tEnv.toDataStream(transform[0]).executeAndCollect());
        List list2 = IteratorUtils.toList(this.tEnv.toDataStream(transform[1]).executeAndCollect());
        List list3 = IteratorUtils.toList(this.tEnv.toDataStream(transform[2]).executeAndCollect());
        Assert.assertEquals(list.size() / 400.0d, 1.0d, 0.1d);
        Assert.assertEquals(list2.size() / 200.0d, 1.0d, 0.1d);
        Assert.assertEquals(list3.size() / 400.0d, 1.0d, 0.1d);
        verifyResultTables(table, transform);
    }

    @Test
    public void testSeed() throws Exception {
        Table table = getTable(100);
        RandomSplitter randomSplitter = (RandomSplitter) new RandomSplitter().setWeights(new Double[]{Double.valueOf(2.0d), Double.valueOf(1.0d), Double.valueOf(2.0d)});
        Table[] transform = randomSplitter.transform(new Table[]{table});
        List list = IteratorUtils.toList(this.tEnv.toDataStream(transform[0]).executeAndCollect());
        List list2 = IteratorUtils.toList(this.tEnv.toDataStream(transform[1]).executeAndCollect());
        List list3 = IteratorUtils.toList(this.tEnv.toDataStream(transform[2]).executeAndCollect());
        Table[] transform2 = randomSplitter.transform(new Table[]{table});
        List list4 = IteratorUtils.toList(this.tEnv.toDataStream(transform2[0]).executeAndCollect());
        List list5 = IteratorUtils.toList(this.tEnv.toDataStream(transform2[1]).executeAndCollect());
        List list6 = IteratorUtils.toList(this.tEnv.toDataStream(transform2[2]).executeAndCollect());
        Assert.assertEquals(list.size(), list4.size());
        Assert.assertEquals(list2.size(), list5.size());
        Assert.assertEquals(list3.size(), list6.size());
    }

    @Test
    public void testSaveLoadAndTransform() throws Exception {
        Table table = getTable(2000);
        Table[] transform = TestUtils.saveAndReload(this.tEnv, (RandomSplitter) new RandomSplitter().setWeights(new Double[]{Double.valueOf(4.0d), Double.valueOf(6.0d)}), TEMPORARY_FOLDER.newFolder().getAbsolutePath(), RandomSplitter::load).transform(new Table[]{table});
        List list = IteratorUtils.toList(this.tEnv.toDataStream(transform[0]).executeAndCollect());
        List list2 = IteratorUtils.toList(this.tEnv.toDataStream(transform[1]).executeAndCollect());
        Assert.assertEquals(list.size() / 800.0d, 1.0d, 0.1d);
        Assert.assertEquals(list2.size() / 1200.0d, 1.0d, 0.1d);
        verifyResultTables(table, transform);
    }

    private void verifyResultTables(Table table, Table[] tableArr) throws Exception {
        List list = IteratorUtils.toList(this.tEnv.toDataStream(table).executeAndCollect());
        ArrayList arrayList = new ArrayList();
        for (Table table2 : tableArr) {
            arrayList.addAll(IteratorUtils.toList(this.tEnv.toDataStream(table2).executeAndCollect()));
        }
        Assert.assertEquals(list.size(), arrayList.size());
        TestBaseUtils.compareResultCollections(list, arrayList, Comparator.comparingLong(row -> {
            return ((Long) row.getFieldAs(0)).longValue();
        }));
    }
}
