package org.apache.gobblin.service.modules.orchestration;

import com.google.inject.Inject;
import com.typesafe.config.Config;
import java.io.IOException;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import javax.inject.Singleton;
import org.apache.gobblin.exception.QuotaExceededException;
import org.apache.gobblin.service.ExecutionStatus;
import org.apache.gobblin.service.modules.flowgraph.Dag;
import org.apache.gobblin.service.modules.orchestration.AbstractUserQuotaManager;
import org.apache.gobblin.service.modules.spec.JobExecutionPlan;
import org.apache.gobblin.util.ConfigUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@Singleton
/* loaded from: input_file:org/apache/gobblin/service/modules/orchestration/InMemoryUserQuotaManager.class */
public class InMemoryUserQuotaManager extends AbstractUserQuotaManager {
    private static final Logger log = LoggerFactory.getLogger(InMemoryUserQuotaManager.class);
    private final Map<String, Integer> proxyUserToJobCount;
    private final Map<String, Integer> flowGroupToJobCount;
    private final Map<String, Integer> requesterToJobCount;
    private final Set<String> runningDagIds;

    @Inject
    public InMemoryUserQuotaManager(Config config) {
        super(config);
        this.proxyUserToJobCount = new ConcurrentHashMap();
        this.flowGroupToJobCount = new ConcurrentHashMap();
        this.requesterToJobCount = new ConcurrentHashMap();
        this.runningDagIds = ConcurrentHashMap.newKeySet();
    }

    protected AbstractUserQuotaManager.QuotaCheck increaseAndCheckQuota(Dag.DagNode<JobExecutionPlan> dagNode) throws IOException {
        AbstractUserQuotaManager.QuotaCheck quotaCheck = new AbstractUserQuotaManager.QuotaCheck(true, true, true, "");
        if (containsDagId(DagManagerUtils.generateDagId(dagNode).toString())) {
            return quotaCheck;
        }
        addDagId(DagManagerUtils.generateDagId(dagNode).toString());
        String string = ConfigUtils.getString(dagNode.getValue().getJobSpec().getConfig(), "user.to.proxy", (String) null);
        String string2 = ConfigUtils.getString(dagNode.getValue().getJobSpec().getConfig(), "flow.group", "");
        String specExecutorUri = DagManagerUtils.getSpecExecutorUri(dagNode);
        StringBuilder sb = new StringBuilder();
        if (string != null && dagNode.getValue().getCurrentAttempts() <= 1) {
            int incrementJobCountAndCheckQuota = incrementJobCountAndCheckQuota(DagManagerUtils.getUserQuotaKey(string, dagNode), getQuotaForUser(string), AbstractUserQuotaManager.CountType.USER_COUNT);
            boolean z = incrementJobCountAndCheckQuota >= 0;
            quotaCheck.setProxyUserCheck(z);
            if (!z) {
                sb.append(String.format("Quota exceeded for proxy user %s on executor %s : quota=%s, requests above quota=%d%n", string, specExecutorUri, Integer.valueOf(getQuotaForUser(string)), Integer.valueOf((Math.abs(incrementJobCountAndCheckQuota) + 1) - getQuotaForUser(string))));
            }
        }
        String serializedRequesterList = DagManagerUtils.getSerializedRequesterList(dagNode);
        boolean z2 = true;
        if (dagNode.getValue().getCurrentAttempts() <= 1) {
            for (String str : DagManagerUtils.getDistinctUniqueRequesters(serializedRequesterList)) {
                int incrementJobCountAndCheckQuota2 = incrementJobCountAndCheckQuota(DagManagerUtils.getUserQuotaKey(str, dagNode), getQuotaForUser(str), AbstractUserQuotaManager.CountType.REQUESTER_COUNT);
                boolean z3 = incrementJobCountAndCheckQuota2 >= 0;
                z2 = z2 && z3;
                quotaCheck.setRequesterCheck(z2);
                if (!z3) {
                    sb.append(String.format("Quota exceeded for requester %s on executor %s : quota=%s, requests above quota=%d%n. ", str, specExecutorUri, Integer.valueOf(getQuotaForUser(str)), Integer.valueOf((Math.abs(incrementJobCountAndCheckQuota2) + 1) - getQuotaForUser(str))));
                }
            }
        }
        if (dagNode.getValue().getCurrentAttempts() <= 1) {
            int incrementJobCountAndCheckQuota3 = incrementJobCountAndCheckQuota(DagManagerUtils.getFlowGroupQuotaKey(string2, dagNode), getQuotaForFlowGroup(string2), AbstractUserQuotaManager.CountType.FLOWGROUP_COUNT);
            boolean z4 = incrementJobCountAndCheckQuota3 >= 0;
            quotaCheck.setFlowGroupCheck(z4);
            if (!z4) {
                sb.append(String.format("Quota exceeded for flowgroup %s on executor %s : quota=%s, requests above quota=%d%n", string2, specExecutorUri, Integer.valueOf(getQuotaForFlowGroup(string2)), Integer.valueOf((Math.abs(incrementJobCountAndCheckQuota3) + 1) - getQuotaForFlowGroup(string2))));
            }
        }
        quotaCheck.setRequesterMessage(sb.toString());
        return quotaCheck;
    }

    protected void rollbackIncrements(Dag.DagNode<JobExecutionPlan> dagNode) throws IOException {
        String string = ConfigUtils.getString(dagNode.getValue().getJobSpec().getConfig(), "user.to.proxy", (String) null);
        String string2 = ConfigUtils.getString(dagNode.getValue().getJobSpec().getConfig(), "flow.group", "");
        List<String> distinctUniqueRequesters = DagManagerUtils.getDistinctUniqueRequesters(DagManagerUtils.getSerializedRequesterList(dagNode));
        decrementJobCount(DagManagerUtils.getUserQuotaKey(string, dagNode), AbstractUserQuotaManager.CountType.USER_COUNT);
        decrementQuotaUsageForUsers(distinctUniqueRequesters);
        decrementJobCount(DagManagerUtils.getFlowGroupQuotaKey(string2, dagNode), AbstractUserQuotaManager.CountType.FLOWGROUP_COUNT);
        removeDagId(DagManagerUtils.generateDagId(dagNode).toString());
    }

    private int incrementJobCountAndCheckQuota(String str, int i, AbstractUserQuotaManager.CountType countType) throws IOException {
        int incrementJobCount = incrementJobCount(str, countType);
        return incrementJobCount >= i ? -incrementJobCount : incrementJobCount;
    }

    private void decrementQuotaUsageForUsers(List<String> list) throws IOException {
        Iterator<String> it = list.iterator();
        while (it.hasNext()) {
            decrementJobCount(it.next(), AbstractUserQuotaManager.CountType.REQUESTER_COUNT);
        }
    }

    @Override // org.apache.gobblin.service.modules.orchestration.UserQuotaManager
    public boolean releaseQuota(Dag.DagNode<JobExecutionPlan> dagNode) throws IOException {
        if (!removeDagId(DagManagerUtils.generateDagId(dagNode).toString())) {
            return false;
        }
        String string = ConfigUtils.getString(dagNode.getValue().getJobSpec().getConfig(), "user.to.proxy", (String) null);
        if (string != null) {
            decrementJobCount(DagManagerUtils.getUserQuotaKey(string, dagNode), AbstractUserQuotaManager.CountType.USER_COUNT);
        }
        decrementJobCount(DagManagerUtils.getFlowGroupQuotaKey(ConfigUtils.getString(dagNode.getValue().getJobSpec().getConfig(), "flow.group", ""), dagNode), AbstractUserQuotaManager.CountType.FLOWGROUP_COUNT);
        String serializedRequesterList = DagManagerUtils.getSerializedRequesterList(dagNode);
        try {
            Iterator<String> it = DagManagerUtils.getDistinctUniqueRequesters(serializedRequesterList).iterator();
            while (it.hasNext()) {
                decrementJobCount(DagManagerUtils.getUserQuotaKey(it.next(), dagNode), AbstractUserQuotaManager.CountType.REQUESTER_COUNT);
            }
            return true;
        } catch (IOException e) {
            log.error("Failed to release quota for requester list " + serializedRequesterList, e);
            return false;
        }
    }

    void addDagId(String str) {
        this.runningDagIds.add(str);
    }

    @Override // org.apache.gobblin.service.modules.orchestration.AbstractUserQuotaManager
    boolean containsDagId(String str) {
        return this.runningDagIds.contains(str);
    }

    boolean removeDagId(String str) {
        return this.runningDagIds.remove(str);
    }

    @Override // org.apache.gobblin.service.modules.orchestration.UserQuotaManager
    public void init(Collection<Dag<JobExecutionPlan>> collection) throws IOException {
        Iterator<Dag<JobExecutionPlan>> it = collection.iterator();
        while (it.hasNext()) {
            for (Dag.DagNode<JobExecutionPlan> dagNode : it.next().getNodes()) {
                if (DagManagerUtils.getExecutionStatus(dagNode) == ExecutionStatus.RUNNING) {
                    increaseAndCheckQuota(dagNode);
                }
            }
        }
    }

    @Override // org.apache.gobblin.service.modules.orchestration.UserQuotaManager
    public void checkQuota(Collection<Dag.DagNode<JobExecutionPlan>> collection) throws IOException {
        for (Dag.DagNode<JobExecutionPlan> dagNode : collection) {
            AbstractUserQuotaManager.QuotaCheck increaseAndCheckQuota = increaseAndCheckQuota(dagNode);
            if (!increaseAndCheckQuota.proxyUserCheck || !increaseAndCheckQuota.requesterCheck || !increaseAndCheckQuota.flowGroupCheck) {
                rollbackIncrements(dagNode);
                throw new QuotaExceededException(increaseAndCheckQuota.requesterMessage);
            }
        }
    }

    private int incrementJobCount(String str, Map<String, Integer> map) {
        Integer num;
        while (true) {
            num = map.get(str);
            if (num == null) {
                if (map.putIfAbsent(str, 1) == null) {
                    break;
                }
            } else if (map.replace(str, num, Integer.valueOf(num.intValue() + 1))) {
                break;
            }
        }
        if (num == null) {
            num = 0;
        }
        return num.intValue();
    }

    private void decrementJobCount(String str, Map<String, Integer> map) {
        Integer num;
        if (str == null) {
            return;
        }
        do {
            num = map.get(str);
            if (num == null || num.intValue() <= 0) {
                break;
            }
        } while (!map.replace(str, num, Integer.valueOf(num.intValue() - 1)));
        if (num == null || num.intValue() == 0) {
            log.warn("Decrement job count was called for " + str + " when the count was already zero/absent.");
        }
    }

    int incrementJobCount(String str, AbstractUserQuotaManager.CountType countType) throws IOException {
        switch (countType) {
            case USER_COUNT:
                return incrementJobCount(str, this.proxyUserToJobCount);
            case REQUESTER_COUNT:
                return incrementJobCount(str, this.requesterToJobCount);
            case FLOWGROUP_COUNT:
                return incrementJobCount(str, this.flowGroupToJobCount);
            default:
                throw new IOException("Invalid count type " + countType);
        }
    }

    void decrementJobCount(String str, AbstractUserQuotaManager.CountType countType) throws IOException {
        switch (countType) {
            case USER_COUNT:
                decrementJobCount(str, this.proxyUserToJobCount);
                return;
            case REQUESTER_COUNT:
                decrementJobCount(str, this.requesterToJobCount);
                return;
            case FLOWGROUP_COUNT:
                decrementJobCount(str, this.flowGroupToJobCount);
                return;
            default:
                throw new IOException("Invalid count type " + countType);
        }
    }
}
