diff --git a/examples/03_data_flow/README.md b/examples/03_data_flow/README.md index b092adc..c43585d 100644 --- a/examples/03_data_flow/README.md +++ b/examples/03_data_flow/README.md @@ -11,7 +11,8 @@ This section demonstrates **data flow and inter-task communication** in Graflow ## What You'll Learn - 📡 Using channels for inter-task communication -- 🔒 Type-safe channels with TypedDict +- 🔒 Thread-safe channel operations (`atomic_add`, `lock`) +- 🏷️ Type-safe channels with TypedDict - 💾 Storing and retrieving task results - 🔄 Data flow patterns in workflows - 📊 Sharing state across task boundaries @@ -37,7 +38,25 @@ uv run python examples/03_data_flow/channels_basic.py --- -### 2. typed_channels.py +### 2. channel_concurrency.py + +**Concept**: Thread-safe channel operations + +Learn how to safely update shared channel data when tasks run in parallel using `ParallelGroup`. + +```bash +uv run python examples/03_data_flow/channel_concurrency.py +``` + +**Key Concepts**: +- Race conditions with naive `get`/`set` under concurrency +- Atomic counter updates with `channel.atomic_add()` +- Advisory locking with `channel.lock()` for compound operations +- Threshold-based reset pattern with multi-key updates + +--- + +### 3. typed_channels.py **Concept**: Type-safe channels @@ -56,7 +75,7 @@ uv run python examples/03_data_flow/typed_channels.py --- -### 3. results_storage.py +### 4. results_storage.py **Concept**: Task results and dependency data @@ -107,7 +126,32 @@ def process_batch(ctx: TaskExecutionContext): channel.set("total_processed", total) ``` -### Pattern 3: Type-Safe Messages +### Pattern 3: Thread-Safe Counter + +```python +@task(inject_context=True) +def count_items(ctx: TaskExecutionContext): + channel = ctx.get_channel() + # Atomic — safe for parallel tasks + channel.atomic_add("processed_count", 1) +``` + +### Pattern 4: Compound Update with Lock + +```python +@task(inject_context=True) +def check_and_reset(ctx: TaskExecutionContext): + channel = ctx.get_channel() + with channel.lock("counter"): + val = channel.get("counter") + if val >= threshold: + channel.set("counter", 0) + channel.atomic_add("overflow_count", 1) + else: + channel.set("counter", val + 1) +``` + +### Pattern 5: Type-Safe Messages ```python from typing import TypedDict @@ -129,7 +173,7 @@ def collect_metrics(ctx: TaskExecutionContext): typed_channel.set("metrics", metrics) ``` -### Pattern 4: Result Dependencies +### Pattern 6: Result Dependencies ```python with workflow("pipeline") as ctx: @@ -310,9 +354,8 @@ def use_config(ctx: TaskExecutionContext): @task(inject_context=True) def track_metrics(ctx: TaskExecutionContext): channel = ctx.get_channel() - metrics = channel.get("metrics", {}) - metrics["processed"] = metrics.get("processed", 0) + 1 - channel.set("metrics", metrics) + # Thread-safe counter — works in parallel tasks + channel.atomic_add("processed", 1) ``` ### Error Accumulation @@ -374,6 +417,8 @@ After mastering data flow: - `channel.set(key, value)` - Store value - `channel.get(key, default=None)` - Retrieve value - `channel.keys()` - List all keys +- `channel.atomic_add(key, amount=1)` - Atomic numeric add/subtract (thread-safe for `MemoryChannel`; single-command atomic for Redis) +- `channel.lock(key, timeout=10.0)` - Advisory per-key lock for compound read-modify-write operations (thread-safe via `threading.RLock`) **TypedChannel**: - `typed_channel = ctx.get_typed_channel(SchemaClass)` - Create typed channel diff --git a/examples/03_data_flow/channel_concurrency.py b/examples/03_data_flow/channel_concurrency.py new file mode 100644 index 0000000..29ccc73 --- /dev/null +++ b/examples/03_data_flow/channel_concurrency.py @@ -0,0 +1,209 @@ +""" +Channel Concurrency Example +============================ + +This example demonstrates how to safely share and update channel data +when tasks run in parallel using ``ParallelGroup``. + +Problem +------- +A naive read-modify-write (``get`` -> compute -> ``set``) is **not** atomic. +When multiple tasks execute in parallel threads, updates can be lost because +two tasks may read the same value and overwrite each other's writes. + +Solutions +--------- +1. **``channel.atomic_add(key, amount)``** — Atomic numeric add (inc/dec). + Backed by a per-key lock in MemoryChannel and ``INCRBYFLOAT`` in Redis. + +2. **``channel.lock(key)``** — Advisory lock for arbitrary compound + operations. ``MemoryChannel`` uses a per-key ``threading.RLock`` for + in-process coordination; ``RedisChannel`` uses ``redis.lock.Lock`` + (SET NX + Lua release) for cross-client distributed coordination. + +Expected Output +--------------- +=== Channel Concurrency Demo === + +--- Unsafe parallel increment (race condition) --- +Expected counter: 500, Actual: +Updates lost! + +--- Safe parallel increment with atomic_add() --- +Expected counter: 500, Actual: 500 + +--- Safe compound update with lock() --- +Overflow events: 5 +Counter after resets: 0 + +Done! +""" + +from graflow.core.context import TaskExecutionContext +from graflow.core.decorators import task +from graflow.core.task import ParallelGroup +from graflow.core.workflow import workflow + + +def demo_unsafe_increment() -> None: + """Show that naive get/set loses updates in parallel execution.""" + import time + + print("--- Unsafe parallel increment (race condition) ---") + + num_workers = 5 + increments_per_worker = 100 + expected = num_workers * increments_per_worker + + with workflow("unsafe_demo") as ctx: + + @task(inject_context=True) + def init_counter(context: TaskExecutionContext): + context.get_channel().set("counter", 0) + + # Create worker tasks that use naive get/set + workers = [] + for i in range(num_workers): + + @task(inject_context=True, id=f"unsafe_worker_{i}") + def unsafe_worker(context: TaskExecutionContext): + channel = context.get_channel() + for _ in range(increments_per_worker): + val = channel.get("counter") + time.sleep(0) # yield to trigger interleaving + channel.set("counter", val + 1) + + workers.append(unsafe_worker) + + @task(inject_context=True) + def report(context: TaskExecutionContext): + actual = context.get_channel().get("counter") + print(f" Expected counter: {expected}, Actual: {actual}") + if actual < expected: + print(" Updates lost!\n") + else: + print(" (Got lucky — no interleaving this run)\n") + + parallel = ParallelGroup(workers, name="unsafe_group") + _ = init_counter >> parallel >> report + ctx.execute("init_counter") + + +def demo_atomic_add() -> None: + """Show that atomic_add() is safe for parallel numeric updates.""" + print("--- Safe parallel increment with atomic_add() ---") + + num_workers = 5 + increments_per_worker = 100 + expected = num_workers * increments_per_worker + + with workflow("add_demo") as ctx: + + @task(inject_context=True) + def init_counter(context: TaskExecutionContext): + context.get_channel().set("counter", 0) + + workers = [] + for i in range(num_workers): + + @task(inject_context=True, id=f"add_worker_{i}") + def add_worker(context: TaskExecutionContext): + channel = context.get_channel() + for _ in range(increments_per_worker): + channel.atomic_add("counter", 1) + + workers.append(add_worker) + + @task(inject_context=True) + def report(context: TaskExecutionContext): + actual = context.get_channel().get("counter") + print(f" Expected counter: {expected}, Actual: {actual}\n") + + parallel = ParallelGroup(workers, name="add_group") + _ = init_counter >> parallel >> report + ctx.execute("init_counter") + + +def demo_advisory_lock() -> None: + """Show lock() for compound read-modify-write that atomic_add() can't express.""" + print("--- Safe compound update with lock() ---") + + threshold = 10 + num_workers = 5 + increments_per_worker = 10 + + with workflow("lock_demo") as ctx: + + @task(inject_context=True) + def init(context: TaskExecutionContext): + channel = context.get_channel() + channel.set("counter", 0) + channel.set("overflow_count", 0) + + workers = [] + for i in range(num_workers): + + @task(inject_context=True, id=f"lock_worker_{i}") + def lock_worker(context: TaskExecutionContext): + channel = context.get_channel() + for _ in range(increments_per_worker): + # Advisory lock protects the entire read-modify-write block + with channel.lock("counter"): + val = channel.get("counter") + if val >= threshold: + channel.set("counter", 0) + channel.atomic_add("overflow_count", 1) + else: + channel.set("counter", val + 1) + + workers.append(lock_worker) + + @task(inject_context=True) + def report(context: TaskExecutionContext): + channel = context.get_channel() + overflows = channel.get("overflow_count") + counter = channel.get("counter") + print(f" Overflow events: {overflows}") + print(f" Counter after resets: {counter}\n") + + parallel = ParallelGroup(workers, name="lock_group") + _ = init >> parallel >> report + ctx.execute("init") + + +def main(): + print("=== Channel Concurrency Demo ===\n") + demo_unsafe_increment() + demo_atomic_add() + demo_advisory_lock() + print("Done!") + + +if __name__ == "__main__": + main() + + +# ============================================================================ +# Key Takeaways: +# ============================================================================ +# +# 1. **channel.atomic_add(key, amount)** +# - Atomic numeric add/subtract — no lost updates +# - Initialises missing keys to 0 automatically +# - MemoryChannel: per-key RLock; Redis: INCRBYFLOAT (server-side atomic) +# - Use for counters, metrics, scores +# +# 2. **channel.lock(key)** +# - Advisory lock for compound operations that atomic_add() can't express +# - Wrap with ``with channel.lock(key):`` context manager +# - MemoryChannel: per-key RLock; Redis: distributed lock for the same key +# - Use for conditional updates and other compound read-modify-write logic +# +# 3. **When to use which** +# - Simple counter? → channel.atomic_add("counter", 1) +# - Decrement? → channel.atomic_add("counter", -1) +# - Conditional update? → with channel.lock("key"): ... +# - Multi-key update? → with channel.lock("key"): ... +# - No concurrency concern? → channel.get() / channel.set() is fine +# +# ============================================================================ diff --git a/graflow/channels/base.py b/graflow/channels/base.py index 27dfabd..9fcbe3e 100644 --- a/graflow/channels/base.py +++ b/graflow/channels/base.py @@ -3,7 +3,8 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Any, List, Optional +from contextlib import contextmanager +from typing import Any, Iterator, List, Optional, Union class Channel(ABC): @@ -69,3 +70,61 @@ def prepend(self, key: str, value: Any, ttl: Optional[int] = None) -> int: Length of the list after prepend """ pass + + @abstractmethod + def atomic_add(self, key: str, amount: Union[int, float] = 1) -> Union[int, float]: + """Atomically add *amount* to the numeric value stored at *key*. + + If *key* does not exist, it is initialised to 0 before the addition. + Negative *amount* is allowed (decrement). + + Args: + key: The key identifying the numeric value. + amount: The value to add (default 1). May be negative. + + Returns: + The value after the addition. + + Raises: + TypeError: If the existing value is not numeric. + """ + pass + + @contextmanager + def lock(self, key: str, timeout: float = 10.0) -> Iterator[None]: + """Acquire an advisory lock scoped to *key* for compound operations. + + Usage:: + + with channel.lock("counter"): + val = channel.get("counter") + channel.set("counter", val * 2 if val > 0 else 0) + + The lock is *advisory* — regular ``get``/``set`` calls do **not** + acquire it automatically. It exists for task authors who need to + protect read-modify-write sequences that cannot be expressed with + ``atomic_add()``. + + Both ``MemoryChannel`` (threading lock) and ``RedisChannel`` + (``redis.lock.Lock`` — distributed SET NX + Lua release) provide + real mutual exclusion. + + .. warning:: + + The **base class** default is a **no-op** (yields immediately) + and provides **no mutual exclusion**. Custom subclasses that + need compound-operation safety **must** override this method. + + Args: + key: Logical key to lock on (does not need to correspond to a + stored key). + timeout: Maximum seconds to wait for the lock. + + Raises: + TimeoutError: If the lock cannot be acquired within *timeout* + (raised by ``MemoryChannel`` and ``RedisChannel``). + + Yields: + None — the lock is held for the duration of the ``with`` block. + """ + yield diff --git a/graflow/channels/memory_channel.py b/graflow/channels/memory_channel.py index 225ff86..ddb12d9 100644 --- a/graflow/channels/memory_channel.py +++ b/graflow/channels/memory_channel.py @@ -2,8 +2,10 @@ from __future__ import annotations +import threading import time -from typing import Any, Dict, List, Optional +from contextlib import contextmanager +from typing import Any, Dict, Iterator, List, Optional, Union from graflow.channels.base import Channel @@ -16,6 +18,21 @@ def __init__(self, name: str, **kwargs): super().__init__(name) self.data: Dict[str, Any] = {} self.ttl_data: Dict[str, float] = {} + self._key_locks: Dict[str, threading.RLock] = {} + self._key_locks_guard = threading.Lock() # protects _key_locks dict itself + + def __getstate__(self) -> Dict[str, Any]: + """Exclude unpicklable lock objects during serialization.""" + state = self.__dict__.copy() + state.pop("_key_locks", None) + state.pop("_key_locks_guard", None) + return state + + def __setstate__(self, state: Dict[str, Any]) -> None: + """Recreate lock objects after deserialization.""" + self.__dict__.update(state) # type: ignore[assignment] + self._key_locks = {} + self._key_locks_guard = threading.Lock() def set(self, key: str, value: Any, ttl: Optional[int] = None) -> None: """Store data in the channel.""" @@ -122,3 +139,59 @@ def prepend(self, key: str, value: Any, ttl: Optional[int] = None) -> int: self.ttl_data[key] = time.time() + ttl return len(self.data[key]) + + def atomic_add(self, key: str, amount: Union[int, float] = 1) -> Union[int, float]: + """Atomically add *amount* to the numeric value stored at *key*. + + Thread-safe: acquires a per-key lock so concurrent calls do not lose + updates. + + Args: + key: The key identifying the numeric value. + amount: The value to add (default 1). May be negative. + + Returns: + The value after the addition. + + Raises: + TypeError: If the existing value is not numeric. + """ + with self._get_key_lock(key): + self._cleanup_expired(key) + current = self.data.get(key, 0) + if not isinstance(current, (int, float)): + raise TypeError(f"Key '{key}' holds {type(current).__name__}, expected int or float") + new_value = current + amount + self.data[key] = new_value + return new_value + + # -- advisory locking for compound operations -- + + def _get_key_lock(self, key: str) -> threading.RLock: + """Return (or create) the RLock associated with *key*.""" + if key not in self._key_locks: + with self._key_locks_guard: + # Double-checked locking + if key not in self._key_locks: + self._key_locks[key] = threading.RLock() + return self._key_locks[key] + + @contextmanager + def lock(self, key: str, timeout: float = 10.0) -> Iterator[None]: + """Acquire an advisory per-key lock for compound read-modify-write. + + Args: + key: Logical key to lock on. + timeout: Maximum seconds to wait for the lock. + + Raises: + TimeoutError: If the lock cannot be acquired within *timeout*. + """ + rlock = self._get_key_lock(key) + acquired = rlock.acquire(timeout=timeout) + if not acquired: + raise TimeoutError(f"Could not acquire lock for key '{key}' within {timeout}s") + try: + yield + finally: + rlock.release() diff --git a/graflow/channels/redis_channel.py b/graflow/channels/redis_channel.py index 508ce7d..ef8f4ac 100644 --- a/graflow/channels/redis_channel.py +++ b/graflow/channels/redis_channel.py @@ -3,13 +3,16 @@ from __future__ import annotations import json -from typing import Any, List, Optional, cast +from contextlib import contextmanager +from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Union, cast from graflow.channels.base import Channel +if TYPE_CHECKING: + from redis import Redis + try: import redis - from redis import Redis except ImportError: redis = None @@ -217,6 +220,56 @@ def prepend(self, key: str, value: Any, ttl: Optional[int] = None) -> int: return cast(int, length) + def atomic_add(self, key: str, amount: Union[int, float] = 1) -> Union[int, float]: + """Atomically add *amount* using Redis INCRBYFLOAT. + + This is a server-side atomic operation — no client-side locking needed. + + Args: + key: The key identifying the numeric value. + amount: The value to add (default 1). May be negative. + + Returns: + The value after the addition. + """ + redis_key = self._get_key(key) + # INCRBYFLOAT returns the new value as a string; auto-creates key at 0. + result = self.redis_client.incrbyfloat(redis_key, amount) + new_value = float(cast(str, result)) + # Return int when possible for ergonomics + if new_value == int(new_value): + return int(new_value) + return new_value + + @contextmanager + def lock(self, key: str, timeout: float = 10.0) -> Iterator[None]: + """Acquire a distributed advisory lock scoped to *key*. + + Uses ``redis.lock.Lock`` (SET NX + Lua release) under the hood. + + Args: + key: Logical key to lock on. + timeout: Maximum seconds to wait for the lock. + + Raises: + TimeoutError: If the lock cannot be acquired within *timeout*. + """ + from redis.exceptions import LockNotOwnedError + from redis.lock import Lock + + lock_name = f"{self.key_prefix}lock:{key}" + lock = Lock(self.redis_client, lock_name, timeout=timeout) + acquired = lock.acquire(blocking=True, blocking_timeout=timeout) + if not acquired: + raise TimeoutError(f"Could not acquire lock for key '{key}' within {timeout}s") + try: + yield + finally: + try: + lock.release() + except LockNotOwnedError: + pass + def __getstate__(self): """Support for pickle serialization.""" state = self.__dict__.copy() @@ -226,7 +279,7 @@ def __getstate__(self): def __setstate__(self, state): """Support for pickle deserialization.""" - self.__dict__.update(state) + self.__dict__.update(state) # type: ignore # Recreate the Redis client assert redis is not None, "redis package is required for RedisChannel" self.redis_client = redis.Redis( diff --git a/graflow/channels/typed.py b/graflow/channels/typed.py index 679e482..e0c3727 100644 --- a/graflow/channels/typed.py +++ b/graflow/channels/typed.py @@ -2,7 +2,8 @@ from __future__ import annotations -from typing import Any, ClassVar, Generic, Type, TypeVar, get_type_hints +from contextlib import contextmanager +from typing import Any, ClassVar, Generic, Iterator, Type, TypeVar, Union, get_type_hints from graflow.channels.base import Channel from graflow.exceptions import ConfigError @@ -148,6 +149,16 @@ def prepend(self, key: str, value: Any, ttl: int | None = None) -> int: """ return self._channel.prepend(key, value, ttl) + def atomic_add(self, key: str, amount: Union[int, float] = 1) -> Union[int, float]: + """Atomically add *amount* to the numeric value stored at *key*.""" + return self._channel.atomic_add(key, amount) + + @contextmanager + def lock(self, key: str, timeout: float = 10.0) -> Iterator[None]: + """Acquire advisory lock — delegates to underlying channel.""" + with self._channel.lock(key, timeout): + yield + def send(self, key: str, message: T, ttl: int | None = None) -> None: """Send a typed message. diff --git a/tests/channels/test_memory_channel_thread_safety.py b/tests/channels/test_memory_channel_thread_safety.py new file mode 100644 index 0000000..5c37b58 --- /dev/null +++ b/tests/channels/test_memory_channel_thread_safety.py @@ -0,0 +1,440 @@ +"""Tests for MemoryChannel thread safety and race conditions. + +Demonstrates that concurrent read-modify-write on MemoryChannel +without locks leads to lost updates (race condition), and that +the ``atomic_add()`` / ``lock()`` APIs solve the problem. + +CPython's GIL makes individual dict operations (get/set) effectively atomic, +but the *compound* read-modify-write pattern is NOT atomic. The race window +is between get() and set(). We insert time.sleep(0) to yield the GIL and +force context switches, reliably reproducing the lost-update problem. +""" + +from __future__ import annotations + +import threading +import time + +import pytest + +from graflow.channels.memory_channel import MemoryChannel + + +class TestRaceConditionReproduction: + """Prove that naive get/set is unsafe under concurrency.""" + + def test_concurrent_get_set_loses_updates(self) -> None: + """Naive get → sleep(0) → set loses updates due to interleaving.""" + channel = MemoryChannel("test") + channel.set("counter", 0) + + num_threads = 10 + increments_per_thread = 100 + expected = num_threads * increments_per_thread + + barrier = threading.Barrier(num_threads) + + def unsafe_increment() -> None: + barrier.wait() + for _ in range(increments_per_thread): + val = channel.get("counter") + time.sleep(0) # Yield GIL → forces interleaving + channel.set("counter", val + 1) + + threads = [threading.Thread(target=unsafe_increment) for _ in range(num_threads)] + for t in threads: + t.start() + for t in threads: + t.join() + + actual = channel.get("counter") + assert actual < expected, ( + f"Expected lost updates but got exact count {actual}. " + f"Extremely unlikely with {num_threads} threads and sleep(0)." + ) + + +class TestAtomicAdd: + """Tests for the atomic ``atomic_add()`` method.""" + + def test_concurrent_add_no_lost_updates(self) -> None: + """Concurrent atomic_add() must not lose any updates.""" + channel = MemoryChannel("test") + channel.set("counter", 0) + + num_threads = 10 + increments_per_thread = 1000 + expected = num_threads * increments_per_thread + + barrier = threading.Barrier(num_threads) + + def atomic_add() -> None: + barrier.wait() + for _ in range(increments_per_thread): + channel.atomic_add("counter") + + threads = [threading.Thread(target=atomic_add) for _ in range(num_threads)] + for t in threads: + t.start() + for t in threads: + t.join() + + actual = channel.get("counter") + assert actual == expected, f"Lost updates: {actual} != {expected}" + + def test_add_positive(self) -> None: + channel = MemoryChannel("test") + channel.set("counter", 10) + result = channel.atomic_add("counter", 5) + assert result == 15 + assert channel.get("counter") == 15 + + def test_add_negative(self) -> None: + """Negative amount works as decrement.""" + channel = MemoryChannel("test") + channel.set("counter", 10) + result = channel.atomic_add("counter", -3) + assert result == 7 + + def test_add_initializes_missing_key(self) -> None: + """Missing key is initialised to 0 before adding.""" + channel = MemoryChannel("test") + result = channel.atomic_add("new_key") + assert result == 1 + assert channel.get("new_key") == 1 + + def test_add_float(self) -> None: + channel = MemoryChannel("test") + channel.set("metric", 1.5) + result = channel.atomic_add("metric", 0.25) + assert result == pytest.approx(1.75) + + def test_add_raises_on_non_numeric(self) -> None: + channel = MemoryChannel("test") + channel.set("data", "hello") + with pytest.raises(TypeError, match="expected int or float"): + channel.atomic_add("data", 1) + + def test_set_and_atomic_add_share_same_key(self) -> None: + """set() then atomic_add() then get() must all operate on the same value.""" + channel = MemoryChannel("test") + channel.set("counter", 0) + channel.atomic_add("counter", 5) + assert channel.get("counter") == 5 + + channel.set("counter", 100) + channel.atomic_add("counter", -10) + assert channel.get("counter") == 90 + + def test_atomic_add_then_set_overwrites(self) -> None: + """set() after atomic_add() overwrites the value.""" + channel = MemoryChannel("test") + channel.atomic_add("counter", 42) + channel.set("counter", 0) + assert channel.get("counter") == 0 + + def test_atomic_add_visible_in_keys_and_exists(self) -> None: + """Keys created by atomic_add() appear in keys() and exists().""" + channel = MemoryChannel("test") + channel.atomic_add("auto_created", 1) + assert channel.exists("auto_created") + assert "auto_created" in channel.keys() + + +class TestAdvisoryLock: + """Tests for the explicit ``lock()`` context manager.""" + + def test_lock_prevents_lost_updates(self) -> None: + """Compound get-modify-set inside lock() must not lose updates.""" + channel = MemoryChannel("test") + channel.set("counter", 0) + + num_threads = 10 + increments_per_thread = 100 + expected = num_threads * increments_per_thread + + barrier = threading.Barrier(num_threads) + + def safe_increment() -> None: + barrier.wait() + for _ in range(increments_per_thread): + with channel.lock("counter"): + val = channel.get("counter") + time.sleep(0) # same yield as the race test + channel.set("counter", val + 1) + + threads = [threading.Thread(target=safe_increment) for _ in range(num_threads)] + for t in threads: + t.start() + for t in threads: + t.join() + + actual = channel.get("counter") + assert actual == expected, f"Lost updates even with lock: {actual} != {expected}" + + def test_lock_is_reentrant(self) -> None: + """Same thread can acquire the same lock twice (RLock).""" + channel = MemoryChannel("test") + channel.set("x", 0) + with channel.lock("x"): + with channel.lock("x"): # must not deadlock + channel.atomic_add("x", 1) + assert channel.get("x") == 1 + + def test_lock_timeout_raises(self) -> None: + """If lock is held by another thread, timeout triggers TimeoutError.""" + channel = MemoryChannel("test") + held = threading.Event() + done = threading.Event() + + def holder() -> None: + with channel.lock("key"): + held.set() + done.wait(timeout=5) + + t = threading.Thread(target=holder) + t.start() + assert held.wait(timeout=5), "holder thread failed to acquire lock" + + with pytest.raises(TimeoutError): + with channel.lock("key", timeout=0.1): + pass # pragma: no cover + + done.set() + t.join() + + def test_lock_different_keys_independent(self) -> None: + """Locks on different keys do not block each other.""" + channel = MemoryChannel("test") + order: list[str] = [] + barrier = threading.Barrier(2) + + def lock_a() -> None: + barrier.wait() + with channel.lock("a"): + order.append("a-acquired") + time.sleep(0.05) + + def lock_b() -> None: + barrier.wait() + with channel.lock("b"): + order.append("b-acquired") + time.sleep(0.05) + + ta = threading.Thread(target=lock_a) + tb = threading.Thread(target=lock_b) + ta.start() + tb.start() + ta.join() + tb.join() + + assert "a-acquired" in order + assert "b-acquired" in order + + def test_lock_compound_counter_with_overflow(self) -> None: + """Concurrent counter with threshold-based reset using lock(). + + Each thread increments a counter; when it reaches a threshold the + counter is reset to 0 and an overflow counter is bumped. + Without the lock, the conditional reset is racy. + """ + channel = MemoryChannel("test") + channel.set("counter", 0) + channel.set("overflow_count", 0) + + threshold = 10 + num_threads = 5 + increments_per_thread = 100 + total_increments = num_threads * increments_per_thread + + barrier = threading.Barrier(num_threads) + + def worker() -> None: + barrier.wait() + for _ in range(increments_per_thread): + with channel.lock("counter"): + val = channel.get("counter") + if val >= threshold: + channel.set("counter", 0) + channel.atomic_add("overflow_count", 1) + else: + channel.set("counter", val + 1) + + threads = [threading.Thread(target=worker) for _ in range(num_threads)] + for t in threads: + t.start() + for t in threads: + t.join() + + counter = channel.get("counter") + overflows = channel.get("overflow_count") + + # Invariant: every increment either bumped the counter or triggered a reset + # total_increments = overflows * (threshold + 1) + counter + assert overflows * (threshold + 1) + counter == total_increments, ( + f"Inconsistent state: overflows={overflows}, counter={counter}, expected total={total_increments}" + ) + assert 0 <= counter <= threshold + + def test_lock_multi_key_update(self) -> None: + """Demonstrate lock() protecting a multi-key read-modify-write.""" + channel = MemoryChannel("test") + channel.set("balance_a", 1000) + channel.set("balance_b", 1000) + + num_threads = 10 + transfers_per_thread = 50 + + barrier = threading.Barrier(num_threads) + + def transfer() -> None: + barrier.wait() + for _ in range(transfers_per_thread): + # Lock on a logical "transfer" key to serialize transfers + with channel.lock("transfer"): + a = channel.get("balance_a") + b = channel.get("balance_b") + time.sleep(0) # yield to stress-test + # Move 1 unit from a to b + channel.set("balance_a", a - 1) + channel.set("balance_b", b + 1) + + threads = [threading.Thread(target=transfer) for _ in range(num_threads)] + for t in threads: + t.start() + for t in threads: + t.join() + + a = channel.get("balance_a") + b = channel.get("balance_b") + total_transfers = num_threads * transfers_per_thread + + # Conservation: total should always be 2000 + assert a + b == 2000, f"Balance inconsistency: a={a}, b={b}, sum={a + b}" + assert a == 1000 - total_transfers + assert b == 1000 + total_transfers + + +class TestBaseChannelLockNoop: + """Verify that the base Channel.lock() default is a no-op.""" + + def test_noop_lock(self) -> None: + """Any subclass that doesn't override lock() gets a no-op.""" + from graflow.channels.base import Channel + + # Create a minimal concrete subclass that does NOT override lock() + class MinimalChannel(Channel): + def set(self, key, value, ttl=None): # type: ignore[override] + pass + + def get(self, key, default=None): # type: ignore[override] + pass + + def delete(self, key): # type: ignore[override] + return False + + def exists(self, key): # type: ignore[override] + return False + + def keys(self): # type: ignore[override] + return [] + + def clear(self): + pass + + def append(self, key, value, ttl=None): # type: ignore[override] + return 0 + + def prepend(self, key, value, ttl=None): # type: ignore[override] + return 0 + + def atomic_add(self, key, amount=1): # type: ignore[override] + return 0 + + ch = MinimalChannel("noop") + # Should not raise, just pass through + with ch.lock("anything"): + pass + + +class TestSerialization: + """Verify MemoryChannel survives pickle round-trip (checkpoint/resume).""" + + def test_pickle_round_trip_preserves_data(self) -> None: + """Data and TTL survive pickle; locks are recreated.""" + import pickle + + channel = MemoryChannel("test") + channel.set("counter", 42) + channel.set("config", {"batch": 100}) + + # Force lock creation so it exists in __dict__ + channel.atomic_add("counter", 1) + + restored: MemoryChannel = pickle.loads(pickle.dumps(channel)) + + assert restored.get("counter") == 43 + assert restored.get("config") == {"batch": 100} + assert restored.name == "test" + + def test_pickle_round_trip_locks_functional(self) -> None: + """Locks work correctly after deserialization.""" + import pickle + + channel = MemoryChannel("test") + channel.set("x", 0) + + restored: MemoryChannel = pickle.loads(pickle.dumps(channel)) + + # atomic_add must work (uses per-key lock internally) + restored.atomic_add("x", 5) + assert restored.get("x") == 5 + + # Advisory lock must work + with restored.lock("x"): + val = restored.get("x") + restored.set("x", val + 1) + assert restored.get("x") == 6 + + def test_cloudpickle_round_trip(self) -> None: + """Checkpoint uses cloudpickle — verify it works too.""" + import cloudpickle + + channel = MemoryChannel("test") + channel.set("data", [1, 2, 3]) + channel.atomic_add("counter", 10) + + restored: MemoryChannel = cloudpickle.loads(cloudpickle.dumps(channel)) + + assert restored.get("data") == [1, 2, 3] + assert restored.get("counter") == 10 + restored.atomic_add("counter", 1) + assert restored.get("counter") == 11 + + def test_pickle_round_trip_concurrent_after_restore(self) -> None: + """Restored channel handles concurrent access correctly.""" + import pickle + + channel = MemoryChannel("test") + channel.set("counter", 0) + + restored: MemoryChannel = pickle.loads(pickle.dumps(channel)) + + num_threads = 5 + increments_per_thread = 200 + expected = num_threads * increments_per_thread + + barrier = threading.Barrier(num_threads) + + def worker() -> None: + barrier.wait() + for _ in range(increments_per_thread): + restored.atomic_add("counter") + + threads = [threading.Thread(target=worker) for _ in range(num_threads)] + for t in threads: + t.start() + for t in threads: + t.join() + + assert restored.get("counter") == expected