package org.apache.flink.ml.util;

import java.io.IOException;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.flink.api.common.eventtime.WatermarkStrategy;
import org.apache.flink.api.common.serialization.Encoder;
import org.apache.flink.connector.file.sink.FileSink;
import org.apache.flink.connector.file.src.FileSource;
import org.apache.flink.connector.file.src.reader.SimpleStreamFormat;
import org.apache.flink.core.fs.Path;
import org.apache.flink.ml.api.Stage;
import org.apache.flink.ml.builder.Graph;
import org.apache.flink.ml.builder.GraphData;
import org.apache.flink.ml.builder.GraphModel;
import org.apache.flink.ml.builder.GraphNode;
import org.apache.flink.ml.common.datastream.TableUtils;
import org.apache.flink.ml.common.metrics.MLMetrics;
import org.apache.flink.ml.param.Param;
import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.core.JsonParser;
import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.ObjectMapper;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.functions.sink.filesystem.bucketassigners.BasePathBucketAssigner;
import org.apache.flink.streaming.api.functions.sink.filesystem.rollingpolicies.OnCheckpointRollingPolicy;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.util.Preconditions;

/* loaded from: input_file:org/apache/flink/ml/util/ReadWriteUtils.class */
public class ReadWriteUtils {
    public static final ObjectMapper OBJECT_MAPPER = new ObjectMapper().enable(new JsonParser.Feature[]{JsonParser.Feature.ALLOW_COMMENTS});

    /* JADX WARN: Multi-variable type inference failed */
    private static <T> Object jsonEncodeHelper(Param<T> param, Object obj) throws IOException {
        return param.jsonEncode(obj);
    }

    private static Map<String, Object> jsonEncode(Map<Param<?>, Object> map) throws IOException {
        HashMap hashMap = new HashMap(map.size());
        for (Map.Entry<Param<?>, Object> entry : map.entrySet()) {
            hashMap.put(entry.getKey().name, jsonEncodeHelper(entry.getKey(), entry.getValue()));
        }
        return hashMap;
    }

    public static void saveMetadata(Stage<?> stage, String str, Map<String, ?> map) throws IOException {
        HashMap hashMap = new HashMap(map);
        hashMap.put("className", stage.getClass().getName());
        hashMap.put(MLMetrics.TIMESTAMP, Long.valueOf(System.currentTimeMillis()));
        hashMap.put("paramMap", jsonEncode(stage.getParamMap()));
        FileUtils.saveToFile(new Path(str, "metadata").toUri().toString(), OBJECT_MAPPER.writeValueAsString(hashMap), false);
    }

    public static void saveMetadata(Stage<?> stage, String str) throws IOException {
        saveMetadata(stage, str, new HashMap());
    }

    public static void savePipeline(Stage<?> stage, List<Stage<?>> list, String str) throws IOException {
        FileUtils.mkdirs(new Path(str));
        HashMap hashMap = new HashMap();
        hashMap.put("numStages", Integer.valueOf(list.size()));
        saveMetadata(stage, str, hashMap);
        int size = list.size();
        for (int i = 0; i < size; i++) {
            list.get(i).save(FileUtils.getPathForPipelineStage(i, size, str));
        }
    }

    public static List<Stage<?>> loadPipeline(StreamTableEnvironment streamTableEnvironment, String str, String str2) throws IOException {
        int intValue = ((Integer) FileUtils.loadMetadata(str, str2).get("numStages")).intValue();
        ArrayList arrayList = new ArrayList(intValue);
        for (int i = 0; i < intValue; i++) {
            arrayList.add(loadStage(streamTableEnvironment, FileUtils.getPathForPipelineStage(i, intValue, str)));
        }
        return arrayList;
    }

    public static void saveGraph(Stage<?> stage, GraphData graphData, String str) throws IOException {
        FileUtils.mkdirs(new Path(str));
        HashMap hashMap = new HashMap();
        hashMap.put("graphData", graphData.toMap());
        saveMetadata(stage, str, hashMap);
        int intValue = ((Integer) graphData.nodes.stream().map(graphNode -> {
            return Integer.valueOf(graphNode.nodeId);
        }).max(Comparator.naturalOrder()).orElse(-1)).intValue();
        for (GraphNode graphNode2 : graphData.nodes) {
            graphNode2.stage.save(FileUtils.getPathForPipelineStage(graphNode2.nodeId, intValue + 1, str));
        }
    }

    public static Stage<?> loadGraph(StreamTableEnvironment streamTableEnvironment, String str, String str2) throws IOException {
        GraphData fromMap = GraphData.fromMap((Map) FileUtils.loadMetadata(str, str2).get("graphData"));
        int intValue = ((Integer) fromMap.nodes.stream().map(graphNode -> {
            return Integer.valueOf(graphNode.nodeId);
        }).max(Comparator.naturalOrder()).orElse(-1)).intValue();
        for (GraphNode graphNode2 : fromMap.nodes) {
            graphNode2.stage = loadStage(streamTableEnvironment, FileUtils.getPathForPipelineStage(graphNode2.nodeId, intValue + 1, str));
        }
        if (str2.equals(GraphModel.class.getName())) {
            return new GraphModel(fromMap.nodes, fromMap.modelInputIds, fromMap.outputIds, fromMap.inputModelDataIds, fromMap.outputModelDataIds);
        }
        Preconditions.checkState(str2.equals(Graph.class.getName()));
        return new Graph(fromMap.nodes, fromMap.estimatorInputIds, fromMap.modelInputIds, fromMap.outputIds, fromMap.inputModelDataIds, fromMap.outputModelDataIds);
    }

    public static <T extends Stage<T>> T loadStageParam(String str) throws IOException {
        try {
            return (T) ParamUtils.instantiateWithParams(FileUtils.loadMetadata(str, ""));
        } catch (ClassNotFoundException e) {
            throw new RuntimeException("Failed to load stage.", e);
        }
    }

    public static Stage<?> loadStage(StreamTableEnvironment streamTableEnvironment, String str) throws IOException {
        String str2 = (String) FileUtils.loadMetadata(str, "").get("className");
        try {
            Method method = Class.forName(str2).getMethod("load", StreamTableEnvironment.class, String.class);
            method.setAccessible(true);
            return (Stage) method.invoke(null, streamTableEnvironment, str);
        } catch (ClassNotFoundException | IllegalAccessException | InvocationTargetException e) {
            throw new RuntimeException("Failed to load stage.", e);
        } catch (NoSuchMethodException e2) {
            throw new RuntimeException("Failed to load stage because the static method " + String.format("%s::load(StreamTableEnvironment, String)", str2) + " is not implemented.", e2);
        }
    }

    public static <T> void saveModelData(DataStream<T> dataStream, String str, Encoder<T> encoder) {
        dataStream.sinkTo(FileSink.forRowFormat(FileUtils.getDataPath(str), encoder).withRollingPolicy(OnCheckpointRollingPolicy.build()).withBucketAssigner(new BasePathBucketAssigner()).build());
    }

    public static <T> Table loadModelData(StreamTableEnvironment streamTableEnvironment, String str, SimpleStreamFormat<T> simpleStreamFormat) {
        return streamTableEnvironment.fromDataStream(TableUtils.getExecutionEnvironment(streamTableEnvironment).fromSource(FileSource.forRecordStreamFormat(simpleStreamFormat, new Path[]{FileUtils.getDataPath(str)}).build(), WatermarkStrategy.noWatermarks(), "modelData"));
    }
}
