Skip to content

Commit d75a7bd

Browse files
rparolinclaude
andcommitted
feat(cuda.core): cu12 fallback for prefetch_batch (N3)
Per Leo's review on PR #1775 (_managed_memory_ops.pyx:228), raising NotImplementedError on cu12 forces users to write their own loop. The CUDA driver semantics for cuMemPrefetchBatchAsync are equivalent to per-range cuMemPrefetchAsync calls — just more efficient when batched at the driver level. On cu12 builds (where cuMemPrefetchBatchAsync is not exposed), fall back to a Python-level loop calling cuMemPrefetchAsync per buffer. The single-range path (_do_single_prefetch) already works on cu12 via the IF/ELSE split inside it. Note this fallback applies only to prefetch_batch — discard_batch and discard_prefetch_batch keep the cu12 NotImplementedError because the driver has no single-range cuMemDiscard{,AndPrefetch}Async to fall back to. Test skips for cuMemPrefetchBatchAsync unavailability dropped from TestPrefetchBatch.test_same_location and test_per_buffer_location; the fallback path now runs on cu12 builds too. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent a9cd713 commit d75a7bd

2 files changed

Lines changed: 12 additions & 11 deletions

File tree

cuda_core/cuda/core/_memory/_managed_memory_ops.pyx

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -280,10 +280,11 @@ def prefetch_batch(buffers, locations, *, stream):
280280
stream : :class:`~_stream.Stream` | :class:`~graph.GraphBuilder`
281281
Stream for the asynchronous prefetch (keyword-only).
282282
283-
Raises
284-
------
285-
NotImplementedError
286-
On a CUDA 12 build of ``cuda.core``.
283+
Notes
284+
-----
285+
On a CUDA 12 build, falls back to a Python-level loop calling
286+
``cuMemPrefetchAsync`` per buffer (no batched driver entry point on
287+
CUDA 12). CUDA 13 builds use ``cuMemPrefetchBatchAsync`` directly.
287288
"""
288289
cdef tuple bufs = _coerce_batch_buffers(buffers, "prefetch_batch")
289290
cdef Py_ssize_t n = len(bufs)
@@ -364,9 +365,13 @@ cdef void _do_batch_prefetch(tuple bufs, tuple locs, Stream s):
364365
IF CUDA_CORE_BUILD_MAJOR >= 13:
365366
_do_batch_prefetch_op(bufs, locs, s, cydriver.cuMemPrefetchBatchAsync)
366367
ELSE:
367-
raise NotImplementedError(
368-
"batched prefetch requires a CUDA 13 build of cuda.core"
369-
)
368+
# cu12 has no cuMemPrefetchBatchAsync; loop per-range.
369+
cdef Buffer buf
370+
cdef Py_ssize_t i
371+
cdef Py_ssize_t n = len(bufs)
372+
for i in range(n):
373+
buf = <Buffer>bufs[i]
374+
_do_single_prefetch(buf, locs[i], s)
370375

371376

372377
def discard_prefetch_batch(buffers, locations, *, stream):

cuda_core/tests/memory/test_managed_ops.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -349,8 +349,6 @@ def test_same_location(self, init_cuda):
349349

350350
device = Device()
351351
skip_if_managed_memory_unsupported(device)
352-
if not hasattr(driver, "cuMemPrefetchBatchAsync"):
353-
pytest.skip("cuMemPrefetchBatchAsync unavailable")
354352
device.set_current()
355353
mr = create_managed_memory_resource_or_skip()
356354
bufs = [mr.allocate(_MANAGED_TEST_ALLOCATION_SIZE) for _ in range(3)]
@@ -372,8 +370,6 @@ def test_per_buffer_location(self, init_cuda):
372370

373371
device = Device()
374372
skip_if_managed_memory_unsupported(device)
375-
if not hasattr(driver, "cuMemPrefetchBatchAsync"):
376-
pytest.skip("cuMemPrefetchBatchAsync unavailable")
377373
device.set_current()
378374
mr = create_managed_memory_resource_or_skip()
379375
bufs = [mr.allocate(_MANAGED_TEST_ALLOCATION_SIZE) for _ in range(2)]

0 commit comments

Comments
 (0)