From 6527a7553c8df8baaf9688a28020f57f75c28c90 Mon Sep 17 00:00:00 2001 From: Hood Chatham Date: Thu, 18 Jun 2026 12:48:58 -0700 Subject: [PATCH] fix: Ensure that ctx and env __init__ arguments are always wrapped We fixed the double-wrapping problem in #126 but added a new problem that ctx and env args in subclasses won't be wrapped at all. This switches to wrapping all __init__ functions but making DurableObjectContext and _EnvWrapper idempotent so that double wrapping them is not possible. --- .../durable-object-inheritance/worker.py | 15 ++++ packages/runtime-sdk/src/workers/_workers.py | 74 ++++++++++--------- 2 files changed, 53 insertions(+), 36 deletions(-) diff --git a/packages/cli/tests/workerd-test/durable-object-inheritance/worker.py b/packages/cli/tests/workerd-test/durable-object-inheritance/worker.py index 43f9523..afe9e43 100644 --- a/packages/cli/tests/workerd-test/durable-object-inheritance/worker.py +++ b/packages/cli/tests/workerd-test/durable-object-inheritance/worker.py @@ -23,6 +23,8 @@ def assert_do_wrapped_once(obj): class BaseDurableObject(DurableObject): def __init__(self, ctx, env): + assert isinstance(env, _EnvWrapper) + assert isinstance(ctx, DurableObjectContext) super().__init__(ctx, env) assert_do_wrapped_once(self) self.ctx.storage.sql.exec("SELECT NULL") @@ -43,6 +45,8 @@ async def verify_wrapping(self): class LeafDurableObjectWithInit(BaseDurableObject): def __init__(self, ctx, env): + assert isinstance(env, _EnvWrapper) + assert isinstance(ctx, DurableObjectContext) super().__init__(ctx, env) assert_do_wrapped_once(self) self.custom_attr = "custom" @@ -79,6 +83,17 @@ def get_name(self): class Default(RedundantBaseEntrypoint): async def test(self, ctrl): + class Env: + pass + + x = _EnvWrapper(Env) + assert _EnvWrapper(x) is x + assert x._env is Env + + y = DurableObjectContext(Env) + assert DurableObjectContext(y) is y + assert y._ctx is Env + assert_wrapped_once(self) assert self.get_name() == "redundant" diff --git a/packages/runtime-sdk/src/workers/_workers.py b/packages/runtime-sdk/src/workers/_workers.py index 3709bca..4b59316 100644 --- a/packages/runtime-sdk/src/workers/_workers.py +++ b/packages/runtime-sdk/src/workers/_workers.py @@ -1230,7 +1230,19 @@ class DurableObjectAbort(BaseException): class DurableObjectContext: + # __new__ and __init__ set up to make sure that the following passes: + # + # a = DurableObjectContext(x) + # assert DurableObjectContext(a) is a + # assert a._ctx is x + def __new__(cls, ctx): + if isinstance(ctx, DurableObjectContext): + return ctx + return object.__new__(cls) + def __init__(self, ctx: "DurableObjectState"): + if "_ctx" in self.__dict__: + return self._ctx = ctx def __getattr__(self, name: str): @@ -1315,7 +1327,19 @@ class _EnvWrapper: "WorkerQueue", } + # __new__ and __init__ set up to make sure that the following passes: + # + # a = _EnvWrapper(x) + # assert _EnvWrapper(a) is a + # assert a._env is x + def __new__(cls, env): + if isinstance(env, cls): + return env + return object.__new__(cls) + def __init__(self, env: Any): + if "_env" in self.__dict__: + return self._env = env def _getattr_helper(self, name): @@ -1532,29 +1556,11 @@ async def _closure(): return result -def _is_direct_binding_subclass(cls: type, binding_cls: type) -> bool: - """ - Checks if the class is a direct subclass of the binding class. - - In order to prevent applying the wrapper multiple times, - we only want to apply the wrapper if the class is directly inheriting - from the binding class, not if it's inheriting from another class that - inherits from the binding class. - - Examples: - - `class A(DurableObject)` -> True - - `class B(A)` -> False - - `class C(B)` -> False - - `class D(C, DurableObject)` -> False - """ - return not any( - issubclass(b, binding_cls) for b in cls.__bases__ if b is not binding_cls - ) - - -def _wrap_subclass(cls): +def _wrap_class(cls): # Override the class __init__ so that we can wrap the `env` in the constructor. - original_init = cls.__init__ + original_init = cls.__dict__.get("__init__") + if original_init is None: + return cls def wrapped_init(self, *args, **kwargs): args = list(args) @@ -1568,18 +1574,14 @@ def wrapped_init(self, *args, **kwargs): original_init(self, *args, **kwargs) cls.__init__ = wrapped_init + return cls def _wrap_workflow_step(cls): - run_fn = getattr(cls, "run", None) + run_fn = cls.__dict__.get("run") if run_fn is None: return - # Only patch `on_run` for subclasses of WorkflowEntrypoint. - if not issubclass(cls, WorkflowEntrypoint): - # Not a workflow subclass, so don't wrap `on_run`. - return - @functools.wraps(run_fn) async def wrapped_run(self, event=None, step=None, /, *args, **kwargs): if event is not None: @@ -1597,6 +1599,7 @@ async def wrapped_run(self, event=None, step=None, /, *args, **kwargs): cls.run = wrapped_run +@_wrap_class class DurableObject: """ Base class used to define a Durable Object. @@ -1610,10 +1613,10 @@ def __init__(self, ctx: "DurableObjectState", env: "Env"): self.env = env def __init_subclass__(cls, **_kwargs): - if _is_direct_binding_subclass(cls, DurableObject): - _wrap_subclass(cls) + _wrap_class(cls) +@_wrap_class class WorkerEntrypoint: """ Base class used to define a Worker Entrypoint. @@ -1627,11 +1630,11 @@ def __init__(self, ctx: "ExecutionContext", env: "Env"): self.env = env def __init_subclass__(cls, **_kwargs: Any): - if _is_direct_binding_subclass(cls, WorkerEntrypoint): - _wrap_subclass(cls) - _wrap_queue_handler(cls) + _wrap_class(cls) + _wrap_queue_handler(cls) +@_wrap_class class WorkflowEntrypoint: """ Base class used to define a Workflow Entrypoint. @@ -1645,9 +1648,8 @@ def __init__(self, ctx: "ExecutionContext", env: "Env"): self.env = env def __init_subclass__(cls, **_kwargs: Any): - if _is_direct_binding_subclass(cls, WorkflowEntrypoint): - _wrap_subclass(cls) - _wrap_workflow_step(cls) + _wrap_class(cls) + _wrap_workflow_step(cls) def _wrap_queue_handler(cls):