diff --git a/packages/cli/tests/bindings-test/src/test_queue.py b/packages/cli/tests/bindings-test/src/test_queue.py new file mode 100644 index 0000000..1d56c10 --- /dev/null +++ b/packages/cli/tests/bindings-test/src/test_queue.py @@ -0,0 +1,125 @@ +import asyncio + +import pytest + +pytestmark = pytest.mark.asyncio + +_cache = None + + +def _find(messages, predicate): + return next(m for m in messages if predicate(m)) + + +# Send everything at once to reduce the overall test time. +# Receiving a message from the queue takes ~2 seconds, +# so batching all sends into a single sleep is more efficient. +async def _send_all_messages(env): + global _cache + if _cache is not None: + return _cache + + from worker import RECEIVED_MESSAGES + + RECEIVED_MESSAGES.clear() + + q = env.TEST_QUEUE + await asyncio.gather( + q.send("hello queue"), + q.send({"key": "value", "number": 42}), + q.send(123), + q.send("text message", contentType="text"), + q.send(None), + q.send(True), + q.send([1, 2, 3]), + q.send(""), + q.send({"outer": {"inner": "deep"}, "list": [1, 2]}), + q.sendBatch( + [ + {"body": "batch 1"}, + {"body": "batch 2"}, + {"body": "batch 3"}, + ] + ), + q.sendBatch( + [{"body": "text msg", "contentType": "text"}], + delaySeconds=0, + ), + ) + + await asyncio.sleep(2) + + assert len(RECEIVED_MESSAGES) >= 13 + _cache = list(RECEIVED_MESSAGES) + return _cache + + +async def test_send_string(env): + msgs = await _send_all_messages(env) + msg = _find(msgs, lambda m: m["body"] == "hello queue") + assert isinstance(msg["id"], str) + assert msg["attempts"] >= 1 + + +async def test_send_dict(env): + msgs = await _send_all_messages(env) + msg = _find( + msgs, + lambda m: isinstance(m["body"], dict) and m["body"].get("key") == "value", + ) + assert msg["body"]["number"] == 42 + + +async def test_send_number(env): + msgs = await _send_all_messages(env) + _find(msgs, lambda m: m["body"] == 123) + + +async def test_send_with_content_type(env): + msgs = await _send_all_messages(env) + _find(msgs, lambda m: m["body"] == "text message") + + +async def test_send_none(env): + msgs = await _send_all_messages(env) + _find(msgs, lambda m: m["body"] is None) + + +async def test_send_bool(env): + msgs = await _send_all_messages(env) + _find(msgs, lambda m: m["body"] is True) + + +async def test_send_list(env): + msgs = await _send_all_messages(env) + _find(msgs, lambda m: m["body"] == [1, 2, 3]) + + +async def test_send_empty_string(env): + msgs = await _send_all_messages(env) + _find(msgs, lambda m: m["body"] == "") + + +async def test_send_nested_dict(env): + msgs = await _send_all_messages(env) + msg = _find( + msgs, + lambda m: isinstance(m["body"], dict) + and isinstance(m["body"].get("outer"), dict), + ) + assert msg["body"]["outer"]["inner"] == "deep" + assert msg["body"]["list"] == [1, 2] + + +async def test_send_batch(env): + msgs = await _send_all_messages(env) + bodies = [m["body"] for m in msgs] + assert "batch 1" in bodies + assert "batch 2" in bodies + assert "batch 3" in bodies + + +async def test_send_batch_with_options(env): + msgs = await _send_all_messages(env) + bodies = [m["body"] for m in msgs] + assert "text msg" in bodies diff --git a/packages/cli/tests/bindings-test/src/worker.py b/packages/cli/tests/bindings-test/src/worker.py index 0adb546..4de3468 100644 --- a/packages/cli/tests/bindings-test/src/worker.py +++ b/packages/cli/tests/bindings-test/src/worker.py @@ -11,6 +11,7 @@ import asyncio import importlib.util import sys +from asyncio import InvalidStateError import pytest from pyodide.webloop import WebLoop @@ -28,11 +29,37 @@ async def _noop(*args): WebLoop.shutdown_asyncgens = _noop WebLoop.shutdown_default_executor = _noop -# Pyodide 0.26.0a2's _cancel_all_tasks calls task.exception() on pending tasks, -# which raises InvalidStateError under Pyodide's WebLoop. -# Ignore this error to prevent pytest-asyncio from crashing. +# Pyodide 0.26.0a2's WebLoop causes InvalidStateError when the +# _cancel_all_tasks calls task.exception() on done-but-not-cancelled tasks. +# Replace with a version that cancels tasks but tolerates that error. if sys.version_info < (3, 13): - asyncio.runners._cancel_all_tasks = lambda loop: None # type: ignore[attr-defined] + + def _cancel_all_tasks(loop): + to_cancel = asyncio.tasks.all_tasks(loop) + if not to_cancel: + return + for task in to_cancel: + task.cancel() + loop.run_until_complete( + asyncio.tasks.gather(*to_cancel, return_exceptions=True) + ) + for task in to_cancel: + if task.cancelled(): + continue + try: + if task.exception() is not None: + loop.call_exception_handler( + { + "message": "unhandled exception during asyncio.run() shutdown", + "exception": task.exception(), + "task": task, + } + ) + # Note: This exception catch is added from the original implementation + except (InvalidStateError, RuntimeError): + pass + + asyncio.runners._cancel_all_tasks = _cancel_all_tasks # type: ignore[attr-defined] class ResultCollector: @@ -79,6 +106,11 @@ def pytest_runtest_makereport(self, item, call): else "unknown error", "traceback": report.longreprtext, } + elif report.when in ("setup", "teardown") and report.skipped: + self.results[key] = { + "status": "skipped", + "reason": str(report.longrepr), + } elif report.when in ("setup", "teardown") and report.failed: self.results[key] = { "status": "error", @@ -96,6 +128,9 @@ def env(self): return self._env +RECEIVED_MESSAGES = [] + + class Default(WorkerEntrypoint): async def fetch(self, request): from urllib.parse import urlparse @@ -109,6 +144,17 @@ async def fetch(self, request): return Response.json({"ok": True}) return Response.json({"error": "not found"}, status=404) + async def queue(self, batch, env, ctx): + for message in batch.messages: + RECEIVED_MESSAGES.append( + { + "id": message.id, + "body": message.body, + "attempts": message.attempts, + } + ) + message.ack() + def _run_suite(self, suite_name): module = f"test_{suite_name}" if importlib.util.find_spec(module) is None: diff --git a/packages/cli/tests/bindings-test/wrangler.jsonc b/packages/cli/tests/bindings-test/wrangler.jsonc index bcbe05a..2b486b7 100644 --- a/packages/cli/tests/bindings-test/wrangler.jsonc +++ b/packages/cli/tests/bindings-test/wrangler.jsonc @@ -16,6 +16,14 @@ "database_name": "test-db" } ], + "queues": { + "producers": [ + { "binding": "TEST_QUEUE", "queue": "test-queue" } + ], + "consumers": [ + { "queue": "test-queue", "max_batch_size": 10, "max_batch_timeout": 1 } + ] + }, "durable_objects": { "bindings": [ { "name": "TEST_DO", "class_name": "TestDurableObject" } diff --git a/packages/runtime-sdk/src/workers/_workers.py b/packages/runtime-sdk/src/workers/_workers.py index 7a6d007..684c853 100644 --- a/packages/runtime-sdk/src/workers/_workers.py +++ b/packages/runtime-sdk/src/workers/_workers.py @@ -319,8 +319,13 @@ def _manage_pyproxies(): destroy_proxies(proxies) -def _is_js_instance(val, js_cls_name): - return hasattr(val, "constructor") and val.constructor.name == js_cls_name +def _is_js_instance(val, js_cls_names: str | set[str]): + if not hasattr(val, "constructor"): + return False + name = val.constructor.name + if isinstance(js_cls_names, set): + return name in js_cls_names + return name == js_cls_names try: @@ -1107,13 +1112,20 @@ def __init__(self, binding): def _convert_result(self, result): converted = python_from_rpc(result) + + # After python_from_rpc, some objects may still be JsProxy objects. + # For now, we wrap all of them with the _BindingWrapper (or a subclass of it) + # so that accessing attributes on them will be properly converted. + + # TODO: This is a bit of a hack. We should revisit when there are more + # bindings to support with different return types. if isinstance(converted, JsProxy): - # If the RPC result is another JsProxy, we assume that - # it is another RPC-wrapped object and wrap it as well. - # for example, d1.bind() returns the same object as a result. - # TODO: This is a bit of a hack. We should revisit when there are more - # bindings to support with different patterns. return self.__class__(converted) + if isinstance(converted, list): + return [ + self.__class__(item) if isinstance(item, JsProxy) else item + for item in converted + ] return converted def _getattr_helper(self, name): @@ -1252,6 +1264,13 @@ async def create_batch(self, *args, **kwargs): class _EnvWrapper: + _BINDING_TYPES = { + "KvNamespace", + "R2Bucket", + "D1Database", + "WorkerQueue", + } + def __init__(self, env: Any): self._env = env @@ -1266,13 +1285,7 @@ def _getattr_helper(self, name): if _is_js_instance(binding, "WorkflowImpl"): return _WorkflowBindingWrapper(binding) - if _is_js_instance(binding, "KvNamespace"): - return _BindingWrapper(binding) - - if _is_js_instance(binding, "R2Bucket"): - return _BindingWrapper(binding) - - if _is_js_instance(binding, "D1Database"): + if _is_js_instance(binding, self._BINDING_TYPES): return _BindingWrapper(binding) return binding @@ -1550,6 +1563,7 @@ def __init__(self, ctx: "ExecutionContext", env: "Env"): def __init_subclass__(cls, **_kwargs: Any): _wrap_subclass(cls) + _wrap_queue_handler(cls) class WorkflowEntrypoint: @@ -1567,3 +1581,19 @@ def __init__(self, ctx: "ExecutionContext", env: "Env"): def __init_subclass__(cls, **_kwargs: Any): _wrap_subclass(cls) _wrap_workflow_step(cls) + + +def _wrap_queue_handler(cls): + queue_fn = getattr(cls, "queue", None) + if queue_fn is None: + return + + @functools.wraps(queue_fn) + async def wrapped_queue(self, batch, *args, **kwargs): + wrapped_batch = _BindingWrapper(batch) + result = queue_fn(self, wrapped_batch, *args, **kwargs) + if inspect.iscoroutine(result): + result = await result + return result + + cls.queue = wrapped_queue