diff --git a/src/worker/executors/mixins/data.py b/src/worker/executors/mixins/data.py index f70231aa..ee45aab7 100644 --- a/src/worker/executors/mixins/data.py +++ b/src/worker/executors/mixins/data.py @@ -403,6 +403,8 @@ def _collect_prompts_for_spec( metadata_raw.append(entry_meta) elif dtype == "list": items = data.get("items") + context: dict[str, Any] | None = None + root_node: str | None = None if items is None: expr = data.get("expr") if not expr: @@ -415,11 +417,6 @@ def _collect_prompts_for_spec( resolved_expr = expr.strip() items = _evaluate_expr(resolved_expr, context) root_node = resolved_expr.split(".", 1)[0] or None - if isinstance(items, list): - items = [ - maybe_resolve_artifact_ref(item, context, root_node) - for item in items - ] if not isinstance(items, list): raise ExecutionError( "spec.data.items must be a list or resolve to a list " @@ -427,6 +424,10 @@ def _collect_prompts_for_spec( ) if fetch_images: items, image_group_sizes = self._flatten_grouped_image_items(items) + items = [ + maybe_resolve_artifact_ref(item, context, root_node) + for item in items + ] s3_entries: list[tuple[int, str]] = [] @@ -503,6 +504,10 @@ def _collect_prompts_for_spec( raise ExecutionError("Missing image data for one or more items.") prompts = [x if isinstance(x, str) else "" for x in items] else: + items = [ + maybe_resolve_artifact_ref(item, context, root_node) + for item in items + ] prompts, apply_chat_template, found_system_prompt = ( normalize_prompt_payload(items) ) diff --git a/tests/worker/test_data_mixin_lineage.py b/tests/worker/test_data_mixin_lineage.py index 9a92dc3b..48892ee9 100644 --- a/tests/worker/test_data_mixin_lineage.py +++ b/tests/worker/test_data_mixin_lineage.py @@ -2,7 +2,10 @@ import json from pathlib import Path -from typing import Any +from types import SimpleNamespace +from typing import Any, cast + +from PIL import Image from worker.executors.mixins.data import DataMixin @@ -128,3 +131,42 @@ def test_dump_to_governance_with_merged_children(tmp_path: Path) -> None: ("tsk-c1", "tsk-up-b"), ("tsk-c2", "tsk-up-c"), } + + +def test_collect_prompts_resolves_grouped_image_artifact_refs_after_flatten( + tmp_path: Path, +) -> None: + mixin = _Mixin() + upstream_dir = tmp_path / "upstream-task" + artifacts_dir = upstream_dir / "artifacts" / "images" + artifacts_dir.mkdir(parents=True) + for name, color in (("a.png", "red"), ("b.png", "green"), ("c.png", "blue")): + Image.new("RGB", (2, 2), color=color).save(artifacts_dir / name) + + spec = cast( + Any, + SimpleNamespace( + data={"type": "list", "expr": "vision.images"}, + inference={}, + upstreamResults={ + "vision": { + "images": [ + [{"path": "images/a.png"}, {"path": "images/b.png"}], + [{"path": "images/c.png"}], + ], + "_artifacts": {"base_dir": upstream_dir.as_posix()}, + } + }, + ), + ) + + entry = mixin._collect_prompts_for_spec(spec, "tsk-vision", fetch_images=True) + + assert entry.image_group_sizes == [2, 1] + assert len(entry.images) == 3 + assert all(image is not None for image in entry.images) + assert [image.size for image in entry.images if image is not None] == [ + (2, 2), + (2, 2), + (2, 2), + ]