package ai.djl.tensorflow.engine;

import ai.djl.Device;
import ai.djl.Model;
import ai.djl.engine.Engine;
import ai.djl.engine.EngineException;
import ai.djl.ndarray.NDManager;
import ai.djl.training.GradientCollector;
import ai.djl.util.RandomUtils;
import org.tensorflow.EagerSession;
import org.tensorflow.TensorFlow;
import org.tensorflow.internal.c_api.TF_DeviceList;
import org.tensorflow.internal.c_api.TF_Status;
import org.tensorflow.internal.c_api.global.tensorflow;

/* loaded from: input_file:ai/djl/tensorflow/engine/TfEngine.class */
public final class TfEngine extends Engine {
    public static final String ENGINE_NAME = "TensorFlow";

    private TfEngine() {
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static TfEngine newInstance() {
        try {
            LibUtils.loadLibrary();
            EagerSession.getDefault();
            return new TfEngine();
        } catch (Throwable th) {
            throw new EngineException("Failed to load TensorFlow native library", th);
        }
    }

    public Model newModel(String str, Device device) {
        return new TfModel(str, device);
    }

    public String getEngineName() {
        return ENGINE_NAME;
    }

    public int getRank() {
        return 3;
    }

    public String getVersion() {
        return TensorFlow.version();
    }

    public boolean hasCapability(String str) {
        if ("MKL".equals(str)) {
            return true;
        }
        if (!"CUDA".equals(str)) {
            return false;
        }
        TF_Status TF_NewStatus = tensorflow.TF_NewStatus();
        TF_DeviceList TFE_ContextListDevices = tensorflow.TFE_ContextListDevices(tensorflow.TFE_NewContext(tensorflow.TFE_NewContextOptions(), TF_NewStatus), TF_NewStatus);
        int TF_DeviceListCount = tensorflow.TF_DeviceListCount(TFE_ContextListDevices);
        for (int i = 0; i < TF_DeviceListCount; i++) {
            if (tensorflow.TF_DeviceListName(TFE_ContextListDevices, i, TF_NewStatus).getString().toLowerCase().contains("gpu")) {
                return true;
            }
        }
        return false;
    }

    public NDManager newBaseManager() {
        return TfNDManager.getSystemManager().newSubManager();
    }

    public NDManager newBaseManager(Device device) {
        return TfNDManager.getSystemManager().mo5newSubManager(device);
    }

    public GradientCollector newGradientCollector() {
        throw new UnsupportedOperationException("TensorFlow does not support training yet");
    }

    public void setRandomSeed(int i) {
        TfNDManager.setRandomSeed(Integer.valueOf(i));
        RandomUtils.RANDOM.setSeed(i);
    }

    public String toString() {
        StringBuilder sb = new StringBuilder(200);
        sb.append(getEngineName()).append(':').append(getVersion()).append(", capabilities: [\n\tMKL,\n");
        if (hasCapability("CUDA")) {
            sb.append("\t").append("CUDA").append(",\n");
        }
        sb.append("]\nTensorFlow Library: ").append(LibUtils.getLibName());
        return sb.toString();
    }
}
