package org.apache.flink.ml.builder;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.flink.annotation.PublicEvolving;
import org.apache.flink.ml.api.AlgoOperator;
import org.apache.flink.ml.api.Estimator;
import org.apache.flink.ml.api.Model;
import org.apache.flink.ml.api.Stage;
import org.apache.flink.ml.builder.GraphNode;

@PublicEvolving
/* loaded from: input_file:org/apache/flink/ml/builder/GraphBuilder.class */
public final class GraphBuilder {
    private int maxOutputLength = 20;
    private int nextTableId = 0;
    private int nextNodeId = 0;
    private final List<GraphNode> nodes = new ArrayList();
    private final Map<Stage<?>, GraphNode> existingNodes = new HashMap();

    public GraphBuilder setMaxOutputTableNum(int i) {
        this.maxOutputLength = i;
        return this;
    }

    public TableId createTableId() {
        int i = this.nextTableId;
        this.nextTableId = i + 1;
        return new TableId(i);
    }

    public TableId[] addAlgoOperator(AlgoOperator<?> algoOperator, TableId... tableIdArr) {
        return addStage(algoOperator, GraphNode.StageType.ALGO_OPERATOR, null, tableIdArr);
    }

    public TableId[] addEstimator(Estimator<?, ?> estimator, TableId... tableIdArr) {
        return addEstimator(estimator, tableIdArr, tableIdArr);
    }

    public TableId[] addEstimator(Estimator<?, ?> estimator, TableId[] tableIdArr, TableId[] tableIdArr2) {
        return addStage(estimator, GraphNode.StageType.ESTIMATOR, tableIdArr, tableIdArr2);
    }

    public void setModelDataOnEstimator(Estimator<?, ?> estimator, TableId... tableIdArr) {
        GraphNode graphNode = this.existingNodes.get(estimator);
        if (graphNode == null) {
            throw new RuntimeException("the Estimator has not been added to the graph");
        }
        if (graphNode.stageType != GraphNode.StageType.ESTIMATOR) {
            throw new RuntimeException("the Estimator was previously added as an AlgoOperator");
        }
        if (graphNode.inputModelDataIds != null) {
            throw new RuntimeException("the model data of this Estimator has already been set");
        }
        graphNode.inputModelDataIds = tableIdArr;
    }

    public void setModelDataOnModel(Model<?> model, TableId... tableIdArr) {
        GraphNode graphNode = this.existingNodes.get(model);
        if (graphNode == null) {
            throw new RuntimeException("the Model has not been added to the graph");
        }
        if (graphNode.stageType != GraphNode.StageType.ALGO_OPERATOR) {
            throw new RuntimeException("the Model was previously added as an Estimator");
        }
        if (graphNode.inputModelDataIds != null) {
            throw new RuntimeException("the model data of this Model has already been set");
        }
        graphNode.inputModelDataIds = tableIdArr;
    }

    public TableId[] getModelDataFromEstimator(Estimator<?, ?> estimator) {
        GraphNode graphNode = this.existingNodes.get(estimator);
        if (graphNode == null) {
            throw new RuntimeException("the Estimator has not been added to the graph");
        }
        if (graphNode.stageType != GraphNode.StageType.ESTIMATOR) {
            throw new RuntimeException("the Estimator was previously added as an AlgoOperator");
        }
        if (graphNode.outputModelDataIds != null) {
            throw new RuntimeException("the model data of this Estimator has already been fetched");
        }
        graphNode.outputModelDataIds = createTableIds(this.maxOutputLength);
        return graphNode.outputModelDataIds;
    }

    public TableId[] getModelDataFromModel(Model<?> model) {
        GraphNode graphNode = this.existingNodes.get(model);
        if (graphNode == null) {
            throw new RuntimeException("the Model has not been added to the graph");
        }
        if (graphNode.stageType != GraphNode.StageType.ALGO_OPERATOR) {
            throw new RuntimeException("the Model was previously added as an Estimator");
        }
        if (graphNode.outputModelDataIds != null) {
            throw new RuntimeException("the model data of this Model has already been fetched");
        }
        graphNode.outputModelDataIds = createTableIds(this.maxOutputLength);
        return graphNode.outputModelDataIds;
    }

    public Estimator<?, ?> buildEstimator(TableId[] tableIdArr, TableId[] tableIdArr2) {
        return buildEstimator(tableIdArr, tableIdArr, tableIdArr2, null, null);
    }

    public Estimator<?, ?> buildEstimator(TableId[] tableIdArr, TableId[] tableIdArr2, TableId[] tableIdArr3, TableId[] tableIdArr4) {
        return buildEstimator(tableIdArr, tableIdArr, tableIdArr2, tableIdArr3, tableIdArr4);
    }

    public Estimator<?, ?> buildEstimator(TableId[] tableIdArr, TableId[] tableIdArr2, TableId[] tableIdArr3, TableId[] tableIdArr4, TableId[] tableIdArr5) {
        return new Graph(this.nodes, tableIdArr, tableIdArr2, tableIdArr3, tableIdArr4, tableIdArr5);
    }

    public AlgoOperator<?> buildAlgoOperator(TableId[] tableIdArr, TableId[] tableIdArr2) {
        return buildModel(tableIdArr, tableIdArr2, null, null);
    }

    public Model<?> buildModel(TableId[] tableIdArr, TableId[] tableIdArr2) {
        return buildModel(tableIdArr, tableIdArr2, null, null);
    }

    public Model<?> buildModel(TableId[] tableIdArr, TableId[] tableIdArr2, TableId[] tableIdArr3, TableId[] tableIdArr4) {
        return new GraphModel(this.nodes, tableIdArr, tableIdArr2, tableIdArr3, tableIdArr4);
    }

    private TableId[] createTableIds(int i) {
        TableId[] tableIdArr = new TableId[i];
        for (int i2 = 0; i2 < i; i2++) {
            tableIdArr[i2] = createTableId();
        }
        return tableIdArr;
    }

    private TableId[] addStage(Stage<?> stage, GraphNode.StageType stageType, TableId[] tableIdArr, TableId[] tableIdArr2) {
        TableId[] createTableIds = createTableIds(this.maxOutputLength);
        if (this.existingNodes.containsKey(stage)) {
            throw new RuntimeException("The stage " + stage + " has already been added.");
        }
        int i = this.nextNodeId;
        this.nextNodeId = i + 1;
        GraphNode graphNode = new GraphNode(i, stage, stageType, tableIdArr, tableIdArr2, createTableIds, null, null);
        this.nodes.add(graphNode);
        this.existingNodes.put(stage, graphNode);
        return createTableIds;
    }
}
