package org.apache.flink.ml.feature;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import org.apache.commons.collections.IteratorUtils;
import org.apache.commons.lang3.exception.ExceptionUtils;
import org.apache.flink.api.common.restartstrategy.RestartStrategies;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.ml.common.param.HasHandleInvalid;
import org.apache.flink.ml.feature.bucketizer.Bucketizer;
import org.apache.flink.ml.util.TestUtils;
import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.test.util.AbstractTestBase;
import org.apache.flink.types.Row;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;

/* loaded from: input_file:org/apache/flink/ml/feature/BucketizerTest.class */
public class BucketizerTest extends AbstractTestBase {
    private StreamTableEnvironment tEnv;
    private Table inputTable;
    private static final List<Row> inputData = Arrays.asList(Row.of(new Object[]{1, Double.valueOf(-0.5d), Double.valueOf(0.0d), Double.valueOf(1.0d), Double.valueOf(0.0d)}), Row.of(new Object[]{2, Double.valueOf(Double.NEGATIVE_INFINITY), Double.valueOf(1.0d), Double.valueOf(Double.POSITIVE_INFINITY), Double.valueOf(1.0d)}), Row.of(new Object[]{3, Double.valueOf(Double.NaN), Double.valueOf(-0.5d), Double.valueOf(-0.5d), Double.valueOf(2.0d)}));
    private static final Double[][] splitsArray = {new Double[]{Double.valueOf(-0.5d), Double.valueOf(0.0d), Double.valueOf(0.5d)}, new Double[]{Double.valueOf(-1.0d), Double.valueOf(0.0d), Double.valueOf(2.0d)}, new Double[]{Double.valueOf(Double.NEGATIVE_INFINITY), Double.valueOf(10.0d), Double.valueOf(Double.POSITIVE_INFINITY)}, new Double[]{Double.valueOf(Double.NEGATIVE_INFINITY), Double.valueOf(1.5d), Double.valueOf(Double.POSITIVE_INFINITY)}};

    @Rule
    public final TemporaryFolder tempFolder = new TemporaryFolder();
    private final List<Row> expectedKeepResult = Arrays.asList(Row.of(new Object[]{1, Double.valueOf(0.0d), Double.valueOf(1.0d), Double.valueOf(0.0d), Double.valueOf(0.0d)}), Row.of(new Object[]{2, Double.valueOf(2.0d), Double.valueOf(1.0d), Double.valueOf(1.0d), Double.valueOf(0.0d)}), Row.of(new Object[]{3, Double.valueOf(2.0d), Double.valueOf(0.0d), Double.valueOf(0.0d), Double.valueOf(1.0d)}));
    private final List<Row> expectedSkipResult = Collections.singletonList(Row.of(new Object[]{1, Double.valueOf(0.0d), Double.valueOf(1.0d), Double.valueOf(0.0d), Double.valueOf(0.0d)}));

    @Before
    public void before() {
        Configuration configuration = new Configuration();
        configuration.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
        StreamExecutionEnvironment executionEnvironment = StreamExecutionEnvironment.getExecutionEnvironment(configuration);
        executionEnvironment.setParallelism(4);
        executionEnvironment.enableCheckpointing(100L);
        executionEnvironment.setRestartStrategy(RestartStrategies.noRestart());
        this.tEnv = StreamTableEnvironment.create(executionEnvironment);
        this.inputTable = this.tEnv.fromDataStream(executionEnvironment.fromCollection(inputData)).as("id", new String[]{"f1", "f2", "f3", "f4"});
    }

    private void verifyOutputResult(Table table, String[] strArr, List<Row> list) throws Exception {
        List<Row> list2 = IteratorUtils.toList(this.tEnv.toDataStream(table).executeAndCollect());
        ArrayList arrayList = new ArrayList(list2.size());
        for (Row row : list2) {
            Row row2 = new Row(strArr.length + 1);
            row2.setField(0, row.getField("id"));
            for (int i = 0; i < strArr.length; i++) {
                row2.setField(i + 1, row.getField(strArr[i]));
            }
            arrayList.add(row2);
        }
        compareResultCollections(list, arrayList, Comparator.comparingInt(row3 -> {
            return ((Integer) row3.getField(0)).intValue();
        }));
    }

    @Test
    public void testParam() {
        Bucketizer bucketizer = new Bucketizer();
        Assert.assertEquals("error", bucketizer.getHandleInvalid());
        ((Bucketizer) ((Bucketizer) ((Bucketizer) bucketizer.setInputCols(new String[]{"f1", "f2", "f3", "f4"})).setOutputCols(new String[]{"o1", "o2", "o3", "o4"})).setHandleInvalid("skip")).setSplitsArray(splitsArray);
        Assert.assertArrayEquals(new String[]{"f1", "f2", "f3", "f4"}, bucketizer.getInputCols());
        Assert.assertArrayEquals(new String[]{"o1", "o2", "o3", "o4"}, bucketizer.getOutputCols());
        Assert.assertEquals("skip", bucketizer.getHandleInvalid());
        Double[][] splitsArray2 = bucketizer.getSplitsArray();
        Assert.assertEquals(splitsArray.length, splitsArray2.length);
        for (int i = 0; i < splitsArray.length; i++) {
            Assert.assertArrayEquals(splitsArray[i], splitsArray2[i]);
        }
    }

    @Test
    public void testOutputSchema() {
        Assert.assertEquals(Arrays.asList("id", "f1", "f2", "f3", "f4", "o1", "o2", "o3", "o4"), ((Bucketizer) ((Bucketizer) ((Bucketizer) ((Bucketizer) new Bucketizer().setInputCols(new String[]{"f1", "f2", "f3", "f4"})).setOutputCols(new String[]{"o1", "o2", "o3", "o4"})).setHandleInvalid("skip")).setSplitsArray(splitsArray)).transform(new Table[]{this.inputTable})[0].getResolvedSchema().getColumnNames());
    }

    @Test
    public void testTransform() throws Exception {
        Bucketizer bucketizer = (Bucketizer) ((Bucketizer) ((Bucketizer) new Bucketizer().setInputCols(new String[]{"f1", "f2", "f3", "f4"})).setOutputCols(new String[]{"o1", "o2", "o3", "o4"})).setSplitsArray(splitsArray);
        bucketizer.setHandleInvalid("skip");
        verifyOutputResult(bucketizer.transform(new Table[]{this.inputTable})[0], bucketizer.getOutputCols(), this.expectedSkipResult);
        bucketizer.setHandleInvalid("keep");
        verifyOutputResult(bucketizer.transform(new Table[]{this.inputTable})[0], bucketizer.getOutputCols(), this.expectedKeepResult);
        bucketizer.setHandleInvalid("error");
        try {
            IteratorUtils.toList(bucketizer.transform(new Table[]{this.inputTable})[0].execute().collect());
            Assert.fail();
        } catch (Throwable th) {
            Assert.assertEquals("The input contains invalid value. See " + HasHandleInvalid.HANDLE_INVALID + " parameter for more options.", ExceptionUtils.getRootCause(th).getMessage());
        }
    }

    @Test
    public void testInputTypeConversion() throws Exception {
        this.inputTable = TestUtils.convertDataTypesToSparseInt(this.tEnv, this.inputTable);
        Assert.assertArrayEquals(new Class[]{Integer.class, Integer.class, Integer.class, Integer.class, Integer.class}, TestUtils.getColumnDataTypes(this.inputTable));
        Bucketizer bucketizer = (Bucketizer) ((Bucketizer) ((Bucketizer) new Bucketizer().setInputCols(new String[]{"f1", "f2", "f3", "f4"})).setOutputCols(new String[]{"o1", "o2", "o3", "o4"})).setSplitsArray(splitsArray);
        bucketizer.setHandleInvalid("skip");
        verifyOutputResult(bucketizer.transform(new Table[]{this.inputTable})[0], bucketizer.getOutputCols(), Arrays.asList(Row.of(new Object[]{1, Double.valueOf(1.0d), Double.valueOf(1.0d), Double.valueOf(0.0d), Double.valueOf(0.0d)}), Row.of(new Object[]{3, Double.valueOf(1.0d), Double.valueOf(1.0d), Double.valueOf(0.0d), Double.valueOf(1.0d)})));
    }

    @Test
    public void testSaveLoadAndTransform() throws Exception {
        Bucketizer saveAndReload = TestUtils.saveAndReload(this.tEnv, (Bucketizer) ((Bucketizer) ((Bucketizer) ((Bucketizer) new Bucketizer().setInputCols(new String[]{"f1", "f2", "f3", "f4"})).setOutputCols(new String[]{"o1", "o2", "o3", "o4"})).setHandleInvalid("keep")).setSplitsArray(splitsArray), this.tempFolder.newFolder().getAbsolutePath());
        verifyOutputResult(saveAndReload.transform(new Table[]{this.inputTable})[0], saveAndReload.getOutputCols(), this.expectedKeepResult);
    }
}
