diff --git a/getstream/base.py b/getstream/base.py index e97bd7ba..9adb4e14 100644 --- a/getstream/base.py +++ b/getstream/base.py @@ -1,8 +1,10 @@ import json +import mimetypes +import os import time import uuid import asyncio -from typing import Any, Dict, Optional, Type, cast, get_origin +from typing import Any, Dict, List, Optional, Tuple, Type, cast, get_origin from getstream.models import APIError from getstream.rate_limit import extract_rate_limit @@ -25,6 +27,11 @@ import ijson +def _read_file_bytes(file_path: str) -> bytes: + with open(file_path, "rb") as f: + return f.read() + + def build_path(path: str, path_params: Optional[Dict[str, Any]]) -> str: if path_params is None: return path @@ -293,6 +300,39 @@ def delete( data_type=data_type, ) + def _upload_multipart( + self, + path: str, + data_type: Type[T], + file_path: str, + *, + path_params: Optional[Dict[str, str]] = None, + query_params: Optional[Dict[str, str]] = None, + form_fields: Optional[List[Tuple[str, str]]] = None, + ) -> StreamResponse[T]: + """Send a multipart/form-data upload request, matching Go/PHP SDK behavior.""" + file_name = os.path.basename(file_path) + content_type = mimetypes.guess_type(file_path)[0] or "application/octet-stream" + with open(file_path, "rb") as f: + file_content = f.read() + + files = {"file": (file_name, file_content, content_type)} + data: Dict[str, str] = {} + for field_name, field_value in form_fields or []: + data[field_name] = field_value + + kwargs: Dict[str, Any] = {"files": files} + if data: + kwargs["data"] = data + + return self._request_sync( + "POST", + path, + query_params=query_params, + kwargs=kwargs | {"path_params": path_params}, + data_type=data_type, + ) + def close(self): """ Close HTTPX client. @@ -333,6 +373,39 @@ async def aclose(self): """Close HTTPX async client (closes pools/keep-alives).""" await self.client.aclose() + async def _upload_multipart( + self, + path: str, + data_type: Type[T], + file_path: str, + *, + path_params: Optional[Dict[str, str]] = None, + query_params: Optional[Dict[str, str]] = None, + form_fields: Optional[List[Tuple[str, str]]] = None, + ) -> StreamResponse[T]: + """Send a multipart/form-data upload request, matching Go/PHP SDK behavior.""" + file_name = os.path.basename(file_path) + content_type = mimetypes.guess_type(file_path)[0] or "application/octet-stream" + + file_content = await asyncio.to_thread(_read_file_bytes, file_path) + + files = {"file": (file_name, file_content, content_type)} + data: Dict[str, str] = {} + for field_name, field_value in form_fields or []: + data[field_name] = field_value + + kwargs: Dict[str, Any] = {"files": files} + if data: + kwargs["data"] = data + + return await self._request_async( + "POST", + path, + query_params=query_params, + kwargs=kwargs | {"path_params": path_params}, + data_type=data_type, + ) + def _endpoint_name(self, path: str) -> str: op = getattr(self, "_operation_name", None) return op or current_operation(self._normalize_endpoint_from_path(path)) or "" diff --git a/getstream/chat/async_client.py b/getstream/chat/async_client.py index 828f0435..9c50bbaa 100644 --- a/getstream/chat/async_client.py +++ b/getstream/chat/async_client.py @@ -1,5 +1,16 @@ +import json +from typing import List, Optional + from getstream.chat.async_channel import Channel from getstream.chat.async_rest_client import ChatRestClient +from getstream.common import telemetry +from getstream.models import ( + ImageSize, + OnlyUserID, + UploadChannelFileResponse, + UploadChannelResponse, +) +from getstream.stream_response import StreamResponse class ChatClient(ChatRestClient): @@ -15,3 +26,46 @@ def __init__(self, api_key: str, base_url, token, timeout, stream, user_agent=No def channel(self, call_type: str, id: str) -> Channel: return Channel(self, call_type, id) + + @telemetry.operation_name("getstream.api.chat.upload_channel_file") + async def upload_channel_file( + self, + type: str, + id: str, + file: str, + user: Optional[OnlyUserID] = None, + ) -> StreamResponse[UploadChannelFileResponse]: + form_fields = [] + if user is not None: + form_fields.append(("user", json.dumps(user.to_dict()))) + return await self._upload_multipart( + "/api/v2/chat/channels/{type}/{id}/file", + UploadChannelFileResponse, + file, + path_params={"type": type, "id": id}, + form_fields=form_fields, + ) + + @telemetry.operation_name("getstream.api.chat.upload_channel_image") + async def upload_channel_image( + self, + type: str, + id: str, + file: str, + upload_sizes: Optional[List[ImageSize]] = None, + user: Optional[OnlyUserID] = None, + ) -> StreamResponse[UploadChannelResponse]: + form_fields = [] + if user is not None: + form_fields.append(("user", json.dumps(user.to_dict()))) + if upload_sizes is not None: + form_fields.append( + ("upload_sizes", json.dumps([s.to_dict() for s in upload_sizes])) + ) + return await self._upload_multipart( + "/api/v2/chat/channels/{type}/{id}/image", + UploadChannelResponse, + file, + path_params={"type": type, "id": id}, + form_fields=form_fields, + ) diff --git a/getstream/chat/client.py b/getstream/chat/client.py index 05ec7d26..de46d7d9 100644 --- a/getstream/chat/client.py +++ b/getstream/chat/client.py @@ -1,5 +1,16 @@ +import json +from typing import List, Optional + from getstream.chat.channel import Channel from getstream.chat.rest_client import ChatRestClient +from getstream.common import telemetry +from getstream.models import ( + ImageSize, + OnlyUserID, + UploadChannelFileResponse, + UploadChannelResponse, +) +from getstream.stream_response import StreamResponse class ChatClient(ChatRestClient): @@ -15,3 +26,46 @@ def __init__(self, api_key: str, base_url, token, timeout, stream, user_agent=No def channel(self, call_type: str, id: str) -> Channel: return Channel(self, call_type, id) + + @telemetry.operation_name("getstream.api.chat.upload_channel_file") + def upload_channel_file( + self, + type: str, + id: str, + file: str, + user: Optional[OnlyUserID] = None, + ) -> StreamResponse[UploadChannelFileResponse]: + form_fields = [] + if user is not None: + form_fields.append(("user", json.dumps(user.to_dict()))) + return self._upload_multipart( + "/api/v2/chat/channels/{type}/{id}/file", + UploadChannelFileResponse, + file, + path_params={"type": type, "id": id}, + form_fields=form_fields, + ) + + @telemetry.operation_name("getstream.api.chat.upload_channel_image") + def upload_channel_image( + self, + type: str, + id: str, + file: str, + upload_sizes: Optional[List[ImageSize]] = None, + user: Optional[OnlyUserID] = None, + ) -> StreamResponse[UploadChannelResponse]: + form_fields = [] + if user is not None: + form_fields.append(("user", json.dumps(user.to_dict()))) + if upload_sizes is not None: + form_fields.append( + ("upload_sizes", json.dumps([s.to_dict() for s in upload_sizes])) + ) + return self._upload_multipart( + "/api/v2/chat/channels/{type}/{id}/image", + UploadChannelResponse, + file, + path_params={"type": type, "id": id}, + form_fields=form_fields, + ) diff --git a/getstream/common/async_client.py b/getstream/common/async_client.py index 3a85c767..ef9294f4 100644 --- a/getstream/common/async_client.py +++ b/getstream/common/async_client.py @@ -1,4 +1,15 @@ +import json +from typing import List, Optional + +from getstream.common import telemetry from getstream.common.async_rest_client import CommonRestClient +from getstream.models import ( + FileUploadResponse, + ImageSize, + ImageUploadResponse, + OnlyUserID, +) +from getstream.stream_response import StreamResponse class CommonClient(CommonRestClient): @@ -10,3 +21,38 @@ def __init__(self, api_key: str, base_url, token, timeout, user_agent=None): timeout=timeout, user_agent=user_agent, ) + + @telemetry.operation_name("getstream.api.common.upload_file") + async def upload_file( + self, file: str, user: Optional[OnlyUserID] = None + ) -> StreamResponse[FileUploadResponse]: + form_fields = [] + if user is not None: + form_fields.append(("user", json.dumps(user.to_dict()))) + return await self._upload_multipart( + "/api/v2/uploads/file", + FileUploadResponse, + file, + form_fields=form_fields, + ) + + @telemetry.operation_name("getstream.api.common.upload_image") + async def upload_image( + self, + file: str, + upload_sizes: Optional[List[ImageSize]] = None, + user: Optional[OnlyUserID] = None, + ) -> StreamResponse[ImageUploadResponse]: + form_fields = [] + if user is not None: + form_fields.append(("user", json.dumps(user.to_dict()))) + if upload_sizes is not None: + form_fields.append( + ("upload_sizes", json.dumps([s.to_dict() for s in upload_sizes])) + ) + return await self._upload_multipart( + "/api/v2/uploads/image", + ImageUploadResponse, + file, + form_fields=form_fields, + ) diff --git a/getstream/common/client.py b/getstream/common/client.py index 1e10237c..748b0b17 100644 --- a/getstream/common/client.py +++ b/getstream/common/client.py @@ -1,4 +1,15 @@ +import json +from typing import List, Optional + +from getstream.common import telemetry from getstream.common.rest_client import CommonRestClient +from getstream.models import ( + FileUploadResponse, + ImageSize, + ImageUploadResponse, + OnlyUserID, +) +from getstream.stream_response import StreamResponse class CommonClient(CommonRestClient): @@ -10,3 +21,38 @@ def __init__(self, api_key: str, base_url, token, timeout, user_agent=None): timeout=timeout, user_agent=user_agent, ) + + @telemetry.operation_name("getstream.api.common.upload_file") + def upload_file( + self, file: str, user: Optional[OnlyUserID] = None + ) -> StreamResponse[FileUploadResponse]: + form_fields = [] + if user is not None: + form_fields.append(("user", json.dumps(user.to_dict()))) + return self._upload_multipart( + "/api/v2/uploads/file", + FileUploadResponse, + file, + form_fields=form_fields, + ) + + @telemetry.operation_name("getstream.api.common.upload_image") + def upload_image( + self, + file: str, + upload_sizes: Optional[List[ImageSize]] = None, + user: Optional[OnlyUserID] = None, + ) -> StreamResponse[ImageUploadResponse]: + form_fields = [] + if user is not None: + form_fields.append(("user", json.dumps(user.to_dict()))) + if upload_sizes is not None: + form_fields.append( + ("upload_sizes", json.dumps([s.to_dict() for s in upload_sizes])) + ) + return self._upload_multipart( + "/api/v2/uploads/image", + ImageUploadResponse, + file, + form_fields=form_fields, + ) diff --git a/tests/test_chat_channel.py b/tests/test_chat_channel.py index f95bfc7e..371508d3 100644 --- a/tests/test_chat_channel.py +++ b/tests/test_chat_channel.py @@ -703,44 +703,32 @@ def test_ban_user_in_channel( class TestChannelFileUpload: - def test_upload_and_delete_file(self, channel: Channel, random_user): - """Upload and delete a file.""" - file_path = str(ASSETS_DIR / "test_upload.txt") + def test_upload_and_delete_file(self, channel: Channel, random_user, tmp_path): + """Upload and delete a file via multipart/form-data.""" + file_path = tmp_path / "chat-test-upload.txt" + file_path.write_text("hello world test file content") - try: - upload_resp = channel.upload_channel_file( - file=file_path, - user=OnlyUserID(id=random_user.id), - ) - assert upload_resp.data.file is not None - file_url = upload_resp.data.file - assert "http" in file_url - - channel.delete_channel_file(url=file_url) - except Exception as e: - if "multipart" in str(e).lower(): - import pytest - - pytest.skip("File upload requires multipart/form-data support") - raise + upload_resp = channel.upload_channel_file( + file=str(file_path), + user=OnlyUserID(id=random_user.id), + ) + assert upload_resp.data.file is not None + file_url = upload_resp.data.file + assert "http" in file_url - def test_upload_and_delete_image(self, channel: Channel, random_user): - """Upload and delete an image.""" - file_path = str(ASSETS_DIR / "test_upload.jpg") + channel.delete_channel_file(url=file_url) - try: - upload_resp = channel.upload_channel_image( - file=file_path, - user=OnlyUserID(id=random_user.id), - ) - assert upload_resp.data.file is not None - image_url = upload_resp.data.file - assert "http" in image_url + def test_upload_and_delete_image(self, channel: Channel, random_user, tmp_path): + """Upload and delete an image via multipart/form-data.""" + file_path = tmp_path / "chat-test-upload.jpg" + file_path.write_bytes(b"fake-jpg-image-data-for-testing") - channel.delete_channel_image(url=image_url) - except Exception as e: - if "multipart" in str(e).lower(): - import pytest + upload_resp = channel.upload_channel_image( + file=str(file_path), + user=OnlyUserID(id=random_user.id), + ) + assert upload_resp.data.file is not None + image_url = upload_resp.data.file + assert "http" in image_url - pytest.skip("Image upload requires multipart/form-data support") - raise + channel.delete_channel_image(url=image_url) diff --git a/tests/test_chat_misc.py b/tests/test_chat_misc.py index 7342cfef..fe19d5df 100644 --- a/tests/test_chat_misc.py +++ b/tests/test_chat_misc.py @@ -13,6 +13,7 @@ EventHook, FileUploadConfig, MessageRequest, + OnlyUserID, QueryFutureChannelBansPayload, SortParamRequest, ) @@ -600,3 +601,35 @@ def test_event_hooks_sqs_sns(client: Stream): finally: # Restore original hooks client.update_app(event_hooks=original_hooks or []) + + +def test_upload_and_delete_file(client: Stream, random_user, tmp_path): + """Upload and delete a file via the common upload endpoint.""" + file_path = tmp_path / "common-test-upload.txt" + file_path.write_text("hello world test file content") + + upload_resp = client.upload_file( + file=str(file_path), + user=OnlyUserID(id=random_user.id), + ) + assert upload_resp.data.file is not None + file_url = upload_resp.data.file + assert "http" in file_url + + client.delete_file(url=file_url) + + +def test_upload_and_delete_image(client: Stream, random_user, tmp_path): + """Upload and delete an image via the common upload endpoint.""" + file_path = tmp_path / "common-test-upload.jpg" + file_path.write_bytes(b"fake-jpg-image-data-for-testing") + + upload_resp = client.upload_image( + file=str(file_path), + user=OnlyUserID(id=random_user.id), + ) + assert upload_resp.data.file is not None + image_url = upload_resp.data.file + assert "http" in image_url + + client.delete_image(url=image_url) diff --git a/uv.lock b/uv.lock index 08c32e71..3bc04b64 100644 --- a/uv.lock +++ b/uv.lock @@ -955,7 +955,7 @@ dev = [ [package.metadata] requires-dist = [ { name = "aiohttp", marker = "extra == 'webrtc'", specifier = ">=3.13.2,<4" }, - { name = "aiortc", marker = "extra == 'webrtc'", specifier = ">=1.14.0,<2" }, + { name = "aiortc", marker = "extra == 'webrtc'", specifier = ">=1.14.0,<1.15.0" }, { name = "av", marker = "extra == 'webrtc'", specifier = ">=14.2.0,<17" }, { name = "dataclasses-json", specifier = ">=0.6.0,<0.7" }, { name = "httpx", specifier = ">=0.28.1" },