package org.apache.giraph.worker;

import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import java.io.IOException;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import org.apache.giraph.bsp.CentralizedServiceWorker;
import org.apache.giraph.comm.GlobalCommType;
import org.apache.giraph.comm.aggregators.AggregatorUtils;
import org.apache.giraph.comm.aggregators.AllAggregatorServerData;
import org.apache.giraph.comm.aggregators.GlobalCommValueOutputStream;
import org.apache.giraph.comm.aggregators.OwnerAggregatorServerData;
import org.apache.giraph.comm.aggregators.WorkerAggregatorRequestProcessor;
import org.apache.giraph.conf.ImmutableClassesGiraphConfiguration;
import org.apache.giraph.reducers.ReduceOperation;
import org.apache.giraph.reducers.Reducer;
import org.apache.giraph.utils.UnsafeByteArrayOutputStream;
import org.apache.giraph.utils.UnsafeReusableByteArrayInput;
import org.apache.giraph.utils.WritableUtils;
import org.apache.hadoop.io.Writable;
import org.apache.hadoop.util.Progressable;
import org.apache.log4j.Logger;

/* loaded from: input_file:org/apache/giraph/worker/WorkerAggregatorHandler.class */
public class WorkerAggregatorHandler implements WorkerThreadGlobalCommUsage {
    private static final Logger LOG = Logger.getLogger(WorkerAggregatorHandler.class);
    private final Map<String, Writable> broadcastedMap = Maps.newHashMap();
    private final Map<String, Reducer<Object, Writable>> reducerMap = Maps.newHashMap();
    private final CentralizedServiceWorker<?, ?, ?> serviceWorker;
    private final Progressable progressable;
    private final int maxBytesPerAggregatorRequest;
    private final ImmutableClassesGiraphConfiguration conf;

    /* loaded from: input_file:org/apache/giraph/worker/WorkerAggregatorHandler$ThreadLocalWorkerGlobalCommUsage.class */
    public class ThreadLocalWorkerGlobalCommUsage implements WorkerThreadGlobalCommUsage {
        private final Map<String, Reducer<Object, Writable>> threadReducerMap;

        /* JADX WARN: Multi-variable type inference failed */
        public ThreadLocalWorkerGlobalCommUsage() {
            this.threadReducerMap = Maps.newHashMapWithExpectedSize(WorkerAggregatorHandler.this.reducerMap.size());
            UnsafeByteArrayOutputStream unsafeByteArrayOutputStream = new UnsafeByteArrayOutputStream();
            UnsafeReusableByteArrayInput unsafeReusableByteArrayInput = new UnsafeReusableByteArrayInput();
            for (Map.Entry entry : WorkerAggregatorHandler.this.reducerMap.entrySet()) {
                this.threadReducerMap.put(entry.getKey(), new Reducer((ReduceOperation) WritableUtils.createCopy(unsafeByteArrayOutputStream, unsafeReusableByteArrayInput, ((Reducer) entry.getValue()).getReduceOp(), WorkerAggregatorHandler.this.conf)));
            }
        }

        @Override // org.apache.giraph.worker.WorkerReduceUsage
        public void reduce(String str, Object obj) {
            Reducer<Object, Writable> reducer = this.threadReducerMap.get(str);
            if (reducer == null) {
                throw new IllegalStateException("reduce: " + AggregatorUtils.getUnregisteredAggregatorMessage(str, this.threadReducerMap.size() != 0, WorkerAggregatorHandler.this.conf));
            }
            WorkerAggregatorHandler.this.progressable.progress();
            reducer.reduceSingle(obj);
        }

        @Override // org.apache.giraph.worker.WorkerBroadcastUsage
        public <B extends Writable> B getBroadcast(String str) {
            return (B) WorkerAggregatorHandler.this.getBroadcast(str);
        }

        @Override // org.apache.giraph.worker.WorkerThreadGlobalCommUsage
        public void finishThreadComputation() {
            for (Map.Entry<String, Reducer<Object, Writable>> entry : this.threadReducerMap.entrySet()) {
                WorkerAggregatorHandler.this.reducePartial(entry.getKey(), entry.getValue().getCurrentValue());
            }
        }
    }

    public WorkerAggregatorHandler(CentralizedServiceWorker<?, ?, ?> centralizedServiceWorker, ImmutableClassesGiraphConfiguration immutableClassesGiraphConfiguration, Progressable progressable) {
        this.serviceWorker = centralizedServiceWorker;
        this.progressable = progressable;
        this.conf = immutableClassesGiraphConfiguration;
        this.maxBytesPerAggregatorRequest = immutableClassesGiraphConfiguration.getInt(AggregatorUtils.MAX_BYTES_PER_AGGREGATOR_REQUEST, 1048576);
    }

    @Override // org.apache.giraph.worker.WorkerBroadcastUsage
    public <B extends Writable> B getBroadcast(String str) {
        B b = (B) this.broadcastedMap.get(str);
        if (b == null) {
            LOG.warn("getBroadcast: " + AggregatorUtils.getUnregisteredBroadcastMessage(str, this.broadcastedMap.size() != 0, this.conf));
        }
        return b;
    }

    @Override // org.apache.giraph.worker.WorkerReduceUsage
    public void reduce(String str, Object obj) {
        Reducer<Object, Writable> reducer = this.reducerMap.get(str);
        if (reducer == null) {
            throw new IllegalStateException("reduce: " + AggregatorUtils.getUnregisteredReducerMessage(str, this.reducerMap.size() != 0, this.conf));
        }
        this.progressable.progress();
        synchronized (reducer) {
            reducer.reduceSingle(obj);
        }
    }

    protected void reducePartial(String str, Writable writable) {
        Reducer<Object, Writable> reducer = this.reducerMap.get(str);
        if (reducer == null) {
            throw new IllegalStateException("reduce: " + AggregatorUtils.getUnregisteredReducerMessage(str, this.reducerMap.size() != 0, this.conf));
        }
        this.progressable.progress();
        synchronized (reducer) {
            reducer.reducePartial(writable);
        }
    }

    public void prepareSuperstep(WorkerAggregatorRequestProcessor workerAggregatorRequestProcessor) {
        this.broadcastedMap.clear();
        this.reducerMap.clear();
        if (LOG.isDebugEnabled()) {
            LOG.debug("prepareSuperstep: Start preparing aggregators");
        }
        AllAggregatorServerData allAggregatorData = this.serviceWorker.getServerData().getAllAggregatorData();
        try {
            workerAggregatorRequestProcessor.distributeReducedValues(allAggregatorData.getDataFromMasterWhenReady(this.serviceWorker.getMasterInfo()));
            allAggregatorData.fillNextSuperstepMapsWhenReady(getOtherWorkerIdsSet(), this.broadcastedMap, this.reducerMap);
            if (LOG.isDebugEnabled()) {
                LOG.debug("prepareSuperstep: Aggregators prepared");
            }
        } catch (IOException e) {
            throw new IllegalStateException("prepareSuperstep: IOException occurred while trying to distribute aggregators", e);
        }
    }

    public void finishSuperstep(WorkerAggregatorRequestProcessor workerAggregatorRequestProcessor) {
        if (LOG.isInfoEnabled()) {
            LOG.info("finishSuperstep: Start gathering aggregators, workers will send their aggregated values once they are done with superstep computation");
        }
        OwnerAggregatorServerData ownerAggregatorData = this.serviceWorker.getServerData().getOwnerAggregatorData();
        for (Map.Entry<String, Reducer<Object, Writable>> entry : this.reducerMap.entrySet()) {
            try {
                if (!workerAggregatorRequestProcessor.sendReducedValue(entry.getKey(), entry.getValue().getCurrentValue())) {
                    ownerAggregatorData.reduce(entry.getKey(), entry.getValue().getCurrentValue());
                }
                this.progressable.progress();
            } catch (IOException e) {
                throw new IllegalStateException("finishSuperstep: IOException occurred while sending aggregator " + entry.getKey() + " to its owner", e);
            }
        }
        try {
            workerAggregatorRequestProcessor.flush();
            Iterable<Map.Entry<String, Writable>> myReducedValuesWhenReady = ownerAggregatorData.getMyReducedValuesWhenReady(getOtherWorkerIdsSet());
            GlobalCommValueOutputStream globalCommValueOutputStream = new GlobalCommValueOutputStream(false);
            for (Map.Entry<String, Writable> entry2 : myReducedValuesWhenReady) {
                try {
                    if (globalCommValueOutputStream.addValue(entry2.getKey(), GlobalCommType.REDUCED_VALUE, entry2.getValue()) > this.maxBytesPerAggregatorRequest) {
                        workerAggregatorRequestProcessor.sendReducedValuesToMaster(globalCommValueOutputStream.flush());
                    }
                    this.progressable.progress();
                } catch (IOException e2) {
                    throw new IllegalStateException("finishSuperstep: IOException occurred while writing aggregator " + entry2.getKey(), e2);
                }
            }
            try {
                workerAggregatorRequestProcessor.sendReducedValuesToMaster(globalCommValueOutputStream.flush());
                this.serviceWorker.getWorkerClient().waitAllRequests();
                ownerAggregatorData.reset();
                if (LOG.isDebugEnabled()) {
                    LOG.debug("finishSuperstep: Aggregators finished");
                }
            } catch (IOException e3) {
                throw new IllegalStateException("finishSuperstep: IOException occured while sending aggregators to master", e3);
            }
        } catch (IOException e4) {
            throw new IllegalStateException("finishSuperstep: IOException occurred while sending aggregators to owners", e4);
        }
    }

    public WorkerThreadGlobalCommUsage newThreadAggregatorUsage() {
        return AggregatorUtils.useThreadLocalAggregators(this.conf) ? new ThreadLocalWorkerGlobalCommUsage() : this;
    }

    @Override // org.apache.giraph.worker.WorkerThreadGlobalCommUsage
    public void finishThreadComputation() {
    }

    public Set<Integer> getOtherWorkerIdsSet() {
        HashSet newHashSetWithExpectedSize = Sets.newHashSetWithExpectedSize(this.serviceWorker.getWorkerInfoList().size());
        for (WorkerInfo workerInfo : this.serviceWorker.getWorkerInfoList()) {
            if (workerInfo.getTaskId() != this.serviceWorker.getWorkerInfo().getTaskId()) {
                newHashSetWithExpectedSize.add(Integer.valueOf(workerInfo.getTaskId()));
            }
        }
        return newHashSetWithExpectedSize;
    }
}
