package org.apache.flink.ml.feature;

import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import org.apache.commons.lang3.exception.ExceptionUtils;
import org.apache.flink.api.common.restartstrategy.RestartStrategies;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.ml.feature.onehotencoder.OneHotEncoder;
import org.apache.flink.ml.feature.onehotencoder.OneHotEncoderModel;
import org.apache.flink.ml.feature.onehotencoder.OneHotEncoderModelData;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.ml.util.ReadWriteUtils;
import org.apache.flink.ml.util.TestUtils;
import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.types.Row;
import org.apache.flink.util.CloseableIterator;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;

/* loaded from: input_file:org/apache/flink/ml/feature/OneHotEncoderTest.class */
public class OneHotEncoderTest {

    @Rule
    public final TemporaryFolder tempFolder = new TemporaryFolder();
    private StreamExecutionEnvironment env;
    private StreamTableEnvironment tEnv;
    private Table trainTable;
    private Table predictTable;
    private Map<Double, Vector>[] expectedOutput;
    private OneHotEncoder estimator;

    @Before
    public void before() {
        Configuration configuration = new Configuration();
        configuration.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
        this.env = StreamExecutionEnvironment.getExecutionEnvironment(configuration);
        this.env.setParallelism(4);
        this.env.enableCheckpointing(100L);
        this.env.setRestartStrategy(RestartStrategies.noRestart());
        this.tEnv = StreamTableEnvironment.create(this.env);
        this.trainTable = this.tEnv.fromDataStream(this.env.fromCollection(Arrays.asList(Row.of(new Object[]{Double.valueOf(0.0d)}), Row.of(new Object[]{Double.valueOf(1.0d)}), Row.of(new Object[]{Double.valueOf(2.0d)}), Row.of(new Object[]{Double.valueOf(0.0d)})))).as("input", new String[0]);
        this.predictTable = this.tEnv.fromDataStream(this.env.fromCollection(Arrays.asList(Row.of(new Object[]{Double.valueOf(0.0d)}), Row.of(new Object[]{Double.valueOf(1.0d)}), Row.of(new Object[]{Double.valueOf(2.0d)})))).as("input", new String[0]);
        this.expectedOutput = new HashMap[]{new HashMap<Double, Vector>() { // from class: org.apache.flink.ml.feature.OneHotEncoderTest.1
            {
                put(Double.valueOf(0.0d), Vectors.sparse(2, new int[]{0}, new double[]{1.0d}));
                put(Double.valueOf(1.0d), Vectors.sparse(2, new int[]{1}, new double[]{1.0d}));
                put(Double.valueOf(2.0d), Vectors.sparse(2, new int[0], new double[0]));
            }
        }};
        this.estimator = (OneHotEncoder) ((OneHotEncoder) new OneHotEncoder().setInputCols(new String[]{"input"})).setOutputCols(new String[]{"output"});
    }

    private static Map<Double, Vector>[] executeAndCollect(Table table, String[] strArr, String[] strArr2) {
        HashMap[] hashMapArr = new HashMap[strArr.length];
        for (int i = 0; i < strArr.length; i++) {
            hashMapArr[i] = new HashMap();
        }
        CloseableIterator collect = table.execute().collect();
        while (collect.hasNext()) {
            Row row = (Row) collect.next();
            for (int i2 = 0; i2 < strArr.length; i2++) {
                hashMapArr[i2].put(Double.valueOf(((Number) row.getField(strArr[i2])).doubleValue()), (Vector) row.getField(strArr2[i2]));
            }
        }
        return hashMapArr;
    }

    @Test
    public void testParam() {
        OneHotEncoder oneHotEncoder = new OneHotEncoder();
        Assert.assertTrue(oneHotEncoder.getDropLast());
        ((OneHotEncoder) ((OneHotEncoder) oneHotEncoder.setInputCols(new String[]{"test_input"})).setOutputCols(new String[]{"test_output"})).setDropLast(false);
        Assert.assertArrayEquals(new String[]{"test_input"}, oneHotEncoder.getInputCols());
        Assert.assertArrayEquals(new String[]{"test_output"}, oneHotEncoder.getOutputCols());
        Assert.assertFalse(oneHotEncoder.getDropLast());
        OneHotEncoderModel oneHotEncoderModel = new OneHotEncoderModel();
        Assert.assertTrue(oneHotEncoderModel.getDropLast());
        ((OneHotEncoderModel) ((OneHotEncoderModel) oneHotEncoderModel.setInputCols(new String[]{"test_input"})).setOutputCols(new String[]{"test_output"})).setDropLast(false);
        Assert.assertArrayEquals(new String[]{"test_input"}, oneHotEncoderModel.getInputCols());
        Assert.assertArrayEquals(new String[]{"test_output"}, oneHotEncoderModel.getOutputCols());
        Assert.assertFalse(oneHotEncoderModel.getDropLast());
    }

    @Test
    public void testFitAndPredict() {
        OneHotEncoderModel fit = this.estimator.fit(new Table[]{this.trainTable});
        Assert.assertArrayEquals(this.expectedOutput, executeAndCollect(fit.transform(new Table[]{this.predictTable})[0], fit.getInputCols(), fit.getOutputCols()));
    }

    @Test
    public void testInputTypeConversion() throws Exception {
        this.trainTable = TestUtils.convertDataTypesToSparseInt(this.tEnv, this.trainTable);
        this.predictTable = TestUtils.convertDataTypesToSparseInt(this.tEnv, this.predictTable);
        Assert.assertArrayEquals(new Class[]{Integer.class}, TestUtils.getColumnDataTypes(this.trainTable));
        Assert.assertArrayEquals(new Class[]{Integer.class}, TestUtils.getColumnDataTypes(this.predictTable));
        OneHotEncoderModel fit = this.estimator.fit(new Table[]{this.trainTable});
        Assert.assertArrayEquals(this.expectedOutput, executeAndCollect(fit.transform(new Table[]{this.predictTable})[0], fit.getInputCols(), fit.getOutputCols()));
    }

    @Test
    public void testDropLast() {
        this.estimator.setDropLast(false);
        this.expectedOutput = new HashMap[]{new HashMap<Double, Vector>() { // from class: org.apache.flink.ml.feature.OneHotEncoderTest.2
            {
                put(Double.valueOf(0.0d), Vectors.sparse(3, new int[]{0}, new double[]{1.0d}));
                put(Double.valueOf(1.0d), Vectors.sparse(3, new int[]{1}, new double[]{1.0d}));
                put(Double.valueOf(2.0d), Vectors.sparse(3, new int[]{2}, new double[]{1.0d}));
            }
        }};
        OneHotEncoderModel fit = this.estimator.fit(new Table[]{this.trainTable});
        Assert.assertArrayEquals(this.expectedOutput, executeAndCollect(fit.transform(new Table[]{this.predictTable})[0], fit.getInputCols(), fit.getOutputCols()));
    }

    @Test
    public void testInputDataType() {
        this.trainTable = this.tEnv.fromDataStream(this.env.fromCollection(Arrays.asList(Row.of(new Object[]{0}), Row.of(new Object[]{1}), Row.of(new Object[]{2}), Row.of(new Object[]{0})))).as("input", new String[0]);
        this.predictTable = this.tEnv.fromDataStream(this.env.fromCollection(Arrays.asList(Row.of(new Object[]{0}), Row.of(new Object[]{1}), Row.of(new Object[]{2})))).as("input", new String[0]);
        this.expectedOutput = new HashMap[]{new HashMap<Double, Vector>() { // from class: org.apache.flink.ml.feature.OneHotEncoderTest.3
            {
                put(Double.valueOf(0.0d), Vectors.sparse(2, new int[]{0}, new double[]{1.0d}));
                put(Double.valueOf(1.0d), Vectors.sparse(2, new int[]{1}, new double[]{1.0d}));
                put(Double.valueOf(2.0d), Vectors.sparse(2, new int[0], new double[0]));
            }
        }};
        OneHotEncoderModel fit = this.estimator.fit(new Table[]{this.trainTable});
        Assert.assertArrayEquals(this.expectedOutput, executeAndCollect(fit.transform(new Table[]{this.predictTable})[0], fit.getInputCols(), fit.getOutputCols()));
    }

    @Test
    public void testNotSupportedHandleInvalidOptions() {
        this.estimator.setHandleInvalid("skip");
        try {
            this.estimator.fit(new Table[]{this.trainTable});
            Assert.fail("Expected IllegalArgumentException");
        } catch (Exception e) {
            Assert.assertEquals(IllegalArgumentException.class, e.getClass());
        }
    }

    @Test
    public void testNonIndexedTrainData() {
        this.trainTable = this.tEnv.fromDataStream(this.env.fromCollection(Arrays.asList(Row.of(new Object[]{Double.valueOf(0.5d)}), Row.of(new Object[]{Double.valueOf(1.0d)}), Row.of(new Object[]{Double.valueOf(2.0d)}), Row.of(new Object[]{Double.valueOf(0.0d)})))).as("input", new String[0]);
        try {
            this.estimator.fit(new Table[]{this.trainTable}).transform(new Table[]{this.predictTable})[0].execute().collect().next();
            Assert.fail("Expected IllegalArgumentException");
        } catch (Exception e) {
            Throwable rootCause = ExceptionUtils.getRootCause(e);
            Assert.assertEquals(IllegalArgumentException.class, rootCause.getClass());
            Assert.assertEquals("Value 0.5 cannot be parsed as indexed integer.", rootCause.getMessage());
        }
    }

    @Test
    public void testNonIndexedPredictData() {
        this.predictTable = this.tEnv.fromDataStream(this.env.fromCollection(Arrays.asList(Row.of(new Object[]{Double.valueOf(0.5d)}), Row.of(new Object[]{Double.valueOf(1.0d)}), Row.of(new Object[]{Double.valueOf(2.0d)}), Row.of(new Object[]{Double.valueOf(0.0d)})))).as("input", new String[0]);
        try {
            this.estimator.fit(new Table[]{this.trainTable}).transform(new Table[]{this.predictTable})[0].execute().collect().next();
            Assert.fail("Expected IllegalArgumentException");
        } catch (Exception e) {
            Throwable th = e;
            while (true) {
                Throwable th2 = th;
                if (th2.getCause() == null) {
                    Assert.assertEquals(IllegalArgumentException.class, th2.getClass());
                    Assert.assertEquals("Value 0.5 cannot be parsed as indexed integer.", th2.getMessage());
                    return;
                }
                th = th2.getCause();
            }
        }
    }

    @Test
    public void testSaveLoad() throws Exception {
        this.estimator = TestUtils.saveAndReload(this.tEnv, this.estimator, this.tempFolder.newFolder().getAbsolutePath());
        OneHotEncoderModel saveAndReload = TestUtils.saveAndReload(this.tEnv, this.estimator.fit(new Table[]{this.trainTable}), this.tempFolder.newFolder().getAbsolutePath());
        Assert.assertArrayEquals(this.expectedOutput, executeAndCollect(saveAndReload.transform(new Table[]{this.predictTable})[0], saveAndReload.getInputCols(), saveAndReload.getOutputCols()));
    }

    @Test
    public void testGetModelData() throws Exception {
        Assert.assertEquals(new Tuple2(0, 2), (Tuple2) OneHotEncoderModelData.getModelDataStream(this.estimator.fit(new Table[]{this.trainTable}).getModelData()[0]).executeAndCollect().next());
    }

    @Test
    public void testSetModelData() {
        OneHotEncoderModel fit = this.estimator.fit(new Table[]{this.trainTable});
        OneHotEncoderModel modelData = new OneHotEncoderModel().setModelData(new Table[]{fit.getModelData()[0]});
        ReadWriteUtils.updateExistingParams(modelData, fit.getParamMap());
        Assert.assertArrayEquals(this.expectedOutput, executeAndCollect(modelData.transform(new Table[]{this.predictTable})[0], modelData.getInputCols(), modelData.getOutputCols()));
    }
}
