/*
 * Decompiled with CFR 0.152.
 */
package de.jungblut.math.minimize;

import com.google.common.base.Preconditions;
import de.jungblut.math.DoubleVector;
import de.jungblut.math.minimize.AbstractMinimizer;
import de.jungblut.math.minimize.CostFunction;
import de.jungblut.math.minimize.CostGradientTuple;
import java.util.Arrays;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

public final class GradientDescent
extends AbstractMinimizer {
    private static final Logger LOG = LogManager.getLogger(GradientDescent.class);
    private static final int COST_HISTORY = 3;
    private final boolean breakOnDivergence;
    private final double breakDifference;
    private final double momentum;
    private final double alpha;
    private final boolean boldDriver;
    private final double boldIncreasePercentage;
    private final double boldDecreasePercentage;
    private final int annealingIteration;

    private GradientDescent(GradientDescentBuilder builder) {
        this.alpha = builder.alpha;
        this.breakDifference = builder.breakDifference;
        this.momentum = builder.momentum;
        this.breakOnDivergence = builder.breakOnDivergence;
        this.boldDriver = builder.boldDriver;
        this.boldIncreasePercentage = builder.boldIncreasePercentage;
        this.boldDecreasePercentage = builder.boldDecreasePercentage;
        this.annealingIteration = builder.annealingIteration;
    }

    public GradientDescent(double alpha, double limit) {
        this(GradientDescentBuilder.create(alpha).breakOnDifference(limit));
    }

    @Override
    public final DoubleVector minimize(CostFunction f, DoubleVector pInput, int maxIterations, boolean verbose) {
        double[] lastCosts = new double[3];
        Arrays.fill(lastCosts, Double.MAX_VALUE);
        int lastIndex = lastCosts.length - 1;
        DoubleVector lastTheta = null;
        DoubleVector lastGradient = null;
        DoubleVector theta = pInput;
        double alpha = this.alpha;
        for (int iteration = 0; iteration < maxIterations; ++iteration) {
            CostGradientTuple evaluateCost = f.evaluateCost(theta);
            if (verbose) {
                LOG.info("Iteration " + iteration + " | Cost: " + evaluateCost.getCost());
            }
            GradientDescent.shiftLeft(lastCosts);
            lastCosts[lastIndex] = evaluateCost.getCost();
            if (GradientDescent.converged(lastCosts, this.breakDifference) || this.breakOnDivergence && GradientDescent.ascending(lastCosts)) break;
            DoubleVector gradient = evaluateCost.getGradient();
            if (this.boldDriver) {
                if (lastGradient != null) {
                    double costDifference = GradientDescent.getCostDifference(lastCosts);
                    if (costDifference < 0.0) {
                        alpha += alpha * this.boldDecreasePercentage;
                    } else {
                        theta = lastTheta;
                        gradient = lastGradient;
                        alpha -= alpha * this.boldIncreasePercentage;
                    }
                    if (verbose) {
                        LOG.info("Iteration " + iteration + " | Alpha: " + alpha + "\n");
                    }
                }
                lastGradient = gradient;
            }
            if (this.annealingIteration > 0) {
                alpha = this.alpha / (1.0 + (double)(iteration / this.annealingIteration));
            }
            lastTheta = theta;
            theta = theta.subtract(gradient.multiply(alpha));
            if (lastTheta != null && this.momentum != 0.0) {
                theta = theta.add(lastTheta.subtract(theta).multiply(this.momentum));
            }
            this.onIterationFinished(iteration, evaluateCost.getCost(), theta);
        }
        return theta;
    }

    public static DoubleVector minimizeFunction(CostFunction f, DoubleVector pInput, double alpha, double limit, int length, boolean verbose) {
        return new GradientDescent(alpha, limit).minimize(f, pInput, length, verbose);
    }

    static void shiftLeft(double[] lastCosts) {
        int lastIndex = lastCosts.length - 1;
        for (int i = 0; i < lastIndex; ++i) {
            lastCosts[i] = lastCosts[i + 1];
        }
        lastCosts[lastIndex] = Double.MAX_VALUE;
    }

    static boolean converged(double[] lastCosts, double limit) {
        return Math.abs(GradientDescent.getCostDifference(lastCosts)) < limit;
    }

    static boolean ascending(double[] lastCosts) {
        double last = lastCosts[0];
        boolean ascending = false;
        for (int i = 1; i < lastCosts.length; ++i) {
            ascending = last < lastCosts[i];
            last = lastCosts[i];
        }
        return ascending;
    }

    private static double getCostDifference(double[] lastCosts) {
        return lastCosts[lastCosts.length - 1] - lastCosts[lastCosts.length - 2];
    }

    public static class GradientDescentBuilder {
        private final double alpha;
        private double breakDifference;
        private double momentum;
        private boolean breakOnDivergence;
        private boolean boldDriver;
        private double boldIncreasePercentage;
        private double boldDecreasePercentage;
        private int annealingIteration = -1;

        private GradientDescentBuilder(double alpha) {
            this.alpha = alpha;
        }

        public GradientDescent build() {
            return new GradientDescent(this);
        }

        public GradientDescentBuilder momentum(double momentum) {
            Preconditions.checkArgument((momentum >= 0.0 && momentum <= 1.0 ? 1 : 0) != 0, (Object)"Momentum must be between 0 and 1.");
            this.momentum = momentum;
            return this;
        }

        public GradientDescentBuilder boldDriver() {
            return this.boldDriver(0.5, 0.05);
        }

        public GradientDescentBuilder boldDriver(double increasedCostPercentage, double decreasedCostPercentage) {
            Preconditions.checkArgument((increasedCostPercentage >= 0.0 && increasedCostPercentage <= 1.0 ? 1 : 0) != 0, (Object)"increasedCostPercentage must be between 0 and 1.");
            Preconditions.checkArgument((decreasedCostPercentage >= 0.0 && decreasedCostPercentage <= 1.0 ? 1 : 0) != 0, (Object)"decreasedCostPercentage must be between 0 and 1.");
            this.boldDriver = true;
            this.boldIncreasePercentage = increasedCostPercentage;
            this.boldDecreasePercentage = decreasedCostPercentage;
            return this;
        }

        public GradientDescentBuilder breakOnDivergence() {
            this.breakOnDivergence = true;
            return this;
        }

        public GradientDescentBuilder breakOnDifference(double delta) {
            this.breakDifference = delta;
            return this;
        }

        public GradientDescentBuilder annealingAfter(int iteration) {
            Preconditions.checkArgument((iteration > 0 ? 1 : 0) != 0, (Object)("Annealing can only kick in after the first iteration! Given: " + iteration));
            this.annealingIteration = iteration;
            return this;
        }

        public static GradientDescentBuilder create(double alpha) {
            return new GradientDescentBuilder(alpha);
        }
    }
}

