diff --git a/.github/workflows/security.yml b/.github/workflows/security.yml index 0859dfd0..d2540862 100644 --- a/.github/workflows/security.yml +++ b/.github/workflows/security.yml @@ -105,6 +105,28 @@ jobs: --ignore-vuln GHSA-vfmq-68hx-4jfw \ --ignore-vuln GHSA-j7w6-vpvq-j3gm \ --ignore-vuln GHSA-98h9-4798-4q5v \ + --ignore-vuln PYSEC-2025-189 \ + --ignore-vuln PYSEC-2025-190 \ + --ignore-vuln PYSEC-2025-191 \ + --ignore-vuln PYSEC-2025-192 \ + --ignore-vuln PYSEC-2025-193 \ + --ignore-vuln PYSEC-2025-194 \ + --ignore-vuln PYSEC-2025-195 \ + --ignore-vuln PYSEC-2025-196 \ + --ignore-vuln PYSEC-2025-197 \ + --ignore-vuln PYSEC-2025-210 \ + --ignore-vuln PYSEC-2026-139 \ + --ignore-vuln PYSEC-2025-211 \ + --ignore-vuln PYSEC-2025-212 \ + --ignore-vuln PYSEC-2025-213 \ + --ignore-vuln PYSEC-2025-214 \ + --ignore-vuln PYSEC-2025-215 \ + --ignore-vuln PYSEC-2025-216 \ + --ignore-vuln PYSEC-2025-217 \ + --ignore-vuln PYSEC-2025-218 \ + --ignore-vuln PYSEC-2026-97 \ + --ignore-vuln PYSEC-2025-183 \ + --ignore-vuln PYSEC-2024-277 \ -r /tmp/requirements-worker-cpu-audit.txt - name: Run pip-audit (worker GPU delta) # no --strict: flashinfer-jit-cache is unauditable on PyPI run: | @@ -130,4 +152,28 @@ jobs: --ignore-vuln GHSA-83vm-p52w-f9pw \ --ignore-vuln GHSA-j7w6-vpvq-j3gm \ --ignore-vuln GHSA-98h9-4798-4q5v \ + --ignore-vuln PYSEC-2025-189 \ + --ignore-vuln PYSEC-2025-190 \ + --ignore-vuln PYSEC-2025-191 \ + --ignore-vuln PYSEC-2025-192 \ + --ignore-vuln PYSEC-2025-193 \ + --ignore-vuln PYSEC-2025-194 \ + --ignore-vuln PYSEC-2025-195 \ + --ignore-vuln PYSEC-2025-196 \ + --ignore-vuln PYSEC-2025-197 \ + --ignore-vuln PYSEC-2025-210 \ + --ignore-vuln PYSEC-2026-139 \ + --ignore-vuln PYSEC-2025-211 \ + --ignore-vuln PYSEC-2025-212 \ + --ignore-vuln PYSEC-2025-213 \ + --ignore-vuln PYSEC-2025-214 \ + --ignore-vuln PYSEC-2025-215 \ + --ignore-vuln PYSEC-2025-216 \ + --ignore-vuln PYSEC-2025-217 \ + --ignore-vuln PYSEC-2025-218 \ + --ignore-vuln PYSEC-2026-97 \ + --ignore-vuln PYSEC-2025-183 \ + --ignore-vuln PYSEC-2024-277 \ + --ignore-vuln PYSEC-2025-222 \ + --ignore-vuln PYSEC-2024-274 \ -r src/worker/requirements/requirements.gpu.txt diff --git a/docs/CODE_STYLE.md b/docs/CODE_STYLE.md index 33ed7383..394628a9 100644 --- a/docs/CODE_STYLE.md +++ b/docs/CODE_STYLE.md @@ -89,13 +89,13 @@ CI runs `pip-audit` against each generated requirements file When pip-audit reports a new CVE, the only real fix is to bump the offending dep in `pyproject.toml`, then `uv lock` and `uv run scripts/dev/sync_requirements.py --write`. Silencing via `--ignore-vuln` -is a last resort; every silenced GHSA needs a written upgrade-blocker. -The currently-ignored advisories and the upgrade blocker that justifies -each are listed below; the same list is encoded as `--ignore-vuln` -flags in `.github/workflows/security.yml`. +is a last resort; every silenced advisory needs a written +upgrade-blocker. The currently-ignored advisories and the upgrade +blocker that justifies each are listed below; the same list is encoded +as `--ignore-vuln` flags in `.github/workflows/security.yml`. -| GHSA | Package | Fix version | Why ignored | -|------|---------|-------------|-------------| +| Advisory | Package | Fix version | Why ignored | +|----------|---------|-------------|-------------| | GHSA-69w3-r845-3855 | transformers | 5.0.0rc3 | held by vllm/vllm-omni 0.18 compatibility | | GHSA-pf3h-qjgv-vcpr | vllm | 0.19.0 | held by transformers 4.57 + adjacent inference deps | | GHSA-pq5c-rjhq-qp7p | vllm | 0.19.0 | same | @@ -117,6 +117,30 @@ flags in `.github/workflows/security.yml`. | GHSA-w8v5-vhqr-4h9v | diskcache | (none) | upstream unmaintained, no fixed version published | | GHSA-j7w6-vpvq-j3gm | diffusers | 0.38.0 | fix requires safetensors>=0.8.0rc0 pre-release; uv lock won't pick up pre-releases without explicit opt-in | | GHSA-98h9-4798-4q5v | diffusers | 0.38.0 | same blocker as GHSA-j7w6-vpvq-j3gm — both fixed in 0.38.0 | +| PYSEC-2025-189 | torch | (none) | no fix version published | +| PYSEC-2025-190 | torch | (none) | same | +| PYSEC-2025-191 | torch | (none) | same | +| PYSEC-2025-192 | torch | (none) | same | +| PYSEC-2025-193 | torch | (none) | same | +| PYSEC-2025-194 | torch | (none) | same | +| PYSEC-2025-195 | torch | (none) | same | +| PYSEC-2025-196 | torch | (none) | same | +| PYSEC-2025-197 | torch | (none) | same | +| PYSEC-2025-210 | torch | (none) | same | +| PYSEC-2026-139 | torch | (none) | same | +| PYSEC-2025-211 | transformers | (none) | no fix version published; transformers also held by vllm-omni 0.18 | +| PYSEC-2025-212 | transformers | (none) | same | +| PYSEC-2025-213 | transformers | (none) | same | +| PYSEC-2025-214 | transformers | (none) | same | +| PYSEC-2025-215 | transformers | (none) | same | +| PYSEC-2025-216 | transformers | (none) | same | +| PYSEC-2025-217 | transformers | (none) | same | +| PYSEC-2025-218 | transformers | (none) | same | +| PYSEC-2026-97 | nltk | (none) | no fix version published | +| PYSEC-2025-183 | pyjwt | (none) | no fix version published | +| PYSEC-2024-277 | joblib | (none) | no fix version published | +| PYSEC-2025-222 | vllm | (none) | no fix version published; held by vllm-omni 0.18 pin | +| PYSEC-2024-274 | gradio | (none) | no fix version published; vllm-omni 0.18 pins gradio==5.50 | When a blocker lifts (e.g. transformers 5 ↔ vllm 0.19 line stabilizes), drop the corresponding `--ignore-vuln` flag from the workflow and the diff --git a/docs/EXECUTORS.md b/docs/EXECUTORS.md index ed02e903..155efb08 100644 --- a/docs/EXECUTORS.md +++ b/docs/EXECUTORS.md @@ -22,6 +22,26 @@ Helper utilities live in `src/worker/executors/utils/` (`artifacts`, `src/worker/executors/mixins/` (`data`, `governance`, `inference`, `training`). +## Result schema + +Every executor's `run()` returns a subclass of `BaseExecutorResult` +(`src/shared/schemas/result.py`). The base class carries two +cross-cutting fields: + +- `children: dict[str, BaseExecutorResult]` — per-child results when + merged tasks share a dispatch. +- `artifacts: ArtifactContext | None` (wire key `_artifacts`) — + resolution context for relative artifact refs. + +Per-executor subclasses live next to the executor they describe — e.g. +`VLLMResult` in `src/worker/executors/vllm_executor.py`, `LoRAResult` in +`src/worker/executors/lora_sft_executor.py`. They add executor-specific +fields (`items`, `usage`, `final_lora`, `command`, …). + +Artifact-bearing fields use `ArtifactRef` (`{"path": rel_path}`); +relative paths resolve against the producer's `_artifacts` context via +`artifact_to_source` / `_render_artifact_ref`. + ## Agent executor (utu / youtu-agent) `AgentExecutor` requires the following env vars to run; the executor diff --git a/sdk/pyproject.toml b/sdk/pyproject.toml index 881e7689..97ca5abc 100644 --- a/sdk/pyproject.toml +++ b/sdk/pyproject.toml @@ -13,7 +13,7 @@ license-files = ["LICENSE"] dependencies = [ "httpx>=0.27.0", "pandas>=2.3.3", - "pydantic>=2.0.0", + "pydantic>=2.12.3", "pyyaml>=6.0.0", ] diff --git a/sdk/src/flowmesh/models/__init__.py b/sdk/src/flowmesh/models/__init__.py index e2043805..ea08aa5a 100644 --- a/sdk/src/flowmesh/models/__init__.py +++ b/sdk/src/flowmesh/models/__init__.py @@ -1,5 +1,6 @@ """FlowMesh SDK data models.""" +from .artifacts import ArtifactContext, ArtifactRef from .common import ( LogEntry, LogEvent, @@ -19,7 +20,7 @@ NodeWorkerInfo, WorkerRegisterResponse, ) -from .results import PathResponse, ResultEnvelope +from .results import BaseExecutorResult, PathResponse, ResultEnvelope from .tasks import HardwareUsage, TaskInfo, TaskUsage from .traces import ( ActiveWaitBreakdown, @@ -54,7 +55,10 @@ __all__ = [ "ActiveWaitBreakdown", + "ArtifactContext", + "ArtifactRef", "AssetSummary", + "BaseExecutorResult", "CPUInfo", "CriticalPathSummary", "E2EBreakdown", diff --git a/sdk/src/flowmesh/models/artifacts.py b/sdk/src/flowmesh/models/artifacts.py new file mode 100644 index 00000000..fdbb3edb --- /dev/null +++ b/sdk/src/flowmesh/models/artifacts.py @@ -0,0 +1,10 @@ +from pydantic import BaseModel, Field + + +class ArtifactRef(BaseModel): + path: str + + +class ArtifactContext(BaseModel): + base_dir: str + base_url: str | None = Field(default=None, exclude_if=lambda v: v is None) diff --git a/sdk/src/flowmesh/models/results.py b/sdk/src/flowmesh/models/results.py index 3368f1bc..399ea7b8 100644 --- a/sdk/src/flowmesh/models/results.py +++ b/sdk/src/flowmesh/models/results.py @@ -1,8 +1,14 @@ """Result-related models.""" +# This is necessary to allow for the recursive type hint of `children` in +# `BaseExecutorResult`. +from __future__ import annotations + from typing import Any -from pydantic import BaseModel, Field +from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny + +from .artifacts import ArtifactContext class PathResponse(BaseModel): @@ -10,11 +16,30 @@ class PathResponse(BaseModel): path: str +class BaseExecutorResult(BaseModel): + model_config = ConfigDict(extra="allow", serialize_by_alias=True) + + ok: bool = True + children: dict[str, SerializeAsAny[BaseExecutorResult]] = Field( + default_factory=dict, exclude_if=lambda v: not v + ) + artifacts_: ArtifactContext | None = Field(default=None, alias="_artifacts") + + @classmethod + def __pydantic_init_subclass__(cls, **kwargs: Any) -> None: + super().__pydantic_init_subclass__(**kwargs) + if "artifacts_" in cls.__annotations__: + raise TypeError( + f"{cls.__name__} may not redefine the internal " + "BaseExecutorResult.artifacts_ field" + ) + + class ResultEnvelope(BaseModel): """Canonical on-disk shape of ``results.json`` (mirrors the server).""" task_id: str - result: dict[str, Any] + result: SerializeAsAny[BaseExecutorResult] worker_id: str | None = None metadata: dict[str, Any] | None = None received_at: str | None = Field(default=None) diff --git a/sdk/src/flowmesh/resources/results.py b/sdk/src/flowmesh/resources/results.py index 1f5e99cf..3742106b 100644 --- a/sdk/src/flowmesh/resources/results.py +++ b/sdk/src/flowmesh/resources/results.py @@ -198,10 +198,9 @@ def _finalize_materialize( return {}, json_path, extracted envelope = ResultEnvelope.model_validate_json(json_path.read_text()) - if _wants_artifacts(sections): - ctx = envelope.result["_artifacts"] - ctx["base_dir"] = (output_dir / task_id).resolve().as_posix() - ctx.pop("base_url", None) + if _wants_artifacts(sections) and (ctx := envelope.result.artifacts_): + ctx.base_dir = (output_dir / task_id).resolve().as_posix() + ctx.base_url = None payload = envelope.model_dump(mode="json") json_path.write_text(json.dumps(payload, indent=2)) return payload, json_path, extracted diff --git a/src/server/dispatcher/base.py b/src/server/dispatcher/base.py index d3c15bc3..501a9b2d 100644 --- a/src/server/dispatcher/base.py +++ b/src/server/dispatcher/base.py @@ -9,12 +9,18 @@ from pydantic import BaseModel, ValidationError from shared.schemas.event import TaskEvent -from shared.schemas.result import ResultEnvelope, result_file_path, write_result +from shared.schemas.result import ( + BaseExecutorResult, + ResultEnvelope, + result_file_path, + write_result, +) from shared.tasks import ( MergedChildTaskStrict, TaskEnvelope, TaskEnvelopeStrict, TaskEnvelopeTemplate, + TaskSpecStrict, ) from shared.tasks.placeholders import PLACEHOLDER_PATTERN from shared.tasks.specs import ( @@ -32,6 +38,8 @@ from ..utils.time import now_iso from .worker_selector import DEFAULT_WORKER_SELECTION, select_worker +_SENTINEL: Any = object() + class StageReferenceNotReady(Exception): """Raised when a task references a stage whose artifacts are not yet available.""" @@ -707,7 +715,7 @@ def _resolve_stage_references( self, task_id: str, task: TaskEnvelopeTemplate, record: TaskRecord ) -> TaskEnvelopeStrict: context = self._build_stage_context(record) - resolved_task = task + resolved_task: TaskEnvelopeTemplate = task if context and task.has_placeholder(): resolved_task = self._resolve_placeholders(task, context) @@ -716,7 +724,7 @@ def _resolve_stage_references( ) if upstream_results: existing = resolved_task.spec.upstreamResults or {} - merged: dict[str, Any] = {} + merged: dict[str, BaseExecutorResult] = {} if isinstance(existing, dict): merged.update(copy.deepcopy(existing)) merged.update(upstream_results) @@ -730,7 +738,7 @@ def _resolve_stage_references( return TaskEnvelopeStrict.model_validate(resolved_task) - def _resolve_placeholders(self, value: Any, context: dict[str, Any]) -> Any: + def _resolve_placeholders(self, value: Any, context: dict[str, TaskRecord]) -> Any: if isinstance(value, str): exact = PLACEHOLDER_PATTERN.fullmatch(value) if exact: @@ -814,7 +822,7 @@ def _stage_context_keys(record: TaskRecord) -> tuple[str, ...]: keys.append(record.graph_node_name) return tuple(keys) - def _resolve_reference(self, expr: str, context: dict[str, Any]) -> Any: + def _resolve_reference(self, expr: str, context: dict[str, TaskRecord]) -> Any: expr = expr.strip() if not expr: raise ValueError("Empty stage reference") @@ -831,49 +839,47 @@ def _resolve_reference(self, expr: str, context: dict[str, Any]) -> Any: return stage_record.task_id if stage_record.status != "DONE": raise StageReferenceNotReady(f"Stage '{stage_name}' has not completed") - data = self._load_stage_result(stage_record.task_id) - value = self._dig_path(data, path.split(".")) + envelope = self._load_stage_result(stage_record.task_id) + value = self._dig_path(envelope.result, path.split(".")) if value is None: raise ValueError(f"Missing value for reference '{expr}'") # If the referenced value is an artifact ref ({path: "..."}), render # it as a full URL (when base_url is set) or an absolute filesystem # path using the producing stage's top-level _artifacts context. - if rendered := self._render_artifact_ref(value, data): + if rendered := self._render_artifact_ref(value, envelope): return rendered return value @staticmethod - def _render_artifact_ref(value: Any, stage_result: Any) -> str | None: + def _render_artifact_ref(value: Any, stage_result: ResultEnvelope) -> str | None: if not isinstance(value, dict): return None path_value = value.get("path") if not isinstance(path_value, str) or not path_value: return None - if not isinstance(stage_result, dict): - return None - ctx = stage_result.get("_artifacts") - if not isinstance(ctx, dict): + ctx = stage_result.result.artifacts_ + if ctx is None: return None - base_url = ctx.get("base_url") - base_dir = ctx.get("base_dir") - if isinstance(base_url, str) and base_url and isinstance(base_dir, str): + base_url = ctx.base_url + base_dir = ctx.base_dir + if base_url and base_dir: task_id = Path(base_dir).name return f"{base_url.rstrip('/')}/api/v1/results/{task_id}/files/{path_value}" - if isinstance(base_dir, str) and base_dir: + if base_dir: return (Path(base_dir) / "artifacts" / path_value).as_posix() return None def _collect_upstream_results( self, context: dict[str, TaskRecord], current_task_id: str - ) -> dict[str, Any]: - results: dict[str, Any] = {} + ) -> dict[str, BaseExecutorResult]: + results: dict[str, BaseExecutorResult] = {} for name, record in context.items(): if not record or record.task_id == current_task_id: continue if record.status != TaskStatus.DONE: continue try: - data = self._load_stage_result(record.task_id) + envelope = self._load_stage_result(record.task_id) except StageReferenceNotReady as exc: raise exc except Exception as exc: @@ -884,11 +890,11 @@ def _collect_upstream_results( exc, ) continue - results[name] = data + results[name] = envelope.result return results def _resolve_upstream_task_ids( - self, record: TaskRecord, spec: Any + self, record: TaskRecord, spec: TaskSpecStrict ) -> dict[str, str] | None: if not isinstance(spec, SSHSpecStrict) or not spec.inputs: return None @@ -915,14 +921,14 @@ def _resolve_upstream_task_ids( resolved[stage_name] = upstream.task_id return resolved or None - def _load_stage_result(self, stage_task_id: str) -> dict[str, Any]: + def _load_stage_result(self, stage_task_id: str) -> ResultEnvelope: path = result_file_path(self._results_dir, stage_task_id) if not path.exists(): raise StageReferenceNotReady( f"Result for task {stage_task_id} not found at {path}" ) content = json.loads(path.read_text(encoding="utf-8")) - return ResultEnvelope.model_validate(content).result + return ResultEnvelope.model_validate(content) def _dig_path(self, data: Any, parts: list[str]) -> Any: current = data @@ -946,6 +952,11 @@ def _dig_path(self, data: Any, parts: list[str]) -> Any: return None current = current[idx] continue + if isinstance(current, BaseModel): + current = getattr(current, part, _SENTINEL) + if current is _SENTINEL: + return None + continue return None return current @@ -992,7 +1003,7 @@ def _evaluate_condition_skip( ) skip_envelope = ResultEnvelope( task_id=task_id, - result={}, + result=BaseExecutorResult(), metadata={ "skipped": True, "reason": "condition_not_met", diff --git a/src/server/routers/v1/results.py b/src/server/routers/v1/results.py index 9edceb13..e38e8d76 100644 --- a/src/server/routers/v1/results.py +++ b/src/server/routers/v1/results.py @@ -4,7 +4,6 @@ import tarfile import tempfile from pathlib import Path -from typing import Any from fastapi import ( APIRouter, @@ -20,12 +19,12 @@ from pydantic import ValidationError from shared.schemas.result import ( + BaseExecutorResult, ResultEnvelope, read_result, result_file_path, write_result, ) -from shared.utils.atomic import atomic_write_text from shared.utils.manifest import ARTIFACTS_DIR, LOGS_DIR, RESULTS_NAME, sync_manifest from ...app_state import ( @@ -116,7 +115,7 @@ async def get_result( principal: PrincipalContext = Depends(authenticate_connection), results_dir: Path = Depends(get_results_dir), logger: logging.Logger = Depends(get_logger), -) -> dict[str, Any]: +) -> BaseExecutorResult: task_id = (task_id or "").strip() if not task_id: raise HTTPException( @@ -186,8 +185,6 @@ async def upload_result_file( detail=f"Failed to store artifact: {exc}", ) from exc - _rewrite_jsonl_export_paths(task_id, base_dir, target_path) - record = runtime.get_record(task_id) expected_artifacts: list[str] = [] if record: @@ -333,55 +330,6 @@ async def download_task_logs( return FileResponse(target_path) -def _rewrite_jsonl_export_paths( - task_id: str, base_dir: Path, artifact_path: Path -) -> None: - results_path = base_dir / RESULTS_NAME - if not results_path.exists(): - return - try: - payload = json.loads(results_path.read_text(encoding="utf-8")) - except Exception: - return - - filename = artifact_path.name - new_abs = str(artifact_path) - new_relative = f"{ARTIFACTS_DIR}/{filename}" - updated = False - - def _update(entry: dict[str, Any]) -> None: - nonlocal updated - if not isinstance(entry, dict): - return - block = entry.get("jsonl_export") - if isinstance(block, dict): - path_value = str(block.get("path") or "") - if path_value and Path(path_value).name == filename: - if "worker_path" not in block and path_value != new_abs: - block["worker_path"] = path_value - block["path"] = new_abs - block["relative_path"] = new_relative - block.setdefault("url", f"/api/v1/results/{task_id}/files/{filename}") - updated = True - children = entry.get("children") - if isinstance(children, dict): - for child_entry in children.values(): - if isinstance(child_entry, dict): - _update(child_entry) - - result_entry = payload.get("result") if isinstance(payload, dict) else None - if isinstance(result_entry, dict): - _update(result_entry) - - if updated: - try: - atomic_write_text( - results_path, json.dumps(payload, ensure_ascii=False, indent=2) - ) - except Exception: - pass - - def _resolve_bundle_sections(include: list[str]) -> tuple[str, ...]: if not include: return _BUNDLE_SECTIONS_DEFAULT diff --git a/src/shared/schemas/artifact.py b/src/shared/schemas/artifact.py new file mode 100644 index 00000000..724d8819 --- /dev/null +++ b/src/shared/schemas/artifact.py @@ -0,0 +1,14 @@ +from pydantic import BaseModel, Field + + +class ArtifactRef(BaseModel): + path: str = Field(description="Path relative to the task's artifacts/ dir.") + + +class ArtifactContext(BaseModel): + base_dir: str = Field(description="Producing task's output directory.") + base_url: str | None = Field( + default=None, + exclude_if=lambda v: v is None, + description="HTTP origin (scheme://host[:port]) for upload.", + ) diff --git a/src/shared/schemas/result.py b/src/shared/schemas/result.py index 0eea3e5c..e4378f90 100644 --- a/src/shared/schemas/result.py +++ b/src/shared/schemas/result.py @@ -1,18 +1,57 @@ """Result envelope schema shared by server and worker.""" +# This is necessary to allow for the recursive type hint of `children` in +# `BaseExecutorResult`. +from __future__ import annotations + from pathlib import Path from typing import Any -from pydantic import BaseModel, Field +from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny from shared.utils.atomic import atomic_write_text from shared.utils.manifest import prepare_output_dir from shared.utils.time import now_iso +from .artifact import ArtifactContext + + +class BaseExecutorResult(BaseModel): + """Common shape for every executor's result payload. + + ``extra="allow"`` lets the server round-trip subclass payloads through + this base class without losing executor-specific fields. + """ + + model_config = ConfigDict(extra="allow", serialize_by_alias=True) + + ok: bool = Field(default=True, description="Whether task execution succeeded.") + children: dict[str, SerializeAsAny[BaseExecutorResult]] = Field( + default_factory=dict, + exclude_if=lambda v: not v, + description="Per-child result payloads for task merging.", + ) + artifacts_: ArtifactContext | None = Field( + default=None, + alias="_artifacts", + description="Resolution context for relative artifact refs.", + ) + + @classmethod + def __pydantic_init_subclass__(cls, **kwargs: Any) -> None: + super().__pydantic_init_subclass__(**kwargs) + if "artifacts_" in cls.__annotations__: + raise TypeError( + f"{cls.__name__} may not redefine the internal " + "BaseExecutorResult.artifacts_ field" + ) + class ResultEnvelope(BaseModel): task_id: str = Field(description="Task identifier.") - result: dict[str, Any] = Field(description="Result payload data.") + result: SerializeAsAny[BaseExecutorResult] = Field( + description="Result payload data." + ) worker_id: str | None = Field( default=None, description="Worker identifier submitting the result." ) @@ -40,13 +79,6 @@ def write_result(base_dir: Path, envelope: ResultEnvelope) -> Path: return path -def write_result_in_envelope(path: Path, task_id: str, result: dict[str, Any]) -> None: - """Wrap ``result`` in a ``ResultEnvelope`` and persist it at ``path``.""" - path.parent.mkdir(parents=True, exist_ok=True) - envelope = ResultEnvelope(task_id=task_id, result=result) - atomic_write_text(path, envelope.model_dump_json(indent=2)) - - def read_result(base_dir: Path, task_id: str) -> str: path = result_file_path(base_dir, task_id) return path.read_text(encoding="utf-8") diff --git a/src/shared/tasks/specs/common.py b/src/shared/tasks/specs/common.py index d6cb61db..e841e596 100644 --- a/src/shared/tasks/specs/common.py +++ b/src/shared/tasks/specs/common.py @@ -2,6 +2,7 @@ from pydantic import Field, model_validator +from ...schemas.result import BaseExecutorResult from .._base import StrictBaseModel, TemplateBaseModel from ..components import ( AdapterConfig, @@ -72,7 +73,7 @@ class TaskSpecStrictBase(StrictBaseModel): shard: ShardSpec | None = None # Server-injected stage context (reserve the user-facing key `_upstreamResults`) - upstreamResults: dict[str, Any] | None = Field( + upstreamResults: dict[str, BaseExecutorResult] | None = Field( default=None, alias="_upstreamResults" ) @@ -98,7 +99,7 @@ class TaskSpecTemplateBase(TemplateBaseModel): condition: ConditionSpec | None = None shard: ShardSpecTemplate | None = None - upstreamResults: dict[str, Any] | None = Field( + upstreamResults: dict[str, BaseExecutorResult] | None = Field( default=None, alias="_upstreamResults" ) diff --git a/src/worker/executors/agent_executor.py b/src/worker/executors/agent_executor.py index c12d4003..db59687b 100644 --- a/src/worker/executors/agent_executor.py +++ b/src/worker/executors/agent_executor.py @@ -16,14 +16,12 @@ from datasets import load_dataset +from shared.schemas.artifact import ArtifactRef +from shared.schemas.result import BaseExecutorResult from shared.tasks.specs import AgentSpecStrict from .base_executor import ExecutionError, Executor, ExecutorTask -from .utils.checkpoints import ( - artifact_ref, - maybe_upload_artifacts, - write_executor_result, -) +from .utils.checkpoints import maybe_upload_artifacts, write_executor_result from .utils.graph_templates import build_prompts_from_graph_template # Add agent directory to sys.path for utu imports @@ -54,6 +52,16 @@ def _resolve_task_timeout(agent: dict[str, Any] | None) -> int: logger = logging.getLogger("worker.agent") +class AgentResult(BaseExecutorResult): + ok: bool = True + model: str + items: list[dict[str, Any]] = [] + usage: dict[str, Any] | None = None + metadata: dict[str, Any] | None = None + agent_output: ArtifactRef | None = None + batch_summary_file: ArtifactRef | None = None + + class AgentExecutor(Executor): """Agent executor using youtu-agent (utu) framework""" @@ -223,7 +231,7 @@ def prepare_data(self, spec: AgentSpecStrict) -> None: else: raise ExecutionError(f"Unsupported spec.data.type: {dtype!r}") - def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]: + def run(self, task: ExecutorTask, out_dir: Path) -> AgentResult: """Execute agent tasks using youtu-agent (utu) framework""" self.ensure_dir(out_dir) @@ -248,32 +256,34 @@ def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]: agent_config_name, self._tasks[0], out_dir, task_timeout ) ) + agent_output_ref = result.get("_agent_output_ref") - output: dict[str, Any] = { - "ok": True, - "model": agent_config_name, - "items": [ + output = AgentResult( + model=agent_config_name, + items=[ { "index": 0, "output": result.get("output", ""), "finish_reason": "completed", } ], - "usage": { + usage={ "execution_time_sec": result.get("usage", {}).get( "execution_time_sec", 0 ), "num_requests": 1, "agent_config": agent_config_name, }, - "metadata": { + metadata={ "task": self._tasks[0], "execution_log": result.get("log", []), }, - } - agent_output_ref = result.get("_agent_output_ref") - if isinstance(agent_output_ref, str): - output["agent_output"] = artifact_ref(agent_output_ref) + agent_output=( + None + if agent_output_ref is None + else ArtifactRef(path=agent_output_ref) + ), + ) else: # Batch execution for multiple tasks results = asyncio.run( @@ -297,26 +307,28 @@ def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]: } ) - output = { - "ok": True, - "model": agent_config_name, - "items": items, - "usage": { + batch_summary_ref = results.get("_batch_summary_ref") + output = AgentResult( + model=agent_config_name, + items=items, + usage={ "execution_time_sec": results.get("usage", {}).get( "execution_time_sec", 0 ), "num_requests": len(self._tasks), "agent_config": agent_config_name, }, - "metadata": { + metadata={ "tasks_count": len(self._tasks), "execution_log": results.get("log", []), "batch_summary": results.get("batch_summary", {}), }, - } - batch_summary_ref = results.get("_batch_summary_ref") - if isinstance(batch_summary_ref, str): - output["batch_summary_file"] = artifact_ref(batch_summary_ref) + batch_summary_file=( + None + if batch_summary_ref is None + else ArtifactRef(path=batch_summary_ref) + ), + ) maybe_upload_artifacts(task, out_dir, logger=logger) @@ -327,21 +339,21 @@ def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]: raise except Exception as e: logger.exception(f"Agent execution failed: {e}") - error_output = { - "ok": False, - "model": agent_config_name, - "items": [], - "usage": { + error_output = AgentResult( + ok=False, + model=agent_config_name, + items=[], + usage={ "execution_time_sec": 0, "num_requests": len(self._tasks), "agent_config": agent_config_name, }, - "metadata": { + metadata={ "tasks_count": len(self._tasks), "error": str(e), "execution_log": [], }, - } + ) write_executor_result( out_dir / "results.json", task.task_id, task.spec, error_output ) diff --git a/src/worker/executors/api_executor.py b/src/worker/executors/api_executor.py index 4129ce0f..bc0e9bde 100644 --- a/src/worker/executors/api_executor.py +++ b/src/worker/executors/api_executor.py @@ -5,7 +5,9 @@ from typing import Any, ClassVar import httpx +from pydantic import Field +from shared.schemas.result import BaseExecutorResult from shared.tasks.specs import ApiSpecStrict from .base_executor import ExecutionError, Executor, ExecutorTask @@ -16,6 +18,18 @@ _ClientKey = tuple[str, float, bool, bool] +class APIResult(BaseExecutorResult): + executor: str + method: str + url: str + status_code: int + truncated: bool = False + headers: dict[str, str] | None = None + response_json: Any = Field(default=None, alias="json") + usage: dict[str, Any] | None = None + text: str | None = None + + class APIExecutor(Executor): """Executor that performs a single HTTP request defined by task YAML. @@ -89,7 +103,7 @@ def cleanup_after_run(self) -> None: """Close the connection pool when the runner deactivates this executor.""" self.close_all_clients() - def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]: + def run(self, task: ExecutorTask, out_dir: Path) -> APIResult: spec = self.require_spec(task, ApiSpecStrict) api_cfg = spec.api or {} if not isinstance(api_cfg, dict): @@ -173,17 +187,17 @@ def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]: body_bytes = body_bytes[:max_body_bytes] truncated = True - result: dict[str, Any] = { - "ok": resp.is_success, - "executor": self.name, - "method": method, - "url": str(resp.url), - "status_code": resp.status_code, - "truncated": truncated, - } + result = APIResult( + ok=resp.is_success, + executor=self.name, + method=method, + url=str(resp.url), + status_code=resp.status_code, + truncated=truncated, + ) if include_headers: - result["headers"] = dict(resp.headers) + result.headers = dict(resp.headers) body_text: str | None = None if return_body: @@ -191,26 +205,25 @@ def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]: body_text = body_bytes.decode(encoding, errors="replace") if parse_json: - result["json"] = resp.json() - if not isinstance(result["json"], dict): + result.response_json = resp.json() + if not isinstance(result.response_json, dict): raise ExecutionError("Response is not a valid JSON mapping") - if isinstance(result["json"], dict): - usage = result["json"].get("usage") - if not isinstance(usage, dict): - raise ExecutionError( - "spec.api.response.parse_json is true but response JSON " - f"does not contain usage info: {result['json']}" - ) - result["usage"] = usage + usage = result.response_json.get("usage") + if not isinstance(usage, dict): + raise ExecutionError( + "spec.api.response.parse_json is true but response JSON " + f"does not contain usage info: {result.response_json}" + ) + result.usage = usage try: - result["text"] = result["json"]["choices"][0]["message"]["content"] + result.text = result.response_json["choices"][0]["message"]["content"] except Exception as exc: raise ExecutionError( "spec.api.response.parse_json is true but response JSON " - f"does not contain message.content: {result['json']}" + f"does not contain message.content: {result.response_json}" ) from exc elif return_body: - result["text"] = body_text + result.text = body_text if raise_for_status and resp.is_error: if body_text: diff --git a/src/worker/executors/base_executor.py b/src/worker/executors/base_executor.py index f4f43f90..f5a20245 100644 --- a/src/worker/executors/base_executor.py +++ b/src/worker/executors/base_executor.py @@ -2,19 +2,23 @@ Executor base class and a minimal example implementation. Usage: - from executor_base import Executor, ExecutionError, EchoExecutor + from shared.schemas.result import BaseExecutorResult + from worker.executors.base_executor import Executor, ExecutionError + + class MyResult(BaseExecutorResult): + echo: str class MyExecutor(Executor): name = "my-executor" - def run(self, task: ExecutorTask, out_dir: Path) -> dict: + def run(self, task: ExecutorTask, out_dir: Path) -> MyResult: # ... your logic ... - return {"ok": True, "echo": task.task_id} + return MyResult(echo=task.task_id) Contract: -- Implement `run(task: ExecutorTask, out_dir: Path) -> dict`. The runner - writes the returned dict to `out_dir/results.json` and injects the - top-level `_artifacts` context — executors should not write that file - themselves on the success path. +- Implement `run(task: ExecutorTask, out_dir: Path) -> BaseExecutorResult`. + The runner writes the returned model to `out_dir/results.json` and + injects the top-level `_artifacts` context — executors should not write + that file themselves on the success path. - Drop generated files under `out_dir/artifacts/` (uploaded to the server when the task has an HTTP destination) or `scratch_dir(out_dir)` for local-only scratch data. @@ -27,6 +31,7 @@ def run(self, task: ExecutorTask, out_dir: Path) -> dict: from pathlib import Path from typing import Any, TypeVar +from shared.schemas.result import BaseExecutorResult from shared.tasks import MergedChildTaskStrict from shared.tasks.specs import TaskSpecStrictBase from shared.tasks.worker_message import WorkerHardware, WorkerTaskMessage @@ -84,7 +89,7 @@ def prepare(self) -> None: return None @abstractmethod - def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]: + def run(self, task: ExecutorTask, out_dir: Path) -> BaseExecutorResult: """Execute a single task. Args: @@ -93,7 +98,7 @@ def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]: if needed. Returns: - A JSON-serializable dictionary summarizing the result. + A ``BaseExecutorResult`` subclass instance. Raises: ExecutionError: for expected, user-facing failures. @@ -139,13 +144,19 @@ def load_json(path: Path) -> dict[str, Any]: # -------- Minimal example implementation -------- +class EchoResult(BaseExecutorResult): + ok: bool = True + executor: str + task_id: str + task_type: str + + class EchoExecutor(Executor): name = "echo" - def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]: - return { - "ok": True, - "executor": self.name, - "task_id": task.task_id, - "task_type": task.spec.taskType, - } + def run(self, task: ExecutorTask, out_dir: Path) -> EchoResult: + return EchoResult( + executor=self.name, + task_id=task.task_id, + task_type=task.spec.taskType, + ) diff --git a/src/worker/executors/data_profiling_executor.py b/src/worker/executors/data_profiling_executor.py index bfb7678f..57f4249c 100644 --- a/src/worker/executors/data_profiling_executor.py +++ b/src/worker/executors/data_profiling_executor.py @@ -8,6 +8,7 @@ from pathlib import Path from typing import Any +from shared.schemas.result import BaseExecutorResult from shared.tasks.specs import DataProfilingSpecStrict from shared.utils.json import to_json_serializable, validate_keys @@ -19,19 +20,25 @@ logger = logging.getLogger(__name__) +class DataProfilingResult(BaseExecutorResult): + ok: bool = True + type: str = "sql" + template: str | None = None + cost_estimates: dict[str, Any] | None = None + + class DataProfilingExecutor(DataMixin, Executor): """Executor that estimates SQL query costs by sampling SQL template params.""" name = "data_profiling" - def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]: + def run(self, task: ExecutorTask, out_dir: Path) -> DataProfilingResult: spec = self.require_spec(task, DataProfilingSpecStrict) task_id = task.task_id merge_children = task.merged_children or [] result = self._run_single_profile(spec, task_id) - child_results: dict[str, dict[str, Any]] = {} for child in merge_children: child_id = child.task_id child_spec = child.spec @@ -39,16 +46,13 @@ def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]: raise ExecutionError( "Merged child spec must be data_profiling for merged profiling" ) - child_results[child_id] = self._run_single_profile(child_spec, child_id) - - if child_results: - result["children"] = child_results + result.children[child_id] = self._run_single_profile(child_spec, child_id) return result def _run_single_profile( self, spec: DataProfilingSpecStrict, task_id: str - ) -> dict[str, Any]: + ) -> DataProfilingResult: data_cfg = spec.data if not isinstance(data_cfg, dict): raise ExecutionError( @@ -100,14 +104,12 @@ def _run_single_profile( connection_string, queries, params_rows ) - result: dict[str, Any] = { - "ok": True, - "type": "sql", - "template": template_str, - "cost_estimates": cost_estimates, - } - - return result + return DataProfilingResult( + ok=True, + type="sql", + template=template_str, + cost_estimates=cost_estimates, + ) def _sample_template_queries( self, diff --git a/src/worker/executors/data_retrieval_executor.py b/src/worker/executors/data_retrieval_executor.py index 0321f147..4347a262 100644 --- a/src/worker/executors/data_retrieval_executor.py +++ b/src/worker/executors/data_retrieval_executor.py @@ -10,6 +10,8 @@ import pandas as pd from PIL import Image +from shared.schemas.artifact import ArtifactRef +from shared.schemas.result import BaseExecutorResult from shared.tasks.specs import DataRetrievalSpecStrict from shared.utils.json import validate_keys @@ -17,20 +19,23 @@ from ..utils.serialization import serialize_dataframe from .base_executor import ExecutionError, Executor, ExecutorTask from .mixins.data import DataMixin -from .utils.checkpoints import ( - artifact_ref, - maybe_upload_artifacts, - maybe_upload_traces, -) +from .utils.checkpoints import maybe_upload_artifacts, maybe_upload_traces from .utils.graph_templates import _render_template, _resolve_columns logger = logging.getLogger(__name__) +class DataRetrievalResult(BaseExecutorResult): + type: str | None = None + items: list[dict[str, Any]] = [] + count: int | None = None + metadata: dict[str, Any] | None = None + + class DataRetrievalExecutor(DataMixin, Executor): name = "data_retrieval" - def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]: + def run(self, task: ExecutorTask, out_dir: Path) -> DataRetrievalResult: spec = self.require_spec(task, DataRetrievalSpecStrict) task_id = task.task_id with self._task_span( @@ -71,15 +76,15 @@ def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]: return result def _run_sql( - self, data_cfg: dict[str, Any], context: dict[str, Any] - ) -> dict[str, Any]: + self, data_cfg: dict[str, Any], context: dict[str, BaseExecutorResult] + ) -> DataRetrievalResult: """ Execute SQL queries based on the provided data configuration and context. :param data_cfg: Description :type data_cfg: dict[str, Any] :param context: Description - :type context: dict[str, Any] + :type context: dict[str, BaseExecutorResult] :return: Description :rtype: dict[str, Any] """ @@ -144,18 +149,17 @@ def _run_sql( } ) - return { - "ok": True, - "items": items, - "count": len(items), - } + return DataRetrievalResult( + items=items, + count=len(items), + ) def _run_s3( self, data_cfg: dict[str, Any], - context: dict[str, Any], + context: dict[str, BaseExecutorResult], out_dir: Path, - ) -> dict[str, Any]: + ) -> DataRetrievalResult: validate_keys( data_cfg, "DataRetrievalExecutor.spec.data", @@ -207,20 +211,18 @@ def _run_s3( item["params"] = params_rows items.append(item) - result = { - "ok": True, - "type": "s3", - "items": items, - "metadata": s3_result["metadata"], # type: ignore - } - return result + return DataRetrievalResult( + type="s3", + items=items, + metadata=s3_result["metadata"], # type: ignore + ) def _run_agent( self, data_cfg: dict[str, Any], - context: dict[str, Any], + context: dict[str, BaseExecutorResult], out_dir: Path, - ) -> dict[str, Any]: + ) -> DataRetrievalResult: """Drive lumid.data's data agent for NL-driven retrieval.""" validate_keys( data_cfg, @@ -314,12 +316,11 @@ def _run_agent( } ) - return { - "ok": True, - "type": "agent", - "items": items, - "count": len(items), - } + return DataRetrievalResult( + type="agent", + items=items, + count=len(items), + ) def _load_table(self, path: Path, output_format: str) -> pd.DataFrame: """Load the materialized retrieval file into a DataFrame.""" @@ -338,5 +339,5 @@ def _serialize_s3_content(self, content: Any, out_dir: Path) -> Any: filename = f"{uuid.uuid4().hex}.png" file_path = images_dir / filename content.save(file_path, format="PNG") - return artifact_ref(f"s3_images/{filename}") + return ArtifactRef(path=f"s3_images/{filename}") return content diff --git a/src/worker/executors/diffusers_executor.py b/src/worker/executors/diffusers_executor.py index f74fb8e4..0c782fce 100644 --- a/src/worker/executors/diffusers_executor.py +++ b/src/worker/executors/diffusers_executor.py @@ -14,16 +14,14 @@ from PIL import Image +from shared.schemas.artifact import ArtifactRef +from shared.schemas.result import BaseExecutorResult from shared.tasks.specs import DiffusionSpecStrict from ..utils.logging import configure_hf_library_logging from .base_executor import ExecutionError, Executor, ExecutorTask from .mixins.data import DataMixin -from .utils.checkpoints import ( - artifact_ref, - maybe_upload_artifacts, - maybe_upload_traces, -) +from .utils.checkpoints import maybe_upload_artifacts, maybe_upload_traces try: import torch @@ -45,6 +43,12 @@ logger = logging.getLogger(__name__) +class DiffusersResult(BaseExecutorResult): + ok: bool = True + model: str | None = None + images: list[ArtifactRef] = [] + + class DiffusersExecutor(DataMixin, Executor): """Executor that runs text-to-image generation via Hugging Face Diffusers.""" @@ -240,24 +244,24 @@ def _encode_and_combine_prompts( return combined_pos, combined_neg, user_pos_pooled, user_neg_pooled - def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]: + def run(self, task: ExecutorTask, out_dir: Path) -> DiffusersResult: configure_hf_library_logging() spec = self.require_spec(task, DiffusionSpecStrict) task_id = task.task_id.strip() with self._task_span( task_id, task.workflow_id, out_dir, owner_id=task.owner_id ): - response = self._run_inner(spec, task_id, out_dir) + result = self._run_inner(spec, task_id, out_dir) maybe_upload_artifacts(task, out_dir, logger=logger) maybe_upload_traces(task, out_dir, logger=logger) - return response + return result def _run_inner( self, spec: DiffusionSpecStrict, task_id: str, out_dir: Path, - ) -> dict[str, Any]: + ) -> DiffusersResult: self._ensure_pipeline(spec) assert self._pipe is not None @@ -336,24 +340,18 @@ def _run_inner( image_dir = out_dir / "artifacts" / "images" image_dir.mkdir(parents=True, exist_ok=True) - generated_images: list[dict[str, str]] = [] + generated_images: list[ArtifactRef] = [] for idx, img in enumerate(images): img.save(image_dir / f"image_{idx}.png", format="PNG") - generated_images.append(artifact_ref(f"images/image_{idx}.png")) - - response: dict[str, Any] = { - "ok": True, - "model": self._model_name, - "images": generated_images, - } + generated_images.append(ArtifactRef(path=f"images/image_{idx}.png")) + result = DiffusersResult(model=self._model_name, images=generated_images) self._dump_to_governance( task_id=task_id, - result=response, + result=result, dependencies_by_task=dependencies_by_task, ) - - return response + return result def cleanup_after_run(self) -> None: if self._pipe is not None: diff --git a/src/worker/executors/dpo_executor.py b/src/worker/executors/dpo_executor.py index 7dc1e8d2..b62256f2 100644 --- a/src/worker/executors/dpo_executor.py +++ b/src/worker/executors/dpo_executor.py @@ -25,6 +25,8 @@ from trl.trainer.dpo_config import DPOConfig from trl.trainer.dpo_trainer import DPOTrainer +from shared.schemas.artifact import ArtifactRef +from shared.schemas.result import BaseExecutorResult from shared.tasks.specs import DPOSpecStrict from shared.utils.manifest import scratch_dir @@ -33,7 +35,6 @@ from .mixins.training import TrainingMixin from .utils.checkpoints import ( archive_model_dir, - artifact_ref, get_http_destination, maybe_upload_artifacts, write_executor_result, @@ -45,6 +46,18 @@ logger = logging.getLogger("worker.dpo") +class DPOResult(BaseExecutorResult): + training_time_seconds: float | None = None + error_message: str | None = None + model_name: str | None = None + dataset_size: int = 0 + output_dir: str | None = None + checkpoints_dir: ArtifactRef | None = None + final_model: ArtifactRef | None = None + final_model_archive: ArtifactRef | None = None + spawned_torchrun: bool = False + + class DPOExecutor(TrainingMixin, Executor): """DPO training executor using TRL library.""" @@ -59,7 +72,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: self._current_trainer: DPOTrainer | None = None self._task_out_dir: Path | None = None - def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]: + def run(self, task: ExecutorTask, out_dir: Path) -> DPOResult: configure_hf_library_logging() logger.info("Starting DPO training task") spec = self.require_spec(task, DPOSpecStrict) @@ -80,22 +93,21 @@ def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]: ) ipc_path = scratch_dir(out_dir) / "distributed_result.json" if ipc_path.exists(): - return self.load_json(ipc_path) - return { - "training_successful": True, - "spawned_torchrun": True, - "model_name": ( - spec.model - and spec.model.source - and spec.model.source.identifier + return DPOResult.model_validate(self.load_json(ipc_path)) + return DPOResult( + spawned_torchrun=True, + model_name=( + spec.model.source.identifier + if spec.model and spec.model.source + else None ), - "output_dir": out_dir.as_posix(), - } + output_dir=out_dir.as_posix(), + ) result = self._execute_training(task, out_dir) logger.info( "DPO training task completed in %.2f seconds", - result.get("training_time_seconds", 0.0), + result.training_time_seconds or 0.0, ) return result finally: @@ -354,7 +366,7 @@ def _ensure_jsonl_local(jsonl_cfg: dict[str, Any]) -> Path: return dataset # type: ignore[return-value] - def _execute_training(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]: + def _execute_training(self, task: ExecutorTask, out_dir: Path) -> DPOResult: spec = self.require_spec(task, DPOSpecStrict) training_config = spec.training or {} artifacts_dir = out_dir / "artifacts" @@ -462,23 +474,28 @@ def _execute_training(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any] logger.warning("Failed to save model: %s", exc) training_time = time.time() - start_time - results: dict[str, Any] = { - "training_successful": True, - "training_time_seconds": training_time, - "error_message": None, - "model_name": self._model_name, - "dataset_size": len(dataset), - "output_dir": out_dir.as_posix(), - "checkpoints_dir": artifact_ref("checkpoints"), - } - if final_model_path is not None: - results["final_model"] = artifact_ref( - final_model_path.relative_to(artifacts_dir).as_posix() - ) - if final_archive_path is not None: - results["final_model_archive"] = artifact_ref( - final_archive_path.relative_to(artifacts_dir).as_posix() - ) + result = DPOResult( + training_time_seconds=training_time, + error_message=None, + model_name=self._model_name, + dataset_size=len(dataset) if dataset is not None else 0, + output_dir=out_dir.as_posix(), + checkpoints_dir=ArtifactRef(path="checkpoints"), + final_model=( + ArtifactRef( + path=final_model_path.relative_to(artifacts_dir).as_posix() + ) + if final_model_path + else None + ), + final_model_archive=( + ArtifactRef( + path=final_archive_path.relative_to(artifacts_dir).as_posix() + ) + if final_archive_path + else None + ), + ) maybe_upload_artifacts(task, out_dir, logger=logger) @@ -488,24 +505,22 @@ def _execute_training(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any] final_model_path, final_archive_path, ) - return results + return result except Exception as exc: training_time = time.time() - start_time - results = { - "training_successful": False, - "training_time_seconds": training_time, - "error_message": str(exc), - "model_name": self._model_name, - "dataset_size": len(dataset) if dataset is not None else 0, - "output_dir": out_dir.as_posix(), - } + result = DPOResult( + ok=False, + training_time_seconds=training_time, + error_message=str(exc), + model_name=self._model_name, + dataset_size=len(dataset) if dataset is not None else 0, + output_dir=out_dir.as_posix(), + ) write_executor_result( - out_dir / "results.json", task.task_id, task.spec, results + out_dir / "results.json", task.task_id, task.spec, result ) logger.exception("DPO training failed: %s", exc) - raise ExecutionError( - results["error_message"] or "DPO training failed" - ) from exc + raise ExecutionError(result.error_message or "DPO training failed") from exc def _spawn_distributed( self, diff --git a/src/worker/executors/echo_executor.py b/src/worker/executors/echo_executor.py index 480938a4..1a03662f 100644 --- a/src/worker/executors/echo_executor.py +++ b/src/worker/executors/echo_executor.py @@ -2,6 +2,7 @@ from pathlib import Path from typing import Any +from shared.schemas.result import BaseExecutorResult from shared.tasks.specs import EchoSpecStrict from .base_executor import ExecutionError, Executor, ExecutorTask @@ -14,6 +15,11 @@ type EchoItem = str | dict[str, str] +class EchoResult(BaseExecutorResult): + items: list[dict[str, Any]] = [] + count: int = 0 + + class EchoExecutor(DataMixin, Executor): name = "echo" @@ -25,7 +31,9 @@ def _append_outputs(self, out_items: list[dict[str, Any]], value: Any) -> None: out_items.append({"output": value}) @staticmethod - def _resolve_expr_item(item: dict[str, Any], context: dict[str, Any]) -> Any: + def _resolve_expr_item( + item: dict[str, Any], context: dict[str, BaseExecutorResult] + ) -> Any: expr = item.get("expr") if not expr: node = item.get("node") @@ -44,7 +52,9 @@ def _resolve_expr_item(item: dict[str, Any], context: dict[str, Any]) -> Any: ) return resolved - def _resolve_item(self, item: EchoItem, context: dict[str, Any]) -> Any: + def _resolve_item( + self, item: EchoItem, context: dict[str, BaseExecutorResult] + ) -> Any: if isinstance(item, str): return item elif isinstance(item, dict): @@ -55,7 +65,7 @@ def _resolve_item(self, item: EchoItem, context: dict[str, Any]) -> Any: "a string literal or a mapping" ) - def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]: + def run(self, task: ExecutorTask, out_dir: Path) -> EchoResult: spec = self.require_spec(task, EchoSpecStrict) task_id = task.task_id.strip() with self._task_span( @@ -81,18 +91,14 @@ def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]: resolved = self._resolve_item(item, context) self._append_outputs(merged_items, resolved) - payload: dict[str, Any] = { - "ok": True, - "items": merged_items, - "count": len(merged_items), - } + result = EchoResult(items=merged_items, count=len(merged_items)) deps = self._extract_source_data_ids(spec) dependencies_by_task = {task_id: deps} self._dump_to_governance( task_id=task_id, - result=payload, + result=result, dependencies_by_task=dependencies_by_task, ) maybe_upload_traces(task, out_dir, logger=logger) - return payload + return result diff --git a/src/worker/executors/lora_sft_executor.py b/src/worker/executors/lora_sft_executor.py index f40c99ad..9e36e220 100644 --- a/src/worker/executors/lora_sft_executor.py +++ b/src/worker/executors/lora_sft_executor.py @@ -18,6 +18,8 @@ from trl.trainer.sft_config import SFTConfig from trl.trainer.sft_trainer import SFTTrainer +from shared.schemas.artifact import ArtifactRef +from shared.schemas.result import BaseExecutorResult from shared.tasks.specs import LoRASFTSpecStrict from ..utils.logging import configure_hf_library_logging @@ -26,7 +28,6 @@ from .sft_executor import SFTExecutor from .utils.checkpoints import ( archive_model_dir, - artifact_ref, determine_resume_path, maybe_upload_artifacts, write_executor_result, @@ -48,6 +49,18 @@ logger = logging.getLogger("worker.sft.lora") +class LoRAResult(BaseExecutorResult): + training_time_seconds: float | None = None + error_message: str | None = None + model_name: str | None = None + dataset_size: int = 0 + output_dir: str | None = None + checkpoints_dir: ArtifactRef | None = None + resume_from_path: str | None = None + final_lora: ArtifactRef | None = None + final_lora_archive: ArtifactRef | None = None + + class LoRASFTExecutor(TrainingMixin, Executor): """Execute LoRA-based supervised fine-tuning using TRL's SFTTrainer.""" @@ -59,11 +72,11 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: self._current_model: Any | None = None self._current_trainer: Any | None = None - def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]: + def run(self, task: ExecutorTask, out_dir: Path) -> LoRAResult: configure_hf_library_logging() spec = self.require_spec(task, LoRASFTSpecStrict) start_time = time.time() - training_successful = False + ok = False error_msg: str | None = None if ( @@ -99,6 +112,8 @@ def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]: checkpoint_dir = artifacts_dir / "checkpoints" checkpoint_dir.mkdir(parents=True, exist_ok=True) + resume_path: Path | None = None + train_dataset: Dataset | None = None try: model_name = spec.model_name or "gpt2" self._model_name = model_name @@ -121,7 +136,7 @@ def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]: resume_path = determine_resume_path( spec, training_cfg, out_dir, logger=logger ) - resume_str = str(resume_path) if resume_path else None + resume_str = resume_path.as_posix() if resume_path else None if bool(training_cfg.get("gradient_checkpointing", False)): model.gradient_checkpointing_enable() @@ -139,9 +154,7 @@ def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]: "Resuming LoRA training from local checkpoint %s", resume_path ) peft_model = PeftModel.from_pretrained( - model, - str(resume_path), - is_trainable=True, + model, resume_path.as_posix(), is_trainable=True ) logger.info("Loaded existing LoRA adapters; continuing fine-tuning") else: @@ -182,7 +195,7 @@ def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]: logger.info("Initialized new LoRA adapters: %s", lora_target_modules) sft_config = SFTConfig( - output_dir=str(checkpoint_dir), + output_dir=checkpoint_dir.as_posix(), num_train_epochs=float(training_cfg.get("num_train_epochs", 1.0)), per_device_train_batch_size=int(training_cfg.get("batch_size", 2)), gradient_accumulation_steps=int( @@ -277,41 +290,42 @@ def _compute_loss_with_guard( logger.info("Starting LoRA supervised fine-tuning run") trainer.train() - training_successful = True + ok = True logger.info("LoRA SFT training completed") final_adapter_path: Path | None = None archive_path: Path | None = None if bool(training_cfg.get("save_model", True)): model_path = artifacts_dir / "final_lora" - trainer.save_model(str(model_path)) + trainer.save_model(model_path.as_posix()) tokenizer.save_pretrained(model_path) final_adapter_path = model_path logger.info("Saved LoRA-adapted weights to %s", model_path) training_time = time.time() - start_time - result_payload: dict[str, Any] = { - "task_id": task.task_id, - "training_successful": training_successful, - "training_time_seconds": training_time, - "error_message": error_msg, - "model_name": self._model_name, - "dataset_size": ( - len(train_dataset) if "train_dataset" in locals() else 0 - ), - "output_dir": out_dir.as_posix(), - "checkpoints_dir": artifact_ref("checkpoints"), - "resume_from_path": resume_str, - } - if final_adapter_path is not None: - result_payload["final_lora"] = artifact_ref( - final_adapter_path.relative_to(artifacts_dir).as_posix() + final_lora: ArtifactRef | None = None + final_lora_archive: ArtifactRef | None = None + if final_adapter_path: + final_lora = ArtifactRef( + path=final_adapter_path.relative_to(artifacts_dir).as_posix() ) archive_path = archive_model_dir(final_adapter_path) - result_payload["final_lora_archive"] = artifact_ref( - archive_path.relative_to(artifacts_dir).as_posix() + final_lora_archive = ArtifactRef( + path=archive_path.relative_to(artifacts_dir).as_posix() ) logger.info("Prepared LoRA archive at %s", archive_path) + result = LoRAResult( + ok=ok, + training_time_seconds=training_time, + error_message=error_msg, + model_name=self._model_name, + dataset_size=len(train_dataset) if train_dataset is not None else 0, + output_dir=out_dir.as_posix(), + checkpoints_dir=ArtifactRef(path="checkpoints"), + resume_from_path=resume_str, + final_lora=final_lora, + final_lora_archive=final_lora_archive, + ) maybe_upload_artifacts(task, out_dir, logger=logger) @@ -321,30 +335,27 @@ def _compute_loss_with_guard( final_adapter_path, archive_path, ) - return result_payload + return result except Exception as exc: # pylint: disable=broad-except error_msg = str(exc) - training_successful = False + ok = False logger.exception("LoRA SFT training failed: %s", exc) training_time = time.time() - start_time - result: dict[str, Any] = { - "task_id": task.task_id, - "training_successful": training_successful, - "training_time_seconds": training_time, - "error_message": error_msg, - "model_name": self._model_name, - "dataset_size": len(train_dataset) if "train_dataset" in locals() else 0, - "output_dir": out_dir.as_posix(), - "checkpoints_dir": artifact_ref("checkpoints"), - "resume_from_path": ( - str(resume_path) if "resume_path" in locals() and resume_path else None - ), - } - - if training_successful: + result = LoRAResult( + ok=ok, + training_time_seconds=training_time, + error_message=error_msg, + model_name=self._model_name, + dataset_size=len(train_dataset) if train_dataset is not None else 0, + output_dir=out_dir.as_posix(), + checkpoints_dir=ArtifactRef(path="checkpoints"), + resume_from_path=resume_path.as_posix() if resume_path else None, + ) + + if ok: return result write_executor_result(out_dir / "results.json", task.task_id, task.spec, result) diff --git a/src/worker/executors/mixins/data.py b/src/worker/executors/mixins/data.py index ee45aab7..d0ee4324 100644 --- a/src/worker/executors/mixins/data.py +++ b/src/worker/executors/mixins/data.py @@ -15,6 +15,7 @@ from datasets import Dataset, load_dataset from PIL import Image +from shared.schemas.result import BaseExecutorResult from shared.tasks.specs import TaskSpecStrictBase from shared.utils.json import safe_get @@ -403,7 +404,7 @@ def _collect_prompts_for_spec( metadata_raw.append(entry_meta) elif dtype == "list": items = data.get("items") - context: dict[str, Any] | None = None + context: dict[str, BaseExecutorResult] | None = None root_node: str | None = None if items is None: expr = data.get("expr") @@ -702,14 +703,11 @@ def _collect_prompts_for_spec( ) def _populate_table( - self, - payload: dict[str, Any], - table_stores_list: list[pd.DataFrame], - ): + self, items: list[dict[str, Any]], table_stores_list: list[pd.DataFrame] + ) -> list[dict[str, Any]]: """ Group row-level generation outputs back into per-table outputs. """ - items = payload["items"] cur = 0 grouped_items: list[dict[str, Any]] = [] for df in table_stores_list: @@ -724,8 +722,7 @@ def _populate_table( f"Output length {len(items)} does not match " f"the total number of rows {cur} in table stores." ) - payload["items"] = grouped_items - return payload + return grouped_items def _maybe_apply_dataset_shard(self, dataset, spec: TaskSpecStrictBase): shard_cfg = spec.shard diff --git a/src/worker/executors/mixins/governance.py b/src/worker/executors/mixins/governance.py index 30adfcc9..44617d25 100644 --- a/src/worker/executors/mixins/governance.py +++ b/src/worker/executors/mixins/governance.py @@ -18,6 +18,7 @@ TASK_SPAN_NAME, SpanType, ) +from shared.schemas.result import BaseExecutorResult from shared.tasks.specs import TaskSpecStrictBase from shared.utils.time import now_iso @@ -209,7 +210,9 @@ def _record_output( ) @staticmethod - def _spec_upstream_results(spec: TaskSpecStrictBase) -> dict[str, Any]: + def _spec_upstream_results( + spec: TaskSpecStrictBase, + ) -> dict[str, BaseExecutorResult]: """Validated ``spec._upstreamResults`` (server-injected stage context).""" context = spec.upstreamResults or {} if not isinstance(context, dict): @@ -236,27 +239,25 @@ def _extract_source_data_ids(self, spec: TaskSpecStrictBase) -> list[str]: def _dump_to_governance( self, task_id: str, - result: dict[str, Any], + result: BaseExecutorResult, dependencies_by_task: dict[str, list[str]], ) -> None: """Write parent + merged-child results and emit asset/lineage rows.""" parent_deps = dependencies_by_task.get(task_id, []) - children_payload = result.get("children", {}) - collection_jobs: list[dict[str, Any]] = [ { "task_id": task_id, - "result": result, + "result": result.model_dump(), "deps": parent_deps, "is_parent": True, } ] - for child_id, child_result in children_payload.items(): + for child_id, child_result in result.children.items(): child_deps = dependencies_by_task.get(child_id, []) collection_jobs.append( { "task_id": child_id, - "result": child_result, + "result": child_result.model_dump(), "deps": child_deps, "is_parent": False, } diff --git a/src/worker/executors/mixins/inference.py b/src/worker/executors/mixins/inference.py index fb7ae7a2..4076e901 100644 --- a/src/worker/executors/mixins/inference.py +++ b/src/worker/executors/mixins/inference.py @@ -15,7 +15,6 @@ from shared.utils.json import to_json_serializable from ..base_executor import ExecutionError -from ..utils.checkpoints import artifact_ref from .data import DataMixin, InferenceEntry, PromptInput logger = logging.getLogger(__name__) @@ -229,7 +228,7 @@ def _maybe_export_jsonl( self, spec: InferenceSpecStrict, task_id: str, - result: dict[str, Any], + items: list[dict[str, Any]], out_dir: Path, ) -> None: post_cfg = (postprocess := spec.postprocess) and postprocess.jsonl_export @@ -261,7 +260,6 @@ def _maybe_export_jsonl( ) from exc target_path.parent.mkdir(parents=True, exist_ok=True) - items = result.get("items") or [] required_fields = post_cfg.required_fields or [] records: list[dict[str, Any]] = [] @@ -292,11 +290,6 @@ def _maybe_export_jsonl( fh.write("\n") rel_path = target_path.relative_to(artifacts_dir).as_posix() - result["jsonl_export"] = { - **artifact_ref(rel_path), - "record_count": len(records), - "fields": list(fields_cfg.keys()), - } logger.info( "Task %s exported %d records to artifacts/%s", task_id, diff --git a/src/worker/executors/mp_executor.py b/src/worker/executors/mp_executor.py index f2379fe2..f2b7c199 100644 --- a/src/worker/executors/mp_executor.py +++ b/src/worker/executors/mp_executor.py @@ -21,6 +21,7 @@ import psutil +from shared.schemas.result import BaseExecutorResult from shared.tasks.worker_message import WorkerHardware from worker.config import WorkerConfig @@ -439,7 +440,7 @@ def _loop() -> None: self._log_thread = t t.start() - def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]: + def run(self, task: ExecutorTask, out_dir: Path) -> BaseExecutorResult: with self._lock: if self._shutdown: logger.info("Starting worker subprocess for %s", self.name) diff --git a/src/worker/executors/omni_executor_base.py b/src/worker/executors/omni_executor_base.py index 87e36f93..8fd9ec21 100644 --- a/src/worker/executors/omni_executor_base.py +++ b/src/worker/executors/omni_executor_base.py @@ -19,6 +19,7 @@ import yaml +from shared.schemas.result import BaseExecutorResult from shared.tasks.specs import TaskSpecStrictBase from shared.utils.parsing import to_bool, to_int @@ -40,6 +41,13 @@ logger = logging.getLogger(__name__) +class OmniResult(BaseExecutorResult): + executor: str + mode: str + model: str | None + items: list[dict[str, Any]] + + class OmniExecutorBase(InferenceMixin, Executor): """Shared base for Omni-family executors. @@ -58,7 +66,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: self._omni_spec: tuple[Any, ...] | None = None self._stage_configs_tmp: Path | None = None - def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]: + def run(self, task: ExecutorTask, out_dir: Path) -> OmniResult: spec = self.require_spec(task, self._TASK_SPEC_TYPE) spec_dict = spec.model_dump(by_alias=True) out_dir = Path(out_dir).resolve() @@ -77,7 +85,7 @@ def _run_inner( spec: TaskSpecStrictBase, spec_dict: dict[str, Any], out_dir: Path, - ) -> dict[str, Any]: + ) -> OmniResult: """Run the executor-specific body. ``spec`` is the concrete strict spec; subclasses ``assert isinstance(spec, ...)`` to narrow.""" raise NotImplementedError diff --git a/src/worker/executors/omni_text2audio_executor.py b/src/worker/executors/omni_text2audio_executor.py index 952dfb28..ef4ff3c1 100644 --- a/src/worker/executors/omni_text2audio_executor.py +++ b/src/worker/executors/omni_text2audio_executor.py @@ -50,22 +50,33 @@ current_omni_platform = None _HAS_OMNI_PLATFORM = False +from shared.schemas.artifact import ArtifactRef from shared.schemas.governance import SpanType from shared.tasks.specs import TaskSpecStrictBase from shared.tasks.specs.omni import OmniText2AudioSpecStrict from shared.utils.parsing import to_float, to_int from .base_executor import ExecutionError, ExecutorTask -from .omni_executor_base import OmniExecutorBase, extract_multimodal_output -from .utils.checkpoints import artifact_ref +from .omni_executor_base import OmniExecutorBase, OmniResult, extract_multimodal_output logger = logging.getLogger(__name__) +EXECUTOR_NAME = "omni_text2audio" + + +class OmniText2AudioResult(OmniResult): + executor: str = EXECUTOR_NAME + mode: str = "bgm" + audio: ArtifactRef | None + sample_rate: int + num_waveforms: int + audio_length: float + storyboard: dict[str, Any] | None = None class OmniText2AudioExecutor(OmniExecutorBase): """Generate background music with Omni diffusion sampling.""" - name = "omni_text2audio" + name = EXECUTOR_NAME _TASK_SPEC_TYPE = OmniText2AudioSpecStrict def prepare(self) -> None: @@ -84,7 +95,7 @@ def _run_inner( spec: TaskSpecStrictBase, spec_dict: dict[str, Any], out_dir: Path, - ) -> dict[str, Any]: + ) -> OmniText2AudioResult: assert isinstance(spec, OmniText2AudioSpecStrict) prompts = self._collect_text_inputs(spec, task.task_id) @@ -192,8 +203,8 @@ def _run_inner( "prompt_index": prompt_idx, "waveform_index": local_idx, "prompt": prompt, - "audio": artifact_ref( - self.relative_to(save_path, artifacts_dir) + "audio": ArtifactRef( + path=self.relative_to(save_path, artifacts_dir) ), } ) @@ -202,22 +213,15 @@ def _run_inner( if not items: raise ExecutionError("omni_text2audio produced no savable waveforms.") - first = items[0]["audio"] if items else {} - result: dict[str, Any] = { - "ok": True, - "executor": self.name, - "mode": "bgm", - "model": self._model_name, - "audio": first, - "items": items, - "sample_rate": sample_rate, - "num_waveforms": len(items), - "audio_length": audio_length, - } - storyboard = spec_dict.get("storyboard") - if isinstance(storyboard, dict): - result["storyboard"] = dict(storyboard) - return result + return OmniText2AudioResult( + model=self._model_name, + audio=items[0]["audio"] if items else None, + items=items, + sample_rate=sample_rate, + num_waveforms=len(items), + audio_length=audio_length, + storyboard=spec_dict.get("storyboard"), + ) # ── model ──────────────────────────────────────────────────────────── diff --git a/src/worker/executors/omni_text2general_executor.py b/src/worker/executors/omni_text2general_executor.py index 5664af99..9c48763c 100644 --- a/src/worker/executors/omni_text2general_executor.py +++ b/src/worker/executors/omni_text2general_executor.py @@ -26,16 +26,22 @@ Omni = None _HAS_OMNI = False +from shared.schemas.artifact import ArtifactRef from shared.schemas.governance import SpanType from shared.tasks.specs import TaskSpecStrictBase from shared.tasks.specs.omni import OmniText2GeneralSpecStrict from shared.utils.parsing import as_list, to_bool, to_float, to_int, to_int_list from .base_executor import ExecutionError, ExecutorTask -from .omni_executor_base import OmniExecutorBase, extract_audio_from_mm, save_audio -from .utils.checkpoints import artifact_ref +from .omni_executor_base import ( + OmniExecutorBase, + OmniResult, + extract_audio_from_mm, + save_audio, +) logger = logging.getLogger(__name__) +EXECUTOR_NAME = "omni_text2general" _DEFAULT_SYSTEM_PROMPT = ( "You are Qwen, a virtual human developed by the Qwen Team, " @@ -44,10 +50,18 @@ ) +class OmniText2GeneralResult(OmniResult): + executor: str = EXECUTOR_NAME + mode: str = "narration" + audio: ArtifactRef | None + sample_rate: int + storyboard: dict[str, Any] | None = None + + class OmniText2GeneralExecutor(OmniExecutorBase): """Generate narration/speech audio using Qwen3-Omni through vllm_omni.Omni.""" - name = "omni_text2general" + name = EXECUTOR_NAME _TASK_SPEC_TYPE = OmniText2GeneralSpecStrict def prepare(self) -> None: @@ -67,7 +81,7 @@ def _run_inner( spec: TaskSpecStrictBase, spec_dict: dict[str, Any], out_dir: Path, - ) -> dict[str, Any]: + ) -> OmniText2GeneralResult: assert isinstance(spec, OmniText2GeneralSpecStrict) texts = self._collect_text_inputs(spec, task.task_id) @@ -170,27 +184,22 @@ def _run_inner( "index": idx, "request_id": rid, "prompt": texts[idx] if idx < len(texts) else None, - "audio": artifact_ref(self.relative_to(save_path, artifacts_dir)), + "audio": ArtifactRef( + path=self.relative_to(save_path, artifacts_dir) + ), } text_out = text_results.get(rid) if text_out: item["text"] = text_out items.append(item) - first = items[0]["audio"] if items else {} - result: dict[str, Any] = { - "ok": True, - "executor": self.name, - "mode": "narration", - "model": self._model_name, - "audio": first, - "items": items, - "sample_rate": sample_rate, - } - storyboard = spec_dict.get("storyboard") - if isinstance(storyboard, dict): - result["storyboard"] = dict(storyboard) - return result + return OmniText2GeneralResult( + model=self._model_name, + items=items, + audio=items[0]["audio"] if items else None, + sample_rate=sample_rate, + storyboard=spec_dict.get("storyboard"), + ) # ── model ──────────────────────────────────────────────────────────── diff --git a/src/worker/executors/omni_text2image_executor.py b/src/worker/executors/omni_text2image_executor.py index e3b1f234..f11db25b 100644 --- a/src/worker/executors/omni_text2image_executor.py +++ b/src/worker/executors/omni_text2image_executor.py @@ -15,22 +15,29 @@ Omni = None _HAS_OMNI = False +from shared.schemas.artifact import ArtifactRef from shared.schemas.governance import SpanType from shared.tasks.specs import TaskSpecStrictBase from shared.tasks.specs.omni import OmniText2ImageSpecStrict from shared.utils.parsing import as_list from .base_executor import ExecutionError, ExecutorTask -from .omni_executor_base import OmniExecutorBase -from .utils.checkpoints import artifact_ref +from .omni_executor_base import OmniExecutorBase, OmniResult logger = logging.getLogger(__name__) +EXECUTOR_NAME = "omni_text2image" + + +class OmniText2ImageResult(OmniResult): + executor: str = EXECUTOR_NAME + mode: str = "image" + image: ArtifactRef | None class OmniText2ImageExecutor(OmniExecutorBase): """Generate images using vllm_omni.Omni.""" - name = "omni_text2image" + name = EXECUTOR_NAME _TASK_SPEC_TYPE = OmniText2ImageSpecStrict def prepare(self) -> None: @@ -45,7 +52,7 @@ def _run_inner( spec: TaskSpecStrictBase, spec_dict: dict[str, Any], out_dir: Path, - ) -> dict[str, Any]: + ) -> OmniText2ImageResult: assert isinstance(spec, OmniText2ImageSpecStrict) prompts = self._collect_text_inputs(spec, task.task_id) @@ -86,22 +93,17 @@ def _run_inner( { "index": idx, "prompt": prompt, - "image": artifact_ref( - self.relative_to(save_path, artifacts_dir) + "image": ArtifactRef( + path=self.relative_to(save_path, artifacts_dir) ), } ) - first = items[0]["image"] if items else {} - result: dict[str, Any] = { - "ok": True, - "executor": self.name, - "mode": "image", - "model": self._model_name, - "image": first, - "items": items, - } - return result + return OmniText2ImageResult( + model=self._model_name, + image=items[0]["image"] if items else None, + items=items, + ) # ── model ──────────────────────────────────────────────────────────── diff --git a/src/worker/executors/omni_text2speech_executor.py b/src/worker/executors/omni_text2speech_executor.py index 5d7be1b7..b37aa74a 100644 --- a/src/worker/executors/omni_text2speech_executor.py +++ b/src/worker/executors/omni_text2speech_executor.py @@ -16,6 +16,7 @@ Omni = None _HAS_OMNI = False +from shared.schemas.artifact import ArtifactRef from shared.schemas.governance import SpanType from shared.tasks.specs import TaskSpecStrictBase from shared.tasks.specs.omni import OmniText2SpeechSpecStrict @@ -24,19 +25,28 @@ from .base_executor import ExecutionError, ExecutorTask from .omni_executor_base import ( OmniExecutorBase, + OmniResult, extract_audio_from_mm, extract_multimodal_output, save_audio, ) -from .utils.checkpoints import artifact_ref logger = logging.getLogger(__name__) +EXECUTOR_NAME = "omni_text2speech" + + +class OmniText2SpeechResult(OmniResult): + executor: str = EXECUTOR_NAME + mode: str = "tts" + audio: ArtifactRef | None + sample_rate: int + storyboard: dict[str, Any] | None = None class OmniText2SpeechExecutor(OmniExecutorBase): """Generate speech audio using vllm_omni.Omni.""" - name = "omni_text2speech" + name = EXECUTOR_NAME _TASK_SPEC_TYPE = OmniText2SpeechSpecStrict def prepare(self) -> None: @@ -51,7 +61,7 @@ def _run_inner( spec: TaskSpecStrictBase, spec_dict: dict[str, Any], out_dir: Path, - ) -> dict[str, Any]: + ) -> OmniText2SpeechResult: assert isinstance(spec, OmniText2SpeechSpecStrict) texts = self._collect_text_inputs(spec, task.task_id) @@ -95,26 +105,19 @@ def _run_inner( { "index": idx, "text": text, - "audio": artifact_ref( - self.relative_to(save_path, artifacts_dir) + "audio": ArtifactRef( + path=self.relative_to(save_path, artifacts_dir) ), } ) - first = items[0]["audio"] if items else {} - result: dict[str, Any] = { - "ok": True, - "executor": self.name, - "mode": "tts", - "model": self._model_name, - "audio": first, - "items": items, - "sample_rate": sample_rate, - } - storyboard = spec_dict.get("storyboard") - if isinstance(storyboard, dict): - result["storyboard"] = dict(storyboard) - return result + return OmniText2SpeechResult( + model=self._model_name, + items=items, + audio=items[0]["audio"] if items else None, + sample_rate=sample_rate, + storyboard=spec_dict.get("storyboard"), + ) # ── model ──────────────────────────────────────────────────────────── diff --git a/src/worker/executors/ppo_executor.py b/src/worker/executors/ppo_executor.py index 1cc47183..4c859719 100644 --- a/src/worker/executors/ppo_executor.py +++ b/src/worker/executors/ppo_executor.py @@ -14,10 +14,7 @@ from contextlib import contextmanager, nullcontext from pathlib import Path from types import SimpleNamespace -from typing import TYPE_CHECKING, Any, cast - -if TYPE_CHECKING: - from deepspeed.runtime.engine import DeepSpeedEngine +from typing import Any, cast import torch from datasets import Dataset @@ -33,6 +30,8 @@ from trl.trainer.ppo_config import PPOConfig from trl.trainer.ppo_trainer import PPOTrainer +from shared.schemas.artifact import ArtifactRef +from shared.schemas.result import BaseExecutorResult from shared.tasks.specs import PPOSpecStrict from shared.utils.manifest import scratch_dir from shared.utils.parsing import safe_float, safe_int, to_bool @@ -42,7 +41,6 @@ from .mixins.training import TrainingMixin from .utils.checkpoints import ( archive_model_dir, - artifact_ref, get_http_destination, maybe_upload_artifacts, write_executor_result, @@ -54,6 +52,18 @@ logger = logging.getLogger("worker.ppo") +class PPOResult(BaseExecutorResult): + training_time_seconds: float | None = None + error_message: str | None = None + model_name: str | None = None + dataset_size: int = 0 + output_dir: str | None = None + checkpoints_dir: ArtifactRef | None = None + final_model: ArtifactRef | None = None + final_model_archive: ArtifactRef | None = None + spawned_torchrun: bool = False + + class _ExternalRewardModel(torch.nn.Module): """Wraps a sequence classification model to score decoded PPO responses.""" @@ -392,7 +402,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: self._reward_module: _ExternalRewardModel | _RewardAdapter | None = None self._task_out_dir: Path | None = None - def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]: + def run(self, task: ExecutorTask, out_dir: Path) -> PPOResult: configure_hf_library_logging() logger.info("Starting PPO training task") spec = self.require_spec(task, PPOSpecStrict) @@ -418,16 +428,14 @@ def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]: ) ipc_path = scratch_dir(out_dir) / "distributed_result.json" if ipc_path.exists(): - result = self.load_json(ipc_path) self._task_out_dir = None - return result + return PPOResult.model_validate(self.load_json(ipc_path)) self._task_out_dir = None - return { - "training_successful": True, - "spawned_torchrun": True, - "model_name": spec.model_name, - "output_dir": out_dir.as_posix(), - } + return PPOResult( + spawned_torchrun=True, + model_name=spec.model_name, + output_dir=out_dir.as_posix(), + ) start_time = time.time() @@ -853,7 +861,7 @@ def build_trainer() -> _EarlyStopPPOTrainer: pass logger.info("PPO training completed") - training_successful = True + ok = True error_msg = None # Save model if requested @@ -881,44 +889,48 @@ def build_trainer() -> _EarlyStopPPOTrainer: logger.warning("Failed to save model: %s", exc) except Exception as exc: - training_successful = False + ok = False error_msg = str(exc) logger.exception("PPO training failed: %s", exc) training_time = time.time() - start_time dataset_size = len(dataset) if "dataset" in locals() else 0 # type: ignore - results = { - "training_successful": training_successful, - "training_time_seconds": training_time, - "error_message": error_msg, - "model_name": self._model_name, - "dataset_size": dataset_size, - "output_dir": out_dir.as_posix(), - } + result = PPOResult( + ok=ok, + training_time_seconds=training_time, + error_message=error_msg, + model_name=self._model_name, + dataset_size=dataset_size, + output_dir=out_dir.as_posix(), + ) write_executor_result( - out_dir / "results.json", task.task_id, task.spec, results + out_dir / "results.json", task.task_id, task.spec, result ) self._task_out_dir = None raise ExecutionError(error_msg or "PPO training failed") from exc training_time = time.time() - start_time - results = { - "training_successful": training_successful, - "training_time_seconds": training_time, - "error_message": error_msg, - "model_name": self._model_name, - "dataset_size": len(dataset), - "output_dir": out_dir.as_posix(), - "checkpoints_dir": artifact_ref("checkpoints"), - } - if final_model_path is not None: - results["final_model"] = artifact_ref( - final_model_path.relative_to(artifacts_dir).as_posix() - ) - if final_archive_path is not None: - results["final_model_archive"] = artifact_ref( - final_archive_path.relative_to(artifacts_dir).as_posix() - ) + result = PPOResult( + ok=ok, + training_time_seconds=training_time, + error_message=error_msg, + model_name=self._model_name, + dataset_size=len(dataset), + output_dir=out_dir.as_posix(), + checkpoints_dir=ArtifactRef(path="checkpoints"), + final_model=( + ArtifactRef(path=final_model_path.relative_to(artifacts_dir).as_posix()) + if final_model_path + else None + ), + final_model_archive=( + ArtifactRef( + path=final_archive_path.relative_to(artifacts_dir).as_posix() + ) + if final_archive_path + else None + ), + ) maybe_upload_artifacts(task, out_dir, logger=logger) @@ -931,7 +943,7 @@ def build_trainer() -> _EarlyStopPPOTrainer: logger.info("PPO training task completed in %.2f seconds", training_time) self._task_out_dir = None - return results + return result def _ensure_jsonl_local(self, jsonl_cfg: dict[str, Any]) -> Path: headers_cfg = ( @@ -1446,7 +1458,7 @@ def _wrapped_save_model( output_dir: str | None = None, _internal_call: bool = False ) -> None: backup_model = ppo_trainer.model - backup_deepspeed: DeepSpeedEngine | None = None + backup_deepspeed: Any = None ppo_trainer.model = self._resolve_model_for_save(backup_model) if ppo_trainer.is_deepspeed_enabled: backup_deepspeed = ppo_trainer.deepspeed diff --git a/src/worker/executors/rag_executor.py b/src/worker/executors/rag_executor.py index 839041bf..06b3c705 100644 --- a/src/worker/executors/rag_executor.py +++ b/src/worker/executors/rag_executor.py @@ -16,18 +16,29 @@ from datasets import load_dataset from qdrant_client import QdrantClient, models +from shared.schemas.result import BaseExecutorResult from shared.tasks.specs import RagSpecStrict from .base_executor import ExecutionError, Executor, ExecutorTask from .utils.graph_templates import Message, build_prompts_from_graph_template logger = logging.getLogger("worker.rag") +EXECUTOR_NAME = "rag" + + +class RAGResult(BaseExecutorResult): + executor: str = EXECUTOR_NAME + qdrant: dict[str, Any] + embedding: dict[str, Any] + search: dict[str, Any] + queries: list[dict[str, Any]] = [] + usage: dict[str, Any] | None = None class RAGExecutor(Executor): - name = "rag" + name = EXECUTOR_NAME - def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]: + def run(self, task: ExecutorTask, out_dir: Path) -> RAGResult: start_ts = time.time() spec = self.require_spec(task, RagSpecStrict) @@ -180,22 +191,17 @@ def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]: } ) - # Compose response - out: dict[str, Any] = { - "ok": True, - "executor": self.name, - "qdrant": {"collection": collection, "url": url}, - "embedding": {"model": model_name}, - "search": {"top_k": top_k}, - "queries": results_per_query, - "usage": { + logger.info( + "RAG query completed queries=%d total_results=%d", len(queries), total_items + ) + return RAGResult( + qdrant={"collection": collection, "url": url}, + embedding={"model": model_name}, + search={"top_k": top_k}, + queries=results_per_query, + usage={ "latency_sec": round(time.time() - start_ts, 4), "num_queries": len(queries), "total_results": total_items, }, - } - - logger.info( - "RAG query completed queries=%d total_results=%d", len(queries), total_items ) - return out diff --git a/src/worker/executors/sft_executor.py b/src/worker/executors/sft_executor.py index 4361c94d..030dcd81 100644 --- a/src/worker/executors/sft_executor.py +++ b/src/worker/executors/sft_executor.py @@ -22,6 +22,8 @@ from trl.trainer.sft_config import SFTConfig from trl.trainer.sft_trainer import SFTTrainer +from shared.schemas.artifact import ArtifactRef +from shared.schemas.result import BaseExecutorResult from shared.tasks.specs import SFTSpecStrict, TaskSpecStrictBase from shared.utils.manifest import scratch_dir @@ -30,7 +32,6 @@ from .mixins.training import TrainingMixin from .utils.checkpoints import ( archive_model_dir, - artifact_ref, determine_resume_path, get_http_destination, maybe_upload_artifacts, @@ -43,6 +44,19 @@ logger = logging.getLogger("worker.sft") +class SFTResult(BaseExecutorResult): + training_time_seconds: float | None = None + error_message: str | None = None + model_name: str | None = None + dataset_size: int = 0 + output_dir: str | None = None + checkpoints_dir: ArtifactRef | None = None + resume_from_path: str | None = None + final_model: ArtifactRef | None = None + final_model_archive: ArtifactRef | None = None + spawned_torchrun: bool = False + + class SFTExecutor(TrainingMixin, Executor): name = "sft_executor" @@ -55,12 +69,12 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: self._final_model_dir: Path | None = None self._task_out_dir: Path | None = None - def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]: + def run(self, task: ExecutorTask, out_dir: Path) -> SFTResult: configure_hf_library_logging() spec = self.require_spec(task, SFTSpecStrict) requested_gpu_count = self._requested_gpu_count(spec) start_time = time.time() - training_successful = False + ok = False error_msg: str | None = None caught_exc: Exception | None = None self._task_out_dir = out_dir @@ -189,22 +203,22 @@ def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]: ) ipc_path = scratch_dir(out_dir) / "distributed_result.json" if ipc_path.exists(): - distributed_result = self.load_json(ipc_path) self._task_out_dir = None - return distributed_result + return SFTResult.model_validate(self.load_json(ipc_path)) self._task_out_dir = None - return { - "training_successful": True, - "spawned_torchrun": True, - "model_name": spec.model_name, - "output_dir": out_dir.as_posix(), - } + return SFTResult( + spawned_torchrun=True, + model_name=spec.model_name, + output_dir=out_dir.as_posix(), + ) except Exception as spawn_exc: logger.exception("Failed to launch distributed SFT: %s", spawn_exc) raise ExecutionError( "Failed to launch distributed SFT subprocess" ) from spawn_exc + resume_path: Path | None = None + train_dataset: Any = None try: # Proceed with in-process training (single GPU or inside torchrun) model_name = spec.model_name or "gpt2" @@ -213,7 +227,7 @@ def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]: resume_path = determine_resume_path( spec, training_cfg, out_dir, logger=logger ) - resume_str = str(resume_path) if resume_path else None + resume_str = resume_path.as_posix() if resume_path else None if resume_path: logger.info( @@ -291,7 +305,7 @@ def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]: ) sft_config = SFTConfig( - output_dir=str(checkpoint_dir), + output_dir=checkpoint_dir.as_posix(), num_train_epochs=float(training_cfg.get("num_train_epochs", 1.0)), per_device_train_batch_size=int(training_cfg.get("batch_size", 2)), gradient_accumulation_steps=int( @@ -394,14 +408,14 @@ def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]: logger.info("Starting supervised fine-tuning") trainer.train() - training_successful = True + ok = True logger.info("Training finished") final_model_path: Path | None = None final_archive_path: Path | None = None if bool(training_cfg.get("save_model", True)): model_path = artifacts_dir / "final_model" - trainer.save_model(str(model_path)) + trainer.save_model(model_path.as_posix()) tokenizer.save_pretrained(model_path) final_model_path = model_path logger.info("Saved fine-tuned model to %s", model_path) @@ -422,30 +436,35 @@ def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]: ) training_time = time.time() - start_time - result_payload: dict[str, Any] = { - "task_id": task.task_id, - "training_successful": training_successful, - "training_time_seconds": training_time, - "error_message": error_msg, - "model_name": self._model_name, - "dataset_size": len(train_dataset), - "output_dir": out_dir.as_posix(), - "checkpoints_dir": artifact_ref("checkpoints"), - "resume_from_path": resume_str, - } - - if final_model_path is not None: + final_model: ArtifactRef | None = None + final_model_archive: ArtifactRef | None = None + if final_model_path: resolved_model_path = Path(final_model_path) self._final_model_dir = ( resolved_model_path if resolved_model_path.exists() else None ) - result_payload["final_model"] = artifact_ref( - final_model_path.relative_to(artifacts_dir).as_posix() + final_model = ArtifactRef( + path=final_model_path.relative_to(artifacts_dir).as_posix() ) - if final_archive_path is not None: - result_payload["final_model_archive"] = artifact_ref( - final_archive_path.relative_to(artifacts_dir).as_posix() + final_model_archive = ( + ArtifactRef( + path=final_archive_path.relative_to(artifacts_dir).as_posix() ) + if final_archive_path + else None + ) + result = SFTResult( + ok=ok, + training_time_seconds=training_time, + error_message=error_msg, + model_name=self._model_name, + dataset_size=len(train_dataset), + output_dir=out_dir.as_posix(), + checkpoints_dir=ArtifactRef(path="checkpoints"), + resume_from_path=resume_str, + final_model=final_model, + final_model_archive=final_model_archive, + ) maybe_upload_artifacts(task, out_dir, logger=logger) @@ -467,11 +486,11 @@ def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]: self._final_model_dir = None self._task_out_dir = None - return result_payload + return result except Exception as exc: error_msg = str(exc) - training_successful = False + ok = False caught_exc = exc logger.exception("SFT training failed: %s", exc) trainer = None @@ -484,23 +503,20 @@ def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]: self._final_model_dir = None training_time = time.time() - start_time - result: dict[str, Any] = { - "task_id": task.task_id, - "training_successful": training_successful, - "training_time_seconds": training_time, - "error_message": error_msg, - "model_name": self._model_name, - "dataset_size": ( + result = SFTResult( + ok=ok, + training_time_seconds=training_time, + error_message=error_msg, + model_name=self._model_name, + dataset_size=( len(train_dataset) if "train_dataset" in locals() and train_dataset is not None else 0 ), - "output_dir": out_dir.as_posix(), - "checkpoints_dir": artifact_ref("checkpoints"), - "resume_from_path": ( - str(resume_path) if "resume_path" in locals() and resume_path else None - ), - } + output_dir=out_dir.as_posix(), + checkpoints_dir=ArtifactRef(path="checkpoints"), + resume_from_path=resume_path.as_posix() if resume_path else None, + ) write_executor_result(out_dir / "results.json", task.task_id, task.spec, result) if caught_exc is not None: self._task_out_dir = None @@ -602,7 +618,7 @@ def _ensure_jsonl_local(self, jsonl_cfg: dict[str, Any]) -> Path: timeout=timeout, logger=logger, ) - jsonl_cfg["path"] = str(resolved) + jsonl_cfg["path"] = resolved.as_posix() return resolved except ExecutionError as exc: last_error = exc @@ -823,13 +839,13 @@ def _resolve_deepspeed_config(training_cfg: dict[str, Any], log) -> Any | None: if isinstance(cfg, dict): return cfg if isinstance(cfg, (str, Path)): - candidate = Path(str(cfg)).expanduser() + candidate = Path(cfg).expanduser() if candidate.exists(): log.info("Using DeepSpeed config at %s", candidate) - return str(candidate) + return candidate.as_posix() # Allow literal identifiers (e.g., 'auto') without file presence. log.info("Using DeepSpeed literal configuration '%s'", cfg) - return str(cfg) + return cfg if isinstance(cfg, str) else cfg.as_posix() raise ValueError("training.deepspeed must be a dict, path string, or falsy") @staticmethod diff --git a/src/worker/executors/ssh_executor.py b/src/worker/executors/ssh_executor.py index ef6d7958..107e93df 100644 --- a/src/worker/executors/ssh_executor.py +++ b/src/worker/executors/ssh_executor.py @@ -29,6 +29,7 @@ import requests +from shared.schemas.result import BaseExecutorResult from shared.tasks.components.resources import GPURequirements from shared.tasks.specs.ssh import ( SSHInputSpec, @@ -55,6 +56,14 @@ TaskCancelledError, ) + +class SSHResult(BaseExecutorResult): + session_id: str + exit_code: int + command: list[str] | None = None + entrypoint: list[str] | None = None + + try: import docker from docker import DockerClient @@ -425,7 +434,7 @@ def teardown(self) -> None: # Main execution # ------------------------------------------------------------------ # - def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]: + def run(self, task: ExecutorTask, out_dir: Path) -> SSHResult: spec = self.require_spec(task, SSHSpecStrict) cfg = SSHConfig.from_spec(spec, self._config, self._hardware) access_mode = cfg.access_mode @@ -539,16 +548,17 @@ def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]: cfg.output, mount_plan, ) - result: dict[str, Any] = {"session_id": session_id, "exit_code": exit_code} + result = SSHResult(session_id=session_id, exit_code=exit_code) if interactive: - result.update(session_info) + for key, value in session_info.items(): + setattr(result, key, value) else: # Keep as fallback — captures any output the streaming thread missed. self._save_container_logs(container, out_dir) if cfg.command is not None: - result["command"] = cfg.command + result.command = cfg.command if cfg.entrypoint is not None: - result["entrypoint"] = cfg.entrypoint + result.entrypoint = cfg.entrypoint if mount_plan.copy_output_path: self._copy_output_directory( container, diff --git a/src/worker/executors/transformers_executor.py b/src/worker/executors/transformers_executor.py index c4a3b586..beaed479 100644 --- a/src/worker/executors/transformers_executor.py +++ b/src/worker/executors/transformers_executor.py @@ -56,7 +56,9 @@ from pathlib import Path from typing import TYPE_CHECKING, Any +from shared.schemas.artifact import ArtifactRef from shared.schemas.governance import SpanType +from shared.schemas.result import BaseExecutorResult from shared.tasks.specs import ( EmbeddingSpecStrict, InferenceSpecStrict, @@ -66,11 +68,7 @@ from .base_executor import ExecutionError, Executor, ExecutorTask from .mixins.data import InferenceEntry from .mixins.inference import InferenceMixin -from .utils.checkpoints import ( - artifact_ref, - maybe_upload_artifacts, - maybe_upload_traces, -) +from .utils.checkpoints import maybe_upload_artifacts, maybe_upload_traces try: import torch @@ -115,6 +113,16 @@ logger = logging.getLogger(__name__) +class TransformersResult(BaseExecutorResult): + ok: bool = True + model: str | None = None + items: list[dict[str, Any]] = [] + usage: dict[str, Any] | None = None + count: int | None = None + embedding_file: ArtifactRef | None = None + image_group_sizes: list[int] | None = None + + class HFTransformersExecutor(InferenceMixin, Executor): """Executor that runs text generation via Hugging Face Transformers.""" @@ -384,7 +392,7 @@ def _detect_finish_reason( return "length" return None - def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]: # type: ignore[override] + def run(self, task: ExecutorTask, out_dir: Path) -> TransformersResult: configure_hf_library_logging() spec = task.spec if not isinstance(spec, (InferenceSpecStrict, EmbeddingSpecStrict)): @@ -406,7 +414,7 @@ def _run_inner( spec: "InferenceSpecStrict | EmbeddingSpecStrict", task_id: str, out_dir: Path, - ) -> dict[str, Any]: + ) -> TransformersResult: with self._span("model load", span_type=SpanType.COMPUTE): self._ensure_model(spec) @@ -423,8 +431,6 @@ def _run_inner( assert self._model is not None - result: dict[str, Any] = {} - if self._mode == "visual-embedding": assert self._image_processor is not None device = self._device or ("cuda" if torch.cuda.is_available() else "cpu") @@ -479,15 +485,13 @@ def _run_inner( emb_path = artifacts_dir / "visual_embeddings.pt" torch.save(grouped_visual_embeddings, emb_path) - result = { - "ok": True, - "model": self._model_name, - "items": [], # Embeddings are in file - "count": len(grouped_visual_embeddings), - "embedding_file": artifact_ref("visual_embeddings.pt"), - } - if image_group_sizes is not None: - result["image_group_sizes"] = image_group_sizes + result = TransformersResult( + model=self._model_name, + items=[], + count=len(grouped_visual_embeddings), + embedding_file=ArtifactRef(path="visual_embeddings.pt"), + image_group_sizes=image_group_sizes, + ) self._dump_to_governance( task_id=task_id, @@ -593,21 +597,20 @@ def _run_inner( prompt_tokens += int(input_len) completion_tokens += int(gen_part.shape[0]) - result = { - "ok": True, - "model": self._model_name, - "items": items, - "usage": { + result = TransformersResult( + model=self._model_name, + items=items, + usage={ "prompt_tokens": int(prompt_tokens), "completion_tokens": int(completion_tokens), "total_tokens": int(prompt_tokens + completion_tokens), "num_requests": len(self._prompts), "latency_sec": latency, }, - } + ) if isinstance(spec, InferenceSpecStrict): - self._maybe_export_jsonl(spec, task_id, result, out_dir) + self._maybe_export_jsonl(spec, task_id, items, out_dir) self._dump_to_governance( task_id=task_id, diff --git a/src/worker/executors/utils/artifacts.py b/src/worker/executors/utils/artifacts.py index 7f4df0be..8605433b 100644 --- a/src/worker/executors/utils/artifacts.py +++ b/src/worker/executors/utils/artifacts.py @@ -6,48 +6,39 @@ import requests +from shared.schemas.result import BaseExecutorResult from shared.utils.http import auth_headers from ..base_executor import ExecutionError def artifact_to_source( - ref: dict[str, Any], context: dict[str, Any] | None, node: str | None + ref: dict[str, Any], context: dict[str, BaseExecutorResult] | None, node: str | None ) -> str: """Translate a `{path: ...}` artifact ref into a URL or local path.""" rel_path = ref.get("path") if not isinstance(rel_path, str) or not rel_path: raise ExecutionError("Artifact ref must include a non-empty 'path' field") - ctx: dict[str, Any] = {} - if context and isinstance(node, str) and node: - node_payload = context.get(node) - if isinstance(node_payload, dict): - raw_ctx = node_payload.get("_artifacts") - if isinstance(raw_ctx, dict): - ctx = raw_ctx - if not ctx: - # Some executors stuff their result under a "result" key. - result = node_payload.get("result") - if isinstance(result, dict): - inner = result.get("_artifacts") - if isinstance(inner, dict): - ctx = inner - - if ctx: - base_url = ctx.get("base_url") - base_dir = ctx.get("base_dir") + if ( + context + and node + and (node_result := context.get(node)) + and (ctx := node_result.artifacts_) + ): + base_url = ctx.base_url + base_dir = ctx.base_dir else: base_url = base_dir = None # Check local filesystem first - base_dir_path = Path(base_dir) if isinstance(base_dir, str) and base_dir else None + base_dir_path = Path(base_dir) if base_dir else None local_path = base_dir_path / "artifacts" / rel_path if base_dir_path else None if local_path is not None and local_path.is_file(): return local_path.as_posix() # Fallback to a URL if base_url is provided - if isinstance(base_url, str) and base_url: + if base_url: if not base_dir_path: raise ExecutionError( "Artifact ref with base_url requires upstream base_dir to " @@ -66,7 +57,7 @@ def artifact_to_source( def maybe_resolve_artifact_ref( - value: Any, context: dict[str, Any] | None, node: str | None + value: Any, context: dict[str, BaseExecutorResult] | None, node: str | None ) -> Any: """Convert `{path: ...}` ref dicts to URL/path strings; pass others through.""" if isinstance(value, dict) and "path" in value: diff --git a/src/worker/executors/utils/checkpoints.py b/src/worker/executors/utils/checkpoints.py index 60ba6053..ed57f369 100644 --- a/src/worker/executors/utils/checkpoints.py +++ b/src/worker/executors/utils/checkpoints.py @@ -11,8 +11,10 @@ import requests -from shared.schemas.result import write_result_in_envelope +from shared.schemas.artifact import ArtifactContext +from shared.schemas.result import BaseExecutorResult, ResultEnvelope from shared.tasks.specs import TaskSpecStrictBase +from shared.utils.atomic import atomic_write_text from shared.utils.http import add_auth_headers from shared.utils.parsing import parse_bool_env @@ -394,15 +396,9 @@ def is_cleanup_enabled() -> bool: return normalized not in {"0", "false", "no", "off"} -def artifact_ref(rel_path: str) -> dict[str, str]: - """Build an artifact reference dict. `rel_path` is the path relative to - `out_dir/artifacts/`""" - return {"path": rel_path} - - -def build_artifact_context(spec: TaskSpecStrictBase, out_dir: Path) -> dict[str, Any]: - """Top-level `_artifacts` context: {base_dir, base_url}. base_url is the - destination origin (scheme://host[:port]) for HTTP, else None.""" +def build_artifact_context(spec: TaskSpecStrictBase, out_dir: Path) -> ArtifactContext: + """Top-level ``_artifacts`` context. ``base_url`` is the destination + origin (scheme://host[:port]) for HTTP, else ``None``.""" base_dir = Path(out_dir).resolve().as_posix() base_url: str | None = None destination = get_http_destination(spec) @@ -410,18 +406,17 @@ def build_artifact_context(spec: TaskSpecStrictBase, out_dir: Path) -> dict[str, parsed = urlparse(destination.url) if parsed.scheme and parsed.netloc: base_url = f"{parsed.scheme}://{parsed.netloc}" - return {"base_dir": base_dir, "base_url": base_url} + return ArtifactContext(base_dir=base_dir, base_url=base_url) def write_executor_result( - path: Path, - task_id: str, - spec: TaskSpecStrictBase, - result: dict[str, Any], + path: Path, task_id: str, spec: TaskSpecStrictBase, result: BaseExecutorResult ) -> None: """Stamp ``_artifacts`` onto ``result`` and persist the envelope.""" - result["_artifacts"] = build_artifact_context(spec, path.parent) - write_result_in_envelope(path, task_id, result) + path.parent.mkdir(parents=True, exist_ok=True) + result.artifacts_ = build_artifact_context(spec, path.parent) + envelope = ResultEnvelope(task_id=task_id, result=result) + atomic_write_text(path, envelope.model_dump_json(indent=2)) def maybe_upload_artifacts( diff --git a/src/worker/executors/utils/graph_templates.py b/src/worker/executors/utils/graph_templates.py index eb01ea1a..e51aad54 100644 --- a/src/worker/executors/utils/graph_templates.py +++ b/src/worker/executors/utils/graph_templates.py @@ -5,7 +5,9 @@ from typing import Any import pandas as pd +from pydantic import BaseModel +from shared.schemas.result import BaseExecutorResult from shared.tasks.specs import TaskSpecStrictBase from shared.utils.json import validate_keys @@ -13,6 +15,8 @@ from ..base_executor import ExecutionError from .safe_eval import safe_execute_function, safe_materialize_function +_SENTINEL: Any = object() + type MessageItem = dict[str, str] type Message = Sequence[MessageItem] type MaterializedMessage = Sequence[MessageItem] @@ -81,7 +85,7 @@ def build_prompts_from_graph_template( def _maybe_broadcast_image_prompts( prompts: Sequence[str | Message], data_cfg: dict[str, Any], - context: dict[str, Any], + context: dict[str, BaseExecutorResult], ) -> list[str | Message]: image_embedding = data_cfg.get("image_embedding") if not isinstance(image_embedding, dict): @@ -95,7 +99,7 @@ def _maybe_broadcast_image_prompts( def _resolve_image_embedding_count( - image_embedding: dict[str, Any], context: dict[str, Any] + image_embedding: dict[str, Any], context: dict[str, BaseExecutorResult] ) -> int | None: node = image_embedding.get("node") if not isinstance(node, str): @@ -105,15 +109,17 @@ def _resolve_image_embedding_count( if not isinstance(node, str) or not node: return None upstream = context.get(node) - if not isinstance(upstream, dict): + if upstream is None: return None - count = upstream.get("count") + count = getattr(upstream, "count", _SENTINEL) if isinstance(count, int) and count > 0: return count return None -def _resolve_columns(columns_cfg: Any, context: dict[str, Any]) -> list[dict[str, Any]]: +def _resolve_columns( + columns_cfg: Any, context: dict[str, BaseExecutorResult] +) -> list[dict[str, Any]]: if not isinstance(columns_cfg, list): raise ExecutionError("graph_template.template.columns must be a list.") @@ -580,17 +586,17 @@ def _format_column_line(label: str, value: str) -> str: return f"• {label}: {indented}" -def _evaluate_expr(expr: str, context: dict[str, Any]) -> Any: +def _evaluate_expr(expr: str, context: dict[str, BaseExecutorResult]) -> Any: if not expr: return None parts = expr.split(".") root = parts[0] - data = context.get(root) - if data is None: + result = context.get(root) + if result is None: return None - value: Any = data + value: Any = result for token in parts[1:]: if not token: continue @@ -617,6 +623,14 @@ def _evaluate_expr(expr: str, context: dict[str, Any]) -> Any: f"{attr} not a valid column in DataFrame for {token}." ) value = value[attr].tolist() + elif isinstance(value, BaseModel): + resolved = getattr(value, attr, _SENTINEL) + if resolved is _SENTINEL: + raise ExecutionError( + f"{attr} not a valid attribute of {type(value).__name__} " + f"for {token}." + ) + value = resolved else: raise ExecutionError( f"{attr} in {parts} is not a valid key - " diff --git a/src/worker/executors/vllm_executor.py b/src/worker/executors/vllm_executor.py index 6ad6abd1..bb94b21d 100644 --- a/src/worker/executors/vllm_executor.py +++ b/src/worker/executors/vllm_executor.py @@ -67,6 +67,7 @@ StructuredOutputsParams = None # type: ignore from shared.schemas.governance import SpanType +from shared.schemas.result import BaseExecutorResult from shared.tasks.specs import InferenceSpecStrict from .base_executor import ExecutionError, Executor, ExecutorTask @@ -81,6 +82,13 @@ logger = logging.getLogger(__name__) +class VLLMResult(BaseExecutorResult): + ok: bool = True + model: str | None = None + items: list[dict[str, Any]] = [] + usage: dict[str, Any] | None = None + + class _RawJsonSchema: """Tag wrapping a JSON schema so ``_build_sampling_params`` can ``isinstance``-dispatch raw schemas vs. named-fields pydantic kwargs.""" @@ -849,7 +857,7 @@ def _postprocess_prompts(self, parsed: InferenceEntry) -> PreparedInferenceEntry # --------------------------------------------------------------------- # # Execution # --------------------------------------------------------------------- # - def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]: # type: ignore[override] + def run(self, task: ExecutorTask, out_dir: Path) -> VLLMResult: spec = self.require_spec(task, InferenceSpecStrict) task_id = task.task_id.strip() if not task_id: @@ -868,7 +876,7 @@ def _run_inner( task: ExecutorTask, spec: InferenceSpecStrict, out_dir: Path, - ) -> dict[str, Any]: + ) -> VLLMResult: task_id = task.task_id.strip() merge_children = task.merged_children or [] entries: list[PreparedInferenceEntry] = [] @@ -1149,51 +1157,45 @@ def _collect( "num_requests": len(self._batched_inputs), } - result: dict[str, Any] = { - "ok": True, - "model": self._model_name, - "items": per_task_items.get(task_id, []), - "usage": parent_usage, - } - - child_results: dict[str, Any] = {} + items = per_task_items.get(task_id, []) + child_results: dict[str, VLLMResult] = {} for child in merge_children: child_id = child.task_id.strip() if not child_id: continue - child_payload: dict[str, Any] = { - "items": per_task_items.get(child_id, []), - } + child_items = per_task_items.get(child_id, []) maybe_usage = usage_by_task.get(child_id) - if maybe_usage: - child_payload["usage"] = maybe_usage - child_results[child_id] = child_payload + child_results[child_id] = VLLMResult( + model=self._model_name, items=child_items, usage=maybe_usage + ) if parent_tables := parent_entry.tables: - result = self._populate_table(result, parent_tables) + items = self._populate_table(items, parent_tables) if child_results: for child_id, child_payload in child_results.items(): if (child_entry := entry_by_task_id.get(child_id)) and ( child_tables := child_entry.tables ): - child_results[child_id] = self._populate_table( - child_payload, child_tables + child_results[child_id].items = self._populate_table( + child_payload.items, child_tables ) - if child_results: - result["children"] = child_results + result = VLLMResult( + children=cast(dict[str, BaseExecutorResult], child_results), + model=self._model_name, + items=items, + usage=parent_usage, + ) with self._span( "JSONL export", span_type=SpanType.COMPUTE, attributes={"task_ids": task_ids}, ): - self._maybe_export_jsonl(spec, task_id, result, out_dir) + self._maybe_export_jsonl(spec, task_id, result.items, out_dir) self._dump_to_governance( - task_id=task_id, - result=result, - dependencies_by_task=dependencies_by_task, + task_id=task_id, result=result, dependencies_by_task=dependencies_by_task ) return result diff --git a/src/worker/runner.py b/src/worker/runner.py index eef7e292..7b7086de 100644 --- a/src/worker/runner.py +++ b/src/worker/runner.py @@ -10,6 +10,7 @@ import requests +from shared.schemas.result import BaseExecutorResult from shared.tasks import MergedChildTaskStrict from shared.tasks.specs import InferenceSpecStrict, TaskSpecStrictBase from shared.tasks.worker_message import ( @@ -127,17 +128,14 @@ def _write_results( spec: TaskSpecStrictBase, merged_children: list[MergedChildTaskStrict], out_dir: Path, - result: dict[str, Any] | None, + result: BaseExecutorResult | None, ): if result is None: return self._write_single_result(task_id, spec, out_dir, result) child_lookup = {entry.task_id: entry for entry in merged_children} - children_payload = ( - result.get("children") if isinstance(result, dict) else {} - ) or {} - for child_id, child_result in children_payload.items(): + for child_id, child_result in result.children.items(): child_info = child_lookup.get(child_id) if child_info is None: continue @@ -151,7 +149,7 @@ def _write_single_result( task_id: str, spec: TaskSpecStrictBase, out_dir: Path, - payload: dict[str, Any] | None, + payload: BaseExecutorResult | None, ): if payload is None: return @@ -179,7 +177,7 @@ def _simulate_bandwidth_delay(self, payload_bytes: int, destination: str) -> Non time.sleep(delay) def _maybe_emit_http( - self, task_id: str, spec: TaskSpecStrictBase, result: dict[str, Any] + self, task_id: str, spec: TaskSpecStrictBase, result: BaseExecutorResult ) -> None: """Send task results to an HTTP endpoint when requested by the spec.""" destination = get_http_destination(spec) @@ -190,7 +188,7 @@ def _maybe_emit_http( ignore_error = destination.ignore_error payload = { "task_id": task_id, - "result": result, + "result": result.model_dump(), "worker_id": self.lifecycle.worker_id, } payload_size = len(json.dumps(payload, ensure_ascii=False).encode("utf-8")) diff --git a/tests/sdk/test_schema_compat.py b/tests/sdk/test_schema_compat.py index b757060f..d4092a7e 100644 --- a/tests/sdk/test_schema_compat.py +++ b/tests/sdk/test_schema_compat.py @@ -8,6 +8,9 @@ # SDK-side imports from flowmesh.models import ( + ArtifactContext, + ArtifactRef, + BaseExecutorResult, CPUInfo, GpuInfo, GpuPlatformInfo, @@ -79,6 +82,9 @@ ) from server.task.models import TaskInfo as SrvTaskInfo from server.task.models import TaskUsage as SrvTaskUsage +from shared.schemas.artifact import ArtifactContext as SrvArtifactContext +from shared.schemas.artifact import ArtifactRef as SrvArtifactRef +from shared.schemas.result import BaseExecutorResult as SrvBaseExecutorResult from shared.schemas.result import ResultEnvelope as SrvResultEnvelope from shared.schemas.worker import SSHLimits as SrvSSHLimits from shared.tasks.task_type import TaskType as SrvTaskType @@ -123,7 +129,11 @@ # Common (SrvOkResponse, OkResponse), # Results + (SrvBaseExecutorResult, BaseExecutorResult), (SrvResultEnvelope, ResultEnvelope), + # Artifacts + (SrvArtifactContext, ArtifactContext), + (SrvArtifactRef, ArtifactRef), ] diff --git a/tests/shared/test_executor_result.py b/tests/shared/test_executor_result.py new file mode 100644 index 00000000..4b5d79d3 --- /dev/null +++ b/tests/shared/test_executor_result.py @@ -0,0 +1,71 @@ +"""Round-trip tests for the shared executor-result schema.""" + +import json +from typing import Any + +from pydantic import Field + +from shared.schemas.artifact import ArtifactContext, ArtifactRef +from shared.schemas.result import BaseExecutorResult, ResultEnvelope + + +class _SampleResult(BaseExecutorResult): + ok: bool = True + items: list[dict[str, Any]] = Field(default_factory=list) + usage: dict[str, Any] | None = None + + +def test_subclass_round_trip_through_base_preserves_extra_fields() -> None: + """A subclass's executor-specific fields survive a JSON trip through the + base class via ``extra='allow'``.""" + original = _SampleResult( + ok=True, + items=[{"output": "hello"}], + usage={"latency_sec": 0.5}, + _artifacts=ArtifactContext(base_dir="/tmp/t", base_url="http://h"), + ) + payload = original.model_dump_json() + + base = BaseExecutorResult.model_validate_json(payload) + redumped = json.loads(base.model_dump_json()) + + assert redumped["items"] == [{"output": "hello"}] + assert redumped["usage"] == {"latency_sec": 0.5} + assert redumped["_artifacts"] == {"base_dir": "/tmp/t", "base_url": "http://h"} + + +def test_recursive_children_round_trip() -> None: + """Nested ``children`` deserialize as ``BaseExecutorResult`` and re-emit + their extra fields when serialized.""" + parent = _SampleResult( + items=[{"output": "p"}], + children={ + "c1": _SampleResult(items=[{"output": "c"}], usage={"total_tokens": 3}), + }, + ) + payload = parent.model_dump_json() + base = BaseExecutorResult.model_validate_json(payload) + + assert "c1" in base.children + child = base.children["c1"] + redumped = json.loads(child.model_dump_json()) + assert redumped["items"] == [{"output": "c"}] + assert redumped["usage"] == {"total_tokens": 3} + + +def test_artifact_ref_is_a_typed_path_reference() -> None: + ref = ArtifactRef(path="lora/final") + assert ref.model_dump() == {"path": "lora/final"} + + +def test_envelope_round_trip_preserves_subclass_payload() -> None: + """The production write→read path (``write_result_in_envelope`` → + ``ResultEnvelope.model_validate``) round-trips subclass fields.""" + inner = _SampleResult(items=[{"output": "hello"}], usage={"total_tokens": 7}) + env = ResultEnvelope(task_id="tsk-x", result=inner) + raw = env.model_dump_json() + + parsed = ResultEnvelope.model_validate_json(raw) + dumped = parsed.result.model_dump() + assert dumped["items"] == [{"output": "hello"}] + assert dumped["usage"] == {"total_tokens": 7} diff --git a/tests/worker/test_agent_connector.py b/tests/worker/test_agent_connector.py index b4fd1eb0..f3b07f32 100644 --- a/tests/worker/test_agent_connector.py +++ b/tests/worker/test_agent_connector.py @@ -182,9 +182,9 @@ def test_renders_description_from_params_and_yields_table_item( ) result = executor._run_agent(data_cfg, context, out_dir) - assert result["ok"] is True - assert result["count"] == 1 - item = result["items"][0] + assert result.ok is True + assert result.count == 1 + item = result.items[0] assert "fetch NVDA quotes for 2024-01-01" in item["description"] assert item["rows"] == 1 assert item["run_id"] == "run-fake" diff --git a/tests/worker/test_artifact_utils.py b/tests/worker/test_artifact_utils.py index a30528a5..189733b2 100644 --- a/tests/worker/test_artifact_utils.py +++ b/tests/worker/test_artifact_utils.py @@ -8,6 +8,8 @@ import pytest +from shared.schemas.artifact import ArtifactContext +from shared.schemas.result import BaseExecutorResult from worker.executors.base_executor import ExecutionError from worker.executors.utils.artifacts import ( artifact_to_source, @@ -16,42 +18,26 @@ class TestArtifactToSource: - def test_unwrapped_upstream_yields_url(self, tmp_path: Path) -> None: - upstream = { - "_artifacts": { - "base_url": "http://host:8010", - "base_dir": (tmp_path / "producer-tid").as_posix(), - }, - "result": {"images": [{"path": "a.png"}]}, - } + def test_url_resolution(self, tmp_path: Path) -> None: + upstream = BaseExecutorResult( + _artifacts=ArtifactContext( + base_dir=(tmp_path / "producer-tid").as_posix(), + base_url="http://host:8010", + ), + ) url = artifact_to_source({"path": "a.png"}, {"producer": upstream}, "producer") assert url == "http://host:8010/api/v1/results/producer-tid/files/a.png" - def test_envelope_wrapped_upstream_yields_url(self, tmp_path: Path) -> None: - """Server stores results.json as `{task_id, ..., result: {...}}`; the - helper must unwrap one level to find `_artifacts`.""" - upstream = { - "task_id": "producer-tid", - "result": { - "_artifacts": { - "base_url": "http://host:8010", - "base_dir": (tmp_path / "producer-tid").as_posix(), - }, - }, - } - url = artifact_to_source({"path": "x.png"}, {"producer": upstream}, "producer") - assert url == "http://host:8010/api/v1/results/producer-tid/files/x.png" - def test_local_file_takes_fast_path(self, tmp_path: Path) -> None: task_root = tmp_path / "producer-tid" (task_root / "artifacts").mkdir(parents=True) (task_root / "artifacts" / "a.png").write_bytes(b"\x89PNG") - upstream = { - "_artifacts": { - "base_url": "http://host:8010", - "base_dir": task_root.as_posix(), - } - } + upstream = BaseExecutorResult( + _artifacts=ArtifactContext( + base_dir=task_root.as_posix(), + base_url="http://host:8010", + ) + ) resolved = artifact_to_source( {"path": "a.png"}, {"producer": upstream}, "producer" ) @@ -59,7 +45,9 @@ def test_local_file_takes_fast_path(self, tmp_path: Path) -> None: def test_local_only_upstream_returns_local_path(self, tmp_path: Path) -> None: task_root = tmp_path / "producer-tid" - upstream = {"_artifacts": {"base_url": None, "base_dir": task_root.as_posix()}} + upstream = BaseExecutorResult( + _artifacts=ArtifactContext(base_dir=task_root.as_posix()) + ) resolved = artifact_to_source( {"path": "a.png"}, {"producer": upstream}, "producer" ) @@ -67,11 +55,13 @@ def test_local_only_upstream_returns_local_path(self, tmp_path: Path) -> None: def test_missing_context_raises(self) -> None: with pytest.raises(ExecutionError, match="_artifacts context is missing"): - artifact_to_source({"path": "a.png"}, {"producer": {}}, "producer") + artifact_to_source( + {"path": "a.png"}, {"producer": BaseExecutorResult()}, "producer" + ) def test_missing_path_raises(self) -> None: with pytest.raises(ExecutionError, match="non-empty 'path' field"): - artifact_to_source({}, {"producer": {}}, "producer") + artifact_to_source({}, {"producer": BaseExecutorResult()}, "producer") class TestMaybeResolveArtifactRef: @@ -85,12 +75,12 @@ def test_passes_through_dict_without_path(self) -> None: assert maybe_resolve_artifact_ref(value, None, None) is value def test_resolves_path_dict(self, tmp_path: Path) -> None: - upstream = { - "_artifacts": { - "base_url": "http://host:8010", - "base_dir": (tmp_path / "producer-tid").as_posix(), - } - } + upstream = BaseExecutorResult( + _artifacts=ArtifactContext( + base_url="http://host:8010", + base_dir=(tmp_path / "producer-tid").as_posix(), + ) + ) out = maybe_resolve_artifact_ref( {"path": "a.png"}, {"producer": upstream}, "producer" ) diff --git a/tests/worker/test_checkpoint_utils.py b/tests/worker/test_checkpoint_utils.py index 27816fbe..7c671641 100644 --- a/tests/worker/test_checkpoint_utils.py +++ b/tests/worker/test_checkpoint_utils.py @@ -6,6 +6,7 @@ import pytest +from shared.schemas.artifact import ArtifactContext from worker.executors.base_executor import TaskReference from worker.executors.utils import checkpoints @@ -29,25 +30,22 @@ def _task( return cast(TaskReference, SimpleNamespace(task_id="task-1", spec=spec)) -class TestArtifactRef: - def test_returns_path_only(self) -> None: - assert checkpoints.artifact_ref("images/foo.png") == {"path": "images/foo.png"} - - class TestBuildArtifactContext: def test_http_destination_strips_api_suffix(self, tmp_path: Path) -> None: out_dir = tmp_path / "task-1" out_dir.mkdir() ctx = checkpoints.build_artifact_context(_task().spec, out_dir) - assert ctx["base_dir"] == out_dir.resolve().as_posix() - assert ctx["base_url"] == "http://host:8010" + assert ctx == ArtifactContext( + base_dir=out_dir.resolve().as_posix(), base_url="http://host:8010" + ) def test_local_destination_leaves_base_url_none(self, tmp_path: Path) -> None: out_dir = tmp_path / "task-1" out_dir.mkdir() ctx = checkpoints.build_artifact_context(_task("local").spec, out_dir) - assert ctx["base_dir"] == out_dir.resolve().as_posix() - assert ctx["base_url"] is None + assert ctx == ArtifactContext( + base_dir=out_dir.resolve().as_posix(), base_url=None + ) class TestMaybeUploadArtifacts: diff --git a/tests/worker/test_connector_logging.py b/tests/worker/test_connector_logging.py index beda187e..417e8a98 100644 --- a/tests/worker/test_connector_logging.py +++ b/tests/worker/test_connector_logging.py @@ -1,16 +1,24 @@ """Test that connector logs are properly redirected from worker subprocess to parent.""" +import logging import tempfile import time import uuid from pathlib import Path +from typing import Any +from shared.schemas.result import BaseExecutorResult from shared.tasks.worker_message import WorkerTaskMessage from tests.worker.factories import make_live_worker_config, make_worker_hardware from worker.executors.base_executor import Executor from worker.executors.mp_executor import MPExecutor +class ConnectorLoggingResult(BaseExecutorResult): + ok: bool = True + result: dict[str, Any] + + class ConnectorLoggingExecutor(Executor): """Simple test executor that uses a connector and logs messages.""" @@ -19,9 +27,8 @@ class ConnectorLoggingExecutor(Executor): def prepare(self) -> None: pass - def run(self, task, out_dir: Path) -> dict: + def run(self, task, out_dir: Path) -> ConnectorLoggingResult: """Run a simple test that logs from different modules.""" - import logging # Get loggers from different modules that would be used in real execution executor_logger = logging.getLogger("executors.test_executor") @@ -36,13 +43,12 @@ def run(self, task, out_dir: Path) -> dict: root_logger.info("Message from root logger") - return { - "ok": True, - "result": { + return ConnectorLoggingResult( + result={ "status": "completed", "log_test": "Messages logged from different modules", }, - } + ) def cleanup_after_run(self) -> None: pass @@ -89,8 +95,8 @@ def test_connector_logs_printed_to_stderr(tmp_path: Path) -> None: mp.cleanup_after_run() - # Verify the executor ran successfully - assert result["ok"], f"Executor failed: {result}" - assert ( - result.get("result", {}).get("status") == "completed" - ), f"Unexpected result: {result}" + # Verify the executor ran successfully. The MP boundary pickles the + # subclass instance, so the parent receives a ``ConnectorLoggingResult``. + assert isinstance(result, ConnectorLoggingResult), f"Unexpected result: {result}" + assert result.ok is True + assert result.result.get("status") == "completed", f"Unexpected result: {result}" diff --git a/tests/worker/test_data_mixin_lineage.py b/tests/worker/test_data_mixin_lineage.py index 48892ee9..69e8856f 100644 --- a/tests/worker/test_data_mixin_lineage.py +++ b/tests/worker/test_data_mixin_lineage.py @@ -7,6 +7,7 @@ from PIL import Image +from shared.schemas.result import BaseExecutorResult from worker.executors.mixins.data import DataMixin @@ -96,14 +97,16 @@ def test_dump_to_governance_with_merged_children(tmp_path: Path) -> None: mixin = _Mixin() out_dir = tmp_path / "task" with mixin._task_span("tsk-parent", "wfl-1", out_dir, owner_id="alice"): - result = { - "ok": True, - "items": [{"output": "p"}], - "children": { - "tsk-c1": {"items": [{"output": "c1"}]}, - "tsk-c2": {"items": [{"output": "c2"}]}, - }, - } + result = BaseExecutorResult.model_validate( + { + "ok": True, + "items": [{"output": "p"}], + "children": { + "tsk-c1": {"items": [{"output": "c1"}]}, + "tsk-c2": {"items": [{"output": "c2"}]}, + }, + } + ) deps = { "tsk-parent": ["tsk-up-a"], "tsk-c1": ["tsk-up-b"], @@ -143,20 +146,21 @@ def test_collect_prompts_resolves_grouped_image_artifact_refs_after_flatten( for name, color in (("a.png", "red"), ("b.png", "green"), ("c.png", "blue")): Image.new("RGB", (2, 2), color=color).save(artifacts_dir / name) + result = BaseExecutorResult.model_validate( + { + "images": [ + [{"path": "images/a.png"}, {"path": "images/b.png"}], + [{"path": "images/c.png"}], + ], + "_artifacts": {"base_dir": upstream_dir.as_posix()}, + } + ) spec = cast( Any, SimpleNamespace( data={"type": "list", "expr": "vision.images"}, inference={}, - upstreamResults={ - "vision": { - "images": [ - [{"path": "images/a.png"}, {"path": "images/b.png"}], - [{"path": "images/c.png"}], - ], - "_artifacts": {"base_dir": upstream_dir.as_posix()}, - } - }, + upstreamResults={"vision": result}, ), ) diff --git a/tests/worker/test_executor_bootstrap.py b/tests/worker/test_executor_bootstrap.py index 34f6596c..45277de9 100644 --- a/tests/worker/test_executor_bootstrap.py +++ b/tests/worker/test_executor_bootstrap.py @@ -10,6 +10,7 @@ from pathlib import Path from typing import Any +from shared.schemas.result import BaseExecutorResult from tests.worker.factories import make_live_worker_config, make_worker_hardware from worker.executors.base_executor import Executor, ExecutorTask from worker.main import initialize_executors @@ -23,8 +24,8 @@ class _PassthroughExecutor(Executor): def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) - def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]: # noqa: D401 - return {"ok": True} + def run(self, task: ExecutorTask, out_dir: Path) -> BaseExecutorResult: + return BaseExecutorResult.model_validate({"ok": True}) class TestInitializeExecutorsHardware: diff --git a/tests/worker/test_mp_executor_lifecycle.py b/tests/worker/test_mp_executor_lifecycle.py index 3fe73be8..6cea3af4 100644 --- a/tests/worker/test_mp_executor_lifecycle.py +++ b/tests/worker/test_mp_executor_lifecycle.py @@ -4,6 +4,7 @@ import uuid from pathlib import Path +from shared.schemas.result import BaseExecutorResult from shared.tasks import TaskType from shared.tasks.specs import EchoSpecStrict from shared.tasks.worker_message import WorkerTaskMessage @@ -16,11 +17,16 @@ from worker.executors.mp_executor import MPExecutor +class _SimpleMPResult(BaseExecutorResult): + ok: bool = True + task_id: str + + class _SimpleMPExecutor(Executor): name = "simple_mp" - def run(self, task, out_dir: Path) -> dict: - return {"ok": True, "task_id": task.task_id} + def run(self, task, out_dir: Path) -> _SimpleMPResult: + return _SimpleMPResult(task_id=task.task_id) def cleanup_after_run(self) -> None: return None @@ -54,7 +60,8 @@ def test_mp_executor_does_not_start_subprocess_until_first_run(tmp_path: Path) - with tempfile.TemporaryDirectory() as out_dir: result = mp.run(_simple_task_message(), Path(out_dir)) - assert result["ok"] is True + assert isinstance(result, _SimpleMPResult) + assert result.ok is True assert mp._shutdown is False assert mp._proc is not None assert mp._proc.is_alive() diff --git a/tests/worker/test_transformers_chat_inference.py b/tests/worker/test_transformers_chat_inference.py index 12fa0f4a..5f67fa80 100644 --- a/tests/worker/test_transformers_chat_inference.py +++ b/tests/worker/test_transformers_chat_inference.py @@ -116,18 +116,18 @@ def test_transformers_executor_supports_chat_prompts_and_jsonl_export( result = executor.run(task, tmp_path) mock_ensure_model.assert_called_once_with(spec) - assert [item["output"] for item in result["items"]] == [ + assert [item["output"] for item in result.items] == [ "first answer", "second answer", ] - assert [item["prompt"] for item in result["items"]] == [ + assert [item["prompt"] for item in result.items] == [ "hello", "world", ] - assert [item["metadata"]["row_id"] for item in result["items"]] == ["a", "b"] - assert result["usage"]["num_requests"] == 2 - assert "latency_sec" in result["usage"] - assert result["jsonl_export"]["record_count"] == 2 + assert [item["metadata"]["row_id"] for item in result.items] == ["a", "b"] + assert result.usage is not None + assert result.usage["num_requests"] == 2 + assert "latency_sec" in result.usage exported = (tmp_path / "artifacts" / "rows.jsonl").read_text(encoding="utf-8") assert '"row_id": "a"' in exported diff --git a/uv.lock b/uv.lock index ea21fb32..650e27a2 100644 --- a/uv.lock +++ b/uv.lock @@ -2355,7 +2355,7 @@ dependencies = [ requires-dist = [ { name = "httpx", specifier = ">=0.27.0" }, { name = "pandas", specifier = ">=2.3.3" }, - { name = "pydantic", specifier = ">=2.0.0" }, + { name = "pydantic", specifier = ">=2.12.3" }, { name = "pyyaml", specifier = ">=6.0.0" }, ]