package hex.tree.xgboost;

import hex.DataInfo;
import hex.tree.xgboost.XGBoostModel;
import hex.tree.xgboost.XGBoostUtils;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.net.URL;
import java.security.SecureRandom;
import java.text.NumberFormat;
import java.text.ParseException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.Locale;
import java.util.Map;
import ml.dmlc.xgboost4j.java.Booster;
import ml.dmlc.xgboost4j.java.DMatrix;
import ml.dmlc.xgboost4j.java.IEvaluation;
import ml.dmlc.xgboost4j.java.IObjective;
import ml.dmlc.xgboost4j.java.Rabit;
import ml.dmlc.xgboost4j.java.XGBoost;
import ml.dmlc.xgboost4j.java.XGBoostError;
import org.junit.After;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Ignore;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import water.Key;
import water.MRTask;
import water.Scope;
import water.TestUtil;
import water.fvec.Frame;
import water.fvec.TestFrameBuilder;
import water.fvec.Vec;
import water.util.VecUtils;

@Ignore("Parent for XGBoostUtilsTest, no actual tests here")
/* loaded from: input_file:hex/tree/xgboost/XGBoostUtilsTest.class */
public class XGBoostUtilsTest extends TestUtil {
    protected static final int DEFAULT_SPARSE_MATRIX_SIZE = XGBoostUtils.SPARSE_MATRIX_DIM;
    protected static final int MAX_ARR_SIZE = 2147483637;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:hex/tree/xgboost/XGBoostUtilsTest$CsrLayout.class */
    public static class CsrLayout {
        int _numRegRows;
        int _regRowLen;
        int _lastRowLen;

        private CsrLayout() {
        }

        /* JADX WARN: Type inference failed for: r0v3, types: [long[], long[][]] */
        long[][] allocateLong() {
            ?? r0 = new long[this._numRegRows + 1];
            for (int i = 0; i < this._numRegRows; i++) {
                r0[i] = new long[this._regRowLen];
            }
            r0[r0.length - 1] = new long[this._lastRowLen];
            return r0;
        }

        /* JADX WARN: Type inference failed for: r0v3, types: [int[], int[][]] */
        int[][] allocateInt() {
            ?? r0 = new int[this._numRegRows + 1];
            for (int i = 0; i < this._numRegRows; i++) {
                r0[i] = new int[this._regRowLen];
            }
            r0[r0.length - 1] = new int[this._lastRowLen];
            return r0;
        }

        /* JADX WARN: Type inference failed for: r0v3, types: [float[], float[][]] */
        float[][] allocateFloat() {
            ?? r0 = new float[this._numRegRows + 1];
            for (int i = 0; i < this._numRegRows; i++) {
                r0[i] = new float[this._regRowLen];
            }
            r0[r0.length - 1] = new float[this._lastRowLen];
            return r0;
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:hex/tree/xgboost/XGBoostUtilsTest$Matrices.class */
    public static class Matrices {
        private final DMatrix _dmatrix;
        private final Frame _h2oFrame;

        public Matrices(DMatrix dMatrix, Frame frame) {
            this._dmatrix = dMatrix;
            this._h2oFrame = frame;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:hex/tree/xgboost/XGBoostUtilsTest$XGBSparseMatrixDimTask.class */
    public static final class XGBSparseMatrixDimTask extends MRTask<XGBSparseMatrixDimTask> {
        private final int _maxMatrixDimension;

        private XGBSparseMatrixDimTask(int i) {
            if (i < 1) {
                throw new IllegalArgumentException("Max matrix dimension must be greater than 0.");
            }
            this._maxMatrixDimension = i;
        }

        protected void setupLocal() {
            XGBoostUtils.SPARSE_MATRIX_DIM = this._maxMatrixDimension;
            Assert.assertEquals(XGBoostUtils.SPARSE_MATRIX_DIM, this._maxMatrixDimension);
        }
    }

    @RunWith(Parameterized.class)
    /* loaded from: input_file:hex/tree/xgboost/XGBoostUtilsTest$XGBoostSparseMatrixAllocationTest.class */
    public static final class XGBoostSparseMatrixAllocationTest extends XGBoostUtilsTest {

        @Parameterized.Parameter(0)
        public int nonZeroElementsCount;

        @Parameterized.Parameter(1)
        public int rowIndicesCount;

        @Parameterized.Parameter(2)
        public int sparseDataMatrixNumRows;

        @Parameterized.Parameter(3)
        public int arrNumRows;

        @Parameterized.Parameter(4)
        public int arrNumCols;

        @Parameterized.Parameter(5)
        public int arrNumColsLastRow;

        @Parameterized.Parameter(6)
        public int sparseMatrixDimensions;

        @Parameterized.Parameters
        public static Collection<Object[]> data() {
            return Arrays.asList(new Object[]{9, 3, 3, 3, 3, 3, 3}, new Object[]{9, 3, 3, 5, 2, 1, 2}, new Object[]{0, 2, 1, 0, 1, 1, 3});
        }

        @Test
        public void testAllocateCSR() {
            XGBoostUtilsTest.setSparseMatrixMaxDimensions(this.sparseMatrixDimensions);
            XGBoostUtils.SparseMatrix allocateCSRMatrix = XGBoostUtils.allocateCSRMatrix(new XGBoostUtils.SparseMatrixDimensions(this.nonZeroElementsCount, this.rowIndicesCount));
            Assert.assertEquals(this.arrNumRows, allocateCSRMatrix._sparseData.length);
            for (int i = 0; i < allocateCSRMatrix._sparseData.length - 1; i++) {
                Assert.assertEquals(this.arrNumCols, allocateCSRMatrix._sparseData[i].length);
            }
            if (allocateCSRMatrix._sparseData.length != 0) {
                Assert.assertEquals(this.arrNumColsLastRow, allocateCSRMatrix._sparseData[allocateCSRMatrix._sparseData.length - 1].length);
            }
        }
    }

    @RunWith(Parameterized.class)
    /* loaded from: input_file:hex/tree/xgboost/XGBoostUtilsTest$XGBoostSparseMatrixTest.class */
    public static final class XGBoostSparseMatrixTest extends XGBoostUtilsTest {

        @Parameterized.Parameter(0)
        public int matrixDimension;

        @Parameterized.Parameter(1)
        public int maxArrayLen;

        @Parameterized.Parameter(2)
        public int maxNativeArrayLen;

        @Parameterized.Parameters
        public static Collection<Object[]> data() {
            return Arrays.asList(new Object[]{30, 10, Integer.valueOf(XGBoostUtilsTest.MAX_ARR_SIZE)}, new Object[]{30, 10, 10}, new Object[]{30, Integer.valueOf(XGBoostUtilsTest.MAX_ARR_SIZE), Integer.valueOf(XGBoostUtilsTest.MAX_ARR_SIZE)}, new Object[]{300, 10, Integer.valueOf(XGBoostUtilsTest.MAX_ARR_SIZE)}, new Object[]{300, 10, 10}, new Object[]{300, Integer.valueOf(XGBoostUtilsTest.MAX_ARR_SIZE), Integer.valueOf(XGBoostUtilsTest.MAX_ARR_SIZE)}, new Object[]{1000, 10, Integer.valueOf(XGBoostUtilsTest.MAX_ARR_SIZE)}, new Object[]{1000, 10, 10}, new Object[]{1000, Integer.valueOf(XGBoostUtilsTest.MAX_ARR_SIZE), Integer.valueOf(XGBoostUtilsTest.MAX_ARR_SIZE)});
        }

        @Test
        public void testCSRPredictions_compare_with_native() throws XGBoostError {
            Frame frame = null;
            XGBoostModel xGBoostModel = null;
            Booster booster = null;
            Frame frame2 = null;
            try {
                HashMap hashMap = new HashMap();
                hashMap.put("DMLC_TASK_ID", "0");
                Rabit.init(hashMap);
                Matrices createIdentityMatrices = createIdentityMatrices(this.matrixDimension, XGBoostUtilsTest.MAX_ARR_SIZE);
                frame = createIdentityMatrices._h2oFrame;
                final DMatrix dMatrix = createIdentityMatrices._dmatrix;
                float[] createRandomLabelCol = createRandomLabelCol(this.matrixDimension);
                dMatrix.setLabel(createRandomLabelCol);
                attachLabelToFrame(createIdentityMatrices._h2oFrame, createRandomLabelCol);
                booster = XGBoost.train(dMatrix, new HashMap<String, Object>() { // from class: hex.tree.xgboost.XGBoostUtilsTest.XGBoostSparseMatrixTest.1
                    {
                        put("objective", "reg:linear");
                        put("eta", Double.valueOf(1.0d));
                        put("max_depth", 16);
                        put("ntrees", 5);
                        put("colsample_bytree", Double.valueOf(1.0d));
                        put("tree_method", "exact");
                        put("backend", "cpu");
                        put("booster", "gbtree");
                        put("lambda", Double.valueOf(1.0d));
                        put("grow_policy", "depthwise");
                        put("nthread", 12);
                        put("subsample", Double.valueOf(1.0d));
                        put("colsample_bylevel", Double.valueOf(1.0d));
                        put("max_delta_step", Double.valueOf(0.0d));
                        put("min_child_weight", Double.valueOf(1.0d));
                        put("gamma", Double.valueOf(0.0d));
                        put("seed", 1);
                    }
                }, 5, new HashMap<String, DMatrix>() { // from class: hex.tree.xgboost.XGBoostUtilsTest.XGBoostSparseMatrixTest.2
                    {
                        put("train", dMatrix);
                    }
                }, (IObjective) null, (IEvaluation) null);
                Assert.assertNotNull(booster);
                float[][] predict = booster.predict(dMatrix);
                Assert.assertNotNull(predict);
                XGBoostUtilsTest.setSparseMatrixMaxDimensions(this.maxArrayLen);
                XGBoostModel.XGBoostParameters xGBoostParameters = new XGBoostModel.XGBoostParameters();
                xGBoostParameters._ntrees = 5;
                xGBoostParameters._eta = 1.0d;
                xGBoostParameters._max_depth = 16;
                xGBoostParameters._stopping_rounds = 5;
                xGBoostParameters._train = createIdentityMatrices._h2oFrame._key;
                xGBoostParameters._response_column = "response";
                xGBoostParameters._backend = XGBoostModel.XGBoostParameters.Backend.cpu;
                xGBoostParameters._tree_method = XGBoostModel.XGBoostParameters.TreeMethod.exact;
                xGBoostParameters._seed = 1L;
                xGBoostModel = (XGBoostModel) new XGBoost(xGBoostParameters).trainModel().get();
                Assert.assertNotNull(xGBoostModel);
                frame2 = xGBoostModel.score(frame);
                XGBoostUtilsTest.comparePreds(predict, frame2.vec("predict"), 1.0E-6f);
                Rabit.shutdown();
                if (frame != null) {
                    frame.delete();
                }
                if (xGBoostModel != null) {
                    xGBoostModel.delete();
                }
                if (booster != null) {
                    booster.dispose();
                }
                if (frame2 != null) {
                    frame2.delete();
                }
            } catch (Throwable th) {
                Rabit.shutdown();
                if (frame != null) {
                    frame.delete();
                }
                if (xGBoostModel != null) {
                    xGBoostModel.delete();
                }
                if (booster != null) {
                    booster.dispose();
                }
                if (frame2 != null) {
                    frame2.delete();
                }
                throw th;
            }
        }
    }

    /* loaded from: input_file:hex/tree/xgboost/XGBoostUtilsTest$XGBoostUtilsTestSingleRun.class */
    public static final class XGBoostUtilsTestSingleRun extends XGBoostUtilsTest {
        @Test
        public void parseFeatureScores() throws IOException, ParseException {
            String[] readLines = XGBoostUtilsTest.readLines(getClass().getResource("xgbdump.txt"));
            String[] readLines2 = XGBoostUtilsTest.readLines(getClass().getResource("xgbvarimps.txt"));
            Map parseFeatureScores = XGBoostUtils.parseFeatureScores(readLines);
            double d = 0.0d;
            double d2 = 0.0d;
            double d3 = 0.0d;
            for (XGBoostUtils.FeatureScore featureScore : parseFeatureScores.values()) {
                d += featureScore._gain;
                d2 += featureScore._cover;
                d3 += featureScore._frequency;
            }
            NumberFormat numberFormat = NumberFormat.getInstance(Locale.US);
            for (String str : readLines2) {
                String[] split = str.split(" ");
                Assert.assertNotNull("Score " + split[0] + " should ve calculated", (XGBoostUtils.FeatureScore) parseFeatureScores.get(split[0]));
                Assert.assertEquals("Gain of " + split[0], numberFormat.parse(split[1]).floatValue(), r0._gain / d, 1.0E-6d);
                Assert.assertEquals("Cover of " + split[0], numberFormat.parse(split[2]).floatValue(), r0._cover / d2, 1.0E-6d);
                Assert.assertEquals("Frequency of " + split[0], numberFormat.parse(split[3]).floatValue(), r0._frequency / d3, 1.0E-6d);
            }
        }

        @Test
        public void testCSRPredictionComparison_cars() {
            try {
                Scope.enter();
                Frame parse_test_file = TestUtil.parse_test_file("smalldata/junit/cars.csv");
                Scope.track(new Frame[]{parse_test_file});
                Frame parse_test_file2 = TestUtil.parse_test_file("smalldata/testng/cars_test.csv");
                Scope.track(new Frame[]{parse_test_file2});
                XGBoostUtilsTest.testCSRPredictions(parse_test_file, "cylinders", parse_test_file2);
                Scope.exit(new Key[0]);
            } catch (Throwable th) {
                Scope.exit(new Key[0]);
                throw th;
            }
        }

        @Test
        public void testCSRPredictionComparison_airlines() {
            try {
                Scope.enter();
                Frame parse_test_file = TestUtil.parse_test_file("smalldata/testng/airlines.csv");
                Scope.track(new Frame[]{parse_test_file});
                Frame parse_test_file2 = TestUtil.parse_test_file("smalldata/testng/airlines_test.csv");
                Scope.track(new Frame[]{parse_test_file2});
                XGBoostUtilsTest.testCSRPredictions(parse_test_file, "IsDepDelayed", parse_test_file2);
                Scope.exit(new Key[0]);
            } catch (Throwable th) {
                Scope.exit(new Key[0]);
                throw th;
            }
        }

        @Test
        public void testCSRPredictionComparison_airQuality() {
            try {
                Scope.enter();
                Frame parse_test_file = TestUtil.parse_test_file("smalldata/testng/airquality_train1.csv");
                Scope.track(new Frame[]{parse_test_file});
                Frame parse_test_file2 = TestUtil.parse_test_file("smalldata/testng/airquality_validation1.csv");
                Scope.track(new Frame[]{parse_test_file2});
                XGBoostUtilsTest.testCSRPredictions(parse_test_file, "Ozone", parse_test_file2);
                Scope.exit(new Key[0]);
            } catch (Throwable th) {
                Scope.exit(new Key[0]);
                throw th;
            }
        }

        @Test
        public void testCSRPredictionComparison_prostate() {
            try {
                Scope.enter();
                Frame parse_test_file = TestUtil.parse_test_file("smalldata/testng/prostate_train.csv");
                Scope.track(new Frame[]{parse_test_file});
                Frame parse_test_file2 = TestUtil.parse_test_file("smalldata/testng/prostate_test.csv");
                Scope.track(new Frame[]{parse_test_file2});
                XGBoostUtilsTest.testCSRPredictions(parse_test_file, "GLEASON", parse_test_file2);
                Scope.exit(new Key[0]);
            } catch (Throwable th) {
                Scope.exit(new Key[0]);
                throw th;
            }
        }

        @Test
        public void testSparsematrixNumLines() throws XGBoostError {
            Frame frame = null;
            try {
                frame = Scope.track(new Frame[]{new TestFrameBuilder().withName("testFrame").withColNames(new String[]{"C1", "C2", "C3"}).withVecTypes(new byte[]{3, 3, 3}).withDataForCol(0, ard(new double[]{0.0d, 1.0d, 0.0d})).withDataForCol(1, ard(new double[]{0.0d, 2.0d, 0.0d})).withDataForCol(2, ard(new double[]{0.0d, 3.0d, 0.0d})).build()});
                DMatrix convertFrameToDMatrix = XGBoostUtils.convertFrameToDMatrix(new DataInfo(frame, (Frame) null, true, DataInfo.TransformType.NONE, false, false, false), frame, "C3", (String) null, true);
                Assert.assertNotNull(convertFrameToDMatrix);
                Assert.assertEquals(3L, convertFrameToDMatrix.rowNum());
                Assert.assertArrayEquals(arf(new float[]{0.0f, 3.0f, 0.0f}), convertFrameToDMatrix.getLabel(), 0.0f);
                if (frame != null) {
                    frame.remove();
                }
            } catch (Throwable th) {
                if (frame != null) {
                    frame.remove();
                }
                throw th;
            }
        }

        @Test
        public void testSparsematrixInit_emptyRowHandling() throws XGBoostError {
            Frame frame = null;
            try {
                frame = Scope.track(new Frame[]{new TestFrameBuilder().withName("testFrame").withColNames(new String[]{"C1", "C2", "C3"}).withVecTypes(new byte[]{3, 3, 3}).withDataForCol(0, ard(new double[]{0.0d, 1.0d, 0.0d})).withDataForCol(1, ard(new double[]{0.0d, 2.0d, 0.0d})).withDataForCol(2, ard(new double[]{0.0d, 3.0d, 0.0d})).build()});
                Vec anyVec = frame.anyVec();
                int[] localChunkIds = VecUtils.getLocalChunkIds(frame.anyVec());
                float[] fArr = new float[(int) anyVec.length()];
                Vec vec = frame.vec("C3");
                vec.getClass();
                Vec.Reader reader = new Vec.Reader(vec);
                DataInfo dataInfo = new DataInfo(frame, (Frame) null, true, DataInfo.TransformType.NONE, false, false, false);
                Vec.Reader[] readerArr = new Vec.Reader[frame.numCols()];
                for (int i = 0; i < readerArr.length; i++) {
                    Vec vec2 = frame.vec(i);
                    vec2.getClass();
                    readerArr[i] = new Vec.Reader(vec2);
                }
                int length = (int) anyVec.length();
                XGBoostUtilsTest.setSparseMatrixMaxDimensions(3);
                XGBoostUtils.SparseMatrixDimensions calculateCSRMatrixDimensions = XGBoostUtils.calculateCSRMatrixDimensions(frame, localChunkIds, (Vec) null, dataInfo);
                Assert.assertNotNull(calculateCSRMatrixDimensions);
                Assert.assertEquals(3L, calculateCSRMatrixDimensions._nonZeroElementsCount);
                Assert.assertEquals(4L, calculateCSRMatrixDimensions._rowHeadersCount);
                XGBoostUtils.SparseMatrix allocateCSRMatrix = XGBoostUtils.allocateCSRMatrix(calculateCSRMatrixDimensions);
                XGBoostUtilsTest.checkSparseDataStructuresAllocation(allocateCSRMatrix, calculateCSRMatrixDimensions._nonZeroElementsCount, length);
                int initalizeFromChunkIds = XGBoostUtils.initalizeFromChunkIds(frame, localChunkIds, readerArr, (Vec.Reader) null, dataInfo, allocateCSRMatrix._rowHeaders, allocateCSRMatrix._sparseData, allocateCSRMatrix._colIndices, reader, fArr, (float[]) null);
                Assert.assertEquals(3L, initalizeFromChunkIds);
                XGBoostUtilsTest.checkSparseDataInitialization(allocateCSRMatrix, new float[]{1.0f, 2.0f, 3.0f}, new long[]{0, 0, 3, 3}, new int[]{0, 1, 2});
                Assert.assertEquals(length, new DMatrix(allocateCSRMatrix._rowHeaders, allocateCSRMatrix._colIndices, allocateCSRMatrix._sparseData, DMatrix.SparseType.CSR, dataInfo.fullN(), initalizeFromChunkIds + 1, calculateCSRMatrixDimensions._nonZeroElementsCount).rowNum());
                if (frame != null) {
                    frame.remove();
                }
            } catch (Throwable th) {
                if (frame != null) {
                    frame.remove();
                }
                throw th;
            }
        }

        @Test
        public void testSparsematrixInit_identity() throws XGBoostError {
            Frame frame = null;
            try {
                frame = Scope.track(new Frame[]{new TestFrameBuilder().withName("testFrame").withColNames(new String[]{"C1", "C2", "C3"}).withVecTypes(new byte[]{3, 3, 3}).withDataForCol(0, ard(new double[]{1.0d, 0.0d, 0.0d})).withDataForCol(1, ard(new double[]{0.0d, 1.0d, 0.0d})).withDataForCol(2, ard(new double[]{0.0d, 0.0d, 1.0d})).build()});
                Vec anyVec = frame.anyVec();
                int[] localChunkIds = VecUtils.getLocalChunkIds(frame.anyVec());
                float[] fArr = new float[(int) anyVec.length()];
                Vec vec = frame.vec("C3");
                vec.getClass();
                Vec.Reader reader = new Vec.Reader(vec);
                DataInfo dataInfo = new DataInfo(frame, (Frame) null, true, DataInfo.TransformType.NONE, false, false, false);
                Vec.Reader[] readerArr = new Vec.Reader[frame.numCols()];
                for (int i = 0; i < readerArr.length; i++) {
                    Vec vec2 = frame.vec(i);
                    vec2.getClass();
                    readerArr[i] = new Vec.Reader(vec2);
                }
                int length = (int) anyVec.length();
                XGBoostUtilsTest.setSparseMatrixMaxDimensions(3);
                XGBoostUtils.SparseMatrixDimensions calculateCSRMatrixDimensions = XGBoostUtils.calculateCSRMatrixDimensions(frame, localChunkIds, (Vec) null, dataInfo);
                Assert.assertNotNull(calculateCSRMatrixDimensions);
                Assert.assertEquals(3L, calculateCSRMatrixDimensions._nonZeroElementsCount);
                Assert.assertEquals(4L, calculateCSRMatrixDimensions._rowHeadersCount);
                XGBoostUtils.SparseMatrix allocateCSRMatrix = XGBoostUtils.allocateCSRMatrix(calculateCSRMatrixDimensions);
                XGBoostUtilsTest.checkSparseDataStructuresAllocation(allocateCSRMatrix, calculateCSRMatrixDimensions._nonZeroElementsCount, length);
                int initalizeFromChunkIds = XGBoostUtils.initalizeFromChunkIds(frame, localChunkIds, readerArr, (Vec.Reader) null, dataInfo, allocateCSRMatrix._rowHeaders, allocateCSRMatrix._sparseData, allocateCSRMatrix._colIndices, reader, fArr, (float[]) null);
                Assert.assertEquals(3L, initalizeFromChunkIds);
                XGBoostUtilsTest.checkSparseDataInitialization(allocateCSRMatrix, new float[]{1.0f, 1.0f, 1.0f}, new long[]{0, 1, 2, 3}, new int[]{0, 1, 2});
                Assert.assertEquals(length, new DMatrix(allocateCSRMatrix._rowHeaders, allocateCSRMatrix._colIndices, allocateCSRMatrix._sparseData, DMatrix.SparseType.CSR, dataInfo.fullN(), initalizeFromChunkIds + 1, calculateCSRMatrixDimensions._nonZeroElementsCount).rowNum());
                if (frame != null) {
                    frame.remove();
                }
            } catch (Throwable th) {
                if (frame != null) {
                    frame.remove();
                }
                throw th;
            }
        }

        @Test
        public void testSparsematrixInit_dimensions_test() throws XGBoostError {
            Frame frame = null;
            try {
                frame = Scope.track(new Frame[]{new TestFrameBuilder().withName("testFrame").withColNames(new String[]{"C1", "C2", "C3"}).withVecTypes(new byte[]{3, 3, 3}).withDataForCol(0, ard(new double[]{10.0d, 0.0d, 0.0d})).withDataForCol(1, ard(new double[]{0.0d, 20.0d, 0.0d})).withDataForCol(2, ard(new double[]{0.0d, 0.0d, 30.0d})).build()});
                Vec anyVec = frame.anyVec();
                int[] localChunkIds = VecUtils.getLocalChunkIds(frame.anyVec());
                float[] fArr = new float[(int) anyVec.length()];
                Vec vec = frame.vec("C3");
                vec.getClass();
                Vec.Reader reader = new Vec.Reader(vec);
                DataInfo dataInfo = new DataInfo(frame, (Frame) null, true, DataInfo.TransformType.NONE, false, false, false);
                Vec.Reader[] readerArr = new Vec.Reader[frame.numCols()];
                for (int i = 0; i < readerArr.length; i++) {
                    Vec vec2 = frame.vec(i);
                    vec2.getClass();
                    readerArr[i] = new Vec.Reader(vec2);
                }
                int length = (int) anyVec.length();
                XGBoostUtilsTest.setSparseMatrixMaxDimensions(1);
                XGBoostUtils.SparseMatrixDimensions calculateCSRMatrixDimensions = XGBoostUtils.calculateCSRMatrixDimensions(frame, localChunkIds, (Vec) null, dataInfo);
                Assert.assertNotNull(calculateCSRMatrixDimensions);
                Assert.assertEquals(3L, calculateCSRMatrixDimensions._nonZeroElementsCount);
                Assert.assertEquals(4L, calculateCSRMatrixDimensions._rowHeadersCount);
                XGBoostUtils.SparseMatrix allocateCSRMatrix = XGBoostUtils.allocateCSRMatrix(calculateCSRMatrixDimensions);
                XGBoostUtilsTest.checkSparseDataStructuresAllocation(allocateCSRMatrix, calculateCSRMatrixDimensions._nonZeroElementsCount, length);
                int initalizeFromChunkIds = XGBoostUtils.initalizeFromChunkIds(frame, localChunkIds, readerArr, (Vec.Reader) null, dataInfo, allocateCSRMatrix._rowHeaders, allocateCSRMatrix._sparseData, allocateCSRMatrix._colIndices, reader, fArr, (float[]) null);
                Assert.assertEquals(3L, initalizeFromChunkIds);
                XGBoostUtilsTest.checkSparseDataInitialization(allocateCSRMatrix, new float[]{10.0f, 20.0f, 30.0f}, new long[]{0, 1, 2, 3}, new int[]{0, 1, 2});
                Assert.assertEquals(length, new DMatrix(allocateCSRMatrix._rowHeaders, allocateCSRMatrix._colIndices, allocateCSRMatrix._sparseData, DMatrix.SparseType.CSR, dataInfo.fullN(), initalizeFromChunkIds + 1, calculateCSRMatrixDimensions._nonZeroElementsCount).rowNum());
                if (frame != null) {
                    frame.remove();
                }
            } catch (Throwable th) {
                if (frame != null) {
                    frame.remove();
                }
                throw th;
            }
        }

        @Test
        public void testSparsematrixInit_categoricals_2D() throws XGBoostError {
            Frame frame = null;
            try {
                frame = Scope.track(new Frame[]{new TestFrameBuilder().withName("testFrame").withColNames(new String[]{"C1", "C2", "C3"}).withVecTypes(new byte[]{3, 4, 3}).withDataForCol(0, ard(new double[]{10.0d, 0.0d, 0.0d})).withDataForCol(1, ar(new String[]{"a", "b", "c"})).withDataForCol(2, ard(new double[]{0.0d, 0.0d, 30.0d})).build()});
                Vec anyVec = frame.anyVec();
                int[] localChunkIds = VecUtils.getLocalChunkIds(frame.anyVec());
                float[] fArr = new float[(int) anyVec.length()];
                Vec vec = frame.vec("C3");
                vec.getClass();
                Vec.Reader reader = new Vec.Reader(vec);
                DataInfo dataInfo = new DataInfo(frame, (Frame) null, true, DataInfo.TransformType.NONE, false, false, false);
                Vec.Reader[] readerArr = new Vec.Reader[frame.numCols()];
                for (int i = 0; i < readerArr.length; i++) {
                    Vec vec2 = frame.vec(i);
                    vec2.getClass();
                    readerArr[i] = new Vec.Reader(vec2);
                }
                int length = (int) anyVec.length();
                XGBoostUtilsTest.setSparseMatrixMaxDimensions(1);
                XGBoostUtils.SparseMatrixDimensions calculateCSRMatrixDimensions = XGBoostUtils.calculateCSRMatrixDimensions(frame, localChunkIds, (Vec) null, dataInfo);
                Assert.assertNotNull(calculateCSRMatrixDimensions);
                Assert.assertEquals(5L, calculateCSRMatrixDimensions._nonZeroElementsCount);
                Assert.assertEquals(4L, calculateCSRMatrixDimensions._rowHeadersCount);
                XGBoostUtils.SparseMatrix allocateCSRMatrix = XGBoostUtils.allocateCSRMatrix(calculateCSRMatrixDimensions);
                XGBoostUtilsTest.checkSparseDataStructuresAllocation(allocateCSRMatrix, calculateCSRMatrixDimensions._nonZeroElementsCount, length);
                int initalizeFromChunkIds = XGBoostUtils.initalizeFromChunkIds(frame, localChunkIds, readerArr, (Vec.Reader) null, dataInfo, allocateCSRMatrix._rowHeaders, allocateCSRMatrix._sparseData, allocateCSRMatrix._colIndices, reader, fArr, (float[]) null);
                Assert.assertEquals(3L, initalizeFromChunkIds);
                XGBoostUtilsTest.checkSparseDataInitialization(allocateCSRMatrix, new float[]{1.0f, 10.0f, 1.0f, 1.0f, 30.0f}, new long[]{0, 2, 3, 5}, new int[]{0, 3, 1, 2, 4});
                Assert.assertEquals(length, new DMatrix(allocateCSRMatrix._rowHeaders, allocateCSRMatrix._colIndices, allocateCSRMatrix._sparseData, DMatrix.SparseType.CSR, dataInfo.fullN(), initalizeFromChunkIds + 1, calculateCSRMatrixDimensions._nonZeroElementsCount).rowNum());
                if (frame != null) {
                    frame.remove();
                }
            } catch (Throwable th) {
                if (frame != null) {
                    frame.remove();
                }
                throw th;
            }
        }
    }

    @BeforeClass
    public static void beforeClass() {
        TestUtil.stall_till_cloudsize(1);
    }

    @After
    public void tearDown() {
        revertDefaultSparseMatrixMaxSize();
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static void checkSparseDataStructuresAllocation(XGBoostUtils.SparseMatrix sparseMatrix, long j, int i) {
        Assert.assertNotNull(sparseMatrix);
        Assert.assertNotNull(sparseMatrix._colIndices);
        Assert.assertNotNull(sparseMatrix._rowHeaders);
        float[][] fArr = sparseMatrix._sparseData;
        int[][] iArr = sparseMatrix._colIndices;
        Assert.assertNotNull(fArr);
        long j2 = j / XGBoostUtils.SPARSE_MATRIX_DIM;
        if (j % XGBoostUtils.SPARSE_MATRIX_DIM != 0) {
            j2++;
        }
        Assert.assertEquals(j2, fArr.length);
        long min = Math.min(Math.min(2147483637L, j), XGBoostUtils.SPARSE_MATRIX_DIM);
        for (int i2 = 0; i2 < fArr.length - 1; i2++) {
            Assert.assertEquals(min, fArr[i2].length);
            Assert.assertEquals(min, iArr[i2].length);
        }
        long j3 = j % XGBoostUtils.SPARSE_MATRIX_DIM;
        if (j3 == 0) {
            Assert.assertEquals(min, fArr[fArr.length - 1].length);
            Assert.assertEquals(min, iArr[iArr.length - 1].length);
        } else {
            Assert.assertEquals(j3, fArr[fArr.length - 1].length);
            Assert.assertEquals(j3, iArr[iArr.length - 1].length);
        }
        long j4 = 0;
        for (int i3 = 0; i3 < sparseMatrix._rowHeaders.length; i3++) {
            j4 += sparseMatrix._rowHeaders[i3].length;
        }
        Assert.assertEquals(i + 1, j4);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static void checkSparseDataInitialization(XGBoostUtils.SparseMatrix sparseMatrix, float[] fArr, long[] jArr, int[] iArr) {
        float[][] fArr2 = sparseMatrix._sparseData;
        int i = 0;
        for (int i2 = 0; i2 < fArr2.length; i2++) {
            for (int i3 = 0; i3 < fArr2[i2].length; i3++) {
                int i4 = i;
                i++;
                Assert.assertEquals(fArr[i4], fArr2[i2][i3], 0.0d);
            }
        }
        long[][] jArr2 = sparseMatrix._rowHeaders;
        int i5 = 0;
        for (int i6 = 0; i6 < jArr2.length; i6++) {
            for (int i7 = 0; i7 < jArr2[i6].length; i7++) {
                int i8 = i5;
                i5++;
                Assert.assertEquals(jArr[i8], jArr2[i6][i7], 0.0d);
            }
        }
        int[][] iArr2 = sparseMatrix._colIndices;
        int i9 = 0;
        for (int i10 = 0; i10 < iArr2.length; i10++) {
            for (int i11 = 0; i11 < iArr2[i10].length; i11++) {
                int i12 = i9;
                i9++;
                Assert.assertEquals(iArr[i12], iArr2[i10][i11], 0.0d);
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static String[] readLines(URL url) throws IOException {
        ArrayList arrayList = new ArrayList();
        BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(url.openStream()));
        Throwable th = null;
        while (true) {
            try {
                try {
                    String readLine = bufferedReader.readLine();
                    if (readLine == null) {
                        break;
                    }
                    arrayList.add(readLine);
                } finally {
                }
            } catch (Throwable th2) {
                if (bufferedReader != null) {
                    if (th != null) {
                        try {
                            bufferedReader.close();
                        } catch (Throwable th3) {
                            th.addSuppressed(th3);
                        }
                    } else {
                        bufferedReader.close();
                    }
                }
                throw th2;
            }
        }
        if (bufferedReader != null) {
            if (0 != 0) {
                try {
                    bufferedReader.close();
                } catch (Throwable th4) {
                    th.addSuppressed(th4);
                }
            } else {
                bufferedReader.close();
            }
        }
        return (String[]) arrayList.toArray(new String[0]);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static void setSparseMatrixMaxDimensions(int i) {
        new XGBSparseMatrixDimTask(i).doAllNodes();
    }

    private static void revertDefaultSparseMatrixMaxSize() {
        new XGBSparseMatrixDimTask(DEFAULT_SPARSE_MATRIX_SIZE).doAllNodes();
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static void testCSRPredictions(Frame frame, String str, Frame frame2) {
        try {
            Scope.enter();
            XGBoostModel.XGBoostParameters xGBoostParameters = new XGBoostModel.XGBoostParameters();
            xGBoostParameters._response_column = str;
            xGBoostParameters._train = frame._key;
            xGBoostParameters._ntrees = 10;
            xGBoostParameters._backend = XGBoostModel.XGBoostParameters.Backend.cpu;
            xGBoostParameters._dmatrix_type = XGBoostModel.XGBoostParameters.DMatrixType.sparse;
            setSparseMatrixMaxDimensions(10);
            XGBoostModel xGBoostModel = new XGBoost(xGBoostParameters).trainModel().get();
            Scope.track_generic(xGBoostModel);
            Frame score = xGBoostModel.score(frame2);
            Scope.track(new Frame[]{score});
            Assert.assertNotNull(score);
            setSparseMatrixMaxDimensions(MAX_ARR_SIZE);
            XGBoostModel xGBoostModel2 = new XGBoost(xGBoostParameters).trainModel().get();
            Assert.assertNotNull(xGBoostModel2);
            Scope.track_generic(xGBoostModel2);
            Frame score2 = xGBoostModel2.score(frame2);
            Scope.track(new Frame[]{score2});
            Assert.assertNotEquals(score, score2);
            Assert.assertTrue(TestUtil.compareFrames(score, score2));
            xGBoostParameters._dmatrix_type = XGBoostModel.XGBoostParameters.DMatrixType.dense;
            XGBoostModel xGBoostModel3 = new XGBoost(xGBoostParameters).trainModel().get();
            Assert.assertNotNull(xGBoostModel3);
            Scope.track_generic(xGBoostModel3);
            Frame score3 = xGBoostModel3.score(frame2);
            Scope.track(new Frame[]{score3});
            Assert.assertTrue(TestUtil.compareFrames(score3, score2));
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    protected static float[] createRandomLabelCol(int i) {
        float[] fArr = new float[i];
        SecureRandom secureRandom = new SecureRandom();
        for (int i2 = 0; i2 < fArr.length; i2++) {
            fArr[i2] = (float) secureRandom.nextGaussian();
        }
        return fArr;
    }

    protected static void attachLabelToFrame(Frame frame, float[] fArr) {
        frame.add("response", Vec.makeVec(fArr, frame.anyVec().group().addVec()));
    }

    protected static Matrices createIdentityMatrices(int i, int i2) throws XGBoostError {
        long[][] allocateLong = createLayout(i + 1, i2).allocateLong();
        int[][] allocateInt = createLayout(i, i2).allocateInt();
        float[][] allocateFloat = createLayout(i, i2).allocateFloat();
        TestFrameBuilder withUniformVecTypes = new TestFrameBuilder().withUniformVecTypes(i, (byte) 3);
        long j = 0;
        for (int i3 = 0; i3 < i; i3++) {
            int i4 = (int) (j / i2);
            int i5 = (int) (j % i2);
            withUniformVecTypes = withUniformVecTypes.withDataForCol(i3, genIdentityMatrixFrameCol(i3, i));
            allocateFloat[i4][i5] = 1.0f;
            allocateInt[i4][i5] = i3;
            allocateLong[i4][i5] = j;
            j++;
        }
        allocateLong[(int) (j / i2)][(int) (j % i2)] = j;
        Assert.assertEquals(i, j);
        return new Matrices(new DMatrix(allocateLong, allocateInt, allocateFloat, DMatrix.SparseType.CSR, i, i + 1, i), withUniformVecTypes.build());
    }

    private static double[] genIdentityMatrixFrameCol(int i, int i2) {
        double[] dArr = new double[i2];
        dArr[i] = 1.0d;
        return dArr;
    }

    private static CsrLayout createLayout(long j, int i) {
        CsrLayout csrLayout = new CsrLayout();
        csrLayout._numRegRows = (int) (j / i);
        csrLayout._regRowLen = i;
        csrLayout._lastRowLen = (int) (j - (csrLayout._numRegRows * csrLayout._regRowLen));
        return csrLayout;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static void comparePreds(float[][] fArr, Vec vec, float f) {
        if (fArr.length != vec.length()) {
            throw new IllegalStateException(String.format("Predictions do not have the same length. Native: %x, H2O: %x", Integer.valueOf(fArr.length), Long.valueOf(vec.length())));
        }
        for (int i = 0; i < fArr.length; i++) {
            Assert.assertEquals(fArr[i][0], (float) vec.at(i), f);
        }
    }
}
