package jitk.spline;

import org.apache.log4j.LogManager;
import org.apache.log4j.Logger;
import org.ejml.data.DenseMatrix64F;
import org.ejml.ops.CommonOps;
import org.ejml.ops.NormOps;

/* loaded from: input_file:jitk/spline/TransformInverseGradientDescent.class */
public class TransformInverseGradientDescent {
    int ndims;
    ThinPlateR2LogRSplineKernelTransform xfm;
    DenseMatrix64F jacobian;
    DenseMatrix64F directionalDeriv;
    DenseMatrix64F dir;
    DenseMatrix64F errorV;
    DenseMatrix64F estimate;
    DenseMatrix64F estimateXfm;
    DenseMatrix64F target;
    protected static Logger logger = LogManager.getLogger(TransformInverseGradientDescent.class.getName());
    double error = 9999.0d;
    double stepSz = 1.0d;
    int maxIters = 20;
    double eps = 1.0E-6d;
    double beta = 0.7d;
    DenseMatrix64F descentDirectionMag = new DenseMatrix64F(1, 1);

    public TransformInverseGradientDescent(int i, ThinPlateR2LogRSplineKernelTransform thinPlateR2LogRSplineKernelTransform) {
        this.ndims = i;
        this.xfm = thinPlateR2LogRSplineKernelTransform;
        this.dir = new DenseMatrix64F(i, 1);
        this.errorV = new DenseMatrix64F(i, 1);
        this.directionalDeriv = new DenseMatrix64F(i, 1);
    }

    public void setEps(double d) {
        this.eps = d;
    }

    public void setStepSize(double d) {
        this.stepSz = d;
    }

    public void setJacobian(double[][] dArr) {
        this.jacobian = new DenseMatrix64F(dArr);
        logger.trace("setJacobian:\n" + this.jacobian);
    }

    public void setTarget(double[] dArr) {
        this.target = new DenseMatrix64F(this.ndims, 1);
        this.target.setData(dArr);
    }

    public DenseMatrix64F getErrorVector() {
        return this.errorV;
    }

    public DenseMatrix64F getDirection() {
        return this.dir;
    }

    public DenseMatrix64F getJacobian() {
        return this.jacobian;
    }

    public void setEstimate(double[] dArr) {
        this.estimate = new DenseMatrix64F(this.ndims, 1);
        this.estimate.setData(dArr);
    }

    public void setEstimateXfm(double[] dArr) {
        this.estimateXfm = new DenseMatrix64F(this.ndims, 1);
        this.estimateXfm.setData(dArr);
        updateError();
    }

    public DenseMatrix64F getEstimate() {
        return this.estimate;
    }

    public double getError() {
        return this.error;
    }

    public void oneIteration() {
        oneIteration(true);
    }

    public void oneIteration(boolean z) {
        computeDirection();
        updateEstimate(this.stepSz);
        if (z) {
            updateError();
        }
    }

    public void computeDirectionSteepest() {
        DenseMatrix64F denseMatrix64F = new DenseMatrix64F(this.ndims, 1);
        logger.trace("\nerrorV:\n" + this.errorV);
        CommonOps.mult(this.jacobian, this.estimate, denseMatrix64F);
        CommonOps.subEquals(denseMatrix64F, this.errorV);
        CommonOps.multTransA(2.0d, this.jacobian, denseMatrix64F, this.dir);
        CommonOps.divide(NormOps.normP2(this.dir), this.dir);
        CommonOps.mult(this.jacobian, this.dir, this.directionalDeriv);
        CommonOps.scale(-1.0d, this.dir);
    }

    public void computeDirection() {
        CommonOps.solve(this.jacobian, this.errorV, this.dir);
        CommonOps.divide(NormOps.normP2(this.dir), this.dir);
        CommonOps.mult(this.jacobian, this.dir, this.directionalDeriv);
        CommonOps.multTransA(this.dir, this.directionalDeriv, this.descentDirectionMag);
        logger.debug("descentDirectionMag: " + this.descentDirectionMag.get(0));
    }

    public double backtrackingLineSearch(double d, double d2, int i, double d3) {
        double d4 = d3;
        int i2 = 0;
        while (i2 < i && !armijoCondition(d, d4)) {
            d4 *= d2;
            i2++;
        }
        logger.trace("selected step size after " + i2 + " tries");
        return d4;
    }

    public boolean armijoCondition(double d, double d2) {
        double[] dArr = this.dir.data;
        double[] dArr2 = this.estimate.data;
        double[] dArr3 = new double[this.ndims];
        for (int i = 0; i < this.ndims; i++) {
            dArr3[i] = dArr2[i] + (d2 * dArr[i]);
        }
        double[] dArr4 = this.estimateXfm.data;
        double[] apply = this.xfm.apply(dArr3);
        double squaredError = squaredError(dArr4);
        double squaredError2 = squaredError(apply);
        double sumSquaredErrorsDeriv = sumSquaredErrorsDeriv(this.target.data, dArr4) * this.descentDirectionMag.get(0);
        logger.trace("   f( x )     : " + squaredError);
        logger.trace("   f( x + ap ): " + squaredError2);
        logger.trace("   f( x ) + c * m * t: " + (squaredError + (d * d2 * sumSquaredErrorsDeriv)));
        return squaredError2 < squaredError + ((d * d2) * sumSquaredErrorsDeriv);
    }

    public double squaredError(double[] dArr) {
        double d = 0.0d;
        for (int i = 0; i < this.ndims; i++) {
            d += (dArr[i] - this.target.get(i)) * (dArr[i] - this.target.get(i));
        }
        return d;
    }

    public void updateEstimate(double d) {
        logger.trace("step size: " + d);
        logger.trace("estimate:\n" + this.estimate);
        CommonOps.addEquals(this.estimate, d, this.dir);
        logger.trace("new estimate:\n" + this.estimate);
    }

    public void updateEstimateNormBased(double d) {
        logger.debug("step size: " + d);
        logger.trace("estimate:\n" + this.estimate);
        double normP2 = NormOps.normP2(this.dir);
        logger.debug("norm: " + normP2);
        if (normP2 > d) {
            CommonOps.scale((-d) / normP2, this.dir);
        }
        CommonOps.addEquals(this.estimate, this.dir);
        logger.trace("new estimate:\n" + this.estimate);
    }

    public void updateError() {
        if (this.estimate == null || this.target == null) {
            System.err.println("WARNING: Call to updateError with null target or estimate");
            return;
        }
        CommonOps.sub(this.target, this.estimateXfm, this.errorV);
        logger.trace("#########################");
        logger.trace("updateError, estimate   :\n" + this.estimate);
        logger.trace("updateError, estimateXfm:\n" + this.estimateXfm);
        logger.trace("updateError, target     :\n" + this.target);
        logger.trace("updateError, error      :\n" + this.errorV);
        logger.trace("#########################");
        this.error = Math.abs(this.errorV.get(0));
        for (int i = 1; i < this.ndims; i++) {
            if (Math.abs(this.errorV.get(i)) > this.error) {
                this.error = Math.abs(this.errorV.get(i));
            }
        }
    }

    private double sumSquaredErrorsDeriv(double[] dArr, double[] dArr2) {
        double d = 0.0d;
        for (int i = 0; i < this.ndims; i++) {
            d += (dArr[i] - dArr2[i]) * (dArr[i] - dArr2[i]);
        }
        return 2.0d * d;
    }

    public static double sumSquaredErrors(double[] dArr, double[] dArr2) {
        int length = dArr.length;
        double d = 0.0d;
        for (int i = 0; i < length; i++) {
            d += (dArr[i] - dArr2[i]) * (dArr[i] - dArr2[i]);
        }
        return d;
    }

    public static void copyVectorIntoArray(DenseMatrix64F denseMatrix64F, double[] dArr) {
        System.arraycopy(denseMatrix64F.data, 0, dArr, 0, denseMatrix64F.getNumElements());
    }
}
