diff --git a/.github/workflows/format-build-frontend.yaml b/.github/workflows/format-build-frontend.yaml index 0f8b1322b9..822bca7bf4 100644 --- a/.github/workflows/format-build-frontend.yaml +++ b/.github/workflows/format-build-frontend.yaml @@ -34,7 +34,9 @@ jobs: node-version: '22' - name: Install Dependencies - run: npm install --force + run: | + npm install -g npm@latest + npm install --force - name: Format Frontend run: npm run format @@ -61,7 +63,9 @@ jobs: node-version: '22' - name: Install Dependencies - run: npm ci --force + run: | + npm install -g npm@latest + npm ci --force - name: Run vitest run: npm run test:frontend diff --git a/CHANGELOG.md b/CHANGELOG.md index 4915f28376..b63e1e4e21 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,130 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [0.7.1] - 2026-01-09 + +### Fixed + +- ⚡ **Improved reliability for low-spec and SQLite deployments.** Fixed page timeouts by disabling database session sharing by default, improving stability for resource-constrained environments. Users can re-enable via 'DATABASE_ENABLE_SESSION_SHARING=true' if needed. [#20520](https://github.com/open-webui/open-webui/issues/20520) + +## [0.7.0] - 2026-01-09 + +### Added + +- 🤖 **Native Function Calling with Built-in Tools.** Users can now ask models to perform multi-step tasks that combine web research, knowledge base queries, note-taking, and image generation in a single conversation—for example, "research the latest on X, save key findings to a note, and generate an infographic." Requires models with native function calling support and function calling mode set to "Native" in Chat Controls. [#19397](https://github.com/open-webui/open-webui/issues/19397), [Commit](https://github.com/open-webui/open-webui/commit/5c1d52231a3997a17381c48639bd7e339262cf7c) +- 🧠 Users can now ask the model to find relevant context from their notes, past chats, and channel messages—for example, "what did I discuss about project X last week?" or "find the conversation where I brainstormed ideas for Y." [Commit](https://github.com/open-webui/open-webui/commit/646835d76744ad9b2e67ede0407a61d62e969aab) +- 📚 Users can now ask the model to search their knowledge bases and retrieve documents without manually attaching files—for example, "find the section about authentication in our API docs" or "what do our internal guidelines say about X?" [Commit](https://github.com/open-webui/open-webui/commit/c8622adcb01f3091b17ca50f8c8e2f20c7b9cd2a) +- 💭 Users with models that support interleaved thinking now get more refined results from multi-step workflows, as the model can analyze each tool's output before deciding what to do next. +- 🔍 When models invoke web search, search results appear as clickable citations in real-time for full source verification. [Commit](https://github.com/open-webui/open-webui/commit/2789f6a24d8405c30cd48ae460071f6a4f2c35f9) +- 🎚️ Users can selectively disable specific built-in tools (timestamps, memory, chat history, notes, web search, knowledge bases) per model via the model editor's capabilities settings. [Commit](https://github.com/open-webui/open-webui/commit/60e916d6c0c5f7db9e6d670e12be2d1d4abc2dd6) +- 👁️ Pending tool calls are now displayed during response generation, so users know which tools are being invoked. [Commit](https://github.com/open-webui/open-webui/commit/1d08376860e775049abd1dd5f568ac0c6466c944) +- 📁 Administrators can now limit the number of files that can be uploaded to folders using the "FOLDER_MAX_FILE_COUNT" setting, preventing resource exhaustion from bulk uploads. [#19810](https://github.com/open-webui/open-webui/issues/19810), [Commit](https://github.com/open-webui/open-webui/commit/a1036e544d573e3d35e05c1c2472ba762c32431b), [Commit](https://github.com/open-webui/open-webui/commit/d3ee3fd23e762c9d83fe1da5636d03259e186e57) +- ⚡ Users experience transformative speed improvements across the entire application through completely reengineered database connection handling, delivering noticeably faster page loads, butter-smooth interactions, and rock-solid stability during intensive operations like user management and bulk data processing. [Commit](https://github.com/open-webui/open-webui/commit/2041ab483e21b3a757baa25c47dc2fa29018674f), [Commit](https://github.com/open-webui/open-webui/commit/145c7516f227ce56fd52373cee86217fadf16181), [Commit](https://github.com/open-webui/open-webui/commit/475dd91ed798f2efcdf27799d5c5cae3f0e6e847), [Commit](https://github.com/open-webui/open-webui/commit/5d1459df166cce8445eb1556cc26abbd65a3f9f4), [Commit](https://github.com/open-webui/open-webui/commit/2453b75ff0fb2dc75b929d96f417d84e332922d9), [Commit](https://github.com/open-webui/open-webui/commit/5649a668fad15393a52c27a2f188841af8b66989) +- 🚀 Users experience significantly faster initial page load times through dynamic loading of document processing libraries, reducing the initial bundle size. [#20200](https://github.com/open-webui/open-webui/pull/20200), [#20202](https://github.com/open-webui/open-webui/pull/20202), [#20203](https://github.com/open-webui/open-webui/pull/20203), [#20204](https://github.com/open-webui/open-webui/pull/20204) +- 💨 Administrators experience dramatically faster user list loading through optimized database queries that eliminate N+1 query patterns, reducing query count from 1+N to just 2 total queries regardless of user count. [#20427](https://github.com/open-webui/open-webui/pull/20427) +- 📋 Notes now load faster through optimized database queries that batch user lookups instead of fetching each note's author individually. [Commit](https://github.com/open-webui/open-webui/commit/084f0ef6a5491e186bf6b71c6386973ba18ef2fa) +- 💬 Channel messages, pinned messages, and thread replies now load faster through batched user lookups instead of individual queries per message. [#20458](https://github.com/open-webui/open-webui/pull/20458), [#20459](https://github.com/open-webui/open-webui/pull/20459), [#20460](https://github.com/open-webui/open-webui/pull/20460) +- 🔗 Users can now click citation content links to jump directly to the relevant portion of source documents with automatic text highlighting, making it easier to verify AI responses against their original sources. [#20116](https://github.com/open-webui/open-webui/pull/20116), [Commit](https://github.com/open-webui/open-webui/commit/40c45ffe1f9b45538d32c8ecba8cac62c6eca503) +- 📌 Users can now pin or hide models directly from the Workspace Models page and Admin Settings Models page, making it easier to manage which models appear in the sidebar without switching to the chat interface. [#20176](https://github.com/open-webui/open-webui/pull/20176) +- 🔎 Administrators can now quickly find settings using the new search bar in the Admin Settings sidebar, which supports fuzzy filtering by category names and related keywords like "whisper" for Audio or "rag" for Documents. [#20434](https://github.com/open-webui/open-webui/pull/20434) +- 🎛️ Users can now view read-only models in the workspace models list, with clear "Read Only" badges indicating when editing is restricted. [#20243](https://github.com/open-webui/open-webui/issues/20243), [#20369](https://github.com/open-webui/open-webui/pull/20369) +- 📝 Users can now view read-only prompts in the workspace prompts list, with clear "Read Only" badges indicating when editing is restricted. [#20368](https://github.com/open-webui/open-webui/pull/20368) +- 🔧 Users can now view read-only tools in the workspace tools list, with clear "Read Only" badges indicating when editing is restricted. [#20243](https://github.com/open-webui/open-webui/issues/20243), [#20370](https://github.com/open-webui/open-webui/pull/20370) +- 📂 Searching for files is now significantly faster, especially for users with large file collections. [Commit](https://github.com/open-webui/open-webui/commit/a9a979fb3db1743553ca0705f571c0b9c252841f) +- 🏆 The Evaluations leaderboard now calculates Elo ratings on the backend instead of in the browser, improving performance and enabling topic-based model ranking through semantic search. [#15392](https://github.com/open-webui/open-webui/pull/15392), [#20476](https://github.com/open-webui/open-webui/issues/20476), [Commit](https://github.com/open-webui/open-webui/commit/10838b3654bf6fdef02d57311f7f1c01df4cd033) +- 📊 The Evaluations leaderboard now includes a per-model activity chart displaying daily wins and losses as a diverging bar chart, with 30-day, 1-year, and all-time views using weekly aggregation for longer timeframes. +- 🎞️ Users can now upload animated GIF and WebP formats as model profile images, with animation preserved by skipping resize processing for these file types. [Commit](https://github.com/open-webui/open-webui/commit/00af37bb4ed1ea0957c7c84a6a8def3a7998b8ca) +- 📸 Users uploading profile images for users, models, and arena models now benefit from WebP compression at 80% quality instead of JPEG, resulting in significantly smaller file sizes and faster uploads while maintaining visual quality. [Commit](https://github.com/open-webui/open-webui/commit/b1d30673b69571e081abb881d34944cf33cdc67e) +- ⭐ Action Function developers can now update message favorite status using the new "chat:message:favorite" event, enabling the development of pin/unpin message actions without race conditions from frontend auto-save. [#20375](https://github.com/open-webui/open-webui/pull/20375) +- 🌐 Users with OpenAI-compatible models that have web search capabilities now see URL citations displayed as sources in the interface. [#20172](https://github.com/open-webui/open-webui/pull/20172), [Commit](https://github.com/open-webui/open-webui/commit/fe84afd09a2bc8a89311f30186ec2608a4edda3a) +- 📰 Users can now dismiss the "What's New" changelog modal permanently using the X button, matching the behavior of the "Okay, Let's Go!" button. [#20258](https://github.com/open-webui/open-webui/pull/20258) +- 📧 Administrators can now configure the admin contact email displayed in the Account Pending overlay directly from the Admin Panel instead of only through environment variables. [#12500](https://github.com/open-webui/open-webui/issues/12500), [#20260](https://github.com/open-webui/open-webui/pull/20260) +- 📄 Administrators can now enable markdown header text splitting as a preprocessing step that works with either character or token splitting, through the new "ENABLE_MARKDOWN_HEADER_TEXT_SPLITTER" setting. [Commit](https://github.com/open-webui/open-webui/commit/d3a682759f375c6cb0bc8c170f62863a070f712e), [Commit](https://github.com/open-webui/open-webui/commit/18a33a079bf07487edffc421a721c86194cc90c9), [Commit](https://github.com/open-webui/open-webui/commit/08bf4670ec862018f9dc57296cb19fd5eab14ef6) +- 🧩 Administrators can now set a minimum chunk size target using the "CHUNK_MIN_SIZE_TARGET" setting to merge small markdown header chunks with neighbors, which improves retrieval quality by eliminating tiny meaningless fragments, significantly speeds up document processing and embedding, reduces storage costs, and lowers embedding API costs or local compute requirements. [#19595](https://github.com/open-webui/open-webui/issues/19595), [#20314](https://github.com/open-webui/open-webui/pull/20314), [Commit](https://github.com/open-webui/open-webui/commit/c32435958073cf002d87e78544baa88bc4e15d7f) +- 💨 Administrators can now enable KV prefix caching optimization by setting "RAG_SYSTEM_CONTEXT" to true, which injects RAG context into the system message instead of user messages, enabling models to reuse cached tokens for follow-up questions instead of reprocessing the entire context on each turn, significantly improving response times and reducing costs for cloud-based models. [#20301](https://github.com/open-webui/open-webui/discussions/20301), [#20317](https://github.com/open-webui/open-webui/pull/20317) +- 🖼️ Administrators and Action developers can now control image generation denoising steps per-request using a steps parameter, allowing Actions and API calls to override the global IMAGE_STEPS configuration for both ComfyUI and Automatic1111 engines. [#20337](https://github.com/open-webui/open-webui/pull/20337) +- 🗄️ Administrators running multi-pod deployments can now designate a master pod to handle database migrations using the "ENABLE_DB_MIGRATIONS" environment variable. [Commit](https://github.com/open-webui/open-webui/commit/9824f0e33359a917ac07b60bf1f972074d5c8203) +- 🎙️ Administrators can now configure Whisper's compute type using the "WHISPER_COMPUTE_TYPE" environment variable to fix compatibility issues with CUDA/GPU deployments. [Commit](https://github.com/open-webui/open-webui/commit/26af1f92e21ddfd08348570bf54a6f345ac69648) +- 🔍 Administrators can now control sigmoid normalization for CrossEncoder reranking models using the "SENTENCE_TRANSFORMERS_CROSS_ENCODER_SIGMOID_ACTIVATION_FUNCTION" environment variable, enabled by default for proper relevance threshold behavior with MS MARCO models. [#20228](https://github.com/open-webui/open-webui/pull/20228) +- 🔒 Administrators can now disable SSL certificate verification for external tools using the "REQUESTS_VERIFY" environment variable, enabling integration with self-signed certificates for Tika, Ollama embeddings, and external rerankers. [#19968](https://github.com/open-webui/open-webui/issues/19968), [Commit](https://github.com/open-webui/open-webui/commit/dfc5dad63167eabb7fb027e63c324675b23f2e9d) +- 📈 Administrators can now control audit log output destinations using "ENABLE_AUDIT_STDOUT" and "ENABLE_AUDIT_LOGS_FILE" environment variables, allowing audit logs to be sent to container logs for centralized logging systems. [#20114](https://github.com/open-webui/open-webui/pull/20114), [Commit](https://github.com/open-webui/open-webui/commit/fdae5644e36972384b3e2513e3074f95f9f7381f) +- 🛡️ Administrators can now restrict non-admin user access to Interface Settings through per-user or per-group permissions. [#20424](https://github.com/open-webui/open-webui/pull/20424) +- 🧠 Administrators can now globally enable or disable the Memories feature and control access through per-user or per-group permissions, with the Personalization tab automatically hidden when the feature is disabled. [#20462](https://github.com/open-webui/open-webui/pull/20462) +- 🟢 Administrators can now globally enable or disable user status visibility through the "ENABLE_USER_STATUS" setting in Admin Settings. [#20488](https://github.com/open-webui/open-webui/pull/20488) +- 🪝 Channel managers can now create webhooks to allow external services to post messages to channels without authentication. [Commit](https://github.com/open-webui/open-webui/commit/cd296fcf0d79cecd1a6a3ee4e492c6b5246ca7ae) +- 📄 In the model editor users can now disable the "File Context" capability to skip automatic file content extraction and injection, forwarding raw messages with file attachment metadata instead for use with custom tools or future built-in file access tools. [Commit](https://github.com/open-webui/open-webui/commit/daccf0713e3ecd6d24f003a87b5f8b3c61852958), [Docs:Commit](https://github.com/open-webui/docs/commit/18ec6eaefc071a278ec57d4d1b8d66d686af0870) +- 🔊 In the model editor users can now configure a specific TTS voice for each model, overriding user preferences and global defaults to give different AI personas distinct voices. [#3097](https://github.com/open-webui/open-webui/issues/3097), [Commit](https://github.com/open-webui/open-webui/commit/bb6188abf04302f79d80b0d6cc42c232624b5757) +- 👥 Administrators now have three granular group sharing permission options instead of a simple on/off toggle, allowing them to choose between "No one", "Members", or "Anyone" for who can share content to each group. [Commit](https://github.com/open-webui/open-webui/commit/ca514cd3eda2524b8da472ef17c0ccb216bac2e8) +- 📦 Administrators can now export knowledge bases as zip files containing text files for backup and archival purposes. [#20120](https://github.com/open-webui/open-webui/issues/20120), [Commit](https://github.com/open-webui/open-webui/commit/c1147578c073a8c7fa7e7f836149e1cdfec8f18d) +- 🚀 Administrators can now create an admin account automatically at startup via "WEBUI_ADMIN_EMAIL", "WEBUI_ADMIN_PASSWORD", and "WEBUI_ADMIN_NAME" environment variables, enabling headless and automated deployments without exposing the signup page. [#17654](https://github.com/open-webui/open-webui/issues/17654), [Commit](https://github.com/open-webui/open-webui/commit/1138929f4d083931305f1f925899971b190562ae) +- 🦆 Administrators can now select a specific search backend for DDGS instead of random selection, with options including Bing, Brave, DuckDuckGo, Google, Wikipedia, Yahoo, and others. [#20330](https://github.com/open-webui/open-webui/issues/20330), [#20366](https://github.com/open-webui/open-webui/pull/20366) +- 🧭 Administrators can now configure custom Jina Search API endpoints using the "JINA_API_BASE_URL" environment variable, enabling region-specific deployments such as EU data processing. [#19718](https://github.com/open-webui/open-webui/pull/19718), [Commit](https://github.com/open-webui/open-webui/commit/f7f8a263b92289df8d4f8dbc3bae09bd009a5699) +- 🔥 Administrators can now configure Firecrawl timeout values using the "FIRECRAWL_TIMEOUT" environment variable to control web scraping wait times. [#19973](https://github.com/open-webui/open-webui/pull/19973), [Commit](https://github.com/open-webui/open-webui/commit/89ad1c68d1aadf849960b5e202aa4651096b05f5) +- 💾 Administrators can now use openGauss as the vector database backend for knowledge base document storage and retrieval. [#20179](https://github.com/open-webui/open-webui/pull/20179) +- 🔄 Various improvements were implemented across the application to enhance performance, stability, and security. +- 📊 Users can now sync their anonymous usage statistics to the Open WebUI Community platform to power community leaderboards, drive model evaluations, and contribute to open-source AI research that benefits everyone, all while keeping conversations completely private (only metadata like model names, message counts, and ratings are shared). By sharing your stats, you're helping the community identify which models perform best, contributing to transparent AI benchmarking, and supporting the collective effort to make AI better for all. You can also download your stats as JSON for personal analysis. +- 🌐 Translations for German, Portuguese (Brazil), Spanish, Simplified Chinese, Traditional Chinese, and Polish were enhanced and expanded. + +### Fixed + +- 🔊 Text-to-speech now correctly splits on newlines in addition to punctuation, so markdown bullet points and lists are spoken as separate sentences instead of being merged together. [#5924](https://github.com/open-webui/open-webui/issues/5924), [Commit](https://github.com/open-webui/open-webui/commit/869108a3e1ce2b8110084113c1b392072e98fd5f) +- 🔒 Users are now protected from stored XSS vulnerabilities in iFrame embeds for citations and response messages through configurable same-origin sandbox settings instead of hardcoded values. [#20209](https://github.com/open-webui/open-webui/pull/20209), [#20210](https://github.com/open-webui/open-webui/pull/20210) +- 🔑 Image Generation, Web Search, and Audio (TTS/STT) API endpoints now enforce permission checks on the backend, closing a security gap where disabled features could previously be accessed via direct API calls. [#20471](https://github.com/open-webui/open-webui/pull/20471) +- 🛠️ Tools and Tool Servers (MCP and OpenAPI) now enforce access control checks on the backend, ensuring users can only access tools they have permission to use even via direct API calls. [#20443](https://github.com/open-webui/open-webui/issues/20443), [Commit](https://github.com/open-webui/open-webui/commit/9b06fdc8fe1c933071610336be05f11e77e6c8eb) +- 🔁 System prompts are no longer duplicated when using native function calling, fixing an issue where the prompt would be applied twice during tool-calling workflows. [Commit](https://github.com/open-webui/open-webui/commit/9223efaff0db6e56bfa157ef214d9590005156d2) +- 🗂️ Knowledge base uploads to folders no longer fail when "FOLDER_MAX_FILE_COUNT" is unset, fixing an issue where the default null value caused all uploads to error. [Commit](https://github.com/open-webui/open-webui/commit/ef9cd0e0ad6e45b8a3efec6f3858b3d69d42f619) +- 📝 The "Create Note" button in the chat input now correctly hides for users without Notes permissions instead of showing and returning a 401 error when clicked. [#20486](https://github.com/open-webui/open-webui/issues/20486), [Commit](https://github.com/open-webui/open-webui/commit/9e9616b670c1c4389193b18500a7d80d86d7e280) +- 📊 The Evaluations page no longer crashes when administrators have large amounts of feedback data, as the leaderboard now fetches only the minimal required fields instead of loading entire conversation snapshots. [#20476](https://github.com/open-webui/open-webui/issues/20476), [#20489](https://github.com/open-webui/open-webui/pull/20489), [Commit](https://github.com/open-webui/open-webui/commit/b2a1f71d920e55b143f1c02e61104938d2588762) +- 💬 Users can now export chats, use the Ask/Explain popup, and view chat lists correctly again after these features were broken by recent refactoring changes that caused 500 and 400 server errors. [#20146](https://github.com/open-webui/open-webui/issues/20146), [#20205](https://github.com/open-webui/open-webui/issues/20205), [#20206](https://github.com/open-webui/open-webui/issues/20206), [#20212](https://github.com/open-webui/open-webui/pull/20212) +- 💭 Users no longer experience data corruption when switching between chats during background operations like image generation, where messages from one chat would appear in another chat's history. [#20266](https://github.com/open-webui/open-webui/pull/20266) +- 🛡️ Users no longer encounter critical chat stability errors, including duplicate key errors from circular message dependencies, null message access during chat loading, and errors in the chat overview visualization. [#20268](https://github.com/open-webui/open-webui/pull/20268) +- 📡 Users with Channels no longer experience infinite recursion and connection pool exhaustion when fetching threaded replies, preventing RecursionError crashes during chat history loading. [#20299](https://github.com/open-webui/open-webui/pull/20299), [Commit](https://github.com/open-webui/open-webui/commit/c144122f608759c2b79472e1f6948a7c1600a3d1) +- 📎 Users no longer encounter TypeError crashes when viewing messages with file attachments that have undefined URL properties. [#20343](https://github.com/open-webui/open-webui/pull/20343) +- 🔐 Users with MCP integrations now experience reliable OAuth 2.1 token refresh after access token expiration through proper Protected Resource discovery, preventing integration failures that caused sessions to be deleted. [#19794](https://github.com/open-webui/open-webui/issues/19794), [#20138](https://github.com/open-webui/open-webui/pull/20138), [#20291](https://github.com/open-webui/open-webui/issues/20291), [Commit](https://github.com/open-webui/open-webui/commit/bf2b2962399e341926bdbf9e0a82101f31a90b23), [Commit](https://github.com/open-webui/open-webui/commit/89565c58c6ae6b5b129559ef68b5a0c18c110765) +- 📚 Users who belong to multiple groups can now see Knowledge Bases shared with those groups, fixing an issue where they would disappear when shared with more than one group. [#20124](https://github.com/open-webui/open-webui/issues/20124), [#20229](https://github.com/open-webui/open-webui/issues/20229), [Commit](https://github.com/open-webui/open-webui/commit/61e25dc2dce9c12dcb5b88a6b814060c4338e67b) +- 📂 Users now see the correct Knowledge Base name when hovering over # file references in chat input instead of "undefined". [#20329](https://github.com/open-webui/open-webui/issues/20329), [#20333](https://github.com/open-webui/open-webui/pull/20333) +- 📋 Users now see notes displayed in correct chronological order within their time range groupings, fixing an issue where insertion order was not preserved. [Commit](https://github.com/open-webui/open-webui/commit/3f577c0c3fbfd9f09c02940e4ae474f987149277) +- 📑 Users collaborating on notes now experience proper content sync when initializing from both HTML and JSON formats, fixing sync failures in collaborative editing sessions. [Commit](https://github.com/open-webui/open-webui/commit/e27fb3e291a735c715a089a80e7a49d2c2209096) +- 🔎 Users searching notes can now find hyphenated words and variations with spaces, so searching "todo" now finds "to-do" and "to do". [Commit](https://github.com/open-webui/open-webui/commit/a3270648d8b8535443d8ce2ea719f8e678e4e358) +- 📥 Users no longer experience false duplicate file warnings when reuploading files after initial processing failed, as the file hash is now only stored after successful processing completion. [#19264](https://github.com/open-webui/open-webui/issues/19264), [#20282](https://github.com/open-webui/open-webui/pull/20282), [Commit](https://github.com/open-webui/open-webui/commit/d3ab9f4b96eee7f91c9b1355cee055fdabca9730) +- 💾 Users experience significantly improved page load performance as model profile images now cache properly in browsers, avoiding unnecessary image refetches. [Commit](https://github.com/open-webui/open-webui/commit/bb821ab654e93908a3b4632c359753eeff053264) +- 🎨 Users can now successfully edit uploaded images instead of having new images generated, fixing an issue introduced by the file storage refactor where images with type "file" and content_type starting with "image/" weren't being recognized as editable images. [#20237](https://github.com/open-webui/open-webui/issues/20237), [#20169](https://github.com/open-webui/open-webui/pull/20169), [#20239](https://github.com/open-webui/open-webui/pull/20239), [Commit](https://github.com/open-webui/open-webui/commit/1148d1c927d096e14917b6d762789fca3188f281) +- 🌐 Users writing in Persian and Arabic now see properly displayed right-to-left text in the notes section through automatic text direction detection. [#19743](https://github.com/open-webui/open-webui/issues/19743), [#20102](https://github.com/open-webui/open-webui/pull/20102), [Commit](https://github.com/open-webui/open-webui/commit/b619a157bc54c5bc44d223d2ae3acb9ce4ac6a6c) +- 🤖 Users can now successfully @ mention models in Channels instead of experiencing silent failures. [Commit](https://github.com/open-webui/open-webui/commit/59957715836acb635f4b1c4ddbfb4ba7b82b3281) +- 📋 Users on Windows now see correctly preserved line breaks when using the {{CLIPBOARD}} variable through CRLF to LF normalization. [#19370](https://github.com/open-webui/open-webui/issues/19370), [#20283](https://github.com/open-webui/open-webui/pull/20283) +- 📁 Users now see the Knowledge Selector dropdown correctly displayed above the Create Folder modal instead of being hidden behind it. [#20219](https://github.com/open-webui/open-webui/issues/20219), [#20213](https://github.com/open-webui/open-webui/pull/20213) +- 🌅 Users now see profile images in non-PNG formats like SVG, JPEG, and GIF displayed correctly instead of appearing broken. [#20171](https://github.com/open-webui/open-webui/pull/20171) +- 🆕 Non-admin users with disabled temporary chat permissions can now successfully create new chats and use pinned models from the sidebar. [#20336](https://github.com/open-webui/open-webui/issues/20336), [#20367](https://github.com/open-webui/open-webui/pull/20367), [Commit](https://github.com/open-webui/open-webui/commit/e754940c031f9689fb4f6edb3625aa06aeb53377) +- 🎛️ Users can now successfully use workspace models in chat, fixing "Model not found" errors that occurred when using custom model presets. [#20340](https://github.com/open-webui/open-webui/issues/20340), [#20344](https://github.com/open-webui/open-webui/pull/20344), [Commit](https://github.com/open-webui/open-webui/commit/b55a46ae99c32068ed306a5ecdaafa9f75504cd7), [Commit](https://github.com/open-webui/open-webui/commit/2bb13d5dbc6e233856e8aa26143222ceda8f6c11) +- 🔁 Users can now regenerate messages without crashes when the parent message is missing or corrupted in the chat history. [#20264](https://github.com/open-webui/open-webui/pull/20264) +- ✏️ Users no longer experience TipTap rich text editor crashes with "editor view is not available" errors when plugins or async methods try to access the editor after it has been destroyed. [#20266](https://github.com/open-webui/open-webui/pull/20266) +- 📗 Administrators with bypass access control enabled now correctly have write access to all knowledge bases. [#20371](https://github.com/open-webui/open-webui/pull/20371) +- 🔍 Administrators using local CrossEncoder reranking models now see proper relevance threshold behavior through MS MARCO model score normalization to the 0-1 range via sigmoid activation. [#19999](https://github.com/open-webui/open-webui/issues/19999), [#20228](https://github.com/open-webui/open-webui/pull/20228) +- 🎯 Administrators using local SentenceTransformers embedding engine now benefit from proper batch size settings, preventing excessive memory usage from the default batch size of 32. [#20053](https://github.com/open-webui/open-webui/issues/20053), [#20054](https://github.com/open-webui/open-webui/pull/20054), [Commit](https://github.com/open-webui/open-webui/commit/e4a5b06ca68303512678b4d2dc296bc78b9f983f) +- 🔧 Administrators and users in offline mode or restricted environments like uv, poetry, and NixOS no longer experience crashes when Tools and Functions have frontmatter requirements, as pip installation is now skipped when offline mode is enabled. [#20320](https://github.com/open-webui/open-webui/issues/20320), [#20321](https://github.com/open-webui/open-webui/pull/20321), [Commit](https://github.com/open-webui/open-webui/commit/bd07ef8) +- 📄 Administrators can now properly configure the MinerU document parsing service as the MinerU Cloud API key field is now available in the Admin Panel Documents settings. [#20319](https://github.com/open-webui/open-webui/issues/20319), [#20328](https://github.com/open-webui/open-webui/pull/20328) +- ⚠️ Administrators no longer see SyntaxWarnings for invalid escape sequences in password validation regex patterns. [#20298](https://github.com/open-webui/open-webui/pull/20298), [Commit](https://github.com/open-webui/open-webui/commit/e55bf2c2ac391caed871d41f0484820091081908) +- 🎨 Users with ComfyUI workflows now see only the intended final output images in chat instead of duplicate images from intermediate processing nodes like masks, crops, or segmentation previews. [#20158](https://github.com/open-webui/open-webui/issues/20158), [#20182](https://github.com/open-webui/open-webui/pull/20182) +- 🖼️ Users with image generation enabled no longer see false vision capability warnings, allowing them to send follow-up messages after generating images and to send images to non-vision models for image editing. [#20129](https://github.com/open-webui/open-webui/issues/20129), [#20256](https://github.com/open-webui/open-webui/pull/20256) +- 🔌 Administrators no longer experience infinite loading screens when invalid or MCP-style configurations are used with OpenAPI connection types for external tools. [#20207](https://github.com/open-webui/open-webui/issues/20207), [#20257](https://github.com/open-webui/open-webui/pull/20257) +- 📥 Administrators no longer encounter TypeError crashes during SHA256 verification when uploading GGUF models via URL, fixing 500 Internal Server Error crashes. [#20263](https://github.com/open-webui/open-webui/issues/20263) +- 🚦 Users with Brave Search now experience automatic retry with a 1-second delay when hitting rate limits, preventing failures when sequential requests exceed the 1 request per second limit, though this only works reliably when web search concurrency is set to a maximum of 1. [#15134](https://github.com/open-webui/open-webui/issues/15134), [#20255](https://github.com/open-webui/open-webui/pull/20255) +- 🗄️ Administrators with Redis Sentinel deployments no longer experience crashes during websocket disconnections due to improper async-generator handling in the YDocManager. [#20142](https://github.com/open-webui/open-webui/issues/20142), [#20145](https://github.com/open-webui/open-webui/pull/20145) +- 🔐 Administrators using SCIM group management no longer encounter 500 errors when working with groups that have no members. [#20187](https://github.com/open-webui/open-webui/pull/20187) +- 🔗 Users now experience more reliable citations from AI models, especially when using smaller or weaker models that may not format citation references perfectly. [Commit](https://github.com/open-webui/open-webui/commit/c0ec04935b4eea3d334bfdec2fc41278f1085a49) +- 🕸️ Administrators can now successfully save WebSearch settings without encountering validation errors for domain filter lists, YouTube language settings, or timeout values. [#20422](https://github.com/open-webui/open-webui/pull/20422) +- 📦 Administrators installing with the uv package manager now experience successful installation after deprecated dependencies that were causing conflicts were removed. [#20177](https://github.com/open-webui/open-webui/issues/20177), [#20192](https://github.com/open-webui/open-webui/pull/20192) +- ⏱️ Administrators using custom "AIOHTTP_CLIENT_TIMEOUT" settings now see the configured timeout correctly applied to embedding generation, OAuth discovery, webhook calls, and tool/function loading instead of falling back to the default 300-second timeout. [Commit](https://github.com/open-webui/open-webui/commit/e67891a374625d9888ec391da561f0b4ed79ed5d) + +### Changed + +- ⚠️ This release includes a major overhaul of database connection handling in the backend that requires all instances in multi-worker, multi-server, or load-balanced deployments to be updated simultaneously; running mixed versions will cause failures due to incompatible database connection management between old and new instances. +- 📝 Administrators who previously used the standalone "Markdown (Header)" text splitter must now switch to "character" or "token" mode with the new "ENABLE_MARKDOWN_HEADER_TEXT_SPLITTER" toggle enabled, as document chunking now applies markdown header splitting as a preprocessing step before character or token splitting. [Commit](https://github.com/open-webui/open-webui/commit/d3a682759f375c6cb0bc8c170f62863a070f712e), [Commit](https://github.com/open-webui/open-webui/commit/18a33a079bf07487edffc421a721c86194cc90c9), [Commit](https://github.com/open-webui/open-webui/commit/08bf4670ec862018f9dc57296cb19fd5eab14ef6) +- 🖼️ Users no longer see the "Generate Image" action button in chat message interfaces; custom function should be used. [Commit](https://github.com/open-webui/open-webui/commit/f0829ba6e6fd200702fb76efc43dd785cf87fec9) +- 🔗 Administrators will find the Admin Evaluations page at the new URL "/admin/evaluations/feedback" instead of "/admin/evaluations/feedbacks" to use the correct uncountable form of the word. [#20296](https://github.com/open-webui/open-webui/pull/20296) +- 🔐 Scripts or integrations that directly called Image Generation, Web Search, or Audio APIs while those features were disabled in the Admin UI will now receive 403 Forbidden errors, as backend permission enforcement has been added to match frontend restrictions. [#20471](https://github.com/open-webui/open-webui/pull/20471) +- 👥 The default group sharing permission changed from "Members" to "Anyone", meaning users can now share content to any group configured with "Anyone" permission regardless of their membership in that group. [Commit](https://github.com/open-webui/open-webui/commit/ca514cd3eda2524b8da472ef17c0ccb216bac2e8) + ## [0.6.43] - 2025-12-22 ### Fixed diff --git a/CHANGELOG_EXTRA.md b/CHANGELOG_EXTRA.md index 7fb8788b30..a0c1f25a93 100644 --- a/CHANGELOG_EXTRA.md +++ b/CHANGELOG_EXTRA.md @@ -5,6 +5,12 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [0.7.1.1] - 2026.01.10 + +### Changed + +- 合并官方 0.7.1 改动 + ## [0.6.43.2] - 2025.12.22 ### Fixed diff --git a/Dockerfile b/Dockerfile index 52a910906e..ca9fe71e73 100644 --- a/Dockerfile +++ b/Dockerfile @@ -13,6 +13,7 @@ ARG USE_CUDA_VER=cu128 # IMPORTANT: If you change the embedding model (sentence-transformers/all-MiniLM-L6-v2) and vice versa, you aren't able to use RAG Chat with your previous documents loaded in the WebUI! You need to re-embed them. ARG USE_EMBEDDING_MODEL=sentence-transformers/all-MiniLM-L6-v2 ARG USE_RERANKING_MODEL="" +ARG USE_AUXILIARY_EMBEDDING_MODEL=TaylorAI/bge-micro-v2 # Tiktoken encoding name; models to use can be found at https://huggingface.co/models?library=tiktoken ARG USE_TIKTOKEN_ENCODING_NAME="cl100k_base" @@ -35,6 +36,7 @@ WORKDIR /app RUN apk add --no-cache git COPY package.json package-lock.json ./ +RUN npm install -g npm@latest RUN npm ci --force COPY . . @@ -42,7 +44,7 @@ ENV APP_BUILD_HASH=${BUILD_HASH} RUN npm run build ######## WebUI backend ######## -FROM python:3.11-slim-bookworm AS base +FROM python:3.11.14-slim-bookworm AS base # Use args ARG USE_CUDA @@ -52,6 +54,7 @@ ARG USE_SLIM ARG USE_PERMISSION_HARDENING ARG USE_EMBEDDING_MODEL ARG USE_RERANKING_MODEL +ARG USE_AUXILIARY_EMBEDDING_MODEL ARG UID ARG GID @@ -67,7 +70,8 @@ ENV ENV=prod \ USE_SLIM_DOCKER=${USE_SLIM} \ USE_CUDA_DOCKER_VER=${USE_CUDA_VER} \ USE_EMBEDDING_MODEL_DOCKER=${USE_EMBEDDING_MODEL} \ - USE_RERANKING_MODEL_DOCKER=${USE_RERANKING_MODEL} + USE_RERANKING_MODEL_DOCKER=${USE_RERANKING_MODEL} \ + USE_AUXILIARY_EMBEDDING_MODEL_DOCKER=${USE_AUXILIARY_EMBEDDING_MODEL} ## Basis URL Config ## ENV OLLAMA_BASE_URL="/ollama" \ @@ -88,6 +92,7 @@ ENV WHISPER_MODEL="base" \ ## RAG Embedding model settings ## ENV RAG_EMBEDDING_MODEL="$USE_EMBEDDING_MODEL_DOCKER" \ RAG_RERANKING_MODEL="$USE_RERANKING_MODEL_DOCKER" \ + AUXILIARY_EMBEDDING_MODEL="$USE_AUXILIARY_EMBEDDING_MODEL_DOCKER" \ SENTENCE_TRANSFORMERS_HOME="/app/backend/data/cache/embedding/models" ## Tiktoken model settings ## @@ -136,6 +141,7 @@ RUN pip3 install --no-cache-dir uv && \ pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/$USE_CUDA_DOCKER_VER --no-cache-dir && \ uv pip install --system -r requirements.txt --no-cache-dir && \ python -c "import os; from sentence_transformers import SentenceTransformer; SentenceTransformer(os.environ['RAG_EMBEDDING_MODEL'], device='cpu')" && \ + python -c "import os; from sentence_transformers import SentenceTransformer; SentenceTransformer(os.environ.get('AUXILIARY_EMBEDDING_MODEL', 'TaylorAI/bge-micro-v2'), device='cpu')" && \ python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])"; \ python -c "import os; import tiktoken; tiktoken.get_encoding(os.environ['TIKTOKEN_ENCODING_NAME'])"; \ else \ @@ -143,6 +149,7 @@ RUN pip3 install --no-cache-dir uv && \ uv pip install --system -r requirements.txt --no-cache-dir && \ if [ "$USE_SLIM" != "true" ]; then \ python -c "import os; from sentence_transformers import SentenceTransformer; SentenceTransformer(os.environ['RAG_EMBEDDING_MODEL'], device='cpu')" && \ + python -c "import os; from sentence_transformers import SentenceTransformer; SentenceTransformer(os.environ.get('AUXILIARY_EMBEDDING_MODEL', 'TaylorAI/bge-micro-v2'), device='cpu')" && \ python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])"; \ python -c "import os; import tiktoken; tiktoken.get_encoding(os.environ['TIKTOKEN_ENCODING_NAME'])"; \ fi; \ diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index 763cad1616..e19644da17 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -20,6 +20,7 @@ from open_webui.env import ( DATA_DIR, DATABASE_URL, + ENABLE_DB_MIGRATIONS, ENV, REDIS_URL, REDIS_KEY_PREFIX, @@ -33,7 +34,7 @@ WEBUI_NAME, log, ) -from open_webui.internal.db import Base, get_db, Session +from open_webui.internal.db import Base, get_db from open_webui.utils.redis import get_redis_connection @@ -1433,6 +1434,14 @@ def feishu_oauth_register(oauth: OAuth): os.environ.get("USER_PERMISSIONS_FEATURES_API_KEYS", "False").lower() == "true" ) +USER_PERMISSIONS_FEATURES_MEMORIES = ( + os.environ.get("USER_PERMISSIONS_FEATURES_MEMORIES", "True").lower() == "true" +) + +USER_PERMISSIONS_SETTINGS_INTERFACE = ( + os.environ.get("USER_PERMISSIONS_SETTINGS_INTERFACE", "True").lower() == "true" +) + DEFAULT_USER_PERMISSIONS = { "workspace": { "models": USER_PERMISSIONS_WORKSPACE_MODELS_ACCESS, @@ -1490,6 +1499,10 @@ def feishu_oauth_register(oauth: OAuth): "web_search": USER_PERMISSIONS_FEATURES_WEB_SEARCH, "image_generation": USER_PERMISSIONS_FEATURES_IMAGE_GENERATION, "code_interpreter": USER_PERMISSIONS_FEATURES_CODE_INTERPRETER, + "memories": USER_PERMISSIONS_FEATURES_MEMORIES, + }, + "settings": { + "interface": USER_PERMISSIONS_SETTINGS_INTERFACE, }, } @@ -1505,6 +1518,12 @@ def feishu_oauth_register(oauth: OAuth): os.environ.get("ENABLE_FOLDERS", "True").lower() == "true", ) +FOLDER_MAX_FILE_COUNT = PersistentConfig( + "FOLDER_MAX_FILE_COUNT", + "folders.max_file_count", + os.environ.get("FOLDER_MAX_FILE_COUNT", ""), +) + ENABLE_CHANNELS = PersistentConfig( "ENABLE_CHANNELS", "channels.enable", @@ -1517,6 +1536,12 @@ def feishu_oauth_register(oauth: OAuth): os.environ.get("ENABLE_NOTES", "True").lower() == "true", ) +ENABLE_USER_STATUS = PersistentConfig( + "ENABLE_USER_STATUS", + "users.enable_status", + os.environ.get("ENABLE_USER_STATUS", "True").lower() == "true", +) + ENABLE_EVALUATION_ARENA_MODELS = PersistentConfig( "ENABLE_EVALUATION_ARENA_MODELS", "evaluation.arena.enable", @@ -2020,6 +2045,12 @@ class BannerModel(BaseModel): os.environ.get("ENABLE_CODE_INTERPRETER", "True").lower() == "true", ) +ENABLE_MEMORIES = PersistentConfig( + "ENABLE_MEMORIES", + "memories.enable", + os.environ.get("ENABLE_MEMORIES", "True").lower() == "true", +) + CODE_INTERPRETER_ENGINE = PersistentConfig( "CODE_INTERPRETER_ENGINE", "code_interpreter.engine", @@ -2290,6 +2321,51 @@ class BannerModel(BaseModel): except Exception: PGVECTOR_IVFFLAT_LISTS = 100 +# openGauss +OPENGAUSS_DB_URL = os.environ.get("OPENGAUSS_DB_URL", DATABASE_URL) + +OPENGAUSS_INITIALIZE_MAX_VECTOR_LENGTH = int( + os.environ.get("OPENGAUSS_INITIALIZE_MAX_VECTOR_LENGTH", "1536") +) + +OPENGAUSS_POOL_SIZE = os.environ.get("OPENGAUSS_POOL_SIZE", None) + +if OPENGAUSS_POOL_SIZE != None: + try: + OPENGAUSS_POOL_SIZE = int(OPENGAUSS_POOL_SIZE) + except Exception: + OPENGAUSS_POOL_SIZE = None + +OPENGAUSS_POOL_MAX_OVERFLOW = os.environ.get("OPENGAUSS_POOL_MAX_OVERFLOW", 0) + +if OPENGAUSS_POOL_MAX_OVERFLOW == "": + OPENGAUSS_POOL_MAX_OVERFLOW = 0 +else: + try: + OPENGAUSS_POOL_MAX_OVERFLOW = int(OPENGAUSS_POOL_MAX_OVERFLOW) + except Exception: + OPENGAUSS_POOL_MAX_OVERFLOW = 0 + +OPENGAUSS_POOL_TIMEOUT = os.environ.get("OPENGAUSS_POOL_TIMEOUT", 30) + +if OPENGAUSS_POOL_TIMEOUT == "": + OPENGAUSS_POOL_TIMEOUT = 30 +else: + try: + OPENGAUSS_POOL_TIMEOUT = int(OPENGAUSS_POOL_TIMEOUT) + except Exception: + OPENGAUSS_POOL_TIMEOUT = 30 + +OPENGAUSS_POOL_RECYCLE = os.environ.get("OPENGAUSS_POOL_RECYCLE", 3600) + +if OPENGAUSS_POOL_RECYCLE == "": + OPENGAUSS_POOL_RECYCLE = 3600 +else: + try: + OPENGAUSS_POOL_RECYCLE = int(OPENGAUSS_POOL_RECYCLE) + except Exception: + OPENGAUSS_POOL_RECYCLE = 3600 + # Pinecone PINECONE_API_KEY = os.environ.get("PINECONE_API_KEY", None) PINECONE_ENVIRONMENT = os.environ.get("PINECONE_ENVIRONMENT", None) @@ -2753,13 +2829,18 @@ class BannerModel(BaseModel): os.environ.get("RAG_EXTERNAL_RERANKER_TIMEOUT", ""), ) - RAG_TEXT_SPLITTER = PersistentConfig( "RAG_TEXT_SPLITTER", "rag.text_splitter", os.environ.get("RAG_TEXT_SPLITTER", ""), ) +ENABLE_MARKDOWN_HEADER_TEXT_SPLITTER = PersistentConfig( + "ENABLE_MARKDOWN_HEADER_TEXT_SPLITTER", + "rag.enable_markdown_header_text_splitter", + os.environ.get("ENABLE_MARKDOWN_HEADER_TEXT_SPLITTER", "True").lower() == "true", +) + TIKTOKEN_CACHE_DIR = os.environ.get("TIKTOKEN_CACHE_DIR", f"{CACHE_DIR}/tiktoken") TIKTOKEN_ENCODING_NAME = PersistentConfig( "TIKTOKEN_ENCODING_NAME", @@ -2770,6 +2851,13 @@ class BannerModel(BaseModel): CHUNK_SIZE = PersistentConfig( "CHUNK_SIZE", "rag.chunk_size", int(os.environ.get("CHUNK_SIZE", "1000")) ) + +CHUNK_MIN_SIZE_TARGET = PersistentConfig( + "CHUNK_MIN_SIZE_TARGET", + "rag.chunk_min_size_target", + int(os.environ.get("CHUNK_MIN_SIZE_TARGET", "0")), +) + CHUNK_OVERLAP = PersistentConfig( "CHUNK_OVERLAP", "rag.chunk_overlap", @@ -2952,7 +3040,6 @@ class BannerModel(BaseModel): os.getenv("WEB_LOADER_TIMEOUT", ""), ) - ENABLE_WEB_LOADER_SSL_VERIFICATION = PersistentConfig( "ENABLE_WEB_LOADER_SSL_VERIFICATION", "rag.web.loader.ssl_verification", @@ -3061,12 +3148,24 @@ class BannerModel(BaseModel): os.getenv("SERPLY_API_KEY", ""), ) +DDGS_BACKEND = PersistentConfig( + "DDGS_BACKEND", + "rag.web.search.ddgs_backend", + os.getenv("DDGS_BACKEND", "auto"), +) + JINA_API_KEY = PersistentConfig( "JINA_API_KEY", "rag.web.search.jina_api_key", os.getenv("JINA_API_KEY", ""), ) +JINA_API_BASE_URL = PersistentConfig( + "JINA_API_BASE_URL", + "rag.web.search.jina_api_base_url", + os.getenv("JINA_API_BASE_URL", ""), +) + SEARCHAPI_API_KEY = PersistentConfig( "SEARCHAPI_API_KEY", "rag.web.search.searchapi_api_key", @@ -3201,6 +3300,12 @@ class BannerModel(BaseModel): os.environ.get("FIRECRAWL_API_BASE_URL", "https://api.firecrawl.dev"), ) +FIRECRAWL_TIMEOUT = PersistentConfig( + "FIRECRAWL_TIMEOUT", + "rag.web.loader.firecrawl_timeout", + os.environ.get("FIRECRAWL_TIMEOUT", ""), +) + EXTERNAL_WEB_SEARCH_URL = PersistentConfig( "EXTERNAL_WEB_SEARCH_URL", "rag.web.search.external_web_search_url", @@ -3558,17 +3663,16 @@ class BannerModel(BaseModel): os.getenv("WHISPER_MODEL", "base"), ) +WHISPER_COMPUTE_TYPE = os.getenv("WHISPER_COMPUTE_TYPE", "int8") WHISPER_MODEL_DIR = os.getenv("WHISPER_MODEL_DIR", f"{CACHE_DIR}/whisper/models") WHISPER_MODEL_AUTO_UPDATE = ( not OFFLINE_MODE and os.environ.get("WHISPER_MODEL_AUTO_UPDATE", "").lower() == "true" ) -WHISPER_VAD_FILTER = PersistentConfig( - "WHISPER_VAD_FILTER", - "audio.stt.whisper_vad_filter", - os.getenv("WHISPER_VAD_FILTER", "False").lower() == "true", -) +WHISPER_VAD_FILTER = os.getenv("WHISPER_VAD_FILTER", "False").lower() == "true" + +WHISPER_MULTILINGUAL = os.getenv("WHISPER_MULTILINGUAL", "False").lower() == "true" WHISPER_LANGUAGE = os.getenv("WHISPER_LANGUAGE", "").lower() or None diff --git a/backend/open_webui/env.py b/backend/open_webui/env.py index ca3f8b28aa..91be977371 100644 --- a/backend/open_webui/env.py +++ b/backend/open_webui/env.py @@ -117,6 +117,8 @@ DEPLOYMENT_ID = os.environ.get("DEPLOYMENT_ID", "") INSTANCE_ID = os.environ.get("INSTANCE_ID", str(uuid4())) +ENABLE_DB_MIGRATIONS = os.environ.get("ENABLE_DB_MIGRATIONS", "True").lower() == "true" + # Function to parse each section def parse_section(section): @@ -341,6 +343,11 @@ def parse_section(section): except Exception: DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL = 0.0 +# When enabled, get_db_context reuses existing sessions; set to False to always create new sessions +DATABASE_ENABLE_SESSION_SHARING = ( + os.environ.get("DATABASE_ENABLE_SESSION_SHARING", "False").lower() == "true" +) + # Enable public visibility of active user count (when disabled, only admins can see it) ENABLE_PUBLIC_ACTIVE_USERS_COUNT = ( os.environ.get("ENABLE_PUBLIC_ACTIVE_USERS_COUNT", "True").lower() == "true" @@ -356,6 +363,8 @@ def parse_section(section): ENABLE_QUERIES_CACHE = os.environ.get("ENABLE_QUERIES_CACHE", "False").lower() == "true" +RAG_SYSTEM_CONTEXT = os.environ.get("RAG_SYSTEM_CONTEXT", "False").lower() == "true" + #################################### # REDIS #################################### @@ -411,6 +420,16 @@ def parse_section(section): os.environ.get("ENABLE_SIGNUP_PASSWORD_CONFIRMATION", "False").lower() == "true" ) +#################################### +# Admin Account Runtime Creation +#################################### + +# Optional env vars for creating an admin account on startup +# Useful for headless/automated deployments +WEBUI_ADMIN_EMAIL = os.environ.get("WEBUI_ADMIN_EMAIL", "") +WEBUI_ADMIN_PASSWORD = os.environ.get("WEBUI_ADMIN_PASSWORD", "") +WEBUI_ADMIN_NAME = os.environ.get("WEBUI_ADMIN_NAME", "Admin") + WEBUI_AUTH_TRUSTED_EMAIL_HEADER = os.environ.get( "WEBUI_AUTH_TRUSTED_EMAIL_HEADER", None ) @@ -428,12 +447,14 @@ def parse_section(section): "^(?=.*[a-z])(?=.*[A-Z])(?=.*\d)(?=.*[^\w\s]).{8,}$", ) + try: + PASSWORD_VALIDATION_REGEX_PATTERN = rf"{PASSWORD_VALIDATION_REGEX_PATTERN}" PASSWORD_VALIDATION_REGEX_PATTERN = re.compile(PASSWORD_VALIDATION_REGEX_PATTERN) except Exception as e: log.error(f"Invalid PASSWORD_VALIDATION_REGEX_PATTERN: {e}") PASSWORD_VALIDATION_REGEX_PATTERN = re.compile( - "^(?=.*[a-z])(?=.*[A-Z])(?=.*\d)(?=.*[^\w\s]).{8,}$" + r"^(?=.*[a-z])(?=.*[A-Z])(?=.*\d)(?=.*[^\w\s]).{8,}$" ) @@ -669,6 +690,8 @@ def parse_section(section): WEBSOCKET_SERVER_PING_INTERVAL = 25 +REQUESTS_VERIFY = os.environ.get("REQUESTS_VERIFY", "True").lower() == "true" + AIOHTTP_CLIENT_TIMEOUT = os.environ.get("AIOHTTP_CLIENT_TIMEOUT", "") if AIOHTTP_CLIENT_TIMEOUT == "": @@ -766,6 +789,16 @@ def parse_section(section): except Exception: SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS = None +# Whether to apply sigmoid normalization to CrossEncoder reranking scores. +# When enabled (default), scores are normalized to 0-1 range for proper +# relevance threshold behavior with MS MARCO models. +SENTENCE_TRANSFORMERS_CROSS_ENCODER_SIGMOID_ACTIVATION_FUNCTION = ( + os.environ.get( + "SENTENCE_TRANSFORMERS_CROSS_ENCODER_SIGMOID_ACTIVATION_FUNCTION", "True" + ).lower() + == "true" +) + #################################### # OFFLINE_MODE #################################### @@ -782,6 +815,11 @@ def parse_section(section): #################################### # AUDIT LOGGING #################################### + + +ENABLE_AUDIT_STDOUT = os.getenv("ENABLE_AUDIT_STDOUT", "False").lower() == "true" +ENABLE_AUDIT_LOGS_FILE = os.getenv("ENABLE_AUDIT_LOGS_FILE", "True").lower() == "true" + # Where to store log file # Defaults to the DATA_DIR/audit.log. To set AUDIT_LOGS_FILE_PATH you need to # provide the whole path, like: /app/audit.log diff --git a/backend/open_webui/internal/db.py b/backend/open_webui/internal/db.py index a5eecd6605..6050e37fa3 100644 --- a/backend/open_webui/internal/db.py +++ b/backend/open_webui/internal/db.py @@ -14,11 +14,13 @@ DATABASE_POOL_SIZE, DATABASE_POOL_TIMEOUT, DATABASE_ENABLE_SQLITE_WAL, + DATABASE_ENABLE_SESSION_SHARING, + ENABLE_DB_MIGRATIONS, ) from peewee_migrate import Router from sqlalchemy import Dialect, create_engine, MetaData, event, types from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import scoped_session, sessionmaker +from sqlalchemy.orm import scoped_session, sessionmaker, Session from sqlalchemy.pool import QueuePool, NullPool from sqlalchemy.sql.type_api import _T from typing_extensions import Self @@ -75,7 +77,8 @@ def handle_peewee_migration(DATABASE_URL): assert db.is_closed(), "Database connection is still open." -handle_peewee_migration(DATABASE_URL) +if ENABLE_DB_MIGRATIONS: + handle_peewee_migration(DATABASE_URL) SQLALCHEMY_DATABASE_URL = DATABASE_URL @@ -146,7 +149,7 @@ def on_connect(dbapi_connection, connection_record): ) metadata_obj = MetaData(schema=DATABASE_SCHEMA) Base = declarative_base(metadata=metadata_obj) -Session = scoped_session(SessionLocal) +ScopedSession = scoped_session(SessionLocal) def get_session(): @@ -158,3 +161,12 @@ def get_session(): get_db = contextmanager(get_session) + + +@contextmanager +def get_db_context(db: Optional[Session] = None): + if isinstance(db, Session) and DATABASE_ENABLE_SESSION_SHARING: + yield db + else: + with get_db() as session: + yield session diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index e00f2595fb..a33e765ff3 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -96,7 +96,9 @@ get_rf, ) -from open_webui.internal.db import Session, engine + +from sqlalchemy.orm import Session +from open_webui.internal.db import ScopedSession, engine, get_session from open_webui.models.functions import Functions from open_webui.models.models import Models @@ -137,6 +139,7 @@ CODE_INTERPRETER_JUPYTER_AUTH_TOKEN, CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD, CODE_INTERPRETER_JUPYTER_TIMEOUT, + ENABLE_MEMORIES, # Image AUTOMATIC1111_API_AUTH, AUTOMATIC1111_BASE_URL, @@ -200,6 +203,7 @@ PLAYWRIGHT_TIMEOUT, FIRECRAWL_API_BASE_URL, FIRECRAWL_API_KEY, + FIRECRAWL_TIMEOUT, WEB_LOADER_ENGINE, WEB_LOADER_CONCURRENT_REQUESTS, WEB_LOADER_TIMEOUT, @@ -240,6 +244,7 @@ RAG_OLLAMA_BASE_URL, RAG_OLLAMA_API_KEY, CHUNK_OVERLAP, + CHUNK_MIN_SIZE_TARGET, CHUNK_SIZE, CONTENT_EXTRACTION_ENGINE, DATALAB_MARKER_API_KEY, @@ -270,6 +275,7 @@ MISTRAL_OCR_API_BASE_URL, MISTRAL_OCR_API_KEY, RAG_TEXT_SPLITTER, + ENABLE_MARKDOWN_HEADER_TEXT_SPLITTER, TIKTOKEN_ENCODING_NAME, PDF_EXTRACT_IMAGES, YOUTUBE_LOADER_LANGUAGE, @@ -285,6 +291,7 @@ WEB_SEARCH_DOMAIN_FILTER_LIST, OLLAMA_CLOUD_WEB_SEARCH_API_KEY, JINA_API_KEY, + JINA_API_BASE_URL, SEARCHAPI_API_KEY, SEARCHAPI_ENGINE, SERPAPI_API_KEY, @@ -296,6 +303,7 @@ YACY_PASSWORD, SERPER_API_KEY, SERPLY_API_KEY, + DDGS_BACKEND, SERPSTACK_API_KEY, SERPSTACK_HTTPS, TAVILY_API_KEY, @@ -348,8 +356,10 @@ ENABLE_API_KEYS_ENDPOINT_RESTRICTIONS, API_KEYS_ALLOWED_ENDPOINTS, ENABLE_FOLDERS, + FOLDER_MAX_FILE_COUNT, ENABLE_CHANNELS, ENABLE_NOTES, + ENABLE_USER_STATUS, ENABLE_COMMUNITY_SHARING, ENABLE_MESSAGE_RATING, ENABLE_USER_WEBHOOKS, @@ -500,6 +510,10 @@ ENABLE_STAR_SESSIONS_MIDDLEWARE, BASE_DIR, ENABLE_PUBLIC_ACTIVE_USERS_COUNT, + # Admin Account Runtime Creation + WEBUI_ADMIN_EMAIL, + WEBUI_ADMIN_PASSWORD, + WEBUI_ADMIN_NAME, ) from open_webui.utils.models import ( @@ -523,6 +537,7 @@ decode_token, get_admin_user, get_verified_user, + create_admin_user, ) from open_webui.utils.plugin import install_tool_and_function_dependencies from open_webui.utils.oauth import ( @@ -601,6 +616,12 @@ async def lifespan(app: FastAPI): if LICENSE_KEY: get_license_data(app, LICENSE_KEY) + # Create admin account from env vars if specified and no users exist + if WEBUI_ADMIN_EMAIL and WEBUI_ADMIN_PASSWORD: + if create_admin_user(WEBUI_ADMIN_EMAIL, WEBUI_ADMIN_PASSWORD, WEBUI_ADMIN_NAME): + # Disable signup since we now have an admin + app.state.config.ENABLE_SIGNUP = False + # This should be blocking (sync) so functions are not deactivated on first /get_models calls # when the first user lands on the / route. log.info("Installing external dependencies of functions and tools...") @@ -802,11 +823,13 @@ async def lifespan(app: FastAPI): app.state.config.ENABLE_FOLDERS = ENABLE_FOLDERS +app.state.config.FOLDER_MAX_FILE_COUNT = FOLDER_MAX_FILE_COUNT app.state.config.ENABLE_CHANNELS = ENABLE_CHANNELS app.state.config.ENABLE_NOTES = ENABLE_NOTES app.state.config.ENABLE_COMMUNITY_SHARING = ENABLE_COMMUNITY_SHARING app.state.config.ENABLE_MESSAGE_RATING = ENABLE_MESSAGE_RATING app.state.config.ENABLE_USER_WEBHOOKS = ENABLE_USER_WEBHOOKS +app.state.config.ENABLE_USER_STATUS = ENABLE_USER_STATUS app.state.config.ENABLE_EVALUATION_ARENA_MODELS = ENABLE_EVALUATION_ARENA_MODELS app.state.config.EVALUATION_ARENA_MODELS = EVALUATION_ARENA_MODELS @@ -913,11 +936,17 @@ async def lifespan(app: FastAPI): app.state.config.MINERU_PARAMS = MINERU_PARAMS app.state.config.TEXT_SPLITTER = RAG_TEXT_SPLITTER +app.state.config.ENABLE_MARKDOWN_HEADER_TEXT_SPLITTER = ( + ENABLE_MARKDOWN_HEADER_TEXT_SPLITTER +) + app.state.config.TIKTOKEN_ENCODING_NAME = TIKTOKEN_ENCODING_NAME app.state.config.CHUNK_SIZE = CHUNK_SIZE +app.state.config.CHUNK_MIN_SIZE_TARGET = CHUNK_MIN_SIZE_TARGET app.state.config.CHUNK_OVERLAP = CHUNK_OVERLAP + app.state.config.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE app.state.config.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL app.state.config.RAG_EMBEDDING_BATCH_SIZE = RAG_EMBEDDING_BATCH_SIZE @@ -981,12 +1010,14 @@ async def lifespan(app: FastAPI): app.state.config.SERPSTACK_HTTPS = SERPSTACK_HTTPS app.state.config.SERPER_API_KEY = SERPER_API_KEY app.state.config.SERPLY_API_KEY = SERPLY_API_KEY +app.state.config.DDGS_BACKEND = DDGS_BACKEND app.state.config.TAVILY_API_KEY = TAVILY_API_KEY app.state.config.SEARCHAPI_API_KEY = SEARCHAPI_API_KEY app.state.config.SEARCHAPI_ENGINE = SEARCHAPI_ENGINE app.state.config.SERPAPI_API_KEY = SERPAPI_API_KEY app.state.config.SERPAPI_ENGINE = SERPAPI_ENGINE app.state.config.JINA_API_KEY = JINA_API_KEY +app.state.config.JINA_API_BASE_URL = JINA_API_BASE_URL app.state.config.BING_SEARCH_V7_ENDPOINT = BING_SEARCH_V7_ENDPOINT app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY = BING_SEARCH_V7_SUBSCRIPTION_KEY app.state.config.EXA_API_KEY = EXA_API_KEY @@ -1005,6 +1036,7 @@ async def lifespan(app: FastAPI): app.state.config.PLAYWRIGHT_TIMEOUT = PLAYWRIGHT_TIMEOUT app.state.config.FIRECRAWL_API_BASE_URL = FIRECRAWL_API_BASE_URL app.state.config.FIRECRAWL_API_KEY = FIRECRAWL_API_KEY +app.state.config.FIRECRAWL_TIMEOUT = FIRECRAWL_TIMEOUT app.state.config.TAVILY_EXTRACT_DEPTH = TAVILY_EXTRACT_DEPTH app.state.EMBEDDING_FUNCTION = None @@ -1111,6 +1143,7 @@ async def lifespan(app: FastAPI): app.state.config.IMAGE_GENERATION_ENGINE = IMAGE_GENERATION_ENGINE app.state.config.ENABLE_IMAGE_GENERATION = ENABLE_IMAGE_GENERATION app.state.config.ENABLE_IMAGE_PROMPT_GENERATION = ENABLE_IMAGE_PROMPT_GENERATION +app.state.config.ENABLE_MEMORIES = ENABLE_MEMORIES app.state.config.IMAGE_GENERATION_MODEL = IMAGE_GENERATION_MODEL app.state.config.IMAGE_SIZE = IMAGE_SIZE @@ -1163,7 +1196,6 @@ async def lifespan(app: FastAPI): app.state.config.STT_OPENAI_API_KEY = AUDIO_STT_OPENAI_API_KEY app.state.config.WHISPER_MODEL = WHISPER_MODEL -app.state.config.WHISPER_VAD_FILTER = WHISPER_VAD_FILTER app.state.config.DEEPGRAM_API_KEY = DEEPGRAM_API_KEY app.state.config.AUDIO_STT_AZURE_API_KEY = AUDIO_STT_AZURE_API_KEY @@ -1383,7 +1415,7 @@ async def dispatch(self, request: Request, call_next): async def commit_session_after_request(request: Request, call_next): response = await call_next(request) # log.debug("Commit session after request") - Session.commit() + ScopedSession.commit() return response @@ -1717,7 +1749,7 @@ async def chat_completion( ) # Insert chat files from parent message if any - parent_message = metadata.get("parent_message", {}) + parent_message = metadata.get("parent_message") or {} parent_message_files = parent_message.get("files", []) if parent_message_files: try: @@ -1979,6 +2011,7 @@ async def get_app_config(request: Request): { "enable_direct_connections": app.state.config.ENABLE_DIRECT_CONNECTIONS, "enable_folders": app.state.config.ENABLE_FOLDERS, + "folder_max_file_count": app.state.config.FOLDER_MAX_FILE_COUNT, "enable_channels": app.state.config.ENABLE_CHANNELS, "enable_notes": app.state.config.ENABLE_NOTES, "enable_web_search": app.state.config.ENABLE_WEB_SEARCH, @@ -1989,10 +2022,12 @@ async def get_app_config(request: Request): "enable_community_sharing": app.state.config.ENABLE_COMMUNITY_SHARING, "enable_message_rating": app.state.config.ENABLE_MESSAGE_RATING, "enable_user_webhooks": app.state.config.ENABLE_USER_WEBHOOKS, + "enable_user_status": app.state.config.ENABLE_USER_STATUS, "enable_admin_export": ENABLE_ADMIN_EXPORT, "enable_admin_chat_access": ENABLE_ADMIN_CHAT_ACCESS, "enable_google_drive_integration": app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION, "enable_onedrive_integration": app.state.config.ENABLE_ONEDRIVE_INTEGRATION, + "enable_memories": app.state.config.ENABLE_MEMORIES, **( { "enable_onedrive_personal": ENABLE_ONEDRIVE_PERSONAL, @@ -2357,8 +2392,13 @@ async def oauth_login(provider: str, request: Request): # - Email addresses are considered unique, so we fail registration if the email address is already taken @app.get("/oauth/{provider}/login/callback") @app.get("/oauth/{provider}/callback") # Legacy endpoint -async def oauth_login_callback(provider: str, request: Request, response: Response): - return await oauth_manager.handle_callback(request, provider, response) +async def oauth_login_callback( + provider: str, + request: Request, + response: Response, + db: Session = Depends(get_session), +): + return await oauth_manager.handle_callback(request, provider, response, db=db) @app.get("/manifest.json") @@ -2417,7 +2457,7 @@ async def healthcheck(): @app.get("/health/db") async def healthcheck_with_db(): - Session.execute(text("SELECT 1;")).all() + ScopedSession.execute(text("SELECT 1;")).all() return {"status": True} diff --git a/backend/open_webui/migrate.py b/backend/open_webui/migrate.py index 241d49b13c..99fb39dfdb 100644 --- a/backend/open_webui/migrate.py +++ b/backend/open_webui/migrate.py @@ -1,5 +1,5 @@ -from open_webui.env import OPEN_WEBUI_DIR, log -from open_webui.internal.db import Session +from open_webui.env import OPEN_WEBUI_DIR, log, ENABLE_DB_MIGRATIONS +from open_webui.internal.db import ScopedSession from sqlalchemy import text @@ -33,7 +33,7 @@ def run_extra_migrations(): # do migrations try: # load version from db - current_version = Session.execute( + current_version = ScopedSession.execute( text("SELECT version_num FROM alembic_version") ).scalar_one() @@ -73,5 +73,6 @@ def run_extra_migrations(): if __name__ == "__main__": - run_migrations() - run_extra_migrations() + if ENABLE_DB_MIGRATIONS: + run_migrations() + run_extra_migrations() diff --git a/backend/open_webui/models/auths.py b/backend/open_webui/models/auths.py index cb6c057c88..93f17dff11 100644 --- a/backend/open_webui/models/auths.py +++ b/backend/open_webui/models/auths.py @@ -2,7 +2,8 @@ import uuid from typing import Optional -from open_webui.internal.db import Base, get_db +from sqlalchemy.orm import Session +from open_webui.internal.db import Base, JSONField, get_db, get_db_context from open_webui.models.users import UserModel, UserProfileImageResponse, Users from pydantic import BaseModel from sqlalchemy import Boolean, Column, String, Text @@ -87,8 +88,9 @@ def insert_new_auth( profile_image_url: str = "/user.png", role: str = "pending", oauth: Optional[dict] = None, + db: Optional[Session] = None, ) -> Optional[UserModel]: - with get_db() as db: + with get_db_context(db) as db: log.info("insert_new_auth") id = str(uuid.uuid4()) @@ -100,7 +102,7 @@ def insert_new_auth( db.add(result) user = Users.insert_new_user( - id, name, email, profile_image_url, role, oauth=oauth + id, name, email, profile_image_url, role, oauth=oauth, db=db ) db.commit() @@ -112,16 +114,16 @@ def insert_new_auth( return None def authenticate_user( - self, email: str, verify_password: callable + self, email: str, verify_password: callable, db: Optional[Session] = None ) -> Optional[UserModel]: log.info(f"authenticate_user: {email}") - user = Users.get_user_by_email(email) + user = Users.get_user_by_email(email, db=db) if not user: return None try: - with get_db() as db: + with get_db_context(db) as db: auth = db.query(Auth).filter_by(id=user.id, active=True).first() if auth: if verify_password(auth.password): @@ -133,32 +135,38 @@ def authenticate_user( except Exception: return None - def authenticate_user_by_api_key(self, api_key: str) -> Optional[UserModel]: + def authenticate_user_by_api_key( + self, api_key: str, db: Optional[Session] = None + ) -> Optional[UserModel]: log.info(f"authenticate_user_by_api_key: {api_key}") # if no api_key, return None if not api_key: return None try: - user = Users.get_user_by_api_key(api_key) + user = Users.get_user_by_api_key(api_key, db=db) return user if user else None except Exception: return False - def authenticate_user_by_email(self, email: str) -> Optional[UserModel]: + def authenticate_user_by_email( + self, email: str, db: Optional[Session] = None + ) -> Optional[UserModel]: log.info(f"authenticate_user_by_email: {email}") try: - with get_db() as db: + with get_db_context(db) as db: auth = db.query(Auth).filter_by(email=email, active=True).first() if auth: - user = Users.get_user_by_id(auth.id) + user = Users.get_user_by_id(auth.id, db=db) return user except Exception: return None - def update_user_password_by_id(self, id: str, new_password: str) -> bool: + def update_user_password_by_id( + self, id: str, new_password: str, db: Optional[Session] = None + ) -> bool: try: - with get_db() as db: + with get_db_context(db) as db: result = ( db.query(Auth).filter_by(id=id).update({"password": new_password}) ) @@ -167,20 +175,22 @@ def update_user_password_by_id(self, id: str, new_password: str) -> bool: except Exception: return False - def update_email_by_id(self, id: str, email: str) -> bool: + def update_email_by_id( + self, id: str, email: str, db: Optional[Session] = None + ) -> bool: try: - with get_db() as db: + with get_db_context(db) as db: result = db.query(Auth).filter_by(id=id).update({"email": email}) db.commit() return True if result == 1 else False except Exception: return False - def delete_auth_by_id(self, id: str) -> bool: + def delete_auth_by_id(self, id: str, db: Optional[Session] = None) -> bool: try: - with get_db() as db: + with get_db_context(db) as db: # Delete User - result = Users.delete_user_by_id(id) + result = Users.delete_user_by_id(id, db=db) if result: db.query(Auth).filter_by(id=id).delete() diff --git a/backend/open_webui/models/channels.py b/backend/open_webui/models/channels.py index 362222a284..8e70918e1a 100644 --- a/backend/open_webui/models/channels.py +++ b/backend/open_webui/models/channels.py @@ -1,9 +1,11 @@ import json +import secrets import time import uuid from typing import Optional -from open_webui.internal.db import Base, get_db +from sqlalchemy.orm import Session +from open_webui.internal.db import Base, JSONField, get_db, get_db_context from open_webui.models.groups import Groups from pydantic import BaseModel, ConfigDict @@ -244,6 +246,11 @@ class CreateChannelForm(ChannelForm): type: Optional[str] = None +class ChannelWebhookForm(BaseModel): + name: str + profile_image_url: Optional[str] = None + + class ChannelTable: def _collect_unique_user_ids( @@ -304,9 +311,9 @@ def _create_membership_models( return memberships def insert_new_channel( - self, form_data: CreateChannelForm, user_id: str + self, form_data: CreateChannelForm, user_id: str, db: Optional[Session] = None ) -> Optional[ChannelModel]: - with get_db() as db: + with get_db_context(db) as db: channel = ChannelModel( **{ **form_data.model_dump(), @@ -337,8 +344,8 @@ def insert_new_channel( db.commit() return channel - def get_channels(self) -> list[ChannelModel]: - with get_db() as db: + def get_channels(self, db: Optional[Session] = None) -> list[ChannelModel]: + with get_db_context(db) as db: channels = db.query(Channel).all() return [ChannelModel.model_validate(channel) for channel in channels] @@ -384,10 +391,12 @@ def _has_permission(self, db, query, filter: dict, permission: str = "read"): return query - def get_channels_by_user_id(self, user_id: str) -> list[ChannelModel]: - with get_db() as db: + def get_channels_by_user_id( + self, user_id: str, db: Optional[Session] = None + ) -> list[ChannelModel]: + with get_db_context(db) as db: user_group_ids = [ - group.id for group in Groups.get_groups_by_member_id(user_id) + group.id for group in Groups.get_groups_by_member_id(user_id, db=db) ] membership_channels = ( @@ -421,8 +430,10 @@ def get_channels_by_user_id(self, user_id: str) -> list[ChannelModel]: all_channels = membership_channels + standard_channels return [ChannelModel.model_validate(c) for c in all_channels] - def get_dm_channel_by_user_ids(self, user_ids: list[str]) -> Optional[ChannelModel]: - with get_db() as db: + def get_dm_channel_by_user_ids( + self, user_ids: list[str], db: Optional[Session] = None + ) -> Optional[ChannelModel]: + with get_db_context(db) as db: # Ensure uniqueness in case a list with duplicates is passed unique_user_ids = list(set(user_ids)) @@ -460,8 +471,9 @@ def add_members_to_channel( invited_by: str, user_ids: Optional[list[str]] = None, group_ids: Optional[list[str]] = None, + db: Optional[Session] = None, ) -> list[ChannelMemberModel]: - with get_db() as db: + with get_db_context(db) as db: # 1. Collect all user_ids including groups + inviter requested_users = self._collect_unique_user_ids( invited_by, user_ids, group_ids @@ -494,8 +506,9 @@ def remove_members_from_channel( self, channel_id: str, user_ids: list[str], + db: Optional[Session] = None, ) -> int: - with get_db() as db: + with get_db_context(db) as db: result = ( db.query(ChannelMember) .filter( @@ -507,8 +520,10 @@ def remove_members_from_channel( db.commit() return result # number of rows deleted - def is_user_channel_manager(self, channel_id: str, user_id: str) -> bool: - with get_db() as db: + def is_user_channel_manager( + self, channel_id: str, user_id: str, db: Optional[Session] = None + ) -> bool: + with get_db_context(db) as db: # Check if the user is the creator of the channel # or has a 'manager' role in ChannelMember channel = db.query(Channel).filter(Channel.id == channel_id).first() @@ -527,9 +542,9 @@ def is_user_channel_manager(self, channel_id: str, user_id: str) -> bool: return membership is not None def join_channel( - self, channel_id: str, user_id: str + self, channel_id: str, user_id: str, db: Optional[Session] = None ) -> Optional[ChannelMemberModel]: - with get_db() as db: + with get_db_context(db) as db: # Check if the membership already exists existing_membership = ( db.query(ChannelMember) @@ -565,8 +580,10 @@ def join_channel( db.commit() return channel_member - def leave_channel(self, channel_id: str, user_id: str) -> bool: - with get_db() as db: + def leave_channel( + self, channel_id: str, user_id: str, db: Optional[Session] = None + ) -> bool: + with get_db_context(db) as db: membership = ( db.query(ChannelMember) .filter( @@ -587,9 +604,9 @@ def leave_channel(self, channel_id: str, user_id: str) -> bool: return True def get_member_by_channel_and_user_id( - self, channel_id: str, user_id: str + self, channel_id: str, user_id: str, db: Optional[Session] = None ) -> Optional[ChannelMemberModel]: - with get_db() as db: + with get_db_context(db) as db: membership = ( db.query(ChannelMember) .filter( @@ -600,8 +617,10 @@ def get_member_by_channel_and_user_id( ) return ChannelMemberModel.model_validate(membership) if membership else None - def get_members_by_channel_id(self, channel_id: str) -> list[ChannelMemberModel]: - with get_db() as db: + def get_members_by_channel_id( + self, channel_id: str, db: Optional[Session] = None + ) -> list[ChannelMemberModel]: + with get_db_context(db) as db: memberships = ( db.query(ChannelMember) .filter(ChannelMember.channel_id == channel_id) @@ -612,8 +631,14 @@ def get_members_by_channel_id(self, channel_id: str) -> list[ChannelMemberModel] for membership in memberships ] - def pin_channel(self, channel_id: str, user_id: str, is_pinned: bool) -> bool: - with get_db() as db: + def pin_channel( + self, + channel_id: str, + user_id: str, + is_pinned: bool, + db: Optional[Session] = None, + ) -> bool: + with get_db_context(db) as db: membership = ( db.query(ChannelMember) .filter( @@ -631,8 +656,10 @@ def pin_channel(self, channel_id: str, user_id: str, is_pinned: bool) -> bool: db.commit() return True - def update_member_last_read_at(self, channel_id: str, user_id: str) -> bool: - with get_db() as db: + def update_member_last_read_at( + self, channel_id: str, user_id: str, db: Optional[Session] = None + ) -> bool: + with get_db_context(db) as db: membership = ( db.query(ChannelMember) .filter( @@ -651,9 +678,13 @@ def update_member_last_read_at(self, channel_id: str, user_id: str) -> bool: return True def update_member_active_status( - self, channel_id: str, user_id: str, is_active: bool + self, + channel_id: str, + user_id: str, + is_active: bool, + db: Optional[Session] = None, ) -> bool: - with get_db() as db: + with get_db_context(db) as db: membership = ( db.query(ChannelMember) .filter( @@ -671,8 +702,10 @@ def update_member_active_status( db.commit() return True - def is_user_channel_member(self, channel_id: str, user_id: str) -> bool: - with get_db() as db: + def is_user_channel_member( + self, channel_id: str, user_id: str, db: Optional[Session] = None + ) -> bool: + with get_db_context(db) as db: membership = ( db.query(ChannelMember) .filter( @@ -683,13 +716,20 @@ def is_user_channel_member(self, channel_id: str, user_id: str) -> bool: ) return membership is not None - def get_channel_by_id(self, id: str) -> Optional[ChannelModel]: - with get_db() as db: - channel = db.query(Channel).filter(Channel.id == id).first() - return ChannelModel.model_validate(channel) if channel else None + def get_channel_by_id( + self, id: str, db: Optional[Session] = None + ) -> Optional[ChannelModel]: + try: + with get_db_context(db) as db: + channel = db.query(Channel).filter(Channel.id == id).first() + return ChannelModel.model_validate(channel) if channel else None + except Exception: + return None - def get_channels_by_file_id(self, file_id: str) -> list[ChannelModel]: - with get_db() as db: + def get_channels_by_file_id( + self, file_id: str, db: Optional[Session] = None + ) -> list[ChannelModel]: + with get_db_context(db) as db: channel_files = ( db.query(ChannelFile).filter(ChannelFile.file_id == file_id).all() ) @@ -698,9 +738,9 @@ def get_channels_by_file_id(self, file_id: str) -> list[ChannelModel]: return [ChannelModel.model_validate(channel) for channel in channels] def get_channels_by_file_id_and_user_id( - self, file_id: str, user_id: str + self, file_id: str, user_id: str, db: Optional[Session] = None ) -> list[ChannelModel]: - with get_db() as db: + with get_db_context(db) as db: # 1. Determine which channels have this file channel_file_rows = ( db.query(ChannelFile).filter(ChannelFile.file_id == file_id).all() @@ -724,7 +764,9 @@ def get_channels_by_file_id_and_user_id( return [] # Preload user's group membership - user_group_ids = [g.id for g in Groups.get_groups_by_member_id(user_id)] + user_group_ids = [ + g.id for g in Groups.get_groups_by_member_id(user_id, db=db) + ] allowed_channels = [] @@ -761,9 +803,9 @@ def get_channels_by_file_id_and_user_id( return allowed_channels def get_channel_by_id_and_user_id( - self, id: str, user_id: str + self, id: str, user_id: str, db: Optional[Session] = None ) -> Optional[ChannelModel]: - with get_db() as db: + with get_db_context(db) as db: # Fetch the channel channel: Channel = ( db.query(Channel) @@ -799,7 +841,7 @@ def get_channel_by_id_and_user_id( # Determine user groups user_group_ids = [ - group.id for group in Groups.get_groups_by_member_id(user_id) + group.id for group in Groups.get_groups_by_member_id(user_id, db=db) ] # Apply ACL rules @@ -818,9 +860,9 @@ def get_channel_by_id_and_user_id( ) def update_channel_by_id( - self, id: str, form_data: ChannelForm + self, id: str, form_data: ChannelForm, db: Optional[Session] = None ) -> Optional[ChannelModel]: - with get_db() as db: + with get_db_context(db) as db: channel = db.query(Channel).filter(Channel.id == id).first() if not channel: return None @@ -839,9 +881,9 @@ def update_channel_by_id( return ChannelModel.model_validate(channel) if channel else None def add_file_to_channel_by_id( - self, channel_id: str, file_id: str, user_id: str + self, channel_id: str, file_id: str, user_id: str, db: Optional[Session] = None ) -> Optional[ChannelFileModel]: - with get_db() as db: + with get_db_context(db) as db: channel_file = ChannelFileModel( **{ "id": str(uuid.uuid4()), @@ -866,10 +908,14 @@ def add_file_to_channel_by_id( return None def set_file_message_id_in_channel_by_id( - self, channel_id: str, file_id: str, message_id: str + self, + channel_id: str, + file_id: str, + message_id: str, + db: Optional[Session] = None, ) -> bool: try: - with get_db() as db: + with get_db_context(db) as db: channel_file = ( db.query(ChannelFile) .filter_by(channel_id=channel_id, file_id=file_id) @@ -886,9 +932,11 @@ def set_file_message_id_in_channel_by_id( except Exception: return False - def remove_file_from_channel_by_id(self, channel_id: str, file_id: str) -> bool: + def remove_file_from_channel_by_id( + self, channel_id: str, file_id: str, db: Optional[Session] = None + ) -> bool: try: - with get_db() as db: + with get_db_context(db) as db: db.query(ChannelFile).filter_by( channel_id=channel_id, file_id=file_id ).delete() @@ -897,11 +945,115 @@ def remove_file_from_channel_by_id(self, channel_id: str, file_id: str) -> bool: except Exception: return False - def delete_channel_by_id(self, id: str): - with get_db() as db: + def delete_channel_by_id(self, id: str, db: Optional[Session] = None) -> bool: + with get_db_context(db) as db: db.query(Channel).filter(Channel.id == id).delete() db.commit() return True + #################### + # Webhook Methods + #################### + + def insert_webhook( + self, + channel_id: str, + user_id: str, + form_data: ChannelWebhookForm, + db: Optional[Session] = None, + ) -> Optional[ChannelWebhookModel]: + with get_db_context(db) as db: + webhook = ChannelWebhookModel( + id=str(uuid.uuid4()), + channel_id=channel_id, + user_id=user_id, + name=form_data.name, + profile_image_url=form_data.profile_image_url, + token=secrets.token_urlsafe(32), + last_used_at=None, + created_at=int(time.time_ns()), + updated_at=int(time.time_ns()), + ) + db.add(ChannelWebhook(**webhook.model_dump())) + db.commit() + return webhook + + def get_webhooks_by_channel_id( + self, channel_id: str, db: Optional[Session] = None + ) -> list[ChannelWebhookModel]: + with get_db_context(db) as db: + webhooks = ( + db.query(ChannelWebhook) + .filter(ChannelWebhook.channel_id == channel_id) + .all() + ) + return [ChannelWebhookModel.model_validate(w) for w in webhooks] + + def get_webhook_by_id( + self, webhook_id: str, db: Optional[Session] = None + ) -> Optional[ChannelWebhookModel]: + with get_db_context(db) as db: + webhook = ( + db.query(ChannelWebhook).filter(ChannelWebhook.id == webhook_id).first() + ) + return ChannelWebhookModel.model_validate(webhook) if webhook else None + + def get_webhook_by_id_and_token( + self, webhook_id: str, token: str, db: Optional[Session] = None + ) -> Optional[ChannelWebhookModel]: + with get_db_context(db) as db: + webhook = ( + db.query(ChannelWebhook) + .filter( + ChannelWebhook.id == webhook_id, + ChannelWebhook.token == token, + ) + .first() + ) + return ChannelWebhookModel.model_validate(webhook) if webhook else None + + def update_webhook_by_id( + self, + webhook_id: str, + form_data: ChannelWebhookForm, + db: Optional[Session] = None, + ) -> Optional[ChannelWebhookModel]: + with get_db_context(db) as db: + webhook = ( + db.query(ChannelWebhook).filter(ChannelWebhook.id == webhook_id).first() + ) + if not webhook: + return None + webhook.name = form_data.name + webhook.profile_image_url = form_data.profile_image_url + webhook.updated_at = int(time.time_ns()) + db.commit() + return ChannelWebhookModel.model_validate(webhook) + + def update_webhook_last_used_at( + self, webhook_id: str, db: Optional[Session] = None + ) -> bool: + with get_db_context(db) as db: + webhook = ( + db.query(ChannelWebhook).filter(ChannelWebhook.id == webhook_id).first() + ) + if not webhook: + return False + webhook.last_used_at = int(time.time_ns()) + db.commit() + return True + + def delete_webhook_by_id( + self, webhook_id: str, db: Optional[Session] = None + ) -> bool: + with get_db_context(db) as db: + result = ( + db.query(ChannelWebhook) + .filter(ChannelWebhook.id == webhook_id) + .delete() + ) + db.commit() + return result > 0 + Channels = ChannelTable() diff --git a/backend/open_webui/models/chats.py b/backend/open_webui/models/chats.py index d821985a4e..12359eec9f 100644 --- a/backend/open_webui/models/chats.py +++ b/backend/open_webui/models/chats.py @@ -4,7 +4,8 @@ import uuid from typing import Optional -from open_webui.internal.db import Base, get_db +from sqlalchemy.orm import Session +from open_webui.internal.db import Base, JSONField, get_db, get_db_context from open_webui.models.tags import TagModel, Tag, Tags from open_webui.models.folders import Folders from open_webui.utils.misc import sanitize_data_for_db, sanitize_text_for_db @@ -210,6 +211,48 @@ class ChatUsageStatsListResponse(BaseModel): model_config = ConfigDict(extra="allow") +class MessageStats(BaseModel): + id: str + role: str + model: Optional[str] = None + content_length: int + token_count: Optional[int] = None + timestamp: Optional[int] = None + rating: Optional[int] = None # Derived from message.annotation.rating + tags: Optional[list[str]] = None # Derived from message.annotation.tags + + +class ChatHistoryStats(BaseModel): + messages: dict[str, MessageStats] + currentId: Optional[str] = None + + +class ChatBody(BaseModel): + history: ChatHistoryStats + + +class AggregateChatStats(BaseModel): + average_response_time: float + average_user_message_content_length: float + average_assistant_message_content_length: float + models: dict[str, int] + message_count: int + history_models: dict[str, int] + history_message_count: int + history_user_message_count: int + history_assistant_message_count: int + + +class ChatStatsExport(BaseModel): + id: str + user_id: str + created_at: int + updated_at: int + tags: list[str] = [] + stats: AggregateChatStats + chat: ChatBody + + class ChatTable: def _clean_null_bytes(self, obj): """Recursively remove null bytes from strings in dict/list structures.""" @@ -238,8 +281,10 @@ def _sanitize_chat_row(self, chat_item): return changed - def insert_new_chat(self, user_id: str, form_data: ChatForm) -> Optional[ChatModel]: - with get_db() as db: + def insert_new_chat( + self, user_id: str, form_data: ChatForm, db: Optional[Session] = None + ) -> Optional[ChatModel]: + with get_db_context(db) as db: id = str(uuid.uuid4()) chat = ChatModel( **{ @@ -289,9 +334,12 @@ def _chat_import_form_to_chat_model( return chat def import_chats( - self, user_id: str, chat_import_forms: list[ChatImportForm] + self, + user_id: str, + chat_import_forms: list[ChatImportForm], + db: Optional[Session] = None, ) -> list[ChatModel]: - with get_db() as db: + with get_db_context(db) as db: chats = [] for form_data in chat_import_forms: @@ -302,9 +350,11 @@ def import_chats( db.commit() return [ChatModel.model_validate(chat) for chat in chats] - def update_chat_by_id(self, id: str, chat: dict) -> Optional[ChatModel]: + def update_chat_by_id( + self, id: str, chat: dict, db: Optional[Session] = None + ) -> Optional[ChatModel]: try: - with get_db() as db: + with get_db_context(db) as db: chat_item = db.get(Chat, id) chat_item.chat = self._clean_null_bytes(chat) chat_item.title = ( @@ -423,31 +473,37 @@ def add_message_status_to_chat_by_id_and_message_id( def add_message_files_by_id_and_message_id( self, id: str, message_id: str, files: list[dict] ) -> list[dict]: - chat = self.get_chat_by_id(id) - if chat is None: - return None + with get_db_context() as db: + chat = self.get_chat_by_id(id, db=db) + if chat is None: + return None - chat = chat.chat - history = chat.get("history", {}) + chat = chat.chat + history = chat.get("history", {}) - message_files = [] + message_files = [] - if message_id in history.get("messages", {}): - message_files = history["messages"][message_id].get("files", []) - message_files = message_files + files - history["messages"][message_id]["files"] = message_files + if message_id in history.get("messages", {}): + message_files = history["messages"][message_id].get("files", []) + message_files = message_files + files + history["messages"][message_id]["files"] = message_files - chat["history"] = history - self.update_chat_by_id(id, chat) - return message_files + chat["history"] = history + self.update_chat_by_id(id, chat, db=db) + return message_files - def insert_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]: - with get_db() as db: + def insert_shared_chat_by_chat_id( + self, chat_id: str, db: Optional[Session] = None + ) -> Optional[ChatModel]: + with get_db_context(db) as db: # Get the existing chat to share chat = db.get(Chat, chat_id) + # Check if chat exists + if not chat: + return None # Check if the chat is already shared if chat.share_id: - return self.get_chat_by_id_and_user_id(chat.share_id, "shared") + return self.get_chat_by_id_and_user_id(chat.share_id, "shared", db=db) # Create a new chat with the same data, but with a new ID shared_chat = ChatModel( **{ @@ -476,16 +532,18 @@ def insert_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]: db.commit() return shared_chat if (shared_result and result) else None - def update_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]: + def update_shared_chat_by_chat_id( + self, chat_id: str, db: Optional[Session] = None + ) -> Optional[ChatModel]: try: - with get_db() as db: + with get_db_context(db) as db: chat = db.get(Chat, chat_id) shared_chat = ( db.query(Chat).filter_by(user_id=f"shared-{chat_id}").first() ) if shared_chat is None: - return self.insert_shared_chat_by_chat_id(chat_id) + return self.insert_shared_chat_by_chat_id(chat_id, db=db) shared_chat.title = chat.title shared_chat.chat = chat.chat @@ -500,9 +558,11 @@ def update_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]: except Exception: return None - def delete_shared_chat_by_chat_id(self, chat_id: str) -> bool: + def delete_shared_chat_by_chat_id( + self, chat_id: str, db: Optional[Session] = None + ) -> bool: try: - with get_db() as db: + with get_db_context(db) as db: db.query(Chat).filter_by(user_id=f"shared-{chat_id}").delete() db.commit() @@ -510,9 +570,11 @@ def delete_shared_chat_by_chat_id(self, chat_id: str) -> bool: except Exception: return False - def unarchive_all_chats_by_user_id(self, user_id: str) -> bool: + def unarchive_all_chats_by_user_id( + self, user_id: str, db: Optional[Session] = None + ) -> bool: try: - with get_db() as db: + with get_db_context(db) as db: db.query(Chat).filter_by(user_id=user_id).update({"archived": False}) db.commit() return True @@ -520,10 +582,10 @@ def unarchive_all_chats_by_user_id(self, user_id: str) -> bool: return False def update_chat_share_id_by_id( - self, id: str, share_id: Optional[str] + self, id: str, share_id: Optional[str], db: Optional[Session] = None ) -> Optional[ChatModel]: try: - with get_db() as db: + with get_db_context(db) as db: chat = db.get(Chat, id) chat.share_id = share_id db.commit() @@ -532,9 +594,11 @@ def update_chat_share_id_by_id( except Exception: return None - def toggle_chat_pinned_by_id(self, id: str) -> Optional[ChatModel]: + def toggle_chat_pinned_by_id( + self, id: str, db: Optional[Session] = None + ) -> Optional[ChatModel]: try: - with get_db() as db: + with get_db_context(db) as db: chat = db.get(Chat, id) chat.pinned = not chat.pinned chat.updated_at = int(time.time()) @@ -544,9 +608,11 @@ def toggle_chat_pinned_by_id(self, id: str) -> Optional[ChatModel]: except Exception: return None - def toggle_chat_archive_by_id(self, id: str) -> Optional[ChatModel]: + def toggle_chat_archive_by_id( + self, id: str, db: Optional[Session] = None + ) -> Optional[ChatModel]: try: - with get_db() as db: + with get_db_context(db) as db: chat = db.get(Chat, id) chat.archived = not chat.archived chat.folder_id = None @@ -557,9 +623,11 @@ def toggle_chat_archive_by_id(self, id: str) -> Optional[ChatModel]: except Exception: return None - def archive_all_chats_by_user_id(self, user_id: str) -> bool: + def archive_all_chats_by_user_id( + self, user_id: str, db: Optional[Session] = None + ) -> bool: try: - with get_db() as db: + with get_db_context(db) as db: db.query(Chat).filter_by(user_id=user_id).update({"archived": True}) db.commit() return True @@ -572,9 +640,10 @@ def get_archived_chat_list_by_user_id( filter: Optional[dict] = None, skip: int = 0, limit: int = 50, + db: Optional[Session] = None, ) -> list[ChatModel]: - with get_db() as db: + with get_db_context(db) as db: query = db.query(Chat).filter_by(user_id=user_id, archived=True) if filter: @@ -613,8 +682,9 @@ def get_chat_list_by_user_id( filter: Optional[dict] = None, skip: int = 0, limit: int = 50, + db: Optional[Session] = None, ) -> list[ChatModel]: - with get_db() as db: + with get_db_context(db) as db: query = db.query(Chat).filter_by(user_id=user_id) if not include_archived: query = query.filter_by(archived=False) @@ -653,8 +723,9 @@ def get_chat_title_id_list_by_user_id( include_pinned: bool = False, skip: Optional[int] = None, limit: Optional[int] = None, + db: Optional[Session] = None, ) -> list[ChatTitleIdResponse]: - with get_db() as db: + with get_db_context(db) as db: query = db.query(Chat).filter_by(user_id=user_id) if not include_folders: @@ -691,9 +762,13 @@ def get_chat_title_id_list_by_user_id( ] def get_chat_list_by_chat_ids( - self, chat_ids: list[str], skip: int = 0, limit: int = 50 + self, + chat_ids: list[str], + skip: int = 0, + limit: int = 50, + db: Optional[Session] = None, ) -> list[ChatModel]: - with get_db() as db: + with get_db_context(db) as db: all_chats = ( db.query(Chat) .filter(Chat.id.in_(chat_ids)) @@ -703,9 +778,11 @@ def get_chat_list_by_chat_ids( ) return [ChatModel.model_validate(chat) for chat in all_chats] - def get_chat_by_id(self, id: str) -> Optional[ChatModel]: + def get_chat_by_id( + self, id: str, db: Optional[Session] = None + ) -> Optional[ChatModel]: try: - with get_db() as db: + with get_db_context(db) as db: chat_item = db.get(Chat, id) if chat_item is None: return None @@ -718,30 +795,36 @@ def get_chat_by_id(self, id: str) -> Optional[ChatModel]: except Exception: return None - def get_chat_by_share_id(self, id: str) -> Optional[ChatModel]: + def get_chat_by_share_id( + self, id: str, db: Optional[Session] = None + ) -> Optional[ChatModel]: try: - with get_db() as db: + with get_db_context(db) as db: # it is possible that the shared link was deleted. hence, # we check if the chat is still shared by checking if a chat with the share_id exists chat = db.query(Chat).filter_by(share_id=id).first() if chat: - return self.get_chat_by_id(id) + return self.get_chat_by_id(id, db=db) else: return None except Exception: return None - def get_chat_by_id_and_user_id(self, id: str, user_id: str) -> Optional[ChatModel]: + def get_chat_by_id_and_user_id( + self, id: str, user_id: str, db: Optional[Session] = None + ) -> Optional[ChatModel]: try: - with get_db() as db: + with get_db_context(db) as db: chat = db.query(Chat).filter_by(id=id, user_id=user_id).first() return ChatModel.model_validate(chat) except Exception: return None - def get_chats(self, skip: int = 0, limit: int = 50) -> list[ChatModel]: - with get_db() as db: + def get_chats( + self, skip: int = 0, limit: int = 50, db: Optional[Session] = None + ) -> list[ChatModel]: + with get_db_context(db) as db: all_chats = ( db.query(Chat) # .limit(limit).offset(skip) @@ -750,14 +833,34 @@ def get_chats(self, skip: int = 0, limit: int = 50) -> list[ChatModel]: return [ChatModel.model_validate(chat) for chat in all_chats] def get_chats_by_user_id( - self, user_id: str, skip: Optional[int] = None, limit: Optional[int] = None + self, + user_id: str, + filter: Optional[dict] = None, + skip: Optional[int] = None, + limit: Optional[int] = None, + db: Optional[Session] = None, ) -> ChatListResponse: - with get_db() as db: - query = ( - db.query(Chat) - .filter_by(user_id=user_id) - .order_by(Chat.updated_at.desc()) - ) + with get_db_context(db) as db: + query = db.query(Chat).filter_by(user_id=user_id) + + if filter: + if filter.get("updated_at"): + query = query.filter(Chat.updated_at > filter.get("updated_at")) + + order_by = filter.get("order_by") + direction = filter.get("direction") + + if order_by and direction: + if hasattr(Chat, order_by): + if direction.lower() == "asc": + query = query.order_by(getattr(Chat, order_by).asc()) + elif direction.lower() == "desc": + query = query.order_by(getattr(Chat, order_by).desc()) + else: + query = query.order_by(Chat.updated_at.desc()) + + else: + query = query.order_by(Chat.updated_at.desc()) total = query.count() @@ -775,8 +878,10 @@ def get_chats_by_user_id( } ) - def get_pinned_chats_by_user_id(self, user_id: str) -> list[ChatModel]: - with get_db() as db: + def get_pinned_chats_by_user_id( + self, user_id: str, db: Optional[Session] = None + ) -> list[ChatModel]: + with get_db_context(db) as db: all_chats = ( db.query(Chat) .filter_by(user_id=user_id, pinned=True, archived=False) @@ -784,8 +889,10 @@ def get_pinned_chats_by_user_id(self, user_id: str) -> list[ChatModel]: ) return [ChatModel.model_validate(chat) for chat in all_chats] - def get_archived_chats_by_user_id(self, user_id: str) -> list[ChatModel]: - with get_db() as db: + def get_archived_chats_by_user_id( + self, user_id: str, db: Optional[Session] = None + ) -> list[ChatModel]: + with get_db_context(db) as db: all_chats = ( db.query(Chat) .filter_by(user_id=user_id, archived=True) @@ -800,6 +907,7 @@ def get_chats_by_user_id_and_search_text( include_archived: bool = False, skip: int = 0, limit: int = 60, + db: Optional[Session] = None, ) -> list[ChatModel]: """ Filters chats based on a search query using Python, allowing pagination using skip and limit. @@ -808,7 +916,7 @@ def get_chats_by_user_id_and_search_text( if not search_text: return self.get_chat_list_by_user_id( - user_id, include_archived, filter={}, skip=skip, limit=limit + user_id, include_archived, filter={}, skip=skip, limit=limit, db=db ) search_text_words = search_text.split(" ") @@ -863,7 +971,7 @@ def get_chats_by_user_id_and_search_text( search_text = " ".join(search_text_words) - with get_db() as db: + with get_db_context(db) as db: query = db.query(Chat).filter(Chat.user_id == user_id) if is_archived is not None: @@ -1004,9 +1112,14 @@ def get_chats_by_user_id_and_search_text( return [ChatModel.model_validate(chat) for chat in all_chats] def get_chats_by_folder_id_and_user_id( - self, folder_id: str, user_id: str, skip: int = 0, limit: int = 60 + self, + folder_id: str, + user_id: str, + skip: int = 0, + limit: int = 60, + db: Optional[Session] = None, ) -> list[ChatModel]: - with get_db() as db: + with get_db_context(db) as db: query = db.query(Chat).filter_by(folder_id=folder_id, user_id=user_id) query = query.filter(or_(Chat.pinned == False, Chat.pinned == None)) query = query.filter_by(archived=False) @@ -1022,9 +1135,9 @@ def get_chats_by_folder_id_and_user_id( return [ChatModel.model_validate(chat) for chat in all_chats] def get_chats_by_folder_ids_and_user_id( - self, folder_ids: list[str], user_id: str + self, folder_ids: list[str], user_id: str, db: Optional[Session] = None ) -> list[ChatModel]: - with get_db() as db: + with get_db_context(db) as db: query = db.query(Chat).filter( Chat.folder_id.in_(folder_ids), Chat.user_id == user_id ) @@ -1037,10 +1150,10 @@ def get_chats_by_folder_ids_and_user_id( return [ChatModel.model_validate(chat) for chat in all_chats] def update_chat_folder_id_by_id_and_user_id( - self, id: str, user_id: str, folder_id: str + self, id: str, user_id: str, folder_id: str, db: Optional[Session] = None ) -> Optional[ChatModel]: try: - with get_db() as db: + with get_db_context(db) as db: chat = db.get(Chat, id) chat.folder_id = folder_id chat.updated_at = int(time.time()) @@ -1051,16 +1164,23 @@ def update_chat_folder_id_by_id_and_user_id( except Exception: return None - def get_chat_tags_by_id_and_user_id(self, id: str, user_id: str) -> list[TagModel]: - with get_db() as db: + def get_chat_tags_by_id_and_user_id( + self, id: str, user_id: str, db: Optional[Session] = None + ) -> list[TagModel]: + with get_db_context(db) as db: chat = db.get(Chat, id) tags = chat.meta.get("tags", []) return [Tags.get_tag_by_name_and_user_id(tag, user_id) for tag in tags] def get_chat_list_by_user_id_and_tag_name( - self, user_id: str, tag_name: str, skip: int = 0, limit: int = 50 + self, + user_id: str, + tag_name: str, + skip: int = 0, + limit: int = 50, + db: Optional[Session] = None, ) -> list[ChatModel]: - with get_db() as db: + with get_db_context(db) as db: query = db.query(Chat).filter_by(user_id=user_id) tag_id = tag_name.replace(" ", "_").lower() @@ -1089,13 +1209,13 @@ def get_chat_list_by_user_id_and_tag_name( return [ChatModel.model_validate(chat) for chat in all_chats] def add_chat_tag_by_id_and_user_id_and_tag_name( - self, id: str, user_id: str, tag_name: str + self, id: str, user_id: str, tag_name: str, db: Optional[Session] = None ) -> Optional[ChatModel]: tag = Tags.get_tag_by_name_and_user_id(tag_name, user_id) if tag is None: tag = Tags.insert_new_tag(tag_name, user_id) try: - with get_db() as db: + with get_db_context(db) as db: chat = db.get(Chat, id) tag_id = tag.id @@ -1111,8 +1231,10 @@ def add_chat_tag_by_id_and_user_id_and_tag_name( except Exception: return None - def count_chats_by_tag_name_and_user_id(self, tag_name: str, user_id: str) -> int: - with get_db() as db: # Assuming `get_db()` returns a session object + def count_chats_by_tag_name_and_user_id( + self, tag_name: str, user_id: str, db: Optional[Session] = None + ) -> int: + with get_db_context(db) as db: # Assuming `get_db()` returns a session object query = db.query(Chat).filter_by(user_id=user_id, archived=False) # Normalize the tag_name for consistency @@ -1147,8 +1269,10 @@ def count_chats_by_tag_name_and_user_id(self, tag_name: str, user_id: str) -> in return count - def count_chats_by_folder_id_and_user_id(self, folder_id: str, user_id: str) -> int: - with get_db() as db: + def count_chats_by_folder_id_and_user_id( + self, folder_id: str, user_id: str, db: Optional[Session] = None + ) -> int: + with get_db_context(db) as db: query = db.query(Chat).filter_by(user_id=user_id) query = query.filter_by(folder_id=folder_id) @@ -1158,10 +1282,10 @@ def count_chats_by_folder_id_and_user_id(self, folder_id: str, user_id: str) -> return count def delete_tag_by_id_and_user_id_and_tag_name( - self, id: str, user_id: str, tag_name: str + self, id: str, user_id: str, tag_name: str, db: Optional[Session] = None ) -> bool: try: - with get_db() as db: + with get_db_context(db) as db: chat = db.get(Chat, id) tags = chat.meta.get("tags", []) tag_id = tag_name.replace(" ", "_").lower() @@ -1176,9 +1300,11 @@ def delete_tag_by_id_and_user_id_and_tag_name( except Exception: return False - def delete_all_tags_by_id_and_user_id(self, id: str, user_id: str) -> bool: + def delete_all_tags_by_id_and_user_id( + self, id: str, user_id: str, db: Optional[Session] = None + ) -> bool: try: - with get_db() as db: + with get_db_context(db) as db: chat = db.get(Chat, id) chat.meta = { **chat.meta, @@ -1190,30 +1316,34 @@ def delete_all_tags_by_id_and_user_id(self, id: str, user_id: str) -> bool: except Exception: return False - def delete_chat_by_id(self, id: str) -> bool: + def delete_chat_by_id(self, id: str, db: Optional[Session] = None) -> bool: try: - with get_db() as db: + with get_db_context(db) as db: db.query(Chat).filter_by(id=id).delete() db.commit() - return True and self.delete_shared_chat_by_chat_id(id) + return True and self.delete_shared_chat_by_chat_id(id, db=db) except Exception: return False - def delete_chat_by_id_and_user_id(self, id: str, user_id: str) -> bool: + def delete_chat_by_id_and_user_id( + self, id: str, user_id: str, db: Optional[Session] = None + ) -> bool: try: - with get_db() as db: + with get_db_context(db) as db: db.query(Chat).filter_by(id=id, user_id=user_id).delete() db.commit() - return True and self.delete_shared_chat_by_chat_id(id) + return True and self.delete_shared_chat_by_chat_id(id, db=db) except Exception: return False - def delete_chats_by_user_id(self, user_id: str) -> bool: + def delete_chats_by_user_id( + self, user_id: str, db: Optional[Session] = None + ) -> bool: try: - with get_db() as db: - self.delete_shared_chats_by_user_id(user_id) + with get_db_context(db) as db: + self.delete_shared_chats_by_user_id(user_id, db=db) db.query(Chat).filter_by(user_id=user_id).delete() db.commit() @@ -1223,10 +1353,10 @@ def delete_chats_by_user_id(self, user_id: str) -> bool: return False def delete_chats_by_user_id_and_folder_id( - self, user_id: str, folder_id: str + self, user_id: str, folder_id: str, db: Optional[Session] = None ) -> bool: try: - with get_db() as db: + with get_db_context(db) as db: db.query(Chat).filter_by(user_id=user_id, folder_id=folder_id).delete() db.commit() @@ -1235,10 +1365,14 @@ def delete_chats_by_user_id_and_folder_id( return False def move_chats_by_user_id_and_folder_id( - self, user_id: str, folder_id: str, new_folder_id: Optional[str] + self, + user_id: str, + folder_id: str, + new_folder_id: Optional[str], + db: Optional[Session] = None, ) -> bool: try: - with get_db() as db: + with get_db_context(db) as db: db.query(Chat).filter_by(user_id=user_id, folder_id=folder_id).update( {"folder_id": new_folder_id} ) @@ -1248,9 +1382,11 @@ def move_chats_by_user_id_and_folder_id( except Exception: return False - def delete_shared_chats_by_user_id(self, user_id: str) -> bool: + def delete_shared_chats_by_user_id( + self, user_id: str, db: Optional[Session] = None + ) -> bool: try: - with get_db() as db: + with get_db_context(db) as db: chats_by_user = db.query(Chat).filter_by(user_id=user_id).all() shared_chat_ids = [f"shared-{chat.id}" for chat in chats_by_user] @@ -1262,7 +1398,12 @@ def delete_shared_chats_by_user_id(self, user_id: str) -> bool: return False def insert_chat_files( - self, chat_id: str, message_id: str, file_ids: list[str], user_id: str + self, + chat_id: str, + message_id: str, + file_ids: list[str], + user_id: str, + db: Optional[Session] = None, ) -> Optional[list[ChatFileModel]]: if not file_ids: return None @@ -1270,7 +1411,7 @@ def insert_chat_files( chat_message_file_ids = [ item.id for item in self.get_chat_files_by_chat_id_and_message_id( - chat_id, message_id + chat_id, message_id, db=db ) ] # Remove duplicates and existing file_ids @@ -1287,7 +1428,7 @@ def insert_chat_files( return None try: - with get_db() as db: + with get_db_context(db) as db: now = int(time.time()) chat_files = [ @@ -1315,9 +1456,9 @@ def insert_chat_files( return None def get_chat_files_by_chat_id_and_message_id( - self, chat_id: str, message_id: str + self, chat_id: str, message_id: str, db: Optional[Session] = None ) -> list[ChatFileModel]: - with get_db() as db: + with get_db_context(db) as db: all_chat_files = ( db.query(ChatFile) .filter_by(chat_id=chat_id, message_id=message_id) @@ -1328,17 +1469,21 @@ def get_chat_files_by_chat_id_and_message_id( ChatFileModel.model_validate(chat_file) for chat_file in all_chat_files ] - def delete_chat_file(self, chat_id: str, file_id: str) -> bool: + def delete_chat_file( + self, chat_id: str, file_id: str, db: Optional[Session] = None + ) -> bool: try: - with get_db() as db: + with get_db_context(db) as db: db.query(ChatFile).filter_by(chat_id=chat_id, file_id=file_id).delete() db.commit() return True except Exception: return False - def get_shared_chats_by_file_id(self, file_id: str) -> list[ChatModel]: - with get_db() as db: + def get_shared_chats_by_file_id( + self, file_id: str, db: Optional[Session] = None + ) -> list[ChatModel]: + with get_db_context(db) as db: # Join Chat and ChatFile tables to get shared chats associated with the file_id all_chats = ( db.query(Chat) diff --git a/backend/open_webui/models/feedbacks.py b/backend/open_webui/models/feedbacks.py index 39e22ff2d9..048c10f85c 100644 --- a/backend/open_webui/models/feedbacks.py +++ b/backend/open_webui/models/feedbacks.py @@ -3,7 +3,8 @@ import uuid from typing import Optional -from open_webui.internal.db import Base, get_db +from sqlalchemy.orm import Session +from open_webui.internal.db import Base, JSONField, get_db, get_db_context from open_webui.models.users import User from pydantic import BaseModel, ConfigDict @@ -67,6 +68,13 @@ class FeedbackIdResponse(BaseModel): updated_at: int +class LeaderboardFeedbackData(BaseModel): + """Minimal feedback data for leaderboard computation (excludes snapshot/meta).""" + + id: str + data: Optional[dict] = None + + class RatingData(BaseModel): rating: Optional[str | int] = None model_id: Optional[str] = None @@ -119,11 +127,22 @@ class FeedbackListResponse(BaseModel): total: int +class ModelHistoryEntry(BaseModel): + date: str + won: int + lost: int + + +class ModelHistoryResponse(BaseModel): + model_id: str + history: list[ModelHistoryEntry] + + class FeedbackTable: def insert_new_feedback( - self, user_id: str, form_data: FeedbackForm + self, user_id: str, form_data: FeedbackForm, db: Optional[Session] = None ) -> Optional[FeedbackModel]: - with get_db() as db: + with get_db_context(db) as db: id = str(uuid.uuid4()) feedback = FeedbackModel( **{ @@ -148,9 +167,11 @@ def insert_new_feedback( log.exception(f"Error creating a new feedback: {e}") return None - def get_feedback_by_id(self, id: str) -> Optional[FeedbackModel]: + def get_feedback_by_id( + self, id: str, db: Optional[Session] = None + ) -> Optional[FeedbackModel]: try: - with get_db() as db: + with get_db_context(db) as db: feedback = db.query(Feedback).filter_by(id=id).first() if not feedback: return None @@ -159,10 +180,10 @@ def get_feedback_by_id(self, id: str) -> Optional[FeedbackModel]: return None def get_feedback_by_id_and_user_id( - self, id: str, user_id: str + self, id: str, user_id: str, db: Optional[Session] = None ) -> Optional[FeedbackModel]: try: - with get_db() as db: + with get_db_context(db) as db: feedback = db.query(Feedback).filter_by(id=id, user_id=user_id).first() if not feedback: return None @@ -171,9 +192,13 @@ def get_feedback_by_id_and_user_id( return None def get_feedback_items( - self, filter: dict = {}, skip: int = 0, limit: int = 30 + self, + filter: dict = {}, + skip: int = 0, + limit: int = 30, + db: Optional[Session] = None, ) -> FeedbackListResponse: - with get_db() as db: + with get_db_context(db) as db: query = db.query(Feedback, User).join(User, Feedback.user_id == User.id) if filter: @@ -234,8 +259,8 @@ def get_feedback_items( return FeedbackListResponse(items=feedbacks, total=total) - def get_all_feedbacks(self) -> list[FeedbackModel]: - with get_db() as db: + def get_all_feedbacks(self, db: Optional[Session] = None) -> list[FeedbackModel]: + with get_db_context(db) as db: return [ FeedbackModel.model_validate(feedback) for feedback in db.query(Feedback) @@ -243,8 +268,110 @@ def get_all_feedbacks(self) -> list[FeedbackModel]: .all() ] - def get_feedbacks_by_type(self, type: str) -> list[FeedbackModel]: - with get_db() as db: + def get_all_feedback_ids( + self, db: Optional[Session] = None + ) -> list[FeedbackIdResponse]: + with get_db_context(db) as db: + return [ + FeedbackIdResponse( + id=row.id, + user_id=row.user_id, + created_at=row.created_at, + updated_at=row.updated_at, + ) + for row in db.query( + Feedback.id, + Feedback.user_id, + Feedback.created_at, + Feedback.updated_at, + ) + .order_by(Feedback.updated_at.desc()) + .all() + ] + + def get_feedbacks_for_leaderboard( + self, db: Optional[Session] = None + ) -> list[LeaderboardFeedbackData]: + """Fetch only id and data for leaderboard computation (excludes snapshot/meta).""" + with get_db_context(db) as db: + return [ + LeaderboardFeedbackData(id=row.id, data=row.data) + for row in db.query(Feedback.id, Feedback.data).all() + ] + + def get_model_evaluation_history( + self, model_id: str, days: int = 30, db: Optional[Session] = None + ) -> list[ModelHistoryEntry]: + """ + Get daily wins/losses for a specific model over the past N days. + If days=0, returns all time data starting from first feedback. + Returns: [{"date": "2026-01-08", "won": 5, "lost": 2}, ...] + """ + from datetime import datetime, timedelta + from collections import defaultdict + + with get_db_context(db) as db: + if days == 0: + # All time - no cutoff + rows = db.query(Feedback.created_at, Feedback.data).all() + else: + cutoff = int(time.time()) - (days * 86400) + rows = ( + db.query(Feedback.created_at, Feedback.data) + .filter(Feedback.created_at >= cutoff) + .all() + ) + + daily_counts = defaultdict(lambda: {"won": 0, "lost": 0}) + first_date = None + + for created_at, data in rows: + if not data: + continue + if data.get("model_id") != model_id: + continue + + rating_str = str(data.get("rating", "")) + if rating_str not in ("1", "-1"): + continue + + date_str = datetime.fromtimestamp(created_at).strftime("%Y-%m-%d") + if rating_str == "1": + daily_counts[date_str]["won"] += 1 + else: + daily_counts[date_str]["lost"] += 1 + + # Track first date for this model + if first_date is None or date_str < first_date: + first_date = date_str + + # Generate date range + result = [] + today = datetime.now().date() + + if days == 0 and first_date: + # All time: start from first feedback date + start_date = datetime.strptime(first_date, "%Y-%m-%d").date() + num_days = (today - start_date).days + 1 + else: + # Fixed range + num_days = days + start_date = today - timedelta(days=days - 1) + + for i in range(num_days): + d = start_date + timedelta(days=i) + date_str = d.strftime("%Y-%m-%d") + counts = daily_counts.get(date_str, {"won": 0, "lost": 0}) + result.append( + ModelHistoryEntry(date=date_str, won=counts["won"], lost=counts["lost"]) + ) + + return result + + def get_feedbacks_by_type( + self, type: str, db: Optional[Session] = None + ) -> list[FeedbackModel]: + with get_db_context(db) as db: return [ FeedbackModel.model_validate(feedback) for feedback in db.query(Feedback) @@ -253,8 +380,10 @@ def get_feedbacks_by_type(self, type: str) -> list[FeedbackModel]: .all() ] - def get_feedbacks_by_user_id(self, user_id: str) -> list[FeedbackModel]: - with get_db() as db: + def get_feedbacks_by_user_id( + self, user_id: str, db: Optional[Session] = None + ) -> list[FeedbackModel]: + with get_db_context(db) as db: return [ FeedbackModel.model_validate(feedback) for feedback in db.query(Feedback) @@ -264,9 +393,9 @@ def get_feedbacks_by_user_id(self, user_id: str) -> list[FeedbackModel]: ] def update_feedback_by_id( - self, id: str, form_data: FeedbackForm + self, id: str, form_data: FeedbackForm, db: Optional[Session] = None ) -> Optional[FeedbackModel]: - with get_db() as db: + with get_db_context(db) as db: feedback = db.query(Feedback).filter_by(id=id).first() if not feedback: return None @@ -284,9 +413,13 @@ def update_feedback_by_id( return FeedbackModel.model_validate(feedback) def update_feedback_by_id_and_user_id( - self, id: str, user_id: str, form_data: FeedbackForm + self, + id: str, + user_id: str, + form_data: FeedbackForm, + db: Optional[Session] = None, ) -> Optional[FeedbackModel]: - with get_db() as db: + with get_db_context(db) as db: feedback = db.query(Feedback).filter_by(id=id, user_id=user_id).first() if not feedback: return None @@ -303,8 +436,8 @@ def update_feedback_by_id_and_user_id( db.commit() return FeedbackModel.model_validate(feedback) - def delete_feedback_by_id(self, id: str) -> bool: - with get_db() as db: + def delete_feedback_by_id(self, id: str, db: Optional[Session] = None) -> bool: + with get_db_context(db) as db: feedback = db.query(Feedback).filter_by(id=id).first() if not feedback: return False @@ -312,8 +445,10 @@ def delete_feedback_by_id(self, id: str) -> bool: db.commit() return True - def delete_feedback_by_id_and_user_id(self, id: str, user_id: str) -> bool: - with get_db() as db: + def delete_feedback_by_id_and_user_id( + self, id: str, user_id: str, db: Optional[Session] = None + ) -> bool: + with get_db_context(db) as db: feedback = db.query(Feedback).filter_by(id=id, user_id=user_id).first() if not feedback: return False @@ -321,8 +456,10 @@ def delete_feedback_by_id_and_user_id(self, id: str, user_id: str) -> bool: db.commit() return True - def delete_feedbacks_by_user_id(self, user_id: str) -> bool: - with get_db() as db: + def delete_feedbacks_by_user_id( + self, user_id: str, db: Optional[Session] = None + ) -> bool: + with get_db_context(db) as db: feedbacks = db.query(Feedback).filter_by(user_id=user_id).all() if not feedbacks: return False @@ -331,8 +468,8 @@ def delete_feedbacks_by_user_id(self, user_id: str) -> bool: db.commit() return True - def delete_all_feedbacks(self) -> bool: - with get_db() as db: + def delete_all_feedbacks(self, db: Optional[Session] = None) -> bool: + with get_db_context(db) as db: feedbacks = db.query(Feedback).all() if not feedbacks: return False diff --git a/backend/open_webui/models/files.py b/backend/open_webui/models/files.py index 9d4e8fb054..4097ae08e1 100644 --- a/backend/open_webui/models/files.py +++ b/backend/open_webui/models/files.py @@ -2,7 +2,8 @@ import time from typing import Optional -from open_webui.internal.db import Base, JSONField, get_db +from sqlalchemy.orm import Session +from open_webui.internal.db import Base, JSONField, get_db, get_db_context from pydantic import BaseModel, ConfigDict from sqlalchemy import BigInteger, Column, String, Text, JSON @@ -108,8 +109,10 @@ class FileListResponse(BaseModel): class FilesTable: - def insert_new_file(self, user_id: str, form_data: FileForm) -> Optional[FileModel]: - with get_db() as db: + def insert_new_file( + self, user_id: str, form_data: FileForm, db: Optional[Session] = None + ) -> Optional[FileModel]: + with get_db_context(db) as db: file = FileModel( **{ **form_data.model_dump(), @@ -132,16 +135,23 @@ def insert_new_file(self, user_id: str, form_data: FileForm) -> Optional[FileMod log.exception(f"Error inserting a new file: {e}") return None - def get_file_by_id(self, id: str) -> Optional[FileModel]: - with get_db() as db: - try: - file = db.get(File, id) - return FileModel.model_validate(file) - except Exception: - return None + def get_file_by_id( + self, id: str, db: Optional[Session] = None + ) -> Optional[FileModel]: + try: + with get_db_context(db) as db: + try: + file = db.get(File, id) + return FileModel.model_validate(file) + except Exception: + return None + except Exception: + return None - def get_file_by_id_and_user_id(self, id: str, user_id: str) -> Optional[FileModel]: - with get_db() as db: + def get_file_by_id_and_user_id( + self, id: str, user_id: str, db: Optional[Session] = None + ) -> Optional[FileModel]: + with get_db_context(db) as db: try: file = db.query(File).filter_by(id=id, user_id=user_id).first() if file: @@ -151,8 +161,10 @@ def get_file_by_id_and_user_id(self, id: str, user_id: str) -> Optional[FileMode except Exception: return None - def get_file_metadata_by_id(self, id: str) -> Optional[FileMetadataResponse]: - with get_db() as db: + def get_file_metadata_by_id( + self, id: str, db: Optional[Session] = None + ) -> Optional[FileMetadataResponse]: + with get_db_context(db) as db: try: file = db.get(File, id) return FileMetadataResponse( @@ -165,12 +177,14 @@ def get_file_metadata_by_id(self, id: str) -> Optional[FileMetadataResponse]: except Exception: return None - def get_files(self) -> list[FileModel]: - with get_db() as db: + def get_files(self, db: Optional[Session] = None) -> list[FileModel]: + with get_db_context(db) as db: return [FileModel.model_validate(file) for file in db.query(File).all()] - def check_access_by_user_id(self, id, user_id, permission="write") -> bool: - file = self.get_file_by_id(id) + def check_access_by_user_id( + self, id, user_id, permission="write", db: Optional[Session] = None + ) -> bool: + file = self.get_file_by_id(id, db=db) if not file: return False if file.user_id == user_id: @@ -178,8 +192,10 @@ def check_access_by_user_id(self, id, user_id, permission="write") -> bool: # Implement additional access control logic here as needed return False - def get_files_by_ids(self, ids: list[str]) -> list[FileModel]: - with get_db() as db: + def get_files_by_ids( + self, ids: list[str], db: Optional[Session] = None + ) -> list[FileModel]: + with get_db_context(db) as db: return [ FileModel.model_validate(file) for file in db.query(File) @@ -188,8 +204,10 @@ def get_files_by_ids(self, ids: list[str]) -> list[FileModel]: .all() ] - def get_file_metadatas_by_ids(self, ids: list[str]) -> list[FileMetadataResponse]: - with get_db() as db: + def get_file_metadatas_by_ids( + self, ids: list[str], db: Optional[Session] = None + ) -> list[FileMetadataResponse]: + with get_db_context(db) as db: return [ FileMetadataResponse( id=file.id, @@ -206,17 +224,81 @@ def get_file_metadatas_by_ids(self, ids: list[str]) -> list[FileMetadataResponse .all() ] - def get_files_by_user_id(self, user_id: str) -> list[FileModel]: - with get_db() as db: + def get_files_by_user_id( + self, user_id: str, db: Optional[Session] = None + ) -> list[FileModel]: + with get_db_context(db) as db: return [ FileModel.model_validate(file) for file in db.query(File).filter_by(user_id=user_id).all() ] + @staticmethod + def _glob_to_like_pattern(glob: str) -> str: + """ + Convert a glob/fnmatch pattern to a SQL LIKE pattern. + + Escapes SQL special characters and converts glob wildcards: + - `*` becomes `%` (match any sequence of characters) + - `?` becomes `_` (match exactly one character) + + Args: + glob: A glob pattern (e.g., "*.txt", "file?.doc") + + Returns: + A SQL LIKE compatible pattern with proper escaping. + """ + # Escape SQL special characters first, then convert glob wildcards + pattern = glob.replace("\\", "\\\\") + pattern = pattern.replace("%", "\\%") + pattern = pattern.replace("_", "\\_") + pattern = pattern.replace("*", "%") + pattern = pattern.replace("?", "_") + return pattern + + def search_files( + self, + user_id: Optional[str] = None, + filename: str = "*", + skip: int = 0, + limit: int = 100, + db: Optional[Session] = None, + ) -> list[FileModel]: + """ + Search files with glob pattern matching, optional user filter, and pagination. + + Args: + user_id: Filter by user ID. If None, returns files for all users. + filename: Glob pattern to match filenames (e.g., "*.txt"). Default "*" matches all. + skip: Number of results to skip for pagination. + limit: Maximum number of results to return. + db: Optional database session. + + Returns: + List of matching FileModel objects, ordered by updated_at descending. + """ + with get_db_context(db) as db: + query = db.query(File) + + if user_id: + query = query.filter_by(user_id=user_id) + + pattern = self._glob_to_like_pattern(filename) + if pattern != "%": + query = query.filter(File.filename.ilike(pattern, escape="\\")) + + return [ + FileModel.model_validate(file) + for file in query.order_by(File.updated_at.desc()) + .offset(skip) + .limit(limit) + .all() + ] + def update_file_by_id( - self, id: str, form_data: FileUpdateForm + self, id: str, form_data: FileUpdateForm, db: Optional[Session] = None ) -> Optional[FileModel]: - with get_db() as db: + with get_db_context(db) as db: try: file = db.query(File).filter_by(id=id).first() @@ -236,8 +318,10 @@ def update_file_by_id( log.exception(f"Error updating file completely by id: {e}") return None - def update_file_hash_by_id(self, id: str, hash: str) -> Optional[FileModel]: - with get_db() as db: + def update_file_hash_by_id( + self, id: str, hash: Optional[str], db: Optional[Session] = None + ) -> Optional[FileModel]: + with get_db_context(db) as db: try: file = db.query(File).filter_by(id=id).first() file.hash = hash @@ -248,8 +332,10 @@ def update_file_hash_by_id(self, id: str, hash: str) -> Optional[FileModel]: except Exception: return None - def update_file_data_by_id(self, id: str, data: dict) -> Optional[FileModel]: - with get_db() as db: + def update_file_data_by_id( + self, id: str, data: dict, db: Optional[Session] = None + ) -> Optional[FileModel]: + with get_db_context(db) as db: try: file = db.query(File).filter_by(id=id).first() file.data = {**(file.data if file.data else {}), **data} @@ -260,8 +346,10 @@ def update_file_data_by_id(self, id: str, data: dict) -> Optional[FileModel]: return None - def update_file_metadata_by_id(self, id: str, meta: dict) -> Optional[FileModel]: - with get_db() as db: + def update_file_metadata_by_id( + self, id: str, meta: dict, db: Optional[Session] = None + ) -> Optional[FileModel]: + with get_db_context(db) as db: try: file = db.query(File).filter_by(id=id).first() file.meta = {**(file.meta if file.meta else {}), **meta} @@ -271,8 +359,10 @@ def update_file_metadata_by_id(self, id: str, meta: dict) -> Optional[FileModel] except Exception: return None - def delete_file_by_id(self, id: str) -> bool: - with get_db() as db: + return False + + def delete_file_by_id(self, id: str, db: Optional[Session] = None) -> bool: + with get_db_context(db) as db: try: db.query(File).filter_by(id=id).delete() db.commit() @@ -281,8 +371,8 @@ def delete_file_by_id(self, id: str) -> bool: except Exception: return False - def delete_all_files(self) -> bool: - with get_db() as db: + def delete_all_files(self, db: Optional[Session] = None) -> bool: + with get_db_context(db) as db: try: db.query(File).delete() db.commit() diff --git a/backend/open_webui/models/folders.py b/backend/open_webui/models/folders.py index 0043dd3644..3455208944 100644 --- a/backend/open_webui/models/folders.py +++ b/backend/open_webui/models/folders.py @@ -7,8 +7,9 @@ from pydantic import BaseModel, ConfigDict from sqlalchemy import BigInteger, Column, Text, JSON, Boolean, func +from sqlalchemy.orm import Session -from open_webui.internal.db import Base, get_db +from open_webui.internal.db import Base, JSONField, get_db, get_db_context log = logging.getLogger(__name__) @@ -83,9 +84,13 @@ class FolderUpdateForm(BaseModel): class FolderTable: def insert_new_folder( - self, user_id: str, form_data: FolderForm, parent_id: Optional[str] = None + self, + user_id: str, + form_data: FolderForm, + parent_id: Optional[str] = None, + db: Optional[Session] = None, ) -> Optional[FolderModel]: - with get_db() as db: + with get_db_context(db) as db: id = str(uuid.uuid4()) folder = FolderModel( **{ @@ -111,10 +116,10 @@ def insert_new_folder( return None def get_folder_by_id_and_user_id( - self, id: str, user_id: str + self, id: str, user_id: str, db: Optional[Session] = None ) -> Optional[FolderModel]: try: - with get_db() as db: + with get_db_context(db) as db: folder = db.query(Folder).filter_by(id=id, user_id=user_id).first() if not folder: @@ -125,15 +130,15 @@ def get_folder_by_id_and_user_id( return None def get_children_folders_by_id_and_user_id( - self, id: str, user_id: str + self, id: str, user_id: str, db: Optional[Session] = None ) -> Optional[list[FolderModel]]: try: - with get_db() as db: + with get_db_context(db) as db: folders = [] def get_children(folder): children = self.get_folders_by_parent_id_and_user_id( - folder.id, user_id + folder.id, user_id, db=db ) for child in children: get_children(child) @@ -148,18 +153,24 @@ def get_children(folder): except Exception: return None - def get_folders_by_user_id(self, user_id: str) -> list[FolderModel]: - with get_db() as db: + def get_folders_by_user_id( + self, user_id: str, db: Optional[Session] = None + ) -> list[FolderModel]: + with get_db_context(db) as db: return [ FolderModel.model_validate(folder) for folder in db.query(Folder).filter_by(user_id=user_id).all() ] def get_folder_by_parent_id_and_user_id_and_name( - self, parent_id: Optional[str], user_id: str, name: str + self, + parent_id: Optional[str], + user_id: str, + name: str, + db: Optional[Session] = None, ) -> Optional[FolderModel]: try: - with get_db() as db: + with get_db_context(db) as db: # Check if folder exists folder = ( db.query(Folder) @@ -177,9 +188,9 @@ def get_folder_by_parent_id_and_user_id_and_name( return None def get_folders_by_parent_id_and_user_id( - self, parent_id: Optional[str], user_id: str + self, parent_id: Optional[str], user_id: str, db: Optional[Session] = None ) -> list[FolderModel]: - with get_db() as db: + with get_db_context(db) as db: return [ FolderModel.model_validate(folder) for folder in db.query(Folder) @@ -192,9 +203,10 @@ def update_folder_parent_id_by_id_and_user_id( id: str, user_id: str, parent_id: str, + db: Optional[Session] = None, ) -> Optional[FolderModel]: try: - with get_db() as db: + with get_db_context(db) as db: folder = db.query(Folder).filter_by(id=id, user_id=user_id).first() if not folder: @@ -211,10 +223,14 @@ def update_folder_parent_id_by_id_and_user_id( return def update_folder_by_id_and_user_id( - self, id: str, user_id: str, form_data: FolderUpdateForm + self, + id: str, + user_id: str, + form_data: FolderUpdateForm, + db: Optional[Session] = None, ) -> Optional[FolderModel]: try: - with get_db() as db: + with get_db_context(db) as db: folder = db.query(Folder).filter_by(id=id, user_id=user_id).first() if not folder: @@ -257,10 +273,10 @@ def update_folder_by_id_and_user_id( return def update_folder_is_expanded_by_id_and_user_id( - self, id: str, user_id: str, is_expanded: bool + self, id: str, user_id: str, is_expanded: bool, db: Optional[Session] = None ) -> Optional[FolderModel]: try: - with get_db() as db: + with get_db_context(db) as db: folder = db.query(Folder).filter_by(id=id, user_id=user_id).first() if not folder: @@ -276,10 +292,12 @@ def update_folder_is_expanded_by_id_and_user_id( log.error(f"update_folder: {e}") return - def delete_folder_by_id_and_user_id(self, id: str, user_id: str) -> list[str]: + def delete_folder_by_id_and_user_id( + self, id: str, user_id: str, db: Optional[Session] = None + ) -> list[str]: try: folder_ids = [] - with get_db() as db: + with get_db_context(db) as db: folder = db.query(Folder).filter_by(id=id, user_id=user_id).first() if not folder: return folder_ids @@ -289,7 +307,7 @@ def delete_folder_by_id_and_user_id(self, id: str, user_id: str) -> list[str]: # Delete all children folders def delete_children(folder): folder_children = self.get_folders_by_parent_id_and_user_id( - folder.id, user_id + folder.id, user_id, db=db ) for folder_child in folder_children: @@ -314,7 +332,7 @@ def normalize_folder_name(self, name: str) -> str: return name.strip().lower() def search_folders_by_names( - self, user_id: str, queries: list[str] + self, user_id: str, queries: list[str], db: Optional[Session] = None ) -> list[FolderModel]: """ Search for folders for a user where the name matches any of the queries, treating _ and space as equivalent, case-insensitive. @@ -324,7 +342,7 @@ def search_folders_by_names( return [] results = {} - with get_db() as db: + with get_db_context(db) as db: folders = db.query(Folder).filter_by(user_id=user_id).all() for folder in folders: if self.normalize_folder_name(folder.name) in normalized_queries: @@ -332,7 +350,7 @@ def search_folders_by_names( # get children folders children = self.get_children_folders_by_id_and_user_id( - folder.id, user_id + folder.id, user_id, db=db ) for child in children: results[child.id] = child @@ -345,14 +363,14 @@ def search_folders_by_names( return results def search_folders_by_name_contains( - self, user_id: str, query: str + self, user_id: str, query: str, db: Optional[Session] = None ) -> list[FolderModel]: """ Partial match: normalized name contains (as substring) the normalized query. """ normalized_query = self.normalize_folder_name(query) results = [] - with get_db() as db: + with get_db_context(db) as db: folders = db.query(Folder).filter_by(user_id=user_id).all() for folder in folders: norm_name = self.normalize_folder_name(folder.name) diff --git a/backend/open_webui/models/functions.py b/backend/open_webui/models/functions.py index 19ad985d0c..8e23bac093 100644 --- a/backend/open_webui/models/functions.py +++ b/backend/open_webui/models/functions.py @@ -2,7 +2,8 @@ import time from typing import Optional -from open_webui.internal.db import Base, JSONField, get_db +from sqlalchemy.orm import Session +from open_webui.internal.db import Base, JSONField, get_db, get_db_context from open_webui.models.users import Users, UserModel from pydantic import BaseModel, ConfigDict from sqlalchemy import BigInteger, Boolean, Column, String, Text, Index @@ -103,7 +104,11 @@ class FunctionValves(BaseModel): class FunctionsTable: def insert_new_function( - self, user_id: str, type: str, form_data: FunctionForm + self, + user_id: str, + type: str, + form_data: FunctionForm, + db: Optional[Session] = None, ) -> Optional[FunctionModel]: function = FunctionModel( **{ @@ -116,7 +121,7 @@ def insert_new_function( ) try: - with get_db() as db: + with get_db_context(db) as db: result = Function(**function.model_dump()) db.add(result) db.commit() @@ -130,11 +135,14 @@ def insert_new_function( return None def sync_functions( - self, user_id: str, functions: list[FunctionWithValvesModel] + self, + user_id: str, + functions: list[FunctionWithValvesModel], + db: Optional[Session] = None, ) -> list[FunctionWithValvesModel]: # Synchronize functions for a user by updating existing ones, inserting new ones, and removing those that are no longer present. try: - with get_db() as db: + with get_db_context(db) as db: # Get existing functions existing_functions = db.query(Function).all() existing_ids = {func.id for func in existing_functions} @@ -177,18 +185,20 @@ def sync_functions( log.exception(f"Error syncing functions for user {user_id}: {e}") return [] - def get_function_by_id(self, id: str) -> Optional[FunctionModel]: + def get_function_by_id( + self, id: str, db: Optional[Session] = None + ) -> Optional[FunctionModel]: try: - with get_db() as db: + with get_db_context(db) as db: function = db.get(Function, id) return FunctionModel.model_validate(function) except Exception: return None def get_functions( - self, active_only=False, include_valves=False + self, active_only=False, include_valves=False, db: Optional[Session] = None ) -> list[FunctionModel | FunctionWithValvesModel]: - with get_db() as db: + with get_db_context(db) as db: if active_only: functions = db.query(Function).filter_by(is_active=True).all() @@ -205,12 +215,14 @@ def get_functions( FunctionModel.model_validate(function) for function in functions ] - def get_function_list(self) -> list[FunctionUserResponse]: - with get_db() as db: + def get_function_list( + self, db: Optional[Session] = None + ) -> list[FunctionUserResponse]: + with get_db_context(db) as db: functions = db.query(Function).order_by(Function.updated_at.desc()).all() user_ids = list(set(func.user_id for func in functions)) - users = Users.get_users_by_user_ids(user_ids) if user_ids else [] + users = Users.get_users_by_user_ids(user_ids, db=db) if user_ids else [] users_dict = {user.id: user for user in users} return [ @@ -228,9 +240,9 @@ def get_function_list(self) -> list[FunctionUserResponse]: ] def get_functions_by_type( - self, type: str, active_only=False + self, type: str, active_only=False, db: Optional[Session] = None ) -> list[FunctionModel]: - with get_db() as db: + with get_db_context(db) as db: if active_only: return [ FunctionModel.model_validate(function) @@ -244,8 +256,10 @@ def get_functions_by_type( for function in db.query(Function).filter_by(type=type).all() ] - def get_global_filter_functions(self) -> list[FunctionModel]: - with get_db() as db: + def get_global_filter_functions( + self, db: Optional[Session] = None + ) -> list[FunctionModel]: + with get_db_context(db) as db: return [ FunctionModel.model_validate(function) for function in db.query(Function) @@ -253,8 +267,10 @@ def get_global_filter_functions(self) -> list[FunctionModel]: .all() ] - def get_global_action_functions(self) -> list[FunctionModel]: - with get_db() as db: + def get_global_action_functions( + self, db: Optional[Session] = None + ) -> list[FunctionModel]: + with get_db_context(db) as db: return [ FunctionModel.model_validate(function) for function in db.query(Function) @@ -262,8 +278,10 @@ def get_global_action_functions(self) -> list[FunctionModel]: .all() ] - def get_function_valves_by_id(self, id: str) -> Optional[dict]: - with get_db() as db: + def get_function_valves_by_id( + self, id: str, db: Optional[Session] = None + ) -> Optional[dict]: + with get_db_context(db) as db: try: function = db.get(Function, id) return function.valves if function.valves else {} @@ -272,23 +290,23 @@ def get_function_valves_by_id(self, id: str) -> Optional[dict]: return None def update_function_valves_by_id( - self, id: str, valves: dict + self, id: str, valves: dict, db: Optional[Session] = None ) -> Optional[FunctionValves]: - with get_db() as db: + with get_db_context(db) as db: try: function = db.get(Function, id) function.valves = valves function.updated_at = int(time.time()) db.commit() db.refresh(function) - return self.get_function_by_id(id) + return self.get_function_by_id(id, db=db) except Exception: return None def update_function_metadata_by_id( - self, id: str, metadata: dict + self, id: str, metadata: dict, db: Optional[Session] = None ) -> Optional[FunctionModel]: - with get_db() as db: + with get_db_context(db) as db: try: function = db.get(Function, id) @@ -301,7 +319,7 @@ def update_function_metadata_by_id( function.updated_at = int(time.time()) db.commit() db.refresh(function) - return self.get_function_by_id(id) + return self.get_function_by_id(id, db=db) else: return None except Exception as e: @@ -309,10 +327,10 @@ def update_function_metadata_by_id( return None def get_user_valves_by_id_and_user_id( - self, id: str, user_id: str + self, id: str, user_id: str, db: Optional[Session] = None ) -> Optional[dict]: try: - user = Users.get_user_by_id(user_id) + user = Users.get_user_by_id(user_id, db=db) user_settings = user.settings.model_dump() if user.settings else {} # Check if user has "functions" and "valves" settings @@ -327,10 +345,10 @@ def get_user_valves_by_id_and_user_id( return None def update_user_valves_by_id_and_user_id( - self, id: str, user_id: str, valves: dict + self, id: str, user_id: str, valves: dict, db: Optional[Session] = None ) -> Optional[dict]: try: - user = Users.get_user_by_id(user_id) + user = Users.get_user_by_id(user_id, db=db) user_settings = user.settings.model_dump() if user.settings else {} # Check if user has "functions" and "valves" settings @@ -342,7 +360,7 @@ def update_user_valves_by_id_and_user_id( user_settings["functions"]["valves"][id] = valves # Update the user settings in the database - Users.update_user_by_id(user_id, {"settings": user_settings}) + Users.update_user_by_id(user_id, {"settings": user_settings}, db=db) return user_settings["functions"]["valves"][id] except Exception as e: @@ -351,8 +369,10 @@ def update_user_valves_by_id_and_user_id( ) return None - def update_function_by_id(self, id: str, updated: dict) -> Optional[FunctionModel]: - with get_db() as db: + def update_function_by_id( + self, id: str, updated: dict, db: Optional[Session] = None + ) -> Optional[FunctionModel]: + with get_db_context(db) as db: try: db.query(Function).filter_by(id=id).update( { @@ -361,12 +381,12 @@ def update_function_by_id(self, id: str, updated: dict) -> Optional[FunctionMode } ) db.commit() - return self.get_function_by_id(id) + return self.get_function_by_id(id, db=db) except Exception: return None - def deactivate_all_functions(self) -> Optional[bool]: - with get_db() as db: + def deactivate_all_functions(self, db: Optional[Session] = None) -> Optional[bool]: + with get_db_context(db) as db: try: db.query(Function).update( { @@ -379,8 +399,8 @@ def deactivate_all_functions(self) -> Optional[bool]: except Exception: return None - def delete_function_by_id(self, id: str) -> bool: - with get_db() as db: + def delete_function_by_id(self, id: str, db: Optional[Session] = None) -> bool: + with get_db_context(db) as db: try: db.query(Function).filter_by(id=id).delete() db.commit() diff --git a/backend/open_webui/models/groups.py b/backend/open_webui/models/groups.py index da94287111..ae557f4daf 100644 --- a/backend/open_webui/models/groups.py +++ b/backend/open_webui/models/groups.py @@ -4,7 +4,8 @@ from typing import Optional import uuid -from open_webui.internal.db import Base, get_db +from sqlalchemy.orm import Session +from open_webui.internal.db import Base, JSONField, get_db, get_db_context from open_webui.models.files import FileMetadataResponse @@ -120,9 +121,9 @@ class GroupListResponse(BaseModel): class GroupTable: def insert_new_group( - self, user_id: str, form_data: GroupForm + self, user_id: str, form_data: GroupForm, db: Optional[Session] = None ) -> Optional[GroupModel]: - with get_db() as db: + with get_db_context(db) as db: group = GroupModel( **{ **form_data.model_dump(exclude_none=True), @@ -146,54 +147,84 @@ def insert_new_group( except Exception: return None - def get_all_groups(self) -> list[GroupModel]: - with get_db() as db: + def get_all_groups(self, db: Optional[Session] = None) -> list[GroupModel]: + with get_db_context(db) as db: groups = db.query(Group).order_by(Group.updated_at.desc()).all() return [GroupModel.model_validate(group) for group in groups] - def get_groups(self, filter) -> list[GroupResponse]: - with get_db() as db: + def get_groups(self, filter, db: Optional[Session] = None) -> list[GroupResponse]: + with get_db_context(db) as db: query = db.query(Group) if filter: if "query" in filter: query = query.filter(Group.name.ilike(f"%{filter['query']}%")) - if "member_id" in filter: - query = query.join( - GroupMember, GroupMember.group_id == Group.id - ).filter(GroupMember.user_id == filter["member_id"]) + # When share filter is present, member check is handled in the share logic if "share" in filter: share_value = filter["share"] - json_share = Group.data["config"]["share"].as_boolean() + member_id = filter.get("member_id") + json_share = Group.data["config"]["share"] + json_share_bool = json_share.as_boolean() + json_share_str = json_share.as_string() if share_value: - query = query.filter( - or_( - Group.data.is_(None), - json_share.is_(None), - json_share == True, - ) + # Groups open to anyone: data is null, share is null, or share is true + anyone_can_share = or_( + Group.data.is_(None), + json_share_bool.is_(None), + json_share_bool == True, ) + + if member_id: + # Also include member-only groups where user is a member + member_groups_subq = ( + db.query(GroupMember.group_id) + .filter(GroupMember.user_id == member_id) + .subquery() + ) + members_only_and_is_member = and_( + json_share_str == "members", + Group.id.in_(member_groups_subq), + ) + query = query.filter( + or_(anyone_can_share, members_only_and_is_member) + ) + else: + query = query.filter(anyone_can_share) else: query = query.filter( - and_(Group.data.isnot(None), json_share == False) + and_(Group.data.isnot(None), json_share_bool == False) ) + + else: + # Only apply member_id filter when share filter is NOT present + if "member_id" in filter: + query = query.join( + GroupMember, GroupMember.group_id == Group.id + ).filter(GroupMember.user_id == filter["member_id"]) + groups = query.order_by(Group.updated_at.desc()).all() return [ GroupResponse.model_validate( { **GroupModel.model_validate(group).model_dump(), - "member_count": self.get_group_member_count_by_id(group.id), + "member_count": self.get_group_member_count_by_id( + group.id, db=db + ), } ) for group in groups ] def search_groups( - self, filter: Optional[dict] = None, skip: int = 0, limit: int = 30 + self, + filter: Optional[dict] = None, + skip: int = 0, + limit: int = 30, + db: Optional[Session] = None, ) -> GroupListResponse: - with get_db() as db: + with get_db_context(db) as db: query = db.query(Group) if filter: @@ -220,15 +251,17 @@ def search_groups( "items": [ GroupResponse.model_validate( **GroupModel.model_validate(group).model_dump(), - member_count=self.get_group_member_count_by_id(group.id), + member_count=self.get_group_member_count_by_id(group.id, db=db), ) for group in groups ], "total": total, } - def get_groups_by_member_id(self, user_id: str) -> list[GroupModel]: - with get_db() as db: + def get_groups_by_member_id( + self, user_id: str, db: Optional[Session] = None + ) -> list[GroupModel]: + with get_db_context(db) as db: return [ GroupModel.model_validate(group) for group in db.query(Group) @@ -238,16 +271,41 @@ def get_groups_by_member_id(self, user_id: str) -> list[GroupModel]: .all() ] - def get_group_by_id(self, id: str) -> Optional[GroupModel]: + def get_groups_by_member_ids( + self, user_ids: list[str], db: Optional[Session] = None + ) -> dict[str, list[GroupModel]]: + """Fetch groups for multiple users in a single query to avoid N+1.""" + with get_db_context(db) as db: + # Query GroupMember joined with Group, filtering by user_ids + results = ( + db.query(GroupMember.user_id, Group) + .join(Group, Group.id == GroupMember.group_id) + .filter(GroupMember.user_id.in_(user_ids)) + .order_by(Group.updated_at.desc()) + .all() + ) + + # Group groups by user_id + user_groups: dict[str, list[GroupModel]] = {uid: [] for uid in user_ids} + for user_id, group in results: + user_groups[user_id].append(GroupModel.model_validate(group)) + + return user_groups + + def get_group_by_id( + self, id: str, db: Optional[Session] = None + ) -> Optional[GroupModel]: try: - with get_db() as db: + with get_db_context(db) as db: group = db.query(Group).filter_by(id=id).first() return GroupModel.model_validate(group) if group else None except Exception: return None - def get_group_user_ids_by_id(self, id: str) -> Optional[list[str]]: - with get_db() as db: + def get_group_user_ids_by_id( + self, id: str, db: Optional[Session] = None + ) -> Optional[list[str]]: + with get_db_context(db) as db: members = ( db.query(GroupMember.user_id).filter(GroupMember.group_id == id).all() ) @@ -257,8 +315,10 @@ def get_group_user_ids_by_id(self, id: str) -> Optional[list[str]]: return [m[0] for m in members] - def get_group_user_ids_by_ids(self, group_ids: list[str]) -> dict[str, list[str]]: - with get_db() as db: + def get_group_user_ids_by_ids( + self, group_ids: list[str], db: Optional[Session] = None + ) -> dict[str, list[str]]: + with get_db_context(db) as db: members = ( db.query(GroupMember.group_id, GroupMember.user_id) .filter(GroupMember.group_id.in_(group_ids)) @@ -274,8 +334,10 @@ def get_group_user_ids_by_ids(self, group_ids: list[str]) -> dict[str, list[str] return group_user_ids - def set_group_user_ids_by_id(self, group_id: str, user_ids: list[str]) -> None: - with get_db() as db: + def set_group_user_ids_by_id( + self, group_id: str, user_ids: list[str], db: Optional[Session] = None + ) -> None: + with get_db_context(db) as db: # Delete existing members db.query(GroupMember).filter(GroupMember.group_id == group_id).delete() @@ -295,8 +357,10 @@ def set_group_user_ids_by_id(self, group_id: str, user_ids: list[str]) -> None: db.add_all(new_members) db.commit() - def get_group_member_count_by_id(self, id: str) -> int: - with get_db() as db: + def get_group_member_count_by_id( + self, id: str, db: Optional[Session] = None + ) -> int: + with get_db_context(db) as db: count = ( db.query(func.count(GroupMember.user_id)) .filter(GroupMember.group_id == id) @@ -305,10 +369,14 @@ def get_group_member_count_by_id(self, id: str) -> int: return count if count else 0 def update_group_by_id( - self, id: str, form_data: GroupUpdateForm, overwrite: bool = False + self, + id: str, + form_data: GroupUpdateForm, + overwrite: bool = False, + db: Optional[Session] = None, ) -> Optional[GroupModel]: try: - with get_db() as db: + with get_db_context(db) as db: db.query(Group).filter_by(id=id).update( { **form_data.model_dump(exclude_none=True), @@ -316,22 +384,22 @@ def update_group_by_id( } ) db.commit() - return self.get_group_by_id(id=id) + return self.get_group_by_id(id=id, db=db) except Exception as e: log.exception(e) return None - def delete_group_by_id(self, id: str) -> bool: + def delete_group_by_id(self, id: str, db: Optional[Session] = None) -> bool: try: - with get_db() as db: + with get_db_context(db) as db: db.query(Group).filter_by(id=id).delete() db.commit() return True except Exception: return False - def delete_all_groups(self) -> bool: - with get_db() as db: + def delete_all_groups(self, db: Optional[Session] = None) -> bool: + with get_db_context(db) as db: try: db.query(Group).delete() db.commit() @@ -340,8 +408,10 @@ def delete_all_groups(self) -> bool: except Exception: return False - def remove_user_from_all_groups(self, user_id: str) -> bool: - with get_db() as db: + def remove_user_from_all_groups( + self, user_id: str, db: Optional[Session] = None + ) -> bool: + with get_db_context(db) as db: try: # Find all groups the user belongs to groups = ( @@ -369,16 +439,16 @@ def remove_user_from_all_groups(self, user_id: str) -> bool: return False def create_groups_by_group_names( - self, user_id: str, group_names: list[str] + self, user_id: str, group_names: list[str], db: Optional[Session] = None ) -> list[GroupModel]: # check for existing groups - existing_groups = self.get_all_groups() + existing_groups = self.get_all_groups(db=db) existing_group_names = {group.name for group in existing_groups} new_groups = [] - with get_db() as db: + with get_db_context(db) as db: for group_name in group_names: if group_name not in existing_group_names: new_group = GroupModel( @@ -400,8 +470,10 @@ def create_groups_by_group_names( continue return new_groups - def sync_groups_by_group_names(self, user_id: str, group_names: list[str]) -> bool: - with get_db() as db: + def sync_groups_by_group_names( + self, user_id: str, group_names: list[str], db: Optional[Session] = None + ) -> bool: + with get_db_context(db) as db: try: now = int(time.time()) @@ -461,10 +533,13 @@ def sync_groups_by_group_names(self, user_id: str, group_names: list[str]) -> bo return False def add_users_to_group( - self, id: str, user_ids: Optional[list[str]] = None + self, + id: str, + user_ids: Optional[list[str]] = None, + db: Optional[Session] = None, ) -> Optional[GroupModel]: try: - with get_db() as db: + with get_db_context(db) as db: group = db.query(Group).filter_by(id=id).first() if not group: return None @@ -499,10 +574,13 @@ def add_users_to_group( return None def remove_users_from_group( - self, id: str, user_ids: Optional[list[str]] = None + self, + id: str, + user_ids: Optional[list[str]] = None, + db: Optional[Session] = None, ) -> Optional[GroupModel]: try: - with get_db() as db: + with get_db_context(db) as db: group = db.query(Group).filter_by(id=id).first() if not group: return None diff --git a/backend/open_webui/models/knowledge.py b/backend/open_webui/models/knowledge.py index d8af004338..7f99f828c7 100644 --- a/backend/open_webui/models/knowledge.py +++ b/backend/open_webui/models/knowledge.py @@ -4,7 +4,8 @@ from typing import Optional import uuid -from open_webui.internal.db import Base, get_db +from sqlalchemy.orm import Session +from open_webui.internal.db import Base, JSONField, get_db, get_db_context from open_webui.models.files import ( File, @@ -157,9 +158,9 @@ class KnowledgeFileListResponse(BaseModel): class KnowledgeTable: def insert_new_knowledge( - self, user_id: str, form_data: KnowledgeForm + self, user_id: str, form_data: KnowledgeForm, db: Optional[Session] = None ) -> Optional[KnowledgeModel]: - with get_db() as db: + with get_db_context(db) as db: knowledge = KnowledgeModel( **{ **form_data.model_dump(), @@ -183,15 +184,15 @@ def insert_new_knowledge( return None def get_knowledge_bases( - self, skip: int = 0, limit: int = 30 + self, skip: int = 0, limit: int = 30, db: Optional[Session] = None ) -> list[KnowledgeUserModel]: - with get_db() as db: + with get_db_context(db) as db: all_knowledge = ( db.query(Knowledge).order_by(Knowledge.updated_at.desc()).all() ) user_ids = list(set(knowledge.user_id for knowledge in all_knowledge)) - users = Users.get_users_by_user_ids(user_ids) if user_ids else [] + users = Users.get_users_by_user_ids(user_ids, db=db) if user_ids else [] users_dict = {user.id: user for user in users} knowledge_bases = [] @@ -208,10 +209,15 @@ def get_knowledge_bases( return knowledge_bases def search_knowledge_bases( - self, user_id: str, filter: dict, skip: int = 0, limit: int = 30 + self, + user_id: str, + filter: dict, + skip: int = 0, + limit: int = 30, + db: Optional[Session] = None, ) -> KnowledgeListResponse: try: - with get_db() as db: + with get_db_context(db) as db: query = db.query(Knowledge, User).outerjoin( User, User.id == Knowledge.user_id ) @@ -267,17 +273,17 @@ def search_knowledge_bases( return KnowledgeListResponse(items=[], total=0) def search_knowledge_files( - self, filter: dict, skip: int = 0, limit: int = 30 + self, filter: dict, skip: int = 0, limit: int = 30, db: Optional[Session] = None ) -> KnowledgeFileListResponse: """ Scalable version: search files across all knowledge bases the user has READ access to, without loading all KBs or using large IN() lists. """ try: - with get_db() as db: + with get_db_context(db) as db: # Base query: join Knowledge → KnowledgeFile → File query = ( - db.query(File, User) + db.query(File, User, Knowledge) .join(KnowledgeFile, File.id == KnowledgeFile.file_id) .join(Knowledge, KnowledgeFile.knowledge_id == Knowledge.id) .outerjoin(User, User.id == KnowledgeFile.user_id) @@ -307,7 +313,7 @@ def search_knowledge_files( rows = query.all() items = [] - for file, user in rows: + for file, user, knowledge in rows: items.append( FileUserResponse( **FileModel.model_validate(file).model_dump(), @@ -318,6 +324,9 @@ def search_knowledge_files( if user else None ), + collection=KnowledgeModel.model_validate( + knowledge + ).model_dump(), ) ) @@ -327,20 +336,26 @@ def search_knowledge_files( print("search_knowledge_files error:", e) return KnowledgeFileListResponse(items=[], total=0) - def check_access_by_user_id(self, id, user_id, permission="write") -> bool: - knowledge = self.get_knowledge_by_id(id) + def check_access_by_user_id( + self, id, user_id, permission="write", db: Optional[Session] = None + ) -> bool: + knowledge = self.get_knowledge_by_id(id, db=db) if not knowledge: return False if knowledge.user_id == user_id: return True - user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id)} + user_group_ids = { + group.id for group in Groups.get_groups_by_member_id(user_id, db=db) + } return has_access(user_id, permission, knowledge.access_control, user_group_ids) def get_knowledge_bases_by_user_id( - self, user_id: str, permission: str = "write" + self, user_id: str, permission: str = "write", db: Optional[Session] = None ) -> list[KnowledgeUserModel]: - knowledge_bases = self.get_knowledge_bases() - user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id)} + knowledge_bases = self.get_knowledge_bases(db=db) + user_group_ids = { + group.id for group in Groups.get_groups_by_member_id(user_id, db=db) + } return [ knowledge_base for knowledge_base in knowledge_bases @@ -350,32 +365,38 @@ def get_knowledge_bases_by_user_id( ) ] - def get_knowledge_by_id(self, id: str) -> Optional[KnowledgeModel]: + def get_knowledge_by_id( + self, id: str, db: Optional[Session] = None + ) -> Optional[KnowledgeModel]: try: - with get_db() as db: + with get_db_context(db) as db: knowledge = db.query(Knowledge).filter_by(id=id).first() return KnowledgeModel.model_validate(knowledge) if knowledge else None except Exception: return None def get_knowledge_by_id_and_user_id( - self, id: str, user_id: str + self, id: str, user_id: str, db: Optional[Session] = None ) -> Optional[KnowledgeModel]: - knowledge = self.get_knowledge_by_id(id) + knowledge = self.get_knowledge_by_id(id, db=db) if not knowledge: return None if knowledge.user_id == user_id: return knowledge - user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id)} + user_group_ids = { + group.id for group in Groups.get_groups_by_member_id(user_id, db=db) + } if has_access(user_id, "write", knowledge.access_control, user_group_ids): return knowledge return None - def get_knowledges_by_file_id(self, file_id: str) -> list[KnowledgeModel]: + def get_knowledges_by_file_id( + self, file_id: str, db: Optional[Session] = None + ) -> list[KnowledgeModel]: try: - with get_db() as db: + with get_db_context(db) as db: knowledges = ( db.query(Knowledge) .join(KnowledgeFile, Knowledge.id == KnowledgeFile.knowledge_id) @@ -395,9 +416,10 @@ def search_files_by_id( filter: dict, skip: int = 0, limit: int = 30, + db: Optional[Session] = None, ) -> KnowledgeFileListResponse: try: - with get_db() as db: + with get_db_context(db) as db: query = ( db.query(File, User) .join(KnowledgeFile, File.id == KnowledgeFile.file_id) @@ -470,9 +492,11 @@ def search_files_by_id( print(e) return KnowledgeFileListResponse(items=[], total=0) - def get_files_by_id(self, knowledge_id: str) -> list[FileModel]: + def get_files_by_id( + self, knowledge_id: str, db: Optional[Session] = None + ) -> list[FileModel]: try: - with get_db() as db: + with get_db_context(db) as db: files = ( db.query(File) .join(KnowledgeFile, File.id == KnowledgeFile.file_id) @@ -483,18 +507,24 @@ def get_files_by_id(self, knowledge_id: str) -> list[FileModel]: except Exception: return [] - def get_file_metadatas_by_id(self, knowledge_id: str) -> list[FileMetadataResponse]: + def get_file_metadatas_by_id( + self, knowledge_id: str, db: Optional[Session] = None + ) -> list[FileMetadataResponse]: try: - with get_db() as db: - files = self.get_files_by_id(knowledge_id) + with get_db_context(db) as db: + files = self.get_files_by_id(knowledge_id, db=db) return [FileMetadataResponse(**file.model_dump()) for file in files] except Exception: return [] def add_file_to_knowledge_by_id( - self, knowledge_id: str, file_id: str, user_id: str + self, + knowledge_id: str, + file_id: str, + user_id: str, + db: Optional[Session] = None, ) -> Optional[KnowledgeFileModel]: - with get_db() as db: + with get_db_context(db) as db: knowledge_file = KnowledgeFileModel( **{ "id": str(uuid.uuid4()), @@ -518,9 +548,11 @@ def add_file_to_knowledge_by_id( except Exception: return None - def remove_file_from_knowledge_by_id(self, knowledge_id: str, file_id: str) -> bool: + def remove_file_from_knowledge_by_id( + self, knowledge_id: str, file_id: str, db: Optional[Session] = None + ) -> bool: try: - with get_db() as db: + with get_db_context(db) as db: db.query(KnowledgeFile).filter_by( knowledge_id=knowledge_id, file_id=file_id ).delete() @@ -529,9 +561,11 @@ def remove_file_from_knowledge_by_id(self, knowledge_id: str, file_id: str) -> b except Exception: return False - def reset_knowledge_by_id(self, id: str) -> Optional[KnowledgeModel]: + def reset_knowledge_by_id( + self, id: str, db: Optional[Session] = None + ) -> Optional[KnowledgeModel]: try: - with get_db() as db: + with get_db_context(db) as db: # Delete all knowledge_file entries for this knowledge_id db.query(KnowledgeFile).filter_by(knowledge_id=id).delete() db.commit() @@ -544,17 +578,21 @@ def reset_knowledge_by_id(self, id: str) -> Optional[KnowledgeModel]: ) db.commit() - return self.get_knowledge_by_id(id=id) + return self.get_knowledge_by_id(id=id, db=db) except Exception as e: log.exception(e) return None def update_knowledge_by_id( - self, id: str, form_data: KnowledgeForm, overwrite: bool = False + self, + id: str, + form_data: KnowledgeForm, + overwrite: bool = False, + db: Optional[Session] = None, ) -> Optional[KnowledgeModel]: try: - with get_db() as db: - knowledge = self.get_knowledge_by_id(id=id) + with get_db_context(db) as db: + knowledge = self.get_knowledge_by_id(id=id, db=db) db.query(Knowledge).filter_by(id=id).update( { **form_data.model_dump(), @@ -562,17 +600,17 @@ def update_knowledge_by_id( } ) db.commit() - return self.get_knowledge_by_id(id=id) + return self.get_knowledge_by_id(id=id, db=db) except Exception as e: log.exception(e) return None def update_knowledge_data_by_id( - self, id: str, data: dict + self, id: str, data: dict, db: Optional[Session] = None ) -> Optional[KnowledgeModel]: try: - with get_db() as db: - knowledge = self.get_knowledge_by_id(id=id) + with get_db_context(db) as db: + knowledge = self.get_knowledge_by_id(id=id, db=db) db.query(Knowledge).filter_by(id=id).update( { "data": data, @@ -580,22 +618,22 @@ def update_knowledge_data_by_id( } ) db.commit() - return self.get_knowledge_by_id(id=id) + return self.get_knowledge_by_id(id=id, db=db) except Exception as e: log.exception(e) return None - def delete_knowledge_by_id(self, id: str) -> bool: + def delete_knowledge_by_id(self, id: str, db: Optional[Session] = None) -> bool: try: - with get_db() as db: + with get_db_context(db) as db: db.query(Knowledge).filter_by(id=id).delete() db.commit() return True except Exception: return False - def delete_all_knowledge(self) -> bool: - with get_db() as db: + def delete_all_knowledge(self, db: Optional[Session] = None) -> bool: + with get_db_context(db) as db: try: db.query(Knowledge).delete() db.commit() diff --git a/backend/open_webui/models/memories.py b/backend/open_webui/models/memories.py index f5f2492b99..2dc9656856 100644 --- a/backend/open_webui/models/memories.py +++ b/backend/open_webui/models/memories.py @@ -2,7 +2,8 @@ import uuid from typing import Optional -from open_webui.internal.db import Base, get_db +from sqlalchemy.orm import Session +from open_webui.internal.db import Base, get_db, get_db_context from pydantic import BaseModel, ConfigDict from sqlalchemy import BigInteger, Column, String, Text @@ -41,8 +42,9 @@ def insert_new_memory( self, user_id: str, content: str, + db: Optional[Session] = None, ) -> Optional[MemoryModel]: - with get_db() as db: + with get_db_context(db) as db: id = str(uuid.uuid4()) memory = MemoryModel( @@ -68,8 +70,9 @@ def update_memory_by_id_and_user_id( id: str, user_id: str, content: str, + db: Optional[Session] = None, ) -> Optional[MemoryModel]: - with get_db() as db: + with get_db_context(db) as db: try: memory = db.get(Memory, id) if not memory or memory.user_id != user_id: @@ -83,32 +86,36 @@ def update_memory_by_id_and_user_id( except Exception: return None - def get_memories(self) -> list[MemoryModel]: - with get_db() as db: + def get_memories(self, db: Optional[Session] = None) -> list[MemoryModel]: + with get_db_context(db) as db: try: memories = db.query(Memory).all() return [MemoryModel.model_validate(memory) for memory in memories] except Exception: return None - def get_memories_by_user_id(self, user_id: str) -> list[MemoryModel]: - with get_db() as db: + def get_memories_by_user_id( + self, user_id: str, db: Optional[Session] = None + ) -> list[MemoryModel]: + with get_db_context(db) as db: try: memories = db.query(Memory).filter_by(user_id=user_id).all() return [MemoryModel.model_validate(memory) for memory in memories] except Exception: return None - def get_memory_by_id(self, id: str) -> Optional[MemoryModel]: - with get_db() as db: + def get_memory_by_id( + self, id: str, db: Optional[Session] = None + ) -> Optional[MemoryModel]: + with get_db_context(db) as db: try: memory = db.get(Memory, id) return MemoryModel.model_validate(memory) except Exception: return None - def delete_memory_by_id(self, id: str) -> bool: - with get_db() as db: + def delete_memory_by_id(self, id: str, db: Optional[Session] = None) -> bool: + with get_db_context(db) as db: try: db.query(Memory).filter_by(id=id).delete() db.commit() @@ -118,8 +125,10 @@ def delete_memory_by_id(self, id: str) -> bool: except Exception: return False - def delete_memories_by_user_id(self, user_id: str) -> bool: - with get_db() as db: + def delete_memories_by_user_id( + self, user_id: str, db: Optional[Session] = None + ) -> bool: + with get_db_context(db) as db: try: db.query(Memory).filter_by(user_id=user_id).delete() db.commit() @@ -128,8 +137,10 @@ def delete_memories_by_user_id(self, user_id: str) -> bool: except Exception: return False - def delete_memory_by_id_and_user_id(self, id: str, user_id: str) -> bool: - with get_db() as db: + def delete_memory_by_id_and_user_id( + self, id: str, user_id: str, db: Optional[Session] = None + ) -> bool: + with get_db_context(db) as db: try: memory = db.get(Memory, id) if not memory or memory.user_id != user_id: diff --git a/backend/open_webui/models/messages.py b/backend/open_webui/models/messages.py index 5b068b6449..0851107b0b 100644 --- a/backend/open_webui/models/messages.py +++ b/backend/open_webui/models/messages.py @@ -3,7 +3,8 @@ import uuid from typing import Optional -from open_webui.internal.db import Base, get_db +from sqlalchemy.orm import Session +from open_webui.internal.db import Base, JSONField, get_db, get_db_context from open_webui.models.tags import TagModel, Tag, Tags from open_webui.models.users import Users, User, UserNameResponse from open_webui.models.channels import Channels, ChannelMember @@ -137,9 +138,13 @@ class MessageResponse(MessageReplyToResponse): class MessageTable: def insert_new_message( - self, form_data: MessageForm, channel_id: str, user_id: str + self, + form_data: MessageForm, + channel_id: str, + user_id: str, + db: Optional[Session] = None, ) -> Optional[MessageModel]: - with get_db() as db: + with get_db_context(db) as db: channel_member = Channels.join_channel(channel_id, user_id) id = str(uuid.uuid4()) @@ -169,26 +174,57 @@ def insert_new_message( db.refresh(result) return MessageModel.model_validate(result) if result else None - def get_message_by_id(self, id: str) -> Optional[MessageResponse]: - with get_db() as db: + def get_message_by_id( + self, + id: str, + include_thread_replies: Optional[bool] = True, + db: Optional[Session] = None, + ) -> Optional[MessageResponse]: + with get_db_context(db) as db: message = db.get(Message, id) if not message: return None reply_to_message = ( - self.get_message_by_id(message.reply_to_id) + self.get_message_by_id( + message.reply_to_id, include_thread_replies=False, db=db + ) if message.reply_to_id else None ) - reactions = self.get_reactions_by_message_id(id) - thread_replies = self.get_thread_replies_by_message_id(id) + reactions = self.get_reactions_by_message_id(id, db=db) + + thread_replies = [] + if include_thread_replies: + thread_replies = self.get_thread_replies_by_message_id(id, db=db) + + # Check if message was sent by webhook (webhook info in meta takes precedence) + webhook_info = message.meta.get("webhook") if message.meta else None + if webhook_info and webhook_info.get("id"): + # Look up webhook by ID to get current name + webhook = Channels.get_webhook_by_id(webhook_info.get("id"), db=db) + if webhook: + user_info = { + "id": webhook.id, + "name": webhook.name, + "role": "webhook", + } + else: + # Webhook was deleted, use placeholder + user_info = { + "id": webhook_info.get("id"), + "name": "Deleted Webhook", + "role": "webhook", + } + else: + user = Users.get_user_by_id(message.user_id, db=db) + user_info = user.model_dump() if user else None - user = Users.get_user_by_id(message.user_id) return MessageResponse.model_validate( { **MessageModel.model_validate(message).model_dump(), - "user": user.model_dump() if user else None, + "user": user_info, "reply_to_message": ( reply_to_message.model_dump() if reply_to_message else None ), @@ -200,8 +236,10 @@ def get_message_by_id(self, id: str) -> Optional[MessageResponse]: } ) - def get_thread_replies_by_message_id(self, id: str) -> list[MessageReplyToResponse]: - with get_db() as db: + def get_thread_replies_by_message_id( + self, id: str, db: Optional[Session] = None + ) -> list[MessageReplyToResponse]: + with get_db_context(db) as db: all_messages = ( db.query(Message) .filter_by(parent_id=id) @@ -212,14 +250,35 @@ def get_thread_replies_by_message_id(self, id: str) -> list[MessageReplyToRespon messages = [] for message in all_messages: reply_to_message = ( - self.get_message_by_id(message.reply_to_id) + self.get_message_by_id( + message.reply_to_id, include_thread_replies=False, db=db + ) if message.reply_to_id else None ) + + webhook_info = message.meta.get("webhook") if message.meta else None + user_info = None + if webhook_info and webhook_info.get("id"): + webhook = Channels.get_webhook_by_id(webhook_info.get("id"), db=db) + if webhook: + user_info = { + "id": webhook.id, + "name": webhook.name, + "role": "webhook", + } + else: + user_info = { + "id": webhook_info.get("id"), + "name": "Deleted Webhook", + "role": "webhook", + } + messages.append( MessageReplyToResponse.model_validate( { **MessageModel.model_validate(message).model_dump(), + "user": user_info, "reply_to_message": ( reply_to_message.model_dump() if reply_to_message @@ -230,17 +289,23 @@ def get_thread_replies_by_message_id(self, id: str) -> list[MessageReplyToRespon ) return messages - def get_reply_user_ids_by_message_id(self, id: str) -> list[str]: - with get_db() as db: + def get_reply_user_ids_by_message_id( + self, id: str, db: Optional[Session] = None + ) -> list[str]: + with get_db_context(db) as db: return [ message.user_id for message in db.query(Message).filter_by(parent_id=id).all() ] def get_messages_by_channel_id( - self, channel_id: str, skip: int = 0, limit: int = 50 + self, + channel_id: str, + skip: int = 0, + limit: int = 50, + db: Optional[Session] = None, ) -> list[MessageReplyToResponse]: - with get_db() as db: + with get_db_context(db) as db: all_messages = ( db.query(Message) .filter_by(channel_id=channel_id, parent_id=None) @@ -253,14 +318,35 @@ def get_messages_by_channel_id( messages = [] for message in all_messages: reply_to_message = ( - self.get_message_by_id(message.reply_to_id) + self.get_message_by_id( + message.reply_to_id, include_thread_replies=False, db=db + ) if message.reply_to_id else None ) + + webhook_info = message.meta.get("webhook") if message.meta else None + user_info = None + if webhook_info and webhook_info.get("id"): + webhook = Channels.get_webhook_by_id(webhook_info.get("id"), db=db) + if webhook: + user_info = { + "id": webhook.id, + "name": webhook.name, + "role": "webhook", + } + else: + user_info = { + "id": webhook_info.get("id"), + "name": "Deleted Webhook", + "role": "webhook", + } + messages.append( MessageReplyToResponse.model_validate( { **MessageModel.model_validate(message).model_dump(), + "user": user_info, "reply_to_message": ( reply_to_message.model_dump() if reply_to_message @@ -272,9 +358,14 @@ def get_messages_by_channel_id( return messages def get_messages_by_parent_id( - self, channel_id: str, parent_id: str, skip: int = 0, limit: int = 50 + self, + channel_id: str, + parent_id: str, + skip: int = 0, + limit: int = 50, + db: Optional[Session] = None, ) -> list[MessageReplyToResponse]: - with get_db() as db: + with get_db_context(db) as db: message = db.get(Message, parent_id) if not message: @@ -296,14 +387,35 @@ def get_messages_by_parent_id( messages = [] for message in all_messages: reply_to_message = ( - self.get_message_by_id(message.reply_to_id) + self.get_message_by_id( + message.reply_to_id, include_thread_replies=False, db=db + ) if message.reply_to_id else None ) + + webhook_info = message.meta.get("webhook") if message.meta else None + user_info = None + if webhook_info and webhook_info.get("id"): + webhook = Channels.get_webhook_by_id(webhook_info.get("id"), db=db) + if webhook: + user_info = { + "id": webhook.id, + "name": webhook.name, + "role": "webhook", + } + else: + user_info = { + "id": webhook_info.get("id"), + "name": "Deleted Webhook", + "role": "webhook", + } + messages.append( MessageReplyToResponse.model_validate( { **MessageModel.model_validate(message).model_dump(), + "user": user_info, "reply_to_message": ( reply_to_message.model_dump() if reply_to_message @@ -314,8 +426,10 @@ def get_messages_by_parent_id( ) return messages - def get_last_message_by_channel_id(self, channel_id: str) -> Optional[MessageModel]: - with get_db() as db: + def get_last_message_by_channel_id( + self, channel_id: str, db: Optional[Session] = None + ) -> Optional[MessageModel]: + with get_db_context(db) as db: message = ( db.query(Message) .filter_by(channel_id=channel_id) @@ -325,9 +439,13 @@ def get_last_message_by_channel_id(self, channel_id: str) -> Optional[MessageMod return MessageModel.model_validate(message) if message else None def get_pinned_messages_by_channel_id( - self, channel_id: str, skip: int = 0, limit: int = 50 + self, + channel_id: str, + skip: int = 0, + limit: int = 50, + db: Optional[Session] = None, ) -> list[MessageModel]: - with get_db() as db: + with get_db_context(db) as db: all_messages = ( db.query(Message) .filter_by(channel_id=channel_id, is_pinned=True) @@ -339,9 +457,9 @@ def get_pinned_messages_by_channel_id( return [MessageModel.model_validate(message) for message in all_messages] def update_message_by_id( - self, id: str, form_data: MessageForm + self, id: str, form_data: MessageForm, db: Optional[Session] = None ) -> Optional[MessageModel]: - with get_db() as db: + with get_db_context(db) as db: message = db.get(Message, id) message.content = form_data.content message.data = { @@ -358,9 +476,13 @@ def update_message_by_id( return MessageModel.model_validate(message) if message else None def update_is_pinned_by_id( - self, id: str, is_pinned: bool, pinned_by: Optional[str] = None + self, + id: str, + is_pinned: bool, + pinned_by: Optional[str] = None, + db: Optional[Session] = None, ) -> Optional[MessageModel]: - with get_db() as db: + with get_db_context(db) as db: message = db.get(Message, id) message.is_pinned = is_pinned message.pinned_at = int(time.time_ns()) if is_pinned else None @@ -370,9 +492,13 @@ def update_is_pinned_by_id( return MessageModel.model_validate(message) if message else None def get_unread_message_count( - self, channel_id: str, user_id: str, last_read_at: Optional[int] = None + self, + channel_id: str, + user_id: str, + last_read_at: Optional[int] = None, + db: Optional[Session] = None, ) -> int: - with get_db() as db: + with get_db_context(db) as db: query = db.query(Message).filter( Message.channel_id == channel_id, Message.parent_id == None, # only count top-level messages @@ -383,9 +509,9 @@ def get_unread_message_count( return query.count() def add_reaction_to_message( - self, id: str, user_id: str, name: str + self, id: str, user_id: str, name: str, db: Optional[Session] = None ) -> Optional[MessageReactionModel]: - with get_db() as db: + with get_db_context(db) as db: # check for existing reaction existing_reaction = ( db.query(MessageReaction) @@ -409,8 +535,10 @@ def add_reaction_to_message( db.refresh(result) return MessageReactionModel.model_validate(result) if result else None - def get_reactions_by_message_id(self, id: str) -> list[Reactions]: - with get_db() as db: + def get_reactions_by_message_id( + self, id: str, db: Optional[Session] = None + ) -> list[Reactions]: + with get_db_context(db) as db: # JOIN User so all user info is fetched in one query results = ( db.query(MessageReaction, User) @@ -440,29 +568,29 @@ def get_reactions_by_message_id(self, id: str) -> list[Reactions]: return [Reactions(**reaction) for reaction in reactions.values()] def remove_reaction_by_id_and_user_id_and_name( - self, id: str, user_id: str, name: str + self, id: str, user_id: str, name: str, db: Optional[Session] = None ) -> bool: - with get_db() as db: + with get_db_context(db) as db: db.query(MessageReaction).filter_by( message_id=id, user_id=user_id, name=name ).delete() db.commit() return True - def delete_reactions_by_id(self, id: str) -> bool: - with get_db() as db: + def delete_reactions_by_id(self, id: str, db: Optional[Session] = None) -> bool: + with get_db_context(db) as db: db.query(MessageReaction).filter_by(message_id=id).delete() db.commit() return True - def delete_replies_by_id(self, id: str) -> bool: - with get_db() as db: + def delete_replies_by_id(self, id: str, db: Optional[Session] = None) -> bool: + with get_db_context(db) as db: db.query(Message).filter_by(parent_id=id).delete() db.commit() return True - def delete_message_by_id(self, id: str) -> bool: - with get_db() as db: + def delete_message_by_id(self, id: str, db: Optional[Session] = None) -> bool: + with get_db_context(db) as db: db.query(Message).filter_by(id=id).delete() # Delete all reactions to this message @@ -471,5 +599,35 @@ def delete_message_by_id(self, id: str) -> bool: db.commit() return True + def search_messages_by_channel_ids( + self, + channel_ids: list[str], + query: str, + start_timestamp: Optional[int] = None, + end_timestamp: Optional[int] = None, + limit: int = 10, + db: Optional[Session] = None, + ) -> list[MessageModel]: + """Search messages in specified channels by content.""" + with get_db_context(db) as db: + query_builder = db.query(Message).filter( + Message.channel_id.in_(channel_ids), + Message.content.ilike(f"%{query}%"), + ) + + if start_timestamp: + query_builder = query_builder.filter( + Message.created_at >= start_timestamp + ) + if end_timestamp: + query_builder = query_builder.filter( + Message.created_at <= end_timestamp + ) + + messages = ( + query_builder.order_by(Message.created_at.desc()).limit(limit).all() + ) + return [MessageModel.model_validate(msg) for msg in messages] + Messages = MessageTable() diff --git a/backend/open_webui/models/models.py b/backend/open_webui/models/models.py index 5feb044cba..f7b540d685 100755 --- a/backend/open_webui/models/models.py +++ b/backend/open_webui/models/models.py @@ -2,7 +2,8 @@ import time from typing import Optional -from open_webui.internal.db import Base, JSONField, get_db +from sqlalchemy.orm import Session +from open_webui.internal.db import Base, JSONField, get_db, get_db_context from open_webui.models.groups import Groups from open_webui.models.users import User, UserModel, Users, UserResponse @@ -130,6 +131,10 @@ class ModelUserResponse(ModelModel): user: Optional[UserResponse] = None +class ModelAccessResponse(ModelUserResponse): + write_access: Optional[bool] = False + + class ModelResponse(ModelModel): pass @@ -139,6 +144,11 @@ class ModelListResponse(BaseModel): total: int +class ModelAccessListResponse(BaseModel): + items: list[ModelAccessResponse] + total: int + + class ModelPriceForm(BaseModel): prompt_price: float = Field( default=0, description="prompt token price for 1m tokens", ge=0 @@ -185,7 +195,7 @@ class ModelForm(BaseModel): class ModelsTable: def insert_new_model( - self, form_data: ModelForm, user_id: str + self, form_data: ModelForm, user_id: str, db: Optional[Session] = None ) -> Optional[ModelModel]: model = ModelModel( **{ @@ -196,7 +206,7 @@ def insert_new_model( } ) try: - with get_db() as db: + with get_db_context(db) as db: result = Model(**model.model_dump()) db.add(result) db.commit() @@ -210,17 +220,17 @@ def insert_new_model( log.exception(f"Failed to insert a new model: {e}") return None - def get_all_models(self) -> list[ModelModel]: - with get_db() as db: + def get_all_models(self, db: Optional[Session] = None) -> list[ModelModel]: + with get_db_context(db) as db: return [ModelModel.model_validate(model) for model in db.query(Model).all()] - def get_models(self) -> list[ModelUserResponse]: - with get_db() as db: + def get_models(self, db: Optional[Session] = None) -> list[ModelUserResponse]: + with get_db_context(db) as db: all_models = db.query(Model).filter(Model.base_model_id != None).all() user_ids = list(set(model.user_id for model in all_models)) - users = Users.get_users_by_user_ids(user_ids) if user_ids else [] + users = Users.get_users_by_user_ids(user_ids, db=db) if user_ids else [] users_dict = {user.id: user for user in users} models = [] @@ -236,18 +246,20 @@ def get_models(self) -> list[ModelUserResponse]: ) return models - def get_base_models(self) -> list[ModelModel]: - with get_db() as db: + def get_base_models(self, db: Optional[Session] = None) -> list[ModelModel]: + with get_db_context(db) as db: return [ ModelModel.model_validate(model) for model in db.query(Model).filter(Model.base_model_id == None).all() ] def get_models_by_user_id( - self, user_id: str, permission: str = "write" + self, user_id: str, permission: str = "write", db: Optional[Session] = None ) -> list[ModelUserResponse]: - models = self.get_models() - user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id)} + models = self.get_models(db=db) + user_group_ids = { + group.id for group in Groups.get_groups_by_member_id(user_id, db=db) + } return [ model for model in models @@ -298,9 +310,14 @@ def _has_permission(self, db, query, filter: dict, permission: str = "read"): return query def search_models( - self, user_id: str, filter: dict = {}, skip: int = 0, limit: int = 30 + self, + user_id: str, + filter: dict = {}, + skip: int = 0, + limit: int = 30, + db: Optional[Session] = None, ) -> ModelListResponse: - with get_db() as db: + with get_db_context(db) as db: # Join GroupMember so we can order by group_id when requested query = db.query(Model, User).outerjoin(User, User.id == Model.user_id) query = query.filter(Model.base_model_id != None) @@ -326,7 +343,7 @@ def search_models( db, query, filter, - permission="write", + permission="read", ) tag = filter.get("tag") @@ -384,24 +401,30 @@ def search_models( return ModelListResponse(items=models, total=total) - def get_model_by_id(self, id: str) -> Optional[ModelModel]: + def get_model_by_id( + self, id: str, db: Optional[Session] = None + ) -> Optional[ModelModel]: try: - with get_db() as db: + with get_db_context(db) as db: model = db.get(Model, id) return ModelModel.model_validate(model) except Exception: return None - def get_models_by_ids(self, ids: list[str]) -> list[ModelModel]: + def get_models_by_ids( + self, ids: list[str], db: Optional[Session] = None + ) -> list[ModelModel]: try: - with get_db() as db: + with get_db_context(db) as db: models = db.query(Model).filter(Model.id.in_(ids)).all() return [ModelModel.model_validate(model) for model in models] except Exception: return [] - def toggle_model_by_id(self, id: str) -> Optional[ModelModel]: - with get_db() as db: + def toggle_model_by_id( + self, id: str, db: Optional[Session] = None + ) -> Optional[ModelModel]: + with get_db_context(db) as db: try: is_active = db.query(Model).filter_by(id=id).first().is_active @@ -413,13 +436,15 @@ def toggle_model_by_id(self, id: str) -> Optional[ModelModel]: ) db.commit() - return self.get_model_by_id(id) + return self.get_model_by_id(id, db=db) except Exception: return None - def update_model_by_id(self, id: str, model: ModelForm) -> Optional[ModelModel]: + def update_model_by_id( + self, id: str, model: ModelForm, db: Optional[Session] = None + ) -> Optional[ModelModel]: try: - with get_db() as db: + with get_db_context(db) as db: # update only the fields that are present in the model data = model.model_dump(exclude={"id"}) result = db.query(Model).filter_by(id=id).update(data) @@ -433,9 +458,9 @@ def update_model_by_id(self, id: str, model: ModelForm) -> Optional[ModelModel]: log.exception(f"Failed to update the model by id {id}: {e}") return None - def delete_model_by_id(self, id: str) -> bool: + def delete_model_by_id(self, id: str, db: Optional[Session] = None) -> bool: try: - with get_db() as db: + with get_db_context(db) as db: db.query(Model).filter_by(id=id).delete() db.commit() @@ -443,9 +468,9 @@ def delete_model_by_id(self, id: str) -> bool: except Exception: return False - def delete_all_models(self) -> bool: + def delete_all_models(self, db: Optional[Session] = None) -> bool: try: - with get_db() as db: + with get_db_context(db) as db: db.query(Model).delete() db.commit() @@ -453,9 +478,11 @@ def delete_all_models(self) -> bool: except Exception: return False - def sync_models(self, user_id: str, models: list[ModelModel]) -> list[ModelModel]: + def sync_models( + self, user_id: str, models: list[ModelModel], db: Optional[Session] = None + ) -> list[ModelModel]: try: - with get_db() as db: + with get_db_context(db) as db: # Get existing models existing_models = db.query(Model).all() existing_ids = {model.id for model in existing_models} diff --git a/backend/open_webui/models/notes.py b/backend/open_webui/models/notes.py index cfeddf4a8c..bd23530785 100644 --- a/backend/open_webui/models/notes.py +++ b/backend/open_webui/models/notes.py @@ -4,7 +4,8 @@ from typing import Optional from functools import lru_cache -from open_webui.internal.db import Base, get_db +from sqlalchemy.orm import Session +from open_webui.internal.db import Base, get_db, get_db_context from open_webui.models.groups import Groups from open_webui.utils.access_control import has_access from open_webui.models.users import User, UserModel, Users, UserResponse @@ -211,11 +212,9 @@ def _has_permission(self, db, query, filter: dict, permission: str = "read"): return query def insert_new_note( - self, - form_data: NoteForm, - user_id: str, + self, user_id: str, form_data: NoteForm, db: Optional[Session] = None ) -> Optional[NoteModel]: - with get_db() as db: + with get_db_context(db) as db: note = NoteModel( **{ "id": str(uuid.uuid4()), @@ -233,9 +232,9 @@ def insert_new_note( return note def get_notes( - self, skip: Optional[int] = None, limit: Optional[int] = None + self, skip: int = 0, limit: int = 50, db: Optional[Session] = None ) -> list[NoteModel]: - with get_db() as db: + with get_db_context(db) as db: query = db.query(Note).order_by(Note.updated_at.desc()) if skip is not None: query = query.offset(skip) @@ -245,19 +244,32 @@ def get_notes( return [NoteModel.model_validate(note) for note in notes] def search_notes( - self, user_id: str, filter: dict = {}, skip: int = 0, limit: int = 30 + self, + user_id: str, + filter: dict = {}, + skip: int = 0, + limit: int = 30, + db: Optional[Session] = None, ) -> NoteListResponse: - with get_db() as db: + with get_db_context(db) as db: query = db.query(Note, User).outerjoin(User, User.id == Note.user_id) if filter: query_key = filter.get("query") if query_key: + # Normalize search by removing hyphens and spaces (e.g., "todo" matches "to-do" and "to do") + normalized_query = query_key.replace("-", "").replace(" ", "") query = query.filter( or_( - Note.title.ilike(f"%{query_key}%"), - cast(Note.data["content"]["md"], Text).ilike( - f"%{query_key}%" - ), + func.replace( + func.replace(Note.title, "-", ""), " ", "" + ).ilike(f"%{normalized_query}%"), + func.replace( + func.replace( + cast(Note.data["content"]["md"], Text), "-", "" + ), + " ", + "", + ).ilike(f"%{normalized_query}%"), ) ) @@ -333,12 +345,13 @@ def get_notes_by_user_id( self, user_id: str, permission: str = "read", - skip: Optional[int] = None, - limit: Optional[int] = None, + skip: int = 0, + limit: int = 50, + db: Optional[Session] = None, ) -> list[NoteModel]: - with get_db() as db: + with get_db_context(db) as db: user_group_ids = [ - group.id for group in Groups.get_groups_by_member_id(user_id) + group.id for group in Groups.get_groups_by_member_id(user_id, db=db) ] query = db.query(Note).order_by(Note.updated_at.desc()) @@ -354,15 +367,17 @@ def get_notes_by_user_id( notes = query.all() return [NoteModel.model_validate(note) for note in notes] - def get_note_by_id(self, id: str) -> Optional[NoteModel]: - with get_db() as db: + def get_note_by_id( + self, id: str, db: Optional[Session] = None + ) -> Optional[NoteModel]: + with get_db_context(db) as db: note = db.query(Note).filter(Note.id == id).first() return NoteModel.model_validate(note) if note else None def update_note_by_id( - self, id: str, form_data: NoteUpdateForm + self, id: str, form_data: NoteUpdateForm, db: Optional[Session] = None ) -> Optional[NoteModel]: - with get_db() as db: + with get_db_context(db) as db: note = db.query(Note).filter(Note.id == id).first() if not note: return None @@ -384,11 +399,14 @@ def update_note_by_id( db.commit() return NoteModel.model_validate(note) if note else None - def delete_note_by_id(self, id: str): - with get_db() as db: - db.query(Note).filter(Note.id == id).delete() - db.commit() - return True + def delete_note_by_id(self, id: str, db: Optional[Session] = None) -> bool: + try: + with get_db_context(db) as db: + db.query(Note).filter(Note.id == id).delete() + db.commit() + return True + except Exception: + return False Notes = NoteTable() diff --git a/backend/open_webui/models/oauth_sessions.py b/backend/open_webui/models/oauth_sessions.py index 8b0334ed19..f7ee5cceb8 100644 --- a/backend/open_webui/models/oauth_sessions.py +++ b/backend/open_webui/models/oauth_sessions.py @@ -8,7 +8,8 @@ from cryptography.fernet import Fernet -from open_webui.internal.db import Base, get_db +from sqlalchemy.orm import Session +from open_webui.internal.db import Base, get_db, get_db_context from open_webui.env import OAUTH_SESSION_TOKEN_ENCRYPTION_KEY from pydantic import BaseModel, ConfigDict @@ -109,10 +110,11 @@ def create_session( user_id: str, provider: str, token: dict, + db: Optional[Session] = None, ) -> Optional[OAuthSessionModel]: """Create a new OAuth session""" try: - with get_db() as db: + with get_db_context(db) as db: current_time = int(time.time()) id = str(uuid.uuid4()) @@ -141,10 +143,12 @@ def create_session( log.error(f"Error creating OAuth session: {e}") return None - def get_session_by_id(self, session_id: str) -> Optional[OAuthSessionModel]: + def get_session_by_id( + self, session_id: str, db: Optional[Session] = None + ) -> Optional[OAuthSessionModel]: """Get OAuth session by ID""" try: - with get_db() as db: + with get_db_context(db) as db: session = db.query(OAuthSession).filter_by(id=session_id).first() if session: session.token = self._decrypt_token(session.token) @@ -156,11 +160,11 @@ def get_session_by_id(self, session_id: str) -> Optional[OAuthSessionModel]: return None def get_session_by_id_and_user_id( - self, session_id: str, user_id: str + self, session_id: str, user_id: str, db: Optional[Session] = None ) -> Optional[OAuthSessionModel]: """Get OAuth session by ID and user ID""" try: - with get_db() as db: + with get_db_context(db) as db: session = ( db.query(OAuthSession) .filter_by(id=session_id, user_id=user_id) @@ -176,11 +180,11 @@ def get_session_by_id_and_user_id( return None def get_session_by_provider_and_user_id( - self, provider: str, user_id: str + self, provider: str, user_id: str, db: Optional[Session] = None ) -> Optional[OAuthSessionModel]: """Get OAuth session by provider and user ID""" try: - with get_db() as db: + with get_db_context(db) as db: session = ( db.query(OAuthSession) .filter_by(provider=provider, user_id=user_id) @@ -195,10 +199,12 @@ def get_session_by_provider_and_user_id( log.error(f"Error getting OAuth session by provider and user ID: {e}") return None - def get_sessions_by_user_id(self, user_id: str) -> List[OAuthSessionModel]: + def get_sessions_by_user_id( + self, user_id: str, db: Optional[Session] = None + ) -> List[OAuthSessionModel]: """Get all OAuth sessions for a user""" try: - with get_db() as db: + with get_db_context(db) as db: sessions = db.query(OAuthSession).filter_by(user_id=user_id).all() results = [] @@ -213,11 +219,11 @@ def get_sessions_by_user_id(self, user_id: str) -> List[OAuthSessionModel]: return [] def update_session_by_id( - self, session_id: str, token: dict + self, session_id: str, token: dict, db: Optional[Session] = None ) -> Optional[OAuthSessionModel]: """Update OAuth session tokens""" try: - with get_db() as db: + with get_db_context(db) as db: current_time = int(time.time()) db.query(OAuthSession).filter_by(id=session_id).update( @@ -239,10 +245,12 @@ def update_session_by_id( log.error(f"Error updating OAuth session tokens: {e}") return None - def delete_session_by_id(self, session_id: str) -> bool: + def delete_session_by_id( + self, session_id: str, db: Optional[Session] = None + ) -> bool: """Delete an OAuth session""" try: - with get_db() as db: + with get_db_context(db) as db: result = db.query(OAuthSession).filter_by(id=session_id).delete() db.commit() return result > 0 @@ -250,10 +258,12 @@ def delete_session_by_id(self, session_id: str) -> bool: log.error(f"Error deleting OAuth session: {e}") return False - def delete_sessions_by_user_id(self, user_id: str) -> bool: + def delete_sessions_by_user_id( + self, user_id: str, db: Optional[Session] = None + ) -> bool: """Delete all OAuth sessions for a user""" try: - with get_db() as db: + with get_db_context(db) as db: result = db.query(OAuthSession).filter_by(user_id=user_id).delete() db.commit() return True @@ -261,10 +271,12 @@ def delete_sessions_by_user_id(self, user_id: str) -> bool: log.error(f"Error deleting OAuth sessions by user ID: {e}") return False - def delete_sessions_by_provider(self, provider: str) -> bool: + def delete_sessions_by_provider( + self, provider: str, db: Optional[Session] = None + ) -> bool: """Delete all OAuth sessions for a provider""" try: - with get_db() as db: + with get_db_context(db) as db: db.query(OAuthSession).filter_by(provider=provider).delete() db.commit() return True diff --git a/backend/open_webui/models/prompts.py b/backend/open_webui/models/prompts.py index 7502f34ccd..847597bc65 100644 --- a/backend/open_webui/models/prompts.py +++ b/backend/open_webui/models/prompts.py @@ -1,7 +1,8 @@ import time from typing import Optional -from open_webui.internal.db import Base, get_db +from sqlalchemy.orm import Session +from open_webui.internal.db import Base, JSONField, get_db, get_db_context from open_webui.models.groups import Groups from open_webui.models.users import Users, UserResponse @@ -62,6 +63,10 @@ class PromptUserResponse(PromptModel): user: Optional[UserResponse] = None +class PromptAccessResponse(PromptUserResponse): + write_access: Optional[bool] = False + + class PromptForm(BaseModel): command: str title: str @@ -71,7 +76,7 @@ class PromptForm(BaseModel): class PromptsTable: def insert_new_prompt( - self, user_id: str, form_data: PromptForm + self, user_id: str, form_data: PromptForm, db: Optional[Session] = None ) -> Optional[PromptModel]: prompt = PromptModel( **{ @@ -82,7 +87,7 @@ def insert_new_prompt( ) try: - with get_db() as db: + with get_db_context(db) as db: result = Prompt(**prompt.model_dump()) db.add(result) db.commit() @@ -94,21 +99,23 @@ def insert_new_prompt( except Exception: return None - def get_prompt_by_command(self, command: str) -> Optional[PromptModel]: + def get_prompt_by_command( + self, command: str, db: Optional[Session] = None + ) -> Optional[PromptModel]: try: - with get_db() as db: + with get_db_context(db) as db: prompt = db.query(Prompt).filter_by(command=command).first() return PromptModel.model_validate(prompt) except Exception: return None - def get_prompts(self) -> list[PromptUserResponse]: - with get_db() as db: + def get_prompts(self, db: Optional[Session] = None) -> list[PromptUserResponse]: + with get_db_context(db) as db: all_prompts = db.query(Prompt).order_by(Prompt.timestamp.desc()).all() user_ids = list(set(prompt.user_id for prompt in all_prompts)) - users = Users.get_users_by_user_ids(user_ids) if user_ids else [] + users = Users.get_users_by_user_ids(user_ids, db=db) if user_ids else [] users_dict = {user.id: user for user in users} prompts = [] @@ -126,10 +133,12 @@ def get_prompts(self) -> list[PromptUserResponse]: return prompts def get_prompts_by_user_id( - self, user_id: str, permission: str = "write" + self, user_id: str, permission: str = "write", db: Optional[Session] = None ) -> list[PromptUserResponse]: - prompts = self.get_prompts() - user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id)} + prompts = self.get_prompts(db=db) + user_group_ids = { + group.id for group in Groups.get_groups_by_member_id(user_id, db=db) + } return [ prompt @@ -139,10 +148,10 @@ def get_prompts_by_user_id( ] def update_prompt_by_command( - self, command: str, form_data: PromptForm + self, command: str, form_data: PromptForm, db: Optional[Session] = None ) -> Optional[PromptModel]: try: - with get_db() as db: + with get_db_context(db) as db: prompt = db.query(Prompt).filter_by(command=command).first() prompt.title = form_data.title prompt.content = form_data.content @@ -153,9 +162,11 @@ def update_prompt_by_command( except Exception: return None - def delete_prompt_by_command(self, command: str) -> bool: + def delete_prompt_by_command( + self, command: str, db: Optional[Session] = None + ) -> bool: try: - with get_db() as db: + with get_db_context(db) as db: db.query(Prompt).filter_by(command=command).delete() db.commit() diff --git a/backend/open_webui/models/tags.py b/backend/open_webui/models/tags.py index 499f3859dc..64cb559547 100644 --- a/backend/open_webui/models/tags.py +++ b/backend/open_webui/models/tags.py @@ -3,7 +3,8 @@ import uuid from typing import Optional -from open_webui.internal.db import Base, get_db +from sqlalchemy.orm import Session +from open_webui.internal.db import Base, JSONField, get_db, get_db_context from pydantic import BaseModel, ConfigDict @@ -50,8 +51,10 @@ class TagChatIdForm(BaseModel): class TagTable: - def insert_new_tag(self, name: str, user_id: str) -> Optional[TagModel]: - with get_db() as db: + def insert_new_tag( + self, name: str, user_id: str, db: Optional[Session] = None + ) -> Optional[TagModel]: + with get_db_context(db) as db: id = name.replace(" ", "_").lower() tag = TagModel(**{"id": id, "user_id": user_id, "name": name}) try: @@ -68,27 +71,29 @@ def insert_new_tag(self, name: str, user_id: str) -> Optional[TagModel]: return None def get_tag_by_name_and_user_id( - self, name: str, user_id: str + self, name: str, user_id: str, db: Optional[Session] = None ) -> Optional[TagModel]: try: id = name.replace(" ", "_").lower() - with get_db() as db: + with get_db_context(db) as db: tag = db.query(Tag).filter_by(id=id, user_id=user_id).first() return TagModel.model_validate(tag) except Exception: return None - def get_tags_by_user_id(self, user_id: str) -> list[TagModel]: - with get_db() as db: + def get_tags_by_user_id( + self, user_id: str, db: Optional[Session] = None + ) -> list[TagModel]: + with get_db_context(db) as db: return [ TagModel.model_validate(tag) for tag in (db.query(Tag).filter_by(user_id=user_id).all()) ] def get_tags_by_ids_and_user_id( - self, ids: list[str], user_id: str + self, ids: list[str], user_id: str, db: Optional[Session] = None ) -> list[TagModel]: - with get_db() as db: + with get_db_context(db) as db: return [ TagModel.model_validate(tag) for tag in ( @@ -96,9 +101,11 @@ def get_tags_by_ids_and_user_id( ) ] - def delete_tag_by_name_and_user_id(self, name: str, user_id: str) -> bool: + def delete_tag_by_name_and_user_id( + self, name: str, user_id: str, db: Optional[Session] = None + ) -> bool: try: - with get_db() as db: + with get_db_context(db) as db: id = name.replace(" ", "_").lower() res = db.query(Tag).filter_by(id=id, user_id=user_id).delete() log.debug(f"res: {res}") diff --git a/backend/open_webui/models/tools.py b/backend/open_webui/models/tools.py index fff53a7e94..cd7d0bd1a0 100644 --- a/backend/open_webui/models/tools.py +++ b/backend/open_webui/models/tools.py @@ -2,7 +2,8 @@ import time from typing import Optional -from open_webui.internal.db import Base, JSONField, get_db +from sqlalchemy.orm import Session +from open_webui.internal.db import Base, JSONField, get_db, get_db_context from open_webui.models.users import Users, UserResponse from open_webui.models.groups import Groups @@ -96,6 +97,10 @@ class ToolUserResponse(ToolResponse): model_config = ConfigDict(extra="allow") +class ToolAccessResponse(ToolUserResponse): + write_access: Optional[bool] = False + + class ToolForm(BaseModel): id: str name: str @@ -110,9 +115,13 @@ class ToolValves(BaseModel): class ToolsTable: def insert_new_tool( - self, user_id: str, form_data: ToolForm, specs: list[dict] + self, + user_id: str, + form_data: ToolForm, + specs: list[dict], + db: Optional[Session] = None, ) -> Optional[ToolModel]: - with get_db() as db: + with get_db_context(db) as db: tool = ToolModel( **{ **form_data.model_dump(), @@ -136,21 +145,23 @@ def insert_new_tool( log.exception(f"Error creating a new tool: {e}") return None - def get_tool_by_id(self, id: str) -> Optional[ToolModel]: + def get_tool_by_id( + self, id: str, db: Optional[Session] = None + ) -> Optional[ToolModel]: try: - with get_db() as db: + with get_db_context(db) as db: tool = db.get(Tool, id) return ToolModel.model_validate(tool) except Exception: return None - def get_tools(self) -> list[ToolUserModel]: - with get_db() as db: + def get_tools(self, db: Optional[Session] = None) -> list[ToolUserModel]: + with get_db_context(db) as db: all_tools = db.query(Tool).order_by(Tool.updated_at.desc()).all() user_ids = list(set(tool.user_id for tool in all_tools)) - users = Users.get_users_by_user_ids(user_ids) if user_ids else [] + users = Users.get_users_by_user_ids(user_ids, db=db) if user_ids else [] users_dict = {user.id: user for user in users} tools = [] @@ -167,10 +178,12 @@ def get_tools(self) -> list[ToolUserModel]: return tools def get_tools_by_user_id( - self, user_id: str, permission: str = "write" + self, user_id: str, permission: str = "write", db: Optional[Session] = None ) -> list[ToolUserModel]: - tools = self.get_tools() - user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id)} + tools = self.get_tools(db=db) + user_group_ids = { + group.id for group in Groups.get_groups_by_member_id(user_id, db=db) + } return [ tool @@ -179,31 +192,35 @@ def get_tools_by_user_id( or has_access(user_id, permission, tool.access_control, user_group_ids) ] - def get_tool_valves_by_id(self, id: str) -> Optional[dict]: + def get_tool_valves_by_id( + self, id: str, db: Optional[Session] = None + ) -> Optional[dict]: try: - with get_db() as db: + with get_db_context(db) as db: tool = db.get(Tool, id) return tool.valves if tool.valves else {} except Exception as e: log.exception(f"Error getting tool valves by id {id}") return None - def update_tool_valves_by_id(self, id: str, valves: dict) -> Optional[ToolValves]: + def update_tool_valves_by_id( + self, id: str, valves: dict, db: Optional[Session] = None + ) -> Optional[ToolValves]: try: - with get_db() as db: + with get_db_context(db) as db: db.query(Tool).filter_by(id=id).update( {"valves": valves, "updated_at": int(time.time())} ) db.commit() - return self.get_tool_by_id(id) + return self.get_tool_by_id(id, db=db) except Exception: return None def get_user_valves_by_id_and_user_id( - self, id: str, user_id: str + self, id: str, user_id: str, db: Optional[Session] = None ) -> Optional[dict]: try: - user = Users.get_user_by_id(user_id) + user = Users.get_user_by_id(user_id, db=db) user_settings = user.settings.model_dump() if user.settings else {} # Check if user has "tools" and "valves" settings @@ -220,10 +237,10 @@ def get_user_valves_by_id_and_user_id( return None def update_user_valves_by_id_and_user_id( - self, id: str, user_id: str, valves: dict + self, id: str, user_id: str, valves: dict, db: Optional[Session] = None ) -> Optional[dict]: try: - user = Users.get_user_by_id(user_id) + user = Users.get_user_by_id(user_id, db=db) user_settings = user.settings.model_dump() if user.settings else {} # Check if user has "tools" and "valves" settings @@ -235,7 +252,7 @@ def update_user_valves_by_id_and_user_id( user_settings["tools"]["valves"][id] = valves # Update the user settings in the database - Users.update_user_by_id(user_id, {"settings": user_settings}) + Users.update_user_by_id(user_id, {"settings": user_settings}, db=db) return user_settings["tools"]["valves"][id] except Exception as e: @@ -244,9 +261,11 @@ def update_user_valves_by_id_and_user_id( ) return None - def update_tool_by_id(self, id: str, updated: dict) -> Optional[ToolModel]: + def update_tool_by_id( + self, id: str, updated: dict, db: Optional[Session] = None + ) -> Optional[ToolModel]: try: - with get_db() as db: + with get_db_context(db) as db: db.query(Tool).filter_by(id=id).update( {**updated, "updated_at": int(time.time())} ) @@ -258,9 +277,9 @@ def update_tool_by_id(self, id: str, updated: dict) -> Optional[ToolModel]: except Exception: return None - def delete_tool_by_id(self, id: str) -> bool: + def delete_tool_by_id(self, id: str, db: Optional[Session] = None) -> bool: try: - with get_db() as db: + with get_db_context(db) as db: db.query(Tool).filter_by(id=id).delete() db.commit() diff --git a/backend/open_webui/models/users.py b/backend/open_webui/models/users.py index 6be042f7b1..294c11aa74 100644 --- a/backend/open_webui/models/users.py +++ b/backend/open_webui/models/users.py @@ -1,7 +1,8 @@ import time from typing import Optional -from open_webui.internal.db import Base, JSONField, get_db +from sqlalchemy.orm import Session +from open_webui.internal.db import Base, JSONField, get_db, get_db_context from open_webui.env import DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL @@ -249,8 +250,9 @@ def insert_new_user( profile_image_url: str = "/user.png", role: str = "pending", oauth: Optional[dict] = None, + db: Optional[Session] = None, ) -> Optional[UserModel]: - with get_db() as db: + with get_db_context(db) as db: user = UserModel( **{ "id": id, @@ -273,17 +275,21 @@ def insert_new_user( else: return None - def get_user_by_id(self, id: str) -> Optional[UserModel]: + def get_user_by_id( + self, id: str, db: Optional[Session] = None + ) -> Optional[UserModel]: try: - with get_db() as db: + with get_db_context(db) as db: user = db.query(User).filter_by(id=id).first() return UserModel.model_validate(user) except Exception: return None - def get_user_by_api_key(self, api_key: str) -> Optional[UserModel]: + def get_user_by_api_key( + self, api_key: str, db: Optional[Session] = None + ) -> Optional[UserModel]: try: - with get_db() as db: + with get_db_context(db) as db: user = ( db.query(User) .join(ApiKey, User.id == ApiKey.user_id) @@ -294,17 +300,21 @@ def get_user_by_api_key(self, api_key: str) -> Optional[UserModel]: except Exception: return None - def get_user_by_email(self, email: str) -> Optional[UserModel]: + def get_user_by_email( + self, email: str, db: Optional[Session] = None + ) -> Optional[UserModel]: try: - with get_db() as db: + with get_db_context(db) as db: user = db.query(User).filter_by(email=email).first() return UserModel.model_validate(user) except Exception: return None - def get_user_by_oauth_sub(self, provider: str, sub: str) -> Optional[UserModel]: + def get_user_by_oauth_sub( + self, provider: str, sub: str, db: Optional[Session] = None + ) -> Optional[UserModel]: try: - with get_db() as db: # type: Session + with get_db_context(db) as db: # type: Session dialect_name = db.bind.dialect.name query = db.query(User) @@ -326,8 +336,9 @@ def get_users( filter: Optional[dict] = None, skip: Optional[int] = None, limit: Optional[int] = None, + db: Optional[Session] = None, ) -> dict: - with get_db() as db: + with get_db_context(db) as db: # Join GroupMember so we can order by group_id when requested query = db.query(User) @@ -458,8 +469,10 @@ def get_users( "total": total, } - def get_users_by_group_id(self, group_id: str) -> list[UserModel]: - with get_db() as db: + def get_users_by_group_id( + self, group_id: str, db: Optional[Session] = None + ) -> list[UserModel]: + with get_db_context(db) as db: users = ( db.query(User) .join(GroupMember, User.id == GroupMember.user_id) @@ -468,30 +481,34 @@ def get_users_by_group_id(self, group_id: str) -> list[UserModel]: ) return [UserModel.model_validate(user) for user in users] - def get_users_by_user_ids(self, user_ids: list[str]) -> list[UserStatusModel]: - with get_db() as db: + def get_users_by_user_ids( + self, user_ids: list[str], db: Optional[Session] = None + ) -> list[UserStatusModel]: + with get_db_context(db) as db: users = db.query(User).filter(User.id.in_(user_ids)).all() return [UserModel.model_validate(user) for user in users] - def get_num_users(self) -> Optional[int]: - with get_db() as db: + def get_num_users(self, db: Optional[Session] = None) -> Optional[int]: + with get_db_context(db) as db: return db.query(User).count() - def has_users(self) -> bool: - with get_db() as db: + def has_users(self, db: Optional[Session] = None) -> bool: + with get_db_context(db) as db: return db.query(db.query(User).exists()).scalar() - def get_first_user(self) -> UserModel: + def get_first_user(self, db: Optional[Session] = None) -> UserModel: try: - with get_db() as db: + with get_db_context(db) as db: user = db.query(User).order_by(User.created_at).first() return UserModel.model_validate(user) except Exception: return None - def get_user_webhook_url_by_id(self, id: str) -> Optional[str]: + def get_user_webhook_url_by_id( + self, id: str, db: Optional[Session] = None + ) -> Optional[str]: try: - with get_db() as db: + with get_db_context(db) as db: user = db.query(User).filter_by(id=id).first() if user.settings is None: @@ -505,8 +522,8 @@ def get_user_webhook_url_by_id(self, id: str) -> Optional[str]: except Exception: return None - def get_num_users_active_today(self) -> Optional[int]: - with get_db() as db: + def get_num_users_active_today(self, db: Optional[Session] = None) -> Optional[int]: + with get_db_context(db) as db: current_timestamp = int(datetime.datetime.now().timestamp()) today_midnight_timestamp = current_timestamp - (current_timestamp % 86400) query = db.query(User).filter( @@ -514,9 +531,11 @@ def get_num_users_active_today(self) -> Optional[int]: ) return query.count() - def update_user_role_by_id(self, id: str, role: str) -> Optional[UserModel]: + def update_user_role_by_id( + self, id: str, role: str, db: Optional[Session] = None + ) -> Optional[UserModel]: try: - with get_db() as db: + with get_db_context(db) as db: db.query(User).filter_by(id=id).update({"role": role}) db.commit() user = db.query(User).filter_by(id=id).first() @@ -525,10 +544,10 @@ def update_user_role_by_id(self, id: str, role: str) -> Optional[UserModel]: return None def update_user_status_by_id( - self, id: str, form_data: UserStatus + self, id: str, form_data: UserStatus, db: Optional[Session] = None ) -> Optional[UserModel]: try: - with get_db() as db: + with get_db_context(db) as db: db.query(User).filter_by(id=id).update( {**form_data.model_dump(exclude_none=True)} ) @@ -540,10 +559,10 @@ def update_user_status_by_id( return None def update_user_profile_image_url_by_id( - self, id: str, profile_image_url: str + self, id: str, profile_image_url: str, db: Optional[Session] = None ) -> Optional[UserModel]: try: - with get_db() as db: + with get_db_context(db) as db: db.query(User).filter_by(id=id).update( {"profile_image_url": profile_image_url} ) @@ -555,9 +574,11 @@ def update_user_profile_image_url_by_id( return None @throttle(DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL) - def update_last_active_by_id(self, id: str) -> Optional[UserModel]: + def update_last_active_by_id( + self, id: str, db: Optional[Session] = None + ) -> Optional[UserModel]: try: - with get_db() as db: + with get_db_context(db) as db: db.query(User).filter_by(id=id).update( {"last_active_at": int(time.time())} ) @@ -569,7 +590,7 @@ def update_last_active_by_id(self, id: str) -> Optional[UserModel]: return None def update_user_oauth_by_id( - self, id: str, provider: str, sub: str + self, id: str, provider: str, sub: str, db: Optional[Session] = None ) -> Optional[UserModel]: """ Update or insert an OAuth provider/sub pair into the user's oauth JSON field. @@ -580,7 +601,7 @@ def update_user_oauth_by_id( } """ try: - with get_db() as db: + with get_db_context(db) as db: user = db.query(User).filter_by(id=id).first() if not user: return None @@ -600,9 +621,11 @@ def update_user_oauth_by_id( except Exception: return None - def update_user_by_id(self, id: str, updated: dict) -> Optional[UserModel]: + def update_user_by_id( + self, id: str, updated: dict, db: Optional[Session] = None + ) -> Optional[UserModel]: try: - with get_db() as db: + with get_db_context(db) as db: db.query(User).filter_by(id=id).update(updated) db.commit() @@ -613,10 +636,16 @@ def update_user_by_id(self, id: str, updated: dict) -> Optional[UserModel]: print(e) return None - def update_user_settings_by_id(self, id: str, updated: dict) -> Optional[UserModel]: + def update_user_settings_by_id( + self, id: str, updated: dict, db: Optional[Session] = None + ) -> Optional[UserModel]: try: - with get_db() as db: - user_settings = db.query(User).filter_by(id=id).first().settings + with get_db_context(db) as db: + user = db.query(User).filter_by(id=id).first() + if not user: + return None + + user_settings = user.settings if user_settings is None: user_settings = {} @@ -631,15 +660,15 @@ def update_user_settings_by_id(self, id: str, updated: dict) -> Optional[UserMod except Exception: return None - def delete_user_by_id(self, id: str) -> bool: + def delete_user_by_id(self, id: str, db: Optional[Session] = None) -> bool: try: # Remove User from Groups Groups.remove_user_from_all_groups(id) # Delete User Chats - result = Chats.delete_chats_by_user_id(id) + result = Chats.delete_chats_by_user_id(id, db=db) if result: - with get_db() as db: + with get_db_context(db) as db: # Delete User db.query(User).filter_by(id=id).delete() db.commit() @@ -650,17 +679,21 @@ def delete_user_by_id(self, id: str) -> bool: except Exception: return False - def get_user_api_key_by_id(self, id: str) -> Optional[str]: + def get_user_api_key_by_id( + self, id: str, db: Optional[Session] = None + ) -> Optional[str]: try: - with get_db() as db: + with get_db_context(db) as db: api_key = db.query(ApiKey).filter_by(user_id=id).first() return api_key.key if api_key else None except Exception: return None - def update_user_api_key_by_id(self, id: str, api_key: str) -> bool: + def update_user_api_key_by_id( + self, id: str, api_key: str, db: Optional[Session] = None + ) -> bool: try: - with get_db() as db: + with get_db_context(db) as db: db.query(ApiKey).filter_by(user_id=id).delete() db.commit() @@ -680,30 +713,32 @@ def update_user_api_key_by_id(self, id: str, api_key: str) -> bool: except Exception: return False - def delete_user_api_key_by_id(self, id: str) -> bool: + def delete_user_api_key_by_id(self, id: str, db: Optional[Session] = None) -> bool: try: - with get_db() as db: + with get_db_context(db) as db: db.query(ApiKey).filter_by(user_id=id).delete() db.commit() return True except Exception: return False - def get_valid_user_ids(self, user_ids: list[str]) -> list[str]: - with get_db() as db: + def get_valid_user_ids( + self, user_ids: list[str], db: Optional[Session] = None + ) -> list[str]: + with get_db_context(db) as db: users = db.query(User).filter(User.id.in_(user_ids)).all() return [user.id for user in users] - def get_super_admin_user(self) -> Optional[UserModel]: - with get_db() as db: + def get_super_admin_user(self, db: Optional[Session] = None) -> Optional[UserModel]: + with get_db_context(db) as db: user = db.query(User).filter_by(role="admin").first() if user: return UserModel.model_validate(user) else: return None - def get_active_user_count(self) -> int: - with get_db() as db: + def get_active_user_count(self, db: Optional[Session] = None) -> int: + with get_db_context(db) as db: # Consider user active if last_active_at within the last 3 minutes three_minutes_ago = int(time.time()) - 180 count = ( @@ -711,8 +746,8 @@ def get_active_user_count(self) -> int: ) return count - def is_user_active(self, user_id: str) -> bool: - with get_db() as db: + def is_user_active(self, user_id: str, db: Optional[Session] = None) -> bool: + with get_db_context(db) as db: user = db.query(User).filter_by(id=user_id).first() if user and user.last_active_at: # Consider user active if last_active_at within the last 3 minutes diff --git a/backend/open_webui/retrieval/loaders/main.py b/backend/open_webui/retrieval/loaders/main.py index b0bc1dd068..2b83e44283 100644 --- a/backend/open_webui/retrieval/loaders/main.py +++ b/backend/open_webui/retrieval/loaders/main.py @@ -30,7 +30,7 @@ from open_webui.retrieval.loaders.mineru import MinerULoader -from open_webui.env import GLOBAL_LOG_LEVEL +from open_webui.env import GLOBAL_LOG_LEVEL, REQUESTS_VERIFY logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL) log = logging.getLogger(__name__) @@ -114,7 +114,7 @@ def load(self) -> list[Document]: endpoint += "/" endpoint += "tika/text" - r = requests.put(endpoint, data=data, headers=headers) + r = requests.put(endpoint, data=data, headers=headers, verify=REQUESTS_VERIFY) if r.ok: raw_metadata = r.json() diff --git a/backend/open_webui/retrieval/models/external.py b/backend/open_webui/retrieval/models/external.py index d567bf4fe5..095143d20d 100644 --- a/backend/open_webui/retrieval/models/external.py +++ b/backend/open_webui/retrieval/models/external.py @@ -4,7 +4,7 @@ from urllib.parse import quote -from open_webui.env import ENABLE_FORWARD_USER_INFO_HEADERS +from open_webui.env import ENABLE_FORWARD_USER_INFO_HEADERS, REQUESTS_VERIFY from open_webui.retrieval.models.base_reranker import BaseReranker from open_webui.utils.headers import include_user_info_headers @@ -55,6 +55,7 @@ def predict( headers=headers, json=payload, timeout=self.timeout, + verify=REQUESTS_VERIFY, ) r.raise_for_status() diff --git a/backend/open_webui/retrieval/utils.py b/backend/open_webui/retrieval/utils.py index 5cfa659f79..61b1a947a6 100644 --- a/backend/open_webui/retrieval/utils.py +++ b/backend/open_webui/retrieval/utils.py @@ -38,8 +38,10 @@ from open_webui.retrieval.loaders.youtube import YoutubeLoader from open_webui.env import ( + AIOHTTP_CLIENT_TIMEOUT, OFFLINE_MODE, ENABLE_FORWARD_USER_INFO_HEADERS, + AIOHTTP_CLIENT_SESSION_SSL, ) from open_webui.config import ( RAG_EMBEDDING_QUERY_PREFIX, @@ -559,7 +561,9 @@ async def agenerate_openai_batch_embeddings( if ENABLE_FORWARD_USER_INFO_HEADERS and user: headers = include_user_info_headers(headers, user) - async with aiohttp.ClientSession(trust_env=True) as session: + async with aiohttp.ClientSession( + trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT) + ) as session: async with session.post( f"{url}/embeddings", headers=headers, json=form_data ) as r: @@ -625,7 +629,9 @@ async def agenerate_azure_openai_batch_embeddings( if ENABLE_FORWARD_USER_INFO_HEADERS and user: headers = include_user_info_headers(headers, user) - async with aiohttp.ClientSession(trust_env=True) as session: + async with aiohttp.ClientSession( + trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT) + ) as session: async with session.post(full_url, headers=headers, json=form_data) as r: r.raise_for_status() data = await r.json() @@ -682,9 +688,14 @@ async def agenerate_ollama_batch_embeddings( if ENABLE_FORWARD_USER_INFO_HEADERS and user: headers = include_user_info_headers(headers, user) - async with aiohttp.ClientSession(trust_env=True) as session: + async with aiohttp.ClientSession( + trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT) + ) as session: async with session.post( - f"{url}/api/embed", headers=headers, json=form_data + f"{url}/api/embed", + headers=headers, + json=form_data, + ssl=AIOHTTP_CLIENT_SESSION_SSL, ) as r: r.raise_for_status() data = await r.json() @@ -724,7 +735,9 @@ async def async_embedding_function(query, prefix=None, user=None): return await asyncio.to_thread( ( lambda query, prefix=None: embedding_function.encode( - query, **({"prompt": prefix} if prefix else {}) + query, + batch_size=int(embedding_batch_size), + **({"prompt": prefix} if prefix else {}), ).tolist() ), query, diff --git a/backend/open_webui/retrieval/vector/dbs/chroma.py b/backend/open_webui/retrieval/vector/dbs/chroma.py index 69d894afde..b7ea5244b4 100755 --- a/backend/open_webui/retrieval/vector/dbs/chroma.py +++ b/backend/open_webui/retrieval/vector/dbs/chroma.py @@ -69,7 +69,11 @@ def delete_collection(self, collection_name: str): return self.client.delete_collection(name=collection_name) def search( - self, collection_name: str, vectors: list[list[float | int]], limit: int + self, + collection_name: str, + vectors: list[list[float | int]], + filter: Optional[dict] = None, + limit: int = 10, ) -> Optional[SearchResult]: # Search for the nearest neighbor items based on the vectors and return 'limit' number of results. try: @@ -78,6 +82,7 @@ def search( result = collection.query( query_embeddings=vectors, n_results=limit, + where=filter, ) # chromadb has cosine distance, 2 (worst) -> 0 (best). Re-odering to 0 -> 1 diff --git a/backend/open_webui/retrieval/vector/dbs/elasticsearch.py b/backend/open_webui/retrieval/vector/dbs/elasticsearch.py index 6de0d859f8..e209453f5c 100644 --- a/backend/open_webui/retrieval/vector/dbs/elasticsearch.py +++ b/backend/open_webui/retrieval/vector/dbs/elasticsearch.py @@ -153,7 +153,11 @@ def delete_collection(self, collection_name: str): # Status: works def search( - self, collection_name: str, vectors: list[list[float]], limit: int + self, + collection_name: str, + vectors: list[list[float]], + filter: Optional[dict] = None, + limit: int = 10, ) -> Optional[SearchResult]: query = { "size": limit, diff --git a/backend/open_webui/retrieval/vector/dbs/milvus.py b/backend/open_webui/retrieval/vector/dbs/milvus.py index 23e4bbd03e..35cf6b3829 100644 --- a/backend/open_webui/retrieval/vector/dbs/milvus.py +++ b/backend/open_webui/retrieval/vector/dbs/milvus.py @@ -179,7 +179,11 @@ def delete_collection(self, collection_name: str): ) def search( - self, collection_name: str, vectors: list[list[float | int]], limit: int + self, + collection_name: str, + vectors: list[list[float | int]], + filter: Optional[dict] = None, + limit: int = 10, ) -> Optional[SearchResult]: # Search for the nearest neighbor items based on the vectors and return 'limit' number of results. collection_name = collection_name.replace("-", "_") diff --git a/backend/open_webui/retrieval/vector/dbs/milvus_multitenancy.py b/backend/open_webui/retrieval/vector/dbs/milvus_multitenancy.py index 203a36141e..c58189b2a3 100644 --- a/backend/open_webui/retrieval/vector/dbs/milvus_multitenancy.py +++ b/backend/open_webui/retrieval/vector/dbs/milvus_multitenancy.py @@ -157,7 +157,11 @@ def upsert(self, collection_name: str, items: List[VectorItem]): collection.insert(entities) def search( - self, collection_name: str, vectors: List[List[float]], limit: int + self, + collection_name: str, + vectors: List[List[float]], + filter: Optional[Dict] = None, + limit: int = 10, ) -> Optional[SearchResult]: if not vectors: return None diff --git a/backend/open_webui/retrieval/vector/dbs/opengauss.py b/backend/open_webui/retrieval/vector/dbs/opengauss.py new file mode 100644 index 0000000000..7d4f9ea092 --- /dev/null +++ b/backend/open_webui/retrieval/vector/dbs/opengauss.py @@ -0,0 +1,427 @@ +from typing import Optional, List, Dict, Any +import logging +import re +import json +from sqlalchemy import ( + func, + literal, + cast, + column, + create_engine, + Column, + Integer, + MetaData, + LargeBinary, + select, + text, + Text, + Table, + values, +) +from sqlalchemy.sql import true +from sqlalchemy.pool import NullPool, QueuePool + +from sqlalchemy.orm import declarative_base, scoped_session, sessionmaker +from sqlalchemy.dialects.postgresql import JSONB, array +from pgvector.sqlalchemy import Vector +from sqlalchemy.ext.mutable import MutableDict +from sqlalchemy.exc import NoSuchTableError + +from sqlalchemy.dialects.postgresql.psycopg2 import PGDialect_psycopg2 +from sqlalchemy.dialects import registry + + +class OpenGaussDialect(PGDialect_psycopg2): + name = "opengauss" + + def _get_server_version_info(self, connection): + try: + version = connection.exec_driver_sql("SELECT version()").scalar() + if not version: + return (9, 0, 0) + + match = re.search( + r"openGauss\s+(\d+)\.(\d+)\.(\d+)(?:-\w+)?", version, re.IGNORECASE + ) + if match: + return (int(match.group(1)), int(match.group(2)), int(match.group(3))) + + return super()._get_server_version_info(connection) + except Exception: + return (9, 0, 0) + + +# Register dialect +registry.register("opengauss", __name__, "OpenGaussDialect") + +from open_webui.retrieval.vector.utils import process_metadata +from open_webui.retrieval.vector.main import ( + VectorDBBase, + VectorItem, + SearchResult, + GetResult, +) +from open_webui.config import ( + OPENGAUSS_DB_URL, + OPENGAUSS_INITIALIZE_MAX_VECTOR_LENGTH, + OPENGAUSS_POOL_SIZE, + OPENGAUSS_POOL_MAX_OVERFLOW, + OPENGAUSS_POOL_TIMEOUT, + OPENGAUSS_POOL_RECYCLE, +) + +from open_webui.env import SRC_LOG_LEVELS + +VECTOR_LENGTH = OPENGAUSS_INITIALIZE_MAX_VECTOR_LENGTH +Base = declarative_base() + +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["RAG"]) + + +class DocumentChunk(Base): + __tablename__ = "document_chunk" + + id = Column(Text, primary_key=True) + vector = Column(Vector(dim=VECTOR_LENGTH), nullable=True) + collection_name = Column(Text, nullable=False) + text = Column(Text, nullable=True) + vmetadata = Column(MutableDict.as_mutable(JSONB), nullable=True) + + +class OpenGaussClient(VectorDBBase): + def __init__(self) -> None: + if not OPENGAUSS_DB_URL: + from open_webui.internal.db import ScopedSession + + self.session = ScopedSession + else: + engine_kwargs = {"pool_pre_ping": True, "dialect": OpenGaussDialect()} + + if isinstance(OPENGAUSS_POOL_SIZE, int) and OPENGAUSS_POOL_SIZE > 0: + engine_kwargs.update( + { + "pool_size": OPENGAUSS_POOL_SIZE, + "max_overflow": OPENGAUSS_POOL_MAX_OVERFLOW, + "pool_timeout": OPENGAUSS_POOL_TIMEOUT, + "pool_recycle": OPENGAUSS_POOL_RECYCLE, + "poolclass": QueuePool, + } + ) + else: + engine_kwargs["poolclass"] = NullPool + + engine = create_engine(OPENGAUSS_DB_URL, **engine_kwargs) + + SessionLocal = sessionmaker( + autocommit=False, autoflush=False, bind=engine, expire_on_commit=False + ) + self.session = scoped_session(SessionLocal) + + try: + connection = self.session.connection() + Base.metadata.create_all(bind=connection) + + self.session.execute( + text( + "CREATE INDEX IF NOT EXISTS idx_document_chunk_vector " + "ON document_chunk USING ivfflat (vector vector_cosine_ops) WITH (lists = 100);" + ) + ) + self.session.execute( + text( + "CREATE INDEX IF NOT EXISTS idx_document_chunk_collection_name " + "ON document_chunk (collection_name);" + ) + ) + self.session.commit() + log.info("OpenGauss vector database initialization completed.") + except Exception as e: + self.session.rollback() + log.exception(f"OpenGauss Initialization failed.: {e}") + raise + + def check_vector_length(self) -> None: + metadata = MetaData() + try: + document_chunk_table = Table( + "document_chunk", metadata, autoload_with=self.session.bind + ) + except NoSuchTableError: + return + + if "vector" in document_chunk_table.columns: + vector_column = document_chunk_table.columns["vector"] + vector_type = vector_column.type + if isinstance(vector_type, Vector): + db_vector_length = vector_type.dim + if db_vector_length != VECTOR_LENGTH: + raise Exception( + f"Vector dimension mismatch: configured {VECTOR_LENGTH} vs. {db_vector_length} in the database." + ) + else: + raise Exception("The 'vector' column type is not Vector.") + else: + raise Exception( + "The 'vector' column does not exist in the 'document_chunk' table." + ) + + def adjust_vector_length(self, vector: List[float]) -> List[float]: + current_length = len(vector) + if current_length < VECTOR_LENGTH: + vector += [0.0] * (VECTOR_LENGTH - current_length) + elif current_length > VECTOR_LENGTH: + vector = vector[:VECTOR_LENGTH] + return vector + + def insert(self, collection_name: str, items: List[VectorItem]) -> None: + try: + new_items = [] + for item in items: + vector = self.adjust_vector_length(item["vector"]) + new_chunk = DocumentChunk( + id=item["id"], + vector=vector, + collection_name=collection_name, + text=item["text"], + vmetadata=process_metadata(item["metadata"]), + ) + new_items.append(new_chunk) + self.session.bulk_save_objects(new_items) + self.session.commit() + log.info( + f"Inserting {len(new_items)} items into collection '{collection_name}'." + ) + except Exception as e: + self.session.rollback() + log.exception(f"Failed to insert data: {e}") + raise + + def upsert(self, collection_name: str, items: List[VectorItem]) -> None: + try: + for item in items: + vector = self.adjust_vector_length(item["vector"]) + existing = ( + self.session.query(DocumentChunk) + .filter(DocumentChunk.id == item["id"]) + .first() + ) + if existing: + existing.vector = vector + existing.text = item["text"] + existing.vmetadata = process_metadata(item["metadata"]) + existing.collection_name = collection_name + else: + new_chunk = DocumentChunk( + id=item["id"], + vector=vector, + collection_name=collection_name, + text=item["text"], + vmetadata=process_metadata(item["metadata"]), + ) + self.session.add(new_chunk) + self.session.commit() + log.info( + f"Inserting/updating {len(items)} items in collection '{collection_name}'." + ) + except Exception as e: + self.session.rollback() + log.exception(f"Failed to insert or update data.: {e}") + raise + + def search( + self, + collection_name: str, + vectors: List[List[float]], + filter: Optional[Dict[str, Any]] = None, + limit: int = 10, + ) -> Optional[SearchResult]: + try: + if not vectors: + return None + + vectors = [self.adjust_vector_length(vector) for vector in vectors] + num_queries = len(vectors) + + def vector_expr(vector): + return cast(array(vector), Vector(VECTOR_LENGTH)) + + qid_col = column("qid", Integer) + q_vector_col = column("q_vector", Vector(VECTOR_LENGTH)) + query_vectors = ( + values(qid_col, q_vector_col) + .data( + [(idx, vector_expr(vector)) for idx, vector in enumerate(vectors)] + ) + .alias("query_vectors") + ) + + result_fields = [ + DocumentChunk.id, + DocumentChunk.text, + DocumentChunk.vmetadata, + (DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)).label( + "distance" + ), + ] + + subq = ( + select(*result_fields) + .where(DocumentChunk.collection_name == collection_name) + .order_by( + DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector) + ) + ) + if limit is not None: + subq = subq.limit(limit) + subq = subq.lateral("result") + + stmt = ( + select( + query_vectors.c.qid, + subq.c.id, + subq.c.text, + subq.c.vmetadata, + subq.c.distance, + ) + .select_from(query_vectors) + .join(subq, true()) + .order_by(query_vectors.c.qid, subq.c.distance) + ) + + result_proxy = self.session.execute(stmt) + results = result_proxy.all() + + ids = [[] for _ in range(num_queries)] + distances = [[] for _ in range(num_queries)] + documents = [[] for _ in range(num_queries)] + metadatas = [[] for _ in range(num_queries)] + + for row in results: + qid = int(row.qid) + ids[qid].append(row.id) + distances[qid].append((2.0 - row.distance) / 2.0) + documents[qid].append(row.text) + metadatas[qid].append(row.vmetadata) + + self.session.rollback() + return SearchResult( + ids=ids, distances=distances, documents=documents, metadatas=metadatas + ) + except Exception as e: + self.session.rollback() + log.exception(f"Vector search failed: {e}") + return None + + def query( + self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None + ) -> Optional[GetResult]: + try: + query = self.session.query(DocumentChunk).filter( + DocumentChunk.collection_name == collection_name + ) + + for key, value in filter.items(): + query = query.filter(DocumentChunk.vmetadata[key].astext == str(value)) + + if limit is not None: + query = query.limit(limit) + + results = query.all() + + if not results: + return None + + ids = [[result.id for result in results]] + documents = [[result.text for result in results]] + metadatas = [[result.vmetadata for result in results]] + + self.session.rollback() + return GetResult(ids=ids, documents=documents, metadatas=metadatas) + except Exception as e: + self.session.rollback() + log.exception(f"Conditional query failed: {e}") + return None + + def get( + self, collection_name: str, limit: Optional[int] = None + ) -> Optional[GetResult]: + try: + query = self.session.query(DocumentChunk).filter( + DocumentChunk.collection_name == collection_name + ) + if limit is not None: + query = query.limit(limit) + + results = query.all() + + if not results: + return None + + ids = [[result.id for result in results]] + documents = [[result.text for result in results]] + metadatas = [[result.vmetadata for result in results]] + + self.session.rollback() + return GetResult(ids=ids, documents=documents, metadatas=metadatas) + except Exception as e: + self.session.rollback() + log.exception(f"Failed to retrieve data: {e}") + return None + + def delete( + self, + collection_name: str, + ids: Optional[List[str]] = None, + filter: Optional[Dict[str, Any]] = None, + ) -> None: + try: + query = self.session.query(DocumentChunk).filter( + DocumentChunk.collection_name == collection_name + ) + if ids: + query = query.filter(DocumentChunk.id.in_(ids)) + if filter: + for key, value in filter.items(): + query = query.filter( + DocumentChunk.vmetadata[key].astext == str(value) + ) + deleted = query.delete(synchronize_session=False) + self.session.commit() + log.info(f"Deleted {deleted} items from collection '{collection_name}'") + except Exception as e: + self.session.rollback() + log.exception(f"Failed to delete data: {e}") + raise + + def reset(self) -> None: + try: + deleted = self.session.query(DocumentChunk).delete() + self.session.commit() + log.info(f"Reset completed. Deleted {deleted} items") + except Exception as e: + self.session.rollback() + log.exception(f"Reset failed: {e}") + raise + + def close(self) -> None: + pass + + def has_collection(self, collection_name: str) -> bool: + try: + exists = ( + self.session.query(DocumentChunk) + .filter(DocumentChunk.collection_name == collection_name) + .first() + is not None + ) + self.session.rollback() + return exists + except Exception as e: + self.session.rollback() + log.exception(f"Failed to check collection existence: {e}") + return False + + def delete_collection(self, collection_name: str) -> None: + self.delete(collection_name) + log.info(f"Collection '{collection_name}' has been deleted") diff --git a/backend/open_webui/retrieval/vector/dbs/opensearch.py b/backend/open_webui/retrieval/vector/dbs/opensearch.py index 2e946710e2..dc9c35805e 100644 --- a/backend/open_webui/retrieval/vector/dbs/opensearch.py +++ b/backend/open_webui/retrieval/vector/dbs/opensearch.py @@ -113,7 +113,11 @@ def delete_collection(self, collection_name: str): self.client.indices.delete(index=self._get_index_name(collection_name)) def search( - self, collection_name: str, vectors: list[list[float | int]], limit: int + self, + collection_name: str, + vectors: list[list[float | int]], + filter: Optional[dict] = None, + limit: int = 10, ) -> Optional[SearchResult]: try: if not self.has_collection(collection_name): diff --git a/backend/open_webui/retrieval/vector/dbs/oracle23ai.py b/backend/open_webui/retrieval/vector/dbs/oracle23ai.py index 3f5c3463f0..9f16f82bc9 100644 --- a/backend/open_webui/retrieval/vector/dbs/oracle23ai.py +++ b/backend/open_webui/retrieval/vector/dbs/oracle23ai.py @@ -521,7 +521,11 @@ def upsert(self, collection_name: str, items: List[VectorItem]) -> None: raise def search( - self, collection_name: str, vectors: List[List[Union[float, int]]], limit: int + self, + collection_name: str, + vectors: List[List[Union[float, int]]], + filter: Optional[dict] = None, + limit: int = 10, ) -> Optional[SearchResult]: """ Search for similar vectors in the database. diff --git a/backend/open_webui/retrieval/vector/dbs/pgvector.py b/backend/open_webui/retrieval/vector/dbs/pgvector.py index 2f4677995a..15430db114 100644 --- a/backend/open_webui/retrieval/vector/dbs/pgvector.py +++ b/backend/open_webui/retrieval/vector/dbs/pgvector.py @@ -90,9 +90,9 @@ def __init__(self) -> None: # if no pgvector uri, use the existing database connection if not PGVECTOR_DB_URL: - from open_webui.internal.db import Session + from open_webui.internal.db import ScopedSession - self.session = Session + self.session = ScopedSession else: if isinstance(PGVECTOR_POOL_SIZE, int): if PGVECTOR_POOL_SIZE > 0: @@ -427,7 +427,8 @@ def search( self, collection_name: str, vectors: List[List[float]], - limit: Optional[int] = None, + filter: Optional[Dict[str, Any]] = None, + limit: int = 10, ) -> Optional[SearchResult]: try: if not vectors: @@ -475,9 +476,47 @@ def vector_expr(vector): ) # Build the lateral subquery for each query vector + where_clauses = [DocumentChunk.collection_name == collection_name] + + # Apply metadata filter if provided + if filter: + for key, value in filter.items(): + if isinstance(value, dict) and "$in" in value: + # Handle $in operator: {"field": {"$in": [values]}} + in_values = value["$in"] + if PGVECTOR_PGCRYPTO: + where_clauses.append( + pgcrypto_decrypt( + DocumentChunk.vmetadata, + PGVECTOR_PGCRYPTO_KEY, + JSONB, + )[key].astext.in_([str(v) for v in in_values]) + ) + else: + where_clauses.append( + DocumentChunk.vmetadata[key].astext.in_( + [str(v) for v in in_values] + ) + ) + else: + # Handle simple equality: {"field": "value"} + if PGVECTOR_PGCRYPTO: + where_clauses.append( + pgcrypto_decrypt( + DocumentChunk.vmetadata, + PGVECTOR_PGCRYPTO_KEY, + JSONB, + )[key].astext + == str(value) + ) + else: + where_clauses.append( + DocumentChunk.vmetadata[key].astext == str(value) + ) + subq = ( select(*result_fields) - .where(DocumentChunk.collection_name == collection_name) + .where(*where_clauses) .order_by( (DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)) ) diff --git a/backend/open_webui/retrieval/vector/dbs/pinecone.py b/backend/open_webui/retrieval/vector/dbs/pinecone.py index 94d09dabf5..fc3c98f8cf 100644 --- a/backend/open_webui/retrieval/vector/dbs/pinecone.py +++ b/backend/open_webui/retrieval/vector/dbs/pinecone.py @@ -391,7 +391,11 @@ async def upsert_async(self, collection_name: str, items: List[VectorItem]) -> N ) def search( - self, collection_name: str, vectors: List[List[Union[float, int]]], limit: int + self, + collection_name: str, + vectors: List[List[Union[float, int]]], + filter: Optional[dict] = None, + limit: int = 10, ) -> Optional[SearchResult]: """Search for similar vectors in a collection.""" if not vectors or not vectors[0]: diff --git a/backend/open_webui/retrieval/vector/dbs/qdrant.py b/backend/open_webui/retrieval/vector/dbs/qdrant.py index ce7095bea2..d42984e1d6 100644 --- a/backend/open_webui/retrieval/vector/dbs/qdrant.py +++ b/backend/open_webui/retrieval/vector/dbs/qdrant.py @@ -145,7 +145,11 @@ def delete_collection(self, collection_name: str): ) def search( - self, collection_name: str, vectors: list[list[float | int]], limit: int + self, + collection_name: str, + vectors: list[list[float | int]], + filter: Optional[dict] = None, + limit: int = 10, ) -> Optional[SearchResult]: # Search for the nearest neighbor items based on the vectors and return 'limit' number of results. if limit is None: diff --git a/backend/open_webui/retrieval/vector/dbs/qdrant_multitenancy.py b/backend/open_webui/retrieval/vector/dbs/qdrant_multitenancy.py index fdc8f9d897..f87f85a23b 100644 --- a/backend/open_webui/retrieval/vector/dbs/qdrant_multitenancy.py +++ b/backend/open_webui/retrieval/vector/dbs/qdrant_multitenancy.py @@ -254,7 +254,11 @@ def delete( ) def search( - self, collection_name: str, vectors: List[List[float | int]], limit: int + self, + collection_name: str, + vectors: List[List[float | int]], + filter: Optional[Dict] = None, + limit: int = 10, ) -> Optional[SearchResult]: """ Search for the nearest neighbor items based on the vectors with tenant isolation. diff --git a/backend/open_webui/retrieval/vector/dbs/s3vector.py b/backend/open_webui/retrieval/vector/dbs/s3vector.py index 95fc5d3f24..96e487f111 100644 --- a/backend/open_webui/retrieval/vector/dbs/s3vector.py +++ b/backend/open_webui/retrieval/vector/dbs/s3vector.py @@ -295,7 +295,11 @@ def upsert(self, collection_name: str, items: List[VectorItem]) -> None: raise def search( - self, collection_name: str, vectors: List[List[Union[float, int]]], limit: int + self, + collection_name: str, + vectors: List[List[Union[float, int]]], + filter: Optional[dict] = None, + limit: int = 10, ) -> Optional[SearchResult]: """ Search for similar vectors in a collection using multiple query vectors. diff --git a/backend/open_webui/retrieval/vector/dbs/weaviate.py b/backend/open_webui/retrieval/vector/dbs/weaviate.py index 6bb8a1ecb4..d204e8293a 100644 --- a/backend/open_webui/retrieval/vector/dbs/weaviate.py +++ b/backend/open_webui/retrieval/vector/dbs/weaviate.py @@ -159,7 +159,11 @@ def upsert(self, collection_name: str, items: List[VectorItem]) -> None: ) def search( - self, collection_name: str, vectors: List[List[Union[float, int]]], limit: int + self, + collection_name: str, + vectors: List[List[Union[float, int]]], + filter: Optional[dict] = None, + limit: int = 10, ) -> Optional[SearchResult]: sane_collection_name = self._sanitize_collection_name(collection_name) if not self.client.collections.exists(sane_collection_name): diff --git a/backend/open_webui/retrieval/vector/factory.py b/backend/open_webui/retrieval/vector/factory.py index b843e0926d..68595fb595 100644 --- a/backend/open_webui/retrieval/vector/factory.py +++ b/backend/open_webui/retrieval/vector/factory.py @@ -53,6 +53,10 @@ def get_vector(vector_type: str) -> VectorDBBase: from open_webui.retrieval.vector.dbs.pgvector import PgvectorClient return PgvectorClient() + case VectorType.OPENGAUSS: + from open_webui.retrieval.vector.dbs.opengauss import OpenGaussClient + + return OpenGaussClient() case VectorType.ELASTICSEARCH: from open_webui.retrieval.vector.dbs.elasticsearch import ( ElasticsearchClient, diff --git a/backend/open_webui/retrieval/vector/main.py b/backend/open_webui/retrieval/vector/main.py index 53f752f579..a76fec9956 100644 --- a/backend/open_webui/retrieval/vector/main.py +++ b/backend/open_webui/retrieval/vector/main.py @@ -53,7 +53,11 @@ def upsert(self, collection_name: str, items: List[VectorItem]) -> None: @abstractmethod def search( - self, collection_name: str, vectors: List[List[Union[float, int]]], limit: int + self, + collection_name: str, + vectors: List[List[Union[float, int]]], + filter: Optional[Dict] = None, + limit: int = 10, ) -> Optional[SearchResult]: """Search for similar vectors in a collection.""" pass diff --git a/backend/open_webui/retrieval/vector/type.py b/backend/open_webui/retrieval/vector/type.py index 292cad1e78..de20133fce 100644 --- a/backend/open_webui/retrieval/vector/type.py +++ b/backend/open_webui/retrieval/vector/type.py @@ -12,3 +12,4 @@ class VectorType(StrEnum): ORACLE23AI = "oracle23ai" S3VECTOR = "s3vector" WEAVIATE = "weaviate" + OPENGAUSS = "opengauss" diff --git a/backend/open_webui/retrieval/web/brave.py b/backend/open_webui/retrieval/web/brave.py index e047602b36..49c8a88e81 100644 --- a/backend/open_webui/retrieval/web/brave.py +++ b/backend/open_webui/retrieval/web/brave.py @@ -1,4 +1,5 @@ import logging +import time from typing import Optional import requests @@ -25,6 +26,14 @@ def search_brave( params = {"q": query, "count": count} response = requests.get(url, headers=headers, params=params) + + # Handle 429 rate limiting - Brave free tier allows 1 request/second + # If rate limited, wait 1 second and retry once before failing + if response.status_code == 429: + log.info("Brave Search API rate limited (429), retrying after 1 second...") + time.sleep(1) + response = requests.get(url, headers=headers, params=params) + response.raise_for_status() json_response = response.json() diff --git a/backend/open_webui/retrieval/web/duckduckgo.py b/backend/open_webui/retrieval/web/duckduckgo.py index 0303b4e303..7528418cdb 100644 --- a/backend/open_webui/retrieval/web/duckduckgo.py +++ b/backend/open_webui/retrieval/web/duckduckgo.py @@ -13,12 +13,14 @@ def search_duckduckgo( count: int, filter_list: Optional[list[str]] = None, concurrent_requests: Optional[int] = None, + backend: Optional[str] = "auto", ) -> list[SearchResult]: """ Search using DuckDuckGo's Search API and return the results as a list of SearchResult objects. Args: query (str): The query to search for count (int): The number of results to return + backend (str): The search backend to use (auto, duckduckgo, google, brave, etc.) Returns: list[SearchResult]: A list of search results @@ -32,7 +34,7 @@ def search_duckduckgo( # Use the ddgs.text() method to perform the search try: search_results = ddgs.text( - query, safesearch="moderate", max_results=count, backend="lite" + query, safesearch="moderate", max_results=count, backend=backend ) except RatelimitException as e: log.error(f"RatelimitException: {e}") diff --git a/backend/open_webui/retrieval/web/jina_search.py b/backend/open_webui/retrieval/web/jina_search.py index bcc5794027..d1168bb36f 100644 --- a/backend/open_webui/retrieval/web/jina_search.py +++ b/backend/open_webui/retrieval/web/jina_search.py @@ -7,17 +7,21 @@ log = logging.getLogger(__name__) -def search_jina(api_key: str, query: str, count: int) -> list[SearchResult]: +def search_jina( + api_key: str, query: str, count: int, base_url: str = "" +) -> list[SearchResult]: """ Search using Jina's Search API and return the results as a list of SearchResult objects. Args: + api_key (str): The Jina API key query (str): The query to search for count (int): The number of results to return + base_url (str): Optional custom base URL for the Jina API Returns: list[SearchResult]: A list of search results """ - jina_search_endpoint = "https://s.jina.ai/" + jina_search_endpoint = base_url if base_url else "https://s.jina.ai/" headers = { "Accept": "application/json", diff --git a/backend/open_webui/retrieval/web/utils.py b/backend/open_webui/retrieval/web/utils.py index eb01976e95..6c1ea4b1bf 100644 --- a/backend/open_webui/retrieval/web/utils.py +++ b/backend/open_webui/retrieval/web/utils.py @@ -36,6 +36,7 @@ WEB_LOADER_TIMEOUT, FIRECRAWL_API_BASE_URL, FIRECRAWL_API_KEY, + FIRECRAWL_TIMEOUT, TAVILY_API_KEY, TAVILY_EXTRACT_DEPTH, EXTERNAL_WEB_LOADER_URL, @@ -189,6 +190,7 @@ def __init__( continue_on_failure: bool = True, api_key: Optional[str] = None, api_url: Optional[str] = None, + timeout: Optional[int] = None, mode: Literal["crawl", "scrape", "map"] = "scrape", proxy: Optional[Dict[str, str]] = None, params: Optional[Dict] = None, @@ -231,6 +233,7 @@ def __init__( self.continue_on_failure = continue_on_failure self.api_key = api_key self.api_url = api_url + self.timeout = timeout self.mode = mode self.params = params or {} @@ -253,7 +256,7 @@ def lazy_load(self) -> Iterator[Document]: ignore_invalid_urls=True, remove_base64_images=True, max_age=300000, # 5 minutes https://docs.firecrawl.dev/features/fast-scraping#common-maxage-values - wait_timeout=len(self.web_paths) * 3, + wait_timeout=self.timeout if self.timeout else len(self.web_paths) * 3, **self.params, ) @@ -294,7 +297,7 @@ async def alazy_load(self): ignore_invalid_urls=True, remove_base64_images=True, max_age=300000, # 5 minutes https://docs.firecrawl.dev/features/fast-scraping#common-maxage-values - wait_timeout=len(self.web_paths) * 3, + wait_timeout=self.timeout if self.timeout else len(self.web_paths) * 3, **self.params, ) @@ -697,6 +700,11 @@ def get_web_loader( WebLoaderClass = SafeFireCrawlLoader web_loader_args["api_key"] = FIRECRAWL_API_KEY.value web_loader_args["api_url"] = FIRECRAWL_API_BASE_URL.value + if FIRECRAWL_TIMEOUT.value: + try: + web_loader_args["timeout"] = int(FIRECRAWL_TIMEOUT.value) + except ValueError: + pass if WEB_LOADER_ENGINE.value == "tavily": WebLoaderClass = SafeTavilyLoader diff --git a/backend/open_webui/routers/audio.py b/backend/open_webui/routers/audio.py index 120300d49e..2a64a3598d 100644 --- a/backend/open_webui/routers/audio.py +++ b/backend/open_webui/routers/audio.py @@ -35,12 +35,16 @@ from open_webui.utils.misc import strict_match_mime_type from open_webui.utils.auth import get_admin_user, get_verified_user +from open_webui.utils.access_control import has_permission from open_webui.utils.headers import include_user_info_headers from open_webui.config import ( WHISPER_MODEL_AUTO_UPDATE, + WHISPER_COMPUTE_TYPE, WHISPER_MODEL_DIR, + WHISPER_VAD_FILTER, CACHE_DIR, WHISPER_LANGUAGE, + WHISPER_MULTILINGUAL, ELEVENLABS_API_BASE_URL, ) @@ -129,7 +133,7 @@ def set_faster_whisper_model(model: str, auto_update: bool = False): faster_whisper_kwargs = { "model_size_or_path": model, "device": DEVICE_TYPE if DEVICE_TYPE and DEVICE_TYPE == "cuda" else "cpu", - "compute_type": "int8", + "compute_type": WHISPER_COMPUTE_TYPE, "download_root": WHISPER_MODEL_DIR, "local_files_only": not auto_update, } @@ -328,6 +332,20 @@ def load_speech_pipeline(request): @router.post("/speech") async def speech(request: Request, user=Depends(get_verified_user)): + if request.app.state.config.TTS_ENGINE == "": + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) + + if user.role != "admin" and not has_permission( + user.id, "chat.tts", request.app.state.config.USER_PERMISSIONS + ): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) + body = await request.body() name = hashlib.sha256( body @@ -585,8 +603,9 @@ def transcription_handler(request, file_path, metadata, user=None): segments, info = model.transcribe( file_path, beam_size=5, - vad_filter=request.app.state.config.WHISPER_VAD_FILTER, + vad_filter=WHISPER_VAD_FILTER, language=languages[0], + multilingual=WHISPER_MULTILINGUAL, ) log.info( "Detected language '%s' with probability %f" @@ -1150,6 +1169,19 @@ def transcription( language: Optional[str] = Form(None), user=Depends(get_verified_user), ): + if request.app.state.config.STT_ENGINE == "": + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) + + if user.role != "admin" and not has_permission( + user.id, "chat.stt", request.app.state.config.USER_PERMISSIONS + ): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) log.info(f"file.content_type: {file.content_type}") stt_supported_content_types = getattr( request.app.state.config, "STT_SUPPORTED_CONTENT_TYPES", [] diff --git a/backend/open_webui/routers/auths.py b/backend/open_webui/routers/auths.py index 83e390fa6c..29ed41a74d 100644 --- a/backend/open_webui/routers/auths.py +++ b/backend/open_webui/routers/auths.py @@ -72,6 +72,8 @@ send_verify_email, verify_email_by_code, ) +from open_webui.internal.db import get_session +from sqlalchemy.orm import Session from open_webui.utils.webhook import post_webhook from open_webui.utils.access_control import get_permissions, has_permission from open_webui.utils.groups import apply_default_group_assignment @@ -118,7 +120,10 @@ class SessionUserInfoResponse(SessionUserResponse, UserStatus): @router.get("/", response_model=SessionUserInfoResponse) async def get_session_user( - request: Request, response: Response, user: UserModel = Depends(get_current_user) + request: Request, + response: Response, + user: UserModel = Depends(get_current_user), + db: Session = Depends(get_session), ): auth_header = request.headers.get("Authorization") auth_token = get_http_authorization_cred(auth_header) @@ -151,7 +156,7 @@ async def get_session_user( ) user_permissions = get_permissions( - user.id, request.app.state.config.USER_PERMISSIONS + user.id, request.app.state.config.USER_PERMISSIONS, db=db ) credit = Credits.init_credit_by_user_id(user.id) @@ -183,12 +188,15 @@ async def get_session_user( @router.post("/update/profile", response_model=UserProfileImageResponse) async def update_profile( - form_data: UpdateProfileForm, session_user=Depends(get_verified_user) + form_data: UpdateProfileForm, + session_user=Depends(get_verified_user), + db: Session = Depends(get_session), ): if session_user: user = Users.update_user_by_id( session_user.id, form_data.model_dump(), + db=db, ) if user: return user @@ -198,6 +206,32 @@ async def update_profile( raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) +############################ +# Update Timezone +############################ + + +class UpdateTimezoneForm(BaseModel): + timezone: str + + +@router.post("/update/timezone") +async def update_timezone( + form_data: UpdateTimezoneForm, + session_user=Depends(get_current_user), + db: Session = Depends(get_session), +): + if session_user: + Users.update_user_by_id( + session_user.id, + {"timezone": form_data.timezone}, + db=db, + ) + return {"status": True} + else: + raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) + + ############################ # Update Password ############################ @@ -205,13 +239,17 @@ async def update_profile( @router.post("/update/password", response_model=bool) async def update_password( - form_data: UpdatePasswordForm, session_user=Depends(get_current_user) + form_data: UpdatePasswordForm, + session_user=Depends(get_current_user), + db: Session = Depends(get_session), ): if WEBUI_AUTH_TRUSTED_EMAIL_HEADER: raise HTTPException(400, detail=ERROR_MESSAGES.ACTION_PROHIBITED) if session_user: user = Auths.authenticate_user( - session_user.email, lambda pw: verify_password(form_data.password, pw) + session_user.email, + lambda pw: verify_password(form_data.password, pw), + db=db, ) if user: @@ -220,7 +258,7 @@ async def update_password( except Exception as e: raise HTTPException(400, detail=str(e)) hashed = get_password_hash(form_data.new_password) - return Auths.update_user_password_by_id(user.id, hashed) + return Auths.update_user_password_by_id(user.id, hashed, db=db) else: raise HTTPException(400, detail=ERROR_MESSAGES.INCORRECT_PASSWORD) else: @@ -231,7 +269,12 @@ async def update_password( # LDAP Authentication ############################ @router.post("/ldap", response_model=SessionUserResponse) -async def ldap_auth(request: Request, response: Response, form_data: LdapForm): +async def ldap_auth( + request: Request, + response: Response, + form_data: LdapForm, + db: Session = Depends(get_session), +): # Security checks FIRST - before loading any config if not request.app.state.config.ENABLE_LDAP: raise HTTPException(400, detail="LDAP authentication is not enabled") @@ -417,12 +460,12 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm): if not connection_user.bind(): raise HTTPException(400, "Authentication failed.") - user = Users.get_user_by_email(email) + user = Users.get_user_by_email(email, db=db) if not user: try: role = ( "admin" - if not Users.has_users() + if not Users.has_users(db=db) else request.app.state.config.DEFAULT_USER_ROLE ) @@ -431,6 +474,7 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm): password=str(uuid.uuid4()), name=cn, role=role, + db=db, ) if not user: @@ -441,6 +485,7 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm): apply_default_group_assignment( request.app.state.config.DEFAULT_GROUP_ID, user.id, + db=db, ) except HTTPException: @@ -451,7 +496,7 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm): 500, detail="Internal error occurred during LDAP user creation." ) - user = Auths.authenticate_user_by_email(email) + user = Auths.authenticate_user_by_email(email, db=db) if user: expires_delta = parse_duration(request.app.state.config.JWT_EXPIRES_IN) @@ -481,7 +526,7 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm): ) user_permissions = get_permissions( - user.id, request.app.state.config.USER_PERMISSIONS + user.id, request.app.state.config.USER_PERMISSIONS, db=db ) credit = Credits.init_credit_by_user_id(user.id) @@ -492,9 +537,9 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm): and user_groups ): if ENABLE_LDAP_GROUP_CREATION: - Groups.create_groups_by_group_names(user.id, user_groups) + Groups.create_groups_by_group_names(user.id, user_groups, db=db) try: - Groups.sync_groups_by_group_names(user.id, user_groups) + Groups.sync_groups_by_group_names(user.id, user_groups, db=db) log.info( f"Successfully synced groups for user {user.id}: {user_groups}" ) @@ -528,7 +573,12 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm): @router.post("/signin", response_model=SessionUserResponse) -async def signin(request: Request, response: Response, form_data: SigninForm): +async def signin( + request: Request, + response: Response, + form_data: SigninForm, + db: Session = Depends(get_session), +): if not ENABLE_PASSWORD_AUTH: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, @@ -549,14 +599,15 @@ async def signin(request: Request, response: Response, form_data: SigninForm): except Exception as e: pass - if not Users.get_user_by_email(email.lower()): + if not Users.get_user_by_email(email.lower(), db=db): await signup( request, response, SignupForm(email=email, password=str(uuid.uuid4()), name=name), + db=db, ) - user = Auths.authenticate_user_by_email(email) + user = Auths.authenticate_user_by_email(email, db=db) if WEBUI_AUTH_TRUSTED_GROUPS_HEADER and user and user.role != "admin": group_names = request.headers.get( WEBUI_AUTH_TRUSTED_GROUPS_HEADER, "" @@ -564,28 +615,33 @@ async def signin(request: Request, response: Response, form_data: SigninForm): group_names = [name.strip() for name in group_names if name.strip()] if group_names: - Groups.sync_groups_by_group_names(user.id, group_names) + Groups.sync_groups_by_group_names(user.id, group_names, db=db) elif WEBUI_AUTH == False: admin_email = "admin@localhost" admin_password = "admin" - if Users.get_user_by_email(admin_email.lower()): + if Users.get_user_by_email(admin_email.lower(), db=db): user = Auths.authenticate_user( - admin_email.lower(), lambda pw: verify_password(admin_password, pw) + admin_email.lower(), + lambda pw: verify_password(admin_password, pw), + db=db, ) else: - if Users.has_users(): + if Users.has_users(db=db): raise HTTPException(400, detail=ERROR_MESSAGES.EXISTING_USERS) await signup( request, response, SignupForm(email=admin_email, password=admin_password, name="User"), + db=db, ) user = Auths.authenticate_user( - admin_email.lower(), lambda pw: verify_password(admin_password, pw) + admin_email.lower(), + lambda pw: verify_password(admin_password, pw), + db=db, ) else: if signin_rate_limiter.is_limited(form_data.email.lower()): @@ -604,7 +660,9 @@ async def signin(request: Request, response: Response, form_data: SigninForm): form_data.password = password_bytes.decode("utf-8", errors="ignore") user = Auths.authenticate_user( - form_data.email.lower(), lambda pw: verify_password(form_data.password, pw) + form_data.email.lower(), + lambda pw: verify_password(form_data.password, pw), + db=db, ) if user: @@ -636,7 +694,7 @@ async def signin(request: Request, response: Response, form_data: SigninForm): ) user_permissions = get_permissions( - user.id, request.app.state.config.USER_PERMISSIONS + user.id, request.app.state.config.USER_PERMISSIONS, db=db ) credit = Credits.init_credit_by_user_id(user.id) @@ -663,8 +721,13 @@ async def signin(request: Request, response: Response, form_data: SigninForm): @router.post("/signup", response_model=SessionUserResponse) -async def signup(request: Request, response: Response, form_data: SignupForm): - has_users = Users.has_users() +async def signup( + request: Request, + response: Response, + form_data: SignupForm, + db: Session = Depends(get_session), +): + has_users = Users.has_users(db=db) if WEBUI_AUTH: if ( @@ -700,7 +763,7 @@ async def signup(request: Request, response: Response, form_data: SignupForm): status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT ) - if Users.get_user_by_email(form_data.email.lower()): + if Users.get_user_by_email(form_data.email.lower(), db=db): raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN) try: @@ -725,6 +788,7 @@ async def signup(request: Request, response: Response, form_data: SignupForm): form_data.name, form_data.profile_image_url, role, + db=db, ) if user: @@ -767,7 +831,7 @@ async def signup(request: Request, response: Response, form_data: SignupForm): ) user_permissions = get_permissions( - user.id, request.app.state.config.USER_PERMISSIONS + user.id, request.app.state.config.USER_PERMISSIONS, db=db ) if not has_users: @@ -777,6 +841,7 @@ async def signup(request: Request, response: Response, form_data: SignupForm): apply_default_group_assignment( request.app.state.config.DEFAULT_GROUP_ID, user.id, + db=db, ) credit = Credits.init_credit_by_user_id(user.id) @@ -815,7 +880,9 @@ async def signup_verify(request: Request, code: str): @router.get("/signout") -async def signout(request: Request, response: Response): +async def signout( + request: Request, response: Response, db: Session = Depends(get_session) +): # get auth token from headers or cookies token = None @@ -837,7 +904,7 @@ async def signout(request: Request, response: Response): if oauth_session_id: response.delete_cookie("oauth_session_id") - session = OAuthSessions.get_session_by_id(oauth_session_id) + session = OAuthSessions.get_session_by_id(oauth_session_id, db=db) oauth_server_metadata_url = ( request.app.state.oauth_manager.get_server_metadata_url(session.provider) if session @@ -900,14 +967,17 @@ async def signout(request: Request, response: Response): @router.post("/add", response_model=SigninResponse) async def add_user( - request: Request, form_data: AddUserForm, user=Depends(get_admin_user) + request: Request, + form_data: AddUserForm, + user=Depends(get_admin_user), + db: Session = Depends(get_session), ): if not validate_email_format(form_data.email.lower()): raise HTTPException( status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT ) - if Users.get_user_by_email(form_data.email.lower()): + if Users.get_user_by_email(form_data.email.lower(), db=db): raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN) try: @@ -923,12 +993,14 @@ async def add_user( form_data.name, form_data.profile_image_url, form_data.role, + db=db, ) if user: apply_default_group_assignment( request.app.state.config.DEFAULT_GROUP_ID, user.id, + db=db, ) token = create_token(data={"id": user.id}) @@ -956,7 +1028,9 @@ async def add_user( @router.get("/admin/details") -async def get_admin_details(request: Request, user=Depends(get_current_user)): +async def get_admin_details( + request: Request, user=Depends(get_current_user), db: Session = Depends(get_session) +): if request.app.state.config.SHOW_ADMIN_DETAILS: admin_email = request.app.state.config.ADMIN_EMAIL admin_name = None @@ -964,11 +1038,11 @@ async def get_admin_details(request: Request, user=Depends(get_current_user)): log.info(f"Admin details - Email: {admin_email}, Name: {admin_name}") if admin_email: - admin = Users.get_user_by_email(admin_email) + admin = Users.get_user_by_email(admin_email, db=db) if admin: admin_name = admin.name else: - admin = Users.get_first_user() + admin = Users.get_first_user(db=db) if admin: admin_email = admin.email admin_name = admin.name @@ -990,6 +1064,7 @@ async def get_admin_details(request: Request, user=Depends(get_current_user)): async def get_admin_config(request: Request, user=Depends(get_admin_user)): return { "SHOW_ADMIN_DETAILS": request.app.state.config.SHOW_ADMIN_DETAILS, + "ADMIN_EMAIL": request.app.state.config.ADMIN_EMAIL, "WEBUI_URL": request.app.state.config.WEBUI_URL, "ENABLE_SIGNUP": request.app.state.config.ENABLE_SIGNUP, "ENABLE_SIGNUP_VERIFY": request.app.state.config.ENABLE_SIGNUP_VERIFY, @@ -1008,9 +1083,12 @@ async def get_admin_config(request: Request, user=Depends(get_admin_user)): "ENABLE_COMMUNITY_SHARING": request.app.state.config.ENABLE_COMMUNITY_SHARING, "ENABLE_MESSAGE_RATING": request.app.state.config.ENABLE_MESSAGE_RATING, "ENABLE_FOLDERS": request.app.state.config.ENABLE_FOLDERS, + "FOLDER_MAX_FILE_COUNT": request.app.state.config.FOLDER_MAX_FILE_COUNT, "ENABLE_CHANNELS": request.app.state.config.ENABLE_CHANNELS, + "ENABLE_MEMORIES": request.app.state.config.ENABLE_MEMORIES, "ENABLE_NOTES": request.app.state.config.ENABLE_NOTES, "ENABLE_USER_WEBHOOKS": request.app.state.config.ENABLE_USER_WEBHOOKS, + "ENABLE_USER_STATUS": request.app.state.config.ENABLE_USER_STATUS, "PENDING_USER_OVERLAY_TITLE": request.app.state.config.PENDING_USER_OVERLAY_TITLE, "PENDING_USER_OVERLAY_CONTENT": request.app.state.config.PENDING_USER_OVERLAY_CONTENT, "RESPONSE_WATERMARK": request.app.state.config.RESPONSE_WATERMARK, @@ -1019,6 +1097,7 @@ async def get_admin_config(request: Request, user=Depends(get_admin_user)): class AdminConfig(BaseModel): SHOW_ADMIN_DETAILS: bool + ADMIN_EMAIL: Optional[str] = None WEBUI_URL: str ENABLE_SIGNUP: bool ENABLE_SIGNUP_VERIFY: bool = Field(default=False) @@ -1037,9 +1116,12 @@ class AdminConfig(BaseModel): ENABLE_COMMUNITY_SHARING: bool ENABLE_MESSAGE_RATING: bool ENABLE_FOLDERS: bool + FOLDER_MAX_FILE_COUNT: Optional[int | str] = None ENABLE_CHANNELS: bool + ENABLE_MEMORIES: bool ENABLE_NOTES: bool ENABLE_USER_WEBHOOKS: bool + ENABLE_USER_STATUS: bool PENDING_USER_OVERLAY_TITLE: Optional[str] = None PENDING_USER_OVERLAY_CONTENT: Optional[str] = None RESPONSE_WATERMARK: Optional[str] = None @@ -1063,6 +1145,7 @@ async def update_admin_config( raise HTTPException(status_code=400, detail="Redis is not configured.") request.app.state.config.SHOW_ADMIN_DETAILS = form_data.SHOW_ADMIN_DETAILS + request.app.state.config.ADMIN_EMAIL = form_data.ADMIN_EMAIL request.app.state.config.WEBUI_URL = form_data.WEBUI_URL request.app.state.config.ENABLE_SIGNUP = form_data.ENABLE_SIGNUP request.app.state.config.ENABLE_SIGNUP_VERIFY = form_data.ENABLE_SIGNUP_VERIFY @@ -1084,7 +1167,11 @@ async def update_admin_config( ) request.app.state.config.ENABLE_FOLDERS = form_data.ENABLE_FOLDERS + request.app.state.config.FOLDER_MAX_FILE_COUNT = ( + int(form_data.FOLDER_MAX_FILE_COUNT) if form_data.FOLDER_MAX_FILE_COUNT else "" + ) request.app.state.config.ENABLE_CHANNELS = form_data.ENABLE_CHANNELS + request.app.state.config.ENABLE_MEMORIES = form_data.ENABLE_MEMORIES request.app.state.config.ENABLE_NOTES = form_data.ENABLE_NOTES if form_data.DEFAULT_USER_ROLE in ["pending", "user", "admin"]: @@ -1104,6 +1191,7 @@ async def update_admin_config( request.app.state.config.ENABLE_MESSAGE_RATING = form_data.ENABLE_MESSAGE_RATING request.app.state.config.ENABLE_USER_WEBHOOKS = form_data.ENABLE_USER_WEBHOOKS + request.app.state.config.ENABLE_USER_STATUS = form_data.ENABLE_USER_STATUS request.app.state.config.PENDING_USER_OVERLAY_TITLE = ( form_data.PENDING_USER_OVERLAY_TITLE @@ -1116,6 +1204,7 @@ async def update_admin_config( return { "SHOW_ADMIN_DETAILS": request.app.state.config.SHOW_ADMIN_DETAILS, + "ADMIN_EMAIL": request.app.state.config.ADMIN_EMAIL, "WEBUI_URL": request.app.state.config.WEBUI_URL, "ENABLE_SIGNUP": request.app.state.config.ENABLE_SIGNUP, "ENABLE_SIGNUP_VERIFY": request.app.state.config.ENABLE_SIGNUP_VERIFY, @@ -1134,9 +1223,12 @@ async def update_admin_config( "ENABLE_COMMUNITY_SHARING": request.app.state.config.ENABLE_COMMUNITY_SHARING, "ENABLE_MESSAGE_RATING": request.app.state.config.ENABLE_MESSAGE_RATING, "ENABLE_FOLDERS": request.app.state.config.ENABLE_FOLDERS, + "FOLDER_MAX_FILE_COUNT": request.app.state.config.FOLDER_MAX_FILE_COUNT, "ENABLE_CHANNELS": request.app.state.config.ENABLE_CHANNELS, + "ENABLE_MEMORIES": request.app.state.config.ENABLE_MEMORIES, "ENABLE_NOTES": request.app.state.config.ENABLE_NOTES, "ENABLE_USER_WEBHOOKS": request.app.state.config.ENABLE_USER_WEBHOOKS, + "ENABLE_USER_STATUS": request.app.state.config.ENABLE_USER_STATUS, "PENDING_USER_OVERLAY_TITLE": request.app.state.config.PENDING_USER_OVERLAY_TITLE, "PENDING_USER_OVERLAY_CONTENT": request.app.state.config.PENDING_USER_OVERLAY_CONTENT, "RESPONSE_WATERMARK": request.app.state.config.RESPONSE_WATERMARK, @@ -1253,7 +1345,9 @@ async def update_ldap_config( # create api key @router.post("/api_key", response_model=ApiKey) -async def generate_api_key(request: Request, user=Depends(get_current_user)): +async def generate_api_key( + request: Request, user=Depends(get_current_user), db: Session = Depends(get_session) +): if not request.app.state.config.ENABLE_API_KEYS or not has_permission( user.id, "features.api_keys", request.app.state.config.USER_PERMISSIONS ): @@ -1263,7 +1357,7 @@ async def generate_api_key(request: Request, user=Depends(get_current_user)): ) api_key = create_api_key() - success = Users.update_user_api_key_by_id(user.id, api_key) + success = Users.update_user_api_key_by_id(user.id, api_key, db=db) if success: return { @@ -1275,14 +1369,18 @@ async def generate_api_key(request: Request, user=Depends(get_current_user)): # delete api key @router.delete("/api_key", response_model=bool) -async def delete_api_key(user=Depends(get_current_user)): - return Users.delete_user_api_key_by_id(user.id) +async def delete_api_key( + user=Depends(get_current_user), db: Session = Depends(get_session) +): + return Users.delete_user_api_key_by_id(user.id, db=db) # get api key @router.get("/api_key", response_model=ApiKey) -async def get_api_key(user=Depends(get_current_user)): - api_key = Users.get_user_api_key_by_id(user.id) +async def get_api_key( + user=Depends(get_current_user), db: Session = Depends(get_session) +): + api_key = Users.get_user_api_key_by_id(user.id, db=db) if api_key: return { "api_key": api_key, diff --git a/backend/open_webui/routers/channels.py b/backend/open_webui/routers/channels.py index 4cd116424f..777f5e74ea 100644 --- a/backend/open_webui/routers/channels.py +++ b/backend/open_webui/routers/channels.py @@ -1,9 +1,12 @@ import json import logging +import base64 +import io from typing import Optional from fastapi import APIRouter, Depends, HTTPException, Request, status, BackgroundTasks +from fastapi.responses import Response, StreamingResponse, FileResponse from pydantic import BaseModel from pydantic import field_validator @@ -29,6 +32,8 @@ ChannelForm, ChannelResponse, CreateChannelForm, + ChannelWebhookModel, + ChannelWebhookForm, ) from open_webui.models.messages import ( Messages, @@ -43,6 +48,7 @@ from open_webui.config import ENABLE_ADMIN_CHAT_ACCESS, ENABLE_ADMIN_EXPORT from open_webui.constants import ERROR_MESSAGES +from open_webui.env import STATIC_DIR from open_webui.utils.models import ( @@ -61,6 +67,8 @@ ) from open_webui.utils.webhook import post_webhook from open_webui.utils.channels import extract_mentions, replace_mentions +from open_webui.internal.db import get_session +from sqlalchemy.orm import Session log = logging.getLogger(__name__) @@ -98,26 +106,29 @@ class ChannelListItemResponse(ChannelModel): async def get_channels( request: Request, user=Depends(get_verified_user), + db: Session = Depends(get_session), ): check_channels_access(request) if user.role != "admin" and not has_permission( - user.id, "features.channels", request.app.state.config.USER_PERMISSIONS + user.id, "features.channels", request.app.state.config.USER_PERMISSIONS, db=db ): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.UNAUTHORIZED, ) - channels = Channels.get_channels_by_user_id(user.id) + channels = Channels.get_channels_by_user_id(user.id, db=db) channel_list = [] for channel in channels: - last_message = Messages.get_last_message_by_channel_id(channel.id) + last_message = Messages.get_last_message_by_channel_id(channel.id, db=db) last_message_at = last_message.created_at if last_message else None - channel_member = Channels.get_member_by_channel_and_user_id(channel.id, user.id) + channel_member = Channels.get_member_by_channel_and_user_id( + channel.id, user.id, db=db + ) unread_count = ( Messages.get_unread_message_count( - channel.id, user.id, channel_member.last_read_at + channel.id, user.id, channel_member.last_read_at, db=db ) if channel_member else 0 @@ -128,13 +139,16 @@ async def get_channels( if channel.type == "dm": user_ids = [ member.user_id - for member in Channels.get_members_by_channel_id(channel.id) + for member in Channels.get_members_by_channel_id(channel.id, db=db) ] users = [ UserIdNameStatusResponse( - **{**user.model_dump(), "is_active": Users.is_user_active(user.id)} + **{ + **user.model_dump(), + "is_active": Users.is_user_active(user.id, db=db), + } ) - for user in Users.get_users_by_user_ids(user_ids) + for user in Users.get_users_by_user_ids(user_ids, db=db) ] channel_list.append( @@ -154,11 +168,12 @@ async def get_channels( async def get_all_channels( request: Request, user=Depends(get_verified_user), + db: Session = Depends(get_session), ): check_channels_access(request) if user.role == "admin": - return Channels.get_channels() - return Channels.get_channels_by_user_id(user.id) + return Channels.get_channels(db=db) + return Channels.get_channels_by_user_id(user.id, db=db) ############################ @@ -171,10 +186,11 @@ async def get_dm_channel_by_user_id( request: Request, user_id: str, user=Depends(get_verified_user), + db: Session = Depends(get_session), ): check_channels_access(request) if user.role != "admin" and not has_permission( - user.id, "features.channels", request.app.state.config.USER_PERMISSIONS + user.id, "features.channels", request.app.state.config.USER_PERMISSIONS, db=db ): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -182,11 +198,15 @@ async def get_dm_channel_by_user_id( ) try: - existing_channel = Channels.get_dm_channel_by_user_ids([user.id, user_id]) + existing_channel = Channels.get_dm_channel_by_user_ids( + [user.id, user_id], db=db + ) if existing_channel: participant_ids = [ member.user_id - for member in Channels.get_members_by_channel_id(existing_channel.id) + for member in Channels.get_members_by_channel_id( + existing_channel.id, db=db + ) ] await emit_to_users( @@ -198,7 +218,9 @@ async def get_dm_channel_by_user_id( f"channel:{existing_channel.id}", participant_ids ) - Channels.update_member_active_status(existing_channel.id, user.id, True) + Channels.update_member_active_status( + existing_channel.id, user.id, True, db=db + ) return ChannelModel(**existing_channel.model_dump()) channel = Channels.insert_new_channel( @@ -208,12 +230,13 @@ async def get_dm_channel_by_user_id( user_ids=[user_id], ), user.id, + db=db, ) if channel: participant_ids = [ member.user_id - for member in Channels.get_members_by_channel_id(channel.id) + for member in Channels.get_members_by_channel_id(channel.id, db=db) ] await emit_to_users( @@ -243,10 +266,11 @@ async def create_new_channel( request: Request, form_data: CreateChannelForm, user=Depends(get_verified_user), + db: Session = Depends(get_session), ): check_channels_access(request) if user.role != "admin" and not has_permission( - user.id, "features.channels", request.app.state.config.USER_PERMISSIONS + user.id, "features.channels", request.app.state.config.USER_PERMISSIONS, db=db ): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -263,13 +287,13 @@ async def create_new_channel( try: if form_data.type == "dm": existing_channel = Channels.get_dm_channel_by_user_ids( - [user.id, *form_data.user_ids] + [user.id, *form_data.user_ids], db=db ) if existing_channel: participant_ids = [ member.user_id for member in Channels.get_members_by_channel_id( - existing_channel.id + existing_channel.id, db=db ) ] await emit_to_users( @@ -281,15 +305,17 @@ async def create_new_channel( f"channel:{existing_channel.id}", participant_ids ) - Channels.update_member_active_status(existing_channel.id, user.id, True) + Channels.update_member_active_status( + existing_channel.id, user.id, True, db=db + ) return ChannelModel(**existing_channel.model_dump()) - channel = Channels.insert_new_channel(form_data, user.id) + channel = Channels.insert_new_channel(form_data, user.id, db=db) if channel: participant_ids = [ member.user_id - for member in Channels.get_members_by_channel_id(channel.id) + for member in Channels.get_members_by_channel_id(channel.id, db=db) ] await emit_to_users( @@ -327,9 +353,10 @@ async def get_channel_by_id( request: Request, id: str, user=Depends(get_verified_user), + db: Session = Depends(get_session), ): check_channels_access(request) - channel = Channels.get_channel_by_id(id) + channel = Channels.get_channel_by_id(id, db=db) if not channel: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND @@ -339,23 +366,29 @@ async def get_channel_by_id( users = None if channel.type in ["group", "dm"]: - if not Channels.is_user_channel_member(channel.id, user.id): + if not Channels.is_user_channel_member(channel.id, user.id, db=db): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() ) user_ids = [ - member.user_id for member in Channels.get_members_by_channel_id(channel.id) + member.user_id + for member in Channels.get_members_by_channel_id(channel.id, db=db) ] users = [ UserIdNameStatusResponse( - **{**user.model_dump(), "is_active": Users.is_user_active(user.id)} + **{ + **user.model_dump(), + "is_active": Users.is_user_active(user.id, db=db), + } ) - for user in Users.get_users_by_user_ids(user_ids) + for user in Users.get_users_by_user_ids(user_ids, db=db) ] - channel_member = Channels.get_member_by_channel_and_user_id(channel.id, user.id) + channel_member = Channels.get_member_by_channel_and_user_id( + channel.id, user.id, db=db + ) unread_count = Messages.get_unread_message_count( channel.id, user.id, channel_member.last_read_at if channel_member else None ) @@ -365,7 +398,9 @@ async def get_channel_by_id( **channel.model_dump(), "user_ids": user_ids, "users": users, - "is_manager": Channels.is_user_channel_manager(channel.id, user.id), + "is_manager": Channels.is_user_channel_manager( + channel.id, user.id, db=db + ), "write_access": True, "user_count": len(user_ids), "last_read_at": channel_member.last_read_at if channel_member else None, @@ -374,19 +409,25 @@ async def get_channel_by_id( ) else: if user.role != "admin" and not has_access( - user.id, type="read", access_control=channel.access_control + user.id, type="read", access_control=channel.access_control, db=db ): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() ) write_access = has_access( - user.id, type="write", access_control=channel.access_control, strict=False + user.id, + type="write", + access_control=channel.access_control, + strict=False, + db=db, ) user_count = len(get_users_with_access("read", channel.access_control)) - channel_member = Channels.get_member_by_channel_and_user_id(channel.id, user.id) + channel_member = Channels.get_member_by_channel_and_user_id( + channel.id, user.id, db=db + ) unread_count = Messages.get_unread_message_count( channel.id, user.id, channel_member.last_read_at if channel_member else None ) @@ -396,7 +437,9 @@ async def get_channel_by_id( **channel.model_dump(), "user_ids": user_ids, "users": users, - "is_manager": Channels.is_user_channel_manager(channel.id, user.id), + "is_manager": Channels.is_user_channel_manager( + channel.id, user.id, db=db + ), "write_access": write_access or user.role == "admin", "user_count": user_count, "last_read_at": channel_member.last_read_at if channel_member else None, @@ -422,10 +465,11 @@ async def get_channel_members_by_id( direction: Optional[str] = None, page: Optional[int] = 1, user=Depends(get_verified_user), + db: Session = Depends(get_session), ): check_channels_access(request) - channel = Channels.get_channel_by_id(id) + channel = Channels.get_channel_by_id(id, db=db) if not channel: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND @@ -437,22 +481,23 @@ async def get_channel_members_by_id( skip = (page - 1) * limit if channel.type in ["group", "dm"]: - if not Channels.is_user_channel_member(channel.id, user.id): + if not Channels.is_user_channel_member(channel.id, user.id, db=db): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() ) if channel.type == "dm": user_ids = [ - member.user_id for member in Channels.get_members_by_channel_id(channel.id) + member.user_id + for member in Channels.get_members_by_channel_id(channel.id, db=db) ] - users = Users.get_users_by_user_ids(user_ids) + users = Users.get_users_by_user_ids(user_ids, db=db) total = len(users) return { "users": [ UserModelResponse( - **user.model_dump(), is_active=Users.is_user_active(user.id) + **user.model_dump(), is_active=Users.is_user_active(user.id, db=db) ) for user in users ], @@ -479,7 +524,7 @@ async def get_channel_members_by_id( filter["user_ids"] = permitted_ids.get("user_ids") filter["group_ids"] = permitted_ids.get("group_ids") - result = Users.get_users(filter=filter, skip=skip, limit=limit) + result = Users.get_users(filter=filter, skip=skip, limit=limit, db=db) users = result["users"] total = result["total"] @@ -487,7 +532,7 @@ async def get_channel_members_by_id( return { "users": [ UserModelResponse( - **user.model_dump(), is_active=Users.is_user_active(user.id) + **user.model_dump(), is_active=Users.is_user_active(user.id, db=db) ) for user in users ], @@ -510,20 +555,23 @@ async def update_is_active_member_by_id_and_user_id( id: str, form_data: UpdateActiveMemberForm, user=Depends(get_verified_user), + db: Session = Depends(get_session), ): check_channels_access(request) - channel = Channels.get_channel_by_id(id) + channel = Channels.get_channel_by_id(id, db=db) if not channel: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND ) - if not Channels.is_user_channel_member(channel.id, user.id): + if not Channels.is_user_channel_member(channel.id, user.id, db=db): raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND ) - Channels.update_member_active_status(channel.id, user.id, form_data.is_active) + Channels.update_member_active_status( + channel.id, user.id, form_data.is_active, db=db + ) return True @@ -543,17 +591,18 @@ async def add_members_by_id( id: str, form_data: UpdateMembersForm, user=Depends(get_verified_user), + db: Session = Depends(get_session), ): check_channels_access(request) if user.role != "admin" and not has_permission( - user.id, "features.channels", request.app.state.config.USER_PERMISSIONS + user.id, "features.channels", request.app.state.config.USER_PERMISSIONS, db=db ): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.UNAUTHORIZED, ) - channel = Channels.get_channel_by_id(id) + channel = Channels.get_channel_by_id(id, db=db) if not channel: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND @@ -566,7 +615,7 @@ async def add_members_by_id( try: memberships = Channels.add_members_to_channel( - channel.id, user.id, form_data.user_ids, form_data.group_ids + channel.id, user.id, form_data.user_ids, form_data.group_ids, db=db ) return memberships @@ -592,17 +641,18 @@ async def remove_members_by_id( id: str, form_data: RemoveMembersForm, user=Depends(get_verified_user), + db: Session = Depends(get_session), ): check_channels_access(request) if user.role != "admin" and not has_permission( - user.id, "features.channels", request.app.state.config.USER_PERMISSIONS + user.id, "features.channels", request.app.state.config.USER_PERMISSIONS, db=db ): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.UNAUTHORIZED, ) - channel = Channels.get_channel_by_id(id) + channel = Channels.get_channel_by_id(id, db=db) if not channel: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND @@ -614,7 +664,9 @@ async def remove_members_by_id( ) try: - deleted = Channels.remove_members_from_channel(channel.id, form_data.user_ids) + deleted = Channels.remove_members_from_channel( + channel.id, form_data.user_ids, db=db + ) return deleted except Exception as e: @@ -635,17 +687,18 @@ async def update_channel_by_id( id: str, form_data: ChannelForm, user=Depends(get_verified_user), + db: Session = Depends(get_session), ): check_channels_access(request) if user.role != "admin" and not has_permission( - user.id, "features.channels", request.app.state.config.USER_PERMISSIONS + user.id, "features.channels", request.app.state.config.USER_PERMISSIONS, db=db ): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.UNAUTHORIZED, ) - channel = Channels.get_channel_by_id(id) + channel = Channels.get_channel_by_id(id, db=db) if not channel: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND @@ -657,7 +710,7 @@ async def update_channel_by_id( ) try: - channel = Channels.update_channel_by_id(id, form_data) + channel = Channels.update_channel_by_id(id, form_data, db=db) return ChannelModel(**channel.model_dump()) except Exception as e: log.exception(e) @@ -676,17 +729,18 @@ async def delete_channel_by_id( request: Request, id: str, user=Depends(get_verified_user), + db: Session = Depends(get_session), ): check_channels_access(request) if user.role != "admin" and not has_permission( - user.id, "features.channels", request.app.state.config.USER_PERMISSIONS + user.id, "features.channels", request.app.state.config.USER_PERMISSIONS, db=db ): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.UNAUTHORIZED, ) - channel = Channels.get_channel_by_id(id) + channel = Channels.get_channel_by_id(id, db=db) if not channel: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND @@ -698,7 +752,7 @@ async def delete_channel_by_id( ) try: - Channels.delete_channel_by_id(id) + Channels.delete_channel_by_id(id, db=db) return True except Exception as e: log.exception(e) @@ -732,57 +786,63 @@ async def get_channel_messages( skip: int = 0, limit: int = 50, user=Depends(get_verified_user), + db: Session = Depends(get_session), ): check_channels_access(request) - channel = Channels.get_channel_by_id(id) + channel = Channels.get_channel_by_id(id, db=db) if not channel: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND ) if channel.type in ["group", "dm"]: - if not Channels.is_user_channel_member(channel.id, user.id): + if not Channels.is_user_channel_member(channel.id, user.id, db=db): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() ) else: if user.role != "admin" and not has_access( - user.id, type="read", access_control=channel.access_control + user.id, type="read", access_control=channel.access_control, db=db ): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() ) channel_member = Channels.join_channel( - id, user.id + id, user.id, db=db ) # Ensure user is a member of the channel - message_list = Messages.get_messages_by_channel_id(id, skip, limit) - users = {} + message_list = Messages.get_messages_by_channel_id(id, skip, limit, db=db) + + if not message_list: + return [] + + # Batch fetch all users in a single query (fixes N+1 problem) + user_ids = list(set(m.user_id for m in message_list)) + users = {u.id: u for u in Users.get_users_by_user_ids(user_ids, db=db)} messages = [] for message in message_list: - if message.user_id not in users: - user = Users.get_user_by_id(message.user_id) - users[message.user_id] = user - - thread_replies = Messages.get_thread_replies_by_message_id(message.id) + thread_replies = Messages.get_thread_replies_by_message_id(message.id, db=db) latest_thread_reply_at = ( thread_replies[0].created_at if thread_replies else None ) - user = None - user_model = users.get(message.user_id) or None - if user_model: - user = UserNameResponse(**user_model.model_dump()) + # Use message.user if present (for webhooks), otherwise look up by user_id + user_info = message.user + if user_info is None and message.user_id in users: + user_info = UserNameResponse(**users[message.user_id].model_dump()) + messages.append( MessageUserResponse( **{ **message.model_dump(), "reply_count": len(thread_replies), "latest_reply_at": latest_thread_reply_at, - "reactions": Messages.get_reactions_by_message_id(message.id), - "user": user, + "reactions": Messages.get_reactions_by_message_id( + message.id, db=db + ), + "user": user_info, } ) ) @@ -803,22 +863,23 @@ async def get_pinned_channel_messages( id: str, page: int = 1, user=Depends(get_verified_user), + db: Session = Depends(get_session), ): check_channels_access(request) - channel = Channels.get_channel_by_id(id) + channel = Channels.get_channel_by_id(id, db=db) if not channel: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND ) if channel.type in ["group", "dm"]: - if not Channels.is_user_channel_member(channel.id, user.id): + if not Channels.is_user_channel_member(channel.id, user.id, db=db): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() ) else: if user.role != "admin" and not has_access( - user.id, type="read", access_control=channel.access_control + user.id, type="read", access_control=channel.access_control, db=db ): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() @@ -828,25 +889,38 @@ async def get_pinned_channel_messages( skip = (page - 1) * PAGE_ITEM_COUNT_PINNED limit = PAGE_ITEM_COUNT_PINNED - message_list = Messages.get_pinned_messages_by_channel_id(id, skip, limit) - users = {} + message_list = Messages.get_pinned_messages_by_channel_id(id, skip, limit, db=db) + + if not message_list: + return [] + + # Batch fetch all users in a single query (fixes N+1 problem) + user_ids = list(set(m.user_id for m in message_list)) + users = {u.id: u for u in Users.get_users_by_user_ids(user_ids, db=db)} messages = [] for message in message_list: - if message.user_id not in users: - user = Users.get_user_by_id(message.user_id) - users[message.user_id] = user - - user = None - user_model = users.get(message.user_id) or None - if user_model: - user = UserNameResponse(**user_model.model_dump()) + # Check for webhook identity in meta + webhook_info = message.meta.get("webhook") if message.meta else None + if webhook_info: + user_info = UserNameResponse( + id=webhook_info.get("id"), + name=webhook_info.get("name"), + role="webhook", + ) + elif message.user_id in users: + user_info = UserNameResponse(**users[message.user_id].model_dump()) + else: + user_info = None + messages.append( MessageWithReactionsResponse( **{ **message.model_dump(), - "reactions": Messages.get_reactions_by_message_id(message.id), - "user": user, + "reactions": Messages.get_reactions_by_message_id( + message.id, db=db + ), + "user": user_info, } ) ) @@ -859,12 +933,14 @@ async def get_pinned_channel_messages( ############################ -async def send_notification(name, webui_url, channel, message, active_user_ids): +async def send_notification( + name, webui_url, channel, message, active_user_ids, db=None +): users = get_users_with_access("read", channel.access_control) for user in users: if (user.id not in active_user_ids) and Channels.is_user_channel_member( - channel.id, user.id + channel.id, user.id, db=db ): if user.settings: webhook_url = user.settings.ui.get("notifications", {}).get( @@ -886,7 +962,7 @@ async def send_notification(name, webui_url, channel, message, active_user_ids): return True -async def model_response_handler(request, channel, message, user): +async def model_response_handler(request, channel, message, user, db=None): MODELS = { model["id"]: model for model in get_filtered_models(await get_all_models(request, user=user), user) @@ -924,6 +1000,7 @@ async def model_response_handler(request, channel, message, user): thread_messages = Messages.get_messages_by_parent_id( channel.id, message.parent_id if message.parent_id else message.id, + db=db, )[::-1] response_message, channel = await new_message_handler( @@ -943,6 +1020,7 @@ async def model_response_handler(request, channel, message, user): } ), user, + db, ) thread_history = [] @@ -952,7 +1030,9 @@ async def model_response_handler(request, channel, message, user): for thread_message in thread_messages: message_user = None if thread_message.user_id not in message_users: - message_user = Users.get_user_by_id(thread_message.user_id) + message_user = Users.get_user_by_id( + thread_message.user_id, db=db + ) message_users[thread_message.user_id] = message_user else: message_user = message_users[thread_message.user_id] @@ -1031,6 +1111,7 @@ async def model_response_handler(request, channel, message, user): if res: if res.get("choices", []) and len(res["choices"]) > 0: await update_message_by_id( + request, channel.id, response_message.id, MessageForm( @@ -1042,9 +1123,11 @@ async def model_response_handler(request, channel, message, user): } ), user, + db, ) elif res.get("error", None): await update_message_by_id( + request, channel.id, response_message.id, MessageForm( @@ -1056,6 +1139,7 @@ async def model_response_handler(request, channel, message, user): } ), user, + db, ) except Exception as e: log.info(e) @@ -1065,39 +1149,43 @@ async def model_response_handler(request, channel, message, user): async def new_message_handler( - request: Request, id: str, form_data: MessageForm, user=Depends(get_verified_user) + request: Request, id: str, form_data: MessageForm, user, db ): - channel = Channels.get_channel_by_id(id) + channel = Channels.get_channel_by_id(id, db=db) if not channel: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND ) if channel.type in ["group", "dm"]: - if not Channels.is_user_channel_member(channel.id, user.id): + if not Channels.is_user_channel_member(channel.id, user.id, db=db): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() ) else: if user.role != "admin" and not has_access( - user.id, type="write", access_control=channel.access_control, strict=False + user.id, + type="write", + access_control=channel.access_control, + strict=False, + db=db, ): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() ) try: - message = Messages.insert_new_message(form_data, channel.id, user.id) + message = Messages.insert_new_message(form_data, channel.id, user.id, db=db) if message: if channel.type in ["group", "dm"]: - members = Channels.get_members_by_channel_id(channel.id) + members = Channels.get_members_by_channel_id(channel.id, db=db) for member in members: if not member.is_active: Channels.update_member_active_status( - channel.id, member.user_id, True + channel.id, member.user_id, True, db=db ) - message = Messages.get_message_by_id(message.id) + message = Messages.get_message_by_id(message.id, db=db) event_data = { "channel_id": channel.id, "message_id": message.id, @@ -1117,7 +1205,7 @@ async def new_message_handler( if message.parent_id: # If this message is a reply, emit to the parent message as well - parent_message = Messages.get_message_by_id(message.parent_id) + parent_message = Messages.get_message_by_id(message.parent_id, db=db) if parent_message: await sio.emit( @@ -1151,16 +1239,17 @@ async def post_new_message( form_data: MessageForm, background_tasks: BackgroundTasks, user=Depends(get_verified_user), + db: Session = Depends(get_session), ): check_channels_access(request) try: - message, channel = await new_message_handler(request, id, form_data, user) + message, channel = await new_message_handler(request, id, form_data, user, db) try: if files := message.data.get("files", []): for file in files: Channels.set_file_message_id_in_channel_by_id( - channel.id, file.get("id", ""), message.id + channel.id, file.get("id", ""), message.id, db=db ) except Exception as e: log.debug(e) @@ -1168,13 +1257,14 @@ async def post_new_message( active_user_ids = get_user_ids_from_room(f"channel:{channel.id}") async def background_handler(): - await model_response_handler(request, channel, message, user) + await model_response_handler(request, channel, message, user, db) await send_notification( request.app.state.WEBUI_NAME, request.app.state.config.WEBUI_URL, channel, message, active_user_ids, + db=db, ) background_tasks.add_task(background_handler) @@ -1201,28 +1291,29 @@ async def get_channel_message( id: str, message_id: str, user=Depends(get_verified_user), + db: Session = Depends(get_session), ): check_channels_access(request) - channel = Channels.get_channel_by_id(id) + channel = Channels.get_channel_by_id(id, db=db) if not channel: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND ) if channel.type in ["group", "dm"]: - if not Channels.is_user_channel_member(channel.id, user.id): + if not Channels.is_user_channel_member(channel.id, user.id, db=db): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() ) else: if user.role != "admin" and not has_access( - user.id, type="read", access_control=channel.access_control + user.id, type="read", access_control=channel.access_control, db=db ): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() ) - message = Messages.get_message_by_id(message_id) + message = Messages.get_message_by_id(message_id, db=db) if not message: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND @@ -1237,7 +1328,7 @@ async def get_channel_message( **{ **message.model_dump(), "user": UserNameResponse( - **Users.get_user_by_id(message.user_id).model_dump() + **Users.get_user_by_id(message.user_id, db=db).model_dump() ), } ) @@ -1254,28 +1345,29 @@ async def get_channel_message_data( id: str, message_id: str, user=Depends(get_verified_user), + db: Session = Depends(get_session), ): check_channels_access(request) - channel = Channels.get_channel_by_id(id) + channel = Channels.get_channel_by_id(id, db=db) if not channel: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND ) if channel.type in ["group", "dm"]: - if not Channels.is_user_channel_member(channel.id, user.id): + if not Channels.is_user_channel_member(channel.id, user.id, db=db): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() ) else: if user.role != "admin" and not has_access( - user.id, type="read", access_control=channel.access_control + user.id, type="read", access_control=channel.access_control, db=db ): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() ) - message = Messages.get_message_by_id(message_id) + message = Messages.get_message_by_id(message_id, db=db) if not message: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND @@ -1307,28 +1399,29 @@ async def pin_channel_message( message_id: str, form_data: PinMessageForm, user=Depends(get_verified_user), + db: Session = Depends(get_session), ): check_channels_access(request) - channel = Channels.get_channel_by_id(id) + channel = Channels.get_channel_by_id(id, db=db) if not channel: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND ) if channel.type in ["group", "dm"]: - if not Channels.is_user_channel_member(channel.id, user.id): + if not Channels.is_user_channel_member(channel.id, user.id, db=db): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() ) else: if user.role != "admin" and not has_access( - user.id, type="read", access_control=channel.access_control + user.id, type="read", access_control=channel.access_control, db=db ): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() ) - message = Messages.get_message_by_id(message_id) + message = Messages.get_message_by_id(message_id, db=db) if not message: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND @@ -1340,13 +1433,13 @@ async def pin_channel_message( ) try: - Messages.update_is_pinned_by_id(message_id, form_data.is_pinned, user.id) - message = Messages.get_message_by_id(message_id) + Messages.update_is_pinned_by_id(message_id, form_data.is_pinned, user.id, db=db) + message = Messages.get_message_by_id(message_id, db=db) return MessageUserResponse( **{ **message.model_dump(), "user": UserNameResponse( - **Users.get_user_by_id(message.user_id).model_dump() + **Users.get_user_by_id(message.user_id, db=db).model_dump() ), } ) @@ -1372,48 +1465,56 @@ async def get_channel_thread_messages( skip: int = 0, limit: int = 50, user=Depends(get_verified_user), + db: Session = Depends(get_session), ): check_channels_access(request) - channel = Channels.get_channel_by_id(id) + channel = Channels.get_channel_by_id(id, db=db) if not channel: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND ) if channel.type in ["group", "dm"]: - if not Channels.is_user_channel_member(channel.id, user.id): + if not Channels.is_user_channel_member(channel.id, user.id, db=db): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() ) else: if user.role != "admin" and not has_access( - user.id, type="read", access_control=channel.access_control + user.id, type="read", access_control=channel.access_control, db=db ): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() ) - message_list = Messages.get_messages_by_parent_id(id, message_id, skip, limit) - users = {} + message_list = Messages.get_messages_by_parent_id( + id, message_id, skip, limit, db=db + ) + + if not message_list: + return [] + + # Batch fetch all users in a single query (fixes N+1 problem) + user_ids = list(set(m.user_id for m in message_list)) + users = {u.id: u for u in Users.get_users_by_user_ids(user_ids, db=db)} messages = [] for message in message_list: - if message.user_id not in users: - user = Users.get_user_by_id(message.user_id) - users[message.user_id] = user - - user = None - user_model = users.get(message.user_id) or None - if user_model: - user = UserNameResponse(**user_model.model_dump()) + # Use message.user if present (for webhooks), otherwise look up by user_id + user_info = message.user + if user_info is None and message.user_id in users: + user_info = UserNameResponse(**users[message.user_id].model_dump()) + messages.append( MessageUserResponse( **{ **message.model_dump(), "reply_count": 0, "latest_reply_at": None, - "reactions": Messages.get_reactions_by_message_id(message.id), - "user": user, + "reactions": Messages.get_reactions_by_message_id( + message.id, db=db + ), + "user": user_info, } ) ) @@ -1435,15 +1536,16 @@ async def update_message_by_id( message_id: str, form_data: MessageForm, user=Depends(get_verified_user), + db: Session = Depends(get_session), ): check_channels_access(request) - channel = Channels.get_channel_by_id(id) + channel = Channels.get_channel_by_id(id, db=db) if not channel: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND ) - message = Messages.get_message_by_id(message_id) + message = Messages.get_message_by_id(message_id, db=db) if not message: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND @@ -1455,7 +1557,7 @@ async def update_message_by_id( ) if channel.type in ["group", "dm"]: - if not Channels.is_user_channel_member(channel.id, user.id): + if not Channels.is_user_channel_member(channel.id, user.id, db=db): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() ) @@ -1464,7 +1566,7 @@ async def update_message_by_id( user.role != "admin" and message.user_id != user.id and not has_access( - user.id, type="read", access_control=channel.access_control + user.id, type="read", access_control=channel.access_control, db=db ) ): raise HTTPException( @@ -1472,8 +1574,8 @@ async def update_message_by_id( ) try: - message = Messages.update_message_by_id(message_id, form_data) - message = Messages.get_message_by_id(message_id) + message = Messages.update_message_by_id(message_id, form_data, db=db) + message = Messages.get_message_by_id(message_id, db=db) if message: await sio.emit( @@ -1515,28 +1617,33 @@ async def add_reaction_to_message( message_id: str, form_data: ReactionForm, user=Depends(get_verified_user), + db: Session = Depends(get_session), ): check_channels_access(request) - channel = Channels.get_channel_by_id(id) + channel = Channels.get_channel_by_id(id, db=db) if not channel: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND ) if channel.type in ["group", "dm"]: - if not Channels.is_user_channel_member(channel.id, user.id): + if not Channels.is_user_channel_member(channel.id, user.id, db=db): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() ) else: if user.role != "admin" and not has_access( - user.id, type="write", access_control=channel.access_control, strict=False + user.id, + type="write", + access_control=channel.access_control, + strict=False, + db=db, ): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() ) - message = Messages.get_message_by_id(message_id) + message = Messages.get_message_by_id(message_id, db=db) if not message: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND @@ -1548,8 +1655,8 @@ async def add_reaction_to_message( ) try: - Messages.add_reaction_to_message(message_id, user.id, form_data.name) - message = Messages.get_message_by_id(message_id) + Messages.add_reaction_to_message(message_id, user.id, form_data.name, db=db) + message = Messages.get_message_by_id(message_id, db=db) await sio.emit( "events:channel", @@ -1589,28 +1696,33 @@ async def remove_reaction_by_id_and_user_id_and_name( message_id: str, form_data: ReactionForm, user=Depends(get_verified_user), + db: Session = Depends(get_session), ): check_channels_access(request) - channel = Channels.get_channel_by_id(id) + channel = Channels.get_channel_by_id(id, db=db) if not channel: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND ) if channel.type in ["group", "dm"]: - if not Channels.is_user_channel_member(channel.id, user.id): + if not Channels.is_user_channel_member(channel.id, user.id, db=db): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() ) else: if user.role != "admin" and not has_access( - user.id, type="write", access_control=channel.access_control, strict=False + user.id, + type="write", + access_control=channel.access_control, + strict=False, + db=db, ): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() ) - message = Messages.get_message_by_id(message_id) + message = Messages.get_message_by_id(message_id, db=db) if not message: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND @@ -1623,10 +1735,10 @@ async def remove_reaction_by_id_and_user_id_and_name( try: Messages.remove_reaction_by_id_and_user_id_and_name( - message_id, user.id, form_data.name + message_id, user.id, form_data.name, db=db ) - message = Messages.get_message_by_id(message_id) + message = Messages.get_message_by_id(message_id, db=db) await sio.emit( "events:channel", @@ -1665,15 +1777,16 @@ async def delete_message_by_id( id: str, message_id: str, user=Depends(get_verified_user), + db: Session = Depends(get_session), ): check_channels_access(request) - channel = Channels.get_channel_by_id(id) + channel = Channels.get_channel_by_id(id, db=db) if not channel: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND ) - message = Messages.get_message_by_id(message_id) + message = Messages.get_message_by_id(message_id, db=db) if not message: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND @@ -1685,7 +1798,7 @@ async def delete_message_by_id( ) if channel.type in ["group", "dm"]: - if not Channels.is_user_channel_member(channel.id, user.id): + if not Channels.is_user_channel_member(channel.id, user.id, db=db): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() ) @@ -1698,6 +1811,7 @@ async def delete_message_by_id( type="write", access_control=channel.access_control, strict=False, + db=db, ) ): raise HTTPException( @@ -1705,7 +1819,7 @@ async def delete_message_by_id( ) try: - Messages.delete_message_by_id(message_id) + Messages.delete_message_by_id(message_id, db=db) await sio.emit( "events:channel", { @@ -1726,7 +1840,7 @@ async def delete_message_by_id( if message.parent_id: # If this message is a reply, emit to the parent message as well - parent_message = Messages.get_message_by_id(message.parent_id) + parent_message = Messages.get_message_by_id(message.parent_id, db=db) if parent_message: await sio.emit( @@ -1750,3 +1864,263 @@ async def delete_message_by_id( raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT() ) + + +############################ +# Webhooks +############################ + + +@router.get("/webhooks/{webhook_id}/profile/image") +async def get_webhook_profile_image( + webhook_id: str, + user=Depends(get_verified_user), + db: Session = Depends(get_session), +): + """Get webhook profile image by webhook ID.""" + webhook = Channels.get_webhook_by_id(webhook_id, db=db) + if not webhook: + # Return default favicon if webhook not found + return FileResponse(f"{STATIC_DIR}/favicon.png") + + if webhook.profile_image_url: + # Check if it's url or base64 + if webhook.profile_image_url.startswith("http"): + return Response( + status_code=status.HTTP_302_FOUND, + headers={"Location": webhook.profile_image_url}, + ) + elif webhook.profile_image_url.startswith("data:image"): + try: + header, base64_data = webhook.profile_image_url.split(",", 1) + image_data = base64.b64decode(base64_data) + image_buffer = io.BytesIO(image_data) + media_type = header.split(";")[0].lstrip("data:") + + return StreamingResponse( + image_buffer, + media_type=media_type, + headers={"Content-Disposition": "inline"}, + ) + except Exception as e: + pass + + # Return default favicon if no profile image + return FileResponse(f"{STATIC_DIR}/favicon.png") + + +@router.get("/{id}/webhooks", response_model=list[ChannelWebhookModel]) +async def get_channel_webhooks( + request: Request, + id: str, + user=Depends(get_verified_user), + db: Session = Depends(get_session), +): + check_channels_access(request) + channel = Channels.get_channel_by_id(id, db=db) + if not channel: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND + ) + + # Only channel managers can view webhooks + if ( + not Channels.is_user_channel_manager(channel.id, user.id, db=db) + and user.role != "admin" + ): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.UNAUTHORIZED + ) + + return Channels.get_webhooks_by_channel_id(id, db=db) + + +@router.post("/{id}/webhooks/create", response_model=ChannelWebhookModel) +async def create_channel_webhook( + request: Request, + id: str, + form_data: ChannelWebhookForm, + user=Depends(get_verified_user), + db: Session = Depends(get_session), +): + check_channels_access(request) + channel = Channels.get_channel_by_id(id, db=db) + if not channel: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND + ) + + # Only channel managers can create webhooks + if ( + not Channels.is_user_channel_manager(channel.id, user.id, db=db) + and user.role != "admin" + ): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.UNAUTHORIZED + ) + + webhook = Channels.insert_webhook(id, user.id, form_data, db=db) + if not webhook: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT() + ) + + return webhook + + +@router.post("/{id}/webhooks/{webhook_id}/update", response_model=ChannelWebhookModel) +async def update_channel_webhook( + request: Request, + id: str, + webhook_id: str, + form_data: ChannelWebhookForm, + user=Depends(get_verified_user), + db: Session = Depends(get_session), +): + check_channels_access(request) + channel = Channels.get_channel_by_id(id, db=db) + if not channel: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND + ) + + # Only channel managers can update webhooks + if ( + not Channels.is_user_channel_manager(channel.id, user.id, db=db) + and user.role != "admin" + ): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.UNAUTHORIZED + ) + + webhook = Channels.get_webhook_by_id(webhook_id, db=db) + if not webhook or webhook.channel_id != id: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND + ) + + updated = Channels.update_webhook_by_id(webhook_id, form_data, db=db) + if not updated: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT() + ) + + return updated + + +@router.delete("/{id}/webhooks/{webhook_id}/delete", response_model=bool) +async def delete_channel_webhook( + request: Request, + id: str, + webhook_id: str, + user=Depends(get_verified_user), + db: Session = Depends(get_session), +): + check_channels_access(request) + channel = Channels.get_channel_by_id(id, db=db) + if not channel: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND + ) + + # Only channel managers can delete webhooks + if ( + not Channels.is_user_channel_manager(channel.id, user.id, db=db) + and user.role != "admin" + ): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.UNAUTHORIZED + ) + + webhook = Channels.get_webhook_by_id(webhook_id, db=db) + if not webhook or webhook.channel_id != id: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND + ) + + return Channels.delete_webhook_by_id(webhook_id, db=db) + + +############################ +# Public Webhook Endpoint +############################ + + +class WebhookMessageForm(BaseModel): + content: str + + +@router.post("/webhooks/{webhook_id}/{token}") +async def post_webhook_message( + request: Request, + webhook_id: str, + token: str, + form_data: WebhookMessageForm, + db: Session = Depends(get_session), +): + """Public endpoint to post messages via webhook. No authentication required.""" + check_channels_access(request) + + # Validate webhook + webhook = Channels.get_webhook_by_id_and_token(webhook_id, token, db=db) + if not webhook: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid webhook URL", + ) + + channel = Channels.get_channel_by_id(webhook.channel_id, db=db) + if not channel: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND + ) + + # Create message with webhook identity stored in meta + message = Messages.insert_new_message( + MessageForm(content=form_data.content, meta={"webhook": {"id": webhook.id}}), + webhook.channel_id, + webhook.user_id, # Required for DB but webhook info in meta takes precedence + db=db, + ) + + if not message: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Failed to create message", + ) + + # Update last_used_at + Channels.update_webhook_last_used_at(webhook_id, db=db) + + # Get full message and emit event + message = Messages.get_message_by_id(message.id, db=db) + + event_data = { + "channel_id": channel.id, + "message_id": message.id, + "data": { + "type": "message", + "data": { + **message.model_dump(), + "user": { + "id": webhook.id, + "name": webhook.name, + "role": "webhook", + }, + }, + }, + "user": { + "id": webhook.id, + "name": webhook.name, + "role": "webhook", + }, + "channel": channel.model_dump(), + } + + await sio.emit( + "events:channel", + event_data, + to=f"channel:{channel.id}", + ) + + return {"success": True, "message_id": message.id} diff --git a/backend/open_webui/routers/chats.py b/backend/open_webui/routers/chats.py index 8dde946a4d..9a43234aa6 100644 --- a/backend/open_webui/routers/chats.py +++ b/backend/open_webui/routers/chats.py @@ -1,6 +1,9 @@ import json import logging from typing import Optional +from sqlalchemy.orm import Session +import asyncio +from fastapi.responses import StreamingResponse from open_webui.utils.misc import get_message_list @@ -13,9 +16,15 @@ ChatResponse, Chats, ChatTitleIdResponse, + ChatStatsExport, + AggregateChatStats, + ChatBody, + ChatHistoryStats, + MessageStats, ) from open_webui.models.tags import TagModel, Tags from open_webui.models.folders import Folders +from open_webui.internal.db import get_session from open_webui.config import ENABLE_ADMIN_CHAT_ACCESS, ENABLE_ADMIN_EXPORT from open_webui.constants import ERROR_MESSAGES @@ -42,6 +51,7 @@ def get_session_user_chat_list( page: Optional[int] = None, include_pinned: Optional[bool] = False, include_folders: Optional[bool] = False, + db: Session = Depends(get_session), ): try: if page is not None: @@ -54,10 +64,14 @@ def get_session_user_chat_list( include_pinned=include_pinned, skip=skip, limit=limit, + db=db, ) else: return Chats.get_chat_title_id_list_by_user_id( - user.id, include_folders=include_folders, include_pinned=include_pinned + user.id, + include_folders=include_folders, + include_pinned=include_pinned, + db=db, ) except Exception as e: log.exception(e) @@ -77,12 +91,13 @@ def get_session_user_chat_usage_stats( items_per_page: Optional[int] = 50, page: Optional[int] = 1, user=Depends(get_verified_user), + db: Session = Depends(get_session), ): try: limit = items_per_page skip = (page - 1) * limit - result = Chats.get_chats_by_user_id(user.id, skip=skip, limit=limit) + result = Chats.get_chats_by_user_id(user.id, skip=skip, limit=limit, db=db) chats = result.items total = result.total @@ -193,12 +208,334 @@ def get_session_user_chat_usage_stats( ############################ -# DeleteAllChats +# GetChatStatsExport +############################ + + +CHAT_EXPORT_PAGE_ITEM_COUNT = 10 + + +class ChatStatsExportList(BaseModel): + type: str = "chats" + items: list[ChatStatsExport] + total: int + page: int + + +def _process_chat_for_export(chat) -> Optional[ChatStatsExport]: + try: + + def get_message_content_length(message): + content = message.get("content", "") + if isinstance(content, str): + return len(content) + elif isinstance(content, list): + return sum( + len(item.get("text", "")) + for item in content + if item.get("type") == "text" + ) + return 0 + + messages_map = chat.chat.get("history", {}).get("messages", {}) + message_id = chat.chat.get("history", {}).get("currentId") + + history_models = {} + history_message_count = len(messages_map) + history_user_messages = [] + history_assistant_messages = [] + + export_messages = {} + for key, message in messages_map.items(): + try: + content_length = get_message_content_length(message) + + # Extract rating safely + rating = message.get("annotation", {}).get("rating") + tags = message.get("annotation", {}).get("tags") + + message_stat = MessageStats( + id=message.get("id"), + role=message.get("role"), + model=message.get("model"), + timestamp=message.get("timestamp"), + content_length=content_length, + token_count=None, # Populate if available, e.g. message.get("info", {}).get("token_count") + rating=rating, + tags=tags, + ) + + export_messages[key] = message_stat + + # --- Aggregation Logic (copied/adapted from usage stats) --- + role = message.get("role", "") + if role == "user": + history_user_messages.append(message) + elif role == "assistant": + history_assistant_messages.append(message) + model = message.get("model") + if model: + if model not in history_models: + history_models[model] = 0 + history_models[model] += 1 + except Exception as e: + log.debug(f"Error processing message {key}: {e}") + continue + + # Calculate Averages + average_user_message_content_length = ( + sum(get_message_content_length(m) for m in history_user_messages) + / len(history_user_messages) + if history_user_messages + else 0 + ) + + average_assistant_message_content_length = ( + sum(get_message_content_length(m) for m in history_assistant_messages) + / len(history_assistant_messages) + if history_assistant_messages + else 0 + ) + + # Response Times + response_times = [] + for message in history_assistant_messages: + user_message_id = message.get("parentId", None) + if user_message_id and user_message_id in messages_map: + user_message = messages_map[user_message_id] + # Ensure timestamps exist + t1 = message.get("timestamp") + t0 = user_message.get("timestamp") + if t1 and t0: + response_times.append(t1 - t0) + + average_response_time = ( + sum(response_times) / len(response_times) if response_times else 0 + ) + + # Current Message List Logic (Main path) + message_list = get_message_list(messages_map, message_id) + message_count = len(message_list) + models = {} + for message in reversed(message_list): + if message.get("role") == "assistant": + model = message.get("model") + if model: + if model not in models: + models[model] = 0 + models[model] += 1 + + # Construct Aggregate Stats + stats = AggregateChatStats( + average_response_time=average_response_time, + average_user_message_content_length=average_user_message_content_length, + average_assistant_message_content_length=average_assistant_message_content_length, + models=models, + message_count=message_count, + history_models=history_models, + history_message_count=history_message_count, + history_user_message_count=len(history_user_messages), + history_assistant_message_count=len(history_assistant_messages), + ) + + # Construct Chat Body + chat_body = ChatBody( + history=ChatHistoryStats(messages=export_messages, currentId=message_id) + ) + + return ChatStatsExport( + id=chat.id, + user_id=chat.user_id, + created_at=chat.created_at, + updated_at=chat.updated_at, + tags=chat.meta.get("tags", []), + stats=stats, + chat=chat_body, + ) + except Exception as e: + log.exception(f"Error exporting stats for chat {chat.id}: {e}") + return None + + +def calculate_chat_stats( + user_id, skip=0, limit=10, filter=None, db: Optional[Session] = None +): + if filter is None: + filter = {} + + result = Chats.get_chats_by_user_id( + user_id, + skip=skip, + limit=limit, + filter=filter, + db=db, + ) + + chat_stats_export_list = [] + for chat in result.items: + chat_stat = _process_chat_for_export(chat) + if chat_stat: + chat_stats_export_list.append(chat_stat) + + return chat_stats_export_list, result.total + + +def generate_chat_stats_jsonl_generator(user_id, filter): + """ + Synchronous generator for streaming chat stats export. + + NOTE: We intentionally do NOT pass a shared db session here. Instead, we let + each batch create its own short-lived session via get_db_context(None). + This is critical for SQLite in low-resource environments because: + 1. SQLite uses file-level locking + 2. Holding a session open for the entire streaming duration blocks other requests + 3. Short-lived sessions release locks between batches, allowing other operations + """ + skip = 0 + limit = CHAT_EXPORT_PAGE_ITEM_COUNT + + while True: + # Each batch gets its own session that closes after the query + result = Chats.get_chats_by_user_id( + user_id, + filter=filter, + skip=skip, + limit=limit, + db=None, # Let get_db_context create a fresh session per batch + ) + if not result.items: + break + + for chat in result.items: + try: + chat_stat = _process_chat_for_export(chat) + if chat_stat: + yield chat_stat.model_dump_json() + "\n" + except Exception as e: + log.exception(f"Error processing chat {chat.id}: {e}") + + skip += limit + + +@router.get("/stats/export", response_model=ChatStatsExportList) +async def export_chat_stats( + request: Request, + updated_at: Optional[int] = None, + page: Optional[int] = 1, + stream: bool = False, + user=Depends(get_verified_user), + db: Session = Depends(get_session), +): + # Check if the user has permission to share/export chats + if (user.role != "admin") and ( + not request.app.state.config.ENABLE_COMMUNITY_SHARING + ): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) + + try: + # Fetch chats with date filtering + filter = {"order_by": "updated_at", "direction": "asc"} + + if updated_at: + filter["updated_at"] = updated_at + + if stream: + return StreamingResponse( + generate_chat_stats_jsonl_generator(user.id, filter), + media_type="application/x-ndjson", + headers={ + "Content-Disposition": f"attachment; filename=chat-stats-export-{user.id}.jsonl" + }, + ) + else: + limit = CHAT_EXPORT_PAGE_ITEM_COUNT + skip = (page - 1) * limit + + chat_stats_export_list, total = await asyncio.to_thread( + calculate_chat_stats, user.id, skip, limit, filter, db=db + ) + + return ChatStatsExportList( + items=chat_stats_export_list, total=total, page=page + ) + + except Exception as e: + log.debug(f"Error exporting chat stats: {e}") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT() + ) + + +############################ +# GetSingleChatStatsExport ############################ +@router.get("/stats/export/{chat_id}", response_model=Optional[ChatStatsExport]) +async def export_single_chat_stats( + request: Request, + chat_id: str, + user=Depends(get_verified_user), + db: Session = Depends(get_session), +): + """ + Export stats for exactly one chat by ID. + Returns ChatStatsExport for the specified chat. + """ + # Check if the user has permission to share/export chats + if (user.role != "admin") and ( + not request.app.state.config.ENABLE_COMMUNITY_SHARING + ): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) + + try: + chat = Chats.get_chat_by_id(chat_id, db=db) + + if not chat: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + # Verify the chat belongs to the user (unless admin) + if chat.user_id != user.id and user.role != "admin": + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) + + # Process the chat for export + chat_stats = await asyncio.to_thread(_process_chat_for_export, chat) + + if not chat_stats: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Failed to process chat stats", + ) + + return chat_stats + + except HTTPException: + raise + except Exception as e: + log.debug(f"Error exporting single chat stats: {e}") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT() + ) + + @router.delete("/", response_model=bool) -async def delete_all_user_chats(request: Request, user=Depends(get_verified_user)): +async def delete_all_user_chats( + request: Request, + user=Depends(get_verified_user), + db: Session = Depends(get_session), +): if user.role == "user" and not has_permission( user.id, "chat.delete", request.app.state.config.USER_PERMISSIONS @@ -208,7 +545,7 @@ async def delete_all_user_chats(request: Request, user=Depends(get_verified_user detail=ERROR_MESSAGES.ACCESS_PROHIBITED, ) - result = Chats.delete_chats_by_user_id(user.id) + result = Chats.delete_chats_by_user_id(user.id, db=db) return result @@ -225,6 +562,7 @@ async def get_user_chat_list_by_user_id( order_by: Optional[str] = None, direction: Optional[str] = None, user=Depends(get_admin_user), + db: Session = Depends(get_session), ): if not ENABLE_ADMIN_CHAT_ACCESS: raise HTTPException( @@ -247,7 +585,7 @@ async def get_user_chat_list_by_user_id( filter["direction"] = direction return Chats.get_chat_list_by_user_id( - user_id, include_archived=True, filter=filter, skip=skip, limit=limit + user_id, include_archived=True, filter=filter, skip=skip, limit=limit, db=db ) @@ -257,9 +595,13 @@ async def get_user_chat_list_by_user_id( @router.post("/new", response_model=Optional[ChatResponse]) -async def create_new_chat(form_data: ChatForm, user=Depends(get_verified_user)): +async def create_new_chat( + form_data: ChatForm, + user=Depends(get_verified_user), + db: Session = Depends(get_session), +): try: - chat = Chats.insert_new_chat(user.id, form_data) + chat = Chats.insert_new_chat(user.id, form_data, db=db) return ChatResponse(**chat.model_dump()) except Exception as e: log.exception(e) @@ -274,9 +616,13 @@ async def create_new_chat(form_data: ChatForm, user=Depends(get_verified_user)): @router.post("/import", response_model=list[ChatResponse]) -async def import_chats(form_data: ChatsImportForm, user=Depends(get_verified_user)): +async def import_chats( + form_data: ChatsImportForm, + user=Depends(get_verified_user), + db: Session = Depends(get_session), +): try: - chats = Chats.import_chats(user.id, form_data.chats) + chats = Chats.import_chats(user.id, form_data.chats, db=db) return chats except Exception as e: log.exception(e) @@ -292,7 +638,10 @@ async def import_chats(form_data: ChatsImportForm, user=Depends(get_verified_use @router.get("/search", response_model=list[ChatTitleIdResponse]) def search_user_chats( - text: str, page: Optional[int] = None, user=Depends(get_verified_user) + text: str, + page: Optional[int] = None, + user=Depends(get_verified_user), + db: Session = Depends(get_session), ): if page is None: page = 1 @@ -303,7 +652,7 @@ def search_user_chats( chat_list = [ ChatTitleIdResponse(**chat.model_dump()) for chat in Chats.get_chats_by_user_id_and_search_text( - user.id, text, skip=skip, limit=limit + user.id, text, skip=skip, limit=limit, db=db ) ] @@ -312,9 +661,9 @@ def search_user_chats( if page == 1 and len(words) == 1 and words[0].startswith("tag:"): tag_id = words[0].replace("tag:", "") if len(chat_list) == 0: - if Tags.get_tag_by_name_and_user_id(tag_id, user.id): + if Tags.get_tag_by_name_and_user_id(tag_id, user.id, db=db): log.debug(f"deleting tag: {tag_id}") - Tags.delete_tag_by_name_and_user_id(tag_id, user.id) + Tags.delete_tag_by_name_and_user_id(tag_id, user.id, db=db) return chat_list @@ -325,23 +674,30 @@ def search_user_chats( @router.get("/folder/{folder_id}", response_model=list[ChatResponse]) -async def get_chats_by_folder_id(folder_id: str, user=Depends(get_verified_user)): +async def get_chats_by_folder_id( + folder_id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) +): folder_ids = [folder_id] children_folders = Folders.get_children_folders_by_id_and_user_id( - folder_id, user.id + folder_id, user.id, db=db ) if children_folders: folder_ids.extend([folder.id for folder in children_folders]) return [ ChatResponse(**chat.model_dump()) - for chat in Chats.get_chats_by_folder_ids_and_user_id(folder_ids, user.id) + for chat in Chats.get_chats_by_folder_ids_and_user_id( + folder_ids, user.id, db=db + ) ] @router.get("/folder/{folder_id}/list") async def get_chat_list_by_folder_id( - folder_id: str, page: Optional[int] = 1, user=Depends(get_verified_user) + folder_id: str, + page: Optional[int] = 1, + user=Depends(get_verified_user), + db: Session = Depends(get_session), ): try: limit = 10 @@ -350,7 +706,7 @@ async def get_chat_list_by_folder_id( return [ {"title": chat.title, "id": chat.id, "updated_at": chat.updated_at} for chat in Chats.get_chats_by_folder_id_and_user_id( - folder_id, user.id, skip=skip, limit=limit + folder_id, user.id, skip=skip, limit=limit, db=db ) ] @@ -367,10 +723,12 @@ async def get_chat_list_by_folder_id( @router.get("/pinned", response_model=list[ChatTitleIdResponse]) -async def get_user_pinned_chats(user=Depends(get_verified_user)): +async def get_user_pinned_chats( + user=Depends(get_verified_user), db: Session = Depends(get_session) +): return [ ChatTitleIdResponse(**chat.model_dump()) - for chat in Chats.get_pinned_chats_by_user_id(user.id) + for chat in Chats.get_pinned_chats_by_user_id(user.id, db=db) ] @@ -380,11 +738,11 @@ async def get_user_pinned_chats(user=Depends(get_verified_user)): @router.get("/all", response_model=list[ChatResponse]) -async def get_user_chats(user=Depends(get_verified_user)): - return [ - ChatResponse(**chat.model_dump()) - for chat in Chats.get_chats_by_user_id(user.id) - ] +async def get_user_chats( + user=Depends(get_verified_user), db: Session = Depends(get_session) +): + result = Chats.get_chats_by_user_id(user.id, db=db) + return [ChatResponse(**chat.model_dump()) for chat in result.items] ############################ @@ -393,10 +751,12 @@ async def get_user_chats(user=Depends(get_verified_user)): @router.get("/all/archived", response_model=list[ChatResponse]) -async def get_user_archived_chats(user=Depends(get_verified_user)): +async def get_user_archived_chats( + user=Depends(get_verified_user), db: Session = Depends(get_session) +): return [ ChatResponse(**chat.model_dump()) - for chat in Chats.get_archived_chats_by_user_id(user.id) + for chat in Chats.get_archived_chats_by_user_id(user.id, db=db) ] @@ -406,9 +766,11 @@ async def get_user_archived_chats(user=Depends(get_verified_user)): @router.get("/all/tags", response_model=list[TagModel]) -async def get_all_user_tags(user=Depends(get_verified_user)): +async def get_all_user_tags( + user=Depends(get_verified_user), db: Session = Depends(get_session) +): try: - tags = Tags.get_tags_by_user_id(user.id) + tags = Tags.get_tags_by_user_id(user.id, db=db) return tags except Exception as e: log.exception(e) @@ -423,13 +785,15 @@ async def get_all_user_tags(user=Depends(get_verified_user)): @router.get("/all/db", response_model=list[ChatResponse]) -async def get_all_user_chats_in_db(user=Depends(get_admin_user)): +async def get_all_user_chats_in_db( + user=Depends(get_admin_user), db: Session = Depends(get_session) +): if not ENABLE_ADMIN_EXPORT: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.ACCESS_PROHIBITED, ) - return [ChatResponse(**chat.model_dump()) for chat in Chats.get_chats()] + return [ChatResponse(**chat.model_dump()) for chat in Chats.get_chats(db=db)] ############################ @@ -444,6 +808,7 @@ async def get_archived_session_user_chat_list( order_by: Optional[str] = None, direction: Optional[str] = None, user=Depends(get_verified_user), + db: Session = Depends(get_session), ): if page is None: page = 1 @@ -466,6 +831,7 @@ async def get_archived_session_user_chat_list( filter=filter, skip=skip, limit=limit, + db=db, ) ] @@ -478,8 +844,10 @@ async def get_archived_session_user_chat_list( @router.post("/archive/all", response_model=bool) -async def archive_all_chats(user=Depends(get_verified_user)): - return Chats.archive_all_chats_by_user_id(user.id) +async def archive_all_chats( + user=Depends(get_verified_user), db: Session = Depends(get_session) +): + return Chats.archive_all_chats_by_user_id(user.id, db=db) ############################ @@ -488,8 +856,10 @@ async def archive_all_chats(user=Depends(get_verified_user)): @router.post("/unarchive/all", response_model=bool) -async def unarchive_all_chats(user=Depends(get_verified_user)): - return Chats.unarchive_all_chats_by_user_id(user.id) +async def unarchive_all_chats( + user=Depends(get_verified_user), db: Session = Depends(get_session) +): + return Chats.unarchive_all_chats_by_user_id(user.id, db=db) ############################ @@ -498,16 +868,18 @@ async def unarchive_all_chats(user=Depends(get_verified_user)): @router.get("/share/{share_id}", response_model=Optional[ChatResponse]) -async def get_shared_chat_by_id(share_id: str, user=Depends(get_verified_user)): +async def get_shared_chat_by_id( + share_id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) +): if user.role == "pending": raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND ) if user.role == "user" or (user.role == "admin" and not ENABLE_ADMIN_CHAT_ACCESS): - chat = Chats.get_chat_by_share_id(share_id) + chat = Chats.get_chat_by_share_id(share_id, db=db) elif user.role == "admin" and ENABLE_ADMIN_CHAT_ACCESS: - chat = Chats.get_chat_by_id(share_id) + chat = Chats.get_chat_by_id(share_id, db=db) if chat: return ChatResponse(**chat.model_dump()) @@ -534,13 +906,15 @@ class TagFilterForm(TagForm): @router.post("/tags", response_model=list[ChatTitleIdResponse]) async def get_user_chat_list_by_tag_name( - form_data: TagFilterForm, user=Depends(get_verified_user) + form_data: TagFilterForm, + user=Depends(get_verified_user), + db: Session = Depends(get_session), ): chats = Chats.get_chat_list_by_user_id_and_tag_name( - user.id, form_data.name, form_data.skip, form_data.limit + user.id, form_data.name, form_data.skip, form_data.limit, db=db ) if len(chats) == 0: - Tags.delete_tag_by_name_and_user_id(form_data.name, user.id) + Tags.delete_tag_by_name_and_user_id(form_data.name, user.id, db=db) return chats @@ -551,8 +925,10 @@ async def get_user_chat_list_by_tag_name( @router.get("/{id}", response_model=Optional[ChatResponse]) -async def get_chat_by_id(id: str, user=Depends(get_verified_user)): - chat = Chats.get_chat_by_id_and_user_id(id, user.id) +async def get_chat_by_id( + id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) +): + chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db) if chat: return ChatResponse(**chat.model_dump()) @@ -570,12 +946,15 @@ async def get_chat_by_id(id: str, user=Depends(get_verified_user)): @router.post("/{id}", response_model=Optional[ChatResponse]) async def update_chat_by_id( - id: str, form_data: ChatForm, user=Depends(get_verified_user) + id: str, + form_data: ChatForm, + user=Depends(get_verified_user), + db: Session = Depends(get_session), ): - chat = Chats.get_chat_by_id_and_user_id(id, user.id) + chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db) if chat: updated_chat = {**chat.chat, **form_data.chat} - chat = Chats.update_chat_by_id(id, updated_chat) + chat = Chats.update_chat_by_id(id, updated_chat, db=db) return ChatResponse(**chat.model_dump()) else: raise HTTPException( @@ -593,9 +972,13 @@ class MessageForm(BaseModel): @router.post("/{id}/messages/{message_id}", response_model=Optional[ChatResponse]) async def update_chat_message_by_id( - id: str, message_id: str, form_data: MessageForm, user=Depends(get_verified_user) + id: str, + message_id: str, + form_data: MessageForm, + user=Depends(get_verified_user), + db: Session = Depends(get_session), ): - chat = Chats.get_chat_by_id(id) + chat = Chats.get_chat_by_id(id, db=db) if not chat: raise HTTPException( @@ -615,6 +998,7 @@ async def update_chat_message_by_id( { "content": form_data.content, }, + db=db, ) event_emitter = get_event_emitter( @@ -651,9 +1035,13 @@ class EventForm(BaseModel): @router.post("/{id}/messages/{message_id}/event", response_model=Optional[bool]) async def send_chat_message_event_by_id( - id: str, message_id: str, form_data: EventForm, user=Depends(get_verified_user) + id: str, + message_id: str, + form_data: EventForm, + user=Depends(get_verified_user), + db: Session = Depends(get_session), ): - chat = Chats.get_chat_by_id(id) + chat = Chats.get_chat_by_id(id, db=db) if not chat: raise HTTPException( @@ -691,14 +1079,24 @@ async def send_chat_message_event_by_id( @router.delete("/{id}", response_model=bool) -async def delete_chat_by_id(request: Request, id: str, user=Depends(get_verified_user)): +async def delete_chat_by_id( + request: Request, + id: str, + user=Depends(get_verified_user), + db: Session = Depends(get_session), +): if user.role == "admin": - chat = Chats.get_chat_by_id(id) + chat = Chats.get_chat_by_id(id, db=db) + if not chat: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=ERROR_MESSAGES.NOT_FOUND, + ) for tag in chat.meta.get("tags", []): - if Chats.count_chats_by_tag_name_and_user_id(tag, user.id) == 1: - Tags.delete_tag_by_name_and_user_id(tag, user.id) + if Chats.count_chats_by_tag_name_and_user_id(tag, user.id, db=db) == 1: + Tags.delete_tag_by_name_and_user_id(tag, user.id, db=db) - result = Chats.delete_chat_by_id(id) + result = Chats.delete_chat_by_id(id, db=db) return result else: @@ -710,12 +1108,17 @@ async def delete_chat_by_id(request: Request, id: str, user=Depends(get_verified detail=ERROR_MESSAGES.ACCESS_PROHIBITED, ) - chat = Chats.get_chat_by_id(id) + chat = Chats.get_chat_by_id(id, db=db) + if not chat: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=ERROR_MESSAGES.NOT_FOUND, + ) for tag in chat.meta.get("tags", []): - if Chats.count_chats_by_tag_name_and_user_id(tag, user.id) == 1: - Tags.delete_tag_by_name_and_user_id(tag, user.id) + if Chats.count_chats_by_tag_name_and_user_id(tag, user.id, db=db) == 1: + Tags.delete_tag_by_name_and_user_id(tag, user.id, db=db) - result = Chats.delete_chat_by_id_and_user_id(id, user.id) + result = Chats.delete_chat_by_id_and_user_id(id, user.id, db=db) return result @@ -725,8 +1128,10 @@ async def delete_chat_by_id(request: Request, id: str, user=Depends(get_verified @router.get("/{id}/pinned", response_model=Optional[bool]) -async def get_pinned_status_by_id(id: str, user=Depends(get_verified_user)): - chat = Chats.get_chat_by_id_and_user_id(id, user.id) +async def get_pinned_status_by_id( + id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) +): + chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db) if chat: return chat.pinned else: @@ -741,10 +1146,12 @@ async def get_pinned_status_by_id(id: str, user=Depends(get_verified_user)): @router.post("/{id}/pin", response_model=Optional[ChatResponse]) -async def pin_chat_by_id(id: str, user=Depends(get_verified_user)): - chat = Chats.get_chat_by_id_and_user_id(id, user.id) +async def pin_chat_by_id( + id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) +): + chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db) if chat: - chat = Chats.toggle_chat_pinned_by_id(id) + chat = Chats.toggle_chat_pinned_by_id(id, db=db) return chat else: raise HTTPException( @@ -763,9 +1170,12 @@ class CloneForm(BaseModel): @router.post("/{id}/clone", response_model=Optional[ChatResponse]) async def clone_chat_by_id( - form_data: CloneForm, id: str, user=Depends(get_verified_user) + form_data: CloneForm, + id: str, + user=Depends(get_verified_user), + db: Session = Depends(get_session), ): - chat = Chats.get_chat_by_id_and_user_id(id, user.id) + chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db) if chat: updated_chat = { **chat.chat, @@ -786,6 +1196,7 @@ async def clone_chat_by_id( } ) ], + db=db, ) if chats: @@ -808,12 +1219,14 @@ async def clone_chat_by_id( @router.post("/{id}/clone/shared", response_model=Optional[ChatResponse]) -async def clone_shared_chat_by_id(id: str, user=Depends(get_verified_user)): +async def clone_shared_chat_by_id( + id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) +): if user.role == "admin": - chat = Chats.get_chat_by_id(id) + chat = Chats.get_chat_by_id(id, db=db) else: - chat = Chats.get_chat_by_share_id(id) + chat = Chats.get_chat_by_share_id(id, db=db) if chat: updated_chat = { @@ -835,6 +1248,7 @@ async def clone_shared_chat_by_id(id: str, user=Depends(get_verified_user)): } ) ], + db=db, ) if chats: @@ -857,23 +1271,28 @@ async def clone_shared_chat_by_id(id: str, user=Depends(get_verified_user)): @router.post("/{id}/archive", response_model=Optional[ChatResponse]) -async def archive_chat_by_id(id: str, user=Depends(get_verified_user)): - chat = Chats.get_chat_by_id_and_user_id(id, user.id) +async def archive_chat_by_id( + id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) +): + chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db) if chat: - chat = Chats.toggle_chat_archive_by_id(id) + chat = Chats.toggle_chat_archive_by_id(id, db=db) # Delete tags if chat is archived if chat.archived: for tag_id in chat.meta.get("tags", []): - if Chats.count_chats_by_tag_name_and_user_id(tag_id, user.id) == 0: + if ( + Chats.count_chats_by_tag_name_and_user_id(tag_id, user.id, db=db) + == 0 + ): log.debug(f"deleting tag: {tag_id}") - Tags.delete_tag_by_name_and_user_id(tag_id, user.id) + Tags.delete_tag_by_name_and_user_id(tag_id, user.id, db=db) else: for tag_id in chat.meta.get("tags", []): - tag = Tags.get_tag_by_name_and_user_id(tag_id, user.id) + tag = Tags.get_tag_by_name_and_user_id(tag_id, user.id, db=db) if tag is None: log.debug(f"inserting tag: {tag_id}") - tag = Tags.insert_new_tag(tag_id, user.id) + tag = Tags.insert_new_tag(tag_id, user.id, db=db) return ChatResponse(**chat.model_dump()) else: @@ -888,7 +1307,12 @@ async def archive_chat_by_id(id: str, user=Depends(get_verified_user)): @router.post("/{id}/share", response_model=Optional[ChatResponse]) -async def share_chat_by_id(request: Request, id: str, user=Depends(get_verified_user)): +async def share_chat_by_id( + request: Request, + id: str, + user=Depends(get_verified_user), + db: Session = Depends(get_session), +): if (user.role != "admin") and ( not has_permission( user.id, "chat.share", request.app.state.config.USER_PERMISSIONS @@ -899,14 +1323,14 @@ async def share_chat_by_id(request: Request, id: str, user=Depends(get_verified_ detail=ERROR_MESSAGES.ACCESS_PROHIBITED, ) - chat = Chats.get_chat_by_id_and_user_id(id, user.id) + chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db) if chat: if chat.share_id: - shared_chat = Chats.update_shared_chat_by_chat_id(chat.id) + shared_chat = Chats.update_shared_chat_by_chat_id(chat.id, db=db) return ChatResponse(**shared_chat.model_dump()) - shared_chat = Chats.insert_shared_chat_by_chat_id(chat.id) + shared_chat = Chats.insert_shared_chat_by_chat_id(chat.id, db=db) if not shared_chat: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, @@ -927,14 +1351,16 @@ async def share_chat_by_id(request: Request, id: str, user=Depends(get_verified_ @router.delete("/{id}/share", response_model=Optional[bool]) -async def delete_shared_chat_by_id(id: str, user=Depends(get_verified_user)): - chat = Chats.get_chat_by_id_and_user_id(id, user.id) +async def delete_shared_chat_by_id( + id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) +): + chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db) if chat: if not chat.share_id: return False - result = Chats.delete_shared_chat_by_chat_id(id) - update_result = Chats.update_chat_share_id_by_id(id, None) + result = Chats.delete_shared_chat_by_chat_id(id, db=db) + update_result = Chats.update_chat_share_id_by_id(id, None, db=db) return result and update_result != None else: @@ -955,12 +1381,15 @@ class ChatFolderIdForm(BaseModel): @router.post("/{id}/folder", response_model=Optional[ChatResponse]) async def update_chat_folder_id_by_id( - id: str, form_data: ChatFolderIdForm, user=Depends(get_verified_user) + id: str, + form_data: ChatFolderIdForm, + user=Depends(get_verified_user), + db: Session = Depends(get_session), ): - chat = Chats.get_chat_by_id_and_user_id(id, user.id) + chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db) if chat: chat = Chats.update_chat_folder_id_by_id_and_user_id( - id, user.id, form_data.folder_id + id, user.id, form_data.folder_id, db=db ) return ChatResponse(**chat.model_dump()) else: @@ -975,11 +1404,13 @@ async def update_chat_folder_id_by_id( @router.get("/{id}/tags", response_model=list[TagModel]) -async def get_chat_tags_by_id(id: str, user=Depends(get_verified_user)): - chat = Chats.get_chat_by_id_and_user_id(id, user.id) +async def get_chat_tags_by_id( + id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) +): + chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db) if chat: tags = chat.meta.get("tags", []) - return Tags.get_tags_by_ids_and_user_id(tags, user.id) + return Tags.get_tags_by_ids_and_user_id(tags, user.id, db=db) else: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND @@ -993,9 +1424,12 @@ async def get_chat_tags_by_id(id: str, user=Depends(get_verified_user)): @router.post("/{id}/tags", response_model=list[TagModel]) async def add_tag_by_id_and_tag_name( - id: str, form_data: TagForm, user=Depends(get_verified_user) + id: str, + form_data: TagForm, + user=Depends(get_verified_user), + db: Session = Depends(get_session), ): - chat = Chats.get_chat_by_id_and_user_id(id, user.id) + chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db) if chat: tags = chat.meta.get("tags", []) tag_id = form_data.name.replace(" ", "_").lower() @@ -1008,12 +1442,12 @@ async def add_tag_by_id_and_tag_name( if tag_id not in tags: Chats.add_chat_tag_by_id_and_user_id_and_tag_name( - id, user.id, form_data.name + id, user.id, form_data.name, db=db ) - chat = Chats.get_chat_by_id_and_user_id(id, user.id) + chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db) tags = chat.meta.get("tags", []) - return Tags.get_tags_by_ids_and_user_id(tags, user.id) + return Tags.get_tags_by_ids_and_user_id(tags, user.id, db=db) else: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT() @@ -1027,18 +1461,26 @@ async def add_tag_by_id_and_tag_name( @router.delete("/{id}/tags", response_model=list[TagModel]) async def delete_tag_by_id_and_tag_name( - id: str, form_data: TagForm, user=Depends(get_verified_user) + id: str, + form_data: TagForm, + user=Depends(get_verified_user), + db: Session = Depends(get_session), ): - chat = Chats.get_chat_by_id_and_user_id(id, user.id) + chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db) if chat: - Chats.delete_tag_by_id_and_user_id_and_tag_name(id, user.id, form_data.name) + Chats.delete_tag_by_id_and_user_id_and_tag_name( + id, user.id, form_data.name, db=db + ) - if Chats.count_chats_by_tag_name_and_user_id(form_data.name, user.id) == 0: - Tags.delete_tag_by_name_and_user_id(form_data.name, user.id) + if ( + Chats.count_chats_by_tag_name_and_user_id(form_data.name, user.id, db=db) + == 0 + ): + Tags.delete_tag_by_name_and_user_id(form_data.name, user.id, db=db) - chat = Chats.get_chat_by_id_and_user_id(id, user.id) + chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db) tags = chat.meta.get("tags", []) - return Tags.get_tags_by_ids_and_user_id(tags, user.id) + return Tags.get_tags_by_ids_and_user_id(tags, user.id, db=db) else: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND @@ -1051,14 +1493,16 @@ async def delete_tag_by_id_and_tag_name( @router.delete("/{id}/tags/all", response_model=Optional[bool]) -async def delete_all_tags_by_id(id: str, user=Depends(get_verified_user)): - chat = Chats.get_chat_by_id_and_user_id(id, user.id) +async def delete_all_tags_by_id( + id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) +): + chat = Chats.get_chat_by_id_and_user_id(id, user.id, db=db) if chat: - Chats.delete_all_tags_by_id_and_user_id(id, user.id) + Chats.delete_all_tags_by_id_and_user_id(id, user.id, db=db) for tag in chat.meta.get("tags", []): - if Chats.count_chats_by_tag_name_and_user_id(tag, user.id) == 0: - Tags.delete_tag_by_name_and_user_id(tag, user.id) + if Chats.count_chats_by_tag_name_and_user_id(tag, user.id, db=db) == 0: + Tags.delete_tag_by_name_and_user_id(tag, user.id, db=db) return True else: diff --git a/backend/open_webui/routers/configs.py b/backend/open_webui/routers/configs.py index 5ba0313975..83a01c6dc4 100644 --- a/backend/open_webui/routers/configs.py +++ b/backend/open_webui/routers/configs.py @@ -6,6 +6,7 @@ from typing import Optional, Literal +from open_webui.env import AIOHTTP_CLIENT_TIMEOUT from open_webui.utils.auth import get_admin_user, get_verified_user from open_webui.config import get_config, save_config from open_webui.config import BannerModel @@ -228,7 +229,10 @@ async def verify_tool_servers_config( log.debug( f"Trying to fetch OAuth 2.1 discovery document from {discovery_url}" ) - async with aiohttp.ClientSession(trust_env=True) as session: + async with aiohttp.ClientSession( + trust_env=True, + timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT), + ) as session: async with session.get( discovery_url ) as oauth_server_metadata_response: diff --git a/backend/open_webui/routers/evaluations.py b/backend/open_webui/routers/evaluations.py index cdcefe6ba7..22bb20df9b 100644 --- a/backend/open_webui/routers/evaluations.py +++ b/backend/open_webui/routers/evaluations.py @@ -1,5 +1,7 @@ from typing import Optional +import logging from fastapi import APIRouter, Depends, HTTPException, status, Request +from fastapi.concurrency import run_in_threadpool from pydantic import BaseModel from open_webui.models.users import Users, UserModel @@ -10,15 +12,268 @@ FeedbackForm, FeedbackUserResponse, FeedbackListResponse, + LeaderboardFeedbackData, + ModelHistoryEntry, + ModelHistoryResponse, Feedbacks, ) from open_webui.constants import ERROR_MESSAGES from open_webui.utils.auth import get_admin_user, get_verified_user +from open_webui.internal.db import get_session +from sqlalchemy.orm import Session + +log = logging.getLogger(__name__) + router = APIRouter() +# Leaderboard Elo Rating Computation +# +# How it works: +# 1. Each model starts with a rating of 1000 +# 2. When a user picks a winner between two models, ratings are adjusted: +# - Winner gains points, loser loses points +# - The amount depends on expected outcome (upset = bigger change) +# 3. The Elo formula: new_rating = old_rating + K * (actual - expected) +# - K=32 controls how much ratings can change per match +# - expected = probability of winning based on current ratings +# +# Query-based re-ranking (optional): +# When a user searches for a topic (e.g., "coding"), we want to show +# which models perform best FOR THAT TOPIC. We do this by: +# 1. Computing semantic similarity between the query and each feedback's tags +# 2. Using that similarity as a weight in the Elo calculation +# 3. Feedbacks about "coding" contribute more to the final ranking +# 4. Feedbacks about unrelated topics (e.g., "cooking") contribute less +# This gives topic-specific leaderboards without needing separate data. + +import os + +EMBEDDING_MODEL_NAME = os.environ.get( + "AUXILIARY_EMBEDDING_MODEL", "TaylorAI/bge-micro-v2" +) +_embedding_model = None + + +def _get_embedding_model(): + global _embedding_model + if _embedding_model is None: + try: + from sentence_transformers import SentenceTransformer + + _embedding_model = SentenceTransformer(EMBEDDING_MODEL_NAME) + except Exception as e: + log.error(f"Embedding model load failed: {e}") + return _embedding_model + + +def _calculate_elo( + feedbacks: list[LeaderboardFeedbackData], similarities: dict = None +) -> dict: + """ + Calculate Elo ratings for models based on user feedback. + + Each feedback represents a comparison where a user rated one model + against its opponents (sibling_model_ids). Rating=1 means the model won, + rating=-1 means it lost. + + The Elo system adjusts ratings based on: + - Current rating difference (upsets cause bigger swings) + - Optional similarity weights (for query-based filtering) + + Returns: {model_id: {"rating": float, "won": int, "lost": int}} + """ + K_FACTOR = 32 # Standard Elo K-factor for rating volatility + model_stats = {} + + def get_or_create_stats(model_id): + if model_id not in model_stats: + model_stats[model_id] = {"rating": 1000.0, "won": 0, "lost": 0} + return model_stats[model_id] + + for feedback in feedbacks: + data = feedback.data or {} + winner_id = data.get("model_id") + rating_value = str(data.get("rating", "")) + if not winner_id or rating_value not in ("1", "-1"): + continue + + won = rating_value == "1" + weight = similarities.get(feedback.id, 1.0) if similarities else 1.0 + + for opponent_id in data.get("sibling_model_ids") or []: + winner = get_or_create_stats(winner_id) + opponent = get_or_create_stats(opponent_id) + expected = 1 / (1 + 10 ** ((opponent["rating"] - winner["rating"]) / 400)) + + winner["rating"] += K_FACTOR * ((1 if won else 0) - expected) * weight + opponent["rating"] += ( + K_FACTOR * ((0 if won else 1) - (1 - expected)) * weight + ) + + if won: + winner["won"] += 1 + opponent["lost"] += 1 + else: + winner["lost"] += 1 + opponent["won"] += 1 + + return model_stats + + +def _get_top_tags(feedbacks: list[LeaderboardFeedbackData], limit: int = 5) -> dict: + """ + Count tag occurrences per model and return the most frequent ones. + + Each feedback can have tags describing the conversation topic. + This aggregates those tags per model to show what topics each model + is commonly used for. + + Returns: {model_id: [{"tag": str, "count": int}, ...]} + """ + from collections import defaultdict + + tag_counts = defaultdict(lambda: defaultdict(int)) + + for feedback in feedbacks: + data = feedback.data or {} + model_id = data.get("model_id") + if model_id: + for tag in data.get("tags", []): + tag_counts[model_id][tag] += 1 + + return { + model_id: [ + {"tag": tag, "count": count} + for tag, count in sorted(tags.items(), key=lambda x: -x[1])[:limit] + ] + for model_id, tags in tag_counts.items() + } + + +def _compute_similarities(feedbacks: list[LeaderboardFeedbackData], query: str) -> dict: + """ + Compute how relevant each feedback is to a search query. + + Uses embeddings to find semantic similarity between the query and + each feedback's tags. Higher similarity means the feedback is more + relevant to what the user searched for. + + This is used to weight Elo calculations - feedbacks matching the + query have more influence on the final rankings. + + Returns: {feedback_id: similarity_score (0-1)} + """ + import numpy as np + + embedding_model = _get_embedding_model() + if not embedding_model: + return {} + + all_tags = list( + { + tag + for feedback in feedbacks + if feedback.data + for tag in feedback.data.get("tags", []) + } + ) + if not all_tags: + return {} + + try: + tag_embeddings = embedding_model.encode(all_tags) + query_embedding = embedding_model.encode([query])[0] + except Exception as e: + log.error(f"Embedding error: {e}") + return {} + + # Vectorized cosine similarity + tag_norms = np.linalg.norm(tag_embeddings, axis=1) + query_norm = np.linalg.norm(query_embedding) + similarities = np.dot(tag_embeddings, query_embedding) / ( + tag_norms * query_norm + 1e-9 + ) + tag_similarity_map = dict(zip(all_tags, similarities.tolist())) + + return { + feedback.id: max( + ( + tag_similarity_map.get(tag, 0) + for tag in (feedback.data or {}).get("tags", []) + ), + default=0, + ) + for feedback in feedbacks + } + + +class LeaderboardEntry(BaseModel): + model_id: str + rating: int + won: int + lost: int + count: int + top_tags: list[dict] + + +class LeaderboardResponse(BaseModel): + entries: list[LeaderboardEntry] + + +@router.get("/leaderboard", response_model=LeaderboardResponse) +async def get_leaderboard( + query: Optional[str] = None, + user=Depends(get_admin_user), + db: Session = Depends(get_session), +): + """Get model leaderboard with Elo ratings. Query filters by tag similarity.""" + feedbacks = Feedbacks.get_feedbacks_for_leaderboard(db=db) + + similarities = None + if query and query.strip(): + similarities = await run_in_threadpool( + _compute_similarities, feedbacks, query.strip() + ) + + elo_stats = _calculate_elo(feedbacks, similarities) + tags_by_model = _get_top_tags(feedbacks) + + entries = sorted( + [ + LeaderboardEntry( + model_id=mid, + rating=round(s["rating"]), + won=s["won"], + lost=s["lost"], + count=s["won"] + s["lost"], + top_tags=tags_by_model.get(mid, []), + ) + for mid, s in elo_stats.items() + ], + key=lambda e: e.rating, + reverse=True, + ) + + return LeaderboardResponse(entries=entries) + + +@router.get("/leaderboard/{model_id}/history", response_model=ModelHistoryResponse) +async def get_model_history( + model_id: str, + days: int = 30, + user=Depends(get_admin_user), + db: Session = Depends(get_session), +): + """Get daily win/loss history for a specific model.""" + history = Feedbacks.get_model_evaluation_history( + model_id=model_id, days=days, db=db + ) + return ModelHistoryResponse(model_id=model_id, history=history) + + ############################ # GetConfig ############################ @@ -60,38 +315,49 @@ async def update_config( @router.get("/feedbacks/all", response_model=list[FeedbackResponse]) -async def get_all_feedbacks(user=Depends(get_admin_user)): - feedbacks = Feedbacks.get_all_feedbacks() +async def get_all_feedbacks( + user=Depends(get_admin_user), db: Session = Depends(get_session) +): + feedbacks = Feedbacks.get_all_feedbacks(db=db) return feedbacks @router.get("/feedbacks/all/ids", response_model=list[FeedbackIdResponse]) -async def get_all_feedback_ids(user=Depends(get_admin_user)): - feedbacks = Feedbacks.get_all_feedbacks() - return feedbacks +async def get_all_feedback_ids( + user=Depends(get_admin_user), db: Session = Depends(get_session) +): + return Feedbacks.get_all_feedback_ids(db=db) @router.delete("/feedbacks/all") -async def delete_all_feedbacks(user=Depends(get_admin_user)): - success = Feedbacks.delete_all_feedbacks() +async def delete_all_feedbacks( + user=Depends(get_admin_user), db: Session = Depends(get_session) +): + success = Feedbacks.delete_all_feedbacks(db=db) return success @router.get("/feedbacks/all/export", response_model=list[FeedbackModel]) -async def export_all_feedbacks(user=Depends(get_admin_user)): - feedbacks = Feedbacks.get_all_feedbacks() +async def export_all_feedbacks( + user=Depends(get_admin_user), db: Session = Depends(get_session) +): + feedbacks = Feedbacks.get_all_feedbacks(db=db) return feedbacks @router.get("/feedbacks/user", response_model=list[FeedbackUserResponse]) -async def get_feedbacks(user=Depends(get_verified_user)): - feedbacks = Feedbacks.get_feedbacks_by_user_id(user.id) +async def get_feedbacks( + user=Depends(get_verified_user), db: Session = Depends(get_session) +): + feedbacks = Feedbacks.get_feedbacks_by_user_id(user.id, db=db) return feedbacks @router.delete("/feedbacks", response_model=bool) -async def delete_feedbacks(user=Depends(get_verified_user)): - success = Feedbacks.delete_feedbacks_by_user_id(user.id) +async def delete_feedbacks( + user=Depends(get_verified_user), db: Session = Depends(get_session) +): + success = Feedbacks.delete_feedbacks_by_user_id(user.id, db=db) return success @@ -104,6 +370,7 @@ async def get_feedbacks( direction: Optional[str] = None, page: Optional[int] = 1, user=Depends(get_admin_user), + db: Session = Depends(get_session), ): limit = PAGE_ITEM_COUNT @@ -116,7 +383,7 @@ async def get_feedbacks( if direction: filter["direction"] = direction - result = Feedbacks.get_feedback_items(filter=filter, skip=skip, limit=limit) + result = Feedbacks.get_feedback_items(filter=filter, skip=skip, limit=limit, db=db) return result @@ -125,8 +392,11 @@ async def create_feedback( request: Request, form_data: FeedbackForm, user=Depends(get_verified_user), + db: Session = Depends(get_session), ): - feedback = Feedbacks.insert_new_feedback(user_id=user.id, form_data=form_data) + feedback = Feedbacks.insert_new_feedback( + user_id=user.id, form_data=form_data, db=db + ) if not feedback: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -137,11 +407,15 @@ async def create_feedback( @router.get("/feedback/{id}", response_model=FeedbackModel) -async def get_feedback_by_id(id: str, user=Depends(get_verified_user)): +async def get_feedback_by_id( + id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) +): if user.role == "admin": - feedback = Feedbacks.get_feedback_by_id(id=id) + feedback = Feedbacks.get_feedback_by_id(id=id, db=db) else: - feedback = Feedbacks.get_feedback_by_id_and_user_id(id=id, user_id=user.id) + feedback = Feedbacks.get_feedback_by_id_and_user_id( + id=id, user_id=user.id, db=db + ) if not feedback: raise HTTPException( @@ -153,13 +427,16 @@ async def get_feedback_by_id(id: str, user=Depends(get_verified_user)): @router.post("/feedback/{id}", response_model=FeedbackModel) async def update_feedback_by_id( - id: str, form_data: FeedbackForm, user=Depends(get_verified_user) + id: str, + form_data: FeedbackForm, + user=Depends(get_verified_user), + db: Session = Depends(get_session), ): if user.role == "admin": - feedback = Feedbacks.update_feedback_by_id(id=id, form_data=form_data) + feedback = Feedbacks.update_feedback_by_id(id=id, form_data=form_data, db=db) else: feedback = Feedbacks.update_feedback_by_id_and_user_id( - id=id, user_id=user.id, form_data=form_data + id=id, user_id=user.id, form_data=form_data, db=db ) if not feedback: @@ -171,11 +448,15 @@ async def update_feedback_by_id( @router.delete("/feedback/{id}") -async def delete_feedback_by_id(id: str, user=Depends(get_verified_user)): +async def delete_feedback_by_id( + id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) +): if user.role == "admin": - success = Feedbacks.delete_feedback_by_id(id=id) + success = Feedbacks.delete_feedback_by_id(id=id, db=db) else: - success = Feedbacks.delete_feedback_by_id_and_user_id(id=id, user_id=user.id) + success = Feedbacks.delete_feedback_by_id_and_user_id( + id=id, user_id=user.id, db=db + ) if not success: raise HTTPException( diff --git a/backend/open_webui/routers/files.py b/backend/open_webui/routers/files.py index 7f01537e66..d0e56075eb 100644 --- a/backend/open_webui/routers/files.py +++ b/backend/open_webui/routers/files.py @@ -2,7 +2,6 @@ import os import uuid import json -from fnmatch import fnmatch from pathlib import Path from typing import Optional from urllib.parse import quote @@ -22,6 +21,8 @@ ) from fastapi.responses import FileResponse, StreamingResponse +from sqlalchemy.orm import Session +from open_webui.internal.db import get_session, SessionLocal from open_webui.constants import ERROR_MESSAGES from open_webui.retrieval.vector.factory import VECTOR_DB_CLIENT @@ -62,9 +63,12 @@ # TODO: Optimize this function to use the knowledge_file table for faster lookups. def has_access_to_file( - file_id: Optional[str], access_type: str, user=Depends(get_verified_user) + file_id: Optional[str], + access_type: str, + user=Depends(get_verified_user), + db: Optional[Session] = None, ) -> bool: - file = Files.get_file_by_id(file_id) + file = Files.get_file_by_id(file_id, db=db) log.debug(f"Checking if user has {access_type} access to file") if not file: raise HTTPException( @@ -73,31 +77,33 @@ def has_access_to_file( ) # Check if the file is associated with any knowledge bases the user has access to - knowledge_bases = Knowledges.get_knowledges_by_file_id(file_id) - user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user.id)} + knowledge_bases = Knowledges.get_knowledges_by_file_id(file_id, db=db) + user_group_ids = { + group.id for group in Groups.get_groups_by_member_id(user.id, db=db) + } for knowledge_base in knowledge_bases: if knowledge_base.user_id == user.id or has_access( - user.id, access_type, knowledge_base.access_control, user_group_ids + user.id, access_type, knowledge_base.access_control, user_group_ids, db=db ): return True knowledge_base_id = file.meta.get("collection_name") if file.meta else None if knowledge_base_id: knowledge_bases = Knowledges.get_knowledge_bases_by_user_id( - user.id, access_type + user.id, access_type, db=db ) for knowledge_base in knowledge_bases: if knowledge_base.id == knowledge_base_id: return True # Check if the file is associated with any channels the user has access to - channels = Channels.get_channels_by_file_id_and_user_id(file_id, user.id) + channels = Channels.get_channels_by_file_id_and_user_id(file_id, user.id, db=db) if access_type == "read" and channels: return True # Check if the file is associated with any chats the user has access to # TODO: Granular access control for chats - chats = Chats.get_shared_chats_by_file_id(file_id) + chats = Chats.get_shared_chats_by_file_id(file_id, db=db) if chats: return True @@ -109,47 +115,78 @@ def has_access_to_file( ############################ -def process_uploaded_file(request, file, file_path, file_item, file_metadata, user): - try: - if file.content_type: - stt_supported_content_types = getattr( - request.app.state.config, "STT_SUPPORTED_CONTENT_TYPES", [] - ) +def process_uploaded_file( + request, + file, + file_path, + file_item, + file_metadata, + user, + db: Optional[Session] = None, +): + def _process_handler(db_session): + try: + if file.content_type: + stt_supported_content_types = getattr( + request.app.state.config, "STT_SUPPORTED_CONTENT_TYPES", [] + ) - if strict_match_mime_type(stt_supported_content_types, file.content_type): - file_path = Storage.get_file(file_path) - result = transcribe(request, file_path, file_metadata, user) + if strict_match_mime_type( + stt_supported_content_types, file.content_type + ): + file_path_processed = Storage.get_file(file_path) + result = transcribe( + request, file_path_processed, file_metadata, user + ) + process_file( + request, + ProcessFileForm( + file_id=file_item.id, content=result.get("text", "") + ), + user=user, + db=db_session, + ) + elif (not file.content_type.startswith(("image/", "video/"))) or ( + request.app.state.config.CONTENT_EXTRACTION_ENGINE == "external" + ): + process_file( + request, + ProcessFileForm(file_id=file_item.id), + user=user, + db=db_session, + ) + else: + raise Exception( + f"File type {file.content_type} is not supported for processing" + ) + else: + log.info( + f"File type {file.content_type} is not provided, but trying to process anyway" + ) process_file( request, - ProcessFileForm( - file_id=file_item.id, content=result.get("text", "") - ), + ProcessFileForm(file_id=file_item.id), user=user, + db=db_session, ) - elif (not file.content_type.startswith(("image/", "video/"))) or ( - request.app.state.config.CONTENT_EXTRACTION_ENGINE == "external" - ): - process_file(request, ProcessFileForm(file_id=file_item.id), user=user) - else: - raise Exception( - f"File type {file.content_type} is not supported for processing" - ) - else: - log.info( - f"File type {file.content_type} is not provided, but trying to process anyway" + + except Exception as e: + log.error(f"Error processing file: {file_item.id}") + Files.update_file_data_by_id( + file_item.id, + { + "status": "failed", + "error": str(e.detail) if hasattr(e, "detail") else str(e), + }, + db=db_session, ) - process_file(request, ProcessFileForm(file_id=file_item.id), user=user) - except Exception as e: - log.error(f"Error processing file: {file_item.id}") - Files.update_file_data_by_id( - file_item.id, - { - "status": "failed", - "error": str(e.detail) if hasattr(e, "detail") else str(e), - }, - ) + if db: + _process_handler(db) + else: + with SessionLocal() as db_session: + _process_handler(db_session) @router.post("/", response_model=FileModelResponse) @@ -161,6 +198,7 @@ def upload_file( process: bool = Query(True), process_in_background: bool = Query(True), user=Depends(get_verified_user), + db: Session = Depends(get_session), ): return upload_file_handler( request, @@ -170,6 +208,7 @@ def upload_file( process_in_background=process_in_background, user=user, background_tasks=background_tasks, + db=db, ) @@ -181,6 +220,7 @@ def upload_file_handler( process_in_background: bool = Query(True), user=Depends(get_verified_user), background_tasks: Optional[BackgroundTasks] = None, + db: Optional[Session] = None, ): log.info(f"file.content_type: {file.content_type} {process}") @@ -248,14 +288,17 @@ def upload_file_handler( }, } ), + db=db, ) if "channel_id" in file_metadata: channel = Channels.get_channel_by_id_and_user_id( - file_metadata["channel_id"], user.id + file_metadata["channel_id"], user.id, db=db ) if channel: - Channels.add_file_to_channel_by_id(channel.id, file_item.id, user.id) + Channels.add_file_to_channel_by_id( + channel.id, file_item.id, user.id, db=db + ) if process: if background_tasks and process_in_background: @@ -277,6 +320,7 @@ def upload_file_handler( file_item, file_metadata, user, + db=db, ) return {"status": True, **file_item.model_dump()} else: @@ -302,11 +346,15 @@ def upload_file_handler( @router.get("/", response_model=list[FileModelResponse]) -async def list_files(user=Depends(get_verified_user), content: bool = Query(True)): +async def list_files( + user=Depends(get_verified_user), + content: bool = Query(True), + db: Session = Depends(get_session), +): if user.role == "admin": - files = Files.get_files() + files = Files.get_files(db=db) else: - files = Files.get_files_by_user_id(user.id) + files = Files.get_files_by_user_id(user.id, db=db) if not content: for file in files: @@ -328,34 +376,41 @@ async def search_files( description="Filename pattern to search for. Supports wildcards such as '*.txt'", ), content: bool = Query(True), + skip: int = Query(0, ge=0, description="Number of files to skip"), + limit: int = Query( + 100, ge=1, le=1000, description="Maximum number of files to return" + ), user=Depends(get_verified_user), + db: Session = Depends(get_session), ): """ Search for files by filename with support for wildcard patterns. + Uses SQL-based filtering with pagination for better performance. """ - # Get files according to user role - if user.role == "admin": - files = Files.get_files() - else: - files = Files.get_files_by_user_id(user.id) - - # Get matching files - matching_files = [ - file for file in files if fnmatch(file.filename.lower(), filename.lower()) - ] + # Determine user_id: null for admin (search all), user.id for regular users + user_id = None if user.role == "admin" else user.id + + # Use optimized database query with pagination + files = Files.search_files( + user_id=user_id, + filename=filename, + skip=skip, + limit=limit, + db=db, + ) - if not matching_files: + if not files: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="No files found matching the pattern.", ) if not content: - for file in matching_files: - if "content" in file.data: + for file in files: + if file.data and "content" in file.data: del file.data["content"] - return matching_files + return files ############################ @@ -364,8 +419,10 @@ async def search_files( @router.delete("/all") -async def delete_all_files(user=Depends(get_admin_user)): - result = Files.delete_all_files() +async def delete_all_files( + user=Depends(get_admin_user), db: Session = Depends(get_session) +): + result = Files.delete_all_files(db=db) if result: try: Storage.delete_all_files() @@ -391,8 +448,10 @@ async def delete_all_files(user=Depends(get_admin_user)): @router.get("/{id}", response_model=Optional[FileModel]) -async def get_file_by_id(id: str, user=Depends(get_verified_user)): - file = Files.get_file_by_id(id) +async def get_file_by_id( + id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) +): + file = Files.get_file_by_id(id, db=db) if not file: raise HTTPException( @@ -403,7 +462,7 @@ async def get_file_by_id(id: str, user=Depends(get_verified_user)): if ( file.user_id == user.id or user.role == "admin" - or has_access_to_file(id, "read", user) + or has_access_to_file(id, "read", user, db=db) ): return file else: @@ -415,9 +474,12 @@ async def get_file_by_id(id: str, user=Depends(get_verified_user)): @router.get("/{id}/process/status") async def get_file_process_status( - id: str, stream: bool = Query(False), user=Depends(get_verified_user) + id: str, + stream: bool = Query(False), + user=Depends(get_verified_user), + db: Session = Depends(get_session), ): - file = Files.get_file_by_id(id) + file = Files.get_file_by_id(id, db=db) if not file: raise HTTPException( @@ -428,7 +490,7 @@ async def get_file_process_status( if ( file.user_id == user.id or user.role == "admin" - or has_access_to_file(id, "read", user) + or has_access_to_file(id, "read", user, db=db) ): if stream: MAX_FILE_PROCESSING_DURATION = 3600 * 2 @@ -436,7 +498,7 @@ async def get_file_process_status( async def event_stream(file_item): if file_item: for _ in range(MAX_FILE_PROCESSING_DURATION): - file_item = Files.get_file_by_id(file_item.id) + file_item = Files.get_file_by_id(file_item.id, db=db) if file_item: data = file_item.model_dump().get("data", {}) status = data.get("status") @@ -476,8 +538,10 @@ async def event_stream(file_item): @router.get("/{id}/data/content") -async def get_file_data_content_by_id(id: str, user=Depends(get_verified_user)): - file = Files.get_file_by_id(id) +async def get_file_data_content_by_id( + id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) +): + file = Files.get_file_by_id(id, db=db) if not file: raise HTTPException( @@ -488,7 +552,7 @@ async def get_file_data_content_by_id(id: str, user=Depends(get_verified_user)): if ( file.user_id == user.id or user.role == "admin" - or has_access_to_file(id, "read", user) + or has_access_to_file(id, "read", user, db=db) ): return {"content": file.data.get("content", "")} else: @@ -509,9 +573,13 @@ class ContentForm(BaseModel): @router.post("/{id}/data/content/update") async def update_file_data_content_by_id( - request: Request, id: str, form_data: ContentForm, user=Depends(get_verified_user) + request: Request, + id: str, + form_data: ContentForm, + user=Depends(get_verified_user), + db: Session = Depends(get_session), ): - file = Files.get_file_by_id(id) + file = Files.get_file_by_id(id, db=db) if not file: raise HTTPException( @@ -522,7 +590,7 @@ async def update_file_data_content_by_id( if ( file.user_id == user.id or user.role == "admin" - or has_access_to_file(id, "write", user) + or has_access_to_file(id, "write", user, db=db) ): try: process_file( @@ -530,7 +598,7 @@ async def update_file_data_content_by_id( ProcessFileForm(file_id=id, content=form_data.content), user=user, ) - file = Files.get_file_by_id(id=id) + file = Files.get_file_by_id(id=id, db=db) except Exception as e: log.exception(e) log.error(f"Error processing file: {file.id}") @@ -550,9 +618,12 @@ async def update_file_data_content_by_id( @router.get("/{id}/content") async def get_file_content_by_id( - id: str, user=Depends(get_verified_user), attachment: bool = Query(False) + id: str, + user=Depends(get_verified_user), + attachment: bool = Query(False), + db: Session = Depends(get_session), ): - file = Files.get_file_by_id(id) + file = Files.get_file_by_id(id, db=db) if not file: raise HTTPException( @@ -563,7 +634,7 @@ async def get_file_content_by_id( if ( file.user_id == user.id or user.role == "admin" - or has_access_to_file(id, "read", user) + or has_access_to_file(id, "read", user, db=db) ): try: file_path = Storage.get_file(file.path) @@ -619,8 +690,10 @@ async def get_file_content_by_id( @router.get("/{id}/content/html") -async def get_html_file_content_by_id(id: str, user=Depends(get_verified_user)): - file = Files.get_file_by_id(id) +async def get_html_file_content_by_id( + id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) +): + file = Files.get_file_by_id(id, db=db) if not file: raise HTTPException( @@ -628,7 +701,7 @@ async def get_html_file_content_by_id(id: str, user=Depends(get_verified_user)): detail=ERROR_MESSAGES.NOT_FOUND, ) - file_user = Users.get_user_by_id(file.user_id) + file_user = Users.get_user_by_id(file.user_id, db=db) if not file_user.role == "admin": raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, @@ -638,7 +711,7 @@ async def get_html_file_content_by_id(id: str, user=Depends(get_verified_user)): if ( file.user_id == user.id or user.role == "admin" - or has_access_to_file(id, "read", user) + or has_access_to_file(id, "read", user, db=db) ): try: file_path = Storage.get_file(file.path) @@ -668,8 +741,10 @@ async def get_html_file_content_by_id(id: str, user=Depends(get_verified_user)): @router.get("/{id}/content/{file_name}") -async def get_file_content_by_id(id: str, user=Depends(get_verified_user)): - file = Files.get_file_by_id(id) +async def get_file_content_by_id( + id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) +): + file = Files.get_file_by_id(id, db=db) if not file: raise HTTPException( @@ -680,7 +755,7 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)): if ( file.user_id == user.id or user.role == "admin" - or has_access_to_file(id, "read", user) + or has_access_to_file(id, "read", user, db=db) ): file_path = file.path @@ -730,8 +805,10 @@ def generator(): @router.delete("/{id}") -async def delete_file_by_id(id: str, user=Depends(get_verified_user)): - file = Files.get_file_by_id(id) +async def delete_file_by_id( + id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) +): + file = Files.get_file_by_id(id, db=db) if not file: raise HTTPException( @@ -742,10 +819,10 @@ async def delete_file_by_id(id: str, user=Depends(get_verified_user)): if ( file.user_id == user.id or user.role == "admin" - or has_access_to_file(id, "write", user) + or has_access_to_file(id, "write", user, db=db) ): - result = Files.delete_file_by_id(id) + result = Files.delete_file_by_id(id, db=db) if result: try: Storage.delete_file(file.path) diff --git a/backend/open_webui/routers/folders.py b/backend/open_webui/routers/folders.py index 32911fa509..1c9b2229cf 100644 --- a/backend/open_webui/routers/folders.py +++ b/backend/open_webui/routers/folders.py @@ -22,6 +22,8 @@ from open_webui.config import UPLOAD_DIR from open_webui.constants import ERROR_MESSAGES +from open_webui.internal.db import get_session +from sqlalchemy.orm import Session from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status, Request @@ -44,7 +46,11 @@ @router.get("/", response_model=list[FolderNameIdResponse]) -async def get_folders(request: Request, user=Depends(get_verified_user)): +async def get_folders( + request: Request, + user=Depends(get_verified_user), + db: Session = Depends(get_session), +): if request.app.state.config.ENABLE_FOLDERS is False: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, @@ -55,22 +61,23 @@ async def get_folders(request: Request, user=Depends(get_verified_user)): user.id, "features.folders", request.app.state.config.USER_PERMISSIONS, + db=db, ): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED, ) - folders = Folders.get_folders_by_user_id(user.id) + folders = Folders.get_folders_by_user_id(user.id, db=db) # Verify folder data integrity folder_list = [] for folder in folders: if folder.parent_id and not Folders.get_folder_by_id_and_user_id( - folder.parent_id, user.id + folder.parent_id, user.id, db=db ): folder = Folders.update_folder_parent_id_by_id_and_user_id( - folder.id, user.id, None + folder.id, user.id, None, db=db ) if folder.data: @@ -80,12 +87,12 @@ async def get_folders(request: Request, user=Depends(get_verified_user)): if file.get("type") == "file": if Files.check_access_by_user_id( - file.get("id"), user.id, "read" + file.get("id"), user.id, "read", db=db ): valid_files.append(file) elif file.get("type") == "collection": if Knowledges.check_access_by_user_id( - file.get("id"), user.id, "read" + file.get("id"), user.id, "read", db=db ): valid_files.append(file) else: @@ -93,7 +100,7 @@ async def get_folders(request: Request, user=Depends(get_verified_user)): folder.data["files"] = valid_files Folders.update_folder_by_id_and_user_id( - folder.id, user.id, FolderUpdateForm(data=folder.data) + folder.id, user.id, FolderUpdateForm(data=folder.data), db=db ) folder_list.append(FolderNameIdResponse(**folder.model_dump())) @@ -107,9 +114,13 @@ async def get_folders(request: Request, user=Depends(get_verified_user)): @router.post("/") -def create_folder(form_data: FolderForm, user=Depends(get_verified_user)): +def create_folder( + form_data: FolderForm, + user=Depends(get_verified_user), + db: Session = Depends(get_session), +): folder = Folders.get_folder_by_parent_id_and_user_id_and_name( - None, user.id, form_data.name + None, user.id, form_data.name, db=db ) if folder: @@ -119,7 +130,7 @@ def create_folder(form_data: FolderForm, user=Depends(get_verified_user)): ) try: - folder = Folders.insert_new_folder(user.id, form_data) + folder = Folders.insert_new_folder(user.id, form_data, db=db) return folder except Exception as e: log.exception(e) @@ -136,8 +147,10 @@ def create_folder(form_data: FolderForm, user=Depends(get_verified_user)): @router.get("/{id}", response_model=Optional[FolderModel]) -async def get_folder_by_id(id: str, user=Depends(get_verified_user)): - folder = Folders.get_folder_by_id_and_user_id(id, user.id) +async def get_folder_by_id( + id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) +): + folder = Folders.get_folder_by_id_and_user_id(id, user.id, db=db) if folder: return folder else: @@ -154,15 +167,18 @@ async def get_folder_by_id(id: str, user=Depends(get_verified_user)): @router.post("/{id}/update") async def update_folder_name_by_id( - id: str, form_data: FolderUpdateForm, user=Depends(get_verified_user) + id: str, + form_data: FolderUpdateForm, + user=Depends(get_verified_user), + db: Session = Depends(get_session), ): - folder = Folders.get_folder_by_id_and_user_id(id, user.id) + folder = Folders.get_folder_by_id_and_user_id(id, user.id, db=db) if folder: if form_data.name is not None: # Check if folder with same name exists existing_folder = Folders.get_folder_by_parent_id_and_user_id_and_name( - folder.parent_id, user.id, form_data.name + folder.parent_id, user.id, form_data.name, db=db ) if existing_folder and existing_folder.id != id: raise HTTPException( @@ -171,7 +187,9 @@ async def update_folder_name_by_id( ) try: - folder = Folders.update_folder_by_id_and_user_id(id, user.id, form_data) + folder = Folders.update_folder_by_id_and_user_id( + id, user.id, form_data, db=db + ) return folder except Exception as e: log.exception(e) @@ -198,12 +216,15 @@ class FolderParentIdForm(BaseModel): @router.post("/{id}/update/parent") async def update_folder_parent_id_by_id( - id: str, form_data: FolderParentIdForm, user=Depends(get_verified_user) + id: str, + form_data: FolderParentIdForm, + user=Depends(get_verified_user), + db: Session = Depends(get_session), ): - folder = Folders.get_folder_by_id_and_user_id(id, user.id) + folder = Folders.get_folder_by_id_and_user_id(id, user.id, db=db) if folder: existing_folder = Folders.get_folder_by_parent_id_and_user_id_and_name( - form_data.parent_id, user.id, folder.name + form_data.parent_id, user.id, folder.name, db=db ) if existing_folder: @@ -214,7 +235,7 @@ async def update_folder_parent_id_by_id( try: folder = Folders.update_folder_parent_id_by_id_and_user_id( - id, user.id, form_data.parent_id + id, user.id, form_data.parent_id, db=db ) return folder except Exception as e: @@ -242,13 +263,16 @@ class FolderIsExpandedForm(BaseModel): @router.post("/{id}/update/expanded") async def update_folder_is_expanded_by_id( - id: str, form_data: FolderIsExpandedForm, user=Depends(get_verified_user) + id: str, + form_data: FolderIsExpandedForm, + user=Depends(get_verified_user), + db: Session = Depends(get_session), ): - folder = Folders.get_folder_by_id_and_user_id(id, user.id) + folder = Folders.get_folder_by_id_and_user_id(id, user.id, db=db) if folder: try: folder = Folders.update_folder_is_expanded_by_id_and_user_id( - id, user.id, form_data.is_expanded + id, user.id, form_data.is_expanded, db=db ) return folder except Exception as e: @@ -276,10 +300,11 @@ async def delete_folder_by_id( id: str, delete_contents: Optional[bool] = True, user=Depends(get_verified_user), + db: Session = Depends(get_session), ): - if Chats.count_chats_by_folder_id_and_user_id(id, user.id): + if Chats.count_chats_by_folder_id_and_user_id(id, user.id, db=db): chat_delete_permission = has_permission( - user.id, "chat.delete", request.app.state.config.USER_PERMISSIONS + user.id, "chat.delete", request.app.state.config.USER_PERMISSIONS, db=db ) if user.role != "admin" and not chat_delete_permission: raise HTTPException( @@ -288,19 +313,21 @@ async def delete_folder_by_id( ) folders = [] - folders.append(Folders.get_folder_by_id_and_user_id(id, user.id)) + folders.append(Folders.get_folder_by_id_and_user_id(id, user.id, db=db)) while folders: folder = folders.pop() if folder: try: - folder_ids = Folders.delete_folder_by_id_and_user_id(id, user.id) + folder_ids = Folders.delete_folder_by_id_and_user_id(id, user.id, db=db) for folder_id in folder_ids: if delete_contents: - Chats.delete_chats_by_user_id_and_folder_id(user.id, folder_id) + Chats.delete_chats_by_user_id_and_folder_id( + user.id, folder_id, db=db + ) else: Chats.move_chats_by_user_id_and_folder_id( - user.id, folder_id, None + user.id, folder_id, None, db=db ) return True @@ -314,7 +341,7 @@ async def delete_folder_by_id( finally: # Get all subfolders subfolders = Folders.get_folders_by_parent_id_and_user_id( - folder.id, user.id + folder.id, user.id, db=db ) folders.extend(subfolders) diff --git a/backend/open_webui/routers/functions.py b/backend/open_webui/routers/functions.py index 82321cb546..ad47318911 100644 --- a/backend/open_webui/routers/functions.py +++ b/backend/open_webui/routers/functions.py @@ -6,6 +6,7 @@ from pathlib import Path from typing import Optional +from open_webui.env import AIOHTTP_CLIENT_TIMEOUT from open_webui.models.functions import ( FunctionForm, FunctionModel, @@ -24,6 +25,8 @@ from fastapi import APIRouter, Depends, HTTPException, Request, status from open_webui.utils.auth import get_admin_user, get_verified_user from pydantic import BaseModel, HttpUrl +from open_webui.internal.db import get_session +from sqlalchemy.orm import Session log = logging.getLogger(__name__) @@ -37,13 +40,17 @@ @router.get("/", response_model=list[FunctionResponse]) -async def get_functions(user=Depends(get_verified_user)): - return Functions.get_functions() +async def get_functions( + user=Depends(get_verified_user), db: Session = Depends(get_session) +): + return Functions.get_functions(db=db) @router.get("/list", response_model=list[FunctionUserResponse]) -async def get_function_list(user=Depends(get_admin_user)): - return Functions.get_function_list() +async def get_function_list( + user=Depends(get_admin_user), db: Session = Depends(get_session) +): + return Functions.get_function_list(db=db) ############################ @@ -52,8 +59,12 @@ async def get_function_list(user=Depends(get_admin_user)): @router.get("/export", response_model=list[FunctionModel | FunctionWithValvesModel]) -async def get_functions(include_valves: bool = False, user=Depends(get_admin_user)): - return Functions.get_functions(include_valves=include_valves) +async def get_functions( + include_valves: bool = False, + user=Depends(get_admin_user), + db: Session = Depends(get_session), +): + return Functions.get_functions(include_valves=include_valves, db=db) ############################ @@ -110,7 +121,9 @@ async def load_function_from_url( ) try: - async with aiohttp.ClientSession(trust_env=True) as session: + async with aiohttp.ClientSession( + trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT) + ) as session: async with session.get( url, headers={"Content-Type": "application/json"} ) as resp: @@ -142,7 +155,10 @@ class SyncFunctionsForm(BaseModel): @router.post("/sync", response_model=list[FunctionWithValvesModel]) async def sync_functions( - request: Request, form_data: SyncFunctionsForm, user=Depends(get_admin_user) + request: Request, + form_data: SyncFunctionsForm, + user=Depends(get_admin_user), + db: Session = Depends(get_session), ): try: for function in form_data.functions: @@ -164,7 +180,7 @@ async def sync_functions( ) raise e - return Functions.sync_functions(user.id, form_data.functions) + return Functions.sync_functions(user.id, form_data.functions, db=db) except Exception as e: log.exception(f"Failed to load a function: {e}") raise HTTPException( @@ -180,7 +196,10 @@ async def sync_functions( @router.post("/create", response_model=Optional[FunctionResponse]) async def create_new_function( - request: Request, form_data: FunctionForm, user=Depends(get_admin_user) + request: Request, + form_data: FunctionForm, + user=Depends(get_admin_user), + db: Session = Depends(get_session), ): if not form_data.id.isidentifier(): raise HTTPException( @@ -190,7 +209,7 @@ async def create_new_function( form_data.id = form_data.id.lower() - function = Functions.get_function_by_id(form_data.id) + function = Functions.get_function_by_id(form_data.id, db=db) if function is None: try: form_data.content = replace_imports(form_data.content) @@ -203,13 +222,17 @@ async def create_new_function( FUNCTIONS = request.app.state.FUNCTIONS FUNCTIONS[form_data.id] = function_module - function = Functions.insert_new_function(user.id, function_type, form_data) + function = Functions.insert_new_function( + user.id, function_type, form_data, db=db + ) function_cache_dir = CACHE_DIR / "functions" / form_data.id function_cache_dir.mkdir(parents=True, exist_ok=True) if function_type == "filter" and getattr(function_module, "toggle", None): - Functions.update_function_metadata_by_id(id, {"toggle": True}) + Functions.update_function_metadata_by_id( + form_data.id, {"toggle": True}, db=db + ) if function: return function @@ -237,8 +260,10 @@ async def create_new_function( @router.get("/id/{id}", response_model=Optional[FunctionModel]) -async def get_function_by_id(id: str, user=Depends(get_admin_user)): - function = Functions.get_function_by_id(id) +async def get_function_by_id( + id: str, user=Depends(get_admin_user), db: Session = Depends(get_session) +): + function = Functions.get_function_by_id(id, db=db) if function: return function @@ -255,11 +280,13 @@ async def get_function_by_id(id: str, user=Depends(get_admin_user)): @router.post("/id/{id}/toggle", response_model=Optional[FunctionModel]) -async def toggle_function_by_id(id: str, user=Depends(get_admin_user)): - function = Functions.get_function_by_id(id) +async def toggle_function_by_id( + id: str, user=Depends(get_admin_user), db: Session = Depends(get_session) +): + function = Functions.get_function_by_id(id, db=db) if function: function = Functions.update_function_by_id( - id, {"is_active": not function.is_active} + id, {"is_active": not function.is_active}, db=db ) if function: @@ -282,11 +309,13 @@ async def toggle_function_by_id(id: str, user=Depends(get_admin_user)): @router.post("/id/{id}/toggle/global", response_model=Optional[FunctionModel]) -async def toggle_global_by_id(id: str, user=Depends(get_admin_user)): - function = Functions.get_function_by_id(id) +async def toggle_global_by_id( + id: str, user=Depends(get_admin_user), db: Session = Depends(get_session) +): + function = Functions.get_function_by_id(id, db=db) if function: function = Functions.update_function_by_id( - id, {"is_global": not function.is_global} + id, {"is_global": not function.is_global}, db=db ) if function: @@ -310,7 +339,11 @@ async def toggle_global_by_id(id: str, user=Depends(get_admin_user)): @router.post("/id/{id}/update", response_model=Optional[FunctionModel]) async def update_function_by_id( - request: Request, id: str, form_data: FunctionForm, user=Depends(get_admin_user) + request: Request, + id: str, + form_data: FunctionForm, + user=Depends(get_admin_user), + db: Session = Depends(get_session), ): try: form_data.content = replace_imports(form_data.content) @@ -325,10 +358,10 @@ async def update_function_by_id( updated = {**form_data.model_dump(exclude={"id"}), "type": function_type} log.debug(updated) - function = Functions.update_function_by_id(id, updated) + function = Functions.update_function_by_id(id, updated, db=db) if function_type == "filter" and getattr(function_module, "toggle", None): - Functions.update_function_metadata_by_id(id, {"toggle": True}) + Functions.update_function_metadata_by_id(id, {"toggle": True}, db=db) if function: return function @@ -352,9 +385,12 @@ async def update_function_by_id( @router.delete("/id/{id}/delete", response_model=bool) async def delete_function_by_id( - request: Request, id: str, user=Depends(get_admin_user) + request: Request, + id: str, + user=Depends(get_admin_user), + db: Session = Depends(get_session), ): - result = Functions.delete_function_by_id(id) + result = Functions.delete_function_by_id(id, db=db) if result: FUNCTIONS = request.app.state.FUNCTIONS @@ -370,11 +406,13 @@ async def delete_function_by_id( @router.get("/id/{id}/valves", response_model=Optional[dict]) -async def get_function_valves_by_id(id: str, user=Depends(get_admin_user)): - function = Functions.get_function_by_id(id) +async def get_function_valves_by_id( + id: str, user=Depends(get_admin_user), db: Session = Depends(get_session) +): + function = Functions.get_function_by_id(id, db=db) if function: try: - valves = Functions.get_function_valves_by_id(id) + valves = Functions.get_function_valves_by_id(id, db=db) return valves except Exception as e: raise HTTPException( @@ -395,9 +433,12 @@ async def get_function_valves_by_id(id: str, user=Depends(get_admin_user)): @router.get("/id/{id}/valves/spec", response_model=Optional[dict]) async def get_function_valves_spec_by_id( - request: Request, id: str, user=Depends(get_admin_user) + request: Request, + id: str, + user=Depends(get_admin_user), + db: Session = Depends(get_session), ): - function = Functions.get_function_by_id(id) + function = Functions.get_function_by_id(id, db=db) if function: function_module, function_type, frontmatter = get_function_module_from_cache( request, id @@ -421,9 +462,13 @@ async def get_function_valves_spec_by_id( @router.post("/id/{id}/valves/update", response_model=Optional[dict]) async def update_function_valves_by_id( - request: Request, id: str, form_data: dict, user=Depends(get_admin_user) + request: Request, + id: str, + form_data: dict, + user=Depends(get_admin_user), + db: Session = Depends(get_session), ): - function = Functions.get_function_by_id(id) + function = Functions.get_function_by_id(id, db=db) if function: function_module, function_type, frontmatter = get_function_module_from_cache( request, id @@ -437,7 +482,7 @@ async def update_function_valves_by_id( valves = Valves(**form_data) valves_dict = valves.model_dump(exclude_unset=True) - Functions.update_function_valves_by_id(id, valves_dict) + Functions.update_function_valves_by_id(id, valves_dict, db=db) return valves_dict except Exception as e: log.exception(f"Error updating function values by id {id}: {e}") @@ -464,11 +509,15 @@ async def update_function_valves_by_id( @router.get("/id/{id}/valves/user", response_model=Optional[dict]) -async def get_function_user_valves_by_id(id: str, user=Depends(get_verified_user)): - function = Functions.get_function_by_id(id) +async def get_function_user_valves_by_id( + id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) +): + function = Functions.get_function_by_id(id, db=db) if function: try: - user_valves = Functions.get_user_valves_by_id_and_user_id(id, user.id) + user_valves = Functions.get_user_valves_by_id_and_user_id( + id, user.id, db=db + ) return user_valves except Exception as e: raise HTTPException( @@ -484,9 +533,12 @@ async def get_function_user_valves_by_id(id: str, user=Depends(get_verified_user @router.get("/id/{id}/valves/user/spec", response_model=Optional[dict]) async def get_function_user_valves_spec_by_id( - request: Request, id: str, user=Depends(get_verified_user) + request: Request, + id: str, + user=Depends(get_verified_user), + db: Session = Depends(get_session), ): - function = Functions.get_function_by_id(id) + function = Functions.get_function_by_id(id, db=db) if function: function_module, function_type, frontmatter = get_function_module_from_cache( request, id @@ -505,9 +557,13 @@ async def get_function_user_valves_spec_by_id( @router.post("/id/{id}/valves/user/update", response_model=Optional[dict]) async def update_function_user_valves_by_id( - request: Request, id: str, form_data: dict, user=Depends(get_verified_user) + request: Request, + id: str, + form_data: dict, + user=Depends(get_verified_user), + db: Session = Depends(get_session), ): - function = Functions.get_function_by_id(id) + function = Functions.get_function_by_id(id, db=db) if function: function_module, function_type, frontmatter = get_function_module_from_cache( @@ -522,7 +578,7 @@ async def update_function_user_valves_by_id( user_valves = UserValves(**form_data) user_valves_dict = user_valves.model_dump(exclude_unset=True) Functions.update_user_valves_by_id_and_user_id( - id, user.id, user_valves_dict + id, user.id, user_valves_dict, db=db ) return user_valves_dict except Exception as e: diff --git a/backend/open_webui/routers/groups.py b/backend/open_webui/routers/groups.py index 423f6b1c67..cc0cb8f5a3 100755 --- a/backend/open_webui/routers/groups.py +++ b/backend/open_webui/routers/groups.py @@ -16,6 +16,9 @@ from open_webui.constants import ERROR_MESSAGES from fastapi import APIRouter, Depends, HTTPException, Request, status +from open_webui.internal.db import get_session +from sqlalchemy.orm import Session + from open_webui.utils.auth import get_admin_user, get_verified_user @@ -29,16 +32,21 @@ @router.get("/", response_model=list[GroupResponse]) -async def get_groups(share: Optional[bool] = None, user=Depends(get_verified_user)): +async def get_groups( + share: Optional[bool] = None, + user=Depends(get_verified_user), + db: Session = Depends(get_session), +): filter = {} + + # Admins can share to all groups regardless of share setting if user.role != "admin": filter["member_id"] = user.id + if share is not None: + filter["share"] = share - if share is not None: - filter["share"] = share - - groups = Groups.get_groups(filter=filter) + groups = Groups.get_groups(filter=filter, db=db) return groups @@ -49,13 +57,17 @@ async def get_groups(share: Optional[bool] = None, user=Depends(get_verified_use @router.post("/create", response_model=Optional[GroupResponse]) -async def create_new_group(form_data: GroupForm, user=Depends(get_admin_user)): +async def create_new_group( + form_data: GroupForm, + user=Depends(get_admin_user), + db: Session = Depends(get_session), +): try: - group = Groups.insert_new_group(user.id, form_data) + group = Groups.insert_new_group(user.id, form_data, db=db) if group: return GroupResponse( **group.model_dump(), - member_count=Groups.get_group_member_count_by_id(group.id), + member_count=Groups.get_group_member_count_by_id(group.id, db=db), ) else: raise HTTPException( @@ -76,12 +88,14 @@ async def create_new_group(form_data: GroupForm, user=Depends(get_admin_user)): @router.get("/id/{id}", response_model=Optional[GroupResponse]) -async def get_group_by_id(id: str, user=Depends(get_admin_user)): - group = Groups.get_group_by_id(id) +async def get_group_by_id( + id: str, user=Depends(get_admin_user), db: Session = Depends(get_session) +): + group = Groups.get_group_by_id(id, db=db) if group: return GroupResponse( **group.model_dump(), - member_count=Groups.get_group_member_count_by_id(group.id), + member_count=Groups.get_group_member_count_by_id(group.id, db=db), ) else: raise HTTPException( @@ -101,13 +115,15 @@ class GroupExportResponse(GroupResponse): @router.get("/id/{id}/export", response_model=Optional[GroupExportResponse]) -async def export_group_by_id(id: str, user=Depends(get_admin_user)): - group = Groups.get_group_by_id(id) +async def export_group_by_id( + id: str, user=Depends(get_admin_user), db: Session = Depends(get_session) +): + group = Groups.get_group_by_id(id, db=db) if group: return GroupExportResponse( **group.model_dump(), - member_count=Groups.get_group_member_count_by_id(group.id), - user_ids=Groups.get_group_user_ids_by_id(group.id), + member_count=Groups.get_group_member_count_by_id(group.id, db=db), + user_ids=Groups.get_group_user_ids_by_id(group.id, db=db), ) else: raise HTTPException( @@ -122,9 +138,11 @@ async def export_group_by_id(id: str, user=Depends(get_admin_user)): @router.post("/id/{id}/users", response_model=list[UserInfoResponse]) -async def get_users_in_group(id: str, user=Depends(get_admin_user)): +async def get_users_in_group( + id: str, user=Depends(get_admin_user), db: Session = Depends(get_session) +): try: - users = Users.get_users_by_group_id(id) + users = Users.get_users_by_group_id(id, db=db) return users except Exception as e: log.exception(f"Error adding users to group {id}: {e}") @@ -141,14 +159,17 @@ async def get_users_in_group(id: str, user=Depends(get_admin_user)): @router.post("/id/{id}/update", response_model=Optional[GroupResponse]) async def update_group_by_id( - id: str, form_data: GroupUpdateForm, user=Depends(get_admin_user) + id: str, + form_data: GroupUpdateForm, + user=Depends(get_admin_user), + db: Session = Depends(get_session), ): try: - group = Groups.update_group_by_id(id, form_data) + group = Groups.update_group_by_id(id, form_data, db=db) if group: return GroupResponse( **group.model_dump(), - member_count=Groups.get_group_member_count_by_id(group.id), + member_count=Groups.get_group_member_count_by_id(group.id, db=db), ) else: raise HTTPException( @@ -170,17 +191,20 @@ async def update_group_by_id( @router.post("/id/{id}/users/add", response_model=Optional[GroupResponse]) async def add_user_to_group( - id: str, form_data: UserIdsForm, user=Depends(get_admin_user) + id: str, + form_data: UserIdsForm, + user=Depends(get_admin_user), + db: Session = Depends(get_session), ): try: if form_data.user_ids: - form_data.user_ids = Users.get_valid_user_ids(form_data.user_ids) + form_data.user_ids = Users.get_valid_user_ids(form_data.user_ids, db=db) - group = Groups.add_users_to_group(id, form_data.user_ids) + group = Groups.add_users_to_group(id, form_data.user_ids, db=db) if group: return GroupResponse( **group.model_dump(), - member_count=Groups.get_group_member_count_by_id(group.id), + member_count=Groups.get_group_member_count_by_id(group.id, db=db), ) else: raise HTTPException( @@ -197,14 +221,17 @@ async def add_user_to_group( @router.post("/id/{id}/users/remove", response_model=Optional[GroupResponse]) async def remove_users_from_group( - id: str, form_data: UserIdsForm, user=Depends(get_admin_user) + id: str, + form_data: UserIdsForm, + user=Depends(get_admin_user), + db: Session = Depends(get_session), ): try: - group = Groups.remove_users_from_group(id, form_data.user_ids) + group = Groups.remove_users_from_group(id, form_data.user_ids, db=db) if group: return GroupResponse( **group.model_dump(), - member_count=Groups.get_group_member_count_by_id(group.id), + member_count=Groups.get_group_member_count_by_id(group.id, db=db), ) else: raise HTTPException( @@ -225,9 +252,11 @@ async def remove_users_from_group( @router.delete("/id/{id}/delete", response_model=bool) -async def delete_group_by_id(id: str, user=Depends(get_admin_user)): +async def delete_group_by_id( + id: str, user=Depends(get_admin_user), db: Session = Depends(get_session) +): try: - result = Groups.delete_group_by_id(id) + result = Groups.delete_group_by_id(id, db=db) if result: return result else: diff --git a/backend/open_webui/routers/images.py b/backend/open_webui/routers/images.py index 8037d2077d..0fc6930b81 100644 --- a/backend/open_webui/routers/images.py +++ b/backend/open_webui/routers/images.py @@ -21,7 +21,10 @@ from open_webui.models.chats import Chats from open_webui.routers.files import upload_file_handler, get_file_content_by_id from open_webui.utils.auth import get_admin_user, get_verified_user +from open_webui.utils.access_control import has_permission from open_webui.utils.headers import include_user_info_headers +from open_webui.internal.db import get_session +from sqlalchemy.orm import Session from open_webui.utils.images.comfyui import ( ComfyUICreateImageForm, ComfyUIEditImageForm, @@ -461,6 +464,7 @@ class CreateImageForm(BaseModel): prompt: str size: Optional[str] = None n: int = 1 + steps: Optional[int] = None negative_prompt: Optional[str] = None @@ -496,7 +500,7 @@ def get_image_data(data: str, headers=None): return None, None -def upload_image(request, image_data, content_type, metadata, user): +def upload_image(request, image_data, content_type, metadata, user, db=None): image_format = mimetypes.guess_extension(content_type) file = UploadFile( file=io.BytesIO(image_data), @@ -524,6 +528,7 @@ def upload_image(request, image_data, content_type, metadata, user): message_id=message_id, file_ids=[file_item.id], user_id=user.id, + db=db, ) url = request.app.url_path_for("get_file_content_by_id", id=file_item.id) @@ -534,6 +539,20 @@ def upload_image(request, image_data, content_type, metadata, user): async def generate_images( request: Request, form_data: CreateImageForm, user=Depends(get_verified_user) ): + if not request.app.state.config.ENABLE_IMAGE_GENERATION: + raise HTTPException( + status_code=403, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) + + if user.role != "admin" and not has_permission( + user.id, "features.image_generation", request.app.state.config.USER_PERMISSIONS + ): + raise HTTPException( + status_code=403, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) + return await image_generations(request, form_data, user=user) @@ -703,8 +722,15 @@ async def image_generations( "n": form_data.n, } - if request.app.state.config.IMAGE_STEPS is not None: - data["steps"] = request.app.state.config.IMAGE_STEPS + if ( + request.app.state.config.IMAGE_STEPS is not None + or form_data.steps is not None + ): + data["steps"] = ( + form_data.steps + if form_data.steps is not None + else request.app.state.config.IMAGE_STEPS + ) if form_data.negative_prompt is not None: data["negative_prompt"] = form_data.negative_prompt @@ -762,8 +788,15 @@ async def image_generations( "height": height, } - if request.app.state.config.IMAGE_STEPS is not None: - data["steps"] = request.app.state.config.IMAGE_STEPS + if ( + request.app.state.config.IMAGE_STEPS is not None + or form_data.steps is not None + ): + data["steps"] = ( + form_data.steps + if form_data.steps is not None + else request.app.state.config.IMAGE_STEPS + ) if form_data.negative_prompt is not None: data["negative_prompt"] = form_data.negative_prompt @@ -844,6 +877,9 @@ async def image_edits( try: async def load_url_image(data): + if data.startswith("data:"): + return data + if data.startswith("http://") or data.startswith("https://"): r = await asyncio.to_thread(requests.get, data) r.raise_for_status() @@ -851,10 +887,14 @@ async def load_url_image(data): image_data = base64.b64encode(r.content).decode("utf-8") return f"data:{r.headers['content-type']};base64,{image_data}" - elif data.startswith("/api/v1/files"): - file_id = data.split("/api/v1/files/")[1].split("/content")[0] - file_response = await get_file_content_by_id(file_id, user) + else: + file_id = None + if data.startswith("/api/v1/files"): + file_id = data.split("/api/v1/files/")[1].split("/content")[0] + else: + file_id = data + file_response = await get_file_content_by_id(file_id, user) if isinstance(file_response, FileResponse): file_path = file_response.path @@ -864,7 +904,6 @@ async def load_url_image(data): mime_type, _ = mimetypes.guess_type(file_path) return f"data:{mime_type};base64,{image_data}" - return data # Load image(s) from URL(s) if necessary diff --git a/backend/open_webui/routers/knowledge.py b/backend/open_webui/routers/knowledge.py index 467f6f3896..9fc30424ca 100644 --- a/backend/open_webui/routers/knowledge.py +++ b/backend/open_webui/routers/knowledge.py @@ -1,9 +1,14 @@ from typing import List, Optional from pydantic import BaseModel from fastapi import APIRouter, Depends, HTTPException, status, Request, Query +from fastapi.responses import StreamingResponse from fastapi.concurrency import run_in_threadpool import logging +import io +import zipfile +from sqlalchemy.orm import Session +from open_webui.internal.db import get_session from open_webui.models.groups import Groups from open_webui.models.knowledge import ( KnowledgeFileListResponse, @@ -23,7 +28,7 @@ from open_webui.storage.provider import Storage from open_webui.constants import ERROR_MESSAGES -from open_webui.utils.auth import get_verified_user +from open_webui.utils.auth import get_verified_user, get_admin_user from open_webui.utils.access_control import has_access, has_permission @@ -41,6 +46,54 @@ PAGE_ITEM_COUNT = 30 +############################ +# Knowledge Base Embedding +############################ + +KNOWLEDGE_BASES_COLLECTION = "knowledge-bases" + + +async def embed_knowledge_base_metadata( + request: Request, + knowledge_base_id: str, + name: str, + description: str, +) -> bool: + """Generate and store embedding for knowledge base.""" + try: + content = f"{name}\n\n{description}" if description else name + embedding = await request.app.state.EMBEDDING_FUNCTION(content) + VECTOR_DB_CLIENT.upsert( + collection_name=KNOWLEDGE_BASES_COLLECTION, + items=[ + { + "id": knowledge_base_id, + "text": content, + "vector": embedding, + "metadata": { + "knowledge_base_id": knowledge_base_id, + }, + } + ], + ) + return True + except Exception as e: + log.error(f"Failed to embed knowledge base {knowledge_base_id}: {e}") + return False + + +def remove_knowledge_base_metadata_embedding(knowledge_base_id: str) -> bool: + """Remove knowledge base embedding.""" + try: + VECTOR_DB_CLIENT.delete( + collection_name=KNOWLEDGE_BASES_COLLECTION, + ids=[knowledge_base_id], + ) + return True + except Exception as e: + log.debug(f"Failed to remove embedding for {knowledge_base_id}: {e}") + return False + class KnowledgeAccessResponse(KnowledgeUserResponse): write_access: Optional[bool] = False @@ -52,21 +105,25 @@ class KnowledgeAccessListResponse(BaseModel): @router.get("/", response_model=KnowledgeAccessListResponse) -async def get_knowledge_bases(page: Optional[int] = 1, user=Depends(get_verified_user)): +async def get_knowledge_bases( + page: Optional[int] = 1, + user=Depends(get_verified_user), + db: Session = Depends(get_session), +): page = max(page, 1) limit = PAGE_ITEM_COUNT skip = (page - 1) * limit filter = {} if not user.role == "admin" or not BYPASS_ADMIN_ACCESS_CONTROL: - groups = Groups.get_groups_by_member_id(user.id) + groups = Groups.get_groups_by_member_id(user.id, db=db) if groups: filter["group_ids"] = [group.id for group in groups] filter["user_id"] = user.id result = Knowledges.search_knowledge_bases( - user.id, filter=filter, skip=skip, limit=limit + user.id, filter=filter, skip=skip, limit=limit, db=db ) return KnowledgeAccessListResponse( @@ -75,7 +132,10 @@ async def get_knowledge_bases(page: Optional[int] = 1, user=Depends(get_verified **knowledge_base.model_dump(), write_access=( user.id == knowledge_base.user_id - or has_access(user.id, "write", knowledge_base.access_control) + or (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL) + or has_access( + user.id, "write", knowledge_base.access_control, db=db + ) ), ) for knowledge_base in result.items @@ -90,6 +150,7 @@ async def search_knowledge_bases( view_option: Optional[str] = None, page: Optional[int] = 1, user=Depends(get_verified_user), + db: Session = Depends(get_session), ): page = max(page, 1) limit = PAGE_ITEM_COUNT @@ -102,14 +163,14 @@ async def search_knowledge_bases( filter["view_option"] = view_option if not user.role == "admin" or not BYPASS_ADMIN_ACCESS_CONTROL: - groups = Groups.get_groups_by_member_id(user.id) + groups = Groups.get_groups_by_member_id(user.id, db=db) if groups: filter["group_ids"] = [group.id for group in groups] filter["user_id"] = user.id result = Knowledges.search_knowledge_bases( - user.id, filter=filter, skip=skip, limit=limit + user.id, filter=filter, skip=skip, limit=limit, db=db ) return KnowledgeAccessListResponse( @@ -118,7 +179,10 @@ async def search_knowledge_bases( **knowledge_base.model_dump(), write_access=( user.id == knowledge_base.user_id - or has_access(user.id, "write", knowledge_base.access_control) + or (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL) + or has_access( + user.id, "write", knowledge_base.access_control, db=db + ) ), ) for knowledge_base in result.items @@ -132,6 +196,7 @@ async def search_knowledge_files( query: Optional[str] = None, page: Optional[int] = 1, user=Depends(get_verified_user), + db: Session = Depends(get_session), ): page = max(page, 1) limit = PAGE_ITEM_COUNT @@ -141,13 +206,15 @@ async def search_knowledge_files( if query: filter["query"] = query - groups = Groups.get_groups_by_member_id(user.id) + groups = Groups.get_groups_by_member_id(user.id, db=db) if groups: filter["group_ids"] = [group.id for group in groups] filter["user_id"] = user.id - return Knowledges.search_knowledge_files(filter=filter, skip=skip, limit=limit) + return Knowledges.search_knowledge_files( + filter=filter, skip=skip, limit=limit, db=db + ) ############################ @@ -157,10 +224,13 @@ async def search_knowledge_files( @router.post("/create", response_model=Optional[KnowledgeResponse]) async def create_new_knowledge( - request: Request, form_data: KnowledgeForm, user=Depends(get_verified_user) + request: Request, + form_data: KnowledgeForm, + user=Depends(get_verified_user), + db: Session = Depends(get_session), ): if user.role != "admin" and not has_permission( - user.id, "workspace.knowledge", request.app.state.config.USER_PERMISSIONS + user.id, "workspace.knowledge", request.app.state.config.USER_PERMISSIONS, db=db ): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -175,13 +245,21 @@ async def create_new_knowledge( user.id, "sharing.public_knowledge", request.app.state.config.USER_PERMISSIONS, + db=db, ) ): form_data.access_control = {} - knowledge = Knowledges.insert_new_knowledge(user.id, form_data) + knowledge = Knowledges.insert_new_knowledge(user.id, form_data, db=db) if knowledge: + # Embed knowledge base for semantic search + await embed_knowledge_base_metadata( + request, + knowledge.id, + knowledge.name, + knowledge.description, + ) return knowledge else: raise HTTPException( @@ -196,20 +274,24 @@ async def create_new_knowledge( @router.post("/reindex", response_model=bool) -async def reindex_knowledge_files(request: Request, user=Depends(get_verified_user)): +async def reindex_knowledge_files( + request: Request, + user=Depends(get_verified_user), + db: Session = Depends(get_session), +): if user.role != "admin": raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.UNAUTHORIZED, ) - knowledge_bases = Knowledges.get_knowledge_bases() + knowledge_bases = Knowledges.get_knowledge_bases(db=db) log.info(f"Starting reindexing for {len(knowledge_bases)} knowledge bases") for knowledge_base in knowledge_bases: try: - files = Knowledges.get_files_by_id(knowledge_base.id) + files = Knowledges.get_files_by_id(knowledge_base.id, db=db) try: if VECTOR_DB_CLIENT.has_collection(collection_name=knowledge_base.id): VECTOR_DB_CLIENT.delete_collection( @@ -229,6 +311,7 @@ async def reindex_knowledge_files(request: Request, user=Depends(get_verified_us file_id=file.id, collection_name=knowledge_base.id ), user=user, + db=db, ) except Exception as e: log.error( @@ -253,6 +336,30 @@ async def reindex_knowledge_files(request: Request, user=Depends(get_verified_us return True +############################ +# ReindexKnowledgeBases +############################ + + +@router.post("/metadata/reindex", response_model=dict) +async def reindex_knowledge_base_metadata_embeddings( + request: Request, + user=Depends(get_admin_user), + db: Session = Depends(get_session), +): + """Batch embed all existing knowledge bases. Admin only.""" + knowledge_bases = Knowledges.get_knowledge_bases(db=db) + log.info(f"Reindexing embeddings for {len(knowledge_bases)} knowledge bases") + + success_count = 0 + for kb in knowledge_bases: + if await embed_knowledge_base_metadata(request, kb.id, kb.name, kb.description): + success_count += 1 + + log.info(f"Embedding reindex complete: {success_count}/{len(knowledge_bases)}") + return {"total": len(knowledge_bases), "success": success_count} + + ############################ # GetKnowledgeById ############################ @@ -264,26 +371,34 @@ class KnowledgeFilesResponse(KnowledgeResponse): @router.get("/{id}", response_model=Optional[KnowledgeFilesResponse]) -async def get_knowledge_by_id(id: str, user=Depends(get_verified_user)): - knowledge = Knowledges.get_knowledge_by_id(id=id) +async def get_knowledge_by_id( + id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) +): + knowledge = Knowledges.get_knowledge_by_id(id=id, db=db) if knowledge: if ( user.role == "admin" or knowledge.user_id == user.id - or has_access(user.id, "read", knowledge.access_control) + or has_access(user.id, "read", knowledge.access_control, db=db) ): return KnowledgeFilesResponse( **knowledge.model_dump(), write_access=( user.id == knowledge.user_id - or has_access(user.id, "write", knowledge.access_control) + or (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL) + or has_access(user.id, "write", knowledge.access_control, db=db) ), ) + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) else: raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, + status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND, ) @@ -299,8 +414,9 @@ async def update_knowledge_by_id( id: str, form_data: KnowledgeForm, user=Depends(get_verified_user), + db: Session = Depends(get_session), ): - knowledge = Knowledges.get_knowledge_by_id(id=id) + knowledge = Knowledges.get_knowledge_by_id(id=id, db=db) if not knowledge: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -309,7 +425,7 @@ async def update_knowledge_by_id( # Is the user the original creator, in a group with write access, or an admin if ( knowledge.user_id != user.id - and not has_access(user.id, "write", knowledge.access_control) + and not has_access(user.id, "write", knowledge.access_control, db=db) and user.role != "admin" ): raise HTTPException( @@ -325,15 +441,23 @@ async def update_knowledge_by_id( user.id, "sharing.public_knowledge", request.app.state.config.USER_PERMISSIONS, + db=db, ) ): form_data.access_control = {} - knowledge = Knowledges.update_knowledge_by_id(id=id, form_data=form_data) + knowledge = Knowledges.update_knowledge_by_id(id=id, form_data=form_data, db=db) if knowledge: + # Re-embed knowledge base for semantic search + await embed_knowledge_base_metadata( + request, + knowledge.id, + knowledge.name, + knowledge.description, + ) return KnowledgeFilesResponse( **knowledge.model_dump(), - files=Knowledges.get_file_metadatas_by_id(knowledge.id), + files=Knowledges.get_file_metadatas_by_id(knowledge.id, db=db), ) else: raise HTTPException( @@ -356,9 +480,10 @@ async def get_knowledge_files_by_id( direction: Optional[str] = None, page: Optional[int] = 1, user=Depends(get_verified_user), + db: Session = Depends(get_session), ): - knowledge = Knowledges.get_knowledge_by_id(id=id) + knowledge = Knowledges.get_knowledge_by_id(id=id, db=db) if not knowledge: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -368,7 +493,7 @@ async def get_knowledge_files_by_id( if not ( user.role == "admin" or knowledge.user_id == user.id - or has_access(user.id, "read", knowledge.access_control) + or has_access(user.id, "read", knowledge.access_control, db=db) ): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -391,7 +516,7 @@ async def get_knowledge_files_by_id( filter["direction"] = direction return Knowledges.search_files_by_id( - id, user.id, filter=filter, skip=skip, limit=limit + id, user.id, filter=filter, skip=skip, limit=limit, db=db ) @@ -410,8 +535,9 @@ def add_file_to_knowledge_by_id( id: str, form_data: KnowledgeFileIdForm, user=Depends(get_verified_user), + db: Session = Depends(get_session), ): - knowledge = Knowledges.get_knowledge_by_id(id=id) + knowledge = Knowledges.get_knowledge_by_id(id=id, db=db) if not knowledge: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -420,7 +546,7 @@ def add_file_to_knowledge_by_id( if ( knowledge.user_id != user.id - and not has_access(user.id, "write", knowledge.access_control) + and not has_access(user.id, "write", knowledge.access_control, db=db) and user.role != "admin" ): raise HTTPException( @@ -428,7 +554,7 @@ def add_file_to_knowledge_by_id( detail=ERROR_MESSAGES.ACCESS_PROHIBITED, ) - file = Files.get_file_by_id(form_data.file_id) + file = Files.get_file_by_id(form_data.file_id, db=db) if not file: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -446,11 +572,12 @@ def add_file_to_knowledge_by_id( request, ProcessFileForm(file_id=form_data.file_id, collection_name=id), user=user, + db=db, ) # Add file to knowledge base Knowledges.add_file_to_knowledge_by_id( - knowledge_id=id, file_id=form_data.file_id, user_id=user.id + knowledge_id=id, file_id=form_data.file_id, user_id=user.id, db=db ) except Exception as e: log.debug(e) @@ -462,7 +589,7 @@ def add_file_to_knowledge_by_id( if knowledge: return KnowledgeFilesResponse( **knowledge.model_dump(), - files=Knowledges.get_file_metadatas_by_id(knowledge.id), + files=Knowledges.get_file_metadatas_by_id(knowledge.id, db=db), ) else: raise HTTPException( @@ -477,8 +604,9 @@ def update_file_from_knowledge_by_id( id: str, form_data: KnowledgeFileIdForm, user=Depends(get_verified_user), + db: Session = Depends(get_session), ): - knowledge = Knowledges.get_knowledge_by_id(id=id) + knowledge = Knowledges.get_knowledge_by_id(id=id, db=db) if not knowledge: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -487,7 +615,7 @@ def update_file_from_knowledge_by_id( if ( knowledge.user_id != user.id - and not has_access(user.id, "write", knowledge.access_control) + and not has_access(user.id, "write", knowledge.access_control, db=db) and user.role != "admin" ): @@ -496,7 +624,7 @@ def update_file_from_knowledge_by_id( detail=ERROR_MESSAGES.ACCESS_PROHIBITED, ) - file = Files.get_file_by_id(form_data.file_id) + file = Files.get_file_by_id(form_data.file_id, db=db) if not file: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -514,6 +642,7 @@ def update_file_from_knowledge_by_id( request, ProcessFileForm(file_id=form_data.file_id, collection_name=id), user=user, + db=db, ) except Exception as e: raise HTTPException( @@ -524,7 +653,7 @@ def update_file_from_knowledge_by_id( if knowledge: return KnowledgeFilesResponse( **knowledge.model_dump(), - files=Knowledges.get_file_metadatas_by_id(knowledge.id), + files=Knowledges.get_file_metadatas_by_id(knowledge.id, db=db), ) else: raise HTTPException( @@ -544,8 +673,9 @@ def remove_file_from_knowledge_by_id( form_data: KnowledgeFileIdForm, delete_file: bool = Query(True), user=Depends(get_verified_user), + db: Session = Depends(get_session), ): - knowledge = Knowledges.get_knowledge_by_id(id=id) + knowledge = Knowledges.get_knowledge_by_id(id=id, db=db) if not knowledge: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -554,7 +684,7 @@ def remove_file_from_knowledge_by_id( if ( knowledge.user_id != user.id - and not has_access(user.id, "write", knowledge.access_control) + and not has_access(user.id, "write", knowledge.access_control, db=db) and user.role != "admin" ): raise HTTPException( @@ -562,7 +692,7 @@ def remove_file_from_knowledge_by_id( detail=ERROR_MESSAGES.ACCESS_PROHIBITED, ) - file = Files.get_file_by_id(form_data.file_id) + file = Files.get_file_by_id(form_data.file_id, db=db) if not file: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -570,7 +700,7 @@ def remove_file_from_knowledge_by_id( ) Knowledges.remove_file_from_knowledge_by_id( - knowledge_id=id, file_id=form_data.file_id + knowledge_id=id, file_id=form_data.file_id, db=db ) # Remove content from the vector database @@ -599,12 +729,12 @@ def remove_file_from_knowledge_by_id( pass # Delete file from database - Files.delete_file_by_id(form_data.file_id) + Files.delete_file_by_id(form_data.file_id, db=db) if knowledge: return KnowledgeFilesResponse( **knowledge.model_dump(), - files=Knowledges.get_file_metadatas_by_id(knowledge.id), + files=Knowledges.get_file_metadatas_by_id(knowledge.id, db=db), ) else: raise HTTPException( @@ -619,8 +749,10 @@ def remove_file_from_knowledge_by_id( @router.delete("/{id}/delete", response_model=bool) -async def delete_knowledge_by_id(id: str, user=Depends(get_verified_user)): - knowledge = Knowledges.get_knowledge_by_id(id=id) +async def delete_knowledge_by_id( + id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) +): + knowledge = Knowledges.get_knowledge_by_id(id=id, db=db) if not knowledge: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -629,7 +761,7 @@ async def delete_knowledge_by_id(id: str, user=Depends(get_verified_user)): if ( knowledge.user_id != user.id - and not has_access(user.id, "write", knowledge.access_control) + and not has_access(user.id, "write", knowledge.access_control, db=db) and user.role != "admin" ): raise HTTPException( @@ -640,7 +772,7 @@ async def delete_knowledge_by_id(id: str, user=Depends(get_verified_user)): log.info(f"Deleting knowledge base: {id} (name: {knowledge.name})") # Get all models - models = Models.get_all_models() + models = Models.get_all_models(db=db) log.info(f"Found {len(models)} models to check for knowledge base {id}") # Update models that reference this knowledge base @@ -664,7 +796,7 @@ async def delete_knowledge_by_id(id: str, user=Depends(get_verified_user)): access_control=model.access_control, is_active=model.is_active, ) - Models.update_model_by_id(model.id, model_form) + Models.update_model_by_id(model.id, model_form, db=db) # Clean up vector DB try: @@ -672,7 +804,11 @@ async def delete_knowledge_by_id(id: str, user=Depends(get_verified_user)): except Exception as e: log.debug(e) pass - result = Knowledges.delete_knowledge_by_id(id=id) + + # Remove knowledge base embedding + remove_knowledge_base_metadata_embedding(id) + + result = Knowledges.delete_knowledge_by_id(id=id, db=db) return result @@ -682,8 +818,10 @@ async def delete_knowledge_by_id(id: str, user=Depends(get_verified_user)): @router.post("/{id}/reset", response_model=Optional[KnowledgeResponse]) -async def reset_knowledge_by_id(id: str, user=Depends(get_verified_user)): - knowledge = Knowledges.get_knowledge_by_id(id=id) +async def reset_knowledge_by_id( + id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) +): + knowledge = Knowledges.get_knowledge_by_id(id=id, db=db) if not knowledge: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -692,7 +830,7 @@ async def reset_knowledge_by_id(id: str, user=Depends(get_verified_user)): if ( knowledge.user_id != user.id - and not has_access(user.id, "write", knowledge.access_control) + and not has_access(user.id, "write", knowledge.access_control, db=db) and user.role != "admin" ): raise HTTPException( @@ -706,7 +844,7 @@ async def reset_knowledge_by_id(id: str, user=Depends(get_verified_user)): log.debug(e) pass - knowledge = Knowledges.reset_knowledge_by_id(id=id) + knowledge = Knowledges.reset_knowledge_by_id(id=id, db=db) return knowledge @@ -721,11 +859,12 @@ async def add_files_to_knowledge_batch( id: str, form_data: list[KnowledgeFileIdForm], user=Depends(get_verified_user), + db: Session = Depends(get_session), ): """ Add multiple files to a knowledge base """ - knowledge = Knowledges.get_knowledge_by_id(id=id) + knowledge = Knowledges.get_knowledge_by_id(id=id, db=db) if not knowledge: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -734,7 +873,7 @@ async def add_files_to_knowledge_batch( if ( knowledge.user_id != user.id - and not has_access(user.id, "write", knowledge.access_control) + and not has_access(user.id, "write", knowledge.access_control, db=db) and user.role != "admin" ): raise HTTPException( @@ -746,7 +885,7 @@ async def add_files_to_knowledge_batch( log.info(f"files/batch/add - {len(form_data)} files") files: List[FileModel] = [] for form in form_data: - file = Files.get_file_by_id(form.file_id) + file = Files.get_file_by_id(form.file_id, db=db) if not file: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -760,6 +899,7 @@ async def add_files_to_knowledge_batch( request=request, form_data=BatchProcessFilesForm(files=files, collection_name=id), user=user, + db=db, ) except Exception as e: log.error( @@ -771,7 +911,7 @@ async def add_files_to_knowledge_batch( successful_file_ids = [r.file_id for r in result.results if r.status == "completed"] for file_id in successful_file_ids: Knowledges.add_file_to_knowledge_by_id( - knowledge_id=id, file_id=file_id, user_id=user.id + knowledge_id=id, file_id=file_id, user_id=user.id, db=db ) # If there were any errors, include them in the response @@ -779,7 +919,7 @@ async def add_files_to_knowledge_batch( error_details = [f"{err.file_id}: {err.error}" for err in result.errors] return KnowledgeFilesResponse( **knowledge.model_dump(), - files=Knowledges.get_file_metadatas_by_id(knowledge.id), + files=Knowledges.get_file_metadatas_by_id(knowledge.id, db=db), warnings={ "message": "Some files failed to process", "errors": error_details, @@ -788,5 +928,53 @@ async def add_files_to_knowledge_batch( return KnowledgeFilesResponse( **knowledge.model_dump(), - files=Knowledges.get_file_metadatas_by_id(knowledge.id), + files=Knowledges.get_file_metadatas_by_id(knowledge.id, db=db), + ) + + +############################ +# ExportKnowledgeById +############################ + + +@router.get("/{id}/export") +async def export_knowledge_by_id( + id: str, user=Depends(get_admin_user), db: Session = Depends(get_session) +): + """ + Export a knowledge base as a zip file containing .txt files. + Admin only. + """ + + knowledge = Knowledges.get_knowledge_by_id(id=id, db=db) + if not knowledge: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + files = Knowledges.get_files_by_id(id, db=db) + + # Create zip file in memory + zip_buffer = io.BytesIO() + with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zf: + for file in files: + content = file.data.get("content", "") if file.data else "" + if content: + # Use original filename with .txt extension + filename = file.filename + if not filename.endswith(".txt"): + filename = f"{filename}.txt" + zf.writestr(filename, content) + + zip_buffer.seek(0) + + # Sanitize knowledge name for filename + safe_name = "".join(c if c.isalnum() or c in " -_" else "_" for c in knowledge.name) + zip_filename = f"{safe_name}.zip" + + return StreamingResponse( + zip_buffer, + media_type="application/zip", + headers={"Content-Disposition": f"attachment; filename={zip_filename}"}, ) diff --git a/backend/open_webui/routers/memories.py b/backend/open_webui/routers/memories.py index 9bb1ef518d..e0ba36c76f 100644 --- a/backend/open_webui/routers/memories.py +++ b/backend/open_webui/routers/memories.py @@ -1,4 +1,4 @@ -from fastapi import APIRouter, Depends, HTTPException, Request +from fastapi import APIRouter, Depends, HTTPException, Request, status from pydantic import BaseModel import logging import asyncio @@ -7,7 +7,11 @@ from open_webui.models.memories import Memories, MemoryModel from open_webui.retrieval.vector.factory import VECTOR_DB_CLIENT from open_webui.utils.auth import get_verified_user +from open_webui.internal.db import get_session +from sqlalchemy.orm import Session +from open_webui.utils.access_control import has_permission +from open_webui.constants import ERROR_MESSAGES log = logging.getLogger(__name__) @@ -25,8 +29,26 @@ async def get_embeddings(request: Request): @router.get("/", response_model=list[MemoryModel]) -async def get_memories(user=Depends(get_verified_user)): - return Memories.get_memories_by_user_id(user.id) +async def get_memories( + request: Request, + user=Depends(get_verified_user), + db: Session = Depends(get_session), +): + if not request.app.state.config.ENABLE_MEMORIES: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + if not has_permission( + user.id, "features.memories", request.app.state.config.USER_PERMISSIONS + ): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) + + return Memories.get_memories_by_user_id(user.id, db=db) ############################ @@ -47,8 +69,23 @@ async def add_memory( request: Request, form_data: AddMemoryForm, user=Depends(get_verified_user), + db: Session = Depends(get_session), ): - memory = Memories.insert_new_memory(user.id, form_data.content) + if not request.app.state.config.ENABLE_MEMORIES: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + if not has_permission( + user.id, "features.memories", request.app.state.config.USER_PERMISSIONS + ): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) + + memory = Memories.insert_new_memory(user.id, form_data.content, db=db) vector = await request.app.state.EMBEDDING_FUNCTION(memory.content, user=user) @@ -79,9 +116,26 @@ class QueryMemoryForm(BaseModel): @router.post("/query") async def query_memory( - request: Request, form_data: QueryMemoryForm, user=Depends(get_verified_user) + request: Request, + form_data: QueryMemoryForm, + user=Depends(get_verified_user), + db: Session = Depends(get_session), ): - memories = Memories.get_memories_by_user_id(user.id) + if not request.app.state.config.ENABLE_MEMORIES: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + if not has_permission( + user.id, "features.memories", request.app.state.config.USER_PERMISSIONS + ): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) + + memories = Memories.get_memories_by_user_id(user.id, db=db) if not memories: raise HTTPException(status_code=404, detail="No memories found for user") @@ -101,11 +155,27 @@ async def query_memory( ############################ @router.post("/reset", response_model=bool) async def reset_memory_from_vector_db( - request: Request, user=Depends(get_verified_user) + request: Request, + user=Depends(get_verified_user), + db: Session = Depends(get_session), ): + if not request.app.state.config.ENABLE_MEMORIES: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + if not has_permission( + user.id, "features.memories", request.app.state.config.USER_PERMISSIONS + ): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) + VECTOR_DB_CLIENT.delete_collection(f"user-memory-{user.id}") - memories = Memories.get_memories_by_user_id(user.id) + memories = Memories.get_memories_by_user_id(user.id, db=db) # Generate vectors in parallel vectors = await asyncio.gather( @@ -140,8 +210,26 @@ async def reset_memory_from_vector_db( @router.delete("/delete/user", response_model=bool) -async def delete_memory_by_user_id(user=Depends(get_verified_user)): - result = Memories.delete_memories_by_user_id(user.id) +async def delete_memory_by_user_id( + request: Request, + user=Depends(get_verified_user), + db: Session = Depends(get_session), +): + if not request.app.state.config.ENABLE_MEMORIES: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + if not has_permission( + user.id, "features.memories", request.app.state.config.USER_PERMISSIONS + ): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) + + result = Memories.delete_memories_by_user_id(user.id, db=db) if result: try: @@ -164,9 +252,24 @@ async def update_memory_by_id( request: Request, form_data: MemoryUpdateModel, user=Depends(get_verified_user), + db: Session = Depends(get_session), ): + if not request.app.state.config.ENABLE_MEMORIES: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + if not has_permission( + user.id, "features.memories", request.app.state.config.USER_PERMISSIONS + ): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) + memory = Memories.update_memory_by_id_and_user_id( - memory_id, user.id, form_data.content + memory_id, user.id, form_data.content, db=db ) if memory is None: raise HTTPException(status_code=404, detail="Memory not found") @@ -198,8 +301,27 @@ async def update_memory_by_id( @router.delete("/{memory_id}", response_model=bool) -async def delete_memory_by_id(memory_id: str, user=Depends(get_verified_user)): - result = Memories.delete_memory_by_id_and_user_id(memory_id, user.id) +async def delete_memory_by_id( + memory_id: str, + request: Request, + user=Depends(get_verified_user), + db: Session = Depends(get_session), +): + if not request.app.state.config.ENABLE_MEMORIES: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + if not has_permission( + user.id, "features.memories", request.app.state.config.USER_PERMISSIONS + ): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) + + result = Memories.delete_memory_by_id_and_user_id(memory_id, user.id, db=db) if result: VECTOR_DB_CLIENT.delete( diff --git a/backend/open_webui/routers/models.py b/backend/open_webui/routers/models.py index 4475b2d78e..a1f642bbce 100644 --- a/backend/open_webui/routers/models.py +++ b/backend/open_webui/routers/models.py @@ -11,6 +11,8 @@ ModelModel, ModelResponse, ModelListResponse, + ModelAccessListResponse, + ModelAccessResponse, Models, ) @@ -30,6 +32,8 @@ from open_webui.utils.auth import get_admin_user, get_verified_user from open_webui.utils.access_control import has_access, has_permission from open_webui.config import BYPASS_ADMIN_ACCESS_CONTROL, STATIC_DIR +from open_webui.internal.db import get_session +from sqlalchemy.orm import Session log = logging.getLogger(__name__) @@ -49,7 +53,7 @@ def is_valid_model_id(model_id: str) -> bool: @router.get( - "/list", response_model=ModelListResponse + "/list", response_model=ModelAccessListResponse ) # do NOT use "/" as path, conflicts with main.py async def get_models( query: Optional[str] = None, @@ -59,6 +63,7 @@ async def get_models( direction: Optional[str] = None, page: Optional[int] = 1, user=Depends(get_verified_user), + db: Session = Depends(get_session), ): limit = PAGE_ITEM_COUNT @@ -79,13 +84,27 @@ async def get_models( filter["direction"] = direction if not user.role == "admin" or not BYPASS_ADMIN_ACCESS_CONTROL: - groups = Groups.get_groups_by_member_id(user.id) + groups = Groups.get_groups_by_member_id(user.id, db=db) if groups: filter["group_ids"] = [group.id for group in groups] filter["user_id"] = user.id - return Models.search_models(user.id, filter=filter, skip=skip, limit=limit) + result = Models.search_models(user.id, filter=filter, skip=skip, limit=limit, db=db) + return ModelAccessListResponse( + items=[ + ModelAccessResponse( + **model.model_dump(), + write_access=( + (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL) + or user.id == model.user_id + or has_access(user.id, "write", model.access_control, db=db) + ), + ) + for model in result.items + ], + total=result.total, + ) ########################### @@ -94,8 +113,10 @@ async def get_models( @router.get("/base", response_model=list[ModelResponse]) -async def get_base_models(user=Depends(get_admin_user)): - return Models.get_base_models() +async def get_base_models( + user=Depends(get_admin_user), db: Session = Depends(get_session) +): + return Models.get_base_models(db=db) ########################### @@ -104,11 +125,13 @@ async def get_base_models(user=Depends(get_admin_user)): @router.get("/tags", response_model=list[str]) -async def get_model_tags(user=Depends(get_verified_user)): +async def get_model_tags( + user=Depends(get_verified_user), db: Session = Depends(get_session) +): if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL: - models = Models.get_models() + models = Models.get_models(db=db) else: - models = Models.get_models_by_user_id(user.id) + models = Models.get_models_by_user_id(user.id, db=db) tags_set = set() for model in models: @@ -132,16 +155,17 @@ async def create_new_model( request: Request, form_data: ModelForm, user=Depends(get_verified_user), + db: Session = Depends(get_session), ): if user.role != "admin" and not has_permission( - user.id, "workspace.models", request.app.state.config.USER_PERMISSIONS + user.id, "workspace.models", request.app.state.config.USER_PERMISSIONS, db=db ): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.UNAUTHORIZED, ) - model = Models.get_model_by_id(form_data.id) + model = Models.get_model_by_id(form_data.id, db=db) if model: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -155,7 +179,7 @@ async def create_new_model( ) else: - model = Models.insert_new_model(form_data, user.id) + model = Models.insert_new_model(form_data, user.id, db=db) if model: return model else: @@ -171,9 +195,16 @@ async def create_new_model( @router.get("/export", response_model=list[ModelModel]) -async def export_models(request: Request, user=Depends(get_verified_user)): +async def export_models( + request: Request, + user=Depends(get_verified_user), + db: Session = Depends(get_session), +): if user.role != "admin" and not has_permission( - user.id, "workspace.models_export", request.app.state.config.USER_PERMISSIONS + user.id, + "workspace.models_export", + request.app.state.config.USER_PERMISSIONS, + db=db, ): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -181,9 +212,9 @@ async def export_models(request: Request, user=Depends(get_verified_user)): ) if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL: - return Models.get_models() + return Models.get_models(db=db) else: - return Models.get_models_by_user_id(user.id) + return Models.get_models_by_user_id(user.id, db=db) ############################ @@ -200,9 +231,13 @@ async def import_models( request: Request, user=Depends(get_verified_user), form_data: ModelsImportForm = (...), + db: Session = Depends(get_session), ): if user.role != "admin" and not has_permission( - user.id, "workspace.models_import", request.app.state.config.USER_PERMISSIONS + user.id, + "workspace.models_import", + request.app.state.config.USER_PERMISSIONS, + db=db, ): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -216,7 +251,7 @@ async def import_models( model_id = model_data.get("id") if model_id and is_valid_model_id(model_id): - existing_model = Models.get_model_by_id(model_id) + existing_model = Models.get_model_by_id(model_id, db=db) if existing_model: # Update existing model model_data["meta"] = model_data.get("meta", {}) @@ -225,13 +260,15 @@ async def import_models( updated_model = ModelForm( **{**existing_model.model_dump(), **model_data} ) - Models.update_model_by_id(model_id, updated_model) + Models.update_model_by_id(model_id, updated_model, db=db) else: # Insert new model model_data["meta"] = model_data.get("meta", {}) model_data["params"] = model_data.get("params", {}) new_model = ModelForm(**model_data) - Models.insert_new_model(user_id=user.id, form_data=new_model) + Models.insert_new_model( + user_id=user.id, form_data=new_model, db=db + ) return True else: raise HTTPException(status_code=400, detail="Invalid JSON format") @@ -251,9 +288,12 @@ class SyncModelsForm(BaseModel): @router.post("/sync", response_model=list[ModelModel]) async def sync_models( - request: Request, form_data: SyncModelsForm, user=Depends(get_admin_user) + request: Request, + form_data: SyncModelsForm, + user=Depends(get_admin_user), + db: Session = Depends(get_session), ): - return Models.sync_models(user.id, form_data.models) + return Models.sync_models(user.id, form_data.models, db=db) ########################### @@ -266,19 +306,33 @@ class ModelIdForm(BaseModel): # Note: We're not using the typical url path param here, but instead using a query parameter to allow '/' in the id -@router.get("/model", response_model=Optional[ModelResponse]) -async def get_model_by_id(id: str, user=Depends(get_verified_user)): - model = Models.get_model_by_id(id) +@router.get("/model", response_model=Optional[ModelAccessResponse]) +async def get_model_by_id( + id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) +): + model = Models.get_model_by_id(id, db=db) if model: if ( (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL) or model.user_id == user.id - or has_access(user.id, "read", model.access_control) + or has_access(user.id, "read", model.access_control, db=db) ): - return model + return ModelAccessResponse( + **model.model_dump(), + write_access=( + (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL) + or user.id == model.user_id + or has_access(user.id, "write", model.access_control, db=db) + ), + ) + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) else: raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, + status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND, ) @@ -289,38 +343,42 @@ async def get_model_by_id(id: str, user=Depends(get_verified_user)): @router.get("/model/profile/image") -async def get_model_profile_image(id: str, user=Depends(get_verified_user)): - model = Models.get_model_by_id(id) - # Cache-control headers to prevent stale cached images - cache_headers = {"Cache-Control": "no-cache, must-revalidate"} +def get_model_profile_image( + id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) +): + model = Models.get_model_by_id(id, db=db) if model: + etag = f'"{model.updated_at}"' if model.updated_at else None + if model.meta.profile_image_url: if model.meta.profile_image_url.startswith("http"): return Response( status_code=status.HTTP_302_FOUND, - headers={"Location": model.meta.profile_image_url, **cache_headers}, + headers={"Location": model.meta.profile_image_url}, ) elif model.meta.profile_image_url.startswith("data:image"): try: header, base64_data = model.meta.profile_image_url.split(",", 1) image_data = base64.b64decode(base64_data) image_buffer = io.BytesIO(image_data) + media_type = header.split(";")[0].lstrip("data:") + + headers = {"Content-Disposition": "inline"} + if etag: + headers["ETag"] = etag return StreamingResponse( image_buffer, - media_type="image/png", - headers={ - "Content-Disposition": "inline; filename=image.png", - **cache_headers, - }, + media_type=media_type, + headers=headers, ) except Exception as e: pass - return FileResponse(f"{STATIC_DIR}/favicon.png", headers=cache_headers) + return FileResponse(f"{STATIC_DIR}/favicon.png") else: - return FileResponse(f"{STATIC_DIR}/favicon.png", headers=cache_headers) + return FileResponse(f"{STATIC_DIR}/favicon.png") ############################ @@ -329,15 +387,17 @@ async def get_model_profile_image(id: str, user=Depends(get_verified_user)): @router.post("/model/toggle", response_model=Optional[ModelResponse]) -async def toggle_model_by_id(id: str, user=Depends(get_verified_user)): - model = Models.get_model_by_id(id) +async def toggle_model_by_id( + id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) +): + model = Models.get_model_by_id(id, db=db) if model: if ( user.role == "admin" or model.user_id == user.id - or has_access(user.id, "write", model.access_control) + or has_access(user.id, "write", model.access_control, db=db) ): - model = Models.toggle_model_by_id(id) + model = Models.toggle_model_by_id(id, db=db) if model: return model @@ -367,8 +427,9 @@ async def toggle_model_by_id(id: str, user=Depends(get_verified_user)): async def update_model_by_id( form_data: ModelForm, user=Depends(get_verified_user), + db: Session = Depends(get_session), ): - model = Models.get_model_by_id(form_data.id) + model = Models.get_model_by_id(form_data.id, db=db) if not model: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -377,7 +438,7 @@ async def update_model_by_id( if ( model.user_id != user.id - and not has_access(user.id, "write", model.access_control) + and not has_access(user.id, "write", model.access_control, db=db) and user.role != "admin" ): raise HTTPException( @@ -385,7 +446,9 @@ async def update_model_by_id( detail=ERROR_MESSAGES.ACCESS_PROHIBITED, ) - model = Models.update_model_by_id(form_data.id, ModelForm(**form_data.model_dump())) + model = Models.update_model_by_id( + form_data.id, ModelForm(**form_data.model_dump()), db=db + ) return model @@ -395,8 +458,12 @@ async def update_model_by_id( @router.post("/model/delete", response_model=bool) -async def delete_model_by_id(form_data: ModelIdForm, user=Depends(get_verified_user)): - model = Models.get_model_by_id(form_data.id) +async def delete_model_by_id( + form_data: ModelIdForm, + user=Depends(get_verified_user), + db: Session = Depends(get_session), +): + model = Models.get_model_by_id(form_data.id, db=db) if not model: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -406,18 +473,20 @@ async def delete_model_by_id(form_data: ModelIdForm, user=Depends(get_verified_u if ( user.role != "admin" and model.user_id != user.id - and not has_access(user.id, "write", model.access_control) + and not has_access(user.id, "write", model.access_control, db=db) ): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.UNAUTHORIZED, ) - result = Models.delete_model_by_id(form_data.id) + result = Models.delete_model_by_id(form_data.id, db=db) return result @router.delete("/delete/all", response_model=bool) -async def delete_all_models(user=Depends(get_admin_user)): - result = Models.delete_all_models() +async def delete_all_models( + user=Depends(get_admin_user), db: Session = Depends(get_session) +): + result = Models.delete_all_models(db=db) return result diff --git a/backend/open_webui/routers/notes.py b/backend/open_webui/routers/notes.py index ee0e46da29..56730e2b6a 100644 --- a/backend/open_webui/routers/notes.py +++ b/backend/open_webui/routers/notes.py @@ -28,6 +28,8 @@ from open_webui.utils.auth import get_admin_user, get_verified_user from open_webui.utils.access_control import has_access, has_permission +from open_webui.internal.db import get_session +from sqlalchemy.orm import Session log = logging.getLogger(__name__) @@ -49,10 +51,13 @@ class NoteItemResponse(BaseModel): @router.get("/", response_model=list[NoteItemResponse]) async def get_notes( - request: Request, page: Optional[int] = None, user=Depends(get_verified_user) + request: Request, + page: Optional[int] = None, + user=Depends(get_verified_user), + db: Session = Depends(get_session), ): if user.role != "admin" and not has_permission( - user.id, "features.notes", request.app.state.config.USER_PERMISSIONS + user.id, "features.notes", request.app.state.config.USER_PERMISSIONS, db=db ): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -65,16 +70,23 @@ async def get_notes( limit = 60 skip = (page - 1) * limit - notes = [ + notes = Notes.get_notes_by_user_id(user.id, "read", skip=skip, limit=limit, db=db) + if not notes: + return [] + + user_ids = list(set(note.user_id for note in notes)) + users = {user.id: user for user in Users.get_users_by_user_ids(user_ids, db=db)} + + return [ NoteUserResponse( **{ **note.model_dump(), - "user": UserResponse(**Users.get_user_by_id(note.user_id).model_dump()), + "user": UserResponse(**users[note.user_id].model_dump()), } ) - for note in Notes.get_notes_by_user_id(user.id, "read", skip=skip, limit=limit) + for note in notes + if note.user_id in users ] - return notes @router.get("/search", response_model=NoteListResponse) @@ -87,9 +99,10 @@ async def search_notes( direction: Optional[str] = None, page: Optional[int] = 1, user=Depends(get_verified_user), + db: Session = Depends(get_session), ): if user.role != "admin" and not has_permission( - user.id, "features.notes", request.app.state.config.USER_PERMISSIONS + user.id, "features.notes", request.app.state.config.USER_PERMISSIONS, db=db ): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -115,13 +128,13 @@ async def search_notes( filter["direction"] = direction if not user.role == "admin" or not BYPASS_ADMIN_ACCESS_CONTROL: - groups = Groups.get_groups_by_member_id(user.id) + groups = Groups.get_groups_by_member_id(user.id, db=db) if groups: filter["group_ids"] = [group.id for group in groups] filter["user_id"] = user.id - return Notes.search_notes(user.id, filter, skip=skip, limit=limit) + return Notes.search_notes(user.id, filter, skip=skip, limit=limit, db=db) ############################ @@ -131,10 +144,13 @@ async def search_notes( @router.post("/create", response_model=Optional[NoteModel]) async def create_new_note( - request: Request, form_data: NoteForm, user=Depends(get_verified_user) + request: Request, + form_data: NoteForm, + user=Depends(get_verified_user), + db: Session = Depends(get_session), ): if user.role != "admin" and not has_permission( - user.id, "features.notes", request.app.state.config.USER_PERMISSIONS + user.id, "features.notes", request.app.state.config.USER_PERMISSIONS, db=db ): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -142,7 +158,7 @@ async def create_new_note( ) try: - note = Notes.insert_new_note(form_data, user.id) + note = Notes.insert_new_note(user.id, form_data, db=db) return note except Exception as e: log.exception(e) @@ -161,16 +177,21 @@ class NoteResponse(NoteModel): @router.get("/{id}", response_model=Optional[NoteResponse]) -async def get_note_by_id(request: Request, id: str, user=Depends(get_verified_user)): +async def get_note_by_id( + request: Request, + id: str, + user=Depends(get_verified_user), + db: Session = Depends(get_session), +): if user.role != "admin" and not has_permission( - user.id, "features.notes", request.app.state.config.USER_PERMISSIONS + user.id, "features.notes", request.app.state.config.USER_PERMISSIONS, db=db ): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.UNAUTHORIZED, ) - note = Notes.get_note_by_id(id) + note = Notes.get_note_by_id(id, db=db) if not note: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND @@ -178,7 +199,11 @@ async def get_note_by_id(request: Request, id: str, user=Depends(get_verified_us if user.role != "admin" and ( user.id != note.user_id - and (not has_access(user.id, type="read", access_control=note.access_control)) + and ( + not has_access( + user.id, type="read", access_control=note.access_control, db=db + ) + ) ): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() @@ -188,7 +213,11 @@ async def get_note_by_id(request: Request, id: str, user=Depends(get_verified_us user.role == "admin" or (user.id == note.user_id) or has_access( - user.id, type="write", access_control=note.access_control, strict=False + user.id, + type="write", + access_control=note.access_control, + strict=False, + db=db, ) ) @@ -202,17 +231,21 @@ async def get_note_by_id(request: Request, id: str, user=Depends(get_verified_us @router.post("/{id}/update", response_model=Optional[NoteModel]) async def update_note_by_id( - request: Request, id: str, form_data: NoteForm, user=Depends(get_verified_user) + request: Request, + id: str, + form_data: NoteForm, + user=Depends(get_verified_user), + db: Session = Depends(get_session), ): if user.role != "admin" and not has_permission( - user.id, "features.notes", request.app.state.config.USER_PERMISSIONS + user.id, "features.notes", request.app.state.config.USER_PERMISSIONS, db=db ): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.UNAUTHORIZED, ) - note = Notes.get_note_by_id(id) + note = Notes.get_note_by_id(id, db=db) if not note: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND @@ -220,7 +253,9 @@ async def update_note_by_id( if user.role != "admin" and ( user.id != note.user_id - and not has_access(user.id, type="write", access_control=note.access_control) + and not has_access( + user.id, type="write", access_control=note.access_control, db=db + ) ): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() @@ -234,12 +269,13 @@ async def update_note_by_id( user.id, "sharing.public_notes", request.app.state.config.USER_PERMISSIONS, + db=db, ) ): form_data.access_control = {} try: - note = Notes.update_note_by_id(id, form_data) + note = Notes.update_note_by_id(id, form_data, db=db) await sio.emit( "note-events", note.model_dump(), @@ -260,16 +296,21 @@ async def update_note_by_id( @router.delete("/{id}/delete", response_model=bool) -async def delete_note_by_id(request: Request, id: str, user=Depends(get_verified_user)): +async def delete_note_by_id( + request: Request, + id: str, + user=Depends(get_verified_user), + db: Session = Depends(get_session), +): if user.role != "admin" and not has_permission( - user.id, "features.notes", request.app.state.config.USER_PERMISSIONS + user.id, "features.notes", request.app.state.config.USER_PERMISSIONS, db=db ): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.UNAUTHORIZED, ) - note = Notes.get_note_by_id(id) + note = Notes.get_note_by_id(id, db=db) if not note: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND @@ -277,14 +318,16 @@ async def delete_note_by_id(request: Request, id: str, user=Depends(get_verified if user.role != "admin" and ( user.id != note.user_id - and not has_access(user.id, type="write", access_control=note.access_control) + and not has_access( + user.id, type="write", access_control=note.access_control, db=db + ) ): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() ) try: - note = Notes.delete_note_by_id(id) + note = Notes.delete_note_by_id(id, db=db) return True except Exception as e: log.exception(e) diff --git a/backend/open_webui/routers/ollama.py b/backend/open_webui/routers/ollama.py index 2a04210a62..afaca2feb5 100644 --- a/backend/open_webui/routers/ollama.py +++ b/backend/open_webui/routers/ollama.py @@ -38,6 +38,9 @@ from fastapi.responses import StreamingResponse from pydantic import BaseModel, ConfigDict, validator from starlette.background import BackgroundTask +from sqlalchemy.orm import Session + +from open_webui.internal.db import get_session from open_webui.models.models import Models @@ -423,14 +426,14 @@ async def get_all_models(request: Request, user: UserModel = None): return models -async def get_filtered_models(models, user): +async def get_filtered_models(models, user, db=None): # Filter models based on user access control filtered_models = [] for model in models.get("models", []): - model_info = Models.get_model_by_id(model["model"]) + model_info = Models.get_model_by_id(model["model"], db=db) if model_info: if user.id == model_info.user_id or has_access( - user.id, type="read", access_control=model_info.access_control + user.id, type="read", access_control=model_info.access_control, db=db ): filtered_models.append(model) return filtered_models @@ -1272,6 +1275,8 @@ async def generate_chat_completion( url_idx: Optional[int] = None, user=Depends(get_verified_user), bypass_filter: Optional[bool] = False, + bypass_system_prompt: bool = False, + db: Session = Depends(get_session), ): if BYPASS_MODEL_ACCESS_CONTROL: bypass_filter = True @@ -1293,7 +1298,7 @@ async def generate_chat_completion( del payload["metadata"] model_id = payload["model"] - model_info = Models.get_model_by_id(model_id) + model_info = Models.get_model_by_id(model_id, db=db) if model_info: if model_info.base_model_id: @@ -1310,14 +1315,18 @@ async def generate_chat_completion( system = params.pop("system", None) payload = apply_model_params_to_body_ollama(params, payload) - payload = apply_system_prompt_to_body(system, payload, metadata, user) + if not bypass_system_prompt: + payload = apply_system_prompt_to_body(system, payload, metadata, user) # Check if user has access to the model if not bypass_filter and user.role == "user": if not ( user.id == model_info.user_id or has_access( - user.id, type="read", access_control=model_info.access_control + user.id, + type="read", + access_control=model_info.access_control, + db=db, ) ): raise HTTPException( @@ -1389,6 +1398,7 @@ async def generate_openai_completion( form_data: dict, url_idx: Optional[int] = None, user=Depends(get_verified_user), + db: Session = Depends(get_session), ): metadata = form_data.pop("metadata", None) @@ -1409,7 +1419,7 @@ async def generate_openai_completion( if ":" not in model_id: model_id = f"{model_id}:latest" - model_info = Models.get_model_by_id(model_id) + model_info = Models.get_model_by_id(model_id, db=db) if model_info: if model_info.base_model_id: payload["model"] = model_info.base_model_id @@ -1423,7 +1433,10 @@ async def generate_openai_completion( if not ( user.id == model_info.user_id or has_access( - user.id, type="read", access_control=model_info.access_control + user.id, + type="read", + access_control=model_info.access_control, + db=db, ) ): raise HTTPException( @@ -1468,6 +1481,7 @@ async def generate_openai_chat_completion( form_data: dict, url_idx: Optional[int] = None, user=Depends(get_verified_user), + db: Session = Depends(get_session), ): check_credit_by_user_id(user_id=user.id, form_data=form_data) @@ -1490,7 +1504,7 @@ async def generate_openai_chat_completion( if ":" not in model_id: model_id = f"{model_id}:latest" - model_info = Models.get_model_by_id(model_id) + model_info = Models.get_model_by_id(model_id, db=db) if model_info: if model_info.base_model_id: payload["model"] = model_info.base_model_id @@ -1508,7 +1522,10 @@ async def generate_openai_chat_completion( if not ( user.id == model_info.user_id or has_access( - user.id, type="read", access_control=model_info.access_control + user.id, + type="read", + access_control=model_info.access_control, + db=db, ) ): raise HTTPException( @@ -1551,6 +1568,7 @@ async def get_openai_models( request: Request, url_idx: Optional[int] = None, user=Depends(get_verified_user), + db: Session = Depends(get_session), ): models = [] @@ -1603,10 +1621,13 @@ async def get_openai_models( # Filter models based on user access control filtered_models = [] for model in models: - model_info = Models.get_model_by_id(model["id"]) + model_info = Models.get_model_by_id(model["id"], db=db) if model_info: if user.id == model_info.user_id or has_access( - user.id, type="read", access_control=model_info.access_control + user.id, + type="read", + access_control=model_info.access_control, + db=db, ): filtered_models.append(model) models = filtered_models @@ -1673,11 +1694,10 @@ async def download_file_stream( if done: file.close() + hashed = calculate_sha256(file_path, chunk_size) with open(file_path, "rb") as file: chunk_size = 1024 * 1024 * 2 - hashed = calculate_sha256(file, chunk_size) - url = f"{ollama_url}/api/blobs/sha256:{hashed}" with requests.Session() as session: response = session.post(url, data=file, timeout=30) diff --git a/backend/open_webui/routers/openai.py b/backend/open_webui/routers/openai.py index 4428ebfba2..f4b87c372d 100644 --- a/backend/open_webui/routers/openai.py +++ b/backend/open_webui/routers/openai.py @@ -19,6 +19,9 @@ ) from pydantic import BaseModel from starlette.background import BackgroundTask +from sqlalchemy.orm import Session + +from open_webui.internal.db import get_session from open_webui.models.models import Models from open_webui.config import ( @@ -455,14 +458,14 @@ async def get_all_models_responses(request: Request, user: UserModel) -> list: return responses -async def get_filtered_models(models, user): +async def get_filtered_models(models, user, db=None): # Filter models based on user access control filtered_models = [] for model in models.get("data", []): - model_info = Models.get_model_by_id(model["id"]) + model_info = Models.get_model_by_id(model["id"], db=db) if model_info: if user.id == model_info.user_id or has_access( - user.id, type="read", access_control=model_info.access_control + user.id, type="read", access_control=model_info.access_control, db=db ): filtered_models.append(model) return filtered_models @@ -804,6 +807,8 @@ async def generate_chat_completion( form_data: dict, user=Depends(get_verified_user), bypass_filter: Optional[bool] = False, + bypass_system_prompt: bool = False, + db: Session = Depends(get_session), ): check_credit_by_user_id(user_id=user.id, form_data=form_data) @@ -816,7 +821,7 @@ async def generate_chat_completion( metadata = payload.pop("metadata", None) model_id = form_data.get("model") - model_info = Models.get_model_by_id(model_id) + model_info = Models.get_model_by_id(model_id, db=db) # Check model info and override the payload if model_info: @@ -835,14 +840,18 @@ async def generate_chat_completion( system = params.pop("system", None) payload = apply_model_params_to_body_openai(params, payload) - payload = apply_system_prompt_to_body(system, payload, metadata, user) + if not bypass_system_prompt: + payload = apply_system_prompt_to_body(system, payload, metadata, user) # Check if user has access to the model if not bypass_filter and user.role == "user": if not ( user.id == model_info.user_id or has_access( - user.id, type="read", access_control=model_info.access_control + user.id, + type="read", + access_control=model_info.access_control, + db=db, ) ): raise HTTPException( diff --git a/backend/open_webui/routers/prompts.py b/backend/open_webui/routers/prompts.py index 6a957f2547..19d25685ad 100644 --- a/backend/open_webui/routers/prompts.py +++ b/backend/open_webui/routers/prompts.py @@ -4,6 +4,7 @@ from open_webui.models.prompts import ( PromptForm, PromptUserResponse, + PromptAccessResponse, PromptModel, Prompts, ) @@ -11,6 +12,8 @@ from open_webui.utils.auth import get_admin_user, get_verified_user from open_webui.utils.access_control import has_access, has_permission from open_webui.config import BYPASS_ADMIN_ACCESS_CONTROL +from open_webui.internal.db import get_session +from sqlalchemy.orm import Session router = APIRouter() @@ -20,23 +23,37 @@ @router.get("/", response_model=list[PromptModel]) -async def get_prompts(user=Depends(get_verified_user)): +async def get_prompts( + user=Depends(get_verified_user), db: Session = Depends(get_session) +): if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL: - prompts = Prompts.get_prompts() + prompts = Prompts.get_prompts(db=db) else: - prompts = Prompts.get_prompts_by_user_id(user.id, "read") + prompts = Prompts.get_prompts_by_user_id(user.id, "read", db=db) return prompts -@router.get("/list", response_model=list[PromptUserResponse]) -async def get_prompt_list(user=Depends(get_verified_user)): +@router.get("/list", response_model=list[PromptAccessResponse]) +async def get_prompt_list( + user=Depends(get_verified_user), db: Session = Depends(get_session) +): if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL: - prompts = Prompts.get_prompts() + prompts = Prompts.get_prompts(db=db) else: - prompts = Prompts.get_prompts_by_user_id(user.id, "write") - - return prompts + prompts = Prompts.get_prompts_by_user_id(user.id, "read", db=db) + + return [ + PromptAccessResponse( + **prompt.model_dump(), + write_access=( + (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL) + or user.id == prompt.user_id + or has_access(user.id, "write", prompt.access_control, db=db) + ), + ) + for prompt in prompts + ] ############################ @@ -46,16 +63,23 @@ async def get_prompt_list(user=Depends(get_verified_user)): @router.post("/create", response_model=Optional[PromptModel]) async def create_new_prompt( - request: Request, form_data: PromptForm, user=Depends(get_verified_user) + request: Request, + form_data: PromptForm, + user=Depends(get_verified_user), + db: Session = Depends(get_session), ): if user.role != "admin" and not ( has_permission( - user.id, "workspace.prompts", request.app.state.config.USER_PERMISSIONS + user.id, + "workspace.prompts", + request.app.state.config.USER_PERMISSIONS, + db=db, ) or has_permission( user.id, "workspace.prompts_import", request.app.state.config.USER_PERMISSIONS, + db=db, ) ): raise HTTPException( @@ -63,9 +87,9 @@ async def create_new_prompt( detail=ERROR_MESSAGES.UNAUTHORIZED, ) - prompt = Prompts.get_prompt_by_command(form_data.command) + prompt = Prompts.get_prompt_by_command(form_data.command, db=db) if prompt is None: - prompt = Prompts.insert_new_prompt(user.id, form_data) + prompt = Prompts.insert_new_prompt(user.id, form_data, db=db) if prompt: return prompt @@ -84,17 +108,26 @@ async def create_new_prompt( ############################ -@router.get("/command/{command}", response_model=Optional[PromptModel]) -async def get_prompt_by_command(command: str, user=Depends(get_verified_user)): - prompt = Prompts.get_prompt_by_command(f"/{command}") +@router.get("/command/{command}", response_model=Optional[PromptAccessResponse]) +async def get_prompt_by_command( + command: str, user=Depends(get_verified_user), db: Session = Depends(get_session) +): + prompt = Prompts.get_prompt_by_command(f"/{command}", db=db) if prompt: if ( user.role == "admin" or prompt.user_id == user.id - or has_access(user.id, "read", prompt.access_control) + or has_access(user.id, "read", prompt.access_control, db=db) ): - return prompt + return PromptAccessResponse( + **prompt.model_dump(), + write_access=( + (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL) + or user.id == prompt.user_id + or has_access(user.id, "write", prompt.access_control, db=db) + ), + ) else: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -112,8 +145,9 @@ async def update_prompt_by_command( command: str, form_data: PromptForm, user=Depends(get_verified_user), + db: Session = Depends(get_session), ): - prompt = Prompts.get_prompt_by_command(f"/{command}") + prompt = Prompts.get_prompt_by_command(f"/{command}", db=db) if not prompt: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -123,7 +157,7 @@ async def update_prompt_by_command( # Is the user the original creator, in a group with write access, or an admin if ( prompt.user_id != user.id - and not has_access(user.id, "write", prompt.access_control) + and not has_access(user.id, "write", prompt.access_control, db=db) and user.role != "admin" ): raise HTTPException( @@ -131,7 +165,7 @@ async def update_prompt_by_command( detail=ERROR_MESSAGES.ACCESS_PROHIBITED, ) - prompt = Prompts.update_prompt_by_command(f"/{command}", form_data) + prompt = Prompts.update_prompt_by_command(f"/{command}", form_data, db=db) if prompt: return prompt else: @@ -147,8 +181,10 @@ async def update_prompt_by_command( @router.delete("/command/{command}/delete", response_model=bool) -async def delete_prompt_by_command(command: str, user=Depends(get_verified_user)): - prompt = Prompts.get_prompt_by_command(f"/{command}") +async def delete_prompt_by_command( + command: str, user=Depends(get_verified_user), db: Session = Depends(get_session) +): + prompt = Prompts.get_prompt_by_command(f"/{command}", db=db) if not prompt: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -157,7 +193,7 @@ async def delete_prompt_by_command(command: str, user=Depends(get_verified_user) if ( prompt.user_id != user.id - and not has_access(user.id, "write", prompt.access_control) + and not has_access(user.id, "write", prompt.access_control, db=db) and user.role != "admin" ): raise HTTPException( @@ -165,5 +201,5 @@ async def delete_prompt_by_command(command: str, user=Depends(get_verified_user) detail=ERROR_MESSAGES.ACCESS_PROHIBITED, ) - result = Prompts.delete_prompt_by_command(f"/{command}") + result = Prompts.delete_prompt_by_command(f"/{command}", db=db) return result diff --git a/backend/open_webui/routers/retrieval.py b/backend/open_webui/routers/retrieval.py index a2c1cc80d5..db3f80f149 100644 --- a/backend/open_webui/routers/retrieval.py +++ b/backend/open_webui/routers/retrieval.py @@ -39,6 +39,8 @@ from open_webui.models.files import FileModel, FileUpdateForm, Files from open_webui.models.knowledge import Knowledges from open_webui.storage.provider import Storage +from open_webui.internal.db import get_session +from sqlalchemy.orm import Session from open_webui.retrieval.vector.factory import VECTOR_DB_CLIENT @@ -91,6 +93,7 @@ sanitize_text_for_db, ) from open_webui.utils.auth import get_admin_user, get_verified_user +from open_webui.utils.access_control import has_permission from open_webui.config import ( ENV, @@ -110,6 +113,7 @@ SENTENCE_TRANSFORMERS_MODEL_KWARGS, SENTENCE_TRANSFORMERS_CROSS_ENCODER_BACKEND, SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS, + SENTENCE_TRANSFORMERS_CROSS_ENCODER_SIGMOID_ACTIVATION_FUNCTION, ) from open_webui.constants import ERROR_MESSAGES @@ -188,6 +192,7 @@ def get_rf( raise Exception(ERROR_MESSAGES.DEFAULT(e)) else: import sentence_transformers + import torch try: rf = sentence_transformers.CrossEncoder( @@ -196,6 +201,11 @@ def get_rf( trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE, backend=SENTENCE_TRANSFORMERS_CROSS_ENCODER_BACKEND, model_kwargs=SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS, + activation_fn=( + torch.nn.Sigmoid() + if SENTENCE_TRANSFORMERS_CROSS_ENCODER_SIGMOID_ACTIVATION_FUNCTION + else None + ), ) except Exception as e: log.error(f"CrossEncoder: {e}") @@ -494,7 +504,9 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)): "RAG_EXTERNAL_RERANKER_TIMEOUT": request.app.state.config.RAG_EXTERNAL_RERANKER_TIMEOUT, # Chunking settings "TEXT_SPLITTER": request.app.state.config.TEXT_SPLITTER, + "ENABLE_MARKDOWN_HEADER_TEXT_SPLITTER": request.app.state.config.ENABLE_MARKDOWN_HEADER_TEXT_SPLITTER, "CHUNK_SIZE": request.app.state.config.CHUNK_SIZE, + "CHUNK_MIN_SIZE_TARGET": request.app.state.config.CHUNK_MIN_SIZE_TARGET, "CHUNK_OVERLAP": request.app.state.config.CHUNK_OVERLAP, # File upload settings "FILE_MAX_SIZE": request.app.state.config.FILE_MAX_SIZE, @@ -532,12 +544,14 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)): "SERPSTACK_HTTPS": request.app.state.config.SERPSTACK_HTTPS, "SERPER_API_KEY": request.app.state.config.SERPER_API_KEY, "SERPLY_API_KEY": request.app.state.config.SERPLY_API_KEY, + "DDGS_BACKEND": request.app.state.config.DDGS_BACKEND, "TAVILY_API_KEY": request.app.state.config.TAVILY_API_KEY, "SEARCHAPI_API_KEY": request.app.state.config.SEARCHAPI_API_KEY, "SEARCHAPI_ENGINE": request.app.state.config.SEARCHAPI_ENGINE, "SERPAPI_API_KEY": request.app.state.config.SERPAPI_API_KEY, "SERPAPI_ENGINE": request.app.state.config.SERPAPI_ENGINE, "JINA_API_KEY": request.app.state.config.JINA_API_KEY, + "JINA_API_BASE_URL": request.app.state.config.JINA_API_BASE_URL, "BING_SEARCH_V7_ENDPOINT": request.app.state.config.BING_SEARCH_V7_ENDPOINT, "BING_SEARCH_V7_SUBSCRIPTION_KEY": request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY, "EXA_API_KEY": request.app.state.config.EXA_API_KEY, @@ -554,6 +568,7 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)): "PLAYWRIGHT_TIMEOUT": request.app.state.config.PLAYWRIGHT_TIMEOUT, "FIRECRAWL_API_KEY": request.app.state.config.FIRECRAWL_API_KEY, "FIRECRAWL_API_BASE_URL": request.app.state.config.FIRECRAWL_API_BASE_URL, + "FIRECRAWL_TIMEOUT": request.app.state.config.FIRECRAWL_TIMEOUT, "TAVILY_EXTRACT_DEPTH": request.app.state.config.TAVILY_EXTRACT_DEPTH, "EXTERNAL_WEB_SEARCH_URL": request.app.state.config.EXTERNAL_WEB_SEARCH_URL, "EXTERNAL_WEB_SEARCH_API_KEY": request.app.state.config.EXTERNAL_WEB_SEARCH_API_KEY, @@ -592,12 +607,14 @@ class WebConfig(BaseModel): SERPSTACK_HTTPS: Optional[bool] = None SERPER_API_KEY: Optional[str] = None SERPLY_API_KEY: Optional[str] = None + DDGS_BACKEND: Optional[str] = None TAVILY_API_KEY: Optional[str] = None SEARCHAPI_API_KEY: Optional[str] = None SEARCHAPI_ENGINE: Optional[str] = None SERPAPI_API_KEY: Optional[str] = None SERPAPI_ENGINE: Optional[str] = None JINA_API_KEY: Optional[str] = None + JINA_API_BASE_URL: Optional[str] = None BING_SEARCH_V7_ENDPOINT: Optional[str] = None BING_SEARCH_V7_SUBSCRIPTION_KEY: Optional[str] = None EXA_API_KEY: Optional[str] = None @@ -614,6 +631,7 @@ class WebConfig(BaseModel): PLAYWRIGHT_TIMEOUT: Optional[int] = None FIRECRAWL_API_KEY: Optional[str] = None FIRECRAWL_API_BASE_URL: Optional[str] = None + FIRECRAWL_TIMEOUT: Optional[str] = None TAVILY_EXTRACT_DEPTH: Optional[str] = None EXTERNAL_WEB_SEARCH_URL: Optional[str] = None EXTERNAL_WEB_SEARCH_API_KEY: Optional[str] = None @@ -683,7 +701,9 @@ class ConfigForm(BaseModel): # Chunking settings TEXT_SPLITTER: Optional[str] = None + ENABLE_MARKDOWN_HEADER_TEXT_SPLITTER: Optional[bool] = None CHUNK_SIZE: Optional[int] = None + CHUNK_MIN_SIZE_TARGET: Optional[int] = None CHUNK_OVERLAP: Optional[int] = None # File upload settings @@ -991,6 +1011,11 @@ async def update_rag_config( if form_data.CHUNK_SIZE is not None else request.app.state.config.CHUNK_SIZE ) + request.app.state.config.CHUNK_MIN_SIZE_TARGET = ( + form_data.CHUNK_MIN_SIZE_TARGET + if form_data.CHUNK_MIN_SIZE_TARGET is not None + else request.app.state.config.CHUNK_MIN_SIZE_TARGET + ) request.app.state.config.CHUNK_OVERLAP = ( form_data.CHUNK_OVERLAP if form_data.CHUNK_OVERLAP is not None @@ -1075,12 +1100,14 @@ async def update_rag_config( request.app.state.config.SERPSTACK_HTTPS = form_data.web.SERPSTACK_HTTPS request.app.state.config.SERPER_API_KEY = form_data.web.SERPER_API_KEY request.app.state.config.SERPLY_API_KEY = form_data.web.SERPLY_API_KEY + request.app.state.config.DDGS_BACKEND = form_data.web.DDGS_BACKEND request.app.state.config.TAVILY_API_KEY = form_data.web.TAVILY_API_KEY request.app.state.config.SEARCHAPI_API_KEY = form_data.web.SEARCHAPI_API_KEY request.app.state.config.SEARCHAPI_ENGINE = form_data.web.SEARCHAPI_ENGINE request.app.state.config.SERPAPI_API_KEY = form_data.web.SERPAPI_API_KEY request.app.state.config.SERPAPI_ENGINE = form_data.web.SERPAPI_ENGINE request.app.state.config.JINA_API_KEY = form_data.web.JINA_API_KEY + request.app.state.config.JINA_API_BASE_URL = form_data.web.JINA_API_BASE_URL request.app.state.config.BING_SEARCH_V7_ENDPOINT = ( form_data.web.BING_SEARCH_V7_ENDPOINT ) @@ -1112,6 +1139,7 @@ async def update_rag_config( request.app.state.config.FIRECRAWL_API_BASE_URL = ( form_data.web.FIRECRAWL_API_BASE_URL ) + request.app.state.config.FIRECRAWL_TIMEOUT = form_data.web.FIRECRAWL_TIMEOUT request.app.state.config.EXTERNAL_WEB_SEARCH_URL = ( form_data.web.EXTERNAL_WEB_SEARCH_URL ) @@ -1188,6 +1216,8 @@ async def update_rag_config( # Chunking settings "TEXT_SPLITTER": request.app.state.config.TEXT_SPLITTER, "CHUNK_SIZE": request.app.state.config.CHUNK_SIZE, + "CHUNK_MIN_SIZE_TARGET": request.app.state.config.CHUNK_MIN_SIZE_TARGET, + "ENABLE_MARKDOWN_HEADER_TEXT_SPLITTER": request.app.state.config.ENABLE_MARKDOWN_HEADER_TEXT_SPLITTER, "CHUNK_OVERLAP": request.app.state.config.CHUNK_OVERLAP, # File upload settings "FILE_MAX_SIZE": request.app.state.config.FILE_MAX_SIZE, @@ -1231,6 +1261,7 @@ async def update_rag_config( "SERPAPI_API_KEY": request.app.state.config.SERPAPI_API_KEY, "SERPAPI_ENGINE": request.app.state.config.SERPAPI_ENGINE, "JINA_API_KEY": request.app.state.config.JINA_API_KEY, + "JINA_API_BASE_URL": request.app.state.config.JINA_API_BASE_URL, "BING_SEARCH_V7_ENDPOINT": request.app.state.config.BING_SEARCH_V7_ENDPOINT, "BING_SEARCH_V7_SUBSCRIPTION_KEY": request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY, "EXA_API_KEY": request.app.state.config.EXA_API_KEY, @@ -1247,6 +1278,7 @@ async def update_rag_config( "PLAYWRIGHT_TIMEOUT": request.app.state.config.PLAYWRIGHT_TIMEOUT, "FIRECRAWL_API_KEY": request.app.state.config.FIRECRAWL_API_KEY, "FIRECRAWL_API_BASE_URL": request.app.state.config.FIRECRAWL_API_BASE_URL, + "FIRECRAWL_TIMEOUT": request.app.state.config.FIRECRAWL_TIMEOUT, "TAVILY_EXTRACT_DEPTH": request.app.state.config.TAVILY_EXTRACT_DEPTH, "EXTERNAL_WEB_SEARCH_URL": request.app.state.config.EXTERNAL_WEB_SEARCH_URL, "EXTERNAL_WEB_SEARCH_API_KEY": request.app.state.config.EXTERNAL_WEB_SEARCH_API_KEY, @@ -1266,6 +1298,85 @@ async def update_rag_config( #################################### +def can_merge_chunks(a: Document, b: Document) -> bool: + if a.metadata.get("source") != b.metadata.get("source"): + return False + + a_file_id = a.metadata.get("file_id") + b_file_id = b.metadata.get("file_id") + + if a_file_id is not None and b_file_id is not None: + return a_file_id == b_file_id + + return True + + +def merge_docs_to_target_size( + request: Request, + chunks: list[Document], +) -> list[Document]: + """ + Best-effort normalization of chunk sizes. + + Attempts to grow small chunks up to a desired minimum size, + without exceeding the maximum size or crossing source/file + boundaries. + """ + min_chunk_size_target = request.app.state.config.CHUNK_MIN_SIZE_TARGET + max_chunk_size = request.app.state.config.CHUNK_SIZE + + if min_chunk_size_target <= 0: + return chunks + + measure_chunk_size = len + if request.app.state.config.TEXT_SPLITTER == "token": + encoding = tiktoken.get_encoding( + str(request.app.state.config.TIKTOKEN_ENCODING_NAME) + ) + measure_chunk_size = lambda text: len(encoding.encode(text)) + + processed_chunks: list[Document] = [] + + current_chunk: Document | None = None + current_content: str = "" + + for next_chunk in chunks: + if current_chunk is None: + current_chunk = next_chunk + current_content = next_chunk.page_content + continue # First chunk initialization + + proposed_content = f"{current_content}\n\n{next_chunk.page_content}" + + can_merge = ( + can_merge_chunks(current_chunk, next_chunk) + and measure_chunk_size(current_content) < min_chunk_size_target + and measure_chunk_size(proposed_content) <= max_chunk_size + ) + + if can_merge: + current_content = proposed_content + else: + processed_chunks.append( + Document( + page_content=current_content, + metadata={**current_chunk.metadata}, + ) + ) + current_chunk = next_chunk + current_content = next_chunk.page_content + + if current_chunk is not None: + processed_chunks.append( + Document( + page_content=current_content, + metadata={**current_chunk.metadata}, + ) + ) + + return processed_chunks + + def save_docs_to_vector_db( request: Request, docs, @@ -1303,13 +1414,46 @@ def _get_docs_info(docs: list[Document]) -> str: filter={"hash": metadata["hash"]}, ) - if result is not None: + if result is not None and result.ids and len(result.ids) > 0: existing_doc_ids = result.ids[0] if existing_doc_ids: log.info(f"Document with hash {metadata['hash']} already exists") raise ValueError(ERROR_MESSAGES.DUPLICATE_CONTENT) if split: + if request.app.state.config.ENABLE_MARKDOWN_HEADER_TEXT_SPLITTER: + log.info("Using markdown header text splitter") + # Define headers to split on - covering most common markdown header levels + markdown_splitter = MarkdownHeaderTextSplitter( + headers_to_split_on=[ + ("#", "Header 1"), + ("##", "Header 2"), + ("###", "Header 3"), + ("####", "Header 4"), + ("#####", "Header 5"), + ("######", "Header 6"), + ], + strip_headers=False, # Keep headers in content for context + ) + + split_docs = [] + for doc in docs: + split_docs.extend( + [ + Document( + page_content=split_chunk.page_content, + metadata={**doc.metadata}, + ) + for split_chunk in markdown_splitter.split_text( + doc.page_content + ) + ] + ) + + docs = split_docs + if request.app.state.config.CHUNK_MIN_SIZE_TARGET > 0: + docs = merge_docs_to_target_size(request, docs) + if request.app.state.config.TEXT_SPLITTER in ["", "character"]: text_splitter = RecursiveCharacterTextSplitter( chunk_size=request.app.state.config.CHUNK_SIZE, @@ -1330,52 +1474,6 @@ def _get_docs_info(docs: list[Document]) -> str: add_start_index=True, ) docs = text_splitter.split_documents(docs) - elif request.app.state.config.TEXT_SPLITTER == "markdown_header": - log.info("Using markdown header text splitter") - - # Define headers to split on - covering most common markdown header levels - headers_to_split_on = [ - ("#", "Header 1"), - ("##", "Header 2"), - ("###", "Header 3"), - ("####", "Header 4"), - ("#####", "Header 5"), - ("######", "Header 6"), - ] - - markdown_splitter = MarkdownHeaderTextSplitter( - headers_to_split_on=headers_to_split_on, - strip_headers=False, # Keep headers in content for context - ) - - md_split_docs = [] - for doc in docs: - md_header_splits = markdown_splitter.split_text(doc.page_content) - text_splitter = RecursiveCharacterTextSplitter( - chunk_size=request.app.state.config.CHUNK_SIZE, - chunk_overlap=request.app.state.config.CHUNK_OVERLAP, - add_start_index=True, - ) - md_header_splits = text_splitter.split_documents(md_header_splits) - - # Convert back to Document objects, preserving original metadata - for split_chunk in md_header_splits: - headings_list = [] - # Extract header values in order based on headers_to_split_on - for _, header_meta_key_name in headers_to_split_on: - if header_meta_key_name in split_chunk.metadata: - headings_list.append( - split_chunk.metadata[header_meta_key_name] - ) - - md_split_docs.append( - Document( - page_content=split_chunk.page_content, - metadata={**doc.metadata, "headings": headings_list}, - ) - ) - - docs = md_split_docs else: raise ValueError(ERROR_MESSAGES.DEFAULT("Invalid text splitter")) @@ -1484,14 +1582,15 @@ def process_file( request: Request, form_data: ProcessFileForm, user=Depends(get_verified_user), + db: Session = Depends(get_session), ): """ Process a file and save its content to the vector database. """ if user.role == "admin": - file = Files.get_file_by_id(form_data.file_id) + file = Files.get_file_by_id(form_data.file_id, db=db) else: - file = Files.get_file_by_id_and_user_id(form_data.file_id, user.id) + file = Files.get_file_by_id_and_user_id(form_data.file_id, user.id, db=db) if file: try: @@ -1633,12 +1732,13 @@ def process_file( Files.update_file_data_by_id( file.id, {"content": text_content}, + db=db, ) hash = calculate_sha256_string(text_content) - Files.update_file_hash_by_id(file.id, hash) if request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL: - Files.update_file_data_by_id(file.id, {"status": "completed"}) + Files.update_file_data_by_id(file.id, {"status": "completed"}, db=db) + Files.update_file_hash_by_id(file.id, hash, db=db) return { "status": True, "collection_name": None, @@ -1667,12 +1767,15 @@ def process_file( { "collection_name": collection_name, }, + db=db, ) Files.update_file_data_by_id( file.id, {"status": "completed"}, + db=db, ) + Files.update_file_hash_by_id(file.id, hash, db=db) return { "status": True, @@ -1690,7 +1793,10 @@ def process_file( Files.update_file_data_by_id( file.id, {"status": "failed"}, + db=db, ) + # Clear the hash so the file can be re-uploaded after fixing the issue + Files.update_file_hash_by_id(file.id, None, db=db) if "No pandoc was found" in str(e): raise HTTPException( @@ -1972,6 +2078,7 @@ def search_web( request.app.state.config.WEB_SEARCH_RESULT_COUNT, request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST, concurrent_requests=request.app.state.config.WEB_SEARCH_CONCURRENT_REQUESTS, + backend=request.app.state.config.DDGS_BACKEND, ) elif engine == "tavily": if request.app.state.config.TAVILY_API_KEY: @@ -2020,6 +2127,7 @@ def search_web( request.app.state.config.JINA_API_KEY, query, request.app.state.config.WEB_SEARCH_RESULT_COUNT, + request.app.state.config.JINA_API_BASE_URL, ) elif engine == "bing": return search_bing( @@ -2106,6 +2214,19 @@ def search_web( async def process_web_search( request: Request, form_data: SearchForm, user=Depends(get_verified_user) ): + if not request.app.state.config.ENABLE_WEB_SEARCH: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) + + if user.role != "admin" and not has_permission( + user.id, "features.web_search", request.app.state.config.USER_PERMISSIONS + ): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) urls = [] result_items = [] @@ -2417,10 +2538,19 @@ class DeleteForm(BaseModel): @router.post("/delete") -def delete_entries_from_collection(form_data: DeleteForm, user=Depends(get_admin_user)): +def delete_entries_from_collection( + form_data: DeleteForm, + user=Depends(get_admin_user), + db: Session = Depends(get_session), +): try: if VECTOR_DB_CLIENT.has_collection(collection_name=form_data.collection_name): - file = Files.get_file_by_id(form_data.file_id) + file = Files.get_file_by_id(form_data.file_id, db=db) + if not file: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=ERROR_MESSAGES.NOT_FOUND, + ) hash = file.hash VECTOR_DB_CLIENT.delete( @@ -2436,9 +2566,9 @@ def delete_entries_from_collection(form_data: DeleteForm, user=Depends(get_admin @router.post("/reset/db") -def reset_vector_db(user=Depends(get_admin_user)): +def reset_vector_db(user=Depends(get_admin_user), db: Session = Depends(get_session)): VECTOR_DB_CLIENT.reset() - Knowledges.delete_all_knowledge() + Knowledges.delete_all_knowledge(db=db) @router.post("/reset/uploads") @@ -2496,6 +2626,7 @@ async def process_files_batch( request: Request, form_data: BatchProcessFilesForm, user=Depends(get_verified_user), + db: Session = Depends(get_session), ) -> BatchProcessFilesResponse: """ Process a batch of files and save them to the vector database. @@ -2558,7 +2689,9 @@ async def process_files_batch( # Update all files with collection name for file_update, file_result in zip(file_updates, file_results): - Files.update_file_by_id(id=file_result.file_id, form_data=file_update) + Files.update_file_by_id( + id=file_result.file_id, form_data=file_update, db=db + ) file_result.status = "completed" except Exception as e: diff --git a/backend/open_webui/routers/scim.py b/backend/open_webui/routers/scim.py index bd2fd3d4f7..9070256770 100644 --- a/backend/open_webui/routers/scim.py +++ b/backend/open_webui/routers/scim.py @@ -25,6 +25,10 @@ ) from open_webui.constants import ERROR_MESSAGES + +from sqlalchemy.orm import Session +from open_webui.internal.db import get_session + log = logging.getLogger(__name__) router = APIRouter() @@ -296,7 +300,7 @@ def get_scim_auth( ) -def user_to_scim(user: UserModel, request: Request) -> SCIMUser: +def user_to_scim(user: UserModel, request: Request, db=None) -> SCIMUser: """Convert internal User model to SCIM User""" # Parse display name into name components name_parts = user.name.split(" ", 1) if user.name else ["", ""] @@ -304,7 +308,7 @@ def user_to_scim(user: UserModel, request: Request) -> SCIMUser: family_name = name_parts[1] if len(name_parts) > 1 else "" # Get user's groups - user_groups = Groups.get_groups_by_member_id(user.id) + user_groups = Groups.get_groups_by_member_id(user.id, db=db) groups = [ { "value": group.id, @@ -345,13 +349,13 @@ def user_to_scim(user: UserModel, request: Request) -> SCIMUser: ) -def group_to_scim(group: GroupModel, request: Request) -> SCIMGroup: +def group_to_scim(group: GroupModel, request: Request, db=None) -> SCIMGroup: """Convert internal Group model to SCIM Group""" - member_ids = Groups.get_group_user_ids_by_id(group.id) + member_ids = Groups.get_group_user_ids_by_id(group.id, db) or [] members = [] for user_id in member_ids: - user = Users.get_user_by_id(user_id) + user = Users.get_user_by_id(user_id, db=db) if user: members.append( SCIMGroupMember( @@ -483,6 +487,7 @@ async def get_users( count: int = Query(20, ge=1, le=100), filter: Optional[str] = None, _: bool = Depends(get_scim_auth), + db: Session = Depends(get_session), ): """List SCIM Users""" skip = startIndex - 1 @@ -494,20 +499,20 @@ async def get_users( # In production, you'd want a more robust filter parser if "userName eq" in filter: email = filter.split('"')[1] - user = Users.get_user_by_email(email) + user = Users.get_user_by_email(email, db=db) users_list = [user] if user else [] total = 1 if user else 0 else: - response = Users.get_users(skip=skip, limit=limit) + response = Users.get_users(skip=skip, limit=limit, db=db) users_list = response["users"] total = response["total"] else: - response = Users.get_users(skip=skip, limit=limit) + response = Users.get_users(skip=skip, limit=limit, db=db) users_list = response["users"] total = response["total"] # Convert to SCIM format - scim_users = [user_to_scim(user, request) for user in users_list] + scim_users = [user_to_scim(user, request, db=db) for user in users_list] return SCIMListResponse( totalResults=total, @@ -522,15 +527,16 @@ async def get_user( user_id: str, request: Request, _: bool = Depends(get_scim_auth), + db: Session = Depends(get_session), ): """Get SCIM User by ID""" - user = Users.get_user_by_id(user_id) + user = Users.get_user_by_id(user_id, db=db) if not user: return scim_error( status_code=status.HTTP_404_NOT_FOUND, detail=f"User {user_id} not found" ) - return user_to_scim(user, request) + return user_to_scim(user, request, db=db) @router.post("/Users", response_model=SCIMUser, status_code=status.HTTP_201_CREATED) @@ -538,10 +544,11 @@ async def create_user( request: Request, user_data: SCIMUserCreateRequest, _: bool = Depends(get_scim_auth), + db: Session = Depends(get_session), ): """Create SCIM User""" # Check if user already exists - existing_user = Users.get_user_by_email(user_data.userName) + existing_user = Users.get_user_by_email(user_data.userName, db=db) if existing_user: raise HTTPException( status_code=status.HTTP_409_CONFLICT, @@ -572,6 +579,7 @@ async def create_user( email=email, profile_image_url=profile_image, role="user" if user_data.active else "pending", + db=db, ) if not new_user: @@ -580,7 +588,7 @@ async def create_user( detail="Failed to create user", ) - return user_to_scim(new_user, request) + return user_to_scim(new_user, request, db=db) @router.put("/Users/{user_id}", response_model=SCIMUser) @@ -589,9 +597,10 @@ async def update_user( request: Request, user_data: SCIMUserUpdateRequest, _: bool = Depends(get_scim_auth), + db: Session = Depends(get_session), ): """Update SCIM User (full update)""" - user = Users.get_user_by_id(user_id) + user = Users.get_user_by_id(user_id, db=db) if not user: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, @@ -624,14 +633,14 @@ async def update_user( update_data["profile_image_url"] = user_data.photos[0].value # Update user - updated_user = Users.update_user_by_id(user_id, update_data) + updated_user = Users.update_user_by_id(user_id, update_data, db=db) if not updated_user: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to update user", ) - return user_to_scim(updated_user, request) + return user_to_scim(updated_user, request, db=db) @router.patch("/Users/{user_id}", response_model=SCIMUser) @@ -640,9 +649,10 @@ async def patch_user( request: Request, patch_data: SCIMPatchRequest, _: bool = Depends(get_scim_auth), + db: Session = Depends(get_session), ): """Update SCIM User (partial update)""" - user = Users.get_user_by_id(user_id) + user = Users.get_user_by_id(user_id, db=db) if not user: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, @@ -670,7 +680,7 @@ async def patch_user( # Update user if update_data: - updated_user = Users.update_user_by_id(user_id, update_data) + updated_user = Users.update_user_by_id(user_id, update_data, db=db) if not updated_user: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, @@ -679,7 +689,7 @@ async def patch_user( else: updated_user = user - return user_to_scim(updated_user, request) + return user_to_scim(updated_user, request, db=db) @router.delete("/Users/{user_id}", status_code=status.HTTP_204_NO_CONTENT) @@ -687,16 +697,17 @@ async def delete_user( user_id: str, request: Request, _: bool = Depends(get_scim_auth), + db: Session = Depends(get_session), ): """Delete SCIM User""" - user = Users.get_user_by_id(user_id) + user = Users.get_user_by_id(user_id, db=db) if not user: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=f"User {user_id} not found", ) - success = Users.delete_user_by_id(user_id) + success = Users.delete_user_by_id(user_id, db=db) if not success: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, @@ -714,10 +725,11 @@ async def get_groups( count: int = Query(20, ge=1, le=100), filter: Optional[str] = None, _: bool = Depends(get_scim_auth), + db: Session = Depends(get_session), ): """List SCIM Groups""" # Get all groups - groups_list = Groups.get_all_groups() + groups_list = Groups.get_all_groups(db=db) # Apply pagination total = len(groups_list) @@ -726,7 +738,7 @@ async def get_groups( paginated_groups = groups_list[start:end] # Convert to SCIM format - scim_groups = [group_to_scim(group, request) for group in paginated_groups] + scim_groups = [group_to_scim(group, request, db=db) for group in paginated_groups] return SCIMListResponse( totalResults=total, @@ -741,16 +753,17 @@ async def get_group( group_id: str, request: Request, _: bool = Depends(get_scim_auth), + db: Session = Depends(get_session), ): """Get SCIM Group by ID""" - group = Groups.get_group_by_id(group_id) + group = Groups.get_group_by_id(group_id, db=db) if not group: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=f"Group {group_id} not found", ) - return group_to_scim(group, request) + return group_to_scim(group, request, db=db) @router.post("/Groups", response_model=SCIMGroup, status_code=status.HTTP_201_CREATED) @@ -758,6 +771,7 @@ async def create_group( request: Request, group_data: SCIMGroupCreateRequest, _: bool = Depends(get_scim_auth), + db: Session = Depends(get_session), ): """Create SCIM Group""" # Extract member IDs @@ -775,14 +789,14 @@ async def create_group( ) # Need to get the creating user's ID - we'll use the first admin - admin_user = Users.get_super_admin_user() + admin_user = Users.get_super_admin_user(db=db) if not admin_user: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="No admin user found", ) - new_group = Groups.insert_new_group(admin_user.id, form) + new_group = Groups.insert_new_group(admin_user.id, form, db=db) if not new_group: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, @@ -798,12 +812,12 @@ async def create_group( description=new_group.description, ) - Groups.update_group_by_id(new_group.id, update_form) - Groups.set_group_user_ids_by_id(new_group.id, member_ids) + Groups.update_group_by_id(new_group.id, update_form, db=db) + Groups.set_group_user_ids_by_id(new_group.id, member_ids, db=db) - new_group = Groups.get_group_by_id(new_group.id) + new_group = Groups.get_group_by_id(new_group.id, db=db) - return group_to_scim(new_group, request) + return group_to_scim(new_group, request, db=db) @router.put("/Groups/{group_id}", response_model=SCIMGroup) @@ -812,9 +826,10 @@ async def update_group( request: Request, group_data: SCIMGroupUpdateRequest, _: bool = Depends(get_scim_auth), + db: Session = Depends(get_session), ): """Update SCIM Group (full update)""" - group = Groups.get_group_by_id(group_id) + group = Groups.get_group_by_id(group_id, db=db) if not group: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, @@ -832,17 +847,17 @@ async def update_group( # Handle members if provided if group_data.members is not None: member_ids = [member.value for member in group_data.members] - Groups.set_group_user_ids_by_id(group_id, member_ids) + Groups.set_group_user_ids_by_id(group_id, member_ids, db=db) # Update group - updated_group = Groups.update_group_by_id(group_id, update_form) + updated_group = Groups.update_group_by_id(group_id, update_form, db=db) if not updated_group: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to update group", ) - return group_to_scim(updated_group, request) + return group_to_scim(updated_group, request, db=db) @router.patch("/Groups/{group_id}", response_model=SCIMGroup) @@ -851,9 +866,10 @@ async def patch_group( request: Request, patch_data: SCIMPatchRequest, _: bool = Depends(get_scim_auth), + db: Session = Depends(get_session), ): """Update SCIM Group (partial update)""" - group = Groups.get_group_by_id(group_id) + group = Groups.get_group_by_id(group_id, db=db) if not group: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, @@ -878,7 +894,7 @@ async def patch_group( elif path == "members": # Replace all members Groups.set_group_user_ids_by_id( - group_id, [member["value"] for member in value] + group_id, [member["value"] for member in value], db=db ) elif op == "add": @@ -887,22 +903,24 @@ async def patch_group( if isinstance(value, list): for member in value: if isinstance(member, dict) and "value" in member: - Groups.add_users_to_group(group_id, [member["value"]]) + Groups.add_users_to_group( + group_id, [member["value"]], db=db + ) elif op == "remove": if path and path.startswith("members[value eq"): # Remove specific member member_id = path.split('"')[1] - Groups.remove_users_from_group(group_id, [member_id]) + Groups.remove_users_from_group(group_id, [member_id], db=db) # Update group - updated_group = Groups.update_group_by_id(group_id, update_form) + updated_group = Groups.update_group_by_id(group_id, update_form, db=db) if not updated_group: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to update group", ) - return group_to_scim(updated_group, request) + return group_to_scim(updated_group, request, db=db) @router.delete("/Groups/{group_id}", status_code=status.HTTP_204_NO_CONTENT) @@ -910,16 +928,17 @@ async def delete_group( group_id: str, request: Request, _: bool = Depends(get_scim_auth), + db: Session = Depends(get_session), ): """Delete SCIM Group""" - group = Groups.get_group_by_id(group_id) + group = Groups.get_group_by_id(group_id, db=db) if not group: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=f"Group {group_id} not found", ) - success = Groups.delete_group_by_id(group_id) + success = Groups.delete_group_by_id(group_id, db=db) if not success: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, diff --git a/backend/open_webui/routers/tools.py b/backend/open_webui/routers/tools.py index fdcaf266fa..03018d24a1 100644 --- a/backend/open_webui/routers/tools.py +++ b/backend/open_webui/routers/tools.py @@ -4,9 +4,12 @@ import time import re import aiohttp +from open_webui.env import AIOHTTP_CLIENT_TIMEOUT from open_webui.models.groups import Groups from pydantic import BaseModel, HttpUrl from fastapi import APIRouter, Depends, HTTPException, Request, status +from sqlalchemy.orm import Session +from open_webui.internal.db import get_session from open_webui.models.oauth_sessions import OAuthSessions @@ -15,6 +18,7 @@ ToolModel, ToolResponse, ToolUserResponse, + ToolAccessResponse, Tools, ) from open_webui.utils.plugin import ( @@ -51,11 +55,15 @@ def get_tool_module(request, tool_id, load_from_db=True): @router.get("/", response_model=list[ToolUserResponse]) -async def get_tools(request: Request, user=Depends(get_verified_user)): +async def get_tools( + request: Request, + user=Depends(get_verified_user), + db: Session = Depends(get_session), +): tools = [] # Local Tools - for tool in Tools.get_tools(): + for tool in Tools.get_tools(db=db): tool_module = get_tool_module(request, tool.id) tools.append( ToolUserResponse( @@ -140,12 +148,14 @@ async def get_tools(request: Request, user=Depends(get_verified_user)): # Admin can see all tools return tools else: - user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user.id)} + user_group_ids = { + group.id for group in Groups.get_groups_by_member_id(user.id, db=db) + } tools = [ tool for tool in tools if tool.user_id == user.id - or has_access(user.id, "read", tool.access_control, user_group_ids) + or has_access(user.id, "read", tool.access_control, user_group_ids, db=db) ] return tools @@ -155,13 +165,26 @@ async def get_tools(request: Request, user=Depends(get_verified_user)): ############################ -@router.get("/list", response_model=list[ToolUserResponse]) -async def get_tool_list(user=Depends(get_verified_user)): +@router.get("/list", response_model=list[ToolAccessResponse]) +async def get_tool_list( + user=Depends(get_verified_user), db: Session = Depends(get_session) +): if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL: - tools = Tools.get_tools() + tools = Tools.get_tools(db=db) else: - tools = Tools.get_tools_by_user_id(user.id, "write") - return tools + tools = Tools.get_tools_by_user_id(user.id, "read", db=db) + + return [ + ToolAccessResponse( + **tool.model_dump(), + write_access=( + (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL) + or user.id == tool.user_id + or has_access(user.id, "write", tool.access_control, db=db) + ), + ) + for tool in tools + ] ############################ @@ -218,7 +241,9 @@ async def load_tool_from_url( ) try: - async with aiohttp.ClientSession(trust_env=True) as session: + async with aiohttp.ClientSession( + trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT) + ) as session: async with session.get( url, headers={"Content-Type": "application/json"} ) as resp: @@ -245,9 +270,16 @@ async def load_tool_from_url( @router.get("/export", response_model=list[ToolModel]) -async def export_tools(request: Request, user=Depends(get_verified_user)): +async def export_tools( + request: Request, + user=Depends(get_verified_user), + db: Session = Depends(get_session), +): if user.role != "admin" and not has_permission( - user.id, "workspace.tools_export", request.app.state.config.USER_PERMISSIONS + user.id, + "workspace.tools_export", + request.app.state.config.USER_PERMISSIONS, + db=db, ): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -255,9 +287,9 @@ async def export_tools(request: Request, user=Depends(get_verified_user)): ) if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL: - return Tools.get_tools() + return Tools.get_tools(db=db) else: - return Tools.get_tools_by_user_id(user.id, "read") + return Tools.get_tools_by_user_id(user.id, "read", db=db) ############################ @@ -270,13 +302,17 @@ async def create_new_tools( request: Request, form_data: ToolForm, user=Depends(get_verified_user), + db: Session = Depends(get_session), ): if user.role != "admin" and not ( has_permission( - user.id, "workspace.tools", request.app.state.config.USER_PERMISSIONS + user.id, "workspace.tools", request.app.state.config.USER_PERMISSIONS, db=db ) or has_permission( - user.id, "workspace.tools_import", request.app.state.config.USER_PERMISSIONS + user.id, + "workspace.tools_import", + request.app.state.config.USER_PERMISSIONS, + db=db, ) ): raise HTTPException( @@ -292,7 +328,7 @@ async def create_new_tools( form_data.id = form_data.id.lower() - tools = Tools.get_tool_by_id(form_data.id) + tools = Tools.get_tool_by_id(form_data.id, db=db) if tools is None: try: form_data.content = replace_imports(form_data.content) @@ -305,7 +341,7 @@ async def create_new_tools( TOOLS[form_data.id] = tool_module specs = get_tool_specs(TOOLS[form_data.id]) - tools = Tools.insert_new_tool(user.id, form_data, specs) + tools = Tools.insert_new_tool(user.id, form_data, specs, db=db) tool_cache_dir = CACHE_DIR / "tools" / form_data.id tool_cache_dir.mkdir(parents=True, exist_ok=True) @@ -335,20 +371,34 @@ async def create_new_tools( ############################ -@router.get("/id/{id}", response_model=Optional[ToolModel]) -async def get_tools_by_id(id: str, user=Depends(get_verified_user)): - tools = Tools.get_tool_by_id(id) +@router.get("/id/{id}", response_model=Optional[ToolAccessResponse]) +async def get_tools_by_id( + id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) +): + tools = Tools.get_tool_by_id(id, db=db) if tools: if ( user.role == "admin" or tools.user_id == user.id - or has_access(user.id, "read", tools.access_control) + or has_access(user.id, "read", tools.access_control, db=db) ): - return tools + return ToolAccessResponse( + **tools.model_dump(), + write_access=( + (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL) + or user.id == tools.user_id + or has_access(user.id, "write", tools.access_control, db=db) + ), + ) + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) else: raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, + status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND, ) @@ -364,8 +414,9 @@ async def update_tools_by_id( id: str, form_data: ToolForm, user=Depends(get_verified_user), + db: Session = Depends(get_session), ): - tools = Tools.get_tool_by_id(id) + tools = Tools.get_tool_by_id(id, db=db) if not tools: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -375,7 +426,7 @@ async def update_tools_by_id( # Is the user the original creator, in a group with write access, or an admin if ( tools.user_id != user.id - and not has_access(user.id, "write", tools.access_control) + and not has_access(user.id, "write", tools.access_control, db=db) and user.role != "admin" ): raise HTTPException( @@ -399,7 +450,7 @@ async def update_tools_by_id( } log.debug(updated) - tools = Tools.update_tool_by_id(id, updated) + tools = Tools.update_tool_by_id(id, updated, db=db) if tools: return tools @@ -423,9 +474,12 @@ async def update_tools_by_id( @router.delete("/id/{id}/delete", response_model=bool) async def delete_tools_by_id( - request: Request, id: str, user=Depends(get_verified_user) + request: Request, + id: str, + user=Depends(get_verified_user), + db: Session = Depends(get_session), ): - tools = Tools.get_tool_by_id(id) + tools = Tools.get_tool_by_id(id, db=db) if not tools: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -434,7 +488,7 @@ async def delete_tools_by_id( if ( tools.user_id != user.id - and not has_access(user.id, "write", tools.access_control) + and not has_access(user.id, "write", tools.access_control, db=db) and user.role != "admin" ): raise HTTPException( @@ -442,7 +496,7 @@ async def delete_tools_by_id( detail=ERROR_MESSAGES.UNAUTHORIZED, ) - result = Tools.delete_tool_by_id(id) + result = Tools.delete_tool_by_id(id, db=db) if result: TOOLS = request.app.state.TOOLS if id in TOOLS: @@ -457,11 +511,13 @@ async def delete_tools_by_id( @router.get("/id/{id}/valves", response_model=Optional[dict]) -async def get_tools_valves_by_id(id: str, user=Depends(get_verified_user)): - tools = Tools.get_tool_by_id(id) +async def get_tools_valves_by_id( + id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) +): + tools = Tools.get_tool_by_id(id, db=db) if tools: try: - valves = Tools.get_tool_valves_by_id(id) + valves = Tools.get_tool_valves_by_id(id, db=db) return valves except Exception as e: raise HTTPException( @@ -482,9 +538,12 @@ async def get_tools_valves_by_id(id: str, user=Depends(get_verified_user)): @router.get("/id/{id}/valves/spec", response_model=Optional[dict]) async def get_tools_valves_spec_by_id( - request: Request, id: str, user=Depends(get_verified_user) + request: Request, + id: str, + user=Depends(get_verified_user), + db: Session = Depends(get_session), ): - tools = Tools.get_tool_by_id(id) + tools = Tools.get_tool_by_id(id, db=db) if tools: if id in request.app.state.TOOLS: tools_module = request.app.state.TOOLS[id] @@ -510,9 +569,13 @@ async def get_tools_valves_spec_by_id( @router.post("/id/{id}/valves/update", response_model=Optional[dict]) async def update_tools_valves_by_id( - request: Request, id: str, form_data: dict, user=Depends(get_verified_user) + request: Request, + id: str, + form_data: dict, + user=Depends(get_verified_user), + db: Session = Depends(get_session), ): - tools = Tools.get_tool_by_id(id) + tools = Tools.get_tool_by_id(id, db=db) if not tools: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -521,7 +584,7 @@ async def update_tools_valves_by_id( if ( tools.user_id != user.id - and not has_access(user.id, "write", tools.access_control) + and not has_access(user.id, "write", tools.access_control, db=db) and user.role != "admin" ): raise HTTPException( @@ -546,7 +609,7 @@ async def update_tools_valves_by_id( form_data = {k: v for k, v in form_data.items() if v is not None} valves = Valves(**form_data) valves_dict = valves.model_dump(exclude_unset=True) - Tools.update_tool_valves_by_id(id, valves_dict) + Tools.update_tool_valves_by_id(id, valves_dict, db=db) return valves_dict except Exception as e: log.exception(f"Failed to update tool valves by id {id}: {e}") @@ -562,11 +625,13 @@ async def update_tools_valves_by_id( @router.get("/id/{id}/valves/user", response_model=Optional[dict]) -async def get_tools_user_valves_by_id(id: str, user=Depends(get_verified_user)): - tools = Tools.get_tool_by_id(id) +async def get_tools_user_valves_by_id( + id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) +): + tools = Tools.get_tool_by_id(id, db=db) if tools: try: - user_valves = Tools.get_user_valves_by_id_and_user_id(id, user.id) + user_valves = Tools.get_user_valves_by_id_and_user_id(id, user.id, db=db) return user_valves except Exception as e: raise HTTPException( @@ -582,9 +647,12 @@ async def get_tools_user_valves_by_id(id: str, user=Depends(get_verified_user)): @router.get("/id/{id}/valves/user/spec", response_model=Optional[dict]) async def get_tools_user_valves_spec_by_id( - request: Request, id: str, user=Depends(get_verified_user) + request: Request, + id: str, + user=Depends(get_verified_user), + db: Session = Depends(get_session), ): - tools = Tools.get_tool_by_id(id) + tools = Tools.get_tool_by_id(id, db=db) if tools: if id in request.app.state.TOOLS: tools_module = request.app.state.TOOLS[id] @@ -605,9 +673,13 @@ async def get_tools_user_valves_spec_by_id( @router.post("/id/{id}/valves/user/update", response_model=Optional[dict]) async def update_tools_user_valves_by_id( - request: Request, id: str, form_data: dict, user=Depends(get_verified_user) + request: Request, + id: str, + form_data: dict, + user=Depends(get_verified_user), + db: Session = Depends(get_session), ): - tools = Tools.get_tool_by_id(id) + tools = Tools.get_tool_by_id(id, db=db) if tools: if id in request.app.state.TOOLS: @@ -624,7 +696,7 @@ async def update_tools_user_valves_by_id( user_valves = UserValves(**form_data) user_valves_dict = user_valves.model_dump(exclude_unset=True) Tools.update_user_valves_by_id_and_user_id( - id, user.id, user_valves_dict + id, user.id, user_valves_dict, db=db ) return user_valves_dict except Exception as e: diff --git a/backend/open_webui/routers/users.py b/backend/open_webui/routers/users.py index 75a5f53a48..828de4a429 100644 --- a/backend/open_webui/routers/users.py +++ b/backend/open_webui/routers/users.py @@ -1,6 +1,7 @@ import logging from decimal import Decimal from typing import Optional, List +from sqlalchemy.orm import Session import base64 import io @@ -37,6 +38,7 @@ from open_webui.constants import ERROR_MESSAGES from open_webui.env import STATIC_DIR +from open_webui.internal.db import get_session from open_webui.utils.auth import ( @@ -68,6 +70,7 @@ async def get_users( direction: Optional[str] = None, page: Optional[int] = 1, user=Depends(get_admin_user), + db: Session = Depends(get_session), ): limit = PAGE_ITEM_COUNT @@ -82,11 +85,17 @@ async def get_users( if direction: filter["direction"] = direction - result = Users.get_users(filter=filter, skip=skip, limit=limit) + filter["direction"] = direction + + result = Users.get_users(filter=filter, skip=skip, limit=limit, db=db) users = result["users"] total = result["total"] + # Fetch groups for all users in a single query to avoid N+1 + user_ids = [user.id for user in users] + user_groups = Groups.get_groups_by_member_ids(user_ids, db=db) + credit_map = { credit.user_id: {"credit": "%.4f" % credit.credit} for credit in Credits.list_credits_by_user_id( @@ -101,9 +110,7 @@ async def get_users( UserGroupIdsModel( **{ **user.model_dump(), - "group_ids": [ - group.id for group in Groups.get_groups_by_member_id(user.id) - ], + "group_ids": [group.id for group in user_groups.get(user.id, [])], } ) for user in users @@ -115,8 +122,9 @@ async def get_users( @router.get("/all", response_model=UserInfoListResponse) async def get_all_users( user=Depends(get_admin_user), + db: Session = Depends(get_session), ): - user_data = Users.get_users() + user_data = Users.get_users(db=db) users = user_data["users"] credit_map = { credit.user_id: {"credit": "%.4f" % credit.credit} @@ -136,16 +144,13 @@ async def search_users( direction: Optional[str] = None, page: Optional[int] = 1, user=Depends(get_verified_user), + db: Session = Depends(get_session), ): limit = PAGE_ITEM_COUNT page = max(1, page) skip = (page - 1) * limit - filter = {} - if query: - filter["query"] = query - filter = {} if query: filter["query"] = query @@ -154,7 +159,7 @@ async def search_users( if direction: filter["direction"] = direction - return Users.get_users(filter=filter, skip=skip, limit=limit) + return Users.get_users(filter=filter, skip=skip, limit=limit, db=db) ############################ @@ -163,8 +168,10 @@ async def search_users( @router.get("/groups") -async def get_user_groups(user=Depends(get_verified_user)): - return Groups.get_groups_by_member_id(user.id) +async def get_user_groups( + user=Depends(get_verified_user), db: Session = Depends(get_session) +): + return Groups.get_groups_by_member_id(user.id, db=db) ############################ @@ -173,9 +180,13 @@ async def get_user_groups(user=Depends(get_verified_user)): @router.get("/permissions") -async def get_user_permissisions(request: Request, user=Depends(get_verified_user)): +async def get_user_permissisions( + request: Request, + user=Depends(get_verified_user), + db: Session = Depends(get_session), +): user_permissions = get_permissions( - user.id, request.app.state.config.USER_PERMISSIONS + user.id, request.app.state.config.USER_PERMISSIONS, db=db ) return user_permissions @@ -242,6 +253,11 @@ class FeaturesPermissions(BaseModel): web_search: bool = True image_generation: bool = True code_interpreter: bool = True + memories: bool = True + + +class SettingsPermissions(BaseModel): + interface: bool = True class UserPermissions(BaseModel): @@ -249,6 +265,7 @@ class UserPermissions(BaseModel): sharing: SharingPermissions chat: ChatPermissions features: FeaturesPermissions + settings: SettingsPermissions @router.get("/default/permissions", response_model=UserPermissions) @@ -266,6 +283,9 @@ async def get_default_user_permissions(request: Request, user=Depends(get_admin_ "features": FeaturesPermissions( **request.app.state.config.USER_PERMISSIONS.get("features", {}) ), + "settings": SettingsPermissions( + **request.app.state.config.USER_PERMISSIONS.get("settings", {}) + ), } @@ -283,8 +303,10 @@ async def update_default_user_permissions( @router.get("/user/settings", response_model=Optional[UserSettings]) -async def get_user_settings_by_session_user(user=Depends(get_verified_user)): - user = Users.get_user_by_id(user.id) +async def get_user_settings_by_session_user( + user=Depends(get_verified_user), db: Session = Depends(get_session) +): + user = Users.get_user_by_id(user.id, db=db) if user: return user.settings else: @@ -301,12 +323,17 @@ async def get_user_settings_by_session_user(user=Depends(get_verified_user)): @router.post("/user/settings/update", response_model=UserSettings) async def update_user_settings_by_session_user( - request: Request, form_data: UserSettings, user=Depends(get_verified_user) + request: Request, + form_data: UserSettings, + user=Depends(get_verified_user), + db: Session = Depends(get_session), ): updated_user_settings = form_data.model_dump() + ui_settings = updated_user_settings.get("ui") if ( user.role != "admin" - and "toolServers" in updated_user_settings.get("ui").keys() + and ui_settings is not None + and "toolServers" in ui_settings.keys() and not has_permission( user.id, "features.direct_tool_servers", @@ -316,7 +343,7 @@ async def update_user_settings_by_session_user( # If the user is not an admin and does not have permission to use tool servers, remove the key updated_user_settings["ui"].pop("toolServers", None) - user = Users.update_user_settings_by_id(user.id, updated_user_settings) + user = Users.update_user_settings_by_id(user.id, updated_user_settings, db=db) if user: return user.settings else: @@ -332,8 +359,17 @@ async def update_user_settings_by_session_user( @router.get("/user/status") -async def get_user_status_by_session_user(user=Depends(get_verified_user)): - user = Users.get_user_by_id(user.id) +async def get_user_status_by_session_user( + request: Request, + user=Depends(get_verified_user), + db: Session = Depends(get_session), +): + if not request.app.state.config.ENABLE_USER_STATUS: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=ERROR_MESSAGES.ACTION_PROHIBITED, + ) + user = Users.get_user_by_id(user.id, db=db) if user: return user else: @@ -350,11 +386,19 @@ async def get_user_status_by_session_user(user=Depends(get_verified_user)): @router.post("/user/status/update") async def update_user_status_by_session_user( - form_data: UserStatus, user=Depends(get_verified_user) + request: Request, + form_data: UserStatus, + user=Depends(get_verified_user), + db: Session = Depends(get_session), ): - user = Users.get_user_by_id(user.id) + if not request.app.state.config.ENABLE_USER_STATUS: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=ERROR_MESSAGES.ACTION_PROHIBITED, + ) + user = Users.get_user_by_id(user.id, db=db) if user: - user = Users.update_user_status_by_id(user.id, form_data) + user = Users.update_user_status_by_id(user.id, form_data, db=db) return user else: raise HTTPException( @@ -369,8 +413,10 @@ async def update_user_status_by_session_user( @router.get("/user/info", response_model=Optional[dict]) -async def get_user_info_by_session_user(user=Depends(get_verified_user)): - user = Users.get_user_by_id(user.id) +async def get_user_info_by_session_user( + user=Depends(get_verified_user), db: Session = Depends(get_session) +): + user = Users.get_user_by_id(user.id, db=db) if user: return user.info else: @@ -387,14 +433,16 @@ async def get_user_info_by_session_user(user=Depends(get_verified_user)): @router.post("/user/info/update", response_model=Optional[dict]) async def update_user_info_by_session_user( - form_data: dict, user=Depends(get_verified_user) + form_data: dict, user=Depends(get_verified_user), db: Session = Depends(get_session) ): - user = Users.get_user_by_id(user.id) + user = Users.get_user_by_id(user.id, db=db) if user: if user.info is None: user.info = {} - user = Users.update_user_by_id(user.id, {"info": {**user.info, **form_data}}) + user = Users.update_user_by_id( + user.id, {"info": {**user.info, **form_data}}, db=db + ) if user: return user.info else: @@ -424,7 +472,9 @@ class UserActiveResponse(UserStatus): @router.get("/{user_id}", response_model=UserActiveResponse) -async def get_user_by_id(user_id: str, user=Depends(get_verified_user)): +async def get_user_by_id( + user_id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) +): # Check if user_id is a shared chat # If it is, get the user_id from the chat if user_id.startswith("shared-"): @@ -438,14 +488,14 @@ async def get_user_by_id(user_id: str, user=Depends(get_verified_user)): detail=ERROR_MESSAGES.USER_NOT_FOUND, ) - user = Users.get_user_by_id(user_id) + user = Users.get_user_by_id(user_id, db=db) if user: - groups = Groups.get_groups_by_member_id(user_id) + groups = Groups.get_groups_by_member_id(user_id, db=db) return UserActiveResponse( **{ **user.model_dump(), "groups": [{"id": group.id, "name": group.name} for group in groups], - "is_active": Users.is_user_active(user_id), + "is_active": Users.is_user_active(user_id, db=db), } ) else: @@ -456,8 +506,10 @@ async def get_user_by_id(user_id: str, user=Depends(get_verified_user)): @router.get("/{user_id}/oauth/sessions") -async def get_user_oauth_sessions_by_id(user_id: str, user=Depends(get_admin_user)): - sessions = OAuthSessions.get_sessions_by_user_id(user_id) +async def get_user_oauth_sessions_by_id( + user_id: str, user=Depends(get_admin_user), db: Session = Depends(get_session) +): + sessions = OAuthSessions.get_sessions_by_user_id(user_id, db=db) if sessions and len(sessions) > 0: return sessions else: @@ -473,8 +525,10 @@ async def get_user_oauth_sessions_by_id(user_id: str, user=Depends(get_admin_use @router.get("/{user_id}/profile/image") -async def get_user_profile_image_by_id(user_id: str, user=Depends(get_verified_user)): - user = Users.get_user_by_id(user_id) +async def get_user_profile_image_by_id( + user_id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) +): + user = Users.get_user_by_id(user_id, db=db) if user: if user.profile_image_url: # check if it's url or base64 @@ -488,11 +542,12 @@ async def get_user_profile_image_by_id(user_id: str, user=Depends(get_verified_u header, base64_data = user.profile_image_url.split(",", 1) image_data = base64.b64decode(base64_data) image_buffer = io.BytesIO(image_data) + media_type = header.split(";")[0].lstrip("data:") return StreamingResponse( image_buffer, - media_type="image/png", - headers={"Content-Disposition": "inline; filename=image.png"}, + media_type=media_type, + headers={"Content-Disposition": "inline"}, ) except Exception as e: pass @@ -510,9 +565,11 @@ async def get_user_profile_image_by_id(user_id: str, user=Depends(get_verified_u @router.get("/{user_id}/active", response_model=dict) -async def get_user_active_status_by_id(user_id: str, user=Depends(get_verified_user)): +async def get_user_active_status_by_id( + user_id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) +): return { - "active": Users.is_user_active(user_id), + "active": Users.is_user_active(user_id, db=db), } @@ -527,10 +584,11 @@ async def update_user_by_id( user_id: str, form_data: UserUpdateForm, session_user=Depends(get_admin_user), + db: Session = Depends(get_session), ): # Prevent modification of the primary admin user by other admins try: - first_user = Users.get_first_user() + first_user = Users.get_first_user(db=db) if first_user: if user_id == first_user.id: if session_user.id != user_id: @@ -554,11 +612,11 @@ async def update_user_by_id( detail="Could not verify primary admin status.", ) - user = Users.get_user_by_id(user_id) + user = Users.get_user_by_id(user_id, db=db) if user: if form_data.email.lower() != user.email: - email_user = Users.get_user_by_email(form_data.email.lower()) + email_user = Users.get_user_by_email(form_data.email.lower(), db=db) if email_user: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -572,9 +630,9 @@ async def update_user_by_id( raise HTTPException(400, detail=str(e)) hashed = get_password_hash(form_data.password) - Auths.update_user_password_by_id(user_id, hashed) + Auths.update_user_password_by_id(user_id, hashed, db=db) - Auths.update_email_by_id(user_id, form_data.email.lower()) + Auths.update_email_by_id(user_id, form_data.email.lower(), db=db) updated_user = Users.update_user_by_id( user_id, { @@ -583,6 +641,7 @@ async def update_user_by_id( "email": form_data.email.lower(), "profile_image_url": form_data.profile_image_url, }, + db=db, ) if form_data.credit is not None: @@ -663,10 +722,12 @@ async def update_credit_by_user_id( @router.delete("/{user_id}", response_model=bool) -async def delete_user_by_id(user_id: str, user=Depends(get_admin_user)): +async def delete_user_by_id( + user_id: str, user=Depends(get_admin_user), db: Session = Depends(get_session) +): # Prevent deletion of the primary admin user try: - first_user = Users.get_first_user() + first_user = Users.get_first_user(db=db) if first_user and user_id == first_user.id: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, @@ -680,7 +741,7 @@ async def delete_user_by_id(user_id: str, user=Depends(get_admin_user)): ) if user.id != user_id: - result = Auths.delete_auth_by_id(user_id) + result = Auths.delete_auth_by_id(user_id, db=db) if result: return True @@ -703,5 +764,7 @@ async def delete_user_by_id(user_id: str, user=Depends(get_admin_user)): @router.get("/{user_id}/groups") -async def get_user_groups_by_id(user_id: str, user=Depends(get_admin_user)): - return Groups.get_groups_by_member_id(user_id) +async def get_user_groups_by_id( + user_id: str, user=Depends(get_admin_user), db: Session = Depends(get_session) +): + return Groups.get_groups_by_member_id(user_id, db=db) diff --git a/backend/open_webui/socket/main.py b/backend/open_webui/socket/main.py index 72b2761c64..67e04e69c3 100644 --- a/backend/open_webui/socket/main.py +++ b/backend/open_webui/socket/main.py @@ -246,7 +246,13 @@ def get_user_ids_from_room(room): active_session_ids = get_session_ids_from_room(room) active_user_ids = list( - set([SESSION_POOL.get(session_id)["id"] for session_id in active_session_ids]) + set( + [ + SESSION_POOL.get(session_id)["id"] + for session_id in active_session_ids + if SESSION_POOL.get(session_id) is not None + ] + ) ) return active_user_ids diff --git a/backend/open_webui/test/apps/webui/routers/test_chats.py b/backend/open_webui/test/apps/webui/routers/test_chats.py deleted file mode 100644 index a36a01fb14..0000000000 --- a/backend/open_webui/test/apps/webui/routers/test_chats.py +++ /dev/null @@ -1,236 +0,0 @@ -import uuid - -from test.util.abstract_integration_test import AbstractPostgresTest -from test.util.mock_user import mock_webui_user - - -class TestChats(AbstractPostgresTest): - BASE_PATH = "/api/v1/chats" - - def setup_class(cls): - super().setup_class() - - def setup_method(self): - super().setup_method() - from open_webui.models.chats import ChatForm, Chats - - self.chats = Chats - self.chats.insert_new_chat( - "2", - ChatForm( - **{ - "chat": { - "name": "chat1", - "description": "chat1 description", - "tags": ["tag1", "tag2"], - "history": {"currentId": "1", "messages": []}, - } - } - ), - ) - - def test_get_session_user_chat_list(self): - with mock_webui_user(id="2"): - response = self.fast_api_client.get(self.create_url("/")) - assert response.status_code == 200 - first_chat = response.json()[0] - assert first_chat["id"] is not None - assert first_chat["title"] == "New Chat" - assert first_chat["created_at"] is not None - assert first_chat["updated_at"] is not None - - def test_delete_all_user_chats(self): - with mock_webui_user(id="2"): - response = self.fast_api_client.delete(self.create_url("/")) - assert response.status_code == 200 - assert len(self.chats.get_chats()) == 0 - - def test_get_user_chat_list_by_user_id(self): - with mock_webui_user(id="3"): - response = self.fast_api_client.get(self.create_url("/list/user/2")) - assert response.status_code == 200 - first_chat = response.json()[0] - assert first_chat["id"] is not None - assert first_chat["title"] == "New Chat" - assert first_chat["created_at"] is not None - assert first_chat["updated_at"] is not None - - def test_create_new_chat(self): - with mock_webui_user(id="2"): - response = self.fast_api_client.post( - self.create_url("/new"), - json={ - "chat": { - "name": "chat2", - "description": "chat2 description", - "tags": ["tag1", "tag2"], - } - }, - ) - assert response.status_code == 200 - data = response.json() - assert data["archived"] is False - assert data["chat"] == { - "name": "chat2", - "description": "chat2 description", - "tags": ["tag1", "tag2"], - } - assert data["user_id"] == "2" - assert data["id"] is not None - assert data["share_id"] is None - assert data["title"] == "New Chat" - assert data["updated_at"] is not None - assert data["created_at"] is not None - assert len(self.chats.get_chats()) == 2 - - def test_get_user_chats(self): - self.test_get_session_user_chat_list() - - def test_get_user_archived_chats(self): - self.chats.archive_all_chats_by_user_id("2") - from open_webui.internal.db import Session - - Session.commit() - with mock_webui_user(id="2"): - response = self.fast_api_client.get(self.create_url("/all/archived")) - assert response.status_code == 200 - first_chat = response.json()[0] - assert first_chat["id"] is not None - assert first_chat["title"] == "New Chat" - assert first_chat["created_at"] is not None - assert first_chat["updated_at"] is not None - - def test_get_all_user_chats_in_db(self): - with mock_webui_user(id="4"): - response = self.fast_api_client.get(self.create_url("/all/db")) - assert response.status_code == 200 - assert len(response.json()) == 1 - - def test_get_archived_session_user_chat_list(self): - self.test_get_user_archived_chats() - - def test_archive_all_chats(self): - with mock_webui_user(id="2"): - response = self.fast_api_client.post(self.create_url("/archive/all")) - assert response.status_code == 200 - assert len(self.chats.get_archived_chats_by_user_id("2")) == 1 - - def test_get_shared_chat_by_id(self): - chat_id = self.chats.get_chats()[0].id - self.chats.update_chat_share_id_by_id(chat_id, chat_id) - with mock_webui_user(id="2"): - response = self.fast_api_client.get(self.create_url(f"/share/{chat_id}")) - assert response.status_code == 200 - data = response.json() - assert data["id"] == chat_id - assert data["chat"] == { - "name": "chat1", - "description": "chat1 description", - "tags": ["tag1", "tag2"], - "history": {"currentId": "1", "messages": []}, - } - assert data["id"] == chat_id - assert data["share_id"] == chat_id - assert data["title"] == "New Chat" - - def test_get_chat_by_id(self): - chat_id = self.chats.get_chats()[0].id - with mock_webui_user(id="2"): - response = self.fast_api_client.get(self.create_url(f"/{chat_id}")) - assert response.status_code == 200 - data = response.json() - assert data["id"] == chat_id - assert data["chat"] == { - "name": "chat1", - "description": "chat1 description", - "tags": ["tag1", "tag2"], - "history": {"currentId": "1", "messages": []}, - } - assert data["share_id"] is None - assert data["title"] == "New Chat" - assert data["user_id"] == "2" - - def test_update_chat_by_id(self): - chat_id = self.chats.get_chats()[0].id - with mock_webui_user(id="2"): - response = self.fast_api_client.post( - self.create_url(f"/{chat_id}"), - json={ - "chat": { - "name": "chat2", - "description": "chat2 description", - "tags": ["tag2", "tag4"], - "title": "Just another title", - } - }, - ) - assert response.status_code == 200 - data = response.json() - assert data["id"] == chat_id - assert data["chat"] == { - "name": "chat2", - "title": "Just another title", - "description": "chat2 description", - "tags": ["tag2", "tag4"], - "history": {"currentId": "1", "messages": []}, - } - assert data["share_id"] is None - assert data["title"] == "Just another title" - assert data["user_id"] == "2" - - def test_delete_chat_by_id(self): - chat_id = self.chats.get_chats()[0].id - with mock_webui_user(id="2"): - response = self.fast_api_client.delete(self.create_url(f"/{chat_id}")) - assert response.status_code == 200 - assert response.json() is True - - def test_clone_chat_by_id(self): - chat_id = self.chats.get_chats()[0].id - with mock_webui_user(id="2"): - response = self.fast_api_client.get(self.create_url(f"/{chat_id}/clone")) - - assert response.status_code == 200 - data = response.json() - assert data["id"] != chat_id - assert data["chat"] == { - "branchPointMessageId": "1", - "description": "chat1 description", - "history": {"currentId": "1", "messages": []}, - "name": "chat1", - "originalChatId": chat_id, - "tags": ["tag1", "tag2"], - "title": "Clone of New Chat", - } - assert data["share_id"] is None - assert data["title"] == "Clone of New Chat" - assert data["user_id"] == "2" - - def test_archive_chat_by_id(self): - chat_id = self.chats.get_chats()[0].id - with mock_webui_user(id="2"): - response = self.fast_api_client.get(self.create_url(f"/{chat_id}/archive")) - assert response.status_code == 200 - - chat = self.chats.get_chat_by_id(chat_id) - assert chat.archived is True - - def test_share_chat_by_id(self): - chat_id = self.chats.get_chats()[0].id - with mock_webui_user(id="2"): - response = self.fast_api_client.post(self.create_url(f"/{chat_id}/share")) - assert response.status_code == 200 - - chat = self.chats.get_chat_by_id(chat_id) - assert chat.share_id is not None - - def test_delete_shared_chat_by_id(self): - chat_id = self.chats.get_chats()[0].id - share_id = str(uuid.uuid4()) - self.chats.update_chat_share_id_by_id(chat_id, share_id) - with mock_webui_user(id="2"): - response = self.fast_api_client.delete(self.create_url(f"/{chat_id}/share")) - assert response.status_code - - chat = self.chats.get_chat_by_id(chat_id) - assert chat.share_id is None diff --git a/backend/open_webui/test/util/abstract_integration_test.py b/backend/open_webui/test/util/abstract_integration_test.py deleted file mode 100644 index e8492befb6..0000000000 --- a/backend/open_webui/test/util/abstract_integration_test.py +++ /dev/null @@ -1,161 +0,0 @@ -import logging -import os -import time - -import docker -import pytest -from docker import DockerClient -from pytest_docker.plugin import get_docker_ip -from fastapi.testclient import TestClient -from sqlalchemy import text, create_engine - - -log = logging.getLogger(__name__) - - -def get_fast_api_client(): - from main import app - - with TestClient(app) as c: - return c - - -class AbstractIntegrationTest: - BASE_PATH = None - - def create_url(self, path="", query_params=None): - if self.BASE_PATH is None: - raise Exception("BASE_PATH is not set") - parts = self.BASE_PATH.split("/") - parts = [part.strip() for part in parts if part.strip() != ""] - path_parts = path.split("/") - path_parts = [part.strip() for part in path_parts if part.strip() != ""] - query_parts = "" - if query_params: - query_parts = "&".join( - [f"{key}={value}" for key, value in query_params.items()] - ) - query_parts = f"?{query_parts}" - return "/".join(parts + path_parts) + query_parts - - @classmethod - def setup_class(cls): - pass - - def setup_method(self): - pass - - @classmethod - def teardown_class(cls): - pass - - def teardown_method(self): - pass - - -class AbstractPostgresTest(AbstractIntegrationTest): - DOCKER_CONTAINER_NAME = "postgres-test-container-will-get-deleted" - docker_client: DockerClient - - @classmethod - def _create_db_url(cls, env_vars_postgres: dict) -> str: - host = get_docker_ip() - user = env_vars_postgres["POSTGRES_USER"] - pw = env_vars_postgres["POSTGRES_PASSWORD"] - port = 8081 - db = env_vars_postgres["POSTGRES_DB"] - return f"postgresql://{user}:{pw}@{host}:{port}/{db}" - - @classmethod - def setup_class(cls): - super().setup_class() - try: - env_vars_postgres = { - "POSTGRES_USER": "user", - "POSTGRES_PASSWORD": "example", - "POSTGRES_DB": "openwebui", - } - cls.docker_client = docker.from_env() - cls.docker_client.containers.run( - "postgres:16.2", - detach=True, - environment=env_vars_postgres, - name=cls.DOCKER_CONTAINER_NAME, - ports={5432: ("0.0.0.0", 8081)}, - command="postgres -c log_statement=all", - ) - time.sleep(0.5) - - database_url = cls._create_db_url(env_vars_postgres) - os.environ["DATABASE_URL"] = database_url - retries = 10 - db = None - while retries > 0: - try: - from open_webui.config import OPEN_WEBUI_DIR - - db = create_engine(database_url, pool_pre_ping=True) - db = db.connect() - log.info("postgres is ready!") - break - except Exception as e: - log.warning(e) - time.sleep(3) - retries -= 1 - - if db: - # import must be after setting env! - cls.fast_api_client = get_fast_api_client() - db.close() - else: - raise Exception("Could not connect to Postgres") - except Exception as ex: - log.error(ex) - cls.teardown_class() - pytest.fail(f"Could not setup test environment: {ex}") - - def _check_db_connection(self): - from open_webui.internal.db import Session - - retries = 10 - while retries > 0: - try: - Session.execute(text("SELECT 1")) - Session.commit() - break - except Exception as e: - Session.rollback() - log.warning(e) - time.sleep(3) - retries -= 1 - - def setup_method(self): - super().setup_method() - self._check_db_connection() - - @classmethod - def teardown_class(cls) -> None: - super().teardown_class() - cls.docker_client.containers.get(cls.DOCKER_CONTAINER_NAME).remove(force=True) - - def teardown_method(self): - from open_webui.internal.db import Session - - # rollback everything not yet committed - Session.commit() - - # truncate all tables - tables = [ - "auth", - "chat", - "chatidtag", - "document", - "memory", - "model", - "prompt", - "tag", - '"user"', - ] - for table in tables: - Session.execute(text(f"TRUNCATE TABLE {table}")) - Session.commit() diff --git a/backend/open_webui/test/util/mock_user.py b/backend/open_webui/test/util/mock_user.py deleted file mode 100644 index 7ce64dffa9..0000000000 --- a/backend/open_webui/test/util/mock_user.py +++ /dev/null @@ -1,45 +0,0 @@ -from contextlib import contextmanager - -from fastapi import FastAPI - - -@contextmanager -def mock_webui_user(**kwargs): - from open_webui.routers.webui import app - - with mock_user(app, **kwargs): - yield - - -@contextmanager -def mock_user(app: FastAPI, **kwargs): - from open_webui.utils.auth import ( - get_current_user, - get_verified_user, - get_admin_user, - get_current_user_by_api_key, - ) - from open_webui.models.users import User - - def create_user(): - user_parameters = { - "id": "1", - "name": "John Doe", - "email": "john.doe@openwebui.com", - "role": "user", - "profile_image_url": "/user.png", - "last_active_at": 1627351200, - "updated_at": 1627351200, - "created_at": 162735120, - **kwargs, - } - return User(**user_parameters) - - app.dependency_overrides = { - get_current_user: create_user, - get_verified_user: create_user, - get_admin_user: create_user, - get_current_user_by_api_key: create_user, - } - yield - app.dependency_overrides = {} diff --git a/backend/open_webui/tools/__init__.py b/backend/open_webui/tools/__init__.py new file mode 100644 index 0000000000..112324b569 --- /dev/null +++ b/backend/open_webui/tools/__init__.py @@ -0,0 +1,6 @@ +""" +Open WebUI Tools Package. + +This package contains built-in tools that are automatically available +when native function calling is enabled. +""" diff --git a/backend/open_webui/tools/builtin.py b/backend/open_webui/tools/builtin.py new file mode 100644 index 0000000000..eb3b7cfc9f --- /dev/null +++ b/backend/open_webui/tools/builtin.py @@ -0,0 +1,1671 @@ +""" +Built-in tools for Open WebUI. + +These tools are automatically available when native function calling is enabled. + +IMPORTANT: DO NOT IMPORT THIS MODULE DIRECTLY IN OTHER PARTS OF THE CODEBASE. +""" + +import json +import logging +import time +import asyncio +from typing import Optional + +from fastapi import Request + +from open_webui.models.users import UserModel +from open_webui.routers.retrieval import search_web as _search_web +from open_webui.retrieval.utils import get_content_from_url +from open_webui.routers.images import ( + image_generations, + image_edits, + CreateImageForm, + EditImageForm, +) +from open_webui.routers.memories import ( + query_memory, + add_memory as _add_memory, + update_memory_by_id, + QueryMemoryForm, + AddMemoryForm, + MemoryUpdateModel, +) +from open_webui.models.notes import Notes +from open_webui.models.chats import Chats +from open_webui.models.channels import Channels, ChannelMember, Channel +from open_webui.models.messages import Messages, Message +from open_webui.models.groups import Groups + +log = logging.getLogger(__name__) + +MAX_KNOWLEDGE_BASE_SEARCH_ITEMS = 10_000 + +# ============================================================================= +# TIME UTILITIES +# ============================================================================= + + +async def get_current_timestamp( + __request__: Request = None, + __user__: dict = None, +) -> str: + """ + Get the current Unix timestamp in seconds. + + :return: JSON with current_timestamp (seconds) and current_iso (ISO format) + """ + try: + import datetime + + now = datetime.datetime.now(datetime.timezone.utc) + return json.dumps( + { + "current_timestamp": int(now.timestamp()), + "current_iso": now.isoformat(), + }, + ensure_ascii=False, + ) + except Exception as e: + log.exception(f"get_current_timestamp error: {e}") + return json.dumps({"error": str(e)}) + + +async def calculate_timestamp( + days_ago: int = 0, + weeks_ago: int = 0, + months_ago: int = 0, + years_ago: int = 0, + __request__: Request = None, + __user__: dict = None, +) -> str: + """ + Get the current Unix timestamp, optionally adjusted by days, weeks, months, or years. + Use this to calculate timestamps for date filtering in search functions. + Examples: "last week" = weeks_ago=1, "3 days ago" = days_ago=3, "a year ago" = years_ago=1 + + :param days_ago: Number of days to subtract from current time (default: 0) + :param weeks_ago: Number of weeks to subtract from current time (default: 0) + :param months_ago: Number of months to subtract from current time (default: 0) + :param years_ago: Number of years to subtract from current time (default: 0) + :return: JSON with current_timestamp and calculated_timestamp (both in seconds) + """ + try: + import datetime + from dateutil.relativedelta import relativedelta + + now = datetime.datetime.now(datetime.timezone.utc) + current_ts = int(now.timestamp()) + + # Calculate the adjusted time + total_days = days_ago + (weeks_ago * 7) + adjusted = now - datetime.timedelta(days=total_days) + + # Handle months and years separately (variable length) + if months_ago > 0 or years_ago > 0: + adjusted = adjusted - relativedelta(months=months_ago, years=years_ago) + + adjusted_ts = int(adjusted.timestamp()) + + return json.dumps( + { + "current_timestamp": current_ts, + "current_iso": now.isoformat(), + "calculated_timestamp": adjusted_ts, + "calculated_iso": adjusted.isoformat(), + }, + ensure_ascii=False, + ) + except ImportError: + # Fallback without dateutil + import datetime + + now = datetime.datetime.now(datetime.timezone.utc) + current_ts = int(now.timestamp()) + total_days = days_ago + (weeks_ago * 7) + (months_ago * 30) + (years_ago * 365) + adjusted = now - datetime.timedelta(days=total_days) + adjusted_ts = int(adjusted.timestamp()) + return json.dumps( + { + "current_timestamp": current_ts, + "current_iso": now.isoformat(), + "calculated_timestamp": adjusted_ts, + "calculated_iso": adjusted.isoformat(), + }, + ensure_ascii=False, + ) + except Exception as e: + log.exception(f"calculate_timestamp error: {e}") + return json.dumps({"error": str(e)}) + + +# ============================================================================= +# WEB SEARCH TOOLS +# ============================================================================= + + +async def search_web( + query: str, + count: int = 5, + __request__: Request = None, + __user__: dict = None, +) -> str: + """ + Search the public web for information. Best for current events, external references, + or topics not covered in internal documents. If knowledge base tools are available, + consider checking those first for internal information. + + :param query: The search query to look up + :param count: Number of results to return (default: 5) + :return: JSON with search results containing title, link, and snippet for each result + """ + if __request__ is None: + return json.dumps({"error": "Request context not available"}) + + try: + engine = __request__.app.state.config.WEB_SEARCH_ENGINE + user = UserModel(**__user__) if __user__ else None + + results = _search_web(__request__, engine, query, user) + + # Limit results + results = results[:count] if results else [] + + return json.dumps( + [{"title": r.title, "link": r.link, "snippet": r.snippet} for r in results], + ensure_ascii=False, + ) + except Exception as e: + log.exception(f"search_web error: {e}") + return json.dumps({"error": str(e)}) + + +async def fetch_url( + url: str, + __request__: Request = None, + __user__: dict = None, +) -> str: + """ + Fetch and extract the main text content from a web page URL. + + :param url: The URL to fetch content from + :return: The extracted text content from the page + """ + if __request__ is None: + return json.dumps({"error": "Request context not available"}) + + try: + content, _ = await asyncio.to_thread(get_content_from_url, __request__, url) + + # Truncate if too long (avoid overwhelming context) + max_length = 50000 + if len(content) > max_length: + content = content[:max_length] + "\n\n[Content truncated...]" + + return content + except Exception as e: + log.exception(f"fetch_url error: {e}") + return json.dumps({"error": str(e)}) + + +# ============================================================================= +# IMAGE GENERATION TOOLS +# ============================================================================= + + +async def generate_image( + prompt: str, + __request__: Request = None, + __user__: dict = None, + __event_emitter__: callable = None, + __chat_id__: str = None, + __message_id__: str = None, +) -> str: + """ + Generate an image based on a text prompt. + + :param prompt: A detailed description of the image to generate + :return: Confirmation that the image was generated, or an error message + """ + if __request__ is None: + return json.dumps({"error": "Request context not available"}) + + try: + user = UserModel(**__user__) if __user__ else None + + images = await image_generations( + request=__request__, + form_data=CreateImageForm(prompt=prompt), + user=user, + ) + + # Prepare file entries for the images + image_files = [{"type": "image", "url": img["url"]} for img in images] + + # Persist files to DB if chat context is available + if __chat_id__ and __message_id__ and images: + image_files = Chats.add_message_files_by_id_and_message_id( + __chat_id__, + __message_id__, + image_files, + ) + + # Emit the images to the UI if event emitter is available + if __event_emitter__ and image_files: + await __event_emitter__( + { + "type": "chat:message:files", + "data": { + "files": image_files, + }, + } + ) + # Return a message indicating the image is already displayed + return json.dumps( + { + "status": "success", + "message": "The image has been successfully generated and is already visible to the user in the chat. You do not need to display or embed the image again - just acknowledge that it has been created.", + "images": images, + }, + ensure_ascii=False, + ) + + return json.dumps({"status": "success", "images": images}, ensure_ascii=False) + except Exception as e: + log.exception(f"generate_image error: {e}") + return json.dumps({"error": str(e)}) + + +async def edit_image( + prompt: str, + image_urls: list[str], + __request__: Request = None, + __user__: dict = None, + __event_emitter__: callable = None, + __chat_id__: str = None, + __message_id__: str = None, +) -> str: + """ + Edit existing images based on a text prompt. + + :param prompt: A description of the changes to make to the images + :param image_urls: A list of URLs of the images to edit + :return: Confirmation that the images were edited, or an error message + """ + if __request__ is None: + return json.dumps({"error": "Request context not available"}) + + try: + user = UserModel(**__user__) if __user__ else None + + images = await image_edits( + request=__request__, + form_data=EditImageForm(prompt=prompt, image=image_urls), + user=user, + ) + + # Prepare file entries for the images + image_files = [{"type": "image", "url": img["url"]} for img in images] + + # Persist files to DB if chat context is available + if __chat_id__ and __message_id__ and images: + image_files = Chats.add_message_files_by_id_and_message_id( + __chat_id__, + __message_id__, + image_files, + ) + + # Emit the images to the UI if event emitter is available + if __event_emitter__ and image_files: + await __event_emitter__( + { + "type": "chat:message:files", + "data": { + "files": image_files, + }, + } + ) + # Return a message indicating the image is already displayed + return json.dumps( + { + "status": "success", + "message": "The edited image has been successfully generated and is already visible to the user in the chat. You do not need to display or embed the image again - just acknowledge that it has been created.", + "images": images, + }, + ensure_ascii=False, + ) + + return json.dumps({"status": "success", "images": images}, ensure_ascii=False) + except Exception as e: + log.exception(f"edit_image error: {e}") + return json.dumps({"error": str(e)}) + + +# ============================================================================= +# MEMORY TOOLS +# ============================================================================= + + +async def search_memories( + query: str, + count: int = 5, + __request__: Request = None, + __user__: dict = None, +) -> str: + """ + Search the user's stored memories for relevant information. + + :param query: The search query to find relevant memories + :param count: Number of memories to return (default 5) + :return: JSON with matching memories and their dates + """ + if __request__ is None: + return json.dumps({"error": "Request context not available"}) + + try: + user = UserModel(**__user__) if __user__ else None + + results = await query_memory( + __request__, + QueryMemoryForm(content=query, k=count), + user, + ) + + if results and hasattr(results, "documents") and results.documents: + memories = [] + for doc_idx, doc in enumerate(results.documents[0]): + memory_id = None + if results.ids and results.ids[0]: + memory_id = results.ids[0][doc_idx] + created_at = "Unknown" + if results.metadatas and results.metadatas[0][doc_idx].get( + "created_at" + ): + created_at = time.strftime( + "%Y-%m-%d", + time.localtime(results.metadatas[0][doc_idx]["created_at"]), + ) + memories.append({"id": memory_id, "date": created_at, "content": doc}) + return json.dumps(memories, ensure_ascii=False) + else: + return json.dumps([]) + except Exception as e: + log.exception(f"search_memories error: {e}") + return json.dumps({"error": str(e)}) + + +async def add_memory( + content: str, + __request__: Request = None, + __user__: dict = None, +) -> str: + """ + Store a new memory for the user. + + :param content: The memory content to store + :return: Confirmation that the memory was stored + """ + if __request__ is None: + return json.dumps({"error": "Request context not available"}) + + try: + user = UserModel(**__user__) if __user__ else None + + memory = await _add_memory( + __request__, + AddMemoryForm(content=content), + user, + ) + + return json.dumps({"status": "success", "id": memory.id}, ensure_ascii=False) + except Exception as e: + log.exception(f"add_memory error: {e}") + return json.dumps({"error": str(e)}) + + +async def replace_memory_content( + memory_id: str, + content: str, + __request__: Request = None, + __user__: dict = None, +) -> str: + """ + Update the content of an existing memory by its ID. + + :param memory_id: The ID of the memory to update + :param content: The new content for the memory + :return: Confirmation that the memory was updated + """ + if __request__ is None: + return json.dumps({"error": "Request context not available"}) + + try: + user = UserModel(**__user__) if __user__ else None + + memory = await update_memory_by_id( + memory_id=memory_id, + request=__request__, + form_data=MemoryUpdateModel(content=content), + user=user, + ) + + return json.dumps( + {"status": "success", "id": memory.id, "content": memory.content}, + ensure_ascii=False, + ) + except Exception as e: + log.exception(f"replace_memory_content error: {e}") + return json.dumps({"error": str(e)}) + + +# ============================================================================= +# NOTES TOOLS +# ============================================================================= + + +async def search_notes( + query: str, + count: int = 5, + start_timestamp: Optional[int] = None, + end_timestamp: Optional[int] = None, + __request__: Request = None, + __user__: dict = None, +) -> str: + """ + Search the user's notes by title and content. + + :param query: The search query to find matching notes + :param count: Maximum number of results to return (default: 5) + :param start_timestamp: Only include notes updated after this Unix timestamp (seconds) + :param end_timestamp: Only include notes updated before this Unix timestamp (seconds) + :return: JSON with matching notes containing id, title, and content snippet + """ + if __request__ is None: + return json.dumps({"error": "Request context not available"}) + + if not __user__: + return json.dumps({"error": "User context not available"}) + + try: + user_id = __user__.get("id") + user_group_ids = [group.id for group in Groups.get_groups_by_member_id(user_id)] + + result = Notes.search_notes( + user_id=user_id, + filter={ + "query": query, + "user_id": user_id, + "group_ids": user_group_ids, + "permission": "read", + }, + skip=0, + limit=count * 3, # Fetch more for filtering + ) + + # Convert timestamps to nanoseconds for comparison + start_ts = start_timestamp * 1_000_000_000 if start_timestamp else None + end_ts = end_timestamp * 1_000_000_000 if end_timestamp else None + + notes = [] + for note in result.items: + # Apply date filters (updated_at is in nanoseconds) + if start_ts and note.updated_at < start_ts: + continue + if end_ts and note.updated_at > end_ts: + continue + + # Extract a snippet from the markdown content + content_snippet = "" + if note.data and note.data.get("content", {}).get("md"): + md_content = note.data["content"]["md"] + lower_content = md_content.lower() + lower_query = query.lower() + idx = lower_content.find(lower_query) + if idx != -1: + start = max(0, idx - 50) + end = min(len(md_content), idx + len(query) + 100) + content_snippet = ( + ("..." if start > 0 else "") + + md_content[start:end] + + ("..." if end < len(md_content) else "") + ) + else: + content_snippet = md_content[:150] + ( + "..." if len(md_content) > 150 else "" + ) + + notes.append( + { + "id": note.id, + "title": note.title, + "snippet": content_snippet, + "updated_at": note.updated_at, + } + ) + + if len(notes) >= count: + break + + return json.dumps(notes, ensure_ascii=False) + except Exception as e: + log.exception(f"search_notes error: {e}") + return json.dumps({"error": str(e)}) + + +async def view_note( + note_id: str, + __request__: Request = None, + __user__: dict = None, +) -> str: + """ + Get the full content of a note by its ID. + + :param note_id: The ID of the note to retrieve + :return: JSON with the note's id, title, and full markdown content + """ + if __request__ is None: + return json.dumps({"error": "Request context not available"}) + + if not __user__: + return json.dumps({"error": "User context not available"}) + + try: + note = Notes.get_note_by_id(note_id) + + if not note: + return json.dumps({"error": "Note not found"}) + + # Check access permission + user_id = __user__.get("id") + user_group_ids = [group.id for group in Groups.get_groups_by_member_id(user_id)] + + from open_webui.utils.access_control import has_access + + if note.user_id != user_id and not has_access( + user_id, "read", note.access_control, user_group_ids + ): + return json.dumps({"error": "Access denied"}) + + # Extract markdown content + content = "" + if note.data and note.data.get("content", {}).get("md"): + content = note.data["content"]["md"] + + return json.dumps( + { + "id": note.id, + "title": note.title, + "content": content, + "updated_at": note.updated_at, + "created_at": note.created_at, + }, + ensure_ascii=False, + ) + except Exception as e: + log.exception(f"view_note error: {e}") + return json.dumps({"error": str(e)}) + + +async def write_note( + title: str, + content: str, + __request__: Request = None, + __user__: dict = None, +) -> str: + """ + Create a new note with the given title and content. + + :param title: The title of the new note + :param content: The markdown content for the note + :return: JSON with success status and new note id + """ + if __request__ is None: + return json.dumps({"error": "Request context not available"}) + + if not __user__: + return json.dumps({"error": "User context not available"}) + + try: + from open_webui.models.notes import NoteForm + + user_id = __user__.get("id") + + form = NoteForm( + title=title, + data={"content": {"md": content}}, + access_control={}, # Private by default - only owner can access + ) + + new_note = Notes.insert_new_note(user_id, form) + + if not new_note: + return json.dumps({"error": "Failed to create note"}) + + return json.dumps( + { + "status": "success", + "id": new_note.id, + "title": new_note.title, + "created_at": new_note.created_at, + }, + ensure_ascii=False, + ) + except Exception as e: + log.exception(f"write_note error: {e}") + return json.dumps({"error": str(e)}) + + +async def replace_note_content( + note_id: str, + content: str, + title: Optional[str] = None, + __request__: Request = None, + __user__: dict = None, +) -> str: + """ + Update the content of a note. Use this to modify task lists, add notes, or update content. + + :param note_id: The ID of the note to update + :param content: The new markdown content for the note + :param title: Optional new title for the note + :return: JSON with success status and updated note info + """ + if __request__ is None: + return json.dumps({"error": "Request context not available"}) + + if not __user__: + return json.dumps({"error": "User context not available"}) + + try: + from open_webui.models.notes import NoteUpdateForm + + note = Notes.get_note_by_id(note_id) + + if not note: + return json.dumps({"error": "Note not found"}) + + # Check write permission + user_id = __user__.get("id") + user_group_ids = [group.id for group in Groups.get_groups_by_member_id(user_id)] + + from open_webui.utils.access_control import has_access + + if note.user_id != user_id and not has_access( + user_id, "write", note.access_control, user_group_ids + ): + return json.dumps({"error": "Write access denied"}) + + # Build update form + update_data = {"data": {"content": {"md": content}}} + if title: + update_data["title"] = title + + form = NoteUpdateForm(**update_data) + updated_note = Notes.update_note_by_id(note_id, form) + + if not updated_note: + return json.dumps({"error": "Failed to update note"}) + + return json.dumps( + { + "status": "success", + "id": updated_note.id, + "title": updated_note.title, + "updated_at": updated_note.updated_at, + }, + ensure_ascii=False, + ) + except Exception as e: + log.exception(f"replace_note_content error: {e}") + return json.dumps({"error": str(e)}) + + +# ============================================================================= +# CHATS TOOLS +# ============================================================================= + + +async def search_chats( + query: str, + count: int = 5, + start_timestamp: Optional[int] = None, + end_timestamp: Optional[int] = None, + __request__: Request = None, + __user__: dict = None, +) -> str: + """ + Search the user's previous chat conversations by title and message content. + + :param query: The search query to find matching chats + :param count: Maximum number of results to return (default: 5) + :param start_timestamp: Only include chats updated after this Unix timestamp (seconds) + :param end_timestamp: Only include chats updated before this Unix timestamp (seconds) + :return: JSON with matching chats containing id, title, updated_at, and content snippet + """ + if __request__ is None: + return json.dumps({"error": "Request context not available"}) + + if not __user__: + return json.dumps({"error": "User context not available"}) + + try: + user_id = __user__.get("id") + + chats = Chats.get_chats_by_user_id_and_search_text( + user_id=user_id, + search_text=query, + include_archived=False, + skip=0, + limit=count * 3, # Fetch more for filtering + ) + + results = [] + for chat in chats: + # Apply date filters (updated_at is in seconds) + if start_timestamp and chat.updated_at < start_timestamp: + continue + if end_timestamp and chat.updated_at > end_timestamp: + continue + + # Find a matching message snippet + snippet = "" + messages = chat.chat.get("history", {}).get("messages", {}) + lower_query = query.lower() + + for msg_id, msg in messages.items(): + content = msg.get("content", "") + if isinstance(content, str) and lower_query in content.lower(): + idx = content.lower().find(lower_query) + start = max(0, idx - 50) + end = min(len(content), idx + len(query) + 100) + snippet = ( + ("..." if start > 0 else "") + + content[start:end] + + ("..." if end < len(content) else "") + ) + break + + if not snippet and lower_query in chat.title.lower(): + snippet = f"Title match: {chat.title}" + + results.append( + { + "id": chat.id, + "title": chat.title, + "snippet": snippet, + "updated_at": chat.updated_at, + } + ) + + if len(results) >= count: + break + + return json.dumps(results, ensure_ascii=False) + except Exception as e: + log.exception(f"search_chats error: {e}") + return json.dumps({"error": str(e)}) + + +async def view_chat( + chat_id: str, + __request__: Request = None, + __user__: dict = None, +) -> str: + """ + Get the full conversation history of a chat by its ID. + + :param chat_id: The ID of the chat to retrieve + :return: JSON with the chat's id, title, and messages + """ + if __request__ is None: + return json.dumps({"error": "Request context not available"}) + + if not __user__: + return json.dumps({"error": "User context not available"}) + + try: + user_id = __user__.get("id") + + chat = Chats.get_chat_by_id_and_user_id(chat_id, user_id) + + if not chat: + return json.dumps({"error": "Chat not found or access denied"}) + + # Extract messages from history + messages = [] + history = chat.chat.get("history", {}) + msg_dict = history.get("messages", {}) + + # Build message chain from currentId + current_id = history.get("currentId") + visited = set() + + while current_id and current_id not in visited: + visited.add(current_id) + msg = msg_dict.get(current_id) + if msg: + messages.append( + { + "role": msg.get("role", ""), + "content": msg.get("content", ""), + } + ) + current_id = msg.get("parentId") if msg else None + + # Reverse to get chronological order + messages.reverse() + + return json.dumps( + { + "id": chat.id, + "title": chat.title, + "messages": messages, + "updated_at": chat.updated_at, + "created_at": chat.created_at, + }, + ensure_ascii=False, + ) + except Exception as e: + log.exception(f"view_chat error: {e}") + return json.dumps({"error": str(e)}) + + +# ============================================================================= +# CHANNELS TOOLS +# ============================================================================= + + +async def search_channels( + query: str, + count: int = 5, + __request__: Request = None, + __user__: dict = None, +) -> str: + """ + Search for channels by name and description that the user has access to. + + :param query: The search query to find matching channels + :param count: Maximum number of results to return (default: 5) + :return: JSON with matching channels containing id, name, description, and type + """ + if __request__ is None: + return json.dumps({"error": "Request context not available"}) + + if not __user__: + return json.dumps({"error": "User context not available"}) + + try: + user_id = __user__.get("id") + + # Get all channels the user has access to + all_channels = Channels.get_channels_by_user_id(user_id) + + # Filter by query + lower_query = query.lower() + matching_channels = [] + + for channel in all_channels: + name_match = lower_query in channel.name.lower() if channel.name else False + desc_match = lower_query in (channel.description or "").lower() + + if name_match or desc_match: + matching_channels.append( + { + "id": channel.id, + "name": channel.name, + "description": channel.description or "", + "type": channel.type or "public", + } + ) + + if len(matching_channels) >= count: + break + + return json.dumps(matching_channels, ensure_ascii=False) + except Exception as e: + log.exception(f"search_channels error: {e}") + return json.dumps({"error": str(e)}) + + +async def search_channel_messages( + query: str, + count: int = 10, + start_timestamp: Optional[int] = None, + end_timestamp: Optional[int] = None, + __request__: Request = None, + __user__: dict = None, +) -> str: + """ + Search for messages in channels the user is a member of, including thread replies. + + :param query: The search query to find matching messages + :param count: Maximum number of results to return (default: 10) + :param start_timestamp: Only include messages created after this Unix timestamp (seconds) + :param end_timestamp: Only include messages created before this Unix timestamp (seconds) + :return: JSON with matching messages containing channel info, message content, and thread context + """ + if __request__ is None: + return json.dumps({"error": "Request context not available"}) + + if not __user__: + return json.dumps({"error": "User context not available"}) + + try: + user_id = __user__.get("id") + + # Get all channels the user has access to + user_channels = Channels.get_channels_by_user_id(user_id) + channel_ids = [c.id for c in user_channels] + channel_map = {c.id: c for c in user_channels} + + if not channel_ids: + return json.dumps([]) + + # Convert timestamps to nanoseconds (Message.created_at is in nanoseconds) + start_ts = start_timestamp * 1_000_000_000 if start_timestamp else None + end_ts = end_timestamp * 1_000_000_000 if end_timestamp else None + + # Search messages using the model method + matching_messages = Messages.search_messages_by_channel_ids( + channel_ids=channel_ids, + query=query, + start_timestamp=start_ts, + end_timestamp=end_ts, + limit=count, + ) + + results = [] + for msg in matching_messages: + channel = channel_map.get(msg.channel_id) + + # Extract snippet around the match + content = msg.content or "" + lower_query = query.lower() + idx = content.lower().find(lower_query) + if idx != -1: + start = max(0, idx - 50) + end = min(len(content), idx + len(query) + 100) + snippet = ( + ("..." if start > 0 else "") + + content[start:end] + + ("..." if end < len(content) else "") + ) + else: + snippet = content[:150] + ("..." if len(content) > 150 else "") + + results.append( + { + "channel_id": msg.channel_id, + "channel_name": channel.name if channel else "Unknown", + "message_id": msg.id, + "content_snippet": snippet, + "is_thread_reply": msg.parent_id is not None, + "parent_id": msg.parent_id, + "created_at": msg.created_at, + } + ) + + return json.dumps(results, ensure_ascii=False) + except Exception as e: + log.exception(f"search_channel_messages error: {e}") + return json.dumps({"error": str(e)}) + + +async def view_channel_message( + message_id: str, + __request__: Request = None, + __user__: dict = None, +) -> str: + """ + Get the full content of a channel message by its ID, including thread replies. + + :param message_id: The ID of the message to retrieve + :return: JSON with the message content, channel info, and thread replies if any + """ + if __request__ is None: + return json.dumps({"error": "Request context not available"}) + + if not __user__: + return json.dumps({"error": "User context not available"}) + + try: + user_id = __user__.get("id") + + message = Messages.get_message_by_id(message_id) + + if not message: + return json.dumps({"error": "Message not found"}) + + # Verify user has access to the channel + channel = Channels.get_channel_by_id(message.channel_id) + if not channel: + return json.dumps({"error": "Channel not found"}) + + # Check if user has access to the channel + user_channels = Channels.get_channels_by_user_id(user_id) + channel_ids = [c.id for c in user_channels] + + if message.channel_id not in channel_ids: + return json.dumps({"error": "Access denied"}) + + # Build response with thread information + result = { + "id": message.id, + "channel_id": message.channel_id, + "channel_name": channel.name, + "content": message.content, + "user_id": message.user_id, + "is_thread_reply": message.parent_id is not None, + "parent_id": message.parent_id, + "reply_count": message.reply_count, + "created_at": message.created_at, + "updated_at": message.updated_at, + } + + # Include user info if available + if message.user: + result["user_name"] = message.user.name + + return json.dumps(result, ensure_ascii=False) + except Exception as e: + log.exception(f"view_channel_message error: {e}") + return json.dumps({"error": str(e)}) + + +async def view_channel_thread( + parent_message_id: str, + __request__: Request = None, + __user__: dict = None, +) -> str: + """ + Get all messages in a channel thread, including the parent message and all replies. + + :param parent_message_id: The ID of the parent message that started the thread + :return: JSON with the parent message and all thread replies in chronological order + """ + if __request__ is None: + return json.dumps({"error": "Request context not available"}) + + if not __user__: + return json.dumps({"error": "User context not available"}) + + try: + user_id = __user__.get("id") + + # Get the parent message + parent_message = Messages.get_message_by_id(parent_message_id) + + if not parent_message: + return json.dumps({"error": "Message not found"}) + + # Verify user has access to the channel + channel = Channels.get_channel_by_id(parent_message.channel_id) + if not channel: + return json.dumps({"error": "Channel not found"}) + + user_channels = Channels.get_channels_by_user_id(user_id) + channel_ids = [c.id for c in user_channels] + + if parent_message.channel_id not in channel_ids: + return json.dumps({"error": "Access denied"}) + + # Get all thread replies + thread_replies = Messages.get_thread_replies_by_message_id(parent_message_id) + + # Build the response + messages = [] + + # Add parent message first + messages.append( + { + "id": parent_message.id, + "content": parent_message.content, + "user_id": parent_message.user_id, + "user_name": parent_message.user.name if parent_message.user else None, + "is_parent": True, + "created_at": parent_message.created_at, + } + ) + + # Add thread replies (reverse to get chronological order) + for reply in reversed(thread_replies): + messages.append( + { + "id": reply.id, + "content": reply.content, + "user_id": reply.user_id, + "user_name": reply.user.name if reply.user else None, + "is_parent": False, + "reply_to_id": reply.reply_to_id, + "created_at": reply.created_at, + } + ) + + return json.dumps( + { + "channel_id": parent_message.channel_id, + "channel_name": channel.name, + "thread_id": parent_message_id, + "message_count": len(messages), + "messages": messages, + }, + ensure_ascii=False, + ) + except Exception as e: + log.exception(f"view_channel_thread error: {e}") + return json.dumps({"error": str(e)}) + + +# ============================================================================= +# KNOWLEDGE BASE TOOLS +# ============================================================================= + + +async def list_knowledge_bases( + count: int = 10, + skip: int = 0, + __request__: Request = None, + __user__: dict = None, +) -> str: + """ + List the user's accessible knowledge bases. + + :param count: Maximum number of KBs to return (default: 10) + :param skip: Number of results to skip for pagination (default: 0) + :return: JSON with KBs containing id, name, description, and file_count + """ + if __request__ is None: + return json.dumps({"error": "Request context not available"}) + + if not __user__: + return json.dumps({"error": "User context not available"}) + + try: + from open_webui.models.knowledge import Knowledges + + user_id = __user__.get("id") + user_group_ids = [group.id for group in Groups.get_groups_by_member_id(user_id)] + + result = Knowledges.search_knowledge_bases( + user_id, + filter={ + "query": "", + "user_id": user_id, + "group_ids": user_group_ids, + }, + skip=skip, + limit=count, + ) + + knowledge_bases = [] + for knowledge_base in result.items: + files = Knowledges.get_files_by_id(knowledge_base.id) + file_count = len(files) if files else 0 + + knowledge_bases.append( + { + "id": knowledge_base.id, + "name": knowledge_base.name, + "description": knowledge_base.description or "", + "file_count": file_count, + "updated_at": knowledge_base.updated_at, + } + ) + + return json.dumps(knowledge_bases, ensure_ascii=False) + except Exception as e: + log.exception(f"list_knowledge_bases error: {e}") + return json.dumps({"error": str(e)}) + + +async def search_knowledge_bases( + query: str, + count: int = 5, + skip: int = 0, + __request__: Request = None, + __user__: dict = None, +) -> str: + """ + Search the user's accessible knowledge bases by name and description. + + :param query: The search query to find matching knowledge bases + :param count: Maximum number of results to return (default: 5) + :param skip: Number of results to skip for pagination (default: 0) + :return: JSON with matching KBs containing id, name, description, and file_count + """ + if __request__ is None: + return json.dumps({"error": "Request context not available"}) + + if not __user__: + return json.dumps({"error": "User context not available"}) + + try: + from open_webui.models.knowledge import Knowledges + + user_id = __user__.get("id") + user_group_ids = [group.id for group in Groups.get_groups_by_member_id(user_id)] + + result = Knowledges.search_knowledge_bases( + user_id, + filter={ + "query": query, + "user_id": user_id, + "group_ids": user_group_ids, + }, + skip=skip, + limit=count, + ) + + knowledge_bases = [] + for knowledge_base in result.items: + files = Knowledges.get_files_by_id(knowledge_base.id) + file_count = len(files) if files else 0 + + knowledge_bases.append( + { + "id": knowledge_base.id, + "name": knowledge_base.name, + "description": knowledge_base.description or "", + "file_count": file_count, + "updated_at": knowledge_base.updated_at, + } + ) + + return json.dumps(knowledge_bases, ensure_ascii=False) + except Exception as e: + log.exception(f"search_knowledge_bases error: {e}") + return json.dumps({"error": str(e)}) + + +async def search_knowledge_files( + query: str, + knowledge_id: Optional[str] = None, + count: int = 5, + skip: int = 0, + __request__: Request = None, + __user__: dict = None, +) -> str: + """ + Search files across knowledge bases the user has access to. + + :param query: The search query to find matching files by filename + :param knowledge_id: Optional KB id to limit search to a specific knowledge base + :param count: Maximum number of results to return (default: 5) + :param skip: Number of results to skip for pagination (default: 0) + :return: JSON with matching files containing id, filename, and updated_at + """ + if __request__ is None: + return json.dumps({"error": "Request context not available"}) + + if not __user__: + return json.dumps({"error": "User context not available"}) + + try: + from open_webui.models.knowledge import Knowledges + + user_id = __user__.get("id") + user_group_ids = [group.id for group in Groups.get_groups_by_member_id(user_id)] + + if knowledge_id: + result = Knowledges.search_files_by_id( + knowledge_id=knowledge_id, + user_id=user_id, + filter={"query": query}, + skip=skip, + limit=count, + ) + else: + result = Knowledges.search_knowledge_files( + filter={ + "query": query, + "user_id": user_id, + "group_ids": user_group_ids, + }, + skip=skip, + limit=count, + ) + + files = [] + for file in result.items: + file_info = { + "id": file.id, + "filename": file.filename, + "updated_at": file.updated_at, + } + if hasattr(file, "collection") and file.collection: + file_info["knowledge_id"] = file.collection.get("id", "") + file_info["knowledge_name"] = file.collection.get("name", "") + files.append(file_info) + + return json.dumps(files, ensure_ascii=False) + except Exception as e: + log.exception(f"search_knowledge_files error: {e}") + return json.dumps({"error": str(e)}) + + +async def view_knowledge_file( + file_id: str, + __request__: Request = None, + __user__: dict = None, +) -> str: + """ + Get the full content of a file from a knowledge base. + + :param file_id: The ID of the file to retrieve + :return: JSON with the file's id, filename, and full text content + """ + if __request__ is None: + return json.dumps({"error": "Request context not available"}) + + if not __user__: + return json.dumps({"error": "User context not available"}) + + try: + from open_webui.models.files import Files + from open_webui.models.knowledge import Knowledges + from open_webui.utils.access_control import has_access + + user_id = __user__.get("id") + user_role = __user__.get("role", "user") + user_group_ids = [group.id for group in Groups.get_groups_by_member_id(user_id)] + + file = Files.get_file_by_id(file_id) + if not file: + return json.dumps({"error": "File not found"}) + + # Check access via any KB containing this file + knowledges = Knowledges.get_knowledges_by_file_id(file_id) + has_knowledge_access = False + knowledge_info = None + + for knowledge_base in knowledges: + if ( + user_role == "admin" + or knowledge_base.user_id == user_id + or has_access( + user_id, "read", knowledge_base.access_control, user_group_ids + ) + ): + has_knowledge_access = True + knowledge_info = {"id": knowledge_base.id, "name": knowledge_base.name} + break + + if not has_knowledge_access: + if file.user_id != user_id and user_role != "admin": + return json.dumps({"error": "Access denied"}) + + content = "" + if file.data: + content = file.data.get("content", "") + + result = { + "id": file.id, + "filename": file.filename, + "content": content, + "updated_at": file.updated_at, + "created_at": file.created_at, + } + if knowledge_info: + result["knowledge_id"] = knowledge_info["id"] + result["knowledge_name"] = knowledge_info["name"] + + return json.dumps(result, ensure_ascii=False) + except Exception as e: + log.exception(f"view_knowledge_file error: {e}") + return json.dumps({"error": str(e)}) + + +async def query_knowledge_files( + query: str, + knowledge_ids: Optional[list[str]] = None, + count: int = 5, + __request__: Request = None, + __user__: dict = None, + __model_knowledge__: list[dict] = None, +) -> str: + """ + Search knowledge base files using semantic/vector search. This should be your first + choice for finding information before searching the web. Searches across collections (KBs), + individual files, and notes that the user has access to. + + :param query: The search query to find semantically relevant content + :param knowledge_ids: Optional list of KB ids to limit search to specific knowledge bases + :param count: Maximum number of results to return (default: 5) + :return: JSON with relevant chunks containing content, source filename, and relevance score + """ + if __request__ is None: + return json.dumps({"error": "Request context not available"}) + + if not __user__: + return json.dumps({"error": "User context not available"}) + + try: + from open_webui.models.knowledge import Knowledges + from open_webui.models.files import Files + from open_webui.models.notes import Notes + from open_webui.retrieval.utils import query_collection + from open_webui.utils.access_control import has_access + + user_id = __user__.get("id") + user_role = __user__.get("role", "user") + user_group_ids = [group.id for group in Groups.get_groups_by_member_id(user_id)] + + embedding_function = __request__.app.state.EMBEDDING_FUNCTION + if not embedding_function: + return json.dumps({"error": "Embedding function not configured"}) + + collection_names = [] + note_results = [] # Notes aren't vectorized, handle separately + + # If model has attached knowledge, use those + if __model_knowledge__: + for item in __model_knowledge__: + item_type = item.get("type") + item_id = item.get("id") + + if item_type == "collection": + # Knowledge base - use KB ID as collection name + knowledge = Knowledges.get_knowledge_by_id(item_id) + if knowledge and ( + user_role == "admin" + or knowledge.user_id == user_id + or has_access( + user_id, "read", knowledge.access_control, user_group_ids + ) + ): + collection_names.append(item_id) + + elif item_type == "file": + # Individual file - use file-{id} as collection name + file = Files.get_file_by_id(item_id) + if file and (user_role == "admin" or file.user_id == user_id): + collection_names.append(f"file-{item_id}") + + elif item_type == "note": + # Note - always return full content as context + note = Notes.get_note_by_id(item_id) + if note and ( + user_role == "admin" + or note.user_id == user_id + or has_access(user_id, "read", note.access_control) + ): + content = note.data.get("content", {}).get("md", "") + note_results.append( + { + "content": content, + "source": note.title, + "note_id": note.id, + "type": "note", + } + ) + + elif knowledge_ids: + # User specified specific KBs + for knowledge_id in knowledge_ids: + knowledge = Knowledges.get_knowledge_by_id(knowledge_id) + if knowledge and ( + user_role == "admin" + or knowledge.user_id == user_id + or has_access( + user_id, "read", knowledge.access_control, user_group_ids + ) + ): + collection_names.append(knowledge_id) + else: + # No model knowledge and no specific IDs - search all accessible KBs + result = Knowledges.search_knowledge_bases( + user_id, + filter={ + "query": "", + "user_id": user_id, + "group_ids": user_group_ids, + }, + skip=0, + limit=50, + ) + collection_names = [knowledge_base.id for knowledge_base in result.items] + + chunks = [] + + # Add note results first + chunks.extend(note_results) + + # Query vector collections if any + if collection_names: + query_results = await query_collection( + collection_names=collection_names, + queries=[query], + embedding_function=embedding_function, + k=count, + ) + + if query_results and "documents" in query_results: + documents = query_results.get("documents", [[]])[0] + metadatas = query_results.get("metadatas", [[]])[0] + distances = query_results.get("distances", [[]])[0] + + for idx, doc in enumerate(documents): + chunk_info = { + "content": doc, + "source": metadatas[idx].get( + "source", metadatas[idx].get("name", "Unknown") + ), + "file_id": metadatas[idx].get("file_id", ""), + } + if idx < len(distances): + chunk_info["distance"] = distances[idx] + chunks.append(chunk_info) + + # Limit to requested count + chunks = chunks[:count] + + return json.dumps(chunks, ensure_ascii=False) + except Exception as e: + log.exception(f"query_knowledge_files error: {e}") + return json.dumps({"error": str(e)}) + + +async def query_knowledge_bases( + query: str, + count: int = 5, + __request__: Request = None, + __user__: dict = None, +) -> str: + """ + Search knowledge bases by semantic similarity to query. + Finds KBs whose name/description match the meaning of your query. + Use this to discover relevant knowledge bases before querying their files. + + :param query: Natural language query describing what you're looking for + :param count: Maximum results (default: 5) + :return: JSON with matching KBs (id, name, description, similarity) + """ + if __request__ is None: + return json.dumps({"error": "Request context not available"}) + + if not __user__: + return json.dumps({"error": "User context not available"}) + + try: + import heapq + from open_webui.models.knowledge import Knowledges + from open_webui.routers.knowledge import KNOWLEDGE_BASES_COLLECTION + from open_webui.retrieval.vector.factory import VECTOR_DB_CLIENT + + user_id = __user__.get("id") + user_group_ids = [group.id for group in Groups.get_groups_by_member_id(user_id)] + query_embedding = await __request__.app.state.EMBEDDING_FUNCTION(query) + + # Min-heap of (distance, knowledge_base_id) - only holds top `count` results + top_results_heap = [] + seen_ids = set() + page_offset = 0 + page_size = 100 + + while True: + accessible_knowledge_bases = Knowledges.search_knowledge_bases( + user_id, + filter={"user_id": user_id, "group_ids": user_group_ids}, + skip=page_offset, + limit=page_size, + ) + + if not accessible_knowledge_bases.items: + break + + accessible_ids = [kb.id for kb in accessible_knowledge_bases.items] + + search_results = VECTOR_DB_CLIENT.search( + collection_name=KNOWLEDGE_BASES_COLLECTION, + vectors=[query_embedding], + filter={"knowledge_base_id": {"$in": accessible_ids}}, + limit=count, + ) + + if search_results and search_results.ids and search_results.ids[0]: + result_ids = search_results.ids[0] + result_distances = ( + search_results.distances[0] + if search_results.distances + else [0] * len(result_ids) + ) + + for knowledge_base_id, distance in zip(result_ids, result_distances): + if knowledge_base_id in seen_ids: + continue + seen_ids.add(knowledge_base_id) + + if len(top_results_heap) < count: + heapq.heappush(top_results_heap, (distance, knowledge_base_id)) + elif distance > top_results_heap[0][0]: + heapq.heapreplace( + top_results_heap, (distance, knowledge_base_id) + ) + + page_offset += page_size + if len(accessible_knowledge_bases.items) < page_size: + break + if page_offset >= MAX_KNOWLEDGE_BASE_SEARCH_ITEMS: + break + + # Sort by distance descending (best first) and fetch KB details + sorted_results = sorted(top_results_heap, key=lambda x: x[0], reverse=True) + + matching_knowledge_bases = [] + for distance, knowledge_base_id in sorted_results: + knowledge_base = Knowledges.get_knowledge_by_id(knowledge_base_id) + if knowledge_base: + matching_knowledge_bases.append( + { + "id": knowledge_base.id, + "name": knowledge_base.name, + "description": knowledge_base.description or "", + "similarity": round(distance, 4), + } + ) + + return json.dumps(matching_knowledge_bases, ensure_ascii=False) + + except Exception as e: + log.exception(f"query_knowledge_bases error: {e}") + return json.dumps({"error": str(e)}) diff --git a/backend/open_webui/utils/access_control.py b/backend/open_webui/utils/access_control.py index 97d0b41491..7784f6efd7 100644 --- a/backend/open_webui/utils/access_control.py +++ b/backend/open_webui/utils/access_control.py @@ -28,6 +28,7 @@ def fill_missing_permissions( def get_permissions( user_id: str, default_permissions: Dict[str, Any], + db: Optional[Any] = None, ) -> Dict[str, Any]: """ Get all permissions for a user by combining the permissions of all groups the user is a member of. @@ -53,7 +54,7 @@ def combine_permissions( ) # Use the most permissive value (True > False) return permissions - user_groups = Groups.get_groups_by_member_id(user_id) + user_groups = Groups.get_groups_by_member_id(user_id, db=db) # Deep copy default permissions to avoid modifying the original dict permissions = json.loads(json.dumps(default_permissions)) @@ -72,6 +73,7 @@ def has_permission( user_id: str, permission_key: str, default_permissions: Dict[str, Any] = {}, + db: Optional[Any] = None, ) -> bool: """ Check if a user has a specific permission by checking the group permissions @@ -92,7 +94,7 @@ def get_permission(permissions: Dict[str, Any], keys: List[str]) -> bool: permission_hierarchy = permission_key.split(".") # Retrieve user group permissions - user_groups = Groups.get_groups_by_member_id(user_id) + user_groups = Groups.get_groups_by_member_id(user_id, db=db) for group in user_groups: if get_permission(group.permissions or {}, permission_hierarchy): @@ -127,6 +129,7 @@ def has_access( access_control: Optional[dict] = None, user_group_ids: Optional[Set[str]] = None, strict: bool = True, + db: Optional[Any] = None, ) -> bool: if access_control is None: if strict: @@ -135,7 +138,7 @@ def has_access( return True if user_group_ids is None: - user_groups = Groups.get_groups_by_member_id(user_id) + user_groups = Groups.get_groups_by_member_id(user_id, db=db) user_group_ids = {group.id for group in user_groups} permitted_ids = get_permitted_group_and_user_ids(type, access_control) @@ -152,10 +155,10 @@ def has_access( # Get all users with access to a resource def get_users_with_access( - type: str = "write", access_control: Optional[dict] = None + type: str = "write", access_control: Optional[dict] = None, db: Optional[Any] = None ) -> list[UserModel]: if access_control is None: - result = Users.get_users(filter={"roles": ["!pending"]}) + result = Users.get_users(filter={"roles": ["!pending"]}, db=db) return result.get("users", []) permitted_ids = get_permitted_group_and_user_ids(type, access_control) @@ -167,8 +170,8 @@ def get_users_with_access( user_ids_with_access = set(permitted_user_ids) - group_user_ids_map = Groups.get_group_user_ids_by_ids(permitted_group_ids) + group_user_ids_map = Groups.get_group_user_ids_by_ids(permitted_group_ids, db=db) for user_ids in group_user_ids_map.values(): user_ids_with_access.update(user_ids) - return Users.get_users_by_user_ids(list(user_ids_with_access)) + return Users.get_users_by_user_ids(list(user_ids_with_access), db=db) diff --git a/backend/open_webui/utils/auth.py b/backend/open_webui/utils/auth.py index 3252b162ea..7c683c3706 100644 --- a/backend/open_webui/utils/auth.py +++ b/backend/open_webui/utils/auth.py @@ -27,6 +27,8 @@ from open_webui.utils.access_control import has_permission from open_webui.models.users import Users +from open_webui.models.auths import Auths + from open_webui.constants import ERROR_MESSAGES @@ -51,6 +53,8 @@ from fastapi import BackgroundTasks, Depends, HTTPException, Request, Response, status from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer +from sqlalchemy.orm import Session +from open_webui.internal.db import get_session from open_webui.utils.redis import get_redis_connection, get_sentinels_from_env @@ -246,6 +250,10 @@ async def is_valid_token(request, decoded) -> bool: async def invalidate_token(request, token): decoded = decode_token(token) + # If token is invalid/expired, nothing to revoke + if not decoded: + return + # Require Redis to store revoked tokens if request.app.state.redis: jti = decoded.get("jti") @@ -289,6 +297,7 @@ async def get_current_user( response: Response, background_tasks: BackgroundTasks, auth_token: HTTPAuthorizationCredentials = Depends(bearer_security), + db: Session = Depends(get_session), ): token = None @@ -303,7 +312,7 @@ async def get_current_user( # auth by api key if token.startswith("sk-"): - user = get_current_user_by_api_key(request, token) + user = get_current_user_by_api_key(request, token, db=db) # Add user info to current span current_span = trace.get_current_span() @@ -332,7 +341,7 @@ async def get_current_user( detail="Invalid token", ) - user = Users.get_user_by_id(data["id"]) + user = Users.get_user_by_id(data["id"], db=db) if user is None: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -382,8 +391,8 @@ async def get_current_user( raise e -def get_current_user_by_api_key(request, api_key: str): - user = Users.get_user_by_api_key(api_key) +def get_current_user_by_api_key(request, api_key: str, db: Session = None): + user = Users.get_user_by_api_key(api_key, db=db) if user is None: raise HTTPException( @@ -411,7 +420,7 @@ def get_current_user_by_api_key(request, api_key: str): current_span.set_attribute("client.user.role", user.role) current_span.set_attribute("client.auth.type", "api_key") - Users.update_last_active_by_id(user.id) + Users.update_last_active_by_id(user.id, db=db) return user @@ -433,6 +442,40 @@ def get_admin_user(user=Depends(get_current_user)): return user +def create_admin_user(email: str, password: str, name: str = "Admin"): + """ + Create an admin user from environment variables. + Used for headless/automated deployments. + Returns the created user or None if creation failed. + """ + + if not email or not password: + return None + + if Users.has_users(): + log.debug("Users already exist, skipping admin creation") + return None + + log.info(f"Creating admin account from environment variables: {email}") + try: + hashed = get_password_hash(password) + user = Auths.insert_new_auth( + email=email.lower(), + password=hashed, + name=name, + role="admin", + ) + if user: + log.info(f"Admin account created successfully: {email}") + return user + else: + log.error("Failed to create admin account from environment variables") + return None + except Exception as e: + log.error(f"Error creating admin account: {e}") + return None + + verify_email_template = """
diff --git a/backend/open_webui/utils/chat.py b/backend/open_webui/utils/chat.py index 15876a3fd3..1424bdcf16 100644 --- a/backend/open_webui/utils/chat.py +++ b/backend/open_webui/utils/chat.py @@ -168,6 +168,7 @@ async def generate_chat_completion( form_data: dict, user: Any, bypass_filter: bool = False, + bypass_system_prompt: bool = False, ): check_credit_by_user_id(user_id=user.id, form_data=form_data) @@ -241,7 +242,11 @@ async def stream_wrapper(stream): yield chunk response = await generate_chat_completion( - request, form_data, user, bypass_filter=True + request, + form_data, + user, + bypass_filter=True, + bypass_system_prompt=bypass_system_prompt, ) return StreamingResponse( stream_wrapper(response.body_iterator), @@ -252,7 +257,11 @@ async def stream_wrapper(stream): return { **( await generate_chat_completion( - request, form_data, user, bypass_filter=True + request, + form_data, + user, + bypass_filter=True, + bypass_system_prompt=bypass_system_prompt, ) ), "selected_model_id": selected_model_id, @@ -272,6 +281,7 @@ async def stream_wrapper(stream): form_data=form_data, user=user, bypass_filter=bypass_filter, + bypass_system_prompt=bypass_system_prompt, ) if form_data.get("stream"): response.headers["content-type"] = "text/event-stream" @@ -298,6 +308,7 @@ async def stream_wrapper(stream): form_data=form_data, user=user, bypass_filter=bypass_filter, + bypass_system_prompt=bypass_system_prompt, ) @@ -325,7 +336,7 @@ async def chat_completed(request: Request, form_data: dict, user: Any): try: data = await process_pipeline_outlet_filter(request, data, user, models) except Exception as e: - return Exception(f"Error: {e}") + raise Exception(f"Error: {e}") metadata = { "chat_id": data["chat_id"], @@ -361,7 +372,7 @@ async def chat_completed(request: Request, form_data: dict, user: Any): ) return result except Exception as e: - return Exception(f"Error: {e}") + raise Exception(f"Error: {e}") async def chat_action(request: Request, action_id: str, form_data: dict, user: Any): @@ -457,6 +468,6 @@ async def chat_action(request: Request, action_id: str, form_data: dict, user: A data = action(**params) except Exception as e: - return Exception(f"Error: {e}") + raise Exception(f"Error: {e}") return data diff --git a/backend/open_webui/utils/db/access_control.py b/backend/open_webui/utils/db/access_control.py index d2e6151e5b..75bd337f8c 100644 --- a/backend/open_webui/utils/db/access_control.py +++ b/backend/open_webui/utils/db/access_control.py @@ -28,9 +28,7 @@ def has_permission(db, DocumentModel, query, filter: dict, permission: str = "re for gid in group_ids: if dialect_name == "sqlite": group_read_conditions.append( - DocumentModel.access_control["read"]["group_ids"].contains( - [gid] - ) + DocumentModel.access_control["read"]["group_ids"].contains(gid) ) elif dialect_name == "postgresql": group_read_conditions.append( @@ -63,9 +61,7 @@ def has_permission(db, DocumentModel, query, filter: dict, permission: str = "re for gid in group_ids: if dialect_name == "sqlite": group_write_conditions.append( - DocumentModel.access_control["write"]["group_ids"].contains( - [gid] - ) + DocumentModel.access_control["write"]["group_ids"].contains(gid) ) elif dialect_name == "postgresql": group_write_conditions.append( @@ -111,9 +107,7 @@ def has_permission(db, DocumentModel, query, filter: dict, permission: str = "re for gid in group_ids: if dialect_name == "sqlite": group_conditions.append( - DocumentModel.access_control[permission]["group_ids"].contains( - [gid] - ) + DocumentModel.access_control[permission]["group_ids"].contains(gid) ) elif dialect_name == "postgresql": group_conditions.append( diff --git a/backend/open_webui/utils/groups.py b/backend/open_webui/utils/groups.py index 0f15f27e2c..26fc5d8434 100644 --- a/backend/open_webui/utils/groups.py +++ b/backend/open_webui/utils/groups.py @@ -7,6 +7,7 @@ def apply_default_group_assignment( default_group_id: str, user_id: str, + db=None, ) -> None: """ Apply default group assignment to a user if default_group_id is provided. @@ -17,7 +18,7 @@ def apply_default_group_assignment( """ if default_group_id: try: - Groups.add_users_to_group(default_group_id, [user_id]) + Groups.add_users_to_group(default_group_id, [user_id], db=db) except Exception as e: log.error( f"Failed to add user {user_id} to default group {default_group_id}: {e}" diff --git a/backend/open_webui/utils/images/comfyui.py b/backend/open_webui/utils/images/comfyui.py index c1293a0fc6..3c402cbc17 100644 --- a/backend/open_webui/utils/images/comfyui.py +++ b/backend/open_webui/utils/images/comfyui.py @@ -64,8 +64,8 @@ def get_history(prompt_id, base_url, api_key): return json.loads(response.read()) -def get_images(ws, prompt, client_id, base_url, api_key): - prompt_id = queue_prompt(prompt, client_id, base_url, api_key)["prompt_id"] +def get_images(ws, workflow, client_id, base_url, api_key): + prompt_id = queue_prompt(workflow, client_id, base_url, api_key)["prompt_id"] output_images = [] while True: out = ws.recv() @@ -79,9 +79,12 @@ def get_images(ws, prompt, client_id, base_url, api_key): continue # previews are binary data history = get_history(prompt_id, base_url, api_key)[prompt_id] - for o in history["outputs"]: - for node_id in history["outputs"]: - node_output = history["outputs"][node_id] + for node_id in history["outputs"]: + node_output = history["outputs"][node_id] + if node_id in workflow and workflow[node_id].get("class_type") in [ + "SaveImage", + "PreviewImage", + ]: if "images" in node_output: for image in node_output["images"]: url = get_image_url( diff --git a/backend/open_webui/utils/logger.py b/backend/open_webui/utils/logger.py index 540527bf82..4af3064235 100644 --- a/backend/open_webui/utils/logger.py +++ b/backend/open_webui/utils/logger.py @@ -6,11 +6,13 @@ from loguru import logger from opentelemetry import trace from open_webui.env import ( - AUDIT_UVICORN_LOGGER_NAMES, + ENABLE_AUDIT_STDOUT, + ENABLE_AUDIT_LOGS_FILE, + AUDIT_LOGS_FILE_PATH, AUDIT_LOG_FILE_ROTATION_SIZE, AUDIT_LOG_LEVEL, - AUDIT_LOGS_FILE_PATH, GLOBAL_LOG_LEVEL, + AUDIT_UVICORN_LOGGER_NAMES, ENABLE_OTEL, ENABLE_OTEL_LOGS, ) @@ -130,9 +132,11 @@ def start_logger(): sys.stdout, level=GLOBAL_LOG_LEVEL, format=stdout_format, - filter=lambda record: "auditable" not in record["extra"], + filter=lambda record: ( + "auditable" not in record["extra"] if ENABLE_AUDIT_STDOUT else True + ), ) - if AUDIT_LOG_LEVEL != "NONE": + if AUDIT_LOG_LEVEL != "NONE" and ENABLE_AUDIT_LOGS_FILE: try: logger.add( AUDIT_LOGS_FILE_PATH, diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index 05f76b1724..fe2d7e5dc1 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -44,6 +44,7 @@ process_web_search, SearchForm, ) +from open_webui.utils.tools import get_builtin_tools from open_webui.routers.images import ( image_generations, CreateImageForm, @@ -92,7 +93,11 @@ convert_logit_bias_input_to_json, get_content_from_message, ) -from open_webui.utils.tools import get_tools, get_updated_tool_function +from open_webui.utils.tools import ( + get_tools, + get_updated_tool_function, + has_tool_server_access, +) from open_webui.utils.plugin import load_function_module_by_id from open_webui.utils.filter import ( get_sorted_filter_ids, @@ -118,6 +123,7 @@ BYPASS_MODEL_ACCESS_CONTROL, ENABLE_REALTIME_CHAT_SAVE, ENABLE_QUERIES_CACHE, + RAG_SYSTEM_CONTEXT, ) from open_webui.constants import TASKS @@ -140,6 +146,196 @@ DEFAULT_CODE_INTERPRETER_TAGS = [("