Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions .github/workflows/publish_asr_worker.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,33 @@ on:
- 'asr-worker-*'

jobs:
create-release:
runs-on: ubuntu-latest
env:
PYTHON_VERSION: 3.12
ASTRAL_VERSION: 0.11.6
steps:
- uses: actions/checkout@v6
- name: Create GH release
run: gh release create "$tag" --generate-notes
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
tag: ${{ github.ref_name }}
- name: Install uv
uses: astral-sh/setup-uv@v7
with:
version: ${{ env.ASTRAL_VERSION }}
python-version: ${{ env.PYTHON_VERSION }}
enable-cache: true
working-directory: asr-worker
- name: Upload available models
run: |
uv run --frozen asr-worker models list-available > available-models.json
gh release upload "$tag" available-models.json
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
tag: ${{ github.ref_name }}

publish-io-worker:
runs-on: ubuntu-latest
steps:
Expand Down
3 changes: 3 additions & 0 deletions .github/workflows/test_asr_worker.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ jobs:

test:
runs-on: ubuntu-latest
env:
PYTHON_VERSION: 3.12
ASTRAL_VERSION: 0.11.6
steps:
- uses: actions/checkout@v6
- name: Setup Python project
Expand Down
3 changes: 3 additions & 0 deletions .github/workflows/test_datashare_python.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ jobs:

test:
runs-on: ubuntu-latest
env:
PYTHON_VERSION: 3.12
ASTRAL_VERSION: 0.11.6
steps:
- uses: actions/checkout@v6
- name: Setup Python project
Expand Down
3 changes: 3 additions & 0 deletions .github/workflows/test_translation_worker.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ jobs:

test:
runs-on: ubuntu-latest
env:
PYTHON_VERSION: 3.12
ASTRAL_VERSION: 0.11.6
steps:
- uses: actions/checkout@v6
- name: Setup Python project
Expand Down
3 changes: 3 additions & 0 deletions .github/workflows/test_worker_template.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ jobs:

test:
runs-on: ubuntu-latest
env:
PYTHON_VERSION: 3.12
ASTRAL_VERSION: 0.11.6
steps:
- uses: actions/checkout@v6
- name: Setup Python project
Expand Down
9 changes: 9 additions & 0 deletions asr-worker/asr_worker/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from .cli import cli_app


def main() -> None:
cli_app()


if __name__ == "__main__":
main()
35 changes: 35 additions & 0 deletions asr-worker/asr_worker/cli/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from typing import Annotated

import datashare_python
import typer
from datashare_python.cli import pretty_exc_callback, version_callback
from datashare_python.cli.utils import AsyncTyper
from icij_common.logging_utils import setup_loggers

import asr_worker

from .models import models_app

cli_app = AsyncTyper(
context_settings={"help_option_names": ["-h", "--help"]},
pretty_exceptions_enable=False,
)
cli_app.add_typer(models_app)


@cli_app.callback()
def main(
version: Annotated[ # noqa: ARG001
bool | None,
typer.Option("--version", callback=version_callback, is_eager=True),
] = None,
*,
pretty_exceptions: Annotated[ # noqa: ARG001
bool,
typer.Option(
"--pretty-exceptions", callback=pretty_exc_callback, is_eager=True
),
] = False,
) -> None:
"""Datashare Python CLI."""
setup_loggers(["__main__", datashare_python.__name__, asr_worker.__name__])
13 changes: 13 additions & 0 deletions asr-worker/asr_worker/cli/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from ..models import available_models
from .utils import AsyncTyper

_MODELS = "models"

_LIST_AVAILABLE_MODELS_HELP = "list available ASR models by language"

models_app = AsyncTyper(name=_MODELS)


@models_app.async_command(help=_LIST_AVAILABLE_MODELS_HELP)
async def list_available() -> None:
print(available_models().model_dump_json(indent=2))
35 changes: 35 additions & 0 deletions asr-worker/asr_worker/cli/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import asyncio
import concurrent.futures
import sys
from collections.abc import Callable
from functools import wraps
from typing import Any

import typer


class AsyncTyper(typer.Typer):
def async_command(self, *args, **kwargs) -> Callable[[Callable], Callable]:
def decorator(async_func: Callable) -> Callable:
@wraps(async_func)
def sync_func(*_args, **_kwargs) -> Any:
res = asyncio.run(async_func(*_args, **_kwargs))
return res

self.command(*args, **kwargs)(sync_func)
return async_func

return decorator


def eprint(*args, **kwargs) -> None:
print(*args, file=sys.stderr, **kwargs)


def _to_concurrent(
fut: asyncio.Future, loop: asyncio.AbstractEventLoop
) -> concurrent.futures.Future:
async def wait() -> None:
await fut

return asyncio.run_coroutine_threadsafe(wait(), loop)
20 changes: 17 additions & 3 deletions asr-worker/asr_worker/models.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
import math
from collections import defaultdict
from functools import cache
from typing import Annotated, Any, Self

from caul.asr_pipeline import ASRPipelineConfig
from caul.config import InferenceRunnerConfig as CaulInferenceRunnerConfig
from caul.constant import ASRModel
from caul.objects import ASRResult
from caul.objects import ASRLanguage, ASRModel, ASRResult
from caul.tasks import (
ParakeetInferenceRunnerConfig,
ParakeetPostprocessorConfig,
ParakeetPreprocessorConfig,
)
from datashare_python.objects import DatashareModel
from icij_common.pydantic_utils import make_enum_discriminator, tagged_union
from pydantic import Discriminator, Field
from pydantic import Discriminator, Field, RootModel

model_discriminator = make_enum_discriminator("model", ASRModel)
InferenceRunnerConfig = Annotated[
Expand Down Expand Up @@ -75,3 +76,16 @@ def from_asr_handler_result(cls, asr_handler_result: ASRResult) -> Self:
if confidence is not None:
confidence = math.exp(asr_handler_result.score)
return Transcription(confidence=confidence, transcripts=transcripts)


AvailableModels = RootModel[dict[ASRLanguage, list[ASRModel]]]


@cache
def available_models() -> AvailableModels:
models = defaultdict(list)
for m in ASRModel:
for language in m.supported_languages():
models[language].append(m)
models = {k: sorted(v) for k, v in sorted(models.items())}
return AvailableModels.model_validate(models)
10 changes: 7 additions & 3 deletions asr-worker/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,12 @@ dependencies = [
"datashare-python~=0.5.0",
]

[project.scripts]
asr-worker = "asr_worker.__main__:main"

[project.optional-dependencies]
cpu = [
"caul==0.2.16",
"caul==0.3.1",
"torch==2.10.0",
"torchaudio==2.10.0",
"torchcodec==0.10.0",
Expand All @@ -29,14 +32,14 @@ gpu = [
"torchcodec==0.10.0+cu128; sys_platform == 'linux'",
]
inference = [
"caul[nemo]==0.2.16",
"caul[nemo]==0.3.1",
"kaldialign==0.9.3",
"ml-dtypes==0.5.4",
"numpy==2.3.0",
"pyarrow==20.0.0",
]
preprocessing = [
"caul==0.2.16",
"caul==0.3.1",
]

[project.entry-points."datashare.workflows"]
Expand Down Expand Up @@ -91,6 +94,7 @@ torchcodec = [
]
datashare-python = { path = "../datashare-python", editable = true }


[tool.uv]
package = true
override-dependencies = [
Expand Down
17 changes: 17 additions & 0 deletions asr-worker/tests/cli/test_asr_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from asr_worker.cli import cli_app
from asr_worker.models import AvailableModels
from typer.testing import CliRunner


async def test_list_models(
typer_asyncio_patch, # noqa: ANN001, ARG001
) -> None:
# Given
runner = CliRunner(mix_stderr=False)
cmd = ["models", "list-available"]
# When
result = runner.invoke(cli_app, cmd, catch_exceptions=False)
# Then
assert int(result.exit_code) == 0
available_models = AvailableModels.model_validate_json(result.output)
assert isinstance(available_models.root, dict)
1 change: 1 addition & 0 deletions asr-worker/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
test_es_client_session,
test_temporal_client_session,
test_worker_config,
typer_asyncio_patch,
worker_lifetime_deps,
)
from datashare_python.dependencies import set_es_client, set_temporal_client
Expand Down
50 changes: 44 additions & 6 deletions asr-worker/uv.dist.lock
Original file line number Diff line number Diff line change
Expand Up @@ -218,9 +218,9 @@ dev = [

[package.metadata]
requires-dist = [
{ name = "caul", marker = "extra == 'cpu'", specifier = "==0.2.16" },
{ name = "caul", marker = "extra == 'preprocessing'", specifier = "==0.2.16" },
{ name = "caul", extras = ["nemo"], marker = "extra == 'inference'", specifier = "==0.2.16" },
{ name = "caul", marker = "extra == 'cpu'", specifier = "==0.3.1" },
{ name = "caul", marker = "extra == 'preprocessing'", specifier = "==0.3.1" },
{ name = "caul", extras = ["nemo"], marker = "extra == 'inference'", specifier = "==0.3.1" },
{ name = "datashare-python", specifier = "~=0.5.0" },
{ name = "kaldialign", marker = "extra == 'inference'", specifier = "==0.9.3" },
{ name = "ml-dtypes", marker = "extra == 'inference'", specifier = "==0.5.4" },
Expand Down Expand Up @@ -329,21 +329,23 @@ wheels = [

[[package]]
name = "caul"
version = "0.2.16"
version = "0.3.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "icij-common", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-10-asr-worker-cpu' and extra == 'extra-10-asr-worker-gpu') or sys_platform == 'darwin' or (sys_platform != 'linux' and extra == 'extra-10-asr-worker-cpu' and extra == 'extra-10-asr-worker-gpu')" },
{ name = "kaldialign", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-10-asr-worker-cpu' and extra == 'extra-10-asr-worker-gpu') or sys_platform == 'darwin' or (sys_platform != 'linux' and extra == 'extra-10-asr-worker-cpu' and extra == 'extra-10-asr-worker-gpu')" },
{ name = "langcodes", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-10-asr-worker-cpu' and extra == 'extra-10-asr-worker-gpu') or sys_platform == 'darwin' or (sys_platform != 'linux' and extra == 'extra-10-asr-worker-cpu' and extra == 'extra-10-asr-worker-gpu')" },
{ name = "numpy", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-10-asr-worker-cpu' and extra == 'extra-10-asr-worker-gpu') or sys_platform == 'darwin' or (sys_platform != 'linux' and extra == 'extra-10-asr-worker-cpu' and extra == 'extra-10-asr-worker-gpu')" },
{ name = "pydantic-extra-types", extra = ["pycountry"], marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-10-asr-worker-cpu' and extra == 'extra-10-asr-worker-gpu') or sys_platform == 'darwin' or (sys_platform != 'linux' and extra == 'extra-10-asr-worker-cpu' and extra == 'extra-10-asr-worker-gpu')" },
{ name = "torchaudio", version = "2.10.0", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "(sys_platform == 'darwin' and extra == 'extra-10-asr-worker-cpu') or (extra == 'extra-10-asr-worker-cpu' and extra == 'extra-10-asr-worker-gpu')" },
{ name = "torchaudio", version = "2.10.0", source = { registry = "https://pypi.org/simple" }, marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra != 'extra-10-asr-worker-cpu' and extra != 'extra-10-asr-worker-gpu') or (sys_platform == 'darwin' and extra == 'extra-10-asr-worker-gpu') or (sys_platform == 'darwin' and extra != 'extra-10-asr-worker-cpu') or (extra == 'extra-10-asr-worker-cpu' and extra == 'extra-10-asr-worker-gpu')" },
{ name = "torchaudio", version = "2.10.0+cpu", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-10-asr-worker-cpu') or (platform_machine != 'x86_64' and extra == 'extra-10-asr-worker-cpu' and extra == 'extra-10-asr-worker-gpu') or (sys_platform != 'linux' and extra == 'extra-10-asr-worker-cpu' and extra == 'extra-10-asr-worker-gpu')" },
{ name = "torchaudio", version = "2.10.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "(platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'extra-10-asr-worker-gpu') or (platform_machine != 'x86_64' and extra == 'extra-10-asr-worker-cpu' and extra == 'extra-10-asr-worker-gpu') or (sys_platform != 'linux' and extra == 'extra-10-asr-worker-cpu' and extra == 'extra-10-asr-worker-gpu')" },
{ name = "torchcodec", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-10-asr-worker-cpu' and extra == 'extra-10-asr-worker-gpu') or sys_platform == 'darwin' or (sys_platform != 'linux' and extra == 'extra-10-asr-worker-cpu' and extra == 'extra-10-asr-worker-gpu')" },
]
sdist = { url = "https://files.pythonhosted.org/packages/b6/79/1ef53340a2c1fcc6b7f9e28182d8142b9e2806a9d54aaf2fa399f602701d/caul-0.2.16.tar.gz", hash = "sha256:9f5ac14b2cd721d3ac5d2d7e08fe495c851cb61acc51d634414efae2f68995a7", size = 319415, upload-time = "2026-04-10T08:08:23.211Z" }
sdist = { url = "https://files.pythonhosted.org/packages/3d/5a/16c3b0cd208086d8ae6d2f607c1d0006031a5ffb604f3491c666bd9cecb0/caul-0.3.1.tar.gz", hash = "sha256:cf1570736e6fc562fffe83d5b0bef5063db45a3c6305b9b3992f30202bf4fe84", size = 326458, upload-time = "2026-04-17T09:54:56.658Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/4d/42/8d6a1a6c42af5181f4ff12abf100052d01100f1be176808aab9e7930b0d6/caul-0.2.16-py3-none-any.whl", hash = "sha256:95f7d291e72dea39b56ce7c45d076dd8be0e80ee462f43840d296addea56a81e", size = 20621, upload-time = "2026-04-10T08:08:22.3Z" },
{ url = "https://files.pythonhosted.org/packages/41/7e/67082ea344cd09b814034271f66792d27e94b4a1882abe4cd29a60fe8922/caul-0.3.1-py3-none-any.whl", hash = "sha256:dcb76032c0b6358010b64fdfeba12eeab77606ee8ef048b47d321b59e959b69f", size = 21794, upload-time = "2026-04-17T09:54:55.258Z" },
]

[package.optional-dependencies]
Expand Down Expand Up @@ -1120,6 +1122,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/48/44/2b5b95b7aa39fb2d8d9d956e0f3d5d45aef2ae1d942d4c3ffac2f9cfed1a/kiwisolver-1.5.0-pp311-pypy311_pp73-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:be4a51a55833dc29ab5d7503e7bcb3b3af3402d266018137127450005cdfe737", size = 79892, upload-time = "2026-03-09T13:15:49.694Z" },
]

[[package]]
name = "langcodes"
version = "3.5.1"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/a9/75/f9edc5d72945019312f359e69ded9f82392a81d49c5051ed3209b100c0d2/langcodes-3.5.1.tar.gz", hash = "sha256:40bff315e01b01d11c2ae3928dd4f5cbd74dd38f9bd912c12b9a3606c143f731", size = 191084, upload-time = "2025-12-02T16:22:01.627Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/dd/c1/d10b371bcba7abce05e2b33910e39c33cfa496a53f13640b7b8e10bb4d2b/langcodes-3.5.1-py3-none-any.whl", hash = "sha256:b6a9c25c603804e2d169165091d0cdb23934610524a21d226e4f463e8e958a72", size = 183050, upload-time = "2025-12-02T16:21:59.954Z" },
]

[[package]]
name = "lazy-loader"
version = "0.5"
Expand Down Expand Up @@ -2223,6 +2234,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/31/a9/dfb999c2fc6911201dcbf348247f9cc382a8990f9ab45c12eabfd7243a38/pyarrow-20.0.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:6102b4864d77102dbbb72965618e204e550135a940c2534711d5ffa787df2a5a", size = 44557216, upload-time = "2025-04-27T12:30:36.977Z" },
]

[[package]]
name = "pycountry"
version = "26.2.16"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/de/1d/061b9e7a48b85cfd69f33c33d2ef784a531c359399ad764243399673c8f5/pycountry-26.2.16.tar.gz", hash = "sha256:5b6027d453fcd6060112b951dd010f01f168b51b4bf8a1f1fc8c95c8d94a0801", size = 7711342, upload-time = "2026-02-17T03:42:52.367Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/9c/42/7703bd45b62fecd44cd7d3495423097e2f7d28bc2e99e7c1af68892ab157/pycountry-26.2.16-py3-none-any.whl", hash = "sha256:115c4baf7cceaa30f59a4694d79483c9167dbce7a9de4d3d571c5f3ea77c305a", size = 8044600, upload-time = "2026-02-17T03:42:49.777Z" },
]

[[package]]
name = "pycparser"
version = "3.0"
Expand Down Expand Up @@ -2276,6 +2296,24 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/a4/11/aa5089b941e85294b1d5d526840b18f0d4464f842d43d8999ce50ef881c1/pydantic_core-2.46.0-pp311-pypy311_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:2f7e6a3752378a69fadf3f5ee8bc5fa082f623703eec0f4e854b12c548322de0", size = 2365925, upload-time = "2026-04-13T09:05:38.338Z" },
]

[[package]]
name = "pydantic-extra-types"
version = "2.11.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "pydantic", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-10-asr-worker-cpu' and extra == 'extra-10-asr-worker-gpu') or sys_platform == 'darwin' or (sys_platform != 'linux' and extra == 'extra-10-asr-worker-cpu' and extra == 'extra-10-asr-worker-gpu')" },
{ name = "typing-extensions", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-10-asr-worker-cpu' and extra == 'extra-10-asr-worker-gpu') or sys_platform == 'darwin' or (sys_platform != 'linux' and extra == 'extra-10-asr-worker-cpu' and extra == 'extra-10-asr-worker-gpu')" },
]
sdist = { url = "https://files.pythonhosted.org/packages/66/71/dba38ee2651f84f7842206adbd2233d8bbdb59fb85e9fa14232486a8c471/pydantic_extra_types-2.11.1.tar.gz", hash = "sha256:46792d2307383859e923d8fcefa82108b1a141f8a9c0198982b3832ab5ef1049", size = 172002, upload-time = "2026-03-16T08:08:03.92Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/17/c1/3226e6d7f5a4f736f38ac11a6fbb262d701889802595cdb0f53a885ac2e0/pydantic_extra_types-2.11.1-py3-none-any.whl", hash = "sha256:1722ea2bddae5628ace25f2aa685b69978ef533123e5638cfbddb999e0100ec1", size = 79526, upload-time = "2026-03-16T08:08:02.533Z" },
]

[package.optional-dependencies]
pycountry = [
{ name = "pycountry", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-10-asr-worker-cpu' and extra == 'extra-10-asr-worker-gpu') or sys_platform == 'darwin' or (sys_platform != 'linux' and extra == 'extra-10-asr-worker-cpu' and extra == 'extra-10-asr-worker-gpu')" },
]

[[package]]
name = "pydantic-settings"
version = "2.13.1"
Expand Down
Loading
Loading