package org.apache.flink.ml.builder;

import java.io.IOException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import javax.annotation.Nullable;
import org.apache.flink.annotation.PublicEvolving;
import org.apache.flink.ml.api.AlgoOperator;
import org.apache.flink.ml.api.Estimator;
import org.apache.flink.ml.api.Model;
import org.apache.flink.ml.api.Stage;
import org.apache.flink.ml.builder.GraphNode;
import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.ReadWriteUtils;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.util.Preconditions;

@PublicEvolving
/* loaded from: input_file:org/apache/flink/ml/builder/GraphModel.class */
public final class GraphModel implements Model<GraphModel> {
    private static final long serialVersionUID = 6354856913812529398L;
    private final Map<Param<?>, Object> paramMap = new HashMap();
    private final List<GraphNode> nodes;
    private final TableId[] inputIds;
    private final TableId[] outputIds;

    @Nullable
    private final TableId[] inputModelDataIds;

    @Nullable
    private final TableId[] outputModelDataIds;
    private final GraphExecutionHelper executionHelper;

    public GraphModel(List<GraphNode> list, TableId[] tableIdArr, TableId[] tableIdArr2, TableId[] tableIdArr3, TableId[] tableIdArr4) {
        this.nodes = (List) Preconditions.checkNotNull(list);
        this.inputIds = (TableId[]) Preconditions.checkNotNull(tableIdArr);
        this.outputIds = (TableId[]) Preconditions.checkNotNull(tableIdArr2);
        this.inputModelDataIds = tableIdArr3;
        this.outputModelDataIds = tableIdArr4;
        ParamUtils.initializeMapWithDefaultValues(this.paramMap, this);
        this.executionHelper = new GraphExecutionHelper(list);
    }

    @Override // org.apache.flink.ml.api.AlgoOperator
    public Table[] transform(Table... tableArr) {
        Preconditions.checkArgument(this.inputIds.length == tableArr.length, "number of provided tables %s does not match the expected number of tables %s", new Object[]{Integer.valueOf(tableArr.length), Integer.valueOf(this.inputIds.length)});
        this.executionHelper.setTables(this.inputIds, tableArr);
        while (true) {
            GraphNode pollNextReadyNode = this.executionHelper.pollNextReadyNode();
            if (pollNextReadyNode == null) {
                return this.executionHelper.getTables(this.outputIds);
            }
            Stage<?> stage = pollNextReadyNode.stage;
            if (pollNextReadyNode.stageType == GraphNode.StageType.ESTIMATOR) {
                stage = ((Estimator) stage).fit(this.executionHelper.getTables(pollNextReadyNode.estimatorInputIds));
            }
            if (pollNextReadyNode.inputModelDataIds != null) {
                ((Model) stage).setModelData(this.executionHelper.getTables(pollNextReadyNode.inputModelDataIds));
            }
            this.executionHelper.setTables(pollNextReadyNode.outputIds, ((AlgoOperator) stage).transform(this.executionHelper.getTables(pollNextReadyNode.algoOpInputIds)));
            if (pollNextReadyNode.outputModelDataIds != null) {
                this.executionHelper.setTables(pollNextReadyNode.outputModelDataIds, ((Model) stage).getModelData());
            }
        }
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.apache.flink.ml.api.Model
    public GraphModel setModelData(Table... tableArr) {
        Preconditions.checkArgument(this.inputModelDataIds != null, "setModelData() is not supported");
        Preconditions.checkArgument(this.inputModelDataIds.length == tableArr.length, "number of provided tables %s does not match the expected number of tables %s", new Object[]{Integer.valueOf(tableArr.length), Integer.valueOf(this.inputIds.length)});
        this.executionHelper.setTables(this.inputModelDataIds, tableArr);
        return this;
    }

    @Override // org.apache.flink.ml.api.Model
    public Table[] getModelData() {
        Preconditions.checkArgument(this.outputModelDataIds != null);
        return this.executionHelper.getTables(this.outputModelDataIds);
    }

    public Map<Param<?>, Object> getParamMap() {
        return this.paramMap;
    }

    @Override // org.apache.flink.ml.api.Stage
    public void save(String str) throws IOException {
        ReadWriteUtils.saveGraph(this, new GraphData(this.nodes, null, this.inputIds, this.outputIds, this.inputModelDataIds, this.outputModelDataIds), str);
    }

    public static GraphModel load(StreamTableEnvironment streamTableEnvironment, String str) throws IOException {
        return (GraphModel) ReadWriteUtils.loadGraph(streamTableEnvironment, str, GraphModel.class.getName());
    }
}
