package org.apache.flink.ml.api;

import java.nio.file.Files;
import java.nio.file.attribute.FileAttribute;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import org.apache.flink.ml.api.ExampleStages;
import org.apache.flink.ml.builder.Pipeline;
import org.apache.flink.ml.builder.PipelineModel;
import org.apache.flink.ml.servable.api.DataFrame;
import org.apache.flink.ml.servable.api.Row;
import org.apache.flink.ml.servable.types.DataTypes;
import org.apache.flink.ml.util.TestUtils;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.test.util.AbstractTestBase;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;

/* loaded from: input_file:org/apache/flink/ml/api/PipelineTest.class */
public class PipelineTest extends AbstractTestBase {
    private StreamExecutionEnvironment env;
    private StreamTableEnvironment tEnv;

    @Rule
    public final TemporaryFolder tempFolder = new TemporaryFolder();

    @Before
    public void before() {
        this.env = TestUtils.getExecutionEnvironment();
        this.tEnv = StreamTableEnvironment.create(this.env);
    }

    @Test
    public void testPipelineModel() throws Exception {
        PipelineModel pipelineModel = new PipelineModel(Arrays.asList(new ExampleStages.SumModel().m1setModelData(this.tEnv.fromValues(new Object[]{10})), new ExampleStages.SumModel().m1setModelData(this.tEnv.fromValues(new Object[]{20})), new ExampleStages.SumModel().m1setModelData(this.tEnv.fromValues(new Object[]{30}))));
        List singletonList = Collections.singletonList(Arrays.asList(1, 2, 3));
        List asList = Arrays.asList(61, 62, 63);
        TestUtils.executeAndCheckOutput(this.env, pipelineModel, singletonList, asList, null, null);
        String path = Files.createTempDirectory("", new FileAttribute[0]).toString();
        pipelineModel.save(path);
        this.env.execute();
        TestUtils.executeAndCheckOutput(this.env, PipelineModel.load(this.tEnv, path), singletonList, asList, null, null);
    }

    @Test
    public void testPipeline() throws Exception {
        Pipeline pipeline = new Pipeline(Arrays.asList(new ExampleStages.SumModel().m1setModelData(this.tEnv.fromValues(new Object[]{10})), new ExampleStages.SumEstimator(), new ExampleStages.SumModel().m1setModelData(this.tEnv.fromValues(new Object[]{30}))));
        List singletonList = Collections.singletonList(Arrays.asList(1, 2, 3));
        List asList = Arrays.asList(77, 78, 79);
        TestUtils.executeAndCheckOutput(this.env, pipeline, singletonList, asList, null, null);
        String path = Files.createTempDirectory("", new FileAttribute[0]).toString();
        pipeline.save(path);
        this.env.execute();
        TestUtils.executeAndCheckOutput(this.env, Pipeline.load(this.tEnv, path), singletonList, asList, null, null);
    }

    @Test
    public void testSupportServable() {
        Stage sumEstimator = new ExampleStages.SumEstimator();
        Stage unionAlgoOperator = new ExampleStages.UnionAlgoOperator();
        Stage sumModel = new ExampleStages.SumModel();
        Assert.assertTrue(new PipelineModel(Arrays.asList(sumModel, new ExampleStages.SumModel())).supportServable());
        Assert.assertFalse(new PipelineModel(Arrays.asList(sumEstimator, sumModel)).supportServable());
        Assert.assertFalse(new PipelineModel(Arrays.asList(unionAlgoOperator, sumModel)).supportServable());
    }

    @Test
    public void testPipelineModelServable() throws Exception {
        org.apache.flink.ml.servable.TestUtils.assertDataFrameEquals(new DataFrame(Collections.singletonList("input"), Collections.singletonList(DataTypes.INT), Arrays.asList(new Row(Collections.singletonList(61)), new Row(Collections.singletonList(62)), new Row(Collections.singletonList(63)))), TestUtils.saveAndLoadServable(this.tEnv, new PipelineModel(Arrays.asList(new ExampleStages.SumModel().m1setModelData(this.tEnv.fromValues(new Object[]{10})), new ExampleStages.SumModel().m1setModelData(this.tEnv.fromValues(new Object[]{20})), new ExampleStages.SumModel().m1setModelData(this.tEnv.fromValues(new Object[]{30})))), this.tempFolder.newFolder().getAbsolutePath(), PipelineModel::loadServable).transform(new DataFrame(Collections.singletonList("input"), Collections.singletonList(DataTypes.INT), Arrays.asList(new Row(Collections.singletonList(1)), new Row(Collections.singletonList(2)), new Row(Collections.singletonList(3))))));
    }
}
