Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
using Workerd = import "/workerd/workerd.capnp";

const config :Workerd.Config = (
services = [
(name = "main", worker = .mainWorker),
(name = "TEST_TMPDIR", disk = (writable = true)),
],
);

const mainWorker :Workerd.Worker = (
modules = [
(name = "worker.py", pythonModule = embed "worker.py"),
%PYTHON_MODULES
],
durableObjectNamespaces = [
(
className = "LeafDurableObject",
uniqueKey = "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4",
enableSql = true,
),
(
className = "LeafDurableObjectWithInit",
uniqueKey = "b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5",
enableSql = true,
),
(
className = "RedundantBaseDO",
uniqueKey = "c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6",
enableSql = true,
),
],
durableObjectStorage = (localDisk = "TEST_TMPDIR"),
bindings = [
(name = "DO_LEAF", durableObjectNamespace = "LeafDurableObject"),
(name = "DO_LEAF_INIT", durableObjectNamespace = "LeafDurableObjectWithInit"),
(name = "DO_REDUNDANT", durableObjectNamespace = "RedundantBaseDO"),
],
compatibilityDate = "%COMPAT_DATE",
compatibilityFlags = ["python_workers", "service_binding_extra_handlers", "enable_python_external_sdk"],
);
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
[project]
name = "test"
version = "0.0.0"
requires-python = ">=3.12"
dependencies = []
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Regression test: multi-level inheritance from DurableObject / WorkerEntrypoint.
# https://github.com/cloudflare/workers-py/issues/125
# _wrap_subclass must not double-wrap ctx and env when the hierarchy is >1 deep.

from workers import DurableObject, WorkerEntrypoint
from workers._workers import DurableObjectContext, _EnvWrapper


def assert_wrapped_once(obj):
assert isinstance(obj.env, _EnvWrapper), "env should be an _EnvWrapper"
assert not isinstance(obj.env._env, _EnvWrapper), "env should not be double-wrapped"


def assert_do_wrapped_once(obj):
assert_wrapped_once(obj)
assert isinstance(obj.ctx, DurableObjectContext), (
"ctx should be a DurableObjectContext"
)
assert not isinstance(obj.ctx._ctx, DurableObjectContext), (
"ctx should not be double-wrapped"
)


class BaseDurableObject(DurableObject):
def __init__(self, ctx, env):
super().__init__(ctx, env)
assert_do_wrapped_once(self)
self.ctx.storage.sql.exec("SELECT NULL")

async def shared_method(self):
return "from base"


class LeafDurableObject(BaseDurableObject):
async def hello(self):
return "hello from leaf"

async def verify_wrapping(self):
assert_do_wrapped_once(self)
self.ctx.storage.sql.exec("SELECT NULL")
return True


class LeafDurableObjectWithInit(BaseDurableObject):
def __init__(self, ctx, env):
super().__init__(ctx, env)
assert_do_wrapped_once(self)
self.custom_attr = "custom"

async def hello(self):
return "hello with init"

async def verify_wrapping(self):
assert_do_wrapped_once(self)
assert self.custom_attr == "custom"
self.ctx.storage.sql.exec("SELECT NULL")
return True


class RedundantBaseDO(BaseDurableObject, DurableObject):
async def hello(self):
return "hello from redundant"

async def verify_wrapping(self):
assert_do_wrapped_once(self)
self.ctx.storage.sql.exec("SELECT NULL")
return True


class BaseEntrypoint(WorkerEntrypoint):
def get_name(self):
return "base"


class RedundantBaseEntrypoint(BaseEntrypoint, WorkerEntrypoint):
def get_name(self):
return "redundant"


class Default(RedundantBaseEntrypoint):
async def test(self, ctrl):
assert_wrapped_once(self)
assert self.get_name() == "redundant"

id1 = self.env.DO_LEAF.idFromName("leaf-test")
obj1 = self.env.DO_LEAF.get(id1)
assert await obj1.hello() == "hello from leaf"
assert await obj1.shared_method() == "from base"
assert await obj1.verify_wrapping()

id2 = self.env.DO_LEAF_INIT.idFromName("leaf-init-test")
obj2 = self.env.DO_LEAF_INIT.get(id2)
assert await obj2.hello() == "hello with init"
assert await obj2.verify_wrapping()

id3 = self.env.DO_REDUNDANT.idFromName("redundant-test")
obj3 = self.env.DO_REDUNDANT.get(id3)
assert await obj3.hello() == "hello from redundant"
assert await obj3.verify_wrapping()
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"name": "test-worker",
"compatibility_date": "%COMPAT_DATE",
"compatibility_flags": ["python_workers"]
}
33 changes: 28 additions & 5 deletions packages/runtime-sdk/src/workers/_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1532,6 +1532,26 @@ async def _closure():
return result


def _is_direct_binding_subclass(cls: type, binding_cls: type) -> bool:
"""
Checks if the class is a direct subclass of the binding class.

In order to prevent applying the wrapper multiple times,
we only want to apply the wrapper if the class is directly inheriting
from the binding class, not if it's inheriting from another class that
inherits from the binding class.

Examples:
- `class A(DurableObject)` -> True
- `class B(A)` -> False
- `class C(B)` -> False
- `class D(C, DurableObject)` -> False
"""
return not any(
issubclass(b, binding_cls) for b in cls.__bases__ if b is not binding_cls
)


def _wrap_subclass(cls):
# Override the class __init__ so that we can wrap the `env` in the constructor.
original_init = cls.__init__
Expand Down Expand Up @@ -1590,7 +1610,8 @@ def __init__(self, ctx: "DurableObjectState", env: "Env"):
self.env = env

def __init_subclass__(cls, **_kwargs):
_wrap_subclass(cls)
if _is_direct_binding_subclass(cls, DurableObject):
_wrap_subclass(cls)


class WorkerEntrypoint:
Expand All @@ -1606,8 +1627,9 @@ def __init__(self, ctx: "ExecutionContext", env: "Env"):
self.env = env

def __init_subclass__(cls, **_kwargs: Any):
_wrap_subclass(cls)
_wrap_queue_handler(cls)
if _is_direct_binding_subclass(cls, WorkerEntrypoint):
_wrap_subclass(cls)
_wrap_queue_handler(cls)


class WorkflowEntrypoint:
Expand All @@ -1623,8 +1645,9 @@ def __init__(self, ctx: "ExecutionContext", env: "Env"):
self.env = env

def __init_subclass__(cls, **_kwargs: Any):
_wrap_subclass(cls)
_wrap_workflow_step(cls)
if _is_direct_binding_subclass(cls, WorkflowEntrypoint):
_wrap_subclass(cls)
_wrap_workflow_step(cls)


def _wrap_queue_handler(cls):
Expand Down
Loading