From 374cb185b4ec4147fd70771f874d900fd7f9116a Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Sun, 15 Feb 2026 16:36:56 +0800 Subject: [PATCH 1/9] to test --- .../Tinker-Compatible-Client.md | 2 +- ...\271\345\256\242\346\210\267\347\253\257.md" | 2 +- src/twinkle/server/tinker/model.py | 15 ++++++++++++++- src/twinkle/server/tinker/server.py | 8 ++------ src/twinkle/server/twinkle/model.py | 17 +++++++++++++++-- src/twinkle/server/twinkle/processor.py | 17 ++++++++++++++++- src/twinkle/server/utils/adapter_manager.py | 4 +++- src/twinkle/server/utils/validation.py | 6 +++--- src/twinkle_client/__init__.py | 2 +- src/twinkle_client/http/http_utils.py | 2 +- 10 files changed, 57 insertions(+), 18 deletions(-) diff --git a/docs/source_en/Usage Guide/Server and Client/Tinker-Compatible-Client.md b/docs/source_en/Usage Guide/Server and Client/Tinker-Compatible-Client.md index 67a6b30f..d3cf4a8f 100644 --- a/docs/source_en/Usage Guide/Server and Client/Tinker-Compatible-Client.md +++ b/docs/source_en/Usage Guide/Server and Client/Tinker-Compatible-Client.md @@ -25,7 +25,7 @@ for item in service_client.get_server_capabilities().supported_models: When calling `init_tinker_compat_client`, the following operations are automatically executed: 1. **Patch Tinker SDK**: Bypass Tinker's `tinker://` prefix validation, allowing it to connect to standard HTTP addresses -2. **Set Request Headers**: Inject necessary authentication headers such as `X-Ray-Serve-Request-Id` and `Authorization` +2. **Set Request Headers**: Inject necessary authentication headers such as `serve_multiplexed_model_id` and `Authorization` 3. **Return `ServiceClient`**: Returns a standard Tinker `ServiceClient` object, subsequent operations are completely identical to native Tinker This means that after initialization, **all existing Tinker training code can be used directly** without any modifications. diff --git "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\347\253\257\345\222\214\345\256\242\346\210\267\347\253\257/Tinker\345\205\274\345\256\271\345\256\242\346\210\267\347\253\257.md" "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\347\253\257\345\222\214\345\256\242\346\210\267\347\253\257/Tinker\345\205\274\345\256\271\345\256\242\346\210\267\347\253\257.md" index ef1c7e26..35b39536 100644 --- "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\347\253\257\345\222\214\345\256\242\346\210\267\347\253\257/Tinker\345\205\274\345\256\271\345\256\242\346\210\267\347\253\257.md" +++ "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\347\253\257\345\222\214\345\256\242\346\210\267\347\253\257/Tinker\345\205\274\345\256\271\345\256\242\346\210\267\347\253\257.md" @@ -25,7 +25,7 @@ for item in service_client.get_server_capabilities().supported_models: 调用 `init_tinker_compat_client` 时,会自动执行以下操作: 1. **Patch Tinker SDK**:绕过 Tinker 的 `tinker://` 前缀校验,使其可以连接到标准 HTTP 地址 -2. **设置请求头**:注入 `X-Ray-Serve-Request-Id` 和 `Authorization` 等必要的认证头 +2. **设置请求头**:注入 `serve_multiplexed_model_id` 和 `Authorization` 等必要的认证头 3. **返回 `ServiceClient`**:返回一个标准的 Tinker `ServiceClient` 对象,后续操作与原生 Tinker 完全一致 这意味着在初始化之后,**所有已有的 Tinker 训练代码都可以直接使用**,无需任何修改。 diff --git a/src/twinkle/server/tinker/model.py b/src/twinkle/server/tinker/model.py index 2a119162..787a1902 100644 --- a/src/twinkle/server/tinker/model.py +++ b/src/twinkle/server/tinker/model.py @@ -118,6 +118,17 @@ def __init__(self, self._init_adapter_manager(**adapter_config) self.start_adapter_countdown() + @serve.multiplexed(max_num_models_per_replica=kwargs.get('max_loras', 5)) + async def get_multiplexed_adapter(self, request_id: str): + return request_id + + def remove_multiplexed_adapter(self, adapter_name: str): + adapter_info = self.get_adapter_info(adapter_name) + if adapter_info is None or adapter_info.get('request_id') is None: + return + if hasattr(self, '_serve_multiplexed_models'): + self._serve_multiplexed_models.pop(adapter_info['request_id'], None) + def _cleanup_adapter(self, adapter_name: str) -> None: """Common adapter cleanup logic used by both manual unload and automatic expiration. @@ -132,6 +143,7 @@ def _cleanup_adapter(self, adapter_name: str) -> None: """ # Remove from model if it exists if self.get_adapter_info(adapter_name): + self.remove_multiplexed_adapter(adapter_name) # Clear adapter state self.clear_adapter_state(adapter_name) @@ -178,7 +190,7 @@ async def _create_adapter(): adapter_name = self.get_adapter_name(adapter_name=model_id) # Register adapter FIRST (limit check happens inside register_adapter) - self.register_adapter(adapter_name, request.state.token, session_id=body.session_id) + self.register_adapter(adapter_name, request.state.token, session_id=body.session_id, request_id=request.state.request_id) # Create adapter AFTER successful registration self.model.add_adapter_to_model(adapter_name=adapter_name, config_or_dir=lora_cfg) @@ -189,6 +201,7 @@ async def _create_adapter(): # Fresh adapter has no accumulated gradients. self.set_adapter_state(adapter_name, 'grad_ready', False) + await self.get_multiplexed_adapter(request.state.request_id) training_run_manager = create_training_run_manager(request.state.token) training_run_manager.save(model_id, body) diff --git a/src/twinkle/server/tinker/server.py b/src/twinkle/server/tinker/server.py index 2e669f56..e0278e68 100644 --- a/src/twinkle/server/tinker/server.py +++ b/src/twinkle/server/tinker/server.py @@ -82,10 +82,6 @@ def __init__(self, self.client = httpx.AsyncClient(timeout=None, trust_env=False) self.route_prefix = kwargs.get('route_prefix', '/api/v1') self.supported_models = self.normalize_models(supported_models) or [ - types.SupportedModel(model_name='Qwen/Qwen2.5-0.5B-Instruct'), - types.SupportedModel(model_name='Qwen/Qwen2.5-3B-Instruct'), - types.SupportedModel(model_name='Qwen/Qwen2.5-7B-Instruct'), - types.SupportedModel(model_name='Qwen/Qwen2.5-72B-Instruct'), types.SupportedModel(model_name='Qwen/Qwen3-30B-A3B-Instruct-2507'), ] # Lock for ModelScope config file operations (login writes, get_user_info reads) @@ -165,8 +161,8 @@ async def _proxy_request(self, request: Request, endpoint: str, base_model: str, try: if os.environ.get('TWINKLE_DEBUG_PROXY', '0') == '1': - logger.info('proxy_to_model endpoint=%s target_url=%s x-ray-serve-request-id=%s', endpoint, - target_url, headers.get('x-ray-serve-request-id')) + logger.info('proxy_to_model endpoint=%s target_url=%s serve_multiplexed_model_id=%s', endpoint, + target_url, headers.get('serve_multiplexed_model_id')) rp_ = await self.client.request( method=request.method, url=target_url, diff --git a/src/twinkle/server/twinkle/model.py b/src/twinkle/server/twinkle/model.py index 1660cd10..e166f29c 100644 --- a/src/twinkle/server/twinkle/model.py +++ b/src/twinkle/server/twinkle/model.py @@ -187,6 +187,17 @@ def __init__(self, nproc_per_node: int, device_group: Dict[str, Any], device_mes self._init_adapter_manager(**adapter_config) self.start_adapter_countdown() + @serve.multiplexed(max_num_models_per_replica=kwargs.get('max_loras', 5)) + async def get_multiplexed_adapter(self, request_id: str): + return request_id + + def remove_multiplexed_adapter(self, adapter_name: str): + adapter_info = self.get_adapter_info(adapter_name) + if adapter_info is None or adapter_info.get('request_id') is None: + return + if hasattr(self, '_serve_multiplexed_models'): + self._serve_multiplexed_models.pop(adapter_info['request_id'], None) + def _on_adapter_expired(self, adapter_name: str) -> None: """Handle adapter expiration by removing it from the model. @@ -198,6 +209,7 @@ def _on_adapter_expired(self, adapter_name: str) -> None: """ # Remove from model if it exists if self.get_adapter_info(adapter_name): + self.remove_multiplexed_adapter(adapter_name) # Clear adapter state self.clear_adapter_state(adapter_name) # Unregister from adapter manager @@ -481,7 +493,7 @@ def upload_to_hub(self, request: Request, body: UploadToHubRequest): return {'result': body.hub_model_id} @app.post('/add_adapter_to_model') - def add_adapter_to_model(self, request: Request, body: AddAdapterRequest): + async def add_adapter_to_model(self, request: Request, body: AddAdapterRequest): """ Add a new adapter to the model. @@ -507,10 +519,11 @@ def add_adapter_to_model(self, request: Request, body: AddAdapterRequest): training_run_manager = create_training_run_manager(token) # Register adapter FIRST (limit check happens inside register_adapter) - self.register_adapter(adapter_name, token) + self.register_adapter(adapter_name, token, request_id=request.state.request_id) # Create adapter AFTER successful registration self.model.add_adapter_to_model(adapter_name, config, **extra_kwargs) + await self.get_multiplexed_adapter(request.state.request_id) # Save training run metadata (similar to tinker's create_model) # Create a training run config from the adapter configuration diff --git a/src/twinkle/server/twinkle/processor.py b/src/twinkle/server/twinkle/processor.py index cbead9b7..ed2e27ef 100644 --- a/src/twinkle/server/twinkle/processor.py +++ b/src/twinkle/server/twinkle/processor.py @@ -68,6 +68,7 @@ def __init__(self, nproc_per_node: int, ncpu_proc_per_node: int, device_group: D self.device_mesh = DeviceMesh.from_sizes(**device_mesh) self.resource_dict = {} self.resource_records: Dict[str, int] = {} + self.resource_client_ids = Dict[str, str] = {} self.hb_thread = threading.Thread(target=self.countdown, daemon=True) self.hb_thread.start() self.state: ServerStateProxy = get_server_state() @@ -83,9 +84,21 @@ def countdown(self): if self.resource_records[key] > self.COUNT_DOWN: self.resource_records.pop(key, None) self.resource_dict.pop(key, None) + self.remove_multiplexed_adapter(key) if key in self.key_token_dict: self.handle_processor_count(self.key_token_dict.pop(key), False) + @serve.multiplexed(max_num_models_per_replica=100) + async def get_multiplexed_adapter(self, request_id: str): + return request_id + + def remove_multiplexed_adapter(self, processor_id: str): + request_id = self.resource_client_ids.pop(processor_id, None) + if request_id is None: + return + if hasattr(self, '_serve_multiplexed_models'): + self._serve_multiplexed_models.pop(request_id, None) + def assert_processor_exists(self, processor_id: str): assert processor_id and processor_id in self.resource_dict, f'Processor {processor_id} not found' @@ -105,7 +118,7 @@ def handle_processor_count(self, token: str, add: bool): self.state.pop_config(user_key) @app.post('/create') - def create(self, request: Request, body: CreateRequest): + async def create(self, request: Request, body: CreateRequest): processor_type_name = body.processor_type class_type = body.class_type @@ -134,6 +147,8 @@ def create(self, request: Request, body: CreateRequest): remote_group=self.device_group.name, device_mesh=self.device_mesh, instance_id=processor_id, **_kwargs) self.resource_dict[processor_id] = processor self.resource_records[processor_id] = 0 + self.resource_client_ids[processor_id] = request.state.request_id + await self.get_multiplexed_adapter(request.state.request_id) return {'processor_id': 'pid:' + processor_id} @app.post('/heartbeat') diff --git a/src/twinkle/server/utils/adapter_manager.py b/src/twinkle/server/utils/adapter_manager.py index 04e56922..1e93756c 100644 --- a/src/twinkle/server/utils/adapter_manager.py +++ b/src/twinkle/server/utils/adapter_manager.py @@ -74,7 +74,7 @@ def _init_adapter_manager( self._adapter_countdown_thread: threading.Thread | None = None self._adapter_countdown_running = False - def register_adapter(self, adapter_name: str, token: str, session_id: str | None = None) -> None: + def register_adapter(self, adapter_name: str, token: str, session_id: str | None = None, request_id: str | None = None) -> None: """Register a new adapter for lifecycle tracking. Args: @@ -82,6 +82,7 @@ def register_adapter(self, adapter_name: str, token: str, session_id: str | None token: User token that owns this adapter. session_id: Optional session ID to associate with this adapter. If provided, adapter will expire when the session expires. + request_id: The client request_id from `serve_multiplexed_model_id` Raises: RuntimeError: If adapter limit is exceeded for this token. @@ -100,6 +101,7 @@ def register_adapter(self, adapter_name: str, token: str, session_id: str | None 'inactivity_counter': 0, 'state': {}, 'expiring': False, + 'request_id': request_id, } logger.debug(f'[AdapterManager] Registered adapter {adapter_name} for token {token[:8]}...' + (f' (session: {session_id})' if session_id else '')) diff --git a/src/twinkle/server/utils/validation.py b/src/twinkle/server/utils/validation.py index 23539ed8..0e5a71d1 100644 --- a/src/twinkle/server/utils/validation.py +++ b/src/twinkle/server/utils/validation.py @@ -11,7 +11,7 @@ async def verify_request_token(request: Request, call_next): This middleware: 1. Extracts the Bearer token from Authorization header 2. Validates the token - 3. Extracts X-Ray-Serve-Request-Id for sticky sessions + 3. Extracts serve_multiplexed_model_id for sticky sessions 4. Stores token and request_id in request.state for later use Args: @@ -26,10 +26,10 @@ async def verify_request_token(request: Request, call_next): if not is_token_valid(token): return JSONResponse(status_code=403, content={'detail': 'Invalid token'}) - request_id = request.headers.get('X-Ray-Serve-Request-Id') + request_id = request.headers.get('serve_multiplexed_model_id') if not request_id: return JSONResponse( - status_code=400, content={'detail': 'Missing X-Ray-Serve-Request-Id header, required for sticky session'}) + status_code=400, content={'detail': 'Missing serve_multiplexed_model_id header, required for sticky session'}) request.state.request_id = request_id request.state.token = token response = await call_next(request) diff --git a/src/twinkle_client/__init__.py b/src/twinkle_client/__init__.py index 5a6928e9..f236f734 100644 --- a/src/twinkle_client/__init__.py +++ b/src/twinkle_client/__init__.py @@ -28,7 +28,7 @@ def init_tinker_compat_client(base_url: str | None = None, api_key: str | None = base_url = f'http://{base_url}' default_headers = { - 'X-Ray-Serve-Request-Id': get_request_id(), + 'serve_multiplexed_model_id': get_request_id(), 'Authorization': 'Bearer ' + api_key, 'Twinkle-Authorization': 'Bearer ' + api_key, # For server compatibility } | kwargs.pop('default_headers', {}) diff --git a/src/twinkle_client/http/http_utils.py b/src/twinkle_client/http/http_utils.py index 522b46af..f9cafa1c 100644 --- a/src/twinkle_client/http/http_utils.py +++ b/src/twinkle_client/http/http_utils.py @@ -16,7 +16,7 @@ def _build_headers(additional_headers: Optional[Dict[str, str]] = None) -> Dict[ Dictionary of headers """ headers = { - 'X-Ray-Serve-Request-Id': get_request_id(), + 'serve_multiplexed_model_id': get_request_id(), 'Authorization': 'Bearer ' + get_api_key(), 'Twinkle-Authorization': 'Bearer ' + get_api_key(), # For server compatibility } From 154872966ef1246bc11ac51a4553b8b49c485d7b Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Sun, 15 Feb 2026 16:58:59 +0800 Subject: [PATCH 2/9] fix --- src/twinkle/server/tinker/model.py | 10 +--------- src/twinkle/server/twinkle/model.py | 10 +--------- src/twinkle/server/twinkle/processor.py | 10 ---------- src/twinkle/server/utils/adapter_manager.py | 4 +--- src/twinkle/server/utils/validation.py | 3 ++- 5 files changed, 5 insertions(+), 32 deletions(-) diff --git a/src/twinkle/server/tinker/model.py b/src/twinkle/server/tinker/model.py index 787a1902..e0075c6a 100644 --- a/src/twinkle/server/tinker/model.py +++ b/src/twinkle/server/tinker/model.py @@ -122,13 +122,6 @@ def __init__(self, async def get_multiplexed_adapter(self, request_id: str): return request_id - def remove_multiplexed_adapter(self, adapter_name: str): - adapter_info = self.get_adapter_info(adapter_name) - if adapter_info is None or adapter_info.get('request_id') is None: - return - if hasattr(self, '_serve_multiplexed_models'): - self._serve_multiplexed_models.pop(adapter_info['request_id'], None) - def _cleanup_adapter(self, adapter_name: str) -> None: """Common adapter cleanup logic used by both manual unload and automatic expiration. @@ -143,7 +136,6 @@ def _cleanup_adapter(self, adapter_name: str) -> None: """ # Remove from model if it exists if self.get_adapter_info(adapter_name): - self.remove_multiplexed_adapter(adapter_name) # Clear adapter state self.clear_adapter_state(adapter_name) @@ -190,7 +182,7 @@ async def _create_adapter(): adapter_name = self.get_adapter_name(adapter_name=model_id) # Register adapter FIRST (limit check happens inside register_adapter) - self.register_adapter(adapter_name, request.state.token, session_id=body.session_id, request_id=request.state.request_id) + self.register_adapter(adapter_name, request.state.token, session_id=body.session_id) # Create adapter AFTER successful registration self.model.add_adapter_to_model(adapter_name=adapter_name, config_or_dir=lora_cfg) diff --git a/src/twinkle/server/twinkle/model.py b/src/twinkle/server/twinkle/model.py index e166f29c..07b4bb0e 100644 --- a/src/twinkle/server/twinkle/model.py +++ b/src/twinkle/server/twinkle/model.py @@ -191,13 +191,6 @@ def __init__(self, nproc_per_node: int, device_group: Dict[str, Any], device_mes async def get_multiplexed_adapter(self, request_id: str): return request_id - def remove_multiplexed_adapter(self, adapter_name: str): - adapter_info = self.get_adapter_info(adapter_name) - if adapter_info is None or adapter_info.get('request_id') is None: - return - if hasattr(self, '_serve_multiplexed_models'): - self._serve_multiplexed_models.pop(adapter_info['request_id'], None) - def _on_adapter_expired(self, adapter_name: str) -> None: """Handle adapter expiration by removing it from the model. @@ -209,7 +202,6 @@ def _on_adapter_expired(self, adapter_name: str) -> None: """ # Remove from model if it exists if self.get_adapter_info(adapter_name): - self.remove_multiplexed_adapter(adapter_name) # Clear adapter state self.clear_adapter_state(adapter_name) # Unregister from adapter manager @@ -519,7 +511,7 @@ async def add_adapter_to_model(self, request: Request, body: AddAdapterRequest): training_run_manager = create_training_run_manager(token) # Register adapter FIRST (limit check happens inside register_adapter) - self.register_adapter(adapter_name, token, request_id=request.state.request_id) + self.register_adapter(adapter_name, token) # Create adapter AFTER successful registration self.model.add_adapter_to_model(adapter_name, config, **extra_kwargs) diff --git a/src/twinkle/server/twinkle/processor.py b/src/twinkle/server/twinkle/processor.py index ed2e27ef..912bcc29 100644 --- a/src/twinkle/server/twinkle/processor.py +++ b/src/twinkle/server/twinkle/processor.py @@ -68,7 +68,6 @@ def __init__(self, nproc_per_node: int, ncpu_proc_per_node: int, device_group: D self.device_mesh = DeviceMesh.from_sizes(**device_mesh) self.resource_dict = {} self.resource_records: Dict[str, int] = {} - self.resource_client_ids = Dict[str, str] = {} self.hb_thread = threading.Thread(target=self.countdown, daemon=True) self.hb_thread.start() self.state: ServerStateProxy = get_server_state() @@ -84,7 +83,6 @@ def countdown(self): if self.resource_records[key] > self.COUNT_DOWN: self.resource_records.pop(key, None) self.resource_dict.pop(key, None) - self.remove_multiplexed_adapter(key) if key in self.key_token_dict: self.handle_processor_count(self.key_token_dict.pop(key), False) @@ -92,13 +90,6 @@ def countdown(self): async def get_multiplexed_adapter(self, request_id: str): return request_id - def remove_multiplexed_adapter(self, processor_id: str): - request_id = self.resource_client_ids.pop(processor_id, None) - if request_id is None: - return - if hasattr(self, '_serve_multiplexed_models'): - self._serve_multiplexed_models.pop(request_id, None) - def assert_processor_exists(self, processor_id: str): assert processor_id and processor_id in self.resource_dict, f'Processor {processor_id} not found' @@ -147,7 +138,6 @@ async def create(self, request: Request, body: CreateRequest): remote_group=self.device_group.name, device_mesh=self.device_mesh, instance_id=processor_id, **_kwargs) self.resource_dict[processor_id] = processor self.resource_records[processor_id] = 0 - self.resource_client_ids[processor_id] = request.state.request_id await self.get_multiplexed_adapter(request.state.request_id) return {'processor_id': 'pid:' + processor_id} diff --git a/src/twinkle/server/utils/adapter_manager.py b/src/twinkle/server/utils/adapter_manager.py index 1e93756c..04e56922 100644 --- a/src/twinkle/server/utils/adapter_manager.py +++ b/src/twinkle/server/utils/adapter_manager.py @@ -74,7 +74,7 @@ def _init_adapter_manager( self._adapter_countdown_thread: threading.Thread | None = None self._adapter_countdown_running = False - def register_adapter(self, adapter_name: str, token: str, session_id: str | None = None, request_id: str | None = None) -> None: + def register_adapter(self, adapter_name: str, token: str, session_id: str | None = None) -> None: """Register a new adapter for lifecycle tracking. Args: @@ -82,7 +82,6 @@ def register_adapter(self, adapter_name: str, token: str, session_id: str | None token: User token that owns this adapter. session_id: Optional session ID to associate with this adapter. If provided, adapter will expire when the session expires. - request_id: The client request_id from `serve_multiplexed_model_id` Raises: RuntimeError: If adapter limit is exceeded for this token. @@ -101,7 +100,6 @@ def register_adapter(self, adapter_name: str, token: str, session_id: str | None 'inactivity_counter': 0, 'state': {}, 'expiring': False, - 'request_id': request_id, } logger.debug(f'[AdapterManager] Registered adapter {adapter_name} for token {token[:8]}...' + (f' (session: {session_id})' if session_id else '')) diff --git a/src/twinkle/server/utils/validation.py b/src/twinkle/server/utils/validation.py index 0e5a71d1..d419818a 100644 --- a/src/twinkle/server/utils/validation.py +++ b/src/twinkle/server/utils/validation.py @@ -29,7 +29,8 @@ async def verify_request_token(request: Request, call_next): request_id = request.headers.get('serve_multiplexed_model_id') if not request_id: return JSONResponse( - status_code=400, content={'detail': 'Missing serve_multiplexed_model_id header, required for sticky session'}) + status_code=400, + content={'detail': 'Missing serve_multiplexed_model_id header, required for sticky session'}) request.state.request_id = request_id request.state.token = token response = await call_next(request) From 96112a8da0ff2e4820ef034b5cc47feab929b311 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Sun, 15 Feb 2026 17:13:33 +0800 Subject: [PATCH 3/9] add code doc --- src/twinkle/server/twinkle/model.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/twinkle/server/twinkle/model.py b/src/twinkle/server/twinkle/model.py index 07b4bb0e..6eb107da 100644 --- a/src/twinkle/server/twinkle/model.py +++ b/src/twinkle/server/twinkle/model.py @@ -189,6 +189,17 @@ def __init__(self, nproc_per_node: int, device_group: Dict[str, Any], device_mes @serve.multiplexed(max_num_models_per_replica=kwargs.get('max_loras', 5)) async def get_multiplexed_adapter(self, request_id: str): + """ + Reference docs: + 1. https://docs.ray.io/en/latest/serve/model-multiplexing.html + 2. https://docs.ray.io/en/latest/serve/llm/architecture/routing-policies.html + 3. https://github.com/ray-project/ray/pull/56855/changes + Args: + request_id: + + Returns: + + """ return request_id def _on_adapter_expired(self, adapter_name: str) -> None: From 946810ad2ef167ef8d2b7add13b4f49d183a89c9 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Sun, 15 Feb 2026 19:03:44 +0800 Subject: [PATCH 4/9] wip --- .../client/tinker/megatron/server_config.yaml | 27 +++++++++---------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/cookbook/client/tinker/megatron/server_config.yaml b/cookbook/client/tinker/megatron/server_config.yaml index fe9ea0d6..522f9f55 100644 --- a/cookbook/client/tinker/megatron/server_config.yaml +++ b/cookbook/client/tinker/megatron/server_config.yaml @@ -33,26 +33,25 @@ applications: runtime_env: env_vars: TWINKLE_TRUST_REMOTE_CODE: "0" - DEVICE_COUNT_PER_PHYSICAL_NODE: "8" # 3. Sampler Service - Runs inference / sampling using vLLM engine # Used for generating text from the model (e.g., evaluating LoRA results). - name: sampler-Qwen3-30B-A3B-Instruct-2507 - route_prefix: /api/v1/sampler/Qwen/Qwen3-30B-A3B-Instruct-2507 + route_prefix: /api/v1/sampler/Qwen/Qwen2.5-7B-Instruct import_path: sampler args: - model_id: "ms://Qwen/Qwen3-30B-A3B-Instruct-2507" # ModelScope model identifier + model_id: "ms://Qwen/Qwen2.5-7B-Instruct" # ModelScope model identifier nproc_per_node: 4 # Number of GPU processes per node sampler_type: vllm # Inference engine: 'vllm' (fast) or 'torch' (TorchSampler) engine_args: # vLLM engine-specific settings max_model_len: 16000 # Maximum sequence length the engine supports - gpu_memory_utilization: 0.85 # Fraction of GPU memory to use (0.0-1.0) + gpu_memory_utilization: 0.7 # Fraction of GPU memory to use (0.0-1.0) enable_lora: true # Allow loading LoRA adapters during inference max_loras: 5 # Max allowed loras working on vLLM at the same time device_group: # Logical device group for the sampler name: sampler gpus_per_worker: 1 - ranks: [0,1,2,3] # GPU rank indices to use + ranks: 4 # GPU rank indices to use device_type: cuda device_mesh: device_type: cuda @@ -63,30 +62,30 @@ applications: deployments: - name: SamplerManagement autoscaling_config: - min_replicas: 1 - max_replicas: 1 + min_replicas: 2 + max_replicas: 2 target_ongoing_requests: 16 ray_actor_options: num_cpus: 0.1 runtime_env: env_vars: TWINKLE_TRUST_REMOTE_CODE: "0" - DEVICE_COUNT_PER_PHYSICAL_NODE: "8" + DEVICE_COUNT_PER_PHYSICAL_NODE: "16" # 2. Model Service (commented out) - Would host the base model for training. # Uncomment and configure if you need a training model worker. - name: models-Qwen3-30B-A3B-Instruct-2507 - route_prefix: /api/v1/model/Qwen/Qwen3-30B-A3B-Instruct-2507 + route_prefix: /api/v1/model/Qwen/Qwen2.5-7B-Instruct import_path: model args: use_megatron: true # Use HuggingFace Transformers backend - model_id: "ms://Qwen/Qwen3-30B-A3B-Instruct-2507" # ModelScope model identifier + model_id: "ms://Qwen/Qwen2.5-7B-Instruct" # ModelScope model identifier max_length: 16000 # model max length max_loras: 5 # model max loras nproc_per_node: 4 # Number of GPU processes per node device_group: name: model - ranks: [4,5,6,7] # GPU rank indices + ranks: 4 # GPU rank indices device_type: cuda device_mesh: device_type: cuda @@ -103,12 +102,12 @@ applications: deployments: - name: ModelManagement autoscaling_config: - min_replicas: 1 - max_replicas: 1 + min_replicas: 2 + max_replicas: 2 target_ongoing_requests: 8 ray_actor_options: num_cpus: 0.1 runtime_env: env_vars: TWINKLE_TRUST_REMOTE_CODE: "0" - DEVICE_COUNT_PER_PHYSICAL_NODE: "8" + DEVICE_COUNT_PER_PHYSICAL_NODE: "16" From 95d474e96b8300d424007bd7d0e815ccdbb7c2b2 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Sun, 15 Feb 2026 19:04:53 +0800 Subject: [PATCH 5/9] wip --- src/twinkle/infra/_ray/resource_manager.py | 13 +++++++++++++ src/twinkle/server/tinker/model.py | 6 ++++-- src/twinkle/server/tinker/sampler.py | 3 +++ src/twinkle/server/twinkle/model.py | 6 ++++-- src/twinkle/server/twinkle/sampler.py | 6 ++++-- 5 files changed, 28 insertions(+), 6 deletions(-) diff --git a/src/twinkle/infra/_ray/resource_manager.py b/src/twinkle/infra/_ray/resource_manager.py index 817cd793..fc657fc3 100644 --- a/src/twinkle/infra/_ray/resource_manager.py +++ b/src/twinkle/infra/_ray/resource_manager.py @@ -137,6 +137,19 @@ def __init__(self, nproc_per_node: int, ncpu_proc_per_node: int, groups: List[De if self.node_ranks.count(0) > 1: self.node_ranks = list(range(len(self.placement_groups))) + self.visible_devices = [] + + @ray.remote + def get_visible_devices(): + return os.environ.get(Platform.get_platform(group.device_type).visible_device_env()) + + if self.placement_groups: + self.visible_devices = ray.get([ + get_visible_devices.options(placement_group=pg).remote() for pg in self.placement_groups + ]) + + breakpoint() + self.node2pg: Dict[int, PlacementGroup] = {} # Map actual node indices to placement groups # For GPU/NPU groups, node indices start from self.min_node_idx diff --git a/src/twinkle/server/tinker/model.py b/src/twinkle/server/tinker/model.py index e0075c6a..c25c9ab0 100644 --- a/src/twinkle/server/tinker/model.py +++ b/src/twinkle/server/tinker/model.py @@ -100,15 +100,17 @@ def __init__(self, else: self.device_mesh = DeviceMesh.from_sizes(**device_mesh) self.use_megatron = use_megatron + replica_context = serve.get_replica_context() + replica_id = replica_context.replica_id.unique_id # Initialize model immediately - choose backend based on use_megatron if use_megatron: from .common.megatron_model import TwinkleCompatMegatronModel self.model = TwinkleCompatMegatronModel( - model_id=model_id, device_mesh=self.device_mesh, remote_group=self.device_group.name, **kwargs) + model_id=model_id, device_mesh=self.device_mesh, remote_group=self.device_group.name, instance_id=replica_id, **kwargs) else: from .common.transformers_model import TwinkleCompatTransformersModel self.model = TwinkleCompatTransformersModel( - model_id=model_id, device_mesh=self.device_mesh, remote_group=self.device_group.name, **kwargs) + model_id=model_id, device_mesh=self.device_mesh, remote_group=self.device_group.name, instance_id=replica_id, **kwargs) self.base_model = model_id self.state: ServerStateProxy = get_server_state() diff --git a/src/twinkle/server/tinker/sampler.py b/src/twinkle/server/tinker/sampler.py index bf4108c9..8ab6fd91 100644 --- a/src/twinkle/server/tinker/sampler.py +++ b/src/twinkle/server/tinker/sampler.py @@ -102,6 +102,8 @@ def __init__(self, else: self.device_mesh = DeviceMesh.from_sizes(**device_mesh) self.sampler_type = sampler_type + replica_context = serve.get_replica_context() + replica_id = replica_context.replica_id.unique_id # Initialize sampler based on type if sampler_type == 'vllm': @@ -112,6 +114,7 @@ def __init__(self, engine_args=sampler_kwargs, device_mesh=self.device_mesh, remote_group=self.device_group.name, + instance_id=replica_id, **{ k: v for k, v in kwargs.items() if k not in ['engine_args'] diff --git a/src/twinkle/server/twinkle/model.py b/src/twinkle/server/twinkle/model.py index 6eb107da..d510569b 100644 --- a/src/twinkle/server/twinkle/model.py +++ b/src/twinkle/server/twinkle/model.py @@ -171,14 +171,16 @@ def __init__(self, nproc_per_node: int, device_group: Dict[str, Any], device_mes self.device_mesh = DeviceMesh(**device_mesh) else: self.device_mesh = DeviceMesh.from_sizes(**device_mesh) + replica_context = serve.get_replica_context() + replica_id = replica_context.replica_id.unique_id if use_megatron: from twinkle.model import MultiLoraMegatronModel self.model = MultiLoraMegatronModel( - model_id=model_id, device_mesh=self.device_mesh, remote_group=self.device_group.name, **kwargs) + model_id=model_id, device_mesh=self.device_mesh, remote_group=self.device_group.name, instance_id=replica_id, **kwargs) else: from twinkle.model import MultiLoraTransformersModel self.model = MultiLoraTransformersModel( - model_id=model_id, device_mesh=self.device_mesh, remote_group=self.device_group.name, **kwargs) + model_id=model_id, device_mesh=self.device_mesh, remote_group=self.device_group.name, instance_id=replica_id, **kwargs) # Initialize state before adapter manager (mixin needs self.state) self.state: ServerStateProxy = get_server_state() diff --git a/src/twinkle/server/twinkle/sampler.py b/src/twinkle/server/twinkle/sampler.py index 857c53f6..bf755ee5 100644 --- a/src/twinkle/server/twinkle/sampler.py +++ b/src/twinkle/server/twinkle/sampler.py @@ -152,7 +152,8 @@ def __init__(self, else: self.device_mesh = DeviceMesh.from_sizes(**device_mesh) self.sampler_type = sampler_type - + replica_context = serve.get_replica_context() + replica_id = replica_context.replica_id.unique_id # Initialize sampler based on type if sampler_type == 'vllm': from twinkle.sampler import vLLMSampler @@ -162,6 +163,7 @@ def __init__(self, engine_args=sampler_kwargs, device_mesh=self.device_mesh, remote_group=self.device_group.name, + instance_id=replica_id, **{ k: v for k, v in kwargs.items() if k not in ['engine_args'] @@ -169,7 +171,7 @@ def __init__(self, else: from twinkle.sampler import TorchSampler self.sampler = TorchSampler( - model_id=model_id, device_mesh=self.device_mesh, remote_group=self.device_group.name, **kwargs) + model_id=model_id, device_mesh=self.device_mesh, instance_id=replica_id, remote_group=self.device_group.name, **kwargs) # Initialize state and adapter manager self.state: ServerStateProxy = get_server_state() From 6fee00d8c37335f84e4ad5e6be79cfd024eae6ab Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Sun, 15 Feb 2026 19:17:45 +0800 Subject: [PATCH 6/9] fix --- .../client/tinker/megatron/server_config.yaml | 2 - .../tinker/megatron/server_config_7b.yaml | 2 - .../tinker/transformer/server_config.yaml | 2 - .../twinkle/transformer/server_config.yaml | 3 -- .../Usage Guide/Server and Client/Server.md | 7 ---- .../\346\234\215\345\212\241\347\253\257.md" | 7 ---- src/twinkle/infra/_ray/ray_helper.py | 13 ------- src/twinkle/infra/_ray/resource_manager.py | 38 ++++++++++++++----- 8 files changed, 28 insertions(+), 46 deletions(-) diff --git a/cookbook/client/tinker/megatron/server_config.yaml b/cookbook/client/tinker/megatron/server_config.yaml index 522f9f55..d13967f2 100644 --- a/cookbook/client/tinker/megatron/server_config.yaml +++ b/cookbook/client/tinker/megatron/server_config.yaml @@ -70,7 +70,6 @@ applications: runtime_env: env_vars: TWINKLE_TRUST_REMOTE_CODE: "0" - DEVICE_COUNT_PER_PHYSICAL_NODE: "16" # 2. Model Service (commented out) - Would host the base model for training. # Uncomment and configure if you need a training model worker. @@ -110,4 +109,3 @@ applications: runtime_env: env_vars: TWINKLE_TRUST_REMOTE_CODE: "0" - DEVICE_COUNT_PER_PHYSICAL_NODE: "16" diff --git a/cookbook/client/tinker/megatron/server_config_7b.yaml b/cookbook/client/tinker/megatron/server_config_7b.yaml index cad014c9..cdac55f7 100644 --- a/cookbook/client/tinker/megatron/server_config_7b.yaml +++ b/cookbook/client/tinker/megatron/server_config_7b.yaml @@ -67,7 +67,6 @@ applications: runtime_env: env_vars: TWINKLE_TRUST_REMOTE_CODE: "0" - DEVICE_COUNT_PER_PHYSICAL_NODE: "8" # 3. Sampler Service - Runs inference / sampling using vLLM engine # Used for generating text from the model (e.g., evaluating LoRA results). @@ -104,4 +103,3 @@ applications: runtime_env: env_vars: TWINKLE_TRUST_REMOTE_CODE: "0" - DEVICE_COUNT_PER_PHYSICAL_NODE: "8" diff --git a/cookbook/client/tinker/transformer/server_config.yaml b/cookbook/client/tinker/transformer/server_config.yaml index 00e57387..20d25f52 100644 --- a/cookbook/client/tinker/transformer/server_config.yaml +++ b/cookbook/client/tinker/transformer/server_config.yaml @@ -65,7 +65,6 @@ applications: runtime_env: env_vars: TWINKLE_TRUST_REMOTE_CODE: "0" - DEVICE_COUNT_PER_PHYSICAL_NODE: "8" # 3. Sampler Service - Runs inference / sampling using vLLM engine # Used for generating text from the model (e.g., evaluating LoRA results). @@ -102,4 +101,3 @@ applications: runtime_env: env_vars: TWINKLE_TRUST_REMOTE_CODE: "0" - DEVICE_COUNT_PER_PHYSICAL_NODE: "8" diff --git a/cookbook/client/twinkle/transformer/server_config.yaml b/cookbook/client/twinkle/transformer/server_config.yaml index 93fe8592..787f0a0b 100644 --- a/cookbook/client/twinkle/transformer/server_config.yaml +++ b/cookbook/client/twinkle/transformer/server_config.yaml @@ -61,7 +61,6 @@ applications: runtime_env: env_vars: TWINKLE_TRUST_REMOTE_CODE: "0" - DEVICE_COUNT_PER_PHYSICAL_NODE: "8" # 3. Processor Service - Handles data preprocessing on CPU # Runs tokenization, template application, and other CPU-bound tasks. @@ -90,7 +89,6 @@ applications: runtime_env: env_vars: TWINKLE_TRUST_REMOTE_CODE: "0" - DEVICE_COUNT_PER_PHYSICAL_NODE: "8" # 4. Sampler Service - Handles text generation inference # Uses vLLM for efficient batched generation with optional LoRA adapters. @@ -125,4 +123,3 @@ applications: runtime_env: env_vars: TWINKLE_TRUST_REMOTE_CODE: "0" - DEVICE_COUNT_PER_PHYSICAL_NODE: "8" diff --git a/docs/source_en/Usage Guide/Server and Client/Server.md b/docs/source_en/Usage Guide/Server and Client/Server.md index ec7b4b42..302a5875 100644 --- a/docs/source_en/Usage Guide/Server and Client/Server.md +++ b/docs/source_en/Usage Guide/Server and Client/Server.md @@ -55,12 +55,9 @@ This configuration starts 3 nodes: Before starting the Server, you need to set the following environment variables: ```bash -export DEVICE_COUNT_PER_PHYSICAL_NODE=8 # Specify the total number of GPUs on each physical machine export TWINKLE_TRUST_REMOTE_CODE=0 # Whether to trust remote code (security consideration) ``` -> **Important Note**: `DEVICE_COUNT_PER_PHYSICAL_NODE` must be set to the actual number of physical GPUs on the machine, which is crucial for correctly parsing the `ranks` configuration. - ### Node Rank in YAML Configuration In the YAML configuration file, **each component needs to occupy a separate Node**. @@ -117,7 +114,6 @@ applications: **Important notes:** - The `ranks` configuration uses **physical GPU card numbers**, directly corresponding to the actual GPU devices on the machine - The `device_mesh` configuration uses parameters like `dp_size`, `tp_size`, `pp_size`, `ep_size` instead of the original `mesh` and `mesh_dim_names` -- The environment variable `DEVICE_COUNT_PER_PHYSICAL_NODE` must be set to inform the system of the total number of physical GPUs on each machine - Different components will be automatically assigned to different Nodes - Ray will automatically schedule to the appropriate Node based on resource requirements (`num_gpus`, `num_cpus` in `ray_actor_options`) @@ -393,7 +389,6 @@ applications: num_cpus: 0.1 runtime_env: env_vars: - DEVICE_COUNT_PER_PHYSICAL_NODE: "8" # Total number of physical GPUs on each machine # 3. Sampler service (optional, for inference sampling) - name: sampler-Qwen2.5-0.5B-Instruct @@ -425,7 +420,6 @@ applications: num_gpus: 1 # Sampler needs independent GPU runtime_env: env_vars: - DEVICE_COUNT_PER_PHYSICAL_NODE: "8" # Total number of physical GPUs on each machine ``` ## Configuration Item Description @@ -471,6 +465,5 @@ device_mesh: **Environment variables:** ```bash -export DEVICE_COUNT_PER_PHYSICAL_NODE=8 # Total number of GPUs on each physical machine (must be set) export TWINKLE_TRUST_REMOTE_CODE=0 # Whether to trust remote code ``` diff --git "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\347\253\257\345\222\214\345\256\242\346\210\267\347\253\257/\346\234\215\345\212\241\347\253\257.md" "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\347\253\257\345\222\214\345\256\242\346\210\267\347\253\257/\346\234\215\345\212\241\347\253\257.md" index ab7a2436..a09b81e2 100644 --- "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\347\253\257\345\222\214\345\256\242\346\210\267\347\253\257/\346\234\215\345\212\241\347\253\257.md" +++ "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\347\253\257\345\222\214\345\256\242\346\210\267\347\253\257/\346\234\215\345\212\241\347\253\257.md" @@ -55,12 +55,9 @@ ray start --address=10.28.252.9:6379 --num-gpus=0 在启动 Server 之前,需要设置以下环境变量: ```bash -export DEVICE_COUNT_PER_PHYSICAL_NODE=8 # 指定每台物理机上的 GPU 总数 export TWINKLE_TRUST_REMOTE_CODE=0 # 是否信任远程代码(安全考虑) ``` -> **重要提示**:`DEVICE_COUNT_PER_PHYSICAL_NODE` 必须设置为机器上实际的物理 GPU 数量,这对于正确解析 `ranks` 配置至关重要。 - ### YAML 配置中的 Node Rank 在 YAML 配置文件中,**每个组件需要占用一个独立的 Node**。 @@ -117,7 +114,6 @@ applications: **重要提示:** - `ranks` 配置使用**物理 GPU 卡号**,直接对应机器上的实际 GPU 设备 - `device_mesh` 配置使用 `dp_size`、`tp_size`、`pp_size`、`ep_size` 等参数替代原来的 `mesh` 和 `mesh_dim_names` -- 必须设置环境变量 `DEVICE_COUNT_PER_PHYSICAL_NODE` 来告知系统每台机器的物理 GPU 总数 - 不同组件会自动分配到不同的 Node 上 - Ray 会根据资源需求(`ray_actor_options` 中的 `num_gpus`、`num_cpus`)自动调度到合适的 Node @@ -336,7 +332,6 @@ applications: num_cpus: 0.1 runtime_env: env_vars: - DEVICE_COUNT_PER_PHYSICAL_NODE: "8" # 每台机器的物理 GPU 总数 # 3. Sampler 服务(可选,用于推理采样) - name: sampler-Qwen2.5-0.5B-Instruct @@ -368,7 +363,6 @@ applications: num_gpus: 1 # Sampler 需要独立 GPU runtime_env: env_vars: - DEVICE_COUNT_PER_PHYSICAL_NODE: "8" # 每台机器的物理 GPU 总数 ``` ## 配置项说明 @@ -414,6 +408,5 @@ device_mesh: **环境变量:** ```bash -export DEVICE_COUNT_PER_PHYSICAL_NODE=8 # 每台物理机上的 GPU 总数(必须设置) export TWINKLE_TRUST_REMOTE_CODE=0 # 是否信任远程代码 ``` diff --git a/src/twinkle/infra/_ray/ray_helper.py b/src/twinkle/infra/_ray/ray_helper.py index 0a03442c..d82fb7bb 100644 --- a/src/twinkle/infra/_ray/ray_helper.py +++ b/src/twinkle/infra/_ray/ray_helper.py @@ -229,19 +229,6 @@ def has_ref(args, kwargs) -> bool: return True return False - @staticmethod - def _noset_env(): - return { - 'RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES': '1', - 'RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES': '1', - 'RAY_EXPERIMENTAL_NOSET_HIP_VISIBLE_DEVICES': '1', - 'RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES': '1', - 'RAY_EXPERIMENTAL_NOSET_HABANA_VISIBLE_MODULES': '1', - 'RAY_EXPERIMENTAL_NOSET_NEURON_RT_VISIBLE_CORES': '1', - 'RAY_EXPERIMENTAL_NOSET_TPU_VISIBLE_CHIPS': '1', - 'RAY_EXPERIMENTAL_NOSET_ONEAPI_DEVICE_SELECTOR': '1', - } - @staticmethod def create_workers(worker_cls: Type[T], group: str, diff --git a/src/twinkle/infra/_ray/resource_manager.py b/src/twinkle/infra/_ray/resource_manager.py index fc657fc3..2d4c1dec 100644 --- a/src/twinkle/infra/_ray/resource_manager.py +++ b/src/twinkle/infra/_ray/resource_manager.py @@ -145,10 +145,19 @@ def get_visible_devices(): if self.placement_groups: self.visible_devices = ray.get([ - get_visible_devices.options(placement_group=pg).remote() for pg in self.placement_groups + get_visible_devices.options(placement_group=pg, runtime_env={ + "env_vars": self.noset_env() + }).remote() for pg in self.placement_groups ]) - breakpoint() + visible_devices = [] + for visible_device in self.visible_devices: + if visible_device: + visible_device = int(visible_device.split(',')[0]) + else: + visible_device = 0 + visible_devices.append(visible_device) + self.visible_devices = visible_devices self.node2pg: Dict[int, PlacementGroup] = {} # Map actual node indices to placement groups @@ -164,12 +173,8 @@ def get_visible_devices(): self.device_groups = {} ray_address = str(ray.get_runtime_context().gcs_address) - if 'DEVICE_COUNT_PER_PHYSICAL_NODE' in os.environ: - # Sometimes, multiply nodes are in one physical node, there may be error in `gpu_rank` - device_per_node = int(os.environ['DEVICE_COUNT_PER_PHYSICAL_NODE']) - else: - device_per_node = nproc_per_node - for group in groups: + assert len(groups) == len(visible_devices) + for group, visible_start_device in zip(groups, self.visible_devices): if group.device_type != 'CPU': ranks = group.ranks gpus_per_worker = getattr(group, 'gpus_per_worker', 1) @@ -191,7 +196,7 @@ def get_visible_devices(): # All GPUs for a worker should be on the same node node_ranks = [r // nproc_per_node for r in worker_ranks] - gpu_ranks_local = [r % device_per_node for r in worker_ranks] + gpu_ranks_local = [r % nproc_per_node + visible_start_device for r in worker_ranks] if len(set(node_ranks)) > 1: raise ValueError(f"DeviceGroup '{group.name}': GPUs {worker_ranks} span multiple nodes. " @@ -206,7 +211,7 @@ def get_visible_devices(): else: for alloc_rank in normalized_ranks: node_rank = alloc_rank // nproc_per_node - gpu_rank = alloc_rank % device_per_node + gpu_rank = alloc_rank % nproc_per_node + visible_start_device local_device_groups.append( dict(gpu_rank=[gpu_rank], placement_group=self.node2pg[node_rank], ray_address=ray_address)) @@ -234,6 +239,19 @@ def get_visible_devices(): logger.info(f'node_ranks: {self.node_ranks}') logger.info(f'node2pg keys: {list(self.node2pg.keys())}') + @staticmethod + def noset_env(): + return { + 'RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES': '1', + 'RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES': '1', + 'RAY_EXPERIMENTAL_NOSET_HIP_VISIBLE_DEVICES': '1', + 'RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES': '1', + 'RAY_EXPERIMENTAL_NOSET_HABANA_VISIBLE_MODULES': '1', + 'RAY_EXPERIMENTAL_NOSET_NEURON_RT_VISIBLE_CORES': '1', + 'RAY_EXPERIMENTAL_NOSET_TPU_VISIBLE_CHIPS': '1', + 'RAY_EXPERIMENTAL_NOSET_ONEAPI_DEVICE_SELECTOR': '1', + } + def get_config(self, group: str): for config in self.group_configs: if config.name == group: From 663ac6743a5e9a6fc7ee615d2e10311467ed54e1 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Mon, 16 Feb 2026 00:53:17 +0800 Subject: [PATCH 7/9] wip --- .../client/tinker/megatron/server_config.yaml | 1 - cookbook/client/tinker/self_congnition.py | 16 +++-- src/twinkle/infra/_ray/ray_helper.py | 2 +- src/twinkle/infra/_ray/resource_manager.py | 10 +-- src/twinkle/server/tinker/model.py | 15 +++-- src/twinkle/server/tinker/server.py | 61 +++++++++++++++---- 6 files changed, 74 insertions(+), 31 deletions(-) diff --git a/cookbook/client/tinker/megatron/server_config.yaml b/cookbook/client/tinker/megatron/server_config.yaml index d13967f2..0d87f0a4 100644 --- a/cookbook/client/tinker/megatron/server_config.yaml +++ b/cookbook/client/tinker/megatron/server_config.yaml @@ -89,7 +89,6 @@ applications: device_mesh: device_type: cuda dp_size: 4 - ep_size: 2 queue_config: rps_limit: 20 # Max requests per second diff --git a/cookbook/client/tinker/self_congnition.py b/cookbook/client/tinker/self_congnition.py index 9f0fba9b..bae64558 100644 --- a/cookbook/client/tinker/self_congnition.py +++ b/cookbook/client/tinker/self_congnition.py @@ -19,7 +19,7 @@ from twinkle.server.tinker.common import input_feature_to_datum # The base model to fine-tune / evaluate -base_model = 'Qwen/Qwen3-30B-A3B-Instruct-2507' +base_model = 'Qwen/Qwen2.5-7B-Instruct' def train(): @@ -44,7 +44,7 @@ def train(): # Connect to the Twinkle server running locally service_client = init_tinker_compat_client( - base_url='http://www.modelscope.cn/twinkle', api_key=os.environ.get('MODELSCOPE_TOKEN')) + base_url='localhost:9000', api_key=os.environ.get('MODELSCOPE_TOKEN')) # Create a LoRA training client for the base model (rank=16 for the LoRA adapter) training_client = service_client.create_lora_training_client(base_model=base_model, rank=16) @@ -68,14 +68,12 @@ def train(): optim_result = optim_future.result() # Compute weighted average log-loss per token for monitoring - logprobs = np.concatenate([output['logprobs'].tolist() for output in fwdbwd_result.loss_fn_outputs]) - weights = np.concatenate([example.loss_fn_inputs['weights'].tolist() for example in input_datum]) - print(f'Loss per token: {-np.dot(logprobs, weights) / weights.sum():.4f}') + print(f'Metric: {optim_result}') # Save a checkpoint after each epoch - save_future = training_client.save_state(f'twinkle-lora-{epoch}') - save_result = save_future.result() - print(f'Saved checkpoint to {save_result.path}') + #save_future = training_client.save_state(f'twinkle-lora-{epoch}') + #save_result = save_future.result() + #print(f'Saved checkpoint to {save_result.path}') def eval(): @@ -85,7 +83,7 @@ def eval(): weight_path = 'twinkle://20260212_174205-Qwen_Qwen2_5-7B-Instruct-51edc9ed/weights/twinkle-lora-2' # Connect to the server and create a sampling client with the trained weights - service_client = init_tinker_compat_client(base_url='http://localhost:8000') + service_client = init_tinker_compat_client(base_url='http://localhost:9000') sampling_client = service_client.create_sampling_client(model_path=weight_path, base_model=base_model) # Step 2: Prepare the chat prompt diff --git a/src/twinkle/infra/_ray/ray_helper.py b/src/twinkle/infra/_ray/ray_helper.py index d82fb7bb..f0a4011b 100644 --- a/src/twinkle/infra/_ray/ray_helper.py +++ b/src/twinkle/infra/_ray/ray_helper.py @@ -307,7 +307,7 @@ def create_workers(worker_cls: Type[T], # Prevent Ray from overriding CUDA_VISIBLE_DEVICES set in runtime_env # This is critical for multi-GPU workers (gpus_per_worker > 1) - env_vars.update(RayHelper._noset_env()) + env_vars.update(ResourceManager.noset_env()) runtime_env = RuntimeEnv(env_vars=env_vars) diff --git a/src/twinkle/infra/_ray/resource_manager.py b/src/twinkle/infra/_ray/resource_manager.py index 2d4c1dec..cf001f98 100644 --- a/src/twinkle/infra/_ray/resource_manager.py +++ b/src/twinkle/infra/_ray/resource_manager.py @@ -153,9 +153,9 @@ def get_visible_devices(): visible_devices = [] for visible_device in self.visible_devices: if visible_device: - visible_device = int(visible_device.split(',')[0]) + visible_device = [int(device) for device in visible_device.split(',')] else: - visible_device = 0 + visible_device = list(range(nproc_per_node)) visible_devices.append(visible_device) self.visible_devices = visible_devices @@ -174,7 +174,7 @@ def get_visible_devices(): self.device_groups = {} ray_address = str(ray.get_runtime_context().gcs_address) assert len(groups) == len(visible_devices) - for group, visible_start_device in zip(groups, self.visible_devices): + for group, visible_device_list in zip(groups, self.visible_devices): if group.device_type != 'CPU': ranks = group.ranks gpus_per_worker = getattr(group, 'gpus_per_worker', 1) @@ -196,7 +196,7 @@ def get_visible_devices(): # All GPUs for a worker should be on the same node node_ranks = [r // nproc_per_node for r in worker_ranks] - gpu_ranks_local = [r % nproc_per_node + visible_start_device for r in worker_ranks] + gpu_ranks_local = [visible_device_list[r % nproc_per_node] for r in worker_ranks] if len(set(node_ranks)) > 1: raise ValueError(f"DeviceGroup '{group.name}': GPUs {worker_ranks} span multiple nodes. " @@ -211,7 +211,7 @@ def get_visible_devices(): else: for alloc_rank in normalized_ranks: node_rank = alloc_rank // nproc_per_node - gpu_rank = alloc_rank % nproc_per_node + visible_start_device + gpu_rank = visible_device_list[alloc_rank % nproc_per_node] local_device_groups.append( dict(gpu_rank=[gpu_rank], placement_group=self.node2pg[node_rank], ray_address=ray_address)) diff --git a/src/twinkle/server/tinker/model.py b/src/twinkle/server/tinker/model.py index c25c9ab0..e1f5108c 100644 --- a/src/twinkle/server/tinker/model.py +++ b/src/twinkle/server/tinker/model.py @@ -9,7 +9,6 @@ 3. Checkpoint management (save/load weights) 4. Multi-user support with token-based isolation """ -import os import traceback from fastapi import FastAPI, Request from peft import LoraConfig @@ -56,6 +55,7 @@ def build_model_app(model_id: str, Returns: Configured Ray Serve deployment bound with parameters """ + import ray app = FastAPI() @app.middleware('http') @@ -172,7 +172,10 @@ async def create_model(self, request: Request, body: types.CreateModelRequest) - Returns: UntypedAPIFuture wrapping CreateModelResponse with model_id """ + if isinstance(body, dict): + body = types.CreateModelRequest(**body) # Register a new model_id for each create_model call + await self.get_multiplexed_adapter(request.state.request_id) model_id = self.state.register_model(body.model_dump(), token=request.state.token) async def _create_adapter(): @@ -195,7 +198,7 @@ async def _create_adapter(): # Fresh adapter has no accumulated gradients. self.set_adapter_state(adapter_name, 'grad_ready', False) - await self.get_multiplexed_adapter(request.state.request_id) + logger.info(f'Create adapter: {adapter_name}, request_id: {request.state.request_id}, ray node: {ray.get_runtime_context().get_node_id()}') training_run_manager = create_training_run_manager(request.state.token) training_run_manager.save(model_id, body) @@ -346,10 +349,12 @@ async def forward_backward(self, request: Request, Returns: UntypedAPIFuture wrapping ForwardBackwardOutput with loss and metrics """ - + if isinstance(body, dict): + body = types.ForwardBackwardRequest(**body) async def _do_forward_backward(): try: adapter_name = self.get_adapter_name(adapter_name=body.model_id) + logger.info(f'forward_backward: {adapter_name}, request_id: {request.state.request_id}, ray: {ray.get_runtime_context().get_node_id()}') self.assert_adapter_exists(adapter_name=adapter_name) # Touch adapter to reset inactivity counter @@ -407,10 +412,12 @@ async def optim_step(self, request: Request, body: types.OptimStepRequest) -> ty Returns: UntypedAPIFuture wrapping OptimStepResponse """ - + if isinstance(body, dict): + body = types.OptimStepRequest(**body) async def _do_optim(): try: adapter_name = self.get_adapter_name(adapter_name=body.model_id) + logger.info(f'optim_step: {adapter_name}, request_id: {request.state.request_id}, ray: {ray.get_runtime_context().node_id}') self.assert_adapter_exists(adapter_name=adapter_name) # Disallow empty step (must have at least one forward_backward since last step) diff --git a/src/twinkle/server/tinker/server.py b/src/twinkle/server/tinker/server.py index e0278e68..603d742a 100644 --- a/src/twinkle/server/tinker/server.py +++ b/src/twinkle/server/tinker/server.py @@ -12,6 +12,8 @@ from __future__ import annotations import asyncio +import dataclasses + import httpx import logging import os @@ -82,6 +84,7 @@ def __init__(self, self.client = httpx.AsyncClient(timeout=None, trust_env=False) self.route_prefix = kwargs.get('route_prefix', '/api/v1') self.supported_models = self.normalize_models(supported_models) or [ + types.SupportedModel(model_name='Qwen/Qwen2.5-7B-Instruct'), types.SupportedModel(model_name='Qwen/Qwen3-30B-A3B-Instruct-2507'), ] # Lock for ModelScope config file operations (login writes, get_user_info reads) @@ -160,23 +163,59 @@ async def _proxy_request(self, request: Request, endpoint: str, base_model: str, headers.pop('content-length', None) try: - if os.environ.get('TWINKLE_DEBUG_PROXY', '0') == '1': + if os.environ.get('TWINKLE_DEBUG_PROXY', '1') == '1': logger.info('proxy_to_model endpoint=%s target_url=%s serve_multiplexed_model_id=%s', endpoint, target_url, headers.get('serve_multiplexed_model_id')) - rp_ = await self.client.request( - method=request.method, - url=target_url, - content=body_bytes, - headers=headers, - params=request.query_params, + handle = serve.get_deployment_handle( + deployment_name="ModelManagement", + app_name="models-Qwen3-30B-A3B-Instruct-2507" ) + + def make_fake_request(original_request: Request): + """用 SimpleNamespace 模拟 Request""" + from types import SimpleNamespace + fake = SimpleNamespace() + fake.headers = dict(original_request.headers) + + fake.state = SimpleNamespace() + fake.state.request_id = headers.get('serve_multiplexed_model_id') + fake.state.token = getattr(original_request.state, 'token', None) + return fake + + fake_request = make_fake_request(request) + import json + result = await getattr(handle.options( + multiplexed_model_id=headers.get('serve_multiplexed_model_id') + ), endpoint).remote( + body=json.loads(body_bytes), + request=fake_request, + ) if os.environ.get('TWINKLE_DEBUG_PROXY', '0') == '1': logger.info('proxy_to_model response status=%s body=%s', rp_.status_code, rp_.text[:200]) + + # 处理返回值 + if hasattr(result, 'model_dump'): + # Pydantic v2 + content = json.dumps(result.model_dump()) + elif hasattr(result, 'dict'): + # Pydantic v1 + content = json.dumps(result.dict()) + elif isinstance(result, dict): + content = json.dumps(result) + elif isinstance(result, (str, bytes)): + content = result + else: + content = json.dumps(result) + + # 判断是否是错误响应 + if isinstance(result, types.RequestFailedResponse): + status_code = 500 + else: + status_code = 200 return Response( - content=rp_.content, - status_code=rp_.status_code, - headers=dict(rp_.headers), - media_type=rp_.headers.get('content-type'), + content=content, + status_code=status_code, + media_type='application/json', ) except Exception as e: return Response(content=f'Proxy Error: {str(e)}', status_code=502) From a5e2aefeab4cb6bc2e7641db408ac7de2a1a3ba1 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Mon, 16 Feb 2026 09:41:03 +0800 Subject: [PATCH 8/9] wip --- .../client/tinker/megatron/server_config.yaml | 19 ++++----- cookbook/client/tinker/self_congnition.py | 13 ++++--- src/twinkle/infra/_ray/resource_manager.py | 2 +- src/twinkle/server/tinker/model.py | 39 ++++++++++++------- src/twinkle/server/tinker/server.py | 17 +++----- src/twinkle/server/twinkle/model.py | 28 +++++-------- src/twinkle/server/twinkle/processor.py | 7 +--- src/twinkle/server/twinkle/sampler.py | 6 ++- src/twinkle/server/twinkle/server.py | 4 +- src/twinkle/server/utils/adapter_manager.py | 2 +- 10 files changed, 68 insertions(+), 69 deletions(-) diff --git a/cookbook/client/tinker/megatron/server_config.yaml b/cookbook/client/tinker/megatron/server_config.yaml index 0d87f0a4..74c0e717 100644 --- a/cookbook/client/tinker/megatron/server_config.yaml +++ b/cookbook/client/tinker/megatron/server_config.yaml @@ -37,15 +37,15 @@ applications: # 3. Sampler Service - Runs inference / sampling using vLLM engine # Used for generating text from the model (e.g., evaluating LoRA results). - name: sampler-Qwen3-30B-A3B-Instruct-2507 - route_prefix: /api/v1/sampler/Qwen/Qwen2.5-7B-Instruct + route_prefix: /api/v1/sampler/Qwen/Qwen3-30B-A3B-Instruct-2507 import_path: sampler args: - model_id: "ms://Qwen/Qwen2.5-7B-Instruct" # ModelScope model identifier + model_id: "ms://Qwen/Qwen3-30B-A3B-Instruct-2507" # ModelScope model identifier nproc_per_node: 4 # Number of GPU processes per node sampler_type: vllm # Inference engine: 'vllm' (fast) or 'torch' (TorchSampler) engine_args: # vLLM engine-specific settings max_model_len: 16000 # Maximum sequence length the engine supports - gpu_memory_utilization: 0.7 # Fraction of GPU memory to use (0.0-1.0) + gpu_memory_utilization: 0.85 # Fraction of GPU memory to use (0.0-1.0) enable_lora: true # Allow loading LoRA adapters during inference max_loras: 5 # Max allowed loras working on vLLM at the same time device_group: # Logical device group for the sampler @@ -62,8 +62,8 @@ applications: deployments: - name: SamplerManagement autoscaling_config: - min_replicas: 2 - max_replicas: 2 + min_replicas: 1 + max_replicas: 1 target_ongoing_requests: 16 ray_actor_options: num_cpus: 0.1 @@ -74,11 +74,11 @@ applications: # 2. Model Service (commented out) - Would host the base model for training. # Uncomment and configure if you need a training model worker. - name: models-Qwen3-30B-A3B-Instruct-2507 - route_prefix: /api/v1/model/Qwen/Qwen2.5-7B-Instruct + route_prefix: /api/v1/model/Qwen/Qwen3-30B-A3B-Instruct-2507 import_path: model args: use_megatron: true # Use HuggingFace Transformers backend - model_id: "ms://Qwen/Qwen2.5-7B-Instruct" # ModelScope model identifier + model_id: "ms://Qwen/Qwen3-30B-A3B-Instruct-2507" # ModelScope model identifier max_length: 16000 # model max length max_loras: 5 # model max loras nproc_per_node: 4 # Number of GPU processes per node @@ -89,6 +89,7 @@ applications: device_mesh: device_type: cuda dp_size: 4 + ep_size: 2 queue_config: rps_limit: 20 # Max requests per second @@ -100,8 +101,8 @@ applications: deployments: - name: ModelManagement autoscaling_config: - min_replicas: 2 - max_replicas: 2 + min_replicas: 1 + max_replicas: 1 target_ongoing_requests: 8 ray_actor_options: num_cpus: 0.1 diff --git a/cookbook/client/tinker/self_congnition.py b/cookbook/client/tinker/self_congnition.py index bae64558..13a462b4 100644 --- a/cookbook/client/tinker/self_congnition.py +++ b/cookbook/client/tinker/self_congnition.py @@ -19,7 +19,7 @@ from twinkle.server.tinker.common import input_feature_to_datum # The base model to fine-tune / evaluate -base_model = 'Qwen/Qwen2.5-7B-Instruct' +base_model = 'Qwen/Qwen3-30B-A3B-Instruct-2507' def train(): @@ -68,12 +68,15 @@ def train(): optim_result = optim_future.result() # Compute weighted average log-loss per token for monitoring - print(f'Metric: {optim_result}') + # logprobs = np.concatenate([output['logprobs'].tolist() for output in fwdbwd_result.loss_fn_outputs]) + # weights = np.concatenate([example.loss_fn_inputs['weights'].tolist() for example in input_datum]) + # print(f'Loss per token: {-np.dot(logprobs, weights) / weights.sum():.4f}') + print(f'Training Metrics: {optim_result}') # Save a checkpoint after each epoch - #save_future = training_client.save_state(f'twinkle-lora-{epoch}') - #save_result = save_future.result() - #print(f'Saved checkpoint to {save_result.path}') + save_future = training_client.save_state(f'twinkle-lora-{epoch}') + save_result = save_future.result() + print(f'Saved checkpoint to {save_result.path}') def eval(): diff --git a/src/twinkle/infra/_ray/resource_manager.py b/src/twinkle/infra/_ray/resource_manager.py index cf001f98..7a45aa8e 100644 --- a/src/twinkle/infra/_ray/resource_manager.py +++ b/src/twinkle/infra/_ray/resource_manager.py @@ -146,7 +146,7 @@ def get_visible_devices(): if self.placement_groups: self.visible_devices = ray.get([ get_visible_devices.options(placement_group=pg, runtime_env={ - "env_vars": self.noset_env() + 'env_vars': self.noset_env() }).remote() for pg in self.placement_groups ]) diff --git a/src/twinkle/server/tinker/model.py b/src/twinkle/server/tinker/model.py index e1f5108c..7b280f0f 100644 --- a/src/twinkle/server/tinker/model.py +++ b/src/twinkle/server/tinker/model.py @@ -106,11 +106,19 @@ def __init__(self, if use_megatron: from .common.megatron_model import TwinkleCompatMegatronModel self.model = TwinkleCompatMegatronModel( - model_id=model_id, device_mesh=self.device_mesh, remote_group=self.device_group.name, instance_id=replica_id, **kwargs) + model_id=model_id, + device_mesh=self.device_mesh, + remote_group=self.device_group.name, + instance_id=replica_id, + **kwargs) else: from .common.transformers_model import TwinkleCompatTransformersModel self.model = TwinkleCompatTransformersModel( - model_id=model_id, device_mesh=self.device_mesh, remote_group=self.device_group.name, instance_id=replica_id, **kwargs) + model_id=model_id, + device_mesh=self.device_mesh, + remote_group=self.device_group.name, + instance_id=replica_id, + **kwargs) self.base_model = model_id self.state: ServerStateProxy = get_server_state() @@ -120,9 +128,18 @@ def __init__(self, self._init_adapter_manager(**adapter_config) self.start_adapter_countdown() - @serve.multiplexed(max_num_models_per_replica=kwargs.get('max_loras', 5)) - async def get_multiplexed_adapter(self, request_id: str): - return request_id + """ + TODO This is a cache system, we must change to sticky routing + Reference docs: + 1. [Now]https://docs.ray.io/en/latest/serve/model-multiplexing.html + 2. https://docs.ray.io/en/latest/serve/llm/architecture/routing-policies.html + 3. https://github.com/ray-project/ray/pull/56855/changes + 4. Direct call actor instead of http or handler in server.py + """ + + # @serve.multiplexed(max_num_models_per_replica=kwargs.get('max_loras', 5)) + # async def get_multiplexed_adapter(self, request_id: str): + # return request_id def _cleanup_adapter(self, adapter_name: str) -> None: """Common adapter cleanup logic used by both manual unload and automatic expiration. @@ -172,10 +189,7 @@ async def create_model(self, request: Request, body: types.CreateModelRequest) - Returns: UntypedAPIFuture wrapping CreateModelResponse with model_id """ - if isinstance(body, dict): - body = types.CreateModelRequest(**body) # Register a new model_id for each create_model call - await self.get_multiplexed_adapter(request.state.request_id) model_id = self.state.register_model(body.model_dump(), token=request.state.token) async def _create_adapter(): @@ -198,7 +212,6 @@ async def _create_adapter(): # Fresh adapter has no accumulated gradients. self.set_adapter_state(adapter_name, 'grad_ready', False) - logger.info(f'Create adapter: {adapter_name}, request_id: {request.state.request_id}, ray node: {ray.get_runtime_context().get_node_id()}') training_run_manager = create_training_run_manager(request.state.token) training_run_manager.save(model_id, body) @@ -349,12 +362,10 @@ async def forward_backward(self, request: Request, Returns: UntypedAPIFuture wrapping ForwardBackwardOutput with loss and metrics """ - if isinstance(body, dict): - body = types.ForwardBackwardRequest(**body) + async def _do_forward_backward(): try: adapter_name = self.get_adapter_name(adapter_name=body.model_id) - logger.info(f'forward_backward: {adapter_name}, request_id: {request.state.request_id}, ray: {ray.get_runtime_context().get_node_id()}') self.assert_adapter_exists(adapter_name=adapter_name) # Touch adapter to reset inactivity counter @@ -412,12 +423,10 @@ async def optim_step(self, request: Request, body: types.OptimStepRequest) -> ty Returns: UntypedAPIFuture wrapping OptimStepResponse """ - if isinstance(body, dict): - body = types.OptimStepRequest(**body) + async def _do_optim(): try: adapter_name = self.get_adapter_name(adapter_name=body.model_id) - logger.info(f'optim_step: {adapter_name}, request_id: {request.state.request_id}, ray: {ray.get_runtime_context().node_id}') self.assert_adapter_exists(adapter_name=adapter_name) # Disallow empty step (must have at least one forward_backward since last step) diff --git a/src/twinkle/server/tinker/server.py b/src/twinkle/server/tinker/server.py index 603d742a..3722753b 100644 --- a/src/twinkle/server/tinker/server.py +++ b/src/twinkle/server/tinker/server.py @@ -13,7 +13,6 @@ import asyncio import dataclasses - import httpx import logging import os @@ -84,7 +83,6 @@ def __init__(self, self.client = httpx.AsyncClient(timeout=None, trust_env=False) self.route_prefix = kwargs.get('route_prefix', '/api/v1') self.supported_models = self.normalize_models(supported_models) or [ - types.SupportedModel(model_name='Qwen/Qwen2.5-7B-Instruct'), types.SupportedModel(model_name='Qwen/Qwen3-30B-A3B-Instruct-2507'), ] # Lock for ModelScope config file operations (login writes, get_user_info reads) @@ -167,9 +165,7 @@ async def _proxy_request(self, request: Request, endpoint: str, base_model: str, logger.info('proxy_to_model endpoint=%s target_url=%s serve_multiplexed_model_id=%s', endpoint, target_url, headers.get('serve_multiplexed_model_id')) handle = serve.get_deployment_handle( - deployment_name="ModelManagement", - app_name="models-Qwen3-30B-A3B-Instruct-2507" - ) + deployment_name='ModelManagement', app_name='models-Qwen3-30B-A3B-Instruct-2507') def make_fake_request(original_request: Request): """用 SimpleNamespace 模拟 Request""" @@ -184,12 +180,11 @@ def make_fake_request(original_request: Request): fake_request = make_fake_request(request) import json - result = await getattr(handle.options( - multiplexed_model_id=headers.get('serve_multiplexed_model_id') - ), endpoint).remote( - body=json.loads(body_bytes), - request=fake_request, - ) + result = await getattr( + handle.options(multiplexed_model_id=headers.get('serve_multiplexed_model_id')), endpoint).remote( + body=json.loads(body_bytes), + request=fake_request, + ) if os.environ.get('TWINKLE_DEBUG_PROXY', '0') == '1': logger.info('proxy_to_model response status=%s body=%s', rp_.status_code, rp_.text[:200]) diff --git a/src/twinkle/server/twinkle/model.py b/src/twinkle/server/twinkle/model.py index d510569b..50bbc3eb 100644 --- a/src/twinkle/server/twinkle/model.py +++ b/src/twinkle/server/twinkle/model.py @@ -176,11 +176,19 @@ def __init__(self, nproc_per_node: int, device_group: Dict[str, Any], device_mes if use_megatron: from twinkle.model import MultiLoraMegatronModel self.model = MultiLoraMegatronModel( - model_id=model_id, device_mesh=self.device_mesh, remote_group=self.device_group.name, instance_id=replica_id, **kwargs) + model_id=model_id, + device_mesh=self.device_mesh, + remote_group=self.device_group.name, + instance_id=replica_id, + **kwargs) else: from twinkle.model import MultiLoraTransformersModel self.model = MultiLoraTransformersModel( - model_id=model_id, device_mesh=self.device_mesh, remote_group=self.device_group.name, instance_id=replica_id, **kwargs) + model_id=model_id, + device_mesh=self.device_mesh, + remote_group=self.device_group.name, + instance_id=replica_id, + **kwargs) # Initialize state before adapter manager (mixin needs self.state) self.state: ServerStateProxy = get_server_state() @@ -189,21 +197,6 @@ def __init__(self, nproc_per_node: int, device_group: Dict[str, Any], device_mes self._init_adapter_manager(**adapter_config) self.start_adapter_countdown() - @serve.multiplexed(max_num_models_per_replica=kwargs.get('max_loras', 5)) - async def get_multiplexed_adapter(self, request_id: str): - """ - Reference docs: - 1. https://docs.ray.io/en/latest/serve/model-multiplexing.html - 2. https://docs.ray.io/en/latest/serve/llm/architecture/routing-policies.html - 3. https://github.com/ray-project/ray/pull/56855/changes - Args: - request_id: - - Returns: - - """ - return request_id - def _on_adapter_expired(self, adapter_name: str) -> None: """Handle adapter expiration by removing it from the model. @@ -528,7 +521,6 @@ async def add_adapter_to_model(self, request: Request, body: AddAdapterRequest): # Create adapter AFTER successful registration self.model.add_adapter_to_model(adapter_name, config, **extra_kwargs) - await self.get_multiplexed_adapter(request.state.request_id) # Save training run metadata (similar to tinker's create_model) # Create a training run config from the adapter configuration diff --git a/src/twinkle/server/twinkle/processor.py b/src/twinkle/server/twinkle/processor.py index 912bcc29..cbead9b7 100644 --- a/src/twinkle/server/twinkle/processor.py +++ b/src/twinkle/server/twinkle/processor.py @@ -86,10 +86,6 @@ def countdown(self): if key in self.key_token_dict: self.handle_processor_count(self.key_token_dict.pop(key), False) - @serve.multiplexed(max_num_models_per_replica=100) - async def get_multiplexed_adapter(self, request_id: str): - return request_id - def assert_processor_exists(self, processor_id: str): assert processor_id and processor_id in self.resource_dict, f'Processor {processor_id} not found' @@ -109,7 +105,7 @@ def handle_processor_count(self, token: str, add: bool): self.state.pop_config(user_key) @app.post('/create') - async def create(self, request: Request, body: CreateRequest): + def create(self, request: Request, body: CreateRequest): processor_type_name = body.processor_type class_type = body.class_type @@ -138,7 +134,6 @@ async def create(self, request: Request, body: CreateRequest): remote_group=self.device_group.name, device_mesh=self.device_mesh, instance_id=processor_id, **_kwargs) self.resource_dict[processor_id] = processor self.resource_records[processor_id] = 0 - await self.get_multiplexed_adapter(request.state.request_id) return {'processor_id': 'pid:' + processor_id} @app.post('/heartbeat') diff --git a/src/twinkle/server/twinkle/sampler.py b/src/twinkle/server/twinkle/sampler.py index bf755ee5..62cb6a72 100644 --- a/src/twinkle/server/twinkle/sampler.py +++ b/src/twinkle/server/twinkle/sampler.py @@ -171,7 +171,11 @@ def __init__(self, else: from twinkle.sampler import TorchSampler self.sampler = TorchSampler( - model_id=model_id, device_mesh=self.device_mesh, instance_id=replica_id, remote_group=self.device_group.name, **kwargs) + model_id=model_id, + device_mesh=self.device_mesh, + instance_id=replica_id, + remote_group=self.device_group.name, + **kwargs) # Initialize state and adapter manager self.state: ServerStateProxy = get_server_state() diff --git a/src/twinkle/server/twinkle/server.py b/src/twinkle/server/twinkle/server.py index 42d2b4b2..86857647 100644 --- a/src/twinkle/server/twinkle/server.py +++ b/src/twinkle/server/twinkle/server.py @@ -12,10 +12,10 @@ """ from __future__ import annotations -from fastapi import FastAPI, HTTPException, Request, Response +from fastapi import FastAPI, HTTPException, Request from pydantic import BaseModel from ray import serve -from typing import Any, Dict, List, Optional +from typing import Any from twinkle.server.utils.state import ServerStateProxy, get_server_state from twinkle.server.utils.validation import get_token_from_request, verify_request_token diff --git a/src/twinkle/server/utils/adapter_manager.py b/src/twinkle/server/utils/adapter_manager.py index 04e56922..c24ce466 100644 --- a/src/twinkle/server/utils/adapter_manager.py +++ b/src/twinkle/server/utils/adapter_manager.py @@ -13,7 +13,7 @@ import threading import time -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Any if TYPE_CHECKING: from twinkle.server.utils.state import ServerStateProxy From 8f7e086be9085f430848fb892ff0aab99ed9d1d4 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Mon, 16 Feb 2026 09:47:15 +0800 Subject: [PATCH 9/9] wip --- ROADMAP.md | 2 + src/twinkle/server/tinker/model.py | 1 - src/twinkle/server/tinker/server.py | 61 +++++++---------------------- src/twinkle/server/twinkle/model.py | 2 +- 4 files changed, 17 insertions(+), 49 deletions(-) diff --git a/ROADMAP.md b/ROADMAP.md index 9294fc5d..8fa0c69e 100644 --- a/ROADMAP.md +++ b/ROADMAP.md @@ -65,6 +65,7 @@ - [ ] 支持DPO对齐训练 - [ ] 支持colocate RL训练 - [ ] Preprocess支持batched +- [ ] 对多replica的支持和粘滞路由 ### 网络能力 @@ -84,5 +85,6 @@ - [ ] Support for DPO alignment training - [ ] Support for colocate RL training - [ ] Support for batched preprocessing +- [ ] Support for multiple replicas and sticky routing ### Networking Capabilities diff --git a/src/twinkle/server/tinker/model.py b/src/twinkle/server/tinker/model.py index 7b280f0f..64bfacab 100644 --- a/src/twinkle/server/tinker/model.py +++ b/src/twinkle/server/tinker/model.py @@ -55,7 +55,6 @@ def build_model_app(model_id: str, Returns: Configured Ray Serve deployment bound with parameters """ - import ray app = FastAPI() @app.middleware('http') diff --git a/src/twinkle/server/tinker/server.py b/src/twinkle/server/tinker/server.py index 3722753b..1a706b45 100644 --- a/src/twinkle/server/tinker/server.py +++ b/src/twinkle/server/tinker/server.py @@ -161,56 +161,23 @@ async def _proxy_request(self, request: Request, endpoint: str, base_model: str, headers.pop('content-length', None) try: - if os.environ.get('TWINKLE_DEBUG_PROXY', '1') == '1': - logger.info('proxy_to_model endpoint=%s target_url=%s serve_multiplexed_model_id=%s', endpoint, - target_url, headers.get('serve_multiplexed_model_id')) - handle = serve.get_deployment_handle( - deployment_name='ModelManagement', app_name='models-Qwen3-30B-A3B-Instruct-2507') - - def make_fake_request(original_request: Request): - """用 SimpleNamespace 模拟 Request""" - from types import SimpleNamespace - fake = SimpleNamespace() - fake.headers = dict(original_request.headers) - - fake.state = SimpleNamespace() - fake.state.request_id = headers.get('serve_multiplexed_model_id') - fake.state.token = getattr(original_request.state, 'token', None) - return fake - - fake_request = make_fake_request(request) - import json - result = await getattr( - handle.options(multiplexed_model_id=headers.get('serve_multiplexed_model_id')), endpoint).remote( - body=json.loads(body_bytes), - request=fake_request, - ) + if os.environ.get('TWINKLE_DEBUG_PROXY', '0') == '1': + logger.info('proxy_to_model endpoint=%s target_url=%s x-ray-serve-request-id=%s', endpoint, + target_url, headers.get('x-ray-serve-request-id')) + rp_ = await self.client.request( + method=request.method, + url=target_url, + content=body_bytes, + headers=headers, + params=request.query_params, + ) if os.environ.get('TWINKLE_DEBUG_PROXY', '0') == '1': logger.info('proxy_to_model response status=%s body=%s', rp_.status_code, rp_.text[:200]) - - # 处理返回值 - if hasattr(result, 'model_dump'): - # Pydantic v2 - content = json.dumps(result.model_dump()) - elif hasattr(result, 'dict'): - # Pydantic v1 - content = json.dumps(result.dict()) - elif isinstance(result, dict): - content = json.dumps(result) - elif isinstance(result, (str, bytes)): - content = result - else: - content = json.dumps(result) - - # 判断是否是错误响应 - if isinstance(result, types.RequestFailedResponse): - status_code = 500 - else: - status_code = 200 return Response( - content=content, - status_code=status_code, - media_type='application/json', + content=rp_.content, + status_code=rp_.status_code, + headers=dict(rp_.headers), + media_type=rp_.headers.get('content-type'), ) except Exception as e: return Response(content=f'Proxy Error: {str(e)}', status_code=502) diff --git a/src/twinkle/server/twinkle/model.py b/src/twinkle/server/twinkle/model.py index 50bbc3eb..1fcf6f8a 100644 --- a/src/twinkle/server/twinkle/model.py +++ b/src/twinkle/server/twinkle/model.py @@ -491,7 +491,7 @@ def upload_to_hub(self, request: Request, body: UploadToHubRequest): return {'result': body.hub_model_id} @app.post('/add_adapter_to_model') - async def add_adapter_to_model(self, request: Request, body: AddAdapterRequest): + def add_adapter_to_model(self, request: Request, body: AddAdapterRequest): """ Add a new adapter to the model.