package org.apache.flink.ml.feature;

import java.lang.invoke.SerializedLambda;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import org.apache.commons.collections.IteratorUtils;
import org.apache.commons.lang3.exception.ExceptionUtils;
import org.apache.flink.ml.feature.standardscaler.StandardScaler;
import org.apache.flink.ml.feature.standardscaler.StandardScalerModel;
import org.apache.flink.ml.feature.standardscaler.StandardScalerModelData;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.SparseVector;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.TestUtils;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.test.util.AbstractTestBase;
import org.apache.flink.types.Row;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;

/* loaded from: input_file:org/apache/flink/ml/feature/StandardScalerTest.class */
public class StandardScalerTest extends AbstractTestBase {
    private StreamExecutionEnvironment env;
    private StreamTableEnvironment tEnv;
    private Table denseTable;
    private static final double TOLERANCE = 1.0E-7d;

    @Rule
    public final TemporaryFolder tempFolder = new TemporaryFolder();
    private final List<Row> denseInput = Arrays.asList(Row.of(new Object[]{Vectors.dense(new double[]{-2.5d, 9.0d, 1.0d})}), Row.of(new Object[]{Vectors.dense(new double[]{1.4d, -5.0d, 1.0d})}), Row.of(new Object[]{Vectors.dense(new double[]{2.0d, -1.0d, -2.0d})}));
    private final List<DenseVector> expectedResWithMean = Arrays.asList(Vectors.dense(new double[]{-2.8d, 8.0d, 1.0d}), Vectors.dense(new double[]{1.1d, -6.0d, 1.0d}), Vectors.dense(new double[]{1.7d, -2.0d, -2.0d}));
    private final List<DenseVector> expectedResWithStd = Arrays.asList(Vectors.dense(new double[]{-1.0231819d, 1.2480754d, 0.5773502d}), Vectors.dense(new double[]{0.5729819d, -0.6933752d, 0.5773503d}), Vectors.dense(new double[]{0.8185455d, -0.138675d, -1.1547005d}));
    private final List<DenseVector> expectedResWithMeanAndStd = Arrays.asList(Vectors.dense(new double[]{-1.1459637d, 1.1094004d, 0.5773503d}), Vectors.dense(new double[]{0.45020003d, -0.8320503d, 0.5773503d}), Vectors.dense(new double[]{0.69576368d, -0.2773501d, -1.1547005d}));
    private final double[] expectedMean = {0.3d, 1.0d, 0.0d};
    private final double[] expectedStd = {2.4433583d, 7.2111026d, 1.7320508d};

    @Before
    public void before() {
        this.env = TestUtils.getExecutionEnvironment();
        this.tEnv = StreamTableEnvironment.create(this.env);
        this.denseTable = this.tEnv.fromDataStream(this.env.fromCollection(this.denseInput)).as("input", new String[0]);
    }

    private void verifyPredictionResult(List<DenseVector> list, Table table, String str) throws Exception {
        List list2 = IteratorUtils.toList(this.tEnv.toDataStream(table).executeAndCollect());
        ArrayList arrayList = new ArrayList(list2.size());
        Iterator it = list2.iterator();
        while (it.hasNext()) {
            arrayList.add(((Vector) ((Row) it.next()).getField(str)).toDense());
        }
        Assert.assertEquals(list.size(), arrayList.size());
        arrayList.sort((denseVector, denseVector2) -> {
            int min = Math.min(denseVector.size(), denseVector2.size());
            for (int i = 0; i < min; i++) {
                int compare = Double.compare(denseVector.get(i), denseVector2.get(i));
                if (compare != 0) {
                    return compare;
                }
            }
            return 0;
        });
        for (int i = 0; i < arrayList.size(); i++) {
            Assert.assertArrayEquals(list.get(i).values, ((DenseVector) arrayList.get(i)).values, TOLERANCE);
        }
    }

    @Test
    public void testParam() {
        StandardScaler standardScaler = new StandardScaler();
        Assert.assertEquals("input", standardScaler.getInputCol());
        Assert.assertEquals(false, standardScaler.getWithMean());
        Assert.assertEquals(true, standardScaler.getWithStd());
        Assert.assertEquals("output", standardScaler.getOutputCol());
        ((StandardScaler) ((StandardScaler) ((StandardScaler) standardScaler.setInputCol("test_input")).setWithMean(true)).setWithStd(false)).setOutputCol("test_output");
        Assert.assertEquals("test_input", standardScaler.getInputCol());
        Assert.assertEquals(true, standardScaler.getWithMean());
        Assert.assertEquals(false, standardScaler.getWithStd());
        Assert.assertEquals("test_output", standardScaler.getOutputCol());
    }

    @Test
    public void testOutputSchema() {
        Table as = this.denseTable.as("test_input", new String[0]);
        Assert.assertEquals(Arrays.asList("test_input", "test_output"), ((StandardScaler) ((StandardScaler) new StandardScaler().setInputCol("test_input")).setOutputCol("test_output")).fit(new Table[]{as}).transform(new Table[]{as})[0].getResolvedSchema().getColumnNames());
    }

    @Test
    public void testFitAndPredictWithStd() throws Exception {
        StandardScaler standardScaler = new StandardScaler();
        verifyPredictionResult(this.expectedResWithStd, standardScaler.fit(new Table[]{this.denseTable}).transform(new Table[]{this.denseTable})[0], standardScaler.getOutputCol());
    }

    @Test
    public void testFitAndPredictWithMean() throws Exception {
        StandardScaler standardScaler = (StandardScaler) ((StandardScaler) new StandardScaler().setWithStd(false)).setWithMean(true);
        verifyPredictionResult(this.expectedResWithMean, standardScaler.fit(new Table[]{this.denseTable}).transform(new Table[]{this.denseTable})[0], standardScaler.getOutputCol());
    }

    @Test
    public void testFitAndPredictWithMeanAndStd() throws Exception {
        StandardScaler standardScaler = (StandardScaler) new StandardScaler().setWithMean(true);
        verifyPredictionResult(this.expectedResWithMeanAndStd, standardScaler.fit(new Table[]{this.denseTable}).transform(new Table[]{this.denseTable})[0], standardScaler.getOutputCol());
    }

    @Test
    public void testInputTypeConversion() throws Exception {
        this.denseTable = TestUtils.convertDataTypesToSparseInt(this.tEnv, this.denseTable);
        Assert.assertArrayEquals(new Class[]{SparseVector.class}, TestUtils.getColumnDataTypes(this.denseTable));
        StandardScaler standardScaler = (StandardScaler) new StandardScaler().setWithMean(true);
        verifyPredictionResult(this.expectedResWithMeanAndStd, standardScaler.fit(new Table[]{this.denseTable}).transform(new Table[]{this.denseTable})[0], standardScaler.getOutputCol());
    }

    @Test
    public void testSaveLoadAndPredict() throws Exception {
        StandardScaler saveAndReload = TestUtils.saveAndReload(this.tEnv, new StandardScaler(), this.tempFolder.newFolder().getAbsolutePath(), StandardScaler::load);
        StandardScalerModel saveAndReload2 = TestUtils.saveAndReload(this.tEnv, saveAndReload.fit(new Table[]{this.denseTable}), this.tempFolder.newFolder().getAbsolutePath(), StandardScalerModel::load);
        Assert.assertEquals(Arrays.asList("mean", "std"), saveAndReload2.getModelData()[0].getResolvedSchema().getColumnNames().subList(0, 2));
        verifyPredictionResult(this.expectedResWithStd, saveAndReload2.transform(new Table[]{this.denseTable})[0], saveAndReload.getOutputCol());
    }

    @Test
    public void testGetModelData() throws Exception {
        Table table = new StandardScaler().fit(new Table[]{this.denseTable}).getModelData()[0];
        Assert.assertEquals(Arrays.asList("mean", "std"), table.getResolvedSchema().getColumnNames().subList(0, 2));
        List list = IteratorUtils.toList(StandardScalerModelData.getModelDataStream(table).executeAndCollect());
        Assert.assertEquals(1L, list.size());
        StandardScalerModelData standardScalerModelData = (StandardScalerModelData) list.get(0);
        Assert.assertArrayEquals(this.expectedMean, standardScalerModelData.mean.values, TOLERANCE);
        Assert.assertArrayEquals(this.expectedStd, standardScalerModelData.std.values, TOLERANCE);
    }

    @Test
    public void testSetModelData() throws Exception {
        StandardScaler standardScaler = new StandardScaler();
        StandardScalerModel fit = standardScaler.fit(new Table[]{this.denseTable});
        StandardScalerModel standardScalerModel = new StandardScalerModel();
        ParamUtils.updateExistingParams(standardScalerModel, fit.getParamMap());
        standardScalerModel.setModelData(fit.getModelData());
        verifyPredictionResult(this.expectedResWithStd, standardScalerModel.transform(new Table[]{this.denseTable})[0], standardScaler.getOutputCol());
    }

    @Test
    public void testSparseInput() throws Exception {
        Table as = this.tEnv.fromDataStream(this.env.fromCollection(Arrays.asList(Row.of(new Object[]{Vectors.sparse(3, new int[]{0, 1}, new double[]{-2.5d, 1.0d})}), Row.of(new Object[]{Vectors.sparse(3, new int[]{1, 2}, new double[]{2.0d, -2.0d})}), Row.of(new Object[]{Vectors.sparse(3, new int[]{0, 2}, new double[]{1.4d, 1.0d})})))).as("input", new String[0]);
        List<DenseVector> asList = Arrays.asList(Vectors.dense(new double[]{-1.2653836d, 1.0d, 0.0d}), Vectors.dense(new double[]{0.0d, 2.0d, -1.30930734d}), Vectors.dense(new double[]{0.7086148d, 0.0d, 0.6546537d}));
        StandardScaler standardScaler = new StandardScaler();
        verifyPredictionResult(asList, standardScaler.fit(new Table[]{as}).transform(new Table[]{as})[0], standardScaler.getOutputCol());
    }

    @Test
    public void testFitOnEmptyData() {
        try {
            IteratorUtils.toList(StandardScalerModelData.getModelDataStream(new StandardScaler().fit(new Table[]{this.tEnv.fromDataStream(this.env.fromCollection(this.denseInput).filter(row -> {
                return row.getArity() == 0;
            })).as("input", new String[0])}).getModelData()[0]).executeAndCollect());
            Assert.fail();
        } catch (Throwable th) {
            Assert.assertEquals("The training set is empty.", ExceptionUtils.getRootCause(th).getMessage());
        }
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case 1077534466:
                if (implMethodName.equals("lambda$testFitOnEmptyData$3dedd8cf$1")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/flink/api/common/functions/FilterFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("filter") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Z") && serializedLambda.getImplClass().equals("org/apache/flink/ml/feature/StandardScalerTest") && serializedLambda.getImplMethodSignature().equals("(Lorg/apache/flink/types/Row;)Z")) {
                    return row -> {
                        return row.getArity() == 0;
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
