Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
358 changes: 187 additions & 171 deletions lib/crewai-files/src/crewai_files/cache/upload_cache.py

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions lib/crewai/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,9 @@ file-processing = [
qdrant-edge = [
"qdrant-edge-py>=0.6.0",
]
valkey = [
"valkey-glide>=1.3.0",
]


[project.scripts]
Expand Down
22 changes: 20 additions & 2 deletions lib/crewai/src/crewai/a2a/utils/agent_card.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,12 @@
from typing import TYPE_CHECKING

from a2a.client.errors import A2AClientHTTPError
from a2a.types import AgentCapabilities, AgentCard, AgentSkill
from aiocache import cached # type: ignore[import-untyped]
from a2a.types import (
AgentCapabilities,
AgentCard,
AgentSkill,
)
from aiocache import cached, caches # type: ignore[import-untyped]
from aiocache.serializers import PickleSerializer # type: ignore[import-untyped]
import httpx

Expand All @@ -32,6 +36,7 @@
A2AAuthenticationFailedEvent,
A2AConnectionErrorEvent,
)
from crewai.utilities.cache_config import get_aiocache_config


if TYPE_CHECKING:
Expand All @@ -40,6 +45,18 @@
from crewai.task import Task


_cache_configured = False


def _ensure_cache_configured() -> None:
"""Configure aiocache on first use (lazy initialization)."""
global _cache_configured
if _cache_configured:
return
caches.set_config(get_aiocache_config())
_cache_configured = True


def _get_tls_verify(auth: ClientAuthScheme | None) -> ssl.SSLContext | bool | str:
"""Get TLS verify parameter from auth scheme.

Expand Down Expand Up @@ -191,6 +208,7 @@ async def afetch_agent_card(
else:
auth_hash = _auth_store.compute_key("none", "")
_auth_store.set(auth_hash, auth)
_ensure_cache_configured()
agent_card: AgentCard = await _afetch_agent_card_cached(
endpoint, auth_hash, timeout
)
Expand Down
127 changes: 74 additions & 53 deletions lib/crewai/src/crewai/a2a/utils/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,8 @@
from functools import wraps
import json
import logging
import os
import threading
from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar, cast
from urllib.parse import urlparse

from a2a.server.agent_execution import RequestContext
from a2a.server.events import EventQueue
Expand All @@ -38,7 +37,6 @@
from a2a.utils.errors import ServerError
from aiocache import SimpleMemoryCache, caches # type: ignore[import-untyped]
from pydantic import BaseModel
from typing_extensions import TypedDict

from crewai.a2a.utils.agent_card import _get_server_config
from crewai.a2a.utils.content_type import validate_message_parts
Expand All @@ -50,12 +48,18 @@
A2AServerTaskStartedEvent,
)
from crewai.task import Task
from crewai.utilities.cache_config import (
get_aiocache_config,
parse_cache_url,
use_valkey_cache,
)
from crewai.utilities.pydantic_schema_utils import create_model_from_schema


if TYPE_CHECKING:
from crewai.a2a.extensions.server import ExtensionContext, ServerExtensionRegistry
from crewai.agent import Agent
from crewai.memory.storage.valkey_cache import ValkeyCache


logger = logging.getLogger(__name__)
Expand All @@ -64,52 +68,40 @@
T = TypeVar("T")


class RedisCacheConfig(TypedDict, total=False):
"""Configuration for aiocache Redis backend."""

cache: str
endpoint: str
port: int
db: int
password: str
# ---------------------------------------------------------------------------
# Lazy cache initialisation
# ---------------------------------------------------------------------------

_task_cache: ValkeyCache | None = None
_cache_initialized = False
_cache_init_lock = threading.Lock()

def _parse_redis_url(url: str) -> RedisCacheConfig:
"""Parse a Redis URL into aiocache configuration.

Args:
url: Redis connection URL (e.g., redis://localhost:6379/0).
def _ensure_task_cache() -> None:
"""Initialise the task cache on first use (thread-safe)."""
global _task_cache, _cache_initialized
if _cache_initialized:
return

Returns:
Configuration dict for aiocache.RedisCache.
"""
parsed = urlparse(url)
config: RedisCacheConfig = {
"cache": "aiocache.RedisCache",
"endpoint": parsed.hostname or "localhost",
"port": parsed.port or 6379,
}
if parsed.path and parsed.path != "/":
try:
config["db"] = int(parsed.path.lstrip("/"))
except ValueError:
pass
if parsed.password:
config["password"] = parsed.password
return config
with _cache_init_lock:
if _cache_initialized:
return

if use_valkey_cache():
from crewai.memory.storage.valkey_cache import ValkeyCache

_redis_url = os.environ.get("REDIS_URL")
conn = parse_cache_url() or {}
_task_cache = ValkeyCache(
host=conn.get("host", "localhost"),
port=conn.get("port", 6379),
db=conn.get("db", 0),
password=conn.get("password"),
default_ttl=3600,
)
else:
caches.set_config(get_aiocache_config())

caches.set_config(
{
"default": _parse_redis_url(_redis_url)
if _redis_url
else {
"cache": "aiocache.SimpleMemoryCache",
}
}
)
_cache_initialized = True


def cancellable(
Expand All @@ -130,6 +122,8 @@ def cancellable(
@wraps(fn)
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
"""Wrap function with cancellation monitoring."""
_ensure_task_cache()

context: RequestContext | None = None
for arg in args:
if isinstance(arg, RequestContext):
Expand All @@ -142,19 +136,34 @@ async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
return await fn(*args, **kwargs)

task_id = context.task_id
cache = caches.get("default")

async def poll_for_cancel() -> bool:
"""Poll cache for cancellation flag."""
async def poll_for_cancel_valkey() -> bool:
"""Poll ValkeyCache for cancellation flag."""
while True:
if _task_cache is not None and await _task_cache.get(
f"cancel:{task_id}"
):
return True
await asyncio.sleep(0.1)

async def poll_for_cancel_aiocache() -> bool:
"""Poll aiocache for cancellation flag."""
cache = caches.get("default")
while True:
if await cache.get(f"cancel:{task_id}"):
return True
await asyncio.sleep(0.1)

async def watch_for_cancel() -> bool:
"""Watch for cancellation events via pub/sub or polling."""
if _task_cache is not None:
# ValkeyCache: use polling (pub/sub not implemented yet)
return await poll_for_cancel_valkey()

# aiocache: use pub/sub if Redis, otherwise poll
cache = caches.get("default")
if isinstance(cache, SimpleMemoryCache):
return await poll_for_cancel()
return await poll_for_cancel_aiocache()

try:
client = cache.client
Expand All @@ -168,7 +177,7 @@ async def watch_for_cancel() -> bool:
"Cancel watcher Redis error, falling back to polling",
extra={"task_id": task_id, "error": str(e)},
)
return await poll_for_cancel()
return await poll_for_cancel_aiocache()
return False

execute_task = asyncio.create_task(fn(*args, **kwargs))
Expand All @@ -190,7 +199,12 @@ async def watch_for_cancel() -> bool:
cancel_watch.cancel()
return execute_task.result()
finally:
await cache.delete(f"cancel:{task_id}")
# Clean up cancellation flag
if _task_cache is not None:
await _task_cache.delete(f"cancel:{task_id}")
else:
cache = caches.get("default")
await cache.delete(f"cancel:{task_id}")

return wrapper

Expand Down Expand Up @@ -475,18 +489,25 @@ async def cancel(
if task_id is None or context_id is None:
raise ServerError(InvalidParamsError(message="task_id and context_id required"))

_ensure_task_cache()

if context.current_task and context.current_task.status.state in (
TaskState.completed,
TaskState.failed,
TaskState.canceled,
):
return context.current_task

cache = caches.get("default")

await cache.set(f"cancel:{task_id}", True, ttl=3600)
if not isinstance(cache, SimpleMemoryCache):
await cache.client.publish(f"cancel:{task_id}", "cancel")
if _task_cache is not None:
# Use ValkeyCache
await _task_cache.set(f"cancel:{task_id}", True, ttl=3600)
# Note: pub/sub not implemented for ValkeyCache yet, relies on polling
else:
# Use aiocache
cache = caches.get("default")
await cache.set(f"cancel:{task_id}", True, ttl=3600)
if not isinstance(cache, SimpleMemoryCache):
await cache.client.publish(f"cancel:{task_id}", "cancel")

await event_queue.enqueue_event(
TaskStatusUpdateEvent(
Expand Down
27 changes: 26 additions & 1 deletion lib/crewai/src/crewai/memory/encoding_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from typing import Any
from uuid import uuid4

from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, field_validator

from crewai.flow.flow import Flow, listen, start
from crewai.memory.analyze import (
Expand Down Expand Up @@ -68,6 +68,31 @@ class ItemState(BaseModel):
plan: ConsolidationPlan | None = None
result_record: MemoryRecord | None = None

@field_validator("similar_records", "result_record", mode="before")
@classmethod
def ensure_embedding_is_list(cls, v: Any) -> Any:
"""Ensure MemoryRecord embeddings are list[float], not bytes."""
if v is None:
return None
if isinstance(v, list):
# Process list of MemoryRecords
for record in v:
if isinstance(record, MemoryRecord) and isinstance(
record.embedding, bytes
):
import numpy as np

arr = np.frombuffer(record.embedding, dtype=np.float32)
record.embedding = [float(x) for x in arr]
return v
if isinstance(v, MemoryRecord) and isinstance(v.embedding, bytes):
# Process single MemoryRecord
import numpy as np

arr = np.frombuffer(v.embedding, dtype=np.float32)
v.embedding = [float(x) for x in arr]
return v


class EncodingState(BaseModel):
"""Batch-level state for the encoding flow."""
Expand Down
Loading