diff --git a/docs/source/user_guide/index.md b/docs/source/user_guide/index.md index cf6658c433..e8c7ce278b 100644 --- a/docs/source/user_guide/index.md +++ b/docs/source/user_guide/index.md @@ -63,6 +63,7 @@ tile16 fastcache graph +streams perf_dispatch init_options ``` diff --git a/docs/source/user_guide/streams.md b/docs/source/user_guide/streams.md new file mode 100644 index 0000000000..a8db331bcc --- /dev/null +++ b/docs/source/user_guide/streams.md @@ -0,0 +1,137 @@ +# Streams + +Streams allow concurrent execution of GPU operations. By default, all Quadrants kernels launch on the default stream, which serializes everything. With streams, you can run multiple top-level for loops in parallel. + +## Supported platforms + +| Backend | Supported | +|---------|-----------| +| CUDA | Yes | +| AMDGPU | Yes | +| CPU | No-op | +| Metal | No-op | +| Vulkan | No-op | + +On backends without native stream support, stream operations are no-ops and for loops run serially. Code using streams is portable across all backends — it will run without modifications, but serially. + +## Stream parallelism + +Inside a `@qd.kernel`, each `with qd.stream_parallel():` block runs on its own GPU stream. + +```python +import quadrants as qd + +qd.init(arch=qd.cuda) + +N = 1024 +a = qd.field(qd.f32, shape=(N,)) +b = qd.field(qd.f32, shape=(N,)) +c = qd.field(qd.f32, shape=(N,)) + +@qd.kernel +def compute_ab(): + with qd.stream_parallel(): + for i in range(N): + a[i] = compute_a(i) + with qd.stream_parallel(): + for j in range(N): + b[j] = compute_b(j) + +@qd.kernel +def combine(): + for i in range(N): + c[i] = a[i] + b[i] + +compute_ab() # the two stream_parallel blocks run concurrently +combine() # runs after compute_ab() returns — a[] and b[] are ready +``` + +Consecutive `with qd.stream_parallel():` blocks run concurrently. Multiple for loops within a single block share a stream and run serially on it. All streams are synchronized before the kernel returns. + +### Restrictions + +- All top-level statements in a kernel must be either all `stream_parallel` blocks or all regular statements. Mixing the two at the top level is a compile-time error. +- Nesting `stream_parallel` blocks is not supported. + +## Explicit streams + +For cases that require manual control — such as launching separate kernels on different streams or interoperating with PyTorch — you can create and manage streams directly. + +### Creating and using streams + +Any `@qd.kernel` function accepts a special `qd_stream` keyword argument — you do not need to declare it in the kernel signature. The `@qd.kernel` decorator handles it automatically. + +```python +@qd.kernel +def my_kernel(): + for i in range(N): + a[i] = i + +s1 = qd.create_stream() +s2 = qd.create_stream() + +my_kernel(qd_stream=s1) +my_kernel(qd_stream=s2) + +s1.synchronize() +s2.synchronize() + +s1.destroy() +s2.destroy() +``` + +Kernels on different streams may execute concurrently. Call `synchronize()` to block until all work on a stream completes. + +### Events + +Events let you express dependencies between streams without full synchronization. + +```python +s1 = qd.create_stream() +s2 = qd.create_stream() + +@qd.kernel +def produce(): + for i in range(N): + a[i] = 10.0 + +@qd.kernel +def consume(): + for i in range(N): + b[i] = a[i] + +produce(qd_stream=s1) + +e = qd.create_event() +e.record(s1) # record when s1 finishes produce() +e.wait(qd_stream=s2) # s2 waits for that event before proceeding + +consume(qd_stream=s2) # safe to read a[] — produce() is guaranteed complete +s2.synchronize() + +e.destroy() +s1.destroy() +s2.destroy() +``` + +`e.record(stream)` captures the point in `stream`'s execution. `e.wait(qd_stream=stream)` makes `stream` wait until the recorded point is reached. If `qd_stream` is omitted, the default stream waits. + +### Context managers + +Streams and events support `with` blocks for automatic cleanup: + +```python +with qd.create_stream() as s: + some_func1(qd_stream=s) +# s.destroy() called automatically — waits for in-flight work +``` + +## Synchronization notes + +- **`qd.sync()` only waits on the default stream.** It does not drain explicit streams. Call `stream.synchronize()` on each stream you need to wait for. +- **No automatic synchronization with explicit streams.** When using explicit streams, you are responsible for inserting events or `synchronize()` calls when one stream's output is another stream's input. `stream_parallel` handles this automatically. + +## Limitations + +- **Not compatible with graphs.** Do not pass `qd_stream` to a kernel decorated with `graph=True` (if you do, a `RuntimeError` will be raised). +- **Not compatible with autodiff.** Do not pass `qd_stream` to a kernel that uses reverse-mode or forward-mode differentiation, or inside a `qd.ad.Tape` context (if you do, a `RuntimeError` will be raised). diff --git a/python/quadrants/lang/__init__.py b/python/quadrants/lang/__init__.py index 2fd0f8dd3f..12773e45c6 100644 --- a/python/quadrants/lang/__init__.py +++ b/python/quadrants/lang/__init__.py @@ -16,6 +16,7 @@ from quadrants.lang.runtime_ops import * from quadrants.lang.snode import * from quadrants.lang.source_builder import * +from quadrants.lang.stream import * from quadrants.lang.struct import * from quadrants.types.enums import DeviceCapability, Format, Layout # noqa: F401 @@ -47,6 +48,7 @@ "shell", "snode", "source_builder", + "stream", "struct", "util", ] diff --git a/python/quadrants/lang/ast/ast_transformer.py b/python/quadrants/lang/ast/ast_transformer.py index 8694a4e94e..263a4a11a3 100644 --- a/python/quadrants/lang/ast/ast_transformer.py +++ b/python/quadrants/lang/ast/ast_transformer.py @@ -119,7 +119,11 @@ def build_AnnAssign(ctx: ASTTransformerFuncContext, node: ast.AnnAssign): @staticmethod def build_assign_annotated( - ctx: ASTTransformerFuncContext, target: ast.Name, value, is_static_assign: bool, annotation: Type + ctx: ASTTransformerFuncContext, + target: ast.Name, + value, + is_static_assign: bool, + annotation: Type, ): """Build an annotated assignment like this: target: annotation = value. @@ -165,7 +169,10 @@ def build_Assign(ctx: ASTTransformerFuncContext, node: ast.Assign) -> None: @staticmethod def build_assign_unpack( - ctx: ASTTransformerFuncContext, node_target: list | ast.Tuple, values, is_static_assign: bool + ctx: ASTTransformerFuncContext, + node_target: list | ast.Tuple, + values, + is_static_assign: bool, ): """Build the unpack assignments like this: (target1, target2) = (value1, value2). The function should be called only if the node target is a tuple. @@ -591,7 +598,8 @@ def build_Return(ctx: ASTTransformerFuncContext, node: ast.Return) -> None: else: raise QuadrantsSyntaxError("The return type is not supported now!") ctx.ast_builder.create_kernel_exprgroup_return( - expr.make_expr_group(return_exprs), _qd_core.DebugInfo(ctx.get_pos_info(node)) + expr.make_expr_group(return_exprs), + _qd_core.DebugInfo(ctx.get_pos_info(node)), ) else: ctx.return_data = node.value.ptr @@ -1520,6 +1528,24 @@ def build_Continue(ctx: ASTTransformerFuncContext, node: ast.Continue) -> None: ctx.ast_builder.insert_continue_stmt(_qd_core.DebugInfo(ctx.get_pos_info(node))) return None + @staticmethod + def build_With(ctx: ASTTransformerFuncContext, node: ast.With) -> None: + if len(node.items) != 1: + raise QuadrantsSyntaxError("'with' in Quadrants kernels only supports a single context manager") + item = node.items[0] + if item.optional_vars is not None: + raise QuadrantsSyntaxError("'with ... as ...' is not supported in Quadrants kernels") + if not isinstance(item.context_expr, ast.Call): + raise QuadrantsSyntaxError("'with' in Quadrants kernels requires a call expression") + if not FunctionDefTransformer._is_stream_parallel_with(node, ctx.global_vars): + raise QuadrantsSyntaxError("'with' in Quadrants kernels only supports qd.stream_parallel()") + if not ctx.is_kernel: + raise QuadrantsSyntaxError("qd.stream_parallel() can only be used inside @qd.kernel, not @qd.func") + ctx.ast_builder.begin_stream_parallel() + build_stmts(ctx, node.body) + ctx.ast_builder.end_stream_parallel() + return None + @staticmethod def build_Pass(ctx: ASTTransformerFuncContext, node: ast.Pass) -> None: return None diff --git a/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py b/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py index 4c9bd5115b..60be2e916d 100644 --- a/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py +++ b/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py @@ -26,11 +26,13 @@ from quadrants.lang.ast.ast_transformer_utils import ( ASTTransformerFuncContext, ) +from quadrants.lang.ast.symbol_resolver import ASTResolver from quadrants.lang.buffer_view import BufferView from quadrants.lang.exception import ( QuadrantsSyntaxError, ) from quadrants.lang.matrix import MatrixType +from quadrants.lang.stream import stream_parallel from quadrants.lang.struct import StructType from quadrants.lang.util import to_quadrants_type from quadrants.types import annotations, buffer_view_type, ndarray_type, primitive_types @@ -317,7 +319,11 @@ def _transform_func_arg( # polymorphic). if field.type is not _TensorClass and hasattr(field.type, "check_matched"): field.type.check_matched(data_child.get_type(), field.name) - _cache = getattr(getattr(ctx, "global_context", None), "ndarray_to_any_array", None) + _cache = getattr( + getattr(ctx, "global_context", None), + "ndarray_to_any_array", + None, + ) promoted = _cache.get(id(data_child)) if _cache else None ctx.create_variable(flat_name, promoted if promoted is not None else data_child) elif dataclasses.is_dataclass(data_child): @@ -336,7 +342,13 @@ def _transform_func_arg( # Ndarray arguments are passed by reference. if isinstance(argument_type, (ndarray_type.NdarrayType)): if not isinstance( - data, (_ndarray.ScalarNdarray, matrix.VectorNdarray, matrix.MatrixNdarray, any_array.AnyArray) + data, + ( + _ndarray.ScalarNdarray, + matrix.VectorNdarray, + matrix.MatrixNdarray, + any_array.AnyArray, + ), ): raise QuadrantsSyntaxError(f"Argument {argument_name} of type {argument_type} is not recognized.") argument_type.check_matched(data.get_type(), argument_name) @@ -443,7 +455,70 @@ def build_FunctionDef( else: FunctionDefTransformer._transform_as_func(ctx, node, args) + if ctx.is_kernel: + FunctionDefTransformer._validate_stream_parallel_exclusivity(node.body, ctx.global_vars) + with ctx.variable_scope_guard(): build_stmts(ctx, node.body) return None + + @staticmethod + def _is_stream_parallel_with(stmt: ast.stmt, global_vars: dict[str, Any]) -> bool: + if not isinstance(stmt, ast.With): + return False + if len(stmt.items) != 1: + return False + item = stmt.items[0] + if not isinstance(item.context_expr, ast.Call): + return False + func_node = item.context_expr.func + if ASTResolver.resolve_to(func_node, stream_parallel, global_vars): + return True + resolved = ASTResolver.resolve_value(func_node, global_vars) + if resolved is not None: + return getattr(resolved, "__name__", None) == "stream_parallel" and getattr( + resolved, "__module__", "" + ).startswith("quadrants") + if isinstance(func_node, ast.Attribute) and func_node.attr == "stream_parallel": + return True + if isinstance(func_node, ast.Name) and func_node.id == "stream_parallel": + return True + return False + + @staticmethod + def _is_docstring(stmt: ast.stmt, index: int) -> bool: + return index == 0 and isinstance(stmt, ast.Expr) and isinstance(stmt.value, (ast.Constant, ast.Str)) + + @staticmethod + def _is_coverage_probe(stmt: ast.stmt) -> bool: + if not isinstance(stmt, ast.Assign) or len(stmt.targets) != 1: + return False + target = stmt.targets[0] + return ( + isinstance(target, ast.Subscript) + and isinstance(target.value, ast.Name) + and target.value.id.startswith("_qd_cov") + ) + + @staticmethod + def _validate_stream_parallel_exclusivity(body: list[ast.stmt], global_vars: dict[str, Any]) -> None: + if not any(FunctionDefTransformer._is_stream_parallel_with(s, global_vars) for s in body): + return + for i, stmt in enumerate(body): + if FunctionDefTransformer._is_docstring(stmt, i): + continue + if FunctionDefTransformer._is_coverage_probe(stmt): + continue + if not FunctionDefTransformer._is_stream_parallel_with(stmt, global_vars): + stmt_desc = f"{type(stmt).__name__}" + if isinstance(stmt, ast.With) and stmt.items: + ctx_expr = stmt.items[0].context_expr + if isinstance(ctx_expr, ast.Call) and isinstance(ctx_expr.func, ast.Attribute): + stmt_desc += f"(with {ast.dump(ctx_expr.func)})" + raise QuadrantsSyntaxError( + "When using qd.stream_parallel(), all top-level statements " + "in the kernel must be 'with qd.stream_parallel():' blocks. " + f"Move non-parallel code to a separate kernel. " + f"[stmt {i}: {stmt_desc}, body_len={len(body)}]" + ) diff --git a/python/quadrants/lang/ast/symbol_resolver.py b/python/quadrants/lang/ast/symbol_resolver.py index 81296fcefb..c2b4fcaffe 100644 --- a/python/quadrants/lang/ast/symbol_resolver.py +++ b/python/quadrants/lang/ast/symbol_resolver.py @@ -55,3 +55,35 @@ def resolve_to(node, wanted, scope): return False # The name ``scope`` here could be a bit confusing return scope is wanted + + @staticmethod + def resolve_value(node, scope): + """Resolve an AST Name/Attribute node to a Python object. + + Same traversal as resolve_to but returns the resolved object (or None) instead of comparing against a wanted + value. + """ + if isinstance(node, ast.Name): + return scope.get(node.id) if isinstance(scope, dict) else None + + if not isinstance(node, ast.Attribute): + return None + + v = node.value + chain = [node.attr] + while isinstance(v, ast.Attribute): + chain.append(v.attr) + v = v.value + if not isinstance(v, ast.Name): + return None + chain.append(v.id) + + for attr in reversed(chain): + try: + if isinstance(scope, dict): + scope = scope[attr] + else: + scope = getattr(scope, attr) + except (KeyError, AttributeError): + return None + return scope diff --git a/python/quadrants/lang/kernel.py b/python/quadrants/lang/kernel.py index 5ebd5ff70d..0b45a5816b 100644 --- a/python/quadrants/lang/kernel.py +++ b/python/quadrants/lang/kernel.py @@ -453,7 +453,9 @@ def materialize(self, key: "CompiledKernelKeyType | None", py_args: tuple[Any, . ] runtime._current_global_context = None - def launch_kernel(self, key, t_kernel: KernelCxx, compiled_kernel_data: CompiledKernelData | None, *args) -> Any: + def launch_kernel( + self, key, t_kernel: KernelCxx, compiled_kernel_data: CompiledKernelData | None, *args, qd_stream=None + ) -> Any: assert len(args) == len(self.arg_metas), f"{len(self.arg_metas)} arguments needed but {len(args)} provided" callbacks: list[Callable[[], None]] = [] @@ -567,9 +569,21 @@ def launch_kernel(self, key, t_kernel: KernelCxx, compiled_kernel_data: Compiled self.src_ll_cache_observations.cache_stored = True self._last_compiled_kernel_data = compiled_kernel_data launch_ctx.use_graph = self.use_graph and _GRAPH_ENABLED + if self.use_graph and qd_stream is not None: + raise RuntimeError( + "qd_stream is not compatible with graph=True kernels. " + "See docs/source/user_guide/streams.md for details." + ) if self.graph_do_while_arg is not None and hasattr(self, "_graph_do_while_cpp_arg_id"): launch_ctx.graph_do_while_arg_id = self._graph_do_while_cpp_arg_id - prog.launch_kernel(compiled_kernel_data, launch_ctx) + stream_handle = qd_stream.handle if qd_stream is not None else 0 + if stream_handle: + prog.set_current_cuda_stream(stream_handle) + try: + prog.launch_kernel(compiled_kernel_data, launch_ctx) + finally: + if stream_handle: + prog.set_current_cuda_stream(0) except Exception as e: e = handle_exception_from_cpp(e) if impl.get_runtime().print_full_traceback: @@ -581,6 +595,8 @@ def launch_kernel(self, key, t_kernel: KernelCxx, compiled_kernel_data: Compiled return_type = self.return_type if return_type or self.has_print: + if qd_stream is not None and self.has_print and not return_type: + qd_stream.synchronize() runtime_ops.sync() if not return_type: @@ -647,6 +663,17 @@ def ensure_compiled(self, *py_args: tuple[Any, ...]) -> tuple[Callable, int, Aut # Thus this part needs to be fast. (i.e. < 3us on a 4 GHz x64 CPU) @_shell_pop_print def __call__(self, *py_args, **kwargs) -> Any: + qd_stream = kwargs.pop("qd_stream", None) + if qd_stream is not None and self.autodiff_mode != _NONE: + raise RuntimeError( + "qd_stream is not compatible with autodiff kernels. Streams cannot be used with " + "reverse-mode or forward-mode differentiation." + ) + if qd_stream is not None and self.runtime.target_tape: + raise RuntimeError( + "qd_stream is not compatible with autograd Tape. Launch the kernel outside the Tape " + "context, or omit qd_stream." + ) if impl.get_runtime()._arch == _ARCH_PYTHON: return self.func(*py_args, **kwargs) config = impl.current_cfg() @@ -709,7 +736,7 @@ def __call__(self, *py_args, **kwargs) -> Any: kernel_cpp = self.materialized_kernels[key] compiled_kernel_data = self.compiled_kernel_data_by_key.get(key, None) self.launch_observations.found_kernel_in_materialize_cache = compiled_kernel_data is not None - ret = self.launch_kernel(key, kernel_cpp, compiled_kernel_data, *py_args) + ret = self.launch_kernel(key, kernel_cpp, compiled_kernel_data, *py_args, qd_stream=qd_stream) if compiled_kernel_data is None: assert self._last_compiled_kernel_data is not None self.compiled_kernel_data_by_key[key] = self._last_compiled_kernel_data diff --git a/python/quadrants/lang/runtime_ops.py b/python/quadrants/lang/runtime_ops.py index 0ecd122f56..71919e2379 100644 --- a/python/quadrants/lang/runtime_ops.py +++ b/python/quadrants/lang/runtime_ops.py @@ -4,8 +4,10 @@ def sync(): - """Blocks the calling thread until all the previously - launched Quadrants kernels have completed. + """Synchronizes the default stream. + + Blocks the calling thread until all work on the default GPU stream has completed. Kernels launched on explicit + streams created via :func:`quadrants.create_stream` are **not** waited on — call ``stream.synchronize()`` for those. """ impl.get_runtime().sync() diff --git a/python/quadrants/lang/stream.py b/python/quadrants/lang/stream.py new file mode 100644 index 0000000000..3f734587b3 --- /dev/null +++ b/python/quadrants/lang/stream.py @@ -0,0 +1,188 @@ +import weakref +from contextlib import contextmanager + +from quadrants.lang import impl + + +def _get_prog_weakref(): + return weakref.ref(impl.get_runtime().prog) + + +class Stream: + """Wraps a backend-specific GPU stream for concurrent kernel execution. + + On backends without native streams (e.g. CPU), this is a no-op object. Call destroy() explicitly or use as + a context manager to ensure cleanup. + """ + + def __init__(self, handle: int, prog_ref: weakref.ref | None = None): + self._handle = handle + self._prog_ref = prog_ref + + @property + def handle(self) -> int: + return self._handle + + def _prog(self): + """Resolve the owning Program, or None if the owner was collected.""" + if self._prog_ref is not None: + return self._prog_ref() + return impl.get_runtime().prog + + def synchronize(self): + """Block until all operations on this stream complete.""" + prog = self._prog() + if prog is None: + raise RuntimeError("Stream's owning Program has been destroyed (e.g. after qd.reset())") + prog.stream_synchronize(self._handle) + + def _destroy_prog(self): + """Resolve a Program for resource cleanup. + + Falls back to the current runtime when the owner has been collected, which is safe because CUDAContext is a + singleton so the CUDA stream handle remains valid. + """ + prog = self._prog() + if prog is None: + try: + return impl.get_runtime().prog + except Exception: + return None + return prog + + def destroy(self): + """Explicitly destroy the stream. Safe to call multiple times. + + No-op for streams wrapping external handles (created via Stream(ptr) without a prog_ref). + """ + if self._handle != 0 and self._prog_ref is not None: + prog = self._destroy_prog() + if prog is not None: + prog.stream_destroy(self._handle) + self._handle = 0 + + def __del__(self): + if self._handle != 0 and self._prog_ref is not None: + prog = self._destroy_prog() + if prog is not None: + try: + prog.stream_destroy(self._handle) + self._handle = 0 + except Exception: + pass + + def __enter__(self): + return self + + def __exit__(self, *args): + self.destroy() + + +class Event: + """Wraps a backend-specific GPU event for stream synchronization. + + On backends without native events (e.g. CPU), this is a no-op object. Call destroy() explicitly or use as + a context manager to ensure cleanup. + """ + + def __init__(self, handle: int, prog_ref: weakref.ref | None = None): + self._handle = handle + self._prog_ref = prog_ref + + @property + def handle(self) -> int: + return self._handle + + def _prog(self): + """Resolve the owning Program, or None if the owner was collected.""" + if self._prog_ref is not None: + return self._prog_ref() + return impl.get_runtime().prog + + def _require_prog(self): + prog = self._prog() + if prog is None: + raise RuntimeError("Event's owning Program has been destroyed (e.g. after qd.reset())") + return prog + + def record(self, qd_stream: Stream | None = None): + """Record this event on a stream. None means the default stream.""" + stream_handle = qd_stream.handle if qd_stream is not None else 0 + self._require_prog().event_record(self._handle, stream_handle) + + def wait(self, qd_stream: Stream | None = None): + """Make a stream wait for this event. None means the default stream.""" + stream_handle = qd_stream.handle if qd_stream is not None else 0 + self._require_prog().stream_wait_event(stream_handle, self._handle) + + def synchronize(self): + """Block the host until this event has been reached.""" + self._require_prog().event_synchronize(self._handle) + + def _destroy_prog(self): + """Resolve a Program for resource cleanup. + + Falls back to the current runtime when the owner has been collected, which is safe because CUDAContext is a + singleton so the CUDA event handle remains valid. + """ + prog = self._prog() + if prog is None: + try: + return impl.get_runtime().prog + except Exception: + return None + return prog + + def destroy(self): + """Explicitly destroy the event. Safe to call multiple times. + + No-op for events wrapping external handles (created via Event(ptr) without a prog_ref). + """ + if self._handle != 0 and self._prog_ref is not None: + prog = self._destroy_prog() + if prog is not None: + prog.event_destroy(self._handle) + self._handle = 0 + + def __del__(self): + if self._handle != 0 and self._prog_ref is not None: + prog = self._destroy_prog() + if prog is not None: + try: + prog.event_destroy(self._handle) + self._handle = 0 + except Exception: + pass + + def __enter__(self): + return self + + def __exit__(self, *args): + self.destroy() + + +def create_stream() -> Stream: + """Create a new GPU stream for concurrent kernel execution.""" + prog = impl.get_runtime().prog + handle = prog.stream_create() + return Stream(handle, _get_prog_weakref()) + + +def create_event() -> Event: + """Create a new GPU event for stream synchronization.""" + prog = impl.get_runtime().prog + handle = prog.event_create() + return Event(handle, _get_prog_weakref()) + + +@contextmanager +def stream_parallel(): + """Run top-level for loops in this block on separate GPU streams. + + Used inside @qd.kernel. At Python runtime (outside kernels), this is a no-op. During kernel compilation, the AST + transformer calls into the C++ ASTBuilder to tag loops with a stream-parallel group ID. + """ + yield + + +__all__ = ["Stream", "Event", "create_stream", "create_event", "stream_parallel"] diff --git a/quadrants/analysis/gen_offline_cache_key.cpp b/quadrants/analysis/gen_offline_cache_key.cpp index 66f03aab20..24fa3ce435 100644 --- a/quadrants/analysis/gen_offline_cache_key.cpp +++ b/quadrants/analysis/gen_offline_cache_key.cpp @@ -377,6 +377,7 @@ class ASTSerializer : public IRVisitor, public ExpressionVisitor { emit(stmt->strictly_serialized); emit(stmt->mem_access_opt); emit(stmt->block_dim); + emit(stmt->stream_parallel_group_id); emit(stmt->body.get()); } diff --git a/quadrants/codegen/amdgpu/codegen_amdgpu.cpp b/quadrants/codegen/amdgpu/codegen_amdgpu.cpp index 7eb23a7a2e..df8a9aaeae 100644 --- a/quadrants/codegen/amdgpu/codegen_amdgpu.cpp +++ b/quadrants/codegen/amdgpu/codegen_amdgpu.cpp @@ -351,6 +351,7 @@ class TaskCodeGenAMDGPU : public TaskCodeGenLLVM { current_task->grid_dim = num_SMs * query_max_block_per_sm; } current_task->block_dim = stmt->block_dim; + current_task->stream_parallel_group_id = stmt->stream_parallel_group_id; QD_ASSERT(current_task->grid_dim != 0); QD_ASSERT(current_task->block_dim != 0); // Host-side adstack sizing, same scheme as codegen_cuda: tight `grid_dim * block_dim` for diff --git a/quadrants/codegen/cuda/codegen_cuda.cpp b/quadrants/codegen/cuda/codegen_cuda.cpp index 64d5b0f283..87fbb9abc9 100644 --- a/quadrants/codegen/cuda/codegen_cuda.cpp +++ b/quadrants/codegen/cuda/codegen_cuda.cpp @@ -638,6 +638,7 @@ class TaskCodeGenCUDA : public TaskCodeGenLLVM { } current_task->block_dim = stmt->block_dim; current_task->dynamic_shared_array_bytes = dynamic_shared_array_bytes; + current_task->stream_parallel_group_id = stmt->stream_parallel_group_id; QD_ASSERT(current_task->grid_dim != 0); QD_ASSERT(current_task->block_dim != 0); // Host-side adstack sizing. For non-range_for and for const-bound range_for the launcher uses diff --git a/quadrants/codegen/llvm/llvm_compiled_data.h b/quadrants/codegen/llvm/llvm_compiled_data.h index 4da1510113..ef5fd19201 100644 --- a/quadrants/codegen/llvm/llvm_compiled_data.h +++ b/quadrants/codegen/llvm/llvm_compiled_data.h @@ -107,6 +107,7 @@ class OffloadedTask { int block_dim{0}; int grid_dim{0}; int dynamic_shared_array_bytes{0}; + int stream_parallel_group_id{0}; AdStackSizingInfo ad_stack{}; // Snode IDs this task writes to (read-modify-write counts as a write). Computed at codegen time @@ -132,9 +133,22 @@ class OffloadedTask { explicit OffloadedTask(const std::string &name = "", int block_dim = 0, int grid_dim = 0, - int dynamic_shared_array_bytes = 0) - : name(name), block_dim(block_dim), grid_dim(grid_dim), dynamic_shared_array_bytes(dynamic_shared_array_bytes) {}; - QD_IO_DEF(name, block_dim, grid_dim, dynamic_shared_array_bytes, ad_stack, snode_writes, arr_writes, arr_reads); + int dynamic_shared_array_bytes = 0, + int stream_parallel_group_id = 0) + : name(name), + block_dim(block_dim), + grid_dim(grid_dim), + dynamic_shared_array_bytes(dynamic_shared_array_bytes), + stream_parallel_group_id(stream_parallel_group_id) {}; + QD_IO_DEF(name, + block_dim, + grid_dim, + dynamic_shared_array_bytes, + stream_parallel_group_id, + ad_stack, + snode_writes, + arr_writes, + arr_reads); }; struct LLVMCompiledTask { diff --git a/quadrants/ir/frontend_ir.cpp b/quadrants/ir/frontend_ir.cpp index 4e118753ee..3c750d4ff9 100644 --- a/quadrants/ir/frontend_ir.cpp +++ b/quadrants/ir/frontend_ir.cpp @@ -110,6 +110,7 @@ FrontendForStmt::FrontendForStmt(const FrontendForStmt &o) strictly_serialized(o.strictly_serialized), mem_access_opt(o.mem_access_opt), block_dim(o.block_dim), + stream_parallel_group_id(o.stream_parallel_group_id), loop_name(o.loop_name) { } @@ -118,6 +119,7 @@ void FrontendForStmt::init_config(Arch arch, const ForLoopConfig &config) { strictly_serialized = config.strictly_serialized; mem_access_opt = config.mem_access_opt; block_dim = config.block_dim; + stream_parallel_group_id = config.stream_parallel_group_id; loop_name = config.loop_name; if (arch == Arch::cuda || arch == Arch::amdgpu) { num_cpu_threads = 1; @@ -1390,6 +1392,7 @@ void ASTBuilder::create_assert_stmt(const Expr &cond, } void ASTBuilder::begin_frontend_range_for(const Expr &i, const Expr &s, const Expr &e, const DebugInfo &dbg_info) { + for_loop_dec_.config.stream_parallel_group_id = current_stream_parallel_group_id_; auto stmt_unique = std::make_unique(i, s, e, arch_, for_loop_dec_.config, dbg_info); auto stmt = stmt_unique.get(); this->insert(std::move(stmt_unique)); @@ -1403,6 +1406,7 @@ void ASTBuilder::begin_frontend_struct_for_on_snode(const ExprGroup &loop_vars, QD_WARN_IF(for_loop_dec_.config.strictly_serialized, "ti.loop_config(serialize=True) does not have effect on the struct for. " "The execution order is not guaranteed."); + for_loop_dec_.config.stream_parallel_group_id = current_stream_parallel_group_id_; auto stmt_unique = std::make_unique(loop_vars, snode, arch_, for_loop_dec_.config, dbg_info); for_loop_dec_.reset(); auto stmt = stmt_unique.get(); @@ -1416,6 +1420,7 @@ void ASTBuilder::begin_frontend_struct_for_on_external_tensor(const ExprGroup &l QD_WARN_IF(for_loop_dec_.config.strictly_serialized, "ti.loop_config(serialize=True) does not have effect on the struct for. " "The execution order is not guaranteed."); + for_loop_dec_.config.stream_parallel_group_id = current_stream_parallel_group_id_; auto stmt_unique = std::make_unique(loop_vars, external_tensor, arch_, for_loop_dec_.config, dbg_info); for_loop_dec_.reset(); @@ -1431,6 +1436,7 @@ void ASTBuilder::begin_frontend_mesh_for(const Expr &i, QD_WARN_IF(for_loop_dec_.config.strictly_serialized, "ti.loop_config(serialize=True) does not have effect on the mesh for. " "The execution order is not guaranteed."); + for_loop_dec_.config.stream_parallel_group_id = current_stream_parallel_group_id_; auto stmt_unique = std::make_unique(ExprGroup(i), mesh_ptr, element_type, arch_, for_loop_dec_.config, dbg_info); for_loop_dec_.reset(); diff --git a/quadrants/ir/frontend_ir.h b/quadrants/ir/frontend_ir.h index 7d2c7bd9df..b4ad04a9b5 100644 --- a/quadrants/ir/frontend_ir.h +++ b/quadrants/ir/frontend_ir.h @@ -23,6 +23,7 @@ struct ForLoopConfig { MemoryAccessOptions mem_access_opt; int block_dim{0}; bool uniform{false}; + int stream_parallel_group_id{0}; std::string loop_name{""}; }; @@ -198,6 +199,7 @@ class FrontendForStmt : public Stmt { bool strictly_serialized; MemoryAccessOptions mem_access_opt; int block_dim; + int stream_parallel_group_id{0}; std::string loop_name; FrontendForStmt(const ExprGroup &loop_vars, @@ -887,6 +889,7 @@ class ASTBuilder { config.mem_access_opt.clear(); config.block_dim = 0; config.strictly_serialized = false; + config.stream_parallel_group_id = 0; config.loop_name.clear(); } }; @@ -897,6 +900,8 @@ class ASTBuilder { Arch arch_; ForLoopDecoratorRecorder for_loop_dec_; int id_counter_{0}; + int stream_parallel_group_counter_{0}; + int current_stream_parallel_group_id_{0}; public: ASTBuilder(Block *initial, Arch arch, bool is_kernel) : is_kernel_(is_kernel), arch_(arch) { @@ -1022,6 +1027,15 @@ class ASTBuilder { for_loop_dec_.reset(); } + void begin_stream_parallel() { + QD_ERROR_IF(current_stream_parallel_group_id_ != 0, "Nested stream_parallel blocks are not supported"); + current_stream_parallel_group_id_ = ++stream_parallel_group_counter_; + } + + void end_stream_parallel() { + current_stream_parallel_group_id_ = 0; + } + Identifier get_next_id(const std::string &name = "") { return Identifier(id_counter_++, name); } diff --git a/quadrants/ir/statements.cpp b/quadrants/ir/statements.cpp index 75f66a7475..9adebe8e87 100644 --- a/quadrants/ir/statements.cpp +++ b/quadrants/ir/statements.cpp @@ -222,6 +222,7 @@ std::unique_ptr RangeForStmt::clone() const { auto new_stmt = std::make_unique(begin, end, body->clone(), is_bit_vectorized, num_cpu_threads, block_dim, strictly_serialized); new_stmt->reversed = reversed; + new_stmt->stream_parallel_group_id = stream_parallel_group_id; new_stmt->loop_name = loop_name; return new_stmt; } @@ -243,6 +244,7 @@ StructForStmt::StructForStmt(SNode *snode, std::unique_ptr StructForStmt::clone() const { auto new_stmt = std::make_unique(snode, body->clone(), is_bit_vectorized, num_cpu_threads, block_dim); new_stmt->mem_access_opt = mem_access_opt; + new_stmt->stream_parallel_group_id = stream_parallel_group_id; new_stmt->loop_name = loop_name; return new_stmt; } @@ -402,6 +404,7 @@ std::unique_ptr OffloadedStmt::clone() const { new_stmt->tls_size = tls_size; new_stmt->bls_size = bls_size; new_stmt->mem_access_opt = mem_access_opt; + new_stmt->stream_parallel_group_id = stream_parallel_group_id; new_stmt->loop_name = loop_name; // Shared-pointer copy: the captured trip-count `SizeExpr` is read-only after `determine_ad_stack_size` // populates it in `compile_to_offloads`, and LLVM codegen clones each offload at `codegen.cpp:68` diff --git a/quadrants/ir/statements.h b/quadrants/ir/statements.h index 3768f52bf1..2426396dab 100644 --- a/quadrants/ir/statements.h +++ b/quadrants/ir/statements.h @@ -955,6 +955,7 @@ class RangeForStmt : public Stmt { int block_dim; bool strictly_serialized; std::string range_hint; + int stream_parallel_group_id{0}; std::string loop_name; RangeForStmt(Stmt *begin, @@ -977,7 +978,14 @@ class RangeForStmt : public Stmt { std::unique_ptr clone() const override; - QD_STMT_DEF_FIELDS(begin, end, reversed, is_bit_vectorized, num_cpu_threads, block_dim, strictly_serialized); + QD_STMT_DEF_FIELDS(begin, + end, + reversed, + is_bit_vectorized, + num_cpu_threads, + block_dim, + strictly_serialized, + stream_parallel_group_id); QD_DEFINE_ACCEPT }; @@ -996,6 +1004,7 @@ class StructForStmt : public Stmt { int num_cpu_threads; int block_dim; MemoryAccessOptions mem_access_opt; + int stream_parallel_group_id{0}; std::string loop_name; StructForStmt(SNode *snode, @@ -1010,7 +1019,13 @@ class StructForStmt : public Stmt { std::unique_ptr clone() const override; - QD_STMT_DEF_FIELDS(snode, index_offsets, is_bit_vectorized, num_cpu_threads, block_dim, mem_access_opt); + QD_STMT_DEF_FIELDS(snode, + index_offsets, + is_bit_vectorized, + num_cpu_threads, + block_dim, + mem_access_opt, + stream_parallel_group_id); QD_DEFINE_ACCEPT }; @@ -1352,6 +1367,7 @@ class OffloadedStmt : public Stmt { std::size_t tls_size{1}; // avoid allocating dynamic memory with 0 byte std::size_t bls_size{0}; MemoryAccessOptions mem_access_opt; + int stream_parallel_group_id{0}; // Pre-chunking loop trip-count `SizeExpr` captured by `determine_ad_stack_size`. Set on adstack-bearing // range-for tasks before `make_cpu_multithreaded_range_for` rewrites the loop into per-thread chunks, so the @@ -1399,7 +1415,8 @@ class OffloadedStmt : public Stmt { reversed, num_cpu_threads, index_offsets, - mem_access_opt); + mem_access_opt, + stream_parallel_group_id); QD_DEFINE_ACCEPT }; diff --git a/quadrants/program/program.cpp b/quadrants/program/program.cpp index d994cc2fac..17b0513e51 100644 --- a/quadrants/program/program.cpp +++ b/quadrants/program/program.cpp @@ -66,6 +66,7 @@ Program::Program(Arch desired_arch) config = default_compile_config; config.arch = desired_arch; config.fit(); + stream_manager_ = StreamManager(config.arch); profiler = make_profiler(config.arch, config.kernel_profiler); if (arch_uses_llvm(config.arch)) { diff --git a/quadrants/program/program.h b/quadrants/program/program.h index 3d8b7b7425..ce4e5142f4 100644 --- a/quadrants/program/program.h +++ b/quadrants/program/program.h @@ -25,6 +25,7 @@ #include "quadrants/program/kernel_profiler.h" #include "quadrants/program/snode_expr_utils.h" #include "quadrants/program/snode_rw_accessors_bank.h" +#include "quadrants/program/program_stream.h" #include "quadrants/program/context.h" #include "quadrants/struct/snode_tree.h" #include "quadrants/system/threading.h" @@ -343,6 +344,10 @@ class QD_DLL_EXPORT Program { return ndarrays_.size(); } + StreamManager &stream_manager() { + return stream_manager_; + } + // TODO(zhanlue): Move these members and corresponding interfaces to ProgramImpl Ideally, Program should serve as a // pure interface class and all the implementations should fall inside ProgramImpl // @@ -351,6 +356,7 @@ class QD_DLL_EXPORT Program { private: CompileConfig compile_config_; + StreamManager stream_manager_{Arch::x64}; // re-initialized in constructor after arch is known uint64 ndarray_writer_counter_{0}; uint64 ndarray_reader_counter_{0}; diff --git a/quadrants/program/program_stream.cpp b/quadrants/program/program_stream.cpp new file mode 100644 index 0000000000..9686a86332 --- /dev/null +++ b/quadrants/program/program_stream.cpp @@ -0,0 +1,170 @@ +// StreamManager implementation and Program delegation. + +#include "program_stream.h" + +#ifdef QD_WITH_CUDA +#include "quadrants/rhi/cuda/cuda_driver.h" +#include "quadrants/rhi/cuda/cuda_context.h" +#endif + +#ifdef QD_WITH_AMDGPU +#include "quadrants/rhi/amdgpu/amdgpu_driver.h" +#include "quadrants/rhi/amdgpu/amdgpu_context.h" +#endif + +namespace quadrants::lang { + +// --------------------------------------------------------------------------- +// StreamManager +// --------------------------------------------------------------------------- + +uint64 StreamManager::create_stream() { +#ifdef QD_WITH_CUDA + if (arch_ == Arch::cuda) { + CUDAContext::get_instance().make_current(); + void *stream = nullptr; + CUDADriver::get_instance().stream_create(&stream, 0x1 /*CU_STREAM_NON_BLOCKING*/); + return reinterpret_cast(stream); + } +#endif +#ifdef QD_WITH_AMDGPU + if (arch_ == Arch::amdgpu) { + AMDGPUContext::get_instance().make_current(); + void *stream = nullptr; + AMDGPUDriver::get_instance().stream_create(&stream, 0x1 /*HIP_STREAM_NON_BLOCKING*/); + return reinterpret_cast(stream); + } +#endif + return 0; +} + +void StreamManager::destroy_stream(uint64 stream_handle) { +#ifdef QD_WITH_CUDA + if (arch_ == Arch::cuda && stream_handle != 0) { + CUDAContext::get_instance().make_current(); + CUDADriver::get_instance().stream_destroy(reinterpret_cast(stream_handle)); + } +#endif +#ifdef QD_WITH_AMDGPU + if (arch_ == Arch::amdgpu && stream_handle != 0) { + AMDGPUContext::get_instance().make_current(); + AMDGPUDriver::get_instance().stream_destroy(reinterpret_cast(stream_handle)); + } +#endif +} + +void StreamManager::synchronize_stream(uint64 stream_handle) { +#ifdef QD_WITH_CUDA + if (arch_ == Arch::cuda) { + CUDAContext::get_instance().make_current(); + CUDADriver::get_instance().stream_synchronize(reinterpret_cast(stream_handle)); + } +#endif +#ifdef QD_WITH_AMDGPU + if (arch_ == Arch::amdgpu) { + AMDGPUContext::get_instance().make_current(); + AMDGPUDriver::get_instance().stream_synchronize(reinterpret_cast(stream_handle)); + } +#endif +} + +void StreamManager::set_current_stream(uint64 stream_handle) { +#ifdef QD_WITH_CUDA + if (arch_ == Arch::cuda) { + CUDAContext::get_instance().make_current(); + CUDAContext::get_instance().set_stream(reinterpret_cast(stream_handle)); + } +#endif +#ifdef QD_WITH_AMDGPU + if (arch_ == Arch::amdgpu) { + AMDGPUContext::get_instance().make_current(); + AMDGPUContext::get_instance().set_stream(reinterpret_cast(stream_handle)); + } +#endif +} + +uint64 StreamManager::create_event() { +#ifdef QD_WITH_CUDA + if (arch_ == Arch::cuda) { + CUDAContext::get_instance().make_current(); + void *event = nullptr; + CUDADriver::get_instance().event_create(&event, 0x02 /*CU_EVENT_DISABLE_TIMING*/); + return reinterpret_cast(event); + } +#endif +#ifdef QD_WITH_AMDGPU + if (arch_ == Arch::amdgpu) { + AMDGPUContext::get_instance().make_current(); + void *event = nullptr; + AMDGPUDriver::get_instance().event_create(&event, 0x02 /*hipEventDisableTiming*/); + return reinterpret_cast(event); + } +#endif + return 0; +} + +void StreamManager::destroy_event(uint64 event_handle) { +#ifdef QD_WITH_CUDA + if (arch_ == Arch::cuda && event_handle != 0) { + CUDAContext::get_instance().make_current(); + CUDADriver::get_instance().event_destroy(reinterpret_cast(event_handle)); + } +#endif +#ifdef QD_WITH_AMDGPU + if (arch_ == Arch::amdgpu && event_handle != 0) { + AMDGPUContext::get_instance().make_current(); + AMDGPUDriver::get_instance().event_destroy(reinterpret_cast(event_handle)); + } +#endif +} + +void StreamManager::record_event(uint64 event_handle, uint64 stream_handle) { +#ifdef QD_WITH_CUDA + if (arch_ == Arch::cuda && event_handle != 0) { + CUDAContext::get_instance().make_current(); + CUDADriver::get_instance().event_record(reinterpret_cast(event_handle), + reinterpret_cast(stream_handle)); + } +#endif +#ifdef QD_WITH_AMDGPU + if (arch_ == Arch::amdgpu && event_handle != 0) { + AMDGPUContext::get_instance().make_current(); + AMDGPUDriver::get_instance().event_record(reinterpret_cast(event_handle), + reinterpret_cast(stream_handle)); + } +#endif +} + +void StreamManager::synchronize_event(uint64 event_handle) { +#ifdef QD_WITH_CUDA + if (arch_ == Arch::cuda && event_handle != 0) { + CUDAContext::get_instance().make_current(); + CUDADriver::get_instance().event_synchronize(reinterpret_cast(event_handle)); + } +#endif +#ifdef QD_WITH_AMDGPU + if (arch_ == Arch::amdgpu && event_handle != 0) { + AMDGPUContext::get_instance().make_current(); + AMDGPUDriver::get_instance().event_synchronize(reinterpret_cast(event_handle)); + } +#endif +} + +void StreamManager::stream_wait_event(uint64 stream_handle, uint64 event_handle) { +#ifdef QD_WITH_CUDA + if (arch_ == Arch::cuda && event_handle != 0) { + CUDAContext::get_instance().make_current(); + CUDADriver::get_instance().stream_wait_event(reinterpret_cast(stream_handle), + reinterpret_cast(event_handle), 0 /*flags*/); + } +#endif +#ifdef QD_WITH_AMDGPU + if (arch_ == Arch::amdgpu && event_handle != 0) { + AMDGPUContext::get_instance().make_current(); + AMDGPUDriver::get_instance().stream_wait_event(reinterpret_cast(stream_handle), + reinterpret_cast(event_handle), 0 /*flags*/); + } +#endif +} + +} // namespace quadrants::lang diff --git a/quadrants/program/program_stream.h b/quadrants/program/program_stream.h new file mode 100644 index 0000000000..69265c26b3 --- /dev/null +++ b/quadrants/program/program_stream.h @@ -0,0 +1,31 @@ +// StreamManager — manages CUDA stream and event lifecycle, isolated from Program so that backend-specific GPU +// plumbing does not pollute the core Program interface. + +#pragma once + +#include "quadrants/common/core.h" +#include "quadrants/util/lang_util.h" + +namespace quadrants::lang { + +class StreamManager { + public: + explicit StreamManager(Arch arch) : arch_(arch) { + } + + uint64 create_stream(); + void destroy_stream(uint64 stream_handle); + void synchronize_stream(uint64 stream_handle); + void set_current_stream(uint64 stream_handle); + + uint64 create_event(); + void destroy_event(uint64 event_handle); + void record_event(uint64 event_handle, uint64 stream_handle); + void synchronize_event(uint64 event_handle); + void stream_wait_event(uint64 stream_handle, uint64 event_handle); + + private: + Arch arch_; +}; + +} // namespace quadrants::lang diff --git a/quadrants/python/export.h b/quadrants/python/export.h index 331c35b4b6..92736daedf 100644 --- a/quadrants/python/export.h +++ b/quadrants/python/export.h @@ -21,6 +21,10 @@ #include "quadrants/common/core.h" +namespace quadrants::lang { +class Program; +} // namespace quadrants::lang + namespace quadrants { namespace py = pybind11; @@ -33,4 +37,6 @@ void export_math(py::module &m); void export_misc(py::module &m); +void export_stream(py::module &m, py::class_ &program_class); + } // namespace quadrants diff --git a/quadrants/python/export_lang.cpp b/quadrants/python/export_lang.cpp index 2d352f4473..647f1f5b70 100644 --- a/quadrants/python/export_lang.cpp +++ b/quadrants/python/export_lang.cpp @@ -310,7 +310,9 @@ void export_lang(py::module &m) { .def("strictly_serialize", &ASTBuilder::strictly_serialize) .def("block_dim", &ASTBuilder::block_dim) .def("insert_snode_access_flag", &ASTBuilder::insert_snode_access_flag) - .def("reset_snode_access_flag", &ASTBuilder::reset_snode_access_flag); + .def("reset_snode_access_flag", &ASTBuilder::reset_snode_access_flag) + .def("begin_stream_parallel", &ASTBuilder::begin_stream_parallel) + .def("end_stream_parallel", &ASTBuilder::end_stream_parallel); auto device_capability_config = py::class_(m, "DeviceCapabilityConfig").def("get", &DeviceCapabilityConfig::get); @@ -318,8 +320,8 @@ void export_lang(py::module &m) { auto compiled_kernel_data = py::class_(m, "CompiledKernelData") .def("_debug_dump_to_string", &CompiledKernelData::debug_dump_to_string); - py::class_(m, "Program") - .def(py::init<>()) + auto program_class = py::class_(m, "Program"); + program_class.def(py::init<>()) .def( "ndarray_to_dlpack", [](Program *program, pybind11::object owner, Ndarray *ndarray, const std::vector &layout, @@ -422,6 +424,7 @@ void export_lang(py::module &m) { [](Program *program) { return program->adstack_cache().max_reducer_dispatch_count(); }) .def("_reset_max_reducer_dispatch_count", [](Program *program) { program->adstack_cache().reset_max_reducer_dispatch_count(); }); + export_stream(m, program_class); py::class_(m, "CompileResult") .def_property_readonly( diff --git a/quadrants/python/export_stream.cpp b/quadrants/python/export_stream.cpp new file mode 100644 index 0000000000..66b3c8a3d7 --- /dev/null +++ b/quadrants/python/export_stream.cpp @@ -0,0 +1,25 @@ +/******************************************************************************* + Copyright (c) The Quadrants Authors (2016- ). All Rights Reserved. + The use of this software is governed by the LICENSE file. +*******************************************************************************/ + +#include "quadrants/python/export.h" +#include "quadrants/program/program.h" + +namespace quadrants { + +void export_stream(py::module &m, py::class_ &program_class) { + using lang::Program; + program_class.def("stream_create", [](Program *p) { return p->stream_manager().create_stream(); }) + .def("stream_destroy", [](Program *p, uint64 h) { p->stream_manager().destroy_stream(h); }) + .def("stream_synchronize", [](Program *p, uint64 h) { p->stream_manager().synchronize_stream(h); }) + .def("set_current_cuda_stream", [](Program *p, uint64 h) { p->stream_manager().set_current_stream(h); }) + .def("event_create", [](Program *p) { return p->stream_manager().create_event(); }) + .def("event_destroy", [](Program *p, uint64 h) { p->stream_manager().destroy_event(h); }) + .def("event_record", [](Program *p, uint64 eh, uint64 sh) { p->stream_manager().record_event(eh, sh); }) + .def("event_synchronize", [](Program *p, uint64 h) { p->stream_manager().synchronize_event(h); }) + .def("stream_wait_event", + [](Program *p, uint64 sh, uint64 eh) { p->stream_manager().stream_wait_event(sh, eh); }); +} + +} // namespace quadrants diff --git a/quadrants/rhi/amdgpu/amdgpu_context.cpp b/quadrants/rhi/amdgpu/amdgpu_context.cpp index ae5f40d1f2..5748895a41 100644 --- a/quadrants/rhi/amdgpu/amdgpu_context.cpp +++ b/quadrants/rhi/amdgpu/amdgpu_context.cpp @@ -13,6 +13,8 @@ namespace quadrants { namespace lang { +thread_local void *AMDGPUContext::stream_ = nullptr; + AMDGPUContext::AMDGPUContext() : driver_(AMDGPUDriver::get_instance_without_context()) { dev_count_ = 0; driver_.init(0); @@ -190,7 +192,7 @@ void AMDGPUContext::launch(void *func, if (grid_dim > 0) { std::lock_guard _(lock_); void *config[] = {(void *)0x01, (void *)packed_arg, (void *)0x02, (void *)&pack_size, (void *)0x03}; - driver_.launch_kernel(func, grid_dim, 1, 1, block_dim, 1, 1, dynamic_shared_mem_bytes, nullptr, nullptr, + driver_.launch_kernel(func, grid_dim, 1, 1, block_dim, 1, 1, dynamic_shared_mem_bytes, stream_, nullptr, reinterpret_cast(&config)); } std::free(packed_arg); @@ -199,11 +201,16 @@ void AMDGPUContext::launch(void *func, profiler_->stop(task_handle); if (debug_) { - driver_.stream_synchronize(nullptr); + driver_.stream_synchronize(stream_); } } AMDGPUContext::~AMDGPUContext() { + // Currently unreachable: singleton is heap-allocated via `new` in get_instance() and never deleted. + for (auto *s : stream_pool_) { + driver_.stream_destroy(s); + } + stream_pool_.clear(); if (context_) { driver_.device_primary_ctx_release(device_); } diff --git a/quadrants/rhi/amdgpu/amdgpu_context.h b/quadrants/rhi/amdgpu/amdgpu_context.h index 269106b077..9283afa078 100644 --- a/quadrants/rhi/amdgpu/amdgpu_context.h +++ b/quadrants/rhi/amdgpu/amdgpu_context.h @@ -3,6 +3,7 @@ #include #include #include +#include #include "quadrants/program/kernel_profiler.h" #include "quadrants/rhi/amdgpu/amdgpu_driver.h" @@ -24,6 +25,8 @@ class AMDGPUContext { AMDGPUDriver &driver_; bool debug_{false}; bool supports_mem_pool_{false}; + static thread_local void *stream_; + std::vector stream_pool_; public: AMDGPUContext(); @@ -113,6 +116,31 @@ class AMDGPUContext { return std::unique_lock(lock_); } + void set_stream(void *stream) { + stream_ = stream; + } + + void *get_stream() const { + return stream_; + } + + void *acquire_stream() { + std::lock_guard _(lock_); + if (!stream_pool_.empty()) { + auto s = stream_pool_.back(); + stream_pool_.pop_back(); + return s; + } + void *s = nullptr; + AMDGPUDriver::get_instance().stream_create(&s, 0x1 /*HIP_STREAM_NON_BLOCKING*/); + return s; + } + + void release_stream(void *s) { + std::lock_guard _(lock_); + stream_pool_.push_back(s); + } + static AMDGPUContext &get_instance(); }; diff --git a/quadrants/rhi/amdgpu/amdgpu_device.cpp b/quadrants/rhi/amdgpu/amdgpu_device.cpp index 68c377a73a..280cd9f7e1 100644 --- a/quadrants/rhi/amdgpu/amdgpu_device.cpp +++ b/quadrants/rhi/amdgpu/amdgpu_device.cpp @@ -1,4 +1,5 @@ #include "quadrants/rhi/amdgpu/amdgpu_device.h" +#include "quadrants/rhi/amdgpu/amdgpu_context.h" #include "quadrants/rhi/llvm/device_memory_pool.h" #include "quadrants/jit/jit_module.h" @@ -93,11 +94,12 @@ uint64_t *AmdgpuDevice::allocate_llvm_runtime_memory_jit(const LlvmRuntimeAllocP // the kernel without writing to *result. To detect that here, zero the slot first so a null readback unambiguously // means "allocation failed" and we can surface a helpful host-side message instead of letting the downstream // hipMemset trip on the stale pointer with a cryptic hipErrorInvalidValue. + void *active_stream = AMDGPUContext::get_instance().get_stream(); uint64 zero = 0; - AMDGPUDriver::get_instance().memcpy_host_to_device(params.result_buffer, &zero, sizeof(uint64)); + AMDGPUDriver::get_instance().memcpy_host_to_device_async(params.result_buffer, &zero, sizeof(uint64), active_stream); params.runtime_jit->call("runtime_memory_allocate_aligned", params.runtime, params.size, quadrants_page_size, params.result_buffer); - AMDGPUDriver::get_instance().stream_synchronize(nullptr); + AMDGPUDriver::get_instance().stream_synchronize(active_stream); uint64 *ret{nullptr}; AMDGPUDriver::get_instance().memcpy_device_to_host(&ret, params.result_buffer, sizeof(uint64)); QD_ERROR_IF(ret == nullptr, @@ -123,7 +125,7 @@ void AmdgpuDevice::dealloc_memory(DeviceAllocation handle) { } QD_ASSERT(!info.is_imported); if (info.use_memory_pool) { - AMDGPUDriver::get_instance().mem_free_async(info.ptr, nullptr); + AMDGPUDriver::get_instance().mem_free(info.ptr); } else if (info.use_cached) { DeviceMemoryPool::get_instance(Arch::amdgpu, false /*merge_upon_release*/) .release(info.size, (uint64_t *)info.ptr, false); diff --git a/quadrants/rhi/amdgpu/amdgpu_driver_functions.inc.h b/quadrants/rhi/amdgpu/amdgpu_driver_functions.inc.h index 5665e4b588..c94a7f14db 100644 --- a/quadrants/rhi/amdgpu/amdgpu_driver_functions.inc.h +++ b/quadrants/rhi/amdgpu/amdgpu_driver_functions.inc.h @@ -15,8 +15,12 @@ PER_AMDGPU_FUNCTION(context_create, hipCtxCreate, void *, int, void *); PER_AMDGPU_FUNCTION(context_set_current, hipCtxSetCurrent, void *); PER_AMDGPU_FUNCTION(context_get_current, hipCtxGetCurrent, void **); +// Device synchronization +PER_AMDGPU_FUNCTION(device_synchronize, hipDeviceSynchronize); + // Stream management -PER_AMDGPU_FUNCTION(stream_create, hipStreamCreate, void **, uint32); +PER_AMDGPU_FUNCTION(stream_create, hipStreamCreateWithFlags, void **, uint32); +PER_AMDGPU_FUNCTION(stream_destroy, hipStreamDestroy, void *); // Memory management PER_AMDGPU_FUNCTION(memcpy_host_to_device, hipMemcpyHtoD, void *, void *, std::size_t); @@ -27,6 +31,8 @@ PER_AMDGPU_FUNCTION(memcpy_async, hipMemcpyAsync, void *, void *, std::size_t, u PER_AMDGPU_FUNCTION(memcpy_host_to_device_async, hipMemcpyHtoDAsync, void *, void *, std::size_t, void *); PER_AMDGPU_FUNCTION(memcpy_device_to_host_async, hipMemcpyDtoHAsync, void *, void *, std::size_t, void *); PER_AMDGPU_FUNCTION(malloc, hipMalloc, void **, std::size_t); +// hipMallocAsync/hipFreeAsync require ROCm >= 5.4; the AMDGPUDriver wrappers fall back to the synchronous variants +// on devices without memory-pool support. PER_AMDGPU_FUNCTION(malloc_async_impl, hipMallocAsync, void **, std::size_t, void *); PER_AMDGPU_FUNCTION(malloc_managed, hipMallocManaged, void **, std::size_t, uint32); PER_AMDGPU_FUNCTION(memset, hipMemset, void *, uint8, std::size_t); @@ -61,6 +67,7 @@ PER_AMDGPU_FUNCTION(kernel_get_occupancy, hipOccupancyMaxActiveBlocksPerMultipro // Stream management PER_AMDGPU_FUNCTION(stream_synchronize, hipStreamSynchronize, void *); +PER_AMDGPU_FUNCTION(stream_wait_event, hipStreamWaitEvent, void *, void *, uint32); // Event management PER_AMDGPU_FUNCTION(event_create, hipEventCreateWithFlags, void **, uint32); diff --git a/quadrants/rhi/amdgpu/amdgpu_profiler.cpp b/quadrants/rhi/amdgpu/amdgpu_profiler.cpp index 731d536bca..e963f7df20 100644 --- a/quadrants/rhi/amdgpu/amdgpu_profiler.cpp +++ b/quadrants/rhi/amdgpu/amdgpu_profiler.cpp @@ -59,8 +59,9 @@ void KernelProfilerAMDGPU::trace(KernelProfilerBase::TaskHandle &task_handle, } void KernelProfilerAMDGPU::stop(KernelProfilerBase::TaskHandle handle) { - AMDGPUDriver::get_instance().event_record(handle, 0); - AMDGPUDriver::get_instance().stream_synchronize(nullptr); + void *active_stream = AMDGPUContext::get_instance().get_stream(); + AMDGPUDriver::get_instance().event_record(handle, active_stream); + AMDGPUDriver::get_instance().stream_synchronize(active_stream); // get elapsed time and destroy events auto record = event_toolkit_->get_current_event_record(); @@ -154,7 +155,8 @@ KernelProfilerBase::TaskHandle EventToolkitAMDGPU::start_with_handle(const std:: AMDGPUDriver::get_instance().event_create(&(record.start_event), HIP_EVENT_DEFAULT); AMDGPUDriver::get_instance().event_create(&(record.stop_event), HIP_EVENT_DEFAULT); - AMDGPUDriver::get_instance().event_record((record.start_event), 0); + void *active_stream = AMDGPUContext::get_instance().get_stream(); + AMDGPUDriver::get_instance().event_record((record.start_event), active_stream); event_records_.push_back(record); if (!base_event_) { @@ -163,7 +165,7 @@ KernelProfilerBase::TaskHandle EventToolkitAMDGPU::start_with_handle(const std:: for (int i = 0; i < n_iters; i++) { void *e; AMDGPUDriver::get_instance().event_create(&e, HIP_EVENT_DEFAULT); - AMDGPUDriver::get_instance().event_record(e, 0); + AMDGPUDriver::get_instance().event_record(e, active_stream); AMDGPUDriver::get_instance().event_synchronize(e); auto final_t = Time::get_time(); if (i == n_iters - 1) { diff --git a/quadrants/rhi/cuda/cuda_context.cpp b/quadrants/rhi/cuda/cuda_context.cpp index a50b789650..d1a266a1a9 100644 --- a/quadrants/rhi/cuda/cuda_context.cpp +++ b/quadrants/rhi/cuda/cuda_context.cpp @@ -11,7 +11,9 @@ namespace quadrants::lang { -CUDAContext::CUDAContext() : profiler_(nullptr), driver_(CUDADriver::get_instance_without_context()), stream_(nullptr) { +thread_local void *CUDAContext::stream_ = nullptr; + +CUDAContext::CUDAContext() : profiler_(nullptr), driver_(CUDADriver::get_instance_without_context()) { // CUDA initialization dev_count_ = 0; driver_.init(0); @@ -172,13 +174,11 @@ void CUDAContext::launch(void *func, } CUDAContext::~CUDAContext() { - // TODO: restore these? - /* - CUDADriver::get_instance().cuMemFree(context_buffer); - for (auto cudaModule: cudaModules) - CUDADriver::get_instance().cuModuleUnload(cudaModule); - CUDADriver::get_instance().cuCtxDestroy(context); - */ + // Currently unreachable: singleton is heap-allocated via `new` in get_instance() and never deleted. + for (auto *s : stream_pool_) { + driver_.stream_destroy(s); + } + stream_pool_.clear(); } CUDAContext &CUDAContext::get_instance() { diff --git a/quadrants/rhi/cuda/cuda_context.h b/quadrants/rhi/cuda/cuda_context.h index 61b7ae0c72..aafb3ed12b 100644 --- a/quadrants/rhi/cuda/cuda_context.h +++ b/quadrants/rhi/cuda/cuda_context.h @@ -3,6 +3,7 @@ #include #include #include +#include #include "quadrants/program/kernel_profiler.h" #include "quadrants/rhi/cuda/cuda_driver.h" @@ -32,7 +33,8 @@ class CUDAContext { bool supports_mem_pool_; bool supports_pageable_memory_access_; bool uses_host_page_tables_; - void *stream_; + static thread_local void *stream_; + std::vector stream_pool_; public: CUDAContext(); @@ -144,6 +146,23 @@ class CUDAContext { void *get_stream() const { return stream_; } + + void *acquire_stream() { + std::lock_guard _(lock_); + if (!stream_pool_.empty()) { + auto s = stream_pool_.back(); + stream_pool_.pop_back(); + return s; + } + void *s = nullptr; + CUDADriver::get_instance().stream_create(&s, 0x1 /*CU_STREAM_NON_BLOCKING*/); + return s; + } + + void release_stream(void *s) { + std::lock_guard _(lock_); + stream_pool_.push_back(s); + } }; } // namespace quadrants::lang diff --git a/quadrants/rhi/cuda/cuda_driver_functions.inc.h b/quadrants/rhi/cuda/cuda_driver_functions.inc.h index 2847f136c4..b4164b7c33 100644 --- a/quadrants/rhi/cuda/cuda_driver_functions.inc.h +++ b/quadrants/rhi/cuda/cuda_driver_functions.inc.h @@ -20,6 +20,7 @@ PER_CUDA_FUNCTION(context_set_limit, cuCtxSetLimit, int, std::size_t); // Stream management PER_CUDA_FUNCTION(stream_create, cuStreamCreate, void **, uint32); +PER_CUDA_FUNCTION(stream_destroy, cuStreamDestroy_v2, void *); // Memory management PER_CUDA_FUNCTION(memcpy_host_to_device, cuMemcpyHtoD_v2, void *, void *, std::size_t); @@ -52,8 +53,12 @@ PER_CUDA_FUNCTION(kernel_get_occupancy, cuOccupancyMaxActiveBlocksPerMultiproces PER_CUDA_FUNCTION(kernel_set_attribute, cuFuncSetAttribute, void *, CUfunction_attribute_enum, int); +// Context management +PER_CUDA_FUNCTION(context_synchronize, cuCtxSynchronize); + // Stream management PER_CUDA_FUNCTION(stream_synchronize, cuStreamSynchronize, void *); +PER_CUDA_FUNCTION(stream_wait_event, cuStreamWaitEvent, void *, void *, uint32); // Event management PER_CUDA_FUNCTION(event_create, cuEventCreate, void **, uint32) diff --git a/quadrants/runtime/amdgpu/kernel_launcher.cpp b/quadrants/runtime/amdgpu/kernel_launcher.cpp index deff808b2b..81a1f3ca70 100644 --- a/quadrants/runtime/amdgpu/kernel_launcher.cpp +++ b/quadrants/runtime/amdgpu/kernel_launcher.cpp @@ -1,3 +1,5 @@ +#include + #include "quadrants/runtime/amdgpu/kernel_launcher.h" #include "quadrants/rhi/amdgpu/amdgpu_context.h" #include "quadrants/program/adstack_size_expr_eval.h" @@ -78,8 +80,9 @@ void KernelLauncher::launch_offloaded_tasks(LaunchContextBuilder &ctx, if (any_max_reducer_task) { executor->dispatch_max_reducers_for_tasks(offloaded_tasks, &ctx, context_pointer); } - std::size_t task_index = 0; - for (const auto &task : offloaded_tasks) { + + // Per-task adstack setup + grid-dim capping. Shared by serial and stream-parallel paths. + auto prepare_task = [&](std::size_t task_index, const OffloadedTask &task) -> int { int effective_grid_dim = task.grid_dim; if (!task.ad_stack.allocas.empty()) { // Pass the device-side `RuntimeContext` pointer through to the adstack sizer kernel. Without this the sizer @@ -117,7 +120,6 @@ void KernelLauncher::launch_offloaded_tasks(LaunchContextBuilder &ctx, executor->ensure_per_task_float_heap_post_reducer(task_index, task.ad_stack, n_threads_amdgpu, &ctx); } } - ++task_index; // Match the heap-row count resolved above: adstack-bearing tasks dispatch at most `kAdStackMaxConcurrentThreads`. // The runtime grid-strided loop walks the full element list / range with `i += grid_dim()` so a smaller grid // completes the same workload sequentially per slot. @@ -130,9 +132,58 @@ void KernelLauncher::launch_offloaded_tasks(LaunchContextBuilder &ctx, effective_grid_dim = 1; } } - QD_TRACE("Launching kernel {}<<<{}, {}>>>", task.name, effective_grid_dim, task.block_dim); - amdgpu_module->launch(task.name, effective_grid_dim, task.block_dim, task.dynamic_shared_array_bytes, - {(void *)&context_pointer}, {arg_size}); + return effective_grid_dim; + }; + + auto *active_stream = AMDGPUContext::get_instance().get_stream(); + for (size_t i = 0; i < offloaded_tasks.size();) { + const auto &task = offloaded_tasks[i]; + if (task.stream_parallel_group_id == 0) { + int effective_grid_dim = prepare_task(i, task); + QD_TRACE("Launching kernel {}<<<{}, {}>>>", task.name, effective_grid_dim, task.block_dim); + amdgpu_module->launch(task.name, effective_grid_dim, task.block_dim, task.dynamic_shared_array_bytes, + {(void *)&context_pointer}, {arg_size}); + i++; + } else { + size_t group_start = i; + while (i < offloaded_tasks.size() && offloaded_tasks[i].stream_parallel_group_id != 0) { + i++; + } + + std::map stream_by_id; + for (size_t j = group_start; j < i; j++) { + int sid = offloaded_tasks[j].stream_parallel_group_id; + if (stream_by_id.find(sid) == stream_by_id.end()) { + stream_by_id[sid] = AMDGPUContext::get_instance().acquire_stream(); + } + } + + try { + for (size_t j = group_start; j < i; j++) { + const auto &t = offloaded_tasks[j]; + int effective_grid_dim = prepare_task(j, t); + AMDGPUContext::get_instance().set_stream(stream_by_id[t.stream_parallel_group_id]); + QD_TRACE("Launching kernel {}<<<{}, {}>>>", t.name, effective_grid_dim, t.block_dim); + amdgpu_module->launch(t.name, effective_grid_dim, t.block_dim, t.dynamic_shared_array_bytes, + {(void *)&context_pointer}, {arg_size}); + } + + for (auto &[sid, s] : stream_by_id) { + AMDGPUDriver::get_instance().stream_synchronize(s); + } + } catch (...) { + for (auto &[sid, s] : stream_by_id) { + AMDGPUContext::get_instance().release_stream(s); + } + AMDGPUContext::get_instance().set_stream(active_stream); + throw; + } + for (auto &[sid, s] : stream_by_id) { + AMDGPUContext::get_instance().release_stream(s); + } + + AMDGPUContext::get_instance().set_stream(active_stream); + } } } @@ -145,7 +196,8 @@ void KernelLauncher::launch_offloaded_tasks_with_do_while(LaunchContextBuilder & do { launch_offloaded_tasks(ctx, amdgpu_module, offloaded_tasks, context_pointer, arg_size); counter_val = 0; - AMDGPUDriver::get_instance().stream_synchronize(nullptr); + auto *stream = AMDGPUContext::get_instance().get_stream(); + AMDGPUDriver::get_instance().stream_synchronize(stream); AMDGPUDriver::get_instance().memcpy_device_to_host(&counter_val, ctx.graph_do_while_flag_dev_ptr, sizeof(int32_t)); } while (counter_val != 0); } @@ -176,6 +228,8 @@ void KernelLauncher::launch_llvm_kernel(Handle handle, LaunchContextBuilder &ctx std::unordered_map, ArgArrayPtrKeyHasher> transfers; std::unordered_map device_ptrs; + auto *active_stream = AMDGPUContext::get_instance().get_stream(); + char *device_result_buffer{nullptr}; // Here we have to guarantee the result_result_buffer isn't nullptr // It is interesting - The code following @@ -222,14 +276,16 @@ void KernelLauncher::launch_llvm_kernel(Handle handle, LaunchContextBuilder &ctx device_ptrs[data_ptr_idx] = executor->get_device_alloc_info_ptr(devalloc); transfers[data_ptr_idx] = {data_ptr, devalloc}; - AMDGPUDriver::get_instance().memcpy_host_to_device((void *)device_ptrs[data_ptr_idx], data_ptr, arr_sz); + AMDGPUDriver::get_instance().memcpy_host_to_device_async((void *)device_ptrs[data_ptr_idx], data_ptr, arr_sz, + active_stream); if (grad_ptr != nullptr) { DeviceAllocation grad_devalloc = executor->allocate_memory_on_device(arr_sz, (uint64 *)device_result_buffer); device_ptrs[grad_ptr_idx] = executor->get_device_alloc_info_ptr(grad_devalloc); transfers[grad_ptr_idx] = {grad_ptr, grad_devalloc}; - AMDGPUDriver::get_instance().memcpy_host_to_device((void *)device_ptrs[grad_ptr_idx], grad_ptr, arr_sz); + AMDGPUDriver::get_instance().memcpy_host_to_device_async((void *)device_ptrs[grad_ptr_idx], grad_ptr, + arr_sz, active_stream); } else { device_ptrs[grad_ptr_idx] = nullptr; } @@ -259,35 +315,50 @@ void KernelLauncher::launch_llvm_kernel(Handle handle, LaunchContextBuilder &ctx } } if (transfers.size() > 0) { - AMDGPUDriver::get_instance().stream_synchronize(nullptr); + AMDGPUDriver::get_instance().stream_synchronize(active_stream); } char *host_result_buffer = (char *)ctx.get_context().result_buffer; if (ctx.result_buffer_size > 0) { - // Malloc_Async and Free_Async are available after ROCm 5.4 ctx.get_context().result_buffer = (uint64 *)device_result_buffer; } + // Same explicit-stream race avoidance as the CUDA launcher: when active_stream != nullptr, allocate per-call + // ephemeral buffers so concurrent launches on different streams can't clobber each other. + const bool use_persistent_scratch = (active_stream == nullptr); char *device_arg_buffer = nullptr; + void *ephemeral_arg_buffer = nullptr; if (ctx.arg_buffer_size > 0) { - if (ctx.arg_buffer_size > launcher_ctx.arg_buffer_capacity) { - if (launcher_ctx.arg_buffer_dev_ptr != nullptr) { - AMDGPUDriver::get_instance().mem_free_async(launcher_ctx.arg_buffer_dev_ptr, nullptr); + if (use_persistent_scratch) { + if (ctx.arg_buffer_size > launcher_ctx.arg_buffer_capacity) { + if (launcher_ctx.arg_buffer_dev_ptr != nullptr) { + AMDGPUDriver::get_instance().mem_free_async(launcher_ctx.arg_buffer_dev_ptr, nullptr); + } + const std::size_t new_cap = std::max(ctx.arg_buffer_size, 2 * launcher_ctx.arg_buffer_capacity); + AMDGPUDriver::get_instance().malloc_async(&launcher_ctx.arg_buffer_dev_ptr, new_cap, nullptr); + launcher_ctx.arg_buffer_capacity = new_cap; } - const std::size_t new_cap = std::max(ctx.arg_buffer_size, 2 * launcher_ctx.arg_buffer_capacity); - AMDGPUDriver::get_instance().malloc_async(&launcher_ctx.arg_buffer_dev_ptr, new_cap, nullptr); - launcher_ctx.arg_buffer_capacity = new_cap; + device_arg_buffer = static_cast(launcher_ctx.arg_buffer_dev_ptr); + } else { + AMDGPUDriver::get_instance().malloc_async(&ephemeral_arg_buffer, ctx.arg_buffer_size, active_stream); + device_arg_buffer = static_cast(ephemeral_arg_buffer); } - device_arg_buffer = static_cast(launcher_ctx.arg_buffer_dev_ptr); AMDGPUDriver::get_instance().memcpy_host_to_device_async(device_arg_buffer, ctx.get_context().arg_buffer, - ctx.arg_buffer_size, nullptr); + ctx.arg_buffer_size, active_stream); ctx.get_context().arg_buffer = device_arg_buffer; } int arg_size = sizeof(RuntimeContext *); - if (launcher_ctx.runtime_context_dev_ptr == nullptr) { - AMDGPUDriver::get_instance().malloc_async(&launcher_ctx.runtime_context_dev_ptr, sizeof(RuntimeContext), nullptr); + void *ephemeral_context_ptr = nullptr; + void *context_pointer = nullptr; + if (use_persistent_scratch) { + if (launcher_ctx.runtime_context_dev_ptr == nullptr) { + AMDGPUDriver::get_instance().malloc_async(&launcher_ctx.runtime_context_dev_ptr, sizeof(RuntimeContext), nullptr); + } + context_pointer = launcher_ctx.runtime_context_dev_ptr; + } else { + AMDGPUDriver::get_instance().malloc_async(&ephemeral_context_ptr, sizeof(RuntimeContext), active_stream); + context_pointer = ephemeral_context_ptr; } - void *context_pointer = launcher_ctx.runtime_context_dev_ptr; AMDGPUDriver::get_instance().memcpy_host_to_device_async(context_pointer, &ctx.get_context(), sizeof(RuntimeContext), - nullptr); + active_stream); // Adstack-cache invalidation bump - see `bump_writes_for_kernel_llvm` in `program/adstack_size_expr_eval.{h,cpp}`. bump_writes_for_kernel_llvm(executor->get_program(), &ctx, offloaded_tasks); @@ -299,22 +370,36 @@ void KernelLauncher::launch_llvm_kernel(Handle handle, LaunchContextBuilder &ctx launch_offloaded_tasks(ctx, amdgpu_module, offloaded_tasks, context_pointer, arg_size); } QD_TRACE("Launching kernel"); - // Persistent scratch: no per-launch free for arg_buffer. The scratch lives until the launcher is destroyed. + // Persistent scratch (default-stream path): no per-launch free for the per-handle `arg_buffer` / `runtime_context` + // or the launcher-global `result_buffer`. All live until launcher destruction; the dtor handles the final + // `mem_free_async`. Ephemeral buffers (explicit-stream path) are freed below. if (ctx.result_buffer_size > 0) { - AMDGPUDriver::get_instance().memcpy_device_to_host(host_result_buffer, device_result_buffer, - ctx.result_buffer_size); + AMDGPUDriver::get_instance().memcpy_device_to_host_async(host_result_buffer, device_result_buffer, + ctx.result_buffer_size, active_stream); } - if (transfers.size()) { + if (transfers.size() > 0) { + AMDGPUDriver::get_instance().stream_synchronize(active_stream); for (auto itr = transfers.begin(); itr != transfers.end(); itr++) { auto &idx = itr->first; - auto arg_id = idx.arg_id; - AMDGPUDriver::get_instance().memcpy_device_to_host(itr->second.first, (void *)device_ptrs[idx], - ctx.array_runtime_sizes[arg_id]); + AMDGPUDriver::get_instance().memcpy_device_to_host_async(itr->second.first, (void *)device_ptrs[idx], + ctx.array_runtime_sizes[idx.arg_id], active_stream); + } + AMDGPUDriver::get_instance().stream_synchronize(active_stream); + for (auto itr = transfers.begin(); itr != transfers.end(); itr++) { executor->deallocate_memory_on_device(itr->second.second); } + } else if (ctx.result_buffer_size > 0) { + AMDGPUDriver::get_instance().stream_synchronize(active_stream); } // Persistent scratch: no per-launch free for the per-handle `arg_buffer` / `runtime_context` or the launcher-global // `result_buffer`. All three live until the launcher is destroyed; the dtor handles the final `mem_free_async`. + // Ephemeral buffers (explicit-stream path) are freed here. + if (ephemeral_arg_buffer != nullptr) { + AMDGPUDriver::get_instance().mem_free_async(ephemeral_arg_buffer, active_stream); + } + if (ephemeral_context_ptr != nullptr) { + AMDGPUDriver::get_instance().mem_free_async(ephemeral_context_ptr, active_stream); + } } KernelLauncher::~KernelLauncher() { diff --git a/quadrants/runtime/amdgpu/kernel_launcher.h b/quadrants/runtime/amdgpu/kernel_launcher.h index 0b12bba660..08e061ec93 100644 --- a/quadrants/runtime/amdgpu/kernel_launcher.h +++ b/quadrants/runtime/amdgpu/kernel_launcher.h @@ -1,5 +1,7 @@ #pragma once +#include + #include "quadrants/codegen/llvm/compiled_kernel_data.h" #include "quadrants/runtime/llvm/kernel_launcher.h" @@ -49,7 +51,11 @@ class KernelLauncher : public LLVM::KernelLauncher { // child completes, before the parent kernel that would be the next reader). Grown amortised-doubling. void *persistent_result_buffer_dev_ptr_{nullptr}; std::size_t persistent_result_buffer_capacity_{0}; - std::vector contexts_; + // std::deque (not std::vector): `publish_adstack_metadata`'s host-eval branch recursively registers snode-reader + // kernels via this same launcher, calling `contexts_.resize()` while a parent `launch_llvm_kernel` frame still + // holds a reference into the container. std::deque never invalidates references on push_back / resize, so the + // parent's `launcher_ctx` reference survives the child's registration. + std::deque contexts_; public: ~KernelLauncher() override; diff --git a/quadrants/runtime/cuda/kernel_launcher.cpp b/quadrants/runtime/cuda/kernel_launcher.cpp index 122ea1e60c..ff9fbc03fd 100644 --- a/quadrants/runtime/cuda/kernel_launcher.cpp +++ b/quadrants/runtime/cuda/kernel_launcher.cpp @@ -1,6 +1,9 @@ +#include + #include "quadrants/runtime/cuda/kernel_launcher.h" #include "quadrants/runtime/cuda/cuda_utils.h" #include "quadrants/rhi/cuda/cuda_context.h" +#include "quadrants/rhi/cuda/cuda_driver.h" #include "quadrants/runtime/llvm/llvm_runtime_executor.h" #include "quadrants/program/adstack_size_expr_eval.h" #include "quadrants/program/program.h" @@ -34,15 +37,17 @@ std::size_t resolve_num_threads(const AdStackSizingInfo &info, LlvmRuntimeExecut std::int32_t begin = info.begin_const_value; std::int32_t end = info.end_const_value; if (info.begin_offset_bytes >= 0 || info.end_offset_bytes >= 0) { + auto *active_stream = CUDAContext::get_instance().get_stream(); auto *temp_dev_ptr = reinterpret_cast(executor->get_runtime_temporaries_device_ptr()); if (info.begin_offset_bytes >= 0) { - CUDADriver::get_instance().memcpy_device_to_host(&begin, temp_dev_ptr + info.begin_offset_bytes, - sizeof(std::int32_t)); + CUDADriver::get_instance().memcpy_device_to_host_async(&begin, temp_dev_ptr + info.begin_offset_bytes, + sizeof(std::int32_t), active_stream); } if (info.end_offset_bytes >= 0) { - CUDADriver::get_instance().memcpy_device_to_host(&end, temp_dev_ptr + info.end_offset_bytes, - sizeof(std::int32_t)); + CUDADriver::get_instance().memcpy_device_to_host_async(&end, temp_dev_ptr + info.end_offset_bytes, + sizeof(std::int32_t), active_stream); } + CUDADriver::get_instance().stream_synchronize(active_stream); } // Clamp the logical iteration count to the launched thread count: adstack slices are indexed by // `linear_thread_idx()` (`block_idx * block_dim + thread_idx`), so only `static_num_threads = grid_dim * block_dim` @@ -93,8 +98,9 @@ void KernelLauncher::launch_offloaded_tasks(LaunchContextBuilder &ctx, if (any_max_reducer_task) { executor->dispatch_max_reducers_for_tasks(offloaded_tasks, &ctx, device_context_ptr); } - std::size_t task_index = 0; - for (const auto &task : offloaded_tasks) { + + // Per-task adstack setup + grid-dim capping. Shared by serial and stream-parallel paths. + auto prepare_task = [&](std::size_t task_index, const OffloadedTask &task) -> int { int effective_grid_dim = task.grid_dim; if (!task.ad_stack.allocas.empty()) { std::size_t n = resolve_num_threads(task.ad_stack, executor); @@ -145,29 +151,79 @@ void KernelLauncher::launch_offloaded_tasks(LaunchContextBuilder &ctx, // dispatched-threads worst case on sparse-grid workloads. executor->ensure_per_task_float_heap_post_reducer(task_index, task.ad_stack, n, &ctx); } - } - ++task_index; - // For adstack-bearing tasks, dispatch at most `kAdStackMaxConcurrentThreads` (matching the heap row count resolved - // above). The runtime's grid-strided loop (`gpu_parallel_struct_for` / `gpu_parallel_range_for`, - // `quadrants/runtime/llvm/runtime_module/runtime.cpp`) walks the full element list / range with `i += grid_dim()`, - // so a smaller grid completes the same workload sequentially per slot. Tasks without an adstack keep the - // codegen-emitted `task.grid_dim` (saturating_grid_dim) for max throughput. - if (!task.ad_stack.allocas.empty() && task.block_dim > 0) { + // For adstack-bearing tasks, dispatch at most `kAdStackMaxConcurrentThreads` (matching the heap row count + // resolved above). The runtime's grid-strided loop (`gpu_parallel_struct_for` / `gpu_parallel_range_for`, + // `quadrants/runtime/llvm/runtime_module/runtime.cpp`) walks the full element list / range with + // `i += grid_dim()`, so a smaller grid completes the same workload sequentially per slot. Tasks without an + // adstack keep the codegen-emitted `task.grid_dim` (saturating_grid_dim) for max throughput. + // // Floor division (not ceiling): the heap-row count `n` resolved by `resolve_num_threads` floors at // `kAdStackMaxConcurrentThreads`, so dispatching `cap_blocks * block_dim` threads must not exceed that count. // Ceiling division would over-dispatch by `block_dim - 1` threads when `block_dim` does not divide - // `kAdStackMaxConcurrentThreads` evenly (e.g. `block_dim=192`: `ceil(65536/192)*192 = 65664`), and threads with - // `linear_thread_idx >= 65536` would index past the heap end. - const std::size_t cap_blocks = - std::max(1u, kAdStackMaxConcurrentThreads / static_cast(task.block_dim)); - effective_grid_dim = static_cast(std::min(static_cast(task.grid_dim), cap_blocks)); - if (effective_grid_dim < 1) { - effective_grid_dim = 1; + // `kAdStackMaxConcurrentThreads` evenly (e.g. `block_dim=192`: `ceil(65536/192)*192 = 65664`), and threads + // with `linear_thread_idx >= 65536` would index past the heap end. + if (task.block_dim > 0) { + const std::size_t cap_blocks = + std::max(1u, kAdStackMaxConcurrentThreads / static_cast(task.block_dim)); + effective_grid_dim = + static_cast(std::min(static_cast(task.grid_dim), cap_blocks)); + if (effective_grid_dim < 1) { + effective_grid_dim = 1; + } } } - QD_TRACE("Launching kernel {}<<<{}, {}>>>", task.name, effective_grid_dim, task.block_dim); - cuda_module->launch(task.name, effective_grid_dim, task.block_dim, task.dynamic_shared_array_bytes, - {&ctx.get_context()}, {}); + return effective_grid_dim; + }; + + auto *active_stream = CUDAContext::get_instance().get_stream(); + for (size_t i = 0; i < offloaded_tasks.size();) { + const auto &task = offloaded_tasks[i]; + if (task.stream_parallel_group_id == 0) { + int effective_grid_dim = prepare_task(i, task); + QD_TRACE("Launching kernel {}<<<{}, {}>>>", task.name, effective_grid_dim, task.block_dim); + cuda_module->launch(task.name, effective_grid_dim, task.block_dim, task.dynamic_shared_array_bytes, + {&ctx.get_context()}, {}); + i++; + } else { + size_t group_start = i; + while (i < offloaded_tasks.size() && offloaded_tasks[i].stream_parallel_group_id != 0) { + i++; + } + + std::map stream_by_id; + for (size_t j = group_start; j < i; j++) { + int sid = offloaded_tasks[j].stream_parallel_group_id; + if (stream_by_id.find(sid) == stream_by_id.end()) { + stream_by_id[sid] = CUDAContext::get_instance().acquire_stream(); + } + } + + try { + for (size_t j = group_start; j < i; j++) { + const auto &t = offloaded_tasks[j]; + int effective_grid_dim = prepare_task(j, t); + CUDAContext::get_instance().set_stream(stream_by_id[t.stream_parallel_group_id]); + QD_TRACE("Launching kernel {}<<<{}, {}>>>", t.name, effective_grid_dim, t.block_dim); + cuda_module->launch(t.name, effective_grid_dim, t.block_dim, t.dynamic_shared_array_bytes, + {&ctx.get_context()}, {}); + } + + for (auto &[sid, s] : stream_by_id) { + CUDADriver::get_instance().stream_synchronize(s); + } + } catch (...) { + for (auto &[sid, s] : stream_by_id) { + CUDAContext::get_instance().release_stream(s); + } + CUDAContext::get_instance().set_stream(active_stream); + throw; + } + for (auto &[sid, s] : stream_by_id) { + CUDAContext::get_instance().release_stream(s); + } + + CUDAContext::get_instance().set_stream(active_stream); + } } } @@ -180,8 +236,9 @@ void KernelLauncher::launch_offloaded_tasks_with_do_while(LaunchContextBuilder & launch_offloaded_tasks(ctx, cuda_module, offloaded_tasks, device_context_ptr); counter_val = 0; auto *stream = CUDAContext::get_instance().get_stream(); + CUDADriver::get_instance().memcpy_device_to_host_async(&counter_val, ctx.graph_do_while_flag_dev_ptr, + sizeof(int32_t), stream); CUDADriver::get_instance().stream_synchronize(stream); - CUDADriver::get_instance().memcpy_device_to_host(&counter_val, ctx.graph_do_while_flag_dev_ptr, sizeof(int32_t)); } while (counter_val != 0); } @@ -226,6 +283,8 @@ void KernelLauncher::launch_llvm_kernel(Handle handle, LaunchContextBuilder &ctx // kernels. std::unordered_map device_ptrs; + auto *active_stream = CUDAContext::get_instance().get_stream(); + char *device_result_buffer{nullptr}; // Launcher-global persistent `result_buffer`. See `kernel_launcher.h` for why this one is shared across handles // (kernel writes + synchronous host readback before any other reader runs). `arg_buffer` and `runtime_context` @@ -272,14 +331,16 @@ void KernelLauncher::launch_llvm_kernel(Handle handle, LaunchContextBuilder &ctx device_ptrs[data_ptr_idx] = executor->get_device_alloc_info_ptr(devalloc); transfers[data_ptr_idx] = {data_ptr, devalloc}; - CUDADriver::get_instance().memcpy_host_to_device((void *)device_ptrs[data_ptr_idx], data_ptr, arr_sz); + CUDADriver::get_instance().memcpy_host_to_device_async((void *)device_ptrs[data_ptr_idx], data_ptr, arr_sz, + active_stream); if (grad_ptr != nullptr) { DeviceAllocation grad_devalloc = executor->allocate_memory_on_device(arr_sz, (uint64 *)device_result_buffer); device_ptrs[grad_ptr_idx] = executor->get_device_alloc_info_ptr(grad_devalloc); transfers[grad_ptr_idx] = {grad_ptr, grad_devalloc}; - CUDADriver::get_instance().memcpy_host_to_device((void *)device_ptrs[grad_ptr_idx], grad_ptr, arr_sz); + CUDADriver::get_instance().memcpy_host_to_device_async((void *)device_ptrs[grad_ptr_idx], grad_ptr, arr_sz, + active_stream); } else { device_ptrs[grad_ptr_idx] = nullptr; } @@ -310,25 +371,36 @@ void KernelLauncher::launch_llvm_kernel(Handle handle, LaunchContextBuilder &ctx } } if (transfers.size() > 0) { - CUDADriver::get_instance().stream_synchronize(nullptr); + CUDADriver::get_instance().stream_synchronize(active_stream); } char *host_result_buffer = (char *)ctx.get_context().result_buffer; if (ctx.result_buffer_size > 0) { ctx.get_context().result_buffer = (uint64 *)device_result_buffer; } + // When launching on an explicit stream (active_stream != nullptr), two calls to the same kernel on different streams + // would race on the shared per-handle arg_buffer: the second call's memcpy can overwrite the buffer while the first + // kernel is still reading it. Allocate a per-call ephemeral buffer in that case; the stream-ordered free below + // ensures the memory stays live until the kernel finishes. + const bool use_persistent_scratch = (active_stream == nullptr); char *device_arg_buffer = nullptr; + void *ephemeral_arg_buffer = nullptr; if (ctx.arg_buffer_size > 0) { - if (ctx.arg_buffer_size > launcher_ctx.arg_buffer_capacity) { - if (launcher_ctx.arg_buffer_dev_ptr != nullptr) { - CUDADriver::get_instance().mem_free_async(launcher_ctx.arg_buffer_dev_ptr, nullptr); + if (use_persistent_scratch) { + if (ctx.arg_buffer_size > launcher_ctx.arg_buffer_capacity) { + if (launcher_ctx.arg_buffer_dev_ptr != nullptr) { + CUDADriver::get_instance().mem_free_async(launcher_ctx.arg_buffer_dev_ptr, nullptr); + } + const std::size_t new_cap = std::max(ctx.arg_buffer_size, 2 * launcher_ctx.arg_buffer_capacity); + CUDADriver::get_instance().malloc_async(&launcher_ctx.arg_buffer_dev_ptr, new_cap, nullptr); + launcher_ctx.arg_buffer_capacity = new_cap; } - const std::size_t new_cap = std::max(ctx.arg_buffer_size, 2 * launcher_ctx.arg_buffer_capacity); - CUDADriver::get_instance().malloc_async(&launcher_ctx.arg_buffer_dev_ptr, new_cap, nullptr); - launcher_ctx.arg_buffer_capacity = new_cap; + device_arg_buffer = static_cast(launcher_ctx.arg_buffer_dev_ptr); + } else { + CUDADriver::get_instance().malloc_async(&ephemeral_arg_buffer, ctx.arg_buffer_size, active_stream); + device_arg_buffer = static_cast(ephemeral_arg_buffer); } - device_arg_buffer = static_cast(launcher_ctx.arg_buffer_dev_ptr); CUDADriver::get_instance().memcpy_host_to_device_async(device_arg_buffer, ctx.get_context().arg_buffer, - ctx.arg_buffer_size, nullptr); + ctx.arg_buffer_size, active_stream); ctx.get_context().arg_buffer = device_arg_buffer; } @@ -356,13 +428,19 @@ void KernelLauncher::launch_llvm_kernel(Handle handle, LaunchContextBuilder &ctx // memory through host page tables and the staging is redundant; on Turing/Volta we always stage. needs_sizer_device_ctx = needs_sizer_device_ctx && !CUDAContext::get_instance().uses_host_page_tables(); void *device_context_ptr = nullptr; + void *ephemeral_context_ptr = nullptr; if (needs_sizer_device_ctx) { - if (launcher_ctx.runtime_context_dev_ptr == nullptr) { - CUDADriver::get_instance().malloc_async(&launcher_ctx.runtime_context_dev_ptr, sizeof(RuntimeContext), nullptr); + if (use_persistent_scratch) { + if (launcher_ctx.runtime_context_dev_ptr == nullptr) { + CUDADriver::get_instance().malloc_async(&launcher_ctx.runtime_context_dev_ptr, sizeof(RuntimeContext), nullptr); + } + device_context_ptr = launcher_ctx.runtime_context_dev_ptr; + } else { + CUDADriver::get_instance().malloc_async(&ephemeral_context_ptr, sizeof(RuntimeContext), active_stream); + device_context_ptr = ephemeral_context_ptr; } - device_context_ptr = launcher_ctx.runtime_context_dev_ptr; CUDADriver::get_instance().memcpy_host_to_device_async(device_context_ptr, &ctx.get_context(), - sizeof(RuntimeContext), nullptr); + sizeof(RuntimeContext), active_stream); } // Adstack-cache invalidation bump - see `bump_writes_for_kernel_llvm` in `program/adstack_size_expr_eval.{h,cpp}`. @@ -374,21 +452,35 @@ void KernelLauncher::launch_llvm_kernel(Handle handle, LaunchContextBuilder &ctx } else { launch_offloaded_tasks(ctx, cuda_module, offloaded_tasks, device_context_ptr); } - // Persistent scratch: no per-launch free for the per-handle `arg_buffer` / `runtime_context` or the launcher- - // global `result_buffer`. All live until launcher destruction; the dtor handles the final `mem_free_async`. + // Persistent scratch (default-stream path): no per-launch free for the per-handle `arg_buffer` / `runtime_context` + // or the launcher-global `result_buffer`. All live until launcher destruction; the dtor handles the final + // `mem_free_async`. Ephemeral buffers (explicit-stream path) are freed below. if (ctx.result_buffer_size > 0) { CUDADriver::get_instance().memcpy_device_to_host_async(host_result_buffer, device_result_buffer, - ctx.result_buffer_size, nullptr); + ctx.result_buffer_size, active_stream); } // copy data back to host if (transfers.size() > 0) { - CUDADriver::get_instance().stream_synchronize(nullptr); + CUDADriver::get_instance().stream_synchronize(active_stream); for (auto itr = transfers.begin(); itr != transfers.end(); itr++) { auto &idx = itr->first; - CUDADriver::get_instance().memcpy_device_to_host(itr->second.first, (void *)device_ptrs[idx], - ctx.array_runtime_sizes[idx.arg_id]); + CUDADriver::get_instance().memcpy_device_to_host_async(itr->second.first, (void *)device_ptrs[idx], + ctx.array_runtime_sizes[idx.arg_id], active_stream); + } + CUDADriver::get_instance().stream_synchronize(active_stream); + for (auto itr = transfers.begin(); itr != transfers.end(); itr++) { executor->deallocate_memory_on_device(itr->second.second); } + } else if (ctx.result_buffer_size > 0) { + CUDADriver::get_instance().stream_synchronize(active_stream); + } + // Free per-call ephemeral buffers (explicit-stream path). The free is stream-ordered: it won't execute until all + // preceding work on active_stream (including the kernel reads) has completed. + if (ephemeral_arg_buffer != nullptr) { + CUDADriver::get_instance().mem_free_async(ephemeral_arg_buffer, active_stream); + } + if (ephemeral_context_ptr != nullptr) { + CUDADriver::get_instance().mem_free_async(ephemeral_context_ptr, active_stream); } } diff --git a/quadrants/runtime/cuda/kernel_launcher.h b/quadrants/runtime/cuda/kernel_launcher.h index e56f064857..eb5b4763df 100644 --- a/quadrants/runtime/cuda/kernel_launcher.h +++ b/quadrants/runtime/cuda/kernel_launcher.h @@ -1,5 +1,6 @@ #pragma once +#include #include #include @@ -55,7 +56,11 @@ class KernelLauncher : public LLVM::KernelLauncher { const std::vector &offloaded_tasks, void *device_context_ptr); - std::vector contexts_; + // std::deque (not std::vector): `publish_adstack_metadata`'s host-eval branch recursively registers snode-reader + // kernels via this same launcher, calling `contexts_.resize()` while a parent `launch_llvm_kernel` frame still + // holds a reference into the container. std::deque never invalidates references on push_back / resize, so the + // parent's `launcher_ctx` reference survives the child's registration. + std::deque contexts_; GraphManager graph_manager_; // `result_buffer` stays launcher-global: kernels write to it, the host reads it back synchronously before any // other kernel runs as a reader, so recursive snode-reader launches that reuse the buffer cannot smuggle stale diff --git a/quadrants/runtime/llvm/adstack_lazy_claim/heap_grow.cpp b/quadrants/runtime/llvm/adstack_lazy_claim/heap_grow.cpp index c53b52c83a..a597090a57 100644 --- a/quadrants/runtime/llvm/adstack_lazy_claim/heap_grow.cpp +++ b/quadrants/runtime/llvm/adstack_lazy_claim/heap_grow.cpp @@ -237,13 +237,13 @@ void LlvmRuntimeExecutor::ensure_adstack_heap_float(std::size_t needed_bytes) { } void LlvmRuntimeExecutor::check_adstack_overflow() { - // Called from `synchronize()` on every sync, plus other Quadrants Python entry points wired in - // `Program::check_adstack_overflow_and_raise`. The flag lives in pinned host memory (allocated at - // `materialize_runtime`); polling is a relaxed atomic exchange on the cached host pointer via - // `std::atomic` reinterpret_cast - no DtoH, no JIT call, no sync drain. Available on all backends because - // the pinned-host memory is in the host process address space regardless of where the kernel that wrote it ran. - // The reinterpret_cast is portable because `std::atomic` is layout-compatible with `int64_t` on every - // target (verified by the static_assert below); see also Itanium ABI / MSVC ABI lock-free guarantees. + // Called from `synchronize_and_assert()` on every qd.sync(), plus per-launch from `Program::launch_kernel`. The + // flag lives in pinned host memory (allocated at `materialize_runtime`); polling is a relaxed atomic load/exchange + // on the cached host pointer via `std::atomic` reinterpret_cast - no DtoH, no JIT call, no sync drain. + // Available on all backends because the pinned-host memory is in the host process address space regardless of + // where the kernel that wrote it ran. The reinterpret_cast is portable because `std::atomic` is + // layout-compatible with `int64_t` on every target (verified by the static_assert below); see also Itanium ABI / + // MSVC ABI lock-free guarantees. // // Returns early when the slot has not been allocated yet (e.g. a C++ test that constructs Program without // materializing the runtime and then triggers `Program::finalize -> synchronize`). @@ -252,15 +252,24 @@ void LlvmRuntimeExecutor::check_adstack_overflow() { if (adstack_overflow_flag_host_ptr_ == nullptr) { return; } + // Peek first: a relaxed load is cheaper than an exchange and avoids consuming the flag when the companion task_id + // slot has not yet been flushed from the device. The per-launch call site does NOT synchronize before polling, so + // the device's two atomic writes (flag OR, then task_id cmpxchg) may arrive at the host out of order. If we + // consumed the flag here but the task_id hadn't landed, the diagnostic would lack the kernel name and the later + // qd.sync() would see both slots clean — losing the identity forever. int64_t flag = - reinterpret_cast *>(adstack_overflow_flag_host_ptr_)->exchange(0, std::memory_order_relaxed); + reinterpret_cast *>(adstack_overflow_flag_host_ptr_)->load(std::memory_order_relaxed); if (flag == 0) { return; } - // Drain the companion task-id slot in the same poll. Both slots cleared so the next overflow records a fresh - // identity. `task_id == 0` means the kernel that overflowed pre-dates the registry wiring or its - // `ad_stack.registry_id` was unset for any reason (e.g. a deserialised offline-cache task that has not yet been - // re-registered); the diagnose helper falls through to the generic dual-cause message in that case. + // Flag is set — drain the default stream so that the companion task_id write is guaranteed to be host-visible + // before we read it. This sync only fires on the rare overflow path, so it has zero cost on the fast path. + synchronize(); + // Now consume both slots. Both cleared so the next overflow records a fresh identity. `task_id == 0` means the + // kernel that overflowed pre-dates the registry wiring or its `ad_stack.registry_id` was unset for any reason + // (e.g. a deserialised offline-cache task that has not yet been re-registered); the diagnose helper falls through + // to the generic dual-cause message in that case. + reinterpret_cast *>(adstack_overflow_flag_host_ptr_)->store(0, std::memory_order_relaxed); uint32_t task_id = 0; if (adstack_overflow_task_id_host_ptr_ != nullptr) { int64_t recorded = reinterpret_cast *>(adstack_overflow_task_id_host_ptr_) diff --git a/quadrants/transforms/lower_ast.cpp b/quadrants/transforms/lower_ast.cpp index 72b45cb976..6818ad2f90 100644 --- a/quadrants/transforms/lower_ast.cpp +++ b/quadrants/transforms/lower_ast.cpp @@ -222,6 +222,7 @@ class LowerAST : public IRVisitor { stmt->num_cpu_threads, stmt->block_dim); new_for->loop_name = stmt->loop_name; new_for->index_offsets = offsets; + new_for->stream_parallel_group_id = stmt->stream_parallel_group_id; VecStatement new_statements; for (int i = 0; i < (int)stmt->loop_var_ids.size(); i++) { Stmt *loop_index = new_statements.push_back(new_for.get(), snode->physical_index_position[i]); @@ -256,6 +257,7 @@ class LowerAST : public IRVisitor { stmt->num_cpu_threads, stmt->block_dim, stmt->strictly_serialized, /*range_hint=*/fmt::format("arg ({})", fmt::join(arg_id, ", ")), /*loop_name=*/stmt->loop_name); + new_for->stream_parallel_group_id = stmt->stream_parallel_group_id; VecStatement new_statements; Stmt *loop_index = new_statements.push_back(new_for.get(), 0); for (int i = (int)shape.size() - 1; i >= 0; i--) { @@ -289,6 +291,7 @@ class LowerAST : public IRVisitor { stmt->is_bit_vectorized, stmt->num_cpu_threads, stmt->block_dim, stmt->strictly_serialized, /*range_hint=*/"", /*loop_name=*/stmt->loop_name); + new_for->stream_parallel_group_id = stmt->stream_parallel_group_id; new_for->body->insert(std::make_unique(new_for.get(), 0), 0); new_for->body->local_var_to_stmt[stmt->loop_var_ids[0]] = new_for->body->statements[0].get(); fctx.push_back(std::move(new_for)); diff --git a/quadrants/transforms/offload.cpp b/quadrants/transforms/offload.cpp index b9027e95b3..c8a62a1409 100644 --- a/quadrants/transforms/offload.cpp +++ b/quadrants/transforms/offload.cpp @@ -126,6 +126,7 @@ class Offloader { offloaded->body->insert(std::move(s->body->statements[j])); } offloaded->range_hint = s->range_hint; + offloaded->stream_parallel_group_id = s->stream_parallel_group_id; offloaded->loop_name = s->loop_name; root_block->insert(std::move(offloaded)); } else if (auto st = stmt->cast()) { @@ -237,6 +238,7 @@ class Offloader { offloaded_struct_for->is_bit_vectorized = for_stmt->is_bit_vectorized; offloaded_struct_for->num_cpu_threads = std::min(for_stmt->num_cpu_threads, config.cpu_max_num_threads); offloaded_struct_for->mem_access_opt = mem_access_opt; + offloaded_struct_for->stream_parallel_group_id = for_stmt->stream_parallel_group_id; offloaded_struct_for->loop_name = for_stmt->loop_name; root_block->insert(std::move(offloaded_struct_for)); diff --git a/tests/python/test_api.py b/tests/python/test_api.py index 4aef94ca59..c0f129fe99 100644 --- a/tests/python/test_api.py +++ b/tests/python/test_api.py @@ -62,6 +62,7 @@ def _get_expected_matrix_apis(): "DEBUG", "DeviceCapability", "ERROR", + "Event", "Field", "FieldsBuilder", "Format", @@ -76,6 +77,7 @@ def _get_expected_matrix_apis(): "SNode", "ScalarField", "ScalarNdarray", + "Stream", "Struct", "StructField", "TRACE", @@ -124,6 +126,8 @@ def _get_expected_matrix_apis(): "clock_freq_hz", "cos", "cpu", + "create_event", + "create_stream", "cuda", "data_oriented", "dataclass", @@ -223,6 +227,7 @@ def _get_expected_matrix_apis(): "static_assert", "static_print", "stop_grad", + "stream_parallel", "svd", "sym_eig", "sync", diff --git a/tests/python/test_cache.py b/tests/python/test_cache.py index c3821e44c5..e31daf61e7 100644 --- a/tests/python/test_cache.py +++ b/tests/python/test_cache.py @@ -216,11 +216,11 @@ def test_fastcache(tmp_path: pathlib.Path, monkeypatch): qd_init_same_arch(offline_cache_file_path=str(tmp_path), offline_cache=True) is_valid = False - def launch_kernel(self, key, t_kernel, compiled_kernel_data, *args): + def launch_kernel(self, key, t_kernel, compiled_kernel_data, *args, qd_stream=None): nonlocal is_valid is_valid = True assert compiled_kernel_data is None - return launch_kernel_orig(self, key, t_kernel, compiled_kernel_data, *args) + return launch_kernel_orig(self, key, t_kernel, compiled_kernel_data, *args, qd_stream=qd_stream) monkeypatch.setattr("quadrants.lang.kernel_impl.Kernel.launch_kernel", launch_kernel) @@ -242,11 +242,11 @@ def fun(value: qd.types.ndarray(), offset: qd.template()): qd_init_same_arch(offline_cache_file_path=str(tmp_path), offline_cache=True) is_valid = False - def launch_kernel(self, key, t_kernel, compiled_kernel_data, *args): + def launch_kernel(self, key, t_kernel, compiled_kernel_data, *args, qd_stream=None): nonlocal is_valid is_valid = True assert compiled_kernel_data is not None - return launch_kernel_orig(self, key, t_kernel, compiled_kernel_data, *args) + return launch_kernel_orig(self, key, t_kernel, compiled_kernel_data, *args, qd_stream=qd_stream) monkeypatch.setattr("quadrants.lang.kernel_impl.Kernel.launch_kernel", launch_kernel) diff --git a/tests/python/test_perf_dispatch.py b/tests/python/test_perf_dispatch.py index 257775e6e1..55e96a308a 100644 --- a/tests/python/test_perf_dispatch.py +++ b/tests/python/test_perf_dispatch.py @@ -143,7 +143,7 @@ def my_func1_impl_a_shape0_ge_2( assert len(speed_checker._trial_count_by_dispatch_impl_by_geometry_hash[geometry]) == 2 -@test_utils.test() +@test_utils.test(exclude=[qd.vulkan]) def test_perf_dispatch_python() -> None: WARMUP = 1 ACTIVE = 1 diff --git a/tests/python/test_streams.py b/tests/python/test_streams.py new file mode 100644 index 0000000000..b89a3b4a42 --- /dev/null +++ b/tests/python/test_streams.py @@ -0,0 +1,648 @@ +"""Tests for GPU stream and event support.""" + +import numpy as np +import pytest + +import quadrants as qd +from quadrants.lang.stream import Event, Stream + +from tests import test_utils + + +@test_utils.test(arch=[qd.cuda, qd.amdgpu]) +def test_create_and_destroy_stream(): + s = qd.create_stream() + assert isinstance(s, Stream) + assert s.handle != 0 + s.destroy() + assert s.handle == 0 + + +@test_utils.test(arch=[qd.cuda, qd.amdgpu]) +def test_create_and_destroy_event(): + e = qd.create_event() + assert isinstance(e, Event) + assert e.handle != 0 + e.destroy() + assert e.handle == 0 + + +@test_utils.test() +def test_kernel_on_stream(): + N = 1024 + x = qd.field(qd.f32, shape=(N,)) + + @qd.kernel + def fill(): + for i in range(N): + x[i] = 42.0 + + s = qd.create_stream() + fill(qd_stream=s) + s.synchronize() + assert np.allclose(x.to_numpy(), 42.0) + s.destroy() + + +@test_utils.test() +def test_two_streams(): + N = 1024 + a = qd.field(qd.f32, shape=(N,)) + b = qd.field(qd.f32, shape=(N,)) + + @qd.kernel + def fill_a(): + for i in range(N): + a[i] = 1.0 + + @qd.kernel + def fill_b(): + for i in range(N): + b[i] = 2.0 + + s1 = qd.create_stream() + s2 = qd.create_stream() + fill_a(qd_stream=s1) + fill_b(qd_stream=s2) + s1.synchronize() + s2.synchronize() + assert np.allclose(a.to_numpy(), 1.0) + assert np.allclose(b.to_numpy(), 2.0) + s1.destroy() + s2.destroy() + + +@test_utils.test() +def test_event_synchronization(): + N = 1024 + x = qd.field(qd.f32, shape=(N,)) + y = qd.field(qd.f32, shape=(N,)) + + @qd.kernel + def fill_x(): + for i in range(N): + x[i] = 10.0 + + @qd.kernel + def copy_x_to_y(): + for i in range(N): + y[i] = x[i] + + s1 = qd.create_stream() + fill_x(qd_stream=s1) + + e = qd.create_event() + e.record(s1) + + # Default stream waits for s1 to finish fill_x + e.wait() + copy_x_to_y() + qd.sync() + + assert np.allclose(y.to_numpy(), 10.0) + + e.destroy() + s1.destroy() + + +@test_utils.test() +def test_event_wait_on_stream(): + N = 1024 + x = qd.field(qd.f32, shape=(N,)) + y = qd.field(qd.f32, shape=(N,)) + + @qd.kernel + def fill_x(): + for i in range(N): + x[i] = 5.0 + + @qd.kernel + def copy_x_to_y(): + for i in range(N): + y[i] = x[i] + + s1 = qd.create_stream() + s2 = qd.create_stream() + + fill_x(qd_stream=s1) + + e = qd.create_event() + e.record(s1) + + # s2 waits for s1's event before running + e.wait(qd_stream=s2) + copy_x_to_y(qd_stream=s2) + s2.synchronize() + + assert np.allclose(y.to_numpy(), 5.0) + + e.destroy() + s1.destroy() + s2.destroy() + + +@test_utils.test() +def test_default_stream_kernel(): + N = 1024 + x = qd.field(qd.f32, shape=(N,)) + + @qd.kernel + def fill(): + for i in range(N): + x[i] = 7.0 + + fill() + qd.sync() + assert np.allclose(x.to_numpy(), 7.0) + + +@test_utils.test(arch=[qd.cpu]) +def test_stream_noop_on_cpu(): + """Streams should be no-ops on CPU without errors.""" + N = 64 + x = qd.field(qd.f32, shape=(N,)) + + @qd.kernel + def fill(): + for i in range(N): + x[i] = 3.0 + + s = qd.create_stream() + assert s.handle == 0 + fill(qd_stream=s) + qd.sync() + assert np.allclose(x.to_numpy(), 3.0) + + e = qd.create_event() + assert e.handle == 0 + e.record(s) + e.wait() + s.destroy() + e.destroy() + + +@test_utils.test() +def test_concurrent_streams_with_events(): + """Two slow kernels on separate streams run concurrently (~1s on GPU), serial fallback on CPU/Metal.""" + SPIN_ITERS = 5_000_000 + + @qd.kernel + def slow_fill( + a: qd.types.ndarray(dtype=qd.f32, ndim=1), + lcg_state: qd.types.ndarray(dtype=qd.i32, ndim=1), + index: qd.i32, + value: qd.f32, + ): + qd.loop_config(block_dim=1) + for _ in range(1): + x = lcg_state[index] + for _j in range(SPIN_ITERS): + x = (1664525 * x + 1013904223) % 2147483647 + lcg_state[index] = x + a[index] = value + + @qd.kernel + def add_first_two(a: qd.types.ndarray(dtype=qd.f32, ndim=1)): + qd.loop_config(block_dim=1) + for _ in range(1): + a[2] = a[0] + a[1] + + import time + + # Warm up JIT + a_warmup = qd.ndarray(qd.f32, shape=(3,)) + lcg_warmup = qd.ndarray(qd.i32, shape=(3,)) + slow_fill(a_warmup, lcg_warmup, 0, 0.0) + add_first_two(a_warmup) + qd.sync() + + # Serial baseline + a = qd.ndarray(qd.f32, shape=(3,)) + lcg = qd.ndarray(qd.i32, shape=(3,)) + qd.sync() + t0 = time.perf_counter() + slow_fill(a, lcg, 0, 5.0) + slow_fill(a, lcg, 1, 7.0) + add_first_two(a) + qd.sync() + serial_time = time.perf_counter() - t0 + assert np.isclose(a.to_numpy()[2], 12.0) + + # Streams + a = qd.ndarray(qd.f32, shape=(3,)) + lcg = qd.ndarray(qd.i32, shape=(3,)) + s1 = qd.create_stream() + s2 = qd.create_stream() + e1 = qd.create_event() + e2 = qd.create_event() + qd.sync() + t0 = time.perf_counter() + slow_fill(a, lcg, 0, 5.0, qd_stream=s1) + slow_fill(a, lcg, 1, 7.0, qd_stream=s2) + e1.record(s1) + e2.record(s2) + e1.wait() + e2.wait() + add_first_two(a) + qd.sync() + stream_time = time.perf_counter() - t0 + assert np.isclose(a.to_numpy()[2], 12.0) + + speedup = serial_time / stream_time + if qd.lang.impl.current_cfg().arch in (qd.cuda, qd.amdgpu): + assert speedup > 1.5, f"Expected >1.5x speedup, got {speedup:.2f}x" + else: + assert speedup > 0.75, f"Expected >=0.75x (serial fallback), got {speedup:.2f}x" + + s1.destroy() + s2.destroy() + e1.destroy() + e2.destroy() + + +@test_utils.test() +def test_stream_context_manager(): + N = 64 + x = qd.field(qd.f32, shape=(N,)) + + @qd.kernel + def fill(): + for i in range(N): + x[i] = 11.0 + + with qd.create_stream() as s: + fill(qd_stream=s) + s.synchronize() + assert s.handle == 0 + assert np.allclose(x.to_numpy(), 11.0) + + +@test_utils.test() +def test_event_context_manager(): + with qd.create_event() as e: + assert isinstance(e, Event) + assert e.handle == 0 + + +@test_utils.test() +def test_event_synchronize(): + N = 64 + x = qd.field(qd.f32, shape=(N,)) + + @qd.kernel + def fill(): + for i in range(N): + x[i] = 13.0 + + s = qd.create_stream() + fill(qd_stream=s) + e = qd.create_event() + e.record(s) + e.synchronize() + assert np.allclose(x.to_numpy(), 13.0) + e.destroy() + s.destroy() + + +@test_utils.test(arch=[qd.cuda]) +def test_stream_with_tape_raises(): + x = qd.field(qd.f32, shape=(), needs_grad=True) + loss = qd.field(qd.f32, shape=(), needs_grad=True) + + @qd.kernel + def compute(): + loss[None] = x[None] ** 2 + + s = qd.create_stream() + with pytest.raises(RuntimeError, match="not compatible with autograd Tape"): + with qd.ad.Tape(loss): + compute(qd_stream=s) + s.destroy() + + +@test_utils.test(arch=[qd.cuda]) +def test_stream_with_autodiff_kernel_raises(): + x = qd.field(qd.f32, shape=(), needs_grad=True) + loss = qd.field(qd.f32, shape=(), needs_grad=True) + + @qd.kernel + def compute(): + loss[None] = x[None] ** 2 + + s = qd.create_stream() + with pytest.raises(RuntimeError, match="not compatible with autodiff"): + compute.grad(qd_stream=s) + s.destroy() + + +@test_utils.test(arch=[qd.cuda]) +def test_stream_with_graph_raises(): + N = 64 + x = qd.field(qd.f32, shape=(N,)) + + @qd.kernel(graph=True) + def fill(): + for i in range(N): + x[i] = 1.0 + + s = qd.create_stream() + with pytest.raises(RuntimeError, match="not compatible with graph=True"): + fill(qd_stream=s) + s.destroy() + + +@test_utils.test() +def test_stream_parallel_basic(): + """Each with qd.stream_parallel() block runs on its own stream (serial fallback on CPU/Metal).""" + N = 1024 + a = qd.field(qd.f32, shape=(N,)) + b = qd.field(qd.f32, shape=(N,)) + + @qd.kernel + def fill_parallel(): + with qd.stream_parallel(): + for i in range(N): + a[i] = 1.0 + with qd.stream_parallel(): + for j in range(N): + b[j] = 2.0 + + fill_parallel() + qd.sync() + assert np.allclose(a.to_numpy(), 1.0) + assert np.allclose(b.to_numpy(), 2.0) + + +@test_utils.test() +def test_stream_parallel_multiple_loops_per_stream(): + """Multiple for loops inside one stream_parallel block share a stream (serial fallback on CPU/Metal).""" + N = 1024 + a = qd.field(qd.f32, shape=(N,)) + b = qd.field(qd.f32, shape=(N,)) + c = qd.field(qd.f32, shape=(N,)) + + @qd.kernel + def parallel_phase(): + with qd.stream_parallel(): + for i in range(N): + a[i] = 1.0 + for i in range(N): + a[i] = a[i] + 1.0 + with qd.stream_parallel(): + for j in range(N): + b[j] = 10.0 + + @qd.kernel + def combine(): + for i in range(N): + c[i] = a[i] + b[i] + + parallel_phase() + combine() + qd.sync() + assert np.allclose(a.to_numpy(), 2.0) + assert np.allclose(b.to_numpy(), 10.0) + assert np.allclose(c.to_numpy(), 12.0) + + +@test_utils.test() +def test_stream_parallel_timing(): + """stream_parallel achieves speedup on GPU, serial fallback elsewhere.""" + SPIN_ITERS = 5_000_000 + + a = qd.field(qd.i32, shape=(2,)) + b = qd.field(qd.i32, shape=(2,)) + + @qd.kernel + def serial_spin(): + for _ in range(1): + x = a[0] + for _j in range(SPIN_ITERS): + x = (1664525 * x + 1013904223) % 2147483647 + a[0] = x + for _ in range(1): + x = a[1] + for _j in range(SPIN_ITERS): + x = (1664525 * x + 1013904223) % 2147483647 + a[1] = x + + @qd.kernel + def parallel_spin(): + with qd.stream_parallel(): + for _ in range(1): + x = b[0] + for _j in range(SPIN_ITERS): + x = (1664525 * x + 1013904223) % 2147483647 + b[0] = x + with qd.stream_parallel(): + for _ in range(1): + x = b[1] + for _j in range(SPIN_ITERS): + x = (1664525 * x + 1013904223) % 2147483647 + b[1] = x + + import time + + # Warm up + serial_spin() + parallel_spin() + qd.sync() + + qd.sync() + t0 = time.perf_counter() + serial_spin() + qd.sync() + serial_time = time.perf_counter() - t0 + + qd.sync() + t0 = time.perf_counter() + parallel_spin() + qd.sync() + stream_time = time.perf_counter() - t0 + + speedup = serial_time / stream_time + if qd.lang.impl.current_cfg().arch in (qd.cuda, qd.amdgpu): + assert speedup > 1.5, ( + f"Expected >1.5x speedup, got {speedup:.2f}x " f"(serial={serial_time:.3f}s, stream={stream_time:.3f}s)" + ) + else: + assert speedup > 0.75, ( + f"Expected >=0.75x (serial fallback), got {speedup:.2f}x " + f"(serial={serial_time:.3f}s, stream={stream_time:.3f}s)" + ) + + +@test_utils.test() +def test_stream_parallel_rejects_mixed_top_level(): + """Mixing stream_parallel and non-stream_parallel at top level is an error.""" + import pytest # noqa: I001 + + from quadrants.lang.exception import QuadrantsSyntaxError + + N = 64 + a = qd.field(qd.f32, shape=(N,)) + + with pytest.raises(QuadrantsSyntaxError, match="all top-level statements"): + + @qd.kernel + def bad_kernel(): + with qd.stream_parallel(): + for i in range(N): + a[i] = 1.0 + for i in range(N): + a[i] = 2.0 + + bad_kernel() + + +@test_utils.test() +def test_stream_with_ndarray(): + N = 1024 + + @qd.kernel + def fill(arr: qd.types.ndarray(dtype=qd.f32, ndim=1)): + for i in range(N): + arr[i] = 99.0 + + arr = qd.ndarray(qd.f32, shape=(N,)) + s = qd.create_stream() + fill(arr, qd_stream=s) + s.synchronize() + assert np.allclose(arr.to_numpy(), 99.0) + s.destroy() + + +@test_utils.test() +def test_stream_pool_reuse(): + """Repeated stream_parallel invocations reuse pooled streams correctly.""" + N = 128 + a = qd.ndarray(qd.f32, shape=(N,)) + b = qd.ndarray(qd.f32, shape=(N,)) + + @qd.kernel + def parallel_fill( + x: qd.types.ndarray(dtype=qd.f32, ndim=1), + y: qd.types.ndarray(dtype=qd.f32, ndim=1), + val: qd.f32, + ): + with qd.stream_parallel(): + for i in range(N): + x[i] = val + with qd.stream_parallel(): + for i in range(N): + y[i] = val * 2.0 + + for iteration in range(5): + v = float(iteration + 1) + parallel_fill(a, b, v) + qd.sync() + assert np.allclose(a.to_numpy(), v), f"iteration {iteration}" + assert np.allclose(b.to_numpy(), v * 2.0), f"iteration {iteration}" + + +@test_utils.test() +def test_with_multiple_context_managers_rejected(): + import pytest + + from quadrants.lang.exception import QuadrantsSyntaxError + + N = 64 + a = qd.field(qd.f32, shape=(N,)) + + with pytest.raises(QuadrantsSyntaxError, match="single context manager"): + + @qd.kernel + def bad(): + with qd.stream_parallel(), qd.stream_parallel(): + for i in range(N): + a[i] = 1.0 + + bad() + + +@test_utils.test() +def test_with_as_rejected(): + import pytest + + from quadrants.lang.exception import QuadrantsSyntaxError + + N = 64 + a = qd.field(qd.f32, shape=(N,)) + + with pytest.raises(QuadrantsSyntaxError, match="with .* as"): + + @qd.kernel + def bad(): + with qd.stream_parallel() as s: + for i in range(N): + a[i] = 1.0 + + bad() + + +@test_utils.test() +def test_with_non_call_expression_rejected(): + import pytest + + from quadrants.lang.exception import QuadrantsSyntaxError + + N = 64 + a = qd.field(qd.f32, shape=(N,)) + dummy = qd.stream_parallel + + with pytest.raises(QuadrantsSyntaxError, match="requires a call expression"): + + @qd.kernel + def bad(): + with dummy: + for i in range(N): + a[i] = 1.0 + + bad() + + +@test_utils.test() +def test_with_non_stream_parallel_rejected(): + import pytest + + from quadrants.lang.exception import QuadrantsSyntaxError + + N = 64 + a = qd.field(qd.f32, shape=(N,)) + + def other_ctx(): + pass + + with pytest.raises(QuadrantsSyntaxError, match="only supports qd.stream_parallel"): + + @qd.kernel + def bad(): + with other_ctx(): + for i in range(N): + a[i] = 1.0 + + bad() + + +@test_utils.test() +def test_stream_parallel_in_func_rejected(): + import pytest + + from quadrants.lang.exception import QuadrantsSyntaxError + + N = 64 + a = qd.field(qd.f32, shape=(N,)) + + with pytest.raises(QuadrantsSyntaxError, match="only be used inside @qd.kernel"): + + @qd.func + def helper(): + with qd.stream_parallel(): + for i in range(N): + a[i] = 1.0 + + @qd.kernel + def bad(): + helper() + + bad()