Skip to content

Commit faaa1d8

Browse files
committed
wip
1 parent b4d252c commit faaa1d8

2 files changed

Lines changed: 29 additions & 130 deletions

File tree

cuda_core/cuda/core/_memory/_managed_memory_ops.pyx

Lines changed: 29 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,7 @@
44

55
from __future__ import annotations
66

7-
from libc.stdint cimport uintptr_t
8-
9-
from cuda.bindings cimport cydriver
10-
from cuda.core._memory._buffer cimport Buffer, _MemAttrs, _init_mem_attrs, _query_memory_attrs
7+
from cuda.core._memory._buffer cimport Buffer, _init_mem_attrs
118
from cuda.core._stream cimport Stream, Stream_accept
129

1310
from cuda.core._utils.cuda_utils import driver, get_binding_version, handle_return
@@ -56,7 +53,6 @@ cdef dict _MANAGED_ADVICE_ALLOWED_LOCTYPES = {
5653
"unset_accessed_by": _DEVICE_HOST_ONLY,
5754
}
5855

59-
cdef int _MANAGED_SIZE_NOT_PROVIDED = -1
6056
cdef int _HOST_NUMA_CURRENT_ID = 0
6157
cdef int _FIRST_PREFETCH_LOCATION_INDEX = 0
6258
cdef size_t _SINGLE_RANGE_COUNT = 1
@@ -241,71 +237,19 @@ cdef void _require_managed_discard_prefetch_support(str what):
241237
)
242238

243239

244-
cdef tuple _managed_range_from_buffer(
245-
Buffer buffer,
246-
int size,
247-
str what,
248-
):
249-
if size != _MANAGED_SIZE_NOT_PROVIDED:
250-
raise TypeError(f"{what} does not accept size= when target is a Buffer")
251-
_require_managed_buffer(buffer, what)
252-
return buffer.handle, buffer._size
253-
254-
255-
cdef uintptr_t _coerce_raw_pointer(object target, str what) except? 0:
256-
cdef object ptr_obj
257-
try:
258-
ptr_obj = int(target)
259-
except Exception as exc:
260-
raise TypeError(
261-
f"{what} target must be a Buffer or a raw pointer, got {type(target).__name__}"
262-
) from exc
263-
if ptr_obj < 0:
264-
raise ValueError(f"{what} target pointer must be >= 0, got {target!r}")
265-
return <uintptr_t>ptr_obj
266-
267-
268-
cdef int _require_managed_pointer(uintptr_t ptr, str what) except -1:
269-
cdef _MemAttrs mem_attrs
270-
with nogil:
271-
_query_memory_attrs(mem_attrs, <cydriver.CUdeviceptr>ptr)
272-
if not mem_attrs.is_managed:
273-
raise ValueError(f"{what} requires a managed-memory allocation")
274-
return 0
275-
276-
277-
cdef tuple _normalize_managed_target_range(
278-
object target,
279-
int size,
280-
str what,
281-
):
282-
cdef uintptr_t ptr
283-
284-
if isinstance(target, Buffer):
285-
return _managed_range_from_buffer(<Buffer>target, size, what)
286-
287-
if size == _MANAGED_SIZE_NOT_PROVIDED:
288-
raise TypeError(f"{what} requires size= when target is a raw pointer")
289-
ptr = _coerce_raw_pointer(target, what)
290-
_require_managed_pointer(ptr, what)
291-
return ptr, <size_t>size
292-
293-
294240
def advise(
295-
target,
241+
target: Buffer,
296242
advice: driver.CUmem_advise | str,
297243
location: Device | int | None = None,
298244
*,
299-
int size=_MANAGED_SIZE_NOT_PROVIDED,
300245
location_type: str | None = None,
301246
):
302247
"""Apply managed-memory advice to an allocation range.
303248
304249
Parameters
305250
----------
306-
target : :class:`Buffer` | int | object
307-
Managed allocation to operate on. This may be a :class:`Buffer` or a
308-
raw pointer (requires ``size=``).
251+
target : :class:`Buffer`
252+
Managed allocation to operate on.
309253
advice : :obj:`~driver.CUmem_advise` | str
310254
Managed-memory advice to apply. String aliases such as
311255
``"set_read_mostly"``, ``"set_preferred_location"``, and
@@ -314,17 +258,18 @@ def advise(
314258
Target location. When ``location_type`` is ``None``, values are
315259
interpreted as a device ordinal, ``-1`` for host, or ``None`` for
316260
advice values that ignore location.
317-
size : int, optional
318-
Allocation size in bytes. Required when ``target`` is a raw pointer.
319261
location_type : str | None, optional
320262
Explicit location kind. Supported values are ``"device"``, ``"host"``,
321263
``"host_numa"``, and ``"host_numa_current"``.
322264
"""
265+
if not isinstance(target, Buffer):
266+
raise TypeError(f"advise target must be a Buffer, got {type(target).__name__}")
267+
cdef Buffer buf = <Buffer>target
268+
_require_managed_buffer(buf, "advise")
323269
cdef str advice_name
324-
cdef object ptr
325-
cdef size_t nbytes
270+
cdef object ptr = buf.handle
271+
cdef size_t nbytes = buf._size
326272

327-
ptr, nbytes = _normalize_managed_target_range(target, size, "advise")
328273
advice_name, advice = _normalize_managed_advice(advice)
329274
location = _normalize_managed_location(
330275
location,
@@ -347,37 +292,36 @@ def advise(
347292

348293

349294
def prefetch(
350-
target,
295+
target: Buffer,
351296
location: Device | int | None = None,
352297
*,
353298
stream: Stream | GraphBuilder,
354-
int size=_MANAGED_SIZE_NOT_PROVIDED,
355299
location_type: str | None = None,
356300
):
357301
"""Prefetch a managed-memory allocation range to a target location.
358302
359303
Parameters
360304
----------
361-
target : :class:`Buffer` | int | object
362-
Managed allocation to operate on. This may be a :class:`Buffer` or a
363-
raw pointer (requires ``size=``).
305+
target : :class:`Buffer`
306+
Managed allocation to operate on.
364307
location : :obj:`~_device.Device` | int | None, optional
365308
Target location. When ``location_type`` is ``None``, values are
366309
interpreted as a device ordinal, ``-1`` for host, or ``None``.
367310
A location is required for prefetch.
368311
stream : :obj:`~_stream.Stream` | :obj:`~_graph.GraphBuilder`
369312
Keyword argument specifying the stream for the asynchronous prefetch.
370-
size : int, optional
371-
Allocation size in bytes. Required when ``target`` is a raw pointer.
372313
location_type : str | None, optional
373314
Explicit location kind. Supported values are ``"device"``, ``"host"``,
374315
``"host_numa"``, and ``"host_numa_current"``.
375316
"""
317+
if not isinstance(target, Buffer):
318+
raise TypeError(f"prefetch target must be a Buffer, got {type(target).__name__}")
319+
cdef Buffer buf = <Buffer>target
320+
_require_managed_buffer(buf, "prefetch")
376321
cdef Stream s = Stream_accept(stream)
377-
cdef object ptr
378-
cdef size_t nbytes
322+
cdef object ptr = buf.handle
323+
cdef size_t nbytes = buf._size
379324

380-
ptr, nbytes = _normalize_managed_target_range(target, size, "prefetch")
381325
location = _normalize_managed_location(
382326
location,
383327
location_type,
@@ -405,40 +349,37 @@ def prefetch(
405349

406350

407351
def discard_prefetch(
408-
target,
352+
target: Buffer,
409353
location: Device | int | None = None,
410354
*,
411355
stream: Stream | GraphBuilder,
412-
int size=_MANAGED_SIZE_NOT_PROVIDED,
413356
location_type: str | None = None,
414357
):
415358
"""Discard a managed-memory allocation range and prefetch it to a target location.
416359
417360
Parameters
418361
----------
419-
target : :class:`Buffer` | int | object
420-
Managed allocation to operate on. This may be a :class:`Buffer` or a
421-
raw pointer (requires ``size=``).
362+
target : :class:`Buffer`
363+
Managed allocation to operate on.
422364
location : :obj:`~_device.Device` | int | None, optional
423365
Target location. When ``location_type`` is ``None``, values are
424366
interpreted as a device ordinal, ``-1`` for host, or ``None``.
425367
A location is required for discard_prefetch.
426368
stream : :obj:`~_stream.Stream` | :obj:`~_graph.GraphBuilder`
427369
Keyword argument specifying the stream for the asynchronous operation.
428-
size : int, optional
429-
Allocation size in bytes. Required when ``target`` is a raw pointer.
430370
location_type : str | None, optional
431371
Explicit location kind. Supported values are ``"device"``, ``"host"``,
432372
``"host_numa"``, and ``"host_numa_current"``.
433373
"""
434-
cdef object ptr
435-
cdef object batch_ptr
436-
cdef size_t nbytes
437-
438-
ptr, nbytes = _normalize_managed_target_range(target, size, "discard_prefetch")
374+
if not isinstance(target, Buffer):
375+
raise TypeError(f"discard_prefetch target must be a Buffer, got {type(target).__name__}")
376+
cdef Buffer buf = <Buffer>target
377+
_require_managed_buffer(buf, "discard_prefetch")
439378
_require_managed_discard_prefetch_support("discard_prefetch")
440379
cdef Stream s = Stream_accept(stream)
441-
batch_ptr = driver.CUdeviceptr(int(ptr))
380+
cdef object ptr = buf.handle
381+
cdef size_t nbytes = buf._size
382+
cdef object batch_ptr = driver.CUdeviceptr(int(ptr))
442383
location = _normalize_managed_location(
443384
location,
444385
location_type,

cuda_core/tests/test_memory.py

Lines changed: 0 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1441,20 +1441,6 @@ def test_managed_memory_advise_accepts_enum_value(init_cuda):
14411441
buffer.close()
14421442

14431443

1444-
def test_managed_memory_advise_size_rejected_for_buffer(init_cuda):
1445-
"""advise() raises TypeError when size= is given with a Buffer target."""
1446-
device = Device()
1447-
_skip_if_managed_allocation_unsupported(device)
1448-
device.set_current()
1449-
1450-
buffer = DummyUnifiedMemoryResource(device).allocate(_MANAGED_TEST_ALLOCATION_SIZE)
1451-
1452-
with pytest.raises(TypeError, match="does not accept size="):
1453-
managed_memory.advise(buffer, "set_read_mostly", size=1024)
1454-
1455-
buffer.close()
1456-
1457-
14581444
def test_managed_memory_advise_invalid_advice_values(init_cuda):
14591445
"""advise() rejects invalid advice strings and wrong types."""
14601446
device = Device()
@@ -1472,34 +1458,6 @@ def test_managed_memory_advise_invalid_advice_values(init_cuda):
14721458
buffer.close()
14731459

14741460

1475-
def test_managed_memory_functions_accept_raw_pointer_ranges(init_cuda):
1476-
device = Device()
1477-
_skip_if_managed_location_ops_unsupported(device)
1478-
device.set_current()
1479-
1480-
buffer = DummyUnifiedMemoryResource(device).allocate(_MANAGED_TEST_ALLOCATION_SIZE)
1481-
stream = device.create_stream()
1482-
1483-
managed_memory.advise(buffer.handle, "set_read_mostly", size=buffer.size)
1484-
assert (
1485-
_get_int_mem_range_attr(
1486-
buffer,
1487-
driver.CUmem_range_attribute.CU_MEM_RANGE_ATTRIBUTE_READ_MOSTLY,
1488-
)
1489-
== _READ_MOSTLY_ENABLED
1490-
)
1491-
1492-
managed_memory.prefetch(buffer.handle, device, size=buffer.size, stream=stream)
1493-
stream.sync()
1494-
last_location = _get_int_mem_range_attr(
1495-
buffer,
1496-
driver.CUmem_range_attribute.CU_MEM_RANGE_ATTRIBUTE_LAST_PREFETCH_LOCATION,
1497-
)
1498-
assert last_location == device.device_id
1499-
1500-
buffer.close()
1501-
1502-
15031461
def test_managed_memory_resource_host_numa_auto_resolve_failure(init_cuda):
15041462
"""host_numa with None raises RuntimeError when NUMA ID cannot be determined."""
15051463
from unittest.mock import MagicMock, patch

0 commit comments

Comments
 (0)