package org.apache.flink.ml.feature;

import java.lang.invoke.SerializedLambda;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import org.apache.commons.collections.IteratorUtils;
import org.apache.commons.lang3.exception.ExceptionUtils;
import org.apache.flink.ml.feature.maxabsscaler.MaxAbsScaler;
import org.apache.flink.ml.feature.maxabsscaler.MaxAbsScalerModel;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.TestUtils;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.api.internal.TableImpl;
import org.apache.flink.test.util.TestBaseUtils;
import org.apache.flink.types.Row;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;

/* loaded from: input_file:org/apache/flink/ml/feature/MaxAbsScalerTest.class */
public class MaxAbsScalerTest {

    @Rule
    public final TemporaryFolder tempFolder = new TemporaryFolder();
    private StreamTableEnvironment tEnv;
    private StreamExecutionEnvironment env;
    private Table trainDataTable;
    private Table predictDataTable;
    private Table trainSparseDataTable;
    private Table predictSparseDataTable;
    private static final List<Row> TRAIN_DATA = new ArrayList(Arrays.asList(Row.of(new Object[]{Vectors.dense(new double[]{0.0d, 3.0d, 0.0d})}), Row.of(new Object[]{Vectors.dense(new double[]{2.1d, 0.0d, 0.0d})}), Row.of(new Object[]{Vectors.dense(new double[]{4.1d, 5.1d, 0.0d})}), Row.of(new Object[]{Vectors.dense(new double[]{6.1d, 8.1d, 0.0d})}), Row.of(new Object[]{Vectors.dense(new double[]{200.0d, -400.0d, 0.0d})})));
    private static final List<Row> PREDICT_DATA = new ArrayList(Arrays.asList(Row.of(new Object[]{Vectors.dense(new double[]{150.0d, 90.0d, 1.0d})}), Row.of(new Object[]{Vectors.dense(new double[]{50.0d, 40.0d, 1.0d})}), Row.of(new Object[]{Vectors.dense(new double[]{100.0d, 50.0d, 0.5d})})));
    private static final List<Row> TRAIN_SPARSE_DATA = new ArrayList(Arrays.asList(Row.of(new Object[]{Vectors.sparse(4, new int[]{1, 3}, new double[]{4.0d, 3.0d})}), Row.of(new Object[]{Vectors.sparse(4, new int[]{0, 2}, new double[]{2.0d, -6.0d})}), Row.of(new Object[]{Vectors.sparse(4, new int[]{1, 2}, new double[]{1.0d, 3.0d})}), Row.of(new Object[]{Vectors.sparse(4, new int[]{0, 1}, new double[]{2.0d, 8.0d})}), Row.of(new Object[]{Vectors.sparse(4, new int[]{1, 3}, new double[]{1.0d, 5.0d})})));
    private static final List<Row> PREDICT_SPARSE_DATA = new ArrayList(Arrays.asList(Row.of(new Object[]{Vectors.sparse(4, new int[]{0, 1}, new double[]{2.0d, 4.0d})}), Row.of(new Object[]{Vectors.sparse(4, new int[]{0, 2}, new double[]{1.0d, 3.0d})}), Row.of(new Object[]{Vectors.sparse(4, new int[0], new double[0])}), Row.of(new Object[]{Vectors.sparse(4, new int[]{1, 3}, new double[]{1.0d, 2.0d})})));
    private static final List<Vector> EXPECTED_DATA = new ArrayList(Arrays.asList(Vectors.dense(new double[]{0.25d, 0.1d, 1.0d}), Vectors.dense(new double[]{0.5d, 0.125d, 0.5d}), Vectors.dense(new double[]{0.75d, 0.225d, 1.0d})));
    private static final List<Vector> EXPECTED_SPARSE_DATA = new ArrayList(Arrays.asList(Vectors.sparse(4, new int[]{0, 1}, new double[]{1.0d, 0.5d}), Vectors.sparse(4, new int[]{0, 2}, new double[]{0.5d, 0.5d}), Vectors.sparse(4, new int[0], new double[0]), Vectors.sparse(4, new int[]{1, 3}, new double[]{0.125d, 0.4d})));

    @Before
    public void before() {
        this.env = TestUtils.getExecutionEnvironment();
        this.tEnv = StreamTableEnvironment.create(this.env);
        this.trainDataTable = this.tEnv.fromDataStream(this.env.fromCollection(TRAIN_DATA)).as("input", new String[0]);
        this.predictDataTable = this.tEnv.fromDataStream(this.env.fromCollection(PREDICT_DATA)).as("input", new String[0]);
        this.trainSparseDataTable = this.tEnv.fromDataStream(this.env.fromCollection(TRAIN_SPARSE_DATA)).as("input", new String[0]);
        this.predictSparseDataTable = this.tEnv.fromDataStream(this.env.fromCollection(PREDICT_SPARSE_DATA)).as("input", new String[0]);
    }

    private static void verifyPredictionResult(Table table, String str, List<Vector> list) throws Exception {
        TestBaseUtils.compareResultCollections(list, IteratorUtils.toList(((TableImpl) table).getTableEnvironment().toDataStream(table).map(row -> {
            return (Vector) row.getFieldAs(str);
        }, VectorTypeInfo.INSTANCE).executeAndCollect()), TestUtils::compare);
    }

    @Test
    public void testParam() {
        MaxAbsScaler maxAbsScaler = new MaxAbsScaler();
        Assert.assertEquals("input", maxAbsScaler.getInputCol());
        Assert.assertEquals("output", maxAbsScaler.getOutputCol());
        ((MaxAbsScaler) maxAbsScaler.setInputCol("test_input")).setOutputCol("test_output");
        Assert.assertEquals("test_input", maxAbsScaler.getInputCol());
        Assert.assertEquals("test_output", maxAbsScaler.getOutputCol());
    }

    @Test
    public void testOutputSchema() {
        Assert.assertEquals(Arrays.asList("test_input", "test_output"), ((MaxAbsScaler) ((MaxAbsScaler) new MaxAbsScaler().setInputCol("test_input")).setOutputCol("test_output")).fit(new Table[]{this.trainDataTable.as("test_input", new String[0])}).transform(new Table[]{this.predictDataTable.as("test_input", new String[0])})[0].getResolvedSchema().getColumnNames());
    }

    @Test
    public void testFitAndPredict() throws Exception {
        MaxAbsScaler maxAbsScaler = new MaxAbsScaler();
        verifyPredictionResult(maxAbsScaler.fit(new Table[]{this.trainDataTable}).transform(new Table[]{this.predictDataTable})[0], maxAbsScaler.getOutputCol(), EXPECTED_DATA);
    }

    @Test
    public void testFitDataWithNullValue() {
        try {
            IteratorUtils.toList(this.tEnv.toDataStream(new MaxAbsScaler().fit(new Table[]{this.tEnv.fromDataStream(this.env.fromCollection(new ArrayList(Arrays.asList(Row.of(new Object[]{Vectors.dense(new double[]{0.0d, 3.0d})}), Row.of(new Object[]{Vectors.dense(new double[]{2.1d, 0.0d})}), Row.of(new Object[]{null}), Row.of(new Object[]{Vectors.dense(new double[]{6.1d, 8.1d})}), Row.of(new Object[]{Vectors.dense(new double[]{200.0d, 400.0d})}))))).as("input", new String[0])}).getModelData()[0]).executeAndCollect());
            Assert.fail();
        } catch (Exception e) {
            Assert.assertEquals("The vector must not be null.", ExceptionUtils.getRootCause(e).getMessage());
        }
    }

    @Test
    public void testFitAndPredictSparse() throws Exception {
        MaxAbsScaler maxAbsScaler = new MaxAbsScaler();
        verifyPredictionResult(maxAbsScaler.fit(new Table[]{this.trainSparseDataTable}).transform(new Table[]{this.predictSparseDataTable})[0], maxAbsScaler.getOutputCol(), EXPECTED_SPARSE_DATA);
    }

    @Test
    public void testSaveLoadAndPredict() throws Exception {
        MaxAbsScaler maxAbsScaler = new MaxAbsScaler();
        verifyPredictionResult(TestUtils.saveAndReload(this.tEnv, TestUtils.saveAndReload(this.tEnv, maxAbsScaler, this.tempFolder.newFolder().getAbsolutePath(), MaxAbsScaler::load).fit(new Table[]{this.trainDataTable}), this.tempFolder.newFolder().getAbsolutePath(), MaxAbsScalerModel::load).transform(new Table[]{this.predictDataTable})[0], maxAbsScaler.getOutputCol(), EXPECTED_DATA);
    }

    @Test
    public void testGetModelData() throws Exception {
        Table table = new MaxAbsScaler().fit(new Table[]{this.trainDataTable}).getModelData()[0];
        Assert.assertEquals(Collections.singletonList("maxVector"), table.getResolvedSchema().getColumnNames());
        Assert.assertEquals(new DenseVector(new double[]{200.0d, 400.0d, 0.0d}), ((Row) IteratorUtils.toList(this.tEnv.toDataStream(table).executeAndCollect()).get(0)).getField(0));
    }

    @Test
    public void testSetModelData() throws Exception {
        MaxAbsScaler maxAbsScaler = new MaxAbsScaler();
        MaxAbsScalerModel fit = maxAbsScaler.fit(new Table[]{this.trainDataTable});
        MaxAbsScalerModel modelData = new MaxAbsScalerModel().setModelData(new Table[]{fit.getModelData()[0]});
        ParamUtils.updateExistingParams(modelData, fit.getParamMap());
        verifyPredictionResult(modelData.transform(new Table[]{this.predictDataTable})[0], maxAbsScaler.getOutputCol(), EXPECTED_DATA);
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case -2045060237:
                if (implMethodName.equals("lambda$verifyPredictionResult$f4fb6055$1")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/flink/api/common/functions/MapFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("map") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/flink/ml/feature/MaxAbsScalerTest") && serializedLambda.getImplMethodSignature().equals("(Ljava/lang/String;Lorg/apache/flink/types/Row;)Lorg/apache/flink/ml/linalg/Vector;")) {
                    String str = (String) serializedLambda.getCapturedArg(0);
                    return row -> {
                        return (Vector) row.getFieldAs(str);
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
