From 908fa590e944e2802231ca45563d96ed9751f780 Mon Sep 17 00:00:00 2001 From: Makoto Yui Date: Sun, 5 Apr 2026 18:09:14 +0900 Subject: [PATCH 01/21] feat: add atomic add method and advisory lock to Channel interface --- graflow/channels/base.py | 53 ++++++++++++++++++++++++++++++++++++++- graflow/channels/typed.py | 13 +++++++++- 2 files changed, 64 insertions(+), 2 deletions(-) diff --git a/graflow/channels/base.py b/graflow/channels/base.py index 27dfabd..44d4304 100644 --- a/graflow/channels/base.py +++ b/graflow/channels/base.py @@ -3,7 +3,8 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Any, List, Optional +from contextlib import contextmanager +from typing import Any, Iterator, List, Optional, Union class Channel(ABC): @@ -69,3 +70,53 @@ def prepend(self, key: str, value: Any, ttl: Optional[int] = None) -> int: Length of the list after prepend """ pass + + @abstractmethod + def add(self, key: str, amount: Union[int, float] = 1) -> Union[int, float]: + """Atomically add *amount* to the numeric value stored at *key*. + + If *key* does not exist, it is initialised to 0 before the addition. + Negative *amount* is allowed (decrement). + + Args: + key: The key identifying the numeric value. + amount: The value to add (default 1). May be negative. + + Returns: + The value after the addition. + + Raises: + TypeError: If the existing value is not numeric. + """ + pass + + @contextmanager + def lock(self, key: str, timeout: float = 10.0) -> Iterator[None]: + """Acquire an advisory lock scoped to *key* for compound operations. + + Usage:: + + with channel.lock("counter"): + val = channel.get("counter") + channel.set("counter", val * 2 if val > 0 else 0) + + The lock is *advisory* — regular ``get``/``set`` calls do **not** + acquire it automatically. It exists for task authors who need to + protect read-modify-write sequences that cannot be expressed with + ``add``. + + The default implementation is a **no-op** (yields immediately). + This is appropriate for backends where atomicity is guaranteed by + the server (e.g. Redis commands are serialised server-side). + Subclasses that need client-side locking (e.g. ``MemoryChannel`` + under multi-threading) should override this method. + + Args: + key: Logical key to lock on (does not need to correspond to a + stored key). + timeout: Maximum seconds to wait for the lock. + + Yields: + None — the lock is held for the duration of the ``with`` block. + """ + yield diff --git a/graflow/channels/typed.py b/graflow/channels/typed.py index 679e482..60c3216 100644 --- a/graflow/channels/typed.py +++ b/graflow/channels/typed.py @@ -2,7 +2,8 @@ from __future__ import annotations -from typing import Any, ClassVar, Generic, Type, TypeVar, get_type_hints +from contextlib import contextmanager +from typing import Any, ClassVar, Generic, Iterator, Type, TypeVar, Union, get_type_hints from graflow.channels.base import Channel from graflow.exceptions import ConfigError @@ -148,6 +149,16 @@ def prepend(self, key: str, value: Any, ttl: int | None = None) -> int: """ return self._channel.prepend(key, value, ttl) + def add(self, key: str, amount: Union[int, float] = 1) -> Union[int, float]: + """Atomically add *amount* to the numeric value stored at *key*.""" + return self._channel.add(key, amount) + + @contextmanager + def lock(self, key: str, timeout: float = 10.0) -> Iterator[None]: + """Acquire advisory lock — delegates to underlying channel.""" + with self._channel.lock(key, timeout): + yield + def send(self, key: str, message: T, ttl: int | None = None) -> None: """Send a typed message. From 3569c8a6c9d9068feb08358ed580d1be22719525 Mon Sep 17 00:00:00 2001 From: Makoto Yui Date: Sun, 5 Apr 2026 18:12:46 +0900 Subject: [PATCH 02/21] feat: rename add method to atomic_add for clarity in Channel interface --- graflow/channels/base.py | 2 +- graflow/channels/typed.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/graflow/channels/base.py b/graflow/channels/base.py index 44d4304..c34d58a 100644 --- a/graflow/channels/base.py +++ b/graflow/channels/base.py @@ -72,7 +72,7 @@ def prepend(self, key: str, value: Any, ttl: Optional[int] = None) -> int: pass @abstractmethod - def add(self, key: str, amount: Union[int, float] = 1) -> Union[int, float]: + def atomic_add(self, key: str, amount: Union[int, float] = 1) -> Union[int, float]: """Atomically add *amount* to the numeric value stored at *key*. If *key* does not exist, it is initialised to 0 before the addition. diff --git a/graflow/channels/typed.py b/graflow/channels/typed.py index 60c3216..e0c3727 100644 --- a/graflow/channels/typed.py +++ b/graflow/channels/typed.py @@ -149,9 +149,9 @@ def prepend(self, key: str, value: Any, ttl: int | None = None) -> int: """ return self._channel.prepend(key, value, ttl) - def add(self, key: str, amount: Union[int, float] = 1) -> Union[int, float]: + def atomic_add(self, key: str, amount: Union[int, float] = 1) -> Union[int, float]: """Atomically add *amount* to the numeric value stored at *key*.""" - return self._channel.add(key, amount) + return self._channel.atomic_add(key, amount) @contextmanager def lock(self, key: str, timeout: float = 10.0) -> Iterator[None]: From 9c43073a72aa38b3e744f78ab5ea7481f895f29b Mon Sep 17 00:00:00 2001 From: Makoto Yui Date: Sun, 5 Apr 2026 18:32:53 +0900 Subject: [PATCH 03/21] feat: implement atomic_add method and advisory locking in MemoryChannel --- graflow/channels/memory_channel.py | 75 +++++++++++++++++++++++++++++- 1 file changed, 74 insertions(+), 1 deletion(-) diff --git a/graflow/channels/memory_channel.py b/graflow/channels/memory_channel.py index 225ff86..b894136 100644 --- a/graflow/channels/memory_channel.py +++ b/graflow/channels/memory_channel.py @@ -2,8 +2,10 @@ from __future__ import annotations +import threading import time -from typing import Any, Dict, List, Optional +from contextlib import contextmanager +from typing import Any, Dict, Iterator, List, Optional, Union from graflow.channels.base import Channel @@ -16,6 +18,21 @@ def __init__(self, name: str, **kwargs): super().__init__(name) self.data: Dict[str, Any] = {} self.ttl_data: Dict[str, float] = {} + self._key_locks: Dict[str, threading.RLock] = {} + self._key_locks_guard = threading.Lock() # protects _key_locks dict itself + + def __getstate__(self) -> Dict[str, Any]: + """Exclude unpicklable lock objects during serialization.""" + state = self.__dict__.copy() + del state["_key_locks"] + del state["_key_locks_guard"] + return state + + def __setstate__(self, state: Dict[str, Any]) -> None: + """Recreate lock objects after deserialization.""" + self.__dict__.update(state) # type: ignore[assignment] + self._key_locks = {} + self._key_locks_guard = threading.Lock() def set(self, key: str, value: Any, ttl: Optional[int] = None) -> None: """Store data in the channel.""" @@ -122,3 +139,59 @@ def prepend(self, key: str, value: Any, ttl: Optional[int] = None) -> int: self.ttl_data[key] = time.time() + ttl return len(self.data[key]) + + def atomic_add(self, key: str, amount: Union[int, float] = 1) -> Union[int, float]: + """Atomically add *amount* to the numeric value stored at *key*. + + Thread-safe: acquires a per-key lock so concurrent calls do not lose + updates. + + Args: + key: The key identifying the numeric value. + amount: The value to add (default 1). May be negative. + + Returns: + The value after the addition. + + Raises: + TypeError: If the existing value is not numeric. + """ + with self._get_key_lock(key): + self._cleanup_expired(key) + current = self.data.get(key, 0) + if not isinstance(current, int | float): + raise TypeError(f"Key '{key}' holds {type(current).__name__}, expected int or float") + new_value = current + amount + self.data[key] = new_value + return new_value + + # -- advisory locking for compound operations -- + + def _get_key_lock(self, key: str) -> threading.RLock: + """Return (or create) the RLock associated with *key*.""" + if key not in self._key_locks: + with self._key_locks_guard: + # Double-checked locking + if key not in self._key_locks: + self._key_locks[key] = threading.RLock() + return self._key_locks[key] + + @contextmanager + def lock(self, key: str, timeout: float = 10.0) -> Iterator[None]: + """Acquire an advisory per-key lock for compound read-modify-write. + + Args: + key: Logical key to lock on. + timeout: Maximum seconds to wait for the lock. + + Raises: + TimeoutError: If the lock cannot be acquired within *timeout*. + """ + rlock = self._get_key_lock(key) + acquired = rlock.acquire(timeout=timeout) + if not acquired: + raise TimeoutError(f"Could not acquire lock for key '{key}' within {timeout}s") + try: + yield + finally: + rlock.release() From 7026671985bdb2bfbbc2e1933d57fd1f283f95ef Mon Sep 17 00:00:00 2001 From: Makoto Yui Date: Sun, 5 Apr 2026 18:33:07 +0900 Subject: [PATCH 04/21] feat: add atomic_add method for atomic arithmetic in RedisChannel --- graflow/channels/redis_channel.py | 33 ++++++++++++++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/graflow/channels/redis_channel.py b/graflow/channels/redis_channel.py index 508ce7d..9437d46 100644 --- a/graflow/channels/redis_channel.py +++ b/graflow/channels/redis_channel.py @@ -3,7 +3,7 @@ from __future__ import annotations import json -from typing import Any, List, Optional, cast +from typing import Any, List, Optional, Union, cast from graflow.channels.base import Channel @@ -217,6 +217,37 @@ def prepend(self, key: str, value: Any, ttl: Optional[int] = None) -> int: return cast(int, length) + # Redis INCRBYFLOAT uses a separate key namespace to avoid JSON + # serialisation overhead and to leverage native atomic arithmetic. + _COUNTER_PREFIX = "counter:" + + def _get_counter_key(self, key: str) -> str: + """Return the Redis key used for numeric counters.""" + return f"{self.key_prefix}{self._COUNTER_PREFIX}{key}" + + def atomic_add(self, key: str, amount: Union[int, float] = 1) -> Union[int, float]: + """Atomically add *amount* using Redis INCRBYFLOAT. + + This is a server-side atomic operation — no client-side locking needed. + Counter values live in a dedicated key namespace so they don't collide + with JSON-serialised values stored via ``set()``. + + Args: + key: The key identifying the numeric value. + amount: The value to add (default 1). May be negative. + + Returns: + The value after the addition. + """ + counter_key = self._get_counter_key(key) + # INCRBYFLOAT returns the new value as a string; auto-creates key at 0. + result = self.redis_client.incrbyfloat(counter_key, amount) + new_value = float(cast(str, result)) + # Return int when possible for ergonomics + if new_value == int(new_value): + return int(new_value) + return new_value + def __getstate__(self): """Support for pickle serialization.""" state = self.__dict__.copy() From 012662c2c30ed494290ecac7af1091724c2794d7 Mon Sep 17 00:00:00 2001 From: Makoto Yui Date: Sun, 5 Apr 2026 18:36:20 +0900 Subject: [PATCH 05/21] test: add thread safety tests for MemoryChannel and atomic_add method --- .../test_memory_channel_thread_safety.py | 415 ++++++++++++++++++ 1 file changed, 415 insertions(+) create mode 100644 tests/channels/test_memory_channel_thread_safety.py diff --git a/tests/channels/test_memory_channel_thread_safety.py b/tests/channels/test_memory_channel_thread_safety.py new file mode 100644 index 0000000..66f1bd4 --- /dev/null +++ b/tests/channels/test_memory_channel_thread_safety.py @@ -0,0 +1,415 @@ +"""Tests for MemoryChannel thread safety and race conditions. + +Demonstrates that concurrent read-modify-write on MemoryChannel +without locks leads to lost updates (race condition), and that +the ``atomic_add()`` / ``lock()`` APIs solve the problem. + +CPython's GIL makes individual dict operations (get/set) effectively atomic, +but the *compound* read-modify-write pattern is NOT atomic. The race window +is between get() and set(). We insert time.sleep(0) to yield the GIL and +force context switches, reliably reproducing the lost-update problem. +""" + +from __future__ import annotations + +import threading +import time + +import pytest + +from graflow.channels.memory_channel import MemoryChannel + + +class TestRaceConditionReproduction: + """Prove that naive get/set is unsafe under concurrency.""" + + def test_concurrent_get_set_loses_updates(self) -> None: + """Naive get → sleep(0) → set loses updates due to interleaving.""" + channel = MemoryChannel("test") + channel.set("counter", 0) + + num_threads = 10 + increments_per_thread = 100 + expected = num_threads * increments_per_thread + + barrier = threading.Barrier(num_threads) + + def unsafe_increment() -> None: + barrier.wait() + for _ in range(increments_per_thread): + val = channel.get("counter") + time.sleep(0) # Yield GIL → forces interleaving + channel.set("counter", val + 1) + + threads = [threading.Thread(target=unsafe_increment) for _ in range(num_threads)] + for t in threads: + t.start() + for t in threads: + t.join() + + actual = channel.get("counter") + assert actual < expected, ( + f"Expected lost updates but got exact count {actual}. " + f"Extremely unlikely with {num_threads} threads and sleep(0)." + ) + + +class TestAtomicAdd: + """Tests for the atomic ``atomic_add()`` method.""" + + def test_concurrent_add_no_lost_updates(self) -> None: + """Concurrent atomic_add() must not lose any updates.""" + channel = MemoryChannel("test") + channel.set("counter", 0) + + num_threads = 10 + increments_per_thread = 1000 + expected = num_threads * increments_per_thread + + barrier = threading.Barrier(num_threads) + + def atomic_add() -> None: + barrier.wait() + for _ in range(increments_per_thread): + channel.atomic_add("counter") + + threads = [threading.Thread(target=atomic_add) for _ in range(num_threads)] + for t in threads: + t.start() + for t in threads: + t.join() + + actual = channel.get("counter") + assert actual == expected, f"Lost updates: {actual} != {expected}" + + def test_add_positive(self) -> None: + channel = MemoryChannel("test") + channel.set("counter", 10) + result = channel.atomic_add("counter", 5) + assert result == 15 + assert channel.get("counter") == 15 + + def test_add_negative(self) -> None: + """Negative amount works as decrement.""" + channel = MemoryChannel("test") + channel.set("counter", 10) + result = channel.atomic_add("counter", -3) + assert result == 7 + + def test_add_initializes_missing_key(self) -> None: + """Missing key is initialised to 0 before adding.""" + channel = MemoryChannel("test") + result = channel.atomic_add("new_key") + assert result == 1 + assert channel.get("new_key") == 1 + + def test_add_float(self) -> None: + channel = MemoryChannel("test") + channel.set("metric", 1.5) + result = channel.atomic_add("metric", 0.25) + assert result == pytest.approx(1.75) + + def test_add_raises_on_non_numeric(self) -> None: + channel = MemoryChannel("test") + channel.set("data", "hello") + with pytest.raises(TypeError, match="expected int or float"): + channel.atomic_add("data", 1) + + +class TestAdvisoryLock: + """Tests for the explicit ``lock()`` context manager.""" + + def test_lock_prevents_lost_updates(self) -> None: + """Compound get-modify-set inside lock() must not lose updates.""" + channel = MemoryChannel("test") + channel.set("counter", 0) + + num_threads = 10 + increments_per_thread = 100 + expected = num_threads * increments_per_thread + + barrier = threading.Barrier(num_threads) + + def safe_increment() -> None: + barrier.wait() + for _ in range(increments_per_thread): + with channel.lock("counter"): + val = channel.get("counter") + time.sleep(0) # same yield as the race test + channel.set("counter", val + 1) + + threads = [threading.Thread(target=safe_increment) for _ in range(num_threads)] + for t in threads: + t.start() + for t in threads: + t.join() + + actual = channel.get("counter") + assert actual == expected, f"Lost updates even with lock: {actual} != {expected}" + + def test_lock_is_reentrant(self) -> None: + """Same thread can acquire the same lock twice (RLock).""" + channel = MemoryChannel("test") + channel.set("x", 0) + with channel.lock("x"): + with channel.lock("x"): # must not deadlock + channel.atomic_add("x", 1) + assert channel.get("x") == 1 + + def test_lock_timeout_raises(self) -> None: + """If lock is held by another thread, timeout triggers TimeoutError.""" + channel = MemoryChannel("test") + held = threading.Event() + done = threading.Event() + + def holder() -> None: + with channel.lock("key"): + held.set() + done.wait(timeout=5) + + t = threading.Thread(target=holder) + t.start() + held.wait() + + with pytest.raises(TimeoutError): + with channel.lock("key", timeout=0.1): + pass # pragma: no cover + + done.set() + t.join() + + def test_lock_different_keys_independent(self) -> None: + """Locks on different keys do not block each other.""" + channel = MemoryChannel("test") + order: list[str] = [] + barrier = threading.Barrier(2) + + def lock_a() -> None: + barrier.wait() + with channel.lock("a"): + order.append("a-acquired") + time.sleep(0.05) + + def lock_b() -> None: + barrier.wait() + with channel.lock("b"): + order.append("b-acquired") + time.sleep(0.05) + + ta = threading.Thread(target=lock_a) + tb = threading.Thread(target=lock_b) + ta.start() + tb.start() + ta.join() + tb.join() + + assert "a-acquired" in order + assert "b-acquired" in order + + def test_lock_compound_counter_with_overflow(self) -> None: + """Concurrent counter with threshold-based reset using lock(). + + Each thread increments a counter; when it reaches a threshold the + counter is reset to 0 and an overflow counter is bumped. + Without the lock, the conditional reset is racy. + """ + channel = MemoryChannel("test") + channel.set("counter", 0) + channel.set("overflow_count", 0) + + threshold = 10 + num_threads = 5 + increments_per_thread = 100 + total_increments = num_threads * increments_per_thread + + barrier = threading.Barrier(num_threads) + + def worker() -> None: + barrier.wait() + for _ in range(increments_per_thread): + with channel.lock("counter"): + val = channel.get("counter") + if val >= threshold: + channel.set("counter", 0) + channel.atomic_add("overflow_count", 1) + else: + channel.set("counter", val + 1) + + threads = [threading.Thread(target=worker) for _ in range(num_threads)] + for t in threads: + t.start() + for t in threads: + t.join() + + counter = channel.get("counter") + overflows = channel.get("overflow_count") + + # Invariant: every increment either bumped the counter or triggered a reset + # total_increments = overflows * (threshold + 1) + counter + assert overflows * (threshold + 1) + counter == total_increments, ( + f"Inconsistent state: overflows={overflows}, counter={counter}, expected total={total_increments}" + ) + assert 0 <= counter <= threshold + + def test_lock_multi_key_update(self) -> None: + """Demonstrate lock() protecting a multi-key read-modify-write.""" + channel = MemoryChannel("test") + channel.set("balance_a", 1000) + channel.set("balance_b", 1000) + + num_threads = 10 + transfers_per_thread = 50 + + barrier = threading.Barrier(num_threads) + + def transfer() -> None: + barrier.wait() + for _ in range(transfers_per_thread): + # Lock on a logical "transfer" key to serialize transfers + with channel.lock("transfer"): + a = channel.get("balance_a") + b = channel.get("balance_b") + time.sleep(0) # yield to stress-test + # Move 1 unit from a to b + channel.set("balance_a", a - 1) + channel.set("balance_b", b + 1) + + threads = [threading.Thread(target=transfer) for _ in range(num_threads)] + for t in threads: + t.start() + for t in threads: + t.join() + + a = channel.get("balance_a") + b = channel.get("balance_b") + total_transfers = num_threads * transfers_per_thread + + # Conservation: total should always be 2000 + assert a + b == 2000, f"Balance inconsistency: a={a}, b={b}, sum={a + b}" + assert a == 1000 - total_transfers + assert b == 1000 + total_transfers + + +class TestBaseChannelLockNoop: + """Verify that the base Channel.lock() default is a no-op.""" + + def test_noop_lock(self) -> None: + """RedisChannel (or any subclass) that doesn't override lock() gets a no-op.""" + from graflow.channels.base import Channel + + # Create a minimal concrete subclass that does NOT override lock() + class MinimalChannel(Channel): + def set(self, key, value, ttl=None): # type: ignore[override] + pass + + def get(self, key, default=None): # type: ignore[override] + pass + + def delete(self, key): # type: ignore[override] + return False + + def exists(self, key): # type: ignore[override] + return False + + def keys(self): # type: ignore[override] + return [] + + def clear(self): + pass + + def append(self, key, value, ttl=None): # type: ignore[override] + return 0 + + def prepend(self, key, value, ttl=None): # type: ignore[override] + return 0 + + def atomic_add(self, key, amount=1): # type: ignore[override] + return 0 + + ch = MinimalChannel("noop") + # Should not raise, just pass through + with ch.lock("anything"): + pass + + +class TestSerialization: + """Verify MemoryChannel survives pickle round-trip (checkpoint/resume).""" + + def test_pickle_round_trip_preserves_data(self) -> None: + """Data and TTL survive pickle; locks are recreated.""" + import pickle + + channel = MemoryChannel("test") + channel.set("counter", 42) + channel.set("config", {"batch": 100}) + + # Force lock creation so it exists in __dict__ + channel.atomic_add("counter", 1) + + restored: MemoryChannel = pickle.loads(pickle.dumps(channel)) + + assert restored.get("counter") == 43 + assert restored.get("config") == {"batch": 100} + assert restored.name == "test" + + def test_pickle_round_trip_locks_functional(self) -> None: + """Locks work correctly after deserialization.""" + import pickle + + channel = MemoryChannel("test") + channel.set("x", 0) + + restored: MemoryChannel = pickle.loads(pickle.dumps(channel)) + + # atomic_add must work (uses per-key lock internally) + restored.atomic_add("x", 5) + assert restored.get("x") == 5 + + # Advisory lock must work + with restored.lock("x"): + val = restored.get("x") + restored.set("x", val + 1) + assert restored.get("x") == 6 + + def test_cloudpickle_round_trip(self) -> None: + """Checkpoint uses cloudpickle — verify it works too.""" + import cloudpickle + + channel = MemoryChannel("test") + channel.set("data", [1, 2, 3]) + channel.atomic_add("counter", 10) + + restored: MemoryChannel = cloudpickle.loads(cloudpickle.dumps(channel)) + + assert restored.get("data") == [1, 2, 3] + assert restored.get("counter") == 10 + restored.atomic_add("counter", 1) + assert restored.get("counter") == 11 + + def test_pickle_round_trip_concurrent_after_restore(self) -> None: + """Restored channel handles concurrent access correctly.""" + import pickle + + channel = MemoryChannel("test") + channel.set("counter", 0) + + restored: MemoryChannel = pickle.loads(pickle.dumps(channel)) + + num_threads = 5 + increments_per_thread = 200 + expected = num_threads * increments_per_thread + + barrier = threading.Barrier(num_threads) + + def worker() -> None: + barrier.wait() + for _ in range(increments_per_thread): + restored.atomic_add("counter") + + threads = [threading.Thread(target=worker) for _ in range(num_threads)] + for t in threads: + t.start() + for t in threads: + t.join() + + assert restored.get("counter") == expected From 850792cc7dc8e017e4220198a2dce537df077ab3 Mon Sep 17 00:00:00 2001 From: Makoto Yui Date: Sun, 5 Apr 2026 18:39:44 +0900 Subject: [PATCH 06/21] docs: update README and add channel concurrency example for thread-safe operations --- examples/03_data_flow/README.md | 61 +++++- examples/03_data_flow/channel_concurrency.py | 207 +++++++++++++++++++ 2 files changed, 260 insertions(+), 8 deletions(-) create mode 100644 examples/03_data_flow/channel_concurrency.py diff --git a/examples/03_data_flow/README.md b/examples/03_data_flow/README.md index b092adc..4008fd7 100644 --- a/examples/03_data_flow/README.md +++ b/examples/03_data_flow/README.md @@ -11,7 +11,8 @@ This section demonstrates **data flow and inter-task communication** in Graflow ## What You'll Learn - 📡 Using channels for inter-task communication -- 🔒 Type-safe channels with TypedDict +- 🔒 Thread-safe channel operations (`atomic_add`, `lock`) +- 🏷️ Type-safe channels with TypedDict - 💾 Storing and retrieving task results - 🔄 Data flow patterns in workflows - 📊 Sharing state across task boundaries @@ -37,7 +38,25 @@ uv run python examples/03_data_flow/channels_basic.py --- -### 2. typed_channels.py +### 2. channel_concurrency.py + +**Concept**: Thread-safe channel operations + +Learn how to safely update shared channel data when tasks run in parallel using `ParallelGroup`. + +```bash +uv run python examples/03_data_flow/channel_concurrency.py +``` + +**Key Concepts**: +- Race conditions with naive `get`/`set` under concurrency +- Atomic counter updates with `channel.atomic_add()` +- Advisory locking with `channel.lock()` for compound operations +- Threshold-based reset pattern with multi-key updates + +--- + +### 3. typed_channels.py **Concept**: Type-safe channels @@ -56,7 +75,7 @@ uv run python examples/03_data_flow/typed_channels.py --- -### 3. results_storage.py +### 4. results_storage.py **Concept**: Task results and dependency data @@ -107,7 +126,32 @@ def process_batch(ctx: TaskExecutionContext): channel.set("total_processed", total) ``` -### Pattern 3: Type-Safe Messages +### Pattern 3: Thread-Safe Counter + +```python +@task(inject_context=True) +def count_items(ctx: TaskExecutionContext): + channel = ctx.get_channel() + # Atomic — safe for parallel tasks + channel.atomic_add("processed_count", 1) +``` + +### Pattern 4: Compound Update with Lock + +```python +@task(inject_context=True) +def check_and_reset(ctx: TaskExecutionContext): + channel = ctx.get_channel() + with channel.lock("counter"): + val = channel.get("counter") + if val >= threshold: + channel.set("counter", 0) + channel.atomic_add("overflow_count", 1) + else: + channel.set("counter", val + 1) +``` + +### Pattern 5: Type-Safe Messages ```python from typing import TypedDict @@ -129,7 +173,7 @@ def collect_metrics(ctx: TaskExecutionContext): typed_channel.set("metrics", metrics) ``` -### Pattern 4: Result Dependencies +### Pattern 6: Result Dependencies ```python with workflow("pipeline") as ctx: @@ -310,9 +354,8 @@ def use_config(ctx: TaskExecutionContext): @task(inject_context=True) def track_metrics(ctx: TaskExecutionContext): channel = ctx.get_channel() - metrics = channel.get("metrics", {}) - metrics["processed"] = metrics.get("processed", 0) + 1 - channel.set("metrics", metrics) + # Thread-safe counter — works in parallel tasks + channel.atomic_add("processed", 1) ``` ### Error Accumulation @@ -374,6 +417,8 @@ After mastering data flow: - `channel.set(key, value)` - Store value - `channel.get(key, default=None)` - Retrieve value - `channel.keys()` - List all keys +- `channel.atomic_add(key, amount=1)` - Atomic numeric add/subtract (thread-safe) +- `channel.lock(key, timeout=10.0)` - Advisory lock for compound operations **TypedChannel**: - `typed_channel = ctx.get_typed_channel(SchemaClass)` - Create typed channel diff --git a/examples/03_data_flow/channel_concurrency.py b/examples/03_data_flow/channel_concurrency.py new file mode 100644 index 0000000..7d45016 --- /dev/null +++ b/examples/03_data_flow/channel_concurrency.py @@ -0,0 +1,207 @@ +""" +Channel Concurrency Example +============================ + +This example demonstrates how to safely share and update channel data +when tasks run in parallel using ``ParallelGroup``. + +Problem +------- +A naive read-modify-write (``get`` -> compute -> ``set``) is **not** atomic. +When multiple tasks execute in parallel threads, updates can be lost because +two tasks may read the same value and overwrite each other's writes. + +Solutions +--------- +1. **``channel.atomic_add(key, amount)``** — Atomic numeric add (inc/dec). + Backed by a per-key lock in MemoryChannel and ``INCRBYFLOAT`` in Redis. + +2. **``channel.lock(key)``** — Advisory lock for arbitrary compound + operations. No-op for Redis (server-side serialisation). + +Expected Output +--------------- +=== Channel Concurrency Demo === + +--- Unsafe parallel increment (race condition) --- +Expected counter: 500, Actual: +Updates lost! + +--- Safe parallel increment with atomic_add() --- +Expected counter: 500, Actual: 500 + +--- Safe compound update with lock() --- +Overflow events: 5 +Counter after resets: 0 + +Done! +""" + +from graflow.core.context import TaskExecutionContext +from graflow.core.decorators import task +from graflow.core.task import ParallelGroup +from graflow.core.workflow import workflow + + +def demo_unsafe_increment() -> None: + """Show that naive get/set loses updates in parallel execution.""" + import time + + print("--- Unsafe parallel increment (race condition) ---") + + num_workers = 5 + increments_per_worker = 100 + expected = num_workers * increments_per_worker + + with workflow("unsafe_demo") as ctx: + + @task(inject_context=True) + def init_counter(context: TaskExecutionContext): + context.get_channel().set("counter", 0) + + # Create worker tasks that use naive get/set + workers = [] + for i in range(num_workers): + + @task(inject_context=True, id=f"unsafe_worker_{i}") + def unsafe_worker(context: TaskExecutionContext): + channel = context.get_channel() + for _ in range(increments_per_worker): + val = channel.get("counter") + time.sleep(0) # yield to trigger interleaving + channel.set("counter", val + 1) + + workers.append(unsafe_worker) + + @task(inject_context=True) + def report(context: TaskExecutionContext): + actual = context.get_channel().get("counter") + print(f" Expected counter: {expected}, Actual: {actual}") + if actual < expected: + print(" Updates lost!\n") + else: + print(" (Got lucky — no interleaving this run)\n") + + parallel = ParallelGroup(workers, name="unsafe_group") + _ = init_counter >> parallel >> report + ctx.execute("init_counter") + + +def demo_atomic_add() -> None: + """Show that atomic_add() is safe for parallel numeric updates.""" + print("--- Safe parallel increment with atomic_add() ---") + + num_workers = 5 + increments_per_worker = 100 + expected = num_workers * increments_per_worker + + with workflow("add_demo") as ctx: + + @task(inject_context=True) + def init_counter(context: TaskExecutionContext): + context.get_channel().set("counter", 0) + + workers = [] + for i in range(num_workers): + + @task(inject_context=True, id=f"add_worker_{i}") + def add_worker(context: TaskExecutionContext): + channel = context.get_channel() + for _ in range(increments_per_worker): + channel.atomic_add("counter", 1) + + workers.append(add_worker) + + @task(inject_context=True) + def report(context: TaskExecutionContext): + actual = context.get_channel().get("counter") + print(f" Expected counter: {expected}, Actual: {actual}\n") + + parallel = ParallelGroup(workers, name="add_group") + _ = init_counter >> parallel >> report + ctx.execute("init_counter") + + +def demo_advisory_lock() -> None: + """Show lock() for compound read-modify-write that atomic_add() can't express.""" + print("--- Safe compound update with lock() ---") + + threshold = 10 + num_workers = 5 + increments_per_worker = 10 + + with workflow("lock_demo") as ctx: + + @task(inject_context=True) + def init(context: TaskExecutionContext): + channel = context.get_channel() + channel.set("counter", 0) + channel.set("overflow_count", 0) + + workers = [] + for i in range(num_workers): + + @task(inject_context=True, id=f"lock_worker_{i}") + def lock_worker(context: TaskExecutionContext): + channel = context.get_channel() + for _ in range(increments_per_worker): + # Advisory lock protects the entire read-modify-write block + with channel.lock("counter"): + val = channel.get("counter") + if val >= threshold: + channel.set("counter", 0) + channel.atomic_add("overflow_count", 1) + else: + channel.set("counter", val + 1) + + workers.append(lock_worker) + + @task(inject_context=True) + def report(context: TaskExecutionContext): + channel = context.get_channel() + overflows = channel.get("overflow_count") + counter = channel.get("counter") + print(f" Overflow events: {overflows}") + print(f" Counter after resets: {counter}\n") + + parallel = ParallelGroup(workers, name="lock_group") + _ = init >> parallel >> report + ctx.execute("init") + + +def main(): + print("=== Channel Concurrency Demo ===\n") + demo_unsafe_increment() + demo_atomic_add() + demo_advisory_lock() + print("Done!") + + +if __name__ == "__main__": + main() + + +# ============================================================================ +# Key Takeaways: +# ============================================================================ +# +# 1. **channel.atomic_add(key, amount)** +# - Atomic numeric add/subtract — no lost updates +# - Initialises missing keys to 0 automatically +# - MemoryChannel: per-key RLock; Redis: INCRBYFLOAT (server-side atomic) +# - Use for counters, metrics, scores +# +# 2. **channel.lock(key)** +# - Advisory lock for compound operations that atomic_add() can't express +# - Wrap with ``with channel.lock(key):`` context manager +# - MemoryChannel: per-key RLock; Redis: no-op (server-side serialisation) +# - Use for conditional updates, multi-key transactions +# +# 3. **When to use which** +# - Simple counter? → channel.atomic_add("counter", 1) +# - Decrement? → channel.atomic_add("counter", -1) +# - Conditional update? → with channel.lock("key"): ... +# - Multi-key update? → with channel.lock("key"): ... +# - No concurrency concern? → channel.get() / channel.set() is fine +# +# ============================================================================ From 1fdc256267685f3941c37b01ca56bb29471423d3 Mon Sep 17 00:00:00 2001 From: Makoto Yui Date: Sun, 5 Apr 2026 18:54:48 +0900 Subject: [PATCH 07/21] refactor: simplify atomic_add method by removing counter key logic and using generic key retrieval --- graflow/channels/redis_channel.py | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/graflow/channels/redis_channel.py b/graflow/channels/redis_channel.py index 9437d46..b619062 100644 --- a/graflow/channels/redis_channel.py +++ b/graflow/channels/redis_channel.py @@ -217,20 +217,10 @@ def prepend(self, key: str, value: Any, ttl: Optional[int] = None) -> int: return cast(int, length) - # Redis INCRBYFLOAT uses a separate key namespace to avoid JSON - # serialisation overhead and to leverage native atomic arithmetic. - _COUNTER_PREFIX = "counter:" - - def _get_counter_key(self, key: str) -> str: - """Return the Redis key used for numeric counters.""" - return f"{self.key_prefix}{self._COUNTER_PREFIX}{key}" - def atomic_add(self, key: str, amount: Union[int, float] = 1) -> Union[int, float]: """Atomically add *amount* using Redis INCRBYFLOAT. This is a server-side atomic operation — no client-side locking needed. - Counter values live in a dedicated key namespace so they don't collide - with JSON-serialised values stored via ``set()``. Args: key: The key identifying the numeric value. @@ -239,9 +229,9 @@ def atomic_add(self, key: str, amount: Union[int, float] = 1) -> Union[int, floa Returns: The value after the addition. """ - counter_key = self._get_counter_key(key) + redis_key = self._get_key(key) # INCRBYFLOAT returns the new value as a string; auto-creates key at 0. - result = self.redis_client.incrbyfloat(counter_key, amount) + result = self.redis_client.incrbyfloat(redis_key, amount) new_value = float(cast(str, result)) # Return int when possible for ergonomics if new_value == int(new_value): From 2b5f8557c75b6a0bf5247a1ace56ff9a76567111 Mon Sep 17 00:00:00 2001 From: Makoto Yui Date: Sun, 5 Apr 2026 18:55:24 +0900 Subject: [PATCH 08/21] test: add additional tests for atomic_add method in MemoryChannel --- .../test_memory_channel_thread_safety.py | 25 +++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/tests/channels/test_memory_channel_thread_safety.py b/tests/channels/test_memory_channel_thread_safety.py index 66f1bd4..853f9e5 100644 --- a/tests/channels/test_memory_channel_thread_safety.py +++ b/tests/channels/test_memory_channel_thread_safety.py @@ -115,6 +115,31 @@ def test_add_raises_on_non_numeric(self) -> None: with pytest.raises(TypeError, match="expected int or float"): channel.atomic_add("data", 1) + def test_set_and_atomic_add_share_same_key(self) -> None: + """set() then atomic_add() then get() must all operate on the same value.""" + channel = MemoryChannel("test") + channel.set("counter", 0) + channel.atomic_add("counter", 5) + assert channel.get("counter") == 5 + + channel.set("counter", 100) + channel.atomic_add("counter", -10) + assert channel.get("counter") == 90 + + def test_atomic_add_then_set_overwrites(self) -> None: + """set() after atomic_add() overwrites the value.""" + channel = MemoryChannel("test") + channel.atomic_add("counter", 42) + channel.set("counter", 0) + assert channel.get("counter") == 0 + + def test_atomic_add_visible_in_keys_and_exists(self) -> None: + """Keys created by atomic_add() appear in keys() and exists().""" + channel = MemoryChannel("test") + channel.atomic_add("auto_created", 1) + assert channel.exists("auto_created") + assert "auto_created" in channel.keys() + class TestAdvisoryLock: """Tests for the explicit ``lock()`` context manager.""" From ee0f048527e5eab9076716e40d7e370352b5d5c3 Mon Sep 17 00:00:00 2001 From: Makoto YUI Date: Sun, 5 Apr 2026 18:59:38 +0900 Subject: [PATCH 09/21] Update graflow/channels/memory_channel.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Makoto YUI --- graflow/channels/memory_channel.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/graflow/channels/memory_channel.py b/graflow/channels/memory_channel.py index b894136..7bdc5f1 100644 --- a/graflow/channels/memory_channel.py +++ b/graflow/channels/memory_channel.py @@ -24,8 +24,8 @@ def __init__(self, name: str, **kwargs): def __getstate__(self) -> Dict[str, Any]: """Exclude unpicklable lock objects during serialization.""" state = self.__dict__.copy() - del state["_key_locks"] - del state["_key_locks_guard"] + state.pop("_key_locks", None) + state.pop("_key_locks_guard", None) return state def __setstate__(self, state: Dict[str, Any]) -> None: From 2b841d7773b674c371ccc14d071afd7a7130f4fc Mon Sep 17 00:00:00 2001 From: Makoto YUI Date: Sun, 5 Apr 2026 19:00:16 +0900 Subject: [PATCH 10/21] Update examples/03_data_flow/channel_concurrency.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Makoto YUI --- examples/03_data_flow/channel_concurrency.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/examples/03_data_flow/channel_concurrency.py b/examples/03_data_flow/channel_concurrency.py index 7d45016..d9ac30a 100644 --- a/examples/03_data_flow/channel_concurrency.py +++ b/examples/03_data_flow/channel_concurrency.py @@ -17,7 +17,10 @@ Backed by a per-key lock in MemoryChannel and ``INCRBYFLOAT`` in Redis. 2. **``channel.lock(key)``** — Advisory lock for arbitrary compound - operations. No-op for Redis (server-side serialisation). + operations. Safe for in-process coordination only when the backend + provides a real lock; do **not** rely on a no-op Redis lock for + ``get`` -> compute -> ``set`` sequences, because those commands can + still interleave across clients. Expected Output --------------- From 8352daafab6a6f2d24ee1ef776bfc727185d5058 Mon Sep 17 00:00:00 2001 From: Makoto YUI Date: Sun, 5 Apr 2026 19:01:12 +0900 Subject: [PATCH 11/21] Update graflow/channels/base.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Makoto YUI --- graflow/channels/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/graflow/channels/base.py b/graflow/channels/base.py index c34d58a..f776545 100644 --- a/graflow/channels/base.py +++ b/graflow/channels/base.py @@ -103,7 +103,7 @@ def lock(self, key: str, timeout: float = 10.0) -> Iterator[None]: The lock is *advisory* — regular ``get``/``set`` calls do **not** acquire it automatically. It exists for task authors who need to protect read-modify-write sequences that cannot be expressed with - ``add``. + ``atomic_add()``. The default implementation is a **no-op** (yields immediately). This is appropriate for backends where atomicity is guaranteed by From bca8dff4e1329b7efb2b06cb26c62b5464aec479 Mon Sep 17 00:00:00 2001 From: Makoto Yui Date: Sun, 5 Apr 2026 19:56:09 +0900 Subject: [PATCH 12/21] docs: enhance warning for lock method in Channel class to clarify no mutual exclusion --- graflow/channels/base.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/graflow/channels/base.py b/graflow/channels/base.py index f776545..8933ea2 100644 --- a/graflow/channels/base.py +++ b/graflow/channels/base.py @@ -105,11 +105,15 @@ def lock(self, key: str, timeout: float = 10.0) -> Iterator[None]: protect read-modify-write sequences that cannot be expressed with ``atomic_add()``. - The default implementation is a **no-op** (yields immediately). - This is appropriate for backends where atomicity is guaranteed by - the server (e.g. Redis commands are serialised server-side). - Subclasses that need client-side locking (e.g. ``MemoryChannel`` - under multi-threading) should override this method. + .. warning:: + + The default implementation is a **no-op** (yields immediately) + and provides **no mutual exclusion**. Subclasses that need + compound-operation safety **must** override this method — this + includes both in-process backends under multi-threading + (e.g. ``MemoryChannel``) and distributed backends where + multi-client read-modify-write sequences are racy + (e.g. Redis without a distributed lock). Args: key: Logical key to lock on (does not need to correspond to a From 9cdbcc7d22bfc1ec1ce7a15a5b3e798cf3bc98c9 Mon Sep 17 00:00:00 2001 From: Makoto Yui Date: Sun, 5 Apr 2026 20:21:42 +0900 Subject: [PATCH 13/21] docs: clarify atomic_add and lock method descriptions in README --- examples/03_data_flow/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/03_data_flow/README.md b/examples/03_data_flow/README.md index 4008fd7..c43585d 100644 --- a/examples/03_data_flow/README.md +++ b/examples/03_data_flow/README.md @@ -417,8 +417,8 @@ After mastering data flow: - `channel.set(key, value)` - Store value - `channel.get(key, default=None)` - Retrieve value - `channel.keys()` - List all keys -- `channel.atomic_add(key, amount=1)` - Atomic numeric add/subtract (thread-safe) -- `channel.lock(key, timeout=10.0)` - Advisory lock for compound operations +- `channel.atomic_add(key, amount=1)` - Atomic numeric add/subtract (thread-safe for `MemoryChannel`; single-command atomic for Redis) +- `channel.lock(key, timeout=10.0)` - Advisory per-key lock for compound read-modify-write operations (thread-safe via `threading.RLock`) **TypedChannel**: - `typed_channel = ctx.get_typed_channel(SchemaClass)` - Create typed channel From 3e541ef6c0bdff6c4b7f1a7fcc322fe44ce1d530 Mon Sep 17 00:00:00 2001 From: Makoto Yui Date: Sun, 5 Apr 2026 23:14:06 +0900 Subject: [PATCH 14/21] feat: add distributed advisory lock method to RedisChannel --- graflow/channels/redis_channel.py | 38 ++++++++++++++++++++++++++++--- 1 file changed, 35 insertions(+), 3 deletions(-) diff --git a/graflow/channels/redis_channel.py b/graflow/channels/redis_channel.py index b619062..975f9f5 100644 --- a/graflow/channels/redis_channel.py +++ b/graflow/channels/redis_channel.py @@ -3,13 +3,16 @@ from __future__ import annotations import json -from typing import Any, List, Optional, Union, cast +from contextlib import contextmanager +from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Union, cast from graflow.channels.base import Channel +if TYPE_CHECKING: + from redis import Redis + try: import redis - from redis import Redis except ImportError: redis = None @@ -238,6 +241,35 @@ def atomic_add(self, key: str, amount: Union[int, float] = 1) -> Union[int, floa return int(new_value) return new_value + @contextmanager + def lock(self, key: str, timeout: float = 10.0) -> Iterator[None]: + """Acquire a distributed advisory lock scoped to *key*. + + Uses ``redis.lock.Lock`` (SET NX + Lua release) under the hood. + + Args: + key: Logical key to lock on. + timeout: Maximum seconds to wait for the lock. + + Raises: + TimeoutError: If the lock cannot be acquired within *timeout*. + """ + from redis.exceptions import LockNotOwnedError + from redis.lock import Lock + + lock_name = f"{self.key_prefix}lock:{key}" + lock = Lock(self.redis_client, lock_name, timeout=timeout) + acquired = lock.acquire(blocking=True, blocking_timeout=timeout) + if not acquired: + raise TimeoutError(f"Could not acquire lock for key '{key}' within {timeout}s") + try: + yield + finally: + try: + lock.release() + except LockNotOwnedError: + pass + def __getstate__(self): """Support for pickle serialization.""" state = self.__dict__.copy() @@ -247,7 +279,7 @@ def __getstate__(self): def __setstate__(self, state): """Support for pickle deserialization.""" - self.__dict__.update(state) + self.__dict__.update(state) # type: ignore # Recreate the Redis client assert redis is not None, "redis package is required for RedisChannel" self.redis_client = redis.Redis( From ccef489926ae6852128da6e3fdabf03d3dbbe8ab Mon Sep 17 00:00:00 2001 From: Makoto Yui Date: Sun, 5 Apr 2026 23:14:10 +0900 Subject: [PATCH 15/21] docs: enhance lock method documentation to clarify mutual exclusion and timeout behavior --- graflow/channels/base.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/graflow/channels/base.py b/graflow/channels/base.py index 8933ea2..9fcbe3e 100644 --- a/graflow/channels/base.py +++ b/graflow/channels/base.py @@ -105,21 +105,25 @@ def lock(self, key: str, timeout: float = 10.0) -> Iterator[None]: protect read-modify-write sequences that cannot be expressed with ``atomic_add()``. + Both ``MemoryChannel`` (threading lock) and ``RedisChannel`` + (``redis.lock.Lock`` — distributed SET NX + Lua release) provide + real mutual exclusion. + .. warning:: - The default implementation is a **no-op** (yields immediately) - and provides **no mutual exclusion**. Subclasses that need - compound-operation safety **must** override this method — this - includes both in-process backends under multi-threading - (e.g. ``MemoryChannel``) and distributed backends where - multi-client read-modify-write sequences are racy - (e.g. Redis without a distributed lock). + The **base class** default is a **no-op** (yields immediately) + and provides **no mutual exclusion**. Custom subclasses that + need compound-operation safety **must** override this method. Args: key: Logical key to lock on (does not need to correspond to a stored key). timeout: Maximum seconds to wait for the lock. + Raises: + TimeoutError: If the lock cannot be acquired within *timeout* + (raised by ``MemoryChannel`` and ``RedisChannel``). + Yields: None — the lock is held for the duration of the ``with`` block. """ From cd999e67b3a61697ab418cf76ffc1e64471f0eaa Mon Sep 17 00:00:00 2001 From: Makoto Yui Date: Sun, 5 Apr 2026 23:23:24 +0900 Subject: [PATCH 16/21] adjust formatting in __setstate__ method for consistency --- graflow/channels/redis_channel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/graflow/channels/redis_channel.py b/graflow/channels/redis_channel.py index 975f9f5..ef8f4ac 100644 --- a/graflow/channels/redis_channel.py +++ b/graflow/channels/redis_channel.py @@ -279,7 +279,7 @@ def __getstate__(self): def __setstate__(self, state): """Support for pickle deserialization.""" - self.__dict__.update(state) # type: ignore + self.__dict__.update(state) # type: ignore # Recreate the Redis client assert redis is not None, "redis package is required for RedisChannel" self.redis_client = redis.Redis( From 445614d300d1f5c920a1fa6fc3d968593793ce2d Mon Sep 17 00:00:00 2001 From: Makoto YUI Date: Sun, 5 Apr 2026 23:35:08 +0900 Subject: [PATCH 17/21] Update tests/channels/test_memory_channel_thread_safety.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Makoto YUI --- tests/channels/test_memory_channel_thread_safety.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/channels/test_memory_channel_thread_safety.py b/tests/channels/test_memory_channel_thread_safety.py index 853f9e5..47620f1 100644 --- a/tests/channels/test_memory_channel_thread_safety.py +++ b/tests/channels/test_memory_channel_thread_safety.py @@ -319,7 +319,7 @@ class TestBaseChannelLockNoop: """Verify that the base Channel.lock() default is a no-op.""" def test_noop_lock(self) -> None: - """RedisChannel (or any subclass) that doesn't override lock() gets a no-op.""" + """Any subclass that doesn't override lock() gets a no-op.""" from graflow.channels.base import Channel # Create a minimal concrete subclass that does NOT override lock() From 79d4cde4ab59545d82fd9db49b5cdfb6496c6fe8 Mon Sep 17 00:00:00 2001 From: Makoto Yui Date: Sun, 5 Apr 2026 23:33:55 +0900 Subject: [PATCH 18/21] test: add timeout assertion for lock acquisition in TestAdvisoryLock --- tests/channels/test_memory_channel_thread_safety.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/channels/test_memory_channel_thread_safety.py b/tests/channels/test_memory_channel_thread_safety.py index 47620f1..5c37b58 100644 --- a/tests/channels/test_memory_channel_thread_safety.py +++ b/tests/channels/test_memory_channel_thread_safety.py @@ -194,7 +194,7 @@ def holder() -> None: t = threading.Thread(target=holder) t.start() - held.wait() + assert held.wait(timeout=5), "holder thread failed to acquire lock" with pytest.raises(TimeoutError): with channel.lock("key", timeout=0.1): From 6621148bd4d33c21864b1a6c0c6bc8761ed6082f Mon Sep 17 00:00:00 2001 From: Makoto Yui Date: Sun, 5 Apr 2026 23:36:26 +0900 Subject: [PATCH 19/21] docs: clarify advisory lock documentation to emphasize safe usage and coordination --- examples/03_data_flow/channel_concurrency.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/examples/03_data_flow/channel_concurrency.py b/examples/03_data_flow/channel_concurrency.py index d9ac30a..fe708b7 100644 --- a/examples/03_data_flow/channel_concurrency.py +++ b/examples/03_data_flow/channel_concurrency.py @@ -17,10 +17,9 @@ Backed by a per-key lock in MemoryChannel and ``INCRBYFLOAT`` in Redis. 2. **``channel.lock(key)``** — Advisory lock for arbitrary compound - operations. Safe for in-process coordination only when the backend - provides a real lock; do **not** rely on a no-op Redis lock for - ``get`` -> compute -> ``set`` sequences, because those commands can - still interleave across clients. + operations. ``MemoryChannel`` uses a per-key ``threading.RLock`` for + in-process coordination; ``RedisChannel`` uses ``redis.lock.Lock`` + (SET NX + Lua release) for cross-client distributed coordination. Expected Output --------------- From 30eb9c18a70ab5024ec5fc7eeade6d8a86fbf116 Mon Sep 17 00:00:00 2001 From: Makoto YUI Date: Sun, 5 Apr 2026 23:38:28 +0900 Subject: [PATCH 20/21] Update examples/03_data_flow/channel_concurrency.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Makoto YUI --- examples/03_data_flow/channel_concurrency.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/03_data_flow/channel_concurrency.py b/examples/03_data_flow/channel_concurrency.py index fe708b7..29ccc73 100644 --- a/examples/03_data_flow/channel_concurrency.py +++ b/examples/03_data_flow/channel_concurrency.py @@ -196,8 +196,8 @@ def main(): # 2. **channel.lock(key)** # - Advisory lock for compound operations that atomic_add() can't express # - Wrap with ``with channel.lock(key):`` context manager -# - MemoryChannel: per-key RLock; Redis: no-op (server-side serialisation) -# - Use for conditional updates, multi-key transactions +# - MemoryChannel: per-key RLock; Redis: distributed lock for the same key +# - Use for conditional updates and other compound read-modify-write logic # # 3. **When to use which** # - Simple counter? → channel.atomic_add("counter", 1) From d7dee2fbd97daf57931794e02b7463b336992d36 Mon Sep 17 00:00:00 2001 From: Makoto YUI Date: Sun, 5 Apr 2026 23:38:47 +0900 Subject: [PATCH 21/21] Update graflow/channels/memory_channel.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Makoto YUI --- graflow/channels/memory_channel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/graflow/channels/memory_channel.py b/graflow/channels/memory_channel.py index 7bdc5f1..ddb12d9 100644 --- a/graflow/channels/memory_channel.py +++ b/graflow/channels/memory_channel.py @@ -159,7 +159,7 @@ def atomic_add(self, key: str, amount: Union[int, float] = 1) -> Union[int, floa with self._get_key_lock(key): self._cleanup_expired(key) current = self.data.get(key, 0) - if not isinstance(current, int | float): + if not isinstance(current, (int, float)): raise TypeError(f"Key '{key}' holds {type(current).__name__}, expected int or float") new_value = current + amount self.data[key] = new_value