package org.apache.reef.examples.group.bgd;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.logging.Level;
import java.util.logging.Logger;
import javax.inject.Inject;
import org.apache.commons.configuration.tree.DefaultExpressionEngine;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;
import org.apache.reef.examples.group.bgd.operatornames.ControlMessageBroadcaster;
import org.apache.reef.examples.group.bgd.operatornames.DescentDirectionBroadcaster;
import org.apache.reef.examples.group.bgd.operatornames.LineSearchEvaluationsReducer;
import org.apache.reef.examples.group.bgd.operatornames.LossAndGradientReducer;
import org.apache.reef.examples.group.bgd.operatornames.MinEtaBroadcaster;
import org.apache.reef.examples.group.bgd.operatornames.ModelAndDescentDirectionBroadcaster;
import org.apache.reef.examples.group.bgd.operatornames.ModelBroadcaster;
import org.apache.reef.examples.group.bgd.parameters.AllCommunicationGroup;
import org.apache.reef.examples.group.bgd.parameters.EnableRampup;
import org.apache.reef.examples.group.bgd.parameters.Iterations;
import org.apache.reef.examples.group.bgd.parameters.Lambda;
import org.apache.reef.examples.group.bgd.parameters.ModelDimensions;
import org.apache.reef.examples.group.bgd.utils.StepSizes;
import org.apache.reef.examples.group.utils.math.DenseVector;
import org.apache.reef.examples.group.utils.math.Vector;
import org.apache.reef.examples.group.utils.timer.Timer;
import org.apache.reef.exception.evaluator.NetworkException;
import org.apache.reef.io.Tuple;
import org.apache.reef.io.network.group.api.GroupChanges;
import org.apache.reef.io.network.group.api.operators.Broadcast;
import org.apache.reef.io.network.group.api.operators.Reduce;
import org.apache.reef.io.network.group.api.task.CommunicationGroupClient;
import org.apache.reef.io.network.group.api.task.GroupCommClient;
import org.apache.reef.io.network.util.Pair;
import org.apache.reef.io.serialization.Codec;
import org.apache.reef.io.serialization.SerializableCodec;
import org.apache.reef.tang.annotations.Parameter;
import org.apache.reef.task.Task;

/* loaded from: input_file:org/apache/reef/examples/group/bgd/MasterTask.class */
public class MasterTask implements Task {
    public static final String TASK_ID = "MasterTask";
    private static final Logger LOG = Logger.getLogger(MasterTask.class.getName());
    private final CommunicationGroupClient communicationGroupClient;
    private final Broadcast.Sender<ControlMessages> controlMessageBroadcaster;
    private final Broadcast.Sender<Vector> modelBroadcaster;
    private final Reduce.Receiver<Pair<Pair<Double, Integer>, Vector>> lossAndGradientReducer;
    private final Broadcast.Sender<Pair<Vector, Vector>> modelAndDescentDirectionBroadcaster;
    private final Broadcast.Sender<Vector> descentDriectionBroadcaster;
    private final Reduce.Receiver<Pair<Vector, Integer>> lineSearchEvaluationsReducer;
    private final Broadcast.Sender<Double> minEtaBroadcaster;
    private final boolean ignoreAndContinue;
    private final StepSizes ts;
    private final double lambda;
    private final int maxIters;
    private final Vector model;
    private final ArrayList<Double> losses = new ArrayList<>();
    private final Codec<ArrayList<Double>> lossCodec = new SerializableCodec();
    private boolean sendModel = true;
    private double minEta = CMAESOptimizer.DEFAULT_STOPFITNESS;

    @Inject
    public MasterTask(GroupCommClient groupCommClient, @Parameter(ModelDimensions.class) int i, @Parameter(Lambda.class) double d, @Parameter(Iterations.class) int i2, @Parameter(EnableRampup.class) boolean z, StepSizes stepSizes) {
        this.lambda = d;
        this.maxIters = i2;
        this.ts = stepSizes;
        this.ignoreAndContinue = z;
        this.model = new DenseVector(i);
        this.communicationGroupClient = groupCommClient.getCommunicationGroup(AllCommunicationGroup.class);
        this.controlMessageBroadcaster = this.communicationGroupClient.getBroadcastSender(ControlMessageBroadcaster.class);
        this.modelBroadcaster = this.communicationGroupClient.getBroadcastSender(ModelBroadcaster.class);
        this.lossAndGradientReducer = this.communicationGroupClient.getReduceReceiver(LossAndGradientReducer.class);
        this.modelAndDescentDirectionBroadcaster = this.communicationGroupClient.getBroadcastSender(ModelAndDescentDirectionBroadcaster.class);
        this.descentDriectionBroadcaster = this.communicationGroupClient.getBroadcastSender(DescentDirectionBroadcaster.class);
        this.lineSearchEvaluationsReducer = this.communicationGroupClient.getReduceReceiver(LineSearchEvaluationsReducer.class);
        this.minEtaBroadcaster = this.communicationGroupClient.getBroadcastSender(MinEtaBroadcaster.class);
    }

    @Override // org.apache.reef.task.Task
    public byte[] call(byte[] bArr) throws Exception {
        double d = Double.MAX_VALUE;
        for (int i = 1; !converged(i, d); i++) {
            Timer timer = new Timer("Current Iteration(" + i + DefaultExpressionEngine.DEFAULT_INDEX_END);
            Throwable th = null;
            try {
                try {
                    Pair<Double, Vector> computeLossAndGradient = computeLossAndGradient();
                    this.losses.add(computeLossAndGradient.getFirst());
                    Vector descentDirection = getDescentDirection(computeLossAndGradient.getSecond());
                    updateModel(descentDirection);
                    d = descentDirection.norm2();
                    if (timer != null) {
                        if (0 != 0) {
                            try {
                                timer.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            timer.close();
                        }
                    }
                } finally {
                }
            } catch (Throwable th3) {
                if (timer != null) {
                    if (th != null) {
                        try {
                            timer.close();
                        } catch (Throwable th4) {
                            th.addSuppressed(th4);
                        }
                    } else {
                        timer.close();
                    }
                }
                throw th3;
            }
        }
        LOG.log(Level.INFO, "OUT: Stop");
        this.controlMessageBroadcaster.send(ControlMessages.Stop);
        Iterator<Double> it = this.losses.iterator();
        while (it.hasNext()) {
            LOG.log(Level.INFO, "OUT: LOSS = {0}", it.next());
        }
        return this.lossCodec.encode(this.losses);
    }

    private void updateModel(Vector vector) throws NetworkException, InterruptedException {
        Timer timer = new Timer("GetDescentDirection + FindMinEta + UpdateModel");
        Throwable th = null;
        try {
            try {
                this.minEta = findMinEta(this.model, vector, lineSearch(vector));
                this.model.multAdd(this.minEta, vector);
                if (timer != null) {
                    if (0 != 0) {
                        try {
                            timer.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        timer.close();
                    }
                }
                LOG.log(Level.INFO, "OUT: New Model = {0}", this.model);
            } finally {
            }
        } catch (Throwable th3) {
            if (timer != null) {
                if (th != null) {
                    try {
                        timer.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    timer.close();
                }
            }
            throw th3;
        }
    }

    private Vector lineSearch(Vector vector) throws NetworkException, InterruptedException {
        boolean z;
        Vector vector2 = null;
        while (true) {
            Timer timer = new Timer("LineSearch - Broadcast(" + (this.sendModel ? "ModelAndDescentDirection" : "DescentDirection") + ") + Reduce(LossEvalsInLineSearch)");
            Throwable th = null;
            try {
                try {
                    if (this.sendModel) {
                        LOG.log(Level.INFO, "OUT: DoLineSearchWithModel");
                        this.controlMessageBroadcaster.send(ControlMessages.DoLineSearchWithModel);
                        this.modelAndDescentDirectionBroadcaster.send(new Pair<>(this.model, vector));
                    } else {
                        LOG.log(Level.INFO, "OUT: DoLineSearch");
                        this.controlMessageBroadcaster.send(ControlMessages.DoLineSearch);
                        this.descentDriectionBroadcaster.send(vector);
                    }
                    Pair<Vector, Integer> reduce = this.lineSearchEvaluationsReducer.reduce();
                    if (reduce != null) {
                        int intValue = reduce.getSecond().intValue();
                        vector2 = reduce.getFirst();
                        vector2.scale(1.0d / intValue);
                        LOG.log(Level.INFO, "OUT: #Examples: {0}", Integer.valueOf(intValue));
                        LOG.log(Level.INFO, "OUT: LineSearchEvals: {0}", vector2);
                        z = false;
                    } else {
                        z = true;
                    }
                    if (timer != null) {
                        if (0 != 0) {
                            try {
                                timer.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            timer.close();
                        }
                    }
                    this.sendModel = chkAndUpdate();
                    if (z || (!this.ignoreAndContinue && this.sendModel)) {
                    }
                } catch (Throwable th3) {
                    if (timer != null) {
                        if (th != null) {
                            try {
                                timer.close();
                            } catch (Throwable th4) {
                                th.addSuppressed(th4);
                            }
                        } else {
                            timer.close();
                        }
                    }
                    throw th3;
                }
            } finally {
            }
        }
        return vector2;
    }

    private Pair<Double, Vector> computeLossAndGradient() throws NetworkException, InterruptedException {
        boolean z;
        Pair<Double, Vector> pair = null;
        while (true) {
            Timer timer = new Timer("Broadcast(" + (this.sendModel ? "Model" : "MinEta") + ") + Reduce(LossAndGradient)");
            Throwable th = null;
            try {
                try {
                    if (this.sendModel) {
                        LOG.log(Level.INFO, "OUT: ComputeGradientWithModel");
                        this.controlMessageBroadcaster.send(ControlMessages.ComputeGradientWithModel);
                        this.modelBroadcaster.send(this.model);
                    } else {
                        LOG.log(Level.INFO, "OUT: ComputeGradientWithMinEta");
                        this.controlMessageBroadcaster.send(ControlMessages.ComputeGradientWithMinEta);
                        this.minEtaBroadcaster.send(Double.valueOf(this.minEta));
                    }
                    Pair<Pair<Double, Integer>, Vector> reduce = this.lossAndGradientReducer.reduce();
                    if (reduce != null) {
                        int intValue = reduce.getFirst().getSecond().intValue();
                        LOG.log(Level.INFO, "OUT: #Examples: {0}", Integer.valueOf(intValue));
                        double doubleValue = reduce.getFirst().getFirst().doubleValue() / intValue;
                        LOG.log(Level.INFO, "OUT: Loss: {0}", Double.valueOf(doubleValue));
                        double norm2Sqr = ((this.lambda / 2.0d) * this.model.norm2Sqr()) + doubleValue;
                        LOG.log(Level.INFO, "OUT: Objective Func Value: {0}", Double.valueOf(norm2Sqr));
                        Vector second = reduce.getSecond();
                        second.scale(1.0d / intValue);
                        LOG.log(Level.INFO, "OUT: Gradient: {0}", second);
                        pair = new Pair<>(Double.valueOf(norm2Sqr), second);
                        z = false;
                    } else {
                        z = true;
                    }
                    if (timer != null) {
                        if (0 != 0) {
                            try {
                                timer.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            timer.close();
                        }
                    }
                    this.sendModel = chkAndUpdate();
                    if (z || (!this.ignoreAndContinue && this.sendModel)) {
                    }
                } catch (Throwable th3) {
                    if (timer != null) {
                        if (th != null) {
                            try {
                                timer.close();
                            } catch (Throwable th4) {
                                th.addSuppressed(th4);
                            }
                        } else {
                            timer.close();
                        }
                    }
                    throw th3;
                }
            } finally {
            }
        }
        return pair;
    }

    private boolean chkAndUpdate() {
        long currentTimeMillis = System.currentTimeMillis();
        GroupChanges topologyChanges = this.communicationGroupClient.getTopologyChanges();
        LOG.log(Level.INFO, "OUT: Time to get TopologyChanges = " + ((System.currentTimeMillis() - currentTimeMillis) / 1000.0d) + " sec");
        if (!topologyChanges.exist()) {
            LOG.log(Level.INFO, "OUT: No changes in topology exist. So not updating topology");
            return false;
        }
        LOG.log(Level.INFO, "OUT: There exist topology changes. Asking to update Topology");
        long currentTimeMillis2 = System.currentTimeMillis();
        this.communicationGroupClient.updateTopology();
        LOG.log(Level.INFO, "OUT: Time to get TopologyChanges = " + ((System.currentTimeMillis() - currentTimeMillis2) / 1000.0d) + " sec");
        return true;
    }

    private boolean converged(int i, double d) {
        return i >= this.maxIters || Math.abs(d) <= 0.001d;
    }

    private double findMinEta(Vector vector, Vector vector2, Vector vector3) {
        double norm2Sqr = vector.norm2Sqr();
        double norm2Sqr2 = vector2.norm2Sqr();
        double dot = vector.dot(vector2);
        double[] t = this.ts.getT();
        int i = 0;
        for (double d : t) {
            vector3.set(i, vector3.get(i) + ((this.lambda / 2.0d) * (norm2Sqr + (d * d * norm2Sqr2) + (2.0d * d * dot))));
            i++;
        }
        LOG.log(Level.INFO, "OUT: Regularized LineSearchEvals: {0}", vector3);
        Tuple<Integer, Double> min = vector3.min();
        LOG.log(Level.INFO, "OUT: MinTup: {0}", min);
        double d2 = t[min.getKey().intValue()];
        LOG.log(Level.INFO, "OUT: MinT: {0}", Double.valueOf(d2));
        return d2;
    }

    private Vector getDescentDirection(Vector vector) {
        vector.multAdd(this.lambda, this.model);
        vector.scale(-1.0d);
        LOG.log(Level.INFO, "OUT: DescentDirection: {0}", vector);
        return vector;
    }
}
