Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 77 additions & 4 deletions code-interpreter/app/api/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@

from app.app_configs import get_settings
from app.models.schemas import (
CreateSessionRequest,
CreateSessionResponse,
ExecuteFile,
ExecuteRequest,
ExecuteResponse,
FileMetadataResponse,
Expand All @@ -19,7 +22,7 @@
WorkspaceFile,
)
from app.services.executor_base import EntryKind, StreamChunk, StreamResult, WorkspaceEntry
from app.services.executor_factory import execute_python, execute_python_streaming
from app.services.executor_factory import execute_python, execute_python_streaming, get_executor
from app.services.file_storage import FileStorageService

router = APIRouter()
Expand All @@ -46,8 +49,8 @@ def _validate_timeout(req: ExecuteRequest) -> None:
)


def _stage_request_files(
req: ExecuteRequest,
def _resolve_uploaded_files(
files: list[ExecuteFile],
storage: FileStorageService,
) -> tuple[list[tuple[str, bytes]], dict[str, bytes]]:
"""Resolve uploaded file IDs into content for the executor.
Expand All @@ -56,7 +59,7 @@ def _stage_request_files(
"""
staged_files: list[tuple[str, bytes]] = []
input_files_map: dict[str, bytes] = {}
for file in req.files:
for file in files:
try:
content, _ = storage.get_file(file.file_id)
except FileNotFoundError as exc:
Expand All @@ -69,6 +72,17 @@ def _stage_request_files(
return staged_files, input_files_map


def _stage_request_files(
req: ExecuteRequest,
storage: FileStorageService,
) -> tuple[list[tuple[str, bytes]], dict[str, bytes]]:
"""Resolve uploaded file IDs into content for the executor.

Returns (staged_files, input_files_map).
"""
return _resolve_uploaded_files(req.files, storage)


def _save_workspace_files(
entries: tuple[WorkspaceEntry, ...],
input_files_map: dict[str, bytes],
Expand Down Expand Up @@ -248,3 +262,62 @@ def delete_file(file_id: str) -> Response:
)

return Response(status_code=status.HTTP_204_NO_CONTENT)


@router.post(
"/sessions",
response_model=CreateSessionResponse,
status_code=status.HTTP_201_CREATED,
)
def create_session(req: CreateSessionRequest) -> CreateSessionResponse:
"""Create a long-lived code-executor pod with the given TTL.

The pod is guaranteed to be torn down at or before the TTL expires, even
if the API service crashes and restarts.
"""
settings = get_settings()
storage = get_file_storage()
staged_files, _ = _resolve_uploaded_files(req.files, storage)

try:
info = get_executor().create_session(
ttl_seconds=req.ttl_seconds,
files=staged_files,
cpu_time_limit_sec=settings.cpu_time_limit_sec,
memory_limit_mb=settings.memory_limit_mb,
)
except NotImplementedError as exc:
raise HTTPException(
status_code=status.HTTP_501_NOT_IMPLEMENTED,
detail=str(exc),
) from exc
except ValueError as exc:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=str(exc),
) from exc

return CreateSessionResponse(
session_id=info.session_id,
expires_at=info.expires_at,
)


@router.delete("/sessions/{session_id}", status_code=status.HTTP_204_NO_CONTENT)
def delete_session(session_id: str) -> Response:
"""Tear down a session pod by ID."""
try:
deleted = get_executor().delete_session(session_id)
except NotImplementedError as exc:
raise HTTPException(
status_code=status.HTTP_501_NOT_IMPLEMENTED,
detail=str(exc),
) from exc

if not deleted:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Session '{session_id}' not found",
)

return Response(status_code=status.HTTP_204_NO_CONTENT)
34 changes: 31 additions & 3 deletions code-interpreter/app/main.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from __future__ import annotations

import asyncio
import logging
import subprocess
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from contextlib import asynccontextmanager, suppress
from shutil import which
from typing import Final

Expand All @@ -14,6 +15,8 @@
from app.models.schemas import HealthResponse
from app.services.executor_factory import get_executor

SESSION_REAPER_INTERVAL_SEC = 30

# Configure logging
logging.basicConfig(
level=logging.INFO,
Expand Down Expand Up @@ -78,6 +81,24 @@ def _ensure_docker_image_available() -> None:
) from e


async def _reap_expired_sessions_once() -> None:
"""Run a single reap pass via the configured executor."""
try:
count = await asyncio.to_thread(get_executor().reap_expired_sessions)
except Exception:
logger.warning("Session reaper pass failed", exc_info=True)
return
if count > 0:
logger.info("Reaped %d expired session(s)", count)


async def _session_reaper_loop() -> None:
"""Periodically delete sessions whose TTL has elapsed."""
while True:
await asyncio.sleep(SESSION_REAPER_INTERVAL_SEC)
await _reap_expired_sessions_once()


@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
"""Manage application lifespan events."""
Expand All @@ -87,9 +108,16 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
_ensure_docker_image_available()
logger.info("Docker executor image is ready")

yield
# Reap any sessions whose TTL elapsed while the service was down.
await _reap_expired_sessions_once()
reaper_task = asyncio.create_task(_session_reaper_loop())

# Shutdown: Add any cleanup logic here if needed in the future
try:
yield
finally:
reaper_task.cancel()
with suppress(asyncio.CancelledError):
await reaper_task


def create_app() -> FastAPI:
Expand Down
27 changes: 27 additions & 0 deletions code-interpreter/app/models/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,3 +120,30 @@ class ListFilesResponse(BaseModel):
class HealthResponse(BaseModel):
status: Literal["ok", "error"]
message: StrictStr | None = None


DEFAULT_SESSION_TTL_SEC = 15 * 60
MAX_SESSION_TTL_SEC = 24 * 60 * 60


class CreateSessionRequest(BaseModel):
files: list[ExecuteFile] = Field(
default_factory=list,
description="Files to stage in the session workspace at create time.",
)
ttl_seconds: StrictInt = Field(
DEFAULT_SESSION_TTL_SEC,
ge=1,
le=MAX_SESSION_TTL_SEC,
description=(
"Session lifetime in seconds. The session pod is automatically "
"destroyed after this duration even if the API service crashes."
),
)


class CreateSessionResponse(BaseModel):
session_id: StrictStr = Field(..., description="Identifier for the session pod/container.")
expires_at: float = Field(
..., description="Unix timestamp when the session is scheduled to expire."
)
38 changes: 38 additions & 0 deletions code-interpreter/app/services/executor_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,20 @@ class HealthCheck:
message: str | None = None


@dataclass(frozen=True, slots=True)
class SessionInfo:
"""Identifying information for a long-lived session."""

session_id: str
expires_at: float


SESSION_NAME_PREFIX = "code-session-"
SESSION_APP_LABEL = "code-interpreter"
SESSION_COMPONENT_LABEL = "session"
SESSION_EXPIRES_AT_KEY = "code-interpreter.expires-at"


class ExecutorProtocol(Protocol):
def execute_python(
self,
Expand Down Expand Up @@ -168,6 +182,30 @@ def execute_python_streaming(
"""
raise NotImplementedError(f"{type(self).__name__} does not support streaming execution")

def create_session(
self,
*,
ttl_seconds: int,
files: Sequence[tuple[str, bytes]] | None = None,
cpu_time_limit_sec: int | None = None,
memory_limit_mb: int | None = None,
) -> SessionInfo:
"""Create a long-lived execution environment.

Returns identifying information for the session. The session is
guaranteed to be torn down at or before ``expires_at`` even if this
process crashes.
"""
raise NotImplementedError(f"{type(self).__name__} does not support sessions")

def delete_session(self, session_id: str) -> bool:
"""Tear down a session by ID. Returns True if found and deleted."""
raise NotImplementedError(f"{type(self).__name__} does not support sessions")

def reap_expired_sessions(self) -> int:
"""Delete sessions whose TTL has elapsed. Returns number reaped."""
return 0

@staticmethod
def truncate_output(stream: bytes, max_bytes: int) -> str:
if len(stream) <= max_bytes:
Expand Down
Loading
Loading