Skip to content

Commit 68ec024

Browse files
committed
refactor: worker configuration dependency
1 parent 82624d6 commit 68ec024

24 files changed

Lines changed: 1877 additions & 499 deletions

.github/workflows/test_asr_worker.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ jobs:
5050
- name: Run tests
5151
run: |
5252
cd asr-worker
53-
uv sync --frozen --extra preprocessing --extra inference --dev
53+
uv sync --frozen --extra preprocessing --extra inference --extra cpu --dev
5454
uv run --frozen python -m pytest --timeout=180 -vvv --cache-clear --show-capture=all -r A
5555
5656
concurrency:

asr-worker/asr_worker/activities.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from contextlib import AbstractContextManager, contextmanager
33
from itertools import tee
44
from pathlib import Path
5-
from typing import Any
5+
from typing import Any, cast
66

77
from caul.objects import ASRResult, PreprocessedInput
88
from caul.tasks import (
@@ -12,6 +12,7 @@
1212
Postprocessor,
1313
Preprocessor,
1414
)
15+
from datashare_python.dependencies import lifespan_worker_config
1516
from datashare_python.types_ import ProgressRateHandler
1617
from datashare_python.utils import (
1718
ActivityWithProgress,
@@ -67,7 +68,7 @@ async def search_audios(
6768
self, project: str, query: dict[str, Any], batch_size: int
6869
) -> list[Path]:
6970
es_client = lifespan_es_client()
70-
worker_config = ASRWorkerConfig()
71+
worker_config = cast(ASRWorkerConfig, lifespan_worker_config())
7172
batch_dir_name = activity_contextual_id()
7273
workdir = worker_config.workdir
7374
batch_root = workdir / batch_dir_name
@@ -87,7 +88,7 @@ def preprocess(
8788
self, paths: list[Path] | Path, config: ParakeetPreprocessorConfig
8889
) -> list[Path]:
8990
# TODO: this shouldn't be necessary, fix this bug
90-
worker_config = ASRWorkerConfig()
91+
worker_config = cast(ASRWorkerConfig, lifespan_worker_config())
9192
audio_root = worker_config.audios_root
9293
workdir = worker_config.workdir
9394
# TODO: implement caching
@@ -120,7 +121,7 @@ def infer(
120121
) -> list[Path]:
121122
# TODO: fix this temporal by, we shouldn't have to reload
122123
config = _INFERENCE_CONFIG_TYPE_ADAPTER.validate_python(config)
123-
worker_config = ASRWorkerConfig()
124+
worker_config = cast(ASRWorkerConfig, lifespan_worker_config())
124125
workdir = worker_config.workdir
125126
preprocessed_inputs = _LIST_OF_PATH_ADAPTER.validate_python(preprocessed_inputs)
126127
if progress is not None:
@@ -168,7 +169,7 @@ def postprocess(
168169
# TODO: this shouldn't be necessary, fix this bug
169170
input_paths = _LIST_OF_PATH_ADAPTER.validate_python(input_paths)
170171
config = ParakeetPostprocessorConfig.model_validate(config)
171-
worker_config = ASRWorkerConfig()
172+
worker_config = cast(ASRWorkerConfig, lifespan_worker_config())
172173
workdir = worker_config.workdir
173174
artifacts_root = worker_config.artifacts_root
174175
inference_results = _LIST_OF_PATH_ADAPTER.validate_python(inference_results)

asr-worker/asr_worker/dependencies.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55

66
from datashare_python.dependencies import ( # noqa: F401
77
lifespan_es_client,
8+
lifespan_worker_config,
89
set_es_client,
10+
set_worker_config,
911
)
1012

1113

@@ -23,6 +25,7 @@ def set_multiprocessing_start_method(**_) -> Generator[None, None, None]:
2325

2426

2527
REGISTRY = {
26-
"inference": [set_multiprocessing_start_method],
27-
"io": [set_es_client],
28+
"inference": [set_worker_config, set_multiprocessing_start_method],
29+
"io": [set_worker_config, set_es_client],
30+
"preprocessing": [set_worker_config],
2831
}
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#!/bin/bash
22

33
uv run --no-sync datashare-python worker start \
4+
--dependencies inference \
45
--queue worker-template.classify-gpu \
56
--activities asr.transcription.infer

asr-worker/entrypoints/io_worker.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
set -e
33

44
uv run --no-sync datashare-python worker start \
5+
--dependencies io \
56
--queue asr.io \
67
--activities asr.transcription.search-audios \
7-
--workflows asr.transcription \
8-
--dependencies io
8+
--workflows asr.transcription

asr-worker/entrypoints/preprocessing_worker.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,6 @@
22
set -e
33

44
uv run --no-sync datashare-python worker start \
5+
--dependencies preprocessing \
56
--queue asr.preprocessing.cpu \
67
--activities asr.transcription.preprocess

asr-worker/pyproject.toml

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,21 +16,24 @@ dependencies = [
1616
]
1717

1818
[project.optional-dependencies]
19-
preprocessing = [
20-
"torchaudio==2.11.0",
21-
"torchcodec==0.11.0",
19+
cpu = [ # just a placeholder to create an extra marker used to switch to torch gpu build
20+
"torch==2.11.0",
21+
"torchaudio==2.11.0",
22+
"torchcodec==0.11.0",
23+
]
24+
gpu = [ # just a placeholder to create an extra marker used to switch to torch gpu build
25+
"torch==2.11.0+cu128",
26+
"torchaudio==2.11.0",
27+
"torchcodec==0.11.0",
2228
]
2329
inference = [
2430
"caul[nemo]==0.2.16",
2531
"kaldialign==0.9.3",
2632
"ml-dtypes==0.5.4",
2733
"numpy==2.3.0",
2834
"pyarrow==20.0.0",
29-
"torchaudio==2.11.0",
30-
"torch==2.11.0",
31-
"torchcodec==0.11.0",
3235
]
33-
gpu = [ # just a placeholder to create an extra marker used to switch to torch gpu build
36+
preprocessing = [
3437
]
3538

3639
[project.entry-points."datashare.workflows"]
@@ -69,22 +72,28 @@ explicit = true
6972

7073
[tool.uv.sources]
7174
torch = [
72-
{ index = "pytorch-gpu", marker = "extra == 'gpu'" },
73-
{ index = "pytorch-cpu", marker = "extra != 'gpu'"},
75+
{ index = "pytorch-gpu", extra = "gpu" },
76+
{ index = "pytorch-cpu", extra = "cpu"},
7477
]
7578
torchaudio = [
76-
{ index = "pytorch-gpu", marker = "extra == 'gpu'" },
77-
{ index = "pytorch-cpu", marker = "extra != 'gpu'"},
79+
{ index = "pytorch-gpu", extra = "gpu" },
80+
{ index = "pytorch-cpu", extra = "cpu"},
7881
]
7982
torchcodec = [
80-
{ index = "pytorch-gpu", marker = "extra == 'gpu'" },
81-
{ index = "pytorch-cpu", marker = "extra != 'gpu'"},
83+
{ index = "pytorch-gpu", extra = "gpu" },
84+
{ index = "pytorch-cpu", extra = "cpu"},
8285
]
8386
datashare-python = { path = "../datashare-python", editable = true }
8487

8588
[tool.uv]
8689
package = true
8790
override-dependencies = ["kaldialign>=0.9.3"]
91+
conflicts = [
92+
[
93+
{ extra = "cpu" },
94+
{ extra = "gpu" },
95+
],
96+
]
8897

8998
[tool.uv.pip]
9099
prerelease = "if-necessary"

asr-worker/tests/conftest.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
1-
import os
2-
from pathlib import Path
3-
41
import pytest
2+
from _pytest.tmpdir import TempPathFactory
53
from asr_worker.config import ASRWorkerConfig
64
from asr_worker.dependencies import set_multiprocessing_start_method
5+
from datashare_python.config import DatashareClientConfig, TemporalClientConfig
76
from datashare_python.conftest import ( # noqa: F401
87
TEST_PROJECT,
98
doc_3,
@@ -20,26 +19,30 @@
2019
from datashare_python.objects import Document
2120
from datashare_python.types_ import ContextManagerFactory
2221
from icij_common.es import ESClient
23-
from icij_common.test_utils import reset_env # noqa: F401
2422

2523

2624
@pytest.fixture(scope="session")
2725
def test_deps() -> list[ContextManagerFactory]:
2826
return [set_temporal_client, set_es_client, set_multiprocessing_start_method]
2927

3028

31-
@pytest.fixture
32-
def mocked_worker_config_in_env(reset_env, tmp_path: Path) -> ASRWorkerConfig: # noqa: ANN001, ARG001, F811
33-
audios_path = tmp_path / "audios"
34-
audios_path.mkdir()
35-
os.environ["DS_WORKER_AUDIOS_ROOT"] = str(audios_path)
36-
artifacts_path = tmp_path / "artifacts"
37-
artifacts_path.mkdir()
38-
os.environ["DS_WORKER_ARTIFACTS_ROOT"] = str(artifacts_path)
29+
@pytest.fixture(scope="session")
30+
def test_worker_config(tmp_path_factory: TempPathFactory) -> ASRWorkerConfig: # noqa: ANN001, ARG001, F811
31+
tmp_path = tmp_path_factory.mktemp("test-")
32+
audios_root = tmp_path / "audios"
33+
audios_root.mkdir()
34+
artifacts_root = tmp_path / "artifacts"
35+
artifacts_root.mkdir()
3936
workdir = tmp_path / "workdir"
4037
workdir.mkdir()
41-
os.environ["DS_WORKER_WORKDIR"] = str(workdir)
42-
return ASRWorkerConfig()
38+
return ASRWorkerConfig(
39+
log_level="DEBUG",
40+
datashare=DatashareClientConfig(url="http://localhost:8080"),
41+
temporal=TemporalClientConfig(host="localhost:7233"),
42+
audios_root=audios_root,
43+
artifacts_root=artifacts_root,
44+
workdir=workdir,
45+
)
4346

4447

4548
@pytest.fixture(scope="session")

0 commit comments

Comments
 (0)