From 84fd38e20a90348929d88d428014fcb746f28854 Mon Sep 17 00:00:00 2001 From: Evan Lezar Date: Thu, 2 Oct 2025 11:09:42 +0200 Subject: [PATCH 1/3] [no-relnote] Add ResourceManager mock for testing Signed-off-by: Evan Lezar --- internal/rm/rm.go | 2 + internal/rm/rm_mock.go | 319 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 321 insertions(+) create mode 100644 internal/rm/rm_mock.go diff --git a/internal/rm/rm.go b/internal/rm/rm.go index 372165df0..33f44b9d8 100644 --- a/internal/rm/rm.go +++ b/internal/rm/rm.go @@ -37,6 +37,8 @@ type resourceManager struct { } // ResourceManager provides an interface for listing a set of Devices and checking health on them +// +//go:generate moq -rm -fmt=goimports -stub -out rm_mock.go . ResourceManager type ResourceManager interface { Resource() spec.ResourceName Devices() Devices diff --git a/internal/rm/rm_mock.go b/internal/rm/rm_mock.go new file mode 100644 index 000000000..4efee5fd9 --- /dev/null +++ b/internal/rm/rm_mock.go @@ -0,0 +1,319 @@ +// Code generated by moq; DO NOT EDIT. +// github.com/matryer/moq + +package rm + +import ( + "sync" + + spec "github.com/NVIDIA/k8s-device-plugin/api/config/v1" +) + +// Ensure, that ResourceManagerMock does implement ResourceManager. +// If this is not the case, regenerate this file with moq. +var _ ResourceManager = &ResourceManagerMock{} + +// ResourceManagerMock is a mock implementation of ResourceManager. +// +// func TestSomethingThatUsesResourceManager(t *testing.T) { +// +// // make and configure a mocked ResourceManager +// mockedResourceManager := &ResourceManagerMock{ +// CheckHealthFunc: func(stop <-chan interface{}, unhealthy chan<- *Device) error { +// panic("mock out the CheckHealth method") +// }, +// DevicesFunc: func() Devices { +// panic("mock out the Devices method") +// }, +// GetDevicePathsFunc: func(strings []string) []string { +// panic("mock out the GetDevicePaths method") +// }, +// GetPreferredAllocationFunc: func(available []string, required []string, size int) ([]string, error) { +// panic("mock out the GetPreferredAllocation method") +// }, +// ResourceFunc: func() spec.ResourceName { +// panic("mock out the Resource method") +// }, +// ValidateRequestFunc: func(annotatedIDs AnnotatedIDs) error { +// panic("mock out the ValidateRequest method") +// }, +// } +// +// // use mockedResourceManager in code that requires ResourceManager +// // and then make assertions. +// +// } +type ResourceManagerMock struct { + // CheckHealthFunc mocks the CheckHealth method. + CheckHealthFunc func(stop <-chan interface{}, unhealthy chan<- *Device) error + + // DevicesFunc mocks the Devices method. + DevicesFunc func() Devices + + // GetDevicePathsFunc mocks the GetDevicePaths method. + GetDevicePathsFunc func(strings []string) []string + + // GetPreferredAllocationFunc mocks the GetPreferredAllocation method. + GetPreferredAllocationFunc func(available []string, required []string, size int) ([]string, error) + + // ResourceFunc mocks the Resource method. + ResourceFunc func() spec.ResourceName + + // ValidateRequestFunc mocks the ValidateRequest method. + ValidateRequestFunc func(annotatedIDs AnnotatedIDs) error + + // calls tracks calls to the methods. + calls struct { + // CheckHealth holds details about calls to the CheckHealth method. + CheckHealth []struct { + // Stop is the stop argument value. + Stop <-chan interface{} + // Unhealthy is the unhealthy argument value. + Unhealthy chan<- *Device + } + // Devices holds details about calls to the Devices method. + Devices []struct { + } + // GetDevicePaths holds details about calls to the GetDevicePaths method. + GetDevicePaths []struct { + // Strings is the strings argument value. + Strings []string + } + // GetPreferredAllocation holds details about calls to the GetPreferredAllocation method. + GetPreferredAllocation []struct { + // Available is the available argument value. + Available []string + // Required is the required argument value. + Required []string + // Size is the size argument value. + Size int + } + // Resource holds details about calls to the Resource method. + Resource []struct { + } + // ValidateRequest holds details about calls to the ValidateRequest method. + ValidateRequest []struct { + // AnnotatedIDs is the annotatedIDs argument value. + AnnotatedIDs AnnotatedIDs + } + } + lockCheckHealth sync.RWMutex + lockDevices sync.RWMutex + lockGetDevicePaths sync.RWMutex + lockGetPreferredAllocation sync.RWMutex + lockResource sync.RWMutex + lockValidateRequest sync.RWMutex +} + +// CheckHealth calls CheckHealthFunc. +func (mock *ResourceManagerMock) CheckHealth(stop <-chan interface{}, unhealthy chan<- *Device) error { + callInfo := struct { + Stop <-chan interface{} + Unhealthy chan<- *Device + }{ + Stop: stop, + Unhealthy: unhealthy, + } + mock.lockCheckHealth.Lock() + mock.calls.CheckHealth = append(mock.calls.CheckHealth, callInfo) + mock.lockCheckHealth.Unlock() + if mock.CheckHealthFunc == nil { + var ( + errOut error + ) + return errOut + } + return mock.CheckHealthFunc(stop, unhealthy) +} + +// CheckHealthCalls gets all the calls that were made to CheckHealth. +// Check the length with: +// +// len(mockedResourceManager.CheckHealthCalls()) +func (mock *ResourceManagerMock) CheckHealthCalls() []struct { + Stop <-chan interface{} + Unhealthy chan<- *Device +} { + var calls []struct { + Stop <-chan interface{} + Unhealthy chan<- *Device + } + mock.lockCheckHealth.RLock() + calls = mock.calls.CheckHealth + mock.lockCheckHealth.RUnlock() + return calls +} + +// Devices calls DevicesFunc. +func (mock *ResourceManagerMock) Devices() Devices { + callInfo := struct { + }{} + mock.lockDevices.Lock() + mock.calls.Devices = append(mock.calls.Devices, callInfo) + mock.lockDevices.Unlock() + if mock.DevicesFunc == nil { + var ( + devicesOut Devices + ) + return devicesOut + } + return mock.DevicesFunc() +} + +// DevicesCalls gets all the calls that were made to Devices. +// Check the length with: +// +// len(mockedResourceManager.DevicesCalls()) +func (mock *ResourceManagerMock) DevicesCalls() []struct { +} { + var calls []struct { + } + mock.lockDevices.RLock() + calls = mock.calls.Devices + mock.lockDevices.RUnlock() + return calls +} + +// GetDevicePaths calls GetDevicePathsFunc. +func (mock *ResourceManagerMock) GetDevicePaths(strings []string) []string { + callInfo := struct { + Strings []string + }{ + Strings: strings, + } + mock.lockGetDevicePaths.Lock() + mock.calls.GetDevicePaths = append(mock.calls.GetDevicePaths, callInfo) + mock.lockGetDevicePaths.Unlock() + if mock.GetDevicePathsFunc == nil { + var ( + stringsOut []string + ) + return stringsOut + } + return mock.GetDevicePathsFunc(strings) +} + +// GetDevicePathsCalls gets all the calls that were made to GetDevicePaths. +// Check the length with: +// +// len(mockedResourceManager.GetDevicePathsCalls()) +func (mock *ResourceManagerMock) GetDevicePathsCalls() []struct { + Strings []string +} { + var calls []struct { + Strings []string + } + mock.lockGetDevicePaths.RLock() + calls = mock.calls.GetDevicePaths + mock.lockGetDevicePaths.RUnlock() + return calls +} + +// GetPreferredAllocation calls GetPreferredAllocationFunc. +func (mock *ResourceManagerMock) GetPreferredAllocation(available []string, required []string, size int) ([]string, error) { + callInfo := struct { + Available []string + Required []string + Size int + }{ + Available: available, + Required: required, + Size: size, + } + mock.lockGetPreferredAllocation.Lock() + mock.calls.GetPreferredAllocation = append(mock.calls.GetPreferredAllocation, callInfo) + mock.lockGetPreferredAllocation.Unlock() + if mock.GetPreferredAllocationFunc == nil { + var ( + stringsOut []string + errOut error + ) + return stringsOut, errOut + } + return mock.GetPreferredAllocationFunc(available, required, size) +} + +// GetPreferredAllocationCalls gets all the calls that were made to GetPreferredAllocation. +// Check the length with: +// +// len(mockedResourceManager.GetPreferredAllocationCalls()) +func (mock *ResourceManagerMock) GetPreferredAllocationCalls() []struct { + Available []string + Required []string + Size int +} { + var calls []struct { + Available []string + Required []string + Size int + } + mock.lockGetPreferredAllocation.RLock() + calls = mock.calls.GetPreferredAllocation + mock.lockGetPreferredAllocation.RUnlock() + return calls +} + +// Resource calls ResourceFunc. +func (mock *ResourceManagerMock) Resource() spec.ResourceName { + callInfo := struct { + }{} + mock.lockResource.Lock() + mock.calls.Resource = append(mock.calls.Resource, callInfo) + mock.lockResource.Unlock() + if mock.ResourceFunc == nil { + var ( + resourceNameOut spec.ResourceName + ) + return resourceNameOut + } + return mock.ResourceFunc() +} + +// ResourceCalls gets all the calls that were made to Resource. +// Check the length with: +// +// len(mockedResourceManager.ResourceCalls()) +func (mock *ResourceManagerMock) ResourceCalls() []struct { +} { + var calls []struct { + } + mock.lockResource.RLock() + calls = mock.calls.Resource + mock.lockResource.RUnlock() + return calls +} + +// ValidateRequest calls ValidateRequestFunc. +func (mock *ResourceManagerMock) ValidateRequest(annotatedIDs AnnotatedIDs) error { + callInfo := struct { + AnnotatedIDs AnnotatedIDs + }{ + AnnotatedIDs: annotatedIDs, + } + mock.lockValidateRequest.Lock() + mock.calls.ValidateRequest = append(mock.calls.ValidateRequest, callInfo) + mock.lockValidateRequest.Unlock() + if mock.ValidateRequestFunc == nil { + var ( + errOut error + ) + return errOut + } + return mock.ValidateRequestFunc(annotatedIDs) +} + +// ValidateRequestCalls gets all the calls that were made to ValidateRequest. +// Check the length with: +// +// len(mockedResourceManager.ValidateRequestCalls()) +func (mock *ResourceManagerMock) ValidateRequestCalls() []struct { + AnnotatedIDs AnnotatedIDs +} { + var calls []struct { + AnnotatedIDs AnnotatedIDs + } + mock.lockValidateRequest.RLock() + calls = mock.calls.ValidateRequest + mock.lockValidateRequest.RUnlock() + return calls +} From 521a12255900fe77b13159bb7a4e9d70d4aaee28 Mon Sep 17 00:00:00 2001 From: Evan Lezar Date: Thu, 2 Oct 2025 11:09:03 +0200 Subject: [PATCH 2/3] Check for nil before reading boolean config values Signed-off-by: Evan Lezar --- internal/plugin/server.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/internal/plugin/server.go b/internal/plugin/server.go index 46f4f0de3..f81b0143d 100644 --- a/internal/plugin/server.go +++ b/internal/plugin/server.go @@ -348,16 +348,16 @@ func (plugin *nvidiaDevicePlugin) getAllocateResponse(requestIds []string) (*plu if plugin.deviceListStrategies.Includes(spec.DeviceListStrategyVolumeMounts) { plugin.updateResponseForDeviceMounts(response, deviceIDs...) } - if *plugin.config.Flags.Plugin.PassDeviceSpecs { + if plugin.config.Flags.Plugin.PassDeviceSpecs != nil && *plugin.config.Flags.Plugin.PassDeviceSpecs { response.Devices = append(response.Devices, plugin.apiDeviceSpecs(*plugin.config.Flags.NvidiaDevRoot, requestIds)...) } - if *plugin.config.Flags.GDRCopyEnabled { + if plugin.config.Flags.GDRCopyEnabled != nil && *plugin.config.Flags.GDRCopyEnabled { response.Envs["NVIDIA_GDRCOPY"] = "enabled" } - if *plugin.config.Flags.GDSEnabled { + if plugin.config.Flags.GDSEnabled != nil && *plugin.config.Flags.GDSEnabled { response.Envs["NVIDIA_GDS"] = "enabled" } - if *plugin.config.Flags.MOFEDEnabled { + if plugin.config.Flags.MOFEDEnabled != nil && *plugin.config.Flags.MOFEDEnabled { response.Envs["NVIDIA_MOFED"] = "enabled" } return response, nil From 75633e25beb2d5a7766adb969fdb7578b05a22f2 Mon Sep 17 00:00:00 2001 From: Evan Lezar Date: Thu, 2 Oct 2025 10:45:24 +0200 Subject: [PATCH 3/3] Deduplicate requested device IDs This change ensures that the incoming device IDs are deduplicated before updating the AllocateResponse. This avoids cases where the NVIDIA_VISIBLE_DEVICES envvar or CDI annotations contain repeated device UUIDs or INDICES that do not add additional modifications to the container. Signed-off-by: Evan Lezar --- internal/plugin/server.go | 15 ++++-- internal/plugin/server_test.go | 85 ++++++++++++++++++++++++++++++++++ 2 files changed, 97 insertions(+), 3 deletions(-) diff --git a/internal/plugin/server.go b/internal/plugin/server.go index f81b0143d..29fbd2a29 100644 --- a/internal/plugin/server.go +++ b/internal/plugin/server.go @@ -319,7 +319,7 @@ func (plugin *nvidiaDevicePlugin) Allocate(ctx context.Context, reqs *pluginapi. } func (plugin *nvidiaDevicePlugin) getAllocateResponse(requestIds []string) (*pluginapi.ContainerAllocateResponse, error) { - deviceIDs := plugin.deviceIDsFromAnnotatedDeviceIDs(requestIds) + deviceIDs := plugin.uniqueDeviceIDsFromAnnotatedDeviceIDs(requestIds) // Create an empty response that will be updated as required below. response := &pluginapi.ContainerAllocateResponse{ @@ -451,7 +451,7 @@ func (plugin *nvidiaDevicePlugin) dial(unixSocketPath string, timeout time.Durat return c, nil } -func (plugin *nvidiaDevicePlugin) deviceIDsFromAnnotatedDeviceIDs(ids []string) []string { +func (plugin *nvidiaDevicePlugin) uniqueDeviceIDsFromAnnotatedDeviceIDs(ids []string) []string { var deviceIDs []string if *plugin.config.Flags.Plugin.DeviceIDStrategy == spec.DeviceIDStrategyUUID { deviceIDs = rm.AnnotatedIDs(ids).GetIDs() @@ -459,7 +459,16 @@ func (plugin *nvidiaDevicePlugin) deviceIDsFromAnnotatedDeviceIDs(ids []string) if *plugin.config.Flags.Plugin.DeviceIDStrategy == spec.DeviceIDStrategyIndex { deviceIDs = plugin.rm.Devices().Subset(ids).GetIndices() } - return deviceIDs + var uniqueIDs []string + seen := make(map[string]bool) + for _, id := range deviceIDs { + if seen[id] { + continue + } + seen[id] = true + uniqueIDs = append(uniqueIDs, id) + } + return uniqueIDs } func (plugin *nvidiaDevicePlugin) apiDevices() []*pluginapi.Device { diff --git a/internal/plugin/server_test.go b/internal/plugin/server_test.go index 3954e36e3..fb3693b5e 100644 --- a/internal/plugin/server_test.go +++ b/internal/plugin/server_test.go @@ -17,6 +17,7 @@ package plugin import ( + "context" "testing" "github.com/stretchr/testify/require" @@ -25,8 +26,88 @@ import ( v1 "github.com/NVIDIA/k8s-device-plugin/api/config/v1" "github.com/NVIDIA/k8s-device-plugin/internal/cdi" "github.com/NVIDIA/k8s-device-plugin/internal/imex" + "github.com/NVIDIA/k8s-device-plugin/internal/rm" ) +func TestAllocate(t *testing.T) { + testCases := []struct { + description string + request *pluginapi.AllocateRequest + expectedError error + expectedResponse *pluginapi.AllocateResponse + }{ + { + description: "single device", + request: &pluginapi.AllocateRequest{ + ContainerRequests: []*pluginapi.ContainerAllocateRequest{ + { + DevicesIDs: []string{"foo"}, + }, + }, + }, + expectedResponse: &pluginapi.AllocateResponse{ + ContainerResponses: []*pluginapi.ContainerAllocateResponse{ + { + Envs: map[string]string{ + "NVIDIA_VISIBLE_DEVICES": "foo", + }, + }, + }, + }, + }, + { + description: "duplicate device IDs", + request: &pluginapi.AllocateRequest{ + ContainerRequests: []*pluginapi.ContainerAllocateRequest{ + { + DevicesIDs: []string{"foo", "bar", "foo"}, + }, + }, + }, + expectedResponse: &pluginapi.AllocateResponse{ + ContainerResponses: []*pluginapi.ContainerAllocateResponse{ + { + Envs: map[string]string{ + "NVIDIA_VISIBLE_DEVICES": "foo,bar", + }, + }, + }, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.description, func(t *testing.T) { + plugin := nvidiaDevicePlugin{ + rm: &rm.ResourceManagerMock{ + ValidateRequestFunc: func(annotatedIDs rm.AnnotatedIDs) error { + return nil + }, + }, + config: &v1.Config{ + Flags: v1.Flags{ + CommandLineFlags: v1.CommandLineFlags{ + Plugin: &v1.PluginCommandLineFlags{ + DeviceIDStrategy: ptr(v1.DeviceIDStrategyUUID), + }, + }, + }, + }, + cdiHandler: &cdi.InterfaceMock{ + QualifiedNameFunc: func(c string, s string) string { + return "nvidia.com/" + c + "=" + s + }, + }, + deviceListStrategies: v1.DeviceListStrategies{"envvar": true}, + } + + response, err := plugin.Allocate(context.TODO(), tc.request) + require.EqualValues(t, tc.expectedError, err) + require.EqualValues(t, tc.expectedResponse, response) + }) + } +} + func TestCDIAllocateResponse(t *testing.T) { testCases := []struct { description string @@ -169,3 +250,7 @@ func TestCDIAllocateResponse(t *testing.T) { }) } } + +func ptr[T any](x T) *T { + return &x +}