package hex.glm;

import hex.AUC2;
import hex.DataInfo;
import hex.ModelMetrics;
import hex.ModelMetricsBinomial;
import hex.ModelMetricsBinomialGLM;
import hex.ModelMetricsRegressionGLM;
import hex.glm.GLM;
import hex.glm.GLMModel;
import hex.glm.GLMTask;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Random;
import java.util.concurrent.ExecutionException;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Ignore;
import org.junit.Test;
import water.DKV;
import water.H2O;
import water.Job;
import water.Key;
import water.Keyed;
import water.MRTask;
import water.MemoryManager;
import water.Scope;
import water.TestUtil;
import water.fvec.C0DChunk;
import water.fvec.Chunk;
import water.fvec.FVecTest;
import water.fvec.Frame;
import water.fvec.Vec;
import water.parser.BufferedString;
import water.parser.ParseDataset;
import water.util.ArrayUtils;

/* loaded from: input_file:hex/glm/GLMTest.class */
public class GLMTest extends TestUtil {
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:hex/glm/GLMTest$GLMIterationTaskTest.class */
    private static final class GLMIterationTaskTest extends GLMTask.GLMIterationTask {
        final GLMModel _m;
        GLMMetricBuilder _val2;

        public GLMIterationTaskTest(Key key, DataInfo dataInfo, double d, GLMModel.GLMParameters gLMParameters, boolean z, double[] dArr, double d2, GLMModel gLMModel) {
            super(key, dataInfo, new GLMModel.GLMWeightsFun(gLMParameters), dArr);
            this._m = gLMModel;
        }

        public void map(Chunk[] chunkArr) {
            super.map(chunkArr);
            this._val2 = this._m.makeMetricBuilder(chunkArr[chunkArr.length - 1].vec().domain());
            double[] dArr = new double[3];
            float[] fArr = new float[1];
            for (int i = 0; i < chunkArr[0]._len; i++) {
                this._m.score0(chunkArr, i, (double[]) null, dArr);
                fArr[0] = (float) chunkArr[chunkArr.length - 1].atd(i);
                this._val2.perRow(dArr, fArr, this._m);
            }
        }

        public void reduce(GLMTask.GLMIterationTask gLMIterationTask) {
            super.reduce(gLMIterationTask);
            this._val2.reduce(((GLMIterationTaskTest) gLMIterationTask)._val2);
        }
    }

    /* loaded from: input_file:hex/glm/GLMTest$TestScore0.class */
    public static class TestScore0 extends MRTask {
        final GLMModel _m;
        final boolean _weights;
        final boolean _offset;

        public TestScore0(GLMModel gLMModel, boolean z, boolean z2) {
            this._m = gLMModel;
            this._weights = z;
            this._offset = z2;
        }

        private void checkScore(long j, double[] dArr, double[] dArr2) {
            int i = 0;
            if (this._m._parms._family == GLMModel.GLMParameters.Family.binomial && Math.abs(dArr[2] - this._m.defaultThreshold()) < 1.0E-10d) {
                i = 1;
            }
            if (this._m._parms._family == GLMModel.GLMParameters.Family.multinomial) {
                double[] dArr3 = new double[2];
                for (int i2 = 1; i2 < dArr.length; i2++) {
                    if (dArr[i2] > dArr3[0]) {
                        if (dArr[i2] > dArr3[1]) {
                            dArr3[0] = dArr3[1];
                            dArr3[1] = dArr[i2];
                        } else {
                            dArr3[0] = dArr[i2];
                        }
                    }
                }
                if (dArr3[1] - dArr3[0] < 1.0E-10d) {
                    i = 1;
                }
            }
            for (int i3 = i; i3 < dArr.length; i3++) {
                Assert.assertEquals("mismatch at row " + j + ", p = " + i3 + ": " + dArr2[i3] + " != " + dArr[i3] + ", predictions = " + Arrays.toString(dArr) + ", output = " + Arrays.toString(dArr2), dArr2[i3], dArr[i3], 1.0E-6d);
            }
        }

        public void map(Chunk[] chunkArr) {
            int nclasses = this._m._parms._family == GLMModel.GLMParameters.Family.multinomial ? this._m._output.nclasses() + 1 : this._m._parms._family == GLMModel.GLMParameters.Family.binomial ? 3 : 1;
            Chunk[] chunkArr2 = (Chunk[]) Arrays.copyOfRange(chunkArr, chunkArr.length - nclasses, chunkArr.length);
            Chunk[] chunkArr3 = (Chunk[]) Arrays.copyOf(chunkArr, chunkArr.length - nclasses);
            Chunk c0DChunk = new C0DChunk(0.0d, chunkArr3[0]._len);
            double[] dArr = new double[this._m._output._dinfo._cats + this._m._output._dinfo._nums];
            double[] dArr2 = new double[nclasses];
            double[] dArr3 = new double[nclasses];
            if (this._offset) {
                c0DChunk = chunkArr3[chunkArr3.length - 1];
                chunkArr3 = (Chunk[]) Arrays.copyOf(chunkArr3, chunkArr3.length - 1);
            }
            if (this._weights) {
                chunkArr3 = (Chunk[]) Arrays.copyOf(chunkArr3, chunkArr3.length - 1);
            }
            for (int i = 0; i < chunkArr3[0]._len; i++) {
                if (this._weights || this._offset) {
                    this._m.score0(chunkArr3, c0DChunk.atd(i), i, dArr, dArr2);
                } else {
                    this._m.score0(chunkArr3, i, dArr, dArr2);
                }
                for (int i2 = 0; i2 < dArr2.length; i2++) {
                    dArr3[i2] = chunkArr2[i2].atd(i);
                }
                checkScore(i + chunkArr3[0].start(), dArr2, dArr3);
            }
        }
    }

    @BeforeClass
    public static void setup() {
        stall_till_cloudsize(1);
    }

    public static void testScoring(GLMModel gLMModel, Frame frame) {
        Scope.enter();
        Frame frame2 = new Frame(frame);
        frame2.remove(gLMModel._output.responseName());
        Frame track = Scope.track(new Frame[]{gLMModel.score(frame2)});
        gLMModel.adaptTestForTrain(frame2, true, false);
        frame2.remove(frame2.numCols() - 1);
        int i = gLMModel._output._dinfo._cats + gLMModel._output._dinfo._nums;
        int numCols = (frame2.numCols() - (gLMModel._output._dinfo._weights ? 1 : 0)) - (gLMModel._output._dinfo._offset ? 1 : 0);
        if (!$assertionsDisabled && i != numCols) {
            throw new AssertionError(i + " != " + numCols);
        }
        frame2.add(track.names(), track.vecs());
        new TestScore0(gLMModel, gLMModel._output._dinfo._weights, gLMModel._output._dinfo._offset).doAll(frame2);
        if (!gLMModel._output._dinfo._weights && !gLMModel._output._dinfo._offset) {
            Assert.assertTrue(gLMModel.testJavaScoring(frame, track, 1.0E-15d));
        }
        Scope.exit(new Key[0]);
    }

    @Test
    public void testStandardizedCoeff() {
        testCoeffs(GLMModel.GLMParameters.Family.multinomial, "smalldata/glm_test/multinomial_10_classes_10_cols_10000_Rows_train.csv", "C11");
        testCoeffs(GLMModel.GLMParameters.Family.binomial, "smalldata/glm_test/binomial_20_cols_10KRows.csv", "C21");
        testCoeffs(GLMModel.GLMParameters.Family.gaussian, "smalldata/glm_test/gaussian_20cols_10000Rows.csv", "C21");
    }

    public void testCoeffs(GLMModel.GLMParameters.Family family, String str, String str2) {
        try {
            Scope.enter();
            Frame parse_test_file = parse_test_file(str);
            int numCols = parse_test_file.numCols();
            int i = (numCols - 1) / 2;
            for (int i2 = 0; i2 < i; i2++) {
                parse_test_file.replace(i2, parse_test_file.vec(i2).toCategoricalVec()).remove();
            }
            int i3 = numCols - 1;
            if (family.equals(GLMModel.GLMParameters.Family.binomial) || family.equals(GLMModel.GLMParameters.Family.multinomial)) {
                parse_test_file.replace(i3, parse_test_file.vec(i3).toCategoricalVec()).remove();
            }
            DKV.put(parse_test_file);
            Scope.track(new Frame[]{parse_test_file});
            GLMModel.GLMParameters gLMParameters = new GLMModel.GLMParameters(family);
            gLMParameters._standardize = true;
            gLMParameters._response_column = str2;
            gLMParameters._train = parse_test_file._key;
            GLMModel gLMModel = new GLM(gLMParameters).trainModel().get();
            Scope.track_generic(gLMModel);
            int[] iArr = new int[i];
            double[] dArr = new double[i];
            double[] dArr2 = new double[i];
            int i4 = 0;
            HashMap hashMap = new HashMap();
            HashMap hashMap2 = new HashMap();
            String[] names = parse_test_file.names();
            for (int i5 = i; i5 < i3; i5++) {
                iArr[i4] = i5;
                dArr[i4] = parse_test_file.vec(i5).mean();
                dArr2[i4] = 1.0d / parse_test_file.vec(i5).sigma();
                hashMap.put(names[i5], Double.valueOf(dArr[i4]));
                hashMap2.put(names[i5], Double.valueOf(parse_test_file.vec(i5).sigma()));
                i4++;
            }
            gLMParameters._standardize = false;
            GLMModel gLMModel2 = new GLM(gLMParameters).trainModel().get();
            Scope.track_generic(gLMModel2);
            HashMap coefficients = gLMModel2.coefficients(true);
            HashMap coefficients2 = gLMModel2.coefficients();
            if (family.equals(GLMModel.GLMParameters.Family.multinomial)) {
                double[] dArr3 = new double[gLMModel2._output._nclasses];
                for (String str3 : coefficients.keySet()) {
                    double doubleValue = ((Double) coefficients.get(str3)).doubleValue();
                    double doubleValue2 = ((Double) coefficients2.get(str3)).doubleValue();
                    if (Math.abs(doubleValue - doubleValue2) > 1.0E-6d) {
                        String[] split = str3.split("_");
                        if (split[1].equals("Intercept")) {
                            continue;
                        } else {
                            String str4 = split[1];
                            int intValue = Integer.valueOf(split[0]).intValue();
                            dArr3[intValue] = dArr3[intValue] + (doubleValue2 * ((Double) hashMap.get(str4)).doubleValue());
                            double doubleValue3 = doubleValue2 * ((Double) hashMap2.get(str4)).doubleValue();
                            if (!$assertionsDisabled && Math.abs(doubleValue - doubleValue3) >= 1.0E-6d) {
                                throw new AssertionError("Expected coefficients for " + split[1] + " is " + doubleValue + " but actual " + doubleValue3);
                            }
                        }
                    }
                }
                for (int i6 = 0; i6 < gLMModel2._output._nclasses; i6++) {
                    String str5 = i6 + "_Intercept";
                    double doubleValue4 = ((Double) coefficients.get(str5)).doubleValue();
                    double doubleValue5 = ((Double) coefficients2.get(str5)).doubleValue() + dArr3[i6];
                    if (!$assertionsDisabled && Math.abs(doubleValue4 - doubleValue5) >= 1.0E-6d) {
                        throw new AssertionError("Expected coefficients for " + str5 + " is " + doubleValue4 + " but actual " + doubleValue5);
                    }
                }
            } else {
                double d = 0.0d;
                for (String str6 : coefficients2.keySet()) {
                    double doubleValue6 = ((Double) coefficients.get(str6)).doubleValue();
                    double doubleValue7 = ((Double) coefficients2.get(str6)).doubleValue();
                    if (Math.abs(doubleValue6 - doubleValue7) > 1.0E-6d && !str6.equals("Intercept")) {
                        d += doubleValue7 * ((Double) hashMap.get(str6)).doubleValue();
                        double doubleValue8 = doubleValue7 * ((Double) hashMap2.get(str6)).doubleValue();
                        if (!$assertionsDisabled && Math.abs(doubleValue6 - doubleValue8) >= 1.0E-6d) {
                            throw new AssertionError("Expected coefficients for " + str6 + " is " + doubleValue6 + " but actual " + doubleValue8);
                        }
                    }
                }
                double doubleValue9 = ((Double) coefficients.get("Intercept")).doubleValue();
                double doubleValue10 = ((Double) coefficients2.get("Intercept")).doubleValue() + d;
                if (!$assertionsDisabled && Math.abs(doubleValue9 - doubleValue10) >= 1.0E-6d) {
                    throw new AssertionError("Expected coefficients for Intercept is " + doubleValue9 + " but actual " + doubleValue10);
                }
            }
            new TestUtil.StandardizeColumns(iArr, dArr, dArr2, parse_test_file).doAll(parse_test_file);
            DKV.put(parse_test_file);
            Scope.track(new Frame[]{parse_test_file});
            gLMParameters._standardize = false;
            gLMParameters._train = parse_test_file._key;
            GLMModel gLMModel3 = new GLM(gLMParameters).trainModel().get();
            Scope.track_generic(gLMModel3);
            if (family.equals(GLMModel.GLMParameters.Family.multinomial)) {
                double[][] normBetaMultinomial = gLMModel._output.getNormBetaMultinomial();
                double[][] normBetaMultinomial2 = gLMModel3._output.getNormBetaMultinomial();
                for (int i7 = 0; i7 < normBetaMultinomial.length; i7++) {
                    if (!$assertionsDisabled && !TestUtil.equalTwoArrays(normBetaMultinomial[i7], normBetaMultinomial2[i7], 1.0E-6d)) {
                        throw new AssertionError();
                    }
                }
            } else if (!$assertionsDisabled && !TestUtil.equalTwoArrays(gLMModel._output.getNormBeta(), gLMModel3._output.getNormBeta(), 1.0E-6d)) {
                throw new AssertionError();
            }
            HashMap coefficients3 = gLMModel.coefficients(true);
            HashMap coefficients4 = gLMModel3.coefficients(true);
            if (!$assertionsDisabled && !TestUtil.equalTwoHashMaps(coefficients3, coefficients4, 1.0E-6d)) {
                throw new AssertionError();
            }
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    @Test
    public void testGaussianRegression() throws InterruptedException, ExecutionException {
        Key make = Key.make("gaussian_test_data_raw");
        Key make2 = Key.make("gaussian_test_data_parsed");
        GLMModel gLMModel = null;
        Frame frame = null;
        Frame frame2 = null;
        try {
            FVecTest.makeByteVec(make, new String[]{"x,y\n0,0\n1,0.1\n2,0.2\n3,0.3\n4,0.4\n5,0.5\n6,0.6\n7,0.7\n8,0.8\n9,0.9"});
            frame = ParseDataset.parse(make2, new Key[]{make});
            GLMModel.GLMParameters gLMParameters = new GLMModel.GLMParameters(GLMModel.GLMParameters.Family.gaussian);
            gLMParameters._train = frame._key;
            gLMParameters._response_column = frame._names[1];
            gLMParameters._lambda = new double[]{0.0d};
            gLMModel = (GLMModel) new GLM(gLMParameters).trainModel().get();
            HashMap coefficients = gLMModel.coefficients();
            Assert.assertEquals(0.0d, ((Double) coefficients.get("Intercept")).doubleValue(), 1.0E-4d);
            Assert.assertEquals(0.1d, ((Double) coefficients.get("x")).doubleValue(), 1.0E-4d);
            testScoring(gLMModel, frame);
            if (frame != null) {
                frame.remove();
            }
            if (0 != 0) {
                frame2.remove();
            }
            if (gLMModel != null) {
                gLMModel.remove();
            }
        } catch (Throwable th) {
            if (frame != null) {
                frame.remove();
            }
            if (0 != 0) {
                frame2.remove();
            }
            if (gLMModel != null) {
                gLMModel.remove();
            }
            throw th;
        }
    }

    @Test
    public void testPoissonRegression() throws InterruptedException, ExecutionException {
        Key make = Key.make("poisson_test_data_raw");
        Key make2 = Key.make("poisson_test_data_parsed");
        GLMModel gLMModel = null;
        Frame frame = null;
        Frame frame2 = null;
        try {
            FVecTest.makeByteVec(make, new String[]{"x,y\n0,2\n1,4\n2,8\n3,16\n4,32\n5,64\n6,128\n7,256"});
            Frame parse = ParseDataset.parse(make2, new Key[]{make});
            Vec vec = parse.vec(0);
            System.out.println(vec.min() + ", " + vec.max() + ", mean = " + vec.mean());
            GLMModel.GLMParameters gLMParameters = new GLMModel.GLMParameters(GLMModel.GLMParameters.Family.poisson);
            gLMParameters._train = parse._key;
            gLMParameters._response_column = parse._names[1];
            gLMParameters._lambda = new double[]{0.0d};
            gLMParameters._standardize = false;
            GLMModel gLMModel2 = new GLM(gLMParameters).trainModel().get();
            for (double d : gLMModel2.beta()) {
                Assert.assertEquals(Math.log(2.0d), d, 0.01d);
            }
            testScoring(gLMModel2, parse);
            gLMModel2.delete();
            parse.delete();
            FVecTest.makeByteVec(make, new String[]{"x,y\n1,0\n2,1\n3,2\n4,3\n5,1\n6,4\n7,9\n8,18\n9,23\n10,31\n11,20\n12,25\n13,37\n14,45\n150,7.193936e+16\n"});
            frame = ParseDataset.parse(make2, new Key[]{make});
            GLMModel.GLMParameters gLMParameters2 = new GLMModel.GLMParameters(GLMModel.GLMParameters.Family.poisson);
            gLMParameters2._train = frame._key;
            gLMParameters2._response_column = frame._names[1];
            gLMParameters2._lambda = new double[]{0.0d};
            gLMParameters2._standardize = true;
            gLMParameters2._beta_epsilon = 1.0E-5d;
            gLMModel = (GLMModel) new GLM(gLMParameters2).trainModel().get();
            Assert.assertEquals(0.3396d, gLMModel.beta()[1], 0.1d);
            Assert.assertEquals(0.2565d, gLMModel.beta()[0], 0.1d);
            testScoring(gLMModel, frame);
            if (frame != null) {
                frame.delete();
            }
            if (0 != 0) {
                frame2.delete();
            }
            if (gLMModel != null) {
                gLMModel.delete();
            }
        } catch (Throwable th) {
            if (frame != null) {
                frame.delete();
            }
            if (0 != 0) {
                frame2.delete();
            }
            if (gLMModel != null) {
                gLMModel.delete();
            }
            throw th;
        }
    }

    @Test
    public void testGammaRegression() throws InterruptedException, ExecutionException {
        GLMModel gLMModel = null;
        Frame frame = null;
        Frame frame2 = null;
        try {
            Key make = Key.make("gamma_test_data_raw");
            Key make2 = Key.make("gamma_test_data_parsed");
            FVecTest.makeByteVec(make, new String[]{"x,y\n0,1\n1,0.5\n2,0.3333333\n3,0.25\n4,0.2\n5,0.1666667\n6,0.1428571\n7,0.125"});
            frame = ParseDataset.parse(make2, new Key[]{make});
            GLMModel.GLMParameters gLMParameters = new GLMModel.GLMParameters(GLMModel.GLMParameters.Family.gamma);
            gLMParameters._response_column = frame._names[1];
            gLMParameters._train = make2;
            gLMParameters._lambda = new double[]{0.0d};
            gLMModel = (GLMModel) new GLM(gLMParameters).trainModel().get();
            for (double d : gLMModel.beta()) {
                Assert.assertEquals(1.0d, d, 1.0E-4d);
            }
            testScoring(gLMModel, frame);
            if (frame != null) {
                frame.delete();
            }
            if (0 != 0) {
                frame2.delete();
            }
            if (gLMModel != null) {
                gLMModel.delete();
            }
        } catch (Throwable th) {
            if (frame != null) {
                frame.delete();
            }
            if (0 != 0) {
                frame2.delete();
            }
            if (gLMModel != null) {
                gLMModel.delete();
            }
            throw th;
        }
    }

    @Test
    public void testAllNAs() {
        Key make = Key.make("gamma_test_data_raw");
        Key make2 = Key.make("gamma_test_data_parsed");
        FVecTest.makeByteVec(make, new String[]{"x,y,z\n1,0,NA\n2,NA,1\nNA,3,2\n4,3,NA\n5,NA,1\nNA,6,4\n7,NA,9\n8,NA,18\nNA,9,23\n10,31,NA\nNA,11,20\n12,NA,25\nNA,13,37\n14,45,NA\n"});
        Frame parse = ParseDataset.parse(make2, new Key[]{make});
        try {
            try {
                GLMModel.GLMParameters gLMParameters = new GLMModel.GLMParameters(GLMModel.GLMParameters.Family.poisson);
                gLMParameters._response_column = parse._names[1];
                gLMParameters._train = make2;
                gLMParameters._lambda = new double[]{0.0d};
                gLMParameters._missing_values_handling = GLMModel.GLMParameters.MissingValuesHandling.Skip;
                new GLM(gLMParameters).trainModel().get();
                Assert.assertFalse("should've thrown IAE", true);
                parse.delete();
            } catch (IllegalArgumentException e) {
                Assert.assertTrue(e.getMessage(), e.getMessage().contains("No rows left in the dataset"));
                parse.delete();
            }
        } catch (Throwable th) {
            parse.delete();
            throw th;
        }
    }

    @Test
    public void testGradientTask() {
        Key make = Key.make("cars_parsed");
        Frame frame = null;
        DataInfo dataInfo = null;
        try {
            frame = parse_test_file(make, "smalldata/junit/mixcat_train.csv");
            GLMModel.GLMParameters gLMParameters = new GLMModel.GLMParameters(GLMModel.GLMParameters.Family.binomial, GLMModel.GLMParameters.Family.binomial.defaultLink, new double[]{0.0d}, new double[]{0.0d}, 0.0d, 0.0d);
            gLMParameters._train = make;
            gLMParameters._lambda = new double[]{0.0d};
            gLMParameters._use_all_factor_levels = true;
            frame.add("Useless", frame.remove("Useless"));
            DataInfo dataInfo2 = new DataInfo(frame, (Frame) null, 1, gLMParameters._use_all_factor_levels || gLMParameters._lambda_search, gLMParameters._standardize ? DataInfo.TransformType.STANDARDIZE : DataInfo.TransformType.NONE, DataInfo.TransformType.NONE, true, false, false, false, false, false);
            DKV.put(dataInfo2._key, dataInfo2);
            double[] malloc8d = MemoryManager.malloc8d(dataInfo2.fullN() + 1);
            Random random = new Random(987654321L);
            for (int i = 0; i < malloc8d.length; i++) {
                malloc8d[i] = 1.0d - (2.0d * random.nextDouble());
            }
            GLMTask.GLMGradientTask doAll = new GLMTask.GLMBinomialGradientTask((Key) null, dataInfo2, gLMParameters, gLMParameters._lambda[0], malloc8d).doAll(dataInfo2._adaptedFrame);
            GLMTask.GLMGradientTask doAll2 = new GLMTask.GLMGenericGradientTask((Key) null, dataInfo2, gLMParameters, gLMParameters._lambda[0], malloc8d).doAll(dataInfo2._adaptedFrame);
            for (int i2 = 0; i2 < malloc8d.length; i2++) {
                Assert.assertEquals("gradients differ", doAll._gradient[i2], doAll2._gradient[i2], 1.0E-4d);
            }
            GLMModel.GLMParameters gLMParameters2 = new GLMModel.GLMParameters(GLMModel.GLMParameters.Family.gaussian, GLMModel.GLMParameters.Family.gaussian.defaultLink, new double[]{0.0d}, new double[]{0.0d}, 0.0d, 0.0d);
            gLMParameters2._use_all_factor_levels = false;
            dataInfo2.remove();
            dataInfo = new DataInfo(frame, (Frame) null, 1, gLMParameters2._use_all_factor_levels || gLMParameters2._lambda_search, gLMParameters2._standardize ? DataInfo.TransformType.STANDARDIZE : DataInfo.TransformType.NONE, DataInfo.TransformType.NONE, true, false, false, false, false, false);
            DKV.put(dataInfo._key, dataInfo);
            double[] malloc8d2 = MemoryManager.malloc8d(dataInfo.fullN() + 1);
            Random random2 = new Random(1987654321L);
            for (int i3 = 0; i3 < malloc8d2.length; i3++) {
                malloc8d2[i3] = 1.0d - (2.0d * random2.nextDouble());
            }
            GLMTask.GLMGradientTask doAll3 = new GLMTask.GLMGaussianGradientTask((Key) null, dataInfo, gLMParameters2, gLMParameters2._lambda[0], malloc8d2).doAll(dataInfo._adaptedFrame);
            GLMTask.GLMGradientTask doAll4 = new GLMTask.GLMGenericGradientTask((Key) null, dataInfo, gLMParameters2, gLMParameters2._lambda[0], malloc8d2).doAll(dataInfo._adaptedFrame);
            for (int i4 = 0; i4 < malloc8d2.length; i4++) {
                Assert.assertEquals("gradients differ: " + Arrays.toString(doAll3._gradient) + " != " + Arrays.toString(doAll4._gradient), doAll3._gradient[i4], doAll4._gradient[i4], 1.0E-4d);
            }
            dataInfo.remove();
            if (frame != null) {
                frame.delete();
            }
            if (dataInfo != null) {
                dataInfo.remove();
            }
        } catch (Throwable th) {
            if (frame != null) {
                frame.delete();
            }
            if (dataInfo != null) {
                dataInfo.remove();
            }
            throw th;
        }
    }

    /* JADX WARN: Type inference failed for: r0v4, types: [double[], double[][]] */
    @Test
    public void testMultinomialGradient() {
        Key make = Key.make("covtype");
        Frame frame = null;
        ?? r0 = {new double[]{5.886754459d, -0.27047962d, -0.075466082d, -0.157524534d, -0.225843747d, -0.975387326d, -0.018808013d, -0.597839451d, 0.931896624d, 1.06000601d, 1.513888539d, 0.58880278d, 0.157815155d, -2.158268564d, -0.504962385d, -1.218970183d, -0.840958642d, -0.425931637d, -0.355548831d, -0.845035489d, -0.065364107d, 0.215897656d, 0.213009374d, 0.006831714d, 1.212368946d, 0.006106444d, -0.350643486d, -0.268207009d, -0.252099054d, -1.374010836d, 0.25793586d, 0.397459631d, 0.411530391d, 0.728368253d, 0.292076224d, 0.170774269d, -0.059574793d, 0.273670163d, 0.180844505d, -0.186483071d, 0.369186813d, 0.161909512d, 0.249411716d, -0.094481604d, 0.41335436d, -0.419043967d, 0.044517794d, -0.252596992d, -0.371926422d, 0.253835004d, 0.58816209d, 0.123330837d, 2.856812217d}, new double[]{1.89790254d, -0.29776886d, 0.15613197d, 0.37602123d, -0.36464436d, -0.30240244d, -0.5728437d, 0.62408956d, -0.22369305d, 0.33644602d, 0.798864d, 0.65351945d, -0.53682819d, -0.58319898d, -1.07762513d, -0.2852747d, 0.46563482d, -0.76956081d, -0.72513805d, 0.29857876d, 0.03993456d, 0.15835864d, -0.24797599d, -0.02483503d, 0.9382249d, -0.12406087d, -0.75837978d, -0.23516944d, -0.48520212d, 0.73571466d, 0.19652011d, 0.21602846d, -0.32743154d, 0.49421903d, -0.02262943d, 0.08093216d, 0.11524497d, 0.21657128d, 0.18072853d, 0.30872666d, 0.17947687d, 0.20156151d, 0.16812179d, -0.12286908d, 0.29630502d, 0.09992565d, -0.00603293d, 0.20700058d, -0.49706211d, -0.14534034d, -0.18819217d, 0.0364268d, 7.3182834d}, new double[]{-6.098728943d, 0.284144173d, 0.114373474d, 0.328977319d, 0.417830082d, 0.28569615d, -0.652674822d, 0.319136906d, -0.942440279d, -1.619235397d, -1.272568201d, -0.079855555d, 1.19126355d, 0.205102353d, 0.991773314d, 0.930363203d, 1.014021007d, 0.651243292d, 0.646532457d, 0.91433603d, 0.012171754d, -0.053042102d, 0.777710362d, 0.527369151d, -0.019496049d, 0.186290583d, 0.554926655d, 0.476911685d, 0.52920752d, -0.13324306d, -0.198957274d, -0.561552913d, -0.069239959d, -0.23660087d, -0.969503908d, -0.848089244d, 0.001498592d, -0.241007311d, -0.129271912d, -0.259961677d, -0.895676033d, -0.865827509d, -0.972629899d, 0.307756211d, -1.809423763d, -0.199557594d, 0.024221965d, -0.024834485d, 0.047044475d, 0.028951561d, -0.157701002d, 0.007940593d, -2.073329675d}, new double[]{-8.3604444d, 0.10541672d, -0.0162868d, -0.43787017d, 0.42383466d, 2.45802808d, 0.59818831d, 0.61971728d, -0.62598983d, 0.20261555d, -0.21909545d, 0.35125447d, -3.29155913d, 3.74668257d, 0.18126128d, -0.13948924d, 0.20465077d, -0.39930635d, 0.1570457d, -0.01036891d, 0.02822546d, -0.02349234d, -0.93922249d, -0.2002591d, 0.25184125d, 0.06415974d, 0.3527129d, 0.0460906d, 0.03018497d, -0.1064154d, 0.00354805d, -0.12194129d, 0.05115876d, 0.23981864d, -0.10007012d, 0.04773226d, 0.01217421d, 0.02367464d, 0.05552397d, 0.05343606d, -0.05818705d, -0.30055029d, -0.03898723d, 0.02322906d, -0.04908215d, 0.04274038d, 0.25045428d, 0.08561191d, 0.1522816d, 0.67005377d, 0.59311621d, 0.58814959d, -4.83776046d}, new double[]{-0.39251919d, 0.07053038d, 0.09397355d, 0.19394977d, -0.02030732d, -0.87489691d, 0.21295049d, 0.31800509d, -0.05347208d, -1.03491602d, 2.20106706d, -1.20895873d, 1.06158893d, -3.29214054d, -0.69334082d, 0.62309414d, -1.64753442d, 0.10189669d, -0.44746013d, -1.04084383d, -0.01997483d, -0.2335618d, 0.34384724d, 0.37566329d, -1.7931651d, 0.46183758d, -0.58814389d, 0.12072985d, 0.48349078d, 1.18956325d, 0.41962148d, 0.1876716d, -0.25252495d, -1.1367154d, 0.71488183d, 0.27405258d, -0.03527945d, 0.43124949d, -0.28740586d, 0.35165348d, 1.17594079d, 1.13893507d, 0.49423372d, 0.30525649d, 0.7080968d, 0.1666033d, -0.37726163d, -0.14687217d, -0.17079711d, -1.01897715d, -1.17494223d, -0.72698683d, 1.64022531d}, new double[]{-5.892381502d, 0.295534637d, -0.112763568d, 0.080283203d, 0.197113227d, 0.525435203d, 0.727252262d, -1.190672917d, 1.137103389d, -0.648526151d, -2.581362158d, -0.268338673d, 2.010179009d, 0.90207445d, 0.816138328d, 0.55707147d, 0.389932578d, 0.009422297d, 0.542270816d, 0.550653667d, 0.00521172d, -0.071954379d, 0.320008238d, 0.155814784d, -0.264213966d, 0.320538295d, 0.569730803d, 0.444518874d, 0.247279544d, -0.31948433d, -0.372129988d, 0.340944707d, -0.158424299d, -0.479426774d, 0.026966661d, 0.273389077d, -0.004744599d, -0.339321329d, -0.119323949d, -0.210123558d, -1.218998166d, -0.740525896d, 0.134778587d, 0.252701229d, 0.527468284d, 0.214164427d, -0.080104361d, -0.021448994d, 0.004509104d, -0.189729053d, -0.335041198d, -0.080698796d, -1.192518082d}, new double[]{12.9594170391d, -0.18737743d, -0.159962536d, -0.3838368119d, -0.427982539d, -1.1164727575d, -0.2940645257d, -0.0924364781d, -0.223404772d, 1.7036099945d, -0.4407937881d, -0.0364237384d, -0.5924593214d, 1.1797487023d, 0.2867554171d, -0.46679469d, 0.4142538835d, 0.8322365174d, 0.1822980332d, 0.1326797653d, -2.045542E-4d, 0.0077943238d, -0.4673767424d, -0.840584814d, -0.3255599769d, -0.9148717663d, 0.2197967986d, -0.5848745645d, -0.552861643d, 0.0078757154d, -0.3065382365d, -0.4586101971d, 0.3449315968d, 0.39033712d, 0.0582787537d, 0.0012089013d, -0.0293189213d, -0.3648369414d, 0.1189047254d, -0.0572478953d, 0.4482567793d, 0.4044976082d, -0.0349286763d, -0.6715923088d, -0.0867185553d, 0.0951677966d, 0.1442048837d, 0.1531401571d, 0.8359504674d, 0.4012062075d, 0.6745982951d, 0.051837806d, -3.7117127004d}};
        double[] dArr = {-8.955455E-5d, 6.429112E-4d, 4.384381E-4d, 0.001363695d, 4.714468E-4d, -0.002264769d, 4.412849E-4d, 0.00146176d, -2.957754E-5d, -0.002244325d, -0.002744438d, 9.109376E-4d, 0.001920764d, 7.562221E-4d, 1.840414E-4d, 2.455081E-4d, 3.077885E-4d, 2.833261E-4d, 1.248686E-4d, 2.509248E-4d, 9.68126E-6d, -1.097335E-4d, 0.001005934d, 5.623159E-4d, -0.002568397d, 0.0011139d, 1.263858E-4d, 9.075801E-5d, 8.056571E-5d, 1.848318E-4d, -1.291357E-4d, -3.71057E-4d, 5.693621E-5d, 1.328082E-4d, 3.244018E-4d, 4.130594E-4d, 9.681066E-6d, 5.21526E-4d, 4.054695E-4d, 2.904901E-5d, -0.003074865d, -1.247025E-4d, 0.001044981d, 8.612937E-4d, 0.001376526d, 4.543256E-5d, -4.596319E-6d, 3.062111E-5d, 5.649646E-5d, 5.392599E-4d, 9.681357E-4d, 2.298219E-4d, -0.001369109d, -6.884926E-4d, -9.921529E-4d, -5.369346E-4d, -0.001732447d, 5.677645E-4d, 0.001655432d, -4.78689E-4d, -8.688757E-4d, 2.922016E-4d, 0.00360121d, 0.004050781d, -6.409806E-4d, -0.002788663d, -0.001426483d, -1.946904E-4d, -8.279536E-4d, -3.148338E-4d, 2.263577E-6d, -1.320917E-4d, 3.635088E-4d, -1.024655E-5d, 1.079612E-4d, -0.001607591d, -1.801967E-4d, 0.002548311d, -0.001007139d, -1.33699E-4d, 2.538803E-4d, -4.851292E-4d, -9.168206E-4d, 1.027708E-4d, 0.001061545d, -4.098038E-5d, 1.070448E-4d, 3.220238E-4d, -7.011285E-4d, -1.024153E-5d, -7.96738E-4d, -2.708138E-4d, -2.698165E-4d, 0.003088978d, 4.260939E-4d, -5.868815E-4d, -0.001562233d, -0.001007565d, -2.034456E-4d, -6.198011E-4d, -3.277194E-5d, -5.976557E-5d, -0.001143198d, -0.001025416d, 3.671158E-4d, 0.001448332d, 0.001940231d, -6.130695E-4d, -0.00208646d, -2.969848E-4d, 1.455597E-4d, 0.001745515d, 0.002123991d, 9.036201E-4d, -5.270206E-4d, 0.001053891d, 0.001358911d, 2.528711E-4d, 1.326987E-4d, -0.001825879d, -6.085616E-4d, -1.347628E-4d, 3.499544E-4d, 3.616313E-4d, -7.008672E-4d, -0.001211077d, 1.117824E-5d, 3.535679E-5d, -0.002668903d, -2.399884E-4d, 3.979678E-4d, 2.519517E-4d, 1.113206E-4d, 6.029871E-4d, 3.512828E-4d, 2.134159E-4d, 7.590052E-5d, 1.729959E-4d, 4.472972E-5d, 2.094373E-4d, 3.136961E-4d, 1.83553E-4d, 1.117824E-5d, 8.225263E-5d, 4.330828E-5d, 3.354142E-5d, 7.452883E-4d, 4.631413E-4d, 2.054077E-4d, -5.520636E-5d, 2.818063E-4d, 5.246077E-5d, 1.131811E-4d, 3.535664E-5d, 6.52336E-5d, 3.072416E-4d, 2.913399E-4d, 2.42276E-4d, -0.001580841d, -1.117356E-4d, 2.573351E-4d, 8.117137E-4d, 1.168873E-4d, -4.216143E-4d, -5.847717E-5d, 3.501109E-4d, 2.344622E-4d, -1.330097E-4d, -5.948309E-4d, -2.349808E-4d, -4.495448E-5d, -1.916493E-4d, 5.017336E-4d, -8.440468E-5d, 4.767465E-4d, 2.485018E-4d, 2.060573E-4d, -1.527142E-4d, -9.268231E-6d, -1.985972E-6d, -6.285478E-6d, -2.214673E-5d, 5.82225E-4d, -7.069316E-5d, -4.387924E-5d, -2.774128E-4d, -5.455282E-4d, 3.186328E-4d, -3.793242E-5d, -1.349306E-5d, -3.070112E-5d, -7.951882E-6d, -3.723186E-5d, -5.571437E-5d, -3.26078E-5d, -1.987225E-6d, -1.462245E-5d, -7.699184E-6d, -5.962867E-6d, -1.316053E-4d, -8.10857E-5d, -3.651228E-5d, -5.312255E-5d, -5.009791E-5d, -9.325808E-6d, -2.012086E-5d, -6.285571E-6d, -1.159698E-5d, -5.462022E-5d, -5.17931E-5d, -4.307092E-5d, 2.81036E-4d, 3.869942E-4d, -3.450936E-5d, -7.805675E-5d, 6.405561E-4d, -2.284402E-4d, -1.866295E-4d, -4.858359E-4d, 3.49689E-4d, 7.35278E-4d, 5.767877E-4d, -8.477014E-4d, -5.512698E-5d, 0.001091158d, -1.900036E-4d, -4.632766E-5d, 1.086153E-5d, -7.743051E-5d, -7.545391E-4d, -3.143243E-5d, -6.316374E-5d, -2.435782E-6d, -7.707894E-6d, 4.451785E-4d, 2.043479E-4d, -8.673378E-5d, -3.314975E-5d, -3.181369E-5d, -5.422704E-4d, -9.020739E-5d, 6.747588E-4d, 5.997742E-6d, -9.729086E-4d, -9.75149E-6d, -4.565744E-5d, -4.181943E-4d, 7.522183E-4d, -2.436958E-6d, 2.531532E-4d, -9.4416E-6d, 2.317743E-4d, 4.254207E-4d, -3.224488E-4d, 3.979052E-4d, 2.066697E-4d, 2.486194E-5d, 1.189306E-4d, -2.465884E-5d, -7.708071E-6d, -1.422152E-5d, -6.697064E-5d, -6.351172E-5d, -5.28106E-5d, 3.446379E-4d, -0.001212986d, 9.206612E-4d, 6.469824E-4d, -6.605882E-4d, -1.646537E-5d, -6.854543E-4d, -0.002079925d, -0.001031449d, 3.926585E-4d, -0.001556234d, -0.001129748d, -2.11348E-4d, -4.922559E-4d, 0.001938461d, 6.900824E-4d, 1.497533E-4d, -6.140808E-4d, -3.365137E-4d, 8.516225E-4d, 5.874586E-4d, -9.342693E-6d, -2.955083E-5d, 0.002692614d, -9.928211E-4d, -3.326157E-4d, -3.572773E-4d, 1.641113E-4d, 7.442831E-5d, -2.543959E-4d, -1.783712E-4d, -6.343638E-5d, 9.077554E-5d, -3.73848E-5d, -1.750387E-4d, -6.56848E-4d, -2.035799E-4d, -9.342694E-6d, -6.874421E-5d, -3.619677E-5d, -2.803369E-5d, -6.228932E-4d, -3.870861E-4d, -0.001103792d, 9.58536E-4d, -7.037269E-5d, 2.736606E-4d, -9.459508E-5d, -2.955084E-5d, -5.45218E-5d, -2.567899E-4d, -2.43493E-4d, -2.024919E-4d, 0.001321256d, -2.244563E-4d, -1.811758E-4d, 8.043173E-4d, 5.68882E-4d, -5.182511E-4d, -2.056167E-4d, 1.290635E-4d, -0.001049207d, -7.305304E-4d, -8.364983E-4d, -4.528248E-4d, -2.113987E-4d, 3.279472E-4d, 2.459491E-4d, 5.986061E-5d, 7.984705E-5d, 1.001005E-4d, 2.377746E-4d, 4.061439E-5d, 8.161668E-5d, 3.151497E-6d, 9.959707E-6d, 1.54914E-4d, 6.411739E-5d, 1.121613E-4d, 7.559378E-5d, 4.110778E-5d, 6.574476E-5d, 7.925128E-5d, 6.01177E-5d, 2.139605E-5d, 4.934971E-5d, -5.597385E-6d, -1.913622E-4d, 1.706349E-4d, -4.115145E-4d, 3.149101E-6d, 2.317293E-5d, -1.246264E-4d, 9.448371E-6d, -4.303234E-4d, 2.608783E-5d, 7.889196E-5d, -3.559375E-4d, -5.551586E-4d, -2.777131E-4d, 6.505911E-4d, 1.033867E-5d, 1.837583E-5d, 6.750772E-4d, 1.247379E-4d, -5.408403E-4d, -4.453114E-4d};
        Vec vec = null;
        try {
            frame = parse_test_file(make, "smalldata/covtype/covtype.20k.data");
            frame.remove("C21").remove();
            frame.remove("C29").remove();
            GLMModel.GLMParameters gLMParameters = new GLMModel.GLMParameters(GLMModel.GLMParameters.Family.multinomial);
            gLMParameters._response_column = "C55";
            gLMParameters._ignored_columns = new String[0];
            gLMParameters._train = make;
            gLMParameters._lambda = new double[]{0.0d};
            gLMParameters._alpha = new double[]{0.0d};
            vec = frame.remove("C55");
            Vec add = frame.add("C55", vec.toCategoricalVec());
            double[] dArr2 = new double[add.domain().length];
            long[] bins = add.bins();
            double sum = 1.0d / ArrayUtils.sum(bins);
            for (int i = 0; i < bins.length; i++) {
                dArr2[i] = bins[i] * sum;
            }
            DataInfo dataInfo = new DataInfo(frame, (Frame) null, 1, true, DataInfo.TransformType.STANDARDIZE, DataInfo.TransformType.NONE, true, false, false, false, false, false);
            GLMTask.GLMMultinomialGradientBaseTask doAll = new GLMTask.GLMMultinomialGradientTask((Job) null, dataInfo, 0.0d, (double[][]) r0, 1.0d / frame.numRows()).doAll(dataInfo._adaptedFrame);
            Assert.assertEquals(0.6421113d, doAll._likelihood / frame.numRows(), 1.0E-8d);
            System.out.println("likelihood = " + (doAll._likelihood / frame.numRows()));
            double[] gradient = doAll.gradient();
            for (int i2 = 0; i2 < gradient.length; i2++) {
                Assert.assertEquals("Mismatch at coefficient '' (" + i2 + ")", dArr[i2], gradient[i2], 1.0E-8d);
            }
            if (vec != null) {
                vec.remove();
            }
            if (frame != null) {
                frame.delete();
            }
        } catch (Throwable th) {
            if (vec != null) {
                vec.remove();
            }
            if (frame != null) {
                frame.delete();
            }
            throw th;
        }
    }

    @Test
    public void testCars() throws InterruptedException, ExecutionException {
        Scope.enter();
        Key make = Key.make("cars_parsed");
        Frame frame = null;
        GLMModel gLMModel = null;
        Frame frame2 = null;
        try {
            frame = parse_test_file(make, "smalldata/junit/cars.csv");
            GLMModel.GLMParameters gLMParameters = new GLMModel.GLMParameters(GLMModel.GLMParameters.Family.poisson, GLMModel.GLMParameters.Family.poisson.defaultLink, new double[]{0.0d}, new double[]{0.0d}, 0.0d, 0.0d);
            gLMParameters._response_column = "power (hp)";
            gLMParameters._ignored_columns = new String[]{"name"};
            gLMParameters._train = make;
            gLMParameters._lambda = new double[]{0.0d};
            gLMParameters._alpha = new double[]{0.0d};
            gLMParameters._missing_values_handling = GLMModel.GLMParameters.MissingValuesHandling.Skip;
            GLMModel gLMModel2 = new GLM(gLMParameters).trainModel().get();
            HashMap coefficients = gLMModel2.coefficients();
            String[] strArr = {"Intercept", "economy (mpg)", "cylinders", "displacement (cc)", "weight (lb)", "0-60 mph (s)", "year"};
            double[] dArr = {4.9504805d, -0.0095859d, -0.0063046d, 4.392E-4d, 1.762E-4d, -0.046981d, 2.891E-4d};
            for (int i = 0; i < strArr.length; i++) {
                Assert.assertEquals(dArr[i], ((Double) coefficients.get(strArr[i])).doubleValue(), 1.0E-4d);
            }
            double[] dArr2 = {0.008992d, 1.818E-4d, -1.125E-4d, 1.505E-6d, -1.284E-6d, 4.51E-4d, -7.254E-5d};
            testScoring(gLMModel2, frame);
            gLMModel2.delete();
            GLMModel.GLMParameters gLMParameters2 = new GLMModel.GLMParameters(GLMModel.GLMParameters.Family.gamma, GLMModel.GLMParameters.Family.gamma.defaultLink, new double[]{0.0d}, new double[]{0.0d}, 0.0d, 0.0d);
            gLMParameters2._response_column = "power (hp)";
            gLMParameters2._ignored_columns = new String[]{"name"};
            gLMParameters2._train = make;
            gLMParameters2._lambda = new double[]{0.0d};
            gLMParameters2._beta_epsilon = 1.0E-5d;
            gLMParameters2._missing_values_handling = GLMModel.GLMParameters.MissingValuesHandling.Skip;
            GLMModel gLMModel3 = new GLM(gLMParameters2).trainModel().get();
            HashMap coefficients2 = gLMModel3.coefficients();
            for (int i2 = 0; i2 < strArr.length; i2++) {
                Assert.assertEquals(dArr2[i2], ((Double) coefficients2.get(strArr[i2])).doubleValue(), 1.0E-4d);
            }
            testScoring(gLMModel3, frame);
            gLMModel3.delete();
            double[] dArr3 = {166.95862d, -0.00531d, -2.4669d, 0.12635d, 0.02159d, -4.66995d, -0.85724d};
            GLMModel.GLMParameters gLMParameters3 = new GLMModel.GLMParameters(GLMModel.GLMParameters.Family.gaussian);
            gLMParameters3._response_column = "power (hp)";
            gLMParameters3._ignored_columns = new String[]{"name"};
            gLMParameters3._train = make;
            gLMParameters3._lambda = new double[]{0.0d};
            gLMParameters3._missing_values_handling = GLMModel.GLMParameters.MissingValuesHandling.Skip;
            gLMModel = (GLMModel) new GLM(gLMParameters3).trainModel().get();
            HashMap coefficients3 = gLMModel.coefficients();
            for (int i3 = 0; i3 < strArr.length; i3++) {
                Assert.assertEquals(dArr3[i3], ((Double) coefficients3.get(strArr[i3])).doubleValue(), 1.0E-4d);
            }
            if (frame != null) {
                frame.delete();
            }
            if (0 != 0) {
                frame2.delete();
            }
            if (gLMModel != null) {
                gLMModel.delete();
            }
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            if (frame != null) {
                frame.delete();
            }
            if (0 != 0) {
                frame2.delete();
            }
            if (gLMModel != null) {
                gLMModel.delete();
            }
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    @Test
    public void testBounds() {
        GLMModel gLMModel = null;
        Key make = Key.make("prostate_parsed");
        Key make2 = Key.make("prostate_model");
        Frame parse_test_file = parse_test_file(make, "smalldata/logreg/prostate.csv");
        Key make3 = Key.make("beta_constraints");
        String[] strArr = {"AGE", "RACE", "DPROS", "DCAPS", "PSA", "VOL", "GLEASON", "Intercept"};
        double[] dArr = {-0.006502588d, -0.5d, 0.5d, 0.4d, 0.034826559d, -0.011661747d, 0.5d, -4.564024d};
        FVecTest.makeByteVec(make3, new String[]{"names, lower_bounds, upper_bounds\n AGE, -.5, .5\n RACE, -.5, .5\n DCAPS, -.4, .4\n DPROS, -.5, .5 \nPSA, -.5, .5\n VOL, -.5, .5\nGLEASON, -.5, .5"});
        Frame parse = ParseDataset.parse(Key.make("beta_constraints.hex"), new Key[]{make3});
        try {
            GLMModel.GLMParameters gLMParameters = new GLMModel.GLMParameters();
            gLMParameters._standardize = true;
            gLMParameters._family = GLMModel.GLMParameters.Family.binomial;
            gLMParameters._beta_constraints = parse._key;
            gLMParameters._response_column = "CAPSULE";
            gLMParameters._ignored_columns = new String[]{"ID"};
            gLMParameters._train = parse_test_file._key;
            gLMParameters._objective_epsilon = 0.0d;
            gLMParameters._alpha = new double[]{1.0d};
            gLMParameters._lambda = new double[]{0.001607d};
            gLMParameters._obj_reg = 0.002631578947368421d;
            GLM glm = new GLM(gLMParameters, make2);
            GLMModel gLMModel2 = glm.trainModel().get();
            Assert.assertTrue(glm.isStopped());
            ModelMetricsBinomialGLM modelMetricsBinomialGLM = gLMModel2._output._training_metrics;
            Assert.assertEquals(512.2888d, modelMetricsBinomialGLM._nullDev, 0.1d);
            Assert.assertTrue(modelMetricsBinomialGLM._resDev <= 388.5d);
            gLMModel2.delete();
            gLMParameters._lambda = new double[]{0.0d};
            gLMParameters._alpha = new double[]{0.0d};
            FVecTest.makeByteVec(make3, new String[]{"names, lower_bounds, upper_bounds\n RACE, -.5, .5\n DCAPS, -.4, .4\n DPROS, -.5, .5 \nPSA, -.5, .5\n VOL, -.5, .5"});
            parse = ParseDataset.parse(Key.make("beta_constraints.hex"), new Key[]{make3});
            GLM glm2 = new GLM(gLMParameters, make2);
            gLMModel = (GLMModel) glm2.trainModel().get();
            Assert.assertTrue(glm2.isStopped());
            double[] beta = gLMModel.beta();
            System.out.println("beta = " + Arrays.toString(beta));
            parse_test_file.add("CAPSULE", parse_test_file.remove("CAPSULE"));
            parse_test_file.remove("ID").remove();
            DKV.put(parse_test_file._key, parse_test_file);
            DataInfo dataInfo = new DataInfo(parse_test_file, (Frame) null, 1, true, DataInfo.TransformType.NONE, DataInfo.TransformType.NONE, true, false, false, false, false, false);
            double[] dArr2 = new GLMTask.GLMBinomialGradientTask((Key) null, dataInfo, gLMParameters, 0.0d, beta).doAll(dataInfo._adaptedFrame)._gradient;
            String[] coefNames = gLMModel.dinfo().coefNames();
            BufferedString bufferedString = new BufferedString();
            for (int i = 0; i < coefNames.length; i++) {
                int i2 = 0;
                while (true) {
                    if (i2 >= parse.numRows()) {
                        Assert.assertEquals(0.0d, dArr2[i], 0.01d);
                        break;
                    } else if (!parse.vec("names").atStr(bufferedString, i2).toString().equals(coefNames[i]) || (Math.abs(beta[i] - parse.vec("lower_bounds").at(i2)) >= 1.0E-4d && Math.abs(beta[i] - parse.vec("upper_bounds").at(i2)) >= 1.0E-4d)) {
                        i2++;
                    }
                }
            }
            parse_test_file.delete();
            parse.delete();
            if (gLMModel != null) {
                gLMModel.delete();
            }
        } catch (Throwable th) {
            parse_test_file.delete();
            parse.delete();
            if (gLMModel != null) {
                gLMModel.delete();
            }
            throw th;
        }
    }

    @Test
    public void testCoordinateDescent_airlines() {
        GLMModel gLMModel = null;
        Key make = Key.make("airlines_parsed");
        Key make2 = Key.make("airlines_model");
        Frame parse_test_file = parse_test_file(make, "smalldata/airlines/AirlinesTrain.csv.zip");
        try {
            GLMModel.GLMParameters gLMParameters = new GLMModel.GLMParameters();
            gLMParameters._standardize = true;
            gLMParameters._family = GLMModel.GLMParameters.Family.binomial;
            gLMParameters._solver = GLMModel.GLMParameters.Solver.COORDINATE_DESCENT_NAIVE;
            gLMParameters._response_column = "IsDepDelayed";
            gLMParameters._ignored_columns = new String[]{"IsDepDelayed_REC"};
            gLMParameters._train = parse_test_file._key;
            GLM glm = new GLM(gLMParameters, make2);
            gLMModel = (GLMModel) glm.trainModel().get();
            Assert.assertTrue(glm.isStopped());
            System.out.println(gLMModel._output._training_metrics);
            parse_test_file.delete();
            if (gLMModel != null) {
                gLMModel.delete();
            }
        } catch (Throwable th) {
            parse_test_file.delete();
            if (gLMModel != null) {
                gLMModel.delete();
            }
            throw th;
        }
    }

    @Test
    public void testCoordinateDescent_airlines_CovUpdates() {
        GLMModel gLMModel = null;
        Key make = Key.make("airlines_parsed");
        Key make2 = Key.make("airlines_model");
        Frame parse_test_file = parse_test_file(make, "smalldata/airlines/AirlinesTrain.csv.zip");
        try {
            GLMModel.GLMParameters gLMParameters = new GLMModel.GLMParameters();
            gLMParameters._standardize = true;
            gLMParameters._family = GLMModel.GLMParameters.Family.binomial;
            gLMParameters._solver = GLMModel.GLMParameters.Solver.COORDINATE_DESCENT;
            gLMParameters._response_column = "IsDepDelayed";
            gLMParameters._ignored_columns = new String[]{"IsDepDelayed_REC"};
            gLMParameters._train = parse_test_file._key;
            GLM glm = new GLM(gLMParameters, make2);
            gLMModel = (GLMModel) glm.trainModel().get();
            Assert.assertTrue(glm.isStopped());
            System.out.println(gLMModel._output._training_metrics);
            parse_test_file.delete();
            if (gLMModel != null) {
                gLMModel.delete();
            }
        } catch (Throwable th) {
            parse_test_file.delete();
            if (gLMModel != null) {
                gLMModel.delete();
            }
            throw th;
        }
    }

    @Test
    public void testCoordinateDescent_anomaly() {
        GLMModel gLMModel = null;
        Key make = Key.make("anomaly_parsed");
        Key make2 = Key.make("anomaly_model");
        Frame parse_test_file = parse_test_file(make, "smalldata/anomaly/ecg_discord_train.csv");
        try {
            GLMModel.GLMParameters gLMParameters = new GLMModel.GLMParameters();
            gLMParameters._standardize = true;
            gLMParameters._family = GLMModel.GLMParameters.Family.gaussian;
            gLMParameters._solver = GLMModel.GLMParameters.Solver.COORDINATE_DESCENT_NAIVE;
            gLMParameters._response_column = "C1";
            gLMParameters._train = parse_test_file._key;
            GLM glm = new GLM(gLMParameters, make2);
            gLMModel = (GLMModel) glm.trainModel().get();
            Assert.assertTrue(glm.isStopped());
            System.out.println(gLMModel._output._training_metrics);
            parse_test_file.delete();
            if (gLMModel != null) {
                gLMModel.delete();
            }
        } catch (Throwable th) {
            parse_test_file.delete();
            if (gLMModel != null) {
                gLMModel.delete();
            }
            throw th;
        }
    }

    @Test
    public void testCoordinateDescent_anomaly_CovUpdates() {
        GLMModel gLMModel = null;
        Key make = Key.make("anomaly_parsed");
        Key make2 = Key.make("anomaly_model");
        Frame parse_test_file = parse_test_file(make, "smalldata/anomaly/ecg_discord_train.csv");
        try {
            GLMModel.GLMParameters gLMParameters = new GLMModel.GLMParameters();
            gLMParameters._standardize = true;
            gLMParameters._family = GLMModel.GLMParameters.Family.gaussian;
            gLMParameters._solver = GLMModel.GLMParameters.Solver.COORDINATE_DESCENT;
            gLMParameters._response_column = "C1";
            gLMParameters._train = parse_test_file._key;
            GLM glm = new GLM(gLMParameters, make2);
            gLMModel = (GLMModel) glm.trainModel().get();
            Assert.assertTrue(glm.isStopped());
            System.out.println(gLMModel._output._training_metrics);
            parse_test_file.delete();
            if (gLMModel != null) {
                gLMModel.delete();
            }
        } catch (Throwable th) {
            parse_test_file.delete();
            if (gLMModel != null) {
                gLMModel.delete();
            }
            throw th;
        }
    }

    @Test
    public void testProximal() {
        Key make = Key.make("prostate_parsed");
        Key make2 = Key.make("prostate_model");
        GLMModel gLMModel = null;
        Frame parse_test_file = parse_test_file(make, "smalldata/logreg/prostate.csv");
        parse_test_file.remove("ID").remove();
        DKV.put(parse_test_file._key, parse_test_file);
        Key make3 = Key.make("beta_constraints");
        FVecTest.makeByteVec(make3, new String[]{"names, beta_given, rho\n AGE, 0.1, 1\n RACE, -0.1, 1 \n DPROS, 10, 1 \n DCAPS, -10, 1 \n PSA, 0, 1\n VOL, 0, 1\nGLEASON, 0, 1\n Intercept, 0, 0 \n"});
        Frame parse = ParseDataset.parse(Key.make("beta_constraints.hex"), new Key[]{make3});
        try {
            GLMModel.GLMParameters gLMParameters = new GLMModel.GLMParameters();
            gLMParameters._standardize = false;
            gLMParameters._family = GLMModel.GLMParameters.Family.binomial;
            gLMParameters._beta_constraints = parse._key;
            gLMParameters._response_column = "CAPSULE";
            gLMParameters._ignored_columns = new String[]{"ID"};
            gLMParameters._train = parse_test_file._key;
            gLMParameters._alpha = new double[]{0.0d};
            gLMParameters._lambda = new double[]{0.0d};
            gLMParameters._obj_reg = 0.002631578947368421d;
            gLMParameters._objective_epsilon = 0.0d;
            double[] beta = new GLM(gLMParameters, make2).trainModel().get().beta();
            gLMParameters._solver = GLMModel.GLMParameters.Solver.L_BFGS;
            gLMParameters._max_iterations = 1000;
            gLMModel = (GLMModel) new GLM(gLMParameters, make2).trainModel().get();
            parse_test_file.add("CAPSULE", parse_test_file.remove("CAPSULE"));
            DataInfo dataInfo = new DataInfo(parse_test_file, (Frame) null, 1, true, DataInfo.TransformType.NONE, DataInfo.TransformType.NONE, true, false, false, false, false, false);
            double[] dArr = new GLMTask.GLMBinomialGradientTask((Key) null, dataInfo, gLMParameters, 0.0d, beta).doAll(dataInfo._adaptedFrame)._gradient;
            for (int i = 0; i < beta.length; i++) {
                Assert.assertEquals(0.0d, dArr[i] + (parse.vec("rho").at(i) * (beta[i] - parse.vec("beta_given").at(i))), 1.0E-4d);
            }
            parse.delete();
            parse_test_file.delete();
            if (gLMModel != null) {
                gLMModel.delete();
            }
        } catch (Throwable th) {
            parse.delete();
            parse_test_file.delete();
            if (gLMModel != null) {
                gLMModel.delete();
            }
            throw th;
        }
    }

    @Test
    public void testSparseGramComputation() {
        Random random = new Random(123456789L);
        double[] malloc8d = MemoryManager.malloc8d(1000);
        double[] malloc8d2 = MemoryManager.malloc8d(1000);
        double[] malloc8d3 = MemoryManager.malloc8d(1000);
        double[] malloc8d4 = MemoryManager.malloc8d(1000);
        double[] malloc8d5 = MemoryManager.malloc8d(1000);
        double[] malloc8d6 = MemoryManager.malloc8d(1000);
        double[] malloc8d7 = MemoryManager.malloc8d(1000);
        double[] malloc8d8 = MemoryManager.malloc8d(1000);
        double[] malloc8d9 = MemoryManager.malloc8d(1000);
        double[] malloc8d10 = MemoryManager.malloc8d(1000);
        long[] malloc8 = MemoryManager.malloc8(1000);
        long[] malloc82 = MemoryManager.malloc8(1000);
        String[] strArr = {"a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z"};
        for (int i = 0; i < malloc8d2.length; i++) {
            malloc8[i] = random.nextInt(strArr.length);
            malloc82[i] = random.nextInt(strArr.length);
            malloc8d[i] = random.nextDouble();
            malloc8d2[i] = random.nextDouble();
        }
        for (int i2 = 0; i2 < 30; i2++) {
            malloc8d3[random.nextInt(malloc8d3.length)] = random.nextDouble();
            malloc8d4[random.nextInt(malloc8d3.length)] = random.nextDouble();
            malloc8d5[random.nextInt(malloc8d3.length)] = random.nextDouble();
            malloc8d6[random.nextInt(malloc8d3.length)] = random.nextDouble();
            malloc8d7[random.nextInt(malloc8d3.length)] = random.nextDouble();
            malloc8d8[random.nextInt(malloc8d3.length)] = random.nextDouble();
            malloc8d9[random.nextInt(malloc8d3.length)] = random.nextDouble();
            malloc8d10[random.nextInt(malloc8d3.length)] = 1.0d;
        }
        Vec.VectorGroup vectorGroup = Vec.VectorGroup.VG_LEN1;
        Vec makeVec = Vec.makeVec(malloc8, strArr, vectorGroup.addVec());
        Vec makeVec2 = Vec.makeVec(malloc82, strArr, vectorGroup.addVec());
        Vec makeVec3 = Vec.makeVec(malloc8d, vectorGroup.addVec());
        Vec makeVec4 = Vec.makeVec(malloc8d2, vectorGroup.addVec());
        Vec makeVec5 = Vec.makeVec(malloc8d3, vectorGroup.addVec());
        Frame frame = new Frame(Key.make("TestData"), (String[]) null, new Vec[]{makeVec, makeVec2, makeVec3, makeVec4, makeVec5, makeVec5, Vec.makeVec(malloc8d4, vectorGroup.addVec()), Vec.makeVec(malloc8d5, vectorGroup.addVec()), Vec.makeVec(malloc8d6, vectorGroup.addVec()), Vec.makeVec(malloc8d7, vectorGroup.addVec()), Vec.makeVec(malloc8d8, vectorGroup.addVec()), Vec.makeVec(malloc8d9, vectorGroup.addVec()), Vec.makeVec(malloc8d10, vectorGroup.addVec())});
        DKV.put(frame);
        DataInfo dataInfo = new DataInfo(frame, (Frame) null, 1, true, DataInfo.TransformType.STANDARDIZE, DataInfo.TransformType.NONE, true, false, false, false, false, false);
        GLMModel.GLMParameters gLMParameters = new GLMModel.GLMParameters(GLMModel.GLMParameters.Family.gaussian);
        GLMTask.GLMIterationTask doAll = new GLMTask.GLMIterationTask((Key) null, dataInfo, new GLMModel.GLMWeightsFun(gLMParameters), (double[]) null).setSparse(true).doAll(dataInfo._adaptedFrame);
        final GLMTask.GLMIterationTask doAll2 = new GLMTask.GLMIterationTask((Key) null, dataInfo, new GLMModel.GLMWeightsFun(gLMParameters), (double[]) null).setSparse(false).doAll(dataInfo._adaptedFrame);
        for (int i3 = 0; i3 < doAll2._xy.length; i3++) {
            for (int i4 = 0; i4 <= i3; i4++) {
                Assert.assertEquals(doAll2._gram.get(i3, i4), doAll._gram.get(i3, i4), 1.0E-8d);
            }
            Assert.assertEquals(doAll2._xy[i3], doAll._xy[i3], 1.0E-8d);
        }
        final double[] malloc8d11 = MemoryManager.malloc8d(dataInfo.fullN() + 1);
        ((AnonymousClass1) H2O.submitTask(new H2O.H2OCountedCompleter() { // from class: hex.glm.GLMTest.1
            public void compute2() {
                new GLM.GramSolver(doAll2._gram, doAll2._xy, true, 1.0E-5d, 0.0d, (double[]) null, (double[]) null, (double[]) null, (double[]) null).solve((double[]) null, malloc8d11);
                tryComplete();
            }
        })).join();
        GLMTask.GLMIterationTask doAll3 = new GLMTask.GLMIterationTask((Key) null, dataInfo, new GLMModel.GLMWeightsFun(gLMParameters), malloc8d11).setSparse(true).doAll(dataInfo._adaptedFrame);
        GLMTask.GLMIterationTask doAll4 = new GLMTask.GLMIterationTask((Key) null, dataInfo, new GLMModel.GLMWeightsFun(gLMParameters), malloc8d11).setSparse(false).doAll(dataInfo._adaptedFrame);
        for (int i5 = 0; i5 < doAll4._xy.length; i5++) {
            for (int i6 = 0; i6 <= i5; i6++) {
                Assert.assertEquals(doAll4._gram.get(i5, i6), doAll3._gram.get(i5, i6), 1.0E-8d);
            }
            Assert.assertEquals(doAll4._xy[i5], doAll3._xy[i5], 1.0E-8d);
        }
        dataInfo.remove();
        frame.delete();
    }

    /* JADX WARN: Type inference failed for: r0v15, types: [hex.glm.GLMTest$2] */
    @Test
    @Ignore
    public void testConstantColumns() {
        Frame parse_test_file = parse_test_file(Key.make("Airlines"), "smalldata/airlines/allyears2k_headers.zip");
        parse_test_file.replace(parse_test_file.find("IsDepDelayed"), parse_test_file.vec("IsDepDelayed").makeCopy((String[]) null)).remove();
        Vec makeZero = parse_test_file.anyVec().makeZero();
        new MRTask() { // from class: hex.glm.GLMTest.2
            public void map(Chunk chunk) {
                for (int i = 0; i < chunk._len && chunk.start() + i < 1999; i++) {
                    chunk.set(i, 1L);
                }
            }
        }.doAll(new Vec[]{makeZero});
        parse_test_file.add("weights", makeZero);
        DKV.put(parse_test_file);
        GLMModel.GLMParameters gLMParameters = new GLMModel.GLMParameters(GLMModel.GLMParameters.Family.gaussian);
        gLMParameters._train = parse_test_file._key;
        gLMParameters._weights_column = "weights";
        gLMParameters._lambda_search = true;
        gLMParameters._alpha = new double[]{0.0d};
        gLMParameters._response_column = "IsDepDelayed";
        gLMParameters._ignored_columns = new String[]{"DepTime", "ArrTime", "Cancelled", "CancellationCode", "DepDelay", "Diverted", "CarrierDelay", "WeatherDelay", "NASDelay", "SecurityDelay", "LateAircraftDelay", "IsArrDelayed"};
        gLMParameters._standardize = true;
        new GLM(gLMParameters).trainModel().get().delete();
        parse_test_file.delete();
    }

    @Test
    public void testAirlines() {
        GLMModel gLMModel = null;
        GLMModel gLMModel2 = null;
        GLMModel gLMModel3 = null;
        GLMModel gLMModel4 = null;
        Frame parse_test_file = parse_test_file(Key.make("AirlinesMM"), "smalldata/airlines/AirlinesTrainMM.csv.zip");
        Frame parse_test_file2 = parse_test_file(Key.make("gram"), "smalldata/airlines/gram_std.csv");
        Vec remove = parse_test_file2.remove("xy");
        parse_test_file.remove("C1").remove();
        Vec remove2 = parse_test_file.remove("IsDepDelayed");
        parse_test_file.add("IsDepDelayed", remove2.makeCopy((String[]) null));
        remove2.remove();
        DKV.put(parse_test_file._key, parse_test_file);
        Frame parse_test_file3 = parse_test_file(Key.make("Airlines"), "smalldata/airlines/AirlinesTrain.csv.zip");
        Frame frame = null;
        Vec remove3 = parse_test_file3.remove("IsDepDelayed");
        parse_test_file3.add("IsDepDelayed", remove3.makeCopy((String[]) null));
        remove3.remove();
        DKV.put(parse_test_file3._key, parse_test_file3);
        String[] strArr = {"fYear", "fMonth", "fDayofMonth", "fDayOfWeek", "DepTime", "ArrTime", "IsDepDelayed_REC"};
        try {
            Scope.enter();
            GLMModel.GLMParameters gLMParameters = new GLMModel.GLMParameters(GLMModel.GLMParameters.Family.gaussian);
            gLMParameters._response_column = "IsDepDelayed";
            gLMParameters._ignored_columns = strArr;
            gLMParameters._train = parse_test_file3._key;
            gLMParameters._lambda = new double[]{0.0d};
            gLMParameters._alpha = new double[]{0.0d};
            gLMParameters._standardize = false;
            gLMParameters._use_all_factor_levels = false;
            gLMModel = (GLMModel) new GLM(gLMParameters).trainModel().get();
            testScoring(gLMModel, parse_test_file3);
            Frame score = gLMModel.score(parse_test_file3);
            ModelMetricsRegressionGLM fromDKV = ModelMetrics.getFromDKV(gLMModel, parse_test_file3);
            Assert.assertEquals(gLMModel._output._training_metrics._resDev, fromDKV._resDev, 1.0E-4d);
            Assert.assertEquals(gLMModel._output._training_metrics._resDev, fromDKV._MSE * score.numRows(), 1.0E-4d);
            score.delete();
            fromDKV.remove();
            frame = gLMModel.score(parse_test_file3);
            gLMParameters._train = parse_test_file._key;
            gLMParameters._ignored_columns = new String[]{"X"};
            gLMModel2 = (GLMModel) new GLM(gLMParameters).trainModel().get();
            HashMap coefficients = gLMModel.coefficients();
            testScoring(gLMModel2, parse_test_file);
            HashMap coefficients2 = gLMModel2.coefficients();
            boolean z = false;
            for (String str : coefficients2.keySet()) {
                String str2 = str;
                if (str.startsWith("Origin")) {
                    str2 = "Origin." + str.substring(6);
                }
                if (str.startsWith("Dest")) {
                    str2 = "Dest." + str.substring(4);
                }
                if (str.startsWith("UniqueCarrier")) {
                    str2 = "UniqueCarrier." + str.substring(13);
                }
                if (Math.abs(((Double) coefficients.get(str2)).doubleValue() - ((Double) coefficients2.get(str)).doubleValue()) > 1.0E-4d) {
                    System.out.println("coeff " + str2 + " differs, " + coefficients.get(str2) + " != " + coefficients2.get(str));
                    z = true;
                }
            }
            Assert.assertFalse(z);
            gLMParameters._standardize = true;
            gLMParameters._train = parse_test_file._key;
            gLMParameters._use_all_factor_levels = true;
            DataInfo dataInfo = new DataInfo(parse_test_file, (Frame) null, 1, true, DataInfo.TransformType.STANDARDIZE, DataInfo.TransformType.NONE, true, false, false, false, false, false);
            GLMTask.GLMIterationTask doAll = new GLMTask.GLMIterationTask((Key) null, dataInfo, new GLMModel.GLMWeightsFun(gLMParameters), (double[]) null).doAll(dataInfo._adaptedFrame);
            for (int i = 0; i < doAll._xy.length; i++) {
                for (int i2 = 0; i2 <= i; i2++) {
                    Assert.assertEquals(parse_test_file2.vec(i2).at(i), doAll._gram.get(i, i2), 1.0E-5d);
                }
                Assert.assertEquals(remove.at(i), doAll._xy[i], 1.0E-5d);
            }
            remove.remove();
            GLMModel.GLMParameters clone = gLMParameters.clone();
            clone._standardize = false;
            clone._family = GLMModel.GLMParameters.Family.binomial;
            clone._link = GLMModel.GLMParameters.Link.logit;
            gLMModel3 = (GLMModel) new GLM(clone).trainModel().get();
            testScoring(gLMModel3, parse_test_file);
            clone._train = parse_test_file3._key;
            clone._ignored_columns = strArr;
            gLMModel4 = (GLMModel) new GLM(clone).trainModel().get();
            testScoring(gLMModel4, parse_test_file3);
            Assert.assertEquals(nullDeviance(gLMModel3), nullDeviance(gLMModel4), 1.0E-4d);
            Assert.assertEquals(residualDeviance(gLMModel4), residualDeviance(gLMModel3), nullDeviance(gLMModel3) * 0.001d);
            Assert.assertEquals(nullDeviance(gLMModel), nullDeviance(gLMModel2), 1.0E-4d);
            Assert.assertEquals(residualDeviance(gLMModel), residualDeviance(gLMModel2), 1.0E-4d);
            Assert.assertEquals(5336.918d, residualDeviance(gLMModel), 1.0d);
            Assert.assertEquals(6051.613d, nullDeviance(gLMModel2), 1.0d);
            parse_test_file3.delete();
            parse_test_file.delete();
            parse_test_file2.delete();
            if (frame != null) {
                frame.delete();
            }
            if (gLMModel != null) {
                gLMModel.delete();
            }
            if (gLMModel2 != null) {
                gLMModel2.delete();
            }
            if (gLMModel3 != null) {
                gLMModel3.delete();
            }
            if (gLMModel4 != null) {
                gLMModel4.delete();
            }
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            parse_test_file3.delete();
            parse_test_file.delete();
            parse_test_file2.delete();
            if (frame != null) {
                frame.delete();
            }
            if (gLMModel != null) {
                gLMModel.delete();
            }
            if (gLMModel2 != null) {
                gLMModel2.delete();
            }
            if (gLMModel3 != null) {
                gLMModel3.delete();
            }
            if (gLMModel4 != null) {
                gLMModel4.delete();
            }
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    @Test
    public void test_COD_Airlines_SingleLambda() {
        GLMModel gLMModel = null;
        Frame parse_test_file = parse_test_file(Key.make("Airlines"), "smalldata/airlines/AirlinesTrain.csv.zip");
        String[] strArr = {"IsDepDelayed_REC"};
        try {
            Scope.enter();
            GLMModel.GLMParameters gLMParameters = new GLMModel.GLMParameters(GLMModel.GLMParameters.Family.binomial);
            gLMParameters._response_column = "IsDepDelayed";
            gLMParameters._ignored_columns = strArr;
            gLMParameters._train = parse_test_file._key;
            gLMParameters._valid = parse_test_file._key;
            gLMParameters._lambda = new double[]{0.01d};
            gLMParameters._alpha = new double[]{1.0d};
            gLMParameters._standardize = false;
            gLMParameters._solver = GLMModel.GLMParameters.Solver.COORDINATE_DESCENT_NAIVE;
            gLMParameters._lambda_search = true;
            gLMParameters._nlambdas = 5;
            gLMModel = (GLMModel) new GLM(gLMParameters).trainModel().get();
            double[] beta = gLMModel.beta();
            ArrayUtils.l1norm(beta, true);
            ArrayUtils.l2norm2(beta, true);
            parse_test_file.delete();
            if (gLMModel != null) {
                gLMModel.delete();
            }
        } catch (Throwable th) {
            parse_test_file.delete();
            if (gLMModel != null) {
                gLMModel.delete();
            }
            throw th;
        }
    }

    @Test
    public void test_COD_Airlines_SingleLambda_CovUpdates() {
        GLMModel gLMModel = null;
        Frame parse_test_file = parse_test_file(Key.make("Airlines"), "smalldata/airlines/AirlinesTrain.csv.zip");
        String[] strArr = {"IsDepDelayed_REC"};
        try {
            Scope.enter();
            GLMModel.GLMParameters gLMParameters = new GLMModel.GLMParameters(GLMModel.GLMParameters.Family.binomial);
            gLMParameters._response_column = "IsDepDelayed";
            gLMParameters._ignored_columns = strArr;
            gLMParameters._train = parse_test_file._key;
            gLMParameters._valid = parse_test_file._key;
            gLMParameters._lambda = new double[]{0.01d};
            gLMParameters._alpha = new double[]{1.0d};
            gLMParameters._standardize = false;
            gLMParameters._solver = GLMModel.GLMParameters.Solver.COORDINATE_DESCENT;
            gLMParameters._lambda_search = true;
            gLMModel = (GLMModel) new GLM(gLMParameters).trainModel().get();
            double[] beta = gLMModel.beta();
            ArrayUtils.l1norm(beta, true);
            ArrayUtils.l2norm2(beta, true);
            parse_test_file.delete();
            if (gLMModel != null) {
                gLMModel.delete();
            }
        } catch (Throwable th) {
            parse_test_file.delete();
            if (gLMModel != null) {
                gLMModel.delete();
            }
            throw th;
        }
    }

    @Test
    public void test_COD_Airlines_LambdaSearch() {
        GLMModel gLMModel = null;
        Frame parse_test_file = parse_test_file(Key.make("Airlines"), "smalldata/airlines/AirlinesTrain.csv.zip");
        String[] strArr = {"IsDepDelayed_REC"};
        try {
            Scope.enter();
            GLMModel.GLMParameters gLMParameters = new GLMModel.GLMParameters(GLMModel.GLMParameters.Family.binomial);
            gLMParameters._response_column = "IsDepDelayed";
            gLMParameters._ignored_columns = strArr;
            gLMParameters._train = parse_test_file._key;
            gLMParameters._valid = parse_test_file._key;
            gLMParameters._lambda = null;
            gLMParameters._alpha = new double[]{1.0d};
            gLMParameters._standardize = false;
            gLMParameters._solver = GLMModel.GLMParameters.Solver.COORDINATE_DESCENT_NAIVE;
            gLMParameters._lambda_search = true;
            gLMParameters._nlambdas = 5;
            gLMModel = (GLMModel) new GLM(gLMParameters).trainModel().get();
            GLMModel.Submodel submodel = gLMModel._output._submodels[gLMModel._output._submodels.length - 1];
            double[] dArr = submodel.beta;
            System.out.println("lambda " + submodel.lambda_value);
            ArrayUtils.l1norm(dArr, true);
            ArrayUtils.l2norm2(dArr, true);
            parse_test_file.delete();
            if (gLMModel != null) {
                gLMModel.delete();
            }
        } catch (Throwable th) {
            parse_test_file.delete();
            if (gLMModel != null) {
                gLMModel.delete();
            }
            throw th;
        }
    }

    @Test
    public void test_COD_Airlines_LambdaSearch_CovUpdates() {
        GLMModel gLMModel = null;
        Frame parse_test_file = parse_test_file(Key.make("Airlines"), "smalldata/airlines/AirlinesTrain.csv.zip");
        String[] strArr = {"IsDepDelayed_REC"};
        try {
            Scope.enter();
            GLMModel.GLMParameters gLMParameters = new GLMModel.GLMParameters(GLMModel.GLMParameters.Family.binomial);
            gLMParameters._response_column = "IsDepDelayed";
            gLMParameters._ignored_columns = strArr;
            gLMParameters._train = parse_test_file._key;
            gLMParameters._valid = parse_test_file._key;
            gLMParameters._lambda = null;
            gLMParameters._alpha = new double[]{1.0d};
            gLMParameters._standardize = false;
            gLMParameters._solver = GLMModel.GLMParameters.Solver.COORDINATE_DESCENT;
            gLMParameters._lambda_search = true;
            gLMParameters._nlambdas = 5;
            gLMModel = (GLMModel) new GLM(gLMParameters).trainModel().get();
            GLMModel.Submodel submodel = gLMModel._output._submodels[gLMModel._output._submodels.length - 1];
            double[] dArr = submodel.beta;
            System.out.println("lambda " + submodel.lambda_value);
            ArrayUtils.l1norm(dArr, true);
            ArrayUtils.l2norm2(dArr, true);
            parse_test_file.delete();
            if (gLMModel != null) {
                gLMModel.delete();
            }
        } catch (Throwable th) {
            parse_test_file.delete();
            if (gLMModel != null) {
                gLMModel.delete();
            }
            throw th;
        }
    }

    public static double residualDeviance(GLMModel gLMModel) {
        return (gLMModel._parms._family == GLMModel.GLMParameters.Family.binomial || gLMModel._parms._family == GLMModel.GLMParameters.Family.quasibinomial) ? gLMModel._output._training_metrics._resDev : gLMModel._output._training_metrics._resDev;
    }

    public static double residualDevianceTest(GLMModel gLMModel) {
        return gLMModel._parms._family == GLMModel.GLMParameters.Family.binomial ? gLMModel._output._validation_metrics._resDev : gLMModel._output._validation_metrics._resDev;
    }

    public static double nullDevianceTest(GLMModel gLMModel) {
        return gLMModel._parms._family == GLMModel.GLMParameters.Family.binomial ? gLMModel._output._validation_metrics._nullDev : gLMModel._output._validation_metrics._nullDev;
    }

    public static double aic(GLMModel gLMModel) {
        return gLMModel._parms._family == GLMModel.GLMParameters.Family.binomial ? gLMModel._output._training_metrics._AIC : gLMModel._output._training_metrics._AIC;
    }

    public static double nullDOF(GLMModel gLMModel) {
        return gLMModel._parms._family == GLMModel.GLMParameters.Family.binomial ? gLMModel._output._training_metrics._nullDegressOfFreedom : gLMModel._output._training_metrics._nullDegressOfFreedom;
    }

    public static double resDOF(GLMModel gLMModel) {
        return (gLMModel._parms._family == GLMModel.GLMParameters.Family.binomial || gLMModel._parms._family == GLMModel.GLMParameters.Family.quasibinomial) ? gLMModel._output._training_metrics._residualDegressOfFreedom : gLMModel._output._training_metrics._residualDegressOfFreedom;
    }

    public static double auc(GLMModel gLMModel) {
        return gLMModel._output._training_metrics.auc_obj()._auc;
    }

    public static double logloss(GLMModel gLMModel) {
        return gLMModel._output._training_metrics._logloss;
    }

    public static double mse(GLMModel gLMModel) {
        return gLMModel._parms._family == GLMModel.GLMParameters.Family.binomial ? gLMModel._output._training_metrics._MSE : gLMModel._output._training_metrics._MSE;
    }

    public static double nullDeviance(GLMModel gLMModel) {
        return (gLMModel._parms._family == GLMModel.GLMParameters.Family.binomial || gLMModel._parms._family == GLMModel.GLMParameters.Family.quasibinomial) ? gLMModel._output._training_metrics._nullDev : gLMModel._output._training_metrics._nullDev;
    }

    @Test
    public void testProstate() throws InterruptedException, ExecutionException {
        GLMModel gLMModel = null;
        GLMModel gLMModel2 = null;
        GLMModel gLMModel3 = null;
        GLMModel gLMModel4 = null;
        Frame parse_test_file = parse_test_file("smalldata/glm_test/prostate_cat_replaced.csv");
        try {
            Scope.enter();
            String[] strArr = {"Intercept", "AGE", "RACE.R2", "RACE.R3", "DPROS", "DCAPS", "PSA", "VOL", "GLEASON"};
            double[] dArr = {-8.14867d, -0.01368d, 0.32337d, -0.38028d, 0.55964d, 0.49548d, 0.02794d, -0.01104d, 0.97704d};
            GLMModel.GLMParameters gLMParameters = new GLMModel.GLMParameters(GLMModel.GLMParameters.Family.binomial);
            gLMParameters._response_column = "CAPSULE";
            gLMParameters._ignored_columns = new String[]{"ID"};
            gLMParameters._train = parse_test_file._key;
            gLMParameters._lambda = new double[]{0.0d};
            gLMParameters._standardize = false;
            GLMModel gLMModel5 = new GLM(gLMParameters).trainModel().get();
            Assert.assertTrue(gLMModel5._output.bestSubmodel().iteration == 5);
            gLMModel5.delete();
            gLMParameters._max_iterations = 4;
            gLMModel = (GLMModel) new GLM(gLMParameters).trainModel().get();
            Assert.assertTrue(gLMModel._output.bestSubmodel().iteration == 4);
            System.out.println(gLMModel._output._model_summary);
            HashMap coefficients = gLMModel.coefficients();
            System.out.println(coefficients);
            for (int i = 0; i < strArr.length; i++) {
                Assert.assertEquals(dArr[i], ((Double) coefficients.get(strArr[i])).doubleValue(), 1.0E-4d);
            }
            Assert.assertEquals(512.3d, nullDeviance(gLMModel), 0.1d);
            Assert.assertEquals(378.3d, residualDeviance(gLMModel), 0.1d);
            Assert.assertEquals(371.0d, resDOF(gLMModel), 0.0d);
            Assert.assertEquals(396.3d, aic(gLMModel), 0.1d);
            testScoring(gLMModel, parse_test_file);
            gLMModel.score(parse_test_file).delete();
            ModelMetricsBinomialGLM fromDKV = ModelMetricsBinomial.getFromDKV(gLMModel, parse_test_file);
            AUC2 auc2 = ((ModelMetricsBinomial) fromDKV)._auc;
            Assert.assertEquals(gLMModel._output._training_metrics.auc_obj()._auc, auc2._auc, 1.0E-8d);
            Assert.assertEquals(0.7654038154645615d, auc2.pr_auc(), 1.0E-8d);
            Assert.assertEquals(gLMModel._output._training_metrics._MSE, ((ModelMetricsBinomial) fromDKV)._MSE, 1.0E-8d);
            Assert.assertEquals(gLMModel._output._training_metrics._resDev, fromDKV._resDev, 1.0E-8d);
            gLMModel.score(parse_test_file).delete();
            ModelMetricsBinomialGLM fromDKV2 = ModelMetricsBinomial.getFromDKV(gLMModel, parse_test_file);
            Assert.assertEquals(gLMModel._output._training_metrics.auc_obj()._auc, auc2._auc, 1.0E-8d);
            Assert.assertEquals(gLMModel._output._training_metrics._MSE, ((ModelMetricsBinomial) fromDKV2)._MSE, 1.0E-8d);
            Assert.assertEquals(gLMModel._output._training_metrics._resDev, fromDKV2._resDev, 1.0E-8d);
            gLMParameters._prior = 1.0E-5d;
            gLMModel2 = (GLMModel) new GLM(gLMParameters).trainModel().get();
            for (int i2 = 0; i2 < gLMModel2.beta().length - 1; i2++) {
                Assert.assertEquals(gLMModel.beta()[i2], gLMModel2.beta()[i2], 1.0E-8d);
            }
            Assert.assertEquals(gLMModel.beta()[gLMModel.beta().length - 1] - Math.log((gLMModel._ymu[0] * (1.0d - 1.0E-5d)) / (1.0E-5d * (1.0d - gLMModel._ymu[0]))), gLMModel2.beta()[gLMModel.beta().length - 1], 1.0E-10d);
            gLMParameters._lambda_search = true;
            gLMParameters._lambda = null;
            gLMParameters._alpha = new double[]{0.0d};
            gLMParameters._prior = -1.0d;
            gLMParameters._obj_reg = -1.0d;
            gLMParameters._max_iterations = 500;
            gLMParameters._objective_epsilon = 1.0E-6d;
            gLMModel3 = (GLMModel) new GLM(gLMParameters).trainModel().get();
            double d = gLMModel3._output._submodels[gLMModel3._output._best_lambda_idx].lambda_value;
            gLMParameters._lambda_search = false;
            gLMParameters._lambda = new double[]{d};
            ModelMetricsBinomialGLM fromDKV3 = ModelMetrics.getFromDKV(gLMModel3, parse_test_file);
            Assert.assertEquals("mse don't match, " + gLMModel3._output._training_metrics._MSE + " != " + ((ModelMetrics) fromDKV3)._MSE, gLMModel3._output._training_metrics._MSE, ((ModelMetrics) fromDKV3)._MSE, 1.0E-8d);
            Assert.assertEquals("res-devs don't match, " + gLMModel3._output._training_metrics._resDev + " != " + fromDKV3._resDev, gLMModel3._output._training_metrics._resDev, fromDKV3._resDev, 1.0E-4d);
            parse_test_file.add("CAPSULE", parse_test_file.remove("CAPSULE"));
            parse_test_file.remove("ID").remove();
            DKV.put(parse_test_file._key, parse_test_file);
            new DataInfo(parse_test_file, (Frame) null, 1, true, DataInfo.TransformType.NONE, DataInfo.TransformType.NONE, true, false, false, false, false, false);
            gLMModel3.score(parse_test_file).delete();
            ModelMetricsBinomialGLM fromDKV4 = ModelMetrics.getFromDKV(gLMModel3, parse_test_file);
            Assert.assertEquals("mse don't match, " + gLMModel3._output._training_metrics._MSE + " != " + ((ModelMetrics) fromDKV4)._MSE, gLMModel3._output._training_metrics._MSE, ((ModelMetrics) fromDKV4)._MSE, 1.0E-8d);
            Assert.assertEquals("res-devs don't match, " + gLMModel3._output._training_metrics._resDev + " != " + fromDKV4._resDev, gLMModel3._output._training_metrics._resDev, fromDKV4._resDev, 1.0E-4d);
            gLMModel4 = (GLMModel) new GLM(gLMParameters).trainModel().get();
            Assert.assertEquals("mse don't match, " + gLMModel3._output._training_metrics._MSE + " != " + gLMModel4._output._training_metrics._MSE, gLMModel3._output._training_metrics._MSE, gLMModel4._output._training_metrics._MSE, 1.0E-6d);
            Assert.assertEquals("res-devs don't match, " + gLMModel3._output._training_metrics._resDev + " != " + gLMModel4._output._training_metrics._resDev, gLMModel3._output._training_metrics._resDev, gLMModel4._output._training_metrics._resDev, 1.0E-4d);
            gLMModel4.score(parse_test_file).delete();
            ModelMetricsBinomialGLM fromDKV5 = ModelMetrics.getFromDKV(gLMModel4, parse_test_file);
            Assert.assertEquals("mse don't match, " + ((ModelMetrics) fromDKV4)._MSE + " != " + ((ModelMetrics) fromDKV5)._MSE, ((ModelMetrics) fromDKV4)._MSE, ((ModelMetrics) fromDKV5)._MSE, 1.0E-6d);
            Assert.assertEquals("res-devs don't match, " + fromDKV4._resDev + " != " + fromDKV5._resDev, fromDKV4._resDev, fromDKV5._resDev, 1.0E-4d);
            parse_test_file.delete();
            if (gLMModel != null) {
                gLMModel.delete();
            }
            if (gLMModel2 != null) {
                gLMModel2.delete();
            }
            if (gLMModel3 != null) {
                gLMModel3.delete();
            }
            if (gLMModel4 != null) {
                gLMModel4.delete();
            }
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            parse_test_file.delete();
            if (gLMModel != null) {
                gLMModel.delete();
            }
            if (gLMModel2 != null) {
                gLMModel2.delete();
            }
            if (gLMModel3 != null) {
                gLMModel3.delete();
            }
            if (gLMModel4 != null) {
                gLMModel4.delete();
            }
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    @Test
    public void testQuasibinomial() {
        GLMModel.GLMParameters gLMParameters = new GLMModel.GLMParameters(GLMModel.GLMParameters.Family.quasibinomial);
        GLM glm = new GLM(gLMParameters);
        gLMParameters.validate(glm);
        gLMParameters._link = GLMModel.GLMParameters.Link.log;
        try {
            gLMParameters.validate(glm);
            Assert.assertTrue("should've thrown IAE", false);
        } catch (IllegalArgumentException e) {
        }
        GLMModel gLMModel = null;
        Frame parse_test_file = parse_test_file("smalldata/glm_test/prostate_cat_replaced.csv");
        try {
            Scope.enter();
            String[] strArr = {"Intercept", "AGE", "RACE.R2", "RACE.R3", "DPROS", "DCAPS", "PSA", "VOL", "GLEASON"};
            double[] dArr = {-8.14867d, -0.01368d, 0.32337d, -0.38028d, 0.55964d, 0.49548d, 0.02794d, -0.01104d, 0.97704d};
            GLMModel.GLMParameters gLMParameters2 = new GLMModel.GLMParameters(GLMModel.GLMParameters.Family.quasibinomial);
            gLMParameters2._response_column = "CAPSULE";
            gLMParameters2._ignored_columns = new String[]{"ID"};
            gLMParameters2._train = parse_test_file._key;
            gLMParameters2._lambda = new double[]{0.0d};
            gLMParameters2._nfolds = 5;
            gLMParameters2._standardize = false;
            gLMParameters2._link = GLMModel.GLMParameters.Link.logit;
            gLMModel = (GLMModel) new GLM(gLMParameters2).trainModel().get();
            HashMap coefficients = gLMModel.coefficients();
            System.out.println(coefficients);
            for (int i = 0; i < strArr.length; i++) {
                Assert.assertEquals(dArr[i], ((Double) coefficients.get(strArr[i])).doubleValue(), 1.0E-4d);
            }
            Assert.assertEquals(512.3d, nullDeviance(gLMModel), 0.1d);
            Assert.assertEquals(378.3d, residualDeviance(gLMModel), 0.1d);
            Assert.assertEquals(371.0d, resDOF(gLMModel), 0.0d);
            parse_test_file.delete();
            if (gLMModel != null) {
                gLMModel.deleteCrossValidationModels();
                gLMModel.delete();
            }
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            parse_test_file.delete();
            if (gLMModel != null) {
                gLMModel.deleteCrossValidationModels();
                gLMModel.delete();
            }
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    @Test
    public void testSynthetic() throws Exception {
        GLMModel gLMModel = null;
        Frame parse_test_file = parse_test_file("smalldata/glm_test/glm_test2.csv");
        Frame frame = null;
        try {
            Scope.enter();
            GLMModel.GLMParameters gLMParameters = new GLMModel.GLMParameters(GLMModel.GLMParameters.Family.binomial);
            gLMParameters._response_column = "response";
            gLMParameters._ignored_columns = new String[]{"ID"};
            gLMParameters._train = parse_test_file._key;
            gLMParameters._lambda = new double[]{0.0d};
            gLMParameters._standardize = false;
            gLMParameters._max_iterations = 20;
            gLMModel = (GLMModel) new GLM(gLMParameters).trainModel().get();
            System.out.println("beta = " + Arrays.toString(gLMModel.beta()));
            Assert.assertEquals(auc(gLMModel), 1.0d, 1.0E-4d);
            frame = gLMModel.score(parse_test_file);
            Assert.assertEquals(auc(gLMModel), ModelMetricsBinomial.getFromDKV(gLMModel, parse_test_file)._auc._auc, 0.01d);
            parse_test_file.remove();
            if (gLMModel != null) {
                gLMModel.delete();
            }
            if (frame != null) {
                frame.delete();
            }
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            parse_test_file.remove();
            if (gLMModel != null) {
                gLMModel.delete();
            }
            if (frame != null) {
                frame.delete();
            }
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    @Test
    public void testCitibikeReproPUBDEV1839() throws Exception {
        GLMModel gLMModel = null;
        Frame parse_test_file = parse_test_file("smalldata/jira/pubdev_1839_repro_train.csv");
        Frame parse_test_file2 = parse_test_file("smalldata/jira/pubdev_1839_repro_test.csv");
        try {
            Scope.enter();
            GLMModel.GLMParameters gLMParameters = new GLMModel.GLMParameters(GLMModel.GLMParameters.Family.poisson);
            gLMParameters._response_column = "bikes";
            gLMParameters._train = parse_test_file._key;
            gLMParameters._valid = parse_test_file2._key;
            gLMModel = (GLMModel) new GLM(gLMParameters).trainModel().get();
            testScoring(gLMModel, parse_test_file2);
            parse_test_file.remove();
            parse_test_file2.remove();
            if (gLMModel != null) {
                gLMModel.delete();
            }
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            parse_test_file.remove();
            parse_test_file2.remove();
            if (gLMModel != null) {
                gLMModel.delete();
            }
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    @Test
    public void testCitibikeReproPUBDEV1953() throws Exception {
        GLMModel gLMModel = null;
        Frame parse_test_file = parse_test_file("smalldata/glm_test/citibike_small_train.csv");
        Frame parse_test_file2 = parse_test_file("smalldata/glm_test/citibike_small_test.csv");
        try {
            Scope.enter();
            GLMModel.GLMParameters gLMParameters = new GLMModel.GLMParameters(GLMModel.GLMParameters.Family.poisson);
            gLMParameters._response_column = "bikes";
            gLMParameters._train = parse_test_file._key;
            gLMParameters._valid = parse_test_file2._key;
            gLMParameters._family = GLMModel.GLMParameters.Family.poisson;
            gLMModel = (GLMModel) new GLM(gLMParameters).trainModel().get();
            testScoring(gLMModel, parse_test_file2);
            parse_test_file.remove();
            parse_test_file2.remove();
            if (gLMModel != null) {
                gLMModel.delete();
            }
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            parse_test_file.remove();
            parse_test_file2.remove();
            if (gLMModel != null) {
                gLMModel.delete();
            }
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    @Test
    public void testXval() {
        GLMModel gLMModel = null;
        Frame parse_test_file = parse_test_file("smalldata/glm_test/prostate_cat_replaced.csv");
        try {
            GLMModel.GLMParameters gLMParameters = new GLMModel.GLMParameters(GLMModel.GLMParameters.Family.binomial);
            gLMParameters._response_column = "CAPSULE";
            gLMParameters._ignored_columns = new String[]{"ID"};
            gLMParameters._train = parse_test_file._key;
            gLMParameters._lambda_search = true;
            gLMParameters._nfolds = 3;
            gLMParameters._standardize = false;
            gLMParameters._keep_cross_validation_models = true;
            gLMModel = (GLMModel) new GLM(gLMParameters).trainModel().get();
            parse_test_file.delete();
            if (gLMModel != null) {
                for (Key key : gLMModel._output._cross_validation_models) {
                    Keyed.remove(key);
                }
                gLMModel.delete();
            }
        } catch (Throwable th) {
            parse_test_file.delete();
            if (gLMModel != null) {
                for (Key key2 : gLMModel._output._cross_validation_models) {
                    Keyed.remove(key2);
                }
                gLMModel.delete();
            }
            throw th;
        }
    }

    /* JADX WARN: Finally extract failed */
    /* JADX WARN: Type inference failed for: r1v48, types: [double[], double[][]] */
    @Test
    public void testCustomLambdaSearch() {
        Key make = Key.make("prostate");
        Frame parse_test_file = parse_test_file(make, "smalldata/glm_test/prostate_cat_replaced.csv");
        GLMModel.GLMParameters.Family[] familyArr = {GLMModel.GLMParameters.Family.multinomial, GLMModel.GLMParameters.Family.binomial};
        int length = familyArr.length;
        for (int i = 0; i < length; i++) {
            GLMModel.GLMParameters.Family family = familyArr[i];
            for (double d : new double[]{0.0d, 0.5d, 1.0d}) {
                for (GLMModel.GLMParameters.Solver solver : GLMModel.GLMParameters.Solver.values()) {
                    if (solver != GLMModel.GLMParameters.Solver.COORDINATE_DESCENT_NAIVE && solver != GLMModel.GLMParameters.Solver.AUTO && !solver.equals(GLMModel.GLMParameters.Solver.GRADIENT_DESCENT_LH) && !solver.equals(GLMModel.GLMParameters.Solver.GRADIENT_DESCENT_SQERR)) {
                        try {
                            Scope.enter();
                            GLMModel.GLMParameters gLMParameters = new GLMModel.GLMParameters(family);
                            gLMParameters._train = make;
                            gLMParameters._alpha = new double[]{d};
                            gLMParameters._solver = solver;
                            gLMParameters._lambda = new double[]{10.0d, 1.0d, 0.1d, 1.0E-5d, 0.0d};
                            gLMParameters._lambda_search = true;
                            gLMParameters._response_column = family == GLMModel.GLMParameters.Family.multinomial ? "RACE" : "CAPSULE";
                            GLMModel gLMModel = new GLM(gLMParameters).trainModel().get();
                            GLMModel.RegularizationPath regularizationPath = gLMModel.getRegularizationPath();
                            for (int i2 = 0; i2 < gLMParameters._lambda.length; i2++) {
                                GLMModel.GLMParameters gLMParameters2 = new GLMModel.GLMParameters(family);
                                gLMParameters2._train = make;
                                gLMParameters2._alpha = new double[]{d};
                                gLMParameters2._solver = solver;
                                gLMParameters2._lambda = new double[]{gLMParameters._lambda[i2]};
                                gLMParameters2._lambda_search = false;
                                gLMParameters2._response_column = family == GLMModel.GLMParameters.Family.multinomial ? "RACE" : "CAPSULE";
                                gLMParameters2._beta_epsilon = 1.0E-5d;
                                gLMParameters2._objective_epsilon = 1.0E-8d;
                                GLMModel gLMModel2 = new GLM(gLMParameters2).trainModel().get();
                                double[] dArr = regularizationPath._coefficients_std[i2];
                                double[] flat = family == GLMModel.GLMParameters.Family.multinomial ? ArrayUtils.flat(gLMModel2._output.getNormBetaMultinomial()) : gLMModel2._output.getNormBeta();
                                System.out.println(ArrayUtils.pprint((double[][]) new double[]{flat, dArr}));
                                double null_deviance = 0.5d * gLMModel2._output._training_metrics.null_deviance() * (1.0d - regularizationPath._explained_deviance_train[i2]);
                                double residual_deviance = 0.5d * gLMModel2._output._training_metrics.residual_deviance();
                                double d2 = gLMModel._nobs;
                                if (family == GLMModel.GLMParameters.Family.multinomial) {
                                    flat = (double[]) flat.clone();
                                    dArr = (double[]) dArr.clone();
                                    int length2 = flat.length / gLMModel._output.nclasses();
                                    if (!$assertionsDisabled && flat.length != length2 * gLMModel._output.nclasses()) {
                                        throw new AssertionError();
                                    }
                                    for (int i3 = length2 - 1; i3 < flat.length; i3 += length2) {
                                        flat[i3] = 0.0d;
                                        dArr[i3] = 0.0d;
                                    }
                                }
                                Assert.assertEquals((residual_deviance / d2) + ((1.0d - d) * gLMParameters._lambda[i2] * 0.5d * ArrayUtils.l2norm2(flat, true)) + (d * gLMParameters._lambda[i2] * ArrayUtils.l1norm(flat, true)), (null_deviance / d2) + ((1.0d - d) * gLMParameters._lambda[i2] * 0.5d * ArrayUtils.l2norm2(dArr, true)) + (d * gLMParameters._lambda[i2] * ArrayUtils.l1norm(dArr, true)), 2.0d * gLMParameters._objective_epsilon);
                                gLMModel2.delete();
                            }
                            gLMModel.delete();
                            Scope.exit(new Key[0]);
                        } catch (Throwable th) {
                            Scope.exit(new Key[0]);
                            throw th;
                        }
                    }
                }
            }
        }
        parse_test_file.delete();
    }

    @Test
    public void testArcene() throws InterruptedException, ExecutionException {
        Key make = Key.make("arcene_parsed");
        Key make2 = Key.make("arcene_model");
        GLMModel gLMModel = null;
        Frame parse_test_file = parse_test_file(make, "smalldata/glm_test/arcene.csv");
        try {
            Scope.enter();
            GLMModel.GLMParameters gLMParameters = new GLMModel.GLMParameters(GLMModel.GLMParameters.Family.gaussian);
            gLMParameters._lambda = null;
            gLMParameters._response_column = parse_test_file._names[0];
            gLMParameters._train = make;
            gLMParameters._lambda_search = true;
            gLMParameters._nlambdas = 35;
            gLMParameters._lambda_min_ratio = 0.18d;
            gLMParameters._max_iterations = 100000;
            gLMParameters._max_active_predictors = 10000;
            gLMParameters._alpha = new double[]{1.0d};
            for (GLMModel.GLMParameters.Solver solver : new GLMModel.GLMParameters.Solver[]{GLMModel.GLMParameters.Solver.IRLSM, GLMModel.GLMParameters.Solver.COORDINATE_DESCENT}) {
                gLMParameters._solver = solver;
                new GLM(gLMParameters, make2).trainModel().get();
                gLMModel = (GLMModel) DKV.get(make2).get();
                System.out.println(gLMModel._output._model_summary);
                Assert.assertEquals(gLMParameters._nlambdas, gLMModel._output._submodels.length);
                System.out.println(gLMModel._output._training_metrics);
            }
            gLMModel.delete();
            gLMParameters._solver = GLMModel.GLMParameters.Solver.COORDINATE_DESCENT;
            gLMParameters._max_active_predictors = 100;
            gLMParameters._lambda_min_ratio = 0.01d;
            gLMParameters._nlambdas = 100;
            new GLM(gLMParameters, make2).trainModel().get();
            GLMModel gLMModel2 = DKV.get(make2).get();
            Assert.assertTrue(gLMModel2._output.rank() <= gLMParameters._max_active_predictors);
            System.out.println(gLMModel2._output._model_summary);
            System.out.println(gLMModel2._output._training_metrics);
            System.out.println("============================================================================================================");
            gLMModel2.delete();
            gLMParameters._max_active_predictors = 250;
            gLMParameters._lambda = null;
            gLMParameters._lambda_search = false;
            new GLM(gLMParameters, make2).trainModel().get();
            gLMModel = (GLMModel) DKV.get(make2).get();
            Assert.assertTrue(gLMModel._output.rank() <= gLMParameters._max_active_predictors);
            System.out.println(gLMModel._output._model_summary);
            System.out.println(gLMModel._output._training_metrics);
            System.out.println("============================================================================================================");
            gLMModel.delete();
            parse_test_file.delete();
            if (gLMModel != null) {
                gLMModel.delete();
            }
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            parse_test_file.delete();
            if (gLMModel != null) {
                gLMModel.delete();
            }
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    @Test
    public void testBigPOJO() {
        GLMModel gLMModel = null;
        Frame parse_test_file = parse_test_file(Key.make("arcene_parsed"), "smalldata/glm_test/arcene.csv");
        Frame frame = null;
        try {
            Scope.enter();
            GLMModel.GLMParameters gLMParameters = new GLMModel.GLMParameters(GLMModel.GLMParameters.Family.gaussian);
            gLMParameters._lambda = null;
            gLMParameters._response_column = parse_test_file._names[0];
            gLMParameters._train = parse_test_file._key;
            gLMParameters._max_active_predictors = 100000;
            gLMParameters._alpha = new double[]{0.0d};
            gLMParameters._solver = GLMModel.GLMParameters.Solver.L_BFGS;
            gLMModel = (GLMModel) new GLM(gLMParameters).trainModel().get();
            frame = gLMModel.score(parse_test_file);
            gLMModel.testJavaScoring(parse_test_file, frame, 0.0d);
            parse_test_file.delete();
            if (gLMModel != null) {
                gLMModel.delete();
            }
            if (frame != null) {
                frame.delete();
            }
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            parse_test_file.delete();
            if (gLMModel != null) {
                gLMModel.delete();
            }
            if (frame != null) {
                frame.delete();
            }
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    @Test
    public void testAbalone() {
        Scope.enter();
        GLMModel gLMModel = null;
        try {
            Frame parse_test_file = parse_test_file("smalldata/glm_test/Abalone.gz");
            Scope.track(new Frame[]{parse_test_file});
            GLMModel.GLMParameters gLMParameters = new GLMModel.GLMParameters(GLMModel.GLMParameters.Family.gaussian);
            gLMParameters._train = parse_test_file._key;
            gLMParameters._response_column = parse_test_file._names[8];
            gLMParameters._alpha = new double[]{1.0d};
            gLMParameters._lambda_search = true;
            gLMModel = (GLMModel) new GLM(gLMParameters).trainModel().get();
            testScoring(gLMModel, parse_test_file);
            if (gLMModel != null) {
                gLMModel.delete();
            }
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            if (gLMModel != null) {
                gLMModel.delete();
            }
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    @Test
    public void testZeroedColumn() {
        Vec makeCon = Vec.makeCon(Vec.newKey(), new double[]{1.0d, 2.0d, 3.0d, 4.0d, 5.0d});
        Frame frame = new Frame(Key.make("test"), new String[]{"x", "y", "z", "w"}, new Vec[]{makeCon, Vec.makeCon(makeCon.group().addVec(), new double[]{0.0d, 1.0d, 0.0d, 1.0d, 0.0d}), Vec.makeCon(Vec.newKey(), new double[]{1.0d, 2.0d, 3.0d, 4.0d, 5.0d}), Vec.makeCon(makeCon.group().addVec(), new double[]{1.0d, 0.0d, 1.0d, 0.0d, 1.0d})});
        DKV.put(frame);
        GLMModel.GLMParameters gLMParameters = new GLMModel.GLMParameters(GLMModel.GLMParameters.Family.gaussian);
        gLMParameters._train = frame._key;
        gLMParameters._lambda = new double[]{0.0d};
        gLMParameters._alpha = new double[]{0.0d};
        gLMParameters._compute_p_values = true;
        gLMParameters._response_column = "z";
        gLMParameters._weights_column = "w";
        GLMModel gLMModel = new GLM(gLMParameters).trainModel().get();
        System.out.println(gLMModel.coefficients());
        gLMModel.delete();
        frame.delete();
    }

    @Test
    public void testDeviances() {
        GLMModel.GLMParameters.Family[] values = GLMModel.GLMParameters.Family.values();
        int length = values.length;
        for (int i = 0; i < length; i++) {
            GLMModel.GLMParameters.Family family = values[i];
            if (family != GLMModel.GLMParameters.Family.quasibinomial && family != GLMModel.GLMParameters.Family.ordinal) {
                Frame frame = null;
                Frame frame2 = null;
                Frame frame3 = null;
                GLMModel gLMModel = null;
                try {
                    frame = parse_test_file("./smalldata/gbm_test/BostonHousing.csv");
                    GLMModel.GLMParameters gLMParameters = new GLMModel.GLMParameters();
                    gLMParameters._train = frame._key;
                    String lastVecName = frame.lastVecName();
                    if (family == GLMModel.GLMParameters.Family.binomial || family == GLMModel.GLMParameters.Family.multinomial || family == GLMModel.GLMParameters.Family.fractionalbinomial) {
                        lastVecName = family == GLMModel.GLMParameters.Family.multinomial ? "rad" : "chas";
                        Vec remove = frame.remove(lastVecName);
                        frame.add(lastVecName, remove.toCategoricalVec());
                        remove.remove();
                        DKV.put(frame);
                    }
                    gLMParameters._response_column = lastVecName;
                    gLMParameters._family = family;
                    GLMModel gLMModel2 = new GLM(gLMParameters).trainModel().get();
                    Frame score = gLMModel2.score(frame);
                    Frame computeDeviances = gLMModel2.computeDeviances(frame, score, "myDeviances");
                    double mean = computeDeviances.anyVec().mean();
                    if (gLMModel2._output.nclasses() == 2) {
                        Assert.assertEquals(mean, gLMModel2._output._training_metrics._logloss, 1.0E-6d * Math.abs(mean));
                    } else if (gLMModel2._output.nclasses() > 2) {
                        Assert.assertEquals(mean, gLMModel2._output._training_metrics._logloss, 1.0E-6d * Math.abs(mean));
                    } else {
                        Assert.assertEquals(mean, gLMModel2._output._training_metrics._mean_residual_deviance, 1.0E-6d * Math.abs(mean));
                    }
                    if (frame != null) {
                        frame.delete();
                    }
                    if (computeDeviances != null) {
                        computeDeviances.delete();
                    }
                    if (score != null) {
                        score.delete();
                    }
                    if (gLMModel2 != null) {
                        gLMModel2.delete();
                    }
                } catch (Throwable th) {
                    if (frame != null) {
                        frame.delete();
                    }
                    if (0 != 0) {
                        frame2.delete();
                    }
                    if (0 != 0) {
                        frame3.delete();
                    }
                    if (0 != 0) {
                        gLMModel.delete();
                    }
                    throw th;
                }
            }
        }
    }

    @Test
    public void testUnseenLevels() {
        Scope.enter();
        try {
            Vec makeCon = Vec.makeCon(Vec.newKey(), new double[]{1.0d, 0.0d, 0.0d, 1.0d, 1.0d});
            makeCon.setDomain(new String[]{"blue", "red"});
            Frame frame = new Frame(Key.make("train"), new String[]{"color", "label"}, new Vec[]{makeCon, makeCon.makeCopy((String[]) null)});
            DKV.put(frame);
            Vec makeCon2 = Vec.makeCon(Vec.newKey(), new double[]{1.0d, 0.0d, 0.0d, 2.0d});
            makeCon2.setDomain(new String[]{"blue", "red", "yellow"});
            Frame frame2 = new Frame(Key.make("test"), new String[]{"color", "label"}, new Vec[]{makeCon2, Vec.makeCon(makeCon2.group().addVec(), new double[]{1.0d, 0.0d, 0.0d, 0.0d})});
            DKV.put(frame2);
            GLMModel.GLMParameters gLMParameters = new GLMModel.GLMParameters(GLMModel.GLMParameters.Family.gaussian);
            gLMParameters._train = frame._key;
            gLMParameters._response_column = "label";
            gLMParameters._missing_values_handling = GLMModel.GLMParameters.MissingValuesHandling.Skip;
            GLMModel gLMModel = new GLM(gLMParameters).trainModel().get();
            System.out.println("coefficients = " + gLMModel.coefficients());
            double doubleValue = ((Double) gLMModel.coefficients().get("Intercept")).doubleValue();
            Frame score = gLMModel.score(frame2);
            Assert.assertEquals(doubleValue + ((Double) gLMModel.coefficients().get("color.red")).doubleValue(), score.vec(0).at(0L), 0.0d);
            Assert.assertEquals(doubleValue + ((Double) gLMModel.coefficients().get("color.blue")).doubleValue(), score.vec(0).at(1L), 0.0d);
            Assert.assertEquals(doubleValue + ((Double) gLMModel.coefficients().get("color.blue")).doubleValue(), score.vec(0).at(2L), 0.0d);
            Assert.assertEquals(doubleValue, score.vec(0).at(3L), 0.0d);
            gLMParameters._missing_values_handling = GLMModel.GLMParameters.MissingValuesHandling.MeanImputation;
            GLMModel gLMModel2 = new GLM(gLMParameters).trainModel().get();
            Frame score2 = gLMModel2.score(frame2);
            double doubleValue2 = ((Double) gLMModel2.coefficients().get("Intercept")).doubleValue();
            System.out.println("coefficients = " + gLMModel2.coefficients());
            Assert.assertEquals(doubleValue2 + ((Double) gLMModel2.coefficients().get("color.red")).doubleValue(), score2.vec(0).at(0L), 0.0d);
            Assert.assertEquals(doubleValue2 + ((Double) gLMModel2.coefficients().get("color.blue")).doubleValue(), score2.vec(0).at(1L), 0.0d);
            Assert.assertEquals(doubleValue2 + ((Double) gLMModel2.coefficients().get("color.blue")).doubleValue(), score2.vec(0).at(2L), 0.0d);
            Assert.assertEquals(doubleValue2 + ((Double) gLMModel2.coefficients().get("color.red")).doubleValue(), score2.vec(0).at(3L), 0.0d);
            frame.delete();
            frame2.delete();
            gLMModel.delete();
            score.delete();
            score2.delete();
            gLMModel2.delete();
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    static {
        $assertionsDisabled = !GLMTest.class.desiredAssertionStatus();
    }
}
