package org.apache.flink.ml.pipeline;

import org.apache.flink.ml.common.MLEnvironment;
import org.apache.flink.ml.common.MLEnvironmentFactory;
import org.apache.flink.ml.operator.batch.BatchOperator;
import org.apache.flink.ml.operator.stream.StreamOperator;
import org.apache.flink.table.api.Table;
import org.junit.Assert;
import org.junit.Test;

/* loaded from: input_file:org/apache/flink/ml/pipeline/EstimatorBaseTest.class */
public class EstimatorBaseTest extends PipelineStageTestBase {

    /* loaded from: input_file:org/apache/flink/ml/pipeline/EstimatorBaseTest$FakeEstimator.class */
    private static class FakeEstimator extends EstimatorBase {
        boolean batchFitted;
        boolean streamFitted;

        private FakeEstimator() {
            this.batchFitted = false;
            this.streamFitted = false;
        }

        public ModelBase fit(BatchOperator batchOperator) {
            this.batchFitted = true;
            return null;
        }

        public ModelBase fit(StreamOperator streamOperator) {
            this.streamFitted = true;
            return null;
        }
    }

    @Override // org.apache.flink.ml.pipeline.PipelineStageTestBase
    protected PipelineStageBase createPipelineStage() {
        return new FakeEstimator();
    }

    @Test
    public void testFitBatchTable() {
        Long newMLEnvironmentId = MLEnvironmentFactory.getNewMLEnvironmentId();
        MLEnvironment mLEnvironment = MLEnvironmentFactory.get(newMLEnvironmentId);
        Table fromDataSet = mLEnvironment.getBatchTableEnvironment().fromDataSet(mLEnvironment.getExecutionEnvironment().fromElements(new Integer[]{1, 2, 3}));
        FakeEstimator fakeEstimator = new FakeEstimator();
        fakeEstimator.setMLEnvironmentId(newMLEnvironmentId);
        fakeEstimator.fit(mLEnvironment.getBatchTableEnvironment(), fromDataSet);
        Assert.assertTrue(fakeEstimator.batchFitted);
        Assert.assertFalse(fakeEstimator.streamFitted);
    }

    @Test
    public void testFitStreamTable() {
        Long newMLEnvironmentId = MLEnvironmentFactory.getNewMLEnvironmentId();
        MLEnvironment mLEnvironment = MLEnvironmentFactory.get(newMLEnvironmentId);
        Table fromDataStream = mLEnvironment.getStreamTableEnvironment().fromDataStream(mLEnvironment.getStreamExecutionEnvironment().fromElements(new Integer[]{1, 2, 3}));
        FakeEstimator fakeEstimator = new FakeEstimator();
        fakeEstimator.setMLEnvironmentId(newMLEnvironmentId);
        fakeEstimator.fit(mLEnvironment.getStreamTableEnvironment(), fromDataStream);
        Assert.assertFalse(fakeEstimator.batchFitted);
        Assert.assertTrue(fakeEstimator.streamFitted);
    }

    @Override // org.apache.flink.ml.pipeline.PipelineStageTestBase
    @Test(expected = IllegalArgumentException.class)
    public /* bridge */ /* synthetic */ void testNullInputTable() {
        super.testNullInputTable();
    }

    @Override // org.apache.flink.ml.pipeline.PipelineStageTestBase
    @Test(expected = IllegalArgumentException.class)
    public /* bridge */ /* synthetic */ void testMismatchTableEnvironment() {
        super.testMismatchTableEnvironment();
    }
}
