Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 125 additions & 0 deletions packages/cli/tests/bindings-test/src/test_queue.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import asyncio

import pytest

pytestmark = pytest.mark.asyncio

_cache = None


def _find(messages, predicate):
return next(m for m in messages if predicate(m))


# Send everything at once to reduce the overall test time.
# Receiving a message from the queue takes ~2 seconds,
# so batching all sends into a single sleep is more efficient.
async def _send_all_messages(env):
global _cache
if _cache is not None:
return _cache

from worker import RECEIVED_MESSAGES

RECEIVED_MESSAGES.clear()

q = env.TEST_QUEUE
await asyncio.gather(
q.send("hello queue"),
q.send({"key": "value", "number": 42}),
q.send(123),
q.send("text message", contentType="text"),
q.send(None),
q.send(True),
q.send([1, 2, 3]),
q.send(""),
q.send({"outer": {"inner": "deep"}, "list": [1, 2]}),
q.sendBatch(
[
{"body": "batch 1"},
{"body": "batch 2"},
{"body": "batch 3"},
]
),
q.sendBatch(
[{"body": "text msg", "contentType": "text"}],
delaySeconds=0,
),
)

await asyncio.sleep(2)

assert len(RECEIVED_MESSAGES) >= 13
_cache = list(RECEIVED_MESSAGES)
return _cache


async def test_send_string(env):
msgs = await _send_all_messages(env)
msg = _find(msgs, lambda m: m["body"] == "hello queue")
assert isinstance(msg["id"], str)
assert msg["attempts"] >= 1


async def test_send_dict(env):
msgs = await _send_all_messages(env)
msg = _find(
msgs,
lambda m: isinstance(m["body"], dict) and m["body"].get("key") == "value",
)
assert msg["body"]["number"] == 42


async def test_send_number(env):
msgs = await _send_all_messages(env)
_find(msgs, lambda m: m["body"] == 123)


async def test_send_with_content_type(env):
msgs = await _send_all_messages(env)
_find(msgs, lambda m: m["body"] == "text message")


async def test_send_none(env):
msgs = await _send_all_messages(env)
_find(msgs, lambda m: m["body"] is None)


async def test_send_bool(env):
msgs = await _send_all_messages(env)
_find(msgs, lambda m: m["body"] is True)


async def test_send_list(env):
msgs = await _send_all_messages(env)
_find(msgs, lambda m: m["body"] == [1, 2, 3])


async def test_send_empty_string(env):
msgs = await _send_all_messages(env)
_find(msgs, lambda m: m["body"] == "")


async def test_send_nested_dict(env):
msgs = await _send_all_messages(env)
msg = _find(
msgs,
lambda m: isinstance(m["body"], dict)
and isinstance(m["body"].get("outer"), dict),
)
assert msg["body"]["outer"]["inner"] == "deep"
assert msg["body"]["list"] == [1, 2]


async def test_send_batch(env):
msgs = await _send_all_messages(env)
bodies = [m["body"] for m in msgs]
assert "batch 1" in bodies
assert "batch 2" in bodies
assert "batch 3" in bodies


async def test_send_batch_with_options(env):
msgs = await _send_all_messages(env)
bodies = [m["body"] for m in msgs]
assert "text msg" in bodies
54 changes: 50 additions & 4 deletions packages/cli/tests/bindings-test/src/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import asyncio
import importlib.util
import sys
from asyncio import InvalidStateError

import pytest
from pyodide.webloop import WebLoop
Expand All @@ -28,11 +29,37 @@ async def _noop(*args):
WebLoop.shutdown_asyncgens = _noop
WebLoop.shutdown_default_executor = _noop

# Pyodide 0.26.0a2's _cancel_all_tasks calls task.exception() on pending tasks,
# which raises InvalidStateError under Pyodide's WebLoop.
# Ignore this error to prevent pytest-asyncio from crashing.
# Pyodide 0.26.0a2's WebLoop causes InvalidStateError when the
# _cancel_all_tasks calls task.exception() on done-but-not-cancelled tasks.
# Replace with a version that cancels tasks but tolerates that error.
if sys.version_info < (3, 13):
asyncio.runners._cancel_all_tasks = lambda loop: None # type: ignore[attr-defined]

def _cancel_all_tasks(loop):
to_cancel = asyncio.tasks.all_tasks(loop)
if not to_cancel:
return
for task in to_cancel:
task.cancel()
loop.run_until_complete(
asyncio.tasks.gather(*to_cancel, return_exceptions=True)
)
for task in to_cancel:
if task.cancelled():
continue
try:
if task.exception() is not None:
loop.call_exception_handler(
{
"message": "unhandled exception during asyncio.run() shutdown",
"exception": task.exception(),
"task": task,
}
)
# Note: This exception catch is added from the original implementation
except (InvalidStateError, RuntimeError):
pass

asyncio.runners._cancel_all_tasks = _cancel_all_tasks # type: ignore[attr-defined]


class ResultCollector:
Expand Down Expand Up @@ -79,6 +106,11 @@ def pytest_runtest_makereport(self, item, call):
else "unknown error",
"traceback": report.longreprtext,
}
elif report.when in ("setup", "teardown") and report.skipped:
self.results[key] = {
"status": "skipped",
"reason": str(report.longrepr),
}
elif report.when in ("setup", "teardown") and report.failed:
self.results[key] = {
"status": "error",
Expand All @@ -96,6 +128,9 @@ def env(self):
return self._env


RECEIVED_MESSAGES = []


class Default(WorkerEntrypoint):
async def fetch(self, request):
from urllib.parse import urlparse
Expand All @@ -109,6 +144,17 @@ async def fetch(self, request):
return Response.json({"ok": True})
return Response.json({"error": "not found"}, status=404)

async def queue(self, batch, env, ctx):
for message in batch.messages:
RECEIVED_MESSAGES.append(
{
"id": message.id,
"body": message.body,
"attempts": message.attempts,
}
)
message.ack()

def _run_suite(self, suite_name):
module = f"test_{suite_name}"
if importlib.util.find_spec(module) is None:
Expand Down
8 changes: 8 additions & 0 deletions packages/cli/tests/bindings-test/wrangler.jsonc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,14 @@
"database_name": "test-db"
}
],
"queues": {
"producers": [
{ "binding": "TEST_QUEUE", "queue": "test-queue" }
],
"consumers": [
{ "queue": "test-queue", "max_batch_size": 10, "max_batch_timeout": 1 }
]
},
"durable_objects": {
"bindings": [
{ "name": "TEST_DO", "class_name": "TestDurableObject" }
Expand Down
58 changes: 44 additions & 14 deletions packages/runtime-sdk/src/workers/_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,8 +319,13 @@ def _manage_pyproxies():
destroy_proxies(proxies)


def _is_js_instance(val, js_cls_name):
return hasattr(val, "constructor") and val.constructor.name == js_cls_name
def _is_js_instance(val, js_cls_names: str | set[str]):
if not hasattr(val, "constructor"):
return False
name = val.constructor.name
if isinstance(js_cls_names, set):
return name in js_cls_names
return name == js_cls_names
Comment thread
ryanking13 marked this conversation as resolved.


try:
Expand Down Expand Up @@ -1107,13 +1112,20 @@ def __init__(self, binding):

def _convert_result(self, result):
converted = python_from_rpc(result)

# After python_from_rpc, some objects may still be JsProxy objects.
# For now, we wrap all of them with the _BindingWrapper (or a subclass of it)
# so that accessing attributes on them will be properly converted.

# TODO: This is a bit of a hack. We should revisit when there are more
# bindings to support with different return types.
if isinstance(converted, JsProxy):
# If the RPC result is another JsProxy, we assume that
# it is another RPC-wrapped object and wrap it as well.
# for example, d1.bind() returns the same object as a result.
# TODO: This is a bit of a hack. We should revisit when there are more
# bindings to support with different patterns.
return self.__class__(converted)
if isinstance(converted, list):
return [
self.__class__(item) if isinstance(item, JsProxy) else item
Comment thread
ryanking13 marked this conversation as resolved.
for item in converted
]
return converted

def _getattr_helper(self, name):
Expand Down Expand Up @@ -1252,6 +1264,13 @@ async def create_batch(self, *args, **kwargs):


class _EnvWrapper:
_BINDING_TYPES = {
"KvNamespace",
"R2Bucket",
"D1Database",
"WorkerQueue",
}

def __init__(self, env: Any):
self._env = env

Expand All @@ -1266,13 +1285,7 @@ def _getattr_helper(self, name):
if _is_js_instance(binding, "WorkflowImpl"):
return _WorkflowBindingWrapper(binding)

if _is_js_instance(binding, "KvNamespace"):
return _BindingWrapper(binding)

if _is_js_instance(binding, "R2Bucket"):
return _BindingWrapper(binding)

if _is_js_instance(binding, "D1Database"):
if _is_js_instance(binding, self._BINDING_TYPES):
return _BindingWrapper(binding)

return binding
Expand Down Expand Up @@ -1550,6 +1563,7 @@ def __init__(self, ctx: "ExecutionContext", env: "Env"):

def __init_subclass__(cls, **_kwargs: Any):
_wrap_subclass(cls)
_wrap_queue_handler(cls)


class WorkflowEntrypoint:
Expand All @@ -1567,3 +1581,19 @@ def __init__(self, ctx: "ExecutionContext", env: "Env"):
def __init_subclass__(cls, **_kwargs: Any):
_wrap_subclass(cls)
_wrap_workflow_step(cls)


def _wrap_queue_handler(cls):
queue_fn = getattr(cls, "queue", None)
if queue_fn is None:
return

@functools.wraps(queue_fn)
async def wrapped_queue(self, batch, *args, **kwargs):
wrapped_batch = _BindingWrapper(batch)
result = queue_fn(self, wrapped_batch, *args, **kwargs)
if inspect.iscoroutine(result):
result = await result
return result

cls.queue = wrapped_queue
Loading