package com.linkedin.feathr.compute;

import com.linkedin.data.template.IntegerMap;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Deque;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;

/* loaded from: input_file:com/linkedin/feathr/compute/ComputeGraphs.class */
public class ComputeGraphs {

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/linkedin/feathr/compute/ComputeGraphs$VisitedState.class */
    public enum VisitedState {
        NOT_VISITED,
        IN_PROGRESS,
        VISITED
    }

    private ComputeGraphs() {
    }

    public static ComputeGraph validate(ComputeGraph computeGraph) {
        ensureNodeIdsAreSequential(computeGraph);
        ensureNodeReferencesExist(computeGraph);
        ensureNoDependencyCycles(computeGraph);
        ensureNoExternalReferencesToSelf(computeGraph);
        return computeGraph;
    }

    public static ComputeGraph merge(Collection<ComputeGraph> collection) {
        ComputeGraphBuilder computeGraphBuilder = new ComputeGraphBuilder();
        collection.forEach(computeGraph -> {
            int peekNextNodeId = computeGraphBuilder.peekNextNodeId();
            computeGraph.getNodes().forEach(anyNode -> {
                AnyNode copy = PegasusUtils.copy(anyNode);
                Dependencies.remapDependencies(copy, num -> {
                    return Integer.valueOf(num.intValue() + peekNextNodeId);
                });
                computeGraphBuilder.addNode(copy);
            });
            computeGraph.getFeatureNames().forEach((str, num) -> {
                computeGraphBuilder.addFeatureName(str, Integer.valueOf(num.intValue() + peekNextNodeId));
            });
        });
        return validate(removeExternalNodesForFeaturesDefinedInThisGraph(computeGraphBuilder.build(new ComputeGraph(), false)));
    }

    public static ComputeGraph removeRedundancies(ComputeGraph computeGraph) throws CloneNotSupportedException {
        Map<Integer, Set<Integer>> reverseDependencyIndex = getReverseDependencyIndex(computeGraph);
        Map<Integer, Set<String>> reverseFeatureDependencyIndex = getReverseFeatureDependencyIndex(computeGraph);
        List list = (List) computeGraph.getNodes().stream().map(PegasusUtils::copy).collect(Collectors.toList());
        list.forEach(anyNode -> {
            PegasusUtils.setNodeId(anyNode, 0);
        });
        IntegerMap featureNames = computeGraph.getFeatureNames();
        HashMap hashMap = new HashMap();
        Deque deque = (Deque) IntStream.range(0, list.size()).boxed().collect(Collectors.toCollection(ArrayDeque::new));
        ArrayList arrayList = new ArrayList(Collections.nCopies(list.size(), VisitedState.NOT_VISITED));
        while (!deque.isEmpty()) {
            int intValue = ((Integer) deque.pop()).intValue();
            if (arrayList.get(intValue) != VisitedState.VISITED) {
                AnyNode anyNode2 = (AnyNode) list.get(intValue);
                List list2 = (List) new Dependencies().getDependencies(anyNode2).stream().filter(num -> {
                    return arrayList.get(num.intValue()) != VisitedState.VISITED;
                }).collect(Collectors.toList());
                if (list2.isEmpty()) {
                    Integer num2 = (Integer) hashMap.get(anyNode2);
                    if (num2 != null) {
                        reverseDependencyIndex.getOrDefault(Integer.valueOf(intValue), Collections.emptySet()).forEach(num3 -> {
                            Dependencies.remapDependencies((AnyNode) list.get(num3.intValue()), num3 -> {
                                return num3.intValue() == intValue ? num2 : num3;
                            });
                        });
                        reverseFeatureDependencyIndex.getOrDefault(Integer.valueOf(intValue), Collections.emptySet()).forEach(str -> {
                        });
                    } else {
                        hashMap.put(anyNode2, Integer.valueOf(intValue));
                    }
                    arrayList.set(intValue, VisitedState.VISITED);
                } else {
                    if (arrayList.get(intValue) == VisitedState.IN_PROGRESS) {
                        throw new RuntimeException("Dependency cycle detected at node " + intValue);
                    }
                    deque.push(Integer.valueOf(intValue));
                    arrayList.set(intValue, VisitedState.IN_PROGRESS);
                    deque.getClass();
                    list2.forEach((v1) -> {
                        r1.push(v1);
                    });
                }
            }
        }
        hashMap.forEach((anyNode3, num4) -> {
            PegasusUtils.setNodeId(anyNode3, num4.intValue());
        });
        return reindexNodes(hashMap.keySet(), featureNames);
    }

    private static ComputeGraph removeExternalNodesForFeaturesDefinedInThisGraph(ComputeGraph computeGraph) {
        Integer num;
        HashMap hashMap = new HashMap();
        for (int i = 0; i < computeGraph.getNodes().size(); i++) {
            AnyNode anyNode = computeGraph.getNodes().get(i);
            if (anyNode.isExternal() && (num = (Integer) computeGraph.getFeatureNames().get(anyNode.getExternal().getName())) != null) {
                hashMap.put(Integer.valueOf(i), num);
            }
        }
        if (hashMap.isEmpty()) {
            return computeGraph;
        }
        computeGraph.getNodes().forEach(anyNode2 -> {
            Dependencies.remapDependencies(anyNode2, num2 -> {
                Integer num2 = (Integer) hashMap.get(num2);
                return num2 != null ? num2 : num2;
            });
        });
        hashMap.getClass();
        return removeNodes(computeGraph, (v1) -> {
            return r1.containsKey(v1);
        });
    }

    static ComputeGraph removeNodes(ComputeGraph computeGraph, Predicate<Integer> predicate) {
        Stream<Integer> filter = IntStream.range(0, computeGraph.getNodes().size()).boxed().filter(predicate.negate());
        AnyNodeArray nodes = computeGraph.getNodes();
        nodes.getClass();
        return reindexNodes((List) filter.map((v1) -> {
            return r1.get(v1);
        }).collect(Collectors.toList()), computeGraph.getFeatureNames());
    }

    static ComputeGraph reindexNodes(Collection<AnyNode> collection, IntegerMap integerMap) {
        HashMap hashMap = new HashMap();
        ComputeGraphBuilder computeGraphBuilder = new ComputeGraphBuilder();
        collection.forEach(anyNode -> {
            hashMap.put(Integer.valueOf(PegasusUtils.getNodeId(anyNode)), Integer.valueOf(computeGraphBuilder.addNode(anyNode)));
        });
        Function function = num -> {
            Integer num = (Integer) hashMap.get(num);
            if (num == null) {
                throw new RuntimeException("Node " + num + " not found in subgraph.");
            }
            return num;
        };
        collection.forEach(anyNode2 -> {
            Dependencies.remapDependencies(anyNode2, function);
        });
        integerMap.forEach((str, num2) -> {
            computeGraphBuilder.addFeatureName(str, (Integer) function.apply(num2));
        });
        return computeGraphBuilder.build();
    }

    private static Map<Integer, Set<Integer>> getReverseDependencyIndex(ComputeGraph computeGraph) {
        HashMap hashMap = new HashMap();
        for (int i = 0; i < computeGraph.getNodes().size(); i++) {
            Iterator<Integer> it = new Dependencies().getDependencies(computeGraph.getNodes().get(i)).iterator();
            while (it.hasNext()) {
                ((Set) hashMap.computeIfAbsent(Integer.valueOf(it.next().intValue()), num -> {
                    return new HashSet();
                })).add(Integer.valueOf(i));
            }
        }
        return hashMap;
    }

    static Map<Integer, Set<String>> getReverseFeatureDependencyIndex(ComputeGraph computeGraph) {
        HashMap hashMap = new HashMap();
        computeGraph.getFeatureNames().forEach((str, num) -> {
            ((Set) hashMap.computeIfAbsent(num, num -> {
                return new HashSet(1);
            })).add(str);
        });
        return hashMap;
    }

    static void ensureNodeIdsAreSequential(ComputeGraph computeGraph) {
        for (int i = 0; i < computeGraph.getNodes().size(); i++) {
            if (PegasusUtils.getNodeId(computeGraph.getNodes().get(i)) != i) {
                throw new RuntimeException("Graph nodes must be ID'd sequentially from 0 to N-1 where N is the number of nodes.");
            }
        }
    }

    static void ensureNodeReferencesExist(ComputeGraph computeGraph) {
        int size = computeGraph.getNodes().size() - 1;
        computeGraph.getNodes().forEach(anyNode -> {
            List list = (List) new Dependencies().getDependencies(anyNode).stream().filter(num -> {
                return num.intValue() < 0 || num.intValue() > size;
            }).collect(Collectors.toList());
            if (!list.isEmpty()) {
                throw new RuntimeException("Encountered missing dependencies " + list + " for node " + anyNode + ". Graph = " + computeGraph);
            }
        });
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static void ensureNoConcreteKeys(ComputeGraph computeGraph) {
        computeGraph.getNodes().forEach(anyNode -> {
            if ((anyNode.isExternal() && anyNode.getExternal().hasConcreteKey()) || ((anyNode.isAggregation() && anyNode.getAggregation().hasConcreteKey()) || ((anyNode.isDataSource() && anyNode.getDataSource().hasConcreteKey()) || ((anyNode.isLookup() && anyNode.getLookup().hasConcreteKey()) || (anyNode.isTransformation() && anyNode.getTransformation().hasConcreteKey()))))) {
                throw new RuntimeException("A concrete key has already been set for the node " + anyNode);
            }
        });
    }

    static void ensureNoExternalReferencesToSelf(ComputeGraph computeGraph) {
        computeGraph.getNodes().stream().filter((v0) -> {
            return v0.isExternal();
        }).forEach(anyNode -> {
            String name = anyNode.getExternal().getName();
            if (computeGraph.getFeatureNames().containsKey(name)) {
                throw new RuntimeException("Graph contains External node " + anyNode + " but also contains feature " + name + " in its feature name table: " + computeGraph.getFeatureNames() + ". Graph = " + computeGraph);
            }
        });
    }

    static void ensureNoDependencyCycles(ComputeGraph computeGraph) {
        Deque deque = (Deque) IntStream.range(0, computeGraph.getNodes().size()).boxed().collect(Collectors.toCollection(ArrayDeque::new));
        ArrayList arrayList = new ArrayList(Collections.nCopies(computeGraph.getNodes().size(), VisitedState.NOT_VISITED));
        while (!deque.isEmpty()) {
            int intValue = ((Integer) deque.pop()).intValue();
            if (arrayList.get(intValue) != VisitedState.VISITED) {
                List list = (List) new Dependencies().getDependencies(computeGraph.getNodes().get(intValue)).stream().filter(num -> {
                    return arrayList.get(num.intValue()) != VisitedState.VISITED;
                }).collect(Collectors.toList());
                if (list.isEmpty()) {
                    arrayList.set(intValue, VisitedState.VISITED);
                } else {
                    if (arrayList.get(intValue) == VisitedState.IN_PROGRESS) {
                        throw new RuntimeException("Dependency cycle involving node " + intValue);
                    }
                    deque.push(Integer.valueOf(intValue));
                    deque.getClass();
                    list.forEach((v1) -> {
                        r1.push(v1);
                    });
                    arrayList.set(intValue, VisitedState.IN_PROGRESS);
                }
            }
        }
    }
}
