Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions av/_hwdevice_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
_cuda_hwdevice_data_ptr_to_device_id: dict[int, int] = {}


def register_cuda_hwdevice_data_ptr(hwdevice_data_ptr: int, device_id: int) -> None:
if hwdevice_data_ptr:
_cuda_hwdevice_data_ptr_to_device_id[int(hwdevice_data_ptr)] = int(device_id)


def lookup_cuda_device_id(hwdevice_data_ptr: int) -> int:
if not hwdevice_data_ptr:
return 0
return _cuda_hwdevice_data_ptr_to_device_id.get(int(hwdevice_data_ptr), 0)
1 change: 1 addition & 0 deletions av/codec/hwaccel.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@ cdef class HWAccel:
cdef public bint allow_software_fallback
cdef public dict options
cdef public int flags
cdef public str output_format
31 changes: 31 additions & 0 deletions av/codec/hwaccel.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from cython.cimports.av.error import err_check
from cython.cimports.av.video.format import get_video_format

import av._hwdevice_registry as _hwreg


class HWDeviceType(IntEnum):
none = lib.AV_HWDEVICE_TYPE_NONE
Expand Down Expand Up @@ -110,7 +112,11 @@ def __init__(
allow_software_fallback=True,
options=None,
flags=None,
output_format=None,
):
if isinstance(device, int):
device = str(device)

if isinstance(device_type, HWDeviceType):
self._device_type = device_type
elif isinstance(device_type, str):
Expand All @@ -120,9 +126,20 @@ def __init__(
else:
raise ValueError("Unknown type for device_type")

if output_format is None:
output_format = "sw"
if isinstance(output_format, str):
output_format = output_format.lower()
if output_format not in {"sw", "hw"}:
raise ValueError("output_format must be 'sw' or 'hw'")
self.output_format = output_format

self._device = device
self.allow_software_fallback = allow_software_fallback

self.options = {} if not options else dict(options)
if self._device_type == HWDeviceType.cuda and self.output_format == "hw":
self.options.setdefault("primary_ctx", "1")
self.flags = 0 if not flags else flags
self.ptr = cython.NULL
self.config = None
Expand Down Expand Up @@ -155,6 +172,19 @@ def _initialize_hw_context(self, codec: Codec):
)
)

if config.ptr.device_type == lib.AV_HWDEVICE_TYPE_CUDA:
device_id = 0
if self._device:
try:
device_id = int(self._device)
except ValueError:
device_id = 0

_hwreg.register_cuda_hwdevice_data_ptr(
cython.cast(cython.size_t, self.ptr.data),
device_id,
)

def create(self, codec: Codec):
"""Create a new hardware accelerator context with the given codec"""
if self.ptr:
Expand All @@ -165,6 +195,7 @@ def create(self, codec: Codec):
device=self._device,
allow_software_fallback=self.allow_software_fallback,
options=self.options,
output_format=self.output_format,
)
ret._initialize_hw_context(codec)
return ret
Expand Down
8 changes: 6 additions & 2 deletions av/codec/hwaccel.pyi
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from enum import IntEnum
from typing import cast
from typing import Literal, cast

from av.codec.codec import Codec
from av.video.format import VideoFormat
Expand Down Expand Up @@ -37,13 +37,17 @@ class HWConfig:
def is_supported(self) -> bool: ...

class HWAccel:
output_format: Literal["sw", "hw"]
options: dict[str, object]

def __init__(
self,
device_type: str | HWDeviceType,
device: str | None = None,
device: str | int | None = None,
allow_software_fallback: bool = False,
options: dict[str, object] | None = None,
flags: int | None = None,
output_format: Literal["sw", "hw"] | None = None,
) -> None: ...
def create(self, codec: Codec) -> HWAccel: ...

Expand Down
3 changes: 3 additions & 0 deletions av/video/codeccontext.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,9 @@ def _transfer_hwframe(self, frame: Frame):
# need to transfer.
return frame

if self.hwaccel_ctx.output_format == "hw":
return frame

frame_sw: Frame = self._alloc_next_frame()
err_check(lib.av_hwframe_transfer_data(frame_sw.ptr, frame.ptr, 0))
# TODO: Is there anything else to transfer?
Expand Down
Loading
Loading