From 9d7f4bf0cb5934a8fca1f6320ec16b84d11ee6be Mon Sep 17 00:00:00 2001 From: Joe Freeman Date: Sat, 18 Apr 2026 18:31:52 +0100 Subject: [PATCH 01/25] Setup tables for streams --- adapters/python/coflux/__init__.py | 2 + adapters/python/coflux/context.py | 3 + adapters/python/coflux/errors.py | 12 + server/lib/coflux/handlers/worker.ex | 6 +- server/lib/coflux/orchestration/epoch.ex | 28 +- server/lib/coflux/orchestration/results.ex | 83 +++- server/lib/coflux/orchestration/runs.ex | 15 +- server/lib/coflux/orchestration/server.ex | 540 ++++++++++++--------- server/lib/coflux/topics/run.ex | 21 +- server/priv/migrations/orchestration/4.sql | 67 +++ 10 files changed, 533 insertions(+), 244 deletions(-) create mode 100644 server/priv/migrations/orchestration/4.sql diff --git a/adapters/python/coflux/__init__.py b/adapters/python/coflux/__init__.py index 50360eab..6a9e9cf6 100644 --- a/adapters/python/coflux/__init__.py +++ b/adapters/python/coflux/__init__.py @@ -16,6 +16,7 @@ from .errors import ( ExecutionAbandoned, ExecutionCancelled, + ExecutionCrashed, ExecutionError, ExecutionTerminated, ExecutionTimeout, @@ -41,6 +42,7 @@ "ExecutionCancelled", "ExecutionTimeout", "ExecutionAbandoned", + "ExecutionCrashed", "InputDismissed", "Input", "ModelSchema", diff --git a/adapters/python/coflux/context.py b/adapters/python/coflux/context.py index eb349907..fc781aa7 100644 --- a/adapters/python/coflux/context.py +++ b/adapters/python/coflux/context.py @@ -15,6 +15,7 @@ from .errors import ( ExecutionAbandoned, ExecutionCancelled, + ExecutionCrashed, ExecutionTimeout, InputDismissed, create_execution_error, @@ -60,6 +61,8 @@ def _unwrap_response( raise ExecutionTimeout() if status == "abandoned": raise ExecutionAbandoned() + if status == "crashed": + raise ExecutionCrashed() raise RuntimeError(f"Unexpected select status: {status}") diff --git a/adapters/python/coflux/errors.py b/adapters/python/coflux/errors.py index 559e58e9..46aead92 100644 --- a/adapters/python/coflux/errors.py +++ b/adapters/python/coflux/errors.py @@ -65,6 +65,18 @@ def __init__(self, message: str = "execution was abandoned"): super().__init__(message) +class ExecutionCrashed(ExecutionTerminated): + """Raised when a child execution's worker terminated without reporting. + + The worker process ended (sent notify_terminated) but never reported + a result or error — typically indicates a process crash or shutdown + that didn't give the task code a chance to report. + """ + + def __init__(self, message: str = "worker terminated without reporting a result"): + super().__init__(message) + + class InputDismissed(Exception): """Raised when an input request was dismissed.""" diff --git a/server/lib/coflux/handlers/worker.ex b/server/lib/coflux/handlers/worker.ex index f227d4fa..1f70f4ce 100644 --- a/server/lib/coflux/handlers/worker.ex +++ b/server/lib/coflux/handlers/worker.ex @@ -648,7 +648,8 @@ defmodule Coflux.Handlers.Worker do # where result_detail is one of: # {:value, value} # {:error, type, message, frames, retry_id, retryable?} - # :cancelled | :dismissed | {:abandoned, nil} | {:timeout, nil} | :suspended + # :cancelled | :dismissed | {:abandoned, _} | {:crashed, _} | + # {:timeout, _} | :suspended defp compose_select_result(:timeout) do nil end @@ -681,6 +682,9 @@ defmodule Coflux.Handlers.Worker do {:abandoned, _} -> Map.put(base, "status", "abandoned") + {:crashed, _} -> + Map.put(base, "status", "crashed") + {:timeout, _} -> Map.put(base, "status", "timeout") diff --git a/server/lib/coflux/orchestration/epoch.ex b/server/lib/coflux/orchestration/epoch.ex index e585f55b..f0d6eeac 100644 --- a/server/lib/coflux/orchestration/epoch.ex +++ b/server/lib/coflux/orchestration/epoch.ex @@ -261,15 +261,16 @@ defmodule Coflux.Orchestration.Epoch do case query_one( source_db, """ - SELECT type, error_id, value_id, successor_id, successor_ref_id, created_at, created_by + SELECT type, error_id, value_id, successor_id, successor_ref_id, + retryable, created_at, created_by FROM results WHERE execution_id = ?1 """, {old_exec_id} ) do {:ok, - {type, error_id, value_id, successor_id, successor_ref_id, result_created_at, - result_created_by}} -> + {type, error_id, value_id, successor_id, successor_ref_id, retryable, + result_created_at, result_created_by}} -> new_error_id = if error_id, do: ensure_error(source_db, target_db, error_id) {new_value_id, new_successor_id, new_successor_ref_id, visited} = @@ -299,6 +300,7 @@ defmodule Coflux.Orchestration.Epoch do value_id: new_value_id, successor_id: new_successor_id, successor_ref_id: new_successor_ref_id, + retryable: retryable, created_at: result_created_at, created_by: ensure_principal(source_db, target_db, result_created_by) }) @@ -310,6 +312,26 @@ defmodule Coflux.Orchestration.Epoch do end end) + # Copy completions (where present — an execution may have results + # but no completion yet if it's mid-termination). + Enum.each(execution_ids, fn {old_exec_id, new_exec_id} -> + case query_one( + source_db, + "SELECT created_at FROM completions WHERE execution_id = ?1", + {old_exec_id} + ) do + {:ok, {completion_created_at}} -> + {:ok, _} = + insert_one(target_db, :completions, %{ + execution_id: new_exec_id, + created_at: completion_created_at + }) + + {:ok, nil} -> + :ok + end + end) + # Copy children — same-run internal IDs {:ok, children} = query( diff --git a/server/lib/coflux/orchestration/results.ex b/server/lib/coflux/orchestration/results.ex index b41ffb23..26f95823 100644 --- a/server/lib/coflux/orchestration/results.ex +++ b/server/lib/coflux/orchestration/results.ex @@ -3,8 +3,18 @@ defmodule Coflux.Orchestration.Results do alias Coflux.Orchestration.Values + # Writes the results row capturing the disposition (value/error/retryable) + # and any server-decided successor. Written at the time the disposition is + # known — for worker-reported results that's put_result/put_error/etc.; + # for server-initiated dispositions (abandonment, defer/cache/spawn) it's + # when the server makes the decision. + # + # A matching completion row is written separately via record_completion, + # typically when the worker's process confirms it has terminated (via + # notify_terminated). For server-initiated cases that never involve a + # worker, the caller writes both in sequence. def record_result(db, execution_id, result, created_by \\ nil) do - with_transaction(db, fn -> + with_snapshot(db, fn -> now = current_timestamp() {type, error_id, value_id, successor_id, successor_ref_id, retryable} = @@ -34,7 +44,8 @@ defmodule Coflux.Orchestration.Results do {:deferred, defer_id} -> {4, nil, nil, defer_id, nil, nil} - # Resolved deferred (successor resolved to a value) + # Resolved deferred (successor resolved to a value — from epoch copy + # or runtime cache hit) {:deferred, ref_id, value} -> {:ok, value_id} = Values.get_or_create_value(db, value) {4, nil, value_id, nil, ref_id, nil} @@ -43,7 +54,7 @@ defmodule Coflux.Orchestration.Results do {:cached, cached_id} -> {5, nil, nil, cached_id, nil, nil} - # Resolved cached (successor resolved to a value) + # Resolved cached {:cached, ref_id, value} -> {:ok, value_id} = Values.get_or_create_value(db, value) {5, nil, value_id, nil, ref_id, nil} @@ -58,7 +69,7 @@ defmodule Coflux.Orchestration.Results do {:spawned, execution_id} -> {7, nil, nil, execution_id, nil, nil} - # Resolved spawned (successor resolved to a value) + # Resolved spawned {:spawned, ref_id, value} -> {:ok, value_id} = Values.get_or_create_value(db, value) {7, nil, value_id, nil, ref_id, nil} @@ -76,11 +87,27 @@ defmodule Coflux.Orchestration.Results do now, created_by ) do - {:ok, _} -> - {:ok, now} + {:ok, _} -> {:ok, now} + {:error, "UNIQUE constraint failed: " <> _field} -> {:error, :already_recorded} + end + end) + end - {:error, "UNIQUE constraint failed: " <> _field} -> - {:error, :already_recorded} + # Writes the completion row — a simple timestamp marker recording that the + # execution's process has fully terminated. For worker-involved cases this + # is triggered by notify_terminated; for server-initiated dispositions + # (abandonment, cache-hit scheduling) the caller writes this right after + # record_result. + def record_completion(db, execution_id) do + with_transaction(db, fn -> + now = current_timestamp() + + case insert_one(db, :completions, %{ + execution_id: execution_id, + created_at: now + }) do + {:ok, _} -> {:ok, now} + {:error, "UNIQUE constraint failed: " <> _field} -> {:error, :already_completed} end end) end @@ -92,23 +119,51 @@ defmodule Coflux.Orchestration.Results do end end + def has_completion?(db, execution_id) do + case query_one( + db, + "SELECT count(*) FROM completions WHERE execution_id = ?1", + {execution_id} + ) do + {:ok, {0}} -> {:ok, false} + {:ok, {1}} -> {:ok, true} + end + end + def get_result(db, execution_id) do case query_one( db, """ - SELECT r.type, r.error_id, r.value_id, r.successor_id, r.successor_ref_id, r.retryable, r.created_at, + SELECT r.type, r.error_id, r.value_id, r.successor_id, r.successor_ref_id, + r.retryable, r.created_at, c.created_at AS completion_created_at, p.user_external_id AS created_by_user_external_id, t.external_id AS created_by_token_external_id FROM results AS r + LEFT JOIN completions AS c ON c.execution_id = r.execution_id LEFT JOIN principals AS p ON r.created_by = p.id LEFT JOIN tokens AS t ON p.token_id = t.id WHERE r.execution_id = ?1 """, {execution_id} ) do + {:ok, nil} -> + # No results row. If the execution has a completion row anyway, + # the worker terminated without ever reporting — treat as crashed. + case query_one( + db, + "SELECT created_at FROM completions WHERE execution_id = ?1", + {execution_id} + ) do + {:ok, {completion_created_at}} -> + {:ok, {{:crashed, nil}, nil, completion_created_at, nil}} + + {:ok, nil} -> + {:ok, nil} + end + {:ok, {type, error_id, value_id, successor_id, successor_ref_id, retryable, created_at, - created_by_user_ext_id, created_by_token_ext_id}} -> + completion_created_at, created_by_user_ext_id, created_by_token_ext_id}} -> created_by = case {created_by_user_ext_id, created_by_token_ext_id} do {nil, nil} -> nil @@ -145,7 +200,6 @@ defmodule Coflux.Orchestration.Results do {8, nil, nil, retry_id, nil} -> {:timeout, retry_id} - # Deferred: in-flight (successor_id set) or resolved (ref + value) {4, nil, nil, defer_id, nil} -> {:deferred, defer_id} @@ -154,7 +208,6 @@ defmodule Coflux.Orchestration.Results do {:ok, value} -> {:deferred, ref_id, value} end - # Cached: in-flight or resolved {5, nil, nil, cached_id, nil} -> {:cached, cached_id} @@ -166,7 +219,6 @@ defmodule Coflux.Orchestration.Results do {6, nil, nil, successor_id, nil} -> {:suspended, successor_id} - # Spawned: in-flight or resolved {7, nil, nil, execution_id, nil} -> {:spawned, execution_id} @@ -179,10 +231,7 @@ defmodule Coflux.Orchestration.Results do {:recurred, successor_id} end - {:ok, {result, created_at, created_by}} - - {:ok, nil} -> - {:ok, nil} + {:ok, {result, created_at, completion_created_at, created_by}} end end diff --git a/server/lib/coflux/orchestration/runs.ex b/server/lib/coflux/orchestration/runs.ex index cf66b397..9643e364 100644 --- a/server/lib/coflux/orchestration/runs.ex +++ b/server/lib/coflux/orchestration/runs.ex @@ -950,21 +950,28 @@ defmodule Coflux.Orchestration.Runs do @doc """ Gets result types for executions of a step, ordered most recent first. """ + # Returns {execution_id, type} for each execution of the step that has + # reached a terminal state, most-recent first. A row is considered terminal + # if it has either a results row or a completions row (the latter without + # the former indicates the worker crashed without reporting). + # For crashed executions, type is NULL. def get_step_result_types(db, step_id, limit) do case query( db, """ - SELECT r.type + SELECT e.id, r.type FROM executions AS e - INNER JOIN results AS r ON r.execution_id = e.id + LEFT JOIN results AS r ON r.execution_id = e.id + LEFT JOIN completions AS c ON c.execution_id = e.id WHERE e.step_id = ?1 + AND (r.execution_id IS NOT NULL OR c.execution_id IS NOT NULL) ORDER BY e.created_at DESC LIMIT ?2 """, {step_id, limit} ) do {:ok, rows} -> - {:ok, Enum.map(rows, fn {type} -> type end)} + {:ok, rows} end end @@ -1182,7 +1189,7 @@ defmodule Coflux.Orchestration.Runs do # Also find predecessors that reference this execution via successor_ref_id. # This only searches the active epoch, which is sufficient because - # successor_ref_id results are always written to the active epoch (either + # successor_ref_id rows are always written to the active epoch (either # during runtime cache hits or epoch copy), and in-flight runs are copied # forward during rotation. {:ok, {run_ext, step_num, attempt}} = get_execution_key(db, execution_id) diff --git a/server/lib/coflux/orchestration/server.ex b/server/lib/coflux/orchestration/server.ex index 643190df..722754a1 100644 --- a/server/lib/coflux/orchestration/server.ex +++ b/server/lib/coflux/orchestration/server.ex @@ -346,10 +346,13 @@ defmodule Coflux.Orchestration.Server do end) ) else - # Session no longer active - abandon these executions + # Session no longer active - abandon these executions. Server-initiated + # so we write both the results row (via process_result) and the + # completion row (via complete_execution) here — no worker is + # going to send notify_terminated for this execution. Enum.reduce(execution_ids, state, fn execution_id, state -> {:ok, state} = process_result(state, execution_id, :abandoned) - state + complete_execution(state, execution_id) end) end end) @@ -1661,8 +1664,10 @@ defmodule Coflux.Orchestration.Server do case Results.has_result?(state.db, execution_id) do {:ok, false} -> + # Server-detected abandonment. Write both tables — no worker + # will send notify_terminated for this execution. {:ok, state} = process_result(state, execution_id, :abandoned) - state + complete_execution(state, execution_id) {:ok, true} -> state @@ -1731,18 +1736,14 @@ defmodule Coflux.Orchestration.Server do state = external_execution_ids |> Enum.reduce(state, fn ext_id, state -> - # If execution has no result recorded, mark it as abandoned + # Finalize the execution — writes completion (plus creating any + # successor) if it hasn't already been done. For worker-reported + # error/timeout this runs the retry decision now. For executions + # with no results row yet, falls back to the :abandoned path. state = case Map.fetch(state.execution_ids, ext_id) do {:ok, execution_id} -> - case Results.has_result?(state.db, execution_id) do - {:ok, false} -> - {:ok, state} = process_result(state, execution_id, :abandoned) - state - - {:ok, true} -> - state - end + complete_execution(state, execution_id) :error -> state @@ -2707,8 +2708,10 @@ defmodule Coflux.Orchestration.Server do {:cached, ref_id, value} end - {:ok, state} = - process_result(state, execution.execution_id, result) + # Cache hit during scheduling — server-only, no worker runs this + # execution. Write results + completion together. + {:ok, state} = process_result(state, execution.execution_id, result) + state = complete_execution(state, execution.execution_id) {state, assigned, unassigned} else @@ -3447,7 +3450,7 @@ defmodule Coflux.Orchestration.Server do # execution that doesn't yet have a result. defp resolve_active_execution(db, execution_id) do case Results.get_result(db, execution_id) do - {:ok, {{:spawned, successor_id}, _created_at, _created_by}} -> + {:ok, {{:spawned, successor_id}, _created_at, _completion_created_at, _created_by}} -> resolve_active_execution(db, successor_id) _ -> @@ -3457,6 +3460,11 @@ defmodule Coflux.Orchestration.Server do # Cancel a single execution: record :cancelled, abort if assigned, cancel descendants. defp do_cancel_execution(state, execution_id, workspace_id) do + # Write the results row and fire result-time notifications to mark the + # execution as cancelled. The completion row isn't written until the + # worker confirms termination (via notify_terminated), so consumers can + # distinguish "cancelling" (results present, completion absent) from + # "cancelled" (both present). state = case record_and_notify_result(state, execution_id, :cancelled, nil) do {:ok, state} -> state @@ -3707,14 +3715,16 @@ defmodule Coflux.Orchestration.Server do {session, state} = pop_in(state.sessions[session_id]) state = Map.update!(state, :session_expiries, &Map.delete(&1, session_id)) - # starting/executing now contain external IDs - resolve to internal for process_result + # starting/executing now contain external IDs - resolve to internal for process_result. + # Session removal means no more notify_terminated for these executions, so we + # write both results + completion here. state = session.executing |> MapSet.union(session.starting) |> Enum.reduce(state, fn ext_id, state -> execution_id = Map.fetch!(state.execution_ids, ext_id) {:ok, state} = process_result(state, execution_id, :abandoned) - state + complete_execution(state, execution_id) end) |> Map.update!(:targets, fn all_targets -> Enum.reduce( @@ -4078,6 +4088,7 @@ defmodule Coflux.Orchestration.Server do {:error, _, _, _, false} -> false {:error, _, _, _, _} -> true :abandoned -> true + :crashed -> true :timeout -> true _ -> false end @@ -4694,17 +4705,17 @@ defmodule Coflux.Orchestration.Server do run_executions |> Enum.map(&elem(&1, 0)) |> Enum.reduce(%{}, fn execution_id, results -> - {result, completed_at, result_created_by} = + {result, result_at, completed_at, result_created_by} = case Results.get_result(db, execution_id) do - {:ok, {result, completed_at, created_by}} -> + {:ok, {result, result_at, completion_at, created_by}} -> result = build_result(result, db) - {result, completed_at, created_by} + {result, result_at, completion_at, created_by} {:ok, nil} -> - {nil, nil, nil} + {nil, nil, nil, nil} end - Map.put(results, execution_id, {result, completed_at, result_created_by}) + Map.put(results, execution_id, {result, result_at, completed_at, result_created_by}) end) steps = @@ -4764,7 +4775,8 @@ defmodule Coflux.Orchestration.Server do {:ok, workspace_external_id} = Workspaces.get_workspace_external_id(db, workspace_id) - {result, completed_at, result_created_by} = Map.fetch!(results, execution_id) + {result, result_at, completed_at, result_created_by} = + Map.fetch!(results, execution_id) execution_groups = groups @@ -4810,6 +4822,7 @@ defmodule Coflux.Orchestration.Server do created_by: execution_created_by, execute_after: execute_after, assigned_at: assigned_at, + result_at: result_at, completed_at: completed_at, groups: execution_groups, assets: assets, @@ -5403,6 +5416,7 @@ defmodule Coflux.Orchestration.Server do {:error, _, _, _, retry_id} -> is_nil(retry_id) {:value, _} -> true {:abandoned, retry_id} -> is_nil(retry_id) + {:crashed, retry_id} -> is_nil(retry_id) :cancelled -> true {:timeout, retry_id} -> is_nil(retry_id) {:suspended, _} -> false @@ -5434,6 +5448,10 @@ defmodule Coflux.Orchestration.Server do retry = if retry_id, do: resolve_execution(db, retry_id) {:abandoned, retry} + {:crashed, retry_id} -> + retry = if retry_id, do: resolve_execution(db, retry_id) + {:crashed, retry} + :cancelled -> :cancelled @@ -5471,129 +5489,199 @@ defmodule Coflux.Orchestration.Server do end end + # Write the results row (with successor info baked in) and fire + # result-time notifications: wake waiters, update dependencies, send the + # Studio :result event carrying the result tuple and result_at timestamp. + # The completion row is written later via complete_execution (triggered by + # notify_terminated for worker-involved cases, or by the server-initiated + # paths directly when no worker is involved). defp record_and_notify_result(state, execution_id, result, _module, created_by \\ nil) do + result = + case result do + {:value, value} -> {:value, normalize_value(value)} + other -> other + end + + case Results.record_result(state.db, execution_id, result, created_by) do + {:ok, result_at} -> + state = fire_result_notifications(state, execution_id, result, result_at, created_by) + {:ok, state} + + {:error, reason} -> + {:error, reason} + end + end + + defp fire_result_notifications(state, execution_id, result, result_at, created_by) do {:ok, workspace_id} = Runs.get_workspace_id_for_execution(state.db, execution_id) {:ok, successors} = Runs.get_result_successors(state.db, execution_id) {:ok, {r, s, a}} = Runs.get_execution_key(state.db, execution_id) execution_external_id = execution_external_id(r, s, a) - result = - case result do - {:value, value} -> - {:value, normalize_value(value)} + state = + state + |> notify_waiting(execution_id) + |> update_dependencies_on_result(execution_id) + |> unregister_pending_dependencies(execution_id) + + final = is_result_final?(result) + built_result = build_result(result, state.db) - other -> - other + principal = + case Principals.get_principal(state.db, created_by) do + {:ok, {type, external_id}} -> %{type: type, external_id: external_id} + {:ok, nil} -> nil end - case Results.record_result(state.db, execution_id, result, created_by) do - {:ok, created_at} -> - state = - state - |> notify_waiting(execution_id) - |> update_dependencies_on_result(execution_id) - |> unregister_pending_dependencies(execution_id) + ws_ext_id = workspace_external_id(state, workspace_id) - final = is_result_final?(result) - result = build_result(result, state.db) + state = + successors + |> Enum.reduce(state, fn {run_external_id, successor_id}, state -> + cond do + successor_id == execution_id -> + notify_listeners( + state, + {:run, run_external_id}, + {:result, execution_external_id, built_result, result_at, principal} + ) - principal = - case Principals.get_principal(state.db, created_by) do - {:ok, {type, external_id}} -> %{type: type, external_id: external_id} - {:ok, nil} -> nil - end + final -> + {:ok, {r2, s2, a2}} = Runs.get_execution_key(state.db, successor_id) + successor_external_id = execution_external_id(r2, s2, a2) - ws_ext_id = workspace_external_id(state, workspace_id) + notify_listeners( + state, + {:run, run_external_id}, + # TODO: better name? + {:result_result, successor_external_id, built_result, result_at, principal} + ) - # get_result_successors now returns {run_external_id, successor_id} - state = - successors - |> Enum.reduce(state, fn {run_external_id, successor_id}, state -> - cond do - successor_id == execution_id -> - notify_listeners( - state, - {:run, run_external_id}, - {:result, execution_external_id, result, created_at, principal} - ) + true -> + state + end + end) + |> then(fn state -> + case untrack_run_execution(state, r, execution_id) do + {{root_module, root_target}, state} -> + notify_listeners( + state, + {:modules, ws_ext_id}, + {:completed, {root_module, root_target}, r, execution_external_id} + ) - final -> - {:ok, {r, s, a}} = Runs.get_execution_key(state.db, successor_id) - successor_external_id = execution_external_id(r, s, a) + {nil, state} -> + state + end + end) + |> notify_listeners( + {:queue, ws_ext_id}, + {:completed, execution_external_id} + ) - notify_listeners( - state, - {:run, run_external_id}, - # TODO: better name? - {:result_result, successor_external_id, result, created_at, principal} - ) + # Check if any input dependencies became inactive. Route the + # :inputs topic notification to the INPUT's workspace (matching + # :input_dependency_active in the resolve_input handler), not the + # completing execution's workspace — these differ when an execution + # in a child workspace resolved an input created in a parent. + state = + case Inputs.get_input_dependencies_for_execution(state.db, execution_id) do + {:ok, deps} -> + Enum.reduce(deps, state, fn {input_id, input_ws_id}, state -> + if Inputs.has_active_dependency?(state.db, input_id) do + state + else + {:ok, run_ext_id, input_number} = + Inputs.get_input_run_and_number(state.db, input_id) - true -> - state + input_ext_id = input_external_id(run_ext_id, input_number) + input_ws_ext_id = workspace_external_id(state, input_ws_id) + + state + |> notify_listeners( + {:inputs, input_ws_ext_id}, + {:input_dependency_inactive, input_ext_id} + ) + |> notify_listeners( + {:input, input_ext_id}, + {:active, false} + ) end end) - |> then(fn state -> - case untrack_run_execution(state, r, execution_id) do - {{root_module, root_target}, state} -> - notify_listeners( - state, - {:modules, ws_ext_id}, - {:completed, {root_module, root_target}, r, execution_external_id} - ) - {nil, state} -> + _ -> + state + end + + # TODO: only if there's an execution waiting for this result? + send(self(), :tick) + + state + end + + # Write the completion row and fire a completion-time notification. For + # "crashed" cases (notify_terminated with no prior results row), also + # decides retry and fires result-time notifications with a synthesised + # :crashed shape. + defp complete_execution(state, execution_id) do + case Results.has_completion?(state.db, execution_id) do + {:ok, true} -> + state + + {:ok, false} -> + case Results.has_result?(state.db, execution_id) do + {:ok, true} -> + case Results.record_completion(state.db, execution_id) do + {:ok, completion_at} -> + fire_completion_notification(state, execution_id, completion_at) + + {:error, :already_completed} -> state end - end) - |> notify_listeners( - {:queue, ws_ext_id}, - {:completed, execution_external_id} - ) - - # Check if any input dependencies became inactive. Route the - # :inputs topic notification to the INPUT's workspace (matching - # :input_dependency_active in the resolve_input handler), not the - # completing execution's workspace — these differ when an execution - # in a child workspace resolved an input created in a parent. - state = - case Inputs.get_input_dependencies_for_execution(state.db, execution_id) do - {:ok, deps} -> - Enum.reduce(deps, state, fn {input_id, input_ws_id}, state -> - if Inputs.has_active_dependency?(state.db, input_id) do - state - else - {:ok, run_ext_id, input_number} = - Inputs.get_input_run_and_number(state.db, input_id) - input_ext_id = input_external_id(run_ext_id, input_number) - input_ws_ext_id = workspace_external_id(state, input_ws_id) + {:ok, false} -> + handle_crashed(state, execution_id) + end + end + end - state - |> notify_listeners( - {:inputs, input_ws_ext_id}, - {:input_dependency_inactive, input_ext_id} - ) - |> notify_listeners( - {:input, input_ext_id}, - {:active, false} - ) - end - end) + # No results row exists for this execution but notify_terminated has + # arrived — the worker terminated without reporting. Decide retry, write + # completion (no results row), fire notifications. + defp handle_crashed(state, execution_id) do + {:ok, step} = Runs.get_step_for_execution(state.db, execution_id) + {:ok, workspace_id} = Runs.get_workspace_id_for_execution(state.db, execution_id) - _ -> - state - end + # Decide retry as if this were an abandoned-like failure. result_retryable? + # treats :crashed as retryable so the step's retry policy applies. + {retry_id, _recurred?, state} = + decide_and_create_successor(state, execution_id, step, workspace_id, :crashed) - # TODO: only if there's an execution waiting for this result? - send(self(), :tick) + case Results.record_completion(state.db, execution_id) do + {:ok, completion_at} -> + # Result-time notifications weren't fired (no results row was ever + # written), so fire them now alongside the completion notification. + state = + fire_result_notifications(state, execution_id, {:crashed, retry_id}, nil, nil) - {:ok, state} + fire_completion_notification(state, execution_id, completion_at) - {:error, reason} -> - {:error, reason} + {:error, :already_completed} -> + state end end + defp fire_completion_notification(state, execution_id, completion_at) do + {:ok, {r, s, a}} = Runs.get_execution_key(state.db, execution_id) + execution_external_id = execution_external_id(r, s, a) + + notify_listeners( + state, + {:run, r}, + {:completion, execution_external_id, completion_at} + ) + end + defp process_result(state, execution_id, result, created_by \\ nil) do case Results.has_result?(state.db, execution_id) do {:ok, true} -> @@ -5603,136 +5691,151 @@ defmodule Coflux.Orchestration.Server do {:ok, step} = Runs.get_step_for_execution(state.db, execution_id) {:ok, workspace_id} = Runs.get_workspace_id_for_execution(state.db, execution_id) - execution_ext_id = - case Runs.get_execution_key(state.db, execution_id) do - {:ok, {r, s, a}} -> execution_external_id(r, s, a) - {:error, :not_found} -> nil - end - {retry_id, recurred?, state} = - cond do - match?({:suspended, _, _}, result) -> - {:suspended, execute_after, dependency_keys} = result + decide_and_create_successor(state, execution_id, step, workspace_id, result) - # TODO: limit the number of times a step can suspend? (or rate?) + result = transform_result_with_successor(result, retry_id, recurred?) - {:ok, retry_id, _, state} = - rerun_step(state, step, workspace_id, - execute_after: execute_after, - dependency_keys: dependency_keys - ) + state = + case record_and_notify_result( + state, + execution_id, + result, + step.module, + created_by + ) do + {:ok, state} -> state + {:error, :already_recorded} -> state + end - state = - if execution_ext_id do - abort_execution(state, execution_ext_id) - else - state - end + # Cancel descendant executions for timeouts and cancellations + state = + if match?({:timeout, _}, result) or result == :cancelled do + cancel_descendants(state, execution_id, workspace_id) + else + state + end - {retry_id, false, state} + {:ok, state} + end + end - result_retryable?(result) && step.retry_limit == -1 -> - # Unlimited retries - random delay between min and max - delay_ms = - step.retry_backoff_min + - :rand.uniform() * (step.retry_backoff_max - step.retry_backoff_min) + defp decide_and_create_successor(state, execution_id, step, workspace_id, result) do + execution_ext_id = + case Runs.get_execution_key(state.db, execution_id) do + {:ok, {r, s, a}} -> execution_external_id(r, s, a) + {:error, :not_found} -> nil + end + + cond do + match?({:suspended, _, _}, result) -> + {:suspended, execute_after, dependency_keys} = result + + # TODO: limit the number of times a step can suspend? (or rate?) - execute_after = System.os_time(:millisecond) + delay_ms + {:ok, retry_id, _, state} = + rerun_step(state, step, workspace_id, + execute_after: execute_after, + dependency_keys: dependency_keys + ) - {:ok, retry_id, _, state} = - rerun_step(state, step, workspace_id, execute_after: execute_after) + state = + if execution_ext_id do + abort_execution(state, execution_ext_id) + else + state + end - {retry_id, false, state} + {retry_id, false, state} - result_retryable?(result) && step.retry_limit > 0 -> - # Limited retries - check consecutive failures - {:ok, result_types} = - Runs.get_step_result_types(state.db, step.id, step.retry_limit + 1) + result_retryable?(result) && step.retry_limit == -1 -> + # Unlimited retries - random delay between min and max + delay_ms = + step.retry_backoff_min + + :rand.uniform() * (step.retry_backoff_max - step.retry_backoff_min) - consecutive_failures = - result_types - |> Enum.take_while(&(&1 in [0, 2, 8])) - |> Enum.count() + execute_after = System.os_time(:millisecond) + delay_ms - if consecutive_failures < step.retry_limit do - # TODO: add jitter (within min/max delay) - delay_ms = - step.retry_backoff_min + - consecutive_failures / max(step.retry_limit - 1, 1) * - (step.retry_backoff_max - step.retry_backoff_min) + {:ok, retry_id, _, state} = + rerun_step(state, step, workspace_id, execute_after: execute_after) - execute_after = System.os_time(:millisecond) + delay_ms + {retry_id, false, state} - {:ok, retry_id, _, state} = - rerun_step(state, step, workspace_id, execute_after: execute_after) + result_retryable?(result) && step.retry_limit > 0 -> + # Limited retries - check consecutive failures. Exclude the current + # execution's row so this works regardless of whether its results row + # has already been written (deferred path) or not (immediate path). + # A nil type indicates a crashed execution (completion without a + # results row) — counted as a failure. + {:ok, rows} = + Runs.get_step_result_types(state.db, step.id, step.retry_limit + 2) - {retry_id, false, state} - else - {nil, false, state} - end + consecutive_failures = + rows + |> Enum.reject(fn {id, _type} -> id == execution_id end) + |> Enum.take_while(fn {_id, type} -> type in [0, 2, 8] or is_nil(type) end) + |> Enum.count() - step.recurrent == 1 and match?({:value, {:raw, nil, []}}, result) -> - # Null return from recurrent step: schedule next iteration via :recurred - execute_after = - if step.delay > 0 do - System.os_time(:millisecond) + step.delay - end + if consecutive_failures < step.retry_limit do + # TODO: add jitter (within min/max delay) + delay_ms = + step.retry_backoff_min + + consecutive_failures / max(step.retry_limit - 1, 1) * + (step.retry_backoff_max - step.retry_backoff_min) - {:ok, retry_id, _, state} = - rerun_step(state, step, workspace_id, execute_after: execute_after) + execute_after = System.os_time(:millisecond) + delay_ms - {retry_id, true, state} + {:ok, retry_id, _, state} = + rerun_step(state, step, workspace_id, execute_after: execute_after) - step.recurrent == 1 and match?({:value, _}, result) -> - # Non-null return from recurrent step: stop recurrence - {nil, false, state} + {retry_id, false, state} + else + {nil, false, state} + end - true -> - {nil, false, state} + step.recurrent == 1 and match?({:value, {:raw, nil, []}}, result) -> + # Null return from recurrent step: schedule next iteration via :recurred + execute_after = + if step.delay > 0 do + System.os_time(:millisecond) + step.delay end - result = - case result do - {:error, type, message, frames, retryable} -> - {:error, type, message, frames, retry_id, retryable} + {:ok, retry_id, _, state} = + rerun_step(state, step, workspace_id, execute_after: execute_after) + + {retry_id, true, state} + + step.recurrent == 1 and match?({:value, _}, result) -> + # Non-null return from recurrent step: stop recurrence + {nil, false, state} - :abandoned -> - {:abandoned, retry_id} + true -> + {nil, false, state} + end + end - :timeout -> - {:timeout, retry_id} + defp transform_result_with_successor(result, retry_id, recurred?) do + case result do + {:error, type, message, frames, retryable} -> + {:error, type, message, frames, retry_id, retryable} - {:suspended, _, _} -> - {:suspended, retry_id} + :abandoned -> + {:abandoned, retry_id} - {:value, _} when recurred? -> - {:recurred, retry_id} + :crashed -> + {:crashed, retry_id} - other -> - other - end + :timeout -> + {:timeout, retry_id} - state = - case record_and_notify_result( - state, - execution_id, - result, - step.module, - created_by - ) do - {:ok, state} -> state - {:error, :already_recorded} -> state - end + {:suspended, _, _} -> + {:suspended, retry_id} - # Cancel descendant executions for timeouts and cancellations - state = - if match?({:timeout, _}, result) or result == :cancelled do - cancel_descendants(state, execution_id, workspace_id) - else - state - end + {:value, _} when recurred? -> + {:recurred, retry_id} - {:ok, state} + other -> + other end end @@ -5742,7 +5845,7 @@ defmodule Coflux.Orchestration.Server do {:ok, nil} -> {:pending, execution_id} - {:ok, {result, _created_at, _created_by}} -> + {:ok, {result, _created_at, _completion_created_at, _created_by}} -> case result do {:error, _, _, _, execution_id, _retryable} when not is_nil(execution_id) -> resolve_result(db, execution_id) @@ -5750,6 +5853,9 @@ defmodule Coflux.Orchestration.Server do {:abandoned, execution_id} when not is_nil(execution_id) -> resolve_result(db, execution_id) + {:crashed, execution_id} when not is_nil(execution_id) -> + resolve_result(db, execution_id) + {:timeout, execution_id} when not is_nil(execution_id) -> resolve_result(db, execution_id) diff --git a/server/lib/coflux/topics/run.ex b/server/lib/coflux/topics/run.ex index 902b429c..613f359a 100644 --- a/server/lib/coflux/topics/run.ex +++ b/server/lib/coflux/topics/run.ex @@ -186,14 +186,23 @@ defmodule Coflux.Topics.Run do defp process_notification( topic, - {:result, execution_external_id, result, created_at, created_by} + {:result, execution_external_id, result, result_at, created_by} ) do result = build_result(result, created_by) update_execution(topic, execution_external_id, fn topic, base_path -> topic |> Topic.set(base_path ++ [:result], result) - |> Topic.set(base_path ++ [:completedAt], created_at) + |> Topic.set(base_path ++ [:resultAt], result_at) + end) + end + + defp process_notification( + topic, + {:completion, execution_external_id, completion_at} + ) do + update_execution(topic, execution_external_id, fn topic, base_path -> + Topic.set(topic, base_path ++ [:completedAt], completion_at) end) end @@ -346,6 +355,7 @@ defmodule Coflux.Topics.Run do createdBy: build_principal(execution.created_by), executeAfter: execution.execute_after, assignedAt: execution.assigned_at, + resultAt: execution.result_at, completedAt: execution.completed_at, groups: execution.groups, assets: @@ -449,6 +459,13 @@ defmodule Coflux.Topics.Run do retry: if(retry, do: execution_attempt(retry)) } + {:crashed, retry} -> + %{ + type: "crashed", + createdBy: created_by, + retry: if(retry, do: execution_attempt(retry)) + } + :cancelled -> %{type: "cancelled", createdBy: created_by} diff --git a/server/priv/migrations/orchestration/4.sql b/server/priv/migrations/orchestration/4.sql new file mode 100644 index 00000000..2dbc273f --- /dev/null +++ b/server/priv/migrations/orchestration/4.sql @@ -0,0 +1,67 @@ +-- Add completions table — a pure termination marker. The existing results +-- table continues to hold the disposition (including any successor), written +-- at result-arrival time. A completions row is written separately at +-- notify_terminated time, so its timestamp reflects when the worker's +-- process actually finished shutting down. +-- +-- This enables streaming support: a results row can be written with stream +-- handles while the process keeps running, with completions written later +-- when streams have drained. + +CREATE TABLE completions ( + execution_id INTEGER PRIMARY KEY, + created_at INTEGER NOT NULL, + FOREIGN KEY (execution_id) REFERENCES executions ON DELETE CASCADE +) STRICT; + +-- Every existing results row represents a terminated execution, so each +-- produces a completions row with the same timestamp. +INSERT INTO completions (execution_id, created_at) + SELECT execution_id, created_at FROM results; + +-- Streams — ordered, append-only sequences of values produced by an +-- execution. Each stream is identified by (execution_id, sequence), where +-- sequence is assigned monotonically by the worker when serialising the +-- execution's return value. The worker manages allocation locally, so no +-- server round-trip is needed to mint an id. +-- +-- Invariants: +-- • A stream is owned by exactly one execution (its producer). +-- • stream_items are append-only with monotonic position starting at 0. +-- • stream_closures are terminal — no items may be appended after closure. +-- • On execution completion / cancellation / crash, every owned stream +-- that lacks a closure receives one (clean, cancelled, or crashed). +-- • Re-running a producer execution creates fresh streams (new attempt ⇒ +-- new execution_id ⇒ new rows). Consumer references are concrete to +-- the original streams. +-- • Consumer cursors are kept in-memory only; re-run consumers subscribe +-- fresh from position 0. + +CREATE TABLE streams ( + execution_id INTEGER NOT NULL, + sequence INTEGER NOT NULL, + created_at INTEGER NOT NULL, + PRIMARY KEY (execution_id, sequence), + FOREIGN KEY (execution_id) REFERENCES executions ON DELETE CASCADE +) STRICT; + +CREATE TABLE stream_items ( + execution_id INTEGER NOT NULL, + sequence INTEGER NOT NULL, + position INTEGER NOT NULL, + value_id INTEGER NOT NULL, + created_at INTEGER NOT NULL, + PRIMARY KEY (execution_id, sequence, position), + FOREIGN KEY (execution_id, sequence) REFERENCES streams (execution_id, sequence) ON DELETE CASCADE, + FOREIGN KEY (value_id) REFERENCES values_ ON DELETE RESTRICT +) STRICT; + +CREATE TABLE stream_closures ( + execution_id INTEGER NOT NULL, + sequence INTEGER NOT NULL, + error_id INTEGER, + created_at INTEGER NOT NULL, + PRIMARY KEY (execution_id, sequence), + FOREIGN KEY (execution_id, sequence) REFERENCES streams (execution_id, sequence) ON DELETE CASCADE, + FOREIGN KEY (error_id) REFERENCES errors ON DELETE RESTRICT +) STRICT; From d43bcf4965a28cb8d281d41473f1b9e7eb5d2d53 Mon Sep 17 00:00:00 2001 From: Joe Freeman Date: Sat, 18 Apr 2026 18:33:36 +0100 Subject: [PATCH 02/25] Setup streams module --- server/lib/coflux/orchestration/errors.ex | 66 ++++++ server/lib/coflux/orchestration/results.ex | 63 +----- server/lib/coflux/orchestration/streams.ex | 230 +++++++++++++++++++++ 3 files changed, 300 insertions(+), 59 deletions(-) create mode 100644 server/lib/coflux/orchestration/errors.ex create mode 100644 server/lib/coflux/orchestration/streams.ex diff --git a/server/lib/coflux/orchestration/errors.ex b/server/lib/coflux/orchestration/errors.ex new file mode 100644 index 00000000..8bd1ea4e --- /dev/null +++ b/server/lib/coflux/orchestration/errors.ex @@ -0,0 +1,66 @@ +defmodule Coflux.Orchestration.Errors do + @moduledoc """ + Shared helpers for deduping errors in the `errors` + `error_frames` tables. + Used by `Results` (execution errors) and `Streams` (stream closure errors). + """ + + import Coflux.Store + + # Inserts or returns an existing error matching (type, message, frames). + # Returns the error id as an integer. + def get_or_create(db, type, message, frames) do + hash = hash(type, message, frames) + + case query_one(db, "SELECT id FROM errors WHERE hash = ?1", {{:blob, hash}}) do + {:ok, {id}} -> + id + + {:ok, nil} -> + {:ok, error_id} = + insert_one(db, :errors, %{ + hash: {:blob, hash}, + type: type, + message: message + }) + + {:ok, _} = + insert_many( + db, + :error_frames, + {:error_id, :depth, :file, :line, :name, :code}, + frames + |> Enum.with_index() + |> Enum.map(fn {{file, line, name, code}, index} -> + {error_id, index, file, line, name, code} + end) + ) + + error_id + end + end + + # Returns `{:ok, {type, message, frames}}`. + def get_by_id(db, error_id) do + {:ok, {type, message}} = + query_one!(db, "SELECT type, message FROM errors WHERE id = ?1", {error_id}) + + {:ok, frames} = + query( + db, + "SELECT file, line, name, code FROM error_frames WHERE error_id = ?1 ORDER BY depth", + {error_id} + ) + + {:ok, {type, message, frames}} + end + + defp hash(type, message, frames) do + frame_parts = + Enum.flat_map(frames, fn {file, line, name, code} -> + [file, Integer.to_string(line), name || 0, code || 0] + end) + + parts = Enum.concat([type, message], frame_parts) + :crypto.hash(:sha256, Enum.intersperse(parts, 0)) + end +end diff --git a/server/lib/coflux/orchestration/results.ex b/server/lib/coflux/orchestration/results.ex index 26f95823..05d97b0d 100644 --- a/server/lib/coflux/orchestration/results.ex +++ b/server/lib/coflux/orchestration/results.ex @@ -1,7 +1,7 @@ defmodule Coflux.Orchestration.Results do import Coflux.Store - alias Coflux.Orchestration.Values + alias Coflux.Orchestration.{Errors, Values} # Writes the results row capturing the disposition (value/error/retryable) # and any server-decided successor. Written at the time the disposition is @@ -20,11 +20,11 @@ defmodule Coflux.Orchestration.Results do {type, error_id, value_id, successor_id, successor_ref_id, retryable} = case result do {:error, type, message, frames, retry_id, retryable} -> - {:ok, error_id} = get_or_create_error(db, type, message, frames) + error_id = Errors.get_or_create(db, type, message, frames) {0, error_id, nil, retry_id, nil, retryable} {:error, type, message, frames, retry_id} -> - {:ok, error_id} = get_or_create_error(db, type, message, frames) + error_id = Errors.get_or_create(db, type, message, frames) {0, error_id, nil, retry_id, nil, nil} {:value, value} -> @@ -181,7 +181,7 @@ defmodule Coflux.Orchestration.Results do result = case {type, error_id, value_id, successor_id, successor_ref_id} do {0, error_id, nil, retry_id, nil} -> - case get_error_by_id(db, error_id) do + case Errors.get_by_id(db, error_id) do {:ok, {type, message, frames}} -> {:error, type, message, frames, retry_id, retryable} end @@ -235,61 +235,6 @@ defmodule Coflux.Orchestration.Results do end end - defp get_error_by_id(db, error_id) do - {:ok, {type, message}} = - query_one!(db, "SELECT type, message FROM errors WHERE id = ?1", {error_id}) - - {:ok, frames} = - query( - db, - "SELECT file, line, name, code FROM error_frames WHERE error_id = ?1 ORDER BY depth", - {error_id} - ) - - {:ok, {type, message, frames}} - end - - defp hash_error(type, message, frames) do - frame_parts = - Enum.flat_map(frames, fn {file, line, name, code} -> - [file, Integer.to_string(line), name || 0, code || 0] - end) - - parts = Enum.concat([type, message], frame_parts) - :crypto.hash(:sha256, Enum.intersperse(parts, 0)) - end - - defp get_or_create_error(db, type, message, frames) do - hash = hash_error(type, message, frames) - - case query_one(db, "SELECT id FROM errors WHERE hash = ?1", {{:blob, hash}}) do - {:ok, {id}} -> - {:ok, id} - - {:ok, nil} -> - {:ok, error_id} = - insert_one(db, :errors, %{ - hash: {:blob, hash}, - type: type, - message: message - }) - - {:ok, _} = - insert_many( - db, - :error_frames, - {:error_id, :depth, :file, :line, :name, :code}, - frames - |> Enum.with_index() - |> Enum.map(fn {{file, line, name, code}, index} -> - {error_id, index, file, line, name, code} - end) - ) - - {:ok, error_id} - end - end - def put_execution_asset(db, execution_id, asset_id) do now = current_timestamp() diff --git a/server/lib/coflux/orchestration/streams.ex b/server/lib/coflux/orchestration/streams.ex new file mode 100644 index 00000000..08988717 --- /dev/null +++ b/server/lib/coflux/orchestration/streams.ex @@ -0,0 +1,230 @@ +defmodule Coflux.Orchestration.Streams do + @moduledoc """ + Storage for execution-produced streams. + + A stream is an ordered, append-only sequence of values produced by an + execution. Each stream is identified by `(execution_id, sequence)` where + sequence is assigned monotonically by the worker during return-value + serialisation — the worker mints ids locally, no server round-trip. + + Invariants enforced here (and by schema FKs): + + * A stream is owned by exactly one execution (its producer). + * Items are append-only with monotonic `position` starting at 0. + * A closure is terminal — no items may be appended after one is recorded. + * On execution completion / cancel / crash, every owned stream that lacks + a closure receives one (clean, cancelled, or crashed). Enforced by the + lifecycle code in `Server`, not by this module. + * Re-running a producer execution creates fresh streams (new attempt ⇒ + new execution_id ⇒ new rows). Consumer refs pin to the original streams. + * Consumer cursors are kept in-memory only; re-run consumers subscribe + fresh from position 0. + """ + + import Coflux.Store + + alias Coflux.Orchestration.{Errors, Values} + + # Registers a new stream owned by `execution_id` with the given `sequence` + # (monotonic per-execution, worker-assigned). Returns `{:error, :already_registered}` + # if the sequence was already used. + def register_stream(db, execution_id, sequence) do + now = current_timestamp() + + case insert_one(db, :streams, %{ + execution_id: execution_id, + sequence: sequence, + created_at: now + }) do + {:ok, _} -> {:ok, now} + {:error, "UNIQUE constraint failed: " <> _} -> {:error, :already_registered} + end + end + + # Appends an item at `position` to the stream. Caller supplies the position + # (worker-assigned, monotonic). Returns: + # * `{:error, :not_registered}` if the stream doesn't exist + # * `{:error, :closed}` if the stream has a closure row + # * `{:error, :already_appended}` if position collides with an existing item + def append_item(db, execution_id, sequence, position, value) do + with_transaction(db, fn -> + case has_closure?(db, execution_id, sequence) do + {:ok, true} -> + {:error, :closed} + + {:ok, false} -> + case exists?(db, execution_id, sequence) do + {:ok, false} -> + {:error, :not_registered} + + {:ok, true} -> + {:ok, value_id} = Values.get_or_create_value(db, value) + now = current_timestamp() + + case insert_one(db, :stream_items, %{ + execution_id: execution_id, + sequence: sequence, + position: position, + value_id: value_id, + created_at: now + }) do + {:ok, _} -> {:ok, now} + {:error, "UNIQUE constraint failed: " <> _} -> {:error, :already_appended} + end + end + end + end) + end + + # Closes the stream. `error` is either `nil` (clean close) or a + # `{type, message, frames}` triple (error close — re-uses the errors table + # via the same path as Results). + def close_stream(db, execution_id, sequence, error \\ nil) do + with_transaction(db, fn -> + case exists?(db, execution_id, sequence) do + {:ok, false} -> + {:error, :not_registered} + + {:ok, true} -> + now = current_timestamp() + + error_id = + case error do + nil -> nil + {type, message, frames} -> Errors.get_or_create(db, type, message, frames) + end + + case insert_one(db, :stream_closures, %{ + execution_id: execution_id, + sequence: sequence, + error_id: error_id, + created_at: now + }) do + {:ok, _} -> {:ok, now} + {:error, "UNIQUE constraint failed: " <> _} -> {:error, :already_closed} + end + end + end) + end + + def exists?(db, execution_id, sequence) do + case query_one( + db, + "SELECT 1 FROM streams WHERE execution_id = ?1 AND sequence = ?2", + {execution_id, sequence} + ) do + {:ok, nil} -> {:ok, false} + {:ok, {1}} -> {:ok, true} + end + end + + def has_closure?(db, execution_id, sequence) do + case query_one( + db, + "SELECT 1 FROM stream_closures WHERE execution_id = ?1 AND sequence = ?2", + {execution_id, sequence} + ) do + {:ok, nil} -> {:ok, false} + {:ok, {1}} -> {:ok, true} + end + end + + # Returns `{:ok, [sequence, ...]}` for every stream owned by `execution_id`, + # in sequence order. + def get_streams_for_execution(db, execution_id) do + case query( + db, + "SELECT sequence FROM streams WHERE execution_id = ?1 ORDER BY sequence", + {execution_id} + ) do + {:ok, rows} -> + {:ok, Enum.map(rows, fn {sequence} -> sequence end)} + end + end + + # Returns sequences of streams owned by `execution_id` that don't yet have + # a closure row. Used by the lifecycle code to discover which streams to + # close on completion / cancel / crash. + def get_open_streams_for_execution(db, execution_id) do + case query( + db, + """ + SELECT s.sequence + FROM streams AS s + LEFT JOIN stream_closures AS c + ON c.execution_id = s.execution_id AND c.sequence = s.sequence + WHERE s.execution_id = ?1 AND c.execution_id IS NULL + ORDER BY s.sequence + """, + {execution_id} + ) do + {:ok, rows} -> + {:ok, Enum.map(rows, fn {sequence} -> sequence end)} + end + end + + # Returns closure info or `{:ok, nil}` if the stream is still open. + # Closure info: `{error | nil, created_at}` where error is + # `{type, message, frames}` when present. + def get_stream_closure(db, execution_id, sequence) do + case query_one( + db, + "SELECT error_id, created_at FROM stream_closures WHERE execution_id = ?1 AND sequence = ?2", + {execution_id, sequence} + ) do + {:ok, nil} -> + {:ok, nil} + + {:ok, {nil, created_at}} -> + {:ok, {nil, created_at}} + + {:ok, {error_id, created_at}} -> + {:ok, error} = Errors.get_by_id(db, error_id) + {:ok, {error, created_at}} + end + end + + # Fetches up to `max_items` items from the stream starting at `from_position`. + # Returns `{:ok, [{position, value, created_at}, ...]}` in position order. + # The caller (Server) layers filter logic (slice / partition) on top of this. + def get_stream_items(db, execution_id, sequence, from_position, max_items) do + case query( + db, + """ + SELECT position, value_id, created_at + FROM stream_items + WHERE execution_id = ?1 AND sequence = ?2 AND position >= ?3 + ORDER BY position + LIMIT ?4 + """, + {execution_id, sequence, from_position, max_items} + ) do + {:ok, rows} -> + items = + Enum.map(rows, fn {position, value_id, created_at} -> + {:ok, value} = Values.get_value_by_id(db, value_id) + {position, value, created_at} + end) + + {:ok, items} + end + end + + # Returns the highest position recorded for the stream, or `-1` if empty. + # Used by the worker protocol to report "head" for flow control without + # requiring the caller to scan all items. + def get_stream_head(db, execution_id, sequence) do + case query_one( + db, + "SELECT MAX(position) FROM stream_items WHERE execution_id = ?1 AND sequence = ?2", + {execution_id, sequence} + ) do + {:ok, {nil}} -> {:ok, -1} + {:ok, {position}} -> {:ok, position} + end + end + + defp current_timestamp() do + System.os_time(:millisecond) + end +end From ab585495abe220cfe00abee36685398367afad63 Mon Sep 17 00:00:00 2001 From: Joe Freeman Date: Sat, 18 Apr 2026 18:39:38 +0100 Subject: [PATCH 03/25] Close streams when execution completes --- server/lib/coflux/orchestration/server.ex | 34 +++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/server/lib/coflux/orchestration/server.ex b/server/lib/coflux/orchestration/server.ex index 722754a1..0a5df804 100644 --- a/server/lib/coflux/orchestration/server.ex +++ b/server/lib/coflux/orchestration/server.ex @@ -10,6 +10,7 @@ defmodule Coflux.Orchestration.Server do Sessions, Runs, Results, + Streams, Assets, Inputs, Values, @@ -3471,6 +3472,12 @@ defmodule Coflux.Orchestration.Server do {:error, :already_recorded} -> state end + # Close any open streams so iterating consumers stop waiting. Any + # subsequent `append_item` from the producer will fail with `:closed`, + # signalling the worker to stop. The closure carries no error; consumers + # resolve the cancel from the execution's own disposition. + close_open_streams(state, execution_id) + state = case Runs.get_execution_key(state.db, execution_id) do {:ok, {r, s, a}} -> @@ -5631,6 +5638,12 @@ defmodule Coflux.Orchestration.Server do {:ok, false} -> case Results.has_result?(state.db, execution_id) do {:ok, true} -> + # Close any streams left open by the producer. Generator tasks + # normally close their streams explicitly; this is the backstop + # for ones that didn't. Consumers resolve the close reason from + # the execution's own disposition. + close_open_streams(state, execution_id) + case Results.record_completion(state.db, execution_id) do {:ok, completion_at} -> fire_completion_notification(state, execution_id, completion_at) @@ -5657,6 +5670,11 @@ defmodule Coflux.Orchestration.Server do {retry_id, _recurred?, state} = decide_and_create_successor(state, execution_id, step, workspace_id, :crashed) + # Streams that had been appended to before the worker died need to be + # closed so consumers don't wait forever. The closure carries no error — + # the execution's own :crashed disposition is the source of truth. + close_open_streams(state, execution_id) + case Results.record_completion(state.db, execution_id) do {:ok, completion_at} -> # Result-time notifications weren't fired (no results row was ever @@ -5671,6 +5689,22 @@ defmodule Coflux.Orchestration.Server do end end + # Closes every stream owned by `execution_id` that doesn't yet have a + # closure row. The closure carries no error — the consumer resolves the + # reason from the execution's result / completion state (clean, crashed, + # cancelled). If a generator closed its stream with an explicit error, + # that closure already exists and is left untouched. + defp close_open_streams(state, execution_id) do + {:ok, sequences} = Streams.get_open_streams_for_execution(state.db, execution_id) + + Enum.each(sequences, fn sequence -> + case Streams.close_stream(state.db, execution_id, sequence) do + {:ok, _} -> :ok + {:error, :already_closed} -> :ok + end + end) + end + defp fire_completion_notification(state, execution_id, completion_at) do {:ok, {r, s, a}} = Runs.get_execution_key(state.db, execution_id) execution_external_id = execution_external_id(r, s, a) From 326328ba589f8390d1473c70ac8fcbeb0b861c52 Mon Sep 17 00:00:00 2001 From: Joe Freeman Date: Sat, 18 Apr 2026 18:48:26 +0100 Subject: [PATCH 04/25] Update wire protocol --- adapters/python/coflux/protocol.py | 59 ++++++++++++++++++ cli/internal/adapter/protocol.go | 31 ++++++++++ cli/internal/pool/pool.go | 53 ++++++++++++++++ cli/internal/worker/worker.go | 37 ++++++++++++ server/lib/coflux/handlers/worker.ex | 74 +++++++++++++++++++++++ server/lib/coflux/orchestration.ex | 16 +++++ server/lib/coflux/orchestration/server.ex | 49 +++++++++++++++ 7 files changed, 319 insertions(+) diff --git a/adapters/python/coflux/protocol.py b/adapters/python/coflux/protocol.py index e55824f7..914f6c51 100644 --- a/adapters/python/coflux/protocol.py +++ b/adapters/python/coflux/protocol.py @@ -454,6 +454,65 @@ def send_metric( get_protocol().send_message("metric", params) +def send_stream_register(execution_id: str, sequence: int) -> None: + """Register a stream owned by this execution. + + Sequence is worker-assigned and monotonic per execution (0, 1, 2, ...). + """ + get_protocol().send_message( + "stream_register", + {"execution_id": execution_id, "sequence": sequence}, + ) + + +def send_stream_append( + execution_id: str, + sequence: int, + position: int, + value: dict[str, Any], +) -> None: + """Append an item to a stream at the given (worker-assigned) position. + + Position is monotonic per stream (0, 1, 2, ...). Value uses the same + Value shape as execution results (type + format + value/path + refs). + """ + get_protocol().send_message( + "stream_append", + { + "execution_id": execution_id, + "sequence": sequence, + "position": position, + "value": value, + }, + ) + + +def send_stream_close( + execution_id: str, + sequence: int, + error_type: str | None = None, + error_message: str = "", + traceback: str = "", +) -> None: + """Close a stream. + + With no error args, signals a clean close (generator exhausted). With + error args set, signals that the generator raised — consumers will see + the exception on their next iteration. + """ + params: dict[str, Any] = { + "execution_id": execution_id, + "sequence": sequence, + } + if error_type is not None: + params["error"] = { + "type": error_type, + "message": error_message, + "traceback": traceback, + } + get_protocol().send_message("stream_close", params) + + def receive_message() -> dict[str, Any] | None: """Receive the next message from the CLI.""" return get_protocol().receive() diff --git a/cli/internal/adapter/protocol.go b/cli/internal/adapter/protocol.go index 571dc097..594788bc 100644 --- a/cli/internal/adapter/protocol.go +++ b/cli/internal/adapter/protocol.go @@ -253,6 +253,37 @@ type RegisterGroupParams struct { Name *string `json:"name,omitempty"` } +// StreamRegisterParams for stream_register notification. +// Sequence is worker-assigned, monotonic per execution. +type StreamRegisterParams struct { + ExecutionID string `json:"execution_id"` + Sequence int `json:"sequence"` +} + +// StreamAppendParams for stream_append notification. +// Position is worker-assigned, monotonic per stream. +type StreamAppendParams struct { + ExecutionID string `json:"execution_id"` + Sequence int `json:"sequence"` + Position int `json:"position"` + Value *Value `json:"value"` +} + +// StreamCloseParams for stream_close notification. Error is present only +// when the producer's generator raised an exception. +type StreamCloseParams struct { + ExecutionID string `json:"execution_id"` + Sequence int `json:"sequence"` + Error *StreamCloseError `json:"error,omitempty"` +} + +// StreamCloseError describes an error that terminated a stream. +type StreamCloseError struct { + Type string `json:"type"` + Message string `json:"message"` + Traceback string `json:"traceback"` +} + // DownloadBlobParams for download_blob request type DownloadBlobParams struct { ExecutionID string `json:"execution_id"` diff --git a/cli/internal/pool/pool.go b/cli/internal/pool/pool.go index aa1a932e..9ef56572 100644 --- a/cli/internal/pool/pool.go +++ b/cli/internal/pool/pool.go @@ -49,6 +49,14 @@ type ExecutionHandler interface { SubmitInput(ctx context.Context, params *adapter.SubmitInputParams) (string, error) // NotifyTerminated notifies the server that an execution's process has exited NotifyTerminated(ctx context.Context, executionID string) error + // StreamRegister declares a new stream owned by an execution. + // Sequence is worker-assigned, monotonic per execution. + StreamRegister(ctx context.Context, executionID string, sequence int) error + // StreamAppend appends an item to a stream at the given (worker-assigned) position. + StreamAppend(ctx context.Context, executionID string, sequence int, position int, value *adapter.Value) error + // StreamClose closes a stream. Error is nil for a clean close, or a (type, message, traceback) + // triple when the producer's generator raised. + StreamClose(ctx context.Context, executionID string, sequence int, err *adapter.StreamCloseError) error } // Pool manages executor processes. Each executor handles one execution then @@ -264,6 +272,15 @@ loop: case "register_group": p.handleRegisterGroup(execCtx, executionID, params, logger) + case "stream_register": + p.handleStreamRegister(execCtx, executionID, params, logger) + + case "stream_append": + p.handleStreamAppend(execCtx, executionID, params, logger) + + case "stream_close": + p.handleStreamClose(execCtx, executionID, params, logger) + default: err := fmt.Errorf("unknown message method: %s", method) logger.Error("unknown message method", "method", method) @@ -435,6 +452,42 @@ func (p *Pool) handleRegisterGroup(ctx context.Context, executionID string, para } } +func (p *Pool) handleStreamRegister(ctx context.Context, executionID string, params json.RawMessage, logger *slog.Logger) { + var req adapter.StreamRegisterParams + if err := json.Unmarshal(params, &req); err != nil { + logger.Error("failed to parse stream_register message", "error", err) + return + } + + if err := p.handler.StreamRegister(ctx, req.ExecutionID, req.Sequence); err != nil { + logger.Error("failed to register stream", "error", err) + } +} + +func (p *Pool) handleStreamAppend(ctx context.Context, executionID string, params json.RawMessage, logger *slog.Logger) { + var req adapter.StreamAppendParams + if err := json.Unmarshal(params, &req); err != nil { + logger.Error("failed to parse stream_append message", "error", err) + return + } + + if err := p.handler.StreamAppend(ctx, req.ExecutionID, req.Sequence, req.Position, req.Value); err != nil { + logger.Error("failed to append stream item", "error", err) + } +} + +func (p *Pool) handleStreamClose(ctx context.Context, executionID string, params json.RawMessage, logger *slog.Logger) { + var req adapter.StreamCloseParams + if err := json.Unmarshal(params, &req); err != nil { + logger.Error("failed to parse stream_close message", "error", err) + return + } + + if err := p.handler.StreamClose(ctx, req.ExecutionID, req.Sequence, req.Error); err != nil { + logger.Error("failed to close stream", "error", err) + } +} + func (p *Pool) handleRequest(ctx context.Context, exec *adapter.Executor, method string, id int, params json.RawMessage, logger *slog.Logger) { var result any var errInfo *adapter.ErrorInfo diff --git a/cli/internal/worker/worker.go b/cli/internal/worker/worker.go index da2757e3..3058c819 100644 --- a/cli/internal/worker/worker.go +++ b/cli/internal/worker/worker.go @@ -1101,6 +1101,43 @@ func (w *Worker) RegisterGroup(ctx context.Context, executionID string, groupID return conn.Notify("register_group", executionID, groupID, name) } +func (w *Worker) StreamRegister(ctx context.Context, executionID string, sequence int) error { + conn, err := w.requireConn() + if err != nil { + return err + } + return conn.Notify("stream_register", executionID, sequence) +} + +func (w *Worker) StreamAppend(ctx context.Context, executionID string, sequence int, position int, value *adapter.Value) error { + conn, err := w.requireConn() + if err != nil { + return err + } + // Apply blob threshold + upload fragment references just like ReportResult. + serverValue, err := w.convertValueToServerFormat(value) + if err != nil { + return err + } + return conn.Notify("stream_append", executionID, sequence, position, serverValue) +} + +func (w *Worker) StreamClose(ctx context.Context, executionID string, sequence int, streamErr *adapter.StreamCloseError) error { + conn, err := w.requireConn() + if err != nil { + return err + } + var errTuple any + if streamErr != nil { + // Match the shape used for put_error: [type, message, frames]. + // Stream closures never carry a retryable flag — retry decisions + // live at the execution level, not per-stream. + frames := parseTraceback(streamErr.Traceback) + errTuple = []any{streamErr.Type, streamErr.Message, frames} + } + return conn.Notify("stream_close", executionID, sequence, errTuple) +} + func (w *Worker) Cancel(ctx context.Context, executionID string, handles []adapter.SelectHandle) error { conn, err := w.waitForConn(ctx) if err != nil { diff --git a/server/lib/coflux/handlers/worker.ex b/server/lib/coflux/handlers/worker.ex index 1f70f4ce..3c4f29a4 100644 --- a/server/lib/coflux/handlers/worker.ex +++ b/server/lib/coflux/handlers/worker.ex @@ -253,6 +253,80 @@ defmodule Coflux.Handlers.Worker do {[{:close, 4000, "execution_invalid"}], nil} end + "stream_register" -> + [execution_id, sequence] = message["params"] + + if is_recognised_execution?(execution_id, state) do + case Orchestration.register_stream(state.project_id, execution_id, sequence) do + :ok -> {[], state} + # Idempotent — a duplicate register is harmless. + {:error, :already_registered} -> {[], state} + {:error, :not_found} -> {[{:close, 4000, "execution_invalid"}], nil} + end + else + {[{:close, 4000, "execution_invalid"}], nil} + end + + "stream_append" -> + [execution_id, sequence, position, value] = message["params"] + + if is_recognised_execution?(execution_id, state) do + case Orchestration.append_stream_item( + state.project_id, + execution_id, + sequence, + position, + parse_value(value) + ) do + :ok -> + {[], state} + + # Worker is trying to append to a stream the server has already + # closed (e.g., owner execution was cancelled). Surfacing this to + # the adapter would let it stop producing; for now, swallow so + # the stream-close propagation to the worker (task #10) is the + # canonical signal. + {:error, :closed} -> + {[], state} + + {:error, :not_registered} -> + {[], state} + + {:error, :already_appended} -> + {[], state} + + {:error, :not_found} -> + {[{:close, 4000, "execution_invalid"}], nil} + end + else + {[{:close, 4000, "execution_invalid"}], nil} + end + + "stream_close" -> + [execution_id, sequence, error] = message["params"] + + if is_recognised_execution?(execution_id, state) do + parsed_error = + case parse_error(error) do + nil -> nil + {type, message, frames, _retryable} -> {type, message, frames} + end + + case Orchestration.close_stream( + state.project_id, + execution_id, + sequence, + parsed_error + ) do + :ok -> {[], state} + {:error, :already_closed} -> {[], state} + {:error, :not_registered} -> {[], state} + {:error, :not_found} -> {[{:close, 4000, "execution_invalid"}], nil} + end + else + {[{:close, 4000, "execution_invalid"}], nil} + end + "put_error" -> [execution_id, error] = message["params"] diff --git a/server/lib/coflux/orchestration.ex b/server/lib/coflux/orchestration.ex index 6a57e584..d6bb80e9 100644 --- a/server/lib/coflux/orchestration.ex +++ b/server/lib/coflux/orchestration.ex @@ -181,6 +181,22 @@ defmodule Coflux.Orchestration do call_server(project_id, {:record_result, execution_id, result}) end + # Stream producer messages — worker registers a stream, appends items, + # and closes the stream. Sequence and position are worker-assigned and + # monotonic per-execution / per-stream. + + def register_stream(project_id, execution_id, sequence) do + call_server(project_id, {:register_stream, execution_id, sequence}) + end + + def append_stream_item(project_id, execution_id, sequence, position, value) do + call_server(project_id, {:append_stream_item, execution_id, sequence, position, value}) + end + + def close_stream(project_id, execution_id, sequence, error) do + call_server(project_id, {:close_stream, execution_id, sequence, error}) + end + def select( project_id, handles, diff --git a/server/lib/coflux/orchestration/server.ex b/server/lib/coflux/orchestration/server.ex index 0a5df804..6732e45f 100644 --- a/server/lib/coflux/orchestration/server.ex +++ b/server/lib/coflux/orchestration/server.ex @@ -1807,6 +1807,55 @@ defmodule Coflux.Orchestration.Server do end end + def handle_call({:register_stream, execution_external_id, sequence}, _from, state) do + case Map.fetch(state.execution_ids, execution_external_id) do + {:ok, execution_id} -> + case Streams.register_stream(state.db, execution_id, sequence) do + {:ok, _} -> {:reply, :ok, state} + {:error, :already_registered} -> {:reply, {:error, :already_registered}, state} + end + + :error -> + {:reply, {:error, :not_found}, state} + end + end + + def handle_call( + {:append_stream_item, execution_external_id, sequence, position, value}, + _from, + state + ) do + case Map.fetch(state.execution_ids, execution_external_id) do + {:ok, execution_id} -> + case Streams.append_item( + state.db, + execution_id, + sequence, + position, + normalize_value(value) + ) do + {:ok, _} -> {:reply, :ok, state} + {:error, reason} -> {:reply, {:error, reason}, state} + end + + :error -> + {:reply, {:error, :not_found}, state} + end + end + + def handle_call({:close_stream, execution_external_id, sequence, error}, _from, state) do + case Map.fetch(state.execution_ids, execution_external_id) do + {:ok, execution_id} -> + case Streams.close_stream(state.db, execution_id, sequence, error) do + {:ok, _} -> {:reply, :ok, state} + {:error, reason} -> {:reply, {:error, reason}, state} + end + + :error -> + {:reply, {:error, :not_found}, state} + end + end + def handle_call( {:select, handles, from_execution_external_id, timeout_ms, suspend, cancel_remaining, request_id}, From aba392e4fd155c0b36a5b7436d86c2fd78ff1301 Mon Sep 17 00:00:00 2001 From: Joe Freeman Date: Sat, 18 Apr 2026 19:38:30 +0100 Subject: [PATCH 05/25] Implement generator detector and driver --- adapters/python/coflux/context.py | 51 ++++--- adapters/python/coflux/dispatcher.py | 176 ++++++++++++++++++++++++ adapters/python/coflux/executor.py | 24 +++- adapters/python/coflux/protocol.py | 10 +- adapters/python/coflux/serialization.py | 31 ++++- adapters/python/coflux/streams.py | 97 +++++++++++++ adapters/python/coflux/target.py | 8 +- 7 files changed, 363 insertions(+), 34 deletions(-) create mode 100644 adapters/python/coflux/dispatcher.py create mode 100644 adapters/python/coflux/streams.py diff --git a/adapters/python/coflux/context.py b/adapters/python/coflux/context.py index fc781aa7..2fdf028e 100644 --- a/adapters/python/coflux/context.py +++ b/adapters/python/coflux/context.py @@ -12,6 +12,7 @@ from typing import Any, Callable, Iterator from . import protocol +from .dispatcher import get_dispatcher from .errors import ( ExecutionAbandoned, ExecutionCancelled, @@ -22,6 +23,7 @@ ) from .models import Asset, AssetEntry, AssetMetadata, Execution, Input from .serialization import deserialize_value, serialize_value +from .streams import StreamDriver def _handle_key(handle: Any) -> tuple[str, str]: @@ -81,7 +83,6 @@ class ExecutorContext: def __init__(self, execution_id: str, working_dir: Path | None = None): self.execution_id = execution_id - self._pending_requests: dict[int, Any] = {} self._groups: list[str | None] = [] self._working_dir = working_dir or Path.cwd() self._defined_metrics: dict[str, dict] = {} @@ -92,6 +93,23 @@ def __init__(self, execution_id: str, working_dir: Path | None = None): # poll_handle to avoid a round-trip for handles that have already # been seen in this context's lifetime. self._resolved: dict[tuple[str, str], dict[str, Any]] = {} + # Owns generator streams for this execution. Generators encountered + # during serialization (of the return value OR of submit arguments) + # are registered here and driven in background threads. + self._stream_driver = StreamDriver(execution_id) + + def register_stream(self, generator: Any) -> tuple[str, int]: + """Callback for ``serialize_value(on_generator=...)``. + + Registers a generator with this execution's driver and returns the + ``(execution_id, sequence)`` stream reference to embed in the + serialized value. + """ + return self._stream_driver.register(generator) + + def wait_streams(self) -> None: + """Block until every stream produced by this execution has drained.""" + self._stream_driver.wait_all() def submit_execution( self, @@ -520,8 +538,7 @@ def suspend_execution( request_id = protocol.request_suspend(self.execution_id, execute_after) self._wait_response(request_id) # Suspension confirmed. Block until the server aborts this execution. - while protocol.receive_message() is not None: - pass + get_dispatcher().wait_closed() raise SystemExit(0) def _parse_response(self, msg: dict) -> Any: @@ -532,22 +549,12 @@ def _parse_response(self, msg: dict) -> Any: return msg.get("result", {}) def _wait_response(self, request_id: int) -> Any: - """Wait for a response to a request.""" - if request_id in self._pending_requests: - return self._parse_response(self._pending_requests.pop(request_id)) - while True: - msg = protocol.receive_message() - if msg is None: - raise RuntimeError("Connection closed while waiting for response") - - # Check if this is a response - if "id" in msg: - if msg["id"] == request_id: - return self._parse_response(msg) - # Store other responses for later - self._pending_requests[msg["id"]] = msg - else: - # Unexpected message during wait - raise RuntimeError( - f"Unexpected message while waiting for response: {msg}" - ) + """Wait for a response to a request. + + Delegates to the dispatcher, which owns stdin and routes the matching + response to this caller. Safe to call from any thread. + """ + msg = get_dispatcher().wait_for_response(request_id) + if msg is None: + raise RuntimeError("Timed out waiting for response") + return self._parse_response(msg) diff --git a/adapters/python/coflux/dispatcher.py b/adapters/python/coflux/dispatcher.py new file mode 100644 index 00000000..00ddc594 --- /dev/null +++ b/adapters/python/coflux/dispatcher.py @@ -0,0 +1,176 @@ +"""Message dispatcher for concurrent stdio access. + +The CLI can send messages to the adapter at any time: + + * **Responses** — `{"id": N, "result": ...}` or `{"id": N, "error": ...}` + replying to an earlier request. + * **Notifications** — `{"method": "X", "params": {...}}` pushed without a + prior request (e.g. `stream_produce_until` flow-control signals for + producers, `stream_items` pushes for consumers). + +Multiple threads in the adapter may be waiting on different responses +simultaneously (e.g. each generator driver thread calling a Coflux subtask). +Reading stdin from multiple threads would corrupt the protocol, so one +dedicated reader thread owns stdin and dispatches what it reads: + + * Responses are routed to the specific waiter by request id. + * Notifications are routed by method name to registered handlers. + +Writers are separately protected by ``Protocol._write_lock``. + +Lifecycle: ``start()`` spawns the reader thread (daemon — dies with the +process on EOF). ``wait_for_response`` and ``wait_closed`` are the +blocking APIs used by the rest of the adapter. +""" + +from __future__ import annotations + +import threading +from typing import Any, Callable + +from .protocol import Protocol + + +class Dispatcher: + """Owns stdin; routes responses by request id, notifications by method.""" + + def __init__(self, protocol: Protocol) -> None: + self._protocol = protocol + self._lock = threading.Lock() + + # request_id → (Event, [response_msg | None]) + # The list is a one-slot mutable box because closures can't rebind. + self._waiting: dict[int, tuple[threading.Event, list[Any]]] = {} + + # Responses that arrived before their waiter registered — rare but + # possible if the writer thread and reader thread interleave. + self._early_responses: dict[int, dict[str, Any]] = {} + + # method → handler(params). Handlers run on the reader thread; keep + # them fast and non-blocking (delegate heavy work to queues or + # other threads). + self._notification_handlers: dict[str, Callable[[dict[str, Any]], None]] = {} + + # Set when stdin reaches EOF. Wakes all pending waiters. + self._closed = threading.Event() + + self._thread = threading.Thread( + target=self._run, + name="coflux-dispatcher", + daemon=True, + ) + + def start(self) -> None: + self._thread.start() + + def wait_closed(self) -> None: + """Block until stdin closes (server disconnected / aborted us).""" + self._closed.wait() + + def register_notification( + self, + method: str, + handler: Callable[[dict[str, Any]], None], + ) -> None: + """Register a handler for an incoming notification method. + + Handlers run on the reader thread — they must not block. For any + real work, enqueue to another thread. + """ + with self._lock: + self._notification_handlers[method] = handler + + def unregister_notification(self, method: str) -> None: + with self._lock: + self._notification_handlers.pop(method, None) + + def wait_for_response( + self, + request_id: int, + timeout: float | None = None, + ) -> dict[str, Any] | None: + """Block until the response for ``request_id`` arrives. + + Returns the raw response dict (``{"id": ..., "result": ...}`` or + ``{"id": ..., "error": ...}``). Returns ``None`` if the wait times + out. Raises ``RuntimeError`` if the connection closes before the + response arrives. + """ + event = threading.Event() + slot: list[Any] = [None] + + with self._lock: + if request_id in self._early_responses: + return self._early_responses.pop(request_id) + if self._closed.is_set(): + raise RuntimeError("Connection closed while waiting for response") + self._waiting[request_id] = (event, slot) + + try: + ready = event.wait(timeout) if timeout is not None else event.wait() + finally: + with self._lock: + self._waiting.pop(request_id, None) + + if not ready: + return None + if slot[0] is None: + # Woken by EOF rather than a real response. + raise RuntimeError("Connection closed while waiting for response") + return slot[0] + + def _run(self) -> None: + while True: + msg = self._protocol.receive() + if msg is None: + # EOF — wake all waiters with a null slot; they'll raise. + self._closed.set() + with self._lock: + for event, _slot in self._waiting.values(): + event.set() + return + + if "id" in msg: + request_id = msg["id"] + with self._lock: + entry = self._waiting.get(request_id) + if entry is not None: + _event, slot = entry + slot[0] = msg + _event.set() + else: + # Buffer for a waiter that registers later. + self._early_responses[request_id] = msg + elif "method" in msg: + with self._lock: + handler = self._notification_handlers.get(msg["method"]) + if handler is not None: + try: + handler(msg.get("params", {})) + except Exception: # noqa: BLE001 + # Don't let a handler fault kill the dispatcher. + # Adapter-side logging hooks into protocol anyway, + # but we swallow here rather than taking the loop down. + pass + # Silently ignore malformed messages rather than killing the + # reader — log once per session if we want to be strict. + + +# Module-level singleton, mirroring how Protocol is handled. +_dispatcher: Dispatcher | None = None + + +def get_dispatcher() -> Dispatcher: + """Return the active dispatcher. Raises if ``start_dispatcher`` hasn't run.""" + if _dispatcher is None: + raise RuntimeError("Dispatcher hasn't been started") + return _dispatcher + + +def start_dispatcher(protocol: Protocol) -> Dispatcher: + """Create and start the dispatcher. Idempotent.""" + global _dispatcher + if _dispatcher is None: + _dispatcher = Dispatcher(protocol) + _dispatcher.start() + return _dispatcher diff --git a/adapters/python/coflux/executor.py b/adapters/python/coflux/executor.py index 001069b3..ad58e545 100644 --- a/adapters/python/coflux/executor.py +++ b/adapters/python/coflux/executor.py @@ -13,6 +13,7 @@ from . import protocol from .context import ExecutorContext +from .dispatcher import start_dispatcher from .state import set_context from .output import capture_output from .models import Input @@ -71,6 +72,10 @@ def execute_target( ) -> None: """Execute a target with the given arguments.""" original_dir = os.getcwd() + # Start the stdin dispatcher. From here on, all incoming messages flow + # through it — individual threads block on the dispatcher rather than + # racing on stdin directly. + start_dispatcher(protocol.get_protocol()) try: if working_dir: os.chdir(working_dir) @@ -92,11 +97,10 @@ def execute_target( deserialized_args = _apply_type_hints(fn, deserialized_args) # Set up execution context - set_context( - ExecutorContext( - execution_id, working_dir=Path(working_dir) if working_dir else None - ) + ctx = ExecutorContext( + execution_id, working_dir=Path(working_dir) if working_dir else None ) + set_context(ctx) with capture_output(execution_id): if inspect.iscoroutinefunction(fn): @@ -107,10 +111,18 @@ def execute_target( else: result = fn(*deserialized_args) - # Serialize and send result - result_value = serialize_value(result) + # Serialize result. Generators anywhere in the return value (or that + # were passed to submitted child executions as args) have already + # been registered with the context's stream driver. + result_value = serialize_value(result, on_generator=ctx.register_stream) protocol.send_execution_result(execution_id, result_value) + # Hold the process open until every stream has drained. Thread + # safety: stdin access goes through the dispatcher (so subtask + # calls from generator bodies don't race), and stdout writes are + # serialised by Protocol._write_lock. + ctx.wait_streams() + except Exception as e: # Evaluate retry 'when' callback if present # None = no callback configured, True/False = callback result diff --git a/adapters/python/coflux/protocol.py b/adapters/python/coflux/protocol.py index 914f6c51..9d02d992 100644 --- a/adapters/python/coflux/protocol.py +++ b/adapters/python/coflux/protocol.py @@ -4,6 +4,7 @@ import json import sys +import threading from typing import Any from ._version import __version__ @@ -18,6 +19,10 @@ def __init__(self) -> None: # even when stdout/stderr are redirected for output capture self._stdout = sys.stdout self._stdin = sys.stdin + # Multiple threads (main + stream drivers + dispatcher-invoked + # handlers) can emit messages concurrently; serialize writes so JSON + # lines don't interleave. + self._write_lock = threading.Lock() def send_message(self, method: str, params: dict[str, Any] | None = None) -> None: """Send a notification message (no response expected).""" @@ -61,8 +66,9 @@ def receive(self) -> dict[str, Any] | None: def _write(self, obj: dict[str, Any]) -> None: """Write a JSON object as a line to stdout.""" line = json.dumps(obj, separators=(",", ":")) - self._stdout.write(line + "\n") - self._stdout.flush() + with self._write_lock: + self._stdout.write(line + "\n") + self._stdout.flush() # Global protocol instance for convenience diff --git a/adapters/python/coflux/serialization.py b/adapters/python/coflux/serialization.py index 8f8380bb..d73b64a7 100644 --- a/adapters/python/coflux/serialization.py +++ b/adapters/python/coflux/serialization.py @@ -7,6 +7,7 @@ import datetime import decimal import importlib +import inspect import io import json import pickle @@ -42,6 +43,7 @@ def _write_temp_file(data: bytes) -> str: def _encode_value( value: Any, write_temp_file: Callable[[bytes], str] = _write_temp_file, + on_generator: Callable[[Any], tuple[str, int]] | None = None, ) -> tuple[Any, list[list[Any]]]: """Encode a Python value using the custom JSON value format. @@ -53,6 +55,10 @@ def _encode_value( value: The Python value to encode. write_temp_file: Callable that writes bytes to a temp file and returns the path. Used for pickle fragment references. + on_generator: Callback invoked for each generator encountered. Should + register the generator (spawn its driver) and return the + `(execution_id, sequence)` identifying the stream. If None, + encountering a generator raises TypeError. Returns: Tuple of (data, references) where data is JSON-serializable and @@ -63,6 +69,22 @@ def _encode_value( def _encode(v: Any) -> Any: if v is None or isinstance(v, (str, bool, int, float)): return v + elif inspect.isgenerator(v): + if on_generator is None: + raise TypeError( + "Cannot serialize a generator: no stream driver is active." + ) + execution_id, sequence = on_generator(v) + return { + "type": "stream", + "execution_id": execution_id, + "sequence": sequence, + } + elif inspect.isasyncgen(v): + raise TypeError( + "Async generators aren't supported yet — use a sync generator " + "(def + yield) for now." + ) elif isinstance(v, list): return [_encode(x) for x in v] elif isinstance(v, dict): @@ -155,7 +177,10 @@ def _encode(v: Any) -> Any: return data, references -def serialize_value(value: Any) -> dict[str, Any]: +def serialize_value( + value: Any, + on_generator: Callable[[Any], tuple[str, int]] | None = None, +) -> dict[str, Any]: """Serialize a result value to the protocol format. Uses the custom JSON value encoding (dict/set/tuple types, fragment refs @@ -163,11 +188,13 @@ def serialize_value(value: Any) -> dict[str, Any]: Args: value: The Python value to serialize. + on_generator: Optional callback for generator objects. See + `_encode_value` for the contract. Without it, generators raise. Returns: Serialized value dict. """ - data, references = _encode_value(value) + data, references = _encode_value(value, on_generator=on_generator) encoded = json.dumps(data, separators=(",", ":")).encode() if len(encoded) > TRANSFER_THRESHOLD: diff --git a/adapters/python/coflux/streams.py b/adapters/python/coflux/streams.py new file mode 100644 index 00000000..3cd23fa2 --- /dev/null +++ b/adapters/python/coflux/streams.py @@ -0,0 +1,97 @@ +"""Producer-side stream management. + +Each execution that returns a value containing generators spins up a +``StreamDriver``. The driver: + + * Assigns monotonic sequence numbers to generators in encounter order. + * Registers each stream with the server (``stream_register``) at + encoding time so consumers can subscribe as soon as the result is + visible. + * Spawns one background thread per generator — a slow generator + doesn't block sibling streams. + * Joins all driver threads before the executor exits, so the process + stays alive until every stream has drained naturally. + +Threading is safe here because the ``Dispatcher`` owns stdin: any +subtask call from a generator body gets its response routed back to the +right driver thread. Writes to stdout go through ``Protocol._write_lock``. +""" + +from __future__ import annotations + +import threading +import traceback +from typing import Any + +from . import protocol +from .serialization import serialize_value + + +class StreamDriver: + """Manages streams produced by a single execution.""" + + def __init__(self, execution_id: str) -> None: + self._execution_id = execution_id + self._next_sequence = 0 + self._threads: list[threading.Thread] = [] + self._lock = threading.Lock() + + def register(self, generator: Any) -> tuple[str, int]: + """Register a generator, spawn its driver thread. + + Returns ``(execution_id, sequence)`` for embedding in the serialized + value as a stream reference. + """ + with self._lock: + sequence = self._next_sequence + self._next_sequence += 1 + + protocol.send_stream_register(self._execution_id, sequence) + + thread = threading.Thread( + target=self._drive, + args=(sequence, generator), + name=f"stream-{self._execution_id}-{sequence}", + daemon=False, + ) + thread.start() + self._threads.append(thread) + + return self._execution_id, sequence + + def _drive(self, sequence: int, generator: Any) -> None: + """Pump one generator to exhaustion (or error).""" + position = 0 + try: + for item in generator: + serialized = serialize_value(item) + protocol.send_stream_append( + self._execution_id, + sequence, + position, + serialized, + ) + position += 1 + except GeneratorExit: + # Generator explicitly closed (e.g. execution cancelled). The + # server already knows — no close message needed. + return + except BaseException as e: # noqa: BLE001 - we propagate all + error_type = f"{type(e).__module__}.{type(e).__qualname__}" + tb = traceback.format_exc() + protocol.send_stream_close( + self._execution_id, + sequence, + error_type=error_type, + error_message=str(e), + traceback=tb, + ) + else: + protocol.send_stream_close(self._execution_id, sequence) + + def wait_all(self) -> None: + """Block until every driver thread has finished.""" + with self._lock: + threads = list(self._threads) + for t in threads: + t.join() diff --git a/adapters/python/coflux/target.py b/adapters/python/coflux/target.py index 4d33f492..e65a54cc 100644 --- a/adapters/python/coflux/target.py +++ b/adapters/python/coflux/target.py @@ -409,8 +409,12 @@ def submit(self, *args: P.args, **kwargs: P.kwargs) -> Execution[T]: ctx = get_context() - # Serialize arguments - serialized_args = [serialize_value(arg) for arg in args] + # Serialize arguments. Generators passed as args are registered with + # the current execution's stream driver — the caller becomes the + # producer, the callee gets a Stream handle. + serialized_args = [ + serialize_value(arg, on_generator=ctx.register_stream) for arg in args + ] # Use only the declared wait_for from the decorator wait_for_val = ( From 24ddfc14f80da8eaba1f00067d386627c08bc44a Mon Sep 17 00:00:00 2001 From: Joe Freeman Date: Sat, 18 Apr 2026 20:04:43 +0100 Subject: [PATCH 06/25] Setup stream consumers --- adapters/python/coflux/__init__.py | 11 +- adapters/python/coflux/models.py | 69 ++++ adapters/python/coflux/protocol.py | 33 ++ adapters/python/coflux/serialization.py | 6 +- adapters/python/coflux/streams.py | 196 ++++++++++-- cli/internal/adapter/protocol.go | 34 ++ cli/internal/pool/pool.go | 63 ++++ cli/internal/worker/worker.go | 149 +++++++-- server/lib/coflux/handlers/worker.ex | 70 +++++ server/lib/coflux/orchestration.ex | 25 ++ server/lib/coflux/orchestration/server.ex | 365 ++++++++++++++++++++-- 11 files changed, 947 insertions(+), 74 deletions(-) diff --git a/adapters/python/coflux/__init__.py b/adapters/python/coflux/__init__.py index 6a9e9cf6..ea7e4e0a 100644 --- a/adapters/python/coflux/__init__.py +++ b/adapters/python/coflux/__init__.py @@ -23,7 +23,15 @@ InputDismissed, ) from .metric import Metric, MetricGroup, MetricScale, progress -from .models import Asset, AssetEntry, AssetMetadata, Execution, Input, ModelSchema +from .models import ( + Asset, + AssetEntry, + AssetMetadata, + Execution, + Input, + ModelSchema, + Stream, +) from .prompt import Prompt from .state import get_context from .target import Cache, Defer, Retries @@ -56,6 +64,7 @@ "Asset", "AssetEntry", "AssetMetadata", + "Stream", # Context functions "group", "suspense", diff --git a/adapters/python/coflux/models.py b/adapters/python/coflux/models.py index f99d5512..3048a87d 100644 --- a/adapters/python/coflux/models.py +++ b/adapters/python/coflux/models.py @@ -208,3 +208,72 @@ def module(self) -> str: @property def target(self) -> str: return self._target + + +class Stream(t.Iterable[T]): + """A handle to a stream produced by another execution. + + Iterating a ``Stream`` opens a subscription with the server; items arrive + pushed over the WebSocket and yield from the iterator. Each ``__iter__`` + starts a fresh subscription from position 0, so a stream can be iterated + multiple times and each iteration sees the whole sequence. + + ``partition`` and ``slice`` return new ``Stream`` views with an additional + filter; no server round-trip happens until iteration begins. + """ + + def __init__( + self, + producer_execution_id: str, + sequence: int, + filters: tuple[dict[str, t.Any], ...] = (), + ): + self._producer_execution_id = producer_execution_id + self._sequence = sequence + self._filters = filters + + @property + def producer_execution_id(self) -> str: + return self._producer_execution_id + + @property + def sequence(self) -> int: + return self._sequence + + def partition(self, n: int, i: int) -> "Stream[T]": + """Return a view of this stream where only positions ``p`` with + ``p % n == i`` are delivered. Round-robin partitioning for parallel + consumers. + """ + if n < 1 or i < 0 or i >= n: + raise ValueError(f"invalid partition args: n={n}, i={i}") + return Stream( + self._producer_execution_id, + self._sequence, + self._filters + ({"type": "partition", "n": n, "i": i},), + ) + + def slice(self, start: int, stop: int | None = None) -> "Stream[T]": + """Return a view of this stream restricted to positions ``[start, stop)``. + + ``stop=None`` means unbounded. Equivalent to ``itertools.islice`` on + the source stream's positions. + """ + if start < 0 or (stop is not None and stop < start): + raise ValueError(f"invalid slice args: start={start}, stop={stop}") + return Stream( + self._producer_execution_id, + self._sequence, + self._filters + ({"type": "slice", "start": start, "stop": stop},), + ) + + def __iter__(self) -> t.Iterator[T]: + # Deferred import to avoid a cycle (streams.py imports serialization + # which imports models for Execution/Input/Asset). + from .streams import open_subscription + + return open_subscription( + self._producer_execution_id, + self._sequence, + self._filters, + ) diff --git a/adapters/python/coflux/protocol.py b/adapters/python/coflux/protocol.py index 9d02d992..14000574 100644 --- a/adapters/python/coflux/protocol.py +++ b/adapters/python/coflux/protocol.py @@ -519,6 +519,39 @@ def send_stream_close( get_protocol().send_message("stream_close", params) +def send_stream_subscribe( + execution_id: str, + subscription_id: int, + producer_execution_id: str, + sequence: int, + from_position: int, + filter: dict[str, Any] | None = None, +) -> None: + """Open a consumer subscription to a stream owned by another execution. + + ``execution_id`` is the consumer's own execution — the server uses it to + track who's subscribed and where to push items. + """ + params: dict[str, Any] = { + "execution_id": execution_id, + "subscription_id": subscription_id, + "producer_execution_id": producer_execution_id, + "sequence": sequence, + "from_position": from_position, + } + if filter is not None: + params["filter"] = filter + get_protocol().send_message("stream_subscribe", params) + + +def send_stream_unsubscribe(execution_id: str, subscription_id: int) -> None: + """Drop a consumer subscription.""" + get_protocol().send_message( + "stream_unsubscribe", + {"execution_id": execution_id, "subscription_id": subscription_id}, + ) + + def receive_message() -> dict[str, Any] | None: """Receive the next message from the CLI.""" return get_protocol().receive() diff --git a/adapters/python/coflux/serialization.py b/adapters/python/coflux/serialization.py index d73b64a7..d2746370 100644 --- a/adapters/python/coflux/serialization.py +++ b/adapters/python/coflux/serialization.py @@ -16,7 +16,7 @@ from pathlib import Path from typing import Any, Callable -from .models import Asset, AssetMetadata, Execution, Input +from .models import Asset, AssetMetadata, Execution, Input, Stream # Try to import pydantic try: @@ -259,6 +259,10 @@ def _decode(v: Any) -> Any: return uuid.UUID(v["value"]) elif t == "ref": return _resolve_ref(v["index"]) + elif t == "stream": + # Producer-owned stream reference. Self-contained — + # execution_id and sequence are both in the descriptor. + return Stream(v["execution_id"], v["sequence"]) else: return v else: diff --git a/adapters/python/coflux/streams.py b/adapters/python/coflux/streams.py index 3cd23fa2..cbb96c64 100644 --- a/adapters/python/coflux/streams.py +++ b/adapters/python/coflux/streams.py @@ -1,30 +1,35 @@ -"""Producer-side stream management. - -Each execution that returns a value containing generators spins up a -``StreamDriver``. The driver: - - * Assigns monotonic sequence numbers to generators in encounter order. - * Registers each stream with the server (``stream_register``) at - encoding time so consumers can subscribe as soon as the result is - visible. - * Spawns one background thread per generator — a slow generator - doesn't block sibling streams. - * Joins all driver threads before the executor exits, so the process - stays alive until every stream has drained naturally. - -Threading is safe here because the ``Dispatcher`` owns stdin: any -subtask call from a generator body gets its response routed back to the -right driver thread. Writes to stdout go through ``Protocol._write_lock``. +"""Producer and consumer stream plumbing. + +The producer side owns ``StreamDriver``: each execution whose return value +(or submitted arguments) contains generators uses one to drive each +generator in a background thread. + +The consumer side owns a module-level ``StreamRegistry``: open consumer +subscriptions are keyed by subscription id. The registry's dispatcher +handlers (``stream_items``/``stream_closed``) route incoming pushes from +the server to the right iterator's queue, which yields as the user +iterates. + +Both sides are thread-safe: the ``Dispatcher`` owns stdin (so subtask +calls from generator bodies don't race), and stdout writes go through +``Protocol._write_lock``. """ from __future__ import annotations +import queue import threading import traceback -from typing import Any +from typing import Any, Iterator from . import protocol -from .serialization import serialize_value +from .dispatcher import get_dispatcher +from .errors import create_execution_error +from .serialization import deserialize_value, serialize_value +from .state import get_context + + +# --- Producer side --- class StreamDriver: @@ -95,3 +100,156 @@ def wait_all(self) -> None: threads = list(self._threads) for t in threads: t.join() + + +# --- Consumer side --- + + +# Sentinel pushed onto a subscriber's queue to signal close. Carries the +# optional error dict ({"type": str, "message": str} or None). +class _Closed: + __slots__ = ("error",) + + def __init__(self, error: dict[str, Any] | None) -> None: + self.error = error + + +class _StreamIterator(Iterator[Any]): + """Drains items for one active subscription via a bounded-free queue.""" + + def __init__(self, subscription_id: int, execution_id: str) -> None: + self._subscription_id = subscription_id + self._execution_id = execution_id + self._queue: queue.Queue[Any] = queue.Queue() + self._done = False + + def on_items(self, items: list[list[Any]]) -> None: + """Called by the registry when the server pushes items for this + subscription. ``items`` is a list of ``[position, value_wire]``. + """ + for _position, value in items: + # Decode eagerly so iteration cost is paid per-item as it arrives. + self._queue.put(deserialize_value(value)) + + def on_closed(self, error: dict[str, Any] | None) -> None: + """Called by the registry when the stream closes.""" + self._queue.put(_Closed(error)) + + def __iter__(self) -> "_StreamIterator": + return self + + def __next__(self) -> Any: + if self._done: + raise StopIteration + item = self._queue.get() + if isinstance(item, _Closed): + self._done = True + _stream_registry().drop(self._subscription_id) + protocol.send_stream_unsubscribe(self._execution_id, self._subscription_id) + if item.error is not None: + raise create_execution_error( + item.error.get("type", ""), + item.error.get("message", ""), + ) + raise StopIteration + return item + + +class StreamRegistry: + """Per-process registry of open consumer subscriptions.""" + + def __init__(self) -> None: + self._lock = threading.Lock() + self._next_id = 0 + self._iterators: dict[int, _StreamIterator] = {} + self._installed = False + + def _ensure_installed(self) -> None: + # Register dispatcher handlers on first use. Deferred so importing + # this module is free until a task actually iterates a stream. + if self._installed: + return + d = get_dispatcher() + d.register_notification("stream_items", self._on_items) + d.register_notification("stream_closed", self._on_closed) + self._installed = True + + def allocate(self, execution_id: str) -> tuple[int, _StreamIterator]: + """Claim a subscription id and iterator.""" + self._ensure_installed() + with self._lock: + subscription_id = self._next_id + self._next_id += 1 + it = _StreamIterator(subscription_id, execution_id) + self._iterators[subscription_id] = it + return subscription_id, it + + def drop(self, subscription_id: int) -> None: + with self._lock: + self._iterators.pop(subscription_id, None) + + def _on_items(self, params: dict[str, Any]) -> None: + subscription_id = params.get("subscription_id") + items = params.get("items") or [] + with self._lock: + it = self._iterators.get(subscription_id) + if it is not None: + it.on_items(items) + + def _on_closed(self, params: dict[str, Any]) -> None: + subscription_id = params.get("subscription_id") + error = params.get("error") + with self._lock: + it = self._iterators.get(subscription_id) + if it is not None: + it.on_closed(error) + + +_registry_instance: StreamRegistry | None = None + + +def _stream_registry() -> StreamRegistry: + global _registry_instance + if _registry_instance is None: + _registry_instance = StreamRegistry() + return _registry_instance + + +def open_subscription( + producer_execution_id: str, + sequence: int, + filters: tuple[dict[str, Any], ...], +) -> Iterator[Any]: + """Begin iterating a stream. Called by ``Stream.__iter__``. + + Allocates a subscription id, sends the subscribe message, and returns + an iterator that yields as items arrive. + """ + ctx = get_context() + execution_id = ctx.execution_id + subscription_id, iterator = _stream_registry().allocate(execution_id) + + filter = _compose_filter(filters) + protocol.send_stream_subscribe( + execution_id, + subscription_id, + producer_execution_id, + sequence, + 0, + filter, + ) + return iterator + + +def _compose_filter( + filters: tuple[dict[str, Any], ...], +) -> dict[str, Any] | None: + """Collapse a list of filters for the wire. + + Empty → null. Single → pass through. Many → wrap in {"type": "chain"}. + """ + if not filters: + return None + if len(filters) == 1: + return filters[0] + return {"type": "chain", "filters": list(filters)} diff --git a/cli/internal/adapter/protocol.go b/cli/internal/adapter/protocol.go index 594788bc..eb1903c1 100644 --- a/cli/internal/adapter/protocol.go +++ b/cli/internal/adapter/protocol.go @@ -284,6 +284,40 @@ type StreamCloseError struct { Traceback string `json:"traceback"` } +// StreamSubscribeParams for stream_subscribe notification. +// Filter is one of nil, {"type": "slice", "start", "stop"}, +// or {"type": "partition", "n", "i"}. +type StreamSubscribeParams struct { + ExecutionID string `json:"execution_id"` // consumer + SubscriptionID int `json:"subscription_id"` + ProducerExecutionID string `json:"producer_execution_id"` + Sequence int `json:"sequence"` + FromPosition int `json:"from_position"` + Filter map[string]any `json:"filter,omitempty"` +} + +// StreamUnsubscribeParams for stream_unsubscribe notification. +type StreamUnsubscribeParams struct { + ExecutionID string `json:"execution_id"` + SubscriptionID int `json:"subscription_id"` +} + +// StreamItemsParams for stream_items notification pushed CLI → adapter. +// Items are [[position, value], ...] where value is a wire Value. +type StreamItemsParams struct { + ExecutionID string `json:"execution_id"` + SubscriptionID int `json:"subscription_id"` + Items []any `json:"items"` +} + +// StreamClosedParams for stream_closed notification pushed CLI → adapter. +// Error is nil for clean close or a {type, message} dict for errored close. +type StreamClosedParams struct { + ExecutionID string `json:"execution_id"` + SubscriptionID int `json:"subscription_id"` + Error map[string]any `json:"error,omitempty"` +} + // DownloadBlobParams for download_blob request type DownloadBlobParams struct { ExecutionID string `json:"execution_id"` diff --git a/cli/internal/pool/pool.go b/cli/internal/pool/pool.go index 9ef56572..d0fae65c 100644 --- a/cli/internal/pool/pool.go +++ b/cli/internal/pool/pool.go @@ -57,6 +57,11 @@ type ExecutionHandler interface { // StreamClose closes a stream. Error is nil for a clean close, or a (type, message, traceback) // triple when the producer's generator raised. StreamClose(ctx context.Context, executionID string, sequence int, err *adapter.StreamCloseError) error + // StreamSubscribe opens a consumer subscription to a stream owned by another execution. + // Filter is nil or a {"type": "slice", ...}/{"type": "partition", ...} map. + StreamSubscribe(ctx context.Context, executionID string, subscriptionID int, producerExecutionID string, sequence int, fromPosition int, filter map[string]any) error + // StreamUnsubscribe drops a consumer subscription. + StreamUnsubscribe(ctx context.Context, executionID string, subscriptionID int) error } // Pool manages executor processes. Each executor handles one execution then @@ -281,6 +286,12 @@ loop: case "stream_close": p.handleStreamClose(execCtx, executionID, params, logger) + case "stream_subscribe": + p.handleStreamSubscribe(execCtx, executionID, params, logger) + + case "stream_unsubscribe": + p.handleStreamUnsubscribe(execCtx, executionID, params, logger) + default: err := fmt.Errorf("unknown message method: %s", method) logger.Error("unknown message method", "method", method) @@ -488,6 +499,38 @@ func (p *Pool) handleStreamClose(ctx context.Context, executionID string, params } } +func (p *Pool) handleStreamSubscribe(ctx context.Context, executionID string, params json.RawMessage, logger *slog.Logger) { + var req adapter.StreamSubscribeParams + if err := json.Unmarshal(params, &req); err != nil { + logger.Error("failed to parse stream_subscribe message", "error", err) + return + } + + if err := p.handler.StreamSubscribe( + ctx, + req.ExecutionID, + req.SubscriptionID, + req.ProducerExecutionID, + req.Sequence, + req.FromPosition, + req.Filter, + ); err != nil { + logger.Error("failed to subscribe to stream", "error", err) + } +} + +func (p *Pool) handleStreamUnsubscribe(ctx context.Context, executionID string, params json.RawMessage, logger *slog.Logger) { + var req adapter.StreamUnsubscribeParams + if err := json.Unmarshal(params, &req); err != nil { + logger.Error("failed to parse stream_unsubscribe message", "error", err) + return + } + + if err := p.handler.StreamUnsubscribe(ctx, req.ExecutionID, req.SubscriptionID); err != nil { + logger.Error("failed to unsubscribe from stream", "error", err) + } +} + func (p *Pool) handleRequest(ctx context.Context, exec *adapter.Executor, method string, id int, params json.RawMessage, logger *slog.Logger) { var result any var errInfo *adapter.ErrorInfo @@ -646,6 +689,26 @@ func (p *Pool) Abort(executionID string) error { return nil } +// PushToExecutor forwards a server-originated notification to the adapter +// process handling the given execution. Used for stream_items / stream_closed +// pushes and, later, stream_produce_until flow-control signals. Silently +// no-ops if the execution isn't active — a late push after the adapter exited +// isn't an error condition. +func (p *Pool) PushToExecutor(executionID, method string, params any) error { + p.mu.Lock() + exec, ok := p.busy[executionID] + p.mu.Unlock() + + if !ok { + return nil + } + + return exec.Send(map[string]any{ + "method": method, + "params": params, + }) +} + // Drain marks the pool as draining (stops spawning warm executors), // closes any existing warm executors, and waits up to timeout for // in-flight executions to finish on their own. A timeout of 0 means diff --git a/cli/internal/worker/worker.go b/cli/internal/worker/worker.go index 3058c819..cbf6e72e 100644 --- a/cli/internal/worker/worker.go +++ b/cli/internal/worker/worker.go @@ -360,6 +360,8 @@ func (w *Worker) runConnection(ctx context.Context, targets map[string]map[strin ) conn.RegisterHandler("execute", w.handleExecute) conn.RegisterHandler("abort", w.handleAbort) + conn.RegisterHandler("stream_items", w.handleStreamItems) + conn.RegisterHandler("stream_closed", w.handleStreamClosed) conn.SetOnSession(w.handleSession) if err := conn.Connect(ctx); err != nil { @@ -553,45 +555,55 @@ func (w *Worker) handleExecute(params []any) error { func (w *Worker) convertArguments(args []any) ([]adapter.Argument, error) { result := make([]adapter.Argument, len(args)) for i, arg := range args { - arr, ok := arg.([]any) - if !ok { - return nil, fmt.Errorf("argument %d: expected array", i) - } - - value, err := api.ParseValue(arr) + value, err := w.convertValueFromServer(arg) if err != nil { return nil, fmt.Errorf("argument %d: %w", i, err) } + result[i] = *value + } + return result, nil +} - // Convert to adapter argument - adapterRefs, err := w.refsToAdapter(value.References) +// convertValueFromServer turns a wire-form value array (["raw", data, refs] +// or ["blob", key, size, refs]) into an adapter-side Value struct suitable +// for forwarding to the Python adapter. +func (w *Worker) convertValueFromServer(arg any) (*adapter.Value, error) { + arr, ok := arg.([]any) + if !ok { + return nil, fmt.Errorf("expected array") + } + + value, err := api.ParseValue(arr) + if err != nil { + return nil, err + } + + adapterRefs, err := w.refsToAdapter(value.References) + if err != nil { + return nil, err + } + + switch value.Type { + case api.ValueTypeRaw: + return &adapter.Value{ + Type: "inline", + Format: "json", + Value: value.Content, + References: adapterRefs, + }, nil + case api.ValueTypeBlob: + path, err := w.blobs.Download(value.Key) if err != nil { - return nil, fmt.Errorf("argument %d: %w", i, err) - } - switch value.Type { - case api.ValueTypeRaw: - result[i] = adapter.Argument{ - Type: "inline", - Format: "json", - Value: value.Content, - References: adapterRefs, - } - case api.ValueTypeBlob: - // Download blob to cache - path, err := w.blobs.Download(value.Key) - if err != nil { - return nil, fmt.Errorf("argument %d: failed to download blob: %w", i, err) - } - format := "json" - result[i] = adapter.Argument{ - Type: "file", - Format: format, - Path: path, - References: adapterRefs, - } + return nil, fmt.Errorf("failed to download blob: %w", err) } + return &adapter.Value{ + Type: "file", + Format: "json", + Path: path, + References: adapterRefs, + }, nil } - return result, nil + return nil, fmt.Errorf("unknown value type: %v", value.Type) } func (w *Worker) refsToAdapter(refs []api.Reference) ([][]any, error) { @@ -637,6 +649,62 @@ func (w *Worker) handleAbort(params []any) error { return w.pool.Abort(executionID) } +// handleStreamItems forwards a server-pushed batch of stream items to the +// adapter process owning the target execution. Params: [execution_id, +// subscription_id, items]. Each item arrives as [position, value_array] +// and is converted to [position, adapter.Value dict] so the Python side +// can deserialize_value it directly. +func (w *Worker) handleStreamItems(params []any) error { + if len(params) < 3 { + return fmt.Errorf("stream_items: insufficient params") + } + executionID := getString(params[0]) + subscriptionID, _ := params[1].(float64) + rawItems, ok := params[2].([]any) + if !ok { + return fmt.Errorf("stream_items: items is not an array") + } + + converted := make([]any, len(rawItems)) + for i, raw := range rawItems { + itemArr, ok := raw.([]any) + if !ok || len(itemArr) != 2 { + return fmt.Errorf("stream_items: item %d malformed", i) + } + value, err := w.convertValueFromServer(itemArr[1]) + if err != nil { + return fmt.Errorf("stream_items: item %d value: %w", i, err) + } + converted[i] = []any{itemArr[0], value} + } + + return w.pool.PushToExecutor(executionID, "stream_items", map[string]any{ + "execution_id": executionID, + "subscription_id": int(subscriptionID), + "items": converted, + }) +} + +// handleStreamClosed forwards a server-pushed stream-closed notification. +// Params: [execution_id, subscription_id, error_or_null]. +func (w *Worker) handleStreamClosed(params []any) error { + if len(params) < 3 { + return fmt.Errorf("stream_closed: insufficient params") + } + executionID := getString(params[0]) + subscriptionID, _ := params[1].(float64) + errField := params[2] + + forwarded := map[string]any{ + "execution_id": executionID, + "subscription_id": int(subscriptionID), + } + if errField != nil { + forwarded["error"] = errField + } + return w.pool.PushToExecutor(executionID, "stream_closed", forwarded) +} + func (w *Worker) heartbeatLoop(ctx context.Context) { ticker := time.NewTicker(heartbeatInterval) defer ticker.Stop() @@ -1138,6 +1206,23 @@ func (w *Worker) StreamClose(ctx context.Context, executionID string, sequence i return conn.Notify("stream_close", executionID, sequence, errTuple) } +func (w *Worker) StreamSubscribe(ctx context.Context, executionID string, subscriptionID int, producerExecutionID string, sequence int, fromPosition int, filter map[string]any) error { + conn, err := w.requireConn() + if err != nil { + return err + } + // Params: [subscription_id, consumer_execution_id, producer_execution_id, sequence, from_position, filter] + return conn.Notify("stream_subscribe", subscriptionID, executionID, producerExecutionID, sequence, fromPosition, filter) +} + +func (w *Worker) StreamUnsubscribe(ctx context.Context, executionID string, subscriptionID int) error { + conn, err := w.requireConn() + if err != nil { + return err + } + return conn.Notify("stream_unsubscribe", subscriptionID) +} + func (w *Worker) Cancel(ctx context.Context, executionID string, handles []adapter.SelectHandle) error { conn, err := w.waitForConn(ctx) if err != nil { diff --git a/server/lib/coflux/handlers/worker.ex b/server/lib/coflux/handlers/worker.ex index 3c4f29a4..9397a0b8 100644 --- a/server/lib/coflux/handlers/worker.ex +++ b/server/lib/coflux/handlers/worker.ex @@ -327,6 +327,61 @@ defmodule Coflux.Handlers.Worker do {[{:close, 4000, "execution_invalid"}], nil} end + "stream_subscribe" -> + [ + subscription_id, + consumer_execution_id, + producer_execution_id, + sequence, + from_position, + filter + ] = message["params"] + + if is_recognised_execution?(consumer_execution_id, state) do + case Orchestration.subscribe_stream( + state.project_id, + state.session_id, + subscription_id, + consumer_execution_id, + producer_execution_id, + sequence, + from_position, + filter + ) do + :ok -> + {[], state} + + # If the stream doesn't exist yet (or producer vanished), push an + # immediate close so the consumer doesn't wait forever. + {:error, reason} + when reason in [:stream_not_found, :producer_not_found, :already_subscribed] -> + {[ + command_message("stream_closed", [ + consumer_execution_id, + subscription_id, + %{"type" => "Coflux.StreamNotFound", "message" => Atom.to_string(reason)} + ]) + ], state} + + {:error, :consumer_not_found} -> + {[{:close, 4000, "execution_invalid"}], nil} + end + else + {[{:close, 4000, "execution_invalid"}], nil} + end + + "stream_unsubscribe" -> + [subscription_id] = message["params"] + + :ok = + Orchestration.unsubscribe_stream( + state.project_id, + state.session_id, + subscription_id + ) + + {[], state} + "put_error" -> [execution_id, error] = message["params"] @@ -572,6 +627,21 @@ defmodule Coflux.Handlers.Worker do {[command_message("abort", [execution_external_id])], state} end + def websocket_info({:stream_items, execution_external_id, subscription_id, items}, state) do + # Items arrive in resolved form ([[position, value_tuple], ...]); compose + # each value tuple to wire JSON here. + encoded = + Enum.map(items, fn [position, value] -> + [position, compose_value(value)] + end) + + {[command_message("stream_items", [execution_external_id, subscription_id, encoded])], state} + end + + def websocket_info({:stream_closed, execution_external_id, subscription_id, error}, state) do + {[command_message("stream_closed", [execution_external_id, subscription_id, error])], state} + end + def websocket_info(:stop, state) do {[{:close, 4000, "workspace_not_found"}], state} end diff --git a/server/lib/coflux/orchestration.ex b/server/lib/coflux/orchestration.ex index d6bb80e9..9a90986d 100644 --- a/server/lib/coflux/orchestration.ex +++ b/server/lib/coflux/orchestration.ex @@ -197,6 +197,31 @@ defmodule Coflux.Orchestration do call_server(project_id, {:close_stream, execution_id, sequence, error}) end + # Stream consumer messages — consumer opens a subscription to receive + # items from a producer's stream; server pushes stream_items / + # stream_closed commands to the consumer's session. + + def subscribe_stream( + project_id, + session_id, + subscription_id, + consumer_execution_id, + producer_execution_id, + sequence, + from_position, + filter + ) do + call_server( + project_id, + {:subscribe_stream, session_id, subscription_id, consumer_execution_id, + producer_execution_id, sequence, from_position, filter} + ) + end + + def unsubscribe_stream(project_id, session_id, subscription_id) do + call_server(project_id, {:unsubscribe_stream, session_id, subscription_id}) + end + def select( project_id, handles, diff --git a/server/lib/coflux/orchestration/server.ex b/server/lib/coflux/orchestration/server.ex index 6732e45f..2ae912de 100644 --- a/server/lib/coflux/orchestration/server.ex +++ b/server/lib/coflux/orchestration/server.ex @@ -115,7 +115,23 @@ defmodule Coflux.Orchestration.Server do pending_dependencies: %{}, # execution_id -> MapSet of execution_ids that are waiting on this execution - dependency_waiters: %{} + dependency_waiters: %{}, + + # Active stream subscriptions — in-memory, session-scoped. + # A consumer adapter opens a subscription by sending stream_subscribe + # with a session-unique subscription_id; we push items (stream_items + # command) as they arrive on the producer side, and a terminal + # stream_closed command when the stream ends. Dropped when the + # session disconnects, when the consumer unsubscribes, or when the + # stream closes. + # + # stream_subscriptions: {session_id, subscription_id} -> %{ + # consumer_execution_id, producer_execution_id, sequence, + # cursor, filter} + # stream_subscribers: {producer_execution_id, sequence} -> MapSet of + # {session_id, subscription_id} + stream_subscriptions: %{}, + stream_subscribers: %{} end def start_link(opts) do @@ -1834,8 +1850,12 @@ defmodule Coflux.Orchestration.Server do position, normalize_value(value) ) do - {:ok, _} -> {:reply, :ok, state} - {:error, reason} -> {:reply, {:error, reason}, state} + {:ok, _} -> + state = push_stream_item(state, execution_id, sequence, position, value) + {:reply, :ok, state} + + {:error, reason} -> + {:reply, {:error, reason}, state} end :error -> @@ -1847,8 +1867,12 @@ defmodule Coflux.Orchestration.Server do case Map.fetch(state.execution_ids, execution_external_id) do {:ok, execution_id} -> case Streams.close_stream(state.db, execution_id, sequence, error) do - {:ok, _} -> {:reply, :ok, state} - {:error, reason} -> {:reply, {:error, reason}, state} + {:ok, _} -> + state = push_stream_closed(state, execution_id, sequence, error) + {:reply, :ok, state} + + {:error, reason} -> + {:reply, {:error, reason}, state} end :error -> @@ -1856,6 +1880,68 @@ defmodule Coflux.Orchestration.Server do end end + def handle_call( + {:subscribe_stream, session_external_id, subscription_id, consumer_execution_external_id, + producer_execution_external_id, sequence, from_position, filter}, + _from, + state + ) do + with {:ok, session_id} <- + Map.fetch(state.session_ids, session_external_id) + |> ok_or(:session_not_found), + {:ok, consumer_execution_id} <- + Map.fetch(state.execution_ids, consumer_execution_external_id) + |> ok_or(:consumer_not_found), + {:ok, producer_execution_id} <- + Map.fetch(state.execution_ids, producer_execution_external_id) + |> ok_or(:producer_not_found), + {:ok, true} <- Streams.exists?(state.db, producer_execution_id, sequence), + key = {session_id, subscription_id}, + false <- Map.has_key?(state.stream_subscriptions, key) do + subscription = %{ + consumer_execution_id: consumer_execution_id, + consumer_execution_external_id: consumer_execution_external_id, + producer_execution_id: producer_execution_id, + sequence: sequence, + cursor: from_position, + filter: filter + } + + state = + state + |> Map.update!(:stream_subscriptions, &Map.put(&1, key, subscription)) + |> Map.update!(:stream_subscribers, fn m -> + Map.update( + m, + {producer_execution_id, sequence}, + MapSet.new([key]), + &MapSet.put(&1, key) + ) + end) + + # Push any items already in the log that match the filter, then (if + # the stream has already closed) the terminal close record. + state = push_backlog(state, session_id, subscription_id) + state = maybe_push_closure_if_closed(state, session_id, subscription_id) + + {:reply, :ok, state} + else + {:ok, false} -> {:reply, {:error, :stream_not_found}, state} + true -> {:reply, {:error, :already_subscribed}, state} + {:error, reason} -> {:reply, {:error, reason}, state} + end + end + + def handle_call({:unsubscribe_stream, session_external_id, subscription_id}, _from, state) do + case Map.fetch(state.session_ids, session_external_id) do + {:ok, session_id} -> + {:reply, :ok, drop_subscription(state, {session_id, subscription_id})} + + :error -> + {:reply, :ok, state} + end + end + def handle_call( {:select, handles, from_execution_external_id, timeout_ms, suspend, cancel_remaining, request_id}, @@ -3523,9 +3609,9 @@ defmodule Coflux.Orchestration.Server do # Close any open streams so iterating consumers stop waiting. Any # subsequent `append_item` from the producer will fail with `:closed`, - # signalling the worker to stop. The closure carries no error; consumers - # resolve the cancel from the execution's own disposition. - close_open_streams(state, execution_id) + # signalling the worker to stop. Push :cancelled so consumers raise + # ExecutionCancelled on iteration. + state = close_open_streams(state, execution_id, :cancelled) state = case Runs.get_execution_key(state.db, execution_id) do @@ -3770,6 +3856,9 @@ defmodule Coflux.Orchestration.Server do {:ok, _} = Sessions.expire_session(state.db, session_id) {session, state} = pop_in(state.sessions[session_id]) state = Map.update!(state, :session_expiries, &Map.delete(&1, session_id)) + # Drop any stream subscriptions this session held — consumer has gone + # away, so there's no one to push to. + state = drop_session_subscriptions(state, session_id) # starting/executing now contain external IDs - resolve to internal for process_result. # Session removal means no more notify_terminated for these executions, so we @@ -5689,9 +5778,8 @@ defmodule Coflux.Orchestration.Server do {:ok, true} -> # Close any streams left open by the producer. Generator tasks # normally close their streams explicitly; this is the backstop - # for ones that didn't. Consumers resolve the close reason from - # the execution's own disposition. - close_open_streams(state, execution_id) + # for ones that didn't. + state = close_open_streams(state, execution_id, :complete) case Results.record_completion(state.db, execution_id) do {:ok, completion_at} -> @@ -5720,9 +5808,9 @@ defmodule Coflux.Orchestration.Server do decide_and_create_successor(state, execution_id, step, workspace_id, :crashed) # Streams that had been appended to before the worker died need to be - # closed so consumers don't wait forever. The closure carries no error — - # the execution's own :crashed disposition is the source of truth. - close_open_streams(state, execution_id) + # closed so consumers don't wait forever. Push :crashed so consumers + # raise ExecutionCrashed on iteration. + state = close_open_streams(state, execution_id, :crashed) case Results.record_completion(state.db, execution_id) do {:ok, completion_at} -> @@ -5739,17 +5827,31 @@ defmodule Coflux.Orchestration.Server do end # Closes every stream owned by `execution_id` that doesn't yet have a - # closure row. The closure carries no error — the consumer resolves the - # reason from the execution's result / completion state (clean, crashed, - # cancelled). If a generator closed its stream with an explicit error, + # closure row, and pushes a `stream_closed` notification to every active + # subscriber. If a generator closed its stream with an explicit error, # that closure already exists and is left untouched. - defp close_open_streams(state, execution_id) do + # + # ``reason`` annotates the closure for the consumer's benefit. It's stored + # in the DB closure row as nil (execution disposition is the source of + # truth), but pushed synchronously to subscribers so they can raise the + # right exception on iteration: + # * :complete – no error, StopIteration on next() + # * :cancelled – ExecutionCancelled on next() + # * :crashed – ExecutionCrashed on next() + defp close_open_streams(state, execution_id, reason) do {:ok, sequences} = Streams.get_open_streams_for_execution(state.db, execution_id) - Enum.each(sequences, fn sequence -> + push_error = + case reason do + :complete -> nil + :cancelled -> {"Coflux.ExecutionCancelled", "execution cancelled", []} + :crashed -> {"Coflux.ExecutionCrashed", "worker terminated", []} + end + + Enum.reduce(sequences, state, fn sequence, state -> case Streams.close_stream(state.db, execution_id, sequence) do - {:ok, _} -> :ok - {:error, :already_closed} -> :ok + {:ok, _} -> push_stream_closed(state, execution_id, sequence, push_error) + {:error, :already_closed} -> state end end) end @@ -7022,6 +7124,227 @@ defmodule Coflux.Orchestration.Server do end end + # --- Stream subscription helpers --- + + defp ok_or({:ok, val}, _reason), do: {:ok, val} + defp ok_or(:error, reason), do: {:error, reason} + + # Does a `position` pass a subscription's filter? + defp filter_matches?(nil, _position), do: true + + defp filter_matches?(%{"type" => "slice", "start" => s, "stop" => nil}, position), + do: position >= s + + defp filter_matches?(%{"type" => "slice", "start" => s, "stop" => e}, position), + do: position >= s and position < e + + defp filter_matches?(%{"type" => "partition", "n" => n, "i" => i}, position), + do: rem(position, n) == i + + defp filter_matches?(%{"type" => "chain", "filters" => fs}, position), + do: Enum.all?(fs, &filter_matches?(&1, position)) + + defp filter_matches?(_filter, _position), do: true + + # Is `position` past the end of the filter's effective range? + # Lets us close streams early once a slice's stop is reached. + defp filter_exhausted?(%{"type" => "slice", "stop" => stop}, cursor) when is_integer(stop), + do: cursor >= stop + + defp filter_exhausted?(%{"type" => "chain", "filters" => fs}, cursor), + do: Enum.any?(fs, &filter_exhausted?(&1, cursor)) + + defp filter_exhausted?(_filter, _cursor), do: false + + # Send backlog items (those already in the DB) for a newly subscribed consumer. + defp push_backlog(state, session_id, subscription_id) do + key = {session_id, subscription_id} + sub = Map.fetch!(state.stream_subscriptions, key) + + # Conservative fetch cap: stream more in batches if needed on advance. + # For v1 (no flow control) we just drain the whole tail. + {:ok, items} = + Streams.get_stream_items( + state.db, + sub.producer_execution_id, + sub.sequence, + sub.cursor, + 1_000_000 + ) + + filtered = + items + |> Enum.filter(fn {position, _value, _at} -> filter_matches?(sub.filter, position) end) + |> Enum.take_while(fn {position, _, _} -> not filter_exhausted?(sub.filter, position) end) + + if filtered == [] do + state + else + last_pos = elem(List.last(filtered), 0) + next_cursor = last_pos + 1 + + # Values are already in internal form (from DB) — resolve refs to + # external IDs. Final JSON encoding happens in the WS handler. + resolved_items = + Enum.map(filtered, fn {position, value, _at} -> + [position, build_value(value, state.db)] + end) + + state = + send_session( + state, + session_id, + {:stream_items, sub.consumer_execution_external_id, subscription_id, resolved_items} + ) + + update_in( + state.stream_subscriptions[key], + &Map.put(&1, :cursor, next_cursor) + ) + end + end + + # Push a freshly-appended item to every subscriber of this stream. + defp push_stream_item(state, producer_execution_id, sequence, position, value) do + subscribers = + Map.get(state.stream_subscribers, {producer_execution_id, sequence}, MapSet.new()) + + Enum.reduce(subscribers, state, fn key, state -> + {session_id, subscription_id} = key + sub = Map.fetch!(state.stream_subscriptions, key) + + cond do + position < sub.cursor -> + # Consumer already has this position via backlog; skip. + state + + not filter_matches?(sub.filter, position) -> + state + + true -> + # Value came off the wire in parse form (ext-id refs, no metadata). + # Normalise + resolve to match the form push_backlog sends; the WS + # handler composes to wire JSON. + resolved = build_value(normalize_value(value), state.db) + item = [position, resolved] + + state = + send_session( + state, + session_id, + {:stream_items, sub.consumer_execution_external_id, subscription_id, [item]} + ) + + state = + update_in( + state.stream_subscriptions[key], + &Map.put(&1, :cursor, position + 1) + ) + + # If the filter is exhausted (e.g. slice reached its stop), close + # the subscription early — no more items will match. + if filter_exhausted?(sub.filter, position + 1) do + state + |> send_session( + session_id, + {:stream_closed, sub.consumer_execution_external_id, subscription_id, nil} + ) + |> drop_subscription(key) + else + state + end + end + end) + end + + # On close, tell every subscriber. Error is either nil (clean close) or a + # {type, message, frames} triple — same shape as Streams.close_stream takes. + defp push_stream_closed(state, producer_execution_id, sequence, error) do + subscribers = + Map.get(state.stream_subscribers, {producer_execution_id, sequence}, MapSet.new()) + + encoded_error = + case error do + nil -> nil + {type, message, _frames} -> %{"type" => type, "message" => message} + end + + Enum.reduce(subscribers, state, fn key, state -> + {session_id, subscription_id} = key + sub = Map.fetch!(state.stream_subscriptions, key) + + state + |> send_session( + session_id, + {:stream_closed, sub.consumer_execution_external_id, subscription_id, encoded_error} + ) + |> drop_subscription(key) + end) + end + + # If a subscription attaches to an already-closed stream, emit closure now. + defp maybe_push_closure_if_closed(state, session_id, subscription_id) do + key = {session_id, subscription_id} + sub = Map.fetch!(state.stream_subscriptions, key) + + case Streams.get_stream_closure(state.db, sub.producer_execution_id, sub.sequence) do + {:ok, nil} -> + state + + {:ok, {error, _closed_at}} -> + encoded_error = + case error do + nil -> nil + {type, message, _frames} -> %{"type" => type, "message" => message} + end + + state + |> send_session( + session_id, + {:stream_closed, sub.consumer_execution_external_id, subscription_id, encoded_error} + ) + |> drop_subscription(key) + end + end + + defp drop_subscription(state, key) do + case Map.fetch(state.stream_subscriptions, key) do + :error -> + state + + {:ok, sub} -> + stream_key = {sub.producer_execution_id, sub.sequence} + + state + |> Map.update!(:stream_subscriptions, &Map.delete(&1, key)) + |> Map.update!(:stream_subscribers, fn m -> + case Map.get(m, stream_key) do + nil -> + m + + subs -> + remaining = MapSet.delete(subs, key) + + if MapSet.size(remaining) == 0 do + Map.delete(m, stream_key) + else + Map.put(m, stream_key, remaining) + end + end + end) + end + end + + # Drop every subscription owned by a disconnected session. + defp drop_session_subscriptions(state, session_id) do + keys = + state.stream_subscriptions + |> Map.keys() + |> Enum.filter(fn {sid, _} -> sid == session_id end) + + Enum.reduce(keys, state, &drop_subscription(&2, &1)) + end + # Clean up an execution's state and send an abort message to the worker. defp abort_execution(state, execution_ext_id) do state = cleanup_execution(state, execution_ext_id) From 6afcba254faea300a20c3676b89e329ff62493ae Mon Sep 17 00:00:00 2001 From: Joe Freeman Date: Sat, 18 Apr 2026 20:32:33 +0100 Subject: [PATCH 07/25] Update topic, and update epoch generation --- server/lib/coflux/orchestration/epoch.ex | 66 ++++++++++++++++ server/lib/coflux/orchestration/server.ex | 87 ++++++++++++++++++++-- server/lib/coflux/orchestration/streams.ex | 36 +++++++++ server/lib/coflux/topics/run.ex | 50 ++++++++++++- 4 files changed, 230 insertions(+), 9 deletions(-) diff --git a/server/lib/coflux/orchestration/epoch.ex b/server/lib/coflux/orchestration/epoch.ex index f0d6eeac..e7edc0f1 100644 --- a/server/lib/coflux/orchestration/epoch.ex +++ b/server/lib/coflux/orchestration/epoch.ex @@ -332,6 +332,72 @@ defmodule Coflux.Orchestration.Epoch do end end) + # Copy streams, their items, and any closure rows. An execution's + # streams may be mid-production (items appended, no closure) — + # carry them forward so consumers can keep reading after rotation. + Enum.each(execution_ids, fn {old_exec_id, new_exec_id} -> + {:ok, streams} = + query( + source_db, + "SELECT sequence, created_at FROM streams WHERE execution_id = ?1", + {old_exec_id} + ) + + Enum.each(streams, fn {sequence, stream_created_at} -> + {:ok, _} = + insert_one(target_db, :streams, %{ + execution_id: new_exec_id, + sequence: sequence, + created_at: stream_created_at + }) + + {:ok, items} = + query( + source_db, + """ + SELECT position, value_id, created_at + FROM stream_items + WHERE execution_id = ?1 AND sequence = ?2 + """, + {old_exec_id, sequence} + ) + + Enum.each(items, fn {position, value_id, item_created_at} -> + new_value_id = ensure_value(source_db, target_db, value_id) + + {:ok, _} = + insert_one(target_db, :stream_items, %{ + execution_id: new_exec_id, + sequence: sequence, + position: position, + value_id: new_value_id, + created_at: item_created_at + }) + end) + + case query_one( + source_db, + "SELECT error_id, created_at FROM stream_closures WHERE execution_id = ?1 AND sequence = ?2", + {old_exec_id, sequence} + ) do + {:ok, {error_id, closure_created_at}} -> + new_error_id = + if error_id, do: ensure_error(source_db, target_db, error_id) + + {:ok, _} = + insert_one(target_db, :stream_closures, %{ + execution_id: new_exec_id, + sequence: sequence, + error_id: new_error_id, + created_at: closure_created_at + }) + + {:ok, nil} -> + :ok + end + end) + end) + # Copy children — same-run internal IDs {:ok, children} = query( diff --git a/server/lib/coflux/orchestration/server.ex b/server/lib/coflux/orchestration/server.ex index 2ae912de..2124864a 100644 --- a/server/lib/coflux/orchestration/server.ex +++ b/server/lib/coflux/orchestration/server.ex @@ -1827,8 +1827,15 @@ defmodule Coflux.Orchestration.Server do case Map.fetch(state.execution_ids, execution_external_id) do {:ok, execution_id} -> case Streams.register_stream(state.db, execution_id, sequence) do - {:ok, _} -> {:reply, :ok, state} - {:error, :already_registered} -> {:reply, {:error, :already_registered}, state} + {:ok, created_at} -> + state = + notify_stream_opened(state, execution_id, sequence, created_at) + |> flush_notifications() + + {:reply, :ok, state} + + {:error, :already_registered} -> + {:reply, {:error, :already_registered}, state} end :error -> @@ -1867,8 +1874,13 @@ defmodule Coflux.Orchestration.Server do case Map.fetch(state.execution_ids, execution_external_id) do {:ok, execution_id} -> case Streams.close_stream(state.db, execution_id, sequence, error) do - {:ok, _} -> - state = push_stream_closed(state, execution_id, sequence, error) + {:ok, closed_at} -> + state = + state + |> push_stream_closed(execution_id, sequence, error) + |> notify_stream_closed(execution_id, sequence, error, closed_at) + |> flush_notifications() + {:reply, :ok, state} {:error, reason} -> @@ -4959,6 +4971,8 @@ defmodule Coflux.Orchestration.Server do ) ) + {:ok, streams} = Streams.get_streams_with_closures_for_execution(db, execution_id) + {attempt, %{ execution_id: exec_external_id, @@ -4976,7 +4990,8 @@ defmodule Coflux.Orchestration.Server do result: result, result_created_by: result_created_by, children: Map.get(run_children, execution_id, []), - metric_definitions: Map.get(metric_definitions_by_execution, execution_id, %{}) + metric_definitions: Map.get(metric_definitions_by_execution, execution_id, %{}), + streams: streams }} end) }} @@ -5850,8 +5865,13 @@ defmodule Coflux.Orchestration.Server do Enum.reduce(sequences, state, fn sequence, state -> case Streams.close_stream(state.db, execution_id, sequence) do - {:ok, _} -> push_stream_closed(state, execution_id, sequence, push_error) - {:error, :already_closed} -> state + {:ok, closed_at} -> + state + |> push_stream_closed(execution_id, sequence, push_error) + |> notify_stream_closed(execution_id, sequence, push_error, closed_at) + + {:error, :already_closed} -> + state end end) end @@ -7124,6 +7144,59 @@ defmodule Coflux.Orchestration.Server do end end + # --- Producer flow control --- + # + # Not yet wired. Implicit backpressure today: a slow consumer's WS push + # blocks the GenServer, which blocks append_item, which blocks the + # producer. That's usually enough. When a real use case surfaces for + # explicit pause/resume (e.g. an infinite producer with no subscribers + # filling disk), the hooks go here: + # * On first subscriber for a stream: send_session(producer_session, + # {:stream_resume, producer_exec_ext, sequence}). + # * On last subscriber dropping: {:stream_pause, ...}. + # Dispatcher on the adapter side already routes notifications by method; + # the StreamDriver gets a per-stream Event to gate its next() calls. + + # --- Stream topic notifications (for Studio subscribers) --- + # These flow through `notify_listeners` → the run topic, distinct from the + # session-directed `push_stream_*` helpers which target subscribed consumer + # sessions' WebSockets. + + defp notify_stream_opened(state, execution_id, sequence, created_at) do + {:ok, {r, _s, _a}} = Runs.get_execution_key(state.db, execution_id) + {:ok, execution_ext_id} = execution_external_id_for(state.db, execution_id) + + notify_listeners( + state, + {:run, r}, + {:stream_opened, execution_ext_id, sequence, created_at} + ) + end + + defp notify_stream_closed(state, execution_id, sequence, error, closed_at) do + {:ok, {r, _s, _a}} = Runs.get_execution_key(state.db, execution_id) + {:ok, execution_ext_id} = execution_external_id_for(state.db, execution_id) + + encoded_error = + case error do + nil -> nil + {type, message, _frames} -> %{type: type, message: message} + end + + notify_listeners( + state, + {:run, r}, + {:stream_closed, execution_ext_id, sequence, encoded_error, closed_at} + ) + end + + defp execution_external_id_for(db, execution_id) do + case Runs.get_execution_key(db, execution_id) do + {:ok, {r, s, a}} -> {:ok, execution_external_id(r, s, a)} + err -> err + end + end + # --- Stream subscription helpers --- defp ok_or({:ok, val}, _reason), do: {:ok, val} diff --git a/server/lib/coflux/orchestration/streams.ex b/server/lib/coflux/orchestration/streams.ex index 08988717..623a2d2c 100644 --- a/server/lib/coflux/orchestration/streams.ex +++ b/server/lib/coflux/orchestration/streams.ex @@ -210,6 +210,42 @@ defmodule Coflux.Orchestration.Streams do end end + # Returns one row per stream owned by `execution_id`: + # `{sequence, created_at, closed_at | nil, error | nil}` where error, when + # present, is a `{type, message, frames}` triple. Used when populating the + # topic state for a run — lets the UI render streams and their state in + # one query. + def get_streams_with_closures_for_execution(db, execution_id) do + case query( + db, + """ + SELECT s.sequence, s.created_at, c.created_at, c.error_id + FROM streams AS s + LEFT JOIN stream_closures AS c + ON c.execution_id = s.execution_id AND c.sequence = s.sequence + WHERE s.execution_id = ?1 + ORDER BY s.sequence + """, + {execution_id} + ) do + {:ok, rows} -> + streams = + Enum.map(rows, fn + {sequence, created_at, nil, nil} -> + {sequence, created_at, nil, nil} + + {sequence, created_at, closed_at, nil} -> + {sequence, created_at, closed_at, nil} + + {sequence, created_at, closed_at, error_id} -> + {:ok, error} = Errors.get_by_id(db, error_id) + {sequence, created_at, closed_at, error} + end) + + {:ok, streams} + end + end + # Returns the highest position recorded for the stream, or `-1` if empty. # Used by the worker protocol to report "head" for flow control without # requiring the caller to scan all items. diff --git a/server/lib/coflux/topics/run.ex b/server/lib/coflux/topics/run.ex index 613f359a..c8a3dea4 100644 --- a/server/lib/coflux/topics/run.ex +++ b/server/lib/coflux/topics/run.ex @@ -103,7 +103,8 @@ defmodule Coflux.Topics.Run do children: [], inputs: %{}, result: nil, - metrics: %{} + metrics: %{}, + streams: %{} } ) else @@ -206,6 +207,32 @@ defmodule Coflux.Topics.Run do end) end + defp process_notification( + topic, + {:stream_opened, execution_external_id, sequence, created_at} + ) do + update_execution(topic, execution_external_id, fn topic, base_path -> + Topic.set(topic, base_path ++ [:streams, Integer.to_string(sequence)], %{ + openedAt: created_at, + closedAt: nil, + error: nil + }) + end) + end + + defp process_notification( + topic, + {:stream_closed, execution_external_id, sequence, error, closed_at} + ) do + seq_key = Integer.to_string(sequence) + + update_execution(topic, execution_external_id, fn topic, base_path -> + topic + |> Topic.set(base_path ++ [:streams, seq_key, :closedAt], closed_at) + |> Topic.set(base_path ++ [:streams, seq_key, :error], error) + end) + end + defp process_notification( topic, {:result_result, execution_external_id, result, _created_at, created_by} @@ -380,7 +407,8 @@ defmodule Coflux.Topics.Run do lower: def_data.lower, upper: def_data.upper }} - end) + end), + streams: build_streams(execution.streams) }} end) }} @@ -519,6 +547,24 @@ defmodule Coflux.Topics.Run do end end + defp build_streams(streams) do + Map.new(streams, fn + {sequence, opened_at, nil, nil} -> + {Integer.to_string(sequence), %{openedAt: opened_at, closedAt: nil, error: nil}} + + {sequence, opened_at, closed_at, nil} -> + {Integer.to_string(sequence), %{openedAt: opened_at, closedAt: closed_at, error: nil}} + + {sequence, opened_at, closed_at, {type, message, _frames}} -> + {Integer.to_string(sequence), + %{ + openedAt: opened_at, + closedAt: closed_at, + error: %{type: type, message: message} + }} + end) + end + defp execution_attempt({ext_id, _module, _target}) do ext_id |> String.split(":") |> List.last() |> String.to_integer() end From e87ab2f0b547fbfe4573ad568c02d0e6a825fe1a Mon Sep 17 00:00:00 2001 From: Joe Freeman Date: Sat, 18 Apr 2026 23:05:22 +0100 Subject: [PATCH 08/25] Add serialiser for stream --- adapters/python/coflux/serialization.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/adapters/python/coflux/serialization.py b/adapters/python/coflux/serialization.py index d2746370..d1ab888a 100644 --- a/adapters/python/coflux/serialization.py +++ b/adapters/python/coflux/serialization.py @@ -80,6 +80,19 @@ def _encode(v: Any) -> Any: "execution_id": execution_id, "sequence": sequence, } + elif isinstance(v, Stream): + # Pass-through: a Stream handle received from another execution + # (possibly with partition/slice filters layered on top) is + # being forwarded as an argument. Preserve the filter chain so + # the downstream consumer subscribes with the same filters. + encoded: dict[str, Any] = { + "type": "stream", + "execution_id": v.producer_execution_id, + "sequence": v.sequence, + } + if v._filters: + encoded["filters"] = list(v._filters) + return encoded elif inspect.isasyncgen(v): raise TypeError( "Async generators aren't supported yet — use a sync generator " @@ -261,8 +274,11 @@ def _decode(v: Any) -> Any: return _resolve_ref(v["index"]) elif t == "stream": # Producer-owned stream reference. Self-contained — - # execution_id and sequence are both in the descriptor. - return Stream(v["execution_id"], v["sequence"]) + # execution_id, sequence, and any filter chain (when the + # Stream was forwarded with partition/slice filters + # already applied) are all in the descriptor. + filters = tuple(v.get("filters") or ()) + return Stream(v["execution_id"], v["sequence"], filters) else: return v else: From b340d4901ca5783b4cb24e9fa00014aa906430a3 Mon Sep 17 00:00:00 2001 From: Joe Freeman Date: Sat, 18 Apr 2026 23:05:57 +0100 Subject: [PATCH 09/25] Fix typing --- adapters/python/coflux/decorators.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/adapters/python/coflux/decorators.py b/adapters/python/coflux/decorators.py index 5924353d..272c5f39 100644 --- a/adapters/python/coflux/decorators.py +++ b/adapters/python/coflux/decorators.py @@ -7,19 +7,30 @@ from .target import Cache, Defer, Retries, Target +if t.TYPE_CHECKING: + from .models import Stream + P = t.ParamSpec("P") T = t.TypeVar("T") class _TargetDecorator(t.Protocol): - """Decorator protocol that unwraps ``Coroutine`` for async functions. + """Decorator protocol that unwraps ``Coroutine`` and collapses generator + return types to ``Stream``. + + Overloading on ``__call__`` (rather than on the factory function) lets + the type checker pick the right overload based on the decorated + function's return type — which is visible at application time, but not + at the factory call. - Overloading on ``__call__`` (rather than on the factory function) - lets the type checker pick the right overload based on the decorated - function's return type — which is visible at application time, but - not at the factory call. + Overload resolution order matters: generator functions match first (so + ``-> Iterator[T]`` gives a ``Target[P, Stream[T]]``), then async + coroutines (unwrapped), then the general case. """ + @t.overload + def __call__(self, fn: t.Callable[P, t.Iterator[T]]) -> Target[P, "Stream[T]"]: ... + @t.overload def __call__( self, fn: t.Callable[P, t.Coroutine[t.Any, t.Any, T]] From f2b3f4d517b884e168b06ab692d1eb91d06ea30a Mon Sep 17 00:00:00 2001 From: Joe Freeman Date: Sat, 18 Apr 2026 23:07:10 +0100 Subject: [PATCH 10/25] Add tests --- cli/internal/pool/pool.go | 16 +- server/lib/coflux/orchestration/server.ex | 49 +- tests/support/adapter.py | 37 +- tests/support/executor.py | 138 +++++- tests/support/protocol.py | 88 ++++ tests/test_streams.py | 533 ++++++++++++++++++++++ 6 files changed, 834 insertions(+), 27 deletions(-) create mode 100644 tests/test_streams.py diff --git a/cli/internal/pool/pool.go b/cli/internal/pool/pool.go index d0fae65c..e722ed4b 100644 --- a/cli/internal/pool/pool.go +++ b/cli/internal/pool/pool.go @@ -214,6 +214,10 @@ func (p *Pool) runExecution(ctx context.Context, exec *adapter.Executor, executi defer cancelTimeout() } timedOut := false + // Once the adapter has reported a result, a subsequent EOF is the + // expected natural exit (after any stream drain completes). Don't + // treat that as a failure to receive. + resultReported := false // Handle messages until execution completes loop: @@ -240,6 +244,9 @@ loop: if aborted { logger.Info("execution aborted") logger.Debug("aborted executor exit", "error", err) + } else if resultReported { + // Clean exit after result + stream drain. + logger.Debug("executor exited after result", "error", err) } else { logger.Error("failed to receive message", "error", err) p.handler.ReportError(ctx, executionID, "internal", err.Error(), "", nil) @@ -255,10 +262,17 @@ loop: switch method { case "execution_result": + // Don't break the loop here — a streaming task keeps sending + // stream_append/stream_close messages after the result is + // committed. Stop reading only when the adapter process exits + // (Receive returns an error), which happens after wait_all in + // the adapter drains every generator. p.handleExecutionResult(execCtx, executionID, params, logger) - break loop + resultReported = true case "execution_error": + // Error is terminal — a task that raised before yielding any + // generators has nothing to stream. Stop reading. p.handleExecutionError(execCtx, executionID, params, logger) break loop diff --git a/server/lib/coflux/orchestration/server.ex b/server/lib/coflux/orchestration/server.ex index 2124864a..604deb14 100644 --- a/server/lib/coflux/orchestration/server.ex +++ b/server/lib/coflux/orchestration/server.ex @@ -1898,15 +1898,21 @@ defmodule Coflux.Orchestration.Server do _from, state ) do + # Producer may already have terminated — resolve from DB (active epoch) + # rather than the in-memory active-execution cache. + producer_result = + case resolve_internal_execution_id(state, producer_execution_external_id) do + {:ok, id} -> {:ok, id} + {:error, :not_found} -> {:error, :producer_not_found} + end + with {:ok, session_id} <- Map.fetch(state.session_ids, session_external_id) |> ok_or(:session_not_found), {:ok, consumer_execution_id} <- Map.fetch(state.execution_ids, consumer_execution_external_id) |> ok_or(:consumer_not_found), - {:ok, producer_execution_id} <- - Map.fetch(state.execution_ids, producer_execution_external_id) - |> ok_or(:producer_not_found), + {:ok, producer_execution_id} <- producer_result, {:ok, true} <- Streams.exists?(state.db, producer_execution_id, sequence), key = {session_id, subscription_id}, false <- Map.has_key?(state.stream_subscriptions, key) do @@ -7270,10 +7276,26 @@ defmodule Coflux.Orchestration.Server do {:stream_items, sub.consumer_execution_external_id, subscription_id, resolved_items} ) - update_in( - state.stream_subscriptions[key], - &Map.put(&1, :cursor, next_cursor) - ) + state = + update_in( + state.stream_subscriptions[key], + &Map.put(&1, :cursor, next_cursor) + ) + + # If the filter is now exhausted (slice's stop reached), close the + # subscription synchronously — matches push_stream_item's behaviour. + # Without this, a consumer that subscribed after appends with a + # bounded filter would wait forever for a close that never comes. + if filter_exhausted?(sub.filter, next_cursor) do + state + |> send_session( + session_id, + {:stream_closed, sub.consumer_execution_external_id, subscription_id, nil} + ) + |> drop_subscription(key) + else + state + end end end @@ -7356,10 +7378,21 @@ defmodule Coflux.Orchestration.Server do end # If a subscription attaches to an already-closed stream, emit closure now. + # If push_backlog already closed the subscription (e.g., a bounded filter + # was exhausted by the backlog itself), this is a no-op. defp maybe_push_closure_if_closed(state, session_id, subscription_id) do key = {session_id, subscription_id} - sub = Map.fetch!(state.stream_subscriptions, key) + case Map.fetch(state.stream_subscriptions, key) do + :error -> + state + + {:ok, sub} -> + do_maybe_push_closure(state, sub, session_id, subscription_id, key) + end + end + + defp do_maybe_push_closure(state, sub, session_id, subscription_id, key) do case Streams.get_stream_closure(state.db, sub.producer_execution_id, sub.sequence) do {:ok, nil} -> state diff --git a/tests/support/adapter.py b/tests/support/adapter.py index 0e54557b..77cffa44 100644 --- a/tests/support/adapter.py +++ b/tests/support/adapter.py @@ -23,15 +23,25 @@ def execute(args): stdin_fd = sys.stdin.buffer.fileno() stdout_fd = sys.stdout.buffer.fileno() + # Exit when EITHER direction ends: + # * stdin closes — the CLI's old pattern (break loop, Wait, close stdin). + # * the test's socket closes — the new pattern, simulating the real + # adapter exiting after its task finishes. The CLI's post-fix receive + # loop stays open until EOF on adapter stdout, so mock-side socket + # close is how tests signal "this execution's work is done". + done = threading.Event() + def stdin_to_socket(): - while data := os.read(stdin_fd, 4096): - conn.sendall(data) - # Stdin closed (Go pool called Wait/Close). Shut down writes to - # tell the test socket we're done sending. + try: + while data := os.read(stdin_fd, 4096): + conn.sendall(data) + except OSError: + pass try: conn.shutdown(socket.SHUT_WR) except OSError: pass + done.set() def socket_to_stdout(): try: @@ -39,17 +49,18 @@ def socket_to_stdout(): os.write(stdout_fd, data) except OSError: pass + # Test closed the socket. Close stdout so CLI's readLoop gets EOF, + # breaks its receive loop, and runs Wait (which will close stdin). + try: + os.close(stdout_fd) + except OSError: + pass + done.set() - t1 = threading.Thread(target=stdin_to_socket) - t2 = threading.Thread(target=socket_to_stdout, daemon=True) - t1.start() - t2.start() + threading.Thread(target=stdin_to_socket, daemon=True).start() + threading.Thread(target=socket_to_stdout, daemon=True).start() - # Wait for stdin to close (Go pool called Wait/Close), then exit. - # t2 is a daemon thread — it will be cleaned up when the process exits. - # By this point, any data from the test socket has already been relayed - # to stdout (the Go pool reads the result before closing stdin). - t1.join() + done.wait() if __name__ == "__main__": diff --git a/tests/support/executor.py b/tests/support/executor.py index ae756be8..e282d8ed 100644 --- a/tests/support/executor.py +++ b/tests/support/executor.py @@ -49,11 +49,20 @@ def __init__(self, conn): self._conn = conn self._file = conn.makefile("rb") self._next_request_id = 1 + # When a test mixes RPCs with async pushes (stream_items, stream_closed), + # we may read a push while waiting for a response and vice-versa. Park + # mismatched messages here so the next `recv` / helper picks them up. + self._buffer = [] def send(self, msg: dict): self._conn.sendall(protocol.encode_message(msg)) def recv(self, timeout=10) -> dict: + if self._buffer: + return self._buffer.pop(0) + return self._recv_raw(timeout) + + def _recv_raw(self, timeout) -> dict: self._conn.settimeout(timeout) try: line = self._file.readline() @@ -90,14 +99,22 @@ def recv_execute(self, **kwargs): return p["execution_id"], p.get("module", ""), p["target"], p.get("arguments", []) def _request(self, msg): - """Send a request message (with auto-assigned ID) and return the response.""" + """Send a request message (with auto-assigned ID) and return the response. + + If async pushes (stream_items / stream_closed) arrive ahead of the + response, they're parked in the buffer so tests can fetch them later + via recv_push. + """ rid = self._next_request_id self._next_request_id += 1 msg["id"] = rid self.send(msg) - resp = self.recv() - assert resp["id"] == rid - return resp + while True: + incoming = self.recv() + if incoming.get("id") == rid: + return incoming + # Park non-matching messages (typically notifications). + self._buffer.append(incoming) def submit_task(self, execution_id, module, target, arguments, **kwargs): """Submit a child task execution and return the target execution ID.""" @@ -239,15 +256,126 @@ def resolve_input( raise RuntimeError(f"select error: {resp['error']}") return _unwrap_select_result(resp.get("result")) + # --- Stream producer helpers --- + + def stream_register(self, execution_id, sequence): + """Notify that a new stream exists.""" + self.send(protocol.stream_register(execution_id, sequence)) + + def stream_append(self, execution_id, sequence, position, value, format="json"): + """Append an item (raw JSON value) to a stream.""" + self.send(protocol.stream_append(execution_id, sequence, position, value, format=format)) + + def stream_close(self, execution_id, sequence, error=None): + """Close a stream (optionally with an error {type, message, traceback}).""" + self.send(protocol.stream_close(execution_id, sequence, error=error)) + + # --- Stream consumer helpers --- + + def stream_subscribe( + self, + execution_id, + subscription_id, + producer_execution_id, + sequence, + from_position=0, + filter=None, + ): + """Subscribe to a stream. ``filter`` is an optional dict built via + protocol.slice_filter / partition_filter / chain_filter.""" + self.send( + protocol.stream_subscribe( + execution_id, + subscription_id, + producer_execution_id, + sequence, + from_position=from_position, + filter=filter, + ) + ) + + def stream_unsubscribe(self, execution_id, subscription_id): + self.send(protocol.stream_unsubscribe(execution_id, subscription_id)) + + def recv_push(self, method, subscription_id=None, timeout=10): + """Read messages until one matching ``method`` (and subscription) arrives. + + Returns the params dict. Non-matching messages are re-buffered in + order so later calls can consume them. + """ + held = [] + deadline = time.time() + timeout + try: + while True: + remaining = max(0.01, deadline - time.time()) + msg = self.recv(timeout=remaining) + if msg.get("method") == method: + params = msg.get("params", {}) + if ( + subscription_id is None + or params.get("subscription_id") == subscription_id + ): + # Put held messages back (preserve order) before returning. + self._buffer[:0] = held + return params + held.append(msg) + except TimeoutError: + # Restore buffer and propagate. + self._buffer[:0] = held + raise + + def drain_stream(self, subscription_id, timeout=10): + """Collect every pushed item + final closure for ``subscription_id``. + + Returns ``(items, closed_params)`` where ``items`` is a list of + ``[position, value_dict]`` pairs in arrival order. Messages for other + subscriptions are re-buffered so later calls can fetch them. + """ + items = [] + deadline = time.time() + timeout + while True: + remaining = max(0.01, deadline - time.time()) + msg = self.recv(timeout=remaining) + method = msg.get("method") + params = msg.get("params", {}) + if params.get("subscription_id") != subscription_id or method not in ( + "stream_items", + "stream_closed", + ): + self._buffer.append(msg) + continue + if method == "stream_items": + items.extend(params.get("items", [])) + continue + # stream_closed — terminal + return items, params + def complete(self, execution_id, value=None): - """Send execution_result.""" + """Send execution_result and signal the mock adapter we're done. + + Closing the socket mirrors what a real adapter does — it exits after + the task finishes (and streams drain). Without this, the CLI's + post-result receive loop never hits EOF, so the mock adapter hangs + and ties up the worker's concurrency slot. + """ self.send(protocol.execution_result(execution_id, value=value)) + self._close_sending_side() def fail(self, execution_id, error_type, message, traceback="", retryable=None): """Send execution_error.""" self.send( protocol.execution_error(execution_id, error_type, message, traceback, retryable=retryable) ) + self._close_sending_side() + + def _close_sending_side(self): + """Shut down the socket's write side; reads stay open for any pushes + in flight (e.g. final stream_closed) so asserts can still collect. + """ + try: + self._conn.shutdown(socket.SHUT_WR) + except OSError: + pass def close(self): try: diff --git a/tests/support/protocol.py b/tests/support/protocol.py index 56ae2b6b..d0512e7a 100644 --- a/tests/support/protocol.py +++ b/tests/support/protocol.py @@ -231,6 +231,94 @@ def register_group_notification(execution_id, group_id, name=None): return {"method": "register_group", "params": params} +# --- Stream messages (producer side: adapter → server) --- + + +def stream_register(execution_id, sequence): + return { + "method": "stream_register", + "params": {"execution_id": execution_id, "sequence": sequence}, + } + + +def stream_append(execution_id, sequence, position, value, format="json"): + """Append an item to a stream. ``value`` is the raw JSON value. + + Builds a Value wire-form message with an empty references list. Tests + that need references should build the Value dict manually. + """ + return { + "method": "stream_append", + "params": { + "execution_id": execution_id, + "sequence": sequence, + "position": position, + "value": { + "type": "inline", + "format": format, + "value": value, + "references": [], + }, + }, + } + + +def stream_close(execution_id, sequence, error=None): + """Close a stream. ``error`` is optional {type, message, traceback}.""" + params = {"execution_id": execution_id, "sequence": sequence} + if error is not None: + params["error"] = error + return {"method": "stream_close", "params": params} + + +# --- Stream messages (consumer side: adapter → server) --- + + +def stream_subscribe( + execution_id, + subscription_id, + producer_execution_id, + sequence, + from_position=0, + filter=None, +): + params = { + "execution_id": execution_id, + "subscription_id": subscription_id, + "producer_execution_id": producer_execution_id, + "sequence": sequence, + "from_position": from_position, + } + if filter is not None: + params["filter"] = filter + return {"method": "stream_subscribe", "params": params} + + +def stream_unsubscribe(execution_id, subscription_id): + return { + "method": "stream_unsubscribe", + "params": { + "execution_id": execution_id, + "subscription_id": subscription_id, + }, + } + + +# --- Filter builders --- + + +def slice_filter(start, stop=None): + return {"type": "slice", "start": start, "stop": stop} + + +def partition_filter(n, i): + return {"type": "partition", "n": n, "i": i} + + +def chain_filter(*filters): + return {"type": "chain", "filters": list(filters)} + + def submit_input_request( request_id, execution_id, diff --git a/tests/test_streams.py b/tests/test_streams.py new file mode 100644 index 00000000..ebd62f5f --- /dev/null +++ b/tests/test_streams.py @@ -0,0 +1,533 @@ +"""Integration tests for the streaming protocol. + +These tests drive the mock adapter directly — they send/receive the wire +messages the real Python adapter would use, so they exercise the full +server + CLI relay + subscription/push pipeline. + +Two common patterns: + + * Producer-only: a single execution registers a stream, appends items, + and closes. Verification is done by a subsequent consumer subscription + (also driven by the same test). + + * Producer + consumer interleaved: both are driven from the same test, + taking turns over different connections. +""" + +from support.manifest import workflow +from support.protocol import ( + execution_result, + json_args, + partition_filter, + slice_filter, + chain_filter, +) + + +def _run_and_handle_stream(ctx, targets, produce_fn): + """Submit a no-arg workflow and hand the executor connection to `produce_fn`. + + ``produce_fn(conn, execution_id)`` does whatever stream work the test + needs (register / append / close) and then sends an execution_result. + Returns the run_id so tests can assert on topic state if desired. + """ + resp = ctx.submit("test", targets[0]["name"]) + ex = ctx.executor.next_execute() + produce_fn(ex.conn, ex.execution_id) + return resp["runId"], ex.execution_id + + +def test_producer_writes_and_consumer_reads_backlog(worker): + """Producer registers, appends 3 items, closes. Then a consumer in a + separate execution subscribes and drains the backlog plus close. + """ + targets = [ + workflow("test", "producer"), + workflow("test", "consumer"), + ] + + with worker(targets) as ctx: + # Producer + prod_resp = ctx.submit("test", "producer") + prod_ex = ctx.executor.next_execute() + prod_ex.conn.stream_register(prod_ex.execution_id, 0) + prod_ex.conn.stream_append(prod_ex.execution_id, 0, 0, "a") + prod_ex.conn.stream_append(prod_ex.execution_id, 0, 1, "b") + prod_ex.conn.stream_append(prod_ex.execution_id, 0, 2, "c") + prod_ex.conn.stream_close(prod_ex.execution_id, 0) + prod_ex.conn.complete(prod_ex.execution_id, value=42) + + # Producer's run must finish before the consumer can subscribe — + # otherwise the execution_id isn't known to the consumer's workflow. + ctx.result(prod_resp["runId"]) + + # Consumer in a separate workflow subscribes to the producer's stream. + cons_resp = ctx.submit("test", "consumer") + cons_ex = ctx.executor.next_execute() + cons_ex.conn.stream_subscribe( + cons_ex.execution_id, + subscription_id=1, + producer_execution_id=prod_ex.execution_id, + sequence=0, + ) + items, closed = cons_ex.conn.drain_stream(subscription_id=1) + cons_ex.conn.complete(cons_ex.execution_id) + + assert [item[0] for item in items] == [0, 1, 2] + assert [item[1]["value"] for item in items] == ["a", "b", "c"] + assert closed.get("error") is None + + +def test_consumer_sees_live_push(worker): + """Consumer subscribes *before* the producer appends. Items arrive live.""" + targets = [ + workflow("test", "producer"), + workflow("test", "consumer"), + ] + + with worker(targets, concurrency=2) as ctx: + # Producer registers but doesn't append yet. + prod_resp = ctx.submit("test", "producer") + prod_ex = ctx.executor.next_execute() + prod_ex.conn.stream_register(prod_ex.execution_id, 0) + + # Consumer subscribes now — stream is open with no items yet. + cons_resp = ctx.submit("test", "consumer") + cons_ex = ctx.executor.next_execute() + cons_ex.conn.stream_subscribe( + cons_ex.execution_id, + subscription_id=1, + producer_execution_id=prod_ex.execution_id, + sequence=0, + ) + + # Now producer appends + closes. + prod_ex.conn.stream_append(prod_ex.execution_id, 0, 0, 10) + prod_ex.conn.stream_append(prod_ex.execution_id, 0, 1, 20) + prod_ex.conn.stream_close(prod_ex.execution_id, 0) + + items, closed = cons_ex.conn.drain_stream(subscription_id=1) + + prod_ex.conn.complete(prod_ex.execution_id) + cons_ex.conn.complete(cons_ex.execution_id) + + assert [item[1]["value"] for item in items] == [10, 20] + assert closed.get("error") is None + + +def test_slice_filter_restricts_items(worker): + """Slice filter ``[1, 3)`` delivers only positions 1 and 2.""" + targets = [workflow("test", "producer"), workflow("test", "consumer")] + + with worker(targets) as ctx: + prod_resp = ctx.submit("test", "producer") + prod_ex = ctx.executor.next_execute() + prod_ex.conn.stream_register(prod_ex.execution_id, 0) + for i in range(5): + prod_ex.conn.stream_append(prod_ex.execution_id, 0, i, i * 10) + prod_ex.conn.stream_close(prod_ex.execution_id, 0) + prod_ex.conn.complete(prod_ex.execution_id) + ctx.result(prod_resp["runId"]) + + cons_resp = ctx.submit("test", "consumer") + cons_ex = ctx.executor.next_execute() + cons_ex.conn.stream_subscribe( + cons_ex.execution_id, + subscription_id=1, + producer_execution_id=prod_ex.execution_id, + sequence=0, + filter=slice_filter(1, 3), + ) + items, _ = cons_ex.conn.drain_stream(subscription_id=1) + cons_ex.conn.complete(cons_ex.execution_id) + + assert [item[0] for item in items] == [1, 2] + assert [item[1]["value"] for item in items] == [10, 20] + + +def test_partition_filter_round_robin(worker): + """Partition filter ``(n=3, i=1)`` delivers positions 1, 4, 7.""" + targets = [workflow("test", "producer"), workflow("test", "consumer")] + + with worker(targets) as ctx: + prod_resp = ctx.submit("test", "producer") + prod_ex = ctx.executor.next_execute() + prod_ex.conn.stream_register(prod_ex.execution_id, 0) + for i in range(9): + prod_ex.conn.stream_append(prod_ex.execution_id, 0, i, i) + prod_ex.conn.stream_close(prod_ex.execution_id, 0) + prod_ex.conn.complete(prod_ex.execution_id) + ctx.result(prod_resp["runId"]) + + cons_resp = ctx.submit("test", "consumer") + cons_ex = ctx.executor.next_execute() + cons_ex.conn.stream_subscribe( + cons_ex.execution_id, + subscription_id=1, + producer_execution_id=prod_ex.execution_id, + sequence=0, + filter=partition_filter(n=3, i=1), + ) + items, _ = cons_ex.conn.drain_stream(subscription_id=1) + cons_ex.conn.complete(cons_ex.execution_id) + + assert [item[0] for item in items] == [1, 4, 7] + + +def test_producer_error_closes_with_error_info(worker): + """Generator raises mid-stream: subscriber sees items-so-far then an + errored closure carrying {type, message}. + """ + targets = [workflow("test", "producer"), workflow("test", "consumer")] + + with worker(targets) as ctx: + prod_resp = ctx.submit("test", "producer") + prod_ex = ctx.executor.next_execute() + prod_ex.conn.stream_register(prod_ex.execution_id, 0) + prod_ex.conn.stream_append(prod_ex.execution_id, 0, 0, "ok") + prod_ex.conn.stream_close( + prod_ex.execution_id, + 0, + error={"type": "ValueError", "message": "boom", "traceback": ""}, + ) + prod_ex.conn.complete(prod_ex.execution_id) + ctx.result(prod_resp["runId"]) + + cons_resp = ctx.submit("test", "consumer") + cons_ex = ctx.executor.next_execute() + cons_ex.conn.stream_subscribe( + cons_ex.execution_id, + subscription_id=1, + producer_execution_id=prod_ex.execution_id, + sequence=0, + ) + items, closed = cons_ex.conn.drain_stream(subscription_id=1) + cons_ex.conn.complete(cons_ex.execution_id) + + assert [item[1]["value"] for item in items] == ["ok"] + err = closed.get("error") + assert err is not None + assert err["type"] == "ValueError" + assert err["message"] == "boom" + + +def test_subscribe_to_unknown_producer_closes_immediately(worker): + """Subscribing to a stream that doesn't exist yields an immediate error + closure (not an indefinite wait). + """ + targets = [workflow("test", "consumer")] + + with worker(targets) as ctx: + cons_resp = ctx.submit("test", "consumer") + cons_ex = ctx.executor.next_execute() + cons_ex.conn.stream_subscribe( + cons_ex.execution_id, + subscription_id=1, + producer_execution_id="00000000:0:0", + sequence=0, + ) + _items, closed = cons_ex.conn.drain_stream(subscription_id=1) + cons_ex.conn.complete(cons_ex.execution_id) + + err = closed.get("error") + assert err is not None + assert err["type"] == "Coflux.StreamNotFound" + + +def test_topic_exposes_stream_state(worker): + """Studio topic gets `streams` per execution: opened, closed, error.""" + targets = [workflow("test", "producer")] + + with worker(targets) as ctx: + prod_resp = ctx.submit("test", "producer") + prod_ex = ctx.executor.next_execute() + prod_ex.conn.stream_register(prod_ex.execution_id, 0) + prod_ex.conn.stream_register(prod_ex.execution_id, 1) + prod_ex.conn.stream_close(prod_ex.execution_id, 0) + prod_ex.conn.stream_close( + prod_ex.execution_id, + 1, + error={"type": "RuntimeError", "message": "bad", "traceback": ""}, + ) + prod_ex.conn.complete(prod_ex.execution_id) + ctx.result(prod_resp["runId"]) + + snapshot = ctx.inspect(prod_resp["runId"]) + # The run snapshot has a `steps → {run:step → {executions → {attempt → {...}}}}` shape. + step = next(iter(snapshot["steps"].values())) + execution = next(iter(step["executions"].values())) + streams = execution["streams"] + + assert "0" in streams and "1" in streams + assert streams["0"]["openedAt"] is not None + assert streams["0"]["closedAt"] is not None + assert streams["0"]["error"] is None + assert streams["1"]["closedAt"] is not None + assert streams["1"]["error"] == {"type": "RuntimeError", "message": "bad"} + + +def test_cancellation_closes_streams_with_cancelled_error(worker): + """Cancel an execution mid-stream: the subscriber receives a closure + carrying the ExecutionCancelled error synthesised by close_open_streams. + """ + targets = [workflow("test", "producer"), workflow("test", "consumer")] + + with worker(targets, concurrency=2) as ctx: + prod_resp = ctx.submit("test", "producer") + prod_ex = ctx.executor.next_execute() + prod_ex.conn.stream_register(prod_ex.execution_id, 0) + prod_ex.conn.stream_append(prod_ex.execution_id, 0, 0, "before") + + cons_resp = ctx.submit("test", "consumer") + cons_ex = ctx.executor.next_execute() + cons_ex.conn.stream_subscribe( + cons_ex.execution_id, + subscription_id=1, + producer_execution_id=prod_ex.execution_id, + sequence=0, + ) + + ctx.cancel(prod_ex.execution_id) + + items, closed = cons_ex.conn.drain_stream(subscription_id=1) + cons_ex.conn.complete(cons_ex.execution_id) + + assert [item[1]["value"] for item in items] == ["before"] + err = closed.get("error") + assert err is not None + assert err["type"] == "Coflux.ExecutionCancelled" + + +def test_multiple_subscribers_get_independent_delivery(worker): + """Two consumers subscribe to the same stream — each gets the full + sequence independently. + """ + targets = [workflow("test", "producer"), workflow("test", "consumer")] + + with worker(targets, concurrency=3) as ctx: + prod_resp = ctx.submit("test", "producer") + prod_ex = ctx.executor.next_execute() + prod_ex.conn.stream_register(prod_ex.execution_id, 0) + for i in range(3): + prod_ex.conn.stream_append(prod_ex.execution_id, 0, i, i) + prod_ex.conn.stream_close(prod_ex.execution_id, 0) + prod_ex.conn.complete(prod_ex.execution_id) + ctx.result(prod_resp["runId"]) + + # Subscriber A + a_resp = ctx.submit("test", "consumer") + a_ex = ctx.executor.next_execute() + a_ex.conn.stream_subscribe( + a_ex.execution_id, + subscription_id=1, + producer_execution_id=prod_ex.execution_id, + sequence=0, + ) + + # Subscriber B + b_resp = ctx.submit("test", "consumer") + b_ex = ctx.executor.next_execute() + b_ex.conn.stream_subscribe( + b_ex.execution_id, + subscription_id=1, + producer_execution_id=prod_ex.execution_id, + sequence=0, + ) + + a_items, _ = a_ex.conn.drain_stream(subscription_id=1) + b_items, _ = b_ex.conn.drain_stream(subscription_id=1) + a_ex.conn.complete(a_ex.execution_id) + b_ex.conn.complete(b_ex.execution_id) + + assert [item[1]["value"] for item in a_items] == [0, 1, 2] + assert [item[1]["value"] for item in b_items] == [0, 1, 2] + + +def test_slice_with_stop_closes_early(worker): + """slice(0, 2) on a stream that has more items should close the + subscriber as soon as position 2 is reached, not wait for the full + stream to drain. The early-close path is the `filter_exhausted?` branch + in push_stream_item. + """ + targets = [workflow("test", "producer"), workflow("test", "consumer")] + + with worker(targets, concurrency=2) as ctx: + prod_resp = ctx.submit("test", "producer") + prod_ex = ctx.executor.next_execute() + prod_ex.conn.stream_register(prod_ex.execution_id, 0) + + # Subscriber gets first 2 items then close. + cons_resp = ctx.submit("test", "consumer") + cons_ex = ctx.executor.next_execute() + cons_ex.conn.stream_subscribe( + cons_ex.execution_id, + subscription_id=1, + producer_execution_id=prod_ex.execution_id, + sequence=0, + filter=slice_filter(0, 2), + ) + + prod_ex.conn.stream_append(prod_ex.execution_id, 0, 0, "a") + prod_ex.conn.stream_append(prod_ex.execution_id, 0, 1, "b") + + items, closed = cons_ex.conn.drain_stream(subscription_id=1, timeout=5) + + # items 2+ should NOT reach the subscriber — its slice is satisfied. + # Finish the producer so its run wraps up cleanly. + prod_ex.conn.stream_append(prod_ex.execution_id, 0, 2, "c") + prod_ex.conn.stream_close(prod_ex.execution_id, 0) + prod_ex.conn.complete(prod_ex.execution_id) + cons_ex.conn.complete(cons_ex.execution_id) + + assert [item[1]["value"] for item in items] == ["a", "b"] + assert closed.get("error") is None + + +def test_unsubscribe_prevents_receiving_full_stream(worker): + """Consumer unsubscribes partway through and doesn't receive every item. + + Ordering note: with the producer and consumer on separate sessions, + a few items appended immediately after unsubscribe can still reach the + consumer if they're in flight when the server processes unsubscribe. + The meaningful check is that the consumer stops seeing items before + the full stream is delivered — not that unsubscribe is synchronous. + """ + targets = [workflow("test", "producer"), workflow("test", "consumer")] + + with worker(targets, concurrency=2) as ctx: + prod_resp = ctx.submit("test", "producer") + prod_ex = ctx.executor.next_execute() + prod_ex.conn.stream_register(prod_ex.execution_id, 0) + + cons_resp = ctx.submit("test", "consumer") + cons_ex = ctx.executor.next_execute() + cons_ex.conn.stream_subscribe( + cons_ex.execution_id, + subscription_id=1, + producer_execution_id=prod_ex.execution_id, + sequence=0, + ) + + prod_ex.conn.stream_append(prod_ex.execution_id, 0, 0, 0) + first = cons_ex.conn.recv_push("stream_items", subscription_id=1, timeout=3) + assert first["items"][0][1]["value"] == 0 + cons_ex.conn.stream_unsubscribe(cons_ex.execution_id, subscription_id=1) + + # Producer keeps appending after unsubscribe. + for i in range(1, 10): + prod_ex.conn.stream_append(prod_ex.execution_id, 0, i, i) + prod_ex.conn.stream_close(prod_ex.execution_id, 0) + prod_ex.conn.complete(prod_ex.execution_id) + + # Collect anything still in flight. It might include the next + # one or two items (racing with unsubscribe) but MUST NOT include + # the tail — some positions drop out between unsubscribe and close. + received_positions = [0] + try: + while True: + msg = cons_ex.conn.recv_push( + "stream_items", subscription_id=1, timeout=0.5 + ) + for item in msg["items"]: + received_positions.append(item[0]) + except TimeoutError: + pass + + cons_ex.conn.complete(cons_ex.execution_id) + + # The consumer should have received strictly fewer than all 10 items. + assert len(received_positions) < 10, ( + f"unsubscribe should stop further delivery; got {received_positions}" + ) + + +def test_close_while_subscribed_delivers_closure(worker): + """Close the stream after a subscriber is already connected — the + closure gets pushed to the live subscriber (not just stored). + """ + targets = [workflow("test", "producer"), workflow("test", "consumer")] + + with worker(targets, concurrency=2) as ctx: + prod_resp = ctx.submit("test", "producer") + prod_ex = ctx.executor.next_execute() + prod_ex.conn.stream_register(prod_ex.execution_id, 0) + + cons_resp = ctx.submit("test", "consumer") + cons_ex = ctx.executor.next_execute() + cons_ex.conn.stream_subscribe( + cons_ex.execution_id, + subscription_id=1, + producer_execution_id=prod_ex.execution_id, + sequence=0, + ) + + prod_ex.conn.stream_append(prod_ex.execution_id, 0, 0, "only") + prod_ex.conn.stream_close(prod_ex.execution_id, 0) + + items, closed = cons_ex.conn.drain_stream(subscription_id=1) + prod_ex.conn.complete(prod_ex.execution_id) + cons_ex.conn.complete(cons_ex.execution_id) + + assert [item[1]["value"] for item in items] == ["only"] + assert closed.get("error") is None + + +def test_lifecycle_close_on_completion_delivers_to_subscriber(worker): + """Producer registers a stream but never explicitly closes it. + When the execution completes, close_open_streams backstops — subscriber + gets a clean close. + """ + targets = [workflow("test", "producer"), workflow("test", "consumer")] + + with worker(targets, concurrency=2) as ctx: + prod_resp = ctx.submit("test", "producer") + prod_ex = ctx.executor.next_execute() + prod_ex.conn.stream_register(prod_ex.execution_id, 0) + prod_ex.conn.stream_append(prod_ex.execution_id, 0, 0, 1) + + cons_resp = ctx.submit("test", "consumer") + cons_ex = ctx.executor.next_execute() + cons_ex.conn.stream_subscribe( + cons_ex.execution_id, + subscription_id=1, + producer_execution_id=prod_ex.execution_id, + sequence=0, + ) + + # Producer completes *without* closing the stream. + prod_ex.conn.complete(prod_ex.execution_id) + + items, closed = cons_ex.conn.drain_stream(subscription_id=1) + cons_ex.conn.complete(cons_ex.execution_id) + + assert [item[1]["value"] for item in items] == [1] + assert closed.get("error") is None # clean close — execution completed normally + + +def test_filter_chain_combines_slice_and_partition(worker): + """``chain(slice(0, 6), partition(2, 0))`` → positions 0, 2, 4.""" + targets = [workflow("test", "producer"), workflow("test", "consumer")] + + with worker(targets) as ctx: + prod_resp = ctx.submit("test", "producer") + prod_ex = ctx.executor.next_execute() + prod_ex.conn.stream_register(prod_ex.execution_id, 0) + for i in range(10): + prod_ex.conn.stream_append(prod_ex.execution_id, 0, i, i) + prod_ex.conn.stream_close(prod_ex.execution_id, 0) + prod_ex.conn.complete(prod_ex.execution_id) + ctx.result(prod_resp["runId"]) + + cons_resp = ctx.submit("test", "consumer") + cons_ex = ctx.executor.next_execute() + cons_ex.conn.stream_subscribe( + cons_ex.execution_id, + subscription_id=1, + producer_execution_id=prod_ex.execution_id, + sequence=0, + filter=chain_filter(slice_filter(0, 6), partition_filter(n=2, i=0)), + ) + items, _ = cons_ex.conn.drain_stream(subscription_id=1) + cons_ex.conn.complete(cons_ex.execution_id) + + assert [item[0] for item in items] == [0, 2, 4] From 5d8d86cf19d224670928b5e24de3204864c5801c Mon Sep 17 00:00:00 2001 From: Joe Freeman Date: Sun, 19 Apr 2026 10:56:50 +0100 Subject: [PATCH 11/25] Various fixes --- adapters/python/coflux/context.py | 7 + adapters/python/coflux/dispatcher.py | 32 ++ adapters/python/coflux/errors.py | 45 +++ adapters/python/coflux/executor.py | 12 + adapters/python/coflux/serialization.py | 7 +- adapters/python/coflux/streams.py | 170 ++++++++-- cli/internal/worker/worker.go | 27 +- server/lib/coflux/handlers/worker.ex | 3 +- server/lib/coflux/orchestration.ex | 7 +- server/lib/coflux/orchestration/server.ex | 370 ++++++++++++++------- server/lib/coflux/orchestration/streams.ex | 83 +++-- server/priv/migrations/orchestration/4.sql | 11 +- tests/test_streams.py | 101 +++++- 13 files changed, 688 insertions(+), 187 deletions(-) diff --git a/adapters/python/coflux/context.py b/adapters/python/coflux/context.py index 2fdf028e..e882f2a2 100644 --- a/adapters/python/coflux/context.py +++ b/adapters/python/coflux/context.py @@ -111,6 +111,13 @@ def wait_streams(self) -> None: """Block until every stream produced by this execution has drained.""" self._stream_driver.wait_all() + def close_streams(self) -> None: + """Close every registered generator so driver threads exit promptly. + + Used on the error path before reporting execution_error. + """ + self._stream_driver.close_all() + def submit_execution( self, module: str, diff --git a/adapters/python/coflux/dispatcher.py b/adapters/python/coflux/dispatcher.py index 00ddc594..6ceb32f7 100644 --- a/adapters/python/coflux/dispatcher.py +++ b/adapters/python/coflux/dispatcher.py @@ -51,6 +51,10 @@ def __init__(self, protocol: Protocol) -> None: # other threads). self._notification_handlers: dict[str, Callable[[dict[str, Any]], None]] = {} + # Callbacks invoked on EOF so subsystems with their own blocking + # queues (e.g. the stream registry) can wake their waiters. + self._close_callbacks: list[Callable[[], None]] = [] + # Set when stdin reaches EOF. Wakes all pending waiters. self._closed = threading.Event() @@ -84,6 +88,27 @@ def unregister_notification(self, method: str) -> None: with self._lock: self._notification_handlers.pop(method, None) + def add_close_callback(self, callback: Callable[[], None]) -> None: + """Register a callback to fire when stdin reaches EOF. + + Fires on the reader thread, after ``_closed`` is set and response + waiters are woken. Use this to unblock subsystems that queue on + something other than a response event (e.g. per-iterator queues in + the stream registry). + """ + with self._lock: + already_closed = self._closed.is_set() + if not already_closed: + self._close_callbacks.append(callback) + if already_closed: + try: + callback() + except Exception: + pass + + def is_closed(self) -> bool: + return self._closed.is_set() + def wait_for_response( self, request_id: int, @@ -128,6 +153,13 @@ def _run(self) -> None: with self._lock: for event, _slot in self._waiting.values(): event.set() + callbacks = list(self._close_callbacks) + self._close_callbacks.clear() + for cb in callbacks: + try: + cb() + except Exception: + pass return if "id" in msg: diff --git a/adapters/python/coflux/errors.py b/adapters/python/coflux/errors.py index 46aead92..e03f73f7 100644 --- a/adapters/python/coflux/errors.py +++ b/adapters/python/coflux/errors.py @@ -161,3 +161,48 @@ def create_execution_error(error_type: str, error_message: str) -> ExecutionErro error_type=error_type, error_message=error_message, ) + + +# --- Stream error helpers --- + +# Server synthesises these types when closing streams due to the producer's +# disposition (cancel, crash, etc.). The wire carries them as regular +# {type, message, frames} errors; we route them here to specific +# ExecutionTerminated subclasses rather than generic ExecutionError. +_STREAM_SYNTHETIC_ERRORS: dict[str, type[Exception]] = { + "Coflux.ExecutionCancelled": ExecutionCancelled, + "Coflux.ExecutionAbandoned": ExecutionAbandoned, + "Coflux.ExecutionCrashed": ExecutionCrashed, + "Coflux.ExecutionErrored": ExecutionError, +} + + +def create_stream_error(error: dict) -> Exception: + """Build an exception for a stream closure. + + Server-synthesised types (``Coflux.Execution*``) map to + ``ExecutionTerminated`` subclasses. Real user exceptions (raised by the + producer's generator) go through ``create_execution_error`` and get + the producer's frames attached as ``.frames`` for debuggability. + """ + error_type = error.get("type", "") + error_message = error.get("message", "") + frames = error.get("frames") or [] + + synthetic = _STREAM_SYNTHETIC_ERRORS.get(error_type) + if synthetic is ExecutionError: + exc = ExecutionError( + error_message, + error_type=error_type, + error_message=error_message, + ) + elif synthetic is not None: + exc = synthetic() + else: + exc = create_execution_error(error_type, error_message) + + if frames: + # Frames are [file, line, name, code] lists — the same wire shape + # used by execution errors. Expose on the exception for inspection. + exc.frames = frames + return exc diff --git a/adapters/python/coflux/executor.py b/adapters/python/coflux/executor.py index ad58e545..5bc51df9 100644 --- a/adapters/python/coflux/executor.py +++ b/adapters/python/coflux/executor.py @@ -76,6 +76,7 @@ def execute_target( # through it — individual threads block on the dispatcher rather than # racing on stdin directly. start_dispatcher(protocol.get_protocol()) + ctx: ExecutorContext | None = None try: if working_dir: os.chdir(working_dir) @@ -136,6 +137,17 @@ def execute_target( except Exception as callback_exc: e = callback_exc + # Stop any in-flight stream producers and wait for their driver + # threads to exit before reporting the execution error. The server's + # close_open_streams will then synthesise a Coflux.ExecutionErrored + # close for any streams still open when the error is recorded. + if ctx is not None: + try: + ctx.close_streams() + ctx.wait_streams() + except Exception: + pass + error_type = f"{type(e).__module__}.{type(e).__qualname__}" tb = _format_filtered_traceback(e) protocol.send_execution_error( diff --git a/adapters/python/coflux/serialization.py b/adapters/python/coflux/serialization.py index d1ab888a..d7cd6405 100644 --- a/adapters/python/coflux/serialization.py +++ b/adapters/python/coflux/serialization.py @@ -69,7 +69,7 @@ def _encode_value( def _encode(v: Any) -> Any: if v is None or isinstance(v, (str, bool, int, float)): return v - elif inspect.isgenerator(v): + elif inspect.isgenerator(v) or inspect.isasyncgen(v): if on_generator is None: raise TypeError( "Cannot serialize a generator: no stream driver is active." @@ -93,11 +93,6 @@ def _encode(v: Any) -> Any: if v._filters: encoded["filters"] = list(v._filters) return encoded - elif inspect.isasyncgen(v): - raise TypeError( - "Async generators aren't supported yet — use a sync generator " - "(def + yield) for now." - ) elif isinstance(v, list): return [_encode(x) for x in v] elif isinstance(v, dict): diff --git a/adapters/python/coflux/streams.py b/adapters/python/coflux/streams.py index cbb96c64..5f105550 100644 --- a/adapters/python/coflux/streams.py +++ b/adapters/python/coflux/streams.py @@ -1,14 +1,17 @@ """Producer and consumer stream plumbing. The producer side owns ``StreamDriver``: each execution whose return value -(or submitted arguments) contains generators uses one to drive each -generator in a background thread. +(or submitted arguments) contains generators uses one to run each +generator in a background thread. Both sync (``def`` + ``yield``) and +async (``async def`` + ``yield``) generators are supported; async +generators get a fresh event loop confined to their worker thread. The consumer side owns a module-level ``StreamRegistry``: open consumer subscriptions are keyed by subscription id. The registry's dispatcher handlers (``stream_items``/``stream_closed``) route incoming pushes from the server to the right iterator's queue, which yields as the user -iterates. +iterates. On dispatcher EOF every active iterator is woken with a +synthetic abandoned-close so user code doesn't hang forever. Both sides are thread-safe: the ``Dispatcher`` owns stdin (so subtask calls from generator bodies don't race), and stdout writes go through @@ -17,6 +20,8 @@ from __future__ import annotations +import asyncio +import inspect import queue import threading import traceback @@ -24,7 +29,7 @@ from . import protocol from .dispatcher import get_dispatcher -from .errors import create_execution_error +from .errors import create_stream_error from .serialization import deserialize_value, serialize_value from .state import get_context @@ -39,10 +44,16 @@ def __init__(self, execution_id: str) -> None: self._execution_id = execution_id self._next_sequence = 0 self._threads: list[threading.Thread] = [] + self._generators: list[Any] = [] self._lock = threading.Lock() def register(self, generator: Any) -> tuple[str, int]: - """Register a generator, spawn its driver thread. + """Register a generator and start running it in a worker thread. + + Accepts both sync generators (``def`` + ``yield``) and async + generators (``async def`` + ``yield``). Each gets its own thread; + async generators run inside a fresh event loop confined to that + thread. Returns ``(execution_id, sequence)`` for embedding in the serialized value as a stream reference. @@ -53,19 +64,24 @@ def register(self, generator: Any) -> tuple[str, int]: protocol.send_stream_register(self._execution_id, sequence) + is_async = inspect.isasyncgen(generator) + target = self._run_async if is_async else self._run thread = threading.Thread( - target=self._drive, + target=target, args=(sequence, generator), name=f"stream-{self._execution_id}-{sequence}", daemon=False, ) + entry = {"generator": generator, "is_async": is_async, "loop": None} + with self._lock: + self._generators.append(entry) + self._threads.append(thread) thread.start() - self._threads.append(thread) return self._execution_id, sequence - def _drive(self, sequence: int, generator: Any) -> None: - """Pump one generator to exhaustion (or error).""" + def _run(self, sequence: int, generator: Any) -> None: + """Run one sync generator to exhaustion (or error).""" position = 0 try: for item in generator: @@ -78,8 +94,10 @@ def _drive(self, sequence: int, generator: Any) -> None: ) position += 1 except GeneratorExit: - # Generator explicitly closed (e.g. execution cancelled). The - # server already knows — no close message needed. + # Generator explicitly closed (via close_all on error path, or + # server-initiated cancel). Skip send_stream_close — the server + # records a lifecycle closure when the execution terminates and + # derives the error from the execution's outcome. return except BaseException as e: # noqa: BLE001 - we propagate all error_type = f"{type(e).__module__}.{type(e).__qualname__}" @@ -94,13 +112,99 @@ def _drive(self, sequence: int, generator: Any) -> None: else: protocol.send_stream_close(self._execution_id, sequence) + def _run_async(self, sequence: int, generator: Any) -> None: + """Run one async generator in a fresh event loop on this thread. + + The loop handle is recorded so ``close_all`` can schedule aclose() + from another thread via ``run_coroutine_threadsafe``. + """ + loop = asyncio.new_event_loop() + self._record_loop(generator, loop) + asyncio.set_event_loop(loop) + + async def iterate() -> None: + position = 0 + async for item in generator: + serialized = serialize_value(item) + protocol.send_stream_append( + self._execution_id, + sequence, + position, + serialized, + ) + position += 1 + + try: + loop.run_until_complete(iterate()) + except (GeneratorExit, asyncio.CancelledError): + return + except BaseException as e: # noqa: BLE001 - we propagate all + error_type = f"{type(e).__module__}.{type(e).__qualname__}" + tb = traceback.format_exc() + protocol.send_stream_close( + self._execution_id, + sequence, + error_type=error_type, + error_message=str(e), + traceback=tb, + ) + else: + protocol.send_stream_close(self._execution_id, sequence) + finally: + try: + loop.run_until_complete(generator.aclose()) + except Exception: + pass + try: + loop.close() + except Exception: + pass + + def _record_loop(self, generator: Any, loop: asyncio.AbstractEventLoop) -> None: + with self._lock: + for entry in self._generators: + if entry["generator"] is generator: + entry["loop"] = loop + return + def wait_all(self) -> None: - """Block until every driver thread has finished.""" + """Block until every worker thread has finished.""" with self._lock: threads = list(self._threads) for t in threads: t.join() + def close_all(self) -> None: + """Close every registered generator so worker threads exit promptly. + + Used on the error path: when the task body raises, we want in-flight + streams to stop producing rather than racing the execution_error + notification. For sync generators, ``generator.close()`` raises + ``GeneratorExit`` at the current yield point. For async generators, + we schedule ``aclose()`` onto the generator's own event loop so the + awaiting coroutine is cancelled cleanly. + """ + with self._lock: + entries = list(self._generators) + for entry in entries: + try: + if entry["is_async"]: + loop = entry["loop"] + if loop is not None and not loop.is_closed(): + gen = entry["generator"] + + async def _close(g=gen) -> None: + try: + await g.aclose() + except Exception: + pass + + asyncio.run_coroutine_threadsafe(_close(), loop) + else: + entry["generator"].close() + except Exception: + pass + # --- Consumer side --- @@ -126,10 +230,14 @@ def __init__(self, subscription_id: int, execution_id: str) -> None: def on_items(self, items: list[list[Any]]) -> None: """Called by the registry when the server pushes items for this subscription. ``items`` is a list of ``[position, value_wire]``. + + Runs on the dispatcher reader thread — keep it cheap. The raw wire + value goes onto the queue unmodified; deserialization happens in + ``__next__`` on the consumer's thread so heavy decode work doesn't + stall stdin reads. """ for _position, value in items: - # Decode eagerly so iteration cost is paid per-item as it arrives. - self._queue.put(deserialize_value(value)) + self._queue.put(value) def on_closed(self, error: dict[str, Any] | None) -> None: """Called by the registry when the stream closes.""" @@ -145,14 +253,20 @@ def __next__(self) -> Any: if isinstance(item, _Closed): self._done = True _stream_registry().drop(self._subscription_id) - protocol.send_stream_unsubscribe(self._execution_id, self._subscription_id) + # Skip the unsubscribe roundtrip when the dispatcher is gone — + # stdout may still be writable but there's no one to receive it, + # and a closed pipe would raise from send_*. + if not get_dispatcher().is_closed(): + try: + protocol.send_stream_unsubscribe( + self._execution_id, self._subscription_id + ) + except Exception: + pass if item.error is not None: - raise create_execution_error( - item.error.get("type", ""), - item.error.get("message", ""), - ) + raise create_stream_error(item.error) raise StopIteration - return item + return deserialize_value(item) class StreamRegistry: @@ -172,8 +286,24 @@ def _ensure_installed(self) -> None: d = get_dispatcher() d.register_notification("stream_items", self._on_items) d.register_notification("stream_closed", self._on_closed) + # If stdin goes away before the server sends close messages, + # blocked iterators would hang on their queues forever. Push a + # synthetic closed sentinel into each so ``__next__`` raises. + d.add_close_callback(self._on_dispatcher_closed) self._installed = True + def _on_dispatcher_closed(self) -> None: + """Wake all active iterators with a connection-closed error.""" + error = { + "type": "Coflux.ExecutionAbandoned", + "message": "connection closed", + "frames": [], + } + with self._lock: + iterators = list(self._iterators.values()) + for it in iterators: + it.on_closed(error) + def allocate(self, execution_id: str) -> tuple[int, _StreamIterator]: """Claim a subscription id and iterator.""" self._ensure_installed() diff --git a/cli/internal/worker/worker.go b/cli/internal/worker/worker.go index cbf6e72e..90d03ec4 100644 --- a/cli/internal/worker/worker.go +++ b/cli/internal/worker/worker.go @@ -658,11 +658,17 @@ func (w *Worker) handleStreamItems(params []any) error { if len(params) < 3 { return fmt.Errorf("stream_items: insufficient params") } - executionID := getString(params[0]) - subscriptionID, _ := params[1].(float64) + executionID, ok := params[0].(string) + if !ok { + return fmt.Errorf("stream_items: execution_id is not a string (got %T)", params[0]) + } + subscriptionID, ok := params[1].(float64) + if !ok { + return fmt.Errorf("stream_items: subscription_id is not a number (got %T)", params[1]) + } rawItems, ok := params[2].([]any) if !ok { - return fmt.Errorf("stream_items: items is not an array") + return fmt.Errorf("stream_items: items is not an array (got %T)", params[2]) } converted := make([]any, len(rawItems)) @@ -691,8 +697,14 @@ func (w *Worker) handleStreamClosed(params []any) error { if len(params) < 3 { return fmt.Errorf("stream_closed: insufficient params") } - executionID := getString(params[0]) - subscriptionID, _ := params[1].(float64) + executionID, ok := params[0].(string) + if !ok { + return fmt.Errorf("stream_closed: execution_id is not a string (got %T)", params[0]) + } + subscriptionID, ok := params[1].(float64) + if !ok { + return fmt.Errorf("stream_closed: subscription_id is not a number (got %T)", params[1]) + } errField := params[2] forwarded := map[string]any{ @@ -1220,7 +1232,10 @@ func (w *Worker) StreamUnsubscribe(ctx context.Context, executionID string, subs if err != nil { return err } - return conn.Notify("stream_unsubscribe", subscriptionID) + // Server params: [consumer_execution_id, subscription_id]. The consumer + // id scopes the subscription key server-side, so two adapters in the + // same session can reuse subscription_id without colliding. + return conn.Notify("stream_unsubscribe", executionID, subscriptionID) } func (w *Worker) Cancel(ctx context.Context, executionID string, handles []adapter.SelectHandle) error { diff --git a/server/lib/coflux/handlers/worker.ex b/server/lib/coflux/handlers/worker.ex index 9397a0b8..177daedd 100644 --- a/server/lib/coflux/handlers/worker.ex +++ b/server/lib/coflux/handlers/worker.ex @@ -371,12 +371,13 @@ defmodule Coflux.Handlers.Worker do end "stream_unsubscribe" -> - [subscription_id] = message["params"] + [consumer_execution_id, subscription_id] = message["params"] :ok = Orchestration.unsubscribe_stream( state.project_id, state.session_id, + consumer_execution_id, subscription_id ) diff --git a/server/lib/coflux/orchestration.ex b/server/lib/coflux/orchestration.ex index 9a90986d..1d53fcba 100644 --- a/server/lib/coflux/orchestration.ex +++ b/server/lib/coflux/orchestration.ex @@ -218,8 +218,11 @@ defmodule Coflux.Orchestration do ) end - def unsubscribe_stream(project_id, session_id, subscription_id) do - call_server(project_id, {:unsubscribe_stream, session_id, subscription_id}) + def unsubscribe_stream(project_id, session_id, consumer_execution_id, subscription_id) do + call_server( + project_id, + {:unsubscribe_stream, session_id, consumer_execution_id, subscription_id} + ) end def select( diff --git a/server/lib/coflux/orchestration/server.ex b/server/lib/coflux/orchestration/server.ex index 604deb14..7f19c404 100644 --- a/server/lib/coflux/orchestration/server.ex +++ b/server/lib/coflux/orchestration/server.ex @@ -119,17 +119,22 @@ defmodule Coflux.Orchestration.Server do # Active stream subscriptions — in-memory, session-scoped. # A consumer adapter opens a subscription by sending stream_subscribe - # with a session-unique subscription_id; we push items (stream_items - # command) as they arrive on the producer side, and a terminal - # stream_closed command when the stream ends. Dropped when the - # session disconnects, when the consumer unsubscribes, or when the - # stream closes. + # with a subscription_id unique within that consumer's adapter + # process; we push items (stream_items command) as they arrive on + # the producer side, and a terminal stream_closed command when the + # stream ends. Dropped when the session disconnects, when the + # consumer unsubscribes, when the consumer execution terminates, + # or when the stream closes. # - # stream_subscriptions: {session_id, subscription_id} -> %{ - # consumer_execution_id, producer_execution_id, sequence, - # cursor, filter} + # The key includes consumer_execution_id so concurrent consumer + # adapters (each starting subscription counters from 0) can't + # collide. + # + # stream_subscriptions: {consumer_execution_id, subscription_id} -> + # %{consumer_execution_external_id, producer_execution_id, + # sequence, cursor, filter} # stream_subscribers: {producer_execution_id, sequence} -> MapSet of - # {session_id, subscription_id} + # {consumer_execution_id, subscription_id} stream_subscriptions: %{}, stream_subscribers: %{} end @@ -1760,7 +1765,9 @@ defmodule Coflux.Orchestration.Server do state = case Map.fetch(state.execution_ids, ext_id) do {:ok, execution_id} -> - complete_execution(state, execution_id) + state + |> complete_execution(execution_id) + |> drop_execution_subscriptions(execution_id) :error -> state @@ -1873,7 +1880,13 @@ defmodule Coflux.Orchestration.Server do def handle_call({:close_stream, execution_external_id, sequence, error}, _from, state) do case Map.fetch(state.execution_ids, execution_external_id) do {:ok, execution_id} -> - case Streams.close_stream(state.db, execution_id, sequence, error) do + spec = + case error do + nil -> :complete + {type, message, frames} -> {:errored, type, message, frames} + end + + case Streams.close_stream(state.db, execution_id, sequence, spec) do {:ok, closed_at} -> state = state @@ -1906,7 +1919,7 @@ defmodule Coflux.Orchestration.Server do {:error, :not_found} -> {:error, :producer_not_found} end - with {:ok, session_id} <- + with {:ok, _session_id} <- Map.fetch(state.session_ids, session_external_id) |> ok_or(:session_not_found), {:ok, consumer_execution_id} <- @@ -1914,10 +1927,9 @@ defmodule Coflux.Orchestration.Server do |> ok_or(:consumer_not_found), {:ok, producer_execution_id} <- producer_result, {:ok, true} <- Streams.exists?(state.db, producer_execution_id, sequence), - key = {session_id, subscription_id}, + key = {consumer_execution_id, subscription_id}, false <- Map.has_key?(state.stream_subscriptions, key) do subscription = %{ - consumer_execution_id: consumer_execution_id, consumer_execution_external_id: consumer_execution_external_id, producer_execution_id: producer_execution_id, sequence: sequence, @@ -1939,8 +1951,8 @@ defmodule Coflux.Orchestration.Server do # Push any items already in the log that match the filter, then (if # the stream has already closed) the terminal close record. - state = push_backlog(state, session_id, subscription_id) - state = maybe_push_closure_if_closed(state, session_id, subscription_id) + state = push_backlog(state, key) + state = maybe_push_closure_if_closed(state, key) {:reply, :ok, state} else @@ -1950,11 +1962,17 @@ defmodule Coflux.Orchestration.Server do end end - def handle_call({:unsubscribe_stream, session_external_id, subscription_id}, _from, state) do - case Map.fetch(state.session_ids, session_external_id) do - {:ok, session_id} -> - {:reply, :ok, drop_subscription(state, {session_id, subscription_id})} - + def handle_call( + {:unsubscribe_stream, session_external_id, consumer_execution_external_id, + subscription_id}, + _from, + state + ) do + with {:ok, _session_id} <- Map.fetch(state.session_ids, session_external_id), + {:ok, consumer_execution_id} <- + Map.fetch(state.execution_ids, consumer_execution_external_id) do + {:reply, :ok, drop_subscription(state, {consumer_execution_id, subscription_id})} + else :error -> {:reply, :ok, state} end @@ -3627,9 +3645,9 @@ defmodule Coflux.Orchestration.Server do # Close any open streams so iterating consumers stop waiting. Any # subsequent `append_item` from the producer will fail with `:closed`, - # signalling the worker to stop. Push :cancelled so consumers raise - # ExecutionCancelled on iteration. - state = close_open_streams(state, execution_id, :cancelled) + # signalling the worker to stop. Recorded as :lifecycle — consumers + # derive the ExecutionCancelled error from the recorded result. + state = close_open_streams(state, execution_id) state = case Runs.get_execution_key(state.db, execution_id) do @@ -4977,7 +4995,7 @@ defmodule Coflux.Orchestration.Server do ) ) - {:ok, streams} = Streams.get_streams_with_closures_for_execution(db, execution_id) + streams = streams_with_resolved_errors(db, execution_id) {attempt, %{ @@ -5798,9 +5816,11 @@ defmodule Coflux.Orchestration.Server do case Results.has_result?(state.db, execution_id) do {:ok, true} -> # Close any streams left open by the producer. Generator tasks - # normally close their streams explicitly; this is the backstop - # for ones that didn't. - state = close_open_streams(state, execution_id, :complete) + # normally close their streams explicitly; this is the backstop. + # We record a :lifecycle closure — the error surfaced to + # consumers is derived from the execution's recorded result + # rather than stored separately. + state = close_open_streams(state, execution_id) case Results.record_completion(state.db, execution_id) do {:ok, completion_at} -> @@ -5829,9 +5849,9 @@ defmodule Coflux.Orchestration.Server do decide_and_create_successor(state, execution_id, step, workspace_id, :crashed) # Streams that had been appended to before the worker died need to be - # closed so consumers don't wait forever. Push :crashed so consumers - # raise ExecutionCrashed on iteration. - state = close_open_streams(state, execution_id, :crashed) + # closed so consumers don't wait forever. Recorded as :lifecycle — + # consumers derive the specific error from the execution's outcome. + state = close_open_streams(state, execution_id) case Results.record_completion(state.db, execution_id) do {:ok, completion_at} -> @@ -5849,28 +5869,18 @@ defmodule Coflux.Orchestration.Server do # Closes every stream owned by `execution_id` that doesn't yet have a # closure row, and pushes a `stream_closed` notification to every active - # subscriber. If a generator closed its stream with an explicit error, - # that closure already exists and is left untouched. + # subscriber. Streams already closed by the producer (clean or errored) + # are left untouched. # - # ``reason`` annotates the closure for the consumer's benefit. It's stored - # in the DB closure row as nil (execution disposition is the source of - # truth), but pushed synchronously to subscribers so they can raise the - # right exception on iteration: - # * :complete – no error, StopIteration on next() - # * :cancelled – ExecutionCancelled on next() - # * :crashed – ExecutionCrashed on next() - defp close_open_streams(state, execution_id, reason) do + # The closure is recorded with reason :lifecycle — no error is stored + # on the closure row. Consumers that need to surface an error derive + # it from the execution's recorded result (see derive_lifecycle_error). + defp close_open_streams(state, execution_id) do {:ok, sequences} = Streams.get_open_streams_for_execution(state.db, execution_id) - - push_error = - case reason do - :complete -> nil - :cancelled -> {"Coflux.ExecutionCancelled", "execution cancelled", []} - :crashed -> {"Coflux.ExecutionCrashed", "worker terminated", []} - end + push_error = derive_lifecycle_error(state.db, execution_id) Enum.reduce(sequences, state, fn sequence, state -> - case Streams.close_stream(state.db, execution_id, sequence) do + case Streams.close_stream(state.db, execution_id, sequence, :lifecycle) do {:ok, closed_at} -> state |> push_stream_closed(execution_id, sequence, push_error) @@ -5882,6 +5892,55 @@ defmodule Coflux.Orchestration.Server do end) end + # Returns the streams list for `execution_id` with :lifecycle closures' + # errors resolved from the execution's recorded result. Shape: + # `{sequence, opened_at, closed_at | nil, error | nil}` — the same + # shape the topic module expects (reason is collapsed into error/nil + # once we've derived it). + defp streams_with_resolved_errors(db, execution_id) do + {:ok, rows} = Streams.get_streams_with_closures_for_execution(db, execution_id) + + Enum.map(rows, fn + {sequence, opened_at, nil, nil, nil} -> + {sequence, opened_at, nil, nil} + + {sequence, opened_at, closed_at, :lifecycle, _} -> + {sequence, opened_at, closed_at, derive_lifecycle_error(db, execution_id)} + + {sequence, opened_at, closed_at, _reason, error} -> + {sequence, opened_at, closed_at, error} + end) + end + + # Build a {type, message, frames} triple describing why a lifecycle + # closure happened, by looking at the execution's recorded result. Used + # both when pushing live closures to subscribers and when late + # subscribers attach to an already-closed stream. Returns nil if the + # execution has no result yet (shouldn't happen in practice — lifecycle + # closures are driven by complete_execution which only runs after a + # result is recorded). + defp derive_lifecycle_error(db, execution_id) do + case Results.get_result(db, execution_id) do + {:ok, {{:error, type, message, frames, _, _}, _, _, _}} -> + {type, message, frames} + + {:ok, {:cancelled, _, _, _}} -> + {"Coflux.ExecutionCancelled", "execution cancelled", []} + + {:ok, {{:abandoned, _}, _, _, _}} -> + {"Coflux.ExecutionAbandoned", "execution abandoned", []} + + {:ok, {{:crashed, _}, _, _, _}} -> + {"Coflux.ExecutionCrashed", "worker terminated", []} + + {:ok, {{:timeout, _}, _, _, _}} -> + {"Coflux.ExecutionTimeout", "execution timed out", []} + + _ -> + nil + end + end + defp fire_completion_notification(state, execution_id, completion_at) do {:ok, {r, s, a}} = Runs.get_execution_key(state.db, execution_id) execution_external_id = execution_external_id(r, s, a) @@ -7235,67 +7294,113 @@ defmodule Coflux.Orchestration.Server do defp filter_exhausted?(_filter, _cursor), do: false - # Send backlog items (those already in the DB) for a newly subscribed consumer. - defp push_backlog(state, session_id, subscription_id) do - key = {session_id, subscription_id} + # Per-fetch page size when draining backlog for a newly subscribed + # consumer. Keeps any single DB read bounded and lets us push each page + # to the session before loading the next. + @backlog_page_size 1024 + + # Send backlog items (those already in the DB) for a newly subscribed + # consumer. Pages the DB reads + session pushes so a very long stream + # doesn't materialise the entire tail in memory at once. + defp push_backlog(state, key) do + push_backlog_page(state, key) + end + + defp push_backlog_page(state, key) do sub = Map.fetch!(state.stream_subscriptions, key) - # Conservative fetch cap: stream more in batches if needed on advance. - # For v1 (no flow control) we just drain the whole tail. {:ok, items} = Streams.get_stream_items( state.db, sub.producer_execution_id, sub.sequence, sub.cursor, - 1_000_000 + @backlog_page_size ) + if items == [] do + state + else + state = push_backlog_items(state, key, sub, items) + + # If the subscription was dropped (filter exhausted) or didn't + # advance (nothing in this page matched), stop. Otherwise keep + # paging until the tail empties. + case Map.fetch(state.stream_subscriptions, key) do + :error -> + state + + {:ok, next_sub} -> + cond do + next_sub.cursor == sub.cursor -> state + length(items) < @backlog_page_size -> state + true -> push_backlog_page(state, key) + end + end + end + end + + defp push_backlog_items(state, key, sub, items) do + {_consumer_execution_id, subscription_id} = key + filtered = items |> Enum.filter(fn {position, _value, _at} -> filter_matches?(sub.filter, position) end) |> Enum.take_while(fn {position, _, _} -> not filter_exhausted?(sub.filter, position) end) - if filtered == [] do - state - else - last_pos = elem(List.last(filtered), 0) - next_cursor = last_pos + 1 - - # Values are already in internal form (from DB) — resolve refs to - # external IDs. Final JSON encoding happens in the WS handler. - resolved_items = - Enum.map(filtered, fn {position, value, _at} -> - [position, build_value(value, state.db)] - end) + # Advance cursor past the page even if no items matched this filter — + # otherwise we'd re-fetch the same positions forever. + advance_to = + if filtered == [] do + elem(List.last(items), 0) + 1 + else + elem(List.last(filtered), 0) + 1 + end - state = - send_session( + state = + if filtered == [] do + state + else + resolved_items = + Enum.map(filtered, fn {position, value, _at} -> + [position, build_value(value, state.db)] + end) + + send_to_consumer( state, - session_id, + sub, {:stream_items, sub.consumer_execution_external_id, subscription_id, resolved_items} ) + end - state = - update_in( - state.stream_subscriptions[key], - &Map.put(&1, :cursor, next_cursor) - ) + state = + update_in( + state.stream_subscriptions[key], + &Map.put(&1, :cursor, advance_to) + ) - # If the filter is now exhausted (slice's stop reached), close the - # subscription synchronously — matches push_stream_item's behaviour. - # Without this, a consumer that subscribed after appends with a - # bounded filter would wait forever for a close that never comes. - if filter_exhausted?(sub.filter, next_cursor) do - state - |> send_session( - session_id, - {:stream_closed, sub.consumer_execution_external_id, subscription_id, nil} - ) - |> drop_subscription(key) - else - state - end + # If the filter is now exhausted (slice's stop reached), close the + # subscription synchronously — matches push_stream_item's behaviour. + # Without this, a consumer that subscribed after appends with a + # bounded filter would wait forever for a close that never comes. + if filter_exhausted?(sub.filter, advance_to) do + state + |> send_to_consumer( + sub, + {:stream_closed, sub.consumer_execution_external_id, subscription_id, nil} + ) + |> drop_subscription(key) + else + state + end + end + + # Resolve the consumer's current session and send, skipping if the + # execution is no longer live on any session (reconnect window, etc.). + defp send_to_consumer(state, sub, payload) do + case find_session_for_execution(state, sub.consumer_execution_external_id) do + {:ok, session_id} -> send_session(state, session_id, payload) + :error -> state end end @@ -7305,7 +7410,7 @@ defmodule Coflux.Orchestration.Server do Map.get(state.stream_subscribers, {producer_execution_id, sequence}, MapSet.new()) Enum.reduce(subscribers, state, fn key, state -> - {session_id, subscription_id} = key + {_consumer_execution_id, subscription_id} = key sub = Map.fetch!(state.stream_subscriptions, key) cond do @@ -7324,9 +7429,9 @@ defmodule Coflux.Orchestration.Server do item = [position, resolved] state = - send_session( + send_to_consumer( state, - session_id, + sub, {:stream_items, sub.consumer_execution_external_id, subscription_id, [item]} ) @@ -7340,8 +7445,8 @@ defmodule Coflux.Orchestration.Server do # the subscription early — no more items will match. if filter_exhausted?(sub.filter, position + 1) do state - |> send_session( - session_id, + |> send_to_consumer( + sub, {:stream_closed, sub.consumer_execution_external_id, subscription_id, nil} ) |> drop_subscription(key) @@ -7358,56 +7463,68 @@ defmodule Coflux.Orchestration.Server do subscribers = Map.get(state.stream_subscribers, {producer_execution_id, sequence}, MapSet.new()) - encoded_error = - case error do - nil -> nil - {type, message, _frames} -> %{"type" => type, "message" => message} - end + encoded_error = encode_stream_error(error) Enum.reduce(subscribers, state, fn key, state -> - {session_id, subscription_id} = key + {_consumer_execution_id, subscription_id} = key sub = Map.fetch!(state.stream_subscriptions, key) state - |> send_session( - session_id, + |> send_to_consumer( + sub, {:stream_closed, sub.consumer_execution_external_id, subscription_id, encoded_error} ) |> drop_subscription(key) end) end + # Common wire encoding for stream_closed errors. Frames are included so + # consumers can reconstruct tracebacks for debuggability. + defp encode_stream_error(nil), do: nil + + defp encode_stream_error({type, message, frames}) do + %{ + "type" => type, + "message" => message, + "frames" => + Enum.map(frames, fn {file, line, name, code} -> + [file, line, name, code] + end) + } + end + # If a subscription attaches to an already-closed stream, emit closure now. # If push_backlog already closed the subscription (e.g., a bounded filter # was exhausted by the backlog itself), this is a no-op. - defp maybe_push_closure_if_closed(state, session_id, subscription_id) do - key = {session_id, subscription_id} - + defp maybe_push_closure_if_closed(state, key) do case Map.fetch(state.stream_subscriptions, key) do :error -> state {:ok, sub} -> - do_maybe_push_closure(state, sub, session_id, subscription_id, key) + do_maybe_push_closure(state, sub, key) end end - defp do_maybe_push_closure(state, sub, session_id, subscription_id, key) do + defp do_maybe_push_closure(state, sub, key) do + {_consumer_execution_id, subscription_id} = key + case Streams.get_stream_closure(state.db, sub.producer_execution_id, sub.sequence) do {:ok, nil} -> state - {:ok, {error, _closed_at}} -> - encoded_error = - case error do - nil -> nil - {type, message, _frames} -> %{"type" => type, "message" => message} + {:ok, {reason, error, _closed_at}} -> + resolved_error = + case reason do + :lifecycle -> derive_lifecycle_error(state.db, sub.producer_execution_id) + _ -> error end state - |> send_session( - session_id, - {:stream_closed, sub.consumer_execution_external_id, subscription_id, encoded_error} + |> send_to_consumer( + sub, + {:stream_closed, sub.consumer_execution_external_id, subscription_id, + encode_stream_error(resolved_error)} ) |> drop_subscription(key) end @@ -7441,12 +7558,31 @@ defmodule Coflux.Orchestration.Server do end end - # Drop every subscription owned by a disconnected session. + # Drop every subscription held by the disconnected session's executions. + # Called just before the session is removed, so we can read its live + # execution set directly. defp drop_session_subscriptions(state, session_id) do + session = Map.fetch!(state.sessions, session_id) + + session.starting + |> MapSet.union(session.executing) + |> Enum.reduce(state, fn ext_id, state -> + case Map.fetch(state.execution_ids, ext_id) do + {:ok, execution_id} -> drop_execution_subscriptions(state, execution_id) + :error -> state + end + end) + end + + # Drop every subscription owned by a terminated consumer execution so the + # server stops pushing items and the subscription map doesn't leak. Called + # from notify_terminated — by that point the consumer's generator iterator + # (and thus the subscription) is definitely gone. + defp drop_execution_subscriptions(state, consumer_execution_id) do keys = state.stream_subscriptions |> Map.keys() - |> Enum.filter(fn {sid, _} -> sid == session_id end) + |> Enum.filter(fn {cons_id, _sub_id} -> cons_id == consumer_execution_id end) Enum.reduce(keys, state, &drop_subscription(&2, &1)) end diff --git a/server/lib/coflux/orchestration/streams.ex b/server/lib/coflux/orchestration/streams.ex index 623a2d2c..27e51f45 100644 --- a/server/lib/coflux/orchestration/streams.ex +++ b/server/lib/coflux/orchestration/streams.ex @@ -76,10 +76,16 @@ defmodule Coflux.Orchestration.Streams do end) end - # Closes the stream. `error` is either `nil` (clean close) or a - # `{type, message, frames}` triple (error close — re-uses the errors table - # via the same path as Results). - def close_stream(db, execution_id, sequence, error \\ nil) do + # Closes the stream. `spec` describes *why* it closed: + # + # * `:complete` — producer finished normally + # * `{:errored, type, message, frames}` — producer raised an error; the + # error is stored via the errors table, same as Results + # * `:lifecycle` — closed implicitly because the producer execution + # ended (cancel/crash/abandon/error). No error is recorded here — + # callers that need to surface an error derive it from the + # execution's recorded result at read time. + def close_stream(db, execution_id, sequence, spec \\ :complete) do with_transaction(db, fn -> case exists?(db, execution_id, sequence) do {:ok, false} -> @@ -87,16 +93,12 @@ defmodule Coflux.Orchestration.Streams do {:ok, true} -> now = current_timestamp() - - error_id = - case error do - nil -> nil - {type, message, frames} -> Errors.get_or_create(db, type, message, frames) - end + {reason, error_id} = resolve_close_spec(db, spec) case insert_one(db, :stream_closures, %{ execution_id: execution_id, sequence: sequence, + reason: reason, error_id: error_id, created_at: now }) do @@ -107,6 +109,26 @@ defmodule Coflux.Orchestration.Streams do end) end + # Closure reason codes — kept in sync with the CHECK constraint in 4.sql. + @reason_complete 0 + @reason_errored 1 + @reason_lifecycle 2 + + defp resolve_close_spec(_db, :complete), do: {@reason_complete, nil} + defp resolve_close_spec(_db, :lifecycle), do: {@reason_lifecycle, nil} + + defp resolve_close_spec(db, {:errored, type, message, frames}) do + error_id = Errors.get_or_create(db, type, message, frames) + {@reason_errored, error_id} + end + + # Atom form of the reason integer — used by callers that want to decide + # whether to derive an error from the execution's result (:lifecycle) + # or use the stored one (:errored / :complete). + def reason_from_int(@reason_complete), do: :complete + def reason_from_int(@reason_errored), do: :errored + def reason_from_int(@reason_lifecycle), do: :lifecycle + def exists?(db, execution_id, sequence) do case query_one( db, @@ -164,23 +186,26 @@ defmodule Coflux.Orchestration.Streams do end # Returns closure info or `{:ok, nil}` if the stream is still open. - # Closure info: `{error | nil, created_at}` where error is - # `{type, message, frames}` when present. + # Closure info: `{reason, error | nil, created_at}` where + # * reason is :complete | :errored | :lifecycle + # * error is the `{type, message, frames}` triple for :errored, nil + # otherwise (callers derive it from the execution's result on + # :lifecycle) def get_stream_closure(db, execution_id, sequence) do case query_one( db, - "SELECT error_id, created_at FROM stream_closures WHERE execution_id = ?1 AND sequence = ?2", + "SELECT reason, error_id, created_at FROM stream_closures WHERE execution_id = ?1 AND sequence = ?2", {execution_id, sequence} ) do {:ok, nil} -> {:ok, nil} - {:ok, {nil, created_at}} -> - {:ok, {nil, created_at}} + {:ok, {reason_int, nil, created_at}} -> + {:ok, {reason_from_int(reason_int), nil, created_at}} - {:ok, {error_id, created_at}} -> + {:ok, {reason_int, error_id, created_at}} -> {:ok, error} = Errors.get_by_id(db, error_id) - {:ok, {error, created_at}} + {:ok, {reason_from_int(reason_int), error, created_at}} end end @@ -211,15 +236,17 @@ defmodule Coflux.Orchestration.Streams do end # Returns one row per stream owned by `execution_id`: - # `{sequence, created_at, closed_at | nil, error | nil}` where error, when - # present, is a `{type, message, frames}` triple. Used when populating the - # topic state for a run — lets the UI render streams and their state in - # one query. + # `{sequence, created_at, closed_at | nil, reason | nil, error | nil}`. + # * reason is :complete | :errored | :lifecycle when closed, nil when open + # * error is the stored `{type, message, frames}` triple for :errored + # closures only — callers that need to surface an error for a + # :lifecycle closure derive it from the execution's result. + # Used when populating the topic state for a run. def get_streams_with_closures_for_execution(db, execution_id) do case query( db, """ - SELECT s.sequence, s.created_at, c.created_at, c.error_id + SELECT s.sequence, s.created_at, c.created_at, c.reason, c.error_id FROM streams AS s LEFT JOIN stream_closures AS c ON c.execution_id = s.execution_id AND c.sequence = s.sequence @@ -231,15 +258,15 @@ defmodule Coflux.Orchestration.Streams do {:ok, rows} -> streams = Enum.map(rows, fn - {sequence, created_at, nil, nil} -> - {sequence, created_at, nil, nil} + {sequence, created_at, nil, nil, nil} -> + {sequence, created_at, nil, nil, nil} - {sequence, created_at, closed_at, nil} -> - {sequence, created_at, closed_at, nil} + {sequence, created_at, closed_at, reason_int, nil} -> + {sequence, created_at, closed_at, reason_from_int(reason_int), nil} - {sequence, created_at, closed_at, error_id} -> + {sequence, created_at, closed_at, reason_int, error_id} -> {:ok, error} = Errors.get_by_id(db, error_id) - {sequence, created_at, closed_at, error} + {sequence, created_at, closed_at, reason_from_int(reason_int), error} end) {:ok, streams} diff --git a/server/priv/migrations/orchestration/4.sql b/server/priv/migrations/orchestration/4.sql index 2dbc273f..8bac925a 100644 --- a/server/priv/migrations/orchestration/4.sql +++ b/server/priv/migrations/orchestration/4.sql @@ -56,12 +56,21 @@ CREATE TABLE stream_items ( FOREIGN KEY (value_id) REFERENCES values_ ON DELETE RESTRICT ) STRICT; +-- Closure of a stream. `reason` records *why* it closed: +-- 0 = complete — producer finished normally (no error) +-- 1 = errored — producer raised an error (stored in errors via error_id) +-- 2 = lifecycle — closed implicitly because the producer execution ended +-- (cancel/crash/abandon/error). The specific error is +-- derived on read by looking up the execution's result, +-- so we don't duplicate that state here. CREATE TABLE stream_closures ( execution_id INTEGER NOT NULL, sequence INTEGER NOT NULL, + reason INTEGER NOT NULL, error_id INTEGER, created_at INTEGER NOT NULL, PRIMARY KEY (execution_id, sequence), FOREIGN KEY (execution_id, sequence) REFERENCES streams (execution_id, sequence) ON DELETE CASCADE, - FOREIGN KEY (error_id) REFERENCES errors ON DELETE RESTRICT + FOREIGN KEY (error_id) REFERENCES errors ON DELETE RESTRICT, + CHECK ((reason = 1) = (error_id IS NOT NULL)) ) STRICT; diff --git a/tests/test_streams.py b/tests/test_streams.py index ebd62f5f..923af984 100644 --- a/tests/test_streams.py +++ b/tests/test_streams.py @@ -314,28 +314,29 @@ def test_multiple_subscribers_get_independent_delivery(worker): prod_ex.conn.complete(prod_ex.execution_id) ctx.result(prod_resp["runId"]) - # Subscriber A + # Each consumer picks its own subscription id locally; they only + # need to be unique within each consumer execution. Use different + # values here so we'd also catch any stale cross-consumer routing. a_resp = ctx.submit("test", "consumer") a_ex = ctx.executor.next_execute() a_ex.conn.stream_subscribe( a_ex.execution_id, - subscription_id=1, + subscription_id=7, producer_execution_id=prod_ex.execution_id, sequence=0, ) - # Subscriber B b_resp = ctx.submit("test", "consumer") b_ex = ctx.executor.next_execute() b_ex.conn.stream_subscribe( b_ex.execution_id, - subscription_id=1, + subscription_id=42, producer_execution_id=prod_ex.execution_id, sequence=0, ) - a_items, _ = a_ex.conn.drain_stream(subscription_id=1) - b_items, _ = b_ex.conn.drain_stream(subscription_id=1) + a_items, _ = a_ex.conn.drain_stream(subscription_id=7) + b_items, _ = b_ex.conn.drain_stream(subscription_id=42) a_ex.conn.complete(a_ex.execution_id) b_ex.conn.complete(b_ex.execution_id) @@ -343,6 +344,94 @@ def test_multiple_subscribers_get_independent_delivery(worker): assert [item[1]["value"] for item in b_items] == [0, 1, 2] +def test_subscription_ids_can_collide_across_consumers(worker): + """Two different consumer executions can each allocate the same + subscription id locally — the server scopes its routing map by + consumer_execution_id, so items for each consumer's subscription + reach the right executor. + """ + targets = [workflow("test", "producer"), workflow("test", "consumer")] + + with worker(targets, concurrency=3) as ctx: + prod_resp = ctx.submit("test", "producer") + prod_ex = ctx.executor.next_execute() + prod_ex.conn.stream_register(prod_ex.execution_id, 0) + for i in range(3): + prod_ex.conn.stream_append(prod_ex.execution_id, 0, i, f"v{i}") + prod_ex.conn.stream_close(prod_ex.execution_id, 0) + prod_ex.conn.complete(prod_ex.execution_id) + ctx.result(prod_resp["runId"]) + + # Both consumers use subscription_id=1 — they must not collide. + a_resp = ctx.submit("test", "consumer") + a_ex = ctx.executor.next_execute() + a_ex.conn.stream_subscribe( + a_ex.execution_id, + subscription_id=1, + producer_execution_id=prod_ex.execution_id, + sequence=0, + ) + + b_resp = ctx.submit("test", "consumer") + b_ex = ctx.executor.next_execute() + b_ex.conn.stream_subscribe( + b_ex.execution_id, + subscription_id=1, + producer_execution_id=prod_ex.execution_id, + sequence=0, + ) + + a_items, a_closed = a_ex.conn.drain_stream(subscription_id=1) + b_items, b_closed = b_ex.conn.drain_stream(subscription_id=1) + a_ex.conn.complete(a_ex.execution_id) + b_ex.conn.complete(b_ex.execution_id) + + assert [item[1]["value"] for item in a_items] == ["v0", "v1", "v2"] + assert [item[1]["value"] for item in b_items] == ["v0", "v1", "v2"] + assert a_closed.get("error") is None + assert b_closed.get("error") is None + + +def test_consumer_termination_drops_subscription(worker): + """When a consumer's notify_terminated arrives, the server must drop + its stream subscriptions so subsequent producer appends don't try to + route to a gone consumer. + """ + targets = [workflow("test", "producer"), workflow("test", "consumer")] + + with worker(targets, concurrency=2) as ctx: + prod_resp = ctx.submit("test", "producer") + prod_ex = ctx.executor.next_execute() + prod_ex.conn.stream_register(prod_ex.execution_id, 0) + + cons_resp = ctx.submit("test", "consumer") + cons_ex = ctx.executor.next_execute() + cons_ex.conn.stream_subscribe( + cons_ex.execution_id, + subscription_id=1, + producer_execution_id=prod_ex.execution_id, + sequence=0, + ) + + prod_ex.conn.stream_append(prod_ex.execution_id, 0, 0, "before") + first = cons_ex.conn.recv_push("stream_items", subscription_id=1, timeout=3) + assert first["items"][0][1]["value"] == "before" + + # Consumer finishes without explicit unsubscribe — notify_terminated + # from the session should drop the subscription on the server side. + cons_ex.conn.complete(cons_ex.execution_id) + ctx.result(cons_resp["runId"]) + + # Producer keeps appending; these should not cause the server to + # error trying to route to the dead consumer. The producer finishes + # cleanly — the assertion is that the server doesn't crash. + for i in range(1, 5): + prod_ex.conn.stream_append(prod_ex.execution_id, 0, i, i) + prod_ex.conn.stream_close(prod_ex.execution_id, 0) + prod_ex.conn.complete(prod_ex.execution_id) + ctx.result(prod_resp["runId"]) + + def test_slice_with_stop_closes_early(worker): """slice(0, 2) on a stream that has more items should close the subscriber as soon as position 2 is reached, not wait for the full From 9870be52a6e65a99801c26eb7e03bf78a69ea431 Mon Sep 17 00:00:00 2001 From: Joe Freeman Date: Sun, 19 Apr 2026 10:58:30 +0100 Subject: [PATCH 12/25] Add stream topic --- server/lib/coflux/application.ex | 3 +- server/lib/coflux/orchestration.ex | 7 + server/lib/coflux/orchestration/server.ex | 163 ++++++++++++++++++++- server/lib/coflux/orchestration/streams.ex | 47 ++++++ server/lib/coflux/topics/stream.ex | 135 +++++++++++++++++ 5 files changed, 347 insertions(+), 8 deletions(-) create mode 100644 server/lib/coflux/topics/stream.ex diff --git a/server/lib/coflux/application.ex b/server/lib/coflux/application.ex index a858dd8f..fda04af5 100644 --- a/server/lib/coflux/application.ex +++ b/server/lib/coflux/application.ex @@ -46,7 +46,8 @@ defmodule Coflux.Application do Topics.Asset, Topics.Queue, Topics.Inputs, - Topics.Input + Topics.Input, + Topics.Stream ] end end diff --git a/server/lib/coflux/orchestration.ex b/server/lib/coflux/orchestration.ex index 1d53fcba..8ac13dc6 100644 --- a/server/lib/coflux/orchestration.ex +++ b/server/lib/coflux/orchestration.ex @@ -280,6 +280,13 @@ defmodule Coflux.Orchestration do call_server(project_id, {:subscribe_run, run_id, pid}) end + def subscribe_stream_topic(project_id, execution_external_id, sequence, pid) do + call_server( + project_id, + {:subscribe_stream_topic, execution_external_id, sequence, pid} + ) + end + def subscribe_targets(project_id, workspace_id, pid) do call_server(project_id, {:subscribe_targets, workspace_id, pid}) end diff --git a/server/lib/coflux/orchestration/server.ex b/server/lib/coflux/orchestration/server.ex index 7f19c404..e31864ce 100644 --- a/server/lib/coflux/orchestration/server.ex +++ b/server/lib/coflux/orchestration/server.ex @@ -1864,8 +1864,19 @@ defmodule Coflux.Orchestration.Server do position, normalize_value(value) ) do - {:ok, _} -> - state = push_stream_item(state, execution_id, sequence, position, value) + {:ok, created_at} -> + state = + state + |> push_stream_item(execution_id, sequence, position, value) + |> notify_stream_item_appended( + execution_id, + sequence, + position, + value, + created_at + ) + |> flush_notifications() + {:reply, :ok, state} {:error, reason} -> @@ -2673,6 +2684,23 @@ defmodule Coflux.Orchestration.Server do end end + def handle_call( + {:subscribe_stream_topic, execution_external_id, sequence, pid}, + _from, + state + ) do + case build_stream_topic_initial(state, execution_external_id, sequence) do + {:ok, initial} -> + {:ok, ref, state} = + add_listener(state, {:stream, execution_external_id, sequence}, pid) + + {:reply, {:ok, initial, ref}, state} + + {:error, reason} -> + {:reply, {:error, reason}, state} + end + end + def handle_call({:subscribe_targets, workspace_external_id, pid}, _from, state) do case resolve_workspace_external_id(state, workspace_external_id) do {:error, error} -> @@ -7223,9 +7251,15 @@ defmodule Coflux.Orchestration.Server do # the StreamDriver gets a per-stream Event to gate its next() calls. # --- Stream topic notifications (for Studio subscribers) --- - # These flow through `notify_listeners` → the run topic, distinct from the - # session-directed `push_stream_*` helpers which target subscribed consumer - # sessions' WebSockets. + # These flow through `notify_listeners` → the run topic and the + # per-stream `{:stream, execution_ext_id, sequence}` inspection topic, + # distinct from the session-directed `push_stream_*` helpers which + # target subscribed consumer sessions' WebSockets. + + # Bounded tail of items held by the stream inspection topic. Long + # streams don't need to materialise every item — the UI loads older + # items on demand. + @stream_topic_tail_size 200 defp notify_stream_opened(state, execution_id, sequence, created_at) do {:ok, {r, _s, _a}} = Runs.get_execution_key(state.db, execution_id) @@ -7238,6 +7272,32 @@ defmodule Coflux.Orchestration.Server do ) end + defp notify_stream_item_appended( + state, + execution_id, + sequence, + position, + value, + created_at + ) do + {:ok, execution_ext_id} = execution_external_id_for(state.db, execution_id) + topic = {:stream, execution_ext_id, sequence} + + # Skip the build_value (which hits the DB to resolve refs) when the + # inspection topic has no active subscribers. + if Map.has_key?(state.topics, topic) do + resolved = build_value(normalize_value(value), state.db) + + notify_listeners( + state, + topic, + {:item_appended, position, resolved, created_at} + ) + else + state + end + end + defp notify_stream_closed(state, execution_id, sequence, error, closed_at) do {:ok, {r, _s, _a}} = Runs.get_execution_key(state.db, execution_id) {:ok, execution_ext_id} = execution_external_id_for(state.db, execution_id) @@ -7248,13 +7308,102 @@ defmodule Coflux.Orchestration.Server do {type, message, _frames} -> %{type: type, message: message} end + state = + notify_listeners( + state, + {:run, r}, + {:stream_closed, execution_ext_id, sequence, encoded_error, closed_at} + ) + notify_listeners( state, - {:run, r}, - {:stream_closed, execution_ext_id, sequence, encoded_error, closed_at} + {:stream, execution_ext_id, sequence}, + {:closed, encoded_error, closed_at} ) end + # Build the initial state for a newly-opened stream inspection topic. + # Returns {:ok, state} with producer metadata, opened/closed timestamps, + # closure info (with lifecycle errors already derived), bounded tail of + # items, and the total item count. + defp build_stream_topic_initial(state, execution_ext_id, sequence) do + with {:ok, execution_id} <- resolve_internal_execution_id(state, execution_ext_id), + {:ok, true} <- Streams.exists?(state.db, execution_id, sequence), + {:ok, opened_at} <- Streams.get_opened_at(state.db, execution_id, sequence), + {:ok, {items, total_count}} <- + Streams.get_stream_tail(state.db, execution_id, sequence, @stream_topic_tail_size) do + # Keep the tuple shape here — the topic module runs TopicUtils.build_value + # on each item's value to produce the JSON-encodable form, matching + # how live :item_appended notifications are handled. + resolved_items = + Enum.map(items, fn {position, value, created_at} -> + {position, build_value(value, state.db), created_at} + end) + + first_position = + case resolved_items do + [] -> nil + [{pos, _, _} | _] -> pos + end + + closure = build_stream_topic_closure(state, execution_id, sequence) + + {:ok, + %{ + producer: build_stream_producer(state.db, execution_ext_id, execution_id), + openedAt: opened_at, + closedAt: closure && closure.closedAt, + closure: closure, + items: resolved_items, + firstPosition: first_position, + totalCount: total_count, + tailSize: @stream_topic_tail_size + }} + else + {:ok, false} -> {:error, :not_found} + {:error, reason} -> {:error, reason} + end + end + + defp build_stream_topic_closure(state, execution_id, sequence) do + case Streams.get_stream_closure(state.db, execution_id, sequence) do + {:ok, nil} -> + nil + + {:ok, {reason, stored_error, closed_at}} -> + error = + case reason do + :lifecycle -> derive_lifecycle_error(state.db, execution_id) + _ -> stored_error + end + + %{ + reason: Atom.to_string(reason), + error: encode_stream_error_summary(error), + closedAt: closed_at + } + end + end + + defp encode_stream_error_summary(nil), do: nil + + defp encode_stream_error_summary({type, message, _frames}) do + %{type: type, message: message} + end + + defp build_stream_producer(db, execution_ext_id, execution_id) do + {:ok, step} = Runs.get_step_for_execution(db, execution_id) + {:ok, {run_ext_id}} = Runs.get_external_run_id_for_execution(db, execution_id) + + %{ + executionId: execution_ext_id, + runId: run_ext_id, + stepId: "#{run_ext_id}:#{step.number}", + module: step.module, + target: step.target + } + end + defp execution_external_id_for(db, execution_id) do case Runs.get_execution_key(db, execution_id) do {:ok, {r, s, a}} -> {:ok, execution_external_id(r, s, a)} diff --git a/server/lib/coflux/orchestration/streams.ex b/server/lib/coflux/orchestration/streams.ex index 27e51f45..1458edd5 100644 --- a/server/lib/coflux/orchestration/streams.ex +++ b/server/lib/coflux/orchestration/streams.ex @@ -140,6 +140,18 @@ defmodule Coflux.Orchestration.Streams do end end + # Returns the stream's registration timestamp, or `{:error, :not_found}`. + def get_opened_at(db, execution_id, sequence) do + case query_one( + db, + "SELECT created_at FROM streams WHERE execution_id = ?1 AND sequence = ?2", + {execution_id, sequence} + ) do + {:ok, nil} -> {:error, :not_found} + {:ok, {created_at}} -> {:ok, created_at} + end + end + def has_closure?(db, execution_id, sequence) do case query_one( db, @@ -287,6 +299,41 @@ defmodule Coflux.Orchestration.Streams do end end + # Returns the last `max_items` items of the stream, in position order, + # alongside the total item count. Used by the inspection topic to + # bootstrap its bounded tail buffer without materialising the full log. + def get_stream_tail(db, execution_id, sequence, max_items) do + {:ok, {total_count}} = + query_one( + db, + "SELECT COUNT(*) FROM stream_items WHERE execution_id = ?1 AND sequence = ?2", + {execution_id, sequence} + ) + + case query( + db, + """ + SELECT position, value_id, created_at + FROM stream_items + WHERE execution_id = ?1 AND sequence = ?2 + ORDER BY position DESC + LIMIT ?3 + """, + {execution_id, sequence, max_items} + ) do + {:ok, rows} -> + items = + rows + |> Enum.reverse() + |> Enum.map(fn {position, value_id, created_at} -> + {:ok, value} = Values.get_value_by_id(db, value_id) + {position, value, created_at} + end) + + {:ok, {items, total_count}} + end + end + defp current_timestamp() do System.os_time(:millisecond) end diff --git a/server/lib/coflux/topics/stream.ex b/server/lib/coflux/topics/stream.ex new file mode 100644 index 00000000..5b63bf0f --- /dev/null +++ b/server/lib/coflux/topics/stream.ex @@ -0,0 +1,135 @@ +defmodule Coflux.Topics.Stream do + @moduledoc """ + Inspection topic for a single stream, keyed by the producer's external + execution id and sequence. Used by the Studio UI when a user opens a + stream dialog — the topic keeps a bounded tail of items (with resolved + values) plus closure state, and receives live updates as items are + appended or the stream is closed. + """ + use Topical.Topic, route: ["streams", :execution_id, :sequence] + + alias Coflux.Orchestration + alias Coflux.TopicUtils + + def connect(params, context) do + {:ok, Map.put(params, :project, context.project)} + end + + def init(params) do + project_id = Map.fetch!(params, :project) + execution_id = Map.fetch!(params, :execution_id) + sequence = parse_sequence(Map.fetch!(params, :sequence)) + + case Orchestration.subscribe_stream_topic( + project_id, + execution_id, + sequence, + self() + ) do + {:ok, initial, ref} -> + {:ok, + Topic.new( + %{ + producer: initial.producer, + openedAt: initial.openedAt, + closedAt: initial.closedAt, + closure: build_closure(initial.closure), + items: Enum.map(initial.items, &build_item/1), + firstPosition: initial.firstPosition, + totalCount: initial.totalCount, + tailSize: initial.tailSize + }, + %{ref: ref, tail_size: initial.tailSize} + )} + + {:error, :not_found} -> + {:error, :not_found} + end + end + + def handle_info({:topic, _ref, notifications}, topic) do + topic = Enum.reduce(notifications, topic, &process_notification/2) + {:ok, topic} + end + + defp process_notification({:item_appended, position, value, created_at}, topic) do + tail_size = topic.state.tail_size || 200 + + item = build_item({position, value, created_at}) + + existing = topic.value.items + total = topic.value.totalCount + 1 + + # Keep items bounded: if we're already at capacity, drop the head. + # Otherwise append and leave firstPosition alone (or set it if empty). + {new_items, new_first_position} = + cond do + existing == [] -> + {[item], position} + + length(existing) >= tail_size -> + [_dropped | rest] = existing + new_items = rest ++ [item] + [first_item | _] = new_items + {new_items, first_item.position} + + true -> + {existing ++ [item], topic.value.firstPosition} + end + + topic + |> Topic.set([:items], new_items) + |> Topic.set([:firstPosition], new_first_position) + |> Topic.set([:totalCount], total) + end + + defp process_notification({:closed, error, closed_at}, topic) do + closure = build_closure_from_notification(error, closed_at) + + topic + |> Topic.set([:closedAt], closed_at) + |> Topic.set([:closure], closure) + end + + defp build_item({position, value, created_at}) do + %{ + position: position, + value: TopicUtils.build_value(value), + createdAt: created_at + } + end + + defp build_closure(nil), do: nil + + defp build_closure(%{reason: reason, error: error, closedAt: closed_at}) do + %{ + reason: reason, + error: error, + closedAt: closed_at + } + end + + # The live-close notification carries the already-encoded error summary + # (type/message) and closedAt — it doesn't include the reason because + # the notification is the same shape used by the run topic. Default to + # "errored" when there's an error, "complete" otherwise; the distinction + # between errored/lifecycle is resolved server-side before we get here. + defp build_closure_from_notification(error, closed_at) do + reason = if error, do: "errored", else: "complete" + + %{ + reason: reason, + error: error, + closedAt: closed_at + } + end + + defp parse_sequence(s) when is_integer(s), do: s + + defp parse_sequence(s) when is_binary(s) do + case Integer.parse(s) do + {n, ""} -> n + _ -> raise ArgumentError, "invalid stream sequence: #{inspect(s)}" + end + end +end From 90d5b6f6f7d1875afa4679211b64f9d2a28d1bb8 Mon Sep 17 00:00:00 2001 From: Joe Freeman Date: Sun, 19 Apr 2026 13:11:32 +0100 Subject: [PATCH 13/25] Tidy stream IDs/terminology --- adapters/python/coflux/context.py | 5 +- adapters/python/coflux/models.py | 38 ++-- adapters/python/coflux/protocol.py | 30 +-- adapters/python/coflux/serialization.py | 29 ++- adapters/python/coflux/streams.py | 76 +++++--- cli/internal/adapter/protocol.go | 18 +- cli/internal/pool/pool.go | 24 +-- cli/internal/worker/worker.go | 22 +-- server/lib/coflux/handlers/worker.ex | 26 +-- server/lib/coflux/orchestration.ex | 27 +-- server/lib/coflux/orchestration/epoch.ex | 25 +-- server/lib/coflux/orchestration/server.ex | 203 ++++++++++----------- server/lib/coflux/orchestration/streams.ex | 154 ++++++++-------- server/lib/coflux/topics/run.ex | 34 ++-- server/lib/coflux/topics/stream.ex | 115 ++++++------ server/priv/migrations/orchestration/4.sql | 29 +-- tests/support/executor.py | 20 +- tests/support/protocol.py | 26 +-- tests/test_streams.py | 44 ++--- 19 files changed, 482 insertions(+), 463 deletions(-) diff --git a/adapters/python/coflux/context.py b/adapters/python/coflux/context.py index e882f2a2..ca316088 100644 --- a/adapters/python/coflux/context.py +++ b/adapters/python/coflux/context.py @@ -98,12 +98,11 @@ def __init__(self, execution_id: str, working_dir: Path | None = None): # are registered here and driven in background threads. self._stream_driver = StreamDriver(execution_id) - def register_stream(self, generator: Any) -> tuple[str, int]: + def register_stream(self, generator: Any) -> str: """Callback for ``serialize_value(on_generator=...)``. Registers a generator with this execution's driver and returns the - ``(execution_id, sequence)`` stream reference to embed in the - serialized value. + opaque stream ``id`` to embed in the serialized value. """ return self._stream_driver.register(generator) diff --git a/adapters/python/coflux/models.py b/adapters/python/coflux/models.py index 3048a87d..193d76d8 100644 --- a/adapters/python/coflux/models.py +++ b/adapters/python/coflux/models.py @@ -215,7 +215,7 @@ class Stream(t.Iterable[T]): Iterating a ``Stream`` opens a subscription with the server; items arrive pushed over the WebSocket and yield from the iterator. Each ``__iter__`` - starts a fresh subscription from position 0, so a stream can be iterated + starts a fresh subscription from sequence 0, so a stream can be iterated multiple times and each iteration sees the whole sequence. ``partition`` and ``slice`` return new ``Stream`` views with an additional @@ -224,46 +224,40 @@ class Stream(t.Iterable[T]): def __init__( self, - producer_execution_id: str, - sequence: int, + id: str, filters: tuple[dict[str, t.Any], ...] = (), ): - self._producer_execution_id = producer_execution_id - self._sequence = sequence + # Opaque identifier of the form ``_``. + # Users may see this in the CLI/Studio but shouldn't need to parse it. + self._id = id self._filters = filters @property - def producer_execution_id(self) -> str: - return self._producer_execution_id - - @property - def sequence(self) -> int: - return self._sequence + def id(self) -> str: + return self._id def partition(self, n: int, i: int) -> "Stream[T]": - """Return a view of this stream where only positions ``p`` with - ``p % n == i`` are delivered. Round-robin partitioning for parallel + """Return a view of this stream where only sequences ``s`` with + ``s % n == i`` are delivered. Round-robin partitioning for parallel consumers. """ if n < 1 or i < 0 or i >= n: raise ValueError(f"invalid partition args: n={n}, i={i}") return Stream( - self._producer_execution_id, - self._sequence, + self._id, self._filters + ({"type": "partition", "n": n, "i": i},), ) def slice(self, start: int, stop: int | None = None) -> "Stream[T]": - """Return a view of this stream restricted to positions ``[start, stop)``. + """Return a view of this stream restricted to sequences ``[start, stop)``. ``stop=None`` means unbounded. Equivalent to ``itertools.islice`` on - the source stream's positions. + the source stream's items. """ if start < 0 or (stop is not None and stop < start): raise ValueError(f"invalid slice args: start={start}, stop={stop}") return Stream( - self._producer_execution_id, - self._sequence, + self._id, self._filters + ({"type": "slice", "start": start, "stop": stop},), ) @@ -272,8 +266,4 @@ def __iter__(self) -> t.Iterator[T]: # which imports models for Execution/Input/Asset). from .streams import open_subscription - return open_subscription( - self._producer_execution_id, - self._sequence, - self._filters, - ) + return open_subscription(self._id, self._filters) diff --git a/adapters/python/coflux/protocol.py b/adapters/python/coflux/protocol.py index 14000574..dc78bdc4 100644 --- a/adapters/python/coflux/protocol.py +++ b/adapters/python/coflux/protocol.py @@ -460,34 +460,36 @@ def send_metric( get_protocol().send_message("metric", params) -def send_stream_register(execution_id: str, sequence: int) -> None: +def send_stream_register(execution_id: str, index: int) -> None: """Register a stream owned by this execution. - Sequence is worker-assigned and monotonic per execution (0, 1, 2, ...). + ``index`` is worker-assigned and monotonic per execution (0, 1, 2, ...); + it identifies the stream within its producer execution. """ get_protocol().send_message( "stream_register", - {"execution_id": execution_id, "sequence": sequence}, + {"execution_id": execution_id, "index": index}, ) def send_stream_append( execution_id: str, + index: int, sequence: int, - position: int, value: dict[str, Any], ) -> None: - """Append an item to a stream at the given (worker-assigned) position. + """Append an item to a stream. - Position is monotonic per stream (0, 1, 2, ...). Value uses the same - Value shape as execution results (type + format + value/path + refs). + ``sequence`` is monotonic per stream (0, 1, 2, ...); it identifies the + item within its stream. Value uses the same Value shape as execution + results (type + format + value/path + refs). """ get_protocol().send_message( "stream_append", { "execution_id": execution_id, + "index": index, "sequence": sequence, - "position": position, "value": value, }, ) @@ -495,7 +497,7 @@ def send_stream_append( def send_stream_close( execution_id: str, - sequence: int, + index: int, error_type: str | None = None, error_message: str = "", traceback: str = "", @@ -508,7 +510,7 @@ def send_stream_close( """ params: dict[str, Any] = { "execution_id": execution_id, - "sequence": sequence, + "index": index, } if error_type is not None: params["error"] = { @@ -523,8 +525,8 @@ def send_stream_subscribe( execution_id: str, subscription_id: int, producer_execution_id: str, - sequence: int, - from_position: int, + index: int, + from_sequence: int, filter: dict[str, Any] | None = None, ) -> None: """Open a consumer subscription to a stream owned by another execution. @@ -536,8 +538,8 @@ def send_stream_subscribe( "execution_id": execution_id, "subscription_id": subscription_id, "producer_execution_id": producer_execution_id, - "sequence": sequence, - "from_position": from_position, + "index": index, + "from_sequence": from_sequence, } if filter is not None: params["filter"] = filter diff --git a/adapters/python/coflux/serialization.py b/adapters/python/coflux/serialization.py index d7cd6405..2f83be2e 100644 --- a/adapters/python/coflux/serialization.py +++ b/adapters/python/coflux/serialization.py @@ -57,8 +57,8 @@ def _encode_value( the path. Used for pickle fragment references. on_generator: Callback invoked for each generator encountered. Should register the generator (spawn its driver) and return the - `(execution_id, sequence)` identifying the stream. If None, - encountering a generator raises TypeError. + stream's opaque `id`. If None, encountering a generator + raises TypeError. Returns: Tuple of (data, references) where data is JSON-serializable and @@ -74,22 +74,14 @@ def _encode(v: Any) -> Any: raise TypeError( "Cannot serialize a generator: no stream driver is active." ) - execution_id, sequence = on_generator(v) - return { - "type": "stream", - "execution_id": execution_id, - "sequence": sequence, - } + stream_id = on_generator(v) + return {"type": "stream", "id": stream_id} elif isinstance(v, Stream): # Pass-through: a Stream handle received from another execution # (possibly with partition/slice filters layered on top) is # being forwarded as an argument. Preserve the filter chain so # the downstream consumer subscribes with the same filters. - encoded: dict[str, Any] = { - "type": "stream", - "execution_id": v.producer_execution_id, - "sequence": v.sequence, - } + encoded: dict[str, Any] = {"type": "stream", "id": v.id} if v._filters: encoded["filters"] = list(v._filters) return encoded @@ -268,12 +260,13 @@ def _decode(v: Any) -> Any: elif t == "ref": return _resolve_ref(v["index"]) elif t == "stream": - # Producer-owned stream reference. Self-contained — - # execution_id, sequence, and any filter chain (when the - # Stream was forwarded with partition/slice filters - # already applied) are all in the descriptor. + # Producer-owned stream reference. Self-contained — the + # opaque `id` encodes the producer's execution id + the + # stream's index, and any filter chain (when the Stream + # was forwarded with partition/slice filters already + # applied) rides alongside. filters = tuple(v.get("filters") or ()) - return Stream(v["execution_id"], v["sequence"], filters) + return Stream(v["id"], filters) else: return v else: diff --git a/adapters/python/coflux/streams.py b/adapters/python/coflux/streams.py index 5f105550..b09fa170 100644 --- a/adapters/python/coflux/streams.py +++ b/adapters/python/coflux/streams.py @@ -42,12 +42,12 @@ class StreamDriver: def __init__(self, execution_id: str) -> None: self._execution_id = execution_id - self._next_sequence = 0 + self._next_index = 0 self._threads: list[threading.Thread] = [] self._generators: list[Any] = [] self._lock = threading.Lock() - def register(self, generator: Any) -> tuple[str, int]: + def register(self, generator: Any) -> str: """Register a generator and start running it in a worker thread. Accepts both sync generators (``def`` + ``yield``) and async @@ -55,21 +55,21 @@ def register(self, generator: Any) -> tuple[str, int]: async generators run inside a fresh event loop confined to that thread. - Returns ``(execution_id, sequence)`` for embedding in the serialized - value as a stream reference. + Returns the stream's opaque ``id`` (``_``) + for embedding in the serialized value as a stream reference. """ with self._lock: - sequence = self._next_sequence - self._next_sequence += 1 + index = self._next_index + self._next_index += 1 - protocol.send_stream_register(self._execution_id, sequence) + protocol.send_stream_register(self._execution_id, index) is_async = inspect.isasyncgen(generator) target = self._run_async if is_async else self._run thread = threading.Thread( target=target, - args=(sequence, generator), - name=f"stream-{self._execution_id}-{sequence}", + args=(index, generator), + name=f"stream-{self._execution_id}-{index}", daemon=False, ) entry = {"generator": generator, "is_async": is_async, "loop": None} @@ -78,21 +78,21 @@ def register(self, generator: Any) -> tuple[str, int]: self._threads.append(thread) thread.start() - return self._execution_id, sequence + return compose_stream_id(self._execution_id, index) - def _run(self, sequence: int, generator: Any) -> None: + def _run(self, index: int, generator: Any) -> None: """Run one sync generator to exhaustion (or error).""" - position = 0 + sequence = 0 try: for item in generator: serialized = serialize_value(item) protocol.send_stream_append( self._execution_id, + index, sequence, - position, serialized, ) - position += 1 + sequence += 1 except GeneratorExit: # Generator explicitly closed (via close_all on error path, or # server-initiated cancel). Skip send_stream_close — the server @@ -104,15 +104,15 @@ def _run(self, sequence: int, generator: Any) -> None: tb = traceback.format_exc() protocol.send_stream_close( self._execution_id, - sequence, + index, error_type=error_type, error_message=str(e), traceback=tb, ) else: - protocol.send_stream_close(self._execution_id, sequence) + protocol.send_stream_close(self._execution_id, index) - def _run_async(self, sequence: int, generator: Any) -> None: + def _run_async(self, index: int, generator: Any) -> None: """Run one async generator in a fresh event loop on this thread. The loop handle is recorded so ``close_all`` can schedule aclose() @@ -123,16 +123,16 @@ def _run_async(self, sequence: int, generator: Any) -> None: asyncio.set_event_loop(loop) async def iterate() -> None: - position = 0 + sequence = 0 async for item in generator: serialized = serialize_value(item) protocol.send_stream_append( self._execution_id, + index, sequence, - position, serialized, ) - position += 1 + sequence += 1 try: loop.run_until_complete(iterate()) @@ -143,13 +143,13 @@ async def iterate() -> None: tb = traceback.format_exc() protocol.send_stream_close( self._execution_id, - sequence, + index, error_type=error_type, error_message=str(e), traceback=tb, ) else: - protocol.send_stream_close(self._execution_id, sequence) + protocol.send_stream_close(self._execution_id, index) finally: try: loop.run_until_complete(generator.aclose()) @@ -229,14 +229,14 @@ def __init__(self, subscription_id: int, execution_id: str) -> None: def on_items(self, items: list[list[Any]]) -> None: """Called by the registry when the server pushes items for this - subscription. ``items`` is a list of ``[position, value_wire]``. + subscription. ``items`` is a list of ``[sequence, value_wire]``. Runs on the dispatcher reader thread — keep it cheap. The raw wire value goes onto the queue unmodified; deserialization happens in ``__next__`` on the consumer's thread so heavy decode work doesn't stall stdin reads. """ - for _position, value in items: + for _sequence, value in items: self._queue.put(value) def on_closed(self, error: dict[str, Any] | None) -> None: @@ -345,9 +345,27 @@ def _stream_registry() -> StreamRegistry: return _registry_instance +def compose_stream_id(execution_id: str, index: int) -> str: + """Build the opaque stream id from its two components. + + Joined with ``_`` because the alternatives are overloaded: ``:`` is + used inside the execution id, ``#`` is used for attempt numbers, ``/`` + separates module/target. Execution ids use only alphanumerics, so + ``rpartition('_')`` is unambiguous on the parse side. + """ + return f"{execution_id}_{index}" + + +def parse_stream_id(id: str) -> tuple[str, int]: + """Reverse of ``compose_stream_id``. Raises ValueError on bad input.""" + exec_id, sep, index = id.rpartition("_") + if not sep or not exec_id: + raise ValueError(f"invalid stream id: {id!r}") + return exec_id, int(index) + + def open_subscription( - producer_execution_id: str, - sequence: int, + stream_id: str, filters: tuple[dict[str, Any], ...], ) -> Iterator[Any]: """Begin iterating a stream. Called by ``Stream.__iter__``. @@ -359,12 +377,16 @@ def open_subscription( execution_id = ctx.execution_id subscription_id, iterator = _stream_registry().allocate(execution_id) + # Split the opaque id for the wire message, which still takes + # producer_execution_id + index positionally. + producer_execution_id, index = parse_stream_id(stream_id) + filter = _compose_filter(filters) protocol.send_stream_subscribe( execution_id, subscription_id, producer_execution_id, - sequence, + index, 0, filter, ) diff --git a/cli/internal/adapter/protocol.go b/cli/internal/adapter/protocol.go index eb1903c1..ae3b8473 100644 --- a/cli/internal/adapter/protocol.go +++ b/cli/internal/adapter/protocol.go @@ -254,18 +254,20 @@ type RegisterGroupParams struct { } // StreamRegisterParams for stream_register notification. -// Sequence is worker-assigned, monotonic per execution. +// Index is worker-assigned, monotonic per execution — it identifies the +// stream within its producer execution. type StreamRegisterParams struct { ExecutionID string `json:"execution_id"` - Sequence int `json:"sequence"` + Index int `json:"index"` } // StreamAppendParams for stream_append notification. -// Position is worker-assigned, monotonic per stream. +// Sequence is worker-assigned, monotonic per stream — it identifies the +// item within its stream. type StreamAppendParams struct { ExecutionID string `json:"execution_id"` + Index int `json:"index"` Sequence int `json:"sequence"` - Position int `json:"position"` Value *Value `json:"value"` } @@ -273,7 +275,7 @@ type StreamAppendParams struct { // when the producer's generator raised an exception. type StreamCloseParams struct { ExecutionID string `json:"execution_id"` - Sequence int `json:"sequence"` + Index int `json:"index"` Error *StreamCloseError `json:"error,omitempty"` } @@ -291,8 +293,8 @@ type StreamSubscribeParams struct { ExecutionID string `json:"execution_id"` // consumer SubscriptionID int `json:"subscription_id"` ProducerExecutionID string `json:"producer_execution_id"` - Sequence int `json:"sequence"` - FromPosition int `json:"from_position"` + Index int `json:"index"` + FromSequence int `json:"from_sequence"` Filter map[string]any `json:"filter,omitempty"` } @@ -303,7 +305,7 @@ type StreamUnsubscribeParams struct { } // StreamItemsParams for stream_items notification pushed CLI → adapter. -// Items are [[position, value], ...] where value is a wire Value. +// Items are [[sequence, value], ...] where value is a wire Value. type StreamItemsParams struct { ExecutionID string `json:"execution_id"` SubscriptionID int `json:"subscription_id"` diff --git a/cli/internal/pool/pool.go b/cli/internal/pool/pool.go index e722ed4b..54c77692 100644 --- a/cli/internal/pool/pool.go +++ b/cli/internal/pool/pool.go @@ -50,16 +50,18 @@ type ExecutionHandler interface { // NotifyTerminated notifies the server that an execution's process has exited NotifyTerminated(ctx context.Context, executionID string) error // StreamRegister declares a new stream owned by an execution. - // Sequence is worker-assigned, monotonic per execution. - StreamRegister(ctx context.Context, executionID string, sequence int) error - // StreamAppend appends an item to a stream at the given (worker-assigned) position. - StreamAppend(ctx context.Context, executionID string, sequence int, position int, value *adapter.Value) error + // Index is worker-assigned, monotonic per execution — it identifies + // the stream within its producer execution. + StreamRegister(ctx context.Context, executionID string, index int) error + // StreamAppend appends an item to a stream. Sequence is worker-assigned, + // monotonic per stream — it identifies the item within its stream. + StreamAppend(ctx context.Context, executionID string, index int, sequence int, value *adapter.Value) error // StreamClose closes a stream. Error is nil for a clean close, or a (type, message, traceback) // triple when the producer's generator raised. - StreamClose(ctx context.Context, executionID string, sequence int, err *adapter.StreamCloseError) error + StreamClose(ctx context.Context, executionID string, index int, err *adapter.StreamCloseError) error // StreamSubscribe opens a consumer subscription to a stream owned by another execution. // Filter is nil or a {"type": "slice", ...}/{"type": "partition", ...} map. - StreamSubscribe(ctx context.Context, executionID string, subscriptionID int, producerExecutionID string, sequence int, fromPosition int, filter map[string]any) error + StreamSubscribe(ctx context.Context, executionID string, subscriptionID int, producerExecutionID string, index int, fromSequence int, filter map[string]any) error // StreamUnsubscribe drops a consumer subscription. StreamUnsubscribe(ctx context.Context, executionID string, subscriptionID int) error } @@ -484,7 +486,7 @@ func (p *Pool) handleStreamRegister(ctx context.Context, executionID string, par return } - if err := p.handler.StreamRegister(ctx, req.ExecutionID, req.Sequence); err != nil { + if err := p.handler.StreamRegister(ctx, req.ExecutionID, req.Index); err != nil { logger.Error("failed to register stream", "error", err) } } @@ -496,7 +498,7 @@ func (p *Pool) handleStreamAppend(ctx context.Context, executionID string, param return } - if err := p.handler.StreamAppend(ctx, req.ExecutionID, req.Sequence, req.Position, req.Value); err != nil { + if err := p.handler.StreamAppend(ctx, req.ExecutionID, req.Index, req.Sequence, req.Value); err != nil { logger.Error("failed to append stream item", "error", err) } } @@ -508,7 +510,7 @@ func (p *Pool) handleStreamClose(ctx context.Context, executionID string, params return } - if err := p.handler.StreamClose(ctx, req.ExecutionID, req.Sequence, req.Error); err != nil { + if err := p.handler.StreamClose(ctx, req.ExecutionID, req.Index, req.Error); err != nil { logger.Error("failed to close stream", "error", err) } } @@ -525,8 +527,8 @@ func (p *Pool) handleStreamSubscribe(ctx context.Context, executionID string, pa req.ExecutionID, req.SubscriptionID, req.ProducerExecutionID, - req.Sequence, - req.FromPosition, + req.Index, + req.FromSequence, req.Filter, ); err != nil { logger.Error("failed to subscribe to stream", "error", err) diff --git a/cli/internal/worker/worker.go b/cli/internal/worker/worker.go index 90d03ec4..0e48420e 100644 --- a/cli/internal/worker/worker.go +++ b/cli/internal/worker/worker.go @@ -651,8 +651,8 @@ func (w *Worker) handleAbort(params []any) error { // handleStreamItems forwards a server-pushed batch of stream items to the // adapter process owning the target execution. Params: [execution_id, -// subscription_id, items]. Each item arrives as [position, value_array] -// and is converted to [position, adapter.Value dict] so the Python side +// subscription_id, items]. Each item arrives as [sequence, value_array] +// and is converted to [sequence, adapter.Value dict] so the Python side // can deserialize_value it directly. func (w *Worker) handleStreamItems(params []any) error { if len(params) < 3 { @@ -1181,15 +1181,15 @@ func (w *Worker) RegisterGroup(ctx context.Context, executionID string, groupID return conn.Notify("register_group", executionID, groupID, name) } -func (w *Worker) StreamRegister(ctx context.Context, executionID string, sequence int) error { +func (w *Worker) StreamRegister(ctx context.Context, executionID string, index int) error { conn, err := w.requireConn() if err != nil { return err } - return conn.Notify("stream_register", executionID, sequence) + return conn.Notify("stream_register", executionID, index) } -func (w *Worker) StreamAppend(ctx context.Context, executionID string, sequence int, position int, value *adapter.Value) error { +func (w *Worker) StreamAppend(ctx context.Context, executionID string, index int, sequence int, value *adapter.Value) error { conn, err := w.requireConn() if err != nil { return err @@ -1199,10 +1199,10 @@ func (w *Worker) StreamAppend(ctx context.Context, executionID string, sequence if err != nil { return err } - return conn.Notify("stream_append", executionID, sequence, position, serverValue) + return conn.Notify("stream_append", executionID, index, sequence, serverValue) } -func (w *Worker) StreamClose(ctx context.Context, executionID string, sequence int, streamErr *adapter.StreamCloseError) error { +func (w *Worker) StreamClose(ctx context.Context, executionID string, index int, streamErr *adapter.StreamCloseError) error { conn, err := w.requireConn() if err != nil { return err @@ -1215,16 +1215,16 @@ func (w *Worker) StreamClose(ctx context.Context, executionID string, sequence i frames := parseTraceback(streamErr.Traceback) errTuple = []any{streamErr.Type, streamErr.Message, frames} } - return conn.Notify("stream_close", executionID, sequence, errTuple) + return conn.Notify("stream_close", executionID, index, errTuple) } -func (w *Worker) StreamSubscribe(ctx context.Context, executionID string, subscriptionID int, producerExecutionID string, sequence int, fromPosition int, filter map[string]any) error { +func (w *Worker) StreamSubscribe(ctx context.Context, executionID string, subscriptionID int, producerExecutionID string, index int, fromSequence int, filter map[string]any) error { conn, err := w.requireConn() if err != nil { return err } - // Params: [subscription_id, consumer_execution_id, producer_execution_id, sequence, from_position, filter] - return conn.Notify("stream_subscribe", subscriptionID, executionID, producerExecutionID, sequence, fromPosition, filter) + // Params: [subscription_id, consumer_execution_id, producer_execution_id, index, from_sequence, filter] + return conn.Notify("stream_subscribe", subscriptionID, executionID, producerExecutionID, index, fromSequence, filter) } func (w *Worker) StreamUnsubscribe(ctx context.Context, executionID string, subscriptionID int) error { diff --git a/server/lib/coflux/handlers/worker.ex b/server/lib/coflux/handlers/worker.ex index 177daedd..50702d32 100644 --- a/server/lib/coflux/handlers/worker.ex +++ b/server/lib/coflux/handlers/worker.ex @@ -254,10 +254,10 @@ defmodule Coflux.Handlers.Worker do end "stream_register" -> - [execution_id, sequence] = message["params"] + [execution_id, index] = message["params"] if is_recognised_execution?(execution_id, state) do - case Orchestration.register_stream(state.project_id, execution_id, sequence) do + case Orchestration.register_stream(state.project_id, execution_id, index) do :ok -> {[], state} # Idempotent — a duplicate register is harmless. {:error, :already_registered} -> {[], state} @@ -268,14 +268,14 @@ defmodule Coflux.Handlers.Worker do end "stream_append" -> - [execution_id, sequence, position, value] = message["params"] + [execution_id, index, sequence, value] = message["params"] if is_recognised_execution?(execution_id, state) do case Orchestration.append_stream_item( state.project_id, execution_id, + index, sequence, - position, parse_value(value) ) do :ok -> @@ -303,7 +303,7 @@ defmodule Coflux.Handlers.Worker do end "stream_close" -> - [execution_id, sequence, error] = message["params"] + [execution_id, index, error] = message["params"] if is_recognised_execution?(execution_id, state) do parsed_error = @@ -315,7 +315,7 @@ defmodule Coflux.Handlers.Worker do case Orchestration.close_stream( state.project_id, execution_id, - sequence, + index, parsed_error ) do :ok -> {[], state} @@ -332,8 +332,8 @@ defmodule Coflux.Handlers.Worker do subscription_id, consumer_execution_id, producer_execution_id, - sequence, - from_position, + index, + from_sequence, filter ] = message["params"] @@ -344,8 +344,8 @@ defmodule Coflux.Handlers.Worker do subscription_id, consumer_execution_id, producer_execution_id, - sequence, - from_position, + index, + from_sequence, filter ) do :ok -> @@ -629,11 +629,11 @@ defmodule Coflux.Handlers.Worker do end def websocket_info({:stream_items, execution_external_id, subscription_id, items}, state) do - # Items arrive in resolved form ([[position, value_tuple], ...]); compose + # Items arrive in resolved form ([[sequence, value_tuple], ...]); compose # each value tuple to wire JSON here. encoded = - Enum.map(items, fn [position, value] -> - [position, compose_value(value)] + Enum.map(items, fn [sequence, value] -> + [sequence, compose_value(value)] end) {[command_message("stream_items", [execution_external_id, subscription_id, encoded])], state} diff --git a/server/lib/coflux/orchestration.ex b/server/lib/coflux/orchestration.ex index 8ac13dc6..1053dac0 100644 --- a/server/lib/coflux/orchestration.ex +++ b/server/lib/coflux/orchestration.ex @@ -182,19 +182,20 @@ defmodule Coflux.Orchestration do end # Stream producer messages — worker registers a stream, appends items, - # and closes the stream. Sequence and position are worker-assigned and - # monotonic per-execution / per-stream. + # and closes the stream. `index` identifies the stream within its + # producer execution; `sequence` identifies an item within the stream. + # Both are worker-assigned and monotonic from 0. - def register_stream(project_id, execution_id, sequence) do - call_server(project_id, {:register_stream, execution_id, sequence}) + def register_stream(project_id, execution_id, index) do + call_server(project_id, {:register_stream, execution_id, index}) end - def append_stream_item(project_id, execution_id, sequence, position, value) do - call_server(project_id, {:append_stream_item, execution_id, sequence, position, value}) + def append_stream_item(project_id, execution_id, index, sequence, value) do + call_server(project_id, {:append_stream_item, execution_id, index, sequence, value}) end - def close_stream(project_id, execution_id, sequence, error) do - call_server(project_id, {:close_stream, execution_id, sequence, error}) + def close_stream(project_id, execution_id, index, error) do + call_server(project_id, {:close_stream, execution_id, index, error}) end # Stream consumer messages — consumer opens a subscription to receive @@ -207,14 +208,14 @@ defmodule Coflux.Orchestration do subscription_id, consumer_execution_id, producer_execution_id, - sequence, - from_position, + index, + from_sequence, filter ) do call_server( project_id, {:subscribe_stream, session_id, subscription_id, consumer_execution_id, - producer_execution_id, sequence, from_position, filter} + producer_execution_id, index, from_sequence, filter} ) end @@ -280,10 +281,10 @@ defmodule Coflux.Orchestration do call_server(project_id, {:subscribe_run, run_id, pid}) end - def subscribe_stream_topic(project_id, execution_external_id, sequence, pid) do + def subscribe_stream_topic(project_id, execution_external_id, index, pid) do call_server( project_id, - {:subscribe_stream_topic, execution_external_id, sequence, pid} + {:subscribe_stream_topic, execution_external_id, index, pid} ) end diff --git a/server/lib/coflux/orchestration/epoch.ex b/server/lib/coflux/orchestration/epoch.ex index e7edc0f1..099cb66c 100644 --- a/server/lib/coflux/orchestration/epoch.ex +++ b/server/lib/coflux/orchestration/epoch.ex @@ -339,15 +339,15 @@ defmodule Coflux.Orchestration.Epoch do {:ok, streams} = query( source_db, - "SELECT sequence, created_at FROM streams WHERE execution_id = ?1", + "SELECT `index`, created_at FROM streams WHERE execution_id = ?1", {old_exec_id} ) - Enum.each(streams, fn {sequence, stream_created_at} -> + Enum.each(streams, fn {index, stream_created_at} -> {:ok, _} = insert_one(target_db, :streams, %{ execution_id: new_exec_id, - sequence: sequence, + index: index, created_at: stream_created_at }) @@ -355,21 +355,21 @@ defmodule Coflux.Orchestration.Epoch do query( source_db, """ - SELECT position, value_id, created_at + SELECT sequence, value_id, created_at FROM stream_items - WHERE execution_id = ?1 AND sequence = ?2 + WHERE execution_id = ?1 AND `index` = ?2 """, - {old_exec_id, sequence} + {old_exec_id, index} ) - Enum.each(items, fn {position, value_id, item_created_at} -> + Enum.each(items, fn {sequence, value_id, item_created_at} -> new_value_id = ensure_value(source_db, target_db, value_id) {:ok, _} = insert_one(target_db, :stream_items, %{ execution_id: new_exec_id, + index: index, sequence: sequence, - position: position, value_id: new_value_id, created_at: item_created_at }) @@ -377,17 +377,18 @@ defmodule Coflux.Orchestration.Epoch do case query_one( source_db, - "SELECT error_id, created_at FROM stream_closures WHERE execution_id = ?1 AND sequence = ?2", - {old_exec_id, sequence} + "SELECT reason, error_id, created_at FROM stream_closures WHERE execution_id = ?1 AND `index` = ?2", + {old_exec_id, index} ) do - {:ok, {error_id, closure_created_at}} -> + {:ok, {reason, error_id, closure_created_at}} -> new_error_id = if error_id, do: ensure_error(source_db, target_db, error_id) {:ok, _} = insert_one(target_db, :stream_closures, %{ execution_id: new_exec_id, - sequence: sequence, + index: index, + reason: reason, error_id: new_error_id, created_at: closure_created_at }) diff --git a/server/lib/coflux/orchestration/server.ex b/server/lib/coflux/orchestration/server.ex index e31864ce..5eac0732 100644 --- a/server/lib/coflux/orchestration/server.ex +++ b/server/lib/coflux/orchestration/server.ex @@ -132,8 +132,8 @@ defmodule Coflux.Orchestration.Server do # # stream_subscriptions: {consumer_execution_id, subscription_id} -> # %{consumer_execution_external_id, producer_execution_id, - # sequence, cursor, filter} - # stream_subscribers: {producer_execution_id, sequence} -> MapSet of + # index, cursor, filter} + # stream_subscribers: {producer_execution_id, index} -> MapSet of # {consumer_execution_id, subscription_id} stream_subscriptions: %{}, stream_subscribers: %{} @@ -1830,13 +1830,13 @@ defmodule Coflux.Orchestration.Server do end end - def handle_call({:register_stream, execution_external_id, sequence}, _from, state) do + def handle_call({:register_stream, execution_external_id, index}, _from, state) do case Map.fetch(state.execution_ids, execution_external_id) do {:ok, execution_id} -> - case Streams.register_stream(state.db, execution_id, sequence) do + case Streams.register_stream(state.db, execution_id, index) do {:ok, created_at} -> state = - notify_stream_opened(state, execution_id, sequence, created_at) + notify_stream_opened(state, execution_id, index, created_at) |> flush_notifications() {:reply, :ok, state} @@ -1851,7 +1851,7 @@ defmodule Coflux.Orchestration.Server do end def handle_call( - {:append_stream_item, execution_external_id, sequence, position, value}, + {:append_stream_item, execution_external_id, index, sequence, value}, _from, state ) do @@ -1860,18 +1860,18 @@ defmodule Coflux.Orchestration.Server do case Streams.append_item( state.db, execution_id, + index, sequence, - position, normalize_value(value) ) do {:ok, created_at} -> state = state - |> push_stream_item(execution_id, sequence, position, value) + |> push_stream_item(execution_id, index, sequence, value) |> notify_stream_item_appended( execution_id, + index, sequence, - position, value, created_at ) @@ -1888,21 +1888,24 @@ defmodule Coflux.Orchestration.Server do end end - def handle_call({:close_stream, execution_external_id, sequence, error}, _from, state) do + def handle_call({:close_stream, execution_external_id, index, error}, _from, state) do case Map.fetch(state.execution_ids, execution_external_id) do {:ok, execution_id} -> - spec = + {spec, reason} = case error do - nil -> :complete - {type, message, frames} -> {:errored, type, message, frames} + nil -> + {:complete, :complete} + + {type, message, frames} -> + {{:errored, type, message, frames}, :errored} end - case Streams.close_stream(state.db, execution_id, sequence, spec) do + case Streams.close_stream(state.db, execution_id, index, spec) do {:ok, closed_at} -> state = state - |> push_stream_closed(execution_id, sequence, error) - |> notify_stream_closed(execution_id, sequence, error, closed_at) + |> push_stream_closed(execution_id, index, error) + |> notify_stream_closed(execution_id, index, reason, error, closed_at) |> flush_notifications() {:reply, :ok, state} @@ -1918,7 +1921,7 @@ defmodule Coflux.Orchestration.Server do def handle_call( {:subscribe_stream, session_external_id, subscription_id, consumer_execution_external_id, - producer_execution_external_id, sequence, from_position, filter}, + producer_execution_external_id, index, from_sequence, filter}, _from, state ) do @@ -1937,14 +1940,14 @@ defmodule Coflux.Orchestration.Server do Map.fetch(state.execution_ids, consumer_execution_external_id) |> ok_or(:consumer_not_found), {:ok, producer_execution_id} <- producer_result, - {:ok, true} <- Streams.exists?(state.db, producer_execution_id, sequence), + {:ok, true} <- Streams.exists?(state.db, producer_execution_id, index), key = {consumer_execution_id, subscription_id}, false <- Map.has_key?(state.stream_subscriptions, key) do subscription = %{ consumer_execution_external_id: consumer_execution_external_id, producer_execution_id: producer_execution_id, - sequence: sequence, - cursor: from_position, + index: index, + cursor: from_sequence, filter: filter } @@ -1954,7 +1957,7 @@ defmodule Coflux.Orchestration.Server do |> Map.update!(:stream_subscribers, fn m -> Map.update( m, - {producer_execution_id, sequence}, + {producer_execution_id, index}, MapSet.new([key]), &MapSet.put(&1, key) ) @@ -2685,14 +2688,14 @@ defmodule Coflux.Orchestration.Server do end def handle_call( - {:subscribe_stream_topic, execution_external_id, sequence, pid}, + {:subscribe_stream_topic, execution_external_id, index, pid}, _from, state ) do - case build_stream_topic_initial(state, execution_external_id, sequence) do + case build_stream_topic_initial(state, execution_external_id, index) do {:ok, initial} -> {:ok, ref, state} = - add_listener(state, {:stream, execution_external_id, sequence}, pid) + add_listener(state, {:stream, execution_external_id, index}, pid) {:reply, {:ok, initial, ref}, state} @@ -5904,15 +5907,15 @@ defmodule Coflux.Orchestration.Server do # on the closure row. Consumers that need to surface an error derive # it from the execution's recorded result (see derive_lifecycle_error). defp close_open_streams(state, execution_id) do - {:ok, sequences} = Streams.get_open_streams_for_execution(state.db, execution_id) + {:ok, indexes} = Streams.get_open_streams_for_execution(state.db, execution_id) push_error = derive_lifecycle_error(state.db, execution_id) - Enum.reduce(sequences, state, fn sequence, state -> - case Streams.close_stream(state.db, execution_id, sequence, :lifecycle) do + Enum.reduce(indexes, state, fn index, state -> + case Streams.close_stream(state.db, execution_id, index, :lifecycle) do {:ok, closed_at} -> state - |> push_stream_closed(execution_id, sequence, push_error) - |> notify_stream_closed(execution_id, sequence, push_error, closed_at) + |> push_stream_closed(execution_id, index, push_error) + |> notify_stream_closed(execution_id, index, :lifecycle, push_error, closed_at) {:error, :already_closed} -> state @@ -5922,21 +5925,22 @@ defmodule Coflux.Orchestration.Server do # Returns the streams list for `execution_id` with :lifecycle closures' # errors resolved from the execution's recorded result. Shape: - # `{sequence, opened_at, closed_at | nil, error | nil}` — the same - # shape the topic module expects (reason is collapsed into error/nil - # once we've derived it). + # `{index, opened_at, closed_at | nil, reason | nil, error | nil}` — + # reason is retained so the topic can colour open vs complete vs + # errored vs lifecycle distinctly. defp streams_with_resolved_errors(db, execution_id) do {:ok, rows} = Streams.get_streams_with_closures_for_execution(db, execution_id) Enum.map(rows, fn - {sequence, opened_at, nil, nil, nil} -> - {sequence, opened_at, nil, nil} + {index, opened_at, nil, nil, nil} -> + {index, opened_at, nil, nil, nil} - {sequence, opened_at, closed_at, :lifecycle, _} -> - {sequence, opened_at, closed_at, derive_lifecycle_error(db, execution_id)} + {index, opened_at, closed_at, :lifecycle, _} -> + {index, opened_at, closed_at, :lifecycle, + derive_lifecycle_error(db, execution_id)} - {sequence, opened_at, closed_at, _reason, error} -> - {sequence, opened_at, closed_at, error} + {index, opened_at, closed_at, reason, error} -> + {index, opened_at, closed_at, reason, error} end) end @@ -7245,14 +7249,14 @@ defmodule Coflux.Orchestration.Server do # explicit pause/resume (e.g. an infinite producer with no subscribers # filling disk), the hooks go here: # * On first subscriber for a stream: send_session(producer_session, - # {:stream_resume, producer_exec_ext, sequence}). + # {:stream_resume, producer_exec_ext, index}). # * On last subscriber dropping: {:stream_pause, ...}. # Dispatcher on the adapter side already routes notifications by method; # the StreamDriver gets a per-stream Event to gate its next() calls. # --- Stream topic notifications (for Studio subscribers) --- # These flow through `notify_listeners` → the run topic and the - # per-stream `{:stream, execution_ext_id, sequence}` inspection topic, + # per-stream `{:stream, execution_ext_id, index}` inspection topic, # distinct from the session-directed `push_stream_*` helpers which # target subscribed consumer sessions' WebSockets. @@ -7261,27 +7265,27 @@ defmodule Coflux.Orchestration.Server do # items on demand. @stream_topic_tail_size 200 - defp notify_stream_opened(state, execution_id, sequence, created_at) do + defp notify_stream_opened(state, execution_id, index, created_at) do {:ok, {r, _s, _a}} = Runs.get_execution_key(state.db, execution_id) {:ok, execution_ext_id} = execution_external_id_for(state.db, execution_id) notify_listeners( state, {:run, r}, - {:stream_opened, execution_ext_id, sequence, created_at} + {:stream_opened, execution_ext_id, index, created_at} ) end defp notify_stream_item_appended( state, execution_id, + index, sequence, - position, value, created_at ) do {:ok, execution_ext_id} = execution_external_id_for(state.db, execution_id) - topic = {:stream, execution_ext_id, sequence} + topic = {:stream, execution_ext_id, index} # Skip the build_value (which hits the DB to resolve refs) when the # inspection topic has no active subscribers. @@ -7291,14 +7295,14 @@ defmodule Coflux.Orchestration.Server do notify_listeners( state, topic, - {:item_appended, position, resolved, created_at} + {:item_appended, sequence, resolved, created_at} ) else state end end - defp notify_stream_closed(state, execution_id, sequence, error, closed_at) do + defp notify_stream_closed(state, execution_id, index, reason, error, closed_at) do {:ok, {r, _s, _a}} = Runs.get_execution_key(state.db, execution_id) {:ok, execution_ext_id} = execution_external_id_for(state.db, execution_id) @@ -7308,17 +7312,19 @@ defmodule Coflux.Orchestration.Server do {type, message, _frames} -> %{type: type, message: message} end + reason_str = Atom.to_string(reason) + state = notify_listeners( state, {:run, r}, - {:stream_closed, execution_ext_id, sequence, encoded_error, closed_at} + {:stream_closed, execution_ext_id, index, reason_str, encoded_error, closed_at} ) notify_listeners( state, - {:stream, execution_ext_id, sequence}, - {:closed, encoded_error, closed_at} + {:stream, execution_ext_id, index}, + {:closed, reason_str, encoded_error, closed_at} ) end @@ -7326,36 +7332,28 @@ defmodule Coflux.Orchestration.Server do # Returns {:ok, state} with producer metadata, opened/closed timestamps, # closure info (with lifecycle errors already derived), bounded tail of # items, and the total item count. - defp build_stream_topic_initial(state, execution_ext_id, sequence) do + defp build_stream_topic_initial(state, execution_ext_id, index) do with {:ok, execution_id} <- resolve_internal_execution_id(state, execution_ext_id), - {:ok, true} <- Streams.exists?(state.db, execution_id, sequence), - {:ok, opened_at} <- Streams.get_opened_at(state.db, execution_id, sequence), + {:ok, true} <- Streams.exists?(state.db, execution_id, index), + {:ok, opened_at} <- Streams.get_opened_at(state.db, execution_id, index), {:ok, {items, total_count}} <- - Streams.get_stream_tail(state.db, execution_id, sequence, @stream_topic_tail_size) do + Streams.get_stream_tail(state.db, execution_id, index, @stream_topic_tail_size) do # Keep the tuple shape here — the topic module runs TopicUtils.build_value # on each item's value to produce the JSON-encodable form, matching # how live :item_appended notifications are handled. resolved_items = - Enum.map(items, fn {position, value, created_at} -> - {position, build_value(value, state.db), created_at} + Enum.map(items, fn {sequence, value, created_at} -> + {sequence, build_value(value, state.db), created_at} end) - first_position = - case resolved_items do - [] -> nil - [{pos, _, _} | _] -> pos - end - - closure = build_stream_topic_closure(state, execution_id, sequence) + closure = build_stream_topic_closure(state, execution_id, index) {:ok, %{ producer: build_stream_producer(state.db, execution_ext_id, execution_id), openedAt: opened_at, - closedAt: closure && closure.closedAt, closure: closure, items: resolved_items, - firstPosition: first_position, totalCount: total_count, tailSize: @stream_topic_tail_size }} @@ -7365,8 +7363,8 @@ defmodule Coflux.Orchestration.Server do end end - defp build_stream_topic_closure(state, execution_id, sequence) do - case Streams.get_stream_closure(state.db, execution_id, sequence) do + defp build_stream_topic_closure(state, execution_id, index) do + case Streams.get_stream_closure(state.db, execution_id, index) do {:ok, nil} -> nil @@ -7392,16 +7390,11 @@ defmodule Coflux.Orchestration.Server do end defp build_stream_producer(db, execution_ext_id, execution_id) do + # The external execution_id already encodes run + step + attempt, so + # the producer reference only carries identifier + module/target — + # matching ExecutionReference on the wire. {:ok, step} = Runs.get_step_for_execution(db, execution_id) - {:ok, {run_ext_id}} = Runs.get_external_run_id_for_execution(db, execution_id) - - %{ - executionId: execution_ext_id, - runId: run_ext_id, - stepId: "#{run_ext_id}:#{step.number}", - module: step.module, - target: step.target - } + Coflux.TopicUtils.build_execution({execution_ext_id, step.module, step.target}) end defp execution_external_id_for(db, execution_id) do @@ -7416,24 +7409,24 @@ defmodule Coflux.Orchestration.Server do defp ok_or({:ok, val}, _reason), do: {:ok, val} defp ok_or(:error, reason), do: {:error, reason} - # Does a `position` pass a subscription's filter? - defp filter_matches?(nil, _position), do: true + # Does a `sequence` pass a subscription's filter? + defp filter_matches?(nil, _sequence), do: true - defp filter_matches?(%{"type" => "slice", "start" => s, "stop" => nil}, position), - do: position >= s + defp filter_matches?(%{"type" => "slice", "start" => s, "stop" => nil}, sequence), + do: sequence >= s - defp filter_matches?(%{"type" => "slice", "start" => s, "stop" => e}, position), - do: position >= s and position < e + defp filter_matches?(%{"type" => "slice", "start" => s, "stop" => e}, sequence), + do: sequence >= s and sequence < e - defp filter_matches?(%{"type" => "partition", "n" => n, "i" => i}, position), - do: rem(position, n) == i + defp filter_matches?(%{"type" => "partition", "n" => n, "i" => i}, sequence), + do: rem(sequence, n) == i - defp filter_matches?(%{"type" => "chain", "filters" => fs}, position), - do: Enum.all?(fs, &filter_matches?(&1, position)) + defp filter_matches?(%{"type" => "chain", "filters" => fs}, sequence), + do: Enum.all?(fs, &filter_matches?(&1, sequence)) - defp filter_matches?(_filter, _position), do: true + defp filter_matches?(_filter, _sequence), do: true - # Is `position` past the end of the filter's effective range? + # Is `sequence` past the end of the filter's effective range? # Lets us close streams early once a slice's stop is reached. defp filter_exhausted?(%{"type" => "slice", "stop" => stop}, cursor) when is_integer(stop), do: cursor >= stop @@ -7462,7 +7455,7 @@ defmodule Coflux.Orchestration.Server do Streams.get_stream_items( state.db, sub.producer_execution_id, - sub.sequence, + sub.index, sub.cursor, @backlog_page_size ) @@ -7494,11 +7487,11 @@ defmodule Coflux.Orchestration.Server do filtered = items - |> Enum.filter(fn {position, _value, _at} -> filter_matches?(sub.filter, position) end) - |> Enum.take_while(fn {position, _, _} -> not filter_exhausted?(sub.filter, position) end) + |> Enum.filter(fn {sequence, _value, _at} -> filter_matches?(sub.filter, sequence) end) + |> Enum.take_while(fn {sequence, _, _} -> not filter_exhausted?(sub.filter, sequence) end) # Advance cursor past the page even if no items matched this filter — - # otherwise we'd re-fetch the same positions forever. + # otherwise we'd re-fetch the same sequences forever. advance_to = if filtered == [] do elem(List.last(items), 0) + 1 @@ -7511,8 +7504,8 @@ defmodule Coflux.Orchestration.Server do state else resolved_items = - Enum.map(filtered, fn {position, value, _at} -> - [position, build_value(value, state.db)] + Enum.map(filtered, fn {sequence, value, _at} -> + [sequence, build_value(value, state.db)] end) send_to_consumer( @@ -7554,20 +7547,20 @@ defmodule Coflux.Orchestration.Server do end # Push a freshly-appended item to every subscriber of this stream. - defp push_stream_item(state, producer_execution_id, sequence, position, value) do + defp push_stream_item(state, producer_execution_id, index, sequence, value) do subscribers = - Map.get(state.stream_subscribers, {producer_execution_id, sequence}, MapSet.new()) + Map.get(state.stream_subscribers, {producer_execution_id, index}, MapSet.new()) Enum.reduce(subscribers, state, fn key, state -> {_consumer_execution_id, subscription_id} = key sub = Map.fetch!(state.stream_subscriptions, key) cond do - position < sub.cursor -> - # Consumer already has this position via backlog; skip. + sequence < sub.cursor -> + # Consumer already has this sequence via backlog; skip. state - not filter_matches?(sub.filter, position) -> + not filter_matches?(sub.filter, sequence) -> state true -> @@ -7575,7 +7568,7 @@ defmodule Coflux.Orchestration.Server do # Normalise + resolve to match the form push_backlog sends; the WS # handler composes to wire JSON. resolved = build_value(normalize_value(value), state.db) - item = [position, resolved] + item = [sequence, resolved] state = send_to_consumer( @@ -7587,12 +7580,12 @@ defmodule Coflux.Orchestration.Server do state = update_in( state.stream_subscriptions[key], - &Map.put(&1, :cursor, position + 1) + &Map.put(&1, :cursor, sequence + 1) ) # If the filter is exhausted (e.g. slice reached its stop), close # the subscription early — no more items will match. - if filter_exhausted?(sub.filter, position + 1) do + if filter_exhausted?(sub.filter, sequence + 1) do state |> send_to_consumer( sub, @@ -7608,9 +7601,9 @@ defmodule Coflux.Orchestration.Server do # On close, tell every subscriber. Error is either nil (clean close) or a # {type, message, frames} triple — same shape as Streams.close_stream takes. - defp push_stream_closed(state, producer_execution_id, sequence, error) do + defp push_stream_closed(state, producer_execution_id, index, error) do subscribers = - Map.get(state.stream_subscribers, {producer_execution_id, sequence}, MapSet.new()) + Map.get(state.stream_subscribers, {producer_execution_id, index}, MapSet.new()) encoded_error = encode_stream_error(error) @@ -7658,7 +7651,7 @@ defmodule Coflux.Orchestration.Server do defp do_maybe_push_closure(state, sub, key) do {_consumer_execution_id, subscription_id} = key - case Streams.get_stream_closure(state.db, sub.producer_execution_id, sub.sequence) do + case Streams.get_stream_closure(state.db, sub.producer_execution_id, sub.index) do {:ok, nil} -> state @@ -7685,7 +7678,7 @@ defmodule Coflux.Orchestration.Server do state {:ok, sub} -> - stream_key = {sub.producer_execution_id, sub.sequence} + stream_key = {sub.producer_execution_id, sub.index} state |> Map.update!(:stream_subscriptions, &Map.delete(&1, key)) diff --git a/server/lib/coflux/orchestration/streams.ex b/server/lib/coflux/orchestration/streams.ex index 1458edd5..9d5434de 100644 --- a/server/lib/coflux/orchestration/streams.ex +++ b/server/lib/coflux/orchestration/streams.ex @@ -3,14 +3,22 @@ defmodule Coflux.Orchestration.Streams do Storage for execution-produced streams. A stream is an ordered, append-only sequence of values produced by an - execution. Each stream is identified by `(execution_id, sequence)` where - sequence is assigned monotonically by the worker during return-value + execution. Each stream is identified by `(execution_id, index)` where + `index` is assigned monotonically by the worker during return-value serialisation — the worker mints ids locally, no server round-trip. + Items within a stream are identified by `sequence` — a 0-based, + monotonically increasing per-item counter. + + The SQL column is quoted with backticks (``` `index` ```) throughout + queries because `INDEX` is a SQLite keyword; at the Elixir level we + just pass `:index` as a map key — the Store helper handles quoting + for inserts. + Invariants enforced here (and by schema FKs): * A stream is owned by exactly one execution (its producer). - * Items are append-only with monotonic `position` starting at 0. + * Items are append-only with monotonic `sequence` starting at 0. * A closure is terminal — no items may be appended after one is recorded. * On execution completion / cancel / crash, every owned stream that lacks a closure receives one (clean, cancelled, or crashed). Enforced by the @@ -18,22 +26,22 @@ defmodule Coflux.Orchestration.Streams do * Re-running a producer execution creates fresh streams (new attempt ⇒ new execution_id ⇒ new rows). Consumer refs pin to the original streams. * Consumer cursors are kept in-memory only; re-run consumers subscribe - fresh from position 0. + fresh from sequence 0. """ import Coflux.Store alias Coflux.Orchestration.{Errors, Values} - # Registers a new stream owned by `execution_id` with the given `sequence` - # (monotonic per-execution, worker-assigned). Returns `{:error, :already_registered}` - # if the sequence was already used. - def register_stream(db, execution_id, sequence) do + # Registers a new stream owned by `execution_id` at `index` (monotonic + # per-execution, worker-assigned). Returns `{:error, :already_registered}` + # if the index was already used. + def register_stream(db, execution_id, index) do now = current_timestamp() case insert_one(db, :streams, %{ execution_id: execution_id, - sequence: sequence, + index: index, created_at: now }) do {:ok, _} -> {:ok, now} @@ -41,19 +49,19 @@ defmodule Coflux.Orchestration.Streams do end end - # Appends an item at `position` to the stream. Caller supplies the position + # Appends an item at `sequence` to the stream. Caller supplies the sequence # (worker-assigned, monotonic). Returns: # * `{:error, :not_registered}` if the stream doesn't exist # * `{:error, :closed}` if the stream has a closure row - # * `{:error, :already_appended}` if position collides with an existing item - def append_item(db, execution_id, sequence, position, value) do + # * `{:error, :already_appended}` if sequence collides with an existing item + def append_item(db, execution_id, index, sequence, value) do with_transaction(db, fn -> - case has_closure?(db, execution_id, sequence) do + case has_closure?(db, execution_id, index) do {:ok, true} -> {:error, :closed} {:ok, false} -> - case exists?(db, execution_id, sequence) do + case exists?(db, execution_id, index) do {:ok, false} -> {:error, :not_registered} @@ -63,8 +71,8 @@ defmodule Coflux.Orchestration.Streams do case insert_one(db, :stream_items, %{ execution_id: execution_id, + index: index, sequence: sequence, - position: position, value_id: value_id, created_at: now }) do @@ -85,9 +93,9 @@ defmodule Coflux.Orchestration.Streams do # ended (cancel/crash/abandon/error). No error is recorded here — # callers that need to surface an error derive it from the # execution's recorded result at read time. - def close_stream(db, execution_id, sequence, spec \\ :complete) do + def close_stream(db, execution_id, index, spec \\ :complete) do with_transaction(db, fn -> - case exists?(db, execution_id, sequence) do + case exists?(db, execution_id, index) do {:ok, false} -> {:error, :not_registered} @@ -97,7 +105,7 @@ defmodule Coflux.Orchestration.Streams do case insert_one(db, :stream_closures, %{ execution_id: execution_id, - sequence: sequence, + index: index, reason: reason, error_id: error_id, created_at: now @@ -129,11 +137,11 @@ defmodule Coflux.Orchestration.Streams do def reason_from_int(@reason_errored), do: :errored def reason_from_int(@reason_lifecycle), do: :lifecycle - def exists?(db, execution_id, sequence) do + def exists?(db, execution_id, index) do case query_one( db, - "SELECT 1 FROM streams WHERE execution_id = ?1 AND sequence = ?2", - {execution_id, sequence} + "SELECT 1 FROM streams WHERE execution_id = ?1 AND `index` = ?2", + {execution_id, index} ) do {:ok, nil} -> {:ok, false} {:ok, {1}} -> {:ok, true} @@ -141,59 +149,59 @@ defmodule Coflux.Orchestration.Streams do end # Returns the stream's registration timestamp, or `{:error, :not_found}`. - def get_opened_at(db, execution_id, sequence) do + def get_opened_at(db, execution_id, index) do case query_one( db, - "SELECT created_at FROM streams WHERE execution_id = ?1 AND sequence = ?2", - {execution_id, sequence} + "SELECT created_at FROM streams WHERE execution_id = ?1 AND `index` = ?2", + {execution_id, index} ) do {:ok, nil} -> {:error, :not_found} {:ok, {created_at}} -> {:ok, created_at} end end - def has_closure?(db, execution_id, sequence) do + def has_closure?(db, execution_id, index) do case query_one( db, - "SELECT 1 FROM stream_closures WHERE execution_id = ?1 AND sequence = ?2", - {execution_id, sequence} + "SELECT 1 FROM stream_closures WHERE execution_id = ?1 AND `index` = ?2", + {execution_id, index} ) do {:ok, nil} -> {:ok, false} {:ok, {1}} -> {:ok, true} end end - # Returns `{:ok, [sequence, ...]}` for every stream owned by `execution_id`, - # in sequence order. + # Returns `{:ok, [index, ...]}` for every stream owned by `execution_id`, + # in index order. def get_streams_for_execution(db, execution_id) do case query( db, - "SELECT sequence FROM streams WHERE execution_id = ?1 ORDER BY sequence", + "SELECT `index` FROM streams WHERE execution_id = ?1 ORDER BY `index`", {execution_id} ) do {:ok, rows} -> - {:ok, Enum.map(rows, fn {sequence} -> sequence end)} + {:ok, Enum.map(rows, fn {index} -> index end)} end end - # Returns sequences of streams owned by `execution_id` that don't yet have + # Returns indexes of streams owned by `execution_id` that don't yet have # a closure row. Used by the lifecycle code to discover which streams to # close on completion / cancel / crash. def get_open_streams_for_execution(db, execution_id) do case query( db, """ - SELECT s.sequence + SELECT s.`index` FROM streams AS s LEFT JOIN stream_closures AS c - ON c.execution_id = s.execution_id AND c.sequence = s.sequence + ON c.execution_id = s.execution_id AND c.`index` = s.`index` WHERE s.execution_id = ?1 AND c.execution_id IS NULL - ORDER BY s.sequence + ORDER BY s.`index` """, {execution_id} ) do {:ok, rows} -> - {:ok, Enum.map(rows, fn {sequence} -> sequence end)} + {:ok, Enum.map(rows, fn {index} -> index end)} end end @@ -203,11 +211,11 @@ defmodule Coflux.Orchestration.Streams do # * error is the `{type, message, frames}` triple for :errored, nil # otherwise (callers derive it from the execution's result on # :lifecycle) - def get_stream_closure(db, execution_id, sequence) do + def get_stream_closure(db, execution_id, index) do case query_one( db, - "SELECT reason, error_id, created_at FROM stream_closures WHERE execution_id = ?1 AND sequence = ?2", - {execution_id, sequence} + "SELECT reason, error_id, created_at FROM stream_closures WHERE execution_id = ?1 AND `index` = ?2", + {execution_id, index} ) do {:ok, nil} -> {:ok, nil} @@ -221,26 +229,26 @@ defmodule Coflux.Orchestration.Streams do end end - # Fetches up to `max_items` items from the stream starting at `from_position`. - # Returns `{:ok, [{position, value, created_at}, ...]}` in position order. + # Fetches up to `max_items` items from the stream starting at `from_sequence`. + # Returns `{:ok, [{sequence, value, created_at}, ...]}` in sequence order. # The caller (Server) layers filter logic (slice / partition) on top of this. - def get_stream_items(db, execution_id, sequence, from_position, max_items) do + def get_stream_items(db, execution_id, index, from_sequence, max_items) do case query( db, """ - SELECT position, value_id, created_at + SELECT sequence, value_id, created_at FROM stream_items - WHERE execution_id = ?1 AND sequence = ?2 AND position >= ?3 - ORDER BY position + WHERE execution_id = ?1 AND `index` = ?2 AND sequence >= ?3 + ORDER BY sequence LIMIT ?4 """, - {execution_id, sequence, from_position, max_items} + {execution_id, index, from_sequence, max_items} ) do {:ok, rows} -> items = - Enum.map(rows, fn {position, value_id, created_at} -> + Enum.map(rows, fn {sequence, value_id, created_at} -> {:ok, value} = Values.get_value_by_id(db, value_id) - {position, value, created_at} + {sequence, value, created_at} end) {:ok, items} @@ -248,7 +256,7 @@ defmodule Coflux.Orchestration.Streams do end # Returns one row per stream owned by `execution_id`: - # `{sequence, created_at, closed_at | nil, reason | nil, error | nil}`. + # `{index, created_at, closed_at | nil, reason | nil, error | nil}`. # * reason is :complete | :errored | :lifecycle when closed, nil when open # * error is the stored `{type, message, frames}` triple for :errored # closures only — callers that need to surface an error for a @@ -258,76 +266,76 @@ defmodule Coflux.Orchestration.Streams do case query( db, """ - SELECT s.sequence, s.created_at, c.created_at, c.reason, c.error_id + SELECT s.`index`, s.created_at, c.created_at, c.reason, c.error_id FROM streams AS s LEFT JOIN stream_closures AS c - ON c.execution_id = s.execution_id AND c.sequence = s.sequence + ON c.execution_id = s.execution_id AND c.`index` = s.`index` WHERE s.execution_id = ?1 - ORDER BY s.sequence + ORDER BY s.`index` """, {execution_id} ) do {:ok, rows} -> streams = Enum.map(rows, fn - {sequence, created_at, nil, nil, nil} -> - {sequence, created_at, nil, nil, nil} + {index, created_at, nil, nil, nil} -> + {index, created_at, nil, nil, nil} - {sequence, created_at, closed_at, reason_int, nil} -> - {sequence, created_at, closed_at, reason_from_int(reason_int), nil} + {index, created_at, closed_at, reason_int, nil} -> + {index, created_at, closed_at, reason_from_int(reason_int), nil} - {sequence, created_at, closed_at, reason_int, error_id} -> + {index, created_at, closed_at, reason_int, error_id} -> {:ok, error} = Errors.get_by_id(db, error_id) - {sequence, created_at, closed_at, reason_from_int(reason_int), error} + {index, created_at, closed_at, reason_from_int(reason_int), error} end) {:ok, streams} end end - # Returns the highest position recorded for the stream, or `-1` if empty. + # Returns the highest sequence recorded for the stream, or `-1` if empty. # Used by the worker protocol to report "head" for flow control without # requiring the caller to scan all items. - def get_stream_head(db, execution_id, sequence) do + def get_stream_head(db, execution_id, index) do case query_one( db, - "SELECT MAX(position) FROM stream_items WHERE execution_id = ?1 AND sequence = ?2", - {execution_id, sequence} + "SELECT MAX(sequence) FROM stream_items WHERE execution_id = ?1 AND `index` = ?2", + {execution_id, index} ) do {:ok, {nil}} -> {:ok, -1} - {:ok, {position}} -> {:ok, position} + {:ok, {sequence}} -> {:ok, sequence} end end - # Returns the last `max_items` items of the stream, in position order, + # Returns the last `max_items` items of the stream, in sequence order, # alongside the total item count. Used by the inspection topic to # bootstrap its bounded tail buffer without materialising the full log. - def get_stream_tail(db, execution_id, sequence, max_items) do + def get_stream_tail(db, execution_id, index, max_items) do {:ok, {total_count}} = query_one( db, - "SELECT COUNT(*) FROM stream_items WHERE execution_id = ?1 AND sequence = ?2", - {execution_id, sequence} + "SELECT COUNT(*) FROM stream_items WHERE execution_id = ?1 AND `index` = ?2", + {execution_id, index} ) case query( db, """ - SELECT position, value_id, created_at + SELECT sequence, value_id, created_at FROM stream_items - WHERE execution_id = ?1 AND sequence = ?2 - ORDER BY position DESC + WHERE execution_id = ?1 AND `index` = ?2 + ORDER BY sequence DESC LIMIT ?3 """, - {execution_id, sequence, max_items} + {execution_id, index, max_items} ) do {:ok, rows} -> items = rows |> Enum.reverse() - |> Enum.map(fn {position, value_id, created_at} -> + |> Enum.map(fn {sequence, value_id, created_at} -> {:ok, value} = Values.get_value_by_id(db, value_id) - {position, value, created_at} + {sequence, value, created_at} end) {:ok, {items, total_count}} diff --git a/server/lib/coflux/topics/run.ex b/server/lib/coflux/topics/run.ex index c8a3dea4..88fc45e3 100644 --- a/server/lib/coflux/topics/run.ex +++ b/server/lib/coflux/topics/run.ex @@ -209,12 +209,13 @@ defmodule Coflux.Topics.Run do defp process_notification( topic, - {:stream_opened, execution_external_id, sequence, created_at} + {:stream_opened, execution_external_id, index, created_at} ) do update_execution(topic, execution_external_id, fn topic, base_path -> - Topic.set(topic, base_path ++ [:streams, Integer.to_string(sequence)], %{ + Topic.set(topic, base_path ++ [:streams, Integer.to_string(index)], %{ openedAt: created_at, closedAt: nil, + reason: nil, error: nil }) end) @@ -222,14 +223,15 @@ defmodule Coflux.Topics.Run do defp process_notification( topic, - {:stream_closed, execution_external_id, sequence, error, closed_at} + {:stream_closed, execution_external_id, index, reason, error, closed_at} ) do - seq_key = Integer.to_string(sequence) + index_key = Integer.to_string(index) update_execution(topic, execution_external_id, fn topic, base_path -> topic - |> Topic.set(base_path ++ [:streams, seq_key, :closedAt], closed_at) - |> Topic.set(base_path ++ [:streams, seq_key, :error], error) + |> Topic.set(base_path ++ [:streams, index_key, :closedAt], closed_at) + |> Topic.set(base_path ++ [:streams, index_key, :reason], reason) + |> Topic.set(base_path ++ [:streams, index_key, :error], error) end) end @@ -549,17 +551,25 @@ defmodule Coflux.Topics.Run do defp build_streams(streams) do Map.new(streams, fn - {sequence, opened_at, nil, nil} -> - {Integer.to_string(sequence), %{openedAt: opened_at, closedAt: nil, error: nil}} + {index, opened_at, nil, nil, nil} -> + {Integer.to_string(index), + %{openedAt: opened_at, closedAt: nil, reason: nil, error: nil}} - {sequence, opened_at, closed_at, nil} -> - {Integer.to_string(sequence), %{openedAt: opened_at, closedAt: closed_at, error: nil}} + {index, opened_at, closed_at, reason, nil} -> + {Integer.to_string(index), + %{ + openedAt: opened_at, + closedAt: closed_at, + reason: Atom.to_string(reason), + error: nil + }} - {sequence, opened_at, closed_at, {type, message, _frames}} -> - {Integer.to_string(sequence), + {index, opened_at, closed_at, reason, {type, message, _frames}} -> + {Integer.to_string(index), %{ openedAt: opened_at, closedAt: closed_at, + reason: Atom.to_string(reason), error: %{type: type, message: message} }} end) diff --git a/server/lib/coflux/topics/stream.ex b/server/lib/coflux/topics/stream.ex index 5b63bf0f..d97f34d4 100644 --- a/server/lib/coflux/topics/stream.ex +++ b/server/lib/coflux/topics/stream.ex @@ -1,12 +1,12 @@ defmodule Coflux.Topics.Stream do @moduledoc """ - Inspection topic for a single stream, keyed by the producer's external - execution id and sequence. Used by the Studio UI when a user opens a - stream dialog — the topic keeps a bounded tail of items (with resolved - values) plus closure state, and receives live updates as items are - appended or the stream is closed. + Inspection topic for a single stream, keyed by the stream's opaque id + (``_``). Used by the Studio UI when a + user opens a stream dialog — the topic keeps a bounded tail of items + (with resolved values) plus closure state, and receives live updates + as items are appended or the stream is closed. """ - use Topical.Topic, route: ["streams", :execution_id, :sequence] + use Topical.Topic, route: ["streams", :id] alias Coflux.Orchestration alias Coflux.TopicUtils @@ -17,13 +17,21 @@ defmodule Coflux.Topics.Stream do def init(params) do project_id = Map.fetch!(params, :project) - execution_id = Map.fetch!(params, :execution_id) - sequence = parse_sequence(Map.fetch!(params, :sequence)) + case parse_id(Map.fetch!(params, :id)) do + {:ok, execution_id, index} -> + do_init(project_id, execution_id, index) + + :error -> + {:error, :not_found} + end + end + + defp do_init(project_id, execution_id, index) do case Orchestration.subscribe_stream_topic( project_id, execution_id, - sequence, + index, self() ) do {:ok, initial, ref} -> @@ -32,10 +40,8 @@ defmodule Coflux.Topics.Stream do %{ producer: initial.producer, openedAt: initial.openedAt, - closedAt: initial.closedAt, closure: build_closure(initial.closure), items: Enum.map(initial.items, &build_item/1), - firstPosition: initial.firstPosition, totalCount: initial.totalCount, tailSize: initial.tailSize }, @@ -52,48 +58,34 @@ defmodule Coflux.Topics.Stream do {:ok, topic} end - defp process_notification({:item_appended, position, value, created_at}, topic) do + defp process_notification({:item_appended, sequence, value, created_at}, topic) do tail_size = topic.state.tail_size || 200 - item = build_item({position, value, created_at}) - + item = build_item({sequence, value, created_at}) existing = topic.value.items - total = topic.value.totalCount + 1 - - # Keep items bounded: if we're already at capacity, drop the head. - # Otherwise append and leave firstPosition alone (or set it if empty). - {new_items, new_first_position} = - cond do - existing == [] -> - {[item], position} - - length(existing) >= tail_size -> - [_dropped | rest] = existing - new_items = rest ++ [item] - [first_item | _] = new_items - {new_items, first_item.position} - - true -> - {existing ++ [item], topic.value.firstPosition} + + # Keep items bounded: drop the head once we're at capacity. + new_items = + if length(existing) >= tail_size do + [_dropped | rest] = existing + rest ++ [item] + else + existing ++ [item] end topic |> Topic.set([:items], new_items) - |> Topic.set([:firstPosition], new_first_position) - |> Topic.set([:totalCount], total) + |> Topic.set([:totalCount], topic.value.totalCount + 1) end - defp process_notification({:closed, error, closed_at}, topic) do - closure = build_closure_from_notification(error, closed_at) - - topic - |> Topic.set([:closedAt], closed_at) - |> Topic.set([:closure], closure) + defp process_notification({:closed, reason, error, closed_at}, topic) do + closure = %{reason: reason, error: error, closedAt: closed_at} + Topic.set(topic, [:closure], closure) end - defp build_item({position, value, created_at}) do + defp build_item({sequence, value, created_at}) do %{ - position: position, + sequence: sequence, value: TopicUtils.build_value(value), createdAt: created_at } @@ -109,27 +101,26 @@ defmodule Coflux.Topics.Stream do } end - # The live-close notification carries the already-encoded error summary - # (type/message) and closedAt — it doesn't include the reason because - # the notification is the same shape used by the run topic. Default to - # "errored" when there's an error, "complete" otherwise; the distinction - # between errored/lifecycle is resolved server-side before we get here. - defp build_closure_from_notification(error, closed_at) do - reason = if error, do: "errored", else: "complete" - - %{ - reason: reason, - error: error, - closedAt: closed_at - } - end - - defp parse_sequence(s) when is_integer(s), do: s - - defp parse_sequence(s) when is_binary(s) do - case Integer.parse(s) do - {n, ""} -> n - _ -> raise ArgumentError, "invalid stream sequence: #{inspect(s)}" + # Split an opaque stream id back into (execution_id, index). The + # separator is `_` — execution ids use alphanumerics + `:`, so the last + # `_` unambiguously marks the index suffix. + defp parse_id(id) when is_binary(id) do + case String.split(id, "_") do + parts when length(parts) >= 2 -> + {index_str, execution_parts} = List.pop_at(parts, -1) + + with {index, ""} when index >= 0 <- Integer.parse(index_str), + execution_id when execution_id != "" <- + Enum.join(execution_parts, "_") do + {:ok, execution_id, index} + else + _ -> :error + end + + _ -> + :error end end + + defp parse_id(_), do: :error end diff --git a/server/priv/migrations/orchestration/4.sql b/server/priv/migrations/orchestration/4.sql index 8bac925a..7910e69f 100644 --- a/server/priv/migrations/orchestration/4.sql +++ b/server/priv/migrations/orchestration/4.sql @@ -20,14 +20,15 @@ INSERT INTO completions (execution_id, created_at) SELECT execution_id, created_at FROM results; -- Streams — ordered, append-only sequences of values produced by an --- execution. Each stream is identified by (execution_id, sequence), where --- sequence is assigned monotonically by the worker when serialising the --- execution's return value. The worker manages allocation locally, so no --- server round-trip is needed to mint an id. +-- execution. Each stream is identified by (execution_id, index), where +-- `index` is assigned monotonically by the worker when serialising the +-- execution's return value. The worker manages allocation locally, so +-- no server round-trip is needed to mint an id. The column is quoted +-- with backticks throughout because INDEX is a SQLite keyword. -- -- Invariants: -- • A stream is owned by exactly one execution (its producer). --- • stream_items are append-only with monotonic position starting at 0. +-- • stream_items are append-only with monotonic sequence starting at 0. -- • stream_closures are terminal — no items may be appended after closure. -- • On execution completion / cancellation / crash, every owned stream -- that lacks a closure receives one (clean, cancelled, or crashed). @@ -35,24 +36,24 @@ INSERT INTO completions (execution_id, created_at) -- new execution_id ⇒ new rows). Consumer references are concrete to -- the original streams. -- • Consumer cursors are kept in-memory only; re-run consumers subscribe --- fresh from position 0. +-- fresh from sequence 0. CREATE TABLE streams ( execution_id INTEGER NOT NULL, - sequence INTEGER NOT NULL, + `index` INTEGER NOT NULL, created_at INTEGER NOT NULL, - PRIMARY KEY (execution_id, sequence), + PRIMARY KEY (execution_id, `index`), FOREIGN KEY (execution_id) REFERENCES executions ON DELETE CASCADE ) STRICT; CREATE TABLE stream_items ( execution_id INTEGER NOT NULL, + `index` INTEGER NOT NULL, sequence INTEGER NOT NULL, - position INTEGER NOT NULL, value_id INTEGER NOT NULL, created_at INTEGER NOT NULL, - PRIMARY KEY (execution_id, sequence, position), - FOREIGN KEY (execution_id, sequence) REFERENCES streams (execution_id, sequence) ON DELETE CASCADE, + PRIMARY KEY (execution_id, `index`, sequence), + FOREIGN KEY (execution_id, `index`) REFERENCES streams (execution_id, `index`) ON DELETE CASCADE, FOREIGN KEY (value_id) REFERENCES values_ ON DELETE RESTRICT ) STRICT; @@ -65,12 +66,12 @@ CREATE TABLE stream_items ( -- so we don't duplicate that state here. CREATE TABLE stream_closures ( execution_id INTEGER NOT NULL, - sequence INTEGER NOT NULL, + `index` INTEGER NOT NULL, reason INTEGER NOT NULL, error_id INTEGER, created_at INTEGER NOT NULL, - PRIMARY KEY (execution_id, sequence), - FOREIGN KEY (execution_id, sequence) REFERENCES streams (execution_id, sequence) ON DELETE CASCADE, + PRIMARY KEY (execution_id, `index`), + FOREIGN KEY (execution_id, `index`) REFERENCES streams (execution_id, `index`) ON DELETE CASCADE, FOREIGN KEY (error_id) REFERENCES errors ON DELETE RESTRICT, CHECK ((reason = 1) = (error_id IS NOT NULL)) ) STRICT; diff --git a/tests/support/executor.py b/tests/support/executor.py index e282d8ed..08069279 100644 --- a/tests/support/executor.py +++ b/tests/support/executor.py @@ -258,17 +258,17 @@ def resolve_input( # --- Stream producer helpers --- - def stream_register(self, execution_id, sequence): + def stream_register(self, execution_id, index): """Notify that a new stream exists.""" - self.send(protocol.stream_register(execution_id, sequence)) + self.send(protocol.stream_register(execution_id, index)) - def stream_append(self, execution_id, sequence, position, value, format="json"): + def stream_append(self, execution_id, index, sequence, value, format="json"): """Append an item (raw JSON value) to a stream.""" - self.send(protocol.stream_append(execution_id, sequence, position, value, format=format)) + self.send(protocol.stream_append(execution_id, index, sequence, value, format=format)) - def stream_close(self, execution_id, sequence, error=None): + def stream_close(self, execution_id, index, error=None): """Close a stream (optionally with an error {type, message, traceback}).""" - self.send(protocol.stream_close(execution_id, sequence, error=error)) + self.send(protocol.stream_close(execution_id, index, error=error)) # --- Stream consumer helpers --- @@ -277,8 +277,8 @@ def stream_subscribe( execution_id, subscription_id, producer_execution_id, - sequence, - from_position=0, + index, + from_sequence=0, filter=None, ): """Subscribe to a stream. ``filter`` is an optional dict built via @@ -288,8 +288,8 @@ def stream_subscribe( execution_id, subscription_id, producer_execution_id, - sequence, - from_position=from_position, + index, + from_sequence=from_sequence, filter=filter, ) ) diff --git a/tests/support/protocol.py b/tests/support/protocol.py index d0512e7a..62c25384 100644 --- a/tests/support/protocol.py +++ b/tests/support/protocol.py @@ -234,25 +234,27 @@ def register_group_notification(execution_id, group_id, name=None): # --- Stream messages (producer side: adapter → server) --- -def stream_register(execution_id, sequence): +def stream_register(execution_id, index): return { "method": "stream_register", - "params": {"execution_id": execution_id, "sequence": sequence}, + "params": {"execution_id": execution_id, "index": index}, } -def stream_append(execution_id, sequence, position, value, format="json"): +def stream_append(execution_id, index, sequence, value, format="json"): """Append an item to a stream. ``value`` is the raw JSON value. - Builds a Value wire-form message with an empty references list. Tests - that need references should build the Value dict manually. + ``index`` identifies the stream within its execution; ``sequence`` + identifies the item within the stream. Builds a Value wire-form + message with an empty references list. Tests that need references + should build the Value dict manually. """ return { "method": "stream_append", "params": { "execution_id": execution_id, + "index": index, "sequence": sequence, - "position": position, "value": { "type": "inline", "format": format, @@ -263,9 +265,9 @@ def stream_append(execution_id, sequence, position, value, format="json"): } -def stream_close(execution_id, sequence, error=None): +def stream_close(execution_id, index, error=None): """Close a stream. ``error`` is optional {type, message, traceback}.""" - params = {"execution_id": execution_id, "sequence": sequence} + params = {"execution_id": execution_id, "index": index} if error is not None: params["error"] = error return {"method": "stream_close", "params": params} @@ -278,16 +280,16 @@ def stream_subscribe( execution_id, subscription_id, producer_execution_id, - sequence, - from_position=0, + index, + from_sequence=0, filter=None, ): params = { "execution_id": execution_id, "subscription_id": subscription_id, "producer_execution_id": producer_execution_id, - "sequence": sequence, - "from_position": from_position, + "index": index, + "from_sequence": from_sequence, } if filter is not None: params["filter"] = filter diff --git a/tests/test_streams.py b/tests/test_streams.py index 923af984..1e6f3dbe 100644 --- a/tests/test_streams.py +++ b/tests/test_streams.py @@ -68,7 +68,7 @@ def test_producer_writes_and_consumer_reads_backlog(worker): cons_ex.execution_id, subscription_id=1, producer_execution_id=prod_ex.execution_id, - sequence=0, + index=0, ) items, closed = cons_ex.conn.drain_stream(subscription_id=1) cons_ex.conn.complete(cons_ex.execution_id) @@ -98,7 +98,7 @@ def test_consumer_sees_live_push(worker): cons_ex.execution_id, subscription_id=1, producer_execution_id=prod_ex.execution_id, - sequence=0, + index=0, ) # Now producer appends + closes. @@ -135,7 +135,7 @@ def test_slice_filter_restricts_items(worker): cons_ex.execution_id, subscription_id=1, producer_execution_id=prod_ex.execution_id, - sequence=0, + index=0, filter=slice_filter(1, 3), ) items, _ = cons_ex.conn.drain_stream(subscription_id=1) @@ -165,7 +165,7 @@ def test_partition_filter_round_robin(worker): cons_ex.execution_id, subscription_id=1, producer_execution_id=prod_ex.execution_id, - sequence=0, + index=0, filter=partition_filter(n=3, i=1), ) items, _ = cons_ex.conn.drain_stream(subscription_id=1) @@ -199,7 +199,7 @@ def test_producer_error_closes_with_error_info(worker): cons_ex.execution_id, subscription_id=1, producer_execution_id=prod_ex.execution_id, - sequence=0, + index=0, ) items, closed = cons_ex.conn.drain_stream(subscription_id=1) cons_ex.conn.complete(cons_ex.execution_id) @@ -224,7 +224,7 @@ def test_subscribe_to_unknown_producer_closes_immediately(worker): cons_ex.execution_id, subscription_id=1, producer_execution_id="00000000:0:0", - sequence=0, + index=0, ) _items, closed = cons_ex.conn.drain_stream(subscription_id=1) cons_ex.conn.complete(cons_ex.execution_id) @@ -261,8 +261,10 @@ def test_topic_exposes_stream_state(worker): assert "0" in streams and "1" in streams assert streams["0"]["openedAt"] is not None assert streams["0"]["closedAt"] is not None + assert streams["0"]["reason"] == "complete" assert streams["0"]["error"] is None assert streams["1"]["closedAt"] is not None + assert streams["1"]["reason"] == "errored" assert streams["1"]["error"] == {"type": "RuntimeError", "message": "bad"} @@ -284,7 +286,7 @@ def test_cancellation_closes_streams_with_cancelled_error(worker): cons_ex.execution_id, subscription_id=1, producer_execution_id=prod_ex.execution_id, - sequence=0, + index=0, ) ctx.cancel(prod_ex.execution_id) @@ -323,7 +325,7 @@ def test_multiple_subscribers_get_independent_delivery(worker): a_ex.execution_id, subscription_id=7, producer_execution_id=prod_ex.execution_id, - sequence=0, + index=0, ) b_resp = ctx.submit("test", "consumer") @@ -332,7 +334,7 @@ def test_multiple_subscribers_get_independent_delivery(worker): b_ex.execution_id, subscription_id=42, producer_execution_id=prod_ex.execution_id, - sequence=0, + index=0, ) a_items, _ = a_ex.conn.drain_stream(subscription_id=7) @@ -369,7 +371,7 @@ def test_subscription_ids_can_collide_across_consumers(worker): a_ex.execution_id, subscription_id=1, producer_execution_id=prod_ex.execution_id, - sequence=0, + index=0, ) b_resp = ctx.submit("test", "consumer") @@ -378,7 +380,7 @@ def test_subscription_ids_can_collide_across_consumers(worker): b_ex.execution_id, subscription_id=1, producer_execution_id=prod_ex.execution_id, - sequence=0, + index=0, ) a_items, a_closed = a_ex.conn.drain_stream(subscription_id=1) @@ -410,7 +412,7 @@ def test_consumer_termination_drops_subscription(worker): cons_ex.execution_id, subscription_id=1, producer_execution_id=prod_ex.execution_id, - sequence=0, + index=0, ) prod_ex.conn.stream_append(prod_ex.execution_id, 0, 0, "before") @@ -452,7 +454,7 @@ def test_slice_with_stop_closes_early(worker): cons_ex.execution_id, subscription_id=1, producer_execution_id=prod_ex.execution_id, - sequence=0, + index=0, filter=slice_filter(0, 2), ) @@ -494,7 +496,7 @@ def test_unsubscribe_prevents_receiving_full_stream(worker): cons_ex.execution_id, subscription_id=1, producer_execution_id=prod_ex.execution_id, - sequence=0, + index=0, ) prod_ex.conn.stream_append(prod_ex.execution_id, 0, 0, 0) @@ -511,22 +513,22 @@ def test_unsubscribe_prevents_receiving_full_stream(worker): # Collect anything still in flight. It might include the next # one or two items (racing with unsubscribe) but MUST NOT include # the tail — some positions drop out between unsubscribe and close. - received_positions = [0] + received_sequences = [0] try: while True: msg = cons_ex.conn.recv_push( "stream_items", subscription_id=1, timeout=0.5 ) for item in msg["items"]: - received_positions.append(item[0]) + received_sequences.append(item[0]) except TimeoutError: pass cons_ex.conn.complete(cons_ex.execution_id) # The consumer should have received strictly fewer than all 10 items. - assert len(received_positions) < 10, ( - f"unsubscribe should stop further delivery; got {received_positions}" + assert len(received_sequences) < 10, ( + f"unsubscribe should stop further delivery; got {received_sequences}" ) @@ -547,7 +549,7 @@ def test_close_while_subscribed_delivers_closure(worker): cons_ex.execution_id, subscription_id=1, producer_execution_id=prod_ex.execution_id, - sequence=0, + index=0, ) prod_ex.conn.stream_append(prod_ex.execution_id, 0, 0, "only") @@ -580,7 +582,7 @@ def test_lifecycle_close_on_completion_delivers_to_subscriber(worker): cons_ex.execution_id, subscription_id=1, producer_execution_id=prod_ex.execution_id, - sequence=0, + index=0, ) # Producer completes *without* closing the stream. @@ -613,7 +615,7 @@ def test_filter_chain_combines_slice_and_partition(worker): cons_ex.execution_id, subscription_id=1, producer_execution_id=prod_ex.execution_id, - sequence=0, + index=0, filter=chain_filter(slice_filter(0, 6), partition_filter(n=2, i=0)), ) items, _ = cons_ex.conn.drain_stream(subscription_id=1) From 589612bbaa2b83c67722fdd2d6d6ec1a1fff48e2 Mon Sep 17 00:00:00 2001 From: Joe Freeman Date: Sun, 19 Apr 2026 15:18:34 +0100 Subject: [PATCH 14/25] Support configuring buffer/backpressure --- adapters/python/coflux/__init__.py | 3 + adapters/python/coflux/context.py | 11 +- adapters/python/coflux/decorators.py | 15 +- adapters/python/coflux/executor.py | 24 +- adapters/python/coflux/protocol.py | 20 +- adapters/python/coflux/serialization.py | 57 ++- adapters/python/coflux/streams.py | 147 +++++++- adapters/python/coflux/target.py | 49 ++- cli/internal/adapter/protocol.go | 12 +- cli/internal/pool/pool.go | 7 +- cli/internal/worker/worker.go | 37 +- server/lib/coflux/handlers/worker.ex | 15 +- server/lib/coflux/orchestration.ex | 7 +- server/lib/coflux/orchestration/server.ex | 387 +++++++++++++++++---- server/lib/coflux/orchestration/streams.ex | 42 ++- server/lib/coflux/topics/run.ex | 13 +- server/lib/coflux/topics/stream.ex | 1 + server/priv/migrations/orchestration/4.sql | 6 + tests/support/executor.py | 6 +- tests/support/protocol.py | 10 +- tests/test_streams.py | 117 +++++++ 21 files changed, 821 insertions(+), 165 deletions(-) diff --git a/adapters/python/coflux/__init__.py b/adapters/python/coflux/__init__.py index ea7e4e0a..f181b42e 100644 --- a/adapters/python/coflux/__init__.py +++ b/adapters/python/coflux/__init__.py @@ -34,6 +34,7 @@ ) from .prompt import Prompt from .state import get_context +from .streams import stream from .target import Cache, Defer, Retries __all__ = [ @@ -65,6 +66,8 @@ "AssetEntry", "AssetMetadata", "Stream", + # Producer-side stream helper + "stream", # Context functions "group", "suspense", diff --git a/adapters/python/coflux/context.py b/adapters/python/coflux/context.py index ca316088..b25c68a7 100644 --- a/adapters/python/coflux/context.py +++ b/adapters/python/coflux/context.py @@ -98,13 +98,14 @@ def __init__(self, execution_id: str, working_dir: Path | None = None): # are registered here and driven in background threads. self._stream_driver = StreamDriver(execution_id) - def register_stream(self, generator: Any) -> str: - """Callback for ``serialize_value(on_generator=...)``. + def register_stream(self, generator: Any, buffer: int | None) -> str: + """Register a generator with this execution's stream driver and + return the resulting opaque stream id. - Registers a generator with this execution's driver and returns the - opaque stream ``id`` to embed in the serialized value. + Called from ``cf.stream(...)``; also from the executor when the + task body itself is a generator. """ - return self._stream_driver.register(generator) + return self._stream_driver.register(generator, buffer) def wait_streams(self) -> None: """Block until every stream produced by this execution has drained.""" diff --git a/adapters/python/coflux/decorators.py b/adapters/python/coflux/decorators.py index 272c5f39..a50aab72 100644 --- a/adapters/python/coflux/decorators.py +++ b/adapters/python/coflux/decorators.py @@ -5,7 +5,7 @@ import datetime as dt import typing as t -from .target import Cache, Defer, Retries, Target +from .target import _BUFFER_UNSET, Cache, Defer, Retries, Target if t.TYPE_CHECKING: from .models import Stream @@ -52,12 +52,20 @@ def task( memo: bool | t.Iterable[str] = False, requires: dict[str, str | bool | list[str]] | None = None, timeout: float | dt.timedelta = 0, + buffer: int | None = _BUFFER_UNSET, # type: ignore[assignment] ) -> _TargetDecorator: """Decorator for defining a task. For ``async def`` functions, the coroutine is run to completion by the executor; the task's return type is the coroutine's resolved value (not the coroutine itself). + + ``buffer`` only applies to generator-bodied tasks. ``0`` (default) + gives strict lockstep: the producer emits an item, waits for a + consumer to ack, then emits the next. ``N`` lets the producer stay + up to N items ahead of the fastest consumer. ``None`` disables + backpressure entirely. Passing ``buffer`` on a non-generator task + raises ``TypeError`` at decoration time. """ def decorator(fn): @@ -74,6 +82,7 @@ def decorator(fn): memo=memo, requires=requires, timeout=timeout, + buffer=buffer, ) return decorator # type: ignore[return-value] @@ -91,12 +100,15 @@ def workflow( memo: bool = False, requires: dict[str, str | bool | list[str]] | None = None, timeout: float | dt.timedelta = 0, + buffer: int | None = _BUFFER_UNSET, # type: ignore[assignment] ) -> _TargetDecorator: """Decorator for defining a workflow. For ``async def`` functions, the coroutine is run to completion by the executor; the workflow's return type is the coroutine's resolved value (not the coroutine itself). + + See ``@cf.task`` for ``buffer=`` semantics. """ def decorator(fn): @@ -113,6 +125,7 @@ def decorator(fn): memo=memo, requires=requires, timeout=timeout, + buffer=buffer, ) return decorator # type: ignore[return-value] diff --git a/adapters/python/coflux/executor.py b/adapters/python/coflux/executor.py index 5bc51df9..01ee4ad4 100644 --- a/adapters/python/coflux/executor.py +++ b/adapters/python/coflux/executor.py @@ -112,10 +112,26 @@ def execute_target( else: result = fn(*deserialized_args) - # Serialize result. Generators anywhere in the return value (or that - # were passed to submitted child executions as args) have already - # been registered with the context's stream driver. - result_value = serialize_value(result, on_generator=ctx.register_stream) + # If the task body was itself a generator (``def`` + ``yield`` + # or ``async def`` + ``yield``), the call above returned an + # unstarted generator object. Register it with the task's + # configured buffer so callers don't have to wrap explicitly. + # Streams created via cf.stream(...) are already registered; + # they appear here as Stream handles, not generators. + if (inspect.isgenerator(result) or inspect.isasyncgen(result)) and hasattr( + target_obj, "definition" + ): + from .streams import stream as _register_stream + + result = _register_stream( + result, buffer=target_obj.definition.buffer + ) + + # Serialize result. Any streams returned (directly, or embedded + # in the result structure) were already registered via + # cf.stream() or the auto-wrap above; they're Stream handles + # here, not raw generators. + result_value = serialize_value(result) protocol.send_execution_result(execution_id, result_value) # Hold the process open until every stream has drained. Thread diff --git a/adapters/python/coflux/protocol.py b/adapters/python/coflux/protocol.py index dc78bdc4..a3788e2c 100644 --- a/adapters/python/coflux/protocol.py +++ b/adapters/python/coflux/protocol.py @@ -460,16 +460,26 @@ def send_metric( get_protocol().send_message("metric", params) -def send_stream_register(execution_id: str, index: int) -> None: +def send_stream_register( + execution_id: str, + index: int, + buffer: int | None = None, +) -> None: """Register a stream owned by this execution. ``index`` is worker-assigned and monotonic per execution (0, 1, 2, ...); it identifies the stream within its producer execution. + + ``buffer`` is the producer-side backpressure budget. ``None`` opts out + of backpressure entirely; the server won't issue demand grants and + the producer emits freely. Any integer value tells the server to + pace the producer — it'll send ``stream_demand`` notifications as + credits become available. """ - get_protocol().send_message( - "stream_register", - {"execution_id": execution_id, "index": index}, - ) + params: dict[str, Any] = {"execution_id": execution_id, "index": index} + if buffer is not None: + params["buffer"] = buffer + get_protocol().send_message("stream_register", params) def send_stream_append( diff --git a/adapters/python/coflux/serialization.py b/adapters/python/coflux/serialization.py index 2f83be2e..ce962b6f 100644 --- a/adapters/python/coflux/serialization.py +++ b/adapters/python/coflux/serialization.py @@ -43,7 +43,6 @@ def _write_temp_file(data: bytes) -> str: def _encode_value( value: Any, write_temp_file: Callable[[bytes], str] = _write_temp_file, - on_generator: Callable[[Any], tuple[str, int]] | None = None, ) -> tuple[Any, list[list[Any]]]: """Encode a Python value using the custom JSON value format. @@ -55,36 +54,20 @@ def _encode_value( value: The Python value to encode. write_temp_file: Callable that writes bytes to a temp file and returns the path. Used for pickle fragment references. - on_generator: Callback invoked for each generator encountered. Should - register the generator (spawn its driver) and return the - stream's opaque `id`. If None, encountering a generator - raises TypeError. Returns: Tuple of (data, references) where data is JSON-serializable and references is a list of reference arrays. + + Streams must be registered explicitly via ``cf.stream(...)`` (which + returns a ``Stream`` handle) before serialisation. A bare generator + encountered here is an error — the user probably meant to wrap it. """ references: list[list[Any]] = [] def _encode(v: Any) -> Any: if v is None or isinstance(v, (str, bool, int, float)): return v - elif inspect.isgenerator(v) or inspect.isasyncgen(v): - if on_generator is None: - raise TypeError( - "Cannot serialize a generator: no stream driver is active." - ) - stream_id = on_generator(v) - return {"type": "stream", "id": stream_id} - elif isinstance(v, Stream): - # Pass-through: a Stream handle received from another execution - # (possibly with partition/slice filters layered on top) is - # being forwarded as an argument. Preserve the filter chain so - # the downstream consumer subscribes with the same filters. - encoded: dict[str, Any] = {"type": "stream", "id": v.id} - if v._filters: - encoded["filters"] = list(v._filters) - return encoded elif isinstance(v, list): return [_encode(x) for x in v] elif isinstance(v, dict): @@ -138,6 +121,22 @@ def _encode(v: Any) -> Any: ] ) return {"type": "ref", "index": len(references) - 1} + elif isinstance(v, Stream): + # Pass-through: a Stream handle received from another execution + # (possibly with partition/slice filters layered on top) is + # being forwarded as an argument. Preserve the filter chain so + # the downstream consumer subscribes with the same filters. + encoded: dict[str, Any] = {"type": "stream", "id": v.id} + if v._filters: + encoded["filters"] = list(v._filters) + return encoded + elif inspect.isgenerator(v) or inspect.isasyncgen(v): + raise TypeError( + "Bare generators aren't serialisable — wrap with " + "cf.stream(generator, buffer=...) to register a stream " + "first. Tasks whose body yields directly are handled " + "automatically via @cf.task(buffer=...)." + ) elif HAS_PYDANTIC and isinstance(v, pydantic.BaseModel): model_class = v.__class__ model_fqn = f"{model_class.__module__}.{model_class.__name__}" @@ -177,24 +176,18 @@ def _encode(v: Any) -> Any: return data, references -def serialize_value( - value: Any, - on_generator: Callable[[Any], tuple[str, int]] | None = None, -) -> dict[str, Any]: +def serialize_value(value: Any) -> dict[str, Any]: """Serialize a result value to the protocol format. Uses the custom JSON value encoding (dict/set/tuple types, fragment refs - for unsupported types). The result is always JSON-format data. - - Args: - value: The Python value to serialize. - on_generator: Optional callback for generator objects. See - `_encode_value` for the contract. Without it, generators raise. + for unsupported types). The result is always JSON-format data. Bare + generators raise — streams must be registered explicitly via + ``cf.stream(...)`` first. Returns: Serialized value dict. """ - data, references = _encode_value(value, on_generator=on_generator) + data, references = _encode_value(value) encoded = json.dumps(data, separators=(",", ":")).encode() if len(encoded) > TRANSFER_THRESHOLD: diff --git a/adapters/python/coflux/streams.py b/adapters/python/coflux/streams.py index b09fa170..feaf3d35 100644 --- a/adapters/python/coflux/streams.py +++ b/adapters/python/coflux/streams.py @@ -37,6 +37,51 @@ # --- Producer side --- +def stream(generator: Any, *, buffer: int | None = 0) -> Any: + """Register a generator as a Coflux stream and return a handle. + + Use this when a task returns multiple streams or needs a buffer size + different from the task default. For the common case where a task + body is itself a generator, ``@cf.task(buffer=N)`` handles the + registration automatically — you don't need to call ``cf.stream`` + explicitly. + + Registration happens at call time: the driver thread starts, the + server is told about the stream, and any later serialisation sees a + regular ``Stream`` handle. That means ``cf.stream`` must be called + inside a task or workflow body (where an execution context is + active); calling it from module scope or outside a task raises. + + Args: + generator: A sync or async generator. Other iterables aren't + accepted — wrapping a list in ``cf.stream`` doesn't make + sense; pass it as a value directly. + buffer: Backpressure budget. ``0`` (the default) means strict + lockstep — the producer emits an item, waits for a consumer + to acknowledge it, then emits the next. ``N`` allows the + producer to stay up to ``N`` items ahead of the fastest + consumer. ``None`` disables backpressure entirely. + + Returns: + A ``Stream`` handle referencing the newly registered stream. + It serialises as ``{"type": "stream", "id": ...}`` and is + iterable by downstream tasks. + """ + if not (inspect.isgenerator(generator) or inspect.isasyncgen(generator)): + raise TypeError( + f"cf.stream expects a generator, got {type(generator).__name__}" + ) + if buffer is not None and buffer < 0: + raise ValueError(f"buffer must be non-negative or None, got {buffer}") + ctx = get_context() + stream_id = ctx.register_stream(generator, buffer) + # Local import to avoid a top-level cycle — models imports nothing + # from streams but streams already imports from models at top. + from .models import Stream as StreamHandle + + return StreamHandle(stream_id) + + class StreamDriver: """Manages streams produced by a single execution.""" @@ -46,8 +91,17 @@ def __init__(self, execution_id: str) -> None: self._threads: list[threading.Thread] = [] self._generators: list[Any] = [] self._lock = threading.Lock() - - def register(self, generator: Any) -> str: + # Demand tracking: each registered stream gets a per-index slot in + # `_demand`. Drivers wait on `_demand_cv` until credit is granted + # by the server (via stream_demand notifications) or the driver is + # asked to close. ``None`` means unbounded demand (buffer=None at + # registration time); the driver never waits. + self._demand_cv = threading.Condition() + self._demand: dict[int, int | None] = {} + self._closing = False + self._demand_handler_registered = False + + def register(self, generator: Any, buffer: int | None) -> str: """Register a generator and start running it in a worker thread. Accepts both sync generators (``def`` + ``yield``) and async @@ -55,14 +109,28 @@ def register(self, generator: Any) -> str: async generators run inside a fresh event loop confined to that thread. + ``buffer`` is the producer-side backpressure budget. ``None`` + means unbounded (no flow control); ``0`` means strict lockstep + (producer waits for a consumer to ack each item before emitting + the next); ``N>0`` allows the producer to stay up to N items + ahead of the fastest consumer. + Returns the stream's opaque ``id`` (``_``) for embedding in the serialized value as a stream reference. """ + self._ensure_demand_handler_registered() + with self._lock: index = self._next_index self._next_index += 1 - protocol.send_stream_register(self._execution_id, index) + with self._demand_cv: + # Unbounded ⇒ driver never waits. Bounded ⇒ starts at 0; the + # server issues a credit grant once demand calculation warrants + # it (or on first consumer subscribing). + self._demand[index] = None if buffer is None else 0 + + protocol.send_stream_register(self._execution_id, index, buffer=buffer) is_async = inspect.isasyncgen(generator) target = self._run_async if is_async else self._run @@ -80,11 +148,60 @@ def register(self, generator: Any) -> str: return compose_stream_id(self._execution_id, index) + def _ensure_demand_handler_registered(self) -> None: + if self._demand_handler_registered: + return + get_dispatcher().register_notification("stream_demand", self._on_stream_demand) + self._demand_handler_registered = True + + def _on_stream_demand(self, params: dict[str, Any]) -> None: + """Server granted additional demand for one of our streams. + + The notification carries the delta (``n`` extra credits). We add + to the per-stream counter and wake any waiter. + """ + index = params.get("index") + n = params.get("n", 0) + if index is None or n <= 0: + return + with self._demand_cv: + current = self._demand.get(index) + if current is None: + # Unbounded — nothing to account for. + return + self._demand[index] = current + n + self._demand_cv.notify_all() + + def _acquire_demand(self, index: int) -> bool: + """Wait for a credit and consume it. Returns False if closed mid-wait.""" + with self._demand_cv: + while True: + if self._closing: + return False + current = self._demand.get(index) + if current is None: + # Unbounded stream — never waits. + return True + if current > 0: + self._demand[index] = current - 1 + return True + self._demand_cv.wait() + def _run(self, index: int, generator: Any) -> None: """Run one sync generator to exhaustion (or error).""" sequence = 0 try: - for item in generator: + iterator = iter(generator) + while True: + # Block until the server grants a credit (or the driver is + # asked to close). For unbounded streams this returns + # immediately without consuming any credit. + if not self._acquire_demand(index): + return + try: + item = next(iterator) + except StopIteration: + break serialized = serialize_value(item) protocol.send_stream_append( self._execution_id, @@ -124,7 +241,19 @@ def _run_async(self, index: int, generator: Any) -> None: async def iterate() -> None: sequence = 0 - async for item in generator: + iterator = generator.__aiter__() + while True: + # The demand wait uses a threading.Condition, which would + # block the event loop. This loop is dedicated to one + # generator though — nothing else scheduled — so blocking + # in-thread is harmless and simpler than bridging to an + # asyncio primitive. + if not self._acquire_demand(index): + return + try: + item = await iterator.__anext__() + except StopAsyncIteration: + break serialized = serialize_value(item) protocol.send_stream_append( self._execution_id, @@ -183,7 +312,15 @@ def close_all(self) -> None: ``GeneratorExit`` at the current yield point. For async generators, we schedule ``aclose()`` onto the generator's own event loop so the awaiting coroutine is cancelled cleanly. + + We also flip a closing flag and broadcast on the demand condition + so drivers parked in ``_acquire_demand`` (blocked for credits that + will never arrive) wake and exit. """ + with self._demand_cv: + self._closing = True + self._demand_cv.notify_all() + with self._lock: entries = list(self._generators) for entry in entries: diff --git a/adapters/python/coflux/target.py b/adapters/python/coflux/target.py index e65a54cc..798d213a 100644 --- a/adapters/python/coflux/target.py +++ b/adapters/python/coflux/target.py @@ -70,6 +70,10 @@ class TargetDefinition(t.NamedTuple): timeout: float | dt.timedelta instruction: str | None is_stub: bool + # Backpressure for generator-bodied tasks. 0 = strict lockstep (default), + # N = up to N items ahead of the fastest consumer, None = unbounded. + # Only meaningful when ``fn`` is a generator function. + buffer: int | None def _json_dumps(obj: t.Any) -> str: @@ -195,6 +199,36 @@ def _parse_requires( return {k: _parse_require(v) for k, v in requires.items()} if requires else None +_BUFFER_UNSET = object() + + +def _resolve_buffer( + buffer: t.Any, + fn: t.Callable, +) -> int | None: + """Validate the decorator's ``buffer=`` and return the resolved value. + + Default is 0 (strict lockstep) for generator-bodied tasks. ``None`` + disables backpressure. ``buffer`` on a non-generator task is an + error — it wouldn't apply to anything. + """ + is_generator = inspect.isgeneratorfunction(fn) or inspect.isasyncgenfunction(fn) + if buffer is _BUFFER_UNSET: + return 0 if is_generator else None + if not is_generator: + raise TypeError( + f"@cf.task/@cf.workflow(buffer=...) only applies to generator functions " + f"(def + yield or async def + yield); {fn.__name__} is not." + ) + if buffer is None: + return None + if not isinstance(buffer, int) or isinstance(buffer, bool) or buffer < 0: + raise ValueError( + f"buffer must be a non-negative integer or None, got {buffer!r}" + ) + return buffer + + def _build_definition( type: TargetType, fn: t.Callable, @@ -208,6 +242,7 @@ def _build_definition( requires: dict[str, str | bool | list[str]] | None, timeout: float | dt.timedelta, is_stub: bool, + buffer: t.Any = _BUFFER_UNSET, ) -> TargetDefinition: parameters = inspect.signature(fn).parameters.values() for p in parameters: @@ -228,6 +263,7 @@ def _build_definition( timeout, inspect.getdoc(fn), is_stub, + _resolve_buffer(buffer, fn), ) @@ -320,6 +356,7 @@ def __init__( requires: dict[str, str | bool | list[str]] | None = None, timeout: float | dt.timedelta = 0, is_stub: bool = False, + buffer: t.Any = _BUFFER_UNSET, ): self._fn = fn self._name = name or fn.__name__ @@ -337,6 +374,7 @@ def __init__( requires, timeout, is_stub, + buffer, ) functools.update_wrapper(self, fn) @@ -409,12 +447,11 @@ def submit(self, *args: P.args, **kwargs: P.kwargs) -> Execution[T]: ctx = get_context() - # Serialize arguments. Generators passed as args are registered with - # the current execution's stream driver — the caller becomes the - # producer, the callee gets a Stream handle. - serialized_args = [ - serialize_value(arg, on_generator=ctx.register_stream) for arg in args - ] + # Serialize arguments. Streams passed as args must already have + # been registered via cf.stream(...) — the caller becomes the + # producer, the callee gets a Stream handle. Bare generators + # raise; the user should wrap them explicitly. + serialized_args = [serialize_value(arg) for arg in args] # Use only the declared wait_for from the decorator wait_for_val = ( diff --git a/cli/internal/adapter/protocol.go b/cli/internal/adapter/protocol.go index ae3b8473..1679ac60 100644 --- a/cli/internal/adapter/protocol.go +++ b/cli/internal/adapter/protocol.go @@ -255,10 +255,20 @@ type RegisterGroupParams struct { // StreamRegisterParams for stream_register notification. // Index is worker-assigned, monotonic per execution — it identifies the -// stream within its producer execution. +// stream within its producer execution. Buffer is the optional +// backpressure budget; nil means unbounded (no flow control). type StreamRegisterParams struct { ExecutionID string `json:"execution_id"` Index int `json:"index"` + Buffer *int `json:"buffer,omitempty"` +} + +// StreamDemandParams for stream_demand notification pushed CLI → adapter. +// Grants the producer ``n`` more credits for the given stream. +type StreamDemandParams struct { + ExecutionID string `json:"execution_id"` + Index int `json:"index"` + N int `json:"n"` } // StreamAppendParams for stream_append notification. diff --git a/cli/internal/pool/pool.go b/cli/internal/pool/pool.go index 54c77692..9b2ced58 100644 --- a/cli/internal/pool/pool.go +++ b/cli/internal/pool/pool.go @@ -51,8 +51,9 @@ type ExecutionHandler interface { NotifyTerminated(ctx context.Context, executionID string) error // StreamRegister declares a new stream owned by an execution. // Index is worker-assigned, monotonic per execution — it identifies - // the stream within its producer execution. - StreamRegister(ctx context.Context, executionID string, index int) error + // the stream within its producer execution. Buffer is the optional + // backpressure budget; nil means unbounded (no flow control). + StreamRegister(ctx context.Context, executionID string, index int, buffer *int) error // StreamAppend appends an item to a stream. Sequence is worker-assigned, // monotonic per stream — it identifies the item within its stream. StreamAppend(ctx context.Context, executionID string, index int, sequence int, value *adapter.Value) error @@ -486,7 +487,7 @@ func (p *Pool) handleStreamRegister(ctx context.Context, executionID string, par return } - if err := p.handler.StreamRegister(ctx, req.ExecutionID, req.Index); err != nil { + if err := p.handler.StreamRegister(ctx, req.ExecutionID, req.Index, req.Buffer); err != nil { logger.Error("failed to register stream", "error", err) } } diff --git a/cli/internal/worker/worker.go b/cli/internal/worker/worker.go index 0e48420e..19b1ee4c 100644 --- a/cli/internal/worker/worker.go +++ b/cli/internal/worker/worker.go @@ -362,6 +362,7 @@ func (w *Worker) runConnection(ctx context.Context, targets map[string]map[strin conn.RegisterHandler("abort", w.handleAbort) conn.RegisterHandler("stream_items", w.handleStreamItems) conn.RegisterHandler("stream_closed", w.handleStreamClosed) + conn.RegisterHandler("stream_demand", w.handleStreamDemand) conn.SetOnSession(w.handleSession) if err := conn.Connect(ctx); err != nil { @@ -717,6 +718,33 @@ func (w *Worker) handleStreamClosed(params []any) error { return w.pool.PushToExecutor(executionID, "stream_closed", forwarded) } +// handleStreamDemand forwards a server-pushed demand grant to the producer +// adapter. Params: [execution_id, index, n]. The producer's StreamDriver +// adds ``n`` to its per-stream credit counter and wakes any waiting +// worker thread. +func (w *Worker) handleStreamDemand(params []any) error { + if len(params) < 3 { + return fmt.Errorf("stream_demand: insufficient params") + } + executionID, ok := params[0].(string) + if !ok { + return fmt.Errorf("stream_demand: execution_id is not a string (got %T)", params[0]) + } + index, ok := params[1].(float64) + if !ok { + return fmt.Errorf("stream_demand: index is not a number (got %T)", params[1]) + } + n, ok := params[2].(float64) + if !ok { + return fmt.Errorf("stream_demand: n is not a number (got %T)", params[2]) + } + return w.pool.PushToExecutor(executionID, "stream_demand", map[string]any{ + "execution_id": executionID, + "index": int(index), + "n": int(n), + }) +} + func (w *Worker) heartbeatLoop(ctx context.Context) { ticker := time.NewTicker(heartbeatInterval) defer ticker.Stop() @@ -1181,12 +1209,17 @@ func (w *Worker) RegisterGroup(ctx context.Context, executionID string, groupID return conn.Notify("register_group", executionID, groupID, name) } -func (w *Worker) StreamRegister(ctx context.Context, executionID string, index int) error { +func (w *Worker) StreamRegister(ctx context.Context, executionID string, index int, buffer *int) error { conn, err := w.requireConn() if err != nil { return err } - return conn.Notify("stream_register", executionID, index) + // The wire protocol takes buffer positionally; nil encodes to JSON null, + // which the server interprets as "no backpressure" (unbounded). + if buffer == nil { + return conn.Notify("stream_register", executionID, index, nil) + } + return conn.Notify("stream_register", executionID, index, *buffer) } func (w *Worker) StreamAppend(ctx context.Context, executionID string, index int, sequence int, value *adapter.Value) error { diff --git a/server/lib/coflux/handlers/worker.ex b/server/lib/coflux/handlers/worker.ex index 50702d32..6af16c36 100644 --- a/server/lib/coflux/handlers/worker.ex +++ b/server/lib/coflux/handlers/worker.ex @@ -254,10 +254,17 @@ defmodule Coflux.Handlers.Worker do end "stream_register" -> - [execution_id, index] = message["params"] + [execution_id, index | rest] = message["params"] + buffer = List.first(rest) if is_recognised_execution?(execution_id, state) do - case Orchestration.register_stream(state.project_id, execution_id, index) do + case Orchestration.register_stream( + state.project_id, + execution_id, + index, + buffer, + state.session_id + ) do :ok -> {[], state} # Idempotent — a duplicate register is harmless. {:error, :already_registered} -> {[], state} @@ -639,6 +646,10 @@ defmodule Coflux.Handlers.Worker do {[command_message("stream_items", [execution_external_id, subscription_id, encoded])], state} end + def websocket_info({:stream_demand, execution_external_id, index, n}, state) do + {[command_message("stream_demand", [execution_external_id, index, n])], state} + end + def websocket_info({:stream_closed, execution_external_id, subscription_id, error}, state) do {[command_message("stream_closed", [execution_external_id, subscription_id, error])], state} end diff --git a/server/lib/coflux/orchestration.ex b/server/lib/coflux/orchestration.ex index 1053dac0..626e28b1 100644 --- a/server/lib/coflux/orchestration.ex +++ b/server/lib/coflux/orchestration.ex @@ -186,8 +186,11 @@ defmodule Coflux.Orchestration do # producer execution; `sequence` identifies an item within the stream. # Both are worker-assigned and monotonic from 0. - def register_stream(project_id, execution_id, index) do - call_server(project_id, {:register_stream, execution_id, index}) + def register_stream(project_id, execution_id, index, buffer, session_id) do + call_server( + project_id, + {:register_stream, execution_id, index, buffer, session_id} + ) end def append_stream_item(project_id, execution_id, index, sequence, value) do diff --git a/server/lib/coflux/orchestration/server.ex b/server/lib/coflux/orchestration/server.ex index 5eac0732..1858f647 100644 --- a/server/lib/coflux/orchestration/server.ex +++ b/server/lib/coflux/orchestration/server.ex @@ -136,7 +136,24 @@ defmodule Coflux.Orchestration.Server do # stream_subscribers: {producer_execution_id, index} -> MapSet of # {consumer_execution_id, subscription_id} stream_subscriptions: %{}, - stream_subscribers: %{} + stream_subscribers: %{}, + + # Per-stream producer state for backpressure. Only present + # when the producer opted in by registering with a non-nil + # buffer. Keyed by {producer_execution_id, index}. + # + # %{buffer, demand_granted, session_id, execution_external_id} + # + # * buffer — configured backpressure budget + # * demand_granted — cumulative credits sent so far + # * session_id — where to route stream_demand + # * execution_external_id — external id for the command wire + # + # The current max_cursor across adapter subscribers is + # recomputed from stream_subscribers on demand rather than + # cached here, since it changes often and is cheap to + # derive. + stream_producers: %{} end def start_link(opts) do @@ -1830,13 +1847,32 @@ defmodule Coflux.Orchestration.Server do end end - def handle_call({:register_stream, execution_external_id, index}, _from, state) do + def handle_call( + {:register_stream, execution_external_id, index, buffer, session_external_id}, + _from, + state + ) do case Map.fetch(state.execution_ids, execution_external_id) do {:ok, execution_id} -> - case Streams.register_stream(state.db, execution_id, index) do + case Streams.register_stream(state.db, execution_id, index, buffer) do {:ok, created_at} -> + # Resolve the session's external id to the internal one — + # send_session (which delivers stream_demand) indexes by the + # internal id. + internal_session_id = + Map.get(state.session_ids, session_external_id) + state = - notify_stream_opened(state, execution_id, index, created_at) + state + |> maybe_init_stream_producer( + execution_id, + execution_external_id, + index, + buffer, + internal_session_id + ) + |> notify_stream_opened(execution_id, index, buffer, created_at) + |> maybe_send_initial_demand(execution_id, index) |> flush_notifications() {:reply, :ok, state} @@ -1850,6 +1886,151 @@ defmodule Coflux.Orchestration.Server do end end + defp maybe_init_stream_producer( + state, + _execution_id, + _execution_external_id, + _index, + nil, + _session_id + ) do + # buffer=nil means the producer has opted out of backpressure — no + # tracking required on the server side. It'll emit freely and the + # adapter's driver never waits. + state + end + + defp maybe_init_stream_producer( + state, + execution_id, + execution_external_id, + index, + buffer, + session_id + ) + when is_integer(buffer) and buffer >= 0 do + put_in(state.stream_producers[{execution_id, index}], %{ + buffer: buffer, + demand_granted: 0, + session_id: session_id, + execution_external_id: execution_external_id + }) + end + + defp maybe_send_initial_demand(state, execution_id, index) do + # At registration time there are no subscribers yet. Allow the + # producer to pre-warm up to `buffer` items; lockstep (buffer=0) + # stays paused until a consumer attaches. + refresh_stream_demand(state, {execution_id, index}) + end + + # Recompute the target demand for one stream and, if it's grown, + # send a delta grant to the producer's session. + # + # Formula: + # target = max_cursor + buffer + (1 if has_subscribers else 0) + # The +1 on subscriber presence unblocks lockstep streams — a + # consumer's cursor at position N means "ready for item N", which is + # one item beyond what they've acked. + # + # demand_granted is monotonic; if target drops (e.g. the fastest + # consumer left) we don't claw back, future grants just wait until + # the remaining subscribers catch up past the old max. + defp refresh_stream_demand(state, {_execution_id, index} = key) do + case Map.fetch(state.stream_producers, key) do + :error -> + state + + {:ok, producer} -> + has_subscribers = has_stream_subscribers?(state, key) + max_cursor = current_max_cursor(state, key) + bump = if has_subscribers, do: 1, else: 0 + target = max_cursor + producer.buffer + bump + delta = target - producer.demand_granted + + if delta > 0 do + state + |> put_in([Access.key(:stream_producers), key, :demand_granted], target) + |> send_session( + producer.session_id, + {:stream_demand, producer.execution_external_id, index, delta} + ) + else + state + end + end + end + + defp has_stream_subscribers?(state, key) do + case Map.get(state.stream_subscribers, key) do + nil -> false + set -> MapSet.size(set) > 0 + end + end + + defp current_max_cursor(state, key) do + state.stream_subscribers + |> Map.get(key, MapSet.new()) + |> Enum.reduce(0, fn sub_key, acc -> + case Map.get(state.stream_subscriptions, sub_key) do + nil -> acc + sub -> max(acc, sub.cursor) + end + end) + end + + defp drop_stream_producer(state, key) do + Map.update!(state, :stream_producers, &Map.delete(&1, key)) + end + + # Lazily rebuild stream_producer state from the DB if it's missing. + # Used after server restart — in-memory producer state is gone but + # the ``streams`` table still has the buffer. We rebuild on first + # append or subscribe for a given stream, recovering flow control. + # + # ``session_id`` is the internal id of the producer's current session; + # supply ``nil`` if not known, in which case demand grants will be + # deferred until the session is resolvable. + defp ensure_stream_producer( + state, + execution_id, + execution_external_id, + index, + session_id + ) do + key = {execution_id, index} + + cond do + Map.has_key?(state.stream_producers, key) -> + state + + true -> + case Streams.get_buffer(state.db, execution_id, index) do + {:ok, nil} -> + # Stream opted out of backpressure; nothing to track. + state + + {:ok, buffer} when is_integer(buffer) -> + # Reconstruct state. demand_granted starts at items already + # produced — we assume earlier-us granted enough for those, + # and rely on the producer having kept its local credit + # counter consistent. + {:ok, head} = Streams.get_stream_head(state.db, execution_id, index) + items_produced = if head < 0, do: 0, else: head + 1 + + put_in(state.stream_producers[key], %{ + buffer: buffer, + demand_granted: items_produced, + session_id: session_id, + execution_external_id: execution_external_id + }) + + {:error, :not_found} -> + state + end + end + end + def handle_call( {:append_stream_item, execution_external_id, index, sequence, value}, _from, @@ -1865,8 +2046,24 @@ defmodule Coflux.Orchestration.Server do normalize_value(value) ) do {:ok, created_at} -> + # If we came out of a server restart with no in-memory + # producer state for this stream, rebuild it now from the + # persisted buffer so subsequent consumer advances can + # refresh demand. The appending session is the producer. + producer_session_id = + case find_session_for_execution(state, execution_external_id) do + {:ok, sid} -> sid + :error -> nil + end + state = state + |> ensure_stream_producer( + execution_id, + execution_external_id, + index, + producer_session_id + ) |> push_stream_item(execution_id, index, sequence, value) |> notify_stream_item_appended( execution_id, @@ -1906,6 +2103,7 @@ defmodule Coflux.Orchestration.Server do state |> push_stream_closed(execution_id, index, error) |> notify_stream_closed(execution_id, index, reason, error, closed_at) + |> drop_stream_producer({execution_id, index}) |> flush_notifications() {:reply, :ok, state} @@ -1963,6 +2161,31 @@ defmodule Coflux.Orchestration.Server do ) end) + # Post-restart recovery: producer state may be missing. The + # producer's session isn't necessarily the one the subscribe + # came from — look it up across sessions by external execution + # id. + producer_session_id = + case find_session_for_execution(state, producer_execution_external_id) do + {:ok, sid} -> sid + :error -> nil + end + + state = + ensure_stream_producer( + state, + producer_execution_id, + producer_execution_external_id, + index, + producer_session_id + ) + + # First subscriber (or a later one whose cursor exceeds the prior + # max) may unblock the producer — recompute demand before pushing + # backlog so any delivered items keep the credit maths honest. + state = + refresh_stream_demand(state, {producer_execution_id, index}) + # Push any items already in the log that match the filter, then (if # the stream has already closed) the terminal close record. state = push_backlog(state, key) @@ -3921,11 +4144,13 @@ defmodule Coflux.Orchestration.Server do defp remove_session(state, session_id) do {:ok, _} = Sessions.expire_session(state.db, session_id) - {session, state} = pop_in(state.sessions[session_id]) - state = Map.update!(state, :session_expiries, &Map.delete(&1, session_id)) # Drop any stream subscriptions this session held — consumer has gone - # away, so there's no one to push to. + # away, so there's no one to push to. Do this before popping the + # session from state.sessions since drop_session_subscriptions reads + # the session's live execution set. state = drop_session_subscriptions(state, session_id) + {session, state} = pop_in(state.sessions[session_id]) + state = Map.update!(state, :session_expiries, &Map.delete(&1, session_id)) # starting/executing now contain external IDs - resolve to internal for process_result. # Session removal means no more notify_terminated for these executions, so we @@ -5916,6 +6141,7 @@ defmodule Coflux.Orchestration.Server do state |> push_stream_closed(execution_id, index, push_error) |> notify_stream_closed(execution_id, index, :lifecycle, push_error, closed_at) + |> drop_stream_producer({execution_id, index}) {:error, :already_closed} -> state @@ -5925,22 +6151,23 @@ defmodule Coflux.Orchestration.Server do # Returns the streams list for `execution_id` with :lifecycle closures' # errors resolved from the execution's recorded result. Shape: - # `{index, opened_at, closed_at | nil, reason | nil, error | nil}` — + # `{index, buffer, opened_at, closed_at | nil, reason | nil, error | nil}` — # reason is retained so the topic can colour open vs complete vs - # errored vs lifecycle distinctly. + # errored vs lifecycle distinctly; buffer is passed through for the + # topic to display. defp streams_with_resolved_errors(db, execution_id) do {:ok, rows} = Streams.get_streams_with_closures_for_execution(db, execution_id) Enum.map(rows, fn - {index, opened_at, nil, nil, nil} -> - {index, opened_at, nil, nil, nil} + {index, buffer, opened_at, nil, nil, nil} -> + {index, buffer, opened_at, nil, nil, nil} - {index, opened_at, closed_at, :lifecycle, _} -> - {index, opened_at, closed_at, :lifecycle, + {index, buffer, opened_at, closed_at, :lifecycle, _} -> + {index, buffer, opened_at, closed_at, :lifecycle, derive_lifecycle_error(db, execution_id)} - {index, opened_at, closed_at, reason, error} -> - {index, opened_at, closed_at, reason, error} + {index, buffer, opened_at, closed_at, reason, error} -> + {index, buffer, opened_at, closed_at, reason, error} end) end @@ -7265,14 +7492,14 @@ defmodule Coflux.Orchestration.Server do # items on demand. @stream_topic_tail_size 200 - defp notify_stream_opened(state, execution_id, index, created_at) do + defp notify_stream_opened(state, execution_id, index, buffer, created_at) do {:ok, {r, _s, _a}} = Runs.get_execution_key(state.db, execution_id) {:ok, execution_ext_id} = execution_external_id_for(state.db, execution_id) notify_listeners( state, {:run, r}, - {:stream_opened, execution_ext_id, index, created_at} + {:stream_opened, execution_ext_id, index, buffer, created_at} ) end @@ -7336,6 +7563,7 @@ defmodule Coflux.Orchestration.Server do with {:ok, execution_id} <- resolve_internal_execution_id(state, execution_ext_id), {:ok, true} <- Streams.exists?(state.db, execution_id, index), {:ok, opened_at} <- Streams.get_opened_at(state.db, execution_id, index), + {:ok, buffer} <- Streams.get_buffer(state.db, execution_id, index), {:ok, {items, total_count}} <- Streams.get_stream_tail(state.db, execution_id, index, @stream_topic_tail_size) do # Keep the tuple shape here — the topic module runs TopicUtils.build_value @@ -7351,6 +7579,7 @@ defmodule Coflux.Orchestration.Server do {:ok, %{ producer: build_stream_producer(state.db, execution_ext_id, execution_id), + buffer: buffer, openedAt: opened_at, closure: closure, items: resolved_items, @@ -7521,20 +7750,26 @@ defmodule Coflux.Orchestration.Server do &Map.put(&1, :cursor, advance_to) ) - # If the filter is now exhausted (slice's stop reached), close the - # subscription synchronously — matches push_stream_item's behaviour. - # Without this, a consumer that subscribed after appends with a - # bounded filter would wait forever for a close that never comes. - if filter_exhausted?(sub.filter, advance_to) do - state - |> send_to_consumer( - sub, - {:stream_closed, sub.consumer_execution_external_id, subscription_id, nil} - ) - |> drop_subscription(key) - else - state - end + state = + if filter_exhausted?(sub.filter, advance_to) do + # If the filter is now exhausted (slice's stop reached), close the + # subscription synchronously — matches push_stream_item's + # behaviour. Without this, a consumer that subscribed after + # appends with a bounded filter would wait forever for a close + # that never comes. + state + |> send_to_consumer( + sub, + {:stream_closed, sub.consumer_execution_external_id, subscription_id, nil} + ) + |> drop_subscription(key) + else + state + end + + # The backlog push moved this consumer's cursor forward — may + # unblock the producer if their buffer has room now. + refresh_stream_demand(state, {sub.producer_execution_id, sub.index}) end # Resolve the consumer's current session and send, skipping if the @@ -7548,55 +7783,61 @@ defmodule Coflux.Orchestration.Server do # Push a freshly-appended item to every subscriber of this stream. defp push_stream_item(state, producer_execution_id, index, sequence, value) do - subscribers = - Map.get(state.stream_subscribers, {producer_execution_id, index}, MapSet.new()) + stream_key = {producer_execution_id, index} - Enum.reduce(subscribers, state, fn key, state -> - {_consumer_execution_id, subscription_id} = key - sub = Map.fetch!(state.stream_subscriptions, key) + subscribers = Map.get(state.stream_subscribers, stream_key, MapSet.new()) - cond do - sequence < sub.cursor -> - # Consumer already has this sequence via backlog; skip. - state + state = + Enum.reduce(subscribers, state, fn key, state -> + {_consumer_execution_id, subscription_id} = key + sub = Map.fetch!(state.stream_subscriptions, key) - not filter_matches?(sub.filter, sequence) -> - state + cond do + sequence < sub.cursor -> + # Consumer already has this sequence via backlog; skip. + state - true -> - # Value came off the wire in parse form (ext-id refs, no metadata). - # Normalise + resolve to match the form push_backlog sends; the WS - # handler composes to wire JSON. - resolved = build_value(normalize_value(value), state.db) - item = [sequence, resolved] + not filter_matches?(sub.filter, sequence) -> + state - state = - send_to_consumer( - state, - sub, - {:stream_items, sub.consumer_execution_external_id, subscription_id, [item]} - ) + true -> + # Value came off the wire in parse form (ext-id refs, no metadata). + # Normalise + resolve to match the form push_backlog sends; the WS + # handler composes to wire JSON. + resolved = build_value(normalize_value(value), state.db) + item = [sequence, resolved] - state = - update_in( - state.stream_subscriptions[key], - &Map.put(&1, :cursor, sequence + 1) - ) + state = + send_to_consumer( + state, + sub, + {:stream_items, sub.consumer_execution_external_id, subscription_id, [item]} + ) - # If the filter is exhausted (e.g. slice reached its stop), close - # the subscription early — no more items will match. - if filter_exhausted?(sub.filter, sequence + 1) do - state - |> send_to_consumer( - sub, - {:stream_closed, sub.consumer_execution_external_id, subscription_id, nil} - ) - |> drop_subscription(key) - else - state - end - end - end) + state = + update_in( + state.stream_subscriptions[key], + &Map.put(&1, :cursor, sequence + 1) + ) + + # If the filter is exhausted (e.g. slice reached its stop), close + # the subscription early — no more items will match. + if filter_exhausted?(sub.filter, sequence + 1) do + state + |> send_to_consumer( + sub, + {:stream_closed, sub.consumer_execution_external_id, subscription_id, nil} + ) + |> drop_subscription(key) + else + state + end + end + end) + + # Subscriber cursors may have advanced — recompute demand once per + # stream (cheaper than once per subscriber, same result). + refresh_stream_demand(state, stream_key) end # On close, tell every subscriber. Error is either nil (clean close) or a diff --git a/server/lib/coflux/orchestration/streams.ex b/server/lib/coflux/orchestration/streams.ex index 9d5434de..943bf173 100644 --- a/server/lib/coflux/orchestration/streams.ex +++ b/server/lib/coflux/orchestration/streams.ex @@ -34,14 +34,17 @@ defmodule Coflux.Orchestration.Streams do alias Coflux.Orchestration.{Errors, Values} # Registers a new stream owned by `execution_id` at `index` (monotonic - # per-execution, worker-assigned). Returns `{:error, :already_registered}` - # if the index was already used. - def register_stream(db, execution_id, index) do + # per-execution, worker-assigned). ``buffer`` is the persisted flow- + # control budget — ``nil`` means no backpressure, integer N means the + # producer may be up to N items ahead of the fastest consumer. Returns + # ``{:error, :already_registered}`` if the index was already used. + def register_stream(db, execution_id, index, buffer \\ nil) do now = current_timestamp() case insert_one(db, :streams, %{ execution_id: execution_id, index: index, + buffer: buffer, created_at: now }) do {:ok, _} -> {:ok, now} @@ -49,6 +52,20 @@ defmodule Coflux.Orchestration.Streams do end end + # Returns the persisted buffer for a stream. Result is ``{:ok, buffer}`` + # where ``buffer`` is either an integer or ``nil`` (no backpressure). + # ``{:error, :not_found}`` if the stream doesn't exist. + def get_buffer(db, execution_id, index) do + case query_one( + db, + "SELECT buffer FROM streams WHERE execution_id = ?1 AND `index` = ?2", + {execution_id, index} + ) do + {:ok, nil} -> {:error, :not_found} + {:ok, {buffer}} -> {:ok, buffer} + end + end + # Appends an item at `sequence` to the stream. Caller supplies the sequence # (worker-assigned, monotonic). Returns: # * `{:error, :not_registered}` if the stream doesn't exist @@ -256,7 +273,8 @@ defmodule Coflux.Orchestration.Streams do end # Returns one row per stream owned by `execution_id`: - # `{index, created_at, closed_at | nil, reason | nil, error | nil}`. + # `{index, buffer, created_at, closed_at | nil, reason | nil, error | nil}`. + # * buffer is the persisted backpressure budget (integer or nil) # * reason is :complete | :errored | :lifecycle when closed, nil when open # * error is the stored `{type, message, frames}` triple for :errored # closures only — callers that need to surface an error for a @@ -266,7 +284,7 @@ defmodule Coflux.Orchestration.Streams do case query( db, """ - SELECT s.`index`, s.created_at, c.created_at, c.reason, c.error_id + SELECT s.`index`, s.buffer, s.created_at, c.created_at, c.reason, c.error_id FROM streams AS s LEFT JOIN stream_closures AS c ON c.execution_id = s.execution_id AND c.`index` = s.`index` @@ -278,15 +296,17 @@ defmodule Coflux.Orchestration.Streams do {:ok, rows} -> streams = Enum.map(rows, fn - {index, created_at, nil, nil, nil} -> - {index, created_at, nil, nil, nil} + {index, buffer, created_at, nil, nil, nil} -> + {index, buffer, created_at, nil, nil, nil} - {index, created_at, closed_at, reason_int, nil} -> - {index, created_at, closed_at, reason_from_int(reason_int), nil} + {index, buffer, created_at, closed_at, reason_int, nil} -> + {index, buffer, created_at, closed_at, reason_from_int(reason_int), nil} - {index, created_at, closed_at, reason_int, error_id} -> + {index, buffer, created_at, closed_at, reason_int, error_id} -> {:ok, error} = Errors.get_by_id(db, error_id) - {index, created_at, closed_at, reason_from_int(reason_int), error} + + {index, buffer, created_at, closed_at, reason_from_int(reason_int), + error} end) {:ok, streams} diff --git a/server/lib/coflux/topics/run.ex b/server/lib/coflux/topics/run.ex index 88fc45e3..11519c98 100644 --- a/server/lib/coflux/topics/run.ex +++ b/server/lib/coflux/topics/run.ex @@ -209,10 +209,11 @@ defmodule Coflux.Topics.Run do defp process_notification( topic, - {:stream_opened, execution_external_id, index, created_at} + {:stream_opened, execution_external_id, index, buffer, created_at} ) do update_execution(topic, execution_external_id, fn topic, base_path -> Topic.set(topic, base_path ++ [:streams, Integer.to_string(index)], %{ + buffer: buffer, openedAt: created_at, closedAt: nil, reason: nil, @@ -551,22 +552,24 @@ defmodule Coflux.Topics.Run do defp build_streams(streams) do Map.new(streams, fn - {index, opened_at, nil, nil, nil} -> + {index, buffer, opened_at, nil, nil, nil} -> {Integer.to_string(index), - %{openedAt: opened_at, closedAt: nil, reason: nil, error: nil}} + %{buffer: buffer, openedAt: opened_at, closedAt: nil, reason: nil, error: nil}} - {index, opened_at, closed_at, reason, nil} -> + {index, buffer, opened_at, closed_at, reason, nil} -> {Integer.to_string(index), %{ + buffer: buffer, openedAt: opened_at, closedAt: closed_at, reason: Atom.to_string(reason), error: nil }} - {index, opened_at, closed_at, reason, {type, message, _frames}} -> + {index, buffer, opened_at, closed_at, reason, {type, message, _frames}} -> {Integer.to_string(index), %{ + buffer: buffer, openedAt: opened_at, closedAt: closed_at, reason: Atom.to_string(reason), diff --git a/server/lib/coflux/topics/stream.ex b/server/lib/coflux/topics/stream.ex index d97f34d4..05b150f6 100644 --- a/server/lib/coflux/topics/stream.ex +++ b/server/lib/coflux/topics/stream.ex @@ -39,6 +39,7 @@ defmodule Coflux.Topics.Stream do Topic.new( %{ producer: initial.producer, + buffer: initial.buffer, openedAt: initial.openedAt, closure: build_closure(initial.closure), items: Enum.map(initial.items, &build_item/1), diff --git a/server/priv/migrations/orchestration/4.sql b/server/priv/migrations/orchestration/4.sql index 7910e69f..022e44f9 100644 --- a/server/priv/migrations/orchestration/4.sql +++ b/server/priv/migrations/orchestration/4.sql @@ -41,6 +41,12 @@ INSERT INTO completions (execution_id, created_at) CREATE TABLE streams ( execution_id INTEGER NOT NULL, `index` INTEGER NOT NULL, + -- Producer-side backpressure budget. NULL opts out of flow control + -- (producer emits freely). Integer N means the producer may run up + -- to N items ahead of the fastest consumer; N=0 is strict lockstep. + -- Persisted so the server can reconstruct per-stream flow-control + -- state on restart and so Studio can display the configuration. + buffer INTEGER, created_at INTEGER NOT NULL, PRIMARY KEY (execution_id, `index`), FOREIGN KEY (execution_id) REFERENCES executions ON DELETE CASCADE diff --git a/tests/support/executor.py b/tests/support/executor.py index 08069279..07a1cd81 100644 --- a/tests/support/executor.py +++ b/tests/support/executor.py @@ -258,9 +258,9 @@ def resolve_input( # --- Stream producer helpers --- - def stream_register(self, execution_id, index): - """Notify that a new stream exists.""" - self.send(protocol.stream_register(execution_id, index)) + def stream_register(self, execution_id, index, buffer=None): + """Notify that a new stream exists. ``buffer`` enables backpressure.""" + self.send(protocol.stream_register(execution_id, index, buffer=buffer)) def stream_append(self, execution_id, index, sequence, value, format="json"): """Append an item (raw JSON value) to a stream.""" diff --git a/tests/support/protocol.py b/tests/support/protocol.py index 62c25384..cf7166a9 100644 --- a/tests/support/protocol.py +++ b/tests/support/protocol.py @@ -234,11 +234,11 @@ def register_group_notification(execution_id, group_id, name=None): # --- Stream messages (producer side: adapter → server) --- -def stream_register(execution_id, index): - return { - "method": "stream_register", - "params": {"execution_id": execution_id, "index": index}, - } +def stream_register(execution_id, index, buffer=None): + params = {"execution_id": execution_id, "index": index} + if buffer is not None: + params["buffer"] = buffer + return {"method": "stream_register", "params": params} def stream_append(execution_id, index, sequence, value, format="json"): diff --git a/tests/test_streams.py b/tests/test_streams.py index 1e6f3dbe..f64cc9fc 100644 --- a/tests/test_streams.py +++ b/tests/test_streams.py @@ -14,6 +14,8 @@ taking turns over different connections. """ +import pytest + from support.manifest import workflow from support.protocol import ( execution_result, @@ -622,3 +624,118 @@ def test_filter_chain_combines_slice_and_partition(worker): cons_ex.conn.complete(cons_ex.execution_id) assert [item[0] for item in items] == [0, 2, 4] + + +# --- Backpressure ------------------------------------------------------- + + +def test_backpressure_no_buffer_no_initial_demand(worker): + """Registering with buffer=0 and no subscribers: the server sends no + demand grants — the producer stays paused until a consumer attaches. + """ + targets = [workflow("test", "producer")] + + with worker(targets) as ctx: + prod_resp = ctx.submit("test", "producer") + prod_ex = ctx.executor.next_execute() + prod_ex.conn.stream_register(prod_ex.execution_id, 0, buffer=0) + + # Producer shouldn't get any demand grants yet. + with pytest.raises(TimeoutError): + prod_ex.conn.recv_push("stream_demand", timeout=0.5) + + prod_ex.conn.stream_close(prod_ex.execution_id, 0) + prod_ex.conn.complete(prod_ex.execution_id) + ctx.result(prod_resp["runId"]) + + +def test_backpressure_prewarms_up_to_buffer(worker): + """buffer=N without any subscribers: producer is granted N credits + up front so it can run ahead and pre-warm. + """ + targets = [workflow("test", "producer")] + + with worker(targets) as ctx: + prod_resp = ctx.submit("test", "producer") + prod_ex = ctx.executor.next_execute() + prod_ex.conn.stream_register(prod_ex.execution_id, 0, buffer=5) + + params = prod_ex.conn.recv_push("stream_demand", timeout=2) + assert params["index"] == 0 + assert params["n"] == 5 + + prod_ex.conn.stream_close(prod_ex.execution_id, 0) + prod_ex.conn.complete(prod_ex.execution_id) + ctx.result(prod_resp["runId"]) + + +def test_backpressure_subscribe_unblocks_producer(worker): + """buffer=0 + consumer subscribes → server grants 1 credit. Producer + emits. Consumer reads → cursor advances → server grants 1 more. + """ + targets = [workflow("test", "producer"), workflow("test", "consumer")] + + with worker(targets, concurrency=2) as ctx: + prod_resp = ctx.submit("test", "producer") + prod_ex = ctx.executor.next_execute() + prod_ex.conn.stream_register(prod_ex.execution_id, 0, buffer=0) + + # No consumer yet → no demand. + with pytest.raises(TimeoutError): + prod_ex.conn.recv_push("stream_demand", timeout=0.3) + + # Attach consumer. First grant arrives. + cons_resp = ctx.submit("test", "consumer") + cons_ex = ctx.executor.next_execute() + cons_ex.conn.stream_subscribe( + cons_ex.execution_id, + subscription_id=1, + producer_execution_id=prod_ex.execution_id, + index=0, + ) + first = prod_ex.conn.recv_push("stream_demand", timeout=2) + assert first["n"] == 1 + + # Producer emits item 0 (it's the adapter's responsibility to + # decrement credits; we just emulate that here by appending). + prod_ex.conn.stream_append(prod_ex.execution_id, 0, 0, "hi") + + # Consumer receives item → cursor advances → server grants again. + cons_ex.conn.recv_push("stream_items", subscription_id=1, timeout=2) + second = prod_ex.conn.recv_push("stream_demand", timeout=2) + assert second["n"] == 1 + + prod_ex.conn.stream_close(prod_ex.execution_id, 0) + prod_ex.conn.complete(prod_ex.execution_id) + cons_ex.conn.complete(cons_ex.execution_id) + ctx.result(prod_resp["runId"]) + + +def test_backpressure_unbounded_sends_no_demand(worker): + """Registering without a buffer (wire buffer=null) opts out of + backpressure — the server never sends demand grants for this stream. + """ + targets = [workflow("test", "producer"), workflow("test", "consumer")] + + with worker(targets, concurrency=2) as ctx: + prod_resp = ctx.submit("test", "producer") + prod_ex = ctx.executor.next_execute() + prod_ex.conn.stream_register(prod_ex.execution_id, 0) # buffer omitted + + cons_resp = ctx.submit("test", "consumer") + cons_ex = ctx.executor.next_execute() + cons_ex.conn.stream_subscribe( + cons_ex.execution_id, + subscription_id=1, + producer_execution_id=prod_ex.execution_id, + index=0, + ) + + # Even with a consumer attached, no demand grant should fire. + with pytest.raises(TimeoutError): + prod_ex.conn.recv_push("stream_demand", timeout=0.5) + + prod_ex.conn.stream_close(prod_ex.execution_id, 0) + prod_ex.conn.complete(prod_ex.execution_id) + cons_ex.conn.complete(cons_ex.execution_id) + ctx.result(prod_resp["runId"]) From 4932cb34d2a08529887495337256865f5e86cc31 Mon Sep 17 00:00:00 2001 From: Joe Freeman Date: Sun, 19 Apr 2026 23:14:16 +0100 Subject: [PATCH 15/25] Use completions to determine execution state --- adapters/python/coflux/errors.py | 79 ++- adapters/python/coflux/streams.py | 35 +- cli/internal/adapter/protocol.go | 10 +- cli/internal/worker/worker.go | 28 +- server/lib/coflux/handlers/worker.ex | 26 +- server/lib/coflux/orchestration/epoch.ex | 264 ++++---- server/lib/coflux/orchestration/inputs.ex | 16 +- server/lib/coflux/orchestration/results.ex | 693 ++++++++++++++------- server/lib/coflux/orchestration/runs.ex | 97 ++- server/lib/coflux/orchestration/server.ex | 658 +++++++++++++------ server/lib/coflux/topics/run.ex | 11 +- server/priv/migrations/orchestration/4.sql | 115 +++- tests/test_streams.py | 15 +- 13 files changed, 1394 insertions(+), 653 deletions(-) diff --git a/adapters/python/coflux/errors.py b/adapters/python/coflux/errors.py index e03f73f7..dd571ce2 100644 --- a/adapters/python/coflux/errors.py +++ b/adapters/python/coflux/errors.py @@ -163,46 +163,45 @@ def create_execution_error(error_type: str, error_message: str) -> ExecutionErro ) -# --- Stream error helpers --- - -# Server synthesises these types when closing streams due to the producer's -# disposition (cancel, crash, etc.). The wire carries them as regular -# {type, message, frames} errors; we route them here to specific -# ExecutionTerminated subclasses rather than generic ExecutionError. -_STREAM_SYNTHETIC_ERRORS: dict[str, type[Exception]] = { - "Coflux.ExecutionCancelled": ExecutionCancelled, - "Coflux.ExecutionAbandoned": ExecutionAbandoned, - "Coflux.ExecutionCrashed": ExecutionCrashed, - "Coflux.ExecutionErrored": ExecutionError, -} - - -def create_stream_error(error: dict) -> Exception: - """Build an exception for a stream closure. - - Server-synthesised types (``Coflux.Execution*``) map to - ``ExecutionTerminated`` subclasses. Real user exceptions (raised by the - producer's generator) go through ``create_execution_error`` and get - the producer's frames attached as ``.frames`` for debuggability. +# --- Stream close handling --- + + +def raise_for_close(reason: str, error: dict | None) -> None: + """Raise the appropriate exception for a stream-closed event. + + The server carries a semantic reason atom rather than a fabricated + exception type; we map it to the Python exception idiomatic for each + case. Only ``"errored"`` carries an error dict — the producer's + actual exception; other reasons raise a corresponding + ``ExecutionTerminated`` subclass with no further payload. """ - error_type = error.get("type", "") - error_message = error.get("message", "") - frames = error.get("frames") or [] + if reason == "complete": + return - synthetic = _STREAM_SYNTHETIC_ERRORS.get(error_type) - if synthetic is ExecutionError: - exc = ExecutionError( - error_message, - error_type=error_type, - error_message=error_message, - ) - elif synthetic is not None: - exc = synthetic() - else: + if reason == "errored" and error is not None: + error_type = error.get("type", "") + error_message = error.get("message", "") + frames = error.get("frames") or [] exc = create_execution_error(error_type, error_message) - - if frames: - # Frames are [file, line, name, code] lists — the same wire shape - # used by execution errors. Expose on the exception for inspection. - exc.frames = frames - return exc + if frames: + exc.frames = frames + raise exc + + if reason == "cancelled": + raise ExecutionCancelled() + if reason == "abandoned": + raise ExecutionAbandoned() + if reason == "crashed": + raise ExecutionCrashed() + if reason == "timeout": + raise ExecutionTimeout() + + # Anything else (e.g. "not_found", "already_subscribed", or an + # unknown future reason) is a subscription problem rather than a + # producer-terminal event — surface as a generic ExecutionError so + # the consumer is at least aware the stream ended abnormally. + raise ExecutionError( + reason, + error_type="Coflux.StreamClosed", + error_message=reason, + ) diff --git a/adapters/python/coflux/streams.py b/adapters/python/coflux/streams.py index feaf3d35..defc37ba 100644 --- a/adapters/python/coflux/streams.py +++ b/adapters/python/coflux/streams.py @@ -29,7 +29,7 @@ from . import protocol from .dispatcher import get_dispatcher -from .errors import create_stream_error +from .errors import raise_for_close from .serialization import deserialize_value, serialize_value from .state import get_context @@ -346,12 +346,16 @@ async def _close(g=gen) -> None: # --- Consumer side --- -# Sentinel pushed onto a subscriber's queue to signal close. Carries the -# optional error dict ({"type": str, "message": str} or None). +# Sentinel pushed onto a subscriber's queue to signal close. `reason` +# is the semantic close reason (``"complete"`` / ``"errored"`` / +# ``"cancelled"`` / ``"abandoned"`` / ``"crashed"`` / ``"timeout"`` / +# ``"not_found"``). ``error`` is only populated when ``reason == +# "errored"`` — it's the producer's actual ``{type, message, frames}``. class _Closed: - __slots__ = ("error",) + __slots__ = ("reason", "error") - def __init__(self, error: dict[str, Any] | None) -> None: + def __init__(self, reason: str, error: dict[str, Any] | None) -> None: + self.reason = reason self.error = error @@ -376,9 +380,9 @@ def on_items(self, items: list[list[Any]]) -> None: for _sequence, value in items: self._queue.put(value) - def on_closed(self, error: dict[str, Any] | None) -> None: + def on_closed(self, reason: str, error: dict[str, Any] | None) -> None: """Called by the registry when the stream closes.""" - self._queue.put(_Closed(error)) + self._queue.put(_Closed(reason, error)) def __iter__(self) -> "_StreamIterator": return self @@ -400,8 +404,7 @@ def __next__(self) -> Any: ) except Exception: pass - if item.error is not None: - raise create_stream_error(item.error) + raise_for_close(item.reason, item.error) raise StopIteration return deserialize_value(item) @@ -430,16 +433,13 @@ def _ensure_installed(self) -> None: self._installed = True def _on_dispatcher_closed(self) -> None: - """Wake all active iterators with a connection-closed error.""" - error = { - "type": "Coflux.ExecutionAbandoned", - "message": "connection closed", - "frames": [], - } + """Wake all active iterators — connection to the server is gone + so no close message is going to arrive. Treat as ``abandoned`` + (we don't know anything more specific from this side).""" with self._lock: iterators = list(self._iterators.values()) for it in iterators: - it.on_closed(error) + it.on_closed("abandoned", None) def allocate(self, execution_id: str) -> tuple[int, _StreamIterator]: """Claim a subscription id and iterator.""" @@ -465,11 +465,12 @@ def _on_items(self, params: dict[str, Any]) -> None: def _on_closed(self, params: dict[str, Any]) -> None: subscription_id = params.get("subscription_id") + reason = params.get("reason") or "complete" error = params.get("error") with self._lock: it = self._iterators.get(subscription_id) if it is not None: - it.on_closed(error) + it.on_closed(reason, error) _registry_instance: StreamRegistry | None = None diff --git a/cli/internal/adapter/protocol.go b/cli/internal/adapter/protocol.go index 1679ac60..72ba0db6 100644 --- a/cli/internal/adapter/protocol.go +++ b/cli/internal/adapter/protocol.go @@ -323,10 +323,18 @@ type StreamItemsParams struct { } // StreamClosedParams for stream_closed notification pushed CLI → adapter. -// Error is nil for clean close or a {type, message} dict for errored close. +// +// `Reason` is a semantic string ("complete" / "errored" / "cancelled" / +// "abandoned" / "crashed" / "timeout" / "not_found" / ...). The adapter +// maps it to whatever exception/return value is idiomatic in the target +// language — the CLI doesn't fabricate types. +// +// `Error` is non-nil only when `Reason == "errored"`, carrying the +// producer's actual `{type, message, frames}`. type StreamClosedParams struct { ExecutionID string `json:"execution_id"` SubscriptionID int `json:"subscription_id"` + Reason string `json:"reason"` Error map[string]any `json:"error,omitempty"` } diff --git a/cli/internal/worker/worker.go b/cli/internal/worker/worker.go index 19b1ee4c..aa8036d8 100644 --- a/cli/internal/worker/worker.go +++ b/cli/internal/worker/worker.go @@ -693,9 +693,14 @@ func (w *Worker) handleStreamItems(params []any) error { } // handleStreamClosed forwards a server-pushed stream-closed notification. -// Params: [execution_id, subscription_id, error_or_null]. +// Params: [execution_id, subscription_id, reason, error_or_null]. +// +// `reason` is a string like "complete" / "errored" / "cancelled" / +// "abandoned" / "crashed" / "timeout" / "not_found" — the adapter +// decides how to represent each in its language idiom rather than the +// server fabricating exception types. func (w *Worker) handleStreamClosed(params []any) error { - if len(params) < 3 { + if len(params) < 4 { return fmt.Errorf("stream_closed: insufficient params") } executionID, ok := params[0].(string) @@ -706,11 +711,13 @@ func (w *Worker) handleStreamClosed(params []any) error { if !ok { return fmt.Errorf("stream_closed: subscription_id is not a number (got %T)", params[1]) } - errField := params[2] + reason, _ := params[2].(string) + errField := params[3] forwarded := map[string]any{ "execution_id": executionID, "subscription_id": int(subscriptionID), + "reason": reason, } if errField != nil { forwarded["error"] = errField @@ -1728,6 +1735,15 @@ func (w *Worker) trySendResult(executionID string) { // Should only be called after the result has been queued to sendCh (either // via the write callback chain or from flushPending), so that FIFO ordering // ensures the result message precedes notify_terminated. +// +// The execution entry stays in w.executions (with pendingTerminated = true) +// even after a successful send — "successful" here means the local write +// didn't error, which doesn't guarantee delivery when the underlying TCP +// connection is failing silently. The authoritative signal that the +// server received the termination is the next session message: if the +// execution isn't in the server's known set, handleSession drops it. +// Until then, a reconnect triggers flushPending which re-sends both the +// buffered result (via trySendResult) and this termination. func (w *Worker) trySendTerminated(executionID string) { conn := w.getConn() if conn == nil || !conn.IsConnected() { @@ -1746,10 +1762,6 @@ func (w *Worker) trySendTerminated(executionID string) { w.logger.Warn("failed to send terminated, will retry on reconnect", "execution_id", executionID, "error", err) return } - - w.mu.Lock() - delete(w.executions, executionID) - w.mu.Unlock() } // NotifyTerminated is called by the pool after an execution's process has exited. @@ -1809,8 +1821,6 @@ func (w *Worker) handleSession(executionIDs []string) { known[id] = struct{}{} } - // Prune executions not in the server's list (result was delivered, or - // server no longer cares about them). w.mu.Lock() for id := range w.executions { if _, ok := known[id]; !ok { diff --git a/server/lib/coflux/handlers/worker.ex b/server/lib/coflux/handlers/worker.ex index 6af16c36..ee1f3a5b 100644 --- a/server/lib/coflux/handlers/worker.ex +++ b/server/lib/coflux/handlers/worker.ex @@ -359,14 +359,17 @@ defmodule Coflux.Handlers.Worker do {[], state} # If the stream doesn't exist yet (or producer vanished), push an - # immediate close so the consumer doesn't wait forever. + # immediate close so the consumer doesn't wait forever. Carry + # the reason atom verbatim — consumers decide how to surface + # "not_found" / "already_subscribed" in their own idiom. {:error, reason} when reason in [:stream_not_found, :producer_not_found, :already_subscribed] -> {[ command_message("stream_closed", [ consumer_execution_id, subscription_id, - %{"type" => "Coflux.StreamNotFound", "message" => Atom.to_string(reason)} + Atom.to_string(reason), + nil ]) ], state} @@ -650,8 +653,23 @@ defmodule Coflux.Handlers.Worker do {[command_message("stream_demand", [execution_external_id, index, n])], state} end - def websocket_info({:stream_closed, execution_external_id, subscription_id, error}, state) do - {[command_message("stream_closed", [execution_external_id, subscription_id, error])], state} + def websocket_info( + {:stream_closed, execution_external_id, subscription_id, reason, error}, + state + ) do + # `reason` is a string ("complete" / "errored" / "cancelled" / + # "abandoned" / "crashed" / "timeout"); `error` is non-nil only for + # "errored" — the producer's actual `{type, message, frames}`. Other + # reasons travel as the string alone; the consumer adapter decides + # how to represent them. + {[ + command_message("stream_closed", [ + execution_external_id, + subscription_id, + reason, + error + ]) + ], state} end def websocket_info(:stop, state) do diff --git a/server/lib/coflux/orchestration/epoch.ex b/server/lib/coflux/orchestration/epoch.ex index 099cb66c..aa9cf899 100644 --- a/server/lib/coflux/orchestration/epoch.ex +++ b/server/lib/coflux/orchestration/epoch.ex @@ -255,54 +255,84 @@ defmodule Coflux.Orchestration.Epoch do end end) - # Copy results + # Copy result payloads. A results row may be present without a + # completion yet (mid-termination) — both are copied + # independently below. + Enum.each(execution_ids, fn {old_exec_id, new_exec_id} -> + case query_one( + source_db, + """ + SELECT value_id, error_id, retryable, created_at + FROM results + WHERE execution_id = ?1 + """, + {old_exec_id} + ) do + {:ok, {value_id, error_id, retryable, result_created_at}} -> + new_value_id = + if value_id, do: ensure_value(source_db, target_db, value_id) + + new_error_id = + if error_id, do: ensure_error(source_db, target_db, error_id) + + {:ok, _} = + insert_one(target_db, :results, %{ + execution_id: new_exec_id, + value_id: new_value_id, + error_id: new_error_id, + retryable: retryable, + created_at: result_created_at + }) + + {:ok, nil} -> + :ok + end + end) + + # Copy completions. Successors (same-epoch and cross-epoch) are + # remapped here — for in-flight deferred/cached/spawned that + # cross into another run, we may also need to inline a resolved + # value into the just-copied results row. visited = Enum.reduce(execution_ids, visited, fn {old_exec_id, new_exec_id}, visited -> case query_one( source_db, """ - SELECT type, error_id, value_id, successor_id, successor_ref_id, - retryable, created_at, created_by - FROM results + SELECT kind, successor_id, successor_ref_id, created_at, created_by + FROM completions WHERE execution_id = ?1 """, {old_exec_id} ) do {:ok, - {type, error_id, value_id, successor_id, successor_ref_id, retryable, - result_created_at, result_created_by}} -> - new_error_id = if error_id, do: ensure_error(source_db, target_db, error_id) - - {new_value_id, new_successor_id, new_successor_ref_id, visited} = - copy_result_successor( + {kind, successor_id, successor_ref_id, completion_created_at, created_by}} -> + {new_successor_id, new_successor_ref_id, inline_value_id, visited} = + copy_completion_successor( source_db, target_db, - type, - value_id, + kind, successor_id, successor_ref_id, execution_ids, visited ) - new_value_id = - cond do - new_value_id -> new_value_id - value_id && type == 1 -> ensure_value(source_db, target_db, value_id) - true -> nil - end + if inline_value_id do + # Cross-run resolved target — inline the value onto the + # results row (which may or may not already exist). + upsert_inline_value(target_db, new_exec_id, inline_value_id, + created_at: completion_created_at + ) + end {:ok, _} = - insert_one(target_db, :results, %{ + insert_one(target_db, :completions, %{ execution_id: new_exec_id, - type: type, - error_id: new_error_id, - value_id: new_value_id, + kind: kind, successor_id: new_successor_id, successor_ref_id: new_successor_ref_id, - retryable: retryable, - created_at: result_created_at, - created_by: ensure_principal(source_db, target_db, result_created_by) + created_at: completion_created_at, + created_by: ensure_principal(source_db, target_db, created_by) }) visited @@ -312,26 +342,6 @@ defmodule Coflux.Orchestration.Epoch do end end) - # Copy completions (where present — an execution may have results - # but no completion yet if it's mid-termination). - Enum.each(execution_ids, fn {old_exec_id, new_exec_id} -> - case query_one( - source_db, - "SELECT created_at FROM completions WHERE execution_id = ?1", - {old_exec_id} - ) do - {:ok, {completion_created_at}} -> - {:ok, _} = - insert_one(target_db, :completions, %{ - execution_id: new_exec_id, - created_at: completion_created_at - }) - - {:ok, nil} -> - :ok - end - end) - # Copy streams, their items, and any closure rows. An execution's # streams may be mid-production (items appended, no closure) — # carry them forward so consumers can keep reading after rotation. @@ -1820,84 +1830,94 @@ defmodule Coflux.Orchestration.Epoch do ref_id end - # Handle successor copying for result types. - # Returns {new_value_id, new_successor_id, new_successor_ref_id}. - defp copy_result_successor( + # Copy the successor reference on a completion from source to target + # epoch. Returns `{new_successor_id, new_successor_ref_id, inline_value_id, + # visited}`. `inline_value_id` is non-nil when the target has resolved to a + # value; the caller inlines it onto the results row. + # + # Completion kinds (see Results.kind_atom): + # * 0 succeeded, 1 errored — may have successor_id (retry). + # * 2 abandoned, 3 crashed, 4 timeout — may have successor_id (retry). + # * 5 cancelled — no successor. + # * 6 suspended, 7 recurred — successor_id (same-run handoff). + # * 8 deferred, 9 cached, 10 spawned — successor_id in-flight, or + # successor_ref_id once resolved (value inlined on results). + defp copy_completion_successor( _source_db, _target_db, - type, - _value_id, + 5, _successor_id, _successor_ref_id, _execution_ids, visited - ) - when type in [1, 3] do - # Type 1 (value) and type 3 (cancelled) have no successor + ) do + # Cancelled — never has a successor. {nil, nil, nil, visited} end - defp copy_result_successor( + defp copy_completion_successor( _source_db, _target_db, - type, - _value_id, + kind, successor_id, _successor_ref_id, execution_ids, visited ) - when type in [0, 2, 6] do - # Types 0 (error retry), 2 (abandoned retry), 6 (suspended) — same-run successor - new_successor_id = if successor_id, do: Map.fetch!(execution_ids, successor_id) - {nil, new_successor_id, nil, visited} + when kind in [0, 1, 2, 3, 4, 6, 7] do + # Succeeded / errored / abandoned / crashed / timeout / suspended / + # recurred — retry or same-run successor, if any. + new_successor_id = + if successor_id, do: Map.fetch!(execution_ids, successor_id) + + {new_successor_id, nil, nil, visited} end - defp copy_result_successor( + defp copy_completion_successor( source_db, target_db, - _type, - value_id, + kind, successor_id, successor_ref_id, execution_ids, visited - ) do - # Types 4 (deferred), 5 (cached), 7 (spawned) + ) + when kind in [8, 9, 10] do + # Deferred / cached / spawned. cond do - # Already resolved: successor_ref_id + value_id both set successor_ref_id != nil -> + # Already resolved to an execution_ref — copy the ref. Any inlined + # value on the source results row has already been copied by the + # results pass; no further inlining needed here. new_ref_id = ensure_execution_ref(source_db, target_db, successor_ref_id) - new_value_id = ensure_value(source_db, target_db, value_id) - {new_value_id, nil, new_ref_id, visited} + {nil, nil, new_ref_id, visited} - # In-flight, same run successor_id != nil and is_map_key(execution_ids, successor_id) -> - {nil, Map.fetch!(execution_ids, successor_id), nil, visited} + # Same-run in-flight — remap integer id. + {Map.fetch!(execution_ids, successor_id), nil, nil, visited} - # In-flight, cross-run — try to resolve the chain in source_db successor_id != nil -> + # Cross-run in-flight — try to resolve the chain in source_db; if + # resolved, inline the value and swap to an execution_ref. resolve_cross_run_successor(source_db, target_db, successor_id, visited) - # No successor (shouldn't happen for these types) true -> {nil, nil, nil, visited} end end - # For cross-run successors (types 4, 5, 7): try to resolve the result chain. - # If resolved to a value, copy the value and create an execution_ref. - # If still pending, copy the target run and remap the successor_id. + # For cross-run successors on deferred/cached/spawned: try to resolve the + # completion chain in source_db. If resolved to a value, create an + # execution_ref and return the value_id for inlining. Otherwise copy the + # target run and remap the successor_id. defp resolve_cross_run_successor(source_db, target_db, successor_id, visited) do - case resolve_result_chain(source_db, successor_id, MapSet.new()) do + case resolve_completion_chain(source_db, successor_id, MapSet.new()) do {:ok, chain_value_id} -> - # Resolved to a value — copy value and create execution_ref for the successor new_value_id = ensure_value(source_db, target_db, chain_value_id) new_ref_id = create_execution_ref_for_id(source_db, target_db, successor_id) - {new_value_id, nil, new_ref_id, visited} + {nil, new_ref_id, new_value_id, visited} _ -> - # Pending or cancelled — copy the target run, then remap successor_id {:ok, {run_ext_id, step_num, attempt}} = query_one!( source_db, @@ -1925,54 +1945,88 @@ defmodule Coflux.Orchestration.Epoch do """, {run_ext_id, step_num, attempt} ) do - {:ok, {new_id}} -> {nil, new_id, nil, visited} + {:ok, {new_id}} -> {new_id, nil, nil, visited} {:ok, nil} -> {nil, nil, nil, visited} end end end - # Follow the result successor chain in a single DB to find a terminal value. - defp resolve_result_chain(db, execution_id, visited) do + # Follow the completion-successor chain in a single DB to find a terminal + # value. Returns `{:ok, value_id}` if resolved, `:pending`/`:cancelled`/ + # `:timeout` otherwise. The caller uses this to decide whether to inline a + # value onto the deferred/cached/spawned target's results row. + defp resolve_completion_chain(db, execution_id, visited) do if MapSet.member?(visited, execution_id) do :pending else visited = MapSet.put(visited, execution_id) - case query_one( - db, - "SELECT type, value_id, successor_id, successor_ref_id FROM results WHERE execution_id = ?1", - {execution_id} - ) do - {:ok, nil} -> + # Value comes from results; kind + successor from completions. + result_row = + query_one( + db, + "SELECT value_id, error_id FROM results WHERE execution_id = ?1", + {execution_id} + ) + + completion_row = + query_one( + db, + "SELECT kind, successor_id, successor_ref_id FROM completions WHERE execution_id = ?1", + {execution_id} + ) + + case {result_row, completion_row} do + {_, {:ok, nil}} -> + # No completion yet. :pending - # Plain value - {:ok, {1, value_id, nil, nil}} -> + {{:ok, {value_id, nil}}, _} when not is_nil(value_id) -> + # Value payload recorded — treat as resolved regardless of kind + # (covers :succeeded and resolved :deferred/:cached/:spawned). {:ok, value_id} - # Cancelled - {:ok, {3, nil, nil, nil}} -> + {_, {:ok, {5, _, _}}} -> :cancelled - # Timeout (follow retry chain if successor_id set) - {:ok, {8, nil, successor_id, nil}} when not is_nil(successor_id) -> - resolve_result_chain(db, successor_id, visited) - - {:ok, {8, nil, nil, nil}} -> + {_, {:ok, {4, nil, _}}} -> :timeout - # Types 4, 5, 7 already resolved (successor_ref_id + value_id) - {:ok, {type, value_id, nil, successor_ref_id}} - when type in [4, 5, 7] and not is_nil(successor_ref_id) and not is_nil(value_id) -> - {:ok, value_id} - - # Any type with successor_id — follow the chain - {:ok, {_type, nil, successor_id, nil}} when not is_nil(successor_id) -> - resolve_result_chain(db, successor_id, visited) + {_, {:ok, {_kind, successor_id, nil}}} when not is_nil(successor_id) -> + resolve_completion_chain(db, successor_id, visited) _ -> :pending end end end + + # Write or update a results row to inline a resolved value. Used during + # cross-run rotation when we resolve a deferred/cached/spawned target to + # a concrete value: the completion records the kind, the results row + # records the inlined payload. + defp upsert_inline_value(target_db, execution_id, value_id, opts) do + created_at = Keyword.fetch!(opts, :created_at) + + case query_one( + target_db, + "SELECT 1 FROM results WHERE execution_id = ?1", + {execution_id} + ) do + {:ok, nil} -> + {:ok, _} = + insert_one(target_db, :results, %{ + execution_id: execution_id, + value_id: value_id, + error_id: nil, + retryable: nil, + created_at: created_at + }) + + {:ok, {1}} -> + # Already has a payload (shouldn't happen for fresh cross-run + # resolution, but be idempotent). + :ok + end + end end diff --git a/server/lib/coflux/orchestration/inputs.ex b/server/lib/coflux/orchestration/inputs.ex index 892dbac9..857e3de0 100644 --- a/server/lib/coflux/orchestration/inputs.ex +++ b/server/lib/coflux/orchestration/inputs.ex @@ -375,10 +375,13 @@ defmodule Coflux.Orchestration.Inputs do INNER JOIN runs AS r ON r.id = i.run_id LEFT JOIN input_responses AS ir ON ir.input_id = i.id INNER JOIN input_dependencies AS id ON id.input_id = i.id - LEFT JOIN results AS dr ON dr.execution_id = id.execution_id + LEFT JOIN completions AS dc ON dc.execution_id = id.execution_id WHERE i.workspace_id = ?1 AND ir.input_id IS NULL - AND (dr.execution_id IS NULL OR dr.successor_id IS NOT NULL) + -- An execution is still "active" if it hasn't completed yet, or if + -- its completion carries a successor (retry / suspended / deferred + -- chain that will resume somewhere). + AND (dc.execution_id IS NULL OR dc.successor_id IS NOT NULL) ORDER BY i.created_at DESC """, {workspace_id} @@ -386,16 +389,17 @@ defmodule Coflux.Orchestration.Inputs do end def has_active_dependency?(db, input_id) do - # An execution is "active" if it has no result, or if its result has a - # successor (suspended/retried — the chain is still alive). + # An execution is "active" if it has no completion yet, or if its + # completion carries a successor (suspended/retried/deferred — the + # chain is still alive). case query_one( db, """ SELECT 1 FROM input_dependencies AS id - LEFT JOIN results AS r ON r.execution_id = id.execution_id + LEFT JOIN completions AS c ON c.execution_id = id.execution_id WHERE id.input_id = ?1 - AND (r.execution_id IS NULL OR r.successor_id IS NOT NULL) + AND (c.execution_id IS NULL OR c.successor_id IS NOT NULL) LIMIT 1 """, {input_id} diff --git a/server/lib/coflux/orchestration/results.ex b/server/lib/coflux/orchestration/results.ex index 05d97b0d..9c635709 100644 --- a/server/lib/coflux/orchestration/results.ex +++ b/server/lib/coflux/orchestration/results.ex @@ -3,115 +3,231 @@ defmodule Coflux.Orchestration.Results do alias Coflux.Orchestration.{Errors, Values} - # Writes the results row capturing the disposition (value/error/retryable) - # and any server-decided successor. Written at the time the disposition is - # known — for worker-reported results that's put_result/put_error/etc.; - # for server-initiated dispositions (abandonment, defer/cache/spawn) it's - # when the server makes the decision. + # --- Completion kinds --- # - # A matching completion row is written separately via record_completion, - # typically when the worker's process confirms it has terminated (via - # notify_terminated). For server-initiated cases that never involve a - # worker, the caller writes both in sequence. - def record_result(db, execution_id, result, created_by \\ nil) do - with_snapshot(db, fn -> - now = current_timestamp() - - {type, error_id, value_id, successor_id, successor_ref_id, retryable} = - case result do - {:error, type, message, frames, retry_id, retryable} -> - error_id = Errors.get_or_create(db, type, message, frames) - {0, error_id, nil, retry_id, nil, retryable} - - {:error, type, message, frames, retry_id} -> - error_id = Errors.get_or_create(db, type, message, frames) - {0, error_id, nil, retry_id, nil, nil} - - {:value, value} -> - {:ok, value_id} = Values.get_or_create_value(db, value) - {1, nil, value_id, nil, nil, nil} - - {:abandoned, retry_id} -> - {2, nil, nil, retry_id, nil, nil} - - :cancelled -> - {3, nil, nil, nil, nil, nil} - - {:timeout, retry_id} -> - {8, nil, nil, retry_id, nil, nil} - - # In-flight deferred (successor still executing) - {:deferred, defer_id} -> - {4, nil, nil, defer_id, nil, nil} + # Kept in sync with the `completions.kind` column. See migration 4 for the + # authoritative list and descriptions. + + @kind_succeeded 0 + @kind_errored 1 + @kind_abandoned 2 + @kind_crashed 3 + @kind_timeout 4 + @kind_cancelled 5 + @kind_suspended 6 + @kind_recurred 7 + @kind_deferred 8 + @kind_cached 9 + @kind_spawned 10 + + @failure_kinds [@kind_errored, @kind_abandoned, @kind_crashed, @kind_timeout] + + def kind_atom(0), do: :succeeded + def kind_atom(1), do: :errored + def kind_atom(2), do: :abandoned + def kind_atom(3), do: :crashed + def kind_atom(4), do: :timeout + def kind_atom(5), do: :cancelled + def kind_atom(6), do: :suspended + def kind_atom(7), do: :recurred + def kind_atom(8), do: :deferred + def kind_atom(9), do: :cached + def kind_atom(10), do: :spawned + + def atom_kind(:succeeded), do: @kind_succeeded + def atom_kind(:errored), do: @kind_errored + def atom_kind(:abandoned), do: @kind_abandoned + def atom_kind(:crashed), do: @kind_crashed + def atom_kind(:timeout), do: @kind_timeout + def atom_kind(:cancelled), do: @kind_cancelled + def atom_kind(:suspended), do: @kind_suspended + def atom_kind(:recurred), do: @kind_recurred + def atom_kind(:deferred), do: @kind_deferred + def atom_kind(:cached), do: @kind_cached + def atom_kind(:spawned), do: @kind_spawned + + def failure_kinds, do: @failure_kinds + + # --- Writing results (payload only) --- + # + # Written when a worker's process reports a value or an error for the task + # body. Does not write a completion row — the completion is written later + # (via notify_terminated for worker-involved cases, or directly by the + # server for server-initiated dispositions). - # Resolved deferred (successor resolved to a value — from epoch copy - # or runtime cache hit) - {:deferred, ref_id, value} -> - {:ok, value_id} = Values.get_or_create_value(db, value) - {4, nil, value_id, nil, ref_id, nil} + # Writes a value payload. Returns the created timestamp on success. + def record_value_result(db, execution_id, value) do + with_snapshot(db, fn -> + {:ok, value_id} = Values.get_or_create_value(db, value) + insert_result_row(db, execution_id, value_id: value_id) + end) + end - # In-flight cached - {:cached, cached_id} -> - {5, nil, nil, cached_id, nil, nil} + # Writes an error payload. `retryable` is the optional `when`-callback + # result from the worker: `nil` = no callback configured, `true` = callback + # allows retry, `false` = callback blocks retry. + def record_error_result(db, execution_id, type, message, frames, retryable \\ nil) do + with_snapshot(db, fn -> + error_id = Errors.get_or_create(db, type, message, frames) + insert_result_row(db, execution_id, error_id: error_id, retryable: retryable) + end) + end - # Resolved cached - {:cached, ref_id, value} -> - {:ok, value_id} = Values.get_or_create_value(db, value) - {5, nil, value_id, nil, ref_id, nil} + defp insert_result_row(db, execution_id, fields) do + now = current_timestamp() - {:suspended, successor_id} -> - {6, nil, nil, successor_id, nil, nil} + row = + Map.merge( + %{execution_id: execution_id, created_at: now}, + Map.new(fields, fn + {:retryable, nil} -> {:retryable, nil} + {:retryable, true} -> {:retryable, 1} + {:retryable, false} -> {:retryable, 0} + {k, v} -> {k, v} + end) + ) - {:recurred, successor_id} -> - {9, nil, nil, successor_id, nil, nil} + case insert_one(db, :results, row) do + {:ok, _} -> {:ok, now} + {:error, "UNIQUE constraint failed: " <> _} -> {:error, :already_recorded} + end + end - # In-flight spawned - {:spawned, execution_id} -> - {7, nil, nil, execution_id, nil, nil} + # Compatibility shim. Dispatches legacy-shaped result tuples to the + # appropriate split-API writes (result payload and/or completion). For + # value/error the caller must still invoke record_completion later (via + # complete_execution). For all other tagged tuples this writes the + # completion directly — no separate record_completion call needed. + # + # Returns `{:ok, timestamp}` on success. The timestamp is the results row + # for value/error (result arrival time) and the completions row + # otherwise (terminal state time). + def record_result(db, execution_id, result, created_by \\ nil) do + case result do + {:value, value} -> + record_value_result(db, execution_id, value) + + {:error, type, message, frames, _retry_id, retryable} -> + # Retry successor is recorded on the completion now — so we drop it + # from the results row. The caller passes it to record_completion. + record_error_result(db, execution_id, type, message, frames, retryable) + + {:error, type, message, frames, _retry_id} -> + record_error_result(db, execution_id, type, message, frames, nil) + + :cancelled -> + record_completion(db, execution_id, :cancelled, created_by: created_by) + + {:abandoned, retry_id} -> + record_completion(db, execution_id, :abandoned, + successor_id: retry_id, + created_by: created_by + ) + + {:crashed, retry_id} -> + record_completion(db, execution_id, :crashed, + successor_id: retry_id, + created_by: created_by + ) + + {:timeout, retry_id} -> + record_completion(db, execution_id, :timeout, + successor_id: retry_id, + created_by: created_by + ) + + {:suspended, successor_id} -> + record_completion(db, execution_id, :suspended, + successor_id: successor_id, + created_by: created_by + ) + + {:recurred, successor_id} -> + record_completion(db, execution_id, :recurred, + successor_id: successor_id, + created_by: created_by + ) + + {:deferred, successor_id} -> + record_completion(db, execution_id, :deferred, + successor_id: successor_id, + created_by: created_by + ) + + {:deferred, ref_id, value} -> + with {:ok, _} <- record_value_result(db, execution_id, value) do + record_completion(db, execution_id, :deferred, + successor_ref_id: ref_id, + created_by: created_by + ) + end - # Resolved spawned - {:spawned, ref_id, value} -> - {:ok, value_id} = Values.get_or_create_value(db, value) - {7, nil, value_id, nil, ref_id, nil} + {:cached, successor_id} -> + record_completion(db, execution_id, :cached, + successor_id: successor_id, + created_by: created_by + ) + + {:cached, ref_id, value} -> + with {:ok, _} <- record_value_result(db, execution_id, value) do + record_completion(db, execution_id, :cached, + successor_ref_id: ref_id, + created_by: created_by + ) end - case insert_result( - db, - execution_id, - type, - error_id, - value_id, - successor_id, - successor_ref_id, - retryable, - now, - created_by - ) do - {:ok, _} -> {:ok, now} - {:error, "UNIQUE constraint failed: " <> _field} -> {:error, :already_recorded} - end - end) + {:spawned, successor_id} -> + record_completion(db, execution_id, :spawned, + successor_id: successor_id, + created_by: created_by + ) + + {:spawned, ref_id, value} -> + with {:ok, _} <- record_value_result(db, execution_id, value) do + record_completion(db, execution_id, :spawned, + successor_ref_id: ref_id, + created_by: created_by + ) + end + end end - # Writes the completion row — a simple timestamp marker recording that the - # execution's process has fully terminated. For worker-involved cases this - # is triggered by notify_terminated; for server-initiated dispositions - # (abandonment, cache-hit scheduling) the caller writes this right after - # record_result. - def record_completion(db, execution_id) do + # --- Writing completions --- + # + # Written at the point the execution's terminal state becomes known: + # * Worker-involved cases: notify_terminated arrives. Kind is derived + # from whether a result row exists (succeeded / errored) or not + # (crashed). Caller supplies the retry successor, if any. + # * Server-initiated cases (abandon / cancel / cache-hit / defer / + # spawn / suspend / recur / timeout): the server calls this directly + # with the appropriate kind and successor. + + # `kind` is an atom from the enum above. `opts` accepts: + # * `:successor_id` — integer FK into executions (same-epoch pointer). + # * `:successor_ref_id` — integer FK into execution_refs (post-rotation). + # * `:created_by` — principal id. + def record_completion(db, execution_id, kind, opts \\ []) when is_atom(kind) do with_transaction(db, fn -> now = current_timestamp() - case insert_one(db, :completions, %{ - execution_id: execution_id, - created_at: now - }) do + row = %{ + execution_id: execution_id, + kind: atom_kind(kind), + successor_id: Keyword.get(opts, :successor_id), + successor_ref_id: Keyword.get(opts, :successor_ref_id), + created_at: now, + created_by: Keyword.get(opts, :created_by) + } + + case insert_one(db, :completions, row) do {:ok, _} -> {:ok, now} - {:error, "UNIQUE constraint failed: " <> _field} -> {:error, :already_completed} + {:error, "UNIQUE constraint failed: " <> _} -> {:error, :already_completed} end end) end + # --- Existence checks --- + def has_result?(db, execution_id) do case query_one(db, "SELECT count(*) FROM results WHERE execution_id = ?1", {execution_id}) do {:ok, {0}} -> {:ok, false} @@ -130,111 +246,286 @@ defmodule Coflux.Orchestration.Results do end end + # --- Reading --- + + # Returns the logical result for a consumer, derived by joining results and + # completions. Shape: + # `{:ok, {logical_result, result_at, completion_at, created_by}}` + # `{:ok, nil}` — execution has no result and no completion yet + # + # `logical_result` is a tagged tuple in the same shape the old + # single-table version returned — kept compatible so existing callers + # don't have to change their pattern matching: + # * `{:value, value}` + # * `{:error, type, message, frames, retry_id, retryable}` + # * `:cancelled` + # * `{:abandoned, retry_id}` + # * `{:crashed, retry_id}` + # * `{:timeout, retry_id}` + # * `{:suspended, successor_id}` + # * `{:recurred, successor_id}` + # * `{:deferred, successor_id}` — in-flight + # * `{:deferred, successor_ref_id, value}` — resolved + # * `{:cached, successor_id}` — in-flight + # * `{:cached, successor_ref_id, value}` — resolved + # * `{:spawned, successor_id}` — in-flight + # * `{:spawned, successor_ref_id, value}` — resolved + # + # `result_at` is the results row's created_at (nil if none). + # `completion_at` is the completions row's created_at (nil if none). + # `created_by` is the completion's creator principal (nil if none). def get_result(db, execution_id) do + result_row = + query_one( + db, + """ + SELECT value_id, error_id, retryable, created_at + FROM results + WHERE execution_id = ?1 + """, + {execution_id} + ) + + completion_row = + query_one( + db, + """ + SELECT c.kind, c.successor_id, c.successor_ref_id, c.created_at, + p.user_external_id, t.external_id + FROM completions AS c + LEFT JOIN principals AS p ON c.created_by = p.id + LEFT JOIN tokens AS t ON p.token_id = t.id + WHERE c.execution_id = ?1 + """, + {execution_id} + ) + + {:ok, result_at, value_id, error_id, retryable} = decode_result_row(result_row) + + {:ok, completion_at, kind, successor_id, successor_ref_id, created_by} = + decode_completion_row(completion_row) + + case resolve_logical( + db, + kind, + value_id, + error_id, + retryable, + successor_id, + successor_ref_id + ) do + nil -> {:ok, nil} + logical -> {:ok, {logical, result_at, completion_at, created_by}} + end + end + + # Returns the raw result payload (not joined with completion). Used by + # the completion-writing path to decide kind + retry from the worker's + # recorded outcome without needing any transient in-memory state. + # Returns: + # `{:ok, {:value, value}}` + # `{:ok, {:error, type, message, frames, retryable}}` + # `{:ok, nil}` — no result payload recorded + def get_result_payload(db, execution_id) do case query_one( db, """ - SELECT r.type, r.error_id, r.value_id, r.successor_id, r.successor_ref_id, - r.retryable, r.created_at, c.created_at AS completion_created_at, - p.user_external_id AS created_by_user_external_id, - t.external_id AS created_by_token_external_id - FROM results AS r - LEFT JOIN completions AS c ON c.execution_id = r.execution_id - LEFT JOIN principals AS p ON r.created_by = p.id + SELECT value_id, error_id, retryable + FROM results + WHERE execution_id = ?1 + """, + {execution_id} + ) do + {:ok, nil} -> + {:ok, nil} + + {:ok, {value_id, nil, _}} when not is_nil(value_id) -> + {:ok, value} = Values.get_value_by_id(db, value_id) + {:ok, {:value, value}} + + {:ok, {nil, error_id, retryable}} when not is_nil(error_id) -> + {:ok, {type, message, frames}} = Errors.get_by_id(db, error_id) + {:ok, {:error, type, message, frames, decode_retryable(retryable)}} + end + end + + # Returns the raw completion row for an execution. Shape: + # `{:ok, {kind_atom, successor_id, successor_ref_id, created_at, created_by}}` + # `{:ok, nil}` — no completion yet + def get_completion(db, execution_id) do + case query_one( + db, + """ + SELECT c.kind, c.successor_id, c.successor_ref_id, c.created_at, + p.user_external_id, t.external_id + FROM completions AS c + LEFT JOIN principals AS p ON c.created_by = p.id LEFT JOIN tokens AS t ON p.token_id = t.id - WHERE r.execution_id = ?1 + WHERE c.execution_id = ?1 """, {execution_id} ) do {:ok, nil} -> - # No results row. If the execution has a completion row anyway, - # the worker terminated without ever reporting — treat as crashed. - case query_one( - db, - "SELECT created_at FROM completions WHERE execution_id = ?1", - {execution_id} - ) do - {:ok, {completion_created_at}} -> - {:ok, {{:crashed, nil}, nil, completion_created_at, nil}} - - {:ok, nil} -> - {:ok, nil} - end + {:ok, nil} + + {:ok, {kind, successor_id, successor_ref_id, created_at, user_ext, token_ext}} -> + created_by = decode_principal(user_ext, token_ext) + {:ok, {kind_atom(kind), successor_id, successor_ref_id, created_at, created_by}} + end + end + + # Returns the execution's status as a single atom, derived from the + # split result/completion tables. Used by UI and lifecycle checks that + # used to key off `result != nil` — in the new model the execution's + # lifecycle is completion-driven. + # + # Status values: + # * `:pending` — nothing recorded yet (still queued/running) + # * `:draining` — result recorded, completion not yet (streams etc.) + # * a completion kind atom (`:succeeded` / `:errored` / ...) + def execution_status(db, execution_id) do + has_completion = has_completion?(db, execution_id) + has_result = has_result?(db, execution_id) + + case {has_completion, has_result} do + {{:ok, true}, _} -> + {:ok, {kind, _, _, _, _}} = get_completion(db, execution_id) + kind + + {{:ok, false}, {:ok, true}} -> + :draining + + {{:ok, false}, {:ok, false}} -> + :pending + end + end - {:ok, - {type, error_id, value_id, successor_id, successor_ref_id, retryable, created_at, - completion_created_at, created_by_user_ext_id, created_by_token_ext_id}} -> - created_by = - case {created_by_user_ext_id, created_by_token_ext_id} do - {nil, nil} -> nil - {user_ext_id, nil} -> %{type: "user", external_id: user_ext_id} - {nil, token_ext_id} -> %{type: "token", external_id: token_ext_id} - end - - retryable = - case retryable do - 1 -> true - 0 -> false - nil -> nil - end - - result = - case {type, error_id, value_id, successor_id, successor_ref_id} do - {0, error_id, nil, retry_id, nil} -> - case Errors.get_by_id(db, error_id) do - {:ok, {type, message, frames}} -> - {:error, type, message, frames, retry_id, retryable} - end - - {1, nil, value_id, nil, nil} -> - case Values.get_value_by_id(db, value_id) do - {:ok, value} -> {:value, value} - end - - {2, nil, nil, retry_id, nil} -> - {:abandoned, retry_id} - - {3, nil, nil, nil, nil} -> - :cancelled - - {8, nil, nil, retry_id, nil} -> - {:timeout, retry_id} - - {4, nil, nil, defer_id, nil} -> - {:deferred, defer_id} - - {4, nil, value_id, nil, ref_id} when not is_nil(ref_id) -> - case Values.get_value_by_id(db, value_id) do - {:ok, value} -> {:deferred, ref_id, value} - end - - {5, nil, nil, cached_id, nil} -> - {:cached, cached_id} - - {5, nil, value_id, nil, ref_id} when not is_nil(ref_id) -> - case Values.get_value_by_id(db, value_id) do - {:ok, value} -> {:cached, ref_id, value} - end - - {6, nil, nil, successor_id, nil} -> - {:suspended, successor_id} - - {7, nil, nil, execution_id, nil} -> - {:spawned, execution_id} - - {7, nil, value_id, nil, ref_id} when not is_nil(ref_id) -> - case Values.get_value_by_id(db, value_id) do - {:ok, value} -> {:spawned, ref_id, value} - end - - {9, nil, nil, successor_id, nil} -> - {:recurred, successor_id} - end - - {:ok, {result, created_at, completion_created_at, created_by}} + # True when the execution is in a state that could resolve to a useful + # value — either still running (pending/draining) or cleanly completed. + # Used for cache candidacy and memoisation: once a negative signal + # appears (error / abandoned / crashed / timeout) this flips to false. + def cache_candidate?(db, execution_id) do + case execution_status(db, execution_id) do + :pending -> true + :draining -> true + :succeeded -> true + :suspended -> true + :recurred -> true + :deferred -> true + :cached -> true + :spawned -> true + _ -> false end end + # --- Helpers --- + + defp decode_result_row({:ok, nil}), do: {:ok, nil, nil, nil, nil} + + defp decode_result_row({:ok, {value_id, error_id, retryable, created_at}}), + do: {:ok, created_at, value_id, error_id, retryable} + + defp decode_completion_row({:ok, nil}), do: {:ok, nil, nil, nil, nil, nil} + + defp decode_completion_row( + {:ok, {kind, successor_id, successor_ref_id, created_at, user_ext, token_ext}} + ) do + {:ok, created_at, kind, successor_id, successor_ref_id, + decode_principal(user_ext, token_ext)} + end + + defp decode_principal(nil, nil), do: nil + defp decode_principal(user_ext, nil), do: %{type: "user", external_id: user_ext} + defp decode_principal(nil, token_ext), do: %{type: "token", external_id: token_ext} + + defp decode_retryable(nil), do: nil + defp decode_retryable(1), do: true + defp decode_retryable(0), do: false + + # Builds the legacy "logical result" tuple from the split tables. Used + # by most callers (UI, topic state, consumer resolution). Returns `nil` + # only when nothing has been recorded yet. + # + # `resolve_result` in server.ex is responsible for deciding when it's + # safe to follow a successor — error results without a completion carry + # `nil` as their successor here, and the server treats that as "still + # pending". + defp resolve_logical(_db, nil, nil, nil, _, _, _), do: nil + + # Value payload present. Returns the appropriate tagged tuple, picking + # the successor-flavoured form when the completion says this was a + # deferred/cached/spawned resolution. + defp resolve_logical(db, kind, value_id, nil, _retryable, _successor_id, successor_ref_id) + when not is_nil(value_id) do + {:ok, value} = Values.get_value_by_id(db, value_id) + + case kind && kind_atom(kind) do + :deferred when not is_nil(successor_ref_id) -> {:deferred, successor_ref_id, value} + :cached when not is_nil(successor_ref_id) -> {:cached, successor_ref_id, value} + :spawned when not is_nil(successor_ref_id) -> {:spawned, successor_ref_id, value} + _ -> {:value, value} + end + end + + # Error payload without a completion yet. We return the error so UI can + # display it; the successor slot is nil so consumer resolution treats + # it as still pending (the retry decision happens at completion time). + defp resolve_logical(db, nil, nil, error_id, retryable, _successor_id, _successor_ref_id) + when not is_nil(error_id) do + {:ok, {type, message, frames}} = Errors.get_by_id(db, error_id) + {:error, type, message, frames, nil, decode_retryable(retryable)} + end + + # Completion present (possibly with no results row). + defp resolve_logical(db, kind, _value_id, error_id, retryable, successor_id, successor_ref_id) do + case kind_atom(kind) do + :succeeded -> + nil + + :errored when not is_nil(error_id) -> + {:ok, {type, message, frames}} = Errors.get_by_id(db, error_id) + {:error, type, message, frames, successor_id, decode_retryable(retryable)} + + :abandoned -> + {:abandoned, successor_id} + + :crashed -> + {:crashed, successor_id} + + :timeout -> + {:timeout, successor_id} + + :cancelled -> + :cancelled + + :suspended -> + {:suspended, successor_id} + + :recurred -> + {:recurred, successor_id} + + :deferred -> + build_successor_tuple(db, :deferred, nil, successor_id, successor_ref_id) + + :cached -> + build_successor_tuple(db, :cached, nil, successor_id, successor_ref_id) + + :spawned -> + build_successor_tuple(db, :spawned, nil, successor_id, successor_ref_id) + end + end + + defp build_successor_tuple(_db, tag, nil, successor_id, nil) when not is_nil(successor_id), + do: {tag, successor_id} + + defp build_successor_tuple(db, tag, value_id, nil, successor_ref_id) + when not is_nil(value_id) and not is_nil(successor_ref_id) do + {:ok, value} = Values.get_value_by_id(db, value_id) + {tag, successor_ref_id, value} + end + + # --- Assets (unchanged) --- + def put_execution_asset(db, execution_id, asset_id) do now = current_timestamp() @@ -242,59 +533,23 @@ defmodule Coflux.Orchestration.Results do insert_one( db, :execution_assets, - %{ - execution_id: execution_id, - asset_id: asset_id, - created_at: now - }, + %{execution_id: execution_id, asset_id: asset_id, created_at: now}, on_conflict: "DO NOTHING" ) :ok end - # TODO: get all assets for run? def get_assets_for_execution(db, execution_id) do case query( db, "SELECT asset_id FROM execution_assets WHERE execution_id = ?1", {execution_id} ) do - {:ok, rows} -> - {:ok, Enum.map(rows, fn {asset_id} -> asset_id end)} + {:ok, rows} -> {:ok, Enum.map(rows, fn {asset_id} -> asset_id end)} end end - defp insert_result( - db, - execution_id, - type, - error_id, - value_id, - successor_id, - successor_ref_id, - retryable, - created_at, - created_by - ) do - insert_one(db, :results, %{ - execution_id: execution_id, - type: type, - error_id: error_id, - value_id: value_id, - successor_id: successor_id, - successor_ref_id: successor_ref_id, - retryable: - case retryable do - nil -> nil - true -> 1 - false -> 0 - end, - created_at: created_at, - created_by: created_by - }) - end - defp current_timestamp() do System.os_time(:millisecond) end diff --git a/server/lib/coflux/orchestration/runs.ex b/server/lib/coflux/orchestration/runs.ex index 9643e364..dcae92c5 100644 --- a/server/lib/coflux/orchestration/runs.ex +++ b/server/lib/coflux/orchestration/runs.ex @@ -1,5 +1,5 @@ defmodule Coflux.Orchestration.Runs do - alias Coflux.Orchestration.{Models, Values, TagSets, CacheConfigs, Utils} + alias Coflux.Orchestration.{Models, Results, Values, TagSets, CacheConfigs, Utils} import Coflux.Store @@ -507,10 +507,10 @@ defmodule Coflux.Orchestration.Runs do """ SELECT e.id FROM executions AS e - LEFT JOIN results AS r ON r.execution_id = e.id + LEFT JOIN completions AS c ON c.execution_id = e.id WHERE e.step_id = ?1 AND e.workspace_id = ?2 - AND r.execution_id IS NULL + AND c.execution_id IS NULL """, {step_id, workspace_id} ) do @@ -650,8 +650,13 @@ defmodule Coflux.Orchestration.Runs do INNER JOIN steps AS s ON s.id = e.step_id INNER JOIN runs AS run ON run.id = s.run_id LEFT JOIN assignments AS a ON a.execution_id = e.id - LEFT JOIN results AS r ON r.execution_id = e.id - WHERE a.created_at IS NULL AND r.created_at IS NULL + LEFT JOIN completions AS c ON c.execution_id = e.id + -- Unassigned = no worker has picked it up AND no terminal state has + -- been recorded. Server-initiated resolutions (deferred / cached / + -- cancelled) write only a completions row, so filtering on + -- completions (rather than results) is what stops the scheduler from + -- re-picking them. + WHERE a.created_at IS NULL AND c.created_at IS NULL ORDER BY e.execute_after, e.created_at, s.priority DESC """, {}, @@ -660,6 +665,9 @@ defmodule Coflux.Orchestration.Runs do end def get_queue_executions(db, workspace_id) do + # "Still in the queue" = no completion yet. An execution with a value + # result but no completion (streams draining) is still running from + # the lifecycle's point of view, so it stays visible on the queue. case query( db, """ @@ -678,8 +686,8 @@ defmodule Coflux.Orchestration.Runs do INNER JOIN steps AS s ON s.id = e.step_id INNER JOIN runs AS r ON r.id = s.run_id LEFT JOIN assignments AS a ON a.execution_id = e.id - LEFT JOIN results AS re ON re.execution_id = e.id - WHERE e.workspace_id = ?1 AND re.created_at IS NULL + LEFT JOIN completions AS c ON c.execution_id = e.id + WHERE e.workspace_id = ?1 AND c.created_at IS NULL """, {workspace_id} ) do @@ -695,8 +703,8 @@ defmodule Coflux.Orchestration.Runs do SELECT e.id, s.run_id, s.module FROM executions AS e INNER JOIN steps AS s ON s.id = e.step_id - LEFT JOIN results AS r ON r.execution_id = e.id - WHERE e.workspace_id = ?1 AND r.created_at IS NULL + LEFT JOIN completions AS c ON c.execution_id = e.id + WHERE e.workspace_id = ?1 AND c.created_at IS NULL """, {workspace_id} ) @@ -708,8 +716,8 @@ defmodule Coflux.Orchestration.Runs do """ SELECT a.session_id, a.execution_id FROM assignments AS a - LEFT JOIN results AS r ON r.execution_id = a.execution_id - WHERE r.created_at IS NULL + LEFT JOIN completions AS c ON c.execution_id = a.execution_id + WHERE c.created_at IS NULL """ ) end @@ -729,17 +737,19 @@ defmodule Coflux.Orchestration.Runs do INNER JOIN executions AS e ON e.step_id = s.id WHERE s.parent_id IS NOT NULL UNION ALL - SELECT r.execution_id AS parent_id, r.successor_id AS child_id - FROM results AS r - WHERE r.type = 7 AND r.successor_id IS NOT NULL + -- Spawned completions (kind = 10) reference a child execution via + -- successor_id. Traversing these captures the full spawn tree. + SELECT c.execution_id AS parent_id, c.successor_id AS child_id + FROM completions AS c + WHERE c.kind = 10 AND c.successor_id IS NOT NULL ) AS edges ON edges.parent_id = d.execution_id ) - SELECT e.id, s.module, a.created_at, r.created_at, e.workspace_id + SELECT e.id, s.module, a.created_at, c.created_at, e.workspace_id FROM descendants AS d INNER JOIN executions AS e ON e.id = d.execution_id INNER JOIN steps AS s ON s.id = e.step_id LEFT JOIN assignments AS a ON a.execution_id = e.id - LEFT JOIN results AS r ON r.execution_id = e.id + LEFT JOIN completions AS c ON c.execution_id = e.id """, {execution_id} ) @@ -840,11 +850,13 @@ defmodule Coflux.Orchestration.Runs do end def get_active_run_workflows(db, workspace_id \\ nil) do + # Active = no completion yet. Matches the queue/lifecycle semantics: + # streams-draining executions still count as running. {where, params} = if workspace_id do - {"WHERE res.created_at IS NULL AND e.workspace_id = ?1", {workspace_id}} + {"WHERE c.created_at IS NULL AND e.workspace_id = ?1", {workspace_id}} else - {"WHERE res.created_at IS NULL", {}} + {"WHERE c.created_at IS NULL", {}} end query( @@ -855,7 +867,7 @@ defmodule Coflux.Orchestration.Runs do INNER JOIN steps AS s ON s.id = e.step_id INNER JOIN runs AS r ON r.id = s.run_id INNER JOIN steps AS root_s ON root_s.run_id = r.id AND root_s.parent_id IS NULL - LEFT JOIN results AS res ON res.execution_id = e.id + LEFT JOIN completions AS c ON c.execution_id = e.id #{where} """, params @@ -955,16 +967,18 @@ defmodule Coflux.Orchestration.Runs do # if it has either a results row or a completions row (the latter without # the former indicates the worker crashed without reporting). # For crashed executions, type is NULL. - def get_step_result_types(db, step_id, limit) do + # Returns `{execution_id, completion_kind}` pairs for recent attempts of a + # step, ordered newest first. `completion_kind` is the raw integer from + # the completions table (see Results.kind_atom/1). Used by the retry + # scheduler to count consecutive failures. + def get_step_completion_kinds(db, step_id, limit) do case query( db, """ - SELECT e.id, r.type + SELECT e.id, c.kind FROM executions AS e - LEFT JOIN results AS r ON r.execution_id = e.id - LEFT JOIN completions AS c ON c.execution_id = e.id + INNER JOIN completions AS c ON c.execution_id = e.id WHERE e.step_id = ?1 - AND (r.execution_id IS NOT NULL OR c.execution_id IS NOT NULL) ORDER BY e.created_at DESC LIMIT ?2 """, @@ -1107,11 +1121,18 @@ defmodule Coflux.Orchestration.Runs do FROM steps AS s INNER JOIN executions AS e ON e.step_id = s.id LEFT JOIN results AS r ON r.execution_id = e.id + LEFT JOIN completions AS c ON c.execution_id = e.id WHERE s.run_id = ?1 AND e.workspace_id IN (#{build_placeholders(length(workspace_ids), 1)}) AND s.memo_key = ?#{length(workspace_ids) + 2} - AND (r.type IS NULL OR r.type = 1) + -- Either no result yet (in-flight candidate) or a value + -- result. Errors disqualify. Cancelled-with-value also + -- disqualifies: a user-cancelled execution's work shouldn't + -- be reused, even though its value is still valid for + -- already-resolved consumers. + AND (r.execution_id IS NULL OR r.value_id IS NOT NULL) + AND (c.kind IS NULL OR c.kind != #{Results.atom_kind(:cancelled)}) ORDER BY e.created_at DESC LIMIT 1 """, @@ -1147,10 +1168,19 @@ defmodule Coflux.Orchestration.Runs do FROM steps AS s INNER JOIN executions AS e ON e.step_id = s.id LEFT JOIN results AS r ON r.execution_id = e.id + LEFT JOIN completions AS c ON c.execution_id = e.id WHERE e.workspace_id IN (#{build_placeholders(length(workspace_ids))}) AND s.cache_key = ?#{length(workspace_ids) + 1} - AND (r.type IS NULL OR (r.type = 1 AND r.created_at >= ?#{length(workspace_ids) + 2})) + -- Either no result yet (in-flight candidate) or a value result + -- recorded within the cache age window. Errors disqualify. + -- Cancelled-with-value also disqualifies: the value stays valid + -- for already-resolved consumers but shouldn't seed cache hits. + AND ( + r.execution_id IS NULL + OR (r.value_id IS NOT NULL AND r.created_at >= ?#{length(workspace_ids) + 2}) + ) + AND (c.kind IS NULL OR c.kind != #{Results.atom_kind(:cancelled)}) #{step_clause} ORDER BY e.created_at DESC LIMIT 1 @@ -1166,7 +1196,8 @@ defmodule Coflux.Orchestration.Runs do end def get_result_successors(db, execution_id) do - # First, find successors via the successor_id chain (same-run, internal) + # First, find successors via the successor_id chain (same-run, internal). + # Successors now live on the completions table. {:ok, rows1} = query( db, @@ -1174,9 +1205,9 @@ defmodule Coflux.Orchestration.Runs do WITH RECURSIVE successors AS ( SELECT ?1 AS execution_id UNION - SELECT r.execution_id + SELECT c.execution_id FROM successors AS ss - INNER JOIN results AS r ON r.successor_id = ss.execution_id + INNER JOIN completions AS c ON c.successor_id = ss.execution_id ) SELECT run.external_id, ss.execution_id FROM successors AS ss @@ -1198,10 +1229,10 @@ defmodule Coflux.Orchestration.Runs do query( db, """ - SELECT run2.external_id, r.execution_id - FROM results AS r - INNER JOIN execution_refs AS ref ON r.successor_ref_id = ref.id - INNER JOIN executions AS e ON e.id = r.execution_id + SELECT run2.external_id, c.execution_id + FROM completions AS c + INNER JOIN execution_refs AS ref ON c.successor_ref_id = ref.id + INNER JOIN executions AS e ON e.id = c.execution_id INNER JOIN steps AS s ON s.id = e.step_id INNER JOIN runs AS run2 ON run2.id = s.run_id WHERE ref.run_external_id = ?1 AND ref.step_number = ?2 AND ref.attempt = ?3 diff --git a/server/lib/coflux/orchestration/server.ex b/server/lib/coflux/orchestration/server.ex index 1858f647..bc6fa451 100644 --- a/server/lib/coflux/orchestration/server.ex +++ b/server/lib/coflux/orchestration/server.ex @@ -1694,17 +1694,21 @@ defmodule Coflux.Orchestration.Server do update_in(state.sessions[session_id].starting, &MapSet.delete(&1, ext_id)) end) - # Abandon executions that were executing but are no longer reported + # Abandon executions that were executing but are no longer reported. + # Keyed off `has_completion?` rather than `has_result?`: a value + # result with streams still draining is still running from the + # lifecycle's perspective, and we shouldn't synthesise a completion + # until the worker actually stops reporting it. state = session.executing |> MapSet.difference(reported_ext_ids) |> Enum.reduce(state, fn ext_id, state -> execution_id = Map.fetch!(state.execution_ids, ext_id) - case Results.has_result?(state.db, execution_id) do + case Results.has_completion?(state.db, execution_id) do {:ok, false} -> - # Server-detected abandonment. Write both tables — no worker - # will send notify_terminated for this execution. + # Server-detected abandonment. No worker will send + # notify_terminated for this execution. {:ok, state} = process_result(state, execution_id, :abandoned) complete_execution(state, execution_id) @@ -1713,7 +1717,10 @@ defmodule Coflux.Orchestration.Server do end end) - # Check for executions reported but not in starting/executing - abort if already have result + # Executions the worker reported but the server has already + # finalised (completion recorded) — tell the worker to abort + # its local process. Mid-drain executions (result but no + # completion) are legitimate and left alone. state = reported_ext_ids |> MapSet.difference(session.starting) @@ -1721,7 +1728,7 @@ defmodule Coflux.Orchestration.Server do |> Enum.reduce(state, fn ext_id, state -> case Map.fetch(state.execution_ids, ext_id) do {:ok, execution_id} -> - case Results.has_result?(state.db, execution_id) do + case Results.has_completion?(state.db, execution_id) do {:ok, false} -> state @@ -2101,7 +2108,7 @@ defmodule Coflux.Orchestration.Server do {:ok, closed_at} -> state = state - |> push_stream_closed(execution_id, index, error) + |> push_stream_closed(execution_id, index, reason, error) |> notify_stream_closed(execution_id, index, reason, error, closed_at) |> drop_stream_producer({execution_id, index}) |> flush_notifications() @@ -3886,15 +3893,16 @@ defmodule Coflux.Orchestration.Server do # Cancel a single execution: record :cancelled, abort if assigned, cancel descendants. defp do_cancel_execution(state, execution_id, workspace_id) do - # Write the results row and fire result-time notifications to mark the - # execution as cancelled. The completion row isn't written until the - # worker confirms termination (via notify_terminated), so consumers can - # distinguish "cancelling" (results present, completion absent) from - # "cancelled" (both present). + # Write the completion row (kind = cancelled) and fire notifications. + # The result row is left untouched: if the worker already produced a + # value, it stays; otherwise nothing is recorded. UI shows "cancelled" + # via the completion kind, with any prior result visible in the + # sidebar. state = case record_and_notify_result(state, execution_id, :cancelled, nil) do {:ok, state} -> state {:error, :already_recorded} -> state + {:error, :already_completed} -> state end # Close any open streams so iterating consumers stop waiting. Any @@ -5152,7 +5160,23 @@ defmodule Coflux.Orchestration.Server do {nil, nil, nil, nil} end - Map.put(results, execution_id, {result, result_at, completed_at, result_created_by}) + completion = + case Results.get_completion(db, execution_id) do + {:ok, {kind, successor_id, successor_ref_id, _, _}} -> + %{ + kind: Atom.to_string(kind), + successor: build_completion_successor(db, successor_id, successor_ref_id) + } + + {:ok, nil} -> + nil + end + + Map.put( + results, + execution_id, + {result, result_at, completed_at, result_created_by, completion} + ) end) steps = @@ -5212,7 +5236,7 @@ defmodule Coflux.Orchestration.Server do {:ok, workspace_external_id} = Workspaces.get_workspace_external_id(db, workspace_id) - {result, result_at, completed_at, result_created_by} = + {result, result_at, completed_at, result_created_by, completion} = Map.fetch!(results, execution_id) execution_groups = @@ -5251,7 +5275,7 @@ defmodule Coflux.Orchestration.Server do ) ) - streams = streams_with_resolved_errors(db, execution_id) + streams = streams_with_resolved_reasons(db, execution_id) {attempt, %{ @@ -5263,6 +5287,7 @@ defmodule Coflux.Orchestration.Server do assigned_at: assigned_at, result_at: result_at, completed_at: completed_at, + completion: completion, groups: execution_groups, assets: assets, dependencies: dependencies, @@ -5943,8 +5968,22 @@ defmodule Coflux.Orchestration.Server do end case Results.record_result(state.db, execution_id, result, created_by) do - {:ok, result_at} -> - state = fire_result_notifications(state, execution_id, result, result_at, created_by) + {:ok, timestamp} -> + state = fire_result_notifications(state, execution_id, result, timestamp, created_by) + + # For result shapes that write the completion synchronously + # (cancelled / abandoned / timeout / deferred / cached / spawned / + # suspended / recurred), fire the completion-time notifications + # now — queue removal, run-topic `completion` update, waiter + # wake-ups. For value/error the completion is written later via + # complete_execution, which fires these itself. + state = + if writes_completion_immediately?(result) do + fire_completion_notification(state, execution_id, timestamp) + else + state + end + {:ok, state} {:error, reason} -> @@ -5952,8 +5991,21 @@ defmodule Coflux.Orchestration.Server do end end + defp writes_completion_immediately?(:cancelled), do: true + defp writes_completion_immediately?({:abandoned, _}), do: true + defp writes_completion_immediately?({:crashed, _}), do: true + defp writes_completion_immediately?({:timeout, _}), do: true + defp writes_completion_immediately?({:suspended, _}), do: true + defp writes_completion_immediately?({:recurred, _}), do: true + defp writes_completion_immediately?({:deferred, _}), do: true + defp writes_completion_immediately?({:deferred, _, _}), do: true + defp writes_completion_immediately?({:cached, _}), do: true + defp writes_completion_immediately?({:cached, _, _}), do: true + defp writes_completion_immediately?({:spawned, _}), do: true + defp writes_completion_immediately?({:spawned, _, _}), do: true + defp writes_completion_immediately?(_), do: false + defp fire_result_notifications(state, execution_id, result, result_at, created_by) do - {:ok, workspace_id} = Runs.get_workspace_id_for_execution(state.db, execution_id) {:ok, successors} = Runs.get_result_successors(state.db, execution_id) {:ok, {r, s, a}} = Runs.get_execution_key(state.db, execution_id) execution_external_id = execution_external_id(r, s, a) @@ -5964,93 +6016,55 @@ defmodule Coflux.Orchestration.Server do |> update_dependencies_on_result(execution_id) |> unregister_pending_dependencies(execution_id) - final = is_result_final?(result) - built_result = build_result(result, state.db) - - principal = - case Principals.get_principal(state.db, created_by) do - {:ok, {type, external_id}} -> %{type: type, external_id: external_id} - {:ok, nil} -> nil - end - - ws_ext_id = workspace_external_id(state, workspace_id) + # Cancellation after a value result was already recorded: the value + # stays authoritative for consumer resolution (a consumer that saw + # the value before the cancel must keep seeing it). The run-topic + # `:result` notification was already fired at value-record time, so + # re-firing it with `:cancelled` would clobber the value in the UI + # and in any dependent executions' nested result state. The + # `:completion` notification carries the cancelled status via its + # kind field, which is what the UI reads for the badge. + skip_topic_notifications = + result == :cancelled and has_value_result?(state.db, execution_id) state = - successors - |> Enum.reduce(state, fn {run_external_id, successor_id}, state -> - cond do - successor_id == execution_id -> - notify_listeners( - state, - {:run, run_external_id}, - {:result, execution_external_id, built_result, result_at, principal} - ) - - final -> - {:ok, {r2, s2, a2}} = Runs.get_execution_key(state.db, successor_id) - successor_external_id = execution_external_id(r2, s2, a2) - - notify_listeners( - state, - {:run, run_external_id}, - # TODO: better name? - {:result_result, successor_external_id, built_result, result_at, principal} - ) - - true -> - state - end - end) - |> then(fn state -> - case untrack_run_execution(state, r, execution_id) do - {{root_module, root_target}, state} -> - notify_listeners( - state, - {:modules, ws_ext_id}, - {:completed, {root_module, root_target}, r, execution_external_id} - ) + if skip_topic_notifications do + state + else + final = is_result_final?(result) + built_result = build_result(result, state.db) - {nil, state} -> - state - end - end) - |> notify_listeners( - {:queue, ws_ext_id}, - {:completed, execution_external_id} - ) + principal = + case Principals.get_principal(state.db, created_by) do + {:ok, {type, external_id}} -> %{type: type, external_id: external_id} + {:ok, nil} -> nil + end - # Check if any input dependencies became inactive. Route the - # :inputs topic notification to the INPUT's workspace (matching - # :input_dependency_active in the resolve_input handler), not the - # completing execution's workspace — these differ when an execution - # in a child workspace resolved an input created in a parent. - state = - case Inputs.get_input_dependencies_for_execution(state.db, execution_id) do - {:ok, deps} -> - Enum.reduce(deps, state, fn {input_id, input_ws_id}, state -> - if Inputs.has_active_dependency?(state.db, input_id) do - state - else - {:ok, run_ext_id, input_number} = - Inputs.get_input_run_and_number(state.db, input_id) + successors + |> Enum.reduce(state, fn {run_external_id, successor_id}, state -> + cond do + successor_id == execution_id -> + notify_listeners( + state, + {:run, run_external_id}, + {:result, execution_external_id, built_result, result_at, principal} + ) - input_ext_id = input_external_id(run_ext_id, input_number) - input_ws_ext_id = workspace_external_id(state, input_ws_id) + final -> + {:ok, {r2, s2, a2}} = Runs.get_execution_key(state.db, successor_id) + successor_external_id = execution_external_id(r2, s2, a2) - state - |> notify_listeners( - {:inputs, input_ws_ext_id}, - {:input_dependency_inactive, input_ext_id} - ) - |> notify_listeners( - {:input, input_ext_id}, - {:active, false} + notify_listeners( + state, + {:run, run_external_id}, + # TODO: better name? + {:result_result, successor_external_id, built_result, result_at, principal} ) - end - end) - _ -> - state + true -> + state + end + end) end # TODO: only if there's an execution waiting for this result? @@ -6059,39 +6073,138 @@ defmodule Coflux.Orchestration.Server do state end - # Write the completion row and fire a completion-time notification. For - # "crashed" cases (notify_terminated with no prior results row), also - # decides retry and fires result-time notifications with a synthesised - # :crashed shape. + defp has_value_result?(db, execution_id) do + case Results.get_result_payload(db, execution_id) do + {:ok, {:value, _}} -> true + _ -> false + end + end + + # Notify input-topic subscribers when any input this execution depended on + # has now become inactive. `has_active_dependency?` keys off the completion + # row, so this must run at completion time rather than result time — + # calling it any earlier would always see "still active". + defp notify_input_deactivations(state, execution_id) do + case Inputs.get_input_dependencies_for_execution(state.db, execution_id) do + {:ok, deps} -> + Enum.reduce(deps, state, fn {input_id, input_ws_id}, state -> + if Inputs.has_active_dependency?(state.db, input_id) do + state + else + {:ok, run_ext_id, input_number} = + Inputs.get_input_run_and_number(state.db, input_id) + + input_ext_id = input_external_id(run_ext_id, input_number) + # Route :inputs topic notification to the INPUT's workspace + # (matching :input_dependency_active in the resolve_input + # handler) — these differ when an execution in a child + # workspace resolved an input created in a parent. + input_ws_ext_id = workspace_external_id(state, input_ws_id) + + state + |> notify_listeners( + {:inputs, input_ws_ext_id}, + {:input_dependency_inactive, input_ext_id} + ) + |> notify_listeners( + {:input, input_ext_id}, + {:active, false} + ) + end + end) + + _ -> + state + end + end + + # Write the completion row and fire completion-time notifications. Called + # from notify_terminated (or the abandonment/crash paths). Decides any + # retry/successor from the persisted result row at this point rather than + # carrying a decision forward from result-record time — so the decision + # survives server restarts and epoch rotation. defp complete_execution(state, execution_id) do case Results.has_completion?(state.db, execution_id) do {:ok, true} -> state {:ok, false} -> - case Results.has_result?(state.db, execution_id) do - {:ok, true} -> - # Close any streams left open by the producer. Generator tasks - # normally close their streams explicitly; this is the backstop. - # We record a :lifecycle closure — the error surfaced to - # consumers is derived from the execution's recorded result - # rather than stored separately. - state = close_open_streams(state, execution_id) - - case Results.record_completion(state.db, execution_id) do - {:ok, completion_at} -> - fire_completion_notification(state, execution_id, completion_at) - - {:error, :already_completed} -> - state - end + case Results.get_result_payload(state.db, execution_id) do + {:ok, {:value, _}} -> + finalize_success_completion(state, execution_id) + + {:ok, {:error, type, message, frames, retryable}} -> + finalize_error_completion( + state, + execution_id, + {type, message, frames, retryable} + ) - {:ok, false} -> + {:ok, nil} -> handle_crashed(state, execution_id) end end end + # Value result + clean drain: no retry, kind=:succeeded. + defp finalize_success_completion(state, execution_id) do + state = close_open_streams(state, execution_id) + + case Results.record_completion(state.db, execution_id, :succeeded) do + {:ok, completion_at} -> + fire_completion_notification(state, execution_id, completion_at) + + {:error, :already_completed} -> + state + end + end + + # Error result: decide retry now (so the successor decision lands on + # the persisted completion row, not in transient in-memory state) and + # re-fire the :result notification with the retry link filled in. + defp finalize_error_completion(state, execution_id, {type, message, frames, retryable}) do + {:ok, step} = Runs.get_step_for_execution(state.db, execution_id) + {:ok, workspace_id} = Runs.get_workspace_id_for_execution(state.db, execution_id) + + {retry_id, _recurred?, state} = + decide_and_create_successor( + state, + execution_id, + step, + workspace_id, + {:error, type, message, frames, retryable} + ) + + state = close_open_streams(state, execution_id) + + case Results.record_completion(state.db, execution_id, :errored, + successor_id: retry_id + ) do + {:ok, completion_at} -> + # Re-fire :result on the run topic so the error entry in the UI + # picks up the newly-created retry successor. We only need to do + # this when retry_id changed from nil (there was no successor at + # initial :result time) to something. + state = + if retry_id do + fire_result_notifications( + state, + execution_id, + {:error, type, message, frames, retry_id, retryable}, + nil, + nil + ) + else + state + end + + fire_completion_notification(state, execution_id, completion_at) + + {:error, :already_completed} -> + state + end + end + # No results row exists for this execution but notify_terminated has # arrived — the worker terminated without reporting. Decide retry, write # completion (no results row), fire notifications. @@ -6109,7 +6222,9 @@ defmodule Coflux.Orchestration.Server do # consumers derive the specific error from the execution's outcome. state = close_open_streams(state, execution_id) - case Results.record_completion(state.db, execution_id) do + case Results.record_completion(state.db, execution_id, :crashed, + successor_id: retry_id + ) do {:ok, completion_at} -> # Result-time notifications weren't fired (no results row was ever # written), so fire them now alongside the completion notification. @@ -6128,19 +6243,27 @@ defmodule Coflux.Orchestration.Server do # subscriber. Streams already closed by the producer (clean or errored) # are left untouched. # - # The closure is recorded with reason :lifecycle — no error is stored - # on the closure row. Consumers that need to surface an error derive - # it from the execution's recorded result (see derive_lifecycle_error). + # The closure is recorded on disk with reason :lifecycle (no error on + # the closure row). On the wire, we resolve that to a specific reason + # (:cancelled / :abandoned / :crashed / :timeout / :errored) by looking + # at the execution's completion — consumers then decide how to handle + # each case. defp close_open_streams(state, execution_id) do {:ok, indexes} = Streams.get_open_streams_for_execution(state.db, execution_id) - push_error = derive_lifecycle_error(state.db, execution_id) + {push_reason, push_error} = derive_lifecycle_info(state.db, execution_id) Enum.reduce(indexes, state, fn index, state -> case Streams.close_stream(state.db, execution_id, index, :lifecycle) do {:ok, closed_at} -> state - |> push_stream_closed(execution_id, index, push_error) - |> notify_stream_closed(execution_id, index, :lifecycle, push_error, closed_at) + |> push_stream_closed(execution_id, index, push_reason, push_error) + |> notify_stream_closed( + execution_id, + index, + push_reason, + push_error, + closed_at + ) |> drop_stream_producer({execution_id, index}) {:error, :already_closed} -> @@ -6149,13 +6272,17 @@ defmodule Coflux.Orchestration.Server do end) end - # Returns the streams list for `execution_id` with :lifecycle closures' - # errors resolved from the execution's recorded result. Shape: - # `{index, buffer, opened_at, closed_at | nil, reason | nil, error | nil}` — - # reason is retained so the topic can colour open vs complete vs - # errored vs lifecycle distinctly; buffer is passed through for the - # topic to display. - defp streams_with_resolved_errors(db, execution_id) do + # Returns the streams list for `execution_id` for the run topic's + # initial state. Shape: + # `{index, buffer, opened_at, closed_at | nil, reason | nil, error | nil}` + # + # DB closures are recorded with reason :complete / :errored / :lifecycle. + # For :lifecycle we resolve the actual cause (from the execution's + # completion kind) and surface that directly as the reason — + # :cancelled / :abandoned / :crashed / :timeout / :errored — so Studio + # doesn't have to deal with a generic "lifecycle" bucket. `error` is + # non-nil only when the reason is :errored. + defp streams_with_resolved_reasons(db, execution_id) do {:ok, rows} = Streams.get_streams_with_closures_for_execution(db, execution_id) Enum.map(rows, fn @@ -6163,52 +6290,127 @@ defmodule Coflux.Orchestration.Server do {index, buffer, opened_at, nil, nil, nil} {index, buffer, opened_at, closed_at, :lifecycle, _} -> - {index, buffer, opened_at, closed_at, :lifecycle, - derive_lifecycle_error(db, execution_id)} + {resolved_reason, resolved_error} = derive_lifecycle_info(db, execution_id) + {index, buffer, opened_at, closed_at, resolved_reason, resolved_error} {index, buffer, opened_at, closed_at, reason, error} -> {index, buffer, opened_at, closed_at, reason, error} end) end - # Build a {type, message, frames} triple describing why a lifecycle - # closure happened, by looking at the execution's recorded result. Used - # both when pushing live closures to subscribers and when late - # subscribers attach to an already-closed stream. Returns nil if the - # execution has no result yet (shouldn't happen in practice — lifecycle - # closures are driven by complete_execution which only runs after a - # result is recorded). - defp derive_lifecycle_error(db, execution_id) do - case Results.get_result(db, execution_id) do - {:ok, {{:error, type, message, frames, _, _}, _, _, _}} -> - {type, message, frames} - - {:ok, {:cancelled, _, _, _}} -> - {"Coflux.ExecutionCancelled", "execution cancelled", []} - - {:ok, {{:abandoned, _}, _, _, _}} -> - {"Coflux.ExecutionAbandoned", "execution abandoned", []} - - {:ok, {{:crashed, _}, _, _, _}} -> - {"Coflux.ExecutionCrashed", "worker terminated", []} - - {:ok, {{:timeout, _}, _, _, _}} -> - {"Coflux.ExecutionTimeout", "execution timed out", []} + # Derive a semantic reason + optional error for a lifecycle stream + # closure, from the execution's completion kind. Used when pushing + # closures to live consumers and when late subscribers attach to + # already-closed streams. + # + # Returns `{reason, error}` where: + # * `reason` is `:cancelled | :abandoned | :crashed | :timeout | + # :errored | nil` — the shape of the ending, not a fabricated + # exception string. Clients (Python adapter, Studio) decide how + # to represent each reason in their own idioms. + # * `error` is non-nil only when `reason == :errored` — then it's + # the producer's actual `{type, message, frames}`, propagated + # so consumers see the same exception the producer raised. + # + # Keys off the completion kind directly (rather than the logical-result + # tuple) so cancelled-with-value — where `get_result` returns + # `{:value, _}` but the completion kind is `:cancelled` — still + # propagates the cancellation signal to stream consumers. + defp derive_lifecycle_info(db, execution_id) do + case Results.get_completion(db, execution_id) do + {:ok, {:cancelled, _, _, _, _}} -> + {:cancelled, nil} + + {:ok, {:abandoned, _, _, _, _}} -> + {:abandoned, nil} + + {:ok, {:crashed, _, _, _, _}} -> + {:crashed, nil} + + {:ok, {:timeout, _, _, _, _}} -> + {:timeout, nil} + + {:ok, {:errored, _, _, _, _}} -> + # Error payload lives on the results row — pull it so consumers + # see the producer's actual exception. + case Results.get_result_payload(db, execution_id) do + {:ok, {:error, type, message, frames, _}} -> {:errored, {type, message, frames}} + _ -> {:errored, nil} + end _ -> - nil + {nil, nil} end end defp fire_completion_notification(state, execution_id, completion_at) do {:ok, {r, s, a}} = Runs.get_execution_key(state.db, execution_id) execution_external_id = execution_external_id(r, s, a) + {:ok, workspace_id} = Runs.get_workspace_id_for_execution(state.db, execution_id) + ws_ext_id = workspace_external_id(state, workspace_id) - notify_listeners( - state, - {:run, r}, - {:completion, execution_external_id, completion_at} - ) + # Error results become resolvable only once the completion lands (the + # retry decision is on the completion). Any waiters parked at result + # time because the result was still pending need to be re-evaluated now. + state = notify_waiting(state, execution_id) + + {kind, successor} = + case Results.get_completion(state.db, execution_id) do + {:ok, {kind_atom, successor_id, successor_ref_id, _, _}} -> + {kind_atom, build_completion_successor(state.db, successor_id, successor_ref_id)} + + {:ok, nil} -> + {nil, nil} + end + + state = + state + |> notify_listeners( + {:run, r}, + {:completion, execution_external_id, kind, successor, completion_at} + ) + # Queue / workflow-list bookkeeping is completion-driven: the + # execution is considered "done" for the queue and the module-level + # running-workflow tracker only once the completion is recorded. + # An execution with a value result but no completion (streams still + # draining) continues to show up as running. + |> then(fn state -> + case untrack_run_execution(state, r, execution_id) do + {{root_module, root_target}, state} -> + notify_listeners( + state, + {:modules, ws_ext_id}, + {:completed, {root_module, root_target}, r, execution_external_id} + ) + + {nil, state} -> + state + end + end) + |> notify_listeners( + {:queue, ws_ext_id}, + {:completed, execution_external_id} + ) + |> notify_input_deactivations(execution_id) + + state + end + + # Shape the successor on a completion for the run topic. Same-epoch + # integer ids get resolved to their external form; cross-epoch refs go + # out as their resolved run/step/attempt triple. + defp build_completion_successor(_db, nil, nil), do: nil + + defp build_completion_successor(db, successor_id, nil) when is_integer(successor_id) do + case Runs.get_execution_key(db, successor_id) do + {:ok, {r, s, a}} -> %{type: "execution", id: execution_external_id(r, s, a)} + _ -> nil + end + end + + defp build_completion_successor(db, nil, successor_ref_id) when is_integer(successor_ref_id) do + {ext_id, _module, _target} = resolve_execution_ref(db, successor_ref_id) + %{type: "execution", id: ext_id} end defp process_result(state, execution_id, result, created_by \\ nil) do @@ -6220,8 +6422,19 @@ defmodule Coflux.Orchestration.Server do {:ok, step} = Runs.get_step_for_execution(state.db, execution_id) {:ok, workspace_id} = Runs.get_workspace_id_for_execution(state.db, execution_id) + # Retry decisions for error results are deferred to + # complete_execution — they need to survive server restart, and + # the error payload is already persisted on the results row so + # the decision can be reconstructed from there. Every other shape + # (value with recurrent → :recurred, suspended, abandoned, crashed, + # timeout) still needs its successor decided here so the + # compat-shim-written completion carries the correct link. {retry_id, recurred?, state} = - decide_and_create_successor(state, execution_id, step, workspace_id, result) + if match?({:error, _, _, _, _}, result) do + {nil, false, state} + else + decide_and_create_successor(state, execution_id, step, workspace_id, result) + end result = transform_result_with_successor(result, retry_id, recurred?) @@ -6292,17 +6505,18 @@ defmodule Coflux.Orchestration.Server do result_retryable?(result) && step.retry_limit > 0 -> # Limited retries - check consecutive failures. Exclude the current - # execution's row so this works regardless of whether its results row - # has already been written (deferred path) or not (immediate path). - # A nil type indicates a crashed execution (completion without a - # results row) — counted as a failure. + # execution so this works whether or not its completion has been + # written yet. Failure kinds are errored/abandoned/crashed/timeout — + # the same set the retry predicate uses. {:ok, rows} = - Runs.get_step_result_types(state.db, step.id, step.retry_limit + 2) + Runs.get_step_completion_kinds(state.db, step.id, step.retry_limit + 2) + + failure_kinds = Results.failure_kinds() consecutive_failures = rows - |> Enum.reject(fn {id, _type} -> id == execution_id end) - |> Enum.take_while(fn {_id, type} -> type in [0, 2, 8] or is_nil(type) end) + |> Enum.reject(fn {id, _kind} -> id == execution_id end) + |> Enum.take_while(fn {_id, kind} -> kind in failure_kinds end) |> Enum.count() if consecutive_failures < step.retry_limit do @@ -6374,8 +6588,13 @@ defmodule Coflux.Orchestration.Server do {:ok, nil} -> {:pending, execution_id} - {:ok, {result, _created_at, _completion_created_at, _created_by}} -> + {:ok, {result, _created_at, completion_at, _created_by}} -> case result do + # Error payload but no completion yet — retry decision hasn't been + # made. Treat as pending so the caller waits for the completion. + {:error, _, _, _, nil, _retryable} when is_nil(completion_at) -> + {:pending, execution_id} + {:error, _, _, _, execution_id, _retryable} when not is_nil(execution_id) -> resolve_result(db, execution_id) @@ -7529,6 +7748,11 @@ defmodule Coflux.Orchestration.Server do end end + # Fire the stream-closed notification on run + stream topics. `reason` + # is a semantic atom from the full set (:complete / :errored / + # :cancelled / :abandoned / :crashed / :timeout) — Studio renders each + # directly in UI-appropriate language, rather than displaying a + # fabricated exception type. `error` is non-nil only for :errored. defp notify_stream_closed(state, execution_id, index, reason, error, closed_at) do {:ok, {r, _s, _a}} = Runs.get_execution_key(state.db, execution_id) {:ok, execution_ext_id} = execution_external_id_for(state.db, execution_id) @@ -7539,7 +7763,7 @@ defmodule Coflux.Orchestration.Server do {type, message, _frames} -> %{type: type, message: message} end - reason_str = Atom.to_string(reason) + reason_str = if reason, do: Atom.to_string(reason) state = notify_listeners( @@ -7598,15 +7822,20 @@ defmodule Coflux.Orchestration.Server do nil {:ok, {reason, stored_error, closed_at}} -> - error = + # DB stores `:lifecycle` for closures driven by the producer + # execution ending; on read we resolve that to the specific + # cause (:cancelled / :abandoned / :crashed / :timeout / + # :errored) so clients don't need to know about the internal + # bucket. `error` only accompanies a genuine :errored close. + {effective_reason, effective_error} = case reason do - :lifecycle -> derive_lifecycle_error(state.db, execution_id) - _ -> stored_error + :lifecycle -> derive_lifecycle_info(state.db, execution_id) + _ -> {reason, stored_error} end %{ - reason: Atom.to_string(reason), - error: encode_stream_error_summary(error), + reason: if(effective_reason, do: Atom.to_string(effective_reason)), + error: encode_stream_error_summary(effective_error), closedAt: closed_at } end @@ -7756,11 +7985,14 @@ defmodule Coflux.Orchestration.Server do # subscription synchronously — matches push_stream_item's # behaviour. Without this, a consumer that subscribed after # appends with a bounded filter would wait forever for a close - # that never comes. + # that never comes. Filter-exhaustion is a "complete" outcome + # from the consumer's perspective: they've received everything + # that was addressed to them. state |> send_to_consumer( sub, - {:stream_closed, sub.consumer_execution_external_id, subscription_id, nil} + {:stream_closed, sub.consumer_execution_external_id, subscription_id, "complete", + nil} ) |> drop_subscription(key) else @@ -7821,12 +8053,15 @@ defmodule Coflux.Orchestration.Server do ) # If the filter is exhausted (e.g. slice reached its stop), close - # the subscription early — no more items will match. + # the subscription early — no more items will match. Treated + # as a "complete" close for the consumer (they got everything + # that was addressed to them). if filter_exhausted?(sub.filter, sequence + 1) do state |> send_to_consumer( sub, - {:stream_closed, sub.consumer_execution_external_id, subscription_id, nil} + {:stream_closed, sub.consumer_execution_external_id, subscription_id, "complete", + nil} ) |> drop_subscription(key) else @@ -7840,12 +8075,16 @@ defmodule Coflux.Orchestration.Server do refresh_stream_demand(state, stream_key) end - # On close, tell every subscriber. Error is either nil (clean close) or a - # {type, message, frames} triple — same shape as Streams.close_stream takes. - defp push_stream_closed(state, producer_execution_id, index, error) do + # On close, tell every subscriber. `reason` is a semantic atom + # (`:complete | :errored | :cancelled | :abandoned | :crashed | + # :timeout`) — the client chooses how to represent each in its own + # idiom. `error` is non-nil only when `reason == :errored`, carrying + # the producer's actual `{type, message, frames}`. + defp push_stream_closed(state, producer_execution_id, index, reason, error) do subscribers = Map.get(state.stream_subscribers, {producer_execution_id, index}, MapSet.new()) + reason_str = if reason, do: Atom.to_string(reason) encoded_error = encode_stream_error(error) Enum.reduce(subscribers, state, fn key, state -> @@ -7855,14 +8094,17 @@ defmodule Coflux.Orchestration.Server do state |> send_to_consumer( sub, - {:stream_closed, sub.consumer_execution_external_id, subscription_id, encoded_error} + {:stream_closed, sub.consumer_execution_external_id, subscription_id, reason_str, + encoded_error} ) |> drop_subscription(key) end) end - # Common wire encoding for stream_closed errors. Frames are included so - # consumers can reconstruct tracebacks for debuggability. + # Wire encoding for the producer's actual error on an `:errored` close. + # Frames are included so consumers can reconstruct tracebacks for + # debuggability. Lifecycle reasons (:cancelled/:abandoned/...) don't + # go through this — they're conveyed as the reason atom alone. defp encode_stream_error(nil), do: nil defp encode_stream_error({type, message, frames}) do @@ -7896,18 +8138,23 @@ defmodule Coflux.Orchestration.Server do {:ok, nil} -> state - {:ok, {reason, error, _closed_at}} -> - resolved_error = + {:ok, {reason, stored_error, _closed_at}} -> + # Resolve :lifecycle to the specific cause for the wire — same + # treatment as live closures so late subscribers don't get a + # less-informative signal than those attached at close time. + {effective_reason, effective_error} = case reason do - :lifecycle -> derive_lifecycle_error(state.db, sub.producer_execution_id) - _ -> error + :lifecycle -> derive_lifecycle_info(state.db, sub.producer_execution_id) + _ -> {reason, stored_error} end + reason_str = if effective_reason, do: Atom.to_string(effective_reason) + state |> send_to_consumer( sub, - {:stream_closed, sub.consumer_execution_external_id, subscription_id, - encode_stream_error(resolved_error)} + {:stream_closed, sub.consumer_execution_external_id, subscription_id, reason_str, + encode_stream_error(effective_error)} ) |> drop_subscription(key) end @@ -7971,6 +8218,9 @@ defmodule Coflux.Orchestration.Server do end # Clean up an execution's state and send an abort message to the worker. + # If the execution has already terminated (completion recorded), there's + # nothing to abort — skip silently. Only warn when an actively-running + # execution unexpectedly has no session. defp abort_execution(state, execution_ext_id) do state = cleanup_execution(state, execution_ext_id) @@ -7979,7 +8229,25 @@ defmodule Coflux.Orchestration.Server do send_session(state, session_id, {:abort, execution_ext_id}) :error -> - Logger.warning("Couldn't locate session for execution #{execution_ext_id}. Ignoring.") + already_completed? = + case Map.fetch(state.execution_ids, execution_ext_id) do + {:ok, execution_id} -> + case Results.has_completion?(state.db, execution_id) do + {:ok, done?} -> done? + end + + :error -> + # No internal id mapped — execution is long gone (e.g., + # cache rotation cleared the cache). Treat as terminated. + true + end + + unless already_completed? do + Logger.warning( + "Couldn't locate session for execution #{execution_ext_id}. Ignoring." + ) + end + state end end diff --git a/server/lib/coflux/topics/run.ex b/server/lib/coflux/topics/run.ex index 11519c98..476773b3 100644 --- a/server/lib/coflux/topics/run.ex +++ b/server/lib/coflux/topics/run.ex @@ -94,6 +94,7 @@ defmodule Coflux.Topics.Run do executeAfter: execute_after, assignedAt: nil, completedAt: nil, + completion: nil, groups: %{}, assets: %{}, dependencies: @@ -200,10 +201,15 @@ defmodule Coflux.Topics.Run do defp process_notification( topic, - {:completion, execution_external_id, completion_at} + {:completion, execution_external_id, kind, successor, completion_at} ) do update_execution(topic, execution_external_id, fn topic, base_path -> - Topic.set(topic, base_path ++ [:completedAt], completion_at) + topic + |> Topic.set(base_path ++ [:completedAt], completion_at) + |> Topic.set(base_path ++ [:completion], %{ + kind: Atom.to_string(kind), + successor: successor + }) end) end @@ -387,6 +393,7 @@ defmodule Coflux.Topics.Run do assignedAt: execution.assigned_at, resultAt: execution.result_at, completedAt: execution.completed_at, + completion: execution.completion, groups: execution.groups, assets: Map.new(execution.assets, fn {external_asset_id, asset} -> diff --git a/server/priv/migrations/orchestration/4.sql b/server/priv/migrations/orchestration/4.sql index 022e44f9..f54bf464 100644 --- a/server/priv/migrations/orchestration/4.sql +++ b/server/priv/migrations/orchestration/4.sql @@ -1,23 +1,110 @@ --- Add completions table — a pure termination marker. The existing results --- table continues to hold the disposition (including any successor), written --- at result-arrival time. A completions row is written separately at --- notify_terminated time, so its timestamp reflects when the worker's --- process actually finished shutting down. +-- Split the old `results` table into two: -- --- This enables streaming support: a results row can be written with stream --- handles while the process keeps running, with completions written later --- when streams have drained. +-- * `results` — pure payload: value_id XOR error_id (+ retryable flag +-- for errors). Written by the worker when the task body +-- produces a value/error. +-- * `completions` — terminal-state marker for the execution: kind (with +-- a broader vocabulary than results' old `type`), +-- optional successor (retry / suspended / recurred / +-- deferred / cached / spawned), and created_by (which +-- moved from results). +-- +-- The lifecycle primary is the completion: an execution is "running" iff no +-- completions row exists. A value result is available for downstream +-- resolution as soon as it's written (before completion, so consumers don't +-- block on stream drain); an error result can't resolve until the completion +-- tells us whether there's a retry successor. + +-- The new results table. Payload only. `value_id` and `error_id` are +-- mutually exclusive, enforced by CHECK. `retryable` is the `when`-callback +-- result from the worker and only meaningful for errors. +CREATE TABLE results_new ( + execution_id INTEGER PRIMARY KEY, + value_id INTEGER, + error_id INTEGER, + retryable INTEGER, + created_at INTEGER NOT NULL, + FOREIGN KEY (execution_id) REFERENCES executions ON DELETE CASCADE, + FOREIGN KEY (value_id) REFERENCES values_ ON DELETE RESTRICT, + FOREIGN KEY (error_id) REFERENCES errors ON DELETE RESTRICT, + CHECK ((value_id IS NULL) != (error_id IS NULL)) +) STRICT; +-- The completions table. Terminal state for the execution. `kind` values: +-- 0 = succeeded — value result recorded, process ended cleanly +-- 1 = errored — error result recorded, process ended cleanly +-- 2 = abandoned — session expired before notify_terminated +-- 3 = crashed — notify_terminated without prior result +-- 4 = timeout — execution hit its timeout +-- 5 = cancelled — user cancelled (may or may not have a result row) +-- 6 = suspended — body called suspend; successor resumes later +-- 7 = recurred — recurrent execution scheduled its next run +-- 8 = deferred — execution deferred to another (memoisation / defer) +-- 9 = cached — execution resolved to an existing cache hit +-- 10 = spawned — execution spawned a continuation +-- +-- `successor_id` points at an execution in the same epoch; used for retry +-- chains and in-flight handoffs. `successor_ref_id` points at an +-- `execution_refs` row and is used post-epoch-rotation, when the target +-- integer id is no longer resolvable in the active DB. At most one is set. CREATE TABLE completions ( execution_id INTEGER PRIMARY KEY, + kind INTEGER NOT NULL, + successor_id INTEGER, + successor_ref_id INTEGER, created_at INTEGER NOT NULL, - FOREIGN KEY (execution_id) REFERENCES executions ON DELETE CASCADE + created_by INTEGER REFERENCES principals ON DELETE SET NULL, + FOREIGN KEY (execution_id) REFERENCES executions ON DELETE CASCADE, + FOREIGN KEY (successor_id) REFERENCES executions ON DELETE RESTRICT, + FOREIGN KEY (successor_ref_id) REFERENCES execution_refs ON DELETE RESTRICT ) STRICT; --- Every existing results row represents a terminated execution, so each --- produces a completions row with the same timestamp. -INSERT INTO completions (execution_id, created_at) - SELECT execution_id, created_at FROM results; +CREATE INDEX idx_completions_successor_id ON completions(successor_id); +CREATE INDEX idx_completions_successor_ref_id ON completions(successor_ref_id); + +-- Migrate completions first (one row per existing results row). Map the old +-- `type` enum onto the new `kind` enum. created_by moves from results to +-- completions. +INSERT INTO completions (execution_id, kind, successor_id, successor_ref_id, created_at, created_by) + SELECT + execution_id, + CASE type + WHEN 0 THEN 1 -- errored + WHEN 1 THEN 0 -- succeeded + WHEN 2 THEN 2 -- abandoned + WHEN 3 THEN 5 -- cancelled + WHEN 4 THEN 8 -- deferred + WHEN 5 THEN 9 -- cached + WHEN 6 THEN 6 -- suspended + WHEN 7 THEN 10 -- spawned + WHEN 8 THEN 4 -- timeout + WHEN 9 THEN 7 -- recurred + END, + successor_id, + successor_ref_id, + created_at, + created_by + FROM results; + +-- Migrate payloads into new results. Only rows that actually carry a value +-- or error get copied — in-flight deferred/cached/spawned (types 4/5/7 +-- without a value_id) don't get a results row. `retryable` is cleared for +-- non-error rows to keep the new CHECK/intent tight. +INSERT INTO results_new (execution_id, value_id, error_id, retryable, created_at) + SELECT + execution_id, + value_id, + error_id, + CASE WHEN error_id IS NOT NULL THEN retryable ELSE NULL END, + created_at + FROM results + WHERE error_id IS NOT NULL OR value_id IS NOT NULL; + +-- Replace the old results table. +DROP INDEX idx_results_successor_id; +DROP INDEX idx_results_successor_ref_id; +DROP TABLE results; +ALTER TABLE results_new RENAME TO results; -- Streams — ordered, append-only sequences of values produced by an -- execution. Each stream is identified by (execution_id, index), where @@ -68,7 +155,7 @@ CREATE TABLE stream_items ( -- 1 = errored — producer raised an error (stored in errors via error_id) -- 2 = lifecycle — closed implicitly because the producer execution ended -- (cancel/crash/abandon/error). The specific error is --- derived on read by looking up the execution's result, +-- derived on read by looking up the execution's completion, -- so we don't duplicate that state here. CREATE TABLE stream_closures ( execution_id INTEGER NOT NULL, diff --git a/tests/test_streams.py b/tests/test_streams.py index f64cc9fc..d0571289 100644 --- a/tests/test_streams.py +++ b/tests/test_streams.py @@ -231,9 +231,8 @@ def test_subscribe_to_unknown_producer_closes_immediately(worker): _items, closed = cons_ex.conn.drain_stream(subscription_id=1) cons_ex.conn.complete(cons_ex.execution_id) - err = closed.get("error") - assert err is not None - assert err["type"] == "Coflux.StreamNotFound" + assert closed.get("reason") == "producer_not_found" + assert closed.get("error") is None def test_topic_exposes_stream_state(worker): @@ -270,9 +269,10 @@ def test_topic_exposes_stream_state(worker): assert streams["1"]["error"] == {"type": "RuntimeError", "message": "bad"} -def test_cancellation_closes_streams_with_cancelled_error(worker): +def test_cancellation_closes_streams_with_cancelled_reason(worker): """Cancel an execution mid-stream: the subscriber receives a closure - carrying the ExecutionCancelled error synthesised by close_open_streams. + carrying reason="cancelled" — no fabricated exception type, the + adapter maps the reason to its own idiom. """ targets = [workflow("test", "producer"), workflow("test", "consumer")] @@ -297,9 +297,8 @@ def test_cancellation_closes_streams_with_cancelled_error(worker): cons_ex.conn.complete(cons_ex.execution_id) assert [item[1]["value"] for item in items] == ["before"] - err = closed.get("error") - assert err is not None - assert err["type"] == "Coflux.ExecutionCancelled" + assert closed.get("reason") == "cancelled" + assert closed.get("error") is None def test_multiple_subscribers_get_independent_delivery(worker): From aecdb713b2ac04a106581a87f10d4decde57cfec Mon Sep 17 00:00:00 2001 From: Joe Freeman Date: Sun, 19 Apr 2026 23:15:21 +0100 Subject: [PATCH 16/25] Consolidate stream filters into a computed stride --- adapters/python/coflux/context.py | 31 ++++++-- adapters/python/coflux/metric.py | 71 +++++++++--------- adapters/python/coflux/models.py | 90 +++++++++++++++++------ adapters/python/coflux/protocol.py | 38 ++++++---- adapters/python/coflux/serialization.py | 31 +++++--- adapters/python/coflux/streams.py | 62 ++++++++-------- cli/internal/adapter/protocol.go | 10 ++- cli/internal/pool/pool.go | 10 ++- cli/internal/worker/worker.go | 6 +- server/lib/coflux/handlers/worker.ex | 4 +- server/lib/coflux/orchestration/server.ex | 75 +++++++++---------- tests/support/executor.py | 11 ++- tests/support/protocol.py | 20 ++--- tests/test_streams.py | 19 ++--- 14 files changed, 287 insertions(+), 191 deletions(-) diff --git a/adapters/python/coflux/context.py b/adapters/python/coflux/context.py index b25c68a7..141d9c79 100644 --- a/adapters/python/coflux/context.py +++ b/adapters/python/coflux/context.py @@ -7,6 +7,7 @@ import fnmatch as fnmatch import hashlib import json +import threading from contextlib import contextmanager from pathlib import Path from typing import Any, Callable, Iterator @@ -93,6 +94,12 @@ def __init__(self, execution_id: str, working_dir: Path | None = None): # poll_handle to avoid a round-trip for handles that have already # been seen in this context's lifetime. self._resolved: dict[tuple[str, str], dict[str, Any]] = {} + # Guards the mutable collections above. Stream driver threads, the + # main task thread, and any user-spawned threads may call methods + # on this context concurrently; the lock protects check-then-set + # patterns (metric definition, group registration, resolve cache) + # from racing. + self._lock = threading.Lock() # Owns generator streams for this execution. Generators encountered # during serialization (of the return value OR of submit arguments) # are registered here and driven in background threads. @@ -211,7 +218,8 @@ def select( if winner is None: raise RuntimeError(f"Unexpected select response: {response}") - self._resolved[_handle_key(handles[winner])] = response + with self._lock: + self._resolved[_handle_key(handles[winner])] = response return winner def resolve_handle(self, handle: Any) -> Any: @@ -223,13 +231,17 @@ def resolve_handle(self, handle: Any) -> Any: to the deserialized value. """ key = _handle_key(handle) - if key not in self._resolved: + with self._lock: + cached = self._resolved.get(key) + if cached is None: if self.select([handle]) is None: # The wait expired before the handle resolved. Only reachable # from inside a `cf.suspense(timeout=...)` scope; otherwise the # server either resolves or kills the process. raise TimeoutError("timed out waiting for handle to resolve") - return _unwrap_response(self._resolved[key], handle._parser) + with self._lock: + cached = self._resolved[key] + return _unwrap_response(cached, handle._parser) def poll_handle( self, @@ -244,11 +256,15 @@ def poll_handle( is applied to it. """ key = _handle_key(handle) - if key not in self._resolved: + with self._lock: + cached = self._resolved.get(key) + if cached is None: timeout_ms = int(timeout * 1000) if timeout else 0 if self.select([handle], suspend=False, timeout_ms=timeout_ms) is None: return default - return _unwrap_response(self._resolved[key], handle._parser) + with self._lock: + cached = self._resolved[key] + return _unwrap_response(cached, handle._parser) def get_asset_entries(self, asset_id: str) -> list[AssetEntry]: """Get all entries for an asset by ID.""" @@ -511,8 +527,9 @@ def log_error(self, message: str) -> None: @contextmanager def group(self, name: str | None = None) -> Iterator[None]: """Context manager for grouping child executions.""" - group_id = len(self._groups) - self._groups.append(name) + with self._lock: + group_id = len(self._groups) + self._groups.append(name) protocol.send_register_group(self.execution_id, group_id, name) token = _group_id.set(group_id) try: diff --git a/adapters/python/coflux/metric.py b/adapters/python/coflux/metric.py index bacf76db..4eb54a3d 100644 --- a/adapters/python/coflux/metric.py +++ b/adapters/python/coflux/metric.py @@ -162,40 +162,45 @@ def _build_config(self) -> dict: } def _ensure_defined(self, ctx: Any) -> None: + # Hold the context lock for the whole check-then-send-then-record + # cycle so two threads racing to define the same metric don't both + # send a `define_metric` notification (and so the config-equality + # checks are evaluated against a consistent snapshot). config = self._build_config() - existing = ctx._defined_metrics.get(self._key) - if existing is not None: - if existing != config: - raise ValueError( - f"Metric '{self._key}' already defined in this execution " - f"with different configuration" - ) - return - - if self._group is not None: - group_config = self._group._config() - existing_group = ctx._defined_groups.get(self._group.name) - if existing_group is not None and existing_group != group_config: - raise ValueError( - f"Group '{self._group.name}' already defined in this execution " - f"with inconsistent configuration" - ) - ctx._defined_groups[self._group.name] = group_config - - if self._scale is not None and self._scale.name is not None: - scale_config = self._scale._config() - existing_scale = ctx._defined_scales.get(self._scale.name) - if existing_scale is not None and existing_scale != scale_config: - raise ValueError( - f"Scale '{self._scale.name}' already defined in this execution " - f"with inconsistent configuration" - ) - ctx._defined_scales[self._scale.name] = scale_config - - protocol.send_define_metric( - ctx.execution_id, self._key, self._build_definition() - ) - ctx._defined_metrics[self._key] = config + with ctx._lock: + existing = ctx._defined_metrics.get(self._key) + if existing is not None: + if existing != config: + raise ValueError( + f"Metric '{self._key}' already defined in this execution " + f"with different configuration" + ) + return + + if self._group is not None: + group_config = self._group._config() + existing_group = ctx._defined_groups.get(self._group.name) + if existing_group is not None and existing_group != group_config: + raise ValueError( + f"Group '{self._group.name}' already defined in this " + f"execution with inconsistent configuration" + ) + ctx._defined_groups[self._group.name] = group_config + + if self._scale is not None and self._scale.name is not None: + scale_config = self._scale._config() + existing_scale = ctx._defined_scales.get(self._scale.name) + if existing_scale is not None and existing_scale != scale_config: + raise ValueError( + f"Scale '{self._scale.name}' already defined in this " + f"execution with inconsistent configuration" + ) + ctx._defined_scales[self._scale.name] = scale_config + + protocol.send_define_metric( + ctx.execution_id, self._key, self._build_definition() + ) + ctx._defined_metrics[self._key] = config def progress( diff --git a/adapters/python/coflux/models.py b/adapters/python/coflux/models.py index 193d76d8..10fbd877 100644 --- a/adapters/python/coflux/models.py +++ b/adapters/python/coflux/models.py @@ -210,6 +210,40 @@ def target(self) -> str: return self._target +Stride = t.Tuple[int, t.Optional[int], int] +"""A stride over the stream's sequence numbers: ``(start, stop, step)``. + +Matches positions ``start, start+step, start+2·step, …`` up to but not +including ``stop`` (or unbounded when ``stop is None``). Composes with +itself: any chain of ``slice``/``partition``/``stride`` calls reduces to +a single stride, so the wire never sees more than one filter. +""" + + +def _compose_stride(outer: Stride, inner: Stride) -> Stride: + """Stride of a stride. Given that we already have stride ``outer`` + selecting positions ``s₁ + k·step₁`` (``k < (e₁-s₁)/step₁``), apply + ``inner = (s₂, e₂, step₂)`` to those outputs: index ``k`` of the inner + result maps to index ``s₂ + k·step₂`` of the outer, which is original + position ``s₁ + (s₂ + k·step₂)·step₁``. + + Combined: ``start = s₁ + s₂·step₁``, ``step = step₁·step₂``, ``stop`` + is the tighter of the two constraints mapped back to original positions. + """ + s1, e1, step1 = outer + s2, e2, step2 = inner + new_start = s1 + s2 * step1 + new_step = step1 * step2 + inner_stop_mapped = s1 + e2 * step1 if e2 is not None else None + if e1 is None: + new_stop = inner_stop_mapped + elif inner_stop_mapped is None: + new_stop = e1 + else: + new_stop = min(e1, inner_stop_mapped) + return (new_start, new_stop, new_step) + + class Stream(t.Iterable[T]): """A handle to a stream produced by another execution. @@ -218,52 +252,60 @@ class Stream(t.Iterable[T]): starts a fresh subscription from sequence 0, so a stream can be iterated multiple times and each iteration sees the whole sequence. - ``partition`` and ``slice`` return new ``Stream`` views with an additional - filter; no server round-trip happens until iteration begins. + ``partition``, ``slice``, and ``stride`` return new ``Stream`` views + with the stride adjusted. Chained calls compose into a single stride + on the wire — no server-side pipelining logic needed. """ def __init__( self, id: str, - filters: tuple[dict[str, t.Any], ...] = (), + stride: Stride = (0, None, 1), ): # Opaque identifier of the form ``_``. # Users may see this in the CLI/Studio but shouldn't need to parse it. self._id = id - self._filters = filters + self._stride = stride @property def id(self) -> str: return self._id - def partition(self, n: int, i: int) -> "Stream[T]": - """Return a view of this stream where only sequences ``s`` with - ``s % n == i`` are delivered. Round-robin partitioning for parallel - consumers. + def stride( + self, + start: int = 0, + stop: int | None = None, + step: int = 1, + ) -> "Stream[T]": + """Return a view of this stream restricted to the positions + ``start, start+step, …`` up to (but not including) ``stop``. + Composes with any existing stride on this view. """ - if n < 1 or i < 0 or i >= n: - raise ValueError(f"invalid partition args: n={n}, i={i}") - return Stream( - self._id, - self._filters + ({"type": "partition", "n": n, "i": i},), - ) + if start < 0 or step < 1 or (stop is not None and stop < start): + raise ValueError( + f"invalid stride args: start={start}, stop={stop}, step={step}" + ) + return Stream(self._id, _compose_stride(self._stride, (start, stop, step))) def slice(self, start: int, stop: int | None = None) -> "Stream[T]": - """Return a view of this stream restricted to sequences ``[start, stop)``. + """Return a view restricted to sequences ``[start, stop)`` — + shorthand for ``stride(start, stop, 1)``. Equivalent to + ``itertools.islice`` on the source stream's items. + """ + return self.stride(start, stop, 1) - ``stop=None`` means unbounded. Equivalent to ``itertools.islice`` on - the source stream's items. + def partition(self, n: int, i: int) -> "Stream[T]": + """Return a view where only sequences ``s`` with ``s % n == i`` + are delivered — round-robin partitioning for parallel consumers. + Shorthand for ``stride(i, None, n)``. """ - if start < 0 or (stop is not None and stop < start): - raise ValueError(f"invalid slice args: start={start}, stop={stop}") - return Stream( - self._id, - self._filters + ({"type": "slice", "start": start, "stop": stop},), - ) + if n < 1 or i < 0 or i >= n: + raise ValueError(f"invalid partition args: n={n}, i={i}") + return self.stride(i, None, n) def __iter__(self) -> t.Iterator[T]: # Deferred import to avoid a cycle (streams.py imports serialization # which imports models for Execution/Input/Asset). from .streams import open_subscription - return open_subscription(self._id, self._filters) + return open_subscription(self._id, self._stride) diff --git a/adapters/python/coflux/protocol.py b/adapters/python/coflux/protocol.py index a3788e2c..d326aded 100644 --- a/adapters/python/coflux/protocol.py +++ b/adapters/python/coflux/protocol.py @@ -23,6 +23,11 @@ def __init__(self) -> None: # handlers) can emit messages concurrently; serialize writes so JSON # lines don't interleave. self._write_lock = threading.Lock() + # Guards `_next_id`. Separate from `_write_lock` so a slow write + # doesn't block other threads from minting their own ids — the + # Dispatcher routes responses by id so the order of wire writes + # doesn't matter to correctness. + self._id_lock = threading.Lock() def send_message(self, method: str, params: dict[str, Any] | None = None) -> None: """Send a notification message (no response expected).""" @@ -33,14 +38,17 @@ def send_message(self, method: str, params: dict[str, Any] | None = None) -> Non def send_request(self, method: str, params: dict[str, Any]) -> int: """Send a request and return the request ID.""" - self._next_id += 1 - req = { - "id": self._next_id, - "method": method, - "params": params, - } - self._write(req) - return self._next_id + with self._id_lock: + self._next_id += 1 + request_id = self._next_id + self._write( + { + "id": request_id, + "method": method, + "params": params, + } + ) + return request_id def send_response( self, @@ -537,12 +545,16 @@ def send_stream_subscribe( producer_execution_id: str, index: int, from_sequence: int, - filter: dict[str, Any] | None = None, + stride: dict[str, Any] | None = None, ) -> None: """Open a consumer subscription to a stream owned by another execution. - ``execution_id`` is the consumer's own execution — the server uses it to - track who's subscribed and where to push items. + ``execution_id`` is the consumer's own execution — the server uses it + to track who's subscribed and where to push items. ``stride`` is an + optional ``{"start": int, "stop": int|None, "step": int}`` dict + restricting which sequence positions are delivered; any chain of + slice/partition/stride calls on the handle composes into a single + stride before reaching here. """ params: dict[str, Any] = { "execution_id": execution_id, @@ -551,8 +563,8 @@ def send_stream_subscribe( "index": index, "from_sequence": from_sequence, } - if filter is not None: - params["filter"] = filter + if stride is not None: + params["stride"] = stride get_protocol().send_message("stream_subscribe", params) diff --git a/adapters/python/coflux/serialization.py b/adapters/python/coflux/serialization.py index ce962b6f..5912047a 100644 --- a/adapters/python/coflux/serialization.py +++ b/adapters/python/coflux/serialization.py @@ -123,12 +123,15 @@ def _encode(v: Any) -> Any: return {"type": "ref", "index": len(references) - 1} elif isinstance(v, Stream): # Pass-through: a Stream handle received from another execution - # (possibly with partition/slice filters layered on top) is - # being forwarded as an argument. Preserve the filter chain so - # the downstream consumer subscribes with the same filters. + # (possibly with slice/partition/stride layered on top) is + # being forwarded as an argument. Preserve the composed stride + # so the downstream consumer subscribes with the same view. encoded: dict[str, Any] = {"type": "stream", "id": v.id} - if v._filters: - encoded["filters"] = list(v._filters) + start, stop, step = v._stride + # Only include a stride entry if it's doing something — the + # trivial (0, None, 1) is the identity. + if (start, stop, step) != (0, None, 1): + encoded["stride"] = {"start": start, "stop": stop, "step": step} return encoded elif inspect.isgenerator(v) or inspect.isasyncgen(v): raise TypeError( @@ -255,11 +258,19 @@ def _decode(v: Any) -> Any: elif t == "stream": # Producer-owned stream reference. Self-contained — the # opaque `id` encodes the producer's execution id + the - # stream's index, and any filter chain (when the Stream - # was forwarded with partition/slice filters already - # applied) rides alongside. - filters = tuple(v.get("filters") or ()) - return Stream(v["id"], filters) + # stream's index. An optional ``stride`` describes the + # composed slice/partition/stride view (identity when + # absent). + stride_raw = v.get("stride") + if stride_raw: + stride = ( + stride_raw.get("start", 0), + stride_raw.get("stop"), + stride_raw.get("step", 1), + ) + else: + stride = (0, None, 1) + return Stream(v["id"], stride) else: return v else: diff --git a/adapters/python/coflux/streams.py b/adapters/python/coflux/streams.py index defc37ba..a54a9c37 100644 --- a/adapters/python/coflux/streams.py +++ b/adapters/python/coflux/streams.py @@ -21,6 +21,7 @@ from __future__ import annotations import asyncio +import contextvars import inspect import queue import threading @@ -134,9 +135,15 @@ def register(self, generator: Any, buffer: int | None) -> str: is_async = inspect.isasyncgen(generator) target = self._run_async if is_async else self._run + # Capture the context of the registering thread (usually the main + # executor thread) and run the generator body inside it, so any + # `cf.group` / `cf.suspense` scope active at registration time + # flows through to `cf.submit_task` and friends called from the + # generator body. Without this the driver thread sees a fresh + # context and would lose those settings. + parent_context = contextvars.copy_context() thread = threading.Thread( - target=target, - args=(index, generator), + target=lambda: parent_context.run(target, index, generator), name=f"stream-{self._execution_id}-{index}", daemon=False, ) @@ -421,16 +428,21 @@ def __init__(self) -> None: def _ensure_installed(self) -> None: # Register dispatcher handlers on first use. Deferred so importing # this module is free until a task actually iterates a stream. - if self._installed: - return - d = get_dispatcher() - d.register_notification("stream_items", self._on_items) - d.register_notification("stream_closed", self._on_closed) - # If stdin goes away before the server sends close messages, - # blocked iterators would hang on their queues forever. Push a - # synthetic closed sentinel into each so ``__next__`` raises. - d.add_close_callback(self._on_dispatcher_closed) - self._installed = True + # Locked so two consumer threads first-iterating a stream at the + # same time don't both register handlers — the dispatcher would + # silently replace the first, but registering `add_close_callback` + # twice would fire the close-handling twice on EOF. + with self._lock: + if self._installed: + return + d = get_dispatcher() + d.register_notification("stream_items", self._on_items) + d.register_notification("stream_closed", self._on_closed) + # If stdin goes away before the server sends close messages, + # blocked iterators would hang on their queues forever. Push + # a synthetic closed sentinel into each so ``__next__`` raises. + d.add_close_callback(self._on_dispatcher_closed) + self._installed = True def _on_dispatcher_closed(self) -> None: """Wake all active iterators — connection to the server is gone @@ -504,12 +516,14 @@ def parse_stream_id(id: str) -> tuple[str, int]: def open_subscription( stream_id: str, - filters: tuple[dict[str, Any], ...], + stride: tuple[int, int | None, int], ) -> Iterator[Any]: """Begin iterating a stream. Called by ``Stream.__iter__``. Allocates a subscription id, sends the subscribe message, and returns - an iterator that yields as items arrive. + an iterator that yields as items arrive. ``stride`` is a + ``(start, stop, step)`` tuple — any chain of slice/partition/stride + calls on the handle collapses to a single stride before this point. """ ctx = get_context() execution_id = ctx.execution_id @@ -519,27 +533,15 @@ def open_subscription( # producer_execution_id + index positionally. producer_execution_id, index = parse_stream_id(stream_id) - filter = _compose_filter(filters) + start, stop, step = stride + wire_stride = {"start": start, "stop": stop, "step": step} + protocol.send_stream_subscribe( execution_id, subscription_id, producer_execution_id, index, 0, - filter, + stride=wire_stride, ) return iterator - - -def _compose_filter( - filters: tuple[dict[str, Any], ...], -) -> dict[str, Any] | None: - """Collapse a list of filters for the wire. - - Empty → null. Single → pass through. Many → wrap in {"type": "chain"}. - """ - if not filters: - return None - if len(filters) == 1: - return filters[0] - return {"type": "chain", "filters": list(filters)} diff --git a/cli/internal/adapter/protocol.go b/cli/internal/adapter/protocol.go index 72ba0db6..81ba76b2 100644 --- a/cli/internal/adapter/protocol.go +++ b/cli/internal/adapter/protocol.go @@ -296,16 +296,18 @@ type StreamCloseError struct { Traceback string `json:"traceback"` } -// StreamSubscribeParams for stream_subscribe notification. -// Filter is one of nil, {"type": "slice", "start", "stop"}, -// or {"type": "partition", "n", "i"}. +// StreamSubscribeParams for stream_subscribe notification. `Stride` +// (when present) restricts which sequence positions are delivered: the +// positions `start, start+step, start+2·step, …` up to (but not +// including) `stop`. Any chain of slice/partition/stride calls on the +// consumer side composes into a single stride before the wire. type StreamSubscribeParams struct { ExecutionID string `json:"execution_id"` // consumer SubscriptionID int `json:"subscription_id"` ProducerExecutionID string `json:"producer_execution_id"` Index int `json:"index"` FromSequence int `json:"from_sequence"` - Filter map[string]any `json:"filter,omitempty"` + Stride map[string]any `json:"stride,omitempty"` } // StreamUnsubscribeParams for stream_unsubscribe notification. diff --git a/cli/internal/pool/pool.go b/cli/internal/pool/pool.go index 9b2ced58..a0fc98d1 100644 --- a/cli/internal/pool/pool.go +++ b/cli/internal/pool/pool.go @@ -60,9 +60,11 @@ type ExecutionHandler interface { // StreamClose closes a stream. Error is nil for a clean close, or a (type, message, traceback) // triple when the producer's generator raised. StreamClose(ctx context.Context, executionID string, index int, err *adapter.StreamCloseError) error - // StreamSubscribe opens a consumer subscription to a stream owned by another execution. - // Filter is nil or a {"type": "slice", ...}/{"type": "partition", ...} map. - StreamSubscribe(ctx context.Context, executionID string, subscriptionID int, producerExecutionID string, index int, fromSequence int, filter map[string]any) error + // StreamSubscribe opens a consumer subscription to a stream owned + // by another execution. `stride` is an optional + // {"start", "stop", "step"} map restricting which positions are + // delivered; nil means no filtering. + StreamSubscribe(ctx context.Context, executionID string, subscriptionID int, producerExecutionID string, index int, fromSequence int, stride map[string]any) error // StreamUnsubscribe drops a consumer subscription. StreamUnsubscribe(ctx context.Context, executionID string, subscriptionID int) error } @@ -530,7 +532,7 @@ func (p *Pool) handleStreamSubscribe(ctx context.Context, executionID string, pa req.ProducerExecutionID, req.Index, req.FromSequence, - req.Filter, + req.Stride, ); err != nil { logger.Error("failed to subscribe to stream", "error", err) } diff --git a/cli/internal/worker/worker.go b/cli/internal/worker/worker.go index aa8036d8..f4e4fbd7 100644 --- a/cli/internal/worker/worker.go +++ b/cli/internal/worker/worker.go @@ -1258,13 +1258,13 @@ func (w *Worker) StreamClose(ctx context.Context, executionID string, index int, return conn.Notify("stream_close", executionID, index, errTuple) } -func (w *Worker) StreamSubscribe(ctx context.Context, executionID string, subscriptionID int, producerExecutionID string, index int, fromSequence int, filter map[string]any) error { +func (w *Worker) StreamSubscribe(ctx context.Context, executionID string, subscriptionID int, producerExecutionID string, index int, fromSequence int, stride map[string]any) error { conn, err := w.requireConn() if err != nil { return err } - // Params: [subscription_id, consumer_execution_id, producer_execution_id, index, from_sequence, filter] - return conn.Notify("stream_subscribe", subscriptionID, executionID, producerExecutionID, index, fromSequence, filter) + // Params: [subscription_id, consumer_execution_id, producer_execution_id, index, from_sequence, stride] + return conn.Notify("stream_subscribe", subscriptionID, executionID, producerExecutionID, index, fromSequence, stride) } func (w *Worker) StreamUnsubscribe(ctx context.Context, executionID string, subscriptionID int) error { diff --git a/server/lib/coflux/handlers/worker.ex b/server/lib/coflux/handlers/worker.ex index ee1f3a5b..a9fae0f4 100644 --- a/server/lib/coflux/handlers/worker.ex +++ b/server/lib/coflux/handlers/worker.ex @@ -341,7 +341,7 @@ defmodule Coflux.Handlers.Worker do producer_execution_id, index, from_sequence, - filter + stride ] = message["params"] if is_recognised_execution?(consumer_execution_id, state) do @@ -353,7 +353,7 @@ defmodule Coflux.Handlers.Worker do producer_execution_id, index, from_sequence, - filter + stride ) do :ok -> {[], state} diff --git a/server/lib/coflux/orchestration/server.ex b/server/lib/coflux/orchestration/server.ex index bc6fa451..59feb53d 100644 --- a/server/lib/coflux/orchestration/server.ex +++ b/server/lib/coflux/orchestration/server.ex @@ -2126,7 +2126,7 @@ defmodule Coflux.Orchestration.Server do def handle_call( {:subscribe_stream, session_external_id, subscription_id, consumer_execution_external_id, - producer_execution_external_id, index, from_sequence, filter}, + producer_execution_external_id, index, from_sequence, stride}, _from, state ) do @@ -2153,7 +2153,7 @@ defmodule Coflux.Orchestration.Server do producer_execution_id: producer_execution_id, index: index, cursor: from_sequence, - filter: filter + stride: stride } state = @@ -7867,32 +7867,27 @@ defmodule Coflux.Orchestration.Server do defp ok_or({:ok, val}, _reason), do: {:ok, val} defp ok_or(:error, reason), do: {:error, reason} - # Does a `sequence` pass a subscription's filter? - defp filter_matches?(nil, _sequence), do: true + # Does a `sequence` pass a subscription's stride? A stride is + # ``%{"start" => int, "stop" => int | nil, "step" => int}`` — the + # client composes any chain of slice/partition/stride calls into one + # before sending. `nil` means the trivial identity stride (everything). + defp stride_matches?(nil, _sequence), do: true - defp filter_matches?(%{"type" => "slice", "start" => s, "stop" => nil}, sequence), - do: sequence >= s - - defp filter_matches?(%{"type" => "slice", "start" => s, "stop" => e}, sequence), - do: sequence >= s and sequence < e - - defp filter_matches?(%{"type" => "partition", "n" => n, "i" => i}, sequence), - do: rem(sequence, n) == i - - defp filter_matches?(%{"type" => "chain", "filters" => fs}, sequence), - do: Enum.all?(fs, &filter_matches?(&1, sequence)) + defp stride_matches?(%{"start" => start, "stop" => stop, "step" => step}, sequence) do + sequence >= start and + (stop == nil or sequence < stop) and + rem(sequence - start, step) == 0 + end - defp filter_matches?(_filter, _sequence), do: true + # Is `cursor` past the stride's stop? Lets us close the subscription + # early once nothing more can match (i.e. the upper bound is finite + # and we've reached it). + defp stride_exhausted?(nil, _cursor), do: false - # Is `sequence` past the end of the filter's effective range? - # Lets us close streams early once a slice's stop is reached. - defp filter_exhausted?(%{"type" => "slice", "stop" => stop}, cursor) when is_integer(stop), + defp stride_exhausted?(%{"stop" => stop}, cursor) when is_integer(stop), do: cursor >= stop - defp filter_exhausted?(%{"type" => "chain", "filters" => fs}, cursor), - do: Enum.any?(fs, &filter_exhausted?(&1, cursor)) - - defp filter_exhausted?(_filter, _cursor), do: false + defp stride_exhausted?(_stride, _cursor), do: false # Per-fetch page size when draining backlog for a newly subscribed # consumer. Keeps any single DB read bounded and lets us push each page @@ -7945,8 +7940,10 @@ defmodule Coflux.Orchestration.Server do filtered = items - |> Enum.filter(fn {sequence, _value, _at} -> filter_matches?(sub.filter, sequence) end) - |> Enum.take_while(fn {sequence, _, _} -> not filter_exhausted?(sub.filter, sequence) end) + |> Enum.filter(fn {sequence, _value, _at} -> stride_matches?(sub.stride, sequence) end) + |> Enum.take_while(fn {sequence, _, _} -> + not stride_exhausted?(sub.stride, sequence) + end) # Advance cursor past the page even if no items matched this filter — # otherwise we'd re-fetch the same sequences forever. @@ -7980,14 +7977,14 @@ defmodule Coflux.Orchestration.Server do ) state = - if filter_exhausted?(sub.filter, advance_to) do - # If the filter is now exhausted (slice's stop reached), close the - # subscription synchronously — matches push_stream_item's - # behaviour. Without this, a consumer that subscribed after - # appends with a bounded filter would wait forever for a close - # that never comes. Filter-exhaustion is a "complete" outcome - # from the consumer's perspective: they've received everything - # that was addressed to them. + if stride_exhausted?(sub.stride, advance_to) do + # If the stride has reached its stop, close the subscription + # synchronously — matches push_stream_item's behaviour. Without + # this, a consumer that subscribed after appends with a bounded + # stride would wait forever for a close that never comes. + # Stride-exhaustion is a "complete" outcome from the consumer's + # perspective: they've received everything that was addressed + # to them. state |> send_to_consumer( sub, @@ -8029,7 +8026,7 @@ defmodule Coflux.Orchestration.Server do # Consumer already has this sequence via backlog; skip. state - not filter_matches?(sub.filter, sequence) -> + not stride_matches?(sub.stride, sequence) -> state true -> @@ -8052,11 +8049,11 @@ defmodule Coflux.Orchestration.Server do &Map.put(&1, :cursor, sequence + 1) ) - # If the filter is exhausted (e.g. slice reached its stop), close - # the subscription early — no more items will match. Treated - # as a "complete" close for the consumer (they got everything - # that was addressed to them). - if filter_exhausted?(sub.filter, sequence + 1) do + # If the stride has reached its stop, close the subscription + # early — no more items will match. Treated as a "complete" + # close for the consumer (they got everything that was + # addressed to them). + if stride_exhausted?(sub.stride, sequence + 1) do state |> send_to_consumer( sub, diff --git a/tests/support/executor.py b/tests/support/executor.py index 07a1cd81..86515822 100644 --- a/tests/support/executor.py +++ b/tests/support/executor.py @@ -279,10 +279,13 @@ def stream_subscribe( producer_execution_id, index, from_sequence=0, - filter=None, + stride=None, ): - """Subscribe to a stream. ``filter`` is an optional dict built via - protocol.slice_filter / partition_filter / chain_filter.""" + """Subscribe to a stream. ``stride`` is an optional + ``{"start", "stop", "step"}`` dict restricting which positions + are delivered — built via ``protocol.stride`` / + ``slice_stride`` / ``partition_stride``. ``None`` means no + filtering (identity stride).""" self.send( protocol.stream_subscribe( execution_id, @@ -290,7 +293,7 @@ def stream_subscribe( producer_execution_id, index, from_sequence=from_sequence, - filter=filter, + stride=stride, ) ) diff --git a/tests/support/protocol.py b/tests/support/protocol.py index cf7166a9..25bfa37b 100644 --- a/tests/support/protocol.py +++ b/tests/support/protocol.py @@ -282,7 +282,7 @@ def stream_subscribe( producer_execution_id, index, from_sequence=0, - filter=None, + stride=None, ): params = { "execution_id": execution_id, @@ -291,8 +291,8 @@ def stream_subscribe( "index": index, "from_sequence": from_sequence, } - if filter is not None: - params["filter"] = filter + if stride is not None: + params["stride"] = stride return {"method": "stream_subscribe", "params": params} @@ -309,16 +309,18 @@ def stream_unsubscribe(execution_id, subscription_id): # --- Filter builders --- -def slice_filter(start, stop=None): - return {"type": "slice", "start": start, "stop": stop} +def stride(start=0, stop=None, step=1): + return {"start": start, "stop": stop, "step": step} -def partition_filter(n, i): - return {"type": "partition", "n": n, "i": i} +def slice_stride(start, stop=None): + """Stride equivalent to the old ``slice(start, stop)`` filter.""" + return stride(start=start, stop=stop, step=1) -def chain_filter(*filters): - return {"type": "chain", "filters": list(filters)} +def partition_stride(n, i): + """Stride equivalent to the old ``partition(n, i)`` filter.""" + return stride(start=i, stop=None, step=n) def submit_input_request( diff --git a/tests/test_streams.py b/tests/test_streams.py index d0571289..11e0f06b 100644 --- a/tests/test_streams.py +++ b/tests/test_streams.py @@ -20,9 +20,9 @@ from support.protocol import ( execution_result, json_args, - partition_filter, - slice_filter, - chain_filter, + partition_stride, + slice_stride, + stride, ) @@ -138,7 +138,7 @@ def test_slice_filter_restricts_items(worker): subscription_id=1, producer_execution_id=prod_ex.execution_id, index=0, - filter=slice_filter(1, 3), + stride=slice_stride(1, 3), ) items, _ = cons_ex.conn.drain_stream(subscription_id=1) cons_ex.conn.complete(cons_ex.execution_id) @@ -168,7 +168,7 @@ def test_partition_filter_round_robin(worker): subscription_id=1, producer_execution_id=prod_ex.execution_id, index=0, - filter=partition_filter(n=3, i=1), + stride=partition_stride(n=3, i=1), ) items, _ = cons_ex.conn.drain_stream(subscription_id=1) cons_ex.conn.complete(cons_ex.execution_id) @@ -456,7 +456,7 @@ def test_slice_with_stop_closes_early(worker): subscription_id=1, producer_execution_id=prod_ex.execution_id, index=0, - filter=slice_filter(0, 2), + stride=slice_stride(0, 2), ) prod_ex.conn.stream_append(prod_ex.execution_id, 0, 0, "a") @@ -596,8 +596,9 @@ def test_lifecycle_close_on_completion_delivers_to_subscriber(worker): assert closed.get("error") is None # clean close — execution completed normally -def test_filter_chain_combines_slice_and_partition(worker): - """``chain(slice(0, 6), partition(2, 0))`` → positions 0, 2, 4.""" +def test_stride_combines_slice_and_partition(worker): + """The client composes ``slice(0, 6)`` then ``partition(2, 0)`` into + a single stride ``[0:6:2]``, which selects positions 0, 2, 4.""" targets = [workflow("test", "producer"), workflow("test", "consumer")] with worker(targets) as ctx: @@ -617,7 +618,7 @@ def test_filter_chain_combines_slice_and_partition(worker): subscription_id=1, producer_execution_id=prod_ex.execution_id, index=0, - filter=chain_filter(slice_filter(0, 6), partition_filter(n=2, i=0)), + stride=stride(start=0, stop=6, step=2), ) items, _ = cons_ex.conn.drain_stream(subscription_id=1) cons_ex.conn.complete(cons_ex.execution_id) From ab62a2f6eb944602cea00c946ffa77250b1e7035 Mon Sep 17 00:00:00 2001 From: Joe Freeman Date: Mon, 20 Apr 2026 13:31:17 +0100 Subject: [PATCH 17/25] Support configuring timeouts on streams --- adapters/python/coflux/__init__.py | 3 +- adapters/python/coflux/context.py | 34 +++- adapters/python/coflux/decorators.py | 25 +-- adapters/python/coflux/discovery.py | 14 +- adapters/python/coflux/executor.py | 57 ++++++- adapters/python/coflux/protocol.py | 11 ++ adapters/python/coflux/streams.py | 157 +++++++++++++++--- adapters/python/coflux/target.py | 158 ++++++++++++++++--- adapters/python/uv.lock | 2 +- cli/cmd/coflux/submit.go | 3 + cli/internal/adapter/adapter.go | 3 +- cli/internal/adapter/protocol.go | 39 ++++- cli/internal/pool/pool.go | 79 ++++++++-- cli/internal/pool/stream_timers.go | 136 ++++++++++++++++ cli/internal/worker/worker.go | 89 +++++++++-- server/lib/coflux/handlers/api.ex | 53 +++++++ server/lib/coflux/handlers/worker.ex | 51 ++++-- server/lib/coflux/orchestration.ex | 8 +- server/lib/coflux/orchestration/manifests.ex | 52 ++++-- server/lib/coflux/orchestration/models.ex | 2 + server/lib/coflux/orchestration/results.ex | 3 +- server/lib/coflux/orchestration/runs.ex | 12 ++ server/lib/coflux/orchestration/server.ex | 70 ++++---- server/lib/coflux/orchestration/streams.ex | 48 ++++-- server/lib/coflux/topics/run.ex | 20 ++- server/lib/coflux/topics/stream.ex | 1 + server/lib/coflux/topics/workflow.ex | 12 +- server/priv/migrations/orchestration/4.sql | 20 +++ tests/support/executor.py | 32 +++- tests/support/manifest.py | 3 + tests/support/protocol.py | 4 +- tests/test_streams.py | 126 +++++++++++++++ 32 files changed, 1146 insertions(+), 181 deletions(-) create mode 100644 cli/internal/pool/stream_timers.go diff --git a/adapters/python/coflux/__init__.py b/adapters/python/coflux/__init__.py index f181b42e..14214c9f 100644 --- a/adapters/python/coflux/__init__.py +++ b/adapters/python/coflux/__init__.py @@ -35,7 +35,7 @@ from .prompt import Prompt from .state import get_context from .streams import stream -from .target import Cache, Defer, Retries +from .target import Cache, Defer, Retries, Streams __all__ = [ # Version @@ -62,6 +62,7 @@ "Cache", "Defer", "Retries", + "Streams", "Asset", "AssetEntry", "AssetMetadata", diff --git a/adapters/python/coflux/context.py b/adapters/python/coflux/context.py index 141d9c79..c8793c42 100644 --- a/adapters/python/coflux/context.py +++ b/adapters/python/coflux/context.py @@ -25,6 +25,7 @@ from .models import Asset, AssetEntry, AssetMetadata, Execution, Input from .serialization import deserialize_value, serialize_value from .streams import StreamDriver +from .target import Streams def _handle_key(handle: Any) -> tuple[str, str]: @@ -69,6 +70,14 @@ def _unwrap_response( raise RuntimeError(f"Unexpected select status: {status}") +def _timeout_to_ms(timeout: float | dt.timedelta | None) -> int | None: + if timeout is None: + return None + if isinstance(timeout, dt.timedelta): + return int(timeout.total_seconds() * 1000) + return int(timeout * 1000) + + # Context variable for group tracking _group_id: contextvars.ContextVar[int | None] = contextvars.ContextVar( "_group_id", default=None @@ -104,15 +113,34 @@ def __init__(self, execution_id: str, working_dir: Path | None = None): # during serialization (of the return value OR of submit arguments) # are registered here and driven in background threads. self._stream_driver = StreamDriver(execution_id) + # Default stream config for this execution, populated by the + # executor from the target's ``@cf.task(streams=...)`` setting. + # Used by ``cf.stream(...)`` to fill in unspecified options. + self._default_streams: Streams | None = None + + def set_default_streams(self, streams: Streams | None) -> None: + """Record the decorator's stream config so ``cf.stream(...)`` can + inherit from it. Called once by the executor before running the + target function.""" + self._default_streams = streams - def register_stream(self, generator: Any, buffer: int | None) -> str: + def get_default_streams(self) -> Streams | None: + return self._default_streams + + def register_stream( + self, + generator: Any, + buffer: int | None, + timeout: float | dt.timedelta | None = None, + ) -> str: """Register a generator with this execution's stream driver and return the resulting opaque stream id. Called from ``cf.stream(...)``; also from the executor when the task body itself is a generator. """ - return self._stream_driver.register(generator, buffer) + timeout_ms = _timeout_to_ms(timeout) + return self._stream_driver.register(generator, buffer, timeout_ms) def wait_streams(self) -> None: """Block until every stream produced by this execution has drained.""" @@ -141,6 +169,7 @@ def submit_execution( recurrent: bool = False, requires: dict[str, list[str]] | None = None, timeout: int = 0, + streams: dict[str, Any] | None = None, ) -> dict[str, Any]: """Submit a child execution and return its details. @@ -165,6 +194,7 @@ def submit_execution( recurrent=recurrent, requires=requires, timeout=timeout, + streams=streams, ) return self._wait_response(request_id) diff --git a/adapters/python/coflux/decorators.py b/adapters/python/coflux/decorators.py index a50aab72..27df32ad 100644 --- a/adapters/python/coflux/decorators.py +++ b/adapters/python/coflux/decorators.py @@ -5,7 +5,7 @@ import datetime as dt import typing as t -from .target import _BUFFER_UNSET, Cache, Defer, Retries, Target +from .target import _STREAMS_UNSET, Cache, Defer, Retries, Streams, Target if t.TYPE_CHECKING: from .models import Stream @@ -52,7 +52,7 @@ def task( memo: bool | t.Iterable[str] = False, requires: dict[str, str | bool | list[str]] | None = None, timeout: float | dt.timedelta = 0, - buffer: int | None = _BUFFER_UNSET, # type: ignore[assignment] + streams: Streams | None = _STREAMS_UNSET, # type: ignore[assignment] ) -> _TargetDecorator: """Decorator for defining a task. @@ -60,12 +60,13 @@ def task( the executor; the task's return type is the coroutine's resolved value (not the coroutine itself). - ``buffer`` only applies to generator-bodied tasks. ``0`` (default) - gives strict lockstep: the producer emits an item, waits for a - consumer to ack, then emits the next. ``N`` lets the producer stay - up to N items ahead of the fastest consumer. ``None`` disables - backpressure entirely. Passing ``buffer`` on a non-generator task - raises ``TypeError`` at decoration time. + ``streams`` only applies to tasks that produce streams — either + generator-bodied tasks (``def`` + ``yield`` / ``async def`` + + ``yield``) or tasks that call ``cf.stream(...)`` internally. It + configures the default ``buffer`` and ``timeout`` for those + streams; per-call overrides on ``cf.stream(...)`` win. Passing + ``streams=`` on a non-generator task raises ``TypeError`` at + decoration time. """ def decorator(fn): @@ -82,7 +83,7 @@ def decorator(fn): memo=memo, requires=requires, timeout=timeout, - buffer=buffer, + streams=streams, ) return decorator # type: ignore[return-value] @@ -100,7 +101,7 @@ def workflow( memo: bool = False, requires: dict[str, str | bool | list[str]] | None = None, timeout: float | dt.timedelta = 0, - buffer: int | None = _BUFFER_UNSET, # type: ignore[assignment] + streams: Streams | None = _STREAMS_UNSET, # type: ignore[assignment] ) -> _TargetDecorator: """Decorator for defining a workflow. @@ -108,7 +109,7 @@ def workflow( the executor; the workflow's return type is the coroutine's resolved value (not the coroutine itself). - See ``@cf.task`` for ``buffer=`` semantics. + See ``@cf.task`` for ``streams=`` semantics. """ def decorator(fn): @@ -125,7 +126,7 @@ def decorator(fn): memo=memo, requires=requires, timeout=timeout, - buffer=buffer, + streams=streams, ) return decorator # type: ignore[return-value] diff --git a/adapters/python/coflux/discovery.py b/adapters/python/coflux/discovery.py index 21ce4c00..894be9c5 100644 --- a/adapters/python/coflux/discovery.py +++ b/adapters/python/coflux/discovery.py @@ -8,7 +8,14 @@ import sys from typing import Any -from .target import Target, _to_ms, serialize_cache, serialize_defer, serialize_retries +from .target import ( + Target, + _to_ms, + serialize_cache, + serialize_defer, + serialize_retries, + serialize_streams, +) def _expand_modules(module_names: list[str]) -> list[str]: @@ -139,6 +146,11 @@ def _build_target_definition(target: Any, module_name: str) -> dict[str, Any]: if definition.instruction: result["instruction"] = definition.instruction + if definition.streams is not None: + streams_dict = serialize_streams(definition.streams) + if streams_dict is not None: + result["streams"] = streams_dict + return result diff --git a/adapters/python/coflux/executor.py b/adapters/python/coflux/executor.py index 01ee4ad4..12fc0195 100644 --- a/adapters/python/coflux/executor.py +++ b/adapters/python/coflux/executor.py @@ -36,6 +36,31 @@ def _format_filtered_traceback(exc: Exception) -> str: return "".join(lines) +def _resolve_execute_streams(target_obj: Any, streams_from_wire: dict[str, Any] | None): + """Decide the effective stream config for this execution. + + Precedence: + 1. ``streams_from_wire`` — the execute message's ``streams`` param. + Carries either the caller's ``with_streams(...)`` override or + the manifest default copied at submit time. + 2. ``target_obj.definition.streams`` — the decorator-level default. + + Returns a ``Streams`` instance or ``None`` (no stream config). + """ + from .target import Streams # local import to avoid cycles + + if streams_from_wire is not None: + buffer = streams_from_wire.get("buffer", 0) + timeout_ms = streams_from_wire.get("timeout_ms") + timeout = timeout_ms / 1000 if timeout_ms is not None else None + return Streams(buffer=buffer, timeout=timeout) + + if hasattr(target_obj, "definition"): + return target_obj.definition.streams + + return None + + def _apply_type_hints(fn: Any, args: list[Any]) -> list[Any]: """Upgrade deserialized args using the function's type hints. @@ -69,8 +94,17 @@ def execute_target( target_name: str, arguments: list[dict[str, Any]], working_dir: str | None = None, + streams: dict[str, Any] | None = None, ) -> None: - """Execute a target with the given arguments.""" + """Execute a target with the given arguments. + + ``streams`` is the streams config passed in the execute message — + either the call-site override from ``with_streams(...)`` or the + workflow manifest default. When present it wins over the + decorator's static config; it's applied both to the auto-registered + stream for generator-bodied tasks and to ``cf.stream(...)`` calls + inside the body. + """ original_dir = os.getcwd() # Start the stdin dispatcher. From here on, all incoming messages flow # through it — individual threads block on the dispatcher rather than @@ -101,6 +135,13 @@ def execute_target( ctx = ExecutorContext( execution_id, working_dir=Path(working_dir) if working_dir else None ) + # Resolve the effective stream config. The execute message's + # ``streams`` carries either the caller's ``with_streams(...)`` + # override or the workflow manifest default; when absent we + # fall back to the decorator's static config. + effective_streams = _resolve_execute_streams(target_obj, streams) + if effective_streams is not None or hasattr(target_obj, "definition"): + ctx.set_default_streams(effective_streams) set_context(ctx) with capture_output(execution_id): @@ -114,8 +155,8 @@ def execute_target( # If the task body was itself a generator (``def`` + ``yield`` # or ``async def`` + ``yield``), the call above returned an - # unstarted generator object. Register it with the task's - # configured buffer so callers don't have to wrap explicitly. + # unstarted generator object. Register it with the effective + # stream defaults so callers don't have to wrap explicitly. # Streams created via cf.stream(...) are already registered; # they appear here as Stream handles, not generators. if (inspect.isgenerator(result) or inspect.isasyncgen(result)) and hasattr( @@ -123,9 +164,12 @@ def execute_target( ): from .streams import stream as _register_stream - result = _register_stream( - result, buffer=target_obj.definition.buffer - ) + kwargs: dict[str, Any] = {} + if effective_streams is not None: + kwargs["buffer"] = effective_streams.buffer + if effective_streams.timeout is not None: + kwargs["timeout"] = effective_streams.timeout + result = _register_stream(result, **kwargs) # Serialize result. Any streams returned (directly, or embedded # in the result structure) were already registered via @@ -201,6 +245,7 @@ def run_executor() -> int: target_name=params["target"], arguments=params.get("arguments", []), working_dir=params.get("working_dir"), + streams=params.get("streams"), ) return 0 diff --git a/adapters/python/coflux/protocol.py b/adapters/python/coflux/protocol.py index d326aded..37e5ed99 100644 --- a/adapters/python/coflux/protocol.py +++ b/adapters/python/coflux/protocol.py @@ -198,6 +198,7 @@ def request_submit_execution( recurrent: bool = False, requires: dict[str, list[str]] | None = None, timeout: int = 0, + streams: dict[str, Any] | None = None, ) -> int: """Request to submit a child execution.""" params: dict[str, Any] = { @@ -228,6 +229,8 @@ def request_submit_execution( params["requires"] = requires if timeout: params["timeout"] = timeout + if streams is not None: + params["streams"] = streams return get_protocol().send_request("submit_execution", params) @@ -472,6 +475,7 @@ def send_stream_register( execution_id: str, index: int, buffer: int | None = None, + timeout_ms: int | None = None, ) -> None: """Register a stream owned by this execution. @@ -483,10 +487,17 @@ def send_stream_register( the producer emits freely. Any integer value tells the server to pace the producer — it'll send ``stream_demand`` notifications as credits become available. + + ``timeout_ms`` is the idle-timeout budget (milliseconds). If set, the + worker (CLI) force-closes the stream with reason "timeout" when no + item has been appended for that long. Purely informational for the + server; enforcement happens in the worker. """ params: dict[str, Any] = {"execution_id": execution_id, "index": index} if buffer is not None: params["buffer"] = buffer + if timeout_ms is not None: + params["timeout_ms"] = timeout_ms get_protocol().send_message("stream_register", params) diff --git a/adapters/python/coflux/streams.py b/adapters/python/coflux/streams.py index a54a9c37..fefc067c 100644 --- a/adapters/python/coflux/streams.py +++ b/adapters/python/coflux/streams.py @@ -38,14 +38,22 @@ # --- Producer side --- -def stream(generator: Any, *, buffer: int | None = 0) -> Any: +_STREAM_OPT_UNSET: Any = object() + + +def stream( + generator: Any, + *, + buffer: Any = _STREAM_OPT_UNSET, + timeout: Any = _STREAM_OPT_UNSET, +) -> Any: """Register a generator as a Coflux stream and return a handle. - Use this when a task returns multiple streams or needs a buffer size - different from the task default. For the common case where a task - body is itself a generator, ``@cf.task(buffer=N)`` handles the - registration automatically — you don't need to call ``cf.stream`` - explicitly. + Use this when a task returns multiple streams or needs to override + the task-level stream configuration. For the common case where a + task body is itself a generator, ``@cf.task(streams=cf.Streams(...))`` + handles the registration automatically — you don't need to call + ``cf.stream`` explicitly. Registration happens at call time: the driver thread starts, the server is told about the stream, and any later serialisation sees a @@ -53,15 +61,25 @@ def stream(generator: Any, *, buffer: int | None = 0) -> Any: inside a task or workflow body (where an execution context is active); calling it from module scope or outside a task raises. + Unspecified options inherit from the enclosing task's + ``streams=cf.Streams(...)``. Explicit options override per-call. + Args: generator: A sync or async generator. Other iterables aren't accepted — wrapping a list in ``cf.stream`` doesn't make sense; pass it as a value directly. - buffer: Backpressure budget. ``0`` (the default) means strict - lockstep — the producer emits an item, waits for a consumer - to acknowledge it, then emits the next. ``N`` allows the - producer to stay up to ``N`` items ahead of the fastest - consumer. ``None`` disables backpressure entirely. + buffer: Backpressure budget. ``0`` (the default if neither + ``cf.stream(buffer=...)`` nor the task-level default sets + it) means strict lockstep — the producer emits an item, + waits for a consumer to acknowledge it, then emits the + next. ``N`` allows the producer to stay up to ``N`` items + ahead of the fastest consumer. ``None`` disables + backpressure entirely. + timeout: Idle-timeout budget. If the producer doesn't append a + new item within this window (including when blocked on + consumer demand), the stream is force-closed with reason + ``"timeout"``. Accepts a positive number of seconds, a + ``timedelta``, or ``None`` to disable. Returns: A ``Stream`` handle referencing the newly registered stream. @@ -72,10 +90,20 @@ def stream(generator: Any, *, buffer: int | None = 0) -> Any: raise TypeError( f"cf.stream expects a generator, got {type(generator).__name__}" ) - if buffer is not None and buffer < 0: - raise ValueError(f"buffer must be non-negative or None, got {buffer}") + + from .target import Streams, _validate_buffer, _validate_timeout + ctx = get_context() - stream_id = ctx.register_stream(generator, buffer) + default = ctx.get_default_streams() or Streams() + resolved_buffer = ( + _validate_buffer(buffer) if buffer is not _STREAM_OPT_UNSET else default.buffer + ) + resolved_timeout = ( + _validate_timeout(timeout) + if timeout is not _STREAM_OPT_UNSET + else default.timeout + ) + stream_id = ctx.register_stream(generator, resolved_buffer, resolved_timeout) # Local import to avoid a top-level cycle — models imports nothing # from streams but streams already imports from models at top. from .models import Stream as StreamHandle @@ -101,8 +129,22 @@ def __init__(self, execution_id: str) -> None: self._demand: dict[int, int | None] = {} self._closing = False self._demand_handler_registered = False - - def register(self, generator: Any, buffer: int | None) -> str: + self._force_close_handler_registered = False + # Indexes of streams the worker (CLI) has force-closed — typically + # because their idle timeout elapsed. Read by ``_acquire_demand`` + # and by the producer loop so the driver thread exits promptly + # and skips sending its own stream_close (the server already + # recorded the closure). + self._force_closed: dict[int, str] = {} + # Per-index generator entry, for clean close on force-close. + self._by_index: dict[int, dict[str, Any]] = {} + + def register( + self, + generator: Any, + buffer: int | None, + timeout_ms: int | None = None, + ) -> str: """Register a generator and start running it in a worker thread. Accepts both sync generators (``def`` + ``yield``) and async @@ -116,10 +158,15 @@ def register(self, generator: Any, buffer: int | None) -> str: the next); ``N>0`` allows the producer to stay up to N items ahead of the fastest consumer. + ``timeout_ms`` is the idle-timeout budget (milliseconds). The + worker (CLI) closes the stream with reason "timeout" if no item + is appended within that window. ``None`` disables the timeout. + Returns the stream's opaque ``id`` (``_``) for embedding in the serialized value as a stream reference. """ self._ensure_demand_handler_registered() + self._ensure_force_close_handler_registered() with self._lock: index = self._next_index @@ -131,7 +178,9 @@ def register(self, generator: Any, buffer: int | None) -> str: # it (or on first consumer subscribing). self._demand[index] = None if buffer is None else 0 - protocol.send_stream_register(self._execution_id, index, buffer=buffer) + protocol.send_stream_register( + self._execution_id, index, buffer=buffer, timeout_ms=timeout_ms + ) is_async = inspect.isasyncgen(generator) target = self._run_async if is_async else self._run @@ -151,6 +200,7 @@ def register(self, generator: Any, buffer: int | None) -> str: with self._lock: self._generators.append(entry) self._threads.append(thread) + self._by_index[index] = entry thread.start() return compose_stream_id(self._execution_id, index) @@ -161,6 +211,14 @@ def _ensure_demand_handler_registered(self) -> None: get_dispatcher().register_notification("stream_demand", self._on_stream_demand) self._demand_handler_registered = True + def _ensure_force_close_handler_registered(self) -> None: + if self._force_close_handler_registered: + return + get_dispatcher().register_notification( + "stream_force_close", self._on_stream_force_close + ) + self._force_close_handler_registered = True + def _on_stream_demand(self, params: dict[str, Any]) -> None: """Server granted additional demand for one of our streams. @@ -179,11 +237,54 @@ def _on_stream_demand(self, params: dict[str, Any]) -> None: self._demand[index] = current + n self._demand_cv.notify_all() + def _on_stream_force_close(self, params: dict[str, Any]) -> None: + """CLI is telling us to stop producing for a specific stream. + + Fires when the worker's stream-timer has elapsed and it has + already informed the server. We mark the stream force-closed so + ``_acquire_demand`` returns False and the producer thread exits + without sending its own ``stream_close`` (that would race the + closure the server already recorded). + + Also closes the generator so any work it's doing (e.g., a long + ``next()``) is interrupted at the next yield point. + """ + index = params.get("index") + reason = params.get("reason") or "timeout" + if index is None: + return + with self._demand_cv: + self._force_closed[index] = reason + self._demand_cv.notify_all() + # Close the generator off the dispatcher thread to avoid blocking + # on a long-running next() call there. + with self._lock: + entry = self._by_index.get(index) + if entry is None: + return + try: + if entry["is_async"]: + loop = entry["loop"] + if loop is not None and not loop.is_closed(): + gen = entry["generator"] + + async def _close(g=gen) -> None: + try: + await g.aclose() + except Exception: + pass + + asyncio.run_coroutine_threadsafe(_close(), loop) + else: + entry["generator"].close() + except Exception: + pass + def _acquire_demand(self, index: int) -> bool: """Wait for a credit and consume it. Returns False if closed mid-wait.""" with self._demand_cv: while True: - if self._closing: + if self._closing or index in self._force_closed: return False current = self._demand.get(index) if current is None: @@ -194,6 +295,10 @@ def _acquire_demand(self, index: int) -> bool: return True self._demand_cv.wait() + def _is_force_closed(self, index: int) -> bool: + with self._demand_cv: + return index in self._force_closed + def _run(self, index: int, generator: Any) -> None: """Run one sync generator to exhaustion (or error).""" sequence = 0 @@ -219,11 +324,15 @@ def _run(self, index: int, generator: Any) -> None: sequence += 1 except GeneratorExit: # Generator explicitly closed (via close_all on error path, or - # server-initiated cancel). Skip send_stream_close — the server - # records a lifecycle closure when the execution terminates and - # derives the error from the execution's outcome. + # by the force-close handler for a worker-initiated timeout). + # Skip send_stream_close — the server either records a + # lifecycle closure on execution-end, or has already recorded + # the force-close reason (e.g. "timeout"). return except BaseException as e: # noqa: BLE001 - we propagate all + if self._is_force_closed(index): + # Worker already recorded the close; don't overwrite. + return error_type = f"{type(e).__module__}.{type(e).__qualname__}" tb = traceback.format_exc() protocol.send_stream_close( @@ -234,6 +343,8 @@ def _run(self, index: int, generator: Any) -> None: traceback=tb, ) else: + if self._is_force_closed(index): + return protocol.send_stream_close(self._execution_id, index) def _run_async(self, index: int, generator: Any) -> None: @@ -275,6 +386,8 @@ async def iterate() -> None: except (GeneratorExit, asyncio.CancelledError): return except BaseException as e: # noqa: BLE001 - we propagate all + if self._is_force_closed(index): + return error_type = f"{type(e).__module__}.{type(e).__qualname__}" tb = traceback.format_exc() protocol.send_stream_close( @@ -285,6 +398,8 @@ async def iterate() -> None: traceback=tb, ) else: + if self._is_force_closed(index): + return protocol.send_stream_close(self._execution_id, index) finally: try: diff --git a/adapters/python/coflux/target.py b/adapters/python/coflux/target.py index 798d213a..dc28b987 100644 --- a/adapters/python/coflux/target.py +++ b/adapters/python/coflux/target.py @@ -50,6 +50,33 @@ class Retries: ) = None +@dataclasses.dataclass(frozen=True) +class Streams: + """Default stream configuration for a task or workflow. + + Applies to: + * Streams created explicitly with ``cf.stream(...)`` — each option + (``buffer``, ``timeout``) can be overridden per-call. + * Generator-bodied tasks, where the task itself produces the stream. + + ``buffer`` is the producer-side backpressure budget. ``0`` (the + default) means strict lockstep: the producer emits an item, waits + for a consumer to ack, then emits the next. ``N`` allows the + producer to run up to ``N`` items ahead of the fastest consumer. + ``None`` disables backpressure entirely. + + ``timeout`` is the idle-timeout budget — if the producer hasn't + appended a new item within this window (including when blocked + waiting for consumer demand), the stream is force-closed with + reason ``"timeout"``. Enforced at the worker level. ``None`` + disables the timeout. + """ + + _: dataclasses.KW_ONLY + buffer: int | None = 0 + timeout: float | dt.timedelta | None = None + + class Parameter(t.NamedTuple): name: str annotation: str | None @@ -70,10 +97,12 @@ class TargetDefinition(t.NamedTuple): timeout: float | dt.timedelta instruction: str | None is_stub: bool - # Backpressure for generator-bodied tasks. 0 = strict lockstep (default), - # N = up to N items ahead of the fastest consumer, None = unbounded. - # Only meaningful when ``fn`` is a generator function. - buffer: int | None + # Default stream configuration. Used for generator-bodied tasks + # (where the task itself produces a stream) and as the default for + # ``cf.stream(...)`` calls within the task body. Individual + # ``cf.stream`` kwargs override these per-call. ``None`` means the + # task never deals with streams — validated at decoration time. + streams: Streams | None def _json_dumps(obj: t.Any) -> str: @@ -199,34 +228,67 @@ def _parse_requires( return {k: _parse_require(v) for k, v in requires.items()} if requires else None -_BUFFER_UNSET = object() +_STREAMS_UNSET = object() + +def _validate_buffer(buffer: t.Any) -> int | None: + if buffer is None: + return None + if not isinstance(buffer, int) or isinstance(buffer, bool) or buffer < 0: + raise ValueError( + f"buffer must be a non-negative integer or None, got {buffer!r}" + ) + return buffer + + +def _validate_timeout( + timeout: t.Any, +) -> float | dt.timedelta | None: + if timeout is None: + return None + if isinstance(timeout, dt.timedelta): + if timeout.total_seconds() <= 0: + raise ValueError(f"timeout must be positive, got {timeout!r}") + return timeout + if isinstance(timeout, (int, float)) and not isinstance(timeout, bool): + if timeout <= 0: + raise ValueError(f"timeout must be positive, got {timeout!r}") + return timeout + raise TypeError( + f"timeout must be a positive number, timedelta, or None, got {timeout!r}" + ) -def _resolve_buffer( - buffer: t.Any, + +def _resolve_streams( + streams: t.Any, fn: t.Callable, -) -> int | None: - """Validate the decorator's ``buffer=`` and return the resolved value. +) -> Streams | None: + """Validate the decorator's ``streams=`` and return the resolved value. - Default is 0 (strict lockstep) for generator-bodied tasks. ``None`` - disables backpressure. ``buffer`` on a non-generator task is an - error — it wouldn't apply to anything. + A non-generator task gets ``None`` (no stream config makes sense). + A generator-bodied task with no explicit ``streams=`` gets a default + ``Streams()`` (buffer=0 strict lockstep, no timeout). Passing + ``streams=`` on a non-generator task raises. """ is_generator = inspect.isgeneratorfunction(fn) or inspect.isasyncgenfunction(fn) - if buffer is _BUFFER_UNSET: - return 0 if is_generator else None + if streams is _STREAMS_UNSET: + return Streams() if is_generator else None if not is_generator: raise TypeError( - f"@cf.task/@cf.workflow(buffer=...) only applies to generator functions " + f"@cf.task/@cf.workflow(streams=...) only applies to generator functions " f"(def + yield or async def + yield); {fn.__name__} is not." ) - if buffer is None: + if streams is None: return None - if not isinstance(buffer, int) or isinstance(buffer, bool) or buffer < 0: - raise ValueError( - f"buffer must be a non-negative integer or None, got {buffer!r}" + if not isinstance(streams, Streams): + raise TypeError( + f"streams= must be a cf.Streams instance or None, got {type(streams).__name__}" ) - return buffer + # Re-validate the options (defensive — Streams itself is a plain dataclass). + return Streams( + buffer=_validate_buffer(streams.buffer), + timeout=_validate_timeout(streams.timeout), + ) def _build_definition( @@ -242,7 +304,7 @@ def _build_definition( requires: dict[str, str | bool | list[str]] | None, timeout: float | dt.timedelta, is_stub: bool, - buffer: t.Any = _BUFFER_UNSET, + streams: t.Any = _STREAMS_UNSET, ) -> TargetDefinition: parameters = inspect.signature(fn).parameters.values() for p in parameters: @@ -263,7 +325,7 @@ def _build_definition( timeout, inspect.getdoc(fn), is_stub, - _resolve_buffer(buffer, fn), + _resolve_streams(streams, fn), ) @@ -315,6 +377,18 @@ def serialize_retries(retries: Retries) -> dict: return result +def serialize_streams(streams: Streams) -> dict | None: + """Serialise a Streams dataclass to the wire format used in the + manifest and in submit_execution requests. Returns ``None`` if + neither option is set (so the key is omitted from the wire).""" + result: dict[str, t.Any] = {} + if streams.buffer is not None: + result["buffer"] = streams.buffer + if streams.timeout is not None: + result["timeout_ms"] = _to_ms(streams.timeout) + return result if result else None + + class Target(t.Generic[P, T]): """Wrapper for a decorated task or workflow function. @@ -356,7 +430,7 @@ def __init__( requires: dict[str, str | bool | list[str]] | None = None, timeout: float | dt.timedelta = 0, is_stub: bool = False, - buffer: t.Any = _BUFFER_UNSET, + streams: t.Any = _STREAMS_UNSET, ): self._fn = fn self._name = name or fn.__name__ @@ -374,7 +448,7 @@ def __init__( requires, timeout, is_stub, - buffer, + streams, ) functools.update_wrapper(self, fn) @@ -424,6 +498,35 @@ def with_requires( """Return a new Target with routing tags overridden for this call site.""" return self._copy(requires=_parse_requires(requires)) + def with_streams(self, streams: Streams | None) -> Target[P, T]: + """Return a new Target with stream config overridden for this call site. + + Only meaningful for targets that produce streams (generator + functions, or bodies that call ``cf.stream(...)``). The new + config becomes the default for ``cf.stream(...)`` inside the + task; per-call ``cf.stream(buffer=..., timeout=...)`` overrides + still win. + """ + if self._definition.streams is None: + raise TypeError( + f"with_streams is only applicable to stream-producing targets; " + f"{self._name} was declared without a streams config." + ) + if streams is not None and not isinstance(streams, Streams): + raise TypeError( + f"with_streams expects a cf.Streams instance or None, got " + f"{type(streams).__name__}" + ) + resolved = ( + None + if streams is None + else Streams( + buffer=_validate_buffer(streams.buffer), + timeout=_validate_timeout(streams.timeout), + ) + ) + return self._copy(streams=resolved) + @property def name(self) -> str: return self._name @@ -477,6 +580,12 @@ def submit(self, *args: P.args, **kwargs: P.kwargs) -> Execution[T]: else None ) + streams_dict = ( + serialize_streams(self._definition.streams) + if self._definition.streams + else None + ) + # Get memo value (bool or list of indices) memo_val = self._definition.memo if self._definition.memo else None @@ -495,6 +604,7 @@ def submit(self, *args: P.args, **kwargs: P.kwargs) -> Execution[T]: recurrent=self._definition.recurrent, requires=self._definition.requires, timeout=_to_ms(self._definition.timeout) if self._definition.timeout else 0, + streams=streams_dict, ) return Execution(result["execution_id"], result["module"], result["target"]) diff --git a/adapters/python/uv.lock b/adapters/python/uv.lock index 592e344d..464bb997 100644 --- a/adapters/python/uv.lock +++ b/adapters/python/uv.lock @@ -4,5 +4,5 @@ requires-python = ">=3.10" [[package]] name = "coflux" -version = "0.9.0.dev0" +version = "0.12.0.dev0" source = { editable = "." } diff --git a/cli/cmd/coflux/submit.go b/cli/cmd/coflux/submit.go index b160d672..014cb8f2 100644 --- a/cli/cmd/coflux/submit.go +++ b/cli/cmd/coflux/submit.go @@ -124,6 +124,9 @@ func runSubmit(cmd *cobra.Command, args []string) error { if timeout, ok := workflow["timeout"].(float64); ok && timeout > 0 { options["timeout"] = int64(timeout) } + if streams, ok := workflow["streams"].(map[string]any); ok && streams != nil { + options["streams"] = streams + } // Apply per-run overrides from flags. if cmd.Flags().Changed("requires") { diff --git a/cli/internal/adapter/adapter.go b/cli/internal/adapter/adapter.go index 2aa06e15..0527c259 100644 --- a/cli/internal/adapter/adapter.go +++ b/cli/internal/adapter/adapter.go @@ -216,7 +216,7 @@ func (e *Executor) Send(msg any) error { } // SendExecute sends an execute command to the executor -func (e *Executor) SendExecute(executionID, module, target string, arguments []Argument, workingDir string) error { +func (e *Executor) SendExecute(executionID, module, target string, arguments []Argument, workingDir string, streams *StreamsConfig) error { req := ExecuteRequest{ Method: "execute", Params: ExecuteRequestParams{ @@ -225,6 +225,7 @@ func (e *Executor) SendExecute(executionID, module, target string, arguments []A Target: target, Arguments: arguments, WorkingDir: workingDir, + Streams: streams, }, } return e.Send(req) diff --git a/cli/internal/adapter/protocol.go b/cli/internal/adapter/protocol.go index 81ba76b2..50ca3958 100644 --- a/cli/internal/adapter/protocol.go +++ b/cli/internal/adapter/protocol.go @@ -24,10 +24,19 @@ type TargetDefinition struct { Requires map[string][]string `json:"requires,omitempty"` Recurrent bool `json:"recurrent,omitempty"` Timeout int64 `json:"timeout,omitempty"` // timeout in milliseconds + Streams *StreamsConfig `json:"streams,omitempty"` IsStub bool `json:"is_stub,omitempty"` Instruction *string `json:"instruction,omitempty"` } +// StreamsConfig is the wire form of cf.Streams: default buffer + idle +// timeout for streams produced by the target. Either field may be +// absent; the adapter falls back to the decorator default. +type StreamsConfig struct { + Buffer *int `json:"buffer,omitempty"` + TimeoutMs *int `json:"timeout_ms,omitempty"` +} + // Parameter describes a function parameter type Parameter struct { Name string `json:"name"` @@ -63,11 +72,12 @@ type ExecuteRequest struct { // ExecuteRequestParams contains execution parameters type ExecuteRequestParams struct { - ExecutionID string `json:"execution_id"` - Module string `json:"module"` - Target string `json:"target"` - Arguments []Argument `json:"arguments"` - WorkingDir string `json:"working_dir,omitempty"` + ExecutionID string `json:"execution_id"` + Module string `json:"module"` + Target string `json:"target"` + Arguments []Argument `json:"arguments"` + WorkingDir string `json:"working_dir,omitempty"` + Streams *StreamsConfig `json:"streams,omitempty"` } // Argument is the same structure as Value (used for arguments to distinguish context) @@ -194,6 +204,7 @@ type SubmitExecutionParams struct { Recurrent bool `json:"recurrent,omitempty"` Requires map[string][]string `json:"requires,omitempty"` Timeout int64 `json:"timeout,omitempty"` // timeout in milliseconds + Streams *StreamsConfig `json:"streams,omitempty"` } // SubmitExecutionResult is the response to submit_execution @@ -256,15 +267,17 @@ type RegisterGroupParams struct { // StreamRegisterParams for stream_register notification. // Index is worker-assigned, monotonic per execution — it identifies the // stream within its producer execution. Buffer is the optional -// backpressure budget; nil means unbounded (no flow control). +// backpressure budget; nil means unbounded (no flow control). TimeoutMs +// is the optional idle-timeout budget (milliseconds) — nil disables it. type StreamRegisterParams struct { ExecutionID string `json:"execution_id"` Index int `json:"index"` Buffer *int `json:"buffer,omitempty"` + TimeoutMs *int `json:"timeout_ms,omitempty"` } // StreamDemandParams for stream_demand notification pushed CLI → adapter. -// Grants the producer ``n`` more credits for the given stream. +// Grants the producer “n“ more credits for the given stream. type StreamDemandParams struct { ExecutionID string `json:"execution_id"` Index int `json:"index"` @@ -324,6 +337,18 @@ type StreamItemsParams struct { Items []any `json:"items"` } +// StreamForceCloseParams for stream_force_close notification pushed +// CLI → adapter. Tells the producer side that its stream has already +// been closed externally (typically by the worker's idle-timeout +// timer), so it should stop producing and skip sending its own +// stream_close. “Reason“ is a semantic string — today just +// “"timeout"“ but kept as a string for future extension. +type StreamForceCloseParams struct { + ExecutionID string `json:"execution_id"` + Index int `json:"index"` + Reason string `json:"reason"` +} + // StreamClosedParams for stream_closed notification pushed CLI → adapter. // // `Reason` is a semantic string ("complete" / "errored" / "cancelled" / diff --git a/cli/internal/pool/pool.go b/cli/internal/pool/pool.go index a0fc98d1..03d49660 100644 --- a/cli/internal/pool/pool.go +++ b/cli/internal/pool/pool.go @@ -53,13 +53,15 @@ type ExecutionHandler interface { // Index is worker-assigned, monotonic per execution — it identifies // the stream within its producer execution. Buffer is the optional // backpressure budget; nil means unbounded (no flow control). - StreamRegister(ctx context.Context, executionID string, index int, buffer *int) error + // TimeoutMs is the optional idle-timeout budget (milliseconds); + // purely informational for the server (enforced at the worker/CLI). + StreamRegister(ctx context.Context, executionID string, index int, buffer *int, timeoutMs *int) error // StreamAppend appends an item to a stream. Sequence is worker-assigned, // monotonic per stream — it identifies the item within its stream. StreamAppend(ctx context.Context, executionID string, index int, sequence int, value *adapter.Value) error - // StreamClose closes a stream. Error is nil for a clean close, or a (type, message, traceback) - // triple when the producer's generator raised. - StreamClose(ctx context.Context, executionID string, index int, err *adapter.StreamCloseError) error + // StreamClose closes a stream. ``reason`` is "complete" | "errored" | "timeout". + // When nil, inferred from ``err`` (nil→complete, non-nil→errored). + StreamClose(ctx context.Context, executionID string, index int, err *adapter.StreamCloseError, reason *string) error // StreamSubscribe opens a consumer subscription to a stream owned // by another execution. `stride` is an optional // {"start", "stop", "step"} map restricting which positions are @@ -88,6 +90,13 @@ type Pool struct { cancel context.CancelFunc ctx context.Context wg sync.WaitGroup // tracks runExecution goroutines + + // Idle-timeout timers for registered streams. Registered when the + // adapter sends stream_register with a timeout_ms, reset on each + // stream_append, cleared on stream_close or execution end. On fire, + // the pool reports a close with reason="timeout" to the server and + // pushes stream_force_close to the adapter. + streamTimers *streamTimers } // NewPool creates a new executor pool. @@ -101,7 +110,7 @@ func NewPool(adp adapter.Adapter, concurrency int, handler ExecutionHandler, log if warmTarget > 4 { warmTarget = 4 } - return &Pool{ + p := &Pool{ adapter: adp, concurrency: concurrency, warmTarget: warmTarget, @@ -110,6 +119,32 @@ func NewPool(adp adapter.Adapter, concurrency int, handler ExecutionHandler, log busy: make(map[string]*adapter.Executor), aborted: make(map[string]bool), } + p.streamTimers = newStreamTimers(p.onStreamTimeout) + return p +} + +// onStreamTimeout is invoked by the stream-timers registry when a +// stream's idle deadline elapses. Reports the close to the server with +// reason="timeout" and notifies the adapter so its producer thread +// stops trying to append. Runs on a goroutine owned by time.AfterFunc. +func (p *Pool) onStreamTimeout(key streamKey) { + logger := p.logger.With("execution_id", key.executionID, "stream_index", key.index) + logger.Info("stream idle timeout elapsed") + + reason := "timeout" + if err := p.handler.StreamClose(p.ctx, key.executionID, key.index, nil, &reason); err != nil { + logger.Error("failed to report stream timeout close", "error", err) + } + + // Tell the adapter so its producer thread stops. Best-effort — + // PushToExecutor is a no-op if the adapter has already exited. + if err := p.PushToExecutor(key.executionID, "stream_force_close", map[string]any{ + "execution_id": key.executionID, + "index": key.index, + "reason": "timeout", + }); err != nil { + logger.Warn("failed to push stream_force_close", "error", err) + } } // Start initializes the pool by spawning warm executors (best-effort). @@ -148,7 +183,9 @@ func (p *Pool) spawnExecutor(ctx context.Context) (*adapter.Executor, error) { // Execute runs a target. Uses a warm executor if available, otherwise spawns // one on demand. Returns an error if spawning fails (caller should report to server). // timeoutMs, if > 0, enforces a wall-clock timeout on the execution. -func (p *Pool) Execute(ctx context.Context, executionID, module, target string, arguments []adapter.Argument, timeoutMs int64) error { +// streams (if non-nil) is the default stream config — forwarded to the +// adapter so generator-bodied tasks and cf.stream(...) calls pick it up. +func (p *Pool) Execute(ctx context.Context, executionID, module, target string, arguments []adapter.Argument, timeoutMs int64, streams *adapter.StreamsConfig) error { p.mu.Lock() if p.shutdown { p.mu.Unlock() @@ -179,12 +216,12 @@ func (p *Pool) Execute(ctx context.Context, executionID, module, target string, p.wg.Add(1) p.mu.Unlock() - go p.runExecution(ctx, exec, executionID, module, target, arguments, timeoutMs) + go p.runExecution(ctx, exec, executionID, module, target, arguments, timeoutMs, streams) return nil } -func (p *Pool) runExecution(ctx context.Context, exec *adapter.Executor, executionID, module, target string, arguments []adapter.Argument, timeoutMs int64) { +func (p *Pool) runExecution(ctx context.Context, exec *adapter.Executor, executionID, module, target string, arguments []adapter.Argument, timeoutMs int64, streams *adapter.StreamsConfig) { defer p.wg.Done() // Create a temporary directory for this execution @@ -200,7 +237,7 @@ func (p *Pool) runExecution(ctx context.Context, exec *adapter.Executor, executi logger := p.logger.With("execution_id", executionID, "module", module, "target", target) // Send execute command - if err := exec.SendExecute(executionID, module, target, arguments, workingDir); err != nil { + if err := exec.SendExecute(executionID, module, target, arguments, workingDir, streams); err != nil { logger.Error("failed to send execute command", "error", err) p.handler.ReportError(ctx, executionID, "internal", err.Error(), "", nil) os.RemoveAll(workingDir) @@ -353,6 +390,11 @@ func (p *Pool) finishExecution(executionID string, execToClose *adapter.Executor delete(p.aborted, executionID) p.mu.Unlock() + // Drop any lingering stream timers for this execution. The server + // will synthesise :lifecycle closures for streams still open at + // this point — we don't want a timer fire racing that. + p.streamTimers.ClearExecution(executionID) + if execToClose != nil { _ = execToClose.Close() } @@ -489,8 +531,12 @@ func (p *Pool) handleStreamRegister(ctx context.Context, executionID string, par return } - if err := p.handler.StreamRegister(ctx, req.ExecutionID, req.Index, req.Buffer); err != nil { + if err := p.handler.StreamRegister(ctx, req.ExecutionID, req.Index, req.Buffer, req.TimeoutMs); err != nil { logger.Error("failed to register stream", "error", err) + return + } + if req.TimeoutMs != nil { + p.streamTimers.Register(streamKey{req.ExecutionID, req.Index}, *req.TimeoutMs) } } @@ -501,6 +547,11 @@ func (p *Pool) handleStreamAppend(ctx context.Context, executionID string, param return } + // Reset the idle-timeout timer first so an in-flight timer fire + // doesn't race a successful append. Harmless if no timer is + // registered for this stream. + p.streamTimers.Reset(streamKey{req.ExecutionID, req.Index}) + if err := p.handler.StreamAppend(ctx, req.ExecutionID, req.Index, req.Sequence, req.Value); err != nil { logger.Error("failed to append stream item", "error", err) } @@ -513,7 +564,13 @@ func (p *Pool) handleStreamClose(ctx context.Context, executionID string, params return } - if err := p.handler.StreamClose(ctx, req.ExecutionID, req.Index, req.Error); err != nil { + // Clear the timer; harmless if none was registered. If the timer + // had already fired (Clear returns false), the timeout path has + // already reported the close and the server will dedupe this + // forward via ``:already_closed``. + p.streamTimers.Clear(streamKey{req.ExecutionID, req.Index}) + + if err := p.handler.StreamClose(ctx, req.ExecutionID, req.Index, req.Error, nil); err != nil { logger.Error("failed to close stream", "error", err) } } diff --git a/cli/internal/pool/stream_timers.go b/cli/internal/pool/stream_timers.go new file mode 100644 index 00000000..5b45d7c8 --- /dev/null +++ b/cli/internal/pool/stream_timers.go @@ -0,0 +1,136 @@ +// Worker-side idle-timeout enforcement for producer streams. +// +// Each registered stream with a configured ``timeout_ms`` gets a +// ``time.Timer`` that resets on every ``stream_append`` and clears on +// ``stream_close`` (or when the producer execution ends). When the +// timer fires, the owning pool reports a close with reason="timeout" +// to the server and tells the adapter to stop producing via a +// ``stream_force_close`` notification. +// +// The timer lives in the CLI (rather than the server) so the check is +// local to the process producing the stream, without a server +// round-trip. Server-side we just record the outcome. + +package pool + +import ( + "sync" + "time" +) + +type streamKey struct { + executionID string + index int +} + +// streamTimer is one active idle-timeout for a single stream. +type streamTimer struct { + timeout time.Duration + timer *time.Timer +} + +// streamTimers is a concurrency-safe registry of active stream timers +// for a pool. All mutations take the lock; the underlying timer +// callback runs on its own goroutine (scheduled by time.AfterFunc), +// so it must not hold the lock while invoking fireFn. +type streamTimers struct { + mu sync.Mutex + timers map[streamKey]*streamTimer + fireFn func(key streamKey) +} + +func newStreamTimers(fireFn func(key streamKey)) *streamTimers { + return &streamTimers{ + timers: make(map[streamKey]*streamTimer), + fireFn: fireFn, + } +} + +// Register starts a new idle-timeout timer for the given stream. +// “timeoutMs“ <= 0 is a no-op (no timeout configured). Safe to call +// even if a timer already exists for the key — the existing one is +// stopped first (defensive; the adapter shouldn't double-register). +func (s *streamTimers) Register(key streamKey, timeoutMs int) { + if timeoutMs <= 0 { + return + } + d := time.Duration(timeoutMs) * time.Millisecond + + s.mu.Lock() + if existing, ok := s.timers[key]; ok { + existing.timer.Stop() + } + st := &streamTimer{timeout: d} + st.timer = time.AfterFunc(d, func() { s.fire(key) }) + s.timers[key] = st + s.mu.Unlock() +} + +// Reset restarts the countdown for a stream. No-op if no timer was +// registered (stream has no timeout configured, or was already +// cleared). +func (s *streamTimers) Reset(key streamKey) { + s.mu.Lock() + st, ok := s.timers[key] + s.mu.Unlock() + if !ok { + return + } + // time.Timer.Reset is safe to call on a timer that has already + // stopped; we pre-stop to avoid the rare race where the timer + // fires between Stop() and Reset() — Reset alone would schedule a + // second fire. + st.timer.Stop() + st.timer.Reset(st.timeout) +} + +// Clear stops and removes the timer for a stream. Returns true if a +// timer existed. Safe to call from the timer callback (fire already +// removed the entry before invoking fireFn). +func (s *streamTimers) Clear(key streamKey) bool { + s.mu.Lock() + st, ok := s.timers[key] + if ok { + delete(s.timers, key) + } + s.mu.Unlock() + if !ok { + return false + } + st.timer.Stop() + return true +} + +// ClearExecution removes every timer owned by the given execution. +// Used on execution end to drop any lingering timers in bulk. +func (s *streamTimers) ClearExecution(executionID string) { + s.mu.Lock() + toStop := make([]*streamTimer, 0) + for k, st := range s.timers { + if k.executionID == executionID { + toStop = append(toStop, st) + delete(s.timers, k) + } + } + s.mu.Unlock() + for _, st := range toStop { + st.timer.Stop() + } +} + +// fire is invoked by time.AfterFunc when the timeout elapses. Removes +// the entry before invoking the callback, so any Clear/Reset arriving +// from a concurrent close/append is harmless. +func (s *streamTimers) fire(key streamKey) { + s.mu.Lock() + _, ok := s.timers[key] + if ok { + delete(s.timers, key) + } + s.mu.Unlock() + if !ok { + // Cleared before we fired — nothing to do. + return + } + s.fireFn(key) +} diff --git a/cli/internal/worker/worker.go b/cli/internal/worker/worker.go index f4e4fbd7..2470e241 100644 --- a/cli/internal/worker/worker.go +++ b/cli/internal/worker/worker.go @@ -506,6 +506,25 @@ func (w *Worker) handleExecute(params []any) error { } } + // Optional streams config (8th param). A map with string keys + // "buffer" and/or "timeout_ms"; either may be absent. Passed to + // the adapter so ``cf.stream(...)`` / generator-bodied tasks + // inherit the caller's override (or the workflow manifest default). + var streams *adapter.StreamsConfig + if len(params) > 7 && params[7] != nil { + if m, ok := params[7].(map[string]any); ok { + streams = &adapter.StreamsConfig{} + if v, ok := m["buffer"].(float64); ok { + buf := int(v) + streams.Buffer = &buf + } + if v, ok := m["timeout_ms"].(float64); ok { + t := int(v) + streams.TimeoutMs = &t + } + } + } + w.logger.Debug("executing", "execution_id", executionID, "module", moduleName, "target", targetName, "run_id", runID, "timeout_ms", timeoutMs) // Track execution @@ -545,7 +564,7 @@ func (w *Worker) handleExecute(params []any) error { w.mu.Unlock() // Execute on pool - if err := w.pool.Execute(context.Background(), executionID, moduleName, targetName, args, timeoutMs); err != nil { + if err := w.pool.Execute(context.Background(), executionID, moduleName, targetName, args, timeoutMs, streams); err != nil { w.logger.Error("failed to execute", "error", err, "run_id", runID) w.ReportError(context.Background(), executionID, "internal", err.Error(), "", nil) } @@ -727,7 +746,7 @@ func (w *Worker) handleStreamClosed(params []any) error { // handleStreamDemand forwards a server-pushed demand grant to the producer // adapter. Params: [execution_id, index, n]. The producer's StreamDriver -// adds ``n`` to its per-stream credit counter and wakes any waiting +// adds “n“ to its per-stream credit counter and wakes any waiting // worker thread. func (w *Worker) handleStreamDemand(params []any) error { if len(params) < 3 { @@ -876,7 +895,24 @@ func (w *Worker) SubmitExecution(ctx context.Context, params *adapter.SubmitExec timeout = params.Timeout } - // Server expects: module, target, type, arguments, parent_id, group_id, wait_for, cache, defer, memo, delay, retries, recurrent, requires, timeout + // Streams config (buffer + idle timeout_ms defaults for streams + // produced by this execution). Encoded as a map with keys that the + // Elixir handler reads positionally; nil omits the option entirely. + var streams any + if params.Streams != nil { + s := map[string]any{} + if params.Streams.Buffer != nil { + s["buffer"] = *params.Streams.Buffer + } + if params.Streams.TimeoutMs != nil { + s["timeout_ms"] = *params.Streams.TimeoutMs + } + if len(s) > 0 { + streams = s + } + } + + // Server expects: module, target, type, arguments, parent_id, group_id, wait_for, cache, defer, memo, delay, retries, recurrent, requires, timeout, streams conn, err := w.requireConn() if err != nil { return nil, err @@ -897,6 +933,7 @@ func (w *Worker) SubmitExecution(ctx context.Context, params *adapter.SubmitExec params.Recurrent, // recurrent params.Requires, // requires timeout, // timeout + streams, // streams ) if err != nil { return nil, err @@ -1216,17 +1253,23 @@ func (w *Worker) RegisterGroup(ctx context.Context, executionID string, groupID return conn.Notify("register_group", executionID, groupID, name) } -func (w *Worker) StreamRegister(ctx context.Context, executionID string, index int, buffer *int) error { +func (w *Worker) StreamRegister(ctx context.Context, executionID string, index int, buffer *int, timeoutMs *int) error { conn, err := w.requireConn() if err != nil { return err } - // The wire protocol takes buffer positionally; nil encodes to JSON null, - // which the server interprets as "no backpressure" (unbounded). - if buffer == nil { - return conn.Notify("stream_register", executionID, index, nil) + // The wire protocol takes buffer and timeout_ms positionally; nil + // encodes to JSON null. Server reads [execution_id, index, buffer, + // timeout_ms?]; omitting the trailing timeout_ms keeps compat with + // older server builds that don't read it. + var bufferArg any + if buffer != nil { + bufferArg = *buffer + } + if timeoutMs == nil { + return conn.Notify("stream_register", executionID, index, bufferArg) } - return conn.Notify("stream_register", executionID, index, *buffer) + return conn.Notify("stream_register", executionID, index, bufferArg, *timeoutMs) } func (w *Worker) StreamAppend(ctx context.Context, executionID string, index int, sequence int, value *adapter.Value) error { @@ -1242,7 +1285,7 @@ func (w *Worker) StreamAppend(ctx context.Context, executionID string, index int return conn.Notify("stream_append", executionID, index, sequence, serverValue) } -func (w *Worker) StreamClose(ctx context.Context, executionID string, index int, streamErr *adapter.StreamCloseError) error { +func (w *Worker) StreamClose(ctx context.Context, executionID string, index int, streamErr *adapter.StreamCloseError, reason *string) error { conn, err := w.requireConn() if err != nil { return err @@ -1255,7 +1298,14 @@ func (w *Worker) StreamClose(ctx context.Context, executionID string, index int, frames := parseTraceback(streamErr.Traceback) errTuple = []any{streamErr.Type, streamErr.Message, frames} } - return conn.Notify("stream_close", executionID, index, errTuple) + // Wire: [execution_id, index, error, reason?]. When reason is nil + // the server infers close kind from error presence (nil→complete, + // object→errored). A non-nil reason (today only "timeout") is + // passed through explicitly. + if reason == nil { + return conn.Notify("stream_close", executionID, index, errTuple) + } + return conn.Notify("stream_close", executionID, index, errTuple, *reason) } func (w *Worker) StreamSubscribe(ctx context.Context, executionID string, subscriptionID int, producerExecutionID string, index int, fromSequence int, stride map[string]any) error { @@ -1938,6 +1988,22 @@ func (w *Worker) buildManifests(manifest *adapter.DiscoveryManifest) map[string] // Build timeout (0 = not set, same as delay) timeout := int(t.Timeout) + // Build streams (nil if not set) — keys snake_case to match the + // Python adapter's wire format for register_manifests. + var streams any + if t.Streams != nil { + m := map[string]any{} + if t.Streams.Buffer != nil { + m["buffer"] = *t.Streams.Buffer + } + if t.Streams.TimeoutMs != nil { + m["timeout_ms"] = *t.Streams.TimeoutMs + } + if len(m) > 0 { + streams = m + } + } + def := map[string]any{ "parameters": buildParameters(t.Parameters), "waitFor": waitFor, @@ -1950,6 +2016,7 @@ func (w *Worker) buildManifests(manifest *adapter.DiscoveryManifest) map[string] "requires": requires, "instruction": instruction, "memo": t.Memo, + "streams": streams, } manifests[t.Module][t.Name] = def diff --git a/server/lib/coflux/handlers/api.ex b/server/lib/coflux/handlers/api.ex index 1607b7f0..1a493dbd 100644 --- a/server/lib/coflux/handlers/api.ex +++ b/server/lib/coflux/handlers/api.ex @@ -509,6 +509,7 @@ defmodule Coflux.Handlers.Api do timeout: {"timeout", &parse_integer(&1, optional: true)}, requires: {"requires", &parse_tag_set/1}, memo: {"memo", &parse_boolean(&1, optional: true)}, + streams: {"streams", &parse_streams_config/1}, idempotency_key: {"idempotencyKey", &parse_string(&1, optional: true)} } ) do @@ -530,6 +531,7 @@ defmodule Coflux.Handlers.Api do timeout: arguments[:timeout] || 0, requires: arguments[:requires], memo: arguments[:memo], + streams: arguments[:streams], idempotency_key: arguments[:idempotency_key] ) do {:ok, run_id, step_number, execution_external_id} -> @@ -1744,6 +1746,30 @@ defmodule Coflux.Handlers.Api do end end + # Parse a ``streams`` config object from an HTTP request body. Both + # ``buffer`` and ``timeoutMs`` are optional; returns nil when the + # caller omits streams entirely. + defp parse_streams_config(value) do + cond do + is_nil(value) -> + {:ok, nil} + + is_map(value) -> + with {:ok, buffer} <- parse_integer(Map.get(value, "buffer"), optional: true), + {:ok, timeout_ms} <- + parse_integer(Map.get(value, "timeoutMs"), optional: true) do + if buffer == nil and timeout_ms == nil do + {:ok, nil} + else + {:ok, %{buffer: buffer, timeout_ms: timeout_ms}} + end + end + + true -> + {:error, :invalid} + end + end + defp parse_workflow(value) do if is_map(value) do with {:ok, parameters} <- parse_parameters(Map.get(value, "parameters")), @@ -1756,6 +1782,7 @@ defmodule Coflux.Handlers.Api do {:ok, timeout} <- parse_integer(Map.get(value, "timeout"), optional: true), {:ok, requires} <- parse_tag_set(Map.get(value, "requires")), {:ok, memo} <- parse_boolean(Map.get(value, "memo"), optional: true), + {:ok, streams} <- parse_manifest_streams(Map.get(value, "streams")), {:ok, instruction} <- parse_string( Map.get(value, "instruction"), @@ -1774,6 +1801,7 @@ defmodule Coflux.Handlers.Api do timeout: timeout || 0, requires: requires, memo: memo == true, + streams: streams, instruction: instruction }} else @@ -1785,6 +1813,31 @@ defmodule Coflux.Handlers.Api do end end + # Parse the ``streams`` field on a manifest workflow. The Python + # adapter serialises this as ``{"buffer": int?, "timeout_ms": int?}`` + # (snake_case, since it's the wire format shared with worker protocol + # not the HTTP-specific camelCase). + defp parse_manifest_streams(value) do + cond do + is_nil(value) -> + {:ok, nil} + + is_map(value) -> + with {:ok, buffer} <- parse_integer(Map.get(value, "buffer"), optional: true), + {:ok, timeout_ms} <- + parse_integer(Map.get(value, "timeout_ms"), optional: true) do + if buffer == nil and timeout_ms == nil do + {:ok, nil} + else + {:ok, %{buffer: buffer, timeout_ms: timeout_ms}} + end + end + + true -> + {:error, :invalid} + end + end + defp parse_workflows(value) do Enum.reduce_while(value, {:ok, %{}}, fn {workflow_name, workflow}, {:ok, result} -> if is_valid_target_name?(workflow_name) do diff --git a/server/lib/coflux/handlers/worker.ex b/server/lib/coflux/handlers/worker.ex index a9fae0f4..1e78217d 100644 --- a/server/lib/coflux/handlers/worker.ex +++ b/server/lib/coflux/handlers/worker.ex @@ -138,7 +138,8 @@ defmodule Coflux.Handlers.Worker do | rest ] = message["params"] - timeout = List.first(rest) || 0 + timeout = Enum.at(rest, 0) || 0 + streams = parse_streams(Enum.at(rest, 1)) if is_recognised_execution?(parent_id, state) do case Orchestration.schedule_step( @@ -157,7 +158,8 @@ defmodule Coflux.Handlers.Worker do retries: parse_retries(retries), recurrent: recurrent == true, requires: requires, - timeout: timeout + timeout: timeout, + streams: streams ) do {:ok, _run_id, _step_id, execution_external_id, metadata} -> result = [ @@ -255,7 +257,8 @@ defmodule Coflux.Handlers.Worker do "stream_register" -> [execution_id, index | rest] = message["params"] - buffer = List.first(rest) + buffer = Enum.at(rest, 0) + timeout_ms = Enum.at(rest, 1) if is_recognised_execution?(execution_id, state) do case Orchestration.register_stream( @@ -263,6 +266,7 @@ defmodule Coflux.Handlers.Worker do execution_id, index, buffer, + timeout_ms, state.session_id ) do :ok -> {[], state} @@ -310,20 +314,27 @@ defmodule Coflux.Handlers.Worker do end "stream_close" -> - [execution_id, index, error] = message["params"] + [execution_id, index, error | rest] = message["params"] + reason = Enum.at(rest, 0) if is_recognised_execution?(execution_id, state) do - parsed_error = - case parse_error(error) do - nil -> nil - {type, message, frames, _retryable} -> {type, message, frames} + close_spec = + case {reason, parse_error(error)} do + {"timeout", _} -> + :timeout + + {_, nil} -> + nil + + {_, {type, msg, frames, _retryable}} -> + {type, msg, frames} end case Orchestration.close_stream( state.project_id, execution_id, index, - parsed_error + close_spec ) do :ok -> {[], state} {:error, :already_closed} -> {[], state} @@ -610,7 +621,7 @@ defmodule Coflux.Handlers.Worker do def websocket_info( {:execute, execution_external_id, module, target, arguments, run_id, - workspace_external_id, timeout}, + workspace_external_id, timeout, streams}, state ) do arguments = Enum.map(arguments, &compose_value/1) @@ -625,7 +636,8 @@ defmodule Coflux.Handlers.Worker do arguments, run_id, workspace_external_id, - timeout + timeout, + compose_streams(streams) ]) ], state} end @@ -787,6 +799,23 @@ defmodule Coflux.Handlers.Worker do end end + def parse_streams(value) do + if value do + %{ + buffer: Map.get(value, "buffer"), + timeout_ms: Map.get(value, "timeout_ms") + } + end + end + + # Encode a streams config (as stored on the execution) for the wire + # format going CLI-ward. nil stays nil (compact). + defp compose_streams(nil), do: nil + + defp compose_streams(streams) do + Map.new(streams, fn {k, v} -> {Atom.to_string(k), v} end) + end + defp compose_references(references) do Enum.map(references, fn {:fragment, format, blob_key, size, metadata} -> diff --git a/server/lib/coflux/orchestration.ex b/server/lib/coflux/orchestration.ex index 626e28b1..f186fefa 100644 --- a/server/lib/coflux/orchestration.ex +++ b/server/lib/coflux/orchestration.ex @@ -186,10 +186,10 @@ defmodule Coflux.Orchestration do # producer execution; `sequence` identifies an item within the stream. # Both are worker-assigned and monotonic from 0. - def register_stream(project_id, execution_id, index, buffer, session_id) do + def register_stream(project_id, execution_id, index, buffer, timeout_ms, session_id) do call_server( project_id, - {:register_stream, execution_id, index, buffer, session_id} + {:register_stream, execution_id, index, buffer, timeout_ms, session_id} ) end @@ -197,8 +197,8 @@ defmodule Coflux.Orchestration do call_server(project_id, {:append_stream_item, execution_id, index, sequence, value}) end - def close_stream(project_id, execution_id, index, error) do - call_server(project_id, {:close_stream, execution_id, index, error}) + def close_stream(project_id, execution_id, index, spec) do + call_server(project_id, {:close_stream, execution_id, index, spec}) end # Stream consumer messages — consumer opens a subscription to receive diff --git a/server/lib/coflux/orchestration/manifests.ex b/server/lib/coflux/orchestration/manifests.ex index 588fc780..36f83af5 100644 --- a/server/lib/coflux/orchestration/manifests.ex +++ b/server/lib/coflux/orchestration/manifests.ex @@ -21,7 +21,8 @@ defmodule Coflux.Orchestration.Manifests do :workflows, {:manifest_id, :name, :instruction_id, :parameter_set_id, :wait_for, :cache_config_id, :defer_params, :delay, :retry_limit, :retry_backoff_min, - :retry_backoff_max, :recurrent, :timeout, :requires_tag_set_id, :memo}, + :retry_backoff_max, :recurrent, :timeout, :requires_tag_set_id, :memo, + :streams_buffer, :streams_timeout_ms}, Enum.map(workflows, fn {name, workflow} -> {:ok, instruction_id} = if workflow.instruction do @@ -45,6 +46,12 @@ defmodule Coflux.Orchestration.Manifests do {:ok, nil} end + {streams_buffer, streams_timeout_ms} = + case workflow[:streams] do + nil -> {nil, nil} + streams -> {streams[:buffer], streams[:timeout_ms]} + end + { manifest_id, name, @@ -62,7 +69,9 @@ defmodule Coflux.Orchestration.Manifests do if(workflow.recurrent, do: 1, else: 0), workflow[:timeout] || 0, requires_tag_set_id, - if(workflow[:memo], do: 1) + if(workflow[:memo], do: 1), + streams_buffer, + streams_timeout_ms } end) ) @@ -161,7 +170,7 @@ defmodule Coflux.Orchestration.Manifests do case query_one( db, """ - SELECT w.parameter_set_id, w.instruction_id, w.wait_for, w.cache_config_id, w.defer_params, w.delay, w.retry_limit, w.retry_backoff_min, w.retry_backoff_max, w.recurrent, w.timeout, w.requires_tag_set_id, w.memo + SELECT w.parameter_set_id, w.instruction_id, w.wait_for, w.cache_config_id, w.defer_params, w.delay, w.retry_limit, w.retry_backoff_min, w.retry_backoff_max, w.recurrent, w.timeout, w.requires_tag_set_id, w.memo, w.streams_buffer, w.streams_timeout_ms FROM workspace_manifests AS wm LEFT JOIN workflows AS w ON w.manifest_id = wm.manifest_id WHERE wm.workspace_id = ?1 AND wm.module = ?2 AND w.name = ?3 @@ -176,7 +185,7 @@ defmodule Coflux.Orchestration.Manifests do {:ok, {parameter_set_id, instruction_id, wait_for, cache_config_id, defer_params, delay, retry_limit, retry_backoff_min, retry_backoff_max, recurrent, timeout, - requires_tag_set_id, memo}} -> + requires_tag_set_id, memo, streams_buffer, streams_timeout_ms}} -> build_workflow( db, parameter_set_id, @@ -191,7 +200,9 @@ defmodule Coflux.Orchestration.Manifests do recurrent, timeout, requires_tag_set_id, - memo + memo, + streams_buffer, + streams_timeout_ms ) end end @@ -200,7 +211,7 @@ defmodule Coflux.Orchestration.Manifests do case query( db, """ - SELECT name, instruction_id, parameter_set_id, wait_for, cache_config_id, defer_params, delay, retry_limit, retry_backoff_min, retry_backoff_max, recurrent, timeout, requires_tag_set_id, memo + SELECT name, instruction_id, parameter_set_id, wait_for, cache_config_id, defer_params, delay, retry_limit, retry_backoff_min, retry_backoff_max, recurrent, timeout, requires_tag_set_id, memo, streams_buffer, streams_timeout_ms FROM workflows WHERE manifest_id = ?1 """, @@ -210,7 +221,8 @@ defmodule Coflux.Orchestration.Manifests do workflows = Map.new(rows, fn {name, instruction_id, parameter_set_id, wait_for, cache_config_id, defer_params, delay, retry_limit, retry_backoff_min, - retry_backoff_max, recurrent, timeout, requires_tag_set_id, memo} -> + retry_backoff_max, recurrent, timeout, requires_tag_set_id, memo, + streams_buffer, streams_timeout_ms} -> {:ok, workflow} = build_workflow( db, @@ -226,7 +238,9 @@ defmodule Coflux.Orchestration.Manifests do recurrent, timeout, requires_tag_set_id, - memo + memo, + streams_buffer, + streams_timeout_ms ) {name, workflow} @@ -291,7 +305,8 @@ defmodule Coflux.Orchestration.Manifests do Integer.to_string(workflow[:timeout] || 0), hash_requires(workflow.requires), if(workflow[:memo], do: "1", else: "0"), - workflow.instruction || "" + workflow.instruction || "", + hash_streams(workflow[:streams]) ] end) @@ -312,7 +327,9 @@ defmodule Coflux.Orchestration.Manifests do recurrent, timeout, requires_tag_set_id, - memo + memo, + streams_buffer, + streams_timeout_ms ) do {:ok, parameters} = get_parameter_set(db, parameter_set_id) @@ -360,6 +377,11 @@ defmodule Coflux.Orchestration.Manifests do } end + streams = + if streams_buffer != nil or streams_timeout_ms != nil do + %{buffer: streams_buffer, timeout_ms: streams_timeout_ms} + end + {:ok, %{ parameters: parameters, @@ -372,7 +394,8 @@ defmodule Coflux.Orchestration.Manifests do recurrent: recurrent == 1, timeout: timeout, requires: requires, - memo: memo == 1 + memo: memo == 1, + streams: streams }} end @@ -448,6 +471,13 @@ defmodule Coflux.Orchestration.Manifests do :crypto.hash(:sha256, data) end + defp hash_streams(nil), do: "-" + + defp hash_streams(streams) do + "#{if streams[:buffer] != nil, do: Integer.to_string(streams[:buffer]), else: ""}:" <> + "#{if streams[:timeout_ms] != nil, do: Integer.to_string(streams[:timeout_ms]), else: ""}" + end + defp hash_requires(requires) do requires |> Enum.sort() diff --git a/server/lib/coflux/orchestration/models.ex b/server/lib/coflux/orchestration/models.ex index fc0a9baa..17095532 100644 --- a/server/lib/coflux/orchestration/models.ex +++ b/server/lib/coflux/orchestration/models.ex @@ -85,6 +85,8 @@ defmodule Coflux.Orchestration.Models do :retry_backoff_min, :retry_backoff_max, :timeout, + :streams_buffer, + :streams_timeout_ms, :workspace_id, :execute_after, :attempt, diff --git a/server/lib/coflux/orchestration/results.ex b/server/lib/coflux/orchestration/results.ex index 9c635709..905195ba 100644 --- a/server/lib/coflux/orchestration/results.ex +++ b/server/lib/coflux/orchestration/results.ex @@ -430,8 +430,7 @@ defmodule Coflux.Orchestration.Results do defp decode_completion_row( {:ok, {kind, successor_id, successor_ref_id, created_at, user_ext, token_ext}} ) do - {:ok, created_at, kind, successor_id, successor_ref_id, - decode_principal(user_ext, token_ext)} + {:ok, created_at, kind, successor_id, successor_ref_id, decode_principal(user_ext, token_ext)} end defp decode_principal(nil, nil), do: nil diff --git a/server/lib/coflux/orchestration/runs.ex b/server/lib/coflux/orchestration/runs.ex index dcae92c5..2cc466e3 100644 --- a/server/lib/coflux/orchestration/runs.ex +++ b/server/lib/coflux/orchestration/runs.ex @@ -374,6 +374,7 @@ defmodule Coflux.Orchestration.Runs do recurrent = Keyword.get(opts, :recurrent, false) delay = Keyword.get(opts, :delay, 0) timeout = Keyword.get(opts, :timeout, 0) + streams = Keyword.get(opts, :streams) requires = Keyword.get(opts, :requires) || %{} # Calculate execute_after from delay @@ -423,6 +424,9 @@ defmodule Coflux.Orchestration.Runs do if defer, do: build_key(defer.params, arguments, "#{module}:#{target}") + streams_buffer = if streams, do: streams[:buffer] + streams_timeout_ms = if streams, do: streams[:timeout_ms] + # TODO: validate parent belongs to run? {:ok, step_id, step_number} = insert_step( @@ -445,6 +449,8 @@ defmodule Coflux.Orchestration.Runs do delay, timeout, requires_tag_set_id, + streams_buffer, + streams_timeout_ms, now ) @@ -642,6 +648,8 @@ defmodule Coflux.Orchestration.Runs do s.retry_backoff_min, s.retry_backoff_max, s.timeout, + s.streams_buffer, + s.streams_timeout_ms, e.workspace_id, e.execute_after, e.attempt, @@ -1289,6 +1297,8 @@ defmodule Coflux.Orchestration.Runs do delay, timeout, requires_tag_set_id, + streams_buffer, + streams_timeout_ms, now ) do {:ok, step_number} = get_next_step_number(db, run_id) @@ -1313,6 +1323,8 @@ defmodule Coflux.Orchestration.Runs do delay: delay, timeout: timeout, requires_tag_set_id: requires_tag_set_id, + streams_buffer: streams_buffer, + streams_timeout_ms: streams_timeout_ms, created_at: now }) do {:ok, step_id} -> diff --git a/server/lib/coflux/orchestration/server.ex b/server/lib/coflux/orchestration/server.ex index 59feb53d..ee0bc416 100644 --- a/server/lib/coflux/orchestration/server.ex +++ b/server/lib/coflux/orchestration/server.ex @@ -1855,13 +1855,13 @@ defmodule Coflux.Orchestration.Server do end def handle_call( - {:register_stream, execution_external_id, index, buffer, session_external_id}, + {:register_stream, execution_external_id, index, buffer, timeout_ms, session_external_id}, _from, state ) do case Map.fetch(state.execution_ids, execution_external_id) do {:ok, execution_id} -> - case Streams.register_stream(state.db, execution_id, index, buffer) do + case Streams.register_stream(state.db, execution_id, index, buffer, timeout_ms) do {:ok, created_at} -> # Resolve the session's external id to the internal one — # send_session (which delivers stream_demand) indexes by the @@ -1878,7 +1878,7 @@ defmodule Coflux.Orchestration.Server do buffer, internal_session_id ) - |> notify_stream_opened(execution_id, index, buffer, created_at) + |> notify_stream_opened(execution_id, index, buffer, timeout_ms, created_at) |> maybe_send_initial_demand(execution_id, index) |> flush_notifications() @@ -2092,16 +2092,19 @@ defmodule Coflux.Orchestration.Server do end end - def handle_call({:close_stream, execution_external_id, index, error}, _from, state) do + def handle_call({:close_stream, execution_external_id, index, close_spec}, _from, state) do case Map.fetch(state.execution_ids, execution_external_id) do {:ok, execution_id} -> - {spec, reason} = - case error do + {spec, reason, error} = + case close_spec do nil -> - {:complete, :complete} + {:complete, :complete, nil} + + :timeout -> + {:timeout, :timeout, nil} {type, message, frames} -> - {{:errored, type, message, frames}, :errored} + {{:errored, type, message, frames}, :errored, {type, message, frames}} end case Streams.close_stream(state.db, execution_id, index, spec) do @@ -3202,7 +3205,11 @@ defmodule Coflux.Orchestration.Server do session_id, {:execute, execution_external_id, execution.module, execution.target, enriched_arguments, execution.run_external_id, workspace_external_id, - execution.timeout} + execution.timeout, + build_streams_config( + execution.streams_buffer, + execution.streams_timeout_ms + )} ) # Notify sessions topic of updated total @@ -5718,6 +5725,18 @@ defmodule Coflux.Orchestration.Server do "#{run_external_id}:#{step_number}:#{attempt}" end + # Build the map passed to workers in the :execute message, describing + # the execution's default stream config. Returns nil when neither + # option is set — keeps the wire message compact for the common case. + defp build_streams_config(nil, nil), do: nil + + defp build_streams_config(buffer, timeout_ms) do + map = %{} + map = if buffer != nil, do: Map.put(map, :buffer, buffer), else: map + map = if timeout_ms != nil, do: Map.put(map, :timeout_ms, timeout_ms), else: map + map + end + defp input_external_id(run_external_id, input_number) do "#{run_external_id}/i#{input_number}" end @@ -6177,9 +6196,7 @@ defmodule Coflux.Orchestration.Server do state = close_open_streams(state, execution_id) - case Results.record_completion(state.db, execution_id, :errored, - successor_id: retry_id - ) do + case Results.record_completion(state.db, execution_id, :errored, successor_id: retry_id) do {:ok, completion_at} -> # Re-fire :result on the run topic so the error entry in the UI # picks up the newly-created retry successor. We only need to do @@ -6222,9 +6239,7 @@ defmodule Coflux.Orchestration.Server do # consumers derive the specific error from the execution's outcome. state = close_open_streams(state, execution_id) - case Results.record_completion(state.db, execution_id, :crashed, - successor_id: retry_id - ) do + case Results.record_completion(state.db, execution_id, :crashed, successor_id: retry_id) do {:ok, completion_at} -> # Result-time notifications weren't fired (no results row was ever # written), so fire them now alongside the completion notification. @@ -6286,15 +6301,15 @@ defmodule Coflux.Orchestration.Server do {:ok, rows} = Streams.get_streams_with_closures_for_execution(db, execution_id) Enum.map(rows, fn - {index, buffer, opened_at, nil, nil, nil} -> - {index, buffer, opened_at, nil, nil, nil} + {index, buffer, timeout_ms, opened_at, nil, nil, nil} -> + {index, buffer, timeout_ms, opened_at, nil, nil, nil} - {index, buffer, opened_at, closed_at, :lifecycle, _} -> + {index, buffer, timeout_ms, opened_at, closed_at, :lifecycle, _} -> {resolved_reason, resolved_error} = derive_lifecycle_info(db, execution_id) - {index, buffer, opened_at, closed_at, resolved_reason, resolved_error} + {index, buffer, timeout_ms, opened_at, closed_at, resolved_reason, resolved_error} - {index, buffer, opened_at, closed_at, reason, error} -> - {index, buffer, opened_at, closed_at, reason, error} + {index, buffer, timeout_ms, opened_at, closed_at, reason, error} -> + {index, buffer, timeout_ms, opened_at, closed_at, reason, error} end) end @@ -7711,14 +7726,14 @@ defmodule Coflux.Orchestration.Server do # items on demand. @stream_topic_tail_size 200 - defp notify_stream_opened(state, execution_id, index, buffer, created_at) do + defp notify_stream_opened(state, execution_id, index, buffer, timeout_ms, created_at) do {:ok, {r, _s, _a}} = Runs.get_execution_key(state.db, execution_id) {:ok, execution_ext_id} = execution_external_id_for(state.db, execution_id) notify_listeners( state, {:run, r}, - {:stream_opened, execution_ext_id, index, buffer, created_at} + {:stream_opened, execution_ext_id, index, buffer, timeout_ms, created_at} ) end @@ -7788,6 +7803,7 @@ defmodule Coflux.Orchestration.Server do {:ok, true} <- Streams.exists?(state.db, execution_id, index), {:ok, opened_at} <- Streams.get_opened_at(state.db, execution_id, index), {:ok, buffer} <- Streams.get_buffer(state.db, execution_id, index), + {:ok, timeout_ms} <- Streams.get_timeout_ms(state.db, execution_id, index), {:ok, {items, total_count}} <- Streams.get_stream_tail(state.db, execution_id, index, @stream_topic_tail_size) do # Keep the tuple shape here — the topic module runs TopicUtils.build_value @@ -7804,6 +7820,7 @@ defmodule Coflux.Orchestration.Server do %{ producer: build_stream_producer(state.db, execution_ext_id, execution_id), buffer: buffer, + timeoutMs: timeout_ms, openedAt: opened_at, closure: closure, items: resolved_items, @@ -7988,8 +8005,7 @@ defmodule Coflux.Orchestration.Server do state |> send_to_consumer( sub, - {:stream_closed, sub.consumer_execution_external_id, subscription_id, "complete", - nil} + {:stream_closed, sub.consumer_execution_external_id, subscription_id, "complete", nil} ) |> drop_subscription(key) else @@ -8240,9 +8256,7 @@ defmodule Coflux.Orchestration.Server do end unless already_completed? do - Logger.warning( - "Couldn't locate session for execution #{execution_ext_id}. Ignoring." - ) + Logger.warning("Couldn't locate session for execution #{execution_ext_id}. Ignoring.") end state diff --git a/server/lib/coflux/orchestration/streams.ex b/server/lib/coflux/orchestration/streams.ex index 943bf173..974690da 100644 --- a/server/lib/coflux/orchestration/streams.ex +++ b/server/lib/coflux/orchestration/streams.ex @@ -36,15 +36,19 @@ defmodule Coflux.Orchestration.Streams do # Registers a new stream owned by `execution_id` at `index` (monotonic # per-execution, worker-assigned). ``buffer`` is the persisted flow- # control budget — ``nil`` means no backpressure, integer N means the - # producer may be up to N items ahead of the fastest consumer. Returns + # producer may be up to N items ahead of the fastest consumer. + # ``timeout_ms`` is the idle-timeout budget (milliseconds) — ``nil`` + # disables the timeout. The server only stores it (for display in + # Studio); enforcement happens at the worker (CLI). Returns # ``{:error, :already_registered}`` if the index was already used. - def register_stream(db, execution_id, index, buffer \\ nil) do + def register_stream(db, execution_id, index, buffer \\ nil, timeout_ms \\ nil) do now = current_timestamp() case insert_one(db, :streams, %{ execution_id: execution_id, index: index, buffer: buffer, + timeout_ms: timeout_ms, created_at: now }) do {:ok, _} -> {:ok, now} @@ -66,6 +70,20 @@ defmodule Coflux.Orchestration.Streams do end end + # Returns the persisted timeout (milliseconds) for a stream. + # ``{:ok, nil}`` means no timeout. ``{:error, :not_found}`` if the + # stream doesn't exist. + def get_timeout_ms(db, execution_id, index) do + case query_one( + db, + "SELECT timeout_ms FROM streams WHERE execution_id = ?1 AND `index` = ?2", + {execution_id, index} + ) do + {:ok, nil} -> {:error, :not_found} + {:ok, {timeout_ms}} -> {:ok, timeout_ms} + end + end + # Appends an item at `sequence` to the stream. Caller supplies the sequence # (worker-assigned, monotonic). Returns: # * `{:error, :not_registered}` if the stream doesn't exist @@ -110,6 +128,8 @@ defmodule Coflux.Orchestration.Streams do # ended (cancel/crash/abandon/error). No error is recorded here — # callers that need to surface an error derive it from the # execution's recorded result at read time. + # * `:timeout` — the worker closed the stream because its idle + # timeout elapsed without a new item being appended. def close_stream(db, execution_id, index, spec \\ :complete) do with_transaction(db, fn -> case exists?(db, execution_id, index) do @@ -138,9 +158,11 @@ defmodule Coflux.Orchestration.Streams do @reason_complete 0 @reason_errored 1 @reason_lifecycle 2 + @reason_timeout 3 defp resolve_close_spec(_db, :complete), do: {@reason_complete, nil} defp resolve_close_spec(_db, :lifecycle), do: {@reason_lifecycle, nil} + defp resolve_close_spec(_db, :timeout), do: {@reason_timeout, nil} defp resolve_close_spec(db, {:errored, type, message, frames}) do error_id = Errors.get_or_create(db, type, message, frames) @@ -149,10 +171,11 @@ defmodule Coflux.Orchestration.Streams do # Atom form of the reason integer — used by callers that want to decide # whether to derive an error from the execution's result (:lifecycle) - # or use the stored one (:errored / :complete). + # or use the stored one (:errored / :complete / :timeout). def reason_from_int(@reason_complete), do: :complete def reason_from_int(@reason_errored), do: :errored def reason_from_int(@reason_lifecycle), do: :lifecycle + def reason_from_int(@reason_timeout), do: :timeout def exists?(db, execution_id, index) do case query_one( @@ -273,9 +296,10 @@ defmodule Coflux.Orchestration.Streams do end # Returns one row per stream owned by `execution_id`: - # `{index, buffer, created_at, closed_at | nil, reason | nil, error | nil}`. + # `{index, buffer, timeout_ms, created_at, closed_at | nil, reason | nil, error | nil}`. # * buffer is the persisted backpressure budget (integer or nil) - # * reason is :complete | :errored | :lifecycle when closed, nil when open + # * timeout_ms is the persisted idle-timeout budget (integer or nil) + # * reason is :complete | :errored | :lifecycle | :timeout when closed, nil when open # * error is the stored `{type, message, frames}` triple for :errored # closures only — callers that need to surface an error for a # :lifecycle closure derive it from the execution's result. @@ -284,7 +308,7 @@ defmodule Coflux.Orchestration.Streams do case query( db, """ - SELECT s.`index`, s.buffer, s.created_at, c.created_at, c.reason, c.error_id + SELECT s.`index`, s.buffer, s.timeout_ms, s.created_at, c.created_at, c.reason, c.error_id FROM streams AS s LEFT JOIN stream_closures AS c ON c.execution_id = s.execution_id AND c.`index` = s.`index` @@ -296,16 +320,16 @@ defmodule Coflux.Orchestration.Streams do {:ok, rows} -> streams = Enum.map(rows, fn - {index, buffer, created_at, nil, nil, nil} -> - {index, buffer, created_at, nil, nil, nil} + {index, buffer, timeout_ms, created_at, nil, nil, nil} -> + {index, buffer, timeout_ms, created_at, nil, nil, nil} - {index, buffer, created_at, closed_at, reason_int, nil} -> - {index, buffer, created_at, closed_at, reason_from_int(reason_int), nil} + {index, buffer, timeout_ms, created_at, closed_at, reason_int, nil} -> + {index, buffer, timeout_ms, created_at, closed_at, reason_from_int(reason_int), nil} - {index, buffer, created_at, closed_at, reason_int, error_id} -> + {index, buffer, timeout_ms, created_at, closed_at, reason_int, error_id} -> {:ok, error} = Errors.get_by_id(db, error_id) - {index, buffer, created_at, closed_at, reason_from_int(reason_int), + {index, buffer, timeout_ms, created_at, closed_at, reason_from_int(reason_int), error} end) diff --git a/server/lib/coflux/topics/run.ex b/server/lib/coflux/topics/run.ex index 476773b3..f6fb77bf 100644 --- a/server/lib/coflux/topics/run.ex +++ b/server/lib/coflux/topics/run.ex @@ -215,11 +215,12 @@ defmodule Coflux.Topics.Run do defp process_notification( topic, - {:stream_opened, execution_external_id, index, buffer, created_at} + {:stream_opened, execution_external_id, index, buffer, timeout_ms, created_at} ) do update_execution(topic, execution_external_id, fn topic, base_path -> Topic.set(topic, base_path ++ [:streams, Integer.to_string(index)], %{ buffer: buffer, + timeoutMs: timeout_ms, openedAt: created_at, closedAt: nil, reason: nil, @@ -559,24 +560,33 @@ defmodule Coflux.Topics.Run do defp build_streams(streams) do Map.new(streams, fn - {index, buffer, opened_at, nil, nil, nil} -> + {index, buffer, timeout_ms, opened_at, nil, nil, nil} -> {Integer.to_string(index), - %{buffer: buffer, openedAt: opened_at, closedAt: nil, reason: nil, error: nil}} + %{ + buffer: buffer, + timeoutMs: timeout_ms, + openedAt: opened_at, + closedAt: nil, + reason: nil, + error: nil + }} - {index, buffer, opened_at, closed_at, reason, nil} -> + {index, buffer, timeout_ms, opened_at, closed_at, reason, nil} -> {Integer.to_string(index), %{ buffer: buffer, + timeoutMs: timeout_ms, openedAt: opened_at, closedAt: closed_at, reason: Atom.to_string(reason), error: nil }} - {index, buffer, opened_at, closed_at, reason, {type, message, _frames}} -> + {index, buffer, timeout_ms, opened_at, closed_at, reason, {type, message, _frames}} -> {Integer.to_string(index), %{ buffer: buffer, + timeoutMs: timeout_ms, openedAt: opened_at, closedAt: closed_at, reason: Atom.to_string(reason), diff --git a/server/lib/coflux/topics/stream.ex b/server/lib/coflux/topics/stream.ex index 05b150f6..331fc14e 100644 --- a/server/lib/coflux/topics/stream.ex +++ b/server/lib/coflux/topics/stream.ex @@ -40,6 +40,7 @@ defmodule Coflux.Topics.Stream do %{ producer: initial.producer, buffer: initial.buffer, + timeoutMs: initial.timeoutMs, openedAt: initial.openedAt, closure: build_closure(initial.closure), items: Enum.map(initial.items, &build_item/1), diff --git a/server/lib/coflux/topics/workflow.ex b/server/lib/coflux/topics/workflow.ex index b9a4b1a2..f8532196 100644 --- a/server/lib/coflux/topics/workflow.ex +++ b/server/lib/coflux/topics/workflow.ex @@ -119,11 +119,21 @@ defmodule Coflux.Topics.Workflow do recurrent: workflow.recurrent, timeout: workflow.timeout, requires: workflow.requires, - memo: workflow.memo + memo: workflow.memo, + streams: build_streams_configuration(workflow[:streams]) } end end + defp build_streams_configuration(nil), do: nil + + defp build_streams_configuration(streams) do + %{ + buffer: streams[:buffer], + timeoutMs: streams[:timeout_ms] + } + end + defp build_runs(runs) do Map.new(runs, fn {external_run_id, created_at, created_by_user_ext_id, created_by_token_ext_id} -> diff --git a/server/priv/migrations/orchestration/4.sql b/server/priv/migrations/orchestration/4.sql index f54bf464..a6dd71e2 100644 --- a/server/priv/migrations/orchestration/4.sql +++ b/server/priv/migrations/orchestration/4.sql @@ -106,6 +106,20 @@ DROP INDEX idx_results_successor_ref_id; DROP TABLE results; ALTER TABLE results_new RENAME TO results; +-- Stream config attached to a workflow definition (from the manifest) +-- and to each step (copied at submit time, optionally overridden per +-- call via cf.Target.with_streams). NULL means "unset" — the adapter +-- falls back to the decorator-level default. +-- +-- streams_buffer: backpressure budget in number of items (0 = strict +-- lockstep, N = allow N items ahead, NULL = unbounded). +-- streams_timeout_ms: idle-timeout budget in milliseconds; NULL means +-- no timeout. +ALTER TABLE workflows ADD COLUMN streams_buffer INTEGER; +ALTER TABLE workflows ADD COLUMN streams_timeout_ms INTEGER; +ALTER TABLE steps ADD COLUMN streams_buffer INTEGER; +ALTER TABLE steps ADD COLUMN streams_timeout_ms INTEGER; + -- Streams — ordered, append-only sequences of values produced by an -- execution. Each stream is identified by (execution_id, index), where -- `index` is assigned monotonically by the worker when serialising the @@ -134,6 +148,10 @@ CREATE TABLE streams ( -- Persisted so the server can reconstruct per-stream flow-control -- state on restart and so Studio can display the configuration. buffer INTEGER, + -- Idle-timeout budget in milliseconds. NULL disables the timeout. + -- Enforced at the worker (CLI) level; persisted here only so Studio + -- can display the configured value. + timeout_ms INTEGER, created_at INTEGER NOT NULL, PRIMARY KEY (execution_id, `index`), FOREIGN KEY (execution_id) REFERENCES executions ON DELETE CASCADE @@ -157,6 +175,8 @@ CREATE TABLE stream_items ( -- (cancel/crash/abandon/error). The specific error is -- derived on read by looking up the execution's completion, -- so we don't duplicate that state here. +-- 3 = timeout — closed by the worker because the configured idle +-- timeout elapsed without a new item being appended. CREATE TABLE stream_closures ( execution_id INTEGER NOT NULL, `index` INTEGER NOT NULL, diff --git a/tests/support/executor.py b/tests/support/executor.py index 86515822..4ba6ae59 100644 --- a/tests/support/executor.py +++ b/tests/support/executor.py @@ -39,7 +39,11 @@ def _unwrap_select_result(result): return {"status": status} return result -Execution = namedtuple("Execution", ["conn", "execution_id", "module", "target", "arguments"]) +Execution = namedtuple( + "Execution", + ["conn", "execution_id", "module", "target", "arguments", "streams"], + defaults=[None], +) class ExecutorConnection: @@ -92,11 +96,17 @@ def run_one(self, handler): self.send(response) def recv_execute(self, **kwargs): - """Receive an execute message, return (execution_id, module, target, arguments).""" + """Receive an execute message, return (execution_id, module, target, arguments, streams).""" msg = self.recv(**kwargs) assert msg["method"] == "execute", f"expected execute, got {msg['method']}" p = msg["params"] - return p["execution_id"], p.get("module", ""), p["target"], p.get("arguments", []) + return ( + p["execution_id"], + p.get("module", ""), + p["target"], + p.get("arguments", []), + p.get("streams"), + ) def _request(self, msg): """Send a request message (with auto-assigned ID) and return the response. @@ -258,9 +268,15 @@ def resolve_input( # --- Stream producer helpers --- - def stream_register(self, execution_id, index, buffer=None): - """Notify that a new stream exists. ``buffer`` enables backpressure.""" - self.send(protocol.stream_register(execution_id, index, buffer=buffer)) + def stream_register(self, execution_id, index, buffer=None, timeout_ms=None): + """Notify that a new stream exists. ``buffer`` enables + backpressure; ``timeout_ms`` enables idle-timeout enforcement + at the worker.""" + self.send( + protocol.stream_register( + execution_id, index, buffer=buffer, timeout_ms=timeout_ms + ) + ) def stream_append(self, execution_id, index, sequence, value, format="json"): """Append an item (raw JSON value) to a stream.""" @@ -461,9 +477,9 @@ def next_execute(self, timeout=10): if idx in self._consumed: continue try: - eid, module, target, args = conn.recv_execute(timeout=0.1) + eid, module, target, args, streams = conn.recv_execute(timeout=0.1) self._consumed.add(idx) - return Execution(conn, eid, module, target, args) + return Execution(conn, eid, module, target, args, streams) except TimeoutError: continue except (ConnectionError, OSError): diff --git a/tests/support/manifest.py b/tests/support/manifest.py index d3fbf61d..1f3aa50c 100644 --- a/tests/support/manifest.py +++ b/tests/support/manifest.py @@ -12,6 +12,7 @@ def _target( wait_for=None, requires=None, timeout=None, + streams=None, ): target = { "module": module, @@ -35,6 +36,8 @@ def _target( target["requires"] = requires if timeout is not None: target["timeout"] = timeout + if streams is not None: + target["streams"] = streams return target diff --git a/tests/support/protocol.py b/tests/support/protocol.py index 25bfa37b..428cddfc 100644 --- a/tests/support/protocol.py +++ b/tests/support/protocol.py @@ -234,10 +234,12 @@ def register_group_notification(execution_id, group_id, name=None): # --- Stream messages (producer side: adapter → server) --- -def stream_register(execution_id, index, buffer=None): +def stream_register(execution_id, index, buffer=None, timeout_ms=None): params = {"execution_id": execution_id, "index": index} if buffer is not None: params["buffer"] = buffer + if timeout_ms is not None: + params["timeout_ms"] = timeout_ms return {"method": "stream_register", "params": params} diff --git a/tests/test_streams.py b/tests/test_streams.py index 11e0f06b..224addf6 100644 --- a/tests/test_streams.py +++ b/tests/test_streams.py @@ -739,3 +739,129 @@ def test_backpressure_unbounded_sends_no_demand(worker): prod_ex.conn.complete(prod_ex.execution_id) cons_ex.conn.complete(cons_ex.execution_id) ctx.result(prod_resp["runId"]) + + +# --- Idle timeout ------------------------------------------------------- + + +def test_timeout_fires_when_producer_idle(worker): + """A stream registered with ``timeout_ms`` is force-closed by the + worker after that many milliseconds without an append. The + adapter receives a ``stream_force_close`` push and any consumer + sees a ``stream_closed`` push with reason=``"timeout"``. + """ + targets = [workflow("test", "producer"), workflow("test", "consumer")] + + with worker(targets, concurrency=2) as ctx: + prod_resp = ctx.submit("test", "producer") + prod_ex = ctx.executor.next_execute() + prod_ex.conn.stream_register( + prod_ex.execution_id, 0, buffer=None, timeout_ms=150 + ) + + # Consumer subscribes so it'll see the timeout close. + cons_resp = ctx.submit("test", "consumer") + cons_ex = ctx.executor.next_execute() + cons_ex.conn.stream_subscribe( + cons_ex.execution_id, + subscription_id=1, + producer_execution_id=prod_ex.execution_id, + index=0, + ) + + # Producer is idle. Within a second (well past 150ms), the CLI + # should push force-close to the producer and the consumer + # should see the stream closed with reason="timeout". + force = prod_ex.conn.recv_push("stream_force_close", timeout=2) + assert force["index"] == 0 + assert force["reason"] == "timeout" + + closed = cons_ex.conn.recv_push("stream_closed", subscription_id=1, timeout=2) + assert closed["reason"] == "timeout" + assert closed.get("error") is None + + # Producer should skip its own stream_close now (server already + # recorded it); we simulate the real adapter by just completing. + prod_ex.conn.complete(prod_ex.execution_id) + cons_ex.conn.complete(cons_ex.execution_id) + ctx.result(prod_resp["runId"]) + + +def test_timeout_resets_on_append(worker): + """Each append resets the idle deadline — a producer that emits + items at a steady pace faster than the timeout does not fire. + """ + targets = [workflow("test", "producer")] + + with worker(targets) as ctx: + prod_resp = ctx.submit("test", "producer") + prod_ex = ctx.executor.next_execute() + prod_ex.conn.stream_register( + prod_ex.execution_id, 0, buffer=None, timeout_ms=250 + ) + + # Emit 3 items with 100ms gaps — each append resets the + # deadline, so the total 300ms elapsed doesn't trigger a fire. + import time as _t + + for seq in range(3): + _t.sleep(0.1) + prod_ex.conn.stream_append(prod_ex.execution_id, 0, seq, f"v{seq}") + + # No force-close should have been pushed. + with pytest.raises(TimeoutError): + prod_ex.conn.recv_push("stream_force_close", timeout=0.1) + + prod_ex.conn.stream_close(prod_ex.execution_id, 0) + prod_ex.conn.complete(prod_ex.execution_id) + ctx.result(prod_resp["runId"]) + + +def test_manifest_streams_propagates_to_execute(worker): + """Workflow registered with ``streams`` in the manifest: when + submitted (mimicking Studio/CLI), the execute message delivered to + the worker carries the same ``streams`` config. This is the full + propagation path — adapter manifest → server → execute dispatch. + """ + targets = [ + workflow("test", "producer", streams={"buffer": 5, "timeout_ms": 250}), + ] + + with worker(targets) as ctx: + ctx.submit("test", "producer") + prod_ex = ctx.executor.next_execute() + + assert prod_ex.streams is not None + assert prod_ex.streams.get("buffer") == 5 + assert prod_ex.streams.get("timeout_ms") == 250 + + prod_ex.conn.complete(prod_ex.execution_id) + + +def test_timeout_visible_in_topic(worker): + """Studio's run topic surfaces ``timeoutMs`` on the stream state, + and a timeout closure shows ``reason: "timeout"``. + """ + targets = [workflow("test", "producer")] + + with worker(targets) as ctx: + prod_resp = ctx.submit("test", "producer") + prod_ex = ctx.executor.next_execute() + prod_ex.conn.stream_register( + prod_ex.execution_id, 0, buffer=None, timeout_ms=120 + ) + + # Wait for the timeout to fire. + force = prod_ex.conn.recv_push("stream_force_close", timeout=2) + assert force["reason"] == "timeout" + + prod_ex.conn.complete(prod_ex.execution_id) + ctx.result(prod_resp["runId"]) + + snapshot = ctx.inspect(prod_resp["runId"]) + step = next(iter(snapshot["steps"].values())) + execution = next(iter(step["executions"].values())) + stream = execution["streams"]["0"] + assert stream["timeoutMs"] == 120 + assert stream["reason"] == "timeout" + assert stream["error"] is None From f870911f084a76b1a7713c8190adb186c83e443a Mon Sep 17 00:00:00 2001 From: Joe Freeman Date: Mon, 20 Apr 2026 21:37:06 +0100 Subject: [PATCH 18/25] Tidy imports --- adapters/python/coflux/__init__.py | 11 +---------- adapters/python/coflux/executor.py | 6 ++---- adapters/python/coflux/models.py | 27 +++++++++++---------------- adapters/python/coflux/streams.py | 10 +++------- 4 files changed, 17 insertions(+), 37 deletions(-) diff --git a/adapters/python/coflux/__init__.py b/adapters/python/coflux/__init__.py index 14214c9f..74a25571 100644 --- a/adapters/python/coflux/__init__.py +++ b/adapters/python/coflux/__init__.py @@ -23,15 +23,7 @@ InputDismissed, ) from .metric import Metric, MetricGroup, MetricScale, progress -from .models import ( - Asset, - AssetEntry, - AssetMetadata, - Execution, - Input, - ModelSchema, - Stream, -) +from .models import Asset, AssetEntry, AssetMetadata, Execution, Input, Stream from .prompt import Prompt from .state import get_context from .streams import stream @@ -54,7 +46,6 @@ "ExecutionCrashed", "InputDismissed", "Input", - "ModelSchema", "Metric", "MetricGroup", "MetricScale", diff --git a/adapters/python/coflux/executor.py b/adapters/python/coflux/executor.py index 12fc0195..4823fd1b 100644 --- a/adapters/python/coflux/executor.py +++ b/adapters/python/coflux/executor.py @@ -18,6 +18,8 @@ from .output import capture_output from .models import Input from .serialization import deserialize_value, serialize_value +from .streams import stream as _register_stream +from .target import Streams _COFLUX_PKG_DIR = os.path.dirname(os.path.abspath(__file__)) @@ -47,8 +49,6 @@ def _resolve_execute_streams(target_obj: Any, streams_from_wire: dict[str, Any] Returns a ``Streams`` instance or ``None`` (no stream config). """ - from .target import Streams # local import to avoid cycles - if streams_from_wire is not None: buffer = streams_from_wire.get("buffer", 0) timeout_ms = streams_from_wire.get("timeout_ms") @@ -162,8 +162,6 @@ def execute_target( if (inspect.isgenerator(result) or inspect.isasyncgen(result)) and hasattr( target_obj, "definition" ): - from .streams import stream as _register_stream - kwargs: dict[str, Any] = {} if effective_streams is not None: kwargs["buffer"] = effective_streams.buffer diff --git a/adapters/python/coflux/models.py b/adapters/python/coflux/models.py index 10fbd877..a19bcee9 100644 --- a/adapters/python/coflux/models.py +++ b/adapters/python/coflux/models.py @@ -1,4 +1,4 @@ -"""Models for the Coflux Python SDK.""" +"""Reference types returned to user code: assets, handles, and streams.""" from __future__ import annotations @@ -9,23 +9,11 @@ from .state import get_context - T = t.TypeVar("T") D = t.TypeVar("D") -class ModelSchema(t.Protocol): - """Protocol for schema classes that can validate JSON data. - - Compatible with Pydantic BaseModel and any class providing - model_json_schema() and model_validate() classmethods. - """ - - @classmethod - def model_json_schema(cls) -> dict[str, t.Any]: ... - - @classmethod - def model_validate(cls, obj: t.Any) -> t.Any: ... +# --- Assets --- class AssetEntry(t.NamedTuple): @@ -118,6 +106,9 @@ def restore( return {e.path: e.restore(at=at) for e in entries} +# --- Handles (resolve via cf.select) --- + + class _Handle(t.Generic[T]): """Base for handles that resolve via ``cf.select``. @@ -210,6 +201,9 @@ def target(self) -> str: return self._target +# --- Streams --- + + Stride = t.Tuple[int, t.Optional[int], int] """A stride over the stream's sequence numbers: ``(start, stop, step)``. @@ -304,8 +298,9 @@ def partition(self, n: int, i: int) -> "Stream[T]": return self.stride(i, None, n) def __iter__(self) -> t.Iterator[T]: - # Deferred import to avoid a cycle (streams.py imports serialization - # which imports models for Execution/Input/Asset). + # Local import: ``streams`` imports ``serialization`` at top, and + # ``serialization`` imports ``Stream`` from here — a top-level + # ``from .streams import ...`` would cycle. from .streams import open_subscription return open_subscription(self._id, self._stride) diff --git a/adapters/python/coflux/streams.py b/adapters/python/coflux/streams.py index fefc067c..33781eae 100644 --- a/adapters/python/coflux/streams.py +++ b/adapters/python/coflux/streams.py @@ -33,6 +33,8 @@ from .errors import raise_for_close from .serialization import deserialize_value, serialize_value from .state import get_context +from .models import Stream +from .target import Streams, _validate_buffer, _validate_timeout # --- Producer side --- @@ -91,8 +93,6 @@ def stream( f"cf.stream expects a generator, got {type(generator).__name__}" ) - from .target import Streams, _validate_buffer, _validate_timeout - ctx = get_context() default = ctx.get_default_streams() or Streams() resolved_buffer = ( @@ -104,11 +104,7 @@ def stream( else default.timeout ) stream_id = ctx.register_stream(generator, resolved_buffer, resolved_timeout) - # Local import to avoid a top-level cycle — models imports nothing - # from streams but streams already imports from models at top. - from .models import Stream as StreamHandle - - return StreamHandle(stream_id) + return Stream(stream_id) class StreamDriver: From d0611ed9cd62621bffbaf8aa0d840d55e6240e86 Mon Sep 17 00:00:00 2001 From: Joe Freeman Date: Mon, 20 Apr 2026 21:37:47 +0100 Subject: [PATCH 19/25] Fix replay after producer terminated --- server/lib/coflux/orchestration/server.ex | 8 ++++ tests/test_streams.py | 53 +++++++++++++++++++++++ 2 files changed, 61 insertions(+) diff --git a/server/lib/coflux/orchestration/server.ex b/server/lib/coflux/orchestration/server.ex index ee0bc416..a08fb1ae 100644 --- a/server/lib/coflux/orchestration/server.ex +++ b/server/lib/coflux/orchestration/server.ex @@ -1948,6 +1948,14 @@ defmodule Coflux.Orchestration.Server do :error -> state + {:ok, %{session_id: nil}} -> + # Producer's session is gone — typically because the producer + # execution has long since terminated and we rebuilt its in-memory + # state for a late subscriber. There's nothing to grant demand to; + # the stream is durable in the DB and backlog reads don't consume + # credits. + state + {:ok, producer} -> has_subscribers = has_stream_subscribers?(state, key) max_cursor = current_max_cursor(state, key) diff --git a/tests/test_streams.py b/tests/test_streams.py index 224addf6..a84d9df3 100644 --- a/tests/test_streams.py +++ b/tests/test_streams.py @@ -80,6 +80,59 @@ def test_producer_writes_and_consumer_reads_backlog(worker): assert closed.get("error") is None +def test_backlog_replay_after_producer_terminated_with_buffer(worker): + """Regression: subscribing to a closed bounded-buffer stream whose + producer execution has terminated must replay the backlog and close. + + Previously, ``ensure_stream_producer`` rebuilt the producer's + in-memory state with ``session_id=nil`` for the vanished execution. + After ``push_backlog_items`` advanced the consumer's cursor past + ``demand_granted``, ``refresh_stream_demand`` would compute a + positive delta and call ``send_session(state, nil, ...)``, crashing + the project GenServer with ``KeyError :connection`` and leaving the + consumer hung. The unbounded (``buffer=null``) path didn't trigger + it because ``ensure_stream_producer`` no-ops in that case. + """ + targets = [ + workflow("test", "producer"), + workflow("test", "consumer"), + ] + + with worker(targets) as ctx: + prod_resp = ctx.submit("test", "producer") + prod_ex = ctx.executor.next_execute() + prod_ex.conn.stream_register(prod_ex.execution_id, 0, buffer=0) + # First subscriber to a buffer=0 stream needs a credit before the + # producer can append; the demand grant arrives once the consumer + # subscribes. To keep the test focused on the post-termination + # replay, append items here without waiting on demand — the test + # adapter doesn't enforce credit accounting on appends. + prod_ex.conn.stream_append(prod_ex.execution_id, 0, 0, "a") + prod_ex.conn.stream_append(prod_ex.execution_id, 0, 1, "b") + prod_ex.conn.stream_append(prod_ex.execution_id, 0, 2, "c") + prod_ex.conn.stream_close(prod_ex.execution_id, 0) + prod_ex.conn.complete(prod_ex.execution_id, value=42) + + # Producer execution must be fully gone — its session entry for + # this execution removed — before the consumer subscribes. + ctx.result(prod_resp["runId"]) + + cons_resp = ctx.submit("test", "consumer") + cons_ex = ctx.executor.next_execute() + cons_ex.conn.stream_subscribe( + cons_ex.execution_id, + subscription_id=1, + producer_execution_id=prod_ex.execution_id, + index=0, + ) + items, closed = cons_ex.conn.drain_stream(subscription_id=1) + cons_ex.conn.complete(cons_ex.execution_id) + + assert [item[0] for item in items] == [0, 1, 2] + assert [item[1]["value"] for item in items] == ["a", "b", "c"] + assert closed.get("error") is None + + def test_consumer_sees_live_push(worker): """Consumer subscribes *before* the producer appends. Items arrive live.""" targets = [ From a77d391f2f3378040763333e7902b04b921f9302 Mon Sep 17 00:00:00 2001 From: Joe Freeman Date: Mon, 20 Apr 2026 22:25:12 +0100 Subject: [PATCH 20/25] Handle stream error/timeout as dedicated completion kinds --- server/lib/coflux/orchestration/results.ex | 98 ++++++++- server/lib/coflux/orchestration/runs.ex | 22 +- server/lib/coflux/orchestration/server.ex | 79 ++++++- server/lib/coflux/orchestration/streams.ex | 36 ++++ server/priv/migrations/orchestration/4.sql | 27 ++- tests/test_streams.py | 230 ++++++++++++++++++++- 6 files changed, 463 insertions(+), 29 deletions(-) diff --git a/server/lib/coflux/orchestration/results.ex b/server/lib/coflux/orchestration/results.ex index 905195ba..870644bf 100644 --- a/server/lib/coflux/orchestration/results.ex +++ b/server/lib/coflux/orchestration/results.ex @@ -1,7 +1,7 @@ defmodule Coflux.Orchestration.Results do import Coflux.Store - alias Coflux.Orchestration.{Errors, Values} + alias Coflux.Orchestration.{Errors, Streams, Values} # --- Completion kinds --- # @@ -19,8 +19,26 @@ defmodule Coflux.Orchestration.Results do @kind_deferred 8 @kind_cached 9 @kind_spawned 10 - - @failure_kinds [@kind_errored, @kind_abandoned, @kind_crashed, @kind_timeout] + # `:stream_errored` — execution returned a value but at least one of its + # streams closed with an error. Treated as a failure (retried via the + # step's retry policy, ineligible for cache lookup) but the result row + # still carries the original value so consumers that already resolved + # against it keep working. + @kind_stream_errored 11 + # `:stream_timeout` — execution returned a value but at least one of its + # streams closed via idle timeout. Logically a success (resolves to the + # value) but ineligible for cache lookup, since the cached stream + # contents would be shaped by the original consumer's demand pattern. + # Not retried. Distinct from execution-level `:timeout`. + @kind_stream_timeout 12 + + @failure_kinds [ + @kind_errored, + @kind_abandoned, + @kind_crashed, + @kind_timeout, + @kind_stream_errored + ] def kind_atom(0), do: :succeeded def kind_atom(1), do: :errored @@ -33,6 +51,8 @@ defmodule Coflux.Orchestration.Results do def kind_atom(8), do: :deferred def kind_atom(9), do: :cached def kind_atom(10), do: :spawned + def kind_atom(11), do: :stream_errored + def kind_atom(12), do: :stream_timeout def atom_kind(:succeeded), do: @kind_succeeded def atom_kind(:errored), do: @kind_errored @@ -45,6 +65,8 @@ defmodule Coflux.Orchestration.Results do def atom_kind(:deferred), do: @kind_deferred def atom_kind(:cached), do: @kind_cached def atom_kind(:spawned), do: @kind_spawned + def atom_kind(:stream_errored), do: @kind_stream_errored + def atom_kind(:stream_timeout), do: @kind_stream_timeout def failure_kinds, do: @failure_kinds @@ -307,6 +329,7 @@ defmodule Coflux.Orchestration.Results do case resolve_logical( db, + execution_id, kind, value_id, error_id, @@ -441,6 +464,20 @@ defmodule Coflux.Orchestration.Results do defp decode_retryable(1), do: true defp decode_retryable(0), do: false + # Look up the first errored stream closure's error triple for a + # :stream_errored execution. Returns `nil` if no errored closure exists + # (shouldn't happen if the kind is :stream_errored, but defensive). + defp fetch_stream_error(db, execution_id) do + case Streams.get_closure_summary_for_execution(db, execution_id) do + {:ok, %{errored: nil}} -> + nil + + {:ok, %{errored: error_id}} when not is_nil(error_id) -> + {:ok, triple} = Errors.get_by_id(db, error_id) + triple + end + end + # Builds the legacy "logical result" tuple from the split tables. Used # by most callers (UI, topic state, consumer resolution). Returns `nil` # only when nothing has been recorded yet. @@ -449,34 +486,73 @@ defmodule Coflux.Orchestration.Results do # safe to follow a successor — error results without a completion carry # `nil` as their successor here, and the server treats that as "still # pending". - defp resolve_logical(_db, nil, nil, nil, _, _, _), do: nil + defp resolve_logical(_db, _exec_id, nil, nil, nil, _, _, _), do: nil # Value payload present. Returns the appropriate tagged tuple, picking # the successor-flavoured form when the completion says this was a # deferred/cached/spawned resolution. - defp resolve_logical(db, kind, value_id, nil, _retryable, _successor_id, successor_ref_id) + defp resolve_logical( + db, + execution_id, + kind, + value_id, + nil, + retryable, + successor_id, + successor_ref_id + ) when not is_nil(value_id) do {:ok, value} = Values.get_value_by_id(db, value_id) case kind && kind_atom(kind) do - :deferred when not is_nil(successor_ref_id) -> {:deferred, successor_ref_id, value} - :cached when not is_nil(successor_ref_id) -> {:cached, successor_ref_id, value} - :spawned when not is_nil(successor_ref_id) -> {:spawned, successor_ref_id, value} - _ -> {:value, value} + :deferred when not is_nil(successor_ref_id) -> + {:deferred, successor_ref_id, value} + + :cached when not is_nil(successor_ref_id) -> + {:cached, successor_ref_id, value} + + :spawned when not is_nil(successor_ref_id) -> + {:spawned, successor_ref_id, value} + + :stream_errored -> + # Value result + a stream errored mid-flight. Surface as an error + # for consumer resolution, using the first errored stream closure's + # stored error triple. (The cancellation precedent leaves the value + # in place — consumers that already resolved against it keep their + # reference; this branch governs only late lookups.) + case fetch_stream_error(db, execution_id) do + nil -> + {:value, value} + + {type, message, frames} -> + {:error, type, message, frames, successor_id, decode_retryable(retryable)} + end + + _ -> + {:value, value} end end # Error payload without a completion yet. We return the error so UI can # display it; the successor slot is nil so consumer resolution treats # it as still pending (the retry decision happens at completion time). - defp resolve_logical(db, nil, nil, error_id, retryable, _successor_id, _successor_ref_id) + defp resolve_logical(db, _exec_id, nil, nil, error_id, retryable, _successor_id, _successor_ref_id) when not is_nil(error_id) do {:ok, {type, message, frames}} = Errors.get_by_id(db, error_id) {:error, type, message, frames, nil, decode_retryable(retryable)} end # Completion present (possibly with no results row). - defp resolve_logical(db, kind, _value_id, error_id, retryable, successor_id, successor_ref_id) do + defp resolve_logical( + db, + _exec_id, + kind, + _value_id, + error_id, + retryable, + successor_id, + successor_ref_id + ) do case kind_atom(kind) do :succeeded -> nil diff --git a/server/lib/coflux/orchestration/runs.ex b/server/lib/coflux/orchestration/runs.ex index 2cc466e3..f43c065f 100644 --- a/server/lib/coflux/orchestration/runs.ex +++ b/server/lib/coflux/orchestration/runs.ex @@ -1169,6 +1169,20 @@ defmodule Coflux.Orchestration.Runs do workspace_ids ++ [{:blob, cache_key}, recorded_after] end + # Disqualifying completion kinds: cancellation, plus the two + # value-result-but-stream-broke kinds. These all have a value row + # (so the value-id check above wouldn't reject them) but the cached + # value is unsafe to reuse — cancelled = explicit user override, + # stream_errored = producer failure, stream_timeout = consumer-shaped + # output. Each is excluded as soon as its completion is recorded; + # in-flight executions (no completion row yet) remain candidates. + disqualified_kinds = + Enum.map_join( + [:cancelled, :stream_errored, :stream_timeout], + ", ", + &Integer.to_string(Results.atom_kind(&1)) + ) + case query( db, """ @@ -1181,14 +1195,14 @@ defmodule Coflux.Orchestration.Runs do e.workspace_id IN (#{build_placeholders(length(workspace_ids))}) AND s.cache_key = ?#{length(workspace_ids) + 1} -- Either no result yet (in-flight candidate) or a value result - -- recorded within the cache age window. Errors disqualify. - -- Cancelled-with-value also disqualifies: the value stays valid - -- for already-resolved consumers but shouldn't seed cache hits. + -- recorded within the cache age window. Errors disqualify + -- (no value_id). See `disqualified_kinds` for completion-kind + -- exclusions that survive having a value. AND ( r.execution_id IS NULL OR (r.value_id IS NOT NULL AND r.created_at >= ?#{length(workspace_ids) + 2}) ) - AND (c.kind IS NULL OR c.kind != #{Results.atom_kind(:cancelled)}) + AND (c.kind IS NULL OR c.kind NOT IN (#{disqualified_kinds})) #{step_clause} ORDER BY e.created_at DESC LIMIT 1 diff --git a/server/lib/coflux/orchestration/server.ex b/server/lib/coflux/orchestration/server.ex index a08fb1ae..41e468dd 100644 --- a/server/lib/coflux/orchestration/server.ex +++ b/server/lib/coflux/orchestration/server.ex @@ -19,6 +19,7 @@ defmodule Coflux.Orchestration.Server do Workers, Manifests, Principals, + Errors, Epoch } @@ -6173,11 +6174,85 @@ defmodule Coflux.Orchestration.Server do end end - # Value result + clean drain: no retry, kind=:succeeded. + # Value result + drain: dispatch on stream closure outcomes. + # * any owned stream closed `:errored` → `:stream_errored` (retried) + # * else any owned stream closed `:timeout` → `:partial` (not retried, + # not cacheable) + # * else `:succeeded` + # `close_open_streams` runs first so any still-open streams get a + # `:lifecycle` row (which doesn't influence the dispatch — only the + # explicit `:errored`/`:timeout` reasons do). defp finalize_success_completion(state, execution_id) do state = close_open_streams(state, execution_id) - case Results.record_completion(state.db, execution_id, :succeeded) do + {:ok, summary} = Streams.get_closure_summary_for_execution(state.db, execution_id) + + cond do + not is_nil(summary.errored) -> + finalize_stream_errored_completion(state, execution_id, summary.errored) + + summary.timed_out -> + finalize_stream_timeout_completion(state, execution_id) + + true -> + case Results.record_completion(state.db, execution_id, :succeeded) do + {:ok, completion_at} -> + fire_completion_notification(state, execution_id, completion_at) + + {:error, :already_completed} -> + state + end + end + end + + # A stream owned by this execution closed with an error, but the function + # body returned a value. Promote to `:stream_errored`: drives the retry + # policy and excludes the execution from cache lookups, while leaving the + # value in `results` so any consumer that already resolved against it + # keeps its reference (mirrors the cancellation precedent at do_cancel_execution). + defp finalize_stream_errored_completion(state, execution_id, error_id) do + {:ok, step} = Runs.get_step_for_execution(state.db, execution_id) + {:ok, workspace_id} = Runs.get_workspace_id_for_execution(state.db, execution_id) + {:ok, {type, message, frames}} = Errors.get_by_id(state.db, error_id) + + {retry_id, _recurred?, state} = + decide_and_create_successor( + state, + execution_id, + step, + workspace_id, + {:error, type, message, frames, nil} + ) + + case Results.record_completion(state.db, execution_id, :stream_errored, + successor_id: retry_id + ) do + {:ok, completion_at} -> + # Re-fire :result so the UI's value-result entry picks up the new + # error overlay (the resolve_logical path returns :error for + # :stream_errored kinds with a value present). + state = + fire_result_notifications( + state, + execution_id, + {:error, type, message, frames, retry_id, nil}, + nil, + nil + ) + + fire_completion_notification(state, execution_id, completion_at) + + {:error, :already_completed} -> + state + end + end + + # A stream owned by this execution closed via idle timeout. The execution + # itself succeeded; promote to `:stream_timeout` to exclude it from cache + # lookups (consumer-shaped cache contents would be wrong) without + # surfacing as a failure or triggering a retry. + defp finalize_stream_timeout_completion(state, execution_id) do + case Results.record_completion(state.db, execution_id, :stream_timeout) do {:ok, completion_at} -> fire_completion_notification(state, execution_id, completion_at) diff --git a/server/lib/coflux/orchestration/streams.ex b/server/lib/coflux/orchestration/streams.ex index 974690da..6e36968a 100644 --- a/server/lib/coflux/orchestration/streams.ex +++ b/server/lib/coflux/orchestration/streams.ex @@ -224,6 +224,42 @@ defmodule Coflux.Orchestration.Streams do end end + # Returns a summary of how the streams owned by `execution_id` closed. + # Used by `complete_execution` to decide whether to promote a value-result + # to `:stream_errored` / `:partial`. + # + # Shape: `{:ok, %{errored: integer | nil, timed_out: boolean}}` + # * `errored` — the `errors.id` for the *first* errored stream closure + # (in stream-index order), or `nil` if none errored + # * `timed_out` — true if any stream closed via idle timeout + # + # Lifecycle / complete closures are ignored: the former inherit the + # execution's eventual outcome, the latter are the success case. + def get_closure_summary_for_execution(db, execution_id) do + case query( + db, + """ + SELECT reason, error_id + FROM stream_closures + WHERE execution_id = ?1 + ORDER BY `index` + """, + {execution_id} + ) do + {:ok, rows} -> + summary = + Enum.reduce(rows, %{errored: nil, timed_out: false}, fn {reason, error_id}, acc -> + case reason_from_int(reason) do + :errored -> if acc.errored, do: acc, else: %{acc | errored: error_id} + :timeout -> %{acc | timed_out: true} + _ -> acc + end + end) + + {:ok, summary} + end + end + # Returns indexes of streams owned by `execution_id` that don't yet have # a closure row. Used by the lifecycle code to discover which streams to # close on completion / cancel / crash. diff --git a/server/priv/migrations/orchestration/4.sql b/server/priv/migrations/orchestration/4.sql index a6dd71e2..545475e8 100644 --- a/server/priv/migrations/orchestration/4.sql +++ b/server/priv/migrations/orchestration/4.sql @@ -31,17 +31,22 @@ CREATE TABLE results_new ( ) STRICT; -- The completions table. Terminal state for the execution. `kind` values: --- 0 = succeeded — value result recorded, process ended cleanly --- 1 = errored — error result recorded, process ended cleanly --- 2 = abandoned — session expired before notify_terminated --- 3 = crashed — notify_terminated without prior result --- 4 = timeout — execution hit its timeout --- 5 = cancelled — user cancelled (may or may not have a result row) --- 6 = suspended — body called suspend; successor resumes later --- 7 = recurred — recurrent execution scheduled its next run --- 8 = deferred — execution deferred to another (memoisation / defer) --- 9 = cached — execution resolved to an existing cache hit --- 10 = spawned — execution spawned a continuation +-- 0 = succeeded — value result recorded, process ended cleanly +-- 1 = errored — error result recorded, process ended cleanly +-- 2 = abandoned — session expired before notify_terminated +-- 3 = crashed — notify_terminated without prior result +-- 4 = timeout — execution hit its timeout +-- 5 = cancelled — user cancelled (may or may not have a result row) +-- 6 = suspended — body called suspend; successor resumes later +-- 7 = recurred — recurrent execution scheduled its next run +-- 8 = deferred — execution deferred to another (memoisation / defer) +-- 9 = cached — execution resolved to an existing cache hit +-- 10 = spawned — execution spawned a continuation +-- 11 = stream_errored — value result recorded but a stream errored mid-flight; +-- counted as a failure for retry / cache eligibility +-- 12 = stream_timeout — value result recorded but a stream timed out; +-- logically a success but ineligible for cache. +-- Distinct from (execution-level) `timeout` = 4. -- -- `successor_id` points at an execution in the same epoch; used for retry -- chains and in-flight handoffs. `successor_ref_id` points at an diff --git a/tests/test_streams.py b/tests/test_streams.py index a84d9df3..dbe8f6f3 100644 --- a/tests/test_streams.py +++ b/tests/test_streams.py @@ -16,7 +16,7 @@ import pytest -from support.manifest import workflow +from support.manifest import task, workflow from support.protocol import ( execution_result, json_args, @@ -918,3 +918,231 @@ def test_timeout_visible_in_topic(worker): assert stream["timeoutMs"] == 120 assert stream["reason"] == "timeout" assert stream["error"] is None + + +# --- Stream outcomes promoted to execution completion -------------------- + + +def test_clean_stream_keeps_completion_succeeded(worker): + """Sanity: a value-result + cleanly-closed streams still completes as + `:succeeded`. (Guards against the new dispatch in + `finalize_success_completion` regressing the happy path.) + """ + targets = [workflow("test", "producer")] + + with worker(targets) as ctx: + prod_resp = ctx.submit("test", "producer") + prod_ex = ctx.executor.next_execute() + prod_ex.conn.stream_register(prod_ex.execution_id, 0) + prod_ex.conn.stream_append(prod_ex.execution_id, 0, 0, "ok") + prod_ex.conn.stream_close(prod_ex.execution_id, 0) + prod_ex.conn.complete(prod_ex.execution_id, value=1) + ctx.result(prod_resp["runId"]) + + snapshot = ctx.inspect(prod_resp["runId"]) + execution = next(iter(snapshot["steps"][f"{prod_resp['runId']}:1"]["executions"].values())) + assert execution["completion"]["kind"] == "succeeded" + + +def test_stream_error_promotes_completion_to_stream_errored(worker): + """When a stream owned by an execution closes with an error but the + function body returned a value, the execution's completion is + promoted to `:stream_errored` (so the result is not cacheable and + the step's retry policy applies). + """ + targets = [workflow("test", "producer")] + + with worker(targets) as ctx: + prod_resp = ctx.submit("test", "producer") + prod_ex = ctx.executor.next_execute() + prod_ex.conn.stream_register(prod_ex.execution_id, 0) + prod_ex.conn.stream_append(prod_ex.execution_id, 0, 0, "ok") + prod_ex.conn.stream_close( + prod_ex.execution_id, + 0, + error={"type": "ValueError", "message": "boom", "traceback": ""}, + ) + prod_ex.conn.complete(prod_ex.execution_id, value=1) + ctx.result(prod_resp["runId"]) + + snapshot = ctx.inspect(prod_resp["runId"]) + execution = next(iter(snapshot["steps"][f"{prod_resp['runId']}:1"]["executions"].values())) + assert execution["completion"]["kind"] == "stream_errored" + + +def test_stream_timeout_promotes_completion_to_stream_timeout(worker): + """When a stream owned by an execution closes via idle timeout but the + function body returned a value, the completion is promoted to + `:stream_timeout` (logically a success, but ineligible for cache). + """ + targets = [workflow("test", "producer")] + + with worker(targets) as ctx: + prod_resp = ctx.submit("test", "producer") + prod_ex = ctx.executor.next_execute() + prod_ex.conn.stream_register( + prod_ex.execution_id, 0, buffer=None, timeout_ms=100 + ) + + # Wait for the timeout to fire on the producer side. + force = prod_ex.conn.recv_push("stream_force_close", timeout=2) + assert force["reason"] == "timeout" + + prod_ex.conn.complete(prod_ex.execution_id, value=1) + ctx.result(prod_resp["runId"]) + + snapshot = ctx.inspect(prod_resp["runId"]) + execution = next(iter(snapshot["steps"][f"{prod_resp['runId']}:1"]["executions"].values())) + assert execution["completion"]["kind"] == "stream_timeout" + + +def test_stream_error_outranks_timeout(worker): + """If one stream errored and another timed out on the same execution, + `:stream_errored` takes precedence — error is the stronger signal. + """ + targets = [workflow("test", "producer")] + + with worker(targets) as ctx: + prod_resp = ctx.submit("test", "producer") + prod_ex = ctx.executor.next_execute() + # Stream 0: will time out (no appends). + prod_ex.conn.stream_register( + prod_ex.execution_id, 0, buffer=None, timeout_ms=100 + ) + # Stream 1: will be closed with an error. + prod_ex.conn.stream_register(prod_ex.execution_id, 1) + + # Drain the timeout signal so we're past it. + force = prod_ex.conn.recv_push("stream_force_close", timeout=2) + assert force["reason"] == "timeout" + + prod_ex.conn.stream_close( + prod_ex.execution_id, + 1, + error={"type": "RuntimeError", "message": "bad", "traceback": ""}, + ) + prod_ex.conn.complete(prod_ex.execution_id, value=1) + ctx.result(prod_resp["runId"]) + + snapshot = ctx.inspect(prod_resp["runId"]) + execution = next(iter(snapshot["steps"][f"{prod_resp['runId']}:1"]["executions"].values())) + assert execution["completion"]["kind"] == "stream_errored" + + +def _wait_for_completion(ctx, run_id, step_num, timeout=5): + """Poll the run topic until the given step's first execution has a + completion row recorded, then return the completion kind. Used by + cache-lookup tests that need to be sure the prior execution has + transitioned past `:draining` (otherwise the in-flight cache lookup + matches the value-result before the new completion-kind dispatch + runs).""" + import time as _t + + deadline = _t.time() + timeout + while _t.time() < deadline: + snapshot = ctx.inspect(run_id) + executions = snapshot["steps"][f"{run_id}:{step_num}"]["executions"] + execution = next(iter(executions.values())) + completion = execution.get("completion") + if completion is not None: + return completion["kind"] + _t.sleep(0.05) + raise AssertionError(f"step {run_id}:{step_num} not completed within {timeout}s") + + +def test_stream_errored_execution_not_used_as_cache_hit(worker): + """A `:stream_errored` execution is not eligible for cache lookup — a + second submission with the same args re-executes the task. + """ + targets = [ + workflow("test", "main"), + task("test", "produce", parameters=["x"]), + ] + + with worker(targets, concurrency=2) as ctx: + resp = ctx.submit("test", "main") + run_id = resp["runId"] + wf = ctx.executor.next_execute() + + ref1 = wf.conn.submit_task( + wf.execution_id, + "test", "produce", + json_args(1), + cache={"params": True}, + ) + + # First execution: returns a value, but its stream closes with an error. + prod = ctx.executor.next_execute() + prod.conn.stream_register(prod.execution_id, 0) + prod.conn.stream_close( + prod.execution_id, + 0, + error={"type": "ValueError", "message": "oops", "traceback": ""}, + ) + prod.conn.complete(prod.execution_id, value="v") + assert wf.conn.resolve(wf.execution_id, ref1)["value"] == "v" + + # Wait for the completion to be promoted before the second submit — + # otherwise the in-flight (`:draining`) cache lookup matches the + # value-result first and our promotion logic never gets to run. + assert _wait_for_completion(ctx, run_id, 2) == "stream_errored" + + # Second submission with same args: should NOT cache-hit — the + # previous execution's completion is :stream_errored. + wf.conn.submit_task( + wf.execution_id, + "test", "produce", + json_args(1), + cache={"params": True}, + ) + # A cache hit would skip the execute; we expect a fresh dispatch. + prod2 = ctx.executor.next_execute(timeout=3) + prod2.conn.complete(prod2.execution_id, value="v2") + + wf.conn.complete(wf.execution_id) + + +def test_stream_timeout_execution_not_used_as_cache_hit(worker): + """A `:stream_timeout` execution (stream timed out) is not eligible + for cache lookup — a second submission with the same args re-executes. + """ + targets = [ + workflow("test", "main"), + task("test", "produce", parameters=["x"]), + ] + + with worker(targets, concurrency=2) as ctx: + resp = ctx.submit("test", "main") + run_id = resp["runId"] + wf = ctx.executor.next_execute() + + ref1 = wf.conn.submit_task( + wf.execution_id, + "test", "produce", + json_args(1), + cache={"params": True}, + ) + + prod = ctx.executor.next_execute() + prod.conn.stream_register( + prod.execution_id, 0, buffer=None, timeout_ms=100 + ) + force = prod.conn.recv_push("stream_force_close", timeout=2) + assert force["reason"] == "timeout" + prod.conn.complete(prod.execution_id, value="v") + assert wf.conn.resolve(wf.execution_id, ref1)["value"] == "v" + + assert _wait_for_completion(ctx, run_id, 2) == "stream_timeout" + + # Second submission with same args: previous completion is + # :stream_timeout, not cacheable — expect a fresh execution. + wf.conn.submit_task( + wf.execution_id, + "test", "produce", + json_args(1), + cache={"params": True}, + ) + prod2 = ctx.executor.next_execute(timeout=3) + prod2.conn.complete(prod2.execution_id, value="v2") + + wf.conn.complete(wf.execution_id) From ef18a4ac00eb8ef78df799a5c54c232213164012 Mon Sep 17 00:00:00 2001 From: Joe Freeman Date: Mon, 20 Apr 2026 22:27:21 +0100 Subject: [PATCH 21/25] Capture full stream error --- server/lib/coflux/orchestration/server.ex | 17 ++++++++++------- server/lib/coflux/topics/run.ex | 4 ++-- tests/test_streams.py | 4 +++- 3 files changed, 15 insertions(+), 10 deletions(-) diff --git a/server/lib/coflux/orchestration/server.ex b/server/lib/coflux/orchestration/server.ex index 41e468dd..9f390b62 100644 --- a/server/lib/coflux/orchestration/server.ex +++ b/server/lib/coflux/orchestration/server.ex @@ -7855,11 +7855,7 @@ defmodule Coflux.Orchestration.Server do {:ok, {r, _s, _a}} = Runs.get_execution_key(state.db, execution_id) {:ok, execution_ext_id} = execution_external_id_for(state.db, execution_id) - encoded_error = - case error do - nil -> nil - {type, message, _frames} -> %{type: type, message: message} - end + encoded_error = encode_stream_error_summary(error) reason_str = if reason, do: Atom.to_string(reason) @@ -7943,8 +7939,15 @@ defmodule Coflux.Orchestration.Server do defp encode_stream_error_summary(nil), do: nil - defp encode_stream_error_summary({type, message, _frames}) do - %{type: type, message: message} + defp encode_stream_error_summary({type, message, frames}) do + %{ + type: type, + message: message, + frames: + Enum.map(frames, fn {file, line, name, code} -> + %{file: file, line: line, name: name, code: code} + end) + } end defp build_stream_producer(db, execution_ext_id, execution_id) do diff --git a/server/lib/coflux/topics/run.ex b/server/lib/coflux/topics/run.ex index f6fb77bf..ae9ac973 100644 --- a/server/lib/coflux/topics/run.ex +++ b/server/lib/coflux/topics/run.ex @@ -582,7 +582,7 @@ defmodule Coflux.Topics.Run do error: nil }} - {index, buffer, timeout_ms, opened_at, closed_at, reason, {type, message, _frames}} -> + {index, buffer, timeout_ms, opened_at, closed_at, reason, {type, message, frames}} -> {Integer.to_string(index), %{ buffer: buffer, @@ -590,7 +590,7 @@ defmodule Coflux.Topics.Run do openedAt: opened_at, closedAt: closed_at, reason: Atom.to_string(reason), - error: %{type: type, message: message} + error: %{type: type, message: message, frames: build_frames(frames)} }} end) end diff --git a/tests/test_streams.py b/tests/test_streams.py index dbe8f6f3..b4ccbdea 100644 --- a/tests/test_streams.py +++ b/tests/test_streams.py @@ -319,7 +319,9 @@ def test_topic_exposes_stream_state(worker): assert streams["0"]["error"] is None assert streams["1"]["closedAt"] is not None assert streams["1"]["reason"] == "errored" - assert streams["1"]["error"] == {"type": "RuntimeError", "message": "bad"} + assert streams["1"]["error"]["type"] == "RuntimeError" + assert streams["1"]["error"]["message"] == "bad" + assert isinstance(streams["1"]["error"]["frames"], list) def test_cancellation_closes_streams_with_cancelled_reason(worker): From 71ef39b17ec00de946304c0ff08862796732958a Mon Sep 17 00:00:00 2001 From: Joe Freeman Date: Mon, 20 Apr 2026 22:31:37 +0100 Subject: [PATCH 22/25] Fix resolving result for execution with stream error --- server/lib/coflux/orchestration/results.ex | 58 ++++++---------------- server/lib/coflux/orchestration/server.ex | 20 ++------ 2 files changed, 19 insertions(+), 59 deletions(-) diff --git a/server/lib/coflux/orchestration/results.ex b/server/lib/coflux/orchestration/results.ex index 870644bf..3dc4a88a 100644 --- a/server/lib/coflux/orchestration/results.ex +++ b/server/lib/coflux/orchestration/results.ex @@ -1,7 +1,7 @@ defmodule Coflux.Orchestration.Results do import Coflux.Store - alias Coflux.Orchestration.{Errors, Streams, Values} + alias Coflux.Orchestration.{Errors, Values} # --- Completion kinds --- # @@ -464,20 +464,6 @@ defmodule Coflux.Orchestration.Results do defp decode_retryable(1), do: true defp decode_retryable(0), do: false - # Look up the first errored stream closure's error triple for a - # :stream_errored execution. Returns `nil` if no errored closure exists - # (shouldn't happen if the kind is :stream_errored, but defensive). - defp fetch_stream_error(db, execution_id) do - case Streams.get_closure_summary_for_execution(db, execution_id) do - {:ok, %{errored: nil}} -> - nil - - {:ok, %{errored: error_id}} when not is_nil(error_id) -> - {:ok, triple} = Errors.get_by_id(db, error_id) - triple - end - end - # Builds the legacy "logical result" tuple from the split tables. Used # by most callers (UI, topic state, consumer resolution). Returns `nil` # only when nothing has been recorded yet. @@ -490,46 +476,30 @@ defmodule Coflux.Orchestration.Results do # Value payload present. Returns the appropriate tagged tuple, picking # the successor-flavoured form when the completion says this was a - # deferred/cached/spawned resolution. + # deferred/cached/spawned resolution. For `:stream_errored` and + # `:stream_timeout`, the execution's *result* is still the value (the + # stream reference) — the stream's error/timeout state is surfaced + # through the streams panel rather than the result. The completion kind + # alone carries the "this is broken/incomplete" signal for the UI badge + # and for cache eligibility (handled in `find_cached_execution`). defp resolve_logical( db, - execution_id, + _execution_id, kind, value_id, nil, - retryable, - successor_id, + _retryable, + _successor_id, successor_ref_id ) when not is_nil(value_id) do {:ok, value} = Values.get_value_by_id(db, value_id) case kind && kind_atom(kind) do - :deferred when not is_nil(successor_ref_id) -> - {:deferred, successor_ref_id, value} - - :cached when not is_nil(successor_ref_id) -> - {:cached, successor_ref_id, value} - - :spawned when not is_nil(successor_ref_id) -> - {:spawned, successor_ref_id, value} - - :stream_errored -> - # Value result + a stream errored mid-flight. Surface as an error - # for consumer resolution, using the first errored stream closure's - # stored error triple. (The cancellation precedent leaves the value - # in place — consumers that already resolved against it keep their - # reference; this branch governs only late lookups.) - case fetch_stream_error(db, execution_id) do - nil -> - {:value, value} - - {type, message, frames} -> - {:error, type, message, frames, successor_id, decode_retryable(retryable)} - end - - _ -> - {:value, value} + :deferred when not is_nil(successor_ref_id) -> {:deferred, successor_ref_id, value} + :cached when not is_nil(successor_ref_id) -> {:cached, successor_ref_id, value} + :spawned when not is_nil(successor_ref_id) -> {:spawned, successor_ref_id, value} + _ -> {:value, value} end end diff --git a/server/lib/coflux/orchestration/server.ex b/server/lib/coflux/orchestration/server.ex index 9f390b62..701fff1f 100644 --- a/server/lib/coflux/orchestration/server.ex +++ b/server/lib/coflux/orchestration/server.ex @@ -6207,9 +6207,11 @@ defmodule Coflux.Orchestration.Server do # A stream owned by this execution closed with an error, but the function # body returned a value. Promote to `:stream_errored`: drives the retry - # policy and excludes the execution from cache lookups, while leaving the - # value in `results` so any consumer that already resolved against it - # keeps its reference (mirrors the cancellation precedent at do_cancel_execution). + # policy and excludes the execution from cache lookups. The value result + # stays untouched in `results` — the execution's "result" remains the + # value (the stream reference). The stream's error info is surfaced via + # the streams panel; the completion kind alone tells the UI to render + # this as a failure-with-value (mirrors `do_cancel_execution`). defp finalize_stream_errored_completion(state, execution_id, error_id) do {:ok, step} = Runs.get_step_for_execution(state.db, execution_id) {:ok, workspace_id} = Runs.get_workspace_id_for_execution(state.db, execution_id) @@ -6228,18 +6230,6 @@ defmodule Coflux.Orchestration.Server do successor_id: retry_id ) do {:ok, completion_at} -> - # Re-fire :result so the UI's value-result entry picks up the new - # error overlay (the resolve_logical path returns :error for - # :stream_errored kinds with a value present). - state = - fire_result_notifications( - state, - execution_id, - {:error, type, message, frames, retry_id, nil}, - nil, - nil - ) - fire_completion_notification(state, execution_id, completion_at) {:error, :already_completed} -> From 982d75cc30a32a2a0f25cf2a9b2b3c550d1c142d Mon Sep 17 00:00:00 2001 From: Joe Freeman Date: Tue, 21 Apr 2026 23:50:09 +0100 Subject: [PATCH 23/25] Dispatch worker requests async --- cli/internal/pool/pool.go | 9 +++- tests/test_streams.py | 104 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 112 insertions(+), 1 deletion(-) diff --git a/cli/internal/pool/pool.go b/cli/internal/pool/pool.go index 03d49660..78b457ed 100644 --- a/cli/internal/pool/pool.go +++ b/cli/internal/pool/pool.go @@ -328,7 +328,14 @@ loop: p.handleMetric(execCtx, executionID, params, logger) case "submit_execution", "select", "persist_asset", "get_asset", "suspend", "cancel", "download_blob", "upload_blob", "submit_input": - p.handleRequest(execCtx, exec, method, *id, params, logger) + // Dispatch async: these can block on the server (e.g. a + // `select` that waits for a child execution). Blocking the + // message loop here would stop us reading the adapter's + // subsequent messages — including stream_append from a + // stream-producing task that hasn't finished yet — creating + // a deadlock between the waiting select and the consumer + // that's holding it up. + go p.handleRequest(execCtx, exec, method, *id, params, logger) case "register_group": p.handleRegisterGroup(execCtx, executionID, params, logger) diff --git a/tests/test_streams.py b/tests/test_streams.py index b4ccbdea..0e865c10 100644 --- a/tests/test_streams.py +++ b/tests/test_streams.py @@ -14,10 +14,13 @@ taking turns over different connections. """ +import time + import pytest from support.manifest import task, workflow from support.protocol import ( + execution_handle, execution_result, json_args, partition_stride, @@ -766,6 +769,107 @@ def test_backpressure_subscribe_unblocks_producer(worker): ctx.result(prod_resp["runId"]) +def test_workflow_produces_stream_while_awaiting_consumer(worker): + """Regression: a workflow that produces a stream *and* synchronously + awaits a consumer of that stream must not deadlock. + + Reproduces the ``inline_producer`` pattern from + ``examples/python/examples/streams.py``: the workflow registers an + inline stream, submits a consumer task with the stream as its + argument, then calls ``select`` (blocking) to wait for the + consumer's result. The producer driver keeps emitting items on the + same session while the select is in flight. + + Before the fix, the CLI's per-executor message loop handled + ``select`` synchronously: while waiting on the server's response, + the loop couldn't read the adapter's subsequent ``stream_append`` + notifications from stdout. The appends stayed in the readLoop's + channel, the consumer never received any items, and the select + never resolved — a full deadlock. + """ + targets = [ + workflow("test", "producer_workflow"), + task("test", "consumer"), + ] + + with worker(targets, concurrency=2) as ctx: + resp = ctx.submit("test", "producer_workflow") + wf = ctx.executor.next_execute() + + # Workflow registers an inline stream and submits the consumer + # task with that stream handle as its argument. + wf.conn.stream_register(wf.execution_id, 0, buffer=0) + stream_arg = { + "type": "inline", + "format": "json", + "value": {"type": "stream", "id": f"{wf.execution_id}_0"}, + "references": [], + } + consumer_id = wf.conn.submit_task( + wf.execution_id, "test", "consumer", [stream_arg] + ) + + # Send the select ourselves so we can interleave stream_appends + # behind it without racing a background thread. The real adapter + # does exactly this: its main thread blocks in _wait_response + # while the stream driver thread keeps calling send_stream_append + # on the same stdout. + select_id = wf.conn._next_request_id + wf.conn._next_request_id += 1 + wf.conn.send( + { + "id": select_id, + "method": "select", + "params": { + "execution_id": wf.execution_id, + "handles": [execution_handle(consumer_id)], + "suspend": False, + }, + } + ) + + # Consumer picks up and subscribes before any items are emitted. + cons = ctx.executor.next_execute() + cons.conn.stream_subscribe( + cons.execution_id, + subscription_id=1, + producer_execution_id=wf.execution_id, + index=0, + ) + + # Emit items on the workflow's connection while its select is + # pending. Each of these requires the CLI to read and forward + # the notification *while* the select request is still in + # flight — the exact path that used to deadlock. + for i in range(3): + wf.conn.stream_append(wf.execution_id, 0, i, i * 10) + wf.conn.stream_close(wf.execution_id, 0) + + # Consumer drains — reaching the close means every append made + # it through the CLI while select was holding the loop. + items, closed = cons.conn.drain_stream(subscription_id=1, timeout=5) + cons.conn.complete(cons.execution_id, value=len(items)) + + # Read directly from the socket (not via _buffer, which would + # re-pop anything we stored) until the select response arrives. + # If the bug is present this never happens and _recv_raw times + # out. + deadline = time.time() + 5 + while True: + remaining = max(0.01, deadline - time.time()) + msg = wf.conn._recv_raw(remaining) + if msg.get("id") == select_id: + assert "error" not in msg, f"select errored: {msg['error']}" + break + + wf.conn.complete(wf.execution_id, value=len(items)) + ctx.result(resp["runId"]) + + assert [item[1]["value"] for item in items] == [0, 10, 20] + assert closed.get("error") is None + assert closed.get("reason") != "timeout" + + def test_backpressure_unbounded_sends_no_demand(worker): """Registering without a buffer (wire buffer=null) opts out of backpressure — the server never sends demand grants for this stream. From 9dd66dbe61a2e8bb9535c124be20787c2be8e1ea Mon Sep 17 00:00:00 2001 From: Joe Freeman Date: Tue, 21 Apr 2026 23:55:37 +0100 Subject: [PATCH 24/25] Don't use memoised execution with errored/timed-out stream --- server/lib/coflux/orchestration/runs.ex | 44 ++++++------ tests/test_streams.py | 94 +++++++++++++++++++++++++ 2 files changed, 116 insertions(+), 22 deletions(-) diff --git a/server/lib/coflux/orchestration/runs.ex b/server/lib/coflux/orchestration/runs.ex index f43c065f..66177919 100644 --- a/server/lib/coflux/orchestration/runs.ex +++ b/server/lib/coflux/orchestration/runs.ex @@ -1120,6 +1120,21 @@ defmodule Coflux.Orchestration.Runs do |> Enum.join() end + # Completion kinds that disqualify an execution from being reused as + # a memo/cache hit. All three have a value row (so the value-id check + # alone wouldn't reject them) but their value is unsafe to reuse: + # `:cancelled` is an explicit user override, `:stream_errored` is a + # producer failure, `:stream_timeout` is consumer-shaped output. + # In-flight executions (no completion row yet) remain candidates so + # concurrent callers still deduplicate onto a running attempt. + defp reuse_disqualified_kinds do + Enum.map_join( + [:cancelled, :stream_errored, :stream_timeout], + ", ", + &Integer.to_string(Results.atom_kind(&1)) + ) + end + # TODO: consider changed 'requires'? defp find_memoised_execution(db, run_id, workspace_ids, memo_key) do case query( @@ -1135,12 +1150,11 @@ defmodule Coflux.Orchestration.Runs do AND e.workspace_id IN (#{build_placeholders(length(workspace_ids), 1)}) AND s.memo_key = ?#{length(workspace_ids) + 2} -- Either no result yet (in-flight candidate) or a value - -- result. Errors disqualify. Cancelled-with-value also - -- disqualifies: a user-cancelled execution's work shouldn't - -- be reused, even though its value is still valid for - -- already-resolved consumers. + -- result. Errors disqualify (no value_id). See + -- `reuse_disqualified_kinds` for completion-kind exclusions + -- that survive having a value. AND (r.execution_id IS NULL OR r.value_id IS NOT NULL) - AND (c.kind IS NULL OR c.kind != #{Results.atom_kind(:cancelled)}) + AND (c.kind IS NULL OR c.kind NOT IN (#{reuse_disqualified_kinds()})) ORDER BY e.created_at DESC LIMIT 1 """, @@ -1169,20 +1183,6 @@ defmodule Coflux.Orchestration.Runs do workspace_ids ++ [{:blob, cache_key}, recorded_after] end - # Disqualifying completion kinds: cancellation, plus the two - # value-result-but-stream-broke kinds. These all have a value row - # (so the value-id check above wouldn't reject them) but the cached - # value is unsafe to reuse — cancelled = explicit user override, - # stream_errored = producer failure, stream_timeout = consumer-shaped - # output. Each is excluded as soon as its completion is recorded; - # in-flight executions (no completion row yet) remain candidates. - disqualified_kinds = - Enum.map_join( - [:cancelled, :stream_errored, :stream_timeout], - ", ", - &Integer.to_string(Results.atom_kind(&1)) - ) - case query( db, """ @@ -1196,13 +1196,13 @@ defmodule Coflux.Orchestration.Runs do AND s.cache_key = ?#{length(workspace_ids) + 1} -- Either no result yet (in-flight candidate) or a value result -- recorded within the cache age window. Errors disqualify - -- (no value_id). See `disqualified_kinds` for completion-kind - -- exclusions that survive having a value. + -- (no value_id). See `reuse_disqualified_kinds` for + -- completion-kind exclusions that survive having a value. AND ( r.execution_id IS NULL OR (r.value_id IS NOT NULL AND r.created_at >= ?#{length(workspace_ids) + 2}) ) - AND (c.kind IS NULL OR c.kind NOT IN (#{disqualified_kinds})) + AND (c.kind IS NULL OR c.kind NOT IN (#{reuse_disqualified_kinds()})) #{step_clause} ORDER BY e.created_at DESC LIMIT 1 diff --git a/tests/test_streams.py b/tests/test_streams.py index 0e865c10..36a9bd4b 100644 --- a/tests/test_streams.py +++ b/tests/test_streams.py @@ -1252,3 +1252,97 @@ def test_stream_timeout_execution_not_used_as_cache_hit(worker): prod2.conn.complete(prod2.execution_id, value="v2") wf.conn.complete(wf.execution_id) + + +def test_stream_errored_execution_not_used_as_memo_hit(worker): + """A `:stream_errored` execution is not eligible for memo lookup + either — a second memoised call within the same run re-executes + once the first attempt's completion records the stream error. + """ + targets = [ + workflow("test", "main"), + task("test", "produce", parameters=["x"]), + ] + + with worker(targets, concurrency=2) as ctx: + resp = ctx.submit("test", "main") + run_id = resp["runId"] + wf = ctx.executor.next_execute() + + ref1 = wf.conn.submit_task( + wf.execution_id, + "test", "produce", + json_args(1), + memo=True, + ) + + prod = ctx.executor.next_execute() + prod.conn.stream_register(prod.execution_id, 0) + prod.conn.stream_close( + prod.execution_id, + 0, + error={"type": "ValueError", "message": "oops", "traceback": ""}, + ) + prod.conn.complete(prod.execution_id, value="v") + assert wf.conn.resolve(wf.execution_id, ref1)["value"] == "v" + + # Wait for the completion so the memo lookup sees :stream_errored, + # not the in-flight :draining state (which would still match). + assert _wait_for_completion(ctx, run_id, 2) == "stream_errored" + + # Same memo key, but the prior attempt is :stream_errored — expect + # a fresh execution rather than a memo hit on the broken stream. + wf.conn.submit_task( + wf.execution_id, + "test", "produce", + json_args(1), + memo=True, + ) + prod2 = ctx.executor.next_execute(timeout=3) + prod2.conn.complete(prod2.execution_id, value="v2") + + wf.conn.complete(wf.execution_id) + + +def test_stream_timeout_execution_not_used_as_memo_hit(worker): + """A `:stream_timeout` execution is not eligible for memo lookup — + a second memoised call within the same run re-executes. + """ + targets = [ + workflow("test", "main"), + task("test", "produce", parameters=["x"]), + ] + + with worker(targets, concurrency=2) as ctx: + resp = ctx.submit("test", "main") + run_id = resp["runId"] + wf = ctx.executor.next_execute() + + ref1 = wf.conn.submit_task( + wf.execution_id, + "test", "produce", + json_args(1), + memo=True, + ) + + prod = ctx.executor.next_execute() + prod.conn.stream_register( + prod.execution_id, 0, buffer=None, timeout_ms=100 + ) + force = prod.conn.recv_push("stream_force_close", timeout=2) + assert force["reason"] == "timeout" + prod.conn.complete(prod.execution_id, value="v") + assert wf.conn.resolve(wf.execution_id, ref1)["value"] == "v" + + assert _wait_for_completion(ctx, run_id, 2) == "stream_timeout" + + wf.conn.submit_task( + wf.execution_id, + "test", "produce", + json_args(1), + memo=True, + ) + prod2 = ctx.executor.next_execute(timeout=3) + prod2.conn.complete(prod2.execution_id, value="v2") + + wf.conn.complete(wf.execution_id) From d007aa0026987971cc5e05f64a27ed22a913cbe4 Mon Sep 17 00:00:00 2001 From: Joe Freeman Date: Wed, 22 Apr 2026 19:58:08 +0100 Subject: [PATCH 25/25] Track stream dependencies --- server/lib/coflux/orchestration/epoch.ex | 27 ++ server/lib/coflux/orchestration/results.ex | 11 +- server/lib/coflux/orchestration/runs.ex | 38 +++ server/lib/coflux/orchestration/server.ex | 354 ++++++++++++--------- server/lib/coflux/topics/run.ex | 32 ++ server/priv/migrations/orchestration/4.sql | 19 ++ 6 files changed, 325 insertions(+), 156 deletions(-) diff --git a/server/lib/coflux/orchestration/epoch.ex b/server/lib/coflux/orchestration/epoch.ex index aa9cf899..eed85a83 100644 --- a/server/lib/coflux/orchestration/epoch.ex +++ b/server/lib/coflux/orchestration/epoch.ex @@ -448,6 +448,7 @@ defmodule Coflux.Orchestration.Epoch do copy_execution_result_deps(source_db, target_db, old_exec_id, new_exec_id) copy_execution_assets(source_db, target_db, old_exec_id, new_exec_id) copy_execution_asset_deps(source_db, target_db, old_exec_id, new_exec_id) + copy_execution_stream_deps(source_db, target_db, old_exec_id, new_exec_id) copy_execution_inputs(source_db, target_db, old_exec_id, new_exec_id, new_run_id) copy_execution_input_deps(source_db, target_db, old_exec_id, new_exec_id) end) @@ -561,6 +562,32 @@ defmodule Coflux.Orchestration.Epoch do end) end + defp copy_execution_stream_deps(source_db, target_db, old_exec_id, new_exec_id) do + {:ok, stream_deps} = + query( + source_db, + "SELECT stream_ref_id, stream_index, created_at FROM stream_dependencies WHERE execution_id = ?1", + {old_exec_id} + ) + + Enum.each(stream_deps, fn {old_ref_id, stream_index, created_at} -> + new_ref_id = ensure_execution_ref(source_db, target_db, old_ref_id) + + {:ok, _} = + insert_one( + target_db, + :stream_dependencies, + %{ + execution_id: new_exec_id, + stream_ref_id: new_ref_id, + stream_index: stream_index, + created_at: created_at + }, + on_conflict: "DO NOTHING" + ) + end) + end + defp copy_execution_inputs(source_db, target_db, old_exec_id, new_exec_id, new_run_id) do # Copy execution_inputs records (which executions submitted which inputs) {:ok, exec_inputs} = diff --git a/server/lib/coflux/orchestration/results.ex b/server/lib/coflux/orchestration/results.ex index 3dc4a88a..abea7576 100644 --- a/server/lib/coflux/orchestration/results.ex +++ b/server/lib/coflux/orchestration/results.ex @@ -506,7 +506,16 @@ defmodule Coflux.Orchestration.Results do # Error payload without a completion yet. We return the error so UI can # display it; the successor slot is nil so consumer resolution treats # it as still pending (the retry decision happens at completion time). - defp resolve_logical(db, _exec_id, nil, nil, error_id, retryable, _successor_id, _successor_ref_id) + defp resolve_logical( + db, + _exec_id, + nil, + nil, + error_id, + retryable, + _successor_id, + _successor_ref_id + ) when not is_nil(error_id) do {:ok, {type, message, frames}} = Errors.get_by_id(db, error_id) {:error, type, message, frames, nil, decode_retryable(retryable)} diff --git a/server/lib/coflux/orchestration/runs.ex b/server/lib/coflux/orchestration/runs.ex index 66177919..469a2db4 100644 --- a/server/lib/coflux/orchestration/runs.ex +++ b/server/lib/coflux/orchestration/runs.ex @@ -624,6 +624,22 @@ defmodule Coflux.Orchestration.Runs do end) end + def record_stream_dependency(db, execution_id, stream_ref_id, stream_index) do + with_transaction(db, fn -> + insert_one( + db, + :stream_dependencies, + %{ + execution_id: execution_id, + stream_ref_id: stream_ref_id, + stream_index: stream_index, + created_at: current_timestamp() + }, + on_conflict: "DO NOTHING" + ) + end) + end + def get_unassigned_executions(db) do query( db, @@ -951,6 +967,28 @@ defmodule Coflux.Orchestration.Runs do end end + def get_run_stream_dependencies(db, run_id) do + case query( + db, + """ + SELECT d.execution_id, d.stream_ref_id, d.stream_index + FROM stream_dependencies AS d + INNER JOIN executions AS e ON e.id = d.execution_id + INNER JOIN steps AS s ON s.id = e.step_id + WHERE s.run_id = ?1 + """, + {run_id} + ) do + {:ok, rows} -> + {:ok, + Enum.group_by( + rows, + fn {execution_id, _ref_id, _index} -> execution_id end, + fn {_execution_id, ref_id, index} -> {ref_id, index} end + )} + end + end + def get_step_assignments(db, step_id) do case query( db, diff --git a/server/lib/coflux/orchestration/server.ex b/server/lib/coflux/orchestration/server.ex index 701fff1f..286020b3 100644 --- a/server/lib/coflux/orchestration/server.ex +++ b/server/lib/coflux/orchestration/server.ex @@ -1894,159 +1894,6 @@ defmodule Coflux.Orchestration.Server do end end - defp maybe_init_stream_producer( - state, - _execution_id, - _execution_external_id, - _index, - nil, - _session_id - ) do - # buffer=nil means the producer has opted out of backpressure — no - # tracking required on the server side. It'll emit freely and the - # adapter's driver never waits. - state - end - - defp maybe_init_stream_producer( - state, - execution_id, - execution_external_id, - index, - buffer, - session_id - ) - when is_integer(buffer) and buffer >= 0 do - put_in(state.stream_producers[{execution_id, index}], %{ - buffer: buffer, - demand_granted: 0, - session_id: session_id, - execution_external_id: execution_external_id - }) - end - - defp maybe_send_initial_demand(state, execution_id, index) do - # At registration time there are no subscribers yet. Allow the - # producer to pre-warm up to `buffer` items; lockstep (buffer=0) - # stays paused until a consumer attaches. - refresh_stream_demand(state, {execution_id, index}) - end - - # Recompute the target demand for one stream and, if it's grown, - # send a delta grant to the producer's session. - # - # Formula: - # target = max_cursor + buffer + (1 if has_subscribers else 0) - # The +1 on subscriber presence unblocks lockstep streams — a - # consumer's cursor at position N means "ready for item N", which is - # one item beyond what they've acked. - # - # demand_granted is monotonic; if target drops (e.g. the fastest - # consumer left) we don't claw back, future grants just wait until - # the remaining subscribers catch up past the old max. - defp refresh_stream_demand(state, {_execution_id, index} = key) do - case Map.fetch(state.stream_producers, key) do - :error -> - state - - {:ok, %{session_id: nil}} -> - # Producer's session is gone — typically because the producer - # execution has long since terminated and we rebuilt its in-memory - # state for a late subscriber. There's nothing to grant demand to; - # the stream is durable in the DB and backlog reads don't consume - # credits. - state - - {:ok, producer} -> - has_subscribers = has_stream_subscribers?(state, key) - max_cursor = current_max_cursor(state, key) - bump = if has_subscribers, do: 1, else: 0 - target = max_cursor + producer.buffer + bump - delta = target - producer.demand_granted - - if delta > 0 do - state - |> put_in([Access.key(:stream_producers), key, :demand_granted], target) - |> send_session( - producer.session_id, - {:stream_demand, producer.execution_external_id, index, delta} - ) - else - state - end - end - end - - defp has_stream_subscribers?(state, key) do - case Map.get(state.stream_subscribers, key) do - nil -> false - set -> MapSet.size(set) > 0 - end - end - - defp current_max_cursor(state, key) do - state.stream_subscribers - |> Map.get(key, MapSet.new()) - |> Enum.reduce(0, fn sub_key, acc -> - case Map.get(state.stream_subscriptions, sub_key) do - nil -> acc - sub -> max(acc, sub.cursor) - end - end) - end - - defp drop_stream_producer(state, key) do - Map.update!(state, :stream_producers, &Map.delete(&1, key)) - end - - # Lazily rebuild stream_producer state from the DB if it's missing. - # Used after server restart — in-memory producer state is gone but - # the ``streams`` table still has the buffer. We rebuild on first - # append or subscribe for a given stream, recovering flow control. - # - # ``session_id`` is the internal id of the producer's current session; - # supply ``nil`` if not known, in which case demand grants will be - # deferred until the session is resolvable. - defp ensure_stream_producer( - state, - execution_id, - execution_external_id, - index, - session_id - ) do - key = {execution_id, index} - - cond do - Map.has_key?(state.stream_producers, key) -> - state - - true -> - case Streams.get_buffer(state.db, execution_id, index) do - {:ok, nil} -> - # Stream opted out of backpressure; nothing to track. - state - - {:ok, buffer} when is_integer(buffer) -> - # Reconstruct state. demand_granted starts at items already - # produced — we assume earlier-us granted enough for those, - # and rely on the producer having kept its local credit - # counter consistent. - {:ok, head} = Streams.get_stream_head(state.db, execution_id, index) - items_produced = if head < 0, do: 0, else: head + 1 - - put_in(state.stream_producers[key], %{ - buffer: buffer, - demand_granted: items_produced, - session_id: session_id, - execution_external_id: execution_external_id - }) - - {:error, :not_found} -> - state - end - end - end - def handle_call( {:append_stream_item, execution_external_id, index, sequence, value}, _from, @@ -2210,6 +2057,36 @@ defmodule Coflux.Orchestration.Server do state = push_backlog(state, key) state = maybe_push_closure_if_closed(state, key) + # Record the subscribe as a lineage edge (consumer -> producer stream). + # Done unconditionally on subscribe, independent of whether items end + # up being read. Uses execution_refs so the edge survives epoch + # rotation. + {:ok, stream_ref_id} = + Runs.create_execution_ref_for(state.db, producer_execution_id) + + {:ok, inserted_id} = + Runs.record_stream_dependency(state.db, consumer_execution_id, stream_ref_id, index) + + state = + if inserted_id do + {:ok, {run_external_id}} = + Runs.get_external_run_id_for_execution(state.db, consumer_execution_id) + + {producer_ext_id, _module, _target} = + producer_metadata = resolve_execution_ref(state.db, stream_ref_id) + + notify_listeners( + state, + {:run, run_external_id}, + {:stream_dependency, consumer_execution_external_id, producer_ext_id, index, + producer_metadata} + ) + else + state + end + + state = flush_notifications(state) + {:reply, :ok, state} else {:ok, false} -> {:reply, {:error, :stream_not_found}, state} @@ -2992,6 +2869,159 @@ defmodule Coflux.Orchestration.Server do {:reply, :ok, state} end + defp maybe_init_stream_producer( + state, + _execution_id, + _execution_external_id, + _index, + nil, + _session_id + ) do + # buffer=nil means the producer has opted out of backpressure — no + # tracking required on the server side. It'll emit freely and the + # adapter's driver never waits. + state + end + + defp maybe_init_stream_producer( + state, + execution_id, + execution_external_id, + index, + buffer, + session_id + ) + when is_integer(buffer) and buffer >= 0 do + put_in(state.stream_producers[{execution_id, index}], %{ + buffer: buffer, + demand_granted: 0, + session_id: session_id, + execution_external_id: execution_external_id + }) + end + + defp maybe_send_initial_demand(state, execution_id, index) do + # At registration time there are no subscribers yet. Allow the + # producer to pre-warm up to `buffer` items; lockstep (buffer=0) + # stays paused until a consumer attaches. + refresh_stream_demand(state, {execution_id, index}) + end + + # Recompute the target demand for one stream and, if it's grown, + # send a delta grant to the producer's session. + # + # Formula: + # target = max_cursor + buffer + (1 if has_subscribers else 0) + # The +1 on subscriber presence unblocks lockstep streams — a + # consumer's cursor at position N means "ready for item N", which is + # one item beyond what they've acked. + # + # demand_granted is monotonic; if target drops (e.g. the fastest + # consumer left) we don't claw back, future grants just wait until + # the remaining subscribers catch up past the old max. + defp refresh_stream_demand(state, {_execution_id, index} = key) do + case Map.fetch(state.stream_producers, key) do + :error -> + state + + {:ok, %{session_id: nil}} -> + # Producer's session is gone — typically because the producer + # execution has long since terminated and we rebuilt its in-memory + # state for a late subscriber. There's nothing to grant demand to; + # the stream is durable in the DB and backlog reads don't consume + # credits. + state + + {:ok, producer} -> + has_subscribers = has_stream_subscribers?(state, key) + max_cursor = current_max_cursor(state, key) + bump = if has_subscribers, do: 1, else: 0 + target = max_cursor + producer.buffer + bump + delta = target - producer.demand_granted + + if delta > 0 do + state + |> put_in([Access.key(:stream_producers), key, :demand_granted], target) + |> send_session( + producer.session_id, + {:stream_demand, producer.execution_external_id, index, delta} + ) + else + state + end + end + end + + defp has_stream_subscribers?(state, key) do + case Map.get(state.stream_subscribers, key) do + nil -> false + set -> MapSet.size(set) > 0 + end + end + + defp current_max_cursor(state, key) do + state.stream_subscribers + |> Map.get(key, MapSet.new()) + |> Enum.reduce(0, fn sub_key, acc -> + case Map.get(state.stream_subscriptions, sub_key) do + nil -> acc + sub -> max(acc, sub.cursor) + end + end) + end + + defp drop_stream_producer(state, key) do + Map.update!(state, :stream_producers, &Map.delete(&1, key)) + end + + # Lazily rebuild stream_producer state from the DB if it's missing. + # Used after server restart — in-memory producer state is gone but + # the ``streams`` table still has the buffer. We rebuild on first + # append or subscribe for a given stream, recovering flow control. + # + # ``session_id`` is the internal id of the producer's current session; + # supply ``nil`` if not known, in which case demand grants will be + # deferred until the session is resolvable. + defp ensure_stream_producer( + state, + execution_id, + execution_external_id, + index, + session_id + ) do + key = {execution_id, index} + + cond do + Map.has_key?(state.stream_producers, key) -> + state + + true -> + case Streams.get_buffer(state.db, execution_id, index) do + {:ok, nil} -> + # Stream opted out of backpressure; nothing to track. + state + + {:ok, buffer} when is_integer(buffer) -> + # Reconstruct state. demand_granted starts at items already + # produced — we assume earlier-us granted enough for those, + # and rely on the producer having kept its local credit + # counter consistent. + {:ok, head} = Streams.get_stream_head(state.db, execution_id, index) + items_produced = if head < 0, do: 0, else: head + 1 + + put_in(state.stream_producers[key], %{ + buffer: buffer, + demand_granted: items_produced, + session_id: session_id, + execution_external_id: execution_external_id + }) + + {:error, :not_found} -> + state + end + end + end + def handle_cast({:unsubscribe, ref}, state) do Process.demonitor(ref, [:flush]) @@ -5082,6 +5112,7 @@ defmodule Coflux.Orchestration.Server do {:ok, steps} = Runs.get_run_steps(db, run.id) {:ok, run_executions} = Runs.get_run_executions(db, run.id) {:ok, run_dependencies} = Runs.get_run_dependencies(db, run.id) + {:ok, run_stream_dependencies} = Runs.get_run_stream_dependencies(db, run.id) {:ok, run_children} = Runs.get_run_children(db, run.id) {:ok, groups} = Runs.get_groups_for_run(db, run.id) {:ok, run_metric_defs} = Runs.get_run_metric_definitions(db, run.id) @@ -5282,12 +5313,25 @@ defmodule Coflux.Orchestration.Server do {ext_id, {:result, execution}} end) + stream_deps = + run_stream_dependencies + |> Map.get(execution_id, []) + |> Map.new(fn {stream_ref_id, stream_index} -> + {producer_ext_id, _module, _target} = + execution = resolve_execution_ref(db, stream_ref_id) + + {"#{producer_ext_id}:#{stream_index}", {:stream, stream_index, execution}} + end) + dependencies = Map.merge( result_deps, Map.merge( - Map.get(input_deps_by_execution, execution_id, %{}), - Map.get(asset_deps_by_execution, execution_id, %{}) + stream_deps, + Map.merge( + Map.get(input_deps_by_execution, execution_id, %{}), + Map.get(asset_deps_by_execution, execution_id, %{}) + ) ) ) diff --git a/server/lib/coflux/topics/run.ex b/server/lib/coflux/topics/run.ex index ae9ac973..2051f486 100644 --- a/server/lib/coflux/topics/run.ex +++ b/server/lib/coflux/topics/run.ex @@ -178,6 +178,30 @@ defmodule Coflux.Topics.Run do ) end + defp process_notification( + topic, + {:stream_dependency, execution_external_id, producer_execution_id, index, + producer_metadata} + ) do + dependency = %{ + type: "stream", + execution: build_execution(producer_metadata), + index: index + } + + update_execution( + topic, + execution_external_id, + fn topic, base_path -> + Topic.merge( + topic, + base_path ++ [:dependencies, "#{producer_execution_id}:#{index}"], + dependency + ) + end + ) + end + defp process_notification(topic, {:child, parent_execution_external_id, child}) do child = build_child(child, topic.state.external_run_id) @@ -437,6 +461,14 @@ defmodule Coflux.Topics.Run do {id, {:asset, asset}} -> {id, %{type: "asset", assetId: id, asset: build_asset(asset)}} + + {id, {:stream, index, execution}} -> + {id, + %{ + type: "stream", + execution: build_execution(execution), + index: index + }} end) end diff --git a/server/priv/migrations/orchestration/4.sql b/server/priv/migrations/orchestration/4.sql index 545475e8..df983580 100644 --- a/server/priv/migrations/orchestration/4.sql +++ b/server/priv/migrations/orchestration/4.sql @@ -193,3 +193,22 @@ CREATE TABLE stream_closures ( FOREIGN KEY (error_id) REFERENCES errors ON DELETE RESTRICT, CHECK ((reason = 1) = (error_id IS NOT NULL)) ) STRICT; + +-- Track stream subscriptions as a lineage edge between executions, mirroring +-- result_dependencies / asset_dependencies. A row is written when a consumer +-- subscribes to a producer's stream (regardless of whether items are read), +-- so data lineage is preserved even for subscriptions that yield no values. +-- +-- The producer side is referenced via `execution_refs` (not the live +-- `executions` row) so the edge survives epoch rotation, and by `stream_index` +-- so we can distinguish between multiple streams produced by the same +-- execution. +CREATE TABLE stream_dependencies ( + execution_id INTEGER NOT NULL, + stream_ref_id INTEGER NOT NULL, + stream_index INTEGER NOT NULL, + created_at INTEGER NOT NULL, + PRIMARY KEY (execution_id, stream_ref_id, stream_index), + FOREIGN KEY (execution_id) REFERENCES executions ON DELETE CASCADE, + FOREIGN KEY (stream_ref_id) REFERENCES execution_refs ON DELETE RESTRICT +) STRICT;