package org.apache.flink.ml.feature.stringindexer;

import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import org.apache.commons.collections.IteratorUtils;
import org.apache.commons.lang3.exception.ExceptionUtils;
import org.apache.flink.api.common.restartstrategy.RestartStrategies;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.ml.common.param.HasHandleInvalid;
import org.apache.flink.ml.util.ReadWriteUtils;
import org.apache.flink.ml.util.TestUtils;
import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.test.util.AbstractTestBase;
import org.apache.flink.types.Row;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;

/* loaded from: input_file:org/apache/flink/ml/feature/stringindexer/StringIndexerTest.class */
public class StringIndexerTest extends AbstractTestBase {
    private StreamExecutionEnvironment env;
    private StreamTableEnvironment tEnv;
    private Table trainTable;
    private Table predictTable;

    @Rule
    public final TemporaryFolder tempFolder = new TemporaryFolder();
    private final String[][] expectedAlphabeticAscModelData = {new String[]{"a", "b", "c", "d"}, new String[]{"-1.0", "0.0", "1.0", "2.0"}};
    private final List<Row> expectedAlphabeticAscPredictData = Arrays.asList(Row.of(new Object[]{"a", Double.valueOf(2.0d), Double.valueOf(0.0d), Double.valueOf(3.0d)}), Row.of(new Object[]{"b", Double.valueOf(1.0d), Double.valueOf(1.0d), Double.valueOf(2.0d)}), Row.of(new Object[]{"e", Double.valueOf(2.0d), Double.valueOf(4.0d), Double.valueOf(3.0d)}));
    private final List<Row> expectedAlphabeticDescPredictData = Arrays.asList(Row.of(new Object[]{"a", Double.valueOf(2.0d), Double.valueOf(3.0d), Double.valueOf(0.0d)}), Row.of(new Object[]{"b", Double.valueOf(1.0d), Double.valueOf(2.0d), Double.valueOf(1.0d)}), Row.of(new Object[]{"e", Double.valueOf(2.0d), Double.valueOf(4.0d), Double.valueOf(0.0d)}));
    private final List<Row> expectedFreqAscPredictData = Arrays.asList(Row.of(new Object[]{"a", Double.valueOf(2.0d), Double.valueOf(2.0d), Double.valueOf(3.0d)}), Row.of(new Object[]{"b", Double.valueOf(1.0d), Double.valueOf(3.0d), Double.valueOf(1.0d)}), Row.of(new Object[]{"e", Double.valueOf(2.0d), Double.valueOf(4.0d), Double.valueOf(3.0d)}));
    private final List<Row> expectedFreqDescPredictData = Arrays.asList(Row.of(new Object[]{"a", Double.valueOf(2.0d), Double.valueOf(1.0d), Double.valueOf(0.0d)}), Row.of(new Object[]{"b", Double.valueOf(1.0d), Double.valueOf(0.0d), Double.valueOf(2.0d)}), Row.of(new Object[]{"e", Double.valueOf(2.0d), Double.valueOf(4.0d), Double.valueOf(0.0d)}));

    @Before
    public void before() {
        Configuration configuration = new Configuration();
        configuration.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
        this.env = StreamExecutionEnvironment.getExecutionEnvironment(configuration);
        this.env.setParallelism(4);
        this.env.enableCheckpointing(100L);
        this.env.setRestartStrategy(RestartStrategies.noRestart());
        this.tEnv = StreamTableEnvironment.create(this.env);
        this.trainTable = this.tEnv.fromDataStream(this.env.fromCollection(Arrays.asList(Row.of(new Object[]{"a", Double.valueOf(1.0d)}), Row.of(new Object[]{"b", Double.valueOf(1.0d)}), Row.of(new Object[]{"b", Double.valueOf(2.0d)}), Row.of(new Object[]{"c", Double.valueOf(0.0d)}), Row.of(new Object[]{"d", Double.valueOf(2.0d)}), Row.of(new Object[]{"a", Double.valueOf(2.0d)}), Row.of(new Object[]{"b", Double.valueOf(2.0d)}), Row.of(new Object[]{"b", Double.valueOf(-1.0d)}), Row.of(new Object[]{"a", Double.valueOf(-1.0d)}), Row.of(new Object[]{"c", Double.valueOf(-1.0d)})))).as("inputCol1", new String[]{"inputCol2"});
        this.predictTable = this.tEnv.fromDataStream(this.env.fromCollection(Arrays.asList(Row.of(new Object[]{"a", Double.valueOf(2.0d)}), Row.of(new Object[]{"b", Double.valueOf(1.0d)}), Row.of(new Object[]{"e", Double.valueOf(2.0d)})))).as("inputCol1", new String[]{"inputCol2"});
    }

    @Test
    public void testParam() {
        StringIndexer stringIndexer = new StringIndexer();
        Assert.assertEquals(stringIndexer.getStringOrderType(), "arbitrary");
        Assert.assertEquals(stringIndexer.getHandleInvalid(), "error");
        ((StringIndexer) ((StringIndexer) ((StringIndexer) stringIndexer.setInputCols(new String[]{"inputCol1", "inputCol2"})).setOutputCols(new String[]{"outputCol1", "outputCol2"})).setStringOrderType("alphabetAsc")).setHandleInvalid("skip");
        Assert.assertArrayEquals(new String[]{"inputCol1", "inputCol2"}, stringIndexer.getInputCols());
        Assert.assertArrayEquals(new String[]{"outputCol1", "outputCol2"}, stringIndexer.getOutputCols());
        Assert.assertEquals(stringIndexer.getStringOrderType(), "alphabetAsc");
        Assert.assertEquals(stringIndexer.getHandleInvalid(), "skip");
    }

    @Test
    public void testOutputSchema() {
        Assert.assertEquals(Arrays.asList("inputCol1", "inputCol2", "outputCol1", "outputCol2"), ((StringIndexer) ((StringIndexer) ((StringIndexer) ((StringIndexer) new StringIndexer().setInputCols(new String[]{"inputCol1", "inputCol2"})).setOutputCols(new String[]{"outputCol1", "outputCol2"})).setStringOrderType("alphabetAsc")).setHandleInvalid("skip")).fit(new Table[]{this.trainTable}).transform(new Table[]{this.predictTable})[0].getResolvedSchema().getColumnNames());
    }

    @Test
    public void testStringOrderType() throws Exception {
        StringIndexer stringIndexer = (StringIndexer) ((StringIndexer) ((StringIndexer) new StringIndexer().setInputCols(new String[]{"inputCol1", "inputCol2"})).setOutputCols(new String[]{"outputCol1", "outputCol2"})).setHandleInvalid("keep");
        stringIndexer.setStringOrderType("alphabetAsc");
        verifyPredictionResult(this.expectedAlphabeticAscPredictData, IteratorUtils.toList(this.tEnv.toDataStream(stringIndexer.fit(new Table[]{this.trainTable}).transform(new Table[]{this.predictTable})[0]).executeAndCollect()));
        stringIndexer.setStringOrderType("alphabetDesc");
        verifyPredictionResult(this.expectedAlphabeticDescPredictData, IteratorUtils.toList(this.tEnv.toDataStream(stringIndexer.fit(new Table[]{this.trainTable}).transform(new Table[]{this.predictTable})[0]).executeAndCollect()));
        stringIndexer.setStringOrderType("frequencyAsc");
        verifyPredictionResult(this.expectedFreqAscPredictData, IteratorUtils.toList(this.tEnv.toDataStream(stringIndexer.fit(new Table[]{this.trainTable}).transform(new Table[]{this.predictTable})[0]).executeAndCollect()));
        stringIndexer.setStringOrderType("frequencyDesc");
        verifyPredictionResult(this.expectedFreqDescPredictData, IteratorUtils.toList(this.tEnv.toDataStream(stringIndexer.fit(new Table[]{this.trainTable}).transform(new Table[]{this.predictTable})[0]).executeAndCollect()));
        stringIndexer.setStringOrderType("arbitrary");
        List<Row> list = IteratorUtils.toList(this.tEnv.toDataStream(stringIndexer.fit(new Table[]{this.trainTable}).transform(new Table[]{this.predictTable})[0]).executeAndCollect());
        HashSet hashSet = new HashSet();
        HashSet hashSet2 = new HashSet();
        for (Row row : list) {
            double doubleValue = ((Double) row.getField(2)).doubleValue();
            hashSet.add(Double.valueOf(doubleValue));
            Assert.assertTrue(doubleValue >= 0.0d && doubleValue <= 4.0d);
            double doubleValue2 = ((Double) row.getField(3)).doubleValue();
            Assert.assertTrue(doubleValue2 >= 0.0d && doubleValue2 <= 3.0d);
            hashSet2.add(Double.valueOf(doubleValue2));
        }
        Assert.assertEquals(3L, hashSet.size());
        Assert.assertEquals(2L, hashSet2.size());
    }

    @Test
    public void testHandleInvalid() throws Exception {
        StringIndexer stringIndexer = (StringIndexer) ((StringIndexer) ((StringIndexer) new StringIndexer().setInputCols(new String[]{"inputCol1", "inputCol2"})).setOutputCols(new String[]{"outputCol1", "outputCol2"})).setStringOrderType("alphabetAsc");
        stringIndexer.setHandleInvalid("keep");
        verifyPredictionResult(this.expectedAlphabeticAscPredictData, IteratorUtils.toList(this.tEnv.toDataStream(stringIndexer.fit(new Table[]{this.trainTable}).transform(new Table[]{this.predictTable})[0]).executeAndCollect()));
        stringIndexer.setHandleInvalid("skip");
        verifyPredictionResult(Arrays.asList(Row.of(new Object[]{"a", Double.valueOf(2.0d), Double.valueOf(0.0d), Double.valueOf(3.0d)}), Row.of(new Object[]{"b", Double.valueOf(1.0d), Double.valueOf(1.0d), Double.valueOf(2.0d)})), IteratorUtils.toList(this.tEnv.toDataStream(stringIndexer.fit(new Table[]{this.trainTable}).transform(new Table[]{this.predictTable})[0]).executeAndCollect()));
        stringIndexer.setHandleInvalid("error");
        try {
            IteratorUtils.toList(this.tEnv.toDataStream(stringIndexer.fit(new Table[]{this.trainTable}).transform(new Table[]{this.predictTable})[0]).executeAndCollect());
            Assert.fail();
        } catch (Throwable th) {
            Assert.assertEquals("The input contains unseen string: e. See " + HasHandleInvalid.HANDLE_INVALID + " parameter for more options.", ExceptionUtils.getRootCause(th).getMessage());
        }
    }

    @Test
    public void testFitAndPredict() throws Exception {
        verifyPredictionResult(this.expectedAlphabeticAscPredictData, IteratorUtils.toList(this.tEnv.toDataStream(((StringIndexer) ((StringIndexer) ((StringIndexer) ((StringIndexer) new StringIndexer().setInputCols(new String[]{"inputCol1", "inputCol2"})).setOutputCols(new String[]{"outputCol1", "outputCol2"})).setStringOrderType("alphabetAsc")).setHandleInvalid("keep")).fit(new Table[]{this.trainTable}).transform(new Table[]{this.predictTable})[0]).executeAndCollect()));
    }

    @Test
    public void testSaveLoadAndPredict() throws Exception {
        StringIndexerModel saveAndReload = TestUtils.saveAndReload(this.tEnv, TestUtils.saveAndReload(this.tEnv, (StringIndexer) ((StringIndexer) ((StringIndexer) ((StringIndexer) new StringIndexer().setInputCols(new String[]{"inputCol1", "inputCol2"})).setOutputCols(new String[]{"outputCol1", "outputCol2"})).setStringOrderType("alphabetAsc")).setHandleInvalid("keep"), this.tempFolder.newFolder().getAbsolutePath()).fit(new Table[]{this.trainTable}), this.tempFolder.newFolder().getAbsolutePath());
        Assert.assertEquals(Collections.singletonList("stringArrays"), saveAndReload.getModelData()[0].getResolvedSchema().getColumnNames());
        verifyPredictionResult(this.expectedAlphabeticAscPredictData, IteratorUtils.toList(this.tEnv.toDataStream(saveAndReload.transform(new Table[]{this.predictTable})[0]).executeAndCollect()));
    }

    @Test
    public void testGetModelData() throws Exception {
        Table table = ((StringIndexer) ((StringIndexer) ((StringIndexer) new StringIndexer().setInputCols(new String[]{"inputCol1", "inputCol2"})).setOutputCols(new String[]{"outputCol1", "outputCol2"})).setStringOrderType("alphabetAsc")).fit(new Table[]{this.trainTable}).getModelData()[0];
        Assert.assertEquals(Collections.singletonList("stringArrays"), table.getResolvedSchema().getColumnNames());
        List list = IteratorUtils.toList(StringIndexerModelData.getModelDataStream(table).executeAndCollect());
        Assert.assertEquals(1L, list.size());
        StringIndexerModelData stringIndexerModelData = (StringIndexerModelData) list.get(0);
        Assert.assertEquals(2L, stringIndexerModelData.stringArrays.length);
        Assert.assertArrayEquals(this.expectedAlphabeticAscModelData[0], stringIndexerModelData.stringArrays[0]);
        Assert.assertArrayEquals(this.expectedAlphabeticAscModelData[1], stringIndexerModelData.stringArrays[1]);
    }

    @Test
    public void testSetModelData() throws Exception {
        StringIndexerModel fit = ((StringIndexer) ((StringIndexer) ((StringIndexer) ((StringIndexer) new StringIndexer().setInputCols(new String[]{"inputCol1", "inputCol2"})).setOutputCols(new String[]{"outputCol1", "outputCol2"})).setStringOrderType("alphabetAsc")).setHandleInvalid("keep")).fit(new Table[]{this.trainTable});
        StringIndexerModel stringIndexerModel = new StringIndexerModel();
        ReadWriteUtils.updateExistingParams(stringIndexerModel, fit.getParamMap());
        stringIndexerModel.setModelData(fit.getModelData());
        verifyPredictionResult(this.expectedAlphabeticAscPredictData, IteratorUtils.toList(this.tEnv.toDataStream(stringIndexerModel.transform(new Table[]{this.predictTable})[0]).executeAndCollect()));
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static void verifyPredictionResult(List<Row> list, List<Row> list2) {
        compareResultCollections(list, list2, (row, row2) -> {
            int min = Math.min(row.getArity(), row2.getArity());
            for (int i = 0; i < min; i++) {
                int compareTo = String.valueOf(row.getField(i)).compareTo(String.valueOf(row2.getField(i)));
                if (compareTo != 0) {
                    return compareTo;
                }
            }
            return 0;
        });
    }
}
