package org.apache.flink.ml.feature;

import java.lang.invoke.SerializedLambda;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.apache.commons.collections.IteratorUtils;
import org.apache.commons.lang3.exception.ExceptionUtils;
import org.apache.flink.ml.feature.imputer.Imputer;
import org.apache.flink.ml.feature.imputer.ImputerModel;
import org.apache.flink.ml.util.TestUtils;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.api.internal.TableImpl;
import org.apache.flink.test.util.AbstractTestBase;
import org.apache.flink.test.util.TestBaseUtils;
import org.apache.flink.types.Row;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;

/* loaded from: input_file:org/apache/flink/ml/feature/ImputerTest.class */
public class ImputerTest extends AbstractTestBase {

    @Rule
    public final TemporaryFolder tempFolder = new TemporaryFolder();
    private StreamExecutionEnvironment env;
    private StreamTableEnvironment tEnv;
    private Table trainDataTable;
    private Table predictDataTable;
    private static final double EPS = 1.0E-5d;
    private static final List<Row> TRAIN_DATA;
    private static final List<Row> EXPECTED_MEAN_STRATEGY_OUTPUT;
    private static final List<Row> EXPECTED_MEDIAN_STRATEGY_OUTPUT;
    private static final List<Row> EXPECTED_MOST_FREQUENT_STRATEGY_OUTPUT;
    private static final Map<String, List<Row>> strategyAndExpectedOutputs;
    static final /* synthetic */ boolean $assertionsDisabled;

    @Before
    public void before() {
        this.env = TestUtils.getExecutionEnvironment();
        this.tEnv = StreamTableEnvironment.create(this.env);
        this.trainDataTable = this.tEnv.fromDataStream(this.env.fromCollection(TRAIN_DATA)).as("f1", new String[]{"f2", "f3", "f4"});
        this.predictDataTable = this.tEnv.fromDataStream(this.env.fromCollection(TRAIN_DATA)).as("f1", new String[]{"f2", "f3", "f4"});
    }

    private static void verifyPredictionResult(Table table, List<String> list, List<Row> list2) throws Exception {
        TestBaseUtils.compareResultCollections(list2, (List) IteratorUtils.toList(((TableImpl) table).getTableEnvironment().toDataStream(table).executeAndCollect()).stream().map(row -> {
            Row row = new Row(list.size());
            for (int i = 0; i < list.size(); i++) {
                row.setField(i, row.getField((String) list.get(i)));
            }
            return row;
        }).collect(Collectors.toList()), (row2, row3) -> {
            int min = Math.min(row2.getArity(), row3.getArity());
            for (int i = 0; i < min; i++) {
                int compareTo = String.valueOf(row2.getField(i)).compareTo(String.valueOf(row3.getField(i)));
                if (compareTo != 0) {
                    return compareTo;
                }
            }
            return 0;
        });
    }

    @Test
    public void testParam() {
        Imputer imputer = (Imputer) ((Imputer) new Imputer().setInputCols(new String[]{"f1", "f2", "f3", "f4"})).setOutputCols(new String[]{"o1", "o2", "o3", "o4"});
        Assert.assertArrayEquals(new String[]{"f1", "f2", "f3", "f4"}, imputer.getInputCols());
        Assert.assertArrayEquals(new String[]{"o1", "o2", "o3", "o4"}, imputer.getOutputCols());
        Assert.assertEquals("mean", imputer.getStrategy());
        Assert.assertEquals(Double.NaN, imputer.getMissingValue(), EPS);
        Assert.assertEquals(0.001d, imputer.getRelativeError(), EPS);
        ((Imputer) ((Imputer) ((Imputer) ((Imputer) imputer.setMissingValue(0.0d)).setStrategy("median")).setRelativeError(0.1d)).setInputCols(new String[]{"f1", "f2"})).setOutputCols(new String[]{"o1", "o2"});
        Assert.assertEquals("median", imputer.getStrategy());
        Assert.assertEquals(0.0d, imputer.getMissingValue(), EPS);
        Assert.assertEquals(0.1d, imputer.getRelativeError(), EPS);
        Assert.assertArrayEquals(new String[]{"f1", "f2"}, imputer.getInputCols());
        Assert.assertArrayEquals(new String[]{"o1", "o2"}, imputer.getOutputCols());
    }

    @Test
    public void testOutputSchema() {
        Assert.assertEquals(Arrays.asList("f1", "f2", "f3", "f4", "o1", "o2", "o3", "o4"), ((Imputer) ((Imputer) new Imputer().setInputCols(new String[]{"f1", "f2", "f3", "f4"})).setOutputCols(new String[]{"o1", "o2", "o3", "o4"})).fit(new Table[]{this.trainDataTable}).transform(new Table[]{this.predictDataTable})[0].getResolvedSchema().getColumnNames());
    }

    @Test
    public void testFitAndPredict() throws Exception {
        for (Map.Entry<String, List<Row>> entry : strategyAndExpectedOutputs.entrySet()) {
            verifyPredictionResult(((Imputer) ((Imputer) ((Imputer) new Imputer().setInputCols(new String[]{"f1", "f2", "f3", "f4"})).setOutputCols(new String[]{"o1", "o2", "o3", "o4"})).setStrategy(entry.getKey())).fit(new Table[]{this.trainDataTable}).transform(new Table[]{this.predictDataTable})[0], Arrays.asList("o1", "o2", "o3", "o4"), entry.getValue());
        }
    }

    @Test
    public void testSaveLoadAndPredict() throws Exception {
        Imputer imputer = (Imputer) ((Imputer) new Imputer().setInputCols(new String[]{"f1", "f2", "f3", "f4"})).setOutputCols(new String[]{"o1", "o2", "o3", "o4"});
        ImputerModel fit = TestUtils.saveAndReload(this.tEnv, imputer, this.tempFolder.newFolder().getAbsolutePath(), Imputer::load).fit(new Table[]{this.trainDataTable});
        ImputerModel saveAndReload = TestUtils.saveAndReload(this.tEnv, fit, this.tempFolder.newFolder().getAbsolutePath(), ImputerModel::load);
        Assert.assertEquals(Collections.singletonList("surrogates"), fit.getModelData()[0].getResolvedSchema().getColumnNames());
        verifyPredictionResult(saveAndReload.transform(new Table[]{this.predictDataTable})[0], Arrays.asList(imputer.getOutputCols()), EXPECTED_MEAN_STRATEGY_OUTPUT);
    }

    @Test
    public void testFitOnEmptyData() {
        Table as = this.tEnv.fromDataStream(this.env.fromCollection(TRAIN_DATA).filter(row -> {
            return row.getArity() == 0;
        })).as("f1", new String[]{"f2", "f3", "f4"});
        strategyAndExpectedOutputs.remove("median");
        Iterator<Map.Entry<String, List<Row>>> it = strategyAndExpectedOutputs.entrySet().iterator();
        while (it.hasNext()) {
            try {
                ((Imputer) ((Imputer) ((Imputer) new Imputer().setInputCols(new String[]{"f1", "f2", "f3", "f4"})).setOutputCols(new String[]{"o1", "o2", "o3", "o4"})).setStrategy(it.next().getKey())).fit(new Table[]{as}).getModelData()[0].execute().print();
                Assert.fail();
            } catch (Throwable th) {
                Assert.assertEquals("The training set is empty or does not contains valid data.", ExceptionUtils.getRootCause(th).getMessage());
            }
        }
    }

    @Test
    public void testNoValidDataOnMedianStrategy() {
        this.trainDataTable = this.tEnv.fromDataStream(this.env.fromCollection(new ArrayList(Arrays.asList(Row.of(new Object[]{Double.valueOf(Double.NaN), Float.valueOf(3.0f)}), Row.of(new Object[]{null, Float.valueOf(2.0f)}), Row.of(new Object[]{Double.valueOf(1.0d), Float.valueOf(1.0f)}))))).as("f1", new String[]{"f2"});
        try {
            ((Imputer) ((Imputer) ((Imputer) ((Imputer) new Imputer().setInputCols(new String[]{"f1", "f2"})).setOutputCols(new String[]{"o1", "o2"})).setStrategy("median")).setMissingValue(1.0d)).fit(new Table[]{this.trainDataTable}).getModelData()[0].execute().print();
            Assert.fail();
        } catch (Throwable th) {
            Assert.assertEquals("Surrogate cannot be computed. All the values in column [f1] are null, NaN or missingValue.", ExceptionUtils.getRootCause(th).getMessage());
        }
    }

    @Test
    public void testMultipleModeOnMostFrequentStrategy() throws Exception {
        this.trainDataTable = this.tEnv.fromDataStream(this.env.fromCollection(new ArrayList(Arrays.asList(Row.of(new Object[]{Double.valueOf(1.0d), Double.valueOf(2.0d)}), Row.of(new Object[]{Double.valueOf(1.0d), Double.valueOf(2.0d)}), Row.of(new Object[]{Double.valueOf(2.0d), Double.valueOf(1.0d)}), Row.of(new Object[]{Double.valueOf(2.0d), Double.valueOf(1.0d)}))))).as("f1", new String[]{"f2"});
        Map map = (Map) ((Row) IteratorUtils.toList(this.tEnv.toDataStream(((Imputer) ((Imputer) ((Imputer) new Imputer().setInputCols(new String[]{"f1", "f2"})).setOutputCols(new String[]{"o1", "o2"})).setStrategy("most_frequent")).fit(new Table[]{this.trainDataTable}).getModelData()[0]).executeAndCollect()).get(0)).getField(0);
        if (!$assertionsDisabled && map == null) {
            throw new AssertionError();
        }
        Assert.assertEquals(1.0d, ((Double) map.get("f1")).doubleValue(), EPS);
        Assert.assertEquals(1.0d, ((Double) map.get("f2")).doubleValue(), EPS);
    }

    @Test
    public void testInconsistentInputsAndOutputs() {
        try {
            ((Imputer) ((Imputer) new Imputer().setInputCols(new String[]{"f1", "f2", "f3", "f4"})).setOutputCols(new String[]{"o1", "o2", "o3"})).fit(new Table[]{this.trainDataTable});
            Assert.fail();
        } catch (Throwable th) {
            Assert.assertEquals("Num of input columns and output columns are inconsistent.", th.getMessage());
        }
    }

    @Test
    public void testGetModelData() throws Exception {
        Table table = ((Imputer) ((Imputer) new Imputer().setInputCols(new String[]{"f1", "f2", "f3", "f4"})).setOutputCols(new String[]{"o1", "o2", "o3", "o4"})).fit(new Table[]{this.trainDataTable}).getModelData()[0];
        Assert.assertEquals(Collections.singletonList("surrogates"), table.getResolvedSchema().getColumnNames());
        Map map = (Map) ((Row) IteratorUtils.toList(this.tEnv.toDataStream(table).executeAndCollect()).get(0)).getField(0);
        if (!$assertionsDisabled && map == null) {
            throw new AssertionError();
        }
        Assert.assertEquals(2.0d, ((Double) map.get("f1")).doubleValue(), EPS);
        Assert.assertEquals(6.8d, ((Double) map.get("f2")).doubleValue(), EPS);
        Assert.assertEquals(2.0d, ((Double) map.get("f3")).doubleValue(), EPS);
        Assert.assertEquals(6.8d, ((Double) map.get("f4")).doubleValue(), EPS);
    }

    @Test
    public void testSetModelData() throws Exception {
        Imputer imputer = (Imputer) ((Imputer) new Imputer().setInputCols(new String[]{"f1", "f2", "f3", "f4"})).setOutputCols(new String[]{"o1", "o2", "o3", "o4"});
        verifyPredictionResult(((ImputerModel) ((ImputerModel) new ImputerModel().setModelData(new Table[]{imputer.fit(new Table[]{this.trainDataTable}).getModelData()[0]}).setInputCols(new String[]{"f1", "f2", "f3", "f4"})).setOutputCols(new String[]{"o1", "o2", "o3", "o4"})).transform(new Table[]{this.predictDataTable})[0], Arrays.asList(imputer.getOutputCols()), EXPECTED_MEAN_STRATEGY_OUTPUT);
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case 1077534466:
                if (implMethodName.equals("lambda$testFitOnEmptyData$3dedd8cf$1")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/flink/api/common/functions/FilterFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("filter") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Z") && serializedLambda.getImplClass().equals("org/apache/flink/ml/feature/ImputerTest") && serializedLambda.getImplMethodSignature().equals("(Lorg/apache/flink/types/Row;)Z")) {
                    return row -> {
                        return row.getArity() == 0;
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }

    static {
        $assertionsDisabled = !ImputerTest.class.desiredAssertionStatus();
        TRAIN_DATA = new ArrayList(Arrays.asList(Row.of(new Object[]{Double.valueOf(Double.NaN), Double.valueOf(9.0d), 1, Float.valueOf(9.0f)}), Row.of(new Object[]{Double.valueOf(1.0d), Double.valueOf(9.0d), null, Float.valueOf(9.0f)}), Row.of(new Object[]{Double.valueOf(1.5d), Double.valueOf(7.0d), 1, Float.valueOf(7.0f)}), Row.of(new Object[]{Double.valueOf(1.5d), Double.valueOf(Double.NaN), 2, Float.valueOf(Float.NaN)}), Row.of(new Object[]{Double.valueOf(4.0d), Double.valueOf(5.0d), 4, Float.valueOf(5.0f)}), Row.of(new Object[]{null, Double.valueOf(4.0d), null, Float.valueOf(4.0f)})));
        EXPECTED_MEAN_STRATEGY_OUTPUT = new ArrayList(Arrays.asList(Row.of(new Object[]{Double.valueOf(2.0d), Double.valueOf(9.0d), Double.valueOf(1.0d), Double.valueOf(9.0d)}), Row.of(new Object[]{Double.valueOf(1.0d), Double.valueOf(9.0d), Double.valueOf(2.0d), Double.valueOf(9.0d)}), Row.of(new Object[]{Double.valueOf(1.5d), Double.valueOf(7.0d), Double.valueOf(1.0d), Double.valueOf(7.0d)}), Row.of(new Object[]{Double.valueOf(1.5d), Double.valueOf(6.8d), Double.valueOf(2.0d), Double.valueOf(6.8d)}), Row.of(new Object[]{Double.valueOf(4.0d), Double.valueOf(5.0d), Double.valueOf(4.0d), Double.valueOf(5.0d)}), Row.of(new Object[]{Double.valueOf(2.0d), Double.valueOf(4.0d), Double.valueOf(2.0d), Double.valueOf(4.0d)})));
        EXPECTED_MEDIAN_STRATEGY_OUTPUT = new ArrayList(Arrays.asList(Row.of(new Object[]{Double.valueOf(1.5d), Double.valueOf(9.0d), Double.valueOf(1.0d), Double.valueOf(9.0d)}), Row.of(new Object[]{Double.valueOf(1.0d), Double.valueOf(9.0d), Double.valueOf(1.0d), Double.valueOf(9.0d)}), Row.of(new Object[]{Double.valueOf(1.5d), Double.valueOf(7.0d), Double.valueOf(1.0d), Double.valueOf(7.0d)}), Row.of(new Object[]{Double.valueOf(1.5d), Double.valueOf(7.0d), Double.valueOf(2.0d), Double.valueOf(7.0d)}), Row.of(new Object[]{Double.valueOf(4.0d), Double.valueOf(5.0d), Double.valueOf(4.0d), Double.valueOf(5.0d)}), Row.of(new Object[]{Double.valueOf(1.5d), Double.valueOf(4.0d), Double.valueOf(1.0d), Double.valueOf(4.0d)})));
        EXPECTED_MOST_FREQUENT_STRATEGY_OUTPUT = new ArrayList(Arrays.asList(Row.of(new Object[]{Double.valueOf(1.5d), Double.valueOf(9.0d), Double.valueOf(1.0d), Double.valueOf(9.0d)}), Row.of(new Object[]{Double.valueOf(1.0d), Double.valueOf(9.0d), Double.valueOf(1.0d), Double.valueOf(9.0d)}), Row.of(new Object[]{Double.valueOf(1.5d), Double.valueOf(7.0d), Double.valueOf(1.0d), Double.valueOf(7.0d)}), Row.of(new Object[]{Double.valueOf(1.5d), Double.valueOf(9.0d), Double.valueOf(2.0d), Double.valueOf(9.0d)}), Row.of(new Object[]{Double.valueOf(4.0d), Double.valueOf(5.0d), Double.valueOf(4.0d), Double.valueOf(5.0d)}), Row.of(new Object[]{Double.valueOf(1.5d), Double.valueOf(4.0d), Double.valueOf(1.0d), Double.valueOf(4.0d)})));
        strategyAndExpectedOutputs = new HashMap<String, List<Row>>() { // from class: org.apache.flink.ml.feature.ImputerTest.1
            {
                put("mean", ImputerTest.EXPECTED_MEAN_STRATEGY_OUTPUT);
                put("median", ImputerTest.EXPECTED_MEDIAN_STRATEGY_OUTPUT);
                put("most_frequent", ImputerTest.EXPECTED_MOST_FREQUENT_STRATEGY_OUTPUT);
            }
        };
    }
}
