Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 37 additions & 1 deletion lib/crewai/src/crewai/flow/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
BeforeValidator,
ConfigDict,
Field,
PlainSerializer,
PrivateAttr,
SerializeAsAny,
ValidationError,
Expand Down Expand Up @@ -157,6 +158,37 @@ def _resolve_persistence(value: Any) -> Any:
return value


_INITIAL_STATE_CLASS_MARKER = "__crewai_pydantic_class_schema__"


def _serialize_initial_state(value: Any) -> Any:
"""Make ``initial_state`` safe for JSON checkpoint serialization.

``BaseModel`` class refs are emitted as their JSON schema under a sentinel
marker key so deserialization can round-trip them back to a class.
``BaseModel`` instances are dumped to JSON (round-trip as plain dicts,
which ``_create_initial_state`` accepts). Bare ``type`` values that are
not ``BaseModel`` subclasses (e.g. ``dict``) are dropped since they
can't be represented in JSON.
"""
if isinstance(value, type):
if issubclass(value, BaseModel):
return {_INITIAL_STATE_CLASS_MARKER: value.model_json_schema()}
return None
if isinstance(value, BaseModel):
return value.model_dump(mode="json")
return value


def _deserialize_initial_state(value: Any) -> Any:
"""Rehydrate a class ref serialized by :func:`_serialize_initial_state`."""
if isinstance(value, dict) and _INITIAL_STATE_CLASS_MARKER in value:
from crewai.utilities.pydantic_schema_utils import create_model_from_schema

return create_model_from_schema(value[_INITIAL_STATE_CLASS_MARKER])
return value


class FlowState(BaseModel):
"""Base model for all flow states, ensuring each state has a unique ID."""

Expand Down Expand Up @@ -908,7 +940,11 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):

entity_type: Literal["flow"] = "flow"

initial_state: Any = Field(default=None)
initial_state: Annotated[ # type: ignore[type-arg]
type[BaseModel] | type[dict] | dict[str, Any] | BaseModel | None,
BeforeValidator(_deserialize_initial_state),
PlainSerializer(_serialize_initial_state, return_type=Any, when_used="json"),
] = Field(default=None)
name: str | None = Field(default=None)
tracing: bool | None = Field(default=None)
stream: bool = Field(default=False)
Expand Down
62 changes: 61 additions & 1 deletion lib/crewai/tests/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@
from unittest.mock import MagicMock, patch

import pytest
from pydantic import BaseModel

from crewai.agent.core import Agent
from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.crew import Crew
from crewai.flow.flow import Flow, start
from crewai.flow.flow import _INITIAL_STATE_CLASS_MARKER, Flow, start
from crewai.state.checkpoint_config import CheckpointConfig
from crewai.state.checkpoint_listener import (
_find_checkpoint,
Expand Down Expand Up @@ -310,6 +311,65 @@ def test_fork_no_checkpoint_id_unique(self) -> None:
assert state._branch != first


class TestFlowInitialStateSerialization:
"""Regression tests for checkpoint serialization of ``Flow.initial_state``."""

def test_class_ref_serializes_as_schema(self) -> None:
class MyState(BaseModel):
id: str = "x"
foo: str = "bar"

flow = Flow(initial_state=MyState)
state = RuntimeState(root=[flow])
dumped = json.loads(state.model_dump_json())
entity = dumped["entities"][0]
wrapped = entity["initial_state"]
assert isinstance(wrapped, dict)
assert _INITIAL_STATE_CLASS_MARKER in wrapped
assert wrapped[_INITIAL_STATE_CLASS_MARKER].get("title") == "MyState"

def test_class_ref_round_trips_to_basemodel_subclass(self) -> None:
class MyState(BaseModel):
id: str = "x"
foo: str = "bar"

flow = Flow(initial_state=MyState)
raw = RuntimeState(root=[flow]).model_dump_json()
restored = RuntimeState.model_validate_json(
raw, context={"from_checkpoint": True}
)
rehydrated = restored.root[0].initial_state
assert isinstance(rehydrated, type)
assert issubclass(rehydrated, BaseModel)
assert set(rehydrated.model_fields.keys()) == {"id", "foo"}

def test_instance_serializes_as_values(self) -> None:
class MyState(BaseModel):
id: str = "x"
foo: str = "bar"

flow = Flow(initial_state=MyState(foo="baz"))
state = RuntimeState(root=[flow])
dumped = json.loads(state.model_dump_json())
entity = dumped["entities"][0]
assert entity["initial_state"] == {"id": "x", "foo": "baz"}

def test_dict_passthrough(self) -> None:
flow = Flow(initial_state={"id": "x", "foo": "bar"})
state = RuntimeState(root=[flow])
dumped = json.loads(state.model_dump_json())
entity = dumped["entities"][0]
assert entity["initial_state"] == {"id": "x", "foo": "bar"}

def test_dict_round_trips_as_dict(self) -> None:
flow = Flow(initial_state={"id": "x", "foo": "bar"})
raw = RuntimeState(root=[flow]).model_dump_json()
restored = RuntimeState.model_validate_json(
raw, context={"from_checkpoint": True}
)
assert restored.root[0].initial_state == {"id": "x", "foo": "bar"}


# ---------- JsonProvider forking ----------


Expand Down