Skip to content

Commit f569e6c

Browse files
committed
Simplify _MemPoolAttributes to use direct MemoryPoolHandle
Replace weakref pattern with direct MemoryPoolHandle storage in _MemPoolAttributes. The handle's shared_ptr keeps the underlying pool alive, so attributes remain accessible after the MR is deleted. Note: _MemPool retains __weakref__ because the IPC subsystem uses WeakValueDictionary to track memory resources across processes.
1 parent 8053ee5 commit f569e6c

3 files changed

Lines changed: 21 additions & 30 deletions

File tree

cuda_core/cuda/core/_memory/_memory_pool.pxd

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,16 @@ cdef class _MemPool(MemoryResource):
1919
object __weakref__
2020

2121

22+
cdef class _MemPoolAttributes:
23+
cdef:
24+
MemoryPoolHandle _h_pool
25+
26+
@staticmethod
27+
cdef _MemPoolAttributes _init(MemoryPoolHandle h_pool)
28+
29+
cdef int _getattribute(self, cydriver.CUmemPool_attribute attr_enum, void* value) except? -1
30+
31+
2232
cdef class _MemPoolOptions:
2333

2434
cdef:

cuda_core/cuda/core/_memory/_memory_pool.pyx

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ from cuda.core._utils.cuda_utils cimport (
2929
)
3030

3131
import platform # no-cython-lint
32-
import weakref
3332

3433
from cuda.core._utils.cuda_utils import driver
3534

@@ -45,16 +44,15 @@ cdef class _MemPoolOptions:
4544

4645

4746
cdef class _MemPoolAttributes:
48-
cdef:
49-
object _mr_weakref
47+
"""Provides access to memory pool attributes."""
5048

5149
def __init__(self, *args, **kwargs):
5250
raise RuntimeError("_MemPoolAttributes cannot be instantiated directly. Please use MemoryResource APIs.")
5351

54-
@classmethod
55-
def _init(cls, mr):
56-
cdef _MemPoolAttributes self = _MemPoolAttributes.__new__(cls)
57-
self._mr_weakref = mr
52+
@staticmethod
53+
cdef _MemPoolAttributes _init(MemoryPoolHandle h_pool):
54+
cdef _MemPoolAttributes self = _MemPoolAttributes.__new__(_MemPoolAttributes)
55+
self._h_pool = h_pool
5856
return self
5957

6058
def __repr__(self):
@@ -64,12 +62,8 @@ cdef class _MemPoolAttributes:
6462
)
6563

6664
cdef int _getattribute(self, cydriver.CUmemPool_attribute attr_enum, void* value) except?-1:
67-
cdef _MemPool mr = <_MemPool>(self._mr_weakref())
68-
if mr is None:
69-
raise RuntimeError("_MemPool is expired")
70-
cdef cydriver.CUmemoryPool pool_handle = as_cu(mr._h_pool)
7165
with nogil:
72-
HANDLE_RETURN(cydriver.cuMemPoolGetAttribute(pool_handle, attr_enum, value))
66+
HANDLE_RETURN(cydriver.cuMemPoolGetAttribute(as_cu(self._h_pool), attr_enum, value))
7367
return 0
7468

7569
@property
@@ -197,8 +191,7 @@ cdef class _MemPool(MemoryResource):
197191
def attributes(self) -> _MemPoolAttributes:
198192
"""Memory pool attributes."""
199193
if self._attributes is None:
200-
ref = weakref.ref(self)
201-
self._attributes = _MemPoolAttributes._init(ref)
194+
self._attributes = _MemPoolAttributes._init(self._h_pool)
202195
return self._attributes
203196

204197
@property

cuda_core/tests/test_memory.py

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1162,7 +1162,7 @@ def test_mempool_attributes_repr(memory_resource_factory):
11621162

11631163

11641164
def test_mempool_attributes_ownership(memory_resource_factory):
1165-
"""Ensure the attributes bundle handles references correctly for all memory resource types."""
1165+
"""Ensure the attributes bundle keeps the pool alive via the handle."""
11661166
MR, MRops = memory_resource_factory
11671167
device = Device()
11681168

@@ -1190,21 +1190,9 @@ def test_mempool_attributes_ownership(memory_resource_factory):
11901190
mr.close()
11911191
del mr
11921192

1193-
# After deleting the memory resource, the attributes suite is disconnected.
1194-
with pytest.raises(RuntimeError, match="is expired"):
1195-
_ = attributes.used_mem_high
1196-
1197-
# Even when a new object is created (we found a case where the same
1198-
# mempool handle was really reused).
1199-
if MR is DeviceMemoryResource:
1200-
mr = MR(device, dict(max_size=POOL_SIZE)) # noqa: F841
1201-
elif MR is PinnedMemoryResource:
1202-
mr = MR(dict(max_size=POOL_SIZE)) # noqa: F841
1203-
elif MR is ManagedMemoryResource:
1204-
mr = create_managed_memory_resource_or_skip(dict()) # noqa: F841
1205-
1206-
with pytest.raises(RuntimeError, match="is expired"):
1207-
_ = attributes.used_mem_high
1193+
# The attributes bundle keeps the pool alive via MemoryPoolHandle,
1194+
# so accessing attributes still works even after the MR is deleted.
1195+
_ = attributes.used_mem_high # Should not raise
12081196

12091197

12101198
# Ensure that memory views dellocate their reference to dlpack tensors

0 commit comments

Comments
 (0)