Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions AGENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ chainweaver/
├── builder.py FlowBuilder: fluent API for constructing Flow objects
├── decorators.py @tool decorator for zero-boilerplate tool definition
├── tools.py Tool class: named callable with Pydantic I/O schemas
├── flow.py FlowStep + Flow: ordered step definitions (Pydantic models)
├── registry.py FlowRegistry: in-memory catalogue of named flows
├── executor.py FlowExecutor: sequential, LLM-free runner (main entry point)
├── flow.py FlowStep + Flow (linear) + DAGFlowStep + DAGFlow + validate_dag_topology
├── registry.py FlowRegistry: in-memory catalogue of Flow and DAGFlow
├── executor.py FlowExecutor: sequential/DAG runner (main entry point)
├── exceptions.py Typed exception hierarchy (all inherit ChainWeaverError)
├── log_utils.py Structured per-step logging utilities
└── py.typed PEP 561 marker
Expand Down Expand Up @@ -144,6 +144,7 @@ For the full prohibited-actions list and anti-patterns, see
| Add a new exception | `exceptions.py` | `__init__.py` + `__all__` + README error table — **same PR** |
| Modify flow execution | `executor.py` | Keep `StepRecord` + `ExecutionResult` consistent |
| Add a new Flow field | `flow.py` | Serialization tests if `model_dump()` changes |
| Add a new DAGFlow / DAGFlowStep field | `flow.py` | Update `validate_dag_topology` if needed; update tests |
| Change logging format | `log_utils.py` | Update tests (no re-export needed) |
| Add a new module | See [new-module checklist](docs/agent-context/workflows.md#new-module-checklist) |

Expand Down
7 changes: 5 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,7 @@ All errors are typed and traceable:
| `InputMappingError` | A mapping key is not present in the context |
| `FlowExecutionError` | The tool callable raises an unexpected exception |
| `ToolDefinitionError` | The `@tool` decorator cannot build a tool from a function |
| `DAGDefinitionError` | A `DAGFlow` has a cycle, duplicate `step_id`, or unknown dependency |
| `FlowBuilderError` | `FlowBuilder.build()` is called without a name or description |

All exceptions inherit from `ChainWeaverError`.
Expand All @@ -458,8 +459,10 @@ All exceptions inherit from `ChainWeaverError`.

### v0.2 — DAG & Branching

- [ ] DAG-based execution with dependency edges
- [ ] Parallel step groups
- [x] DAG-based execution with dependency edges (`DAGFlow`, `DAGFlowStep`)
- [x] Topological level-grouped execution with sibling key-conflict detection
- [x] `DAGDefinitionError` — cycle / duplicate ID / unknown dep detected at registration
- [ ] Actual parallel/async execution for independent levels
- [ ] Conditional branching inside flows

### v0.3 — Persistence & Learning
Expand Down
13 changes: 11 additions & 2 deletions chainweaver/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,13 @@

.. code-block:: python

from chainweaver import Tool, Flow, FlowStep, FlowRegistry, FlowExecutor, FlowBuilder
from chainweaver import (
Tool, Flow, FlowStep, DAGFlow, DAGFlowStep,
FlowBuilder, FlowRegistry, FlowExecutor, validate_dag_topology,
)
from chainweaver.exceptions import (
ChainWeaverError,
DAGDefinitionError,
ToolNotFoundError,
FlowNotFoundError,
FlowAlreadyExistsError,
Expand All @@ -26,6 +30,7 @@
from chainweaver.decorators import tool
from chainweaver.exceptions import (
ChainWeaverError,
DAGDefinitionError,
FlowAlreadyExistsError,
FlowExecutionError,
FlowNotFoundError,
Expand All @@ -35,7 +40,7 @@
ToolNotFoundError,
)
from chainweaver.executor import ExecutionResult, FlowExecutor, StepRecord
from chainweaver.flow import Flow, FlowStep
from chainweaver.flow import DAGFlow, DAGFlowStep, Flow, FlowStep, validate_dag_topology
from chainweaver.registry import FlowRegistry
from chainweaver.tools import Tool

Expand All @@ -47,6 +52,9 @@

__all__ = [
"ChainWeaverError",
"DAGDefinitionError",
"DAGFlow",
"DAGFlowStep",
"ExecutionResult",
"Flow",
"FlowAlreadyExistsError",
Expand All @@ -64,4 +72,5 @@
"ToolDefinitionError",
"ToolNotFoundError",
"tool",
"validate_dag_topology",
]
22 changes: 22 additions & 0 deletions chainweaver/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,25 @@ def __init__(self, function_name: str, detail: str) -> None:
self.function_name = function_name
self.detail = detail
super().__init__(f"Cannot define tool from function '{function_name}': {detail}")


class DAGDefinitionError(ChainWeaverError):
"""Raised when a :class:`~chainweaver.flow.DAGFlow` definition is invalid.

Attributes:
flow_name: Name of the flow that failed validation.
reason: Machine-readable reason code. One of ``"cycle"``,
``"duplicate_step_id"``, or ``"unknown_dependency"``.
detail: Human-readable explanation.
"""

def __init__(
self,
flow_name: str,
reason: str,
detail: str,
) -> None:
self.flow_name = flow_name
self.reason = reason
self.detail = detail
super().__init__(f"Invalid DAG flow '{flow_name}' ({reason}): {detail}")
240 changes: 237 additions & 3 deletions chainweaver/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from __future__ import annotations

from dataclasses import dataclass, field
from graphlib import TopologicalSorter
from typing import Any

from pydantic import ValidationError
Expand All @@ -18,7 +19,7 @@
SchemaValidationError,
ToolNotFoundError,
)
from chainweaver.flow import FlowStep
from chainweaver.flow import DAGFlow, DAGFlowStep, FlowStep, validate_dag_topology
from chainweaver.log_utils import get_logger, log_step_end, log_step_error, log_step_start
from chainweaver.registry import FlowRegistry
from chainweaver.tools import Tool
Expand Down Expand Up @@ -107,8 +108,8 @@ class FlowExecutor:
print(result.final_output) # {"result": "Final value: 20"}

# TODO (Phase 2): Add async execution mode for I/O-bound tool chains.
# TODO (Phase 2): Support DAG execution with dependency resolution and
# parallel step groups.
# TODO (Phase 2): Support parallel/async execution for independent DAG
# levels (currently steps within a level run sequentially).
# TODO (Phase 2): Add middleware hooks (before_step / after_step) for
# observability and tracing integrations.
"""
Expand Down Expand Up @@ -159,6 +160,9 @@ def execute_flow(
FlowNotFoundError: When *flow_name* is not registered.
"""
flow = self._registry.get_flow(flow_name)
if isinstance(flow, DAGFlow):
return self._execute_dag_flow(flow, initial_input)

_logger.info("Flow '%s' started | steps=%d", flow_name, len(flow.steps))

# -- Flow-level input validation ------------------------------------
Expand Down Expand Up @@ -362,3 +366,233 @@ def _execute_step(
outputs=outputs,
success=True,
)

# ------------------------------------------------------------------
# DAG execution
# ------------------------------------------------------------------

def _compute_dag_levels(self, flow: DAGFlow) -> list[list[DAGFlowStep]]:
"""Return steps grouped into topological execution levels.

Within each level all steps are independent (no inter-level edges).
Steps in the same level can conceptually run in parallel; today they
run sequentially in list order.

Topology is normally validated at registration time. This method
still calls ``validate_dag_topology`` as a belt-and-suspenders guard
for flows that are created and executed without going through
:class:`~chainweaver.registry.FlowRegistry`, so invalid DAGs may
raise :class:`~chainweaver.exceptions.DAGDefinitionError` here.

Level computation uses :class:`graphlib.TopologicalSorter` to iterate
steps in dependency order, so the result is correct regardless of the
declaration order of steps in ``flow.steps``.

Args:
flow: A valid :class:`~chainweaver.flow.DAGFlow`.

Returns:
A list of levels, each level being a list of
:class:`~chainweaver.flow.DAGFlowStep` objects.
"""
validate_dag_topology(flow)
step_by_id = {s.step_id: s for s in flow.steps}
graph: dict[str, set[str]] = {s.step_id: set(s.depends_on) for s in flow.steps}
sorter: TopologicalSorter[str] = TopologicalSorter(graph)
topo_order = list(sorter.static_order())

# level[step_id] = 0-based level index
levels: dict[str, int] = {}
for step_id in topo_order:
step = step_by_id[step_id]
if not step.depends_on:
levels[step_id] = 0
else:
levels[step_id] = max(levels[dep] for dep in step.depends_on) + 1

Comment thread
dgenio marked this conversation as resolved.
max_level = max(levels.values(), default=-1)
grouped: list[list[DAGFlowStep]] = [[] for _ in range(max_level + 1)]
for step_id in topo_order:
grouped[levels[step_id]].append(step_by_id[step_id])
return grouped

def _execute_dag_flow(
self,
flow: DAGFlow,
initial_input: dict[str, Any],
) -> ExecutionResult:
"""Execute a :class:`~chainweaver.flow.DAGFlow`.

Steps are executed level-by-level in topological order. Within each
level steps run sequentially. Outputs from all steps in a level are
collected and merged into the shared context before the next level
starts. If two sibling steps (same level) produce the same output
key a :class:`~chainweaver.exceptions.FlowExecutionError` is raised
immediately to preserve determinism.

Args:
flow: The :class:`~chainweaver.flow.DAGFlow` to execute.
initial_input: Initial key/value context.

Returns:
An :class:`ExecutionResult` with the full execution log.
"""
_logger.info("DAGFlow '%s' started | steps=%d", flow.name, len(flow.steps))

# -- Flow-level input validation ------------------------------------
if flow.input_schema is not None:
try:
flow.input_schema.model_validate(initial_input)
except ValidationError as exc:
wrapped = SchemaValidationError(flow.name, -1, str(exc), context="flow_input")
_logger.error("DAGFlow '%s' input validation failed: %s", flow.name, wrapped)
return ExecutionResult(
flow_name=flow.name,
success=False,
final_output=None,
execution_log=[
StepRecord(
step_index=-1,
tool_name=flow.name,
inputs=dict(initial_input),
error=wrapped,
success=False,
)
],
)

context: dict[str, Any] = dict(initial_input)
log: list[StepRecord] = []
levels = self._compute_dag_levels(flow)
# Flat index for StepRecord.step_index (mirrors linear flow behaviour).
flat_index = 0

for level_steps in levels:
level_outputs: dict[str, Any] = {}
level_records: list[StepRecord] = []

for step in level_steps:
# Reject non-tool step types until KernelBackedExecutor exists.
if step.step_type != "tool":
err = FlowExecutionError(
step.tool_name,
flat_index,
f"Step '{step.step_id}' has step_type='{step.step_type}' "
f"which is not supported by FlowExecutor. "
f"Only step_type='tool' can be executed.",
)
log_step_error(_logger, flat_index, step.tool_name, err)
log.extend(level_records)
log.append(
StepRecord(
step_index=flat_index,
tool_name=step.tool_name,
inputs={},
error=err,
success=False,
)
)
return ExecutionResult(
flow_name=flow.name,
success=False,
final_output=None,
execution_log=log,
)

# Build a lightweight FlowStep-compatible view so _execute_step
# can be reused without modification.
proxy = FlowStep(
tool_name=step.tool_name,
input_mapping=step.input_mapping,
)
record = self._execute_step(flat_index, proxy, context)
Comment thread
dgenio marked this conversation as resolved.
level_records.append(record)
flat_index += 1

if not record.success:
log.extend(level_records)
_logger.error(
"DAGFlow '%s' aborted at step %d (%s)",
flow.name,
record.step_index,
step.tool_name,
)
return ExecutionResult(
flow_name=flow.name,
success=False,
final_output=None,
execution_log=log,
)

assert record.outputs is not None # success guarantees outputs
# Detect sibling key conflicts to preserve determinism.
for key, value in record.outputs.items():
if key in level_outputs:
conflict_err = FlowExecutionError(
step.tool_name,
record.step_index,
f"Key '{key}' produced by both '{step.tool_name}' and a "
f"sibling step in the same DAG level. "
f"Use distinct output keys or sequential steps.",
)
record_conflict = StepRecord(
step_index=record.step_index,
tool_name=step.tool_name,
inputs=record.inputs,
error=conflict_err,
success=False,
)
log.extend(level_records[:-1])
log.append(record_conflict)
_logger.error("DAGFlow '%s': sibling key conflict on '%s'", flow.name, key)
return ExecutionResult(
flow_name=flow.name,
success=False,
final_output=None,
execution_log=log,
)
level_outputs[key] = value

log.extend(level_records)
# Merge all level outputs into context after the level completes.
for key in level_outputs:
if key in context:
_logger.debug(
"DAGFlow '%s': context key '%s' overwritten by level output",
flow.name,
key,
)
context.update(level_outputs)

# -- Flow-level output validation -----------------------------------
if flow.output_schema is not None:
try:
flow.output_schema.model_validate(context)
except ValidationError as exc:
wrapped = SchemaValidationError(
flow.name, len(flow.steps), str(exc), context="flow_output"
)
_logger.error("DAGFlow '%s' output validation failed: %s", flow.name, wrapped)
return ExecutionResult(
flow_name=flow.name,
success=False,
final_output=None,
execution_log=[
*log,
StepRecord(
step_index=len(flow.steps),
tool_name=flow.name,
inputs=dict(context),
error=wrapped,
success=False,
),
],
)

_logger.info("DAGFlow '%s' completed successfully", flow.name)
return ExecutionResult(
flow_name=flow.name,
success=True,
final_output=context,
execution_log=log,
)
Loading