diff --git a/.gitignore b/.gitignore index bea920f..4cb7485 100644 --- a/.gitignore +++ b/.gitignore @@ -10,4 +10,6 @@ __pycache__ .env config.debug.yaml -data/ \ No newline at end of file +data/ + +tests/ \ No newline at end of file diff --git a/README.md b/README.md index f5ee964..8d084c8 100644 --- a/README.md +++ b/README.md @@ -248,6 +248,35 @@ gemini: - model_name: "gemini-3.0-pro" model_header: x-goog-ext-525001261-jspb: '[1,null,null,null,"9d8ca3786ebdfbea",null,null,0,[4],null,null,1]' + gems: + # Disabled by default to avoid accidental creation/update/deletion of gems. + enabled: false + # Policy mode: + # - off: disabled + # - fetch_only: load existing server-managed gems only + # - create_on_demand: create missing managed gems when needed (rate-limited) + # - privacy: reserved for ephemeral request-time flow (startup sync skipped) + policy: "off" + create_rate_limit_per_minute: 4 + managed_gems_max_total: 200 + cleanup: + enabled: false + unused_days: 7 + touch_interval_minutes: 60 + dry_run: false + max_deletes_per_run: 5 + require_managed_marker: true + fetch_on_init: true + include_hidden_on_fetch: false + policies: + enabled: false + prefix: "fastapi_policy_" + default_policy: + enabled: false + key: "general_capability_guardrail" + # If `prompt` is null (or omitted), the implementation's built-in + # base system prompt will be used instead. + prompt: null ``` #### Environment Variables diff --git a/app/server/chat.py b/app/server/chat.py index dad65f5..c6ea207 100644 --- a/app/server/chat.py +++ b/app/server/chat.py @@ -661,6 +661,28 @@ def _prepare_messages_for_model( return prepared +def _extract_leading_system_prompt(messages: list[Message]) -> tuple[str | None, list[Message]]: + """Extract and remove leading system messages, returning joined system text. + + Only leading system messages are extracted to preserve regular conversation flow. + """ + if not messages: + return None, messages + + idx = 0 + system_parts: list[str] = [] + while idx < len(messages) and messages[idx].role == "system": + text = text_from_message(messages[idx]).strip() + if text: + system_parts.append(text) + idx += 1 + + if not system_parts: + return None, messages + + return "\n\n".join(system_parts), messages[idx:] + + def _response_items_to_messages( items: str | list[ResponseInputItem], ) -> tuple[list[Message], str | list[ResponseInputItem]]: @@ -1773,18 +1795,70 @@ async def create_chat_completion( structured_requirement = _build_structured_requirement(request.response_format) extra_instr = [structured_requirement.instruction] if structured_requirement else None - # This ensures that server-injected system instructions are part of the history - msgs = _prepare_messages_for_model( + # Split leading user-provided system prompt so we can attach it as a managed gem + # when create_on_demand is enabled. + system_prompt_text, non_system_messages = _extract_leading_system_prompt(request.messages) + system_only_request = bool(system_prompt_text) and not non_system_messages + + if not system_prompt_text: + non_system_messages = request.messages + + # Prepared messages with system prompt removed (candidate gem path). + msgs_without_system = _prepare_messages_for_model( + [] if system_only_request else non_system_messages, + request.tools, + request.tool_choice, + extra_instr, + ) + + # Prepared messages with full system prompt retained (fallback path). + msgs_with_system = _prepare_messages_for_model( request.messages, request.tools, request.tool_choice, extra_instr, ) + # Prefer searching reusable sessions against system-stripped history because + # gem-based sessions persist that history shape. + msgs = msgs_without_system if (system_prompt_text and not system_only_request) else msgs_with_system + session, client, remain = await _find_reusable_session(db, pool, model, msgs) reused_session = session is not None use_google_temporary_mode = g_config.gemini.chat_mode == ChatMode.TEMPORARY + # Fallback search for legacy sessions that still contain explicit system messages. + if ( + session is None + and system_prompt_text + and not system_only_request + and msgs_with_system != msgs_without_system + ): + session, client, remain = await _find_reusable_session(db, pool, model, msgs_with_system) + if session is not None: + msgs = msgs_with_system + + managed_system_gem_id: str | None = None + if system_prompt_text and not system_only_request: + target_client = client + if target_client is None: + target_client = await pool.acquire() + client = target_client + + managed_system_gem_id = await target_client.system_prompt_gem_id_or_create(system_prompt_text) + if managed_system_gem_id: + # When gem is available, keep system text out of the prompt payload. + msgs = msgs_without_system + if session is not None: + session.gem = managed_system_gem_id + else: + # Fall back to explicit system-text path. + msgs = msgs_with_system + + # If we changed message mode after initial reuse lookup, re-check reuse quickly. + if session is None and msgs in (msgs_without_system, msgs_with_system): + session, client, remain = await _find_reusable_session(db, pool, model, msgs) + if session: if not remain: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="No new messages.") @@ -1810,8 +1884,9 @@ async def create_chat_completion( ) else: try: - client = await pool.acquire() - session = client.start_chat(model=model) + if client is None: + client = await pool.acquire() + session = client.start_chat(model=model, gem=managed_system_gem_id) # Use the already prepared 'msgs' for a fresh session m_input, files = await _process_conversation_with_compaction( msgs, @@ -1972,12 +2047,31 @@ async def create_response( request.tool_choice if isinstance(request.tool_choice, (str, ToolChoiceFunction)) else None ) - messages = _prepare_messages_for_model( + # Split leading system/instruction content so it can be mapped to a managed + # gem when create_on_demand is enabled. + system_prompt_text, conv_without_system = _extract_leading_system_prompt(conv_messages) + system_only_conversation = bool(system_prompt_text) and not conv_without_system + if not system_prompt_text: + conv_without_system = conv_messages + + messages_without_system = _prepare_messages_for_model( + [] if system_only_conversation else conv_without_system, + standard_tools or None, + model_tool_choice, + extra_instr or None, + ) + messages_with_system = _prepare_messages_for_model( conv_messages, standard_tools or None, model_tool_choice, extra_instr or None, ) + messages = ( + messages_without_system + if (system_prompt_text and not system_only_conversation) + else messages_with_system + ) + pool, db = GeminiClientPool(), LMDBConversationStore() try: model = _get_model_by_name(request.model) @@ -1987,6 +2081,36 @@ async def create_response( session, client, remain = await _find_reusable_session(db, pool, model, messages) reused_session = session is not None use_google_temporary_mode = g_config.gemini.chat_mode == ChatMode.TEMPORARY + + # Fallback reuse search for legacy sessions that still included explicit system text. + if ( + session is None + and system_prompt_text + and not system_only_conversation + and messages_with_system != messages_without_system + ): + session, client, remain = await _find_reusable_session(db, pool, model, messages_with_system) + if session is not None: + messages = messages_with_system + + managed_system_gem_id: str | None = None + if system_prompt_text and not system_only_conversation: + target_client = client + if target_client is None: + target_client = await pool.acquire() + client = target_client + + managed_system_gem_id = await target_client.system_prompt_gem_id_or_create(system_prompt_text) + if managed_system_gem_id: + messages = messages_without_system + if session is not None: + session.gem = managed_system_gem_id + else: + messages = messages_with_system + + # If message shape changed after gem resolution, search reusable session again. + if session is None and messages in (messages_without_system, messages_with_system): + session, client, remain = await _find_reusable_session(db, pool, model, messages) if session: msgs = _prepare_messages_for_model( remain, @@ -2008,8 +2132,9 @@ async def create_response( ) else: try: - client = await pool.acquire() - session = client.start_chat(model=model) + if client is None: + client = await pool.acquire() + session = client.start_chat(model=model, gem=managed_system_gem_id) m_input, files = await _process_conversation_with_compaction( messages, tmp_dir, diff --git a/app/services/client.py b/app/services/client.py index b8f976b..fa746b8 100644 --- a/app/services/client.py +++ b/app/services/client.py @@ -1,8 +1,14 @@ +import asyncio +import hashlib +import time +from collections import deque +from dataclasses import dataclass from pathlib import Path from typing import Any, cast import orjson from gemini_webapi import GeminiClient, ModelOutput +from gemini_webapi.types import Gem from loguru import logger from app.models import Message @@ -14,6 +20,8 @@ save_url_to_tempfile, ) +from .policy_gems import PolicySyncResult, sync_policy_gems, touch_managed_description + _UNSET = object() @@ -24,9 +32,199 @@ def _resolve(value: Any, fallback: Any): class GeminiClientWrapper(GeminiClient): """Gemini client with helper methods.""" + @dataclass + class _ManagedGemRetry: + op: str + gem_id: str + attempt: int + next_retry_at: float + def __init__(self, client_id: str, **kwargs): super().__init__(**kwargs) self.id = client_id + self._gem_lock = asyncio.Lock() + self._policy_gem_ids: dict[str, str] = {} + self._system_prompt_gem_ids: dict[str, str] = {} + self._managed_gem_create_timestamps: deque[float] = deque() + self._managed_gem_last_touch_timestamps: dict[str, float] = {} + self._managed_gem_pending_touch_ids: set[str] = set() + self._managed_gem_touch_worker_task: asyncio.Task[None] | None = None + self._managed_gem_retry_queue: list[GeminiClientWrapper._ManagedGemRetry] = [] + self._managed_gem_metrics: dict[str, int] = { + "managed_gems_created": 0, + "managed_gems_updated": 0, + "managed_gems_deleted": 0, + "managed_gems_delete_dry_run": 0, + "managed_gems_skipped_missing_marker": 0, + "managed_gems_skipped_cap": 0, + "managed_gems_rate_limit_skips": 0, + "managed_gems_retry_enqueued": 0, + "managed_gems_retry_success": 0, + "managed_gems_retry_failed": 0, + "managed_gems_touch_updated": 0, + } + + def _acquire_managed_gem_create_budget(self, per_minute: int) -> int: + """Return remaining create budget in the current 60-second window.""" + now = time.monotonic() + window_start = now - 60.0 + while self._managed_gem_create_timestamps and self._managed_gem_create_timestamps[0] < window_start: + self._managed_gem_create_timestamps.popleft() + + used = len(self._managed_gem_create_timestamps) + return max(0, per_minute - used) + + def _consume_managed_gem_create_budget(self, count: int) -> None: + """Record managed gem create usage for rate limiting.""" + if count <= 0: + return + now = time.monotonic() + for _ in range(count): + self._managed_gem_create_timestamps.append(now) + + def _enqueue_retry(self, op: str, gem_id: str, attempt: int = 1) -> None: + delay_sec = min(300.0, float(2**attempt)) + self._managed_gem_retry_queue.append( + self._ManagedGemRetry( + op=op, + gem_id=gem_id, + attempt=attempt, + next_retry_at=time.time() + delay_sec, + ) + ) + self._managed_gem_metrics["managed_gems_retry_enqueued"] += 1 + + async def _process_managed_retry_queue(self) -> None: + """Process due retry operations (delete/touch) with backoff.""" + if not self._managed_gem_retry_queue: + return + + now = time.time() + due = [op for op in self._managed_gem_retry_queue if op.next_retry_at <= now] + self._managed_gem_retry_queue = [op for op in self._managed_gem_retry_queue if op.next_retry_at > now] + if not due: + return + + async with self._gem_lock: + gems = list(await self.fetch_gems(include_hidden=True)) + by_id = {gem.id: gem for gem in gems} + for retry in due: + try: + target = by_id.get(retry.gem_id) + if retry.op == "delete": + if target is not None and not target.predefined and target.name.startswith(g_config.gemini.gems.policies.prefix): + await self.delete_gem(target) + self._managed_gem_metrics["managed_gems_retry_success"] += 1 + elif retry.op == "touch": + if target is None: + self._managed_gem_metrics["managed_gems_retry_success"] += 1 + continue + if target.predefined or not target.name.startswith(g_config.gemini.gems.policies.prefix) or target.prompt is None: + self._managed_gem_metrics["managed_gems_retry_success"] += 1 + continue + + updated_description = touch_managed_description(target.description, now_ts=time.time()) + await self.update_gem( + gem=target, + name=target.name, + description=updated_description, + prompt=target.prompt, + ) + self._managed_gem_last_touch_timestamps[target.id] = time.time() + self._managed_gem_metrics["managed_gems_touch_updated"] += 1 + self._managed_gem_metrics["managed_gems_retry_success"] += 1 + except Exception: + self._managed_gem_metrics["managed_gems_retry_failed"] += 1 + self._enqueue_retry(retry.op, retry.gem_id, retry.attempt + 1) + + def _schedule_managed_policy_touch(self, gem_id: str) -> None: + """Queue gem usage touch updates and ensure a background worker exists.""" + gem_cfg = g_config.gemini.gems + if not gem_cfg.cleanup.enabled: + return + + self._managed_gem_pending_touch_ids.add(gem_id) + if self._managed_gem_touch_worker_task is None or self._managed_gem_touch_worker_task.done(): + self._managed_gem_touch_worker_task = asyncio.create_task( + self._managed_policy_touch_worker() + ) + + async def _managed_policy_touch_worker(self) -> None: + """Batch and flush pending managed gem touches in the background.""" + try: + while self._managed_gem_pending_touch_ids: + await asyncio.sleep(0.5) + await self._flush_managed_policy_touches() + finally: + self._managed_gem_touch_worker_task = None + + async def _flush_managed_policy_touches(self) -> None: + """Flush queued managed gem touch updates in a single fetch/update pass.""" + gem_cfg = g_config.gemini.gems + if not gem_cfg.cleanup.enabled: + self._managed_gem_pending_touch_ids.clear() + return + + pending = list(self._managed_gem_pending_touch_ids) + if not pending: + return + + self._managed_gem_pending_touch_ids.clear() + + now_ts = time.time() + min_interval_sec = gem_cfg.cleanup.touch_interval_minutes * 60 + + async with self._gem_lock: + gems = list(await self.fetch_gems(include_hidden=True)) + by_id = {gem.id: gem for gem in gems} + + for gem_id in pending: + last_touch = self._managed_gem_last_touch_timestamps.get(gem_id) + if last_touch is not None and now_ts - last_touch < min_interval_sec: + continue + + target = by_id.get(gem_id) + if target is None: + continue + if target.predefined: + continue + if not target.name.startswith(gem_cfg.policies.prefix): + continue + if target.prompt is None: + continue + + updated_description = touch_managed_description(target.description, now_ts=now_ts) + if (target.description or "") == updated_description: + self._managed_gem_last_touch_timestamps[gem_id] = now_ts + continue + + try: + await self.update_gem( + gem=target, + name=target.name, + description=updated_description, + prompt=target.prompt, + ) + self._managed_gem_last_touch_timestamps[gem_id] = now_ts + self._managed_gem_metrics["managed_gems_touch_updated"] += 1 + except Exception: + self._enqueue_retry("touch", gem_id) + + def _apply_policy_sync_result(self, sync_result: PolicySyncResult) -> None: + """Apply sync result into cache, metrics, and retry queue.""" + self._policy_gem_ids = sync_result.gem_ids + self._managed_gem_metrics["managed_gems_created"] += sync_result.created_count + self._managed_gem_metrics["managed_gems_updated"] += sync_result.updated_count + self._managed_gem_metrics["managed_gems_deleted"] += sync_result.deleted_count + self._managed_gem_metrics["managed_gems_delete_dry_run"] += sync_result.dry_run_delete_count + self._managed_gem_metrics["managed_gems_skipped_missing_marker"] += ( + sync_result.skipped_missing_marker_count + ) + self._managed_gem_metrics["managed_gems_skipped_cap"] += ( + sync_result.skipped_due_to_cap_count + ) + for failed_id in sync_result.failed_delete_ids: + self._enqueue_retry("delete", failed_id) async def init( self, @@ -59,6 +257,9 @@ async def init( refresh_interval=refresh_interval, verbose=verbose, ) + + # Keep gem cache and server-managed policy gems in a known-good state. + await self._initialize_gems() except Exception: logger.exception(f"Failed to initialize GeminiClient {self.id}") raise @@ -66,6 +267,313 @@ async def init( def running(self) -> bool: return self._running + async def _initialize_gems(self) -> None: + """Initialize gem cache and built-in policy gems based on server config.""" + gem_cfg = g_config.gemini.gems + if not gem_cfg.enabled: + return + + async with self._gem_lock: + include_hidden = gem_cfg.include_hidden_on_fetch + + if gem_cfg.fetch_on_init: + await self.fetch_gems(include_hidden=include_hidden) + + policy_mode = gem_cfg.policy + if policy_mode == "off": + return + + if policy_mode == "privacy": + logger.warning( + "gemini.gems.policy='privacy' is intended for request-time ephemeral flow; " + "startup policy sync is skipped" + ) + return + + # Force include_hidden=True during managed-policy sync so hidden + # server-managed gems are discovered. + if policy_mode in ("fetch_only", "create_on_demand"): + default_prompt = None + policy_dp = getattr(gem_cfg.policies, "default_policy", None) + if policy_dp and getattr(policy_dp, "enabled", False): + default_prompt = getattr(policy_dp, "prompt", None) + + create_budget = None + if policy_mode == "create_on_demand": + create_budget = self._acquire_managed_gem_create_budget( + gem_cfg.create_rate_limit_per_minute + ) + if create_budget <= 0: + self._managed_gem_metrics["managed_gems_rate_limit_skips"] += 1 + + cleanup_days = None + if gem_cfg.cleanup.enabled: + cleanup_days = gem_cfg.cleanup.unused_days + + sync_result: PolicySyncResult = await sync_policy_gems( + self, + prefix=gem_cfg.policies.prefix, + include_hidden=True, + default_prompt=default_prompt, + mode=policy_mode, + create_budget=create_budget, + cleanup_unused_days=cleanup_days, + cleanup_dry_run=gem_cfg.cleanup.dry_run, + cleanup_max_deletes_per_run=gem_cfg.cleanup.max_deletes_per_run, + cleanup_require_managed_marker=gem_cfg.cleanup.require_managed_marker, + managed_max_total=gem_cfg.managed_gems_max_total, + ) + self._apply_policy_sync_result(sync_result) + + if policy_mode == "create_on_demand": + self._consume_managed_gem_create_budget(sync_result.created_count) + + logger.info( + "Managed gem sync stats client='{}': created={}, updated={}, deleted={}, " + "dry_run_deletes={}, retries_queued={}, managed_total={}", + self.id, + sync_result.created_count, + sync_result.updated_count, + sync_result.deleted_count, + sync_result.dry_run_delete_count, + len(sync_result.failed_delete_ids), + sync_result.managed_total_count, + ) + + # Refresh once more so callers can immediately read the final state. + await self.fetch_gems(include_hidden=include_hidden) + + await self._process_managed_retry_queue() + + def policy_gem_id(self, key: str) -> str | None: + """Return a synced policy gem id for a logical key, or None when unavailable.""" + gem_id = self._policy_gem_ids.get(key) + if gem_id: + try: + asyncio.get_running_loop() + self._schedule_managed_policy_touch(gem_id) + except RuntimeError: + # No running loop in this context. + pass + return gem_id + + async def policy_gem_id_or_create(self, key: str) -> str | None: + """Return policy gem id, creating a missing managed gem on-demand when allowed. + + On-demand creation is only attempted when: + - gem management is enabled, + - `gemini.gems.policy` is `create_on_demand`, and + - create-rate/cap limits allow creating more managed gems. + """ + existing = self.policy_gem_id(key) + if existing is not None: + return existing + + gem_cfg = g_config.gemini.gems + if not gem_cfg.enabled: + return None + if gem_cfg.policy != "create_on_demand": + return None + + async with self._gem_lock: + # Re-check after acquiring lock in case another coroutine just synced. + existing = self._policy_gem_ids.get(key) + if existing is not None: + return existing + + create_budget = self._acquire_managed_gem_create_budget( + gem_cfg.create_rate_limit_per_minute + ) + if create_budget <= 0: + self._managed_gem_metrics["managed_gems_rate_limit_skips"] += 1 + return None + + default_prompt = None + policy_dp = getattr(gem_cfg.policies, "default_policy", None) + if policy_dp and getattr(policy_dp, "enabled", False): + default_prompt = getattr(policy_dp, "prompt", None) + + cleanup_days = gem_cfg.cleanup.unused_days if gem_cfg.cleanup.enabled else None + sync_result = await sync_policy_gems( + self, + prefix=gem_cfg.policies.prefix, + include_hidden=True, + default_prompt=default_prompt, + mode="create_on_demand", + create_budget=create_budget, + cleanup_unused_days=cleanup_days, + cleanup_dry_run=gem_cfg.cleanup.dry_run, + cleanup_max_deletes_per_run=gem_cfg.cleanup.max_deletes_per_run, + cleanup_require_managed_marker=gem_cfg.cleanup.require_managed_marker, + managed_max_total=gem_cfg.managed_gems_max_total, + ) + self._apply_policy_sync_result(sync_result) + self._consume_managed_gem_create_budget(sync_result.created_count) + + created_or_found = self._policy_gem_ids.get(key) + if created_or_found is not None: + try: + asyncio.get_running_loop() + self._schedule_managed_policy_touch(created_or_found) + except RuntimeError: + pass + return created_or_found + + async def system_prompt_gem_id_or_create(self, system_prompt: str) -> str | None: + """Return/create a managed gem id for a raw system prompt text. + + This supports request-time prompt de-duplication: same system prompt will + map to the same managed gem name (hash-based) and be cached in memory. + """ + prompt = (system_prompt or "").strip() + if not prompt: + return None + + gem_cfg = g_config.gemini.gems + if not gem_cfg.enabled: + return None + + policy_mode = gem_cfg.policy + if policy_mode not in {"fetch_only", "create_on_demand"}: + return None + + prompt_hash = hashlib.sha256(prompt.encode("utf-8")).hexdigest() + cache_key = f"sys:{prompt_hash}" + cached = self._system_prompt_gem_ids.get(cache_key) + if cached is not None: + return cached + + name = f"{gem_cfg.policies.prefix}sys_{prompt_hash[:24]}" + + async with self._gem_lock: + cached = self._system_prompt_gem_ids.get(cache_key) + if cached is not None: + return cached + + gems = list(await self.fetch_gems(include_hidden=True)) + custom_gems = [gem for gem in gems if not gem.predefined] + + existing = next((gem for gem in custom_gems if gem.name == name), None) + if existing is not None: + self._system_prompt_gem_ids[cache_key] = existing.id + return existing.id + + if policy_mode != "create_on_demand": + return None + + create_budget = self._acquire_managed_gem_create_budget( + gem_cfg.create_rate_limit_per_minute + ) + if create_budget <= 0: + self._managed_gem_metrics["managed_gems_rate_limit_skips"] += 1 + return None + + managed_total = len( + [gem for gem in custom_gems if gem.name.startswith(gem_cfg.policies.prefix)] + ) + if managed_total >= gem_cfg.managed_gems_max_total: + self._managed_gem_metrics["managed_gems_skipped_cap"] += 1 + return None + + description = "Managed system prompt gem created on-demand from API system message." + created = await self.create_gem(name=name, prompt=prompt, description=description) + self._consume_managed_gem_create_budget(1) + self._managed_gem_metrics["managed_gems_created"] += 1 + self._system_prompt_gem_ids[cache_key] = created.id + return created.id + + def managed_gem_metrics(self) -> dict[str, int]: + """Return a copy of managed gem lifecycle counters.""" + return dict(self._managed_gem_metrics) + + async def refresh_gems(self, include_hidden: bool | None = None) -> list[Gem]: + """Fetch gems from Gemini and return a plain list for API responses.""" + gem_cfg = g_config.gemini.gems + use_hidden = gem_cfg.include_hidden_on_fetch if include_hidden is None else include_hidden + + async with self._gem_lock: + gem_jar = await self.fetch_gems(include_hidden=use_hidden) + return list(gem_jar) + + def list_cached_gems(self) -> list[Gem]: + """Return cached gems, or an empty list when cache is not initialized yet.""" + try: + return list(self.gems) + except RuntimeError: + return [] + + @staticmethod + def _find_gem_in_list(gems: list[Gem], gem_ref: str) -> Gem | None: + """Find a gem in a list by id or case-insensitive name.""" + ref_stripped = (gem_ref or "").strip() + normalized = ref_stripped.lower() + for gem in gems: + gem_id = (gem.id or "").strip() + if gem_id == ref_stripped or gem.name.lower() == normalized: + return gem + return None + + async def get_gem(self, gem_ref: str, include_hidden: bool | None = None) -> Gem: + """Find a gem by id or name. Name matching is case-insensitive.""" + gems = self.list_cached_gems() + if not gems: + gems = await self.refresh_gems(include_hidden=include_hidden) + + found = self._find_gem_in_list(gems, gem_ref) + if found is not None: + return found + + raise ValueError(f"Gem '{gem_ref}' not found") + + async def create_custom_gem(self, name: str, prompt: str, description: str = "") -> Gem: + """Create a custom gem and refresh local cache.""" + async with self._gem_lock: + created = await self.create_gem(name=name, prompt=prompt, description=description) + await self.fetch_gems(include_hidden=g_config.gemini.gems.include_hidden_on_fetch) + return created + + async def update_custom_gem( + self, gem_ref: str, name: str, prompt: str, description: str = "" + ) -> Gem: + """Update a custom gem identified by id or name and refresh local cache.""" + async with self._gem_lock: + gems = self.list_cached_gems() + if not gems: + gems = list( + await self.fetch_gems( + include_hidden=g_config.gemini.gems.include_hidden_on_fetch, + ) + ) + target = self._find_gem_in_list(gems, gem_ref) + if target is None: + raise ValueError(f"Gem '{gem_ref}' not found") + + updated = await self.update_gem( + gem=target, + name=name, + prompt=prompt, + description=description, + ) + await self.fetch_gems(include_hidden=g_config.gemini.gems.include_hidden_on_fetch) + return updated + + async def delete_custom_gem(self, gem_ref: str) -> None: + """Delete a custom gem identified by id or name and refresh local cache.""" + async with self._gem_lock: + gems = self.list_cached_gems() + if not gems: + gems = list( + await self.fetch_gems( + include_hidden=g_config.gemini.gems.include_hidden_on_fetch, + ) + ) + target = self._find_gem_in_list(gems, gem_ref) + if target is None: + raise ValueError(f"Gem '{gem_ref}' not found") + + await self.delete_gem(target) + await self.fetch_gems(include_hidden=g_config.gemini.gems.include_hidden_on_fetch) + @staticmethod async def process_message( message: Message, tempdir: Path | None = None, tagged: bool = True, wrap_tool: bool = True diff --git a/app/services/policy_gems.py b/app/services/policy_gems.py new file mode 100644 index 0000000..f55b05d --- /dev/null +++ b/app/services/policy_gems.py @@ -0,0 +1,285 @@ +from __future__ import annotations + +import json +import time +from dataclasses import dataclass +from typing import Any, Literal + +from gemini_webapi import GeminiClient +from gemini_webapi.types import Gem + +from app.utils import g_config + + +@dataclass(frozen=True) +class PolicyGemSpec: + """Declarative definition of a server-managed policy gem.""" + + key: str + name: str + description: str + prompt: str + + +@dataclass(frozen=True) +class PolicySyncResult: + """Result payload for managed policy gem synchronization.""" + + gem_ids: dict[str, str] + created_count: int + updated_count: int + deleted_count: int + dry_run_delete_count: int + failed_delete_ids: list[str] + skipped_missing_marker_count: int + skipped_due_to_cap_count: int + managed_total_count: int + + +_META_MARKER = "\n\n[gemini_fastapi_meta]" + + +def _split_description_meta(description: str | None) -> tuple[str, dict[str, Any]]: + """Split managed metadata suffix from gem description.""" + text = description or "" + marker_index = text.rfind(_META_MARKER) + if marker_index == -1: + return text, {} + + base = text[:marker_index] + raw_meta = text[marker_index + len(_META_MARKER) :].strip() + if not raw_meta: + return base, {} + + try: + parsed = json.loads(raw_meta) + if isinstance(parsed, dict): + return base, parsed + except json.JSONDecodeError: + pass + return base, {} + + +def _compose_description_with_meta(base_description: str, last_used_at: float) -> str: + """Compose description with stable managed metadata suffix.""" + meta = { + "managed_by": "gemini_fastapi", + "last_used_at": int(last_used_at), + } + return f"{base_description}{_META_MARKER}{json.dumps(meta, separators=(',', ':'))}" + + +def touch_managed_description(description: str | None, now_ts: float) -> str: + """Return description with refreshed managed last_used timestamp.""" + base_description, _meta = _split_description_meta(description) + return _compose_description_with_meta(base_description, last_used_at=now_ts) + + +def extract_managed_last_used_at(description: str | None) -> int | None: + """Return managed `last_used_at` unix timestamp from description metadata.""" + _base, meta = _split_description_meta(description) + value = meta.get("last_used_at") + if isinstance(value, int) and value > 0: + return value + return None + + +def has_managed_marker(description: str | None) -> bool: + """Return whether a description contains Gemini-FastAPI managed metadata.""" + _base, meta = _split_description_meta(description) + return meta.get("managed_by") == "gemini_fastapi" + + +def _build_specs(prefix: str, default_prompt: str | None = None) -> list[PolicyGemSpec]: + """Return built-in policy gems that should exist for every configured client. + + `default_prompt` may be supplied (from config) to override the built-in prompt. + """ + # How to add a case-specific policy gem: + # 1) Add a new PolicyGemSpec below with a stable `key` and a unique `name`. + # 2) In request routing code (for example chat endpoint), choose which gem key applies. + # 3) Resolve the gem id via `await client.policy_gem_id_or_create("your_key")` + # when using `create_on_demand`, or `client.policy_gem_id("your_key")` + # for fetch-only behavior, then pass that id only + # when the request matches your condition. + # Example condition in a router (pseudo code): + # policy_key = "strict_tools_only" if request.tools else "general_capability_guardrail" + # policy_id = await client.policy_gem_id_or_create(policy_key) + # if policy_id: + # await session.send_message(..., gemini_options={"gem_id": policy_id}) + if default_prompt is None: + general_guardrail_prompt = ( + "You are operating behind an OpenAI-compatible Gemini wrapper.\n" + ) + else: + general_guardrail_prompt = default_prompt + + return [ + PolicyGemSpec( + key="general_capability_guardrail", + name=f"{prefix}general_capability_guardrail", + description="General capability policy for unsupported video/audio generation paths.", + prompt=general_guardrail_prompt, + ) + ] + + +async def _upsert_gem( + client: GeminiClient, + spec: PolicyGemSpec, + existing: Gem | None, +) -> tuple[Gem, bool, bool]: + """Create the policy gem if missing, or update it when the content changed.""" + now_ts = time.time() + desired_description = _compose_description_with_meta(spec.description, last_used_at=now_ts) + + if existing is None: + created = await client.create_gem( + name=spec.name, + description=desired_description, + prompt=spec.prompt, + ) + + return created, True, False + + existing_base_description, _existing_meta = _split_description_meta(existing.description) + if existing_base_description != spec.description or (existing.prompt or "") != spec.prompt: + updated = await client.update_gem( + gem=existing, + name=spec.name, + description=desired_description, + prompt=spec.prompt, + ) + return updated, False, True + + # Backfill metadata on older managed gems that predate managed suffix. + if extract_managed_last_used_at(existing.description) is None: + updated = await client.update_gem( + gem=existing, + name=spec.name, + description=desired_description, + prompt=spec.prompt, + ) + return updated, False, True + + return existing, False, False + + +async def sync_policy_gems( + client: GeminiClient, + prefix: str = "fastapi_policy_", + include_hidden: bool | None = None, + default_prompt: str | None = None, + mode: Literal["fetch_only", "create_on_demand"] = "fetch_only", + create_budget: int | None = None, + cleanup_unused_days: int | None = None, + cleanup_dry_run: bool = False, + cleanup_max_deletes_per_run: int | None = None, + cleanup_require_managed_marker: bool = True, + managed_max_total: int | None = None, +) -> PolicySyncResult: + """Synchronize built-in policy gems and return a map from policy key to gem id. + + By default the runtime config `g_config.gemini.gems.include_hidden_on_fetch` is used + unless `include_hidden` is explicitly provided. Callers may pass `include_hidden=True` + to ensure hidden gems are included during the sync (recommended when reconciling + hidden policy gems). + + Modes: + - `fetch_only`: read existing managed gems and build id mapping only. + - `create_on_demand`: create/update only the managed gem specs, without deleting extras. + + `create_budget` limits how many new managed gems can be created during this run. + `cleanup_unused_days` removes managed prefixed gems whose last-used metadata + is older than the configured threshold. + `cleanup_dry_run` logs stale candidates without deleting. + `cleanup_max_deletes_per_run` caps deletions for each sync pass. + `cleanup_require_managed_marker` restricts deletion to managed-marker gems. + `managed_max_total` caps total server-managed gems with this prefix. + """ + + prefix = (prefix or "fastapi_policy_").strip() or "fastapi_policy_" + # Default include_hidden to the runtime config when not explicitly provided. + use_hidden = include_hidden if include_hidden is not None else g_config.gemini.gems.include_hidden_on_fetch + specs = _build_specs(prefix, default_prompt=default_prompt) + await client.fetch_gems(include_hidden=use_hidden) + custom_gems = [gem for gem in client.gems if not gem.predefined] + + deleted_count = 0 + dry_run_delete_count = 0 + failed_delete_ids: list[str] = [] + skipped_missing_marker_count = 0 + + max_deletes_left = cleanup_max_deletes_per_run + + if cleanup_unused_days is not None and cleanup_unused_days > 0: + cutoff_ts = int(time.time() - cleanup_unused_days * 24 * 60 * 60) + for gem in custom_gems: + if not gem.name.startswith(prefix): + continue + if cleanup_require_managed_marker and not has_managed_marker(gem.description): + skipped_missing_marker_count += 1 + continue + last_used_at = extract_managed_last_used_at(gem.description) + if last_used_at is None: + continue + if last_used_at < cutoff_ts: + if max_deletes_left is not None and max_deletes_left <= 0: + continue + if cleanup_dry_run: + dry_run_delete_count += 1 + continue + try: + await client.delete_gem(gem) + deleted_count += 1 + if max_deletes_left is not None: + max_deletes_left -= 1 + except Exception: + failed_delete_ids.append(gem.id) + + # Refresh inventory after cleanup deletions. + if deleted_count > 0: + await client.fetch_gems(include_hidden=use_hidden) + custom_gems = [gem for gem in client.gems if not gem.predefined] + + managed_gems = [gem for gem in custom_gems if gem.name.startswith(prefix)] + single_by_name = {gem.name: gem for gem in managed_gems} + managed_total_count = len(managed_gems) + + result: dict[str, str] = {} + created_count = 0 + updated_count = 0 + skipped_due_to_cap_count = 0 + for spec in specs: + existing = single_by_name.get(spec.name) + if mode == "fetch_only": + if existing is None: + continue + gem = existing + else: + if existing is None and managed_max_total is not None and managed_total_count >= managed_max_total: + skipped_due_to_cap_count += 1 + continue + if existing is None and create_budget is not None and create_budget <= 0: + continue + gem, created, updated = await _upsert_gem(client, spec=spec, existing=existing) + if created: + managed_total_count += 1 + created_count += 1 + if create_budget is not None: + create_budget -= 1 + if updated: + updated_count += 1 + result[spec.key] = gem.id + + return PolicySyncResult( + gem_ids=result, + created_count=created_count, + updated_count=updated_count, + deleted_count=deleted_count, + dry_run_delete_count=dry_run_delete_count, + failed_delete_ids=failed_delete_ids, + skipped_missing_marker_count=skipped_missing_marker_count, + skipped_due_to_cap_count=skipped_due_to_cap_count, + managed_total_count=managed_total_count, + ) diff --git a/app/utils/config.py b/app/utils/config.py index fe806c9..71090ea 100644 --- a/app/utils/config.py +++ b/app/utils/config.py @@ -72,6 +72,112 @@ def _parse_json_string(cls, v: Any) -> Any: return v +class GeminiGemDefaultPolicyConfig(BaseModel): + """Configuration for the optional default managed policy gem.""" + + enabled: bool = Field( + default=False, + description="Create or update the default managed policy gem", + ) + key: str = Field( + default="general_capability_guardrail", + description="Logical policy key to map for default policy gem", + ) + prompt: str | None = Field( + default=None, + description=( + "Prompt override for the default policy gem; null uses the built-in prompt" + ), + ) + + +class GeminiGemPoliciesConfig(BaseModel): + """Configuration for built-in policy gems managed by the server.""" + + enabled: bool = Field( + default=False, + description="Deprecated flag. Prefer `gemini.gems.policy` mode selection", + ) + prefix: str = Field( + default="fastapi_policy_", + description="Name prefix used to identify policy gems created by this server", + ) + default_policy: GeminiGemDefaultPolicyConfig = Field( + default=GeminiGemDefaultPolicyConfig(), + description="Optional default policy gem bootstrap settings", + ) + + +class GeminiGemCleanupConfig(BaseModel): + """Cleanup policy for server-managed gems.""" + + enabled: bool = Field( + default=False, + description="Enable deletion of managed gems that have not been used recently", + ) + unused_days: int = Field( + default=7, + ge=1, + description="Delete managed gems that were not used for this many days", + ) + touch_interval_minutes: int = Field( + default=60, + ge=1, + description="Minimum minutes between usage-touch metadata updates for the same gem", + ) + dry_run: bool = Field( + default=False, + description="Log cleanup candidates without deleting them", + ) + max_deletes_per_run: int = Field( + default=5, + ge=1, + description="Maximum managed-gem deletions per synchronization run", + ) + require_managed_marker: bool = Field( + default=True, + description="Delete only gems that contain Gemini-FastAPI managed marker metadata", + ) + + +class GeminiGemsConfig(BaseModel): + """Configuration for gem behaviors exposed by the API.""" + + enabled: bool = Field(default=False, description="Enable gem API endpoints") + policy: Literal["off", "fetch_only", "create_on_demand", "privacy"] = Field( + default="off", + description=( + "Policy gem mode: off=disabled, fetch_only=read existing prefixed gems only, " + "create_on_demand=create missing managed gems, privacy=ephemeral mode" + ), + ) + create_rate_limit_per_minute: int = Field( + default=12, + ge=1, + description="Maximum server-managed gem creations per minute per client", + ) + managed_gems_max_total: int = Field( + default=200, + ge=1, + description="Maximum number of managed gems (by prefix) allowed per client", + ) + fetch_on_init: bool = Field( + default=True, + description="Fetch and cache gem inventory during client initialization", + ) + include_hidden_on_fetch: bool = Field( + default=False, + description="Include hidden gems when fetching gem inventory", + ) + policies: GeminiGemPoliciesConfig = Field( + default=GeminiGemPoliciesConfig(), + description="Built-in policy gem synchronization settings", + ) + cleanup: GeminiGemCleanupConfig = Field( + default=GeminiGemCleanupConfig(), + description="Cleanup policy for managed gems", + ) + class OversizedContextStrategy(str, Enum): """Strategy for handling oversized context.""" @@ -93,6 +199,10 @@ class GeminiConfig(BaseModel): ..., description="List of Gemini client credential pairs" ) models: list[GeminiModelConfig] = Field(default=[], description="List of custom Gemini models") + gems: GeminiGemsConfig = Field( + default=GeminiGemsConfig(), + description="Gem endpoint and synchronization settings", + ) model_strategy: Literal["append", "overwrite"] = Field( default="append", description="Strategy for loading models: 'append' merges custom with default, 'overwrite' uses only custom", diff --git a/config/config.yaml b/config/config.yaml index 462167f..0906afc 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -32,6 +32,39 @@ gemini: chat_mode: "normal" # "normal" reuses Google chat metadata; "temporary" sends with Google's temporary mode (not saved to account) and uses a 90% effective input limit model_strategy: "append" # Strategy: 'append' (default + custom) or 'overwrite' (custom only) models: [] + # Gem management (custom 'gems' managed in your Gemini account). + # WARNING: Enabling policy sync may create, update, or delete gems in the + # associated Gemini account if it has our identifiers. + gems: + enabled: false # Enable gem API endpoints (defaults to OFF) + # Policy mode: + # - off: disabled + # - fetch_only: load existing server-managed gems (prefix-filtered), never create/update/delete + # - create_on_demand: create missing managed gems up to rate-limit budget + # - privacy: reserved for ephemeral request-time flow (startup sync is skipped) + policy: "off" + create_rate_limit_per_minute: 4 # Per-client budget for server-managed gem creations + managed_gems_max_total: 200 # Per-client cap for managed gems with our prefix + cleanup: + enabled: false # Delete managed gems when unused for `unused_days` + unused_days: 7 # Delete if not used for this many days + touch_interval_minutes: 60 # Min minutes between metadata touch updates per gem + dry_run: false # If true, logs stale candidates without deleting + max_deletes_per_run: 5 # Safety cap for deletions in a single sync pass + require_managed_marker: true # Delete only gems carrying Gemini-FastAPI marker + fetch_on_init: true # Fetch and cache gems when each client starts + include_hidden_on_fetch: false # Include hidden gems when fetching inventory + policies: + enabled: false # Keep built-in policy gems synced for every client + prefix: "fastapi_policy_" # Prefix used for server-managed policy gems + # Optional: control a single default policy gem and its prompt text. + default_policy: + enabled: false # Create/update the default policy gem when true + key: "general_capability_guardrail" + # Base prompt used when creating the default policy gem. This can be + # overridden per-deployment. If `prompt` is set to null (or omitted), + # the module's built-in base system prompt will be used instead. + prompt: null storage: path: "data/lmdb" # Database storage path