package org.apache.flink.ml.api.core;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.apache.flink.annotation.PublicEvolving;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.core.JsonProcessingException;
import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.ObjectMapper;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.TableEnvironment;
import org.apache.flink.util.InstantiationUtil;

@PublicEvolving
/* loaded from: input_file:org/apache/flink/ml/api/core/Pipeline.class */
public final class Pipeline implements Estimator<Pipeline, Pipeline>, Transformer<Pipeline>, Model<Pipeline> {
    private static final long serialVersionUID = 1;
    private final List<PipelineStage> stages = new ArrayList();
    private final Params params = new Params();
    private int lastEstimatorIndex = -1;

    public Pipeline() {
    }

    public Pipeline(String str) {
        loadJson(str);
    }

    public Pipeline(List<PipelineStage> list) {
        Iterator<PipelineStage> it = list.iterator();
        while (it.hasNext()) {
            appendStage(it.next());
        }
    }

    private static boolean isStageNeedFit(PipelineStage pipelineStage) {
        return ((pipelineStage instanceof Pipeline) && ((Pipeline) pipelineStage).needFit()) || (!(pipelineStage instanceof Pipeline) && (pipelineStage instanceof Estimator));
    }

    public Pipeline appendStage(PipelineStage pipelineStage) {
        if (isStageNeedFit(pipelineStage)) {
            this.lastEstimatorIndex = this.stages.size();
        } else if (!(pipelineStage instanceof Transformer)) {
            throw new RuntimeException("All PipelineStages should be Estimator or Transformer, got:" + pipelineStage.getClass().getSimpleName());
        }
        this.stages.add(pipelineStage);
        return this;
    }

    public List<PipelineStage> getStages() {
        return Collections.unmodifiableList(this.stages);
    }

    public boolean needFit() {
        return getIndexOfLastEstimator() >= 0;
    }

    @Override // org.apache.flink.ml.api.misc.param.WithParams
    public Params getParams() {
        return this.params;
    }

    private int getIndexOfLastEstimator() {
        return this.lastEstimatorIndex;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.apache.flink.ml.api.core.Estimator
    public Pipeline fit(TableEnvironment tableEnvironment, Table table) {
        ArrayList arrayList = new ArrayList(this.stages.size());
        int indexOfLastEstimator = getIndexOfLastEstimator();
        for (int i = 0; i < this.stages.size(); i++) {
            PipelineStage pipelineStage = this.stages.get(i);
            if (i <= indexOfLastEstimator) {
                Transformer fit = isStageNeedFit(pipelineStage) ? ((Estimator) pipelineStage).fit(tableEnvironment, table) : (Transformer) pipelineStage;
                arrayList.add(fit);
                table = fit.transform(tableEnvironment, table);
            } else {
                arrayList.add(pipelineStage);
            }
        }
        return new Pipeline(arrayList);
    }

    @Override // org.apache.flink.ml.api.core.Transformer
    public Table transform(TableEnvironment tableEnvironment, Table table) {
        if (needFit()) {
            throw new RuntimeException("Pipeline contains Estimator, need to fit first.");
        }
        Iterator<PipelineStage> it = this.stages.iterator();
        while (it.hasNext()) {
            table = ((Transformer) it.next()).transform(tableEnvironment, table);
        }
        return table;
    }

    @Override // org.apache.flink.ml.api.core.PipelineStage
    public String toJson() {
        ObjectMapper objectMapper = new ObjectMapper();
        ArrayList arrayList = new ArrayList();
        for (PipelineStage pipelineStage : getStages()) {
            HashMap hashMap = new HashMap();
            hashMap.put("stageClassName", pipelineStage.getClass().getTypeName());
            hashMap.put("stageJson", pipelineStage.toJson());
            arrayList.add(hashMap);
        }
        try {
            return objectMapper.writeValueAsString(arrayList);
        } catch (JsonProcessingException e) {
            throw new RuntimeException("Failed to serialize pipeline", e);
        }
    }

    @Override // org.apache.flink.ml.api.core.PipelineStage
    public void loadJson(String str) {
        try {
            Iterator it = ((List) new ObjectMapper().readValue(str, List.class)).iterator();
            while (it.hasNext()) {
                appendStage(restoreInnerStage((Map) it.next()));
            }
        } catch (IOException e) {
            throw new RuntimeException("Failed to deserialize pipeline json:" + str, e);
        }
    }

    private PipelineStage<?> restoreInnerStage(Map<String, String> map) {
        String str = map.get("stageClassName");
        try {
            Class<?> cls = Class.forName(str);
            InstantiationUtil.checkForInstantiation(cls);
            try {
                PipelineStage<?> pipelineStage = (PipelineStage) cls.newInstance();
                pipelineStage.loadJson(map.get("stageJson"));
                return pipelineStage;
            } catch (Exception e) {
                throw new RuntimeException("Class is instantiable but failed to new an instance", e);
            }
        } catch (ClassNotFoundException e2) {
            throw new RuntimeException("PipelineStage class " + str + " not exists", e2);
        }
    }
}
