From 078f655bba3a2ad72e78f57d3fa56cdafb02394f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cle=CC=81ment=20Doumouro?= Date: Wed, 13 May 2026 11:16:25 +0200 Subject: [PATCH] feature(translator-worker): initialize translation pipeline from config --- datashare-python/datashare_python/conftest.py | 6 +- datashare-python/datashare_python/objects.py | 266 +++-- datashare-python/datashare_python/utils.py | 1 - datashare-python/pyproject.toml | 1 + .../tests/{test_object.py => test_objects.py} | 29 +- datashare-python/uv.lock | 20 + worker-template/uv.dist.lock | 6 +- worker-template/uv.lock | 20 + workers/asr-worker/uv.lock | 2 + workers/translation-worker/pyproject.toml | 1 + workers/translation-worker/tests/conftest.py | 20 +- .../tests/test_activities.py | 1044 +++++++---------- .../translation-worker/tests/test_objects.py | 22 + ...{test_translation.py => test_workflows.py} | 10 +- .../translation_worker/__init__.py | 8 + .../translation_worker/activities.py | 551 +++++---- .../translation_worker/config.py | 20 +- .../translation_worker/constants.py | 5 +- .../translation_worker/core.py | 253 ---- .../translation_worker/objects.py | 152 ++- .../translation_worker/processors.py | 90 ++ .../translation_worker/sentence_splitters.py | 42 + .../translators/__init__.py | 4 + .../translation_worker/translators/argos.py | 194 +++ .../translation_worker/workflows.py | 22 +- workers/translation-worker/uv.dist.lock | 17 +- workers/translation-worker/uv.lock | 13 + 27 files changed, 1467 insertions(+), 1352 deletions(-) rename datashare-python/tests/{test_object.py => test_objects.py} (70%) create mode 100644 workers/translation-worker/tests/test_objects.py rename workers/translation-worker/tests/{test_translation.py => test_workflows.py} (83%) delete mode 100644 workers/translation-worker/translation_worker/core.py create mode 100644 workers/translation-worker/translation_worker/processors.py create mode 100644 workers/translation-worker/translation_worker/sentence_splitters.py create mode 100644 workers/translation-worker/translation_worker/translators/__init__.py create mode 100644 workers/translation-worker/translation_worker/translators/argos.py diff --git a/datashare-python/datashare_python/conftest.py b/datashare-python/datashare_python/conftest.py index 33d74240..0d1d5555 100644 --- a/datashare-python/datashare_python/conftest.py +++ b/datashare-python/datashare_python/conftest.py @@ -53,7 +53,6 @@ "join": {"type": "join", "relations": {"Document": "NamedEntity"}}, "contentType": {"type": "keyword"}, "content": {"type": "text"}, - "contentTranslated": {"type": "text"}, } } } @@ -197,10 +196,7 @@ def index_docs_ops( docs: list[Document], index_name: str ) -> Generator[dict, None, None]: for doc in docs: - op = { - "_op_type": "index", - "_index": index_name, - } + op = {"_op_type": "index", "_index": index_name} doc = doc.model_dump(by_alias=True) # noqa: PLW2901 op.update(doc) if "path" in op: diff --git a/datashare-python/datashare_python/objects.py b/datashare-python/datashare_python/objects.py index e488f7a4..839aee12 100644 --- a/datashare-python/datashare_python/objects.py +++ b/datashare-python/datashare_python/objects.py @@ -6,8 +6,10 @@ from enum import StrEnum, unique from io import BytesIO from pathlib import Path -from typing import Annotated, Any, Literal, Self, TypeVar, cast +from typing import Annotated, Any, ClassVar, Literal, Self, TypeVar, cast +from pydantic_core import PydanticCustomError, ValidationError, core_schema +from pydantic_extra_types.language_code import LanguageName from temporalio import workflow from .constants import TIKA_METADATA_RESOURCENAME @@ -31,7 +33,13 @@ merge_configs, no_enum_values_config, ) -from pydantic import AfterValidator, Field +from pydantic import ( + AfterValidator, + Field, + GetCoreSchemaHandler, + TypeAdapter, + model_validator, +) from pydantic import BaseModel as _BaseModel from pydantic.main import IncEx @@ -50,110 +58,45 @@ class DatashareModel(BaseModel): model_config = merge_configs(BaseModel.model_config, lowercamel_case_config()) -@unique -class TaskState(StrEnum): - CREATED = "CREATED" - QUEUED = "QUEUED" - RUNNING = "RUNNING" - ERROR = "ERROR" - DONE = "DONE" - CANCELLED = "CANCELLED" - - -READY_STATES = frozenset({TaskState.DONE, TaskState.ERROR, TaskState.CANCELLED}) - +class DatashareLanguage(str): + _language_type_adapter: ClassVar[TypeAdapter] = TypeAdapter(LanguageName) -class StacktraceItem(DatashareModel): - name: str - file: str - lineno: int - - -class Message(DatashareModel): - type: str = Field(frozen=True, alias="@type") + @classmethod + def _validate(cls, __input_value: str, _: core_schema.ValidationInfo) -> Self: + if __input_value != __input_value.upper(): + raise PydanticCustomError( + "datashare_language", "Invalid Datashare language, expected uppercase" + ) + try: + # Use pydantic provided validation + cls._language_type_adapter.validate_python(__input_value.title()) + except ValidationError as e: + raise PydanticCustomError( + "datashare_language", "Unknown Datashare language" + ) from e + return cls(__input_value) - def model_dump( - self, - *, - mode: Literal["json", "python"] | str = "python", - include: IncEx | None = None, - exclude: IncEx | None = None, - context: Any | None = None, - exclude_unset: bool = False, - exclude_defaults: bool = False, - exclude_none: bool = False, - round_trip: bool = False, - warnings: bool | Literal["none", "warn", "error"] = True, - fallback: Callable[[Any], Any] | None = None, - serialize_as_any: bool = False, - ) -> dict[str, Any]: - return super().model_dump( - by_alias=True, - mode=mode, - include=include, - exclude=exclude, - context=context, - exclude_unset=exclude_unset, - exclude_defaults=exclude_defaults, - exclude_none=exclude_none, - round_trip=round_trip, - warnings=warnings, - fallback=fallback, - serialize_as_any=serialize_as_any, + @classmethod + def __get_pydantic_core_schema__( + cls, source: type[Any], handler: GetCoreSchemaHandler + ) -> core_schema.AfterValidatorFunctionSchema: + return core_schema.with_info_after_validator_function( + cls._validate, + core_schema.str_schema(), + serialization=core_schema.to_string_ser_schema(), ) + @property + def as_language_name(self) -> LanguageName: + return LanguageName(self.title()) -class TaskResult(Message): - type: str = Field(frozen=True, alias="@type", default="TaskResult") - value: object - - -class TaskError(Message): - type: str = Field(frozen=True, alias="@type", default="TaskError") - name: str - message: str - cause: str | None = None - stacktrace: list[StacktraceItem] = Field(default_factory=list) - - -def _datetime_now() -> datetime: - return datetime.now(UTC) - - -class User(Message): - type: str = Field( - frozen=True, alias="@type", default="org.icij.datashare.user.User" - ) - id: str - name: str | None = None - email: str | None = None - provider: str | None = None - details: dict = dict() - - -class Task(Message): - type: str = Field(frozen=True, alias="@type", default="Task") - id: str - name: str - args: dict[str, object] | None = None - state: TaskState = TaskState.CREATED - result: TaskResult | None = None - error: TaskError | None = None - progress: float | None = None - created_at: datetime = Field(default_factory=_datetime_now) - completed_at: datetime | None = None - retries_left: int | None = None - max_retries: int | None = None - - -@dataclass(frozen=True) -class TaskGroup: - name: str + @property + def alpha2(self) -> str | None: + return self.as_language_name.alpha2 @property - @classmethod - def python(cls) -> Self: - return cls(name="PYTHON") + def alpha3(self) -> str: + return self.as_language_name.alpha3 @unique @@ -197,19 +140,29 @@ def locate( class Document(DatashareModel): id: str - language: str + language: DatashareLanguage index: str | None = None root_document: str | None = None content: str | None = None + content_text_length: int | None = None content_type: str | None = None path: Path | None = None tags: list[str] = Field(default_factory=list) - content_translated: dict[str, str] = Field( - default_factory=dict, alias="content_translated" + content_translated: list[dict[str, str]] = Field( + default_factory=list, alias="content_translated" ) metadata: dict[str, Any] | None = None type: str = Field(default="Document", frozen=True) + @model_validator(mode="before") + @classmethod + def _initialize_content_length_from_content(cls, data: Any) -> Any: + if isinstance(data, dict): + content_length = data.get("content_text_length") + if content_length is None and (content := data.get(DOC_CONTENT)): + data["content_text_length"] = len(content) + return data + @classmethod def from_es(cls, es_doc: dict) -> Self: sources = es_doc[SOURCE] @@ -217,7 +170,8 @@ def from_es(cls, es_doc: dict) -> Self: id=es_doc[ID_], index=es_doc.get(INDEX_), content=sources.get(DOC_CONTENT), - content_translated=sources.get(DOC_CONTENT_TRANSLATED, dict()), + content_translated=sources.get(DOC_CONTENT_TRANSLATED, []), + content_text_length=sources.get("content_text_length"), language=sources[DOC_LANGUAGE], root_document=sources.get(DOC_ROOT_ID), tags=sources.get("tags", []), @@ -265,3 +219,109 @@ class DocArtifact: artifact: bytes | BytesIO filename: str metadata_key: str + + +@unique +class TaskState(StrEnum): + CREATED = "CREATED" + QUEUED = "QUEUED" + RUNNING = "RUNNING" + ERROR = "ERROR" + DONE = "DONE" + CANCELLED = "CANCELLED" + + +READY_STATES = frozenset({TaskState.DONE, TaskState.ERROR, TaskState.CANCELLED}) + + +class StacktraceItem(DatashareModel): + name: str + file: str + lineno: int + + +class Message(DatashareModel): + type: str = Field(frozen=True, alias="@type") + + def model_dump( + self, + *, + mode: Literal["json", "python"] | str = "python", + include: IncEx | None = None, + exclude: IncEx | None = None, + context: Any | None = None, + exclude_unset: bool = False, + exclude_defaults: bool = False, + exclude_none: bool = False, + round_trip: bool = False, + warnings: bool | Literal["none", "warn", "error"] = True, + fallback: Callable[[Any], Any] | None = None, + serialize_as_any: bool = False, + ) -> dict[str, Any]: + return super().model_dump( + by_alias=True, + mode=mode, + include=include, + exclude=exclude, + context=context, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + round_trip=round_trip, + warnings=warnings, + fallback=fallback, + serialize_as_any=serialize_as_any, + ) + + +class TaskResult(Message): + type: str = Field(frozen=True, alias="@type", default="TaskResult") + value: object + + +class TaskError(Message): + type: str = Field(frozen=True, alias="@type", default="TaskError") + name: str + message: str + cause: str | None = None + stacktrace: list[StacktraceItem] = Field(default_factory=list) + + +def _datetime_now() -> datetime: + return datetime.now(UTC) + + +class User(Message): + type: str = Field( + frozen=True, alias="@type", default="org.icij.datashare.user.User" + ) + id: str + name: str | None = None + email: str | None = None + provider: str | None = None + details: dict = dict() + + +class Task(Message): + type: str = Field(frozen=True, alias="@type", default="Task") + id: str + name: str + args: dict[str, object] | None = None + state: TaskState = TaskState.CREATED + result: TaskResult | None = None + error: TaskError | None = None + progress: float | None = None + created_at: datetime = Field(default_factory=_datetime_now) + completed_at: datetime | None = None + retries_left: int | None = None + max_retries: int | None = None + + +@dataclass(frozen=True) +class TaskGroup: + name: str + + @property + @classmethod + def python(cls) -> Self: + return cls(name="PYTHON") diff --git a/datashare-python/datashare_python/utils.py b/datashare-python/datashare_python/utils.py index 6de9f2ec..87ff36a7 100644 --- a/datashare-python/datashare_python/utils.py +++ b/datashare-python/datashare_python/utils.py @@ -83,7 +83,6 @@ def __init__(self): async def update_progress(self, signal: ProgressSignal) -> None: async with self._update_lock: # TODO: remove this log - workflow.logger.debug("recording progress signal %s", signal) key = (signal.run_id, signal.activity_id) self._progress[key] = signal.to_progress() progress = sum(p.current for p in self._progress.values()) diff --git a/datashare-python/pyproject.toml b/datashare-python/pyproject.toml index a33cfab0..eef8e0bd 100644 --- a/datashare-python/pyproject.toml +++ b/datashare-python/pyproject.toml @@ -21,6 +21,7 @@ dependencies = [ "hatchling~=1.27", "pyyaml~=6.0", "orjson~=3.11", + "pydantic-extra-types[pycountry]>=2.11.1", ] [project.urls] diff --git a/datashare-python/tests/test_object.py b/datashare-python/tests/test_objects.py similarity index 70% rename from datashare-python/tests/test_object.py rename to datashare-python/tests/test_objects.py index cc63c522..c42f8322 100644 --- a/datashare-python/tests/test_object.py +++ b/datashare-python/tests/test_objects.py @@ -6,13 +6,14 @@ from datashare_python.conftest import TEST_PROJECT from datashare_python.constants import TIKA_METADATA_RESOURCENAME from datashare_python.objects import ( + DatashareLanguage, Document, DocumentLocation, FilesystemDocument, Task, TaskState, ) -from pydantic import ValidationError +from pydantic import TypeAdapter, ValidationError def test_task_ser() -> None: @@ -67,3 +68,29 @@ def test_document_to_filesystem_document_use_relative_path() -> None: fs_doc = doc.to_filesystem() relative_path = Path("some/absolute/path/resource.file") assert fs_doc.path == relative_path + + +def test_datashare_language() -> None: + # Given + language = "ENGLISH" + type_adapter = TypeAdapter(DatashareLanguage) + # When + ds_language = type_adapter.validate_python(language) + # Then + assert isinstance(ds_language, DatashareLanguage) + assert ds_language == language + + +@pytest.mark.parametrize( + ("language", "expected_msg"), + [("English", "expected uppercase"), ("AAAA", "Unknown")], +) +def test_invalid_datashare_language_should_raise( + language: str, expected_msg: str +) -> None: + # Given + type_adapter = TypeAdapter(DatashareLanguage) + + # When/Then + with pytest.raises(ValidationError, match=expected_msg): + type_adapter.validate_python(language) diff --git a/datashare-python/uv.lock b/datashare-python/uv.lock index a757f322..46fafb03 100644 --- a/datashare-python/uv.lock +++ b/datashare-python/uv.lock @@ -449,6 +449,7 @@ dependencies = [ { name = "icij-common", extra = ["elasticsearch"] }, { name = "nest-asyncio" }, { name = "orjson" }, + { name = "pydantic-extra-types", extra = ["pycountry"] }, { name = "python-json-logger" }, { name = "pyyaml" }, { name = "temporalio" }, @@ -479,6 +480,7 @@ requires-dist = [ { name = "icij-common", extras = ["elasticsearch"], specifier = "~=0.8.2" }, { name = "nest-asyncio", specifier = "~=1.6" }, { name = "orjson", specifier = "~=3.11" }, + { name = "pydantic-extra-types", extras = ["pycountry"], specifier = ">=2.11.1" }, { name = "python-json-logger", specifier = "~=4.0" }, { name = "pyyaml", specifier = "~=6.0" }, { name = "temporalio", specifier = "~=1.23" }, @@ -2051,6 +2053,24 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/4b/2d/69abac8f838090bbecd5df894befb2c2619e7996a98ddb949db9f3b93225/pydantic_core-2.46.4-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:d51026d73fcfd93610abc7b27789c26b313920fcfb20e27462d74a7f8b06e983", size = 2193071, upload-time = "2026-05-06T13:38:08.682Z" }, ] +[[package]] +name = "pydantic-extra-types" +version = "2.11.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pydantic" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/66/71/dba38ee2651f84f7842206adbd2233d8bbdb59fb85e9fa14232486a8c471/pydantic_extra_types-2.11.1.tar.gz", hash = "sha256:46792d2307383859e923d8fcefa82108b1a141f8a9c0198982b3832ab5ef1049", size = 172002, upload-time = "2026-03-16T08:08:03.92Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/17/c1/3226e6d7f5a4f736f38ac11a6fbb262d701889802595cdb0f53a885ac2e0/pydantic_extra_types-2.11.1-py3-none-any.whl", hash = "sha256:1722ea2bddae5628ace25f2aa685b69978ef533123e5638cfbddb999e0100ec1", size = 79526, upload-time = "2026-03-16T08:08:02.533Z" }, +] + +[package.optional-dependencies] +pycountry = [ + { name = "pycountry" }, +] + [[package]] name = "pydantic-settings" version = "2.14.1" diff --git a/worker-template/uv.dist.lock b/worker-template/uv.dist.lock index a09ba3ad..be0ed201 100644 --- a/worker-template/uv.dist.lock +++ b/worker-template/uv.dist.lock @@ -316,7 +316,7 @@ wheels = [ [[package]] name = "datashare-python" -version = "0.8.18" +version = "0.8.19" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "aiohttp" }, @@ -331,9 +331,9 @@ dependencies = [ { name = "tomlkit" }, { name = "typer" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/3c/42/f086eb8d611d42c5162befd36515f987541f7e103bff8e7e490cedb2f86a/datashare_python-0.8.18.tar.gz", hash = "sha256:052d85140977176268190b20c7468925c7fabb078ce3421266b3b682141fb202", size = 315672, upload-time = "2026-05-22T11:06:57.677Z" } +sdist = { url = "https://files.pythonhosted.org/packages/88/59/23fb6fb2d40a70a83fb7793e1541f34c6a40ad5769751552f4d768b15c34/datashare_python-0.8.19.tar.gz", hash = "sha256:7ec122672d9fd9ae4191ca1e26a3d0213d350514a2384ac20dbfaa98b371680d", size = 315621, upload-time = "2026-05-22T13:56:23.953Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/8e/75/8d96d987f35aecc2177972a8e605a65cc25adcb8b4c949a7c3ab4600d4ae/datashare_python-0.8.18-py3-none-any.whl", hash = "sha256:29eed8086677e2672364a9191f4cfc30d710d66427404e56a9513278def3acb1", size = 321671, upload-time = "2026-05-22T11:06:56.118Z" }, + { url = "https://files.pythonhosted.org/packages/8e/c8/67568e37bbc7bc87aa026599bab523466f3838096b0d5e567e48a99be59a/datashare_python-0.8.19-py3-none-any.whl", hash = "sha256:858ea1d11321b67e73998c03b0aa6dd8b83f433dac4f2b67d60f6b3a0dad66f0", size = 321655, upload-time = "2026-05-22T13:56:22.706Z" }, ] [[package]] diff --git a/worker-template/uv.lock b/worker-template/uv.lock index b5cce445..e82f1d1b 100644 --- a/worker-template/uv.lock +++ b/worker-template/uv.lock @@ -325,6 +325,7 @@ dependencies = [ { name = "icij-common", extra = ["elasticsearch"] }, { name = "nest-asyncio" }, { name = "orjson" }, + { name = "pydantic-extra-types", extra = ["pycountry"] }, { name = "python-json-logger" }, { name = "pyyaml" }, { name = "temporalio" }, @@ -340,6 +341,7 @@ requires-dist = [ { name = "icij-common", extras = ["elasticsearch"], specifier = "~=0.8.2" }, { name = "nest-asyncio", specifier = "~=1.6" }, { name = "orjson", specifier = "~=3.11" }, + { name = "pydantic-extra-types", extras = ["pycountry"], specifier = ">=2.11.1" }, { name = "python-json-logger", specifier = "~=4.0" }, { name = "pyyaml", specifier = "~=6.0" }, { name = "temporalio", specifier = "~=1.23" }, @@ -1498,6 +1500,24 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/32/56/8a7ca5d2cd2cda1d245d34b1c9a942920a718082ae8e54e5f3e5a58b7add/pydantic_core-2.33.2-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:329467cecfb529c925cf2bbd4d60d2c509bc2fb52a20c1045bf09bb70971a9c1", size = 2066757, upload-time = "2025-04-23T18:33:30.645Z" }, ] +[[package]] +name = "pydantic-extra-types" +version = "2.11.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pydantic" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/66/71/dba38ee2651f84f7842206adbd2233d8bbdb59fb85e9fa14232486a8c471/pydantic_extra_types-2.11.1.tar.gz", hash = "sha256:46792d2307383859e923d8fcefa82108b1a141f8a9c0198982b3832ab5ef1049", size = 172002, upload-time = "2026-03-16T08:08:03.92Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/17/c1/3226e6d7f5a4f736f38ac11a6fbb262d701889802595cdb0f53a885ac2e0/pydantic_extra_types-2.11.1-py3-none-any.whl", hash = "sha256:1722ea2bddae5628ace25f2aa685b69978ef533123e5638cfbddb999e0100ec1", size = 79526, upload-time = "2026-03-16T08:08:02.533Z" }, +] + +[package.optional-dependencies] +pycountry = [ + { name = "pycountry" }, +] + [[package]] name = "pydantic-settings" version = "2.13.1" diff --git a/workers/asr-worker/uv.lock b/workers/asr-worker/uv.lock index d8947e23..992b139c 100644 --- a/workers/asr-worker/uv.lock +++ b/workers/asr-worker/uv.lock @@ -572,6 +572,7 @@ dependencies = [ { name = "icij-common", extra = ["elasticsearch"], marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-20-datashare-asr-worker-cpu' and extra == 'extra-20-datashare-asr-worker-gpu') or sys_platform == 'darwin' or (sys_platform != 'linux' and extra == 'extra-20-datashare-asr-worker-cpu' and extra == 'extra-20-datashare-asr-worker-gpu')" }, { name = "nest-asyncio", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-20-datashare-asr-worker-cpu' and extra == 'extra-20-datashare-asr-worker-gpu') or sys_platform == 'darwin' or (sys_platform != 'linux' and extra == 'extra-20-datashare-asr-worker-cpu' and extra == 'extra-20-datashare-asr-worker-gpu')" }, { name = "orjson", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-20-datashare-asr-worker-cpu' and extra == 'extra-20-datashare-asr-worker-gpu') or sys_platform == 'darwin' or (sys_platform != 'linux' and extra == 'extra-20-datashare-asr-worker-cpu' and extra == 'extra-20-datashare-asr-worker-gpu')" }, + { name = "pydantic-extra-types", extra = ["pycountry"], marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-20-datashare-asr-worker-cpu' and extra == 'extra-20-datashare-asr-worker-gpu') or sys_platform == 'darwin' or (sys_platform != 'linux' and extra == 'extra-20-datashare-asr-worker-cpu' and extra == 'extra-20-datashare-asr-worker-gpu')" }, { name = "python-json-logger", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-20-datashare-asr-worker-cpu' and extra == 'extra-20-datashare-asr-worker-gpu') or sys_platform == 'darwin' or (sys_platform != 'linux' and extra == 'extra-20-datashare-asr-worker-cpu' and extra == 'extra-20-datashare-asr-worker-gpu')" }, { name = "pyyaml", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-20-datashare-asr-worker-cpu' and extra == 'extra-20-datashare-asr-worker-gpu') or sys_platform == 'darwin' or (sys_platform != 'linux' and extra == 'extra-20-datashare-asr-worker-cpu' and extra == 'extra-20-datashare-asr-worker-gpu')" }, { name = "temporalio", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-20-datashare-asr-worker-cpu' and extra == 'extra-20-datashare-asr-worker-gpu') or sys_platform == 'darwin' or (sys_platform != 'linux' and extra == 'extra-20-datashare-asr-worker-cpu' and extra == 'extra-20-datashare-asr-worker-gpu')" }, @@ -587,6 +588,7 @@ requires-dist = [ { name = "icij-common", extras = ["elasticsearch"], specifier = "~=0.8.2" }, { name = "nest-asyncio", specifier = "~=1.6" }, { name = "orjson", specifier = "~=3.11" }, + { name = "pydantic-extra-types", extras = ["pycountry"], specifier = ">=2.11.1" }, { name = "python-json-logger", specifier = "~=4.0" }, { name = "pyyaml", specifier = "~=6.0" }, { name = "temporalio", specifier = "~=1.23" }, diff --git a/workers/translation-worker/pyproject.toml b/workers/translation-worker/pyproject.toml index 78aa9c29..e89212e8 100644 --- a/workers/translation-worker/pyproject.toml +++ b/workers/translation-worker/pyproject.toml @@ -12,6 +12,7 @@ requires-python = ">=3.11.0, <3.13" dependencies = [ "datashare-python~=0.8.6", + "langcodes==3.5.1", "pydantic-extra-types[pycountry]==2.11.1", ] diff --git a/workers/translation-worker/tests/conftest.py b/workers/translation-worker/tests/conftest.py index 7b752d7e..a7fa9216 100644 --- a/workers/translation-worker/tests/conftest.py +++ b/workers/translation-worker/tests/conftest.py @@ -24,7 +24,7 @@ test_worker_config, worker_lifetime_deps, ) -from datashare_python.objects import Document +from datashare_python.objects import DatashareLanguage, Document from datashare_python.types_ import ContextManagerFactory, TemporalClient from datashare_python.worker import worker_context from icij_common.es import ESClient @@ -64,21 +64,17 @@ def test_worker_config(tmp_path_factory: TempPathFactory) -> TranslationWorkerCo ) -EN = "en" -FR = "fr" -ES = "es" -ENGLISH = "english" -FRENCH = "french" -SPANISH = "spanish" +ENGLISH = "ENGLISH" +FRENCH = "FRENCH" +SPANISH = "SPANISH" DOC_ID_1 = "doc_id_1" DOC_ID_2 = "doc_id_2" ROOT_DOCUMENT_1 = "root_document_1" ROOT_DOCUMENT_2 = "root_document_2" +DS_ENGLISH = DatashareLanguage(ENGLISH) +DS_FRENCH = DatashareLanguage(FRENCH) +DS_SPANISH = DatashareLanguage(SPANISH) -MOCK_TRANSLATIONS = [ - (DOC_ID_1, ROOT_DOCUMENT_1, "1"), - (DOC_ID_2, ROOT_DOCUMENT_2, "2"), -] FRENCH_TEXT = ( "Dans le port d'Amsterdam, il y a des marins qui chantent les rêves " @@ -94,7 +90,7 @@ def test_worker_config(tmp_path_factory: TempPathFactory) -> TranslationWorkerCo def _create_doc( - doc_id: str, root_doc: str, text: str, language: str = "ENGLISH" + doc_id: str, root_doc: str, text: str, language: DatashareLanguage = DS_ENGLISH ) -> Document: return Document(id=doc_id, root_document=root_doc, language=language, content=text) diff --git a/workers/translation-worker/tests/test_activities.py b/workers/translation-worker/tests/test_activities.py index 4a9ad03b..1804b09e 100644 --- a/workers/translation-worker/tests/test_activities.py +++ b/workers/translation-worker/tests/test_activities.py @@ -1,44 +1,50 @@ # ruff: noqa: ARG001, ANN001, ANN202, FBT001, FBT002, ARG005 -from collections.abc import AsyncGenerator -from typing import Any -from unittest.mock import AsyncMock, MagicMock, patch +from collections.abc import AsyncGenerator, Iterable +from functools import partial +from typing import Any, Self, TypeVar +from unittest.mock import patch +import pytest +from datashare_python.objects import Document from icij_common.es import ( - BOOL, DOC_CONTENT, DOC_LANGUAGE, DOC_ROOT_ID, HITS, ID_, - MUST_NOT, SOURCE, - TERM, + ESClient, + ESSort, ) +from icij_common.registrable import RegistrableConfig +from translation_worker import activities from translation_worker.activities import ( - _add_translation, - _get_doc_contents_and_split_on_sentences, _get_es_docs, - _iter_sentences, - _translate_batch, + _split_sentences, _untranslated_query, - create_translation_batches, - translate_docs, + _update_docs_translation, + create_translation_batches_act, + translate_docs_act, ) from translation_worker.config import TranslationWorkerConfig -from translation_worker.constants import CONTENT_LENGTH -from translation_worker.core import get_translation_ensemble -from translation_worker.objects import BatchSentence +from translation_worker.constants import DOC_CONTENT_TEXT_LENGTH +from translation_worker.objects import ( + DatashareLanguage, + ESTranslation, + Language, + TranslationModel, +) +from translation_worker.processors import SentenceSplitter, Translator from tests.conftest import ( DOC_ID_1, DOC_ID_2, - EN, + DS_ENGLISH, + DS_FRENCH, + DS_SPANISH, ENGLISH, - ES, - FR, FRENCH, - MOCK_TRANSLATIONS, ROOT_DOCUMENT_1, ROOT_DOCUMENT_2, SPANISH, @@ -46,186 +52,213 @@ ) -def _make_es_doc(doc_id: str, content: str, language: str, root_document: str) -> dict: - return { - ID_: doc_id, - SOURCE: { - DOC_CONTENT: content, - DOC_LANGUAGE: language, - DOC_ROOT_ID: root_document, - }, - } - +class MockESClient(ESClient): + def __init__(self, docs: list[dict[str, Any]]): + super().__init__(pagination=10) + self._docs = docs -async def _collect_async(gen: AsyncGenerator) -> list: - return [item async for item in gen] + async def poll_search_pages( + self, + body: dict, # noqa: ARG002 + sort: ESSort = None, # noqa: ARG002 + **kwargs, # noqa: ARG002 + ) -> AsyncGenerator[dict[str, Any], None]: + if sort is not None and any("language" in s for s in sort): + docs = sorted(self._docs, key=lambda x: x[SOURCE][DOC_LANGUAGE]) + else: + docs = self._docs + for i in range(0, len(docs), self._pagination_size): + yield {HITS: {HITS: docs[i : i + self._pagination_size]}} -async def _empty_docs_iter(*args, **kwargs) -> AsyncGenerator[None, Any]: - return - yield +class MockSentenceSplitter(SentenceSplitter): + def __init__(self, splits: list[list[str]]): + self._splits = iter(splits) + def split_sentences(self, text: str) -> list[str]: # noqa: ARG002 + return next(self._splits) -async def _single_doc_iter(es_doc: dict) -> AsyncGenerator: - yield es_doc + @classmethod + def _from_config(cls, config: RegistrableConfig, **extras) -> Self: ... -EN_DOC_1_TEXT = "Hello" -EN_DOC_2_TEXT = "Goodbye" -FR_DOC_1_TEXT = "Bonjour" -FR_DOC_2_TEXT = "Au revoir" -ES_DOC_1_TEXT = "Hola" -ES_DOC_2_TEXT = "Adios" -EN_DOC_1 = _make_es_doc(DOC_ID_1, EN_DOC_1_TEXT, ENGLISH, ROOT_DOCUMENT_1) -EN_DOC_2 = _make_es_doc(DOC_ID_2, EN_DOC_2_TEXT, ENGLISH, ROOT_DOCUMENT_2) -FR_DOC_1 = _make_es_doc(DOC_ID_1, FR_DOC_1_TEXT, FRENCH, ROOT_DOCUMENT_1) -FR_DOC_2 = _make_es_doc(DOC_ID_2, FR_DOC_2_TEXT, FRENCH, ROOT_DOCUMENT_2) -ES_DOC_1 = _make_es_doc(DOC_ID_1, ES_DOC_1_TEXT, SPANISH, ROOT_DOCUMENT_1) -ES_DOC_2 = _make_es_doc(DOC_ID_2, ES_DOC_2_TEXT, SPANISH, ROOT_DOCUMENT_2) - - -# _untranslated_query +class MockTranslator(Translator): + registered_name = TranslationModel.ARGOS + def __init__(self, translations: list[str]): + super().__init__() + self._translations = translations -def test__untranslated_query__fields() -> None: - query = _untranslated_query(EN) + def translate(self, texts: Iterable[str]) -> list[str]: # noqa: ARG002 + return self._translations - assert "query" in query - assert BOOL in query["query"] - assert "content_translated.en" in str(query) + @classmethod + def _from_config(cls, config: RegistrableConfig, **extras) -> Self: ... -def test__untranslated_query__comparison() -> None: - assert _untranslated_query(EN) != _untranslated_query(FR) +def _make_es_doc( + doc_id: str, *, content: str, language: str, root_document: str +) -> dict: + sources = {DOC_CONTENT: content, DOC_LANGUAGE: language, DOC_ROOT_ID: root_document} + return {ID_: doc_id, SOURCE: sources} -def test__untranslated_query__doc_lang_other_than_english() -> None: - assert _untranslated_query(FR)["query"][BOOL][MUST_NOT][1][TERM][DOC_LANGUAGE] == FR - - -# _iter_sentences - - -async def test__iter_sentences__yields_sentences_with_correct_indices() -> None: - es_doc = _make_es_doc( - DOC_ID_1, "Hello world. How are you? I'm fine.", EN, ROOT_DOCUMENT_1 - ) +async def _collect_async(gen: AsyncGenerator) -> list: + return [item async for item in gen] - sentences = ["Hello world.", "How are you?", "I'm fine."] - sentencizer = MagicMock(return_value=sentences) - batches = await _collect_async( - _iter_sentences(_single_doc_iter(es_doc), sentencizer) - ) +T = TypeVar("T") - non_empty = [b for b in batches if b] - assert len(non_empty) == 1 - assert len(non_empty[0]) == 3 - assert [s.sentence_index for s in non_empty[0]] == [0, 1, 2] +async def _aiter(it: Iterable[T]) -> AsyncGenerator[T, None]: + for item in it: + yield item -async def test__iter_sentences__doc_id_and_root_document_are_preserved() -> None: - sentences = [FR_DOC_1_TEXT] - sentencizer = MagicMock(return_value=sentences) - batches = await _collect_async( - _iter_sentences(_single_doc_iter(FR_DOC_1), sentencizer) - ) +EN_DOC_1_TEXT = "Hello" +EN_DOC_2_TEXT = "Goodbye" +FR_DOC_1_TEXT = "Bonjour" +FR_DOC_2_TEXT = "Au revoir" +ES_DOC_1_TEXT = "Hola" +ES_DOC_2_TEXT = "Adios" +EN_DOC_1 = _make_es_doc( + DOC_ID_1, content=EN_DOC_1_TEXT, language=ENGLISH, root_document=ROOT_DOCUMENT_1 +) +EN_DOC_2 = _make_es_doc( + DOC_ID_2, content=EN_DOC_2_TEXT, language=ENGLISH, root_document=ROOT_DOCUMENT_2 +) +FR_DOC_1 = _make_es_doc( + DOC_ID_1, content=FR_DOC_1_TEXT, language=FRENCH, root_document=ROOT_DOCUMENT_1 +) +FR_DOC_2 = _make_es_doc( + DOC_ID_2, content=FR_DOC_2_TEXT, language=FRENCH, root_document=ROOT_DOCUMENT_2 +) +ES_DOC_1 = _make_es_doc( + DOC_ID_1, content=ES_DOC_1_TEXT, language=SPANISH, root_document=ROOT_DOCUMENT_1 +) +ES_DOC_2 = _make_es_doc( + DOC_ID_2, content=ES_DOC_2_TEXT, language=SPANISH, root_document=ROOT_DOCUMENT_2 +) - non_empty = [b for b in batches if b] - sentence = non_empty[0][0] - assert sentence.doc_id == DOC_ID_1 - assert sentence.root_document == ROOT_DOCUMENT_1 +# _untranslated_query -async def test__iter_sentences__multiple_docs_sentences_collected_into_one_batch() -> ( - None -): - async def docs_iter(): - yield EN_DOC_1 - yield EN_DOC_2 - sentences = [[EN_DOC_1_TEXT], [EN_DOC_2_TEXT]] - sentencizer = MagicMock(side_effect=sentences) +@pytest.mark.parametrize("language", [DS_ENGLISH, DS_FRENCH]) +def test__untranslated_query(language: Language) -> None: + query = _untranslated_query(language) + expected_query = { + "query": { + "bool": { + "must_not": [{"term": {"content_translated.target_language": language}}] + } + } + } + assert query == expected_query - batches = await _collect_async(_iter_sentences(docs_iter(), sentencizer)) - non_empty = [b for b in batches if b] - assert len(non_empty) == 1 - assert len(non_empty[0]) == 2 +# _iter_sentences +_EN_DOC_TO_SPLIT = _make_es_doc( + DOC_ID_1, + content="Hello world. How are you? I'm fine.", + language=ENGLISH, + root_document=ROOT_DOCUMENT_1, +) -async def test__iter_sentences__empty_docs_iter_yields_no_sentences() -> None: - batches = await _collect_async(_iter_sentences(_empty_docs_iter(), MagicMock())) - assert not any(batches) +@pytest.mark.parametrize( + ("docs", "sentences", "expected_doc_sents"), + [ + # sentences_with_correct_indices + ( + [_EN_DOC_TO_SPLIT], + [["Hello world.", "How are you?", "I'm fine."]], + [ + (Document.from_es(_EN_DOC_TO_SPLIT), "Hello world."), + (Document.from_es(_EN_DOC_TO_SPLIT), "How are you?"), + (Document.from_es(_EN_DOC_TO_SPLIT), "I'm fine."), + ], + ), + # multiple docs sentences + ( + [EN_DOC_1, EN_DOC_2], + [[EN_DOC_1_TEXT], [EN_DOC_2_TEXT]], + [ + (Document.from_es(EN_DOC_1), EN_DOC_1_TEXT), + (Document.from_es(EN_DOC_2), EN_DOC_2_TEXT), + ], + ), + # no doc, no sentences + ([], [], []), + ], +) +async def test__iter_sentences__yields_sentences_with_correct_indices( + docs: list[dict[str, Any]], + sentences: list[list[str]], + expected_doc_sents: list[tuple[Document, str]], +) -> None: + # Given + sentence_splitter = MockSentenceSplitter(sentences) + docs = _aiter(docs) + # When + doc_sents = await _collect_async(_split_sentences(docs, sentence_splitter)) + # Then + assert doc_sents == expected_doc_sents # create_translation_batches - - -def _make_batching_doc(doc_id: str, language: str, content_length: int = 0) -> dict: - return { - ID_: doc_id, - SOURCE: {DOC_LANGUAGE: language, CONTENT_LENGTH: content_length}, - } - - -async def _make_group(*docs): - for doc in docs: - yield doc +def _make_batching_doc( + doc_id: str, language: DatashareLanguage, content_text_length: int = 0 +) -> dict: + source = {DOC_LANGUAGE: language, DOC_CONTENT_TEXT_LENGTH: content_text_length} + return {ID_: doc_id, SOURCE: source} async def test__create_translation_batches__returns_empty_list_when_no_docs() -> None: - async def mock_get_es_docs(*args, **kwargs): - return - yield - - with patch( - "translation_worker.activities._get_es_docs", side_effect=mock_get_es_docs - ): - result = await create_translation_batches( - project=TEST_PROJECT, - target_language=EN, + # Given + client = MockESClient([]) + # When + result = [ + b + async for b in create_translation_batches_act( + project=TEST_PROJECT, target=DS_ENGLISH, es_client=client ) - + ] + # Then assert result == [] async def test__create_translation_batches__single_doc_creates_one_batch() -> None: + # Given doc = _make_batching_doc(DOC_ID_1, FRENCH) - - async def mock_get_es_docs(*args, **kwargs): - yield _make_group(doc) - - with patch( - "translation_worker.activities._get_es_docs", side_effect=mock_get_es_docs - ): - result = await create_translation_batches( - project=TEST_PROJECT, - target_language=EN, + client = MockESClient([doc]) + # When + result = [ + b + async for b in create_translation_batches_act( + project=TEST_PROJECT, target=DS_ENGLISH, es_client=client ) - + ] + # When assert len(result) == 1 lang, batches = result[0] - assert lang == FR + assert lang == DS_FRENCH assert batches == [[DOC_ID_1]] async def test__create_translation_batches__multiple_docs_same_lang_one_batch() -> None: + # Given docs = [_make_batching_doc(DOC_ID_1, FRENCH), _make_batching_doc(DOC_ID_2, FRENCH)] - - async def mock_get_es_docs(*args, **kwargs): - yield _make_group(*docs) - - with patch( - "translation_worker.activities._get_es_docs", side_effect=mock_get_es_docs - ): - result = await create_translation_batches( - project=TEST_PROJECT, - target_language=EN, + client = MockESClient(docs) + # When + result = [ + b + async for b in create_translation_batches_act( + project=TEST_PROJECT, target=DS_ENGLISH, es_client=client ) - + ] + # Then assert len(result) == 1 _, batches = result[0] assert batches == [[DOC_ID_1, DOC_ID_2]] @@ -234,544 +267,261 @@ async def mock_get_es_docs(*args, **kwargs): async def test__create_translation_batches__multiple_langs_yield_separate_entries() -> ( None ): - fr_doc = _make_batching_doc(DOC_ID_1, FRENCH) - es_doc = _make_batching_doc(DOC_ID_2, SPANISH) - - async def mock_get_es_docs(*args, **kwargs): - yield _make_group(fr_doc) - yield _make_group(es_doc) - - with patch( - "translation_worker.activities._get_es_docs", side_effect=mock_get_es_docs - ): - result = await create_translation_batches( - project=TEST_PROJECT, - target_language=EN, + # Given + fr_doc = _make_batching_doc(DOC_ID_1, DS_FRENCH) + es_doc = _make_batching_doc(DOC_ID_2, DS_SPANISH) + docs = [fr_doc, es_doc] + # When + client = MockESClient(docs) + result = [ + b + async for b in create_translation_batches_act( + project=TEST_PROJECT, target=DS_ENGLISH, es_client=client ) - - langs = [lang for lang, _ in result] + ] + # Then + langs = {lang for lang, _ in result} assert len(result) == 2 - assert FR in langs - assert ES in langs + assert DS_FRENCH in langs + assert DS_SPANISH in langs -async def test__create_translation_batches__splits_batch_if_max_byte_len_exceeded() -> ( +async def test__create_translation_batches__splits_batch_if_max_text_len_exceeded() -> ( None ): - # doc1 is the first doc (byte len not yet tracked); doc2 adds 600 bytes; - # doc3 would bring total to 1200 > 1000, so it triggers a flush before being added + # Given + batch_text_length = 1400 doc_id_3 = "doc_id_3" docs = [ - _make_batching_doc(DOC_ID_1, FRENCH, content_length=600), - _make_batching_doc(DOC_ID_2, FRENCH, content_length=600), - _make_batching_doc(doc_id_3, FRENCH, content_length=600), + _make_batching_doc(DOC_ID_1, FRENCH, content_text_length=600), + _make_batching_doc(DOC_ID_2, FRENCH, content_text_length=600), + _make_batching_doc(doc_id_3, FRENCH, content_text_length=600), ] - - async def mock_get_es_docs(*args, **kwargs): - yield _make_group(*docs) - - with patch( - "translation_worker.activities._get_es_docs", side_effect=mock_get_es_docs - ): - result = await create_translation_batches( - project=TEST_PROJECT, target_language=EN, max_batch_byte_len=1000 + client = MockESClient(docs) + # When + result = [ + b + async for b in create_translation_batches_act( + project=TEST_PROJECT, + target=DS_ENGLISH, + batch_text_length=batch_text_length, + es_client=client, ) - + ] + # Then _, batches = result[0] assert len(batches) == 2 assert batches[0] == [DOC_ID_1, DOC_ID_2] assert batches[1] == [doc_id_3] -# translate_docs +# translate_docs_act -async def test_translate_docs__returns_zero_for_empty_batch() -> None: - empty_batch = (FR, []) - result = await translate_docs( - empty_batch, - EN, +async def test_translate_docs_act__returns_zero_for_empty_batch(monkeypatch) -> None: + # Given + batches = [[]] + sentences = [] + translations = [] + translator = MockTranslator(translations) + sentence_splitter = MockSentenceSplitter(sentences) + es_client = MockESClient([FR_DOC_1, FR_DOC_2]) + # When + n_translated = await translate_docs_act( + batches, project=TEST_PROJECT, - es_client=MagicMock(), + es_client=es_client, worker_config=TranslationWorkerConfig(), + translator=translator, + sentence_splitter=sentence_splitter, ) - assert result == 0 - - -async def test_translate_docs__accepts_none_config_and_defaults() -> None: - empty_batch = (FR, []) - result = await translate_docs( - empty_batch, - EN, - project=TEST_PROJECT, - es_client=MagicMock(), - worker_config=None, - ) - assert result == 0 - + # Then + assert n_translated == 0 -async def test_translate_docs__returns_count_of_unique_docs_translated() -> None: - sentences = [ - BatchSentence( - doc_id=DOC_ID_1, - root_document=ROOT_DOCUMENT_1, - sentence_index=0, - sentence=FR_DOC_1_TEXT, - ), - BatchSentence( - doc_id=DOC_ID_2, - root_document=ROOT_DOCUMENT_2, - sentence_index=0, - sentence=FR_DOC_2_TEXT, - ), - ] - async def mock_split(*args, **kwargs): - yield sentences +async def _do_nothing_es_update( + es_client: ESClient, + translated_docs: Iterable[tuple[Document, ESTranslation]], + project: str, +): + pass - async def mock_translate(batch, ensemble, config): - return [EN_DOC_1_TEXT, EN_DOC_2_TEXT] - with ( - patch( - "translation_worker.activities.get_translation_ensemble", - return_value=MagicMock(), - ), - patch( - "translation_worker.activities._get_doc_contents_and_split_on_sentences", - side_effect=mock_split, - ), - patch( - "translation_worker.activities._translate_batch", side_effect=mock_translate - ), - patch("translation_worker.activities._add_translation", new_callable=AsyncMock), - ): - result = await translate_docs( - (FR, [[DOC_ID_1, DOC_ID_2]]), - EN, +async def _capturing_es_update( + es_client: ESClient, + translated_docs: Iterable[tuple[Document, ESTranslation]], + project: str, + captured: list[tuple[Document, ESTranslation]], +): + captured.extend(translated_docs) + + +async def test_translate_docs_act__returns_count_of_unique_docs_translated( + monkeypatch, +) -> None: + # Given + batches = [[DOC_ID_1, DOC_ID_2]] + sentences = [[FR_DOC_1_TEXT], [FR_DOC_2_TEXT]] + translations = [EN_DOC_1_TEXT, EN_DOC_2_TEXT] + translator = MockTranslator(translations) + sentence_splitter = MockSentenceSplitter(sentences) + monkeypatch.setattr(activities, "_update_docs_translation", _do_nothing_es_update) + es_client = MockESClient([FR_DOC_1, FR_DOC_2]) + # When + with translator.load(source=DS_ENGLISH, target=DS_ENGLISH): + n_translated = await translate_docs_act( + batches, project=TEST_PROJECT, - es_client=MagicMock(), + es_client=es_client, worker_config=TranslationWorkerConfig(), + translator=translator, + sentence_splitter=sentence_splitter, ) - - assert result == 2 - - -async def test_translate_docs__sentences_from_same_doc_count_as_one() -> None: - sentences = [ - BatchSentence( - doc_id=DOC_ID_1, - root_document=ROOT_DOCUMENT_1, - sentence_index=0, - sentence=FR_DOC_1_TEXT, - ), - BatchSentence( - doc_id=DOC_ID_1, - root_document=ROOT_DOCUMENT_1, - sentence_index=1, - sentence=FR_DOC_2_TEXT, - ), - ] - - async def mock_split(*args, **kwargs): - yield sentences - - async def mock_translate(batch, ensemble, config): - return [EN_DOC_1_TEXT, EN_DOC_2_TEXT] - - with ( - patch( - "translation_worker.activities.get_translation_ensemble", - return_value=MagicMock(), - ), - patch( - "translation_worker.activities._get_doc_contents_and_split_on_sentences", - side_effect=mock_split, - ), - patch( - "translation_worker.activities._translate_batch", side_effect=mock_translate - ), - patch("translation_worker.activities._add_translation", new_callable=AsyncMock), - ): - result = await translate_docs( - (FR, [[DOC_ID_1]]), - EN, + # Then + assert n_translated == 2 + + +async def test_translate_docs_act__sentences_from_same_doc_count_as_one( + monkeypatch, +) -> None: + # Given + batches = [[DOC_ID_1]] + sentences = [[FR_DOC_1_TEXT, FR_DOC_2_TEXT]] + translations = [EN_DOC_1_TEXT, EN_DOC_2_TEXT] + translator = MockTranslator(translations) + sentence_splitter = MockSentenceSplitter(sentences) + monkeypatch.setattr(activities, "_update_docs_translation", _do_nothing_es_update) + es_client = MockESClient([FR_DOC_1]) + # When + with translator.load(source=DS_ENGLISH, target=DS_ENGLISH): + n_translated = await translate_docs_act( + batches, project=TEST_PROJECT, - es_client=MagicMock(), + es_client=es_client, worker_config=TranslationWorkerConfig(), + translator=translator, + sentence_splitter=sentence_splitter, ) - - assert result == 1 - - -async def test_translate_docs__reconstructs_translation_in_sentence_order() -> None: - sentences = [ - BatchSentence( - doc_id=DOC_ID_1, - root_document=ROOT_DOCUMENT_1, - sentence_index=0, - sentence=FR_DOC_1_TEXT, - ), - BatchSentence( - doc_id=DOC_ID_1, - root_document=ROOT_DOCUMENT_1, - sentence_index=1, - sentence=FR_DOC_2_TEXT, - ), - ] - captured_translations = [] - - async def mock_split(*args, **kwargs): - yield sentences - - async def mock_translate(batch, ensemble, config): - return [EN_DOC_1_TEXT, EN_DOC_2_TEXT] - - async def mock_add_translation( - es_client, translations, project, *, target_language_alpha_code - ): - captured_translations.extend(translations) - - with ( - patch( - "translation_worker.activities.get_translation_ensemble", - return_value=MagicMock(), - ), - patch( - "translation_worker.activities._get_doc_contents_and_split_on_sentences", - side_effect=mock_split, - ), - patch( - "translation_worker.activities._translate_batch", side_effect=mock_translate - ), - patch( - "translation_worker.activities._add_translation", - side_effect=mock_add_translation, - ), - ): - await translate_docs( - (FR, [[DOC_ID_1]]), - EN, + # Then + assert n_translated == 1 + + +async def test_translate_docs_act__es_update(monkeypatch) -> None: + # Given + batches = [[DOC_ID_1]] + sentences = [[FR_DOC_1_TEXT, FR_DOC_2_TEXT]] + translations = [EN_DOC_1_TEXT, EN_DOC_2_TEXT] + translator = MockTranslator(translations) + sentence_splitter = MockSentenceSplitter(sentences) + captured = [] + update_doc = partial(_capturing_es_update, captured=captured) + monkeypatch.setattr(activities, "_update_docs_translation", update_doc) + es_client = MockESClient([FR_DOC_1]) + # When + with translator.load(source=DS_FRENCH, target=DS_ENGLISH): + n_translated = await translate_docs_act( + batches, project=TEST_PROJECT, - es_client=MagicMock(), + es_client=es_client, worker_config=TranslationWorkerConfig(), + translator=translator, + sentence_splitter=sentence_splitter, ) - - assert len(captured_translations) == 1 - _, _, combined = captured_translations[0] - assert EN_DOC_1_TEXT in combined - assert EN_DOC_2_TEXT in combined - assert combined.index(EN_DOC_1_TEXT) < combined.index(EN_DOC_2_TEXT) - - -async def test_translate_docs__calls_add_translation_with_target_language() -> None: - sentences = [ - BatchSentence( - doc_id=DOC_ID_1, - root_document=ROOT_DOCUMENT_1, - sentence_index=0, - sentence=FR_DOC_1_TEXT, - ), - ] - captured_kwargs = {} - - async def mock_split(*args, **kwargs): - yield sentences - - async def mock_add_translation( - es_client, translations, project, *, target_language_alpha_code - ): - captured_kwargs["target_language_alpha_code"] = target_language_alpha_code - - with ( - patch( - "translation_worker.activities.get_translation_ensemble", - return_value=MagicMock(), - ), - patch( - "translation_worker.activities._get_doc_contents_and_split_on_sentences", - side_effect=mock_split, - ), - patch( - "translation_worker.activities._translate_batch", - side_effect=lambda b, e, c: [EN_DOC_1_TEXT], - ), - patch( - "translation_worker.activities._add_translation", - side_effect=mock_add_translation, - ), - ): - await translate_docs( - (FR, [[DOC_ID_1]]), - EN, - project=TEST_PROJECT, - es_client=MagicMock(), - worker_config=TranslationWorkerConfig(), - ) - - assert captured_kwargs["target_language_alpha_code"] == EN - - -# _translate_batch - - -async def test__translate_batch_returns_translations_from_translate_as_list() -> None: - sentences = ( - BatchSentence( - doc_id=DOC_ID_1, - root_document=ROOT_DOCUMENT_1, - sentence_index=0, - sentence=EN_DOC_1_TEXT, - ), + # Then + assert n_translated == 1 + assert len(captured) == 1 + captured = captured[0] + expected_translation = ESTranslation( + source_language=DS_FRENCH, + target_language=DS_ENGLISH, + translator=TranslationModel.ARGOS, + content="Hello Goodbye", ) - translation_ensemble = get_translation_ensemble( - source_language_alpha_code=EN, - target_language_alpha_code=FR, + expected = (Document.from_es(FR_DOC_1), expected_translation) + assert captured == expected + + +async def test__update_docs() -> None: + # Given + doc_1 = Document(id="doc_1", language=DS_ENGLISH, root_document=ROOT_DOCUMENT_1) + t_1 = ESTranslation( + source_language=DS_ENGLISH, + target_language=DS_FRENCH, + translator=TranslationModel.ARGOS, + content="1", ) - with patch( - "translation_worker.activities.translate_as_list", - return_value=[FR_DOC_1_TEXT], - ): - result = await _translate_batch(sentences, translation_ensemble) - - assert result == [FR_DOC_1_TEXT] - - -# _add_translation - - -async def test__add_translation__calls_async_bulk_once_per_invocation() -> None: - bulk_call_count = 0 - - async def capture_bulk(client, actions, **kwargs): - nonlocal bulk_call_count - list(actions) # consume generator - bulk_call_count += 1 - - with patch("translation_worker.activities.async_bulk", side_effect=capture_bulk): - await _add_translation( - MagicMock(), - MOCK_TRANSLATIONS, - TEST_PROJECT, - target_language_alpha_code=EN, - ) - - assert bulk_call_count == 1 - - -async def test__add_translation__generates_update_action_per_translation() -> None: - bulk_actions = [] - - async def capture_bulk(client, actions, **kwargs): - bulk_actions.extend(list(actions)) - - with patch("translation_worker.activities.async_bulk", side_effect=capture_bulk): - await _add_translation( - MagicMock(), - MOCK_TRANSLATIONS, - TEST_PROJECT, - target_language_alpha_code=EN, - ) - - assert len(bulk_actions) == 2 - assert all(a["_op_type"] == "update" for a in bulk_actions) - assert all(a["_index"] == TEST_PROJECT for a in bulk_actions) - - -async def test__add_translation__sets_correct_doc_id_routing_and_params() -> None: - bulk_actions = [] - - async def capture_bulk(client, actions, **kwargs): - bulk_actions.extend(list(actions)) - - with patch("translation_worker.activities.async_bulk", side_effect=capture_bulk): - await _add_translation( - MagicMock(), - MOCK_TRANSLATIONS[:1], - TEST_PROJECT, - target_language_alpha_code=EN, + doc_2 = Document(id="doc_2", language=DS_ENGLISH, root_document=ROOT_DOCUMENT_2) + t_2 = ESTranslation( + source_language=DS_ENGLISH, + target_language=DatashareLanguage(FRENCH), + translator=TranslationModel.ARGOS, + content="2", + ) + translated_docs = [(doc_1, t_1), (doc_2, t_2)] + + # When + with patch("translation_worker.activities.async_bulk") as mocked: + await _update_docs_translation(MockESClient([]), translated_docs, TEST_PROJECT) + + # Then + calls = mocked.mock_calls + assert len(calls) == 1 + actions = list(calls[0].args[1]) + scripts = [a.pop("script") for a in actions] + source_and_targets = [ + ( + s["params"]["translation"]["source_language"], + s["params"]["translation"]["target_language"], ) - - assert bulk_actions[0][ID_] == DOC_ID_1 - assert bulk_actions[0]["_routing"] == ROOT_DOCUMENT_1 - - params = bulk_actions[0]["script"]["params"] - - assert params["language"] == EN - assert params["translation"] == MOCK_TRANSLATIONS[0][2] + for s in scripts + ] + assert source_and_targets == [("ENGLISH", "FRENCH"), ("ENGLISH", "FRENCH")] + expected_actions = [ + { + "_id": "doc_1", + "_index": "test-project", + "_op_type": "update", + "_routing": "root_document_1", + }, + { + "_id": "doc_2", + "_index": "test-project", + "_op_type": "update", + "_routing": "root_document_2", + }, + ] + assert actions == expected_actions # _get_es_docs -async def test__get_es_docs__empty_docs_yields_no_groups() -> None: - mock_es_client = MagicMock() - - async def mock_poll_search_pages(*args, **kwargs): - yield {HITS: {HITS: []}} - - mock_es_client.poll_search_pages = mock_poll_search_pages - - groups = [ - group async for group in _get_es_docs(mock_es_client, TEST_PROJECT, EN, []) +@pytest.mark.parametrize( + ("docs", "expected_groups"), + [ + # empty_docs_yields_no_doc + ([], []), + # single_doc + ([FR_DOC_1], [[FR_DOC_1[ID_]]]), + # groups_docs_of_same_language_together + ([FR_DOC_1, FR_DOC_2], [[FR_DOC_1[ID_], FR_DOC_2[ID_]]]), + # separate_group_per_language + ( + [FR_DOC_1, ES_DOC_1, ES_DOC_2], + [[FR_DOC_1[ID_]], [ES_DOC_1[ID_], ES_DOC_2[ID_]]], + ), + ], +) +async def test__get_es_docs( + docs: list[dict[str, Any]], expected_groups: list[list[str]] +) -> None: + # Given + es_client = MockESClient(docs) + # When + docs = [ + [d async for d in group] + async for group in _get_es_docs(es_client, TEST_PROJECT, DS_ENGLISH, []) ] - - assert groups == [] - - -async def test__get_es_docs__single_doc_yields_one_group() -> None: - mock_es_client = MagicMock() - - async def mock_poll_search_pages(*args, **kwargs): - yield {HITS: {HITS: [FR_DOC_1]}} - yield {HITS: {HITS: []}} - - mock_es_client.poll_search_pages = mock_poll_search_pages - - groups = [] - async for group in _get_es_docs(mock_es_client, TEST_PROJECT, EN, []): - groups.append([doc async for doc in group]) - - assert len(groups) == 1 - assert groups[0][0][ID_] == DOC_ID_1 - - -async def test__get_es_docs__groups_docs_of_same_language_together() -> None: - mock_es_client = MagicMock() - - async def mock_poll_search_pages(*args, **kwargs): - yield {HITS: {HITS: [FR_DOC_1, FR_DOC_2]}} - yield {HITS: {HITS: []}} - - mock_es_client.poll_search_pages = mock_poll_search_pages - - groups = [] - async for group in _get_es_docs(mock_es_client, TEST_PROJECT, EN, []): - groups.append([doc async for doc in group]) - - assert len(groups) == 1 - assert len(groups[0]) == 2 - - -async def test__get_es_docs__yields_separate_group_per_language() -> None: - mock_es_client = MagicMock() - - async def mock_poll_search_pages(*args, **kwargs): - yield {HITS: {HITS: [FR_DOC_1, ES_DOC_1]}} - yield {HITS: {HITS: []}} - - mock_es_client.poll_search_pages = mock_poll_search_pages - - groups = [] - async for group in _get_es_docs(mock_es_client, TEST_PROJECT, EN, []): - groups.append([doc async for doc in group]) - - assert len(groups) == 2 - assert groups[0][0][SOURCE][DOC_LANGUAGE] == FRENCH - assert groups[1][0][SOURCE][DOC_LANGUAGE] == SPANISH - - -async def test__get_es_docs__all_docs_in_second_language_group_are_included() -> None: - mock_es_client = MagicMock() - - async def mock_poll_search_pages(*args, **kwargs): - yield {HITS: {HITS: [FR_DOC_1, ES_DOC_1, ES_DOC_2]}} - yield {HITS: {HITS: []}} - - mock_es_client.poll_search_pages = mock_poll_search_pages - - groups = [] - async for group in _get_es_docs(mock_es_client, TEST_PROJECT, EN, []): - groups.append([doc async for doc in group]) - - assert len(groups) == 2 - assert len(groups[1]) == 2 - - -# _get_doc_contents_and_split_on_sentences - - -async def test__get_doc_contents_and_split_on_sentences__empty_iter_yields_empty() -> ( - None -): - result = await _collect_async( - _get_doc_contents_and_split_on_sentences( - MagicMock(), TEST_PROJECT, [], MagicMock() - ) - ) - - assert result == [] - - -async def test__get_doc_contents_and_split_on_sentences__yields_sents_from_doc() -> ( - None -): - mock_es_client = MagicMock() - sentencizer = MagicMock(return_value=[FR_DOC_1_TEXT]) - - async def mock_poll_search_pages(*args, **kwargs): - yield {HITS: {HITS: [FR_DOC_1]}} - yield {HITS: {HITS: []}} - - mock_es_client.poll_search_pages = mock_poll_search_pages - - batches = await _collect_async( - _get_doc_contents_and_split_on_sentences( - mock_es_client, TEST_PROJECT, [DOC_ID_1], sentencizer - ) - ) - - assert len(batches) == 1 - assert batches[0][0].doc_id == DOC_ID_1 - assert batches[0][0].sentence == FR_DOC_1_TEXT - - -async def test__get_doc_contents_and_split_on_sentences__yields_last_batch() -> None: - # batch_size is 16 but the doc has only 2 sentences — partial batch must - # still be yielded - mock_es_client = MagicMock() - sentences = ["Sentence one.", "Sentence two."] - sentencizer = MagicMock(return_value=sentences) - - async def mock_poll_search_pages(*args, **kwargs): - yield {HITS: {HITS: [FR_DOC_1]}} - yield {HITS: {HITS: []}} - - mock_es_client.poll_search_pages = mock_poll_search_pages - - batches = await _collect_async( - _get_doc_contents_and_split_on_sentences( - mock_es_client, TEST_PROJECT, [DOC_ID_1], sentencizer - ) - ) - - assert len(batches) == 1 - assert len(batches[0]) == 2 - - -async def test__get_doc_contents_and_split_on_sentences__many_batches_batch_size() -> ( - None -): - mock_es_client = MagicMock() - sentences = ["One.", "Two.", "Three."] - sentencizer = MagicMock(return_value=sentences) - - async def mock_poll_search_pages(*args, **kwargs): - yield {HITS: {HITS: [FR_DOC_1]}} - yield {HITS: {HITS: []}} - - mock_es_client.poll_search_pages = mock_poll_search_pages - - batches = await _collect_async( - _get_doc_contents_and_split_on_sentences( - mock_es_client, - TEST_PROJECT, - [DOC_ID_1], - sentencizer, - sentence_batch_size=2, - ) - ) - - assert len(batches) == 2 - assert len(batches[0]) == 2 - assert len(batches[1]) == 1 + docs = [[d[ID_] for d in g] for g in docs] + # Then + assert docs == expected_groups diff --git a/workers/translation-worker/tests/test_objects.py b/workers/translation-worker/tests/test_objects.py new file mode 100644 index 00000000..2803839c --- /dev/null +++ b/workers/translation-worker/tests/test_objects.py @@ -0,0 +1,22 @@ +from translation_worker.objects import ( + ArgosSentenceSplitterConfig, + ArgosSentencizer, + ArgosTranslatorConfig, + TranslationConfig, +) + + +def test_config_deser() -> None: + # Given + config = {"sentence_splitter": {"model": "ARGOS"}, "translator": {"model": "ARGOS"}} + + # When + deser = TranslationConfig.model_validate(config) + # Then + expected = TranslationConfig( + sentence_splitter=ArgosSentenceSplitterConfig( + sentencizer=ArgosSentencizer.MINI_SBD + ), + translator=ArgosTranslatorConfig(), + ) + assert deser == expected diff --git a/workers/translation-worker/tests/test_translation.py b/workers/translation-worker/tests/test_workflows.py similarity index 83% rename from workers/translation-worker/tests/test_translation.py rename to workers/translation-worker/tests/test_workflows.py index 6bfcc4a6..de97edc9 100644 --- a/workers/translation-worker/tests/test_translation.py +++ b/workers/translation-worker/tests/test_workflows.py @@ -4,11 +4,10 @@ from datashare_python.conftest import TEST_PROJECT from datashare_python.objects import Document from icij_common.es import HITS, ESClient, has_type -from pydantic_extra_types.language_code import LanguageName from temporalio.client import Client as TemporalClient from temporalio.worker import Worker from translation_worker.constants import TaskQueue -from translation_worker.objects import TranslationArgs +from translation_worker.objects import DatashareLanguage, TranslationArgs from translation_worker.workflows import TranslationWorkflow @@ -23,7 +22,7 @@ async def test_translation_workflow( ) -> None: # Given args = TranslationArgs( - project=TEST_PROJECT, target_language=LanguageName("ENGLISH") + project=TEST_PROJECT, target_language=DatashareLanguage("ENGLISH") ) workflow_id = f"translation-{uuid.uuid4().hex}" @@ -43,4 +42,7 @@ async def test_translation_workflow( index_docs += hits[HITS][HITS] assert len(index_docs) == 2 index_docs = [Document.from_es(doc) for doc in index_docs] - assert all("en" in doc.content_translated for doc in index_docs) + assert all( + any(ct["target_language"] == "ENGLISH" for ct in doc.content_translated) + for doc in index_docs + ) diff --git a/workers/translation-worker/translation_worker/__init__.py b/workers/translation-worker/translation_worker/__init__.py index e69de29b..0519c120 100644 --- a/workers/translation-worker/translation_worker/__init__.py +++ b/workers/translation-worker/translation_worker/__init__.py @@ -0,0 +1,8 @@ +from temporalio import workflow + +with workflow.unsafe.imports_passed_through(): + try: + from .sentence_splitters import ArgosSentenceSplitter + except ImportError: + ArgosSentenceSplitter = None + from .translators import ArgosTranslator # noqa: F401 diff --git a/workers/translation-worker/translation_worker/activities.py b/workers/translation-worker/translation_worker/activities.py index 7f91d928..8b690b2d 100644 --- a/workers/translation-worker/translation_worker/activities.py +++ b/workers/translation-worker/translation_worker/activities.py @@ -1,19 +1,17 @@ import asyncio import logging -from collections import defaultdict -from collections.abc import AsyncGenerator, AsyncIterator, Iterable -from copy import deepcopy +from collections.abc import AsyncGenerator, AsyncIterator, Callable, Iterable from functools import partial from typing import Any, cast from aiostream.stream import chain from datashare_python.dependencies import lifespan_es_client, lifespan_worker_config -from datashare_python.objects import Document +from datashare_python.objects import DatashareLanguage, Document from datashare_python.types_ import ProgressRateHandler from datashare_python.utils import ActivityWithProgress, activity_defn, to_raw_progress from elasticsearch._async.helpers import async_bulk from icij_common.es import ( - BOOL, + DOC_CONTENT_TRANSLATED, DOC_LANGUAGE, HITS, ID_, @@ -21,328 +19,275 @@ SOURCE, TERM, ESClient, + bool_query, has_id, must_not, ) -from icij_common.iter_utils import before_and_after, once -from pydantic_extra_types.language_code import LanguageAlpha2, LanguageName +from icij_common.iter_utils import async_batches, before_and_after, once +from pydantic import TypeAdapter + +from translation_worker.constants import DOC_CONTENT_TEXT_LENGTH from .config import TranslationWorkerConfig -from .constants import BATCHING_DOC_SOURCES, CONTENT_LENGTH, TRANSLATION_DOC_SOURCES -from .core import ( - Sentencizer, - TranslationEnsemble, - get_translation_ensemble, - has_language, - translate_as_list, -) -from .objects import BatchSentence +from .constants import BATCHING_DOC_SOURCES, TRANSLATION_DOC_SOURCES +from .objects import ESTranslation, Language, TranslationConfig +from .processors import SentenceSplitter, Translator logger = logging.getLogger(__name__) +DocId = str +Batch = list[DocId] + +_LANGUAGE_TYPE_ADAPTER = TypeAdapter(Language) + class TranslationActivities(ActivityWithProgress): @activity_defn(name="translation.worker_config") async def translation_worker_config(self) -> TranslationWorkerConfig: + logger.info("loading worker configuration...") worker_config = cast(TranslationWorkerConfig, lifespan_worker_config()) return worker_config @activity_defn(name="translation.create_translation_batches") async def create_translation_batches( - self, project: str, target_language: LanguageAlpha2 - ) -> list[tuple[str, list[list[str]]]]: + self, project: str, target: Language + ) -> list[tuple[Language, list[Batch]]]: es_client = lifespan_es_client() worker_config = cast(TranslationWorkerConfig, lifespan_worker_config()) - max_batch_byte_len = worker_config.max_batch_byte_len - batches = await create_translation_batches( - project=project, - target_language=target_language, - max_batch_byte_len=max_batch_byte_len, - es_client=es_client, - ) + batch_text_length = worker_config.batch_text_length + logger.info("creating translation batches...") + batches = [ + b + async for b in create_translation_batches_act( + project=project, + target=target, + batch_text_length=batch_text_length, + es_client=es_client, + ) + ] + logger.info("translation batches created !") return batches @activity_defn(name="translation.translate_docs") async def translate_docs( self, - doc_id_batch_with_lang: tuple[str, list[list[str]]], - target_language: LanguageAlpha2, + batches: list[Batch], *, + source: DatashareLanguage, + target: Language, + config: TranslationConfig, project: str, progress: ProgressRateHandler | None = None, ) -> int: es_client = lifespan_es_client() worker_config = cast(TranslationWorkerConfig, lifespan_worker_config()) - n_translated = await translate_docs( - doc_id_batch_with_lang, - target_language=target_language, - project=project, - es_client=es_client, - progress=progress, - worker_config=worker_config, - ) + # TODO: make a generic fix using interceptors, + # see https://github.com/temporalio/sdk-python/issues/360 + source = _LANGUAGE_TYPE_ADAPTER.validate_python(source) + target = _LANGUAGE_TYPE_ADAPTER.validate_python(target) + if isinstance(config, dict): + config = TranslationConfig.model_validate(config) + # TODO: perform some caching here to avoid reloading + translator = config.to_translator() + sentence_splitter = config.to_sentence_splitter() + # Load the translator first to install the SBD, then load the splitter + logger.debug("loading %s -> %s translator...", source, target) + with translator.load(source, target=target, worker_config=worker_config): + logger.debug("loading %s sentence splitter...", source) + with sentence_splitter.load(source): + logger.info("translating %s batches...", len(batches)) + n_translated = await translate_docs_act( + batches, + project=project, + es_client=es_client, + progress=progress, + worker_config=worker_config, + translator=translator, + sentence_splitter=sentence_splitter, + ) + logger.info("done translating !") return n_translated -async def create_translation_batches( +async def create_translation_batches_act( *, project: str, - target_language: LanguageAlpha2, - max_batch_byte_len: int = 1000000, + target: Language, + batch_text_length: int = 1000000, es_client: ESClient | None = None, -) -> list[tuple[str, list[list[str]]]]: - """Batch doc ids by language and/or total batch byte length - - :param project: Project name - :param target_language: Target language - :param max_batch_byte_len: Maximum batch byte length - :param es_client: ES client - :return: list of batches keyed by language - """ +) -> AsyncGenerator[tuple[DatashareLanguage, list[Batch]], None]: # Retrieve unprocessed docs. es_docs = _get_es_docs( - es_client, - project, - target_language_alpha_code=target_language, - source_includes=BATCHING_DOC_SOURCES, + es_client, project, target=target, source_includes=BATCHING_DOC_SOURCES ) - all_results = {} - current_batch = [] - current_batch_byte_len = 0 - - async for es_doc_id_batch in es_docs: - first_doc = await anext(es_doc_id_batch, None) - if first_doc is None: - continue - - source_alpha_2 = LanguageName(first_doc[SOURCE][DOC_LANGUAGE].title()).alpha2 - all_results[source_alpha_2] = [] - current_batch.append(first_doc[ID_]) - - async for item in es_doc_id_batch: - doc_id = item[ID_] - doc_byte_len = item[SOURCE][CONTENT_LENGTH] - - if 0 < max_batch_byte_len < current_batch_byte_len + doc_byte_len: - all_results[source_alpha_2].append(deepcopy(current_batch)) + async for language_docs in es_docs: + language_batches: list[Batch] = [] + current_batch = [] + current_length = 0 + current_language = None + + async for doc in language_docs: + doc_id: str = doc[ID_] + doc_length = doc[SOURCE][DOC_CONTENT_TEXT_LENGTH] + if current_language is None: + current_language = DatashareLanguage(doc[SOURCE][DOC_LANGUAGE]) + logger.debug("creating batches for %s docs...", current_language) + + next_length = current_length + doc_length + if next_length > batch_text_length: + language_batches.append(list(current_batch)) current_batch = [] - current_batch_byte_len = 0 current_batch.append(doc_id) - current_batch_byte_len += doc_byte_len - - if len(current_batch) > 0: - all_results[source_alpha_2].append(deepcopy(current_batch)) - current_batch = [] - current_batch_byte_len = 0 + current_length = next_length - return list(all_results.items()) + if current_batch: + language_batches.append(list(current_batch)) + if current_language is None: + continue + yield current_language, language_batches -async def translate_docs( - doc_id_batch_with_lang: tuple[str, list[list[str]]], - target_language: LanguageAlpha2, +async def translate_docs_act( + batches: list[Batch], *, project: str, - es_client: ESClient | None = None, - worker_config: TranslationWorkerConfig | None = None, + translator: Translator, + sentence_splitter: SentenceSplitter, + worker_config: TranslationWorkerConfig, + es_client: ESClient, progress: ProgressRateHandler | None = None, # noqa: F821 ) -> int: - """Translate sentence batches and reconstruct translations from original - sentence ordering, inserting them into ES - - :param doc_id_batch_with_lang: doc_ids keyed by document language alpha code - :param target_language: Target language alpha2 code - :param project: Project name - :param es_client: ES client - :param progress: ProgressRateHandler - :param worker_config: worker config - :return: number of documents translated - """ - if worker_config is None: - worker_config = TranslationWorkerConfig() - # TODO: this should not happen if not isinstance(worker_config, TranslationWorkerConfig): worker_config = TranslationWorkerConfig.model_validate(worker_config) - - source_language_alpha_code, doc_id_batches = doc_id_batch_with_lang - - # Get documents - translation_ensemble = get_translation_ensemble( - source_language_alpha_code=source_language_alpha_code, - target_language_alpha_code=target_language, - device=worker_config.device, - inter_threads=worker_config.inter_threads, - intra_threads=worker_config.intra_threads, - compute_type=worker_config.compute_type, + es_queue = asyncio.Queue() + publisher = _translate_and_queue( + batches, + es_queue, + project, + translator, + sentence_splitter, + worker_config, + es_client, + progress, ) + publisher = asyncio.create_task(publisher) + publisher_callback = lambda: es_queue.put_nowait(None) # noqa: E731 + consumer = asyncio.create_task( + _write_translations_to_es(es_client, queue=es_queue, project=project) + ) + n_docs, _ = await _publish_and_consume( + publisher, publisher_callback, consumer=consumer + ) + return n_docs - all_sentences = [] - all_translations = [] - translation_tasks = [] - # unit here is a sentence +async def _translate_and_queue( + batches: list[Batch], + queue: asyncio.Queue, + project: str, + translator: Translator, + sentence_splitter: SentenceSplitter, + worker_config: TranslationWorkerConfig, + es_client: ESClient, + progress: ProgressRateHandler | None = None, # noqa: F821 +) -> int: + n_docs = sum(len(b) for b in batches) + if not n_docs: + return n_docs + if progress is not None: + progress = to_raw_progress(progress, max_progress=n_docs) + source = translator.source + target = translator.target + model = translator.registered_name + translation_factory = partial( + ESTranslation, source_language=source, target_language=target, translator=model + ) seen = 0 - total = 0 - - for doc_id_batch in doc_id_batches: - sentences_batches = _get_doc_contents_and_split_on_sentences( + buffer = [] + current_doc = None + current_doc_translation = [] + n_batches = len(batches) + for batch_i, doc_ids in enumerate(batches): + logger.debug("translating batch %s / %s", batch_i, n_batches) + docs = _poll_from_es( es_client, project, - doc_id_batch, - translation_ensemble.sentencizer, - worker_config.batch_size, + body={QUERY: has_id(doc_ids)}, + source_includes=TRANSLATION_DOC_SOURCES, ) + doc_sents = _split_sentences(docs, sentence_splitter) + # TODO: ideally we should aim at having almost constant size batches, + # by using some sort of binarization / binning. That will also improve add + # context to translate short sentences. A split strategy + # like adding min and max batch_item length should help + doc_sent_batches = async_batches(doc_sents, batch_size=worker_config.batch_size) + async for batch in doc_sent_batches: + batch_docs, sents = zip(*batch, strict=False) + # Run translation 1 batch at the time, parallelization is controlled + # via the batch_size + translated_sents = await asyncio.to_thread(translator.translate, sents) + for doc, translated_sent in zip(batch_docs, translated_sents, strict=True): + if current_doc is not None and doc.id != current_doc.id: + translation = translation_factory(content=current_doc_translation) + buffer.append((current_doc, translation)) + if len(buffer) >= worker_config.es_buffer_size: + queue.put_nowait(buffer) + buffer = [] + seen += 1 + if progress is not None: + await progress(seen) + current_doc_translation = [] + current_doc = doc + current_doc_translation.append(translated_sent) + logger.debug("batch %s / %s translated !", batch_i, n_batches) + # Empty the buffer + if current_doc_translation: + translation = translation_factory(content=current_doc_translation) + buffer.append((current_doc, translation)) + queue.put_nowait(buffer) + return n_docs + + +async def _write_translations_to_es( + es_client: ESClient, queue: asyncio.Queue, project: str +) -> None: + while True: + translated_docs = await queue.get() + if translated_docs is None: + logger.debug("popped poison pill from the queue, exiting !") + queue.task_done() + return + logger.debug("writing translations to the index..") + await _update_docs_translation(es_client, translated_docs, project=project) + logger.debug("translation written !") + queue.task_done() - # Create translation tasks - async for sentences_batch in sentences_batches: - n_sentences = len(sentences_batch) - if not n_sentences: - continue - - # Convert the progress to a "raw" progress to update the progress - # incrementally rather than setting the progress rate - if progress is not None: - progress = to_raw_progress(progress, max_progress=n_sentences) - total += n_sentences - - # Translate - translation_tasks.append( - asyncio.create_task( - _translate_batch( - sentences_batch, translation_ensemble, worker_config.beam_size - ) - ) - ) - - all_sentences += sentences_batch - - # Run translation tasks - for task in asyncio.as_completed(translation_tasks): - translation_batch = await task - all_translations.extend(translation_batch) - - seen += len(translation_batch) - - if progress is not None: - await progress(int(seen / total)) - - all_translations = await asyncio.gather(*translation_tasks) - all_translations = [ - translation for batch in all_translations for translation in batch - ] - - # Reconstruct documents from sentences - # TODO: separate into a function for testing - reconstructed_docs = defaultdict(dict) - - for batch_sentence, translation in zip( - all_sentences, all_translations, strict=False - ): - key = batch_sentence.doc_id, batch_sentence.root_document - reconstructed_docs[key][batch_sentence.sentence_index] = translation - - # Combine sentences into translations and key with doc_id and root_document - # for insertion - translations_with_doc_ids_and_root_doc = [] - - for (doc_id, root_document), sentence_idx_mapping in reconstructed_docs.items(): - ordered_translation = " ".join( - [translation for (_, translation) in sorted(sentence_idx_mapping.items())] - ) - seen += len(ordered_translation) - - translations_with_doc_ids_and_root_doc.append( - (doc_id, root_document, ordered_translation) - ) - - if progress is not None: - await progress(int(seen / total)) - - await _add_translation( - es_client, - translations_with_doc_ids_and_root_doc, - project, - target_language_alpha_code=target_language, - ) - # Return the number of translated documents - return len(reconstructed_docs) - - -# async -async def _get_doc_contents_and_split_on_sentences( - es_client: ESClient, - project: str, - doc_ids: list[str], - sentencizer: Sentencizer, - sentence_batch_size: int = 16, -) -> AsyncGenerator[list[BatchSentence] | None, Any]: - if len(doc_ids) == 0: - return - - batch_gen = _async_query_es( - es_client, - project, - body={QUERY: has_id(doc_ids)}, - source_includes=TRANSLATION_DOC_SOURCES, - ) - - async for batch in _iter_sentences(batch_gen, sentencizer, sentence_batch_size): - yield batch - - -async def _iter_sentences( - doc_iter: AsyncGenerator[dict, None], - sentencizer: Sentencizer, - sentence_batch_size: int = 16, -) -> AsyncGenerator[list[BatchSentence], None]: - sentence_batch = [] +async def _split_sentences( + doc_iter: AsyncGenerator[dict, None], sentence_splitter: SentenceSplitter +) -> AsyncGenerator[tuple[Document, str], None]: async for doc in doc_iter: es_doc = Document.from_es(doc) - sentences = await asyncio.to_thread(sentencizer, es_doc.content) - for idx, sentence in enumerate(sentences): - sentence_batch.append( - BatchSentence( - doc_id=es_doc.id, - root_document=es_doc.root_document, - sentence_index=idx, - sentence=sentence, - ) - ) - - if len(sentence_batch) >= sentence_batch_size: - yield sentence_batch - sentence_batch = [] - - if len(sentence_batch) > 0: - yield sentence_batch - - -async def _translate_batch( - sentence_batch: list[BatchSentence], - translation_ensemble: TranslationEnsemble, - max_parallel_batches: int = 8, - beam_size: int = 4, -) -> list[str]: - async with asyncio.Semaphore(max_parallel_batches): - return await asyncio.to_thread( - translate_as_list, sentence_batch, translation_ensemble, beam_size + sentences = await asyncio.to_thread( + sentence_splitter.split_sentences, es_doc.content ) + for sentence in sentences: + yield es_doc, sentence async def _get_es_docs( es_client: ESClient, project: str, - target_language_alpha_code: str, + target: Language, source_includes: list[str], ) -> AsyncGenerator[AsyncIterator[dict], None]: # Get all documents that are not in the target language sorted by language - docs = _async_query_es( + docs = _poll_from_es( es_client, project, - body=_untranslated_query(target_language_alpha_code), + body=_untranslated_query(target), source_includes=source_includes, sort=[f"{DOC_LANGUAGE}:asc", "_doc:asc"], ) @@ -352,69 +297,66 @@ async def _get_es_docs( except StopAsyncIteration: return current_language = next_doc[SOURCE][DOC_LANGUAGE] - # Consume the iterator until we find a doc with a different language language_docs, docs = before_and_after( - docs, predicate=partial(has_language, language=current_language) + docs, predicate=partial(_has_language, language=current_language) ) # Group all docs of same language grouped_docs = chain(once(next_doc), language_docs) - yield aiter(grouped_docs) -_SCRIPT_SOURCES = """ -if( !ctx._source.containsKey("content_translated") ) { - ctx._source.content_translated = new HashMap(); -} -ctx._source.content_translated[params.language] = params.translation; +_SCRIPT_SOURCES = f""" +if (ctx._source.{DOC_CONTENT_TRANSLATED} == null) {{ + ctx._source.{DOC_CONTENT_TRANSLATED} = [params.translation]; +}} else {{ + def existing = ctx._source.{DOC_CONTENT_TRANSLATED}; + for (int i = 0; i < existing.size(); i++) {{ + if (existing[i].source_language == params.translation.source_language + && existing[i].target_language == params.translation.target_language) {{ + if (existing[i].content == params.translation.content) {{ + ctx.op = 'none'; // skip write if identical + return; + }} + existing[i] = params.translation; + return; + }} + }} + existing.add(params.translation); +}} """ -async def _add_translation( +async def _update_docs_translation( es_client: ESClient, - translations: Iterable[tuple[Document, str]], + translated_docs: Iterable[tuple[Document, ESTranslation]], project: str, - *, - target_language_alpha_code: str, ) -> None: actions = ( { "_op_type": "update", "_index": project, - "_routing": root_document, - ID_: doc_id, + "_routing": doc.root_document, + ID_: doc.id, "script": { "source": _SCRIPT_SOURCES, "lang": "painless", - "params": { - "language": target_language_alpha_code, - "translation": translation, - }, + "params": {"translation": translation.to_es()}, }, } - for doc_id, root_document, translation in translations + for doc, translation in translated_docs ) await async_bulk(es_client, actions, raise_on_error=True, refresh="wait_for") -def _untranslated_query(target_language_alpha_code: str) -> dict: - query = { - "query": { - BOOL: must_not( - { - "exists": { - "field": f"content_translated.{target_language_alpha_code}" - } - }, - {TERM: {DOC_LANGUAGE: target_language_alpha_code}}, - ) - } - } +def _untranslated_query(target: Language) -> dict: + query = bool_query( + must_not({TERM: {f"{DOC_CONTENT_TRANSLATED}.target_language": target.upper()}}) + ) return query -async def _async_query_es( +async def _poll_from_es( es_client: ESClient, project: str, *, @@ -423,15 +365,44 @@ async def _async_query_es( sort: list[str] = None, ) -> AsyncGenerator[dict, None]: async for res in es_client.poll_search_pages( - index=project, - body=body, - _source_includes=source_includes, - sort=sort, + index=project, body=body, _source_includes=source_includes, sort=sort ): for hit in res[HITS][HITS]: yield hit +def _has_language(doc: dict, language: str) -> bool: + return doc[SOURCE][DOC_LANGUAGE] == language + + +async def _publish_and_consume( + publisher: asyncio.Task, + publisher_completion_callback: Callable[[], None], + *, + consumer: asyncio.Task, +) -> tuple[Any, Any]: + # Publish and consume concurrently + logger.debug("starting publish and subscribe") + done, pending = await asyncio.wait( + [publisher, consumer], return_when=asyncio.FIRST_COMPLETED + ) + for d in done: + # Stop everything case of exception + exc = d.exception() + if exc: + for p in pending: + p.cancel() + raise exc + # Wait for publish to be done and push the poison pill to stop consuming + p_res = await publisher + publisher_completion_callback() + logger.debug("done publishing, waiting for consumer to complete...") + # Wait for consumption to be done + c_res = await consumer + logger.debug("done consuming !") + return p_res, c_res + + ACTIVITIES = [ TranslationActivities.translation_worker_config, TranslationActivities.create_translation_batches, diff --git a/workers/translation-worker/translation_worker/config.py b/workers/translation-worker/translation_worker/config.py index 1469bb72..1de76b46 100644 --- a/workers/translation-worker/translation_worker/config.py +++ b/workers/translation-worker/translation_worker/config.py @@ -1,20 +1,26 @@ from datashare_python.config import WorkerConfig +from datashare_python.objects import DatashareModel from pydantic import Field from .constants import TorchDevice -class TranslationWorkerConfig(WorkerConfig): - device: TorchDevice = Field(default=TorchDevice.CPU, frozen=True) - - batch_size: int = 16 - max_parallel_batches: int = 8 - max_batch_byte_len: int = 1000000 - # ctranslate2 params +class C2TranslateConfig(DatashareModel): beam_size: int = 4 inter_threads: int = 1 intra_threads: int = 0 compute_type: str = "auto" # quantization +class TranslationWorkerConfig(WorkerConfig): + device: TorchDevice = Field(default=TorchDevice.CPU, frozen=True) + + batch_size: int = 16 + batch_text_length: int = 10000 + batches_per_worker: int = 10 + es_buffer_size: int = 10 + + c2_translate: C2TranslateConfig = Field(default_factory=C2TranslateConfig) + + WORKER_CONFIG_CLS = TranslationWorkerConfig diff --git a/workers/translation-worker/translation_worker/constants.py b/workers/translation-worker/translation_worker/constants.py index 19de699e..79e88738 100644 --- a/workers/translation-worker/translation_worker/constants.py +++ b/workers/translation-worker/translation_worker/constants.py @@ -19,7 +19,6 @@ class TorchDevice(StrEnum): TRANSLATION_WORKER_NAME = "translation-worker" TRANSLATION_WORKFLOW_NAME = "translation" -CONTENT_LENGTH = "content_length" - +DOC_CONTENT_TEXT_LENGTH = "contentTextLength" TRANSLATION_DOC_SOURCES = [DOC_CONTENT, DOC_ROOT_ID, DOC_LANGUAGE] -BATCHING_DOC_SOURCES = TRANSLATION_DOC_SOURCES[1:] + [CONTENT_LENGTH] +BATCHING_DOC_SOURCES = TRANSLATION_DOC_SOURCES[1:] + [DOC_CONTENT_TEXT_LENGTH] diff --git a/workers/translation-worker/translation_worker/core.py b/workers/translation-worker/translation_worker/core.py deleted file mode 100644 index 1c4cfbcc..00000000 --- a/workers/translation-worker/translation_worker/core.py +++ /dev/null @@ -1,253 +0,0 @@ -import logging -from collections.abc import Generator, Iterable -from dataclasses import dataclass -from typing import TYPE_CHECKING, Protocol - -from icij_common.es import DOC_LANGUAGE, SOURCE - -from .constants import CONTENT_LENGTH -from .objects import BatchSentence -from .utils import find_device - -if TYPE_CHECKING: - from argostranslate.package import Package - from argostranslate.tokenizer import BPETokenizer, SentencePieceTokenizer - from argostranslate.translate import PackageTranslation - from ctranslate2 import Translator - from spacy import Language - -logging.basicConfig(level=logging.INFO) - -logger = logging.getLogger(__name__) - - -class Sentencizer(Protocol): - def __call__(self, text: str) -> Iterable[str]: ... - - -@dataclass(frozen=True) -class TranslationEnsemble: - tokenizer: "SentencePieceTokenizer | BPETokenizer" - sentencizer: "Sentencizer" - translator: "Translator" - target_prefix: str = "" - - -def translate_as_list( - sentence_batch: list[BatchSentence], - translation_ensemble: TranslationEnsemble, - beam_size: int, -) -> list[str]: - sentence_batch = [s.sentence for s in sentence_batch] - return list(_translate(sentence_batch, translation_ensemble, beam_size)) - - -def _translate( - sentence_batch: list[str], - translation_ensemble: "TranslationEnsemble", - beam_size: int, -) -> Generator[str, None, None]: - tokenized_sentences = [ - translation_ensemble.tokenizer.encode(sentence) for sentence in sentence_batch - ] - - target_prefix = None - - if translation_ensemble.target_prefix != "": - target_prefix = [[translation_ensemble.target_prefix]] * len( - tokenized_sentences - ) - - for translation_result in translation_ensemble.translator.translate_batch( - tokenized_sentences, - target_prefix=target_prefix, - replace_unknowns=True, - batch_type="tokens", - beam_size=beam_size, - num_hypotheses=1, - length_penalty=0.2, - return_scores=True, - ): - hypothesis = translation_result.hypotheses[0] - decoded_translation = translation_ensemble.tokenizer.decode(hypothesis) - - if translation_ensemble.target_prefix != "" and decoded_translation.startswith( - translation_ensemble.target_prefix - ): - # Remove target prefix - decoded_translation = decoded_translation[ - len(translation_ensemble.target_prefix) : - ] - - yield decoded_translation - - -def has_language(doc: dict, language: str) -> bool: - return doc[SOURCE][DOC_LANGUAGE] == language - - -def _has_language_or_exceeds_max_len( - doc: dict, language: str, current_batch_byte_len: int, max_batch_byte_len: int -) -> bool: - return ( - doc[SOURCE][DOC_LANGUAGE] == language - or doc[SOURCE][CONTENT_LENGTH] + current_batch_byte_len > max_batch_byte_len - ) - - -def _get_argos_package( - source_language_alpha_code: str, target_language_alpha_code: str -) -> "Package | None": - from argostranslate.package import get_installed_packages # noqa: PLC0415 - - available_packages = get_installed_packages() - return next( - filter( - lambda x: ( - x.from_code == source_language_alpha_code - and x.to_code == target_language_alpha_code - ), - available_packages, - ), - None, - ) - - -def _get_argos_languages( - *languages_to_find: str, -) -> tuple["Language", ...]: - from argostranslate.translate import get_installed_languages # noqa: PLC0415 - - if not isinstance(languages_to_find, (list, tuple)): - languages_to_find = [languages_to_find] - - languages = [] - available_languages = get_installed_languages() - - for language_to_find in languages_to_find: - language_result = next( - filter(lambda x: x.code == language_to_find, available_languages), None - ) - - if language_result is None: - continue - - languages.append(language_result) - - return tuple(languages) - - -def _get_or_download_argos_languages( - source_language_alpha_code: str, target_language_alpha_code: str -) -> tuple["Language", ...]: - from argostranslate.package import ( # noqa: PLC0415 - get_available_packages, - install_from_path, - update_package_index, - ) - - package = _get_argos_package(source_language_alpha_code, target_language_alpha_code) - - if package is None: - logger.info( - "Package %s -> %s not found locally. Checking index.", - source_language_alpha_code, - target_language_alpha_code, - ) - update_package_index() - available_packages = get_available_packages() - package_to_install = next( - filter( - lambda x: ( - x.from_code == source_language_alpha_code - and x.to_code == target_language_alpha_code - ), - available_packages, - ), - None, - ) - - if package_to_install is not None: - logger.info("Downloading argos package %s", package_to_install) - install_from_path(package_to_install.download()) - - _get_argos_package(source_language_alpha_code, target_language_alpha_code) - - return _get_argos_languages(source_language_alpha_code, target_language_alpha_code) - - -def get_translation_ensemble( - source_language_alpha_code: str, - target_language_alpha_code: str, - device: str = "cpu", - inter_threads: int = 1, - intra_threads: int = 0, - compute_type: str = "auto", -) -> "TranslationEnsemble | None": - from argostranslate.translate import CachedTranslation # noqa: PLC0415 - - # Create batches per language - language_packages = _get_or_download_argos_languages( - source_language_alpha_code, target_language_alpha_code - ) - - if len(language_packages) < 2: - logger.exception( - "Language model for %s and/or %s not available. Skipping translation.", - source_language_alpha_code, - target_language_alpha_code, - ) - return None - - source_language_pkg, target_language_pkg = language_packages - # This is one of the weirder things about argos; it thinks of a translation - # from language to another as a functional mapping and so treats it as an object - argos_translation_package: PackageTranslation | None = ( - source_language_pkg.get_translation(target_language_pkg) - ) - - if argos_translation_package is None: - logger.exception( - "No translation model exists from %s to %s. Skipping translation.", - source_language_alpha_code, - target_language_alpha_code, - ) - return None - - # Another clumsy and non-transparent implementation by Argos; underlying is also - # mistyped for returns (should be PackageTranslation, is marked as ITranslation) - if isinstance(argos_translation_package, CachedTranslation): - argos_translation_package: PackageTranslation = ( - argos_translation_package.underlying - ) - - return _get_translation_ensemble_from_argos_package( - argos_translation_package, device, inter_threads, intra_threads, compute_type - ) - - -def _get_translation_ensemble_from_argos_package( - argos_package: "PackageTranslation", - device: str, - inter_threads: int, - intra_threads: int, - compute_type: str, -) -> "TranslationEnsemble": - import ctranslate2 # noqa: PLC0415 - - model_path = str(argos_package.pkg.package_path / "model") - device = find_device(device) - translator = ctranslate2.Translator( - model_path, - device=device, - inter_threads=inter_threads, - intra_threads=intra_threads, - compute_type=compute_type, - ) - - return TranslationEnsemble( - sentencizer=argos_package.sentencizer.split_sentences, - tokenizer=argos_package.pkg.tokenizer, - translator=translator, - target_prefix=argos_package.pkg.target_prefix, - ) diff --git a/workers/translation-worker/translation_worker/objects.py b/workers/translation-worker/translation_worker/objects.py index bfb59003..45007cff 100644 --- a/workers/translation-worker/translation_worker/objects.py +++ b/workers/translation-worker/translation_worker/objects.py @@ -1,27 +1,151 @@ -from typing import Annotated, Any +from abc import ABC +from enum import StrEnum +from typing import TYPE_CHECKING, Annotated, Any, ClassVar, Self -from datashare_python.objects import DatashareModel -from pydantic import BaseModel, BeforeValidator -from pydantic_extra_types.language_code import LanguageName +import langcodes +from datashare_python.objects import BaseModel, DatashareLanguage, DatashareModel +from icij_common.registrable import RegistrableConfig +from pydantic import BeforeValidator, Field, GetCoreSchemaHandler +from pydantic_core import core_schema +from pydantic_core.core_schema import PlainValidatorFunctionSchema +from .processors import SentenceSplitter, Translator -def _to_language_name(value: Any) -> Any: - if isinstance(value, str): - return value.title() - return value +if TYPE_CHECKING: + from argostranslate.sbd import ISentenceBoundaryDetectionModel + + +class _BaseProcessorConfig(BaseModel, RegistrableConfig, ABC): ... + + +class SentenceSplitterModel(StrEnum): + ARGOS = "ARGOS" + + +class TranslationModel(StrEnum): + ARGOS = "ARGOS" + + +class SentenceSplitterConfig(_BaseProcessorConfig): + registry_key: ClassVar[str] = Field(frozen=True, default="model") + model: ClassVar[SentenceSplitterModel] + + +class ArgosSentencizer(StrEnum): + SPACY_SMALL = "spacy_small" + MINI_SBD = "mini_sbd" + + @property + def sentencizer_cls(self) -> type["ISentenceBoundaryDetectionModel"]: + from argostranslate.sbd import ( # noqa: PLC0415 + MiniSBDSentencizer, + SpacySentencizerSmall, + ) + + match self: + case ArgosSentencizer.SPACY_SMALL: + return SpacySentencizerSmall + case ArgosSentencizer.MINI_SBD: + return MiniSBDSentencizer + case _: + raise NotImplementedError() + + +class ArgosSentenceSplitterConfig(SentenceSplitterConfig): + model: ClassVar[SentenceSplitterModel] = SentenceSplitterModel.ARGOS + + sentencizer: ArgosSentencizer = ArgosSentencizer.MINI_SBD + + +class TranslatorConfig(_BaseProcessorConfig): + registry_key: ClassVar[str] = Field(frozen=True, default="model") + model: ClassVar[TranslationModel] + + +class ArgosTranslatorConfig(TranslatorConfig): + model: ClassVar[TranslationModel] = TranslationModel.ARGOS + + beam_size: int = 2 + length_penalty: float = 0.2 + + +# TODO: uncomment when adding more implems +# _SentenceSplitterConfig = tagged_union( +# SentenceSplitterConfig.__subclasses__(), lambda t: t.model.default.value +# ) +# splitter_discriminator = make_enum_discriminator("model", SentenceSplitterModel) + +# TODO: uncomment when adding more implems +# _TranslatorConfig = tagged_union( +# TranslatorConfig.__subclasses__(), lambda t: t.model.default.value +# ) +# translator_discriminator = make_enum_discriminator("model", TranslationModel) + + +class TranslationConfig(DatashareModel): + sentence_splitter: ArgosSentenceSplitterConfig = Field( + # TODO: uncomment when adding more implem + # discriminator=Discriminator(model_discriminator=splitter_discriminator), + default_factory=ArgosSentenceSplitterConfig, + ) + translator: ArgosTranslatorConfig = Field( + # discriminator=Discriminator(model_discriminator=splitter_discriminator), + default_factory=ArgosTranslatorConfig, + ) + + def to_sentence_splitter(self) -> "SentenceSplitter": + from .processors import SentenceSplitter # noqa: PLC0415 + + return SentenceSplitter.from_config(self.sentence_splitter) + + def to_translator(self) -> "Translator": + from .processors import Translator # noqa: PLC0415 + + return Translator.from_config(self.translator) + + +class IETFLanguage(str): + @classmethod + def __get_pydantic_core_schema__( + cls, source: Any, handler: GetCoreSchemaHandler + ) -> PlainValidatorFunctionSchema: + return core_schema.no_info_plain_validator_function(cls.validate) + + @classmethod + def validate(cls, v: Any) -> Self: + tag = langcodes.get(str(v)) + if not tag.is_valid(): + raise ValueError(f"Invalid IETF language: {v}") + return cls(v) + + +Language = DatashareLanguage | IETFLanguage class TranslationArgs(DatashareModel): project: str - target_language: Annotated[LanguageName, BeforeValidator(_to_language_name)] + config: TranslationConfig = Field(default_factory=TranslationConfig) + target_language: Language class TranslationResponse(DatashareModel): n_translations: int = 0 -class BatchSentence(BaseModel): - doc_id: str - root_document: str - sentence_index: int - sentence: str +def _from_sentences(value: Any) -> Any: + if isinstance(value, list): + return " ".join(value) + return value + + +class ESTranslation(BaseModel): # No camelcase here we don't know why + source_language: Language + target_language: Language + translator: TranslationModel + content: Annotated[str, BeforeValidator(_from_sentences)] + + def to_es(self) -> dict[str, Any]: + as_dict = self.model_dump() + as_dict["source_language"] = self.source_language.upper() + as_dict["target_language"] = self.target_language.upper() + return as_dict diff --git a/workers/translation-worker/translation_worker/processors.py b/workers/translation-worker/translation_worker/processors.py new file mode 100644 index 00000000..941176f4 --- /dev/null +++ b/workers/translation-worker/translation_worker/processors.py @@ -0,0 +1,90 @@ +import logging +from abc import abstractmethod +from collections.abc import Iterable +from contextlib import contextmanager +from typing import TYPE_CHECKING, Self, final + +from icij_common.registrable import RegistrableFromConfig + +from .config import TranslationWorkerConfig + +if TYPE_CHECKING: + from .objects import Language +logger = logging.getLogger(__name__) + + +class SentenceSplitter(RegistrableFromConfig): + @abstractmethod + def split_sentences(self, text: str) -> list[str]: ... # noqa: F821 + + def split_sentences_batch(self, batch: list[str]) -> list[list[str]]: + return [self.split_sentences(t) for t in batch] + + @final + @contextmanager + def load(self, language: "Language") -> Self: + with self: + self._load(language) + yield self + + def _load(self, language: "Language") -> Self: ... + + @final + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): ... # noqa: ANN001 + + +class Translator(RegistrableFromConfig): + def __init__(self): + self._source: Language | None = None + self._target: Language | None = None + + @abstractmethod + def translate(self, texts: Iterable[str]) -> list[str]: ... # noqa: F821 + + @property + def source(self) -> "Language": + if self._source is None: + raise ValueError("translator has no source language as it was not loaded") + return self._source + + @property + def target(self) -> "Language": + if self._target is None: + raise ValueError("translator has no target language as it was not loaded") + return self._target + + @contextmanager + @final + def load( + self, + source: "Language", + *, + target: "Language", + worker_config: "TranslationWorkerConfig | None" = None, + ) -> Self: + if worker_config is not None: + worker_config = TranslationWorkerConfig() + with self: + self._load(source, target=target, worker_config=worker_config) + yield self + + def _load( + self, + source: "Language", + *, + target: "Language", + worker_config: "TranslationWorkerConfig", # noqa: ARG002 + ) -> None: + self._source = source + self._target = target + + @final + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): # noqa: ANN001 + self._source = None + self._target = None diff --git a/workers/translation-worker/translation_worker/sentence_splitters.py b/workers/translation-worker/translation_worker/sentence_splitters.py new file mode 100644 index 00000000..2078d642 --- /dev/null +++ b/workers/translation-worker/translation_worker/sentence_splitters.py @@ -0,0 +1,42 @@ +import gc +from typing import Self + +from datashare_python.objects import DatashareLanguage +from pydantic_extra_types.language_code import LanguageAlpha2 + +from .objects import ArgosSentenceSplitterConfig, Language, SentenceSplitterModel +from .processors import SentenceSplitter + + +@SentenceSplitter.register(SentenceSplitterModel.ARGOS) +class ArgosSentenceSplitter(SentenceSplitter): + def __init__(self, config: ArgosSentenceSplitterConfig): + self._config = config + self._inner = None + + def _load(self, language: Language) -> Self: + from argostranslate.package import get_installed_packages # noqa: PLC0415 + + if isinstance(language, DatashareLanguage): + language = LanguageAlpha2(language.alpha2) + + installed = get_installed_packages() + p = next((p for p in installed if p.from_code == language), None) + if p is None: + msg = ( + f"unknown language: {language}, install the translation package first" + f" in order to make the sbd package available" + ) + raise ValueError(msg) # noqa: PLC0415 + self._inner = self._config.sentencizer.sentencizer_cls(p) + + def __exit__(self, exc_type, exc_val, exc_tb): # noqa: ANN001 + del self._inner + gc.collect() + + def split_sentences(self, text: str) -> list[str]: + return self._inner.split_sentences(text) + + @classmethod + def _from_config(cls, config: ArgosSentenceSplitterConfig) -> Self: + return cls(config) diff --git a/workers/translation-worker/translation_worker/translators/__init__.py b/workers/translation-worker/translation_worker/translators/__init__.py new file mode 100644 index 00000000..cb403819 --- /dev/null +++ b/workers/translation-worker/translation_worker/translators/__init__.py @@ -0,0 +1,4 @@ +try: + from .argos import ArgosTranslator +except ImportError: + ArgosTranslator = None diff --git a/workers/translation-worker/translation_worker/translators/argos.py b/workers/translation-worker/translation_worker/translators/argos.py new file mode 100644 index 00000000..b2d9f8ed --- /dev/null +++ b/workers/translation-worker/translation_worker/translators/argos.py @@ -0,0 +1,194 @@ +import gc +import logging +from collections.abc import Iterable +from typing import TYPE_CHECKING, Self + +from datashare_python.objects import DatashareLanguage +from pydantic_extra_types.language_code import LanguageAlpha2 + +from ..config import TranslationWorkerConfig +from ..objects import Language, TranslationModel +from ..processors import Translator +from ..utils import find_device + +if TYPE_CHECKING: + from argostranslate.package import Package + from argostranslate.translate import Language as ArgosLanguage + from argostranslate.translate import PackageTranslation + + from ..objects import ArgosTranslatorConfig + + +logger = logging.getLogger(__name__) + + +@Translator.register(TranslationModel.ARGOS) +class ArgosTranslator(Translator): + def __init__(self, config: "ArgosTranslatorConfig"): + super().__init__() + self._translator = None + self._tokenizer = None + self._target_prefix = None + self._config = config + + @classmethod + def _from_config(cls, config: "ArgosTranslatorConfig", **extras) -> Self: # noqa: ARG003 + return cls(config) + + def _load( + self, + source: Language, + *, + target: Language, + worker_config: TranslationWorkerConfig, + ) -> None: + import ctranslate2 # noqa: PLC0415 + + super()._load(source, target=target, worker_config=worker_config) + + if isinstance(source, DatashareLanguage): + source = LanguageAlpha2(source.alpha2) + if isinstance(target, DatashareLanguage): + target = LanguageAlpha2(target.alpha2) + translation_pkg = _load_translation_package(source=source, target=target) + model_path = str(translation_pkg.pkg.package_path / "model") + device = find_device(worker_config.device) + self._tokenizer = translation_pkg.pkg.tokenizer + self._target_prefix = translation_pkg.pkg.target_prefix + self._translator = ctranslate2.Translator( + model_path, + device=device, + inter_threads=worker_config.c2_translate.inter_threads, + intra_threads=worker_config.c2_translate.intra_threads, + compute_type=worker_config.c2_translate.compute_type, + ) + + def translate(self, texts: Iterable[str]) -> list[str]: + tokenized = [self._tokenizer.encode(t) for t in texts] + target_prefix = None + if self._target_prefix: + target_prefix = [[self._target_prefix]] * len(tokenized) + + translation_results = self._translator.translate_batch( + tokenized, + target_prefix=target_prefix, + replace_unknowns=True, + batch_type="tokens", + beam_size=self._config.beam_size, + num_hypotheses=1, + length_penalty=self._config.length_penalty, + return_scores=True, + ) + best_hyps = (res.hypotheses[0] for res in translation_results) + decoded = (self._tokenizer.decode(hyp) for hyp in best_hyps) + translated = [] + if self._target_prefix: + for d in decoded: + if d.startswith(self._target_prefix): + translated.append(d[len(self._target_prefix) :]) + else: + translated.append(d) + else: + translated = list(decoded) + return list(translated) + + def __exit__(self, exc_type, exc_val, exc_tb): # noqa: ANN001 + self._tokenizer = None + self._target_prefix = None + del self._translator + self._translator = None + gc.collect() + + +def _get_argos_package( + source: LanguageAlpha2, *, target: LanguageAlpha2 +) -> "Package | None": + from argostranslate.package import get_installed_packages # noqa: PLC0415 + + for p in get_installed_packages(): + if p.from_code == source and p.to_code == target: + return p + return None + + +def get_argos_languages( + *languages_to_find: LanguageAlpha2, +) -> tuple["ArgosLanguage", ...]: + from argostranslate.translate import get_installed_languages # noqa: PLC0415 + + if not isinstance(languages_to_find, (list, tuple)): + languages_to_find = [languages_to_find] + + languages = [] + available_languages = get_installed_languages() + + for language_to_find in languages_to_find: + language_result = next( + filter(lambda x: x.code == language_to_find, available_languages), None + ) + + if language_result is None: + continue + + languages.append(language_result) + + return tuple(languages) + + +def _get_or_download_argos_languages( + source: LanguageAlpha2, *, target: LanguageAlpha2 +) -> tuple["ArgosLanguage", ...]: + from argostranslate.package import ( # noqa: PLC0415 + get_available_packages, + install_from_path, + update_package_index, + ) + + package = _get_argos_package(source, target=target) + if package is None: + logger.info( + "package %s -> %s not found locally. Checking index...", + source, + target, + ) + update_package_index() + available_packages = get_available_packages() + package_to_install = next( + filter( + lambda x: x.from_code == source and x.to_code == target, + available_packages, + ), + None, + ) + if package_to_install is not None: + logger.info("downloading argos package %s", package_to_install) + install_from_path(package_to_install.download()) + + return get_argos_languages(source, target) + + +def _load_translation_package( + source: LanguageAlpha2, *, target: LanguageAlpha2 +) -> "PackageTranslation": + from argostranslate.translate import CachedTranslation # noqa: PLC0415 + + # TODO: we should pre-download and cache instead + language_packages = _get_or_download_argos_languages(source, target=target) + + if len(language_packages) < 2: + msg = f"Language model for {source} and/or {target} not available" + raise ValueError(msg) + + source_pkg, target_pkg = language_packages + # This is one of the weirder things about argos; it thinks of a translation + # from language to another as a functional mapping and so treats it as an object + argos_translation_package = source_pkg.get_translation(target_pkg) + if argos_translation_package is None: + msg = f"No translation model exists from {source} to {target}." + raise ValueError(msg) + # Another clumsy and non-transparent implementation by Argos; underlying is also + # mistyped for returns (should be PackageTranslation, is marked as ITranslation) + if isinstance(argos_translation_package, CachedTranslation): + argos_translation_package = argos_translation_package.underlying + + return argos_translation_package diff --git a/workers/translation-worker/translation_worker/workflows.py b/workers/translation-worker/translation_worker/workflows.py index 71c4000b..3f72b0fc 100644 --- a/workers/translation-worker/translation_worker/workflows.py +++ b/workers/translation-worker/translation_worker/workflows.py @@ -1,12 +1,14 @@ import asyncio from datetime import timedelta +from icij_common.iter_utils import batches from temporalio import workflow with workflow.unsafe.imports_passed_through(): from datashare_python.utils import WorkflowWithProgress, execute_activity from .activities import TranslationActivities + from .config import TranslationWorkerConfig from .constants import TRANSLATION_WORKFLOW_NAME, TaskQueue from .objects import TranslationArgs, TranslationResponse @@ -15,11 +17,18 @@ class TranslationWorkflow(WorkflowWithProgress): @workflow.run async def run(self, args: TranslationArgs) -> TranslationResponse: + # Get the config from the worker + worker_config: TranslationWorkerConfig = await execute_activity( + TranslationActivities.translation_worker_config, + task_queue=TaskQueue.IO, + start_to_close_timeout=timedelta(minutes=1), + ) + batches_per_worker = worker_config.batches_per_worker # Create translation batches - target_language = args.target_language.alpha2 - translation_batch_args = [args.project, target_language] - translation_batches: list[tuple[str, list[list[str]]]] - translation_batches = await execute_activity( + target = args.target_language + translation_batch_args = [args.project, target] + per_language_batches: list[tuple[str, list[list[str]]]] + per_language_batches = await execute_activity( TranslationActivities.create_translation_batches, args=translation_batch_args, task_queue=TaskQueue.IO, @@ -28,8 +37,9 @@ async def run(self, args: TranslationArgs) -> TranslationResponse: # Translate translation_args = [ - (id_batch, target_language, args.project) - for id_batch in translation_batches + (b, source, target, args.config, args.project) + for source, languages_batches in per_language_batches + for b in batches(languages_batches, batch_size=batches_per_worker) ] translations_activities = ( execute_activity( diff --git a/workers/translation-worker/uv.dist.lock b/workers/translation-worker/uv.dist.lock index 5364af15..18463b33 100644 --- a/workers/translation-worker/uv.dist.lock +++ b/workers/translation-worker/uv.dist.lock @@ -410,7 +410,7 @@ wheels = [ [[package]] name = "datashare-python" -version = "0.8.18" +version = "0.8.19" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "aiohttp" }, @@ -425,9 +425,9 @@ dependencies = [ { name = "tomlkit" }, { name = "typer" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/3c/42/f086eb8d611d42c5162befd36515f987541f7e103bff8e7e490cedb2f86a/datashare_python-0.8.18.tar.gz", hash = "sha256:052d85140977176268190b20c7468925c7fabb078ce3421266b3b682141fb202", size = 315672, upload-time = "2026-05-22T11:06:57.677Z" } +sdist = { url = "https://files.pythonhosted.org/packages/88/59/23fb6fb2d40a70a83fb7793e1541f34c6a40ad5769751552f4d768b15c34/datashare_python-0.8.19.tar.gz", hash = "sha256:7ec122672d9fd9ae4191ca1e26a3d0213d350514a2384ac20dbfaa98b371680d", size = 315621, upload-time = "2026-05-22T13:56:23.953Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/8e/75/8d96d987f35aecc2177972a8e605a65cc25adcb8b4c949a7c3ab4600d4ae/datashare_python-0.8.18-py3-none-any.whl", hash = "sha256:29eed8086677e2672364a9191f4cfc30d710d66427404e56a9513278def3acb1", size = 321671, upload-time = "2026-05-22T11:06:56.118Z" }, + { url = "https://files.pythonhosted.org/packages/8e/c8/67568e37bbc7bc87aa026599bab523466f3838096b0d5e567e48a99be59a/datashare_python-0.8.19-py3-none-any.whl", hash = "sha256:858ea1d11321b67e73998c03b0aa6dd8b83f433dac4f2b67d60f6b3a0dad66f0", size = 321655, upload-time = "2026-05-22T13:56:22.706Z" }, ] [[package]] @@ -435,6 +435,7 @@ name = "datashare-translation-worker" source = { editable = "." } dependencies = [ { name = "datashare-python" }, + { name = "langcodes" }, { name = "pydantic-extra-types", extra = ["pycountry"] }, ] @@ -461,6 +462,7 @@ dev = [ requires-dist = [ { name = "argostranslate", marker = "extra == 'inference'", specifier = "==1.11.0" }, { name = "datashare-python", specifier = "~=0.8.6" }, + { name = "langcodes", specifier = "==3.5.1" }, { name = "pydantic-extra-types", extras = ["pycountry"], specifier = "==2.11.1" }, ] provides-extras = ["inference"] @@ -742,6 +744,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/7b/91/984aca2ec129e2757d1e4e3c81c3fcda9d0f85b74670a094cc443d9ee949/joblib-1.5.3-py3-none-any.whl", hash = "sha256:5fc3c5039fc5ca8c0276333a188bbd59d6b7ab37fe6632daa76bc7f9ec18e713", size = 309071, upload-time = "2025-12-15T08:41:44.973Z" }, ] +[[package]] +name = "langcodes" +version = "3.5.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/a9/75/f9edc5d72945019312f359e69ded9f82392a81d49c5051ed3209b100c0d2/langcodes-3.5.1.tar.gz", hash = "sha256:40bff315e01b01d11c2ae3928dd4f5cbd74dd38f9bd912c12b9a3606c143f731", size = 191084, upload-time = "2025-12-02T16:22:01.627Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/dd/c1/d10b371bcba7abce05e2b33910e39c33cfa496a53f13640b7b8e10bb4d2b/langcodes-3.5.1-py3-none-any.whl", hash = "sha256:b6a9c25c603804e2d169165091d0cdb23934610524a21d226e4f463e8e958a72", size = 183050, upload-time = "2025-12-02T16:21:59.954Z" }, +] + [[package]] name = "markdown-it-py" version = "4.0.0" diff --git a/workers/translation-worker/uv.lock b/workers/translation-worker/uv.lock index 68c9cd11..b47a06f3 100644 --- a/workers/translation-worker/uv.lock +++ b/workers/translation-worker/uv.lock @@ -419,6 +419,7 @@ dependencies = [ { name = "icij-common", extra = ["elasticsearch"] }, { name = "nest-asyncio" }, { name = "orjson" }, + { name = "pydantic-extra-types", extra = ["pycountry"] }, { name = "python-json-logger" }, { name = "pyyaml" }, { name = "temporalio" }, @@ -434,6 +435,7 @@ requires-dist = [ { name = "icij-common", extras = ["elasticsearch"], specifier = "~=0.8.2" }, { name = "nest-asyncio", specifier = "~=1.6" }, { name = "orjson", specifier = "~=3.11" }, + { name = "pydantic-extra-types", extras = ["pycountry"], specifier = ">=2.11.1" }, { name = "python-json-logger", specifier = "~=4.0" }, { name = "pyyaml", specifier = "~=6.0" }, { name = "temporalio", specifier = "~=1.23" }, @@ -461,6 +463,7 @@ name = "datashare-translation-worker" source = { editable = "." } dependencies = [ { name = "datashare-python" }, + { name = "langcodes" }, { name = "pydantic-extra-types", extra = ["pycountry"] }, ] @@ -487,6 +490,7 @@ dev = [ requires-dist = [ { name = "argostranslate", marker = "extra == 'inference'", specifier = "==1.11.0" }, { name = "datashare-python", editable = "../../datashare-python" }, + { name = "langcodes", specifier = "==3.5.1" }, { name = "pydantic-extra-types", extras = ["pycountry"], specifier = "==2.11.1" }, ] provides-extras = ["inference"] @@ -768,6 +772,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/7b/91/984aca2ec129e2757d1e4e3c81c3fcda9d0f85b74670a094cc443d9ee949/joblib-1.5.3-py3-none-any.whl", hash = "sha256:5fc3c5039fc5ca8c0276333a188bbd59d6b7ab37fe6632daa76bc7f9ec18e713", size = 309071, upload-time = "2025-12-15T08:41:44.973Z" }, ] +[[package]] +name = "langcodes" +version = "3.5.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/a9/75/f9edc5d72945019312f359e69ded9f82392a81d49c5051ed3209b100c0d2/langcodes-3.5.1.tar.gz", hash = "sha256:40bff315e01b01d11c2ae3928dd4f5cbd74dd38f9bd912c12b9a3606c143f731", size = 191084, upload-time = "2025-12-02T16:22:01.627Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/dd/c1/d10b371bcba7abce05e2b33910e39c33cfa496a53f13640b7b8e10bb4d2b/langcodes-3.5.1-py3-none-any.whl", hash = "sha256:b6a9c25c603804e2d169165091d0cdb23934610524a21d226e4f463e8e958a72", size = 183050, upload-time = "2025-12-02T16:21:59.954Z" }, +] + [[package]] name = "markdown-it-py" version = "4.0.0"