Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 15 additions & 28 deletions workers/translation-worker/tests/test_activities.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,15 @@
from icij_common.registrable import RegistrableConfig
from translation_worker import activities
from translation_worker.activities import (
_get_es_docs,
_get_es_docs_by_language,
_split_sentences,
_untranslated_query,
_update_docs_translation,
create_translation_batches_act,
translate_docs_act,
)
from translation_worker.config import TranslationWorkerConfig
from translation_worker.constants import DOC_CONTENT_TEXT_LENGTH
from translation_worker.objects import Language, TranslationModel
from translation_worker.objects import TranslationModel, untranslated_query
from translation_worker.processors import SentenceSplitter, Translator

from tests.conftest import (
Expand Down Expand Up @@ -135,25 +134,6 @@ async def _aiter(it: Iterable[T]) -> AsyncGenerator[T, None]:
DOC_ID_2, content=ES_DOC_2_TEXT, language=SPANISH, root_document=ROOT_DOCUMENT_2
)


# _untranslated_query


@pytest.mark.parametrize("language", [DS_ENGLISH, DS_FRENCH])
def test__untranslated_query(language: Language) -> None:
query = _untranslated_query(language)
expected_query = {
"query": {
"bool": {
"must_not": [
{"term": {"content_translated.target_language.keyword": language}}
]
}
}
}
assert query == expected_query


# _iter_sentences

_EN_DOC_TO_SPLIT = _make_es_doc(
Expand Down Expand Up @@ -215,11 +195,12 @@ def _make_batching_doc(
async def test__create_translation_batches__returns_empty_list_when_no_docs() -> None:
# Given
client = MockESClient([])
query = untranslated_query(DS_ENGLISH)
# When
result = [
b
async for b in create_translation_batches_act(
project=TEST_PROJECT, target=DS_ENGLISH, es_client=client
project=TEST_PROJECT, query=query, es_client=client
)
]
# Then
Expand All @@ -230,11 +211,12 @@ async def test__create_translation_batches__single_doc_creates_one_batch() -> No
# Given
doc = _make_batching_doc(DOC_ID_1, FRENCH)
client = MockESClient([doc])
query = untranslated_query(DS_ENGLISH)
# When
result = [
b
async for b in create_translation_batches_act(
project=TEST_PROJECT, target=DS_ENGLISH, es_client=client
project=TEST_PROJECT, query=query, es_client=client
)
]
# When
Expand All @@ -246,13 +228,14 @@ async def test__create_translation_batches__single_doc_creates_one_batch() -> No

async def test__create_translation_batches__multiple_docs_same_lang_one_batch() -> None:
# Given
query = untranslated_query(DS_ENGLISH)
docs = [_make_batching_doc(DOC_ID_1, FRENCH), _make_batching_doc(DOC_ID_2, FRENCH)]
client = MockESClient(docs)
# When
result = [
b
async for b in create_translation_batches_act(
project=TEST_PROJECT, target=DS_ENGLISH, es_client=client
project=TEST_PROJECT, query=query, es_client=client
)
]
# Then
Expand All @@ -265,6 +248,7 @@ async def test__create_translation_batches__multiple_langs_yield_separate_entrie
None
):
# Given
query = untranslated_query(DS_ENGLISH)
fr_doc = _make_batching_doc(DOC_ID_1, DS_FRENCH)
es_doc = _make_batching_doc(DOC_ID_2, DS_SPANISH)
docs = [fr_doc, es_doc]
Expand All @@ -273,7 +257,7 @@ async def test__create_translation_batches__multiple_langs_yield_separate_entrie
result = [
b
async for b in create_translation_batches_act(
project=TEST_PROJECT, target=DS_ENGLISH, es_client=client
project=TEST_PROJECT, query=query, es_client=client
)
]
# Then
Expand All @@ -289,6 +273,7 @@ async def test__create_translation_batches__splits_batch_if_max_text_len_exceede
# Given
batch_text_length = 1400
doc_id_3 = "doc_id_3"
query = untranslated_query(DS_ENGLISH)
docs = [
_make_batching_doc(DOC_ID_1, FRENCH, content_text_length=600),
_make_batching_doc(DOC_ID_2, FRENCH, content_text_length=600),
Expand All @@ -300,7 +285,7 @@ async def test__create_translation_batches__splits_batch_if_max_text_len_exceede
b
async for b in create_translation_batches_act(
project=TEST_PROJECT,
target=DS_ENGLISH,
query=query,
batch_text_length=batch_text_length,
es_client=client,
)
Expand Down Expand Up @@ -517,7 +502,9 @@ async def test__get_es_docs(
# When
docs = [
[d async for d in group]
async for group in _get_es_docs(es_client, TEST_PROJECT, DS_ENGLISH, [])
async for group in _get_es_docs_by_language(
es_client, TEST_PROJECT, DS_ENGLISH, []
)
]
docs = [[d[ID_] for d in g] for g in docs]
# Then
Expand Down
40 changes: 40 additions & 0 deletions workers/translation-worker/tests/test_objects.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,19 @@
from typing import Any

import pytest
from datashare_python.conftest import TEST_PROJECT
from translation_worker.objects import (
ArgosSentenceSplitterConfig,
ArgosSentencizer,
ArgosTranslatorConfig,
DocId,
DocumentSearchQuery,
TranslationArgs,
TranslationConfig,
)

from .conftest import DS_ENGLISH


def test_config_deser() -> None:
# Given
Expand All @@ -20,3 +29,34 @@ def test_config_deser() -> None:
translator=ArgosTranslatorConfig(),
)
assert deser == expected


_NO_ES_TRANSLATED_CONTENT = {
"bool": {
"must_not": [
{"term": {"content_translated.target_language.keyword": "ENGLISH"}}
]
}
}


@pytest.mark.parametrize(
("docs", "expected_query"),
[
(None, _NO_ES_TRANSLATED_CONTENT),
(["some-doc-id"], {"ids": {"values": ["some-doc-id"]}}),
(
{"term": {"contentType": "some-content-type"}},
{"term": {"contentType": "some-content-type"}},
),
],
)
def test_translation_args_as_query(
docs: list[DocId] | DocumentSearchQuery | None, expected_query: dict[str, Any]
) -> None:
# Given
args = TranslationArgs(project=TEST_PROJECT, docs=docs, target_language=DS_ENGLISH)
# When
query = args.as_query()
# Then
assert query == expected_query
38 changes: 17 additions & 21 deletions workers/translation-worker/translation_worker/activities.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@
from icij_common.es import (
DOC_CONTENT_TRANSLATED,
DOC_LANGUAGE,
ES_DOCUMENT_TYPE,
HITS,
ID_,
QUERY,
SOURCE,
TERM,
ESClient,
bool_query,
and_query,
has_id,
must_not,
has_type,
)
from icij_common.iter_utils import async_batches, before_and_after, once
from pydantic import TypeAdapter
Expand Down Expand Up @@ -50,7 +50,7 @@ async def translation_worker_config(self) -> TranslationWorkerConfig:

@activity_defn(name="translation.create_translation_batches")
async def create_translation_batches(
self, project: str, target: Language
self, project: str, query: dict[str, Any]
) -> list[tuple[Language, list[Batch]]]:
es_client = lifespan_es_client()
worker_config = cast(TranslationWorkerConfig, lifespan_worker_config())
Expand All @@ -59,8 +59,8 @@ async def create_translation_batches(
batches = [
b
async for b in create_translation_batches_act(
project=project,
target=target,
project,
query,
batch_text_length=batch_text_length,
es_client=es_client,
)
Expand Down Expand Up @@ -110,15 +110,15 @@ async def translate_docs(


async def create_translation_batches_act(
*,
project: str,
target: Language,
query: dict[str, Any],
batch_text_length: int = 1000000,
es_client: ESClient | None = None,
) -> AsyncGenerator[tuple[DatashareLanguage, list[Batch]], None]:
# Retrieve unprocessed docs.
es_docs = _get_es_docs(
es_client, project, target=target, source_includes=BATCHING_DOC_SOURCES
query = _with_doc_type(query)
es_docs = _get_es_docs_by_language(
es_client, project, query, source_includes=BATCHING_DOC_SOURCES
)
async for language_docs in es_docs:
language_batches: list[Batch] = []
Expand Down Expand Up @@ -277,17 +277,16 @@ async def _split_sentences(
yield es_doc, sentence


async def _get_es_docs(
async def _get_es_docs_by_language(
es_client: ESClient,
project: str,
target: Language,
query: dict[str, Any],
source_includes: list[str],
) -> AsyncGenerator[AsyncIterator[dict], None]:
# Get all documents that are not in the target language sorted by language
docs = _poll_from_es(
es_client,
project,
body=_untranslated_query(target),
body=query,
source_includes=source_includes,
sort=[f"{DOC_LANGUAGE}:asc", "_doc:asc"],
)
Expand Down Expand Up @@ -349,13 +348,6 @@ async def _update_docs_translation(
await async_bulk(es_client, actions, raise_on_error=True, refresh="wait_for")


def _untranslated_query(target: Language) -> dict:
query = bool_query(
must_not({TERM: {f"{DOC_CONTENT_TRANSLATED}.target_language.keyword": target}})
)
return query


async def _poll_from_es(
es_client: ESClient,
project: str,
Expand All @@ -375,6 +367,10 @@ def _has_language(doc: dict, language: str) -> bool:
return doc[SOURCE][DOC_LANGUAGE] == language


def _with_doc_type(query: dict[str, Any]) -> dict[str, Any]:
return and_query(query, has_type(type_field="type", type_value=ES_DOCUMENT_TYPE))


async def _publish_and_consume(
publisher: asyncio.Task,
publisher_completion_callback: Callable[[], None],
Expand Down
33 changes: 32 additions & 1 deletion workers/translation-worker/translation_worker/objects.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
from abc import ABC
from enum import StrEnum
from typing import TYPE_CHECKING, ClassVar
from typing import TYPE_CHECKING, Any, ClassVar

from datashare_python.objects import (
BaseModel,
DatashareModel,
Language,
)
from icij_common.es import (
DOC_CONTENT_TRANSLATED,
QUERY,
TERM,
bool_query,
has_id,
must_not,
)
from icij_common.registrable import RegistrableConfig
from pydantic import Field

Expand Down Expand Up @@ -105,11 +113,34 @@ def to_translator(self) -> "Translator":
return Translator.from_config(self.translator)


DocumentSearchQuery = dict[str, Any]
DocId = str


def untranslated_query(target: Language) -> dict:
query = bool_query(
must_not({TERM: {f"{DOC_CONTENT_TRANSLATED}.target_language.keyword": target}})
)
return query[QUERY]


class TranslationArgs(DatashareModel):
project: str
docs: list[DocId] | DocumentSearchQuery | None = None
config: TranslationConfig = Field(default_factory=TranslationConfig)
target_language: Language

def as_query(self) -> dict[str, Any]:
match self.docs:
case None:
return untranslated_query(self.target_language)
case list():
return has_id(self.docs)
case dict():
return self.docs
case _:
raise ValueError(f"unsupported docs {self.docs}")


class TranslationResponse(DatashareModel):
n_translations: int = 0
2 changes: 1 addition & 1 deletion workers/translation-worker/translation_worker/workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ async def run(self, args: TranslationArgs) -> TranslationResponse:
batches_per_worker = worker_config.batches_per_worker
# Create translation batches
target = args.target_language
translation_batch_args = [args.project, target]
translation_batch_args = [args.project, args.as_query()]
per_language_batches: list[tuple[str, list[list[str]]]]
per_language_batches = await execute_activity(
TranslationActivities.create_translation_batches,
Expand Down
Loading