Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
9d7f4bf
Setup tables for streams
joefreeman Apr 18, 2026
d43bcf4
Setup streams module
joefreeman Apr 18, 2026
ab58549
Close streams when execution completes
joefreeman Apr 18, 2026
326328b
Update wire protocol
joefreeman Apr 18, 2026
aba392e
Implement generator detector and driver
joefreeman Apr 18, 2026
24ddfc1
Setup stream consumers
joefreeman Apr 18, 2026
6afcba2
Update topic, and update epoch generation
joefreeman Apr 18, 2026
e87ab2f
Add serialiser for stream
joefreeman Apr 18, 2026
b340d49
Fix typing
joefreeman Apr 18, 2026
f2b3f4d
Add tests
joefreeman Apr 18, 2026
5d8d86c
Various fixes
joefreeman Apr 19, 2026
9870be5
Add stream topic
joefreeman Apr 19, 2026
90d5b6f
Tidy stream IDs/terminology
joefreeman Apr 19, 2026
589612b
Support configuring buffer/backpressure
joefreeman Apr 19, 2026
4932cb3
Use completions to determine execution state
joefreeman Apr 19, 2026
aecdb71
Consolidate stream filters into a computed stride
joefreeman Apr 19, 2026
ab62a2f
Support configuring timeouts on streams
joefreeman Apr 20, 2026
f870911
Tidy imports
joefreeman Apr 20, 2026
d0611ed
Fix replay after producer terminated
joefreeman Apr 20, 2026
a77d391
Handle stream error/timeout as dedicated completion kinds
joefreeman Apr 20, 2026
ef18a4a
Capture full stream error
joefreeman Apr 20, 2026
71ef39b
Fix resolving result for execution with stream error
joefreeman Apr 20, 2026
982d75c
Dispatch worker requests async
joefreeman Apr 21, 2026
9dd66db
Don't use memoised execution with errored/timed-out stream
joefreeman Apr 21, 2026
d007aa0
Track stream dependencies
joefreeman Apr 22, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions adapters/python/coflux/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,18 @@
from .errors import (
ExecutionAbandoned,
ExecutionCancelled,
ExecutionCrashed,
ExecutionError,
ExecutionTerminated,
ExecutionTimeout,
InputDismissed,
)
from .metric import Metric, MetricGroup, MetricScale, progress
from .models import Asset, AssetEntry, AssetMetadata, Execution, Input, ModelSchema
from .models import Asset, AssetEntry, AssetMetadata, Execution, Input, Stream
from .prompt import Prompt
from .state import get_context
from .target import Cache, Defer, Retries
from .streams import stream
from .target import Cache, Defer, Retries, Streams

__all__ = [
# Version
Expand All @@ -41,19 +43,23 @@
"ExecutionCancelled",
"ExecutionTimeout",
"ExecutionAbandoned",
"ExecutionCrashed",
"InputDismissed",
"Input",
"ModelSchema",
"Metric",
"MetricGroup",
"MetricScale",
"Prompt",
"Cache",
"Defer",
"Retries",
"Streams",
"Asset",
"AssetEntry",
"AssetMetadata",
"Stream",
# Producer-side stream helper
"stream",
# Context functions
"group",
"suspense",
Expand Down
122 changes: 93 additions & 29 deletions adapters/python/coflux/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,25 @@
import fnmatch as fnmatch
import hashlib
import json
import threading
from contextlib import contextmanager
from pathlib import Path
from typing import Any, Callable, Iterator

from . import protocol
from .dispatcher import get_dispatcher
from .errors import (
ExecutionAbandoned,
ExecutionCancelled,
ExecutionCrashed,
ExecutionTimeout,
InputDismissed,
create_execution_error,
)
from .models import Asset, AssetEntry, AssetMetadata, Execution, Input
from .serialization import deserialize_value, serialize_value
from .streams import StreamDriver
from .target import Streams


def _handle_key(handle: Any) -> tuple[str, str]:
Expand Down Expand Up @@ -60,9 +65,19 @@ def _unwrap_response(
raise ExecutionTimeout()
if status == "abandoned":
raise ExecutionAbandoned()
if status == "crashed":
raise ExecutionCrashed()
raise RuntimeError(f"Unexpected select status: {status}")


def _timeout_to_ms(timeout: float | dt.timedelta | None) -> int | None:
if timeout is None:
return None
if isinstance(timeout, dt.timedelta):
return int(timeout.total_seconds() * 1000)
return int(timeout * 1000)


# Context variable for group tracking
_group_id: contextvars.ContextVar[int | None] = contextvars.ContextVar(
"_group_id", default=None
Expand All @@ -78,7 +93,6 @@ class ExecutorContext:

def __init__(self, execution_id: str, working_dir: Path | None = None):
self.execution_id = execution_id
self._pending_requests: dict[int, Any] = {}
self._groups: list[str | None] = []
self._working_dir = working_dir or Path.cwd()
self._defined_metrics: dict[str, dict] = {}
Expand All @@ -89,6 +103,55 @@ def __init__(self, execution_id: str, working_dir: Path | None = None):
# poll_handle to avoid a round-trip for handles that have already
# been seen in this context's lifetime.
self._resolved: dict[tuple[str, str], dict[str, Any]] = {}
# Guards the mutable collections above. Stream driver threads, the
# main task thread, and any user-spawned threads may call methods
# on this context concurrently; the lock protects check-then-set
# patterns (metric definition, group registration, resolve cache)
# from racing.
self._lock = threading.Lock()
# Owns generator streams for this execution. Generators encountered
# during serialization (of the return value OR of submit arguments)
# are registered here and driven in background threads.
self._stream_driver = StreamDriver(execution_id)
# Default stream config for this execution, populated by the
# executor from the target's ``@cf.task(streams=...)`` setting.
# Used by ``cf.stream(...)`` to fill in unspecified options.
self._default_streams: Streams | None = None

def set_default_streams(self, streams: Streams | None) -> None:
"""Record the decorator's stream config so ``cf.stream(...)`` can
inherit from it. Called once by the executor before running the
target function."""
self._default_streams = streams

def get_default_streams(self) -> Streams | None:
return self._default_streams

def register_stream(
self,
generator: Any,
buffer: int | None,
timeout: float | dt.timedelta | None = None,
) -> str:
"""Register a generator with this execution's stream driver and
return the resulting opaque stream id.

Called from ``cf.stream(...)``; also from the executor when the
task body itself is a generator.
"""
timeout_ms = _timeout_to_ms(timeout)
return self._stream_driver.register(generator, buffer, timeout_ms)

def wait_streams(self) -> None:
"""Block until every stream produced by this execution has drained."""
self._stream_driver.wait_all()

def close_streams(self) -> None:
"""Close every registered generator so driver threads exit promptly.

Used on the error path before reporting execution_error.
"""
self._stream_driver.close_all()

def submit_execution(
self,
Expand All @@ -106,6 +169,7 @@ def submit_execution(
recurrent: bool = False,
requires: dict[str, list[str]] | None = None,
timeout: int = 0,
streams: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""Submit a child execution and return its details.

Expand All @@ -130,6 +194,7 @@ def submit_execution(
recurrent=recurrent,
requires=requires,
timeout=timeout,
streams=streams,
)
return self._wait_response(request_id)

Expand Down Expand Up @@ -183,7 +248,8 @@ def select(
if winner is None:
raise RuntimeError(f"Unexpected select response: {response}")

self._resolved[_handle_key(handles[winner])] = response
with self._lock:
self._resolved[_handle_key(handles[winner])] = response
return winner

def resolve_handle(self, handle: Any) -> Any:
Expand All @@ -195,13 +261,17 @@ def resolve_handle(self, handle: Any) -> Any:
to the deserialized value.
"""
key = _handle_key(handle)
if key not in self._resolved:
with self._lock:
cached = self._resolved.get(key)
if cached is None:
if self.select([handle]) is None:
# The wait expired before the handle resolved. Only reachable
# from inside a `cf.suspense(timeout=...)` scope; otherwise the
# server either resolves or kills the process.
raise TimeoutError("timed out waiting for handle to resolve")
return _unwrap_response(self._resolved[key], handle._parser)
with self._lock:
cached = self._resolved[key]
return _unwrap_response(cached, handle._parser)

def poll_handle(
self,
Expand All @@ -216,11 +286,15 @@ def poll_handle(
is applied to it.
"""
key = _handle_key(handle)
if key not in self._resolved:
with self._lock:
cached = self._resolved.get(key)
if cached is None:
timeout_ms = int(timeout * 1000) if timeout else 0
if self.select([handle], suspend=False, timeout_ms=timeout_ms) is None:
return default
return _unwrap_response(self._resolved[key], handle._parser)
with self._lock:
cached = self._resolved[key]
return _unwrap_response(cached, handle._parser)

def get_asset_entries(self, asset_id: str) -> list[AssetEntry]:
"""Get all entries for an asset by ID."""
Expand Down Expand Up @@ -483,8 +557,9 @@ def log_error(self, message: str) -> None:
@contextmanager
def group(self, name: str | None = None) -> Iterator[None]:
"""Context manager for grouping child executions."""
group_id = len(self._groups)
self._groups.append(name)
with self._lock:
group_id = len(self._groups)
self._groups.append(name)
protocol.send_register_group(self.execution_id, group_id, name)
token = _group_id.set(group_id)
try:
Expand Down Expand Up @@ -517,8 +592,7 @@ def suspend_execution(
request_id = protocol.request_suspend(self.execution_id, execute_after)
self._wait_response(request_id)
# Suspension confirmed. Block until the server aborts this execution.
while protocol.receive_message() is not None:
pass
get_dispatcher().wait_closed()
raise SystemExit(0)

def _parse_response(self, msg: dict) -> Any:
Expand All @@ -529,22 +603,12 @@ def _parse_response(self, msg: dict) -> Any:
return msg.get("result", {})

def _wait_response(self, request_id: int) -> Any:
"""Wait for a response to a request."""
if request_id in self._pending_requests:
return self._parse_response(self._pending_requests.pop(request_id))
while True:
msg = protocol.receive_message()
if msg is None:
raise RuntimeError("Connection closed while waiting for response")

# Check if this is a response
if "id" in msg:
if msg["id"] == request_id:
return self._parse_response(msg)
# Store other responses for later
self._pending_requests[msg["id"]] = msg
else:
# Unexpected message during wait
raise RuntimeError(
f"Unexpected message while waiting for response: {msg}"
)
"""Wait for a response to a request.

Delegates to the dispatcher, which owns stdin and routes the matching
response to this caller. Safe to call from any thread.
"""
msg = get_dispatcher().wait_for_response(request_id)
if msg is None:
raise RuntimeError("Timed out waiting for response")
return self._parse_response(msg)
37 changes: 31 additions & 6 deletions adapters/python/coflux/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,32 @@
import datetime as dt
import typing as t

from .target import Cache, Defer, Retries, Target
from .target import _STREAMS_UNSET, Cache, Defer, Retries, Streams, Target

if t.TYPE_CHECKING:
from .models import Stream

P = t.ParamSpec("P")
T = t.TypeVar("T")


class _TargetDecorator(t.Protocol):
"""Decorator protocol that unwraps ``Coroutine`` for async functions.
"""Decorator protocol that unwraps ``Coroutine`` and collapses generator
return types to ``Stream``.

Overloading on ``__call__`` (rather than on the factory function) lets
the type checker pick the right overload based on the decorated
function's return type — which is visible at application time, but not
at the factory call.

Overloading on ``__call__`` (rather than on the factory function)
lets the type checker pick the right overload based on the decorated
function's return type — which is visible at application time, but
not at the factory call.
Overload resolution order matters: generator functions match first (so
``-> Iterator[T]`` gives a ``Target[P, Stream[T]]``), then async
coroutines (unwrapped), then the general case.
"""

@t.overload
def __call__(self, fn: t.Callable[P, t.Iterator[T]]) -> Target[P, "Stream[T]"]: ...

@t.overload
def __call__(
self, fn: t.Callable[P, t.Coroutine[t.Any, t.Any, T]]
Expand All @@ -41,12 +52,21 @@ def task(
memo: bool | t.Iterable[str] = False,
requires: dict[str, str | bool | list[str]] | None = None,
timeout: float | dt.timedelta = 0,
streams: Streams | None = _STREAMS_UNSET, # type: ignore[assignment]
) -> _TargetDecorator:
"""Decorator for defining a task.

For ``async def`` functions, the coroutine is run to completion by
the executor; the task's return type is the coroutine's resolved
value (not the coroutine itself).

``streams`` only applies to tasks that produce streams — either
generator-bodied tasks (``def`` + ``yield`` / ``async def`` +
``yield``) or tasks that call ``cf.stream(...)`` internally. It
configures the default ``buffer`` and ``timeout`` for those
streams; per-call overrides on ``cf.stream(...)`` win. Passing
``streams=`` on a non-generator task raises ``TypeError`` at
decoration time.
"""

def decorator(fn):
Expand All @@ -63,6 +83,7 @@ def decorator(fn):
memo=memo,
requires=requires,
timeout=timeout,
streams=streams,
)

return decorator # type: ignore[return-value]
Expand All @@ -80,12 +101,15 @@ def workflow(
memo: bool = False,
requires: dict[str, str | bool | list[str]] | None = None,
timeout: float | dt.timedelta = 0,
streams: Streams | None = _STREAMS_UNSET, # type: ignore[assignment]
) -> _TargetDecorator:
"""Decorator for defining a workflow.

For ``async def`` functions, the coroutine is run to completion by
the executor; the workflow's return type is the coroutine's resolved
value (not the coroutine itself).

See ``@cf.task`` for ``streams=`` semantics.
"""

def decorator(fn):
Expand All @@ -102,6 +126,7 @@ def decorator(fn):
memo=memo,
requires=requires,
timeout=timeout,
streams=streams,
)

return decorator # type: ignore[return-value]
Expand Down
14 changes: 13 additions & 1 deletion adapters/python/coflux/discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,14 @@
import sys
from typing import Any

from .target import Target, _to_ms, serialize_cache, serialize_defer, serialize_retries
from .target import (
Target,
_to_ms,
serialize_cache,
serialize_defer,
serialize_retries,
serialize_streams,
)


def _expand_modules(module_names: list[str]) -> list[str]:
Expand Down Expand Up @@ -139,6 +146,11 @@ def _build_target_definition(target: Any, module_name: str) -> dict[str, Any]:
if definition.instruction:
result["instruction"] = definition.instruction

if definition.streams is not None:
streams_dict = serialize_streams(definition.streams)
if streams_dict is not None:
result["streams"] = streams_dict

return result


Expand Down
Loading
Loading