package hex.tree.xgboost.util;

import ai.h2o.xgboost4j.java.DMatrix;
import ai.h2o.xgboost4j.java.INativeLibLoader;
import ai.h2o.xgboost4j.java.NativeLibLoader;
import ai.h2o.xgboost4j.java.Rabit;
import ai.h2o.xgboost4j.java.XGBoost;
import ai.h2o.xgboost4j.java.XGBoostError;
import hex.tree.xgboost.util.NativeLibrary;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Set;
import org.apache.log4j.Logger;
import water.DTask;
import water.H2O;
import water.H2ONode;
import water.RPC;

/* loaded from: input_file:hex/tree/xgboost/util/GpuUtils.class */
public class GpuUtils {
    private static final Logger LOG = Logger.getLogger(GpuUtils.class);
    public static final int[] DEFAULT_GPU_ID = {0};
    private static volatile boolean defaultGpuIdNotValid = false;
    private static volatile boolean gpuSearchPerformed = false;
    private static final Set<Integer> GPUS = new HashSet();

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:hex/tree/xgboost/util/GpuUtils$AllGPUsTask.class */
    public static class AllGPUsTask extends DTask<HasGPUTask> {
        private Integer[] gpuIds;

        private AllGPUsTask() {
        }

        @Override // water.H2O.H2OCountedCompleter
        public void compute2() {
            this.gpuIds = (Integer[]) GpuUtils.allGPUs().toArray(new Integer[0]);
            tryComplete();
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:hex/tree/xgboost/util/GpuUtils$HasGPUTask.class */
    public static class HasGPUTask extends DTask<HasGPUTask> {
        private final int[] _gpu_id;
        private boolean _hasGPU;

        private HasGPUTask(int[] iArr) {
            this._gpu_id = iArr;
        }

        @Override // water.H2O.H2OCountedCompleter
        public void compute2() {
            this._hasGPU = GpuUtils.hasGPU(this._gpu_id);
            tryComplete();
        }
    }

    static boolean isGpuSupportEnabled() {
        try {
            INativeLibLoader loader = NativeLibLoader.getLoader();
            if (loader instanceof NativeLibraryLoaderChain) {
                return ((NativeLibraryLoaderChain) loader).getLoadedLibrary().hasCompilationFlag(NativeLibrary.CompilationFlags.WITH_GPU);
            }
            return false;
        } catch (IOException e) {
            LOG.debug(e);
            return false;
        }
    }

    private static boolean gpuCheckEnabled() {
        return H2O.getSysBoolProperty("xgboost.gpu.check.enabled", true);
    }

    public static int numGPUs(H2ONode h2ONode) {
        return allGPUs(h2ONode).size();
    }

    public static Set<Integer> allGPUs(H2ONode h2ONode) {
        if (H2O.SELF.equals(h2ONode)) {
            return allGPUs();
        }
        AllGPUsTask allGPUsTask = new AllGPUsTask();
        new RPC(h2ONode, allGPUsTask).call().get();
        return new HashSet(Arrays.asList(allGPUsTask.gpuIds));
    }

    public static Set<Integer> allGPUs() {
        if (gpuSearchPerformed) {
            return Collections.unmodifiableSet(GPUS);
        }
        for (int i = 0; hasGPU(new int[]{i}); i++) {
        }
        gpuSearchPerformed = true;
        return Collections.unmodifiableSet(GPUS);
    }

    public static boolean hasGPU(H2ONode h2ONode, int[] iArr) {
        boolean z;
        if (H2O.SELF.equals(h2ONode)) {
            z = hasGPU(iArr);
        } else {
            HasGPUTask hasGPUTask = new HasGPUTask(iArr);
            new RPC(h2ONode, hasGPUTask).call().get();
            z = hasGPUTask._hasGPU;
        }
        LOG.debug("Availability of GPU (id=" + Arrays.toString(iArr) + ") on node " + h2ONode + ": " + z);
        return z;
    }

    public static boolean hasGPU(int[] iArr) {
        if (!gpuCheckEnabled()) {
            return true;
        }
        if (iArr == null && defaultGpuIdNotValid) {
            return false;
        }
        boolean z = true;
        if (iArr == null) {
            iArr = DEFAULT_GPU_ID;
        }
        for (int i = 0; z && i < iArr.length; i++) {
            z = hasGPU_impl(iArr[i]);
        }
        if (Arrays.equals(iArr, DEFAULT_GPU_ID) && !z) {
            defaultGpuIdNotValid = true;
        }
        return z;
    }

    public static boolean hasGPU() {
        return hasGPU(null);
    }

    private static synchronized boolean hasGPU_impl(int i) {
        if (!isGpuSupportEnabled()) {
            return false;
        }
        if (GPUS.contains(Integer.valueOf(i))) {
            return true;
        }
        try {
            DMatrix dMatrix = new DMatrix(new float[]{1.0f, 2.0f, 1.0f, 2.0f}, 2, 2);
            dMatrix.setLabel(new float[]{1.0f, 0.0f});
            HashMap hashMap = new HashMap();
            hashMap.put("updater", "grow_gpu_hist");
            hashMap.put("silent", 1);
            hashMap.put("fail_on_invalid_gpu_id", true);
            hashMap.put("gpu_id", Integer.valueOf(i));
            HashMap hashMap2 = new HashMap();
            hashMap2.put("train", dMatrix);
            try {
                try {
                    Rabit.init(new HashMap());
                    XGBoost.train(dMatrix, hashMap, 1, hashMap2, null, null);
                    GPUS.add(Integer.valueOf(i));
                    return true;
                } catch (XGBoostError e) {
                    try {
                        Rabit.shutdown();
                    } catch (XGBoostError e2) {
                        LOG.warn("Cannot shutdown XGBoost Rabit for current thread.");
                    }
                    return false;
                }
            } finally {
                try {
                    Rabit.shutdown();
                } catch (XGBoostError e3) {
                    LOG.warn("Cannot shutdown XGBoost Rabit for current thread.");
                }
            }
        } catch (XGBoostError e4) {
            throw new IllegalStateException("Couldn't prepare training matrix for XGBoost.", e4);
        }
    }
}
