/*
 * Decompiled with CFR 0.152.
 */
package org.apache.nemo.compiler.frontend.spark.core;

import java.nio.ByteBuffer;
import java.util.List;
import java.util.Stack;
import org.apache.nemo.client.JobLauncher;
import org.apache.nemo.common.KeyExtractor;
import org.apache.nemo.common.dag.DAG;
import org.apache.nemo.common.dag.DAGBuilder;
import org.apache.nemo.common.dag.Edge;
import org.apache.nemo.common.dag.Vertex;
import org.apache.nemo.common.ir.IRDAG;
import org.apache.nemo.common.ir.edge.IREdge;
import org.apache.nemo.common.ir.edge.executionproperty.CommunicationPatternProperty;
import org.apache.nemo.common.ir.edge.executionproperty.DecoderProperty;
import org.apache.nemo.common.ir.edge.executionproperty.EncoderProperty;
import org.apache.nemo.common.ir.edge.executionproperty.KeyExtractorProperty;
import org.apache.nemo.common.ir.executionproperty.EdgeExecutionProperty;
import org.apache.nemo.common.ir.vertex.IRVertex;
import org.apache.nemo.common.ir.vertex.LoopVertex;
import org.apache.nemo.common.ir.vertex.OperatorVertex;
import org.apache.nemo.compiler.frontend.spark.SparkBroadcastVariables;
import org.apache.nemo.compiler.frontend.spark.SparkKeyExtractor;
import org.apache.nemo.compiler.frontend.spark.coder.SparkDecoderFactory;
import org.apache.nemo.compiler.frontend.spark.coder.SparkEncoderFactory;
import org.apache.nemo.compiler.frontend.spark.transform.CollectTransform;
import org.apache.nemo.compiler.frontend.spark.transform.GroupByKeyTransform;
import org.apache.nemo.compiler.frontend.spark.transform.ReduceByKeyTransform;
import org.apache.spark.SparkContext;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.serializer.JavaSerializer;
import org.apache.spark.serializer.KryoSerializer;
import org.apache.spark.serializer.Serializer;
import org.apache.spark.serializer.SerializerInstance;
import scala.Function1;
import scala.Tuple2;
import scala.collection.Iterator;
import scala.collection.JavaConverters;
import scala.collection.TraversableOnce;
import scala.reflect.ClassTag;
import scala.reflect.ClassTag$;

public final class SparkFrontendUtils {
    private static final KeyExtractorProperty SPARK_KEY_EXTRACTOR_PROP = KeyExtractorProperty.of((KeyExtractor)new SparkKeyExtractor());

    private SparkFrontendUtils() {
    }

    public static Serializer deriveSerializerFrom(SparkContext sparkContext) {
        if (sparkContext.conf().get("spark.serializer", "").equals("org.apache.spark.serializer.KryoSerializer")) {
            return new KryoSerializer(sparkContext.conf());
        }
        return new JavaSerializer(sparkContext.conf());
    }

    public static <T> List<T> collect(DAG<IRVertex, IREdge> dag, Stack<LoopVertex> loopVertexStack, IRVertex lastVertex, Serializer serializer) {
        DAGBuilder builder = new DAGBuilder(dag);
        OperatorVertex collectVertex = new OperatorVertex(new CollectTransform());
        builder.addVertex((Vertex)collectVertex, loopVertexStack);
        IREdge newEdge = new IREdge(SparkFrontendUtils.getEdgeCommunicationPattern(lastVertex, (IRVertex)collectVertex), lastVertex, (IRVertex)collectVertex);
        newEdge.setProperty((EdgeExecutionProperty)EncoderProperty.of(new SparkEncoderFactory(serializer)));
        newEdge.setProperty((EdgeExecutionProperty)DecoderProperty.of(new SparkDecoderFactory(serializer)));
        newEdge.setProperty((EdgeExecutionProperty)SPARK_KEY_EXTRACTOR_PROP);
        builder.connectVertices((Edge)newEdge);
        JobLauncher.launchDAG((IRDAG)new IRDAG(builder.build()), SparkBroadcastVariables.getAll(), (String)"");
        return JobLauncher.getCollectedData();
    }

    public static CommunicationPatternProperty.Value getEdgeCommunicationPattern(IRVertex src, IRVertex dst) {
        if (dst instanceof OperatorVertex && (((OperatorVertex)dst).getTransform() instanceof ReduceByKeyTransform || ((OperatorVertex)dst).getTransform() instanceof GroupByKeyTransform)) {
            return CommunicationPatternProperty.Value.SHUFFLE;
        }
        return CommunicationPatternProperty.Value.ONE_TO_ONE;
    }

    public static <I, O> Function<I, O> toJavaFunction(Function1<I, O> scalaFunction) {
        final ClassTag classTag = ClassTag$.MODULE$.apply(scalaFunction.getClass());
        final byte[] serializedFunction = new JavaSerializer().newInstance().serialize(scalaFunction, classTag).array();
        return new Function<I, O>(){
            private Function1<I, O> deserializedFunction;

            public O call(I v1) throws Exception {
                if (this.deserializedFunction == null) {
                    SerializerInstance js = new JavaSerializer().newInstance();
                    this.deserializedFunction = (Function1)js.deserialize(ByteBuffer.wrap(serializedFunction), classTag);
                }
                return this.deserializedFunction.apply(v1);
            }
        };
    }

    public static <I1, I2, O> Function2<I1, I2, O> toJavaFunction(final scala.Function2<I1, I2, O> scalaFunction) {
        return new Function2<I1, I2, O>(){

            public O call(I1 v1, I2 v2) throws Exception {
                return scalaFunction.apply(v1, v2);
            }
        };
    }

    public static <I, O> FlatMapFunction<I, O> toJavaFlatMapFunction(final Function1<I, TraversableOnce<O>> scalaFunction) {
        return new FlatMapFunction<I, O>(){

            public java.util.Iterator<O> call(I i) throws Exception {
                return (java.util.Iterator)JavaConverters.asJavaIteratorConverter((Iterator)((TraversableOnce)scalaFunction.apply(i)).toIterator()).asJava();
            }
        };
    }

    public static <T, K, V> Function<T, Tuple2<K, V>> pairFunctionToPlainFunction(final PairFunction<T, K, V> pairFunction) {
        return new Function<T, Tuple2<K, V>>(){

            public Tuple2<K, V> call(T elem) throws Exception {
                return pairFunction.call(elem);
            }
        };
    }
}

