package hex.tree.drf;

import hex.Model;
import hex.ModelMetrics;
import hex.ModelMetricsBinomial;
import hex.ModelMetricsRegression;
import hex.SplitFrame;
import hex.tree.SharedTreeModel;
import hex.tree.drf.DRFModel;
import java.io.BufferedWriter;
import java.io.FileWriter;
import java.util.Arrays;
import java.util.Map;
import java.util.Random;
import java.util.TreeMap;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Ignore;
import org.junit.Test;
import water.DKV;
import water.H2O;
import water.Key;
import water.Keyed;
import water.Scope;
import water.TestUtil;
import water.exceptions.H2OModelBuilderIllegalArgumentException;
import water.fvec.Frame;
import water.fvec.RebalanceDataSet;
import water.fvec.Vec;
import water.util.ArrayUtils;
import water.util.Log;
import water.util.Triple;
import water.util.VecUtils;

/* loaded from: input_file:hex/tree/drf/DRFTest.class */
public class DRFTest extends TestUtil {
    static double _AUC;
    static double _MSE;
    static double _LogLoss;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:hex/tree/drf/DRFTest$PrepData.class */
    public static abstract class PrepData {
        PrepData() {
        }

        abstract int prep(Frame frame);
    }

    @BeforeClass
    public static void stall() {
        stall_till_cloudsize(1);
    }

    static String[] s(String... strArr) {
        return strArr;
    }

    /* JADX WARN: Type inference failed for: r8v1, types: [double[], double[][]] */
    @Test
    public void testClassIris1() throws Throwable {
        basicDRFTestOOBE_Classification("./smalldata/iris/iris.csv", "iris.hex", new PrepData() { // from class: hex.tree.drf.DRFTest.1
            @Override // hex.tree.drf.DRFTest.PrepData
            int prep(Frame frame) {
                return frame.numCols() - 1;
            }
        }, 1, 20, 1, 20, ard(new double[]{ard(new double[]{15.0d, 0.0d, 0.0d}), ard(new double[]{0.0d, 18.0d, 0.0d}), ard(new double[]{0.0d, 1.0d, 17.0d})}), s("Iris-setosa", "Iris-versicolor", "Iris-virginica"));
    }

    /* JADX WARN: Type inference failed for: r8v1, types: [double[], double[][]] */
    @Test
    public void testClassIris5() throws Throwable {
        basicDRFTestOOBE_Classification("./smalldata/iris/iris.csv", "iris5.hex", new PrepData() { // from class: hex.tree.drf.DRFTest.2
            @Override // hex.tree.drf.DRFTest.PrepData
            int prep(Frame frame) {
                return frame.numCols() - 1;
            }
        }, 5, 20, 1, 20, ard(new double[]{ard(new double[]{43.0d, 0.0d, 0.0d}), ard(new double[]{0.0d, 37.0d, 4.0d}), ard(new double[]{0.0d, 4.0d, 39.0d})}), s("Iris-setosa", "Iris-versicolor", "Iris-virginica"));
    }

    /* JADX WARN: Type inference failed for: r8v1, types: [double[], double[][]] */
    @Test
    public void testClassCars1() throws Throwable {
        basicDRFTestOOBE_Classification("./smalldata/junit/cars.csv", "cars.hex", new PrepData() { // from class: hex.tree.drf.DRFTest.3
            @Override // hex.tree.drf.DRFTest.PrepData
            int prep(Frame frame) {
                frame.remove("name").remove();
                return frame.find("cylinders");
            }
        }, 1, 20, 1, 20, ard(new double[]{ard(new double[]{0.0d, 2.0d, 0.0d, 0.0d, 0.0d}), ard(new double[]{0.0d, 58.0d, 6.0d, 4.0d, 0.0d}), ard(new double[]{0.0d, 1.0d, 0.0d, 0.0d, 0.0d}), ard(new double[]{1.0d, 3.0d, 4.0d, 25.0d, 1.0d}), ard(new double[]{0.0d, 0.0d, 0.0d, 2.0d, 37.0d})}), s("3", "4", "5", "6", "8"));
    }

    /* JADX WARN: Type inference failed for: r8v1, types: [double[], double[][]] */
    @Test
    public void testClassCars5() throws Throwable {
        basicDRFTestOOBE_Classification("./smalldata/junit/cars.csv", "cars5.hex", new PrepData() { // from class: hex.tree.drf.DRFTest.4
            @Override // hex.tree.drf.DRFTest.PrepData
            int prep(Frame frame) {
                frame.remove("name").remove();
                return frame.find("cylinders");
            }
        }, 5, 20, 1, 20, ard(new double[]{ard(new double[]{1.0d, 2.0d, 0.0d, 0.0d, 0.0d}), ard(new double[]{0.0d, 177.0d, 1.0d, 5.0d, 0.0d}), ard(new double[]{0.0d, 2.0d, 0.0d, 0.0d, 0.0d}), ard(new double[]{0.0d, 6.0d, 1.0d, 67.0d, 1.0d}), ard(new double[]{0.0d, 0.0d, 0.0d, 2.0d, 84.0d})}), s("3", "4", "5", "6", "8"));
    }

    @Test
    public void testConstantCols() throws Throwable {
        try {
            basicDRFTestOOBE_Classification("./smalldata/poker/poker100", "poker.hex", new PrepData() { // from class: hex.tree.drf.DRFTest.5
                @Override // hex.tree.drf.DRFTest.PrepData
                int prep(Frame frame) {
                    for (int i = 0; i < 7; i++) {
                        frame.remove(3).remove();
                    }
                    return 3;
                }
            }, 1, 20, 1, 20, (double[][]) null, null);
            Assert.fail();
        } catch (H2OModelBuilderIllegalArgumentException e) {
        }
    }

    /* JADX WARN: Type inference failed for: r8v1, types: [double[], double[][]] */
    @Test
    @Ignore
    public void testBadData() throws Throwable {
        basicDRFTestOOBE_Classification("./smalldata/junit/drf_infinities.csv", "infinitys.hex", new PrepData() { // from class: hex.tree.drf.DRFTest.6
            @Override // hex.tree.drf.DRFTest.PrepData
            int prep(Frame frame) {
                return frame.find("DateofBirth");
            }
        }, 1, 20, 1, 20, ard(new double[]{ard(new double[]{6.0d, 0.0d}), ard(new double[]{9.0d, 1.0d})}), s("0", "1"));
    }

    /* JADX WARN: Type inference failed for: r8v1, types: [double[], double[][]] */
    public void testCreditSample1() throws Throwable {
        basicDRFTestOOBE_Classification("./smalldata/kaggle/creditsample-training.csv.gz", "credit.hex", new PrepData() { // from class: hex.tree.drf.DRFTest.7
            @Override // hex.tree.drf.DRFTest.PrepData
            int prep(Frame frame) {
                frame.remove("MonthlyIncome").remove();
                return frame.find("SeriousDlqin2yrs");
            }
        }, 1, 20, 1, 20, ard(new double[]{ard(new double[]{46294.0d, 202.0d}), ard(new double[]{3187.0d, 107.0d})}), s("0", "1"));
    }

    /* JADX WARN: Type inference failed for: r8v1, types: [double[], double[][]] */
    @Test
    public void testCreditProstate1() throws Throwable {
        basicDRFTestOOBE_Classification("./smalldata/logreg/prostate.csv", "prostate.hex", new PrepData() { // from class: hex.tree.drf.DRFTest.8
            @Override // hex.tree.drf.DRFTest.PrepData
            int prep(Frame frame) {
                frame.remove("ID").remove();
                return frame.find("CAPSULE");
            }
        }, 1, 20, 1, 20, ard(new double[]{ard(new double[]{0.0d, 70.0d}), ard(new double[]{0.0d, 59.0d})}), s("0", "1"));
    }

    @Test
    public void testCreditProstateRegression1() throws Throwable {
        basicDRFTestOOBE_Regression("./smalldata/logreg/prostate.csv", "prostateRegression.hex", new PrepData() { // from class: hex.tree.drf.DRFTest.9
            @Override // hex.tree.drf.DRFTest.PrepData
            int prep(Frame frame) {
                frame.remove("ID").remove();
                return frame.find("AGE");
            }
        }, 1, 20, 1, 10, 63.13182273942728d);
    }

    @Test
    public void testCreditProstateRegression5() throws Throwable {
        basicDRFTestOOBE_Regression("./smalldata/logreg/prostate.csv", "prostateRegression5.hex", new PrepData() { // from class: hex.tree.drf.DRFTest.10
            @Override // hex.tree.drf.DRFTest.PrepData
            int prep(Frame frame) {
                frame.remove("ID").remove();
                return frame.find("AGE");
            }
        }, 5, 20, 1, 10, 59.713095855920244d);
    }

    @Test
    public void testCreditProstateRegression50() throws Throwable {
        basicDRFTestOOBE_Regression("./smalldata/logreg/prostate.csv", "prostateRegression50.hex", new PrepData() { // from class: hex.tree.drf.DRFTest.11
            @Override // hex.tree.drf.DRFTest.PrepData
            int prep(Frame frame) {
                frame.remove("ID").remove();
                return frame.find("AGE");
            }
        }, 50, 20, 1, 10, 46.88452885668735d);
    }

    /* JADX WARN: Type inference failed for: r8v1, types: [double[], double[][]] */
    @Test
    public void testCzechboard() throws Throwable {
        basicDRFTestOOBE_Classification("./smalldata/gbm_test/czechboard_300x300.csv", "czechboard_300x300.hex", new PrepData() { // from class: hex.tree.drf.DRFTest.12
            @Override // hex.tree.drf.DRFTest.PrepData
            int prep(Frame frame) {
                Vec remove = frame.remove("C2");
                frame.add("C2", VecUtils.toCategoricalVec(remove));
                remove.remove();
                return frame.find("C3");
            }
        }, 50, 20, 1, 20, ard(new double[]{ard(new double[]{0.0d, 45000.0d}), ard(new double[]{0.0d, 45000.0d})}), s("0", "1"));
    }

    @Test
    public void test30kUnseenLevels() throws Throwable {
        basicDRFTestOOBE_Regression("./smalldata/gbm_test/30k_cattest.csv", "cat30k", new PrepData() { // from class: hex.tree.drf.DRFTest.13
            @Override // hex.tree.drf.DRFTest.PrepData
            int prep(Frame frame) {
                return frame.find("C3");
            }
        }, 50, 20, 10, 5, 0.25040633586487d);
    }

    @Test
    public void testProstate() throws Throwable {
        basicDRFTestOOBE_Classification("./smalldata/prostate/prostate.csv.zip", "prostate2.zip.hex", new PrepData() { // from class: hex.tree.drf.DRFTest.14
            @Override // hex.tree.drf.DRFTest.PrepData
            int prep(Frame frame) {
                String[] strArr = (String[]) frame.names().clone();
                Vec[] remove = frame.remove(new int[]{1, 4, 5, 8});
                frame.add(strArr[1], VecUtils.toCategoricalVec(remove[0]));
                frame.add(strArr[4], VecUtils.toCategoricalVec(remove[1]));
                frame.add(strArr[5], VecUtils.toCategoricalVec(remove[2]));
                frame.add(strArr[8], VecUtils.toCategoricalVec(remove[3]));
                for (Vec vec : remove) {
                    vec.remove();
                }
                frame.remove(0).remove();
                return 4;
            }
        }, 4, 2, 1, 1, (double[][]) null, s("0", "1"));
    }

    /* JADX WARN: Type inference failed for: r8v1, types: [double[], double[][]] */
    @Test
    public void testAlphabet() throws Throwable {
        basicDRFTestOOBE_Classification("./smalldata/gbm_test/alphabet_cattest.csv", "alphabetClassification.hex", new PrepData() { // from class: hex.tree.drf.DRFTest.15
            @Override // hex.tree.drf.DRFTest.PrepData
            int prep(Frame frame) {
                return frame.find("y");
            }
        }, 1, 20, 1, 20, ard(new double[]{ard(new double[]{670.0d, 0.0d}), ard(new double[]{0.0d, 703.0d})}), s("0", "1"));
    }

    @Test
    public void testAlphabetRegression() throws Throwable {
        basicDRFTestOOBE_Regression("./smalldata/gbm_test/alphabet_cattest.csv", "alphabetRegression.hex", new PrepData() { // from class: hex.tree.drf.DRFTest.16
            @Override // hex.tree.drf.DRFTest.PrepData
            int prep(Frame frame) {
                return frame.find("y");
            }
        }, 1, 20, 1, 10, 0.0d);
    }

    @Test
    public void testAlphabetRegression2() throws Throwable {
        basicDRFTestOOBE_Regression("./smalldata/gbm_test/alphabet_cattest.csv", "alphabetRegression2.hex", new PrepData() { // from class: hex.tree.drf.DRFTest.17
            @Override // hex.tree.drf.DRFTest.PrepData
            int prep(Frame frame) {
                return frame.find("y");
            }
        }, 1, 26, 1, 1, 0.0d);
    }

    @Test
    public void testAlphabetRegression3() throws Throwable {
        basicDRFTestOOBE_Regression("./smalldata/gbm_test/alphabet_cattest.csv", "alphabetRegression3.hex", new PrepData() { // from class: hex.tree.drf.DRFTest.18
            @Override // hex.tree.drf.DRFTest.PrepData
            int prep(Frame frame) {
                return frame.find("y");
            }
        }, 1, 25, 1, 1, 0.24007225096411577d);
    }

    /* JADX WARN: Type inference failed for: r8v1, types: [double[], double[][]] */
    @Test
    @Ignore
    public void testAirlines() throws Throwable {
        basicDRFTestOOBE_Classification("./smalldata/airlines/allyears2k_headers.zip", "airlines.hex", new PrepData() { // from class: hex.tree.drf.DRFTest.19
            @Override // hex.tree.drf.DRFTest.PrepData
            int prep(Frame frame) {
                for (String str : new String[]{"DepTime", "ArrTime", "ActualElapsedTime", "AirTime", "ArrDelay", "DepDelay", "Cancelled", "CancellationCode", "CarrierDelay", "WeatherDelay", "NASDelay", "SecurityDelay", "LateAircraftDelay", "IsArrDelayed"}) {
                    frame.remove(str).remove();
                }
                return frame.find("IsDepDelayed");
            }
        }, 7, 20, 1, 20, ard(new double[]{ard(new double[]{7958.0d, 11707.0d}), ard(new double[]{2709.0d, 19024.0d})}), s("NO", "YES"));
    }

    static Vec unifyFrame(DRFModel.DRFParameters dRFParameters, Frame frame, PrepData prepData, boolean z) {
        int prep = prepData.prep(frame);
        if (prep < 0) {
            prep ^= -1;
        }
        String str = frame._names[prep];
        dRFParameters._response_column = frame.names()[prep];
        Vec vec = frame.vecs()[prep];
        Vec vec2 = null;
        if (z) {
            vec2 = frame.remove(prep);
            frame.add(str, VecUtils.toCategoricalVec(vec));
        } else {
            frame.remove(prep);
            frame.add(str, vec);
        }
        return vec2;
    }

    public void basicDRFTestOOBE_Classification(String str, String str2, PrepData prepData, int i, int i2, int i3, int i4, double[][] dArr, String[] strArr) throws Throwable {
        basicDRF(str, str2, null, prepData, i, i4, i2, true, i3, dArr, -1.0d, strArr);
    }

    public void basicDRFTestOOBE_Regression(String str, String str2, PrepData prepData, int i, int i2, int i3, int i4, double d) throws Throwable {
        basicDRF(str, str2, null, prepData, i, i4, i2, false, i3, (double[][]) null, d, null);
    }

    public void basicDRF(String str, String str2, String str3, PrepData prepData, int i, int i2, int i3, boolean z, int i4, double[][] dArr, double d, String[] strArr) throws Throwable {
        ModelMetrics fromDKV;
        Scope.enter();
        DRFModel.DRFParameters dRFParameters = new DRFModel.DRFParameters();
        Frame frame = null;
        Frame frame2 = null;
        Frame frame3 = null;
        Frame frame4 = null;
        Frame frame5 = null;
        DRFModel dRFModel = null;
        try {
            frame3 = parse_test_file(str);
            Vec unifyFrame = unifyFrame(dRFParameters, frame3, prepData, z);
            if (unifyFrame != null) {
                Scope.track(unifyFrame);
            }
            DKV.put(frame3._key, frame3);
            dRFParameters._train = frame3._key;
            dRFParameters._response_column = DKV.getGet(dRFParameters._train).lastVecName();
            dRFParameters._ntrees = i;
            dRFParameters._max_depth = i2;
            dRFParameters._min_rows = i4;
            dRFParameters._stopping_rounds = 0;
            dRFParameters._nbins = i3;
            dRFParameters._nbins_cats = i3;
            dRFParameters._mtries = -1;
            dRFParameters._sample_rate = 0.666670024394989d;
            dRFParameters._seed = 4294967298L;
            DRF drf = new DRF(dRFParameters);
            dRFModel = (DRFModel) drf.trainModel().get();
            Log.info(new Object[]{dRFModel._output});
            Assert.assertTrue(drf.isStopped());
            if (str3 != null) {
                frame = parse_test_file(str3);
                frame2 = dRFModel.score(frame);
                fromDKV = ModelMetrics.getFromDKV(dRFModel, frame);
            } else {
                fromDKV = ModelMetrics.getFromDKV(dRFModel, frame3);
            }
            Assert.assertEquals("Number of trees differs!", i, dRFModel._output._ntrees);
            frame4 = parse_test_file(str);
            frame5 = dRFModel.score(frame4);
            Assert.assertTrue(dRFModel.testJavaScoring(frame4, frame5, 1.0E-15d));
            if (z && dArr != null) {
                Assert.assertTrue("Expected: " + Arrays.deepToString(dArr) + ", Got: " + Arrays.deepToString(fromDKV.cm()._cm), Arrays.deepEquals(fromDKV.cm()._cm, dArr));
                Assert.assertArrayEquals("CM domain differs!", strArr, dRFModel._output._domains[dRFModel._output._domains.length - 1]);
                Log.info(new Object[]{"\nOOB Training CM:\n" + fromDKV.cm().toASCII()});
                Log.info(new Object[]{"\nTraining CM:\n" + ModelMetrics.getFromDKV(dRFModel, frame4).cm().toASCII()});
            } else if (!z) {
                Assert.assertTrue("Expected: " + d + ", Got: " + fromDKV.mse(), Math.abs(d - fromDKV.mse()) <= 1.0E-10d * Math.abs(d + fromDKV.mse()));
                Log.info(new Object[]{"\nOOB Training MSE: " + fromDKV.mse()});
                Log.info(new Object[]{"\nTraining MSE: " + ModelMetrics.getFromDKV(dRFModel, frame4).mse()});
            }
            ModelMetrics.getFromDKV(dRFModel, frame4);
            if (frame3 != null) {
                frame3.remove();
            }
            if (frame != null) {
                frame.remove();
            }
            if (dRFModel != null) {
                dRFModel.delete();
            }
            if (frame2 != null) {
                frame2.delete();
            }
            if (frame4 != null) {
                frame4.delete();
            }
            if (frame5 != null) {
                frame5.delete();
            }
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            if (frame3 != null) {
                frame3.remove();
            }
            if (frame != null) {
                frame.remove();
            }
            if (dRFModel != null) {
                dRFModel.delete();
            }
            if (frame2 != null) {
                frame2.delete();
            }
            if (frame4 != null) {
                frame4.delete();
            }
            if (frame5 != null) {
                frame5.delete();
            }
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    /* JADX WARN: Finally extract failed */
    @Test
    @Ignore
    public void testAutoRebalance() {
        if (1 != 0) {
            for (int i : new int[]{1, 2, 3, 4, 5}) {
                Frame frame = null;
                Scope.enter();
                try {
                    frame = parse_test_file("/Users/ludirehak/Downloads/train.csv.zip");
                    DRFModel.DRFParameters dRFParameters = new DRFModel.DRFParameters();
                    dRFParameters._train = frame._key;
                    dRFParameters._response_column = "Sales";
                    dRFParameters._nbins = 1000;
                    dRFParameters._ntrees = 10;
                    dRFParameters._max_depth = 20;
                    dRFParameters._mtries = -1;
                    dRFParameters._min_rows = 10.0d;
                    dRFParameters._seed = 1234L;
                    new DRF(dRFParameters).trainModel().get().delete();
                    if (frame != null) {
                        frame.remove();
                    }
                    Scope.exit(new Key[0]);
                } catch (Throwable th) {
                    if (frame != null) {
                        frame.remove();
                    }
                    throw th;
                }
            }
        }
        int[] iArr = {2, 5, 10, 15, 20};
        int[] iArr2 = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32};
        boolean[] zArr = {true};
        int[] iArr3 = {10};
        int length = iArr2.length * iArr.length * zArr.length * iArr3.length;
        double[] dArr = new double[length];
        int[] iArr4 = new int[length];
        int[] iArr5 = new int[length];
        boolean[] zArr2 = new boolean[length];
        int[] iArr6 = new int[length];
        int i2 = 0;
        for (int i3 : iArr) {
            for (int i4 : iArr3) {
                for (boolean z : zArr) {
                    for (int i5 : iArr2) {
                        long currentTimeMillis = System.currentTimeMillis();
                        Scope.enter();
                        Frame parse_test_file = parse_test_file("/Users/ludirehak/Downloads/train.csv.zip");
                        DRFModel.DRFParameters dRFParameters2 = new DRFModel.DRFParameters();
                        dRFParameters2._train = parse_test_file._key;
                        dRFParameters2._response_column = "Sales";
                        dRFParameters2._nbins = 1000;
                        dRFParameters2._mtries = -1;
                        dRFParameters2._min_rows = 10.0d;
                        dRFParameters2._seed = 1234L;
                        dRFParameters2._ntrees = i4;
                        dRFParameters2._max_depth = i3;
                        DRF drf = new DRF(dRFParameters2);
                        DRFModel dRFModel = drf.trainModel().get();
                        Assert.assertEquals(dRFModel._output._ntrees, dRFParameters2._ntrees);
                        ModelMetricsRegression modelMetricsRegression = dRFModel._output._training_metrics;
                        int nChunks = drf.train().anyVec().nChunks();
                        dRFModel.delete();
                        parse_test_file.remove();
                        Scope.exit(new Key[0]);
                        dArr[i2] = (System.currentTimeMillis() - currentTimeMillis) / 1000.0d;
                        if (!z && !$assertionsDisabled && nChunks != 22) {
                            throw new AssertionError();
                        }
                        iArr4[i2] = nChunks;
                        iArr5[i2] = i3;
                        zArr2[i2] = z;
                        iArr6[i2] = dRFModel._output._ntrees;
                        Log.info(new Object[]{"Iteration " + (i2 + 1) + " out of " + dArr.length});
                        Log.info(new Object[]{" DEPTH: " + iArr5[i2] + " NTREES: " + iArr6[i2] + " CHUNKS: " + iArr4[i2] + " EXECUTION TIME: " + dArr[i2] + " Rebalanced: " + z + " WarmedUp: true"});
                        i2++;
                    }
                }
            }
        }
        try {
            BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter("/Users/ludirehak/Desktop/DRFTestRebalance3.txt"));
            bufferedWriter.write("max_depth,ntrees,nbins,min_rows,chunks,execution_time,rebalanceMe,warmUp");
            bufferedWriter.newLine();
            for (int i6 = 0; i6 < dArr.length; i6++) {
                bufferedWriter.write(iArr5[i6] + "," + iArr6[i6] + ",1000,10," + iArr4[i6] + "," + dArr[i6] + ",," + (zArr2[i6] ? 1 : 0) + "," + (1 != 0 ? 1 : 0));
                bufferedWriter.newLine();
            }
            bufferedWriter.close();
        } catch (Exception e) {
            Log.info(new Object[]{"Fail"});
        }
    }

    @Test
    public void testChunks() {
        double[] dArr = new double[4];
        int[] iArr = {1, 13, 19, 39, 500};
        for (int i = 0; i < 4; i++) {
            Scope.enter();
            Frame parse_test_file = parse_test_file("smalldata/covtype/covtype.20k.data");
            Key make = Key.make("df.rebalanced.hex");
            RebalanceDataSet rebalanceDataSet = new RebalanceDataSet(parse_test_file, make, iArr[i]);
            H2O.submitTask(rebalanceDataSet);
            rebalanceDataSet.join();
            parse_test_file.delete();
            Frame frame = DKV.get(make).get();
            Scope.track(frame.replace(54, frame.vecs()[54].toCategoricalVec()));
            DKV.put(frame);
            DRFModel.DRFParameters dRFParameters = new DRFModel.DRFParameters();
            dRFParameters._train = frame._key;
            dRFParameters._response_column = "C55";
            dRFParameters._ntrees = 10;
            dRFParameters._seed = 1234L;
            dRFParameters._auto_rebalance = false;
            DRFModel dRFModel = new DRF(dRFParameters).trainModel().get();
            Assert.assertEquals(dRFModel._output._ntrees, dRFParameters._ntrees);
            dArr[i] = dRFModel._output._scored_train[dRFModel._output._scored_train.length - 1]._mse;
            dRFModel.delete();
            if (frame != null) {
                frame.remove();
            }
            Scope.exit(new Key[0]);
        }
        for (int i2 = 0; i2 < dArr.length; i2++) {
            Log.info(new Object[]{"trial: " + i2 + " -> MSE: " + dArr[i2]});
        }
        for (double d : dArr) {
            Assert.assertEquals(d, dArr[0], 1.0E-10d);
        }
    }

    /* JADX WARN: Finally extract failed */
    @Test
    public void testReproducibility() {
        Frame frame = null;
        double[] dArr = new double[5];
        Scope.enter();
        try {
            Frame parse_test_file = parse_test_file("smalldata/covtype/covtype.20k.data");
            Key make = Key.make("df.rebalanced.hex");
            RebalanceDataSet rebalanceDataSet = new RebalanceDataSet(parse_test_file, make, 256);
            H2O.submitTask(rebalanceDataSet);
            rebalanceDataSet.join();
            parse_test_file.delete();
            frame = (Frame) DKV.get(make).get();
            for (int i = 0; i < 5; i++) {
                DRFModel.DRFParameters dRFParameters = new DRFModel.DRFParameters();
                dRFParameters._train = frame._key;
                dRFParameters._response_column = "C55";
                dRFParameters._nbins = 1000;
                dRFParameters._ntrees = 1;
                dRFParameters._max_depth = 8;
                dRFParameters._mtries = -1;
                dRFParameters._min_rows = 10.0d;
                dRFParameters._seed = 1234L;
                DRFModel dRFModel = new DRF(dRFParameters).trainModel().get();
                Assert.assertEquals(dRFModel._output._ntrees, dRFParameters._ntrees);
                dArr[i] = dRFModel._output._scored_train[dRFModel._output._scored_train.length - 1]._mse;
                dRFModel.delete();
            }
            if (frame != null) {
                frame.remove();
            }
            Scope.exit(new Key[0]);
            for (int i2 = 0; i2 < dArr.length; i2++) {
                Log.info(new Object[]{"trial: " + i2 + " -> MSE: " + dArr[i2]});
            }
            for (double d : dArr) {
                Assert.assertEquals(d, dArr[0], 1.0E-15d);
            }
        } catch (Throwable th) {
            if (frame != null) {
                frame.remove();
            }
            throw th;
        }
    }

    /* JADX WARN: Finally extract failed */
    @Test
    public void testReproducibilityAirline() {
        Frame frame = null;
        double[] dArr = new double[1];
        Scope.enter();
        try {
            Frame parse_test_file = parse_test_file("./smalldata/airlines/allyears2k_headers.zip");
            Key make = Key.make("df.rebalanced.hex");
            RebalanceDataSet rebalanceDataSet = new RebalanceDataSet(parse_test_file, make, 256);
            H2O.submitTask(rebalanceDataSet);
            rebalanceDataSet.join();
            parse_test_file.delete();
            frame = (Frame) DKV.get(make).get();
            for (String str : new String[]{"DepTime", "ArrTime", "ActualElapsedTime", "AirTime", "ArrDelay", "DepDelay", "Cancelled", "CancellationCode", "CarrierDelay", "WeatherDelay", "NASDelay", "SecurityDelay", "LateAircraftDelay", "IsArrDelayed"}) {
                frame.remove(str).remove();
            }
            DKV.put(frame);
            for (int i = 0; i < 1; i++) {
                DRFModel.DRFParameters dRFParameters = new DRFModel.DRFParameters();
                dRFParameters._train = frame._key;
                dRFParameters._response_column = "IsDepDelayed";
                dRFParameters._nbins = 10;
                dRFParameters._nbins_cats = 1024;
                dRFParameters._ntrees = 7;
                dRFParameters._max_depth = 10;
                dRFParameters._binomial_double_trees = false;
                dRFParameters._mtries = -1;
                dRFParameters._min_rows = 1.0d;
                dRFParameters._sample_rate = 0.6320000290870667d;
                dRFParameters._balance_classes = true;
                dRFParameters._seed = 4294967298L;
                DRFModel dRFModel = new DRF(dRFParameters).trainModel().get();
                Assert.assertEquals(dRFModel._output._ntrees, dRFParameters._ntrees);
                dArr[i] = dRFModel._output._training_metrics.mse();
                dRFModel.delete();
            }
            if (frame != null) {
                frame.remove();
            }
            Scope.exit(new Key[0]);
            for (int i2 = 0; i2 < dArr.length; i2++) {
                Log.info(new Object[]{"trial: " + i2 + " -> MSE: " + dArr[i2]});
            }
            for (double d : dArr) {
                Assert.assertEquals(0.20377446328850304d, d, 1.0E-4d);
            }
        } catch (Throwable th) {
            if (frame != null) {
                frame.remove();
            }
            throw th;
        }
    }

    @Test
    @Ignore
    public void testAirline() {
        Frame frame = null;
        Frame frame2 = null;
        Scope.enter();
        try {
            frame = parse_test_file(Key.make("air.hex"), "/users/arno/sz_bench_data/train-1m.csv");
            frame2 = parse_test_file(Key.make("airt.hex"), "/users/arno/sz_bench_data/test.csv");
            DRFModel.DRFParameters dRFParameters = new DRFModel.DRFParameters();
            dRFParameters._train = frame._key;
            dRFParameters._valid = frame2._key;
            dRFParameters._ignored_columns = new String[]{"Origin", "Dest"};
            dRFParameters._response_column = "dep_delayed_15min";
            dRFParameters._nbins = 20;
            dRFParameters._nbins_cats = 1024;
            dRFParameters._binomial_double_trees = new Random().nextBoolean();
            dRFParameters._ntrees = 1;
            dRFParameters._max_depth = 3;
            dRFParameters._mtries = -1;
            dRFParameters._sample_rate = 0.6320000290870667d;
            dRFParameters._min_rows = 10.0d;
            dRFParameters._seed = 12L;
            DRFModel dRFModel = new DRF(dRFParameters).trainModel().get();
            Log.info(new Object[]{"Training set AUC:   " + dRFModel._output._training_metrics.auc_obj()._auc});
            Log.info(new Object[]{"Validation set AUC: " + dRFModel._output._validation_metrics.auc_obj()._auc});
            Assert.assertEquals(dRFModel._output._training_metrics.auc_obj()._auc, 0.6498819479528417d, 1.0E-8d);
            Assert.assertEquals(dRFModel._output._validation_metrics.auc_obj()._auc, 0.6479974533672835d, 1.0E-8d);
            dRFModel.delete();
            if (frame != null) {
                frame.remove();
            }
            if (frame2 != null) {
                frame2.remove();
            }
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            if (frame != null) {
                frame.remove();
            }
            if (frame2 != null) {
                frame2.remove();
            }
            throw th;
        }
    }

    @Test
    public void testNoRowWeights() {
        Keyed keyed = null;
        Frame frame = null;
        DRFModel dRFModel = null;
        Scope.enter();
        try {
            keyed = parse_test_file("smalldata/junit/no_weights.csv");
            DKV.put(keyed);
            DRFModel.DRFParameters dRFParameters = new DRFModel.DRFParameters();
            dRFParameters._train = ((Frame) keyed)._key;
            dRFParameters._response_column = "response";
            dRFParameters._seed = 234L;
            dRFParameters._min_rows = 1.0d;
            dRFParameters._max_depth = 2;
            dRFParameters._ntrees = 3;
            dRFModel = (DRFModel) new DRF(dRFParameters).trainModel().get();
            ModelMetricsBinomial modelMetricsBinomial = dRFModel._output._training_metrics;
            Assert.assertEquals(_AUC, modelMetricsBinomial.auc_obj()._auc, 1.0E-8d);
            Assert.assertEquals(_MSE, modelMetricsBinomial.mse(), 1.0E-8d);
            Assert.assertEquals(_LogLoss, modelMetricsBinomial.logloss(), 1.0E-6d);
            if (keyed != null) {
                keyed.remove();
            }
            if (0 != 0) {
                frame.remove();
            }
            if (dRFModel != null) {
                dRFModel.remove();
            }
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            if (keyed != null) {
                keyed.remove();
            }
            if (0 != 0) {
                frame.remove();
            }
            if (dRFModel != null) {
                dRFModel.remove();
            }
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    @Test
    public void testRowWeightsOne() {
        Keyed keyed = null;
        Frame frame = null;
        DRFModel dRFModel = null;
        Scope.enter();
        try {
            keyed = parse_test_file("smalldata/junit/weights_all_ones.csv");
            DKV.put(keyed);
            DRFModel.DRFParameters dRFParameters = new DRFModel.DRFParameters();
            dRFParameters._train = ((Frame) keyed)._key;
            dRFParameters._response_column = "response";
            dRFParameters._weights_column = "weight";
            dRFParameters._seed = 234L;
            dRFParameters._min_rows = 1.0d;
            dRFParameters._max_depth = 2;
            dRFParameters._ntrees = 3;
            dRFModel = (DRFModel) new DRF(dRFParameters).trainModel().get();
            ModelMetricsBinomial modelMetricsBinomial = dRFModel._output._training_metrics;
            Assert.assertEquals(_AUC, modelMetricsBinomial.auc_obj()._auc, 1.0E-8d);
            Assert.assertEquals(_MSE, modelMetricsBinomial.mse(), 1.0E-8d);
            Assert.assertEquals(_LogLoss, modelMetricsBinomial.logloss(), 1.0E-6d);
            if (keyed != null) {
                keyed.remove();
            }
            if (0 != 0) {
                frame.remove();
            }
            if (dRFModel != null) {
                dRFModel.delete();
            }
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            if (keyed != null) {
                keyed.remove();
            }
            if (0 != 0) {
                frame.remove();
            }
            if (dRFModel != null) {
                dRFModel.delete();
            }
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    @Test
    public void testRowWeightsTwo() {
        Keyed keyed = null;
        Frame frame = null;
        DRFModel dRFModel = null;
        Scope.enter();
        try {
            keyed = parse_test_file("smalldata/junit/weights_all_twos.csv");
            DKV.put(keyed);
            DRFModel.DRFParameters dRFParameters = new DRFModel.DRFParameters();
            dRFParameters._train = ((Frame) keyed)._key;
            dRFParameters._response_column = "response";
            dRFParameters._weights_column = "weight";
            dRFParameters._seed = 234L;
            dRFParameters._min_rows = 2.0d;
            dRFParameters._max_depth = 2;
            dRFParameters._ntrees = 3;
            dRFModel = (DRFModel) new DRF(dRFParameters).trainModel().get();
            ModelMetricsBinomial modelMetricsBinomial = dRFModel._output._training_metrics;
            Assert.assertEquals(_AUC, modelMetricsBinomial.auc_obj()._auc, 1.0E-8d);
            Assert.assertEquals(_MSE, modelMetricsBinomial.mse(), 1.0E-8d);
            Assert.assertEquals(_LogLoss, modelMetricsBinomial.logloss(), 1.0E-6d);
            if (keyed != null) {
                keyed.remove();
            }
            if (0 != 0) {
                frame.remove();
            }
            if (dRFModel != null) {
                dRFModel.delete();
            }
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            if (keyed != null) {
                keyed.remove();
            }
            if (0 != 0) {
                frame.remove();
            }
            if (dRFModel != null) {
                dRFModel.delete();
            }
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    @Test
    public void testRowWeightsTiny() {
        Keyed keyed = null;
        Frame frame = null;
        DRFModel dRFModel = null;
        Scope.enter();
        try {
            keyed = parse_test_file("smalldata/junit/weights_all_tiny.csv");
            DKV.put(keyed);
            DRFModel.DRFParameters dRFParameters = new DRFModel.DRFParameters();
            dRFParameters._train = ((Frame) keyed)._key;
            dRFParameters._response_column = "response";
            dRFParameters._weights_column = "weight";
            dRFParameters._seed = 234L;
            dRFParameters._min_rows = 0.01242d;
            dRFParameters._max_depth = 2;
            dRFParameters._ntrees = 3;
            dRFModel = (DRFModel) new DRF(dRFParameters).trainModel().get();
            ModelMetricsBinomial modelMetricsBinomial = dRFModel._output._training_metrics;
            Assert.assertEquals(_AUC, modelMetricsBinomial.auc_obj()._auc, 1.0E-8d);
            Assert.assertEquals(_MSE, modelMetricsBinomial.mse(), 1.0E-8d);
            Assert.assertEquals(_LogLoss, modelMetricsBinomial.logloss(), 1.0E-6d);
            if (keyed != null) {
                keyed.remove();
            }
            if (0 != 0) {
                frame.remove();
            }
            if (dRFModel != null) {
                dRFModel.delete();
            }
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            if (keyed != null) {
                keyed.remove();
            }
            if (0 != 0) {
                frame.remove();
            }
            if (dRFModel != null) {
                dRFModel.delete();
            }
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    @Test
    public void testNoRowWeightsShuffled() {
        Keyed keyed = null;
        Frame frame = null;
        DRFModel dRFModel = null;
        Scope.enter();
        try {
            keyed = parse_test_file("smalldata/junit/no_weights_shuffled.csv");
            DKV.put(keyed);
            DRFModel.DRFParameters dRFParameters = new DRFModel.DRFParameters();
            dRFParameters._train = ((Frame) keyed)._key;
            dRFParameters._response_column = "response";
            dRFParameters._seed = 234L;
            dRFParameters._min_rows = 1.0d;
            dRFParameters._max_depth = 2;
            dRFParameters._ntrees = 3;
            dRFModel = (DRFModel) new DRF(dRFParameters).trainModel().get();
            ModelMetricsBinomial modelMetricsBinomial = dRFModel._output._training_metrics;
            Assert.assertEquals(1.0d, modelMetricsBinomial.auc_obj()._auc, 1.0E-8d);
            Assert.assertEquals(0.029017857142857144d, modelMetricsBinomial.mse(), 1.0E-8d);
            Assert.assertEquals(0.10824081452821664d, modelMetricsBinomial.logloss(), 1.0E-6d);
            if (keyed != null) {
                keyed.remove();
            }
            if (0 != 0) {
                frame.remove();
            }
            if (dRFModel != null) {
                dRFModel.delete();
            }
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            if (keyed != null) {
                keyed.remove();
            }
            if (0 != 0) {
                frame.remove();
            }
            if (dRFModel != null) {
                dRFModel.delete();
            }
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    @Test
    public void testRowWeights() {
        Keyed keyed = null;
        Frame frame = null;
        DRFModel dRFModel = null;
        Scope.enter();
        try {
            keyed = parse_test_file("smalldata/junit/weights.csv");
            DKV.put(keyed);
            DRFModel.DRFParameters dRFParameters = new DRFModel.DRFParameters();
            dRFParameters._train = ((Frame) keyed)._key;
            dRFParameters._response_column = "response";
            dRFParameters._weights_column = "weight";
            dRFParameters._seed = 234L;
            dRFParameters._min_rows = 1.0d;
            dRFParameters._max_depth = 2;
            dRFParameters._ntrees = 3;
            dRFModel = (DRFModel) new DRF(dRFParameters).trainModel().get();
            ModelMetricsBinomial modelMetricsBinomial = dRFModel._output._training_metrics;
            Assert.assertEquals(1.0d, modelMetricsBinomial.auc_obj()._auc, 1.0E-8d);
            Assert.assertEquals(0.05823863636363636d, modelMetricsBinomial.mse(), 1.0E-8d);
            Assert.assertEquals(0.21035264541934587d, modelMetricsBinomial.logloss(), 1.0E-6d);
            Frame score = dRFModel.score(dRFParameters.train());
            ModelMetricsBinomial fromDKV = ModelMetricsBinomial.getFromDKV(dRFModel, dRFParameters.train());
            Assert.assertEquals(1.0d, fromDKV.auc_obj()._auc, 1.0E-8d);
            Assert.assertEquals(0.0154320987654321d, fromDKV.mse(), 1.0E-8d);
            Assert.assertEquals(0.08349430638608361d, fromDKV.logloss(), 1.0E-8d);
            score.remove();
            if (keyed != null) {
                keyed.remove();
            }
            if (0 != 0) {
                frame.remove();
            }
            if (dRFModel != null) {
                dRFModel.delete();
            }
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            if (keyed != null) {
                keyed.remove();
            }
            if (0 != 0) {
                frame.remove();
            }
            if (dRFModel != null) {
                dRFModel.delete();
            }
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    @Test
    @Ignore
    public void testNFold() {
        Frame frame = null;
        Frame frame2 = null;
        DRFModel dRFModel = null;
        Scope.enter();
        try {
            frame = parse_test_file("./smalldata/airlines/allyears2k_headers.zip");
            for (String str : new String[]{"DepTime", "ArrTime", "ActualElapsedTime", "AirTime", "ArrDelay", "DepDelay", "Cancelled", "CancellationCode", "CarrierDelay", "WeatherDelay", "NASDelay", "SecurityDelay", "LateAircraftDelay", "IsArrDelayed"}) {
                frame.remove(str).remove();
            }
            DKV.put(frame);
            DRFModel.DRFParameters dRFParameters = new DRFModel.DRFParameters();
            dRFParameters._train = frame._key;
            dRFParameters._response_column = "IsDepDelayed";
            dRFParameters._seed = 234L;
            dRFParameters._min_rows = 2.0d;
            dRFParameters._nfolds = 3;
            dRFParameters._max_depth = 5;
            dRFParameters._ntrees = 5;
            dRFModel = (DRFModel) new DRF(dRFParameters).trainModel().get();
            ModelMetricsBinomial modelMetricsBinomial = dRFModel._output._cross_validation_metrics;
            Assert.assertEquals(0.7276154565296726d, modelMetricsBinomial.auc_obj()._auc, 1.0E-8d);
            Assert.assertEquals(0.21211607823987555d, modelMetricsBinomial.mse(), 1.0E-8d);
            Assert.assertEquals(0.6121968624307211d, modelMetricsBinomial.logloss(), 1.0E-6d);
            if (frame != null) {
                frame.remove();
            }
            if (0 != 0) {
                frame2.remove();
            }
            if (dRFModel != null) {
                dRFModel.deleteCrossValidationModels();
                dRFModel.delete();
            }
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            if (frame != null) {
                frame.remove();
            }
            if (0 != 0) {
                frame2.remove();
            }
            if (dRFModel != null) {
                dRFModel.deleteCrossValidationModels();
                dRFModel.delete();
            }
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    @Test
    public void testNFoldBalanceClasses() {
        Frame frame = null;
        Frame frame2 = null;
        DRFModel dRFModel = null;
        Scope.enter();
        try {
            frame = parse_test_file("./smalldata/airlines/allyears2k_headers.zip");
            for (String str : new String[]{"DepTime", "ArrTime", "ActualElapsedTime", "AirTime", "ArrDelay", "DepDelay", "Cancelled", "CancellationCode", "CarrierDelay", "WeatherDelay", "NASDelay", "SecurityDelay", "LateAircraftDelay", "IsArrDelayed"}) {
                frame.remove(str).remove();
            }
            DKV.put(frame);
            DRFModel.DRFParameters dRFParameters = new DRFModel.DRFParameters();
            dRFParameters._train = frame._key;
            dRFParameters._response_column = "IsDepDelayed";
            dRFParameters._seed = 234L;
            dRFParameters._min_rows = 2.0d;
            dRFParameters._nfolds = 3;
            dRFParameters._max_depth = 5;
            dRFParameters._balance_classes = true;
            dRFParameters._ntrees = 5;
            dRFModel = (DRFModel) new DRF(dRFParameters).trainModel().get();
            if (frame != null) {
                frame.remove();
            }
            if (0 != 0) {
                frame2.remove();
            }
            if (dRFModel != null) {
                dRFModel.deleteCrossValidationModels();
                dRFModel.delete();
            }
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            if (frame != null) {
                frame.remove();
            }
            if (0 != 0) {
                frame2.remove();
            }
            if (dRFModel != null) {
                dRFModel.deleteCrossValidationModels();
                dRFModel.delete();
            }
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    @Test
    public void testNfoldsOneVsRest() {
        Keyed keyed = null;
        DRFModel dRFModel = null;
        DRFModel dRFModel2 = null;
        Scope.enter();
        try {
            keyed = parse_test_file("smalldata/junit/weights.csv");
            DKV.put(keyed);
            DRFModel.DRFParameters dRFParameters = new DRFModel.DRFParameters();
            dRFParameters._train = ((Frame) keyed)._key;
            dRFParameters._response_column = "response";
            dRFParameters._seed = 9999L;
            dRFParameters._min_rows = 2.0d;
            dRFParameters._nfolds = (int) keyed.numRows();
            dRFParameters._fold_assignment = Model.Parameters.FoldAssignmentScheme.Modulo;
            dRFParameters._max_depth = 5;
            dRFParameters._ntrees = 5;
            dRFModel = (DRFModel) new DRF(dRFParameters).trainModel().get();
            dRFModel2 = (DRFModel) new DRF(dRFParameters).trainModel().get();
            ModelMetricsBinomial modelMetricsBinomial = dRFModel._output._cross_validation_metrics;
            ModelMetricsBinomial modelMetricsBinomial2 = dRFModel2._output._cross_validation_metrics;
            Assert.assertEquals(modelMetricsBinomial.auc_obj()._auc, modelMetricsBinomial2.auc_obj()._auc, 1.0E-12d);
            Assert.assertEquals(modelMetricsBinomial.mse(), modelMetricsBinomial2.mse(), 1.0E-12d);
            Assert.assertEquals(modelMetricsBinomial.logloss(), modelMetricsBinomial2.logloss(), 1.0E-12d);
            if (keyed != null) {
                keyed.remove();
            }
            if (dRFModel != null) {
                dRFModel.deleteCrossValidationModels();
                dRFModel.delete();
            }
            if (dRFModel2 != null) {
                dRFModel2.deleteCrossValidationModels();
                dRFModel2.delete();
            }
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            if (keyed != null) {
                keyed.remove();
            }
            if (dRFModel != null) {
                dRFModel.deleteCrossValidationModels();
                dRFModel.delete();
            }
            if (dRFModel2 != null) {
                dRFModel2.deleteCrossValidationModels();
                dRFModel2.delete();
            }
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    @Test
    public void testNfoldsInvalidValues() {
        Frame frame = null;
        DRFModel dRFModel = null;
        DRFModel dRFModel2 = null;
        DRFModel dRFModel3 = null;
        Scope.enter();
        try {
            frame = parse_test_file("./smalldata/airlines/allyears2k_headers.zip");
            for (String str : new String[]{"DepTime", "ArrTime", "ActualElapsedTime", "AirTime", "ArrDelay", "DepDelay", "Cancelled", "CancellationCode", "CarrierDelay", "WeatherDelay", "NASDelay", "SecurityDelay", "LateAircraftDelay", "IsArrDelayed"}) {
                frame.remove(str).remove();
            }
            DKV.put(frame);
            DRFModel.DRFParameters dRFParameters = new DRFModel.DRFParameters();
            dRFParameters._train = frame._key;
            dRFParameters._response_column = "IsDepDelayed";
            dRFParameters._seed = 234L;
            dRFParameters._min_rows = 2.0d;
            dRFParameters._max_depth = 5;
            dRFParameters._ntrees = 5;
            dRFParameters._nfolds = 0;
            dRFModel = (DRFModel) new DRF(dRFParameters).trainModel().get();
            dRFParameters._nfolds = 1;
            try {
                Log.info(new Object[]{"Trying nfolds==1."});
                dRFModel2 = (DRFModel) new DRF(dRFParameters).trainModel().get();
                Assert.fail("Should toss H2OModelBuilderIllegalArgumentException instead of reaching here");
            } catch (H2OModelBuilderIllegalArgumentException e) {
            }
            dRFParameters._nfolds = -99;
            try {
                Log.info(new Object[]{"Trying nfolds==-99."});
                dRFModel3 = (DRFModel) new DRF(dRFParameters).trainModel().get();
                Assert.fail("Should toss H2OModelBuilderIllegalArgumentException instead of reaching here");
            } catch (H2OModelBuilderIllegalArgumentException e2) {
            }
            if (frame != null) {
                frame.remove();
            }
            if (dRFModel != null) {
                dRFModel.delete();
            }
            if (dRFModel2 != null) {
                dRFModel2.delete();
            }
            if (dRFModel3 != null) {
                dRFModel3.delete();
            }
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            if (frame != null) {
                frame.remove();
            }
            if (dRFModel != null) {
                dRFModel.delete();
            }
            if (dRFModel2 != null) {
                dRFModel2.delete();
            }
            if (dRFModel3 != null) {
                dRFModel3.delete();
            }
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    @Test
    public void testNfoldsCVAndValidation() {
        Keyed keyed = null;
        Frame frame = null;
        DRFModel dRFModel = null;
        Scope.enter();
        try {
            keyed = parse_test_file("smalldata/junit/weights.csv");
            frame = parse_test_file("smalldata/junit/weights.csv");
            DKV.put(keyed);
            DRFModel.DRFParameters dRFParameters = new DRFModel.DRFParameters();
            dRFParameters._train = ((Frame) keyed)._key;
            dRFParameters._valid = frame._key;
            dRFParameters._response_column = "response";
            dRFParameters._min_rows = 2.0d;
            dRFParameters._max_depth = 2;
            dRFParameters._nfolds = 2;
            dRFParameters._ntrees = 3;
            dRFParameters._seed = 11233L;
            try {
                Log.info(new Object[]{"Trying N-fold cross-validation AND Validation dataset provided."});
                dRFModel = (DRFModel) new DRF(dRFParameters).trainModel().get();
            } catch (H2OModelBuilderIllegalArgumentException e) {
                Assert.fail("Should not toss H2OModelBuilderIllegalArgumentException.");
            }
            if (keyed != null) {
                keyed.remove();
            }
            if (frame != null) {
                frame.remove();
            }
            if (dRFModel != null) {
                dRFModel.deleteCrossValidationModels();
                dRFModel.delete();
            }
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            if (keyed != null) {
                keyed.remove();
            }
            if (frame != null) {
                frame.remove();
            }
            if (dRFModel != null) {
                dRFModel.deleteCrossValidationModels();
                dRFModel.delete();
            }
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    @Test
    public void testNfoldsConsecutiveModelsSame() {
        Frame frame = null;
        Vec vec = null;
        DRFModel dRFModel = null;
        DRFModel dRFModel2 = null;
        Scope.enter();
        try {
            frame = parse_test_file("smalldata/junit/cars_20mpg.csv");
            frame.remove("name").remove();
            frame.remove("economy").remove();
            vec = frame.remove("economy_20mpg");
            frame.add("economy_20mpg", VecUtils.toCategoricalVec(vec));
            DKV.put(frame);
            DRFModel.DRFParameters dRFParameters = new DRFModel.DRFParameters();
            dRFParameters._train = frame._key;
            dRFParameters._response_column = "economy_20mpg";
            dRFParameters._min_rows = 2.0d;
            dRFParameters._max_depth = 2;
            dRFParameters._nfolds = 3;
            dRFParameters._ntrees = 3;
            dRFParameters._seed = 77777L;
            dRFModel = (DRFModel) new DRF(dRFParameters).trainModel().get();
            dRFModel2 = (DRFModel) new DRF(dRFParameters).trainModel().get();
            ModelMetricsBinomial modelMetricsBinomial = dRFModel._output._cross_validation_metrics;
            ModelMetricsBinomial modelMetricsBinomial2 = dRFModel2._output._cross_validation_metrics;
            Assert.assertEquals(modelMetricsBinomial.auc_obj()._auc, modelMetricsBinomial2.auc_obj()._auc, 1.0E-12d);
            Assert.assertEquals(modelMetricsBinomial.mse(), modelMetricsBinomial2.mse(), 1.0E-12d);
            Assert.assertEquals(modelMetricsBinomial.logloss(), modelMetricsBinomial2.logloss(), 1.0E-12d);
            if (frame != null) {
                frame.remove();
            }
            if (vec != null) {
                vec.remove();
            }
            if (dRFModel != null) {
                dRFModel.deleteCrossValidationModels();
                dRFModel.delete();
            }
            if (dRFModel2 != null) {
                dRFModel2.deleteCrossValidationModels();
                dRFModel2.delete();
            }
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            if (frame != null) {
                frame.remove();
            }
            if (vec != null) {
                vec.remove();
            }
            if (dRFModel != null) {
                dRFModel.deleteCrossValidationModels();
                dRFModel.delete();
            }
            if (dRFModel2 != null) {
                dRFModel2.deleteCrossValidationModels();
                dRFModel2.delete();
            }
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    @Test
    public void testMTrys() {
        Frame frame = null;
        Vec vec = null;
        DRFModel dRFModel = null;
        for (int i = 1; i <= 6; i++) {
            Scope.enter();
            try {
                frame = parse_test_file("smalldata/junit/cars_20mpg.csv");
                frame.remove("name").remove();
                frame.remove("economy").remove();
                vec = frame.remove("economy_20mpg");
                frame.add("economy_20mpg", VecUtils.toCategoricalVec(vec));
                DKV.put(frame);
                DRFModel.DRFParameters dRFParameters = new DRFModel.DRFParameters();
                dRFParameters._train = frame._key;
                dRFParameters._response_column = "economy_20mpg";
                dRFParameters._min_rows = 2.0d;
                dRFParameters._ntrees = 5;
                dRFParameters._max_depth = 5;
                dRFParameters._nfolds = 3;
                dRFParameters._mtries = i;
                dRFModel = (DRFModel) new DRF(dRFParameters).trainModel().get();
                Assert.assertTrue(dRFModel._output._cross_validation_metrics._auc != null);
                if (frame != null) {
                    frame.remove();
                }
                if (vec != null) {
                    vec.remove();
                }
                if (dRFModel != null) {
                    dRFModel.deleteCrossValidationModels();
                    dRFModel.delete();
                }
                Scope.exit(new Key[0]);
            } catch (Throwable th) {
                if (frame != null) {
                    frame.remove();
                }
                if (vec != null) {
                    vec.remove();
                }
                if (dRFModel != null) {
                    dRFModel.deleteCrossValidationModels();
                    dRFModel.delete();
                }
                Scope.exit(new Key[0]);
                throw th;
            }
        }
    }

    @Test
    public void testMTryNegTwo() {
        Frame frame = null;
        Vec vec = null;
        DRFModel dRFModel = null;
        Scope.enter();
        try {
            frame = parse_test_file("smalldata/junit/cars_20mpg.csv");
            frame.remove("name").remove();
            frame.remove("economy").remove();
            vec = frame.remove("economy_20mpg");
            frame.add("economy_20mpg", VecUtils.toCategoricalVec(vec));
            frame.add("constantCol", frame.anyVec().makeCon(1.0d));
            DKV.put(frame);
            DRFModel.DRFParameters dRFParameters = new DRFModel.DRFParameters();
            dRFParameters._train = frame._key;
            dRFParameters._response_column = "economy_20mpg";
            dRFParameters._ignored_columns = new String[]{"year"};
            dRFParameters._min_rows = 2.0d;
            dRFParameters._ntrees = 5;
            dRFParameters._max_depth = 5;
            dRFParameters._nfolds = 3;
            dRFParameters._mtries = -2;
            dRFModel = (DRFModel) new DRF(dRFParameters).trainModel().get();
            Assert.assertTrue(dRFModel._output._cross_validation_metrics._auc != null);
            if (frame != null) {
                frame.remove();
            }
            if (vec != null) {
                vec.remove();
            }
            if (dRFModel != null) {
                dRFModel.deleteCrossValidationModels();
                dRFModel.delete();
            }
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            if (frame != null) {
                frame.remove();
            }
            if (vec != null) {
                vec.remove();
            }
            if (dRFModel != null) {
                dRFModel.deleteCrossValidationModels();
                dRFModel.delete();
            }
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    @Test
    public void testStochasticDRFEquivalent() {
        Frame frame = null;
        Frame frame2 = null;
        DRFModel dRFModel = null;
        Scope.enter();
        try {
            frame = parse_test_file("./smalldata/junit/cars.csv");
            for (String str : new String[]{"name"}) {
                frame.remove(str).remove();
            }
            DKV.put(frame);
            DRFModel.DRFParameters dRFParameters = new DRFModel.DRFParameters();
            dRFParameters._train = frame._key;
            dRFParameters._response_column = "cylinders";
            dRFParameters._seed = 234L;
            dRFParameters._min_rows = 2.0d;
            dRFParameters._max_depth = 5;
            dRFParameters._ntrees = 5;
            dRFParameters._mtries = 3;
            dRFParameters._sample_rate = 0.5d;
            dRFModel = (DRFModel) new DRF(dRFParameters).trainModel().get();
            Assert.assertEquals(0.12358322821934015d, dRFModel._output._training_metrics.mse(), 1.0E-4d);
            if (frame != null) {
                frame.remove();
            }
            if (0 != 0) {
                frame2.remove();
            }
            if (dRFModel != null) {
                dRFModel.delete();
            }
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            if (frame != null) {
                frame.remove();
            }
            if (0 != 0) {
                frame2.remove();
            }
            if (dRFModel != null) {
                dRFModel.delete();
            }
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    @Test
    public void testColSamplingPerTree() {
        Frame frame = null;
        Key[] keyArr = new Key[0];
        try {
            frame = parse_test_file("./smalldata/gbm_test/ecology_model.csv");
            SplitFrame splitFrame = new SplitFrame(frame, new double[]{0.5d, 0.5d}, new Key[]{Key.make("train.hex"), Key.make("test.hex")});
            splitFrame.exec().get();
            keyArr = splitFrame._destination_frames;
            DRFModel dRFModel = null;
            float[] fArr = {0.4f, 0.6f, 0.8f, 1.0f};
            float[] fArr2 = {0.4f, 0.6f, 0.8f, 1.0f};
            TreeMap treeMap = new TreeMap();
            for (float f : new float[]{0.2f, 0.4f, 0.6f, 0.8f, 1.0f}) {
                for (float f2 : fArr) {
                    for (float f3 : fArr2) {
                        Scope.enter();
                        try {
                            DRFModel.DRFParameters dRFParameters = new DRFModel.DRFParameters();
                            dRFParameters._train = keyArr[0];
                            dRFParameters._valid = keyArr[1];
                            dRFParameters._response_column = "Angaus";
                            dRFParameters._seed = 12345L;
                            dRFParameters._min_rows = 1.0d;
                            dRFParameters._max_depth = 15;
                            dRFParameters._ntrees = 2;
                            dRFParameters._mtries = Math.max(1, (int) (f2 * (frame.numCols() - 1)));
                            dRFParameters._col_sample_rate_per_tree = f3;
                            dRFParameters._sample_rate = f;
                            dRFModel = (DRFModel) new DRF(dRFParameters).trainModel().get();
                            treeMap.put(Double.valueOf(dRFModel._output._validation_metrics.mse()), new Triple(Float.valueOf(f), Float.valueOf(f2), Float.valueOf(f3)));
                            if (dRFModel != null) {
                                dRFModel.delete();
                            }
                            Scope.exit(new Key[0]);
                        } catch (Throwable th) {
                            if (dRFModel != null) {
                                dRFModel.delete();
                            }
                            Scope.exit(new Key[0]);
                            throw th;
                        }
                    }
                }
            }
            for (Map.Entry entry : treeMap.entrySet()) {
                Log.info(new Object[]{"MSE: " + entry.getKey() + ", row sample: " + ((Triple) entry.getValue()).v1 + ", col sample: " + ((Triple) entry.getValue()).v2 + ", col sample per tree: " + ((Triple) entry.getValue()).v3});
            }
            if (frame != null) {
                frame.remove();
            }
            for (Key key : keyArr) {
                if (key != null) {
                    key.remove();
                }
            }
        } catch (Throwable th2) {
            if (frame != null) {
                frame.remove();
            }
            for (Key key2 : keyArr) {
                if (key2 != null) {
                    key2.remove();
                }
            }
            throw th2;
        }
    }

    @Test
    public void minSplitImprovement() {
        Frame frame = null;
        Key[] keyArr = null;
        DRFModel dRFModel = null;
        try {
            Scope.enter();
            frame = parse_test_file("smalldata/covtype/covtype.20k.data");
            Scope.track(frame.replace(54, frame.vecs()[54].toCategoricalVec()));
            DKV.put(frame);
            SplitFrame splitFrame = new SplitFrame(frame, new double[]{0.5d, 0.5d}, new Key[]{Key.make("train.hex"), Key.make("valid.hex")});
            splitFrame.exec().get();
            keyArr = splitFrame._destination_frames;
            double[] dArr = {0.0d, 1.0E-10d, 1.0E-8d, 1.0E-6d, 1.0E-4d, 0.01d};
            int length = dArr.length;
            double[] dArr2 = new double[length];
            for (int i = 0; i < length; i++) {
                DRFModel.DRFParameters dRFParameters = new DRFModel.DRFParameters();
                dRFParameters._train = keyArr[0];
                dRFParameters._valid = keyArr[1];
                dRFParameters._response_column = frame.names()[54];
                dRFParameters._min_split_improvement = dArr[i];
                dRFParameters._ntrees = 20;
                dRFParameters._score_tree_interval = dRFParameters._ntrees;
                dRFParameters._max_depth = 15;
                dRFParameters._seed = 1234L;
                dRFModel = (DRFModel) new DRF(dRFParameters).trainModel().get();
                dArr2[i] = dRFModel._output._scored_valid[dRFModel._output._scored_valid.length - 1]._logloss;
                if (dRFModel != null) {
                    dRFModel.delete();
                }
            }
            for (int i2 = 0; i2 < dArr.length; i2++) {
                Log.info(new Object[]{"min_split_improvement: " + dArr[i2] + " -> validation logloss: " + dArr2[i2]});
            }
            int minIndex = ArrayUtils.minIndex(dArr2);
            Log.info(new Object[]{"Optimal min_split_improvement: " + dArr[minIndex]});
            Assert.assertTrue(0 != minIndex);
            if (dRFModel != null) {
                dRFModel.delete();
            }
            if (frame != null) {
                frame.delete();
            }
            if (keyArr[0] != null) {
                keyArr[0].remove();
            }
            if (keyArr[1] != null) {
                keyArr[1].remove();
            }
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            if (dRFModel != null) {
                dRFModel.delete();
            }
            if (frame != null) {
                frame.delete();
            }
            if (keyArr[0] != null) {
                keyArr[0].remove();
            }
            if (keyArr[1] != null) {
                keyArr[1].remove();
            }
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    @Test
    public void histoTypes() {
        Frame frame = null;
        Key[] keyArr = null;
        DRFModel dRFModel = null;
        try {
            Scope.enter();
            frame = parse_test_file("smalldata/covtype/covtype.20k.data");
            Scope.track(frame.replace(54, frame.vecs()[54].toCategoricalVec()));
            DKV.put(frame);
            SplitFrame splitFrame = new SplitFrame(frame, new double[]{0.5d, 0.5d}, new Key[]{Key.make("train.hex"), Key.make("valid.hex")});
            splitFrame.exec().get();
            keyArr = splitFrame._destination_frames;
            SharedTreeModel.SharedTreeParameters.HistogramType[] values = SharedTreeModel.SharedTreeParameters.HistogramType.values();
            int length = values.length;
            double[] dArr = new double[length];
            for (int i = 0; i < length; i++) {
                DRFModel.DRFParameters dRFParameters = new DRFModel.DRFParameters();
                dRFParameters._train = keyArr[0];
                dRFParameters._valid = keyArr[1];
                dRFParameters._response_column = frame.names()[54];
                dRFParameters._histogram_type = values[i];
                dRFParameters._ntrees = 10;
                dRFParameters._score_tree_interval = dRFParameters._ntrees;
                dRFParameters._max_depth = 10;
                dRFParameters._seed = 12345L;
                dRFParameters._nbins = 20;
                dRFParameters._nbins_top_level = 20;
                dRFModel = (DRFModel) new DRF(dRFParameters).trainModel().get();
                dArr[i] = dRFModel._output._scored_valid[dRFModel._output._scored_valid.length - 1]._logloss;
                if (dRFModel != null) {
                    dRFModel.delete();
                }
            }
            for (int i2 = 0; i2 < values.length; i2++) {
                Log.info(new Object[]{"histoType: " + values[i2] + " -> validation logloss: " + dArr[i2]});
            }
            int minIndex = ArrayUtils.minIndex(dArr);
            Log.info(new Object[]{"Optimal randomization: " + values[minIndex]});
            Assert.assertTrue(4 == minIndex);
            if (dRFModel != null) {
                dRFModel.delete();
            }
            if (frame != null) {
                frame.delete();
            }
            if (keyArr[0] != null) {
                keyArr[0].remove();
            }
            if (keyArr[1] != null) {
                keyArr[1].remove();
            }
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            if (dRFModel != null) {
                dRFModel.delete();
            }
            if (frame != null) {
                frame.delete();
            }
            if (keyArr[0] != null) {
                keyArr[0].remove();
            }
            if (keyArr[1] != null) {
                keyArr[1].remove();
            }
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    @Test
    public void sampleRatePerClass() {
        Frame frame = null;
        Key[] keyArr = null;
        DRFModel dRFModel = null;
        try {
            Scope.enter();
            frame = parse_test_file("smalldata/covtype/covtype.20k.data");
            Scope.track(frame.replace(54, frame.vecs()[54].toCategoricalVec()));
            DKV.put(frame);
            SplitFrame splitFrame = new SplitFrame(frame, new double[]{0.5d, 0.5d}, new Key[]{Key.make("train.hex"), Key.make("valid.hex")});
            splitFrame.exec().get();
            keyArr = splitFrame._destination_frames;
            DRFModel.DRFParameters dRFParameters = new DRFModel.DRFParameters();
            dRFParameters._train = keyArr[0];
            dRFParameters._valid = keyArr[1];
            dRFParameters._response_column = frame.names()[54];
            dRFParameters._min_split_improvement = 1.0E-5d;
            dRFParameters._ntrees = 20;
            dRFParameters._score_tree_interval = dRFParameters._ntrees;
            dRFParameters._max_depth = 15;
            dRFParameters._seed = 1234L;
            dRFParameters._sample_rate_per_class = new double[]{0.10000000149011612d, 0.10000000149011612d, 0.20000000298023224d, 0.4000000059604645d, 1.0d, 0.30000001192092896d, 0.20000000298023224d};
            dRFModel = (DRFModel) new DRF(dRFParameters).trainModel().get();
            if (dRFModel != null) {
                dRFModel.delete();
            }
            if (dRFModel != null) {
                dRFModel.delete();
            }
            if (frame != null) {
                frame.delete();
            }
            if (keyArr[0] != null) {
                keyArr[0].remove();
            }
            if (keyArr[1] != null) {
                keyArr[1].remove();
            }
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            if (dRFModel != null) {
                dRFModel.delete();
            }
            if (frame != null) {
                frame.delete();
            }
            if (keyArr[0] != null) {
                keyArr[0].remove();
            }
            if (keyArr[1] != null) {
                keyArr[1].remove();
            }
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    static {
        $assertionsDisabled = !DRFTest.class.desiredAssertionStatus();
        _AUC = 1.0d;
        _MSE = 0.041294642857142856d;
        _LogLoss = 0.14472835908293025d;
    }
}
