package org.apache.flink.ml.pipeline;

import org.apache.flink.ml.api.core.Estimator;
import org.apache.flink.ml.api.core.Transformer;
import org.apache.flink.ml.common.MLEnvironment;
import org.apache.flink.ml.common.MLEnvironmentFactory;
import org.apache.flink.table.api.Table;
import org.junit.Test;

/* loaded from: input_file:org/apache/flink/ml/pipeline/PipelineStageTestBase.class */
abstract class PipelineStageTestBase {
    @Test(expected = IllegalArgumentException.class)
    public void testMismatchTableEnvironment() {
        Long newMLEnvironmentId = MLEnvironmentFactory.getNewMLEnvironmentId();
        MLEnvironment mLEnvironment = MLEnvironmentFactory.get(newMLEnvironmentId);
        Table fromDataSet = mLEnvironment.getBatchTableEnvironment().fromDataSet(mLEnvironment.getExecutionEnvironment().fromElements(new Integer[]{1, 2, 3}));
        Estimator createPipelineStage = createPipelineStage();
        createPipelineStage.setMLEnvironmentId(newMLEnvironmentId);
        if (createPipelineStage instanceof EstimatorBase) {
            createPipelineStage.fit(MLEnvironmentFactory.getDefault().getBatchTableEnvironment(), fromDataSet);
        } else {
            ((Transformer) createPipelineStage).transform(MLEnvironmentFactory.getDefault().getBatchTableEnvironment(), fromDataSet);
        }
    }

    @Test(expected = IllegalArgumentException.class)
    public void testNullInputTable() {
        Long newMLEnvironmentId = MLEnvironmentFactory.getNewMLEnvironmentId();
        MLEnvironment mLEnvironment = MLEnvironmentFactory.get(newMLEnvironmentId);
        Estimator createPipelineStage = createPipelineStage();
        createPipelineStage.setMLEnvironmentId(newMLEnvironmentId);
        if (createPipelineStage instanceof Estimator) {
            createPipelineStage.fit(mLEnvironment.getBatchTableEnvironment(), (Table) null);
        } else {
            ((Transformer) createPipelineStage).transform(mLEnvironment.getBatchTableEnvironment(), (Table) null);
        }
    }

    protected abstract PipelineStageBase createPipelineStage();
}
