package org.apache.nemo.compiler.optimizer.pass.compiletime.reshaping;

import java.lang.invoke.SerializedLambda;
import java.util.Map;
import org.apache.nemo.common.KeyExtractor;
import org.apache.nemo.common.coder.DecoderFactory;
import org.apache.nemo.common.coder.EncoderFactory;
import org.apache.nemo.common.coder.LongDecoderFactory;
import org.apache.nemo.common.coder.LongEncoderFactory;
import org.apache.nemo.common.coder.PairDecoderFactory;
import org.apache.nemo.common.coder.PairEncoderFactory;
import org.apache.nemo.common.ir.edge.IREdge;
import org.apache.nemo.common.ir.edge.executionproperty.DecoderProperty;
import org.apache.nemo.common.ir.edge.executionproperty.EncoderProperty;
import org.apache.nemo.common.ir.edge.executionproperty.KeyDecoderProperty;
import org.apache.nemo.common.ir.edge.executionproperty.KeyEncoderProperty;
import org.apache.nemo.common.ir.vertex.utility.runtimepass.MessageAggregatorVertex;
import org.apache.nemo.common.ir.vertex.utility.runtimepass.MessageGeneratorVertex;

/* JADX INFO: Access modifiers changed from: package-private */
/* loaded from: input_file:org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/SkewHandlingUtil.class */
public final class SkewHandlingUtil {
    private SkewHandlingUtil() {
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static MessageGeneratorVertex.MessageGeneratorFunction<Object, Object, Long> getMessageGenerator(KeyExtractor keyExtractor) {
        return (obj, map) -> {
            Object extractKey = keyExtractor.extractKey(obj);
            if (map.containsKey(extractKey)) {
                map.compute(extractKey, (obj, l) -> {
                    return Long.valueOf(l.longValue() + 1);
                });
            } else {
                map.put(extractKey, 1L);
            }
            return map;
        };
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static MessageAggregatorVertex.MessageAggregatorFunction<Object, Long, Map<Object, Long>> getMessageAggregator() {
        return (pair, map) -> {
            Object left = pair.left();
            Long l = (Long) pair.right();
            if (map.containsKey(left)) {
                map.compute(left, (obj, l2) -> {
                    return Long.valueOf(l2.longValue() + l.longValue());
                });
            } else {
                map.put(left, l);
            }
            return map;
        };
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static EncoderProperty getEncoder(IREdge iREdge) {
        return EncoderProperty.of(PairEncoderFactory.of((EncoderFactory) iREdge.getPropertyValue(KeyEncoderProperty.class).orElseThrow(IllegalStateException::new), LongEncoderFactory.of()));
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static DecoderProperty getDecoder(IREdge iREdge) {
        return DecoderProperty.of(PairDecoderFactory.of((DecoderFactory) iREdge.getPropertyValue(KeyDecoderProperty.class).orElseThrow(IllegalStateException::new), LongDecoderFactory.of()));
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case -1309919086:
                if (implMethodName.equals("lambda$getMessageAggregator$c4e66a6a$1")) {
                    z = true;
                    break;
                }
                break;
            case 1675952837:
                if (implMethodName.equals("lambda$getMessageGenerator$f6471931$1")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/nemo/common/ir/vertex/utility/runtimepass/MessageGeneratorVertex$MessageGeneratorFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/SkewHandlingUtil") && serializedLambda.getImplMethodSignature().equals("(Lorg/apache/nemo/common/KeyExtractor;Ljava/lang/Object;Ljava/util/Map;)Ljava/util/Map;")) {
                    KeyExtractor keyExtractor = (KeyExtractor) serializedLambda.getCapturedArg(0);
                    return (obj, map) -> {
                        Object extractKey = keyExtractor.extractKey(obj);
                        if (map.containsKey(extractKey)) {
                            map.compute(extractKey, (obj, l) -> {
                                return Long.valueOf(l.longValue() + 1);
                            });
                        } else {
                            map.put(extractKey, 1L);
                        }
                        return map;
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/nemo/common/ir/vertex/utility/runtimepass/MessageAggregatorVertex$MessageAggregatorFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/SkewHandlingUtil") && serializedLambda.getImplMethodSignature().equals("(Lorg/apache/nemo/common/Pair;Ljava/util/Map;)Ljava/util/Map;")) {
                    return (pair, map2) -> {
                        Object left = pair.left();
                        Long l = (Long) pair.right();
                        if (map2.containsKey(left)) {
                            map2.compute(left, (obj2, l2) -> {
                                return Long.valueOf(l2.longValue() + l.longValue());
                            });
                        } else {
                            map2.put(left, l);
                        }
                        return map2;
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
