package hex.tree.xgboost;

import hex.CVModelBuilder;
import hex.ModelBuilder;
import hex.tree.xgboost.XGBoostModel;
import hex.tree.xgboost.util.GpuUtils;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.List;
import org.apache.log4j.Logger;
import water.Job;

/* loaded from: input_file:hex/tree/xgboost/XGBoostGPUCVModelBuilder.class */
public class XGBoostGPUCVModelBuilder extends CVModelBuilder {
    private static final Logger LOG = Logger.getLogger(XGBoostGPUCVModelBuilder.class);
    private final GPUAllocator _allocator;

    /* loaded from: input_file:hex/tree/xgboost/XGBoostGPUCVModelBuilder$GPUAllocator.class */
    static class GPUAllocator {
        final int[] _gpu_utilization;
        static final /* synthetic */ boolean $assertionsDisabled;

        GPUAllocator(List<Integer> list) {
            this(initUtilization(list));
        }

        GPUAllocator(int[] iArr) {
            this._gpu_utilization = iArr;
        }

        static int[] initUtilization(List<Integer> list) {
            int[] iArr = new int[list.stream().max((v0, v1) -> {
                return v0.compareTo(v1);
            }).orElseThrow(() -> {
                return new IllegalStateException("There are no GPUs available for XGBoost (" + list + ").");
            }).intValue() + 1];
            Arrays.fill(iArr, -1);
            list.forEach(num -> {
                iArr[num.intValue()] = 0;
            });
            return iArr;
        }

        void releaseGPU(int i) {
            int[] iArr = this._gpu_utilization;
            iArr[i] = iArr[i] - 1;
        }

        int takeLeastUtilizedGPU() {
            int i = -1;
            for (int i2 = 0; i2 < this._gpu_utilization.length; i2++) {
                if (this._gpu_utilization[i2] != -1 && (i == -1 || this._gpu_utilization[i2] < this._gpu_utilization[i])) {
                    i = i2;
                }
            }
            if (!$assertionsDisabled && i == -1) {
                throw new AssertionError();
            }
            int[] iArr = this._gpu_utilization;
            int i3 = i;
            iArr[i3] = iArr[i3] + 1;
            return i;
        }

        static {
            $assertionsDisabled = !XGBoostGPUCVModelBuilder.class.desiredAssertionStatus();
        }
    }

    public XGBoostGPUCVModelBuilder(Job<?> job, ModelBuilder<?, ?, ?>[] modelBuilderArr, int i, int[] iArr) {
        super(job, modelBuilderArr, i);
        LinkedList linkedList;
        if (iArr == null || iArr.length <= 0) {
            linkedList = new LinkedList(GpuUtils.allGPUs());
        } else {
            linkedList = new LinkedList();
            for (int i2 : iArr) {
                linkedList.add(Integer.valueOf(i2));
            }
        }
        LOG.info("Available #GPUs for CV model training: " + linkedList.size());
        this._allocator = new GPUAllocator(linkedList);
    }

    protected void prepare(ModelBuilder<?, ?, ?> modelBuilder) {
        XGBoost xGBoost = (XGBoost) modelBuilder;
        ((XGBoostModel.XGBoostParameters) xGBoost._parms)._gpu_id = new int[]{this._allocator.takeLeastUtilizedGPU()};
        LOG.info("Building " + xGBoost.dest() + " on GPU " + ((XGBoostModel.XGBoostParameters) xGBoost._parms)._gpu_id[0]);
    }

    protected void finished(ModelBuilder<?, ?, ?> modelBuilder) {
        this._allocator.releaseGPU(((XGBoostModel.XGBoostParameters) ((XGBoost) modelBuilder)._parms)._gpu_id[0]);
    }
}
