package org.apache.flink.ml.evaluation;

import java.util.Arrays;
import java.util.List;
import org.apache.commons.collections.IteratorUtils;
import org.apache.flink.api.common.restartstrategy.RestartStrategies;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.ml.evaluation.binaryclassification.BinaryClassificationEvaluator;
import org.apache.flink.ml.linalg.SparseVector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.ml.util.TestUtils;
import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.types.Row;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;

/* loaded from: input_file:org/apache/flink/ml/evaluation/BinaryClassificationEvaluatorTest.class */
public class BinaryClassificationEvaluatorTest {

    @Rule
    public final TemporaryFolder tempFolder = new TemporaryFolder();
    private StreamTableEnvironment tEnv;
    private StreamExecutionEnvironment env;
    private Table inputDataTable;
    private Table inputDataTableScore;
    private Table inputDataTableWithMultiScore;
    private Table inputDataTableWithWeight;
    private static final List<Row> INPUT_DATA = Arrays.asList(Row.of(new Object[]{Double.valueOf(1.0d), Vectors.dense(new double[]{0.1d, 0.9d})}), Row.of(new Object[]{Double.valueOf(1.0d), Vectors.dense(new double[]{0.2d, 0.8d})}), Row.of(new Object[]{Double.valueOf(1.0d), Vectors.dense(new double[]{0.3d, 0.7d})}), Row.of(new Object[]{Double.valueOf(0.0d), Vectors.dense(new double[]{0.25d, 0.75d})}), Row.of(new Object[]{Double.valueOf(0.0d), Vectors.dense(new double[]{0.4d, 0.6d})}), Row.of(new Object[]{Double.valueOf(1.0d), Vectors.dense(new double[]{0.35d, 0.65d})}), Row.of(new Object[]{Double.valueOf(1.0d), Vectors.dense(new double[]{0.45d, 0.55d})}), Row.of(new Object[]{Double.valueOf(0.0d), Vectors.dense(new double[]{0.6d, 0.4d})}), Row.of(new Object[]{Double.valueOf(0.0d), Vectors.dense(new double[]{0.7d, 0.3d})}), Row.of(new Object[]{Double.valueOf(1.0d), Vectors.dense(new double[]{0.65d, 0.35d})}), Row.of(new Object[]{Double.valueOf(0.0d), Vectors.dense(new double[]{0.8d, 0.2d})}), Row.of(new Object[]{Double.valueOf(1.0d), Vectors.dense(new double[]{0.9d, 0.1d})}));
    private static final List<Row> INPUT_DATA_DOUBLE_RAW = Arrays.asList(Row.of(new Object[]{1, Double.valueOf(0.9d)}), Row.of(new Object[]{1, Double.valueOf(0.8d)}), Row.of(new Object[]{1, Double.valueOf(0.7d)}), Row.of(new Object[]{0, Double.valueOf(0.75d)}), Row.of(new Object[]{0, Double.valueOf(0.6d)}), Row.of(new Object[]{1, Double.valueOf(0.65d)}), Row.of(new Object[]{1, Double.valueOf(0.55d)}), Row.of(new Object[]{0, Double.valueOf(0.4d)}), Row.of(new Object[]{0, Double.valueOf(0.3d)}), Row.of(new Object[]{1, Double.valueOf(0.35d)}), Row.of(new Object[]{0, Double.valueOf(0.2d)}), Row.of(new Object[]{1, Double.valueOf(0.1d)}));
    private static final List<Row> INPUT_DATA_WITH_MULTI_SCORE = Arrays.asList(Row.of(new Object[]{Double.valueOf(1.0d), Vectors.dense(new double[]{0.1d, 0.9d})}), Row.of(new Object[]{Double.valueOf(1.0d), Vectors.dense(new double[]{0.1d, 0.9d})}), Row.of(new Object[]{Double.valueOf(1.0d), Vectors.dense(new double[]{0.1d, 0.9d})}), Row.of(new Object[]{Double.valueOf(0.0d), Vectors.dense(new double[]{0.25d, 0.75d})}), Row.of(new Object[]{Double.valueOf(0.0d), Vectors.dense(new double[]{0.4d, 0.6d})}), Row.of(new Object[]{Double.valueOf(1.0d), Vectors.dense(new double[]{0.1d, 0.9d})}), Row.of(new Object[]{Double.valueOf(1.0d), Vectors.dense(new double[]{0.1d, 0.9d})}), Row.of(new Object[]{Double.valueOf(0.0d), Vectors.dense(new double[]{0.6d, 0.4d})}), Row.of(new Object[]{Double.valueOf(0.0d), Vectors.dense(new double[]{0.7d, 0.3d})}), Row.of(new Object[]{Double.valueOf(1.0d), Vectors.dense(new double[]{0.1d, 0.9d})}), Row.of(new Object[]{Double.valueOf(0.0d), Vectors.dense(new double[]{0.8d, 0.2d})}), Row.of(new Object[]{Double.valueOf(1.0d), Vectors.dense(new double[]{0.9d, 0.1d})}));
    private static final List<Row> INPUT_DATA_WITH_WEIGHT = Arrays.asList(Row.of(new Object[]{Double.valueOf(1.0d), Vectors.dense(new double[]{0.1d, 0.9d}), Double.valueOf(0.8d)}), Row.of(new Object[]{Double.valueOf(1.0d), Vectors.dense(new double[]{0.1d, 0.9d}), Double.valueOf(0.7d)}), Row.of(new Object[]{Double.valueOf(1.0d), Vectors.dense(new double[]{0.1d, 0.9d}), Double.valueOf(0.5d)}), Row.of(new Object[]{Double.valueOf(0.0d), Vectors.dense(new double[]{0.25d, 0.75d}), Double.valueOf(1.2d)}), Row.of(new Object[]{Double.valueOf(0.0d), Vectors.dense(new double[]{0.4d, 0.6d}), Double.valueOf(1.3d)}), Row.of(new Object[]{Double.valueOf(1.0d), Vectors.dense(new double[]{0.1d, 0.9d}), Double.valueOf(1.5d)}), Row.of(new Object[]{Double.valueOf(1.0d), Vectors.dense(new double[]{0.1d, 0.9d}), Double.valueOf(1.4d)}), Row.of(new Object[]{Double.valueOf(0.0d), Vectors.dense(new double[]{0.6d, 0.4d}), Double.valueOf(0.3d)}), Row.of(new Object[]{Double.valueOf(0.0d), Vectors.dense(new double[]{0.7d, 0.3d}), Double.valueOf(0.5d)}), Row.of(new Object[]{Double.valueOf(1.0d), Vectors.dense(new double[]{0.1d, 0.9d}), Double.valueOf(1.9d)}), Row.of(new Object[]{Double.valueOf(0.0d), Vectors.dense(new double[]{0.8d, 0.2d}), Double.valueOf(1.2d)}), Row.of(new Object[]{Double.valueOf(1.0d), Vectors.dense(new double[]{0.9d, 0.1d}), Double.valueOf(1.0d)}));
    private static final double[] EXPECTED_DATA = {0.7691481137909708d, 0.3714285714285714d, 0.6571428571428571d};
    private static final double[] EXPECTED_DATA_M = {0.8571428571428571d, 0.9377705627705628d, 0.8571428571428571d, 0.6488095238095237d};
    private static final double EXPECTED_DATA_W = 0.8911680911680911d;
    private static final double EPS = 1.0E-5d;

    @Before
    public void before() {
        Configuration configuration = new Configuration();
        configuration.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
        this.env = StreamExecutionEnvironment.getExecutionEnvironment(configuration);
        this.env.setParallelism(3);
        this.env.enableCheckpointing(100L);
        this.env.setRestartStrategy(RestartStrategies.noRestart());
        this.tEnv = StreamTableEnvironment.create(this.env);
        this.inputDataTable = this.tEnv.fromDataStream(this.env.fromCollection(INPUT_DATA)).as("label", new String[]{"rawPrediction"});
        this.inputDataTableScore = this.tEnv.fromDataStream(this.env.fromCollection(INPUT_DATA_DOUBLE_RAW)).as("label", new String[]{"rawPrediction"});
        this.inputDataTableWithMultiScore = this.tEnv.fromDataStream(this.env.fromCollection(INPUT_DATA_WITH_MULTI_SCORE)).as("label", new String[]{"rawPrediction"});
        this.inputDataTableWithWeight = this.tEnv.fromDataStream(this.env.fromCollection(INPUT_DATA_WITH_WEIGHT)).as("label", new String[]{"rawPrediction", "weight"});
    }

    @Test
    public void testParam() {
        BinaryClassificationEvaluator binaryClassificationEvaluator = new BinaryClassificationEvaluator();
        Assert.assertEquals("label", binaryClassificationEvaluator.getLabelCol());
        Assert.assertNull(binaryClassificationEvaluator.getWeightCol());
        Assert.assertEquals("rawPrediction", binaryClassificationEvaluator.getRawPredictionCol());
        Assert.assertArrayEquals(new String[]{"areaUnderROC", "areaUnderPR"}, binaryClassificationEvaluator.getMetricsNames());
        ((BinaryClassificationEvaluator) ((BinaryClassificationEvaluator) ((BinaryClassificationEvaluator) binaryClassificationEvaluator.setLabelCol("labelCol")).setRawPredictionCol("raw")).setMetricsNames(new String[]{"areaUnderROC"})).setWeightCol("weight");
        Assert.assertEquals("labelCol", binaryClassificationEvaluator.getLabelCol());
        Assert.assertEquals("weight", binaryClassificationEvaluator.getWeightCol());
        Assert.assertEquals("raw", binaryClassificationEvaluator.getRawPredictionCol());
        Assert.assertArrayEquals(new String[]{"areaUnderROC"}, binaryClassificationEvaluator.getMetricsNames());
    }

    @Test
    public void testEvaluate() {
        Table table = ((BinaryClassificationEvaluator) new BinaryClassificationEvaluator().setMetricsNames(new String[]{"areaUnderPR", "ks", "areaUnderROC"})).transform(new Table[]{this.inputDataTable})[0];
        List list = IteratorUtils.toList(table.execute().collect());
        Assert.assertArrayEquals(new String[]{"areaUnderPR", "ks", "areaUnderROC"}, table.getResolvedSchema().getColumnNames().toArray());
        Row row = (Row) list.get(0);
        for (int i = 0; i < EXPECTED_DATA.length; i++) {
            Assert.assertEquals(EXPECTED_DATA[i], ((Double) row.getFieldAs(i)).doubleValue(), EPS);
        }
    }

    @Test
    public void testInputTypeConversion() {
        this.inputDataTable = TestUtils.convertDataTypesToSparseInt(this.tEnv, this.inputDataTable);
        Assert.assertArrayEquals(new Class[]{Integer.class, SparseVector.class}, TestUtils.getColumnDataTypes(this.inputDataTable));
        Table table = ((BinaryClassificationEvaluator) new BinaryClassificationEvaluator().setMetricsNames(new String[]{"areaUnderPR", "ks", "areaUnderROC"})).transform(new Table[]{this.inputDataTable})[0];
        List list = IteratorUtils.toList(table.execute().collect());
        Assert.assertArrayEquals(new String[]{"areaUnderPR", "ks", "areaUnderROC"}, table.getResolvedSchema().getColumnNames().toArray());
        Row row = (Row) list.get(0);
        for (int i = 0; i < EXPECTED_DATA.length; i++) {
            Assert.assertEquals(EXPECTED_DATA[i], ((Double) row.getFieldAs(i)).doubleValue(), EPS);
        }
    }

    @Test
    public void testEvaluateWithDoubleRaw() {
        Table table = ((BinaryClassificationEvaluator) new BinaryClassificationEvaluator().setMetricsNames(new String[]{"areaUnderPR", "ks", "areaUnderROC"})).transform(new Table[]{this.inputDataTableScore})[0];
        List list = IteratorUtils.toList(table.execute().collect());
        Assert.assertArrayEquals(new String[]{"areaUnderPR", "ks", "areaUnderROC"}, table.getResolvedSchema().getColumnNames().toArray());
        Row row = (Row) list.get(0);
        for (int i = 0; i < EXPECTED_DATA.length; i++) {
            Assert.assertEquals(EXPECTED_DATA[i], ((Double) row.getFieldAs(i)).doubleValue(), EPS);
        }
    }

    @Test
    public void testMoreSubtaskThanData() {
        this.env.setParallelism(15);
        Table table = ((BinaryClassificationEvaluator) new BinaryClassificationEvaluator().setMetricsNames(new String[]{"areaUnderPR", "ks", "areaUnderROC"})).transform(new Table[]{this.inputDataTableScore})[0];
        List list = IteratorUtils.toList(table.execute().collect());
        Assert.assertArrayEquals(new String[]{"areaUnderPR", "ks", "areaUnderROC"}, table.getResolvedSchema().getColumnNames().toArray());
        Row row = (Row) list.get(0);
        for (int i = 0; i < EXPECTED_DATA.length; i++) {
            Assert.assertEquals(EXPECTED_DATA[i], ((Double) row.getFieldAs(i)).doubleValue(), EPS);
        }
    }

    @Test
    public void testEvaluateWithMultiScore() {
        Table table = ((BinaryClassificationEvaluator) new BinaryClassificationEvaluator().setMetricsNames(new String[]{"areaUnderROC", "areaUnderPR", "ks", "areaUnderLorenz"})).transform(new Table[]{this.inputDataTableWithMultiScore})[0];
        List list = IteratorUtils.toList(table.execute().collect());
        Assert.assertArrayEquals(new String[]{"areaUnderROC", "areaUnderPR", "ks", "areaUnderLorenz"}, table.getResolvedSchema().getColumnNames().toArray());
        Row row = (Row) list.get(0);
        for (int i = 0; i < EXPECTED_DATA_M.length; i++) {
            Assert.assertEquals(EXPECTED_DATA_M[i], ((Double) row.getFieldAs(i)).doubleValue(), EPS);
        }
    }

    @Test
    public void testEvaluateWithWeight() {
        Table table = ((BinaryClassificationEvaluator) ((BinaryClassificationEvaluator) new BinaryClassificationEvaluator().setMetricsNames(new String[]{"areaUnderROC"})).setWeightCol("weight")).transform(new Table[]{this.inputDataTableWithWeight})[0];
        List list = IteratorUtils.toList(table.execute().collect());
        Assert.assertArrayEquals(new String[]{"areaUnderROC"}, table.getResolvedSchema().getColumnNames().toArray());
        Assert.assertEquals(EXPECTED_DATA_W, ((Double) ((Row) list.get(0)).getFieldAs(0)).doubleValue(), EPS);
    }

    @Test
    public void testSaveLoadAndEvaluate() throws Exception {
        Table table = TestUtils.saveAndReload(this.tEnv, (BinaryClassificationEvaluator) new BinaryClassificationEvaluator().setMetricsNames(new String[]{"areaUnderPR", "ks", "areaUnderROC"}), this.tempFolder.newFolder().getAbsolutePath()).transform(new Table[]{this.inputDataTable})[0];
        Assert.assertArrayEquals(new String[]{"areaUnderPR", "ks", "areaUnderROC"}, table.getResolvedSchema().getColumnNames().toArray());
        Row row = (Row) IteratorUtils.toList(table.execute().collect()).get(0);
        for (int i = 0; i < EXPECTED_DATA.length; i++) {
            Assert.assertEquals(EXPECTED_DATA[i], ((Double) row.getFieldAs(i)).doubleValue(), EPS);
        }
    }
}
