package com.linkedin.dagli.dag;

import com.linkedin.dagli.placeholder.Placeholder;
import com.linkedin.dagli.producer.ChildProducer;
import com.linkedin.dagli.producer.Producer;
import com.linkedin.dagli.producer.RootProducer;
import com.linkedin.dagli.producer.internal.AncestorSpliterator;
import com.linkedin.dagli.reducer.ClassReducerTable;
import com.linkedin.dagli.reducer.Reducer;
import com.linkedin.dagli.transformer.PreparableTransformer;
import com.linkedin.dagli.transformer.PreparedTransformer;
import com.linkedin.dagli.util.cloneable.AbstractCloneable;
import com.linkedin.dagli.util.collection.Iterables;
import com.linkedin.dagli.util.collection.LinkedStack;
import com.linkedin.dagli.util.invariant.Arguments;
import com.linkedin.dagli.view.TransformerView;
import it.unimi.dsi.fastutil.objects.ObjectIterator;
import it.unimi.dsi.fastutil.objects.ReferenceArraySet;
import it.unimi.dsi.fastutil.objects.ReferenceOpenHashSet;
import it.unimi.dsi.fastutil.objects.ReferenceSet;
import it.unimi.dsi.fastutil.objects.ReferenceSets;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.IdentityHashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;

/* JADX INFO: Access modifiers changed from: package-private */
/* loaded from: input_file:com/linkedin/dagli/dag/DAGReducer.class */
public abstract class DAGReducer {
    static final /* synthetic */ boolean $assertionsDisabled;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/linkedin/dagli/dag/DAGReducer$AccumulatedState.class */
    public static class AccumulatedState {
        final ReferenceOpenHashSet<Producer<?>> _visited = new ReferenceOpenHashSet<>();
        final ClassReducerTable _classReducerTable = new ClassReducerTable();

        AccumulatedState() {
            this._classReducerTable.add(ConstantResultReducer.INSTANCE, Producer.class);
        }

        void clean(IdentityHashMap<Producer<?>, ReferenceSet<ChildProducer<?>>> identityHashMap) {
            this._visited.removeIf(producer -> {
                return !identityHashMap.containsKey(producer);
            });
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:com/linkedin/dagli/dag/DAGReducer$ReducerContext.class */
    public static class ReducerContext implements Reducer.Context {
        final State _state;

        ReducerContext(State state) {
            this._state = state;
        }

        @Override // com.linkedin.dagli.reducer.Reducer.Context
        public Reducer.Level getMinimumImportance() {
            return this._state._minimumImportance;
        }

        @Override // com.linkedin.dagli.reducer.Reducer.Context
        public boolean isCompleteGraphReduction() {
            return this._state._isCompleteGraphReduction;
        }

        @Override // com.linkedin.dagli.reducer.Reducer.Context
        public boolean isPreparedDAG() {
            return this._state._isPreparedDAG;
        }

        @Override // com.linkedin.dagli.reducer.Reducer.Context
        public boolean isViewed(Producer<?> producer) {
            return (producer instanceof PreparableTransformer) && this._state._childrenMap.get(producer).stream().anyMatch(childProducer -> {
                return childProducer instanceof TransformerView;
            });
        }

        private void replaceUnconditionally(Producer<?> producer, Producer<?> producer2) {
            this._state.replace(producer, producer2);
        }

        private void checkForIllegalPreparable(Producer<?> producer) {
            if (isPreparedDAG() && (producer instanceof PreparableTransformer)) {
                throw new IllegalArgumentException("Cannot add PreparableTransformer to prepared DAG");
            }
        }

        @Override // com.linkedin.dagli.reducer.Reducer.Context
        public <T> boolean hasReducer(Class<T> cls, Reducer<? super T> reducer) {
            return this._state._accumulated._classReducerTable.hasReducer(cls, reducer);
        }

        @Override // com.linkedin.dagli.reducer.Reducer.Context
        public <T extends AbstractCloneable<T> & Producer<?>> void replaceWithSameClass(T t, T t2) {
            Arguments.check(t.getClass().equals(t2.getClass()), "Existing and replacement producers must be of the same type");
            replaceUnconditionally((Producer) t, (Producer) t2);
        }

        @Override // com.linkedin.dagli.reducer.Reducer.Context
        public <R> void replace(RootProducer<R> rootProducer, Producer<? extends R> producer) {
            checkForIllegalPreparable(producer);
            replaceUnconditionally(rootProducer, producer);
        }

        @Override // com.linkedin.dagli.reducer.Reducer.Context
        public <R> void replace(TransformerView<R, ?> transformerView, Producer<? extends R> producer) {
            replaceUnconditionally(transformerView, producer);
        }

        @Override // com.linkedin.dagli.reducer.Reducer.Context
        public <R> void replace(PreparedTransformer<R> preparedTransformer, Producer<? extends R> producer) {
            checkForIllegalPreparable(producer);
            replaceUnconditionally(preparedTransformer, producer);
        }

        @Override // com.linkedin.dagli.reducer.Reducer.Context
        public <R, N extends PreparedTransformer<? extends R>> void replace(PreparableTransformer<R, N> preparableTransformer, PreparableTransformer<? extends R, ? extends N> preparableTransformer2) {
            replaceUnconditionally(preparableTransformer, preparableTransformer2);
        }

        @Override // com.linkedin.dagli.reducer.Reducer.Context
        public <R> void replaceUnviewed(Producer<R> producer, Producer<? extends R> producer2) {
            if (!tryReplaceUnviewed(producer, () -> {
                return producer2;
            })) {
                throw new IllegalArgumentException("Attempt to use replaceUnviewedPreparable(...) on a preparable transformer with one or more TransformerView children");
            }
        }

        @Override // com.linkedin.dagli.reducer.Reducer.Context
        public <R> boolean tryReplaceUnviewed(Producer<R> producer, Supplier<Producer<? extends R>> supplier) {
            if (isViewed(producer)) {
                return false;
            }
            replaceUnconditionally(producer, supplier.get());
            return true;
        }

        @Override // com.linkedin.dagli.reducer.Reducer.Context
        public <T extends Producer<?>> T withCurrentParents(T t) {
            return (T) this._state.withCurrentParents(t);
        }

        @Override // com.linkedin.dagli.reducer.Reducer.Context
        public List<? extends Producer<?>> getParents(Producer<?> producer) {
            return this._state.getParents(producer);
        }

        @Override // com.linkedin.dagli.reducer.Reducer.Context
        public <T> ReferenceSet<T> getParentsByClass(Producer<?> producer, Class<T> cls) {
            List<? extends Producer<?>> parents = getParents(producer);
            ReferenceArraySet referenceArraySet = new ReferenceArraySet(parents.size());
            for (Producer<?> producer2 : parents) {
                if (cls.isInstance(producer2)) {
                    referenceArraySet.add(producer2);
                }
            }
            return referenceArraySet;
        }

        @Override // com.linkedin.dagli.reducer.Reducer.Context
        public <T> ReferenceSet<T> getAncestorsByClass(Producer<?> producer, Class<T> cls, int i) {
            ReferenceOpenHashSet referenceOpenHashSet = new ReferenceOpenHashSet();
            Stream<R> map = ancestors(producer, i).map((v0) -> {
                return v0.peek();
            });
            Objects.requireNonNull(cls);
            map.filter((v1) -> {
                return r1.isInstance(v1);
            }).forEach(producer2 -> {
                referenceOpenHashSet.add(cls.cast(producer2));
            });
            return referenceOpenHashSet;
        }

        @Override // com.linkedin.dagli.reducer.Reducer.Context
        public Stream<LinkedStack<Producer<?>>> ancestors(Producer<?> producer, int i) {
            if (!(producer instanceof ChildProducer)) {
                return Stream.empty();
            }
            State state = this._state;
            Objects.requireNonNull(state);
            return StreamSupport.stream(new AncestorSpliterator((ChildProducer<?>) producer, i, (Function<? super ChildProducer<?>, ? extends List<? extends Producer<?>>>) (v1) -> {
                return r4.getParents(v1);
            }), false);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/linkedin/dagli/dag/DAGReducer$State.class */
    public static class State {
        final Reducer.Level _minimumImportance;
        final boolean _isCompleteGraphReduction;
        final boolean _isPreparedDAG;
        final IdentityHashMap<Producer<?>, ReferenceSet<ChildProducer<?>>> _childrenMap;
        final IdentityHashMap<ChildProducer<?>, List<Producer<?>>> _parentMap;
        boolean _modified = false;
        final ArrayList<Placeholder<?>> _placeholders;
        final ArrayList<Producer<?>> _outputs;
        final AccumulatedState _accumulated;

        State(List<Placeholder<?>> list, List<Producer<?>> list2, HashMap<Producer<?>, ArrayList<ChildProducer<?>>> hashMap, Reducer.Level level, boolean z, AccumulatedState accumulatedState) {
            this._accumulated = accumulatedState;
            this._minimumImportance = level;
            this._isCompleteGraphReduction = z;
            this._placeholders = new ArrayList<>(list);
            this._outputs = new ArrayList<>(list2);
            this._childrenMap = new IdentityHashMap<>(hashMap.size());
            hashMap.forEach((producer, arrayList) -> {
                this._childrenMap.put(producer, new ReferenceOpenHashSet(arrayList));
            });
            this._parentMap = new IdentityHashMap<>(hashMap.size());
            hashMap.forEach((producer2, arrayList2) -> {
                if (producer2 instanceof ChildProducer) {
                    ChildProducer<?> childProducer = (ChildProducer) producer2;
                    this._parentMap.put(childProducer, new ArrayList(childProducer.internalAPI().getInputList()));
                }
            });
            Stream filter = hashMap.keySet().stream().map(producer3 -> {
                return producer3.internalAPI().getClassReducerTable();
            }).filter((v0) -> {
                return Objects.nonNull(v0);
            });
            ClassReducerTable classReducerTable = this._accumulated._classReducerTable;
            Objects.requireNonNull(classReducerTable);
            filter.forEach(classReducerTable::addAll);
            this._isPreparedDAG = hashMap.keySet().stream().noneMatch(producer4 -> {
                return producer4 instanceof PreparableTransformer;
            });
        }

        boolean inWorkingGraph(Producer<?> producer) {
            return this._childrenMap.containsKey(producer);
        }

        void add(Producer<?> producer, ReferenceSet<ChildProducer<?>> referenceSet) {
            this._modified = true;
            ReferenceSet<ChildProducer<?>> referenceSet2 = this._childrenMap.get(producer);
            if (referenceSet2 != null) {
                referenceSet2.addAll(referenceSet);
                return;
            }
            producer.validate();
            this._childrenMap.put(producer, new ReferenceOpenHashSet(referenceSet));
            if (producer instanceof ChildProducer) {
                ChildProducer<?> childProducer = (ChildProducer) producer;
                List<? extends Producer<?>> inputList = childProducer.internalAPI().getInputList();
                this._parentMap.put(childProducer, new ArrayList(inputList));
                inputList.forEach(producer2 -> {
                    add(producer2, ReferenceSets.singleton(childProducer));
                });
            }
        }

        /* JADX INFO: Access modifiers changed from: private */
        public void replace(Producer<?> producer, Producer<?> producer2) {
            Objects.requireNonNull(producer);
            Objects.requireNonNull(producer2);
            if (producer == producer2) {
                return;
            }
            this._modified = true;
            if (producer2 instanceof Placeholder) {
                Arguments.check(this._placeholders.contains(producer2), "Attempting to introduce a new placeholder that is not already part of the graph");
            }
            this._outputs.replaceAll(producer3 -> {
                return producer3 == producer ? producer2 : producer3;
            });
            add(producer2, ReferenceSets.emptySet());
            ReferenceSet<ChildProducer<?>> remove = this._childrenMap.remove(producer);
            ObjectIterator it = remove.iterator();
            while (it.hasNext()) {
                this._parentMap.get((ChildProducer) it.next()).replaceAll(producer4 -> {
                    return producer4 == producer ? producer2 : producer4;
                });
            }
            if (producer instanceof ChildProducer) {
                List<Producer<?>> remove2 = this._parentMap.remove(producer);
                remove2.forEach(producer5 -> {
                    this._childrenMap.get(producer5).remove(producer);
                });
                if (producer instanceof TransformerView) {
                    for (Producer<?> producer6 : remove2) {
                        if (this._childrenMap.get(producer6).stream().noneMatch(childProducer -> {
                            return childProducer instanceof TransformerView;
                        })) {
                            this._accumulated._visited.remove(producer6);
                        }
                    }
                }
            }
            this._childrenMap.get(producer2).addAll(remove);
        }

        List<Producer<?>> getParents(Producer<?> producer) {
            return producer instanceof ChildProducer ? this._parentMap.get(producer) : Collections.emptyList();
        }

        <T extends Producer<?>> T withCurrentParents(T t) {
            if (!(t instanceof ChildProducer)) {
                return t;
            }
            ChildProducer childProducer = (ChildProducer) t;
            List<Producer<?>> parents = getParents(childProducer);
            return Iterables.elementsAreReferenceEqual(parents, childProducer.internalAPI().getInputList()) ? t : childProducer.internalAPI().withInputsUnsafe(parents);
        }
    }

    private DAGReducer() {
    }

    static void reduce(Producer<?> producer, State state) {
        int i;
        if (!$assertionsDisabled && !state.inWorkingGraph(producer)) {
            throw new AssertionError();
        }
        if (!state._accumulated._visited.add(producer)) {
            return;
        }
        do {
            i = 0;
            for (Producer<?> producer2 : new ArrayList(state.getParents(producer))) {
                if (!state._accumulated._visited.contains(producer2) && state.inWorkingGraph(producer2)) {
                    reduce(producer2, state);
                    i++;
                }
            }
        } while (i > 0);
        ReducerContext reducerContext = new ReducerContext(state);
        for (Reducer<? super Object> reducer : producer.internalAPI().getGraphReducers()) {
            if (reducer.getLevel().compareTo(state._minimumImportance) >= 0) {
                reducer.reduce(producer, reducerContext);
                if (!state.inWorkingGraph(producer)) {
                    return;
                }
            }
        }
        Iterator it = Iterables.lazyConcatenate(new Supplier[]{() -> {
            return producer.internalAPI().getGraphReducers();
        }, () -> {
            return state._accumulated._classReducerTable.getReducers(producer.getClass());
        }}).iterator();
        while (it.hasNext()) {
            ((Reducer) it.next()).reduce(producer, reducerContext);
            if (!state.inWorkingGraph(producer)) {
                return;
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static Producer<?> instantiateFromWorkingGraph(State state, Producer<?> producer, IdentityHashMap<Producer<?>, Producer<?>> identityHashMap) {
        if (producer instanceof RootProducer) {
            return producer;
        }
        Producer<?> producer2 = identityHashMap.get(producer);
        if (producer2 != null) {
            return producer2;
        }
        ChildProducer childProducer = (ChildProducer) producer;
        List<? extends Producer<?>> list = (List) state.getParents(childProducer).stream().map(producer3 -> {
            return instantiateFromWorkingGraph(state, producer3, identityHashMap);
        }).collect(Collectors.toList());
        ChildProducer withInputsUnsafe = Iterables.elementsAreReferenceEqual(list, childProducer.internalAPI().getInputList()) ? childProducer : childProducer.internalAPI().withInputsUnsafe(list);
        identityHashMap.put(producer, withInputsUnsafe);
        return withInputsUnsafe;
    }

    public static DeduplicatedDAG reduce(List<Placeholder<?>> list, List<Producer<?>> list2, Reducer.Level level) {
        return reduce(new DeduplicatedDAG(list, list2), level);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static DeduplicatedDAG reduce(DeduplicatedDAG deduplicatedDAG, Reducer.Level level) {
        int i;
        if (level == null) {
            return deduplicatedDAG;
        }
        AccumulatedState accumulatedState = new AccumulatedState();
        while (true) {
            State state = new State(deduplicatedDAG._placeholders, deduplicatedDAG._outputs, deduplicatedDAG._childrenMap, level, true, accumulatedState);
            do {
                i = 0;
                Iterator it = new ArrayList(state._outputs).iterator();
                while (it.hasNext()) {
                    Producer<?> producer = (Producer) it.next();
                    if (!state._accumulated._visited.contains(producer) && state.inWorkingGraph(producer)) {
                        reduce(producer, state);
                        i++;
                    }
                }
            } while (i > 0);
            if (!state._modified) {
                return deduplicatedDAG;
            }
            accumulatedState.clean(state._childrenMap);
            IdentityHashMap identityHashMap = new IdentityHashMap();
            ArrayList<Producer<?>> arrayList = state._outputs;
            arrayList.replaceAll(producer2 -> {
                return instantiateFromWorkingGraph(state, producer2, identityHashMap);
            });
            deduplicatedDAG = new DeduplicatedDAG(state._placeholders, arrayList);
        }
    }

    static {
        $assertionsDisabled = !DAGReducer.class.desiredAssertionStatus();
    }
}
