Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 39 additions & 69 deletions src/vercel/workflow/core.py → src/vercel/_internal/workflow/core.py
Original file line number Diff line number Diff line change
@@ -1,90 +1,39 @@
from __future__ import annotations

import contextlib
import contextvars
import dataclasses
import datetime
import json
import sys
from collections.abc import AsyncIterator, Callable, Coroutine, Generator
from typing import TYPE_CHECKING, Any, Generic, ParamSpec, TypeVar

if sys.version_info >= (3, 11):
from typing import Self
else:
from typing_extensions import Self

import pydantic

from vercel._internal.polyfills import Self

from . import py_sandbox

if TYPE_CHECKING:
from . import world as w


P = ParamSpec("P")
T = TypeVar("T")

# Global (default) registries — used when no sandbox is active.
_global_workflows: dict[str, Workflow[Any, Any]] = {}
_global_steps: dict[str, Step[Any, Any]] = {}

# When a sandbox sets these, decorators and lookups use the
# sandbox-local dicts instead of the globals above.
_cv_workflows: contextvars.ContextVar[dict[str, Workflow[Any, Any]] | None] = (
contextvars.ContextVar("_cv_workflows", default=None)
)
_cv_steps: contextvars.ContextVar[dict[str, Step[Any, Any]] | None] = contextvars.ContextVar(
"_cv_steps", default=None
)


def _get_workflows() -> dict[str, Workflow[Any, Any]]:
rv = _cv_workflows.get()
return _global_workflows if rv is None else rv


def _get_steps() -> dict[str, Step[Any, Any]]:
rv = _cv_steps.get()
return _global_steps if rv is None else rv


@contextlib.contextmanager
def clean_registry():
wf_token = _cv_workflows.set({})
st_token = _cv_steps.set({})
try:
yield
finally:
_cv_steps.reset(st_token)
_cv_workflows.reset(wf_token)


class Workflow(Generic[P, T]):
def __init__(self, func: Callable[P, Coroutine[Any, Any, T]]):
self.func = func
self.module = getattr(func, "__module__", "<unknown module>")
self.workflow_id = f"workflow//{self.module}//{func.__qualname__}"
registry = _get_workflows()
assert self.workflow_id not in registry, f"Duplicate workflow ID: {self.workflow_id}"
registry[self.workflow_id] = self


def workflow(func: Callable[P, Coroutine[Any, Any, T]]) -> Workflow[P, T]:
return Workflow(func)


def get_workflow(workflow_id: str) -> Workflow[Any, Any]:
return _get_workflows()[workflow_id]
self.module = func.__module__
self.qualname = func.__qualname__
self.workflow_id = f"workflow//{self.module}.{self.qualname}"


class Step(Generic[P, T]):
max_retries: int = 3

def __init__(self, func: Callable[P, Coroutine[Any, Any, T]]):
self.func = func
module = getattr(func, "__module__", "<unknown module>")
self.name = f"step//{module}//{func.__qualname__}"
registry = _get_steps()
assert self.name not in registry, f"Duplicate step name: {self.name}"
registry[self.name] = self
self.name = f"step//{func.__module__}.{func.__qualname__}"

async def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T:
from . import runtime
Expand All @@ -99,14 +48,6 @@ async def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T:
return await ctx.run_step(self, *args, **kwargs)


def step(func: Callable[P, Coroutine[Any, Any, T]]) -> Step[P, T]:
return Step(func)


def get_step(step_name: str) -> Step[Any, Any]:
return _get_steps()[step_name]


async def sleep(param: int | float | datetime.datetime | str) -> None:
from . import runtime

Expand Down Expand Up @@ -161,7 +102,7 @@ def dispose(self) -> None:
ctx.dispose_hook(correlation_id=self._correlation_id)


class HookMixin:
class BaseHook:
@classmethod
def wait(cls, *, token: str | None = None) -> HookEvent[Self]:
from . import runtime
Expand Down Expand Up @@ -192,3 +133,32 @@ async def resume(self, token_or_hook: str | w.Hook, **kwargs) -> w.Hook:
raise TypeError("resume only supports pydantic models or dataclasses")

return await runtime.resume_hook(token_or_hook, json_str)


class Workflows:
def __init__(self, *, as_vercel_job: bool = True):
self._workflows: dict[str, Workflow] = {}
self._steps: dict[str, Step] = {}
if as_vercel_job and not py_sandbox.in_sandbox():
from . import runtime

runtime.workflow_entrypoint(self)
runtime.step_entrypoint(self)

def workflow(self, func: Callable[P, Coroutine[Any, Any, T]]) -> Workflow[P, T]:
rv = Workflow(func)
assert rv.workflow_id not in self._workflows, f"Duplicate workflow ID: {rv.workflow_id}"
self._workflows[rv.workflow_id] = rv
return rv

def _get_workflow(self, workflow_id: str) -> Workflow[Any, Any]:
return self._workflows[workflow_id]

def step(self, func: Callable[P, Coroutine[Any, Any, T]]) -> Step[P, T]:
rv = Step(func)
assert rv.name not in self._steps, f"Duplicate step name: {rv.name}"
self._steps[rv.name] = rv
return rv

def _get_step(self, step_name: str) -> Step[Any, Any]:
return self._steps[step_name]
Comment thread
fantix marked this conversation as resolved.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,9 @@ def set_result(self, raw_data: Any) -> None:
class WorkflowOrchestratorContext:
_ctx: contextvars.ContextVar[Self] = contextvars.ContextVar("WorkflowContext")

def __init__(self, events: list[w.Event], *, seed: str, started_at: int):
def __init__(
self, events: list[w.Event], *, seed: str, started_at: int, registry: core.Workflows
):
self.events = events
self.replay_index = 0
prng = random.Random(seed)
Expand All @@ -73,28 +75,32 @@ def __init__(self, events: list[w.Event], *, seed: str, started_at: int):
self.suspensions: dict[str, BaseSuspension] = {}
self.hooks: dict[str, Hook] = {}
self.resume_handle: asyncio.Handle | None = None
self.registry = registry

@classmethod
def current(cls) -> Self:
return cls._ctx.get()

async def run_workflow(self: Self, workflow_run: w.WorkflowRun) -> Any:
mod_name = core.get_workflow(workflow_run.workflow_name).module
wf = self.registry._get_workflow(workflow_run.workflow_name)
if not workflow_run.input or not isinstance(workflow_run.input, list):
raise RuntimeError(f"Invalid workflow input for run {workflow_run.run_id}")
if not workflow_run.input[0].startswith(b"json"):
raise RuntimeError(f"Unsupported workflow input encoding for run {workflow_run.run_id}")
args, kwargs = json.loads(workflow_run.input[0][len(b"json") :].decode())

with core.clean_registry(), workflow_sandbox(random_seed=workflow_run.run_id):
# Re-import the user module inside the sandbox so @workflow
# registers into the sandbox-local _cv_workflows dict.
importlib.import_module(mod_name)
with workflow_sandbox(random_seed=workflow_run.run_id):
mod = importlib.import_module(wf.module)

# Resolve the sandboxed Workflow by qualname from the
# re-imported module.
obj: Any = mod
for attr in wf.qualname.split("."):
obj = getattr(obj, attr)

workflow = core.get_workflow(workflow_run.workflow_name)
token = self._ctx.set(self)
try:
self._fut = asyncio.ensure_future(workflow.func(*args, **kwargs))
self._fut = asyncio.ensure_future(obj.func(*args, **kwargs))
finally:
self._ctx.reset(token)
return await self._fut
Expand Down Expand Up @@ -220,6 +226,7 @@ async def workflow_handler(
attempt: int,
queue_name: str,
message_id: str,
registry: core.Workflows,
) -> float | None:
world = w.get_world()
run_id = w.WorkflowInvokePayload.model_validate(message).run_id
Expand Down Expand Up @@ -269,7 +276,9 @@ async def workflow_handler(
assert result.event is not None
events.append(result.event)

context = WorkflowOrchestratorContext(events, seed=run_id, started_at=workflow_started_at)
context = WorkflowOrchestratorContext(
events, seed=run_id, started_at=workflow_started_at, registry=registry
)
try:
result = await context.run_workflow(workflow_run)
output = b"json" + json.dumps(result).encode()
Expand Down Expand Up @@ -343,13 +352,14 @@ async def step_handler(
attempt: int,
queue_name: str,
message_id: str,
registry: core.Workflows,
) -> float | None:
world = w.get_world()
req = w.StepInvokePayload.model_validate(message)

# Get the step entity
step_run = await world.steps_get(req.workflow_run_id, req.step_id)
step = core.get_step(step_run.step_name)
step = registry._get_step(step_run.step_name)

# Check if retry_after timestamp hasn't been reached yet
now = datetime.now(UTC)
Expand Down Expand Up @@ -498,17 +508,17 @@ async def step_handler(
return None


def workflow_entrypoint() -> w.HTTPHandler:
def workflow_entrypoint(registry: core.Workflows) -> w.HTTPHandler:
return w.get_world().create_queue_handler(
"__wkf_workflow_",
workflow_handler,
functools.partial(workflow_handler, registry=registry),
)


def step_entrypoint() -> w.HTTPHandler:
def step_entrypoint(registry: core.Workflows) -> w.HTTPHandler:
return w.get_world().create_queue_handler(
"__wkf_step_",
step_handler,
functools.partial(step_handler, registry=registry),
)


Expand Down
File renamed without changes.
File renamed without changes.
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ async def async_handler(body: Any, meta: vqs_client.MessageMetadata) -> None:
# we may get a duplicate invocation but won't lose the scheduled wakeup
await self.queue(
queue_name,
payload,
w.WorkflowInvokePayload.model_validate(payload),
deployment_id=body.get("deploymentId"),
delay_seconds=delay_seconds,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ async def async_handler(body: Any, meta: vqs_client.MessageMetadata) -> None:
# we may get a duplicate invocation but won't lose the scheduled wakeup
await self.queue(
queue_name,
payload,
w.WorkflowInvokePayload.model_validate(payload),
deployment_id=body.get("deploymentId"),
delay_seconds=delay_seconds,
)
Expand Down
6 changes: 3 additions & 3 deletions src/vercel/workflow/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .core import HookEvent, HookMixin, sleep, step, workflow
from .runtime import Run, start
from vercel._internal.workflow.core import BaseHook, HookEvent, Workflows, sleep
from vercel._internal.workflow.runtime import Run, start

__all__ = ["step", "workflow", "sleep", "start", "Run", "HookMixin", "HookEvent"]
__all__ = ["Workflows", "sleep", "start", "Run", "BaseHook", "HookEvent"]
2 changes: 1 addition & 1 deletion tests/test_nanoid.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Test nanoid implementation."""

from vercel.workflow import nanoid
from vercel._internal.workflow import nanoid


def test_generate_default():
Expand Down
2 changes: 1 addition & 1 deletion tests/test_ulid.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pytest

from vercel.workflow.ulid import monotonic_factory
from vercel._internal.workflow.ulid import monotonic_factory


class TestMonotonicFactory:
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_py_sandbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import pytest

from vercel.workflow.py_sandbox import SandboxRestrictionError, workflow_sandbox
from vercel._internal.workflow.py_sandbox import SandboxRestrictionError, workflow_sandbox

SEED = "test-seed-42"

Expand Down
Loading