diff --git a/asr-worker/asr_worker/activities.py b/asr-worker/asr_worker/activities.py index 91e4e058..f303382b 100644 --- a/asr-worker/asr_worker/activities.py +++ b/asr-worker/asr_worker/activities.py @@ -1,8 +1,10 @@ -from collections.abc import AsyncGenerator, AsyncIterable, Callable, Generator, Iterable -from contextlib import AbstractContextManager, contextmanager +import json +from asyncio import AbstractEventLoop +from collections.abc import AsyncGenerator, AsyncIterable, Generator, Iterable +from io import TextIOWrapper from itertools import tee from pathlib import Path -from typing import Any, cast +from typing import Any, TextIO, cast from caul.objects import ASRResult, PreprocessedInput from caul.tasks import ( @@ -13,9 +15,10 @@ Preprocessor, ) from datashare_python.dependencies import lifespan_worker_config -from datashare_python.types_ import ProgressRateHandler +from datashare_python.types_ import ProgressRateHandler, RawProgressHandler from datashare_python.utils import ( ActivityWithProgress, + DocArtifact, activity_contextual_id, activity_defn, debuggable_name, @@ -25,10 +28,12 @@ ) from icij_common.es import ( DOC_CONTENT_TYPE, + DOC_PATH, ES_DOCUMENT_TYPE, HITS, ID_, QUERY, + SOURCE, ESClient, and_query, has_type, @@ -51,7 +56,7 @@ TRANSCRIPTION_METADATA_VALUE, ) from .dependencies import lifespan_es_client -from .models import InferenceRunnerConfig, Transcription +from .models import DocId, InferenceRunnerConfig, Transcription _BASE_WEIGHT = 1.0 _SEARCH_AUDIOS_WEIGHT = _BASE_WEIGHT * 2 @@ -64,28 +69,26 @@ class ASRActivities(ActivityWithProgress): @activity_defn(name=SEARCH_AUDIOS_ACTIVITY, progress_weight=_SEARCH_AUDIOS_WEIGHT) - async def search_audios( + async def search_audio_paths( self, project: str, query: dict[str, Any], batch_size: int ) -> list[Path]: es_client = lifespan_es_client() worker_config = cast(ASRWorkerConfig, lifespan_worker_config()) batch_dir_name = activity_contextual_id() workdir = worker_config.workdir - batch_root = workdir / batch_dir_name - batch_root.mkdir(parents=True, exist_ok=True) - # TODO: supported content types should be args - query = search_audios( - es_client, project, query, supported_content_types=SUPPORTED_CONTENT_TYPES - ) + output_dir = workdir / batch_dir_name + output_dir.mkdir(parents=True, exist_ok=True) batch_paths = [ p.relative_to(workdir) - async for p in write_audio_search_results(query, batch_root, batch_size) + async for p in search_audio_paths_act( + project, es_client, query, output_dir, batch_size + ) ] return batch_paths @activity_defn(name=PREPROCESS_ACTIVITY, progress_weight=_PREPROCESS_WEIGHT) def preprocess( - self, paths: list[Path] | Path, config: ParakeetPreprocessorConfig + self, audio_batch: Path, config: ParakeetPreprocessorConfig ) -> list[Path]: # TODO: this shouldn't be necessary, fix this bug worker_config = cast(ASRWorkerConfig, lifespan_worker_config()) @@ -96,19 +99,15 @@ def preprocess( contextual_id = activity_contextual_id() output_dir = workdir / contextual_id output_dir.mkdir(parents=True, exist_ok=True) - if isinstance(paths, Path): - audio_cm = _read_audio_ids(paths) - else: - paths = _LIST_OF_PATH_ADAPTER.validate_python(paths) - audio_cm = _read_audios_cm(paths) - with audio_cm() as audios, preprocessor: - audios = (str(audio_root / p) for p in audios) # noqa: PLW2901 - # TODO: implement a caching strategy here, we could avoid processing files - # which have already been preprocessed - batches = [ - f.relative_to(workdir) - for f in preprocess(preprocessor, audios, output_dir) - ] + audio_batch = workdir / audio_batch + with preprocessor: + batch_paths = preprocess_act( + preprocessor, + audio_batch, + audio_root=audio_root, + output_dir=output_dir, + ) + batches = [p.relative_to(workdir) for p in batch_paths] return batches @activity_defn(name=RUN_INFERENCE_ACTIVITY, progress_weight=_INFERENCE_WEIGHT) @@ -124,53 +123,35 @@ def infer( worker_config = cast(ASRWorkerConfig, lifespan_worker_config()) workdir = worker_config.workdir preprocessed_inputs = _LIST_OF_PATH_ADAPTER.validate_python(preprocessed_inputs) + preprocessed_inputs = [workdir / p for p in preprocessed_inputs] if progress is not None: progress = to_raw_progress(progress, max_progress=len(preprocessed_inputs)) - batch_files = (workdir / batch_file for batch_file in preprocessed_inputs) - # Audios paths in the input are relative to the batch file directory - inputs = ( - [ - _relative_input(PreprocessedInput.model_validate(i), f.parent) - for i in read_jsonl(f) - ] - for f in batch_files - ) - audio_paths, inputs = tee(inputs) - audio_paths = ( - i.metadata.preprocessed_file_path for b in audio_paths for i in b - ) - # TODO: implement caching inference_runner = InferenceRunner.from_config(config) with inference_runner: - # TODO: extract this into a function to improve testability - paths = [] - for res_i, (path, asr_res) in enumerate( - zip(audio_paths, inference_runner.process(inputs), strict=True) - ): - filename = f"{debuggable_name(path.name)}-transcript.json" - transcript_path = workdir / safe_dir(filename) / filename - transcript_path.parent.mkdir(parents=True, exist_ok=True) - transcript_path.write_text(asr_res.model_dump_json()) - paths.append(transcript_path.relative_to(workdir)) - if progress is not None: - self._event_loop.run_until_complete(progress(res_i)) - return paths + paths = infer_act( + inference_runner, + preprocessed_inputs, + workdir, + event_loop=self._event_loop, + progress=progress, + ) + return [p.relative_to(workdir) for p in paths] @activity_defn(name=POSTPROCESS_ACTIVITY, progress_weight=_BASE_WEIGHT) def postprocess( self, inference_results: list[Path], - input_paths: list[Path], + audio_batch: Path, config: ParakeetPostprocessorConfig, project: str, *, progress: ProgressRateHandler | None = None, - ) -> None: + ) -> int: # TODO: this shouldn't be necessary, fix this bug - input_paths = _LIST_OF_PATH_ADAPTER.validate_python(input_paths) config = ParakeetPostprocessorConfig.model_validate(config) worker_config = cast(ASRWorkerConfig, lifespan_worker_config()) workdir = worker_config.workdir + audio_batch = workdir / audio_batch artifacts_root = worker_config.artifacts_root inference_results = _LIST_OF_PATH_ADAPTER.validate_python(inference_results) inference_results = ( @@ -180,25 +161,108 @@ def postprocess( # TODO: implement caching postprocessor = Postprocessor.from_config(config) with postprocessor: + with audio_batch.open() as f: + doc_ids = [doc_id for doc_id, _ in read_batch(f)] if progress is not None: - progress = to_raw_progress(progress, max_progress=len(input_paths)) - transcriptions = postprocessor.process(inference_results) - # Strict is important here ! - for i, (original, asr_result) in enumerate( - zip(input_paths, transcriptions, strict=True) - ): - t_path = write_transcription( - asr_result, - original.name, - artifacts_root=artifacts_root, - project=project, - ) - activity.logger.debug("wrote transcription for %s", t_path) - if progress is not None: - self._event_loop.run_until_complete(progress(i)) - - -def preprocess( + progress = to_raw_progress(progress, max_progress=len(doc_ids)) + return postprocess_act( + postprocessor, + inference_results, + doc_ids, + project=project, + artifacts_root=artifacts_root, + event_loop=self._event_loop, + progress=progress, + ) + + +async def search_audio_paths_act( + project: str, + es_client: ESClient, + query: dict[str, Any], + output_dir: Path, + batch_size: int, +) -> AsyncIterable[Path]: + # TODO: supported content types should be args + query = _search_audio_paths( + es_client, project, query, supported_content_types=SUPPORTED_CONTENT_TYPES + ) + async for p in write_audio_batches(query, output_dir, batch_size): + yield p + + +def preprocess_act( + preprocessor: Preprocessor, + audio_batch: Path, + *, + audio_root: Path, + output_dir: Path, +) -> list[Path]: + with audio_batch.open() as f: + audios = read_batch(f) + audios = (str(audio_root / p) for _, p in audios) + # TODO: implement a caching strategy here, we could avoid processing files + # which have already been preprocessed + return list(_preprocess(preprocessor, audios, output_dir)) + + +def infer_act( + inference_runner: InferenceRunner, + preprocessed_inputs: list[Path], + output_dir: Path, + event_loop: AbstractEventLoop | None = None, + progress: RawProgressHandler | None = None, +) -> list[Path]: + # Audios paths in the input are relative to the batch file directory + inputs = ( + [ + _relative_input(PreprocessedInput.model_validate(i), f.parent) + for i in read_jsonl(f) + ] + for f in preprocessed_inputs + ) + audio_paths, inputs = tee(inputs) + audio_paths = (i.metadata.preprocessed_file_path for b in audio_paths for i in b) + # TODO: implement caching + paths = [] + for res_i, (path, asr_res) in enumerate( + zip(audio_paths, inference_runner.process(inputs), strict=True) + ): + filename = f"{debuggable_name(path.name)}-transcript.json" + transcript_path = output_dir / safe_dir(filename) / filename + transcript_path.parent.mkdir(parents=True, exist_ok=True) + transcript_path.write_text(asr_res.model_dump_json()) + paths.append(transcript_path) + if progress is not None and event_loop is not None: + event_loop.run_until_complete(progress(res_i)) + return paths + + +def postprocess_act( + postprocessor: Postprocessor, + inference_results: Iterable[ASRResult], + doc_ids: Iterable[str], + *, + artifacts_root: Path, + project: str, + event_loop: AbstractEventLoop | None = None, + progress: ProgressRateHandler | None = None, +) -> int: + transcriptions = postprocessor.process(inference_results) + # Strict is important here ! + n_docs = 0 + for i, (doc_id, asr_result) in enumerate(zip(doc_ids, transcriptions, strict=True)): + n_docs += 1 + t_path = write_transcription( + doc_id, asr_result, artifacts_root=artifacts_root, project=project + ) + activity.logger.debug("wrote transcription for %s", t_path) + if progress is not None and event_loop is not None: + event_loop.run_until_complete(progress(i)) + return n_docs + + +def _preprocess( preprocessor: Preprocessor, audios: Iterable[str], output_dir: Path ) -> Iterable[Path]: for batch_i, batch in enumerate( @@ -214,24 +278,20 @@ def preprocess( def write_transcription( - asr_result: ASRResult, - transcribed_filename: str, - *, - artifacts_root: Path, - project: str, + doc_id: str, asr_result: ASRResult, *, artifacts_root: Path, project: str ) -> Path: result = Transcription.from_asr_handler_result(asr_result) - artifact = result.model_dump_json().encode() - # TODO: if transcriptions are too large we could also serialize them - # as jsonl - rel_path = write_artifact( - artifact, - artifacts_root, + artifact_bytes = result.model_dump_json().encode() + artifact = DocArtifact( project=project, - filename=transcribed_filename, + doc_id=doc_id, + filename=TRANSCRIPTION_METADATA_VALUE, metadata_key=TRANSCRIPTION_METADATA_KEY, - metadata_value=TRANSCRIPTION_METADATA_VALUE, + artifact=artifact_bytes, ) + # TODO: if transcriptions are too large we could also serialize them + # as jsonl + rel_path = write_artifact(artifacts_root, artifact) return rel_path @@ -244,34 +304,40 @@ def _relative_input( return PreprocessedInput(metadata=metadata) # noqa: F821 -async def write_audio_search_results( - results: AsyncIterable[str], root: Path, batch_size: int +async def write_audio_batches( + ids_and_paths: AsyncIterable[tuple[DocId, str]], root: Path, batch_size: int ) -> AsyncIterable[Path]: batch_id = 0 - async for batch in async_batches(results, batch_size): + async for batch in async_batches(ids_and_paths, batch_size): batch_path = root / f"{batch_id}.txt" with batch_path.open("w") as f: - for doc_id in batch: - f.write(f"{doc_id}\n") + write_audio_batch(batch, f) yield batch_path batch_id += 1 +def write_audio_batch(batch: Iterable[tuple[DocId, Path]], f: TextIOWrapper) -> None: + for doc_id, path in batch: + data = {"doc_id": doc_id, "path": str(path)} + f.write(f"{json.dumps(data)}\n") + + _DOC_TYPE_QUERY = has_type(type_field="type", type_value=ES_DOCUMENT_TYPE) +_DOC_CONTENT_SOURCES = [DOC_PATH] -async def search_audios( +async def _search_audio_paths( es_client: ESClient, project: str, query: dict[str, Any], supported_content_types: set[str], -) -> AsyncGenerator[str, None]: +) -> AsyncGenerator[tuple[DocId, str], None]: body = _with_audio_content(query, supported_content_types) async for page in es_client.poll_search_pages( - index=project, body=body, sort="_doc:asc", _source=False + index=project, body=body, sort="_doc:asc", _source_includes=_DOC_CONTENT_SOURCES ): for hit in page[HITS][HITS]: - yield hit[ID_] + yield hit[ID_], hit[SOURCE][DOC_PATH] def _content_type_query(supported_content_types: set[str]) -> dict[str, Any]: @@ -289,25 +355,14 @@ def _with_audio_content( return and_query(query, type_query[QUERY]) -def _read_audio_ids(path: Path) -> AbstractContextManager: - @contextmanager - def cm() -> Generator[Iterable[Path], None, None]: - with open(path) as f: - yield (Path(line.strip()) for line in f) - - return cm - - -def _read_audios_cm(paths: list[Path]) -> Callable[[], AbstractContextManager]: - @contextmanager - def cm() -> Generator[Iterable[Path], None, None]: - yield iter(paths) - - return cm +def read_batch(f: TextIO) -> Generator[tuple[DocId, Path], None, None]: + for line in f: + data = json.loads(line) + yield data["doc_id"], Path(data["path"]) REGISTRY = [ - ASRActivities.search_audios, + ASRActivities.search_audio_paths, ASRActivities.preprocess, ASRActivities.infer, ASRActivities.postprocess, diff --git a/asr-worker/asr_worker/models.py b/asr-worker/asr_worker/models.py index 745e6d8d..e0089786 100644 --- a/asr-worker/asr_worker/models.py +++ b/asr-worker/asr_worker/models.py @@ -1,5 +1,4 @@ import math -from pathlib import Path from typing import Annotated, Any, Self from caul.asr_pipeline import ASRPipelineConfig @@ -32,11 +31,12 @@ DocumentSearchQuery = dict[str, Any] +DocId = str class ASRArgs(DatashareModel): project: str - docs: list[Path] | DocumentSearchQuery + docs: list[DocId] | DocumentSearchQuery config: ASRPipelineConfig = Field(default=_DEFAULT_PIPELINE_CONFIG) batch_size: int diff --git a/asr-worker/asr_worker/workflows.py b/asr-worker/asr_worker/workflows.py index d406900a..b9811918 100644 --- a/asr-worker/asr_worker/workflows.py +++ b/asr-worker/asr_worker/workflows.py @@ -4,6 +4,7 @@ from itertools import repeat from datashare_python.utils import WorkflowWithProgress, execute_activity +from icij_common.es import has_id from pydantic import TypeAdapter from temporalio import workflow @@ -30,24 +31,16 @@ async def run(self, args: ASRArgs) -> ASRResponse: logger = workflow.logger config = args.config batch_size = args.batch_size - docs = args.docs - if isinstance(docs, dict): - args = [args.project, docs, batch_size] - batched_input_paths = workflow.execute_activity( - ASRActivities.search_audios, - args=args, - start_to_close_timeout=timedelta(seconds=TEN_MINUTES), - task_queue=TaskQueues.IO, - ) - else: - batched_input_paths = [ - docs[batch_start : batch_start + batch_size] - for batch_start in range(0, len(docs), batch_size) - ] - # Preprocessing - preprocess_args = zip( - batched_input_paths, repeat(config.preprocessing), strict=False + doc_query = has_id(args.docs) if isinstance(args.docs, list) else args.docs + search_args = [args.project, doc_query, batch_size] + batch_paths = await workflow.execute_activity( + ASRActivities.search_audio_paths, + args=search_args, + start_to_close_timeout=timedelta(seconds=TEN_MINUTES), + task_queue=TaskQueues.IO, ) + # Preprocessing + preprocess_args = zip(batch_paths, repeat(config.preprocessing), strict=False) preprocessing_acts = ( execute_activity( ASRActivities.preprocess, @@ -81,7 +74,7 @@ async def run(self, args: ASRArgs) -> ASRResponse: postprocessing_ins = list( zip( inference_results, - batched_input_paths, + batch_paths, repeat(config.postprocessing), repeat(args.project), strict=False, @@ -97,9 +90,10 @@ async def run(self, args: ASRArgs) -> ASRResponse: for i in postprocessing_ins ] logger.info("running postprocessing...") - await gather(*postprocessing_acts) + n_transcribed = await gather(*postprocessing_acts) + n_transcribed = sum(n_transcribed) logger.info("postprocessing complete !") - return ASRResponse(n_transcribed=len(args.docs)) + return ASRResponse(n_transcribed=n_transcribed) REGISTRY = [ASRWorkflow] diff --git a/asr-worker/tests/conftest.py b/asr-worker/tests/conftest.py index fae2d2d0..dfaeb63c 100644 --- a/asr-worker/tests/conftest.py +++ b/asr-worker/tests/conftest.py @@ -1,3 +1,5 @@ +from pathlib import Path + import pytest from _pytest.tmpdir import TempPathFactory from asr_worker.config import ASRWorkerConfig @@ -48,7 +50,11 @@ def test_worker_config(tmp_path_factory: TempPathFactory) -> ASRWorkerConfig: # @pytest.fixture(scope="session") def doc_0() -> Document: return Document( - id="doc-0", root_document="root-0", language="ENGLISH", content_type="audio/wav" + id="doc-0", + root_document="root-0", + language="ENGLISH", + content_type="audio/wav", + path=Path("doc-0.wav"), ) @@ -59,13 +65,18 @@ def doc_1() -> Document: root_document="root-1", language="ENGLISH", content_type="application/json", + path=Path("doc-1.json"), ) @pytest.fixture(scope="session") def doc_2() -> Document: return Document( - id="doc-2", root_document="root-2", language="FRENCH", content_type="audio/mpeg" + id="doc-2", + root_document="root-2", + language="FRENCH", + content_type="audio/mpeg", + path=Path("doc-2.mp3"), ) diff --git a/asr-worker/tests/test_activities.py b/asr-worker/tests/test_activities.py index 9ebee97d..16e77625 100644 --- a/asr-worker/tests/test_activities.py +++ b/asr-worker/tests/test_activities.py @@ -1,125 +1,271 @@ import json -import math from collections.abc import AsyncGenerator, Iterable +from itertools import cycle from pathlib import Path from typing import Self import pytest from aiostream import stream from asr_worker.activities import ( - preprocess, - search_audios, - write_audio_search_results, - write_transcription, + infer_act, + postprocess_act, + preprocess_act, + read_batch, + search_audio_paths_act, + write_audio_batch, + write_audio_batches, ) -from asr_worker.constants import SUPPORTED_CONTENT_TYPES -from asr_worker.models import Timestamp, Transcript, Transcription +from asr_worker.models import DocId, Transcription from asr_worker.utils import read_jsonl -from caul.objects import ASRResult, InputMetadata, PreprocessedInput -from caul.tasks import Preprocessor +from caul.objects import ASRResult, InputMetadata, PreprocessedInput, PreprocessorOutput +from caul.tasks import InferenceRunner, Postprocessor, Preprocessor from datashare_python.conftest import TEST_PROJECT from datashare_python.objects import Document from icij_common.es import ESClient, ids_query, match_all +from icij_common.iter_utils import batches from icij_common.registrable import RegistrableConfig -PREPROCESSED_INPUT_0 = PreprocessedInput(metadata=InputMetadata(duration_s=0.0)) -PREPROCESSED_INPUT_1 = PreprocessedInput(metadata=InputMetadata(duration_s=1.0)) -PREPROCESSED_INPUT_2 = PreprocessedInput(metadata=InputMetadata(duration_s=2.0)) +PREPROCESSED_INPUT_0 = PreprocessedInput( + metadata=InputMetadata( + input_ordering=0, + duration_s=0.0, + preprocessed_file_path=Path("preprocessed_0.wav"), + ) +) +PREPROCESSED_INPUT_1 = PreprocessedInput( + metadata=InputMetadata( + input_ordering=1, + duration_s=1.0, + preprocessed_file_path=Path("preprocessed_1.wav"), + ) +) +PREPROCESSED_INPUT_2 = PreprocessedInput( + metadata=InputMetadata( + input_ordering=2, + duration_s=2.0, + preprocessed_file_path=Path("preprocessed_2.wav"), + ) +) + +INFERENCE_RESULTS = [ + ASRResult( + input_ordering=0, + transcription=[(0.0, 0.0, "preprocessed_0")], + score=1.0, + ), + ASRResult( + input_ordering=1, + transcription=[(0.0, 1.0, "preprocessed_1")], + score=1.0, + ), + ASRResult( + input_ordering=2, + transcription=[(0.0, 2.0, "preprocessed_2")], + score=1.0, + ), +] -class MockProcessor(Preprocessor): +class MockPeprocessor(Preprocessor): + def __init__(self, batch_size: int) -> None: + self._batch_size = batch_size + @classmethod def _from_config(cls, config: RegistrableConfig, **kwargs) -> Self: # noqa: ARG003 return cls(**kwargs) - def __init__(self, batches: list[list[PreprocessedInput]]) -> None: - self._batches = batches - def process( self, audios: Iterable[Path], # noqa: ARG002 **kwargs, # noqa: ARG002 ) -> Iterable[list[PreprocessedInput]]: - yield from self._batches + outputs = cycle( + [PREPROCESSED_INPUT_0, PREPROCESSED_INPUT_1, PREPROCESSED_INPUT_2] + ) + outputs = [next(outputs) for _ in audios] + for b in batches(outputs, self._batch_size): + yield list(b) -def test_preprocess(tmpdir: Path) -> None: - # Given - output_dir = Path(tmpdir) - batches = [[PREPROCESSED_INPUT_0, PREPROCESSED_INPUT_1], [PREPROCESSED_INPUT_2]] - preprocessor = MockProcessor(batches) - audios = [] - # When - batch_files = list(preprocess(preprocessor, audios, output_dir=output_dir)) - # Then - assert len(batch_files) == 2 - written_batches = [ - [PreprocessedInput.model_validate(d) for d in read_jsonl(f)] - for f in batch_files - ] - assert written_batches == batches +class MockInferenceRunner(InferenceRunner): + @classmethod + def _from_config(cls, config: RegistrableConfig, **kwargs) -> Self: # noqa: ARG003 + return cls() + def process( + self, + inputs: Iterable[list[PreprocessorOutput]], + *args, # noqa: ARG002 + **kwargs, # noqa: ARG002 + ) -> Iterable[ASRResult]: + i = 0 + for batch in inputs: + for preprocessed in batch: + transcription = ( + preprocessed.metadata.preprocessed_file_path.name.replace( + ".wav", "" + ) + ) + transcription = [(0.0, float(i), transcription)] + yield ASRResult( + input_ordering=i, transcription=transcription, score=1.0 + ) + i += 1 -def test_write_transcription(tmpdir: Path) -> None: - # Given - asr_result = ASRResult(transcription=[(0.0, 1.0, "text")], score=math.log(0.5)) - transcribed_filename = "0011someid" - artifacts_root = Path(tmpdir) - project = TEST_PROJECT - # When - write_transcription( - asr_result, transcribed_filename, artifacts_root=artifacts_root, project=project - ) - # Then - expected_artifact_dir = artifacts_root / project / "00" / "11" / "0011someid" - assert expected_artifact_dir.exists() - metadata_path = expected_artifact_dir / "metadata.json" - assert metadata_path.exists() - metadata = json.loads(metadata_path.read_text()) - assert metadata["transcription"] == "transcription.json" - transcription_path = expected_artifact_dir / metadata["transcription"] - assert transcription_path.exists() - transcription = Transcription.model_validate_json(transcription_path.read_text()) - expected_transcription = Transcription( - transcripts=[ - Transcript(text="text", timestamp=Timestamp(start_s=0.0, end_s=1.0)) - ], - confidence=0.5, - ) - assert transcription == expected_transcription + +class MockPostprocessor(Postprocessor): + @classmethod + def _from_config(cls, config: RegistrableConfig, **kwargs) -> Self: # noqa: ARG003 + return cls() + + def process( + self, + inputs: Iterable[ASRResult], + *args, # noqa: ARG002 + **kwargs, # noqa: ARG002 + ) -> Iterable[ASRResult]: + yield from inputs @pytest.mark.parametrize( - ("query", "expected_docs"), + ("query", "expected_batches"), [ # Supports empty query - ({}, ["doc-0", "doc-2"]), + ({}, [[("doc-0", Path("doc-0.wav")), ("doc-2", Path("doc-2.mp3"))]]), # Return all audio/video docs - (match_all(), ["doc-0", "doc-2"]), - (ids_query(["doc-0"]), ["doc-0"]), + (match_all(), [[("doc-0", Path("doc-0.wav")), ("doc-2", Path("doc-2.mp3"))]]), + (ids_query(["doc-0"]), [[("doc-0", Path("doc-0.wav"))]]), # Should filter non supported content type (ids_query(["doc-1"]), []), ], ) -async def test_search_audios( +async def test_search_audio_paths_act( populate_es_with_audio: list[Document], test_es_client: ESClient, query: dict, - expected_docs: list[str], + expected_batches: list[tuple[DocId, Path]], + tmpdir: Path, ) -> None: # Given - assert len(populate_es_with_audio) == 4 + tmpdir = Path(tmpdir) + batch_size = len(populate_es_with_audio) + assert batch_size == 4 client = test_es_client # When - results = search_audios( - es_client=client, - project=TEST_PROJECT, - query=query, - supported_content_types=SUPPORTED_CONTENT_TYPES, + batch_paths = [ + batch + async for batch in search_audio_paths_act( + es_client=client, + project=TEST_PROJECT, + query=query, + batch_size=batch_size, + output_dir=tmpdir, + ) + ] + # Then + results = [] + for b in batch_paths: + with b.open() as f: + results.append(list(read_batch(f))) + assert results == expected_batches + + +def test_preprocess_act(tmpdir: Path) -> None: + # Given + tmpdir = Path(tmpdir) + n_audios = 3 + batch_size = n_audios - 1 + output_dir = tmpdir.joinpath("artifacts") + output_dir.mkdir() + audio_root = tmpdir.joinpath("audio_root") + audio_batch = tmpdir / "audio_batch.txt" + batch = [(f"doc-{i}", Path(str(i))) for i in range(n_audios)] + with audio_batch.open("w") as f: + write_audio_batch(batch, f) + preprocessor = MockPeprocessor(batch_size=batch_size) + + # When + batch_files = preprocess_act( + preprocessor, + audio_batch=audio_batch, + audio_root=audio_root, + output_dir=output_dir, ) + # Then - docs_ids = [i async for i in results] - assert docs_ids == expected_docs + assert len(batch_files) == 2 + expected_batches = [ + [PREPROCESSED_INPUT_0, PREPROCESSED_INPUT_1], + [PREPROCESSED_INPUT_2], + ] + written_batches = [ + [PreprocessedInput.model_validate(d) for d in read_jsonl(output_dir / f)] + for f in batch_files + ] + assert written_batches == expected_batches + + +def test_infer_act(tmpdir: Path) -> None: + # Given + inference_runner = MockInferenceRunner() + workdir = Path(tmpdir) / "workdir" + workdir.mkdir() + output_dir = Path(tmpdir) + preprocessed_inputs = [ + PREPROCESSED_INPUT_0, + PREPROCESSED_INPUT_1, + PREPROCESSED_INPUT_2, + ] + paths = [] + for p_i, p in enumerate(preprocessed_inputs): + input_path = workdir / f"{p_i}.json" + input_path.write_text(p.model_dump_json()) + paths.append(input_path) + # When + asr_result_paths = infer_act( + inference_runner, preprocessed_inputs=paths, output_dir=output_dir + ) + # Then + asr_results = [ + ASRResult.model_validate_json((output_dir / p).read_text()) + for p in asr_result_paths + ] + assert asr_results == INFERENCE_RESULTS + + +def test_postprocess_act(tmpdir: Path) -> None: + # Given + postprocessor = MockPostprocessor() + project = TEST_PROJECT + artifacts_root = Path(tmpdir) + doc_ids = [f"{str(i) * 4}-doc-{i}" for i in range(3)] + # When + postprocess_act( + postprocessor, + INFERENCE_RESULTS, + doc_ids=doc_ids, + artifacts_root=artifacts_root, + project=project, + ) + # Then + expected_artifact_dirs = [ + artifacts_root / project / "00" / "00" / "0000-doc-0", + artifacts_root / project / "11" / "11" / "1111-doc-1", + artifacts_root / project / "22" / "22" / "2222-doc-2", + ] + for res, d in zip(INFERENCE_RESULTS, expected_artifact_dirs, strict=True): + assert d.exists() + metadata_path = d / "metadata.json" + assert metadata_path.exists() + metadata = json.loads(metadata_path.read_text()) + assert metadata["transcription"] == "transcription.json" + transcription_path = d / metadata["transcription"] + assert transcription_path.exists() + transcription = Transcription.model_validate_json( + transcription_path.read_text() + ) + expected_transcription = Transcription.from_asr_handler_result(res) + assert transcription == expected_transcription async def test_write_audio_search_results(tmpdir: Path) -> None: @@ -127,21 +273,27 @@ async def test_write_audio_search_results(tmpdir: Path) -> None: root = Path(tmpdir) batch_size = 2 - async def results() -> AsyncGenerator[str, None]: + async def results() -> AsyncGenerator[tuple[DocId, str], None]: res = ["doc-0", "doc-1", "doc-2"] for r in res: - yield r + yield r, f"{r}.wav" # When - batches = write_audio_search_results(results(), root=root, batch_size=batch_size) + results = write_audio_batches(results(), root=root, batch_size=batch_size) # Then async def expected_content() -> AsyncGenerator[str, None]: - contents = ["doc-0\ndoc-1\n", "doc-2\n"] - for e in contents: - yield e + contents = [ + [ + '{"doc_id": "doc-0", "path": "doc-0.wav"}', + '{"doc_id": "doc-1", "path": "doc-1.wav"}', + ], + ['{"doc_id": "doc-2", "path": "doc-2.wav"}'], + ] + for line in contents: + yield "\n".join(line) + "\n" - batches_and_expected_content = stream.zip(batches, expected_content()) + batches_and_expected_content = stream.zip(results, expected_content()) async with batches_and_expected_content.stream() as streamed: async for p, expected_content in streamed: diff --git a/asr-worker/tests/test_workflows.py b/asr-worker/tests/test_workflows.py index 28625b2b..55636f7c 100644 --- a/asr-worker/tests/test_workflows.py +++ b/asr-worker/tests/test_workflows.py @@ -5,35 +5,25 @@ from asyncio import AbstractEventLoop from collections.abc import AsyncGenerator from pathlib import Path -from typing import cast import pytest -from asr_worker.activities import ASRActivities, write_transcription +from asr_worker.activities import ASRActivities from asr_worker.config import ASRWorkerConfig -from asr_worker.constants import ( - POSTPROCESS_ACTIVITY, - PREPROCESS_ACTIVITY, - RUN_INFERENCE_ACTIVITY, -) -from asr_worker.dependencies import REGISTRY, lifespan_worker_config +from asr_worker.constants import SUPPORTED_CONTENT_TYPES +from asr_worker.dependencies import REGISTRY from asr_worker.models import ( ASRArgs, ASRPipelineConfig, - ASRResponse, + DocId, Timestamp, Transcript, Transcription, ) from asr_worker.workflows import ASRWorkflow, TaskQueues -from caul.config import InferenceRunnerConfig, PostprocessorConfig -from caul.objects import ASRResult, InputMetadata, PreprocessedInput -from datashare_python.config import WorkerConfig +from caul.objects import ASRResult from datashare_python.conftest import TEST_PROJECT -from datashare_python.types_ import ( - ProgressRateHandler, - TemporalClient, -) -from datashare_python.utils import ActivityWithProgress, activity_defn +from datashare_python.objects import Document +from datashare_python.types_ import TemporalClient from datashare_python.worker import worker_context from pydantic import TypeAdapter from temporalio.worker import Worker @@ -54,102 +44,6 @@ _TRANSCRIPTIONS = [Transcription.from_asr_handler_result(res) for res in _MODEL_RESULTS] -class MockedASRActivities(ActivityWithProgress): - @activity_defn(name=PREPROCESS_ACTIVITY) - def preprocess(self, paths: list[Path]) -> list[Path]: - # TODO: this shouldn't be necessary, fix this bug - paths = _LIST_OF_PATH_ADAPTER.validate_python(paths) - worker_config = cast(ASRWorkerConfig, lifespan_worker_config()) - workdir = worker_config.workdir - workdir.mkdir(parents=True, exist_ok=True) - batches = [] - for path_i, path in enumerate(paths): - n_segments = len(path.name) - for part_i in range(n_segments): - seg_path = f"file_{path_i}_part_{part_i}.wav" - metadata = InputMetadata( - duration_s=1.0, - input_ordering=path_i, - preprocessed_file_path=Path(seg_path), - ) - preprocessed_input = PreprocessedInput(metadata=metadata) - (workdir / seg_path).write_text(preprocessed_input.model_dump_json()) - batches.append(seg_path) - return batches - - @activity_defn(name=RUN_INFERENCE_ACTIVITY) - def infer( - self, - preprocessed_inputs: list[Path], - config: InferenceRunnerConfig, # noqa: ARG002 - *, - progress: ProgressRateHandler | None = None, # noqa: ARG002 - ) -> list[Path]: # noqa: ANN001, ARG001 - # TODO: this shouldn't be necessary, fix this bug - preprocessed_inputs = _LIST_OF_PATH_ADAPTER.validate_python(preprocessed_inputs) - worker_config = cast(ASRWorkerConfig, lifespan_worker_config()) - workdir = worker_config.workdir - paths = [] - preprocessed_inputs = [ - PreprocessedInput.model_validate_json((workdir / p).read_text()) - for p in preprocessed_inputs - ] - for preprocessed_i, i in enumerate(preprocessed_inputs): - res = _MODEL_RESULTS[preprocessed_i % len(_MODEL_RESULTS)] - res = ASRResult( - input_ordering=i.metadata.input_ordering, - transcription=res.transcription, - score=res.score, - ) - filename = f"{uuid.uuid4().hex[:20]}-transcript.json" - (workdir / filename).write_text(res.model_dump_json()) - paths.append(filename) - return paths - - @activity_defn(name=POSTPROCESS_ACTIVITY) - def postprocess( - self, - inference_results: list[Path], - input_paths: list[Path], - config: PostprocessorConfig, # noqa: ARG002 - project: str, - *, - progress: ProgressRateHandler | None = None, # noqa: ARG002 - ) -> None: - # TODO: this shouldn't be necessary, fix this bug - inference_results = _LIST_OF_PATH_ADAPTER.validate_python(inference_results) - input_paths = _LIST_OF_PATH_ADAPTER.validate_python(input_paths) - worker_config = cast(ASRWorkerConfig, lifespan_worker_config()) - workdir = worker_config.workdir - artifact_root = worker_config.artifacts_root - artifact_root.mkdir(parents=True, exist_ok=True) - inference_results = [ - ASRResult.model_validate_json((workdir / f).read_text()) - for f in inference_results - ] - current_res = None - asr_results, transcription, scores = [], [], [] - for res in inference_results: - if res.input_ordering != current_res and current_res is not None: - score = (sum(scores) / len(scores)) if scores else 0 - asr_results.append( - ASRResult(transcription=sum(transcription, []), score=score) - ) - asr_results, transcription, scores = [], [], [] - current_res = res.input_ordering - transcription.append(res.transcription) - scores.append(res.score) - asr_results.append( - ASRResult( - transcription=sum(transcription, []), score=sum(scores) / len(scores) - ) - ) - for original, asr_result in zip(input_paths, asr_results, strict=True): - write_transcription( - asr_result, original.name, artifacts_root=artifact_root, project=project - ) - - @pytest.fixture async def io_bound_worker( test_temporal_client_session: TemporalClient, @@ -160,6 +54,7 @@ async def io_bound_worker( worker_id = f"worker-{uuid.uuid4()}" task_queue = TaskQueues.IO dependencies = REGISTRY["io"] + activities = ASRActivities(client, event_loop) worker_ctx = worker_context( worker_id, worker_config=test_worker_config, @@ -167,54 +62,7 @@ async def io_bound_worker( event_loop=event_loop, task_queue=task_queue, workflows=[ASRWorkflow], - dependencies=dependencies, - ) - async with worker_ctx: - yield - - -@pytest.fixture -async def mock_cpu_bound_worker( - test_temporal_client_session: TemporalClient, - test_worker_config: WorkerConfig, - event_loop: AbstractEventLoop, # noqa: F811 -) -> AsyncGenerator[None, None]: - client = test_temporal_client_session - activities = MockedASRActivities(client, event_loop) - worker_id = f"worker-{uuid.uuid4()}" - task_queue = TaskQueues.CPU - dependencies = REGISTRY["preprocessing"] - worker_ctx = worker_context( - worker_id, - worker_config=test_worker_config, - client=client, - event_loop=event_loop, - task_queue=task_queue, - activities=[activities.preprocess, activities.postprocess], - dependencies=dependencies, - ) - async with worker_ctx: - yield - - -@pytest.fixture -async def mock_gpu_inference_worker( - test_temporal_client_session: TemporalClient, - test_worker_config: WorkerConfig, - event_loop: AbstractEventLoop, # noqa: F811 -) -> AsyncGenerator[None, None]: - client = test_temporal_client_session - activities = MockedASRActivities(client, event_loop) - task_queue = TaskQueues.INFERENCE_GPU - worker_id = f"worker-{uuid.uuid4()}" - dependencies = REGISTRY["inference"] - worker_ctx = worker_context( - worker_id, - worker_config=test_worker_config, - client=client, - event_loop=event_loop, - task_queue=task_queue, - activities=[activities.infer], + activities=[activities.search_audio_paths], dependencies=dependencies, ) async with worker_ctx: @@ -287,63 +135,20 @@ async def gpu_inference_worker( ) -async def test_asr_workflow( - test_temporal_client_session: TemporalClient, - mock_cpu_bound_worker: Worker, # noqa: ARG001 - mock_gpu_inference_worker: Worker, # noqa: ARG001 - io_bound_worker: Worker, # noqa: ARG001 - test_worker_config: ASRWorkerConfig, -) -> None: - # Given - worker_config = test_worker_config - client = test_temporal_client_session - path = [Path("aabb"), Path("cc")] - batch_size = 1 - config = ASRPipelineConfig() - workflow_id = f"asr-{uuid.uuid4().hex}" - project = TEST_PROJECT - inputs = ASRArgs(project=project, docs=path, config=config, batch_size=batch_size) - # When - result = await client.execute_workflow( - ASRWorkflow.run, inputs, id=workflow_id, task_queue=TaskQueues.IO - ) - # Then - expected_response = ASRResponse(n_transcribed=2) - assert result == expected_response - artifacts_root = worker_config.artifacts_root - expected_transcription_dirs = [ - artifacts_root / project / "aa" / "bb" / "aabb", - artifacts_root / project / "cc" / "cc", - ] - expected_transcriptions = [_EXPECTED_TRANSCRIPTION_0, _EXPECTED_TRANSCRIPTION_1] - for expected_t, d in zip( - expected_transcriptions, expected_transcription_dirs, strict=True - ): - assert d.exists() - assert d.is_dir() - meta_path = d / "metadata.json" - assert meta_path.exists() - meta = json.loads(meta_path.read_text()) - transcription_name = meta.get("transcription") - assert transcription_name is not None - transcription_path = d / transcription_name - assert transcription_path.exists() - transcription = Transcription.model_validate_json( - transcription_path.read_text() - ) - assert transcription == expected_t - - @pytest.fixture -def with_audios(test_worker_config: ASRWorkerConfig) -> list[Path]: +def with_audio_docs( + populate_es_with_audio: list[Document], test_worker_config: ASRWorkerConfig +) -> list[tuple[DocId, Path]]: config = test_worker_config - audios = [f for f in AUDIOS_PATH.iterdir() if f.suffix == ".wav"] + docs = [ + d for d in populate_es_with_audio if d.content_type in SUPPORTED_CONTENT_TYPES + ] paths = [] config.audios_root.mkdir(parents=True, exist_ok=True) - for audio in audios: - rel_path = audio.relative_to(AUDIOS_PATH) - shutil.copy(audio, config.audios_root / rel_path) - paths.append(rel_path) + audio_path = AUDIOS_PATH / "asr_test.wav" + for doc in docs: + shutil.copy(audio_path, config.audios_root / doc.path) + paths.append((doc.id, doc.path)) return paths @@ -354,17 +159,19 @@ async def test_asr_workflow_e2e( gpu_inference_worker: Worker, # noqa: ARG001 io_bound_worker: Worker, # noqa: ARG001 test_worker_config: ASRWorkerConfig, - with_audios: list[Path], + with_audio_docs: list[tuple[DocId, Path]], # noqa: ARG001 ) -> None: # Given config = test_worker_config + artifacts_root = config.artifacts_root client = test_temporal_client_session - n_audios = 3 + n_audios = len(with_audio_docs) batch_size = n_audios - 1 - audios = with_audios * n_audios project = TEST_PROJECT + doc_ids, _ = zip(*with_audio_docs, strict=True) + doc_ids = list(doc_ids) args = ASRArgs( - project=project, docs=audios, config=ASRPipelineConfig(), batch_size=batch_size + project=project, docs=doc_ids, config=ASRPipelineConfig(), batch_size=batch_size ) workflow_id = f"asr-{uuid.uuid4().hex}" @@ -375,26 +182,30 @@ async def test_asr_workflow_e2e( # Then assert response.n_transcribed == n_audios - expected_transcription_path = ( - config.artifacts_root / project / "as" / "r_" / "asr_test.wav" - ) - assert expected_transcription_path.exists() - assert expected_transcription_path.is_dir() - meta_path = expected_transcription_path / "metadata.json" - assert meta_path.exists() - meta = json.loads(meta_path.read_text()) - transcription_name = meta.get("transcription") - assert transcription_name is not None - transcription_path = expected_transcription_path / transcription_name - assert transcription_path.exists() - transcription = Transcription.model_validate_json(transcription_path.read_text()) - expcted_transcription = Transcription( - transcripts=[ - Transcript( - text="To embrace the chaos that they fought in this battle.", - timestamp=Timestamp.from_floats(0.08, 2.56), - ) - ], - confidence=math.exp(-248.3), - ) - assert transcription == expcted_transcription + expected_artifact_dirs = [ + artifacts_root / project / "do" / "c-" / "doc-0", + artifacts_root / project / "do" / "c-" / "doc-2", + ] + for d in expected_artifact_dirs: + assert d.exists() + assert d.is_dir() + meta_path = d / "metadata.json" + assert meta_path.exists() + meta = json.loads(meta_path.read_text()) + transcription_name = meta.get("transcription") + assert transcription_name is not None + transcription_path = d / transcription_name + assert transcription_path.exists() + transcription = Transcription.model_validate_json( + transcription_path.read_text() + ) + expcted_transcription = Transcription( + transcripts=[ + Transcript( + text="To embrace the chaos that they fought in this battle.", + timestamp=Timestamp.from_floats(0.08, 2.56), + ) + ], + confidence=math.exp(-248.3), + ) + assert transcription == expcted_transcription diff --git a/datashare-python/datashare_python/conftest.py b/datashare-python/datashare_python/conftest.py index c95b99a3..2f8062f5 100644 --- a/datashare-python/datashare_python/conftest.py +++ b/datashare-python/datashare_python/conftest.py @@ -200,6 +200,8 @@ def index_docs_ops( } doc = doc.model_dump(by_alias=True) # noqa: PLW2901 op.update(doc) + if "path" in op: + op["path"] = str(op["path"]) op["_id"] = doc[ID] op["routing"] = doc[DOC_ROOT_ID] op["type"] = ES_DOCUMENT_TYPE diff --git a/datashare-python/datashare_python/objects.py b/datashare-python/datashare_python/objects.py index e8a1373c..05a8af08 100644 --- a/datashare-python/datashare_python/objects.py +++ b/datashare-python/datashare_python/objects.py @@ -3,6 +3,7 @@ from dataclasses import dataclass from datetime import UTC, datetime from enum import StrEnum, unique +from pathlib import Path from typing import Any, Literal, Self, TypeVar from temporalio import workflow @@ -142,6 +143,7 @@ class Document(DatashareModel): language: str content: str | None = None content_type: str | None = None + path: Path | None = None tags: list[str] = Field(default_factory=list) content_translated: dict[str, str] = Field( default_factory=dict, alias="content_translated" diff --git a/datashare-python/datashare_python/utils.py b/datashare-python/datashare_python/utils.py index 390aa3a6..6ad646c1 100644 --- a/datashare-python/datashare_python/utils.py +++ b/datashare-python/datashare_python/utils.py @@ -13,6 +13,7 @@ from functools import partial, wraps from hashlib import sha256 from inspect import signature +from io import BytesIO from pathlib import Path from typing import Any, ParamSpec, TypeVar from uuid import uuid4 @@ -76,6 +77,15 @@ def to_progress(self) -> Progress: return Progress(current=self.progress * self.weight, max_progress=self.weight) +@dataclass(frozen=True) +class DocArtifact: + project: str + doc_id: str + artifact: bytes | BytesIO + filename: str + metadata_key: str + + class ActivityWithProgress: def __init__(self, temporal_client: Client, event_loop: asyncio.AbstractEventLoop): self._temporal_client = temporal_client @@ -431,49 +441,43 @@ def _handlers( return handlers -def safe_dir(filename: str) -> Path: - filename = filename.split(".", maxsplit=1)[0] - parts = (p for p in (filename[:2], filename[2:4]) if p) +def safe_dir(doc_id: str) -> Path: + if len(doc_id) < 4: + raise ValueError(f"expected doc_id to be at least 4, found {doc_id}") + parts = (p for p in (doc_id[:2], doc_id[2:4]) if p) return Path(*parts) -def artifacts_dir(project: str, *, filename: str) -> Path: - return Path(project, safe_dir(filename), filename) +def _artifacts_dir(doc_id: str, *, project: str) -> Path: + return Path(project, safe_dir(doc_id), doc_id) -def metadata_path(filename: str, *, project: str) -> Path: - metadata_path = artifacts_dir(project, filename=filename) / METADATA_JSON +def _metadata_path(doc_id: str, *, project: str) -> Path: + metadata_path = _artifacts_dir(doc_id, project=project) / METADATA_JSON return metadata_path -def _read_artifact_metadata(root: Path, project: str, *, filename: str) -> dict: - m_path = root / metadata_path(filename, project=project) +def _read_artifact_metadata(root: Path, artifact: DocArtifact) -> dict: + m_path = root / _metadata_path(artifact.filename, project=artifact.project) return json.loads(m_path.read_text()) -def write_artifact( - artifact: bytes, - root: Path, - *, - project: str, - filename: str, - metadata_key: str, - metadata_value: str, -) -> Path: - artif_dir = root / artifacts_dir(project, filename=filename) +def write_artifact(root: Path, artifact: DocArtifact) -> Path: + artif_dir = root / _artifacts_dir(artifact.doc_id, project=artifact.project) artif_dir.mkdir(exist_ok=True, parents=True) # TODO: if transcriptions are too large we could also serialize them # as jsonl - transcription_path = artif_dir / metadata_value - transcription_path.write_bytes(artifact) - try: - meta = _read_artifact_metadata(root, project, filename=filename) - except FileNotFoundError: - meta = dict() - meta[metadata_key] = metadata_value - meta_path = root / artifacts_dir(project, filename=filename) / METADATA_JSON + artifact_path: Path = artif_dir / artifact.filename + if isinstance(artifact.artifact, bytes): + artifact_path.write_bytes(artifact.artifact) + elif isinstance(artifact_path, BytesIO): + with artifact_path.open("wb") as f: + f.write(artifact.artifact.read()) + meta_path = root / _metadata_path(artifact.doc_id, project=artifact.project) + meta = _read_artifact_metadata(root, artifact) if meta_path.exists() else dict() + meta[artifact.metadata_key] = artifact.filename meta_path.write_text(json.dumps(meta)) - return transcription_path.relative_to(artif_dir) + return artifact_path.relative_to(artif_dir) def debuggable_name( diff --git a/worker-template/worker_template/classify.py b/worker-template/worker_template/classify.py index f6b11780..bda643e8 100644 --- a/worker-template/worker_template/classify.py +++ b/worker-template/worker_template/classify.py @@ -7,7 +7,6 @@ from datashare_python.utils import ( ActivityWithProgress, activity_defn, - batches, to_raw_progress, to_scaled_progress, ) @@ -30,6 +29,7 @@ bulk_action, has_id, ) +from icij_common.iter_utils import batches from temporalio import activity from temporalio.client import Client from transformers import Pipeline, pipeline diff --git a/worker-template/worker_template/translate.py b/worker-template/worker_template/translate.py index 19b03290..bcf9b859 100644 --- a/worker-template/worker_template/translate.py +++ b/worker-template/worker_template/translate.py @@ -5,15 +5,7 @@ from aiostream.stream import chain from datashare_python.objects import Document from datashare_python.types_ import ProgressRateHandler -from datashare_python.utils import ( - ActivityWithProgress, - activity_defn, - async_batches, - batches, - before_and_after, - once, - to_raw_progress, -) +from datashare_python.utils import ActivityWithProgress, activity_defn, to_raw_progress from elasticsearch._async.helpers import async_bulk from icij_common.es import ( BOOL, @@ -30,6 +22,7 @@ has_id, must_not, ) +from icij_common.iter_utils import async_batches, batches, before_and_after, once from temporalio.client import Client from transformers import Pipeline, pipeline