Skip to content

Commit 84e10a4

Browse files
committed
Use indirect fixtures for a nicer pattern and avoid thread issues
After my first AI try was a crazy mess, the second run actually found a neat solution... These objects can be created in the main thread, but we can't create them on the fly in many threads as it was... Signed-off-by: Sebastian Berg <sebastianb@nvidia.com>
1 parent d1c615a commit 84e10a4

1 file changed

Lines changed: 85 additions & 81 deletions

File tree

cuda_core/tests/test_object_protocols.py

Lines changed: 85 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,11 @@ def sample_ipc_buffer_descriptor(ipc_device):
233233
options = DeviceMemoryResourceOptions(max_size=POOL_SIZE, ipc_enabled=True)
234234
mr = DeviceMemoryResource(ipc_device, options=options)
235235
buf = mr.allocate(64, stream=ipc_device.default_stream)
236-
return buf.ipc_descriptor
236+
descriptor = buf.ipc_descriptor
237+
buf.close()
238+
# TODO(seberg): 2026-06: mr close may be unsafe with incomplete `buf.close()`
239+
ipc_device.sync()
240+
return descriptor
237241

238242

239243
@pytest.fixture
@@ -523,6 +527,26 @@ def sample_switch_node_alt(sample_graphdef):
523527
return sample_graphdef.switch(condition, 3)
524528

525529

530+
# Indirect-parametrize helpers: request.getfixturevalue() runs here, in the
531+
# fixture (main thread), so the resolved object is already available when the
532+
# test function runs in a worker thread.
533+
534+
535+
@pytest.fixture
536+
def sample_object(request):
537+
return request.getfixturevalue(request.param)
538+
539+
540+
@pytest.fixture
541+
def sample_object_a(request):
542+
return request.getfixturevalue(request.param)
543+
544+
545+
@pytest.fixture
546+
def sample_object_b(request):
547+
return request.getfixturevalue(request.param)
548+
549+
526550
# =============================================================================
527551
# Type groupings
528552
# =============================================================================
@@ -718,144 +742,125 @@ def sample_switch_node_alt(sample_graphdef):
718742
# =============================================================================
719743

720744

721-
@pytest.mark.parametrize("fixture_name", WEAKREF_TYPES)
722-
def test_weakref_supported(fixture_name, request):
745+
@pytest.mark.parametrize("sample_object", WEAKREF_TYPES, indirect=True)
746+
def test_weakref_supported(sample_object):
723747
"""Object supports weak references."""
724-
obj = request.getfixturevalue(fixture_name)
725-
ref = weakref.ref(obj)
726-
assert ref() is obj
748+
ref = weakref.ref(sample_object)
749+
assert ref() is sample_object
727750

728751

729752
# =============================================================================
730753
# Hash tests
731754
# =============================================================================
732755

733756

734-
@pytest.mark.parametrize("fixture_name", HASH_TYPES)
735-
def test_hash_consistency(fixture_name, request):
757+
@pytest.mark.parametrize("sample_object", HASH_TYPES, indirect=True)
758+
def test_hash_consistency(sample_object):
736759
"""Hash is consistent across multiple calls."""
737-
obj = request.getfixturevalue(fixture_name)
738-
assert hash(obj) == hash(obj)
760+
assert hash(sample_object) == hash(sample_object)
739761

740762

741-
@pytest.mark.parametrize("a_name,b_name", SAME_TYPE_PAIRS)
742-
def test_hash_distinct_same_type(a_name, b_name, request):
763+
@pytest.mark.parametrize("sample_object_a,sample_object_b", SAME_TYPE_PAIRS, indirect=True)
764+
def test_hash_distinct_same_type(sample_object_a, sample_object_b):
743765
"""Distinct objects of the same type have different hashes."""
744-
obj_a = request.getfixturevalue(a_name)
745-
obj_b = request.getfixturevalue(b_name)
746-
assert hash(obj_a) != hash(obj_b) # extremely unlikely
766+
assert hash(sample_object_a) != hash(sample_object_b) # extremely unlikely
747767

748768

749-
@pytest.mark.parametrize("a_name,b_name", itertools.combinations(HASH_TYPES, 2))
750-
def test_hash_distinct_cross_type(a_name, b_name, request):
769+
@pytest.mark.parametrize("sample_object_a,sample_object_b", itertools.combinations(HASH_TYPES, 2), indirect=True)
770+
def test_hash_distinct_cross_type(sample_object_a, sample_object_b):
751771
"""Distinct objects of different types have different hashes."""
752-
obj_a = request.getfixturevalue(a_name)
753-
obj_b = request.getfixturevalue(b_name)
754-
assert hash(obj_a) != hash(obj_b) # extremely unlikely
772+
assert hash(sample_object_a) != hash(sample_object_b) # extremely unlikely
755773

756774

757775
# =============================================================================
758776
# Equality tests
759777
# =============================================================================
760778

761779

762-
@pytest.mark.parametrize("fixture_name", EQ_TYPES)
763-
def test_equality_basic(fixture_name, request):
780+
@pytest.mark.parametrize("sample_object", EQ_TYPES, indirect=True)
781+
def test_equality_basic(sample_object):
764782
"""Object equality: reflexive, not equal to None or other types."""
765-
obj = request.getfixturevalue(fixture_name)
766-
assert obj == obj
767-
assert obj is not None
768-
assert obj != "string"
769-
if hasattr(obj, "handle"):
770-
assert obj != obj.handle
783+
assert sample_object == sample_object
784+
assert sample_object is not None
785+
assert sample_object != "string"
786+
if hasattr(sample_object, "handle"):
787+
assert sample_object != sample_object.handle
771788

772789

773-
@pytest.mark.parametrize("a_name,b_name", itertools.combinations(EQ_TYPES, 2))
774-
def test_no_cross_type_equality(a_name, b_name, request):
790+
@pytest.mark.parametrize("sample_object_a,sample_object_b", itertools.combinations(EQ_TYPES, 2), indirect=True)
791+
def test_no_cross_type_equality(sample_object_a, sample_object_b):
775792
"""No two distinct objects of different types should compare equal."""
776-
obj_a = request.getfixturevalue(a_name)
777-
obj_b = request.getfixturevalue(b_name)
778-
assert obj_a != obj_b
793+
assert sample_object_a != sample_object_b
779794

780795

781-
@pytest.mark.parametrize("a_name,b_name", SAME_TYPE_PAIRS)
782-
def test_same_type_inequality(a_name, b_name, request):
796+
@pytest.mark.parametrize("sample_object_a,sample_object_b", SAME_TYPE_PAIRS, indirect=True)
797+
def test_same_type_inequality(sample_object_a, sample_object_b):
783798
"""Two distinct objects of the same type should not compare equal."""
784-
obj_a = request.getfixturevalue(a_name)
785-
obj_b = request.getfixturevalue(b_name)
786-
assert obj_a is not obj_b
787-
assert obj_a != obj_b
799+
assert sample_object_a is not sample_object_b
800+
assert sample_object_a != sample_object_b
788801

789802

790-
@pytest.mark.parametrize("fixture_name,copy_fn", FROM_HANDLE_COPIES)
791-
def test_equality_same_handle(fixture_name, copy_fn, request):
803+
@pytest.mark.parametrize("sample_object,copy_fn", FROM_HANDLE_COPIES, indirect=["sample_object"])
804+
def test_equality_same_handle(sample_object, copy_fn):
792805
"""Two wrappers around the same handle should compare equal."""
793-
obj = request.getfixturevalue(fixture_name)
794-
obj2 = copy_fn(obj)
795-
assert obj == obj2
796-
assert hash(obj) == hash(obj2)
806+
obj2 = copy_fn(sample_object)
807+
assert sample_object == obj2
808+
assert hash(sample_object) == hash(obj2)
797809

798810

799811
# =============================================================================
800812
# Collection usage tests
801813
# =============================================================================
802814

803815

804-
@pytest.mark.parametrize("fixture_name", DICT_KEY_TYPES)
805-
def test_usable_as_dict_key(fixture_name, request):
816+
@pytest.mark.parametrize("sample_object", DICT_KEY_TYPES, indirect=True)
817+
def test_usable_as_dict_key(sample_object):
806818
"""Object can be used as a dictionary key."""
807-
obj = request.getfixturevalue(fixture_name)
808-
d = {obj: "value"}
809-
assert d[obj] == "value"
810-
assert obj in d
819+
d = {sample_object: "value"}
820+
assert d[sample_object] == "value"
821+
assert sample_object in d
811822

812823

813-
@pytest.mark.parametrize("fixture_name", DICT_KEY_TYPES)
814-
def test_usable_in_set(fixture_name, request):
824+
@pytest.mark.parametrize("sample_object", DICT_KEY_TYPES, indirect=True)
825+
def test_usable_in_set(sample_object):
815826
"""Object can be added to a set."""
816-
obj = request.getfixturevalue(fixture_name)
817-
s = {obj}
818-
assert obj in s
827+
s = {sample_object}
828+
assert sample_object in s
819829

820830

821-
@pytest.mark.parametrize("fixture_name", WEAKREF_TYPES)
822-
def test_usable_in_weak_value_dict(fixture_name, request):
831+
@pytest.mark.parametrize("sample_object", WEAKREF_TYPES, indirect=True)
832+
def test_usable_in_weak_value_dict(sample_object):
823833
"""Object can be used as a WeakValueDictionary value."""
824-
obj = request.getfixturevalue(fixture_name)
825834
wvd = weakref.WeakValueDictionary()
826-
wvd["key"] = obj
827-
assert wvd["key"] is obj
835+
wvd["key"] = sample_object
836+
assert wvd["key"] is sample_object
828837

829838

830-
@pytest.mark.parametrize("fixture_name", WEAK_KEY_TYPES)
831-
def test_usable_in_weak_key_dict(fixture_name, request):
839+
@pytest.mark.parametrize("sample_object", WEAK_KEY_TYPES, indirect=True)
840+
def test_usable_in_weak_key_dict(sample_object):
832841
"""Object can be used as a WeakKeyDictionary key."""
833-
obj = request.getfixturevalue(fixture_name)
834842
wkd = weakref.WeakKeyDictionary()
835-
wkd[obj] = "value"
836-
assert wkd[obj] == "value"
843+
wkd[sample_object] = "value"
844+
assert wkd[sample_object] == "value"
837845

838846

839-
@pytest.mark.parametrize("fixture_name", WEAK_KEY_TYPES)
840-
def test_usable_in_weak_set(fixture_name, request):
847+
@pytest.mark.parametrize("sample_object", WEAK_KEY_TYPES, indirect=True)
848+
def test_usable_in_weak_set(sample_object):
841849
"""Object can be added to a WeakSet."""
842-
obj = request.getfixturevalue(fixture_name)
843850
ws = weakref.WeakSet()
844-
ws.add(obj)
845-
assert obj in ws
851+
ws.add(sample_object)
852+
assert sample_object in ws
846853

847854

848855
# =============================================================================
849856
# Repr tests
850857
# =============================================================================
851858

852859

853-
@pytest.mark.parametrize("fixture_name,pattern", REPR_PATTERNS)
854-
def test_repr_format(fixture_name, pattern, request):
860+
@pytest.mark.parametrize("sample_object,pattern", REPR_PATTERNS, indirect=["sample_object"])
861+
def test_repr_format(sample_object, pattern):
855862
"""repr() returns a properly formatted string."""
856-
obj = request.getfixturevalue(fixture_name)
857-
result = repr(obj)
858-
assert re.fullmatch(pattern, result)
863+
assert re.fullmatch(pattern, repr(sample_object))
859864

860865

861866
# =============================================================================
@@ -864,10 +869,9 @@ def test_repr_format(fixture_name, pattern, request):
864869

865870

866871
@pytest.mark.parametrize("pickle_module", PICKLE_MODULES)
867-
@pytest.mark.parametrize("fixture_name", PICKLE_TYPES)
868-
def test_pickle_roundtrip(fixture_name, pickle_module, request):
872+
@pytest.mark.parametrize("sample_object", PICKLE_TYPES, indirect=True)
873+
def test_pickle_roundtrip(sample_object, pickle_module):
869874
"""Object survives a pickle/cloudpickle roundtrip."""
870875
mod = pytest.importorskip(pickle_module)
871-
obj = request.getfixturevalue(fixture_name)
872-
result = mod.loads(mod.dumps(obj))
873-
assert type(result) is type(obj)
876+
result = mod.loads(mod.dumps(sample_object))
877+
assert type(result) is type(sample_object)

0 commit comments

Comments
 (0)