diff --git a/lib/crewai/src/crewai/flow/flow.py b/lib/crewai/src/crewai/flow/flow.py index b363ebc714..439a40524e 100644 --- a/lib/crewai/src/crewai/flow/flow.py +++ b/lib/crewai/src/crewai/flow/flow.py @@ -45,6 +45,7 @@ BeforeValidator, ConfigDict, Field, + PlainSerializer, PrivateAttr, SerializeAsAny, ValidationError, @@ -157,6 +158,37 @@ def _resolve_persistence(value: Any) -> Any: return value +_INITIAL_STATE_CLASS_MARKER = "__crewai_pydantic_class_schema__" + + +def _serialize_initial_state(value: Any) -> Any: + """Make ``initial_state`` safe for JSON checkpoint serialization. + + ``BaseModel`` class refs are emitted as their JSON schema under a sentinel + marker key so deserialization can round-trip them back to a class. + ``BaseModel`` instances are dumped to JSON (round-trip as plain dicts, + which ``_create_initial_state`` accepts). Bare ``type`` values that are + not ``BaseModel`` subclasses (e.g. ``dict``) are dropped since they + can't be represented in JSON. + """ + if isinstance(value, type): + if issubclass(value, BaseModel): + return {_INITIAL_STATE_CLASS_MARKER: value.model_json_schema()} + return None + if isinstance(value, BaseModel): + return value.model_dump(mode="json") + return value + + +def _deserialize_initial_state(value: Any) -> Any: + """Rehydrate a class ref serialized by :func:`_serialize_initial_state`.""" + if isinstance(value, dict) and _INITIAL_STATE_CLASS_MARKER in value: + from crewai.utilities.pydantic_schema_utils import create_model_from_schema + + return create_model_from_schema(value[_INITIAL_STATE_CLASS_MARKER]) + return value + + class FlowState(BaseModel): """Base model for all flow states, ensuring each state has a unique ID.""" @@ -908,7 +940,11 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta): entity_type: Literal["flow"] = "flow" - initial_state: Any = Field(default=None) + initial_state: Annotated[ # type: ignore[type-arg] + type[BaseModel] | type[dict] | dict[str, Any] | BaseModel | None, + BeforeValidator(_deserialize_initial_state), + PlainSerializer(_serialize_initial_state, return_type=Any, when_used="json"), + ] = Field(default=None) name: str | None = Field(default=None) tracing: bool | None = Field(default=None) stream: bool = Field(default=False) diff --git a/lib/crewai/tests/test_checkpoint.py b/lib/crewai/tests/test_checkpoint.py index d92a24803c..525e3ca3bc 100644 --- a/lib/crewai/tests/test_checkpoint.py +++ b/lib/crewai/tests/test_checkpoint.py @@ -11,11 +11,12 @@ from unittest.mock import MagicMock, patch import pytest +from pydantic import BaseModel from crewai.agent.core import Agent from crewai.agents.agent_builder.base_agent import BaseAgent from crewai.crew import Crew -from crewai.flow.flow import Flow, start +from crewai.flow.flow import _INITIAL_STATE_CLASS_MARKER, Flow, start from crewai.state.checkpoint_config import CheckpointConfig from crewai.state.checkpoint_listener import ( _find_checkpoint, @@ -310,6 +311,65 @@ def test_fork_no_checkpoint_id_unique(self) -> None: assert state._branch != first +class TestFlowInitialStateSerialization: + """Regression tests for checkpoint serialization of ``Flow.initial_state``.""" + + def test_class_ref_serializes_as_schema(self) -> None: + class MyState(BaseModel): + id: str = "x" + foo: str = "bar" + + flow = Flow(initial_state=MyState) + state = RuntimeState(root=[flow]) + dumped = json.loads(state.model_dump_json()) + entity = dumped["entities"][0] + wrapped = entity["initial_state"] + assert isinstance(wrapped, dict) + assert _INITIAL_STATE_CLASS_MARKER in wrapped + assert wrapped[_INITIAL_STATE_CLASS_MARKER].get("title") == "MyState" + + def test_class_ref_round_trips_to_basemodel_subclass(self) -> None: + class MyState(BaseModel): + id: str = "x" + foo: str = "bar" + + flow = Flow(initial_state=MyState) + raw = RuntimeState(root=[flow]).model_dump_json() + restored = RuntimeState.model_validate_json( + raw, context={"from_checkpoint": True} + ) + rehydrated = restored.root[0].initial_state + assert isinstance(rehydrated, type) + assert issubclass(rehydrated, BaseModel) + assert set(rehydrated.model_fields.keys()) == {"id", "foo"} + + def test_instance_serializes_as_values(self) -> None: + class MyState(BaseModel): + id: str = "x" + foo: str = "bar" + + flow = Flow(initial_state=MyState(foo="baz")) + state = RuntimeState(root=[flow]) + dumped = json.loads(state.model_dump_json()) + entity = dumped["entities"][0] + assert entity["initial_state"] == {"id": "x", "foo": "baz"} + + def test_dict_passthrough(self) -> None: + flow = Flow(initial_state={"id": "x", "foo": "bar"}) + state = RuntimeState(root=[flow]) + dumped = json.loads(state.model_dump_json()) + entity = dumped["entities"][0] + assert entity["initial_state"] == {"id": "x", "foo": "bar"} + + def test_dict_round_trips_as_dict(self) -> None: + flow = Flow(initial_state={"id": "x", "foo": "bar"}) + raw = RuntimeState(root=[flow]).model_dump_json() + restored = RuntimeState.model_validate_json( + raw, context={"from_checkpoint": True} + ) + assert restored.root[0].initial_state == {"id": "x", "foo": "bar"} + + # ---------- JsonProvider forking ----------