Skip to content

Commit 7126324

Browse files
committed
refactor(cuda.core): rename AccessedBySet -> AccessedBySetProxy
Align with the graph module's AdjacencySetProxy: rename the class and inherit from collections.abc.MutableSet so the full set interface (remove, pop, clear, |=, &=, -=, ^=, isdisjoint, subset/superset operators, etc.) is filled in automatically from the existing add / discard / __contains__ / __iter__ / __len__ primitives. Add classmethod _from_iterable so binary set operators (&|^) produce plain sets rather than constructing a buffer-less proxy. Tighten add to TypeError on non-Device/Host inputs and discard / __contains__ to silently ignore them, matching MutableSet contracts. The hand-rolled __eq__ (set/frozenset comparison) is dropped: Set ABC's default implementation handles it correctly. Resolves PR #1775 review (Andy-Jost, 2026-05-04): naming consistency with AdjacencySetProxy and full MutableSet conformance.
1 parent 5743e05 commit 7126324

3 files changed

Lines changed: 24 additions & 17 deletions

File tree

cuda_core/cuda/core/_memory/_managed_buffer.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from __future__ import annotations
55

6+
from collections.abc import MutableSet
67
from typing import TYPE_CHECKING
78

89
from cuda.core._device import Device
@@ -56,7 +57,7 @@ def _query_accessed_by(buf: Buffer) -> list[Device | Host]:
5657
return [Host() if v == -1 else Device(v) for v in raw if v != -2]
5758

5859

59-
class AccessedBySet:
60+
class AccessedBySetProxy(MutableSet):
6061
"""Live driver-backed view of ``set_accessed_by`` advice for a managed buffer.
6162
6263
Reads (``__contains__``, ``__iter__``, ``len(...)``) call
@@ -76,7 +77,16 @@ class AccessedBySet:
7677
def __init__(self, buf: ManagedBuffer):
7778
self._buf = buf
7879

80+
# Operators such as &|^ produce a plain set, not another proxy.
81+
@classmethod
82+
def _from_iterable(cls, it):
83+
return set(it)
84+
85+
# --- abstract methods required by MutableSet ---
86+
7987
def __contains__(self, location) -> bool:
88+
if not isinstance(location, (Device, Host)):
89+
return False
8090
return location in _query_accessed_by(self._buf)
8191

8292
def __iter__(self):
@@ -85,24 +95,21 @@ def __iter__(self):
8595
def __len__(self) -> int:
8696
return len(_query_accessed_by(self._buf))
8797

88-
def __eq__(self, other) -> bool:
89-
if isinstance(other, AccessedBySet):
90-
return set(_query_accessed_by(self._buf)) == set(_query_accessed_by(other._buf))
91-
if isinstance(other, (set, frozenset)):
92-
return set(_query_accessed_by(self._buf)) == other
93-
return NotImplemented
94-
95-
def __repr__(self) -> str:
96-
return f"AccessedBySet({set(_query_accessed_by(self._buf))!r})"
97-
9898
def add(self, location: Device | Host) -> None:
9999
"""Apply ``set_accessed_by`` advice for ``location``."""
100+
if not isinstance(location, (Device, Host)):
101+
raise TypeError(f"expected Device or Host, got {type(location).__name__}")
100102
_advise_one(self._buf, _SET_ACCESSED_BY, location)
101103

102104
def discard(self, location: Device | Host) -> None:
103105
"""Apply ``unset_accessed_by`` advice for ``location``."""
106+
if not isinstance(location, (Device, Host)):
107+
return
104108
_advise_one(self._buf, _UNSET_ACCESSED_BY, location)
105109

110+
def __repr__(self) -> str:
111+
return f"AccessedBySetProxy({set(_query_accessed_by(self._buf))!r})"
112+
106113

107114
class ManagedBuffer(Buffer):
108115
"""Managed (unified) memory buffer with a property-style advice API.
@@ -194,9 +201,9 @@ def preferred_location(self, value: Device | Host | None) -> None:
194201
_advise_one(self, _SET_PREFERRED, value)
195202

196203
@property
197-
def accessed_by(self) -> AccessedBySet:
204+
def accessed_by(self) -> AccessedBySetProxy:
198205
"""Live set-like view of ``set_accessed_by`` locations."""
199-
return AccessedBySet(self)
206+
return AccessedBySetProxy(self)
200207

201208
@accessed_by.setter
202209
def accessed_by(self, locations) -> None:

cuda_core/docs/source/api_private.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ CUDA runtime
3232
_device.DeviceProperties
3333
_memory._ipc.IPCAllocationHandle
3434
_memory._ipc.IPCBufferDescriptor
35-
_memory._managed_buffer.AccessedBySet
35+
_memory._managed_buffer.AccessedBySetProxy
3636

3737

3838
CUDA graphs

cuda_core/tests/memory/test_managed_ops.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -619,7 +619,7 @@ def test_accessed_by_add_discard(self, init_cuda):
619619
plain.close()
620620

621621
def test_accessed_by_read_methods(self, init_cuda):
622-
"""Cover __iter__, __len__, __eq__, __repr__ on AccessedBySet."""
622+
"""Cover __iter__, __len__, __eq__, __repr__ on AccessedBySetProxy."""
623623
device = Device()
624624
_skip_if_managed_location_ops_unsupported(device)
625625
device.set_current()
@@ -631,7 +631,7 @@ def test_accessed_by_read_methods(self, init_cuda):
631631
assert len(buf.accessed_by) == 0
632632
assert list(buf.accessed_by) == []
633633
assert buf.accessed_by == set()
634-
assert "AccessedBySet" in repr(buf.accessed_by)
634+
assert "AccessedBySetProxy" in repr(buf.accessed_by)
635635

636636
# After add
637637
buf.accessed_by.add(device)
@@ -640,7 +640,7 @@ def test_accessed_by_read_methods(self, init_cuda):
640640
assert buf.accessed_by == {device}
641641
assert buf.accessed_by != frozenset()
642642

643-
# __eq__ vs another AccessedBySet on the same buffer
643+
# __eq__ vs another AccessedBySetProxy on the same buffer
644644
assert buf.accessed_by == buf.accessed_by
645645
finally:
646646
plain.close()

0 commit comments

Comments
 (0)