package org.apache.flink.ml.classification;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import org.apache.commons.collections.IteratorUtils;
import org.apache.flink.api.common.JobID;
import org.apache.flink.api.common.restartstrategy.RestartStrategies;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.typeutils.RowTypeInfo;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.configuration.RestOptions;
import org.apache.flink.metrics.Gauge;
import org.apache.flink.ml.classification.logisticregression.LogisticRegression;
import org.apache.flink.ml.classification.logisticregression.LogisticRegressionModelData;
import org.apache.flink.ml.classification.logisticregression.LogisticRegressionModelDataUtil;
import org.apache.flink.ml.classification.logisticregression.OnlineLogisticRegression;
import org.apache.flink.ml.classification.logisticregression.OnlineLogisticRegressionModel;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.SparseVector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.ml.util.InMemorySinkFunction;
import org.apache.flink.ml.util.InMemorySourceFunction;
import org.apache.flink.ml.util.TestUtils;
import org.apache.flink.runtime.client.JobStatusMessage;
import org.apache.flink.runtime.jobgraph.JobGraph;
import org.apache.flink.runtime.minicluster.MiniCluster;
import org.apache.flink.runtime.minicluster.MiniClusterConfiguration;
import org.apache.flink.runtime.testutils.InMemoryReporter;
import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.types.Row;
import org.apache.flink.util.TestLogger;
import org.junit.After;
import org.junit.AfterClass;
import org.junit.Assert;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;

/* loaded from: input_file:org/apache/flink/ml/classification/OnlineLogisticRegressionTest.class */
public class OnlineLogisticRegressionTest extends TestLogger {

    @Rule
    public final TemporaryFolder tempFolder = new TemporaryFolder();
    private static final int numTaskManagers = 2;
    private static final int numSlotsPerTaskManager = 2;
    private long currentModelDataVersion;
    private InMemorySourceFunction<Row> trainDenseSource;
    private InMemorySourceFunction<Row> predictDenseSource;
    private InMemorySourceFunction<Row> trainSparseSource;
    private InMemorySourceFunction<Row> predictSparseSource;
    private InMemorySinkFunction<Row> outputSink;
    private InMemorySinkFunction<LogisticRegressionModelData> modelDataSink;
    private static InMemoryReporter reporter;
    private static MiniCluster miniCluster;
    private static StreamExecutionEnvironment env;
    private static StreamTableEnvironment tEnv;
    private Table offlineTrainDenseTable;
    private Table onlineTrainDenseTable;
    private Table onlinePredictDenseTable;
    private Table onlineTrainSparseTable;
    private Table onlinePredictSparseTable;
    private Table initDenseModel;
    private Table initSparseModel;
    private static final double[] ONE_ARRAY = {1.0d, 1.0d, 1.0d};
    private static final Row[] TRAIN_DENSE_ROWS_1 = {Row.of(new Object[]{Vectors.dense(new double[]{0.1d, 2.0d}), Double.valueOf(0.0d)}), Row.of(new Object[]{Vectors.dense(new double[]{0.2d, 2.0d}), Double.valueOf(0.0d)}), Row.of(new Object[]{Vectors.dense(new double[]{0.3d, 2.0d}), Double.valueOf(0.0d)}), Row.of(new Object[]{Vectors.dense(new double[]{0.4d, 2.0d}), Double.valueOf(0.0d)}), Row.of(new Object[]{Vectors.dense(new double[]{0.5d, 2.0d}), Double.valueOf(0.0d)}), Row.of(new Object[]{Vectors.dense(new double[]{11.0d, 12.0d}), Double.valueOf(1.0d)}), Row.of(new Object[]{Vectors.dense(new double[]{12.0d, 11.0d}), Double.valueOf(1.0d)}), Row.of(new Object[]{Vectors.dense(new double[]{13.0d, 12.0d}), Double.valueOf(1.0d)}), Row.of(new Object[]{Vectors.dense(new double[]{14.0d, 12.0d}), Double.valueOf(1.0d)}), Row.of(new Object[]{Vectors.dense(new double[]{15.0d, 12.0d}), Double.valueOf(1.0d)})};
    private static final Row[] TRAIN_DENSE_ROWS_2 = {Row.of(new Object[]{Vectors.dense(new double[]{0.2d, 3.0d}), Double.valueOf(0.0d)}), Row.of(new Object[]{Vectors.dense(new double[]{0.8d, 1.0d}), Double.valueOf(0.0d)}), Row.of(new Object[]{Vectors.dense(new double[]{0.7d, 1.0d}), Double.valueOf(0.0d)}), Row.of(new Object[]{Vectors.dense(new double[]{0.6d, 2.0d}), Double.valueOf(0.0d)}), Row.of(new Object[]{Vectors.dense(new double[]{0.2d, 2.0d}), Double.valueOf(0.0d)}), Row.of(new Object[]{Vectors.dense(new double[]{14.0d, 17.0d}), Double.valueOf(1.0d)}), Row.of(new Object[]{Vectors.dense(new double[]{15.0d, 10.0d}), Double.valueOf(1.0d)}), Row.of(new Object[]{Vectors.dense(new double[]{16.0d, 16.0d}), Double.valueOf(1.0d)}), Row.of(new Object[]{Vectors.dense(new double[]{17.0d, 10.0d}), Double.valueOf(1.0d)}), Row.of(new Object[]{Vectors.dense(new double[]{18.0d, 13.0d}), Double.valueOf(1.0d)})};
    private static final Row[] PREDICT_DENSE_ROWS = {Row.of(new Object[]{Vectors.dense(new double[]{0.8d, 2.7d}), Double.valueOf(0.0d)}), Row.of(new Object[]{Vectors.dense(new double[]{15.5d, 11.2d}), Double.valueOf(1.0d)})};
    private static final int defaultParallelism = 4;
    private static final Row[] TRAIN_SPARSE_ROWS_1 = {Row.of(new Object[]{Vectors.sparse(10, new int[]{1, 3, defaultParallelism}, ONE_ARRAY), Double.valueOf(0.0d), Double.valueOf(1.0d)}), Row.of(new Object[]{Vectors.sparse(10, new int[]{0, 2, 3}, ONE_ARRAY), Double.valueOf(0.0d), Double.valueOf(1.4d)}), Row.of(new Object[]{Vectors.sparse(10, new int[]{0, 3, defaultParallelism}, ONE_ARRAY), Double.valueOf(0.0d), Double.valueOf(1.3d)}), Row.of(new Object[]{Vectors.sparse(10, new int[]{2, 3, defaultParallelism}, ONE_ARRAY), Double.valueOf(0.0d), Double.valueOf(1.4d)}), Row.of(new Object[]{Vectors.sparse(10, new int[]{1, 3, defaultParallelism}, ONE_ARRAY), Double.valueOf(0.0d), Double.valueOf(1.6d)}), Row.of(new Object[]{Vectors.sparse(10, new int[]{6, 7, 8}, ONE_ARRAY), Double.valueOf(1.0d), Double.valueOf(1.8d)}), Row.of(new Object[]{Vectors.sparse(10, new int[]{6, 8, 9}, ONE_ARRAY), Double.valueOf(1.0d), Double.valueOf(1.9d)}), Row.of(new Object[]{Vectors.sparse(10, new int[]{5, 8, 9}, ONE_ARRAY), Double.valueOf(1.0d), Double.valueOf(1.0d)}), Row.of(new Object[]{Vectors.sparse(10, new int[]{5, 6, 7}, ONE_ARRAY), Double.valueOf(1.0d), Double.valueOf(1.1d)})};
    private static final Row[] TRAIN_SPARSE_ROWS_2 = {Row.of(new Object[]{Vectors.sparse(10, new int[]{1, 2, defaultParallelism}, ONE_ARRAY), Double.valueOf(0.0d), Double.valueOf(1.0d)}), Row.of(new Object[]{Vectors.sparse(10, new int[]{2, 3, defaultParallelism}, ONE_ARRAY), Double.valueOf(0.0d), Double.valueOf(1.3d)}), Row.of(new Object[]{Vectors.sparse(10, new int[]{0, 2, defaultParallelism}, ONE_ARRAY), Double.valueOf(0.0d), Double.valueOf(1.4d)}), Row.of(new Object[]{Vectors.sparse(10, new int[]{1, 3, defaultParallelism}, ONE_ARRAY), Double.valueOf(0.0d), Double.valueOf(1.0d)}), Row.of(new Object[]{Vectors.sparse(10, new int[]{6, 7, 9}, ONE_ARRAY), Double.valueOf(1.0d), Double.valueOf(1.6d)}), Row.of(new Object[]{Vectors.sparse(10, new int[]{7, 8, 9}, ONE_ARRAY), Double.valueOf(1.0d), Double.valueOf(1.8d)}), Row.of(new Object[]{Vectors.sparse(10, new int[]{5, 7, 9}, ONE_ARRAY), Double.valueOf(1.0d), Double.valueOf(1.0d)}), Row.of(new Object[]{Vectors.sparse(10, new int[]{5, 6, 7}, ONE_ARRAY), Double.valueOf(1.0d), Double.valueOf(1.5d)}), Row.of(new Object[]{Vectors.sparse(10, new int[]{5, 8, 9}, ONE_ARRAY), Double.valueOf(1.0d), Double.valueOf(1.0d)})};
    private static final Row[] PREDICT_SPARSE_ROWS = {Row.of(new Object[]{Vectors.sparse(10, new int[]{1, 3, 5}, ONE_ARRAY), Double.valueOf(0.0d)}), Row.of(new Object[]{Vectors.sparse(10, new int[]{5, 8, 9}, ONE_ARRAY), Double.valueOf(1.0d)})};

    @BeforeClass
    public static void beforeClass() throws Exception {
        Configuration configuration = new Configuration();
        configuration.set(RestOptions.BIND_PORT, "18081-19091");
        configuration.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
        reporter = InMemoryReporter.create();
        reporter.addToConfiguration(configuration);
        miniCluster = new MiniCluster(new MiniClusterConfiguration.Builder().setConfiguration(configuration).setNumTaskManagers(2).setNumSlotsPerTaskManager(2).build());
        miniCluster.start();
        env = StreamExecutionEnvironment.getExecutionEnvironment(configuration);
        env.getConfig().enableObjectReuse();
        env.setParallelism(defaultParallelism);
        env.enableCheckpointing(100L);
        env.setRestartStrategy(RestartStrategies.noRestart());
        tEnv = StreamTableEnvironment.create(env);
    }

    @Before
    public void before() throws Exception {
        this.currentModelDataVersion = 0L;
        this.trainDenseSource = new InMemorySourceFunction<>();
        this.predictDenseSource = new InMemorySourceFunction<>();
        this.trainSparseSource = new InMemorySourceFunction<>();
        this.predictSparseSource = new InMemorySourceFunction<>();
        this.outputSink = new InMemorySinkFunction<>();
        this.modelDataSink = new InMemorySinkFunction<>();
        this.offlineTrainDenseTable = tEnv.fromDataStream(env.fromElements(TRAIN_DENSE_ROWS_1)).as("features", new String[]{"label"});
        this.onlineTrainDenseTable = tEnv.fromDataStream(env.addSource(this.trainDenseSource, new RowTypeInfo(new TypeInformation[]{TypeInformation.of(DenseVector.class), Types.DOUBLE}, new String[]{"features", "label"})));
        this.onlinePredictDenseTable = tEnv.fromDataStream(env.addSource(this.predictDenseSource, new RowTypeInfo(new TypeInformation[]{TypeInformation.of(DenseVector.class), Types.DOUBLE}, new String[]{"features", "label"})));
        this.onlineTrainSparseTable = tEnv.fromDataStream(env.addSource(this.trainSparseSource, new RowTypeInfo(new TypeInformation[]{TypeInformation.of(SparseVector.class), Types.DOUBLE, Types.DOUBLE}, new String[]{"features", "label", "weight"})));
        this.onlinePredictSparseTable = tEnv.fromDataStream(env.addSource(this.predictSparseSource, new RowTypeInfo(new TypeInformation[]{TypeInformation.of(SparseVector.class), Types.DOUBLE}, new String[]{"features", "label"})));
        this.initDenseModel = tEnv.fromDataStream(env.fromElements(new Row[]{Row.of(new Object[]{new DenseVector(new double[]{0.41233679404769874d, -0.18088118293232122d}), 0L})}));
        this.initSparseModel = tEnv.fromDataStream(env.fromElements(new Row[]{Row.of(new Object[]{new DenseVector(new double[]{0.01d, 0.01d, 0.01d, 0.01d, 0.01d, 0.01d, 0.01d, 0.01d, 0.01d, 0.01d}), 0L})}));
    }

    @After
    public void after() throws Exception {
        Iterator it = ((Collection) miniCluster.listJobs().get()).iterator();
        while (it.hasNext()) {
            miniCluster.cancelJob(((JobStatusMessage) it.next()).getJobId());
        }
    }

    @AfterClass
    public static void afterClass() throws Exception {
        miniCluster.close();
    }

    private void transformAndOutputData(OnlineLogisticRegressionModel onlineLogisticRegressionModel, boolean z) {
        Table[] tableArr = new Table[1];
        tableArr[0] = z ? this.onlinePredictSparseTable : this.onlinePredictDenseTable;
        tEnv.toDataStream(onlineLogisticRegressionModel.transform(tableArr)[0]).addSink(this.outputSink);
        LogisticRegressionModelDataUtil.getModelDataStream(onlineLogisticRegressionModel.getModelData()[0]).addSink(this.modelDataSink);
    }

    private void waitInitModelDataSetup(JobID jobID) throws InterruptedException {
        while (reporter.findMetrics(jobID, "modelDataVersion").size() < defaultParallelism) {
            Thread.sleep(100L);
        }
        waitModelDataUpdate(jobID);
    }

    private void waitModelDataUpdate(JobID jobID) throws InterruptedException {
        while (true) {
            long longValue = ((Long) reporter.findMetrics(jobID, "modelDataVersion").values().stream().map(metric -> {
                return Long.valueOf(Long.parseLong((String) ((Gauge) metric).getValue()));
            }).min((v0, v1) -> {
                return v0.compareTo(v1);
            }).get()).longValue();
            if (longValue != this.currentModelDataVersion) {
                this.currentModelDataVersion = longValue;
                return;
            }
            Thread.sleep(100L);
        }
    }

    private void predictAndAssert(List<DenseVector> list, boolean z) throws Exception {
        if (z) {
            this.predictSparseSource.addAll(PREDICT_SPARSE_ROWS);
        } else {
            this.predictDenseSource.addAll(PREDICT_DENSE_ROWS);
        }
        List<Row> poll = this.outputSink.poll(z ? PREDICT_SPARSE_ROWS.length : PREDICT_DENSE_ROWS.length);
        ArrayList arrayList = new ArrayList(poll.size());
        Iterator<Row> it = poll.iterator();
        while (it.hasNext()) {
            arrayList.add((DenseVector) it.next().getFieldAs(3));
        }
        arrayList.sort((v0, v1) -> {
            return TestUtils.compare(v0, v1);
        });
        list.sort((v0, v1) -> {
            return TestUtils.compare(v0, v1);
        });
        for (int i = 0; i < arrayList.size(); i++) {
            double[] dArr = ((DenseVector) arrayList.get(i)).values;
            double[] dArr2 = list.get(i).values;
            for (int i2 = 0; i2 < dArr2.length; i2++) {
                Assert.assertEquals(dArr[i2], dArr2[i2], 1.0E-5d);
            }
        }
    }

    private JobID submitJob(JobGraph jobGraph) throws ExecutionException, InterruptedException, TimeoutException {
        return (JobID) miniCluster.submitJob(jobGraph).thenApply((v0) -> {
            return v0.getJobID();
        }).get(1L, TimeUnit.SECONDS);
    }

    @Test
    public void testParam() {
        OnlineLogisticRegression onlineLogisticRegression = new OnlineLogisticRegression();
        Assert.assertEquals("features", onlineLogisticRegression.getFeaturesCol());
        Assert.assertEquals("count", onlineLogisticRegression.getBatchStrategy());
        Assert.assertEquals("label", onlineLogisticRegression.getLabelCol());
        Assert.assertEquals(0.0d, onlineLogisticRegression.getReg(), 1.0E-5d);
        Assert.assertEquals(0.0d, onlineLogisticRegression.getElasticNet(), 1.0E-5d);
        Assert.assertEquals(0.1d, onlineLogisticRegression.getAlpha().doubleValue(), 1.0E-5d);
        Assert.assertEquals(0.1d, onlineLogisticRegression.getBeta().doubleValue(), 1.0E-5d);
        Assert.assertEquals(32L, onlineLogisticRegression.getGlobalBatchSize());
        ((OnlineLogisticRegression) ((OnlineLogisticRegression) ((OnlineLogisticRegression) ((OnlineLogisticRegression) ((OnlineLogisticRegression) ((OnlineLogisticRegression) onlineLogisticRegression.setFeaturesCol("test_feature")).setLabelCol("test_label")).setGlobalBatchSize(5)).setReg(Double.valueOf(0.5d))).setElasticNet(Double.valueOf(0.25d))).setAlpha(Double.valueOf(0.1d))).setBeta(Double.valueOf(0.2d));
        Assert.assertEquals("test_feature", onlineLogisticRegression.getFeaturesCol());
        Assert.assertEquals("test_label", onlineLogisticRegression.getLabelCol());
        Assert.assertEquals(0.5d, onlineLogisticRegression.getReg(), 1.0E-5d);
        Assert.assertEquals(0.25d, onlineLogisticRegression.getElasticNet(), 1.0E-5d);
        Assert.assertEquals(0.1d, onlineLogisticRegression.getAlpha().doubleValue(), 1.0E-5d);
        Assert.assertEquals(0.2d, onlineLogisticRegression.getBeta().doubleValue(), 1.0E-5d);
        Assert.assertEquals(5L, onlineLogisticRegression.getGlobalBatchSize());
        OnlineLogisticRegressionModel onlineLogisticRegressionModel = new OnlineLogisticRegressionModel();
        Assert.assertEquals("features", onlineLogisticRegressionModel.getFeaturesCol());
        Assert.assertEquals("modelVersion", onlineLogisticRegressionModel.getModelVersionCol());
        Assert.assertEquals("prediction", onlineLogisticRegressionModel.getPredictionCol());
        Assert.assertEquals("rawPrediction", onlineLogisticRegressionModel.getRawPredictionCol());
        ((OnlineLogisticRegressionModel) ((OnlineLogisticRegressionModel) ((OnlineLogisticRegressionModel) onlineLogisticRegressionModel.setFeaturesCol("test_feature")).setPredictionCol("pred")).setModelVersionCol("version")).setRawPredictionCol("raw");
        Assert.assertEquals("test_feature", onlineLogisticRegressionModel.getFeaturesCol());
        Assert.assertEquals("version", onlineLogisticRegressionModel.getModelVersionCol());
        Assert.assertEquals("pred", onlineLogisticRegressionModel.getPredictionCol());
        Assert.assertEquals("raw", onlineLogisticRegressionModel.getRawPredictionCol());
    }

    @Test
    public void testDenseFitAndPredict() throws Exception {
        List<DenseVector> asList = Arrays.asList(new DenseVector(new double[]{0.04481034155642882d, 0.9551896584435712d}), new DenseVector(new double[]{0.5353966697318491d, 0.4646033302681509d}));
        List<DenseVector> asList2 = Arrays.asList(new DenseVector(new double[]{0.013104324065967066d, 0.9868956759340329d}), new DenseVector(new double[]{0.5095144380001769d, 0.49048556199982307d}));
        transformAndOutputData(((OnlineLogisticRegression) ((OnlineLogisticRegression) ((OnlineLogisticRegression) ((OnlineLogisticRegression) ((OnlineLogisticRegression) ((OnlineLogisticRegression) new OnlineLogisticRegression().setFeaturesCol("features")).setLabelCol("label")).setPredictionCol("prediction")).setReg(Double.valueOf(0.2d))).setElasticNet(Double.valueOf(0.5d))).setGlobalBatchSize(10)).setInitialModelData(this.initDenseModel).fit(new Table[]{this.onlineTrainDenseTable}), false);
        JobID submitJob = submitJob(env.getStreamGraph().getJobGraph());
        this.trainDenseSource.addAll(TRAIN_DENSE_ROWS_1);
        waitInitModelDataSetup(submitJob);
        predictAndAssert(asList, false);
        this.trainDenseSource.addAll(TRAIN_DENSE_ROWS_2);
        waitModelDataUpdate(submitJob);
        predictAndAssert(asList2, false);
    }

    @Test
    public void testSparseFitAndPredict() throws Exception {
        List<DenseVector> asList = Arrays.asList(new DenseVector(new double[]{0.4452309884735286d, 0.5547690115264714d}), new DenseVector(new double[]{0.5105551725414953d, 0.4894448274585047d}));
        List<DenseVector> asList2 = Arrays.asList(new DenseVector(new double[]{0.40310431554310666d, 0.5968956844568933d}), new DenseVector(new double[]{0.5249618837373886d, 0.4750381162626114d}));
        transformAndOutputData(((OnlineLogisticRegression) ((OnlineLogisticRegression) ((OnlineLogisticRegression) ((OnlineLogisticRegression) ((OnlineLogisticRegression) ((OnlineLogisticRegression) new OnlineLogisticRegression().setFeaturesCol("features")).setLabelCol("label")).setPredictionCol("prediction")).setReg(Double.valueOf(0.2d))).setElasticNet(Double.valueOf(0.5d))).setGlobalBatchSize(9)).setInitialModelData(this.initSparseModel).fit(new Table[]{this.onlineTrainSparseTable}), true);
        JobID submitJob = submitJob(env.getStreamGraph().getJobGraph());
        this.trainSparseSource.addAll(TRAIN_SPARSE_ROWS_1);
        waitInitModelDataSetup(submitJob);
        predictAndAssert(asList, true);
        this.trainSparseSource.addAll(TRAIN_SPARSE_ROWS_2);
        waitModelDataUpdate(submitJob);
        predictAndAssert(asList2, true);
    }

    @Test
    public void testFitAndPredictWithWeightCol() throws Exception {
        List<DenseVector> asList = Arrays.asList(new DenseVector(new double[]{0.452491993753382d, 0.547508006246618d}), new DenseVector(new double[]{0.5069192929506545d, 0.4930807070493455d}));
        List<DenseVector> asList2 = Arrays.asList(new DenseVector(new double[]{0.41108882806164193d, 0.5889111719383581d}), new DenseVector(new double[]{0.5247727600974581d, 0.4752272399025419d}));
        transformAndOutputData(((OnlineLogisticRegression) ((OnlineLogisticRegression) ((OnlineLogisticRegression) ((OnlineLogisticRegression) ((OnlineLogisticRegression) ((OnlineLogisticRegression) ((OnlineLogisticRegression) new OnlineLogisticRegression().setFeaturesCol("features")).setLabelCol("label")).setWeightCol("weight")).setPredictionCol("prediction")).setReg(Double.valueOf(0.2d))).setElasticNet(Double.valueOf(0.5d))).setGlobalBatchSize(9)).setInitialModelData(this.initSparseModel).fit(new Table[]{this.onlineTrainSparseTable}), true);
        JobID submitJob = submitJob(env.getStreamGraph().getJobGraph());
        this.trainSparseSource.addAll(TRAIN_SPARSE_ROWS_1);
        waitInitModelDataSetup(submitJob);
        predictAndAssert(asList, true);
        this.trainSparseSource.addAll(TRAIN_SPARSE_ROWS_2);
        waitModelDataUpdate(submitJob);
        predictAndAssert(asList2, true);
    }

    @Test
    public void testGenerateRandomModelData() throws Exception {
        Row row = (Row) IteratorUtils.toList(tEnv.toDataStream(LogisticRegressionModelDataUtil.generateRandomModelData(tEnv, 2, 2022)).executeAndCollect()).get(0);
        Assert.assertEquals(2L, ((DenseVector) row.getField(0)).size());
        Assert.assertEquals(0L, row.getField(1));
    }

    @Test
    public void testInitWithLogisticRegression() throws Exception {
        List<DenseVector> asList = Arrays.asList(new DenseVector(new double[]{0.037327343811250024d, 0.96267265618875d}), new DenseVector(new double[]{0.5684728224189707d, 0.4315271775810293d}));
        List<DenseVector> asList2 = Arrays.asList(new DenseVector(new double[]{0.007758574555505882d, 0.9922414254444941d}), new DenseVector(new double[]{0.5257216567388069d, 0.4742783432611931d}));
        transformAndOutputData(((OnlineLogisticRegression) ((OnlineLogisticRegression) ((OnlineLogisticRegression) ((OnlineLogisticRegression) ((OnlineLogisticRegression) new OnlineLogisticRegression().setFeaturesCol("features")).setPredictionCol("prediction")).setReg(Double.valueOf(0.2d))).setElasticNet(Double.valueOf(0.5d))).setGlobalBatchSize(10)).setInitialModelData(((LogisticRegression) ((LogisticRegression) ((LogisticRegression) new LogisticRegression().setLabelCol("label")).setFeaturesCol("features")).setPredictionCol("prediction")).fit(new Table[]{this.offlineTrainDenseTable}).getModelData()[0]).fit(new Table[]{this.onlineTrainDenseTable}), false);
        JobID submitJob = submitJob(env.getStreamGraph().getJobGraph());
        this.trainDenseSource.addAll(TRAIN_DENSE_ROWS_1);
        waitInitModelDataSetup(submitJob);
        predictAndAssert(asList, false);
        this.trainDenseSource.addAll(TRAIN_DENSE_ROWS_2);
        waitModelDataUpdate(submitJob);
        predictAndAssert(asList2, false);
    }

    @Test
    public void testBatchSizeLessThanParallelism() {
        try {
            ((OnlineLogisticRegression) ((OnlineLogisticRegression) ((OnlineLogisticRegression) ((OnlineLogisticRegression) new OnlineLogisticRegression().setInitialModelData(this.initDenseModel).setReg(Double.valueOf(0.2d))).setElasticNet(Double.valueOf(0.5d))).setGlobalBatchSize(2)).setLabelCol("label")).fit(new Table[]{this.onlineTrainDenseTable});
            Assert.fail("Expected IllegalStateException");
        } catch (Exception e) {
            Throwable th = e;
            while (true) {
                Throwable th2 = th;
                if (th2.getCause() == null) {
                    Assert.assertEquals(IllegalStateException.class, th2.getClass());
                    Assert.assertEquals("There are more subtasks in the training process than the number of elements in each batch. Some subtasks might be idling forever.", th2.getMessage());
                    return;
                }
                th = th2.getCause();
            }
        }
    }

    @Test
    public void testSaveAndReload() throws Exception {
        List<DenseVector> asList = Arrays.asList(new DenseVector(new double[]{0.04481034155642882d, 0.9551896584435712d}), new DenseVector(new double[]{0.5353966697318491d, 0.4646033302681509d}));
        List<DenseVector> asList2 = Arrays.asList(new DenseVector(new double[]{0.013104324065967066d, 0.9868956759340329d}), new DenseVector(new double[]{0.5095144380001769d, 0.49048556199982307d}));
        OnlineLogisticRegression initialModelData = ((OnlineLogisticRegression) ((OnlineLogisticRegression) ((OnlineLogisticRegression) ((OnlineLogisticRegression) ((OnlineLogisticRegression) new OnlineLogisticRegression().setFeaturesCol("features")).setPredictionCol("prediction")).setReg(Double.valueOf(0.2d))).setElasticNet(Double.valueOf(0.5d))).setGlobalBatchSize(10)).setInitialModelData(this.initDenseModel);
        String absolutePath = this.tempFolder.newFolder().getAbsolutePath();
        initialModelData.save(absolutePath);
        miniCluster.executeJobBlocking(env.getStreamGraph().getJobGraph());
        OnlineLogisticRegressionModel fit = OnlineLogisticRegression.load(tEnv, absolutePath).fit(new Table[]{this.onlineTrainDenseTable});
        String absolutePath2 = this.tempFolder.newFolder().getAbsolutePath();
        fit.save(absolutePath2);
        OnlineLogisticRegressionModel load = OnlineLogisticRegressionModel.load(tEnv, absolutePath2);
        load.setModelData(fit.getModelData());
        transformAndOutputData(load, false);
        JobID submitJob = submitJob(env.getStreamGraph().getJobGraph());
        this.trainDenseSource.addAll(TRAIN_DENSE_ROWS_1);
        waitInitModelDataSetup(submitJob);
        predictAndAssert(asList, false);
        this.trainDenseSource.addAll(TRAIN_DENSE_ROWS_2);
        waitModelDataUpdate(submitJob);
        predictAndAssert(asList2, false);
    }

    @Test
    public void testGetModelData() throws Exception {
        transformAndOutputData(((OnlineLogisticRegression) ((OnlineLogisticRegression) ((OnlineLogisticRegression) ((OnlineLogisticRegression) ((OnlineLogisticRegression) new OnlineLogisticRegression().setFeaturesCol("features")).setPredictionCol("prediction")).setReg(Double.valueOf(0.2d))).setElasticNet(Double.valueOf(0.5d))).setGlobalBatchSize(10)).setInitialModelData(this.initDenseModel).fit(new Table[]{this.onlineTrainDenseTable}), false);
        submitJob(env.getStreamGraph().getJobGraph());
        this.trainDenseSource.addAll(TRAIN_DENSE_ROWS_1);
        LogisticRegressionModelData poll = this.modelDataSink.poll();
        LogisticRegressionModelData logisticRegressionModelData = new LogisticRegressionModelData(new DenseVector(new double[]{0.2994527071464283d, -0.1412541067743284d}), 1L);
        Assert.assertArrayEquals(logisticRegressionModelData.coefficient.values, poll.coefficient.values, 1.0E-5d);
        Assert.assertEquals(logisticRegressionModelData.modelVersion, poll.modelVersion);
    }

    @Test
    public void testSetModelData() throws Exception {
        LogisticRegressionModelData logisticRegressionModelData = new LogisticRegressionModelData(new DenseVector(new double[]{0.085d, -0.22d}), 1L);
        LogisticRegressionModelData logisticRegressionModelData2 = new LogisticRegressionModelData(new DenseVector(new double[]{0.075d, -0.28d}), 2L);
        List<DenseVector> asList = Arrays.asList(new DenseVector(new double[]{0.6285496932692606d, 0.3714503067307394d}), new DenseVector(new double[]{0.7588710471221473d, 0.24112895287785274d}));
        List<DenseVector> asList2 = Arrays.asList(new DenseVector(new double[]{0.6673003248270917d, 0.3326996751729083d}), new DenseVector(new double[]{0.8779865510655934d, 0.12201344893440658d}));
        InMemorySourceFunction inMemorySourceFunction = new InMemorySourceFunction();
        transformAndOutputData((OnlineLogisticRegressionModel) ((OnlineLogisticRegressionModel) new OnlineLogisticRegressionModel().setModelData(new Table[]{tEnv.fromDataStream(env.addSource(inMemorySourceFunction, TypeInformation.of(LogisticRegressionModelData.class)))}).setFeaturesCol("features")).setPredictionCol("prediction"), false);
        JobID submitJob = submitJob(env.getStreamGraph().getJobGraph());
        inMemorySourceFunction.addAll(logisticRegressionModelData);
        waitInitModelDataSetup(submitJob);
        predictAndAssert(asList, false);
        inMemorySourceFunction.addAll(logisticRegressionModelData2);
        waitModelDataUpdate(submitJob);
        predictAndAssert(asList2, false);
    }
}
