/*
 * Decompiled with CFR 0.152.
 */
package org.apache.geaflow.cluster.fetcher;

import java.io.Serializable;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.geaflow.common.exception.GeaflowRuntimeException;
import org.apache.geaflow.shuffle.desc.ShardInputDesc;
import org.apache.geaflow.shuffle.message.PipelineBarrier;

public class BarrierHandler
implements Serializable {
    private final int taskId;
    private final Map<Integer, ShardInputDesc> inputShards;
    private final Map<Integer, Integer> edgeId2sliceNum;
    private final int inputSliceNum;
    private final Map<Integer, Set<PipelineBarrier>> edgeBarrierCache;
    private final Map<Long, Set<PipelineBarrier>> windowBarrierCache;
    private long finishedWindowId;
    private long totalWindowCount;

    public BarrierHandler(int taskId, Map<Integer, ShardInputDesc> inputShards) {
        this.taskId = taskId;
        this.inputShards = inputShards;
        this.edgeId2sliceNum = inputShards.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, e -> ((ShardInputDesc)e.getValue()).getSliceNum()));
        this.inputSliceNum = this.edgeId2sliceNum.values().stream().mapToInt(i -> i).sum();
        this.edgeBarrierCache = this.inputShards.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, e -> new HashSet()));
        this.windowBarrierCache = new HashMap<Long, Set<PipelineBarrier>>();
        this.finishedWindowId = -1L;
        this.totalWindowCount = 0L;
    }

    public boolean checkCompleted(PipelineBarrier barrier) {
        if (barrier.getWindowId() <= this.finishedWindowId) {
            throw new GeaflowRuntimeException(String.format("illegal state: taskId %s window %s has finished, last finished window is: %s", this.taskId, barrier.getWindowId(), this.finishedWindowId));
        }
        int edgeId = barrier.getEdgeId();
        long windowId = barrier.getWindowId();
        Set edgeBarriers = this.edgeBarrierCache.computeIfAbsent(edgeId, k -> new HashSet());
        Set windowBarriers = this.windowBarrierCache.computeIfAbsent(windowId, k -> new HashSet());
        edgeBarriers.add(barrier);
        windowBarriers.add(barrier);
        if (this.inputShards.get(edgeId).isPrefetchWrite()) {
            int barrierSize = edgeBarriers.size();
            if (barrierSize == this.edgeId2sliceNum.get(edgeId)) {
                this.edgeBarrierCache.remove(edgeId);
                edgeBarriers.clear();
                if (this.edgeBarrierCache.isEmpty()) {
                    this.windowBarrierCache.remove(windowId);
                    this.finishedWindowId = windowId;
                    this.totalWindowCount = windowBarriers.stream().mapToLong(PipelineBarrier::getCount).sum();
                    windowBarriers.clear();
                }
                return true;
            }
        } else {
            int barrierSize = windowBarriers.size();
            if (barrierSize == this.inputSliceNum) {
                this.windowBarrierCache.remove(windowId);
                this.finishedWindowId = windowId;
                this.totalWindowCount = windowBarriers.stream().mapToLong(PipelineBarrier::getCount).sum();
                windowBarriers.clear();
                return true;
            }
        }
        return false;
    }

    public long getTotalWindowCount() {
        return this.totalWindowCount;
    }
}

