diff --git a/src/vercel/workflow/worlds/__init__.py b/src/vercel/_internal/workflow/__init__.py similarity index 100% rename from src/vercel/workflow/worlds/__init__.py rename to src/vercel/_internal/workflow/__init__.py diff --git a/src/vercel/workflow/core.py b/src/vercel/_internal/workflow/core.py similarity index 63% rename from src/vercel/workflow/core.py rename to src/vercel/_internal/workflow/core.py index 5998f3f..2de7e97 100644 --- a/src/vercel/workflow/core.py +++ b/src/vercel/_internal/workflow/core.py @@ -1,78 +1,31 @@ from __future__ import annotations -import contextlib -import contextvars import dataclasses import datetime import json -import sys from collections.abc import AsyncIterator, Callable, Coroutine, Generator from typing import TYPE_CHECKING, Any, Generic, ParamSpec, TypeVar -if sys.version_info >= (3, 11): - from typing import Self -else: - from typing_extensions import Self - import pydantic +from vercel._internal.polyfills import Self + +from . import py_sandbox + if TYPE_CHECKING: from . import world as w + P = ParamSpec("P") T = TypeVar("T") -# Global (default) registries — used when no sandbox is active. -_global_workflows: dict[str, Workflow[Any, Any]] = {} -_global_steps: dict[str, Step[Any, Any]] = {} - -# When a sandbox sets these, decorators and lookups use the -# sandbox-local dicts instead of the globals above. -_cv_workflows: contextvars.ContextVar[dict[str, Workflow[Any, Any]] | None] = ( - contextvars.ContextVar("_cv_workflows", default=None) -) -_cv_steps: contextvars.ContextVar[dict[str, Step[Any, Any]] | None] = contextvars.ContextVar( - "_cv_steps", default=None -) - - -def _get_workflows() -> dict[str, Workflow[Any, Any]]: - rv = _cv_workflows.get() - return _global_workflows if rv is None else rv - - -def _get_steps() -> dict[str, Step[Any, Any]]: - rv = _cv_steps.get() - return _global_steps if rv is None else rv - - -@contextlib.contextmanager -def clean_registry(): - wf_token = _cv_workflows.set({}) - st_token = _cv_steps.set({}) - try: - yield - finally: - _cv_steps.reset(st_token) - _cv_workflows.reset(wf_token) - class Workflow(Generic[P, T]): def __init__(self, func: Callable[P, Coroutine[Any, Any, T]]): self.func = func - self.module = getattr(func, "__module__", "") - self.workflow_id = f"workflow//{self.module}//{func.__qualname__}" - registry = _get_workflows() - assert self.workflow_id not in registry, f"Duplicate workflow ID: {self.workflow_id}" - registry[self.workflow_id] = self - - -def workflow(func: Callable[P, Coroutine[Any, Any, T]]) -> Workflow[P, T]: - return Workflow(func) - - -def get_workflow(workflow_id: str) -> Workflow[Any, Any]: - return _get_workflows()[workflow_id] + self.module = func.__module__ + self.qualname = func.__qualname__ + self.workflow_id = f"workflow//{self.module}.{self.qualname}" class Step(Generic[P, T]): @@ -80,11 +33,7 @@ class Step(Generic[P, T]): def __init__(self, func: Callable[P, Coroutine[Any, Any, T]]): self.func = func - module = getattr(func, "__module__", "") - self.name = f"step//{module}//{func.__qualname__}" - registry = _get_steps() - assert self.name not in registry, f"Duplicate step name: {self.name}" - registry[self.name] = self + self.name = f"step//{func.__module__}.{func.__qualname__}" async def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: from . import runtime @@ -99,14 +48,6 @@ async def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: return await ctx.run_step(self, *args, **kwargs) -def step(func: Callable[P, Coroutine[Any, Any, T]]) -> Step[P, T]: - return Step(func) - - -def get_step(step_name: str) -> Step[Any, Any]: - return _get_steps()[step_name] - - async def sleep(param: int | float | datetime.datetime | str) -> None: from . import runtime @@ -161,7 +102,7 @@ def dispose(self) -> None: ctx.dispose_hook(correlation_id=self._correlation_id) -class HookMixin: +class BaseHook: @classmethod def wait(cls, *, token: str | None = None) -> HookEvent[Self]: from . import runtime @@ -192,3 +133,32 @@ async def resume(self, token_or_hook: str | w.Hook, **kwargs) -> w.Hook: raise TypeError("resume only supports pydantic models or dataclasses") return await runtime.resume_hook(token_or_hook, json_str) + + +class Workflows: + def __init__(self, *, as_vercel_job: bool = True): + self._workflows: dict[str, Workflow] = {} + self._steps: dict[str, Step] = {} + if as_vercel_job and not py_sandbox.in_sandbox(): + from . import runtime + + runtime.workflow_entrypoint(self) + runtime.step_entrypoint(self) + + def workflow(self, func: Callable[P, Coroutine[Any, Any, T]]) -> Workflow[P, T]: + rv = Workflow(func) + assert rv.workflow_id not in self._workflows, f"Duplicate workflow ID: {rv.workflow_id}" + self._workflows[rv.workflow_id] = rv + return rv + + def _get_workflow(self, workflow_id: str) -> Workflow[Any, Any]: + return self._workflows[workflow_id] + + def step(self, func: Callable[P, Coroutine[Any, Any, T]]) -> Step[P, T]: + rv = Step(func) + assert rv.name not in self._steps, f"Duplicate step name: {rv.name}" + self._steps[rv.name] = rv + return rv + + def _get_step(self, step_name: str) -> Step[Any, Any]: + return self._steps[step_name] diff --git a/src/vercel/workflow/nanoid.py b/src/vercel/_internal/workflow/nanoid.py similarity index 100% rename from src/vercel/workflow/nanoid.py rename to src/vercel/_internal/workflow/nanoid.py diff --git a/src/vercel/workflow/py_sandbox.py b/src/vercel/_internal/workflow/py_sandbox.py similarity index 100% rename from src/vercel/workflow/py_sandbox.py rename to src/vercel/_internal/workflow/py_sandbox.py diff --git a/src/vercel/workflow/runtime.py b/src/vercel/_internal/workflow/runtime.py similarity index 95% rename from src/vercel/workflow/runtime.py rename to src/vercel/_internal/workflow/runtime.py index c73dd8b..eeffa0a 100644 --- a/src/vercel/workflow/runtime.py +++ b/src/vercel/_internal/workflow/runtime.py @@ -63,7 +63,9 @@ def set_result(self, raw_data: Any) -> None: class WorkflowOrchestratorContext: _ctx: contextvars.ContextVar[Self] = contextvars.ContextVar("WorkflowContext") - def __init__(self, events: list[w.Event], *, seed: str, started_at: int): + def __init__( + self, events: list[w.Event], *, seed: str, started_at: int, registry: core.Workflows + ): self.events = events self.replay_index = 0 prng = random.Random(seed) @@ -73,28 +75,32 @@ def __init__(self, events: list[w.Event], *, seed: str, started_at: int): self.suspensions: dict[str, BaseSuspension] = {} self.hooks: dict[str, Hook] = {} self.resume_handle: asyncio.Handle | None = None + self.registry = registry @classmethod def current(cls) -> Self: return cls._ctx.get() async def run_workflow(self: Self, workflow_run: w.WorkflowRun) -> Any: - mod_name = core.get_workflow(workflow_run.workflow_name).module + wf = self.registry._get_workflow(workflow_run.workflow_name) if not workflow_run.input or not isinstance(workflow_run.input, list): raise RuntimeError(f"Invalid workflow input for run {workflow_run.run_id}") if not workflow_run.input[0].startswith(b"json"): raise RuntimeError(f"Unsupported workflow input encoding for run {workflow_run.run_id}") args, kwargs = json.loads(workflow_run.input[0][len(b"json") :].decode()) - with core.clean_registry(), workflow_sandbox(random_seed=workflow_run.run_id): - # Re-import the user module inside the sandbox so @workflow - # registers into the sandbox-local _cv_workflows dict. - importlib.import_module(mod_name) + with workflow_sandbox(random_seed=workflow_run.run_id): + mod = importlib.import_module(wf.module) + + # Resolve the sandboxed Workflow by qualname from the + # re-imported module. + obj: Any = mod + for attr in wf.qualname.split("."): + obj = getattr(obj, attr) - workflow = core.get_workflow(workflow_run.workflow_name) token = self._ctx.set(self) try: - self._fut = asyncio.ensure_future(workflow.func(*args, **kwargs)) + self._fut = asyncio.ensure_future(obj.func(*args, **kwargs)) finally: self._ctx.reset(token) return await self._fut @@ -220,6 +226,7 @@ async def workflow_handler( attempt: int, queue_name: str, message_id: str, + registry: core.Workflows, ) -> float | None: world = w.get_world() run_id = w.WorkflowInvokePayload.model_validate(message).run_id @@ -269,7 +276,9 @@ async def workflow_handler( assert result.event is not None events.append(result.event) - context = WorkflowOrchestratorContext(events, seed=run_id, started_at=workflow_started_at) + context = WorkflowOrchestratorContext( + events, seed=run_id, started_at=workflow_started_at, registry=registry + ) try: result = await context.run_workflow(workflow_run) output = b"json" + json.dumps(result).encode() @@ -343,13 +352,14 @@ async def step_handler( attempt: int, queue_name: str, message_id: str, + registry: core.Workflows, ) -> float | None: world = w.get_world() req = w.StepInvokePayload.model_validate(message) # Get the step entity step_run = await world.steps_get(req.workflow_run_id, req.step_id) - step = core.get_step(step_run.step_name) + step = registry._get_step(step_run.step_name) # Check if retry_after timestamp hasn't been reached yet now = datetime.now(UTC) @@ -498,17 +508,17 @@ async def step_handler( return None -def workflow_entrypoint() -> w.HTTPHandler: +def workflow_entrypoint(registry: core.Workflows) -> w.HTTPHandler: return w.get_world().create_queue_handler( "__wkf_workflow_", - workflow_handler, + functools.partial(workflow_handler, registry=registry), ) -def step_entrypoint() -> w.HTTPHandler: +def step_entrypoint(registry: core.Workflows) -> w.HTTPHandler: return w.get_world().create_queue_handler( "__wkf_step_", - step_handler, + functools.partial(step_handler, registry=registry), ) diff --git a/src/vercel/workflow/ulid.py b/src/vercel/_internal/workflow/ulid.py similarity index 100% rename from src/vercel/workflow/ulid.py rename to src/vercel/_internal/workflow/ulid.py diff --git a/src/vercel/workflow/world.py b/src/vercel/_internal/workflow/world.py similarity index 100% rename from src/vercel/workflow/world.py rename to src/vercel/_internal/workflow/world.py diff --git a/src/vercel/_internal/workflow/worlds/__init__.py b/src/vercel/_internal/workflow/worlds/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/vercel/workflow/worlds/local.py b/src/vercel/_internal/workflow/worlds/local.py similarity index 99% rename from src/vercel/workflow/worlds/local.py rename to src/vercel/_internal/workflow/worlds/local.py index 8fb485d..505a363 100644 --- a/src/vercel/workflow/worlds/local.py +++ b/src/vercel/_internal/workflow/worlds/local.py @@ -141,7 +141,7 @@ async def async_handler(body: Any, meta: vqs_client.MessageMetadata) -> None: # we may get a duplicate invocation but won't lose the scheduled wakeup await self.queue( queue_name, - payload, + w.WorkflowInvokePayload.model_validate(payload), deployment_id=body.get("deploymentId"), delay_seconds=delay_seconds, ) diff --git a/src/vercel/workflow/worlds/vercel.py b/src/vercel/_internal/workflow/worlds/vercel.py similarity index 99% rename from src/vercel/workflow/worlds/vercel.py rename to src/vercel/_internal/workflow/worlds/vercel.py index b2bd423..3882a54 100644 --- a/src/vercel/workflow/worlds/vercel.py +++ b/src/vercel/_internal/workflow/worlds/vercel.py @@ -236,7 +236,7 @@ async def async_handler(body: Any, meta: vqs_client.MessageMetadata) -> None: # we may get a duplicate invocation but won't lose the scheduled wakeup await self.queue( queue_name, - payload, + w.WorkflowInvokePayload.model_validate(payload), deployment_id=body.get("deploymentId"), delay_seconds=delay_seconds, ) diff --git a/src/vercel/workflow/__init__.py b/src/vercel/workflow/__init__.py index e487277..104219b 100644 --- a/src/vercel/workflow/__init__.py +++ b/src/vercel/workflow/__init__.py @@ -1,4 +1,4 @@ -from .core import HookEvent, HookMixin, sleep, step, workflow -from .runtime import Run, start +from vercel._internal.workflow.core import BaseHook, HookEvent, Workflows, sleep +from vercel._internal.workflow.runtime import Run, start -__all__ = ["step", "workflow", "sleep", "start", "Run", "HookMixin", "HookEvent"] +__all__ = ["Workflows", "sleep", "start", "Run", "BaseHook", "HookEvent"] diff --git a/tests/test_nanoid.py b/tests/test_nanoid.py index 5293c8c..2af069e 100644 --- a/tests/test_nanoid.py +++ b/tests/test_nanoid.py @@ -1,6 +1,6 @@ """Test nanoid implementation.""" -from vercel.workflow import nanoid +from vercel._internal.workflow import nanoid def test_generate_default(): diff --git a/tests/test_ulid.py b/tests/test_ulid.py index 468d6d6..42a944c 100644 --- a/tests/test_ulid.py +++ b/tests/test_ulid.py @@ -2,7 +2,7 @@ import pytest -from vercel.workflow.ulid import monotonic_factory +from vercel._internal.workflow.ulid import monotonic_factory class TestMonotonicFactory: diff --git a/tests/unit/test_py_sandbox.py b/tests/unit/test_py_sandbox.py index 3d1a911..91d7114 100644 --- a/tests/unit/test_py_sandbox.py +++ b/tests/unit/test_py_sandbox.py @@ -15,7 +15,7 @@ import pytest -from vercel.workflow.py_sandbox import SandboxRestrictionError, workflow_sandbox +from vercel._internal.workflow.py_sandbox import SandboxRestrictionError, workflow_sandbox SEED = "test-seed-42"