From 21cec378f58196be703c8e3f1752508ae562ae0f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cle=CC=81ment=20Doumouro?= Date: Mon, 1 Jun 2026 15:19:46 +0200 Subject: [PATCH] feature(translation-worker): translate documents from search --- .../tests/test_activities.py | 43 +++++++------------ .../translation-worker/tests/test_objects.py | 40 +++++++++++++++++ .../translation_worker/activities.py | 38 ++++++++-------- .../translation_worker/objects.py | 33 +++++++++++++- .../translation_worker/workflows.py | 2 +- 5 files changed, 105 insertions(+), 51 deletions(-) diff --git a/workers/translation-worker/tests/test_activities.py b/workers/translation-worker/tests/test_activities.py index face527..98106be 100644 --- a/workers/translation-worker/tests/test_activities.py +++ b/workers/translation-worker/tests/test_activities.py @@ -20,16 +20,15 @@ from icij_common.registrable import RegistrableConfig from translation_worker import activities from translation_worker.activities import ( - _get_es_docs, + _get_es_docs_by_language, _split_sentences, - _untranslated_query, _update_docs_translation, create_translation_batches_act, translate_docs_act, ) from translation_worker.config import TranslationWorkerConfig from translation_worker.constants import DOC_CONTENT_TEXT_LENGTH -from translation_worker.objects import Language, TranslationModel +from translation_worker.objects import TranslationModel, untranslated_query from translation_worker.processors import SentenceSplitter, Translator from tests.conftest import ( @@ -135,25 +134,6 @@ async def _aiter(it: Iterable[T]) -> AsyncGenerator[T, None]: DOC_ID_2, content=ES_DOC_2_TEXT, language=SPANISH, root_document=ROOT_DOCUMENT_2 ) - -# _untranslated_query - - -@pytest.mark.parametrize("language", [DS_ENGLISH, DS_FRENCH]) -def test__untranslated_query(language: Language) -> None: - query = _untranslated_query(language) - expected_query = { - "query": { - "bool": { - "must_not": [ - {"term": {"content_translated.target_language.keyword": language}} - ] - } - } - } - assert query == expected_query - - # _iter_sentences _EN_DOC_TO_SPLIT = _make_es_doc( @@ -215,11 +195,12 @@ def _make_batching_doc( async def test__create_translation_batches__returns_empty_list_when_no_docs() -> None: # Given client = MockESClient([]) + query = untranslated_query(DS_ENGLISH) # When result = [ b async for b in create_translation_batches_act( - project=TEST_PROJECT, target=DS_ENGLISH, es_client=client + project=TEST_PROJECT, query=query, es_client=client ) ] # Then @@ -230,11 +211,12 @@ async def test__create_translation_batches__single_doc_creates_one_batch() -> No # Given doc = _make_batching_doc(DOC_ID_1, FRENCH) client = MockESClient([doc]) + query = untranslated_query(DS_ENGLISH) # When result = [ b async for b in create_translation_batches_act( - project=TEST_PROJECT, target=DS_ENGLISH, es_client=client + project=TEST_PROJECT, query=query, es_client=client ) ] # When @@ -246,13 +228,14 @@ async def test__create_translation_batches__single_doc_creates_one_batch() -> No async def test__create_translation_batches__multiple_docs_same_lang_one_batch() -> None: # Given + query = untranslated_query(DS_ENGLISH) docs = [_make_batching_doc(DOC_ID_1, FRENCH), _make_batching_doc(DOC_ID_2, FRENCH)] client = MockESClient(docs) # When result = [ b async for b in create_translation_batches_act( - project=TEST_PROJECT, target=DS_ENGLISH, es_client=client + project=TEST_PROJECT, query=query, es_client=client ) ] # Then @@ -265,6 +248,7 @@ async def test__create_translation_batches__multiple_langs_yield_separate_entrie None ): # Given + query = untranslated_query(DS_ENGLISH) fr_doc = _make_batching_doc(DOC_ID_1, DS_FRENCH) es_doc = _make_batching_doc(DOC_ID_2, DS_SPANISH) docs = [fr_doc, es_doc] @@ -273,7 +257,7 @@ async def test__create_translation_batches__multiple_langs_yield_separate_entrie result = [ b async for b in create_translation_batches_act( - project=TEST_PROJECT, target=DS_ENGLISH, es_client=client + project=TEST_PROJECT, query=query, es_client=client ) ] # Then @@ -289,6 +273,7 @@ async def test__create_translation_batches__splits_batch_if_max_text_len_exceede # Given batch_text_length = 1400 doc_id_3 = "doc_id_3" + query = untranslated_query(DS_ENGLISH) docs = [ _make_batching_doc(DOC_ID_1, FRENCH, content_text_length=600), _make_batching_doc(DOC_ID_2, FRENCH, content_text_length=600), @@ -300,7 +285,7 @@ async def test__create_translation_batches__splits_batch_if_max_text_len_exceede b async for b in create_translation_batches_act( project=TEST_PROJECT, - target=DS_ENGLISH, + query=query, batch_text_length=batch_text_length, es_client=client, ) @@ -517,7 +502,9 @@ async def test__get_es_docs( # When docs = [ [d async for d in group] - async for group in _get_es_docs(es_client, TEST_PROJECT, DS_ENGLISH, []) + async for group in _get_es_docs_by_language( + es_client, TEST_PROJECT, DS_ENGLISH, [] + ) ] docs = [[d[ID_] for d in g] for g in docs] # Then diff --git a/workers/translation-worker/tests/test_objects.py b/workers/translation-worker/tests/test_objects.py index 2803839..f30631e 100644 --- a/workers/translation-worker/tests/test_objects.py +++ b/workers/translation-worker/tests/test_objects.py @@ -1,10 +1,19 @@ +from typing import Any + +import pytest +from datashare_python.conftest import TEST_PROJECT from translation_worker.objects import ( ArgosSentenceSplitterConfig, ArgosSentencizer, ArgosTranslatorConfig, + DocId, + DocumentSearchQuery, + TranslationArgs, TranslationConfig, ) +from .conftest import DS_ENGLISH + def test_config_deser() -> None: # Given @@ -20,3 +29,34 @@ def test_config_deser() -> None: translator=ArgosTranslatorConfig(), ) assert deser == expected + + +_NO_ES_TRANSLATED_CONTENT = { + "bool": { + "must_not": [ + {"term": {"content_translated.target_language.keyword": "ENGLISH"}} + ] + } +} + + +@pytest.mark.parametrize( + ("docs", "expected_query"), + [ + (None, _NO_ES_TRANSLATED_CONTENT), + (["some-doc-id"], {"ids": {"values": ["some-doc-id"]}}), + ( + {"term": {"contentType": "some-content-type"}}, + {"term": {"contentType": "some-content-type"}}, + ), + ], +) +def test_translation_args_as_query( + docs: list[DocId] | DocumentSearchQuery | None, expected_query: dict[str, Any] +) -> None: + # Given + args = TranslationArgs(project=TEST_PROJECT, docs=docs, target_language=DS_ENGLISH) + # When + query = args.as_query() + # Then + assert query == expected_query diff --git a/workers/translation-worker/translation_worker/activities.py b/workers/translation-worker/translation_worker/activities.py index 966002d..fe6b13c 100644 --- a/workers/translation-worker/translation_worker/activities.py +++ b/workers/translation-worker/translation_worker/activities.py @@ -13,15 +13,15 @@ from icij_common.es import ( DOC_CONTENT_TRANSLATED, DOC_LANGUAGE, + ES_DOCUMENT_TYPE, HITS, ID_, QUERY, SOURCE, - TERM, ESClient, - bool_query, + and_query, has_id, - must_not, + has_type, ) from icij_common.iter_utils import async_batches, before_and_after, once from pydantic import TypeAdapter @@ -50,7 +50,7 @@ async def translation_worker_config(self) -> TranslationWorkerConfig: @activity_defn(name="translation.create_translation_batches") async def create_translation_batches( - self, project: str, target: Language + self, project: str, query: dict[str, Any] ) -> list[tuple[Language, list[Batch]]]: es_client = lifespan_es_client() worker_config = cast(TranslationWorkerConfig, lifespan_worker_config()) @@ -59,8 +59,8 @@ async def create_translation_batches( batches = [ b async for b in create_translation_batches_act( - project=project, - target=target, + project, + query, batch_text_length=batch_text_length, es_client=es_client, ) @@ -110,15 +110,15 @@ async def translate_docs( async def create_translation_batches_act( - *, project: str, - target: Language, + query: dict[str, Any], batch_text_length: int = 1000000, es_client: ESClient | None = None, ) -> AsyncGenerator[tuple[DatashareLanguage, list[Batch]], None]: # Retrieve unprocessed docs. - es_docs = _get_es_docs( - es_client, project, target=target, source_includes=BATCHING_DOC_SOURCES + query = _with_doc_type(query) + es_docs = _get_es_docs_by_language( + es_client, project, query, source_includes=BATCHING_DOC_SOURCES ) async for language_docs in es_docs: language_batches: list[Batch] = [] @@ -277,17 +277,16 @@ async def _split_sentences( yield es_doc, sentence -async def _get_es_docs( +async def _get_es_docs_by_language( es_client: ESClient, project: str, - target: Language, + query: dict[str, Any], source_includes: list[str], ) -> AsyncGenerator[AsyncIterator[dict], None]: - # Get all documents that are not in the target language sorted by language docs = _poll_from_es( es_client, project, - body=_untranslated_query(target), + body=query, source_includes=source_includes, sort=[f"{DOC_LANGUAGE}:asc", "_doc:asc"], ) @@ -349,13 +348,6 @@ async def _update_docs_translation( await async_bulk(es_client, actions, raise_on_error=True, refresh="wait_for") -def _untranslated_query(target: Language) -> dict: - query = bool_query( - must_not({TERM: {f"{DOC_CONTENT_TRANSLATED}.target_language.keyword": target}}) - ) - return query - - async def _poll_from_es( es_client: ESClient, project: str, @@ -375,6 +367,10 @@ def _has_language(doc: dict, language: str) -> bool: return doc[SOURCE][DOC_LANGUAGE] == language +def _with_doc_type(query: dict[str, Any]) -> dict[str, Any]: + return and_query(query, has_type(type_field="type", type_value=ES_DOCUMENT_TYPE)) + + async def _publish_and_consume( publisher: asyncio.Task, publisher_completion_callback: Callable[[], None], diff --git a/workers/translation-worker/translation_worker/objects.py b/workers/translation-worker/translation_worker/objects.py index f8e317d..4f2f65e 100644 --- a/workers/translation-worker/translation_worker/objects.py +++ b/workers/translation-worker/translation_worker/objects.py @@ -1,12 +1,20 @@ from abc import ABC from enum import StrEnum -from typing import TYPE_CHECKING, ClassVar +from typing import TYPE_CHECKING, Any, ClassVar from datashare_python.objects import ( BaseModel, DatashareModel, Language, ) +from icij_common.es import ( + DOC_CONTENT_TRANSLATED, + QUERY, + TERM, + bool_query, + has_id, + must_not, +) from icij_common.registrable import RegistrableConfig from pydantic import Field @@ -105,11 +113,34 @@ def to_translator(self) -> "Translator": return Translator.from_config(self.translator) +DocumentSearchQuery = dict[str, Any] +DocId = str + + +def untranslated_query(target: Language) -> dict: + query = bool_query( + must_not({TERM: {f"{DOC_CONTENT_TRANSLATED}.target_language.keyword": target}}) + ) + return query[QUERY] + + class TranslationArgs(DatashareModel): project: str + docs: list[DocId] | DocumentSearchQuery | None = None config: TranslationConfig = Field(default_factory=TranslationConfig) target_language: Language + def as_query(self) -> dict[str, Any]: + match self.docs: + case None: + return untranslated_query(self.target_language) + case list(): + return has_id(self.docs) + case dict(): + return self.docs + case _: + raise ValueError(f"unsupported docs {self.docs}") + class TranslationResponse(DatashareModel): n_translations: int = 0 diff --git a/workers/translation-worker/translation_worker/workflows.py b/workers/translation-worker/translation_worker/workflows.py index 3f72b0f..4973fad 100644 --- a/workers/translation-worker/translation_worker/workflows.py +++ b/workers/translation-worker/translation_worker/workflows.py @@ -26,7 +26,7 @@ async def run(self, args: TranslationArgs) -> TranslationResponse: batches_per_worker = worker_config.batches_per_worker # Create translation batches target = args.target_language - translation_batch_args = [args.project, target] + translation_batch_args = [args.project, args.as_query()] per_language_batches: list[tuple[str, list[list[str]]]] per_language_batches = await execute_activity( TranslationActivities.create_translation_batches,