package org.apache.reef.examples.group.bgd;

import java.util.List;
import java.util.logging.Logger;
import javax.inject.Inject;
import org.apache.reef.examples.group.bgd.data.Example;
import org.apache.reef.examples.group.bgd.loss.LossFunction;
import org.apache.reef.examples.group.bgd.operatornames.ControlMessageBroadcaster;
import org.apache.reef.examples.group.bgd.operatornames.DescentDirectionBroadcaster;
import org.apache.reef.examples.group.bgd.operatornames.LineSearchEvaluationsReducer;
import org.apache.reef.examples.group.bgd.operatornames.LossAndGradientReducer;
import org.apache.reef.examples.group.bgd.operatornames.MinEtaBroadcaster;
import org.apache.reef.examples.group.bgd.operatornames.ModelAndDescentDirectionBroadcaster;
import org.apache.reef.examples.group.bgd.operatornames.ModelBroadcaster;
import org.apache.reef.examples.group.bgd.parameters.AllCommunicationGroup;
import org.apache.reef.examples.group.bgd.parameters.ProbabilityOfFailure;
import org.apache.reef.examples.group.bgd.utils.StepSizes;
import org.apache.reef.examples.group.utils.math.DenseVector;
import org.apache.reef.examples.group.utils.math.Vector;
import org.apache.reef.io.network.group.api.operators.Broadcast;
import org.apache.reef.io.network.group.api.operators.Reduce;
import org.apache.reef.io.network.group.api.task.CommunicationGroupClient;
import org.apache.reef.io.network.group.api.task.GroupCommClient;
import org.apache.reef.io.network.util.Pair;
import org.apache.reef.tang.annotations.Parameter;
import org.apache.reef.task.Task;

/* loaded from: input_file:org/apache/reef/examples/group/bgd/SlaveTask.class */
public class SlaveTask implements Task {
    private static final Logger LOG;
    private final double failureProb;
    private final CommunicationGroupClient communicationGroup;
    private final Broadcast.Receiver<ControlMessages> controlMessageBroadcaster;
    private final Broadcast.Receiver<Vector> modelBroadcaster;
    private final Reduce.Sender<Pair<Pair<Double, Integer>, Vector>> lossAndGradientReducer;
    private final Broadcast.Receiver<Pair<Vector, Vector>> modelAndDescentDirectionBroadcaster;
    private final Broadcast.Receiver<Vector> descentDirectionBroadcaster;
    private final Reduce.Sender<Pair<Vector, Integer>> lineSearchEvaluationsReducer;
    private final Broadcast.Receiver<Double> minEtaBroadcaster;
    private final ExampleList dataSet;
    private final LossFunction lossFunction;
    private final StepSizes ts;
    static final /* synthetic */ boolean $assertionsDisabled;
    private List<Example> examples = null;
    private Vector model = null;
    private Vector descentDirection = null;

    @Inject
    public SlaveTask(GroupCommClient groupCommClient, ExampleList exampleList, LossFunction lossFunction, @Parameter(ProbabilityOfFailure.class) double d, StepSizes stepSizes) {
        this.dataSet = exampleList;
        this.lossFunction = lossFunction;
        this.failureProb = d;
        LOG.info("Using pFailure=" + this.failureProb);
        this.ts = stepSizes;
        this.communicationGroup = groupCommClient.getCommunicationGroup(AllCommunicationGroup.class);
        this.controlMessageBroadcaster = this.communicationGroup.getBroadcastReceiver(ControlMessageBroadcaster.class);
        this.modelBroadcaster = this.communicationGroup.getBroadcastReceiver(ModelBroadcaster.class);
        this.lossAndGradientReducer = this.communicationGroup.getReduceSender(LossAndGradientReducer.class);
        this.modelAndDescentDirectionBroadcaster = this.communicationGroup.getBroadcastReceiver(ModelAndDescentDirectionBroadcaster.class);
        this.descentDirectionBroadcaster = this.communicationGroup.getBroadcastReceiver(DescentDirectionBroadcaster.class);
        this.lineSearchEvaluationsReducer = this.communicationGroup.getReduceSender(LineSearchEvaluationsReducer.class);
        this.minEtaBroadcaster = this.communicationGroup.getBroadcastReceiver(MinEtaBroadcaster.class);
    }

    public byte[] call(byte[] bArr) throws Exception {
        loadData();
        boolean z = true;
        while (z) {
            switch ((ControlMessages) this.controlMessageBroadcaster.receive()) {
                case Stop:
                    z = false;
                    break;
                case ComputeGradientWithModel:
                    failPerhaps();
                    this.model = (Vector) this.modelBroadcaster.receive();
                    this.lossAndGradientReducer.send(computeLossAndGradient());
                    break;
                case ComputeGradientWithMinEta:
                    failPerhaps();
                    double doubleValue = ((Double) this.minEtaBroadcaster.receive()).doubleValue();
                    if (!$assertionsDisabled && this.descentDirection == null) {
                        throw new AssertionError();
                    }
                    this.descentDirection.scale(doubleValue);
                    if (!$assertionsDisabled && this.model == null) {
                        throw new AssertionError();
                    }
                    this.model.add(this.descentDirection);
                    this.lossAndGradientReducer.send(computeLossAndGradient());
                    break;
                    break;
                case DoLineSearch:
                    failPerhaps();
                    this.descentDirection = (Vector) this.descentDirectionBroadcaster.receive();
                    this.lineSearchEvaluationsReducer.send(lineSearchEvals());
                    break;
                case DoLineSearchWithModel:
                    failPerhaps();
                    Pair pair = (Pair) this.modelAndDescentDirectionBroadcaster.receive();
                    this.model = (Vector) pair.getFirst();
                    this.descentDirection = (Vector) pair.getSecond();
                    this.lineSearchEvaluationsReducer.send(lineSearchEvals());
                    break;
            }
        }
        return null;
    }

    private void failPerhaps() {
        if (Math.random() < this.failureProb) {
            throw new RuntimeException("Simulated Failure");
        }
    }

    private Pair<Vector, Integer> lineSearchEvals() {
        if (this.examples == null) {
            loadData();
        }
        DenseVector denseVector = new DenseVector(this.examples.size());
        DenseVector denseVector2 = new DenseVector(this.examples.size());
        for (int i = 0; i < this.examples.size(); i++) {
            Example example = this.examples.get(i);
            denseVector.set(i, example.predict(this.model));
            denseVector2.set(i, example.predict(this.descentDirection));
        }
        double[] t = this.ts.getT();
        DenseVector denseVector3 = new DenseVector(t.length);
        int i2 = 0;
        for (double d : t) {
            double d2 = 0.0d;
            for (int i3 = 0; i3 < this.examples.size(); i3++) {
                d2 += this.lossFunction.computeLoss(this.examples.get(i3).getLabel(), denseVector.get(i3) + (d * denseVector2.get(i3)));
            }
            int i4 = i2;
            i2++;
            denseVector3.set(i4, d2);
        }
        return new Pair<>(denseVector3, Integer.valueOf(this.examples.size()));
    }

    private Pair<Pair<Double, Integer>, Vector> computeLossAndGradient() {
        if (this.examples == null) {
            loadData();
        }
        DenseVector denseVector = new DenseVector(this.model.size());
        double d = 0.0d;
        for (Example example : this.examples) {
            double predict = example.predict(this.model);
            example.addGradient(denseVector, this.lossFunction.computeGradient(example.getLabel(), predict));
            d += this.lossFunction.computeLoss(example.getLabel(), predict);
        }
        return new Pair<>(new Pair(Double.valueOf(d), Integer.valueOf(this.examples.size())), denseVector);
    }

    private void loadData() {
        LOG.info("Loading data");
        this.examples = this.dataSet.getExamples();
    }

    static {
        $assertionsDisabled = !SlaveTask.class.desiredAssertionStatus();
        LOG = Logger.getLogger(SlaveTask.class.getName());
    }
}
