/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.ml.api.core;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.flink.annotation.PublicEvolving;
import org.apache.flink.ml.api.core.Estimator;
import org.apache.flink.ml.api.core.Model;
import org.apache.flink.ml.api.core.PipelineStage;
import org.apache.flink.ml.api.core.Transformer;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.core.JsonProcessingException;
import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.ObjectMapper;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.TableEnvironment;
import org.apache.flink.util.InstantiationUtil;

@PublicEvolving
public final class Pipeline
implements Estimator<Pipeline, Pipeline>,
Transformer<Pipeline>,
Model<Pipeline> {
    private static final long serialVersionUID = 1L;
    private final List<PipelineStage> stages = new ArrayList<PipelineStage>();
    private final Params params = new Params();
    private int lastEstimatorIndex = -1;

    public Pipeline() {
    }

    public Pipeline(String pipelineJson) {
        this.loadJson(pipelineJson);
    }

    public Pipeline(List<PipelineStage> stages) {
        for (PipelineStage s : stages) {
            this.appendStage(s);
        }
    }

    private static boolean isStageNeedFit(PipelineStage stage) {
        return stage instanceof Pipeline && ((Pipeline)stage).needFit() || !(stage instanceof Pipeline) && stage instanceof Estimator;
    }

    public Pipeline appendStage(PipelineStage stage) {
        if (Pipeline.isStageNeedFit(stage)) {
            this.lastEstimatorIndex = this.stages.size();
        } else if (!(stage instanceof Transformer)) {
            throw new RuntimeException("All PipelineStages should be Estimator or Transformer, got:" + stage.getClass().getSimpleName());
        }
        this.stages.add(stage);
        return this;
    }

    public List<PipelineStage> getStages() {
        return Collections.unmodifiableList(this.stages);
    }

    public boolean needFit() {
        return this.getIndexOfLastEstimator() >= 0;
    }

    @Override
    public Params getParams() {
        return this.params;
    }

    private int getIndexOfLastEstimator() {
        return this.lastEstimatorIndex;
    }

    @Override
    public Pipeline fit(TableEnvironment tEnv, Table input) {
        ArrayList<PipelineStage> transformStages = new ArrayList<PipelineStage>(this.stages.size());
        int lastEstimatorIdx = this.getIndexOfLastEstimator();
        for (int i = 0; i < this.stages.size(); ++i) {
            PipelineStage s = this.stages.get(i);
            if (i <= lastEstimatorIdx) {
                boolean needFit = Pipeline.isStageNeedFit(s);
                Transformer<Object> t = needFit ? ((Estimator)s).fit(tEnv, input) : (Transformer)s;
                transformStages.add(t);
                input = t.transform(tEnv, input);
                continue;
            }
            transformStages.add(s);
        }
        return new Pipeline(transformStages);
    }

    @Override
    public Table transform(TableEnvironment tEnv, Table input) {
        if (this.needFit()) {
            throw new RuntimeException("Pipeline contains Estimator, need to fit first.");
        }
        for (PipelineStage s : this.stages) {
            input = ((Transformer)s).transform(tEnv, input);
        }
        return input;
    }

    @Override
    public String toJson() {
        ObjectMapper mapper = new ObjectMapper();
        ArrayList stageJsons = new ArrayList();
        for (PipelineStage s : this.getStages()) {
            HashMap<String, String> stageMap = new HashMap<String, String>();
            stageMap.put("stageClassName", s.getClass().getTypeName());
            stageMap.put("stageJson", s.toJson());
            stageJsons.add(stageMap);
        }
        try {
            return mapper.writeValueAsString(stageJsons);
        }
        catch (JsonProcessingException e) {
            throw new RuntimeException("Failed to serialize pipeline", e);
        }
    }

    @Override
    public void loadJson(String json) {
        List stageJsons;
        ObjectMapper mapper = new ObjectMapper();
        try {
            stageJsons = (List)mapper.readValue(json, List.class);
        }
        catch (IOException e) {
            throw new RuntimeException("Failed to deserialize pipeline json:" + json, e);
        }
        for (Map stageMap : stageJsons) {
            this.appendStage(this.restoreInnerStage(stageMap));
        }
    }

    private PipelineStage<?> restoreInnerStage(Map<String, String> stageMap) {
        PipelineStage s;
        Class<?> clz;
        String className = stageMap.get("stageClassName");
        try {
            clz = Class.forName(className);
        }
        catch (ClassNotFoundException e) {
            throw new RuntimeException("PipelineStage class " + className + " not exists", e);
        }
        InstantiationUtil.checkForInstantiation(clz);
        try {
            s = (PipelineStage)clz.newInstance();
        }
        catch (Exception e) {
            throw new RuntimeException("Class is instantiable but failed to new an instance", e);
        }
        String stageJson = stageMap.get("stageJson");
        s.loadJson(stageJson);
        return s;
    }
}

