-
Notifications
You must be signed in to change notification settings - Fork 7
add support of doubao image #16
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,392 @@ | ||
| """ | ||
| title: Doubao Image | ||
| description: Image generation with Doubao Seedream 5.0 | ||
| author: yuzukumo | ||
| git_url: https://github.com/OVINC-CN/OpenWebUIPlugin.git | ||
| version: 0.1.0 | ||
| licence: MIT | ||
| """ | ||
|
|
||
| import base64 | ||
| import binascii | ||
| import io | ||
| import json | ||
| import logging | ||
| import mimetypes | ||
| import time | ||
| import uuid | ||
| from typing import Any, AsyncIterable, List, Literal, Optional, Tuple | ||
|
|
||
| import httpx | ||
| from fastapi import BackgroundTasks, Request, UploadFile | ||
| from httpx import Response | ||
| from open_webui.env import GLOBAL_LOG_LEVEL | ||
| from open_webui.models.users import UserModel, Users | ||
| from open_webui.routers.files import get_file_content_by_id, upload_file | ||
| from pydantic import BaseModel, Field | ||
| from starlette.datastructures import Headers | ||
| from starlette.responses import StreamingResponse | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
| logger.setLevel(GLOBAL_LOG_LEVEL) | ||
|
|
||
|
|
||
| class APIException(Exception): | ||
| def __init__(self, status: int, content: str, response: Response): | ||
| self._status = status | ||
| self._content = content | ||
| self._response = response | ||
|
|
||
| def __str__(self) -> str: | ||
| try: | ||
| return json.loads(self._content).get("error", {}).get("message", self._content) | ||
| except Exception: | ||
| pass | ||
| try: | ||
| self._response.raise_for_status() | ||
| except Exception as err: | ||
| return str(err) | ||
| return "Unknown API error" | ||
|
|
||
|
|
||
| class Pipe: | ||
| class Valves(BaseModel): | ||
| base_url: str = Field( | ||
| default="https://ark.cn-beijing.volces.com/api/v3", | ||
| title="Base URL", | ||
| ) | ||
| api_key: str = Field(default="", title="API Key") | ||
| timeout: int = Field(default=600, title="请求超时(秒)") | ||
| proxy: str = Field(default="", title="代理地址") | ||
| models: str = Field( | ||
| default="doubao-seedream-5-0-260128,doubao-seedream-5-0-lite-260128", | ||
| title="支持模型列表", | ||
| description="多个模型用逗号分隔", | ||
| ) | ||
|
|
||
| class UserValves(BaseModel): | ||
| n: int = Field(default=1, title="输出张数", ge=1, le=10) | ||
| size: str = Field( | ||
| default="2K", | ||
| title="尺寸", | ||
| description="支持 2K、3K 或像素尺寸,例如 2048x2048", | ||
| ) | ||
| output_format: Literal["png", "jpeg"] = Field(default="png", title="输出格式") | ||
| watermark: bool = Field(default=False, title="AI 水印") | ||
| enable_web_search: bool = Field(default=False, title="启用联网搜索") | ||
|
|
||
| def __init__(self): | ||
| self.valves = self.Valves() | ||
|
|
||
| def pipes(self) -> List[dict]: | ||
| return [ | ||
| {"id": model.strip(), "name": model.strip()} for model in self.valves.models.split(",") if model.strip() | ||
| ] | ||
|
|
||
| async def pipe(self, body: dict, __user__: dict, __request__: Request) -> StreamingResponse: | ||
| return StreamingResponse( | ||
| self._pipe(body=body, __user__=__user__, __request__=__request__), | ||
| media_type="text/event-stream", | ||
| ) | ||
|
|
||
| async def _pipe(self, body: dict, __user__: dict, __request__: Request) -> AsyncIterable[str]: | ||
| user = Users.get_user_by_id(__user__["id"]) | ||
| if not user: | ||
| raise ValueError("user not found") | ||
|
|
||
| model, payload = await self._build_payload(user=user, body=body, user_valves=__user__.get("valves", {})) | ||
|
|
||
| async with httpx.AsyncClient( | ||
| base_url=self.valves.base_url, | ||
| headers={"Authorization": f"Bearer {self.valves.api_key}"}, | ||
| proxy=self.valves.proxy or None, | ||
| trust_env=True, | ||
| timeout=self.valves.timeout, | ||
| ) as client: | ||
| response = await client.post(**payload) | ||
| if response.status_code != 200: | ||
| raise APIException( | ||
| status=response.status_code, | ||
| content=response.text, | ||
| response=response, | ||
| ) | ||
|
|
||
| response_json = response.json() | ||
| content, usage = self._parse_response_content( | ||
| response_json=response_json, | ||
| user=user, | ||
| __request__=__request__, | ||
| output_format=payload["json"]["output_format"], | ||
| ) | ||
| if body.get("stream"): | ||
| yield self._format_data(is_stream=True, model=model, content=content, usage=None) | ||
| if usage: | ||
| yield self._format_data(is_stream=True, model=model, content=None, usage=usage) | ||
| else: | ||
| yield self._format_data(is_stream=False, model=model, content=content, usage=usage) | ||
|
|
||
| async def _build_payload(self, user: UserModel, body: dict, user_valves: Any) -> Tuple[str, dict]: | ||
| user_valves = self._normalize_user_valves(user_valves) | ||
| model = body["model"].split(".", 1)[1] if "." in body["model"] else body["model"] | ||
|
|
||
| prompt, images = await self._parse_messages(user=user, body=body) | ||
|
|
||
| if len(images) > 14: | ||
| raise ValueError("豆包图片生成最多支持 14 张参考图。") | ||
| if len(images) + user_valves.n > 15: | ||
| raise ValueError("参考图数量 + 输出张数不能超过 15。") | ||
|
|
||
| data = { | ||
| "model": model, | ||
| "prompt": prompt, | ||
| "size": user_valves.size, | ||
| "output_format": user_valves.output_format, | ||
| "watermark": user_valves.watermark, | ||
| "response_format": "b64_json", | ||
| } | ||
|
|
||
| if images: | ||
| data["image"] = images[0] if len(images) == 1 else images | ||
|
|
||
| if user_valves.n > 1: | ||
| data["sequential_image_generation"] = "auto" | ||
| data["sequential_image_generation_options"] = {"max_images": user_valves.n} | ||
|
|
||
| if user_valves.enable_web_search: | ||
| data["tools"] = [{"type": "web_search"}] | ||
|
|
||
| payload = { | ||
| "url": "/images/generations", | ||
| "json": data, | ||
| } | ||
| return model, payload | ||
|
|
||
| async def _parse_messages(self, user: UserModel, body: dict) -> Tuple[str, List[str]]: | ||
| prompt_parts: List[str] = [] | ||
| images: List[str] = [] | ||
|
|
||
| messages = body.get("messages", []) | ||
| if len(messages) > 6: | ||
| messages = messages[-6:] | ||
|
|
||
| for message in messages: | ||
| role = message.get("role") | ||
| if role == "system": | ||
| continue | ||
|
|
||
| allow_text = role != "assistant" | ||
| content = message.get("content") | ||
|
|
||
| if isinstance(content, str): | ||
| for line in content.split("\n"): | ||
| line = line.strip() | ||
| if not line: | ||
| continue | ||
| if line.startswith("![") and "-image-" in line: | ||
| data_url = await self._get_image_data_url_from_markdown(user=user, markdown_string=line) | ||
| if data_url: | ||
| images.append(data_url) | ||
| continue | ||
| if allow_text: | ||
| prompt_parts.append(line) | ||
| continue | ||
|
|
||
| if isinstance(content, list): | ||
| for item in content: | ||
| item_type = item.get("type") | ||
| if item_type == "text": | ||
| text = item.get("text", "").strip() | ||
| if text and allow_text: | ||
| prompt_parts.append(text) | ||
| continue | ||
|
|
||
| if item_type in {"image_url", "input_image"}: | ||
| image_url = self._extract_image_url(item) | ||
| if image_url: | ||
| images.append(image_url) | ||
| continue | ||
|
|
||
| raise TypeError("message content invalid") | ||
| continue | ||
|
|
||
| raise TypeError("message content invalid") | ||
|
|
||
| prompt = "\n".join(prompt_parts).strip() | ||
| if not prompt: | ||
| prompt = body.get("prompt", "Please generate an image.") | ||
| return prompt, images | ||
|
|
||
| def _parse_response_content( | ||
| self, | ||
| response_json: dict, | ||
| user: UserModel, | ||
| __request__: Request, | ||
| output_format: str, | ||
| ) -> Tuple[str, Optional[dict]]: | ||
| results: List[str] = [] | ||
| mime_type = "image/png" if output_format == "png" else "image/jpeg" | ||
| for item in response_json.get("data", []): | ||
| image_markdown = self._render_response_item( | ||
| item=item, | ||
| user=user, | ||
| __request__=__request__, | ||
| fallback_mime_type=mime_type, | ||
| ) | ||
| if image_markdown: | ||
| results.append(image_markdown) | ||
|
|
||
| if not results and response_json.get("error"): | ||
| raise ValueError(self._extract_error_message(response_json["error"]) or "Unknown API error") | ||
|
|
||
| if not results: | ||
| raise ValueError("未解析到豆包图片响应") | ||
|
|
||
| return "\n\n".join(results), response_json.get("usage") | ||
|
|
||
| def _render_response_item( | ||
| self, | ||
| item: dict, | ||
| user: UserModel, | ||
| __request__: Request, | ||
| fallback_mime_type: str, | ||
| ) -> str: | ||
| if not isinstance(item, dict): | ||
| return "" | ||
|
|
||
| b64_json = item.get("b64_json") | ||
| image_url = item.get("url") | ||
| mime_type = item.get("mime_type") or fallback_mime_type | ||
|
|
||
| if b64_json: | ||
| return self._upload_image( | ||
| __request__=__request__, | ||
| user=user, | ||
| image_data=b64_json, | ||
| mime_type=mime_type, | ||
| ) | ||
|
|
||
| if image_url: | ||
| return f"" | ||
|
|
||
| return "" | ||
|
|
||
| def _upload_image( | ||
| self, | ||
| __request__: Request, | ||
| user: UserModel, | ||
| image_data: str, | ||
| mime_type: str, | ||
| ) -> str: | ||
| image_bytes = self._decode_base64_image(image_data) | ||
| file_ext = mimetypes.guess_extension(mime_type) or ".png" | ||
| if file_ext == ".jpe": | ||
| file_ext = ".jpg" | ||
|
|
||
| file_item = upload_file( | ||
| request=__request__, | ||
| background_tasks=BackgroundTasks(), | ||
| file=UploadFile( | ||
| file=io.BytesIO(image_bytes), | ||
| filename=f"generated-image-{uuid.uuid4().hex}{file_ext}", | ||
| headers=Headers({"content-type": mime_type}), | ||
| ), | ||
| process=False, | ||
| user=user, | ||
| metadata={"mime_type": mime_type}, | ||
| ) | ||
| image_url = __request__.app.url_path_for("get_file_content_by_id", id=file_item.id) | ||
| return f"" | ||
|
|
||
| async def _get_image_data_url_from_markdown(self, user: UserModel, markdown_string: str) -> str: | ||
| file_id = self._extract_file_id_from_markdown(markdown_string) | ||
| if not file_id: | ||
| return "" | ||
|
|
||
| file_response = await get_file_content_by_id(id=file_id, user=user) | ||
| with open(file_response.path, "rb") as file_content: | ||
| image_bytes = file_content.read() | ||
|
|
||
| mime_type = mimetypes.guess_type(file_response.path)[0] or "image/png" | ||
| encoded = base64.b64encode(image_bytes).decode() | ||
| return f"data:{mime_type};base64,{encoded}" | ||
|
OrenZhang marked this conversation as resolved.
|
||
|
|
||
| @staticmethod | ||
| def _extract_image_url(item: dict) -> str: | ||
| image_url = item.get("image_url", "") | ||
| if isinstance(image_url, dict): | ||
| image_url = image_url.get("url", "") | ||
| return image_url if isinstance(image_url, str) else "" | ||
|
|
||
| @staticmethod | ||
| def _extract_file_id_from_markdown(markdown_string: str) -> str: | ||
| try: | ||
| alt_text = markdown_string.split("![", 1)[1].split("]", 1)[0] | ||
| if "-image-" not in alt_text: | ||
| return "" | ||
| return alt_text.rsplit("-image-", 1)[1] | ||
| except Exception: | ||
| return "" | ||
|
|
||
| @staticmethod | ||
| def _extract_error_message(error: Any) -> str: | ||
| if isinstance(error, dict): | ||
| return error.get("message") or error.get("code") or "" | ||
| if isinstance(error, str): | ||
| return error | ||
| return "" | ||
|
|
||
| @staticmethod | ||
| def _decode_base64_image(image_data: str) -> bytes: | ||
| data = image_data.strip() | ||
| if data.startswith("data:") and "," in data: | ||
| data = data.split(",", 1)[1] | ||
|
|
||
| data = "".join(data.split()) | ||
|
|
||
| try: | ||
| decoded = base64.b64decode(data, validate=True) | ||
| except (binascii.Error, ValueError): | ||
| padding = len(data) % 4 | ||
| if padding: | ||
| data = f"{data}{'=' * (4 - padding)}" | ||
| decoded = base64.b64decode(data) | ||
|
|
||
| if not decoded: | ||
| raise ValueError("decoded image bytes is empty") | ||
| return decoded | ||
|
|
||
| @staticmethod | ||
| def _normalize_user_valves(user_valves: Any) -> "Pipe.UserValves": | ||
| if isinstance(user_valves, Pipe.UserValves): | ||
| return user_valves | ||
| if isinstance(user_valves, BaseModel): | ||
| if hasattr(user_valves, "model_dump"): | ||
| return Pipe.UserValves(**user_valves.model_dump()) | ||
| return Pipe.UserValves(**user_valves.dict()) | ||
| return Pipe.UserValves(**(user_valves or {})) | ||
|
|
||
| @staticmethod | ||
| def _format_data( | ||
| is_stream: bool, | ||
| model: Optional[str] = "", | ||
| content: Optional[str] = "", | ||
| usage: Optional[dict] = None, | ||
| ) -> str: | ||
| data = { | ||
| "id": f"chat.{uuid.uuid4().hex}", | ||
| "object": "chat.completion.chunk", | ||
| "choices": [], | ||
| "created": int(time.time()), | ||
| "model": model, | ||
| } | ||
| if content: | ||
| data["choices"] = [ | ||
| { | ||
| "finish_reason": "stop", | ||
| "index": 0, | ||
| "delta" if is_stream else "message": {"content": content}, | ||
| } | ||
| ] | ||
| if usage: | ||
| data["usage"] = usage | ||
| return f"data: {json.dumps(data, ensure_ascii=False)}\n\n" | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.