package org.apache.flink.ml.benchmark;

import java.io.FileInputStream;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import org.apache.flink.api.common.JobExecutionResult;
import org.apache.flink.api.common.accumulators.LongCounter;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.ml.api.AlgoOperator;
import org.apache.flink.ml.api.Estimator;
import org.apache.flink.ml.api.Model;
import org.apache.flink.ml.api.Stage;
import org.apache.flink.ml.benchmark.datagenerator.DataGenerator;
import org.apache.flink.ml.benchmark.datagenerator.InputDataGenerator;
import org.apache.flink.ml.common.datastream.TableUtils;
import org.apache.flink.ml.common.metrics.MLMetrics;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.ReadWriteUtils;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.functions.sink.RichSinkFunction;
import org.apache.flink.streaming.api.functions.sink.SinkFunction;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.util.Preconditions;

/* loaded from: input_file:org/apache/flink/ml/benchmark/BenchmarkUtils.class */
public class BenchmarkUtils {

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/flink/ml/benchmark/BenchmarkUtils$CountingAndDiscardingSink.class */
    public static class CountingAndDiscardingSink<T> extends RichSinkFunction<T> {
        public static final String COUNTER_NAME = "numElements";
        private static final long serialVersionUID = 1;
        private final LongCounter numElementsCounter;

        private CountingAndDiscardingSink() {
            this.numElementsCounter = new LongCounter();
        }

        public void open(Configuration configuration) throws Exception {
            super.open(configuration);
            getRuntimeContext().addAccumulator(COUNTER_NAME, this.numElementsCounter);
        }

        public void invoke(T t, SinkFunction.Context context) {
            this.numElementsCounter.add(serialVersionUID);
        }
    }

    public static Map<String, Map<String, Map<String, ?>>> parseJsonFile(String str) throws IOException {
        Map map = (Map) ReadWriteUtils.OBJECT_MAPPER.readValue(new FileInputStream(str), Map.class);
        Preconditions.checkArgument(map.containsKey(MLMetrics.VERSION) && map.get(MLMetrics.VERSION).equals(1));
        HashMap hashMap = new HashMap();
        for (Map.Entry entry : map.entrySet()) {
            if (!((String) entry.getKey()).equals(MLMetrics.VERSION)) {
                hashMap.put((String) entry.getKey(), (Map) entry.getValue());
            }
        }
        return hashMap;
    }

    public static BenchmarkResult runBenchmark(StreamTableEnvironment streamTableEnvironment, String str, Map<String, Map<String, ?>> map, boolean z) throws Exception {
        Stage stage = (Stage) ParamUtils.instantiateWithParams(map.get("stage"));
        InputDataGenerator inputDataGenerator = (InputDataGenerator) ParamUtils.instantiateWithParams(map.get("inputData"));
        DataGenerator dataGenerator = null;
        if (map.containsKey("modelData")) {
            dataGenerator = (DataGenerator) ParamUtils.instantiateWithParams(map.get("modelData"));
        }
        return runBenchmark(streamTableEnvironment, str, stage, inputDataGenerator, dataGenerator, z);
    }

    private static BenchmarkResult runBenchmark(StreamTableEnvironment streamTableEnvironment, String str, Stage<?> stage, InputDataGenerator<?> inputDataGenerator, DataGenerator<?> dataGenerator, boolean z) throws Exception {
        Table[] transform;
        StreamExecutionEnvironment executionEnvironment = TableUtils.getExecutionEnvironment(streamTableEnvironment);
        Table[] data = inputDataGenerator.getData(streamTableEnvironment);
        if (dataGenerator != null) {
            ((Model) stage).setModelData(dataGenerator.getData(streamTableEnvironment));
        }
        if (stage instanceof Estimator) {
            transform = ((Estimator) stage).fit(data).getModelData();
        } else {
            if (!(stage instanceof AlgoOperator)) {
                throw new IllegalArgumentException("Unsupported Stage class " + stage.getClass());
            }
            transform = ((AlgoOperator) stage).transform(data);
        }
        for (Table table : transform) {
            streamTableEnvironment.toDataStream(table).addSink(new CountingAndDiscardingSink());
        }
        if (z) {
            return null;
        }
        JobExecutionResult execute = executionEnvironment.execute("Flink ML Benchmark Job " + str);
        double netRuntime = execute.getNetRuntime(TimeUnit.MILLISECONDS);
        long numValues = inputDataGenerator.getNumValues();
        double d = (numValues * 1000.0d) / netRuntime;
        long longValue = ((Long) execute.getAccumulatorResult(CountingAndDiscardingSink.COUNTER_NAME)).longValue();
        return new BenchmarkResult(str, Double.valueOf(netRuntime), Long.valueOf(numValues), Double.valueOf(d), Long.valueOf(longValue), Double.valueOf((longValue * 1000.0d) / netRuntime));
    }
}
