package org.apache.flink.ml.feature.stringindexer;

import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import org.apache.commons.collections.IteratorUtils;
import org.apache.commons.lang3.exception.ExceptionUtils;
import org.apache.flink.api.common.restartstrategy.RestartStrategies;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.ml.util.TestUtils;
import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.test.util.AbstractTestBase;
import org.apache.flink.types.Row;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;

/* loaded from: input_file:org/apache/flink/ml/feature/stringindexer/IndexToStringModelTest.class */
public class IndexToStringModelTest extends AbstractTestBase {
    private StreamExecutionEnvironment env;
    private StreamTableEnvironment tEnv;
    private Table predictTable;
    private Table modelTable;
    private Table predictTableWithUnseenValues;

    @Rule
    public final TemporaryFolder tempFolder = new TemporaryFolder();
    private final List<Row> expectedPrediction = Arrays.asList(Row.of(new Object[]{0, 3, "a", "2.0"}), Row.of(new Object[]{1, 2, "b", "1.0"}));
    private final String[][] stringArrays = {new String[]{"a", "b", "c", "d"}, new String[]{"-1.0", "0.0", "1.0", "2.0"}};

    @Before
    public void before() {
        Configuration configuration = new Configuration();
        configuration.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
        this.env = StreamExecutionEnvironment.getExecutionEnvironment(configuration);
        this.env.setParallelism(4);
        this.env.enableCheckpointing(100L);
        this.env.setRestartStrategy(RestartStrategies.noRestart());
        this.tEnv = StreamTableEnvironment.create(this.env);
        this.modelTable = this.tEnv.fromDataStream(this.env.fromElements(new StringIndexerModelData[]{new StringIndexerModelData(this.stringArrays)})).as("stringArrays", new String[0]);
        this.predictTable = this.tEnv.fromDataStream(this.env.fromCollection(Arrays.asList(Row.of(new Object[]{0, 3}), Row.of(new Object[]{1, 2})))).as("inputCol1", new String[]{"inputCol2"});
        this.predictTableWithUnseenValues = this.tEnv.fromDataStream(this.env.fromCollection(Arrays.asList(Row.of(new Object[]{0, 3}), Row.of(new Object[]{1, 2}), Row.of(new Object[]{4, 1})))).as("inputCol1", new String[]{"inputCol2"});
    }

    @Test
    public void testOutputSchema() {
        Assert.assertEquals(Arrays.asList("inputCol1", "inputCol2", "outputCol1", "outputCol2"), ((IndexToStringModel) ((IndexToStringModel) new IndexToStringModel().setInputCols(new String[]{"inputCol1", "inputCol2"})).setOutputCols(new String[]{"outputCol1", "outputCol2"})).setModelData(new Table[]{this.modelTable}).transform(new Table[]{this.predictTable})[0].getResolvedSchema().getColumnNames());
    }

    @Test
    public void testInputWithUnseenValues() {
        try {
            IteratorUtils.toList(this.tEnv.toDataStream(((IndexToStringModel) ((IndexToStringModel) new IndexToStringModel().setInputCols(new String[]{"inputCol1", "inputCol2"})).setOutputCols(new String[]{"outputCol1", "outputCol2"})).setModelData(new Table[]{this.modelTable}).transform(new Table[]{this.predictTableWithUnseenValues})[0]).executeAndCollect());
            Assert.fail();
        } catch (Throwable th) {
            Assert.assertEquals("The input contains unseen index: 4.", ExceptionUtils.getRootCause(th).getMessage());
        }
    }

    @Test
    public void testPredict() throws Exception {
        StringIndexerTest.verifyPredictionResult(this.expectedPrediction, IteratorUtils.toList(this.tEnv.toDataStream(((IndexToStringModel) ((IndexToStringModel) new IndexToStringModel().setInputCols(new String[]{"inputCol1", "inputCol2"})).setOutputCols(new String[]{"outputCol1", "outputCol2"})).setModelData(new Table[]{this.modelTable}).transform(new Table[]{this.predictTable})[0]).executeAndCollect()));
    }

    @Test
    public void testSaveLoadAndPredict() throws Exception {
        IndexToStringModel saveAndReload = TestUtils.saveAndReload(this.tEnv, ((IndexToStringModel) ((IndexToStringModel) new IndexToStringModel().setInputCols(new String[]{"inputCol1", "inputCol2"})).setOutputCols(new String[]{"outputCol1", "outputCol2"})).setModelData(new Table[]{this.modelTable}), this.tempFolder.newFolder().getAbsolutePath());
        Assert.assertEquals(Collections.singletonList("stringArrays"), saveAndReload.getModelData()[0].getResolvedSchema().getColumnNames());
        StringIndexerTest.verifyPredictionResult(this.expectedPrediction, IteratorUtils.toList(this.tEnv.toDataStream(saveAndReload.transform(new Table[]{this.predictTable})[0]).executeAndCollect()));
    }

    @Test
    public void testGetModelData() throws Exception {
        List list = IteratorUtils.toList(StringIndexerModelData.getModelDataStream(((IndexToStringModel) ((IndexToStringModel) new IndexToStringModel().setInputCols(new String[]{"inputCol1", "inputCol2"})).setOutputCols(new String[]{"outputCol1", "outputCol2"})).setModelData(new Table[]{this.modelTable}).getModelData()[0]).executeAndCollect());
        Assert.assertEquals(1L, list.size());
        StringIndexerModelData stringIndexerModelData = (StringIndexerModelData) list.get(0);
        Assert.assertEquals(2L, stringIndexerModelData.stringArrays.length);
        Assert.assertArrayEquals(this.stringArrays[0], stringIndexerModelData.stringArrays[0]);
        Assert.assertArrayEquals(this.stringArrays[1], stringIndexerModelData.stringArrays[1]);
    }
}
