diff --git a/av/_hwdevice_registry.py b/av/_hwdevice_registry.py new file mode 100644 index 000000000..7557a6a02 --- /dev/null +++ b/av/_hwdevice_registry.py @@ -0,0 +1,12 @@ +_cuda_hwdevice_data_ptr_to_device_id: dict[int, int] = {} + + +def register_cuda_hwdevice_data_ptr(hwdevice_data_ptr: int, device_id: int) -> None: + if hwdevice_data_ptr: + _cuda_hwdevice_data_ptr_to_device_id[int(hwdevice_data_ptr)] = int(device_id) + + +def lookup_cuda_device_id(hwdevice_data_ptr: int) -> int: + if not hwdevice_data_ptr: + return 0 + return _cuda_hwdevice_data_ptr_to_device_id.get(int(hwdevice_data_ptr), 0) diff --git a/av/codec/hwaccel.pxd b/av/codec/hwaccel.pxd index 46efdaf3b..41a6bdc86 100644 --- a/av/codec/hwaccel.pxd +++ b/av/codec/hwaccel.pxd @@ -19,3 +19,4 @@ cdef class HWAccel: cdef public bint allow_software_fallback cdef public dict options cdef public int flags + cdef public str output_format diff --git a/av/codec/hwaccel.py b/av/codec/hwaccel.py index ffa196e48..e717bb070 100644 --- a/av/codec/hwaccel.py +++ b/av/codec/hwaccel.py @@ -8,6 +8,8 @@ from cython.cimports.av.error import err_check from cython.cimports.av.video.format import get_video_format +import av._hwdevice_registry as _hwreg + class HWDeviceType(IntEnum): none = lib.AV_HWDEVICE_TYPE_NONE @@ -110,7 +112,11 @@ def __init__( allow_software_fallback=True, options=None, flags=None, + output_format=None, ): + if isinstance(device, int): + device = str(device) + if isinstance(device_type, HWDeviceType): self._device_type = device_type elif isinstance(device_type, str): @@ -120,9 +126,20 @@ def __init__( else: raise ValueError("Unknown type for device_type") + if output_format is None: + output_format = "sw" + if isinstance(output_format, str): + output_format = output_format.lower() + if output_format not in {"sw", "hw"}: + raise ValueError("output_format must be 'sw' or 'hw'") + self.output_format = output_format + self._device = device self.allow_software_fallback = allow_software_fallback + self.options = {} if not options else dict(options) + if self._device_type == HWDeviceType.cuda and self.output_format == "hw": + self.options.setdefault("primary_ctx", "1") self.flags = 0 if not flags else flags self.ptr = cython.NULL self.config = None @@ -155,6 +172,19 @@ def _initialize_hw_context(self, codec: Codec): ) ) + if config.ptr.device_type == lib.AV_HWDEVICE_TYPE_CUDA: + device_id = 0 + if self._device: + try: + device_id = int(self._device) + except ValueError: + device_id = 0 + + _hwreg.register_cuda_hwdevice_data_ptr( + cython.cast(cython.size_t, self.ptr.data), + device_id, + ) + def create(self, codec: Codec): """Create a new hardware accelerator context with the given codec""" if self.ptr: @@ -165,6 +195,7 @@ def create(self, codec: Codec): device=self._device, allow_software_fallback=self.allow_software_fallback, options=self.options, + output_format=self.output_format, ) ret._initialize_hw_context(codec) return ret diff --git a/av/codec/hwaccel.pyi b/av/codec/hwaccel.pyi index 8bdc0a6e0..549cc7eaf 100644 --- a/av/codec/hwaccel.pyi +++ b/av/codec/hwaccel.pyi @@ -1,5 +1,5 @@ from enum import IntEnum -from typing import cast +from typing import Literal, cast from av.codec.codec import Codec from av.video.format import VideoFormat @@ -37,13 +37,17 @@ class HWConfig: def is_supported(self) -> bool: ... class HWAccel: + output_format: Literal["sw", "hw"] + options: dict[str, object] + def __init__( self, device_type: str | HWDeviceType, - device: str | None = None, + device: str | int | None = None, allow_software_fallback: bool = False, options: dict[str, object] | None = None, flags: int | None = None, + output_format: Literal["sw", "hw"] | None = None, ) -> None: ... def create(self, codec: Codec) -> HWAccel: ... diff --git a/av/video/codeccontext.py b/av/video/codeccontext.py index b7842a7d5..aa7c49560 100644 --- a/av/video/codeccontext.py +++ b/av/video/codeccontext.py @@ -127,6 +127,9 @@ def _transfer_hwframe(self, frame: Frame): # need to transfer. return frame + if self.hwaccel_ctx.output_format == "hw": + return frame + frame_sw: Frame = self._alloc_next_frame() err_check(lib.av_hwframe_transfer_data(frame_sw.ptr, frame.ptr, 0)) # TODO: Is there anything else to transfer? diff --git a/av/video/frame.py b/av/video/frame.py index 1320afad3..7914b1aa6 100644 --- a/av/video/frame.py +++ b/av/video/frame.py @@ -2,12 +2,158 @@ from enum import IntEnum import cython +import cython.cimports.libav as lib +from cython.cimports.av.dictionary import Dictionary from cython.cimports.av.error import err_check from cython.cimports.av.sidedata.sidedata import get_display_rotation from cython.cimports.av.utils import check_ndarray from cython.cimports.av.video.format import get_pix_fmt, get_video_format from cython.cimports.av.video.plane import VideoPlane -from cython.cimports.libc.stdint import uint8_t +from cython.cimports.cpython.exc import PyErr_Clear +from cython.cimports.cpython.pycapsule import ( + PyCapsule_GetPointer, + PyCapsule_IsValid, + PyCapsule_SetName, +) +from cython.cimports.cpython.ref import Py_DECREF, Py_INCREF, PyObject +from cython.cimports.dlpack import DLManagedTensor, kDLCPU, kDLCUDA, kDLUInt +from cython.cimports.libc.stdint import int64_t, uint8_t + +import av._hwdevice_registry as _hwreg + +_cuda_device_ctx_cache = {} +_cuda_frames_ctx_cache = {} + + +@cython.cfunc +def _consume_dlpack(obj: object, stream: object) -> cython.pointer[DLManagedTensor]: + capsule: object + managed: cython.pointer[DLManagedTensor] + + if hasattr(obj, "__dlpack__"): + capsule = obj.__dlpack__() if stream is None else obj.__dlpack__(stream=stream) + else: + capsule = obj + + if not PyCapsule_IsValid(capsule, b"dltensor"): + PyErr_Clear() + raise TypeError( + "expected a DLPack capsule or an object implementing __dlpack__" + ) + + managed = cython.cast( + cython.pointer[DLManagedTensor], + PyCapsule_GetPointer(capsule, b"dltensor"), + ) + if managed == cython.NULL: + raise ValueError("PyCapsule_GetPointer returned NULL") + + if PyCapsule_SetName(capsule, b"used_dltensor") != 0: + raise RuntimeError("PyCapsule_SetName failed") + + return managed + + +@cython.cfunc +@cython.nogil +@cython.exceptval(check=False) +def _dlpack_avbuffer_free( + opaque: cython.p_void, + data: cython.pointer[uint8_t], +) -> cython.void: + managed: cython.pointer[DLManagedTensor] = cython.cast( + cython.pointer[DLManagedTensor], opaque + ) + if managed != cython.NULL: + managed.deleter(managed) + + +@cython.cfunc +@cython.nogil +@cython.exceptval(check=False) +def _numpy_avbuffer_free( + opaque: cython.p_void, + data: cython.pointer[uint8_t], +) -> cython.void: + if opaque != cython.NULL: + with cython.gil: + Py_DECREF(cython.cast(object, opaque)) + + +@cython.cfunc +def _get_cuda_device_ctx( + device_id: cython.int, + primary_ctx: cython.bint, +) -> cython.pointer[lib.AVBufferRef]: + key = (int(device_id), int(primary_ctx)) + cached = _cuda_device_ctx_cache.get(key) + if cached is not None: + return cython.cast( + cython.pointer[lib.AVBufferRef], + cython.cast(cython.size_t, cached), + ) + + device_ref: cython.pointer[lib.AVBufferRef] = cython.NULL + device_bytes = str(device_id).encode() + c_device: cython.p_char = device_bytes + options: Dictionary = Dictionary({"primary_ctx": "1" if primary_ctx else "0"}) + + err_check( + lib.av_hwdevice_ctx_create( + cython.address(device_ref), + lib.AV_HWDEVICE_TYPE_CUDA, + c_device, + options.ptr, + 0, + ) + ) + + _hwreg.register_cuda_hwdevice_data_ptr( + cython.cast(cython.size_t, device_ref.data), + device_id, + ) + + _cuda_device_ctx_cache[key] = cython.cast(cython.size_t, device_ref) + return device_ref + + +@cython.cfunc +def _get_cuda_frames_ctx( + device_id: cython.int, + primary_ctx: cython.bint, + sw_fmt: lib.AVPixelFormat, + width: cython.int, + height: cython.int, +) -> cython.pointer[lib.AVBufferRef]: + key = (int(device_id), int(primary_ctx), int(sw_fmt), int(width), int(height)) + cached = _cuda_frames_ctx_cache.get(key) + if cached is not None: + return cython.cast( + cython.pointer[lib.AVBufferRef], + cython.cast(cython.size_t, cached), + ) + + device_ref = _get_cuda_device_ctx(device_id, primary_ctx) + frames_ref = lib.av_hwframe_ctx_alloc(device_ref) + if frames_ref == cython.NULL: + raise MemoryError("av_hwframe_ctx_alloc() failed") + + try: + frames_ctx: cython.pointer[lib.AVHWFramesContext] = cython.cast( + cython.pointer[lib.AVHWFramesContext], frames_ref.data + ) + frames_ctx.format = get_pix_fmt(b"cuda") + frames_ctx.sw_format = sw_fmt + frames_ctx.width = width + frames_ctx.height = height + err_check(lib.av_hwframe_ctx_init(frames_ref)) + except Exception: + lib.av_buffer_unref(cython.address(frames_ref)) + raise + + _cuda_frames_ctx_cache[key] = cython.cast(cython.size_t, frames_ref) + return frames_ref + _cinit_bypass_sentinel = object() @@ -222,10 +368,7 @@ def _init(self, format: lib.AVPixelFormat, width: cython.uint, height: cython.ui # We enforce aligned buffers, otherwise `sws_scale` can perform # poorly or even cause out-of-bounds reads and writes. if width and height: - res = lib.av_image_alloc( - self.ptr.data, self.ptr.linesize, width, height, format, 16 - ) - self._buffer = self.ptr.data[0] + res = lib.av_frame_get_buffer(self.ptr, 16) if res: err_check(res) @@ -243,7 +386,7 @@ def _init_user_attributes(self): def __dealloc__(self): # The `self._buffer` member is only set if *we* allocated the buffer in `_init`, # as opposed to a buffer allocated by a decoder. - lib.av_freep(cython.address(self._buffer)) + lib.av_frame_unref(self.ptr) # Let go of the reference from the numpy buffers if we made one self._np_buffer = None @@ -261,12 +404,21 @@ def planes(self): # We need to detect which planes actually exist, but also constrain ourselves to # the maximum plane count (as determined only by VideoFrames so far), in case # the library implementation does not set the last plane to NULL. + fmt = self.format + if self.ptr.hw_frames_ctx: + frames_ctx: cython.pointer[lib.AVHWFramesContext] = cython.cast( + cython.pointer[lib.AVHWFramesContext], self.ptr.hw_frames_ctx.data + ) + fmt = get_video_format( + frames_ctx.sw_format, self.ptr.width, self.ptr.height + ) + max_plane_count: cython.int = 0 - for i in range(self.format.ptr.nb_components): - count = self.format.ptr.comp[i].plane + 1 + for i in range(fmt.ptr.nb_components): + count = fmt.ptr.comp[i].plane + 1 if max_plane_count < count: max_plane_count = count - if self.format.name == "pal8": + if fmt.name == "pal8": max_plane_count = 2 plane_count: cython.int = 0 @@ -446,7 +598,21 @@ def to_ndarray(self, channel_last=False, **kwargs): .. note:: For ``gbrp`` formats, channels are flipped to RGB order. """ - frame: VideoFrame = self.reformat(**kwargs) + kwargs2 = dict(kwargs) + if self.ptr.hw_frames_ctx and "format" not in kwargs2: + frames_ctx: cython.pointer[lib.AVHWFramesContext] = cython.cast( + cython.pointer[lib.AVHWFramesContext], self.ptr.hw_frames_ctx.data + ) + kwargs2["format"] = get_video_format( + frames_ctx.sw_format, self.ptr.width, self.ptr.height + ).name + + frame: VideoFrame = self.reformat(**kwargs2) + if frame.ptr.hw_frames_ctx: + raise ValueError( + "Cannot convert a hardware frame to numpy directly. " + "Specify a software format (e.g. format='rgb24') or decode with HWAccel(output_format='sw')." + ) import numpy as np @@ -859,9 +1025,9 @@ def from_numpy_buffer(array, format="rgb24", width=0): return frame def _image_fill_pointers_numpy(self, buffer, width, height, linesizes, format): - c_format: lib.AVPixelFormat - c_ptr: cython.pointer[uint8_t] - c_data: cython.size_t + c_data: cython.size_t = buffer.ctypes.data + c_ptr: cython.pointer[uint8_t] = cython.cast(cython.pointer[uint8_t], c_data) + c_format: lib.AVPixelFormat = get_pix_fmt(format) # If you want to use the numpy notation, then you need to include the following lines at the top of the file: # cimport numpy as cnp @@ -882,26 +1048,41 @@ def _image_fill_pointers_numpy(self, buffer, width, height, linesizes, format): c_data = buffer.ctypes.data c_ptr = cython.cast(cython.pointer[uint8_t], c_data) c_format = get_pix_fmt(format) - lib.av_freep(cython.address(self._buffer)) + lib.av_frame_unref(self.ptr) + self._np_buffer = None # Hold on to a reference for the numpy buffer so that it doesn't get accidentally garbage collected - self._np_buffer = buffer self.ptr.format = c_format self.ptr.width = width self.ptr.height = height for i, linesize in enumerate(linesizes): self.ptr.linesize[i] = linesize - res = lib.av_image_fill_pointers( - self.ptr.data, - cython.cast(lib.AVPixelFormat, self.ptr.format), - self.ptr.height, + required = err_check( + lib.av_image_fill_pointers( + self.ptr.data, + cython.cast(lib.AVPixelFormat, self.ptr.format), + self.ptr.height, + c_ptr, + self.ptr.linesize, + ) + ) + + py_buf = cython.cast(object, buffer) + Py_INCREF(py_buf) + + self.ptr.buf[0] = lib.av_buffer_create( c_ptr, - self.ptr.linesize, + required, + _numpy_avbuffer_free, + cython.cast(cython.p_void, py_buf), + 0, ) + if self.ptr.buf[0] == cython.NULL: + Py_DECREF(py_buf) + raise MemoryError("av_buffer_create failed") - if res: - err_check(res) + self._np_buffer = buffer self._init_user_attributes() @staticmethod @@ -1179,3 +1360,209 @@ def from_bytes( else: raise NotImplementedError(f"Format '{format}' is not supported.") return frame + + @staticmethod + def from_dlpack( + planes, + format: str = "nv12", + width: int = 0, + height: int = 0, + stream=None, + device_id: int | None = None, + primary_ctx: bool = True, + ): + if not isinstance(planes, (tuple, list)): + planes = (planes,) + + if len(planes) != 2: + raise ValueError( + "from_dlpack currently supports 2-plane formats only (nv12/p010le/p016le)" + ) + + sw_fmt: lib.AVPixelFormat = get_pix_fmt(format) + nv12 = get_pix_fmt(b"nv12") + p010le = get_pix_fmt(b"p010le") + p016le = get_pix_fmt(b"p016le") + + if sw_fmt not in {nv12, p010le, p016le}: + raise NotImplementedError("from_dlpack supports nv12, p010le, p016le only") + + expected_bits = 8 if sw_fmt == nv12 else 16 + itemsize = 1 if expected_bits == 8 else 2 + + m0: cython.pointer[DLManagedTensor] = cython.NULL + m1: cython.pointer[DLManagedTensor] = cython.NULL + frame: VideoFrame = None + + try: + m0 = _consume_dlpack(planes[0], stream) + m1 = _consume_dlpack(planes[1], stream) + + dev_type0 = m0.dl_tensor.device.device_type + dev_type1 = m1.dl_tensor.device.device_type + if dev_type0 != dev_type1: + raise ValueError("plane tensors must have the same device_type") + if dev_type0 not in {kDLCUDA, kDLCPU}: + raise NotImplementedError( + "only CPU and CUDA DLPack tensors are supported" + ) + + dev0 = m0.dl_tensor.device.device_id + dev1 = m1.dl_tensor.device.device_id + if dev0 != dev1: + raise ValueError("plane tensors must be on the same CUDA device") + if dev_type0 == kDLCUDA: + if dev0 != dev1: + raise ValueError("plane tensors must be on the same CUDA device") + if device_id is None: + device_id = dev0 + elif device_id != dev0: + raise ValueError( + "device_id does not match the DLPack tensor device_id" + ) + else: + if device_id not in (None, 0): + raise ValueError("device_id must be 0 for CPU tensors") + device_id = 0 + if dev_type0 == kDLCPU and (dev0 != 0 or dev1 != 0): + raise ValueError("CPU DLPack tensors must have device_id == 0") + + if ( + m0.dl_tensor.dtype.code != kDLUInt + or m0.dl_tensor.dtype.bits != expected_bits + or m0.dl_tensor.dtype.lanes != 1 + ): + raise TypeError("unexpected dtype for plane 0") + + if ( + m1.dl_tensor.dtype.code != kDLUInt + or m1.dl_tensor.dtype.bits != expected_bits + or m1.dl_tensor.dtype.lanes != 1 + ): + raise TypeError("unexpected dtype for plane 1") + + if m0.dl_tensor.ndim != 2: + raise ValueError("plane 0 must be 2D (H, W)") + + y_h = cython.cast(int64_t, m0.dl_tensor.shape[0]) + y_w = cython.cast(int64_t, m0.dl_tensor.shape[1]) + + if width == 0 and height == 0: + width = cython.cast(int, y_w) + height = cython.cast(int, y_h) + elif width == 0 or height == 0: + raise ValueError("either specify both width/height or neither") + else: + if y_w != width or y_h != height: + raise ValueError("plane 0 shape does not match width/height") + + if width % 2 or height % 2: + raise ValueError("width/height must be even for nv12/p010le/p016le") + + if m0.dl_tensor.strides != cython.NULL: + if m0.dl_tensor.strides[1] != 1: + raise ValueError("plane 0 must be contiguous in the last dimension") + y_pitch_elems = cython.cast(int64_t, m0.dl_tensor.strides[0]) + else: + y_pitch_elems = cython.cast(int64_t, width) + + y_linesize = cython.cast(int, y_pitch_elems * itemsize) + y_size = cython.cast(int, y_linesize * height) + + uv_ndim = m1.dl_tensor.ndim + uv_h_expected = height // 2 + + if uv_ndim == 2: + uv_h = cython.cast(int, m1.dl_tensor.shape[0]) + uv_w = cython.cast(int, m1.dl_tensor.shape[1]) + if uv_h != uv_h_expected or uv_w != width: + raise ValueError("plane 1 must have shape (H/2, W) for 2D UV") + if m1.dl_tensor.strides != cython.NULL: + if m1.dl_tensor.strides[1] != 1: + raise ValueError( + "plane 1 must be contiguous in the last dimension" + ) + uv_pitch_elems = cython.cast(int64_t, m1.dl_tensor.strides[0]) + else: + uv_pitch_elems = cython.cast(int64_t, uv_w) + elif uv_ndim == 3: + uv_h = cython.cast(int, m1.dl_tensor.shape[0]) + uv_w2 = cython.cast(int, m1.dl_tensor.shape[1]) + uv_c = cython.cast(int, m1.dl_tensor.shape[2]) + if uv_h != uv_h_expected or uv_w2 != (width // 2) or uv_c != 2: + raise ValueError("plane 1 must have shape (H/2, W/2, 2) for 3D UV") + if m1.dl_tensor.strides != cython.NULL: + if m1.dl_tensor.strides[2] != 1 or m1.dl_tensor.strides[1] != 2: + raise ValueError( + "unexpected UV plane strides for (H/2, W/2, 2)" + ) + uv_pitch_elems = cython.cast(int64_t, m1.dl_tensor.strides[0]) + else: + uv_pitch_elems = cython.cast(int64_t, width) + else: + raise ValueError("plane 1 must be 2D or 3D") + + uv_linesize = cython.cast(int, uv_pitch_elems * itemsize) + uv_size = cython.cast(int, uv_linesize * (height // 2)) + + frame = alloc_video_frame() + frame.ptr.width = width + frame.ptr.height = height + if dev_type0 == kDLCUDA: + if primary_ctx is None: + primary_ctx = True + if not isinstance(primary_ctx, (bool, int)): + raise TypeError("primary_ctx must be a bool") + primary_ctx = bool(primary_ctx) + + frames_ref = _get_cuda_frames_ctx( + device_id, primary_ctx, sw_fmt, width, height + ) + + frame.ptr.format = get_pix_fmt(b"cuda") + frame.ptr.hw_frames_ctx = lib.av_buffer_ref(frames_ref) + if frame.ptr.hw_frames_ctx == cython.NULL: + raise MemoryError("av_buffer_ref(hw_frames_ctx) failed") + else: + frame.ptr.format = sw_fmt + + y_ptr = cython.cast( + cython.pointer[uint8_t], m0.dl_tensor.data + ) + cython.cast(cython.size_t, m0.dl_tensor.byte_offset) + uv_ptr = cython.cast( + cython.pointer[uint8_t], m1.dl_tensor.data + ) + cython.cast(cython.size_t, m1.dl_tensor.byte_offset) + + frame.ptr.buf[0] = lib.av_buffer_create( + y_ptr, y_size, _dlpack_avbuffer_free, cython.cast(cython.p_void, m0), 0 + ) + if frame.ptr.buf[0] == cython.NULL: + raise MemoryError("av_buffer_create failed for plane 0") + frame.ptr.data[0] = y_ptr + frame.ptr.linesize[0] = y_linesize + m0 = cython.NULL + + frame.ptr.buf[1] = lib.av_buffer_create( + uv_ptr, + uv_size, + _dlpack_avbuffer_free, + cython.cast(cython.p_void, m1), + 0, + ) + if frame.ptr.buf[1] == cython.NULL: + raise MemoryError("av_buffer_create failed for plane 1") + frame.ptr.data[1] = uv_ptr + frame.ptr.linesize[1] = uv_linesize + m1 = cython.NULL + + frame._init_user_attributes() + return frame + + except Exception: + if frame is not None: + lib.av_frame_unref(frame.ptr) + if m0 != cython.NULL: + m0.deleter(m0) + if m1 != cython.NULL: + m1.deleter(m1) + raise diff --git a/av/video/frame.pyi b/av/video/frame.pyi index a7575e3bd..0102b1472 100644 --- a/av/video/frame.pyi +++ b/av/video/frame.pyi @@ -84,3 +84,13 @@ class VideoFrame(Frame): flip_horizontal: bool = False, flip_vertical: bool = False, ) -> VideoFrame: ... + @staticmethod + def from_dlpack( + planes: object | tuple[object, ...], + format: str = "nv12", + width: int = 0, + height: int = 0, + stream: int | None = None, + device_id: int | None = None, + primary_ctx: bool = True, + ) -> "VideoFrame": ... diff --git a/av/video/plane.py b/av/video/plane.py index 495a9de4c..771ffe031 100644 --- a/av/video/plane.py +++ b/av/video/plane.py @@ -1,25 +1,53 @@ +from typing import Any + import cython +import cython.cimports.libav as lib +from cython.cimports.av.buffer import Buffer +from cython.cimports.av.error import err_check +from cython.cimports.av.video.format import get_pix_fmt, get_video_format from cython.cimports.av.video.frame import VideoFrame +from cython.cimports.cpython import PyBUF_WRITABLE, PyBuffer_FillInfo +from cython.cimports.cpython.buffer import Py_buffer +from cython.cimports.cpython.pycapsule import ( + PyCapsule_GetPointer, + PyCapsule_IsValid, + PyCapsule_New, +) +from cython.cimports.cpython.ref import PyObject +from cython.cimports.dlpack import DLManagedTensor, kDLCPU, kDLCUDA, kDLUInt +from cython.cimports.libc.stdint import int64_t +from cython.cimports.libc.stdlib import free, malloc + +import av._hwdevice_registry as _hwreg @cython.cclass class VideoPlane(Plane): def __cinit__(self, frame: VideoFrame, index: cython.int): # The palette plane has no associated component or linesize; set fields manually - if frame.format.name == "pal8" and index == 1: + fmt = frame.format + if frame.ptr.hw_frames_ctx: + frames_ctx: cython.pointer[lib.AVHWFramesContext] = cython.cast( + cython.pointer[lib.AVHWFramesContext], frame.ptr.hw_frames_ctx.data + ) + fmt = get_video_format( + frames_ctx.sw_format, frame.ptr.width, frame.ptr.height + ) + + if fmt.name == "pal8" and index == 1: self.width = 256 self.height = 1 self.buffer_size = 256 * 4 return - for i in range(frame.format.ptr.nb_components): - if frame.format.ptr.comp[i].plane == index: - component = frame.format.components[i] + for i in range(fmt.ptr.nb_components): + if fmt.ptr.comp[i].plane == index: + component = fmt.components[i] self.width = component.width self.height = component.height break - else: # nobreak - raise RuntimeError(f"could not find plane {index} of {frame.format!r}") + else: + raise RuntimeError(f"could not find plane {index} of {fmt!r}") # Sometimes, linesize is negative (and that is meaningful). We are only # insisting that the buffer size be based on the extent of linesize, and @@ -38,3 +66,257 @@ def line_size(self): :type: int """ return self.frame.ptr.linesize[self.index] + + @cython.cfunc + def _buffer_writable(self) -> cython.bint: + if self.frame.ptr.hw_frames_ctx: + return False + return True + + def __getbuffer__(self, view: cython.pointer[Py_buffer], flags: cython.int): + if self.frame.ptr.hw_frames_ctx: + raise TypeError( + "Hardware frame planes do not support the Python buffer protocol. " + "Use DLPack (__dlpack__) or download to a software frame." + ) + if flags & PyBUF_WRITABLE and not self._buffer_writable(): + raise ValueError("buffer is not writable") + PyBuffer_FillInfo(view, self, self._buffer_ptr(), self._buffer_size(), 0, flags) + + def __dlpack_device__(self): + if self.frame.ptr.hw_frames_ctx: + if cython.cast(lib.AVPixelFormat, self.frame.ptr.format) != get_pix_fmt( + b"cuda" + ): + raise NotImplementedError( + "DLPack export is only implemented for CUDA hw frames" + ) + + frames_ctx: cython.pointer[lib.AVHWFramesContext] = cython.cast( + cython.pointer[lib.AVHWFramesContext], self.frame.ptr.hw_frames_ctx.data + ) + device_id = _hwreg.lookup_cuda_device_id( + cython.cast(cython.size_t, frames_ctx.device_ref.data) + ) + return (kDLCUDA, device_id) + + return (kDLCPU, 0) + + def __dlpack__(self, stream: int | Any | None = None): + if self.frame.ptr.buf[0] == cython.NULL: + raise TypeError( + "DLPack export requires a refcounted AVFrame (frame.buf[0] is NULL)" + ) + + device_type: cython.int + device_id: cython.int + sw_fmt: lib.AVPixelFormat + + if self.frame.ptr.hw_frames_ctx: + if cython.cast(lib.AVPixelFormat, self.frame.ptr.format) != get_pix_fmt( + b"cuda" + ): + raise NotImplementedError( + "DLPack export is only implemented for CUDA hw frames" + ) + + frames_ctx: cython.pointer[lib.AVHWFramesContext] = cython.cast( + cython.pointer[lib.AVHWFramesContext], self.frame.ptr.hw_frames_ctx.data + ) + sw_fmt = frames_ctx.sw_format + device_type = kDLCUDA + device_id = _hwreg.lookup_cuda_device_id( + cython.cast(cython.size_t, frames_ctx.device_ref.data) + ) + else: + sw_fmt = cython.cast(lib.AVPixelFormat, self.frame.ptr.format) + device_type = kDLCPU + device_id = 0 + + line_size = self.line_size + if line_size < 0: + raise NotImplementedError( + "negative linesize is not supported for DLPack export" + ) + + nv12 = get_pix_fmt(b"nv12") + p010le = get_pix_fmt(b"p010le") + p016le = get_pix_fmt(b"p016le") + + ndim: cython.int + bits: cython.int + itemsize: cython.int + + s0: int64_t + s1: int64_t + s2: int64_t + st0: int64_t + st1: int64_t + st2: int64_t + + if sw_fmt == nv12: + itemsize = 1 + bits = 8 + if self.index == 0: + ndim = 2 + s0 = self.frame.ptr.height + s1 = self.frame.ptr.width + st0 = line_size + st1 = 1 + elif self.index == 1: + ndim = 3 + s0 = self.frame.ptr.height // 2 + s1 = self.frame.ptr.width // 2 + s2 = 2 + st0 = line_size + st1 = 2 + st2 = 1 + else: + raise ValueError("invalid plane index for NV12") + elif sw_fmt == p010le or sw_fmt == p016le: + itemsize = 2 + bits = 16 + if line_size % itemsize: + raise ValueError("linesize is not aligned to dtype") + if self.index == 0: + ndim = 2 + s0 = self.frame.ptr.height + s1 = self.frame.ptr.width + st0 = line_size // itemsize + st1 = 1 + elif self.index == 1: + ndim = 3 + s0 = self.frame.ptr.height // 2 + s1 = self.frame.ptr.width // 2 + s2 = 2 + st0 = line_size // itemsize + st1 = 2 + st2 = 1 + else: + raise ValueError("invalid plane index for P010/P016") + else: + raise NotImplementedError("unsupported sw_format for DLPack export") + + frame_ref: cython.pointer[lib.AVFrame] = lib.av_frame_alloc() + if frame_ref == cython.NULL: + raise MemoryError("av_frame_alloc() failed") + err_check(lib.av_frame_ref(frame_ref, self.frame.ptr)) + + shape = cython.cast( + cython.pointer[int64_t], malloc(ndim * cython.sizeof(int64_t)) + ) + strides = cython.cast( + cython.pointer[int64_t], malloc(ndim * cython.sizeof(int64_t)) + ) + if shape == cython.NULL or strides == cython.NULL: + if shape != cython.NULL: + free(shape) + if strides != cython.NULL: + free(strides) + lib.av_frame_free(cython.address(frame_ref)) + raise MemoryError("malloc() failed") + + if ndim == 2: + shape[0] = s0 + shape[1] = s1 + strides[0] = st0 + strides[1] = st1 + else: + shape[0] = s0 + shape[1] = s1 + shape[2] = s2 + strides[0] = st0 + strides[1] = st1 + strides[2] = st2 + + ctx = cython.cast( + cython.pointer[cython.p_void], malloc(3 * cython.sizeof(cython.p_void)) + ) + if ctx == cython.NULL: + free(shape) + free(strides) + lib.av_frame_free(cython.address(frame_ref)) + raise MemoryError("malloc() failed") + + ctx[0] = cython.cast(cython.p_void, frame_ref) + ctx[1] = cython.cast(cython.p_void, shape) + ctx[2] = cython.cast(cython.p_void, strides) + + managed = cython.cast( + cython.pointer[DLManagedTensor], malloc(cython.sizeof(DLManagedTensor)) + ) + if managed == cython.NULL: + free(ctx) + free(shape) + free(strides) + lib.av_frame_free(cython.address(frame_ref)) + raise MemoryError("malloc() failed") + + managed.dl_tensor.data = cython.cast(cython.p_void, frame_ref.data[self.index]) + managed.dl_tensor.device.device_type = device_type + managed.dl_tensor.device.device_id = device_id + managed.dl_tensor.ndim = ndim + managed.dl_tensor.dtype.code = kDLUInt + managed.dl_tensor.dtype.bits = bits + managed.dl_tensor.dtype.lanes = 1 + managed.dl_tensor.shape = shape + managed.dl_tensor.strides = strides + managed.dl_tensor.byte_offset = 0 + managed.manager_ctx = cython.cast(cython.p_void, ctx) + managed.deleter = _dlpack_managed_tensor_deleter + + try: + capsule = PyCapsule_New( + cython.cast(cython.p_void, managed), + b"dltensor", + _dlpack_capsule_destructor, + ) + except Exception: + _dlpack_managed_tensor_deleter(managed) + raise + + return capsule + + +@cython.cfunc +@cython.nogil +@cython.exceptval(check=False) +def _dlpack_managed_tensor_deleter( + managed: cython.pointer[DLManagedTensor], +) -> cython.void: + ctx: cython.pointer[cython.p_void] + frame_ref: cython.pointer[lib.AVFrame] + shape: cython.pointer[int64_t] + strides: cython.pointer[int64_t] + + if managed == cython.NULL: + return + + ctx = cython.cast(cython.pointer[cython.p_void], managed.manager_ctx) + if ctx != cython.NULL: + frame_ref = cython.cast(cython.pointer[lib.AVFrame], ctx[0]) + shape = cython.cast(cython.pointer[int64_t], ctx[1]) + strides = cython.cast(cython.pointer[int64_t], ctx[2]) + + if frame_ref != cython.NULL: + lib.av_frame_free(cython.address(frame_ref)) + if shape != cython.NULL: + free(shape) + if strides != cython.NULL: + free(strides) + free(ctx) + + free(managed) + + +@cython.cfunc +@cython.exceptval(check=False) +def _dlpack_capsule_destructor(capsule: object) -> cython.void: + managed: cython.pointer[DLManagedTensor] + if PyCapsule_IsValid(capsule, b"dltensor"): + managed = cython.cast( + cython.pointer[DLManagedTensor], + PyCapsule_GetPointer(capsule, b"dltensor"), + ) + if managed != cython.NULL: + managed.deleter(managed) diff --git a/av/video/plane.pyi b/av/video/plane.pyi index e4a0a206c..fcbf8e6ed 100644 --- a/av/video/plane.pyi +++ b/av/video/plane.pyi @@ -1,3 +1,5 @@ +from types import CapsuleType + from av.plane import Plane from .frame import VideoFrame @@ -9,3 +11,5 @@ class VideoPlane(Plane): buffer_size: int def __init__(self, frame: VideoFrame, index: int) -> None: ... + def __dlpack_device__(self) -> tuple[int, int]: ... + def __dlpack__(self, *, stream: int | None = None) -> CapsuleType: ... diff --git a/av/video/reformatter.py b/av/video/reformatter.py index 786543744..5a30d4d9b 100644 --- a/av/video/reformatter.py +++ b/av/video/reformatter.py @@ -185,6 +185,23 @@ def _reformat( src_format = cython.cast(lib.AVPixelFormat, frame.ptr.format) # Shortcut! + if frame.ptr.hw_frames_ctx: + if ( + dst_format == src_format + and width == frame.ptr.width + and height == frame.ptr.height + and dst_colorspace == src_colorspace + and src_color_range == dst_color_range + ): + return frame + + frame_sw = alloc_video_frame() + err_check(lib.av_hwframe_transfer_data(frame_sw.ptr, frame.ptr, 0)) + frame_sw.pts = frame.pts + frame_sw._init_user_attributes() + frame = frame_sw + src_format = cython.cast(lib.AVPixelFormat, frame.ptr.format) + if ( dst_format == src_format and width == frame.ptr.width diff --git a/include/dlpack.pxd b/include/dlpack.pxd new file mode 100644 index 000000000..5bcefd3bd --- /dev/null +++ b/include/dlpack.pxd @@ -0,0 +1,41 @@ +from libc.stdint cimport int64_t, uint8_t, uint16_t, uint64_t + + +cdef enum DLDeviceType: + kDLCPU = 1 + kDLCUDA = 2 + +cdef enum DLDataTypeCode: + kDLInt = 0 + kDLUInt = 1 + kDLFloat = 2 + kDLBfloat = 4 + kDLComplex = 5 + kDLBool = 6 + +cdef struct DLDevice: + int device_type + int device_id + +cdef struct DLDataType: + uint8_t code + uint8_t bits + uint16_t lanes + +cdef struct DLTensor: + void* data + DLDevice device + int ndim + DLDataType dtype + int64_t* shape + int64_t* strides + uint64_t byte_offset + +cdef struct DLManagedTensor + +ctypedef void (*DLManagedTensorDeleter)(DLManagedTensor*) noexcept nogil + +cdef struct DLManagedTensor: + DLTensor dl_tensor + void* manager_ctx + DLManagedTensorDeleter deleter diff --git a/include/libavcodec/avcodec.pxd b/include/libavcodec/avcodec.pxd index b75769945..a0a79df8d 100644 --- a/include/libavcodec/avcodec.pxd +++ b/include/libavcodec/avcodec.pxd @@ -1,4 +1,5 @@ -from libc.stdint cimport int64_t, uint16_t, uint32_t, uint8_t +from libc.stdint cimport int64_t, uint8_t, uint16_t, uint32_t + cdef extern from "libavcodec/packet.h" nogil: const AVPacketSideData *av_packet_side_data_get( @@ -353,8 +354,12 @@ cdef extern from "libavcodec/avcodec.h" nogil: int64_t pkt_dts void *opaque int sample_rate - int nb_side_data + AVBufferRef *buf[8] + AVBufferRef **extended_buf + int nb_extended_buf + AVFrameSideData **side_data + int nb_side_data int flags AVColorRange color_range AVColorPrimaries color_primaries @@ -364,6 +369,7 @@ cdef extern from "libavcodec/avcodec.h" nogil: AVDictionary *metadata int decode_error_flags + AVBufferRef *hw_frames_ctx AVBufferRef *opaque_ref AVChannelLayout ch_layout int64_t duration diff --git a/include/libavutil/avutil.pxd b/include/libavutil/avutil.pxd index 9d8486e1a..30de30720 100644 --- a/include/libavutil/avutil.pxd +++ b/include/libavutil/avutil.pxd @@ -1,4 +1,12 @@ -from libc.stdint cimport int16_t, int32_t, int64_t, uint8_t, uint16_t, uint32_t, uint64_t +from libc.stdint cimport ( + int16_t, + int32_t, + int64_t, + uint8_t, + uint16_t, + uint32_t, + uint64_t, +) cdef extern from "libavutil/audio_fifo.h" nogil: @@ -164,7 +172,9 @@ cdef extern from "libavutil/error.h" nogil: cdef extern from "libavutil/frame.h" nogil: cdef AVFrame* av_frame_alloc() cdef void av_frame_free(AVFrame**) + cdef int av_frame_ref(AVFrame *dst, const AVFrame *src) cdef void av_frame_unref(AVFrame *frame) + cdef int av_frame_get_buffer(AVFrame *frame, int align) cdef int av_frame_make_writable(AVFrame *frame) cdef int av_frame_copy_props(AVFrame *dst, const AVFrame *src) cdef AVFrameSideData* av_frame_get_side_data(AVFrame *frame, AVFrameSideDataType type) @@ -185,12 +195,25 @@ cdef extern from "libavutil/hwcontext.h" nogil: AV_HWDEVICE_TYPE_VULKAN AV_HWDEVICE_TYPE_D3D12VA + ctypedef struct AVHWFramesContext: + const void *av_class + AVBufferRef *device_ref + void *device_ctx + void *hwctx + AVPixelFormat format + AVPixelFormat sw_format + int width + int height + cdef int av_hwdevice_ctx_create(AVBufferRef **device_ctx, AVHWDeviceType type, const char *device, AVDictionary *opts, int flags) cdef AVHWDeviceType av_hwdevice_find_type_by_name(const char *name) cdef const char *av_hwdevice_get_type_name(AVHWDeviceType type) cdef AVHWDeviceType av_hwdevice_iterate_types(AVHWDeviceType prev) cdef int av_hwframe_transfer_data(AVFrame *dst, const AVFrame *src, int flags) + cdef AVBufferRef *av_hwframe_ctx_alloc(AVBufferRef *device_ref) + cdef int av_hwframe_ctx_init(AVBufferRef *ref) + cdef extern from "libavutil/imgutils.h" nogil: cdef int av_image_alloc( uint8_t *pointers[4], diff --git a/include/libswscale/swscale.pxd b/include/libswscale/swscale.pxd index ffc0eb6b0..ca84dceef 100644 --- a/include/libswscale/swscale.pxd +++ b/include/libswscale/swscale.pxd @@ -1,5 +1,6 @@ from libc.stdint cimport uint8_t + cdef extern from "libswscale/swscale.h" nogil: cdef int swscale_version() cdef char* swscale_configuration() diff --git a/tests/test_dlpack.py b/tests/test_dlpack.py new file mode 100644 index 000000000..8af2fd8a3 --- /dev/null +++ b/tests/test_dlpack.py @@ -0,0 +1,502 @@ +import gc + +import numpy +import pytest + +import av +from av import VideoFrame +from av.codec.hwaccel import HWAccel + +from .common import assertNdarraysEqual, fate_png + + +def _make_u8(shape: tuple[int, ...]) -> numpy.ndarray: + return numpy.arange(int(numpy.prod(shape)), dtype=numpy.uint8).reshape(shape) + + +def _make_u16(shape: tuple[int, ...]) -> numpy.ndarray: + return numpy.arange(int(numpy.prod(shape)), dtype=numpy.uint16).reshape(shape) + + +def _plane_to_2d(plane, height: int, width: int, dtype) -> numpy.ndarray: + itemsize = numpy.dtype(dtype).itemsize + assert plane.line_size % itemsize == 0 + pitch_elems = plane.line_size // itemsize + arr = numpy.frombuffer(memoryview(plane), dtype=dtype).reshape(height, pitch_elems) + return arr[:, :width] + + +def _get_cuda_backend(): + try: + import cupy # type: ignore + + try: + if cupy.cuda.runtime.getDeviceCount() > 0: + return ("cupy", cupy) + except Exception: + pass + except Exception: + pass + + try: + import torch # type: ignore + + if torch.cuda.is_available(): + return ("torch", torch) + except Exception: + pass + + return None + + +def test_hwdevice_registry_register_and_lookup() -> None: + import av._hwdevice_registry as hwreg + + ptr = 0x1234_5678_9ABC_DEF0 + hwreg.register_cuda_hwdevice_data_ptr(ptr, 7) + assert hwreg.lookup_cuda_device_id(ptr) == 7 + assert hwreg.lookup_cuda_device_id(0) == 0 + assert hwreg.lookup_cuda_device_id(ptr + 1) == 0 + + +def test_hwaccel_output_format_validation_and_primary_ctx() -> None: + hw = HWAccel(device_type="cuda", output_format=None) + assert hw.output_format == "sw" + assert "primary_ctx" not in hw.options + + hw = HWAccel(device_type="cuda", output_format="hw") + assert hw.output_format == "hw" + assert hw.options.get("primary_ctx") == "1" + + hw = HWAccel(device_type="cuda", output_format="hw", options={"primary_ctx": "0"}) + assert hw.options.get("primary_ctx") == "0" + + hw = HWAccel(device_type="cuda", device=0, output_format="hw") + assert hw.output_format == "hw" + + with pytest.raises(ValueError, match="output_format must be 'sw' or 'hw'"): + HWAccel(device_type="cuda", output_format="invalid") # type: ignore[arg-type] + + +def test_video_frame_from_dlpack_nv12_cpu_basic_zero_copy_and_lifetime() -> None: + width, height = 64, 48 + y = _make_u8((height, width)) + uv = _make_u8((height // 2, width // 2, 2)) + + frame = VideoFrame.from_dlpack((y, uv), format="nv12") + + assert frame.format.name == "nv12" + assert frame.width == width + assert frame.height == height + assert len(frame.planes) == 2 + assert frame.planes[0].width == width + assert frame.planes[0].height == height + assert frame.planes[1].width == width // 2 + assert frame.planes[1].height == height // 2 + assert frame.planes[0].line_size == width + assert frame.planes[1].line_size == width + + y_plane = _plane_to_2d(frame.planes[0], height, width, numpy.uint8) + uv_plane = _plane_to_2d(frame.planes[1], height // 2, width, numpy.uint8) + assertNdarraysEqual(y_plane, y) + assertNdarraysEqual(uv_plane, uv.reshape(height // 2, width)) + + y[0, 0] = 123 + uv[0, 0, 0] = 11 + uv[0, 0, 1] = 22 + + expected_y_bytes = y.tobytes() + expected_uv_bytes = uv.reshape(height // 2, width).tobytes() + + assert memoryview(frame.planes[0])[0] == 123 + assert memoryview(frame.planes[1])[0] == 11 + assert memoryview(frame.planes[1])[1] == 22 + + del y + del uv + gc.collect() + + assert bytes(frame.planes[0]) == expected_y_bytes + assert bytes(frame.planes[1]) == expected_uv_bytes + + +def test_video_frame_from_dlpack_nv12_cpu_with_pitch_and_dlpack_export() -> None: + width, height = 64, 48 + pad = 16 + + y_base = _make_u8((height, width + pad)) + y = y_base[:, :width] + uv_base = _make_u8((height // 2, (width + pad) // 2, 2)) + uv = uv_base[:, : width // 2, :] + + frame = VideoFrame.from_dlpack((y, uv), format="nv12") + + assert frame.planes[0].line_size == width + pad + assert frame.planes[1].line_size == width + pad + assert frame.planes[0].buffer_size == (width + pad) * height + assert frame.planes[1].buffer_size == (width + pad) * (height // 2) + + y_plane = _plane_to_2d(frame.planes[0], height, width, numpy.uint8) + uv_plane = _plane_to_2d(frame.planes[1], height // 2, width, numpy.uint8) + assertNdarraysEqual(y_plane, y) + assertNdarraysEqual(uv_plane, uv.reshape(height // 2, width)) + + assert frame.planes[0].__dlpack_device__() == (1, 0) + + y_dl = numpy.from_dlpack(frame.planes[0]) + uv_dl = numpy.from_dlpack(frame.planes[1]) + + assert y_dl.shape == (height, width) + assert y_dl.dtype == numpy.uint8 + assert y_dl.strides == (width + pad, 1) + assertNdarraysEqual(y_dl, y) + + assert uv_dl.shape == (height // 2, width // 2, 2) + assert uv_dl.dtype == numpy.uint8 + assert uv_dl.strides == (width + pad, 2, 1) + assertNdarraysEqual(uv_dl, uv) + + expected_y = numpy.array(y, copy=True) + expected_uv = numpy.array(uv, copy=True) + + del frame + del y + del uv + del y_base + del uv_base + gc.collect() + + assertNdarraysEqual(y_dl, expected_y) + assertNdarraysEqual(uv_dl, expected_uv) + + +def test_video_frame_from_dlpack_nv12_cpu_accepts_uv_2d() -> None: + width, height = 64, 48 + y = _make_u8((height, width)) + uv2d = _make_u8((height // 2, width)) + + frame = VideoFrame.from_dlpack((y, uv2d), format="nv12") + + uv_plane = _plane_to_2d(frame.planes[1], height // 2, width, numpy.uint8) + assertNdarraysEqual(uv_plane, uv2d) + + uv_dl = numpy.from_dlpack(frame.planes[1]) + assert uv_dl.shape == (height // 2, width // 2, 2) + assertNdarraysEqual(uv_dl, uv2d.reshape(height // 2, width // 2, 2)) + + +def test_video_frame_from_dlpack_accepts_video_plane_objects() -> None: + width, height = 64, 48 + y = _make_u8((height, width)) + uv = _make_u8((height // 2, width // 2, 2)) + + frame1 = VideoFrame.from_dlpack((y, uv), format="nv12") + frame2 = VideoFrame.from_dlpack((frame1.planes[0], frame1.planes[1]), format="nv12") + + assert bytes(frame2.planes[0]) == bytes(frame1.planes[0]) + assert bytes(frame2.planes[1]) == bytes(frame1.planes[1]) + + +@pytest.mark.parametrize("fmt", ["p010le", "p016le"]) +def test_video_frame_from_dlpack_p010_p016_cpu(fmt: str) -> None: + width, height = 64, 48 + y = _make_u16((height, width)) + uv = _make_u16((height // 2, width // 2, 2)) + + frame = VideoFrame.from_dlpack((y, uv), format=fmt) + + assert frame.format.name == fmt + assert len(frame.planes) == 2 + assert frame.planes[0].line_size == width * 2 + assert frame.planes[1].line_size == width * 2 + + y_plane = _plane_to_2d(frame.planes[0], height, width, numpy.uint16) + uv_plane = _plane_to_2d(frame.planes[1], height // 2, width, numpy.uint16) + assertNdarraysEqual(y_plane, y) + assertNdarraysEqual(uv_plane, uv.reshape(height // 2, width)) + + y_dl = numpy.from_dlpack(frame.planes[0]) + uv_dl = numpy.from_dlpack(frame.planes[1]) + + assert y_dl.dtype == numpy.uint16 + assert y_dl.shape == (height, width) + assert y_dl.strides == (width * 2, 2) + assertNdarraysEqual(y_dl, y) + + assert uv_dl.dtype == numpy.uint16 + assert uv_dl.shape == (height // 2, width // 2, 2) + assert uv_dl.strides == (width * 2, 4, 2) + assertNdarraysEqual(uv_dl, uv) + + +def test_video_plane_dlpack_export_keeps_frame_alive_after_gc() -> None: + container = av.open(fate_png()) + frame = next(container.decode(video=0)) + frame_nv12 = frame.reformat(format="nv12") + + width = frame_nv12.width + height = frame_nv12.height + line_size = frame_nv12.planes[0].line_size + expected = _plane_to_2d(frame_nv12.planes[0], height, width, numpy.uint8).copy() + + y_dl = numpy.from_dlpack(frame_nv12.planes[0]) + assert y_dl.shape == (height, width) + assert y_dl.strides == (line_size, 1) + + del frame_nv12 + del frame + del container + gc.collect() + + assertNdarraysEqual(y_dl, expected) + + +def test_video_plane_dlpack_unsupported_format_raises() -> None: + rgb = numpy.zeros((16, 16, 3), dtype=numpy.uint8) + frame = VideoFrame.from_ndarray(rgb, format="rgb24") + assert frame.planes[0].__dlpack_device__() == (1, 0) + + with pytest.raises( + NotImplementedError, match="unsupported sw_format for DLPack export" + ): + frame.planes[0].__dlpack__() + + +def test_video_frame_from_dlpack_requires_two_planes() -> None: + y = numpy.zeros((4, 4), dtype=numpy.uint8) + with pytest.raises(ValueError, match="2-plane"): + VideoFrame.from_dlpack(y, format="nv12") + + +def test_video_frame_from_dlpack_rejects_unsupported_format() -> None: + width, height = 64, 48 + y = numpy.zeros((height, width), dtype=numpy.uint8) + uv = numpy.zeros((height // 2, width // 2, 2), dtype=numpy.uint8) + + with pytest.raises(NotImplementedError, match="supports nv12, p010le, p016le only"): + VideoFrame.from_dlpack((y, uv), format="yuv420p") + + +def test_video_frame_from_dlpack_rejects_device_id_for_cpu() -> None: + width, height = 64, 48 + y = numpy.zeros((height, width), dtype=numpy.uint8) + uv = numpy.zeros((height // 2, width // 2, 2), dtype=numpy.uint8) + + with pytest.raises(ValueError, match="device_id must be 0 for CPU tensors"): + VideoFrame.from_dlpack((y, uv), format="nv12", device_id=1) + + +def test_video_frame_from_dlpack_requires_both_width_height_or_neither() -> None: + width, height = 64, 48 + y = numpy.zeros((height, width), dtype=numpy.uint8) + uv = numpy.zeros((height // 2, width // 2, 2), dtype=numpy.uint8) + + with pytest.raises(ValueError, match="either specify both width/height or neither"): + VideoFrame.from_dlpack((y, uv), format="nv12", width=width, height=0) + + +def test_video_frame_from_dlpack_rejects_plane0_shape_mismatch_with_width_height() -> ( + None +): + width, height = 64, 48 + y = numpy.zeros((height, width), dtype=numpy.uint8) + uv = numpy.zeros((height // 2, width // 2, 2), dtype=numpy.uint8) + + with pytest.raises(ValueError, match="plane 0 shape does not match width/height"): + VideoFrame.from_dlpack((y, uv), format="nv12", width=width + 2, height=height) + + +def test_video_frame_from_dlpack_rejects_odd_dimensions() -> None: + width, height = 63, 48 + y = numpy.zeros((height, width), dtype=numpy.uint8) + uv = numpy.zeros((height // 2, width), dtype=numpy.uint8) + + with pytest.raises(ValueError, match="width/height must be even"): + VideoFrame.from_dlpack((y, uv), format="nv12") + + +def test_video_frame_from_dlpack_rejects_noncontiguous_plane0_last_dim() -> None: + width, height = 64, 48 + y_full = numpy.zeros((height, width * 2), dtype=numpy.uint8) + y = y_full[:, ::2] + uv = numpy.zeros((height // 2, width // 2, 2), dtype=numpy.uint8) + + with pytest.raises( + ValueError, match="plane 0 must be contiguous in the last dimension" + ): + VideoFrame.from_dlpack((y, uv), format="nv12") + + +def test_video_frame_from_dlpack_rejects_noncontiguous_uv_plane_last_dim_2d() -> None: + width, height = 64, 48 + y = numpy.zeros((height, width), dtype=numpy.uint8) + uv_full = numpy.zeros((height // 2, width * 2), dtype=numpy.uint8) + uv = uv_full[:, ::2] + + with pytest.raises( + ValueError, match="plane 1 must be contiguous in the last dimension" + ): + VideoFrame.from_dlpack((y, uv), format="nv12") + + +def test_video_frame_from_dlpack_rejects_unexpected_uv_strides_3d() -> None: + width, height = 64, 48 + y = numpy.zeros((height, width), dtype=numpy.uint8) + uv_full = numpy.zeros((height // 2, width // 2, 4), dtype=numpy.uint8) + uv = uv_full[:, :, :2] + + with pytest.raises(ValueError, match="unexpected UV plane strides"): + VideoFrame.from_dlpack((y, uv), format="nv12") + + +def test_video_frame_from_dlpack_rejects_wrong_dtype_plane0() -> None: + width, height = 64, 48 + y = numpy.zeros((height, width), dtype=numpy.uint16) + uv = numpy.zeros((height // 2, width // 2, 2), dtype=numpy.uint8) + + with pytest.raises(TypeError, match="unexpected dtype for plane 0"): + VideoFrame.from_dlpack((y, uv), format="nv12") + + +def test_video_frame_from_dlpack_rejects_wrong_dtype_plane1() -> None: + width, height = 64, 48 + y = numpy.zeros((height, width), dtype=numpy.uint8) + uv = numpy.zeros((height // 2, width // 2, 2), dtype=numpy.uint16) + + with pytest.raises(TypeError, match="unexpected dtype for plane 1"): + VideoFrame.from_dlpack((y, uv), format="nv12") + + +def test_video_frame_from_dlpack_p010le_requires_uint16() -> None: + width, height = 64, 48 + y = numpy.zeros((height, width), dtype=numpy.uint8) + uv = numpy.zeros((height // 2, width // 2, 2), dtype=numpy.uint8) + + with pytest.raises(TypeError, match="unexpected dtype for plane 0"): + VideoFrame.from_dlpack((y, uv), format="p010le") + + +def test_video_frame_from_dlpack_rejects_plane0_ndim_not_2() -> None: + y = numpy.zeros((4, 4, 1), dtype=numpy.uint8) + uv = numpy.zeros((2, 4), dtype=numpy.uint8) + + with pytest.raises(ValueError, match="plane 0 must be 2D"): + VideoFrame.from_dlpack((y, uv), format="nv12", width=4, height=4) + + +def test_video_frame_from_dlpack_rejects_plane1_ndim_not_2_or_3() -> None: + width, height = 64, 48 + y = numpy.zeros((height, width), dtype=numpy.uint8) + uv = numpy.zeros((height // 2, width, 1, 1), dtype=numpy.uint8) + + with pytest.raises(ValueError, match="plane 1 must be 2D or 3D"): + VideoFrame.from_dlpack((y, uv), format="nv12") + + +def test_video_frame_from_dlpack_reusing_capsule_raises_typeerror() -> None: + width, height = 64, 48 + y = numpy.zeros((height, width), dtype=numpy.uint8) + uv = numpy.zeros((height // 2, width // 2, 2), dtype=numpy.uint8) + + cap0 = y.__dlpack__() + cap1 = uv.__dlpack__() + + VideoFrame.from_dlpack((cap0, cap1), format="nv12", width=width, height=height) + + with pytest.raises(TypeError, match="expected a DLPack capsule"): + VideoFrame.from_dlpack((cap0, cap1), format="nv12", width=width, height=height) + + +def test_video_frame_from_dlpack_invalid_plane_object_raises_typeerror() -> None: + with pytest.raises(TypeError, match="expected a DLPack capsule"): + VideoFrame.from_dlpack((object(), object()), format="nv12", width=64, height=48) + + +def test_video_frame_from_dlpack_cuda_hw_frame_behavior_if_available() -> None: + backend = _get_cuda_backend() + if backend is None: + pytest.skip("CUDA backend (cupy/torch) not available.") + + width, height = 64, 48 + name, mod = backend + + try: + if name == "cupy": + try: + ndev = int(mod.cuda.runtime.getDeviceCount()) + except Exception: + ndev = 1 + + device_id = 1 if ndev > 1 else 0 + with mod.cuda.Device(device_id): + y = mod.arange(height * width, dtype=mod.uint8).reshape(height, width) + uv = mod.arange( + (height // 2) * (width // 2) * 2, dtype=mod.uint8 + ).reshape(height // 2, width // 2, 2) + expected_device = y.__dlpack_device__() + frame = VideoFrame.from_dlpack((y, uv), format="nv12") + + assert frame.format.name == "cuda" + assert len(frame.planes) == 2 + + with pytest.raises( + TypeError, match="Hardware frame planes do not support" + ): + memoryview(frame.planes[0]) + + assert frame.planes[0].__dlpack_device__() == expected_device + + cap_y = frame.planes[0].__dlpack__() + if hasattr(mod, "fromDlpack"): + y2 = mod.fromDlpack(cap_y) + else: + y2 = mod.from_dlpack(cap_y) + + assert y2.shape == y.shape + assert mod.all(y2 == y).item() + + with pytest.raises( + ValueError, + match="Cannot convert a hardware frame to numpy directly", + ): + frame.to_ndarray(format="cuda") + + else: + try: + ndev = int(mod.cuda.device_count()) + except Exception: + ndev = 1 + + device_id = 1 if ndev > 1 else 0 + device = f"cuda:{device_id}" + + y = mod.arange(height * width, device=device, dtype=mod.uint8).reshape( + height, width + ) + uv = mod.arange( + (height // 2) * (width // 2) * 2, device=device, dtype=mod.uint8 + ).reshape(height // 2, width // 2, 2) + + expected_device = y.__dlpack_device__() + frame = VideoFrame.from_dlpack((y, uv), format="nv12") + + assert frame.format.name == "cuda" + assert len(frame.planes) == 2 + + with pytest.raises(TypeError, match="Hardware frame planes do not support"): + memoryview(frame.planes[0]) + + assert frame.planes[0].__dlpack_device__() == expected_device + + cap_y = frame.planes[0].__dlpack__() + y2 = mod.utils.dlpack.from_dlpack(cap_y) + + assert tuple(y2.shape) == tuple(y.shape) + assert mod.equal(y2, y) + + with pytest.raises( + ValueError, match="Cannot convert a hardware frame to numpy directly" + ): + frame.to_ndarray(format="cuda") + except av.FFmpegError as e: + pytest.skip(f"CUDA hwcontext not available in this build/runtime: {e}")