package org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.resources.gpu;

import java.io.IOException;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.util.StringUtils;
import org.apache.hadoop.yarn.api.records.ApplicationAttemptId;
import org.apache.hadoop.yarn.api.records.ApplicationId;
import org.apache.hadoop.yarn.api.records.ContainerId;
import org.apache.hadoop.yarn.api.records.ContainerLaunchContext;
import org.apache.hadoop.yarn.api.records.Resource;
import org.apache.hadoop.yarn.conf.YarnConfiguration;
import org.apache.hadoop.yarn.server.nodemanager.Context;
import org.apache.hadoop.yarn.server.nodemanager.containermanager.container.Container;
import org.apache.hadoop.yarn.server.nodemanager.containermanager.container.ResourceMappings;
import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.privileged.PrivilegedOperation;
import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.privileged.PrivilegedOperationException;
import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.privileged.PrivilegedOperationExecutor;
import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.resources.CGroupsHandler;
import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.resources.ResourceHandlerException;
import org.apache.hadoop.yarn.server.nodemanager.containermanager.resourceplugin.gpu.GpuDevice;
import org.apache.hadoop.yarn.server.nodemanager.containermanager.resourceplugin.gpu.GpuDiscoverer;
import org.apache.hadoop.yarn.server.nodemanager.containermanager.runtime.ContainerRuntimeConstants;
import org.apache.hadoop.yarn.server.nodemanager.nodelabels.AbstractNodeLabelsProvider;
import org.apache.hadoop.yarn.server.nodemanager.recovery.NMNullStateStoreService;
import org.apache.hadoop.yarn.server.nodemanager.recovery.NMStateStoreService;
import org.apache.hadoop.yarn.util.resource.TestResourceUtils;
import org.apache.xerces.xs.XSSimpleTypeDefinition;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.mockito.Matchers;
import org.mockito.Mockito;

/* loaded from: input_file:test-classes/org/apache/hadoop/yarn/server/nodemanager/containermanager/linux/resources/gpu/TestGpuResourceHandler.class */
public class TestGpuResourceHandler {
    private CGroupsHandler mockCGroupsHandler;
    private PrivilegedOperationExecutor mockPrivilegedExecutor;
    private GpuResourceHandlerImpl gpuResourceHandler;
    private NMStateStoreService mockNMStateStore;
    private ConcurrentHashMap<ContainerId, Container> runningContainersMap;

    @Before
    public void setup() {
        TestResourceUtils.addNewTypesToResources(new String[]{"yarn.io/gpu"});
        this.mockCGroupsHandler = (CGroupsHandler) Mockito.mock(CGroupsHandler.class);
        this.mockPrivilegedExecutor = (PrivilegedOperationExecutor) Mockito.mock(PrivilegedOperationExecutor.class);
        this.mockNMStateStore = (NMStateStoreService) Mockito.mock(NMStateStoreService.class);
        Context context = (Context) Mockito.mock(Context.class);
        Mockito.when(context.getNMStateStore()).thenReturn(this.mockNMStateStore);
        this.runningContainersMap = new ConcurrentHashMap<>();
        Mockito.when(context.getContainers()).thenReturn(this.runningContainersMap);
        this.gpuResourceHandler = new GpuResourceHandlerImpl(context, this.mockCGroupsHandler, this.mockPrivilegedExecutor);
    }

    @Test
    public void testBootStrap() throws Exception {
        Configuration yarnConfiguration = new YarnConfiguration();
        yarnConfiguration.set("yarn.nodemanager.resource-plugins.gpu.allowed-gpu-devices", "0:0");
        GpuDiscoverer.getInstance().initialize(yarnConfiguration);
        this.gpuResourceHandler.bootstrap(yarnConfiguration);
        ((CGroupsHandler) Mockito.verify(this.mockCGroupsHandler, Mockito.times(1))).initializeCGroupController(CGroupsHandler.CGroupController.DEVICES);
    }

    private static ContainerId getContainerId(int i) {
        return ContainerId.newContainerId(ApplicationAttemptId.newInstance(ApplicationId.newInstance(1234L, 1), 1), i);
    }

    private static Container mockContainerWithGpuRequest(int i, int i2, boolean z) {
        Container container = (Container) Mockito.mock(Container.class);
        Mockito.when(container.getContainerId()).thenReturn(getContainerId(i));
        Resource newInstance = Resource.newInstance(XSSimpleTypeDefinition.FACET_FRACTIONDIGITS, 1);
        ResourceMappings resourceMappings = new ResourceMappings();
        newInstance.setResourceValue("yarn.io/gpu", i2);
        Mockito.when(container.getResource()).thenReturn(newInstance);
        Mockito.when(container.getResourceMappings()).thenReturn(resourceMappings);
        ContainerLaunchContext containerLaunchContext = (ContainerLaunchContext) Mockito.mock(ContainerLaunchContext.class);
        HashMap hashMap = new HashMap();
        if (z) {
            hashMap.put(ContainerRuntimeConstants.ENV_CONTAINER_TYPE, "docker");
        }
        Mockito.when(containerLaunchContext.getEnvironment()).thenReturn(hashMap);
        Mockito.when(container.getLaunchContext()).thenReturn(containerLaunchContext);
        return container;
    }

    private static Container mockContainerWithGpuRequest(int i, int i2) {
        return mockContainerWithGpuRequest(i, i2, false);
    }

    private void verifyDeniedDevices(ContainerId containerId, List<GpuDevice> list) throws ResourceHandlerException, PrivilegedOperationException {
        ((CGroupsHandler) Mockito.verify(this.mockCGroupsHandler, Mockito.times(1))).createCGroup(CGroupsHandler.CGroupController.DEVICES, containerId.toString());
        if (null == list || list.isEmpty()) {
            return;
        }
        ArrayList arrayList = new ArrayList();
        Iterator<GpuDevice> it = list.iterator();
        while (it.hasNext()) {
            arrayList.add(Integer.valueOf(it.next().getMinorNumber()));
        }
        ((PrivilegedOperationExecutor) Mockito.verify(this.mockPrivilegedExecutor, Mockito.times(1))).executePrivilegedOperation(new PrivilegedOperation(PrivilegedOperation.OperationType.GPU, (List<String>) Arrays.asList(GpuResourceHandlerImpl.CONTAINER_ID_CLI_OPTION, containerId.toString(), GpuResourceHandlerImpl.EXCLUDED_GPUS_CLI_OPTION, StringUtils.join(AbstractNodeLabelsProvider.NODE_LABELS_SEPRATOR, arrayList))), true);
    }

    private void commonTestAllocation(boolean z) throws Exception {
        Configuration yarnConfiguration = new YarnConfiguration();
        yarnConfiguration.set("yarn.nodemanager.resource-plugins.gpu.allowed-gpu-devices", "0:0,1:1,2:3,3:4");
        GpuDiscoverer.getInstance().initialize(yarnConfiguration);
        this.gpuResourceHandler.bootstrap(yarnConfiguration);
        Assert.assertEquals(4L, this.gpuResourceHandler.getGpuAllocator().getAvailableGpus());
        this.gpuResourceHandler.preStart(mockContainerWithGpuRequest(1, 3, z));
        if (z) {
            verifyDeniedDevices(getContainerId(1), Collections.emptyList());
        } else {
            verifyDeniedDevices(getContainerId(1), Arrays.asList(new GpuDevice(3, 4)));
        }
        boolean z2 = false;
        try {
            this.gpuResourceHandler.preStart(mockContainerWithGpuRequest(2, 2, z));
        } catch (ResourceHandlerException e) {
            z2 = true;
        }
        Assert.assertTrue(z2);
        this.gpuResourceHandler.preStart(mockContainerWithGpuRequest(3, 1, z));
        if (z) {
            verifyDeniedDevices(getContainerId(3), Collections.emptyList());
        } else {
            verifyDeniedDevices(getContainerId(3), Arrays.asList(new GpuDevice(0, 0), new GpuDevice(1, 1), new GpuDevice(2, 3)));
        }
        this.gpuResourceHandler.preStart(mockContainerWithGpuRequest(4, 0, z));
        if (z) {
            verifyDeniedDevices(getContainerId(4), Collections.emptyList());
        } else {
            verifyDeniedDevices(getContainerId(4), Arrays.asList(new GpuDevice(0, 0), new GpuDevice(1, 1), new GpuDevice(2, 3), new GpuDevice(3, 4)));
        }
        this.gpuResourceHandler.postComplete(getContainerId(1));
        ((CGroupsHandler) Mockito.verify(this.mockCGroupsHandler, Mockito.times(1))).createCGroup(CGroupsHandler.CGroupController.DEVICES, getContainerId(1).toString());
        Assert.assertEquals(3L, this.gpuResourceHandler.getGpuAllocator().getAvailableGpus());
        this.gpuResourceHandler.postComplete(getContainerId(3));
        ((CGroupsHandler) Mockito.verify(this.mockCGroupsHandler, Mockito.times(1))).createCGroup(CGroupsHandler.CGroupController.DEVICES, getContainerId(3).toString());
        Assert.assertEquals(4L, this.gpuResourceHandler.getGpuAllocator().getAvailableGpus());
    }

    @Test
    public void testAllocationWhenDockerContainerEnabled() throws Exception {
        commonTestAllocation(true);
    }

    @Test
    public void testAllocation() throws Exception {
        commonTestAllocation(false);
    }

    @Test
    public void testAssignedGpuWillBeCleanedupWhenStoreOpFails() throws Exception {
        Configuration yarnConfiguration = new YarnConfiguration();
        yarnConfiguration.set("yarn.nodemanager.resource-plugins.gpu.allowed-gpu-devices", "0:0,1:1,2:3,3:4");
        GpuDiscoverer.getInstance().initialize(yarnConfiguration);
        this.gpuResourceHandler.bootstrap(yarnConfiguration);
        Assert.assertEquals(4L, this.gpuResourceHandler.getGpuAllocator().getAvailableGpus());
        ((NMStateStoreService) Mockito.doThrow(new IOException("Exception ...")).when(this.mockNMStateStore)).storeAssignedResources((Container) Matchers.any(Container.class), Matchers.anyString(), Matchers.anyList());
        boolean z = false;
        try {
            this.gpuResourceHandler.preStart(mockContainerWithGpuRequest(1, 3));
        } catch (ResourceHandlerException e) {
            z = true;
        }
        Assert.assertTrue("preStart should throw exception", z);
        Assert.assertEquals(4L, this.gpuResourceHandler.getGpuAllocator().getAvailableGpus());
    }

    @Test
    public void testAllocationWithoutAllowedGpus() throws Exception {
        Configuration yarnConfiguration = new YarnConfiguration();
        yarnConfiguration.set("yarn.nodemanager.resource-plugins.gpu.allowed-gpu-devices", " ");
        GpuDiscoverer.getInstance().initialize(yarnConfiguration);
        try {
            this.gpuResourceHandler.bootstrap(yarnConfiguration);
            Assert.fail("Should fail because no GPU available");
        } catch (ResourceHandlerException e) {
        }
        this.gpuResourceHandler.preStart(mockContainerWithGpuRequest(1, 0));
        verifyDeniedDevices(getContainerId(1), Collections.emptyList());
        boolean z = false;
        try {
            this.gpuResourceHandler.preStart(mockContainerWithGpuRequest(2, 1));
        } catch (ResourceHandlerException e2) {
            z = true;
        }
        Assert.assertTrue(z);
        this.gpuResourceHandler.postComplete(getContainerId(1));
        ((CGroupsHandler) Mockito.verify(this.mockCGroupsHandler, Mockito.times(1))).createCGroup(CGroupsHandler.CGroupController.DEVICES, getContainerId(1).toString());
        Assert.assertEquals(0L, this.gpuResourceHandler.getGpuAllocator().getAvailableGpus());
    }

    @Test
    public void testAllocationStored() throws Exception {
        Configuration yarnConfiguration = new YarnConfiguration();
        yarnConfiguration.set("yarn.nodemanager.resource-plugins.gpu.allowed-gpu-devices", "0:0,1:1,2:3,3:4");
        GpuDiscoverer.getInstance().initialize(yarnConfiguration);
        this.gpuResourceHandler.bootstrap(yarnConfiguration);
        Assert.assertEquals(4L, this.gpuResourceHandler.getGpuAllocator().getAvailableGpus());
        Container mockContainerWithGpuRequest = mockContainerWithGpuRequest(1, 3);
        this.gpuResourceHandler.preStart(mockContainerWithGpuRequest);
        ((NMStateStoreService) Mockito.verify(this.mockNMStateStore)).storeAssignedResources(mockContainerWithGpuRequest, "yarn.io/gpu", Arrays.asList(new GpuDevice(0, 0), new GpuDevice(1, 1), new GpuDevice(2, 3)));
        verifyDeniedDevices(getContainerId(1), Arrays.asList(new GpuDevice(3, 4)));
        Container mockContainerWithGpuRequest2 = mockContainerWithGpuRequest(2, 0);
        this.gpuResourceHandler.preStart(mockContainerWithGpuRequest2);
        verifyDeniedDevices(getContainerId(2), Arrays.asList(new GpuDevice(0, 0), new GpuDevice(1, 1), new GpuDevice(2, 3), new GpuDevice(3, 4)));
        Assert.assertEquals(0L, mockContainerWithGpuRequest2.getResourceMappings().getAssignedResources("yarn.io/gpu").size());
        ((NMStateStoreService) Mockito.verify(this.mockNMStateStore, Mockito.never())).storeAssignedResources((Container) Matchers.eq(mockContainerWithGpuRequest2), (String) Matchers.eq("yarn.io/gpu"), Matchers.anyListOf(Serializable.class));
    }

    @Test
    public void testAllocationStoredWithNULLStateStore() throws Exception {
        NMNullStateStoreService nMNullStateStoreService = (NMNullStateStoreService) Mockito.mock(NMNullStateStoreService.class);
        Context context = (Context) Mockito.mock(Context.class);
        Mockito.when(context.getNMStateStore()).thenReturn(nMNullStateStoreService);
        GpuResourceHandlerImpl gpuResourceHandlerImpl = new GpuResourceHandlerImpl(context, this.mockCGroupsHandler, this.mockPrivilegedExecutor);
        Configuration yarnConfiguration = new YarnConfiguration();
        yarnConfiguration.set("yarn.nodemanager.resource-plugins.gpu.allowed-gpu-devices", "0:0,1:1,2:3,3:4");
        GpuDiscoverer.getInstance().initialize(yarnConfiguration);
        gpuResourceHandlerImpl.bootstrap(yarnConfiguration);
        Assert.assertEquals(4L, gpuResourceHandlerImpl.getGpuAllocator().getAvailableGpus());
        Container mockContainerWithGpuRequest = mockContainerWithGpuRequest(1, 3);
        gpuResourceHandlerImpl.preStart(mockContainerWithGpuRequest);
        ((NMStateStoreService) Mockito.verify(context.getNMStateStore())).storeAssignedResources(mockContainerWithGpuRequest, "yarn.io/gpu", Arrays.asList(new GpuDevice(0, 0), new GpuDevice(1, 1), new GpuDevice(2, 3)));
    }

    @Test
    public void testRecoverResourceAllocation() throws Exception {
        Configuration yarnConfiguration = new YarnConfiguration();
        yarnConfiguration.set("yarn.nodemanager.resource-plugins.gpu.allowed-gpu-devices", "0:0,1:1,2:3,3:4");
        GpuDiscoverer.getInstance().initialize(yarnConfiguration);
        this.gpuResourceHandler.bootstrap(yarnConfiguration);
        Assert.assertEquals(4L, this.gpuResourceHandler.getGpuAllocator().getAvailableGpus());
        Container container = (Container) Mockito.mock(Container.class);
        ResourceMappings resourceMappings = new ResourceMappings();
        ResourceMappings.AssignedResources assignedResources = new ResourceMappings.AssignedResources();
        assignedResources.updateAssignedResources(Arrays.asList(new GpuDevice(1, 1), new GpuDevice(2, 3)));
        resourceMappings.addAssignedResources("yarn.io/gpu", assignedResources);
        Mockito.when(container.getResourceMappings()).thenReturn(resourceMappings);
        this.runningContainersMap.put(getContainerId(1), container);
        this.gpuResourceHandler.reacquireContainer(getContainerId(1));
        Map<GpuDevice, ContainerId> deviceAllocationMappingCopy = this.gpuResourceHandler.getGpuAllocator().getDeviceAllocationMappingCopy();
        Assert.assertEquals(2L, deviceAllocationMappingCopy.size());
        Assert.assertTrue(deviceAllocationMappingCopy.keySet().contains(new GpuDevice(1, 1)));
        Assert.assertTrue(deviceAllocationMappingCopy.keySet().contains(new GpuDevice(2, 3)));
        Assert.assertEquals(deviceAllocationMappingCopy.get(new GpuDevice(1, 1)), getContainerId(1));
        Container container2 = (Container) Mockito.mock(Container.class);
        ResourceMappings resourceMappings2 = new ResourceMappings();
        ResourceMappings.AssignedResources assignedResources2 = new ResourceMappings.AssignedResources();
        assignedResources2.updateAssignedResources(Arrays.asList(new GpuDevice(3, 4), new GpuDevice(4, 5)));
        resourceMappings2.addAssignedResources("yarn.io/gpu", assignedResources2);
        Mockito.when(container2.getResourceMappings()).thenReturn(resourceMappings2);
        this.runningContainersMap.put(getContainerId(2), container2);
        boolean z = false;
        try {
            this.gpuResourceHandler.reacquireContainer(getContainerId(1));
        } catch (ResourceHandlerException e) {
            z = true;
        }
        Assert.assertTrue("Should fail since requested device Id is not in allowed list", z);
        Map<GpuDevice, ContainerId> deviceAllocationMappingCopy2 = this.gpuResourceHandler.getGpuAllocator().getDeviceAllocationMappingCopy();
        Assert.assertEquals(2L, deviceAllocationMappingCopy2.size());
        Assert.assertTrue(deviceAllocationMappingCopy2.keySet().containsAll(Arrays.asList(new GpuDevice(1, 1), new GpuDevice(2, 3))));
        Assert.assertEquals(deviceAllocationMappingCopy2.get(new GpuDevice(1, 1)), getContainerId(1));
        Container container3 = (Container) Mockito.mock(Container.class);
        ResourceMappings resourceMappings3 = new ResourceMappings();
        ResourceMappings.AssignedResources assignedResources3 = new ResourceMappings.AssignedResources();
        assignedResources3.updateAssignedResources(Arrays.asList(new GpuDevice(3, 4), new GpuDevice(2, 3)));
        resourceMappings3.addAssignedResources("gpu", assignedResources3);
        Mockito.when(container3.getResourceMappings()).thenReturn(resourceMappings3);
        this.runningContainersMap.put(getContainerId(2), container3);
        boolean z2 = false;
        try {
            this.gpuResourceHandler.reacquireContainer(getContainerId(1));
        } catch (ResourceHandlerException e2) {
            z2 = true;
        }
        Assert.assertTrue("Should fail since requested device Id is not in allowed list", z2);
        Map<GpuDevice, ContainerId> deviceAllocationMappingCopy3 = this.gpuResourceHandler.getGpuAllocator().getDeviceAllocationMappingCopy();
        Assert.assertEquals(2L, deviceAllocationMappingCopy3.size());
        Assert.assertTrue(deviceAllocationMappingCopy3.keySet().containsAll(Arrays.asList(new GpuDevice(1, 1), new GpuDevice(2, 3))));
        Assert.assertEquals(deviceAllocationMappingCopy3.get(new GpuDevice(1, 1)), getContainerId(1));
    }
}
