/*
 * Decompiled with CFR 0.152.
 */
package de.jungblut.math.minimize;

import de.jungblut.math.DoubleVector;
import de.jungblut.math.dense.DenseDoubleVector;
import de.jungblut.math.minimize.AbstractMinimizer;
import de.jungblut.math.minimize.CostFunction;
import de.jungblut.math.minimize.CostGradientTuple;
import gnu.trove.list.array.TDoubleArrayList;
import java.util.ArrayList;
import org.apache.commons.math3.util.FastMath;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

public class OWLQN
extends AbstractMinimizer {
    private static final Logger LOG = LogManager.getLogger(OWLQN.class);
    private DoubleVector x;
    private DoubleVector grad;
    private DoubleVector newX;
    private DoubleVector newGrad;
    private DoubleVector dir;
    private DoubleVector steepestDescDir;
    private double[] alphas;
    private TDoubleArrayList roList;
    private TDoubleArrayList costs;
    private ArrayList<DoubleVector> sList;
    private ArrayList<DoubleVector> yList;
    private double value;
    private int m = 10;
    private double l1weight = 0.0;
    private double tol = 1.0E-4;
    private boolean gradCheck = false;

    @Override
    public DoubleVector minimize(CostFunction f, DoubleVector theta, int maxIterations, boolean verbose) {
        DenseDoubleVector zeros = new DenseDoubleVector(theta.getDimension());
        this.x = theta;
        this.grad = zeros;
        this.newX = theta.deepCopy();
        this.newGrad = zeros;
        this.dir = zeros;
        this.steepestDescDir = this.newGrad;
        this.alphas = new double[this.m];
        this.roList = new TDoubleArrayList(this.m);
        this.costs = new TDoubleArrayList(this.m);
        this.sList = new ArrayList();
        this.yList = new ArrayList();
        this.value = this.evaluateL1(f);
        this.grad = this.newGrad;
        for (int i = 0; i < maxIterations; ++i) {
            this.updateDir(f, verbose);
            boolean continueIterations = this.backTrackingLineSearch(i, f);
            this.shift();
            this.costs.add(this.value);
            if (this.costs.size() > 5) {
                double first = this.costs.get(0);
                while (this.costs.size() > 5) {
                    this.costs.removeAt(0);
                }
                double avgImprovement = (first - this.value) / (double)this.costs.size();
                double perc = avgImprovement / Math.abs(this.value);
                if (perc < this.tol) break;
            }
            if (!continueIterations) break;
            if (!verbose) continue;
            LOG.info("Iteration " + i + " | Cost: " + this.value);
        }
        this.x = null;
        this.grad = null;
        this.newGrad = null;
        this.dir = null;
        this.steepestDescDir = null;
        this.alphas = null;
        this.roList = null;
        this.costs = null;
        this.sList = null;
        this.yList = null;
        return this.newX;
    }

    private void updateDir(CostFunction f, boolean verbose) {
        this.makeSteepestDescDir();
        this.mapDirectionByInverseHessian();
        this.fixDirectionSigns();
        if (this.gradCheck) {
            this.testDirectionDerivation(f);
        }
    }

    private void testDirectionDerivation(CostFunction f) {
        double dirNorm = FastMath.sqrt((double)this.dir.dot(this.dir));
        if (dirNorm != 0.0) {
            double eps = 1.05E-8 / dirNorm;
            this.getNextPoint(eps);
            double val2 = this.evaluateL1(f);
            double numDeriv = (val2 - this.value) / eps;
            double deriv = this.directionDerivation();
            LOG.info("GradCheck: expected= " + numDeriv + " vs. " + deriv + "! AbsDiff= " + Math.abs(numDeriv - deriv));
        }
    }

    private void fixDirectionSigns() {
        if (this.l1weight > 0.0) {
            for (int i = 0; i < this.dir.getDimension(); ++i) {
                if (!(this.dir.get(i) * this.steepestDescDir.get(i) <= 0.0)) continue;
                this.dir.set(i, 0.0);
            }
        }
    }

    private void mapDirectionByInverseHessian() {
        int count = this.sList.size();
        if (count != 0) {
            for (int i = count - 1; i >= 0; --i) {
                this.alphas[i] = -this.sList.get(i).dot(this.dir) / this.roList.get(i);
                this.addMult(this.dir, this.yList.get(i), this.alphas[i]);
            }
            DoubleVector lastY = this.yList.get(count - 1);
            double yDotY = lastY.dot(lastY);
            double scalar = this.roList.get(count - 1) / yDotY;
            this.scale(this.dir, scalar);
            for (int i = 0; i < count; ++i) {
                double beta = this.yList.get(i).dot(this.dir) / this.roList.get(i);
                this.addMult(this.dir, this.sList.get(i), -this.alphas[i] - beta);
            }
        }
    }

    private void makeSteepestDescDir() {
        if (this.l1weight == 0.0) {
            this.scaleInto(this.dir, this.grad, -1.0);
        } else {
            for (int i = 0; i < this.dir.getDimension(); ++i) {
                if (this.x.get(i) < 0.0) {
                    this.dir.set(i, -this.grad.get(i) + this.l1weight);
                    continue;
                }
                if (this.x.get(i) > 0.0) {
                    this.dir.set(i, -this.grad.get(i) - this.l1weight);
                    continue;
                }
                if (this.grad.get(i) < -this.l1weight) {
                    this.dir.set(i, -this.grad.get(i) - this.l1weight);
                    continue;
                }
                if (this.grad.get(i) > this.l1weight) {
                    this.dir.set(i, -this.grad.get(i) + this.l1weight);
                    continue;
                }
                this.dir.set(i, 0.0);
            }
        }
        this.steepestDescDir = this.dir;
    }

    private boolean backTrackingLineSearch(int iter, CostFunction f) {
        double origDirDeriv = this.directionDerivation();
        if (origDirDeriv > 0.0) {
            throw new RuntimeException("L-BFGS chose a non-descent direction: check your gradient!");
        }
        if (origDirDeriv == 0.0 || Double.isNaN(origDirDeriv)) {
            LOG.info("L-BFGS apparently found the minimum. No direction to descent anymore.");
            return false;
        }
        double alpha = 1.0;
        double backoff = 0.5;
        if (iter == 0) {
            double normDir = FastMath.sqrt((double)this.dir.dot(this.dir));
            alpha = 1.0 / normDir;
            backoff = 0.1;
        }
        double c1 = 1.0E-4;
        double oldValue = this.value;
        while (true) {
            this.getNextPoint(alpha);
            this.value = this.evaluateL1(f);
            if (Double.isNaN(this.value) || this.value <= oldValue + c1 * origDirDeriv * alpha) break;
            alpha *= backoff;
        }
        return true;
    }

    private void getNextPoint(double alpha) {
        this.addMultInto(this.newX, this.x, this.dir, alpha);
        if (this.l1weight > 0.0) {
            for (int i = 0; i < this.x.getDimension(); ++i) {
                if (!(this.x.get(i) * this.newX.get(i) < 0.0)) continue;
                this.newX.set(i, 0.0);
            }
        }
    }

    private void addMultInto(DoubleVector a, DoubleVector b, DoubleVector c, double d) {
        for (int i = 0; i < a.getDimension(); ++i) {
            a.set(i, b.get(i) + c.get(i) * d);
        }
    }

    private void addMult(DoubleVector a, DoubleVector b, double c) {
        for (int i = 0; i < a.getDimension(); ++i) {
            a.set(i, a.get(i) + b.get(i) * c);
        }
    }

    private void scale(DoubleVector a, double b) {
        for (int i = 0; i < a.getDimension(); ++i) {
            a.set(i, a.get(i) * b);
        }
    }

    void scaleInto(DoubleVector a, DoubleVector b, double c) {
        for (int i = 0; i < a.getDimension(); ++i) {
            a.set(i, b.get(i) * c);
        }
    }

    private double directionDerivation() {
        if (this.l1weight == 0.0) {
            return this.dir.dot(this.grad);
        }
        double val = 0.0;
        for (int i = 0; i < this.dir.getDimension(); ++i) {
            if (this.dir.get(i) == 0.0) continue;
            if (this.x.get(i) < 0.0) {
                val += this.dir.get(i) * (this.grad.get(i) - this.l1weight);
                continue;
            }
            if (this.x.get(i) > 0.0) {
                val += this.dir.get(i) * (this.grad.get(i) + this.l1weight);
                continue;
            }
            if (this.dir.get(i) < 0.0) {
                val += this.dir.get(i) * (this.grad.get(i) - this.l1weight);
                continue;
            }
            if (!(this.dir.get(i) > 0.0)) continue;
            val += this.dir.get(i) * (this.grad.get(i) + this.l1weight);
        }
        return val;
    }

    private double evaluateL1(CostFunction f) {
        CostGradientTuple evaluateCost = f.evaluateCost(this.newX);
        this.newGrad = evaluateCost.getGradient();
        double val = evaluateCost.getCost();
        if (this.l1weight > 0.0) {
            for (int i = 0; i < this.newGrad.getDimension(); ++i) {
                val += Math.abs(this.newX.get(i)) * this.l1weight;
            }
        }
        return val;
    }

    private void shift() {
        DenseDoubleVector nextS = null;
        DenseDoubleVector nextY = null;
        int listSize = this.sList.size();
        if (listSize < this.m) {
            nextS = new DenseDoubleVector(this.x.getDimension());
            nextY = new DenseDoubleVector(this.x.getDimension());
        }
        if (nextS == null) {
            nextS = this.sList.get(0);
            this.sList.remove(0);
            nextY = this.yList.get(0);
            this.yList.remove(0);
            this.roList.removeAt(0);
        }
        this.addMultInto((DoubleVector)nextS, this.newX, this.x, -1.0);
        this.addMultInto((DoubleVector)nextY, this.newGrad, this.grad, -1.0);
        double ro = nextS.dot((DoubleVector)nextY);
        this.sList.add((DoubleVector)nextS);
        this.yList.add((DoubleVector)nextY);
        this.roList.add(ro);
        DoubleVector tmpNewX = this.newX.deepCopy();
        this.newX = this.x.deepCopy();
        this.x = tmpNewX;
        DoubleVector tmpNewGrad = this.newGrad.deepCopy();
        this.newGrad = this.grad.deepCopy();
        this.grad = tmpNewGrad;
    }

    public OWLQN doGradChecks() {
        this.gradCheck = true;
        return this;
    }

    public OWLQN setM(int m) {
        this.m = m;
        return this;
    }

    public OWLQN setL1Weight(double l1weight) {
        this.l1weight = l1weight;
        return this;
    }

    public OWLQN setTolerance(double tol) {
        this.tol = tol;
        return this;
    }

    public static DoubleVector minimizeFunction(CostFunction f, DoubleVector theta, int maxIterations, boolean verbose) {
        return new OWLQN().minimize(f, theta, maxIterations, verbose);
    }
}

