Skip to content

Commit 98bf0d4

Browse files
committed
fix: Make legacy_default() and per_thread_default() return singletons - Fixes #1494
Signed-off-by: Monishver Chandrasekaran <monishverchandrasekaran@gmail.com>
1 parent 264524b commit 98bf0d4

2 files changed

Lines changed: 72 additions & 15 deletions

File tree

cuda_core/cuda/core/_stream.pyx

Lines changed: 46 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -117,17 +117,38 @@ cdef class Stream:
117117
complete, and all subsequent operations in blocking streams wait for
118118
the legacy default stream operation to complete.
119119
120+
This stream is useful for ensuring strict ordering of operations but
121+
may limit concurrency. For better performance in concurrent scenarios,
122+
consider using per_thread_default() or creating explicit streams.
123+
124+
This method returns the same singleton instance on every call for the
125+
base Stream class. Subclasses will receive new instances of the subclass
126+
type that wrap the same underlying CUDA stream.
127+
120128
Returns
121129
-------
122130
Stream
123-
The legacy default stream instance for the current context.
131+
The legacy default stream singleton instance for the current context.
124132
125133
See Also
126134
--------
127135
per_thread_default : Per-thread default stream alternative.
136+
from_handle : Create stream from existing handle.
128137
138+
Examples
139+
--------
140+
>>> from cuda.core import Stream
141+
>>> stream1 = Stream.legacy_default()
142+
>>> stream2 = Stream.legacy_default()
143+
>>> stream1 is stream2 # True - returns same singleton
144+
True
129145
"""
130-
return Stream._from_handle(cls, get_legacy_stream())
146+
# Return the singleton for the base Stream class
147+
if cls is Stream:
148+
return C_LEGACY_DEFAULT_STREAM
149+
# For subclasses, create a new instance of the subclass type
150+
else:
151+
return Stream._from_handle(cls, get_legacy_stream())
131152

132153
@classmethod
133154
def per_thread_default(cls):
@@ -139,18 +160,38 @@ cdef class Stream:
139160
non-blocking stream. This allows for better concurrency in multi-threaded
140161
applications.
141162
163+
Each thread has its own per-thread default stream, enabling true
164+
concurrent execution without implicit synchronization barriers.
165+
166+
This method returns the same singleton instance on every call for the
167+
base Stream class. Subclasses will receive new instances of the subclass
168+
type that wrap the same underlying CUDA stream.
169+
142170
Returns
143171
-------
144172
Stream
145-
The per-thread default stream instance for the current thread
146-
and context.
173+
The per-thread default stream singleton instance for the current
174+
thread and context.
147175
148176
See Also
149177
--------
150178
legacy_default : Legacy default stream alternative.
179+
from_handle : Create stream from existing handle.
151180
181+
Examples
182+
--------
183+
>>> from cuda.core import Stream
184+
>>> stream1 = Stream.per_thread_default()
185+
>>> stream2 = Stream.per_thread_default()
186+
>>> stream1 is stream2 # True - returns same singleton
187+
True
152188
"""
153-
return Stream._from_handle(cls, get_per_thread_stream())
189+
# Return the singleton for the base Stream class
190+
if cls is Stream:
191+
return C_PER_THREAD_DEFAULT_STREAM
192+
# For subclasses, create a new instance of the subclass type
193+
else:
194+
return Stream._from_handle(cls, get_per_thread_stream())
154195

155196
@classmethod
156197
def _init(cls, obj: IsStreamT | None = None, options=None, device_id: int = None,

cuda_core/tests/test_stream.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -130,19 +130,35 @@ class MyStream(Stream):
130130

131131

132132
def test_stream_legacy_default_public_api(init_cuda):
133-
"""Test public legacy_default() method."""
134-
stream = Stream.legacy_default()
135-
assert isinstance(stream, Stream)
136-
# Verify it's the same as LEGACY_DEFAULT_STREAM
137-
assert stream == LEGACY_DEFAULT_STREAM
133+
"""Test public legacy_default() method returns singleton."""
134+
stream1 = Stream.legacy_default()
135+
stream2 = Stream.legacy_default()
136+
137+
assert isinstance(stream1, Stream)
138+
assert isinstance(stream2, Stream)
139+
140+
# Verify singleton behavior - same Python object
141+
assert stream1 is stream2, "Should return same singleton instance"
142+
143+
# Verify it's the same as the module constant
144+
assert stream1 is LEGACY_DEFAULT_STREAM, "Should be the same object as LEGACY_DEFAULT_STREAM"
145+
assert stream2 is LEGACY_DEFAULT_STREAM, "Should be the same object as LEGACY_DEFAULT_STREAM"
138146

139147

140148
def test_stream_per_thread_default_public_api(init_cuda):
141-
"""Test public per_thread_default() method."""
142-
stream = Stream.per_thread_default()
143-
assert isinstance(stream, Stream)
144-
# Verify it's the same as PER_THREAD_DEFAULT_STREAM
145-
assert stream == PER_THREAD_DEFAULT_STREAM
149+
"""Test public per_thread_default() method returns singleton."""
150+
stream1 = Stream.per_thread_default()
151+
stream2 = Stream.per_thread_default()
152+
153+
assert isinstance(stream1, Stream)
154+
assert isinstance(stream2, Stream)
155+
156+
# Verify singleton behavior - same Python object
157+
assert stream1 is stream2, "Should return same singleton instance"
158+
159+
# Verify it's the same as the module constant
160+
assert stream1 is PER_THREAD_DEFAULT_STREAM, "Should be the same object as PER_THREAD_DEFAULT_STREAM"
161+
assert stream2 is PER_THREAD_DEFAULT_STREAM, "Should be the same object as PER_THREAD_DEFAULT_STREAM"
146162

147163

148164
# ============================================================================

0 commit comments

Comments
 (0)