package org.apache.flink.ml.feature;

import java.lang.invoke.SerializedLambda;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import org.apache.commons.collections.IteratorUtils;
import org.apache.flink.api.common.eventtime.WatermarkStrategy;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.time.Time;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.typeutils.RowTypeInfo;
import org.apache.flink.ml.common.window.CountTumblingWindows;
import org.apache.flink.ml.common.window.EventTimeTumblingWindows;
import org.apache.flink.ml.common.window.GlobalWindows;
import org.apache.flink.ml.common.window.ProcessingTimeTumblingWindows;
import org.apache.flink.ml.feature.standardscaler.OnlineStandardScaler;
import org.apache.flink.ml.feature.standardscaler.OnlineStandardScalerModel;
import org.apache.flink.ml.feature.standardscaler.StandardScalerModelData;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.TestUtils;
import org.apache.flink.streaming.api.datastream.DataStreamSource;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.DataTypes;
import org.apache.flink.table.api.Schema;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.test.util.AbstractTestBase;
import org.apache.flink.types.Row;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;

/* loaded from: input_file:org/apache/flink/ml/feature/OnlineStandardScalerTest.class */
public class OnlineStandardScalerTest extends AbstractTestBase {
    private StreamExecutionEnvironment env;
    private StreamTableEnvironment tEnv;
    private static final double TOLERANCE = 1.0E-7d;
    private Table inputTable;
    private Table inputTableWithProcessingTime;
    private Table inputTableWithEventTime;

    @Rule
    public final TemporaryFolder tempFolder = new TemporaryFolder();
    private final List<Row> inputData = Arrays.asList(Row.of(new Object[]{0L, Vectors.dense(new double[]{-2.5d, 9.0d, 1.0d})}), Row.of(new Object[]{1000L, Vectors.dense(new double[]{1.4d, -5.0d, 1.0d})}), Row.of(new Object[]{2000L, Vectors.dense(new double[]{2.0d, -1.0d, -2.0d})}), Row.of(new Object[]{6000L, Vectors.dense(new double[]{0.7d, 3.0d, 1.0d})}), Row.of(new Object[]{7000L, Vectors.dense(new double[]{0.0d, 1.0d, 1.0d})}), Row.of(new Object[]{8000L, Vectors.dense(new double[]{0.5d, 0.0d, -2.0d})}), Row.of(new Object[]{9000L, Vectors.dense(new double[]{0.4d, 1.0d, 1.0d})}), Row.of(new Object[]{10000L, Vectors.dense(new double[]{0.3d, 2.0d, 1.0d})}), Row.of(new Object[]{11000L, Vectors.dense(new double[]{0.5d, 1.0d, -2.0d})}));
    private final List<StandardScalerModelData> expectedModelData = Arrays.asList(new StandardScalerModelData(Vectors.dense(new double[]{0.3d, 1.0d, 0.0d}), Vectors.dense(new double[]{2.4433583d, 7.2111026d, 1.7320508d}), 0, 2999), new StandardScalerModelData(Vectors.dense(new double[]{0.35d, 1.1666667d, 0.0d}), Vectors.dense(new double[]{1.5630099d, 4.665476d, 1.5491933d}), 1, 8999), new StandardScalerModelData(Vectors.dense(new double[]{0.3666667d, 1.2222222d, 0.0d}), Vectors.dense(new double[]{1.2369316d, 3.7006005d, 1.5d}), 2, 11999));

    @Before
    public void before() {
        this.env = TestUtils.getExecutionEnvironment();
        this.tEnv = StreamTableEnvironment.create(this.env);
        DataStreamSource fromCollection = this.env.fromCollection(this.inputData);
        this.inputTable = this.tEnv.fromDataStream(fromCollection, Schema.newBuilder().column("f0", DataTypes.BIGINT()).column("f1", DataTypes.RAW(DenseVectorTypeInfo.INSTANCE)).build()).as("id", new String[]{"input"});
        this.inputTableWithProcessingTime = this.tEnv.fromDataStream(fromCollection.map(new MapFunction<Row, Row>() { // from class: org.apache.flink.ml.feature.OnlineStandardScalerTest.1
            private int count = 0;

            public Row map(Row row) throws Exception {
                this.count++;
                if (this.count % 3 == 0) {
                    Thread.sleep(1000L);
                }
                return row;
            }
        }, new RowTypeInfo(new TypeInformation[]{Types.LONG, DenseVectorTypeInfo.INSTANCE}, new String[]{"id", "input"})).setParallelism(1));
        this.inputTableWithEventTime = this.tEnv.fromDataStream(fromCollection.assignTimestampsAndWatermarks(WatermarkStrategy.forMonotonousTimestamps().withTimestampAssigner((row, j) -> {
            return ((Long) row.getFieldAs(0)).longValue();
        })), Schema.newBuilder().column("f0", DataTypes.BIGINT()).column("f1", DataTypes.RAW(DenseVectorTypeInfo.INSTANCE)).columnByMetadata("rowtime", "TIMESTAMP_LTZ(3)").watermark("rowtime", "SOURCE_WATERMARK()").build()).as("id", new String[]{"input"});
    }

    @Test
    public void testParam() {
        OnlineStandardScaler onlineStandardScaler = new OnlineStandardScaler();
        Assert.assertEquals("input", onlineStandardScaler.getInputCol());
        Assert.assertEquals(false, onlineStandardScaler.getWithMean());
        Assert.assertEquals(true, onlineStandardScaler.getWithStd());
        Assert.assertEquals("output", onlineStandardScaler.getOutputCol());
        Assert.assertEquals("version", onlineStandardScaler.getModelVersionCol());
        Assert.assertEquals(GlobalWindows.getInstance(), onlineStandardScaler.getWindows());
        Assert.assertEquals(0L, onlineStandardScaler.getMaxAllowedModelDelayMs());
        ((OnlineStandardScaler) ((OnlineStandardScaler) ((OnlineStandardScaler) ((OnlineStandardScaler) ((OnlineStandardScaler) ((OnlineStandardScaler) onlineStandardScaler.setInputCol("test_input")).setWithMean(true)).setWithStd(false)).setOutputCol("test_output")).setModelVersionCol("test_version")).setWindows(EventTimeTumblingWindows.of(Time.milliseconds(3000L)))).setMaxAllowedModelDelayMs(3000L);
        Assert.assertEquals("test_input", onlineStandardScaler.getInputCol());
        Assert.assertEquals(true, onlineStandardScaler.getWithMean());
        Assert.assertEquals(false, onlineStandardScaler.getWithStd());
        Assert.assertEquals("test_output", onlineStandardScaler.getOutputCol());
        Assert.assertEquals("test_version", onlineStandardScaler.getModelVersionCol());
        Assert.assertEquals(EventTimeTumblingWindows.of(Time.milliseconds(3000L)), onlineStandardScaler.getWindows());
        Assert.assertEquals(3000L, onlineStandardScaler.getMaxAllowedModelDelayMs());
    }

    @Test
    public void testOutputSchema() {
        Table as = this.inputTable.as("test_id", new String[]{"test_input"});
        OnlineStandardScaler onlineStandardScaler = (OnlineStandardScaler) ((OnlineStandardScaler) new OnlineStandardScaler().setInputCol("test_input")).setOutputCol("test_output");
        Assert.assertEquals(Arrays.asList("test_id", "test_input", "test_output", "version"), onlineStandardScaler.fit(new Table[]{as}).transform(new Table[]{as})[0].getResolvedSchema().getColumnNames());
        onlineStandardScaler.setModelVersionCol((String) null);
        Assert.assertEquals(Arrays.asList("test_id", "test_input", "test_output"), onlineStandardScaler.fit(new Table[]{as}).transform(new Table[]{as})[0].getResolvedSchema().getColumnNames());
    }

    @Test
    public void testFitAndPredictWithEventTimeWindow() throws Exception {
        OnlineStandardScaler onlineStandardScaler = new OnlineStandardScaler();
        onlineStandardScaler.setWindows(EventTimeTumblingWindows.of(Time.milliseconds(3000)));
        verifyUsedModelVersion(onlineStandardScaler.fit(new Table[]{this.inputTableWithEventTime}).transform(new Table[]{this.inputTableWithEventTime})[0], onlineStandardScaler.getModelVersionCol(), onlineStandardScaler.getMaxAllowedModelDelayMs());
        onlineStandardScaler.setMaxAllowedModelDelayMs(3000);
        verifyUsedModelVersion(onlineStandardScaler.fit(new Table[]{this.inputTableWithEventTime}).transform(new Table[]{this.inputTableWithEventTime})[0], onlineStandardScaler.getModelVersionCol(), onlineStandardScaler.getMaxAllowedModelDelayMs());
    }

    @Test
    public void testFitAndPredictWithProcessingTimeWindow() throws Exception {
        Assert.assertTrue(IteratorUtils.toList(StandardScalerModelData.getModelDataStream(((OnlineStandardScaler) new OnlineStandardScaler().setWindows(ProcessingTimeTumblingWindows.of(Time.milliseconds(1000L)))).fit(new Table[]{this.inputTableWithProcessingTime}).getModelData()[0]).executeAndCollect()).size() >= 1);
        Assert.assertEquals(this.inputData.size(), IteratorUtils.toList(this.tEnv.toDataStream(r0.transform(new Table[]{this.inputTableWithProcessingTime})[0]).executeAndCollect()).size());
    }

    @Test
    public void testFitAndPredictWithCountWindow() throws Exception {
        List list = IteratorUtils.toList(StandardScalerModelData.getModelDataStream(((OnlineStandardScaler) new OnlineStandardScaler().setWindows(CountTumblingWindows.of(3L))).fit(new Table[]{this.inputTable}).getModelData()[0]).executeAndCollect());
        Assert.assertEquals(this.expectedModelData.size(), list.size());
        for (int i = 0; i < this.expectedModelData.size(); i++) {
            verifyModelData(this.expectedModelData.get(i), (StandardScalerModelData) list.get(i), false);
        }
        Assert.assertEquals(this.inputData.size(), IteratorUtils.toList(this.tEnv.toDataStream(r0.transform(new Table[]{this.inputTableWithEventTime})[0]).executeAndCollect()).size());
    }

    @Test
    public void testFitAndPredictWithGlobalWindow() throws Exception {
        OnlineStandardScaler onlineStandardScaler = (OnlineStandardScaler) new OnlineStandardScaler().setWindows(GlobalWindows.getInstance());
        Table as = this.tEnv.fromDataStream(this.env.fromCollection(Arrays.asList(Row.of(new Object[]{Vectors.dense(new double[]{-2.5d, 9.0d, 1.0d})}), Row.of(new Object[]{Vectors.dense(new double[]{1.4d, -5.0d, 1.0d})}), Row.of(new Object[]{Vectors.dense(new double[]{2.0d, -1.0d, -2.0d})})))).as("input", new String[0]);
        verifyPredictionResult(Arrays.asList(Vectors.dense(new double[]{-2.8d, 8.0d, 1.0d}), Vectors.dense(new double[]{1.1d, -6.0d, 1.0d}), Vectors.dense(new double[]{1.7d, -2.0d, -2.0d})), ((OnlineStandardScaler) ((OnlineStandardScaler) onlineStandardScaler.setWithMean(true)).setWithStd(false)).fit(new Table[]{as}).transform(new Table[]{as})[0], onlineStandardScaler.getOutputCol());
        verifyPredictionResult(Arrays.asList(Vectors.dense(new double[]{-1.0231819d, 1.2480754d, 0.5773502d}), Vectors.dense(new double[]{0.5729819d, -0.6933752d, 0.5773503d}), Vectors.dense(new double[]{0.8185455d, -0.138675d, -1.1547005d})), ((OnlineStandardScaler) ((OnlineStandardScaler) onlineStandardScaler.setWithMean(false)).setWithStd(true)).fit(new Table[]{as}).transform(new Table[]{as})[0], onlineStandardScaler.getOutputCol());
        verifyPredictionResult(Arrays.asList(Vectors.dense(new double[]{-1.1459637d, 1.1094004d, 0.5773503d}), Vectors.dense(new double[]{0.45020003d, -0.8320503d, 0.5773503d}), Vectors.dense(new double[]{0.69576368d, -0.2773501d, -1.1547005d})), ((OnlineStandardScaler) ((OnlineStandardScaler) onlineStandardScaler.setWithMean(true)).setWithStd(true)).fit(new Table[]{as}).transform(new Table[]{as})[0], onlineStandardScaler.getOutputCol());
    }

    @Test
    public void testGetModelData() throws Exception {
        Table table = ((OnlineStandardScaler) new OnlineStandardScaler().setWindows(EventTimeTumblingWindows.of(Time.milliseconds(3000L)))).fit(new Table[]{this.inputTableWithEventTime}).getModelData()[0];
        Assert.assertEquals(Arrays.asList("mean", "std", "version", "timestamp"), table.getResolvedSchema().getColumnNames());
        List list = IteratorUtils.toList(StandardScalerModelData.getModelDataStream(table).executeAndCollect());
        Assert.assertEquals(this.expectedModelData.size(), list.size());
        for (int i = 0; i < this.expectedModelData.size(); i++) {
            verifyModelData(this.expectedModelData.get(i), (StandardScalerModelData) list.get(i), true);
        }
    }

    @Test
    public void testSetModelData() throws Exception {
        OnlineStandardScaler onlineStandardScaler = (OnlineStandardScaler) new OnlineStandardScaler().setWindows(EventTimeTumblingWindows.of(Time.milliseconds(3000L)));
        OnlineStandardScalerModel fit = onlineStandardScaler.fit(new Table[]{this.inputTableWithEventTime});
        OnlineStandardScalerModel onlineStandardScalerModel = new OnlineStandardScalerModel();
        ParamUtils.updateExistingParams(onlineStandardScalerModel, fit.getParamMap());
        onlineStandardScalerModel.setModelData(fit.getModelData());
        verifyUsedModelVersion(onlineStandardScalerModel.transform(new Table[]{this.inputTableWithEventTime})[0], onlineStandardScaler.getModelVersionCol(), onlineStandardScaler.getMaxAllowedModelDelayMs());
    }

    @Test
    public void testSaveLoadPredict() throws Exception {
        OnlineStandardScaler saveAndReload = TestUtils.saveAndReload(this.tEnv, (OnlineStandardScaler) new OnlineStandardScaler().setWindows(EventTimeTumblingWindows.of(Time.milliseconds(3000L))), TEMPORARY_FOLDER.newFolder().getAbsolutePath(), OnlineStandardScaler::load);
        OnlineStandardScalerModel fit = saveAndReload.fit(new Table[]{this.inputTableWithEventTime});
        Table[] modelData = fit.getModelData();
        OnlineStandardScalerModel saveAndReload2 = TestUtils.saveAndReload(this.tEnv, fit, TEMPORARY_FOLDER.newFolder().getAbsolutePath(), OnlineStandardScalerModel::load);
        saveAndReload2.setModelData(modelData);
        verifyUsedModelVersion(saveAndReload2.transform(new Table[]{this.inputTableWithEventTime})[0], saveAndReload.getModelVersionCol(), saveAndReload.getMaxAllowedModelDelayMs());
    }

    private static void verifyModelData(StandardScalerModelData standardScalerModelData, StandardScalerModelData standardScalerModelData2, boolean z) {
        Assert.assertArrayEquals(standardScalerModelData.mean.values, standardScalerModelData2.mean.values, TOLERANCE);
        Assert.assertArrayEquals(standardScalerModelData.std.values, standardScalerModelData2.std.values, TOLERANCE);
        Assert.assertEquals(standardScalerModelData.version, standardScalerModelData2.version);
        if (z) {
            Assert.assertEquals(standardScalerModelData.timestamp, standardScalerModelData2.timestamp);
        }
    }

    private void verifyUsedModelVersion(Table table, String str, long j) throws Exception {
        for (Row row : IteratorUtils.toList(this.tEnv.toDataStream(table).executeAndCollect())) {
            Assert.assertTrue(((Long) row.getFieldAs(0)).longValue() - this.expectedModelData.get((int) ((Long) row.getFieldAs(str)).longValue()).timestamp <= j);
        }
    }

    private void verifyPredictionResult(List<DenseVector> list, Table table, String str) throws Exception {
        List list2 = IteratorUtils.toList(this.tEnv.toDataStream(table).executeAndCollect());
        ArrayList arrayList = new ArrayList(list2.size());
        Iterator it = list2.iterator();
        while (it.hasNext()) {
            arrayList.add(((Vector) ((Row) it.next()).getField(str)).toDense());
        }
        Assert.assertEquals(list.size(), arrayList.size());
        arrayList.sort((v0, v1) -> {
            return TestUtils.compare(v0, v1);
        });
        for (int i = 0; i < arrayList.size(); i++) {
            Assert.assertArrayEquals(list.get(i).values, ((DenseVector) arrayList.get(i)).values, TOLERANCE);
        }
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case 1629017819:
                if (implMethodName.equals("lambda$before$303c8881$1")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/flink/api/common/eventtime/SerializableTimestampAssigner") && serializedLambda.getFunctionalInterfaceMethodName().equals("extractTimestamp") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;J)J") && serializedLambda.getImplClass().equals("org/apache/flink/ml/feature/OnlineStandardScalerTest") && serializedLambda.getImplMethodSignature().equals("(Lorg/apache/flink/types/Row;J)J")) {
                    return (row, j) -> {
                        return ((Long) row.getFieldAs(0)).longValue();
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
