44
55from __future__ import annotations
66
7- from libc.stdint cimport uintptr_t
8-
9- from cuda.bindings cimport cydriver
10- from cuda.core._memory._buffer cimport Buffer, _MemAttrs, _init_mem_attrs, _query_memory_attrs
7+ from cuda.core._memory._buffer cimport Buffer, _init_mem_attrs
118from cuda.core._stream cimport Stream, Stream_accept
129
1310from cuda.core._utils.cuda_utils import driver, get_binding_version, handle_return
@@ -56,7 +53,6 @@ cdef dict _MANAGED_ADVICE_ALLOWED_LOCTYPES = {
5653 " unset_accessed_by" : _DEVICE_HOST_ONLY,
5754}
5855
59- cdef int _MANAGED_SIZE_NOT_PROVIDED = - 1
6056cdef int _HOST_NUMA_CURRENT_ID = 0
6157cdef int _FIRST_PREFETCH_LOCATION_INDEX = 0
6258cdef size_t _SINGLE_RANGE_COUNT = 1
@@ -241,71 +237,19 @@ cdef void _require_managed_discard_prefetch_support(str what):
241237 )
242238
243239
244- cdef tuple _managed_range_from_buffer(
245- Buffer buffer ,
246- int size,
247- str what,
248- ):
249- if size != _MANAGED_SIZE_NOT_PROVIDED:
250- raise TypeError (f" {what} does not accept size= when target is a Buffer" )
251- _require_managed_buffer(buffer , what)
252- return buffer .handle, buffer ._size
253-
254-
255- cdef uintptr_t _coerce_raw_pointer(object target, str what) except ? 0 :
256- cdef object ptr_obj
257- try :
258- ptr_obj = int (target)
259- except Exception as exc:
260- raise TypeError (
261- f" {what} target must be a Buffer or a raw pointer, got {type(target).__name__}"
262- ) from exc
263- if ptr_obj < 0 :
264- raise ValueError (f" {what} target pointer must be >= 0, got {target!r}" )
265- return < uintptr_t> ptr_obj
266-
267-
268- cdef int _require_managed_pointer(uintptr_t ptr, str what) except - 1 :
269- cdef _MemAttrs mem_attrs
270- with nogil:
271- _query_memory_attrs(mem_attrs, < cydriver.CUdeviceptr> ptr)
272- if not mem_attrs.is_managed:
273- raise ValueError (f" {what} requires a managed-memory allocation" )
274- return 0
275-
276-
277- cdef tuple _normalize_managed_target_range(
278- object target,
279- int size,
280- str what,
281- ):
282- cdef uintptr_t ptr
283-
284- if isinstance (target, Buffer):
285- return _managed_range_from_buffer(< Buffer> target, size, what)
286-
287- if size == _MANAGED_SIZE_NOT_PROVIDED:
288- raise TypeError (f" {what} requires size= when target is a raw pointer" )
289- ptr = _coerce_raw_pointer(target, what)
290- _require_managed_pointer(ptr, what)
291- return ptr, < size_t> size
292-
293-
294240def advise (
295- target ,
241+ target: Buffer ,
296242 advice: driver.CUmem_advise | str ,
297243 location: Device | int | None = None ,
298244 *,
299- int size = _MANAGED_SIZE_NOT_PROVIDED,
300245 location_type: str | None = None ,
301246):
302247 """ Apply managed-memory advice to an allocation range.
303248
304249 Parameters
305250 ----------
306- target : :class:`Buffer` | int | object
307- Managed allocation to operate on. This may be a :class:`Buffer` or a
308- raw pointer (requires ``size=``).
251+ target : :class:`Buffer`
252+ Managed allocation to operate on.
309253 advice : :obj:`~driver.CUmem_advise` | str
310254 Managed-memory advice to apply. String aliases such as
311255 ``"set_read_mostly"``, ``"set_preferred_location"``, and
@@ -314,17 +258,18 @@ def advise(
314258 Target location. When ``location_type`` is ``None``, values are
315259 interpreted as a device ordinal, ``-1`` for host, or ``None`` for
316260 advice values that ignore location.
317- size : int, optional
318- Allocation size in bytes. Required when ``target`` is a raw pointer.
319261 location_type : str | None, optional
320262 Explicit location kind. Supported values are ``"device"``, ``"host"``,
321263 ``"host_numa"``, and ``"host_numa_current"``.
322264 """
265+ if not isinstance (target, Buffer):
266+ raise TypeError (f" advise target must be a Buffer, got {type(target).__name__}" )
267+ cdef Buffer buf = < Buffer> target
268+ _require_managed_buffer(buf, " advise" )
323269 cdef str advice_name
324- cdef object ptr
325- cdef size_t nbytes
270+ cdef object ptr = buf.handle
271+ cdef size_t nbytes = buf._size
326272
327- ptr, nbytes = _normalize_managed_target_range(target, size, " advise" )
328273 advice_name, advice = _normalize_managed_advice(advice)
329274 location = _normalize_managed_location(
330275 location,
@@ -347,37 +292,36 @@ def advise(
347292
348293
349294def prefetch (
350- target ,
295+ target: Buffer ,
351296 location: Device | int | None = None ,
352297 *,
353298 stream: Stream | GraphBuilder ,
354- int size = _MANAGED_SIZE_NOT_PROVIDED,
355299 location_type: str | None = None ,
356300):
357301 """ Prefetch a managed-memory allocation range to a target location.
358302
359303 Parameters
360304 ----------
361- target : :class:`Buffer` | int | object
362- Managed allocation to operate on. This may be a :class:`Buffer` or a
363- raw pointer (requires ``size=``).
305+ target : :class:`Buffer`
306+ Managed allocation to operate on.
364307 location : :obj:`~_device.Device` | int | None, optional
365308 Target location. When ``location_type`` is ``None``, values are
366309 interpreted as a device ordinal, ``-1`` for host, or ``None``.
367310 A location is required for prefetch.
368311 stream : :obj:`~_stream.Stream` | :obj:`~_graph.GraphBuilder`
369312 Keyword argument specifying the stream for the asynchronous prefetch.
370- size : int, optional
371- Allocation size in bytes. Required when ``target`` is a raw pointer.
372313 location_type : str | None, optional
373314 Explicit location kind. Supported values are ``"device"``, ``"host"``,
374315 ``"host_numa"``, and ``"host_numa_current"``.
375316 """
317+ if not isinstance (target, Buffer):
318+ raise TypeError (f" prefetch target must be a Buffer, got {type(target).__name__}" )
319+ cdef Buffer buf = < Buffer> target
320+ _require_managed_buffer(buf, " prefetch" )
376321 cdef Stream s = Stream_accept(stream)
377- cdef object ptr
378- cdef size_t nbytes
322+ cdef object ptr = buf.handle
323+ cdef size_t nbytes = buf._size
379324
380- ptr, nbytes = _normalize_managed_target_range(target, size, " prefetch" )
381325 location = _normalize_managed_location(
382326 location,
383327 location_type,
@@ -405,40 +349,37 @@ def prefetch(
405349
406350
407351def discard_prefetch (
408- target ,
352+ target: Buffer ,
409353 location: Device | int | None = None ,
410354 *,
411355 stream: Stream | GraphBuilder ,
412- int size = _MANAGED_SIZE_NOT_PROVIDED,
413356 location_type: str | None = None ,
414357):
415358 """ Discard a managed-memory allocation range and prefetch it to a target location.
416359
417360 Parameters
418361 ----------
419- target : :class:`Buffer` | int | object
420- Managed allocation to operate on. This may be a :class:`Buffer` or a
421- raw pointer (requires ``size=``).
362+ target : :class:`Buffer`
363+ Managed allocation to operate on.
422364 location : :obj:`~_device.Device` | int | None, optional
423365 Target location. When ``location_type`` is ``None``, values are
424366 interpreted as a device ordinal, ``-1`` for host, or ``None``.
425367 A location is required for discard_prefetch.
426368 stream : :obj:`~_stream.Stream` | :obj:`~_graph.GraphBuilder`
427369 Keyword argument specifying the stream for the asynchronous operation.
428- size : int, optional
429- Allocation size in bytes. Required when ``target`` is a raw pointer.
430370 location_type : str | None, optional
431371 Explicit location kind. Supported values are ``"device"``, ``"host"``,
432372 ``"host_numa"``, and ``"host_numa_current"``.
433373 """
434- cdef object ptr
435- cdef object batch_ptr
436- cdef size_t nbytes
437-
438- ptr, nbytes = _normalize_managed_target_range(target, size, " discard_prefetch" )
374+ if not isinstance (target, Buffer):
375+ raise TypeError (f" discard_prefetch target must be a Buffer, got {type(target).__name__}" )
376+ cdef Buffer buf = < Buffer> target
377+ _require_managed_buffer(buf, " discard_prefetch" )
439378 _require_managed_discard_prefetch_support(" discard_prefetch" )
440379 cdef Stream s = Stream_accept(stream)
441- batch_ptr = driver.CUdeviceptr(int (ptr))
380+ cdef object ptr = buf.handle
381+ cdef size_t nbytes = buf._size
382+ cdef object batch_ptr = driver.CUdeviceptr(int (ptr))
442383 location = _normalize_managed_location(
443384 location,
444385 location_type,
0 commit comments