package org.apache.flink.ml.api;

import java.nio.file.Files;
import java.nio.file.attribute.FileAttribute;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import org.apache.flink.ml.api.ExampleStages;
import org.apache.flink.ml.builder.Graph;
import org.apache.flink.ml.builder.GraphBuilder;
import org.apache.flink.ml.builder.GraphModel;
import org.apache.flink.ml.builder.TableId;
import org.apache.flink.ml.common.datastream.TableUtils;
import org.apache.flink.ml.util.TestUtils;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.test.util.AbstractTestBase;
import org.junit.Before;
import org.junit.Test;

/* loaded from: input_file:org/apache/flink/ml/api/GraphTest.class */
public class GraphTest extends AbstractTestBase {
    private StreamExecutionEnvironment env;
    private StreamTableEnvironment tEnv;

    @Before
    public void before() {
        this.env = TestUtils.getExecutionEnvironment();
        this.tEnv = StreamTableEnvironment.create(this.env);
    }

    private static void executeSaveLoadAndCheckOutput(StreamTableEnvironment streamTableEnvironment, Stage<?> stage, List<List<Integer>> list, List<Integer> list2, List<List<Integer>> list3, List<Integer> list4, boolean z) throws Exception {
        StreamExecutionEnvironment executionEnvironment = TableUtils.getExecutionEnvironment(streamTableEnvironment);
        TestUtils.executeAndCheckOutput(executionEnvironment, stage, list, list2, list3, list4);
        String path = Files.createTempDirectory("", new FileAttribute[0]).toString();
        stage.save(path);
        if (z) {
            executionEnvironment.execute();
        }
        TestUtils.executeAndCheckOutput(executionEnvironment, stage instanceof Estimator ? Graph.load(streamTableEnvironment, path) : GraphModel.load(streamTableEnvironment, path), list, list2, list3, list4);
    }

    @Test
    public void testGraphModelWithoutEstimator() throws Exception {
        GraphBuilder graphBuilder = new GraphBuilder();
        ExampleStages.SumModel m1setModelData = new ExampleStages.SumModel().m1setModelData(this.tEnv.fromValues(new Object[]{2}));
        ExampleStages.SumModel m1setModelData2 = new ExampleStages.SumModel().m1setModelData(this.tEnv.fromValues(new Object[]{1}));
        ExampleStages.UnionAlgoOperator unionAlgoOperator = new ExampleStages.UnionAlgoOperator();
        TableId createTableId = graphBuilder.createTableId();
        TableId createTableId2 = graphBuilder.createTableId();
        Model buildModel = graphBuilder.buildModel(new TableId[]{createTableId, createTableId2}, new TableId[]{graphBuilder.addAlgoOperator(unionAlgoOperator, new TableId[]{graphBuilder.addAlgoOperator(m1setModelData, new TableId[]{createTableId})[0], graphBuilder.addAlgoOperator(m1setModelData2, new TableId[]{createTableId2})[0]})[0]});
        ArrayList arrayList = new ArrayList();
        arrayList.add(Arrays.asList(1, 2, 3));
        arrayList.add(Arrays.asList(10, 11, 12));
        executeSaveLoadAndCheckOutput(this.tEnv, buildModel, arrayList, Arrays.asList(3, 4, 5, 11, 12, 13), null, null, true);
    }

    @Test
    public void testGraphModelWithEstimator() throws Exception {
        GraphBuilder graphBuilder = new GraphBuilder();
        ExampleStages.SumEstimator sumEstimator = new ExampleStages.SumEstimator();
        ExampleStages.SumEstimator sumEstimator2 = new ExampleStages.SumEstimator();
        ExampleStages.UnionAlgoOperator unionAlgoOperator = new ExampleStages.UnionAlgoOperator();
        TableId createTableId = graphBuilder.createTableId();
        TableId createTableId2 = graphBuilder.createTableId();
        Model buildModel = graphBuilder.buildModel(new TableId[]{createTableId, createTableId2}, new TableId[]{graphBuilder.addAlgoOperator(unionAlgoOperator, new TableId[]{graphBuilder.addEstimator(sumEstimator, new TableId[]{createTableId})[0], graphBuilder.addEstimator(sumEstimator2, new TableId[]{createTableId2})[0]})[0]});
        ArrayList arrayList = new ArrayList();
        arrayList.add(Arrays.asList(1, 2, 3));
        arrayList.add(Arrays.asList(10, 11, 12));
        executeSaveLoadAndCheckOutput(this.tEnv, buildModel, arrayList, Arrays.asList(7, 8, 9, 43, 44, 45), null, null, false);
    }

    @Test
    public void testGraphModelWithSetGetModelData() throws Exception {
        GraphBuilder graphBuilder = new GraphBuilder();
        ExampleStages.SumModel m1setModelData = new ExampleStages.SumModel().m1setModelData(this.tEnv.fromValues(new Object[]{1}));
        ExampleStages.SumModel sumModel = new ExampleStages.SumModel();
        ExampleStages.SumModel m1setModelData2 = new ExampleStages.SumModel().m1setModelData(this.tEnv.fromValues(new Object[]{3}));
        TableId createTableId = graphBuilder.createTableId();
        TableId createTableId2 = graphBuilder.createTableId();
        TableId tableId = graphBuilder.addAlgoOperator(sumModel, new TableId[]{graphBuilder.addAlgoOperator(m1setModelData, new TableId[]{createTableId})[0]})[0];
        graphBuilder.setModelDataOnModel(sumModel, new TableId[]{createTableId2});
        executeSaveLoadAndCheckOutput(this.tEnv, graphBuilder.buildModel(new TableId[]{createTableId}, new TableId[]{graphBuilder.addAlgoOperator(m1setModelData2, new TableId[]{tableId})[0]}, new TableId[]{createTableId2}, new TableId[]{graphBuilder.getModelDataFromModel(m1setModelData2)[0]}), Collections.singletonList(Arrays.asList(1, 2, 3)), Arrays.asList(7, 8, 9), Collections.singletonList(Collections.singletonList(2)), Collections.singletonList(3), true);
    }

    @Test
    public void testGraphWithEstimator() throws Exception {
        GraphBuilder graphBuilder = new GraphBuilder();
        ExampleStages.SumEstimator sumEstimator = new ExampleStages.SumEstimator();
        ExampleStages.SumEstimator sumEstimator2 = new ExampleStages.SumEstimator();
        ExampleStages.UnionAlgoOperator unionAlgoOperator = new ExampleStages.UnionAlgoOperator();
        TableId createTableId = graphBuilder.createTableId();
        TableId createTableId2 = graphBuilder.createTableId();
        Estimator buildEstimator = graphBuilder.buildEstimator(new TableId[]{createTableId, createTableId2}, new TableId[]{graphBuilder.addAlgoOperator(unionAlgoOperator, new TableId[]{graphBuilder.addEstimator(sumEstimator, new TableId[]{createTableId})[0], graphBuilder.addEstimator(sumEstimator2, new TableId[]{createTableId2})[0]})[0]});
        ArrayList arrayList = new ArrayList();
        arrayList.add(Arrays.asList(1, 2, 3));
        arrayList.add(Arrays.asList(10, 11, 12));
        executeSaveLoadAndCheckOutput(this.tEnv, buildEstimator, arrayList, Arrays.asList(7, 8, 9, 43, 44, 45), null, null, false);
    }

    @Test
    public void testGraphWithSetGetModelData() throws Exception {
        GraphBuilder graphBuilder = new GraphBuilder();
        ExampleStages.SumEstimator sumEstimator = new ExampleStages.SumEstimator();
        ExampleStages.SumModel sumModel = new ExampleStages.SumModel();
        ExampleStages.UnionAlgoOperator unionAlgoOperator = new ExampleStages.UnionAlgoOperator();
        TableId createTableId = graphBuilder.createTableId();
        TableId createTableId2 = graphBuilder.createTableId();
        TableId tableId = graphBuilder.addEstimator(sumEstimator, new TableId[]{createTableId})[0];
        TableId tableId2 = graphBuilder.getModelDataFromEstimator(sumEstimator)[0];
        TableId tableId3 = graphBuilder.addAlgoOperator(sumModel, new TableId[]{createTableId2})[0];
        graphBuilder.setModelDataOnModel(sumModel, new TableId[]{tableId2});
        Estimator buildEstimator = graphBuilder.buildEstimator(new TableId[]{createTableId, createTableId2}, new TableId[]{graphBuilder.addAlgoOperator(unionAlgoOperator, new TableId[]{tableId, tableId3})[0]}, (TableId[]) null, new TableId[]{tableId2});
        ArrayList arrayList = new ArrayList();
        arrayList.add(Arrays.asList(1, 2, 3));
        arrayList.add(Arrays.asList(10, 11, 12));
        executeSaveLoadAndCheckOutput(this.tEnv, buildEstimator, arrayList, Arrays.asList(7, 8, 9, 16, 17, 18), null, Collections.singletonList(6), true);
    }
}
