Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ def assert_do_wrapped_once(obj):

class BaseDurableObject(DurableObject):
def __init__(self, ctx, env):
assert isinstance(env, _EnvWrapper)
assert isinstance(ctx, DurableObjectContext)
super().__init__(ctx, env)
assert_do_wrapped_once(self)
self.ctx.storage.sql.exec("SELECT NULL")
Expand All @@ -43,6 +45,8 @@ async def verify_wrapping(self):

class LeafDurableObjectWithInit(BaseDurableObject):
def __init__(self, ctx, env):
assert isinstance(env, _EnvWrapper)
assert isinstance(ctx, DurableObjectContext)
super().__init__(ctx, env)
assert_do_wrapped_once(self)
self.custom_attr = "custom"
Expand Down Expand Up @@ -79,6 +83,17 @@ def get_name(self):

class Default(RedundantBaseEntrypoint):
async def test(self, ctrl):
class Env:
pass

x = _EnvWrapper(Env)
assert _EnvWrapper(x) is x
assert x._env is Env

y = DurableObjectContext(Env)
assert DurableObjectContext(y) is y
assert y._ctx is Env

assert_wrapped_once(self)
assert self.get_name() == "redundant"

Expand Down
74 changes: 38 additions & 36 deletions packages/runtime-sdk/src/workers/_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1230,7 +1230,19 @@ class DurableObjectAbort(BaseException):


class DurableObjectContext:
# __new__ and __init__ set up to make sure that the following passes:
#
# a = DurableObjectContext(x)
# assert DurableObjectContext(a) is a
# assert a._ctx is x
def __new__(cls, ctx):
if isinstance(ctx, DurableObjectContext):
return ctx
return object.__new__(cls)

def __init__(self, ctx: "DurableObjectState"):
if "_ctx" in self.__dict__:
return
self._ctx = ctx

def __getattr__(self, name: str):
Expand Down Expand Up @@ -1315,7 +1327,19 @@ class _EnvWrapper:
"WorkerQueue",
}

# __new__ and __init__ set up to make sure that the following passes:
#
# a = _EnvWrapper(x)
# assert _EnvWrapper(a) is a
# assert a._env is x
def __new__(cls, env):
if isinstance(env, cls):
return env
return object.__new__(cls)

def __init__(self, env: Any):
if "_env" in self.__dict__:
return
self._env = env

def _getattr_helper(self, name):
Expand Down Expand Up @@ -1532,29 +1556,11 @@ async def _closure():
return result


def _is_direct_binding_subclass(cls: type, binding_cls: type) -> bool:
"""
Checks if the class is a direct subclass of the binding class.

In order to prevent applying the wrapper multiple times,
we only want to apply the wrapper if the class is directly inheriting
from the binding class, not if it's inheriting from another class that
inherits from the binding class.

Examples:
- `class A(DurableObject)` -> True
- `class B(A)` -> False
- `class C(B)` -> False
- `class D(C, DurableObject)` -> False
"""
return not any(
issubclass(b, binding_cls) for b in cls.__bases__ if b is not binding_cls
)


def _wrap_subclass(cls):
def _wrap_class(cls):
# Override the class __init__ so that we can wrap the `env` in the constructor.
original_init = cls.__init__
original_init = cls.__dict__.get("__init__")
if original_init is None:
return cls

def wrapped_init(self, *args, **kwargs):
args = list(args)
Expand All @@ -1568,18 +1574,14 @@ def wrapped_init(self, *args, **kwargs):
original_init(self, *args, **kwargs)

cls.__init__ = wrapped_init
return cls


def _wrap_workflow_step(cls):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess calling _wrap_workflow_step twice would still wrap run function twice

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yeah these need to be adjusted too that is right.

run_fn = getattr(cls, "run", None)
run_fn = cls.__dict__.get("run")
if run_fn is None:
return

# Only patch `on_run` for subclasses of WorkflowEntrypoint.
if not issubclass(cls, WorkflowEntrypoint):
# Not a workflow subclass, so don't wrap `on_run`.
return

@functools.wraps(run_fn)
async def wrapped_run(self, event=None, step=None, /, *args, **kwargs):
if event is not None:
Expand All @@ -1597,6 +1599,7 @@ async def wrapped_run(self, event=None, step=None, /, *args, **kwargs):
cls.run = wrapped_run


@_wrap_class
class DurableObject:
"""
Base class used to define a Durable Object.
Expand All @@ -1610,10 +1613,10 @@ def __init__(self, ctx: "DurableObjectState", env: "Env"):
self.env = env

def __init_subclass__(cls, **_kwargs):
if _is_direct_binding_subclass(cls, DurableObject):
_wrap_subclass(cls)
_wrap_class(cls)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need both decorator + __init_subclass__?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes

@ryanking13 ryanking13 Jun 19, 2026

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, this class wrapping thing is really making my head confused



@_wrap_class
class WorkerEntrypoint:
"""
Base class used to define a Worker Entrypoint.
Expand All @@ -1627,11 +1630,11 @@ def __init__(self, ctx: "ExecutionContext", env: "Env"):
self.env = env

def __init_subclass__(cls, **_kwargs: Any):
if _is_direct_binding_subclass(cls, WorkerEntrypoint):
_wrap_subclass(cls)
_wrap_queue_handler(cls)
_wrap_class(cls)
_wrap_queue_handler(cls)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same for queue. the queue function would be wrapped twice I guess?



@_wrap_class
class WorkflowEntrypoint:
"""
Base class used to define a Workflow Entrypoint.
Expand All @@ -1645,9 +1648,8 @@ def __init__(self, ctx: "ExecutionContext", env: "Env"):
self.env = env

def __init_subclass__(cls, **_kwargs: Any):
if _is_direct_binding_subclass(cls, WorkflowEntrypoint):
_wrap_subclass(cls)
_wrap_workflow_step(cls)
_wrap_class(cls)
_wrap_workflow_step(cls)


def _wrap_queue_handler(cls):
Expand Down
Loading