package org.apache.flink.ml.feature;

import java.lang.invoke.SerializedLambda;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.apache.commons.collections.IteratorUtils;
import org.apache.commons.lang3.exception.ExceptionUtils;
import org.apache.flink.ml.feature.robustscaler.RobustScaler;
import org.apache.flink.ml.feature.robustscaler.RobustScalerModel;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.SparseVector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.ml.util.TestUtils;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Expressions;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.api.internal.TableImpl;
import org.apache.flink.table.expressions.Expression;
import org.apache.flink.test.util.AbstractTestBase;
import org.apache.flink.test.util.TestBaseUtils;
import org.apache.flink.types.Row;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;

/* loaded from: input_file:org/apache/flink/ml/feature/RobustScalerTest.class */
public class RobustScalerTest extends AbstractTestBase {

    @Rule
    public final TemporaryFolder tempFolder = new TemporaryFolder();
    private StreamExecutionEnvironment env;
    private StreamTableEnvironment tEnv;
    private Table trainDataTable;
    private Table predictDataTable;
    private static final double EPS = 1.0E-5d;
    private static final List<Row> TRAIN_DATA = new ArrayList(Arrays.asList(Row.of(new Object[]{0, Vectors.dense(new double[]{0.0d, 0.0d})}), Row.of(new Object[]{1, Vectors.dense(new double[]{1.0d, -1.0d})}), Row.of(new Object[]{2, Vectors.dense(new double[]{2.0d, -2.0d})}), Row.of(new Object[]{3, Vectors.dense(new double[]{3.0d, -3.0d})}), Row.of(new Object[]{4, Vectors.dense(new double[]{4.0d, -4.0d})}), Row.of(new Object[]{5, Vectors.dense(new double[]{5.0d, -5.0d})}), Row.of(new Object[]{6, Vectors.dense(new double[]{6.0d, -6.0d})}), Row.of(new Object[]{7, Vectors.dense(new double[]{7.0d, -7.0d})}), Row.of(new Object[]{8, Vectors.dense(new double[]{8.0d, -8.0d})})));
    private static final List<Row> PREDICT_DATA = new ArrayList(Arrays.asList(Row.of(new Object[]{Vectors.dense(new double[]{3.0d, -3.0d})}), Row.of(new Object[]{Vectors.dense(new double[]{6.0d, -6.0d})}), Row.of(new Object[]{Vectors.dense(new double[]{99.0d, -99.0d})})));
    private static final List<DenseVector> EXPECTED_OUTPUT = new ArrayList(Arrays.asList(Vectors.dense(new double[]{0.75d, -0.75d}), Vectors.dense(new double[]{1.5d, -1.5d}), Vectors.dense(new double[]{24.75d, -24.75d})));

    @Before
    public void before() {
        this.env = TestUtils.getExecutionEnvironment();
        this.tEnv = StreamTableEnvironment.create(this.env);
        this.trainDataTable = this.tEnv.fromDataStream(this.env.fromCollection(TRAIN_DATA)).as("id", new String[]{"input"});
        this.predictDataTable = this.tEnv.fromDataStream(this.env.fromCollection(PREDICT_DATA)).as("input", new String[0]);
    }

    private static void verifyPredictionResult(Table table, String str, List<DenseVector> list) throws Exception {
        TestBaseUtils.compareResultCollections(list, IteratorUtils.toList(((TableImpl) table).getTableEnvironment().toDataStream(table).map(row -> {
            return (DenseVector) row.getField(str);
        }).executeAndCollect()), (v0, v1) -> {
            return TestUtils.compare(v0, v1);
        });
    }

    @Test
    public void testParam() {
        RobustScaler robustScaler = new RobustScaler();
        Assert.assertEquals("input", robustScaler.getInputCol());
        Assert.assertEquals("output", robustScaler.getOutputCol());
        Assert.assertEquals(0.25d, robustScaler.getLower(), EPS);
        Assert.assertEquals(0.75d, robustScaler.getUpper(), EPS);
        Assert.assertEquals(0.001d, robustScaler.getRelativeError(), EPS);
        Assert.assertFalse(robustScaler.getWithCentering());
        Assert.assertTrue(robustScaler.getWithScaling());
        ((RobustScaler) ((RobustScaler) ((RobustScaler) ((RobustScaler) ((RobustScaler) ((RobustScaler) robustScaler.setInputCol("test_input")).setOutputCol("test_output")).setLower(Double.valueOf(0.1d))).setUpper(Double.valueOf(0.9d))).setRelativeError(0.01d)).setWithCentering(true)).setWithScaling(false);
        Assert.assertEquals("test_input", robustScaler.getInputCol());
        Assert.assertEquals("test_output", robustScaler.getOutputCol());
        Assert.assertEquals(0.1d, robustScaler.getLower(), EPS);
        Assert.assertEquals(0.9d, robustScaler.getUpper(), EPS);
        Assert.assertEquals(0.01d, robustScaler.getRelativeError(), EPS);
        Assert.assertTrue(robustScaler.getWithCentering());
        Assert.assertFalse(robustScaler.getWithScaling());
    }

    @Test
    public void testOutputSchema() {
        Table as = this.trainDataTable.as("id", new String[]{"test_input"});
        Assert.assertEquals(Arrays.asList("id", "test_input", "test_output"), ((RobustScaler) ((RobustScaler) new RobustScaler().setInputCol("test_input")).setOutputCol("test_output")).fit(new Table[]{as}).transform(new Table[]{as})[0].getResolvedSchema().getColumnNames());
    }

    @Test
    public void testFitAndPredict() throws Exception {
        RobustScaler robustScaler = new RobustScaler();
        verifyPredictionResult(robustScaler.fit(new Table[]{this.trainDataTable}).transform(new Table[]{this.predictDataTable})[0], robustScaler.getOutputCol(), EXPECTED_OUTPUT);
    }

    @Test
    public void testInputTypeConversion() throws Exception {
        this.trainDataTable = TestUtils.convertDataTypesToSparseInt(this.tEnv, this.trainDataTable.select(new Expression[]{Expressions.$("input")}));
        this.predictDataTable = TestUtils.convertDataTypesToSparseInt(this.tEnv, this.predictDataTable);
        Assert.assertArrayEquals(new Class[]{SparseVector.class}, TestUtils.getColumnDataTypes(this.trainDataTable));
        Assert.assertArrayEquals(new Class[]{SparseVector.class}, TestUtils.getColumnDataTypes(this.predictDataTable));
        RobustScaler robustScaler = new RobustScaler();
        verifyPredictionResult(robustScaler.fit(new Table[]{this.trainDataTable}).transform(new Table[]{this.predictDataTable})[0], robustScaler.getOutputCol(), EXPECTED_OUTPUT);
    }

    @Test
    public void testSaveLoadAndPredict() throws Exception {
        RobustScaler robustScaler = new RobustScaler();
        RobustScalerModel fit = TestUtils.saveAndReload(this.tEnv, robustScaler, this.tempFolder.newFolder().getAbsolutePath(), RobustScaler::load).fit(new Table[]{this.trainDataTable});
        RobustScalerModel saveAndReload = TestUtils.saveAndReload(this.tEnv, fit, this.tempFolder.newFolder().getAbsolutePath(), RobustScalerModel::load);
        Assert.assertEquals(Arrays.asList("medians", "ranges"), fit.getModelData()[0].getResolvedSchema().getColumnNames());
        verifyPredictionResult(saveAndReload.transform(new Table[]{this.predictDataTable})[0], robustScaler.getOutputCol(), EXPECTED_OUTPUT);
    }

    @Test
    public void testFitOnEmptyData() {
        try {
            new RobustScaler().fit(new Table[]{this.tEnv.fromDataStream(this.env.fromCollection(TRAIN_DATA).filter(row -> {
                return row.getArity() == 0;
            })).as("id", new String[]{"input"})}).getModelData()[0].execute().print();
            Assert.fail();
        } catch (Throwable th) {
            Assert.assertEquals("The training set is empty.", ExceptionUtils.getRootCause(th).getMessage());
        }
    }

    @Test
    public void testWithCentering() throws Exception {
        RobustScaler robustScaler = (RobustScaler) new RobustScaler().setWithCentering(true);
        verifyPredictionResult(robustScaler.fit(new Table[]{this.trainDataTable}).transform(new Table[]{this.predictDataTable})[0], robustScaler.getOutputCol(), new ArrayList(Arrays.asList(Vectors.dense(new double[]{-0.25d, 0.25d}), Vectors.dense(new double[]{0.5d, -0.5d}), Vectors.dense(new double[]{23.75d, -23.75d}))));
    }

    @Test
    public void testWithoutScaling() throws Exception {
        RobustScaler robustScaler = (RobustScaler) ((RobustScaler) new RobustScaler().setWithCentering(true)).setWithScaling(false);
        verifyPredictionResult(robustScaler.fit(new Table[]{this.trainDataTable}).transform(new Table[]{this.predictDataTable})[0], robustScaler.getOutputCol(), new ArrayList(Arrays.asList(Vectors.dense(new double[]{-1.0d, 1.0d}), Vectors.dense(new double[]{2.0d, -2.0d}), Vectors.dense(new double[]{95.0d, -95.0d}))));
    }

    @Test
    public void testIncompatibleNumOfFeatures() {
        try {
            new RobustScaler().fit(new Table[]{this.trainDataTable}).transform(new Table[]{this.tEnv.fromDataStream(this.env.fromCollection(new ArrayList(Arrays.asList(Row.of(new Object[]{Vectors.dense(new double[]{1.0d, 2.0d, 3.0d})}), Row.of(new Object[]{Vectors.dense(new double[]{-1.0d, -2.0d, -3.0d})}))))).as("input", new String[0])})[0].execute().print();
            Assert.fail();
        } catch (Throwable th) {
            Assert.assertTrue(ExceptionUtils.getRootCause(th).getMessage().contains("Number of features must be"));
        }
    }

    @Test
    public void testZeroRange() throws Exception {
        ArrayList arrayList = new ArrayList(Arrays.asList(Row.of(new Object[]{0, Vectors.dense(new double[]{0.0d, 0.0d})}), Row.of(new Object[]{1, Vectors.dense(new double[]{1.0d, 1.0d})}), Row.of(new Object[]{2, Vectors.dense(new double[]{1.0d, 1.0d})}), Row.of(new Object[]{3, Vectors.dense(new double[]{1.0d, 1.0d})}), Row.of(new Object[]{4, Vectors.dense(new double[]{4.0d, 4.0d})})));
        ArrayList arrayList2 = new ArrayList(Arrays.asList(Vectors.dense(new double[]{0.0d, -0.0d}), Vectors.dense(new double[]{0.0d, -0.0d}), Vectors.dense(new double[]{0.0d, -0.0d})));
        Table as = this.tEnv.fromDataStream(this.env.fromCollection(arrayList)).as("id", new String[]{"input"});
        RobustScaler robustScaler = new RobustScaler();
        verifyPredictionResult(robustScaler.fit(new Table[]{as}).transform(new Table[]{this.predictDataTable})[0], robustScaler.getOutputCol(), arrayList2);
    }

    @Test
    public void testNaNData() throws Exception {
        ArrayList arrayList = new ArrayList(Arrays.asList(Row.of(new Object[]{0, Vectors.dense(new double[]{0.0d, Double.NaN})}), Row.of(new Object[]{1, Vectors.dense(new double[]{Double.NaN, 0.0d})}), Row.of(new Object[]{2, Vectors.dense(new double[]{1.0d, -1.0d})}), Row.of(new Object[]{3, Vectors.dense(new double[]{2.0d, -2.0d})}), Row.of(new Object[]{4, Vectors.dense(new double[]{3.0d, -3.0d})}), Row.of(new Object[]{5, Vectors.dense(new double[]{4.0d, -4.0d})})));
        ArrayList arrayList2 = new ArrayList(Arrays.asList(Vectors.dense(new double[]{0.0d, Double.NaN}), Vectors.dense(new double[]{Double.NaN, 0.0d}), Vectors.dense(new double[]{0.5d, -0.5d}), Vectors.dense(new double[]{1.0d, -1.0d}), Vectors.dense(new double[]{1.5d, -1.5d}), Vectors.dense(new double[]{2.0d, -2.0d})));
        Table as = this.tEnv.fromDataStream(this.env.fromCollection(arrayList)).as("id", new String[]{"input"});
        RobustScaler robustScaler = new RobustScaler();
        verifyPredictionResult(robustScaler.fit(new Table[]{as}).transform(new Table[]{as})[0], robustScaler.getOutputCol(), arrayList2);
    }

    @Test
    public void testGetModelData() throws Exception {
        Table table = new RobustScaler().fit(new Table[]{this.trainDataTable}).getModelData()[0];
        Assert.assertEquals(Arrays.asList("medians", "ranges"), table.getResolvedSchema().getColumnNames());
        List list = IteratorUtils.toList(this.tEnv.toDataStream(table).executeAndCollect());
        DenseVector denseVector = (DenseVector) ((Row) list.get(0)).getField(0);
        DenseVector denseVector2 = (DenseVector) ((Row) list.get(0)).getField(1);
        DenseVector dense = Vectors.dense(new double[]{4.0d, -4.0d});
        DenseVector dense2 = Vectors.dense(new double[]{4.0d, 4.0d});
        Assert.assertEquals(dense, denseVector);
        Assert.assertEquals(dense2, denseVector2);
    }

    @Test
    public void testSetModelData() throws Exception {
        RobustScaler robustScaler = new RobustScaler();
        verifyPredictionResult(new RobustScalerModel().setModelData(new Table[]{robustScaler.fit(new Table[]{this.trainDataTable}).getModelData()[0]}).transform(new Table[]{this.predictDataTable})[0], robustScaler.getOutputCol(), EXPECTED_OUTPUT);
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case -1872509833:
                if (implMethodName.equals("lambda$verifyPredictionResult$f06f751a$1")) {
                    z = false;
                    break;
                }
                break;
            case 1077534466:
                if (implMethodName.equals("lambda$testFitOnEmptyData$3dedd8cf$1")) {
                    z = true;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/flink/api/common/functions/MapFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("map") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/flink/ml/feature/RobustScalerTest") && serializedLambda.getImplMethodSignature().equals("(Ljava/lang/String;Lorg/apache/flink/types/Row;)Lorg/apache/flink/ml/linalg/DenseVector;")) {
                    String str = (String) serializedLambda.getCapturedArg(0);
                    return row -> {
                        return (DenseVector) row.getField(str);
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/flink/api/common/functions/FilterFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("filter") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Z") && serializedLambda.getImplClass().equals("org/apache/flink/ml/feature/RobustScalerTest") && serializedLambda.getImplMethodSignature().equals("(Lorg/apache/flink/types/Row;)Z")) {
                    return row2 -> {
                        return row2.getArity() == 0;
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
