diff --git a/.github/workflows/security.yml b/.github/workflows/security.yml
index 0859dfd0..d2540862 100644
--- a/.github/workflows/security.yml
+++ b/.github/workflows/security.yml
@@ -105,6 +105,28 @@ jobs:
--ignore-vuln GHSA-vfmq-68hx-4jfw \
--ignore-vuln GHSA-j7w6-vpvq-j3gm \
--ignore-vuln GHSA-98h9-4798-4q5v \
+ --ignore-vuln PYSEC-2025-189 \
+ --ignore-vuln PYSEC-2025-190 \
+ --ignore-vuln PYSEC-2025-191 \
+ --ignore-vuln PYSEC-2025-192 \
+ --ignore-vuln PYSEC-2025-193 \
+ --ignore-vuln PYSEC-2025-194 \
+ --ignore-vuln PYSEC-2025-195 \
+ --ignore-vuln PYSEC-2025-196 \
+ --ignore-vuln PYSEC-2025-197 \
+ --ignore-vuln PYSEC-2025-210 \
+ --ignore-vuln PYSEC-2026-139 \
+ --ignore-vuln PYSEC-2025-211 \
+ --ignore-vuln PYSEC-2025-212 \
+ --ignore-vuln PYSEC-2025-213 \
+ --ignore-vuln PYSEC-2025-214 \
+ --ignore-vuln PYSEC-2025-215 \
+ --ignore-vuln PYSEC-2025-216 \
+ --ignore-vuln PYSEC-2025-217 \
+ --ignore-vuln PYSEC-2025-218 \
+ --ignore-vuln PYSEC-2026-97 \
+ --ignore-vuln PYSEC-2025-183 \
+ --ignore-vuln PYSEC-2024-277 \
-r /tmp/requirements-worker-cpu-audit.txt
- name: Run pip-audit (worker GPU delta) # no --strict: flashinfer-jit-cache is unauditable on PyPI
run: |
@@ -130,4 +152,28 @@ jobs:
--ignore-vuln GHSA-83vm-p52w-f9pw \
--ignore-vuln GHSA-j7w6-vpvq-j3gm \
--ignore-vuln GHSA-98h9-4798-4q5v \
+ --ignore-vuln PYSEC-2025-189 \
+ --ignore-vuln PYSEC-2025-190 \
+ --ignore-vuln PYSEC-2025-191 \
+ --ignore-vuln PYSEC-2025-192 \
+ --ignore-vuln PYSEC-2025-193 \
+ --ignore-vuln PYSEC-2025-194 \
+ --ignore-vuln PYSEC-2025-195 \
+ --ignore-vuln PYSEC-2025-196 \
+ --ignore-vuln PYSEC-2025-197 \
+ --ignore-vuln PYSEC-2025-210 \
+ --ignore-vuln PYSEC-2026-139 \
+ --ignore-vuln PYSEC-2025-211 \
+ --ignore-vuln PYSEC-2025-212 \
+ --ignore-vuln PYSEC-2025-213 \
+ --ignore-vuln PYSEC-2025-214 \
+ --ignore-vuln PYSEC-2025-215 \
+ --ignore-vuln PYSEC-2025-216 \
+ --ignore-vuln PYSEC-2025-217 \
+ --ignore-vuln PYSEC-2025-218 \
+ --ignore-vuln PYSEC-2026-97 \
+ --ignore-vuln PYSEC-2025-183 \
+ --ignore-vuln PYSEC-2024-277 \
+ --ignore-vuln PYSEC-2025-222 \
+ --ignore-vuln PYSEC-2024-274 \
-r src/worker/requirements/requirements.gpu.txt
diff --git a/docs/CODE_STYLE.md b/docs/CODE_STYLE.md
index 33ed7383..394628a9 100644
--- a/docs/CODE_STYLE.md
+++ b/docs/CODE_STYLE.md
@@ -89,13 +89,13 @@ CI runs `pip-audit` against each generated requirements file
When pip-audit reports a new CVE, the only real fix is to bump the
offending dep in `pyproject.toml`, then `uv lock` and `uv run
scripts/dev/sync_requirements.py --write`. Silencing via `--ignore-vuln`
-is a last resort; every silenced GHSA needs a written upgrade-blocker.
-The currently-ignored advisories and the upgrade blocker that justifies
-each are listed below; the same list is encoded as `--ignore-vuln`
-flags in `.github/workflows/security.yml`.
+is a last resort; every silenced advisory needs a written
+upgrade-blocker. The currently-ignored advisories and the upgrade
+blocker that justifies each are listed below; the same list is encoded
+as `--ignore-vuln` flags in `.github/workflows/security.yml`.
-| GHSA | Package | Fix version | Why ignored |
-|------|---------|-------------|-------------|
+| Advisory | Package | Fix version | Why ignored |
+|----------|---------|-------------|-------------|
| GHSA-69w3-r845-3855 | transformers | 5.0.0rc3 | held by vllm/vllm-omni 0.18 compatibility |
| GHSA-pf3h-qjgv-vcpr | vllm | 0.19.0 | held by transformers 4.57 + adjacent inference deps |
| GHSA-pq5c-rjhq-qp7p | vllm | 0.19.0 | same |
@@ -117,6 +117,30 @@ flags in `.github/workflows/security.yml`.
| GHSA-w8v5-vhqr-4h9v | diskcache | (none) | upstream unmaintained, no fixed version published |
| GHSA-j7w6-vpvq-j3gm | diffusers | 0.38.0 | fix requires safetensors>=0.8.0rc0 pre-release; uv lock won't pick up pre-releases without explicit opt-in |
| GHSA-98h9-4798-4q5v | diffusers | 0.38.0 | same blocker as GHSA-j7w6-vpvq-j3gm — both fixed in 0.38.0 |
+| PYSEC-2025-189 | torch | (none) | no fix version published |
+| PYSEC-2025-190 | torch | (none) | same |
+| PYSEC-2025-191 | torch | (none) | same |
+| PYSEC-2025-192 | torch | (none) | same |
+| PYSEC-2025-193 | torch | (none) | same |
+| PYSEC-2025-194 | torch | (none) | same |
+| PYSEC-2025-195 | torch | (none) | same |
+| PYSEC-2025-196 | torch | (none) | same |
+| PYSEC-2025-197 | torch | (none) | same |
+| PYSEC-2025-210 | torch | (none) | same |
+| PYSEC-2026-139 | torch | (none) | same |
+| PYSEC-2025-211 | transformers | (none) | no fix version published; transformers also held by vllm-omni 0.18 |
+| PYSEC-2025-212 | transformers | (none) | same |
+| PYSEC-2025-213 | transformers | (none) | same |
+| PYSEC-2025-214 | transformers | (none) | same |
+| PYSEC-2025-215 | transformers | (none) | same |
+| PYSEC-2025-216 | transformers | (none) | same |
+| PYSEC-2025-217 | transformers | (none) | same |
+| PYSEC-2025-218 | transformers | (none) | same |
+| PYSEC-2026-97 | nltk | (none) | no fix version published |
+| PYSEC-2025-183 | pyjwt | (none) | no fix version published |
+| PYSEC-2024-277 | joblib | (none) | no fix version published |
+| PYSEC-2025-222 | vllm | (none) | no fix version published; held by vllm-omni 0.18 pin |
+| PYSEC-2024-274 | gradio | (none) | no fix version published; vllm-omni 0.18 pins gradio==5.50 |
When a blocker lifts (e.g. transformers 5 ↔ vllm 0.19 line stabilizes),
drop the corresponding `--ignore-vuln` flag from the workflow and the
diff --git a/docs/EXECUTORS.md b/docs/EXECUTORS.md
index ed02e903..155efb08 100644
--- a/docs/EXECUTORS.md
+++ b/docs/EXECUTORS.md
@@ -22,6 +22,26 @@ Helper utilities live in `src/worker/executors/utils/` (`artifacts`,
`src/worker/executors/mixins/` (`data`, `governance`, `inference`,
`training`).
+## Result schema
+
+Every executor's `run()` returns a subclass of `BaseExecutorResult`
+(`src/shared/schemas/result.py`). The base class carries two
+cross-cutting fields:
+
+- `children: dict[str, BaseExecutorResult]` — per-child results when
+ merged tasks share a dispatch.
+- `artifacts: ArtifactContext | None` (wire key `_artifacts`) —
+ resolution context for relative artifact refs.
+
+Per-executor subclasses live next to the executor they describe — e.g.
+`VLLMResult` in `src/worker/executors/vllm_executor.py`, `LoRAResult` in
+`src/worker/executors/lora_sft_executor.py`. They add executor-specific
+fields (`items`, `usage`, `final_lora`, `command`, …).
+
+Artifact-bearing fields use `ArtifactRef` (`{"path": rel_path}`);
+relative paths resolve against the producer's `_artifacts` context via
+`artifact_to_source` / `_render_artifact_ref`.
+
## Agent executor (utu / youtu-agent)
`AgentExecutor` requires the following env vars to run; the executor
diff --git a/sdk/pyproject.toml b/sdk/pyproject.toml
index 881e7689..97ca5abc 100644
--- a/sdk/pyproject.toml
+++ b/sdk/pyproject.toml
@@ -13,7 +13,7 @@ license-files = ["LICENSE"]
dependencies = [
"httpx>=0.27.0",
"pandas>=2.3.3",
- "pydantic>=2.0.0",
+ "pydantic>=2.12.3",
"pyyaml>=6.0.0",
]
diff --git a/sdk/src/flowmesh/models/__init__.py b/sdk/src/flowmesh/models/__init__.py
index e2043805..ea08aa5a 100644
--- a/sdk/src/flowmesh/models/__init__.py
+++ b/sdk/src/flowmesh/models/__init__.py
@@ -1,5 +1,6 @@
"""FlowMesh SDK data models."""
+from .artifacts import ArtifactContext, ArtifactRef
from .common import (
LogEntry,
LogEvent,
@@ -19,7 +20,7 @@
NodeWorkerInfo,
WorkerRegisterResponse,
)
-from .results import PathResponse, ResultEnvelope
+from .results import BaseExecutorResult, PathResponse, ResultEnvelope
from .tasks import HardwareUsage, TaskInfo, TaskUsage
from .traces import (
ActiveWaitBreakdown,
@@ -54,7 +55,10 @@
__all__ = [
"ActiveWaitBreakdown",
+ "ArtifactContext",
+ "ArtifactRef",
"AssetSummary",
+ "BaseExecutorResult",
"CPUInfo",
"CriticalPathSummary",
"E2EBreakdown",
diff --git a/sdk/src/flowmesh/models/artifacts.py b/sdk/src/flowmesh/models/artifacts.py
new file mode 100644
index 00000000..fdbb3edb
--- /dev/null
+++ b/sdk/src/flowmesh/models/artifacts.py
@@ -0,0 +1,10 @@
+from pydantic import BaseModel, Field
+
+
+class ArtifactRef(BaseModel):
+ path: str
+
+
+class ArtifactContext(BaseModel):
+ base_dir: str
+ base_url: str | None = Field(default=None, exclude_if=lambda v: v is None)
diff --git a/sdk/src/flowmesh/models/results.py b/sdk/src/flowmesh/models/results.py
index 3368f1bc..399ea7b8 100644
--- a/sdk/src/flowmesh/models/results.py
+++ b/sdk/src/flowmesh/models/results.py
@@ -1,8 +1,14 @@
"""Result-related models."""
+# This is necessary to allow for the recursive type hint of `children` in
+# `BaseExecutorResult`.
+from __future__ import annotations
+
from typing import Any
-from pydantic import BaseModel, Field
+from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny
+
+from .artifacts import ArtifactContext
class PathResponse(BaseModel):
@@ -10,11 +16,30 @@ class PathResponse(BaseModel):
path: str
+class BaseExecutorResult(BaseModel):
+ model_config = ConfigDict(extra="allow", serialize_by_alias=True)
+
+ ok: bool = True
+ children: dict[str, SerializeAsAny[BaseExecutorResult]] = Field(
+ default_factory=dict, exclude_if=lambda v: not v
+ )
+ artifacts_: ArtifactContext | None = Field(default=None, alias="_artifacts")
+
+ @classmethod
+ def __pydantic_init_subclass__(cls, **kwargs: Any) -> None:
+ super().__pydantic_init_subclass__(**kwargs)
+ if "artifacts_" in cls.__annotations__:
+ raise TypeError(
+ f"{cls.__name__} may not redefine the internal "
+ "BaseExecutorResult.artifacts_ field"
+ )
+
+
class ResultEnvelope(BaseModel):
"""Canonical on-disk shape of ``results.json`` (mirrors the server)."""
task_id: str
- result: dict[str, Any]
+ result: SerializeAsAny[BaseExecutorResult]
worker_id: str | None = None
metadata: dict[str, Any] | None = None
received_at: str | None = Field(default=None)
diff --git a/sdk/src/flowmesh/resources/results.py b/sdk/src/flowmesh/resources/results.py
index 1f5e99cf..3742106b 100644
--- a/sdk/src/flowmesh/resources/results.py
+++ b/sdk/src/flowmesh/resources/results.py
@@ -198,10 +198,9 @@ def _finalize_materialize(
return {}, json_path, extracted
envelope = ResultEnvelope.model_validate_json(json_path.read_text())
- if _wants_artifacts(sections):
- ctx = envelope.result["_artifacts"]
- ctx["base_dir"] = (output_dir / task_id).resolve().as_posix()
- ctx.pop("base_url", None)
+ if _wants_artifacts(sections) and (ctx := envelope.result.artifacts_):
+ ctx.base_dir = (output_dir / task_id).resolve().as_posix()
+ ctx.base_url = None
payload = envelope.model_dump(mode="json")
json_path.write_text(json.dumps(payload, indent=2))
return payload, json_path, extracted
diff --git a/src/server/dispatcher/base.py b/src/server/dispatcher/base.py
index d3c15bc3..501a9b2d 100644
--- a/src/server/dispatcher/base.py
+++ b/src/server/dispatcher/base.py
@@ -9,12 +9,18 @@
from pydantic import BaseModel, ValidationError
from shared.schemas.event import TaskEvent
-from shared.schemas.result import ResultEnvelope, result_file_path, write_result
+from shared.schemas.result import (
+ BaseExecutorResult,
+ ResultEnvelope,
+ result_file_path,
+ write_result,
+)
from shared.tasks import (
MergedChildTaskStrict,
TaskEnvelope,
TaskEnvelopeStrict,
TaskEnvelopeTemplate,
+ TaskSpecStrict,
)
from shared.tasks.placeholders import PLACEHOLDER_PATTERN
from shared.tasks.specs import (
@@ -32,6 +38,8 @@
from ..utils.time import now_iso
from .worker_selector import DEFAULT_WORKER_SELECTION, select_worker
+_SENTINEL: Any = object()
+
class StageReferenceNotReady(Exception):
"""Raised when a task references a stage whose artifacts are not yet available."""
@@ -707,7 +715,7 @@ def _resolve_stage_references(
self, task_id: str, task: TaskEnvelopeTemplate, record: TaskRecord
) -> TaskEnvelopeStrict:
context = self._build_stage_context(record)
- resolved_task = task
+ resolved_task: TaskEnvelopeTemplate = task
if context and task.has_placeholder():
resolved_task = self._resolve_placeholders(task, context)
@@ -716,7 +724,7 @@ def _resolve_stage_references(
)
if upstream_results:
existing = resolved_task.spec.upstreamResults or {}
- merged: dict[str, Any] = {}
+ merged: dict[str, BaseExecutorResult] = {}
if isinstance(existing, dict):
merged.update(copy.deepcopy(existing))
merged.update(upstream_results)
@@ -730,7 +738,7 @@ def _resolve_stage_references(
return TaskEnvelopeStrict.model_validate(resolved_task)
- def _resolve_placeholders(self, value: Any, context: dict[str, Any]) -> Any:
+ def _resolve_placeholders(self, value: Any, context: dict[str, TaskRecord]) -> Any:
if isinstance(value, str):
exact = PLACEHOLDER_PATTERN.fullmatch(value)
if exact:
@@ -814,7 +822,7 @@ def _stage_context_keys(record: TaskRecord) -> tuple[str, ...]:
keys.append(record.graph_node_name)
return tuple(keys)
- def _resolve_reference(self, expr: str, context: dict[str, Any]) -> Any:
+ def _resolve_reference(self, expr: str, context: dict[str, TaskRecord]) -> Any:
expr = expr.strip()
if not expr:
raise ValueError("Empty stage reference")
@@ -831,49 +839,47 @@ def _resolve_reference(self, expr: str, context: dict[str, Any]) -> Any:
return stage_record.task_id
if stage_record.status != "DONE":
raise StageReferenceNotReady(f"Stage '{stage_name}' has not completed")
- data = self._load_stage_result(stage_record.task_id)
- value = self._dig_path(data, path.split("."))
+ envelope = self._load_stage_result(stage_record.task_id)
+ value = self._dig_path(envelope.result, path.split("."))
if value is None:
raise ValueError(f"Missing value for reference '{expr}'")
# If the referenced value is an artifact ref ({path: "..."}), render
# it as a full URL (when base_url is set) or an absolute filesystem
# path using the producing stage's top-level _artifacts context.
- if rendered := self._render_artifact_ref(value, data):
+ if rendered := self._render_artifact_ref(value, envelope):
return rendered
return value
@staticmethod
- def _render_artifact_ref(value: Any, stage_result: Any) -> str | None:
+ def _render_artifact_ref(value: Any, stage_result: ResultEnvelope) -> str | None:
if not isinstance(value, dict):
return None
path_value = value.get("path")
if not isinstance(path_value, str) or not path_value:
return None
- if not isinstance(stage_result, dict):
- return None
- ctx = stage_result.get("_artifacts")
- if not isinstance(ctx, dict):
+ ctx = stage_result.result.artifacts_
+ if ctx is None:
return None
- base_url = ctx.get("base_url")
- base_dir = ctx.get("base_dir")
- if isinstance(base_url, str) and base_url and isinstance(base_dir, str):
+ base_url = ctx.base_url
+ base_dir = ctx.base_dir
+ if base_url and base_dir:
task_id = Path(base_dir).name
return f"{base_url.rstrip('/')}/api/v1/results/{task_id}/files/{path_value}"
- if isinstance(base_dir, str) and base_dir:
+ if base_dir:
return (Path(base_dir) / "artifacts" / path_value).as_posix()
return None
def _collect_upstream_results(
self, context: dict[str, TaskRecord], current_task_id: str
- ) -> dict[str, Any]:
- results: dict[str, Any] = {}
+ ) -> dict[str, BaseExecutorResult]:
+ results: dict[str, BaseExecutorResult] = {}
for name, record in context.items():
if not record or record.task_id == current_task_id:
continue
if record.status != TaskStatus.DONE:
continue
try:
- data = self._load_stage_result(record.task_id)
+ envelope = self._load_stage_result(record.task_id)
except StageReferenceNotReady as exc:
raise exc
except Exception as exc:
@@ -884,11 +890,11 @@ def _collect_upstream_results(
exc,
)
continue
- results[name] = data
+ results[name] = envelope.result
return results
def _resolve_upstream_task_ids(
- self, record: TaskRecord, spec: Any
+ self, record: TaskRecord, spec: TaskSpecStrict
) -> dict[str, str] | None:
if not isinstance(spec, SSHSpecStrict) or not spec.inputs:
return None
@@ -915,14 +921,14 @@ def _resolve_upstream_task_ids(
resolved[stage_name] = upstream.task_id
return resolved or None
- def _load_stage_result(self, stage_task_id: str) -> dict[str, Any]:
+ def _load_stage_result(self, stage_task_id: str) -> ResultEnvelope:
path = result_file_path(self._results_dir, stage_task_id)
if not path.exists():
raise StageReferenceNotReady(
f"Result for task {stage_task_id} not found at {path}"
)
content = json.loads(path.read_text(encoding="utf-8"))
- return ResultEnvelope.model_validate(content).result
+ return ResultEnvelope.model_validate(content)
def _dig_path(self, data: Any, parts: list[str]) -> Any:
current = data
@@ -946,6 +952,11 @@ def _dig_path(self, data: Any, parts: list[str]) -> Any:
return None
current = current[idx]
continue
+ if isinstance(current, BaseModel):
+ current = getattr(current, part, _SENTINEL)
+ if current is _SENTINEL:
+ return None
+ continue
return None
return current
@@ -992,7 +1003,7 @@ def _evaluate_condition_skip(
)
skip_envelope = ResultEnvelope(
task_id=task_id,
- result={},
+ result=BaseExecutorResult(),
metadata={
"skipped": True,
"reason": "condition_not_met",
diff --git a/src/server/routers/v1/results.py b/src/server/routers/v1/results.py
index 9edceb13..e38e8d76 100644
--- a/src/server/routers/v1/results.py
+++ b/src/server/routers/v1/results.py
@@ -4,7 +4,6 @@
import tarfile
import tempfile
from pathlib import Path
-from typing import Any
from fastapi import (
APIRouter,
@@ -20,12 +19,12 @@
from pydantic import ValidationError
from shared.schemas.result import (
+ BaseExecutorResult,
ResultEnvelope,
read_result,
result_file_path,
write_result,
)
-from shared.utils.atomic import atomic_write_text
from shared.utils.manifest import ARTIFACTS_DIR, LOGS_DIR, RESULTS_NAME, sync_manifest
from ...app_state import (
@@ -116,7 +115,7 @@ async def get_result(
principal: PrincipalContext = Depends(authenticate_connection),
results_dir: Path = Depends(get_results_dir),
logger: logging.Logger = Depends(get_logger),
-) -> dict[str, Any]:
+) -> BaseExecutorResult:
task_id = (task_id or "").strip()
if not task_id:
raise HTTPException(
@@ -186,8 +185,6 @@ async def upload_result_file(
detail=f"Failed to store artifact: {exc}",
) from exc
- _rewrite_jsonl_export_paths(task_id, base_dir, target_path)
-
record = runtime.get_record(task_id)
expected_artifacts: list[str] = []
if record:
@@ -333,55 +330,6 @@ async def download_task_logs(
return FileResponse(target_path)
-def _rewrite_jsonl_export_paths(
- task_id: str, base_dir: Path, artifact_path: Path
-) -> None:
- results_path = base_dir / RESULTS_NAME
- if not results_path.exists():
- return
- try:
- payload = json.loads(results_path.read_text(encoding="utf-8"))
- except Exception:
- return
-
- filename = artifact_path.name
- new_abs = str(artifact_path)
- new_relative = f"{ARTIFACTS_DIR}/{filename}"
- updated = False
-
- def _update(entry: dict[str, Any]) -> None:
- nonlocal updated
- if not isinstance(entry, dict):
- return
- block = entry.get("jsonl_export")
- if isinstance(block, dict):
- path_value = str(block.get("path") or "")
- if path_value and Path(path_value).name == filename:
- if "worker_path" not in block and path_value != new_abs:
- block["worker_path"] = path_value
- block["path"] = new_abs
- block["relative_path"] = new_relative
- block.setdefault("url", f"/api/v1/results/{task_id}/files/{filename}")
- updated = True
- children = entry.get("children")
- if isinstance(children, dict):
- for child_entry in children.values():
- if isinstance(child_entry, dict):
- _update(child_entry)
-
- result_entry = payload.get("result") if isinstance(payload, dict) else None
- if isinstance(result_entry, dict):
- _update(result_entry)
-
- if updated:
- try:
- atomic_write_text(
- results_path, json.dumps(payload, ensure_ascii=False, indent=2)
- )
- except Exception:
- pass
-
-
def _resolve_bundle_sections(include: list[str]) -> tuple[str, ...]:
if not include:
return _BUNDLE_SECTIONS_DEFAULT
diff --git a/src/shared/schemas/artifact.py b/src/shared/schemas/artifact.py
new file mode 100644
index 00000000..724d8819
--- /dev/null
+++ b/src/shared/schemas/artifact.py
@@ -0,0 +1,14 @@
+from pydantic import BaseModel, Field
+
+
+class ArtifactRef(BaseModel):
+ path: str = Field(description="Path relative to the task's artifacts/ dir.")
+
+
+class ArtifactContext(BaseModel):
+ base_dir: str = Field(description="Producing task's output directory.")
+ base_url: str | None = Field(
+ default=None,
+ exclude_if=lambda v: v is None,
+ description="HTTP origin (scheme://host[:port]) for upload.",
+ )
diff --git a/src/shared/schemas/result.py b/src/shared/schemas/result.py
index 0eea3e5c..e4378f90 100644
--- a/src/shared/schemas/result.py
+++ b/src/shared/schemas/result.py
@@ -1,18 +1,57 @@
"""Result envelope schema shared by server and worker."""
+# This is necessary to allow for the recursive type hint of `children` in
+# `BaseExecutorResult`.
+from __future__ import annotations
+
from pathlib import Path
from typing import Any
-from pydantic import BaseModel, Field
+from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny
from shared.utils.atomic import atomic_write_text
from shared.utils.manifest import prepare_output_dir
from shared.utils.time import now_iso
+from .artifact import ArtifactContext
+
+
+class BaseExecutorResult(BaseModel):
+ """Common shape for every executor's result payload.
+
+ ``extra="allow"`` lets the server round-trip subclass payloads through
+ this base class without losing executor-specific fields.
+ """
+
+ model_config = ConfigDict(extra="allow", serialize_by_alias=True)
+
+ ok: bool = Field(default=True, description="Whether task execution succeeded.")
+ children: dict[str, SerializeAsAny[BaseExecutorResult]] = Field(
+ default_factory=dict,
+ exclude_if=lambda v: not v,
+ description="Per-child result payloads for task merging.",
+ )
+ artifacts_: ArtifactContext | None = Field(
+ default=None,
+ alias="_artifacts",
+ description="Resolution context for relative artifact refs.",
+ )
+
+ @classmethod
+ def __pydantic_init_subclass__(cls, **kwargs: Any) -> None:
+ super().__pydantic_init_subclass__(**kwargs)
+ if "artifacts_" in cls.__annotations__:
+ raise TypeError(
+ f"{cls.__name__} may not redefine the internal "
+ "BaseExecutorResult.artifacts_ field"
+ )
+
class ResultEnvelope(BaseModel):
task_id: str = Field(description="Task identifier.")
- result: dict[str, Any] = Field(description="Result payload data.")
+ result: SerializeAsAny[BaseExecutorResult] = Field(
+ description="Result payload data."
+ )
worker_id: str | None = Field(
default=None, description="Worker identifier submitting the result."
)
@@ -40,13 +79,6 @@ def write_result(base_dir: Path, envelope: ResultEnvelope) -> Path:
return path
-def write_result_in_envelope(path: Path, task_id: str, result: dict[str, Any]) -> None:
- """Wrap ``result`` in a ``ResultEnvelope`` and persist it at ``path``."""
- path.parent.mkdir(parents=True, exist_ok=True)
- envelope = ResultEnvelope(task_id=task_id, result=result)
- atomic_write_text(path, envelope.model_dump_json(indent=2))
-
-
def read_result(base_dir: Path, task_id: str) -> str:
path = result_file_path(base_dir, task_id)
return path.read_text(encoding="utf-8")
diff --git a/src/shared/tasks/specs/common.py b/src/shared/tasks/specs/common.py
index d6cb61db..e841e596 100644
--- a/src/shared/tasks/specs/common.py
+++ b/src/shared/tasks/specs/common.py
@@ -2,6 +2,7 @@
from pydantic import Field, model_validator
+from ...schemas.result import BaseExecutorResult
from .._base import StrictBaseModel, TemplateBaseModel
from ..components import (
AdapterConfig,
@@ -72,7 +73,7 @@ class TaskSpecStrictBase(StrictBaseModel):
shard: ShardSpec | None = None
# Server-injected stage context (reserve the user-facing key `_upstreamResults`)
- upstreamResults: dict[str, Any] | None = Field(
+ upstreamResults: dict[str, BaseExecutorResult] | None = Field(
default=None, alias="_upstreamResults"
)
@@ -98,7 +99,7 @@ class TaskSpecTemplateBase(TemplateBaseModel):
condition: ConditionSpec | None = None
shard: ShardSpecTemplate | None = None
- upstreamResults: dict[str, Any] | None = Field(
+ upstreamResults: dict[str, BaseExecutorResult] | None = Field(
default=None, alias="_upstreamResults"
)
diff --git a/src/worker/executors/agent_executor.py b/src/worker/executors/agent_executor.py
index c12d4003..db59687b 100644
--- a/src/worker/executors/agent_executor.py
+++ b/src/worker/executors/agent_executor.py
@@ -16,14 +16,12 @@
from datasets import load_dataset
+from shared.schemas.artifact import ArtifactRef
+from shared.schemas.result import BaseExecutorResult
from shared.tasks.specs import AgentSpecStrict
from .base_executor import ExecutionError, Executor, ExecutorTask
-from .utils.checkpoints import (
- artifact_ref,
- maybe_upload_artifacts,
- write_executor_result,
-)
+from .utils.checkpoints import maybe_upload_artifacts, write_executor_result
from .utils.graph_templates import build_prompts_from_graph_template
# Add agent directory to sys.path for utu imports
@@ -54,6 +52,16 @@ def _resolve_task_timeout(agent: dict[str, Any] | None) -> int:
logger = logging.getLogger("worker.agent")
+class AgentResult(BaseExecutorResult):
+ ok: bool = True
+ model: str
+ items: list[dict[str, Any]] = []
+ usage: dict[str, Any] | None = None
+ metadata: dict[str, Any] | None = None
+ agent_output: ArtifactRef | None = None
+ batch_summary_file: ArtifactRef | None = None
+
+
class AgentExecutor(Executor):
"""Agent executor using youtu-agent (utu) framework"""
@@ -223,7 +231,7 @@ def prepare_data(self, spec: AgentSpecStrict) -> None:
else:
raise ExecutionError(f"Unsupported spec.data.type: {dtype!r}")
- def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]:
+ def run(self, task: ExecutorTask, out_dir: Path) -> AgentResult:
"""Execute agent tasks using youtu-agent (utu) framework"""
self.ensure_dir(out_dir)
@@ -248,32 +256,34 @@ def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]:
agent_config_name, self._tasks[0], out_dir, task_timeout
)
)
+ agent_output_ref = result.get("_agent_output_ref")
- output: dict[str, Any] = {
- "ok": True,
- "model": agent_config_name,
- "items": [
+ output = AgentResult(
+ model=agent_config_name,
+ items=[
{
"index": 0,
"output": result.get("output", ""),
"finish_reason": "completed",
}
],
- "usage": {
+ usage={
"execution_time_sec": result.get("usage", {}).get(
"execution_time_sec", 0
),
"num_requests": 1,
"agent_config": agent_config_name,
},
- "metadata": {
+ metadata={
"task": self._tasks[0],
"execution_log": result.get("log", []),
},
- }
- agent_output_ref = result.get("_agent_output_ref")
- if isinstance(agent_output_ref, str):
- output["agent_output"] = artifact_ref(agent_output_ref)
+ agent_output=(
+ None
+ if agent_output_ref is None
+ else ArtifactRef(path=agent_output_ref)
+ ),
+ )
else:
# Batch execution for multiple tasks
results = asyncio.run(
@@ -297,26 +307,28 @@ def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]:
}
)
- output = {
- "ok": True,
- "model": agent_config_name,
- "items": items,
- "usage": {
+ batch_summary_ref = results.get("_batch_summary_ref")
+ output = AgentResult(
+ model=agent_config_name,
+ items=items,
+ usage={
"execution_time_sec": results.get("usage", {}).get(
"execution_time_sec", 0
),
"num_requests": len(self._tasks),
"agent_config": agent_config_name,
},
- "metadata": {
+ metadata={
"tasks_count": len(self._tasks),
"execution_log": results.get("log", []),
"batch_summary": results.get("batch_summary", {}),
},
- }
- batch_summary_ref = results.get("_batch_summary_ref")
- if isinstance(batch_summary_ref, str):
- output["batch_summary_file"] = artifact_ref(batch_summary_ref)
+ batch_summary_file=(
+ None
+ if batch_summary_ref is None
+ else ArtifactRef(path=batch_summary_ref)
+ ),
+ )
maybe_upload_artifacts(task, out_dir, logger=logger)
@@ -327,21 +339,21 @@ def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]:
raise
except Exception as e:
logger.exception(f"Agent execution failed: {e}")
- error_output = {
- "ok": False,
- "model": agent_config_name,
- "items": [],
- "usage": {
+ error_output = AgentResult(
+ ok=False,
+ model=agent_config_name,
+ items=[],
+ usage={
"execution_time_sec": 0,
"num_requests": len(self._tasks),
"agent_config": agent_config_name,
},
- "metadata": {
+ metadata={
"tasks_count": len(self._tasks),
"error": str(e),
"execution_log": [],
},
- }
+ )
write_executor_result(
out_dir / "results.json", task.task_id, task.spec, error_output
)
diff --git a/src/worker/executors/api_executor.py b/src/worker/executors/api_executor.py
index 4129ce0f..bc0e9bde 100644
--- a/src/worker/executors/api_executor.py
+++ b/src/worker/executors/api_executor.py
@@ -5,7 +5,9 @@
from typing import Any, ClassVar
import httpx
+from pydantic import Field
+from shared.schemas.result import BaseExecutorResult
from shared.tasks.specs import ApiSpecStrict
from .base_executor import ExecutionError, Executor, ExecutorTask
@@ -16,6 +18,18 @@
_ClientKey = tuple[str, float, bool, bool]
+class APIResult(BaseExecutorResult):
+ executor: str
+ method: str
+ url: str
+ status_code: int
+ truncated: bool = False
+ headers: dict[str, str] | None = None
+ response_json: Any = Field(default=None, alias="json")
+ usage: dict[str, Any] | None = None
+ text: str | None = None
+
+
class APIExecutor(Executor):
"""Executor that performs a single HTTP request defined by task YAML.
@@ -89,7 +103,7 @@ def cleanup_after_run(self) -> None:
"""Close the connection pool when the runner deactivates this executor."""
self.close_all_clients()
- def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]:
+ def run(self, task: ExecutorTask, out_dir: Path) -> APIResult:
spec = self.require_spec(task, ApiSpecStrict)
api_cfg = spec.api or {}
if not isinstance(api_cfg, dict):
@@ -173,17 +187,17 @@ def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]:
body_bytes = body_bytes[:max_body_bytes]
truncated = True
- result: dict[str, Any] = {
- "ok": resp.is_success,
- "executor": self.name,
- "method": method,
- "url": str(resp.url),
- "status_code": resp.status_code,
- "truncated": truncated,
- }
+ result = APIResult(
+ ok=resp.is_success,
+ executor=self.name,
+ method=method,
+ url=str(resp.url),
+ status_code=resp.status_code,
+ truncated=truncated,
+ )
if include_headers:
- result["headers"] = dict(resp.headers)
+ result.headers = dict(resp.headers)
body_text: str | None = None
if return_body:
@@ -191,26 +205,25 @@ def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]:
body_text = body_bytes.decode(encoding, errors="replace")
if parse_json:
- result["json"] = resp.json()
- if not isinstance(result["json"], dict):
+ result.response_json = resp.json()
+ if not isinstance(result.response_json, dict):
raise ExecutionError("Response is not a valid JSON mapping")
- if isinstance(result["json"], dict):
- usage = result["json"].get("usage")
- if not isinstance(usage, dict):
- raise ExecutionError(
- "spec.api.response.parse_json is true but response JSON "
- f"does not contain usage info: {result['json']}"
- )
- result["usage"] = usage
+ usage = result.response_json.get("usage")
+ if not isinstance(usage, dict):
+ raise ExecutionError(
+ "spec.api.response.parse_json is true but response JSON "
+ f"does not contain usage info: {result.response_json}"
+ )
+ result.usage = usage
try:
- result["text"] = result["json"]["choices"][0]["message"]["content"]
+ result.text = result.response_json["choices"][0]["message"]["content"]
except Exception as exc:
raise ExecutionError(
"spec.api.response.parse_json is true but response JSON "
- f"does not contain message.content: {result['json']}"
+ f"does not contain message.content: {result.response_json}"
) from exc
elif return_body:
- result["text"] = body_text
+ result.text = body_text
if raise_for_status and resp.is_error:
if body_text:
diff --git a/src/worker/executors/base_executor.py b/src/worker/executors/base_executor.py
index f4f43f90..f5a20245 100644
--- a/src/worker/executors/base_executor.py
+++ b/src/worker/executors/base_executor.py
@@ -2,19 +2,23 @@
Executor base class and a minimal example implementation.
Usage:
- from executor_base import Executor, ExecutionError, EchoExecutor
+ from shared.schemas.result import BaseExecutorResult
+ from worker.executors.base_executor import Executor, ExecutionError
+
+ class MyResult(BaseExecutorResult):
+ echo: str
class MyExecutor(Executor):
name = "my-executor"
- def run(self, task: ExecutorTask, out_dir: Path) -> dict:
+ def run(self, task: ExecutorTask, out_dir: Path) -> MyResult:
# ... your logic ...
- return {"ok": True, "echo": task.task_id}
+ return MyResult(echo=task.task_id)
Contract:
-- Implement `run(task: ExecutorTask, out_dir: Path) -> dict`. The runner
- writes the returned dict to `out_dir/results.json` and injects the
- top-level `_artifacts` context — executors should not write that file
- themselves on the success path.
+- Implement `run(task: ExecutorTask, out_dir: Path) -> BaseExecutorResult`.
+ The runner writes the returned model to `out_dir/results.json` and
+ injects the top-level `_artifacts` context — executors should not write
+ that file themselves on the success path.
- Drop generated files under `out_dir/artifacts/` (uploaded to the server
when the task has an HTTP destination) or `scratch_dir(out_dir)` for
local-only scratch data.
@@ -27,6 +31,7 @@ def run(self, task: ExecutorTask, out_dir: Path) -> dict:
from pathlib import Path
from typing import Any, TypeVar
+from shared.schemas.result import BaseExecutorResult
from shared.tasks import MergedChildTaskStrict
from shared.tasks.specs import TaskSpecStrictBase
from shared.tasks.worker_message import WorkerHardware, WorkerTaskMessage
@@ -84,7 +89,7 @@ def prepare(self) -> None:
return None
@abstractmethod
- def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]:
+ def run(self, task: ExecutorTask, out_dir: Path) -> BaseExecutorResult:
"""Execute a single task.
Args:
@@ -93,7 +98,7 @@ def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]:
if needed.
Returns:
- A JSON-serializable dictionary summarizing the result.
+ A ``BaseExecutorResult`` subclass instance.
Raises:
ExecutionError: for expected, user-facing failures.
@@ -139,13 +144,19 @@ def load_json(path: Path) -> dict[str, Any]:
# -------- Minimal example implementation --------
+class EchoResult(BaseExecutorResult):
+ ok: bool = True
+ executor: str
+ task_id: str
+ task_type: str
+
+
class EchoExecutor(Executor):
name = "echo"
- def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]:
- return {
- "ok": True,
- "executor": self.name,
- "task_id": task.task_id,
- "task_type": task.spec.taskType,
- }
+ def run(self, task: ExecutorTask, out_dir: Path) -> EchoResult:
+ return EchoResult(
+ executor=self.name,
+ task_id=task.task_id,
+ task_type=task.spec.taskType,
+ )
diff --git a/src/worker/executors/data_profiling_executor.py b/src/worker/executors/data_profiling_executor.py
index bfb7678f..57f4249c 100644
--- a/src/worker/executors/data_profiling_executor.py
+++ b/src/worker/executors/data_profiling_executor.py
@@ -8,6 +8,7 @@
from pathlib import Path
from typing import Any
+from shared.schemas.result import BaseExecutorResult
from shared.tasks.specs import DataProfilingSpecStrict
from shared.utils.json import to_json_serializable, validate_keys
@@ -19,19 +20,25 @@
logger = logging.getLogger(__name__)
+class DataProfilingResult(BaseExecutorResult):
+ ok: bool = True
+ type: str = "sql"
+ template: str | None = None
+ cost_estimates: dict[str, Any] | None = None
+
+
class DataProfilingExecutor(DataMixin, Executor):
"""Executor that estimates SQL query costs by sampling SQL template params."""
name = "data_profiling"
- def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]:
+ def run(self, task: ExecutorTask, out_dir: Path) -> DataProfilingResult:
spec = self.require_spec(task, DataProfilingSpecStrict)
task_id = task.task_id
merge_children = task.merged_children or []
result = self._run_single_profile(spec, task_id)
- child_results: dict[str, dict[str, Any]] = {}
for child in merge_children:
child_id = child.task_id
child_spec = child.spec
@@ -39,16 +46,13 @@ def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]:
raise ExecutionError(
"Merged child spec must be data_profiling for merged profiling"
)
- child_results[child_id] = self._run_single_profile(child_spec, child_id)
-
- if child_results:
- result["children"] = child_results
+ result.children[child_id] = self._run_single_profile(child_spec, child_id)
return result
def _run_single_profile(
self, spec: DataProfilingSpecStrict, task_id: str
- ) -> dict[str, Any]:
+ ) -> DataProfilingResult:
data_cfg = spec.data
if not isinstance(data_cfg, dict):
raise ExecutionError(
@@ -100,14 +104,12 @@ def _run_single_profile(
connection_string, queries, params_rows
)
- result: dict[str, Any] = {
- "ok": True,
- "type": "sql",
- "template": template_str,
- "cost_estimates": cost_estimates,
- }
-
- return result
+ return DataProfilingResult(
+ ok=True,
+ type="sql",
+ template=template_str,
+ cost_estimates=cost_estimates,
+ )
def _sample_template_queries(
self,
diff --git a/src/worker/executors/data_retrieval_executor.py b/src/worker/executors/data_retrieval_executor.py
index 0321f147..4347a262 100644
--- a/src/worker/executors/data_retrieval_executor.py
+++ b/src/worker/executors/data_retrieval_executor.py
@@ -10,6 +10,8 @@
import pandas as pd
from PIL import Image
+from shared.schemas.artifact import ArtifactRef
+from shared.schemas.result import BaseExecutorResult
from shared.tasks.specs import DataRetrievalSpecStrict
from shared.utils.json import validate_keys
@@ -17,20 +19,23 @@
from ..utils.serialization import serialize_dataframe
from .base_executor import ExecutionError, Executor, ExecutorTask
from .mixins.data import DataMixin
-from .utils.checkpoints import (
- artifact_ref,
- maybe_upload_artifacts,
- maybe_upload_traces,
-)
+from .utils.checkpoints import maybe_upload_artifacts, maybe_upload_traces
from .utils.graph_templates import _render_template, _resolve_columns
logger = logging.getLogger(__name__)
+class DataRetrievalResult(BaseExecutorResult):
+ type: str | None = None
+ items: list[dict[str, Any]] = []
+ count: int | None = None
+ metadata: dict[str, Any] | None = None
+
+
class DataRetrievalExecutor(DataMixin, Executor):
name = "data_retrieval"
- def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]:
+ def run(self, task: ExecutorTask, out_dir: Path) -> DataRetrievalResult:
spec = self.require_spec(task, DataRetrievalSpecStrict)
task_id = task.task_id
with self._task_span(
@@ -71,15 +76,15 @@ def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]:
return result
def _run_sql(
- self, data_cfg: dict[str, Any], context: dict[str, Any]
- ) -> dict[str, Any]:
+ self, data_cfg: dict[str, Any], context: dict[str, BaseExecutorResult]
+ ) -> DataRetrievalResult:
"""
Execute SQL queries based on the provided data configuration and context.
:param data_cfg: Description
:type data_cfg: dict[str, Any]
:param context: Description
- :type context: dict[str, Any]
+ :type context: dict[str, BaseExecutorResult]
:return: Description
:rtype: dict[str, Any]
"""
@@ -144,18 +149,17 @@ def _run_sql(
}
)
- return {
- "ok": True,
- "items": items,
- "count": len(items),
- }
+ return DataRetrievalResult(
+ items=items,
+ count=len(items),
+ )
def _run_s3(
self,
data_cfg: dict[str, Any],
- context: dict[str, Any],
+ context: dict[str, BaseExecutorResult],
out_dir: Path,
- ) -> dict[str, Any]:
+ ) -> DataRetrievalResult:
validate_keys(
data_cfg,
"DataRetrievalExecutor.spec.data",
@@ -207,20 +211,18 @@ def _run_s3(
item["params"] = params_rows
items.append(item)
- result = {
- "ok": True,
- "type": "s3",
- "items": items,
- "metadata": s3_result["metadata"], # type: ignore
- }
- return result
+ return DataRetrievalResult(
+ type="s3",
+ items=items,
+ metadata=s3_result["metadata"], # type: ignore
+ )
def _run_agent(
self,
data_cfg: dict[str, Any],
- context: dict[str, Any],
+ context: dict[str, BaseExecutorResult],
out_dir: Path,
- ) -> dict[str, Any]:
+ ) -> DataRetrievalResult:
"""Drive lumid.data's data agent for NL-driven retrieval."""
validate_keys(
data_cfg,
@@ -314,12 +316,11 @@ def _run_agent(
}
)
- return {
- "ok": True,
- "type": "agent",
- "items": items,
- "count": len(items),
- }
+ return DataRetrievalResult(
+ type="agent",
+ items=items,
+ count=len(items),
+ )
def _load_table(self, path: Path, output_format: str) -> pd.DataFrame:
"""Load the materialized retrieval file into a DataFrame."""
@@ -338,5 +339,5 @@ def _serialize_s3_content(self, content: Any, out_dir: Path) -> Any:
filename = f"{uuid.uuid4().hex}.png"
file_path = images_dir / filename
content.save(file_path, format="PNG")
- return artifact_ref(f"s3_images/{filename}")
+ return ArtifactRef(path=f"s3_images/{filename}")
return content
diff --git a/src/worker/executors/diffusers_executor.py b/src/worker/executors/diffusers_executor.py
index f74fb8e4..0c782fce 100644
--- a/src/worker/executors/diffusers_executor.py
+++ b/src/worker/executors/diffusers_executor.py
@@ -14,16 +14,14 @@
from PIL import Image
+from shared.schemas.artifact import ArtifactRef
+from shared.schemas.result import BaseExecutorResult
from shared.tasks.specs import DiffusionSpecStrict
from ..utils.logging import configure_hf_library_logging
from .base_executor import ExecutionError, Executor, ExecutorTask
from .mixins.data import DataMixin
-from .utils.checkpoints import (
- artifact_ref,
- maybe_upload_artifacts,
- maybe_upload_traces,
-)
+from .utils.checkpoints import maybe_upload_artifacts, maybe_upload_traces
try:
import torch
@@ -45,6 +43,12 @@
logger = logging.getLogger(__name__)
+class DiffusersResult(BaseExecutorResult):
+ ok: bool = True
+ model: str | None = None
+ images: list[ArtifactRef] = []
+
+
class DiffusersExecutor(DataMixin, Executor):
"""Executor that runs text-to-image generation via Hugging Face Diffusers."""
@@ -240,24 +244,24 @@ def _encode_and_combine_prompts(
return combined_pos, combined_neg, user_pos_pooled, user_neg_pooled
- def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]:
+ def run(self, task: ExecutorTask, out_dir: Path) -> DiffusersResult:
configure_hf_library_logging()
spec = self.require_spec(task, DiffusionSpecStrict)
task_id = task.task_id.strip()
with self._task_span(
task_id, task.workflow_id, out_dir, owner_id=task.owner_id
):
- response = self._run_inner(spec, task_id, out_dir)
+ result = self._run_inner(spec, task_id, out_dir)
maybe_upload_artifacts(task, out_dir, logger=logger)
maybe_upload_traces(task, out_dir, logger=logger)
- return response
+ return result
def _run_inner(
self,
spec: DiffusionSpecStrict,
task_id: str,
out_dir: Path,
- ) -> dict[str, Any]:
+ ) -> DiffusersResult:
self._ensure_pipeline(spec)
assert self._pipe is not None
@@ -336,24 +340,18 @@ def _run_inner(
image_dir = out_dir / "artifacts" / "images"
image_dir.mkdir(parents=True, exist_ok=True)
- generated_images: list[dict[str, str]] = []
+ generated_images: list[ArtifactRef] = []
for idx, img in enumerate(images):
img.save(image_dir / f"image_{idx}.png", format="PNG")
- generated_images.append(artifact_ref(f"images/image_{idx}.png"))
-
- response: dict[str, Any] = {
- "ok": True,
- "model": self._model_name,
- "images": generated_images,
- }
+ generated_images.append(ArtifactRef(path=f"images/image_{idx}.png"))
+ result = DiffusersResult(model=self._model_name, images=generated_images)
self._dump_to_governance(
task_id=task_id,
- result=response,
+ result=result,
dependencies_by_task=dependencies_by_task,
)
-
- return response
+ return result
def cleanup_after_run(self) -> None:
if self._pipe is not None:
diff --git a/src/worker/executors/dpo_executor.py b/src/worker/executors/dpo_executor.py
index 7dc1e8d2..b62256f2 100644
--- a/src/worker/executors/dpo_executor.py
+++ b/src/worker/executors/dpo_executor.py
@@ -25,6 +25,8 @@
from trl.trainer.dpo_config import DPOConfig
from trl.trainer.dpo_trainer import DPOTrainer
+from shared.schemas.artifact import ArtifactRef
+from shared.schemas.result import BaseExecutorResult
from shared.tasks.specs import DPOSpecStrict
from shared.utils.manifest import scratch_dir
@@ -33,7 +35,6 @@
from .mixins.training import TrainingMixin
from .utils.checkpoints import (
archive_model_dir,
- artifact_ref,
get_http_destination,
maybe_upload_artifacts,
write_executor_result,
@@ -45,6 +46,18 @@
logger = logging.getLogger("worker.dpo")
+class DPOResult(BaseExecutorResult):
+ training_time_seconds: float | None = None
+ error_message: str | None = None
+ model_name: str | None = None
+ dataset_size: int = 0
+ output_dir: str | None = None
+ checkpoints_dir: ArtifactRef | None = None
+ final_model: ArtifactRef | None = None
+ final_model_archive: ArtifactRef | None = None
+ spawned_torchrun: bool = False
+
+
class DPOExecutor(TrainingMixin, Executor):
"""DPO training executor using TRL library."""
@@ -59,7 +72,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
self._current_trainer: DPOTrainer | None = None
self._task_out_dir: Path | None = None
- def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]:
+ def run(self, task: ExecutorTask, out_dir: Path) -> DPOResult:
configure_hf_library_logging()
logger.info("Starting DPO training task")
spec = self.require_spec(task, DPOSpecStrict)
@@ -80,22 +93,21 @@ def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]:
)
ipc_path = scratch_dir(out_dir) / "distributed_result.json"
if ipc_path.exists():
- return self.load_json(ipc_path)
- return {
- "training_successful": True,
- "spawned_torchrun": True,
- "model_name": (
- spec.model
- and spec.model.source
- and spec.model.source.identifier
+ return DPOResult.model_validate(self.load_json(ipc_path))
+ return DPOResult(
+ spawned_torchrun=True,
+ model_name=(
+ spec.model.source.identifier
+ if spec.model and spec.model.source
+ else None
),
- "output_dir": out_dir.as_posix(),
- }
+ output_dir=out_dir.as_posix(),
+ )
result = self._execute_training(task, out_dir)
logger.info(
"DPO training task completed in %.2f seconds",
- result.get("training_time_seconds", 0.0),
+ result.training_time_seconds or 0.0,
)
return result
finally:
@@ -354,7 +366,7 @@ def _ensure_jsonl_local(jsonl_cfg: dict[str, Any]) -> Path:
return dataset # type: ignore[return-value]
- def _execute_training(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]:
+ def _execute_training(self, task: ExecutorTask, out_dir: Path) -> DPOResult:
spec = self.require_spec(task, DPOSpecStrict)
training_config = spec.training or {}
artifacts_dir = out_dir / "artifacts"
@@ -462,23 +474,28 @@ def _execute_training(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]
logger.warning("Failed to save model: %s", exc)
training_time = time.time() - start_time
- results: dict[str, Any] = {
- "training_successful": True,
- "training_time_seconds": training_time,
- "error_message": None,
- "model_name": self._model_name,
- "dataset_size": len(dataset),
- "output_dir": out_dir.as_posix(),
- "checkpoints_dir": artifact_ref("checkpoints"),
- }
- if final_model_path is not None:
- results["final_model"] = artifact_ref(
- final_model_path.relative_to(artifacts_dir).as_posix()
- )
- if final_archive_path is not None:
- results["final_model_archive"] = artifact_ref(
- final_archive_path.relative_to(artifacts_dir).as_posix()
- )
+ result = DPOResult(
+ training_time_seconds=training_time,
+ error_message=None,
+ model_name=self._model_name,
+ dataset_size=len(dataset) if dataset is not None else 0,
+ output_dir=out_dir.as_posix(),
+ checkpoints_dir=ArtifactRef(path="checkpoints"),
+ final_model=(
+ ArtifactRef(
+ path=final_model_path.relative_to(artifacts_dir).as_posix()
+ )
+ if final_model_path
+ else None
+ ),
+ final_model_archive=(
+ ArtifactRef(
+ path=final_archive_path.relative_to(artifacts_dir).as_posix()
+ )
+ if final_archive_path
+ else None
+ ),
+ )
maybe_upload_artifacts(task, out_dir, logger=logger)
@@ -488,24 +505,22 @@ def _execute_training(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]
final_model_path,
final_archive_path,
)
- return results
+ return result
except Exception as exc:
training_time = time.time() - start_time
- results = {
- "training_successful": False,
- "training_time_seconds": training_time,
- "error_message": str(exc),
- "model_name": self._model_name,
- "dataset_size": len(dataset) if dataset is not None else 0,
- "output_dir": out_dir.as_posix(),
- }
+ result = DPOResult(
+ ok=False,
+ training_time_seconds=training_time,
+ error_message=str(exc),
+ model_name=self._model_name,
+ dataset_size=len(dataset) if dataset is not None else 0,
+ output_dir=out_dir.as_posix(),
+ )
write_executor_result(
- out_dir / "results.json", task.task_id, task.spec, results
+ out_dir / "results.json", task.task_id, task.spec, result
)
logger.exception("DPO training failed: %s", exc)
- raise ExecutionError(
- results["error_message"] or "DPO training failed"
- ) from exc
+ raise ExecutionError(result.error_message or "DPO training failed") from exc
def _spawn_distributed(
self,
diff --git a/src/worker/executors/echo_executor.py b/src/worker/executors/echo_executor.py
index 480938a4..1a03662f 100644
--- a/src/worker/executors/echo_executor.py
+++ b/src/worker/executors/echo_executor.py
@@ -2,6 +2,7 @@
from pathlib import Path
from typing import Any
+from shared.schemas.result import BaseExecutorResult
from shared.tasks.specs import EchoSpecStrict
from .base_executor import ExecutionError, Executor, ExecutorTask
@@ -14,6 +15,11 @@
type EchoItem = str | dict[str, str]
+class EchoResult(BaseExecutorResult):
+ items: list[dict[str, Any]] = []
+ count: int = 0
+
+
class EchoExecutor(DataMixin, Executor):
name = "echo"
@@ -25,7 +31,9 @@ def _append_outputs(self, out_items: list[dict[str, Any]], value: Any) -> None:
out_items.append({"output": value})
@staticmethod
- def _resolve_expr_item(item: dict[str, Any], context: dict[str, Any]) -> Any:
+ def _resolve_expr_item(
+ item: dict[str, Any], context: dict[str, BaseExecutorResult]
+ ) -> Any:
expr = item.get("expr")
if not expr:
node = item.get("node")
@@ -44,7 +52,9 @@ def _resolve_expr_item(item: dict[str, Any], context: dict[str, Any]) -> Any:
)
return resolved
- def _resolve_item(self, item: EchoItem, context: dict[str, Any]) -> Any:
+ def _resolve_item(
+ self, item: EchoItem, context: dict[str, BaseExecutorResult]
+ ) -> Any:
if isinstance(item, str):
return item
elif isinstance(item, dict):
@@ -55,7 +65,7 @@ def _resolve_item(self, item: EchoItem, context: dict[str, Any]) -> Any:
"a string literal or a mapping"
)
- def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]:
+ def run(self, task: ExecutorTask, out_dir: Path) -> EchoResult:
spec = self.require_spec(task, EchoSpecStrict)
task_id = task.task_id.strip()
with self._task_span(
@@ -81,18 +91,14 @@ def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]:
resolved = self._resolve_item(item, context)
self._append_outputs(merged_items, resolved)
- payload: dict[str, Any] = {
- "ok": True,
- "items": merged_items,
- "count": len(merged_items),
- }
+ result = EchoResult(items=merged_items, count=len(merged_items))
deps = self._extract_source_data_ids(spec)
dependencies_by_task = {task_id: deps}
self._dump_to_governance(
task_id=task_id,
- result=payload,
+ result=result,
dependencies_by_task=dependencies_by_task,
)
maybe_upload_traces(task, out_dir, logger=logger)
- return payload
+ return result
diff --git a/src/worker/executors/lora_sft_executor.py b/src/worker/executors/lora_sft_executor.py
index f40c99ad..9e36e220 100644
--- a/src/worker/executors/lora_sft_executor.py
+++ b/src/worker/executors/lora_sft_executor.py
@@ -18,6 +18,8 @@
from trl.trainer.sft_config import SFTConfig
from trl.trainer.sft_trainer import SFTTrainer
+from shared.schemas.artifact import ArtifactRef
+from shared.schemas.result import BaseExecutorResult
from shared.tasks.specs import LoRASFTSpecStrict
from ..utils.logging import configure_hf_library_logging
@@ -26,7 +28,6 @@
from .sft_executor import SFTExecutor
from .utils.checkpoints import (
archive_model_dir,
- artifact_ref,
determine_resume_path,
maybe_upload_artifacts,
write_executor_result,
@@ -48,6 +49,18 @@
logger = logging.getLogger("worker.sft.lora")
+class LoRAResult(BaseExecutorResult):
+ training_time_seconds: float | None = None
+ error_message: str | None = None
+ model_name: str | None = None
+ dataset_size: int = 0
+ output_dir: str | None = None
+ checkpoints_dir: ArtifactRef | None = None
+ resume_from_path: str | None = None
+ final_lora: ArtifactRef | None = None
+ final_lora_archive: ArtifactRef | None = None
+
+
class LoRASFTExecutor(TrainingMixin, Executor):
"""Execute LoRA-based supervised fine-tuning using TRL's SFTTrainer."""
@@ -59,11 +72,11 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
self._current_model: Any | None = None
self._current_trainer: Any | None = None
- def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]:
+ def run(self, task: ExecutorTask, out_dir: Path) -> LoRAResult:
configure_hf_library_logging()
spec = self.require_spec(task, LoRASFTSpecStrict)
start_time = time.time()
- training_successful = False
+ ok = False
error_msg: str | None = None
if (
@@ -99,6 +112,8 @@ def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]:
checkpoint_dir = artifacts_dir / "checkpoints"
checkpoint_dir.mkdir(parents=True, exist_ok=True)
+ resume_path: Path | None = None
+ train_dataset: Dataset | None = None
try:
model_name = spec.model_name or "gpt2"
self._model_name = model_name
@@ -121,7 +136,7 @@ def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]:
resume_path = determine_resume_path(
spec, training_cfg, out_dir, logger=logger
)
- resume_str = str(resume_path) if resume_path else None
+ resume_str = resume_path.as_posix() if resume_path else None
if bool(training_cfg.get("gradient_checkpointing", False)):
model.gradient_checkpointing_enable()
@@ -139,9 +154,7 @@ def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]:
"Resuming LoRA training from local checkpoint %s", resume_path
)
peft_model = PeftModel.from_pretrained(
- model,
- str(resume_path),
- is_trainable=True,
+ model, resume_path.as_posix(), is_trainable=True
)
logger.info("Loaded existing LoRA adapters; continuing fine-tuning")
else:
@@ -182,7 +195,7 @@ def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]:
logger.info("Initialized new LoRA adapters: %s", lora_target_modules)
sft_config = SFTConfig(
- output_dir=str(checkpoint_dir),
+ output_dir=checkpoint_dir.as_posix(),
num_train_epochs=float(training_cfg.get("num_train_epochs", 1.0)),
per_device_train_batch_size=int(training_cfg.get("batch_size", 2)),
gradient_accumulation_steps=int(
@@ -277,41 +290,42 @@ def _compute_loss_with_guard(
logger.info("Starting LoRA supervised fine-tuning run")
trainer.train()
- training_successful = True
+ ok = True
logger.info("LoRA SFT training completed")
final_adapter_path: Path | None = None
archive_path: Path | None = None
if bool(training_cfg.get("save_model", True)):
model_path = artifacts_dir / "final_lora"
- trainer.save_model(str(model_path))
+ trainer.save_model(model_path.as_posix())
tokenizer.save_pretrained(model_path)
final_adapter_path = model_path
logger.info("Saved LoRA-adapted weights to %s", model_path)
training_time = time.time() - start_time
- result_payload: dict[str, Any] = {
- "task_id": task.task_id,
- "training_successful": training_successful,
- "training_time_seconds": training_time,
- "error_message": error_msg,
- "model_name": self._model_name,
- "dataset_size": (
- len(train_dataset) if "train_dataset" in locals() else 0
- ),
- "output_dir": out_dir.as_posix(),
- "checkpoints_dir": artifact_ref("checkpoints"),
- "resume_from_path": resume_str,
- }
- if final_adapter_path is not None:
- result_payload["final_lora"] = artifact_ref(
- final_adapter_path.relative_to(artifacts_dir).as_posix()
+ final_lora: ArtifactRef | None = None
+ final_lora_archive: ArtifactRef | None = None
+ if final_adapter_path:
+ final_lora = ArtifactRef(
+ path=final_adapter_path.relative_to(artifacts_dir).as_posix()
)
archive_path = archive_model_dir(final_adapter_path)
- result_payload["final_lora_archive"] = artifact_ref(
- archive_path.relative_to(artifacts_dir).as_posix()
+ final_lora_archive = ArtifactRef(
+ path=archive_path.relative_to(artifacts_dir).as_posix()
)
logger.info("Prepared LoRA archive at %s", archive_path)
+ result = LoRAResult(
+ ok=ok,
+ training_time_seconds=training_time,
+ error_message=error_msg,
+ model_name=self._model_name,
+ dataset_size=len(train_dataset) if train_dataset is not None else 0,
+ output_dir=out_dir.as_posix(),
+ checkpoints_dir=ArtifactRef(path="checkpoints"),
+ resume_from_path=resume_str,
+ final_lora=final_lora,
+ final_lora_archive=final_lora_archive,
+ )
maybe_upload_artifacts(task, out_dir, logger=logger)
@@ -321,30 +335,27 @@ def _compute_loss_with_guard(
final_adapter_path,
archive_path,
)
- return result_payload
+ return result
except Exception as exc: # pylint: disable=broad-except
error_msg = str(exc)
- training_successful = False
+ ok = False
logger.exception("LoRA SFT training failed: %s", exc)
training_time = time.time() - start_time
- result: dict[str, Any] = {
- "task_id": task.task_id,
- "training_successful": training_successful,
- "training_time_seconds": training_time,
- "error_message": error_msg,
- "model_name": self._model_name,
- "dataset_size": len(train_dataset) if "train_dataset" in locals() else 0,
- "output_dir": out_dir.as_posix(),
- "checkpoints_dir": artifact_ref("checkpoints"),
- "resume_from_path": (
- str(resume_path) if "resume_path" in locals() and resume_path else None
- ),
- }
-
- if training_successful:
+ result = LoRAResult(
+ ok=ok,
+ training_time_seconds=training_time,
+ error_message=error_msg,
+ model_name=self._model_name,
+ dataset_size=len(train_dataset) if train_dataset is not None else 0,
+ output_dir=out_dir.as_posix(),
+ checkpoints_dir=ArtifactRef(path="checkpoints"),
+ resume_from_path=resume_path.as_posix() if resume_path else None,
+ )
+
+ if ok:
return result
write_executor_result(out_dir / "results.json", task.task_id, task.spec, result)
diff --git a/src/worker/executors/mixins/data.py b/src/worker/executors/mixins/data.py
index ee45aab7..d0ee4324 100644
--- a/src/worker/executors/mixins/data.py
+++ b/src/worker/executors/mixins/data.py
@@ -15,6 +15,7 @@
from datasets import Dataset, load_dataset
from PIL import Image
+from shared.schemas.result import BaseExecutorResult
from shared.tasks.specs import TaskSpecStrictBase
from shared.utils.json import safe_get
@@ -403,7 +404,7 @@ def _collect_prompts_for_spec(
metadata_raw.append(entry_meta)
elif dtype == "list":
items = data.get("items")
- context: dict[str, Any] | None = None
+ context: dict[str, BaseExecutorResult] | None = None
root_node: str | None = None
if items is None:
expr = data.get("expr")
@@ -702,14 +703,11 @@ def _collect_prompts_for_spec(
)
def _populate_table(
- self,
- payload: dict[str, Any],
- table_stores_list: list[pd.DataFrame],
- ):
+ self, items: list[dict[str, Any]], table_stores_list: list[pd.DataFrame]
+ ) -> list[dict[str, Any]]:
"""
Group row-level generation outputs back into per-table outputs.
"""
- items = payload["items"]
cur = 0
grouped_items: list[dict[str, Any]] = []
for df in table_stores_list:
@@ -724,8 +722,7 @@ def _populate_table(
f"Output length {len(items)} does not match "
f"the total number of rows {cur} in table stores."
)
- payload["items"] = grouped_items
- return payload
+ return grouped_items
def _maybe_apply_dataset_shard(self, dataset, spec: TaskSpecStrictBase):
shard_cfg = spec.shard
diff --git a/src/worker/executors/mixins/governance.py b/src/worker/executors/mixins/governance.py
index 30adfcc9..44617d25 100644
--- a/src/worker/executors/mixins/governance.py
+++ b/src/worker/executors/mixins/governance.py
@@ -18,6 +18,7 @@
TASK_SPAN_NAME,
SpanType,
)
+from shared.schemas.result import BaseExecutorResult
from shared.tasks.specs import TaskSpecStrictBase
from shared.utils.time import now_iso
@@ -209,7 +210,9 @@ def _record_output(
)
@staticmethod
- def _spec_upstream_results(spec: TaskSpecStrictBase) -> dict[str, Any]:
+ def _spec_upstream_results(
+ spec: TaskSpecStrictBase,
+ ) -> dict[str, BaseExecutorResult]:
"""Validated ``spec._upstreamResults`` (server-injected stage context)."""
context = spec.upstreamResults or {}
if not isinstance(context, dict):
@@ -236,27 +239,25 @@ def _extract_source_data_ids(self, spec: TaskSpecStrictBase) -> list[str]:
def _dump_to_governance(
self,
task_id: str,
- result: dict[str, Any],
+ result: BaseExecutorResult,
dependencies_by_task: dict[str, list[str]],
) -> None:
"""Write parent + merged-child results and emit asset/lineage rows."""
parent_deps = dependencies_by_task.get(task_id, [])
- children_payload = result.get("children", {})
-
collection_jobs: list[dict[str, Any]] = [
{
"task_id": task_id,
- "result": result,
+ "result": result.model_dump(),
"deps": parent_deps,
"is_parent": True,
}
]
- for child_id, child_result in children_payload.items():
+ for child_id, child_result in result.children.items():
child_deps = dependencies_by_task.get(child_id, [])
collection_jobs.append(
{
"task_id": child_id,
- "result": child_result,
+ "result": child_result.model_dump(),
"deps": child_deps,
"is_parent": False,
}
diff --git a/src/worker/executors/mixins/inference.py b/src/worker/executors/mixins/inference.py
index fb7ae7a2..4076e901 100644
--- a/src/worker/executors/mixins/inference.py
+++ b/src/worker/executors/mixins/inference.py
@@ -15,7 +15,6 @@
from shared.utils.json import to_json_serializable
from ..base_executor import ExecutionError
-from ..utils.checkpoints import artifact_ref
from .data import DataMixin, InferenceEntry, PromptInput
logger = logging.getLogger(__name__)
@@ -229,7 +228,7 @@ def _maybe_export_jsonl(
self,
spec: InferenceSpecStrict,
task_id: str,
- result: dict[str, Any],
+ items: list[dict[str, Any]],
out_dir: Path,
) -> None:
post_cfg = (postprocess := spec.postprocess) and postprocess.jsonl_export
@@ -261,7 +260,6 @@ def _maybe_export_jsonl(
) from exc
target_path.parent.mkdir(parents=True, exist_ok=True)
- items = result.get("items") or []
required_fields = post_cfg.required_fields or []
records: list[dict[str, Any]] = []
@@ -292,11 +290,6 @@ def _maybe_export_jsonl(
fh.write("\n")
rel_path = target_path.relative_to(artifacts_dir).as_posix()
- result["jsonl_export"] = {
- **artifact_ref(rel_path),
- "record_count": len(records),
- "fields": list(fields_cfg.keys()),
- }
logger.info(
"Task %s exported %d records to artifacts/%s",
task_id,
diff --git a/src/worker/executors/mp_executor.py b/src/worker/executors/mp_executor.py
index f2379fe2..f2b7c199 100644
--- a/src/worker/executors/mp_executor.py
+++ b/src/worker/executors/mp_executor.py
@@ -21,6 +21,7 @@
import psutil
+from shared.schemas.result import BaseExecutorResult
from shared.tasks.worker_message import WorkerHardware
from worker.config import WorkerConfig
@@ -439,7 +440,7 @@ def _loop() -> None:
self._log_thread = t
t.start()
- def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]:
+ def run(self, task: ExecutorTask, out_dir: Path) -> BaseExecutorResult:
with self._lock:
if self._shutdown:
logger.info("Starting worker subprocess for %s", self.name)
diff --git a/src/worker/executors/omni_executor_base.py b/src/worker/executors/omni_executor_base.py
index 87e36f93..8fd9ec21 100644
--- a/src/worker/executors/omni_executor_base.py
+++ b/src/worker/executors/omni_executor_base.py
@@ -19,6 +19,7 @@
import yaml
+from shared.schemas.result import BaseExecutorResult
from shared.tasks.specs import TaskSpecStrictBase
from shared.utils.parsing import to_bool, to_int
@@ -40,6 +41,13 @@
logger = logging.getLogger(__name__)
+class OmniResult(BaseExecutorResult):
+ executor: str
+ mode: str
+ model: str | None
+ items: list[dict[str, Any]]
+
+
class OmniExecutorBase(InferenceMixin, Executor):
"""Shared base for Omni-family executors.
@@ -58,7 +66,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
self._omni_spec: tuple[Any, ...] | None = None
self._stage_configs_tmp: Path | None = None
- def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]:
+ def run(self, task: ExecutorTask, out_dir: Path) -> OmniResult:
spec = self.require_spec(task, self._TASK_SPEC_TYPE)
spec_dict = spec.model_dump(by_alias=True)
out_dir = Path(out_dir).resolve()
@@ -77,7 +85,7 @@ def _run_inner(
spec: TaskSpecStrictBase,
spec_dict: dict[str, Any],
out_dir: Path,
- ) -> dict[str, Any]:
+ ) -> OmniResult:
"""Run the executor-specific body. ``spec`` is the concrete strict
spec; subclasses ``assert isinstance(spec, ...)`` to narrow."""
raise NotImplementedError
diff --git a/src/worker/executors/omni_text2audio_executor.py b/src/worker/executors/omni_text2audio_executor.py
index 952dfb28..ef4ff3c1 100644
--- a/src/worker/executors/omni_text2audio_executor.py
+++ b/src/worker/executors/omni_text2audio_executor.py
@@ -50,22 +50,33 @@
current_omni_platform = None
_HAS_OMNI_PLATFORM = False
+from shared.schemas.artifact import ArtifactRef
from shared.schemas.governance import SpanType
from shared.tasks.specs import TaskSpecStrictBase
from shared.tasks.specs.omni import OmniText2AudioSpecStrict
from shared.utils.parsing import to_float, to_int
from .base_executor import ExecutionError, ExecutorTask
-from .omni_executor_base import OmniExecutorBase, extract_multimodal_output
-from .utils.checkpoints import artifact_ref
+from .omni_executor_base import OmniExecutorBase, OmniResult, extract_multimodal_output
logger = logging.getLogger(__name__)
+EXECUTOR_NAME = "omni_text2audio"
+
+
+class OmniText2AudioResult(OmniResult):
+ executor: str = EXECUTOR_NAME
+ mode: str = "bgm"
+ audio: ArtifactRef | None
+ sample_rate: int
+ num_waveforms: int
+ audio_length: float
+ storyboard: dict[str, Any] | None = None
class OmniText2AudioExecutor(OmniExecutorBase):
"""Generate background music with Omni diffusion sampling."""
- name = "omni_text2audio"
+ name = EXECUTOR_NAME
_TASK_SPEC_TYPE = OmniText2AudioSpecStrict
def prepare(self) -> None:
@@ -84,7 +95,7 @@ def _run_inner(
spec: TaskSpecStrictBase,
spec_dict: dict[str, Any],
out_dir: Path,
- ) -> dict[str, Any]:
+ ) -> OmniText2AudioResult:
assert isinstance(spec, OmniText2AudioSpecStrict)
prompts = self._collect_text_inputs(spec, task.task_id)
@@ -192,8 +203,8 @@ def _run_inner(
"prompt_index": prompt_idx,
"waveform_index": local_idx,
"prompt": prompt,
- "audio": artifact_ref(
- self.relative_to(save_path, artifacts_dir)
+ "audio": ArtifactRef(
+ path=self.relative_to(save_path, artifacts_dir)
),
}
)
@@ -202,22 +213,15 @@ def _run_inner(
if not items:
raise ExecutionError("omni_text2audio produced no savable waveforms.")
- first = items[0]["audio"] if items else {}
- result: dict[str, Any] = {
- "ok": True,
- "executor": self.name,
- "mode": "bgm",
- "model": self._model_name,
- "audio": first,
- "items": items,
- "sample_rate": sample_rate,
- "num_waveforms": len(items),
- "audio_length": audio_length,
- }
- storyboard = spec_dict.get("storyboard")
- if isinstance(storyboard, dict):
- result["storyboard"] = dict(storyboard)
- return result
+ return OmniText2AudioResult(
+ model=self._model_name,
+ audio=items[0]["audio"] if items else None,
+ items=items,
+ sample_rate=sample_rate,
+ num_waveforms=len(items),
+ audio_length=audio_length,
+ storyboard=spec_dict.get("storyboard"),
+ )
# ── model ────────────────────────────────────────────────────────────
diff --git a/src/worker/executors/omni_text2general_executor.py b/src/worker/executors/omni_text2general_executor.py
index 5664af99..9c48763c 100644
--- a/src/worker/executors/omni_text2general_executor.py
+++ b/src/worker/executors/omni_text2general_executor.py
@@ -26,16 +26,22 @@
Omni = None
_HAS_OMNI = False
+from shared.schemas.artifact import ArtifactRef
from shared.schemas.governance import SpanType
from shared.tasks.specs import TaskSpecStrictBase
from shared.tasks.specs.omni import OmniText2GeneralSpecStrict
from shared.utils.parsing import as_list, to_bool, to_float, to_int, to_int_list
from .base_executor import ExecutionError, ExecutorTask
-from .omni_executor_base import OmniExecutorBase, extract_audio_from_mm, save_audio
-from .utils.checkpoints import artifact_ref
+from .omni_executor_base import (
+ OmniExecutorBase,
+ OmniResult,
+ extract_audio_from_mm,
+ save_audio,
+)
logger = logging.getLogger(__name__)
+EXECUTOR_NAME = "omni_text2general"
_DEFAULT_SYSTEM_PROMPT = (
"You are Qwen, a virtual human developed by the Qwen Team, "
@@ -44,10 +50,18 @@
)
+class OmniText2GeneralResult(OmniResult):
+ executor: str = EXECUTOR_NAME
+ mode: str = "narration"
+ audio: ArtifactRef | None
+ sample_rate: int
+ storyboard: dict[str, Any] | None = None
+
+
class OmniText2GeneralExecutor(OmniExecutorBase):
"""Generate narration/speech audio using Qwen3-Omni through vllm_omni.Omni."""
- name = "omni_text2general"
+ name = EXECUTOR_NAME
_TASK_SPEC_TYPE = OmniText2GeneralSpecStrict
def prepare(self) -> None:
@@ -67,7 +81,7 @@ def _run_inner(
spec: TaskSpecStrictBase,
spec_dict: dict[str, Any],
out_dir: Path,
- ) -> dict[str, Any]:
+ ) -> OmniText2GeneralResult:
assert isinstance(spec, OmniText2GeneralSpecStrict)
texts = self._collect_text_inputs(spec, task.task_id)
@@ -170,27 +184,22 @@ def _run_inner(
"index": idx,
"request_id": rid,
"prompt": texts[idx] if idx < len(texts) else None,
- "audio": artifact_ref(self.relative_to(save_path, artifacts_dir)),
+ "audio": ArtifactRef(
+ path=self.relative_to(save_path, artifacts_dir)
+ ),
}
text_out = text_results.get(rid)
if text_out:
item["text"] = text_out
items.append(item)
- first = items[0]["audio"] if items else {}
- result: dict[str, Any] = {
- "ok": True,
- "executor": self.name,
- "mode": "narration",
- "model": self._model_name,
- "audio": first,
- "items": items,
- "sample_rate": sample_rate,
- }
- storyboard = spec_dict.get("storyboard")
- if isinstance(storyboard, dict):
- result["storyboard"] = dict(storyboard)
- return result
+ return OmniText2GeneralResult(
+ model=self._model_name,
+ items=items,
+ audio=items[0]["audio"] if items else None,
+ sample_rate=sample_rate,
+ storyboard=spec_dict.get("storyboard"),
+ )
# ── model ────────────────────────────────────────────────────────────
diff --git a/src/worker/executors/omni_text2image_executor.py b/src/worker/executors/omni_text2image_executor.py
index e3b1f234..f11db25b 100644
--- a/src/worker/executors/omni_text2image_executor.py
+++ b/src/worker/executors/omni_text2image_executor.py
@@ -15,22 +15,29 @@
Omni = None
_HAS_OMNI = False
+from shared.schemas.artifact import ArtifactRef
from shared.schemas.governance import SpanType
from shared.tasks.specs import TaskSpecStrictBase
from shared.tasks.specs.omni import OmniText2ImageSpecStrict
from shared.utils.parsing import as_list
from .base_executor import ExecutionError, ExecutorTask
-from .omni_executor_base import OmniExecutorBase
-from .utils.checkpoints import artifact_ref
+from .omni_executor_base import OmniExecutorBase, OmniResult
logger = logging.getLogger(__name__)
+EXECUTOR_NAME = "omni_text2image"
+
+
+class OmniText2ImageResult(OmniResult):
+ executor: str = EXECUTOR_NAME
+ mode: str = "image"
+ image: ArtifactRef | None
class OmniText2ImageExecutor(OmniExecutorBase):
"""Generate images using vllm_omni.Omni."""
- name = "omni_text2image"
+ name = EXECUTOR_NAME
_TASK_SPEC_TYPE = OmniText2ImageSpecStrict
def prepare(self) -> None:
@@ -45,7 +52,7 @@ def _run_inner(
spec: TaskSpecStrictBase,
spec_dict: dict[str, Any],
out_dir: Path,
- ) -> dict[str, Any]:
+ ) -> OmniText2ImageResult:
assert isinstance(spec, OmniText2ImageSpecStrict)
prompts = self._collect_text_inputs(spec, task.task_id)
@@ -86,22 +93,17 @@ def _run_inner(
{
"index": idx,
"prompt": prompt,
- "image": artifact_ref(
- self.relative_to(save_path, artifacts_dir)
+ "image": ArtifactRef(
+ path=self.relative_to(save_path, artifacts_dir)
),
}
)
- first = items[0]["image"] if items else {}
- result: dict[str, Any] = {
- "ok": True,
- "executor": self.name,
- "mode": "image",
- "model": self._model_name,
- "image": first,
- "items": items,
- }
- return result
+ return OmniText2ImageResult(
+ model=self._model_name,
+ image=items[0]["image"] if items else None,
+ items=items,
+ )
# ── model ────────────────────────────────────────────────────────────
diff --git a/src/worker/executors/omni_text2speech_executor.py b/src/worker/executors/omni_text2speech_executor.py
index 5d7be1b7..b37aa74a 100644
--- a/src/worker/executors/omni_text2speech_executor.py
+++ b/src/worker/executors/omni_text2speech_executor.py
@@ -16,6 +16,7 @@
Omni = None
_HAS_OMNI = False
+from shared.schemas.artifact import ArtifactRef
from shared.schemas.governance import SpanType
from shared.tasks.specs import TaskSpecStrictBase
from shared.tasks.specs.omni import OmniText2SpeechSpecStrict
@@ -24,19 +25,28 @@
from .base_executor import ExecutionError, ExecutorTask
from .omni_executor_base import (
OmniExecutorBase,
+ OmniResult,
extract_audio_from_mm,
extract_multimodal_output,
save_audio,
)
-from .utils.checkpoints import artifact_ref
logger = logging.getLogger(__name__)
+EXECUTOR_NAME = "omni_text2speech"
+
+
+class OmniText2SpeechResult(OmniResult):
+ executor: str = EXECUTOR_NAME
+ mode: str = "tts"
+ audio: ArtifactRef | None
+ sample_rate: int
+ storyboard: dict[str, Any] | None = None
class OmniText2SpeechExecutor(OmniExecutorBase):
"""Generate speech audio using vllm_omni.Omni."""
- name = "omni_text2speech"
+ name = EXECUTOR_NAME
_TASK_SPEC_TYPE = OmniText2SpeechSpecStrict
def prepare(self) -> None:
@@ -51,7 +61,7 @@ def _run_inner(
spec: TaskSpecStrictBase,
spec_dict: dict[str, Any],
out_dir: Path,
- ) -> dict[str, Any]:
+ ) -> OmniText2SpeechResult:
assert isinstance(spec, OmniText2SpeechSpecStrict)
texts = self._collect_text_inputs(spec, task.task_id)
@@ -95,26 +105,19 @@ def _run_inner(
{
"index": idx,
"text": text,
- "audio": artifact_ref(
- self.relative_to(save_path, artifacts_dir)
+ "audio": ArtifactRef(
+ path=self.relative_to(save_path, artifacts_dir)
),
}
)
- first = items[0]["audio"] if items else {}
- result: dict[str, Any] = {
- "ok": True,
- "executor": self.name,
- "mode": "tts",
- "model": self._model_name,
- "audio": first,
- "items": items,
- "sample_rate": sample_rate,
- }
- storyboard = spec_dict.get("storyboard")
- if isinstance(storyboard, dict):
- result["storyboard"] = dict(storyboard)
- return result
+ return OmniText2SpeechResult(
+ model=self._model_name,
+ items=items,
+ audio=items[0]["audio"] if items else None,
+ sample_rate=sample_rate,
+ storyboard=spec_dict.get("storyboard"),
+ )
# ── model ────────────────────────────────────────────────────────────
diff --git a/src/worker/executors/ppo_executor.py b/src/worker/executors/ppo_executor.py
index 1cc47183..4c859719 100644
--- a/src/worker/executors/ppo_executor.py
+++ b/src/worker/executors/ppo_executor.py
@@ -14,10 +14,7 @@
from contextlib import contextmanager, nullcontext
from pathlib import Path
from types import SimpleNamespace
-from typing import TYPE_CHECKING, Any, cast
-
-if TYPE_CHECKING:
- from deepspeed.runtime.engine import DeepSpeedEngine
+from typing import Any, cast
import torch
from datasets import Dataset
@@ -33,6 +30,8 @@
from trl.trainer.ppo_config import PPOConfig
from trl.trainer.ppo_trainer import PPOTrainer
+from shared.schemas.artifact import ArtifactRef
+from shared.schemas.result import BaseExecutorResult
from shared.tasks.specs import PPOSpecStrict
from shared.utils.manifest import scratch_dir
from shared.utils.parsing import safe_float, safe_int, to_bool
@@ -42,7 +41,6 @@
from .mixins.training import TrainingMixin
from .utils.checkpoints import (
archive_model_dir,
- artifact_ref,
get_http_destination,
maybe_upload_artifacts,
write_executor_result,
@@ -54,6 +52,18 @@
logger = logging.getLogger("worker.ppo")
+class PPOResult(BaseExecutorResult):
+ training_time_seconds: float | None = None
+ error_message: str | None = None
+ model_name: str | None = None
+ dataset_size: int = 0
+ output_dir: str | None = None
+ checkpoints_dir: ArtifactRef | None = None
+ final_model: ArtifactRef | None = None
+ final_model_archive: ArtifactRef | None = None
+ spawned_torchrun: bool = False
+
+
class _ExternalRewardModel(torch.nn.Module):
"""Wraps a sequence classification model to score decoded PPO responses."""
@@ -392,7 +402,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
self._reward_module: _ExternalRewardModel | _RewardAdapter | None = None
self._task_out_dir: Path | None = None
- def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]:
+ def run(self, task: ExecutorTask, out_dir: Path) -> PPOResult:
configure_hf_library_logging()
logger.info("Starting PPO training task")
spec = self.require_spec(task, PPOSpecStrict)
@@ -418,16 +428,14 @@ def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]:
)
ipc_path = scratch_dir(out_dir) / "distributed_result.json"
if ipc_path.exists():
- result = self.load_json(ipc_path)
self._task_out_dir = None
- return result
+ return PPOResult.model_validate(self.load_json(ipc_path))
self._task_out_dir = None
- return {
- "training_successful": True,
- "spawned_torchrun": True,
- "model_name": spec.model_name,
- "output_dir": out_dir.as_posix(),
- }
+ return PPOResult(
+ spawned_torchrun=True,
+ model_name=spec.model_name,
+ output_dir=out_dir.as_posix(),
+ )
start_time = time.time()
@@ -853,7 +861,7 @@ def build_trainer() -> _EarlyStopPPOTrainer:
pass
logger.info("PPO training completed")
- training_successful = True
+ ok = True
error_msg = None
# Save model if requested
@@ -881,44 +889,48 @@ def build_trainer() -> _EarlyStopPPOTrainer:
logger.warning("Failed to save model: %s", exc)
except Exception as exc:
- training_successful = False
+ ok = False
error_msg = str(exc)
logger.exception("PPO training failed: %s", exc)
training_time = time.time() - start_time
dataset_size = len(dataset) if "dataset" in locals() else 0 # type: ignore
- results = {
- "training_successful": training_successful,
- "training_time_seconds": training_time,
- "error_message": error_msg,
- "model_name": self._model_name,
- "dataset_size": dataset_size,
- "output_dir": out_dir.as_posix(),
- }
+ result = PPOResult(
+ ok=ok,
+ training_time_seconds=training_time,
+ error_message=error_msg,
+ model_name=self._model_name,
+ dataset_size=dataset_size,
+ output_dir=out_dir.as_posix(),
+ )
write_executor_result(
- out_dir / "results.json", task.task_id, task.spec, results
+ out_dir / "results.json", task.task_id, task.spec, result
)
self._task_out_dir = None
raise ExecutionError(error_msg or "PPO training failed") from exc
training_time = time.time() - start_time
- results = {
- "training_successful": training_successful,
- "training_time_seconds": training_time,
- "error_message": error_msg,
- "model_name": self._model_name,
- "dataset_size": len(dataset),
- "output_dir": out_dir.as_posix(),
- "checkpoints_dir": artifact_ref("checkpoints"),
- }
- if final_model_path is not None:
- results["final_model"] = artifact_ref(
- final_model_path.relative_to(artifacts_dir).as_posix()
- )
- if final_archive_path is not None:
- results["final_model_archive"] = artifact_ref(
- final_archive_path.relative_to(artifacts_dir).as_posix()
- )
+ result = PPOResult(
+ ok=ok,
+ training_time_seconds=training_time,
+ error_message=error_msg,
+ model_name=self._model_name,
+ dataset_size=len(dataset),
+ output_dir=out_dir.as_posix(),
+ checkpoints_dir=ArtifactRef(path="checkpoints"),
+ final_model=(
+ ArtifactRef(path=final_model_path.relative_to(artifacts_dir).as_posix())
+ if final_model_path
+ else None
+ ),
+ final_model_archive=(
+ ArtifactRef(
+ path=final_archive_path.relative_to(artifacts_dir).as_posix()
+ )
+ if final_archive_path
+ else None
+ ),
+ )
maybe_upload_artifacts(task, out_dir, logger=logger)
@@ -931,7 +943,7 @@ def build_trainer() -> _EarlyStopPPOTrainer:
logger.info("PPO training task completed in %.2f seconds", training_time)
self._task_out_dir = None
- return results
+ return result
def _ensure_jsonl_local(self, jsonl_cfg: dict[str, Any]) -> Path:
headers_cfg = (
@@ -1446,7 +1458,7 @@ def _wrapped_save_model(
output_dir: str | None = None, _internal_call: bool = False
) -> None:
backup_model = ppo_trainer.model
- backup_deepspeed: DeepSpeedEngine | None = None
+ backup_deepspeed: Any = None
ppo_trainer.model = self._resolve_model_for_save(backup_model)
if ppo_trainer.is_deepspeed_enabled:
backup_deepspeed = ppo_trainer.deepspeed
diff --git a/src/worker/executors/rag_executor.py b/src/worker/executors/rag_executor.py
index 839041bf..06b3c705 100644
--- a/src/worker/executors/rag_executor.py
+++ b/src/worker/executors/rag_executor.py
@@ -16,18 +16,29 @@
from datasets import load_dataset
from qdrant_client import QdrantClient, models
+from shared.schemas.result import BaseExecutorResult
from shared.tasks.specs import RagSpecStrict
from .base_executor import ExecutionError, Executor, ExecutorTask
from .utils.graph_templates import Message, build_prompts_from_graph_template
logger = logging.getLogger("worker.rag")
+EXECUTOR_NAME = "rag"
+
+
+class RAGResult(BaseExecutorResult):
+ executor: str = EXECUTOR_NAME
+ qdrant: dict[str, Any]
+ embedding: dict[str, Any]
+ search: dict[str, Any]
+ queries: list[dict[str, Any]] = []
+ usage: dict[str, Any] | None = None
class RAGExecutor(Executor):
- name = "rag"
+ name = EXECUTOR_NAME
- def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]:
+ def run(self, task: ExecutorTask, out_dir: Path) -> RAGResult:
start_ts = time.time()
spec = self.require_spec(task, RagSpecStrict)
@@ -180,22 +191,17 @@ def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]:
}
)
- # Compose response
- out: dict[str, Any] = {
- "ok": True,
- "executor": self.name,
- "qdrant": {"collection": collection, "url": url},
- "embedding": {"model": model_name},
- "search": {"top_k": top_k},
- "queries": results_per_query,
- "usage": {
+ logger.info(
+ "RAG query completed queries=%d total_results=%d", len(queries), total_items
+ )
+ return RAGResult(
+ qdrant={"collection": collection, "url": url},
+ embedding={"model": model_name},
+ search={"top_k": top_k},
+ queries=results_per_query,
+ usage={
"latency_sec": round(time.time() - start_ts, 4),
"num_queries": len(queries),
"total_results": total_items,
},
- }
-
- logger.info(
- "RAG query completed queries=%d total_results=%d", len(queries), total_items
)
- return out
diff --git a/src/worker/executors/sft_executor.py b/src/worker/executors/sft_executor.py
index 4361c94d..030dcd81 100644
--- a/src/worker/executors/sft_executor.py
+++ b/src/worker/executors/sft_executor.py
@@ -22,6 +22,8 @@
from trl.trainer.sft_config import SFTConfig
from trl.trainer.sft_trainer import SFTTrainer
+from shared.schemas.artifact import ArtifactRef
+from shared.schemas.result import BaseExecutorResult
from shared.tasks.specs import SFTSpecStrict, TaskSpecStrictBase
from shared.utils.manifest import scratch_dir
@@ -30,7 +32,6 @@
from .mixins.training import TrainingMixin
from .utils.checkpoints import (
archive_model_dir,
- artifact_ref,
determine_resume_path,
get_http_destination,
maybe_upload_artifacts,
@@ -43,6 +44,19 @@
logger = logging.getLogger("worker.sft")
+class SFTResult(BaseExecutorResult):
+ training_time_seconds: float | None = None
+ error_message: str | None = None
+ model_name: str | None = None
+ dataset_size: int = 0
+ output_dir: str | None = None
+ checkpoints_dir: ArtifactRef | None = None
+ resume_from_path: str | None = None
+ final_model: ArtifactRef | None = None
+ final_model_archive: ArtifactRef | None = None
+ spawned_torchrun: bool = False
+
+
class SFTExecutor(TrainingMixin, Executor):
name = "sft_executor"
@@ -55,12 +69,12 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
self._final_model_dir: Path | None = None
self._task_out_dir: Path | None = None
- def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]:
+ def run(self, task: ExecutorTask, out_dir: Path) -> SFTResult:
configure_hf_library_logging()
spec = self.require_spec(task, SFTSpecStrict)
requested_gpu_count = self._requested_gpu_count(spec)
start_time = time.time()
- training_successful = False
+ ok = False
error_msg: str | None = None
caught_exc: Exception | None = None
self._task_out_dir = out_dir
@@ -189,22 +203,22 @@ def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]:
)
ipc_path = scratch_dir(out_dir) / "distributed_result.json"
if ipc_path.exists():
- distributed_result = self.load_json(ipc_path)
self._task_out_dir = None
- return distributed_result
+ return SFTResult.model_validate(self.load_json(ipc_path))
self._task_out_dir = None
- return {
- "training_successful": True,
- "spawned_torchrun": True,
- "model_name": spec.model_name,
- "output_dir": out_dir.as_posix(),
- }
+ return SFTResult(
+ spawned_torchrun=True,
+ model_name=spec.model_name,
+ output_dir=out_dir.as_posix(),
+ )
except Exception as spawn_exc:
logger.exception("Failed to launch distributed SFT: %s", spawn_exc)
raise ExecutionError(
"Failed to launch distributed SFT subprocess"
) from spawn_exc
+ resume_path: Path | None = None
+ train_dataset: Any = None
try:
# Proceed with in-process training (single GPU or inside torchrun)
model_name = spec.model_name or "gpt2"
@@ -213,7 +227,7 @@ def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]:
resume_path = determine_resume_path(
spec, training_cfg, out_dir, logger=logger
)
- resume_str = str(resume_path) if resume_path else None
+ resume_str = resume_path.as_posix() if resume_path else None
if resume_path:
logger.info(
@@ -291,7 +305,7 @@ def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]:
)
sft_config = SFTConfig(
- output_dir=str(checkpoint_dir),
+ output_dir=checkpoint_dir.as_posix(),
num_train_epochs=float(training_cfg.get("num_train_epochs", 1.0)),
per_device_train_batch_size=int(training_cfg.get("batch_size", 2)),
gradient_accumulation_steps=int(
@@ -394,14 +408,14 @@ def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]:
logger.info("Starting supervised fine-tuning")
trainer.train()
- training_successful = True
+ ok = True
logger.info("Training finished")
final_model_path: Path | None = None
final_archive_path: Path | None = None
if bool(training_cfg.get("save_model", True)):
model_path = artifacts_dir / "final_model"
- trainer.save_model(str(model_path))
+ trainer.save_model(model_path.as_posix())
tokenizer.save_pretrained(model_path)
final_model_path = model_path
logger.info("Saved fine-tuned model to %s", model_path)
@@ -422,30 +436,35 @@ def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]:
)
training_time = time.time() - start_time
- result_payload: dict[str, Any] = {
- "task_id": task.task_id,
- "training_successful": training_successful,
- "training_time_seconds": training_time,
- "error_message": error_msg,
- "model_name": self._model_name,
- "dataset_size": len(train_dataset),
- "output_dir": out_dir.as_posix(),
- "checkpoints_dir": artifact_ref("checkpoints"),
- "resume_from_path": resume_str,
- }
-
- if final_model_path is not None:
+ final_model: ArtifactRef | None = None
+ final_model_archive: ArtifactRef | None = None
+ if final_model_path:
resolved_model_path = Path(final_model_path)
self._final_model_dir = (
resolved_model_path if resolved_model_path.exists() else None
)
- result_payload["final_model"] = artifact_ref(
- final_model_path.relative_to(artifacts_dir).as_posix()
+ final_model = ArtifactRef(
+ path=final_model_path.relative_to(artifacts_dir).as_posix()
)
- if final_archive_path is not None:
- result_payload["final_model_archive"] = artifact_ref(
- final_archive_path.relative_to(artifacts_dir).as_posix()
+ final_model_archive = (
+ ArtifactRef(
+ path=final_archive_path.relative_to(artifacts_dir).as_posix()
)
+ if final_archive_path
+ else None
+ )
+ result = SFTResult(
+ ok=ok,
+ training_time_seconds=training_time,
+ error_message=error_msg,
+ model_name=self._model_name,
+ dataset_size=len(train_dataset),
+ output_dir=out_dir.as_posix(),
+ checkpoints_dir=ArtifactRef(path="checkpoints"),
+ resume_from_path=resume_str,
+ final_model=final_model,
+ final_model_archive=final_model_archive,
+ )
maybe_upload_artifacts(task, out_dir, logger=logger)
@@ -467,11 +486,11 @@ def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]:
self._final_model_dir = None
self._task_out_dir = None
- return result_payload
+ return result
except Exception as exc:
error_msg = str(exc)
- training_successful = False
+ ok = False
caught_exc = exc
logger.exception("SFT training failed: %s", exc)
trainer = None
@@ -484,23 +503,20 @@ def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]:
self._final_model_dir = None
training_time = time.time() - start_time
- result: dict[str, Any] = {
- "task_id": task.task_id,
- "training_successful": training_successful,
- "training_time_seconds": training_time,
- "error_message": error_msg,
- "model_name": self._model_name,
- "dataset_size": (
+ result = SFTResult(
+ ok=ok,
+ training_time_seconds=training_time,
+ error_message=error_msg,
+ model_name=self._model_name,
+ dataset_size=(
len(train_dataset)
if "train_dataset" in locals() and train_dataset is not None
else 0
),
- "output_dir": out_dir.as_posix(),
- "checkpoints_dir": artifact_ref("checkpoints"),
- "resume_from_path": (
- str(resume_path) if "resume_path" in locals() and resume_path else None
- ),
- }
+ output_dir=out_dir.as_posix(),
+ checkpoints_dir=ArtifactRef(path="checkpoints"),
+ resume_from_path=resume_path.as_posix() if resume_path else None,
+ )
write_executor_result(out_dir / "results.json", task.task_id, task.spec, result)
if caught_exc is not None:
self._task_out_dir = None
@@ -602,7 +618,7 @@ def _ensure_jsonl_local(self, jsonl_cfg: dict[str, Any]) -> Path:
timeout=timeout,
logger=logger,
)
- jsonl_cfg["path"] = str(resolved)
+ jsonl_cfg["path"] = resolved.as_posix()
return resolved
except ExecutionError as exc:
last_error = exc
@@ -823,13 +839,13 @@ def _resolve_deepspeed_config(training_cfg: dict[str, Any], log) -> Any | None:
if isinstance(cfg, dict):
return cfg
if isinstance(cfg, (str, Path)):
- candidate = Path(str(cfg)).expanduser()
+ candidate = Path(cfg).expanduser()
if candidate.exists():
log.info("Using DeepSpeed config at %s", candidate)
- return str(candidate)
+ return candidate.as_posix()
# Allow literal identifiers (e.g., 'auto') without file presence.
log.info("Using DeepSpeed literal configuration '%s'", cfg)
- return str(cfg)
+ return cfg if isinstance(cfg, str) else cfg.as_posix()
raise ValueError("training.deepspeed must be a dict, path string, or falsy")
@staticmethod
diff --git a/src/worker/executors/ssh_executor.py b/src/worker/executors/ssh_executor.py
index ef6d7958..107e93df 100644
--- a/src/worker/executors/ssh_executor.py
+++ b/src/worker/executors/ssh_executor.py
@@ -29,6 +29,7 @@
import requests
+from shared.schemas.result import BaseExecutorResult
from shared.tasks.components.resources import GPURequirements
from shared.tasks.specs.ssh import (
SSHInputSpec,
@@ -55,6 +56,14 @@
TaskCancelledError,
)
+
+class SSHResult(BaseExecutorResult):
+ session_id: str
+ exit_code: int
+ command: list[str] | None = None
+ entrypoint: list[str] | None = None
+
+
try:
import docker
from docker import DockerClient
@@ -425,7 +434,7 @@ def teardown(self) -> None:
# Main execution
# ------------------------------------------------------------------ #
- def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]:
+ def run(self, task: ExecutorTask, out_dir: Path) -> SSHResult:
spec = self.require_spec(task, SSHSpecStrict)
cfg = SSHConfig.from_spec(spec, self._config, self._hardware)
access_mode = cfg.access_mode
@@ -539,16 +548,17 @@ def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]:
cfg.output,
mount_plan,
)
- result: dict[str, Any] = {"session_id": session_id, "exit_code": exit_code}
+ result = SSHResult(session_id=session_id, exit_code=exit_code)
if interactive:
- result.update(session_info)
+ for key, value in session_info.items():
+ setattr(result, key, value)
else:
# Keep as fallback — captures any output the streaming thread missed.
self._save_container_logs(container, out_dir)
if cfg.command is not None:
- result["command"] = cfg.command
+ result.command = cfg.command
if cfg.entrypoint is not None:
- result["entrypoint"] = cfg.entrypoint
+ result.entrypoint = cfg.entrypoint
if mount_plan.copy_output_path:
self._copy_output_directory(
container,
diff --git a/src/worker/executors/transformers_executor.py b/src/worker/executors/transformers_executor.py
index c4a3b586..beaed479 100644
--- a/src/worker/executors/transformers_executor.py
+++ b/src/worker/executors/transformers_executor.py
@@ -56,7 +56,9 @@
from pathlib import Path
from typing import TYPE_CHECKING, Any
+from shared.schemas.artifact import ArtifactRef
from shared.schemas.governance import SpanType
+from shared.schemas.result import BaseExecutorResult
from shared.tasks.specs import (
EmbeddingSpecStrict,
InferenceSpecStrict,
@@ -66,11 +68,7 @@
from .base_executor import ExecutionError, Executor, ExecutorTask
from .mixins.data import InferenceEntry
from .mixins.inference import InferenceMixin
-from .utils.checkpoints import (
- artifact_ref,
- maybe_upload_artifacts,
- maybe_upload_traces,
-)
+from .utils.checkpoints import maybe_upload_artifacts, maybe_upload_traces
try:
import torch
@@ -115,6 +113,16 @@
logger = logging.getLogger(__name__)
+class TransformersResult(BaseExecutorResult):
+ ok: bool = True
+ model: str | None = None
+ items: list[dict[str, Any]] = []
+ usage: dict[str, Any] | None = None
+ count: int | None = None
+ embedding_file: ArtifactRef | None = None
+ image_group_sizes: list[int] | None = None
+
+
class HFTransformersExecutor(InferenceMixin, Executor):
"""Executor that runs text generation via Hugging Face Transformers."""
@@ -384,7 +392,7 @@ def _detect_finish_reason(
return "length"
return None
- def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]: # type: ignore[override]
+ def run(self, task: ExecutorTask, out_dir: Path) -> TransformersResult:
configure_hf_library_logging()
spec = task.spec
if not isinstance(spec, (InferenceSpecStrict, EmbeddingSpecStrict)):
@@ -406,7 +414,7 @@ def _run_inner(
spec: "InferenceSpecStrict | EmbeddingSpecStrict",
task_id: str,
out_dir: Path,
- ) -> dict[str, Any]:
+ ) -> TransformersResult:
with self._span("model load", span_type=SpanType.COMPUTE):
self._ensure_model(spec)
@@ -423,8 +431,6 @@ def _run_inner(
assert self._model is not None
- result: dict[str, Any] = {}
-
if self._mode == "visual-embedding":
assert self._image_processor is not None
device = self._device or ("cuda" if torch.cuda.is_available() else "cpu")
@@ -479,15 +485,13 @@ def _run_inner(
emb_path = artifacts_dir / "visual_embeddings.pt"
torch.save(grouped_visual_embeddings, emb_path)
- result = {
- "ok": True,
- "model": self._model_name,
- "items": [], # Embeddings are in file
- "count": len(grouped_visual_embeddings),
- "embedding_file": artifact_ref("visual_embeddings.pt"),
- }
- if image_group_sizes is not None:
- result["image_group_sizes"] = image_group_sizes
+ result = TransformersResult(
+ model=self._model_name,
+ items=[],
+ count=len(grouped_visual_embeddings),
+ embedding_file=ArtifactRef(path="visual_embeddings.pt"),
+ image_group_sizes=image_group_sizes,
+ )
self._dump_to_governance(
task_id=task_id,
@@ -593,21 +597,20 @@ def _run_inner(
prompt_tokens += int(input_len)
completion_tokens += int(gen_part.shape[0])
- result = {
- "ok": True,
- "model": self._model_name,
- "items": items,
- "usage": {
+ result = TransformersResult(
+ model=self._model_name,
+ items=items,
+ usage={
"prompt_tokens": int(prompt_tokens),
"completion_tokens": int(completion_tokens),
"total_tokens": int(prompt_tokens + completion_tokens),
"num_requests": len(self._prompts),
"latency_sec": latency,
},
- }
+ )
if isinstance(spec, InferenceSpecStrict):
- self._maybe_export_jsonl(spec, task_id, result, out_dir)
+ self._maybe_export_jsonl(spec, task_id, items, out_dir)
self._dump_to_governance(
task_id=task_id,
diff --git a/src/worker/executors/utils/artifacts.py b/src/worker/executors/utils/artifacts.py
index 7f4df0be..8605433b 100644
--- a/src/worker/executors/utils/artifacts.py
+++ b/src/worker/executors/utils/artifacts.py
@@ -6,48 +6,39 @@
import requests
+from shared.schemas.result import BaseExecutorResult
from shared.utils.http import auth_headers
from ..base_executor import ExecutionError
def artifact_to_source(
- ref: dict[str, Any], context: dict[str, Any] | None, node: str | None
+ ref: dict[str, Any], context: dict[str, BaseExecutorResult] | None, node: str | None
) -> str:
"""Translate a `{path: ...}` artifact ref into a URL or local path."""
rel_path = ref.get("path")
if not isinstance(rel_path, str) or not rel_path:
raise ExecutionError("Artifact ref must include a non-empty 'path' field")
- ctx: dict[str, Any] = {}
- if context and isinstance(node, str) and node:
- node_payload = context.get(node)
- if isinstance(node_payload, dict):
- raw_ctx = node_payload.get("_artifacts")
- if isinstance(raw_ctx, dict):
- ctx = raw_ctx
- if not ctx:
- # Some executors stuff their result under a "result" key.
- result = node_payload.get("result")
- if isinstance(result, dict):
- inner = result.get("_artifacts")
- if isinstance(inner, dict):
- ctx = inner
-
- if ctx:
- base_url = ctx.get("base_url")
- base_dir = ctx.get("base_dir")
+ if (
+ context
+ and node
+ and (node_result := context.get(node))
+ and (ctx := node_result.artifacts_)
+ ):
+ base_url = ctx.base_url
+ base_dir = ctx.base_dir
else:
base_url = base_dir = None
# Check local filesystem first
- base_dir_path = Path(base_dir) if isinstance(base_dir, str) and base_dir else None
+ base_dir_path = Path(base_dir) if base_dir else None
local_path = base_dir_path / "artifacts" / rel_path if base_dir_path else None
if local_path is not None and local_path.is_file():
return local_path.as_posix()
# Fallback to a URL if base_url is provided
- if isinstance(base_url, str) and base_url:
+ if base_url:
if not base_dir_path:
raise ExecutionError(
"Artifact ref with base_url requires upstream base_dir to "
@@ -66,7 +57,7 @@ def artifact_to_source(
def maybe_resolve_artifact_ref(
- value: Any, context: dict[str, Any] | None, node: str | None
+ value: Any, context: dict[str, BaseExecutorResult] | None, node: str | None
) -> Any:
"""Convert `{path: ...}` ref dicts to URL/path strings; pass others through."""
if isinstance(value, dict) and "path" in value:
diff --git a/src/worker/executors/utils/checkpoints.py b/src/worker/executors/utils/checkpoints.py
index 60ba6053..ed57f369 100644
--- a/src/worker/executors/utils/checkpoints.py
+++ b/src/worker/executors/utils/checkpoints.py
@@ -11,8 +11,10 @@
import requests
-from shared.schemas.result import write_result_in_envelope
+from shared.schemas.artifact import ArtifactContext
+from shared.schemas.result import BaseExecutorResult, ResultEnvelope
from shared.tasks.specs import TaskSpecStrictBase
+from shared.utils.atomic import atomic_write_text
from shared.utils.http import add_auth_headers
from shared.utils.parsing import parse_bool_env
@@ -394,15 +396,9 @@ def is_cleanup_enabled() -> bool:
return normalized not in {"0", "false", "no", "off"}
-def artifact_ref(rel_path: str) -> dict[str, str]:
- """Build an artifact reference dict. `rel_path` is the path relative to
- `out_dir/artifacts/`"""
- return {"path": rel_path}
-
-
-def build_artifact_context(spec: TaskSpecStrictBase, out_dir: Path) -> dict[str, Any]:
- """Top-level `_artifacts` context: {base_dir, base_url}. base_url is the
- destination origin (scheme://host[:port]) for HTTP, else None."""
+def build_artifact_context(spec: TaskSpecStrictBase, out_dir: Path) -> ArtifactContext:
+ """Top-level ``_artifacts`` context. ``base_url`` is the destination
+ origin (scheme://host[:port]) for HTTP, else ``None``."""
base_dir = Path(out_dir).resolve().as_posix()
base_url: str | None = None
destination = get_http_destination(spec)
@@ -410,18 +406,17 @@ def build_artifact_context(spec: TaskSpecStrictBase, out_dir: Path) -> dict[str,
parsed = urlparse(destination.url)
if parsed.scheme and parsed.netloc:
base_url = f"{parsed.scheme}://{parsed.netloc}"
- return {"base_dir": base_dir, "base_url": base_url}
+ return ArtifactContext(base_dir=base_dir, base_url=base_url)
def write_executor_result(
- path: Path,
- task_id: str,
- spec: TaskSpecStrictBase,
- result: dict[str, Any],
+ path: Path, task_id: str, spec: TaskSpecStrictBase, result: BaseExecutorResult
) -> None:
"""Stamp ``_artifacts`` onto ``result`` and persist the envelope."""
- result["_artifacts"] = build_artifact_context(spec, path.parent)
- write_result_in_envelope(path, task_id, result)
+ path.parent.mkdir(parents=True, exist_ok=True)
+ result.artifacts_ = build_artifact_context(spec, path.parent)
+ envelope = ResultEnvelope(task_id=task_id, result=result)
+ atomic_write_text(path, envelope.model_dump_json(indent=2))
def maybe_upload_artifacts(
diff --git a/src/worker/executors/utils/graph_templates.py b/src/worker/executors/utils/graph_templates.py
index eb01ea1a..e51aad54 100644
--- a/src/worker/executors/utils/graph_templates.py
+++ b/src/worker/executors/utils/graph_templates.py
@@ -5,7 +5,9 @@
from typing import Any
import pandas as pd
+from pydantic import BaseModel
+from shared.schemas.result import BaseExecutorResult
from shared.tasks.specs import TaskSpecStrictBase
from shared.utils.json import validate_keys
@@ -13,6 +15,8 @@
from ..base_executor import ExecutionError
from .safe_eval import safe_execute_function, safe_materialize_function
+_SENTINEL: Any = object()
+
type MessageItem = dict[str, str]
type Message = Sequence[MessageItem]
type MaterializedMessage = Sequence[MessageItem]
@@ -81,7 +85,7 @@ def build_prompts_from_graph_template(
def _maybe_broadcast_image_prompts(
prompts: Sequence[str | Message],
data_cfg: dict[str, Any],
- context: dict[str, Any],
+ context: dict[str, BaseExecutorResult],
) -> list[str | Message]:
image_embedding = data_cfg.get("image_embedding")
if not isinstance(image_embedding, dict):
@@ -95,7 +99,7 @@ def _maybe_broadcast_image_prompts(
def _resolve_image_embedding_count(
- image_embedding: dict[str, Any], context: dict[str, Any]
+ image_embedding: dict[str, Any], context: dict[str, BaseExecutorResult]
) -> int | None:
node = image_embedding.get("node")
if not isinstance(node, str):
@@ -105,15 +109,17 @@ def _resolve_image_embedding_count(
if not isinstance(node, str) or not node:
return None
upstream = context.get(node)
- if not isinstance(upstream, dict):
+ if upstream is None:
return None
- count = upstream.get("count")
+ count = getattr(upstream, "count", _SENTINEL)
if isinstance(count, int) and count > 0:
return count
return None
-def _resolve_columns(columns_cfg: Any, context: dict[str, Any]) -> list[dict[str, Any]]:
+def _resolve_columns(
+ columns_cfg: Any, context: dict[str, BaseExecutorResult]
+) -> list[dict[str, Any]]:
if not isinstance(columns_cfg, list):
raise ExecutionError("graph_template.template.columns must be a list.")
@@ -580,17 +586,17 @@ def _format_column_line(label: str, value: str) -> str:
return f"• {label}: {indented}"
-def _evaluate_expr(expr: str, context: dict[str, Any]) -> Any:
+def _evaluate_expr(expr: str, context: dict[str, BaseExecutorResult]) -> Any:
if not expr:
return None
parts = expr.split(".")
root = parts[0]
- data = context.get(root)
- if data is None:
+ result = context.get(root)
+ if result is None:
return None
- value: Any = data
+ value: Any = result
for token in parts[1:]:
if not token:
continue
@@ -617,6 +623,14 @@ def _evaluate_expr(expr: str, context: dict[str, Any]) -> Any:
f"{attr} not a valid column in DataFrame for {token}."
)
value = value[attr].tolist()
+ elif isinstance(value, BaseModel):
+ resolved = getattr(value, attr, _SENTINEL)
+ if resolved is _SENTINEL:
+ raise ExecutionError(
+ f"{attr} not a valid attribute of {type(value).__name__} "
+ f"for {token}."
+ )
+ value = resolved
else:
raise ExecutionError(
f"{attr} in {parts} is not a valid key - "
diff --git a/src/worker/executors/vllm_executor.py b/src/worker/executors/vllm_executor.py
index 6ad6abd1..bb94b21d 100644
--- a/src/worker/executors/vllm_executor.py
+++ b/src/worker/executors/vllm_executor.py
@@ -67,6 +67,7 @@
StructuredOutputsParams = None # type: ignore
from shared.schemas.governance import SpanType
+from shared.schemas.result import BaseExecutorResult
from shared.tasks.specs import InferenceSpecStrict
from .base_executor import ExecutionError, Executor, ExecutorTask
@@ -81,6 +82,13 @@
logger = logging.getLogger(__name__)
+class VLLMResult(BaseExecutorResult):
+ ok: bool = True
+ model: str | None = None
+ items: list[dict[str, Any]] = []
+ usage: dict[str, Any] | None = None
+
+
class _RawJsonSchema:
"""Tag wrapping a JSON schema so ``_build_sampling_params`` can
``isinstance``-dispatch raw schemas vs. named-fields pydantic kwargs."""
@@ -849,7 +857,7 @@ def _postprocess_prompts(self, parsed: InferenceEntry) -> PreparedInferenceEntry
# --------------------------------------------------------------------- #
# Execution
# --------------------------------------------------------------------- #
- def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]: # type: ignore[override]
+ def run(self, task: ExecutorTask, out_dir: Path) -> VLLMResult:
spec = self.require_spec(task, InferenceSpecStrict)
task_id = task.task_id.strip()
if not task_id:
@@ -868,7 +876,7 @@ def _run_inner(
task: ExecutorTask,
spec: InferenceSpecStrict,
out_dir: Path,
- ) -> dict[str, Any]:
+ ) -> VLLMResult:
task_id = task.task_id.strip()
merge_children = task.merged_children or []
entries: list[PreparedInferenceEntry] = []
@@ -1149,51 +1157,45 @@ def _collect(
"num_requests": len(self._batched_inputs),
}
- result: dict[str, Any] = {
- "ok": True,
- "model": self._model_name,
- "items": per_task_items.get(task_id, []),
- "usage": parent_usage,
- }
-
- child_results: dict[str, Any] = {}
+ items = per_task_items.get(task_id, [])
+ child_results: dict[str, VLLMResult] = {}
for child in merge_children:
child_id = child.task_id.strip()
if not child_id:
continue
- child_payload: dict[str, Any] = {
- "items": per_task_items.get(child_id, []),
- }
+ child_items = per_task_items.get(child_id, [])
maybe_usage = usage_by_task.get(child_id)
- if maybe_usage:
- child_payload["usage"] = maybe_usage
- child_results[child_id] = child_payload
+ child_results[child_id] = VLLMResult(
+ model=self._model_name, items=child_items, usage=maybe_usage
+ )
if parent_tables := parent_entry.tables:
- result = self._populate_table(result, parent_tables)
+ items = self._populate_table(items, parent_tables)
if child_results:
for child_id, child_payload in child_results.items():
if (child_entry := entry_by_task_id.get(child_id)) and (
child_tables := child_entry.tables
):
- child_results[child_id] = self._populate_table(
- child_payload, child_tables
+ child_results[child_id].items = self._populate_table(
+ child_payload.items, child_tables
)
- if child_results:
- result["children"] = child_results
+ result = VLLMResult(
+ children=cast(dict[str, BaseExecutorResult], child_results),
+ model=self._model_name,
+ items=items,
+ usage=parent_usage,
+ )
with self._span(
"JSONL export",
span_type=SpanType.COMPUTE,
attributes={"task_ids": task_ids},
):
- self._maybe_export_jsonl(spec, task_id, result, out_dir)
+ self._maybe_export_jsonl(spec, task_id, result.items, out_dir)
self._dump_to_governance(
- task_id=task_id,
- result=result,
- dependencies_by_task=dependencies_by_task,
+ task_id=task_id, result=result, dependencies_by_task=dependencies_by_task
)
return result
diff --git a/src/worker/runner.py b/src/worker/runner.py
index eef7e292..7b7086de 100644
--- a/src/worker/runner.py
+++ b/src/worker/runner.py
@@ -10,6 +10,7 @@
import requests
+from shared.schemas.result import BaseExecutorResult
from shared.tasks import MergedChildTaskStrict
from shared.tasks.specs import InferenceSpecStrict, TaskSpecStrictBase
from shared.tasks.worker_message import (
@@ -127,17 +128,14 @@ def _write_results(
spec: TaskSpecStrictBase,
merged_children: list[MergedChildTaskStrict],
out_dir: Path,
- result: dict[str, Any] | None,
+ result: BaseExecutorResult | None,
):
if result is None:
return
self._write_single_result(task_id, spec, out_dir, result)
child_lookup = {entry.task_id: entry for entry in merged_children}
- children_payload = (
- result.get("children") if isinstance(result, dict) else {}
- ) or {}
- for child_id, child_result in children_payload.items():
+ for child_id, child_result in result.children.items():
child_info = child_lookup.get(child_id)
if child_info is None:
continue
@@ -151,7 +149,7 @@ def _write_single_result(
task_id: str,
spec: TaskSpecStrictBase,
out_dir: Path,
- payload: dict[str, Any] | None,
+ payload: BaseExecutorResult | None,
):
if payload is None:
return
@@ -179,7 +177,7 @@ def _simulate_bandwidth_delay(self, payload_bytes: int, destination: str) -> Non
time.sleep(delay)
def _maybe_emit_http(
- self, task_id: str, spec: TaskSpecStrictBase, result: dict[str, Any]
+ self, task_id: str, spec: TaskSpecStrictBase, result: BaseExecutorResult
) -> None:
"""Send task results to an HTTP endpoint when requested by the spec."""
destination = get_http_destination(spec)
@@ -190,7 +188,7 @@ def _maybe_emit_http(
ignore_error = destination.ignore_error
payload = {
"task_id": task_id,
- "result": result,
+ "result": result.model_dump(),
"worker_id": self.lifecycle.worker_id,
}
payload_size = len(json.dumps(payload, ensure_ascii=False).encode("utf-8"))
diff --git a/tests/sdk/test_schema_compat.py b/tests/sdk/test_schema_compat.py
index b757060f..d4092a7e 100644
--- a/tests/sdk/test_schema_compat.py
+++ b/tests/sdk/test_schema_compat.py
@@ -8,6 +8,9 @@
# SDK-side imports
from flowmesh.models import (
+ ArtifactContext,
+ ArtifactRef,
+ BaseExecutorResult,
CPUInfo,
GpuInfo,
GpuPlatformInfo,
@@ -79,6 +82,9 @@
)
from server.task.models import TaskInfo as SrvTaskInfo
from server.task.models import TaskUsage as SrvTaskUsage
+from shared.schemas.artifact import ArtifactContext as SrvArtifactContext
+from shared.schemas.artifact import ArtifactRef as SrvArtifactRef
+from shared.schemas.result import BaseExecutorResult as SrvBaseExecutorResult
from shared.schemas.result import ResultEnvelope as SrvResultEnvelope
from shared.schemas.worker import SSHLimits as SrvSSHLimits
from shared.tasks.task_type import TaskType as SrvTaskType
@@ -123,7 +129,11 @@
# Common
(SrvOkResponse, OkResponse),
# Results
+ (SrvBaseExecutorResult, BaseExecutorResult),
(SrvResultEnvelope, ResultEnvelope),
+ # Artifacts
+ (SrvArtifactContext, ArtifactContext),
+ (SrvArtifactRef, ArtifactRef),
]
diff --git a/tests/shared/test_executor_result.py b/tests/shared/test_executor_result.py
new file mode 100644
index 00000000..4b5d79d3
--- /dev/null
+++ b/tests/shared/test_executor_result.py
@@ -0,0 +1,71 @@
+"""Round-trip tests for the shared executor-result schema."""
+
+import json
+from typing import Any
+
+from pydantic import Field
+
+from shared.schemas.artifact import ArtifactContext, ArtifactRef
+from shared.schemas.result import BaseExecutorResult, ResultEnvelope
+
+
+class _SampleResult(BaseExecutorResult):
+ ok: bool = True
+ items: list[dict[str, Any]] = Field(default_factory=list)
+ usage: dict[str, Any] | None = None
+
+
+def test_subclass_round_trip_through_base_preserves_extra_fields() -> None:
+ """A subclass's executor-specific fields survive a JSON trip through the
+ base class via ``extra='allow'``."""
+ original = _SampleResult(
+ ok=True,
+ items=[{"output": "hello"}],
+ usage={"latency_sec": 0.5},
+ _artifacts=ArtifactContext(base_dir="/tmp/t", base_url="http://h"),
+ )
+ payload = original.model_dump_json()
+
+ base = BaseExecutorResult.model_validate_json(payload)
+ redumped = json.loads(base.model_dump_json())
+
+ assert redumped["items"] == [{"output": "hello"}]
+ assert redumped["usage"] == {"latency_sec": 0.5}
+ assert redumped["_artifacts"] == {"base_dir": "/tmp/t", "base_url": "http://h"}
+
+
+def test_recursive_children_round_trip() -> None:
+ """Nested ``children`` deserialize as ``BaseExecutorResult`` and re-emit
+ their extra fields when serialized."""
+ parent = _SampleResult(
+ items=[{"output": "p"}],
+ children={
+ "c1": _SampleResult(items=[{"output": "c"}], usage={"total_tokens": 3}),
+ },
+ )
+ payload = parent.model_dump_json()
+ base = BaseExecutorResult.model_validate_json(payload)
+
+ assert "c1" in base.children
+ child = base.children["c1"]
+ redumped = json.loads(child.model_dump_json())
+ assert redumped["items"] == [{"output": "c"}]
+ assert redumped["usage"] == {"total_tokens": 3}
+
+
+def test_artifact_ref_is_a_typed_path_reference() -> None:
+ ref = ArtifactRef(path="lora/final")
+ assert ref.model_dump() == {"path": "lora/final"}
+
+
+def test_envelope_round_trip_preserves_subclass_payload() -> None:
+ """The production write→read path (``write_result_in_envelope`` →
+ ``ResultEnvelope.model_validate``) round-trips subclass fields."""
+ inner = _SampleResult(items=[{"output": "hello"}], usage={"total_tokens": 7})
+ env = ResultEnvelope(task_id="tsk-x", result=inner)
+ raw = env.model_dump_json()
+
+ parsed = ResultEnvelope.model_validate_json(raw)
+ dumped = parsed.result.model_dump()
+ assert dumped["items"] == [{"output": "hello"}]
+ assert dumped["usage"] == {"total_tokens": 7}
diff --git a/tests/worker/test_agent_connector.py b/tests/worker/test_agent_connector.py
index b4fd1eb0..f3b07f32 100644
--- a/tests/worker/test_agent_connector.py
+++ b/tests/worker/test_agent_connector.py
@@ -182,9 +182,9 @@ def test_renders_description_from_params_and_yields_table_item(
)
result = executor._run_agent(data_cfg, context, out_dir)
- assert result["ok"] is True
- assert result["count"] == 1
- item = result["items"][0]
+ assert result.ok is True
+ assert result.count == 1
+ item = result.items[0]
assert "fetch NVDA quotes for 2024-01-01" in item["description"]
assert item["rows"] == 1
assert item["run_id"] == "run-fake"
diff --git a/tests/worker/test_artifact_utils.py b/tests/worker/test_artifact_utils.py
index a30528a5..189733b2 100644
--- a/tests/worker/test_artifact_utils.py
+++ b/tests/worker/test_artifact_utils.py
@@ -8,6 +8,8 @@
import pytest
+from shared.schemas.artifact import ArtifactContext
+from shared.schemas.result import BaseExecutorResult
from worker.executors.base_executor import ExecutionError
from worker.executors.utils.artifacts import (
artifact_to_source,
@@ -16,42 +18,26 @@
class TestArtifactToSource:
- def test_unwrapped_upstream_yields_url(self, tmp_path: Path) -> None:
- upstream = {
- "_artifacts": {
- "base_url": "http://host:8010",
- "base_dir": (tmp_path / "producer-tid").as_posix(),
- },
- "result": {"images": [{"path": "a.png"}]},
- }
+ def test_url_resolution(self, tmp_path: Path) -> None:
+ upstream = BaseExecutorResult(
+ _artifacts=ArtifactContext(
+ base_dir=(tmp_path / "producer-tid").as_posix(),
+ base_url="http://host:8010",
+ ),
+ )
url = artifact_to_source({"path": "a.png"}, {"producer": upstream}, "producer")
assert url == "http://host:8010/api/v1/results/producer-tid/files/a.png"
- def test_envelope_wrapped_upstream_yields_url(self, tmp_path: Path) -> None:
- """Server stores results.json as `{task_id, ..., result: {...}}`; the
- helper must unwrap one level to find `_artifacts`."""
- upstream = {
- "task_id": "producer-tid",
- "result": {
- "_artifacts": {
- "base_url": "http://host:8010",
- "base_dir": (tmp_path / "producer-tid").as_posix(),
- },
- },
- }
- url = artifact_to_source({"path": "x.png"}, {"producer": upstream}, "producer")
- assert url == "http://host:8010/api/v1/results/producer-tid/files/x.png"
-
def test_local_file_takes_fast_path(self, tmp_path: Path) -> None:
task_root = tmp_path / "producer-tid"
(task_root / "artifacts").mkdir(parents=True)
(task_root / "artifacts" / "a.png").write_bytes(b"\x89PNG")
- upstream = {
- "_artifacts": {
- "base_url": "http://host:8010",
- "base_dir": task_root.as_posix(),
- }
- }
+ upstream = BaseExecutorResult(
+ _artifacts=ArtifactContext(
+ base_dir=task_root.as_posix(),
+ base_url="http://host:8010",
+ )
+ )
resolved = artifact_to_source(
{"path": "a.png"}, {"producer": upstream}, "producer"
)
@@ -59,7 +45,9 @@ def test_local_file_takes_fast_path(self, tmp_path: Path) -> None:
def test_local_only_upstream_returns_local_path(self, tmp_path: Path) -> None:
task_root = tmp_path / "producer-tid"
- upstream = {"_artifacts": {"base_url": None, "base_dir": task_root.as_posix()}}
+ upstream = BaseExecutorResult(
+ _artifacts=ArtifactContext(base_dir=task_root.as_posix())
+ )
resolved = artifact_to_source(
{"path": "a.png"}, {"producer": upstream}, "producer"
)
@@ -67,11 +55,13 @@ def test_local_only_upstream_returns_local_path(self, tmp_path: Path) -> None:
def test_missing_context_raises(self) -> None:
with pytest.raises(ExecutionError, match="_artifacts context is missing"):
- artifact_to_source({"path": "a.png"}, {"producer": {}}, "producer")
+ artifact_to_source(
+ {"path": "a.png"}, {"producer": BaseExecutorResult()}, "producer"
+ )
def test_missing_path_raises(self) -> None:
with pytest.raises(ExecutionError, match="non-empty 'path' field"):
- artifact_to_source({}, {"producer": {}}, "producer")
+ artifact_to_source({}, {"producer": BaseExecutorResult()}, "producer")
class TestMaybeResolveArtifactRef:
@@ -85,12 +75,12 @@ def test_passes_through_dict_without_path(self) -> None:
assert maybe_resolve_artifact_ref(value, None, None) is value
def test_resolves_path_dict(self, tmp_path: Path) -> None:
- upstream = {
- "_artifacts": {
- "base_url": "http://host:8010",
- "base_dir": (tmp_path / "producer-tid").as_posix(),
- }
- }
+ upstream = BaseExecutorResult(
+ _artifacts=ArtifactContext(
+ base_url="http://host:8010",
+ base_dir=(tmp_path / "producer-tid").as_posix(),
+ )
+ )
out = maybe_resolve_artifact_ref(
{"path": "a.png"}, {"producer": upstream}, "producer"
)
diff --git a/tests/worker/test_checkpoint_utils.py b/tests/worker/test_checkpoint_utils.py
index 27816fbe..7c671641 100644
--- a/tests/worker/test_checkpoint_utils.py
+++ b/tests/worker/test_checkpoint_utils.py
@@ -6,6 +6,7 @@
import pytest
+from shared.schemas.artifact import ArtifactContext
from worker.executors.base_executor import TaskReference
from worker.executors.utils import checkpoints
@@ -29,25 +30,22 @@ def _task(
return cast(TaskReference, SimpleNamespace(task_id="task-1", spec=spec))
-class TestArtifactRef:
- def test_returns_path_only(self) -> None:
- assert checkpoints.artifact_ref("images/foo.png") == {"path": "images/foo.png"}
-
-
class TestBuildArtifactContext:
def test_http_destination_strips_api_suffix(self, tmp_path: Path) -> None:
out_dir = tmp_path / "task-1"
out_dir.mkdir()
ctx = checkpoints.build_artifact_context(_task().spec, out_dir)
- assert ctx["base_dir"] == out_dir.resolve().as_posix()
- assert ctx["base_url"] == "http://host:8010"
+ assert ctx == ArtifactContext(
+ base_dir=out_dir.resolve().as_posix(), base_url="http://host:8010"
+ )
def test_local_destination_leaves_base_url_none(self, tmp_path: Path) -> None:
out_dir = tmp_path / "task-1"
out_dir.mkdir()
ctx = checkpoints.build_artifact_context(_task("local").spec, out_dir)
- assert ctx["base_dir"] == out_dir.resolve().as_posix()
- assert ctx["base_url"] is None
+ assert ctx == ArtifactContext(
+ base_dir=out_dir.resolve().as_posix(), base_url=None
+ )
class TestMaybeUploadArtifacts:
diff --git a/tests/worker/test_connector_logging.py b/tests/worker/test_connector_logging.py
index beda187e..417e8a98 100644
--- a/tests/worker/test_connector_logging.py
+++ b/tests/worker/test_connector_logging.py
@@ -1,16 +1,24 @@
"""Test that connector logs are properly redirected from worker subprocess to parent."""
+import logging
import tempfile
import time
import uuid
from pathlib import Path
+from typing import Any
+from shared.schemas.result import BaseExecutorResult
from shared.tasks.worker_message import WorkerTaskMessage
from tests.worker.factories import make_live_worker_config, make_worker_hardware
from worker.executors.base_executor import Executor
from worker.executors.mp_executor import MPExecutor
+class ConnectorLoggingResult(BaseExecutorResult):
+ ok: bool = True
+ result: dict[str, Any]
+
+
class ConnectorLoggingExecutor(Executor):
"""Simple test executor that uses a connector and logs messages."""
@@ -19,9 +27,8 @@ class ConnectorLoggingExecutor(Executor):
def prepare(self) -> None:
pass
- def run(self, task, out_dir: Path) -> dict:
+ def run(self, task, out_dir: Path) -> ConnectorLoggingResult:
"""Run a simple test that logs from different modules."""
- import logging
# Get loggers from different modules that would be used in real execution
executor_logger = logging.getLogger("executors.test_executor")
@@ -36,13 +43,12 @@ def run(self, task, out_dir: Path) -> dict:
root_logger.info("Message from root logger")
- return {
- "ok": True,
- "result": {
+ return ConnectorLoggingResult(
+ result={
"status": "completed",
"log_test": "Messages logged from different modules",
},
- }
+ )
def cleanup_after_run(self) -> None:
pass
@@ -89,8 +95,8 @@ def test_connector_logs_printed_to_stderr(tmp_path: Path) -> None:
mp.cleanup_after_run()
- # Verify the executor ran successfully
- assert result["ok"], f"Executor failed: {result}"
- assert (
- result.get("result", {}).get("status") == "completed"
- ), f"Unexpected result: {result}"
+ # Verify the executor ran successfully. The MP boundary pickles the
+ # subclass instance, so the parent receives a ``ConnectorLoggingResult``.
+ assert isinstance(result, ConnectorLoggingResult), f"Unexpected result: {result}"
+ assert result.ok is True
+ assert result.result.get("status") == "completed", f"Unexpected result: {result}"
diff --git a/tests/worker/test_data_mixin_lineage.py b/tests/worker/test_data_mixin_lineage.py
index 48892ee9..69e8856f 100644
--- a/tests/worker/test_data_mixin_lineage.py
+++ b/tests/worker/test_data_mixin_lineage.py
@@ -7,6 +7,7 @@
from PIL import Image
+from shared.schemas.result import BaseExecutorResult
from worker.executors.mixins.data import DataMixin
@@ -96,14 +97,16 @@ def test_dump_to_governance_with_merged_children(tmp_path: Path) -> None:
mixin = _Mixin()
out_dir = tmp_path / "task"
with mixin._task_span("tsk-parent", "wfl-1", out_dir, owner_id="alice"):
- result = {
- "ok": True,
- "items": [{"output": "p"}],
- "children": {
- "tsk-c1": {"items": [{"output": "c1"}]},
- "tsk-c2": {"items": [{"output": "c2"}]},
- },
- }
+ result = BaseExecutorResult.model_validate(
+ {
+ "ok": True,
+ "items": [{"output": "p"}],
+ "children": {
+ "tsk-c1": {"items": [{"output": "c1"}]},
+ "tsk-c2": {"items": [{"output": "c2"}]},
+ },
+ }
+ )
deps = {
"tsk-parent": ["tsk-up-a"],
"tsk-c1": ["tsk-up-b"],
@@ -143,20 +146,21 @@ def test_collect_prompts_resolves_grouped_image_artifact_refs_after_flatten(
for name, color in (("a.png", "red"), ("b.png", "green"), ("c.png", "blue")):
Image.new("RGB", (2, 2), color=color).save(artifacts_dir / name)
+ result = BaseExecutorResult.model_validate(
+ {
+ "images": [
+ [{"path": "images/a.png"}, {"path": "images/b.png"}],
+ [{"path": "images/c.png"}],
+ ],
+ "_artifacts": {"base_dir": upstream_dir.as_posix()},
+ }
+ )
spec = cast(
Any,
SimpleNamespace(
data={"type": "list", "expr": "vision.images"},
inference={},
- upstreamResults={
- "vision": {
- "images": [
- [{"path": "images/a.png"}, {"path": "images/b.png"}],
- [{"path": "images/c.png"}],
- ],
- "_artifacts": {"base_dir": upstream_dir.as_posix()},
- }
- },
+ upstreamResults={"vision": result},
),
)
diff --git a/tests/worker/test_executor_bootstrap.py b/tests/worker/test_executor_bootstrap.py
index 34f6596c..45277de9 100644
--- a/tests/worker/test_executor_bootstrap.py
+++ b/tests/worker/test_executor_bootstrap.py
@@ -10,6 +10,7 @@
from pathlib import Path
from typing import Any
+from shared.schemas.result import BaseExecutorResult
from tests.worker.factories import make_live_worker_config, make_worker_hardware
from worker.executors.base_executor import Executor, ExecutorTask
from worker.main import initialize_executors
@@ -23,8 +24,8 @@ class _PassthroughExecutor(Executor):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
- def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]: # noqa: D401
- return {"ok": True}
+ def run(self, task: ExecutorTask, out_dir: Path) -> BaseExecutorResult:
+ return BaseExecutorResult.model_validate({"ok": True})
class TestInitializeExecutorsHardware:
diff --git a/tests/worker/test_mp_executor_lifecycle.py b/tests/worker/test_mp_executor_lifecycle.py
index 3fe73be8..6cea3af4 100644
--- a/tests/worker/test_mp_executor_lifecycle.py
+++ b/tests/worker/test_mp_executor_lifecycle.py
@@ -4,6 +4,7 @@
import uuid
from pathlib import Path
+from shared.schemas.result import BaseExecutorResult
from shared.tasks import TaskType
from shared.tasks.specs import EchoSpecStrict
from shared.tasks.worker_message import WorkerTaskMessage
@@ -16,11 +17,16 @@
from worker.executors.mp_executor import MPExecutor
+class _SimpleMPResult(BaseExecutorResult):
+ ok: bool = True
+ task_id: str
+
+
class _SimpleMPExecutor(Executor):
name = "simple_mp"
- def run(self, task, out_dir: Path) -> dict:
- return {"ok": True, "task_id": task.task_id}
+ def run(self, task, out_dir: Path) -> _SimpleMPResult:
+ return _SimpleMPResult(task_id=task.task_id)
def cleanup_after_run(self) -> None:
return None
@@ -54,7 +60,8 @@ def test_mp_executor_does_not_start_subprocess_until_first_run(tmp_path: Path) -
with tempfile.TemporaryDirectory() as out_dir:
result = mp.run(_simple_task_message(), Path(out_dir))
- assert result["ok"] is True
+ assert isinstance(result, _SimpleMPResult)
+ assert result.ok is True
assert mp._shutdown is False
assert mp._proc is not None
assert mp._proc.is_alive()
diff --git a/tests/worker/test_transformers_chat_inference.py b/tests/worker/test_transformers_chat_inference.py
index 12fa0f4a..5f67fa80 100644
--- a/tests/worker/test_transformers_chat_inference.py
+++ b/tests/worker/test_transformers_chat_inference.py
@@ -116,18 +116,18 @@ def test_transformers_executor_supports_chat_prompts_and_jsonl_export(
result = executor.run(task, tmp_path)
mock_ensure_model.assert_called_once_with(spec)
- assert [item["output"] for item in result["items"]] == [
+ assert [item["output"] for item in result.items] == [
"first answer",
"second answer",
]
- assert [item["prompt"] for item in result["items"]] == [
+ assert [item["prompt"] for item in result.items] == [
"hello",
"world",
]
- assert [item["metadata"]["row_id"] for item in result["items"]] == ["a", "b"]
- assert result["usage"]["num_requests"] == 2
- assert "latency_sec" in result["usage"]
- assert result["jsonl_export"]["record_count"] == 2
+ assert [item["metadata"]["row_id"] for item in result.items] == ["a", "b"]
+ assert result.usage is not None
+ assert result.usage["num_requests"] == 2
+ assert "latency_sec" in result.usage
exported = (tmp_path / "artifacts" / "rows.jsonl").read_text(encoding="utf-8")
assert '"row_id": "a"' in exported
diff --git a/uv.lock b/uv.lock
index ea21fb32..650e27a2 100644
--- a/uv.lock
+++ b/uv.lock
@@ -2355,7 +2355,7 @@ dependencies = [
requires-dist = [
{ name = "httpx", specifier = ">=0.27.0" },
{ name = "pandas", specifier = ">=2.3.3" },
- { name = "pydantic", specifier = ">=2.0.0" },
+ { name = "pydantic", specifier = ">=2.12.3" },
{ name = "pyyaml", specifier = ">=6.0.0" },
]