diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index d0041f4..31e7023 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -35,17 +35,23 @@ jobs: - name: Cargo clippy run: cargo clippy --all-targets --all-features -- -D warnings + - name: Install uv + uses: astral-sh/setup-uv@v7 + with: + version: "0.10.12" + cache-dependency-glob: uv.lock + - name: Install Python dev deps - run: pip install -e ".[dev]" + run: uv sync --extra dev - name: Ruff check - run: ruff check py_src/ tests/ + run: uv run ruff check py_src/ tests/ - name: Ruff format check - run: ruff format --check py_src/ tests/ + run: uv run ruff format --check py_src/ tests/ - name: Mypy - run: mypy py_src/taskito/ + run: uv run mypy py_src/taskito/ tests/python/ --no-incremental rust-test: runs-on: ubuntu-latest @@ -70,6 +76,15 @@ jobs: env: LD_LIBRARY_PATH: ${{ env.pythonLocation }}/lib + - name: Check Rust (postgres features) + run: cargo check --workspace --features postgres + + - name: Check Rust (redis features) + run: cargo check --workspace --features redis + + - name: Check Rust (native-async features) + run: cargo check --workspace --features native-async + test: needs: lint runs-on: ${{ matrix.os }} @@ -102,32 +117,25 @@ jobs: with: save-if: ${{ matrix.os != 'ubuntu-latest' }} - - name: Create virtualenv (Unix) - if: runner.os != 'Windows' - run: | - python -m venv .venv - echo "${{ github.workspace }}/.venv/bin" >> $GITHUB_PATH - echo "VIRTUAL_ENV=${{ github.workspace }}/.venv" >> $GITHUB_ENV - - - name: Create virtualenv (Windows) - if: runner.os == 'Windows' - run: | - python -m venv .venv - echo "${{ github.workspace }}\.venv\Scripts" >> $env:GITHUB_PATH - echo "VIRTUAL_ENV=${{ github.workspace }}\.venv" >> $env:GITHUB_ENV + - name: Install uv + uses: astral-sh/setup-uv@v7 + with: + version: "0.10.12" + cache-dependency-glob: uv.lock - - name: Install maturin - run: pip install maturin + - name: Install dependencies + run: uv sync --extra dev - - name: Build and install - run: | - pip install -e ".[dev]" - maturin develop --release + - name: Build native extension + uses: PyO3/maturin-action@v1 + with: + command: develop + args: --release --features native-async - name: Run Python tests run: | set +e - pytest tests/python/ -v --junitxml=test-results.xml + uv run python -m pytest tests/python/ -v --junitxml=test-results.xml PYTEST_EXIT=$? if [ $PYTEST_EXIT -eq 0 ]; then exit 0; fi # SIGABRT (134) during interpreter shutdown is a known PyO3 issue; @@ -139,7 +147,7 @@ jobs: print(int(r.get('failures',0)) + int(r.get('errors',0))) ") if [ "$FAILURES" = "0" ]; then - echo "::warning::Tests passed but process crashed during cleanup (known PyO3 issue on this Python version)" + echo "::warning::Tests passed but process crashed during cleanup (known PyO3 issue)" exit 0 fi fi diff --git a/docs/guide/testing.md b/docs/guide/testing.md index be1fc8f..55dd671 100644 --- a/docs/guide/testing.md +++ b/docs/guide/testing.md @@ -307,6 +307,51 @@ assert results[0].succeeded | `wraps` | `Any` | Wrap a real object — returned as-is when accessed. | | `track_calls` | `bool` | Increment `call_count` each access. | +#### `return_value` vs `wraps` + +Use `return_value` when you want a simple stub: + +```python +mock_cache = MockResource("cache", return_value={"key": "value"}) +``` + +Use `wraps` when you need the real object but want call tracking: + +```python +real_db = create_test_database() +spy_db = MockResource("db", wraps=real_db, track_calls=True) +``` + +#### Multiple resources + +Pass multiple resources to `test_mode`: + +```python +with queue.test_mode(resources={ + "db": MockResource("db", return_value=mock_db), + "cache": MockResource("cache", return_value={}), + "mailer": MockResource("mailer", return_value=mock_smtp), +}) as results: + process_order.delay(order_id=123) +``` + +#### Testing with `inject` + +Tasks that use `@queue.task(inject=["db"])` receive the mock resource automatically: + +```python +@queue.task(inject=["db"]) +def create_user(name, db=None): + db.execute("INSERT INTO users (name) VALUES (?)", (name,)) + +mock_db = MagicMock() +with queue.test_mode(resources={"db": mock_db}) as results: + create_user.delay("Alice") + +assert results[0].succeeded +mock_db.execute.assert_called_once() +``` + !!! note When `resources=` is provided, proxy reconstruction is bypassed automatically. Proxy markers in arguments are passed through as-is so tests don't fail due to missing files or network connections. @@ -344,3 +389,26 @@ def test_e2e(): !!! info "Middleware in test mode" Per-task and queue-level `TaskMiddleware` hooks (`before`, `after`, `on_retry`) **do fire** in test mode, since they run in the Python wrapper around your task function. This lets you verify middleware behavior in tests without running a real worker. + +## Running Tests Locally + +```bash +# Rust tests +cargo test --workspace + +# Rebuild the Python extension after Rust changes +uv run maturin develop + +# Python tests +uv run python -m pytest tests/python/ -v + +# Linting +uv run ruff check py_src/ tests/ +uv run mypy py_src/taskito/ --no-incremental +``` + +To build with native async support: + +```bash +uv run maturin develop --features native-async +``` diff --git a/tests/python/conftest.py b/tests/python/conftest.py index 3bfa4df..5042703 100644 --- a/tests/python/conftest.py +++ b/tests/python/conftest.py @@ -1,6 +1,8 @@ """Shared fixtures for taskito tests.""" import threading +from collections.abc import Generator +from pathlib import Path import pytest @@ -8,14 +10,14 @@ @pytest.fixture -def queue(tmp_path): +def queue(tmp_path: Path) -> Queue: """Create a fresh queue with a temp database.""" db_path = str(tmp_path / "test.db") return Queue(db_path=db_path, workers=2) @pytest.fixture -def run_worker(queue): +def run_worker(queue: Queue) -> Generator[threading.Thread]: """Start a worker thread for the given queue. Stops automatically at teardown.""" thread = threading.Thread(target=queue.run_worker, daemon=True) thread.start() diff --git a/tests/python/test_basic.py b/tests/python/test_basic.py index ca7acc3..9a42053 100644 --- a/tests/python/test_basic.py +++ b/tests/python/test_basic.py @@ -2,23 +2,25 @@ import threading +from taskito import Queue -def test_task_registration(queue): + +def test_task_registration(queue: Queue) -> None: """Tasks can be registered with the decorator.""" @queue.task() - def add(a, b): + def add(a: int, b: int) -> int: return a + b assert add.name.endswith("add") assert add.name in queue._task_registry -def test_enqueue_returns_job_result(queue): +def test_enqueue_returns_job_result(queue: Queue) -> None: """Enqueueing a task returns a JobResult handle.""" @queue.task() - def noop(): + def noop() -> None: pass job = noop.delay() @@ -26,32 +28,32 @@ def noop(): assert len(job.id) > 0 -def test_task_direct_call(queue): +def test_task_direct_call(queue: Queue) -> None: """Decorated tasks can still be called directly.""" @queue.task() - def multiply(a, b): + def multiply(a: int, b: int) -> int: return a * b assert multiply(3, 4) == 12 -def test_apply_async_with_delay(queue): +def test_apply_async_with_delay(queue: Queue) -> None: """apply_async accepts a delay parameter.""" @queue.task() - def slow_task(): + def slow_task() -> None: pass job = slow_task.apply_async(delay=60) assert job.id is not None -def test_apply_async_with_overrides(queue): +def test_apply_async_with_overrides(queue: Queue) -> None: """apply_async can override default task settings.""" @queue.task(priority=1, queue="default") - def configurable_task(x): + def configurable_task(x: int) -> int: return x job = configurable_task.apply_async( @@ -64,11 +66,11 @@ def configurable_task(x): assert job.id is not None -def test_queue_stats(queue): +def test_queue_stats(queue: Queue) -> None: """stats() returns counts by status.""" @queue.task() - def sample_task(): + def sample_task() -> None: pass sample_task.delay() @@ -79,11 +81,11 @@ def sample_task(): assert stats["running"] == 0 -def test_worker_executes_task(queue): +def test_worker_executes_task(queue: Queue) -> None: """Worker processes tasks and stores results.""" @queue.task() - def add(a, b): + def add(a: int, b: int) -> int: return a + b job = add.delay(2, 3) @@ -100,11 +102,11 @@ def add(a, b): assert result == 5 -def test_worker_handles_kwargs(queue): +def test_worker_handles_kwargs(queue: Queue) -> None: """Worker correctly passes keyword arguments.""" @queue.task() - def greet(name, greeting="Hello"): + def greet(name: str, greeting: str = "Hello") -> str: return f"{greeting}, {name}!" job = greet.delay("World", greeting="Hi") @@ -119,11 +121,11 @@ def greet(name, greeting="Hello"): assert result == "Hi, World!" -def test_worker_none_result(queue): +def test_worker_none_result(queue: Queue) -> None: """Tasks returning None work correctly.""" @queue.task() - def void_task(): + def void_task() -> None: pass job = void_task.delay() diff --git a/tests/python/test_batch.py b/tests/python/test_batch.py index f13e3eb..d38b8fb 100644 --- a/tests/python/test_batch.py +++ b/tests/python/test_batch.py @@ -2,12 +2,14 @@ import threading +from taskito import Queue -def test_enqueue_many(queue): + +def test_enqueue_many(queue: Queue) -> None: """enqueue_many enqueues all items in a single batch.""" @queue.task() - def double(x): + def double(x: int) -> int: return x * 2 jobs = queue.enqueue_many( @@ -20,11 +22,11 @@ def double(x): assert stats["pending"] == 10 -def test_task_map(queue): +def test_task_map(queue: Queue) -> None: """TaskWrapper.map() enqueues and returns results.""" @queue.task() - def add(a, b): + def add(a: int, b: int) -> int: return a + b jobs = add.map([(1, 2), (3, 4), (5, 6)]) @@ -37,11 +39,11 @@ def add(a, b): assert sorted(results) == [3, 7, 11] -def test_batch_stats(queue): +def test_batch_stats(queue: Queue) -> None: """Batch enqueue of 50 items shows correct pending count.""" @queue.task() - def noop(): + def noop() -> None: pass queue.enqueue_many( diff --git a/tests/python/test_cancel.py b/tests/python/test_cancel.py index 5fe5048..42c09e3 100644 --- a/tests/python/test_cancel.py +++ b/tests/python/test_cancel.py @@ -4,36 +4,39 @@ import threading +from taskito import Queue -def test_cancel_pending_job(queue): + +def test_cancel_pending_job(queue: Queue) -> None: """A pending job can be cancelled.""" @queue.task() - def slow_task(): + def slow_task() -> str: return "done" job = slow_task.delay() assert queue.cancel_job(job.id) is True refreshed = queue.get_job(job.id) + assert refreshed is not None assert refreshed.status == "cancelled" -def test_cancel_nonexistent_job(queue): +def test_cancel_nonexistent_job(queue: Queue) -> None: """Cancelling a nonexistent job returns False.""" @queue.task() - def dummy(): + def dummy() -> None: pass assert queue.cancel_job("nonexistent-id") is False -def test_cancel_completed_job(queue): +def test_cancel_completed_job(queue: Queue) -> None: """Cancelling a completed job returns False (only pending can be cancelled).""" @queue.task() - def quick_task(): + def quick_task() -> int: return 42 job = quick_task.delay() @@ -45,11 +48,11 @@ def quick_task(): assert queue.cancel_job(job.id) is False -def test_cancelled_in_stats(queue): +def test_cancelled_in_stats(queue: Queue) -> None: """Cancelled jobs show up in stats.""" @queue.task() - def task_a(): + def task_a() -> None: pass job = task_a.delay() diff --git a/tests/python/test_chain.py b/tests/python/test_chain.py index d93e66f..a99a862 100644 --- a/tests/python/test_chain.py +++ b/tests/python/test_chain.py @@ -3,6 +3,7 @@ from __future__ import annotations import threading +from pathlib import Path import pytest @@ -10,7 +11,7 @@ @pytest.fixture -def queue(tmp_path): +def queue(tmp_path: Path) -> Queue: db_path = str(tmp_path / "test_chain.db") q = Queue(db_path=db_path, workers=4) @@ -21,15 +22,15 @@ def queue(tmp_path): return q -def test_chain_executes_in_order(queue): +def test_chain_executes_in_order(queue: Queue) -> None: """chain pipes results through signatures in order.""" @queue.task() - def add(a, b): + def add(a: int, b: int) -> int: return a + b @queue.task() - def double(x): + def double(x: int) -> int: return x * 2 result = chain(add.s(2, 3), double.s()) @@ -37,15 +38,15 @@ def double(x): assert last_job.result(timeout=30) == 10 # (2+3) * 2 = 10 -def test_chain_with_immutable(queue): +def test_chain_with_immutable(queue: Queue) -> None: """si() signatures ignore previous results.""" @queue.task() - def add(a, b): + def add(a: int, b: int) -> int: return a + b @queue.task() - def constant(): + def constant() -> int: return 99 result = chain(add.s(1, 2), constant.si()) @@ -53,11 +54,11 @@ def constant(): assert last_job.result(timeout=30) == 99 -def test_group_parallel(queue): +def test_group_parallel(queue: Queue) -> None: """group enqueues tasks in parallel.""" @queue.task() - def square(x): + def square(x: int) -> int: return x * x jobs = group(square.s(2), square.s(3), square.s(4)).apply(queue) @@ -65,15 +66,15 @@ def square(x): assert sorted(results) == [4, 9, 16] -def test_chord_callback(queue): +def test_chord_callback(queue: Queue) -> None: """chord runs group, then callback with collected results.""" @queue.task() - def add(a, b): + def add(a: int, b: int) -> int: return a + b @queue.task() - def total(results): + def total(results: list[int]) -> int: return sum(results) grp = group(add.s(1, 2), add.s(3, 4), add.s(5, 6)) diff --git a/tests/python/test_cli.py b/tests/python/test_cli.py index 690ab32..174e079 100644 --- a/tests/python/test_cli.py +++ b/tests/python/test_cli.py @@ -1,23 +1,25 @@ """Tests for CLI info command.""" +from pathlib import Path + import pytest from taskito.cli import _load_queue, _print_stats -def test_load_queue_invalid_format(): +def test_load_queue_invalid_format() -> None: """_load_queue rejects paths without a colon.""" with pytest.raises(SystemExit): _load_queue("no_colon_here") -def test_load_queue_missing_module(): +def test_load_queue_missing_module() -> None: """_load_queue exits on missing module.""" with pytest.raises(SystemExit): _load_queue("nonexistent.module:queue") -def test_print_stats_format(capsys, tmp_path): +def test_print_stats_format(capsys: pytest.CaptureFixture[str], tmp_path: Path) -> None: """_print_stats prints a formatted stats table.""" from taskito import Queue @@ -25,7 +27,7 @@ def test_print_stats_format(capsys, tmp_path): queue = Queue(db_path=db_path) @queue.task() - def noop(): + def noop() -> None: pass noop.delay() diff --git a/tests/python/test_context.py b/tests/python/test_context.py index 08f497c..b76f284 100644 --- a/tests/python/test_context.py +++ b/tests/python/test_context.py @@ -1,6 +1,8 @@ """Tests for job context — current_job inside running tasks.""" import threading +from pathlib import Path +from typing import Any import pytest @@ -9,23 +11,23 @@ @pytest.fixture -def queue(tmp_path): +def queue(tmp_path: Path) -> Queue: db_path = str(tmp_path / "test_context.db") return Queue(db_path=db_path, workers=1) -def test_current_job_raises_outside_task(): +def test_current_job_raises_outside_task() -> None: """current_job properties raise RuntimeError outside a task.""" with pytest.raises(RuntimeError, match="No active job context"): _ = current_job.id -def test_current_job_id_available_in_task(queue): +def test_current_job_id_available_in_task(queue: Queue) -> None: """current_job.id is accessible inside a running task.""" - captured = {} + captured: dict[str, Any] = {} @queue.task() - def capture_context(): + def capture_context() -> str: captured["id"] = current_job.id captured["task_name"] = current_job.task_name captured["retry_count"] = current_job.retry_count @@ -45,11 +47,11 @@ def capture_context(): assert captured["queue_name"] == "default" -def test_current_job_update_progress(queue): +def test_current_job_update_progress(queue: Queue) -> None: """current_job.update_progress() works inside a running task.""" @queue.task() - def task_with_progress(): + def task_with_progress() -> str: current_job.update_progress(50) current_job.update_progress(100) return "done" diff --git a/tests/python/test_contrib.py b/tests/python/test_contrib.py index bbc607f..95e25e7 100644 --- a/tests/python/test_contrib.py +++ b/tests/python/test_contrib.py @@ -2,6 +2,8 @@ from __future__ import annotations +import types +from typing import Any from unittest.mock import MagicMock, patch # ── Helpers ────────────────────────────────────────────────────────── @@ -100,7 +102,7 @@ def test_after_records_exception_on_error(self) -> None: mock_span.end.assert_called_once() -def _try_import_otel(): # type: ignore[no-untyped-def] +def _try_import_otel() -> types.ModuleType | None: """Import otel module with mocked opentelemetry if not installed.""" try: import sys @@ -186,7 +188,7 @@ def test_after_captures_exception_on_error(self) -> None: mock_sdk.capture_exception.assert_called_once_with(exc) -def _try_import_sentry(): # type: ignore[no-untyped-def] +def _try_import_sentry() -> types.ModuleType | None: try: import sys @@ -205,7 +207,7 @@ def _try_import_sentry(): # type: ignore[no-untyped-def] # ── Prometheus ─────────────────────────────────────────────────────── -def _make_mock_metrics() -> dict: +def _make_mock_metrics() -> dict[str, Any]: """Create a mock metrics dict matching the instance-based store format.""" return { "jobs_total": MagicMock(), @@ -292,7 +294,7 @@ def test_after_tracks_failure(self) -> None: metrics["jobs_total"].labels.assert_called_with(task="my_task", status="failed") -def _try_import_prometheus(): # type: ignore[no-untyped-def] +def _try_import_prometheus() -> types.ModuleType | None: try: import sys diff --git a/tests/python/test_dashboard.py b/tests/python/test_dashboard.py index 3704011..81c1754 100644 --- a/tests/python/test_dashboard.py +++ b/tests/python/test_dashboard.py @@ -4,6 +4,9 @@ import threading import urllib.error import urllib.request +from collections.abc import Generator +from pathlib import Path +from typing import Any import pytest @@ -11,52 +14,52 @@ @pytest.fixture -def queue(tmp_path): +def queue(tmp_path: Path) -> Queue: """Create a fresh queue with some test data pre-registered.""" db_path = str(tmp_path / "test_dashboard.db") q = Queue(db_path=db_path, workers=2) @q.task(queue="default") - def task_a(x): + def task_a(x: int) -> int: return x * 2 @q.task(queue="email") - def task_b(x): + def task_b(x: int) -> int: return x + 1 return q @pytest.fixture -def populated_queue(queue): +def populated_queue(queue: Queue) -> tuple[Queue, list[Any]]: """Queue with several jobs enqueued.""" - task_a = None - task_b = None + task_a_name: str = "" + task_b_name: str = "" for name, _fn in queue._task_registry.items(): if "task_a" in name: - task_a = name + task_a_name = name elif "task_b" in name: - task_b = name + task_b_name = name - jobs = [] + jobs: list[Any] = [] for i in range(5): - jobs.append(queue.enqueue(task_a, args=(i,))) + jobs.append(queue.enqueue(task_a_name, args=(i,))) for i in range(3): - jobs.append(queue.enqueue(task_b, args=(i,), queue="email")) + jobs.append(queue.enqueue(task_b_name, args=(i,), queue="email")) return queue, jobs # ── list_jobs tests ────────────────────────────────────── -def test_list_jobs_returns_all(populated_queue): +def test_list_jobs_returns_all(populated_queue: tuple[Queue, list[Any]]) -> None: """list_jobs() with no filters returns all jobs.""" queue, _ = populated_queue result = queue.list_jobs() assert len(result) == 8 -def test_list_jobs_filter_by_queue(populated_queue): +def test_list_jobs_filter_by_queue(populated_queue: tuple[Queue, list[Any]]) -> None: """list_jobs() can filter by queue name.""" queue, _ = populated_queue result = queue.list_jobs(queue="email") @@ -66,7 +69,7 @@ def test_list_jobs_filter_by_queue(populated_queue): assert d["queue"] == "email" -def test_list_jobs_filter_by_status(populated_queue): +def test_list_jobs_filter_by_status(populated_queue: tuple[Queue, list[Any]]) -> None: """list_jobs() can filter by status.""" queue, _ = populated_queue result = queue.list_jobs(status="pending") @@ -76,7 +79,7 @@ def test_list_jobs_filter_by_status(populated_queue): assert len(result) == 0 -def test_list_jobs_filter_by_task_name(populated_queue): +def test_list_jobs_filter_by_task_name(populated_queue: tuple[Queue, list[Any]]) -> None: """list_jobs() can filter by task name.""" queue, _ = populated_queue # Find the task_a name @@ -90,7 +93,7 @@ def test_list_jobs_filter_by_task_name(populated_queue): assert len(result) == 5 -def test_list_jobs_pagination(populated_queue): +def test_list_jobs_pagination(populated_queue: tuple[Queue, list[Any]]) -> None: """list_jobs() respects limit and offset.""" queue, _ = populated_queue page1 = queue.list_jobs(limit=3, offset=0) @@ -104,7 +107,7 @@ def test_list_jobs_pagination(populated_queue): assert ids1.isdisjoint(ids2) -def test_list_jobs_invalid_status(queue): +def test_list_jobs_invalid_status(queue: Queue) -> None: """list_jobs() raises on invalid status string.""" with pytest.raises(ValueError): queue.list_jobs(status="bogus") @@ -113,11 +116,11 @@ def test_list_jobs_invalid_status(queue): # ── to_dict tests ──────────────────────────────────────── -def test_to_dict_fields(queue): +def test_to_dict_fields(queue: Queue) -> None: """to_dict() returns all expected fields.""" @queue.task() - def dummy(): + def dummy() -> None: pass job = dummy.delay() @@ -146,11 +149,11 @@ def dummy(): assert d["id"] == job.id -def test_to_dict_is_json_serializable(queue): +def test_to_dict_is_json_serializable(queue: Queue) -> None: """to_dict() output can be serialized to JSON.""" @queue.task() - def dummy(): + def dummy() -> None: pass job = dummy.delay() @@ -163,7 +166,9 @@ def dummy(): @pytest.fixture -def dashboard_server(populated_queue): +def dashboard_server( + populated_queue: tuple[Queue, list[Any]], +) -> Generator[tuple[str, Queue, list[Any]]]: """Start a dashboard server on a random port.""" queue, jobs = populated_queue from http.server import ThreadingHTTPServer @@ -182,20 +187,20 @@ def dashboard_server(populated_queue): server.shutdown() -def _get(url): +def _get(url: str) -> Any: """GET request and parse JSON.""" with urllib.request.urlopen(url) as resp: return json.loads(resp.read()) -def _post(url): +def _post(url: str) -> Any: """POST request and parse JSON.""" req = urllib.request.Request(url, method="POST", data=b"") with urllib.request.urlopen(req) as resp: return json.loads(resp.read()) -def test_api_stats(dashboard_server): +def test_api_stats(dashboard_server: tuple[str, Queue, list[Any]]) -> None: """GET /api/stats returns valid stats dict.""" base, _, __ = dashboard_server data = _get(f"{base}/api/stats") @@ -203,7 +208,7 @@ def test_api_stats(dashboard_server): assert data["pending"] == 8 -def test_api_jobs_list(dashboard_server): +def test_api_jobs_list(dashboard_server: tuple[str, Queue, list[Any]]) -> None: """GET /api/jobs returns job list.""" base, _, __ = dashboard_server data = _get(f"{base}/api/jobs") @@ -211,7 +216,7 @@ def test_api_jobs_list(dashboard_server): assert len(data) == 8 -def test_api_jobs_filter_status(dashboard_server): +def test_api_jobs_filter_status(dashboard_server: tuple[str, Queue, list[Any]]) -> None: """GET /api/jobs?status=pending filters correctly.""" base, _, __ = dashboard_server data = _get(f"{base}/api/jobs?status=pending") @@ -221,21 +226,21 @@ def test_api_jobs_filter_status(dashboard_server): assert len(data) == 0 -def test_api_jobs_filter_queue(dashboard_server): +def test_api_jobs_filter_queue(dashboard_server: tuple[str, Queue, list[Any]]) -> None: """GET /api/jobs?queue=email filters correctly.""" base, _, __ = dashboard_server data = _get(f"{base}/api/jobs?queue=email") assert len(data) == 3 -def test_api_jobs_pagination(dashboard_server): +def test_api_jobs_pagination(dashboard_server: tuple[str, Queue, list[Any]]) -> None: """GET /api/jobs?limit=3&offset=0 paginates.""" base, _, __ = dashboard_server data = _get(f"{base}/api/jobs?limit=3&offset=0") assert len(data) == 3 -def test_api_job_detail(dashboard_server): +def test_api_job_detail(dashboard_server: tuple[str, Queue, list[Any]]) -> None: """GET /api/jobs/{id} returns job dict.""" base, _, jobs = dashboard_server job_id = jobs[0].id @@ -244,7 +249,7 @@ def test_api_job_detail(dashboard_server): assert "status" in data -def test_api_job_not_found(dashboard_server): +def test_api_job_not_found(dashboard_server: tuple[str, Queue, list[Any]]) -> None: """GET /api/jobs/nonexistent returns 404.""" base, _, __ = dashboard_server try: @@ -254,7 +259,7 @@ def test_api_job_not_found(dashboard_server): assert e.code == 404 -def test_api_cancel_job(dashboard_server): +def test_api_cancel_job(dashboard_server: tuple[str, Queue, list[Any]]) -> None: """POST /api/jobs/{id}/cancel cancels a pending job.""" base, _, jobs = dashboard_server job_id = jobs[0].id @@ -262,14 +267,14 @@ def test_api_cancel_job(dashboard_server): assert data["cancelled"] is True -def test_api_dead_letters_empty(dashboard_server): +def test_api_dead_letters_empty(dashboard_server: tuple[str, Queue, list[Any]]) -> None: """GET /api/dead-letters returns empty list initially.""" base, _, __ = dashboard_server data = _get(f"{base}/api/dead-letters") assert data == [] -def test_spa_html_served(dashboard_server): +def test_spa_html_served(dashboard_server: tuple[str, Queue, list[Any]]) -> None: """GET / returns the SPA HTML.""" base, _, __ = dashboard_server with urllib.request.urlopen(base) as resp: diff --git a/tests/python/test_dependencies.py b/tests/python/test_dependencies.py index 8bf8c37..d5fa75d 100644 --- a/tests/python/test_dependencies.py +++ b/tests/python/test_dependencies.py @@ -4,12 +4,14 @@ import pytest +from taskito import Queue -def test_enqueue_with_depends_on(queue): + +def test_enqueue_with_depends_on(queue: Queue) -> None: """Jobs can declare dependencies on other jobs.""" @queue.task() - def step(x): + def step(x: int) -> int: return x job_a = step.delay(1) @@ -19,11 +21,11 @@ def step(x): assert job_a.dependents == [job_b.id] -def test_enqueue_with_multiple_deps(queue): +def test_enqueue_with_multiple_deps(queue: Queue) -> None: """Jobs can depend on multiple other jobs.""" @queue.task() - def step(x): + def step(x: int) -> int: return x job_a = step.delay(1) @@ -34,11 +36,11 @@ def step(x): assert deps == {job_a.id, job_b.id} -def test_depends_on_string_coercion(queue): +def test_depends_on_string_coercion(queue: Queue) -> None: """depends_on accepts a single string ID.""" @queue.task() - def step(x): + def step(x: int) -> int: return x job_a = step.delay(1) @@ -52,11 +54,11 @@ def step(x): assert job_b.dependencies == [job_a.id] -def test_dependency_blocks_execution(queue): +def test_dependency_blocks_execution(queue: Queue) -> None: """Dependent job waits until dependency completes.""" @queue.task() - def step(x): + def step(x: int) -> int: return x * 10 job_a = step.delay(1) @@ -73,11 +75,11 @@ def step(x): assert result_b == 20 -def test_cascade_cancel_on_job_cancel(queue): +def test_cascade_cancel_on_job_cancel(queue: Queue) -> None: """Cancelling a job cascades to its dependents.""" @queue.task() - def step(x): + def step(x: int) -> int: return x job_a = step.delay(1) @@ -92,11 +94,11 @@ def step(x): assert job_c.status == "cancelled" -def test_no_dependencies_property_when_none(queue): +def test_no_dependencies_property_when_none(queue: Queue) -> None: """Jobs without dependencies return empty list.""" @queue.task() - def step(x): + def step(x: int) -> int: return x job = step.delay(1) @@ -104,11 +106,11 @@ def step(x): assert job.dependents == [] -def test_enqueue_rejects_missing_dependency(queue): +def test_enqueue_rejects_missing_dependency(queue: Queue) -> None: """Enqueuing with a nonexistent dependency raises an error.""" @queue.task() - def step(x): + def step(x: int) -> int: return x with pytest.raises(RuntimeError): diff --git a/tests/python/test_dlq.py b/tests/python/test_dlq.py index 1ff7318..4ce14c0 100644 --- a/tests/python/test_dlq.py +++ b/tests/python/test_dlq.py @@ -3,18 +3,20 @@ import threading import time +from taskito import Queue -def test_dead_letters_empty(queue): + +def test_dead_letters_empty(queue: Queue) -> None: """Empty DLQ returns empty list.""" dead = queue.dead_letters() assert dead == [] -def test_purge_dead(queue): +def test_purge_dead(queue: Queue) -> None: """Purging removes old dead letter entries.""" @queue.task(max_retries=0, retry_backoff=0.1) - def instant_fail(): + def instant_fail() -> None: raise RuntimeError("fail") instant_fail.delay() diff --git a/tests/python/test_events.py b/tests/python/test_events.py index f1ab61e..e8277fb 100644 --- a/tests/python/test_events.py +++ b/tests/python/test_events.py @@ -1,13 +1,14 @@ """Tests for EventBus event dispatch.""" import time +from typing import Any from taskito.events import EventBus, EventType -def test_callback_receives_event(): +def test_callback_receives_event() -> None: """Registered callbacks receive emitted events.""" - received = [] + received: list[tuple[EventType, dict[str, Any]]] = [] bus = EventBus() bus.on(EventType.JOB_COMPLETED, lambda et, p: received.append((et, p))) @@ -19,7 +20,7 @@ def test_callback_receives_event(): assert received[0][1]["job_id"] == "123" -def test_multiple_callbacks(): +def test_multiple_callbacks() -> None: """Multiple callbacks for the same event type all fire.""" counts = {"a": 0, "b": 0} bus = EventBus() @@ -33,9 +34,9 @@ def test_multiple_callbacks(): assert counts["b"] == 1 -def test_event_filtering(): +def test_event_filtering() -> None: """Callbacks only fire for their registered event type.""" - received = [] + received: list[str] = [] bus = EventBus() bus.on(EventType.JOB_COMPLETED, lambda et, p: received.append("completed")) @@ -45,15 +46,15 @@ def test_event_filtering(): assert received == [] -def test_exception_in_callback_does_not_crash(): +def test_exception_in_callback_does_not_crash() -> None: """A raising callback doesn't prevent other events from processing.""" - results = [] + results: list[str] = [] bus = EventBus() - def bad_callback(et, p): + def bad_callback(et: EventType, p: dict[str, Any]) -> None: raise RuntimeError("callback error") - def good_callback(et, p): + def good_callback(et: EventType, p: dict[str, Any]) -> None: results.append("ok") bus.on(EventType.JOB_ENQUEUED, bad_callback) @@ -65,13 +66,13 @@ def good_callback(et, p): assert results == ["ok"] -def test_emit_with_no_listeners(): +def test_emit_with_no_listeners() -> None: """Emitting an event with no listeners doesn't raise.""" bus = EventBus() bus.emit(EventType.JOB_DEAD, {"job_id": "456"}) -def test_all_event_types_exist(): +def test_all_event_types_exist() -> None: """All expected event types are defined.""" expected = { "job.enqueued", diff --git a/tests/python/test_fastapi.py b/tests/python/test_fastapi.py index 49fefc5..fdbcc1f 100644 --- a/tests/python/test_fastapi.py +++ b/tests/python/test_fastapi.py @@ -1,6 +1,7 @@ """Tests for FastAPI integration.""" import threading +from typing import Any import pytest @@ -11,11 +12,12 @@ from fastapi import FastAPI # noqa: E402 from fastapi.testclient import TestClient # noqa: E402 +from taskito import Queue # noqa: E402 from taskito.contrib.fastapi import TaskitoRouter # noqa: E402 @pytest.fixture -def app(queue): +def app(queue: Queue) -> FastAPI: """Create a FastAPI app with TaskitoRouter.""" app = FastAPI() app.include_router(TaskitoRouter(queue), prefix="/tasks") @@ -23,17 +25,17 @@ def app(queue): @pytest.fixture -def client(app): +def client(app: FastAPI) -> TestClient: """Create a TestClient.""" return TestClient(app) @pytest.fixture -def populated(queue, client): +def populated(queue: Queue, client: TestClient) -> tuple[Queue, TestClient, list[Any], Any]: """Queue with a task and some jobs.""" @queue.task() - def add(a, b): + def add(a: int, b: int) -> int: return a + b jobs = [add.delay(i, i + 1) for i in range(5)] @@ -43,7 +45,7 @@ def add(a, b): # ── Stats ──────────────────────────────────────────────── -def test_stats(populated): +def test_stats(populated: tuple[Queue, TestClient, list[Any], Any]) -> None: _queue, client, _jobs, _add = populated resp = client.get("/tasks/stats") assert resp.status_code == 200 @@ -55,7 +57,7 @@ def test_stats(populated): # ── Job detail ─────────────────────────────────────────── -def test_get_job(populated): +def test_get_job(populated: tuple[Queue, TestClient, list[Any], Any]) -> None: _queue, client, jobs, _add = populated job_id = jobs[0].id resp = client.get(f"/tasks/jobs/{job_id}") @@ -65,7 +67,7 @@ def test_get_job(populated): assert data["status"] == "pending" -def test_get_job_not_found(client): +def test_get_job_not_found(client: TestClient) -> None: resp = client.get("/tasks/jobs/nonexistent") assert resp.status_code == 404 @@ -73,7 +75,7 @@ def test_get_job_not_found(client): # ── Job errors ─────────────────────────────────────────── -def test_get_job_errors_empty(populated): +def test_get_job_errors_empty(populated: tuple[Queue, TestClient, list[Any], Any]) -> None: _queue, client, jobs, _add = populated job_id = jobs[0].id resp = client.get(f"/tasks/jobs/{job_id}/errors") @@ -84,7 +86,7 @@ def test_get_job_errors_empty(populated): # ── Job result ─────────────────────────────────────────── -def test_get_job_result_pending(populated): +def test_get_job_result_pending(populated: tuple[Queue, TestClient, list[Any], Any]) -> None: _queue, client, jobs, _add = populated job_id = jobs[0].id resp = client.get(f"/tasks/jobs/{job_id}/result") @@ -94,7 +96,7 @@ def test_get_job_result_pending(populated): assert data["result"] is None -def test_get_job_result_completed(populated): +def test_get_job_result_completed(populated: tuple[Queue, TestClient, list[Any], Any]) -> None: queue, client, jobs, _add = populated worker = threading.Thread(target=queue.run_worker, daemon=True) @@ -113,7 +115,7 @@ def test_get_job_result_completed(populated): # ── Cancel ─────────────────────────────────────────────── -def test_cancel_job(populated): +def test_cancel_job(populated: tuple[Queue, TestClient, list[Any], Any]) -> None: _queue, client, jobs, _add = populated job_id = jobs[0].id resp = client.post(f"/tasks/jobs/{job_id}/cancel") @@ -128,7 +130,7 @@ def test_cancel_job(populated): # ── Dead letters ───────────────────────────────────────── -def test_dead_letters_empty(client): +def test_dead_letters_empty(client: TestClient) -> None: resp = client.get("/tasks/dead-letters") assert resp.status_code == 200 assert resp.json() == [] @@ -137,7 +139,7 @@ def test_dead_letters_empty(client): # ── Progress SSE ───────────────────────────────────────── -def test_progress_stream(populated): +def test_progress_stream(populated: tuple[Queue, TestClient, list[Any], Any]) -> None: queue, client, jobs, _add = populated # Start worker so the job completes @@ -150,7 +152,7 @@ def test_progress_stream(populated): with client.stream("GET", f"/tasks/jobs/{job_id}/progress") as resp: assert resp.status_code == 200 - lines = [] + lines: list[str] = [] for line in resp.iter_lines(): if line.startswith("data:"): lines.append(line) @@ -164,7 +166,7 @@ def test_progress_stream(populated): assert data["status"] == "complete" -def test_progress_stream_not_found(client): +def test_progress_stream_not_found(client: TestClient) -> None: resp = client.get("/tasks/jobs/nonexistent/progress") assert resp.status_code == 404 @@ -172,13 +174,13 @@ def test_progress_stream_not_found(client): # ── Router config ──────────────────────────────────────── -def test_router_custom_tags(queue): +def test_router_custom_tags(queue: Queue) -> None: """TaskitoRouter accepts standard APIRouter kwargs.""" router = TaskitoRouter(queue, tags=["my-tasks"]) assert "my-tasks" in router.tags -def test_router_custom_prefix(queue): +def test_router_custom_prefix(queue: Queue) -> None: """Router can be mounted with a custom prefix.""" app = FastAPI() app.include_router(TaskitoRouter(queue), prefix="/api/v1/queue") diff --git a/tests/python/test_hooks.py b/tests/python/test_hooks.py index 042113e..e96d4f5 100644 --- a/tests/python/test_hooks.py +++ b/tests/python/test_hooks.py @@ -3,22 +3,25 @@ from __future__ import annotations import threading +from typing import Any +from taskito import Queue -def test_before_and_after_hooks(queue): + +def test_before_and_after_hooks(queue: Queue) -> None: """before_task and after_task hooks fire around task execution.""" - events = [] + events: list[tuple[Any, ...]] = [] @queue.before_task - def on_before(task_name, args, kwargs): + def on_before(task_name: str, args: tuple, kwargs: dict) -> None: events.append(("before", task_name)) @queue.after_task - def on_after(task_name, args, kwargs, result, error): + def on_after(task_name: str, args: tuple, kwargs: dict, result: Any, error: Any) -> None: events.append(("after", task_name, result, error)) @queue.task() - def add(a, b): + def add(a: int, b: int) -> int: return a + b job = add.delay(1, 2) @@ -31,19 +34,19 @@ def add(a, b): # Verify hooks fired assert any(e[0] == "before" for e in events) - assert any(e[0] == "after" and e[2] == 3 and e[3] is None for e in events) + assert any(e[0] == "after" and len(e) > 2 and e[2] == 3 and e[3] is None for e in events) -def test_on_success_hook(queue): +def test_on_success_hook(queue: Queue) -> None: """on_success hook fires when task succeeds.""" - success_results = [] + success_results: list[Any] = [] @queue.on_success - def on_success(task_name, args, kwargs, result): + def on_success(task_name: str, args: tuple, kwargs: dict, result: Any) -> None: success_results.append(result) @queue.task() - def multiply(a, b): + def multiply(a: int, b: int) -> int: return a * b job = multiply.delay(3, 4) @@ -56,16 +59,16 @@ def multiply(a, b): assert 12 in success_results -def test_on_failure_hook(queue): +def test_on_failure_hook(queue: Queue) -> None: """on_failure hook fires when task raises.""" - failure_errors = [] + failure_errors: list[str] = [] @queue.on_failure - def on_failure(task_name, args, kwargs, error): + def on_failure(task_name: str, args: tuple, kwargs: dict, error: Exception) -> None: failure_errors.append(str(error)) @queue.task(max_retries=1, retry_backoff=0.1) - def always_fails(): + def always_fails() -> None: raise ValueError("boom") always_fails.delay() diff --git a/tests/python/test_interception.py b/tests/python/test_interception.py index dec1555..4f7db15 100644 --- a/tests/python/test_interception.py +++ b/tests/python/test_interception.py @@ -11,6 +11,7 @@ import threading import uuid from dataclasses import dataclass +from typing import Any import pytest @@ -29,17 +30,17 @@ @pytest.fixture -def registry(): +def registry() -> Any: return build_default_registry() @pytest.fixture -def strict(registry): +def strict(registry: Any) -> ArgumentInterceptor: return ArgumentInterceptor(registry, mode="strict") @pytest.fixture -def lenient(registry): +def lenient(registry: Any) -> ArgumentInterceptor: return ArgumentInterceptor(registry, mode="lenient") @@ -47,31 +48,31 @@ def lenient(registry): class TestPassStrategy: - def test_int_passes_through(self, strict): + def test_int_passes_through(self, strict: ArgumentInterceptor) -> None: args, _kw = strict.intercept((42,), {}) assert args == (42,) - def test_str_passes_through(self, strict): + def test_str_passes_through(self, strict: ArgumentInterceptor) -> None: args, _kw = strict.intercept(("hello",), {}) assert args == ("hello",) - def test_float_passes_through(self, strict): + def test_float_passes_through(self, strict: ArgumentInterceptor) -> None: args, _kw = strict.intercept((3.14,), {}) assert args == (3.14,) - def test_bool_passes_through(self, strict): + def test_bool_passes_through(self, strict: ArgumentInterceptor) -> None: args, _kw = strict.intercept((True, False), {}) assert args == (True, False) - def test_none_passes_through(self, strict): + def test_none_passes_through(self, strict: ArgumentInterceptor) -> None: args, _kw = strict.intercept((None,), {}) assert args == (None,) - def test_bytes_passes_through(self, strict): + def test_bytes_passes_through(self, strict: ArgumentInterceptor) -> None: args, _kw = strict.intercept((b"data",), {}) assert args == (b"data",) - def test_mixed_primitives(self, strict): + def test_mixed_primitives(self, strict: ArgumentInterceptor) -> None: args, kwargs = strict.intercept( (1, "two", 3.0, True, None, b"six"), {"key": "val"}, @@ -84,7 +85,7 @@ def test_mixed_primitives(self, strict): class TestConvertStrategy: - def test_uuid_round_trip(self, strict): + def test_uuid_round_trip(self, strict: ArgumentInterceptor) -> None: original = uuid.UUID("12345678-1234-5678-1234-567812345678") args, _kw = strict.intercept((original,), {}) assert args[0]["__taskito_convert__"] is True @@ -93,49 +94,49 @@ def test_uuid_round_trip(self, strict): restored = reconstruct_converted(args[0]) assert restored == original - def test_datetime_round_trip(self, strict): + def test_datetime_round_trip(self, strict: ArgumentInterceptor) -> None: original = datetime.datetime(2025, 3, 10, 12, 0, 0) args, _ = strict.intercept((original,), {}) assert args[0]["type_key"] == "datetime" restored = reconstruct_converted(args[0]) assert restored == original - def test_date_round_trip(self, strict): + def test_date_round_trip(self, strict: ArgumentInterceptor) -> None: original = datetime.date(2025, 3, 10) args, _ = strict.intercept((original,), {}) assert args[0]["type_key"] == "date" restored = reconstruct_converted(args[0]) assert restored == original - def test_time_round_trip(self, strict): + def test_time_round_trip(self, strict: ArgumentInterceptor) -> None: original = datetime.time(14, 30, 0) args, _ = strict.intercept((original,), {}) assert args[0]["type_key"] == "time" restored = reconstruct_converted(args[0]) assert restored == original - def test_timedelta_round_trip(self, strict): + def test_timedelta_round_trip(self, strict: ArgumentInterceptor) -> None: original = datetime.timedelta(hours=1, minutes=30) args, _ = strict.intercept((original,), {}) assert args[0]["type_key"] == "timedelta" restored = reconstruct_converted(args[0]) assert restored == original - def test_decimal_round_trip(self, strict): + def test_decimal_round_trip(self, strict: ArgumentInterceptor) -> None: original = decimal.Decimal("3.14159") args, _ = strict.intercept((original,), {}) assert args[0]["type_key"] == "decimal" restored = reconstruct_converted(args[0]) assert restored == original - def test_path_round_trip(self, strict): + def test_path_round_trip(self, strict: ArgumentInterceptor) -> None: original = pathlib.Path("/tmp/data.csv") args, _ = strict.intercept((original,), {}) assert args[0]["type_key"] == "path" restored = reconstruct_converted(args[0]) assert restored == original - def test_pattern_round_trip(self, strict): + def test_pattern_round_trip(self, strict: ArgumentInterceptor) -> None: original = re.compile(r"\d+", re.IGNORECASE) args, _ = strict.intercept((original,), {}) assert args[0]["type_key"] == "pattern" @@ -143,7 +144,7 @@ def test_pattern_round_trip(self, strict): assert restored.pattern == original.pattern assert restored.flags == original.flags - def test_enum_converts(self, strict): + def test_enum_converts(self, strict: ArgumentInterceptor) -> None: class Color(enum.Enum): RED = "red" GREEN = "green" @@ -152,7 +153,7 @@ class Color(enum.Enum): assert args[0]["type_key"] == "enum" assert args[0]["value"] == "red" - def test_dataclass_converts(self, strict): + def test_dataclass_converts(self, strict: ArgumentInterceptor) -> None: @dataclass class Point: x: int @@ -164,7 +165,7 @@ class Point: assert args[0]["type_key"] == "dataclass" assert args[0]["value"] == {"x": 1, "y": 2} - def test_datetime_before_date(self, strict): + def test_datetime_before_date(self, strict: ArgumentInterceptor) -> None: """datetime is a subclass of date — datetime must match first.""" dt = datetime.datetime(2025, 1, 1, 12, 0, 0) args, _ = strict.intercept((dt,), {}) @@ -175,10 +176,10 @@ def test_datetime_before_date(self, strict): class TestRedirectStrategy: - def test_redirect_produces_marker(self, strict): + def test_redirect_produces_marker(self, strict: ArgumentInterceptor) -> None: """Test that redirect types produce markers (if sqlalchemy is installed).""" try: - from sqlalchemy.orm import Session # noqa: F401 + from sqlalchemy.orm import Session # type: ignore[import-not-found] # noqa: F401 sqlalchemy_available = True except ImportError: @@ -192,7 +193,7 @@ def test_redirect_produces_marker(self, strict): class TestRejectStrategy: - def test_threading_lock_rejected(self, strict): + def test_threading_lock_rejected(self, strict: ArgumentInterceptor) -> None: lock = threading.Lock() with pytest.raises(InterceptionError) as exc_info: strict.intercept((lock,), {}) @@ -200,7 +201,7 @@ def test_threading_lock_rejected(self, strict): assert "args[0]" in exc_info.value.failures[0].path assert "lock" in exc_info.value.failures[0].type_name.lower() - def test_socket_rejected(self, strict): + def test_socket_rejected(self, strict: ArgumentInterceptor) -> None: sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) try: with pytest.raises(InterceptionError): @@ -208,24 +209,24 @@ def test_socket_rejected(self, strict): finally: sock.close() - def test_generator_rejected(self, strict): + def test_generator_rejected(self, strict: ArgumentInterceptor) -> None: gen = (x for x in range(10)) with pytest.raises(InterceptionError): strict.intercept((gen,), {}) - def test_reject_error_includes_path(self, strict): + def test_reject_error_includes_path(self, strict: ArgumentInterceptor) -> None: lock = threading.Lock() with pytest.raises(InterceptionError) as exc_info: strict.intercept((), {"session": lock}) assert "kwargs.session" in exc_info.value.failures[0].path - def test_reject_error_has_suggestions(self, strict): + def test_reject_error_has_suggestions(self, strict: ArgumentInterceptor) -> None: lock = threading.Lock() with pytest.raises(InterceptionError) as exc_info: strict.intercept((lock,), {}) assert len(exc_info.value.failures[0].suggestions) > 0 - def test_multiple_rejections_collected(self, strict): + def test_multiple_rejections_collected(self, strict: ArgumentInterceptor) -> None: lock = threading.Lock() event = threading.Event() with pytest.raises(InterceptionError) as exc_info: @@ -237,13 +238,13 @@ def test_multiple_rejections_collected(self, strict): class TestLenientMode: - def test_rejected_arg_dropped(self, lenient): + def test_rejected_arg_dropped(self, lenient: ArgumentInterceptor) -> None: lock = threading.Lock() args, _kw = lenient.intercept((42, lock), {}) assert args[0] == 42 assert args[1] is None # dropped to None - def test_rejected_kwarg_dropped(self, lenient): + def test_rejected_kwarg_dropped(self, lenient: ArgumentInterceptor) -> None: lock = threading.Lock() _args, kwargs = lenient.intercept((), {"x": 1, "lock": lock}) assert kwargs == {"x": 1} @@ -253,7 +254,7 @@ def test_rejected_kwarg_dropped(self, lenient): class TestOffMode: - def test_passthrough_no_interception(self, registry): + def test_passthrough_no_interception(self, registry: Any) -> None: interceptor = ArgumentInterceptor(registry, mode="off") lock = threading.Lock() args, _kw = interceptor.intercept((lock,), {}) @@ -264,17 +265,17 @@ def test_passthrough_no_interception(self, registry): class TestRecursiveWalking: - def test_nested_uuid_in_dict(self, strict): + def test_nested_uuid_in_dict(self, strict: ArgumentInterceptor) -> None: uid = uuid.uuid4() _, kwargs = strict.intercept((), {"config": {"user_id": uid}}) assert kwargs["config"]["user_id"]["__taskito_convert__"] is True - def test_uuid_in_list(self, strict): + def test_uuid_in_list(self, strict: ArgumentInterceptor) -> None: uid = uuid.uuid4() args, _ = strict.intercept(([uid],), {}) assert args[0][0]["__taskito_convert__"] is True - def test_depth_limit(self, registry): + def test_depth_limit(self, registry: Any) -> None: interceptor = ArgumentInterceptor(registry, mode="strict", max_depth=2) uid = uuid.uuid4() # Depth 3 — beyond limit, should pass through @@ -283,8 +284,8 @@ def test_depth_limit(self, registry): # uid at depth 3 should pass through as-is (beyond max_depth=2) assert kwargs["x"]["a"]["b"]["c"] is uid - def test_circular_reference_handled(self, strict): - d: dict = {"value": 42} + def test_circular_reference_handled(self, strict: ArgumentInterceptor) -> None: + d: dict[str, Any] = {"value": 42} d["self"] = d # circular! args, _ = strict.intercept((d,), {}) # Should not infinite loop — circular ref is detected and passed through @@ -295,14 +296,14 @@ def test_circular_reference_handled(self, strict): class TestRoundTrip: - def test_uuid_full_round_trip(self, strict): + def test_uuid_full_round_trip(self, strict: ArgumentInterceptor) -> None: uid = uuid.UUID("aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee") intercepted_args, intercepted_kwargs = strict.intercept((uid,), {"id": uid}) args, kwargs, _redirects = reconstruct_args(intercepted_args, intercepted_kwargs) assert args[0] == uid assert kwargs["id"] == uid - def test_mixed_types_round_trip(self, strict): + def test_mixed_types_round_trip(self, strict: ArgumentInterceptor) -> None: dt = datetime.datetime(2025, 6, 15, 10, 30) path = pathlib.Path("/data/file.txt") intercepted_args, intercepted_kwargs = strict.intercept((42, "hello", dt), {"path": path}) @@ -312,7 +313,7 @@ def test_mixed_types_round_trip(self, strict): assert args[2] == dt assert kwargs["path"] == path - def test_nested_convert_round_trip(self, strict): + def test_nested_convert_round_trip(self, strict: ArgumentInterceptor) -> None: uid = uuid.uuid4() intercepted_args, _ = strict.intercept(({"ids": [uid]},), {}) args, _, _ = reconstruct_args(intercepted_args, {}) @@ -323,24 +324,24 @@ def test_nested_convert_round_trip(self, strict): class TestAnalyze: - def test_analyze_returns_report(self, strict): + def test_analyze_returns_report(self, strict: ArgumentInterceptor) -> None: report = strict.analyze((42, "hello"), {"uid": uuid.uuid4()}) assert isinstance(report, InterceptionReport) assert len(report.entries) == 3 - def test_analyze_shows_strategies(self, strict): + def test_analyze_shows_strategies(self, strict: ArgumentInterceptor) -> None: uid = uuid.uuid4() report = strict.analyze((42, uid), {}) strategies = [e.strategy for e in report.entries] assert Strategy.PASS in strategies assert Strategy.CONVERT in strategies - def test_analyze_on_off_mode_returns_empty(self, registry): + def test_analyze_on_off_mode_returns_empty(self, registry: Any) -> None: interceptor = ArgumentInterceptor(registry, mode="off") report = interceptor.analyze((42,), {}) assert len(report.entries) == 0 - def test_report_str_format(self, strict): + def test_report_str_format(self, strict: ArgumentInterceptor) -> None: report = strict.analyze((42,), {}) text = str(report) assert "Argument Analysis:" in text @@ -350,51 +351,51 @@ def test_report_str_format(self, strict): class TestQueueIntegration: - def test_queue_default_interception_off(self, tmp_path): + def test_queue_default_interception_off(self, tmp_path: Any) -> None: q = Queue(db_path=str(tmp_path / "test.db")) assert q._interceptor is None - def test_queue_strict_mode(self, tmp_path): + def test_queue_strict_mode(self, tmp_path: Any) -> None: q = Queue(db_path=str(tmp_path / "test.db"), interception="strict") assert q._interceptor is not None assert q._interceptor.mode == "strict" - def test_queue_enqueue_with_interception(self, tmp_path): + def test_queue_enqueue_with_interception(self, tmp_path: Any) -> None: q = Queue(db_path=str(tmp_path / "test.db"), interception="strict") @q.task() - def add(a, b): + def add(a: int, b: int) -> int: return a + b # Simple args should work fine result = add.delay(1, 2) assert result.id is not None - def test_queue_enqueue_rejects_lock(self, tmp_path): + def test_queue_enqueue_rejects_lock(self, tmp_path: Any) -> None: q = Queue(db_path=str(tmp_path / "test.db"), interception="strict") @q.task() - def bad_task(lock): + def bad_task(lock: Any) -> None: pass with pytest.raises(InterceptionError): bad_task.delay(threading.Lock()) - def test_task_analyze(self, tmp_path): + def test_task_analyze(self, tmp_path: Any) -> None: q = Queue(db_path=str(tmp_path / "test.db"), interception="strict") @q.task() - def my_task(user_id, created_at): + def my_task(user_id: int, created_at: datetime.datetime) -> None: pass report = my_task.analyze(42, datetime.datetime.now()) assert len(report.entries) == 2 - def test_task_analyze_off_mode(self, tmp_path): + def test_task_analyze_off_mode(self, tmp_path: Any) -> None: q = Queue(db_path=str(tmp_path / "test.db")) @q.task() - def my_task(x): + def my_task(x: int) -> None: pass report = my_task.analyze(42) @@ -405,7 +406,7 @@ def my_task(x): class TestCustomRegistration: - def test_register_custom_reject(self, registry, strict): + def test_register_custom_reject(self, registry: Any, strict: ArgumentInterceptor) -> None: class MyLock: pass @@ -420,13 +421,13 @@ class MyLock: strict.intercept((MyLock(),), {}) assert "distributed locking" in str(exc_info.value) - def test_register_custom_convert(self, registry, strict): + def test_register_custom_convert(self, registry: Any, strict: ArgumentInterceptor) -> None: class Money: - def __init__(self, amount: int, currency: str): + def __init__(self, amount: int, currency: str) -> None: self.amount = amount self.currency = currency - def convert_money(obj): + def convert_money(obj: Money) -> dict[str, Any]: return { "__taskito_convert__": True, "type_key": "money", diff --git a/tests/python/test_keda.py b/tests/python/test_keda.py index fec1f72..a298e34 100644 --- a/tests/python/test_keda.py +++ b/tests/python/test_keda.py @@ -3,7 +3,9 @@ import json import threading import urllib.request +from collections.abc import Generator from http.server import ThreadingHTTPServer +from pathlib import Path import pytest @@ -12,13 +14,13 @@ @pytest.fixture() -def empty_queue(tmp_path): +def empty_queue(tmp_path: Path) -> Queue: """Queue with no pending jobs.""" return Queue(db_path=str(tmp_path / "keda.db"), workers=4) @pytest.fixture() -def populated_queue(tmp_path): +def populated_queue(tmp_path: Path) -> Queue: """Queue with pending jobs across two queues.""" q = Queue(db_path=str(tmp_path / "keda.db"), workers=4) @@ -39,7 +41,7 @@ def generate_report(name: str) -> None: @pytest.fixture() -def scaler_server(empty_queue): +def scaler_server(empty_queue: Queue) -> Generator[str]: """Start a scaler HTTP server on a random port, yield the base URL.""" from taskito.scaler import _make_scaler_handler diff --git a/tests/python/test_namespace.py b/tests/python/test_namespace.py index 8408aa1..704b343 100644 --- a/tests/python/test_namespace.py +++ b/tests/python/test_namespace.py @@ -1,16 +1,17 @@ """Tests for namespace-based routing and isolation.""" import threading +from pathlib import Path from taskito import Queue -def test_namespace_enqueue_sets_namespace(tmp_path): +def test_namespace_enqueue_sets_namespace(tmp_path: Path) -> None: """Jobs enqueued on a namespaced Queue carry the namespace.""" queue = Queue(db_path=str(tmp_path / "test.db"), namespace="team-a") @queue.task() - def add(x, y): + def add(x: int, y: int) -> int: return x + y job = add.delay(1, 2) @@ -19,12 +20,12 @@ def add(x, y): assert py_job.namespace == "team-a" -def test_no_namespace_jobs_have_none(tmp_path): +def test_no_namespace_jobs_have_none(tmp_path: Path) -> None: """Jobs enqueued without a namespace have namespace=None.""" queue = Queue(db_path=str(tmp_path / "test.db")) @queue.task() - def noop(): + def noop() -> None: pass job = noop.delay() @@ -33,7 +34,7 @@ def noop(): assert py_job.namespace is None -def test_namespace_isolation_worker(tmp_path): +def test_namespace_isolation_worker(tmp_path: Path) -> None: """A namespaced worker only processes jobs from its namespace.""" db = str(tmp_path / "test.db") @@ -41,15 +42,15 @@ def test_namespace_isolation_worker(tmp_path): q_a = Queue(db_path=db, namespace="team-a") q_b = Queue(db_path=db, namespace="team-b") - results = [] + results: list[str] = [] @q_a.task() - def task_a(): + def task_a() -> str: results.append("a") return "a" @q_b.task() - def task_b(): + def task_b() -> str: results.append("b") return "b" @@ -72,18 +73,18 @@ def task_b(): q_a._inner.request_shutdown() -def test_namespace_list_jobs_scoped(tmp_path): +def test_namespace_list_jobs_scoped(tmp_path: Path) -> None: """list_jobs defaults to the queue's namespace.""" db = str(tmp_path / "test.db") q_a = Queue(db_path=db, namespace="ns-a") q_b = Queue(db_path=db, namespace="ns-b") @q_a.task() - def task_x(): + def task_x() -> None: pass @q_b.task() - def task_y(): + def task_y() -> None: pass task_x.delay() @@ -98,12 +99,12 @@ def task_y(): assert len(q_a.list_jobs(namespace=None)) == 3 -def test_namespace_preserved_in_job_result(tmp_path): +def test_namespace_preserved_in_job_result(tmp_path: Path) -> None: """JobResult.to_dict() includes the namespace.""" queue = Queue(db_path=str(tmp_path / "test.db"), namespace="my-ns") @queue.task() - def greet(name): + def greet(name: str) -> str: return f"hi {name}" job = greet.delay("world") diff --git a/tests/python/test_native_async.py b/tests/python/test_native_async.py index f42d0b3..de047ac 100644 --- a/tests/python/test_native_async.py +++ b/tests/python/test_native_async.py @@ -5,6 +5,8 @@ import asyncio import threading import time +from pathlib import Path +from typing import Any from unittest.mock import MagicMock from taskito import Queue, TaskCancelledError, current_job @@ -18,24 +20,24 @@ # ── Async detection ────────────────────────────────────────────── -def test_async_task_detected(tmp_path): +def test_async_task_detected(tmp_path: Path) -> None: """_taskito_is_async is True for async functions.""" queue = Queue(db_path=str(tmp_path / "test.db")) @queue.task() - async def my_async_task(): + async def my_async_task() -> None: pass assert my_async_task._taskito_is_async is True assert hasattr(my_async_task, "_taskito_async_fn") -def test_sync_task_not_async(tmp_path): +def test_sync_task_not_async(tmp_path: Path) -> None: """_taskito_is_async is False for sync functions.""" queue = Queue(db_path=str(tmp_path / "test.db")) @queue.task() - def my_sync_task(): + def my_sync_task() -> None: pass assert my_sync_task._taskito_is_async is False @@ -45,7 +47,7 @@ def my_sync_task(): # ── Async context (contextvars) ────────────────────────────────── -def test_async_context_var(): +def test_async_context_var() -> None: """set/get/clear async context via contextvars.""" token = set_async_context("job-1", "my_task", 0, "default") ctx = get_async_context() @@ -58,26 +60,26 @@ def test_async_context_var(): assert get_async_context() is None -def test_async_context_isolated_between_tasks(): +def test_async_context_isolated_between_tasks() -> None: """Each async task gets its own contextvar context (no cross-contamination).""" - results = [] + results: list[str | None] = [] - async def coro(job_id): + async def coro(job_id: str) -> None: token = set_async_context(job_id, "task", 0, "q") await asyncio.sleep(0.01) ctx = get_async_context() results.append(ctx.job_id if ctx else None) clear_async_context(token) - async def run_both(): + async def run_both() -> None: await asyncio.gather(coro("a"), coro("b")) asyncio.run(run_both()) - assert sorted(results) == ["a", "b"] + assert sorted(r for r in results if r is not None) == ["a", "b"] -def test_sync_context_unchanged(tmp_path): +def test_sync_context_unchanged(tmp_path: Path) -> None: """current_job still works via threading.local for sync tasks.""" from taskito.context import _clear_context, _set_context @@ -89,7 +91,7 @@ def test_sync_context_unchanged(tmp_path): _clear_context() -def test_async_context_fallback_to_sync(): +def test_async_context_fallback_to_sync() -> None: """_require_context falls back to threading.local when no async context.""" from taskito.context import _clear_context, _set_context @@ -101,7 +103,7 @@ def test_async_context_fallback_to_sync(): _clear_context() -def test_async_context_preferred_over_sync(): +def test_async_context_preferred_over_sync() -> None: """When both async and sync contexts exist, async wins.""" from taskito.context import _clear_context, _set_context @@ -116,12 +118,12 @@ def test_async_context_preferred_over_sync(): # ── AsyncTaskExecutor unit tests ───────────────────────────────── -def test_async_executor_lifecycle(): +def test_async_executor_lifecycle() -> None: """Start/stop executor without errors.""" from taskito.async_support.executor import AsyncTaskExecutor sender = MagicMock() - registry = {} + registry: dict[str, Any] = {} queue_ref = MagicMock() executor = AsyncTaskExecutor(sender, registry, queue_ref, max_concurrency=10) executor.start() @@ -130,7 +132,7 @@ def test_async_executor_lifecycle(): executor.stop() -def test_async_executor_submit_and_execute(): +def test_async_executor_submit_and_execute() -> None: """Basic async task produces correct result via executor.""" import cloudpickle @@ -138,14 +140,14 @@ def test_async_executor_submit_and_execute(): sender = MagicMock() - async def my_task(x, y): + async def my_task(x: int, y: int) -> int: return x + y # Build a minimal wrapper that the executor expects class FakeWrapper: _taskito_async_fn = staticmethod(my_task) - registry = {"test_mod.my_task": FakeWrapper()} + registry: dict[str, Any] = {"test_mod.my_task": FakeWrapper()} queue_ref = MagicMock() queue_ref._interceptor = None @@ -173,7 +175,7 @@ class FakeWrapper: assert result == 5 -def test_async_exception_reported(): +def test_async_exception_reported() -> None: """Exception in async task → failure result with traceback.""" import cloudpickle @@ -181,13 +183,13 @@ def test_async_exception_reported(): sender = MagicMock() - async def failing_task(): + async def failing_task() -> None: raise ValueError("boom") class FakeWrapper: _taskito_async_fn = staticmethod(failing_task) - registry = {"mod.failing_task": FakeWrapper()} + registry: dict[str, Any] = {"mod.failing_task": FakeWrapper()} queue_ref = MagicMock() queue_ref._interceptor = None @@ -214,7 +216,7 @@ class FakeWrapper: assert call_args[0][6] is True # should_retry -def test_async_cancellation(): +def test_async_cancellation() -> None: """TaskCancelledError → cancelled result.""" import cloudpickle @@ -222,13 +224,13 @@ def test_async_cancellation(): sender = MagicMock() - async def cancelling_task(): + async def cancelling_task() -> None: raise TaskCancelledError("cancelled") class FakeWrapper: _taskito_async_fn = staticmethod(cancelling_task) - registry = {"mod.cancelling_task": FakeWrapper()} + registry: dict[str, Any] = {"mod.cancelling_task": FakeWrapper()} queue_ref = MagicMock() queue_ref._interceptor = None @@ -252,7 +254,7 @@ class FakeWrapper: assert sender.report_cancelled.call_args[0][0] == "job-3" -def test_async_retry_filter(): +def test_async_retry_filter() -> None: """Failed async task respects retry_on filter.""" import cloudpickle @@ -260,13 +262,13 @@ def test_async_retry_filter(): sender = MagicMock() - async def flaky_task(): + async def flaky_task() -> None: raise TypeError("wrong type") class FakeWrapper: _taskito_async_fn = staticmethod(flaky_task) - registry = {"mod.flaky_task": FakeWrapper()} + registry: dict[str, Any] = {"mod.flaky_task": FakeWrapper()} queue_ref = MagicMock() queue_ref._interceptor = None @@ -293,7 +295,7 @@ class FakeWrapper: assert sender.report_failure.call_args[0][6] is False # should_retry = False -def test_async_concurrency_limit(): +def test_async_concurrency_limit() -> None: """Semaphore bounds concurrent async tasks.""" import cloudpickle @@ -304,7 +306,7 @@ def test_async_concurrency_limit(): current = 0 lock = threading.Lock() - async def slow_task(): + async def slow_task() -> None: nonlocal max_concurrent, current with lock: current += 1 @@ -316,7 +318,7 @@ async def slow_task(): class FakeWrapper: _taskito_async_fn = staticmethod(slow_task) - registry = {"mod.slow_task": FakeWrapper()} + registry: dict[str, Any] = {"mod.slow_task": FakeWrapper()} queue_ref = MagicMock() queue_ref._interceptor = None @@ -343,31 +345,31 @@ class FakeWrapper: assert sender.report_success.call_count == 5 -def test_async_middleware_hooks(): +def test_async_middleware_hooks() -> None: """Middleware before/after called for async tasks.""" import cloudpickle from taskito.async_support.executor import AsyncTaskExecutor - before_called = [] - after_called = [] + before_called: list[str] = [] + after_called: list[str] = [] class TestMiddleware(TaskMiddleware): - def before(self, job_context): + def before(self, job_context: Any) -> None: before_called.append(job_context.id) - def after(self, job_context, result, error): + def after(self, job_context: Any, result: Any, error: Any) -> None: after_called.append(job_context.id) sender = MagicMock() - async def simple_task(): + async def simple_task() -> int: return 42 class FakeWrapper: _taskito_async_fn = staticmethod(simple_task) - registry = {"mod.simple_task": FakeWrapper()} + registry: dict[str, Any] = {"mod.simple_task": FakeWrapper()} queue_ref = MagicMock() queue_ref._interceptor = None @@ -391,7 +393,7 @@ class FakeWrapper: assert "mw-job" in after_called -def test_async_task_with_injection(): +def test_async_task_with_injection() -> None: """inject=["db"] works for async tasks via executor.""" import cloudpickle @@ -399,13 +401,13 @@ def test_async_task_with_injection(): sender = MagicMock() - async def db_task(db=None): + async def db_task(db: Any = None) -> str: return f"got-{db}" class FakeWrapper: _taskito_async_fn = staticmethod(db_task) - registry = {"mod.db_task": FakeWrapper()} + registry: dict[str, Any] = {"mod.db_task": FakeWrapper()} fake_db = "fake-conn" @@ -436,23 +438,23 @@ class FakeWrapper: assert result == "got-fake-conn" -def test_async_context_available_inside_task(): +def test_async_context_available_inside_task() -> None: """current_job.id works inside an async task via contextvars.""" import cloudpickle from taskito.async_support.executor import AsyncTaskExecutor sender = MagicMock() - captured_id = [] + captured_id: list[str] = [] - async def ctx_task(): + async def ctx_task() -> str: captured_id.append(current_job.id) return "ok" class FakeWrapper: _taskito_async_fn = staticmethod(ctx_task) - registry = {"mod.ctx_task": FakeWrapper()} + registry: dict[str, Any] = {"mod.ctx_task": FakeWrapper()} queue_ref = MagicMock() queue_ref._interceptor = None @@ -475,13 +477,13 @@ class FakeWrapper: assert captured_id == ["ctx-job"] -def test_async_concurrency_parameter(tmp_path): +def test_async_concurrency_parameter(tmp_path: Path) -> None: """Queue accepts async_concurrency parameter.""" queue = Queue(db_path=str(tmp_path / "test.db"), async_concurrency=50) assert queue._async_concurrency == 50 -def test_async_concurrency_default(tmp_path): +def test_async_concurrency_default(tmp_path: Path) -> None: """Default async_concurrency is 100.""" queue = Queue(db_path=str(tmp_path / "test.db")) assert queue._async_concurrency == 100 diff --git a/tests/python/test_observability.py b/tests/python/test_observability.py index 5659712..7e0f17a 100644 --- a/tests/python/test_observability.py +++ b/tests/python/test_observability.py @@ -2,6 +2,9 @@ from __future__ import annotations +import time +from typing import Any + import pytest from taskito import Queue @@ -59,9 +62,8 @@ def test_status_unhealthy_resource(self) -> None: def test_status_tracks_init_duration(self) -> None: """init_duration_ms is populated after initialize.""" - import time - def slow_factory(): + def slow_factory() -> str: time.sleep(0.05) return "result" @@ -77,7 +79,7 @@ def test_status_tracks_recreations(self) -> None: """recreation count is incremented on successful recreate.""" call_count = 0 - def make_svc(): + def make_svc() -> str: nonlocal call_count call_count += 1 return f"v{call_count}" @@ -133,12 +135,12 @@ def test_from_test_overrides_status(self) -> None: class TestQueueResourceStatus: - def test_resource_status_with_runtime(self, tmp_path) -> None: + def test_resource_status_with_runtime(self, tmp_path: Any) -> None: """resource_status() delegates to runtime when initialized.""" queue = Queue(db_path=str(tmp_path / "q.db")) @queue.worker_resource("db") - def create_db(): + def create_db() -> str: return "conn" # Manually initialize runtime @@ -152,12 +154,12 @@ def create_db(): assert status[0]["health"] == "healthy" rt.teardown() - def test_resource_status_without_runtime(self, tmp_path) -> None: + def test_resource_status_without_runtime(self, tmp_path: Any) -> None: """resource_status() returns definitions with not_initialized health.""" queue = Queue(db_path=str(tmp_path / "q.db")) @queue.worker_resource("db") - def create_db(): + def create_db() -> str: return "conn" status = queue.resource_status() @@ -165,7 +167,7 @@ def create_db(): assert status[0]["name"] == "db" assert status[0]["health"] == "not_initialized" - def test_resource_status_empty(self, tmp_path) -> None: + def test_resource_status_empty(self, tmp_path: Any) -> None: """resource_status() returns empty list with no resources.""" queue = Queue(db_path=str(tmp_path / "q.db")) assert queue.resource_status() == [] @@ -177,14 +179,14 @@ def test_resource_status_empty(self, tmp_path) -> None: class TestHealthCheckIntegration: - def test_readiness_reports_healthy_resources(self, tmp_path) -> None: + def test_readiness_reports_healthy_resources(self, tmp_path: Any) -> None: """check_readiness includes resource status when all healthy.""" from taskito.health import check_readiness queue = Queue(db_path=str(tmp_path / "q.db")) @queue.worker_resource("db") - def create_db(): + def create_db() -> str: return "conn" rt = ResourceRuntime(queue._resource_definitions) @@ -199,14 +201,14 @@ def create_db(): assert res_check["unhealthy"] == [] rt.teardown() - def test_readiness_reports_unhealthy_resources(self, tmp_path) -> None: + def test_readiness_reports_unhealthy_resources(self, tmp_path: Any) -> None: """check_readiness marks status as degraded for unhealthy resources.""" from taskito.health import check_readiness queue = Queue(db_path=str(tmp_path / "q.db")) @queue.worker_resource("db") - def create_db(): + def create_db() -> str: return "conn" rt = ResourceRuntime(queue._resource_definitions) @@ -221,7 +223,7 @@ def create_db(): assert "db" in res_check["unhealthy"] rt.teardown() - def test_readiness_no_resources(self, tmp_path) -> None: + def test_readiness_no_resources(self, tmp_path: Any) -> None: """check_readiness works fine without any resources.""" from taskito.health import check_readiness @@ -246,21 +248,23 @@ def test_health_check_always_ok(self) -> None: class TestCLIResources: - def test_run_resources_no_resources(self, tmp_path) -> None: + def test_run_resources_no_resources(self, tmp_path: Any) -> None: """resource_status returns empty list when no resources registered.""" queue = Queue(db_path=str(tmp_path / "q.db")) assert queue.resource_status() == [] - def test_resource_status_table_format(self, tmp_path, capsys) -> None: + def test_resource_status_table_format( + self, tmp_path: Any, capsys: pytest.CaptureFixture[str] + ) -> None: """Verify table output format from CLI helper.""" queue = Queue(db_path=str(tmp_path / "q.db")) @queue.worker_resource("config") - def create_config(): + def create_config() -> dict[str, str]: return {} @queue.worker_resource("db", depends_on=["config"]) - def create_db(config): + def create_db(config: Any) -> str: return "conn" rt = ResourceRuntime(queue._resource_definitions) @@ -294,12 +298,12 @@ class TestPrometheusResourceMetrics: def test_prometheus_middleware_has_resource_metrics(self) -> None: """Verify resource metric singletons are initialized.""" pytest.importorskip("prometheus_client") - from taskito.contrib.prometheus import _init_metrics + from taskito.contrib.prometheus import _init_metrics # type: ignore[attr-defined] _init_metrics() from taskito.contrib import prometheus as pmod - assert pmod._resource_health is not None - assert pmod._resource_recreations is not None - assert pmod._resource_init_duration is not None + assert pmod._resource_health is not None # type: ignore[attr-defined] + assert pmod._resource_recreations is not None # type: ignore[attr-defined] + assert pmod._resource_init_duration is not None # type: ignore[attr-defined] diff --git a/tests/python/test_periodic.py b/tests/python/test_periodic.py index 854aff1..c8be23d 100644 --- a/tests/python/test_periodic.py +++ b/tests/python/test_periodic.py @@ -2,6 +2,7 @@ import threading import time +from pathlib import Path import pytest @@ -9,16 +10,16 @@ @pytest.fixture -def queue(tmp_path): +def queue(tmp_path: Path) -> Queue: db_path = str(tmp_path / "test_periodic.db") return Queue(db_path=db_path, workers=1) -def test_periodic_task_registration(queue): +def test_periodic_task_registration(queue: Queue) -> None: """Periodic tasks are registered as both regular tasks and periodic configs.""" @queue.periodic(cron="0 * * * * *") - def every_minute(): + def every_minute() -> str: return "tick" assert every_minute.name.endswith("every_minute") @@ -27,22 +28,22 @@ def every_minute(): assert queue._periodic_configs[0]["cron_expr"] == "0 * * * * *" -def test_periodic_task_direct_call(queue): +def test_periodic_task_direct_call(queue: Queue) -> None: """Periodic tasks can still be called directly like regular tasks.""" @queue.periodic(cron="0 * * * * *") - def add(a, b): + def add(a: int, b: int) -> int: return a + b assert add(3, 4) == 7 -def test_periodic_task_triggers(queue): +def test_periodic_task_triggers(queue: Queue) -> None: """Periodic task gets enqueued by the scheduler when due.""" - results = [] + results: list[int] = [] @queue.periodic(cron="* * * * * *") # every second - def frequent_task(): + def frequent_task() -> str: results.append(1) return "done" diff --git a/tests/python/test_priority.py b/tests/python/test_priority.py index 4191581..1b3d6b2 100644 --- a/tests/python/test_priority.py +++ b/tests/python/test_priority.py @@ -2,6 +2,7 @@ import threading import time +from pathlib import Path import pytest @@ -9,17 +10,17 @@ @pytest.fixture -def queue(tmp_path): +def queue(tmp_path: Path) -> Queue: db_path = str(tmp_path / "test_priority.db") return Queue(db_path=db_path, workers=1) # 1 worker for ordering -def test_priority_ordering(queue): +def test_priority_ordering(queue: Queue) -> None: """Higher priority jobs should be processed first.""" - results = [] + results: list[str] = [] @queue.task() - def record_task(label): + def record_task(label: str) -> str: results.append(label) return label diff --git a/tests/python/test_progress.py b/tests/python/test_progress.py index bb86c95..58cd7c6 100644 --- a/tests/python/test_progress.py +++ b/tests/python/test_progress.py @@ -4,12 +4,14 @@ import time +from taskito import Queue -def test_update_progress(queue): + +def test_update_progress(queue: Queue) -> None: """Progress can be updated and read back.""" @queue.task() - def slow_task(): + def slow_task() -> str: time.sleep(0.5) return "done" @@ -20,20 +22,23 @@ def slow_task(): queue.update_progress(job.id, 50) refreshed = queue.get_job(job.id) + assert refreshed is not None assert refreshed.progress == 50 queue.update_progress(job.id, 100) refreshed = queue.get_job(job.id) + assert refreshed is not None assert refreshed.progress == 100 -def test_progress_starts_none(queue): +def test_progress_starts_none(queue: Queue) -> None: """Progress is None by default.""" @queue.task() - def task_a(): + def task_a() -> int: return 1 job = task_a.delay() refreshed = queue.get_job(job.id) + assert refreshed is not None assert refreshed.progress is None diff --git a/tests/python/test_proxies.py b/tests/python/test_proxies.py index 4eada31..b6c960f 100644 --- a/tests/python/test_proxies.py +++ b/tests/python/test_proxies.py @@ -19,7 +19,7 @@ class TestFileHandler: - def test_detect_open_file(self, tmp_path) -> None: + def test_detect_open_file(self, tmp_path: Any) -> None: f = open(tmp_path / "test.txt", "w") # noqa: SIM115 try: handler = FileHandler() @@ -27,7 +27,7 @@ def test_detect_open_file(self, tmp_path) -> None: finally: f.close() - def test_detect_closed_file(self, tmp_path) -> None: + def test_detect_closed_file(self, tmp_path: Any) -> None: f = open(tmp_path / "test.txt", "w") # noqa: SIM115 f.close() handler = FileHandler() @@ -39,7 +39,7 @@ def test_detect_stdin(self) -> None: handler = FileHandler() assert handler.detect(sys.stdin) is False - def test_deconstruct_text_file(self, tmp_path) -> None: + def test_deconstruct_text_file(self, tmp_path: Any) -> None: path = tmp_path / "data.txt" path.write_text("hello world") with open(path) as f: @@ -50,7 +50,7 @@ def test_deconstruct_text_file(self, tmp_path) -> None: assert recipe["encoding"] is not None assert recipe["position"] == 0 - def test_deconstruct_binary_file(self, tmp_path) -> None: + def test_deconstruct_binary_file(self, tmp_path: Any) -> None: path = tmp_path / "data.bin" path.write_bytes(b"\x00\x01\x02") with open(path, "rb") as f: @@ -59,7 +59,7 @@ def test_deconstruct_binary_file(self, tmp_path) -> None: assert recipe["mode"] == "rb" assert recipe["encoding"] is None - def test_reconstruct_text_file(self, tmp_path) -> None: + def test_reconstruct_text_file(self, tmp_path: Any) -> None: path = tmp_path / "data.txt" path.write_text("hello world") handler = FileHandler() @@ -70,7 +70,7 @@ def test_reconstruct_text_file(self, tmp_path) -> None: finally: f.close() - def test_reconstruct_at_position(self, tmp_path) -> None: + def test_reconstruct_at_position(self, tmp_path: Any) -> None: path = tmp_path / "data.txt" path.write_text("hello world") handler = FileHandler() @@ -81,7 +81,7 @@ def test_reconstruct_at_position(self, tmp_path) -> None: finally: f.close() - def test_cleanup_closes_file(self, tmp_path) -> None: + def test_cleanup_closes_file(self, tmp_path: Any) -> None: path = tmp_path / "data.txt" path.write_text("test") f = open(path) # noqa: SIM115 @@ -89,7 +89,7 @@ def test_cleanup_closes_file(self, tmp_path) -> None: handler.cleanup(f) assert f.closed - def test_cleanup_already_closed(self, tmp_path) -> None: + def test_cleanup_already_closed(self, tmp_path: Any) -> None: path = tmp_path / "data.txt" path.write_text("test") f = open(path) # noqa: SIM115 @@ -144,7 +144,7 @@ def test_get_missing(self) -> None: reg = ProxyRegistry() assert reg.get("nonexistent") is None - def test_find_handler(self, tmp_path) -> None: + def test_find_handler(self, tmp_path: Any) -> None: reg = ProxyRegistry() register_builtin_handlers(reg) f = open(tmp_path / "test.txt", "w") # noqa: SIM115 @@ -167,7 +167,7 @@ def test_find_handler_no_match(self) -> None: class TestReconstruct: - def test_proxy_marker_reconstructed(self, tmp_path) -> None: + def test_proxy_marker_reconstructed(self, tmp_path: Any) -> None: path = tmp_path / "data.txt" path.write_text("content") reg = ProxyRegistry() @@ -193,7 +193,7 @@ def test_proxy_marker_reconstructed(self, tmp_path) -> None: finally: cleanup_proxies(cleanup_list) - def test_cleanup_list_populated(self, tmp_path) -> None: + def test_cleanup_list_populated(self, tmp_path: Any) -> None: path = tmp_path / "data.txt" path.write_text("test") reg = ProxyRegistry() @@ -219,7 +219,7 @@ def test_cleanup_list_populated(self, tmp_path) -> None: cleanup_proxies(cleanup_list) assert obj.closed - def test_cleanup_runs_lifo(self, tmp_path) -> None: + def test_cleanup_runs_lifo(self, tmp_path: Any) -> None: """Cleanup runs in reverse reconstruction order.""" p1 = tmp_path / "a.txt" p1.write_text("a") @@ -270,7 +270,7 @@ def reconstruct(self, recipe: dict[str, Any], version: int) -> Any: def cleanup(self, obj: Any) -> None: raise RuntimeError("cleanup boom") - cleanup_list = [(BadHandler(), "obj")] # type: ignore[list-item] + cleanup_list: list[tuple[Any, Any]] = [(BadHandler(), "obj")] cleanup_proxies(cleanup_list) # should not raise def test_missing_handler_raises(self) -> None: @@ -300,7 +300,7 @@ def test_no_markers_passthrough(self) -> None: class TestIdentityTracking: - def test_same_object_deduped(self, tmp_path) -> None: + def test_same_object_deduped(self, tmp_path: Any) -> None: """Same file handle passed twice produces one marker and one ref.""" path = tmp_path / "data.txt" path.write_text("test") @@ -308,6 +308,7 @@ def test_same_object_deduped(self, tmp_path) -> None: queue = Queue(db_path=str(tmp_path / "q.db"), interception="strict") f = open(path) # noqa: SIM115 try: + assert queue._interceptor is not None walker = queue._interceptor._walker args, _kw, _res = walker.walk((f, f), {}) @@ -319,7 +320,7 @@ def test_same_object_deduped(self, tmp_path) -> None: finally: f.close() - def test_identity_reconstructed_once(self, tmp_path) -> None: + def test_identity_reconstructed_once(self, tmp_path: Any) -> None: """Reconstruction creates one object; both positions share it.""" path = tmp_path / "data.txt" path.write_text("shared") @@ -348,7 +349,7 @@ def test_identity_reconstructed_once(self, tmp_path) -> None: finally: cleanup_proxies(cleanup) - def test_different_objects_separate(self, tmp_path) -> None: + def test_different_objects_separate(self, tmp_path: Any) -> None: """Two different file handles get separate recipes.""" p1 = tmp_path / "a.txt" p1.write_text("a") @@ -359,6 +360,7 @@ def test_different_objects_separate(self, tmp_path) -> None: f1 = open(p1) # noqa: SIM115 f2 = open(p2) # noqa: SIM115 try: + assert queue._interceptor is not None walker = queue._interceptor._walker args, _, _ = walker.walk((f1, f2), {}) assert args[0].get("__taskito_proxy__") is True @@ -374,7 +376,7 @@ def test_different_objects_separate(self, tmp_path) -> None: # --------------------------------------------------------------------------- -def test_proxy_in_nested_dict(tmp_path) -> None: +def test_proxy_in_nested_dict(tmp_path: Any) -> None: """File inside a dict is proxied.""" path = tmp_path / "nested.txt" path.write_text("nested content") @@ -382,6 +384,7 @@ def test_proxy_in_nested_dict(tmp_path) -> None: queue = Queue(db_path=str(tmp_path / "q.db"), interception="strict") f = open(path) # noqa: SIM115 try: + assert queue._interceptor is not None walker = queue._interceptor._walker _, kwargs, _ = walker.walk((), {"config": {"file": f}}) inner = kwargs["config"]["file"] @@ -398,10 +401,10 @@ def test_proxy_in_nested_dict(tmp_path) -> None: def test_proxy_roundtrip_in_test_mode(queue: Queue) -> None: """In test mode, original objects pass through (no serialization).""" - captured: list = [] + captured: list[Any] = [] @queue.task() - def use_logger(lgr): + def use_logger(lgr: Any) -> None: captured.append(lgr) lgr = logging.getLogger("test.proxy.e2e") @@ -413,11 +416,12 @@ def use_logger(lgr): assert captured[0] is lgr -def test_logger_proxy_marker_production(tmp_path) -> None: +def test_logger_proxy_marker_production(tmp_path: Any) -> None: """Logger produces a proxy marker when interception is on.""" queue = Queue(db_path=str(tmp_path / "q.db"), interception="strict") lgr = logging.getLogger("test.proxy.marker") + assert queue._interceptor is not None walker = queue._interceptor._walker args, _, _ = walker.walk((lgr,), {}) assert args[0].get("__taskito_proxy__") is True diff --git a/tests/python/test_rate_limit.py b/tests/python/test_rate_limit.py index 9310d9e..6a03dd3 100644 --- a/tests/python/test_rate_limit.py +++ b/tests/python/test_rate_limit.py @@ -3,13 +3,15 @@ import threading import time +from taskito import Queue -def test_rate_limit_throttles(queue): + +def test_rate_limit_throttles(queue: Queue) -> None: """Rate-limited tasks should be throttled.""" - timestamps = [] + timestamps: list[float] = [] @queue.task(rate_limit="2/s") - def rate_limited_task(n): + def rate_limited_task(n: int) -> int: timestamps.append(time.monotonic()) return n diff --git a/tests/python/test_resource_system_full.py b/tests/python/test_resource_system_full.py index fb964a6..61fb9aa 100644 --- a/tests/python/test_resource_system_full.py +++ b/tests/python/test_resource_system_full.py @@ -128,6 +128,7 @@ class MyDB: q.register_type(MyDB, "redirect", resource="db") # Verify it's registered + assert q._interceptor is not None entry = q._interceptor._registry.resolve(MyDB()) assert entry is not None @@ -209,10 +210,10 @@ def test_pool_config(self) -> None: assert cfg.pool_min == 2 def test_resource_pool_acquire_release(self) -> None: - created = [] + created: list[dict[str, int]] = [] - def factory() -> dict: - d: dict = {"id": len(created)} + def factory() -> dict[str, int]: + d: dict[str, int] = {"id": len(created)} created.append(d) return d @@ -428,7 +429,7 @@ def test_status_includes_pool(self) -> None: class TestInjectAnnotation: def test_inject_alias_created(self) -> None: - alias = Inject["db"] + alias = Inject["db"] # type: ignore[type-arg,name-defined] assert isinstance(alias, _InjectAlias) assert alias.resource_name == "db" @@ -440,7 +441,7 @@ def test_inject_annotation_detected_in_task(self) -> None: q = Queue(db_path=":memory:") @q.task() - def my_task(x: int, db: Inject["db"]) -> None: # type: ignore[valid-type] # noqa: F821 + def my_task(x: int, db: Inject["db"]) -> None: # type: ignore[type-arg,name-defined] # noqa: F821 pass assert "db" in q._task_inject_map.get(my_task.name, []) @@ -449,7 +450,7 @@ def test_inject_annotation_merged_with_explicit(self) -> None: q = Queue(db_path=":memory:") @q.task(inject=["redis"]) - def my_task(x: int, db: Inject["db"]) -> None: # type: ignore[valid-type] # noqa: F821 + def my_task(x: int, db: Inject["db"]) -> None: # type: ignore[type-arg,name-defined] # noqa: F821 pass injects = q._task_inject_map.get(my_task.name, []) @@ -624,7 +625,7 @@ def test_inject_annotation_in_test_mode(self) -> None: q = Queue(db_path=":memory:") @q.task() - def process(order_id: int, db: Inject["db"] = None) -> str: # type: ignore[valid-type,assignment] # noqa: F821 + def process(order_id: int, db: Inject["db"] = None) -> str: # type: ignore[type-arg,name-defined,assignment] # noqa: F821 return f"{order_id}:{db}" with q.test_mode(resources={"db": "injected"}) as results: diff --git a/tests/python/test_resources.py b/tests/python/test_resources.py index 28d150a..c94cc2b 100644 --- a/tests/python/test_resources.py +++ b/tests/python/test_resources.py @@ -3,6 +3,7 @@ from __future__ import annotations import time +from typing import Any import pytest @@ -28,7 +29,7 @@ def test_worker_resource_decorator_registers(queue: Queue) -> None: """@queue.worker_resource stores a ResourceDefinition.""" @queue.worker_resource("cache") - def create_cache(): + def create_cache() -> dict[str, int]: return {"hits": 0} assert "cache" in queue._resource_definitions @@ -40,7 +41,7 @@ def create_cache(): def test_register_resource_programmatic(queue: Queue) -> None: """register_resource() stores a definition without the decorator.""" - def factory(): + def factory() -> str: return "hello" queue.register_resource(ResourceDefinition(name="greeter", factory=factory)) @@ -57,13 +58,13 @@ def test_circular_dependency_detected(queue: Queue) -> None: """Circular deps raise CircularDependencyError at registration time.""" @queue.worker_resource("a", depends_on=["b"]) - def make_a(b): + def make_a(b: Any) -> str: return "a" with pytest.raises(CircularDependencyError): @queue.worker_resource("b", depends_on=["a"]) - def make_b(a): + def make_b(a: Any) -> str: return "b" @@ -129,10 +130,10 @@ def test_teardown_reverse_order() -> None: """Resources are torn down in reverse initialization order.""" teardown_log: list[str] = [] - def td_config(inst): + def td_config(inst: Any) -> None: teardown_log.append("config") - def td_db(inst): + def td_db(inst: Any) -> None: teardown_log.append("db") defs = { @@ -178,7 +179,7 @@ def test_from_test_overrides() -> None: def test_async_factory() -> None: """Async factories are awaited during initialize.""" - async def make_client(): + async def make_client() -> str: return "async_client" defs = { @@ -199,13 +200,13 @@ def test_resource_injected_into_task(queue: Queue) -> None: """Task with inject=["db"] receives the resource as a kwarg.""" @queue.worker_resource("db") - def create_db(): + def create_db() -> str: return "live_db" - results_holder: list = [] + results_holder: list[Any] = [] @queue.task(inject=["db"]) - def my_task(x: int, db): + def my_task(x: int, db: Any = None) -> None: results_holder.append((x, db)) with queue.test_mode(resources={"db": "test_db"}) as results: @@ -220,13 +221,13 @@ def test_explicit_kwarg_wins_over_inject(queue: Queue) -> None: """Caller-provided kwargs are not overridden by injection.""" @queue.worker_resource("db") - def create_db(): + def create_db() -> str: return "injected_db" - results_holder: list = [] + results_holder: list[Any] = [] @queue.task(inject=["db"]) - def my_task(db): + def my_task(db: Any = None) -> None: results_holder.append(db) with queue.test_mode(resources={"db": "injected_db"}): @@ -239,13 +240,13 @@ def test_test_mode_with_resources(queue: Queue) -> None: """test_mode(resources=...) injects mock resources.""" @queue.worker_resource("cache") - def create_cache(): + def create_cache() -> dict[str, str]: return {} - captured: list = [] + captured: list[Any] = [] @queue.task(inject=["cache"]) - def use_cache(cache): + def use_cache(cache: Any = None) -> None: captured.append(cache) mock_cache = {"key": "value"} @@ -280,14 +281,14 @@ def test_health_check_recreation() -> None: call_count = 0 - def make_svc(): + def make_svc() -> str: nonlocal call_count call_count += 1 if call_count > 1: raise RuntimeError("factory broken") return f"svc_v{call_count}" - def check_health(inst): + def check_health(inst: Any) -> bool: # Always fail after initial creation return False @@ -321,15 +322,15 @@ def check_health(inst): # --------------------------------------------------------------------------- -def test_banner_shows_resources(queue: Queue, capsys) -> None: +def test_banner_shows_resources(queue: Queue, capsys: pytest.CaptureFixture[str]) -> None: """Resources section appears in the startup banner.""" @queue.worker_resource("db", depends_on=["config"]) - def create_db(config): + def create_db(config: Any) -> str: return "db" @queue.worker_resource("config") - def create_config(): + def create_config() -> dict[str, str]: return {} queue._print_banner(["default"]) @@ -349,7 +350,7 @@ def test_task_wrapper_inject_property(queue: Queue) -> None: """TaskWrapper exposes the inject list.""" @queue.task(inject=["db", "cache"]) - def my_task(db, cache): + def my_task(db: Any = None, cache: Any = None) -> None: pass assert my_task.inject == ["db", "cache"] @@ -359,7 +360,7 @@ def test_task_wrapper_inject_default(queue: Queue) -> None: """TaskWrapper.inject defaults to empty list.""" @queue.task() - def my_task(): + def my_task() -> None: pass assert my_task.inject == [] diff --git a/tests/python/test_retry.py b/tests/python/test_retry.py index 523c8cc..fe55909 100644 --- a/tests/python/test_retry.py +++ b/tests/python/test_retry.py @@ -3,13 +3,15 @@ import threading import time +from taskito import Queue -def test_failing_task_retries(queue): + +def test_failing_task_retries(queue: Queue) -> None: """A failing task should be retried up to max_retries times.""" call_count = 0 @queue.task(max_retries=3, retry_backoff=0.1) - def flaky_task(): + def flaky_task() -> str: nonlocal call_count call_count += 1 if call_count < 3: @@ -29,11 +31,11 @@ def flaky_task(): assert call_count == 3 -def test_exhausted_retries_goes_to_dlq(queue): +def test_exhausted_retries_goes_to_dlq(queue: Queue) -> None: """A task that always fails should end up in the dead letter queue.""" @queue.task(max_retries=2, retry_backoff=0.1) - def always_fails(): + def always_fails() -> None: raise RuntimeError("permanent failure") always_fails.delay() @@ -53,11 +55,11 @@ def always_fails(): assert dead[0]["task_name"].endswith("always_fails") -def test_retry_dead_letter(queue): +def test_retry_dead_letter(queue: Queue) -> None: """A dead letter job can be re-enqueued.""" @queue.task(max_retries=1, retry_backoff=0.1) - def fail_once(): + def fail_once() -> None: raise RuntimeError("fail") fail_once.delay() diff --git a/tests/python/test_retry_history.py b/tests/python/test_retry_history.py index c6307b8..5ab7fed 100644 --- a/tests/python/test_retry_history.py +++ b/tests/python/test_retry_history.py @@ -1,6 +1,7 @@ """Tests for retry history (job_errors tracking).""" import threading +from pathlib import Path import pytest @@ -8,17 +9,17 @@ @pytest.fixture -def queue(tmp_path): +def queue(tmp_path: Path) -> Queue: db_path = str(tmp_path / "test_retry_history.db") return Queue(db_path=db_path, workers=1) -def test_retry_errors_recorded(queue): +def test_retry_errors_recorded(queue: Queue) -> None: """Failed attempts are recorded in job.errors.""" call_count = {"n": 0} @queue.task(max_retries=3, retry_backoff=0.01) - def flaky(): + def flaky() -> str: call_count["n"] += 1 if call_count["n"] <= 3: raise ValueError(f"attempt {call_count['n']}") @@ -40,11 +41,11 @@ def flaky(): assert errors[2]["attempt"] == 2 -def test_errors_empty_on_success(queue): +def test_errors_empty_on_success(queue: Queue) -> None: """Successful jobs have an empty errors list.""" @queue.task() - def ok_task(): + def ok_task() -> int: return 42 job = ok_task.delay() diff --git a/tests/python/test_serializers.py b/tests/python/test_serializers.py index 5c11b3a..977a718 100644 --- a/tests/python/test_serializers.py +++ b/tests/python/test_serializers.py @@ -9,69 +9,69 @@ class TestJsonSerializer: - def test_roundtrip_dict(self): + def test_roundtrip_dict(self) -> None: s = JsonSerializer() data = {"key": "value", "num": 42, "nested": [1, 2, 3]} assert s.loads(s.dumps(data)) == data - def test_roundtrip_list(self): + def test_roundtrip_list(self) -> None: s = JsonSerializer() data = [1, "two", None, True] assert s.loads(s.dumps(data)) == data - def test_roundtrip_primitives(self): + def test_roundtrip_primitives(self) -> None: s = JsonSerializer() for val in [42, 3.14, "hello", True, None]: assert s.loads(s.dumps(val)) == val - def test_dumps_returns_bytes(self): + def test_dumps_returns_bytes(self) -> None: s = JsonSerializer() result = s.dumps({"a": 1}) assert isinstance(result, bytes) - def test_non_serializable_raises(self): + def test_non_serializable_raises(self) -> None: s = JsonSerializer() with pytest.raises(TypeError): s.dumps(object()) - def test_invalid_bytes_raises(self): + def test_invalid_bytes_raises(self) -> None: s = JsonSerializer() with pytest.raises((json.JSONDecodeError, UnicodeDecodeError, ValueError)): s.loads(b"\xff\xfe") class TestCloudpickleSerializer: - def test_roundtrip_dict(self): + def test_roundtrip_dict(self) -> None: s = CloudpickleSerializer() data = {"key": "value", "num": 42} assert s.loads(s.dumps(data)) == data - def test_roundtrip_lambda(self): + def test_roundtrip_lambda(self) -> None: s = CloudpickleSerializer() fn = lambda x: x * 2 # noqa: E731 restored = s.loads(s.dumps(fn)) assert restored(5) == 10 - def test_dumps_returns_bytes(self): + def test_dumps_returns_bytes(self) -> None: s = CloudpickleSerializer() assert isinstance(s.dumps(42), bytes) - def test_invalid_bytes_raises(self): + def test_invalid_bytes_raises(self) -> None: s = CloudpickleSerializer() with pytest.raises((pickle.UnpicklingError, EOFError)): s.loads(b"not-valid-pickle") class TestSerializerProtocol: - def test_json_is_serializer(self): + def test_json_is_serializer(self) -> None: assert isinstance(JsonSerializer(), Serializer) - def test_cloudpickle_is_serializer(self): + def test_cloudpickle_is_serializer(self) -> None: assert isinstance(CloudpickleSerializer(), Serializer) class TestMsgPackSerializer: - def test_roundtrip(self): + def test_roundtrip(self) -> None: pytest.importorskip("msgpack") from taskito.serializers import MsgPackSerializer @@ -79,7 +79,7 @@ def test_roundtrip(self): data = {"key": "value", "num": 42} assert s.loads(s.dumps(data)) == data - def test_dumps_returns_bytes(self): + def test_dumps_returns_bytes(self) -> None: pytest.importorskip("msgpack") from taskito.serializers import MsgPackSerializer @@ -88,7 +88,7 @@ def test_dumps_returns_bytes(self): class TestEncryptedSerializer: - def test_roundtrip(self): + def test_roundtrip(self) -> None: pytest.importorskip("cryptography") import os @@ -99,7 +99,7 @@ def test_roundtrip(self): data = {"secret": "payload"} assert s.loads(s.dumps(data)) == data - def test_wrong_key_fails(self): + def test_wrong_key_fails(self) -> None: pytest.importorskip("cryptography") import os @@ -113,7 +113,7 @@ def test_wrong_key_fails(self): with pytest.raises(InvalidTag): s2.loads(encrypted) - def test_tampered_ciphertext_fails(self): + def test_tampered_ciphertext_fails(self) -> None: pytest.importorskip("cryptography") import os diff --git a/tests/python/test_shutdown.py b/tests/python/test_shutdown.py index 515abcf..7c063b8 100644 --- a/tests/python/test_shutdown.py +++ b/tests/python/test_shutdown.py @@ -3,13 +3,15 @@ import threading import time +from taskito import Queue -def test_graceful_shutdown_completes_inflight(queue): + +def test_graceful_shutdown_completes_inflight(queue: Queue) -> None: """Graceful shutdown waits for in-flight tasks to complete.""" completed = threading.Event() @queue.task() - def slow_task(): + def slow_task() -> str: time.sleep(1) completed.set() return "done" @@ -30,14 +32,15 @@ def slow_task(): assert completed.is_set() fetched = queue.get_job(job.id) + assert fetched is not None assert fetched.status == "complete" -def test_shutdown_stops_worker(queue): +def test_shutdown_stops_worker(queue: Queue) -> None: """request_shutdown causes run_worker to return.""" @queue.task() - def noop(): + def noop() -> None: pass worker_thread = threading.Thread(target=queue.run_worker, daemon=True) diff --git a/tests/python/test_unique.py b/tests/python/test_unique.py index 0cb1e36..c428bfc 100644 --- a/tests/python/test_unique.py +++ b/tests/python/test_unique.py @@ -4,12 +4,14 @@ import threading +from taskito import Queue -def test_unique_key_dedup(queue): + +def test_unique_key_dedup(queue: Queue) -> None: """Two jobs with the same unique_key should return the same job ID.""" @queue.task() - def process(data): + def process(data: str) -> str: return data job1 = process.apply_async(args=("a",), unique_key="dedup-1") @@ -18,11 +20,11 @@ def process(data): assert job1.id == job2.id -def test_different_unique_keys(queue): +def test_different_unique_keys(queue: Queue) -> None: """Different unique keys should create separate jobs.""" @queue.task() - def process(data): + def process(data: str) -> str: return data job1 = process.apply_async(args=("a",), unique_key="key-a") @@ -31,11 +33,11 @@ def process(data): assert job1.id != job2.id -def test_unique_key_allows_after_complete(queue): +def test_unique_key_allows_after_complete(queue: Queue) -> None: """After a unique job completes, a new one with the same key can be created.""" @queue.task() - def fast_task(): + def fast_task() -> str: return "done" job1 = fast_task.apply_async(unique_key="once") @@ -50,11 +52,11 @@ def fast_task(): assert job2.id != job1.id -def test_no_unique_key_allows_duplicates(queue): +def test_no_unique_key_allows_duplicates(queue: Queue) -> None: """Without unique_key, duplicate jobs are allowed.""" @queue.task() - def process(data): + def process(data: str) -> str: return data job1 = process.delay("a") diff --git a/tests/python/test_webhooks.py b/tests/python/test_webhooks.py index 11c4451..39a47f6 100644 --- a/tests/python/test_webhooks.py +++ b/tests/python/test_webhooks.py @@ -5,7 +5,9 @@ import json import threading import time +from collections.abc import Generator from http.server import BaseHTTPRequestHandler, HTTPServer +from typing import Any import pytest @@ -14,12 +16,12 @@ @pytest.fixture -def webhook_server(): +def webhook_server() -> Generator[tuple[str, list[dict[str, Any]]]]: """Start a local HTTP server that records webhook deliveries.""" - received = [] + received: list[dict[str, Any]] = [] class Handler(BaseHTTPRequestHandler): - def do_POST(self): + def do_POST(self) -> None: length = int(self.headers.get("Content-Length", 0)) body = self.rfile.read(length) received.append( @@ -31,7 +33,7 @@ def do_POST(self): self.send_response(200) self.end_headers() - def log_message(self, *args): + def log_message(self, *args: Any) -> None: pass server = HTTPServer(("127.0.0.1", 0), Handler) @@ -44,7 +46,7 @@ def log_message(self, *args): server.shutdown() -def test_webhook_delivery(webhook_server): +def test_webhook_delivery(webhook_server: tuple[str, list[dict[str, Any]]]) -> None: """Webhooks are delivered to registered URLs.""" url, received = webhook_server mgr = WebhookManager() @@ -58,7 +60,7 @@ def test_webhook_delivery(webhook_server): assert received[0]["body"]["job_id"] == "abc" -def test_webhook_event_filtering(webhook_server): +def test_webhook_event_filtering(webhook_server: tuple[str, list[dict[str, Any]]]) -> None: """Webhooks with event filters only receive matching events.""" url, received = webhook_server mgr = WebhookManager() @@ -72,7 +74,7 @@ def test_webhook_event_filtering(webhook_server): assert received[0]["body"]["event"] == "job.failed" -def test_webhook_hmac_signing(webhook_server): +def test_webhook_hmac_signing(webhook_server: tuple[str, list[dict[str, Any]]]) -> None: """Webhooks with a secret include a valid HMAC signature.""" url, received = webhook_server secret = "my-secret-key" @@ -93,7 +95,7 @@ def test_webhook_hmac_signing(webhook_server): assert sig_header == f"sha256={expected_sig}" -def test_webhook_url_validation(): +def test_webhook_url_validation() -> None: """Only http:// and https:// URLs are accepted.""" mgr = WebhookManager() @@ -108,7 +110,7 @@ def test_webhook_url_validation(): mgr.add_webhook("https://example.com/hook") -def test_webhook_custom_headers(webhook_server): +def test_webhook_custom_headers(webhook_server: tuple[str, list[dict[str, Any]]]) -> None: """Custom headers are included in webhook requests.""" url, received = webhook_server mgr = WebhookManager() @@ -121,7 +123,7 @@ def test_webhook_custom_headers(webhook_server): assert received[0]["headers"].get("X-Custom") == "test-value" -def test_webhook_no_subscribers(): +def test_webhook_no_subscribers() -> None: """Notifying with no matching webhooks doesn't raise.""" mgr = WebhookManager() mgr.notify(EventType.JOB_COMPLETED, {"job_id": "1"}) diff --git a/tests/python/test_worker.py b/tests/python/test_worker.py index ade8448..f2ae258 100644 --- a/tests/python/test_worker.py +++ b/tests/python/test_worker.py @@ -2,21 +2,22 @@ import threading import time +from pathlib import Path import pytest from taskito import Queue -def test_multiple_tasks(queue): +def test_multiple_tasks(queue: Queue) -> None: """Worker handles multiple different task types.""" @queue.task() - def task_a(x): + def task_a(x: int) -> int: return x * 2 @queue.task() - def task_b(x): + def task_b(x: int) -> int: return x + 10 job_a = task_a.delay(5) @@ -32,11 +33,11 @@ def task_b(x): assert job_b.result(timeout=10) == 15 -def test_get_job(queue): +def test_get_job(queue: Queue) -> None: """Can retrieve a job by ID.""" @queue.task() - def simple(): + def simple() -> int: return 42 job = simple.delay() @@ -45,11 +46,11 @@ def simple(): assert fetched.id == job.id -def test_job_status_progression(queue): +def test_job_status_progression(queue: Queue) -> None: """Job status progresses from pending through complete.""" @queue.task() - def slow(): + def slow() -> str: time.sleep(0.5) return "done" @@ -57,6 +58,7 @@ def slow(): # Initially pending fetched = queue.get_job(job.id) + assert fetched is not None assert fetched.status == "pending" worker_thread = threading.Thread( @@ -70,17 +72,18 @@ def slow(): # After completion fetched = queue.get_job(job.id) + assert fetched is not None assert fetched.status == "complete" @pytest.mark.asyncio -async def test_async_result(tmp_path): +async def test_async_result(tmp_path: Path) -> None: """Async result retrieval works.""" db_path = str(tmp_path / "test_async.db") queue = Queue(db_path=db_path, workers=2) @queue.task() - def add(a, b): + def add(a: int, b: int) -> int: return a + b job = add.delay(10, 20) @@ -96,13 +99,13 @@ def add(a, b): @pytest.mark.asyncio -async def test_async_stats(tmp_path): +async def test_async_stats(tmp_path: Path) -> None: """Async stats work.""" db_path = str(tmp_path / "test_async_stats.db") queue = Queue(db_path=db_path, workers=2) @queue.task() - def noop(): + def noop() -> None: pass noop.delay() diff --git a/tests/python/test_worker_resources.py b/tests/python/test_worker_resources.py index 460a510..dac4301 100644 --- a/tests/python/test_worker_resources.py +++ b/tests/python/test_worker_resources.py @@ -5,6 +5,7 @@ import json import threading import time +from typing import Any from taskito import Queue @@ -14,21 +15,21 @@ class TestWorkerAdvertisement: - def test_no_resources_returns_none(self, tmp_path) -> None: + def test_no_resources_returns_none(self, tmp_path: Any) -> None: """_build_resource_health_json returns None when no resources.""" queue = Queue(db_path=str(tmp_path / "q.db")) assert queue._build_resource_health_json() is None - def test_build_resource_health_json_with_resources(self, tmp_path) -> None: + def test_build_resource_health_json_with_resources(self, tmp_path: Any) -> None: """_build_resource_health_json returns correct JSON.""" queue = Queue(db_path=str(tmp_path / "q.db")) @queue.worker_resource("db") - def create_db(): + def create_db() -> str: return "db_instance" @queue.worker_resource("cache") - def create_cache(): + def create_cache() -> str: return "cache_instance" health_json = queue._build_resource_health_json() @@ -36,14 +37,14 @@ def create_cache(): health = json.loads(health_json) assert health == {"db": "healthy", "cache": "healthy"} - def test_build_resource_health_reflects_unhealthy(self, tmp_path) -> None: + def test_build_resource_health_reflects_unhealthy(self, tmp_path: Any) -> None: """_build_resource_health_json marks unhealthy resources.""" from taskito.resources.runtime import ResourceRuntime queue = Queue(db_path=str(tmp_path / "q.db")) @queue.worker_resource("db") - def create_db(): + def create_db() -> str: return "db_instance" # Simulate an initialized runtime with an unhealthy resource @@ -58,19 +59,19 @@ def create_db(): health = json.loads(health_json) assert health["db"] == "unhealthy" - def test_worker_heartbeat_method(self, tmp_path) -> None: + def test_worker_heartbeat_method(self, tmp_path: Any) -> None: """worker_heartbeat can be called without error.""" queue = Queue(db_path=str(tmp_path / "q.db")) # Heartbeat for a non-existent worker is a no-op (updates 0 rows) queue._inner.worker_heartbeat("nonexistent-worker") - def test_worker_heartbeat_with_health(self, tmp_path) -> None: + def test_worker_heartbeat_with_health(self, tmp_path: Any) -> None: """worker_heartbeat accepts resource_health JSON.""" queue = Queue(db_path=str(tmp_path / "q.db")) health = json.dumps({"db": "healthy"}) queue._inner.worker_heartbeat("w-test", health) - def test_list_workers_empty(self, tmp_path) -> None: + def test_list_workers_empty(self, tmp_path: Any) -> None: """list_workers returns empty list when no workers registered.""" queue = Queue(db_path=str(tmp_path / "q.db")) workers = queue.workers() @@ -83,16 +84,16 @@ def test_list_workers_empty(self, tmp_path) -> None: class TestWorkerResourceIntegration: - def test_worker_advertises_resources_and_threads(self, tmp_path) -> None: + def test_worker_advertises_resources_and_threads(self, tmp_path: Any) -> None: """A running worker stores resources and threads in storage.""" queue = Queue(db_path=str(tmp_path / "q.db"), workers=2) @queue.worker_resource("db") - def create_db(): + def create_db() -> str: return "db_instance" @queue.task() - def noop(): + def noop() -> None: pass # Start worker in thread, wait for it to register @@ -103,7 +104,7 @@ def noop(): # Poll until a worker appears with resource_health populated # (initial registration has None; first heartbeat sets it) deadline = time.monotonic() + 15 - workers = [] + workers: list[dict[str, Any]] = [] while time.monotonic() < deadline: workers = queue.workers() if workers and workers[0].get("resource_health") is not None: @@ -128,12 +129,12 @@ def noop(): queue._inner.request_shutdown() thread.join(timeout=10) - def test_worker_no_resources(self, tmp_path) -> None: + def test_worker_no_resources(self, tmp_path: Any) -> None: """Worker without resources stores None for resource fields.""" queue = Queue(db_path=str(tmp_path / "q.db"), workers=1) @queue.task() - def noop(): + def noop() -> None: pass thread = threading.Thread(target=queue.run_worker, daemon=True) @@ -141,7 +142,7 @@ def noop(): try: deadline = time.monotonic() + 10 - workers = [] + workers: list[dict[str, Any]] = [] while time.monotonic() < deadline: workers = queue.workers() if workers: @@ -156,16 +157,16 @@ def noop(): queue._inner.request_shutdown() thread.join(timeout=10) - def test_heartbeat_updates_health(self, tmp_path) -> None: + def test_heartbeat_updates_health(self, tmp_path: Any) -> None: """Heartbeat thread updates resource_health in storage.""" queue = Queue(db_path=str(tmp_path / "q.db"), workers=1) @queue.worker_resource("db") - def create_db(): + def create_db() -> str: return "db_instance" @queue.task() - def noop(): + def noop() -> None: pass thread = threading.Thread(target=queue.run_worker, daemon=True) @@ -174,7 +175,7 @@ def noop(): try: # Wait for worker + first heartbeat deadline = time.monotonic() + 10 - workers = [] + workers: list[dict[str, Any]] = [] while time.monotonic() < deadline: workers = queue.workers() if workers and workers[0].get("resource_health"):