Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/user_guide/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ tile16

fastcache
graph
streams
perf_dispatch
init_options
```
Expand Down
137 changes: 137 additions & 0 deletions docs/source/user_guide/streams.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# Streams

Streams allow concurrent execution of GPU operations. By default, all Quadrants kernels launch on the default stream, which serializes everything. With streams, you can run multiple top-level for loops in parallel.

## Supported platforms

| Backend | Supported |
|---------|-----------|
| CUDA | Yes |
| AMDGPU | Yes |
| CPU | No-op |
| Metal | No-op |
| Vulkan | No-op |

On backends without native stream support, stream operations are no-ops and for loops run serially. Code using streams is portable across all backends — it will run without modifications, but serially.

## Stream parallelism

Inside a `@qd.kernel`, each `with qd.stream_parallel():` block runs on its own GPU stream.

```python
import quadrants as qd

qd.init(arch=qd.cuda)

N = 1024
a = qd.field(qd.f32, shape=(N,))
b = qd.field(qd.f32, shape=(N,))
c = qd.field(qd.f32, shape=(N,))

@qd.kernel
def compute_ab():
with qd.stream_parallel():
for i in range(N):
a[i] = compute_a(i)
with qd.stream_parallel():
for j in range(N):
b[j] = compute_b(j)

@qd.kernel
def combine():
for i in range(N):
c[i] = a[i] + b[i]

compute_ab() # the two stream_parallel blocks run concurrently
combine() # runs after compute_ab() returns — a[] and b[] are ready
```

Consecutive `with qd.stream_parallel():` blocks run concurrently. Multiple for loops within a single block share a stream and run serially on it. All streams are synchronized before the kernel returns.

### Restrictions

- All top-level statements in a kernel must be either all `stream_parallel` blocks or all regular statements. Mixing the two at the top level is a compile-time error.
- Nesting `stream_parallel` blocks is not supported.

## Explicit streams

For cases that require manual control — such as launching separate kernels on different streams or interoperating with PyTorch — you can create and manage streams directly.

### Creating and using streams

Any `@qd.kernel` function accepts a special `qd_stream` keyword argument — you do not need to declare it in the kernel signature. The `@qd.kernel` decorator handles it automatically.

```python
@qd.kernel
def my_kernel():
for i in range(N):
a[i] = i

s1 = qd.create_stream()
s2 = qd.create_stream()

my_kernel(qd_stream=s1)
my_kernel(qd_stream=s2)

s1.synchronize()
s2.synchronize()

s1.destroy()
s2.destroy()
```

Kernels on different streams may execute concurrently. Call `synchronize()` to block until all work on a stream completes.

### Events

Events let you express dependencies between streams without full synchronization.

```python
s1 = qd.create_stream()
s2 = qd.create_stream()

@qd.kernel
def produce():
for i in range(N):
a[i] = 10.0

@qd.kernel
def consume():
for i in range(N):
b[i] = a[i]

produce(qd_stream=s1)

e = qd.create_event()
e.record(s1) # record when s1 finishes produce()
e.wait(qd_stream=s2) # s2 waits for that event before proceeding

consume(qd_stream=s2) # safe to read a[] — produce() is guaranteed complete
s2.synchronize()

e.destroy()
s1.destroy()
s2.destroy()
```

`e.record(stream)` captures the point in `stream`'s execution. `e.wait(qd_stream=stream)` makes `stream` wait until the recorded point is reached. If `qd_stream` is omitted, the default stream waits.

### Context managers

Streams and events support `with` blocks for automatic cleanup:

```python
with qd.create_stream() as s:
some_func1(qd_stream=s)
# s.destroy() called automatically — waits for in-flight work
```

## Synchronization notes

- **`qd.sync()` only waits on the default stream.** It does not drain explicit streams. Call `stream.synchronize()` on each stream you need to wait for.
- **No automatic synchronization with explicit streams.** When using explicit streams, you are responsible for inserting events or `synchronize()` calls when one stream's output is another stream's input. `stream_parallel` handles this automatically.

## Limitations

- **Not compatible with graphs.** Do not pass `qd_stream` to a kernel decorated with `graph=True` (if you do, a `RuntimeError` will be raised).
- **Not compatible with autodiff.** Do not pass `qd_stream` to a kernel that uses reverse-mode or forward-mode differentiation, or inside a `qd.ad.Tape` context (if you do, a `RuntimeError` will be raised).
2 changes: 2 additions & 0 deletions python/quadrants/lang/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from quadrants.lang.runtime_ops import *
from quadrants.lang.snode import *
from quadrants.lang.source_builder import *
from quadrants.lang.stream import *
from quadrants.lang.struct import *
from quadrants.types.enums import DeviceCapability, Format, Layout # noqa: F401

Expand Down Expand Up @@ -47,6 +48,7 @@
"shell",
"snode",
"source_builder",
"stream",
"struct",
"util",
]
Expand Down
32 changes: 29 additions & 3 deletions python/quadrants/lang/ast/ast_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,11 @@ def build_AnnAssign(ctx: ASTTransformerFuncContext, node: ast.AnnAssign):

@staticmethod
def build_assign_annotated(
ctx: ASTTransformerFuncContext, target: ast.Name, value, is_static_assign: bool, annotation: Type
ctx: ASTTransformerFuncContext,
target: ast.Name,
value,
is_static_assign: bool,
annotation: Type,
):
"""Build an annotated assignment like this: target: annotation = value.

Expand Down Expand Up @@ -165,7 +169,10 @@ def build_Assign(ctx: ASTTransformerFuncContext, node: ast.Assign) -> None:

@staticmethod
def build_assign_unpack(
ctx: ASTTransformerFuncContext, node_target: list | ast.Tuple, values, is_static_assign: bool
ctx: ASTTransformerFuncContext,
node_target: list | ast.Tuple,
values,
is_static_assign: bool,
):
"""Build the unpack assignments like this: (target1, target2) = (value1, value2).
The function should be called only if the node target is a tuple.
Expand Down Expand Up @@ -591,7 +598,8 @@ def build_Return(ctx: ASTTransformerFuncContext, node: ast.Return) -> None:
else:
raise QuadrantsSyntaxError("The return type is not supported now!")
ctx.ast_builder.create_kernel_exprgroup_return(
expr.make_expr_group(return_exprs), _qd_core.DebugInfo(ctx.get_pos_info(node))
expr.make_expr_group(return_exprs),
_qd_core.DebugInfo(ctx.get_pos_info(node)),
)
else:
ctx.return_data = node.value.ptr
Expand Down Expand Up @@ -1520,6 +1528,24 @@ def build_Continue(ctx: ASTTransformerFuncContext, node: ast.Continue) -> None:
ctx.ast_builder.insert_continue_stmt(_qd_core.DebugInfo(ctx.get_pos_info(node)))
return None

@staticmethod
def build_With(ctx: ASTTransformerFuncContext, node: ast.With) -> None:
if len(node.items) != 1:
raise QuadrantsSyntaxError("'with' in Quadrants kernels only supports a single context manager")
item = node.items[0]
if item.optional_vars is not None:
raise QuadrantsSyntaxError("'with ... as ...' is not supported in Quadrants kernels")
if not isinstance(item.context_expr, ast.Call):
raise QuadrantsSyntaxError("'with' in Quadrants kernels requires a call expression")
if not FunctionDefTransformer._is_stream_parallel_with(node, ctx.global_vars):
raise QuadrantsSyntaxError("'with' in Quadrants kernels only supports qd.stream_parallel()")
if not ctx.is_kernel:
raise QuadrantsSyntaxError("qd.stream_parallel() can only be used inside @qd.kernel, not @qd.func")
ctx.ast_builder.begin_stream_parallel()
build_stmts(ctx, node.body)
ctx.ast_builder.end_stream_parallel()
return None

@staticmethod
def build_Pass(ctx: ASTTransformerFuncContext, node: ast.Pass) -> None:
return None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,13 @@
from quadrants.lang.ast.ast_transformer_utils import (
ASTTransformerFuncContext,
)
from quadrants.lang.ast.symbol_resolver import ASTResolver
from quadrants.lang.buffer_view import BufferView
from quadrants.lang.exception import (
QuadrantsSyntaxError,
)
from quadrants.lang.matrix import MatrixType
from quadrants.lang.stream import stream_parallel
from quadrants.lang.struct import StructType
from quadrants.lang.util import to_quadrants_type
from quadrants.types import annotations, buffer_view_type, ndarray_type, primitive_types
Expand Down Expand Up @@ -317,7 +319,11 @@ def _transform_func_arg(
# polymorphic).
if field.type is not _TensorClass and hasattr(field.type, "check_matched"):
field.type.check_matched(data_child.get_type(), field.name)
_cache = getattr(getattr(ctx, "global_context", None), "ndarray_to_any_array", None)
_cache = getattr(
getattr(ctx, "global_context", None),
"ndarray_to_any_array",
None,
)
promoted = _cache.get(id(data_child)) if _cache else None
ctx.create_variable(flat_name, promoted if promoted is not None else data_child)
elif dataclasses.is_dataclass(data_child):
Expand All @@ -336,7 +342,13 @@ def _transform_func_arg(
# Ndarray arguments are passed by reference.
if isinstance(argument_type, (ndarray_type.NdarrayType)):
if not isinstance(
data, (_ndarray.ScalarNdarray, matrix.VectorNdarray, matrix.MatrixNdarray, any_array.AnyArray)
data,
(
_ndarray.ScalarNdarray,
matrix.VectorNdarray,
matrix.MatrixNdarray,
any_array.AnyArray,
),
):
raise QuadrantsSyntaxError(f"Argument {argument_name} of type {argument_type} is not recognized.")
argument_type.check_matched(data.get_type(), argument_name)
Expand Down Expand Up @@ -443,7 +455,70 @@ def build_FunctionDef(
else:
FunctionDefTransformer._transform_as_func(ctx, node, args)

if ctx.is_kernel:
FunctionDefTransformer._validate_stream_parallel_exclusivity(node.body, ctx.global_vars)

with ctx.variable_scope_guard():
build_stmts(ctx, node.body)

return None

@staticmethod
def _is_stream_parallel_with(stmt: ast.stmt, global_vars: dict[str, Any]) -> bool:
if not isinstance(stmt, ast.With):
return False
if len(stmt.items) != 1:
return False
item = stmt.items[0]
if not isinstance(item.context_expr, ast.Call):
return False
func_node = item.context_expr.func
if ASTResolver.resolve_to(func_node, stream_parallel, global_vars):
return True
resolved = ASTResolver.resolve_value(func_node, global_vars)
if resolved is not None:
return getattr(resolved, "__name__", None) == "stream_parallel" and getattr(
resolved, "__module__", ""
).startswith("quadrants")
if isinstance(func_node, ast.Attribute) and func_node.attr == "stream_parallel":
return True
if isinstance(func_node, ast.Name) and func_node.id == "stream_parallel":
return True
return False

@staticmethod
def _is_docstring(stmt: ast.stmt, index: int) -> bool:
return index == 0 and isinstance(stmt, ast.Expr) and isinstance(stmt.value, (ast.Constant, ast.Str))

@staticmethod
def _is_coverage_probe(stmt: ast.stmt) -> bool:
if not isinstance(stmt, ast.Assign) or len(stmt.targets) != 1:
return False
target = stmt.targets[0]
return (
isinstance(target, ast.Subscript)
and isinstance(target.value, ast.Name)
and target.value.id.startswith("_qd_cov")
)

@staticmethod
def _validate_stream_parallel_exclusivity(body: list[ast.stmt], global_vars: dict[str, Any]) -> None:
if not any(FunctionDefTransformer._is_stream_parallel_with(s, global_vars) for s in body):
return
for i, stmt in enumerate(body):
if FunctionDefTransformer._is_docstring(stmt, i):
continue
if FunctionDefTransformer._is_coverage_probe(stmt):
continue
if not FunctionDefTransformer._is_stream_parallel_with(stmt, global_vars):
stmt_desc = f"{type(stmt).__name__}"
if isinstance(stmt, ast.With) and stmt.items:
ctx_expr = stmt.items[0].context_expr
if isinstance(ctx_expr, ast.Call) and isinstance(ctx_expr.func, ast.Attribute):
stmt_desc += f"(with {ast.dump(ctx_expr.func)})"
raise QuadrantsSyntaxError(
"When using qd.stream_parallel(), all top-level statements "
"in the kernel must be 'with qd.stream_parallel():' blocks. "
f"Move non-parallel code to a separate kernel. "
f"[stmt {i}: {stmt_desc}, body_len={len(body)}]"
)
32 changes: 32 additions & 0 deletions python/quadrants/lang/ast/symbol_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,35 @@ def resolve_to(node, wanted, scope):
return False
# The name ``scope`` here could be a bit confusing
return scope is wanted

@staticmethod
def resolve_value(node, scope):
"""Resolve an AST Name/Attribute node to a Python object.

Same traversal as resolve_to but returns the resolved object (or None) instead of comparing against a wanted
value.
"""
if isinstance(node, ast.Name):
return scope.get(node.id) if isinstance(scope, dict) else None

if not isinstance(node, ast.Attribute):
return None

v = node.value
chain = [node.attr]
while isinstance(v, ast.Attribute):
chain.append(v.attr)
v = v.value
if not isinstance(v, ast.Name):
return None
chain.append(v.id)

for attr in reversed(chain):
try:
if isinstance(scope, dict):
scope = scope[attr]
else:
scope = getattr(scope, attr)
except (KeyError, AttributeError):
return None
return scope
Loading
Loading