package io.kgraph.library.maxbmatching;

import io.kgraph.EdgeWithValue;
import io.kgraph.VertexWithValue;
import io.kgraph.library.maxbmatching.MBMEdgeValue;
import io.kgraph.pregel.ComputeFunction;
import java.util.AbstractMap;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.TreeSet;
import org.apache.log4j.Logger;

/* loaded from: input_file:BOOT-INF/lib/kafka-graphs-core-1.1.1.jar:io/kgraph/library/maxbmatching/MaxBMatching.class */
public class MaxBMatching implements ComputeFunction<Long, Integer, MBMEdgeValue, MBMMessage> {
    private static final Logger LOG = Logger.getLogger(MaxBMatching.class);

    @Override // io.kgraph.pregel.ComputeFunction
    public void compute(int i, VertexWithValue<Long, Integer> vertexWithValue, Iterable<MBMMessage> iterable, Iterable<EdgeWithValue<Long, MBMEdgeValue>> iterable2, ComputeFunction.Callback<Long, Integer, MBMEdgeValue, MBMMessage> callback) {
        if (LOG.isDebugEnabled()) {
            debug(vertexWithValue, iterable2);
        }
        if (vertexWithValue.value().intValue() < 0) {
            throw new AssertionError("Capacity should never be negative: " + vertexWithValue);
        }
        if (vertexWithValue.value().intValue() == 0) {
            removeVertex(vertexWithValue, iterable2, callback);
            callback.voteToHalt();
            return;
        }
        VertexWithValue<Long, Integer> vertexWithValue2 = vertexWithValue;
        if (i > 0) {
            vertexWithValue2 = processUpdates(i, vertexWithValue, iterable, iterable2, callback);
            callback.setNewVertexValue(vertexWithValue2.value());
        }
        if (vertexWithValue2.value().intValue() > 0) {
            sendUpdates(vertexWithValue2, iterable2, callback);
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    private void sendUpdates(VertexWithValue<Long, Integer> vertexWithValue, Iterable<EdgeWithValue<Long, MBMEdgeValue>> iterable, ComputeFunction.Callback<Long, Integer, MBMEdgeValue, MBMMessage> callback) {
        MBMMessage mBMMessage = new MBMMessage(vertexWithValue.id(), MBMEdgeValue.State.PROPOSED);
        int intValue = vertexWithValue.value().intValue();
        TreeSet treeSet = new TreeSet((entry, entry2) -> {
            return (-1) * Double.compare(((MBMEdgeValue) entry.getValue()).getWeight(), ((MBMEdgeValue) entry2.getValue()).getWeight());
        });
        HashMap hashMap = new HashMap();
        for (EdgeWithValue<Long, MBMEdgeValue> edgeWithValue : iterable) {
            if (edgeWithValue.value().getState() == MBMEdgeValue.State.DEFAULT || edgeWithValue.value().getState() == MBMEdgeValue.State.PROPOSED) {
                treeSet.add(new AbstractMap.SimpleImmutableEntry(edgeWithValue.target(), edgeWithValue.value()));
                if (treeSet.size() > intValue) {
                    treeSet.pollLast();
                }
            }
            hashMap.put(edgeWithValue.target(), edgeWithValue.value());
        }
        if (treeSet.isEmpty()) {
            checkSolution(iterable);
            callback.voteToHalt();
            return;
        }
        while (!treeSet.isEmpty()) {
            Map.Entry entry3 = (Map.Entry) treeSet.pollFirst();
            callback.setNewEdgeValue(entry3.getKey(), new MBMEdgeValue(((MBMEdgeValue) hashMap.get(entry3.getKey())).getWeight(), MBMEdgeValue.State.PROPOSED));
            callback.sendMessageTo(entry3.getKey(), mBMMessage);
        }
    }

    private VertexWithValue<Long, Integer> processUpdates(int i, VertexWithValue<Long, Integer> vertexWithValue, Iterable<MBMMessage> iterable, Iterable<EdgeWithValue<Long, MBMEdgeValue>> iterable2, ComputeFunction.Callback<Long, Integer, MBMEdgeValue, MBMMessage> callback) {
        HashSet hashSet = new HashSet();
        int i2 = 0;
        HashMap hashMap = new HashMap();
        for (EdgeWithValue<Long, MBMEdgeValue> edgeWithValue : iterable2) {
            hashMap.put(edgeWithValue.target(), edgeWithValue.value());
        }
        for (MBMMessage mBMMessage : iterable) {
            MBMEdgeValue mBMEdgeValue = (MBMEdgeValue) hashMap.get(mBMMessage.getId());
            if (mBMEdgeValue == null) {
                if (LOG.isDebugEnabled()) {
                    LOG.debug(String.format("Superstep %d Vertex %d: message for removed edge from vertex %d", Integer.valueOf(i), vertexWithValue.id(), mBMMessage.getId()));
                }
            } else if (mBMMessage.getState() == MBMEdgeValue.State.PROPOSED && mBMEdgeValue.getState() == MBMEdgeValue.State.PROPOSED) {
                callback.setNewEdgeValue(mBMMessage.getId(), new MBMEdgeValue(mBMEdgeValue.getWeight(), MBMEdgeValue.State.INCLUDED));
                i2++;
            } else if (mBMMessage.getState() == MBMEdgeValue.State.REMOVED) {
                hashSet.add(mBMMessage.getId());
            }
        }
        VertexWithValue<Long, Integer> vertexWithValue2 = new VertexWithValue<>(vertexWithValue.id(), Integer.valueOf(vertexWithValue.value().intValue() - i2));
        Iterator it = hashSet.iterator();
        while (it.hasNext()) {
            callback.removeEdge((Long) it.next());
        }
        if (LOG.isDebugEnabled()) {
            LOG.debug(String.format("Superstep %d Vertex %d: included %d edges, removed %d edges", Integer.valueOf(i), vertexWithValue.id(), Integer.valueOf(i2), Integer.valueOf(hashSet.size())));
        }
        return vertexWithValue2;
    }

    private void removeVertex(VertexWithValue<Long, Integer> vertexWithValue, Iterable<EdgeWithValue<Long, MBMEdgeValue>> iterable, ComputeFunction.Callback<Long, Integer, MBMEdgeValue, MBMMessage> callback) {
        HashSet hashSet = new HashSet();
        MBMMessage mBMMessage = new MBMMessage(vertexWithValue.id(), MBMEdgeValue.State.REMOVED);
        for (EdgeWithValue<Long, MBMEdgeValue> edgeWithValue : iterable) {
            if (edgeWithValue.value().getState() == MBMEdgeValue.State.DEFAULT) {
                callback.sendMessageTo(edgeWithValue.target(), mBMMessage);
                hashSet.add(edgeWithValue.target());
            }
        }
        Iterator it = hashSet.iterator();
        while (it.hasNext()) {
            callback.removeEdge((Long) it.next());
        }
    }

    private void checkSolution(Iterable<EdgeWithValue<Long, MBMEdgeValue>> iterable) {
        for (EdgeWithValue<Long, MBMEdgeValue> edgeWithValue : iterable) {
            if (edgeWithValue.value().getState() != MBMEdgeValue.State.INCLUDED) {
                throw new AssertionError(String.format("All the edges in the matching should be %s, %s was %s", MBMEdgeValue.State.INCLUDED, edgeWithValue, edgeWithValue.value().getState()));
            }
        }
    }

    private void debug(VertexWithValue<Long, Integer> vertexWithValue, Iterable<EdgeWithValue<Long, MBMEdgeValue>> iterable) {
        LOG.debug(vertexWithValue);
        for (EdgeWithValue<Long, MBMEdgeValue> edgeWithValue : iterable) {
            LOG.debug(String.format("Edge(%d, %s)", edgeWithValue.target(), edgeWithValue.value().toString()));
        }
    }
}
