package org.apache.flink.ml.feature;

import java.lang.invoke.SerializedLambda;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.apache.commons.collections.IteratorUtils;
import org.apache.commons.lang3.exception.ExceptionUtils;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.ml.feature.variancethresholdselector.VarianceThresholdSelector;
import org.apache.flink.ml.feature.variancethresholdselector.VarianceThresholdSelectorModel;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo;
import org.apache.flink.ml.util.TestUtils;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.api.internal.TableImpl;
import org.apache.flink.test.util.AbstractTestBase;
import org.apache.flink.test.util.TestBaseUtils;
import org.apache.flink.types.Row;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;

/* loaded from: input_file:org/apache/flink/ml/feature/VarianceThresholdSelectorTest.class */
public class VarianceThresholdSelectorTest extends AbstractTestBase {

    @Rule
    public final TemporaryFolder tempFolder = new TemporaryFolder();
    private StreamExecutionEnvironment env;
    private StreamTableEnvironment tEnv;
    private Table trainDataTable;
    private Table predictDataTable;
    private static final double EPS = 1.0E-5d;
    private static final List<Row> TRAIN_DATA = Arrays.asList(Row.of(new Object[]{1, Vectors.dense(new double[]{5.0d, 7.0d, 0.0d, 7.0d, 6.0d, 0.0d})}), Row.of(new Object[]{2, Vectors.dense(new double[]{0.0d, 9.0d, 6.0d, 0.0d, 5.0d, 9.0d}).toSparse()}), Row.of(new Object[]{3, Vectors.dense(new double[]{0.0d, 9.0d, 3.0d, 0.0d, 5.0d, 5.0d})}), Row.of(new Object[]{4, Vectors.dense(new double[]{1.0d, 9.0d, 8.0d, 5.0d, 7.0d, 4.0d}).toSparse()}), Row.of(new Object[]{5, Vectors.dense(new double[]{9.0d, 8.0d, 6.0d, 5.0d, 4.0d, 4.0d})}), Row.of(new Object[]{6, Vectors.dense(new double[]{6.0d, 9.0d, 7.0d, 0.0d, 2.0d, 0.0d}).toSparse()}));
    private static final List<Row> PREDICT_DATA = Arrays.asList(Row.of(new Object[]{Vectors.dense(new double[]{1.0d, 2.0d, 3.0d, 4.0d, 5.0d, 6.0d})}), Row.of(new Object[]{Vectors.dense(new double[]{0.1d, 0.2d, 0.3d, 0.4d, 0.5d, 0.6d})}), Row.of(new Object[]{Vectors.sparse(6, new int[]{0, 3, 4}, new double[]{0.1d, 0.3d, 0.5d})}));
    private static final List<Vector> EXPECTED_OUTPUT = Arrays.asList(Vectors.dense(new double[]{1.0d, 4.0d, 6.0d}), Vectors.dense(new double[]{0.1d, 0.4d, 0.6d}), Vectors.sparse(3, new int[]{0, 1}, new double[]{0.1d, 0.3d}));

    @Before
    public void before() {
        this.env = TestUtils.getExecutionEnvironment();
        this.tEnv = StreamTableEnvironment.create(this.env);
        this.trainDataTable = this.tEnv.fromDataStream(this.env.fromCollection(TRAIN_DATA, Types.ROW(new TypeInformation[]{Types.INT, VectorTypeInfo.INSTANCE}))).as("id", new String[]{"input"});
        this.predictDataTable = this.tEnv.fromDataStream(this.env.fromCollection(PREDICT_DATA, Types.ROW(new TypeInformation[]{VectorTypeInfo.INSTANCE}))).as("input", new String[0]);
    }

    private static void verifyPredictionResult(Table table, String str, List<Vector> list) throws Exception {
        TestBaseUtils.compareResultCollections(list, IteratorUtils.toList(((TableImpl) table).getTableEnvironment().toDataStream(table).map(row -> {
            return (Vector) row.getField(str);
        }, VectorTypeInfo.INSTANCE).executeAndCollect()), TestUtils::compare);
    }

    @Test
    public void testParam() {
        VarianceThresholdSelector varianceThresholdSelector = new VarianceThresholdSelector();
        Assert.assertEquals("input", varianceThresholdSelector.getInputCol());
        Assert.assertEquals("output", varianceThresholdSelector.getOutputCol());
        Assert.assertEquals(0.0d, varianceThresholdSelector.getVarianceThreshold(), EPS);
        ((VarianceThresholdSelector) ((VarianceThresholdSelector) varianceThresholdSelector.setInputCol("test_input")).setOutputCol("test_output")).setVarianceThreshold(0.5d);
        Assert.assertEquals("test_input", varianceThresholdSelector.getInputCol());
        Assert.assertEquals(0.5d, varianceThresholdSelector.getVarianceThreshold(), EPS);
        Assert.assertEquals("test_output", varianceThresholdSelector.getOutputCol());
    }

    @Test
    public void testOutputSchema() {
        Table as = this.trainDataTable.as("id", new String[]{"test_input"});
        Assert.assertEquals(Arrays.asList("id", "test_input", "test_output"), ((VarianceThresholdSelector) ((VarianceThresholdSelector) new VarianceThresholdSelector().setInputCol("test_input")).setOutputCol("test_output")).fit(new Table[]{as}).transform(new Table[]{as})[0].getResolvedSchema().getColumnNames());
    }

    @Test
    public void testFitAndPredict() throws Exception {
        VarianceThresholdSelector varianceThresholdSelector = (VarianceThresholdSelector) new VarianceThresholdSelector().setVarianceThreshold(8.0d);
        verifyPredictionResult(varianceThresholdSelector.fit(new Table[]{this.trainDataTable}).transform(new Table[]{this.predictDataTable})[0], varianceThresholdSelector.getOutputCol(), EXPECTED_OUTPUT);
    }

    @Test
    public void testNonSelectedFeatures() throws Exception {
        VarianceThresholdSelector varianceThresholdSelector = (VarianceThresholdSelector) new VarianceThresholdSelector().setVarianceThreshold(100.0d);
        verifyPredictionResult(varianceThresholdSelector.fit(new Table[]{this.trainDataTable}).transform(new Table[]{this.predictDataTable})[0], varianceThresholdSelector.getOutputCol(), Arrays.asList(Vectors.dense(new double[0]), Vectors.dense(new double[0]), Vectors.dense(new double[0])));
    }

    @Test
    public void testSaveLoadAndPredict() throws Exception {
        VarianceThresholdSelector varianceThresholdSelector = (VarianceThresholdSelector) new VarianceThresholdSelector().setVarianceThreshold(8.0d);
        VarianceThresholdSelectorModel fit = TestUtils.saveAndReload(this.tEnv, varianceThresholdSelector, this.tempFolder.newFolder().getAbsolutePath(), VarianceThresholdSelector::load).fit(new Table[]{this.trainDataTable});
        VarianceThresholdSelectorModel saveAndReload = TestUtils.saveAndReload(this.tEnv, fit, this.tempFolder.newFolder().getAbsolutePath(), VarianceThresholdSelectorModel::load);
        Assert.assertEquals(Arrays.asList("numOfFeatures", "indices"), fit.getModelData()[0].getResolvedSchema().getColumnNames());
        verifyPredictionResult(saveAndReload.transform(new Table[]{this.predictDataTable})[0], varianceThresholdSelector.getOutputCol(), EXPECTED_OUTPUT);
    }

    @Test
    public void testFitOnEmptyData() {
        try {
            new VarianceThresholdSelector().fit(new Table[]{this.tEnv.fromDataStream(this.env.fromCollection(TRAIN_DATA, Types.ROW(new TypeInformation[]{Types.INT, VectorTypeInfo.INSTANCE})).filter(row -> {
                return row.getArity() == 0;
            })).as("id", new String[]{"input"})}).getModelData()[0].execute().print();
            Assert.fail();
        } catch (Throwable th) {
            Assert.assertEquals("The training set is empty.", ExceptionUtils.getRootCause(th).getMessage());
        }
    }

    @Test
    public void testIncompatibleNumOfFeatures() {
        try {
            ((VarianceThresholdSelector) new VarianceThresholdSelector().setVarianceThreshold(8.0d)).fit(new Table[]{this.trainDataTable}).transform(new Table[]{this.tEnv.fromDataStream(this.env.fromCollection(new ArrayList(Arrays.asList(Row.of(new Object[]{Vectors.dense(new double[]{1.0d, 2.0d, 3.0d, 4.0d})}), Row.of(new Object[]{Vectors.dense(new double[]{0.1d, 0.2d, 0.3d, 0.4d})}))))).as("input", new String[0])})[0].execute().print();
            Assert.fail();
        } catch (Throwable th) {
            Assert.assertTrue(ExceptionUtils.getRootCause(th).getMessage().contains("but VarianceThresholdSelector is expecting"));
        }
    }

    @Test
    public void testGetModelData() throws Exception {
        Table table = ((VarianceThresholdSelector) new VarianceThresholdSelector().setVarianceThreshold(8.0d)).fit(new Table[]{this.trainDataTable}).getModelData()[0];
        Assert.assertEquals(Arrays.asList("numOfFeatures", "indices"), table.getResolvedSchema().getColumnNames());
        List list = IteratorUtils.toList(this.tEnv.toDataStream(table).executeAndCollect());
        long intValue = ((Integer) ((Row) list.get(0)).getField(0)).intValue();
        int[] iArr = (int[]) ((Row) list.get(0)).getField(1);
        Assert.assertEquals(6L, intValue);
        int[] iArr2 = {0, 3, 5};
        for (int i = 0; i < iArr.length; i++) {
            Assert.assertEquals(iArr2[i], iArr[i]);
        }
    }

    @Test
    public void testSetModelData() throws Exception {
        VarianceThresholdSelector varianceThresholdSelector = (VarianceThresholdSelector) new VarianceThresholdSelector().setVarianceThreshold(8.0d);
        verifyPredictionResult(new VarianceThresholdSelectorModel().setModelData(new Table[]{varianceThresholdSelector.fit(new Table[]{this.trainDataTable}).getModelData()[0]}).transform(new Table[]{this.predictDataTable})[0], varianceThresholdSelector.getOutputCol(), EXPECTED_OUTPUT);
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case -2045060237:
                if (implMethodName.equals("lambda$verifyPredictionResult$f4fb6055$1")) {
                    z = false;
                    break;
                }
                break;
            case 1077534466:
                if (implMethodName.equals("lambda$testFitOnEmptyData$3dedd8cf$1")) {
                    z = true;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/flink/api/common/functions/MapFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("map") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/flink/ml/feature/VarianceThresholdSelectorTest") && serializedLambda.getImplMethodSignature().equals("(Ljava/lang/String;Lorg/apache/flink/types/Row;)Lorg/apache/flink/ml/linalg/Vector;")) {
                    String str = (String) serializedLambda.getCapturedArg(0);
                    return row -> {
                        return (Vector) row.getField(str);
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/flink/api/common/functions/FilterFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("filter") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Z") && serializedLambda.getImplClass().equals("org/apache/flink/ml/feature/VarianceThresholdSelectorTest") && serializedLambda.getImplMethodSignature().equals("(Lorg/apache/flink/types/Row;)Z")) {
                    return row2 -> {
                        return row2.getArity() == 0;
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
