diff --git a/gpuallocator/allocator.go b/gpuallocator/allocator.go index e14405d..08e64ad 100644 --- a/gpuallocator/allocator.go +++ b/gpuallocator/allocator.go @@ -117,6 +117,17 @@ func (a *Allocator) AllocateSpecific(devices ...*Device) error { // Free a set of GPUs back to the allocator. func (a *Allocator) Free(devices ...*Device) { - a.remaining.Insert(devices...) - a.allocated.Delete(devices...) + for _, device := range devices { + if device == nil { + continue + } + + allocated, ok := a.allocated[device.UUID] + if !ok || allocated != device { + continue + } + + a.remaining.Insert(device) + a.allocated.Delete(device) + } } diff --git a/gpuallocator/allocator_test.go b/gpuallocator/allocator_test.go new file mode 100644 index 0000000..5deace9 --- /dev/null +++ b/gpuallocator/allocator_test.go @@ -0,0 +1,50 @@ +// Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. + +package gpuallocator + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestAllocatorFreeReturnsAllocatedDevice(t *testing.T) { + allocator := newAllocatorFrom(New4xRTX8000Node().Devices(), NewSimplePolicy()) + + allocated := allocator.Allocate(1) + require.Len(t, allocated, 1) + + allocator.Free(allocated[0]) + + require.False(t, allocator.allocated.Contains(allocated[0])) + require.True(t, allocator.remaining.Contains(allocated[0])) + require.Same(t, allocated[0], allocator.remaining[allocated[0].UUID]) +} + +func TestAllocatorFreeIgnoresUnknownDevice(t *testing.T) { + allocator := newAllocatorFrom(New4xRTX8000Node().Devices(), NewSimplePolicy()) + + allocated := allocator.Allocate(1) + require.Len(t, allocated, 1) + + unknown := (*Device)(NewTestGPU(99)) + allocator.Free(unknown) + + require.False(t, allocator.remaining.Contains(unknown)) + require.True(t, allocator.allocated.Contains(allocated[0])) + require.False(t, allocator.remaining.Contains(allocated[0])) +} + +func TestAllocatorFreeIgnoresFabricatedAllocatedDevice(t *testing.T) { + allocator := newAllocatorFrom(New4xRTX8000Node().Devices(), NewSimplePolicy()) + + allocated := allocator.Allocate(1) + require.Len(t, allocated, 1) + + fabricated := (*Device)(NewTestGPU(allocated[0].Index)) + allocator.Free(fabricated) + + require.True(t, allocator.allocated.Contains(allocated[0])) + require.Same(t, allocated[0], allocator.allocated[allocated[0].UUID]) + require.False(t, allocator.remaining.Contains(allocated[0])) +}