diff --git a/CHANGELOG.md b/CHANGELOG.md index 572ac40257..955c1f066e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,156 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [0.8.0] - 2026-02-12 + +### Added + +- 📊 **Analytics dashboard.** Administrators now have access to an Analytics dashboard showing model usage statistics, token consumption by model and user, user activity rankings, and time-series charts with hourly or daily granularity; clicking any model opens a detail view with feedback history, associated tags, and chat browser, and results can be filtered by user group. [#21106](https://github.com/open-webui/open-webui/pull/21106), [Commit](https://github.com/open-webui/open-webui/commit/68a1e87b66a7ec8831d5ed52940c4ef110e3e264), [Commit](https://github.com/open-webui/open-webui/commit/e62649f94044abfed4d7d60647a2050383a67e3d) +- 🎯 **Experimental support for Skills.** Open WebUI now supports the Skill standard — allowing users to create and manage reusable AI skills with detailed instructions, reference them in chats using the "$" command, or attach them to specific models for automatic context in conversations. [#21312](https://github.com/open-webui/open-webui/pull/21312) +- 🧪 **Experimental support for Open Responses protocol.** Connections can now be configured to use the experimental Open Responses protocol instead of Chat Completions, enabling native support for extended thinking, streaming reasoning tokens, and richer tool call handling for compatible providers. [Commit](https://github.com/open-webui/open-webui/commit/d2c695eb11ddca9fc93499bb0c3fcafcff7099b5), [Commit](https://github.com/open-webui/open-webui/commit/90a057f4005c000bda6ff8703e13e529190af73a), [Commit](https://github.com/open-webui/open-webui/commit/0dc74a8a2e7adb76fb503ef0cd3c02daddd2f4bb), [Commit](https://github.com/open-webui/open-webui/commit/ea9c58ea80646cef05e06d0beaf5e81cc2f78cb1), [Commit](https://github.com/open-webui/open-webui/commit/6ffce4bccdc13b8b61a8b286e34094c981932eda), [Commit](https://github.com/open-webui/open-webui/commit/6719558150920f570d8febe021da65903e53c976), [Commit](https://github.com/open-webui/open-webui/commit/117c091b95a1b1a76a31c31b97304bac289d6f18), [Commit](https://github.com/open-webui/open-webui/commit/aa8c2959ca8476f269786e1317fb6d2938abd3f9), [Commit](https://github.com/open-webui/open-webui/commit/e2d09ac36174de48a7d85bafc8d3291c9ffe44cd) +- 👥 **Redesigned access control UI.** The access control UI was redesigned with a more intuitive interface that makes it easier to add multiple groups at once. [#21277](https://github.com/open-webui/open-webui/pull/21277) +- 👤 **Per-user resource sharing.** Resources including knowledge bases, prompts, models, tools, channels, and base models can now be shared directly to individual users alongside the existing per-group sharing capability. [#21277](https://github.com/open-webui/open-webui/pull/21277) +- 📨 **Message queuing.** Messages can now be queued while a response is generating rather than being blocked, allowing you to continue your train of thought; queued messages are automatically combined and sent when generation completes, and can be edited, deleted, or sent immediately from the input area. [Commit](https://github.com/open-webui/open-webui/commit/62750b8980ef0a3f2da7bc64b5416706a7495686), [Commit](https://github.com/open-webui/open-webui/commit/d3f2cf74748db42311ca04a56ccd1ea15399eca0) +- 💡 **Active task sidebar indicator.** Users can now see which chats have active tasks running directly in the sidebar. [Commit](https://github.com/open-webui/open-webui/commit/48522271586a5bf24b649610f03b4ffd8afb2782) +- 📝 **Prompt version control.** Prompts now include version control with full history tracking, allowing users to commit changes with messages, view past versions, compare differences between versions, and roll back to previous versions when needed. [#20945](https://github.com/open-webui/open-webui/pull/20945) +- 🏷️ **Prompt tags.** Prompts can now be organized with tags, and users can filter the prompt workspace by tag to quickly find related prompts across large collections. [#20945](https://github.com/open-webui/open-webui/pull/20945) +- 🐍 **Native function calling code execution.** Code execution now works with Native function calling mode, allowing models to autonomously run Python code for calculations, data analysis, and visualizations without requiring Default mode. [#20592](https://github.com/open-webui/open-webui/pull/20592), [Docs:#998](https://github.com/open-webui/docs/pull/998) +- 🚀 **Async web search.** Web search operations now run asynchronously in the background, allowing users to continue interacting with the application while searches complete. [#20630](https://github.com/open-webui/open-webui/pull/20630) +- ⚡ **Search debouncing.** Search operations across the application now respond more efficiently with debouncing that reduces unnecessary server requests while typing, improving responsiveness when searching users, groups, functions, tools, prompts, knowledge bases, notes, and when using the knowledge and prompts commands in chat. [#20982](https://github.com/open-webui/open-webui/pull/20982), [Commit](https://github.com/open-webui/open-webui/commit/36766f157d46102fd76c526b42579400ca70de50), [Commit](https://github.com/open-webui/open-webui/commit/fa859de460376782bd0fa35512c8426c9cd0462c), [Commit](https://github.com/open-webui/open-webui/commit/57ec2aa088ffd5a8c3553c53d39799497ff70479) +- 🤝 **Shared chats management.** Users can now view and manage all their shared chats from Settings, with options to copy share links or unshare conversations they no longer want public. [Commit](https://github.com/open-webui/open-webui/commit/a10ac774ab5d47b505e840b029c0c0340002508b) +- 📁 **User file management.** Users can now view, search, and delete all their uploaded files from Settings, providing centralized file management in one place. [Commit](https://github.com/open-webui/open-webui/commit/93ed4ae2cda2f4311143e51f586aaa73b83a37a7), [#21047](https://github.com/open-webui/open-webui/pull/21047) +- 🗑️ **Shift-click quick delete.** Files in the File Manager can now be quickly deleted by holding Shift and clicking the delete button, bypassing the confirmation dialog for faster bulk cleanup. [#21044](https://github.com/open-webui/open-webui/pull/21044) +- ⌨️ **Model selector shortcut.** The model selector can now be opened with Ctrl+Shift+M keyboard shortcut. [#21130](https://github.com/open-webui/open-webui/pull/21130) +- 🧠 **Smarter knowledge vs web search.** Models now choose more intelligently between knowledge base search and web search rather than always trying knowledge first. [#21115](https://github.com/open-webui/open-webui/pull/21115) +- 🌍 **Community model reviews.** Users can now access community reviews for models directly from the model selector menu and are prompted to leave reviews after rating responses, with administrators able to disable this via the "Community Sharing" setting. [Commit](https://github.com/open-webui/open-webui/commit/bc90463ea60c9a66accb1fd242cf1853910ca838) +- 📄 **Prompts workspace pagination.** The prompts workspace now includes pagination for large prompt collections, loading 30 prompts at a time with search, filtering, and sorting capabilities for improved performance and navigation. [Commit](https://github.com/open-webui/open-webui/commit/36766f157d46102fd76c526b42579400ca70de50) +- 🎨 **Action function HTML rendering.** Action functions can now render rich HTML content directly in chat as embedded iframes, matching the capabilities that tools already had and eliminating the need for action authors to inject codeblocks. [#21294](https://github.com/open-webui/open-webui/pull/21294), [Commit](https://github.com/open-webui/open-webui/commit/60ada21c152ed642971429fdbe88dcbf478cf83a) +- 🔒 **Password-masked valve fields.** Tool and function developers can now mark sensitive fields as passwords, which are automatically masked in the settings UI to prevent shoulder surfing and accidental exposure. [#20852](https://github.com/open-webui/open-webui/issues/20852), [Commit](https://github.com/open-webui/open-webui/commit/8c70453b2e3a6958437d951751e84acbbaafd9aa) +- 📋 **Prompt quick copy.** Prompts in the workspace now include a quick copy button for easily copying prompt content to the clipboard. [Commit](https://github.com/open-webui/open-webui/commit/78f856e2049991441a3469230ae52799cb86954e) +- 🔔 **Dismissible notification toasts.** Notification toasts for new messages and other events now include a close button that appears on hover, allowing users to dismiss them immediately instead of waiting for auto-dismissal. [#21056](https://github.com/open-webui/open-webui/issues/21056), [Commit](https://github.com/open-webui/open-webui/commit/73bb600034c8532e30726129743a5ffe9002c5fb) +- 🔔 **Temporary chat notification privacy.** Notifications from temporary chats now only appear on the device where the chat is running, preventing privacy leaks across logged-in sessions. [#21292](https://github.com/open-webui/open-webui/pull/21292) +- 💡 **Null chat title fallback.** Notifications without chat titles now display "New Chat" instead of showing null. [#21292](https://github.com/open-webui/open-webui/pull/21292) +- 🖼️ **Concurrent image editing.** Image editing operations with multiple images now complete faster by loading all images concurrently instead of sequentially. [#20911](https://github.com/open-webui/open-webui/pull/20911) +- 📧 **USER_EMAIL template variable.** Users can now reference their email address in prompts and system messages using the "{{USER_EMAIL}}" template variable. [#20881](https://github.com/open-webui/open-webui/pull/20881) +- 🔤 **Alphabetical tool ordering.** Tools and Functions in the Chat Controls sidebar now appear in alphabetical order, making it easier to locate specific tools when working with multiple integrations. [#20871](https://github.com/open-webui/open-webui/pull/20871) +- 👁️ **Model list status filtering.** Administrators can now filter the model list by status (enabled, disabled, visible, hidden) and bulk enable or disable all filtered models at once. [#20553](https://github.com/open-webui/open-webui/issues/20553), [#20774](https://github.com/open-webui/open-webui/issues/20774), [Commit](https://github.com/open-webui/open-webui/commit/96a9696383d450dad2cbb230f3756ebfa258e029) +- ⚙️ **Per-model built-in tool toggles.** Administrators can now enable or disable individual built-in tools for each model, including time utilities, memory, chat history, notes, knowledge base, and channels. [#20641](https://github.com/open-webui/open-webui/issues/20641), [Commit](https://github.com/open-webui/open-webui/commit/c46ef3b63bcc1e2e9adbdd18fab82c4bbe33ff6c) +- 📑 **PDF loading modes.** Administrators can now choose between "page" and "single" PDF loading modes, allowing documents to be processed as individual pages or as complete documents for better chunking across page boundaries. [Commit](https://github.com/open-webui/open-webui/commit/ecbdef732bc71a07c21bbb679edb420f26eac181) +- 📑 **Model Settings pagination.** Administrators can now navigate large model lists more efficiently in Model Settings, with pagination displaying 30 models per page for smoother navigation. [Commit](https://github.com/open-webui/open-webui/commit/2f584c9f88aeb34ece07b10d05794020d1d656b8) +- 📌 **Pin read-only models.** Users can now pin read-only models from the workspace. [#21308](https://github.com/open-webui/open-webui/issues/21308), [Commit](https://github.com/open-webui/open-webui/commit/97331bf11d41ca54e47f86777fb8dbd73988c631) +- 🔍 **Yandex search provider.** Administrators can now configure Yandex as a web search provider, expanding search engine options for retrieval-augmented generation. [#20922](https://github.com/open-webui/open-webui/pull/20922) +- 🔐 **Custom password hints.** Administrators can now provide custom password requirement hints to users via the "PASSWORD_VALIDATION_HINT" environment variable, making it clearer what password criteria must be met during signup or password changes. [#20647](https://github.com/open-webui/open-webui/issues/20647), [#20650](https://github.com/open-webui/open-webui/pull/20650) +- 🔑 **OAuth token exchange.** Administrators can now enable OAuth token exchange via "ENABLE_OAUTH_TOKEN_EXCHANGE", allowing external applications to authenticate users by exchanging OAuth provider tokens for Open WebUI session tokens. [Commit](https://github.com/open-webui/open-webui/commit/655420fd25ed0ea872954baa485030079c00c10e) +- 🗄️ **Weaviate custom endpoints.** Administrators can now connect to self-hosted Weaviate deployments with separate HTTP and gRPC endpoints via new environment variables. [#20620](https://github.com/open-webui/open-webui/pull/20620) +- 🛡️ **MCP custom SSL certificates.** Administrators can now connect to MCP servers with self-signed or custom SSL certificates via the "AIOHTTP_CLIENT_SESSION_TOOL_SERVER_SSL" environment variable. [#20875](https://github.com/open-webui/open-webui/issues/20875), [Commit](https://github.com/open-webui/open-webui/commit/c7f996d593e4bb48103b91316204fe7e50e25b35) +- 🗃️ **Redis Sentinel reconnection delay.** Administrators using Redis Sentinel can now configure a reconnection delay via "REDIS_RECONNECT_DELAY" to prevent retry exhaustion during failover elections. [#21021](https://github.com/open-webui/open-webui/pull/21021) +- 📡 **Custom user info headers.** Administrators can now customize the header names used when forwarding user information to external services, enabling compatibility with services like AWS Bedrock AgentCore that require specific header prefixes. [Commit](https://github.com/open-webui/open-webui/commit/6c0f886cdf4b4249dca29e9340b3b998a7262d61) +- 🔗 **Forward user info to tool servers.** User identity and chat context can now be forwarded to MCP servers and external tool servers when "ENABLE_FORWARD_USER_INFO_HEADERS" is enabled, allowing tool providers to implement per-user authorization, auditing, and rate limiting. [#21092](https://github.com/open-webui/open-webui/pull/21092), [Commit](https://github.com/open-webui/open-webui/commit/2c37daef86a058e370151ecead17f10078102307) +- 📬 **External tool event emitters.** External tools (OpenAPI/MCP) can now send tool events back to Open WebUI using the event emitter endpoint, as message ID is now forwarded alongside chat ID when "ENABLE_FORWARD_USER_INFO_HEADERS" is enabled. [#21214](https://github.com/open-webui/open-webui/pull/21214) +- 📥 **Playground chat export.** Administrators can now export playground chats as JSON or plain text files, allowing them to save their conversations for backup or sharing outside the platform. [Commit](https://github.com/open-webui/open-webui/commit/8e2b0b6fd2ac99c833a110e2bc6aa655f1682669) +- 🖼️ **Images playground.** Administrators can now test image generation and editing directly in a new Images playground, with support for uploading source images for edits and downloading results. [Commit](https://github.com/open-webui/open-webui/commit/94302de49b27bdf1df86b5c26f2cafb98f964e52) +- 🛠️ **Dynamic dropdown valve fields.** Tool and function developers can now create dropdown fields with dynamically-generated options that update based on runtime context, such as available models or user permissions. [Commit](https://github.com/open-webui/open-webui/commit/474427c67e953bb9f7d122757a756a639214e0b2) +- 🏎️ **Faster profile updates.** User profile updates and role changes are now faster by eliminating redundant database queries. [#21011](https://github.com/open-webui/open-webui/pull/21011) +- 🔑 **Faster authentication.** User authentication is now 34% faster by combining database lookups into a single query. [#21010](https://github.com/open-webui/open-webui/pull/21010) +- 🔋 **Faster chat completions.** Chat completions and embeddings now respond much faster by checking the model cache before fetching model lists, reducing Time To First Token from several seconds to subsecond for most requests. [#20886](https://github.com/open-webui/open-webui/pull/20886), [#20069](https://github.com/open-webui/open-webui/discussions/20069) +- 🏎️ **Faster Redis model list loading.** Model list loading is now significantly faster when using Redis with many models, reducing API response latency by caching configuration values locally instead of making repeated Redis lookups on every model iteration. [#21306](https://github.com/open-webui/open-webui/pull/21306) +- 💨 **Faster knowledge base file batch-add.** Batch-adding files to knowledge bases is now faster with a single database query instead of one query per file. [#21006](https://github.com/open-webui/open-webui/pull/21006) +- ⚡ **Smoother model selector dropdown.** The model selector dropdown now renders smoothly even with hundreds of models, eliminating the lag and freezing that occurred when opening the dropdown with large model lists. [Commit](https://github.com/open-webui/open-webui/commit/4331029926245b7b74fa8e254610c91400b239b0) +- 🚗 **Faster model visibility toggling.** Toggling model visibility in the admin panel is now faster with optimized database access. [#21009](https://github.com/open-webui/open-webui/pull/21009) +- 💾 **Faster model access control checks.** Model access control checks are now faster by batch-fetching model info and group memberships upfront instead of querying for each model. [#21008](https://github.com/open-webui/open-webui/pull/21008) +- ⚙️ **Faster model list and imports.** Model list loading and model imports are now faster by eliminating redundant database queries. [#21004](https://github.com/open-webui/open-webui/pull/21004) +- 🏃 **Faster SCIM group member lookups.** SCIM group member lookups are now up to 13x faster by batching user queries instead of fetching each member individually. [#21005](https://github.com/open-webui/open-webui/pull/21005) +- 💨 **Batched group member counts.** Group member counts are now fetched in a single batch query when loading group lists, eliminating redundant database lookups. [Commit](https://github.com/open-webui/open-webui/commit/96c07f44a8f5e6346b2ea6ac529ff4ec3c47e90a) +- 💨 **Faster bulk operations.** Bulk feedback deletion and group member removal are now 4-5x faster with optimized batch operations. [#21019](https://github.com/open-webui/open-webui/pull/21019) +- 🧠 **Faster memory updates.** Memory updates are now up to 39% faster by eliminating redundant database queries. [#21013](https://github.com/open-webui/open-webui/pull/21013) +- ⚙️ **Faster filter function loading.** Filter function loading is now faster by batching database queries instead of fetching each function individually. [#21018](https://github.com/open-webui/open-webui/pull/21018) +- 🖼️ **Image model regex configuration.** Administrators can now configure which image generation models support auto-sizing and URL responses via new regex environment variables, improving compatibility with LiteLLM and other proxies that use prefixed model names. [#21126](https://github.com/open-webui/open-webui/pull/21126), [Commit](https://github.com/open-webui/open-webui/commit/ecf3fa2feb28e74ff6c17ca97d94581f316da56a) +- 🎁 **Easter eggs toggle.** Administrators can now control the visibility of easter egg features via the "ENABLE_EASTER_EGGS" environment variable. [Commit](https://github.com/open-webui/open-webui/commit/907dba4517903e5646e40223a0edca26a7107bc8) +- 🔌 **Independent access control updates.** API endpoints now support independent access control updates for models, tools, knowledge bases, and notes, enabling finer-grained permission management. [Commit](https://github.com/open-webui/open-webui/commit/0044902c082f8475336cc7d5c57fe3f35ab0555d), [Commit](https://github.com/open-webui/open-webui/commit/c259c878060af1b03b702c943e8813d7b4fc3199), [Commit](https://github.com/open-webui/open-webui/commit/e3a825769063cee486650cc2eb9a032676e630c5) +- ♿ **Screen reader accessibility.** Screen reader users now hear the password field label only once on the login page, improving form navigation for assistive technology users. [Commit](https://github.com/open-webui/open-webui/commit/1441d0d735c7a1470070b33327e1dd4dc5ca1131) +- 🔄 **General improvements.** Various improvements were implemented across the application to enhance performance, stability, and security. +- 🌐 **Translation updates.** Translations for Catalan, Finnish, Irish, French, German, Japanese, Latvian, Polish, Portuguese (Brazil), Simplified Chinese, Slovak, Spanish, and Traditional Chinese were enhanced and expanded. + +### Fixed + +- ⚡ **Connection pool exhaustion fix.** Database connection pool exhaustion and timeout errors during concurrent usage have been resolved by releasing connections before chat completion requests and embedding operations for memory and knowledge base processing. [#20569](https://github.com/open-webui/open-webui/pull/20569), [#20570](https://github.com/open-webui/open-webui/pull/20570), [#20571](https://github.com/open-webui/open-webui/pull/20571), [#20572](https://github.com/open-webui/open-webui/pull/20572), [#20573](https://github.com/open-webui/open-webui/pull/20573), [#20574](https://github.com/open-webui/open-webui/pull/20574), [#20575](https://github.com/open-webui/open-webui/pull/20575), [#20576](https://github.com/open-webui/open-webui/pull/20576), [#20577](https://github.com/open-webui/open-webui/pull/20577), [#20578](https://github.com/open-webui/open-webui/pull/20578), [#20579](https://github.com/open-webui/open-webui/pull/20579), [#20580](https://github.com/open-webui/open-webui/pull/20580), [#20581](https://github.com/open-webui/open-webui/pull/20581), [Commit](https://github.com/open-webui/open-webui/commit/7da37b4f66b9b2e821796b06b75e03cb0237e0a9), [Commit](https://github.com/open-webui/open-webui/commit/9af40624c5f0f8f7f640a11356e167543b07b2bb) +- 🚫 **LDAP authentication hang fix.** LDAP authentication no longer freezes the entire service when logging in with non-existent accounts, preventing application hangs. [Commit](https://github.com/open-webui/open-webui/commit/a4281f6a7fbc9764b57830e4ef81bb780aa34af9), [#21300](https://github.com/open-webui/open-webui/issues/21300) +- ✅ **Trusted Header auto-registration fix.** Trusted Header Authentication now properly auto-registers new users after the first login, assigning the configured default role instead of failing for users not yet in the database. [Commit](https://github.com/open-webui/open-webui/commit/9b30e8f6894c8c8bad0a9ce4693eab810962adc9) +- 🛡️ **SSRF protection for image loading.** External image loading now validates URLs before fetching to prevent SSRF attacks against local and private network addresses. [Commit](https://github.com/open-webui/open-webui/commit/ce50d9bac4f30b054b09a2fbda52569b73ea591c) +- 🛡️ **Malformed Authorization header fix.** Malformed Authorization headers no longer cause server crashes; requests are now handled gracefully instead of returning HTTP 500 errors. [#20938](https://github.com/open-webui/open-webui/issues/20938), [Commit](https://github.com/open-webui/open-webui/commit/7e79f8d1c6b5a02f1a46e792540c6bbf7bed8edc) +- 🚪 **Channel notification access control.** Users without channel permissions can no longer access channels through notifications, properly enforcing access controls across all channel entry points. [#20883](https://github.com/open-webui/open-webui/pull/20883), [#20789](https://github.com/open-webui/open-webui/discussions/20789) +- 🐛 **Ollama model name suffix fix.** Ollama-compatible providers that do not use ":latest" in model names can now successfully chat, fixing errors where model names were incorrectly appended with ":latest" suffixes. [#21331](https://github.com/open-webui/open-webui/issues/21331), [Commit](https://github.com/open-webui/open-webui/commit/05ae44b98dc279ee12cc8eab17278ccbfec60301) +- ♻️ **Streaming connection cleanup.** Streaming responses now properly clean up network connections when interrupted, preventing "Unclosed client session" errors from accumulating over time. [#20889](https://github.com/open-webui/open-webui/pull/20889), [#17058](https://github.com/open-webui/open-webui/issues/17058) +- 💾 **Inline image context exhaustion fix.** Inline images no longer exhaust the model's context window by including their full base64 data in chat metadata, preventing premature context exhaustion with image-heavy conversations. [#20916](https://github.com/open-webui/open-webui/pull/20916) +- 🚀 **Status indicator GPU usage fix.** High GPU usage caused by the user online status indicator animation has been resolved, reducing consumption from 35-40% to near-zero in browsers with hardware acceleration. [#21062](https://github.com/open-webui/open-webui/issues/21062), [Commit](https://github.com/open-webui/open-webui/commit/938d1b0743c64f0ce513d68e57dfbb86987cb06b) +- 🔧 **Async pipeline operations.** Pipeline operations now run asynchronously instead of blocking the FastAPI event loop, allowing the server to handle other requests while waiting for external pipeline API calls. [#20910](https://github.com/open-webui/open-webui/pull/20910) +- 🔌 **MCP tools regression fix.** MCP tools now work reliably again after a regression in v0.7.2 that caused "cannot pickle '\_asyncio.Future' object" errors when attempting to use MCP servers in chat. [#20629](https://github.com/open-webui/open-webui/issues/20629), [#20500](https://github.com/open-webui/open-webui/issues/20500), [Commit](https://github.com/open-webui/open-webui/commit/886c12c5664bc2dd73313330f61c2257169da6d1) +- 🔗 **Function chat ID propagation fix.** Functions now reliably receive the chat identifier during internal task invocations like web search query generation, RAG query generation, and image prompt generation, enabling stateful functions to maintain consistent per-chat state without fragmentation. [#20563](https://github.com/open-webui/open-webui/issues/20563), [#20585](https://github.com/open-webui/open-webui/pull/20585) +- 💻 **Markdown fence code execution fix.** Code execution now works reliably when models wrap code in markdown fences, automatically stripping the backticks before execution to prevent syntax errors that affected most non-GPT models. [#20941](https://github.com/open-webui/open-webui/issues/20941), [Commit](https://github.com/open-webui/open-webui/commit/4a5516775927aaf002212f2e09c55a17c699bc46), [Commit](https://github.com/open-webui/open-webui/commit/683438b418fb3b453a8ad88c1ba1a9944eac3593) +- 💻 **ANSI code execution fix.** Code execution is now reliable when LLMs include ANSI terminal color codes in their output, preventing random failures that previously caused syntax errors. [#21091](https://github.com/open-webui/open-webui/issues/21091), [Commit](https://github.com/open-webui/open-webui/commit/b1737040a7d3bb5efcfe0f1432e89d7e82e51d2d) +- 🗨️ **Incomplete model metadata crash fix.** Starting chats with models that have incomplete metadata information no longer crashes the application. [#20565](https://github.com/open-webui/open-webui/issues/20565), [Commit](https://github.com/open-webui/open-webui/commit/14f6747dfc66fb7e942b930650286012121e5262) +- 💬 **Unavailable model crash fix.** Adding message pairs with Ctrl+Shift+Enter no longer crashes when the chat's model is unavailable, showing a helpful error message instead. [#20663](https://github.com/open-webui/open-webui/pull/20663) +- 📚 **Knowledge base file upload fix.** Uploading files to knowledge bases now works correctly, fixing database mapping errors that prevented file uploads. [#20925](https://github.com/open-webui/open-webui/issues/20925), [#20931](https://github.com/open-webui/open-webui/pull/20931) +- 🧠 **Knowledge base query type fix.** Knowledge base queries no longer fail intermittently when models send tool call parameters as strings instead of their expected types. [#20705](https://github.com/open-webui/open-webui/pull/20705) +- 📚 **Knowledge base reindex fix.** Reindexing knowledge base files now works correctly instead of failing with duplicate content errors. [#20854](https://github.com/open-webui/open-webui/issues/20854), [#20857](https://github.com/open-webui/open-webui/pull/20857) +- 🔧 **Multi-worker knowledge base timeout fix.** In multi-worker deployments, uploading very large documents to knowledge bases no longer causes workers to be killed by health check timeouts, and administrators can now configure a custom embedding timeout via "RAG_EMBEDDING_TIMEOUT". [#21158](https://github.com/open-webui/open-webui/pull/21158), [Discussion](https://github.com/open-webui/open-webui/discussions/21151), [Commit](https://github.com/open-webui/open-webui/commit/c653e4ec54d070aee5e9568d016daebb61f06632) +- 🌅 **Dark mode icon inversion fix.** Icons in chat and action menus are now displayed correctly in dark mode, fixing an issue where PNG icons with "svg" in their base64 encoding were randomly inverted. [#21272](https://github.com/open-webui/open-webui/pull/21272), [Commit](https://github.com/open-webui/open-webui/commit/0a44d80252afae73de4098ab1c3eb6cf54157fd6) +- 🛠️ **Admin model write permission fix.** Fixed the admin panel allowing models to be assigned write permissions, since users with write permission are not admins and cannot write. [Commit](https://github.com/open-webui/open-webui/commit/4aedfdc5471a1f13c1084b34b48ea3ed6311cd42) +- 🛠️ **Prompt access control save fix.** Prompt access control settings are now saved correctly when modifying resource permissions. [Commit](https://github.com/open-webui/open-webui/commit/30f72672fac2579c267a076e6ba89dfe1812137b) +- ✏️ **Knowledge base file edit fix.** Editing files within knowledge bases now saves correctly and can be used for retrieval, fixing a silent failure where the save appeared successful but the file could not be searched. [Commit](https://github.com/open-webui/open-webui/commit/f9ab66f51a52388a4eb084c8f69044e79bf5cb04) +- 🖼️ **Reasoning section artifact rendering fix.** Code blocks within model reasoning sections no longer incorrectly render as interactive artifacts, ensuring only intended output displays as previews. [#20801](https://github.com/open-webui/open-webui/issues/20801), [#20877](https://github.com/open-webui/open-webui/pull/20877), [Commit](https://github.com/open-webui/open-webui/commit/4c6f100b5fe2145a3d676b70b5f7c0e7f07cee20) +- 🔐 **Group resource sharing fix.** Sharing resources with groups now works correctly, fixing database errors and an issue where models shared with read-only access were not visible to group members. [#20666](https://github.com/open-webui/open-webui/issues/20666), [#21043](https://github.com/open-webui/open-webui/issues/21043), [Commit](https://github.com/open-webui/open-webui/commit/5a075a2c836e46b83f8710285f09aff1f6125072) +- 🔑 **Docling API key fix.** Docling API key authentication now works correctly by using the proper "X-Api-Key" header format instead of the incorrect "Bearer" authorization prefix. [#20652](https://github.com/open-webui/open-webui/pull/20652) +- 🔌 **MCP OAuth 2.1 fix.** MCP OAuth 2.1 authentication now works correctly, resolving connection verification failures and 401 errors during the authorization callback. [#20808](https://github.com/open-webui/open-webui/issues/20808), [#20828](https://github.com/open-webui/open-webui/issues/20828), [Commit](https://github.com/open-webui/open-webui/commit/8eebc2aea63b7045e61c9689a65a2dfa9c797bcb) +- 💻 **MATLAB syntax highlighting.** MATLAB code blocks now display with proper syntax highlighting in chat messages. [#20719](https://github.com/open-webui/open-webui/issues/20719), [#20773](https://github.com/open-webui/open-webui/pull/20773) +- 📊 **CSV export HTML entity decoding.** Exporting tables to CSV now properly decodes HTML entities, ensuring special characters display correctly in the exported file. [#20688](https://github.com/open-webui/open-webui/pull/20688) +- 📄 **Markdown Header Text Splitter persistence.** The "Markdown Header Text Splitter" document setting now persists correctly when disabled, preventing it from reverting to enabled after page refresh. [#20929](https://github.com/open-webui/open-webui/issues/20929), [#20930](https://github.com/open-webui/open-webui/pull/20930) +- 🔌 **Audio service timeout handling.** Audio transcription and text-to-speech requests now have proper timeouts, preventing the UI from freezing when external services don't respond. [#21055](https://github.com/open-webui/open-webui/pull/21055) +- 💬 **Reference Chats visibility fix.** The "Reference Chats" option now appears in the message input menu even when the sidebar is collapsed, fixing the issue where it was hidden on mobile devices and at first load. [#20827](https://github.com/open-webui/open-webui/issues/20827), [Commit](https://github.com/open-webui/open-webui/commit/a3600e8b219fc4c019b95258d16bd3e2827490c6) +- 🔍 **Chat search self-exclusion.** The "search_chats" builtin tool now excludes the current conversation from search results, preventing redundant matches. [#20718](https://github.com/open-webui/open-webui/issues/20718), [Commit](https://github.com/open-webui/open-webui/commit/1a4bdd2b30017d901b9cac1e2e10684ec1edd062) +- 📚 **Knowledge base pagination fix.** Paginating through knowledge base files no longer shows duplicates or skips files when multiple documents share the same update timestamp. [#20846](https://github.com/open-webui/open-webui/issues/20846), [Commit](https://github.com/open-webui/open-webui/commit/a9a0ce6beaa286cc18eff24b518a6f3d7a560e2f) +- 📋 **Batch file error reporting.** Batch file processing operations now return properly structured error information when failures occur, making it clearer what went wrong during multi-file operations. [#20795](https://github.com/open-webui/open-webui/issues/20795), [Commit](https://github.com/open-webui/open-webui/commit/68b2872ed645cffb641fa5a21a784d6e9ea0d72b) +- ⚙️ **Persistent config with Redis fix.** Configuration values now respect the "ENABLE_PERSISTENT_CONFIG" setting when Redis is used, ensuring environment variables are reloaded on restart when persistent config is disabled. [#20830](https://github.com/open-webui/open-webui/issues/20830), [Commit](https://github.com/open-webui/open-webui/commit/5d48e48e15b003874cc821d896998a01e87580a0) +- 🔧 **Engine.IO logging fix.** The "WEBSOCKET_SERVER_ENGINEIO_LOGGING" environment variable now works correctly, allowing administrators to configure Engine.IO logging independently from general websocket logging. [#20727](https://github.com/open-webui/open-webui/pull/20727), [Commit](https://github.com/open-webui/open-webui/commit/5cfb7a08cbde5d39aaf4097b849a80da87c30d66) +- 🌐 **French language default fix.** Browsers requesting French language now default to French (France) instead of French (Canada), matching standard language preference expectations. [#20603](https://github.com/open-webui/open-webui/pull/20603), [Commit](https://github.com/open-webui/open-webui/commit/4d9a7cc6c0adea54b58046c576250a0c3ae7b512) +- 🔘 **Firefox delete button fix.** Pressing Enter after clicking delete buttons no longer incorrectly retriggers confirmation modals in Firefox. [Commit](https://github.com/open-webui/open-webui/commit/57a2024c58b9c674f2ae08eeb552994ef1796888) +- 🌍 **RTL table rendering fix.** Chat markdown tables now correctly display right-to-left when containing RTL language content (Arabic, Hebrew, Farsi, etc.), matching the "Auto" direction setting behavior. [#21160](https://github.com/open-webui/open-webui/issues/21160), [Commit](https://github.com/open-webui/open-webui/commit/284b97bd84c824013ad00ea07621192ec69a5e93) +- 🔒 **Write permission enforcement for tools.** Users without write permissions are now properly prevented from editing tools, with a clear error message displayed when attempting unauthorized edits. [Commit](https://github.com/open-webui/open-webui/commit/85e92fe3b062ae669985c09495f6ff1baf8176ab), [Commit](https://github.com/open-webui/open-webui/commit/91faa9fd5a1cfc5d3ab531d2d91d28db52bcc702) +- 🛡️ **Chat Valves permission enforcement.** The "Allow Chat Valves" permission is now properly enforced in the integrations menu, preventing users from bypassing access restrictions. [#20691](https://github.com/open-webui/open-webui/pull/20691) +- 📝 **Audit log browser session fix.** Audit logs now properly capture all user activity including browser-based sessions, not just API key requests. [#20651](https://github.com/open-webui/open-webui/issues/20651), [Commit](https://github.com/open-webui/open-webui/commit/86e6b2b68b85e958188881785495030de1a30402), [Commit](https://github.com/open-webui/open-webui/commit/ee5fd1246cb3f8f16ca5cbb24feeea43b7800dcb) +- 🎨 **Long model name truncation.** Long model names and IDs in the admin panel now truncate properly to prevent visual overflow, with full names visible on hover. [#20696](https://github.com/open-webui/open-webui/pull/20696) +- 👥 **Admin user filter pagination fix.** Filtering users in the admin panel now automatically resets to page 1, preventing empty results when searching from pages beyond the first. [#20723](https://github.com/open-webui/open-webui/pull/20723), [Commit](https://github.com/open-webui/open-webui/commit/be75bc506adb048ef11b1612c0e3662511c920d0) +- 🔎 **Username search on workspace pages.** Searching for users by username now works correctly on Models, Knowledge, and Functions workspace pages, making it easier to find resources owned by specific users. [#20780](https://github.com/open-webui/open-webui/pull/20780) +- 🗑️ **File deletion orphaned embeddings fix.** Deleting files now properly removes associated knowledge base embeddings, preventing orphaned data from accumulating. [Commit](https://github.com/open-webui/open-webui/commit/93ed4ae2cda2f4311143e51f586aaa73b83a37a7) +- 🧹 **Event listener memory leak fix.** Memory leaks caused by event listeners not being cleaned up during navigation have been resolved. [#20913](https://github.com/open-webui/open-webui/pull/20913) +- 🐳 **Docker Ollama update fix.** Ollama can now be updated within Docker containers after adding a missing zstd dependency. [#20994](https://github.com/open-webui/open-webui/issues/20994), [#21052](https://github.com/open-webui/open-webui/pull/21052) +- 📝 **Workspace duplicate API request fix.** The prompts, knowledge, and models workspaces no longer make duplicate API requests when loading. [Commit](https://github.com/open-webui/open-webui/commit/ab5dfbda54664c9278b0d807ba06cad94edd798f), [Commit](https://github.com/open-webui/open-webui/commit/e5dbfc420dd3e7f6ba047a3e11584449ff0742b4) +- 📡 **OpenTelemetry Redis cluster fix.** OpenTelemetry instrumentation now works correctly with Redis cluster mode deployments. [#21129](https://github.com/open-webui/open-webui/pull/21129) +- 🐳 **Airgapped NLTK tokenizer fix.** Document extraction now works reliably in airgapped environments after container restarts by bundling NLTK tokenizer data in the Docker image. [#21165](https://github.com/open-webui/open-webui/pull/21165), [#21150](https://github.com/open-webui/open-webui/issues/21150) +- 💬 **Channel model mention crash fix.** Mentioning a model in channels no longer crashes when older thread messages have missing data. [#21112](https://github.com/open-webui/open-webui/pull/21112) +- 🔧 **OpenAPI tool import fix.** Importing OpenAPI tool specifications no longer crashes when parameters lack explicit name fields, fixing compatibility with complex request body definitions. [#21121](https://github.com/open-webui/open-webui/pull/21121), [Commit](https://github.com/open-webui/open-webui/commit/8e79b3d0bc4903f30e747b663ac818976618c83c) +- 🌐 **Webpage attachment content fix.** Attaching webpages to chats now retrieves full content instead of only metadata, fixing an unawaited coroutine in SSL certificate verification. [#21166](https://github.com/open-webui/open-webui/issues/21166), [Commit](https://github.com/open-webui/open-webui/commit/a214ec40ea00eebcba49570647ca6ab8f61765d5) +- 💾 **File upload settings persistence.** File upload settings (Max Upload Size, Max File Count, Image Compression dimensions) now persist correctly and are no longer erased when updating other RAG configuration settings. [#21057](https://github.com/open-webui/open-webui/issues/21057), [Commit](https://github.com/open-webui/open-webui/commit/258454276e1ef8ded24968515f7bf5e1833ca011) +- 📦 **Tool call expand/collapse fix.** Tool call results in chat can now be expanded and collapsed again after a recent refactor disabled this behavior. [#21205](https://github.com/open-webui/open-webui/pull/21205) +- 🪛 **Disabled API endpoint bypass fix.** Fixed Ollama/OpenAI API endpoints bypassing 'ENABLE_OLLAMA_API' and 'ENABLE_OPENAI_API' flags when the 'url_idx' parameter was provided. Endpoints now properly return a 503 error with a clear "API is disabled" message instead of attempting to connect and logging confusing connection errors. +- 🛠️ **OpenSearch 3.0 compatibility fix.** Document uploads to knowledge bases now work correctly when using OpenSearch backend with opensearch-py >= 3.0.0, fixing a TypeError that previously caused failures. [#21248](https://github.com/open-webui/open-webui/pull/21248), [#20649](https://github.com/open-webui/open-webui/issues/20649) +- 📱 **Gboard multi-line paste fix.** Multi-line text pasted from Gboard on Android now inserts correctly instead of being replaced with a single newline, fixing a bug where the keyboard's clipboard suggestion strip sent text via 'insertText' events instead of standard paste events. [#21265](https://github.com/open-webui/open-webui/pull/21265) +- 🔧 **Batch embeddings endpoint fix.** The '/api/embeddings' endpoint now correctly returns separate embeddings for each input string when processing batch requests to Ollama providers. [Commit](https://github.com/open-webui/open-webui/commit/8fd5c06e5bf7e0ccbda15d83338912ea17f66783), [#21279](https://github.com/open-webui/open-webui/issues/21279) +- 🗝️ **SSL verification for embeddings.** SSL certificate verification now respects the "AIOHTTP_CLIENT_SESSION_SSL" setting for OpenAI and Azure OpenAI embedding requests, allowing connections to self-signed certificate endpoints when disabled. [Commit](https://github.com/open-webui/open-webui/commit/cd31b8301b38bfa86872608cfbd022ff74e3ae52) +- 🔧 **Tool call HTML entity fix.** Models now receive properly formatted tool call results in multi-turn conversations, fixing an issue where HTML entities caused malformed content that was hard to parse. [#20755](https://github.com/open-webui/open-webui/pull/20755) +- 💾 **Duplicate inline image context fix.** Inline images no longer exhaust the model's context window by including their full base64 data in chat metadata, preventing premature context exhaustion with image-heavy conversations. [#20916](https://github.com/open-webui/open-webui/pull/20916) +- 🐛 **OpenAI model cache lookup fix.** The OpenAI API router model lookup was corrected to use the proper model identifier when checking the cache, ensuring consistent and correct model retrieval during chat completions. [#21327](https://github.com/open-webui/open-webui/pull/21327) +- 🐛 **Ollama latest suffix fix.** Ollama-compatible providers that don't use ":latest" in model names can now successfully chat, fixing errors where model names were incorrectly appended with ":latest" suffixes. [#21331](https://github.com/open-webui/open-webui/issues/21331), [Commit](https://github.com/open-webui/open-webui/commit/05ae44b98dc279ee12cc8eab17278ccbfec60301) +- ⛔ **OpenAI endpoint detection fix.** OpenAI API endpoint detection was corrected to use exact hostname matching instead of substring matching, preventing third-party providers with similar URL patterns from being incorrectly filtered. [Commit](https://github.com/open-webui/open-webui/commit/423d8b18170a0b92b582aba6ef7bb9ba173e876e) +- 🛠️ **RedisCluster task stopping fix.** Task stopping now works correctly in RedisCluster deployments, fixing an issue where tasks would remain active after cancellation attempts. [#20803](https://github.com/open-webui/open-webui/pull/20803), [Commit](https://github.com/open-webui/open-webui/commit/0dcbd05e2436929ae9d2c559a204844ae0239b57) +- 📎 **Citation parsing error fix.** Citation parsing no longer crashes when builtin tools return error responses, fixing AttributeError issues when tools like search_web fail. [#21071](https://github.com/open-webui/open-webui/pull/21071) + +### Changed + +- ‼️ **Database Migration Required** — This release includes database schema changes; multi-worker, multi-server, or load-balanced deployments must update all instances simultaneously rather than performing rolling updates, as running mixed versions will cause application failures due to schema incompatibility between old and new instances. +- ⚠️ **Chat Message Table Migration** — This release includes a new chat message table migration that can take a significant amount of time to complete in larger deployments with extensive chat histories. Administrators should plan for adequate maintenance windows and allow the migration to complete fully without interruption. Running the migration with insufficient time or resources may result in data integrity issues. +- 🔗 **Prompt ID-based URLs.** Prompts now use unique ID-based URLs instead of command-based URLs, allowing more flexible command renaming without breaking saved links or integrations. [#20945](https://github.com/open-webui/open-webui/pull/20945) + ## [0.7.2] - 2026-01-10 ### Fixed diff --git a/CHANGELOG_EXTRA.md b/CHANGELOG_EXTRA.md index 081243b4b8..60f1fa2ad3 100644 --- a/CHANGELOG_EXTRA.md +++ b/CHANGELOG_EXTRA.md @@ -5,6 +5,18 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [0.8.0.2] - 2026.02.13 + +### Fixed + +- 修复模型价格设置不能为小数的问题 + +## [0.8.0.1] - 2026.02.13 + +### Changed + +- 合并官方 0.8.0 改动 + ## [0.7.2.5] - 2026.01.28 ### Changed diff --git a/Dockerfile b/Dockerfile index ca9fe71e73..b3a2b3e1c3 100644 --- a/Dockerfile +++ b/Dockerfile @@ -129,7 +129,7 @@ RUN apt-get update && \ apt-get install -y --no-install-recommends \ git build-essential pandoc gcc netcat-openbsd curl jq \ python3-dev \ - ffmpeg libsm6 libxext6 \ + ffmpeg libsm6 libxext6 zstd \ && rm -rf /var/lib/apt/lists/* # install python dependencies @@ -144,6 +144,7 @@ RUN pip3 install --no-cache-dir uv && \ python -c "import os; from sentence_transformers import SentenceTransformer; SentenceTransformer(os.environ.get('AUXILIARY_EMBEDDING_MODEL', 'TaylorAI/bge-micro-v2'), device='cpu')" && \ python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])"; \ python -c "import os; import tiktoken; tiktoken.get_encoding(os.environ['TIKTOKEN_ENCODING_NAME'])"; \ + python -c "import nltk; nltk.download('punkt_tab')"; \ else \ pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu --no-cache-dir && \ uv pip install --system -r requirements.txt --no-cache-dir && \ @@ -152,6 +153,7 @@ RUN pip3 install --no-cache-dir uv && \ python -c "import os; from sentence_transformers import SentenceTransformer; SentenceTransformer(os.environ.get('AUXILIARY_EMBEDDING_MODEL', 'TaylorAI/bge-micro-v2'), device='cpu')" && \ python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])"; \ python -c "import os; import tiktoken; tiktoken.get_encoding(os.environ['TIKTOKEN_ENCODING_NAME'])"; \ + python -c "import nltk; nltk.download('punkt_tab')"; \ fi; \ fi; \ mkdir -p /app/backend/data && chown -R $UID:$GID /app/backend/data/ && \ diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index e19644da17..8bf052a763 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -2,7 +2,9 @@ import logging import os import shutil +import socket import base64 +from concurrent.futures import ThreadPoolExecutor import redis from datetime import datetime @@ -237,7 +239,7 @@ def __setattr__(self, key, value): self._state[key].value = value self._state[key].save() - if self._redis: + if self._redis and ENABLE_PERSISTENT_CONFIG: redis_key = f"{self._redis_key_prefix}:config:{key}" self._redis.set(redis_key, json.dumps(self._state[key].value)) @@ -245,8 +247,8 @@ def __getattr__(self, key): if key not in self._state: raise AttributeError(f"Config key '{key}' not found") - # If Redis is available, check for an updated value - if self._redis: + # If Redis is available and persistent config is enabled, check for an updated value + if self._redis and ENABLE_PERSISTENT_CONFIG: redis_key = f"{self._redis_key_prefix}:config:{key}" redis_value = self._redis.get(redis_key) @@ -950,6 +952,40 @@ def feishu_oauth_register(oauth: OAuth): elif K8S_FLAG: OLLAMA_BASE_URL = "http://ollama-service.open-webui.svc.cluster.local:11434" + +def _resolve_ollama_base_url(url: str) -> str: + """If the default Ollama port (11434) is unreachable, try the fallback port (12434).""" + + def reachable(host: str, port: int) -> bool: + try: + with socket.create_connection((host, port), timeout=1.0): + return True + except (OSError, TimeoutError): + return False + + host = urlparse(url).hostname or "localhost" + + with ThreadPoolExecutor(max_workers=2) as pool: + default = pool.submit(reachable, host, 11434) + fallback = pool.submit(reachable, host, 12434) + + if not default.result() and fallback.result(): + url = url.replace(":11434", ":12434") + log.info(f"Ollama port 11434 unreachable on {host}, falling back to 12434") + elif not default.result(): + log.info(f"Ollama ports 11434 and 12434 both unreachable on {host}") + + return url + + +# Auto-resolve Ollama port when no explicit URL was provided by the user. +# The Dockerfile default is "/ollama" which the block above rewrites to :11434. +if os.environ.get("OLLAMA_BASE_URL", "") in ("", "/ollama") and not os.environ.get( + "OLLAMA_BASE_URLS", "" +): + OLLAMA_BASE_URL = _resolve_ollama_base_url(OLLAMA_BASE_URL) + + OLLAMA_BASE_URLS = os.environ.get("OLLAMA_BASE_URLS", "") OLLAMA_BASE_URLS = OLLAMA_BASE_URLS if OLLAMA_BASE_URLS != "" else OLLAMA_BASE_URL @@ -1234,6 +1270,11 @@ def feishu_oauth_register(oauth: OAuth): os.environ.get("USER_PERMISSIONS_WORKSPACE_TOOLS_ACCESS", "False").lower() == "true" ) +USER_PERMISSIONS_WORKSPACE_SKILLS_ACCESS = ( + os.environ.get("USER_PERMISSIONS_WORKSPACE_SKILLS_ACCESS", "False").lower() + == "true" +) + USER_PERMISSIONS_WORKSPACE_MODELS_IMPORT = ( os.environ.get("USER_PERMISSIONS_WORKSPACE_MODELS_IMPORT", "False").lower() == "true" @@ -1448,6 +1489,7 @@ def feishu_oauth_register(oauth: OAuth): "knowledge": USER_PERMISSIONS_WORKSPACE_KNOWLEDGE_ACCESS, "prompts": USER_PERMISSIONS_WORKSPACE_PROMPTS_ACCESS, "tools": USER_PERMISSIONS_WORKSPACE_TOOLS_ACCESS, + "skills": USER_PERMISSIONS_WORKSPACE_SKILLS_ACCESS, "models_import": USER_PERMISSIONS_WORKSPACE_MODELS_IMPORT, "models_export": USER_PERMISSIONS_WORKSPACE_MODELS_EXPORT, "prompts_import": USER_PERMISSIONS_WORKSPACE_PROMPTS_IMPORT, @@ -2195,9 +2237,15 @@ class BannerModel(BaseModel): QDRANT_COLLECTION_PREFIX = os.environ.get("QDRANT_COLLECTION_PREFIX", "open-webui") WEAVIATE_HTTP_HOST = os.environ.get("WEAVIATE_HTTP_HOST", "") +WEAVIATE_GRPC_HOST = os.environ.get("WEAVIATE_GRPC_HOST", "") WEAVIATE_HTTP_PORT = int(os.environ.get("WEAVIATE_HTTP_PORT", "8080")) WEAVIATE_GRPC_PORT = int(os.environ.get("WEAVIATE_GRPC_PORT", "50051")) WEAVIATE_API_KEY = os.environ.get("WEAVIATE_API_KEY") +WEAVIATE_HTTP_SECURE = os.environ.get("WEAVIATE_HTTP_SECURE", "false").lower() == "true" +WEAVIATE_GRPC_SECURE = os.environ.get("WEAVIATE_GRPC_SECURE", "false").lower() == "true" +WEAVIATE_SKIP_INIT_CHECKS = ( + os.environ.get("WEAVIATE_SKIP_INIT_CHECKS", "false").lower() == "true" +) # OpenSearch OPENSEARCH_URI = os.environ.get("OPENSEARCH_URI", "https://localhost:9200") @@ -2749,6 +2797,12 @@ class BannerModel(BaseModel): os.environ.get("PDF_EXTRACT_IMAGES", "False").lower() == "true", ) +PDF_LOADER_MODE = PersistentConfig( + "PDF_LOADER_MODE", + "rag.pdf_loader_mode", + os.environ.get("PDF_LOADER_MODE", "page"), +) + RAG_EMBEDDING_MODEL = PersistentConfig( "RAG_EMBEDDING_MODEL", "rag.embedding_model", @@ -3330,6 +3384,24 @@ class BannerModel(BaseModel): os.environ.get("EXTERNAL_WEB_LOADER_API_KEY", ""), ) +YANDEX_WEB_SEARCH_URL = PersistentConfig( + "YANDEX_WEB_SEARCH_URL", + "rag.web.search.yandex_web_search_url", + os.environ.get("YANDEX_WEB_SEARCH_URL", ""), +) + +YANDEX_WEB_SEARCH_API_KEY = PersistentConfig( + "YANDEX_WEB_SEARCH_API_KEY", + "rag.web.search.yandex_web_search_api_key", + os.environ.get("YANDEX_WEB_SEARCH_API_KEY", ""), +) + +YANDEX_WEB_SEARCH_CONFIG = PersistentConfig( + "YANDEX_WEB_SEARCH_CONFIG", + "rag.web.search.yandex_web_search_config", + os.environ.get("YANDEX_WEB_SEARCH_CONFIG", ""), +) + #################################### # Images #################################### @@ -3352,6 +3424,16 @@ class BannerModel(BaseModel): os.getenv("IMAGE_GENERATION_MODEL", ""), ) +# Regex pattern for models that support IMAGE_SIZE = "auto". +IMAGE_AUTO_SIZE_MODELS_REGEX_PATTERN = os.getenv( + "IMAGE_AUTO_SIZE_MODELS_REGEX_PATTERN", "^gpt-image" +) + +# Regex pattern for models that return URLs instead of base64 data. +IMAGE_URL_RESPONSE_MODELS_REGEX_PATTERN = os.getenv( + "IMAGE_URL_RESPONSE_MODELS_REGEX_PATTERN", "^gpt-image" +) + IMAGE_SIZE = PersistentConfig( "IMAGE_SIZE", "image_generation.size", os.getenv("IMAGE_SIZE", "512x512") ) diff --git a/backend/open_webui/env.py b/backend/open_webui/env.py index 91be977371..f589073e8a 100644 --- a/backend/open_webui/env.py +++ b/backend/open_webui/env.py @@ -196,11 +196,35 @@ def parse_section(section): os.environ.get("ENABLE_FORWARD_USER_INFO_HEADERS", "False").lower() == "true" ) +# Header names for user info forwarding (customizable via environment variables) +FORWARD_USER_INFO_HEADER_USER_NAME = os.environ.get( + "FORWARD_USER_INFO_HEADER_USER_NAME", "X-OpenWebUI-User-Name" +) +FORWARD_USER_INFO_HEADER_USER_ID = os.environ.get( + "FORWARD_USER_INFO_HEADER_USER_ID", "X-OpenWebUI-User-Id" +) +FORWARD_USER_INFO_HEADER_USER_EMAIL = os.environ.get( + "FORWARD_USER_INFO_HEADER_USER_EMAIL", "X-OpenWebUI-User-Email" +) +FORWARD_USER_INFO_HEADER_USER_ROLE = os.environ.get( + "FORWARD_USER_INFO_HEADER_USER_ROLE", "X-OpenWebUI-User-Role" +) + +# Header name for chat ID forwarding (customizable via environment variable) +FORWARD_SESSION_INFO_HEADER_MESSAGE_ID = os.environ.get( + "FORWARD_SESSION_INFO_HEADER_MESSAGE_ID", "X-OpenWebUI-Message-Id" +) +FORWARD_SESSION_INFO_HEADER_CHAT_ID = os.environ.get( + "FORWARD_SESSION_INFO_HEADER_CHAT_ID", "X-OpenWebUI-Chat-Id" +) + # Experimental feature, may be removed in future ENABLE_STAR_SESSIONS_MIDDLEWARE = ( os.environ.get("ENABLE_STAR_SESSIONS_MIDDLEWARE", "False").lower() == "true" ) +ENABLE_EASTER_EGGS = os.environ.get("ENABLE_EASTER_EGGS", "True").lower() == "true" + #################################### # WEBUI_BUILD_HASH #################################### @@ -393,6 +417,18 @@ def parse_section(section): except ValueError: REDIS_SOCKET_CONNECT_TIMEOUT = None +REDIS_RECONNECT_DELAY = os.environ.get("REDIS_RECONNECT_DELAY", "") + +if REDIS_RECONNECT_DELAY == "": + REDIS_RECONNECT_DELAY = None +else: + try: + REDIS_RECONNECT_DELAY = float(REDIS_RECONNECT_DELAY) + if REDIS_RECONNECT_DELAY < 0: + REDIS_RECONNECT_DELAY = None + except Exception: + REDIS_RECONNECT_DELAY = None + #################################### # UVICORN WORKERS #################################### @@ -457,6 +493,8 @@ def parse_section(section): r"^(?=.*[a-z])(?=.*[A-Z])(?=.*\d)(?=.*[^\w\s]).{8,}$" ) +PASSWORD_VALIDATION_HINT = os.environ.get("PASSWORD_VALIDATION_HINT", "") + BYPASS_MODEL_ACCESS_CONTROL = ( os.environ.get("BYPASS_MODEL_ACCESS_CONTROL", "False").lower() == "true" @@ -521,6 +559,12 @@ def parse_section(section): "OAUTH_SESSION_TOKEN_ENCRYPTION_KEY", WEBUI_SECRET_KEY ) +# Token Exchange Configuration +# Allows external apps to exchange OAuth tokens for OpenWebUI tokens +ENABLE_OAUTH_TOKEN_EXCHANGE = ( + os.environ.get("ENABLE_OAUTH_TOKEN_EXCHANGE", "False").lower() == "true" +) + #################################### # SCIM Configuration #################################### @@ -547,15 +591,11 @@ def parse_section(section): pk = None if LICENSE_PUBLIC_KEY: - pk = serialization.load_pem_public_key( - f""" + pk = serialization.load_pem_public_key(f""" -----BEGIN PUBLIC KEY----- {LICENSE_PUBLIC_KEY} -----END PUBLIC KEY----- -""".encode( - "utf-8" - ) - ) +""".encode("utf-8")) #################################### @@ -675,7 +715,11 @@ def parse_section(section): os.environ.get("WEBSOCKET_SERVER_LOGGING", "False").lower() == "true" ) WEBSOCKET_SERVER_ENGINEIO_LOGGING = ( - os.environ.get("WEBSOCKET_SERVER_LOGGING", "False").lower() == "true" + os.environ.get( + "WEBSOCKET_SERVER_ENGINEIO_LOGGING", + os.environ.get("WEBSOCKET_SERVER_LOGGING", "False"), + ).lower() + == "true" ) WEBSOCKET_SERVER_PING_TIMEOUT = os.environ.get("WEBSOCKET_SERVER_PING_TIMEOUT", "20") try: @@ -745,6 +789,17 @@ def parse_section(section): ) +RAG_EMBEDDING_TIMEOUT = os.environ.get("RAG_EMBEDDING_TIMEOUT", "") + +if RAG_EMBEDDING_TIMEOUT == "": + RAG_EMBEDDING_TIMEOUT = None +else: + try: + RAG_EMBEDDING_TIMEOUT = int(RAG_EMBEDDING_TIMEOUT) + except Exception: + RAG_EMBEDDING_TIMEOUT = None + + #################################### # SENTENCE TRANSFORMERS #################################### diff --git a/backend/open_webui/internal/migrations/001_initial_schema.py b/backend/open_webui/internal/migrations/001_initial_schema.py index 93f278f15b..0df2249b21 100644 --- a/backend/open_webui/internal/migrations/001_initial_schema.py +++ b/backend/open_webui/internal/migrations/001_initial_schema.py @@ -29,7 +29,6 @@ import peewee as pw from peewee_migrate import Migrator - with suppress(ImportError): import playhouse.postgres_ext as pw_pext diff --git a/backend/open_webui/internal/migrations/002_add_local_sharing.py b/backend/open_webui/internal/migrations/002_add_local_sharing.py index e93501aeec..a01862d103 100644 --- a/backend/open_webui/internal/migrations/002_add_local_sharing.py +++ b/backend/open_webui/internal/migrations/002_add_local_sharing.py @@ -29,7 +29,6 @@ import peewee as pw from peewee_migrate import Migrator - with suppress(ImportError): import playhouse.postgres_ext as pw_pext diff --git a/backend/open_webui/internal/migrations/003_add_auth_api_key.py b/backend/open_webui/internal/migrations/003_add_auth_api_key.py index 07144f3aca..23cba26383 100644 --- a/backend/open_webui/internal/migrations/003_add_auth_api_key.py +++ b/backend/open_webui/internal/migrations/003_add_auth_api_key.py @@ -29,7 +29,6 @@ import peewee as pw from peewee_migrate import Migrator - with suppress(ImportError): import playhouse.postgres_ext as pw_pext diff --git a/backend/open_webui/internal/migrations/004_add_archived.py b/backend/open_webui/internal/migrations/004_add_archived.py index d01c06b4e6..11108a3e0b 100644 --- a/backend/open_webui/internal/migrations/004_add_archived.py +++ b/backend/open_webui/internal/migrations/004_add_archived.py @@ -29,7 +29,6 @@ import peewee as pw from peewee_migrate import Migrator - with suppress(ImportError): import playhouse.postgres_ext as pw_pext diff --git a/backend/open_webui/internal/migrations/005_add_updated_at.py b/backend/open_webui/internal/migrations/005_add_updated_at.py index 950866ef02..f7fc69a5db 100644 --- a/backend/open_webui/internal/migrations/005_add_updated_at.py +++ b/backend/open_webui/internal/migrations/005_add_updated_at.py @@ -29,7 +29,6 @@ import peewee as pw from peewee_migrate import Migrator - with suppress(ImportError): import playhouse.postgres_ext as pw_pext diff --git a/backend/open_webui/internal/migrations/006_migrate_timestamps_and_charfields.py b/backend/open_webui/internal/migrations/006_migrate_timestamps_and_charfields.py index caca14d323..abe7016c57 100644 --- a/backend/open_webui/internal/migrations/006_migrate_timestamps_and_charfields.py +++ b/backend/open_webui/internal/migrations/006_migrate_timestamps_and_charfields.py @@ -29,7 +29,6 @@ import peewee as pw from peewee_migrate import Migrator - with suppress(ImportError): import playhouse.postgres_ext as pw_pext diff --git a/backend/open_webui/internal/migrations/007_add_user_last_active_at.py b/backend/open_webui/internal/migrations/007_add_user_last_active_at.py index dd176ba73e..3f89a5f59f 100644 --- a/backend/open_webui/internal/migrations/007_add_user_last_active_at.py +++ b/backend/open_webui/internal/migrations/007_add_user_last_active_at.py @@ -29,7 +29,6 @@ import peewee as pw from peewee_migrate import Migrator - with suppress(ImportError): import playhouse.postgres_ext as pw_pext diff --git a/backend/open_webui/internal/migrations/008_add_memory.py b/backend/open_webui/internal/migrations/008_add_memory.py index 9307aa4d5c..96be907eba 100644 --- a/backend/open_webui/internal/migrations/008_add_memory.py +++ b/backend/open_webui/internal/migrations/008_add_memory.py @@ -29,7 +29,6 @@ import peewee as pw from peewee_migrate import Migrator - with suppress(ImportError): import playhouse.postgres_ext as pw_pext diff --git a/backend/open_webui/internal/migrations/009_add_models.py b/backend/open_webui/internal/migrations/009_add_models.py index 548ec7cdca..0a8d73bd3b 100644 --- a/backend/open_webui/internal/migrations/009_add_models.py +++ b/backend/open_webui/internal/migrations/009_add_models.py @@ -29,7 +29,6 @@ import peewee as pw from peewee_migrate import Migrator - with suppress(ImportError): import playhouse.postgres_ext as pw_pext diff --git a/backend/open_webui/internal/migrations/011_add_user_settings.py b/backend/open_webui/internal/migrations/011_add_user_settings.py index a1620dcada..c3b9ab6edc 100644 --- a/backend/open_webui/internal/migrations/011_add_user_settings.py +++ b/backend/open_webui/internal/migrations/011_add_user_settings.py @@ -29,7 +29,6 @@ import peewee as pw from peewee_migrate import Migrator - with suppress(ImportError): import playhouse.postgres_ext as pw_pext diff --git a/backend/open_webui/internal/migrations/012_add_tools.py b/backend/open_webui/internal/migrations/012_add_tools.py index 4a68eea552..ac3cd8bfec 100644 --- a/backend/open_webui/internal/migrations/012_add_tools.py +++ b/backend/open_webui/internal/migrations/012_add_tools.py @@ -29,7 +29,6 @@ import peewee as pw from peewee_migrate import Migrator - with suppress(ImportError): import playhouse.postgres_ext as pw_pext diff --git a/backend/open_webui/internal/migrations/013_add_user_info.py b/backend/open_webui/internal/migrations/013_add_user_info.py index 0f68669cca..6fafa951f0 100644 --- a/backend/open_webui/internal/migrations/013_add_user_info.py +++ b/backend/open_webui/internal/migrations/013_add_user_info.py @@ -29,7 +29,6 @@ import peewee as pw from peewee_migrate import Migrator - with suppress(ImportError): import playhouse.postgres_ext as pw_pext diff --git a/backend/open_webui/internal/migrations/014_add_files.py b/backend/open_webui/internal/migrations/014_add_files.py index 5e1acf0ad8..655b00d238 100644 --- a/backend/open_webui/internal/migrations/014_add_files.py +++ b/backend/open_webui/internal/migrations/014_add_files.py @@ -29,7 +29,6 @@ import peewee as pw from peewee_migrate import Migrator - with suppress(ImportError): import playhouse.postgres_ext as pw_pext diff --git a/backend/open_webui/internal/migrations/015_add_functions.py b/backend/open_webui/internal/migrations/015_add_functions.py index 8316a9333b..84d2843839 100644 --- a/backend/open_webui/internal/migrations/015_add_functions.py +++ b/backend/open_webui/internal/migrations/015_add_functions.py @@ -29,7 +29,6 @@ import peewee as pw from peewee_migrate import Migrator - with suppress(ImportError): import playhouse.postgres_ext as pw_pext diff --git a/backend/open_webui/internal/migrations/016_add_valves_and_is_active.py b/backend/open_webui/internal/migrations/016_add_valves_and_is_active.py index e3af521b7e..fadf964e46 100644 --- a/backend/open_webui/internal/migrations/016_add_valves_and_is_active.py +++ b/backend/open_webui/internal/migrations/016_add_valves_and_is_active.py @@ -29,7 +29,6 @@ import peewee as pw from peewee_migrate import Migrator - with suppress(ImportError): import playhouse.postgres_ext as pw_pext diff --git a/backend/open_webui/internal/migrations/017_add_user_oauth_sub.py b/backend/open_webui/internal/migrations/017_add_user_oauth_sub.py index eaa3fa5fe5..67a36b4889 100644 --- a/backend/open_webui/internal/migrations/017_add_user_oauth_sub.py +++ b/backend/open_webui/internal/migrations/017_add_user_oauth_sub.py @@ -25,7 +25,6 @@ import peewee as pw from peewee_migrate import Migrator - with suppress(ImportError): import playhouse.postgres_ext as pw_pext diff --git a/backend/open_webui/internal/migrations/018_add_function_is_global.py b/backend/open_webui/internal/migrations/018_add_function_is_global.py index 04cdab7059..1e932ed710 100644 --- a/backend/open_webui/internal/migrations/018_add_function_is_global.py +++ b/backend/open_webui/internal/migrations/018_add_function_is_global.py @@ -29,7 +29,6 @@ import peewee as pw from peewee_migrate import Migrator - with suppress(ImportError): import playhouse.postgres_ext as pw_pext diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index a33e765ff3..e6523b3db9 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -61,6 +61,7 @@ get_models_in_use, ) from open_webui.routers import ( + analytics, audio, images, ollama, @@ -82,6 +83,7 @@ knowledge, prompts, evaluations, + skills, tools, users, utils, @@ -278,6 +280,7 @@ ENABLE_MARKDOWN_HEADER_TEXT_SPLITTER, TIKTOKEN_ENCODING_NAME, PDF_EXTRACT_IMAGES, + PDF_LOADER_MODE, YOUTUBE_LOADER_LANGUAGE, YOUTUBE_LOADER_PROXY_URL, # Retrieval (Web Search) @@ -342,6 +345,9 @@ EXTERNAL_WEB_SEARCH_API_KEY, EXTERNAL_WEB_LOADER_URL, EXTERNAL_WEB_LOADER_API_KEY, + YANDEX_WEB_SEARCH_URL, + YANDEX_WEB_SEARCH_API_KEY, + YANDEX_WEB_SEARCH_CONFIG, # WebUI WEBUI_AUTH, WEBUI_NAME, @@ -514,6 +520,7 @@ WEBUI_ADMIN_EMAIL, WEBUI_ADMIN_PASSWORD, WEBUI_ADMIN_NAME, + ENABLE_EASTER_EGGS, ) from open_webui.utils.models import ( @@ -525,11 +532,14 @@ from open_webui.utils.chat import ( generate_chat_completion as chat_completion_handler, chat_completed as chat_completed_handler, - chat_action as chat_action_handler, ) +from open_webui.utils.actions import chat_action as chat_action_handler from open_webui.utils.embeddings import generate_embeddings -from open_webui.utils.middleware import process_chat_payload, process_chat_response -from open_webui.utils.access_control import has_access +from open_webui.utils.middleware import ( + build_chat_response_context, + process_chat_payload, + process_chat_response, +) from open_webui.utils.auth import ( get_license_data, @@ -564,7 +574,6 @@ from open_webui.constants import ERROR_MESSAGES - if SAFE_MODE: print("SAFE MODE ENABLED") Functions.deactivate_all_functions() @@ -588,8 +597,7 @@ async def get_response(self, path: str, scope): raise ex -print( - rf""" +print(rf""" ██████╗ ██████╗ ███████╗███╗ ██╗ ██╗ ██╗███████╗██████╗ ██╗ ██╗██╗ ██╔═══██╗██╔══██╗██╔════╝████╗ ██║ ██║ ██║██╔════╝██╔══██╗██║ ██║██║ ██║ ██║██████╔╝█████╗ ██╔██╗ ██║ ██║ █╗ ██║█████╗ ██████╔╝██║ ██║██║ @@ -601,12 +609,15 @@ async def get_response(self, path: str, scope): v{VERSION} - building the best AI user interface. {f"Commit: {WEBUI_BUILD_HASH}" if WEBUI_BUILD_HASH != "dev-build" else ""} https://github.com/open-webui/open-webui -""" -) +""") @asynccontextmanager async def lifespan(app: FastAPI): + # Store reference to main event loop for sync->async calls (e.g., embedding generation) + # This allows sync functions to schedule work on the main loop without blocking health checks + app.state.main_loop = asyncio.get_running_loop() + app.state.instance_id = INSTANCE_ID start_logger() @@ -834,6 +845,21 @@ async def lifespan(app: FastAPI): app.state.config.ENABLE_EVALUATION_ARENA_MODELS = ENABLE_EVALUATION_ARENA_MODELS app.state.config.EVALUATION_ARENA_MODELS = EVALUATION_ARENA_MODELS +# Migrate legacy access_control → access_grants on boot +from open_webui.utils.access_control import migrate_access_control + +connections = app.state.config.TOOL_SERVER_CONNECTIONS +if any("access_control" in c.get("config", {}) for c in connections): + for connection in connections: + migrate_access_control(connection.get("config", {})) + app.state.config.TOOL_SERVER_CONNECTIONS = connections + +arena_models = app.state.config.EVALUATION_ARENA_MODELS +if any("access_control" in m.get("meta", {}) for m in arena_models): + for model in arena_models: + migrate_access_control(model.get("meta", {})) + app.state.config.EVALUATION_ARENA_MODELS = arena_models + app.state.config.OAUTH_USERNAME_CLAIM = OAUTH_USERNAME_CLAIM app.state.config.OAUTH_PICTURE_CLAIM = OAUTH_PICTURE_CLAIM app.state.config.OAUTH_EMAIL_CLAIM = OAUTH_EMAIL_CLAIM @@ -971,6 +997,7 @@ async def lifespan(app: FastAPI): app.state.config.RAG_OLLAMA_API_KEY = RAG_OLLAMA_API_KEY app.state.config.PDF_EXTRACT_IMAGES = PDF_EXTRACT_IMAGES +app.state.config.PDF_LOADER_MODE = PDF_LOADER_MODE app.state.config.YOUTUBE_LOADER_LANGUAGE = YOUTUBE_LOADER_LANGUAGE app.state.config.YOUTUBE_LOADER_PROXY_URL = YOUTUBE_LOADER_PROXY_URL @@ -1031,6 +1058,9 @@ async def lifespan(app: FastAPI): app.state.config.EXTERNAL_WEB_SEARCH_API_KEY = EXTERNAL_WEB_SEARCH_API_KEY app.state.config.EXTERNAL_WEB_LOADER_URL = EXTERNAL_WEB_LOADER_URL app.state.config.EXTERNAL_WEB_LOADER_API_KEY = EXTERNAL_WEB_LOADER_API_KEY +app.state.config.YANDEX_WEB_SEARCH_URL = YANDEX_WEB_SEARCH_URL +app.state.config.YANDEX_WEB_SEARCH_API_KEY = YANDEX_WEB_SEARCH_API_KEY +app.state.config.YANDEX_WEB_SEARCH_CONFIG = YANDEX_WEB_SEARCH_CONFIG app.state.config.PLAYWRIGHT_WS_URL = PLAYWRIGHT_WS_URL app.state.config.PLAYWRIGHT_TIMEOUT = PLAYWRIGHT_TIMEOUT @@ -1374,7 +1404,9 @@ async def dispatch(self, request: Request, call_next): token = None if auth_header: - scheme, token = auth_header.split(" ") + parts = auth_header.split(" ", 1) + if len(parts) == 2: + token = parts[1] # Only apply restrictions if an sk- API key is used if token and token.startswith("sk-"): @@ -1415,7 +1447,13 @@ async def dispatch(self, request: Request, call_next): async def commit_session_after_request(request: Request, call_next): response = await call_next(request) # log.debug("Commit session after request") - ScopedSession.commit() + try: + ScopedSession.commit() + finally: + # CRITICAL: remove() returns the connection to the pool. + # Without this, connections remain "checked out" and accumulate + # as "idle in transaction" in PostgreSQL. + ScopedSession.remove() return response @@ -1425,6 +1463,13 @@ async def check_url(request: Request, call_next): request.state.token = get_http_authorization_cred( request.headers.get("Authorization") ) + # Fallback to cookie token for browser sessions + if request.state.token is None and request.cookies.get("token"): + from fastapi.security import HTTPAuthorizationCredentials + + request.state.token = HTTPAuthorizationCredentials( + scheme="Bearer", credentials=request.cookies.get("token") + ) request.state.enable_api_keys = app.state.config.ENABLE_API_KEYS response = await call_next(request) @@ -1487,6 +1532,7 @@ async def inspect_websocket(request: Request, call_next): app.include_router(knowledge.router, prefix="/api/v1/knowledge", tags=["knowledge"]) app.include_router(prompts.router, prefix="/api/v1/prompts", tags=["prompts"]) app.include_router(tools.router, prefix="/api/v1/tools", tags=["tools"]) +app.include_router(skills.router, prefix="/api/v1/skills", tags=["skills"]) app.include_router(memories.router, prefix="/api/v1/memories", tags=["memories"]) app.include_router(folders.router, prefix="/api/v1/folders", tags=["folders"]) @@ -1496,6 +1542,7 @@ async def inspect_websocket(request: Request, call_next): app.include_router( evaluations.router, prefix="/api/v1/evaluations", tags=["evaluations"] ) +app.include_router(analytics.router, prefix="/api/v1/analytics", tags=["analytics"]) app.include_router(utils.router, prefix="/api/v1/utils", tags=["utils"]) # SCIM 2.0 API for identity management @@ -1802,9 +1849,11 @@ async def process_chat(request, form_data, user, metadata, model): except: pass - return await process_chat_response( - request, response, form_data, user, metadata, model, events, tasks + ctx = build_chat_response_context( + request, form_data, user, model, metadata, tasks, events ) + + return await process_chat_response(response, ctx) except asyncio.CancelledError: log.info("Chat processing was cancelled") try: @@ -1854,6 +1903,16 @@ async def process_chat(request, form_data, user, metadata, model): except Exception as e: log.debug(f"Error cleaning up: {e}") pass + # Emit chat:active=false when task completes + try: + if metadata.get("chat_id"): + event_emitter = get_event_emitter(metadata, update_db=False) + if event_emitter: + await event_emitter( + {"type": "chat:active", "data": {"active": False}} + ) + except Exception as e: + log.debug(f"Error emitting chat:active: {e}") if ( metadata.get("session_id") @@ -1866,6 +1925,10 @@ async def process_chat(request, form_data, user, metadata, model): process_chat(request, form_data, user, metadata, model), id=metadata["chat_id"], ) + # Emit chat:active=true when task starts + event_emitter = get_event_emitter(metadata, update_db=False) + if event_emitter: + await event_emitter({"type": "chat:active", "data": {"active": True}}) return {"status": True, "task_id": task_id} else: return await process_chat(request, form_data, user, metadata, model) @@ -2007,6 +2070,7 @@ async def get_app_config(request: Request): "enable_websocket": ENABLE_WEBSOCKET_SUPPORT, "enable_version_update_check": ENABLE_VERSION_UPDATE_CHECK, "enable_public_active_users_count": ENABLE_PUBLIC_ACTIVE_USERS_COUNT, + "enable_easter_eggs": ENABLE_EASTER_EGGS, **( { "enable_direct_connections": app.state.config.ENABLE_DIRECT_CONNECTIONS, diff --git a/backend/open_webui/migrations/versions/2f1211949ecc_update_message_and_channel_member_table.py b/backend/open_webui/migrations/versions/2f1211949ecc_update_message_and_channel_member_table.py index 2d72583ebe..1a4ae73180 100644 --- a/backend/open_webui/migrations/versions/2f1211949ecc_update_message_and_channel_member_table.py +++ b/backend/open_webui/migrations/versions/2f1211949ecc_update_message_and_channel_member_table.py @@ -12,7 +12,6 @@ import sqlalchemy as sa import open_webui.internal.db - # revision identifiers, used by Alembic. revision: str = "2f1211949ecc" down_revision: Union[str, None] = "37f288994c47" diff --git a/backend/open_webui/migrations/versions/374d2f66af06_add_prompt_history_table.py b/backend/open_webui/migrations/versions/374d2f66af06_add_prompt_history_table.py new file mode 100644 index 0000000000..57bc8748e3 --- /dev/null +++ b/backend/open_webui/migrations/versions/374d2f66af06_add_prompt_history_table.py @@ -0,0 +1,247 @@ +"""Add prompt history table + +Revision ID: 374d2f66af06 +Revises: c440947495f3 +Create Date: 2026-01-23 17:15:00.000000 + +""" + +from typing import Sequence, Union +import uuid + +from alembic import op +import sqlalchemy as sa + +revision: str = "374d2f66af06" +down_revision: Union[str, None] = "c440947495f3" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + conn = op.get_bind() + + # Step 1: Read existing data from OLD table (schema likely command as PK) + # We use batch_alter previously, but we want to move to new table. + # We need to assume the OLD structure. + + old_prompt_table = sa.table( + "prompt", + sa.column("command", sa.Text()), + sa.column("user_id", sa.Text()), + sa.column("title", sa.Text()), + sa.column("content", sa.Text()), + sa.column("timestamp", sa.BigInteger()), + sa.column("access_control", sa.JSON()), + ) + + # Check if table exists/read data + try: + existing_prompts = conn.execute( + sa.select( + old_prompt_table.c.command, + old_prompt_table.c.user_id, + old_prompt_table.c.title, + old_prompt_table.c.content, + old_prompt_table.c.timestamp, + old_prompt_table.c.access_control, + ) + ).fetchall() + except Exception: + # Fallback if table doesn't exist (new install) + existing_prompts = [] + + # Step 2: Create new prompt table with 'id' as PRIMARY KEY + op.create_table( + "prompt_new", + sa.Column("id", sa.Text(), primary_key=True), + sa.Column("command", sa.String(), unique=True, index=True), + sa.Column("user_id", sa.String(), nullable=False), + sa.Column("name", sa.Text(), nullable=False), + sa.Column("content", sa.Text(), nullable=False), + sa.Column("data", sa.JSON(), nullable=True), + sa.Column("meta", sa.JSON(), nullable=True), + sa.Column("access_control", sa.JSON(), nullable=True), + sa.Column("is_active", sa.Boolean(), nullable=False, server_default="1"), + sa.Column("version_id", sa.Text(), nullable=True), + sa.Column("tags", sa.JSON(), nullable=True), + sa.Column("created_at", sa.BigInteger(), nullable=False), + sa.Column("updated_at", sa.BigInteger(), nullable=False), + ) + + # Step 3: Create prompt_history table + op.create_table( + "prompt_history", + sa.Column("id", sa.Text(), primary_key=True), + sa.Column("prompt_id", sa.Text(), nullable=False, index=True), + sa.Column("parent_id", sa.Text(), nullable=True), + sa.Column("snapshot", sa.JSON(), nullable=False), + sa.Column("user_id", sa.Text(), nullable=False), + sa.Column("commit_message", sa.Text(), nullable=True), + sa.Column("created_at", sa.BigInteger(), nullable=False), + ) + + # Step 4: Migrate data + prompt_new_table = sa.table( + "prompt_new", + sa.column("id", sa.Text()), + sa.column("command", sa.String()), + sa.column("user_id", sa.String()), + sa.column("name", sa.Text()), + sa.column("content", sa.Text()), + sa.column("data", sa.JSON()), + sa.column("meta", sa.JSON()), + sa.column("access_control", sa.JSON()), + sa.column("is_active", sa.Boolean()), + sa.column("version_id", sa.Text()), + sa.column("tags", sa.JSON()), + sa.column("created_at", sa.BigInteger()), + sa.column("updated_at", sa.BigInteger()), + ) + + prompt_history_table = sa.table( + "prompt_history", + sa.column("id", sa.Text()), + sa.column("prompt_id", sa.Text()), + sa.column("parent_id", sa.Text()), + sa.column("snapshot", sa.JSON()), + sa.column("user_id", sa.Text()), + sa.column("commit_message", sa.Text()), + sa.column("created_at", sa.BigInteger()), + ) + + for row in existing_prompts: + command = row[0] + user_id = row[1] + title = row[2] + content = row[3] + timestamp = row[4] + access_control = row[5] + + new_uuid = str(uuid.uuid4()) + history_uuid = str(uuid.uuid4()) + clean_command = command[1:] if command and command.startswith("/") else command + + # Insert into prompt_new + conn.execute( + sa.insert(prompt_new_table).values( + id=new_uuid, + command=clean_command, + user_id=user_id, + name=title, + content=content, + data={}, + meta={}, + access_control=access_control, + is_active=True, + version_id=history_uuid, + tags=[], + created_at=timestamp, + updated_at=timestamp, + ) + ) + + # Create initial history entry + conn.execute( + sa.insert(prompt_history_table).values( + id=history_uuid, + prompt_id=new_uuid, + parent_id=None, + snapshot={ + "name": title, + "content": content, + "command": clean_command, + "data": {}, + "meta": {}, + "access_control": access_control, + }, + user_id=user_id, + commit_message=None, + created_at=timestamp, + ) + ) + + # Step 5: Replace old table with new one + op.drop_table("prompt") + op.rename_table("prompt_new", "prompt") + + +def downgrade() -> None: + conn = op.get_bind() + + # Step 1: Read new data + prompt_table = sa.table( + "prompt", + sa.column("command", sa.String()), + sa.column("name", sa.Text()), + sa.column("created_at", sa.BigInteger()), + sa.column("user_id", sa.Text()), + sa.column("content", sa.Text()), + sa.column("access_control", sa.JSON()), + ) + + try: + current_data = conn.execute( + sa.select( + prompt_table.c.command, + prompt_table.c.name, + prompt_table.c.created_at, + prompt_table.c.user_id, + prompt_table.c.content, + prompt_table.c.access_control, + ) + ).fetchall() + except Exception: + current_data = [] + + # Step 2: Drop history and table + op.drop_table("prompt_history") + op.drop_table("prompt") + + # Step 3: Recreate old table (command as PK?) + # Assuming old schema: + op.create_table( + "prompt", + sa.Column("command", sa.String(), primary_key=True), + sa.Column("user_id", sa.String()), + sa.Column("title", sa.Text()), + sa.Column("content", sa.Text()), + sa.Column("timestamp", sa.BigInteger()), + sa.Column("access_control", sa.JSON()), + sa.Column("id", sa.Integer(), nullable=True), + ) + + # Step 4: Restore data + old_prompt_table = sa.table( + "prompt", + sa.column("command", sa.String()), + sa.column("user_id", sa.String()), + sa.column("title", sa.Text()), + sa.column("content", sa.Text()), + sa.column("timestamp", sa.BigInteger()), + sa.column("access_control", sa.JSON()), + ) + + for row in current_data: + command = row[0] + name = row[1] + created_at = row[2] + user_id = row[3] + content = row[4] + access_control = row[5] + + # Restore leading / + old_command = ( + "/" + command if command and not command.startswith("/") else command + ) + + conn.execute( + sa.insert(old_prompt_table).values( + command=old_command, + user_id=user_id, + title=name, + content=content, + timestamp=created_at, + access_control=access_control, + ) + ) diff --git a/backend/open_webui/migrations/versions/37f288994c47_add_group_member_table.py b/backend/open_webui/migrations/versions/37f288994c47_add_group_member_table.py index 0c5cec1941..229bb8cffb 100644 --- a/backend/open_webui/migrations/versions/37f288994c47_add_group_member_table.py +++ b/backend/open_webui/migrations/versions/37f288994c47_add_group_member_table.py @@ -14,7 +14,6 @@ from alembic import op import sqlalchemy as sa - # revision identifiers, used by Alembic. revision: str = "37f288994c47" down_revision: Union[str, None] = "a5c220713937" diff --git a/backend/open_webui/migrations/versions/38d63c18f30f_add_oauth_session_table.py b/backend/open_webui/migrations/versions/38d63c18f30f_add_oauth_session_table.py index 264ce13b41..af8340a3cb 100644 --- a/backend/open_webui/migrations/versions/38d63c18f30f_add_oauth_session_table.py +++ b/backend/open_webui/migrations/versions/38d63c18f30f_add_oauth_session_table.py @@ -11,7 +11,6 @@ from alembic import op import sqlalchemy as sa - # revision identifiers, used by Alembic. revision: str = "38d63c18f30f" down_revision: Union[str, None] = "3af16a1c9fb6" diff --git a/backend/open_webui/migrations/versions/6283dc0e4d8d_add_channel_file_table.py b/backend/open_webui/migrations/versions/6283dc0e4d8d_add_channel_file_table.py index 59fe57a421..f3ef62fd64 100644 --- a/backend/open_webui/migrations/versions/6283dc0e4d8d_add_channel_file_table.py +++ b/backend/open_webui/migrations/versions/6283dc0e4d8d_add_channel_file_table.py @@ -12,7 +12,6 @@ import sqlalchemy as sa import open_webui.internal.db - # revision identifiers, used by Alembic. revision: str = "6283dc0e4d8d" down_revision: Union[str, None] = "3e0e00844bb0" diff --git a/backend/open_webui/migrations/versions/6a39f3d8e55c_add_knowledge_table.py b/backend/open_webui/migrations/versions/6a39f3d8e55c_add_knowledge_table.py index 881e6ae641..d6083d7177 100644 --- a/backend/open_webui/migrations/versions/6a39f3d8e55c_add_knowledge_table.py +++ b/backend/open_webui/migrations/versions/6a39f3d8e55c_add_knowledge_table.py @@ -11,7 +11,6 @@ from sqlalchemy.sql import table, column, select import json - revision = "6a39f3d8e55c" down_revision = "c0fbf31ca0db" branch_labels = None diff --git a/backend/open_webui/migrations/versions/81cc2ce44d79_update_channel_file_and_knowledge_table.py b/backend/open_webui/migrations/versions/81cc2ce44d79_update_channel_file_and_knowledge_table.py index 181b280666..3853ec50d9 100644 --- a/backend/open_webui/migrations/versions/81cc2ce44d79_update_channel_file_and_knowledge_table.py +++ b/backend/open_webui/migrations/versions/81cc2ce44d79_update_channel_file_and_knowledge_table.py @@ -12,7 +12,6 @@ import sqlalchemy as sa import open_webui.internal.db - # revision identifiers, used by Alembic. revision: str = "81cc2ce44d79" down_revision: Union[str, None] = "6283dc0e4d8d" diff --git a/backend/open_webui/migrations/versions/8452d01d26d7_add_chat_message_table.py b/backend/open_webui/migrations/versions/8452d01d26d7_add_chat_message_table.py new file mode 100644 index 0000000000..c8a3647aec --- /dev/null +++ b/backend/open_webui/migrations/versions/8452d01d26d7_add_chat_message_table.py @@ -0,0 +1,177 @@ +"""Add chat_message table + +Revision ID: 8452d01d26d7 +Revises: 374d2f66af06 +Create Date: 2026-02-01 04:00:00.000000 + +""" + +import time +import json +import logging +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + +log = logging.getLogger(__name__) + +revision: str = "8452d01d26d7" +down_revision: Union[str, None] = "374d2f66af06" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # Step 1: Create table + op.create_table( + "chat_message", + sa.Column("id", sa.Text(), primary_key=True), + sa.Column("chat_id", sa.Text(), nullable=False, index=True), + sa.Column("user_id", sa.Text(), index=True), + sa.Column("role", sa.Text(), nullable=False), + sa.Column("parent_id", sa.Text(), nullable=True), + sa.Column("content", sa.JSON(), nullable=True), + sa.Column("output", sa.JSON(), nullable=True), + sa.Column("model_id", sa.Text(), nullable=True, index=True), + sa.Column("files", sa.JSON(), nullable=True), + sa.Column("sources", sa.JSON(), nullable=True), + sa.Column("embeds", sa.JSON(), nullable=True), + sa.Column("done", sa.Boolean(), default=True), + sa.Column("status_history", sa.JSON(), nullable=True), + sa.Column("error", sa.JSON(), nullable=True), + sa.Column("usage", sa.JSON(), nullable=True), + sa.Column("created_at", sa.BigInteger(), index=True), + sa.Column("updated_at", sa.BigInteger()), + sa.ForeignKeyConstraint(["chat_id"], ["chat.id"], ondelete="CASCADE"), + ) + + # Create composite indexes + op.create_index( + "chat_message_chat_parent_idx", "chat_message", ["chat_id", "parent_id"] + ) + op.create_index( + "chat_message_model_created_idx", "chat_message", ["model_id", "created_at"] + ) + op.create_index( + "chat_message_user_created_idx", "chat_message", ["user_id", "created_at"] + ) + + # Step 2: Backfill from existing chats + conn = op.get_bind() + + chat_table = sa.table( + "chat", + sa.column("id", sa.Text()), + sa.column("user_id", sa.Text()), + sa.column("chat", sa.JSON()), + ) + + chat_message_table = sa.table( + "chat_message", + sa.column("id", sa.Text()), + sa.column("chat_id", sa.Text()), + sa.column("user_id", sa.Text()), + sa.column("role", sa.Text()), + sa.column("parent_id", sa.Text()), + sa.column("content", sa.JSON()), + sa.column("output", sa.JSON()), + sa.column("model_id", sa.Text()), + sa.column("files", sa.JSON()), + sa.column("sources", sa.JSON()), + sa.column("embeds", sa.JSON()), + sa.column("done", sa.Boolean()), + sa.column("status_history", sa.JSON()), + sa.column("error", sa.JSON()), + sa.column("usage", sa.JSON()), + sa.column("created_at", sa.BigInteger()), + sa.column("updated_at", sa.BigInteger()), + ) + + # Fetch all chats (excluding shared chats which have user_id starting with 'shared-') + chats = conn.execute( + sa.select(chat_table.c.id, chat_table.c.user_id, chat_table.c.chat).where( + ~chat_table.c.user_id.like("shared-%") + ) + ).fetchall() + + now = int(time.time()) + messages_inserted = 0 + messages_failed = 0 + + for chat_row in chats: + chat_id = chat_row[0] + user_id = chat_row[1] + chat_data = chat_row[2] + + if not chat_data: + continue + + # Handle both string and dict chat data + if isinstance(chat_data, str): + try: + chat_data = json.loads(chat_data) + except Exception: + continue + + history = chat_data.get("history", {}) + messages = history.get("messages", {}) + + for message_id, message in messages.items(): + if not isinstance(message, dict): + continue + + role = message.get("role") + if not role: + continue + + timestamp = message.get("timestamp", now) + + # Normalize timestamp: convert ms to seconds, validate range + if timestamp > 10_000_000_000: + timestamp = timestamp // 1000 + # Must be after 2020 and not too far in the future + if timestamp < 1577836800 or timestamp > now + 86400: + timestamp = now + + # Use savepoint to allow individual insert failures without aborting transaction + savepoint = conn.begin_nested() + try: + conn.execute( + sa.insert(chat_message_table).values( + id=f"{chat_id}-{message_id}", + chat_id=chat_id, + user_id=user_id, + role=role, + parent_id=message.get("parentId"), + content=message.get("content"), + output=message.get("output"), + model_id=message.get("model"), + files=message.get("files"), + sources=message.get("sources"), + embeds=message.get("embeds"), + done=message.get("done", True), + status_history=message.get("statusHistory"), + error=message.get("error"), + created_at=timestamp, + updated_at=timestamp, + ) + ) + savepoint.commit() + messages_inserted += 1 + except Exception as e: + savepoint.rollback() + messages_failed += 1 + log.warning(f"Failed to insert message {message_id}: {e}") + continue + + log.info( + f"Backfilled {messages_inserted} messages into chat_message table ({messages_failed} failed)" + ) + + +def downgrade() -> None: + op.drop_index("chat_message_user_created_idx", table_name="chat_message") + op.drop_index("chat_message_model_created_idx", table_name="chat_message") + op.drop_index("chat_message_chat_parent_idx", table_name="chat_message") + op.drop_table("chat_message") diff --git a/backend/open_webui/migrations/versions/90ef40d4714e_update_channel_and_channel_members_table.py b/backend/open_webui/migrations/versions/90ef40d4714e_update_channel_and_channel_members_table.py index 8c52a4b22a..8b9e338309 100644 --- a/backend/open_webui/migrations/versions/90ef40d4714e_update_channel_and_channel_members_table.py +++ b/backend/open_webui/migrations/versions/90ef40d4714e_update_channel_and_channel_members_table.py @@ -12,7 +12,6 @@ import sqlalchemy as sa import open_webui.internal.db - # revision identifiers, used by Alembic. revision: str = "90ef40d4714e" down_revision: Union[str, None] = "b10670c03dd5" diff --git a/backend/open_webui/migrations/versions/a1b2c3d4e5f6_add_skill_table.py b/backend/open_webui/migrations/versions/a1b2c3d4e5f6_add_skill_table.py new file mode 100644 index 0000000000..26e9e66240 --- /dev/null +++ b/backend/open_webui/migrations/versions/a1b2c3d4e5f6_add_skill_table.py @@ -0,0 +1,45 @@ +"""Add skill table + +Revision ID: a1b2c3d4e5f6 +Revises: f1e2d3c4b5a6 +Create Date: 2026-02-11 09:30:00.000000 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + +from open_webui.migrations.util import get_existing_tables + +revision: str = "a1b2c3d4e5f6" +down_revision: Union[str, None] = "f1e2d3c4b5a6" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + existing_tables = set(get_existing_tables()) + + if "skill" not in existing_tables: + op.create_table( + "skill", + sa.Column("id", sa.String(), nullable=False, primary_key=True), + sa.Column("user_id", sa.String(), nullable=False), + sa.Column("name", sa.Text(), nullable=False, unique=True), + sa.Column("description", sa.Text(), nullable=True), + sa.Column("content", sa.Text(), nullable=False), + sa.Column("meta", sa.JSON(), nullable=True), + sa.Column("is_active", sa.Boolean(), nullable=False), + sa.Column("updated_at", sa.BigInteger(), nullable=False), + sa.Column("created_at", sa.BigInteger(), nullable=False), + ) + op.create_index("idx_skill_user_id", "skill", ["user_id"]) + op.create_index("idx_skill_updated_at", "skill", ["updated_at"]) + + +def downgrade() -> None: + op.drop_index("idx_skill_updated_at", table_name="skill") + op.drop_index("idx_skill_user_id", table_name="skill") + op.drop_table("skill") diff --git a/backend/open_webui/migrations/versions/b10670c03dd5_update_user_table.py b/backend/open_webui/migrations/versions/b10670c03dd5_update_user_table.py index f35a382645..0472c08616 100644 --- a/backend/open_webui/migrations/versions/b10670c03dd5_update_user_table.py +++ b/backend/open_webui/migrations/versions/b10670c03dd5_update_user_table.py @@ -173,12 +173,10 @@ def upgrade() -> None: for uid, api_key in users_with_keys: if api_key: conn.execute( - sa.text( - """ + sa.text(""" INSERT INTO api_key (id, user_id, key, created_at, updated_at) VALUES (:id, :user_id, :key, :created_at, :updated_at) - """ - ), + """), { "id": f"key_{uid}", "user_id": uid, diff --git a/backend/open_webui/migrations/versions/c29facfe716b_update_file_table_path.py b/backend/open_webui/migrations/versions/c29facfe716b_update_file_table_path.py index de82854b88..7786de425f 100644 --- a/backend/open_webui/migrations/versions/c29facfe716b_update_file_table_path.py +++ b/backend/open_webui/migrations/versions/c29facfe716b_update_file_table_path.py @@ -12,7 +12,6 @@ from sqlalchemy.sql import table, column from sqlalchemy import String, Text, JSON, and_ - revision = "c29facfe716b" down_revision = "c69f45358db4" branch_labels = None diff --git a/backend/open_webui/migrations/versions/c440947495f3_add_chat_file_table.py b/backend/open_webui/migrations/versions/c440947495f3_add_chat_file_table.py index 20f4a6d7b6..fa818e1f08 100644 --- a/backend/open_webui/migrations/versions/c440947495f3_add_chat_file_table.py +++ b/backend/open_webui/migrations/versions/c440947495f3_add_chat_file_table.py @@ -11,7 +11,6 @@ from alembic import op import sqlalchemy as sa - # revision identifiers, used by Alembic. revision: str = "c440947495f3" down_revision: Union[str, None] = "81cc2ce44d79" diff --git a/backend/open_webui/migrations/versions/f1e2d3c4b5a6_add_access_grant_table.py b/backend/open_webui/migrations/versions/f1e2d3c4b5a6_add_access_grant_table.py new file mode 100644 index 0000000000..5569718dd8 --- /dev/null +++ b/backend/open_webui/migrations/versions/f1e2d3c4b5a6_add_access_grant_table.py @@ -0,0 +1,369 @@ +"""Add access_grant table + +Revision ID: f1e2d3c4b5a6 +Revises: 8452d01d26d7 +Create Date: 2026-02-05 10:00:00.000000 + +Migrates from JSON access_control columns to normalized access_grant table. +Access control semantics: +- NULL: Public access (all users can read) -> insert user:* for read +- {}: Private/owner-only (no grants) -> insert nothing +- {read: {...}, write: {...}}: Custom permissions -> insert specific grants +""" + +from typing import Sequence, Union +import time +import uuid + +from alembic import op +import sqlalchemy as sa + +from open_webui.migrations.util import get_existing_tables + +revision: str = "f1e2d3c4b5a6" +down_revision: Union[str, None] = "8452d01d26d7" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + existing_tables = set(get_existing_tables()) + + # Create access_grant table + if "access_grant" not in existing_tables: + op.create_table( + "access_grant", + sa.Column("id", sa.Text(), nullable=False, primary_key=True), + sa.Column("resource_type", sa.Text(), nullable=False), + sa.Column("resource_id", sa.Text(), nullable=False), + sa.Column("principal_type", sa.Text(), nullable=False), + sa.Column("principal_id", sa.Text(), nullable=False), + sa.Column("permission", sa.Text(), nullable=False), + sa.Column("created_at", sa.BigInteger(), nullable=False), + sa.UniqueConstraint( + "resource_type", + "resource_id", + "principal_type", + "principal_id", + "permission", + name="uq_access_grant_grant", + ), + ) + op.create_index( + "idx_access_grant_resource", + "access_grant", + ["resource_type", "resource_id"], + ) + op.create_index( + "idx_access_grant_principal", + "access_grant", + ["principal_type", "principal_id"], + ) + + # Backfill existing access_control JSON data + conn = op.get_bind() + + # Tables with access_control JSON columns: (table_name, resource_type) + resource_tables = [ + ("knowledge", "knowledge"), + ("prompt", "prompt"), + ("tool", "tool"), + ("model", "model"), + ("note", "note"), + ("channel", "channel"), + ("file", "file"), + ] + + now = int(time.time()) + inserted = set() + + for table_name, resource_type in resource_tables: + if table_name not in existing_tables: + continue + + # Query all rows + try: + result = conn.execute( + sa.text(f'SELECT id, access_control FROM "{table_name}"') + ) + rows = result.fetchall() + except Exception: + continue + + for row in rows: + resource_id = row[0] + access_control_json = row[1] + + # Handle NULL or JSON "null" = public access (user:* for read) + # Could be Python None (SQL NULL) or string "null" (JSON null) + # EXCEPTION: files with NULL are PRIVATE (owner-only), not public + is_null = ( + access_control_json is None + or access_control_json == "null" + or ( + isinstance(access_control_json, str) + and access_control_json.strip().lower() == "null" + ) + ) + if is_null: + # Files: NULL = private (no entry needed, owner has implicit access) + # Other resources: NULL = public (insert user:* for read) + if resource_type == "file": + continue # Private - no entry needed + + key = (resource_type, resource_id, "user", "*", "read") + if key not in inserted: + try: + conn.execute( + sa.text(""" + INSERT INTO access_grant (id, resource_type, resource_id, principal_type, principal_id, permission, created_at) + VALUES (:id, :resource_type, :resource_id, :principal_type, :principal_id, :permission, :created_at) + """), + { + "id": str(uuid.uuid4()), + "resource_type": resource_type, + "resource_id": resource_id, + "principal_type": "user", + "principal_id": "*", + "permission": "read", + "created_at": now, + }, + ) + inserted.add(key) + except Exception: + pass + continue + + # Handle JSON parsing + if isinstance(access_control_json, str): + import json + + try: + access_control_json = json.loads(access_control_json) + except Exception: + continue + + # Handle {} = private/owner-only - NO entries needed + # Owner access is implicit, no grants to store + if not access_control_json or not isinstance(access_control_json, dict): + continue + + # Check if it's effectively empty (no read/write keys with content) + read_data = access_control_json.get("read", {}) + write_data = access_control_json.get("write", {}) + + has_read_grants = read_data.get("group_ids", []) or read_data.get( + "user_ids", [] + ) + has_write_grants = write_data.get("group_ids", []) or write_data.get( + "user_ids", [] + ) + + if not has_read_grants and not has_write_grants: + # Empty permissions = private, no grants needed + continue + + # Extract permissions and insert into access_grant table + for permission in ["read", "write"]: + perm_data = access_control_json.get(permission, {}) + if not perm_data: + continue + + for group_id in perm_data.get("group_ids", []): + key = (resource_type, resource_id, "group", group_id, permission) + if key in inserted: + continue + try: + conn.execute( + sa.text(""" + INSERT INTO access_grant (id, resource_type, resource_id, principal_type, principal_id, permission, created_at) + VALUES (:id, :resource_type, :resource_id, :principal_type, :principal_id, :permission, :created_at) + """), + { + "id": str(uuid.uuid4()), + "resource_type": resource_type, + "resource_id": resource_id, + "principal_type": "group", + "principal_id": group_id, + "permission": permission, + "created_at": now, + }, + ) + inserted.add(key) + except Exception: + pass + + for user_id in perm_data.get("user_ids", []): + key = (resource_type, resource_id, "user", user_id, permission) + if key in inserted: + continue + try: + conn.execute( + sa.text(""" + INSERT INTO access_grant (id, resource_type, resource_id, principal_type, principal_id, permission, created_at) + VALUES (:id, :resource_type, :resource_id, :principal_type, :principal_id, :permission, :created_at) + """), + { + "id": str(uuid.uuid4()), + "resource_type": resource_type, + "resource_id": resource_id, + "principal_type": "user", + "principal_id": user_id, + "permission": permission, + "created_at": now, + }, + ) + inserted.add(key) + except Exception: + pass + + # Drop access_control columns from resource tables + for table_name, _ in resource_tables: + if table_name not in existing_tables: + continue + try: + with op.batch_alter_table(table_name) as batch: + batch.drop_column("access_control") + except Exception: + pass + + +def downgrade() -> None: + import json + + conn = op.get_bind() + + # Resource tables mapping: (table_name, resource_type) + resource_tables = [ + ("knowledge", "knowledge"), + ("prompt", "prompt"), + ("tool", "tool"), + ("model", "model"), + ("note", "note"), + ("channel", "channel"), + ("file", "file"), + ] + + # Step 1: Re-add access_control columns to resource tables + for table_name, _ in resource_tables: + try: + with op.batch_alter_table(table_name) as batch: + batch.add_column(sa.Column("access_control", sa.JSON(), nullable=True)) + except Exception: + pass + + # Step 2: Query access_grant table and reconstruct JSON for each resource + for table_name, resource_type in resource_tables: + try: + # Get all grants for this resource type + result = conn.execute( + sa.text(""" + SELECT resource_id, principal_type, principal_id, permission + FROM access_grant + WHERE resource_type = :resource_type + """), + {"resource_type": resource_type}, + ) + rows = result.fetchall() + except Exception: + continue + + # Group by resource_id and reconstruct JSON structure + resource_grants = {} + for row in rows: + resource_id = row[0] + principal_type = row[1] + principal_id = row[2] + permission = row[3] + + if resource_id not in resource_grants: + resource_grants[resource_id] = { + "is_public": False, + "read": {"group_ids": [], "user_ids": []}, + "write": {"group_ids": [], "user_ids": []}, + } + + # Handle public access (user:* for read) + if ( + principal_type == "user" + and principal_id == "*" + and permission == "read" + ): + resource_grants[resource_id]["is_public"] = True + continue + + # Add to appropriate list + if permission in ["read", "write"]: + if principal_type == "group": + if ( + principal_id + not in resource_grants[resource_id][permission]["group_ids"] + ): + resource_grants[resource_id][permission]["group_ids"].append( + principal_id + ) + elif principal_type == "user": + if ( + principal_id + not in resource_grants[resource_id][permission]["user_ids"] + ): + resource_grants[resource_id][permission]["user_ids"].append( + principal_id + ) + + # Step 3: Update each resource with reconstructed JSON + for resource_id, grants in resource_grants.items(): + if grants["is_public"]: + # Public = NULL + access_control_value = None + elif ( + not grants["read"]["group_ids"] + and not grants["read"]["user_ids"] + and not grants["write"]["group_ids"] + and not grants["write"]["user_ids"] + ): + # No grants = should not happen (would mean no entries), default to {} + access_control_value = json.dumps({}) + else: + # Custom permissions + access_control_value = json.dumps( + { + "read": grants["read"], + "write": grants["write"], + } + ) + + try: + conn.execute( + sa.text( + f'UPDATE "{table_name}" SET access_control = :access_control WHERE id = :id' + ), + {"access_control": access_control_value, "id": resource_id}, + ) + except Exception: + pass + + # Step 4: Set all resources WITHOUT entries to private + # For files: NULL means private (owner-only), so leave as NULL + # For other resources: {} means private, so update to {} + if resource_type != "file": + try: + conn.execute( + sa.text(f""" + UPDATE "{table_name}" + SET access_control = :private_value + WHERE id NOT IN ( + SELECT DISTINCT resource_id FROM access_grant WHERE resource_type = :resource_type + ) + AND access_control IS NULL + """), + {"private_value": json.dumps({}), "resource_type": resource_type}, + ) + except Exception: + pass + # For files, NULL stays NULL - no action needed + + # Step 5: Drop the access_grant table + op.drop_index("idx_access_grant_principal", table_name="access_grant") + op.drop_index("idx_access_grant_resource", table_name="access_grant") + op.drop_table("access_grant") diff --git a/backend/open_webui/models/access_grants.py b/backend/open_webui/models/access_grants.py new file mode 100644 index 0000000000..fa6e79a8db --- /dev/null +++ b/backend/open_webui/models/access_grants.py @@ -0,0 +1,782 @@ +import logging +import time +import uuid +from typing import Optional + +from sqlalchemy.orm import Session +from open_webui.internal.db import Base, get_db_context + +from pydantic import BaseModel, ConfigDict +from sqlalchemy import BigInteger, Column, Text, UniqueConstraint, or_, and_ +from sqlalchemy.dialects.postgresql import JSONB + +log = logging.getLogger(__name__) + + +#################### +# AccessGrant DB Schema +#################### + + +class AccessGrant(Base): + __tablename__ = "access_grant" + + id = Column(Text, primary_key=True) + resource_type = Column( + Text, nullable=False + ) # "knowledge", "model", "prompt", "tool", "note", "channel", "file" + resource_id = Column(Text, nullable=False) + principal_type = Column(Text, nullable=False) # "user" or "group" + principal_id = Column( + Text, nullable=False + ) # user_id, group_id, or "*" (wildcard for public) + permission = Column(Text, nullable=False) # "read" or "write" + created_at = Column(BigInteger, nullable=False) + + __table_args__ = ( + UniqueConstraint( + "resource_type", + "resource_id", + "principal_type", + "principal_id", + "permission", + name="uq_access_grant_grant", + ), + ) + + +class AccessGrantModel(BaseModel): + model_config = ConfigDict(from_attributes=True) + + id: str + resource_type: str + resource_id: str + principal_type: str + principal_id: str + permission: str + created_at: int + + +class AccessGrantResponse(BaseModel): + """Slim grant model for API responses — resource context is implicit from the parent.""" + + id: str + principal_type: str + principal_id: str + permission: str + + @classmethod + def from_grant(cls, grant: "AccessGrantModel") -> "AccessGrantResponse": + return cls( + id=grant.id, + principal_type=grant.principal_type, + principal_id=grant.principal_id, + permission=grant.permission, + ) + + +#################### +# Conversion utilities +#################### + + +def access_control_to_grants( + resource_type: str, + resource_id: str, + access_control: Optional[dict], +) -> list[dict]: + """ + Convert an old-style access_control JSON dict to a flat list of grant dicts. + + Semantics: + - None → public read (user:* read) — except files which are private + - {} → private/owner-only (no grants) + - {read: {group_ids, user_ids}, write: {group_ids, user_ids}} → specific grants + + Returns a list of dicts with keys: resource_type, resource_id, principal_type, principal_id, permission + """ + grants = [] + + if access_control is None: + # NULL → public read (user:* for read) + # Exception: files with NULL are private (owner-only), no grants needed + if resource_type != "file": + grants.append( + { + "resource_type": resource_type, + "resource_id": resource_id, + "principal_type": "user", + "principal_id": "*", + "permission": "read", + } + ) + return grants + + # {} → private/owner-only, no grants + if not access_control: + return grants + + # Parse structured permissions + for permission in ["read", "write"]: + perm_data = access_control.get(permission, {}) + if not perm_data: + continue + + for group_id in perm_data.get("group_ids", []): + grants.append( + { + "resource_type": resource_type, + "resource_id": resource_id, + "principal_type": "group", + "principal_id": group_id, + "permission": permission, + } + ) + + for user_id in perm_data.get("user_ids", []): + grants.append( + { + "resource_type": resource_type, + "resource_id": resource_id, + "principal_type": "user", + "principal_id": user_id, + "permission": permission, + } + ) + + return grants + + +def normalize_access_grants(access_grants: Optional[list]) -> list[dict]: + """ + Normalize direct access_grants payloads from API forms. + + Keeps only valid grants and removes duplicates by + (principal_type, principal_id, permission). + """ + if not access_grants: + return [] + + deduped = {} + for grant in access_grants: + if isinstance(grant, BaseModel): + grant = grant.model_dump() + if not isinstance(grant, dict): + continue + + principal_type = grant.get("principal_type") + principal_id = grant.get("principal_id") + permission = grant.get("permission") + + if principal_type not in ("user", "group"): + continue + if permission not in ("read", "write"): + continue + if not isinstance(principal_id, str) or not principal_id: + continue + + key = (principal_type, principal_id, permission) + deduped[key] = { + "id": ( + grant.get("id") + if isinstance(grant.get("id"), str) and grant.get("id") + else str(uuid.uuid4()) + ), + "principal_type": principal_type, + "principal_id": principal_id, + "permission": permission, + } + + return list(deduped.values()) + + +def has_public_read_access_grant(access_grants: Optional[list]) -> bool: + """ + Returns True when a direct grant list includes wildcard public-read. + """ + for grant in normalize_access_grants(access_grants): + if ( + grant["principal_type"] == "user" + and grant["principal_id"] == "*" + and grant["permission"] == "read" + ): + return True + return False + + +def grants_to_access_control(grants: list) -> Optional[dict]: + """ + Convert a list of grant objects (AccessGrantModel or AccessGrantResponse) + back to the old-style access_control JSON dict for backward compatibility. + + Semantics: + - [] (empty) → {} (private/owner-only) + - Contains user:*:read → None (public), but write grants are preserved + - Otherwise → {read: {group_ids, user_ids}, write: {group_ids, user_ids}} + + Note: "public" (user:*:read) still allows additional write permissions + to coexist. When the wildcard read is present the function returns None + for the legacy dict, so callers that need write info should inspect the + grants list directly. + """ + if not grants: + return {} # No grants = private/owner-only + + result = { + "read": {"group_ids": [], "user_ids": []}, + "write": {"group_ids": [], "user_ids": []}, + } + + is_public = False + for grant in grants: + if ( + grant.principal_type == "user" + and grant.principal_id == "*" + and grant.permission == "read" + ): + is_public = True + continue # Don't add wildcard to user_ids list + + if grant.permission not in ("read", "write"): + continue + + if grant.principal_type == "group": + if grant.principal_id not in result[grant.permission]["group_ids"]: + result[grant.permission]["group_ids"].append(grant.principal_id) + elif grant.principal_type == "user": + if grant.principal_id not in result[grant.permission]["user_ids"]: + result[grant.permission]["user_ids"].append(grant.principal_id) + + if is_public: + return None # Public read access + + return result + + +#################### +# Table Operations +#################### + + +class AccessGrantsTable: + def grant_access( + self, + resource_type: str, + resource_id: str, + principal_type: str, + principal_id: str, + permission: str, + db: Optional[Session] = None, + ) -> Optional[AccessGrantModel]: + """Add a single access grant. Idempotent (ignores duplicates).""" + with get_db_context(db) as db: + # Check for existing grant + existing = ( + db.query(AccessGrant) + .filter_by( + resource_type=resource_type, + resource_id=resource_id, + principal_type=principal_type, + principal_id=principal_id, + permission=permission, + ) + .first() + ) + if existing: + return AccessGrantModel.model_validate(existing) + + grant = AccessGrant( + id=str(uuid.uuid4()), + resource_type=resource_type, + resource_id=resource_id, + principal_type=principal_type, + principal_id=principal_id, + permission=permission, + created_at=int(time.time()), + ) + db.add(grant) + db.commit() + db.refresh(grant) + return AccessGrantModel.model_validate(grant) + + def revoke_access( + self, + resource_type: str, + resource_id: str, + principal_type: str, + principal_id: str, + permission: str, + db: Optional[Session] = None, + ) -> bool: + """Remove a single access grant.""" + with get_db_context(db) as db: + deleted = ( + db.query(AccessGrant) + .filter_by( + resource_type=resource_type, + resource_id=resource_id, + principal_type=principal_type, + principal_id=principal_id, + permission=permission, + ) + .delete() + ) + db.commit() + return deleted > 0 + + def revoke_all_access( + self, + resource_type: str, + resource_id: str, + db: Optional[Session] = None, + ) -> int: + """Remove all access grants for a resource.""" + with get_db_context(db) as db: + deleted = ( + db.query(AccessGrant) + .filter_by( + resource_type=resource_type, + resource_id=resource_id, + ) + .delete() + ) + db.commit() + return deleted + + def set_access_control( + self, + resource_type: str, + resource_id: str, + access_control: Optional[dict], + db: Optional[Session] = None, + ) -> list[AccessGrantModel]: + """ + Replace all grants for a resource from an access_control JSON dict. + This is the primary bridge for backward compat with the frontend. + """ + with get_db_context(db) as db: + # Delete all existing grants for this resource + db.query(AccessGrant).filter_by( + resource_type=resource_type, + resource_id=resource_id, + ).delete() + + # Convert JSON to grant dicts + grant_dicts = access_control_to_grants( + resource_type, resource_id, access_control + ) + + # Insert new grants + results = [] + for grant_dict in grant_dicts: + grant = AccessGrant( + id=str(uuid.uuid4()), + **grant_dict, + created_at=int(time.time()), + ) + db.add(grant) + results.append(grant) + + db.commit() + + return [AccessGrantModel.model_validate(g) for g in results] + + def set_access_grants( + self, + resource_type: str, + resource_id: str, + access_grants: Optional[list], + db: Optional[Session] = None, + ) -> list[AccessGrantModel]: + """ + Replace all grants for a resource from a direct access_grants list. + """ + with get_db_context(db) as db: + db.query(AccessGrant).filter_by( + resource_type=resource_type, + resource_id=resource_id, + ).delete() + + normalized_grants = normalize_access_grants(access_grants) + + results = [] + for grant_dict in normalized_grants: + grant = AccessGrant( + id=grant_dict["id"], + resource_type=resource_type, + resource_id=resource_id, + principal_type=grant_dict["principal_type"], + principal_id=grant_dict["principal_id"], + permission=grant_dict["permission"], + created_at=int(time.time()), + ) + db.add(grant) + results.append(grant) + + db.commit() + return [AccessGrantModel.model_validate(g) for g in results] + + def get_access_control( + self, + resource_type: str, + resource_id: str, + db: Optional[Session] = None, + ) -> Optional[dict]: + """ + Reconstruct the old-style access_control JSON dict from grants. + For backward compat with the frontend. + """ + with get_db_context(db) as db: + grants = ( + db.query(AccessGrant) + .filter_by( + resource_type=resource_type, + resource_id=resource_id, + ) + .all() + ) + grant_models = [AccessGrantModel.model_validate(g) for g in grants] + return grants_to_access_control(grant_models) + + def get_grants_by_resource( + self, + resource_type: str, + resource_id: str, + db: Optional[Session] = None, + ) -> list[AccessGrantModel]: + """Get all grants for a specific resource.""" + with get_db_context(db) as db: + grants = ( + db.query(AccessGrant) + .filter_by( + resource_type=resource_type, + resource_id=resource_id, + ) + .all() + ) + return [AccessGrantModel.model_validate(g) for g in grants] + + def has_access( + self, + user_id: str, + resource_type: str, + resource_id: str, + permission: str = "read", + user_group_ids: Optional[set[str]] = None, + db: Optional[Session] = None, + ) -> bool: + """ + Check if a user has the specified permission on a resource. + + Access is granted if any of the following is true: + - There's a grant for user:* (public) with the requested permission + - There's a grant for the specific user with the requested permission + - There's a grant for any of the user's groups with the requested permission + """ + with get_db_context(db) as db: + # Build conditions for matching grants + conditions = [ + # Public access + and_( + AccessGrant.principal_type == "user", + AccessGrant.principal_id == "*", + ), + # Direct user access + and_( + AccessGrant.principal_type == "user", + AccessGrant.principal_id == user_id, + ), + ] + + # Group access + if user_group_ids is None: + from open_webui.models.groups import Groups + + user_groups = Groups.get_groups_by_member_id(user_id, db=db) + user_group_ids = {group.id for group in user_groups} + + if user_group_ids: + conditions.append( + and_( + AccessGrant.principal_type == "group", + AccessGrant.principal_id.in_(user_group_ids), + ) + ) + + exists = ( + db.query(AccessGrant) + .filter( + AccessGrant.resource_type == resource_type, + AccessGrant.resource_id == resource_id, + AccessGrant.permission == permission, + or_(*conditions), + ) + .first() + ) + return exists is not None + + def get_users_with_access( + self, + resource_type: str, + resource_id: str, + permission: str = "read", + db: Optional[Session] = None, + ) -> list: + """ + Get all users who have the specified permission on a resource. + Returns a list of UserModel instances. + """ + from open_webui.models.users import Users, UserModel + from open_webui.models.groups import Groups + + with get_db_context(db) as db: + grants = ( + db.query(AccessGrant) + .filter_by( + resource_type=resource_type, + resource_id=resource_id, + permission=permission, + ) + .all() + ) + + # Check for public access + for grant in grants: + if grant.principal_type == "user" and grant.principal_id == "*": + result = Users.get_users(filter={"roles": ["!pending"]}, db=db) + return result.get("users", []) + + user_ids_with_access = set() + + for grant in grants: + if grant.principal_type == "user": + user_ids_with_access.add(grant.principal_id) + elif grant.principal_type == "group": + group_user_ids = Groups.get_group_user_ids_by_id( + grant.principal_id, db=db + ) + if group_user_ids: + user_ids_with_access.update(group_user_ids) + + if not user_ids_with_access: + return [] + + return Users.get_users_by_user_ids(list(user_ids_with_access), db=db) + + def has_permission_filter( + self, + db, + query, + DocumentModel, + filter: dict, + resource_type: str, + permission: str = "read", + ): + """ + Apply access control filtering to a SQLAlchemy query by JOINing with access_grant. + + This replaces the old JSON-column-based filtering with a proper relational JOIN. + """ + group_ids = filter.get("group_ids", []) + user_id = filter.get("user_id") + + if permission == "read_only": + return self._has_read_only_permission_filter( + db, query, DocumentModel, filter, resource_type + ) + + # Build principal conditions + principal_conditions = [] + + if group_ids or user_id: + # Public access: user:* read + principal_conditions.append( + and_( + AccessGrant.principal_type == "user", + AccessGrant.principal_id == "*", + ) + ) + + if user_id: + # Owner always has access + principal_conditions.append(DocumentModel.user_id == user_id) + + # Direct user grant + principal_conditions.append( + and_( + AccessGrant.principal_type == "user", + AccessGrant.principal_id == user_id, + ) + ) + + if group_ids: + # Group grants + principal_conditions.append( + and_( + AccessGrant.principal_type == "group", + AccessGrant.principal_id.in_(group_ids), + ) + ) + + if not principal_conditions: + return query + + # LEFT JOIN access_grant and filter + # We use a subquery approach to avoid duplicates from multiple matching grants + from sqlalchemy import exists as sa_exists, select + + grant_exists = ( + select(AccessGrant.id) + .where( + AccessGrant.resource_type == resource_type, + AccessGrant.resource_id == DocumentModel.id, + AccessGrant.permission == permission, + or_( + and_( + AccessGrant.principal_type == "user", + AccessGrant.principal_id == "*", + ), + *( + [ + and_( + AccessGrant.principal_type == "user", + AccessGrant.principal_id == user_id, + ) + ] + if user_id + else [] + ), + *( + [ + and_( + AccessGrant.principal_type == "group", + AccessGrant.principal_id.in_(group_ids), + ) + ] + if group_ids + else [] + ), + ), + ) + .correlate(DocumentModel) + .exists() + ) + + # Owner OR has a matching grant + owner_or_grant = [grant_exists] + if user_id: + owner_or_grant.append(DocumentModel.user_id == user_id) + + query = query.filter(or_(*owner_or_grant)) + return query + + def _has_read_only_permission_filter( + self, + db, + query, + DocumentModel, + filter: dict, + resource_type: str, + ): + """ + Filter for items where user has read BUT NOT write access. + Public items are NOT considered read_only. + """ + group_ids = filter.get("group_ids", []) + user_id = filter.get("user_id") + + from sqlalchemy import exists as sa_exists, select + + # Has read grant (not public) + read_grant_exists = ( + select(AccessGrant.id) + .where( + AccessGrant.resource_type == resource_type, + AccessGrant.resource_id == DocumentModel.id, + AccessGrant.permission == "read", + or_( + *( + [ + and_( + AccessGrant.principal_type == "user", + AccessGrant.principal_id == user_id, + ) + ] + if user_id + else [] + ), + *( + [ + and_( + AccessGrant.principal_type == "group", + AccessGrant.principal_id.in_(group_ids), + ) + ] + if group_ids + else [] + ), + ), + ) + .correlate(DocumentModel) + .exists() + ) + + # Does NOT have write grant + write_grant_exists = ( + select(AccessGrant.id) + .where( + AccessGrant.resource_type == resource_type, + AccessGrant.resource_id == DocumentModel.id, + AccessGrant.permission == "write", + or_( + *( + [ + and_( + AccessGrant.principal_type == "user", + AccessGrant.principal_id == user_id, + ) + ] + if user_id + else [] + ), + *( + [ + and_( + AccessGrant.principal_type == "group", + AccessGrant.principal_id.in_(group_ids), + ) + ] + if group_ids + else [] + ), + ), + ) + .correlate(DocumentModel) + .exists() + ) + + # Is NOT public + public_grant_exists = ( + select(AccessGrant.id) + .where( + AccessGrant.resource_type == resource_type, + AccessGrant.resource_id == DocumentModel.id, + AccessGrant.permission == "read", + AccessGrant.principal_type == "user", + AccessGrant.principal_id == "*", + ) + .correlate(DocumentModel) + .exists() + ) + + conditions = [read_grant_exists, ~write_grant_exists, ~public_grant_exists] + + # Not owner + if user_id: + conditions.append(DocumentModel.user_id != user_id) + + query = query.filter(and_(*conditions)) + return query + + +AccessGrants = AccessGrantsTable() diff --git a/backend/open_webui/models/auths.py b/backend/open_webui/models/auths.py index 93f17dff11..0b4639fc62 100644 --- a/backend/open_webui/models/auths.py +++ b/backend/open_webui/models/auths.py @@ -4,8 +4,9 @@ from sqlalchemy.orm import Session from open_webui.internal.db import Base, JSONField, get_db, get_db_context -from open_webui.models.users import UserModel, UserProfileImageResponse, Users -from pydantic import BaseModel +from open_webui.models.users import User, UserModel, UserProfileImageResponse, Users +from open_webui.utils.validate import validate_profile_image_url +from pydantic import BaseModel, field_validator from sqlalchemy import Boolean, Column, String, Text log = logging.getLogger(__name__) @@ -74,6 +75,13 @@ class SignupForm(BaseModel): password: str profile_image_url: Optional[str] = "/user.png" + @field_validator("profile_image_url") + @classmethod + def check_profile_image_url(cls, v: Optional[str]) -> Optional[str]: + if v is not None: + return validate_profile_image_url(v) + return v + class AddUserForm(SignupForm): role: Optional[str] = "pending" @@ -155,10 +163,17 @@ def authenticate_user_by_email( log.info(f"authenticate_user_by_email: {email}") try: with get_db_context(db) as db: - auth = db.query(Auth).filter_by(email=email, active=True).first() - if auth: - user = Users.get_user_by_id(auth.id, db=db) - return user + # Single JOIN query instead of two separate queries + result = ( + db.query(Auth, User) + .join(User, Auth.id == User.id) + .filter(Auth.email == email, Auth.active == True) + .first() + ) + if result: + _, user = result + return UserModel.model_validate(user) + return None except Exception: return None diff --git a/backend/open_webui/models/channels.py b/backend/open_webui/models/channels.py index 8e70918e1a..8a55da9345 100644 --- a/backend/open_webui/models/channels.py +++ b/backend/open_webui/models/channels.py @@ -7,8 +7,12 @@ from sqlalchemy.orm import Session from open_webui.internal.db import Base, JSONField, get_db, get_db_context from open_webui.models.groups import Groups +from open_webui.models.access_grants import ( + AccessGrantModel, + AccessGrants, +) -from pydantic import BaseModel, ConfigDict +from pydantic import BaseModel, ConfigDict, Field from sqlalchemy.dialects.postgresql import JSONB @@ -47,7 +51,6 @@ class Channel(Base): data = Column(JSON, nullable=True) meta = Column(JSON, nullable=True) - access_control = Column(JSON, nullable=True) created_at = Column(BigInteger) @@ -76,7 +79,7 @@ class ChannelModel(BaseModel): data: Optional[dict] = None meta: Optional[dict] = None - access_control: Optional[dict] = None + access_grants: list[AccessGrantModel] = Field(default_factory=list) created_at: int # timestamp in epoch (time_ns) @@ -237,7 +240,7 @@ class ChannelForm(BaseModel): is_private: Optional[bool] = None data: Optional[dict] = None meta: Optional[dict] = None - access_control: Optional[dict] = None + access_grants: Optional[list[dict]] = None group_ids: Optional[list[str]] = None user_ids: Optional[list[str]] = None @@ -252,6 +255,20 @@ class ChannelWebhookForm(BaseModel): class ChannelTable: + def _get_access_grants( + self, channel_id: str, db: Optional[Session] = None + ) -> list[AccessGrantModel]: + return AccessGrants.get_grants_by_resource("channel", channel_id, db=db) + + def _to_channel_model( + self, channel: Channel, db: Optional[Session] = None + ) -> ChannelModel: + channel_data = ChannelModel.model_validate(channel).model_dump( + exclude={"access_grants"} + ) + access_grants = self._get_access_grants(channel_data["id"], db=db) + channel_data["access_grants"] = access_grants + return ChannelModel.model_validate(channel_data) def _collect_unique_user_ids( self, @@ -316,16 +333,17 @@ def insert_new_channel( with get_db_context(db) as db: channel = ChannelModel( **{ - **form_data.model_dump(), + **form_data.model_dump(exclude={"access_grants"}), "type": form_data.type if form_data.type else None, "name": form_data.name.lower(), "id": str(uuid.uuid4()), "user_id": user_id, "created_at": int(time.time_ns()), "updated_at": int(time.time_ns()), + "access_grants": [], } ) - new_channel = Channel(**channel.model_dump()) + new_channel = Channel(**channel.model_dump(exclude={"access_grants"})) if form_data.type in ["group", "dm"]: users = self._collect_unique_user_ids( @@ -342,54 +360,25 @@ def insert_new_channel( db.add_all(memberships) db.add(new_channel) db.commit() - return channel + AccessGrants.set_access_grants( + "channel", new_channel.id, form_data.access_grants, db=db + ) + return self._to_channel_model(new_channel, db=db) def get_channels(self, db: Optional[Session] = None) -> list[ChannelModel]: with get_db_context(db) as db: channels = db.query(Channel).all() - return [ChannelModel.model_validate(channel) for channel in channels] + return [self._to_channel_model(channel, db=db) for channel in channels] def _has_permission(self, db, query, filter: dict, permission: str = "read"): - group_ids = filter.get("group_ids", []) - user_id = filter.get("user_id") - - dialect_name = db.bind.dialect.name - - # Public access - conditions = [] - if group_ids or user_id: - conditions.extend( - [ - Channel.access_control.is_(None), - cast(Channel.access_control, String) == "null", - ] - ) - - # User-level permission - if user_id: - conditions.append(Channel.user_id == user_id) - - # Group-level permission - if group_ids: - group_conditions = [] - for gid in group_ids: - if dialect_name == "sqlite": - group_conditions.append( - Channel.access_control[permission]["group_ids"].contains([gid]) - ) - elif dialect_name == "postgresql": - group_conditions.append( - cast( - Channel.access_control[permission]["group_ids"], - JSONB, - ).contains([gid]) - ) - conditions.append(or_(*group_conditions)) - - if conditions: - query = query.filter(or_(*conditions)) - - return query + return AccessGrants.has_permission_filter( + db=db, + query=query, + DocumentModel=Channel, + filter=filter, + resource_type="channel", + permission=permission, + ) def get_channels_by_user_id( self, user_id: str, db: Optional[Session] = None @@ -428,7 +417,7 @@ def get_channels_by_user_id( standard_channels = query.all() all_channels = membership_channels + standard_channels - return [ChannelModel.model_validate(c) for c in all_channels] + return [self._to_channel_model(c, db=db) for c in all_channels] def get_dm_channel_by_user_ids( self, user_ids: list[str], db: Optional[Session] = None @@ -463,7 +452,7 @@ def get_dm_channel_by_user_ids( .first() ) - return ChannelModel.model_validate(channel) if channel else None + return self._to_channel_model(channel, db=db) if channel else None def add_members_to_channel( self, @@ -722,7 +711,7 @@ def get_channel_by_id( try: with get_db_context(db) as db: channel = db.query(Channel).filter(Channel.id == id).first() - return ChannelModel.model_validate(channel) if channel else None + return self._to_channel_model(channel, db=db) if channel else None except Exception: return None @@ -735,7 +724,7 @@ def get_channels_by_file_id( ) channel_ids = [cf.channel_id for cf in channel_files] channels = db.query(Channel).filter(Channel.id.in_(channel_ids)).all() - return [ChannelModel.model_validate(channel) for channel in channels] + return [self._to_channel_model(channel, db=db) for channel in channels] def get_channels_by_file_id_and_user_id( self, file_id: str, user_id: str, db: Optional[Session] = None @@ -783,7 +772,7 @@ def get_channels_by_file_id_and_user_id( .first() ) if membership: - allowed_channels.append(ChannelModel.model_validate(channel)) + allowed_channels.append(self._to_channel_model(channel, db=db)) continue # --- Case B: standard channel => rely on ACL permissions --- @@ -798,7 +787,7 @@ def get_channels_by_file_id_and_user_id( allowed = query.first() if allowed: - allowed_channels.append(ChannelModel.model_validate(allowed)) + allowed_channels.append(self._to_channel_model(allowed, db=db)) return allowed_channels @@ -832,7 +821,7 @@ def get_channel_by_id_and_user_id( .first() ) if membership: - return ChannelModel.model_validate(channel) + return self._to_channel_model(channel, db=db) else: return None @@ -854,7 +843,7 @@ def get_channel_by_id_and_user_id( channel_allowed = query.first() return ( - ChannelModel.model_validate(channel_allowed) + self._to_channel_model(channel_allowed, db=db) if channel_allowed else None ) @@ -874,11 +863,14 @@ def update_channel_by_id( channel.data = form_data.data channel.meta = form_data.meta - channel.access_control = form_data.access_control + if form_data.access_grants is not None: + AccessGrants.set_access_grants( + "channel", id, form_data.access_grants, db=db + ) channel.updated_at = int(time.time_ns()) db.commit() - return ChannelModel.model_validate(channel) if channel else None + return self._to_channel_model(channel, db=db) if channel else None def add_file_to_channel_by_id( self, channel_id: str, file_id: str, user_id: str, db: Optional[Session] = None @@ -947,6 +939,7 @@ def remove_file_from_channel_by_id( def delete_channel_by_id(self, id: str, db: Optional[Session] = None) -> bool: with get_db_context(db) as db: + AccessGrants.revoke_all_access("channel", id, db=db) db.query(Channel).filter(Channel.id == id).delete() db.commit() return True diff --git a/backend/open_webui/models/chat_messages.py b/backend/open_webui/models/chat_messages.py new file mode 100644 index 0000000000..fe3539f9cd --- /dev/null +++ b/backend/open_webui/models/chat_messages.py @@ -0,0 +1,642 @@ +import json +import time +import uuid +from typing import Any, Optional + +from sqlalchemy.orm import Session +from open_webui.internal.db import Base, get_db_context + +from pydantic import BaseModel, ConfigDict +from sqlalchemy import ( + BigInteger, + Boolean, + Column, + ForeignKey, + Text, + JSON, + Index, +) + +#################### +# Helpers +#################### + + +def _normalize_timestamp(timestamp: int) -> float: + """Normalize and validate timestamp. Returns current time if invalid.""" + now = time.time() + + # Convert milliseconds to seconds if needed + if timestamp > 10_000_000_000: + timestamp = timestamp / 1000 + + # Validate: must be after 2020 and not in the future (with 1 day tolerance) + min_valid = 1577836800 # 2020-01-01 00:00:00 UTC + max_valid = now + 86400 # 1 day in the future (clock skew tolerance) + + if timestamp < min_valid or timestamp > max_valid: + return now + + return timestamp + + +#################### +# ChatMessage DB Schema +#################### + + +class ChatMessage(Base): + __tablename__ = "chat_message" + + # Identity + id = Column(Text, primary_key=True) + chat_id = Column( + Text, ForeignKey("chat.id", ondelete="CASCADE"), nullable=False, index=True + ) + user_id = Column(Text, index=True) + + # Structure + role = Column(Text, nullable=False) # user, assistant, system + parent_id = Column(Text, nullable=True) + + # Content + content = Column(JSON, nullable=True) # Can be str or list of blocks + output = Column(JSON, nullable=True) + + # Model (for assistant messages) + model_id = Column(Text, nullable=True, index=True) + + # Attachments + files = Column(JSON, nullable=True) + sources = Column(JSON, nullable=True) + embeds = Column(JSON, nullable=True) + + # Status + done = Column(Boolean, default=True) + status_history = Column(JSON, nullable=True) + error = Column(JSON, nullable=True) + + # Usage (tokens, timing, etc.) + usage = Column(JSON, nullable=True) + + # Timestamps + created_at = Column(BigInteger, index=True) + updated_at = Column(BigInteger) + + __table_args__ = ( + Index("chat_message_chat_parent_idx", "chat_id", "parent_id"), + Index("chat_message_model_created_idx", "model_id", "created_at"), + Index("chat_message_user_created_idx", "user_id", "created_at"), + ) + + +#################### +# Pydantic Models +#################### + + +class ChatMessageModel(BaseModel): + model_config = ConfigDict(from_attributes=True) + + id: str + chat_id: str + user_id: str + role: str + parent_id: Optional[str] = None + content: Optional[Any] = None # str or list of blocks + output: Optional[list] = None + model_id: Optional[str] = None + files: Optional[list] = None + sources: Optional[list] = None + embeds: Optional[list] = None + done: bool = True + status_history: Optional[list] = None + error: Optional[dict | str] = None + usage: Optional[dict] = None + created_at: int + updated_at: int + + +#################### +# Table Operations +#################### + + +class ChatMessageTable: + def upsert_message( + self, + message_id: str, + chat_id: str, + user_id: str, + data: dict, + db: Optional[Session] = None, + ) -> Optional[ChatMessageModel]: + """Insert or update a chat message.""" + with get_db_context(db) as db: + now = int(time.time()) + timestamp = data.get("timestamp", now) + + # Use composite ID: {chat_id}-{message_id} + composite_id = f"{chat_id}-{message_id}" + + existing = db.get(ChatMessage, composite_id) + if existing: + # Update existing + if "role" in data: + existing.role = data["role"] + if "parent_id" in data: + existing.parent_id = data.get("parent_id") or data.get("parentId") + if "content" in data: + existing.content = data.get("content") + if "output" in data: + existing.output = data.get("output") + if "model_id" in data or "model" in data: + existing.model_id = data.get("model_id") or data.get("model") + if "files" in data: + existing.files = data.get("files") + if "sources" in data: + existing.sources = data.get("sources") + if "embeds" in data: + existing.embeds = data.get("embeds") + if "done" in data: + existing.done = data.get("done", True) + if "status_history" in data or "statusHistory" in data: + existing.status_history = data.get("status_history") or data.get( + "statusHistory" + ) + if "error" in data: + existing.error = data.get("error") + # Extract usage - check direct field first, then info.usage + usage = data.get("usage") + if not usage: + info = data.get("info", {}) + usage = info.get("usage") if info else None + if usage: + existing.usage = usage + existing.updated_at = now + db.commit() + db.refresh(existing) + return ChatMessageModel.model_validate(existing) + else: + # Insert new + # Extract usage - check direct field first, then info.usage + usage = data.get("usage") + if not usage: + info = data.get("info", {}) + usage = info.get("usage") if info else None + message = ChatMessage( + id=composite_id, + chat_id=chat_id, + user_id=user_id, + role=data.get("role", "user"), + parent_id=data.get("parent_id") or data.get("parentId"), + content=data.get("content"), + output=data.get("output"), + model_id=data.get("model_id") or data.get("model"), + files=data.get("files"), + sources=data.get("sources"), + embeds=data.get("embeds"), + done=data.get("done", True), + status_history=data.get("status_history") + or data.get("statusHistory"), + error=data.get("error"), + usage=usage, + created_at=timestamp, + updated_at=now, + ) + db.add(message) + db.commit() + db.refresh(message) + return ChatMessageModel.model_validate(message) + + def get_message_by_id( + self, id: str, db: Optional[Session] = None + ) -> Optional[ChatMessageModel]: + with get_db_context(db) as db: + message = db.get(ChatMessage, id) + return ChatMessageModel.model_validate(message) if message else None + + def get_messages_by_chat_id( + self, chat_id: str, db: Optional[Session] = None + ) -> list[ChatMessageModel]: + with get_db_context(db) as db: + messages = ( + db.query(ChatMessage) + .filter_by(chat_id=chat_id) + .order_by(ChatMessage.created_at.asc()) + .all() + ) + return [ChatMessageModel.model_validate(message) for message in messages] + + def get_messages_by_user_id( + self, + user_id: str, + skip: int = 0, + limit: int = 50, + db: Optional[Session] = None, + ) -> list[ChatMessageModel]: + with get_db_context(db) as db: + messages = ( + db.query(ChatMessage) + .filter_by(user_id=user_id) + .order_by(ChatMessage.created_at.desc()) + .offset(skip) + .limit(limit) + .all() + ) + return [ChatMessageModel.model_validate(message) for message in messages] + + def get_messages_by_model_id( + self, + model_id: str, + start_date: Optional[int] = None, + end_date: Optional[int] = None, + skip: int = 0, + limit: int = 100, + db: Optional[Session] = None, + ) -> list[ChatMessageModel]: + with get_db_context(db) as db: + query = db.query(ChatMessage).filter_by(model_id=model_id) + if start_date: + query = query.filter(ChatMessage.created_at >= start_date) + if end_date: + query = query.filter(ChatMessage.created_at <= end_date) + messages = ( + query.order_by(ChatMessage.created_at.desc()) + .offset(skip) + .limit(limit) + .all() + ) + return [ChatMessageModel.model_validate(message) for message in messages] + + def get_chat_ids_by_model_id( + self, + model_id: str, + start_date: Optional[int] = None, + end_date: Optional[int] = None, + skip: int = 0, + limit: int = 50, + db: Optional[Session] = None, + ) -> list[str]: + """Get distinct chat_ids that used a specific model.""" + from sqlalchemy import distinct + + with get_db_context(db) as db: + query = db.query(distinct(ChatMessage.chat_id)).filter( + ChatMessage.model_id == model_id + ) + if start_date: + query = query.filter(ChatMessage.created_at >= start_date) + if end_date: + query = query.filter(ChatMessage.created_at <= end_date) + + # Order by most recent message in each chat + chat_ids = ( + query.order_by(ChatMessage.created_at.desc()) + .offset(skip) + .limit(limit) + .all() + ) + return [chat_id for (chat_id,) in chat_ids] + + def delete_messages_by_chat_id( + self, chat_id: str, db: Optional[Session] = None + ) -> bool: + with get_db_context(db) as db: + db.query(ChatMessage).filter_by(chat_id=chat_id).delete() + db.commit() + return True + + # Analytics methods + def get_message_count_by_model( + self, + start_date: Optional[int] = None, + end_date: Optional[int] = None, + group_id: Optional[str] = None, + db: Optional[Session] = None, + ) -> dict[str, int]: + with get_db_context(db) as db: + from sqlalchemy import func + from open_webui.models.groups import GroupMember + + query = db.query( + ChatMessage.model_id, func.count(ChatMessage.id).label("count") + ).filter( + ChatMessage.role == "assistant", + ChatMessage.model_id.isnot(None), + ~ChatMessage.user_id.like("shared-%"), + ) + + if start_date: + query = query.filter(ChatMessage.created_at >= start_date) + if end_date: + query = query.filter(ChatMessage.created_at <= end_date) + if group_id: + group_users = ( + db.query(GroupMember.user_id) + .filter(GroupMember.group_id == group_id) + .subquery() + ) + query = query.filter(ChatMessage.user_id.in_(group_users)) + + results = query.group_by(ChatMessage.model_id).all() + return {row.model_id: row.count for row in results} + + def get_token_usage_by_model( + self, + start_date: Optional[int] = None, + end_date: Optional[int] = None, + group_id: Optional[str] = None, + db: Optional[Session] = None, + ) -> dict[str, dict]: + """Aggregate token usage by model using database-level aggregation.""" + with get_db_context(db) as db: + from sqlalchemy import func, cast, Integer + from open_webui.models.groups import GroupMember + + dialect = db.bind.dialect.name + + if dialect == "sqlite": + input_tokens = cast( + func.json_extract(ChatMessage.usage, "$.input_tokens"), Integer + ) + output_tokens = cast( + func.json_extract(ChatMessage.usage, "$.output_tokens"), Integer + ) + elif dialect == "postgresql": + # Use json_extract_path_text for PostgreSQL JSON columns + input_tokens = cast( + func.json_extract_path_text(ChatMessage.usage, "input_tokens"), + Integer, + ) + output_tokens = cast( + func.json_extract_path_text(ChatMessage.usage, "output_tokens"), + Integer, + ) + else: + raise NotImplementedError(f"Unsupported dialect: {dialect}") + + query = db.query( + ChatMessage.model_id, + func.coalesce(func.sum(input_tokens), 0).label("input_tokens"), + func.coalesce(func.sum(output_tokens), 0).label("output_tokens"), + func.count(ChatMessage.id).label("message_count"), + ).filter( + ChatMessage.role == "assistant", + ChatMessage.model_id.isnot(None), + ChatMessage.usage.isnot(None), + ~ChatMessage.user_id.like("shared-%"), + ) + + if start_date: + query = query.filter(ChatMessage.created_at >= start_date) + if end_date: + query = query.filter(ChatMessage.created_at <= end_date) + if group_id: + group_users = ( + db.query(GroupMember.user_id) + .filter(GroupMember.group_id == group_id) + .subquery() + ) + query = query.filter(ChatMessage.user_id.in_(group_users)) + + results = query.group_by(ChatMessage.model_id).all() + + return { + row.model_id: { + "input_tokens": row.input_tokens, + "output_tokens": row.output_tokens, + "total_tokens": row.input_tokens + row.output_tokens, + "message_count": row.message_count, + } + for row in results + } + + def get_token_usage_by_user( + self, + start_date: Optional[int] = None, + end_date: Optional[int] = None, + db: Optional[Session] = None, + ) -> dict[str, dict]: + """Aggregate token usage by user using database-level aggregation.""" + with get_db_context(db) as db: + from sqlalchemy import func, cast, Integer + + dialect = db.bind.dialect.name + + if dialect == "sqlite": + input_tokens = cast( + func.json_extract(ChatMessage.usage, "$.input_tokens"), Integer + ) + output_tokens = cast( + func.json_extract(ChatMessage.usage, "$.output_tokens"), Integer + ) + elif dialect == "postgresql": + # Use json_extract_path_text for PostgreSQL JSON columns + input_tokens = cast( + func.json_extract_path_text(ChatMessage.usage, "input_tokens"), + Integer, + ) + output_tokens = cast( + func.json_extract_path_text(ChatMessage.usage, "output_tokens"), + Integer, + ) + else: + raise NotImplementedError(f"Unsupported dialect: {dialect}") + + query = db.query( + ChatMessage.user_id, + func.coalesce(func.sum(input_tokens), 0).label("input_tokens"), + func.coalesce(func.sum(output_tokens), 0).label("output_tokens"), + func.count(ChatMessage.id).label("message_count"), + ).filter( + ChatMessage.role == "assistant", + ChatMessage.user_id.isnot(None), + ChatMessage.usage.isnot(None), + ~ChatMessage.user_id.like("shared-%"), + ) + + if start_date: + query = query.filter(ChatMessage.created_at >= start_date) + if end_date: + query = query.filter(ChatMessage.created_at <= end_date) + + results = query.group_by(ChatMessage.user_id).all() + + return { + row.user_id: { + "input_tokens": row.input_tokens, + "output_tokens": row.output_tokens, + "total_tokens": row.input_tokens + row.output_tokens, + "message_count": row.message_count, + } + for row in results + } + + def get_message_count_by_user( + self, + start_date: Optional[int] = None, + end_date: Optional[int] = None, + group_id: Optional[str] = None, + db: Optional[Session] = None, + ) -> dict[str, int]: + with get_db_context(db) as db: + from sqlalchemy import func + from open_webui.models.groups import GroupMember + + query = db.query( + ChatMessage.user_id, func.count(ChatMessage.id).label("count") + ).filter(~ChatMessage.user_id.like("shared-%")) + + if start_date: + query = query.filter(ChatMessage.created_at >= start_date) + if end_date: + query = query.filter(ChatMessage.created_at <= end_date) + if group_id: + group_users = ( + db.query(GroupMember.user_id) + .filter(GroupMember.group_id == group_id) + .subquery() + ) + query = query.filter(ChatMessage.user_id.in_(group_users)) + + results = query.group_by(ChatMessage.user_id).all() + return {row.user_id: row.count for row in results} + + def get_message_count_by_chat( + self, + start_date: Optional[int] = None, + end_date: Optional[int] = None, + group_id: Optional[str] = None, + db: Optional[Session] = None, + ) -> dict[str, int]: + with get_db_context(db) as db: + from sqlalchemy import func + from open_webui.models.groups import GroupMember + + query = db.query( + ChatMessage.chat_id, func.count(ChatMessage.id).label("count") + ).filter(~ChatMessage.user_id.like("shared-%")) + + if start_date: + query = query.filter(ChatMessage.created_at >= start_date) + if end_date: + query = query.filter(ChatMessage.created_at <= end_date) + if group_id: + group_users = ( + db.query(GroupMember.user_id) + .filter(GroupMember.group_id == group_id) + .subquery() + ) + query = query.filter(ChatMessage.user_id.in_(group_users)) + + results = query.group_by(ChatMessage.chat_id).all() + return {row.chat_id: row.count for row in results} + + def get_daily_message_counts_by_model( + self, + start_date: Optional[int] = None, + end_date: Optional[int] = None, + group_id: Optional[str] = None, + db: Optional[Session] = None, + ) -> dict[str, dict[str, int]]: + """Get message counts grouped by day and model.""" + with get_db_context(db) as db: + from datetime import datetime, timedelta + from open_webui.models.groups import GroupMember + + query = db.query(ChatMessage.created_at, ChatMessage.model_id).filter( + ChatMessage.role == "assistant", + ChatMessage.model_id.isnot(None), + ~ChatMessage.user_id.like("shared-%"), + ) + + if start_date: + query = query.filter(ChatMessage.created_at >= start_date) + if end_date: + query = query.filter(ChatMessage.created_at <= end_date) + if group_id: + group_users = ( + db.query(GroupMember.user_id) + .filter(GroupMember.group_id == group_id) + .subquery() + ) + query = query.filter(ChatMessage.user_id.in_(group_users)) + + results = query.all() + + # Group by date -> model -> count + daily_counts: dict[str, dict[str, int]] = {} + for timestamp, model_id in results: + date_str = datetime.fromtimestamp( + _normalize_timestamp(timestamp) + ).strftime("%Y-%m-%d") + if date_str not in daily_counts: + daily_counts[date_str] = {} + daily_counts[date_str][model_id] = ( + daily_counts[date_str].get(model_id, 0) + 1 + ) + + # Fill in missing days + if start_date and end_date: + current = datetime.fromtimestamp(_normalize_timestamp(start_date)) + end_dt = datetime.fromtimestamp(_normalize_timestamp(end_date)) + while current <= end_dt: + date_str = current.strftime("%Y-%m-%d") + if date_str not in daily_counts: + daily_counts[date_str] = {} + current += timedelta(days=1) + + return daily_counts + + def get_hourly_message_counts_by_model( + self, + start_date: Optional[int] = None, + end_date: Optional[int] = None, + db: Optional[Session] = None, + ) -> dict[str, dict[str, int]]: + """Get message counts grouped by hour and model.""" + with get_db_context(db) as db: + from datetime import datetime, timedelta + + query = db.query(ChatMessage.created_at, ChatMessage.model_id).filter( + ChatMessage.role == "assistant", + ChatMessage.model_id.isnot(None), + ~ChatMessage.user_id.like("shared-%"), + ) + + if start_date: + query = query.filter(ChatMessage.created_at >= start_date) + if end_date: + query = query.filter(ChatMessage.created_at <= end_date) + + results = query.all() + + # Group by hour -> model -> count + hourly_counts: dict[str, dict[str, int]] = {} + for timestamp, model_id in results: + hour_str = datetime.fromtimestamp( + _normalize_timestamp(timestamp) + ).strftime("%Y-%m-%d %H:00") + if hour_str not in hourly_counts: + hourly_counts[hour_str] = {} + hourly_counts[hour_str][model_id] = ( + hourly_counts[hour_str].get(model_id, 0) + 1 + ) + + # Fill in missing hours + if start_date and end_date: + current = datetime.fromtimestamp( + _normalize_timestamp(start_date) + ).replace(minute=0, second=0, microsecond=0) + end_dt = datetime.fromtimestamp(_normalize_timestamp(end_date)) + while current <= end_dt: + hour_str = current.strftime("%Y-%m-%d %H:00") + if hour_str not in hourly_counts: + hourly_counts[hour_str] = {} + current += timedelta(hours=1) + + return hourly_counts + + +ChatMessages = ChatMessageTable() diff --git a/backend/open_webui/models/chats.py b/backend/open_webui/models/chats.py index 12359eec9f..6040050fc3 100644 --- a/backend/open_webui/models/chats.py +++ b/backend/open_webui/models/chats.py @@ -8,6 +8,7 @@ from open_webui.internal.db import Base, JSONField, get_db, get_db_context from open_webui.models.tags import TagModel, Tag, Tags from open_webui.models.folders import Folders +from open_webui.models.chat_messages import ChatMessages from open_webui.utils.misc import sanitize_data_for_db, sanitize_text_for_db from pydantic import BaseModel, ConfigDict @@ -168,6 +169,14 @@ class ChatTitleIdResponse(BaseModel): created_at: int +class SharedChatResponse(BaseModel): + id: str + title: str + share_id: Optional[str] = None + updated_at: int + created_at: int + + class ChatListResponse(BaseModel): items: list[ChatModel] total: int @@ -306,6 +315,24 @@ def insert_new_chat( db.add(chat_item) db.commit() db.refresh(chat_item) + + # Dual-write initial messages to chat_message table + try: + history = form_data.chat.get("history", {}) + messages = history.get("messages", {}) + for message_id, message in messages.items(): + if isinstance(message, dict) and message.get("role"): + ChatMessages.upsert_message( + message_id=message_id, + chat_id=id, + user_id=user_id, + data=message, + ) + except Exception as e: + log.warning( + f"Failed to write initial messages to chat_message table: {e}" + ) + return ChatModel.model_validate(chat_item) if chat_item else None def _chat_import_form_to_chat_model( @@ -348,6 +375,25 @@ def import_chats( db.add_all(chats) db.commit() + + # Dual-write messages to chat_message table + try: + for form_data, chat_obj in zip(chat_import_forms, chats): + history = form_data.chat.get("history", {}) + messages = history.get("messages", {}) + for message_id, message in messages.items(): + if isinstance(message, dict) and message.get("role"): + ChatMessages.upsert_message( + message_id=message_id, + chat_id=chat_obj.id, + user_id=user_id, + data=message, + ) + except Exception as e: + log.warning( + f"Failed to write imported messages to chat_message table: {e}" + ) + return [ChatModel.model_validate(chat) for chat in chats] def update_chat_by_id( @@ -450,6 +496,18 @@ def upsert_message_to_chat_by_id_and_message_id( history["currentId"] = message_id chat["history"] = history + + # Dual-write to chat_message table + try: + ChatMessages.upsert_message( + message_id=message_id, + chat_id=id, + user_id=self.get_chat_by_id(id).user_id, + data=history["messages"][message_id], + ) + except Exception as e: + log.warning(f"Failed to write to chat_message table: {e}") + return self.update_chat_by_id(id, chat) def add_message_status_to_chat_by_id_and_message_id( @@ -675,6 +733,51 @@ def get_archived_chat_list_by_user_id( all_chats = query.all() return [ChatModel.model_validate(chat) for chat in all_chats] + def get_shared_chat_list_by_user_id( + self, + user_id: str, + filter: Optional[dict] = None, + skip: int = 0, + limit: int = 50, + db: Optional[Session] = None, + ) -> list[ChatModel]: + + with get_db_context(db) as db: + query = ( + db.query(Chat) + .filter_by(user_id=user_id) + .filter(Chat.share_id.isnot(None)) + ) + + if filter: + query_key = filter.get("query") + if query_key: + query = query.filter(Chat.title.ilike(f"%{query_key}%")) + + order_by = filter.get("order_by") + direction = filter.get("direction") + + if order_by and direction: + if not getattr(Chat, order_by, None): + raise ValueError("Invalid order_by field") + + if direction.lower() == "asc": + query = query.order_by(getattr(Chat, order_by).asc()) + elif direction.lower() == "desc": + query = query.order_by(getattr(Chat, order_by).desc()) + else: + raise ValueError("Invalid direction for ordering") + else: + query = query.order_by(Chat.updated_at.desc()) + + if skip: + query = query.offset(skip) + if limit: + query = query.limit(limit) + + all_chats = query.all() + return [ChatModel.model_validate(chat) for chat in all_chats] + def get_chat_list_by_user_id( self, user_id: str, @@ -1013,29 +1116,23 @@ def get_chats_by_user_id_and_search_text( # Check if there are any tags to filter, it should have all the tags if "none" in tag_ids: - query = query.filter( - text( - """ + query = query.filter(text(""" NOT EXISTS ( SELECT 1 FROM json_each(Chat.meta, '$.tags') AS tag ) - """ - ) - ) + """)) elif tag_ids: query = query.filter( and_( *[ - text( - f""" + text(f""" EXISTS ( SELECT 1 FROM json_each(Chat.meta, '$.tags') AS tag WHERE tag.value = :tag_id_{tag_idx} ) - """ - ).params(**{f"tag_id_{tag_idx}": tag_id}) + """).params(**{f"tag_id_{tag_idx}": tag_id}) for tag_idx, tag_id in enumerate(tag_ids) ] ) @@ -1071,29 +1168,23 @@ def get_chats_by_user_id_and_search_text( # Check if there are any tags to filter, it should have all the tags if "none" in tag_ids: - query = query.filter( - text( - """ + query = query.filter(text(""" NOT EXISTS ( SELECT 1 FROM json_array_elements_text(Chat.meta->'tags') AS tag ) - """ - ) - ) + """)) elif tag_ids: query = query.filter( and_( *[ - text( - f""" + text(f""" EXISTS ( SELECT 1 FROM json_array_elements_text(Chat.meta->'tags') AS tag WHERE tag = :tag_id_{tag_idx} ) - """ - ).params(**{f"tag_id_{tag_idx}": tag_id}) + """).params(**{f"tag_id_{tag_idx}": tag_id}) for tag_idx, tag_id in enumerate(tag_ids) ] ) diff --git a/backend/open_webui/models/credits.py b/backend/open_webui/models/credits.py index 12ed4030d8..b7a0e5cea6 100644 --- a/backend/open_webui/models/credits.py +++ b/backend/open_webui/models/credits.py @@ -16,7 +16,6 @@ from open_webui.internal.db import Base, get_db from open_webui.utils.redis import get_redis_connection, get_sentinels_from_env - #################### # User Credit DB Schema #################### diff --git a/backend/open_webui/models/feedbacks.py b/backend/open_webui/models/feedbacks.py index 048c10f85c..406adb2559 100644 --- a/backend/open_webui/models/feedbacks.py +++ b/backend/open_webui/models/feedbacks.py @@ -191,6 +191,23 @@ def get_feedback_by_id_and_user_id( except Exception: return None + def get_feedbacks_by_chat_id( + self, chat_id: str, db: Optional[Session] = None + ) -> list[FeedbackModel]: + """Get all feedbacks for a specific chat.""" + try: + with get_db_context(db) as db: + # meta.chat_id stores the chat reference + feedbacks = ( + db.query(Feedback) + .filter(Feedback.meta["chat_id"].as_string() == chat_id) + .order_by(Feedback.created_at.desc()) + .all() + ) + return [FeedbackModel.model_validate(fb) for fb in feedbacks] + except Exception: + return [] + def get_feedback_items( self, filter: dict = {}, @@ -460,23 +477,15 @@ def delete_feedbacks_by_user_id( self, user_id: str, db: Optional[Session] = None ) -> bool: with get_db_context(db) as db: - feedbacks = db.query(Feedback).filter_by(user_id=user_id).all() - if not feedbacks: - return False - for feedback in feedbacks: - db.delete(feedback) + result = db.query(Feedback).filter_by(user_id=user_id).delete() db.commit() - return True + return result > 0 def delete_all_feedbacks(self, db: Optional[Session] = None) -> bool: with get_db_context(db) as db: - feedbacks = db.query(Feedback).all() - if not feedbacks: - return False - for feedback in feedbacks: - db.delete(feedback) + result = db.query(Feedback).delete() db.commit() - return True + return result > 0 Feedbacks = FeedbackTable() diff --git a/backend/open_webui/models/files.py b/backend/open_webui/models/files.py index 4097ae08e1..09060f9bde 100644 --- a/backend/open_webui/models/files.py +++ b/backend/open_webui/models/files.py @@ -4,7 +4,7 @@ from sqlalchemy.orm import Session from open_webui.internal.db import Base, JSONField, get_db, get_db_context -from pydantic import BaseModel, ConfigDict +from pydantic import BaseModel, ConfigDict, model_validator from sqlalchemy import BigInteger, Column, String, Text, JSON log = logging.getLogger(__name__) @@ -26,8 +26,6 @@ class File(Base): data = Column(JSON, nullable=True) meta = Column(JSON, nullable=True) - access_control = Column(JSON, nullable=True) - created_at = Column(BigInteger) updated_at = Column(BigInteger) @@ -45,8 +43,6 @@ class FileModel(BaseModel): data: Optional[dict] = None meta: Optional[dict] = None - access_control: Optional[dict] = None - created_at: Optional[int] # timestamp in epoch updated_at: Optional[int] # timestamp in epoch @@ -63,6 +59,25 @@ class FileMeta(BaseModel): model_config = ConfigDict(extra="allow") + @model_validator(mode="before") + @classmethod + def sanitize_meta(cls, data): + """Sanitize metadata fields to handle malformed legacy data.""" + if not isinstance(data, dict): + return data + + # Handle content_type that may be a list like ['application/pdf', None] + content_type = data.get("content_type") + if isinstance(content_type, list): + # Extract first non-None string value + data["content_type"] = next( + (item for item in content_type if isinstance(item, str)), None + ) + elif content_type is not None and not isinstance(content_type, str): + data["content_type"] = None + + return data + class FileModelResponse(BaseModel): id: str @@ -74,7 +89,7 @@ class FileModelResponse(BaseModel): meta: FileMeta created_at: int # timestamp in epoch - updated_at: int # timestamp in epoch + updated_at: Optional[int] = None # timestamp in epoch, optional for legacy files model_config = ConfigDict(extra="allow") @@ -94,7 +109,6 @@ class FileForm(BaseModel): path: str data: dict = {} meta: dict = {} - access_control: Optional[dict] = None class FileUpdateForm(BaseModel): diff --git a/backend/open_webui/models/folders.py b/backend/open_webui/models/folders.py index 3455208944..24a872bcc4 100644 --- a/backend/open_webui/models/folders.py +++ b/backend/open_webui/models/folders.py @@ -11,7 +11,6 @@ from open_webui.internal.db import Base, JSONField, get_db, get_db_context - log = logging.getLogger(__name__) diff --git a/backend/open_webui/models/functions.py b/backend/open_webui/models/functions.py index 8e23bac093..fdbfac5e7c 100644 --- a/backend/open_webui/models/functions.py +++ b/backend/open_webui/models/functions.py @@ -195,6 +195,25 @@ def get_function_by_id( except Exception: return None + def get_functions_by_ids( + self, ids: list[str], db: Optional[Session] = None + ) -> list[FunctionModel]: + """ + Batch fetch multiple functions by their IDs in a single query. + Returns functions in the same order as the input IDs (None entries filtered out). + """ + if not ids: + return [] + try: + with get_db_context(db) as db: + functions = db.query(Function).filter(Function.id.in_(ids)).all() + # Create a dict for O(1) lookup + func_dict = {f.id: FunctionModel.model_validate(f) for f in functions} + # Return in original order, filtering out any not found + return [func_dict[id] for id in ids if id in func_dict] + except Exception: + return [] + def get_functions( self, active_only=False, include_valves=False, db: Optional[Session] = None ) -> list[FunctionModel | FunctionWithValvesModel]: @@ -299,7 +318,7 @@ def update_function_valves_by_id( function.updated_at = int(time.time()) db.commit() db.refresh(function) - return self.get_function_by_id(id, db=db) + return FunctionModel.model_validate(function) except Exception: return None @@ -319,7 +338,7 @@ def update_function_metadata_by_id( function.updated_at = int(time.time()) db.commit() db.refresh(function) - return self.get_function_by_id(id, db=db) + return FunctionModel.model_validate(function) else: return None except Exception as e: @@ -381,7 +400,8 @@ def update_function_by_id( } ) db.commit() - return self.get_function_by_id(id, db=db) + function = db.get(Function, id) + return FunctionModel.model_validate(function) if function else None except Exception: return None diff --git a/backend/open_webui/models/groups.py b/backend/open_webui/models/groups.py index ae557f4daf..8fe720ecc6 100644 --- a/backend/open_webui/models/groups.py +++ b/backend/open_webui/models/groups.py @@ -22,9 +22,9 @@ ForeignKey, cast, or_, + select, ) - log = logging.getLogger(__name__) #################### @@ -99,6 +99,16 @@ class GroupResponse(GroupModel): member_count: Optional[int] = None +class GroupInfoResponse(BaseModel): + id: str + user_id: str + name: str + description: str + member_count: Optional[int] = None + created_at: int + updated_at: int + + class GroupForm(BaseModel): name: str description: str @@ -165,27 +175,28 @@ def get_groups(self, filter, db: Optional[Session] = None) -> list[GroupResponse share_value = filter["share"] member_id = filter.get("member_id") json_share = Group.data["config"]["share"] - json_share_bool = json_share.as_boolean() json_share_str = json_share.as_string() + json_share_lower = func.lower(json_share_str) if share_value: - # Groups open to anyone: data is null, share is null, or share is true + # Groups open to anyone: data is null, config.share is null, or share is true + # Use case-insensitive string comparison to handle variations like "True", "TRUE" + # Handle potential JSON boolean to string casting issues by checking for both string 'true' and boolean equivalence if possible, anyone_can_share = or_( Group.data.is_(None), - json_share_bool.is_(None), - json_share_bool == True, + json_share_str.is_(None), + json_share_lower == "true", + json_share_lower == "1", # Handle SQLite boolean true ) if member_id: # Also include member-only groups where user is a member - member_groups_subq = ( - db.query(GroupMember.group_id) - .filter(GroupMember.user_id == member_id) - .subquery() + member_groups_select = select(GroupMember.group_id).where( + GroupMember.user_id == member_id ) members_only_and_is_member = and_( - json_share_str == "members", - Group.id.in_(member_groups_subq), + json_share_lower == "members", + Group.id.in_(member_groups_select), ) query = query.filter( or_(anyone_can_share, members_only_and_is_member) @@ -194,7 +205,7 @@ def get_groups(self, filter, db: Optional[Session] = None) -> list[GroupResponse query = query.filter(anyone_can_share) else: query = query.filter( - and_(Group.data.isnot(None), json_share_bool == False) + and_(Group.data.isnot(None), json_share_lower == "false") ) else: @@ -205,13 +216,13 @@ def get_groups(self, filter, db: Optional[Session] = None) -> list[GroupResponse ).filter(GroupMember.user_id == filter["member_id"]) groups = query.order_by(Group.updated_at.desc()).all() + group_ids = [group.id for group in groups] + member_counts = self.get_group_member_counts_by_ids(group_ids, db=db) return [ GroupResponse.model_validate( { **GroupModel.model_validate(group).model_dump(), - "member_count": self.get_group_member_count_by_id( - group.id, db=db - ), + "member_count": member_counts.get(group.id, 0), } ) for group in groups @@ -246,12 +257,14 @@ def search_groups( total = query.count() query = query.order_by(Group.updated_at.desc()) groups = query.offset(skip).limit(limit).all() + group_ids = [group.id for group in groups] + member_counts = self.get_group_member_counts_by_ids(group_ids, db=db) return { "items": [ GroupResponse.model_validate( **GroupModel.model_validate(group).model_dump(), - member_count=self.get_group_member_count_by_id(group.id, db=db), + member_count=member_counts.get(group.id, 0), ) for group in groups ], @@ -304,14 +317,14 @@ def get_group_by_id( def get_group_user_ids_by_id( self, id: str, db: Optional[Session] = None - ) -> Optional[list[str]]: + ) -> list[str]: with get_db_context(db) as db: members = ( db.query(GroupMember.user_id).filter(GroupMember.group_id == id).all() ) if not members: - return None + return [] return [m[0] for m in members] @@ -368,6 +381,20 @@ def get_group_member_count_by_id( ) return count if count else 0 + def get_group_member_counts_by_ids( + self, ids: list[str], db: Optional[Session] = None + ) -> dict[str, int]: + if not ids: + return {} + with get_db_context(db) as db: + rows = ( + db.query(GroupMember.group_id, func.count(GroupMember.user_id)) + .filter(GroupMember.group_id.in_(ids)) + .group_by(GroupMember.group_id) + .all() + ) + return {group_id: count for group_id, count in rows} + def update_group_by_id( self, id: str, @@ -588,11 +615,10 @@ def remove_users_from_group( if not user_ids: return GroupModel.model_validate(group) - # Remove each user from group_member - for user_id in user_ids: - db.query(GroupMember).filter( - GroupMember.group_id == id, GroupMember.user_id == user_id - ).delete() + # Remove users from group_member in batch + db.query(GroupMember).filter( + GroupMember.group_id == id, GroupMember.user_id.in_(user_ids) + ).delete(synchronize_session=False) # Update group timestamp group.updated_at = int(time.time()) diff --git a/backend/open_webui/models/knowledge.py b/backend/open_webui/models/knowledge.py index 7f99f828c7..1d21d5d910 100644 --- a/backend/open_webui/models/knowledge.py +++ b/backend/open_webui/models/knowledge.py @@ -15,9 +15,10 @@ ) from open_webui.models.groups import Groups from open_webui.models.users import User, UserModel, Users, UserResponse +from open_webui.models.access_grants import AccessGrantModel, AccessGrants -from pydantic import BaseModel, ConfigDict +from pydantic import BaseModel, ConfigDict, Field from sqlalchemy import ( BigInteger, Column, @@ -29,10 +30,6 @@ or_, ) -from open_webui.utils.access_control import has_access -from open_webui.utils.db.access_control import has_permission - - log = logging.getLogger(__name__) #################### @@ -50,22 +47,6 @@ class Knowledge(Base): description = Column(Text) meta = Column(JSON, nullable=True) - access_control = Column(JSON, nullable=True) # Controls data access levels. - # Defines access control rules for this entry. - # - `None`: Public access, available to all users with the "user" role. - # - `{}`: Private access, restricted exclusively to the owner. - # - Custom permissions: Specific access control for reading and writing; - # Can specify group or user-level restrictions: - # { - # "read": { - # "group_ids": ["group_id1", "group_id2"], - # "user_ids": ["user_id1", "user_id2"] - # }, - # "write": { - # "group_ids": ["group_id1", "group_id2"], - # "user_ids": ["user_id1", "user_id2"] - # } - # } created_at = Column(BigInteger) updated_at = Column(BigInteger) @@ -82,7 +63,7 @@ class KnowledgeModel(BaseModel): meta: Optional[dict] = None - access_control: Optional[dict] = None + access_grants: list[AccessGrantModel] = Field(default_factory=list) created_at: int # timestamp in epoch updated_at: int # timestamp in epoch @@ -139,7 +120,7 @@ class KnowledgeUserResponse(KnowledgeUserModel): class KnowledgeForm(BaseModel): name: str description: str - access_control: Optional[dict] = None + access_grants: Optional[list[dict]] = None class FileUserResponse(FileModelResponse): @@ -157,27 +138,47 @@ class KnowledgeFileListResponse(BaseModel): class KnowledgeTable: + def _get_access_grants( + self, knowledge_id: str, db: Optional[Session] = None + ) -> list[AccessGrantModel]: + return AccessGrants.get_grants_by_resource("knowledge", knowledge_id, db=db) + + def _to_knowledge_model( + self, knowledge: Knowledge, db: Optional[Session] = None + ) -> KnowledgeModel: + knowledge_data = KnowledgeModel.model_validate(knowledge).model_dump( + exclude={"access_grants"} + ) + knowledge_data["access_grants"] = self._get_access_grants( + knowledge_data["id"], db=db + ) + return KnowledgeModel.model_validate(knowledge_data) + def insert_new_knowledge( self, user_id: str, form_data: KnowledgeForm, db: Optional[Session] = None ) -> Optional[KnowledgeModel]: with get_db_context(db) as db: knowledge = KnowledgeModel( **{ - **form_data.model_dump(), + **form_data.model_dump(exclude={"access_grants"}), "id": str(uuid.uuid4()), "user_id": user_id, "created_at": int(time.time()), "updated_at": int(time.time()), + "access_grants": [], } ) try: - result = Knowledge(**knowledge.model_dump()) + result = Knowledge(**knowledge.model_dump(exclude={"access_grants"})) db.add(result) db.commit() db.refresh(result) + AccessGrants.set_access_grants( + "knowledge", result.id, form_data.access_grants, db=db + ) if result: - return KnowledgeModel.model_validate(result) + return self._to_knowledge_model(result, db=db) else: return None except Exception: @@ -201,7 +202,7 @@ def get_knowledge_bases( knowledge_bases.append( KnowledgeUserModel.model_validate( { - **KnowledgeModel.model_validate(knowledge).model_dump(), + **self._to_knowledge_model(knowledge, db=db).model_dump(), "user": user.model_dump() if user else None, } ) @@ -229,6 +230,9 @@ def search_knowledge_bases( or_( Knowledge.name.ilike(f"%{query_key}%"), Knowledge.description.ilike(f"%{query_key}%"), + User.name.ilike(f"%{query_key}%"), + User.email.ilike(f"%{query_key}%"), + User.username.ilike(f"%{query_key}%"), ) ) @@ -238,9 +242,16 @@ def search_knowledge_bases( elif view_option == "shared": query = query.filter(Knowledge.user_id != user_id) - query = has_permission(db, Knowledge, query, filter) + query = AccessGrants.has_permission_filter( + db=db, + query=query, + DocumentModel=Knowledge, + filter=filter, + resource_type="knowledge", + permission="read", + ) - query = query.order_by(Knowledge.updated_at.desc()) + query = query.order_by(Knowledge.updated_at.desc(), Knowledge.id.asc()) total = query.count() if skip: @@ -255,8 +266,8 @@ def search_knowledge_bases( knowledge_bases.append( KnowledgeUserModel.model_validate( { - **KnowledgeModel.model_validate( - knowledge_base + **self._to_knowledge_model( + knowledge_base, db=db ).model_dump(), "user": ( UserModel.model_validate(user).model_dump() @@ -291,7 +302,14 @@ def search_knowledge_files( # Apply access-control directly to the joined query # This makes the database handle filtering, even with 10k+ KBs - query = has_permission(db, Knowledge, query, filter) + query = AccessGrants.has_permission_filter( + db=db, + query=query, + DocumentModel=Knowledge, + filter=filter, + resource_type="knowledge", + permission="read", + ) # Apply filename search if filter: @@ -300,7 +318,7 @@ def search_knowledge_files( query = query.filter(File.filename.ilike(f"%{q}%")) # Order by file changes - query = query.order_by(File.updated_at.desc()) + query = query.order_by(File.updated_at.desc(), File.id.asc()) # Count before pagination total = query.count() @@ -324,8 +342,8 @@ def search_knowledge_files( if user else None ), - collection=KnowledgeModel.model_validate( - knowledge + collection=self._to_knowledge_model( + knowledge, db=db ).model_dump(), ) ) @@ -347,7 +365,14 @@ def check_access_by_user_id( user_group_ids = { group.id for group in Groups.get_groups_by_member_id(user_id, db=db) } - return has_access(user_id, permission, knowledge.access_control, user_group_ids) + return AccessGrants.has_access( + user_id=user_id, + resource_type="knowledge", + resource_id=knowledge.id, + permission=permission, + user_group_ids=user_group_ids, + db=db, + ) def get_knowledge_bases_by_user_id( self, user_id: str, permission: str = "write", db: Optional[Session] = None @@ -360,8 +385,13 @@ def get_knowledge_bases_by_user_id( knowledge_base for knowledge_base in knowledge_bases if knowledge_base.user_id == user_id - or has_access( - user_id, permission, knowledge_base.access_control, user_group_ids + or AccessGrants.has_access( + user_id=user_id, + resource_type="knowledge", + resource_id=knowledge_base.id, + permission=permission, + user_group_ids=user_group_ids, + db=db, ) ] @@ -371,7 +401,7 @@ def get_knowledge_by_id( try: with get_db_context(db) as db: knowledge = db.query(Knowledge).filter_by(id=id).first() - return KnowledgeModel.model_validate(knowledge) if knowledge else None + return self._to_knowledge_model(knowledge, db=db) if knowledge else None except Exception: return None @@ -388,7 +418,14 @@ def get_knowledge_by_id_and_user_id( user_group_ids = { group.id for group in Groups.get_groups_by_member_id(user_id, db=db) } - if has_access(user_id, "write", knowledge.access_control, user_group_ids): + if AccessGrants.has_access( + user_id=user_id, + resource_type="knowledge", + resource_id=knowledge.id, + permission="write", + user_group_ids=user_group_ids, + db=db, + ): return knowledge return None @@ -404,7 +441,8 @@ def get_knowledges_by_file_id( .all() ) return [ - KnowledgeModel.model_validate(knowledge) for knowledge in knowledges + self._to_knowledge_model(knowledge, db=db) + for knowledge in knowledges ] except Exception: return [] @@ -427,6 +465,9 @@ def search_files_by_id( .filter(KnowledgeFile.knowledge_id == knowledge_id) ) + # Default sort: updated_at descending + primary_sort = File.updated_at.desc() + if filter: query_key = filter.get("query") if query_key: @@ -440,27 +481,23 @@ def search_files_by_id( order_by = filter.get("order_by") direction = filter.get("direction") + is_asc = direction == "asc" if order_by == "name": - if direction == "asc": - query = query.order_by(File.filename.asc()) - else: - query = query.order_by(File.filename.desc()) + primary_sort = ( + File.filename.asc() if is_asc else File.filename.desc() + ) elif order_by == "created_at": - if direction == "asc": - query = query.order_by(File.created_at.asc()) - else: - query = query.order_by(File.created_at.desc()) + primary_sort = ( + File.created_at.asc() if is_asc else File.created_at.desc() + ) elif order_by == "updated_at": - if direction == "asc": - query = query.order_by(File.updated_at.asc()) - else: - query = query.order_by(File.updated_at.desc()) - else: - query = query.order_by(File.updated_at.desc()) + primary_sort = ( + File.updated_at.asc() if is_asc else File.updated_at.desc() + ) - else: - query = query.order_by(File.updated_at.desc()) + # Apply sort with secondary key for deterministic pagination + query = query.order_by(primary_sort, File.id.asc()) # Count BEFORE pagination total = query.count() @@ -595,11 +632,15 @@ def update_knowledge_by_id( knowledge = self.get_knowledge_by_id(id=id, db=db) db.query(Knowledge).filter_by(id=id).update( { - **form_data.model_dump(), + **form_data.model_dump(exclude={"access_grants"}), "updated_at": int(time.time()), } ) db.commit() + if form_data.access_grants is not None: + AccessGrants.set_access_grants( + "knowledge", id, form_data.access_grants, db=db + ) return self.get_knowledge_by_id(id=id, db=db) except Exception as e: log.exception(e) @@ -626,6 +667,7 @@ def update_knowledge_data_by_id( def delete_knowledge_by_id(self, id: str, db: Optional[Session] = None) -> bool: try: with get_db_context(db) as db: + AccessGrants.revoke_all_access("knowledge", id, db=db) db.query(Knowledge).filter_by(id=id).delete() db.commit() return True @@ -635,6 +677,9 @@ def delete_knowledge_by_id(self, id: str, db: Optional[Session] = None) -> bool: def delete_all_knowledge(self, db: Optional[Session] = None) -> bool: with get_db_context(db) as db: try: + knowledge_ids = [row[0] for row in db.query(Knowledge.id).all()] + for knowledge_id in knowledge_ids: + AccessGrants.revoke_all_access("knowledge", knowledge_id, db=db) db.query(Knowledge).delete() db.commit() diff --git a/backend/open_webui/models/memories.py b/backend/open_webui/models/memories.py index 2dc9656856..e6b70a3020 100644 --- a/backend/open_webui/models/memories.py +++ b/backend/open_webui/models/memories.py @@ -82,7 +82,8 @@ def update_memory_by_id_and_user_id( memory.updated_at = int(time.time()) db.commit() - return self.get_memory_by_id(id) + db.refresh(memory) + return MemoryModel.model_validate(memory) except Exception: return None diff --git a/backend/open_webui/models/models.py b/backend/open_webui/models/models.py index f7b540d685..113e0d9fe7 100755 --- a/backend/open_webui/models/models.py +++ b/backend/open_webui/models/models.py @@ -7,6 +7,8 @@ from open_webui.models.groups import Groups from open_webui.models.users import User, UserModel, Users, UserResponse +from open_webui.models.access_grants import AccessGrantModel, AccessGrants + from pydantic import BaseModel, ConfigDict, Field @@ -16,8 +18,6 @@ from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy import BigInteger, Column, Text, JSON, Boolean -from open_webui.utils.access_control import has_access - log = logging.getLogger(__name__) @@ -77,23 +77,6 @@ class Model(Base): Holds a JSON encoded blob of metadata, see `ModelMeta`. """ - access_control = Column(JSON, nullable=True) # Controls data access levels. - # Defines access control rules for this entry. - # - `None`: Public access, available to all users with the "user" role. - # - `{}`: Private access, restricted exclusively to the owner. - # - Custom permissions: Specific access control for reading and writing; - # Can specify group or user-level restrictions: - # { - # "read": { - # "group_ids": ["group_id1", "group_id2"], - # "user_ids": ["user_id1", "user_id2"] - # }, - # "write": { - # "group_ids": ["group_id1", "group_id2"], - # "user_ids": ["user_id1", "user_id2"] - # } - # } - price = Column(JSON, nullable=True) is_active = Column(Boolean, default=True) @@ -111,7 +94,7 @@ class ModelModel(BaseModel): params: ModelParams meta: ModelMeta - access_control: Optional[dict] = None + access_grants: list[AccessGrantModel] = Field(default_factory=list) price: Optional[dict] = None @@ -188,32 +171,46 @@ class ModelForm(BaseModel): name: str meta: ModelMeta params: ModelParams - access_control: Optional[dict] = None + access_grants: Optional[list[dict]] = None price: Optional[ModelPriceForm] = None is_active: bool = True class ModelsTable: + def _get_access_grants( + self, model_id: str, db: Optional[Session] = None + ) -> list[AccessGrantModel]: + return AccessGrants.get_grants_by_resource("model", model_id, db=db) + + def _to_model_model(self, model: Model, db: Optional[Session] = None) -> ModelModel: + model_data = ModelModel.model_validate(model).model_dump( + exclude={"access_grants"} + ) + model_data["access_grants"] = self._get_access_grants(model_data["id"], db=db) + return ModelModel.model_validate(model_data) + def insert_new_model( self, form_data: ModelForm, user_id: str, db: Optional[Session] = None ) -> Optional[ModelModel]: - model = ModelModel( - **{ - **form_data.model_dump(), - "user_id": user_id, - "created_at": int(time.time()), - "updated_at": int(time.time()), - } - ) try: with get_db_context(db) as db: - result = Model(**model.model_dump()) + result = Model( + **{ + **form_data.model_dump(exclude={"access_grants"}), + "user_id": user_id, + "created_at": int(time.time()), + "updated_at": int(time.time()), + } + ) db.add(result) db.commit() db.refresh(result) + AccessGrants.set_access_grants( + "model", result.id, form_data.access_grants, db=db + ) if result: - return ModelModel.model_validate(result) + return self._to_model_model(result, db=db) else: return None except Exception as e: @@ -222,7 +219,9 @@ def insert_new_model( def get_all_models(self, db: Optional[Session] = None) -> list[ModelModel]: with get_db_context(db) as db: - return [ModelModel.model_validate(model) for model in db.query(Model).all()] + return [ + self._to_model_model(model, db=db) for model in db.query(Model).all() + ] def get_models(self, db: Optional[Session] = None) -> list[ModelUserResponse]: with get_db_context(db) as db: @@ -239,7 +238,7 @@ def get_models(self, db: Optional[Session] = None) -> list[ModelUserResponse]: models.append( ModelUserResponse.model_validate( { - **ModelModel.model_validate(model).model_dump(), + **self._to_model_model(model, db=db).model_dump(), "user": user.model_dump() if user else None, } ) @@ -249,7 +248,7 @@ def get_models(self, db: Optional[Session] = None) -> list[ModelUserResponse]: def get_base_models(self, db: Optional[Session] = None) -> list[ModelModel]: with get_db_context(db) as db: return [ - ModelModel.model_validate(model) + self._to_model_model(model, db=db) for model in db.query(Model).filter(Model.base_model_id == None).all() ] @@ -264,50 +263,25 @@ def get_models_by_user_id( model for model in models if model.user_id == user_id - or has_access(user_id, permission, model.access_control, user_group_ids) + or AccessGrants.has_access( + user_id=user_id, + resource_type="model", + resource_id=model.id, + permission=permission, + user_group_ids=user_group_ids, + db=db, + ) ] def _has_permission(self, db, query, filter: dict, permission: str = "read"): - group_ids = filter.get("group_ids", []) - user_id = filter.get("user_id") - - dialect_name = db.bind.dialect.name - - # Public access - conditions = [] - if group_ids or user_id: - conditions.extend( - [ - Model.access_control.is_(None), - cast(Model.access_control, String) == "null", - ] - ) - - # User-level permission - if user_id: - conditions.append(Model.user_id == user_id) - - # Group-level permission - if group_ids: - group_conditions = [] - for gid in group_ids: - if dialect_name == "sqlite": - group_conditions.append( - Model.access_control[permission]["group_ids"].contains([gid]) - ) - elif dialect_name == "postgresql": - group_conditions.append( - cast( - Model.access_control[permission]["group_ids"], - JSONB, - ).contains([gid]) - ) - conditions.append(or_(*group_conditions)) - - if conditions: - query = query.filter(or_(*conditions)) - - return query + return AccessGrants.has_permission_filter( + db=db, + query=query, + DocumentModel=Model, + filter=filter, + resource_type="model", + permission=permission, + ) def search_models( self, @@ -329,6 +303,9 @@ def search_models( or_( Model.name.ilike(f"%{query_key}%"), Model.base_model_id.ilike(f"%{query_key}%"), + User.name.ilike(f"%{query_key}%"), + User.email.ilike(f"%{query_key}%"), + User.username.ilike(f"%{query_key}%"), ) ) @@ -390,7 +367,7 @@ def search_models( for model, user in items: models.append( ModelUserResponse( - **ModelModel.model_validate(model).model_dump(), + **self._to_model_model(model, db=db).model_dump(), user=( UserResponse(**UserModel.model_validate(user).model_dump()) if user @@ -407,7 +384,7 @@ def get_model_by_id( try: with get_db_context(db) as db: model = db.get(Model, id) - return ModelModel.model_validate(model) + return self._to_model_model(model, db=db) if model else None except Exception: return None @@ -417,7 +394,7 @@ def get_models_by_ids( try: with get_db_context(db) as db: models = db.query(Model).filter(Model.id.in_(ids)).all() - return [ModelModel.model_validate(model) for model in models] + return [self._to_model_model(model, db=db) for model in models] except Exception: return [] @@ -426,17 +403,16 @@ def toggle_model_by_id( ) -> Optional[ModelModel]: with get_db_context(db) as db: try: - is_active = db.query(Model).filter_by(id=id).first().is_active + model = db.query(Model).filter_by(id=id).first() + if not model: + return None - db.query(Model).filter_by(id=id).update( - { - "is_active": not is_active, - "updated_at": int(time.time()), - } - ) + model.is_active = not model.is_active + model.updated_at = int(time.time()) db.commit() + db.refresh(model) - return self.get_model_by_id(id, db=db) + return self._to_model_model(model, db=db) except Exception: return None @@ -446,14 +422,16 @@ def update_model_by_id( try: with get_db_context(db) as db: # update only the fields that are present in the model - data = model.model_dump(exclude={"id"}) + data = model.model_dump(exclude={"id", "access_grants"}) result = db.query(Model).filter_by(id=id).update(data) db.commit() + if model.access_grants is not None: + AccessGrants.set_access_grants( + "model", id, model.access_grants, db=db + ) - model = db.get(Model, id) - db.refresh(model) - return ModelModel.model_validate(model) + return self.get_model_by_id(id, db=db) except Exception as e: log.exception(f"Failed to update the model by id {id}: {e}") return None @@ -461,6 +439,7 @@ def update_model_by_id( def delete_model_by_id(self, id: str, db: Optional[Session] = None) -> bool: try: with get_db_context(db) as db: + AccessGrants.revoke_all_access("model", id, db=db) db.query(Model).filter_by(id=id).delete() db.commit() @@ -471,6 +450,9 @@ def delete_model_by_id(self, id: str, db: Optional[Session] = None) -> bool: def delete_all_models(self, db: Optional[Session] = None) -> bool: try: with get_db_context(db) as db: + model_ids = [row[0] for row in db.query(Model.id).all()] + for model_id in model_ids: + AccessGrants.revoke_all_access("model", model_id, db=db) db.query(Model).delete() db.commit() @@ -495,7 +477,7 @@ def sync_models( if model.id in existing_ids: db.query(Model).filter_by(id=model.id).update( { - **model.model_dump(), + **model.model_dump(exclude={"access_grants"}), "user_id": user_id, "updated_at": int(time.time()), } @@ -503,22 +485,27 @@ def sync_models( else: new_model = Model( **{ - **model.model_dump(), + **model.model_dump(exclude={"access_grants"}), "user_id": user_id, "updated_at": int(time.time()), } ) db.add(new_model) + AccessGrants.set_access_grants( + "model", model.id, model.access_grants, db=db + ) # Remove models that are no longer present for model in existing_models: if model.id not in new_model_ids: + AccessGrants.revoke_all_access("model", model.id, db=db) db.delete(model) db.commit() return [ - ModelModel.model_validate(model) for model in db.query(Model).all() + self._to_model_model(model, db=db) + for model in db.query(Model).all() ] except Exception as e: log.exception(f"Error syncing models for user {user_id}: {e}") diff --git a/backend/open_webui/models/notes.py b/backend/open_webui/models/notes.py index bd23530785..d17c749d1c 100644 --- a/backend/open_webui/models/notes.py +++ b/backend/open_webui/models/notes.py @@ -7,17 +7,13 @@ from sqlalchemy.orm import Session from open_webui.internal.db import Base, get_db, get_db_context from open_webui.models.groups import Groups -from open_webui.utils.access_control import has_access from open_webui.models.users import User, UserModel, Users, UserResponse +from open_webui.models.access_grants import AccessGrantModel, AccessGrants -from pydantic import BaseModel, ConfigDict -from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON -from sqlalchemy.dialects.postgresql import JSONB - - -from sqlalchemy import or_, func, select, and_, text, cast, or_, and_, func -from sqlalchemy.sql import exists +from pydantic import BaseModel, ConfigDict, Field +from sqlalchemy import BigInteger, Column, Text, JSON +from sqlalchemy import or_, func, cast #################### # Note DB Schema @@ -34,8 +30,6 @@ class Note(Base): data = Column(JSON, nullable=True) meta = Column(JSON, nullable=True) - access_control = Column(JSON, nullable=True) - created_at = Column(BigInteger) updated_at = Column(BigInteger) @@ -50,7 +44,7 @@ class NoteModel(BaseModel): data: Optional[dict] = None meta: Optional[dict] = None - access_control: Optional[dict] = None + access_grants: list[AccessGrantModel] = Field(default_factory=list) created_at: int # timestamp in epoch updated_at: int # timestamp in epoch @@ -65,14 +59,14 @@ class NoteForm(BaseModel): title: str data: Optional[dict] = None meta: Optional[dict] = None - access_control: Optional[dict] = None + access_grants: Optional[list[dict]] = None class NoteUpdateForm(BaseModel): title: Optional[str] = None data: Optional[dict] = None meta: Optional[dict] = None - access_control: Optional[dict] = None + access_grants: Optional[list[dict]] = None class NoteUserResponse(NoteModel): @@ -94,122 +88,25 @@ class NoteListResponse(BaseModel): class NoteTable: - def _has_permission(self, db, query, filter: dict, permission: str = "read"): - group_ids = filter.get("group_ids", []) - user_id = filter.get("user_id") - dialect_name = db.bind.dialect.name - - conditions = [] - - # Handle read_only permission separately - if permission == "read_only": - # For read_only, we want items where: - # 1. User has explicit read permission (via groups or user-level) - # 2. BUT does NOT have write permission - # 3. Public items are NOT considered read_only - - read_conditions = [] - - # Group-level read permission - if group_ids: - group_read_conditions = [] - for gid in group_ids: - if dialect_name == "sqlite": - group_read_conditions.append( - Note.access_control["read"]["group_ids"].contains([gid]) - ) - elif dialect_name == "postgresql": - group_read_conditions.append( - cast( - Note.access_control["read"]["group_ids"], - JSONB, - ).contains([gid]) - ) - - if group_read_conditions: - read_conditions.append(or_(*group_read_conditions)) - - # Combine read conditions - if read_conditions: - has_read = or_(*read_conditions) - else: - # If no read conditions, return empty result - return query.filter(False) - - # Now exclude items where user has write permission - write_exclusions = [] - - # Exclude items owned by user (they have implicit write) - if user_id: - write_exclusions.append(Note.user_id != user_id) - - # Exclude items where user has explicit write permission via groups - if group_ids: - group_write_conditions = [] - for gid in group_ids: - if dialect_name == "sqlite": - group_write_conditions.append( - Note.access_control["write"]["group_ids"].contains([gid]) - ) - elif dialect_name == "postgresql": - group_write_conditions.append( - cast( - Note.access_control["write"]["group_ids"], - JSONB, - ).contains([gid]) - ) + def _get_access_grants( + self, note_id: str, db: Optional[Session] = None + ) -> list[AccessGrantModel]: + return AccessGrants.get_grants_by_resource("note", note_id, db=db) - if group_write_conditions: - # User should NOT have write permission - write_exclusions.append(~or_(*group_write_conditions)) + def _to_note_model(self, note: Note, db: Optional[Session] = None) -> NoteModel: + note_data = NoteModel.model_validate(note).model_dump(exclude={"access_grants"}) + note_data["access_grants"] = self._get_access_grants(note_data["id"], db=db) + return NoteModel.model_validate(note_data) - # Exclude public items (items without access_control) - write_exclusions.append(Note.access_control.isnot(None)) - write_exclusions.append(cast(Note.access_control, String) != "null") - - # Combine: has read AND does not have write AND not public - if write_exclusions: - query = query.filter(and_(has_read, *write_exclusions)) - else: - query = query.filter(has_read) - - return query - - # Original logic for other permissions (read, write, etc.) - # Public access conditions - if group_ids or user_id: - conditions.extend( - [ - Note.access_control.is_(None), - cast(Note.access_control, String) == "null", - ] - ) - - # User-level permission (owner has all permissions) - if user_id: - conditions.append(Note.user_id == user_id) - - # Group-level permission - if group_ids: - group_conditions = [] - for gid in group_ids: - if dialect_name == "sqlite": - group_conditions.append( - Note.access_control[permission]["group_ids"].contains([gid]) - ) - elif dialect_name == "postgresql": - group_conditions.append( - cast( - Note.access_control[permission]["group_ids"], - JSONB, - ).contains([gid]) - ) - conditions.append(or_(*group_conditions)) - - if conditions: - query = query.filter(or_(*conditions)) - - return query + def _has_permission(self, db, query, filter: dict, permission: str = "read"): + return AccessGrants.has_permission_filter( + db=db, + query=query, + DocumentModel=Note, + filter=filter, + resource_type="note", + permission=permission, + ) def insert_new_note( self, user_id: str, form_data: NoteForm, db: Optional[Session] = None @@ -219,17 +116,21 @@ def insert_new_note( **{ "id": str(uuid.uuid4()), "user_id": user_id, - **form_data.model_dump(), + **form_data.model_dump(exclude={"access_grants"}), "created_at": int(time.time_ns()), "updated_at": int(time.time_ns()), + "access_grants": [], } ) - new_note = Note(**note.model_dump()) + new_note = Note(**note.model_dump(exclude={"access_grants"})) db.add(new_note) db.commit() - return note + AccessGrants.set_access_grants( + "note", note.id, form_data.access_grants, db=db + ) + return self._to_note_model(new_note, db=db) def get_notes( self, skip: int = 0, limit: int = 50, db: Optional[Session] = None @@ -241,7 +142,7 @@ def get_notes( if limit is not None: query = query.limit(limit) notes = query.all() - return [NoteModel.model_validate(note) for note in notes] + return [self._to_note_model(note, db=db) for note in notes] def search_notes( self, @@ -330,7 +231,7 @@ def search_notes( for note, user in items: notes.append( NoteUserResponse( - **NoteModel.model_validate(note).model_dump(), + **self._to_note_model(note, db=db).model_dump(), user=( UserResponse(**UserModel.model_validate(user).model_dump()) if user @@ -365,14 +266,14 @@ def get_notes_by_user_id( query = query.limit(limit) notes = query.all() - return [NoteModel.model_validate(note) for note in notes] + return [self._to_note_model(note, db=db) for note in notes] def get_note_by_id( self, id: str, db: Optional[Session] = None ) -> Optional[NoteModel]: with get_db_context(db) as db: note = db.query(Note).filter(Note.id == id).first() - return NoteModel.model_validate(note) if note else None + return self._to_note_model(note, db=db) if note else None def update_note_by_id( self, id: str, form_data: NoteUpdateForm, db: Optional[Session] = None @@ -391,17 +292,20 @@ def update_note_by_id( if "meta" in form_data: note.meta = {**note.meta, **form_data["meta"]} - if "access_control" in form_data: - note.access_control = form_data["access_control"] + if "access_grants" in form_data: + AccessGrants.set_access_grants( + "note", id, form_data["access_grants"], db=db + ) note.updated_at = int(time.time_ns()) db.commit() - return NoteModel.model_validate(note) if note else None + return self._to_note_model(note, db=db) if note else None def delete_note_by_id(self, id: str, db: Optional[Session] = None) -> bool: try: with get_db_context(db) as db: + AccessGrants.revoke_all_access("note", id, db=db) db.query(Note).filter(Note.id == id).delete() db.commit() return True diff --git a/backend/open_webui/models/prompt_history.py b/backend/open_webui/models/prompt_history.py new file mode 100644 index 0000000000..91ca4cb445 --- /dev/null +++ b/backend/open_webui/models/prompt_history.py @@ -0,0 +1,234 @@ +"""Prompt history model for version tracking.""" + +import time +import uuid +from typing import Optional +import json +import difflib + +from sqlalchemy.orm import Session +from open_webui.internal.db import Base, get_db_context +from open_webui.models.users import Users, UserResponse + +from pydantic import BaseModel, ConfigDict +from sqlalchemy import BigInteger, Column, Text, JSON, Index + +#################### +# PromptHistory DB Schema +#################### + + +class PromptHistory(Base): + __tablename__ = "prompt_history" + + id = Column(Text, primary_key=True) + prompt_id = Column(Text, nullable=False, index=True) + parent_id = Column(Text, nullable=True) # Reference to parent commit + snapshot = Column(JSON, nullable=False) + user_id = Column(Text, nullable=False) + commit_message = Column(Text, nullable=True) + created_at = Column(BigInteger, nullable=False) + + +class PromptHistoryModel(BaseModel): + id: str + prompt_id: str + parent_id: Optional[str] = None + snapshot: dict + user_id: str + commit_message: Optional[str] = None + created_at: int + + model_config = ConfigDict(from_attributes=True) + + +class PromptHistoryResponse(PromptHistoryModel): + """Response model with user info.""" + + user: Optional[UserResponse] = None + + +class PromptHistoryTable: + def create_history_entry( + self, + prompt_id: str, + snapshot: dict, + user_id: str, + parent_id: Optional[str] = None, + commit_message: Optional[str] = None, + db: Optional[Session] = None, + ) -> Optional[PromptHistoryModel]: + """Create a new history entry (commit) for a prompt.""" + with get_db_context(db) as db: + history = PromptHistory( + id=str(uuid.uuid4()), + prompt_id=prompt_id, + parent_id=parent_id, + snapshot=snapshot, + user_id=user_id, + commit_message=commit_message, + created_at=int(time.time()), + ) + db.add(history) + db.commit() + db.refresh(history) + return PromptHistoryModel.model_validate(history) + + def get_history_by_prompt_id( + self, + prompt_id: str, + limit: int = 50, + offset: int = 0, + db: Optional[Session] = None, + ) -> list[PromptHistoryResponse]: + """Get all history entries for a prompt, ordered by created_at desc.""" + with get_db_context(db) as db: + entries = ( + db.query(PromptHistory) + .filter(PromptHistory.prompt_id == prompt_id) + .order_by(PromptHistory.created_at.desc()) + .offset(offset) + .limit(limit) + .all() + ) + + # Get user info for each entry + user_ids = list(set(e.user_id for e in entries)) + users = Users.get_users_by_user_ids(user_ids, db=db) if user_ids else [] + users_dict = {user.id: user for user in users} + + return [ + PromptHistoryResponse( + **PromptHistoryModel.model_validate(entry).model_dump(), + user=( + users_dict.get(entry.user_id).model_dump() + if users_dict.get(entry.user_id) + else None + ), + ) + for entry in entries + ] + + def get_history_entry_by_id( + self, + history_id: str, + db: Optional[Session] = None, + ) -> Optional[PromptHistoryModel]: + """Get a specific history entry by ID.""" + with get_db_context(db) as db: + entry = ( + db.query(PromptHistory).filter(PromptHistory.id == history_id).first() + ) + if entry: + return PromptHistoryModel.model_validate(entry) + return None + + def get_latest_history_entry( + self, + prompt_id: str, + db: Optional[Session] = None, + ) -> Optional[PromptHistoryModel]: + """Get the most recent history entry for a prompt.""" + with get_db_context(db) as db: + entry = ( + db.query(PromptHistory) + .filter(PromptHistory.prompt_id == prompt_id) + .order_by(PromptHistory.created_at.desc()) + .first() + ) + if entry: + return PromptHistoryModel.model_validate(entry) + return None + + def get_history_count( + self, + prompt_id: str, + db: Optional[Session] = None, + ) -> int: + """Get the number of history entries for a prompt.""" + with get_db_context(db) as db: + return ( + db.query(PromptHistory) + .filter(PromptHistory.prompt_id == prompt_id) + .count() + ) + + def compute_diff( + self, + from_id: str, + to_id: str, + db: Optional[Session] = None, + ) -> Optional[dict]: + """Compute diff between two history entries.""" + with get_db_context(db) as db: + from_entry = ( + db.query(PromptHistory).filter(PromptHistory.id == from_id).first() + ) + to_entry = db.query(PromptHistory).filter(PromptHistory.id == to_id).first() + + if not from_entry or not to_entry: + return None + + from_snapshot = from_entry.snapshot + to_snapshot = to_entry.snapshot + + # Compute diff for content field + from_content = from_snapshot.get("content", "") + to_content = to_snapshot.get("content", "") + + diff_lines = list( + difflib.unified_diff( + from_content.splitlines(keepends=True), + to_content.splitlines(keepends=True), + fromfile=f"v{from_id[:8]}", + tofile=f"v{to_id[:8]}", + lineterm="", + ) + ) + + return { + "from_id": from_id, + "to_id": to_id, + "from_snapshot": from_snapshot, + "to_snapshot": to_snapshot, + "content_diff": diff_lines, + "name_changed": from_snapshot.get("name") != to_snapshot.get("name"), + } + + def delete_history_by_prompt_id( + self, + prompt_id: str, + db: Optional[Session] = None, + ) -> bool: + """Delete all history entries for a prompt.""" + with get_db_context(db) as db: + db.query(PromptHistory).filter( + PromptHistory.prompt_id == prompt_id + ).delete() + db.commit() + return True + + def delete_history_entry( + self, + history_id: str, + db: Optional[Session] = None, + ) -> bool: + """Delete a history entry and reparent its children to grandparent.""" + with get_db_context(db) as db: + entry = db.query(PromptHistory).filter_by(id=history_id).first() + if not entry: + return False + + # Find children that reference this entry as parent + children = db.query(PromptHistory).filter_by(parent_id=history_id).all() + + # Reparent children to grandparent + for child in children: + child.parent_id = entry.parent_id + + db.delete(entry) + db.commit() + return True + + +PromptHistories = PromptHistoryTable() diff --git a/backend/open_webui/models/prompts.py b/backend/open_webui/models/prompts.py index 847597bc65..3ab7a496ab 100644 --- a/backend/open_webui/models/prompts.py +++ b/backend/open_webui/models/prompts.py @@ -1,15 +1,17 @@ import time +import uuid from typing import Optional from sqlalchemy.orm import Session from open_webui.internal.db import Base, JSONField, get_db, get_db_context from open_webui.models.groups import Groups from open_webui.models.users import Users, UserResponse +from open_webui.models.prompt_history import PromptHistories +from open_webui.models.access_grants import AccessGrantModel, AccessGrants -from pydantic import BaseModel, ConfigDict -from sqlalchemy import BigInteger, Column, String, Text, JSON -from open_webui.utils.access_control import has_access +from pydantic import BaseModel, ConfigDict, Field +from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON, or_, func, cast #################### # Prompts DB Schema @@ -19,38 +21,35 @@ class Prompt(Base): __tablename__ = "prompt" - command = Column(String, primary_key=True) + id = Column(Text, primary_key=True) + command = Column(String, unique=True, index=True) user_id = Column(String) - title = Column(Text) + name = Column(Text) content = Column(Text) - timestamp = Column(BigInteger) - - access_control = Column(JSON, nullable=True) # Controls data access levels. - # Defines access control rules for this entry. - # - `None`: Public access, available to all users with the "user" role. - # - `{}`: Private access, restricted exclusively to the owner. - # - Custom permissions: Specific access control for reading and writing; - # Can specify group or user-level restrictions: - # { - # "read": { - # "group_ids": ["group_id1", "group_id2"], - # "user_ids": ["user_id1", "user_id2"] - # }, - # "write": { - # "group_ids": ["group_id1", "group_id2"], - # "user_ids": ["user_id1", "user_id2"] - # } - # } + data = Column(JSON, nullable=True) + meta = Column(JSON, nullable=True) + tags = Column(JSON, nullable=True) + is_active = Column(Boolean, default=True) + version_id = Column(Text, nullable=True) # Points to active history entry + created_at = Column(BigInteger, nullable=True) + updated_at = Column(BigInteger, nullable=True) class PromptModel(BaseModel): + id: Optional[str] = None command: str user_id: str - title: str + name: str content: str - timestamp: int # timestamp in epoch + data: Optional[dict] = None + meta: Optional[dict] = None + tags: Optional[list[str]] = None + is_active: Optional[bool] = True + version_id: Optional[str] = None + created_at: Optional[int] = None + updated_at: Optional[int] = None + access_grants: list[AccessGrantModel] = Field(default_factory=list) - access_control: Optional[dict] = None model_config = ConfigDict(from_attributes=True) @@ -67,51 +66,144 @@ class PromptAccessResponse(PromptUserResponse): write_access: Optional[bool] = False +class PromptListResponse(BaseModel): + items: list[PromptUserResponse] + total: int + + +class PromptAccessListResponse(BaseModel): + items: list[PromptAccessResponse] + total: int + + class PromptForm(BaseModel): + command: str - title: str + name: str # Changed from title content: str - access_control: Optional[dict] = None + data: Optional[dict] = None + meta: Optional[dict] = None + tags: Optional[list[str]] = None + access_grants: Optional[list[dict]] = None + version_id: Optional[str] = None # Active version + commit_message: Optional[str] = None # For history tracking + is_production: Optional[bool] = True # Whether to set new version as production class PromptsTable: + def _get_access_grants( + self, prompt_id: str, db: Optional[Session] = None + ) -> list[AccessGrantModel]: + return AccessGrants.get_grants_by_resource("prompt", prompt_id, db=db) + + def _to_prompt_model( + self, prompt: Prompt, db: Optional[Session] = None + ) -> PromptModel: + prompt_data = PromptModel.model_validate(prompt).model_dump( + exclude={"access_grants"} + ) + prompt_data["access_grants"] = self._get_access_grants(prompt_data["id"], db=db) + return PromptModel.model_validate(prompt_data) + def insert_new_prompt( self, user_id: str, form_data: PromptForm, db: Optional[Session] = None ) -> Optional[PromptModel]: + now = int(time.time()) + prompt_id = str(uuid.uuid4()) + prompt = PromptModel( - **{ - "user_id": user_id, - **form_data.model_dump(), - "timestamp": int(time.time()), - } + id=prompt_id, + user_id=user_id, + command=form_data.command, + name=form_data.name, + content=form_data.content, + data=form_data.data or {}, + meta=form_data.meta or {}, + tags=form_data.tags or [], + access_grants=[], + is_active=True, + created_at=now, + updated_at=now, ) try: with get_db_context(db) as db: - result = Prompt(**prompt.model_dump()) + result = Prompt(**prompt.model_dump(exclude={"access_grants"})) db.add(result) db.commit() db.refresh(result) + AccessGrants.set_access_grants( + "prompt", prompt_id, form_data.access_grants, db=db + ) + if result: - return PromptModel.model_validate(result) + current_access_grants = self._get_access_grants(prompt_id, db=db) + snapshot = { + "name": form_data.name, + "content": form_data.content, + "command": form_data.command, + "data": form_data.data or {}, + "meta": form_data.meta or {}, + "tags": form_data.tags or [], + "access_grants": [ + grant.model_dump() for grant in current_access_grants + ], + } + + history_entry = PromptHistories.create_history_entry( + prompt_id=prompt_id, + snapshot=snapshot, + user_id=user_id, + parent_id=None, # Initial commit has no parent + commit_message=form_data.commit_message or "Initial version", + db=db, + ) + + # Set the initial version as the production version + if history_entry: + result.version_id = history_entry.id + db.commit() + db.refresh(result) + + return self._to_prompt_model(result, db=db) else: return None except Exception: return None + def get_prompt_by_id( + self, prompt_id: str, db: Optional[Session] = None + ) -> Optional[PromptModel]: + """Get prompt by UUID.""" + try: + with get_db_context(db) as db: + prompt = db.query(Prompt).filter_by(id=prompt_id).first() + if prompt: + return self._to_prompt_model(prompt, db=db) + return None + except Exception: + return None + def get_prompt_by_command( self, command: str, db: Optional[Session] = None ) -> Optional[PromptModel]: try: with get_db_context(db) as db: prompt = db.query(Prompt).filter_by(command=command).first() - return PromptModel.model_validate(prompt) + if prompt: + return self._to_prompt_model(prompt, db=db) + return None except Exception: return None def get_prompts(self, db: Optional[Session] = None) -> list[PromptUserResponse]: with get_db_context(db) as db: - all_prompts = db.query(Prompt).order_by(Prompt.timestamp.desc()).all() + all_prompts = ( + db.query(Prompt) + .filter(Prompt.is_active == True) + .order_by(Prompt.updated_at.desc()) + .all() + ) user_ids = list(set(prompt.user_id for prompt in all_prompts)) @@ -124,7 +216,7 @@ def get_prompts(self, db: Optional[Session] = None) -> list[PromptUserResponse]: prompts.append( PromptUserResponse.model_validate( { - **PromptModel.model_validate(prompt).model_dump(), + **self._to_prompt_model(prompt, db=db).model_dump(), "user": user.model_dump() if user else None, } ) @@ -144,35 +236,400 @@ def get_prompts_by_user_id( prompt for prompt in prompts if prompt.user_id == user_id - or has_access(user_id, permission, prompt.access_control, user_group_ids) + or AccessGrants.has_access( + user_id=user_id, + resource_type="prompt", + resource_id=prompt.id, + permission=permission, + user_group_ids=user_group_ids, + db=db, + ) ] + def search_prompts( + self, + user_id: str, + filter: dict = {}, + skip: int = 0, + limit: int = 30, + db: Optional[Session] = None, + ) -> PromptListResponse: + with get_db_context(db) as db: + from open_webui.models.users import User, UserModel + + # Join with User table for user filtering and sorting + query = db.query(Prompt, User).outerjoin(User, User.id == Prompt.user_id) + query = query.filter(Prompt.is_active == True) + + if filter: + query_key = filter.get("query") + if query_key: + query = query.filter( + or_( + Prompt.name.ilike(f"%{query_key}%"), + Prompt.command.ilike(f"%{query_key}%"), + Prompt.content.ilike(f"%{query_key}%"), + User.name.ilike(f"%{query_key}%"), + User.email.ilike(f"%{query_key}%"), + ) + ) + + view_option = filter.get("view_option") + if view_option == "created": + query = query.filter(Prompt.user_id == user_id) + elif view_option == "shared": + query = query.filter(Prompt.user_id != user_id) + + # Apply access grant filtering + query = AccessGrants.has_permission_filter( + db=db, + query=query, + DocumentModel=Prompt, + filter=filter, + resource_type="prompt", + permission="read", + ) + + tag = filter.get("tag") + if tag: + # Search for tag in JSON array field + like_pattern = f'%"{tag.lower()}"%' + tags_text = func.lower(cast(Prompt.tags, String)) + query = query.filter(tags_text.like(like_pattern)) + + order_by = filter.get("order_by") + direction = filter.get("direction") + + if order_by == "name": + if direction == "asc": + query = query.order_by(Prompt.name.asc()) + else: + query = query.order_by(Prompt.name.desc()) + elif order_by == "created_at": + if direction == "asc": + query = query.order_by(Prompt.created_at.asc()) + else: + query = query.order_by(Prompt.created_at.desc()) + elif order_by == "updated_at": + if direction == "asc": + query = query.order_by(Prompt.updated_at.asc()) + else: + query = query.order_by(Prompt.updated_at.desc()) + else: + query = query.order_by(Prompt.updated_at.desc()) + else: + query = query.order_by(Prompt.updated_at.desc()) + + # Count BEFORE pagination + total = query.count() + + if skip: + query = query.offset(skip) + if limit: + query = query.limit(limit) + + items = query.all() + + prompts = [] + for prompt, user in items: + prompts.append( + PromptUserResponse( + **self._to_prompt_model(prompt, db=db).model_dump(), + user=( + UserResponse(**UserModel.model_validate(user).model_dump()) + if user + else None + ), + ) + ) + + return PromptListResponse(items=prompts, total=total) + def update_prompt_by_command( - self, command: str, form_data: PromptForm, db: Optional[Session] = None + self, + command: str, + form_data: PromptForm, + user_id: str, + db: Optional[Session] = None, ) -> Optional[PromptModel]: try: with get_db_context(db) as db: prompt = db.query(Prompt).filter_by(command=command).first() - prompt.title = form_data.title + if not prompt: + return None + + latest_history = PromptHistories.get_latest_history_entry( + prompt.id, db=db + ) + parent_id = latest_history.id if latest_history else None + current_access_grants = self._get_access_grants(prompt.id, db=db) + + # Check if content changed to decide on history creation + content_changed = ( + prompt.name != form_data.name + or prompt.content != form_data.content + or form_data.access_grants is not None + ) + + # Update prompt fields + prompt.name = form_data.name + prompt.content = form_data.content + prompt.data = form_data.data or prompt.data + prompt.meta = form_data.meta or prompt.meta + prompt.updated_at = int(time.time()) + if form_data.access_grants is not None: + AccessGrants.set_access_grants( + "prompt", prompt.id, form_data.access_grants, db=db + ) + current_access_grants = self._get_access_grants(prompt.id, db=db) + + db.commit() + + # Create history entry only if content changed + if content_changed: + snapshot = { + "name": form_data.name, + "content": form_data.content, + "command": command, + "data": form_data.data or {}, + "meta": form_data.meta or {}, + "access_grants": [ + grant.model_dump() for grant in current_access_grants + ], + } + + history_entry = PromptHistories.create_history_entry( + prompt_id=prompt.id, + snapshot=snapshot, + user_id=user_id, + parent_id=parent_id, + commit_message=form_data.commit_message, + db=db, + ) + + # Set as production if flag is True (default) + if form_data.is_production and history_entry: + prompt.version_id = history_entry.id + db.commit() + + return self._to_prompt_model(prompt, db=db) + except Exception: + return None + + def update_prompt_by_id( + self, + prompt_id: str, + form_data: PromptForm, + user_id: str, + db: Optional[Session] = None, + ) -> Optional[PromptModel]: + try: + with get_db_context(db) as db: + prompt = db.query(Prompt).filter_by(id=prompt_id).first() + if not prompt: + return None + + latest_history = PromptHistories.get_latest_history_entry( + prompt.id, db=db + ) + parent_id = latest_history.id if latest_history else None + current_access_grants = self._get_access_grants(prompt.id, db=db) + + # Check if content changed to decide on history creation + content_changed = ( + prompt.name != form_data.name + or prompt.command != form_data.command + or prompt.content != form_data.content + or form_data.access_grants is not None + or (form_data.tags is not None and prompt.tags != form_data.tags) + ) + + # Update prompt fields + prompt.name = form_data.name + prompt.command = form_data.command prompt.content = form_data.content - prompt.access_control = form_data.access_control - prompt.timestamp = int(time.time()) + prompt.data = form_data.data or prompt.data + prompt.meta = form_data.meta or prompt.meta + + if form_data.tags is not None: + prompt.tags = form_data.tags + + if form_data.access_grants is not None: + AccessGrants.set_access_grants( + "prompt", prompt.id, form_data.access_grants, db=db + ) + current_access_grants = self._get_access_grants(prompt.id, db=db) + + prompt.updated_at = int(time.time()) + + db.commit() + + # Create history entry only if content changed + if content_changed: + snapshot = { + "name": form_data.name, + "content": form_data.content, + "command": prompt.command, + "data": form_data.data or {}, + "meta": form_data.meta or {}, + "tags": prompt.tags or [], + "access_grants": [ + grant.model_dump() for grant in current_access_grants + ], + } + + history_entry = PromptHistories.create_history_entry( + prompt_id=prompt.id, + snapshot=snapshot, + user_id=user_id, + parent_id=parent_id, + commit_message=form_data.commit_message, + db=db, + ) + + # Set as production if flag is True (default) + if form_data.is_production and history_entry: + prompt.version_id = history_entry.id + db.commit() + + return self._to_prompt_model(prompt, db=db) + except Exception: + return None + + def update_prompt_metadata( + self, + prompt_id: str, + name: str, + command: str, + tags: Optional[list[str]] = None, + db: Optional[Session] = None, + ) -> Optional[PromptModel]: + """Update only name, command, and tags (no history created).""" + try: + with get_db_context(db) as db: + prompt = db.query(Prompt).filter_by(id=prompt_id).first() + if not prompt: + return None + + prompt.name = name + prompt.command = command + + if tags is not None: + prompt.tags = tags + + prompt.updated_at = int(time.time()) + db.commit() + + return self._to_prompt_model(prompt, db=db) + except Exception: + return None + + def update_prompt_version( + self, + prompt_id: str, + version_id: str, + db: Optional[Session] = None, + ) -> Optional[PromptModel]: + """Set the active version of a prompt and restore content from that version's snapshot.""" + try: + with get_db_context(db) as db: + prompt = db.query(Prompt).filter_by(id=prompt_id).first() + if not prompt: + return None + + history_entry = PromptHistories.get_history_entry_by_id( + version_id, db=db + ) + + if not history_entry: + return None + + # Restore prompt content from the snapshot + snapshot = history_entry.snapshot + if snapshot: + prompt.name = snapshot.get("name", prompt.name) + prompt.content = snapshot.get("content", prompt.content) + prompt.data = snapshot.get("data", prompt.data) + prompt.meta = snapshot.get("meta", prompt.meta) + prompt.tags = snapshot.get("tags", prompt.tags) + # Note: command and access_grants are not restored from snapshot + + prompt.version_id = version_id + prompt.updated_at = int(time.time()) db.commit() - return PromptModel.model_validate(prompt) + + return self._to_prompt_model(prompt, db=db) except Exception: return None def delete_prompt_by_command( self, command: str, db: Optional[Session] = None ) -> bool: + """Soft delete a prompt by setting is_active to False.""" try: with get_db_context(db) as db: - db.query(Prompt).filter_by(command=command).delete() - db.commit() + prompt = db.query(Prompt).filter_by(command=command).first() + if prompt: + PromptHistories.delete_history_by_prompt_id(prompt.id, db=db) + AccessGrants.revoke_all_access("prompt", prompt.id, db=db) + + prompt.is_active = False + prompt.updated_at = int(time.time()) + db.commit() + return True + return False + except Exception: + return False - return True + def delete_prompt_by_id(self, prompt_id: str, db: Optional[Session] = None) -> bool: + """Soft delete a prompt by setting is_active to False.""" + try: + with get_db_context(db) as db: + prompt = db.query(Prompt).filter_by(id=prompt_id).first() + if prompt: + PromptHistories.delete_history_by_prompt_id(prompt.id, db=db) + AccessGrants.revoke_all_access("prompt", prompt.id, db=db) + + prompt.is_active = False + prompt.updated_at = int(time.time()) + db.commit() + return True + return False + except Exception: + return False + + def hard_delete_prompt_by_command( + self, command: str, db: Optional[Session] = None + ) -> bool: + """Permanently delete a prompt and its history.""" + try: + with get_db_context(db) as db: + prompt = db.query(Prompt).filter_by(command=command).first() + if prompt: + PromptHistories.delete_history_by_prompt_id(prompt.id, db=db) + AccessGrants.revoke_all_access("prompt", prompt.id, db=db) + + # Delete prompt + db.query(Prompt).filter_by(command=command).delete() + db.commit() + return True + return False except Exception: return False + def get_tags(self, db: Optional[Session] = None) -> list[str]: + try: + with get_db_context(db) as db: + prompts = db.query(Prompt).filter_by(is_active=True).all() + tags = set() + for prompt in prompts: + if prompt.tags: + for tag in prompt.tags: + if tag: + tags.add(tag) + return sorted(list(tags)) + except Exception: + return [] + Prompts = PromptsTable() diff --git a/backend/open_webui/models/skills.py b/backend/open_webui/models/skills.py new file mode 100644 index 0000000000..71e8f97b31 --- /dev/null +++ b/backend/open_webui/models/skills.py @@ -0,0 +1,339 @@ +import logging +import time +from typing import Optional + +from sqlalchemy.orm import Session +from open_webui.internal.db import Base, JSONField, get_db, get_db_context +from open_webui.models.users import Users, UserResponse +from open_webui.models.groups import Groups +from open_webui.models.access_grants import AccessGrantModel, AccessGrants + +from pydantic import BaseModel, ConfigDict, Field +from sqlalchemy import BigInteger, Boolean, Column, String, Text, or_ + +log = logging.getLogger(__name__) + +#################### +# Skills DB Schema +#################### + + +class Skill(Base): + __tablename__ = "skill" + + id = Column(String, primary_key=True, unique=True) + user_id = Column(String) + name = Column(Text, unique=True) + description = Column(Text, nullable=True) + content = Column(Text) + meta = Column(JSONField) + is_active = Column(Boolean, default=True) + + updated_at = Column(BigInteger) + created_at = Column(BigInteger) + + +class SkillMeta(BaseModel): + tags: Optional[list[str]] = [] + + +class SkillModel(BaseModel): + id: str + user_id: str + name: str + description: Optional[str] = None + content: str + meta: SkillMeta + is_active: bool = True + access_grants: list[AccessGrantModel] = Field(default_factory=list) + + updated_at: int # timestamp in epoch + created_at: int # timestamp in epoch + + model_config = ConfigDict(from_attributes=True) + + +#################### +# Forms +#################### + + +class SkillUserModel(SkillModel): + user: Optional[UserResponse] = None + + +class SkillResponse(BaseModel): + id: str + user_id: str + name: str + description: Optional[str] = None + meta: SkillMeta + is_active: bool = True + access_grants: list[AccessGrantModel] = Field(default_factory=list) + updated_at: int # timestamp in epoch + created_at: int # timestamp in epoch + + +class SkillUserResponse(SkillResponse): + user: Optional[UserResponse] = None + + model_config = ConfigDict(extra="allow") + + +class SkillAccessResponse(SkillUserResponse): + write_access: Optional[bool] = False + + +class SkillForm(BaseModel): + id: str + name: str + description: Optional[str] = None + content: str + meta: SkillMeta = SkillMeta() + is_active: bool = True + access_grants: Optional[list[dict]] = None + + +class SkillListResponse(BaseModel): + items: list[SkillUserResponse] = [] + total: int = 0 + + +class SkillAccessListResponse(BaseModel): + items: list[SkillAccessResponse] = [] + total: int = 0 + + +class SkillsTable: + def _get_access_grants( + self, skill_id: str, db: Optional[Session] = None + ) -> list[AccessGrantModel]: + return AccessGrants.get_grants_by_resource("skill", skill_id, db=db) + + def _to_skill_model(self, skill: Skill, db: Optional[Session] = None) -> SkillModel: + skill_data = SkillModel.model_validate(skill).model_dump( + exclude={"access_grants"} + ) + skill_data["access_grants"] = self._get_access_grants(skill_data["id"], db=db) + return SkillModel.model_validate(skill_data) + + def insert_new_skill( + self, + user_id: str, + form_data: SkillForm, + db: Optional[Session] = None, + ) -> Optional[SkillModel]: + with get_db_context(db) as db: + try: + result = Skill( + **{ + **form_data.model_dump(exclude={"access_grants"}), + "user_id": user_id, + "updated_at": int(time.time()), + "created_at": int(time.time()), + } + ) + db.add(result) + db.commit() + db.refresh(result) + AccessGrants.set_access_grants( + "skill", result.id, form_data.access_grants, db=db + ) + if result: + return self._to_skill_model(result, db=db) + else: + return None + except Exception as e: + log.exception(f"Error creating a new skill: {e}") + return None + + def get_skill_by_id( + self, id: str, db: Optional[Session] = None + ) -> Optional[SkillModel]: + try: + with get_db_context(db) as db: + skill = db.get(Skill, id) + return self._to_skill_model(skill, db=db) if skill else None + except Exception: + return None + + def get_skill_by_name( + self, name: str, db: Optional[Session] = None + ) -> Optional[SkillModel]: + try: + with get_db_context(db) as db: + skill = db.query(Skill).filter_by(name=name).first() + return self._to_skill_model(skill, db=db) if skill else None + except Exception: + return None + + def get_skills(self, db: Optional[Session] = None) -> list[SkillUserModel]: + with get_db_context(db) as db: + all_skills = db.query(Skill).order_by(Skill.updated_at.desc()).all() + + user_ids = list(set(skill.user_id for skill in all_skills)) + + users = Users.get_users_by_user_ids(user_ids, db=db) if user_ids else [] + users_dict = {user.id: user for user in users} + + skills = [] + for skill in all_skills: + user = users_dict.get(skill.user_id) + skills.append( + SkillUserModel.model_validate( + { + **self._to_skill_model(skill, db=db).model_dump(), + "user": user.model_dump() if user else None, + } + ) + ) + return skills + + def get_skills_by_user_id( + self, user_id: str, permission: str = "write", db: Optional[Session] = None + ) -> list[SkillUserModel]: + skills = self.get_skills(db=db) + user_group_ids = { + group.id for group in Groups.get_groups_by_member_id(user_id, db=db) + } + + return [ + skill + for skill in skills + if skill.user_id == user_id + or AccessGrants.has_access( + user_id=user_id, + resource_type="skill", + resource_id=skill.id, + permission=permission, + user_group_ids=user_group_ids, + db=db, + ) + ] + + def search_skills( + self, + user_id: str, + filter: dict = {}, + skip: int = 0, + limit: int = 30, + db: Optional[Session] = None, + ) -> SkillListResponse: + try: + with get_db_context(db) as db: + from open_webui.models.users import User, UserModel + + # Join with User table for user filtering + query = db.query(Skill, User).outerjoin(User, User.id == Skill.user_id) + + if filter: + query_key = filter.get("query") + if query_key: + query = query.filter( + or_( + Skill.name.ilike(f"%{query_key}%"), + Skill.description.ilike(f"%{query_key}%"), + Skill.id.ilike(f"%{query_key}%"), + User.name.ilike(f"%{query_key}%"), + User.email.ilike(f"%{query_key}%"), + ) + ) + + view_option = filter.get("view_option") + if view_option == "created": + query = query.filter(Skill.user_id == user_id) + elif view_option == "shared": + query = query.filter(Skill.user_id != user_id) + + # Apply access grant filtering + query = AccessGrants.has_permission_filter( + db=db, + query=query, + DocumentModel=Skill, + filter=filter, + resource_type="skill", + permission="read", + ) + + query = query.order_by(Skill.updated_at.desc()) + + # Count BEFORE pagination + total = query.count() + + if skip: + query = query.offset(skip) + if limit: + query = query.limit(limit) + + items = query.all() + + skills = [] + for skill, user in items: + skills.append( + SkillUserResponse( + **self._to_skill_model(skill, db=db).model_dump(), + user=( + UserResponse( + **UserModel.model_validate(user).model_dump() + ) + if user + else None + ), + ) + ) + + return SkillListResponse(items=skills, total=total) + except Exception as e: + log.exception(f"Error searching skills: {e}") + return SkillListResponse(items=[], total=0) + + def update_skill_by_id( + self, id: str, updated: dict, db: Optional[Session] = None + ) -> Optional[SkillModel]: + try: + with get_db_context(db) as db: + access_grants = updated.pop("access_grants", None) + db.query(Skill).filter_by(id=id).update( + {**updated, "updated_at": int(time.time())} + ) + db.commit() + if access_grants is not None: + AccessGrants.set_access_grants("skill", id, access_grants, db=db) + + skill = db.query(Skill).get(id) + db.refresh(skill) + return self._to_skill_model(skill, db=db) + except Exception: + return None + + def toggle_skill_by_id( + self, id: str, db: Optional[Session] = None + ) -> Optional[SkillModel]: + with get_db_context(db) as db: + try: + skill = db.query(Skill).filter_by(id=id).first() + if not skill: + return None + + skill.is_active = not skill.is_active + skill.updated_at = int(time.time()) + db.commit() + db.refresh(skill) + + return self._to_skill_model(skill, db=db) + except Exception: + return None + + def delete_skill_by_id(self, id: str, db: Optional[Session] = None) -> bool: + try: + with get_db_context(db) as db: + AccessGrants.revoke_all_access("skill", id, db=db) + db.query(Skill).filter_by(id=id).delete() + db.commit() + + return True + except Exception: + return False + + +Skills = SkillsTable() diff --git a/backend/open_webui/models/tools.py b/backend/open_webui/models/tools.py index cd7d0bd1a0..eaac4c385d 100644 --- a/backend/open_webui/models/tools.py +++ b/backend/open_webui/models/tools.py @@ -6,12 +6,10 @@ from open_webui.internal.db import Base, JSONField, get_db, get_db_context from open_webui.models.users import Users, UserResponse from open_webui.models.groups import Groups +from open_webui.models.access_grants import AccessGrantModel, AccessGrants -from pydantic import BaseModel, ConfigDict -from sqlalchemy import BigInteger, Column, String, Text, JSON - -from open_webui.utils.access_control import has_access - +from pydantic import BaseModel, ConfigDict, Field +from sqlalchemy import BigInteger, Column, String, Text log = logging.getLogger(__name__) @@ -31,23 +29,6 @@ class Tool(Base): meta = Column(JSONField) valves = Column(JSONField) - access_control = Column(JSON, nullable=True) # Controls data access levels. - # Defines access control rules for this entry. - # - `None`: Public access, available to all users with the "user" role. - # - `{}`: Private access, restricted exclusively to the owner. - # - Custom permissions: Specific access control for reading and writing; - # Can specify group or user-level restrictions: - # { - # "read": { - # "group_ids": ["group_id1", "group_id2"], - # "user_ids": ["user_id1", "user_id2"] - # }, - # "write": { - # "group_ids": ["group_id1", "group_id2"], - # "user_ids": ["user_id1", "user_id2"] - # } - # } - updated_at = Column(BigInteger) created_at = Column(BigInteger) @@ -64,7 +45,7 @@ class ToolModel(BaseModel): content: str specs: list[dict] meta: ToolMeta - access_control: Optional[dict] = None + access_grants: list[AccessGrantModel] = Field(default_factory=list) updated_at: int # timestamp in epoch created_at: int # timestamp in epoch @@ -86,7 +67,7 @@ class ToolResponse(BaseModel): user_id: str name: str meta: ToolMeta - access_control: Optional[dict] = None + access_grants: list[AccessGrantModel] = Field(default_factory=list) updated_at: int # timestamp in epoch created_at: int # timestamp in epoch @@ -106,7 +87,7 @@ class ToolForm(BaseModel): name: str content: str meta: ToolMeta - access_control: Optional[dict] = None + access_grants: Optional[list[dict]] = None class ToolValves(BaseModel): @@ -114,6 +95,16 @@ class ToolValves(BaseModel): class ToolsTable: + def _get_access_grants( + self, tool_id: str, db: Optional[Session] = None + ) -> list[AccessGrantModel]: + return AccessGrants.get_grants_by_resource("tool", tool_id, db=db) + + def _to_tool_model(self, tool: Tool, db: Optional[Session] = None) -> ToolModel: + tool_data = ToolModel.model_validate(tool).model_dump(exclude={"access_grants"}) + tool_data["access_grants"] = self._get_access_grants(tool_data["id"], db=db) + return ToolModel.model_validate(tool_data) + def insert_new_tool( self, user_id: str, @@ -122,23 +113,24 @@ def insert_new_tool( db: Optional[Session] = None, ) -> Optional[ToolModel]: with get_db_context(db) as db: - tool = ToolModel( - **{ - **form_data.model_dump(), - "specs": specs, - "user_id": user_id, - "updated_at": int(time.time()), - "created_at": int(time.time()), - } - ) - try: - result = Tool(**tool.model_dump()) + result = Tool( + **{ + **form_data.model_dump(exclude={"access_grants"}), + "specs": specs, + "user_id": user_id, + "updated_at": int(time.time()), + "created_at": int(time.time()), + } + ) db.add(result) db.commit() db.refresh(result) + AccessGrants.set_access_grants( + "tool", result.id, form_data.access_grants, db=db + ) if result: - return ToolModel.model_validate(result) + return self._to_tool_model(result, db=db) else: return None except Exception as e: @@ -151,7 +143,7 @@ def get_tool_by_id( try: with get_db_context(db) as db: tool = db.get(Tool, id) - return ToolModel.model_validate(tool) + return self._to_tool_model(tool, db=db) if tool else None except Exception: return None @@ -170,7 +162,7 @@ def get_tools(self, db: Optional[Session] = None) -> list[ToolUserModel]: tools.append( ToolUserModel.model_validate( { - **ToolModel.model_validate(tool).model_dump(), + **self._to_tool_model(tool, db=db).model_dump(), "user": user.model_dump() if user else None, } ) @@ -189,7 +181,14 @@ def get_tools_by_user_id( tool for tool in tools if tool.user_id == user_id - or has_access(user_id, permission, tool.access_control, user_group_ids) + or AccessGrants.has_access( + user_id=user_id, + resource_type="tool", + resource_id=tool.id, + permission=permission, + user_group_ids=user_group_ids, + db=db, + ) ] def get_tool_valves_by_id( @@ -266,20 +265,24 @@ def update_tool_by_id( ) -> Optional[ToolModel]: try: with get_db_context(db) as db: + access_grants = updated.pop("access_grants", None) db.query(Tool).filter_by(id=id).update( {**updated, "updated_at": int(time.time())} ) db.commit() + if access_grants is not None: + AccessGrants.set_access_grants("tool", id, access_grants, db=db) tool = db.query(Tool).get(id) db.refresh(tool) - return ToolModel.model_validate(tool) + return self._to_tool_model(tool, db=db) except Exception: return None def delete_tool_by_id(self, id: str, db: Optional[Session] = None) -> bool: try: with get_db_context(db) as db: + AccessGrants.revoke_all_access("tool", id, db=db) db.query(Tool).filter_by(id=id).delete() db.commit() diff --git a/backend/open_webui/models/users.py b/backend/open_webui/models/users.py index 294c11aa74..cb784850f8 100644 --- a/backend/open_webui/models/users.py +++ b/backend/open_webui/models/users.py @@ -4,7 +4,6 @@ from sqlalchemy.orm import Session from open_webui.internal.db import Base, JSONField, get_db, get_db_context - from open_webui.env import DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL from open_webui.models.chats import Chats @@ -12,9 +11,9 @@ from open_webui.models.channels import ChannelMember from open_webui.utils.misc import throttle +from open_webui.utils.validate import validate_profile_image_url - -from pydantic import BaseModel, ConfigDict +from pydantic import BaseModel, ConfigDict, field_validator from sqlalchemy import ( BigInteger, JSON, @@ -154,6 +153,11 @@ class UpdateProfileForm(BaseModel): gender: Optional[str] = None date_of_birth: Optional[datetime.date] = None + @field_validator("profile_image_url") + @classmethod + def check_profile_image_url(cls, v: str) -> str: + return validate_profile_image_url(v) + class UserGroupIdsModel(UserModel): group_ids: list[str] = [] @@ -184,6 +188,9 @@ class UserInfoResponse(UserStatus): name: str email: str role: str + bio: Optional[str] = None + groups: Optional[list] = [] + is_active: bool = False class UserIdNameResponse(BaseModel): @@ -235,6 +242,11 @@ class UserUpdateForm(BaseModel): password: Optional[str] = None credit: Optional[float] = None + @field_validator("profile_image_url") + @classmethod + def check_profile_image_url(cls, v: str) -> str: + return validate_profile_image_url(v) + class UserCreditUpdateForm(BaseModel): amount: Optional[float] = None @@ -249,6 +261,7 @@ def insert_new_user( email: str, profile_image_url: str = "/user.png", role: str = "pending", + username: Optional[str] = None, oauth: Optional[dict] = None, db: Optional[Session] = None, ) -> Optional[UserModel]: @@ -263,6 +276,7 @@ def insert_new_user( "last_active_at": int(time.time()), "created_at": int(time.time()), "updated_at": int(time.time()), + "username": username, "oauth": oauth, } ) @@ -536,9 +550,12 @@ def update_user_role_by_id( ) -> Optional[UserModel]: try: with get_db_context(db) as db: - db.query(User).filter_by(id=id).update({"role": role}) - db.commit() user = db.query(User).filter_by(id=id).first() + if not user: + return None + user.role = role + db.commit() + db.refresh(user) return UserModel.model_validate(user) except Exception: return None @@ -548,12 +565,13 @@ def update_user_status_by_id( ) -> Optional[UserModel]: try: with get_db_context(db) as db: - db.query(User).filter_by(id=id).update( - {**form_data.model_dump(exclude_none=True)} - ) - db.commit() - user = db.query(User).filter_by(id=id).first() + if not user: + return None + for key, value in form_data.model_dump(exclude_none=True).items(): + setattr(user, key, value) + db.commit() + db.refresh(user) return UserModel.model_validate(user) except Exception: return None @@ -563,12 +581,12 @@ def update_user_profile_image_url_by_id( ) -> Optional[UserModel]: try: with get_db_context(db) as db: - db.query(User).filter_by(id=id).update( - {"profile_image_url": profile_image_url} - ) - db.commit() - user = db.query(User).filter_by(id=id).first() + if not user: + return None + user.profile_image_url = profile_image_url + db.commit() + db.refresh(user) return UserModel.model_validate(user) except Exception: return None @@ -579,12 +597,12 @@ def update_last_active_by_id( ) -> Optional[UserModel]: try: with get_db_context(db) as db: - db.query(User).filter_by(id=id).update( - {"last_active_at": int(time.time())} - ) - db.commit() - user = db.query(User).filter_by(id=id).first() + if not user: + return None + user.last_active_at = int(time.time()) + db.commit() + db.refresh(user) return UserModel.model_validate(user) except Exception: return None @@ -626,12 +644,14 @@ def update_user_by_id( ) -> Optional[UserModel]: try: with get_db_context(db) as db: - db.query(User).filter_by(id=id).update(updated) - db.commit() - user = db.query(User).filter_by(id=id).first() + if not user: + return None + for key, value in updated.items(): + setattr(user, key, value) + db.commit() + db.refresh(user) return UserModel.model_validate(user) - # return UserModel(**user.dict()) except Exception as e: print(e) return None diff --git a/backend/open_webui/retrieval/loaders/main.py b/backend/open_webui/retrieval/loaders/main.py index 2b83e44283..83adb8823f 100644 --- a/backend/open_webui/retrieval/loaders/main.py +++ b/backend/open_webui/retrieval/loaders/main.py @@ -143,7 +143,7 @@ def load(self) -> list[Document]: with open(self.file_path, "rb") as f: headers = {} if self.api_key: - headers["X-Api-Key"] = f"Bearer {self.api_key}" + headers["X-Api-Key"] = f"{self.api_key}" r = requests.post( f"{self.url}/v1/convert/file", @@ -361,7 +361,9 @@ def _get_loader(self, filename: str, file_content_type: str, file_path: str): else: if file_ext == "pdf": loader = PyPDFLoader( - file_path, extract_images=self.kwargs.get("PDF_EXTRACT_IMAGES") + file_path, + extract_images=self.kwargs.get("PDF_EXTRACT_IMAGES"), + mode=self.kwargs.get("PDF_LOADER_MODE", "page"), ) elif file_ext == "csv": loader = CSVLoader(file_path, autodetect_encoding=True) diff --git a/backend/open_webui/retrieval/models/external.py b/backend/open_webui/retrieval/models/external.py index 095143d20d..cd24dc6af2 100644 --- a/backend/open_webui/retrieval/models/external.py +++ b/backend/open_webui/retrieval/models/external.py @@ -8,7 +8,6 @@ from open_webui.retrieval.models.base_reranker import BaseReranker from open_webui.utils.headers import include_user_info_headers - log = logging.getLogger(__name__) diff --git a/backend/open_webui/retrieval/utils.py b/backend/open_webui/retrieval/utils.py index 61b1a947a6..40988070bc 100644 --- a/backend/open_webui/retrieval/utils.py +++ b/backend/open_webui/retrieval/utils.py @@ -28,9 +28,9 @@ from open_webui.models.chats import Chats from open_webui.models.notes import Notes +from open_webui.models.access_grants import AccessGrants from open_webui.retrieval.vector.main import GetResult -from open_webui.utils.access_control import has_access from open_webui.utils.headers import include_user_info_headers from open_webui.utils.misc import get_message_list @@ -565,7 +565,10 @@ async def agenerate_openai_batch_embeddings( trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT) ) as session: async with session.post( - f"{url}/embeddings", headers=headers, json=form_data + f"{url}/embeddings", + headers=headers, + json=form_data, + ssl=AIOHTTP_CLIENT_SESSION_SSL, ) as r: r.raise_for_status() data = await r.json() @@ -632,7 +635,12 @@ async def agenerate_azure_openai_batch_embeddings( async with aiohttp.ClientSession( trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT) ) as session: - async with session.post(full_url, headers=headers, json=form_data) as r: + async with session.post( + full_url, + headers=headers, + json=form_data, + ssl=AIOHTTP_CLIENT_SESSION_SSL, + ) as r: r.raise_for_status() data = await r.json() @@ -932,7 +940,12 @@ async def get_sources_from_items( if note and ( user.role == "admin" or note.user_id == user.id - or has_access(user.id, "read", note.access_control) + or AccessGrants.has_access( + user_id=user.id, + resource_type="note", + resource_id=note.id, + permission="read", + ) ): # User has access to the note query_result = { @@ -1024,7 +1037,12 @@ async def get_sources_from_items( if knowledge_base and ( user.role == "admin" or knowledge_base.user_id == user.id - or has_access(user.id, "read", knowledge_base.access_control) + or AccessGrants.has_access( + user_id=user.id, + resource_type="knowledge", + resource_id=knowledge_base.id, + permission="read", + ) ): if ( item.get("context") == "full" @@ -1033,7 +1051,12 @@ async def get_sources_from_items( if knowledge_base and ( user.role == "admin" or knowledge_base.user_id == user.id - or has_access(user.id, "read", knowledge_base.access_control) + or AccessGrants.has_access( + user_id=user.id, + resource_type="knowledge", + resource_id=knowledge_base.id, + permission="read", + ) ): files = Knowledges.get_files_by_id(knowledge_base.id) diff --git a/backend/open_webui/retrieval/vector/dbs/opensearch.py b/backend/open_webui/retrieval/vector/dbs/opensearch.py index dc9c35805e..ed5a931c68 100644 --- a/backend/open_webui/retrieval/vector/dbs/opensearch.py +++ b/backend/open_webui/retrieval/vector/dbs/opensearch.py @@ -211,7 +211,7 @@ def insert(self, collection_name: str, items: list[VectorItem]): for item in batch ] bulk(self.client, actions) - self.client.indices.refresh(self._get_index_name(collection_name)) + self.client.indices.refresh(index=self._get_index_name(collection_name)) def upsert(self, collection_name: str, items: list[VectorItem]): self._create_index_if_not_exists( @@ -234,7 +234,7 @@ def upsert(self, collection_name: str, items: list[VectorItem]): for item in batch ] bulk(self.client, actions) - self.client.indices.refresh(self._get_index_name(collection_name)) + self.client.indices.refresh(index=self._get_index_name(collection_name)) def delete( self, @@ -263,7 +263,7 @@ def delete( self.client.delete_by_query( index=self._get_index_name(collection_name), body=query_body ) - self.client.indices.refresh(self._get_index_name(collection_name)) + self.client.indices.refresh(index=self._get_index_name(collection_name)) def reset(self): indices = self.client.indices.get(index=f"{self.index_prefix}_*") diff --git a/backend/open_webui/retrieval/vector/dbs/oracle23ai.py b/backend/open_webui/retrieval/vector/dbs/oracle23ai.py index 9f16f82bc9..f4258c9eff 100644 --- a/backend/open_webui/retrieval/vector/dbs/oracle23ai.py +++ b/backend/open_webui/retrieval/vector/dbs/oracle23ai.py @@ -256,8 +256,7 @@ def _initialize_database(self, connection) -> None: with connection.cursor() as cursor: try: log.info("Creating Table document_chunk") - cursor.execute( - """ + cursor.execute(""" BEGIN EXECUTE IMMEDIATE ' CREATE TABLE IF NOT EXISTS document_chunk ( @@ -274,12 +273,10 @@ def _initialize_database(self, connection) -> None: RAISE; END IF; END; - """ - ) + """) log.info("Creating Index document_chunk_collection_name_idx") - cursor.execute( - """ + cursor.execute(""" BEGIN EXECUTE IMMEDIATE ' CREATE INDEX IF NOT EXISTS document_chunk_collection_name_idx @@ -291,12 +288,10 @@ def _initialize_database(self, connection) -> None: RAISE; END IF; END; - """ - ) + """) log.info("Creating VECTOR INDEX document_chunk_vector_ivf_idx") - cursor.execute( - """ + cursor.execute(""" BEGIN EXECUTE IMMEDIATE ' CREATE VECTOR INDEX IF NOT EXISTS document_chunk_vector_ivf_idx @@ -312,8 +307,7 @@ def _initialize_database(self, connection) -> None: RAISE; END IF; END; - """ - ) + """) connection.commit() log.info("Database initialization completed successfully.") diff --git a/backend/open_webui/retrieval/vector/dbs/pgvector.py b/backend/open_webui/retrieval/vector/dbs/pgvector.py index 15430db114..481f9d92fc 100644 --- a/backend/open_webui/retrieval/vector/dbs/pgvector.py +++ b/backend/open_webui/retrieval/vector/dbs/pgvector.py @@ -51,7 +51,6 @@ PGVECTOR_USE_HALFVEC, ) - VECTOR_LENGTH = PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH USE_HALFVEC = PGVECTOR_USE_HALFVEC @@ -121,34 +120,26 @@ def __init__(self) -> None: # Ensure the pgvector extension is available # Use a conditional check to avoid permission issues on Azure PostgreSQL if PGVECTOR_CREATE_EXTENSION: - self.session.execute( - text( - """ + self.session.execute(text(""" DO $$ BEGIN IF NOT EXISTS (SELECT 1 FROM pg_extension WHERE extname = 'vector') THEN CREATE EXTENSION IF NOT EXISTS vector; END IF; END $$; - """ - ) - ) + """)) if PGVECTOR_PGCRYPTO: # Ensure the pgcrypto extension is available for encryption # Use a conditional check to avoid permission issues on Azure PostgreSQL - self.session.execute( - text( - """ + self.session.execute(text(""" DO $$ BEGIN IF NOT EXISTS (SELECT 1 FROM pg_extension WHERE extname = 'pgcrypto') THEN CREATE EXTENSION IF NOT EXISTS pgcrypto; END IF; END $$; - """ - ) - ) + """)) if not PGVECTOR_PGCRYPTO_KEY: raise ValueError( @@ -216,15 +207,13 @@ def _vector_index_configuration(self) -> Tuple[str, str]: def _ensure_vector_index(self, index_method: str, index_options: str) -> None: index_name = "idx_document_chunk_vector" existing_index_def = self.session.execute( - text( - """ + text(""" SELECT indexdef FROM pg_indexes WHERE schemaname = current_schema() AND tablename = 'document_chunk' AND indexname = :index_name - """ - ), + """), {"index_name": index_name}, ).scalar() @@ -310,8 +299,7 @@ def insert(self, collection_name: str, items: List[VectorItem]) -> None: # Ensure metadata is converted to its JSON text representation json_metadata = json.dumps(item["metadata"]) self.session.execute( - text( - """ + text(""" INSERT INTO document_chunk (id, vector, collection_name, text, vmetadata) VALUES ( @@ -320,8 +308,7 @@ def insert(self, collection_name: str, items: List[VectorItem]) -> None: pgp_sym_encrypt(:metadata_text, :key) ) ON CONFLICT (id) DO NOTHING - """ - ), + """), { "id": item["id"], "vector": vector, @@ -363,8 +350,7 @@ def upsert(self, collection_name: str, items: List[VectorItem]) -> None: vector = self.adjust_vector_length(item["vector"]) json_metadata = json.dumps(item["metadata"]) self.session.execute( - text( - """ + text(""" INSERT INTO document_chunk (id, vector, collection_name, text, vmetadata) VALUES ( @@ -377,8 +363,7 @@ def upsert(self, collection_name: str, items: List[VectorItem]) -> None: collection_name = EXCLUDED.collection_name, text = EXCLUDED.text, vmetadata = EXCLUDED.vmetadata - """ - ), + """), { "id": item["id"], "vector": vector, diff --git a/backend/open_webui/retrieval/vector/dbs/pinecone.py b/backend/open_webui/retrieval/vector/dbs/pinecone.py index fc3c98f8cf..156894bc9e 100644 --- a/backend/open_webui/retrieval/vector/dbs/pinecone.py +++ b/backend/open_webui/retrieval/vector/dbs/pinecone.py @@ -33,7 +33,6 @@ ) from open_webui.retrieval.vector.utils import process_metadata - NO_LIMIT = 10000 # Reasonable limit to avoid overwhelming the system BATCH_SIZE = 100 # Recommended batch size for Pinecone operations diff --git a/backend/open_webui/retrieval/vector/dbs/weaviate.py b/backend/open_webui/retrieval/vector/dbs/weaviate.py index d204e8293a..dcc648c788 100644 --- a/backend/open_webui/retrieval/vector/dbs/weaviate.py +++ b/backend/open_webui/retrieval/vector/dbs/weaviate.py @@ -12,9 +12,13 @@ from open_webui.retrieval.vector.utils import process_metadata from open_webui.config import ( WEAVIATE_HTTP_HOST, + WEAVIATE_GRPC_HOST, WEAVIATE_HTTP_PORT, WEAVIATE_GRPC_PORT, WEAVIATE_API_KEY, + WEAVIATE_HTTP_SECURE, + WEAVIATE_GRPC_SECURE, + WEAVIATE_SKIP_INIT_CHECKS, ) @@ -52,9 +56,13 @@ def __init__(self): try: # Build connection parameters connection_params = { - "host": WEAVIATE_HTTP_HOST, - "port": WEAVIATE_HTTP_PORT, + "http_host": WEAVIATE_HTTP_HOST, + "http_port": WEAVIATE_HTTP_PORT, + "http_secure": WEAVIATE_HTTP_SECURE, + "grpc_host": WEAVIATE_GRPC_HOST, "grpc_port": WEAVIATE_GRPC_PORT, + "grpc_secure": WEAVIATE_GRPC_SECURE, + "skip_init_checks": WEAVIATE_SKIP_INIT_CHECKS, } # Only add auth_credentials if WEAVIATE_API_KEY exists and is not empty @@ -63,7 +71,7 @@ def __init__(self): weaviate.classes.init.Auth.api_key(WEAVIATE_API_KEY) ) - self.client = weaviate.connect_to_local(**connection_params) + self.client = weaviate.connect_to_custom(**connection_params) self.client.connect() except Exception as e: raise ConnectionError(f"Failed to connect to Weaviate: {e}") from e diff --git a/backend/open_webui/retrieval/web/external.py b/backend/open_webui/retrieval/web/external.py index 527c918a47..e8cf72e9f0 100644 --- a/backend/open_webui/retrieval/web/external.py +++ b/backend/open_webui/retrieval/web/external.py @@ -8,7 +8,7 @@ from open_webui.retrieval.web.main import SearchResult, get_filtered_results from open_webui.utils.headers import include_user_info_headers - +from open_webui.env import FORWARD_SESSION_INFO_HEADER_CHAT_ID log = logging.getLogger(__name__) @@ -31,7 +31,7 @@ def search_external( chat_id = getattr(request.state, "chat_id", None) if chat_id: - headers["X-OpenWebUI-Chat-Id"] = str(chat_id) + headers[FORWARD_SESSION_INFO_HEADER_CHAT_ID] = str(chat_id) response = requests.post( external_url, diff --git a/backend/open_webui/retrieval/web/firecrawl.py b/backend/open_webui/retrieval/web/firecrawl.py index 82635aa8ca..e6e96992a1 100644 --- a/backend/open_webui/retrieval/web/firecrawl.py +++ b/backend/open_webui/retrieval/web/firecrawl.py @@ -3,7 +3,6 @@ from open_webui.retrieval.web.main import SearchResult, get_filtered_results - log = logging.getLogger(__name__) diff --git a/backend/open_webui/retrieval/web/perplexity_search.py b/backend/open_webui/retrieval/web/perplexity_search.py index 5c591ff64f..744a505c05 100644 --- a/backend/open_webui/retrieval/web/perplexity_search.py +++ b/backend/open_webui/retrieval/web/perplexity_search.py @@ -5,7 +5,6 @@ from open_webui.retrieval.web.main import SearchResult, get_filtered_results from open_webui.utils.headers import include_user_info_headers - log = logging.getLogger(__name__) diff --git a/backend/open_webui/retrieval/web/utils.py b/backend/open_webui/retrieval/web/utils.py index 6c1ea4b1bf..45787cb4bd 100644 --- a/backend/open_webui/retrieval/web/utils.py +++ b/backend/open_webui/retrieval/web/utils.py @@ -174,7 +174,7 @@ async def _safe_process_url(self, url: str) -> bool: def _safe_process_url_sync(self, url: str) -> bool: """Synchronous version of safety checks.""" - if self.verify_ssl and not self._verify_ssl_cert(url): + if self.verify_ssl and not verify_ssl_cert(url): raise ValueError(f"SSL certificate verification failed for {url}") self._sync_wait_for_rate_limit() return True diff --git a/backend/open_webui/retrieval/web/yandex.py b/backend/open_webui/retrieval/web/yandex.py new file mode 100644 index 0000000000..fba4ee482e --- /dev/null +++ b/backend/open_webui/retrieval/web/yandex.py @@ -0,0 +1,164 @@ +import base64 +import io +import json +import logging +import os +from typing import Optional, List + +import requests + +from fastapi import Request + +from open_webui.retrieval.web.main import SearchResult, get_filtered_results +from open_webui.utils.headers import include_user_info_headers +from open_webui.env import FORWARD_SESSION_INFO_HEADER_CHAT_ID + +from xml.etree import ElementTree as ET +from xml.etree.ElementTree import Element + +log = logging.getLogger(__name__) + + +def xml_element_contents_to_string(element: Element) -> str: + buffer = [element.text if element.text else ""] + + for child in element: + buffer.append(xml_element_contents_to_string(child)) + + buffer.append(element.tail if element.tail else "") + + return "".join(buffer) + + +def search_yandex( + request: Request, + yandex_search_url: str, + yandex_search_api_key: str, + yandex_search_config: str, + query: str, + count: int, + filter_list: Optional[List[str]] = None, + user=None, +) -> List[SearchResult]: + try: + headers = { + "User-Agent": "Open WebUI (https://github.com/open-webui/open-webui) RAG Bot", + "Authorization": f"Api-Key {yandex_search_api_key}", + } + + if user is not None: + headers = include_user_info_headers(headers, user) + + chat_id = getattr(request.state, "chat_id", None) + if chat_id: + headers[FORWARD_SESSION_INFO_HEADER_CHAT_ID] = str(chat_id) + + payload = {} if yandex_search_config == "" else json.loads(yandex_search_config) + + if type(payload.get("query", None)) != dict: + payload["query"] = {} + + if "searchType" not in payload["query"]: + payload["query"]["searchType"] = "SEARCH_TYPE_RU" + + payload["query"]["queryText"] = query + + if type(payload.get("groupSpec", None)) != dict: + payload["groupSpec"] = {} + + if "groupMode" not in payload["groupSpec"]: + payload["groupSpec"]["groupMode"] = "GROUP_MODE_DEEP" + + payload["groupSpec"]["groupsOnPage"] = count + payload["groupSpec"]["docsInGroup"] = 1 + + response = requests.post( + ( + "https://searchapi.api.cloud.yandex.net/v2/web/search" + if yandex_search_url == "" + else yandex_search_url + ), + headers=headers, + json=payload, + ) + + response.raise_for_status() + + response_body = response.json() + if "rawData" not in response_body: + raise Exception(f"No `rawData` in response body: {response_body}") + + search_result_body_bytes = base64.decodebytes( + bytes(response_body["rawData"], "utf-8") + ) + + doc_root = ET.parse(io.BytesIO(search_result_body_bytes)) + + results = [] + + for group in doc_root.findall("response/results/grouping/group"): + results.append( + { + "url": xml_element_contents_to_string(group.find("doc/url")).strip( + "\n" + ), + "title": xml_element_contents_to_string( + group.find("doc/title") + ).strip("\n"), + "snippet": xml_element_contents_to_string( + group.find("doc/passages/passage") + ), + } + ) + + results = get_filtered_results(results, filter_list) + + results = [ + SearchResult( + link=result.get("url"), + title=result.get("title"), + snippet=result.get("snippet"), + ) + for result in results[:count] + ] + + log.info(f"Yandex search results: {results}") + + return results + except Exception as e: + log.error(f"Error in search: {e}") + + return [] + + +if __name__ == "__main__": + from starlette.datastructures import Headers + from fastapi import FastAPI + + result = search_yandex( + Request( + { + "type": "http", + "asgi.version": "3.0", + "asgi.spec_version": "2.0", + "method": "GET", + "path": "/internal", + "query_string": b"", + "headers": Headers({}).raw, + "client": ("127.0.0.1", 12345), + "server": ("127.0.0.1", 80), + "scheme": "http", + "app": FastAPI(), + }, + None, + ), + os.environ.get("YANDEX_WEB_SEARCH_URL", ""), + os.environ.get("YANDEX_WEB_SEARCH_API_KEY", ""), + os.environ.get( + "YANDEX_WEB_SEARCH_CONFIG", '{"query": {"searchType": "SEARCH_TYPE_COM"}}' + ), + "TOP movies of the past year", + 3, + ) + + print(result) diff --git a/backend/open_webui/routers/analytics.py b/backend/open_webui/routers/analytics.py new file mode 100644 index 0000000000..61aec66332 --- /dev/null +++ b/backend/open_webui/routers/analytics.py @@ -0,0 +1,454 @@ +from typing import Optional +from datetime import datetime, timedelta +from collections import defaultdict +import logging +from fastapi import APIRouter, Depends, Query +from pydantic import BaseModel + +from open_webui.models.chat_messages import ChatMessages, ChatMessageModel +from open_webui.models.chats import Chats +from open_webui.models.groups import Groups +from open_webui.models.users import Users +from open_webui.models.feedbacks import Feedbacks +from open_webui.utils.auth import get_admin_user +from open_webui.internal.db import get_session +from sqlalchemy.orm import Session + +log = logging.getLogger(__name__) + + +router = APIRouter() + + +#################### +# Response Models +#################### + + +class ModelAnalyticsEntry(BaseModel): + model_id: str + count: int + + +class ModelAnalyticsResponse(BaseModel): + models: list[ModelAnalyticsEntry] + + +class UserAnalyticsEntry(BaseModel): + user_id: str + name: Optional[str] = None + email: Optional[str] = None + count: int + input_tokens: int = 0 + output_tokens: int = 0 + total_tokens: int = 0 + + +class UserAnalyticsResponse(BaseModel): + users: list[UserAnalyticsEntry] + + +#################### +# Endpoints +#################### + + +@router.get("/models", response_model=ModelAnalyticsResponse) +async def get_model_analytics( + start_date: Optional[int] = Query(None, description="Start timestamp (epoch)"), + end_date: Optional[int] = Query(None, description="End timestamp (epoch)"), + group_id: Optional[str] = Query(None, description="Filter by user group ID"), + user=Depends(get_admin_user), + db: Session = Depends(get_session), +): + """Get message counts per model.""" + counts = ChatMessages.get_message_count_by_model( + start_date=start_date, end_date=end_date, group_id=group_id, db=db + ) + models = [ + ModelAnalyticsEntry(model_id=model_id, count=count) + for model_id, count in sorted(counts.items(), key=lambda x: -x[1]) + ] + return ModelAnalyticsResponse(models=models) + + +@router.get("/users", response_model=UserAnalyticsResponse) +async def get_user_analytics( + start_date: Optional[int] = Query(None, description="Start timestamp (epoch)"), + end_date: Optional[int] = Query(None, description="End timestamp (epoch)"), + group_id: Optional[str] = Query(None, description="Filter by user group ID"), + limit: int = Query(50, description="Max users to return"), + user=Depends(get_admin_user), + db: Session = Depends(get_session), +): + """Get message counts and token usage per user with user info.""" + counts = ChatMessages.get_message_count_by_user( + start_date=start_date, end_date=end_date, group_id=group_id, db=db + ) + token_usage = ChatMessages.get_token_usage_by_user( + start_date=start_date, end_date=end_date, db=db + ) + + # Get user info for top users + top_user_ids = [ + uid for uid, _ in sorted(counts.items(), key=lambda x: -x[1])[:limit] + ] + user_info = {u.id: u for u in Users.get_users_by_user_ids(top_user_ids, db=db)} + + users = [] + for user_id in top_user_ids: + u = user_info.get(user_id) + tokens = token_usage.get(user_id, {}) + users.append( + UserAnalyticsEntry( + user_id=user_id, + name=u.name if u else None, + email=u.email if u else None, + count=counts[user_id], + input_tokens=tokens.get("input_tokens", 0), + output_tokens=tokens.get("output_tokens", 0), + total_tokens=tokens.get("total_tokens", 0), + ) + ) + + return UserAnalyticsResponse(users=users) + + +@router.get("/messages", response_model=list[ChatMessageModel]) +async def get_messages( + model_id: Optional[str] = Query(None, description="Filter by model ID"), + user_id: Optional[str] = Query(None, description="Filter by user ID"), + chat_id: Optional[str] = Query(None, description="Filter by chat ID"), + start_date: Optional[int] = Query(None, description="Start timestamp (epoch)"), + end_date: Optional[int] = Query(None, description="End timestamp (epoch)"), + skip: int = Query(0), + limit: int = Query(50, le=100), + user=Depends(get_admin_user), + db: Session = Depends(get_session), +): + """Query messages with filters.""" + if chat_id: + return ChatMessages.get_messages_by_chat_id(chat_id=chat_id, db=db) + elif model_id: + return ChatMessages.get_messages_by_model_id( + model_id=model_id, + start_date=start_date, + end_date=end_date, + skip=skip, + limit=limit, + db=db, + ) + elif user_id: + return ChatMessages.get_messages_by_user_id( + user_id=user_id, skip=skip, limit=limit, db=db + ) + else: + # Return empty if no filter specified + return [] + + +class SummaryResponse(BaseModel): + total_messages: int + total_chats: int + total_models: int + total_users: int + + +@router.get("/summary", response_model=SummaryResponse) +async def get_summary( + start_date: Optional[int] = Query(None, description="Start timestamp (epoch)"), + end_date: Optional[int] = Query(None, description="End timestamp (epoch)"), + group_id: Optional[str] = Query(None, description="Filter by user group ID"), + user=Depends(get_admin_user), + db: Session = Depends(get_session), +): + """Get summary statistics for the dashboard.""" + model_counts = ChatMessages.get_message_count_by_model( + start_date=start_date, end_date=end_date, group_id=group_id, db=db + ) + user_counts = ChatMessages.get_message_count_by_user( + start_date=start_date, end_date=end_date, group_id=group_id, db=db + ) + chat_counts = ChatMessages.get_message_count_by_chat( + start_date=start_date, end_date=end_date, group_id=group_id, db=db + ) + + return SummaryResponse( + total_messages=sum(model_counts.values()), + total_chats=len(chat_counts), + total_models=len(model_counts), + total_users=len(user_counts), + ) + + +class DailyStatsEntry(BaseModel): + date: str + models: dict[str, int] + + +class DailyStatsResponse(BaseModel): + data: list[DailyStatsEntry] + + +@router.get("/daily", response_model=DailyStatsResponse) +async def get_daily_stats( + start_date: Optional[int] = Query(None, description="Start timestamp (epoch)"), + end_date: Optional[int] = Query(None, description="End timestamp (epoch)"), + group_id: Optional[str] = Query(None, description="Filter by user group ID"), + granularity: str = Query("daily", description="Granularity: 'hourly' or 'daily'"), + user=Depends(get_admin_user), + db: Session = Depends(get_session), +): + """Get message counts grouped by model for time-series chart.""" + if granularity == "hourly": + counts = ChatMessages.get_hourly_message_counts_by_model( + start_date=start_date, end_date=end_date, db=db + ) + else: + counts = ChatMessages.get_daily_message_counts_by_model( + start_date=start_date, end_date=end_date, group_id=group_id, db=db + ) + return DailyStatsResponse( + data=[ + DailyStatsEntry(date=date, models=models) + for date, models in sorted(counts.items()) + ] + ) + + +class TokenUsageEntry(BaseModel): + model_id: str + input_tokens: int + output_tokens: int + total_tokens: int + message_count: int + + +class TokenUsageResponse(BaseModel): + models: list[TokenUsageEntry] + total_input_tokens: int + total_output_tokens: int + total_tokens: int + + +@router.get("/tokens", response_model=TokenUsageResponse) +async def get_token_usage( + start_date: Optional[int] = Query(None), + end_date: Optional[int] = Query(None), + group_id: Optional[str] = Query(None, description="Filter by user group ID"), + user=Depends(get_admin_user), + db: Session = Depends(get_session), +): + """Get token usage aggregated by model.""" + usage = ChatMessages.get_token_usage_by_model( + start_date=start_date, end_date=end_date, group_id=group_id, db=db + ) + + models = [ + TokenUsageEntry(model_id=model_id, **data) + for model_id, data in sorted(usage.items(), key=lambda x: -x[1]["total_tokens"]) + ] + + total_input = sum(m.input_tokens for m in models) + total_output = sum(m.output_tokens for m in models) + + return TokenUsageResponse( + models=models, + total_input_tokens=total_input, + total_output_tokens=total_output, + total_tokens=total_input + total_output, + ) + + +#################### +# Model Chats Browser +#################### + + +class ModelChatEntry(BaseModel): + chat_id: str + user_id: Optional[str] = None + user_name: Optional[str] = None + first_message: Optional[str] = None + updated_at: int + + +class ModelChatsResponse(BaseModel): + chats: list[ModelChatEntry] + total: int + + +@router.get("/models/{model_id}/chats", response_model=ModelChatsResponse) +async def get_model_chats( + model_id: str, + start_date: Optional[int] = Query(None), + end_date: Optional[int] = Query(None), + skip: int = Query(0), + limit: int = Query(50, le=100), + user=Depends(get_admin_user), + db: Session = Depends(get_session), +): + """Get chats that used a specific model, with preview and feedback info.""" + + # Get chat IDs that used this model + chat_ids = ChatMessages.get_chat_ids_by_model_id( + model_id=model_id, + start_date=start_date, + end_date=end_date, + skip=skip, + limit=limit, + db=db, + ) + + if not chat_ids: + return ModelChatsResponse(chats=[], total=0) + + # Get chat details from messages only + chats_data = [] + for chat_id in chat_ids: + messages = ChatMessages.get_messages_by_chat_id(chat_id, db=db) + if not messages: + continue + + # Get user_id from first user message + first_user_msg = next((m for m in messages if m.role == "user"), None) + user_id = first_user_msg.user_id if first_user_msg else None + + # Extract first message content as preview + first_message = None + if first_user_msg and first_user_msg.content: + content = first_user_msg.content + if isinstance(content, str): + first_message = content[:200] + elif isinstance(content, list): + text_parts = [b.get("text", "") for b in content if isinstance(b, dict)] + first_message = " ".join(text_parts)[:200] + + # Get user info + user_name = None + if user_id: + user_info = Users.get_user_by_id(user_id, db=db) + user_name = user_info.name if user_info else None + + # Timestamps from messages + updated_at = max(m.created_at for m in messages) if messages else 0 + + chats_data.append( + ModelChatEntry( + chat_id=chat_id, + user_id=user_id, + user_name=user_name, + first_message=first_message, + updated_at=updated_at, + ) + ) + + return ModelChatsResponse(chats=chats_data, total=len(chats_data)) + + +#################### +# Model Overview +#################### + + +class HistoryEntry(BaseModel): + date: str + won: int = 0 + lost: int = 0 + + +class TagEntry(BaseModel): + tag: str + count: int + + +class ModelOverviewResponse(BaseModel): + history: list[HistoryEntry] + tags: list[TagEntry] + + +@router.get("/models/{model_id}/overview", response_model=ModelOverviewResponse) +async def get_model_overview( + model_id: str, + days: int = Query(30, description="Number of days of history (0 for all)"), + user=Depends(get_admin_user), + db: Session = Depends(get_session), +): + """Get model overview with feedback history and chat tags.""" + + # Get chat IDs that used this model + chat_ids = ChatMessages.get_chat_ids_by_model_id( + model_id=model_id, + start_date=None, + end_date=None, + skip=0, + limit=10000, # Get all chats + db=db, + ) + + # Get feedback history per day + history_counts: dict[str, dict] = defaultdict(lambda: {"won": 0, "lost": 0}) + + # Calculate start date for history + now = datetime.now() + start_dt = None + if days > 0: + start_dt = now - timedelta(days=days) + + for chat_id in chat_ids: + feedbacks = Feedbacks.get_feedbacks_by_chat_id(chat_id, db=db) + for fb in feedbacks: + if fb.data and "rating" in fb.data: + rating = fb.data["rating"] + fb_date = datetime.fromtimestamp(fb.created_at) + + # Filter by date range + if start_dt and fb_date < start_dt: + continue + + date_str = fb_date.strftime("%Y-%m-%d") + if rating == 1: + history_counts[date_str]["won"] += 1 + elif rating == -1: + history_counts[date_str]["lost"] += 1 + + # Fill in missing days + history = [] + if history_counts or days > 0: + end_dt = now + if days > 0: + current = start_dt + elif history_counts: + # Find earliest date + min_date = min(history_counts.keys()) + current = datetime.strptime(min_date, "%Y-%m-%d") + else: + current = now + + while current <= end_dt: + date_str = current.strftime("%Y-%m-%d") + counts = history_counts.get(date_str, {"won": 0, "lost": 0}) + history.append( + HistoryEntry( + date=date_str, + won=counts["won"], + lost=counts["lost"], + ) + ) + current += timedelta(days=1) + + # Get chat tags + tag_counts: dict[str, int] = defaultdict(int) + for chat_id in chat_ids: + chat = Chats.get_chat_by_id(chat_id, db=db) + if chat and chat.meta: + for tag in chat.meta.get("tags", []): + tag_counts[tag] += 1 + + # Sort by count and take top 10 + tags = [ + TagEntry(tag=tag, count=count) + for tag, count in sorted(tag_counts.items(), key=lambda x: -x[1])[:10] + ] + + return ModelOverviewResponse(history=history, tags=tags) diff --git a/backend/open_webui/routers/audio.py b/backend/open_webui/routers/audio.py index 52e0182cad..139b64f7cf 100644 --- a/backend/open_webui/routers/audio.py +++ b/backend/open_webui/routers/audio.py @@ -53,11 +53,11 @@ ENV, AIOHTTP_CLIENT_SESSION_SSL, AIOHTTP_CLIENT_TIMEOUT, + AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST, DEVICE_TYPE, ENABLE_FORWARD_USER_INFO_HEADERS, ) - router = APIRouter() # Constants @@ -644,6 +644,7 @@ def transcription_handler(request, file_path, metadata, user=None): headers=headers, files={"file": (filename, open(file_path, "rb"))}, data=payload, + timeout=AIOHTTP_CLIENT_TIMEOUT, ) if r.status_code == 200: @@ -704,6 +705,7 @@ def transcription_handler(request, file_path, metadata, user=None): headers=headers, params=params, data=file_data, + timeout=AIOHTTP_CLIENT_TIMEOUT, ) if r.status_code == 200: @@ -815,6 +817,7 @@ def transcription_handler(request, file_path, metadata, user=None): headers={ "Ocp-Apim-Subscription-Key": api_key, }, + timeout=AIOHTTP_CLIENT_TIMEOUT, ) r.raise_for_status() @@ -954,6 +957,7 @@ def transcription_handler(request, file_path, metadata, user=None): "Authorization": f"Bearer {api_key}", "Content-Type": "application/json", }, + timeout=AIOHTTP_CLIENT_TIMEOUT, ) r.raise_for_status() @@ -997,6 +1001,7 @@ def transcription_handler(request, file_path, metadata, user=None): headers={ "Authorization": f"Bearer {api_key}", }, + timeout=AIOHTTP_CLIENT_TIMEOUT, ) r.raise_for_status() @@ -1240,7 +1245,8 @@ def get_available_models(request: Request) -> list[dict]: ): try: response = requests.get( - f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/models" + f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/models", + timeout=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST, ) response.raise_for_status() data = response.json() @@ -1286,7 +1292,8 @@ def get_available_voices(request) -> dict: ): try: response = requests.get( - f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/voices" + f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/voices", + timeout=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST, ) response.raise_for_status() data = response.json() @@ -1330,7 +1337,9 @@ def get_available_voices(request) -> dict: "Ocp-Apim-Subscription-Key": request.app.state.config.TTS_API_KEY } - response = requests.get(url, headers=headers) + response = requests.get( + url, headers=headers, timeout=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST + ) response.raise_for_status() voices = response.json() @@ -1362,6 +1371,7 @@ def get_elevenlabs_voices(api_key: str) -> dict: "xi-api-key": api_key, "Content-Type": "application/json", }, + timeout=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST, ) response.raise_for_status() voices_data = response.json() diff --git a/backend/open_webui/routers/auths.py b/backend/open_webui/routers/auths.py index 29ed41a74d..8851a829f8 100644 --- a/backend/open_webui/routers/auths.py +++ b/backend/open_webui/routers/auths.py @@ -1,3 +1,4 @@ +import asyncio import re import uuid import time @@ -8,7 +9,6 @@ from aiohttp import ClientSession import urllib - from open_webui.models.auths import ( AddUserForm, ApiKey, @@ -21,6 +21,7 @@ UpdatePasswordForm, ) from open_webui.models.users import ( + UserModel, UserProfileImageResponse, Users, UserModel, @@ -41,6 +42,8 @@ WEBUI_AUTH_COOKIE_SECURE, WEBUI_AUTH_SIGNOUT_REDIRECT_URL, ENABLE_INITIAL_ADMIN_SIGNUP, + ENABLE_OAUTH_TOKEN_EXCHANGE, + AIOHTTP_CLIENT_SESSION_SSL, REDIS_URL, REDIS_SENTINEL_HOSTS, REDIS_SENTINEL_PORT, @@ -53,6 +56,8 @@ ENABLE_OAUTH_SIGNUP, ENABLE_LDAP, ENABLE_PASSWORD_AUTH, + OAUTH_PROVIDERS, + OAUTH_MERGE_ACCOUNTS_BY_EMAIL, ) from pydantic import BaseModel, Field @@ -85,7 +90,6 @@ ) from open_webui.utils.rate_limit import RateLimiter - from typing import Optional, List from ssl import CERT_NONE, CERT_REQUIRED, PROTOCOL_TLS @@ -101,6 +105,67 @@ redis_client=get_redis_client(), limit=5 * 3, window=60 * 3 ) + +def create_session_response( + request: Request, user, db, response: Response = None, set_cookie: bool = False +) -> dict: + """ + Create JWT token and build session response for a user. + Shared helper for signin, signup, ldap_auth, add_user, and token_exchange endpoints. + + Args: + request: FastAPI request object + user: User object + db: Database session + response: FastAPI response object (required if set_cookie is True) + set_cookie: Whether to set the auth cookie on the response + """ + + credit = Credits.init_credit_by_user_id(user.id) + + expires_delta = parse_duration(request.app.state.config.JWT_EXPIRES_IN) + expires_at = None + if expires_delta: + expires_at = int(time.time()) + int(expires_delta.total_seconds()) + + token = create_token( + data={"id": user.id}, + expires_delta=expires_delta, + ) + + if set_cookie and response: + datetime_expires_at = ( + datetime.datetime.fromtimestamp(expires_at, datetime.timezone.utc) + if expires_at + else None + ) + response.set_cookie( + key="token", + value=token, + expires=datetime_expires_at, + httponly=True, + samesite=WEBUI_AUTH_COOKIE_SAME_SITE, + secure=WEBUI_AUTH_COOKIE_SECURE, + ) + + user_permissions = get_permissions( + user.id, request.app.state.config.USER_PERMISSIONS, db=db + ) + + return { + "token": token, + "token_type": "Bearer", + "expires_at": expires_at, + "id": user.id, + "email": user.email, + "name": user.name, + "role": user.role, + "profile_image_url": user.profile_image_url, + "permissions": user_permissions, + "credit": credit.credit, + } + + ############################ # GetSessionUser ############################ @@ -332,7 +397,7 @@ async def ldap_auth( auto_bind="NONE", authentication="SIMPLE" if LDAP_APP_DN else "ANONYMOUS", ) - if not connection_app.bind(): + if not await asyncio.to_thread(connection_app.bind): raise HTTPException(400, detail="Application account bind failed") ENABLE_LDAP_GROUP_MANAGEMENT = ( @@ -353,7 +418,8 @@ async def ldap_auth( ) log.info(f"LDAP search attributes: {search_attributes}") - search_success = connection_app.search( + search_success = await asyncio.to_thread( + connection_app.search, search_base=LDAP_SEARCH_BASE, search_filter=f"(&({LDAP_ATTRIBUTE_FOR_USERNAME}={escape_filter_chars(form_data.user.lower())}){LDAP_SEARCH_FILTERS})", attributes=search_attributes, @@ -457,7 +523,7 @@ async def ldap_auth( auto_bind="NONE", authentication="SIMPLE", ) - if not connection_user.bind(): + if not await asyncio.to_thread(connection_user.bind): raise HTTPException(400, "Authentication failed.") user = Users.get_user_by_email(email, db=db) @@ -499,38 +565,6 @@ async def ldap_auth( user = Auths.authenticate_user_by_email(email, db=db) if user: - expires_delta = parse_duration(request.app.state.config.JWT_EXPIRES_IN) - expires_at = None - if expires_delta: - expires_at = int(time.time()) + int(expires_delta.total_seconds()) - - token = create_token( - data={"id": user.id}, - expires_delta=expires_delta, - ) - - # Set the cookie token - response.set_cookie( - key="token", - value=token, - expires=( - datetime.datetime.fromtimestamp( - expires_at, datetime.timezone.utc - ) - if expires_at - else None - ), - httponly=True, # Ensures the cookie is not accessible via JavaScript - samesite=WEBUI_AUTH_COOKIE_SAME_SITE, - secure=WEBUI_AUTH_COOKIE_SECURE, - ) - - user_permissions = get_permissions( - user.id, request.app.state.config.USER_PERMISSIONS, db=db - ) - - credit = Credits.init_credit_by_user_id(user.id) - if ( user.role != "admin" and ENABLE_LDAP_GROUP_MANAGEMENT @@ -546,18 +580,9 @@ async def ldap_auth( except Exception as e: log.error(f"Failed to sync groups for user {user.id}: {e}") - return { - "token": token, - "token_type": "Bearer", - "expires_at": expires_at, - "id": user.id, - "email": user.email, - "name": user.name, - "role": user.role, - "profile_image_url": user.profile_image_url, - "permissions": user_permissions, - "credit": credit.credit, - } + return create_session_response( + request, user, db, response, set_cookie=True + ) else: raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) else: @@ -600,10 +625,11 @@ async def signin( pass if not Users.get_user_by_email(email.lower(), db=db): - await signup( + await signup_handler( request, - response, - SignupForm(email=email, password=str(uuid.uuid4()), name=name), + email, + str(uuid.uuid4()), + name, db=db, ) @@ -631,10 +657,11 @@ async def signin( if Users.has_users(db=db): raise HTTPException(400, detail=ERROR_MESSAGES.EXISTING_USERS) - await signup( + await signup_handler( request, - response, - SignupForm(email=admin_email, password=admin_password, name="User"), + admin_email, + admin_password, + "User", db=db, ) @@ -666,58 +693,76 @@ async def signin( ) if user: + return create_session_response(request, user, db, response, set_cookie=True) + else: + raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) - expires_delta = parse_duration(request.app.state.config.JWT_EXPIRES_IN) - expires_at = None - if expires_delta: - expires_at = int(time.time()) + int(expires_delta.total_seconds()) - - token = create_token( - data={"id": user.id}, - expires_delta=expires_delta, - ) - datetime_expires_at = ( - datetime.datetime.fromtimestamp(expires_at, datetime.timezone.utc) - if expires_at - else None - ) +############################ +# SignUp +############################ - # Set the cookie token - response.set_cookie( - key="token", - value=token, - expires=datetime_expires_at, - httponly=True, # Ensures the cookie is not accessible via JavaScript - samesite=WEBUI_AUTH_COOKIE_SAME_SITE, - secure=WEBUI_AUTH_COOKIE_SECURE, - ) - user_permissions = get_permissions( - user.id, request.app.state.config.USER_PERMISSIONS, db=db +async def signup_handler( + request: Request, + email: str, + password: str, + name: str, + profile_image_url: str = "/user.png", + *, + db: Session, +) -> UserModel: + """ + Core user-creation logic shared by the signup endpoint and + trusted-header / no-auth auto-registration flows. + + Returns the newly created UserModel. + Raises HTTPException on failure. + """ + has_users = Users.has_users(db=db) + if not has_users: + role = "admin" + elif request.app.state.config.ENABLE_SIGNUP_VERIFY: + role = "pending" + send_verify_email(email=email.lower()) + else: + role = request.app.state.config.DEFAULT_USER_ROLE + hashed = get_password_hash(password) + + user = Auths.insert_new_auth( + email=email.lower(), + password=hashed, + name=name, + profile_image_url=profile_image_url, + role=role, + db=db, + ) + if not user: + raise HTTPException(500, detail=ERROR_MESSAGES.CREATE_USER_ERROR) + + if request.app.state.config.WEBHOOK_URL: + await post_webhook( + request.app.state.WEBUI_NAME, + request.app.state.config.WEBHOOK_URL, + WEBHOOK_MESSAGES.USER_SIGNUP(user.name), + { + "action": "signup", + "message": WEBHOOK_MESSAGES.USER_SIGNUP(user.name), + "user": user.model_dump_json(exclude_none=True), + }, ) - credit = Credits.init_credit_by_user_id(user.id) - - return { - "token": token, - "token_type": "Bearer", - "expires_at": expires_at, - "id": user.id, - "email": user.email, - "name": user.name, - "role": user.role, - "profile_image_url": user.profile_image_url, - "permissions": user_permissions, - "credit": credit.credit, - } - else: - raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) + if not has_users: + # Disable signup after the first user is created + request.app.state.config.ENABLE_SIGNUP = False + apply_default_group_assignment( + request.app.state.config.DEFAULT_GROUP_ID, + user.id, + db=db, + ) -############################ -# SignUp -############################ + return user @router.post("/signup", response_model=SessionUserResponse) @@ -772,94 +817,17 @@ async def signup( except Exception as e: raise HTTPException(400, detail=str(e)) - hashed = get_password_hash(form_data.password) - - if not has_users: - role = "admin" - elif request.app.state.config.ENABLE_SIGNUP_VERIFY: - role = "pending" - send_verify_email(email=form_data.email.lower()) - else: - role = request.app.state.config.DEFAULT_USER_ROLE - - user = Auths.insert_new_auth( - form_data.email.lower(), - hashed, + user = await signup_handler( + request, + form_data.email, + form_data.password, form_data.name, form_data.profile_image_url, - role, db=db, ) - - if user: - expires_delta = parse_duration(request.app.state.config.JWT_EXPIRES_IN) - expires_at = None - if expires_delta: - expires_at = int(time.time()) + int(expires_delta.total_seconds()) - - token = create_token( - data={"id": user.id}, - expires_delta=expires_delta, - ) - - datetime_expires_at = ( - datetime.datetime.fromtimestamp(expires_at, datetime.timezone.utc) - if expires_at - else None - ) - - # Set the cookie token - response.set_cookie( - key="token", - value=token, - expires=datetime_expires_at, - httponly=True, # Ensures the cookie is not accessible via JavaScript - samesite=WEBUI_AUTH_COOKIE_SAME_SITE, - secure=WEBUI_AUTH_COOKIE_SECURE, - ) - - if request.app.state.config.WEBHOOK_URL: - await post_webhook( - request.app.state.WEBUI_NAME, - request.app.state.config.WEBHOOK_URL, - WEBHOOK_MESSAGES.USER_SIGNUP(user.name), - { - "action": "signup", - "message": WEBHOOK_MESSAGES.USER_SIGNUP(user.name), - "user": user.model_dump_json(exclude_none=True), - }, - ) - - user_permissions = get_permissions( - user.id, request.app.state.config.USER_PERMISSIONS, db=db - ) - - if not has_users: - # Disable signup after the first user is created - request.app.state.config.ENABLE_SIGNUP = False - - apply_default_group_assignment( - request.app.state.config.DEFAULT_GROUP_ID, - user.id, - db=db, - ) - - credit = Credits.init_credit_by_user_id(user.id) - - return { - "token": token, - "token_type": "Bearer", - "expires_at": expires_at, - "id": user.id, - "email": user.email, - "name": user.name, - "role": user.role, - "profile_image_url": user.profile_image_url, - "permissions": user_permissions, - "credit": credit.credit, - } - else: - raise HTTPException(500, detail=ERROR_MESSAGES.CREATE_USER_ERROR) + return create_session_response(request, user, db, response, set_cookie=True) + except HTTPException: + raise except Exception as err: log.error(f"Signup error: {str(err)}") raise HTTPException(500, detail="An internal error occurred during signup.") @@ -883,7 +851,6 @@ async def signup_verify(request: Request, code: str): async def signout( request: Request, response: Response, db: Session = Depends(get_session) ): - # get auth token from headers or cookies token = None auth_header = request.headers.get("Authorization") @@ -1015,6 +982,8 @@ async def add_user( } else: raise HTTPException(500, detail=ERROR_MESSAGES.CREATE_USER_ERROR) + except HTTPException: + raise except Exception as err: log.error(f"Add user error: {str(err)}") raise HTTPException( @@ -1387,3 +1356,108 @@ async def get_api_key( } else: raise HTTPException(404, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND) + + +############################ +# Token Exchange +############################ + + +class TokenExchangeForm(BaseModel): + token: str # OAuth access token from external provider + + +@router.post("/oauth/{provider}/token/exchange", response_model=SessionUserResponse) +async def token_exchange( + request: Request, + response: Response, + provider: str, + form_data: TokenExchangeForm, + db: Session = Depends(get_session), +): + """ + Exchange an external OAuth provider token for an OpenWebUI JWT. + This endpoint is disabled by default. Set ENABLE_OAUTH_TOKEN_EXCHANGE=True to enable. + """ + if not ENABLE_OAUTH_TOKEN_EXCHANGE: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Token exchange is disabled", + ) + + provider = provider.lower() + + # Check if provider is configured + if provider not in OAUTH_PROVIDERS: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Provider '{provider}' is not configured", + ) + # Get the OAuth client for this provider + oauth_manager = request.app.state.oauth_manager + client = oauth_manager.get_client(provider) + if not client: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"OAuth client for '{provider}' not found", + ) + + # Validate the token by calling the userinfo endpoint + try: + token_data = {"access_token": form_data.token, "token_type": "Bearer"} + user_data = await client.userinfo(token=token_data) + + if not user_data: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Invalid token or unable to fetch user info", + ) + except Exception as e: + log.warning(f"Token exchange failed for provider {provider}: {e}") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Invalid token or unable to validate with provider", + ) + + # Extract user information from the token claims + email_claim = request.app.state.config.OAUTH_EMAIL_CLAIM + username_claim = request.app.state.config.OAUTH_USERNAME_CLAIM + + # Get sub claim + sub = user_data.get( + request.app.state.config.OAUTH_SUB_CLAIM + or OAUTH_PROVIDERS[provider].get("sub_claim", "sub") + ) + if not sub: + log.warning(f"Token exchange failed: sub claim missing from user data") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Token missing required 'sub' claim", + ) + + email = user_data.get(email_claim, "") + if not email: + log.warning(f"Token exchange failed: email claim missing from user data") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Token missing required email claim", + ) + email = email.lower() + + # Try to find the user by OAuth sub + user = Users.get_user_by_oauth_sub(provider, sub, db=db) + + if not user and OAUTH_MERGE_ACCOUNTS_BY_EMAIL.value: + # Try to find by email if merge is enabled + user = Users.get_user_by_email(email, db=db) + if user: + # Link the OAuth sub to this user + Users.update_user_oauth_by_id(user.id, provider, sub, db=db) + + if not user: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="User not found. Please sign in via the web interface first.", + ) + + return create_session_response(request, user, db) diff --git a/backend/open_webui/routers/channels.py b/backend/open_webui/routers/channels.py index 4e697142bf..1748eaf7ea 100644 --- a/backend/open_webui/routers/channels.py +++ b/backend/open_webui/routers/channels.py @@ -22,6 +22,7 @@ UserListResponse, UserModelResponse, Users, + UserModel, UserNameResponse, ) @@ -35,6 +36,7 @@ ChannelWebhookModel, ChannelWebhookForm, ) +from open_webui.models.access_grants import AccessGrants, has_public_read_access_grant from open_webui.models.messages import ( Messages, MessageModel, @@ -59,12 +61,7 @@ from open_webui.utils.auth import get_admin_user, get_verified_user -from open_webui.utils.access_control import ( - has_access, - get_users_with_access, - get_permitted_group_and_user_ids, - has_permission, -) +from open_webui.utils.access_control import has_permission from open_webui.utils.webhook import post_webhook from open_webui.utils.channels import extract_mentions, replace_mentions from open_webui.internal.db import get_session @@ -75,12 +72,72 @@ router = APIRouter() +def channel_has_access( + user_id: str, + channel: ChannelModel, + permission: str = "read", + strict: bool = True, + db: Optional[Session] = None, +) -> bool: + if AccessGrants.has_access( + user_id=user_id, + resource_type="channel", + resource_id=channel.id, + permission=permission, + db=db, + ): + return True + + if ( + not strict + and permission == "write" + and has_public_read_access_grant(channel.access_grants) + ): + return True + + return False + + +def get_channel_users_with_access( + channel: ChannelModel, permission: str = "read", db: Optional[Session] = None +): + return AccessGrants.get_users_with_access( + resource_type="channel", + resource_id=channel.id, + permission=permission, + db=db, + ) + + +def get_channel_permitted_group_and_user_ids( + channel: ChannelModel, permission: str = "read" +) -> Optional[dict[str, list[str]]]: + if permission == "read" and has_public_read_access_grant(channel.access_grants): + return None + + user_ids = [] + group_ids = [] + + for grant in channel.access_grants: + if grant.permission != permission: + continue + if grant.principal_type == "group": + group_ids.append(grant.principal_id) + elif grant.principal_type == "user" and grant.principal_id != "*": + user_ids.append(grant.principal_id) + + return { + "user_ids": list(dict.fromkeys(user_ids)), + "group_ids": list(dict.fromkeys(group_ids)), + } + + ############################ # Channels Enabled Dependency ############################ -def check_channels_access(request: Request): +def check_channels_access(request: Request, user: Optional[UserModel] = None): """Dependency to ensure channels are globally enabled.""" if not request.app.state.config.ENABLE_CHANNELS: raise HTTPException( @@ -88,6 +145,15 @@ def check_channels_access(request: Request): detail="Channels are not enabled", ) + if user: + if user.role != "admin" and not has_permission( + user.id, "features.channels", request.app.state.config.USER_PERMISSIONS + ): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.UNAUTHORIZED, + ) + ############################ # GetChatList @@ -108,14 +174,7 @@ async def get_channels( user=Depends(get_verified_user), db: Session = Depends(get_session), ): - check_channels_access(request) - if user.role != "admin" and not has_permission( - user.id, "features.channels", request.app.state.config.USER_PERMISSIONS, db=db - ): - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail=ERROR_MESSAGES.UNAUTHORIZED, - ) + check_channels_access(request, user) channels = Channels.get_channels_by_user_id(user.id, db=db) channel_list = [] @@ -188,15 +247,7 @@ async def get_dm_channel_by_user_id( user=Depends(get_verified_user), db: Session = Depends(get_session), ): - check_channels_access(request) - if user.role != "admin" and not has_permission( - user.id, "features.channels", request.app.state.config.USER_PERMISSIONS, db=db - ): - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail=ERROR_MESSAGES.UNAUTHORIZED, - ) - + check_channels_access(request, user) try: existing_channel = Channels.get_dm_channel_by_user_ids( [user.id, user_id], db=db @@ -268,14 +319,7 @@ async def create_new_channel( user=Depends(get_verified_user), db: Session = Depends(get_session), ): - check_channels_access(request) - if user.role != "admin" and not has_permission( - user.id, "features.channels", request.app.state.config.USER_PERMISSIONS, db=db - ): - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail=ERROR_MESSAGES.UNAUTHORIZED, - ) + check_channels_access(request, user) if form_data.type not in ["group", "dm"] and user.role != "admin": # Only admins can create standard channels (joined by default) @@ -355,7 +399,7 @@ async def get_channel_by_id( user=Depends(get_verified_user), db: Session = Depends(get_session), ): - check_channels_access(request) + check_channels_access(request, user) channel = Channels.get_channel_by_id(id, db=db) if not channel: raise HTTPException( @@ -408,22 +452,22 @@ async def get_channel_by_id( } ) else: - if user.role != "admin" and not has_access( - user.id, type="read", access_control=channel.access_control, db=db + if user.role != "admin" and not channel_has_access( + user.id, channel, permission="read", db=db ): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() ) - write_access = has_access( + write_access = channel_has_access( user.id, - type="write", - access_control=channel.access_control, + channel, + permission="write", strict=False, db=db, ) - user_count = len(get_users_with_access("read", channel.access_control)) + user_count = len(get_channel_users_with_access(channel, "read", db=db)) channel_member = Channels.get_member_by_channel_and_user_id( channel.id, user.id, db=db @@ -467,7 +511,7 @@ async def get_channel_members_by_id( user=Depends(get_verified_user), db: Session = Depends(get_session), ): - check_channels_access(request) + check_channels_access(request, user) channel = Channels.get_channel_by_id(id, db=db) if not channel: @@ -517,8 +561,8 @@ async def get_channel_members_by_id( filter["channel_id"] = channel.id else: filter["roles"] = ["!pending"] - permitted_ids = get_permitted_group_and_user_ids( - "read", channel.access_control + permitted_ids = get_channel_permitted_group_and_user_ids( + channel, permission="read" ) if permitted_ids: filter["user_ids"] = permitted_ids.get("user_ids") @@ -593,15 +637,7 @@ async def add_members_by_id( user=Depends(get_verified_user), db: Session = Depends(get_session), ): - check_channels_access(request) - if user.role != "admin" and not has_permission( - user.id, "features.channels", request.app.state.config.USER_PERMISSIONS, db=db - ): - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail=ERROR_MESSAGES.UNAUTHORIZED, - ) - + check_channels_access(request, user) channel = Channels.get_channel_by_id(id, db=db) if not channel: raise HTTPException( @@ -643,14 +679,7 @@ async def remove_members_by_id( user=Depends(get_verified_user), db: Session = Depends(get_session), ): - check_channels_access(request) - if user.role != "admin" and not has_permission( - user.id, "features.channels", request.app.state.config.USER_PERMISSIONS, db=db - ): - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail=ERROR_MESSAGES.UNAUTHORIZED, - ) + check_channels_access(request, user) channel = Channels.get_channel_by_id(id, db=db) if not channel: @@ -689,14 +718,7 @@ async def update_channel_by_id( user=Depends(get_verified_user), db: Session = Depends(get_session), ): - check_channels_access(request) - if user.role != "admin" and not has_permission( - user.id, "features.channels", request.app.state.config.USER_PERMISSIONS, db=db - ): - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail=ERROR_MESSAGES.UNAUTHORIZED, - ) + check_channels_access(request, user) channel = Channels.get_channel_by_id(id, db=db) if not channel: @@ -731,14 +753,7 @@ async def delete_channel_by_id( user=Depends(get_verified_user), db: Session = Depends(get_session), ): - check_channels_access(request) - if user.role != "admin" and not has_permission( - user.id, "features.channels", request.app.state.config.USER_PERMISSIONS, db=db - ): - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail=ERROR_MESSAGES.UNAUTHORIZED, - ) + check_channels_access(request, user) channel = Channels.get_channel_by_id(id, db=db) if not channel: @@ -788,7 +803,7 @@ async def get_channel_messages( user=Depends(get_verified_user), db: Session = Depends(get_session), ): - check_channels_access(request) + check_channels_access(request, user) channel = Channels.get_channel_by_id(id, db=db) if not channel: raise HTTPException( @@ -801,8 +816,8 @@ async def get_channel_messages( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() ) else: - if user.role != "admin" and not has_access( - user.id, type="read", access_control=channel.access_control, db=db + if user.role != "admin" and not channel_has_access( + user.id, channel, permission="read", db=db ): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() @@ -878,8 +893,8 @@ async def get_pinned_channel_messages( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() ) else: - if user.role != "admin" and not has_access( - user.id, type="read", access_control=channel.access_control, db=db + if user.role != "admin" and not channel_has_access( + user.id, channel, permission="read", db=db ): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() @@ -936,7 +951,7 @@ async def get_pinned_channel_messages( async def send_notification( name, webui_url, channel, message, active_user_ids, db=None ): - users = get_users_with_access("read", channel.access_control) + users = get_channel_users_with_access(channel, "read", db=db) for user in users: if (user.id not in active_user_ids) and Channels.is_user_channel_member( @@ -1055,7 +1070,7 @@ async def model_response_handler(request, channel, message, user, db=None): f"{username}: {replace_mentions(thread_message.content)}" ) - thread_message_files = thread_message.data.get("files", []) + thread_message_files = (thread_message.data or {}).get("files", []) for file in thread_message_files: if file.get("type", "") == "image": images.append(file.get("url", "")) @@ -1163,10 +1178,10 @@ async def new_message_handler( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() ) else: - if user.role != "admin" and not has_access( + if user.role != "admin" and not channel_has_access( user.id, - type="write", - access_control=channel.access_control, + channel, + permission="write", strict=False, db=db, ): @@ -1308,8 +1323,8 @@ async def get_channel_message( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() ) else: - if user.role != "admin" and not has_access( - user.id, type="read", access_control=channel.access_control, db=db + if user.role != "admin" and not channel_has_access( + user.id, channel, permission="read", db=db ): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() @@ -1362,8 +1377,8 @@ async def get_channel_message_data( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() ) else: - if user.role != "admin" and not has_access( - user.id, type="read", access_control=channel.access_control, db=db + if user.role != "admin" and not channel_has_access( + user.id, channel, permission="read", db=db ): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() @@ -1416,8 +1431,8 @@ async def pin_channel_message( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() ) else: - if user.role != "admin" and not has_access( - user.id, type="read", access_control=channel.access_control, db=db + if user.role != "admin" and not channel_has_access( + user.id, channel, permission="read", db=db ): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() @@ -1482,8 +1497,8 @@ async def get_channel_thread_messages( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() ) else: - if user.role != "admin" and not has_access( - user.id, type="read", access_control=channel.access_control, db=db + if user.role != "admin" and not channel_has_access( + user.id, channel, permission="read", db=db ): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() @@ -1567,9 +1582,7 @@ async def update_message_by_id( if ( user.role != "admin" and message.user_id != user.id - and not has_access( - user.id, type="read", access_control=channel.access_control, db=db - ) + and not channel_has_access(user.id, channel, permission="read", db=db) ): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() @@ -1634,10 +1647,10 @@ async def add_reaction_to_message( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() ) else: - if user.role != "admin" and not has_access( + if user.role != "admin" and not channel_has_access( user.id, - type="write", - access_control=channel.access_control, + channel, + permission="write", strict=False, db=db, ): @@ -1713,10 +1726,10 @@ async def remove_reaction_by_id_and_user_id_and_name( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() ) else: - if user.role != "admin" and not has_access( + if user.role != "admin" and not channel_has_access( user.id, - type="write", - access_control=channel.access_control, + channel, + permission="write", strict=False, db=db, ): @@ -1808,10 +1821,10 @@ async def delete_message_by_id( if ( user.role != "admin" and message.user_id != user.id - and not has_access( + and not channel_has_access( user.id, - type="write", - access_control=channel.access_control, + channel, + permission="write", strict=False, db=db, ) @@ -1874,13 +1887,9 @@ async def delete_message_by_id( @router.get("/webhooks/{webhook_id}/profile/image") -async def get_webhook_profile_image( - webhook_id: str, - user=Depends(get_verified_user), - db: Session = Depends(get_session), -): +def get_webhook_profile_image(webhook_id: str, user=Depends(get_verified_user)): """Get webhook profile image by webhook ID.""" - webhook = Channels.get_webhook_by_id(webhook_id, db=db) + webhook = Channels.get_webhook_by_id(webhook_id) if not webhook: # Return default favicon if webhook not found return FileResponse(f"{STATIC_DIR}/favicon.png") diff --git a/backend/open_webui/routers/chats.py b/backend/open_webui/routers/chats.py index 9a43234aa6..e03cdc7ba9 100644 --- a/backend/open_webui/routers/chats.py +++ b/backend/open_webui/routers/chats.py @@ -16,6 +16,7 @@ ChatResponse, Chats, ChatTitleIdResponse, + SharedChatResponse, ChatStatsExport, AggregateChatStats, ChatBody, @@ -357,9 +358,7 @@ def get_message_content_length(message): return None -def calculate_chat_stats( - user_id, skip=0, limit=10, filter=None, db: Optional[Session] = None -): +def calculate_chat_stats(user_id, skip=0, limit=10, filter=None): if filter is None: filter = {} @@ -368,7 +367,6 @@ def calculate_chat_stats( skip=skip, limit=limit, filter=filter, - db=db, ) chat_stats_export_list = [] @@ -424,7 +422,6 @@ async def export_chat_stats( page: Optional[int] = 1, stream: bool = False, user=Depends(get_verified_user), - db: Session = Depends(get_session), ): # Check if the user has permission to share/export chats if (user.role != "admin") and ( @@ -455,7 +452,7 @@ async def export_chat_stats( skip = (page - 1) * limit chat_stats_export_list, total = await asyncio.to_thread( - calculate_chat_stats, user.id, skip, limit, filter, db=db + calculate_chat_stats, user.id, skip, limit, filter ) return ChatStatsExportList( @@ -862,6 +859,48 @@ async def unarchive_all_chats( return Chats.unarchive_all_chats_by_user_id(user.id, db=db) +############################ +# GetSharedChats +############################ + + +@router.get("/shared", response_model=list[SharedChatResponse]) +async def get_shared_session_user_chat_list( + page: Optional[int] = None, + query: Optional[str] = None, + order_by: Optional[str] = None, + direction: Optional[str] = None, + user=Depends(get_verified_user), + db: Session = Depends(get_session), +): + if page is None: + page = 1 + + limit = 60 + skip = (page - 1) * limit + + filter = {} + if query: + filter["query"] = query + if order_by: + filter["order_by"] = order_by + if direction: + filter["direction"] = direction + + chat_list = [ + SharedChatResponse(**chat.model_dump()) + for chat in Chats.get_shared_chat_list_by_user_id( + user.id, + filter=filter, + skip=skip, + limit=limit, + db=db, + ) + ] + + return chat_list + + ############################ # GetSharedChatById ############################ diff --git a/backend/open_webui/routers/configs.py b/backend/open_webui/routers/configs.py index 83a01c6dc4..4bee92c87b 100644 --- a/backend/open_webui/routers/configs.py +++ b/backend/open_webui/routers/configs.py @@ -224,7 +224,7 @@ async def verify_tool_servers_config( try: if form_data.type == "mcp": if form_data.auth_type == "oauth_2.1": - discovery_urls = get_discovery_urls(form_data.url) + discovery_urls = await get_discovery_urls(form_data.url) for discovery_url in discovery_urls: log.debug( f"Trying to fetch OAuth 2.1 discovery document from {discovery_url}" diff --git a/backend/open_webui/routers/files.py b/backend/open_webui/routers/files.py index e3dd63525a..4a12db5cd9 100644 --- a/backend/open_webui/routers/files.py +++ b/backend/open_webui/routers/files.py @@ -38,6 +38,7 @@ from open_webui.models.chats import Chats from open_webui.models.knowledge import Knowledges from open_webui.models.groups import Groups +from open_webui.models.access_grants import AccessGrants from open_webui.routers.retrieval import ProcessFileForm, process_file @@ -47,7 +48,6 @@ from open_webui.utils.auth import get_admin_user, get_verified_user -from open_webui.utils.access_control import has_access from open_webui.utils.misc import strict_match_mime_type from pydantic import BaseModel @@ -82,8 +82,13 @@ def has_access_to_file( group.id for group in Groups.get_groups_by_member_id(user.id, db=db) } for knowledge_base in knowledge_bases: - if knowledge_base.user_id == user.id or has_access( - user.id, access_type, knowledge_base.access_control, user_group_ids, db=db + if knowledge_base.user_id == user.id or AccessGrants.has_access( + user_id=user.id, + resource_type="knowledge", + resource_id=knowledge_base.id, + permission=access_type, + user_group_ids=user_group_ids, + db=db, ): return True @@ -282,7 +287,11 @@ def upload_file_handler( }, "meta": { "name": name, - "content_type": file.content_type, + "content_type": ( + file.content_type + if isinstance(file.content_type, str) + else None + ), "size": len(contents), "data": file_metadata, }, @@ -332,6 +341,8 @@ def upload_file_handler( detail=ERROR_MESSAGES.DEFAULT("Error uploading file"), ) + except HTTPException as e: + raise e except Exception as e: log.exception(e) raise HTTPException( @@ -575,7 +586,7 @@ class ContentForm(BaseModel): @router.post("/{id}/data/content/update") -async def update_file_data_content_by_id( +def update_file_data_content_by_id( request: Request, id: str, form_data: ContentForm, @@ -825,6 +836,23 @@ async def delete_file_by_id( or has_access_to_file(id, "write", user, db=db) ): + # Clean up KB associations and embeddings before deleting + knowledges = Knowledges.get_knowledges_by_file_id(id, db=db) + for knowledge in knowledges: + # Remove KB-file relationship + Knowledges.remove_file_from_knowledge_by_id(knowledge.id, id, db=db) + # Clean KB embeddings (same logic as /knowledge/{id}/file/remove) + try: + VECTOR_DB_CLIENT.delete( + collection_name=knowledge.id, filter={"file_id": id} + ) + if file.hash: + VECTOR_DB_CLIENT.delete( + collection_name=knowledge.id, filter={"hash": file.hash} + ) + except Exception as e: + log.debug(f"KB embedding cleanup for {knowledge.id}: {e}") + result = Files.delete_file_by_id(id, db=db) if result: try: diff --git a/backend/open_webui/routers/folders.py b/backend/open_webui/routers/folders.py index 1c9b2229cf..ab0326c882 100644 --- a/backend/open_webui/routers/folders.py +++ b/backend/open_webui/routers/folders.py @@ -33,7 +33,6 @@ from open_webui.utils.auth import get_admin_user, get_verified_user from open_webui.utils.access_control import has_permission - log = logging.getLogger(__name__) diff --git a/backend/open_webui/routers/functions.py b/backend/open_webui/routers/functions.py index ad47318911..3af3b1664a 100644 --- a/backend/open_webui/routers/functions.py +++ b/backend/open_webui/routers/functions.py @@ -19,6 +19,7 @@ load_function_module_by_id, replace_imports, get_function_module_from_cache, + resolve_valves_schema_options, ) from open_webui.config import CACHE_DIR from open_webui.constants import ERROR_MESSAGES @@ -28,7 +29,6 @@ from open_webui.internal.db import get_session from sqlalchemy.orm import Session - log = logging.getLogger(__name__) @@ -446,7 +446,10 @@ async def get_function_valves_spec_by_id( if hasattr(function_module, "Valves"): Valves = function_module.Valves - return Valves.schema() + schema = Valves.schema() + # Resolve dynamic options for select dropdowns + schema = resolve_valves_schema_options(Valves, schema, user) + return schema return None else: raise HTTPException( @@ -546,7 +549,10 @@ async def get_function_user_valves_spec_by_id( if hasattr(function_module, "UserValves"): UserValves = function_module.UserValves - return UserValves.schema() + schema = UserValves.schema() + # Resolve dynamic options for select dropdowns + schema = resolve_valves_schema_options(UserValves, schema, user) + return schema return None else: raise HTTPException( diff --git a/backend/open_webui/routers/groups.py b/backend/open_webui/routers/groups.py index cc0cb8f5a3..3711a52ab4 100755 --- a/backend/open_webui/routers/groups.py +++ b/backend/open_webui/routers/groups.py @@ -7,6 +7,7 @@ from open_webui.models.groups import ( Groups, GroupForm, + GroupInfoResponse, GroupUpdateForm, GroupResponse, UserIdsForm, @@ -21,7 +22,6 @@ from open_webui.utils.auth import get_admin_user, get_verified_user - log = logging.getLogger(__name__) router = APIRouter() @@ -104,6 +104,23 @@ async def get_group_by_id( ) +@router.get("/id/{id}/info", response_model=Optional[GroupInfoResponse]) +async def get_group_info_by_id( + id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) +): + group = Groups.get_group_by_id(id, db=db) + if group: + return GroupInfoResponse( + **group.model_dump(), + member_count=Groups.get_group_member_count_by_id(group.id, db=db), + ) + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + ############################ # ExportGroupById ############################ diff --git a/backend/open_webui/routers/images.py b/backend/open_webui/routers/images.py index 0fc6930b81..942848bc2f 100644 --- a/backend/open_webui/routers/images.py +++ b/backend/open_webui/routers/images.py @@ -14,8 +14,13 @@ from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile from fastapi.responses import FileResponse -from open_webui.config import CACHE_DIR +from open_webui.config import ( + CACHE_DIR, + IMAGE_AUTO_SIZE_MODELS_REGEX_PATTERN, + IMAGE_URL_RESPONSE_MODELS_REGEX_PATTERN, +) from open_webui.constants import ERROR_MESSAGES +from open_webui.retrieval.web.utils import validate_url from open_webui.env import ENABLE_FORWARD_USER_INFO_HEADERS from open_webui.models.chats import Chats @@ -198,14 +203,13 @@ async def update_config( request.app.state.config.IMAGE_GENERATION_ENGINE = form_data.IMAGE_GENERATION_ENGINE set_image_model(request, form_data.IMAGE_GENERATION_MODEL) - if ( - form_data.IMAGE_SIZE == "auto" - and not form_data.IMAGE_GENERATION_MODEL.startswith("gpt-image") + if form_data.IMAGE_SIZE == "auto" and not re.match( + IMAGE_AUTO_SIZE_MODELS_REGEX_PATTERN, form_data.IMAGE_GENERATION_MODEL ): raise HTTPException( status_code=400, detail=ERROR_MESSAGES.INCORRECT_FORMAT( - " (auto is only allowed with gpt-image models)." + f" (auto is only allowed with models matching {IMAGE_AUTO_SIZE_MODELS_REGEX_PATTERN})." ), ) @@ -609,8 +613,9 @@ async def image_generations( ), **( {} - if request.app.state.config.IMAGE_GENERATION_MODEL.startswith( - "gpt-image" + if re.match( + IMAGE_URL_RESPONSE_MODELS_REGEX_PATTERN, + request.app.state.config.IMAGE_GENERATION_MODEL, ) else {"response_format": "b64_json"} ), @@ -881,6 +886,8 @@ async def load_url_image(data): return data if data.startswith("http://") or data.startswith("https://"): + # Validate URL to prevent SSRF attacks against local/private networks + validate_url(data) r = await asyncio.to_thread(requests.get, data) r.raise_for_status() @@ -910,7 +917,10 @@ async def load_url_image(data): if isinstance(form_data.image, str): form_data.image = await load_url_image(form_data.image) elif isinstance(form_data.image, list): - form_data.image = [await load_url_image(img) for img in form_data.image] + # Load all images in parallel for better performance + form_data.image = list( + await asyncio.gather(*[load_url_image(img) for img in form_data.image]) + ) except Exception as e: raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) @@ -945,7 +955,10 @@ def get_image_file_item(base64_string, param_name="image"): **({"size": size} if size else {}), **( {} - if request.app.state.config.IMAGE_EDIT_MODEL.startswith("gpt-image") + if re.match( + IMAGE_URL_RESPONSE_MODELS_REGEX_PATTERN, + request.app.state.config.IMAGE_EDIT_MODEL, + ) else {"response_format": "b64_json"} ), } diff --git a/backend/open_webui/routers/knowledge.py b/backend/open_webui/routers/knowledge.py index 9fc30424ca..eab00aa19b 100644 --- a/backend/open_webui/routers/knowledge.py +++ b/backend/open_webui/routers/knowledge.py @@ -29,13 +29,13 @@ from open_webui.constants import ERROR_MESSAGES from open_webui.utils.auth import get_verified_user, get_admin_user -from open_webui.utils.access_control import has_access, has_permission +from open_webui.utils.access_control import has_permission +from open_webui.models.access_grants import AccessGrants, has_public_read_access_grant from open_webui.config import BYPASS_ADMIN_ACCESS_CONTROL from open_webui.models.models import Models, ModelForm - log = logging.getLogger(__name__) router = APIRouter() @@ -133,8 +133,12 @@ async def get_knowledge_bases( write_access=( user.id == knowledge_base.user_id or (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL) - or has_access( - user.id, "write", knowledge_base.access_control, db=db + or AccessGrants.has_access( + user_id=user.id, + resource_type="knowledge", + resource_id=knowledge_base.id, + permission="write", + db=db, ) ), ) @@ -180,8 +184,12 @@ async def search_knowledge_bases( write_access=( user.id == knowledge_base.user_id or (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL) - or has_access( - user.id, "write", knowledge_base.access_control, db=db + or AccessGrants.has_access( + user_id=user.id, + resource_type="knowledge", + resource_id=knowledge_base.id, + permission="write", + db=db, ) ), ) @@ -227,10 +235,13 @@ async def create_new_knowledge( request: Request, form_data: KnowledgeForm, user=Depends(get_verified_user), - db: Session = Depends(get_session), ): + # NOTE: We intentionally do NOT use Depends(get_session) here. + # Database operations (has_permission, insert_new_knowledge) manage their own sessions. + # This prevents holding a connection during embed_knowledge_base_metadata() + # which makes external embedding API calls (1-5+ seconds). if user.role != "admin" and not has_permission( - user.id, "workspace.knowledge", request.app.state.config.USER_PERMISSIONS, db=db + user.id, "workspace.knowledge", request.app.state.config.USER_PERMISSIONS ): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -240,17 +251,16 @@ async def create_new_knowledge( # Check if user can share publicly if ( user.role != "admin" - and form_data.access_control == None + and has_public_read_access_grant(form_data.access_grants) and not has_permission( user.id, "sharing.public_knowledge", request.app.state.config.USER_PERMISSIONS, - db=db, ) ): - form_data.access_control = {} + form_data.access_grants = [] - knowledge = Knowledges.insert_new_knowledge(user.id, form_data, db=db) + knowledge = Knowledges.insert_new_knowledge(user.id, form_data) if knowledge: # Embed knowledge base for semantic search @@ -345,10 +355,15 @@ async def reindex_knowledge_files( async def reindex_knowledge_base_metadata_embeddings( request: Request, user=Depends(get_admin_user), - db: Session = Depends(get_session), ): - """Batch embed all existing knowledge bases. Admin only.""" - knowledge_bases = Knowledges.get_knowledge_bases(db=db) + """Batch embed all existing knowledge bases. Admin only. + + NOTE: We intentionally do NOT use Depends(get_session) here. + This endpoint loops through ALL knowledge bases and calls embed_knowledge_base_metadata() + for each one, making N external embedding API calls. Holding a session during + this entire operation would exhaust the connection pool. + """ + knowledge_bases = Knowledges.get_knowledge_bases() log.info(f"Reindexing embeddings for {len(knowledge_bases)} knowledge bases") success_count = 0 @@ -380,7 +395,13 @@ async def get_knowledge_by_id( if ( user.role == "admin" or knowledge.user_id == user.id - or has_access(user.id, "read", knowledge.access_control, db=db) + or AccessGrants.has_access( + user_id=user.id, + resource_type="knowledge", + resource_id=knowledge.id, + permission="read", + db=db, + ) ): return KnowledgeFilesResponse( @@ -388,7 +409,13 @@ async def get_knowledge_by_id( write_access=( user.id == knowledge.user_id or (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL) - or has_access(user.id, "write", knowledge.access_control, db=db) + or AccessGrants.has_access( + user_id=user.id, + resource_type="knowledge", + resource_id=knowledge.id, + permission="write", + db=db, + ) ), ) else: @@ -414,9 +441,12 @@ async def update_knowledge_by_id( id: str, form_data: KnowledgeForm, user=Depends(get_verified_user), - db: Session = Depends(get_session), ): - knowledge = Knowledges.get_knowledge_by_id(id=id, db=db) + # NOTE: We intentionally do NOT use Depends(get_session) here. + # Database operations manage their own short-lived sessions internally. + # This prevents holding a connection during embed_knowledge_base_metadata() + # which makes external embedding API calls (1-5+ seconds). + knowledge = Knowledges.get_knowledge_by_id(id=id) if not knowledge: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -425,7 +455,12 @@ async def update_knowledge_by_id( # Is the user the original creator, in a group with write access, or an admin if ( knowledge.user_id != user.id - and not has_access(user.id, "write", knowledge.access_control, db=db) + and not AccessGrants.has_access( + user_id=user.id, + resource_type="knowledge", + resource_id=knowledge.id, + permission="write", + ) and user.role != "admin" ): raise HTTPException( @@ -436,17 +471,16 @@ async def update_knowledge_by_id( # Check if user can share publicly if ( user.role != "admin" - and form_data.access_control == None + and has_public_read_access_grant(form_data.access_grants) and not has_permission( user.id, "sharing.public_knowledge", request.app.state.config.USER_PERMISSIONS, - db=db, ) ): - form_data.access_control = {} + form_data.access_grants = [] - knowledge = Knowledges.update_knowledge_by_id(id=id, form_data=form_data, db=db) + knowledge = Knowledges.update_knowledge_by_id(id=id, form_data=form_data) if knowledge: # Re-embed knowledge base for semantic search await embed_knowledge_base_metadata( @@ -457,7 +491,7 @@ async def update_knowledge_by_id( ) return KnowledgeFilesResponse( **knowledge.model_dump(), - files=Knowledges.get_file_metadatas_by_id(knowledge.id, db=db), + files=Knowledges.get_file_metadatas_by_id(knowledge.id), ) else: raise HTTPException( @@ -466,6 +500,53 @@ async def update_knowledge_by_id( ) +############################ +# UpdateKnowledgeAccessById +############################ + + +class KnowledgeAccessGrantsForm(BaseModel): + access_grants: list[dict] + + +@router.post("/{id}/access/update", response_model=Optional[KnowledgeFilesResponse]) +async def update_knowledge_access_by_id( + id: str, + form_data: KnowledgeAccessGrantsForm, + user=Depends(get_verified_user), + db: Session = Depends(get_session), +): + knowledge = Knowledges.get_knowledge_by_id(id=id, db=db) + if not knowledge: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + if ( + knowledge.user_id != user.id + and not AccessGrants.has_access( + user_id=user.id, + resource_type="knowledge", + resource_id=knowledge.id, + permission="write", + db=db, + ) + and user.role != "admin" + ): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) + + AccessGrants.set_access_grants("knowledge", id, form_data.access_grants, db=db) + + return KnowledgeFilesResponse( + **Knowledges.get_knowledge_by_id(id=id, db=db).model_dump(), + files=Knowledges.get_file_metadatas_by_id(id, db=db), + ) + + ############################ # GetKnowledgeFilesById ############################ @@ -493,7 +574,13 @@ async def get_knowledge_files_by_id( if not ( user.role == "admin" or knowledge.user_id == user.id - or has_access(user.id, "read", knowledge.access_control, db=db) + or AccessGrants.has_access( + user_id=user.id, + resource_type="knowledge", + resource_id=knowledge.id, + permission="read", + db=db, + ) ): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -546,7 +633,13 @@ def add_file_to_knowledge_by_id( if ( knowledge.user_id != user.id - and not has_access(user.id, "write", knowledge.access_control, db=db) + and not AccessGrants.has_access( + user_id=user.id, + resource_type="knowledge", + resource_id=knowledge.id, + permission="write", + db=db, + ) and user.role != "admin" ): raise HTTPException( @@ -615,7 +708,13 @@ def update_file_from_knowledge_by_id( if ( knowledge.user_id != user.id - and not has_access(user.id, "write", knowledge.access_control, db=db) + and not AccessGrants.has_access( + user_id=user.id, + resource_type="knowledge", + resource_id=knowledge.id, + permission="write", + db=db, + ) and user.role != "admin" ): @@ -684,7 +783,13 @@ def remove_file_from_knowledge_by_id( if ( knowledge.user_id != user.id - and not has_access(user.id, "write", knowledge.access_control, db=db) + and not AccessGrants.has_access( + user_id=user.id, + resource_type="knowledge", + resource_id=knowledge.id, + permission="write", + db=db, + ) and user.role != "admin" ): raise HTTPException( @@ -761,7 +866,13 @@ async def delete_knowledge_by_id( if ( knowledge.user_id != user.id - and not has_access(user.id, "write", knowledge.access_control, db=db) + and not AccessGrants.has_access( + user_id=user.id, + resource_type="knowledge", + resource_id=knowledge.id, + permission="write", + db=db, + ) and user.role != "admin" ): raise HTTPException( @@ -793,7 +904,7 @@ async def delete_knowledge_by_id( base_model_id=model.base_model_id, meta=model.meta, params=model.params, - access_control=model.access_control, + access_grants=model.access_grants, is_active=model.is_active, ) Models.update_model_by_id(model.id, model_form, db=db) @@ -830,7 +941,13 @@ async def reset_knowledge_by_id( if ( knowledge.user_id != user.id - and not has_access(user.id, "write", knowledge.access_control, db=db) + and not AccessGrants.has_access( + user_id=user.id, + resource_type="knowledge", + resource_id=knowledge.id, + permission="write", + db=db, + ) and user.role != "admin" ): raise HTTPException( @@ -873,7 +990,13 @@ async def add_files_to_knowledge_batch( if ( knowledge.user_id != user.id - and not has_access(user.id, "write", knowledge.access_control, db=db) + and not AccessGrants.has_access( + user_id=user.id, + resource_type="knowledge", + resource_id=knowledge.id, + permission="write", + db=db, + ) and user.role != "admin" ): raise HTTPException( @@ -881,17 +1004,19 @@ async def add_files_to_knowledge_batch( detail=ERROR_MESSAGES.ACCESS_PROHIBITED, ) - # Get files content + # Batch-fetch all files to avoid N+1 queries log.info(f"files/batch/add - {len(form_data)} files") - files: List[FileModel] = [] - for form in form_data: - file = Files.get_file_by_id(form.file_id, db=db) - if not file: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=f"File {form.file_id} not found", - ) - files.append(file) + file_ids = [form.file_id for form in form_data] + files = Files.get_files_by_ids(file_ids, db=db) + + # Verify all requested files were found + found_ids = {file.id for file in files} + missing_ids = [fid for fid in file_ids if fid not in found_ids] + if missing_ids: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"File {missing_ids[0]} not found", + ) # Process files try: diff --git a/backend/open_webui/routers/memories.py b/backend/open_webui/routers/memories.py index e0ba36c76f..db2bf5e238 100644 --- a/backend/open_webui/routers/memories.py +++ b/backend/open_webui/routers/memories.py @@ -18,11 +18,6 @@ router = APIRouter() -@router.get("/ef") -async def get_embeddings(request: Request): - return {"result": await request.app.state.EMBEDDING_FUNCTION("hello world")} - - ############################ # GetMemories ############################ @@ -69,8 +64,11 @@ async def add_memory( request: Request, form_data: AddMemoryForm, user=Depends(get_verified_user), - db: Session = Depends(get_session), ): + # NOTE: We intentionally do NOT use Depends(get_session) here. + # Database operations (insert_new_memory) manage their own short-lived sessions. + # This prevents holding a connection during EMBEDDING_FUNCTION() + # which makes external embedding API calls (1-5+ seconds). if not request.app.state.config.ENABLE_MEMORIES: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, @@ -85,7 +83,7 @@ async def add_memory( detail=ERROR_MESSAGES.ACCESS_PROHIBITED, ) - memory = Memories.insert_new_memory(user.id, form_data.content, db=db) + memory = Memories.insert_new_memory(user.id, form_data.content) vector = await request.app.state.EMBEDDING_FUNCTION(memory.content, user=user) @@ -119,8 +117,11 @@ async def query_memory( request: Request, form_data: QueryMemoryForm, user=Depends(get_verified_user), - db: Session = Depends(get_session), ): + # NOTE: We intentionally do NOT use Depends(get_session) here. + # Database operations (get_memories_by_user_id) manage their own short-lived sessions. + # This prevents holding a connection during EMBEDDING_FUNCTION() + # which makes external embedding API calls (1-5+ seconds). if not request.app.state.config.ENABLE_MEMORIES: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, @@ -135,7 +136,7 @@ async def query_memory( detail=ERROR_MESSAGES.ACCESS_PROHIBITED, ) - memories = Memories.get_memories_by_user_id(user.id, db=db) + memories = Memories.get_memories_by_user_id(user.id) if not memories: raise HTTPException(status_code=404, detail="No memories found for user") @@ -157,8 +158,15 @@ async def query_memory( async def reset_memory_from_vector_db( request: Request, user=Depends(get_verified_user), - db: Session = Depends(get_session), ): + """Reset user's memory vector embeddings. + + CRITICAL: We intentionally do NOT use Depends(get_session) here. + This endpoint generates embeddings for ALL user memories in parallel using + asyncio.gather(). A user with 100 memories would trigger 100 embedding API + calls simultaneously. With a session held, this could block a connection + for MINUTES, completely exhausting the connection pool. + """ if not request.app.state.config.ENABLE_MEMORIES: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, @@ -175,7 +183,7 @@ async def reset_memory_from_vector_db( VECTOR_DB_CLIENT.delete_collection(f"user-memory-{user.id}") - memories = Memories.get_memories_by_user_id(user.id, db=db) + memories = Memories.get_memories_by_user_id(user.id) # Generate vectors in parallel vectors = await asyncio.gather( @@ -252,8 +260,11 @@ async def update_memory_by_id( request: Request, form_data: MemoryUpdateModel, user=Depends(get_verified_user), - db: Session = Depends(get_session), ): + # NOTE: We intentionally do NOT use Depends(get_session) here. + # Database operations (update_memory_by_id_and_user_id) manage their own + # short-lived sessions. This prevents holding a connection during + # EMBEDDING_FUNCTION() which makes external API calls (1-5+ seconds). if not request.app.state.config.ENABLE_MEMORIES: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, @@ -269,7 +280,7 @@ async def update_memory_by_id( ) memory = Memories.update_memory_by_id_and_user_id( - memory_id, user.id, form_data.content, db=db + memory_id, user.id, form_data.content ) if memory is None: raise HTTPException(status_code=404, detail="Memory not found") diff --git a/backend/open_webui/routers/models.py b/backend/open_webui/routers/models.py index a1f642bbce..7202262bbd 100644 --- a/backend/open_webui/routers/models.py +++ b/backend/open_webui/routers/models.py @@ -15,6 +15,7 @@ ModelAccessResponse, Models, ) +from open_webui.models.access_grants import AccessGrants from pydantic import BaseModel from open_webui.constants import ERROR_MESSAGES @@ -30,7 +31,7 @@ from open_webui.utils.auth import get_admin_user, get_verified_user -from open_webui.utils.access_control import has_access, has_permission +from open_webui.utils.access_control import has_permission from open_webui.config import BYPASS_ADMIN_ACCESS_CONTROL, STATIC_DIR from open_webui.internal.db import get_session from sqlalchemy.orm import Session @@ -98,7 +99,13 @@ async def get_models( write_access=( (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL) or user.id == model.user_id - or has_access(user.id, "write", model.access_control, db=db) + or AccessGrants.has_access( + user_id=user.id, + resource_type="model", + resource_id=model.id, + permission="write", + db=db, + ) ), ) for model in result.items @@ -246,12 +253,25 @@ async def import_models( try: data = form_data.models if isinstance(data, list): + # Batch-fetch all existing models in one query to avoid N+1 + model_ids = [ + model_data.get("id") + for model_data in data + if model_data.get("id") and is_valid_model_id(model_data.get("id")) + ] + existing_models = { + model.id: model + for model in ( + Models.get_models_by_ids(model_ids, db=db) if model_ids else [] + ) + } + for model_data in data: # Here, you can add logic to validate model_data if needed model_id = model_data.get("id") if model_id and is_valid_model_id(model_id): - existing_model = Models.get_model_by_id(model_id, db=db) + existing_model = existing_models.get(model_id) if existing_model: # Update existing model model_data["meta"] = model_data.get("meta", {}) @@ -315,14 +335,26 @@ async def get_model_by_id( if ( (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL) or model.user_id == user.id - or has_access(user.id, "read", model.access_control, db=db) + or AccessGrants.has_access( + user_id=user.id, + resource_type="model", + resource_id=model.id, + permission="read", + db=db, + ) ): return ModelAccessResponse( **model.model_dump(), write_access=( (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL) or user.id == model.user_id - or has_access(user.id, "write", model.access_control, db=db) + or AccessGrants.has_access( + user_id=user.id, + resource_type="model", + resource_id=model.id, + permission="write", + db=db, + ) ), ) else: @@ -343,10 +375,8 @@ async def get_model_by_id( @router.get("/model/profile/image") -def get_model_profile_image( - id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) -): - model = Models.get_model_by_id(id, db=db) +def get_model_profile_image(id: str, user=Depends(get_verified_user)): + model = Models.get_model_by_id(id) if model: etag = f'"{model.updated_at}"' if model.updated_at else None @@ -395,7 +425,13 @@ async def toggle_model_by_id( if ( user.role == "admin" or model.user_id == user.id - or has_access(user.id, "write", model.access_control, db=db) + or AccessGrants.has_access( + user_id=user.id, + resource_type="model", + resource_id=model.id, + permission="write", + db=db, + ) ): model = Models.toggle_model_by_id(id, db=db) @@ -438,7 +474,13 @@ async def update_model_by_id( if ( model.user_id != user.id - and not has_access(user.id, "write", model.access_control, db=db) + and not AccessGrants.has_access( + user_id=user.id, + resource_type="model", + resource_id=model.id, + permission="write", + db=db, + ) and user.role != "admin" ): raise HTTPException( @@ -452,6 +494,52 @@ async def update_model_by_id( return model +############################ +# UpdateModelAccessById +############################ + + +class ModelAccessGrantsForm(BaseModel): + id: str + access_grants: list[dict] + + +@router.post("/model/access/update", response_model=Optional[ModelModel]) +async def update_model_access_by_id( + form_data: ModelAccessGrantsForm, + user=Depends(get_verified_user), + db: Session = Depends(get_session), +): + model = Models.get_model_by_id(form_data.id, db=db) + if not model: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + if ( + model.user_id != user.id + and not AccessGrants.has_access( + user_id=user.id, + resource_type="model", + resource_id=model.id, + permission="write", + db=db, + ) + and user.role != "admin" + ): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) + + AccessGrants.set_access_grants( + "model", form_data.id, form_data.access_grants, db=db + ) + + return Models.get_model_by_id(form_data.id, db=db) + + ############################ # DeleteModelById ############################ @@ -473,7 +561,13 @@ async def delete_model_by_id( if ( user.role != "admin" and model.user_id != user.id - and not has_access(user.id, "write", model.access_control, db=db) + and not AccessGrants.has_access( + user_id=user.id, + resource_type="model", + resource_id=model.id, + permission="write", + db=db, + ) ): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, diff --git a/backend/open_webui/routers/notes.py b/backend/open_webui/routers/notes.py index 56730e2b6a..04841e87cc 100644 --- a/backend/open_webui/routers/notes.py +++ b/backend/open_webui/routers/notes.py @@ -27,7 +27,8 @@ from open_webui.utils.auth import get_admin_user, get_verified_user -from open_webui.utils.access_control import has_access, has_permission +from open_webui.utils.access_control import has_permission +from open_webui.models.access_grants import AccessGrants, has_public_read_access_grant from open_webui.internal.db import get_session from sqlalchemy.orm import Session @@ -200,8 +201,12 @@ async def get_note_by_id( if user.role != "admin" and ( user.id != note.user_id and ( - not has_access( - user.id, type="read", access_control=note.access_control, db=db + not AccessGrants.has_access( + user_id=user.id, + resource_type="note", + resource_id=note.id, + permission="read", + db=db, ) ) ): @@ -212,13 +217,14 @@ async def get_note_by_id( write_access = ( user.role == "admin" or (user.id == note.user_id) - or has_access( - user.id, - type="write", - access_control=note.access_control, - strict=False, + or AccessGrants.has_access( + user_id=user.id, + resource_type="note", + resource_id=note.id, + permission="write", db=db, ) + or has_public_read_access_grant(note.access_grants) ) return NoteResponse(**note.model_dump(), write_access=write_access) @@ -253,8 +259,12 @@ async def update_note_by_id( if user.role != "admin" and ( user.id != note.user_id - and not has_access( - user.id, type="write", access_control=note.access_control, db=db + and not AccessGrants.has_access( + user_id=user.id, + resource_type="note", + resource_id=note.id, + permission="write", + db=db, ) ): raise HTTPException( @@ -264,7 +274,7 @@ async def update_note_by_id( # Check if user can share publicly if ( user.role != "admin" - and form_data.access_control == None + and has_public_read_access_grant(form_data.access_grants) and not has_permission( user.id, "sharing.public_notes", @@ -272,7 +282,7 @@ async def update_note_by_id( db=db, ) ): - form_data.access_control = {} + form_data.access_grants = [] try: note = Notes.update_note_by_id(id, form_data, db=db) @@ -290,6 +300,56 @@ async def update_note_by_id( ) +############################ +# UpdateNoteAccessById +############################ + + +class NoteAccessGrantsForm(BaseModel): + access_grants: list[dict] + + +@router.post("/{id}/access/update", response_model=Optional[NoteModel]) +async def update_note_access_by_id( + request: Request, + id: str, + form_data: NoteAccessGrantsForm, + user=Depends(get_verified_user), + db: Session = Depends(get_session), +): + if user.role != "admin" and not has_permission( + user.id, "features.notes", request.app.state.config.USER_PERMISSIONS, db=db + ): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.UNAUTHORIZED, + ) + + note = Notes.get_note_by_id(id, db=db) + if not note: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND + ) + + if user.role != "admin" and ( + user.id != note.user_id + and not AccessGrants.has_access( + user_id=user.id, + resource_type="note", + resource_id=note.id, + permission="write", + db=db, + ) + ): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() + ) + + AccessGrants.set_access_grants("note", id, form_data.access_grants, db=db) + + return Notes.get_note_by_id(id, db=db) + + ############################ # DeleteNoteById ############################ @@ -318,8 +378,12 @@ async def delete_note_by_id( if user.role != "admin" and ( user.id != note.user_id - and not has_access( - user.id, type="write", access_control=note.access_control, db=db + and not AccessGrants.has_access( + user_id=user.id, + resource_type="note", + resource_id=note.id, + permission="write", + db=db, ) ): raise HTTPException( diff --git a/backend/open_webui/routers/ollama.py b/backend/open_webui/routers/ollama.py index afaca2feb5..1d72d7efdf 100644 --- a/backend/open_webui/routers/ollama.py +++ b/backend/open_webui/routers/ollama.py @@ -23,6 +23,7 @@ from open_webui.env import ( ENABLE_FORWARD_USER_INFO_HEADERS, + FORWARD_SESSION_INFO_HEADER_CHAT_ID, ) from fastapi import ( @@ -37,17 +38,21 @@ from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse from pydantic import BaseModel, ConfigDict, validator -from starlette.background import BackgroundTask + from sqlalchemy.orm import Session from open_webui.internal.db import get_session from open_webui.models.models import Models +from open_webui.models.access_grants import AccessGrants +from open_webui.models.groups import Groups from open_webui.utils.credit.usage import CreditDeduct from open_webui.utils.credit.utils import check_credit_by_user_id from open_webui.utils.misc import ( calculate_sha256, + cleanup_response, + stream_wrapper, ) from open_webui.utils.payload import ( apply_model_params_to_body_ollama, @@ -55,9 +60,6 @@ apply_system_prompt_to_body, ) from open_webui.utils.auth import get_admin_user, get_verified_user -from open_webui.utils.access_control import has_access - - from open_webui.config import ( UPLOAD_DIR, ) @@ -105,16 +107,6 @@ async def send_get_request(url, key=None, user: UserModel = None): return None -async def cleanup_response( - response: Optional[aiohttp.ClientResponse], - session: Optional[aiohttp.ClientSession], -): - if response: - response.close() - if session: - await session.close() - - async def send_post_request( url: str, payload: Union[str, bytes], @@ -126,6 +118,7 @@ async def send_post_request( ): r = None + streaming = False try: session = aiohttp.ClientSession( trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT) @@ -139,7 +132,7 @@ async def send_post_request( if ENABLE_FORWARD_USER_INFO_HEADERS and user: headers = include_user_info_headers(headers, user) if metadata and metadata.get("chat_id"): - headers["X-OpenWebUI-Chat-Id"] = metadata.get("chat_id") + headers[FORWARD_SESSION_INFO_HEADER_CHAT_ID] = metadata.get("chat_id") r = await session.post( url, @@ -170,13 +163,11 @@ async def send_post_request( if content_type: response_headers["Content-Type"] = content_type + streaming = True return StreamingResponse( - r.content, + stream_wrapper(user, payload["model"], payload, r, session), status_code=r.status, headers=response_headers, - background=BackgroundTask( - cleanup_response, response=r, session=session - ), ) else: res = await r.json() @@ -192,7 +183,7 @@ async def send_post_request( detail=detail if e else "Open WebUI: Server Connection Error", ) finally: - if not stream: + if not streaming: await cleanup_response(r, session) @@ -428,12 +419,21 @@ async def get_all_models(request: Request, user: UserModel = None): async def get_filtered_models(models, user, db=None): # Filter models based on user access control + model_ids = [model["model"] for model in models.get("models", [])] + model_infos = {m.id: m for m in Models.get_models_by_ids(model_ids, db=db)} + user_group_ids = {g.id for g in Groups.get_groups_by_member_id(user.id, db=db)} + filtered_models = [] for model in models.get("models", []): - model_info = Models.get_model_by_id(model["model"], db=db) + model_info = model_infos.get(model["model"]) if model_info: - if user.id == model_info.user_id or has_access( - user.id, type="read", access_control=model_info.access_control, db=db + if user.id == model_info.user_id or AccessGrants.has_access( + user_id=user.id, + resource_type="model", + resource_id=model_info.id, + permission="read", + user_group_ids=user_group_ids, + db=db, ): filtered_models.append(model) return filtered_models @@ -444,6 +444,9 @@ async def get_filtered_models(models, user, db=None): async def get_ollama_tags( request: Request, url_idx: Optional[int] = None, user=Depends(get_verified_user) ): + if not request.app.state.config.ENABLE_OLLAMA_API: + raise HTTPException(status_code=503, detail="Ollama API is disabled") + models = [] if url_idx is None: @@ -708,6 +711,9 @@ async def pull_model( url_idx: int = 0, user=Depends(get_admin_user), ): + if not request.app.state.config.ENABLE_OLLAMA_API: + raise HTTPException(status_code=503, detail="Ollama API is disabled") + form_data = form_data.model_dump(exclude_none=True) form_data["model"] = form_data.get("model", form_data.get("name")) @@ -739,6 +745,9 @@ async def push_model( url_idx: Optional[int] = None, user=Depends(get_admin_user), ): + if not request.app.state.config.ENABLE_OLLAMA_API: + raise HTTPException(status_code=503, detail="Ollama API is disabled") + if url_idx is None: await get_all_models(request, user=user) models = request.app.state.OLLAMA_MODELS @@ -778,6 +787,9 @@ async def create_model( url_idx: int = 0, user=Depends(get_admin_user), ): + if not request.app.state.config.ENABLE_OLLAMA_API: + raise HTTPException(status_code=503, detail="Ollama API is disabled") + log.debug(f"form_data: {form_data}") url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] @@ -802,6 +814,9 @@ async def copy_model( url_idx: Optional[int] = None, user=Depends(get_admin_user), ): + if not request.app.state.config.ENABLE_OLLAMA_API: + raise HTTPException(status_code=503, detail="Ollama API is disabled") + if url_idx is None: await get_all_models(request, user=user) models = request.app.state.OLLAMA_MODELS @@ -862,6 +877,9 @@ async def delete_model( url_idx: Optional[int] = None, user=Depends(get_admin_user), ): + if not request.app.state.config.ENABLE_OLLAMA_API: + raise HTTPException(status_code=503, detail="Ollama API is disabled") + form_data = form_data.model_dump(exclude_none=True) form_data["model"] = form_data.get("model", form_data.get("name")) @@ -924,6 +942,9 @@ async def delete_model( async def show_model_info( request: Request, form_data: ModelNameForm, user=Depends(get_verified_user) ): + if not request.app.state.config.ENABLE_OLLAMA_API: + raise HTTPException(status_code=503, detail="Ollama API is disabled") + form_data = form_data.model_dump(exclude_none=True) form_data["model"] = form_data.get("model", form_data.get("name")) @@ -996,16 +1017,19 @@ async def embed( url_idx: Optional[int] = None, user=Depends(get_verified_user), ): + if not request.app.state.config.ENABLE_OLLAMA_API: + raise HTTPException(status_code=503, detail="Ollama API is disabled") + log.info(f"generate_ollama_batch_embeddings {form_data}") if url_idx is None: - await get_all_models(request, user=user) - models = request.app.state.OLLAMA_MODELS - model = form_data.model - if ":" not in model: - model = f"{model}:latest" + # Check if model is already in app state cache to avoid expensive get_all_models() call + models = request.app.state.OLLAMA_MODELS + if not models or model not in models: + await get_all_models(request, user=user) + models = request.app.state.OLLAMA_MODELS if model in models: url_idx = random.choice(models[model]["urls"]) @@ -1078,6 +1102,9 @@ async def embeddings( url_idx: Optional[int] = None, user=Depends(get_verified_user), ): + if not request.app.state.config.ENABLE_OLLAMA_API: + raise HTTPException(status_code=503, detail="Ollama API is disabled") + log.info(f"generate_ollama_embeddings {form_data}") # check credit @@ -1085,13 +1112,13 @@ async def embeddings( check_credit_by_user_id(user_id=user.id, form_data={}, is_embedding=True) if url_idx is None: - await get_all_models(request, user=user) - models = request.app.state.OLLAMA_MODELS - model = form_data.model - if ":" not in model: - model = f"{model}:latest" + # Check if model is already in app state cache to avoid expensive get_all_models() call + models = request.app.state.OLLAMA_MODELS + if not models or model not in models: + await get_all_models(request, user=user) + models = request.app.state.OLLAMA_MODELS if model in models: url_idx = random.choice(models[model]["urls"]) @@ -1185,15 +1212,14 @@ async def generate_completion( url_idx: Optional[int] = None, user=Depends(get_verified_user), ): + if not request.app.state.config.ENABLE_OLLAMA_API: + raise HTTPException(status_code=503, detail="Ollama API is disabled") + if url_idx is None: await get_all_models(request, user=user) models = request.app.state.OLLAMA_MODELS model = form_data.model - - if ":" not in model: - model = f"{model}:latest" - if model in models: url_idx = random.choice(models[model]["urls"]) else: @@ -1276,8 +1302,14 @@ async def generate_chat_completion( user=Depends(get_verified_user), bypass_filter: Optional[bool] = False, bypass_system_prompt: bool = False, - db: Session = Depends(get_session), ): + if not request.app.state.config.ENABLE_OLLAMA_API: + raise HTTPException(status_code=503, detail="Ollama API is disabled") + + # NOTE: We intentionally do NOT use Depends(get_session) here. + # Database operations (get_model_by_id, AccessGrants.has_access) manage their own short-lived sessions. + # This prevents holding a connection during the entire LLM call (30-60+ seconds), + # which would exhaust the connection pool under concurrent load. if BYPASS_MODEL_ACCESS_CONTROL: bypass_filter = True @@ -1298,7 +1330,7 @@ async def generate_chat_completion( del payload["metadata"] model_id = payload["model"] - model_info = Models.get_model_by_id(model_id, db=db) + model_info = Models.get_model_by_id(model_id) if model_info: if model_info.base_model_id: @@ -1322,11 +1354,11 @@ async def generate_chat_completion( if not bypass_filter and user.role == "user": if not ( user.id == model_info.user_id - or has_access( - user.id, - type="read", - access_control=model_info.access_control, - db=db, + or AccessGrants.has_access( + user_id=user.id, + resource_type="model", + resource_id=model_info.id, + permission="read", ) ): raise HTTPException( @@ -1398,8 +1430,11 @@ async def generate_openai_completion( form_data: dict, url_idx: Optional[int] = None, user=Depends(get_verified_user), - db: Session = Depends(get_session), ): + # NOTE: We intentionally do NOT use Depends(get_session) here. + # Database operations (get_model_by_id, AccessGrants.has_access) manage their own short-lived sessions. + # This prevents holding a connection during the entire LLM call (30-60+ seconds), + # which would exhaust the connection pool under concurrent load. metadata = form_data.pop("metadata", None) try: @@ -1419,7 +1454,7 @@ async def generate_openai_completion( if ":" not in model_id: model_id = f"{model_id}:latest" - model_info = Models.get_model_by_id(model_id, db=db) + model_info = Models.get_model_by_id(model_id) if model_info: if model_info.base_model_id: payload["model"] = model_info.base_model_id @@ -1432,11 +1467,11 @@ async def generate_openai_completion( if user.role == "user": if not ( user.id == model_info.user_id - or has_access( - user.id, - type="read", - access_control=model_info.access_control, - db=db, + or AccessGrants.has_access( + user_id=user.id, + resource_type="model", + resource_id=model_info.id, + permission="read", ) ): raise HTTPException( @@ -1481,8 +1516,12 @@ async def generate_openai_chat_completion( form_data: dict, url_idx: Optional[int] = None, user=Depends(get_verified_user), - db: Session = Depends(get_session), ): + # NOTE: We intentionally do NOT use Depends(get_session) here. + # Database operations (get_model_by_id, AccessGrants.has_access) manage their own short-lived sessions. + # This prevents holding a connection during the entire LLM call (30-60+ seconds), + # which would exhaust the connection pool under concurrent load. + check_credit_by_user_id(user_id=user.id, form_data=form_data) metadata = form_data.pop("metadata", None) @@ -1504,7 +1543,7 @@ async def generate_openai_chat_completion( if ":" not in model_id: model_id = f"{model_id}:latest" - model_info = Models.get_model_by_id(model_id, db=db) + model_info = Models.get_model_by_id(model_id) if model_info: if model_info.base_model_id: payload["model"] = model_info.base_model_id @@ -1521,11 +1560,11 @@ async def generate_openai_chat_completion( if user.role == "user": if not ( user.id == model_info.user_id - or has_access( - user.id, - type="read", - access_control=model_info.access_control, - db=db, + or AccessGrants.has_access( + user_id=user.id, + resource_type="model", + resource_id=model_info.id, + permission="read", ) ): raise HTTPException( @@ -1619,14 +1658,20 @@ async def get_openai_models( if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL: # Filter models based on user access control + model_ids = [model["id"] for model in models] + model_infos = {m.id: m for m in Models.get_models_by_ids(model_ids, db=db)} + user_group_ids = {g.id for g in Groups.get_groups_by_member_id(user.id, db=db)} + filtered_models = [] for model in models: - model_info = Models.get_model_by_id(model["id"], db=db) + model_info = model_infos.get(model["id"]) if model_info: - if user.id == model_info.user_id or has_access( - user.id, - type="read", - access_control=model_info.access_control, + if user.id == model_info.user_id or AccessGrants.has_access( + user_id=user.id, + resource_type="model", + resource_id=model_info.id, + permission="read", + user_group_ids=user_group_ids, db=db, ): filtered_models.append(model) diff --git a/backend/open_webui/routers/openai.py b/backend/open_webui/routers/openai.py index f4b87c372d..587509800c 100644 --- a/backend/open_webui/routers/openai.py +++ b/backend/open_webui/routers/openai.py @@ -3,6 +3,7 @@ import json import logging from typing import Optional +from urllib.parse import urlparse import aiohttp from aiocache import cached @@ -18,12 +19,14 @@ PlainTextResponse, ) from pydantic import BaseModel -from starlette.background import BackgroundTask + from sqlalchemy.orm import Session from open_webui.internal.db import get_session from open_webui.models.models import Models +from open_webui.models.access_grants import AccessGrants +from open_webui.models.groups import Groups from open_webui.config import ( CACHE_DIR, ) @@ -33,6 +36,7 @@ AIOHTTP_CLIENT_TIMEOUT, AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST, ENABLE_FORWARD_USER_INFO_HEADERS, + FORWARD_SESSION_INFO_HEADER_CHAT_ID, BYPASS_MODEL_ACCESS_CONTROL, AIOHTTP_CLIENT_READ_BUFFER_SIZE, ) @@ -46,16 +50,16 @@ apply_system_prompt_to_body, ) from open_webui.utils.misc import ( + cleanup_response, convert_logit_bias_input_to_json, stream_chunks_handler, + stream_wrapper, ) from open_webui.utils.auth import get_admin_user, get_verified_user -from open_webui.utils.access_control import has_access from open_webui.utils.credit.usage import CreditDeduct from open_webui.utils.headers import include_user_info_headers - log = logging.getLogger(__name__) @@ -89,16 +93,6 @@ async def send_get_request(url, key=None, user: UserModel = None): return None -async def cleanup_response( - response: Optional[aiohttp.ClientResponse], - session: Optional[aiohttp.ClientSession], -): - if response: - response.close() - if session: - await session.close() - - def openai_reasoning_model_handler(payload): """ Handle reasoning model specific parameters @@ -144,7 +138,7 @@ async def get_headers_and_cookies( if ENABLE_FORWARD_USER_INFO_HEADERS and user: headers = include_user_info_headers(headers, user) if metadata and metadata.get("chat_id"): - headers["X-OpenWebUI-Chat-Id"] = metadata.get("chat_id") + headers[FORWARD_SESSION_INFO_HEADER_CHAT_ID] = metadata.get("chat_id") token = None auth_type = config.get("auth_type") @@ -349,37 +343,41 @@ async def get_all_models_responses(request: Request, user: UserModel) -> list: if not request.app.state.config.ENABLE_OPENAI_API: return [] + # Cache config values locally to avoid repeated Redis lookups. + # Each access to request.app.state.config. triggers a Redis GET; + # caching here avoids hundreds of redundant round-trips. + api_base_urls = request.app.state.config.OPENAI_API_BASE_URLS + api_keys = list(request.app.state.config.OPENAI_API_KEYS) + api_configs = request.app.state.config.OPENAI_API_CONFIGS + # Check if API KEYS length is same than API URLS length - num_urls = len(request.app.state.config.OPENAI_API_BASE_URLS) - num_keys = len(request.app.state.config.OPENAI_API_KEYS) + num_urls = len(api_base_urls) + num_keys = len(api_keys) if num_keys != num_urls: # if there are more keys than urls, remove the extra keys if num_keys > num_urls: - new_keys = request.app.state.config.OPENAI_API_KEYS[:num_urls] - request.app.state.config.OPENAI_API_KEYS = new_keys + api_keys = api_keys[:num_urls] + request.app.state.config.OPENAI_API_KEYS = api_keys # if there are more urls than keys, add empty keys else: - request.app.state.config.OPENAI_API_KEYS += [""] * (num_urls - num_keys) + api_keys += [""] * (num_urls - num_keys) + request.app.state.config.OPENAI_API_KEYS = api_keys request_tasks = [] - for idx, url in enumerate(request.app.state.config.OPENAI_API_BASE_URLS): - if (str(idx) not in request.app.state.config.OPENAI_API_CONFIGS) and ( - url not in request.app.state.config.OPENAI_API_CONFIGS # Legacy support - ): + for idx, url in enumerate(api_base_urls): + if (str(idx) not in api_configs) and (url not in api_configs): # Legacy support request_tasks.append( send_get_request( f"{url}/models", - request.app.state.config.OPENAI_API_KEYS[idx], + api_keys[idx], user=user, ) ) else: - api_config = request.app.state.config.OPENAI_API_CONFIGS.get( + api_config = api_configs.get( str(idx), - request.app.state.config.OPENAI_API_CONFIGS.get( - url, {} - ), # Legacy support + api_configs.get(url, {}), # Legacy support ) enable = api_config.get("enable", True) @@ -390,7 +388,7 @@ async def get_all_models_responses(request: Request, user: UserModel) -> list: request_tasks.append( send_get_request( f"{url}/models", - request.app.state.config.OPENAI_API_KEYS[idx], + api_keys[idx], user=user, ) ) @@ -419,12 +417,10 @@ async def get_all_models_responses(request: Request, user: UserModel) -> list: for idx, response in enumerate(responses): if response: - url = request.app.state.config.OPENAI_API_BASE_URLS[idx] - api_config = request.app.state.config.OPENAI_API_CONFIGS.get( + url = api_base_urls[idx] + api_config = api_configs.get( str(idx), - request.app.state.config.OPENAI_API_CONFIGS.get( - url, {} - ), # Legacy support + api_configs.get(url, {}), # Legacy support ) connection_type = api_config.get("connection_type", "external") @@ -460,12 +456,21 @@ async def get_all_models_responses(request: Request, user: UserModel) -> list: async def get_filtered_models(models, user, db=None): # Filter models based on user access control + model_ids = [model["id"] for model in models.get("data", [])] + model_infos = {m.id: m for m in Models.get_models_by_ids(model_ids, db=db)} + user_group_ids = {g.id for g in Groups.get_groups_by_member_id(user.id, db=db)} + filtered_models = [] for model in models.get("data", []): - model_info = Models.get_model_by_id(model["id"], db=db) + model_info = model_infos.get(model["id"]) if model_info: - if user.id == model_info.user_id or has_access( - user.id, type="read", access_control=model_info.access_control, db=db + if user.id == model_info.user_id or AccessGrants.has_access( + user_id=user.id, + resource_type="model", + resource_id=model_info.id, + permission="read", + user_group_ids=user_group_ids, + db=db, ): filtered_models.append(model) return filtered_models @@ -481,6 +486,10 @@ async def get_all_models(request: Request, user: UserModel) -> dict[str, list]: if not request.app.state.config.ENABLE_OPENAI_API: return {"data": []} + # Cache config value locally to avoid repeated Redis lookups inside + # the nested loop in get_merged_models (one GET per model otherwise). + api_base_urls = request.app.state.config.OPENAI_API_BASE_URLS + responses = await get_all_models_responses(request, user=user) def extract_data(response): @@ -514,10 +523,10 @@ def get_merged_models(model_lists): for model in model_list: model_id = model.get("id") or model.get("name") - if ( - "api.openai.com" - in request.app.state.config.OPENAI_API_BASE_URLS[idx] - and not is_supported_openai_models(model_id) + base_url = api_base_urls[idx] + hostname = urlparse(base_url).hostname if base_url else None + if hostname == "api.openai.com" and not is_supported_openai_models( + model_id ): # Skip unwanted OpenAI models continue @@ -546,6 +555,9 @@ def get_merged_models(model_lists): async def get_models( request: Request, url_idx: Optional[int] = None, user=Depends(get_verified_user) ): + if not request.app.state.config.ENABLE_OPENAI_API: + raise HTTPException(status_code=503, detail="OpenAI API is disabled") + models = { "data": [], } @@ -801,6 +813,115 @@ def convert_to_azure_payload(url, payload: dict, api_version: str): return url, payload +def convert_to_responses_payload(payload: dict) -> dict: + """ + Convert Chat Completions payload to Responses API format. + + Chat Completions: { messages: [{role, content}], ... } + Responses API: { input: [{type: "message", role, content: [...]}], instructions: "system" } + """ + messages = payload.pop("messages", []) + + system_content = "" + input_items = [] + + for msg in messages: + role = msg.get("role", "user") + content = msg.get("content", "") + + # Check for stored output items (from previous Responses API turn) + stored_output = msg.get("output") + if stored_output and isinstance(stored_output, list): + input_items.extend(stored_output) + continue + + if role == "system": + if isinstance(content, str): + system_content = content + elif isinstance(content, list): + system_content = "\n".join( + p.get("text", "") for p in content if p.get("type") == "text" + ) + continue + + # Convert content format + text_type = "output_text" if role == "assistant" else "input_text" + + if isinstance(content, str): + content_parts = [{"type": text_type, "text": content}] + elif isinstance(content, list): + content_parts = [] + for part in content: + if part.get("type") == "text": + content_parts.append( + {"type": text_type, "text": part.get("text", "")} + ) + elif part.get("type") == "image_url": + url_data = part.get("image_url", {}) + url = ( + url_data.get("url", "") + if isinstance(url_data, dict) + else url_data + ) + content_parts.append({"type": "input_image", "image_url": url}) + else: + content_parts = [{"type": text_type, "text": str(content)}] + + input_items.append({"type": "message", "role": role, "content": content_parts}) + + responses_payload = {**payload, "input": input_items} + + if system_content: + responses_payload["instructions"] = system_content + + if "max_tokens" in responses_payload: + responses_payload["max_output_tokens"] = responses_payload.pop("max_tokens") + + # Remove Chat Completions-only parameters not supported by the Responses API + for unsupported_key in ( + "stream_options", + "logit_bias", + "frequency_penalty", + "presence_penalty", + "stop", + ): + responses_payload.pop(unsupported_key, None) + + # Convert Chat Completions tools format to Responses API format + # Chat Completions: {"type": "function", "function": {"name": ..., "description": ..., "parameters": ...}} + # Responses API: {"type": "function", "name": ..., "description": ..., "parameters": ...} + if "tools" in responses_payload and isinstance(responses_payload["tools"], list): + converted_tools = [] + for tool in responses_payload["tools"]: + if isinstance(tool, dict) and "function" in tool: + func = tool["function"] + converted_tool = {"type": tool.get("type", "function")} + if isinstance(func, dict): + converted_tool["name"] = func.get("name", "") + if "description" in func: + converted_tool["description"] = func["description"] + if "parameters" in func: + converted_tool["parameters"] = func["parameters"] + if "strict" in func: + converted_tool["strict"] = func["strict"] + converted_tools.append(converted_tool) + else: + # Already in correct format or unknown format, pass through + converted_tools.append(tool) + responses_payload["tools"] = converted_tools + + return responses_payload + + +def convert_responses_result(response: dict) -> dict: + """ + Convert non-streaming Responses API result. + Just add done flag - pass through raw response, frontend handles output. + """ + response["done"] = True + return response + + @router.post("/chat/completions") async def generate_chat_completion( request: Request, @@ -808,8 +929,12 @@ async def generate_chat_completion( user=Depends(get_verified_user), bypass_filter: Optional[bool] = False, bypass_system_prompt: bool = False, - db: Session = Depends(get_session), ): + # NOTE: We intentionally do NOT use Depends(get_session) here. + # Database operations (get_model_by_id, AccessGrants.has_access) manage their own short-lived sessions. + # This prevents holding a connection during the entire LLM call (30-60+ seconds), + # which would exhaust the connection pool under concurrent load. + check_credit_by_user_id(user_id=user.id, form_data=form_data) if BYPASS_MODEL_ACCESS_CONTROL: @@ -821,7 +946,7 @@ async def generate_chat_completion( metadata = payload.pop("metadata", None) model_id = form_data.get("model") - model_info = Models.get_model_by_id(model_id, db=db) + model_info = Models.get_model_by_id(model_id) # Check model info and override the payload if model_info: @@ -847,11 +972,11 @@ async def generate_chat_completion( if not bypass_filter and user.role == "user": if not ( user.id == model_info.user_id - or has_access( - user.id, - type="read", - access_control=model_info.access_control, - db=db, + or AccessGrants.has_access( + user_id=user.id, + resource_type="model", + resource_id=model_info.id, + permission="read", ) ): raise HTTPException( @@ -865,8 +990,13 @@ async def generate_chat_completion( detail="Model not found", ) - await get_all_models(request, user=user) - model = request.app.state.OPENAI_MODELS.get(model_id) + # Check if model is already in app state cache to avoid expensive get_all_models() call + models = request.app.state.OPENAI_MODELS + if not models or model_id not in models: + await get_all_models(request, user=user) + models = request.app.state.OPENAI_MODELS + model = models.get(model_id) + if model: idx = model["urlIdx"] else: @@ -922,6 +1052,8 @@ async def generate_chat_completion( request, url, key, api_config, metadata, user=user ) + is_responses = api_config.get("api_type") == "responses" + if api_config.get("azure", False): api_version = api_config.get("api_version", "2023-03-15-preview") request_url, payload = convert_to_azure_payload(url, payload, api_version) @@ -932,9 +1064,18 @@ async def generate_chat_completion( headers["api-key"] = key headers["api-version"] = api_version - request_url = f"{request_url}/chat/completions?api-version={api_version}" + + if is_responses: + payload = convert_to_responses_payload(payload) + request_url = f"{request_url}/responses?api-version={api_version}" + else: + request_url = f"{request_url}/chat/completions?api-version={api_version}" else: - request_url = f"{url}/chat/completions" + if is_responses: + payload = convert_to_responses_payload(payload) + request_url = f"{url}/responses" + else: + request_url = f"{url}/chat/completions" payload = json.dumps(payload) @@ -964,12 +1105,11 @@ async def generate_chat_completion( streaming = True return StreamingResponse( - stream_chunks_handler(user, model_id, form_data, r.content), + stream_wrapper( + user, model_id, form_data, r, session, stream_chunks_handler + ), status_code=r.status, headers=dict(r.headers), - background=BackgroundTask( - cleanup_response, response=r, session=session - ), ) else: try: @@ -984,6 +1124,10 @@ async def generate_chat_completion( else: return PlainTextResponse(status_code=r.status, content=response) + # Convert Responses API result to simple format + if is_responses and isinstance(response, dict): + response = convert_responses_result(response) + with CreditDeduct( user=user, model_id=model_id, @@ -1025,9 +1169,12 @@ async def embeddings(request: Request, form_data: dict, user): # Prepare payload/body body = json.dumps(form_data) # Find correct backend url/key based on model - await get_all_models(request, user=user) model_id = form_data.get("model") + # Check if model is already in app state cache to avoid expensive get_all_models() call models = request.app.state.OPENAI_MODELS + if not models or model_id not in models: + await get_all_models(request, user=user) + models = request.app.state.OPENAI_MODELS if model_id in models: idx = models[model_id]["urlIdx"] @@ -1069,12 +1216,9 @@ async def embeddings(request: Request, form_data: dict, user): credit_deduct.run(form_data["input"]) streaming = True return StreamingResponse( - r.content, + stream_wrapper(user, model_id, form_data, r, session), status_code=r.status, headers=dict(r.headers), - background=BackgroundTask( - cleanup_response, response=r, session=session - ), ) else: try: @@ -1118,96 +1262,3 @@ async def embeddings(request: Request, form_data: dict, user): finally: if not streaming: await cleanup_response(r, session) - - -@router.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"]) -async def proxy(path: str, request: Request, user=Depends(get_verified_user)): - """ - Deprecated: proxy all requests to OpenAI API - """ - - body = await request.body() - - idx = 0 - url = request.app.state.config.OPENAI_API_BASE_URLS[idx] - key = request.app.state.config.OPENAI_API_KEYS[idx] - api_config = request.app.state.config.OPENAI_API_CONFIGS.get( - str(idx), - request.app.state.config.OPENAI_API_CONFIGS.get( - request.app.state.config.OPENAI_API_BASE_URLS[idx], {} - ), # Legacy support - ) - - r = None - session = None - streaming = False - - try: - headers, cookies = await get_headers_and_cookies( - request, url, key, api_config, user=user - ) - - if api_config.get("azure", False): - api_version = api_config.get("api_version", "2023-03-15-preview") - - # Only set api-key header if not using Azure Entra ID authentication - auth_type = api_config.get("auth_type", "bearer") - if auth_type not in ("azure_ad", "microsoft_entra_id"): - headers["api-key"] = key - - headers["api-version"] = api_version - - payload = json.loads(body) - url, payload = convert_to_azure_payload(url, payload, api_version) - body = json.dumps(payload).encode() - - request_url = f"{url}/{path}?api-version={api_version}" - else: - request_url = f"{url}/{path}" - - session = aiohttp.ClientSession(trust_env=True) - r = await session.request( - method=request.method, - url=request_url, - data=body, - headers=headers, - cookies=cookies, - ssl=AIOHTTP_CLIENT_SESSION_SSL, - ) - - # Check if response is SSE - if "text/event-stream" in r.headers.get("Content-Type", ""): - streaming = True - return StreamingResponse( - r.content, - status_code=r.status, - headers=dict(r.headers), - background=BackgroundTask( - cleanup_response, response=r, session=session - ), - ) - else: - try: - response_data = await r.json() - except Exception: - response_data = await r.text() - - if r.status >= 400: - if isinstance(response_data, (dict, list)): - return JSONResponse(status_code=r.status, content=response_data) - else: - return PlainTextResponse( - status_code=r.status, content=response_data - ) - - return response_data - - except Exception as e: - log.exception(e) - raise HTTPException( - status_code=r.status if r else 500, - detail="Open WebUI: Server Connection Error", - ) - finally: - if not streaming: - await cleanup_response(r, session) diff --git a/backend/open_webui/routers/pipelines.py b/backend/open_webui/routers/pipelines.py index 7a42acffc1..20fcd75eec 100644 --- a/backend/open_webui/routers/pipelines.py +++ b/backend/open_webui/routers/pipelines.py @@ -13,7 +13,6 @@ import os import logging import shutil -import requests from pydantic import BaseModel from starlette.responses import FileResponse from typing import Optional @@ -217,7 +216,7 @@ async def upload_pipeline( os.makedirs(upload_folder, exist_ok=True) file_path = os.path.join(upload_folder, filename) - r = None + response = None try: # Save the uploaded file with open(file_path, "wb") as buffer: @@ -226,16 +225,25 @@ async def upload_pipeline( url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx] key = request.app.state.config.OPENAI_API_KEYS[urlIdx] - with open(file_path, "rb") as f: - files = {"file": f} - r = requests.post( - f"{url}/pipelines/upload", - headers={"Authorization": f"Bearer {key}"}, - files=files, + headers = {"Authorization": f"Bearer {key}"} + + async with aiohttp.ClientSession(trust_env=True) as session: + form_data = aiohttp.FormData() + form_data.add_field( + "file", + open(file_path, "rb"), + filename=filename, + content_type="application/octet-stream", ) - r.raise_for_status() - data = r.json() + async with session.post( + f"{url}/pipelines/upload", + headers=headers, + data=form_data, + ssl=AIOHTTP_CLIENT_SESSION_SSL, + ) as response: + response.raise_for_status() + data = await response.json() return {**data} except Exception as e: @@ -244,10 +252,10 @@ async def upload_pipeline( detail = None status_code = status.HTTP_404_NOT_FOUND - if r is not None: - status_code = r.status_code + if response is not None: + status_code = response.status try: - res = r.json() + res = await response.json() if "detail" in res: detail = res["detail"] except Exception: @@ -272,21 +280,22 @@ class AddPipelineForm(BaseModel): async def add_pipeline( request: Request, form_data: AddPipelineForm, user=Depends(get_admin_user) ): - r = None + response = None try: urlIdx = form_data.urlIdx url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx] key = request.app.state.config.OPENAI_API_KEYS[urlIdx] - r = requests.post( - f"{url}/pipelines/add", - headers={"Authorization": f"Bearer {key}"}, - json={"url": form_data.url}, - ) - - r.raise_for_status() - data = r.json() + async with aiohttp.ClientSession(trust_env=True) as session: + async with session.post( + f"{url}/pipelines/add", + headers={"Authorization": f"Bearer {key}"}, + json={"url": form_data.url}, + ssl=AIOHTTP_CLIENT_SESSION_SSL, + ) as response: + response.raise_for_status() + data = await response.json() return {**data} except Exception as e: @@ -294,16 +303,18 @@ async def add_pipeline( log.exception(f"Connection error: {e}") detail = None - if r is not None: + if response is not None: try: - res = r.json() + res = await response.json() if "detail" in res: detail = res["detail"] except Exception: pass raise HTTPException( - status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND), + status_code=( + response.status if response is not None else status.HTTP_404_NOT_FOUND + ), detail=detail if detail else "Pipeline not found", ) @@ -317,21 +328,22 @@ class DeletePipelineForm(BaseModel): async def delete_pipeline( request: Request, form_data: DeletePipelineForm, user=Depends(get_admin_user) ): - r = None + response = None try: urlIdx = form_data.urlIdx url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx] key = request.app.state.config.OPENAI_API_KEYS[urlIdx] - r = requests.delete( - f"{url}/pipelines/delete", - headers={"Authorization": f"Bearer {key}"}, - json={"id": form_data.id}, - ) - - r.raise_for_status() - data = r.json() + async with aiohttp.ClientSession(trust_env=True) as session: + async with session.delete( + f"{url}/pipelines/delete", + headers={"Authorization": f"Bearer {key}"}, + json={"id": form_data.id}, + ssl=AIOHTTP_CLIENT_SESSION_SSL, + ) as response: + response.raise_for_status() + data = await response.json() return {**data} except Exception as e: @@ -339,16 +351,18 @@ async def delete_pipeline( log.exception(f"Connection error: {e}") detail = None - if r is not None: + if response is not None: try: - res = r.json() + res = await response.json() if "detail" in res: detail = res["detail"] except Exception: pass raise HTTPException( - status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND), + status_code=( + response.status if response is not None else status.HTTP_404_NOT_FOUND + ), detail=detail if detail else "Pipeline not found", ) @@ -357,15 +371,19 @@ async def delete_pipeline( async def get_pipelines( request: Request, urlIdx: Optional[int] = None, user=Depends(get_admin_user) ): - r = None + response = None try: url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx] key = request.app.state.config.OPENAI_API_KEYS[urlIdx] - r = requests.get(f"{url}/pipelines", headers={"Authorization": f"Bearer {key}"}) - - r.raise_for_status() - data = r.json() + async with aiohttp.ClientSession(trust_env=True) as session: + async with session.get( + f"{url}/pipelines", + headers={"Authorization": f"Bearer {key}"}, + ssl=AIOHTTP_CLIENT_SESSION_SSL, + ) as response: + response.raise_for_status() + data = await response.json() return {**data} except Exception as e: @@ -373,16 +391,18 @@ async def get_pipelines( log.exception(f"Connection error: {e}") detail = None - if r is not None: + if response is not None: try: - res = r.json() + res = await response.json() if "detail" in res: detail = res["detail"] except Exception: pass raise HTTPException( - status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND), + status_code=( + response.status if response is not None else status.HTTP_404_NOT_FOUND + ), detail=detail if detail else "Pipeline not found", ) @@ -394,17 +414,19 @@ async def get_pipeline_valves( pipeline_id: str, user=Depends(get_admin_user), ): - r = None + response = None try: url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx] key = request.app.state.config.OPENAI_API_KEYS[urlIdx] - r = requests.get( - f"{url}/{pipeline_id}/valves", headers={"Authorization": f"Bearer {key}"} - ) - - r.raise_for_status() - data = r.json() + async with aiohttp.ClientSession(trust_env=True) as session: + async with session.get( + f"{url}/{pipeline_id}/valves", + headers={"Authorization": f"Bearer {key}"}, + ssl=AIOHTTP_CLIENT_SESSION_SSL, + ) as response: + response.raise_for_status() + data = await response.json() return {**data} except Exception as e: @@ -412,16 +434,18 @@ async def get_pipeline_valves( log.exception(f"Connection error: {e}") detail = None - if r is not None: + if response is not None: try: - res = r.json() + res = await response.json() if "detail" in res: detail = res["detail"] except Exception: pass raise HTTPException( - status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND), + status_code=( + response.status if response is not None else status.HTTP_404_NOT_FOUND + ), detail=detail if detail else "Pipeline not found", ) @@ -433,18 +457,19 @@ async def get_pipeline_valves_spec( pipeline_id: str, user=Depends(get_admin_user), ): - r = None + response = None try: url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx] key = request.app.state.config.OPENAI_API_KEYS[urlIdx] - r = requests.get( - f"{url}/{pipeline_id}/valves/spec", - headers={"Authorization": f"Bearer {key}"}, - ) - - r.raise_for_status() - data = r.json() + async with aiohttp.ClientSession(trust_env=True) as session: + async with session.get( + f"{url}/{pipeline_id}/valves/spec", + headers={"Authorization": f"Bearer {key}"}, + ssl=AIOHTTP_CLIENT_SESSION_SSL, + ) as response: + response.raise_for_status() + data = await response.json() return {**data} except Exception as e: @@ -452,16 +477,18 @@ async def get_pipeline_valves_spec( log.exception(f"Connection error: {e}") detail = None - if r is not None: + if response is not None: try: - res = r.json() + res = await response.json() if "detail" in res: detail = res["detail"] except Exception: pass raise HTTPException( - status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND), + status_code=( + response.status if response is not None else status.HTTP_404_NOT_FOUND + ), detail=detail if detail else "Pipeline not found", ) @@ -474,19 +501,20 @@ async def update_pipeline_valves( form_data: dict, user=Depends(get_admin_user), ): - r = None + response = None try: url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx] key = request.app.state.config.OPENAI_API_KEYS[urlIdx] - r = requests.post( - f"{url}/{pipeline_id}/valves/update", - headers={"Authorization": f"Bearer {key}"}, - json={**form_data}, - ) - - r.raise_for_status() - data = r.json() + async with aiohttp.ClientSession(trust_env=True) as session: + async with session.post( + f"{url}/{pipeline_id}/valves/update", + headers={"Authorization": f"Bearer {key}"}, + json={**form_data}, + ssl=AIOHTTP_CLIENT_SESSION_SSL, + ) as response: + response.raise_for_status() + data = await response.json() return {**data} except Exception as e: @@ -495,15 +523,17 @@ async def update_pipeline_valves( detail = None - if r is not None: + if response is not None: try: - res = r.json() + res = await response.json() if "detail" in res: detail = res["detail"] except Exception: pass raise HTTPException( - status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND), + status_code=( + response.status if response is not None else status.HTTP_404_NOT_FOUND + ), detail=detail if detail else "Pipeline not found", ) diff --git a/backend/open_webui/routers/prompts.py b/backend/open_webui/routers/prompts.py index 19d25685ad..e8d4660f03 100644 --- a/backend/open_webui/routers/prompts.py +++ b/backend/open_webui/routers/prompts.py @@ -5,18 +5,41 @@ PromptForm, PromptUserResponse, PromptAccessResponse, + PromptAccessListResponse, PromptModel, Prompts, ) +from open_webui.models.access_grants import AccessGrants +from open_webui.models.groups import Groups +from open_webui.models.prompt_history import ( + PromptHistories, + PromptHistoryModel, + PromptHistoryResponse, +) from open_webui.constants import ERROR_MESSAGES from open_webui.utils.auth import get_admin_user, get_verified_user -from open_webui.utils.access_control import has_access, has_permission +from open_webui.utils.access_control import has_permission from open_webui.config import BYPASS_ADMIN_ACCESS_CONTROL from open_webui.internal.db import get_session from sqlalchemy.orm import Session +from pydantic import BaseModel + + +class PromptVersionUpdateForm(BaseModel): + version_id: str + + +class PromptMetadataForm(BaseModel): + name: str + command: str + tags: Optional[list[str]] = None + router = APIRouter() +PAGE_ITEM_COUNT = 30 + + ############################ # GetPrompts ############################ @@ -34,26 +57,80 @@ async def get_prompts( return prompts -@router.get("/list", response_model=list[PromptAccessResponse]) -async def get_prompt_list( +@router.get("/tags", response_model=list[str]) +async def get_prompt_tags( user=Depends(get_verified_user), db: Session = Depends(get_session) ): if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL: - prompts = Prompts.get_prompts(db=db) + return Prompts.get_tags(db=db) else: prompts = Prompts.get_prompts_by_user_id(user.id, "read", db=db) + tags = set() + for prompt in prompts: + if prompt.tags: + tags.update(prompt.tags) + return sorted(list(tags)) - return [ - PromptAccessResponse( - **prompt.model_dump(), - write_access=( - (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL) - or user.id == prompt.user_id - or has_access(user.id, "write", prompt.access_control, db=db) - ), - ) - for prompt in prompts - ] + +@router.get("/list", response_model=PromptAccessListResponse) +async def get_prompt_list( + query: Optional[str] = None, + view_option: Optional[str] = None, + tag: Optional[str] = None, + order_by: Optional[str] = None, + direction: Optional[str] = None, + page: Optional[int] = 1, + user=Depends(get_verified_user), + db: Session = Depends(get_session), +): + limit = PAGE_ITEM_COUNT + + page = max(1, page) + skip = (page - 1) * limit + + filter = {} + if query: + filter["query"] = query + if view_option: + filter["view_option"] = view_option + if tag: + filter["tag"] = tag + if order_by: + filter["order_by"] = order_by + if direction: + filter["direction"] = direction + + if not (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL): + groups = Groups.get_groups_by_member_id(user.id, db=db) + if groups: + filter["group_ids"] = [group.id for group in groups] + + filter["user_id"] = user.id + + result = Prompts.search_prompts( + user.id, filter=filter, skip=skip, limit=limit, db=db + ) + + return PromptAccessListResponse( + items=[ + PromptAccessResponse( + **prompt.model_dump(), + write_access=( + (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL) + or user.id == prompt.user_id + or AccessGrants.has_access( + user_id=user.id, + resource_type="prompt", + resource_id=prompt.id, + permission="write", + db=db, + ) + ), + ) + for prompt in result.items + ], + total=result.total, + ) ############################ @@ -112,52 +189,172 @@ async def create_new_prompt( async def get_prompt_by_command( command: str, user=Depends(get_verified_user), db: Session = Depends(get_session) ): - prompt = Prompts.get_prompt_by_command(f"/{command}", db=db) + prompt = Prompts.get_prompt_by_command(command, db=db) if prompt: if ( user.role == "admin" or prompt.user_id == user.id - or has_access(user.id, "read", prompt.access_control, db=db) + or AccessGrants.has_access( + user_id=user.id, + resource_type="prompt", + resource_id=prompt.id, + permission="read", + db=db, + ) ): return PromptAccessResponse( **prompt.model_dump(), write_access=( (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL) or user.id == prompt.user_id - or has_access(user.id, "write", prompt.access_control, db=db) + or AccessGrants.has_access( + user_id=user.id, + resource_type="prompt", + resource_id=prompt.id, + permission="write", + db=db, + ) ), ) - else: + + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + +############################ +# GetPromptById +############################ + + +@router.get("/id/{prompt_id}", response_model=Optional[PromptAccessResponse]) +async def get_prompt_by_id( + prompt_id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) +): + prompt = Prompts.get_prompt_by_id(prompt_id, db=db) + + if prompt: + if ( + user.role == "admin" + or prompt.user_id == user.id + or AccessGrants.has_access( + user_id=user.id, + resource_type="prompt", + resource_id=prompt.id, + permission="read", + db=db, + ) + ): + return PromptAccessResponse( + **prompt.model_dump(), + write_access=( + (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL) + or user.id == prompt.user_id + or AccessGrants.has_access( + user_id=user.id, + resource_type="prompt", + resource_id=prompt.id, + permission="write", + db=db, + ) + ), + ) + + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + +############################ +# UpdatePromptById +############################ + + +@router.post("/id/{prompt_id}/update", response_model=Optional[PromptModel]) +async def update_prompt_by_id( + prompt_id: str, + form_data: PromptForm, + user=Depends(get_verified_user), + db: Session = Depends(get_session), +): + prompt = Prompts.get_prompt_by_id(prompt_id, db=db) + + if not prompt: raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, + status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND, ) + # Is the user the original creator, in a group with write access, or an admin + if ( + prompt.user_id != user.id + and not AccessGrants.has_access( + user_id=user.id, + resource_type="prompt", + resource_id=prompt.id, + permission="write", + db=db, + ) + and user.role != "admin" + ): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) + + # Check for command collision if command is being changed + if form_data.command != prompt.command: + existing_prompt = Prompts.get_prompt_by_command(form_data.command, db=db) + if existing_prompt and existing_prompt.id != prompt.id: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Command '/{form_data.command}' is already in use by another prompt", + ) + + # Use the ID from the found prompt + updated_prompt = Prompts.update_prompt_by_id(prompt.id, form_data, user.id, db=db) + if updated_prompt: + return updated_prompt + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT(), + ) + ############################ -# UpdatePromptByCommand +# UpdatePromptMetadata ############################ -@router.post("/command/{command}/update", response_model=Optional[PromptModel]) -async def update_prompt_by_command( - command: str, - form_data: PromptForm, +@router.post("/id/{prompt_id}/update/meta", response_model=Optional[PromptModel]) +async def update_prompt_metadata( + prompt_id: str, + form_data: PromptMetadataForm, user=Depends(get_verified_user), db: Session = Depends(get_session), ): - prompt = Prompts.get_prompt_by_command(f"/{command}", db=db) + """Update prompt name and command only (no history created).""" + prompt = Prompts.get_prompt_by_id(prompt_id, db=db) + if not prompt: raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, + status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND, ) - # Is the user the original creator, in a group with write access, or an admin if ( prompt.user_id != user.id - and not has_access(user.id, "write", prompt.access_control, db=db) + and not AccessGrants.has_access( + user_id=user.id, + resource_type="prompt", + resource_id=prompt.id, + permission="write", + db=db, + ) and user.role != "admin" ): raise HTTPException( @@ -165,35 +362,139 @@ async def update_prompt_by_command( detail=ERROR_MESSAGES.ACCESS_PROHIBITED, ) - prompt = Prompts.update_prompt_by_command(f"/{command}", form_data, db=db) - if prompt: - return prompt + # Check for command collision if command is being changed + if form_data.command != prompt.command: + existing_prompt = Prompts.get_prompt_by_command(form_data.command, db=db) + if existing_prompt and existing_prompt.id != prompt.id: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Command '/{form_data.command}' is already in use", + ) + + updated_prompt = Prompts.update_prompt_metadata( + prompt.id, form_data.name, form_data.command, form_data.tags, db=db + ) + if updated_prompt: + return updated_prompt else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT(), + ) + + +@router.post("/id/{prompt_id}/update/version", response_model=Optional[PromptModel]) +async def set_prompt_version( + prompt_id: str, + form_data: PromptVersionUpdateForm, + user=Depends(get_verified_user), + db: Session = Depends(get_session), +): + prompt = Prompts.get_prompt_by_id(prompt_id, db=db) + if not prompt: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + if ( + prompt.user_id != user.id + and not AccessGrants.has_access( + user_id=user.id, + resource_type="prompt", + resource_id=prompt.id, + permission="write", + db=db, + ) + and user.role != "admin" + ): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.ACCESS_PROHIBITED, ) + updated_prompt = Prompts.update_prompt_version( + prompt.id, form_data.version_id, db=db + ) + if updated_prompt: + return updated_prompt + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT(), + ) + ############################ -# DeletePromptByCommand +# UpdatePromptAccessById ############################ -@router.delete("/command/{command}/delete", response_model=bool) -async def delete_prompt_by_command( - command: str, user=Depends(get_verified_user), db: Session = Depends(get_session) +class PromptAccessGrantsForm(BaseModel): + access_grants: list[dict] + + +@router.post("/id/{prompt_id}/access/update", response_model=Optional[PromptModel]) +async def update_prompt_access_by_id( + prompt_id: str, + form_data: PromptAccessGrantsForm, + user=Depends(get_verified_user), + db: Session = Depends(get_session), ): - prompt = Prompts.get_prompt_by_command(f"/{command}", db=db) + prompt = Prompts.get_prompt_by_id(prompt_id, db=db) if not prompt: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + if ( + prompt.user_id != user.id + and not AccessGrants.has_access( + user_id=user.id, + resource_type="prompt", + resource_id=prompt.id, + permission="write", + db=db, + ) + and user.role != "admin" + ): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) + + AccessGrants.set_access_grants("prompt", prompt_id, form_data.access_grants, db=db) + + return Prompts.get_prompt_by_id(prompt_id, db=db) + + +############################ +# DeletePromptById +############################ + + +@router.delete("/id/{prompt_id}/delete", response_model=bool) +async def delete_prompt_by_id( + prompt_id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) +): + prompt = Prompts.get_prompt_by_id(prompt_id, db=db) + + if not prompt: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND, ) if ( prompt.user_id != user.id - and not has_access(user.id, "write", prompt.access_control, db=db) + and not AccessGrants.has_access( + user_id=user.id, + resource_type="prompt", + resource_id=prompt.id, + permission="write", + db=db, + ) and user.role != "admin" ): raise HTTPException( @@ -201,5 +502,188 @@ async def delete_prompt_by_command( detail=ERROR_MESSAGES.ACCESS_PROHIBITED, ) - result = Prompts.delete_prompt_by_command(f"/{command}", db=db) + result = Prompts.delete_prompt_by_id(prompt.id, db=db) return result + + +############################ +# Prompt History Endpoints +############################ + + +@router.get("/id/{prompt_id}/history", response_model=list[PromptHistoryResponse]) +async def get_prompt_history( + prompt_id: str, + page: int = 0, + user=Depends(get_verified_user), + db: Session = Depends(get_session), +): + """Get version history for a prompt.""" + PAGE_SIZE = 20 + + prompt = Prompts.get_prompt_by_id(prompt_id, db=db) + + if not prompt: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + # Check read access + if not ( + user.role == "admin" + or prompt.user_id == user.id + or AccessGrants.has_access( + user_id=user.id, + resource_type="prompt", + resource_id=prompt.id, + permission="read", + db=db, + ) + ): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) + + history = PromptHistories.get_history_by_prompt_id( + prompt.id, limit=PAGE_SIZE, offset=page * PAGE_SIZE, db=db + ) + return history + + +@router.get("/id/{prompt_id}/history/{history_id}", response_model=PromptHistoryModel) +async def get_prompt_history_entry( + prompt_id: str, + history_id: str, + user=Depends(get_verified_user), + db: Session = Depends(get_session), +): + """Get a specific version from history.""" + prompt = Prompts.get_prompt_by_id(prompt_id, db=db) + + if not prompt: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + # Check read access + if not ( + user.role == "admin" + or prompt.user_id == user.id + or AccessGrants.has_access( + user_id=user.id, + resource_type="prompt", + resource_id=prompt.id, + permission="read", + db=db, + ) + ): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) + + history_entry = PromptHistories.get_history_entry_by_id(history_id, db=db) + if not history_entry or history_entry.prompt_id != prompt.id: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + return history_entry + + +@router.delete("/id/{prompt_id}/history/{history_id}", response_model=bool) +async def delete_prompt_history_entry( + prompt_id: str, + history_id: str, + user=Depends(get_verified_user), + db: Session = Depends(get_session), +): + """Delete a history entry. Cannot delete the active production version.""" + prompt = Prompts.get_prompt_by_id(prompt_id, db=db) + + if not prompt: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + # Check write access + if not ( + user.role == "admin" + or prompt.user_id == user.id + or AccessGrants.has_access( + user_id=user.id, + resource_type="prompt", + resource_id=prompt.id, + permission="write", + db=db, + ) + ): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) + + # Cannot delete active production version + if prompt.version_id == history_id: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Cannot delete the active production version", + ) + + success = PromptHistories.delete_history_entry(history_id, db=db) + if not success: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + return success + + +@router.get("/id/{prompt_id}/history/diff") +async def get_prompt_diff( + prompt_id: str, + from_id: str, + to_id: str, + user=Depends(get_verified_user), + db: Session = Depends(get_session), +): + """Get diff between two versions.""" + prompt = Prompts.get_prompt_by_id(prompt_id, db=db) + + if not prompt: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + # Check read access + if not ( + user.role == "admin" + or prompt.user_id == user.id + or AccessGrants.has_access( + user_id=user.id, + resource_type="prompt", + resource_id=prompt.id, + permission="read", + db=db, + ) + ): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) + + diff = PromptHistories.compute_diff(from_id, to_id, db=db) + if not diff: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="One or both history entries not found", + ) + + return diff diff --git a/backend/open_webui/routers/retrieval.py b/backend/open_webui/routers/retrieval.py index db3f80f149..0d513c9ae1 100644 --- a/backend/open_webui/routers/retrieval.py +++ b/backend/open_webui/routers/retrieval.py @@ -39,7 +39,7 @@ from open_webui.models.files import FileModel, FileUpdateForm, Files from open_webui.models.knowledge import Knowledges from open_webui.storage.provider import Storage -from open_webui.internal.db import get_session +from open_webui.internal.db import get_session, get_db from sqlalchemy.orm import Session @@ -76,6 +76,7 @@ from open_webui.retrieval.web.sougou import search_sougou from open_webui.retrieval.web.firecrawl import search_firecrawl from open_webui.retrieval.web.external import search_external +from open_webui.retrieval.web.yandex import search_yandex from open_webui.retrieval.utils import ( get_content_from_url, @@ -109,6 +110,7 @@ from open_webui.env import ( DEVICE_TYPE, DOCKER, + RAG_EMBEDDING_TIMEOUT, SENTENCE_TRANSFORMERS_BACKEND, SENTENCE_TRANSFORMERS_MODEL_KWARGS, SENTENCE_TRANSFORMERS_CROSS_ENCODER_BACKEND, @@ -468,6 +470,7 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)): # Content extraction settings "CONTENT_EXTRACTION_ENGINE": request.app.state.config.CONTENT_EXTRACTION_ENGINE, "PDF_EXTRACT_IMAGES": request.app.state.config.PDF_EXTRACT_IMAGES, + "PDF_LOADER_MODE": request.app.state.config.PDF_LOADER_MODE, "DATALAB_MARKER_API_KEY": request.app.state.config.DATALAB_MARKER_API_KEY, "DATALAB_MARKER_API_BASE_URL": request.app.state.config.DATALAB_MARKER_API_BASE_URL, "DATALAB_MARKER_ADDITIONAL_CONFIG": request.app.state.config.DATALAB_MARKER_ADDITIONAL_CONFIG, @@ -577,6 +580,9 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)): "YOUTUBE_LOADER_LANGUAGE": request.app.state.config.YOUTUBE_LOADER_LANGUAGE, "YOUTUBE_LOADER_PROXY_URL": request.app.state.config.YOUTUBE_LOADER_PROXY_URL, "YOUTUBE_LOADER_TRANSLATION": request.app.state.YOUTUBE_LOADER_TRANSLATION, + "YANDEX_WEB_SEARCH_URL": request.app.state.config.YANDEX_WEB_SEARCH_URL, + "YANDEX_WEB_SEARCH_API_KEY": request.app.state.config.YANDEX_WEB_SEARCH_API_KEY, + "YANDEX_WEB_SEARCH_CONFIG": request.app.state.config.YANDEX_WEB_SEARCH_CONFIG, }, } @@ -640,6 +646,9 @@ class WebConfig(BaseModel): YOUTUBE_LOADER_LANGUAGE: Optional[List[str]] = None YOUTUBE_LOADER_PROXY_URL: Optional[str] = None YOUTUBE_LOADER_TRANSLATION: Optional[str] = None + YANDEX_WEB_SEARCH_URL: Optional[str] = None + YANDEX_WEB_SEARCH_API_KEY: Optional[str] = None + YANDEX_WEB_SEARCH_CONFIG: Optional[str] = None class ConfigForm(BaseModel): @@ -659,6 +668,7 @@ class ConfigForm(BaseModel): # Content extraction settings CONTENT_EXTRACTION_ENGINE: Optional[str] = None PDF_EXTRACT_IMAGES: Optional[bool] = None + PDF_LOADER_MODE: Optional[str] = None DATALAB_MARKER_API_KEY: Optional[str] = None DATALAB_MARKER_API_BASE_URL: Optional[str] = None @@ -786,6 +796,11 @@ async def update_rag_config( if form_data.PDF_EXTRACT_IMAGES is not None else request.app.state.config.PDF_EXTRACT_IMAGES ) + request.app.state.config.PDF_LOADER_MODE = ( + form_data.PDF_LOADER_MODE + if form_data.PDF_LOADER_MODE is not None + else request.app.state.config.PDF_LOADER_MODE + ) request.app.state.config.DATALAB_MARKER_API_KEY = ( form_data.DATALAB_MARKER_API_KEY if form_data.DATALAB_MARKER_API_KEY is not None @@ -1006,6 +1021,11 @@ async def update_rag_config( if form_data.TEXT_SPLITTER is not None else request.app.state.config.TEXT_SPLITTER ) + request.app.state.config.ENABLE_MARKDOWN_HEADER_TEXT_SPLITTER = ( + form_data.ENABLE_MARKDOWN_HEADER_TEXT_SPLITTER + if form_data.ENABLE_MARKDOWN_HEADER_TEXT_SPLITTER is not None + else request.app.state.config.ENABLE_MARKDOWN_HEADER_TEXT_SPLITTER + ) request.app.state.config.CHUNK_SIZE = ( form_data.CHUNK_SIZE if form_data.CHUNK_SIZE is not None @@ -1023,13 +1043,25 @@ async def update_rag_config( ) # File upload settings - request.app.state.config.FILE_MAX_SIZE = form_data.FILE_MAX_SIZE - request.app.state.config.FILE_MAX_COUNT = form_data.FILE_MAX_COUNT + request.app.state.config.FILE_MAX_SIZE = ( + form_data.FILE_MAX_SIZE + if form_data.FILE_MAX_SIZE is not None + else request.app.state.config.FILE_MAX_SIZE + ) + request.app.state.config.FILE_MAX_COUNT = ( + form_data.FILE_MAX_COUNT + if form_data.FILE_MAX_COUNT is not None + else request.app.state.config.FILE_MAX_COUNT + ) request.app.state.config.FILE_IMAGE_COMPRESSION_WIDTH = ( form_data.FILE_IMAGE_COMPRESSION_WIDTH + if form_data.FILE_IMAGE_COMPRESSION_WIDTH is not None + else request.app.state.config.FILE_IMAGE_COMPRESSION_WIDTH ) request.app.state.config.FILE_IMAGE_COMPRESSION_HEIGHT = ( form_data.FILE_IMAGE_COMPRESSION_HEIGHT + if form_data.FILE_IMAGE_COMPRESSION_HEIGHT is not None + else request.app.state.config.FILE_IMAGE_COMPRESSION_HEIGHT ) request.app.state.config.ALLOWED_FILE_EXTENSIONS = ( form_data.ALLOWED_FILE_EXTENSIONS @@ -1164,6 +1196,15 @@ async def update_rag_config( request.app.state.YOUTUBE_LOADER_TRANSLATION = ( form_data.web.YOUTUBE_LOADER_TRANSLATION ) + request.app.state.config.YANDEX_WEB_SEARCH_URL = ( + form_data.web.YANDEX_WEB_SEARCH_URL + ) + request.app.state.config.YANDEX_WEB_SEARCH_API_KEY = ( + form_data.web.YANDEX_WEB_SEARCH_API_KEY + ) + request.app.state.config.YANDEX_WEB_SEARCH_CONFIG = ( + form_data.web.YANDEX_WEB_SEARCH_CONFIG + ) return { "status": True, @@ -1180,6 +1221,7 @@ async def update_rag_config( # Content extraction settings "CONTENT_EXTRACTION_ENGINE": request.app.state.config.CONTENT_EXTRACTION_ENGINE, "PDF_EXTRACT_IMAGES": request.app.state.config.PDF_EXTRACT_IMAGES, + "PDF_LOADER_MODE": request.app.state.config.PDF_LOADER_MODE, "DATALAB_MARKER_API_KEY": request.app.state.config.DATALAB_MARKER_API_KEY, "DATALAB_MARKER_API_BASE_URL": request.app.state.config.DATALAB_MARKER_API_BASE_URL, "DATALAB_MARKER_ADDITIONAL_CONFIG": request.app.state.config.DATALAB_MARKER_ADDITIONAL_CONFIG, @@ -1287,6 +1329,9 @@ async def update_rag_config( "YOUTUBE_LOADER_LANGUAGE": request.app.state.config.YOUTUBE_LOADER_LANGUAGE, "YOUTUBE_LOADER_PROXY_URL": request.app.state.config.YOUTUBE_LOADER_PROXY_URL, "YOUTUBE_LOADER_TRANSLATION": request.app.state.YOUTUBE_LOADER_TRANSLATION, + "YANDEX_WEB_SEARCH_URL": request.app.state.config.YANDEX_WEB_SEARCH_URL, + "YANDEX_WEB_SEARCH_API_KEY": request.app.state.config.YANDEX_WEB_SEARCH_API_KEY, + "YANDEX_WEB_SEARCH_CONFIG": request.app.state.config.YANDEX_WEB_SEARCH_CONFIG, }, } @@ -1417,8 +1462,16 @@ def _get_docs_info(docs: list[Document]) -> str: if result is not None and result.ids and len(result.ids) > 0: existing_doc_ids = result.ids[0] if existing_doc_ids: - log.info(f"Document with hash {metadata['hash']} already exists") - raise ValueError(ERROR_MESSAGES.DUPLICATE_CONTENT) + # Check if the existing document belongs to the same file + # If same file_id, this is a re-add/reindex - allow it + # If different file_id, this is a duplicate - block it + existing_file_id = None + if result.metadatas and result.metadatas[0]: + existing_file_id = result.metadatas[0][0].get("file_id") + + if existing_file_id != metadata.get("file_id"): + log.info(f"Document with hash {metadata['hash']} already exists") + raise ValueError(ERROR_MESSAGES.DUPLICATE_CONTENT) if split: if request.app.state.config.ENABLE_MARKDOWN_HEADER_TEXT_SPLITTER: @@ -1538,14 +1591,19 @@ def _get_docs_info(docs: list[Document]) -> str: enable_async=request.app.state.config.ENABLE_ASYNC_EMBEDDING, ) - # Run async embedding in sync context - embeddings = asyncio.run( + # Run async embedding in sync context using the main event loop + # This allows the main loop to stay responsive to health checks during long operations + embedding_timeout = RAG_EMBEDDING_TIMEOUT + + future = asyncio.run_coroutine_threadsafe( embedding_function( list(map(lambda x: x.replace("\n", " "), texts)), prefix=RAG_EMBEDDING_CONTENT_PREFIX, user=user, - ) + ), + request.app.state.main_loop, ) + embeddings = future.result(timeout=embedding_timeout) log.info(f"embeddings generated {len(embeddings)} for {len(texts)} items") items = [ @@ -1586,6 +1644,9 @@ def process_file( ): """ Process a file and save its content to the vector database. + Process a file and save its content to the vector database. + Note: granular session management is used to prevent connection pool exhaustion. + The session is committed before external API calls, and updates use a fresh session. """ if user.role == "admin": file = Files.get_file_by_id(form_data.file_id, db=db) @@ -1747,6 +1808,12 @@ def process_file( } else: try: + # Commit any pending changes before the slow embedding step. + # Note: file is already a Pydantic model (not ORM), so no expunge needed. + db.commit() + + # External embedding API takes time (5-60s+). + # Subsequent updates use fresh sessions via get_db(). result = save_docs_to_vector_db( request, docs=docs, @@ -1762,27 +1829,29 @@ def process_file( log.info(f"added {len(docs)} items to collection {collection_name}") if result: - Files.update_file_metadata_by_id( - file.id, - { + # Fresh session for the final update. + with get_db() as session: + Files.update_file_metadata_by_id( + file.id, + { + "collection_name": collection_name, + }, + db=session, + ) + + Files.update_file_data_by_id( + file.id, + {"status": "completed"}, + db=session, + ) + Files.update_file_hash_by_id(file.id, hash, db=session) + + return { + "status": True, "collection_name": collection_name, - }, - db=db, - ) - - Files.update_file_data_by_id( - file.id, - {"status": "completed"}, - db=db, - ) - Files.update_file_hash_by_id(file.id, hash, db=db) - - return { - "status": True, - "collection_name": collection_name, - "filename": file.filename, - "content": text_content, - } + "filename": file.filename, + "content": text_content, + } else: raise Exception("Error saving document to vector database") except Exception as e: @@ -1790,13 +1859,15 @@ def process_file( except Exception as e: log.exception(e) - Files.update_file_data_by_id( - file.id, - {"status": "failed"}, - db=db, - ) - # Clear the hash so the file can be re-uploaded after fixing the issue - Files.update_file_hash_by_id(file.id, None, db=db) + # Fresh session for error status update. + with get_db() as session: + Files.update_file_data_by_id( + file.id, + {"status": "failed"}, + db=session, + ) + # Clear the hash so the file can be re-uploaded after fixing the issue + Files.update_file_hash_by_id(file.id, None, db=session) if "No pandoc was found" in str(e): raise HTTPException( @@ -2206,6 +2277,17 @@ def search_web( request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST, user=user, ) + elif engine == "yandex": + return search_yandex( + request, + request.app.state.config.YANDEX_WEB_SEARCH_URL, + request.app.state.config.YANDEX_WEB_SEARCH_API_KEY, + request.app.state.config.YANDEX_WEB_SEARCH_CONFIG, + query, + request.app.state.config.WEB_SEARCH_RESULT_COUNT, + request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST, + user=user, + ) else: raise Exception("No search engine API key found in environment variables") @@ -2626,10 +2708,14 @@ async def process_files_batch( request: Request, form_data: BatchProcessFilesForm, user=Depends(get_verified_user), - db: Session = Depends(get_session), ) -> BatchProcessFilesResponse: """ Process a batch of files and save them to the vector database. + + NOTE: We intentionally do NOT use Depends(get_session) here. + The save_docs_to_vector_db() call makes external embedding API calls which + can take 5-60+ seconds for batch operations. Database operations after + embedding (Files.update_file_by_id) manage their own short-lived sessions. """ collection_name = form_data.collection_name @@ -2689,9 +2775,7 @@ async def process_files_batch( # Update all files with collection name for file_update, file_result in zip(file_updates, file_results): - Files.update_file_by_id( - id=file_result.file_id, form_data=file_update, db=db - ) + Files.update_file_by_id(id=file_result.file_id, form_data=file_update) file_result.status = "completed" except Exception as e: @@ -2701,7 +2785,9 @@ async def process_files_batch( for file_result in file_results: file_result.status = "failed" file_errors.append( - BatchProcessFilesResult(file_id=file_result.file_id, error=str(e)) + BatchProcessFilesResult( + file_id=file_result.file_id, status="failed", error=str(e) + ) ) return BatchProcessFilesResponse(results=file_results, errors=file_errors) diff --git a/backend/open_webui/routers/scim.py b/backend/open_webui/routers/scim.py index 9070256770..681be3c7d2 100644 --- a/backend/open_webui/routers/scim.py +++ b/backend/open_webui/routers/scim.py @@ -352,18 +352,17 @@ def user_to_scim(user: UserModel, request: Request, db=None) -> SCIMUser: def group_to_scim(group: GroupModel, request: Request, db=None) -> SCIMGroup: """Convert internal Group model to SCIM Group""" member_ids = Groups.get_group_user_ids_by_id(group.id, db) or [] - members = [] - - for user_id in member_ids: - user = Users.get_user_by_id(user_id, db=db) - if user: - members.append( - SCIMGroupMember( - value=user.id, - ref=f"{request.base_url}api/v1/scim/v2/Users/{user.id}", - display=user.name, - ) - ) + + # Batch-fetch all users to avoid N+1 queries + users = Users.get_users_by_user_ids(member_ids, db=db) if member_ids else [] + members = [ + SCIMGroupMember( + value=user.id, + ref=f"{request.base_url}api/v1/scim/v2/Users/{user.id}", + display=user.name, + ) + for user in users + ] return SCIMGroup( id=group.id, diff --git a/backend/open_webui/routers/skills.py b/backend/open_webui/routers/skills.py new file mode 100644 index 0000000000..367768e61b --- /dev/null +++ b/backend/open_webui/routers/skills.py @@ -0,0 +1,427 @@ +import logging +from typing import Optional + +from open_webui.models.groups import Groups +from pydantic import BaseModel + +from fastapi import APIRouter, Depends, HTTPException, Request, status +from sqlalchemy.orm import Session + +from open_webui.internal.db import get_session +from open_webui.models.skills import ( + SkillForm, + SkillModel, + SkillResponse, + SkillUserResponse, + SkillAccessResponse, + SkillAccessListResponse, + Skills, +) +from open_webui.models.access_grants import AccessGrants +from open_webui.utils.auth import get_admin_user, get_verified_user +from open_webui.utils.access_control import has_access, has_permission + +from open_webui.config import BYPASS_ADMIN_ACCESS_CONTROL +from open_webui.constants import ERROR_MESSAGES + +log = logging.getLogger(__name__) + +PAGE_ITEM_COUNT = 30 + +router = APIRouter() + + +############################ +# GetSkills +############################ + + +@router.get("/", response_model=list[SkillUserResponse]) +async def get_skills( + request: Request, + user=Depends(get_verified_user), + db: Session = Depends(get_session), +): + if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL: + skills = Skills.get_skills(db=db) + else: + user_group_ids = { + group.id for group in Groups.get_groups_by_member_id(user.id, db=db) + } + all_skills = Skills.get_skills(db=db) + skills = [ + skill + for skill in all_skills + if skill.user_id == user.id + or AccessGrants.has_access( + user_id=user.id, + resource_type="skill", + resource_id=skill.id, + permission="read", + user_group_ids=user_group_ids, + db=db, + ) + ] + + return skills + + +############################ +# GetSkillList +############################ + + +@router.get("/list", response_model=SkillAccessListResponse) +async def get_skill_list( + query: Optional[str] = None, + view_option: Optional[str] = None, + page: Optional[int] = 1, + user=Depends(get_verified_user), + db: Session = Depends(get_session), +): + limit = PAGE_ITEM_COUNT + + page = max(1, page) + skip = (page - 1) * limit + + filter = {} + if query: + filter["query"] = query + if view_option: + filter["view_option"] = view_option + + if not (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL): + groups = Groups.get_groups_by_member_id(user.id, db=db) + if groups: + filter["group_ids"] = [group.id for group in groups] + + filter["user_id"] = user.id + + result = Skills.search_skills(user.id, filter=filter, skip=skip, limit=limit, db=db) + + return SkillAccessListResponse( + items=[ + SkillAccessResponse( + **skill.model_dump(), + write_access=( + (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL) + or user.id == skill.user_id + or AccessGrants.has_access( + user_id=user.id, + resource_type="skill", + resource_id=skill.id, + permission="write", + db=db, + ) + ), + ) + for skill in result.items + ], + total=result.total, + ) + + +############################ +# ExportSkills +############################ + + +@router.get("/export", response_model=list[SkillModel]) +async def export_skills( + request: Request, + user=Depends(get_verified_user), + db: Session = Depends(get_session), +): + if user.role != "admin" and not has_permission( + user.id, + "workspace.skills", + request.app.state.config.USER_PERMISSIONS, + db=db, + ): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.UNAUTHORIZED, + ) + + if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL: + return Skills.get_skills(db=db) + else: + return Skills.get_skills_by_user_id(user.id, "read", db=db) + + +############################ +# CreateNewSkill +############################ + + +@router.post("/create", response_model=Optional[SkillResponse]) +async def create_new_skill( + request: Request, + form_data: SkillForm, + user=Depends(get_verified_user), + db: Session = Depends(get_session), +): + if user.role != "admin" and not has_permission( + user.id, "workspace.skills", request.app.state.config.USER_PERMISSIONS, db=db + ): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.UNAUTHORIZED, + ) + + form_data.id = form_data.id.lower().replace(" ", "-") + + existing = Skills.get_skill_by_id(form_data.id, db=db) + if existing is not None: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.ID_TAKEN, + ) + + try: + skill = Skills.insert_new_skill(user.id, form_data, db=db) + if skill: + return skill + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT("Error creating skill"), + ) + except Exception as e: + log.exception(f"Failed to create skill: {e}") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT(str(e)), + ) + + +############################ +# GetSkillById +############################ + + +@router.get("/id/{id}", response_model=Optional[SkillAccessResponse]) +async def get_skill_by_id( + id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) +): + skill = Skills.get_skill_by_id(id, db=db) + + if skill: + if ( + user.role == "admin" + or skill.user_id == user.id + or AccessGrants.has_access( + user_id=user.id, + resource_type="skill", + resource_id=skill.id, + permission="read", + db=db, + ) + ): + return SkillAccessResponse( + **skill.model_dump(), + write_access=( + (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL) + or user.id == skill.user_id + or AccessGrants.has_access( + user_id=user.id, + resource_type="skill", + resource_id=skill.id, + permission="write", + db=db, + ) + ), + ) + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) + else: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + +############################ +# UpdateSkillById +############################ + + +@router.post("/id/{id}/update", response_model=Optional[SkillModel]) +async def update_skill_by_id( + request: Request, + id: str, + form_data: SkillForm, + user=Depends(get_verified_user), + db: Session = Depends(get_session), +): + skill = Skills.get_skill_by_id(id, db=db) + if not skill: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + if ( + skill.user_id != user.id + and not AccessGrants.has_access( + user_id=user.id, + resource_type="skill", + resource_id=skill.id, + permission="write", + db=db, + ) + and user.role != "admin" + ): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.UNAUTHORIZED, + ) + + try: + updated = { + **form_data.model_dump(exclude={"id"}), + } + + skill = Skills.update_skill_by_id(id, updated, db=db) + + if skill: + return skill + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT("Error updating skill"), + ) + except Exception as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT(str(e)), + ) + + +############################ +# UpdateSkillAccessById +############################ + + +class SkillAccessGrantsForm(BaseModel): + access_grants: list[dict] + + +@router.post("/id/{id}/access/update", response_model=Optional[SkillModel]) +async def update_skill_access_by_id( + id: str, + form_data: SkillAccessGrantsForm, + user=Depends(get_verified_user), + db: Session = Depends(get_session), +): + skill = Skills.get_skill_by_id(id, db=db) + if not skill: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + if ( + skill.user_id != user.id + and not AccessGrants.has_access( + user_id=user.id, + resource_type="skill", + resource_id=skill.id, + permission="write", + db=db, + ) + and user.role != "admin" + ): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.UNAUTHORIZED, + ) + + AccessGrants.set_access_grants("skill", id, form_data.access_grants, db=db) + + return Skills.get_skill_by_id(id, db=db) + + +############################ +# ToggleSkillById +############################ + + +@router.post("/id/{id}/toggle", response_model=Optional[SkillModel]) +async def toggle_skill_by_id( + id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) +): + skill = Skills.get_skill_by_id(id, db=db) + if skill: + if ( + user.role == "admin" + or skill.user_id == user.id + or AccessGrants.has_access( + user_id=user.id, + resource_type="skill", + resource_id=skill.id, + permission="write", + db=db, + ) + ): + skill = Skills.toggle_skill_by_id(id, db=db) + + if skill: + return skill + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT("Error toggling skill"), + ) + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.UNAUTHORIZED, + ) + else: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + +############################ +# DeleteSkillById +############################ + + +@router.delete("/id/{id}/delete", response_model=bool) +async def delete_skill_by_id( + request: Request, + id: str, + user=Depends(get_verified_user), + db: Session = Depends(get_session), +): + skill = Skills.get_skill_by_id(id, db=db) + if not skill: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + if ( + skill.user_id != user.id + and not AccessGrants.has_access( + user_id=user.id, + resource_type="skill", + resource_id=skill.id, + permission="write", + db=db, + ) + and user.role != "admin" + ): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.UNAUTHORIZED, + ) + + result = Skills.delete_skill_by_id(id, db=db) + return result diff --git a/backend/open_webui/routers/tasks.py b/backend/open_webui/routers/tasks.py index 4d4059da19..d322fca0b6 100644 --- a/backend/open_webui/routers/tasks.py +++ b/backend/open_webui/routers/tasks.py @@ -38,7 +38,6 @@ CREDIT_NO_CREDIT_MSG, ) - log = logging.getLogger(__name__) router = APIRouter() @@ -51,6 +50,21 @@ ################################## +class ActiveChatsForm(BaseModel): + chat_ids: list[str] + + +@router.post("/active/chats") +async def check_active_chats( + request: Request, form_data: ActiveChatsForm, user=Depends(get_verified_user) +): + """Check which chat IDs have active tasks.""" + from open_webui.tasks import get_active_chat_ids + + active = await get_active_chat_ids(request.app.state.redis, form_data.chat_ids) + return {"active_chat_ids": active} + + @router.get("/config") async def get_task_config(request: Request, user=Depends(get_verified_user)): return { diff --git a/backend/open_webui/routers/tools.py b/backend/open_webui/routers/tools.py index 03018d24a1..057eb509a1 100644 --- a/backend/open_webui/routers/tools.py +++ b/backend/open_webui/routers/tools.py @@ -21,10 +21,12 @@ ToolAccessResponse, Tools, ) +from open_webui.models.access_grants import AccessGrants from open_webui.utils.plugin import ( load_tool_module_by_id, replace_imports, get_tool_module_from_cache, + resolve_valves_schema_options, ) from open_webui.utils.tools import get_tool_specs from open_webui.utils.auth import get_admin_user, get_verified_user @@ -34,7 +36,6 @@ from open_webui.config import CACHE_DIR, BYPASS_ADMIN_ACCESS_CONTROL from open_webui.constants import ERROR_MESSAGES - log = logging.getLogger(__name__) @@ -75,12 +76,21 @@ async def get_tools( ) # OpenAPI Tool Servers + server_access_grants = {} for server in await get_tool_servers(request): + connection = request.app.state.config.TOOL_SERVER_CONNECTIONS[ + server.get("idx", 0) + ] + server_config = connection.get("config", {}) + + server_id = f"server:{server.get('id')}" + server_access_grants[server_id] = server_config.get("access_grants", []) + tools.append( ToolUserResponse( **{ - "id": f"server:{server.get('id')}", - "user_id": f"server:{server.get('id')}", + "id": server_id, + "user_id": server_id, "name": server.get("openapi", {}) .get("info", {}) .get("title", "Tool Server"), @@ -89,11 +99,6 @@ async def get_tools( .get("info", {}) .get("description", ""), }, - "access_control": request.app.state.config.TOOL_SERVER_CONNECTIONS[ - server.get("idx", 0) - ] - .get("config", {}) - .get("access_control", None), "updated_at": int(time.time()), "created_at": int(time.time()), } @@ -117,20 +122,22 @@ async def get_tools( ) ) + server_config = server.get("config", {}) + + tool_id = f"server:mcp:{server.get('info', {}).get('id')}" + server_access_grants[tool_id] = server_config.get("access_grants", []) + tools.append( ToolUserResponse( **{ - "id": f"server:mcp:{server.get('info', {}).get('id')}", - "user_id": f"server:mcp:{server.get('info', {}).get('id')}", + "id": tool_id, + "user_id": tool_id, "name": server.get("info", {}).get("name", "MCP Tool Server"), "meta": { "description": server.get("info", {}).get( "description", "" ), }, - "access_control": server.get("config", {}).get( - "access_control", None - ), "updated_at": int(time.time()), "created_at": int(time.time()), **( @@ -155,7 +162,24 @@ async def get_tools( tool for tool in tools if tool.user_id == user.id - or has_access(user.id, "read", tool.access_control, user_group_ids, db=db) + or ( + has_access( + user.id, + "read", + server_access_grants.get(str(tool.id), []), + user_group_ids, + db=db, + ) + if str(tool.id).startswith("server:") + else AccessGrants.has_access( + user_id=user.id, + resource_type="tool", + resource_id=tool.id, + permission="read", + user_group_ids=user_group_ids, + db=db, + ) + ) ] return tools @@ -180,7 +204,13 @@ async def get_tool_list( write_access=( (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL) or user.id == tool.user_id - or has_access(user.id, "write", tool.access_control, db=db) + or AccessGrants.has_access( + user_id=user.id, + resource_type="tool", + resource_id=tool.id, + permission="write", + db=db, + ) ), ) for tool in tools @@ -381,14 +411,26 @@ async def get_tools_by_id( if ( user.role == "admin" or tools.user_id == user.id - or has_access(user.id, "read", tools.access_control, db=db) + or AccessGrants.has_access( + user_id=user.id, + resource_type="tool", + resource_id=tools.id, + permission="read", + db=db, + ) ): return ToolAccessResponse( **tools.model_dump(), write_access=( (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL) or user.id == tools.user_id - or has_access(user.id, "write", tools.access_control, db=db) + or AccessGrants.has_access( + user_id=user.id, + resource_type="tool", + resource_id=tools.id, + permission="write", + db=db, + ) ), ) else: @@ -426,7 +468,13 @@ async def update_tools_by_id( # Is the user the original creator, in a group with write access, or an admin if ( tools.user_id != user.id - and not has_access(user.id, "write", tools.access_control, db=db) + and not AccessGrants.has_access( + user_id=user.id, + resource_type="tool", + resource_id=tools.id, + permission="write", + db=db, + ) and user.role != "admin" ): raise HTTPException( @@ -467,6 +515,50 @@ async def update_tools_by_id( ) +############################ +# UpdateToolAccessById +############################ + + +class ToolAccessGrantsForm(BaseModel): + access_grants: list[dict] + + +@router.post("/id/{id}/access/update", response_model=Optional[ToolModel]) +async def update_tool_access_by_id( + id: str, + form_data: ToolAccessGrantsForm, + user=Depends(get_verified_user), + db: Session = Depends(get_session), +): + tools = Tools.get_tool_by_id(id, db=db) + if not tools: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + if ( + tools.user_id != user.id + and not AccessGrants.has_access( + user_id=user.id, + resource_type="tool", + resource_id=tools.id, + permission="write", + db=db, + ) + and user.role != "admin" + ): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.UNAUTHORIZED, + ) + + AccessGrants.set_access_grants("tool", id, form_data.access_grants, db=db) + + return Tools.get_tool_by_id(id, db=db) + + ############################ # DeleteToolsById ############################ @@ -488,7 +580,13 @@ async def delete_tools_by_id( if ( tools.user_id != user.id - and not has_access(user.id, "write", tools.access_control, db=db) + and not AccessGrants.has_access( + user_id=user.id, + resource_type="tool", + resource_id=tools.id, + permission="write", + db=db, + ) and user.role != "admin" ): raise HTTPException( @@ -553,7 +651,10 @@ async def get_tools_valves_spec_by_id( if hasattr(tools_module, "Valves"): Valves = tools_module.Valves - return Valves.schema() + schema = Valves.schema() + # Resolve dynamic options for select dropdowns + schema = resolve_valves_schema_options(Valves, schema, user) + return schema return None else: raise HTTPException( @@ -584,7 +685,13 @@ async def update_tools_valves_by_id( if ( tools.user_id != user.id - and not has_access(user.id, "write", tools.access_control, db=db) + and not AccessGrants.has_access( + user_id=user.id, + resource_type="tool", + resource_id=tools.id, + permission="write", + db=db, + ) and user.role != "admin" ): raise HTTPException( @@ -662,7 +769,10 @@ async def get_tools_user_valves_spec_by_id( if hasattr(tools_module, "UserValves"): UserValves = tools_module.UserValves - return UserValves.schema() + schema = UserValves.schema() + # Resolve dynamic options for select dropdowns + schema = resolve_valves_schema_options(UserValves, schema, user) + return schema return None else: raise HTTPException( diff --git a/backend/open_webui/routers/users.py b/backend/open_webui/routers/users.py index 828de4a429..1c297d909b 100644 --- a/backend/open_webui/routers/users.py +++ b/backend/open_webui/routers/users.py @@ -26,7 +26,7 @@ UserModel, UserGroupIdsModel, UserGroupIdsListResponse, - UserInfoListResponse, + UserInfoResponse, UserInfoListResponse, UserRoleUpdateForm, UserStatus, @@ -49,7 +49,6 @@ ) from open_webui.utils.access_control import get_permissions, has_permission - log = logging.getLogger(__name__) router = APIRouter() @@ -200,6 +199,7 @@ class WorkspacePermissions(BaseModel): knowledge: bool = False prompts: bool = False tools: bool = False + skills: bool = False models_import: bool = False models_export: bool = False prompts_import: bool = False @@ -473,7 +473,7 @@ class UserActiveResponse(UserStatus): @router.get("/{user_id}", response_model=UserActiveResponse) async def get_user_by_id( - user_id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) + user_id: str, user=Depends(get_admin_user), db: Session = Depends(get_session) ): # Check if user_id is a shared chat # If it is, get the user_id from the chat @@ -505,6 +505,27 @@ async def get_user_by_id( ) +@router.get("/{user_id}/info", response_model=UserInfoResponse) +async def get_user_info_by_id( + user_id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) +): + user = Users.get_user_by_id(user_id, db=db) + if user: + groups = Groups.get_groups_by_member_id(user_id, db=db) + return UserInfoResponse( + **{ + **user.model_dump(), + "groups": [{"id": group.id, "name": group.name} for group in groups], + "is_active": Users.is_user_active(user_id, db=db), + } + ) + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.USER_NOT_FOUND, + ) + + @router.get("/{user_id}/oauth/sessions") async def get_user_oauth_sessions_by_id( user_id: str, user=Depends(get_admin_user), db: Session = Depends(get_session) @@ -525,10 +546,8 @@ async def get_user_oauth_sessions_by_id( @router.get("/{user_id}/profile/image") -async def get_user_profile_image_by_id( - user_id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) -): - user = Users.get_user_by_id(user_id, db=db) +def get_user_profile_image_by_id(user_id: str, user=Depends(get_verified_user)): + user = Users.get_user_by_id(user_id) if user: if user.profile_image_url: # check if it's url or base64 diff --git a/backend/open_webui/routers/utils.py b/backend/open_webui/routers/utils.py index 22529ab1b9..49f3a5ca55 100644 --- a/backend/open_webui/routers/utils.py +++ b/backend/open_webui/routers/utils.py @@ -15,7 +15,6 @@ from open_webui.utils.auth import get_admin_user, get_verified_user from open_webui.utils.code_interpreter import execute_code_jupyter - log = logging.getLogger(__name__) router = APIRouter() diff --git a/backend/open_webui/socket/main.py b/backend/open_webui/socket/main.py index 67e04e69c3..0d762ee5b2 100644 --- a/backend/open_webui/socket/main.py +++ b/backend/open_webui/socket/main.py @@ -42,14 +42,14 @@ from open_webui.socket.utils import RedisDict, RedisLock, YdocManager from open_webui.tasks import create_task, stop_item_tasks from open_webui.utils.redis import get_redis_connection -from open_webui.utils.access_control import has_access, get_users_with_access +from open_webui.utils.access_control import has_permission +from open_webui.models.access_grants import AccessGrants from open_webui.env import ( GLOBAL_LOG_LEVEL, ) - logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL) log = logging.getLogger(__name__) @@ -346,11 +346,12 @@ async def user_join(sid, data): await sio.enter_room(sid, f"user:{user.id}") - # Join all the channels - channels = Channels.get_channels_by_user_id(user.id) - log.debug(f"{channels=}") - for channel in channels: - await sio.enter_room(sid, f"channel:{channel.id}") + # Join all the channels only if user has channels permission + if user.role == "admin" or has_permission(user.id, "features.channels"): + channels = Channels.get_channels_by_user_id(user.id) + log.debug(f"{channels=}") + for channel in channels: + await sio.enter_room(sid, f"channel:{channel.id}") return {"id": user.id, "name": user.name} @@ -376,11 +377,12 @@ async def join_channel(sid, data): if not user: return - # Join all the channels - channels = Channels.get_channels_by_user_id(user.id) - log.debug(f"{channels=}") - for channel in channels: - await sio.enter_room(sid, f"channel:{channel.id}") + # Join all the channels only if user has channels permission + if user.role == "admin" or has_permission(user.id, "features.channels"): + channels = Channels.get_channels_by_user_id(user.id) + log.debug(f"{channels=}") + for channel in channels: + await sio.enter_room(sid, f"channel:{channel.id}") @sio.on("join-note") @@ -405,7 +407,12 @@ async def join_note(sid, data): if ( user.role != "admin" and user.id != note.user_id - and not has_access(user.id, type="read", access_control=note.access_control) + and not AccessGrants.has_access( + user_id=user.id, + resource_type="note", + resource_id=note.id, + permission="read", + ) ): log.error(f"User {user.id} does not have access to note {data['note_id']}") return @@ -467,8 +474,11 @@ async def ydoc_document_join(sid, data): if ( user.get("role") != "admin" and user.get("id") != note.user_id - and not has_access( - user.get("id"), type="read", access_control=note.access_control + and not AccessGrants.has_access( + user_id=user.get("id"), + resource_type="note", + resource_id=note.id, + permission="read", ) ): log.error( @@ -537,8 +547,11 @@ async def document_save_handler(document_id, data, user): if ( user.get("role") != "admin" and user.get("id") != note.user_id - and not has_access( - user.get("id"), type="read", access_control=note.access_control + and not AccessGrants.has_access( + user_id=user.get("id"), + resource_type="note", + resource_id=note.id, + permission="read", ) ): log.error(f"User {user.get('id')} does not have access to note {note_id}") diff --git a/backend/open_webui/storage/provider.py b/backend/open_webui/storage/provider.py index ce02105bfa..425d10c812 100644 --- a/backend/open_webui/storage/provider.py +++ b/backend/open_webui/storage/provider.py @@ -34,7 +34,6 @@ from azure.storage.blob import BlobServiceClient from azure.core.exceptions import ResourceNotFoundError - log = logging.getLogger(__name__) diff --git a/backend/open_webui/tasks.py b/backend/open_webui/tasks.py index d83226ffb7..04dfb5a556 100644 --- a/backend/open_webui/tasks.py +++ b/backend/open_webui/tasks.py @@ -10,7 +10,6 @@ from open_webui.env import REDIS_KEY_PREFIX - log = logging.getLogger(__name__) # A dictionary to keep track of active tasks @@ -74,7 +73,13 @@ async def redis_list_item_tasks(redis: Redis, item_id: str) -> List[str]: async def redis_send_command(redis: Redis, command: dict): - await redis.publish(REDIS_PUBSUB_CHANNEL, json.dumps(command)) + command_json = json.dumps(command) + # RedisCluster doesn't expose publish() directly, but the + # PUBLISH command broadcasts across all cluster nodes server-side. + if hasattr(redis, "nodes_manager"): + await redis.execute_command("PUBLISH", REDIS_PUBSUB_CHANNEL, command_json) + else: + await redis.publish(REDIS_PUBSUB_CHANNEL, command_json) async def cleanup_task(redis, task_id: str, id=None): @@ -183,3 +188,18 @@ async def stop_item_tasks(redis: Redis, item_id: str): return result # Return the first failure return {"status": True, "message": f"All tasks for item {item_id} stopped."} + + +async def has_active_tasks(redis, chat_id: str) -> bool: + """Check if a chat has any active tasks.""" + task_ids = await list_task_ids_by_item_id(redis, chat_id) + return len(task_ids) > 0 + + +async def get_active_chat_ids(redis, chat_ids: List[str]) -> List[str]: + """Filter a list of chat_ids to only those with active tasks.""" + active = [] + for chat_id in chat_ids: + if await has_active_tasks(redis, chat_id): + active.append(chat_id) + return active diff --git a/backend/open_webui/tools/builtin.py b/backend/open_webui/tools/builtin.py index eb3b7cfc9f..0175ba5583 100644 --- a/backend/open_webui/tools/builtin.py +++ b/backend/open_webui/tools/builtin.py @@ -36,6 +36,7 @@ from open_webui.models.channels import Channels, ChannelMember, Channel from open_webui.models.messages import Messages, Message from open_webui.models.groups import Groups +from open_webui.utils.sanitize import sanitize_code log = logging.getLogger(__name__) @@ -166,7 +167,7 @@ async def search_web( engine = __request__.app.state.config.WEB_SEARCH_ENGINE user = UserModel(**__user__) if __user__ else None - results = _search_web(__request__, engine, query, user) + results = await asyncio.to_thread(_search_web, __request__, engine, query, user) # Limit results results = results[:count] if results else [] @@ -341,6 +342,178 @@ async def edit_image( return json.dumps({"error": str(e)}) +# ============================================================================= +# CODE INTERPRETER TOOLS +# ============================================================================= + + +async def execute_code( + code: str, + __request__: Request = None, + __user__: dict = None, + __event_emitter__: callable = None, + __event_call__: callable = None, + __chat_id__: str = None, + __message_id__: str = None, + __metadata__: dict = None, +) -> str: + """ + Execute Python code in a sandboxed environment and return the output. + Use this to perform calculations, data analysis, generate visualizations, + or run any Python code that would help answer the user's question. + + :param code: The Python code to execute + :return: JSON with stdout, stderr, and result from execution + """ + from uuid import uuid4 + + if __request__ is None: + return json.dumps({"error": "Request context not available"}) + + try: + # Sanitize code (strips ANSI codes and markdown fences) + code = sanitize_code(code) + + # Import blocked modules from config (same as middleware) + from open_webui.config import CODE_INTERPRETER_BLOCKED_MODULES + + # Add import blocking code if there are blocked modules + if CODE_INTERPRETER_BLOCKED_MODULES: + import textwrap + + blocking_code = textwrap.dedent(f""" + import builtins + + BLOCKED_MODULES = {CODE_INTERPRETER_BLOCKED_MODULES} + + _real_import = builtins.__import__ + def restricted_import(name, globals=None, locals=None, fromlist=(), level=0): + if name.split('.')[0] in BLOCKED_MODULES: + importer_name = globals.get('__name__') if globals else None + if importer_name == '__main__': + raise ImportError( + f"Direct import of module {{name}} is restricted." + ) + return _real_import(name, globals, locals, fromlist, level) + + builtins.__import__ = restricted_import + """) + code = blocking_code + "\n" + code + + engine = getattr( + __request__.app.state.config, "CODE_INTERPRETER_ENGINE", "pyodide" + ) + if engine == "pyodide": + # Execute via frontend pyodide using bidirectional event call + if __event_call__ is None: + return json.dumps( + { + "error": "Event call not available. WebSocket connection required for pyodide execution." + } + ) + + output = await __event_call__( + { + "type": "execute:python", + "data": { + "id": str(uuid4()), + "code": code, + "session_id": ( + __metadata__.get("session_id") if __metadata__ else None + ), + }, + } + ) + + # Parse the output - pyodide returns dict with stdout, stderr, result + if isinstance(output, dict): + stdout = output.get("stdout", "") + stderr = output.get("stderr", "") + result = output.get("result", "") + else: + stdout = "" + stderr = "" + result = str(output) if output else "" + + elif engine == "jupyter": + from open_webui.utils.code_interpreter import execute_code_jupyter + + output = await execute_code_jupyter( + __request__.app.state.config.CODE_INTERPRETER_JUPYTER_URL, + code, + ( + __request__.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN + if __request__.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH + == "token" + else None + ), + ( + __request__.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD + if __request__.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH + == "password" + else None + ), + __request__.app.state.config.CODE_INTERPRETER_JUPYTER_TIMEOUT, + ) + + stdout = output.get("stdout", "") + stderr = output.get("stderr", "") + result = output.get("result", "") + + else: + return json.dumps({"error": f"Unknown code interpreter engine: {engine}"}) + + # Handle image outputs (base64 encoded) - replace with uploaded URLs + # Get actual user object for image upload (upload_image requires user.id attribute) + if __user__ and __user__.get("id"): + from open_webui.models.users import Users + from open_webui.utils.files import get_image_url_from_base64 + + user = Users.get_user_by_id(__user__["id"]) + + # Extract and upload images from stdout + if stdout and isinstance(stdout, str): + stdout_lines = stdout.split("\n") + for idx, line in enumerate(stdout_lines): + if "data:image/png;base64" in line: + image_url = get_image_url_from_base64( + __request__, + line, + __metadata__ or {}, + user, + ) + if image_url: + stdout_lines[idx] = f"![Output Image]({image_url})" + stdout = "\n".join(stdout_lines) + + # Extract and upload images from result + if result and isinstance(result, str): + result_lines = result.split("\n") + for idx, line in enumerate(result_lines): + if "data:image/png;base64" in line: + image_url = get_image_url_from_base64( + __request__, + line, + __metadata__ or {}, + user, + ) + if image_url: + result_lines[idx] = f"![Output Image]({image_url})" + result = "\n".join(result_lines) + + response = { + "status": "success", + "stdout": stdout, + "stderr": stderr, + "result": result, + } + + return json.dumps(response, ensure_ascii=False) + except Exception as e: + log.exception(f"execute_code error: {e}") + return json.dumps({"error": str(e)}) + + # ============================================================================= # MEMORY TOOLS # ============================================================================= @@ -579,10 +752,14 @@ async def view_note( user_id = __user__.get("id") user_group_ids = [group.id for group in Groups.get_groups_by_member_id(user_id)] - from open_webui.utils.access_control import has_access + from open_webui.models.access_grants import AccessGrants - if note.user_id != user_id and not has_access( - user_id, "read", note.access_control, user_group_ids + if note.user_id != user_id and not AccessGrants.has_access( + user_id=user_id, + resource_type="note", + resource_id=note.id, + permission="read", + user_group_ids=set(user_group_ids), ): return json.dumps({"error": "Access denied"}) @@ -633,7 +810,7 @@ async def write_note( form = NoteForm( title=title, data={"content": {"md": content}}, - access_control={}, # Private by default - only owner can access + access_grants=[], # Private by default - only owner can access ) new_note = Notes.insert_new_note(user_id, form) @@ -688,10 +865,14 @@ async def replace_note_content( user_id = __user__.get("id") user_group_ids = [group.id for group in Groups.get_groups_by_member_id(user_id)] - from open_webui.utils.access_control import has_access + from open_webui.models.access_grants import AccessGrants - if note.user_id != user_id and not has_access( - user_id, "write", note.access_control, user_group_ids + if note.user_id != user_id and not AccessGrants.has_access( + user_id=user_id, + resource_type="note", + resource_id=note.id, + permission="write", + user_group_ids=set(user_group_ids), ): return json.dumps({"error": "Write access denied"}) @@ -732,6 +913,7 @@ async def search_chats( end_timestamp: Optional[int] = None, __request__: Request = None, __user__: dict = None, + __chat_id__: str = None, ) -> str: """ Search the user's previous chat conversations by title and message content. @@ -761,6 +943,10 @@ async def search_chats( results = [] for chat in chats: + # Skip the current chat to avoid showing it in search results + if __chat_id__ and chat.id == __chat_id__: + continue + # Apply date filters (updated_at is in seconds) if start_timestamp and chat.updated_at < start_timestamp: continue @@ -1363,7 +1549,7 @@ async def view_knowledge_file( try: from open_webui.models.files import Files from open_webui.models.knowledge import Knowledges - from open_webui.utils.access_control import has_access + from open_webui.models.access_grants import AccessGrants user_id = __user__.get("id") user_role = __user__.get("role", "user") @@ -1382,8 +1568,12 @@ async def view_knowledge_file( if ( user_role == "admin" or knowledge_base.user_id == user_id - or has_access( - user_id, "read", knowledge_base.access_control, user_group_ids + or AccessGrants.has_access( + user_id=user_id, + resource_type="knowledge", + resource_id=knowledge_base.id, + permission="read", + user_group_ids=set(user_group_ids), ) ): has_knowledge_access = True @@ -1424,8 +1614,7 @@ async def query_knowledge_files( __model_knowledge__: list[dict] = None, ) -> str: """ - Search knowledge base files using semantic/vector search. This should be your first - choice for finding information before searching the web. Searches across collections (KBs), + Search knowledge base files using semantic/vector search. Searches across collections (KBs), individual files, and notes that the user has access to. :param query: The search query to find semantically relevant content @@ -1439,12 +1628,31 @@ async def query_knowledge_files( if not __user__: return json.dumps({"error": "User context not available"}) + # Coerce parameters from LLM tool calls (may come as strings) + if isinstance(count, str): + try: + count = int(count) + except ValueError: + count = 5 # Default fallback + + # Handle knowledge_ids being string "None", "null", or empty + if isinstance(knowledge_ids, str): + if knowledge_ids.lower() in ("none", "null", ""): + knowledge_ids = None + else: + # Try to parse as JSON array if it looks like one + try: + knowledge_ids = json.loads(knowledge_ids) + except json.JSONDecodeError: + # Treat as single ID + knowledge_ids = [knowledge_ids] + try: from open_webui.models.knowledge import Knowledges from open_webui.models.files import Files from open_webui.models.notes import Notes from open_webui.retrieval.utils import query_collection - from open_webui.utils.access_control import has_access + from open_webui.models.access_grants import AccessGrants user_id = __user__.get("id") user_role = __user__.get("role", "user") @@ -1469,8 +1677,12 @@ async def query_knowledge_files( if knowledge and ( user_role == "admin" or knowledge.user_id == user_id - or has_access( - user_id, "read", knowledge.access_control, user_group_ids + or AccessGrants.has_access( + user_id=user_id, + resource_type="knowledge", + resource_id=knowledge.id, + permission="read", + user_group_ids=set(user_group_ids), ) ): collection_names.append(item_id) @@ -1487,7 +1699,12 @@ async def query_knowledge_files( if note and ( user_role == "admin" or note.user_id == user_id - or has_access(user_id, "read", note.access_control) + or AccessGrants.has_access( + user_id=user_id, + resource_type="note", + resource_id=note.id, + permission="read", + ) ): content = note.data.get("content", {}).get("md", "") note_results.append( @@ -1506,8 +1723,12 @@ async def query_knowledge_files( if knowledge and ( user_role == "admin" or knowledge.user_id == user_id - or has_access( - user_id, "read", knowledge.access_control, user_group_ids + or AccessGrants.has_access( + user_id=user_id, + resource_type="knowledge", + resource_id=knowledge.id, + permission="read", + user_group_ids=set(user_group_ids), ) ): collection_names.append(knowledge_id) @@ -1669,3 +1890,65 @@ async def query_knowledge_bases( except Exception as e: log.exception(f"query_knowledge_bases error: {e}") return json.dumps({"error": str(e)}) + + +# ============================================================================= +# SKILLS TOOLS +# ============================================================================= + + +async def view_skill( + name: str, + __request__: Request = None, + __user__: dict = None, +) -> str: + """ + Load the full instructions of a skill by its name from the available skills manifest. + Use this when you need detailed instructions for a skill listed in . + + :param name: The name of the skill to load (as shown in the manifest) + :return: The full skill instructions as markdown content + """ + if __request__ is None: + return json.dumps({"error": "Request context not available"}) + + if not __user__: + return json.dumps({"error": "User context not available"}) + + try: + from open_webui.models.skills import Skills + from open_webui.models.access_grants import AccessGrants + + user_id = __user__.get("id") + + # Direct DB lookup by unique name + skill = Skills.get_skill_by_name(name) + + if not skill or not skill.is_active: + return json.dumps({"error": f"Skill '{name}' not found"}) + + # Check user access + user_role = __user__.get("role", "user") + if user_role != "admin" and skill.user_id != user_id: + user_group_ids = [ + group.id for group in Groups.get_groups_by_member_id(user_id) + ] + if not AccessGrants.has_access( + user_id=user_id, + resource_type="skill", + resource_id=skill.id, + permission="read", + user_group_ids=set(user_group_ids), + ): + return json.dumps({"error": "Access denied"}) + + return json.dumps( + { + "name": skill.name, + "content": skill.content, + }, + ensure_ascii=False, + ) + except Exception as e: + log.exception(f"view_skill error: {e}") + return json.dumps({"error": str(e)}) diff --git a/backend/open_webui/utils/access_control.py b/backend/open_webui/utils/access_control.py index 7784f6efd7..b7ea9830db 100644 --- a/backend/open_webui/utils/access_control.py +++ b/backend/open_webui/utils/access_control.py @@ -107,71 +107,90 @@ def get_permission(permissions: Dict[str, Any], keys: List[str]) -> bool: return get_permission(default_permissions, permission_hierarchy) -def get_permitted_group_and_user_ids( - type: str = "write", access_control: Optional[dict] = None -) -> Union[Dict[str, List[str]], None]: - if access_control is None: - return None - - permission_access = access_control.get(type, {}) - permitted_group_ids = permission_access.get("group_ids", []) - permitted_user_ids = permission_access.get("user_ids", []) - - return { - "group_ids": permitted_group_ids, - "user_ids": permitted_user_ids, - } - - def has_access( user_id: str, - type: str = "write", - access_control: Optional[dict] = None, + permission: str = "read", + access_grants: Optional[list] = None, user_group_ids: Optional[Set[str]] = None, - strict: bool = True, db: Optional[Any] = None, ) -> bool: - if access_control is None: - if strict: - return type == "read" - else: - return True + """ + Check if a user has the specified permission using an in-memory access_grants list. - if user_group_ids is None: - user_groups = Groups.get_groups_by_member_id(user_id, db=db) - user_group_ids = {group.id for group in user_groups} + Used for config-driven resources (arena models, tool servers) that store + access control as JSON in PersistentConfig rather than in the access_grant DB table. - permitted_ids = get_permitted_group_and_user_ids(type, access_control) - if permitted_ids is None: + Semantics: + - None or [] → private (owner-only, deny all) + - [{"principal_type": "user", "principal_id": "*", "permission": "read"}] → public read + - Specific grants → check user/group membership + """ + if not access_grants: return False - permitted_group_ids = permitted_ids.get("group_ids", []) - permitted_user_ids = permitted_ids.get("user_ids", []) - - return user_id in permitted_user_ids or any( - group_id in permitted_group_ids for group_id in user_group_ids - ) - - -# Get all users with access to a resource -def get_users_with_access( - type: str = "write", access_control: Optional[dict] = None, db: Optional[Any] = None -) -> list[UserModel]: - if access_control is None: - result = Users.get_users(filter={"roles": ["!pending"]}, db=db) - return result.get("users", []) + if user_group_ids is None: + user_groups = Groups.get_groups_by_member_id(user_id, db=db) + user_group_ids = {group.id for group in user_groups} - permitted_ids = get_permitted_group_and_user_ids(type, access_control) - if permitted_ids is None: - return [] + for grant in access_grants: + if not isinstance(grant, dict): + continue + if grant.get("permission") != permission: + continue + principal_type = grant.get("principal_type") + principal_id = grant.get("principal_id") + if principal_type == "user" and ( + principal_id == "*" or principal_id == user_id + ): + return True + if ( + principal_type == "group" + and user_group_ids + and principal_id in user_group_ids + ): + return True - permitted_group_ids = permitted_ids.get("group_ids", []) - permitted_user_ids = permitted_ids.get("user_ids", []) + return False - user_ids_with_access = set(permitted_user_ids) - group_user_ids_map = Groups.get_group_user_ids_by_ids(permitted_group_ids, db=db) - for user_ids in group_user_ids_map.values(): - user_ids_with_access.update(user_ids) +def migrate_access_control( + data: dict, ac_key: str = "access_control", grants_key: str = "access_grants" +) -> None: + """ + Auto-migrate a config dict in-place from legacy access_control dict to access_grants list. - return Users.get_users_by_user_ids(list(user_ids_with_access), db=db) + If `grants_key` already exists, does nothing. + If `ac_key` exists (old format), converts it and stores as `grants_key`, then removes `ac_key`. + """ + if grants_key in data: + return + + access_control = data.get(ac_key) + if access_control is None and ac_key not in data: + return + + grants: List[Dict[str, str]] = [] + if access_control and isinstance(access_control, dict): + for perm in ["read", "write"]: + perm_data = access_control.get(perm, {}) + if not perm_data: + continue + for group_id in perm_data.get("group_ids", []): + grants.append( + { + "principal_type": "group", + "principal_id": group_id, + "permission": perm, + } + ) + for uid in perm_data.get("user_ids", []): + grants.append( + { + "principal_type": "user", + "principal_id": uid, + "permission": perm, + } + ) + + data[grants_key] = grants + data.pop(ac_key, None) diff --git a/backend/open_webui/utils/actions.py b/backend/open_webui/utils/actions.py new file mode 100644 index 0000000000..0b4b817f0a --- /dev/null +++ b/backend/open_webui/utils/actions.py @@ -0,0 +1,139 @@ +import logging +import sys +import inspect + +from typing import Any + +from fastapi import Request + +from open_webui.models.users import UserModel +from open_webui.models.functions import Functions + +from open_webui.socket.main import get_event_call, get_event_emitter +from open_webui.utils.plugin import get_function_module_from_cache +from open_webui.utils.models import get_all_models +from open_webui.utils.middleware import process_tool_result + +from open_webui.env import GLOBAL_LOG_LEVEL + +logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL) +log = logging.getLogger(__name__) + + +async def chat_action(request: Request, action_id: str, form_data: dict, user: Any): + if "." in action_id: + action_id, sub_action_id = action_id.split(".") + else: + sub_action_id = None + + action = Functions.get_function_by_id(action_id) + if not action: + raise Exception(f"Action not found: {action_id}") + + if not request.app.state.MODELS: + await get_all_models(request, user=user) + + if getattr(request.state, "direct", False) and hasattr(request.state, "model"): + models = { + request.state.model["id"]: request.state.model, + } + else: + models = request.app.state.MODELS + + data = form_data + model_id = data["model"] + + if model_id not in models: + raise Exception("Model not found") + model = models[model_id] + + __event_emitter__ = get_event_emitter( + { + "chat_id": data["chat_id"], + "message_id": data["id"], + "session_id": data["session_id"], + "user_id": user.id, + } + ) + __event_call__ = get_event_call( + { + "chat_id": data["chat_id"], + "message_id": data["id"], + "session_id": data["session_id"], + "user_id": user.id, + } + ) + + function_module, _, _ = get_function_module_from_cache(request, action_id) + + if hasattr(function_module, "valves") and hasattr(function_module, "Valves"): + valves = Functions.get_function_valves_by_id(action_id) + function_module.valves = function_module.Valves(**(valves if valves else {})) + + if hasattr(function_module, "action"): + try: + action = function_module.action + + # Get the signature of the function + sig = inspect.signature(action) + params = {"body": data} + + # Extra parameters to be passed to the function + extra_params = { + "__model__": model, + "__id__": sub_action_id if sub_action_id is not None else action_id, + "__event_emitter__": __event_emitter__, + "__event_call__": __event_call__, + "__request__": request, + } + + # Add extra params in contained in function signature + for key, value in extra_params.items(): + if key in sig.parameters: + params[key] = value + + if "__user__" in sig.parameters: + __user__ = user.model_dump() if isinstance(user, UserModel) else {} + + try: + if hasattr(function_module, "UserValves"): + __user__["valves"] = function_module.UserValves( + **Functions.get_user_valves_by_id_and_user_id( + action_id, user.id + ) + ) + except Exception as e: + log.exception(f"Failed to get user values: {e}") + + params = {**params, "__user__": __user__} + + if inspect.iscoroutinefunction(action): + data = await action(**params) + else: + data = action(**params) + + # Process action result for Rich UI embeds (HTMLResponse, tuple with headers) + processed_result, _, action_embeds = process_tool_result( + request, + action_id, + data, + "action", + ) + + if action_embeds: + await __event_emitter__( + { + "type": "embeds", + "data": { + "embeds": action_embeds, + }, + } + ) + # Replace data with the processed status dict so we don't + # try to serialize the raw HTMLResponse / tuple back to the client + data = processed_result + + except Exception as e: + raise Exception(f"Error: {e}") + + return data diff --git a/backend/open_webui/utils/audit.py b/backend/open_webui/utils/audit.py index dc1226a080..c4abb445b9 100644 --- a/backend/open_webui/utils/audit.py +++ b/backend/open_webui/utils/audit.py @@ -28,7 +28,6 @@ from open_webui.utils.auth import get_current_user, get_http_authorization_cred from open_webui.models.users import UserModel - if TYPE_CHECKING: from loguru import Logger @@ -221,7 +220,10 @@ def _should_skip_auditing(self, request: Request) -> bool: return False # Do NOT skip logging for auth endpoints # Skip logging if the request is not authenticated - if not request.headers.get("authorization"): + # Check both Authorization header (API keys) and token cookie (browser sessions) + if not request.headers.get("authorization") and not request.cookies.get( + "token" + ): return True # match either /api//...(for the endpoint /api/chat case) or /api/v1//... diff --git a/backend/open_webui/utils/auth.py b/backend/open_webui/utils/auth.py index 87ebb19b63..542f469dfd 100644 --- a/backend/open_webui/utils/auth.py +++ b/backend/open_webui/utils/auth.py @@ -36,6 +36,7 @@ ENABLE_PASSWORD_VALIDATION, OFFLINE_MODE, LICENSE_BLOB, + PASSWORD_VALIDATION_HINT, PASSWORD_VALIDATION_REGEX_PATTERN, REDIS_KEY_PREFIX, pk, @@ -191,7 +192,7 @@ def validate_password(password: str) -> bool: if ENABLE_PASSWORD_VALIDATION: if not PASSWORD_VALIDATION_REGEX_PATTERN.match(password): - raise Exception(ERROR_MESSAGES.INVALID_PASSWORD()) + raise Exception(ERROR_MESSAGES.INVALID_PASSWORD(PASSWORD_VALIDATION_HINT)) return True diff --git a/backend/open_webui/utils/chat.py b/backend/open_webui/utils/chat.py index 1424bdcf16..71efd7423d 100644 --- a/backend/open_webui/utils/chat.py +++ b/backend/open_webui/utils/chat.py @@ -7,7 +7,7 @@ from typing import Any, Optional import random import json -import inspect + import uuid import asyncio @@ -59,7 +59,6 @@ from open_webui.env import GLOBAL_LOG_LEVEL, BYPASS_MODEL_ACCESS_CONTROL - logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL) log = logging.getLogger(__name__) @@ -356,12 +355,10 @@ async def chat_completed(request: Request, form_data: dict, user: Any): } try: - filter_functions = [ - Functions.get_function_by_id(filter_id) - for filter_id in get_sorted_filter_ids( - request, model, metadata.get("filter_ids", []) - ) - ] + filter_ids = get_sorted_filter_ids( + request, model, metadata.get("filter_ids", []) + ) + filter_functions = Functions.get_functions_by_ids(filter_ids) result, _ = await process_filter_functions( request=request, @@ -373,101 +370,3 @@ async def chat_completed(request: Request, form_data: dict, user: Any): return result except Exception as e: raise Exception(f"Error: {e}") - - -async def chat_action(request: Request, action_id: str, form_data: dict, user: Any): - if "." in action_id: - action_id, sub_action_id = action_id.split(".") - else: - sub_action_id = None - - action = Functions.get_function_by_id(action_id) - if not action: - raise Exception(f"Action not found: {action_id}") - - if not request.app.state.MODELS: - await get_all_models(request, user=user) - - if getattr(request.state, "direct", False) and hasattr(request.state, "model"): - models = { - request.state.model["id"]: request.state.model, - } - else: - models = request.app.state.MODELS - - data = form_data - model_id = data["model"] - - if model_id not in models: - raise Exception("Model not found") - model = models[model_id] - - __event_emitter__ = get_event_emitter( - { - "chat_id": data["chat_id"], - "message_id": data["id"], - "session_id": data["session_id"], - "user_id": user.id, - } - ) - __event_call__ = get_event_call( - { - "chat_id": data["chat_id"], - "message_id": data["id"], - "session_id": data["session_id"], - "user_id": user.id, - } - ) - - function_module, _, _ = get_function_module_from_cache(request, action_id) - - if hasattr(function_module, "valves") and hasattr(function_module, "Valves"): - valves = Functions.get_function_valves_by_id(action_id) - function_module.valves = function_module.Valves(**(valves if valves else {})) - - if hasattr(function_module, "action"): - try: - action = function_module.action - - # Get the signature of the function - sig = inspect.signature(action) - params = {"body": data} - - # Extra parameters to be passed to the function - extra_params = { - "__model__": model, - "__id__": sub_action_id if sub_action_id is not None else action_id, - "__event_emitter__": __event_emitter__, - "__event_call__": __event_call__, - "__request__": request, - } - - # Add extra params in contained in function signature - for key, value in extra_params.items(): - if key in sig.parameters: - params[key] = value - - if "__user__" in sig.parameters: - __user__ = user.model_dump() if isinstance(user, UserModel) else {} - - try: - if hasattr(function_module, "UserValves"): - __user__["valves"] = function_module.UserValves( - **Functions.get_user_valves_by_id_and_user_id( - action_id, user.id - ) - ) - except Exception as e: - log.exception(f"Failed to get user values: {e}") - - params = {**params, "__user__": __user__} - - if inspect.iscoroutinefunction(action): - data = await action(**params) - else: - data = action(**params) - - except Exception as e: - raise Exception(f"Error: {e}") - - return data diff --git a/backend/open_webui/utils/code_interpreter.py b/backend/open_webui/utils/code_interpreter.py index e89b970cb6..a5de56a6c1 100644 --- a/backend/open_webui/utils/code_interpreter.py +++ b/backend/open_webui/utils/code_interpreter.py @@ -8,7 +8,6 @@ import websockets from pydantic import BaseModel - logger = logging.getLogger(__name__) diff --git a/backend/open_webui/utils/db/access_control.py b/backend/open_webui/utils/db/access_control.py deleted file mode 100644 index 75bd337f8c..0000000000 --- a/backend/open_webui/utils/db/access_control.py +++ /dev/null @@ -1,124 +0,0 @@ -from pydantic import BaseModel, ConfigDict -from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON -from sqlalchemy.dialects.postgresql import JSONB - - -from sqlalchemy import or_, func, select, and_, text, cast, or_, and_, func - - -def has_permission(db, DocumentModel, query, filter: dict, permission: str = "read"): - group_ids = filter.get("group_ids", []) - user_id = filter.get("user_id") - dialect_name = db.bind.dialect.name - - conditions = [] - - # Handle read_only permission separately - if permission == "read_only": - # For read_only, we want items where: - # 1. User has explicit read permission (via groups or user-level) - # 2. BUT does NOT have write permission - # 3. Public items are NOT considered read_only - - read_conditions = [] - - # Group-level read permission - if group_ids: - group_read_conditions = [] - for gid in group_ids: - if dialect_name == "sqlite": - group_read_conditions.append( - DocumentModel.access_control["read"]["group_ids"].contains(gid) - ) - elif dialect_name == "postgresql": - group_read_conditions.append( - cast( - DocumentModel.access_control["read"]["group_ids"], - JSONB, - ).contains([gid]) - ) - - if group_read_conditions: - read_conditions.append(or_(*group_read_conditions)) - - # Combine read conditions - if read_conditions: - has_read = or_(*read_conditions) - else: - # If no read conditions, return empty result - return query.filter(False) - - # Now exclude items where user has write permission - write_exclusions = [] - - # Exclude items owned by user (they have implicit write) - if user_id: - write_exclusions.append(DocumentModel.user_id != user_id) - - # Exclude items where user has explicit write permission via groups - if group_ids: - group_write_conditions = [] - for gid in group_ids: - if dialect_name == "sqlite": - group_write_conditions.append( - DocumentModel.access_control["write"]["group_ids"].contains(gid) - ) - elif dialect_name == "postgresql": - group_write_conditions.append( - cast( - DocumentModel.access_control["write"]["group_ids"], - JSONB, - ).contains([gid]) - ) - - if group_write_conditions: - # User should NOT have write permission - write_exclusions.append(~or_(*group_write_conditions)) - - # Exclude public items (items without access_control) - write_exclusions.append(DocumentModel.access_control.isnot(None)) - write_exclusions.append(cast(DocumentModel.access_control, String) != "null") - - # Combine: has read AND does not have write AND not public - if write_exclusions: - query = query.filter(and_(has_read, *write_exclusions)) - else: - query = query.filter(has_read) - - return query - - # Original logic for other permissions (read, write, etc.) - # Public access conditions - if group_ids or user_id: - conditions.extend( - [ - DocumentModel.access_control.is_(None), - cast(DocumentModel.access_control, String) == "null", - ] - ) - - # User-level permission (owner has all permissions) - if user_id: - conditions.append(DocumentModel.user_id == user_id) - - # Group-level permission - if group_ids: - group_conditions = [] - for gid in group_ids: - if dialect_name == "sqlite": - group_conditions.append( - DocumentModel.access_control[permission]["group_ids"].contains(gid) - ) - elif dialect_name == "postgresql": - group_conditions.append( - cast( - DocumentModel.access_control[permission]["group_ids"], - JSONB, - ).contains([gid]) - ) - conditions.append(or_(*group_conditions)) - - if conditions: - query = query.filter(or_(*conditions)) - - return query diff --git a/backend/open_webui/utils/embeddings.py b/backend/open_webui/utils/embeddings.py index 43cbc56e5f..a2dc080cb5 100644 --- a/backend/open_webui/utils/embeddings.py +++ b/backend/open_webui/utils/embeddings.py @@ -10,12 +10,11 @@ from open_webui.routers.openai import embeddings as openai_embeddings from open_webui.routers.ollama import ( - embeddings as ollama_embeddings, - GenerateEmbeddingsForm, + embed as ollama_embed, + GenerateEmbedForm, ) - -from open_webui.utils.payload import convert_embedding_payload_openai_to_ollama +from open_webui.utils.payload import convert_embed_payload_openai_to_ollama from open_webui.utils.response import convert_embedding_response_ollama_to_openai logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL) @@ -71,12 +70,12 @@ async def generate_embeddings( if not bypass_filter and user.role == "user": check_model_access(user, model) - # Ollama backend + # Ollama backend — use /api/embed which supports batch input natively if model.get("owned_by") == "ollama": - ollama_payload = convert_embedding_payload_openai_to_ollama(form_data) - response = await ollama_embeddings( + ollama_payload = convert_embed_payload_openai_to_ollama(form_data) + response = await ollama_embed( request=request, - form_data=GenerateEmbeddingsForm(**ollama_payload), + form_data=GenerateEmbedForm(**ollama_payload), user=user, ) return convert_embedding_response_ollama_to_openai(response) diff --git a/backend/open_webui/utils/files.py b/backend/open_webui/utils/files.py index a37ecf31c6..af8818d59b 100644 --- a/backend/open_webui/utils/files.py +++ b/backend/open_webui/utils/files.py @@ -18,6 +18,7 @@ from open_webui.models.chats import Chats from open_webui.models.files import Files from open_webui.routers.files import upload_file_handler +from open_webui.retrieval.web.utils import validate_url import mimetypes import base64 @@ -33,6 +34,8 @@ def get_image_base64_from_url(url: str) -> Optional[str]: try: if url.startswith("http"): + # Validate URL to prevent SSRF attacks against local/private networks + validate_url(url) # Download the image from the URL response = requests.get(url) response.raise_for_status() diff --git a/backend/open_webui/utils/headers.py b/backend/open_webui/utils/headers.py index 3caee50334..f0b13c00d3 100644 --- a/backend/open_webui/utils/headers.py +++ b/backend/open_webui/utils/headers.py @@ -1,11 +1,18 @@ from urllib.parse import quote +from open_webui.env import ( + FORWARD_USER_INFO_HEADER_USER_NAME, + FORWARD_USER_INFO_HEADER_USER_ID, + FORWARD_USER_INFO_HEADER_USER_EMAIL, + FORWARD_USER_INFO_HEADER_USER_ROLE, +) + def include_user_info_headers(headers, user): return { **headers, - "X-OpenWebUI-User-Name": quote(user.name, safe=" "), - "X-OpenWebUI-User-Id": user.id, - "X-OpenWebUI-User-Email": user.email, - "X-OpenWebUI-User-Role": user.role, + FORWARD_USER_INFO_HEADER_USER_NAME: quote(user.name, safe=" "), + FORWARD_USER_INFO_HEADER_USER_ID: user.id, + FORWARD_USER_INFO_HEADER_USER_EMAIL: user.email, + FORWARD_USER_INFO_HEADER_USER_ROLE: user.role, } diff --git a/backend/open_webui/utils/logger.py b/backend/open_webui/utils/logger.py index 4af3064235..63d5fbb3ce 100644 --- a/backend/open_webui/utils/logger.py +++ b/backend/open_webui/utils/logger.py @@ -17,7 +17,6 @@ ENABLE_OTEL_LOGS, ) - if TYPE_CHECKING: from loguru import Record diff --git a/backend/open_webui/utils/mcp/client.py b/backend/open_webui/utils/mcp/client.py index 6edfca4f6c..33803648be 100644 --- a/backend/open_webui/utils/mcp/client.py +++ b/backend/open_webui/utils/mcp/client.py @@ -8,6 +8,15 @@ from mcp.client.auth import OAuthClientProvider, TokenStorage from mcp.client.streamable_http import streamablehttp_client from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken +import httpx +from mcp.shared._httpx_utils import create_mcp_http_client +from open_webui.env import AIOHTTP_CLIENT_SESSION_TOOL_SERVER_SSL + + +def create_insecure_httpx_client(headers=None, timeout=None, auth=None): + client = create_mcp_http_client(headers=headers, timeout=timeout, auth=auth) + client.verify = False + return client class MCPClient: @@ -18,7 +27,14 @@ def __init__(self): async def connect(self, url: str, headers: Optional[dict] = None): async with AsyncExitStack() as exit_stack: try: - self._streams_context = streamablehttp_client(url, headers=headers) + if AIOHTTP_CLIENT_SESSION_TOOL_SERVER_SSL: + self._streams_context = streamablehttp_client(url, headers=headers) + else: + self._streams_context = streamablehttp_client( + url, + headers=headers, + httpx_client_factory=create_insecure_httpx_client, + ) transport = await exit_stack.enter_async_context(self._streams_context) read_stream, write_stream, _ = transport diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index fe2d7e5dc1..e39787f715 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -73,6 +73,7 @@ from open_webui.retrieval.utils import get_sources_from_items +from open_webui.utils.sanitize import sanitize_code from open_webui.utils.chat import generate_chat_completion from open_webui.utils.task import ( get_task_model_id, @@ -92,6 +93,7 @@ prepend_to_first_user_message_content, convert_logit_bias_input_to_json, get_content_from_message, + convert_output_to_messages, ) from open_webui.utils.tools import ( get_tools, @@ -105,6 +107,7 @@ ) from open_webui.utils.code_interpreter import execute_code_jupyter from open_webui.utils.payload import apply_system_prompt_to_body +from open_webui.utils.response import normalize_usage from open_webui.utils.mcp.client import MCPClient @@ -124,10 +127,13 @@ ENABLE_REALTIME_CHAT_SAVE, ENABLE_QUERIES_CACHE, RAG_SYSTEM_CONTEXT, + ENABLE_FORWARD_USER_INFO_HEADERS, + FORWARD_SESSION_INFO_HEADER_CHAT_ID, + FORWARD_SESSION_INFO_HEADER_MESSAGE_ID, ) +from open_webui.utils.headers import include_user_info_headers from open_webui.constants import TASKS - logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL) log = logging.getLogger(__name__) @@ -146,6 +152,11 @@ DEFAULT_CODE_INTERPRETER_TAGS = [("", "")] +def output_id(prefix: str) -> str: + """Generate OR-style ID: prefix + 24-char hex UUID.""" + return f"{prefix}_{uuid4().hex[:24]}" + + def get_citation_source_from_tool_result( tool_name: str, tool_params: dict, tool_result: str, tool_id: str = "" ) -> list[dict]: @@ -160,9 +171,13 @@ def get_citation_source_from_tool_result( Returns a list of sources (usually one, but query_knowledge_files may return multiple). """ try: + tool_result = json.loads(tool_result) + if isinstance(tool_result, dict) and "error" in tool_result: + return [] + if tool_name == "search_web": # Parse JSON array: [{"title": "...", "link": "...", "snippet": "..."}] - results = json.loads(tool_result) + results = tool_result documents = [] metadata = [] @@ -189,7 +204,7 @@ def get_citation_source_from_tool_result( ] elif tool_name == "view_knowledge_file": - file_data = json.loads(tool_result) + file_data = tool_result filename = file_data.get("filename", "Unknown File") file_id = file_data.get("id", "") knowledge_name = file_data.get("knowledge_name", "") @@ -218,7 +233,7 @@ def get_citation_source_from_tool_result( ] elif tool_name == "query_knowledge_files": - chunks = json.loads(tool_result) + chunks = tool_result # Group chunks by source for better citation display # Each unique source becomes a separate source entry @@ -286,6 +301,475 @@ def get_citation_source_from_tool_result( ] +def split_content_and_whitespace(content): + content_stripped = content.rstrip() + original_whitespace = ( + content[len(content_stripped) :] if len(content) > len(content_stripped) else "" + ) + return content_stripped, original_whitespace + + +def is_opening_code_block(content): + backtick_segments = content.split("```") + # Even number of segments means the last backticks are opening a new block + return len(backtick_segments) > 1 and len(backtick_segments) % 2 == 0 + + +def serialize_output(output: list) -> str: + """ + Convert OR-aligned output items to HTML for display. + For LLM consumption, use convert_output_to_messages() instead. + """ + content = "" + + # First pass: collect function_call_output items by call_id for lookup + tool_outputs = {} + for item in output: + if item.get("type") == "function_call_output": + tool_outputs[item.get("call_id")] = item + + # Second pass: render items in order + for idx, item in enumerate(output): + item_type = item.get("type", "") + + if item_type == "message": + for content_part in item.get("content", []): + if "text" in content_part: + text = content_part.get("text", "").strip() + if text: + content = f"{content}{text}\n" + + elif item_type == "function_call": + # Render tool call inline with its result (if available) + if content and not content.endswith("\n"): + content += "\n" + + call_id = item.get("call_id", "") + name = item.get("name", "") + arguments = item.get("arguments", "") + + result_item = tool_outputs.get(call_id) + if result_item: + result_text = "" + for out in result_item.get("output", []): + if "text" in out: + result_text += out.get("text", "") + files = result_item.get("files") + embeds = result_item.get("embeds", "") + + content += f'
\nTool Executed\n
\n' + else: + content += f'
\nExecuting...\n
\n' + + elif item_type == "function_call_output": + # Already handled inline with function_call above + pass + + elif item_type == "reasoning": + reasoning_content = "" + # Check for 'summary' (new structure) or 'content' (legacy/fallback) + source_list = item.get("summary", []) or item.get("content", []) + for content_part in source_list: + if "text" in content_part: + reasoning_content += content_part.get("text", "") + elif "summary" in content_part: # Handle potential nested logic if any + pass + + reasoning_content = reasoning_content.strip() + + duration = item.get("duration") + status = item.get("status", "in_progress") + + # Infer completion: if this reasoning item is NOT the last item, + # render as done (a subsequent item means reasoning is complete) + is_last_item = idx == len(output) - 1 + + if content and not content.endswith("\n"): + content += "\n" + + display = html.escape( + "\n".join( + (f"> {line}" if not line.startswith(">") else line) + for line in reasoning_content.splitlines() + ) + ) + + if status == "completed" or duration is not None or not is_last_item: + content = f'{content}
\nThought for {duration or 0} seconds\n{display}\n
\n' + else: + content = f'{content}
\nThinking…\n{display}\n
\n' + + elif item_type == "open_webui:code_interpreter": + content_stripped, original_whitespace = split_content_and_whitespace( + content + ) + if is_opening_code_block(content_stripped): + content = content_stripped.rstrip("`").rstrip() + original_whitespace + else: + content = content_stripped + original_whitespace + + if content and not content.endswith("\n"): + content += "\n" + + return content.strip() + + +def deep_merge(target, source): + """ + Merge source into target recursively (returning new structure). + - Dicts: Recursive merge. + - Strings: Concatenation. + - Others: Overwrite. + """ + if isinstance(target, dict) and isinstance(source, dict): + new_target = target.copy() + for k, v in source.items(): + if k in new_target: + new_target[k] = deep_merge(new_target[k], v) + else: + new_target[k] = v + return new_target + elif isinstance(target, str) and isinstance(source, str): + return target + source + else: + return source + + +def handle_responses_streaming_event( + data: dict, + current_output: list, +) -> tuple[list, dict | None]: + """ + Handle Responses API streaming events in a pure functional way. + + Args: + data: The event data + current_output: List of output items (treated as immutable) + + Returns: + tuple[list, dict | None]: (new_output, metadata) + - new_output: The updated output list. + - metadata: Metadata to emit (e.g. usage), {} if update occurred, None if skip. + """ + # Default: no change + # Note: treating current_output as immutable, but avoiding full deepcopy for perf. + # We will shallow copy only if we need to modify the list structure or items. + + event_type = data.get("type", "") + + if event_type == "response.output_item.added": + item = data.get("item", {}) + if item: + new_output = list(current_output) + new_output.append(item) + return new_output, None + return current_output, None + + elif event_type == "response.content_part.added": + part = data.get("part", {}) + output_index = data.get("output_index", len(current_output) - 1) + + if current_output and 0 <= output_index < len(current_output): + new_output = list(current_output) + # Copy the item to mutate it + item = new_output[output_index].copy() + new_output[output_index] = item + + if "content" not in item: + item["content"] = [] + else: + # Copy content list + item["content"] = list(item["content"]) + + if item.get("type") == "reasoning": + # Reasoning items should not have content parts + pass + else: + item["content"].append(part) + return new_output, None + return current_output, None + + elif event_type == "response.reasoning_summary_part.added": + part = data.get("part", {}) + output_index = data.get("output_index", len(current_output) - 1) + + if current_output and 0 <= output_index < len(current_output): + new_output = list(current_output) + item = new_output[output_index].copy() + new_output[output_index] = item + + if "summary" not in item: + item["summary"] = [] + else: + item["summary"] = list(item["summary"]) + + item["summary"].append(part) + return new_output, None + return current_output, None + + elif event_type.startswith("response.") and event_type.endswith(".delta"): + # Generic Delta Handling + parts = event_type.split(".") + if len(parts) >= 3: + delta_type = parts[1] + delta = data.get("delta", "") + + output_index = data.get("output_index", len(current_output) - 1) + + if current_output and 0 <= output_index < len(current_output): + new_output = list(current_output) + item = new_output[output_index].copy() + new_output[output_index] = item + item_type = item.get("type", "") + + # Determine target field and object based on delta_type and item_type + if delta_type == "function_call_arguments": + key = "arguments" + if item_type == "function_call": + # Function call args are usually strings + item[key] = item.get(key, "") + str(delta) + else: + # Generic handling, refined by item type below + pass + + if item_type == "message": + # Message items: "text"/"output_text" -> "text" + # "reasoning_text" -> Skipped (should use reasoning item) + if delta_type in ["text", "output_text"]: + key = "text" + elif delta_type in ["reasoning_text", "reasoning_summary_text"]: + # Skip reasoning updates for message items + return new_output, None + else: + key = delta_type + + content_index = data.get("content_index", 0) + if "content" not in item: + item["content"] = [] + else: + item["content"] = list(item["content"]) + content_list = item["content"] + + while len(content_list) <= content_index: + content_list.append({"type": "text", "text": ""}) + + # Copy the part to mutate it + part = content_list[content_index].copy() + content_list[content_index] = part + + current_val = part.get(key) + if current_val is None: + # Initialize based on delta type + current_val = {} if isinstance(delta, dict) else "" + + part[key] = deep_merge(current_val, delta) + + elif item_type == "reasoning": + # Reasoning items: "reasoning_text"/"reasoning_summary_text" -> "text" + # "text"/"output_text" -> Skipped (should use message item) + if delta_type == "reasoning_summary_text": + # Summary updates -> item['summary'] + key = "text" + summary_index = data.get("summary_index", 0) + if "summary" not in item: + item["summary"] = [] + else: + item["summary"] = list(item["summary"]) + summary_list = item["summary"] + + while len(summary_list) <= summary_index: + summary_list.append( + {"type": "summary_text", "text": ""} + ) + + part = summary_list[summary_index].copy() + summary_list[summary_index] = part + + target_val = part.get(key, "") + part[key] = deep_merge(target_val, delta) + + elif delta_type == "reasoning_text": + # Reasoning body updates -> item['content'] + key = "text" + content_index = data.get("content_index", 0) + if "content" not in item: + item["content"] = [] + else: + item["content"] = list(item["content"]) + content_list = item["content"] + + while len(content_list) <= content_index: + # Reasoning content parts default to text + content_list.append({"type": "text", "text": ""}) + + part = content_list[content_index].copy() + content_list[content_index] = part + + target_val = part.get(key, "") + part[key] = deep_merge(target_val, delta) + + elif delta_type in ["text", "output_text"]: + return new_output, None + else: + # Fallback just in case other deltas target reasoning? + pass + + else: + # Fallback for other item types + if delta_type in ["text", "output_text"]: + key = "text" + else: + key = delta_type + + current_val = item.get(key) + if current_val is None: + current_val = {} if isinstance(delta, dict) else "" + item[key] = deep_merge(current_val, delta) + + return new_output, None + + elif event_type.startswith("response.") and event_type.endswith(".done"): + # Delta Events: response.content_part.done, response.text.done, etc. + parts = event_type.split(".") + if len(parts) >= 3: + type_name = parts[1] + + # 1. Handle specific Delta "done" signals + if type_name == "content_part": + # "Signaling that no further changes will occur to a content part" + # If payloads contains the full part, we could update it. + # Usually purely signaling in standard implementation, but we check payload. + part = data.get("part") + output_index = data.get("output_index", len(current_output) - 1) + + if part and current_output and 0 <= output_index < len(current_output): + new_output = list(current_output) + item = new_output[output_index].copy() + new_output[output_index] = item + + if "content" in item: + item["content"] = list(item["content"]) + content_index = data.get( + "content_index", len(item["content"]) - 1 + ) + if 0 <= content_index < len(item["content"]): + item["content"][content_index] = part + return new_output, {} + return current_output, None + + elif type_name == "reasoning_summary_part": + part = data.get("part") + output_index = data.get("output_index", len(current_output) - 1) + + if part and current_output and 0 <= output_index < len(current_output): + new_output = list(current_output) + item = new_output[output_index].copy() + new_output[output_index] = item + + if "summary" in item: + item["summary"] = list(item["summary"]) + summary_index = data.get( + "summary_index", len(item["summary"]) - 1 + ) + if 0 <= summary_index < len(item["summary"]): + item["summary"][summary_index] = part + return new_output, {} + return current_output, None + + # 2. Skip Output Item done (handled specifically below) + if type_name == "output_item": + pass + + # 3. Generic Field Done (text.done, audio.done) + elif type_name not in ["completed", "failed"]: + output_index = data.get("output_index", len(current_output) - 1) + if current_output and 0 <= output_index < len(current_output): + + key = ( + "text" + if type_name + in [ + "text", + "output_text", + "reasoning_text", + "reasoning_summary_text", + ] + else type_name + ) + if type_name == "function_call_arguments": + key = "arguments" + + if key in data: + final_value = data[key] + new_output = list(current_output) + item = new_output[output_index].copy() + new_output[output_index] = item + item_type = item.get("type", "") + + if type_name == "function_call_arguments": + if item_type == "function_call": + item["arguments"] = final_value + elif item_type == "message": + content_index = data.get("content_index", 0) + if "content" in item: + item["content"] = list(item["content"]) + if len(item["content"]) > content_index: + part = item["content"][content_index].copy() + item["content"][content_index] = part + part[key] = final_value + elif item_type == "reasoning": + item["status"] = "completed" + else: + item[key] = final_value + + return new_output, {} + + return current_output, None + + elif event_type == "response.output_item.done": + # Delta Event: Output item complete + item = data.get("item") + output_index = data.get("output_index", len(current_output) - 1) + + new_output = list(current_output) + if item and 0 <= output_index < len(current_output): + new_output[output_index] = item + elif item: + new_output.append(item) + return new_output, {} + + elif event_type == "response.completed": + # State Machine Event: Completed + response_data = data.get("response", {}) + final_output = response_data.get("output") + + new_output = final_output if final_output is not None else current_output + + # Ensure reasoning items are marked as completed in the final output + if new_output: + for item in new_output: + if ( + item.get("type") == "reasoning" + and item.get("status") != "completed" + ): + item["status"] = "completed" + + return new_output, {"usage": response_data.get("usage"), "done": True} + + elif event_type == "response.in_progress": + # State Machine Event: In Progress + # We could extract metadata if needed, but for now just acknowledge iteration + return current_output, None + + elif event_type == "response.failed": + # State Machine Event: Failed + error = data.get("response", {}).get("error", {}) + return current_output, {"error": error} + + else: + return current_output, None + + def apply_source_context_to_messages( request: Request, messages: list, @@ -380,7 +864,7 @@ def process_tool_result( else: tool_result = tool_result.body.decode("utf-8", "replace") - elif (tool_type == "external" and isinstance(tool_result, tuple)) or ( + elif (tool_type in ("external", "action") and isinstance(tool_result, tuple)) or ( direct_tool and isinstance(tool_result, list) and len(tool_result) == 2 ): tool_result, tool_response_headers = tool_result @@ -775,6 +1259,7 @@ async def chat_web_search_handler( "messages": messages, "prompt": user_message, "type": "web_search", + "chat_id": extra_params.get("__chat_id__"), }, user, ) @@ -976,7 +1461,9 @@ def format_file_tag(file): for message, stored_message in zip(messages, stored_messages): files_with_urls = [ - file for file in stored_message.get("files", []) if file.get("url") + file + for file in stored_message.get("files", []) + if file.get("url") and not file.get("url").startswith("data:") ] if not files_with_urls: continue @@ -1103,6 +1590,7 @@ async def chat_image_generation_handler( { "model": form_data["model"], "messages": form_data["messages"], + "chat_id": metadata.get("chat_id"), }, user, ) @@ -1209,6 +1697,7 @@ async def chat_completion_files_handler( "model": body["model"], "messages": body["messages"], "type": "retrieval", + "chat_id": body.get("metadata", {}).get("chat_id"), }, user, ) @@ -1399,6 +1888,30 @@ async def convert_url_images_to_base64(form_data): return form_data +def process_messages_with_output(messages: list[dict]) -> list[dict]: + """ + Process messages with OR-aligned output items for LLM consumption. + + For assistant messages with 'output' field, produces properly formatted + OpenAI-style messages (tool_calls + tool results). Strips 'output' before LLM. + """ + processed = [] + + for message in messages: + if message.get("role") == "assistant" and message.get("output"): + # Use output items for clean OpenAI-format messages + output_messages = convert_output_to_messages(message["output"]) + if output_messages: + processed.extend(output_messages) + continue + + # Strip 'output' field before adding (LLM shouldn't see it) + clean_message = {k: v for k, v in message.items() if k != "output"} + processed.append(clean_message) + + return processed + + async def process_chat_payload(request, form_data, user, metadata, model): # Pipeline Inlet -> Filter Inlet -> Chat Memory -> Chat Web Search -> Chat Image Generation # -> Chat Code Interpreter (Form Data Update) -> (Default) Chat Tools Function Calling @@ -1407,6 +1920,9 @@ async def process_chat_payload(request, form_data, user, metadata, model): form_data = apply_params_to_form_data(form_data, model) log.debug(f"form_data: {form_data}") + # Process messages with OR-aligned output items for clean LLM messages + form_data["messages"] = process_messages_with_output(form_data.get("messages", [])) + system_message = get_system_message(form_data.get("messages", [])) if system_message: # Chat Controls/User Settings try: @@ -1421,22 +1937,12 @@ async def process_chat_payload(request, form_data, user, metadata, model): event_emitter = get_event_emitter(metadata) event_caller = get_event_call(metadata) - oauth_token = None - try: - if request.cookies.get("oauth_session_id", None): - oauth_token = await request.app.state.oauth_manager.get_oauth_token( - user.id, - request.cookies.get("oauth_session_id", None), - ) - except Exception as e: - log.error(f"Error getting OAuth token: {e}") - extra_params = { "__event_emitter__": event_emitter, "__event_call__": event_caller, "__user__": user.model_dump() if isinstance(user, UserModel) else {}, "__metadata__": metadata, - "__oauth_token__": oauth_token, + "__oauth_token__": await get_system_oauth_token(request, user), "__request__": request, "__model__": model, "__chat_id__": metadata.get("chat_id"), @@ -1536,12 +2042,10 @@ async def process_chat_payload(request, form_data, user, metadata, model): raise e try: - filter_functions = [ - Functions.get_function_by_id(filter_id) - for filter_id in get_sorted_filter_ids( - request, model, metadata.get("filter_ids", []) - ) - ] + filter_ids = get_sorted_filter_ids( + request, model, metadata.get("filter_ids", []) + ) + filter_functions = Functions.get_functions_by_ids(filter_ids) form_data, flags = await process_filter_functions( request=request, @@ -1602,6 +2106,35 @@ async def process_chat_payload(request, form_data, user, metadata, model): tool_ids = form_data.pop("tool_ids", None) files = form_data.pop("files", None) + # Skills: inject manifest only — model uses view_skill tool to load full content on-demand + user_skill_ids = form_data.pop("skill_ids", None) or [] + model_skill_ids = model.get("info", {}).get("meta", {}).get("skillIds", []) + + all_skill_ids = list(set(user_skill_ids + model_skill_ids)) + available_skills = [] + if all_skill_ids: + from open_webui.models.skills import Skills as SkillsModel + + accessible_skill_ids = { + s.id for s in SkillsModel.get_skills_by_user_id(user.id, "read") + } + available_skills = [ + s + for sid in all_skill_ids + if sid in accessible_skill_ids + and (s := SkillsModel.get_skill_by_id(sid)) + and s.is_active + ] + + if available_skills: + manifest = "\n" + for skill in available_skills: + manifest += f"\n{skill.name}\n{skill.description or ''}\n\n" + manifest += "" + form_data["messages"] = add_or_update_system_message( + manifest, form_data["messages"], append=True + ) + prompt = get_last_user_message(form_data["messages"]) # TODO: re-enable URL extraction from prompt # urls = [] @@ -1715,6 +2248,18 @@ async def process_chat_payload(request, form_data, user, metadata, model): for key, value in connection_headers.items(): headers[key] = value + # Add user info headers if enabled + if ENABLE_FORWARD_USER_INFO_HEADERS and user: + headers = include_user_info_headers(headers, user) + if metadata and metadata.get("chat_id"): + headers[FORWARD_SESSION_INFO_HEADER_CHAT_ID] = metadata.get( + "chat_id" + ) + if metadata and metadata.get("message_id"): + headers[FORWARD_SESSION_INFO_HEADER_MESSAGE_ID] = ( + metadata.get("message_id") + ) + mcp_clients[server_id] = MCPClient() await mcp_clients[server_id].connect( url=mcp_server_connection.get("url", ""), @@ -1808,11 +2353,8 @@ async def tool_function(**kwargs): # Inject builtin tools for native function calling based on enabled features and model capability # Check if builtin_tools capability is enabled for this model (defaults to True if not specified) builtin_tools_enabled = ( - model.get("info", {}) - .get("meta", {}) - .get("capabilities", {}) - .get("builtin_tools", True) - ) + model.get("info", {}).get("meta", {}).get("capabilities") or {} + ).get("builtin_tools", True) if ( metadata.get("params", {}).get("function_calling") == "native" and builtin_tools_enabled @@ -1827,6 +2369,7 @@ async def tool_function(**kwargs): { **extra_params, "__event_emitter__": event_emitter, + "__skill_ids__": [s.id for s in available_skills], }, features, model, @@ -1856,11 +2399,8 @@ async def tool_function(**kwargs): # Check if file context extraction is enabled for this model (default True) file_context_enabled = ( - model.get("info", {}) - .get("meta", {}) - .get("capabilities", {}) - .get("file_context", True) - ) + model.get("info", {}).get("meta", {}).get("capabilities") or {} + ).get("file_context", True) if file_context_enabled: try: @@ -1904,190 +2444,220 @@ async def tool_function(**kwargs): return form_data, metadata, events -async def process_chat_response( - request, response, form_data, user, metadata, model, events, tasks +def get_event_emitter_and_caller(metadata): + event_emitter = None + event_caller = None + if ( + "session_id" in metadata + and metadata["session_id"] + and "chat_id" in metadata + and metadata["chat_id"] + and "message_id" in metadata + and metadata["message_id"] + ): + event_emitter = get_event_emitter(metadata) + event_caller = get_event_call(metadata) + return event_emitter, event_caller + + +def build_chat_response_context( + request, form_data, user, model, metadata, tasks, events ): - async def background_tasks_handler(): - message = None - messages = [] + event_emitter, event_caller = get_event_emitter_and_caller(metadata) + return { + "request": request, + "form_data": form_data, + "user": user, + "model": model, + "metadata": metadata, + "tasks": tasks, + "events": events, + "event_emitter": event_emitter, + "event_caller": event_caller, + } - if "chat_id" in metadata and not metadata["chat_id"].startswith("local:"): - messages_map = Chats.get_messages_map_by_chat_id(metadata["chat_id"]) - message = messages_map.get(metadata["message_id"]) if messages_map else None - message_list = get_message_list(messages_map, metadata["message_id"]) +def get_response_data(response): + if isinstance(response, list) and len(response) == 1: + # If the response is a single-item list, unwrap it #17213 + response = response[0] - # Remove details tags and files from the messages. - # as get_message_list creates a new list, it does not affect - # the original messages outside of this handler + if isinstance(response, JSONResponse): + if isinstance(response.body, bytes): + try: + response_data = json.loads(response.body.decode("utf-8", "replace")) + except json.JSONDecodeError: + response_data = {"error": {"detail": "Invalid JSON response"}} + else: + response_data = response + elif isinstance(response, dict): + response_data = response + else: + response_data = None - messages = [] - for message in message_list: - content = message.get("content", "") - if isinstance(content, list): - for item in content: - if item.get("type") == "text": - content = item["text"] - break + return response, response_data - if isinstance(content, str): - content = re.sub( - r"]*>.*?<\/details>|!\[.*?\]\(.*?\)", - "", - content, - flags=re.S | re.I, - ).strip() - messages.append( - { - **message, - "role": message.get( - "role", "assistant" - ), # Safe fallback for missing role - "content": content, - } - ) - else: - # Local temp chat, get the model and message from the form_data - message = get_last_user_message_item(form_data.get("messages", [])) - messages = form_data.get("messages", []) - if message: - message["model"] = form_data.get("model") - - if message and "model" in message: - if tasks and messages: - if ( - TASKS.FOLLOW_UP_GENERATION in tasks - and tasks[TASKS.FOLLOW_UP_GENERATION] - ): - res = await generate_follow_ups( - request, - { - "model": message["model"], - "messages": messages, - "message_id": metadata["message_id"], - "chat_id": metadata["chat_id"], - }, - user, - ) +def merge_events_into_response(response_data, events): + if events and isinstance(events, list): + extra_response = {} + for event in events: + if isinstance(event, dict): + extra_response.update(event) + else: + extra_response[event] = True - if res and isinstance(res, dict): - if len(res.get("choices", [])) == 1: - response_message = res.get("choices", [])[0].get( - "message", {} - ) + return { + **extra_response, + **response_data, + } + return response_data - follow_ups_string = response_message.get( - "content" - ) or response_message.get("reasoning_content", "") - else: - follow_ups_string = "" - follow_ups_string = follow_ups_string[ - follow_ups_string.find("{") : follow_ups_string.rfind("}") - + 1 - ] +def build_response_object(response, response_data): + if isinstance(response, dict): + return response_data + if isinstance(response, JSONResponse): + return JSONResponse( + content=response_data, + headers=response.headers, + status_code=response.status_code, + ) + return response - try: - follow_ups = json.loads(follow_ups_string).get( - "follow_ups", [] - ) - await event_emitter( - { - "type": "chat:message:follow_ups", - "data": { - "follow_ups": follow_ups, - }, - } - ) - if not metadata.get("chat_id", "").startswith("local:"): - Chats.upsert_message_to_chat_by_id_and_message_id( - metadata["chat_id"], - metadata["message_id"], - { - "followUps": follow_ups, - }, - ) +async def get_system_oauth_token(request, user): + oauth_token = None + try: + if request.cookies.get("oauth_session_id", None): + oauth_token = await request.app.state.oauth_manager.get_oauth_token( + user.id, + request.cookies.get("oauth_session_id", None), + ) + except Exception as e: + log.error(f"Error getting OAuth token: {e}") + return oauth_token - except Exception as e: - pass - if not metadata.get("chat_id", "").startswith( - "local:" - ): # Only update titles and tags for non-temp chats - if TASKS.TITLE_GENERATION in tasks: - user_message = get_last_user_message(messages) - if user_message and len(user_message) > 100: - user_message = user_message[:100] + "..." - - title = None - if tasks[TASKS.TITLE_GENERATION]: - res = await generate_title( - request, - { - "model": message["model"], - "messages": messages, - "chat_id": metadata["chat_id"], - }, - user, - ) +async def background_tasks_handler(ctx): + request = ctx["request"] + form_data = ctx["form_data"] + user = ctx["user"] + metadata = ctx["metadata"] + tasks = ctx["tasks"] + event_emitter = ctx["event_emitter"] - if res and isinstance(res, dict): - if len(res.get("choices", [])) == 1: - response_message = res.get("choices", [])[0].get( - "message", {} - ) + message = None + messages = [] - title_string = ( - response_message.get("content") - or response_message.get( - "reasoning_content", - ) - or message.get("content", user_message) - ) - else: - title_string = "" + if "chat_id" in metadata and not metadata["chat_id"].startswith("local:"): + messages_map = Chats.get_messages_map_by_chat_id(metadata["chat_id"]) + message = messages_map.get(metadata["message_id"]) if messages_map else None - title_string = title_string[ - title_string.find("{") : title_string.rfind("}") + 1 - ] + message_list = get_message_list(messages_map, metadata["message_id"]) - try: - title = json.loads(title_string).get( - "title", user_message - ) - except Exception as e: - title = "" + # Remove details tags and files from the messages. + # as get_message_list creates a new list, it does not affect + # the original messages outside of this handler - if not title: - title = messages[0].get("content", user_message) + messages = [] + for message in message_list: + content = message.get("content", "") + if isinstance(content, list): + for item in content: + if item.get("type") == "text": + content = item["text"] + break - Chats.update_chat_title_by_id( - metadata["chat_id"], title - ) + if isinstance(content, str): + content = re.sub( + r"]*>.*?<\/details>|!\[.*?\]\(.*?\)", + "", + content, + flags=re.S | re.I, + ).strip() - await event_emitter( - { - "type": "chat:title", - "data": title, - } - ) + messages.append( + { + **message, + "role": message.get( + "role", "assistant" + ), # Safe fallback for missing role + "content": content, + } + ) + else: + # Local temp chat, get the model and message from the form_data + message = get_last_user_message_item(form_data.get("messages", [])) + messages = form_data.get("messages", []) + if message: + message["model"] = form_data.get("model") + + if message and "model" in message: + if tasks and messages: + if ( + TASKS.FOLLOW_UP_GENERATION in tasks + and tasks[TASKS.FOLLOW_UP_GENERATION] + ): + res = await generate_follow_ups( + request, + { + "model": message["model"], + "messages": messages, + "message_id": metadata["message_id"], + "chat_id": metadata["chat_id"], + }, + user, + ) - if title == None and len(messages) == 2: - title = messages[0].get("content", user_message) + if res and isinstance(res, dict): + if len(res.get("choices", [])) == 1: + response_message = res.get("choices", [])[0].get("message", {}) - Chats.update_chat_title_by_id(metadata["chat_id"], title) + follow_ups_string = response_message.get( + "content" + ) or response_message.get("reasoning_content", "") + else: + follow_ups_string = "" - await event_emitter( + follow_ups_string = follow_ups_string[ + follow_ups_string.find("{") : follow_ups_string.rfind("}") + 1 + ] + + try: + follow_ups = json.loads(follow_ups_string).get("follow_ups", []) + await event_emitter( + { + "type": "chat:message:follow_ups", + "data": { + "follow_ups": follow_ups, + }, + } + ) + + if not metadata.get("chat_id", "").startswith("local:"): + Chats.upsert_message_to_chat_by_id_and_message_id( + metadata["chat_id"], + metadata["message_id"], { - "type": "chat:title", - "data": message.get("content", user_message), - } + "followUps": follow_ups, + }, ) - if TASKS.TAGS_GENERATION in tasks and tasks[TASKS.TAGS_GENERATION]: - res = await generate_chat_tags( + except Exception as e: + pass + + if not metadata.get("chat_id", "").startswith( + "local:" + ): # Only update titles and tags for non-temp chats + if TASKS.TITLE_GENERATION in tasks: + user_message = get_last_user_message(messages) + if user_message and len(user_message) > 100: + user_message = user_message[:100] + "..." + + title = None + if tasks[TASKS.TITLE_GENERATION]: + res = await generate_title( request, { "model": message["model"], @@ -2103,221 +2673,249 @@ async def background_tasks_handler(): "message", {} ) - tags_string = response_message.get( - "content" - ) or response_message.get("reasoning_content", "") + title_string = ( + response_message.get("content") + or response_message.get( + "reasoning_content", + ) + or message.get("content", user_message) + ) else: - tags_string = "" + title_string = "" - tags_string = tags_string[ - tags_string.find("{") : tags_string.rfind("}") + 1 + title_string = title_string[ + title_string.find("{") : title_string.rfind("}") + 1 ] try: - tags = json.loads(tags_string).get("tags", []) - Chats.update_chat_tags_by_id( - metadata["chat_id"], tags, user - ) - - await event_emitter( - { - "type": "chat:tags", - "data": tags, - } + title = json.loads(title_string).get( + "title", user_message ) except Exception as e: - pass + title = "" - event_emitter = None - event_caller = None - if ( - "session_id" in metadata - and metadata["session_id"] - and "chat_id" in metadata - and metadata["chat_id"] - and "message_id" in metadata - and metadata["message_id"] - ): - event_emitter = get_event_emitter(metadata) - event_caller = get_event_call(metadata) + if not title: + title = messages[0].get("content", user_message) - # Non-streaming response - if not isinstance(response, StreamingResponse): - if event_emitter: - try: - if isinstance(response, dict) or isinstance(response, JSONResponse): - if isinstance(response, list) and len(response) == 1: - # If the response is a single-item list, unwrap it #17213 - response = response[0] - - if isinstance(response, JSONResponse) and isinstance( - response.body, bytes - ): - try: - response_data = json.loads( - response.body.decode("utf-8", "replace") - ) - except json.JSONDecodeError: - response_data = { - "error": {"detail": "Invalid JSON response"} - } - else: - response_data = response - - if "error" in response_data: - error = response_data.get("error") - - if isinstance(error, dict): - error = error.get("detail", error) - else: - error = str(error) + Chats.update_chat_title_by_id(metadata["chat_id"], title) - Chats.upsert_message_to_chat_by_id_and_message_id( - metadata["chat_id"], - metadata["message_id"], - { - "error": {"content": error}, - }, - ) - if isinstance(error, str) or isinstance(error, dict): await event_emitter( { - "type": "chat:message:error", - "data": {"error": {"content": error}}, + "type": "chat:title", + "data": title, } ) - if "selected_model_id" in response_data: - Chats.upsert_message_to_chat_by_id_and_message_id( - metadata["chat_id"], - metadata["message_id"], + if title == None and len(messages) == 2: + title = messages[0].get("content", user_message) + + Chats.update_chat_title_by_id(metadata["chat_id"], title) + + await event_emitter( { - "selectedModelId": response_data["selected_model_id"], - }, + "type": "chat:title", + "data": message.get("content", user_message), + } ) - choices = response_data.get("choices", []) - if choices and choices[0].get("message", {}).get("content"): - content = response_data["choices"][0]["message"]["content"] + if TASKS.TAGS_GENERATION in tasks and tasks[TASKS.TAGS_GENERATION]: + res = await generate_chat_tags( + request, + { + "model": message["model"], + "messages": messages, + "chat_id": metadata["chat_id"], + }, + user, + ) - if content: - await event_emitter( - { - "type": "chat:completion", - "data": response_data, - } + if res and isinstance(res, dict): + if len(res.get("choices", [])) == 1: + response_message = res.get("choices", [])[0].get( + "message", {} ) - title = Chats.get_chat_title_by_id(metadata["chat_id"]) + tags_string = response_message.get( + "content" + ) or response_message.get("reasoning_content", "") + else: + tags_string = "" + + tags_string = tags_string[ + tags_string.find("{") : tags_string.rfind("}") + 1 + ] + + try: + tags = json.loads(tags_string).get("tags", []) + Chats.update_chat_tags_by_id( + metadata["chat_id"], tags, user + ) await event_emitter( { - "type": "chat:completion", - "data": { - "done": True, - "content": content, - "title": title, - }, + "type": "chat:tags", + "data": tags, } ) + except Exception as e: + pass - # Save message in the database - Chats.upsert_message_to_chat_by_id_and_message_id( - metadata["chat_id"], - metadata["message_id"], + +async def non_streaming_chat_response_handler(response, ctx): + request = ctx["request"] + + user = ctx["user"] + metadata = ctx["metadata"] + events = ctx["events"] + + event_emitter = ctx["event_emitter"] + + response, response_data = get_response_data(response) + if response_data is None: + return response + + if event_emitter: + try: + if "error" in response_data: + error = response_data.get("error") + + if isinstance(error, dict): + error = error.get("detail", error) + else: + error = str(error) + + Chats.upsert_message_to_chat_by_id_and_message_id( + metadata["chat_id"], + metadata["message_id"], + { + "error": {"content": error}, + }, + ) + if isinstance(error, str) or isinstance(error, dict): + await event_emitter( + { + "type": "chat:message:error", + "data": {"error": {"content": error}}, + } + ) + + if "selected_model_id" in response_data: + Chats.upsert_message_to_chat_by_id_and_message_id( + metadata["chat_id"], + metadata["message_id"], + { + "selectedModelId": response_data["selected_model_id"], + }, + ) + + choices = response_data.get("choices", []) + if choices and choices[0].get("message", {}).get("content"): + content = response_data["choices"][0]["message"]["content"] + + if content: + await event_emitter( + { + "type": "chat:completion", + "data": response_data, + } + ) + + title = Chats.get_chat_title_by_id(metadata["chat_id"]) + + # Use output from backend if provided (OR-compliant backends), + # otherwise generate from response content + response_output = response_data.get("output") + if not response_output: + response_output = [ + { + "type": "message", + "id": output_id("msg"), + "status": "completed", + "role": "assistant", + "content": [{"type": "output_text", "text": content}], + } + ] + + await event_emitter( + { + "type": "chat:completion", + "data": { + "done": True, + "content": content, + "output": response_output, + "title": title, + }, + } + ) + + # Save message in the database + Chats.upsert_message_to_chat_by_id_and_message_id( + metadata["chat_id"], + metadata["message_id"], + { + "role": "assistant", + "content": content, + "output": response_output, + }, + ) + + # Send a webhook notification if the user is not active + if not Users.is_user_active(user.id): + webhook_url = Users.get_user_webhook_url_by_id(user.id) + if webhook_url: + await post_webhook( + request.app.state.WEBUI_NAME, + webhook_url, + f"{title} - {request.app.state.config.WEBUI_URL}/c/{metadata['chat_id']}\n\n{content}", { - "role": "assistant", - "content": content, + "action": "chat", + "message": content, + "title": title, + "url": f"{request.app.state.config.WEBUI_URL}/c/{metadata['chat_id']}", }, ) - # Send a webhook notification if the user is not active - if not Users.is_user_active(user.id): - webhook_url = Users.get_user_webhook_url_by_id(user.id) - if webhook_url: - await post_webhook( - request.app.state.WEBUI_NAME, - webhook_url, - f"{title} - {request.app.state.config.WEBUI_URL}/c/{metadata['chat_id']}\n\n{content}", - { - "action": "chat", - "message": content, - "title": title, - "url": f"{request.app.state.config.WEBUI_URL}/c/{metadata['chat_id']}", - }, - ) + await background_tasks_handler(ctx) - await background_tasks_handler() + response = build_response_object( + response, merge_events_into_response(response_data, events) + ) + except Exception as e: + log.debug(f"Error occurred while processing request: {e}") + pass - if events and isinstance(events, list): - extra_response = {} - for event in events: - if isinstance(event, dict): - extra_response.update(event) - else: - extra_response[event] = True + return response - response_data = { - **extra_response, - **response_data, - } + if isinstance(response, dict): + response = merge_events_into_response(response_data, events) - if isinstance(response, dict): - response = response_data - if isinstance(response, JSONResponse): - response = JSONResponse( - content=response_data, - headers=response.headers, - status_code=response.status_code, - ) + return response - except Exception as e: - log.debug(f"Error occurred while processing request: {e}") - pass - return response - else: - if events and isinstance(events, list) and isinstance(response, dict): - extra_response = {} - for event in events: - if isinstance(event, dict): - extra_response.update(event) - else: - extra_response[event] = True +async def streaming_chat_response_handler(response, ctx): + request = ctx["request"] - response = { - **extra_response, - **response, - } + form_data = ctx["form_data"] - return response + user = ctx["user"] + model = ctx["model"] - # Non standard response - if not any( - content_type in response.headers["Content-Type"] - for content_type in ["text/event-stream", "application/x-ndjson"] - ): - return response + metadata = ctx["metadata"] + events = ctx["events"] - oauth_token = None - try: - if request.cookies.get("oauth_session_id", None): - oauth_token = await request.app.state.oauth_manager.get_oauth_token( - user.id, - request.cookies.get("oauth_session_id", None), - ) - except Exception as e: - log.error(f"Error getting OAuth token: {e}") + event_emitter = ctx["event_emitter"] + event_caller = ctx["event_caller"] extra_params = { "__event_emitter__": event_emitter, "__event_call__": event_caller, "__user__": user.model_dump() if isinstance(user, UserModel) else {}, "__metadata__": metadata, - "__oauth_token__": oauth_token, + "__oauth_token__": await get_system_oauth_token(request, user), "__request__": request, "__model__": model, } + filter_functions = [ Functions.get_function_by_id(filter_id) for filter_id in get_sorted_filter_ids( @@ -2325,224 +2923,61 @@ async def background_tasks_handler(): ) ] - # Streaming response + # Standard streaming response handler if event_emitter and event_caller: task_id = str(uuid4()) # Create a unique task ID. model_id = form_data.get("model", "") - def split_content_and_whitespace(content): - content_stripped = content.rstrip() - original_whitespace = ( - content[len(content_stripped) :] - if len(content) > len(content_stripped) - else "" - ) - return content_stripped, original_whitespace - - def is_opening_code_block(content): - backtick_segments = content.split("```") - # Even number of segments means the last backticks are opening a new block - return len(backtick_segments) > 1 and len(backtick_segments) % 2 == 0 - # Handle as a background task async def response_handler(response, events): - def serialize_content_blocks(content_blocks, raw=False): - content = "" - - for block in content_blocks: - if block["type"] == "text": - block_content = block["content"].strip() - if block_content: - content = f"{content}{block_content}\n" - elif block["type"] == "tool_calls": - attributes = block.get("attributes", {}) - - tool_calls = block.get("content", []) - results = block.get("results", []) - - if content and not content.endswith("\n"): - content += "\n" - - if results: - - tool_calls_display_content = "" - for tool_call in tool_calls: - - tool_call_id = tool_call.get("id", "") - tool_name = tool_call.get("function", {}).get( - "name", "" - ) - tool_arguments = tool_call.get("function", {}).get( - "arguments", "" - ) - - tool_result = None - tool_result_files = None - for result in results: - if tool_call_id == result.get("tool_call_id", ""): - tool_result = result.get("content", None) - tool_result_files = result.get("files", None) - break - - if tool_result is not None: - tool_result_embeds = result.get("embeds", "") - tool_calls_display_content = f'{tool_calls_display_content}
\nTool Executed\n
\n' - else: - tool_calls_display_content = f'{tool_calls_display_content}
\nExecuting...\n
\n' - - if not raw: - content = f"{content}{tool_calls_display_content}" - else: - tool_calls_display_content = "" - - for tool_call in tool_calls: - tool_call_id = tool_call.get("id", "") - tool_name = tool_call.get("function", {}).get( - "name", "" - ) - tool_arguments = tool_call.get("function", {}).get( - "arguments", "" - ) - - tool_calls_display_content = f'{tool_calls_display_content}\n
\nExecuting...\n
\n' - - if not raw: - content = f"{content}{tool_calls_display_content}" - - elif block["type"] == "reasoning": - reasoning_display_content = html.escape( - "\n".join( - (f"> {line}" if not line.startswith(">") else line) - for line in block["content"].splitlines() - ) - ) - - reasoning_duration = block.get("duration", None) - - start_tag = block.get("start_tag", "") - end_tag = block.get("end_tag", "") - - if content and not content.endswith("\n"): - content += "\n" - - if reasoning_duration is not None: - if raw: - content = ( - f'{content}{start_tag}{block["content"]}{end_tag}\n' - ) - else: - content = f'{content}
\nThought for {reasoning_duration} seconds\n{reasoning_display_content}\n
\n' - else: - if raw: - content = ( - f'{content}{start_tag}{block["content"]}{end_tag}\n' - ) - else: - content = f'{content}
\nThinking…\n{reasoning_display_content}\n
\n' - - elif block["type"] == "code_interpreter": - attributes = block.get("attributes", {}) - output = block.get("output", None) - lang = attributes.get("lang", "") - - content_stripped, original_whitespace = ( - split_content_and_whitespace(content) - ) - if is_opening_code_block(content_stripped): - # Remove trailing backticks that would open a new block - content = ( - content_stripped.rstrip("`").rstrip() - + original_whitespace - ) - else: - # Keep content as is - either closing backticks or no backticks - content = content_stripped + original_whitespace - - if content and not content.endswith("\n"): - content += "\n" - - if output: - output = html.escape(json.dumps(output)) - - if raw: - content = f'{content}\n{block["content"]}\n\n```output\n{output}\n```\n' - else: - content = f'{content}
\nAnalyzed\n```{lang}\n{block["content"]}\n```\n
\n' - else: - if raw: - content = f'{content}\n{block["content"]}\n\n' - else: - content = f'{content}
\nAnalyzing...\n```{lang}\n{block["content"]}\n```\n
\n' - - else: - block_content = str(block["content"]).strip() - if block_content: - content = f"{content}{block['type']}: {block_content}\n" - - return content.strip() - - def convert_content_blocks_to_messages(content_blocks, raw=False): - messages = [] - - temp_blocks = [] - for idx, block in enumerate(content_blocks): - if block["type"] == "tool_calls": - messages.append( - { - "role": "assistant", - "content": serialize_content_blocks(temp_blocks, raw), - "tool_calls": block.get("content"), - } - ) - - results = block.get("results", []) - - for result in results: - messages.append( - { - "role": "tool", - "tool_call_id": result["tool_call_id"], - "content": result.get("content", "") or "", - } - ) - temp_blocks = [] - else: - temp_blocks.append(block) - - if temp_blocks: - content = serialize_content_blocks(temp_blocks, raw) - if content: - messages.append( - { - "role": "assistant", - "content": content, - } - ) - - return messages - - def tag_content_handler(content_type, tags, content, content_blocks): + def tag_output_handler(content_type, tags, content, output): + """ + Detect special tags (reasoning, solution, code_interpreter) in streaming + content and create corresponding OR-aligned output items directly. + Operates on output items instead of content_blocks. + """ end_flag = False def extract_attributes(tag_content): """Extract attributes from a tag if they exist.""" attributes = {} - if not tag_content: # Ensure tag_content is not None + if not tag_content: return attributes - # Match attributes in the format: key="value" (ignores single quotes for simplicity) matches = re.findall(r'(\w+)\s*=\s*"([^"]+)"', tag_content) for key, value in matches: attributes[key] = value return attributes - if content_blocks[-1]["type"] == "text": + def get_last_text(out): + """Get text from last message item, or empty string.""" + if out and out[-1].get("type") == "message": + parts = out[-1].get("content", []) + if parts and parts[-1].get("type") == "output_text": + return parts[-1].get("text", "") + return "" + + def set_last_text(out, text): + """Set text on last message item's output_text.""" + if out and out[-1].get("type") == "message": + parts = out[-1].get("content", []) + if parts and parts[-1].get("type") == "output_text": + parts[-1]["text"] = text + + # Map content_type to output item type + output_type_map = { + "reasoning": "reasoning", + "solution": "message", # solution tags just produce text + "code_interpreter": "open_webui:code_interpreter", + } + output_item_type = output_type_map.get(content_type, content_type) + + last_type = output[-1].get("type", "") if output else "" + + if last_type == "message": for start_tag, end_tag in tags: start_tag_pattern = rf"{re.escape(start_tag)}" if start_tag.startswith("<") and start_tag.endswith(">"): - # Match start tag e.g., or - # remove both '<' and '>' from start_tag - # Match start tag with attributes start_tag_pattern = ( rf"<{re.escape(start_tag[1:-1])}(\s.*?)?>" ) @@ -2550,72 +2985,136 @@ def extract_attributes(tag_content): match = re.search(start_tag_pattern, content) if match: try: - attr_content = ( - match.group(1) if match.group(1) else "" - ) # Ensure it's not None + attr_content = match.group(1) if match.group(1) else "" except: attr_content = "" - attributes = extract_attributes( - attr_content - ) # Extract attributes safely + attributes = extract_attributes(attr_content) - # Capture everything before and after the matched tag - before_tag = content[ - : match.start() - ] # Content before opening tag - after_tag = content[ - match.end() : - ] # Content after opening tag + before_tag = content[: match.start()] + after_tag = content[match.end() :] - # Remove the start tag and after from the currently handling text block - content_blocks[-1]["content"] = content_blocks[-1][ - "content" - ].replace(match.group(0) + after_tag, "") + # Remove the start tag and everything after from last message + current_text = get_last_text(output) + set_last_text( + output, + current_text.replace(match.group(0) + after_tag, ""), + ) if before_tag: - content_blocks[-1]["content"] = before_tag + set_last_text(output, before_tag) - if not content_blocks[-1]["content"]: - content_blocks.pop() + if not get_last_text(output).strip(): + # Remove empty message item + if output and output[-1].get("type") == "message": + output.pop() - # Append the new block - content_blocks.append( - { - "type": content_type, - "start_tag": start_tag, - "end_tag": end_tag, - "attributes": attributes, - "content": "", - "started_at": time.time(), - } - ) + # Append the new output item + if output_item_type == "reasoning": + output.append( + { + "type": "reasoning", + "id": output_id("r"), + "status": "in_progress", + "start_tag": start_tag, + "end_tag": end_tag, + "attributes": attributes, + "content": [], + "summary": None, + "started_at": time.time(), + } + ) + elif output_item_type == "open_webui:code_interpreter": + output.append( + { + "type": "open_webui:code_interpreter", + "id": output_id("ci"), + "status": "in_progress", + "start_tag": start_tag, + "end_tag": end_tag, + "attributes": attributes, + "lang": attributes.get("lang", "python"), + "code": "", + "output": None, + "started_at": time.time(), + } + ) + else: + # solution or other text-producing tag + output.append( + { + "type": "message", + "id": output_id("msg"), + "status": "in_progress", + "role": "assistant", + "content": [ + {"type": "output_text", "text": ""} + ], + "_tag_type": content_type, + "start_tag": start_tag, + "end_tag": end_tag, + "attributes": attributes, + "started_at": time.time(), + } + ) if after_tag: - content_blocks[-1]["content"] = after_tag - tag_content_handler( - content_type, tags, after_tag, content_blocks + # Set the after_tag content on the new item + if output_item_type == "reasoning": + output[-1]["content"] = [ + {"type": "output_text", "text": after_tag} + ] + elif output_item_type == "open_webui:code_interpreter": + output[-1]["code"] = after_tag + else: + set_last_text(output, after_tag) + + tag_output_handler( + content_type, tags, after_tag, output ) break - elif content_blocks[-1]["type"] == content_type: - start_tag = content_blocks[-1]["start_tag"] - end_tag = content_blocks[-1]["end_tag"] + + elif ( + (last_type == "reasoning" and content_type == "reasoning") + or ( + last_type == "open_webui:code_interpreter" + and content_type == "code_interpreter" + ) + or ( + last_type == "message" + and output[-1].get("_tag_type") == content_type + ) + ): + item = output[-1] + start_tag = item.get("start_tag", "") + end_tag = item.get("end_tag", "") if end_tag.startswith("<") and end_tag.endswith(">"): - # Match end tag e.g., end_tag_pattern = rf"{re.escape(end_tag)}" else: - # Handle cases where end_tag is just a tag name end_tag_pattern = rf"{re.escape(end_tag)}" - # Check if the content has the end tag if re.search(end_tag_pattern, content): end_flag = True - block_content = content_blocks[-1]["content"] - # Strip start and end tags from the content - start_tag_pattern = rf"<{re.escape(start_tag)}(.*?)>" + # Get the block content + if last_type == "reasoning": + parts = item.get("content", []) + block_content = "" + if parts and parts[-1].get("type") == "output_text": + block_content = parts[-1].get("text", "") + elif last_type == "open_webui:code_interpreter": + block_content = item.get("code", "") + else: + block_content = get_last_text(output) + + # Strip start and end tags from content + start_tag_pattern = rf"{re.escape(start_tag)}" + if start_tag.startswith("<") and start_tag.endswith(">"): + start_tag_pattern = ( + rf"<{re.escape(start_tag[1:-1])}(\s.*?)?>" + ) block_content = re.sub( start_tag_pattern, "", block_content ).strip() @@ -2623,79 +3122,96 @@ def extract_attributes(tag_content): end_tag_regex = re.compile(end_tag_pattern, re.DOTALL) split_content = end_tag_regex.split(block_content, maxsplit=1) - # Content inside the tag block_content = ( split_content[0].strip() if split_content else "" ) - - # Leftover content (everything after ``) leftover_content = ( split_content[1].strip() if len(split_content) > 1 else "" ) if block_content: - content_blocks[-1]["content"] = block_content - content_blocks[-1]["ended_at"] = time.time() - content_blocks[-1]["duration"] = int( - content_blocks[-1]["ended_at"] - - content_blocks[-1]["started_at"] - ) + # Update the item with final content + if last_type == "reasoning": + item["content"] = [ + {"type": "output_text", "text": block_content} + ] + item["ended_at"] = time.time() + item["duration"] = int( + item["ended_at"] - item["started_at"] + ) + item["status"] = "completed" + elif last_type == "open_webui:code_interpreter": + item["code"] = block_content + item["ended_at"] = time.time() + item["duration"] = int( + item["ended_at"] - item["started_at"] + ) + else: + set_last_text(output, block_content) + item["ended_at"] = time.time() - # Reset the content_blocks by appending a new text block + # Reset by appending a new message item for leftover if content_type != "code_interpreter": - if leftover_content: - - content_blocks.append( - { - "type": "text", - "content": leftover_content, - } - ) - else: - content_blocks.append( - { - "type": "text", - "content": "", - } - ) - - else: - # Remove the block if content is empty - content_blocks.pop() - - if leftover_content: - content_blocks.append( + output.append( { - "type": "text", - "content": leftover_content, + "type": "message", + "id": output_id("msg"), + "status": "in_progress", + "role": "assistant", + "content": [ + { + "type": "output_text", + "text": leftover_content, + } + ], } ) else: - content_blocks.append( + output.append( { - "type": "text", - "content": "", + "type": "message", + "id": output_id("msg"), + "status": "in_progress", + "role": "assistant", + "content": [ + { + "type": "output_text", + "text": leftover_content, + } + ], } ) + else: + # Remove the block if content is empty + output.pop() + output.append( + { + "type": "message", + "id": output_id("msg"), + "status": "in_progress", + "role": "assistant", + "content": [ + { + "type": "output_text", + "text": leftover_content, + } + ], + } + ) # Clean processed content - start_tag_pattern = rf"{re.escape(start_tag)}" + start_tag_clean = rf"{re.escape(start_tag)}" if start_tag.startswith("<") and start_tag.endswith(">"): - # Match start tag e.g., or - # remove both '<' and '>' from start_tag - # Match start tag with attributes - start_tag_pattern = ( - rf"<{re.escape(start_tag[1:-1])}(\s.*?)?>" - ) + start_tag_clean = rf"<{re.escape(start_tag[1:-1])}(\s.*?)?>" content = re.sub( - rf"{start_tag_pattern}(.|\n)*?{re.escape(end_tag)}", + rf"{start_tag_clean}(.|\n)*?{re.escape(end_tag)}", "", content, flags=re.DOTALL, ) - return content, content_blocks, end_flag + return content, output, end_flag message = Chats.get_message_by_id_and_message_id( metadata["chat_id"], metadata["message_id"] @@ -2718,12 +3234,26 @@ def extract_attributes(tag_content): else last_assistant_message if last_assistant_message else "" ) - content_blocks = [ - { - "type": "text", - "content": content, - } - ] + # Initialize output: use existing from message if continuing, else create new + existing_output = message.get("output") if message else None + if existing_output: + output = existing_output + else: + # Only create an initial message item if there is content to initialize with + if content: + output = [ + { + "type": "message", + "id": output_id("msg"), + "status": "in_progress", + "role": "assistant", + "content": [{"type": "output_text", "text": content}], + } + ] + else: + output = [] + + usage = None reasoning_tags_param = metadata.get("params", {}).get("reasoning_tags") DETECT_REASONING_TAGS = reasoning_tags_param is not False @@ -2763,7 +3293,8 @@ def extract_attributes(tag_content): async def stream_body_handler(response, form_data): nonlocal content - nonlocal content_blocks + nonlocal usage + nonlocal output response_tool_calls = [] @@ -2842,13 +3373,41 @@ async def flush_pending_delta_data(threshold: int = 0): "data": data, } ) + # Check for Responses API events (type field starts with "response.") + elif data.get("type", "").startswith("response."): + output, response_metadata = ( + handle_responses_streaming_event(data, output) + ) + + processed_data = { + "output": output, + "content": serialize_output(output), + } + + # print(data) + # print(processed_data) + + # Merge any metadata (usage, done, etc.) + if response_metadata: + processed_data.update(response_metadata) + + await event_emitter( + { + "type": "chat:completion", + "data": processed_data, + } + ) + continue else: choices = data.get("choices", []) - # 17421 - usage = data.get("usage", {}) or {} - usage.update(data.get("timings", {})) # llama.cpp - if usage: + # Normalize usage data to standard format + raw_usage = data.get("usage", {}) or {} + raw_usage.update( + data.get("timings", {}) + ) # llama.cpp + if raw_usage: + usage = normalize_usage(raw_usage) await event_emitter( { "type": "chat:completion", @@ -2971,19 +3530,31 @@ async def flush_pending_delta_data(threshold: int = 0): # Flush any pending text first await flush_pending_delta_data() - pending_content_blocks = content_blocks + [ - { - "type": "tool_calls", - "content": response_tool_calls, - "pending": True, - } - ] + # Build pending function_call output items for display + pending_fc_items = [] + for tc in response_tool_calls: + call_id = tc.get("id", "") + func = tc.get("function", {}) + pending_fc_items.append( + { + "type": "function_call", + "id": call_id + or output_id("fc"), + "call_id": call_id, + "name": func.get("name", ""), + "arguments": func.get( + "arguments", "{}" + ), + "status": "in_progress", + } + ) + pending_output = output + pending_fc_items await event_emitter( { "type": "chat:completion", "data": { - "content": serialize_content_blocks( - pending_content_blocks + "content": serialize_output( + pending_output ), }, } @@ -3018,52 +3589,72 @@ async def flush_pending_delta_data(threshold: int = 0): ) if reasoning_content: if ( - not content_blocks - or content_blocks[-1]["type"] != "reasoning" + not output + or output[-1].get("type") != "reasoning" ): - reasoning_block = { + reasoning_item = { "type": "reasoning", + "id": output_id("r"), + "status": "in_progress", "start_tag": "", "end_tag": "", "attributes": { "type": "reasoning_content" }, - "content": "", + "content": [], + "summary": None, "started_at": time.time(), } - content_blocks.append(reasoning_block) + output.append(reasoning_item) else: - reasoning_block = content_blocks[-1] + reasoning_item = output[-1] - reasoning_block["content"] += reasoning_content + # Append to reasoning content + parts = reasoning_item.get("content", []) + if ( + parts + and parts[-1].get("type") == "output_text" + ): + parts[-1]["text"] += reasoning_content + else: + reasoning_item["content"] = [ + { + "type": "output_text", + "text": reasoning_content, + } + ] - data = { - "content": serialize_content_blocks( - content_blocks - ) - } + data = {"content": serialize_output(output)} if value: if ( - content_blocks - and content_blocks[-1]["type"] - == "reasoning" - and content_blocks[-1] + output + and output[-1].get("type") == "reasoning" + and output[-1] .get("attributes", {}) .get("type") == "reasoning_content" ): - reasoning_block = content_blocks[-1] - reasoning_block["ended_at"] = time.time() - reasoning_block["duration"] = int( - reasoning_block["ended_at"] - - reasoning_block["started_at"] + reasoning_item = output[-1] + reasoning_item["ended_at"] = time.time() + reasoning_item["duration"] = int( + reasoning_item["ended_at"] + - reasoning_item["started_at"] ) + reasoning_item["status"] = "completed" - content_blocks.append( + output.append( { - "type": "text", - "content": "", + "type": "message", + "id": output_id("msg"), + "status": "in_progress", + "role": "assistant", + "content": [ + { + "type": "output_text", + "text": "", + } + ], } ) @@ -3083,45 +3674,59 @@ async def flush_pending_delta_data(threshold: int = 0): ) content = f"{content}{value}" - if not content_blocks: - content_blocks.append( + if ( + not output + or output[-1].get("type") != "message" + ): + output.append( { - "type": "text", - "content": "", + "type": "message", + "id": output_id("msg"), + "status": "in_progress", + "role": "assistant", + "content": [ + { + "type": "output_text", + "text": "", + } + ], } ) - content_blocks[-1]["content"] = ( - content_blocks[-1]["content"] + value - ) + # Append value to last message item's text + msg_parts = output[-1].get("content", []) + if ( + msg_parts + and msg_parts[-1].get("type") + == "output_text" + ): + msg_parts[-1]["text"] += value + else: + output[-1]["content"] = [ + {"type": "output_text", "text": value} + ] if DETECT_REASONING_TAGS: - content, content_blocks, _ = ( - tag_content_handler( - "reasoning", - reasoning_tags, - content, - content_blocks, - ) + content, output, _ = tag_output_handler( + "reasoning", + reasoning_tags, + content, + output, ) - content, content_blocks, _ = ( - tag_content_handler( - "solution", - DEFAULT_SOLUTION_TAGS, - content, - content_blocks, - ) + content, output, _ = tag_output_handler( + "solution", + DEFAULT_SOLUTION_TAGS, + content, + output, ) if DETECT_CODE_INTERPRETER: - content, content_blocks, end = ( - tag_content_handler( - "code_interpreter", - DEFAULT_CODE_INTERPRETER_TAGS, - content, - content_blocks, - ) + content, output, end = tag_output_handler( + "code_interpreter", + DEFAULT_CODE_INTERPRETER_TAGS, + content, + output, ) if end: @@ -3133,16 +3738,13 @@ async def flush_pending_delta_data(threshold: int = 0): metadata["chat_id"], metadata["message_id"], { - "content": serialize_content_blocks( - content_blocks - ), + "content": serialize_output(output), + "output": output, }, ) else: data = { - "content": serialize_content_blocks( - content_blocks - ), + "content": serialize_output(output), } if delta: @@ -3166,32 +3768,38 @@ async def flush_pending_delta_data(threshold: int = 0): continue await flush_pending_delta_data() - if content_blocks: - # Clean up the last text block - if content_blocks[-1]["type"] == "text": - content_blocks[-1]["content"] = content_blocks[-1][ - "content" - ].strip() + if output: + # Clean up the last message item + if output[-1].get("type") == "message": + parts = output[-1].get("content", []) + if parts and parts[-1].get("type") == "output_text": + parts[-1]["text"] = parts[-1]["text"].strip() - if not content_blocks[-1]["content"]: - content_blocks.pop() + if not parts[-1]["text"]: + output.pop() - if not content_blocks: - content_blocks.append( - { - "type": "text", - "content": "", - } - ) + if not output: + output.append( + { + "type": "message", + "id": output_id("msg"), + "status": "in_progress", + "role": "assistant", + "content": [ + {"type": "output_text", "text": ""} + ], + } + ) - if content_blocks[-1]["type"] == "reasoning": - reasoning_block = content_blocks[-1] - if reasoning_block.get("ended_at") is None: - reasoning_block["ended_at"] = time.time() - reasoning_block["duration"] = int( - reasoning_block["ended_at"] - - reasoning_block["started_at"] + if output[-1].get("type") == "reasoning": + reasoning_item = output[-1] + if reasoning_item.get("ended_at") is None: + reasoning_item["ended_at"] = time.time() + reasoning_item["duration"] = int( + reasoning_item["ended_at"] + - reasoning_item["started_at"] ) + reasoning_item["status"] = "completed" if response_tool_calls: tool_calls.append(response_tool_calls) @@ -3213,18 +3821,27 @@ async def flush_pending_delta_data(threshold: int = 0): response_tool_calls = tool_calls.pop(0) - content_blocks.append( - { - "type": "tool_calls", - "content": response_tool_calls, - } - ) + # Append function_call items for each tool call + for tc in response_tool_calls: + call_id = tc.get("id", "") + func = tc.get("function", {}) + output.append( + { + "type": "function_call", + "id": call_id or output_id("fc"), + "call_id": call_id, + "name": func.get("name", ""), + "arguments": func.get("arguments", "{}"), + "status": "in_progress", + } + ) await event_emitter( { "type": "chat:completion", "data": { - "content": serialize_content_blocks(content_blocks), + "content": serialize_output(output), + "output": output, }, } ) @@ -3254,11 +3871,7 @@ async def flush_pending_delta_data(threshold: int = 0): f"Error parsing tool call arguments: {tool_args}" ) - # Mutate the original tool call response params as they are passed back to the passed - # back to the LLM via the content blocks. If they are in a json block and are invalid json, - # this can cause downstream LLM integrations to fail (e.g. bedrock gateway) where response - # params are not valid json. - # Main case so far is no args = "" = invalid json. + # Ensure arguments are valid JSON for downstream LLM integrations log.debug( f"Parsed args from {tool_args} to {tool_function_params}" ) @@ -3375,11 +3988,56 @@ async def flush_pending_delta_data(threshold: int = 0): } ) - content_blocks[-1]["results"] = results - content_blocks.append( + # Update function_call statuses and append function_call_output items + for tc in response_tool_calls: + call_id = tc.get("id", "") + # Mark function_call as completed + for item in output: + if ( + item.get("type") == "function_call" + and item.get("call_id") == call_id + ): + item["status"] = "completed" + # Update arguments with parsed/sanitized version + item["arguments"] = tc.get("function", {}).get( + "arguments", "{}" + ) + break + + for result in results: + output.append( + { + "type": "function_call_output", + "id": output_id("fco"), + "call_id": result.get("tool_call_id", ""), + "output": [ + { + "type": "input_text", + "text": result.get("content", ""), + } + ], + "status": "completed", + **( + {"files": result.get("files")} + if result.get("files") + else {} + ), + **( + {"embeds": result.get("embeds")} + if result.get("embeds") + else {} + ), + } + ) + + # Append a new empty message item for the next response + output.append( { - "type": "text", - "content": "", + "type": "message", + "id": output_id("msg"), + "status": "in_progress", + "role": "assistant", + "content": [{"type": "output_text", "text": ""}], } ) @@ -3403,7 +4061,8 @@ async def flush_pending_delta_data(threshold: int = 0): { "type": "chat:completion", "data": { - "content": serialize_content_blocks(content_blocks), + "content": serialize_output(output), + "output": output, }, } ) @@ -3415,9 +4074,7 @@ async def flush_pending_delta_data(threshold: int = 0): "stream": True, "messages": [ *form_data["messages"], - *convert_content_blocks_to_messages( - content_blocks, True - ), + *convert_output_to_messages(output, raw=True), ], } @@ -3441,7 +4098,8 @@ async def flush_pending_delta_data(threshold: int = 0): retries = 0 while ( - content_blocks[-1]["type"] == "code_interpreter" + output + and output[-1].get("type") == "open_webui:code_interpreter" and retries < MAX_RETRIES ): @@ -3449,7 +4107,8 @@ async def flush_pending_delta_data(threshold: int = 0): { "type": "chat:completion", "data": { - "content": serialize_content_blocks(content_blocks), + "content": serialize_output(output), + "output": output, }, } ) @@ -3457,17 +4116,20 @@ async def flush_pending_delta_data(threshold: int = 0): retries += 1 log.debug(f"Attempt count: {retries}") - output = "" + ci_item = output[-1] + ci_output = "" try: - if content_blocks[-1]["attributes"].get("type") == "code": - code = content_blocks[-1]["content"] + if ci_item.get("attributes", {}).get("type") == "code": + code = ci_item.get("code", "") + # Sanitize code (strips ANSI codes and markdown fences) + code = sanitize_code(code) + if CODE_INTERPRETER_BLOCKED_MODULES: - blocking_code = textwrap.dedent( - f""" + blocking_code = textwrap.dedent(f""" import builtins - + BLOCKED_MODULES = {CODE_INTERPRETER_BLOCKED_MODULES} - + _real_import = builtins.__import__ def restricted_import(name, globals=None, locals=None, fromlist=(), level=0): if name.split('.')[0] in BLOCKED_MODULES: @@ -3477,17 +4139,16 @@ def restricted_import(name, globals=None, locals=None, fromlist=(), level=0): f"Direct import of module {{name}} is restricted." ) return _real_import(name, globals, locals, fromlist, level) - + builtins.__import__ = restricted_import - """ - ) + """) code = blocking_code + "\n" + code if ( request.app.state.config.CODE_INTERPRETER_ENGINE == "pyodide" ): - output = await event_caller( + ci_output = await event_caller( { "type": "execute:python", "data": { @@ -3503,7 +4164,7 @@ def restricted_import(name, globals=None, locals=None, fromlist=(), level=0): request.app.state.config.CODE_INTERPRETER_ENGINE == "jupyter" ): - output = await execute_code_jupyter( + ci_output = await execute_code_jupyter( request.app.state.config.CODE_INTERPRETER_JUPYTER_URL, code, ( @@ -3521,14 +4182,14 @@ def restricted_import(name, globals=None, locals=None, fromlist=(), level=0): request.app.state.config.CODE_INTERPRETER_JUPYTER_TIMEOUT, ) else: - output = { + ci_output = { "stdout": "Code interpreter engine not configured." } - log.debug(f"Code interpreter output: {output}") + log.debug(f"Code interpreter output: {ci_output}") - if isinstance(output, dict): - stdout = output.get("stdout", "") + if isinstance(ci_output, dict): + stdout = ci_output.get("stdout", "") if isinstance(stdout, str): stdoutLines = stdout.split("\n") @@ -3546,9 +4207,9 @@ def restricted_import(name, globals=None, locals=None, fromlist=(), level=0): f"![Output Image]({image_url})" ) - output["stdout"] = "\n".join(stdoutLines) + ci_output["stdout"] = "\n".join(stdoutLines) - result = output.get("result", "") + result = ci_output.get("result", "") if isinstance(result, str): resultLines = result.split("\n") @@ -3563,16 +4224,20 @@ def restricted_import(name, globals=None, locals=None, fromlist=(), level=0): resultLines[idx] = ( f"![Output Image]({image_url})" ) - output["result"] = "\n".join(resultLines) + ci_output["result"] = "\n".join(resultLines) except Exception as e: - output = str(e) + ci_output = str(e) - content_blocks[-1]["output"] = output + ci_item["output"] = ci_output + ci_item["status"] = "completed" - content_blocks.append( + output.append( { - "type": "text", - "content": "", + "type": "message", + "id": output_id("msg"), + "status": "in_progress", + "role": "assistant", + "content": [{"type": "output_text", "text": ""}], } ) @@ -3580,7 +4245,8 @@ def restricted_import(name, globals=None, locals=None, fromlist=(), level=0): { "type": "chat:completion", "data": { - "content": serialize_content_blocks(content_blocks), + "content": serialize_output(output), + "output": output, }, } ) @@ -3592,12 +4258,7 @@ def restricted_import(name, globals=None, locals=None, fromlist=(), level=0): "stream": True, "messages": [ *form_data["messages"], - { - "role": "assistant", - "content": serialize_content_blocks( - content_blocks, raw=True - ), - }, + *convert_output_to_messages(output, raw=True), ], } @@ -3616,10 +4277,16 @@ def restricted_import(name, globals=None, locals=None, fromlist=(), level=0): log.debug(e) break + # Mark all in-progress items as completed + for item in output: + if item.get("status") == "in_progress": + item["status"] = "completed" + title = Chats.get_chat_title_by_id(metadata["chat_id"]) data = { "done": True, - "content": serialize_content_blocks(content_blocks), + "content": serialize_output(output), + "output": output, "title": title, } @@ -3629,9 +4296,17 @@ def restricted_import(name, globals=None, locals=None, fromlist=(), level=0): metadata["chat_id"], metadata["message_id"], { - "content": serialize_content_blocks(content_blocks), + "content": serialize_output(output), + "output": output, + **({"usage": usage} if usage else {}), }, ) + elif usage: + Chats.upsert_message_to_chat_by_id_and_message_id( + metadata["chat_id"], + metadata["message_id"], + {"usage": usage}, + ) # Send a webhook notification if the user is not active if not Users.is_user_active(user.id): @@ -3656,7 +4331,7 @@ def restricted_import(name, globals=None, locals=None, fromlist=(), level=0): } ) - await background_tasks_handler() + await background_tasks_handler(ctx) except asyncio.CancelledError: log.warning("Task was cancelled!") await event_emitter({"type": "chat:tasks:cancel"}) @@ -3667,7 +4342,8 @@ def restricted_import(name, globals=None, locals=None, fromlist=(), level=0): metadata["chat_id"], metadata["message_id"], { - "content": serialize_content_blocks(content_blocks), + "content": serialize_output(output), + "output": output, }, ) @@ -3711,3 +4387,19 @@ def wrap_item(item): headers=dict(response.headers), background=response.background, ) + + +async def process_chat_response(response, ctx): + # Non-streaming response + if not isinstance(response, StreamingResponse): + return await non_streaming_chat_response_handler(response, ctx) + + # Non standard response + if not any( + content_type in response.headers["Content-Type"] + for content_type in ["text/event-stream", "application/x-ndjson"] + ): + return response + + # Streaming response + return await streaming_chat_response_handler(response, ctx) diff --git a/backend/open_webui/utils/misc.py b/backend/open_webui/utils/misc.py index 7b98209734..fbed41420c 100644 --- a/backend/open_webui/utils/misc.py +++ b/backend/open_webui/utils/misc.py @@ -11,7 +11,6 @@ import aiohttp import mimeparse - import collections.abc from open_webui.env import CHAT_STREAM_RESPONSE_CHUNK_MAX_BUFFER_SIZE @@ -128,6 +127,138 @@ def get_content_from_message(message: dict) -> Optional[str]: return None +def convert_output_to_messages(output: list, raw: bool = False) -> list[dict]: + """ + Convert OR-aligned output items to OpenAI Chat Completion-format messages. + + This reconstructs the full conversation from the stored Responses API-native + output items, including assistant messages with tool_calls arrays and tool + role messages. + + Args: + output: List of OR-aligned output items (Responses API format). + raw: If True, include reasoning blocks (with original tags) and code + interpreter blocks for LLM re-processing follow-ups. + """ + if not output or not isinstance(output, list): + return [] + + messages = [] + pending_tool_calls = [] + pending_content = [] + + def flush_pending(): + nonlocal pending_content, pending_tool_calls + if pending_content or pending_tool_calls: + messages.append( + { + "role": "assistant", + "content": "\n".join(pending_content) if pending_content else "", + **( + {"tool_calls": pending_tool_calls} if pending_tool_calls else {} + ), + } + ) + pending_content = [] + pending_tool_calls = [] + + for item in output: + item_type = item.get("type", "") + + if item_type == "message": + # Extract text from output_text content parts + content_parts = item.get("content", []) + text = "" + for part in content_parts: + if part.get("type") == "output_text": + text += part.get("text", "") + if text: + pending_content.append(text) + + elif item_type == "function_call": + # Collect tool calls to batch into assistant message + arguments = item.get("arguments", "{}") + # Ensure arguments is always a JSON string + if not isinstance(arguments, str): + arguments = json.dumps(arguments) + pending_tool_calls.append( + { + "id": item.get("call_id", ""), + "type": "function", + "function": { + "name": item.get("name", ""), + "arguments": arguments, + }, + } + ) + + elif item_type == "function_call_output": + # Flush any pending content/tool_calls before adding tool result + flush_pending() + + # Extract text from output content parts + output_parts = item.get("output", []) + content = "" + for part in output_parts: + if part.get("type") == "input_text": + content += part.get("text", "") + + messages.append( + { + "role": "tool", + "tool_call_id": item.get("call_id", ""), + "content": content, + } + ) + + elif item_type == "reasoning": + if raw: + # Include reasoning with original tags for LLM re-processing + reasoning_text = "" + source_list = item.get("summary", []) or item.get("content", []) + for part in source_list: + if part.get("type") == "output_text": + reasoning_text += part.get("text", "") + elif "text" in part: + reasoning_text += part.get("text", "") + + if reasoning_text: + start_tag = item.get("start_tag", "") + end_tag = item.get("end_tag", "") + pending_content.append(f"{start_tag}{reasoning_text}{end_tag}") + # else: skip reasoning blocks for normal LLM messages + + elif item_type == "open_webui:code_interpreter": + if raw: + # Include code interpreter content for LLM re-processing + code = item.get("code", "") + code_output = item.get("output", "") + + if code: + lang = item.get("lang", "python") + pending_content.append(f"```{lang}\n{code}\n```") + + if code_output: + if isinstance(code_output, dict): + stdout = code_output.get("stdout", "") + result = code_output.get("result", "") + output_text = stdout or result + else: + output_text = str(code_output) + if output_text: + pending_content.append(f"Output:\n{output_text}") + # else: skip extension types + + elif item_type.startswith("open_webui:"): + # Skip other extension types + pass + + # Flush remaining content/tool_calls + flush_pending() + + return messages + + def get_last_user_message(messages: list[dict]) -> Optional[str]: message = get_last_user_message_item(messages) if message is None: @@ -650,105 +781,109 @@ def extract_urls(text: str) -> list[str]: return url_pattern.findall(text) -def stream_chunks_handler( - user: "UserModel", model_id: str, form_data: dict, stream: aiohttp.StreamReader +async def cleanup_response( + response: Optional[aiohttp.ClientResponse], + session: Optional[aiohttp.ClientSession], +): + if response: + response.close() + if session: + await session.close() + + +async def stream_wrapper( + user, model_id, form_data, response, session, content_handler=None ): + """ + Wrap a stream to ensure cleanup happens even if streaming is interrupted. + This is more reliable than BackgroundTask which may not run if client disconnects. + """ + from open_webui.utils.credit.usage import CreditDeduct + + try: + stream = ( + content_handler(response.content) if content_handler else response.content + ) + with CreditDeduct( + user=user, + model_id=model_id, + body=form_data, + is_stream=True, + ) as credit_deduct: + # change to avoid multi \n\n cause message lose + async for chunk in stream: + credit_deduct.run(response=chunk) + yield chunk + + yield credit_deduct.usage_message + finally: + await cleanup_response(response, session) + + +def stream_chunks_handler(stream: aiohttp.StreamReader): """ Handle stream response chunks, supporting large data chunks that exceed the original 16kb limit. When a single line exceeds max_buffer_size, returns an empty JSON string {} and skips subsequent data until encountering normally sized data. - :param user: The user making the request. - :param model_id: The ID of the model being used. - :param form_data: The form data associated with the request. :param stream: The stream reader to handle. :return: An async generator that yields the stream data. """ - from open_webui.utils.credit.usage import CreditDeduct - max_buffer_size = CHAT_STREAM_RESPONSE_CHUNK_MAX_BUFFER_SIZE if max_buffer_size is None or max_buffer_size <= 0: - - async def consumer_content(stream: aiohttp.StreamReader): - with CreditDeduct( - user=user, - model_id=model_id, - body=form_data, - is_stream=True, - ) as credit_deduct: - # change to avoid multi \n\n cause message lose - async for chunk in stream: - credit_deduct.run(response=chunk) - yield chunk - - yield credit_deduct.usage_message - - return consumer_content(stream) + return stream async def yield_safe_stream_chunks(): buffer = b"" skip_mode = False - with CreditDeduct( - user=user, - model_id=model_id, - body=form_data, - is_stream=True, - ) as credit_deduct: - # change to avoid multi \n\n cause message lose - async for data in stream: - - if not data: - continue - - credit_deduct.run(response=data) + async for data, _ in stream.iter_chunks(): + if not data: + continue - # In skip_mode, if buffer already exceeds the limit, clear it (it's part of an oversized line) - if skip_mode and len(buffer) > max_buffer_size: - buffer = b"" + # In skip_mode, if buffer already exceeds the limit, clear it (it's part of an oversized line) + if skip_mode and len(buffer) > max_buffer_size: + buffer = b"" - lines = (buffer + data).split(b"\n") + lines = (buffer + data).split(b"\n") - # Process complete lines (except the last possibly incomplete fragment) - for i in range(len(lines) - 1): - line = lines[i] + # Process complete lines (except the last possibly incomplete fragment) + for i in range(len(lines) - 1): + line = lines[i] - if skip_mode: - # Skip mode: check if current line is small enough to exit skip mode - if len(line) <= max_buffer_size: - skip_mode = False - yield line - else: - yield b"data: {}" - yield b"\n" + if skip_mode: + # Skip mode: check if current line is small enough to exit skip mode + if len(line) <= max_buffer_size: + skip_mode = False + yield line else: - # Normal mode: check if line exceeds limit - if len(line) > max_buffer_size: - skip_mode = True - yield b"data: {}" - yield b"\n" - log.info(f"Skip mode triggered, line size: {len(line)}") - else: - yield line - yield b"\n" - - # Save the last incomplete fragment - buffer = lines[-1] - - # Check if buffer exceeds limit - if not skip_mode and len(buffer) > max_buffer_size: - skip_mode = True - log.info(f"Skip mode triggered, buffer size: {len(buffer)}") - # Clear oversized buffer to prevent unlimited growth - buffer = b"" - - # Process remaining buffer data - if buffer and not skip_mode: - credit_deduct.run(response=buffer) - yield buffer - yield b"\n" - - yield credit_deduct.usage_message + yield b"data: {}" + yield b"\n" + else: + # Normal mode: check if line exceeds limit + if len(line) > max_buffer_size: + skip_mode = True + yield b"data: {}" + yield b"\n" + log.info(f"Skip mode triggered, line size: {len(line)}") + else: + yield line + yield b"\n" + + # Save the last incomplete fragment + buffer = lines[-1] + + # Check if buffer exceeds limit + if not skip_mode and len(buffer) > max_buffer_size: + skip_mode = True + log.info(f"Skip mode triggered, buffer size: {len(buffer)}") + # Clear oversized buffer to prevent unlimited growth + buffer = b"" + + # Process remaining buffer data + if buffer and not skip_mode: + yield buffer + yield b"\n" return yield_safe_stream_chunks() diff --git a/backend/open_webui/utils/models.py b/backend/open_webui/utils/models.py index b3a332adee..ff3a6e0caf 100644 --- a/backend/open_webui/utils/models.py +++ b/backend/open_webui/utils/models.py @@ -13,6 +13,7 @@ from open_webui.models.functions import Functions from open_webui.models.models import Models +from open_webui.models.access_grants import AccessGrants from open_webui.models.groups import Groups @@ -31,7 +32,6 @@ from open_webui.env import BYPASS_MODEL_ACCESS_CONTROL, GLOBAL_LOG_LEVEL from open_webui.models.users import UserModel - logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL) log = logging.getLogger(__name__) @@ -339,12 +339,12 @@ def get_function_module_by_id(function_id): def check_model_access(user, model, db=None): if model.get("arena"): + meta = model.get("info", {}).get("meta", {}) + access_grants = meta.get("access_grants", []) if not has_access( user.id, - type="read", - access_control=model.get("info", {}) - .get("meta", {}) - .get("access_control", {}), + permission="read", + access_grants=access_grants, db=db, ): raise Exception("Model not found") @@ -354,8 +354,12 @@ def check_model_access(user, model, db=None): raise Exception("Model not found") elif not ( user.id == model_info.user_id - or has_access( - user.id, type="read", access_control=model_info.access_control, db=db + or AccessGrants.has_access( + user_id=user.id, + resource_type="model", + resource_id=model_info.id, + permission="read", + db=db, ) ): raise Exception("Model not found") @@ -379,12 +383,12 @@ def get_filtered_models(models, user, db=None): } for model in models: if model.get("arena"): + meta = model.get("info", {}).get("meta", {}) + access_grants = meta.get("access_grants", []) if has_access( user.id, - type="read", - access_control=model.get("info", {}) - .get("meta", {}) - .get("access_control", {}), + permission="read", + access_grants=access_grants, user_group_ids=user_group_ids, ): filtered_models.append(model) @@ -395,11 +399,13 @@ def get_filtered_models(models, user, db=None): if ( (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL) or user.id == model_info.user_id - or has_access( - user.id, - type="read", - access_control=model_info.access_control, + or AccessGrants.has_access( + user_id=user.id, + resource_type="model", + resource_id=model_info.id, + permission="read", user_group_ids=user_group_ids, + db=db, ) ): filtered_models.append(model) diff --git a/backend/open_webui/utils/payload.py b/backend/open_webui/utils/payload.py index 458687b371..318b8f8f88 100644 --- a/backend/open_webui/utils/payload.py +++ b/backend/open_webui/utils/payload.py @@ -287,7 +287,13 @@ def convert_payload_openai_to_ollama(openai_payload: dict) -> dict: Returns: dict: A modified payload compatible with the Ollama API. """ - openai_payload = copy.deepcopy(openai_payload) + # Shallow copy metadata separately (may contain non-picklable objects) + metadata = openai_payload.get("metadata") + openai_payload = copy.deepcopy( + {k: v for k, v in openai_payload.items() if k != "metadata"} + ) + if metadata is not None: + openai_payload["metadata"] = dict(metadata) ollama_payload = {} # Mapping basic model and message details @@ -391,3 +397,29 @@ def convert_embedding_payload_openai_to_ollama(openai_payload: dict) -> dict: ollama_payload[optional_key] = openai_payload[optional_key] return ollama_payload + + +def convert_embed_payload_openai_to_ollama(openai_payload: dict) -> dict: + """ + Convert an embeddings request payload from OpenAI format to Ollama's + /api/embed format, which supports batch input natively. + + Args: + openai_payload (dict): The original payload designed for OpenAI API usage. + Expected keys: "model", "input" (str or list[str]). + + Returns: + dict: A payload compatible with the Ollama /api/embed endpoint. + """ + ollama_payload = {"model": openai_payload.get("model")} + input_value = openai_payload.get("input") + + # /api/embed accepts 'input' as a string or list of strings directly + ollama_payload["input"] = input_value + + # Optionally forward other fields if present + for optional_key in ("truncate", "options", "keep_alive"): + if optional_key in openai_payload: + ollama_payload[optional_key] = openai_payload[optional_key] + + return ollama_payload diff --git a/backend/open_webui/utils/plugin.py b/backend/open_webui/utils/plugin.py index 965b0a688f..2dd49fb8ff 100644 --- a/backend/open_webui/utils/plugin.py +++ b/backend/open_webui/utils/plugin.py @@ -6,6 +6,7 @@ import types import tempfile import logging +from typing import Any from open_webui.env import PIP_OPTIONS, PIP_PACKAGE_INDEX_OPTIONS, OFFLINE_MODE from open_webui.models.functions import Functions @@ -14,6 +15,142 @@ log = logging.getLogger(__name__) +def resolve_valves_schema_options( + valves_class: type, schema: dict, user: Any = None +) -> dict: + """ + Resolve dynamic options in a Valves schema. + + For properties with `input.options`, this function handles two cases: + - List: Used directly as dropdown options + - String: Treated as method name, called to get options dynamically + + Usage in Valves: + class UserValves(BaseModel): + # Static options + priority: str = Field( + default="medium", + json_schema_extra={ + "input": { + "type": "select", + "options": ["low", "medium", "high"] + } + } + ) + + # Dynamic options (method name) + model: str = Field( + default="", + json_schema_extra={ + "input": { + "type": "select", + "options": "get_model_options" + } + } + ) + + @classmethod + def get_model_options(cls, __user__=None) -> list[dict]: + return [{"value": "gpt-4", "label": "GPT-4"}] + + Args: + valves_class: The Valves or UserValves Pydantic model class + schema: The JSON schema dict from valves_class.schema() + user: Optional user object passed to methods that accept __user__ + + Returns: + Modified schema dict with resolved options + """ + if not schema or "properties" not in schema: + return schema + + # Make a copy to avoid mutating the original + schema = dict(schema) + schema["properties"] = dict(schema.get("properties", {})) + + for prop_name, prop_schema in list(schema["properties"].items()): + # Get the original field info from the Pydantic model + if not hasattr(valves_class, "model_fields"): + continue + + field_info = valves_class.model_fields.get(prop_name) + if not field_info: + continue + + # Check json_schema_extra for options + json_schema_extra = field_info.json_schema_extra + if not json_schema_extra or not isinstance(json_schema_extra, dict): + continue + + input_config = json_schema_extra.get("input") + if not input_config or not isinstance(input_config, dict): + continue + + options = input_config.get("options") + if options is None: + continue + + resolved_options = None + + # Case 1: options is already a list - use directly + if isinstance(options, list): + resolved_options = options + + # Case 2: options is a string - treat as method name + elif isinstance(options, str) and options: + method = getattr(valves_class, options, None) + if method is None or not callable(method): + log.warning( + f"options '{options}' not found or not callable on {valves_class.__name__}" + ) + continue + + try: + import inspect + + sig = inspect.signature(method) + params = sig.parameters + + # Prepare kwargs based on what the method accepts + kwargs = {} + if "__user__" in params and user is not None: + kwargs["__user__"] = ( + user.model_dump() if hasattr(user, "model_dump") else user + ) + if "user" in params and user is not None: + kwargs["user"] = ( + user.model_dump() if hasattr(user, "model_dump") else user + ) + + resolved_options = method(**kwargs) if kwargs else method() + + # Validate return type + if not isinstance(resolved_options, list): + log.warning( + f"Method '{options}' did not return a list for {prop_name}" + ) + continue + + except Exception as e: + log.warning(f"Failed to resolve options for {prop_name}: {e}") + continue + else: + # Invalid options type - skip + continue + + # Update the schema with resolved options + schema["properties"][prop_name] = dict(prop_schema) + if "input" not in schema["properties"][prop_name]: + schema["properties"][prop_name]["input"] = {"type": "select"} + else: + schema["properties"][prop_name]["input"] = dict( + schema["properties"][prop_name].get("input", {}) + ) + schema["properties"][prop_name]["input"]["options"] = resolved_options + + return schema + + def extract_frontmatter(content): """ Extract frontmatter as a dictionary from the provided content string. diff --git a/backend/open_webui/utils/redis.py b/backend/open_webui/utils/redis.py index 2040633a7b..fcc4879ba3 100644 --- a/backend/open_webui/utils/redis.py +++ b/backend/open_webui/utils/redis.py @@ -1,5 +1,7 @@ import inspect from urllib.parse import urlparse +import asyncio +import time import logging @@ -12,6 +14,7 @@ REDIS_SENTINEL_MAX_RETRY_COUNT, REDIS_SENTINEL_PORT, REDIS_URL, + REDIS_RECONNECT_DELAY, ) log = logging.getLogger(__name__) @@ -63,6 +66,8 @@ async def _iter(): i + 1, REDIS_SENTINEL_MAX_RETRY_COUNT, ) + if REDIS_RECONNECT_DELAY: + time.sleep(REDIS_RECONNECT_DELAY / 1000) continue log.error( "Redis operation failed after %s retries: %s", @@ -94,6 +99,8 @@ async def _wrapped(*args, **kwargs): i + 1, REDIS_SENTINEL_MAX_RETRY_COUNT, ) + if REDIS_RECONNECT_DELAY: + await asyncio.sleep(REDIS_RECONNECT_DELAY / 1000) continue log.error( "Redis operation failed after %s retries: %s", @@ -122,6 +129,8 @@ def _wrapped(*args, **kwargs): i + 1, REDIS_SENTINEL_MAX_RETRY_COUNT, ) + if REDIS_RECONNECT_DELAY: + time.sleep(REDIS_RECONNECT_DELAY / 1000) continue log.error( "Redis operation failed after %s retries: %s", diff --git a/backend/open_webui/utils/response.py b/backend/open_webui/utils/response.py index 3a3e1f84c2..b7174307a9 100644 --- a/backend/open_webui/utils/response.py +++ b/backend/open_webui/utils/response.py @@ -8,6 +8,47 @@ ) +def normalize_usage(usage: dict) -> dict: + """ + Normalize usage statistics to standard format. + Handles OpenAI, Ollama, and llama.cpp formats. + + Adds standardized token fields to the original data: + - input_tokens: Number of tokens in the prompt + - output_tokens: Number of tokens generated + - total_tokens: Sum of input and output tokens + """ + if not usage: + return {} + + # Map various field names to standard names + input_tokens = ( + usage.get("input_tokens") # Already standard + or usage.get("prompt_tokens") # OpenAI + or usage.get("prompt_eval_count") # Ollama + or usage.get("prompt_n") # llama.cpp + or 0 + ) + + output_tokens = ( + usage.get("output_tokens") # Already standard + or usage.get("completion_tokens") # OpenAI + or usage.get("eval_count") # Ollama + or usage.get("predicted_n") # llama.cpp + or 0 + ) + + total_tokens = usage.get("total_tokens") or (input_tokens + output_tokens) + + # Add standardized fields to original data + result = dict(usage) + result["input_tokens"] = int(input_tokens) + result["output_tokens"] = int(output_tokens) + result["total_tokens"] = int(total_tokens) + + return result + + def convert_ollama_tool_call_to_openai(tool_calls: list) -> list: openai_tool_calls = [] for tool_call in tool_calls: @@ -26,7 +67,19 @@ def convert_ollama_tool_call_to_openai(tool_calls: list) -> list: def convert_ollama_usage_to_openai(data: dict) -> dict: + input_tokens = int(data.get("prompt_eval_count", 0)) + output_tokens = int(data.get("eval_count", 0)) + total_tokens = input_tokens + output_tokens + return { + # Standardized fields + "input_tokens": input_tokens, + "output_tokens": output_tokens, + "total_tokens": total_tokens, + # OpenAI-compatible fields (for backward compatibility) + "prompt_tokens": input_tokens, + "completion_tokens": output_tokens, + # Ollama-specific metrics "response_token/s": ( round( ( @@ -58,22 +111,13 @@ def convert_ollama_usage_to_openai(data: dict) -> dict: "total_duration": data.get("total_duration", 0), "load_duration": data.get("load_duration", 0), "prompt_eval_count": data.get("prompt_eval_count", 0), - "prompt_tokens": int( - data.get("prompt_eval_count", 0) - ), # This is the OpenAI compatible key "prompt_eval_duration": data.get("prompt_eval_duration", 0), "eval_count": data.get("eval_count", 0), - "completion_tokens": int( - data.get("eval_count", 0) - ), # This is the OpenAI compatible key "eval_duration": data.get("eval_duration", 0), "approximate_total": (lambda s: f"{s // 3600}h{(s % 3600) // 60}m{s % 60}s")( (data.get("total_duration", 0) or 0) // 1_000_000_000 ), - "total_tokens": int( # This is the OpenAI compatible key - data.get("prompt_eval_count", 0) + data.get("eval_count", 0) - ), - "completion_tokens_details": { # This is the OpenAI compatible key + "completion_tokens_details": { "reasoning_tokens": 0, "accepted_prediction_tokens": 0, "rejected_prediction_tokens": 0, @@ -161,17 +205,29 @@ def convert_embedding_response_ollama_to_openai(response) -> dict: "model": "...", } """ - # Ollama batch-style output + # Ollama batch-style output from /api/embed + # Response format: {"embeddings": [[0.1, 0.2, ...], [0.3, 0.4, ...]], "model": "..."} if isinstance(response, dict) and "embeddings" in response: openai_data = [] for i, emb in enumerate(response["embeddings"]): - openai_data.append( - { - "object": "embedding", - "embedding": emb.get("embedding"), - "index": emb.get("index", i), - } - ) + # /api/embed returns embeddings as plain float lists + if isinstance(emb, list): + openai_data.append( + { + "object": "embedding", + "embedding": emb, + "index": i, + } + ) + # Also handle dict format for robustness + elif isinstance(emb, dict): + openai_data.append( + { + "object": "embedding", + "embedding": emb.get("embedding"), + "index": emb.get("index", i), + } + ) return { "object": "list", "data": openai_data, diff --git a/backend/open_webui/utils/sanitize.py b/backend/open_webui/utils/sanitize.py new file mode 100644 index 0000000000..258b6d78fb --- /dev/null +++ b/backend/open_webui/utils/sanitize.py @@ -0,0 +1,59 @@ +import re + +# ANSI escape code pattern - matches all common ANSI sequences +# This includes color codes, cursor movement, and other terminal control sequences +ANSI_ESCAPE_PATTERN = re.compile( + r"\x1b\[[0-9;]*[A-Za-z]|\x1b\([AB]|\x1b[PX^_].*?\x1b\\|\x1b\].*?(?:\x07|\x1b\\)" +) + + +def strip_ansi_codes(text: str) -> str: + """ + Strip ANSI escape codes from text. + + ANSI escape codes can be introduced by LLMs that include terminal + color codes in their output. These codes cause syntax errors when + the code is sent to Jupyter for execution. + + Common ANSI codes include: + - Color codes: \x1b[31m (red), \x1b[32m (green), etc. + - Reset codes: \x1b[0m, \x1b[39m + - Cursor movement: \x1b[1A, \x1b[2J, etc. + """ + return ANSI_ESCAPE_PATTERN.sub("", text) + + +def strip_markdown_code_fences(code: str) -> str: + """ + Strip markdown code fences if present. + + This is a defensive, non-breaking change — if the code doesn't + contain fences, it passes through unchanged. + + Handles patterns like: + - ```python + - ```py + - ``` + """ + code = code.strip() + # Remove opening fence (```python, ```py, ``` etc.) + code = re.sub(r"^```\w*\n?", "", code) + # Remove closing fence + code = re.sub(r"\n?```\s*$", "", code) + return code.strip() + + +def sanitize_code(code: str) -> str: + """ + Sanitize code for execution by applying all necessary cleanup steps. + + This is the recommended function to use before sending code to + interpreters like Jupyter or Pyodide. + + Steps applied: + 1. Strip ANSI escape codes (from LLM output) + 2. Strip markdown code fences (if model included them) + """ + code = strip_ansi_codes(code) + code = strip_markdown_code_fences(code) + return code diff --git a/backend/open_webui/utils/task.py b/backend/open_webui/utils/task.py index ecedd595a7..abc8920884 100644 --- a/backend/open_webui/utils/task.py +++ b/backend/open_webui/utils/task.py @@ -10,7 +10,6 @@ from open_webui.config import DEFAULT_RAG_TEMPLATE - log = logging.getLogger(__name__) @@ -69,6 +68,7 @@ def prompt_template(template: str, user: Optional[Any] = None) -> str: USER_VARIABLES = { "name": str(user.get("name")), + "email": str(user.get("email")), "location": str(user_info.get("location")), "bio": str(user.get("bio")), "gender": str(user.get("gender")), @@ -92,6 +92,9 @@ def prompt_template(template: str, user: Optional[Any] = None) -> str: template = template.replace("{{CURRENT_WEEKDAY}}", formatted_weekday) template = template.replace("{{USER_NAME}}", USER_VARIABLES.get("name", "Unknown")) + template = template.replace( + "{{USER_EMAIL}}", USER_VARIABLES.get("email", "Unknown") + ) template = template.replace("{{USER_BIO}}", USER_VARIABLES.get("bio", "Unknown")) template = template.replace( "{{USER_GENDER}}", USER_VARIABLES.get("gender", "Unknown") diff --git a/backend/open_webui/utils/telemetry/instrumentors.py b/backend/open_webui/utils/telemetry/instrumentors.py index dbc4ebb2cb..17536b9adb 100644 --- a/backend/open_webui/utils/telemetry/instrumentors.py +++ b/backend/open_webui/utils/telemetry/instrumentors.py @@ -22,13 +22,13 @@ from opentelemetry.instrumentation.aiohttp_client import AioHttpClientInstrumentor from opentelemetry.trace import Span, StatusCode from redis import Redis +from redis.cluster import RedisCluster from requests import PreparedRequest, Response from sqlalchemy import Engine from fastapi import status from open_webui.utils.telemetry.constants import SPAN_REDIS_TYPE, SpanAttributes - logger = logging.getLogger(__name__) @@ -59,16 +59,28 @@ def response_hook(span: Span, request: PreparedRequest, response: Response): span.set_status(StatusCode.ERROR if response.status_code >= 400 else StatusCode.OK) -def redis_request_hook(span: Span, instance: Redis, args, kwargs): +def redis_request_hook(span: Span, instance: Union[Redis | RedisCluster], args, kwargs): """ Redis Request Hook """ + # In cluster mode, the instance can be of two types: + # - redis.asyncio.cluster.RedisCluster + # - redis.cluster.RedisCluster + # Instead of checking the type, we check if the instance has a nodes_manager attribute. try: - connection_kwargs: dict = instance.connection_pool.connection_kwargs - host = connection_kwargs.get("host") - port = connection_kwargs.get("port") - db = connection_kwargs.get("db") + db = "" + if hasattr(instance, "nodes_manager"): + default_node = instance.nodes_manager.default_node + if not default_node: + return + host = default_node.host + port = default_node.port + else: + connection_kwargs: dict = instance.connection_pool.connection_kwargs + host = connection_kwargs.get("host") + port = connection_kwargs.get("port") + db = connection_kwargs.get("db") span.set_attributes( { SpanAttributes.DB_INSTANCE: f"{host}/{db}", diff --git a/backend/open_webui/utils/tools.py b/backend/open_webui/utils/tools.py index 6cb6c4b856..cb43cf4ef4 100644 --- a/backend/open_webui/utils/tools.py +++ b/backend/open_webui/utils/tools.py @@ -38,6 +38,7 @@ from open_webui.models.tools import Tools from open_webui.models.users import UserModel from open_webui.models.groups import Groups +from open_webui.models.access_grants import AccessGrants from open_webui.utils.plugin import load_tool_module_by_id from open_webui.utils.access_control import has_access from open_webui.config import BYPASS_ADMIN_ACCESS_CONTROL @@ -45,12 +46,17 @@ AIOHTTP_CLIENT_TIMEOUT, AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA, AIOHTTP_CLIENT_SESSION_TOOL_SERVER_SSL, + ENABLE_FORWARD_USER_INFO_HEADERS, + FORWARD_SESSION_INFO_HEADER_CHAT_ID, + FORWARD_SESSION_INFO_HEADER_MESSAGE_ID, ) +from open_webui.utils.headers import include_user_info_headers from open_webui.tools.builtin import ( search_web, fetch_url, generate_image, edit_image, + execute_code, search_memories, add_memory, replace_memory_content, @@ -72,6 +78,7 @@ search_knowledge_files, query_knowledge_files, view_knowledge_file, + view_skill, ) import copy @@ -144,8 +151,9 @@ def has_tool_server_access( if user_group_ids is None: user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user.id)} - access_control = server_connection.get("config", {}).get("access_control", None) - return has_access(user.id, "read", access_control, user_group_ids) + server_config = server_connection.get("config", {}) + access_grants = server_config.get("access_grants", []) + return has_access(user.id, "read", access_grants, user_group_ids) async def get_tools( @@ -164,7 +172,13 @@ async def get_tools( if ( not (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL) and tool.user_id != user.id - and not has_access(user.id, "read", tool.access_control, user_group_ids) + and not AccessGrants.has_access( + user_id=user.id, + resource_type="tool", + resource_id=tool.id, + permission="read", + user_group_ids=user_group_ids, + ) ): log.warning(f"Access denied to tool {tool_id} for user {user.id}") continue @@ -334,6 +348,19 @@ async def get_tools( for key, value in connection_headers.items(): headers[key] = value + # Add user info headers if enabled + if ENABLE_FORWARD_USER_INFO_HEADERS and user: + headers = include_user_info_headers(headers, user) + metadata = extra_params.get("__metadata__", {}) + if metadata and metadata.get("chat_id"): + headers[FORWARD_SESSION_INFO_HEADER_CHAT_ID] = ( + metadata.get("chat_id") + ) + if metadata and metadata.get("message_id"): + headers[FORWARD_SESSION_INFO_HEADER_MESSAGE_ID] = ( + metadata.get("message_id") + ) + def make_tool_function( function_name, tool_server_data, headers ): @@ -396,67 +423,91 @@ def get_builtin_tools( # Helper to get model capabilities (defaults to True if not specified) def get_model_capability(name: str, default: bool = True) -> bool: - return ( - model.get("info", {}) - .get("meta", {}) - .get("capabilities", {}) - .get(name, default) + return (model.get("info", {}).get("meta", {}).get("capabilities") or {}).get( + name, default ) - # Time utilities - always available for date calculations - builtin_functions.extend([get_current_timestamp, calculate_timestamp]) + # Helper to check if a builtin tool category is enabled via meta.builtinTools + # Defaults to True if not specified (backward compatible) + def is_builtin_tool_enabled(category: str) -> bool: + builtin_tools = model.get("info", {}).get("meta", {}).get("builtinTools", {}) + return builtin_tools.get(category, True) + + # Time utilities - available for date calculations + if is_builtin_tool_enabled("time"): + builtin_functions.extend([get_current_timestamp, calculate_timestamp]) # Knowledge base tools - conditional injection based on model knowledge # If model has attached knowledge (any type), only provide query_knowledge_files # Otherwise, provide all KB browsing tools model_knowledge = model.get("info", {}).get("meta", {}).get("knowledge", []) - if model_knowledge: - # Model has attached knowledge - only allow semantic search within it - builtin_functions.append(query_knowledge_files) - else: - # No model knowledge - allow full KB browsing - builtin_functions.extend( - [ - list_knowledge_bases, - search_knowledge_bases, - query_knowledge_bases, - search_knowledge_files, - query_knowledge_files, - view_knowledge_file, - ] - ) + if is_builtin_tool_enabled("knowledge"): + if model_knowledge: + # Model has attached knowledge - only allow semantic search within it + builtin_functions.append(query_knowledge_files) + else: + # No model knowledge - allow full KB browsing + builtin_functions.extend( + [ + list_knowledge_bases, + search_knowledge_bases, + query_knowledge_bases, + search_knowledge_files, + query_knowledge_files, + view_knowledge_file, + ] + ) # Chats tools - search and fetch user's chat history - builtin_functions.extend([search_chats, view_chat]) + if is_builtin_tool_enabled("chats"): + builtin_functions.extend([search_chats, view_chat]) - # Add memory tools if enabled for this chat - if features.get("memory"): + # Add memory tools if builtin category enabled AND enabled for this chat + if is_builtin_tool_enabled("memory") and features.get("memory"): builtin_functions.extend([search_memories, add_memory, replace_memory_content]) - # Add web search tools if enabled globally AND model has web_search capability - if getattr( - request.app.state.config, "ENABLE_WEB_SEARCH", False - ) and get_model_capability("web_search"): + # Add web search tools if builtin category enabled AND enabled globally AND model has web_search capability + if ( + is_builtin_tool_enabled("web_search") + and getattr(request.app.state.config, "ENABLE_WEB_SEARCH", False) + and get_model_capability("web_search") + ): builtin_functions.extend([search_web, fetch_url]) - # Add image generation/edit tools if enabled globally AND model has image_generation capability - if getattr( - request.app.state.config, "ENABLE_IMAGE_GENERATION", False - ) and get_model_capability("image_generation"): + # Add image generation/edit tools if builtin category enabled AND enabled globally AND model has image_generation capability + if ( + is_builtin_tool_enabled("image_generation") + and getattr(request.app.state.config, "ENABLE_IMAGE_GENERATION", False) + and get_model_capability("image_generation") + ): builtin_functions.append(generate_image) - if getattr( - request.app.state.config, "ENABLE_IMAGE_EDIT", False - ) and get_model_capability("image_generation"): + if ( + is_builtin_tool_enabled("image_generation") + and getattr(request.app.state.config, "ENABLE_IMAGE_EDIT", False) + and get_model_capability("image_generation") + ): builtin_functions.append(edit_image) - # Notes tools - search, view, create, and update user's notes (if notes enabled globally) - if getattr(request.app.state.config, "ENABLE_NOTES", False): + # Add code interpreter tool if builtin category enabled AND enabled globally AND model has code_interpreter capability + if ( + is_builtin_tool_enabled("code_interpreter") + and getattr(request.app.state.config, "ENABLE_CODE_INTERPRETER", True) + and get_model_capability("code_interpreter") + ): + builtin_functions.append(execute_code) + + # Notes tools - search, view, create, and update user's notes (if builtin category enabled AND notes enabled globally) + if is_builtin_tool_enabled("notes") and getattr( + request.app.state.config, "ENABLE_NOTES", False + ): builtin_functions.extend( [search_notes, view_note, write_note, replace_note_content] ) - # Channels tools - search channels and messages (if channels enabled globally) - if getattr(request.app.state.config, "ENABLE_CHANNELS", False): + # Channels tools - search channels and messages (if builtin category enabled AND channels enabled globally) + if is_builtin_tool_enabled("channels") and getattr( + request.app.state.config, "ENABLE_CHANNELS", False + ): builtin_functions.extend( [ search_channels, @@ -466,6 +517,10 @@ def get_model_capability(name: str, default: bool = True) -> bool: ] ) + # Skills tools - view_skill allows model to load full skill instructions on demand + if extra_params.get("__skill_ids__"): + builtin_functions.append(view_skill) + for func in builtin_functions: callable = get_async_tool_function_and_apply_extra_params( func, @@ -473,6 +528,8 @@ def get_model_capability(name: str, default: bool = True) -> bool: "__request__": request, "__user__": extra_params.get("__user__", {}), "__event_emitter__": extra_params.get("__event_emitter__"), + "__event_call__": extra_params.get("__event_call__"), + "__metadata__": extra_params.get("__metadata__"), "__chat_id__": extra_params.get("__chat_id__"), "__message_id__": extra_params.get("__message_id__"), "__model_knowledge__": model_knowledge, @@ -673,9 +730,10 @@ def convert_openapi_to_tool_payload(openapi_spec): "parameters": {"type": "object", "properties": {}, "required": []}, } - # Extract path and query parameters for param in operation.get("parameters", []): - param_name = param["name"] + param_name = param.get("name") + if not param_name: + continue param_schema = param.get("schema", {}) description = param_schema.get("description", "") if not description: @@ -947,8 +1005,10 @@ async def execute_tool_server( body_params = {} for param in operation.get("parameters", []): - param_name = param["name"] - param_in = param["in"] + param_name = param.get("name") + if not param_name: + continue + param_in = param.get("in") if param_name in params: if param_in == "path": path_params[param_name] = params[param_name] diff --git a/backend/open_webui/utils/validate.py b/backend/open_webui/utils/validate.py new file mode 100644 index 0000000000..6e62dd5416 --- /dev/null +++ b/backend/open_webui/utils/validate.py @@ -0,0 +1,38 @@ +"""Validation utilities for user-supplied input.""" + +# Known static asset paths used as default profile images +_ALLOWED_STATIC_PATHS = ( + "/user.png", + "/static/favicon.png", +) + + +def validate_profile_image_url(url: str) -> str: + """ + Pydantic-compatible validator for profile image URLs. + + Allowed formats: + - Empty string (falls back to default avatar) + - data:image/* URIs (base64-encoded uploads from the frontend) + - Known static asset paths (/user.png, /static/favicon.png) + + Returns the url unchanged if valid, raises ValueError otherwise. + """ + if not url: + return url + + _ALLOWED_DATA_PREFIXES = ( + "data:image/png", + "data:image/jpeg", + "data:image/gif", + "data:image/webp", + ) + if any(url.startswith(prefix) for prefix in _ALLOWED_DATA_PREFIXES): + return url + + if url in _ALLOWED_STATIC_PATHS: + return url + + raise ValueError( + "Invalid profile image URL: only data URIs and default avatars are allowed." + ) diff --git a/backend/requirements-min.txt b/backend/requirements-min.txt index 115d1bc92f..532a6cd714 100644 --- a/backend/requirements-min.txt +++ b/backend/requirements-min.txt @@ -1,54 +1,55 @@ # Minimal requirements for backend to run # WIP: use this as a reference to build a minimal docker image -fastapi==0.128.0 +fastapi==0.128.5 uvicorn[standard]==0.40.0 pydantic==2.12.5 -python-multipart==0.0.21 +python-multipart==0.0.22 itsdangerous==2.2.0 -python-socketio==5.16.0 +python-socketio==5.16.1 python-jose==3.5.0 cryptography bcrypt==5.0.0 argon2-cffi==25.1.0 -PyJWT[crypto]==2.10.1 -authlib==1.6.6 +PyJWT[crypto]==2.11.0 +authlib==1.6.7 requests==2.32.5 -aiohttp==3.13.2 +aiohttp==3.13.2 # do not update to 3.13.3 - broken async-timeout aiocache aiofiles -starlette-compress==1.6.1 +starlette-compress==1.7.0 +Brotli==1.1.0 httpx[socks,http2,zstd,cli,brotli]==0.28.1 starsessions[redis]==2.2.1 -sqlalchemy==2.0.45 -alembic==1.17.2 -peewee==3.18.3 +sqlalchemy==2.0.46 +alembic==1.18.3 +peewee==3.19.0 peewee-migrate==1.14.3 -pycrdt==0.12.44 +pycrdt==0.12.46 redis APScheduler==3.11.2 RestrictedPython==8.1 loguru==0.7.3 -asgiref==3.11.0 +asgiref==3.11.1 -mcp==1.25.0 +mcp==1.26.0 openai -langchain==1.2.0 +langchain==1.2.9 langchain-community==0.4.1 langchain-classic==1.0.1 langchain-text-splitters==1.1.0 fake-useragent==2.2.0 -chromadb==1.4.0 -black==25.12.0 +chromadb==1.4.1 +black==26.1.0 pydub chardet==5.2.0 diff --git a/backend/requirements.txt b/backend/requirements.txt index c9817968df..ea4966069e 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -1,78 +1,80 @@ -fastapi==0.128.0 +fastapi==0.128.5 uvicorn[standard]==0.40.0 pydantic==2.12.5 -python-multipart==0.0.21 +python-multipart==0.0.22 itsdangerous==2.2.0 -python-socketio==5.16.0 +python-socketio==5.16.1 python-jose==3.5.0 cryptography bcrypt==5.0.0 argon2-cffi==25.1.0 -PyJWT[crypto]==2.10.1 -authlib==1.6.6 +PyJWT[crypto]==2.11.0 +authlib==1.6.7 requests==2.32.5 -aiohttp==3.13.2 +aiohttp==3.13.2 # do not update to 3.13.3 - broken async-timeout aiocache aiofiles -starlette-compress==1.6.1 +starlette-compress==1.7.0 +Brotli==1.1.0 httpx[socks,http2,zstd,cli,brotli]==0.28.1 starsessions[redis]==2.2.1 python-mimeparse==2.0.0 -sqlalchemy==2.0.45 -alembic==1.17.2 -peewee==3.18.3 +sqlalchemy==2.0.46 +alembic==1.18.3 +peewee==3.19.0 peewee-migrate==1.14.3 -pycrdt==0.12.44 +pycrdt==0.12.46 redis APScheduler==3.11.2 RestrictedPython==8.1 +pytz==2025.2 loguru==0.7.3 -asgiref==3.11.0 +asgiref==3.11.1 # AI libraries tiktoken -mcp==1.25.0 +mcp==1.26.0 openai anthropic -google-genai==1.56.0 +google-genai==1.62.0 -langchain==1.2.0 +langchain==1.2.9 langchain-community==0.4.1 langchain-classic==1.0.1 langchain-text-splitters==1.1.0 fake-useragent==2.2.0 -chromadb==1.4.0 +chromadb==1.4.1 weaviate-client==4.19.2 opensearch-py==3.1.0 -transformers==4.57.3 -sentence-transformers==5.2.0 +transformers==5.1.0 +sentence-transformers==5.2.2 accelerate pyarrow==20.0.0 # fix: pin pyarrow version to 20 for rpi compatibility #15897 -einops==0.8.1 +einops==0.8.2 ftfy==6.3.1 chardet==5.2.0 -pypdf==6.5.0 +pypdf==6.7.0 fpdf2==2.8.5 -pymdown-extensions==10.20 +pymdown-extensions==10.20.1 docx2txt==0.9 python-pptx==1.0.2 -unstructured==0.18.24 -msoffcrypto-tool==5.4.2 +unstructured==0.18.31 +msoffcrypto-tool==6.0.0 nltk==3.9.2 -Markdown==3.10 +Markdown==3.10.1 pypandoc==1.16.2 -pandas==2.3.3 +pandas==3.0.0 openpyxl==3.1.5 pyxlsb==1.0.10 xlrd==2.0.2 @@ -83,15 +85,15 @@ jsonpath-ng soundfile==0.13.1 pillow==12.1.0 -opencv-python-headless==4.12.0.88 +opencv-python-headless==4.13.0.92 rapidocr-onnxruntime==1.4.4 rank-bm25==0.2.2 -onnxruntime==1.23.2 +onnxruntime==1.24.1 faster-whisper==1.2.1 -black==25.12.0 -youtube-transcript-api==1.2.3 +black==26.1.0 +youtube-transcript-api==1.2.4 pytube==15.0.0 pydub @@ -99,7 +101,7 @@ ddgs==9.10.0 azure-ai-documentintelligence==1.0.2 azure-identity==1.25.1 -azure-storage-blob==12.27.1 +azure-storage-blob==12.28.0 azure-search-documents==11.6.0 ## Google Drive @@ -108,7 +110,7 @@ google-auth-httplib2 google-auth-oauthlib googleapis-common-protos==1.72.0 -google-cloud-storage==3.7.0 +google-cloud-storage==3.9.0 ## Databases pymongo @@ -116,14 +118,14 @@ psycopg2-binary==2.9.11 pgvector==0.4.2 PyMySQL==1.1.2 -boto3==1.42.21 +boto3==1.42.44 -pymilvus==2.6.6 +pymilvus==2.6.8 qdrant-client==1.16.2 -playwright==1.57.0 # Caution: version must match docker-compose.playwright.yaml - Update the docker-compose.yaml if necessary -elasticsearch==9.2.1 +playwright==1.58.0 # Caution: version must match docker-compose.playwright.yaml - Update the docker-compose.yaml if necessary +elasticsearch==9.3.0 pinecone==6.0.2 -oracledb==3.4.1 +oracledb==3.4.2 av==14.0.1 # Caution: Set due to FATAL FIPS SELFTEST FAILURE, see discussion https://github.com/open-webui/open-webui/discussions/15720 @@ -139,7 +141,7 @@ pytest-docker~=3.2.5 ldap3==2.9.1 ## Firecrawl -firecrawl-py==4.12.0 +firecrawl-py==4.14.0 ## Trace opentelemetry-api==1.39.1 diff --git a/docker-compose.playwright.yaml b/docker-compose.playwright.yaml index e00a28df58..167c2501d6 100644 --- a/docker-compose.playwright.yaml +++ b/docker-compose.playwright.yaml @@ -1,8 +1,8 @@ services: playwright: - image: mcr.microsoft.com/playwright:v1.57.0-noble # Version must match requirements.txt + image: mcr.microsoft.com/playwright:v1.58.0-noble # Version must match requirements.txt container_name: playwright - command: npx -y playwright@1.57.0 run-server --port 3000 --host 0.0.0.0 + command: npx -y playwright@1.58.0 run-server --port 3000 --host 0.0.0.0 open-webui: environment: diff --git a/package-lock.json b/package-lock.json index 9a4b2d39ca..cc10773454 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "open-webui", - "version": "0.7.2.5", + "version": "0.8.0.2", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "open-webui", - "version": "0.7.2.5", + "version": "0.8.0.2", "dependencies": { "@azure/msal-browser": "^4.5.0", "@codemirror/lang-javascript": "^6.2.2", diff --git a/package.json b/package.json index b7766f922d..c1b01d2e66 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "open-webui", - "version": "0.7.2.5", + "version": "0.8.0.2", "private": true, "scripts": { "dev": "npm run pyodide:fetch && vite dev --host", diff --git a/pyproject.toml b/pyproject.toml index 893e1d018a..d240449dff 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,81 +6,83 @@ authors = [ ] license = { file = "LICENSE" } dependencies = [ - "fastapi==0.128.0", + "fastapi==0.128.5", "uvicorn[standard]==0.40.0", "pydantic==2.12.5", - "python-multipart==0.0.21", + "python-multipart==0.0.22", "itsdangerous==2.2.0", - "python-socketio==5.16.0", + "python-socketio==5.16.1", "python-jose==3.5.0", "cryptography", "bcrypt==5.0.0", "argon2-cffi==25.1.0", - "PyJWT[crypto]==2.10.1", - "authlib==1.6.6", + "PyJWT[crypto]==2.11.0", + "authlib==1.6.7", "requests==2.32.5", - "aiohttp==3.13.2", + "aiohttp==3.13.2", # do not update to 3.13.3 - broken "async-timeout", "aiocache", "aiofiles", - "starlette-compress==1.6.1", + "starlette-compress==1.7.0", + "Brotli==1.1.0", "httpx[socks,http2,zstd,cli,brotli]==0.28.1", "starsessions[redis]==2.2.1", "python-mimeparse==2.0.0", - "sqlalchemy==2.0.45", - "alembic==1.17.2", - "peewee==3.18.3", + "sqlalchemy==2.0.46", + "alembic==1.18.3", + "peewee==3.19.0", "peewee-migrate==1.14.3", - "pycrdt==0.12.44", + "pycrdt==0.12.46", "redis", + "pytz==2025.2", "APScheduler==3.11.2", "RestrictedPython==8.1", "loguru==0.7.3", - "asgiref==3.11.0", + "asgiref==3.11.1", "tiktoken", - "mcp==1.25.0", + "mcp==1.26.0", "openai", "anthropic", - "google-genai==1.56.0", + "google-genai==1.62.0", - "langchain==1.2.0", + "langchain==1.2.9", "langchain-community==0.4.1", "langchain-classic==1.0.1", "langchain-text-splitters==1.1.0", "fake-useragent==2.2.0", - "chromadb==1.4.0", + "chromadb==1.4.1", "opensearch-py==3.1.0", "PyMySQL==1.1.2", - "boto3==1.42.21", + "boto3==1.42.44", - "transformers==4.57.3", - "sentence-transformers==5.2.0", + "transformers==5.1.0", + "sentence-transformers==5.2.2", "accelerate", "pyarrow==20.0.0", # fix: pin pyarrow version to 20 for rpi compatibility #15897 - "einops==0.8.1", + "einops==0.8.2", "ftfy==6.3.1", "chardet==5.2.0", - "pypdf==6.5.0", + "pypdf==6.7.0", "fpdf2==2.8.5", - "pymdown-extensions==10.20", + "pymdown-extensions==10.20.1", "docx2txt==0.9", "python-pptx==1.0.2", - "unstructured==0.18.24", - "msoffcrypto-tool==5.4.2", + "unstructured==0.18.31", + "msoffcrypto-tool==6.0.0", "nltk==3.9.2", - "Markdown==3.10", + "Markdown==3.10.1", "pypandoc==1.16.2", - "pandas==2.3.3", + "pandas==3.0.0", "openpyxl==3.1.5", "pyxlsb==1.0.10", "xlrd==2.0.2", @@ -91,15 +93,15 @@ dependencies = [ "azure-ai-documentintelligence==1.0.2", "pillow==12.1.0", - "opencv-python-headless==4.12.0.88", + "opencv-python-headless==4.13.0.92", "rapidocr-onnxruntime==1.4.4", "rank-bm25==0.2.2", - "onnxruntime==1.23.2", + "onnxruntime==1.24.1", "faster-whisper==1.2.1", - "black==25.12.0", - "youtube-transcript-api==1.2.3", + "black==26.1.0", + "youtube-transcript-api==1.2.4", "pytube==15.0.0", "pydub", @@ -110,10 +112,10 @@ dependencies = [ "google-auth-oauthlib", "googleapis-common-protos==1.72.0", - "google-cloud-storage==3.7.0", + "google-cloud-storage==3.9.0", "azure-identity==1.25.1", - "azure-storage-blob==12.27.1", + "azure-storage-blob==12.28.0", "ldap3==2.9.1", ] @@ -145,18 +147,18 @@ all = [ "docker~=7.1.0", "pytest~=8.3.2", "pytest-docker~=3.2.5", - "playwright==1.57.0", # Caution: version must match docker-compose.playwright.yaml - Update the docker-compose.yaml if necessary - "elasticsearch==9.2.1", + "playwright==1.58.0", # Caution: version must match docker-compose.playwright.yaml - Update the docker-compose.yaml if necessary + "elasticsearch==9.3.0", "qdrant-client==1.16.2", "weaviate-client==4.19.2", - "pymilvus==2.6.6", + "pymilvus==2.6.8", "pinecone==6.0.2", - "oracledb==3.4.1", + "oracledb==3.4.2", "colbert-ai==0.2.22", - "firecrawl-py==4.12.0", + "firecrawl-py==4.14.0", "azure-search-documents==11.6.0", ] diff --git a/src/lib/apis/analytics/index.ts b/src/lib/apis/analytics/index.ts new file mode 100644 index 0000000000..6bab2cbf81 --- /dev/null +++ b/src/lib/apis/analytics/index.ts @@ -0,0 +1,319 @@ +import { WEBUI_API_BASE_URL } from '$lib/constants'; + +export const getModelAnalytics = async ( + token: string = '', + startDate: number | null = null, + endDate: number | null = null, + groupId: string | null = null +) => { + let error = null; + + const searchParams = new URLSearchParams(); + if (startDate) searchParams.append('start_date', startDate.toString()); + if (endDate) searchParams.append('end_date', endDate.toString()); + if (groupId) searchParams.append('group_id', groupId); + + const res = await fetch(`${WEBUI_API_BASE_URL}/analytics/models?${searchParams.toString()}`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + error = err.detail; + console.error(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const getUserAnalytics = async ( + token: string = '', + startDate: number | null = null, + endDate: number | null = null, + limit: number = 50, + groupId: string | null = null +) => { + let error = null; + + const searchParams = new URLSearchParams(); + if (startDate) searchParams.append('start_date', startDate.toString()); + if (endDate) searchParams.append('end_date', endDate.toString()); + if (limit) searchParams.append('limit', limit.toString()); + if (groupId) searchParams.append('group_id', groupId); + + const res = await fetch(`${WEBUI_API_BASE_URL}/analytics/users?${searchParams.toString()}`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + error = err.detail; + console.error(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const getMessages = async ( + token: string = '', + modelId: string | null = null, + userId: string | null = null, + chatId: string | null = null, + startDate: number | null = null, + endDate: number | null = null, + skip: number = 0, + limit: number = 50 +) => { + let error = null; + + const searchParams = new URLSearchParams(); + if (modelId) searchParams.append('model_id', modelId); + if (userId) searchParams.append('user_id', userId); + if (chatId) searchParams.append('chat_id', chatId); + if (startDate) searchParams.append('start_date', startDate.toString()); + if (endDate) searchParams.append('end_date', endDate.toString()); + if (skip) searchParams.append('skip', skip.toString()); + if (limit) searchParams.append('limit', limit.toString()); + + const res = await fetch(`${WEBUI_API_BASE_URL}/analytics/messages?${searchParams.toString()}`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + error = err.detail; + console.error(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const getSummary = async ( + token: string = '', + startDate: number | null = null, + endDate: number | null = null, + groupId: string | null = null +) => { + let error = null; + + const searchParams = new URLSearchParams(); + if (startDate) searchParams.append('start_date', startDate.toString()); + if (endDate) searchParams.append('end_date', endDate.toString()); + if (groupId) searchParams.append('group_id', groupId); + + const res = await fetch(`${WEBUI_API_BASE_URL}/analytics/summary?${searchParams.toString()}`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + error = err.detail; + console.error(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const getDailyStats = async ( + token: string = '', + startDate: number | null = null, + endDate: number | null = null, + granularity: 'hourly' | 'daily' = 'daily', + groupId: string | null = null +) => { + let error = null; + + const searchParams = new URLSearchParams(); + if (startDate) searchParams.append('start_date', startDate.toString()); + if (endDate) searchParams.append('end_date', endDate.toString()); + searchParams.append('granularity', granularity); + if (groupId) searchParams.append('group_id', groupId); + + const res = await fetch(`${WEBUI_API_BASE_URL}/analytics/daily?${searchParams.toString()}`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + error = err.detail; + console.error(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const getTokenUsage = async ( + token: string = '', + startDate: number | null = null, + endDate: number | null = null, + groupId: string | null = null +) => { + let error = null; + + const searchParams = new URLSearchParams(); + if (startDate) searchParams.append('start_date', startDate.toString()); + if (endDate) searchParams.append('end_date', endDate.toString()); + if (groupId) searchParams.append('group_id', groupId); + + const res = await fetch(`${WEBUI_API_BASE_URL}/analytics/tokens?${searchParams.toString()}`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + error = err.detail; + console.error(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const getModelChats = async ( + token: string = '', + modelId: string, + startDate: number | null = null, + endDate: number | null = null, + skip: number = 0, + limit: number = 50 +) => { + let error = null; + + const searchParams = new URLSearchParams(); + if (startDate) searchParams.append('start_date', startDate.toString()); + if (endDate) searchParams.append('end_date', endDate.toString()); + if (skip) searchParams.append('skip', skip.toString()); + if (limit) searchParams.append('limit', limit.toString()); + + const res = await fetch( + `${WEBUI_API_BASE_URL}/analytics/models/${encodeURIComponent(modelId)}/chats?${searchParams.toString()}`, + { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + } + } + ) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + error = err.detail; + console.error(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const getModelOverview = async (token: string = '', modelId: string, days: number = 30) => { + let error = null; + + const searchParams = new URLSearchParams(); + searchParams.append('days', days.toString()); + + const res = await fetch( + `${WEBUI_API_BASE_URL}/analytics/models/${encodeURIComponent(modelId)}/overview?${searchParams.toString()}`, + { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + } + } + ) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + error = err.detail; + console.error(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; diff --git a/src/lib/apis/channels/index.ts b/src/lib/apis/channels/index.ts index 225d8cd7cf..5715c64e89 100644 --- a/src/lib/apis/channels/index.ts +++ b/src/lib/apis/channels/index.ts @@ -3,10 +3,11 @@ import { WEBUI_API_BASE_URL } from '$lib/constants'; type ChannelForm = { type?: string; name: string; - is_private?: boolean; + is_private?: boolean | null; data?: object; meta?: object; - access_control?: object; + access_grants?: object[]; + group_ids?: string[]; user_ids?: string[]; }; diff --git a/src/lib/apis/chats/index.ts b/src/lib/apis/chats/index.ts index b33072e890..e16746707a 100644 --- a/src/lib/apis/chats/index.ts +++ b/src/lib/apis/chats/index.ts @@ -255,6 +255,51 @@ export const getArchivedChatList = async ( })); }; +export const getSharedChatList = async (token: string = '', page: number = 1, filter?: object) => { + let error = null; + + const searchParams = new URLSearchParams(); + searchParams.append('page', `${page}`); + + if (filter) { + Object.entries(filter).forEach(([key, value]) => { + if (value !== undefined && value !== null) { + searchParams.append(key, value.toString()); + } + }); + } + + const res = await fetch(`${WEBUI_API_BASE_URL}/chats/shared?${searchParams.toString()}`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + ...(token && { authorization: `Bearer ${token}` }) + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err; + console.error(err); + return null; + }); + + if (error) { + throw error; + } + + return res.map((chat) => ({ + ...chat, + time_range: getTimeRange(chat.updated_at) + })); +}; + export const getAllChats = async (token: string) => { let error = null; diff --git a/src/lib/apis/files/index.ts b/src/lib/apis/files/index.ts index 44af669fa1..15785b354d 100644 --- a/src/lib/apis/files/index.ts +++ b/src/lib/apis/files/index.ts @@ -175,6 +175,44 @@ export const getFiles = async (token: string = '') => { return res; }; +export const searchFiles = async ( + token: string, + filename: string = '*', + skip: number = 0, + limit: number = 50 +) => { + let error = null; + + const searchParams = new URLSearchParams(); + searchParams.append('filename', filename); + searchParams.append('skip', String(skip)); + searchParams.append('limit', String(limit)); + + const res = await fetch(`${WEBUI_API_BASE_URL}/files/search?${searchParams.toString()}`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + error = err.detail; + console.error(err); + return []; + }); + + if (error) { + throw error; + } + + return res; +}; + export const getFileById = async (token: string, id: string) => { let error = null; diff --git a/src/lib/apis/groups/index.ts b/src/lib/apis/groups/index.ts index a74c61b83d..6089a6023f 100644 --- a/src/lib/apis/groups/index.ts +++ b/src/lib/apis/groups/index.ts @@ -99,6 +99,38 @@ export const getGroupById = async (token: string, id: string) => { return res; }; +export const getGroupInfoById = async (token: string, id: string) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/groups/id/${id}/info`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err.detail; + + console.error(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + export const updateGroupById = async (token: string, id: string, group: object) => { let error = null; diff --git a/src/lib/apis/images/index.ts b/src/lib/apis/images/index.ts index a58d16085f..86b8a90ed1 100644 --- a/src/lib/apis/images/index.ts +++ b/src/lib/apis/images/index.ts @@ -217,7 +217,63 @@ export const imageGenerations = async (token: string = '', prompt: string) => { .catch((err) => { console.error(err); if ('detail' in err) { - error = err.detail; + if (Array.isArray(err.detail)) { + error = err.detail.map((e: { msg?: string }) => e.msg || JSON.stringify(e)).join(', '); + } else { + error = err.detail; + } + } else { + error = 'Server connection failed'; + } + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const imageEdits = async ( + token: string = '', + images: string | string[], + prompt: string, + model?: string, + size?: string, + n?: number +) => { + let error = null; + + const res = await fetch(`${IMAGES_API_BASE_URL}/edit`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + ...(token && { authorization: `Bearer ${token}` }) + }, + body: JSON.stringify({ + form_data: { + image: images, + prompt, + ...(model && { model }), + ...(size && { size }), + ...(n && { n }) + } + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.error(err); + if ('detail' in err) { + if (Array.isArray(err.detail)) { + error = err.detail.map((e: { msg?: string }) => e.msg || JSON.stringify(e)).join(', '); + } else { + error = err.detail; + } } else { error = 'Server connection failed'; } diff --git a/src/lib/apis/index.ts b/src/lib/apis/index.ts index 58827c8980..e180c73500 100644 --- a/src/lib/apis/index.ts +++ b/src/lib/apis/index.ts @@ -473,8 +473,9 @@ export const executeToolServer = async ( if (operation.parameters) { operation.parameters.forEach((param: any) => { - const paramName = param.name; - const paramIn = param.in; + const paramName = param?.name; + if (!paramName) return; + const paramIn = param?.in; if (params.hasOwnProperty(paramName)) { if (paramIn === 'path') { pathParams[paramName] = params[paramName]; @@ -887,7 +888,8 @@ export const generateQueries = async ( model: string, messages: object[], prompt: string, - type: string = 'web_search' + type: string = 'web_search', + chat_id?: string ) => { let error = null; @@ -902,7 +904,8 @@ export const generateQueries = async ( model: model, messages: messages, prompt: prompt, - type: type + type: type, + ...(chat_id && { chat_id: chat_id }) }) }) .then(async (res) => { @@ -956,7 +959,8 @@ export const generateAutoCompletion = async ( model: string, prompt: string, messages?: object[], - type: string = 'search query' + type: string = 'search query', + chat_id?: string ) => { const controller = new AbortController(); let error = null; @@ -974,7 +978,8 @@ export const generateAutoCompletion = async ( prompt: prompt, ...(messages && { messages: messages }), type: type, - stream: false + stream: false, + ...(chat_id && { chat_id: chat_id }) }) }) .then(async (res) => { diff --git a/src/lib/apis/knowledge/index.ts b/src/lib/apis/knowledge/index.ts index dc9dd8b88a..f314bae634 100644 --- a/src/lib/apis/knowledge/index.ts +++ b/src/lib/apis/knowledge/index.ts @@ -4,7 +4,7 @@ export const createNewKnowledge = async ( token: string, name: string, description: string, - accessControl: null | object + accessGrants: object[] ) => { let error = null; @@ -18,7 +18,7 @@ export const createNewKnowledge = async ( body: JSON.stringify({ name: name, description: description, - access_control: accessControl + access_grants: accessGrants }) }) .then(async (res) => { @@ -248,7 +248,7 @@ type KnowledgeUpdateForm = { name?: string; description?: string; data?: object; - access_control?: null | object; + access_grants?: object[]; }; export const updateKnowledgeById = async (token: string, id: string, form: KnowledgeUpdateForm) => { @@ -265,7 +265,7 @@ export const updateKnowledgeById = async (token: string, id: string, form: Knowl name: form?.name ? form.name : undefined, description: form?.description ? form.description : undefined, data: form?.data ? form.data : undefined, - access_control: form.access_control + access_grants: form.access_grants }) }) .then(async (res) => { @@ -289,6 +289,39 @@ export const updateKnowledgeById = async (token: string, id: string, form: Knowl return res; }; +export const updateKnowledgeAccessGrants = async ( + token: string, + id: string, + accessGrants: any[] +) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/knowledge/${id}/access/update`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + }, + body: JSON.stringify({ access_grants: accessGrants }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + error = err.detail; + console.error(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + export const addFileToKnowledgeById = async (token: string, id: string, fileId: string) => { let error = null; diff --git a/src/lib/apis/models/index.ts b/src/lib/apis/models/index.ts index d03a83e9ca..42e77c0afa 100644 --- a/src/lib/apis/models/index.ts +++ b/src/lib/apis/models/index.ts @@ -281,6 +281,35 @@ export const updateModelById = async (token: string, id: string, model: object) return res; }; +export const updateModelAccessGrants = async (token: string, id: string, accessGrants: any[]) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/models/model/access/update`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + }, + body: JSON.stringify({ id, access_grants: accessGrants }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + error = err; + console.error(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + export const deleteModelById = async (token: string, id: string) => { let error = null; diff --git a/src/lib/apis/notes/index.ts b/src/lib/apis/notes/index.ts index 55f9427e0d..07e249a889 100644 --- a/src/lib/apis/notes/index.ts +++ b/src/lib/apis/notes/index.ts @@ -5,7 +5,7 @@ type NoteItem = { title: string; data: object; meta?: null | object; - access_control?: null | object; + access_grants?: object[]; }; export const createNewNote = async (token: string, note: NoteItem) => { @@ -253,6 +253,35 @@ export const updateNoteById = async (token: string, id: string, note: NoteItem) return res; }; +export const updateNoteAccessGrants = async (token: string, id: string, accessGrants: any[]) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/notes/${id}/access/update`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + }, + body: JSON.stringify({ access_grants: accessGrants }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + error = err.detail; + console.error(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + export const deleteNoteById = async (token: string, id: string) => { let error = null; diff --git a/src/lib/apis/prompts/index.ts b/src/lib/apis/prompts/index.ts index 4129ea62aa..1fd311c76f 100644 --- a/src/lib/apis/prompts/index.ts +++ b/src/lib/apis/prompts/index.ts @@ -1,10 +1,48 @@ import { WEBUI_API_BASE_URL } from '$lib/constants'; type PromptItem = { + id?: string; // Prompt ID command: string; - title: string; + name: string; // Changed from title content: string; - access_control?: null | object; + data?: object | null; + meta?: object | null; + access_grants?: object[]; + version_id?: string | null; // Active version + commit_message?: string | null; // For history tracking + is_production?: boolean; // Whether to set new version as production +}; + +type PromptHistoryItem = { + id: string; + prompt_id: string; + parent_id: string | null; + snapshot: { + name: string; + content: string; + command: string; + data: object; + meta: object; + access_grants: object[]; + }; + user_id: string; + commit_message: string | null; + created_at: number; + user?: { + id: string; + name: string; + email: string; + }; +}; + +type PromptDiff = { + from_id: string; + to_id: string; + from_snapshot: object; + to_snapshot: object; + content_diff: string[]; + name_changed: boolean; + access_grants_changed: boolean; }; export const createNewPrompt = async (token: string, prompt: PromptItem) => { @@ -19,7 +57,7 @@ export const createNewPrompt = async (token: string, prompt: PromptItem) => { }, body: JSON.stringify({ ...prompt, - command: `/${prompt.command}` + command: prompt.command.startsWith('/') ? prompt.command.slice(1) : prompt.command }) }) .then(async (res) => { @@ -70,6 +108,93 @@ export const getPrompts = async (token: string = '') => { return res; }; +export const getPromptTags = async (token: string = '') => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/prompts/tags`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + error = err.detail; + console.error(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const getPromptItems = async ( + token: string = '', + query: string | null, + viewOption: string | null, + selectedTag: string | null, + orderBy: string | null, + direction: string | null, + page: number +) => { + let error = null; + + const searchParams = new URLSearchParams(); + if (query) { + searchParams.append('query', query); + } + if (viewOption) { + searchParams.append('view_option', viewOption); + } + if (selectedTag) { + searchParams.append('tag', selectedTag); + } + if (orderBy) { + searchParams.append('order_by', orderBy); + } + if (direction) { + searchParams.append('direction', direction); + } + if (page) { + searchParams.append('page', page.toString()); + } + + const res = await fetch(`${WEBUI_API_BASE_URL}/prompts/list?${searchParams.toString()}`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err; + console.error(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + export const getPromptList = async (token: string = '') => { let error = null; @@ -104,6 +229,8 @@ export const getPromptList = async (token: string = '') => { export const getPromptByCommand = async (token: string, command: string) => { let error = null; + command = command.charAt(0) === '/' ? command.slice(1) : command; + const res = await fetch(`${WEBUI_API_BASE_URL}/prompts/command/${command}`, { method: 'GET', headers: { @@ -133,20 +260,49 @@ export const getPromptByCommand = async (token: string, command: string) => { return res; }; -export const updatePromptByCommand = async (token: string, prompt: PromptItem) => { +export const getPromptById = async (token: string, promptId: string) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/prompts/id/${promptId}`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err.detail; + + console.error(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const updatePromptById = async (token: string, prompt: PromptItem) => { let error = null; - const res = await fetch(`${WEBUI_API_BASE_URL}/prompts/command/${prompt.command}/update`, { + const res = await fetch(`${WEBUI_API_BASE_URL}/prompts/id/${prompt.id}/update`, { method: 'POST', headers: { Accept: 'application/json', 'Content-Type': 'application/json', authorization: `Bearer ${token}` }, - body: JSON.stringify({ - ...prompt, - command: `/${prompt.command}` - }) + body: JSON.stringify(prompt) }) .then(async (res) => { if (!res.ok) throw await res.json(); @@ -169,12 +325,80 @@ export const updatePromptByCommand = async (token: string, prompt: PromptItem) = return res; }; -export const deletePromptByCommand = async (token: string, command: string) => { +export const updatePromptMetadata = async ( + token: string, + promptId: string, + name: string, + command: string, + tags: string[] = [] +) => { let error = null; - command = command.charAt(0) === '/' ? command.slice(1) : command; + const res = await fetch(`${WEBUI_API_BASE_URL}/prompts/id/${promptId}/update/meta`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + }, + body: JSON.stringify({ name, command, tags }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + error = err.detail; + console.error(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const setProductionPromptVersion = async ( + token: string, + promptId: string, + version_id: string +) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/prompts/id/${promptId}/update/version`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + version_id: version_id + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + error = err.detail; + return null; + }); + + if (error) { + throw error; + } + + return res; +}; - const res = await fetch(`${WEBUI_API_BASE_URL}/prompts/command/${command}/delete`, { +export const deletePromptById = async (token: string, promptId: string) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/prompts/id/${promptId}/delete`, { method: 'DELETE', headers: { Accept: 'application/json', @@ -202,3 +426,211 @@ export const deletePromptByCommand = async (token: string, command: string) => { return res; }; + +export const updatePromptAccessGrants = async ( + token: string, + promptId: string, + accessGrants: any[] +) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/prompts/id/${promptId}/access/update`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + }, + body: JSON.stringify({ access_grants: accessGrants }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + error = err.detail; + console.error(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +//////////////////////////// +// Prompt History APIs +//////////////////////////// + +export const getPromptHistory = async ( + token: string, + promptId: string, + page: number = 0 +): Promise => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/prompts/id/${promptId}/history?page=${page}`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + error = err.detail; + console.error(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const deletePromptHistoryVersion = async ( + token: string, + promptId: string, + historyId: string +): Promise => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/prompts/id/${promptId}/history/${historyId}`, { + method: 'DELETE', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + error = err.detail; + console.error(err); + return false; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const getPromptHistoryEntry = async ( + token: string, + promptId: string, + historyId: string +): Promise => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/prompts/id/${promptId}/history/${historyId}`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + error = err.detail; + console.error(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const restorePromptFromHistory = async ( + token: string, + promptId: string, + historyId: string, + commitMessage?: string +) => { + let error = null; + + const res = await fetch( + `${WEBUI_API_BASE_URL}/prompts/id/${promptId}/history/${historyId}/restore`, + { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + commit_message: commitMessage + }) + } + ) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + error = err.detail; + console.error(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const getPromptDiff = async ( + token: string, + promptId: string, + fromId: string, + toId: string +): Promise => { + let error = null; + + const res = await fetch( + `${WEBUI_API_BASE_URL}/prompts/id/${promptId}/history/diff?from_id=${fromId}&to_id=${toId}`, + { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + } + } + ) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + error = err.detail; + console.error(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; diff --git a/src/lib/apis/skills/index.ts b/src/lib/apis/skills/index.ts new file mode 100644 index 0000000000..24139fa042 --- /dev/null +++ b/src/lib/apis/skills/index.ts @@ -0,0 +1,321 @@ +import { WEBUI_API_BASE_URL } from '$lib/constants'; + +export const createNewSkill = async (token: string, skill: object) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/skills/create`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + ...skill + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + error = err.detail; + console.error(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const getSkills = async (token: string = '') => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/skills/`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err.detail; + console.error(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const getSkillList = async (token: string = '') => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/skills/list`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err.detail; + console.error(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const getSkillItems = async ( + token: string = '', + query: string | null = null, + viewOption: string | null = null, + page: number | null = null +) => { + let error = null; + + const searchParams = new URLSearchParams(); + if (query) searchParams.append('query', query); + if (viewOption) searchParams.append('view_option', viewOption); + if (page) searchParams.append('page', page.toString()); + + const res = await fetch(`${WEBUI_API_BASE_URL}/skills/list?${searchParams.toString()}`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err; + console.error(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const exportSkills = async (token: string = '') => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/skills/export`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err.detail; + console.error(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const getSkillById = async (token: string, id: string) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/skills/id/${id}`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err.detail; + console.error(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const updateSkillById = async (token: string, id: string, skill: object) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/skills/id/${id}/update`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + ...skill + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + error = err.detail; + console.error(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const updateSkillAccessGrants = async (token: string, id: string, accessGrants: any[]) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/skills/id/${id}/access/update`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + access_grants: accessGrants + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + error = err.detail; + console.error(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const toggleSkillById = async (token: string, id: string) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/skills/id/${id}/toggle`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err.detail; + console.error(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const deleteSkillById = async (token: string, id: string) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/skills/id/${id}/delete`, { + method: 'DELETE', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err.detail; + console.error(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; diff --git a/src/lib/apis/tasks/index.ts b/src/lib/apis/tasks/index.ts new file mode 100644 index 0000000000..dab6090fde --- /dev/null +++ b/src/lib/apis/tasks/index.ts @@ -0,0 +1,14 @@ +import { WEBUI_API_BASE_URL } from '$lib/constants'; + +export const checkActiveChats = async (token: string, chatIds: string[]) => { + const res = await fetch(`${WEBUI_API_BASE_URL}/tasks/active/chats`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}` + }, + body: JSON.stringify({ chat_ids: chatIds }) + }); + if (!res.ok) throw await res.json(); + return res.json(); +}; diff --git a/src/lib/apis/tools/index.ts b/src/lib/apis/tools/index.ts index 2038e46ac6..5d26e50fee 100644 --- a/src/lib/apis/tools/index.ts +++ b/src/lib/apis/tools/index.ts @@ -225,6 +225,35 @@ export const updateToolById = async (token: string, id: string, tool: object) => return res; }; +export const updateToolAccessGrants = async (token: string, id: string, accessGrants: any[]) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/tools/id/${id}/access/update`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + }, + body: JSON.stringify({ access_grants: accessGrants }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + error = err.detail; + console.error(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + export const deleteToolById = async (token: string, id: string) => { let error = null; diff --git a/src/lib/apis/users/index.ts b/src/lib/apis/users/index.ts index d23b98da65..05477e13b4 100644 --- a/src/lib/apis/users/index.ts +++ b/src/lib/apis/users/index.ts @@ -300,10 +300,10 @@ export const updateUserSettings = async (token: string, settings: object) => { return res; }; -export const getUserById = async (token: string, userId: string) => { +export const getUserInfoById = async (token: string, userId: string) => { let error = null; - const res = await fetch(`${WEBUI_API_BASE_URL}/users/${userId}`, { + const res = await fetch(`${WEBUI_API_BASE_URL}/users/${userId}/info`, { method: 'GET', headers: { 'Content-Type': 'application/json', diff --git a/src/lib/components/AddConnectionModal.svelte b/src/lib/components/AddConnectionModal.svelte index 557549098c..a455627e11 100644 --- a/src/lib/components/AddConnectionModal.svelte +++ b/src/lib/components/AddConnectionModal.svelte @@ -42,6 +42,7 @@ let prefixId = ''; let enable = true; let apiVersion = ''; + let apiType = ''; // '' = chat completions (default), 'responses' = Responses API let headers = ''; @@ -183,7 +184,8 @@ connection_type: connectionType, auth_type, headers: headers ? JSON.parse(headers) : undefined, - ...(!ollama && azure ? { azure: true, api_version: apiVersion } : {}) + ...(!ollama && azure ? { azure: true, api_version: apiVersion } : {}), + ...(apiType ? { api_type: apiType } : {}) } }; @@ -221,6 +223,7 @@ connectionType = connection.config?.connection_type ?? 'external'; azure = connection.config?.azure ?? false; apiVersion = connection.config?.api_version ?? ''; + apiType = connection.config?.api_type ?? ''; } } }; @@ -506,7 +509,7 @@
{/if} + {#if !ollama && !direct} +
+ + +
+ +
+
+ {/if} +
{ @@ -363,7 +363,7 @@ enable = connection.config?.enable ?? true; functionNameFilterList = connection.config?.function_name_filter_list ?? ''; - accessControl = connection.config?.access_control ?? null; + accessGrants = connection.config?.access_grants ?? []; } }; @@ -819,7 +819,7 @@
- +
{/if}
diff --git a/src/lib/components/NotificationToast.svelte b/src/lib/components/NotificationToast.svelte index 1b8d9fae8b..5a40b85b70 100644 --- a/src/lib/components/NotificationToast.svelte +++ b/src/lib/components/NotificationToast.svelte @@ -5,6 +5,7 @@ import { marked } from 'marked'; import { createEventDispatcher, onMount } from 'svelte'; + import XMark from '$lib/components/icons/XMark.svelte'; const dispatch = createEventDispatcher(); @@ -15,6 +16,7 @@ let startX = 0, startY = 0; let moved = false; + let closeButtonElement: HTMLButtonElement; const DRAG_THRESHOLD_PX = 6; const clickHandler = () => { @@ -22,6 +24,10 @@ dispatch('closeToast'); }; + const closeHandler = () => { + dispatch('closeToast'); + }; + function onPointerDown(e: PointerEvent) { startX = e.clientX; startY = e.clientY; @@ -43,6 +49,14 @@ // Release capture if taken (e.currentTarget as HTMLElement).releasePointerCapture?.(e.pointerId); + // Skip if clicking the close button + if ( + closeButtonElement && + (e.target === closeButtonElement || closeButtonElement.contains(e.target as Node)) + ) { + return; + } + // Only treat as a click if there wasn't a drag if (!moved) { clickHandler(); @@ -71,7 +85,7 @@
+ + +
favicon
diff --git a/src/lib/components/admin/Analytics.svelte b/src/lib/components/admin/Analytics.svelte new file mode 100644 index 0000000000..4aa6ae86f6 --- /dev/null +++ b/src/lib/components/admin/Analytics.svelte @@ -0,0 +1,24 @@ + + +{#if loaded} +
+ +
+{/if} diff --git a/src/lib/components/admin/Analytics/AnalyticsModelModal.svelte b/src/lib/components/admin/Analytics/AnalyticsModelModal.svelte new file mode 100644 index 0000000000..c9eb9cf3d7 --- /dev/null +++ b/src/lib/components/admin/Analytics/AnalyticsModelModal.svelte @@ -0,0 +1,268 @@ + + + + {#if model} +
+ +
+ {model.name} +
+
+ +
+ + +
+
+ + {#if $config?.features?.enable_admin_chat_access} + + {/if} +
+
+ +
+ {#if selectedTab === 'overview'} + +
+
+ +
+ {$i18n.t('Feedback Activity')} +
+
+
+ {#each TIME_RANGES as range} + + {/each} +
+
+ +
+ + +
+
+ {$i18n.t('Tags')} +
+ {#if tags.length} +
+ {#each tags as tagInfo} + + {tagInfo.tag} {tagInfo.count} + + {/each} +
+ {:else} + - + {/if} +
+ {:else if selectedTab === 'chats'} +
+ (show = false)} + /> +
+ {/if} + +
+ +
+
+ {/if} +
diff --git a/src/lib/components/admin/Analytics/ChartLine.svelte b/src/lib/components/admin/Analytics/ChartLine.svelte new file mode 100644 index 0000000000..6f2f68f39a --- /dev/null +++ b/src/lib/components/admin/Analytics/ChartLine.svelte @@ -0,0 +1,133 @@ + + +
+ (hoveredIdx = null)} + > + {#each models as m} + + {/each} + {#if hoveredIdx !== null} + + {#each models as m} + {@const v = hovered?.models?.[m] || 0} + {#if v > 0} + + {/if} + {/each} + {/if} + + + {#if data.length > 1} + {@const labelCount = Math.min(7, data.length)} + {@const step = labelCount > 1 ? Math.floor((data.length - 1) / (labelCount - 1)) || 1 : 1} + {@const isHourly = data[0]?.date?.includes(':')} + {@const dateFormat = isHourly + ? 'h A' + : period === 'year' || period === 'all' + ? 'M/D/YY' + : 'M/D'} +
+ {#each Array(labelCount) as _, i} + {@const idx = i === labelCount - 1 ? data.length - 1 : Math.min(i * step, data.length - 1)} + {#if data[idx]} + {dayjs(data[idx].date).format(dateFormat)} + {/if} + {/each} +
+ {/if} + {#if hovered} + {@const total = Object.values(hovered.models || {}).reduce((a, b) => a + b, 0)} +
+
+
+ {#if hovered.date?.includes(':')} + {dayjs(hovered.date).format('MMM D, h A')} + {:else} + {dayjs(hovered.date).format('MMM D, YYYY')} + {/if} +
+ {#each Object.entries(hovered.models || {}) + .sort(([, a], [, b]) => b - a) + .slice(0, 5) as [n, c]} +
+ {n} + {c.toLocaleString()} + ({total > 0 ? ((c / total) * 100).toFixed(0) : 0}%) +
+ {/each} +
+
+ {/if} +
diff --git a/src/lib/components/admin/Analytics/Dashboard.svelte b/src/lib/components/admin/Analytics/Dashboard.svelte new file mode 100644 index 0000000000..98877d1d43 --- /dev/null +++ b/src/lib/components/admin/Analytics/Dashboard.svelte @@ -0,0 +1,464 @@ + + + +
+
+ {$i18n.t('Analytics')} +
+
+ {#if groups.length > 0} + + {/if} + +
+
+ + + + + +{#if !loading} +
+ {summary.total_messages.toLocaleString()} + {$i18n.t('messages')} + + {formatNumber(totalTokens.total)} + {$i18n.t('tokens')} + + {summary.total_chats.toLocaleString()} + {$i18n.t('chats')} + {summary.total_users} + {$i18n.t('users')} +
+ + + {#if dailyStats.length > 1} + {@const allModels = [...new Set(dailyStats.flatMap((d) => Object.keys(d.models || {})))]} + {@const topModels = allModels.slice(0, 8)} + {@const chartColors = [ + '#3b82f6', + '#10b981', + '#f59e0b', + '#ef4444', + '#8b5cf6', + '#ec4899', + '#06b6d4', + '#84cc16' + ]} + {@const periodMap = { '24h': 'hour', '7d': 'week', '30d': 'month', '90d': 'year', all: 'all' }} +
+
+ {$i18n.t(selectedPeriod === '24h' ? 'Hourly Messages' : 'Daily Messages')} +
+ +
+ {/if} +{/if} + +{#if loading} +
+ +
+{:else} +
+ +
+
+ {$i18n.t('Model Usage')} +
+
+ + + + + + + + + + + + {#each sortedModels as model, idx (model.model_id)} + { + selectedModel = { id: model.model_id, name: model.name }; + showModelModal = true; + }} + > + + + + + + + {/each} + {#if sortedModels.length === 0} + + {/if} + +
# toggleModelSort('name')} + > +
+ {$i18n.t('Model')} + {#if modelOrderBy === 'name'} + + {#if modelDirection === 'asc'}{:else}{/if} + + {:else} + + {/if} +
+
toggleModelSort('count')} + > +
+ {$i18n.t('Messages')} + {#if modelOrderBy === 'count'} + + {#if modelDirection === 'asc'}{:else}{/if} + + {:else} + + {/if} +
+
{$i18n.t('Tokens')}%
{idx + 1} +
+ {model.name} + {model.name} +
+
{model.count.toLocaleString()}{formatNumber(tokenStats[model.model_id]?.total_tokens ?? 0)} + {totalModelMessages > 0 + ? ((model.count / totalModelMessages) * 100).toFixed(1) + : 0}% +
{$i18n.t('No data')}
+
+
+ + +
+
+ {$i18n.t('User Activity')} +
+
+ + + + + + + + + + + {#each sortedUsers as user, idx (user.user_id)} + + + + + + + {/each} + {#if sortedUsers.length === 0} + + {/if} + +
# toggleUserSort('name')} + > +
+ {$i18n.t('User')} + {#if userOrderBy === 'name'} + + {#if userDirection === 'asc'}{:else}{/if} + + {:else} + + {/if} +
+
toggleUserSort('count')} + > +
+ {$i18n.t('Messages')} + {#if userOrderBy === 'count'} + + {#if userDirection === 'asc'}{:else}{/if} + + {:else} + + {/if} +
+
{$i18n.t('Tokens')}
{idx + 1} +
+ {user.name + {user.name || user.email || user.user_id.substring(0, 8)} +
+
{user.count.toLocaleString()}{formatNumber(user.total_tokens ?? 0)}
{$i18n.t('No data')}
+
+
+
+ +
+ ⓘ {$i18n.t('Message counts are based on assistant responses.')} +
+{/if} diff --git a/src/lib/components/admin/Analytics/ModelUsage.svelte b/src/lib/components/admin/Analytics/ModelUsage.svelte new file mode 100644 index 0000000000..d24e798c88 --- /dev/null +++ b/src/lib/components/admin/Analytics/ModelUsage.svelte @@ -0,0 +1,155 @@ + + +
+
+ {$i18n.t('Model Usage')} + {totalMessages} {$i18n.t('messages')} +
+
+ +
+ {#if loading} +
+ +
+ {/if} + + {#if !modelStats.length && !loading} +
{$i18n.t('No data found')}
+ {:else if modelStats.length} + + + + + + + + + + + {#each sortedModels as model, idx (model.model_id)} + + + + + + + {/each} + +
# toggleSort('name')} + > +
+ {$i18n.t('Model')} + {#if orderBy === 'name'} + {#if direction === 'asc'}{:else}{/if} + {:else} + + {/if} +
+
toggleSort('count')} + > +
+ {$i18n.t('Messages')} + {#if orderBy === 'count'} + {#if direction === 'asc'}{:else}{/if} + {:else} + + {/if} +
+
{$i18n.t('Share')}
+ {idx + 1} + +
+ {model.name} + {model.name} +
+
+ {model.count.toLocaleString()} + + {((model.count / totalMessages) * 100).toFixed(1)}% +
+ {/if} +
+ +
+
+ ⓘ {$i18n.t('Message counts are based on assistant responses.')} +
+
diff --git a/src/lib/components/admin/Analytics/UserUsage.svelte b/src/lib/components/admin/Analytics/UserUsage.svelte new file mode 100644 index 0000000000..09be24e426 --- /dev/null +++ b/src/lib/components/admin/Analytics/UserUsage.svelte @@ -0,0 +1,145 @@ + + +
+
+ {$i18n.t('User Activity')} + {userStats.length} {$i18n.t('users')} +
+
+ +
+ {#if loading} +
+ +
+ {/if} + + {#if !userStats.length && !loading} +
{$i18n.t('No data found')}
+ {:else if userStats.length} + + + + + + + + + + + {#each sortedUsers as user, idx (user.user_id)} + + + + + + + {/each} + +
# toggleSort('user_id')} + > +
+ {$i18n.t('User')} + {#if orderBy === 'user_id'} + {#if direction === 'asc'}{:else}{/if} + {:else} + + {/if} +
+
toggleSort('count')} + > +
+ {$i18n.t('Messages')} + {#if orderBy === 'count'} + {#if direction === 'asc'}{:else}{/if} + {:else} + + {/if} +
+
{$i18n.t('Share')}
+ {idx + 1} + + + {user.user_id.substring(0, 8)}... + + + {user.count.toLocaleString()} + + {((user.count / totalMessages) * 100).toFixed(1)}% +
+ {/if} +
+ +
+
+ ⓘ {$i18n.t('Showing all messages (user + assistant) per user.')} +
+
diff --git a/src/lib/components/admin/Evaluations/Feedbacks.svelte b/src/lib/components/admin/Evaluations/Feedbacks.svelte index 0ec5678f0e..1dedb94f05 100644 --- a/src/lib/components/admin/Evaluations/Feedbacks.svelte +++ b/src/lib/components/admin/Evaluations/Feedbacks.svelte @@ -302,9 +302,13 @@
{#if feedback.data?.sibling_model_ids} -
- {feedback.data?.model_id} -
+ +
+ {feedback.data?.model_id} +
+
@@ -320,11 +324,13 @@
{:else} -
- {feedback.data?.model_id} -
+ +
+ {feedback.data?.model_id} +
+
{/if}
diff --git a/src/lib/components/admin/Evaluations/Leaderboard.svelte b/src/lib/components/admin/Evaluations/Leaderboard.svelte index e16e62c98c..abe0f952e4 100644 --- a/src/lib/components/admin/Evaluations/Leaderboard.svelte +++ b/src/lib/components/admin/Evaluations/Leaderboard.svelte @@ -182,7 +182,11 @@ alt={model.name} class="size-5 rounded-full object-cover" /> - {model.name} + + {model.name} +
diff --git a/src/lib/components/admin/Evaluations/LeaderboardModal.svelte b/src/lib/components/admin/Evaluations/LeaderboardModal.svelte index fc3ec6eb10..6730e739d7 100644 --- a/src/lib/components/admin/Evaluations/LeaderboardModal.svelte +++ b/src/lib/components/admin/Evaluations/LeaderboardModal.svelte @@ -4,6 +4,7 @@ import { getModelHistory } from '$lib/apis/evaluations'; import ModelActivityChart from './ModelActivityChart.svelte'; import XMark from '$lib/components/icons/XMark.svelte'; + import Tooltip from '$lib/components/common/Tooltip.svelte'; export let show = false; export let model = null; @@ -60,9 +61,11 @@ {#if model}
-
- {model.name} -
+ +
+ {model.name} +
+
diff --git a/src/lib/components/admin/Functions.svelte b/src/lib/components/admin/Functions.svelte index 67a1fbbdfd..48c1863e74 100644 --- a/src/lib/components/admin/Functions.svelte +++ b/src/lib/components/admin/Functions.svelte @@ -4,7 +4,7 @@ const { saveAs } = fileSaver; import { WEBUI_NAME, config, functions as _functions, models, settings, user } from '$lib/stores'; - import { onMount, getContext, tick } from 'svelte'; + import { onMount, getContext, tick, onDestroy } from 'svelte'; import { goto } from '$app/navigation'; import { @@ -53,6 +53,7 @@ let viewOption = ''; let query = ''; + let searchDebounceTimer: ReturnType; let selectedTag = ''; let selectedType = ''; @@ -70,12 +71,14 @@ let functions = null; let filteredItems = []; - $: if ( - functions && - query !== undefined && - selectedType !== undefined && - viewOption !== undefined - ) { + $: if (query !== undefined) { + clearTimeout(searchDebounceTimer); + searchDebounceTimer = setTimeout(() => { + setFilteredItems(); + }, 300); + } + + $: if (functions && selectedType !== undefined && viewOption !== undefined) { setFilteredItems(); } @@ -86,7 +89,10 @@ (selectedType !== '' ? f.type === selectedType : true) && (query === '' || f.name.toLowerCase().includes(query.toLowerCase()) || - f.id.toLowerCase().includes(query.toLowerCase())) && + f.id.toLowerCase().includes(query.toLowerCase()) || + (f.user?.name || '').toLowerCase().includes(query.toLowerCase()) || + (f.user?.email || '').toLowerCase().includes(query.toLowerCase()) || + (f.user?.username || '').toLowerCase().includes(query.toLowerCase())) && (viewOption === '' || (viewOption === 'created' && f.user_id === $user?.id) || (viewOption === 'shared' && f.user_id !== $user?.id)) @@ -236,6 +242,10 @@ window.removeEventListener('blur-sm', onBlur); }; }); + + onDestroy(() => { + clearTimeout(searchDebounceTimer); + }); diff --git a/src/lib/components/admin/Settings/Documents.svelte b/src/lib/components/admin/Settings/Documents.svelte index 732d824692..3f4434084b 100644 --- a/src/lib/components/admin/Settings/Documents.svelte +++ b/src/lib/components/admin/Settings/Documents.svelte @@ -362,6 +362,30 @@
+ +
+
+
+ + {$i18n.t('PDF Loader Mode')} + +
+
+ +
+
+
{:else if RAGConfig.CONTENT_EXTRACTION_ENGINE === 'datalab_marker'}
0 ? modelIds : null, filter_mode: modelIds.length > 0 ? (filterMode ? filterMode : null) : null, - access_control: accessControl + access_grants: accessGrants } }; @@ -107,7 +107,7 @@ description = model.meta.description; modelIds = model.meta.model_ids || []; filterMode = model.meta?.filter_mode ?? 'include'; - accessControl = 'access_control' in model.meta ? model.meta.access_control : {}; + accessGrants = model.meta.access_grants ?? []; } }; @@ -293,7 +293,7 @@
- +

diff --git a/src/lib/components/admin/Settings/Models.svelte b/src/lib/components/admin/Settings/Models.svelte index d107e09f70..2762d3a112 100644 --- a/src/lib/components/admin/Settings/Models.svelte +++ b/src/lib/components/admin/Settings/Models.svelte @@ -40,6 +40,11 @@ import Eye from '$lib/components/icons/Eye.svelte'; import { WEBUI_API_BASE_URL, WEBUI_BASE_URL } from '$lib/constants'; import { goto } from '$app/navigation'; + import { DropdownMenu } from 'bits-ui'; + import { flyAndScale } from '$lib/utils/transitions'; + import Dropdown from '$lib/components/common/Dropdown.svelte'; + import AdminViewSelector from './Models/AdminViewSelector.svelte'; + import Pagination from '$lib/components/common/Pagination.svelte'; let shiftKey = false; @@ -58,21 +63,52 @@ let showConfigModal = false; let showManageModal = false; + let viewOption = ''; // '' = All, 'enabled', 'disabled', 'visible', 'hidden' + + const perPage = 30; + let currentPage = 1; + $: if (models) { filteredModels = models .filter((m) => searchValue === '' || m.name.toLowerCase().includes(searchValue.toLowerCase())) + .filter((m) => { + if (viewOption === 'enabled') return m?.is_active ?? true; + if (viewOption === 'disabled') return !(m?.is_active ?? true); + if (viewOption === 'visible') return !(m?.meta?.hidden ?? false); + if (viewOption === 'hidden') return m?.meta?.hidden === true; + return true; // All + }) .sort((a, b) => { - // // Check if either model is inactive and push them to the bottom - // if ((a.is_active ?? true) !== (b.is_active ?? true)) { - // return (b.is_active ?? true) - (a.is_active ?? true); - // } - // If both models' active states are the same, sort alphabetically return (a?.name ?? a?.id ?? '').localeCompare(b?.name ?? b?.id ?? ''); }); } let searchValue = ''; + $: if (searchValue || viewOption !== undefined) { + currentPage = 1; + } + + const enableAllHandler = async () => { + const modelsToEnable = filteredModels.filter((m) => !(m.is_active ?? true)); + // Optimistic UI update + modelsToEnable.forEach((m) => (m.is_active = true)); + models = models; + // Sync with server + await Promise.all(modelsToEnable.map((model) => toggleModelById(localStorage.token, model.id))); + }; + + const disableAllHandler = async () => { + const modelsToDisable = filteredModels.filter((m) => m.is_active ?? true); + // Optimistic UI update + modelsToDisable.forEach((m) => (m.is_active = false)); + models = models; + // Sync with server + await Promise.all( + modelsToDisable.map((model) => toggleModelById(localStorage.token, model.id)) + ); + }; + const downloadModels = async (models) => { let blob = new Blob([JSON.stringify(models)], { type: 'application/json' @@ -275,41 +311,48 @@ {#if selectedModelId === null}
-
- {$i18n.t('Models')} - {filteredModels.length} +
+
+ {$i18n.t('Models')} +
+ +
+ {filteredModels.length} +
-
- - - +
+ - - - +
+
-
+
+
@@ -333,155 +376,218 @@ {/if}
-
-
- {#if models.length > 0} - {#each filteredModels as model, modelIdx (`${model.id}-${modelIdx}`)} -
+
+
+ +
+ +
+ + + + + +
+ + { + enableAllHandler(); + }} + > + +
{$i18n.t('Enable All')}
+
+ + { + disableAllHandler(); + }} + > + +
{$i18n.t('Disable All')}
+
+
+
+
+
+ +
+ {#if filteredModels.length > 0} + {#each filteredModels.slice((currentPage - 1) * perPage, currentPage * perPage) as model, modelIdx (`${model.id}-${modelIdx}`)} +
+
-
- -
{model.name}
-
-
- - {!!model?.meta?.description - ? model?.meta?.description - : model?.ollama?.digest - ? `${model.id} (${model?.ollama?.digest})` - : model.id} - + +
{model.name}
+
+
+ + {!!model?.meta?.description + ? model?.meta?.description + : model?.ollama?.digest + ? `${model.id} (${model?.ollama?.digest})` + : model.id} + +
-
- -
- {#if shiftKey} - + +
+ {#if shiftKey} + + + + {:else} - - {:else} - - - { - exportModelHandler(model); - }} - hideHandler={() => { - hideModelHandler(model); - }} - pinModelHandler={() => { - pinModelHandler(model.id); - }} - copyLinkHandler={() => { - copyLinkHandler(model); - }} - cloneHandler={() => { - cloneHandler(model); - }} - onClose={() => {}} - > - - -
- { + exportModelHandler(model); + }} + hideHandler={() => { + hideModelHandler(model); + }} + pinModelHandler={() => { + pinModelHandler(model.id); + }} + copyLinkHandler={() => { + copyLinkHandler(model); + }} + cloneHandler={() => { + cloneHandler(model); + }} + onClose={() => {}} > - { - toggleModelHandler(model); - }} - /> - -
- {/if} + + + +
+ + { + toggleModelHandler(model); + }} + /> + +
+ {/if} +
+
+ {/each} + {:else} +
+
+
😕
+
{$i18n.t('No models found')}
+
+ {$i18n.t('Try adjusting your search or filter to find what you are looking for.')} +
- {/each} - {:else} -
-
- {$i18n.t('No models found')} -
-
+ {/if} +
+ + {#if filteredModels.length > perPage} + {/if}
diff --git a/src/lib/components/admin/Settings/Models/AdminViewSelector.svelte b/src/lib/components/admin/Settings/Models/AdminViewSelector.svelte new file mode 100644 index 0000000000..5778ba21ab --- /dev/null +++ b/src/lib/components/admin/Settings/Models/AdminViewSelector.svelte @@ -0,0 +1,63 @@ + + + item.value === value)} + {items} + onSelectedChange={(selectedItem) => { + value = selectedItem.value; + onChange(value); + }} +> + + + + + + + {#each items as item} + + {item.label} + + {#if value === item.value} +
+ +
+ {/if} +
+ {/each} +
+
diff --git a/src/lib/components/admin/Settings/Models/Manage/ManageMultipleOllama.svelte b/src/lib/components/admin/Settings/Models/Manage/ManageMultipleOllama.svelte index 54439d16a0..90f1669368 100644 --- a/src/lib/components/admin/Settings/Models/Manage/ManageMultipleOllama.svelte +++ b/src/lib/components/admin/Settings/Models/Manage/ManageMultipleOllama.svelte @@ -22,5 +22,7 @@
- +
+ +
{/if} diff --git a/src/lib/components/admin/Settings/Models/ModelList.svelte b/src/lib/components/admin/Settings/Models/ModelList.svelte index cc86e52e5f..d501a485d4 100644 --- a/src/lib/components/admin/Settings/Models/ModelList.svelte +++ b/src/lib/components/admin/Settings/Models/ModelList.svelte @@ -50,7 +50,7 @@
-
+
{#if $models.find((model) => model.id === modelId)} {$models.find((model) => model.id === modelId).name} {:else} diff --git a/src/lib/components/admin/Settings/WebSearch.svelte b/src/lib/components/admin/Settings/WebSearch.svelte index e91a110f81..4778b69ec5 100644 --- a/src/lib/components/admin/Settings/WebSearch.svelte +++ b/src/lib/components/admin/Settings/WebSearch.svelte @@ -7,6 +7,7 @@ import { toast } from 'svelte-sonner'; import SensitiveInput from '$lib/components/common/SensitiveInput.svelte'; import Tooltip from '$lib/components/common/Tooltip.svelte'; + import Textarea from '$lib/components/common/Textarea.svelte'; const i18n = getContext('i18n'); @@ -35,7 +36,8 @@ 'perplexity', 'sougou', 'firecrawl', - 'external' + 'external', + 'yandex' ]; let webLoaderEngines = ['playwright', 'firecrawl', 'tavily', 'external']; @@ -735,6 +737,55 @@ />
+ {:else if webConfig.WEB_SEARCH_ENGINE === 'yandex'} +
+
+
+ {$i18n.t('Yandex Web Search URL')} +
+ +
+
+ +
+
+
+ +
+
+ {$i18n.t('Yandex Web Search API Key')} +
+ + +
+ +
+
{$i18n.t('Yandex Web Search config')}
+ + +