/*
 * Decompiled with CFR 0.152.
 */
package hex.glm;

import hex.StringPair;
import hex.glm.GLM;
import hex.glm.GLMModel;
import java.util.Arrays;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import water.DKV;
import water.Key;
import water.Keyed;
import water.Scope;
import water.TestUtil;
import water.fvec.Frame;
import water.fvec.TestFrameBuilder;
import water.util.ArrayUtils;

public class GLMPlugValuesTest
extends TestUtil {
    @BeforeClass
    public static void setup() {
        GLMPlugValuesTest.stall_till_cloudsize((int)1);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Test
    public void testNumeric() {
        Scope.enter();
        try {
            Frame fr = new TestFrameBuilder().withColNames(new String[]{"x", "y", "z"}).withDataForCol(0, GLMPlugValuesTest.ard((double[])new double[]{1.0, Double.NaN})).withDataForCol(1, GLMPlugValuesTest.ard((double[])new double[]{Double.NaN, 2.0})).withDataForCol(2, GLMPlugValuesTest.ard((double[])new double[]{2.0, 8.0})).build();
            Frame fr2 = new TestFrameBuilder().withColNames(new String[]{"x", "y", "z"}).withDataForCol(0, GLMPlugValuesTest.ard((double[])new double[]{1.0, 4.0})).withDataForCol(1, GLMPlugValuesTest.ard((double[])new double[]{0.5, 2.0})).withDataForCol(2, GLMPlugValuesTest.ard((double[])new double[]{2.0, 8.0})).build();
            Frame plugValues = GLMPlugValuesTest.oneRowFrame(new String[]{"x", "y"}, new double[]{4.0, 0.5}, new String[0]);
            GLMModel.GLMParameters params = new GLMModel.GLMParameters();
            params._response_column = "z";
            params._family = GLMModel.GLMParameters.Family.gaussian;
            params._standardize = false;
            params._train = fr._key;
            params._ignore_const_cols = false;
            params._intercept = false;
            params._seed = 42L;
            GLMModel.GLMParameters params2 = (GLMModel.GLMParameters)params.clone();
            params2._train = fr2._key;
            params._missing_values_handling = GLMModel.GLMParameters.MissingValuesHandling.PlugValues;
            params._plug_values = plugValues._key;
            GLMModel model = (GLMModel)new GLM(params).trainModel().get();
            Scope.track_generic((Keyed)model);
            GLMModel model2 = (GLMModel)new GLM(params2).trainModel().get();
            Scope.track_generic((Keyed)model2);
            Assert.assertEquals((Object)model2.coefficients(), (Object)model.coefficients());
        }
        finally {
            Scope.exit((Key[])new Key[0]);
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Test
    public void testCategorical() {
        Scope.enter();
        try {
            Frame fr2 = new TestFrameBuilder().withColNames(new String[]{"x", "y"}).withVecTypes(new byte[]{4, 3}).withDataForCol(0, GLMPlugValuesTest.ar((String[])new String[]{"a", "b"})).withDataForCol(1, GLMPlugValuesTest.ard((double[])new double[]{1.0, 2.0})).build();
            Frame fr = fr2.deepCopy(Key.make().toString());
            DKV.put((Keyed)fr);
            Scope.track((Frame[])new Frame[]{fr});
            fr.vec(0).setNA(1L);
            Frame plugValues = GLMPlugValuesTest.oneRowFrame(new String[]{"x"}, new double[0], "b");
            GLMModel.GLMParameters params = new GLMModel.GLMParameters();
            params._response_column = "y";
            params._family = GLMModel.GLMParameters.Family.gaussian;
            params._standardize = false;
            params._train = fr._key;
            params._ignore_const_cols = false;
            params._intercept = false;
            params._seed = 42L;
            GLMModel.GLMParameters params2 = (GLMModel.GLMParameters)params.clone();
            params2._train = fr2._key;
            params._missing_values_handling = GLMModel.GLMParameters.MissingValuesHandling.PlugValues;
            params._plug_values = plugValues._key;
            GLMModel model = (GLMModel)new GLM(params).trainModel().get();
            Scope.track_generic((Keyed)model);
            GLMModel model2 = (GLMModel)new GLM(params2).trainModel().get();
            Scope.track_generic((Keyed)model2);
            Assert.assertEquals((Object)model2.coefficients(), (Object)model.coefficients());
        }
        finally {
            Scope.exit((Key[])new Key[0]);
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Test
    public void testNumericInteraction() {
        Scope.enter();
        try {
            Frame fr = new TestFrameBuilder().withColNames(new String[]{"x", "y", "z"}).withDataForCol(0, GLMPlugValuesTest.ard((double[])new double[]{1.0, 4.0})).withDataForCol(1, GLMPlugValuesTest.ard((double[])new double[]{Double.NaN, 2.0})).withDataForCol(2, GLMPlugValuesTest.ard((double[])new double[]{2.0, 8.0})).build();
            Frame fr2 = new TestFrameBuilder().withColNames(new String[]{"x_y", "x", "y", "z"}).withDataForCol(0, GLMPlugValuesTest.ard((double[])new double[]{0.5, 8.0})).withDataForCol(1, GLMPlugValuesTest.ard((double[])new double[]{1.0, 4.0})).withDataForCol(2, GLMPlugValuesTest.ard((double[])new double[]{0.5, 2.0})).withDataForCol(3, GLMPlugValuesTest.ard((double[])new double[]{2.0, 8.0})).build();
            Frame plugValues = GLMPlugValuesTest.oneRowFrame(new String[]{"x_y", "x", "y"}, new double[]{0.5, 4.0, 0.5}, new String[0]);
            GLMModel.GLMParameters params = new GLMModel.GLMParameters();
            params._response_column = "z";
            params._family = GLMModel.GLMParameters.Family.gaussian;
            params._standardize = false;
            params._train = fr._key;
            params._ignore_const_cols = false;
            params._intercept = false;
            params._seed = 42L;
            GLMModel.GLMParameters params2 = (GLMModel.GLMParameters)params.clone();
            params2._train = fr2._key;
            params._missing_values_handling = GLMModel.GLMParameters.MissingValuesHandling.PlugValues;
            params._plug_values = plugValues._key;
            params._interaction_pairs = new StringPair[]{new StringPair("x", "y")};
            GLMModel model = (GLMModel)new GLM(params).trainModel().get();
            Scope.track_generic((Keyed)model);
            GLMModel model2 = (GLMModel)new GLM(params2).trainModel().get();
            Scope.track_generic((Keyed)model2);
            Assert.assertNotEquals((Object)0, model.coefficients().get("x_y"));
            Assert.assertEquals((Object)model2.coefficients(), (Object)model.coefficients());
        }
        finally {
            Scope.exit((Key[])new Key[0]);
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Test
    public void testCatCatInteraction_smoke() {
        Scope.enter();
        try {
            Frame fr = new TestFrameBuilder().withColNames(new String[]{"n", "x", "y", "z"}).withVecTypes(new byte[]{3, 4, 4, 3}).withDataForCol(0, GLMPlugValuesTest.ar((long[])new long[]{0L, 1L, 0L, 1L})).withDataForCol(1, GLMPlugValuesTest.ar((String[])new String[]{"a", "b", "a", "b"})).withDataForCol(2, GLMPlugValuesTest.ar((String[])new String[]{"A", "B", "B", "A"})).withDataForCol(3, GLMPlugValuesTest.ard((double[])new double[]{2.0, 8.0, 4.0, 1.0})).build();
            Frame plugValues = GLMPlugValuesTest.oneRowFrame(new String[]{"n", "x_y", "x", "y"}, new double[]{0.0}, "a_A", "a", "B");
            GLMModel.GLMParameters params = new GLMModel.GLMParameters();
            params._response_column = "z";
            params._family = GLMModel.GLMParameters.Family.gaussian;
            params._standardize = false;
            params._train = fr._key;
            params._ignore_const_cols = false;
            params._intercept = false;
            params._seed = 42L;
            params._missing_values_handling = GLMModel.GLMParameters.MissingValuesHandling.PlugValues;
            params._plug_values = plugValues._key;
            params._interaction_pairs = new StringPair[]{new StringPair("x", "y")};
            GLMModel model = (GLMModel)new GLM(params).trainModel().get();
            Scope.track_generic((Keyed)model);
        }
        finally {
            Scope.exit((Key[])new Key[0]);
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Test
    public void testNumCatInteraction_smoke() {
        Scope.enter();
        try {
            Frame fr = new TestFrameBuilder().withColNames(new String[]{"x", "y", "z"}).withVecTypes(new byte[]{3, 4, 3}).withDataForCol(0, GLMPlugValuesTest.ard((double[])new double[]{0.0, Double.NaN, 0.0, 1.0})).withDataForCol(1, GLMPlugValuesTest.ar((String[])new String[]{"a", "b", "a", "b"})).withDataForCol(2, GLMPlugValuesTest.ard((double[])new double[]{2.0, 8.0, 4.0, 1.0})).build();
            Frame plugValues = GLMPlugValuesTest.oneRowFrame(new String[]{"x", "x_y.a", "x_y.b", "y"}, new double[]{0.0, 1.0, 2.0}, "b");
            GLMModel.GLMParameters params = new GLMModel.GLMParameters();
            params._response_column = "z";
            params._family = GLMModel.GLMParameters.Family.gaussian;
            params._standardize = false;
            params._train = fr._key;
            params._ignore_const_cols = false;
            params._intercept = false;
            params._seed = 42L;
            params._missing_values_handling = GLMModel.GLMParameters.MissingValuesHandling.PlugValues;
            params._plug_values = plugValues._key;
            params._interaction_pairs = new StringPair[]{new StringPair("x", "y")};
            GLMModel model = (GLMModel)new GLM(params).trainModel().get();
            Scope.track_generic((Keyed)model);
        }
        finally {
            Scope.exit((Key[])new Key[0]);
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Test
    public void testPlugValues_zeros() {
        Scope.enter();
        try {
            Frame fr = GLMPlugValuesTest.parse_test_file((String)"smalldata/junit/cars.csv");
            Scope.track((Frame[])new Frame[]{fr});
            fr.remove("name");
            DKV.put((Keyed)fr);
            Assert.assertTrue((fr.vec("economy (mpg)").naCnt() > 0L ? 1 : 0) != 0);
            GLMModel.GLMParameters params = new GLMModel.GLMParameters(GLMModel.GLMParameters.Family.poisson, GLMModel.GLMParameters.Family.poisson.defaultLink, new double[]{0.0}, new double[]{0.0}, 0.0, 0.0);
            params._response_column = "power (hp)";
            params._train = fr._key;
            params._lambda = new double[]{0.0};
            params._alpha = new double[]{0.0};
            params._missing_values_handling = GLMModel.GLMParameters.MissingValuesHandling.MeanImputation;
            params._seed = 42L;
            GLMModel.GLMParameters params_means = (GLMModel.GLMParameters)params.clone();
            GLMModel.GLMParameters params_zeros = (GLMModel.GLMParameters)params.clone();
            GLMModel model = (GLMModel)new GLM(params).trainModel().get();
            Scope.track_generic((Keyed)model);
            Frame predictors = (Frame)fr.clone();
            predictors.remove(params._response_column);
            Frame plugValues = GLMPlugValuesTest.oneRowFrame(predictors.names(), predictors.means(), new String[0]);
            params_means._plug_values = plugValues._key;
            params_means._missing_values_handling = GLMModel.GLMParameters.MissingValuesHandling.PlugValues;
            GLMModel model_means = (GLMModel)new GLM(params_means).trainModel().get();
            Scope.track_generic((Keyed)model_means);
            Assert.assertArrayEquals((double[])model.beta(), (double[])model_means.beta(), (double)0.0);
            Assert.assertArrayEquals((double[])model.dinfo()._numNAFill, (double[])model_means.dinfo()._numNAFill, (double)0.0);
            Frame plugValues_zeros = GLMPlugValuesTest.oneRowFrame(predictors.names(), new double[predictors.numCols()], new String[0]);
            params_zeros._plug_values = plugValues_zeros._key;
            params_zeros._missing_values_handling = GLMModel.GLMParameters.MissingValuesHandling.PlugValues;
            GLMModel model_zeros = (GLMModel)new GLM(params_zeros).trainModel().get();
            Scope.track_generic((Keyed)model_zeros);
            Assert.assertArrayEquals((double[])model_zeros.dinfo()._numNAFill, (double[])new double[predictors.numCols()], (double)0.0);
            Assert.assertNotEquals(model.coefficients().get("economy (mpg)"), model_zeros.coefficients().get("economy (mpg)"));
        }
        finally {
            Scope.exit((Key[])new Key[0]);
        }
    }

    private static Frame oneRowFrame(String[] names, double[] values, String ... svalues) {
        int i;
        TestFrameBuilder builder = new TestFrameBuilder().withColNames(names);
        byte[] numTypes = new byte[values.length];
        Arrays.fill(numTypes, (byte)3);
        byte[] catTypes = new byte[svalues.length];
        Arrays.fill(catTypes, (byte)4);
        builder.withVecTypes(ArrayUtils.append((byte[])numTypes, (byte[])catTypes));
        for (i = 0; i < values.length; ++i) {
            builder.withDataForCol(i, new double[]{values[i]});
        }
        for (i = 0; i < svalues.length; ++i) {
            builder.withDataForCol(i + values.length, new String[]{svalues[i]});
        }
        return builder.build();
    }
}

