package hex.tree.xgboost.matrix;

import ai.h2o.xgboost4j.java.util.BigDenseMatrix;
import hex.tree.xgboost.matrix.DenseMatrixFactory;
import hex.tree.xgboost.matrix.MatrixLoader;
import hex.tree.xgboost.matrix.SparseMatrixFactory;
import hex.tree.xgboost.task.XGBoostUploadMatrixTask;
import java.util.HashMap;
import java.util.Map;
import water.Key;

/* loaded from: input_file:hex/tree/xgboost/matrix/RemoteMatrixLoader.class */
public class RemoteMatrixLoader extends MatrixLoader {
    private static final Map<String, RemoteMatrix> REGISTRY = new HashMap();
    private final Key<?> modelKey;

    /* loaded from: input_file:hex/tree/xgboost/matrix/RemoteMatrixLoader$RemoteDenseMatrix.class */
    static class RemoteDenseMatrix extends RemoteMatrix {
        final XGBoostUploadMatrixTask.DenseMatrixDimensions dims;
        final BigDenseMatrix matrix;

        RemoteDenseMatrix(XGBoostUploadMatrixTask.DenseMatrixDimensions denseMatrixDimensions) {
            this.dims = denseMatrixDimensions;
            this.matrix = new BigDenseMatrix(denseMatrixDimensions.rows, denseMatrixDimensions.cols);
        }

        @Override // hex.tree.xgboost.matrix.RemoteMatrixLoader.RemoteMatrix
        MatrixLoader.DMatrixProvider make() {
            return new DenseMatrixFactory.DenseDMatrixProvider(this.data.actualRows, this.data.resp, this.data.weights, this.data.offsets, this.matrix);
        }
    }

    /* loaded from: input_file:hex/tree/xgboost/matrix/RemoteMatrixLoader$RemoteMatrix.class */
    static abstract class RemoteMatrix {
        XGBoostUploadMatrixTask.MatrixData data;

        RemoteMatrix() {
        }

        abstract MatrixLoader.DMatrixProvider make();
    }

    /* loaded from: input_file:hex/tree/xgboost/matrix/RemoteMatrixLoader$RemoteSparseMatrix.class */
    static class RemoteSparseMatrix extends RemoteMatrix {
        final SparseMatrixDimensions dims;
        final SparseMatrix matrix;

        RemoteSparseMatrix(SparseMatrixDimensions sparseMatrixDimensions) {
            this.dims = sparseMatrixDimensions;
            this.matrix = SparseMatrixFactory.allocateCSRMatrix(sparseMatrixDimensions);
        }

        @Override // hex.tree.xgboost.matrix.RemoteMatrixLoader.RemoteMatrix
        MatrixLoader.DMatrixProvider make() {
            return SparseMatrixFactory.toDMatrix(this.matrix, this.dims, this.data.actualRows, this.data.shape, this.data.resp, this.data.weights, this.data.offsets);
        }
    }

    public static void initSparse(String str, SparseMatrixDimensions sparseMatrixDimensions) {
        REGISTRY.put(str, new RemoteSparseMatrix(sparseMatrixDimensions));
    }

    public static void sparseChunk(String str, XGBoostUploadMatrixTask.SparseMatrixChunk sparseMatrixChunk) {
        RemoteSparseMatrix remoteSparseMatrix = (RemoteSparseMatrix) REGISTRY.get(str);
        long j = remoteSparseMatrix.dims._precedingNonZeroElementsCounts[sparseMatrixChunk.id];
        SparseMatrixFactory.NestedArrayPointer nestedArrayPointer = new SparseMatrixFactory.NestedArrayPointer(remoteSparseMatrix.dims._precedingRowCounts[sparseMatrixChunk.id]);
        SparseMatrixFactory.NestedArrayPointer nestedArrayPointer2 = new SparseMatrixFactory.NestedArrayPointer(j);
        for (int i = 0; i < sparseMatrixChunk.rowHeader.length; i++) {
            nestedArrayPointer.setAndIncrement(remoteSparseMatrix.matrix._rowHeaders, sparseMatrixChunk.rowHeader[i]);
        }
        for (int i2 = 0; i2 < sparseMatrixChunk.data.length; i2++) {
            nestedArrayPointer2.set(remoteSparseMatrix.matrix._sparseData, sparseMatrixChunk.data[i2]);
            nestedArrayPointer2.set(remoteSparseMatrix.matrix._colIndices, sparseMatrixChunk.colIndices[i2]);
            nestedArrayPointer2.increment();
        }
    }

    public static void initDense(String str, XGBoostUploadMatrixTask.DenseMatrixDimensions denseMatrixDimensions) {
        REGISTRY.put(str, new RemoteDenseMatrix(denseMatrixDimensions));
    }

    public static void denseChunk(String str, XGBoostUploadMatrixTask.DenseMatrixChunk denseMatrixChunk) {
        RemoteDenseMatrix remoteDenseMatrix = (RemoteDenseMatrix) REGISTRY.get(str);
        long j = 0;
        while (true) {
            long j2 = j;
            if (j2 >= denseMatrixChunk.data.length) {
                return;
            }
            remoteDenseMatrix.matrix.set(j2 + (remoteDenseMatrix.dims.rowOffsets[denseMatrixChunk.id] * remoteDenseMatrix.dims.cols), denseMatrixChunk.data[(int) j2]);
            j = j2 + 1;
        }
    }

    public static void matrixData(String str, XGBoostUploadMatrixTask.MatrixData matrixData) {
        REGISTRY.get(str).data = matrixData;
    }

    public static void cleanup(String str) {
        REGISTRY.remove(str);
    }

    public RemoteMatrixLoader(Key<?> key) {
        this.modelKey = key;
    }

    @Override // hex.tree.xgboost.matrix.MatrixLoader
    public MatrixLoader.DMatrixProvider makeLocalTrainMatrix() {
        return REGISTRY.remove(trainMatrixKey(this.modelKey)).make();
    }

    public static String trainMatrixKey(Key<?> key) {
        return key.toString() + "_train";
    }

    @Override // hex.tree.xgboost.matrix.MatrixLoader
    public boolean hasValidationFrame() {
        return REGISTRY.containsKey(validMatrixKey(this.modelKey));
    }

    @Override // hex.tree.xgboost.matrix.MatrixLoader
    public MatrixLoader.DMatrixProvider makeLocalValidMatrix() {
        return REGISTRY.remove(validMatrixKey(this.modelKey)).make();
    }

    public static String validMatrixKey(Key<?> key) {
        return key.toString() + "_valid";
    }
}
