package org.apache.reef.examples.group.bgd;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.logging.Level;
import java.util.logging.Logger;
import javax.inject.Inject;
import org.apache.reef.annotations.audience.DriverSide;
import org.apache.reef.driver.context.ActiveContext;
import org.apache.reef.driver.context.ServiceConfiguration;
import org.apache.reef.driver.task.CompletedTask;
import org.apache.reef.driver.task.FailedTask;
import org.apache.reef.driver.task.RunningTask;
import org.apache.reef.driver.task.TaskConfiguration;
import org.apache.reef.evaluator.context.parameters.ContextIdentifier;
import org.apache.reef.examples.group.bgd.data.parser.Parser;
import org.apache.reef.examples.group.bgd.data.parser.SVMLightParser;
import org.apache.reef.examples.group.bgd.loss.LossFunction;
import org.apache.reef.examples.group.bgd.operatornames.ControlMessageBroadcaster;
import org.apache.reef.examples.group.bgd.operatornames.DescentDirectionBroadcaster;
import org.apache.reef.examples.group.bgd.operatornames.LineSearchEvaluationsReducer;
import org.apache.reef.examples.group.bgd.operatornames.LossAndGradientReducer;
import org.apache.reef.examples.group.bgd.operatornames.MinEtaBroadcaster;
import org.apache.reef.examples.group.bgd.operatornames.ModelAndDescentDirectionBroadcaster;
import org.apache.reef.examples.group.bgd.operatornames.ModelBroadcaster;
import org.apache.reef.examples.group.bgd.parameters.AllCommunicationGroup;
import org.apache.reef.examples.group.bgd.parameters.BGDControlParameters;
import org.apache.reef.examples.group.bgd.parameters.ModelDimensions;
import org.apache.reef.examples.group.bgd.parameters.ProbabilityOfFailure;
import org.apache.reef.io.data.loading.api.DataLoadingService;
import org.apache.reef.io.network.group.api.driver.CommunicationGroupDriver;
import org.apache.reef.io.network.group.api.driver.GroupCommDriver;
import org.apache.reef.io.network.group.impl.config.BroadcastOperatorSpec;
import org.apache.reef.io.network.group.impl.config.ReduceOperatorSpec;
import org.apache.reef.io.serialization.Codec;
import org.apache.reef.io.serialization.SerializableCodec;
import org.apache.reef.poison.PoisonedConfiguration;
import org.apache.reef.tang.Configuration;
import org.apache.reef.tang.Configurations;
import org.apache.reef.tang.Tang;
import org.apache.reef.tang.annotations.Unit;
import org.apache.reef.tang.exceptions.InjectionException;
import org.apache.reef.tang.formats.ConfigurationSerializer;
import org.apache.reef.wake.EventHandler;

@DriverSide
@Unit
/* loaded from: input_file:org/apache/reef/examples/group/bgd/BGDDriver.class */
public class BGDDriver {
    private static final Logger LOG = Logger.getLogger(BGDDriver.class.getName());
    private static final Tang TANG = Tang.Factory.getTang();
    private static final double STARTUP_FAILURE_PROB = 0.01d;
    private final DataLoadingService dataLoadingService;
    private final GroupCommDriver groupCommDriver;
    private final ConfigurationSerializer confSerializer;
    private final CommunicationGroupDriver communicationsGroup;
    private final AtomicBoolean masterSubmitted = new AtomicBoolean(false);
    private final AtomicInteger slaveIds = new AtomicInteger(0);
    private final Map<String, RunningTask> runningTasks = new HashMap();
    private final AtomicBoolean jobComplete = new AtomicBoolean(false);
    private final Codec<ArrayList<Double>> lossCodec = new SerializableCodec();
    private final BGDControlParameters bgdControlParameters;
    private String communicationsGroupMasterContextId;

    /* loaded from: input_file:org/apache/reef/examples/group/bgd/BGDDriver$ContextActiveHandler.class */
    final class ContextActiveHandler implements EventHandler<ActiveContext> {
        static final /* synthetic */ boolean $assertionsDisabled;

        ContextActiveHandler() {
        }

        public void onNext(ActiveContext activeContext) {
            BGDDriver.LOG.log(Level.INFO, "Got active context: {0}", activeContext.getId());
            if (jobRunning(activeContext)) {
                if (BGDDriver.this.groupCommDriver.isConfigured(activeContext)) {
                    submitTask(activeContext);
                } else {
                    submitGroupCommunicationsService(activeContext);
                }
            }
        }

        private void submitGroupCommunicationsService(ActiveContext activeContext) {
            Configuration build;
            Configuration contextConfiguration = BGDDriver.this.groupCommDriver.getContextConfiguration();
            String contextId = BGDDriver.this.getContextId(contextConfiguration);
            if (BGDDriver.this.dataLoadingService.isDataLoadedContext(activeContext)) {
                build = Tang.Factory.getTang().newConfigurationBuilder(new Configuration[]{BGDDriver.this.groupCommDriver.getServiceConfiguration(), ServiceConfiguration.CONF.set(ServiceConfiguration.SERVICES, ExampleList.class).build()}).bindImplementation(Parser.class, SVMLightParser.class).build();
            } else {
                BGDDriver.this.communicationsGroupMasterContextId = contextId;
                build = BGDDriver.this.groupCommDriver.getServiceConfiguration();
            }
            BGDDriver.LOG.log(Level.FINEST, "Submit GCContext conf: {0} and Service conf: {1}", new Object[]{BGDDriver.this.confSerializer.toString(contextConfiguration), BGDDriver.this.confSerializer.toString(build)});
            activeContext.submitContextAndService(contextConfiguration, build);
        }

        private void submitTask(ActiveContext activeContext) {
            Configuration slaveTaskConfiguration;
            if (!$assertionsDisabled && !BGDDriver.this.groupCommDriver.isConfigured(activeContext)) {
                throw new AssertionError();
            }
            if (!activeContext.getId().equals(BGDDriver.this.communicationsGroupMasterContextId) || BGDDriver.this.masterTaskSubmitted()) {
                slaveTaskConfiguration = BGDDriver.this.getSlaveTaskConfiguration(BGDDriver.this.getSlaveId(activeContext));
                BGDDriver.LOG.info("Submitting SlaveTask conf");
            } else {
                slaveTaskConfiguration = BGDDriver.this.getMasterTaskConfiguration();
                BGDDriver.LOG.info("Submitting MasterTask conf");
            }
            BGDDriver.this.communicationsGroup.addTask(slaveTaskConfiguration);
            Configuration taskConfiguration = BGDDriver.this.groupCommDriver.getTaskConfiguration(slaveTaskConfiguration);
            BGDDriver.LOG.log(Level.FINEST, "{0}", BGDDriver.this.confSerializer.toString(taskConfiguration));
            activeContext.submitTask(taskConfiguration);
        }

        private boolean jobRunning(ActiveContext activeContext) {
            synchronized (BGDDriver.this.runningTasks) {
                if (!BGDDriver.this.jobComplete.get()) {
                    return true;
                }
                BGDDriver.LOG.log(Level.INFO, "Job complete. Not submitting any task. Closing context {0}", activeContext);
                activeContext.close();
                return false;
            }
        }

        static {
            $assertionsDisabled = !BGDDriver.class.desiredAssertionStatus();
        }
    }

    /* loaded from: input_file:org/apache/reef/examples/group/bgd/BGDDriver$TaskCompletedHandler.class */
    final class TaskCompletedHandler implements EventHandler<CompletedTask> {
        TaskCompletedHandler() {
        }

        public void onNext(CompletedTask completedTask) {
            BGDDriver.LOG.log(Level.INFO, "Got CompletedTask: {0}", completedTask.getId());
            byte[] bArr = completedTask.get();
            if (bArr != null) {
                Iterator it = ((List) BGDDriver.this.lossCodec.decode(bArr)).iterator();
                while (it.hasNext()) {
                    BGDDriver.LOG.log(Level.INFO, "OUT: LOSS = {0}", (Double) it.next());
                }
            }
            synchronized (BGDDriver.this.runningTasks) {
                BGDDriver.LOG.log(Level.INFO, "Acquired lock on runningTasks. Removing {0}", completedTask.getId());
                if (((RunningTask) BGDDriver.this.runningTasks.remove(completedTask.getId())) != null) {
                    BGDDriver.LOG.log(Level.INFO, "Closing active context: {0}", completedTask.getActiveContext().getId());
                    completedTask.getActiveContext().close();
                } else {
                    BGDDriver.LOG.log(Level.INFO, "Master must have closed active context already for task {0}", completedTask.getId());
                }
                if ("MasterTask".equals(completedTask.getId())) {
                    BGDDriver.this.jobComplete.set(true);
                    BGDDriver.LOG.log(Level.INFO, "Master(=>Job) complete. Closing other running tasks: {0}", BGDDriver.this.runningTasks.values());
                    Iterator it2 = BGDDriver.this.runningTasks.values().iterator();
                    while (it2.hasNext()) {
                        ((RunningTask) it2.next()).getActiveContext().close();
                    }
                    BGDDriver.LOG.finest("Clearing runningTasks");
                    BGDDriver.this.runningTasks.clear();
                }
            }
        }
    }

    /* loaded from: input_file:org/apache/reef/examples/group/bgd/BGDDriver$TaskFailedHandler.class */
    final class TaskFailedHandler implements EventHandler<FailedTask> {
        TaskFailedHandler() {
        }

        public void onNext(FailedTask failedTask) {
            String id = failedTask.getId();
            BGDDriver.LOG.log(Level.WARNING, "Got failed Task: " + id);
            if (jobRunning(id)) {
                ActiveContext activeContext = (ActiveContext) failedTask.getActiveContext().get();
                Configuration taskConfiguration = BGDDriver.this.groupCommDriver.getTaskConfiguration(BGDDriver.this.getSlaveTaskConfiguration(id));
                BGDDriver.LOG.log(Level.FINEST, "Submit SlaveTask conf: {0}", BGDDriver.this.confSerializer.toString(taskConfiguration));
                activeContext.submitTask(taskConfiguration);
            }
        }

        private boolean jobRunning(String str) {
            synchronized (BGDDriver.this.runningTasks) {
                if (!BGDDriver.this.jobComplete.get()) {
                    return true;
                }
                RunningTask runningTask = (RunningTask) BGDDriver.this.runningTasks.remove(str);
                BGDDriver.LOG.log(Level.INFO, "Job has completed. Not resubmitting");
                if (runningTask != null) {
                    BGDDriver.LOG.log(Level.INFO, "Closing activecontext");
                    runningTask.getActiveContext().close();
                } else {
                    BGDDriver.LOG.log(Level.INFO, "Master must have closed my context");
                }
                return false;
            }
        }
    }

    /* loaded from: input_file:org/apache/reef/examples/group/bgd/BGDDriver$TaskRunningHandler.class */
    final class TaskRunningHandler implements EventHandler<RunningTask> {
        TaskRunningHandler() {
        }

        public void onNext(RunningTask runningTask) {
            synchronized (BGDDriver.this.runningTasks) {
                if (BGDDriver.this.jobComplete.get()) {
                    BGDDriver.LOG.log(Level.INFO, "Job complete. Closing context: {0}", runningTask.getActiveContext().getId());
                    runningTask.getActiveContext().close();
                } else {
                    BGDDriver.LOG.log(Level.INFO, "Job has not completed yet. Adding to runningTasks: {0}", runningTask);
                    BGDDriver.this.runningTasks.put(runningTask.getId(), runningTask);
                }
            }
        }
    }

    @Inject
    public BGDDriver(DataLoadingService dataLoadingService, GroupCommDriver groupCommDriver, ConfigurationSerializer configurationSerializer, BGDControlParameters bGDControlParameters) {
        this.dataLoadingService = dataLoadingService;
        this.groupCommDriver = groupCommDriver;
        this.confSerializer = configurationSerializer;
        this.bgdControlParameters = bGDControlParameters;
        int minParts = (bGDControlParameters.isRampup() ? bGDControlParameters.getMinParts() : dataLoadingService.getNumberOfPartitions()) + 1;
        this.communicationsGroup = this.groupCommDriver.newCommunicationGroup(AllCommunicationGroup.class, minParts);
        LOG.log(Level.INFO, "Obtained entire communication group: start with {0} partitions", Integer.valueOf(minParts));
        this.communicationsGroup.addBroadcast(ControlMessageBroadcaster.class, BroadcastOperatorSpec.newBuilder().setSenderId("MasterTask").setDataCodecClass(SerializableCodec.class).build()).addBroadcast(ModelBroadcaster.class, BroadcastOperatorSpec.newBuilder().setSenderId("MasterTask").setDataCodecClass(SerializableCodec.class).build()).addReduce(LossAndGradientReducer.class, ReduceOperatorSpec.newBuilder().setReceiverId("MasterTask").setDataCodecClass(SerializableCodec.class).setReduceFunctionClass(LossAndGradientReduceFunction.class).build()).addBroadcast(ModelAndDescentDirectionBroadcaster.class, BroadcastOperatorSpec.newBuilder().setSenderId("MasterTask").setDataCodecClass(SerializableCodec.class).build()).addBroadcast(DescentDirectionBroadcaster.class, BroadcastOperatorSpec.newBuilder().setSenderId("MasterTask").setDataCodecClass(SerializableCodec.class).build()).addReduce(LineSearchEvaluationsReducer.class, ReduceOperatorSpec.newBuilder().setReceiverId("MasterTask").setDataCodecClass(SerializableCodec.class).setReduceFunctionClass(LineSearchReduceFunction.class).build()).addBroadcast(MinEtaBroadcaster.class, BroadcastOperatorSpec.newBuilder().setSenderId("MasterTask").setDataCodecClass(SerializableCodec.class).build()).finalise();
        LOG.log(Level.INFO, "Added operators to communicationsGroup");
    }

    public Configuration getMasterTaskConfiguration() {
        return Configurations.merge(new Configuration[]{TaskConfiguration.CONF.set(TaskConfiguration.IDENTIFIER, "MasterTask").set(TaskConfiguration.TASK, MasterTask.class).build(), this.bgdControlParameters.getConfiguration()});
    }

    /* JADX INFO: Access modifiers changed from: private */
    public Configuration getSlaveTaskConfiguration(String str) {
        return Tang.Factory.getTang().newConfigurationBuilder(new Configuration[]{TaskConfiguration.CONF.set(TaskConfiguration.IDENTIFIER, str).set(TaskConfiguration.TASK, SlaveTask.class).build()}).bindNamedParameter(ModelDimensions.class, "" + this.bgdControlParameters.getDimensions()).bindImplementation(LossFunction.class, this.bgdControlParameters.getLossFunction()).bindNamedParameter(ProbabilityOfFailure.class, Double.toString(1.0d - Math.pow(this.bgdControlParameters.getProbOfSuccessfulIteration(), 1.0d / this.dataLoadingService.getNumberOfPartitions()))).build();
    }

    private Configuration getTaskPoisonConfiguration() {
        return PoisonedConfiguration.TASK_CONF.set(PoisonedConfiguration.CRASH_PROBABILITY, Double.valueOf(STARTUP_FAILURE_PROB)).set(PoisonedConfiguration.CRASH_TIMEOUT, 1).build();
    }

    /* JADX INFO: Access modifiers changed from: private */
    public String getContextId(Configuration configuration) {
        try {
            return (String) TANG.newInjector(configuration).getNamedInstance(ContextIdentifier.class);
        } catch (InjectionException e) {
            throw new RuntimeException("Unable to inject context identifier from context conf", e);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    public String getSlaveId(ActiveContext activeContext) {
        return "SlaveTask-" + this.slaveIds.getAndIncrement();
    }

    /* JADX INFO: Access modifiers changed from: private */
    public boolean masterTaskSubmitted() {
        return !this.masterSubmitted.compareAndSet(false, true);
    }
}
