package org.apache.flink.ml.classification;

import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import org.apache.commons.lang3.exception.ExceptionUtils;
import org.apache.flink.api.common.restartstrategy.RestartStrategies;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.ml.classification.naivebayes.NaiveBayes;
import org.apache.flink.ml.classification.naivebayes.NaiveBayesModel;
import org.apache.flink.ml.classification.naivebayes.NaiveBayesModelData;
import org.apache.flink.ml.linalg.SparseVector;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.ml.util.ReadWriteUtils;
import org.apache.flink.ml.util.TestUtils;
import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.types.Row;
import org.apache.flink.util.CloseableIterator;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;

/* loaded from: input_file:org/apache/flink/ml/classification/NaiveBayesTest.class */
public class NaiveBayesTest {

    @Rule
    public final TemporaryFolder tempFolder = new TemporaryFolder();
    private StreamExecutionEnvironment env;
    private StreamTableEnvironment tEnv;
    private Table trainTable;
    private Table predictTable;
    private Map<Vector, Double> expectedOutput;
    private NaiveBayes estimator;

    @Before
    public void before() {
        Configuration configuration = new Configuration();
        configuration.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
        this.env = StreamExecutionEnvironment.getExecutionEnvironment(configuration);
        this.env.setParallelism(4);
        this.env.enableCheckpointing(100L);
        this.env.setRestartStrategy(RestartStrategies.noRestart());
        this.tEnv = StreamTableEnvironment.create(this.env);
        this.trainTable = this.tEnv.fromDataStream(this.env.fromCollection(Arrays.asList(Row.of(new Object[]{Vectors.dense(new double[]{0.0d, 0.0d}), 11}), Row.of(new Object[]{Vectors.dense(new double[]{1.0d, 0.0d}), 10}), Row.of(new Object[]{Vectors.dense(new double[]{1.0d, 1.0d}), 10})))).as("features", new String[]{"label"});
        this.predictTable = this.tEnv.fromDataStream(this.env.fromCollection(Arrays.asList(Row.of(new Object[]{Vectors.dense(new double[]{0.0d, 1.0d})}), Row.of(new Object[]{Vectors.dense(new double[]{0.0d, 0.0d})}), Row.of(new Object[]{Vectors.dense(new double[]{1.0d, 0.0d})}), Row.of(new Object[]{Vectors.dense(new double[]{1.0d, 1.0d})})))).as("features", new String[0]);
        this.expectedOutput = new HashMap<Vector, Double>() { // from class: org.apache.flink.ml.classification.NaiveBayesTest.1
            {
                put(Vectors.dense(new double[]{0.0d, 1.0d}), Double.valueOf(11.0d));
                put(Vectors.dense(new double[]{0.0d, 0.0d}), Double.valueOf(11.0d));
                put(Vectors.dense(new double[]{1.0d, 0.0d}), Double.valueOf(10.0d));
                put(Vectors.dense(new double[]{1.0d, 1.0d}), Double.valueOf(10.0d));
            }
        };
        this.estimator = (NaiveBayes) ((NaiveBayes) ((NaiveBayes) ((NaiveBayes) ((NaiveBayes) new NaiveBayes().setSmoothing(Double.valueOf(1.0d))).setFeaturesCol("features")).setLabelCol("label")).setPredictionCol("prediction")).setModelType("multinomial");
    }

    private static Map<Vector, Double> executeAndCollect(Table table, String str, String str2) {
        HashMap hashMap = new HashMap();
        CloseableIterator collect = table.execute().collect();
        while (collect.hasNext()) {
            Row row = (Row) collect.next();
            hashMap.put(((Vector) row.getField(str)).toDense(), (Double) row.getField(str2));
        }
        return hashMap;
    }

    @Test
    public void testParam() {
        NaiveBayes naiveBayes = new NaiveBayes();
        Assert.assertEquals("features", naiveBayes.getFeaturesCol());
        Assert.assertEquals("label", naiveBayes.getLabelCol());
        Assert.assertEquals("multinomial", naiveBayes.getModelType());
        Assert.assertEquals("prediction", naiveBayes.getPredictionCol());
        Assert.assertEquals(1.0d, naiveBayes.getSmoothing().doubleValue(), 1.0E-5d);
        ((NaiveBayes) ((NaiveBayes) ((NaiveBayes) naiveBayes.setFeaturesCol("test_feature")).setLabelCol("test_label")).setPredictionCol("test_prediction")).setSmoothing(Double.valueOf(2.0d));
        Assert.assertEquals("test_feature", naiveBayes.getFeaturesCol());
        Assert.assertEquals("test_label", naiveBayes.getLabelCol());
        Assert.assertEquals("test_prediction", naiveBayes.getPredictionCol());
        Assert.assertEquals(2.0d, naiveBayes.getSmoothing().doubleValue(), 1.0E-5d);
        NaiveBayesModel naiveBayesModel = new NaiveBayesModel();
        Assert.assertEquals("features", naiveBayesModel.getFeaturesCol());
        Assert.assertEquals("multinomial", naiveBayesModel.getModelType());
        Assert.assertEquals("prediction", naiveBayesModel.getPredictionCol());
        ((NaiveBayesModel) naiveBayesModel.setFeaturesCol("test_feature")).setPredictionCol("test_prediction");
        Assert.assertEquals("test_feature", naiveBayesModel.getFeaturesCol());
        Assert.assertEquals("test_prediction", naiveBayesModel.getPredictionCol());
    }

    @Test
    public void testFitAndPredict() {
        NaiveBayesModel fit = this.estimator.fit(new Table[]{this.trainTable});
        Assert.assertEquals(this.expectedOutput, executeAndCollect(fit.transform(new Table[]{this.predictTable})[0], fit.getFeaturesCol(), fit.getPredictionCol()));
    }

    @Test
    public void testInputTypeConversion() {
        this.trainTable = TestUtils.convertDataTypesToSparseInt(this.tEnv, this.trainTable);
        this.predictTable = TestUtils.convertDataTypesToSparseInt(this.tEnv, this.predictTable);
        Assert.assertArrayEquals(new Class[]{SparseVector.class, Integer.class}, TestUtils.getColumnDataTypes(this.trainTable));
        Assert.assertArrayEquals(new Class[]{SparseVector.class}, TestUtils.getColumnDataTypes(this.predictTable));
        NaiveBayesModel fit = this.estimator.fit(new Table[]{this.trainTable});
        Assert.assertEquals(this.expectedOutput, executeAndCollect(fit.transform(new Table[]{this.predictTable})[0], fit.getFeaturesCol(), fit.getPredictionCol()));
    }

    @Test
    public void testOutputSchema() {
        this.trainTable = this.trainTable.as("test_features", new String[]{"test_label"});
        this.predictTable = this.predictTable.as("test_features", new String[0]);
        ((NaiveBayes) ((NaiveBayes) this.estimator.setFeaturesCol("test_features")).setLabelCol("test_label")).setPredictionCol("test_prediction");
        NaiveBayesModel fit = this.estimator.fit(new Table[]{this.trainTable});
        Assert.assertEquals(this.expectedOutput, executeAndCollect(fit.transform(new Table[]{this.predictTable})[0], fit.getFeaturesCol(), fit.getPredictionCol()));
    }

    @Test
    public void testPredictUnseenFeature() {
        this.predictTable = this.tEnv.fromDataStream(this.env.fromElements(new Row[]{Row.of(new Object[]{Vectors.dense(new double[]{2.0d, 1.0d})})})).as("features", new String[0]);
        try {
            this.estimator.fit(new Table[]{this.trainTable}).transform(new Table[]{this.predictTable})[0].execute().collect().next();
            Assert.fail("Expected NullPointerException");
        } catch (Exception e) {
            Throwable rootCause = ExceptionUtils.getRootCause(e);
            Assert.assertEquals(NaiveBayesModel.class.getName(), rootCause.getStackTrace()[0].getClassName());
            Assert.assertEquals("calculateProb", rootCause.getStackTrace()[0].getMethodName());
            Assert.assertEquals(NullPointerException.class, rootCause.getClass());
        }
    }

    @Test
    public void testVectorWithDiffLen() {
        this.trainTable = this.tEnv.fromDataStream(this.env.fromCollection(Arrays.asList(Row.of(new Object[]{Vectors.dense(new double[]{0.0d, 0.0d}), Double.valueOf(11.0d)}), Row.of(new Object[]{Vectors.dense(new double[]{1.0d, 0.0d}), Double.valueOf(10.0d)}), Row.of(new Object[]{Vectors.dense(new double[]{1.0d}), Double.valueOf(10.0d)})))).as("features", new String[]{"label"});
        try {
            this.estimator.fit(new Table[]{this.trainTable}).transform(new Table[]{this.trainTable})[0].execute().collect().next();
            Assert.fail("Expected IllegalArgumentException");
        } catch (Exception e) {
            Throwable rootCause = ExceptionUtils.getRootCause(e);
            Assert.assertEquals(IllegalArgumentException.class, rootCause.getClass());
            Assert.assertEquals("Feature vectors should be of equal length.", rootCause.getMessage());
        }
    }

    @Test
    public void testVectorWithDiffLen2() {
        this.trainTable = this.tEnv.fromDataStream(this.env.fromCollection(Arrays.asList(Row.of(new Object[]{Vectors.dense(new double[]{0.0d, 0.0d}), Double.valueOf(11.0d)}), Row.of(new Object[]{Vectors.dense(new double[]{1.0d}), Double.valueOf(10.0d)})))).as("features", new String[]{"label"});
        try {
            this.estimator.fit(new Table[]{this.trainTable}).transform(new Table[]{this.trainTable})[0].execute().collect().next();
            Assert.fail("Expected IllegalArgumentException");
        } catch (Exception e) {
            Throwable rootCause = ExceptionUtils.getRootCause(e);
            Assert.assertEquals(IllegalArgumentException.class, rootCause.getClass());
            Assert.assertEquals("Feature vectors should be of equal length.", rootCause.getMessage());
        }
    }

    @Test
    public void testSaveLoad() throws Exception {
        this.estimator = TestUtils.saveAndReload(this.tEnv, this.estimator, this.tempFolder.newFolder().getAbsolutePath());
        NaiveBayesModel saveAndReload = TestUtils.saveAndReload(this.tEnv, this.estimator.fit(new Table[]{this.trainTable}), this.tempFolder.newFolder().getAbsolutePath());
        Assert.assertEquals(this.expectedOutput, executeAndCollect(saveAndReload.transform(new Table[]{this.predictTable})[0], saveAndReload.getFeaturesCol(), saveAndReload.getPredictionCol()));
    }

    @Test
    public void testGetModelData() throws Exception {
        this.trainTable = this.tEnv.fromDataStream(this.env.fromCollection(Arrays.asList(Row.of(new Object[]{Vectors.dense(new double[]{1.0d, 1.0d}), Double.valueOf(11.0d)}), Row.of(new Object[]{Vectors.dense(new double[]{2.0d, 1.0d}), Double.valueOf(11.0d)})))).as("features", new String[]{"label"});
        NaiveBayesModelData naiveBayesModelData = (NaiveBayesModelData) NaiveBayesModelData.getModelDataStream(this.estimator.fit(new Table[]{this.trainTable}).getModelData()[0]).executeAndCollect().next();
        Assert.assertArrayEquals(new double[]{11.0d}, naiveBayesModelData.labels.toArray(), 1.0E-5d);
        Assert.assertArrayEquals(new double[]{0.0d}, naiveBayesModelData.piArray.toArray(), 1.0E-5d);
        Assert.assertEquals(-0.6931471805599453d, ((Double) naiveBayesModelData.theta[0][0].get(Double.valueOf(1.0d))).doubleValue(), 1.0E-5d);
        Assert.assertEquals(-0.6931471805599453d, ((Double) naiveBayesModelData.theta[0][0].get(Double.valueOf(2.0d))).doubleValue(), 1.0E-5d);
        Assert.assertEquals(0.0d, ((Double) naiveBayesModelData.theta[0][1].get(Double.valueOf(1.0d))).doubleValue(), 1.0E-5d);
    }

    @Test
    public void testSetModelData() {
        NaiveBayesModel fit = this.estimator.fit(new Table[]{this.trainTable});
        NaiveBayesModel modelData = new NaiveBayesModel().setModelData(new Table[]{fit.getModelData()[0]});
        ReadWriteUtils.updateExistingParams(modelData, fit.getParamMap());
        Assert.assertEquals(this.expectedOutput, executeAndCollect(modelData.transform(new Table[]{this.predictTable})[0], modelData.getFeaturesCol(), modelData.getPredictionCol()));
    }
}
