package org.apache.samza.sql.translator;

import com.google.common.annotations.VisibleForTesting;
import java.lang.invoke.SerializedLambda;
import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List;
import org.apache.calcite.adapter.enumerable.EnumerableTableScan;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.logical.LogicalJoin;
import org.apache.calcite.rel.logical.LogicalProject;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlExplainFormat;
import org.apache.calcite.sql.SqlExplainLevel;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.commons.lang.Validate;
import org.apache.samza.SamzaException;
import org.apache.samza.operators.MessageStream;
import org.apache.samza.serializers.KVSerde;
import org.apache.samza.sql.data.SamzaSqlRelMessage;
import org.apache.samza.sql.interfaces.SqlIOConfig;
import org.apache.samza.sql.runner.SamzaSqlApplicationConfig;
import org.apache.samza.sql.serializers.SamzaSqlRelMessageSerdeFactory;
import org.apache.samza.sql.serializers.SamzaSqlRelRecordSerdeFactory;
import org.apache.samza.sql.translator.JoinInputNode;
import org.apache.samza.table.Table;
import org.apache.samza.table.descriptors.CachingTableDescriptor;
import org.apache.samza.table.descriptors.RemoteTableDescriptor;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/samza/sql/translator/JoinTranslator.class */
class JoinTranslator {
    private static final Logger log = LoggerFactory.getLogger(JoinTranslator.class);
    private String logicalOpId;
    private final String intermediateStreamPrefix;
    private final int queryId;
    private final TranslatorInputMetricsMapFunction inputMetricsMF;
    private final TranslatorOutputMetricsMapFunction outputMetricsMF;

    /* JADX INFO: Access modifiers changed from: package-private */
    public JoinTranslator(String str, String str2, int i) {
        this.logicalOpId = str;
        this.intermediateStreamPrefix = str2 + (str2.isEmpty() ? SamzaSqlApplicationConfig.DEFAULT_METADATA_TOPIC_PREFIX : "_");
        this.queryId = i;
        this.inputMetricsMF = new TranslatorInputMetricsMapFunction(str);
        this.outputMetricsMF = new TranslatorOutputMetricsMapFunction(str);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void translate(LogicalJoin logicalJoin, TranslatorContext translatorContext) {
        JoinInputNode.InputType inputType = getInputType(logicalJoin.getLeft(), translatorContext);
        JoinInputNode.InputType inputType2 = getInputType(logicalJoin.getRight(), translatorContext);
        validateJoinQuery(logicalJoin, inputType, inputType2);
        boolean z = inputType2 != JoinInputNode.InputType.STREAM;
        LinkedList linkedList = new LinkedList();
        LinkedList linkedList2 = new LinkedList();
        populateStreamAndTableKeyIds(logicalJoin.getCondition().getOperands(), logicalJoin, z, linkedList, linkedList2);
        JoinInputNode joinInputNode = new JoinInputNode(z ? logicalJoin.getLeft() : logicalJoin.getRight(), linkedList, z ? inputType : inputType2, !z);
        JoinInputNode joinInputNode2 = new JoinInputNode(z ? logicalJoin.getRight() : logicalJoin.getLeft(), linkedList2, z ? inputType2 : inputType, z);
        MessageStream<SamzaSqlRelMessage> joinStreamWithTable = joinStreamWithTable(translatorContext.getMessageStream(joinInputNode.getRelNode().getId()), getTable(joinInputNode2, translatorContext), joinInputNode, joinInputNode2, logicalJoin, translatorContext);
        translatorContext.registerMessageStream(logicalJoin.getId(), joinStreamWithTable);
        joinStreamWithTable.map(this.outputMetricsMF);
    }

    private MessageStream<SamzaSqlRelMessage> joinStreamWithTable(MessageStream<SamzaSqlRelMessage> messageStream, Table table, JoinInputNode joinInputNode, JoinInputNode joinInputNode2, LogicalJoin logicalJoin, TranslatorContext translatorContext) {
        List<Integer> keyIds = joinInputNode.getKeyIds();
        List<Integer> keyIds2 = joinInputNode2.getKeyIds();
        Validate.isTrue(keyIds.size() == keyIds2.size());
        log.info("Joining on the following Stream and Table field(s): ");
        ArrayList arrayList = new ArrayList(joinInputNode.getFieldNames());
        ArrayList arrayList2 = new ArrayList(joinInputNode2.getFieldNames());
        for (int i = 0; i < keyIds.size(); i++) {
            log.info(((String) arrayList.get(keyIds.get(i).intValue())) + " with " + ((String) arrayList2.get(keyIds2.get(i).intValue())));
        }
        if (!joinInputNode2.isRemoteTable()) {
            return messageStream.map(this.inputMetricsMF).partitionBy(samzaSqlRelMessage -> {
                return SamzaSqlRelMessage.createSamzaSqlCompositeKey(samzaSqlRelMessage, keyIds, SamzaSqlRelMessage.getSamzaSqlCompositeKeyFieldNames(arrayList2, keyIds2));
            }, samzaSqlRelMessage2 -> {
                return samzaSqlRelMessage2;
            }, KVSerde.of((SamzaSqlRelRecordSerdeFactory.SamzaSqlRelRecordSerde) new SamzaSqlRelRecordSerdeFactory().getSerde(null, null), (SamzaSqlRelMessageSerdeFactory.SamzaSqlRelMessageSerde) new SamzaSqlRelMessageSerdeFactory().getSerde(null, null)), this.intermediateStreamPrefix + "stream_" + this.logicalOpId).map((v0) -> {
                return v0.getValue();
            }).join(table, new SamzaSqlLocalTableJoinFunction(joinInputNode, joinInputNode2, logicalJoin.getJoinType()), new Object[0]);
        }
        String sourceName = joinInputNode2.getSourceName();
        return messageStream.map(this.inputMetricsMF).join(table, new SamzaSqlRemoteTableJoinFunction(translatorContext.getMsgConverter(sourceName), translatorContext.getTableKeyConverter(sourceName), joinInputNode, joinInputNode2, logicalJoin.getJoinType(), this.queryId), new Object[0]);
    }

    private void validateJoinQuery(LogicalJoin logicalJoin, JoinInputNode.InputType inputType, JoinInputNode.InputType inputType2) {
        JoinRelType joinType = logicalJoin.getJoinType();
        if (joinType.compareTo(JoinRelType.INNER) != 0 && joinType.compareTo(JoinRelType.LEFT) != 0 && joinType.compareTo(JoinRelType.RIGHT) != 0) {
            throw new SamzaException("Query with only INNER and LEFT/RIGHT OUTER join are supported.");
        }
        boolean z = inputType != JoinInputNode.InputType.STREAM;
        boolean z2 = inputType2 != JoinInputNode.InputType.STREAM;
        if (!z && !z2) {
            throw new SamzaException("Invalid query with both sides of join being denoted as 'stream'. Stream-stream join is not yet supported. " + dumpRelPlanForNode(logicalJoin));
        }
        if (z && z2) {
            throw new SamzaException("Invalid query with both sides of join being denoted as 'table'. " + dumpRelPlanForNode(logicalJoin));
        }
        if (joinType.compareTo(JoinRelType.LEFT) == 0 && z && !z2) {
            throw new SamzaException("Invalid query for outer left join. Left side of the join should be a 'stream' and right side of join should be a 'table'. " + dumpRelPlanForNode(logicalJoin));
        }
        if (joinType.compareTo(JoinRelType.RIGHT) == 0 && z2 && !z) {
            throw new SamzaException("Invalid query for outer right join. Left side of the join should be a 'table' and right side of join should be a 'stream'. " + dumpRelPlanForNode(logicalJoin));
        }
        validateJoinCondition(logicalJoin.getCondition());
    }

    private void validateJoinCondition(RexNode rexNode) {
        if (!(rexNode instanceof RexCall)) {
            throw new SamzaException("SQL Query is not supported. Join condition operand " + rexNode + " is of type " + rexNode.getClass());
        }
        RexCall rexCall = (RexCall) rexNode;
        if (rexCall.isAlwaysTrue()) {
            throw new SamzaException("Query results in a cross join, which is not supported. Please optimize the query. It is expected that the joins should include JOIN ON operator in the sql query.");
        }
        if (rexCall.getKind() != SqlKind.EQUALS && rexCall.getKind() != SqlKind.AND) {
            throw new SamzaException("Only equi-joins and AND operator is supported in join condition.");
        }
    }

    private void populateStreamAndTableKeyIds(List<RexNode> list, LogicalJoin logicalJoin, boolean z, List<Integer> list2, List<Integer> list3) {
        if (list.get(0) instanceof RexCall) {
            list.forEach(rexNode -> {
                validateJoinCondition(rexNode);
                populateStreamAndTableKeyIds(((RexCall) rexNode).getOperands(), logicalJoin, z, list2, list3);
            });
            return;
        }
        Validate.isTrue(list.size() == 2);
        if (!(list.get(0) instanceof RexInputRef) || !(list.get(1) instanceof RexInputRef)) {
            throw new SamzaException("SQL query is not supported. Join condition " + logicalJoin.getCondition() + " should have reference operands but the types are " + list.get(0).getClass() + " and " + list.get(1).getClass());
        }
        RexInputRef rexInputRef = (RexInputRef) list.get(0);
        RexInputRef rexInputRef2 = (RexInputRef) list.get(1);
        validateJoinKeys(rexInputRef);
        validateJoinKeys(rexInputRef2);
        if (rexInputRef.getIndex() > rexInputRef2.getIndex()) {
            rexInputRef = rexInputRef2;
            rexInputRef2 = rexInputRef;
        }
        int index = rexInputRef2.getIndex() - logicalJoin.getLeft().getRowType().getFieldCount();
        list2.add(Integer.valueOf(z ? rexInputRef.getIndex() : index));
        list3.add(Integer.valueOf(z ? index : rexInputRef.getIndex()));
    }

    private void validateJoinKeys(RexInputRef rexInputRef) {
        SqlTypeName sqlTypeName = rexInputRef.getType().getSqlTypeName();
        if (sqlTypeName == SqlTypeName.BOOLEAN || sqlTypeName == SqlTypeName.TINYINT || sqlTypeName == SqlTypeName.SMALLINT || sqlTypeName == SqlTypeName.INTEGER || sqlTypeName == SqlTypeName.CHAR || sqlTypeName == SqlTypeName.BIGINT || sqlTypeName == SqlTypeName.VARCHAR || sqlTypeName == SqlTypeName.DOUBLE || sqlTypeName == SqlTypeName.FLOAT || sqlTypeName == SqlTypeName.ANY || sqlTypeName == SqlTypeName.OTHER) {
            return;
        }
        log.error("Unsupported key type " + sqlTypeName + " used in join condition.");
        throw new SamzaException("Unsupported key type used in join condition.");
    }

    private String dumpRelPlanForNode(RelNode relNode) {
        return RelOptUtil.dumpPlan("Rel expression: ", relNode, SqlExplainFormat.TEXT, SqlExplainLevel.EXPPLAN_ATTRIBUTES);
    }

    private SqlIOConfig resolveSourceConfigForTable(RelNode relNode, TranslatorContext translatorContext) {
        if (relNode instanceof LogicalProject) {
            return resolveSourceConfigForTable(((LogicalProject) relNode).getInput(), translatorContext);
        }
        if (relNode.getInputs().size() > 1) {
            return null;
        }
        String sourceFromSourceParts = SqlIOConfig.getSourceFromSourceParts(relNode.getTable().getQualifiedName());
        SqlIOConfig sqlIOConfig = translatorContext.getExecutionContext().getSamzaSqlApplicationConfig().getInputSystemStreamConfigBySource().get(sourceFromSourceParts);
        if (sqlIOConfig == null) {
            throw new SamzaException("Unsupported source found in join statement: " + sourceFromSourceParts);
        }
        return sqlIOConfig;
    }

    private JoinInputNode.InputType getInputType(RelNode relNode, TranslatorContext translatorContext) {
        if (!(relNode instanceof EnumerableTableScan) && !(relNode instanceof LogicalProject)) {
            return JoinInputNode.InputType.STREAM;
        }
        SqlIOConfig resolveSourceConfigForTable = resolveSourceConfigForTable(relNode, translatorContext);
        return (resolveSourceConfigForTable == null || !resolveSourceConfigForTable.getTableDescriptor().isPresent()) ? JoinInputNode.InputType.STREAM : ((resolveSourceConfigForTable.getTableDescriptor().get() instanceof RemoteTableDescriptor) || (resolveSourceConfigForTable.getTableDescriptor().get() instanceof CachingTableDescriptor)) ? JoinInputNode.InputType.REMOTE_TABLE : JoinInputNode.InputType.LOCAL_TABLE;
    }

    private Table getTable(JoinInputNode joinInputNode, TranslatorContext translatorContext) {
        SqlIOConfig resolveSourceConfigForTable = resolveSourceConfigForTable(joinInputNode.getRelNode(), translatorContext);
        if (resolveSourceConfigForTable == null || !resolveSourceConfigForTable.getTableDescriptor().isPresent()) {
            String str = "Failed to resolve table source in join operation: node=" + joinInputNode.getRelNode();
            log.error(str);
            throw new SamzaException(str);
        }
        Table table = translatorContext.getStreamAppDescriptor().getTable(resolveSourceConfigForTable.getTableDescriptor().get());
        if (joinInputNode.isRemoteTable()) {
            return table;
        }
        MessageStream messageStream = translatorContext.getMessageStream(joinInputNode.getRelNode().getId());
        SamzaSqlRelRecordSerdeFactory.SamzaSqlRelRecordSerde samzaSqlRelRecordSerde = (SamzaSqlRelRecordSerdeFactory.SamzaSqlRelRecordSerde) new SamzaSqlRelRecordSerdeFactory().getSerde(null, null);
        SamzaSqlRelMessageSerdeFactory.SamzaSqlRelMessageSerde samzaSqlRelMessageSerde = (SamzaSqlRelMessageSerdeFactory.SamzaSqlRelMessageSerde) new SamzaSqlRelMessageSerdeFactory().getSerde(null, null);
        List<Integer> keyIds = joinInputNode.getKeyIds();
        messageStream.partitionBy(samzaSqlRelMessage -> {
            return SamzaSqlRelMessage.createSamzaSqlCompositeKey(samzaSqlRelMessage, keyIds);
        }, samzaSqlRelMessage2 -> {
            return samzaSqlRelMessage2;
        }, KVSerde.of(samzaSqlRelRecordSerde, samzaSqlRelMessageSerde), this.intermediateStreamPrefix + "table_" + this.logicalOpId).sendTo(table, new Object[0]);
        return table;
    }

    @VisibleForTesting
    public TranslatorInputMetricsMapFunction getInputMetricsMF() {
        return this.inputMetricsMF;
    }

    @VisibleForTesting
    public TranslatorOutputMetricsMapFunction getOutputMetricsMF() {
        return this.outputMetricsMF;
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case -1236102726:
                if (implMethodName.equals("lambda$getTable$193c3537$1")) {
                    z = 2;
                    break;
                }
                break;
            case -828101576:
                if (implMethodName.equals("lambda$joinStreamWithTable$d560c6f0$1")) {
                    z = 4;
                    break;
                }
                break;
            case 196138697:
                if (implMethodName.equals("lambda$joinStreamWithTable$6028c617$1")) {
                    z = 3;
                    break;
                }
                break;
            case 769108134:
                if (implMethodName.equals("lambda$getTable$da92bfd4$1")) {
                    z = true;
                    break;
                }
                break;
            case 1967798203:
                if (implMethodName.equals("getValue")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 5 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/samza/operators/functions/MapFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/samza/operators/KV") && serializedLambda.getImplMethodSignature().equals("()Ljava/lang/Object;")) {
                    return (v0) -> {
                        return v0.getValue();
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/samza/operators/functions/MapFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/samza/sql/translator/JoinTranslator") && serializedLambda.getImplMethodSignature().equals("(Ljava/util/List;Lorg/apache/samza/sql/data/SamzaSqlRelMessage;)Lorg/apache/samza/sql/SamzaSqlRelRecord;")) {
                    List list = (List) serializedLambda.getCapturedArg(0);
                    return samzaSqlRelMessage -> {
                        return SamzaSqlRelMessage.createSamzaSqlCompositeKey(samzaSqlRelMessage, list);
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/samza/operators/functions/MapFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/samza/sql/translator/JoinTranslator") && serializedLambda.getImplMethodSignature().equals("(Lorg/apache/samza/sql/data/SamzaSqlRelMessage;)Lorg/apache/samza/sql/data/SamzaSqlRelMessage;")) {
                    return samzaSqlRelMessage2 -> {
                        return samzaSqlRelMessage2;
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/samza/operators/functions/MapFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/samza/sql/translator/JoinTranslator") && serializedLambda.getImplMethodSignature().equals("(Lorg/apache/samza/sql/data/SamzaSqlRelMessage;)Lorg/apache/samza/sql/data/SamzaSqlRelMessage;")) {
                    return samzaSqlRelMessage22 -> {
                        return samzaSqlRelMessage22;
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/samza/operators/functions/MapFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/samza/sql/translator/JoinTranslator") && serializedLambda.getImplMethodSignature().equals("(Ljava/util/List;Ljava/util/List;Ljava/util/List;Lorg/apache/samza/sql/data/SamzaSqlRelMessage;)Lorg/apache/samza/sql/SamzaSqlRelRecord;")) {
                    List list2 = (List) serializedLambda.getCapturedArg(0);
                    List list3 = (List) serializedLambda.getCapturedArg(1);
                    List list4 = (List) serializedLambda.getCapturedArg(2);
                    return samzaSqlRelMessage3 -> {
                        return SamzaSqlRelMessage.createSamzaSqlCompositeKey(samzaSqlRelMessage3, list2, SamzaSqlRelMessage.getSamzaSqlCompositeKeyFieldNames(list3, list4));
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
