/*
 * Decompiled with CFR 0.152.
 */
package hex.glrm;

import hex.DataInfo;
import hex.Model;
import hex.ModelMetrics;
import hex.genmodel.algos.glrm.GlrmInitialization;
import hex.genmodel.algos.glrm.GlrmLoss;
import hex.genmodel.algos.glrm.GlrmRegularizer;
import hex.glrm.GLRM;
import hex.glrm.GLRMModel;
import hex.glrm.GLRMTest;
import hex.glrm.ModelMetricsGLRM;
import java.util.Random;
import java.util.concurrent.ExecutionException;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import water.DKV;
import water.Iced;
import water.Key;
import water.Scope;
import water.TestUtil;
import water.fvec.Frame;
import water.fvec.Vec;
import water.util.ArrayUtils;
import water.util.Log;

public class GLRMCategoricalTest
extends TestUtil {
    public final double TOLERANCE = 1.0E-6;

    @BeforeClass
    public static void setup() {
        GLRMCategoricalTest.stall_till_cloudsize((int)1);
    }

    private static String colFormat(String[] cols, String format) {
        int[] idx = new int[cols.length];
        for (int i = 0; i < idx.length; ++i) {
            idx[i] = i;
        }
        return GLRMCategoricalTest.colFormat(cols, format, idx);
    }

    private static String colFormat(String[] cols, String format, int[] idx) {
        StringBuilder sb = new StringBuilder();
        for (int i = 0; i < cols.length; ++i) {
            sb.append(String.format(format, cols[idx[i]]));
        }
        sb.append("\n");
        return sb.toString();
    }

    private static String colExpFormat(String[] cols, String[][] domains, String format) {
        int[] idx = new int[cols.length];
        for (int i = 0; i < idx.length; ++i) {
            idx[i] = i;
        }
        return GLRMCategoricalTest.colExpFormat(cols, domains, format, idx);
    }

    private static String colExpFormat(String[] cols, String[][] domains, String format, int[] idx) {
        StringBuilder sb = new StringBuilder();
        for (int i = 0; i < domains.length; ++i) {
            int c = idx[i];
            if (domains[c] == null) {
                sb.append(String.format(format, cols[c]));
                continue;
            }
            for (int j = 0; j < domains[c].length; ++j) {
                sb.append(String.format(format, domains[c][j]));
            }
        }
        sb.append("\n");
        return sb.toString();
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Test
    public void testCategoricalIris() throws InterruptedException, ExecutionException {
        GLRMModel model = null;
        Frame train = null;
        try {
            train = GLRMCategoricalTest.parse_test_file((Key)Key.make((String)"iris.hex"), (String)"smalldata/iris/iris_wheader.csv");
            GLRMModel.GLRMParameters parms = new GLRMModel.GLRMParameters();
            parms._train = train._key;
            parms._k = 4;
            parms._loss = GlrmLoss.Absolute;
            parms._init = GlrmInitialization.SVD;
            parms._transform = DataInfo.TransformType.NONE;
            parms._recover_svd = true;
            parms._max_iterations = 1000;
            model = (GLRMModel)new GLRM(parms).trainModel().get();
            Log.info((Object[])new Object[]{"Iteration " + ((GLRMModel.GLRMOutput)model._output)._iterations + ": Objective value = " + ((GLRMModel.GLRMOutput)model._output)._objective});
            model.score(train).delete();
            ModelMetricsGLRM mm = (ModelMetricsGLRM)ModelMetrics.getFromDKV((Model)model, (Frame)train);
            Log.info((Object[])new Object[]{"Numeric Sum of Squared Error = " + mm._numerr + "\tCategorical Misclassification Error = " + mm._caterr});
        }
        finally {
            if (train != null) {
                train.delete();
            }
            if (model != null) {
                model.delete();
            }
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Test
    public void testCategoricalProstate() throws InterruptedException, ExecutionException {
        GLRMModel model = null;
        Frame train = null;
        int[] cats = new int[]{1, 3, 4, 5};
        try {
            Scope.enter();
            train = GLRMCategoricalTest.parse_test_file((Key)Key.make((String)"prostate.hex"), (String)"smalldata/logreg/prostate.csv");
            for (int i = 0; i < cats.length; ++i) {
                Scope.track((Vec)train.replace(cats[i], train.vec(cats[i]).toCategoricalVec()));
            }
            train.remove("ID").remove();
            DKV.put((Key)train._key, (Iced)train);
            GLRMModel.GLRMParameters parms = new GLRMModel.GLRMParameters();
            parms._train = train._key;
            parms._k = 8;
            parms._gamma_y = 0.1;
            parms._gamma_x = 0.1;
            parms._regularization_x = GlrmRegularizer.Quadratic;
            parms._regularization_y = GlrmRegularizer.Quadratic;
            parms._init = GlrmInitialization.PlusPlus;
            parms._transform = DataInfo.TransformType.STANDARDIZE;
            parms._recover_svd = false;
            parms._max_iterations = 200;
            model = (GLRMModel)new GLRM(parms).trainModel().get();
            Log.info((Object[])new Object[]{"Iteration " + ((GLRMModel.GLRMOutput)model._output)._iterations + ": Objective value = " + ((GLRMModel.GLRMOutput)model._output)._objective});
            model.score(train).delete();
            ModelMetricsGLRM mm = (ModelMetricsGLRM)ModelMetrics.getFromDKV((Model)model, (Frame)train);
            Log.info((Object[])new Object[]{"Numeric Sum of Squared Error = " + mm._numerr + "\tCategorical Misclassification Error = " + mm._caterr});
        }
        finally {
            if (train != null) {
                train.delete();
            }
            if (model != null) {
                model.delete();
            }
            Scope.exit((Key[])new Key[0]);
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Test
    public void testLosses() throws InterruptedException, ExecutionException {
        long seed = 912559L;
        Random rng = new Random(seed);
        Frame train = null;
        int[] cats = new int[]{1, 3, 4, 5};
        GlrmRegularizer[] regs = new GlrmRegularizer[]{GlrmRegularizer.Quadratic, GlrmRegularizer.L1, GlrmRegularizer.NonNegative, GlrmRegularizer.OneSparse, GlrmRegularizer.UnitOneSparse, GlrmRegularizer.Simplex};
        Scope.enter();
        try {
            train = GLRMCategoricalTest.parse_test_file((Key)Key.make((String)"prostate.hex"), (String)"smalldata/logreg/prostate.csv");
            for (int i = 0; i < cats.length; ++i) {
                Scope.track((Vec)train.replace(cats[i], train.vec(cats[i]).toCategoricalVec()));
            }
            train.remove("ID").remove();
            DKV.put((Key)train._key, (Iced)train);
            for (GlrmLoss loss : new GlrmLoss[]{GlrmLoss.Quadratic, GlrmLoss.Absolute, GlrmLoss.Huber, GlrmLoss.Poisson}) {
                for (GlrmLoss multiloss : new GlrmLoss[]{GlrmLoss.Categorical, GlrmLoss.Ordinal}) {
                    GLRMModel model = null;
                    try {
                        Scope.enter();
                        long myseed = rng.nextLong();
                        Log.info((Object[])new Object[]{"GLRM using seed = " + myseed});
                        GLRMModel.GLRMParameters parms = new GLRMModel.GLRMParameters();
                        parms._train = train._key;
                        parms._transform = DataInfo.TransformType.NONE;
                        parms._k = 5;
                        parms._loss = loss;
                        parms._multi_loss = multiloss;
                        parms._init = GlrmInitialization.SVD;
                        parms._regularization_x = regs[rng.nextInt(regs.length)];
                        parms._regularization_y = regs[rng.nextInt(regs.length)];
                        parms._gamma_x = Math.abs(rng.nextDouble());
                        parms._gamma_y = Math.abs(rng.nextDouble());
                        parms._recover_svd = false;
                        parms._seed = myseed;
                        parms._verbose = false;
                        parms._max_iterations = 500;
                        model = (GLRMModel)new GLRM(parms).trainModel().get();
                        Log.info((Object[])new Object[]{"Iteration " + ((GLRMModel.GLRMOutput)model._output)._iterations + ": Objective value = " + ((GLRMModel.GLRMOutput)model._output)._objective});
                        model.score(train).delete();
                        ModelMetricsGLRM mm = (ModelMetricsGLRM)ModelMetrics.getFromDKV((Model)model, (Frame)train);
                        Log.info((Object[])new Object[]{"Numeric Sum of Squared Error = " + mm._numerr + "\tCategorical Misclassification Error = " + mm._caterr});
                    }
                    finally {
                        if (model != null) {
                            model.delete();
                        }
                        Scope.exit((Key[])new Key[0]);
                    }
                }
            }
        }
        finally {
            if (train != null) {
                train.delete();
            }
            Scope.exit((Key[])new Key[0]);
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Test
    public void testSetColumnLossCats() throws InterruptedException, ExecutionException {
        GLRMModel model = null;
        Frame train = null;
        int[] cats = new int[]{1, 3, 4, 5};
        Scope.enter();
        try {
            train = GLRMCategoricalTest.parse_test_file((Key)Key.make((String)"prostate.hex"), (String)"smalldata/logreg/prostate.csv");
            for (int i = 0; i < cats.length; ++i) {
                Scope.track((Vec)train.replace(cats[i], train.vec(cats[i]).toCategoricalVec()));
            }
            train.remove("ID").remove();
            DKV.put((Key)train._key, (Iced)train);
            GLRMModel.GLRMParameters parms = new GLRMModel.GLRMParameters();
            parms._train = train._key;
            parms._k = 12;
            parms._loss = GlrmLoss.Quadratic;
            parms._multi_loss = GlrmLoss.Categorical;
            parms._loss_by_col = new GlrmLoss[]{GlrmLoss.Ordinal, GlrmLoss.Poisson, GlrmLoss.Absolute};
            parms._loss_by_col_idx = new int[]{3, 1, 6};
            parms._init = GlrmInitialization.PlusPlus;
            parms._min_step_size = 1.0E-5;
            parms._recover_svd = false;
            parms._max_iterations = 2000;
            model = (GLRMModel)new GLRM(parms).trainModel().get();
            Log.info((Object[])new Object[]{"Iteration " + ((GLRMModel.GLRMOutput)model._output)._iterations + ": Objective value = " + ((GLRMModel.GLRMOutput)model._output)._objective});
            GLRMTest.checkLossbyCol(parms, model);
            model.score(train).delete();
            ModelMetricsGLRM mm = (ModelMetricsGLRM)ModelMetrics.getFromDKV((Model)model, (Frame)train);
            Log.info((Object[])new Object[]{"Numeric Sum of Squared Error = " + mm._numerr + "\tCategorical Misclassification Error = " + mm._caterr});
        }
        finally {
            if (train != null) {
                train.delete();
            }
            if (model != null) {
                model.delete();
            }
            Scope.exit((Key[])new Key[0]);
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Test
    public void testExpandCatsIris() throws InterruptedException, ExecutionException {
        double[][] iris = GLRMCategoricalTest.ard((double[][])new double[][]{GLRMCategoricalTest.ard((double[])new double[]{6.3, 2.5, 4.9, 1.5, 1.0}), GLRMCategoricalTest.ard((double[])new double[]{5.7, 2.8, 4.5, 1.3, 1.0}), GLRMCategoricalTest.ard((double[])new double[]{5.6, 2.8, 4.9, 2.0, 2.0}), GLRMCategoricalTest.ard((double[])new double[]{5.0, 3.4, 1.6, 0.4, 0.0}), GLRMCategoricalTest.ard((double[])new double[]{6.0, 2.2, 5.0, 1.5, 2.0})});
        double[][] iris_expandR = GLRMCategoricalTest.ard((double[][])new double[][]{GLRMCategoricalTest.ard((double[])new double[]{0.0, 1.0, 0.0, 6.3, 2.5, 4.9, 1.5}), GLRMCategoricalTest.ard((double[])new double[]{0.0, 1.0, 0.0, 5.7, 2.8, 4.5, 1.3}), GLRMCategoricalTest.ard((double[])new double[]{0.0, 0.0, 1.0, 5.6, 2.8, 4.9, 2.0}), GLRMCategoricalTest.ard((double[])new double[]{1.0, 0.0, 0.0, 5.0, 3.4, 1.6, 0.4}), GLRMCategoricalTest.ard((double[])new double[]{0.0, 0.0, 1.0, 6.0, 2.2, 5.0, 1.5})});
        String[] iris_cols = new String[]{"sepal_len", "sepal_wid", "petal_len", "petal_wid", "class"};
        String[][] iris_domains = new String[][]{null, null, null, null, {"setosa", "versicolor", "virginica"}};
        Frame fr = null;
        try {
            fr = GLRMCategoricalTest.parse_test_file((Key)Key.make((String)"iris.hex"), (String)"smalldata/iris/iris_wheader.csv");
            DataInfo dinfo = new DataInfo(fr, null, 0, true, DataInfo.TransformType.NONE, DataInfo.TransformType.NONE, false, false, false, false, false, false);
            Log.info((Object[])new Object[]{"Original matrix:\n" + GLRMCategoricalTest.colFormat(iris_cols, "%8.7s") + ArrayUtils.pprint((double[][])iris)});
            double[][] iris_perm = ArrayUtils.permuteCols((double[][])iris, (int[])dinfo._permutation);
            Log.info((Object[])new Object[]{"Permuted matrix:\n" + GLRMCategoricalTest.colFormat(iris_cols, "%8.7s", dinfo._permutation) + ArrayUtils.pprint((double[][])iris_perm)});
            double[][] iris_exp = GLRM.expandCats((double[][])iris_perm, (DataInfo)dinfo);
            Log.info((Object[])new Object[]{"Expanded matrix:\n" + GLRMCategoricalTest.colExpFormat(iris_cols, iris_domains, "%8.7s", dinfo._permutation) + ArrayUtils.pprint((double[][])iris_exp)});
            Assert.assertArrayEquals((Object[])iris_expandR, (Object[])iris_exp);
        }
        finally {
            if (fr != null) {
                fr.delete();
            }
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Test
    public void testExpandCatsProstate() throws InterruptedException, ExecutionException {
        double[][] prostate = GLRMCategoricalTest.ard((double[][])new double[][]{GLRMCategoricalTest.ard((double[])new double[]{0.0, 71.0, 1.0, 0.0, 0.0, 4.8, 14.0, 7.0}), GLRMCategoricalTest.ard((double[])new double[]{1.0, 70.0, 1.0, 1.0, 0.0, 8.4, 21.8, 5.0}), GLRMCategoricalTest.ard((double[])new double[]{0.0, 73.0, 1.0, 3.0, 0.0, 10.0, 27.4, 6.0}), GLRMCategoricalTest.ard((double[])new double[]{1.0, 68.0, 1.0, 0.0, 0.0, 6.7, 16.7, 6.0})});
        double[][] pros_expandR = GLRMCategoricalTest.ard((double[][])new double[][]{GLRMCategoricalTest.ard((double[])new double[]{1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 71.0, 4.8, 14.0, 7.0}), GLRMCategoricalTest.ard((double[])new double[]{0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 70.0, 8.4, 21.8, 5.0}), GLRMCategoricalTest.ard((double[])new double[]{0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 73.0, 10.0, 27.4, 6.0}), GLRMCategoricalTest.ard((double[])new double[]{1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 68.0, 6.7, 16.7, 6.0})});
        String[] pros_cols = new String[]{"Capsule", "Age", "Race", "Dpros", "Dcaps", "PSA", "Vol", "Gleason"};
        String[][] pros_domains = new String[][]{{"No", "Yes"}, null, {"Other", "White", "Black"}, {"None", "UniLeft", "UniRight", "Bilobar"}, {"No", "Yes"}, null, null, null};
        int[] cats = new int[]{1, 3, 4, 5};
        Frame fr = null;
        try {
            Scope.enter();
            fr = GLRMCategoricalTest.parse_test_file((Key)Key.make((String)"prostate.hex"), (String)"smalldata/logreg/prostate.csv");
            for (int i = 0; i < cats.length; ++i) {
                Scope.track((Vec)fr.replace(cats[i], fr.vec(cats[i]).toCategoricalVec()));
            }
            fr.remove("ID").remove();
            DKV.put((Key)fr._key, (Iced)fr);
            DataInfo dinfo = new DataInfo(fr, null, 0, true, DataInfo.TransformType.NONE, DataInfo.TransformType.NONE, false, false, false, false, false, false);
            Log.info((Object[])new Object[]{"Original matrix:\n" + GLRMCategoricalTest.colFormat(pros_cols, "%8.7s") + ArrayUtils.pprint((double[][])prostate)});
            double[][] pros_perm = ArrayUtils.permuteCols((double[][])prostate, (int[])dinfo._permutation);
            Log.info((Object[])new Object[]{"Permuted matrix:\n" + GLRMCategoricalTest.colFormat(pros_cols, "%8.7s", dinfo._permutation) + ArrayUtils.pprint((double[][])pros_perm)});
            double[][] pros_exp = GLRM.expandCats((double[][])pros_perm, (DataInfo)dinfo);
            Log.info((Object[])new Object[]{"Expanded matrix:\n" + GLRMCategoricalTest.colExpFormat(pros_cols, pros_domains, "%8.7s", dinfo._permutation) + ArrayUtils.pprint((double[][])pros_exp)});
            Assert.assertArrayEquals((Object[])pros_expandR, (Object[])pros_exp);
        }
        finally {
            if (fr != null) {
                fr.delete();
            }
            Scope.exit((Key[])new Key[0]);
        }
    }
}

