package com.linkedin.dagli.dag;

import com.linkedin.dagli.generator.Generator;
import com.linkedin.dagli.placeholder.Placeholder;
import com.linkedin.dagli.producer.ChildProducer;
import com.linkedin.dagli.producer.MissingInput;
import com.linkedin.dagli.producer.Producer;
import com.linkedin.dagli.producer.RootProducer;
import com.linkedin.dagli.transformer.Transformer;
import com.linkedin.dagli.util.invariant.Arguments;
import com.linkedin.dagli.view.TransformerView;
import it.unimi.dsi.fastutil.objects.ObjectIterator;
import it.unimi.dsi.fastutil.objects.ReferenceOpenHashSet;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.IdentityHashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Objects;
import java.util.PriorityQueue;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;

/* JADX INFO: Access modifiers changed from: package-private */
/* loaded from: input_file:com/linkedin/dagli/dag/DeduplicatedDAG.class */
public class DeduplicatedDAG {
    final List<Placeholder<?>> _placeholders;
    final List<Producer<?>> _outputs;
    final HashMap<Producer<?>, ArrayList<ChildProducer<?>>> _childrenMap;

    /* JADX INFO: Access modifiers changed from: package-private */
    public DeduplicatedDAG(DAGStructure<?> dAGStructure) {
        this._placeholders = dAGStructure._placeholders;
        this._outputs = dAGStructure._outputs;
        this._childrenMap = dAGStructure._childrenMap;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public DeduplicatedDAG(List<? extends Placeholder<?>> list, List<? extends Producer<?>> list2) {
        Objects.requireNonNull(list, "Inputs may not be null");
        Objects.requireNonNull(list2, "Outputs may not be null");
        Arguments.check(list.size() > 0, "Must have at least one input");
        Arguments.check(list2.size() > 0, "Must have at least one output");
        Arguments.check(list.stream().distinct().count() == ((long) list.size()), "The list of placeholders contains duplicates");
        IdentityHashMap<Producer<?>, ArrayList<ChildProducer<?>>> parentToChildrenMap = parentToChildrenMap(list, list2);
        validate(parentToChildrenMap.keySet());
        IdentityHashMap<Producer<?>, Producer<?>> deduplicationMap = deduplicationMap(parentToChildrenMap);
        Stream<? extends Producer<?>> stream = list2.stream();
        Objects.requireNonNull(deduplicationMap);
        List<Producer<?>> list3 = (List) stream.map((v1) -> {
            return r1.get(v1);
        }).collect(Collectors.toList());
        IdentityHashMap<Producer<?>, ArrayList<ChildProducer<?>>> parentToChildrenMap2 = parentToChildrenMap(list, list3);
        HashMap<Producer<?>, ArrayList<ChildProducer<?>>> hashMap = new HashMap<>(parentToChildrenMap2);
        if (parentToChildrenMap2.size() != hashMap.size()) {
            throw new IllegalStateException("Failed to correctly deduplicate nodes while building DAG");
        }
        this._placeholders = new ArrayList(list);
        this._outputs = list3;
        this._childrenMap = hashMap;
    }

    private static IdentityHashMap<Producer<?>, ArrayList<ChildProducer<?>>> parentToChildrenMap(List<? extends Placeholder<?>> list, List<? extends Producer<?>> list2) {
        IdentityHashMap<Producer<?>, ArrayList<ChildProducer<?>>> identityHashMap = new IdentityHashMap<>(list.size() + list2.size());
        Iterator<? extends Placeholder<?>> it = list.iterator();
        while (it.hasNext()) {
            identityHashMap.put(it.next(), new ArrayList<>());
        }
        LinkedList linkedList = new LinkedList();
        ObjectIterator it2 = new ReferenceOpenHashSet(list2).iterator();
        while (it2.hasNext()) {
            Producer<?> producer = (Producer) it2.next();
            if (producer instanceof Placeholder) {
                Arguments.check(identityHashMap.containsKey(producer), "Outputs list includes Placeholder not present in the placeholders list");
            } else {
                if (producer instanceof ChildProducer) {
                    linkedList.add((ChildProducer) producer);
                } else if (!(producer instanceof Generator)) {
                    throw new IllegalArgumentException("Outputs list contains an object that is an unsupported type of Producer: " + producer);
                }
                identityHashMap.put(producer, new ArrayList<>());
            }
        }
        while (!linkedList.isEmpty()) {
            ChildProducer<?> childProducer = (ChildProducer) linkedList.pop();
            int i = -1;
            for (Producer<?> producer2 : childProducer.internalAPI().getInputList()) {
                i++;
                if (producer2 instanceof Placeholder) {
                    Arguments.check(identityHashMap.containsKey(producer2), "The outputs list requires a Placeholder that was not provided: " + producer2.toString() + "; proximate dependent child is " + childProducer.toString());
                    identityHashMap.get(producer2).add(childProducer);
                } else if (producer2 instanceof Generator) {
                    identityHashMap.computeIfAbsent(producer2, producer3 -> {
                        return new ArrayList();
                    }).add(childProducer);
                } else {
                    if (!(producer2 instanceof Transformer) && !(producer2 instanceof TransformerView)) {
                        if (producer2 instanceof MissingInput) {
                            throw new IllegalArgumentException("The transformer " + childProducer + " has an MissingInput at input number " + i + ".  This probably means you forgot to set an input on this transformer, e.g. using withInput(...).");
                        }
                        throw new IllegalArgumentException("Outputs list has ancestor that is not a supported Producer type");
                    }
                    if (identityHashMap.containsKey(producer2)) {
                        identityHashMap.get(producer2).add(childProducer);
                    } else {
                        linkedList.add((ChildProducer) producer2);
                        ArrayList<ChildProducer<?>> arrayList = new ArrayList<>(1);
                        arrayList.add(childProducer);
                        identityHashMap.put(producer2, arrayList);
                    }
                }
            }
        }
        return identityHashMap;
    }

    private static IdentityHashMap<Producer<?>, Producer<?>> deduplicationMap(IdentityHashMap<Producer<?>, ArrayList<ChildProducer<?>>> identityHashMap) {
        Producer<?> producer;
        IdentityHashMap<Producer<?>, Producer<?>> identityHashMap2 = new IdentityHashMap<>(identityHashMap.size());
        IdentityHashMap<ChildProducer<?>, Set<Producer<?>>> producerToInputSetMap = DAGUtil.producerToInputSetMap(identityHashMap.keySet());
        HashMap hashMap = new HashMap();
        PriorityQueue priorityQueue = (PriorityQueue) identityHashMap.keySet().stream().filter(producer2 -> {
            return producer2 instanceof RootProducer;
        }).collect(Collectors.toCollection(() -> {
            return new PriorityQueue(Comparator.comparing(obj -> {
                return Integer.valueOf(classDepth(obj.getClass()));
            }).reversed());
        }));
        while (!priorityQueue.isEmpty()) {
            Producer<?> producer3 = (Producer) priorityQueue.poll();
            if (producer3 instanceof ChildProducer) {
                Objects.requireNonNull(identityHashMap2);
                producer = DAGUtil.remappedInputs((ChildProducer) producer3, (v1) -> {
                    return r1.get(v1);
                });
            } else {
                producer = producer3;
            }
            identityHashMap2.put(producer3, (Producer) hashMap.computeIfAbsent(producer, Function.identity()));
            new ReferenceOpenHashSet(identityHashMap.get(producer3)).forEach(childProducer -> {
                Set set = (Set) producerToInputSetMap.get(childProducer);
                if (!set.remove(producer3)) {
                    throw new IllegalStateException("A Producer's child does not have the expected dependency on that Producer");
                }
                if (set.isEmpty()) {
                    priorityQueue.add(childProducer);
                }
            });
        }
        return identityHashMap2;
    }

    private static int classDepth(Class<?> cls) {
        int i = 0;
        while (cls != null) {
            cls = cls.getSuperclass();
            i++;
        }
        return i;
    }

    private static void validate(Iterable<Producer<?>> iterable) {
        for (Producer<?> producer : iterable) {
            try {
                producer.validate();
            } catch (RuntimeException e) {
                throw new IllegalStateException("While building a DAG, encountered an exception validating node of type " + producer.getClass() + ", " + producer.getName() + " (" + producer + "): " + e.getMessage(), e);
            }
        }
    }
}
