Skip to content

Commit 46056e6

Browse files
committed
Simplify resource handle patterns and clean up tests
- Remove Kernel._module (ObjectCode reference no longer needed since KernelHandle keeps library alive via LibraryHandle dependency) - Simplify Kernel._from_obj signature (remove unused ObjectCode param) - Replace weakref patterns with direct handle storage: - KernelAttributes: store KernelHandle instead of weakref to Kernel - _MemPoolAttributes: store MemoryPoolHandle instead of weakref to _MemPool - Rename get_kernel_from_library to create_kernel_handle for consistency - Remove fragile annotation introspection from test_saxpy_arguments - Update test_mempool_attributes_ownership to reflect new ownership semantics
1 parent b262fa0 commit 46056e6

10 files changed

Lines changed: 42 additions & 62 deletions

File tree

cuda_core/cuda/core/_cpp/resource_handles.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -749,7 +749,7 @@ struct KernelBox {
749749
};
750750
} // namespace
751751

752-
KernelHandle get_kernel_from_library(LibraryHandle h_library, const char* name) {
752+
KernelHandle create_kernel_handle(LibraryHandle h_library, const char* name) {
753753
GILReleaseGuard gil;
754754
CUkernel kernel;
755755
if (CUDA_SUCCESS != (err = p_cuLibraryGetKernel(&kernel, *h_library, name))) {

cuda_core/cuda/core/_cpp/resource_handles.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,7 @@ LibraryHandle create_library_handle_ref(CUlibrary library);
254254
// The kernel structurally depends on the provided library handle.
255255
// Kernels have no explicit destroy - their lifetime is tied to the library.
256256
// Returns empty handle on error (caller must check).
257-
KernelHandle get_kernel_from_library(LibraryHandle h_library, const char* name);
257+
KernelHandle create_kernel_handle(LibraryHandle h_library, const char* name);
258258

259259
// Create a non-owning kernel handle with library dependency.
260260
// Use for borrowed kernels. The library handle keeps the library alive.

cuda_core/cuda/core/_memory/_memory_pool.pxd

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,16 @@ cdef class _MemPool(MemoryResource):
1616
IPCDataForMR _ipc_data
1717
object _attributes
1818
object _peer_accessible_by
19-
object __weakref__
19+
20+
21+
cdef class _MemPoolAttributes:
22+
cdef:
23+
MemoryPoolHandle _h_pool
24+
25+
@staticmethod
26+
cdef _MemPoolAttributes _init(MemoryPoolHandle h_pool)
27+
28+
cdef int _getattribute(self, cydriver.CUmemPool_attribute attr_enum, void* value) except? -1
2029

2130

2231
cdef class _MemPoolOptions:

cuda_core/cuda/core/_memory/_memory_pool.pyx

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ from cuda.core._utils.cuda_utils cimport (
3030

3131
from typing import TYPE_CHECKING
3232
import platform # no-cython-lint
33-
import weakref
3433

3534
from cuda.core._utils.cuda_utils import driver
3635

@@ -50,16 +49,15 @@ cdef class _MemPoolOptions:
5049

5150

5251
cdef class _MemPoolAttributes:
53-
cdef:
54-
object _mr_weakref
52+
"""Provides access to memory pool attributes."""
5553

5654
def __init__(self, *args, **kwargs):
5755
raise RuntimeError("_MemPoolAttributes cannot be instantiated directly. Please use MemoryResource APIs.")
5856

59-
@classmethod
60-
def _init(cls, mr):
61-
cdef _MemPoolAttributes self = _MemPoolAttributes.__new__(cls)
62-
self._mr_weakref = mr
57+
@staticmethod
58+
cdef _MemPoolAttributes _init(MemoryPoolHandle h_pool):
59+
cdef _MemPoolAttributes self = _MemPoolAttributes.__new__(_MemPoolAttributes)
60+
self._h_pool = h_pool
6361
return self
6462

6563
def __repr__(self):
@@ -69,12 +67,8 @@ cdef class _MemPoolAttributes:
6967
)
7068

7169
cdef int _getattribute(self, cydriver.CUmemPool_attribute attr_enum, void* value) except?-1:
72-
cdef _MemPool mr = <_MemPool>(self._mr_weakref())
73-
if mr is None:
74-
raise RuntimeError("_MemPool is expired")
75-
cdef cydriver.CUmemoryPool pool_handle = as_cu(mr._h_pool)
7670
with nogil:
77-
HANDLE_RETURN(cydriver.cuMemPoolGetAttribute(pool_handle, attr_enum, value))
71+
HANDLE_RETURN(cydriver.cuMemPoolGetAttribute(as_cu(self._h_pool), attr_enum, value))
7872
return 0
7973

8074
@property
@@ -202,8 +196,7 @@ cdef class _MemPool(MemoryResource):
202196
def attributes(self) -> _MemPoolAttributes:
203197
"""Memory pool attributes."""
204198
if self._attributes is None:
205-
ref = weakref.ref(self)
206-
self._attributes = _MemPoolAttributes._init(ref)
199+
self._attributes = _MemPoolAttributes._init(self._h_pool)
207200
return self._attributes
208201

209202
@property

cuda_core/cuda/core/_module.pxd

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,11 @@ cdef class KernelAttributes
1414
cdef class Kernel:
1515
cdef:
1616
KernelHandle _h_kernel
17-
ObjectCode _module # ObjectCode reference
1817
KernelAttributes _attributes # lazy
1918
KernelOccupancy _occupancy # lazy
20-
object __weakref__ # Enable weak references
2119

2220
@staticmethod
23-
cdef Kernel _from_obj(KernelHandle h_kernel, ObjectCode mod)
21+
cdef Kernel _from_obj(KernelHandle h_kernel)
2422

2523
cdef tuple _get_arguments_info(self, bint param_info=*)
2624

@@ -46,8 +44,11 @@ cdef class KernelOccupancy:
4644

4745
cdef class KernelAttributes:
4846
cdef:
49-
object _kernel_weakref
47+
KernelHandle _h_kernel
5048
dict _cache
5149

50+
@staticmethod
51+
cdef KernelAttributes _init(KernelHandle h_kernel)
52+
5253
cdef int _get_cached_attribute(self, int device_id, cydriver.CUfunction_attribute attribute) except? -1
5354
cdef int _resolve_device_id(self, device_id) except? -1

cuda_core/cuda/core/_module.pyx

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ from libc.stddef cimport size_t
88

99
import functools
1010
import threading
11-
import weakref
1211
from collections import namedtuple
1312

1413
from cuda.core._device import Device
@@ -21,7 +20,7 @@ from cuda.core._resource_handles cimport (
2120
create_library_handle_from_file,
2221
create_library_handle_from_data,
2322
create_library_handle_ref,
24-
get_kernel_from_library,
23+
create_kernel_handle,
2524
create_kernel_handle_ref,
2625
get_last_error,
2726
as_cu,
@@ -139,15 +138,15 @@ cdef inline LibraryHandle _make_empty_library_handle():
139138

140139

141140
cdef class KernelAttributes:
142-
"""Provides access to kernel attributes. Uses weakref to avoid preventing Kernel GC."""
141+
"""Provides access to kernel attributes."""
143142

144143
def __init__(self, *args, **kwargs):
145144
raise RuntimeError("KernelAttributes cannot be instantiated directly. Please use Kernel APIs.")
146145

147-
@classmethod
148-
def _init(cls, kernel):
149-
cdef KernelAttributes self = KernelAttributes.__new__(cls)
150-
self._kernel_weakref = weakref.ref(kernel)
146+
@staticmethod
147+
cdef KernelAttributes _init(KernelHandle h_kernel):
148+
cdef KernelAttributes self = KernelAttributes.__new__(KernelAttributes)
149+
self._h_kernel = h_kernel
151150
self._cache = {}
152151
_lazy_init()
153152
return self
@@ -158,12 +157,9 @@ cdef class KernelAttributes:
158157
cached = self._cache.get(cache_key, cache_key)
159158
if cached is not cache_key:
160159
return cached
161-
cdef Kernel kernel = <Kernel>(self._kernel_weakref())
162-
if kernel is None:
163-
raise RuntimeError("Cannot access kernel attributes for expired Kernel object")
164160
cdef int result
165161
with nogil:
166-
HANDLE_RETURN(cydriver.cuKernelGetAttribute(&result, attribute, as_cu(kernel._h_kernel), device_id))
162+
HANDLE_RETURN(cydriver.cuKernelGetAttribute(&result, attribute, as_cu(self._h_kernel), device_id))
167163
self._cache[cache_key] = result
168164
return result
169165

@@ -496,10 +492,9 @@ cdef class Kernel:
496492
raise RuntimeError("Kernel objects cannot be instantiated directly. Please use ObjectCode APIs.")
497493

498494
@staticmethod
499-
cdef Kernel _from_obj(KernelHandle h_kernel, ObjectCode mod):
495+
cdef Kernel _from_obj(KernelHandle h_kernel):
500496
cdef Kernel ker = Kernel.__new__(Kernel)
501497
ker._h_kernel = h_kernel
502-
ker._module = mod
503498
ker._attributes = None
504499
ker._occupancy = None
505500
return ker
@@ -508,7 +503,7 @@ cdef class Kernel:
508503
def attributes(self) -> KernelAttributes:
509504
"""Get the read-only attributes of this kernel."""
510505
if self._attributes is None:
511-
self._attributes = KernelAttributes._init(self)
506+
self._attributes = KernelAttributes._init(self._h_kernel)
512507
return self._attributes
513508

514509
cdef tuple _get_arguments_info(self, bint param_info=False):
@@ -607,7 +602,7 @@ cdef class Kernel:
607602
if not h_kernel:
608603
HANDLE_RETURN(get_last_error())
609604

610-
return Kernel._from_obj(h_kernel, mod)
605+
return Kernel._from_obj(h_kernel)
611606

612607

613608
CodeTypeT = bytes | bytearray | str
@@ -812,10 +807,10 @@ cdef class ObjectCode:
812807
except KeyError:
813808
name = name.encode()
814809

815-
cdef KernelHandle h_kernel = get_kernel_from_library(self._h_library, <const char*>name)
810+
cdef KernelHandle h_kernel = create_kernel_handle(self._h_library, <const char*>name)
816811
if not h_kernel:
817812
HANDLE_RETURN(get_last_error())
818-
return Kernel._from_obj(h_kernel, self)
813+
return Kernel._from_obj(h_kernel)
819814

820815
@property
821816
def code(self) -> CodeTypeT:

cuda_core/cuda/core/_resource_handles.pxd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,6 @@ cdef LibraryHandle create_library_handle_from_data(const void* data) nogil excep
109109
cdef LibraryHandle create_library_handle_ref(cydriver.CUlibrary library) nogil except+
110110

111111
# Kernel handles
112-
cdef KernelHandle get_kernel_from_library(LibraryHandle h_library, const char* name) nogil except+
112+
cdef KernelHandle create_kernel_handle(LibraryHandle h_library, const char* name) nogil except+
113113
cdef KernelHandle create_kernel_handle_ref(
114114
cydriver.CUkernel kernel, LibraryHandle h_library) nogil except+

cuda_core/cuda/core/_resource_handles.pyx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ cdef extern from "_cpp/resource_handles.hpp" namespace "cuda_core":
102102
cydriver.CUlibrary library) nogil except+
103103

104104
# Kernel handles
105-
KernelHandle get_kernel_from_library "cuda_core::get_kernel_from_library" (
105+
KernelHandle create_kernel_handle "cuda_core::create_kernel_handle" (
106106
LibraryHandle h_library, const char* name) nogil except+
107107
KernelHandle create_kernel_handle_ref "cuda_core::create_kernel_handle_ref" (
108108
cydriver.CUkernel kernel, LibraryHandle h_library) nogil except+

cuda_core/tests/test_memory.py

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1162,7 +1162,7 @@ def test_mempool_attributes_repr(memory_resource_factory):
11621162

11631163

11641164
def test_mempool_attributes_ownership(memory_resource_factory):
1165-
"""Ensure the attributes bundle handles references correctly for all memory resource types."""
1165+
"""Ensure the attributes bundle keeps the pool alive via the handle."""
11661166
MR, MRops = memory_resource_factory
11671167
device = Device()
11681168

@@ -1190,21 +1190,9 @@ def test_mempool_attributes_ownership(memory_resource_factory):
11901190
mr.close()
11911191
del mr
11921192

1193-
# After deleting the memory resource, the attributes suite is disconnected.
1194-
with pytest.raises(RuntimeError, match="is expired"):
1195-
_ = attributes.used_mem_high
1196-
1197-
# Even when a new object is created (we found a case where the same
1198-
# mempool handle was really reused).
1199-
if MR is DeviceMemoryResource:
1200-
mr = MR(device, dict(max_size=POOL_SIZE)) # noqa: F841
1201-
elif MR is PinnedMemoryResource:
1202-
mr = MR(dict(max_size=POOL_SIZE)) # noqa: F841
1203-
elif MR is ManagedMemoryResource:
1204-
mr = create_managed_memory_resource_or_skip(dict()) # noqa: F841
1205-
1206-
with pytest.raises(RuntimeError, match="is expired"):
1207-
_ = attributes.used_mem_high
1193+
# The attributes bundle keeps the pool alive via MemoryPoolHandle,
1194+
# so accessing attributes still works even after the MR is deleted.
1195+
_ = attributes.used_mem_high # Should not raise
12081196

12091197

12101198
# Ensure that memory views dellocate their reference to dlpack tensors

cuda_core/tests/test_module.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -229,12 +229,6 @@ def test_saxpy_arguments(get_saxpy_kernel_cubin, cuda12_4_prerequisite_check):
229229
_ = krn.num_arguments
230230
return
231231

232-
# Check that arguments_info returns ParamInfo objects (works for both Python and Cython classes)
233-
# For Python classes: type(krn).arguments_info.fget.__annotations__ contains ParamInfo
234-
# For Cython cdef classes: property descriptors don't have .fget, so we check the actual values
235-
prop = type(krn).arguments_info
236-
if hasattr(prop, "fget") and hasattr(prop.fget, "__annotations__"):
237-
assert "ParamInfo" in str(prop.fget.__annotations__)
238232
arg_info = krn.arguments_info
239233
n_args = len(arg_info)
240234
assert n_args == krn.num_arguments

0 commit comments

Comments
 (0)