/*
 * Decompiled with CFR 0.152.
 */
package hex.optimization;

import hex.DataInfo;
import hex.glm.GLM;
import hex.glm.GLMModel;
import hex.optimization.L_BFGS;
import hex.optimization.OptimizationUtils;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import water.DKV;
import water.Iced;
import water.Key;
import water.MemoryManager;
import water.TestUtil;
import water.Value;
import water.fvec.Frame;
import water.fvec.Vec;
import water.util.ArrayUtils;

public class L_BFGS_Test
extends TestUtil {
    @BeforeClass
    public static void setup() {
        L_BFGS_Test.stall_till_cloudsize((int)1);
        try {
            Thread.sleep(100L);
        }
        catch (InterruptedException e) {
            e.printStackTrace();
        }
    }

    @Test
    public void rosenbrock() {
        double a = 1.0;
        double b = 100.0;
        OptimizationUtils.GradientSolver gs = new OptimizationUtils.GradientSolver(){

            public OptimizationUtils.GradientInfo getGradient(double[] beta) {
                double[] g = new double[2];
                double x = beta[0];
                double y = beta[1];
                double xx = x * x;
                g[0] = -2.0 + 2.0 * x - 400.0 * (y * x - x * xx);
                g[1] = 200.0 * (y - xx);
                double objVal = (1.0 - x) * (1.0 - x) + 100.0 * (y - xx) * (y - xx);
                return new OptimizationUtils.GradientInfo(objVal, g);
            }

            public OptimizationUtils.GradientInfo getObjective(double[] beta) {
                return this.getGradient(beta);
            }
        };
        L_BFGS lbfgs = new L_BFGS().setGradEps(1.0E-12);
        L_BFGS.Result r = lbfgs.solve(gs, L_BFGS.startCoefs((int)2, (long)987654321L));
        Assert.assertTrue((String)"LBFGS failed to solve Rosenbrock function optimization", (r.ginfo._objVal < 1.0E-4 ? 1 : 0) != 0);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Test
    public void logistic() {
        Key parsedKey = Key.make((String)"prostate");
        DataInfo dinfo = null;
        try {
            GLMModel.GLMParameters glmp = new GLMModel.GLMParameters(GLMModel.GLMParameters.Family.binomial, GLMModel.GLMParameters.Family.binomial.defaultLink);
            glmp._alpha = new double[]{0.0};
            glmp._lambda = new double[]{1.0E-5};
            Frame source = L_BFGS_Test.parse_test_file((Key)parsedKey, (String)"smalldata/glm_test/prostate_cat_replaced.csv");
            source.add("CAPSULE", source.remove("CAPSULE"));
            source.remove("ID").remove();
            Frame valid = new Frame((String[])source._names.clone(), (Vec[])source.vecs().clone());
            dinfo = new DataInfo(source, valid, 1, false, DataInfo.TransformType.STANDARDIZE, DataInfo.TransformType.NONE, true, false, false, false, false, false);
            DKV.put((Key)dinfo._key, (Iced)dinfo);
            glmp._obj_reg = 0.002631578947368421;
            GLM.GLMGradientSolver solver = new GLM.GLMGradientSolver(null, glmp, dinfo, 1.0E-5, null);
            L_BFGS lbfgs = new L_BFGS().setGradEps(1.0E-8);
            double[] beta = MemoryManager.malloc8d((int)(dinfo.fullN() + 1));
            beta[beta.length - 1] = new GLMModel.GLMWeightsFun(glmp).link(source.vec("CAPSULE").mean());
            L_BFGS.Result r = lbfgs.solve((OptimizationUtils.GradientSolver)solver, beta, (OptimizationUtils.GradientInfo)solver.getGradient(beta), new L_BFGS.ProgressMonitor(){
                int _i = 0;

                public boolean progress(double[] beta, OptimizationUtils.GradientInfo ginfo) {
                    System.out.println(++this._i + ":" + ginfo._objVal + ", " + ArrayUtils.l2norm2((double[])ginfo._gradient, (boolean)false));
                    return true;
                }
            });
            Assert.assertEquals((double)378.34, (double)(2.0 * r.ginfo._objVal * (double)source.numRows()), (double)0.1);
        }
        finally {
            Value v;
            if (dinfo != null) {
                DKV.remove((Key)dinfo._key);
            }
            if ((v = DKV.get((Key)parsedKey)) != null) {
                ((Frame)v.get()).delete();
            }
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Test
    public void testArcene() {
        Key parsedKey = Key.make((String)"arcene_parsed");
        DataInfo dinfo = null;
        try {
            Frame source = L_BFGS_Test.parse_test_file((Key)parsedKey, (String)"smalldata/glm_test/arcene.csv");
            Frame valid = new Frame((String[])source._names.clone(), (Vec[])source.vecs().clone());
            GLMModel.GLMParameters glmp = new GLMModel.GLMParameters(GLMModel.GLMParameters.Family.gaussian);
            glmp._lambda = new double[]{1.0E-5};
            glmp._alpha = new double[]{0.0};
            glmp._obj_reg = 0.01;
            dinfo = new DataInfo(source, valid, 1, false, DataInfo.TransformType.STANDARDIZE, DataInfo.TransformType.NONE, true, false, false, false, false, false);
            DKV.put((Key)dinfo._key, (Iced)dinfo);
            GLM.GLMGradientSolver solver = new GLM.GLMGradientSolver(null, glmp, dinfo, 1.0E-5, null);
            L_BFGS lbfgs = new L_BFGS().setMaxIter(20);
            double[] beta = MemoryManager.malloc8d((int)(dinfo.fullN() + 1));
            beta[beta.length - 1] = new GLMModel.GLMWeightsFun(glmp).link(source.lastVec().mean());
            L_BFGS.Result r1 = lbfgs.solve((OptimizationUtils.GradientSolver)solver, (double[])beta.clone(), solver.getGradient(beta), new L_BFGS.ProgressMonitor(){
                int _i = 0;

                public boolean progress(double[] beta, OptimizationUtils.GradientInfo ginfo) {
                    System.out.println(++this._i + ":" + ginfo._objVal);
                    return true;
                }
            });
            lbfgs.setMaxIter(50);
            final int iter = r1.iter;
            L_BFGS.Result r2 = lbfgs.solve((OptimizationUtils.GradientSolver)solver, r1.coefs, r1.ginfo, new L_BFGS.ProgressMonitor(){
                int _i = 0;

                public boolean progress(double[] beta, OptimizationUtils.GradientInfo ginfo) {
                    System.out.println(iter + " + " + ++this._i + ":" + ginfo._objVal);
                    return true;
                }
            });
            System.out.println();
            lbfgs = new L_BFGS().setMaxIter(100);
            L_BFGS.Result r3 = lbfgs.solve((OptimizationUtils.GradientSolver)solver, (double[])beta.clone(), solver.getGradient(beta), new L_BFGS.ProgressMonitor(){
                int _i = 0;

                public boolean progress(double[] beta, OptimizationUtils.GradientInfo ginfo) {
                    System.out.println(++this._i + ":" + ginfo._objVal + ", " + ArrayUtils.l2norm2((double[])ginfo._gradient, (boolean)false));
                    return true;
                }
            });
            Assert.assertEquals((long)r1.iter, (long)20L);
            Assert.assertEquals((double)r2.ginfo._objVal, (double)r3.ginfo._objVal, (double)1.0E-8);
            Assert.assertEquals((double)(0.5 * glmp._lambda[0] * ArrayUtils.l2norm((double[])r3.coefs, (boolean)true) + r3.ginfo._objVal), (double)1.0E-4, (double)5.0E-4);
            Assert.assertTrue((String)("iter# expected < 100, got " + r3.iter), (r3.iter < 100 ? 1 : 0) != 0);
        }
        finally {
            Value v;
            if (dinfo != null) {
                DKV.remove((Key)dinfo._key);
            }
            if ((v = DKV.get((Key)parsedKey)) != null) {
                ((Frame)v.get()).delete();
            }
        }
    }
}

