package com.linkedin.dagli.reducer;

import com.linkedin.dagli.util.type.Classes;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Stream;

/* loaded from: input_file:com/linkedin/dagli/reducer/ClassReducerTable.class */
public class ClassReducerTable {
    private HashMap<Class<?>, HashSet<Reducer<?>>> _reductionMap = new HashMap<>();

    @SafeVarargs
    public final <T> void add(Reducer<T> reducer, Class<? extends T>... clsArr) {
        for (Class<? extends T> cls : clsArr) {
            add(reducer, cls);
        }
    }

    public <T> void add(Reducer<T> reducer, Class<? extends T> cls) {
        this._reductionMap.computeIfAbsent(cls, cls2 -> {
            return new HashSet();
        }).add(reducer);
    }

    public void addAll(ClassReducerTable classReducerTable) {
        if (classReducerTable == null) {
            return;
        }
        for (Map.Entry<Class<?>, HashSet<Reducer<?>>> entry : classReducerTable._reductionMap.entrySet()) {
            this._reductionMap.compute(entry.getKey(), (cls, hashSet) -> {
                if (hashSet == null) {
                    return (HashSet) ((HashSet) entry.getValue()).clone();
                }
                hashSet.addAll((Collection) entry.getValue());
                return hashSet;
            });
        }
    }

    public <T> Set<? extends Reducer<? super T>> getReducers(Class<T> cls) {
        HashSet hashSet = new HashSet();
        Stream<T> filter = Classes.walkHierarchy(cls).map(cls2 -> {
            return this._reductionMap.get(cls2);
        }).filter((v0) -> {
            return Objects.nonNull(v0);
        });
        Objects.requireNonNull(hashSet);
        filter.forEach((v1) -> {
            r1.addAll(v1);
        });
        return hashSet;
    }

    public <T> boolean hasReducer(Class<T> cls, Reducer<? super T> reducer) {
        Stream walkHierarchy = Classes.walkHierarchy(cls);
        HashMap<Class<?>, HashSet<Reducer<?>>> hashMap = this._reductionMap;
        Objects.requireNonNull(hashMap);
        return walkHierarchy.map((v1) -> {
            return r1.get(v1);
        }).filter((v0) -> {
            return Objects.nonNull(v0);
        }).anyMatch(hashSet -> {
            return hashSet.contains(reducer);
        });
    }
}
