package com.linkedin.dagli.dag;

import com.linkedin.dagli.annotation.equality.ValueEquality;
import com.linkedin.dagli.generator.Constant;
import com.linkedin.dagli.generator.Generator;
import com.linkedin.dagli.placeholder.Placeholder;
import com.linkedin.dagli.producer.ChildProducer;
import com.linkedin.dagli.producer.Producer;
import com.linkedin.dagli.transformer.AbstractPreparedTransformerDynamic;
import com.linkedin.dagli.transformer.PreparableTransformer;
import com.linkedin.dagli.transformer.PreparedTransformer;
import com.linkedin.dagli.transformer.Transformer;
import com.linkedin.dagli.tuple.Tuple;
import com.linkedin.dagli.util.collection.LinkedStack;
import com.linkedin.dagli.view.TransformerView;
import it.unimi.dsi.fastutil.objects.Object2IntOpenHashMap;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.IdentityHashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;

/* JADX INFO: Access modifiers changed from: package-private */
/* loaded from: input_file:com/linkedin/dagli/dag/DAGStructure.class */
public class DAGStructure<R> implements Serializable, Graph<Producer<?>> {
    private static final long serialVersionUID = 1;
    private static final int MISSING_NODE_INDEX = -1;
    final List<Placeholder<?>> _placeholders;
    final List<Producer<?>> _outputs;
    final HashMap<Producer<?>, ArrayList<ChildProducer<?>>> _childrenMap;
    final List<Generator<?>> _generators;
    final Producer<?>[] _nodes;
    private final Object2IntOpenHashMap<Producer<?>> _nodeIndexMap;
    final int[] _phases;
    final int[][] _parents;
    final int[][] _children;
    final int[] _outputIndices;
    final boolean _isPrepared;
    final int _maxMinibatchSize;
    final int _maxParentCount;
    final boolean _isAlwaysConstant;
    final boolean _hasIdempotentPreparer;
    final EqualityLeaf _equalityDAG;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* JADX INFO: Access modifiers changed from: private */
    @ValueEquality
    /* loaded from: input_file:com/linkedin/dagli/dag/DAGStructure$EqualityLeaf.class */
    public static class EqualityLeaf extends AbstractPreparedTransformerDynamic<Void, EqualityLeaf> {
        private static final long serialVersionUID = 1;

        public EqualityLeaf(List<? extends Producer<?>> list) {
            super(list);
        }

        /* JADX INFO: Access modifiers changed from: protected */
        /* JADX WARN: Can't rename method to resolve collision */
        @Override // com.linkedin.dagli.transformer.AbstractPreparedTransformerDynamic
        public Void apply(List list) {
            return null;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public DAGStructure(DeduplicatedDAG deduplicatedDAG) {
        this(deduplicatedDAG._placeholders, deduplicatedDAG._outputs, deduplicatedDAG._childrenMap);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* JADX WARN: Type inference failed for: r1v23, types: [int[], int[][]] */
    /* JADX WARN: Type inference failed for: r1v27, types: [int[], int[][]] */
    public DAGStructure(List<Placeholder<?>> list, List<Producer<?>> list2, HashMap<Producer<?>, ArrayList<ChildProducer<?>>> hashMap) {
        hashMap.keySet().forEach((v0) -> {
            v0.validate();
        });
        this._placeholders = list;
        this._outputs = list2;
        this._childrenMap = hashMap;
        this._generators = (List) hashMap.keySet().stream().filter(producer -> {
            return producer instanceof Generator;
        }).collect(Collectors.toList());
        this._nodes = new Producer[this._childrenMap.size()];
        this._nodeIndexMap = new Object2IntOpenHashMap<>(this._nodes.length);
        this._nodeIndexMap.defaultReturnValue(MISSING_NODE_INDEX);
        this._phases = new int[this._nodes.length];
        this._parents = new int[this._nodes.length];
        this._children = new int[this._nodes.length];
        LinkedList<PreparableTransformer<?, ?>> linkedList = new LinkedList<>();
        LinkedList<PreparedTransformer<?>> linkedList2 = new LinkedList<>();
        LinkedList<TransformerView<?, ?>> linkedList3 = new LinkedList<>();
        IdentityHashMap<ChildProducer<?>, Set<Producer<?>>> producerToInputSetMap = DAGUtil.producerToInputSetMap(this._childrenMap.keySet());
        Iterator<Placeholder<?>> it = this._placeholders.iterator();
        while (it.hasNext()) {
            addNode(it.next(), 0, producerToInputSetMap, linkedList, linkedList2, linkedList3);
        }
        Iterator<Generator<?>> it2 = this._generators.iterator();
        while (it2.hasNext()) {
            addNode(it2.next(), 0, producerToInputSetMap, linkedList, linkedList2, linkedList3);
        }
        int i = 0;
        while (this._nodeIndexMap.size() < this._nodes.length) {
            while (!linkedList3.isEmpty()) {
                addNode(linkedList3.remove(), i, producerToInputSetMap, linkedList, linkedList2, linkedList3);
            }
            while (!linkedList2.isEmpty()) {
                addNode(linkedList2.remove(), i, producerToInputSetMap, linkedList, linkedList2, linkedList3);
            }
            i++;
            LinkedList<PreparableTransformer<?, ?>> linkedList4 = linkedList;
            linkedList = new LinkedList<>();
            Iterator<PreparableTransformer<?, ?>> it3 = linkedList4.iterator();
            while (it3.hasNext()) {
                addNode(it3.next(), i, producerToInputSetMap, linkedList, linkedList2, linkedList3);
            }
        }
        for (int i2 = 0; i2 < this._nodes.length; i2++) {
            ArrayList<ChildProducer<?>> arrayList = this._childrenMap.get(this._nodes[i2]);
            int[] iArr = new int[arrayList.size()];
            for (int i3 = 0; i3 < arrayList.size(); i3++) {
                iArr[i3] = getNodeIndex(arrayList.get(i3));
            }
            Arrays.sort(iArr);
            this._children[i2] = iArr;
        }
        this._outputIndices = new int[this._outputs.size()];
        for (int i4 = 0; i4 < this._outputIndices.length; i4++) {
            this._outputIndices[i4] = this._nodeIndexMap.getInt(this._outputs.get(i4));
        }
        boolean z = true;
        int size = this._placeholders.size() + this._generators.size();
        while (true) {
            if (size >= this._nodes.length) {
                break;
            }
            if (!(this._nodes[size] instanceof PreparedTransformer)) {
                z = false;
                break;
            }
            size++;
        }
        this._isPrepared = z;
        this._isAlwaysConstant = this._outputs.stream().allMatch((v0) -> {
            return v0.hasConstantResult();
        });
        this._hasIdempotentPreparer = z || Arrays.stream(this._nodes).allMatch(producer2 -> {
            return !(producer2 instanceof PreparableTransformer) || ((PreparableTransformer) producer2).internalAPI().hasIdempotentPreparer();
        });
        this._maxMinibatchSize = Arrays.stream(this._nodes).filter(producer3 -> {
            return producer3 instanceof PreparedTransformer;
        }).map(producer4 -> {
            return (PreparedTransformer) producer4;
        }).mapToInt(preparedTransformer -> {
            return preparedTransformer.internalAPI().getPreferredMinibatchSize();
        }).max().orElse(1);
        this._maxParentCount = Arrays.stream(this._parents).mapToInt(iArr2 -> {
            return iArr2.length;
        }).max().orElse(0);
        this._equalityDAG = createEqualityDAG();
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public int getInputArity() {
        return this._placeholders.size();
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public int getOutputArity() {
        return this._outputIndices.length;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public int getNodeIndex(Producer<?> producer) {
        return this._nodeIndexMap.getInt(producer);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public boolean isOutput(int i) {
        for (int i2 : this._outputIndices) {
            if (i == i2) {
                return true;
            }
        }
        return false;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public boolean isRoot(int i) {
        return i < this._placeholders.size() + this._generators.size();
    }

    private void addNode(Producer<?> producer, int i, IdentityHashMap<ChildProducer<?>, Set<Producer<?>>> identityHashMap, LinkedList<PreparableTransformer<?, ?>> linkedList, LinkedList<PreparedTransformer<?>> linkedList2, LinkedList<TransformerView<?, ?>> linkedList3) {
        int size = this._nodeIndexMap.size();
        this._nodes[size] = producer;
        this._nodeIndexMap.put(producer, size);
        this._phases[size] = i;
        if (producer instanceof Transformer) {
            List<? extends Producer<?>> inputList = ((Transformer) producer).internalAPI().getInputList();
            int[] iArr = new int[inputList.size()];
            for (int i2 = 0; i2 < iArr.length; i2++) {
                iArr[i2] = getNodeIndex(inputList.get(i2));
            }
            this._parents[size] = iArr;
        } else if (producer instanceof TransformerView) {
            int[] iArr2 = new int[1];
            iArr2[0] = getNodeIndex(((TransformerView) producer).internalAPI().getViewed());
            this._parents[size] = iArr2;
        } else {
            if (!$assertionsDisabled && i != 0) {
                throw new AssertionError();
            }
            this._parents[size] = new int[0];
        }
        Iterator<ChildProducer<?>> it = this._childrenMap.get(producer).iterator();
        while (it.hasNext()) {
            ChildProducer<?> next = it.next();
            Set<Producer<?>> set = identityHashMap.get(next);
            if (!set.isEmpty()) {
                set.remove(producer);
                if (!set.isEmpty()) {
                    continue;
                } else if (next instanceof PreparedTransformer) {
                    linkedList2.add((PreparedTransformer) next);
                } else if (next instanceof PreparableTransformer) {
                    linkedList.add((PreparableTransformer) next);
                } else {
                    if (!(next instanceof TransformerView)) {
                        throw new IllegalArgumentException("Unknown dependency type");
                    }
                    linkedList3.add((TransformerView) next);
                }
            }
        }
    }

    public int getLastPhase() {
        return this._phases[this._phases.length - 1];
    }

    public boolean isLastPhase(int i) {
        return this._phases[i] == getLastPhase();
    }

    public int firstNodeInPhase(int i) {
        if (i == 0) {
            return 0;
        }
        int binarySearch = Arrays.binarySearch(this._phases, i);
        while (this._phases[binarySearch - 1] == i) {
            binarySearch += MISSING_NODE_INDEX;
        }
        return binarySearch;
    }

    public int firstPreparedTransformerInPhase(int i) {
        int firstNodeInPhase = firstNodeInPhase(i);
        while (firstNodeInPhase < this._phases.length && this._phases[firstNodeInPhase] == i && !(this._nodes[firstNodeInPhase] instanceof PreparedTransformer)) {
            firstNodeInPhase++;
        }
        return firstNodeInPhase;
    }

    @Override // com.linkedin.dagli.dag.Graph
    public Set<? extends Producer<?>> nodes() {
        return this._childrenMap.keySet();
    }

    @Override // com.linkedin.dagli.dag.Graph
    public List<? extends ChildProducer<?>> children(Producer<?> producer) {
        return this._childrenMap.get(producer);
    }

    @Override // com.linkedin.dagli.dag.Graph
    public List<? extends Producer<?>> parents(Producer<?> producer) {
        if ($assertionsDisabled || this._childrenMap.containsKey(producer)) {
            return producer instanceof ChildProducer ? ((ChildProducer) producer).internalAPI().getInputList() : Collections.emptyList();
        }
        throw new AssertionError();
    }

    private EqualityLeaf createEqualityDAG() {
        EqualityLeaf equalityLeaf = new EqualityLeaf(this._outputs);
        IdentityHashMap identityHashMap = new IdentityHashMap(this._placeholders.size());
        for (int i = 0; i < this._placeholders.size(); i++) {
            identityHashMap.put(this._placeholders.get(i), new PositionPlaceholder(i));
        }
        return (EqualityLeaf) DAGUtil.replaceInputs(equalityLeaf, identityHashMap);
    }

    public boolean equals(Object obj) {
        if (!(obj instanceof DAGStructure)) {
            return false;
        }
        DAGStructure dAGStructure = (DAGStructure) obj;
        return this._placeholders.size() == dAGStructure._placeholders.size() && this._equalityDAG.equals(dAGStructure._equalityDAG);
    }

    public int hashCode() {
        return Objects.hash(Integer.valueOf(this._placeholders.size()), this._equalityDAG);
    }

    public Object[] createExecutionStateArray(long j) {
        Object[] objArr = new Object[this._nodes.length];
        for (int size = this._placeholders.size() + this._generators.size(); size < objArr.length; size++) {
            objArr[size] = ((PreparedTransformer) this._nodes[size]).internalAPI().createExecutionCache(j);
        }
        return objArr;
    }

    private static String intSequenceString(int[] iArr) {
        return (String) Arrays.stream(iArr).mapToObj(Integer::toString).collect(Collectors.joining(","));
    }

    public String toProducerTable() {
        StringBuilder sb = new StringBuilder();
        sb.append(String.format("%-5s%-35s%-25s%-25s\n", "ID", "Name", "Children", "Parents"));
        for (int i = 0; i < 85; i++) {
            sb.append('-');
        }
        sb.append('\n');
        for (int i2 = 0; i2 < this._nodes.length; i2++) {
            sb.append(String.format("%-5s%-35s%-25s%-25s\n", Integer.valueOf(i2), this._nodes[i2].getName(), intSequenceString(this._children[i2]), intSequenceString(this._parents[i2])));
        }
        return sb.toString();
    }

    public Stream<LinkedStack<Producer<?>>> producers() {
        return Producer.subgraphProducers(this._outputs);
    }

    public R getConstantOutput() {
        Iterator<R> it = this._outputs.stream().map(producer -> {
            return Constant.tryGetValue(producer);
        }).iterator();
        return this._outputs.size() == 1 ? it.next() : (R) Tuple.generator(this._outputs.size()).fromIterator(it);
    }

    static {
        $assertionsDisabled = !DAGStructure.class.desiredAssertionStatus();
    }
}
