/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements. See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License. You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.inlong.manager.service.sort.util;

import com.google.common.collect.Lists;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.apache.curator.shaded.com.google.common.collect.Maps;
import org.apache.inlong.manager.common.enums.TransformType;
import org.apache.inlong.manager.common.pojo.sink.StreamSink;
import org.apache.inlong.manager.common.pojo.source.StreamSource;
import org.apache.inlong.manager.common.pojo.stream.InlongStreamInfo;
import org.apache.inlong.manager.common.pojo.stream.StreamField;
import org.apache.inlong.manager.common.pojo.stream.StreamNode;
import org.apache.inlong.manager.common.pojo.stream.StreamPipeline;
import org.apache.inlong.manager.common.pojo.stream.StreamTransform;
import org.apache.inlong.manager.common.pojo.transform.TransformDefinition;
import org.apache.inlong.manager.common.pojo.transform.TransformResponse;
import org.apache.inlong.manager.common.pojo.transform.joiner.JoinerDefinition;
import org.apache.inlong.manager.common.pojo.transform.joiner.JoinerDefinition.JoinMode;
import org.apache.inlong.manager.common.util.StreamParseUtils;
import org.apache.inlong.sort.protocol.FieldInfo;
import org.apache.inlong.sort.protocol.StreamInfo;
import org.apache.inlong.sort.protocol.node.Node;
import org.apache.inlong.sort.protocol.node.transform.TransformNode;
import org.apache.inlong.sort.protocol.transformation.FilterFunction;
import org.apache.inlong.sort.protocol.transformation.LogicOperator;
import org.apache.inlong.sort.protocol.transformation.function.SingleValueFilterFunction;
import org.apache.inlong.sort.protocol.transformation.operator.AndOperator;
import org.apache.inlong.sort.protocol.transformation.operator.EmptyOperator;
import org.apache.inlong.sort.protocol.transformation.operator.EqualOperator;
import org.apache.inlong.sort.protocol.transformation.relation.InnerJoinNodeRelation;
import org.apache.inlong.sort.protocol.transformation.relation.LeftOuterJoinNodeRelation;
import org.apache.inlong.sort.protocol.transformation.relation.NodeRelation;
import org.apache.inlong.sort.protocol.transformation.relation.RightOuterJoinNodeRelation;

import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

/**
 * Util for create node relation.
 */
@Slf4j
public class NodeRelationUtils {

    /**
     * Create node relation for the given stream
     */
    public static List<NodeRelation> createNodeRelationsForStream(InlongStreamInfo streamInfo) {
        String tempView = streamInfo.getExtParams();
        if (StringUtils.isEmpty(tempView)) {
            log.warn("stream node relation is empty for {}", streamInfo);
            return Lists.newArrayList();
        }
        StreamPipeline pipeline = StreamParseUtils.parseStreamPipeline(streamInfo.getExtParams(),
                streamInfo.getInlongStreamId());
        return pipeline.getPipeline().stream()
                .map(nodeRelation -> new NodeRelation(
                        Lists.newArrayList(nodeRelation.getInputNodes()),
                        Lists.newArrayList(nodeRelation.getOutputNodes())))
                .collect(Collectors.toList());
    }

    /**
     * Optimize relation of node, JoinerRelation must be rebuilt.
     */
    public static void optimizeNodeRelation(StreamInfo streamInfo, List<TransformResponse> transformResponses) {
        if (CollectionUtils.isEmpty(transformResponses)) {
            return;
        }
        Map<String, TransformDefinition> transformTypeMap = transformResponses.stream().collect(
                Collectors.toMap(TransformResponse::getTransformName, transformResponse -> {
                    TransformType transformType = TransformType.forType(transformResponse.getTransformType());
                    return StreamParseUtils.parseTransformDefinition(transformResponse.getTransformDefinition(),
                            transformType);
                }));
        List<Node> nodes = streamInfo.getNodes();
        Map<String, TransformNode> joinNodes = nodes.stream().filter(node -> node instanceof TransformNode)
                .map(node -> (TransformNode) node)
                .filter(transformNode -> {
                    TransformDefinition transformDefinition = transformTypeMap.get(transformNode.getName());
                    return transformDefinition.getTransformType() == TransformType.JOINER;
                }).collect(Collectors.toMap(TransformNode::getName, transformNode -> transformNode));

        List<NodeRelation> relations = streamInfo.getRelations();
        Iterator<NodeRelation> shipIterator = relations.listIterator();
        List<NodeRelation> joinRelations = Lists.newArrayList();
        while (shipIterator.hasNext()) {
            NodeRelation relation = shipIterator.next();
            List<String> outputs = relation.getOutputs();
            if (outputs.size() == 1) {
                String nodeName = outputs.get(0);
                if (joinNodes.get(nodeName) != null) {
                    TransformDefinition transformDefinition = transformTypeMap.get(nodeName);
                    TransformNode transformNode = joinNodes.get(nodeName);
                    joinRelations.add(getNodeRelation((JoinerDefinition) transformDefinition, relation));
                    shipIterator.remove();
                }
            }
        }
        relations.addAll(joinRelations);
    }

    private static NodeRelation getNodeRelation(JoinerDefinition joinerDefinition, NodeRelation nodeRelation) {
        JoinMode joinMode = joinerDefinition.getJoinMode();
        String leftNode = getNodeName(joinerDefinition.getLeftNode());
        String rightNode = getNodeName(joinerDefinition.getRightNode());
        List<String> preNodes = Lists.newArrayList(leftNode, rightNode);
        List<StreamField> leftJoinFields = joinerDefinition.getLeftJoinFields();
        List<StreamField> rightJoinFields = joinerDefinition.getRightJoinFields();
        List<FilterFunction> filterFunctions = Lists.newArrayList();
        for (int index = 0; index < leftJoinFields.size(); index++) {
            StreamField leftField = leftJoinFields.get(index);
            StreamField rightField = rightJoinFields.get(index);
            LogicOperator operator;
            if (index != leftJoinFields.size() - 1) {
                operator = AndOperator.getInstance();
            } else {
                operator = EmptyOperator.getInstance();
            }
            filterFunctions.add(createFilterFunction(leftField, rightField, operator));
        }
        Map<String, List<FilterFunction>> joinConditions = Maps.newHashMap();
        joinConditions.put(rightNode, filterFunctions);
        switch (joinMode) {
            case LEFT_JOIN:
                return new LeftOuterJoinNodeRelation(preNodes, nodeRelation.getOutputs(), joinConditions);
            case INNER_JOIN:
                return new InnerJoinNodeRelation(preNodes, nodeRelation.getOutputs(), joinConditions);
            case RIGHT_JOIN:
                return new RightOuterJoinNodeRelation(preNodes, nodeRelation.getOutputs(), joinConditions);
            default:
                throw new IllegalArgumentException(String.format("Unsupported join mode=%s for inlong", joinMode));
        }
    }

    private static SingleValueFilterFunction createFilterFunction(StreamField leftField, StreamField rightField,
            LogicOperator operator) {
        FieldInfo sourceField = new FieldInfo(leftField.getOriginFieldName(), leftField.getOriginNodeName(),
                FieldInfoUtils.convertFieldFormat(leftField.getFieldType(), leftField.getFieldFormat()));
        FieldInfo targetField = new FieldInfo(rightField.getOriginFieldName(), rightField.getOriginNodeName(),
                FieldInfoUtils.convertFieldFormat(rightField.getFieldType(), rightField.getFieldFormat()));
        return new SingleValueFilterFunction(operator, sourceField, EqualOperator.getInstance(), targetField);
    }

    private static String getNodeName(StreamNode node) {
        if (node instanceof StreamSource) {
            return ((StreamSource) node).getSourceName();
        } else if (node instanceof StreamSink) {
            return ((StreamSink) node).getSinkName();
        } else {
            return ((StreamTransform) node).getTransformName();
        }
    }

}
