package org.apache.flink.ml.feature;

import java.lang.invoke.SerializedLambda;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import org.apache.commons.collections.IteratorUtils;
import org.apache.commons.lang3.exception.ExceptionUtils;
import org.apache.flink.ml.feature.kbinsdiscretizer.KBinsDiscretizer;
import org.apache.flink.ml.feature.kbinsdiscretizer.KBinsDiscretizerModel;
import org.apache.flink.ml.feature.kbinsdiscretizer.KBinsDiscretizerModelData;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.TestUtils;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Expressions;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.expressions.Expression;
import org.apache.flink.test.util.AbstractTestBase;
import org.apache.flink.test.util.TestBaseUtils;
import org.apache.flink.types.Row;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;

/* loaded from: input_file:org/apache/flink/ml/feature/KBinsDiscretizerTest.class */
public class KBinsDiscretizerTest extends AbstractTestBase {

    @Rule
    public final TemporaryFolder tempFolder = new TemporaryFolder();
    private StreamExecutionEnvironment env;
    private StreamTableEnvironment tEnv;
    private Table trainTable;
    private Table testTable;
    private static final List<Row> TRAIN_INPUT = Arrays.asList(Row.of(new Object[]{Vectors.dense(new double[]{1.0d, 10.0d, 0.0d})}), Row.of(new Object[]{Vectors.dense(new double[]{1.0d, 10.0d, 0.0d})}), Row.of(new Object[]{Vectors.dense(new double[]{1.0d, 10.0d, 0.0d})}), Row.of(new Object[]{Vectors.dense(new double[]{4.0d, 10.0d, 0.0d})}), Row.of(new Object[]{Vectors.dense(new double[]{5.0d, 10.0d, 0.0d})}), Row.of(new Object[]{Vectors.dense(new double[]{6.0d, 10.0d, 0.0d})}), Row.of(new Object[]{Vectors.dense(new double[]{7.0d, 10.0d, 0.0d})}), Row.of(new Object[]{Vectors.dense(new double[]{10.0d, 10.0d, 0.0d})}), Row.of(new Object[]{Vectors.dense(new double[]{13.0d, 10.0d, 3.0d})}));
    private static final List<Row> TEST_INPUT = Arrays.asList(Row.of(new Object[]{Vectors.dense(new double[]{-1.0d, 0.0d, 0.0d})}), Row.of(new Object[]{Vectors.dense(new double[]{1.0d, 1.0d, 1.0d})}), Row.of(new Object[]{Vectors.dense(new double[]{1.5d, 1.0d, 2.0d})}), Row.of(new Object[]{Vectors.dense(new double[]{5.0d, 2.0d, 3.0d})}), Row.of(new Object[]{Vectors.dense(new double[]{7.25d, 3.0d, 4.0d})}), Row.of(new Object[]{Vectors.dense(new double[]{13.0d, 4.0d, 5.0d})}), Row.of(new Object[]{Vectors.dense(new double[]{15.0d, 4.0d, 6.0d})}));
    private static final double[][] UNIFORM_MODEL_DATA = {new double[]{1.0d, 5.0d, 9.0d, 13.0d}, new double[]{Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY}, new double[]{0.0d, 1.0d, 2.0d, 3.0d}};
    private static final List<Row> UNIFORM_OUTPUT = Arrays.asList(Row.of(new Object[]{Vectors.dense(new double[]{0.0d, 0.0d, 0.0d})}), Row.of(new Object[]{Vectors.dense(new double[]{0.0d, 0.0d, 1.0d})}), Row.of(new Object[]{Vectors.dense(new double[]{0.0d, 0.0d, 2.0d})}), Row.of(new Object[]{Vectors.dense(new double[]{1.0d, 0.0d, 2.0d})}), Row.of(new Object[]{Vectors.dense(new double[]{1.0d, 0.0d, 2.0d})}), Row.of(new Object[]{Vectors.dense(new double[]{2.0d, 0.0d, 2.0d})}), Row.of(new Object[]{Vectors.dense(new double[]{2.0d, 0.0d, 2.0d})}));
    private static final List<Row> QUANTILE_OUTPUT = Arrays.asList(Row.of(new Object[]{Vectors.dense(new double[]{0.0d, 0.0d, 0.0d})}), Row.of(new Object[]{Vectors.dense(new double[]{0.0d, 0.0d, 0.0d})}), Row.of(new Object[]{Vectors.dense(new double[]{0.0d, 0.0d, 1.0d})}), Row.of(new Object[]{Vectors.dense(new double[]{1.0d, 0.0d, 1.0d})}), Row.of(new Object[]{Vectors.dense(new double[]{2.0d, 0.0d, 1.0d})}), Row.of(new Object[]{Vectors.dense(new double[]{2.0d, 0.0d, 1.0d})}), Row.of(new Object[]{Vectors.dense(new double[]{2.0d, 0.0d, 1.0d})}));
    private static final List<Row> KMEANS_OUTPUT = Arrays.asList(Row.of(new Object[]{Vectors.dense(new double[]{0.0d, 0.0d, 0.0d})}), Row.of(new Object[]{Vectors.dense(new double[]{0.0d, 0.0d, 1.0d})}), Row.of(new Object[]{Vectors.dense(new double[]{0.0d, 0.0d, 2.0d})}), Row.of(new Object[]{Vectors.dense(new double[]{1.0d, 0.0d, 2.0d})}), Row.of(new Object[]{Vectors.dense(new double[]{1.0d, 0.0d, 2.0d})}), Row.of(new Object[]{Vectors.dense(new double[]{2.0d, 0.0d, 2.0d})}), Row.of(new Object[]{Vectors.dense(new double[]{2.0d, 0.0d, 2.0d})}));
    private static final double TOLERANCE = 1.0E-7d;

    @Before
    public void before() {
        this.env = TestUtils.getExecutionEnvironment();
        this.tEnv = StreamTableEnvironment.create(this.env);
        this.trainTable = this.tEnv.fromDataStream(this.env.fromCollection(TRAIN_INPUT)).as("input", new String[0]);
        this.testTable = this.tEnv.fromDataStream(this.env.fromCollection(TEST_INPUT)).as("input", new String[0]);
    }

    private void verifyPredictionResult(List<Row> list, Table table, String str) throws Exception {
        TestBaseUtils.compareResultCollections(list, IteratorUtils.toList(this.tEnv.toDataStream(table.select(new Expression[]{Expressions.$(str)})).executeAndCollect()), (row, row2) -> {
            return TestUtils.compare((DenseVector) row.getField(0), (DenseVector) row2.getField(0));
        });
    }

    @Test
    public void testParam() {
        KBinsDiscretizer kBinsDiscretizer = new KBinsDiscretizer();
        Assert.assertEquals("input", kBinsDiscretizer.getInputCol());
        Assert.assertEquals(5L, kBinsDiscretizer.getNumBins());
        Assert.assertEquals("quantile", kBinsDiscretizer.getStrategy());
        Assert.assertEquals(200000L, kBinsDiscretizer.getSubSamples());
        Assert.assertEquals("output", kBinsDiscretizer.getOutputCol());
        ((KBinsDiscretizer) ((KBinsDiscretizer) ((KBinsDiscretizer) ((KBinsDiscretizer) kBinsDiscretizer.setInputCol("test_input")).setNumBins(10)).setStrategy("kmeans")).setSubSamples(1000)).setOutputCol("test_output");
        Assert.assertEquals("test_input", kBinsDiscretizer.getInputCol());
        Assert.assertEquals(10L, kBinsDiscretizer.getNumBins());
        Assert.assertEquals("kmeans", kBinsDiscretizer.getStrategy());
        Assert.assertEquals(1000L, kBinsDiscretizer.getSubSamples());
        Assert.assertEquals("test_output", kBinsDiscretizer.getOutputCol());
    }

    @Test
    public void testOutputSchema() {
        Table as = this.tEnv.fromDataStream(this.env.fromElements(new Row[]{Row.of(new Object[]{"", ""})})).as("test_input", new String[]{"dummy_input"});
        Assert.assertEquals(Arrays.asList("test_input", "dummy_input", "test_output"), ((KBinsDiscretizer) ((KBinsDiscretizer) new KBinsDiscretizer().setInputCol("test_input")).setOutputCol("test_output")).fit(new Table[]{as}).transform(new Table[]{as})[0].getResolvedSchema().getColumnNames());
    }

    @Test
    public void testFitAndPredict() throws Exception {
        KBinsDiscretizer kBinsDiscretizer = (KBinsDiscretizer) new KBinsDiscretizer().setNumBins(3);
        kBinsDiscretizer.setStrategy("uniform");
        verifyPredictionResult(UNIFORM_OUTPUT, kBinsDiscretizer.fit(new Table[]{this.trainTable}).transform(new Table[]{this.testTable})[0], kBinsDiscretizer.getOutputCol());
        kBinsDiscretizer.setStrategy("quantile");
        verifyPredictionResult(QUANTILE_OUTPUT, kBinsDiscretizer.fit(new Table[]{this.trainTable}).transform(new Table[]{this.testTable})[0], kBinsDiscretizer.getOutputCol());
        kBinsDiscretizer.setStrategy("kmeans");
        verifyPredictionResult(KMEANS_OUTPUT, kBinsDiscretizer.fit(new Table[]{this.trainTable}).transform(new Table[]{this.testTable})[0], kBinsDiscretizer.getOutputCol());
    }

    @Test
    public void testSaveLoadAndPredict() throws Exception {
        KBinsDiscretizer saveAndReload = TestUtils.saveAndReload(this.tEnv, (KBinsDiscretizer) ((KBinsDiscretizer) new KBinsDiscretizer().setNumBins(3)).setStrategy("uniform"), this.tempFolder.newFolder().getAbsolutePath(), KBinsDiscretizer::load);
        KBinsDiscretizerModel saveAndReload2 = TestUtils.saveAndReload(this.tEnv, saveAndReload.fit(new Table[]{this.trainTable}), this.tempFolder.newFolder().getAbsolutePath(), KBinsDiscretizerModel::load);
        Assert.assertEquals(Collections.singletonList("binEdges"), saveAndReload2.getModelData()[0].getResolvedSchema().getColumnNames());
        verifyPredictionResult(UNIFORM_OUTPUT, saveAndReload2.transform(new Table[]{this.testTable})[0], saveAndReload.getOutputCol());
    }

    @Test
    public void testGetModelData() throws Exception {
        Table table = ((KBinsDiscretizer) ((KBinsDiscretizer) new KBinsDiscretizer().setNumBins(3)).setStrategy("uniform")).fit(new Table[]{this.trainTable}).getModelData()[0];
        Assert.assertEquals(Collections.singletonList("binEdges"), table.getResolvedSchema().getColumnNames());
        List list = IteratorUtils.toList(KBinsDiscretizerModelData.getModelDataStream(table).executeAndCollect());
        Assert.assertEquals(1L, list.size());
        KBinsDiscretizerModelData kBinsDiscretizerModelData = (KBinsDiscretizerModelData) list.get(0);
        Assert.assertEquals(UNIFORM_MODEL_DATA.length, kBinsDiscretizerModelData.binEdges.length);
        for (int i = 0; i < kBinsDiscretizerModelData.binEdges.length; i++) {
            Assert.assertArrayEquals(UNIFORM_MODEL_DATA[i], kBinsDiscretizerModelData.binEdges[i], TOLERANCE);
        }
    }

    @Test
    public void testSetModelData() throws Exception {
        KBinsDiscretizer kBinsDiscretizer = (KBinsDiscretizer) ((KBinsDiscretizer) new KBinsDiscretizer().setNumBins(3)).setStrategy("uniform");
        KBinsDiscretizerModel fit = kBinsDiscretizer.fit(new Table[]{this.trainTable});
        KBinsDiscretizerModel kBinsDiscretizerModel = new KBinsDiscretizerModel();
        ParamUtils.updateExistingParams(kBinsDiscretizerModel, fit.getParamMap());
        kBinsDiscretizerModel.setModelData(fit.getModelData());
        verifyPredictionResult(UNIFORM_OUTPUT, kBinsDiscretizerModel.transform(new Table[]{this.testTable})[0], kBinsDiscretizer.getOutputCol());
    }

    @Test
    public void testFitOnEmptyData() {
        try {
            new KBinsDiscretizer().fit(new Table[]{this.tEnv.fromDataStream(this.env.fromCollection(TRAIN_INPUT).filter(row -> {
                return row.getArity() == 0;
            })).as("input", new String[0])}).getModelData()[0].execute().collect().next();
            Assert.fail();
        } catch (Throwable th) {
            Assert.assertEquals("The training set is empty.", ExceptionUtils.getRootCause(th).getMessage());
        }
    }

    @Test
    public void testBinsWithWidthAsZero() throws Exception {
        List<Row> asList = Arrays.asList(Row.of(new Object[]{Vectors.dense(new double[]{0.0d, 0.0d, 0.0d})}), Row.of(new Object[]{Vectors.dense(new double[]{0.0d, 0.0d, 0.0d})}), Row.of(new Object[]{Vectors.dense(new double[]{0.0d, 0.0d, 1.0d})}), Row.of(new Object[]{Vectors.dense(new double[]{3.0d, 0.0d, 1.0d})}), Row.of(new Object[]{Vectors.dense(new double[]{5.0d, 0.0d, 1.0d})}), Row.of(new Object[]{Vectors.dense(new double[]{6.0d, 0.0d, 1.0d})}), Row.of(new Object[]{Vectors.dense(new double[]{6.0d, 0.0d, 1.0d})}));
        KBinsDiscretizer kBinsDiscretizer = (KBinsDiscretizer) ((KBinsDiscretizer) new KBinsDiscretizer().setNumBins(10)).setStrategy("quantile");
        verifyPredictionResult(asList, kBinsDiscretizer.fit(new Table[]{this.trainTable}).transform(new Table[]{this.testTable})[0], kBinsDiscretizer.getOutputCol());
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case 1077534466:
                if (implMethodName.equals("lambda$testFitOnEmptyData$3dedd8cf$1")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/flink/api/common/functions/FilterFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("filter") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Z") && serializedLambda.getImplClass().equals("org/apache/flink/ml/feature/KBinsDiscretizerTest") && serializedLambda.getImplMethodSignature().equals("(Lorg/apache/flink/types/Row;)Z")) {
                    return row -> {
                        return row.getArity() == 0;
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
