package com.linkedin.dagli.reducer;

import com.linkedin.dagli.producer.Producer;
import com.linkedin.dagli.reducer.Reducer;
import com.linkedin.dagli.transformer.TransformerVariadic;
import com.linkedin.dagli.transformer.TransformerWithInputBound;
import it.unimi.dsi.fastutil.objects.ObjectArraySet;
import java.util.ArrayList;
import java.util.List;
import java.util.Set;

/* loaded from: input_file:com/linkedin/dagli/reducer/AssociativeClassReducer.class */
public class AssociativeClassReducer<V> implements Reducer<TransformerVariadic<V, ?>> {
    private final Set<Class<? extends TransformerWithInputBound<? extends V, ?>>> _parentClasses;

    @Override // com.linkedin.dagli.reducer.Reducer
    public Reducer.Level getLevel() {
        return Reducer.Level.ESSENTIAL;
    }

    @SafeVarargs
    public AssociativeClassReducer(Class<? extends TransformerWithInputBound<? extends V, ?>>... clsArr) {
        this._parentClasses = new ObjectArraySet(clsArr);
    }

    @Override // com.linkedin.dagli.reducer.Reducer
    public void reduce(TransformerVariadic<V, ?> transformerVariadic, Reducer.Context context) {
        List<Producer<?>> parents = context.getParents((TransformerWithInputBound) transformerVariadic);
        if (parents.stream().anyMatch(producer -> {
            return this._parentClasses.contains(producer.getClass());
        })) {
            ArrayList arrayList = new ArrayList(parents.size());
            for (Producer<?> producer2 : parents) {
                if (this._parentClasses.contains(producer2.getClass())) {
                    arrayList.addAll(context.getParents(producer2));
                } else {
                    arrayList.add(producer2);
                }
            }
            context.tryReplaceUnviewed(transformerVariadic, () -> {
                return transformerVariadic.withInputs(arrayList);
            });
        }
    }

    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (obj == null || getClass() != obj.getClass()) {
            return false;
        }
        return this._parentClasses.equals(((AssociativeClassReducer) obj)._parentClasses);
    }

    public int hashCode() {
        return this._parentClasses.hashCode();
    }
}
