package hex.tree.xgboost;

import hex.CVModelBuilder;
import hex.ModelBuilder;
import hex.tree.xgboost.XGBoostModel;
import hex.tree.xgboost.util.GpuUtils;
import java.util.Deque;
import java.util.LinkedList;
import org.apache.log4j.Logger;
import water.Job;

/* loaded from: input_file:hex/tree/xgboost/XGBoostGPUCVModelBuilder.class */
public class XGBoostGPUCVModelBuilder extends CVModelBuilder {
    private static final Logger LOG = Logger.getLogger(XGBoostGPUCVModelBuilder.class);
    private final Deque<Integer> availableGpus;

    public XGBoostGPUCVModelBuilder(Job job, ModelBuilder<?, ?, ?>[] modelBuilderArr, int i, int[] iArr) {
        super(job, modelBuilderArr, i);
        if (iArr == null || iArr.length <= 0) {
            this.availableGpus = new LinkedList(GpuUtils.allGPUs());
        } else {
            this.availableGpus = new LinkedList();
            for (int i2 : iArr) {
                this.availableGpus.add(Integer.valueOf(i2));
            }
        }
        LOG.info("Using parallel GPU building on " + this.availableGpus.size() + " GPUs.");
    }

    @Override // hex.CVModelBuilder
    protected void prepare(ModelBuilder<?, ?, ?> modelBuilder) {
        XGBoost xGBoost = (XGBoost) modelBuilder;
        ((XGBoostModel.XGBoostParameters) xGBoost._parms)._gpu_id = new int[]{this.availableGpus.pop().intValue()};
        LOG.info("Building " + xGBoost.dest() + " on GPU " + ((XGBoostModel.XGBoostParameters) xGBoost._parms)._gpu_id[0]);
    }

    @Override // hex.CVModelBuilder
    protected void finished(ModelBuilder<?, ?, ?> modelBuilder) {
        this.availableGpus.push(Integer.valueOf(((XGBoostModel.XGBoostParameters) ((XGBoost) modelBuilder)._parms)._gpu_id[0]));
    }
}
