package com.linkedin.dagli.xgboost;

import com.linkedin.dagli.math.vector.DenseVector;
import java.util.Collections;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import ml.dmlc.xgboost4j.java.Booster;
import ml.dmlc.xgboost4j.java.DMatrix;
import ml.dmlc.xgboost4j.java.XGBoostError;

/* JADX INFO: Access modifiers changed from: package-private */
/* loaded from: input_file:com/linkedin/dagli/xgboost/XGBoostModel.class */
public abstract class XGBoostModel {
    private static final ReentrantReadWriteLock LOCK = new ReentrantReadWriteLock();
    static final ThreadLocal<Boolean> IS_THREAD_CONFIGURED_FOR_SINGLE_THREADED_PREDICTION = ThreadLocal.withInitial(() -> {
        return false;
    });

    XGBoostModel() {
    }

    private static void configureBooster(Booster booster) {
        if (IS_THREAD_CONFIGURED_FOR_SINGLE_THREADED_PREDICTION.get().booleanValue()) {
            return;
        }
        LOCK.writeLock().lock();
        try {
            try {
                booster.setParam("nthread", 1);
                IS_THREAD_CONFIGURED_FOR_SINGLE_THREADED_PREDICTION.set(true);
                LOCK.writeLock().unlock();
            } catch (XGBoostError e) {
                throw new RuntimeException((Throwable) e);
            }
        } catch (Throwable th) {
            LOCK.writeLock().unlock();
            throw th;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static float[] predictAsFloats(Booster booster, DenseVector denseVector, PredictAsFloatsMethod predictAsFloatsMethod) {
        configureBooster(booster);
        LOCK.readLock().lock();
        DMatrix dMatrix = null;
        try {
            try {
                dMatrix = new DMatrix(Collections.singleton(AbstractXGBoostModel.makeDenseLabeledPoint(null, 0.0f, denseVector)).iterator(), (String) null);
                float[] predictAsFloats = predictAsFloatsMethod.predictAsFloats(booster, dMatrix);
                if (dMatrix != null) {
                    dMatrix.dispose();
                }
                LOCK.readLock().unlock();
                return predictAsFloats;
            } catch (XGBoostError e) {
                throw new RuntimeException("XGBoost threw an exception during inference", e);
            }
        } catch (Throwable th) {
            if (dMatrix != null) {
                dMatrix.dispose();
            }
            LOCK.readLock().unlock();
            throw th;
        }
    }
}
