diff --git a/datashare-python/datashare_python/cli/task.py b/datashare-python/datashare_python/cli/task.py index 75878cc9..a9d8c6b5 100644 --- a/datashare-python/datashare_python/cli/task.py +++ b/datashare-python/datashare_python/cli/task.py @@ -10,8 +10,7 @@ from alive_progress import alive_bar from datashare_python.cli.utils import AsyncTyper, eprint -from datashare_python.constants import PYTHON_TASK_GROUP -from datashare_python.objects import READY_STATES, Task, TaskError, TaskState +from datashare_python.objects import READY_STATES, Task, TaskError, TaskGroup, TaskState from datashare_python.task_client import DatashareTaskClient logger = logging.getLogger(__name__) @@ -41,7 +40,7 @@ async def start( group: Annotated[ str | None, typer.Option("--group", "-g", help=_GROUP_HELP), - ] = PYTHON_TASK_GROUP.name, + ] = TaskGroup.python, # noqa: F821 ds_address: Annotated[ str, typer.Option("--ds-address", "-a", help=_DS_URL_HELP) ] = DEFAULT_DS_ADDRESS, diff --git a/datashare-python/datashare_python/config.py b/datashare-python/datashare_python/config.py index 2977184b..15cc2224 100644 --- a/datashare-python/datashare_python/config.py +++ b/datashare-python/datashare_python/config.py @@ -1,3 +1,4 @@ +from pathlib import Path from typing import ClassVar from icij_common.es import ESClient @@ -87,6 +88,10 @@ class WorkerConfig(ICIJSettings, LogWithWorkerIDMixin, BaseModel): max_concurrent_io_activities: int = 5 + docs_root: Path | None = None + artifacts_root: Path | None = None + workdir: Path | None = None + def to_es_client(self) -> ESClient: return self.elasticsearch.to_es_client(self.datashare.api_key) diff --git a/datashare-python/datashare_python/conftest.py b/datashare-python/datashare_python/conftest.py index 00c30cf5..6aabfbeb 100644 --- a/datashare-python/datashare_python/conftest.py +++ b/datashare-python/datashare_python/conftest.py @@ -262,6 +262,7 @@ def doc_2() -> Document: def doc_3() -> Document: return Document( id="doc-3", + index=TEST_PROJECT, root_document="root-3", language="SPANISH", content="traduce este texto al inglés", diff --git a/datashare-python/datashare_python/constants.py b/datashare-python/datashare_python/constants.py index 0efca13f..5e6ee583 100644 --- a/datashare-python/datashare_python/constants.py +++ b/datashare-python/datashare_python/constants.py @@ -1,11 +1,8 @@ from pathlib import Path -from .objects import TaskGroup - PACKAGE_DIR = Path(__file__).parent PACKAGE_ROOT = PACKAGE_DIR.parent -PYTHON_TASK_GROUP = TaskGroup(name="PYTHON") DEFAULT_TEMPORAL_ADDRESS = "temporal:7233" @@ -14,3 +11,5 @@ DEFAULT_NAMESPACE = "datashare-default" METADATA_JSON = "metadata.json" + +TIKA_METADATA_RESOURCENAME = "tika_metadata_resourcename" diff --git a/datashare-python/datashare_python/objects.py b/datashare-python/datashare_python/objects.py index 11d82ee7..7ee798dc 100644 --- a/datashare-python/datashare_python/objects.py +++ b/datashare-python/datashare_python/objects.py @@ -5,12 +5,24 @@ from enum import StrEnum, unique from io import BytesIO from pathlib import Path -from typing import Any, Literal, Self, TypeVar +from typing import Any, Literal, Self, TypeVar, cast from temporalio import workflow +from .constants import TIKA_METADATA_RESOURCENAME + with workflow.unsafe.imports_passed_through(): - from icij_common.es import DOC_CONTENT, DOC_LANGUAGE, DOC_ROOT_ID, ID_, SOURCE + from icij_common.es import ( + DOC_CONTENT, + DOC_CONTENT_TRANSLATED, + DOC_LANGUAGE, + DOC_METADATA, + DOC_PATH, + DOC_ROOT_ID, + ID_, + INDEX_, + SOURCE, + ) from icij_common.pydantic_utils import ( icij_config, @@ -137,11 +149,48 @@ class Task(Message): class TaskGroup: name: str + @property + @classmethod + def python(cls) -> Self: + return cls(name="PYTHON") + + +@unique +class DocumentLocation(StrEnum): + ORIGINAL = "original" + ARTIFACTS = "artifacts" + WORKDIR = "workdir" + + +class FilesystemDocument(DatashareModel): + id: str + path: Path + index: str + location: DocumentLocation + resource_name: str + + def locate( + self, original_root: Path, *, artifacts_root: Path, workdir: Path + ) -> Path: + from datashare_python.utils import artifacts_dir # noqa: PLC0415 + + project = self.index + match self.location: + case DocumentLocation.ORIGINAL: + return original_root / self.path + case DocumentLocation.ARTIFACTS: + return artifacts_root / artifacts_dir(self.id, project=project) / "raw" + case DocumentLocation.WORKDIR: + return workdir / self.path + case _: + raise ValueError(f"invalid location: {self.path}") + class Document(DatashareModel): id: str - root_document: str language: str + index: str | None = None + root_document: str | None = None content: str | None = None content_type: str | None = None path: Path | None = None @@ -149,6 +198,7 @@ class Document(DatashareModel): content_translated: dict[str, str] = Field( default_factory=dict, alias="content_translated" ) + metadata: dict[str, Any] | None = None type: str = Field(default="Document", frozen=True) @classmethod @@ -156,11 +206,42 @@ def from_es(cls, es_doc: dict) -> Self: sources = es_doc[SOURCE] return cls( id=es_doc[ID_], - content=sources[DOC_CONTENT], - content_translated=sources.get("content_translated", dict()), + index=es_doc.get(INDEX_), + content=sources.get(DOC_CONTENT), + content_translated=sources.get(DOC_CONTENT_TRANSLATED, dict()), language=sources[DOC_LANGUAGE], root_document=sources[DOC_ROOT_ID], tags=sources.get("tags", []), + path=sources.get(DOC_PATH), + metadata=sources.get(DOC_METADATA), + ) + + def to_filesystem(self) -> FilesystemDocument: + from .utils import artifacts_dir # noqa: PLC0415 + + if self.metadata is None: + raise ValueError( + "can't compute filesyste path for document withtout metadata" + ) + resource_name = cast(str, self.metadata[TIKA_METADATA_RESOURCENAME]) + if self.root_document is None: + path = self.path + location = DocumentLocation.ORIGINAL + else: + if self.index is None: + msg = ( + f"can't compute filesystem path for embedded doc {self.id} without" + f" index" + ) + raise ValueError(msg) + path = artifacts_dir(doc_id=self.id, project=self.index) / "raw" + location = DocumentLocation.ARTIFACTS + return FilesystemDocument( + id=self.id, + path=path, + index=self.index, + location=location, + resource_name=resource_name, ) diff --git a/datashare-python/datashare_python/utils.py b/datashare-python/datashare_python/utils.py index 403100e7..b3466762 100644 --- a/datashare-python/datashare_python/utils.py +++ b/datashare-python/datashare_python/utils.py @@ -440,12 +440,12 @@ def safe_dir(doc_id: str) -> Path: return Path(*parts) -def _artifacts_dir(doc_id: str, *, project: str) -> Path: +def artifacts_dir(doc_id: str, *, project: str) -> Path: return Path(project, safe_dir(doc_id), doc_id) def _metadata_path(doc_id: str, *, project: str) -> Path: - metadata_path = _artifacts_dir(doc_id, project=project) / METADATA_JSON + metadata_path = artifacts_dir(doc_id, project=project) / METADATA_JSON return metadata_path @@ -455,7 +455,7 @@ def _read_artifact_metadata(root: Path, artifact: DocArtifact) -> dict: def write_artifact(root: Path, artifact: DocArtifact) -> Path: - artif_dir = root / _artifacts_dir(artifact.doc_id, project=artifact.project) + artif_dir = root / artifacts_dir(artifact.doc_id, project=artifact.project) artif_dir.mkdir(exist_ok=True, parents=True) # TODO: if transcriptions are too large we could also serialize them # as jsonl @@ -479,19 +479,23 @@ def debuggable_name( displayable_file_name = [c[:component_size_limit] for c in path.parts] uuid = sha256(str(path).encode()).hexdigest() if deterministic else uuid4().hex uuid = uuid[:20] - return f"{uuid}-{'__'.join(displayable_file_name)}" + return f"{uuid}-{'--'.join(displayable_file_name)}" -def activity_contextual_id(*, wf_context: bool = True) -> str: - contextual_id = "" +def activity_contextual_id( + *, wf_context: bool = True, act_context: bool = False, run_context: bool = False +) -> str: + contextual_id = [] act_info = activity.info() + if not wf_context and not act_context: + raise ValueError("at least one of wf_context and act_context must be True") if wf_context: - wf_id = act_info.workflow_id - wf_run_id = act_info.workflow_run_id - wf_type = act_info.workflow_type - contextual_id += f"{wf_type}-{wf_id}-{wf_run_id}-" - act_id = act_info.activity_id - act_run_id = act_info.activity_id - act_type = act_info.activity_type - contextual_id = f"{act_type}-{act_id}-{act_run_id}" - return contextual_id + contextual_id.append(act_info.workflow_id) + if run_context: + contextual_id.append(act_info.workflow_run_id) + if act_context: + contextual_id.append(act_info.activity_type) + contextual_id.append(act_info.activity_id) + if run_context: + contextual_id.append(act_info.activity_run_id) + return "-".join(contextual_id) diff --git a/workers/asr-worker/asr_worker/activities.py b/workers/asr-worker/asr_worker/activities.py index f3a775b3..83c9a0f7 100644 --- a/workers/asr-worker/asr_worker/activities.py +++ b/workers/asr-worker/asr_worker/activities.py @@ -1,10 +1,9 @@ -import json +import os from asyncio import AbstractEventLoop -from collections.abc import AsyncGenerator, AsyncIterable, Generator, Iterable -from io import TextIOWrapper +from collections.abc import AsyncGenerator, AsyncIterable, Iterable from itertools import tee from pathlib import Path -from typing import Any, TextIO, cast +from typing import Any, cast from caul.objects import ASRResult, PreprocessedInput from caul.tasks import ( @@ -15,7 +14,12 @@ Preprocessor, ) from datashare_python.dependencies import lifespan_worker_config -from datashare_python.objects import DocArtifact +from datashare_python.objects import ( + DocArtifact, + Document, + DocumentLocation, + FilesystemDocument, +) from datashare_python.types_ import ProgressRateHandler, RawProgressHandler from datashare_python.utils import ( ActivityWithProgress, @@ -28,12 +32,13 @@ ) from icij_common.es import ( DOC_CONTENT_TYPE, + DOC_LANGUAGE, + DOC_METADATA, DOC_PATH, + DOC_ROOT_ID, ES_DOCUMENT_TYPE, HITS, - ID_, QUERY, - SOURCE, ESClient, and_query, has_type, @@ -56,7 +61,7 @@ TRANSCRIPTION_METADATA_VALUE, ) from .dependencies import lifespan_es_client -from .objects import DocId, InferenceRunnerConfig, Transcription +from .objects import InferenceRunnerConfig, Transcription _BASE_WEIGHT = 1.0 _SEARCH_AUDIOS_WEIGHT = _BASE_WEIGHT * 2 @@ -81,7 +86,12 @@ async def search_audio_paths( batch_paths = [ p.relative_to(workdir) async for p in search_audio_paths_act( - project, es_client, query, output_dir, batch_size + project, + es_client, + query, + config=worker_config, + output_dir=output_dir, + batch_size=batch_size, ) ] return batch_paths @@ -92,7 +102,6 @@ def preprocess( ) -> list[Path]: # TODO: this shouldn't be necessary, fix this bug worker_config = cast(ASRWorkerConfig, lifespan_worker_config()) - audio_root = worker_config.audios_root workdir = worker_config.workdir # TODO: implement caching preprocessor = Preprocessor.from_config(config) @@ -104,7 +113,7 @@ def preprocess( batch_paths = preprocess_act( preprocessor, audio_batch, - audio_root=audio_root, + worker_config=worker_config, output_dir=output_dir, ) batches = [p.relative_to(workdir) for p in batch_paths] @@ -114,6 +123,7 @@ def preprocess( def infer( self, preprocessed_inputs: list[Path], + project: str, config: InferenceRunnerConfig, *, progress: ProgressRateHandler | None = None, @@ -122,6 +132,9 @@ def infer( config = _INFERENCE_CONFIG_TYPE_ADAPTER.validate_python(config) worker_config = cast(ASRWorkerConfig, lifespan_worker_config()) workdir = worker_config.workdir + contextual_id = activity_contextual_id() + output_dir = workdir / project / contextual_id + output_dir.mkdir(parents=True, exist_ok=True) preprocessed_inputs = _LIST_OF_PATH_ADAPTER.validate_python(preprocessed_inputs) preprocessed_inputs = [workdir / p for p in preprocessed_inputs] if progress is not None: @@ -131,7 +144,7 @@ def infer( paths = infer_act( inference_runner, preprocessed_inputs, - workdir, + output_dir=output_dir, event_loop=self._event_loop, progress=progress, ) @@ -161,8 +174,11 @@ def postprocess( # TODO: implement caching postprocessor = Postprocessor.from_config(config) with postprocessor: - with audio_batch.open() as f: - doc_ids = [doc_id for doc_id, _ in read_batch(f)] + docs = ( + FilesystemDocument.model_validate(fs_doc) + for fs_doc in read_jsonl(audio_batch) + ) + doc_ids = [doc.id for doc in docs] if progress is not None: progress = to_raw_progress(progress, max_progress=len(doc_ids)) return postprocess_act( @@ -180,14 +196,17 @@ async def search_audio_paths_act( project: str, es_client: ESClient, query: dict[str, Any], + *, + config: ASRWorkerConfig, output_dir: Path, batch_size: int, ) -> AsyncIterable[Path]: # TODO: supported content types should be args - query = _search_audio_paths( + docs = _search_audio_paths( es_client, project, query, supported_content_types=SUPPORTED_CONTENT_TYPES ) - async for p in write_audio_batches(query, output_dir, batch_size): + docs = create_symlinks_for_embedded_audios(docs, config) + async for p in write_audio_batches(docs, output_dir, batch_size): yield p @@ -195,15 +214,28 @@ def preprocess_act( preprocessor: Preprocessor, audio_batch: Path, *, - audio_root: Path, + worker_config: ASRWorkerConfig, output_dir: Path, ) -> list[Path]: - with audio_batch.open() as f: - audios = read_batch(f) - audios = (str(audio_root / p) for _, p in audios) - # TODO: implement a caching strategy here, we could avoid processing files - # which have already been preprocessed - return list(_preprocess(preprocessor, audios, output_dir)) + audios_root = worker_config.audios_root + artifacts_root = worker_config.artifacts_root + workdir = worker_config.workdir + audios = ( + FilesystemDocument.model_validate(fs_doc) for fs_doc in read_jsonl(audio_batch) + ) + audios = ( + str( + fs_doc.locate( + original_root=audios_root, + artifacts_root=artifacts_root, + workdir=workdir, + ) + ) + for fs_doc in audios + ) + # TODO: implement a caching strategy here, we could avoid processing files + # which have already been preprocessed + return list(_preprocess(preprocessor, audios, output_dir)) def infer_act( @@ -305,25 +337,20 @@ def _relative_input( async def write_audio_batches( - ids_and_paths: AsyncIterable[tuple[DocId, str]], root: Path, batch_size: int + docs: AsyncIterable[FilesystemDocument], root: Path, batch_size: int ) -> AsyncIterable[Path]: batch_id = 0 - async for batch in async_batches(ids_and_paths, batch_size): + async for batch in async_batches(docs, batch_size): batch_path = root / f"{batch_id}.txt" with batch_path.open("w") as f: - write_audio_batch(batch, f) + for fs_doc in batch: + f.write(f"{fs_doc.model_dump_json()}\n") yield batch_path batch_id += 1 -def write_audio_batch(batch: Iterable[tuple[DocId, Path]], f: TextIOWrapper) -> None: - for doc_id, path in batch: - data = {"doc_id": doc_id, "path": str(path)} - f.write(f"{json.dumps(data)}\n") - - _DOC_TYPE_QUERY = has_type(type_field="type", type_value=ES_DOCUMENT_TYPE) -_DOC_CONTENT_SOURCES = [DOC_PATH] +_DOC_CONTENT_SOURCES = [DOC_PATH, DOC_ROOT_ID, DOC_LANGUAGE, DOC_METADATA, DOC_PATH] async def _search_audio_paths( @@ -331,13 +358,13 @@ async def _search_audio_paths( project: str, query: dict[str, Any], supported_content_types: set[str], -) -> AsyncGenerator[tuple[DocId, str], None]: +) -> AsyncGenerator[FilesystemDocument, None]: body = _with_audio_content(query, supported_content_types) async for page in es_client.poll_search_pages( index=project, body=body, sort="_doc:asc", _source_includes=_DOC_CONTENT_SOURCES ): for hit in page[HITS][HITS]: - yield hit[ID_], hit[SOURCE][DOC_PATH] + yield Document.from_es(hit).to_filesystem() def _content_type_query(supported_content_types: set[str]) -> dict[str, Any]: @@ -355,10 +382,33 @@ def _with_audio_content( return and_query(query, type_query[QUERY]) -def read_batch(f: TextIO) -> Generator[tuple[DocId, Path], None, None]: - for line in f: - data = json.loads(line) - yield data["doc_id"], Path(data["path"]) +async def create_symlinks_for_embedded_audios( + docs: AsyncIterable[FilesystemDocument], config: ASRWorkerConfig +) -> AsyncIterable[FilesystemDocument]: + workdir = config.workdir + artifacts_root = config.artifacts_root + symlinks_dir = None + async for d in docs: + if d.location == DocumentLocation.ARTIFACTS: + if symlinks_dir is None: + symlinks_dir = workdir / d.index / "symlinks" + symlinks_dir.mkdir(parents=True, exist_ok=True) + artifact_path = artifacts_root / d.path + audio_ext = Path(d.resource_name).suffix + symlink_path = d.path.relative_to(Path(d.index)) + symlink_path = symlinks_dir / f"{symlink_path}{audio_ext}" + symlink_path.parent.mkdir(parents=True, exist_ok=True) + os.symlink(artifact_path, symlink_path) + symlink = FilesystemDocument( + path=symlink_path.relative_to(workdir), + id=d.id, + location=DocumentLocation.WORKDIR, + index=d.index, + resource_name=d.resource_name, + ) + yield symlink + else: + yield d REGISTRY = [ diff --git a/workers/asr-worker/asr_worker/config.py b/workers/asr-worker/asr_worker/config.py index 99511360..0747c5e1 100644 --- a/workers/asr-worker/asr_worker/config.py +++ b/workers/asr-worker/asr_worker/config.py @@ -11,9 +11,13 @@ class ASRWorkerConfig(WorkerConfig): loggers: ClassVar[list[str]] = Field(_ALL_LOGGERS, frozen=True) - audios_root: Path + docs_root: Path = Field(alias="audios_root") artifacts_root: Path workdir: Path + @property + def audios_root(self) -> Path: + return self.docs_root + WORKER_CONFIG_CLS = ASRWorkerConfig diff --git a/workers/asr-worker/asr_worker/workflows.py b/workers/asr-worker/asr_worker/workflows.py index 638b47a7..6963591d 100644 --- a/workers/asr-worker/asr_worker/workflows.py +++ b/workers/asr-worker/asr_worker/workflows.py @@ -53,7 +53,10 @@ async def run(self, args: ASRArgs) -> ASRResponse: logger.info("preprocessing files...") preprocessed_batches = await gather(*preprocessing_acts) inference_args = zip( - preprocessed_batches, repeat(config.inference), strict=False + preprocessed_batches, + repeat(args.project), + repeat(config.inference), + strict=False, ) logger.info("preprocessing complete !") # Inference diff --git a/workers/asr-worker/tests/conftest.py b/workers/asr-worker/tests/conftest.py index 2edf094a..94302d81 100644 --- a/workers/asr-worker/tests/conftest.py +++ b/workers/asr-worker/tests/conftest.py @@ -1,8 +1,10 @@ +import shutil from pathlib import Path import pytest from _pytest.tmpdir import TempPathFactory from asr_worker.config import ASRWorkerConfig +from asr_worker.constants import SUPPORTED_CONTENT_TYPES from asr_worker.dependencies import set_multiprocessing_start_method from datashare_python.config import DatashareClientConfig, TemporalClientConfig from datashare_python.conftest import ( # noqa: F401 @@ -19,10 +21,13 @@ worker_lifetime_deps, ) from datashare_python.dependencies import set_es_client, set_temporal_client -from datashare_python.objects import Document +from datashare_python.objects import Document, FilesystemDocument from datashare_python.types_ import ContextManagerFactory +from datashare_python.utils import artifacts_dir from icij_common.es import ESClient +from tests import AUDIOS_PATH + @pytest.fixture(scope="session") def test_deps() -> list[ContextManagerFactory]: @@ -53,9 +58,11 @@ def doc_0() -> Document: return Document( id="doc-0", root_document="root-0", + index=TEST_PROJECT, language="ENGLISH", content_type="audio/wav", - path=Path("doc-0.wav"), + path=Path("root-0.eml"), + metadata={"tika_metadata_resourcename": "doc-0.wav"}, ) @@ -64,9 +71,11 @@ def doc_1() -> Document: return Document( id="doc-1", root_document="root-1", + index=TEST_PROJECT, language="ENGLISH", content_type="application/json", path=Path("doc-1.json"), + metadata={"tika_metadata_resourcename": "doc-1.json"}, ) @@ -74,15 +83,16 @@ def doc_1() -> Document: def doc_2() -> Document: return Document( id="doc-2", - root_document="root-2", + index=TEST_PROJECT, language="FRENCH", content_type="audio/mpeg", path=Path("doc-2.mp3"), + metadata={"tika_metadata_resourcename": "doc-2.mp3"}, ) @pytest.fixture -async def populate_es_with_audio( +async def populate_es_with_audios( test_es_client: ESClient, # noqa: F811 doc_0: Document, doc_1: Document, @@ -93,3 +103,36 @@ async def populate_es_with_audio( async for _ in index_docs(test_es_client, docs=docs, index_name=TEST_PROJECT): pass return docs + + +def _clear_dirs(config: ASRWorkerConfig) -> None: + shutil.rmtree(config.artifacts_root) + config.artifacts_root.mkdir(parents=True, exist_ok=True) + shutil.rmtree(config.workdir) + config.workdir.mkdir(parents=True, exist_ok=True) + + +@pytest.fixture +def with_audio_docs( + populate_es_with_audios: list[Document], test_worker_config: ASRWorkerConfig +) -> list[FilesystemDocument]: + config = test_worker_config + _clear_dirs(test_worker_config) + docs = [ + d for d in populate_es_with_audios if d.content_type in SUPPORTED_CONTENT_TYPES + ] + paths = [] + audio_path = AUDIOS_PATH / "asr_test.wav" + for doc in docs: + if doc.root_document is None: + config.audios_root.mkdir(parents=True, exist_ok=True) + shutil.copy(audio_path, config.audios_root / doc.path) + else: + artifact_path = ( + config.artifacts_root / artifacts_dir(doc.id, project=doc.index) / "raw" + ) + artifact_path.parent.mkdir(parents=True, exist_ok=True) + shutil.copy(audio_path, artifact_path) + fs_doc = doc.to_filesystem() + paths.append(fs_doc) + return paths diff --git a/workers/asr-worker/tests/test_activities.py b/workers/asr-worker/tests/test_activities.py index 30b095ae..b8266376 100644 --- a/workers/asr-worker/tests/test_activities.py +++ b/workers/asr-worker/tests/test_activities.py @@ -10,17 +10,16 @@ infer_act, postprocess_act, preprocess_act, - read_batch, search_audio_paths_act, - write_audio_batch, write_audio_batches, ) +from asr_worker.config import ASRWorkerConfig from asr_worker.objects import DocId, Transcription from asr_worker.utils import read_jsonl from caul.objects import ASRResult, InputMetadata, PreprocessedInput, PreprocessorOutput from caul.tasks import InferenceRunner, Postprocessor, Preprocessor from datashare_python.conftest import TEST_PROJECT -from datashare_python.objects import Document +from datashare_python.objects import DocumentLocation, FilesystemDocument from icij_common.es import ESClient, ids_query, match_all from icij_common.iter_utils import batches from icij_common.registrable import RegistrableConfig @@ -65,8 +64,23 @@ ), ] +FS_DOCUMENT_0 = FilesystemDocument( + id="doc-0", + path=Path(TEST_PROJECT, "symlinks", "do", "c-", "doc-0", "raw.wav"), + index=TEST_PROJECT, + location=DocumentLocation.WORKDIR, + resource_name="doc-0.wav", +) +DS_DOCUMENT_2 = FilesystemDocument( + id="doc-2", + path=Path("doc-2.mp3"), + index=TEST_PROJECT, + location=DocumentLocation.ORIGINAL, + resource_name="doc-2.mp3", +) + -class MockPeprocessor(Preprocessor): +class MockPreprocessor(Preprocessor): def __init__(self, batch_size: int) -> None: self._batch_size = batch_size @@ -131,25 +145,26 @@ def process( ("query", "expected_batches"), [ # Supports empty query - ({}, [[("doc-0", Path("doc-0.wav")), ("doc-2", Path("doc-2.mp3"))]]), + ({}, [[FS_DOCUMENT_0, DS_DOCUMENT_2]]), # Return all audio/video docs - (match_all(), [[("doc-0", Path("doc-0.wav")), ("doc-2", Path("doc-2.mp3"))]]), - (ids_query(["doc-0"]), [[("doc-0", Path("doc-0.wav"))]]), + (match_all(), [[FS_DOCUMENT_0, DS_DOCUMENT_2]]), + (ids_query(["doc-0"]), [[FS_DOCUMENT_0]]), # Should filter non supported content type (ids_query(["doc-1"]), []), ], ) async def test_search_audio_paths_act( - populate_es_with_audio: list[Document], + with_audio_docs: list[FilesystemDocument], test_es_client: ESClient, query: dict, expected_batches: list[tuple[DocId, Path]], + test_worker_config: ASRWorkerConfig, tmpdir: Path, ) -> None: # Given tmpdir = Path(tmpdir) - batch_size = len(populate_es_with_audio) - assert batch_size == 4 + worker_config = test_worker_config + batch_size = len(with_audio_docs) client = test_es_client # When batch_paths = [ @@ -160,35 +175,44 @@ async def test_search_audio_paths_act( query=query, batch_size=batch_size, output_dir=tmpdir, + config=worker_config, ) ] # Then results = [] for b in batch_paths: - with b.open() as f: - results.append(list(read_batch(f))) + results.append( + [FilesystemDocument.model_validate(fs_doc) for fs_doc in read_jsonl(b)] + ) assert results == expected_batches -def test_preprocess_act(tmpdir: Path) -> None: +def test_preprocess_act(test_worker_config: ASRWorkerConfig, tmpdir: Path) -> None: # Given - tmpdir = Path(tmpdir) + output_dir = Path(tmpdir) n_audios = 3 batch_size = n_audios - 1 - output_dir = tmpdir.joinpath("artifacts") - output_dir.mkdir() - audio_root = tmpdir.joinpath("audio_root") audio_batch = tmpdir / "audio_batch.txt" - batch = [(f"doc-{i}", Path(str(i))) for i in range(n_audios)] + batch = [ + FilesystemDocument( + id=f"doc-{i}", + path=Path(str(i)), + location=DocumentLocation.ARTIFACTS, + index=TEST_PROJECT, + resource_name=f"doc-{i}.wav", + ) + for i in range(n_audios) + ] with audio_batch.open("w") as f: - write_audio_batch(batch, f) - preprocessor = MockPeprocessor(batch_size=batch_size) + for fs_doc in batch: + f.write(fs_doc.model_dump_json() + "\n") + preprocessor = MockPreprocessor(batch_size=batch_size) # When batch_files = preprocess_act( preprocessor, audio_batch=audio_batch, - audio_root=audio_root, + worker_config=test_worker_config, output_dir=output_dir, ) @@ -273,10 +297,17 @@ async def test_write_audio_search_results(tmpdir: Path) -> None: root = Path(tmpdir) batch_size = 2 - async def results() -> AsyncGenerator[tuple[DocId, str], None]: + async def results() -> AsyncGenerator[FilesystemDocument, None]: res = ["doc-0", "doc-1", "doc-2"] for r in res: - yield r, f"{r}.wav" + fs_doc = FilesystemDocument( + id=r, + path=Path(f"{r}.wav"), + index=TEST_PROJECT, + location=DocumentLocation.WORKDIR, + resource_name=f"{r}.wav", + ) + yield fs_doc # When results = write_audio_batches(results(), root=root, batch_size=batch_size) @@ -285,10 +316,12 @@ async def results() -> AsyncGenerator[tuple[DocId, str], None]: async def expected_content() -> AsyncGenerator[str, None]: contents = [ [ - '{"doc_id": "doc-0", "path": "doc-0.wav"}', - '{"doc_id": "doc-1", "path": "doc-1.wav"}', + '{"id":"doc-0","path":"doc-0.wav","index":"test-project","location":"workdir","resource_name":"doc-0.wav"}', + '{"id":"doc-1","path":"doc-1.wav","index":"test-project","location":"workdir","resource_name":"doc-1.wav"}', + ], + [ + '{"id":"doc-2","path":"doc-2.wav","index":"test-project","location":"workdir","resource_name":"doc-2.wav"}' ], - ['{"doc_id": "doc-2", "path": "doc-2.wav"}'], ] for line in contents: yield "\n".join(line) + "\n" diff --git a/workers/asr-worker/tests/test_workflows.py b/workers/asr-worker/tests/test_workflows.py index d523426d..51440984 100644 --- a/workers/asr-worker/tests/test_workflows.py +++ b/workers/asr-worker/tests/test_workflows.py @@ -1,6 +1,5 @@ import json import math -import shutil import uuid from asyncio import AbstractEventLoop from collections.abc import AsyncGenerator @@ -9,12 +8,10 @@ import pytest from asr_worker.activities import ASRActivities from asr_worker.config import ASRWorkerConfig -from asr_worker.constants import SUPPORTED_CONTENT_TYPES from asr_worker.dependencies import REGISTRY from asr_worker.objects import ( ASRArgs, ASRPipelineConfig, - DocId, Timestamp, Transcript, Transcription, @@ -22,14 +19,12 @@ from asr_worker.workflows import ASRWorkflow, TaskQueues from caul.objects import ASRResult from datashare_python.conftest import TEST_PROJECT -from datashare_python.objects import Document +from datashare_python.objects import FilesystemDocument from datashare_python.types_ import TemporalClient from datashare_python.worker import worker_context from pydantic import TypeAdapter from temporalio.worker import Worker -from . import AUDIOS_PATH - _LIST_OF_PATH_ADAPTER = TypeAdapter(list[Path]) _MODEL_RESULT_0 = ASRResult( @@ -135,23 +130,6 @@ async def gpu_inference_worker( ) -@pytest.fixture -def with_audio_docs( - populate_es_with_audio: list[Document], test_worker_config: ASRWorkerConfig -) -> list[tuple[DocId, Path]]: - config = test_worker_config - docs = [ - d for d in populate_es_with_audio if d.content_type in SUPPORTED_CONTENT_TYPES - ] - paths = [] - config.audios_root.mkdir(parents=True, exist_ok=True) - audio_path = AUDIOS_PATH / "asr_test.wav" - for doc in docs: - shutil.copy(audio_path, config.audios_root / doc.path) - paths.append((doc.id, doc.path)) - return paths - - @pytest.mark.e2e async def test_asr_workflow_e2e( test_temporal_client_session: TemporalClient, @@ -159,7 +137,7 @@ async def test_asr_workflow_e2e( gpu_inference_worker: Worker, # noqa: ARG001 io_bound_worker: Worker, # noqa: ARG001 test_worker_config: ASRWorkerConfig, - with_audio_docs: list[tuple[DocId, Path]], # noqa: ARG001 + with_audio_docs: list[FilesystemDocument], # noqa: ARG001 ) -> None: # Given config = test_worker_config @@ -168,8 +146,7 @@ async def test_asr_workflow_e2e( n_audios = len(with_audio_docs) batch_size = n_audios - 1 project = TEST_PROJECT - doc_ids, _ = zip(*with_audio_docs, strict=True) - doc_ids = list(doc_ids) + doc_ids = [d.id for d in with_audio_docs] args = ASRArgs( project=project, docs=doc_ids, config=ASRPipelineConfig(), batch_size=batch_size )