package org.apache.flink.ml.builder;

import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import javax.annotation.Nullable;
import org.apache.flink.annotation.PublicEvolving;
import org.apache.flink.ml.api.AlgoOperator;
import org.apache.flink.ml.api.Estimator;
import org.apache.flink.ml.api.Model;
import org.apache.flink.ml.api.Stage;
import org.apache.flink.ml.builder.GraphNode;
import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.ReadWriteUtils;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.util.Preconditions;

@PublicEvolving
/* loaded from: input_file:org/apache/flink/ml/builder/Graph.class */
public final class Graph implements Estimator<Graph, GraphModel> {
    private static final long serialVersionUID = 6354253958813529308L;
    private final Map<Param<?>, Object> paramMap = new HashMap();
    private final List<GraphNode> nodes;
    private final TableId[] estimatorInputIds;
    private final TableId[] modelInputIds;
    private final TableId[] outputIds;

    @Nullable
    private final TableId[] inputModelDataIds;

    @Nullable
    private final TableId[] outputModelDataIds;

    public Graph(List<GraphNode> list, TableId[] tableIdArr, TableId[] tableIdArr2, TableId[] tableIdArr3, TableId[] tableIdArr4, TableId[] tableIdArr5) {
        this.nodes = (List) Preconditions.checkNotNull(list);
        this.estimatorInputIds = (TableId[]) Preconditions.checkNotNull(tableIdArr);
        this.modelInputIds = (TableId[]) Preconditions.checkNotNull(tableIdArr2);
        this.outputIds = (TableId[]) Preconditions.checkNotNull(tableIdArr3);
        this.inputModelDataIds = tableIdArr4;
        this.outputModelDataIds = tableIdArr5;
        ParamUtils.initializeMapWithDefaultValues(this.paramMap, this);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.apache.flink.ml.api.Estimator
    public GraphModel fit(Table... tableArr) {
        Preconditions.checkArgument(this.estimatorInputIds.length == tableArr.length, "number of provided tables %s does not match the expected number of tables %s", new Object[]{Integer.valueOf(tableArr.length), Integer.valueOf(this.estimatorInputIds.length)});
        ArrayList arrayList = new ArrayList();
        GraphExecutionHelper graphExecutionHelper = new GraphExecutionHelper(this.nodes);
        graphExecutionHelper.setTables(this.estimatorInputIds, tableArr);
        while (true) {
            GraphNode pollNextReadyNode = graphExecutionHelper.pollNextReadyNode();
            if (pollNextReadyNode == null) {
                return new GraphModel(arrayList, this.modelInputIds, this.outputIds, this.inputModelDataIds, this.outputModelDataIds);
            }
            Stage<?> stage = pollNextReadyNode.stage;
            if (pollNextReadyNode.stageType == GraphNode.StageType.ESTIMATOR) {
                stage = ((Estimator) stage).fit(graphExecutionHelper.getTables(pollNextReadyNode.estimatorInputIds));
            }
            if (pollNextReadyNode.inputModelDataIds != null) {
                ((Model) stage).setModelData(graphExecutionHelper.getTables(pollNextReadyNode.inputModelDataIds));
            }
            graphExecutionHelper.setTables(pollNextReadyNode.outputIds, ((AlgoOperator) stage).transform(graphExecutionHelper.getTables(pollNextReadyNode.algoOpInputIds)));
            if (pollNextReadyNode.outputModelDataIds != null) {
                graphExecutionHelper.setTables(pollNextReadyNode.outputModelDataIds, ((Model) stage).getModelData());
            }
            arrayList.add(new GraphNode(pollNextReadyNode.nodeId, stage, GraphNode.StageType.ALGO_OPERATOR, null, pollNextReadyNode.algoOpInputIds, pollNextReadyNode.outputIds, pollNextReadyNode.inputModelDataIds, pollNextReadyNode.outputModelDataIds));
        }
    }

    public Map<Param<?>, Object> getParamMap() {
        return this.paramMap;
    }

    @Override // org.apache.flink.ml.api.Stage
    public void save(String str) throws IOException {
        ReadWriteUtils.saveGraph(this, new GraphData(this.nodes, this.estimatorInputIds, this.modelInputIds, this.outputIds, this.inputModelDataIds, this.outputModelDataIds), str);
    }

    public static Graph load(StreamTableEnvironment streamTableEnvironment, String str) throws IOException {
        return (Graph) ReadWriteUtils.loadGraph(streamTableEnvironment, str, Graph.class.getName());
    }
}
