package hex.glm;

import hex.CreateFrame;
import hex.FrameSplitter;
import hex.ModelMetricsMultinomial;
import hex.SplitFrame;
import hex.deeplearning.DeepLearningModel;
import hex.glm.GLMModel;
import java.util.Random;
import org.junit.AfterClass;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import water.DKV;
import water.H2O;
import water.Key;
import water.Keyed;
import water.Scope;
import water.TestUtil;
import water.exceptions.H2OIllegalArgumentException;
import water.fvec.Frame;
import water.fvec.Vec;

/* loaded from: input_file:hex/glm/GLMBasicTestMultinomial.class */
public class GLMBasicTestMultinomial extends TestUtil {
    static Frame _covtype;
    static Frame _train;
    static Frame _test;
    double _tol = 1.0E-10d;

    @Rule
    public ExpectedException expectedException = ExpectedException.none();

    @BeforeClass
    public static void setup() {
        stall_till_cloudsize(1);
        _covtype = parse_test_file("smalldata/covtype/covtype.20k.data");
        _covtype.replace(_covtype.numCols() - 1, _covtype.lastVec().toCategoricalVec()).remove();
        Key[] keyArr = {Key.make("train"), Key.make("test")};
        H2O.submitTask(new FrameSplitter(_covtype, new double[]{0.8d}, keyArr, (Key) null)).join();
        _train = DKV.getGet(keyArr[0]);
        _test = DKV.getGet(keyArr[1]);
    }

    @AfterClass
    public static void cleanUp() {
        if (_covtype != null) {
            _covtype.delete();
        }
        if (_train != null) {
            _train.delete();
        }
        if (_test != null) {
            _test.delete();
        }
    }

    @Test
    public void testMultinomialPredMojoPojo() {
        try {
            Scope.enter();
            CreateFrame createFrame = new CreateFrame();
            Random random = new Random();
            int nextInt = random.nextInt(10000) + 15000 + 200;
            int nextInt2 = random.nextInt(17) + 3;
            int nextInt3 = random.nextInt(7) + 3;
            createFrame.rows = nextInt;
            createFrame.cols = nextInt2;
            createFrame.factors = 10;
            createFrame.has_response = true;
            createFrame.response_factors = nextInt3;
            createFrame.positive_response = true;
            createFrame.missing_fraction = 0.0d;
            createFrame.seed = System.currentTimeMillis();
            System.out.println("Createframe parameters: rows: " + nextInt + " cols:" + nextInt2 + " response number:" + nextInt3 + " seed: " + createFrame.seed);
            SplitFrame splitFrame = new SplitFrame(Scope.track(new Frame[]{(Frame) createFrame.execImpl().get()}), new double[]{0.8d, 0.2d}, new Key[]{Key.make("train.hex"), Key.make("test.hex")});
            splitFrame.exec().get();
            Key[] keyArr = splitFrame._destination_frames;
            Frame frame = DKV.get(keyArr[0]).get();
            Frame frame2 = DKV.get(keyArr[1]).get();
            Scope.track(new Frame[]{frame});
            Scope.track(new Frame[]{frame2});
            GLMModel.GLMParameters gLMParameters = new GLMModel.GLMParameters(GLMModel.GLMParameters.Family.multinomial, GLMModel.GLMParameters.Family.multinomial.defaultLink, new double[]{0.0d}, new double[]{0.0d}, 0.0d, 0.0d);
            gLMParameters._train = frame._key;
            gLMParameters._lambda_search = false;
            gLMParameters._response_column = "response";
            gLMParameters._lambda = new double[]{0.0d};
            gLMParameters._alpha = new double[]{0.001d};
            gLMParameters._objective_epsilon = 1.0E-6d;
            gLMParameters._beta_epsilon = 1.0E-4d;
            gLMParameters._standardize = false;
            GLMModel gLMModel = new GLM(gLMParameters).trainModel().get();
            Scope.track_generic(gLMModel);
            Frame score = gLMModel.score(frame2);
            Scope.track(new Frame[]{score});
            Assert.assertTrue(gLMModel.testJavaScoring(frame2, score, this._tol));
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    @Test
    public void testCovtypeNoIntercept() {
        GLMModel.GLMParameters gLMParameters = new GLMModel.GLMParameters(GLMModel.GLMParameters.Family.multinomial);
        GLMModel gLMModel = null;
        Frame frame = null;
        Vec makeCon = _covtype.anyVec().makeCon(1.0d);
        Key make = Key.make("cov_with_weights");
        Frame frame2 = new Frame(make, _covtype.names(), _covtype.vecs());
        frame2.add("weights", makeCon);
        DKV.put(frame2);
        try {
            gLMParameters._response_column = "C55";
            gLMParameters._train = make;
            gLMParameters._valid = _covtype._key;
            gLMParameters._objective_epsilon = 1.0E-6d;
            gLMParameters._beta_epsilon = 1.0E-4d;
            gLMParameters._weights_column = "weights";
            gLMParameters._missing_values_handling = DeepLearningModel.DeepLearningParameters.MissingValuesHandling.Skip;
            gLMParameters._intercept = false;
            GLMModel.GLMParameters.Solver solver = GLMModel.GLMParameters.Solver.L_BFGS;
            System.out.println("solver = " + solver);
            gLMParameters._solver = solver;
            gLMParameters._max_iterations = 5000;
            for (double d : new double[]{0.0d, 0.5d, 0.1d}) {
                gLMParameters._alpha = new double[]{d};
                GLMModel gLMModel2 = new GLM(gLMParameters).trainModel().get();
                System.out.println(gLMModel2.coefficients());
                for (double[] dArr : gLMModel2._output.getNormBetaMultinomial()) {
                    Assert.assertEquals(0.0d, dArr[dArr.length - 1], 0.0d);
                }
                System.out.println(gLMModel2._output._model_summary);
                System.out.println(gLMModel2._output._training_metrics);
                System.out.println(gLMModel2._output._validation_metrics);
                Frame score = gLMModel2.score(_covtype);
                Assert.assertTrue(gLMModel2._output._training_metrics.equals(ModelMetricsMultinomial.getFromDKV(gLMModel2, _covtype)));
                gLMModel2.delete();
                gLMModel = null;
                score.delete();
                frame = null;
            }
        } finally {
            makeCon.remove();
            DKV.remove(make);
            if (gLMModel != null) {
                gLMModel.delete();
            }
            if (frame != null) {
                frame.delete();
            }
        }
    }

    @Test
    public void testCovtypeBasic() {
        GLMModel.GLMParameters gLMParameters = new GLMModel.GLMParameters(GLMModel.GLMParameters.Family.multinomial);
        GLMModel gLMModel = null;
        Frame frame = null;
        Vec makeCon = _covtype.anyVec().makeCon(1.0d);
        Key make = Key.make("cov_with_weights");
        Frame frame2 = new Frame(make, _covtype.names(), _covtype.vecs());
        frame2.add("weights", makeCon);
        DKV.put(frame2);
        try {
            gLMParameters._response_column = "C55";
            gLMParameters._train = make;
            gLMParameters._valid = _covtype._key;
            gLMParameters._lambda = new double[]{4.881E-5d};
            gLMParameters._alpha = new double[]{1.0d};
            gLMParameters._objective_epsilon = 1.0E-6d;
            gLMParameters._beta_epsilon = 1.0E-4d;
            gLMParameters._weights_column = "weights";
            gLMParameters._missing_values_handling = DeepLearningModel.DeepLearningParameters.MissingValuesHandling.Skip;
            double[] dArr = {1.0d};
            double[] dArr2 = {25499.76d};
            double[] dArr3 = {2.54475E-5d};
            for (GLMModel.GLMParameters.Solver solver : new GLMModel.GLMParameters.Solver[]{GLMModel.GLMParameters.Solver.IRLSM, GLMModel.GLMParameters.Solver.COORDINATE_DESCENT, GLMModel.GLMParameters.Solver.L_BFGS}) {
                System.out.println("solver = " + solver);
                gLMParameters._solver = solver;
                gLMParameters._max_iterations = gLMParameters._solver == GLMModel.GLMParameters.Solver.L_BFGS ? 300 : 10;
                for (int i = 0; i < dArr.length; i++) {
                    gLMParameters._alpha[0] = dArr[i];
                    gLMParameters._lambda[0] = dArr3[i];
                    GLMModel gLMModel2 = new GLM(gLMParameters).trainModel().get();
                    System.out.println(gLMModel2._output._model_summary);
                    System.out.println(gLMModel2._output._training_metrics);
                    System.out.println(gLMModel2._output._validation_metrics);
                    Assert.assertTrue(gLMModel2._output._training_metrics.equals(gLMModel2._output._validation_metrics));
                    Assert.assertTrue(gLMModel2._output._training_metrics._resDev <= dArr2[i] * 1.1d);
                    Frame score = gLMModel2.score(_covtype);
                    Assert.assertTrue(gLMModel2._output._training_metrics.equals(ModelMetricsMultinomial.getFromDKV(gLMModel2, _covtype)));
                    gLMModel2.delete();
                    gLMModel = null;
                    score.delete();
                    frame = null;
                }
            }
        } finally {
            makeCon.remove();
            DKV.remove(make);
            if (gLMModel != null) {
                gLMModel.delete();
            }
            if (frame != null) {
                frame.delete();
            }
        }
    }

    @Test
    public void testCovtypeMinActivePredictors() {
        GLMModel.GLMParameters gLMParameters = new GLMModel.GLMParameters(GLMModel.GLMParameters.Family.multinomial);
        GLMModel gLMModel = null;
        Frame frame = null;
        try {
            gLMParameters._response_column = "C55";
            gLMParameters._train = _covtype._key;
            gLMParameters._valid = _covtype._key;
            gLMParameters._lambda = new double[]{4.881E-5d};
            gLMParameters._alpha = new double[]{1.0d};
            gLMParameters._objective_epsilon = 1.0E-6d;
            gLMParameters._beta_epsilon = 1.0E-4d;
            gLMParameters._max_active_predictors = 50;
            gLMParameters._max_iterations = 10;
            new double[1][0] = 0.99d;
            new double[1][0] = 2.54475E-5d;
            GLMModel.GLMParameters.Solver solver = GLMModel.GLMParameters.Solver.COORDINATE_DESCENT;
            System.out.println("solver = " + solver);
            gLMParameters._solver = solver;
            GLMModel gLMModel2 = new GLM(gLMParameters).trainModel().get();
            System.out.println(gLMModel2._output._model_summary);
            System.out.println(gLMModel2._output._training_metrics);
            System.out.println(gLMModel2._output._validation_metrics);
            System.out.println("rank = " + gLMModel2._output.rank() + ", max active preds = " + (gLMParameters._max_active_predictors + gLMModel2._output.nclasses()));
            Assert.assertTrue(gLMModel2._output.rank() <= gLMParameters._max_active_predictors + gLMModel2._output.nclasses());
            Assert.assertTrue(gLMModel2._output._training_metrics.equals(gLMModel2._output._validation_metrics));
            Assert.assertTrue(gLMModel2._output._training_metrics._resDev <= 33000.0d * 1.1d);
            Frame score = gLMModel2.score(_covtype);
            Assert.assertTrue(gLMModel2._output._training_metrics.equals(ModelMetricsMultinomial.getFromDKV(gLMModel2, _covtype)));
            gLMModel2.delete();
            gLMModel = null;
            score.delete();
            frame = null;
            if (0 != 0) {
                gLMModel.delete();
            }
            if (0 != 0) {
                frame.delete();
            }
        } catch (Throwable th) {
            if (gLMModel != null) {
                gLMModel.delete();
            }
            if (frame != null) {
                frame.delete();
            }
            throw th;
        }
    }

    @Test
    public void testCovtypeLS() {
        GLMModel.GLMParameters gLMParameters = new GLMModel.GLMParameters(GLMModel.GLMParameters.Family.multinomial);
        GLMModel gLMModel = null;
        Frame frame = null;
        try {
            gLMParameters._nlambdas = 3;
            gLMParameters._response_column = "C55";
            gLMParameters._train = _covtype._key;
            gLMParameters._valid = _covtype._key;
            gLMParameters._alpha = new double[]{0.99d};
            gLMParameters._objective_epsilon = 1.0E-6d;
            gLMParameters._beta_epsilon = 1.0E-4d;
            gLMParameters._max_active_predictors = 50;
            gLMParameters._max_iterations = 500;
            gLMParameters._solver = GLMModel.GLMParameters.Solver.AUTO;
            gLMParameters._lambda_search = true;
            GLMModel gLMModel2 = new GLM(gLMParameters).trainModel().get();
            System.out.println(gLMModel2._output._training_metrics);
            System.out.println(gLMModel2._output._validation_metrics);
            Assert.assertTrue(gLMModel2._output._training_metrics.equals(gLMModel2._output._validation_metrics));
            Frame score = gLMModel2.score(_covtype);
            Assert.assertTrue(gLMModel2._output._training_metrics.equals(ModelMetricsMultinomial.getFromDKV(gLMModel2, _covtype)));
            Assert.assertTrue(gLMModel2._output._training_metrics._resDev <= 33000.0d);
            System.out.println(gLMModel2._output._model_summary);
            gLMModel2.delete();
            gLMModel = null;
            score.delete();
            frame = null;
            if (0 != 0) {
                gLMModel.delete();
            }
            if (0 != 0) {
                frame.delete();
            }
        } catch (Throwable th) {
            if (gLMModel != null) {
                gLMModel.delete();
            }
            if (frame != null) {
                frame.delete();
            }
            throw th;
        }
    }

    @Test
    public void testCovtypeNAs() {
        GLMModel.GLMParameters gLMParameters = new GLMModel.GLMParameters(GLMModel.GLMParameters.Family.multinomial);
        GLMModel gLMModel = null;
        Frame frame = null;
        Keyed keyed = null;
        Keyed keyed2 = null;
        try {
            keyed2 = _covtype.deepCopy("covtype_copy");
            DKV.put(keyed2);
            Vec.Writer open = keyed2.vec(54).open();
            open.setNA(10L);
            open.setNA(20L);
            open.setNA(30L);
            open.close();
            keyed = new Frame(Key.make("covtype_subset"), new String[]{"C51", "C52", "C53", "C54", "C55"}, keyed2.vecs(new int[]{50, 51, 52, 53, 54}));
            DKV.put(keyed);
            gLMParameters._response_column = "C55";
            gLMParameters._train = ((Frame) keyed2)._key;
            gLMParameters._valid = ((Frame) keyed2)._key;
            gLMParameters._alpha = new double[]{0.99d};
            gLMParameters._objective_epsilon = 1.0E-6d;
            gLMParameters._beta_epsilon = 1.0E-4d;
            gLMParameters._max_active_predictors = 50;
            gLMParameters._max_iterations = 500;
            gLMParameters._solver = GLMModel.GLMParameters.Solver.L_BFGS;
            gLMParameters._missing_values_handling = DeepLearningModel.DeepLearningParameters.MissingValuesHandling.Skip;
            GLMModel gLMModel2 = new GLM(gLMParameters).trainModel().get();
            Assert.assertEquals((keyed2.numRows() - 3) - 1, gLMModel2._nullDOF);
            System.out.println(gLMModel2._output._training_metrics);
            System.out.println(gLMModel2._output._validation_metrics);
            Assert.assertTrue(gLMModel2._output._training_metrics.equals(gLMModel2._output._validation_metrics));
            Frame score = gLMModel2.score(keyed2);
            Assert.assertTrue(gLMModel2._output._training_metrics.equals(ModelMetricsMultinomial.getFromDKV(gLMModel2, keyed2)));
            Assert.assertTrue(gLMModel2._output._training_metrics._resDev <= 26000.0d);
            System.out.println(gLMModel2._output._model_summary);
            gLMModel2.delete();
            score.delete();
            gLMParameters._train = ((Frame) keyed)._key;
            GLMModel gLMModel3 = new GLM(gLMParameters).trainModel().get();
            Assert.assertEquals((keyed2.numRows() - 3) - 1, gLMModel3._nullDOF);
            System.out.println(gLMModel3._output._training_metrics);
            System.out.println(gLMModel3._output._validation_metrics);
            Assert.assertTrue(gLMModel3._output._training_metrics.equals(gLMModel3._output._validation_metrics));
            Frame score2 = gLMModel3.score(_covtype);
            System.out.println(gLMModel3._output._model_summary);
            Assert.assertTrue(gLMModel3._output._training_metrics._resDev <= 66000.0d);
            gLMModel3.delete();
            gLMModel = null;
            score2.delete();
            frame = null;
            if (keyed != null) {
                keyed.delete();
            }
            if (keyed2 != null) {
                keyed2.delete();
            }
            if (0 != 0) {
                gLMModel.delete();
            }
            if (0 != 0) {
                frame.delete();
            }
        } catch (Throwable th) {
            if (keyed != null) {
                keyed.delete();
            }
            if (keyed2 != null) {
                keyed2.delete();
            }
            if (gLMModel != null) {
                gLMModel.delete();
            }
            if (frame != null) {
                frame.delete();
            }
            throw th;
        }
    }

    @Test
    public void testNaiveCoordinateDescent() {
        this.expectedException.expect(H2OIllegalArgumentException.class);
        this.expectedException.expectMessage("Naive coordinate descent is not supported for multinomial.");
        GLMModel.GLMParameters gLMParameters = new GLMModel.GLMParameters(GLMModel.GLMParameters.Family.multinomial);
        gLMParameters._solver = GLMModel.GLMParameters.Solver.COORDINATE_DESCENT_NAIVE;
        new GLM(gLMParameters);
    }

    @Test
    public void testNaiveCoordinateDescent_families() {
        GLMModel.GLMParameters gLMParameters = new GLMModel.GLMParameters(GLMModel.GLMParameters.Family.binomial);
        gLMParameters._solver = GLMModel.GLMParameters.Solver.COORDINATE_DESCENT_NAIVE;
        GLMModel.GLMParameters.Family[] familyArr = {GLMModel.GLMParameters.Family.binomial, GLMModel.GLMParameters.Family.gaussian, GLMModel.GLMParameters.Family.gamma, GLMModel.GLMParameters.Family.tweedie, GLMModel.GLMParameters.Family.poisson, GLMModel.GLMParameters.Family.ordinal, GLMModel.GLMParameters.Family.quasibinomial};
        GLMModel.GLMParameters.Link[] linkArr = {GLMModel.GLMParameters.Link.logit, GLMModel.GLMParameters.Link.identity, GLMModel.GLMParameters.Link.log, GLMModel.GLMParameters.Link.tweedie, GLMModel.GLMParameters.Link.log, GLMModel.GLMParameters.Link.ologit, GLMModel.GLMParameters.Link.logit};
        for (int i = 0; i < familyArr.length; i++) {
            gLMParameters._family = familyArr[i];
            gLMParameters._link = linkArr[i];
            new GLM(gLMParameters);
        }
    }
}
