diff --git a/ROADMAP.md b/ROADMAP.md index 9294fc5d..8fa0c69e 100644 --- a/ROADMAP.md +++ b/ROADMAP.md @@ -65,6 +65,7 @@ - [ ] 支持DPO对齐训练 - [ ] 支持colocate RL训练 - [ ] Preprocess支持batched +- [ ] 对多replica的支持和粘滞路由 ### 网络能力 @@ -84,5 +85,6 @@ - [ ] Support for DPO alignment training - [ ] Support for colocate RL training - [ ] Support for batched preprocessing +- [ ] Support for multiple replicas and sticky routing ### Networking Capabilities diff --git a/cookbook/client/tinker/megatron/server_config.yaml b/cookbook/client/tinker/megatron/server_config.yaml index fe9ea0d6..74c0e717 100644 --- a/cookbook/client/tinker/megatron/server_config.yaml +++ b/cookbook/client/tinker/megatron/server_config.yaml @@ -33,7 +33,6 @@ applications: runtime_env: env_vars: TWINKLE_TRUST_REMOTE_CODE: "0" - DEVICE_COUNT_PER_PHYSICAL_NODE: "8" # 3. Sampler Service - Runs inference / sampling using vLLM engine # Used for generating text from the model (e.g., evaluating LoRA results). @@ -52,7 +51,7 @@ applications: device_group: # Logical device group for the sampler name: sampler gpus_per_worker: 1 - ranks: [0,1,2,3] # GPU rank indices to use + ranks: 4 # GPU rank indices to use device_type: cuda device_mesh: device_type: cuda @@ -71,7 +70,6 @@ applications: runtime_env: env_vars: TWINKLE_TRUST_REMOTE_CODE: "0" - DEVICE_COUNT_PER_PHYSICAL_NODE: "8" # 2. Model Service (commented out) - Would host the base model for training. # Uncomment and configure if you need a training model worker. @@ -86,7 +84,7 @@ applications: nproc_per_node: 4 # Number of GPU processes per node device_group: name: model - ranks: [4,5,6,7] # GPU rank indices + ranks: 4 # GPU rank indices device_type: cuda device_mesh: device_type: cuda @@ -111,4 +109,3 @@ applications: runtime_env: env_vars: TWINKLE_TRUST_REMOTE_CODE: "0" - DEVICE_COUNT_PER_PHYSICAL_NODE: "8" diff --git a/cookbook/client/tinker/megatron/server_config_7b.yaml b/cookbook/client/tinker/megatron/server_config_7b.yaml index cad014c9..cdac55f7 100644 --- a/cookbook/client/tinker/megatron/server_config_7b.yaml +++ b/cookbook/client/tinker/megatron/server_config_7b.yaml @@ -67,7 +67,6 @@ applications: runtime_env: env_vars: TWINKLE_TRUST_REMOTE_CODE: "0" - DEVICE_COUNT_PER_PHYSICAL_NODE: "8" # 3. Sampler Service - Runs inference / sampling using vLLM engine # Used for generating text from the model (e.g., evaluating LoRA results). @@ -104,4 +103,3 @@ applications: runtime_env: env_vars: TWINKLE_TRUST_REMOTE_CODE: "0" - DEVICE_COUNT_PER_PHYSICAL_NODE: "8" diff --git a/cookbook/client/tinker/self_congnition.py b/cookbook/client/tinker/self_congnition.py index 9f0fba9b..13a462b4 100644 --- a/cookbook/client/tinker/self_congnition.py +++ b/cookbook/client/tinker/self_congnition.py @@ -44,7 +44,7 @@ def train(): # Connect to the Twinkle server running locally service_client = init_tinker_compat_client( - base_url='http://www.modelscope.cn/twinkle', api_key=os.environ.get('MODELSCOPE_TOKEN')) + base_url='localhost:9000', api_key=os.environ.get('MODELSCOPE_TOKEN')) # Create a LoRA training client for the base model (rank=16 for the LoRA adapter) training_client = service_client.create_lora_training_client(base_model=base_model, rank=16) @@ -68,9 +68,10 @@ def train(): optim_result = optim_future.result() # Compute weighted average log-loss per token for monitoring - logprobs = np.concatenate([output['logprobs'].tolist() for output in fwdbwd_result.loss_fn_outputs]) - weights = np.concatenate([example.loss_fn_inputs['weights'].tolist() for example in input_datum]) - print(f'Loss per token: {-np.dot(logprobs, weights) / weights.sum():.4f}') + # logprobs = np.concatenate([output['logprobs'].tolist() for output in fwdbwd_result.loss_fn_outputs]) + # weights = np.concatenate([example.loss_fn_inputs['weights'].tolist() for example in input_datum]) + # print(f'Loss per token: {-np.dot(logprobs, weights) / weights.sum():.4f}') + print(f'Training Metrics: {optim_result}') # Save a checkpoint after each epoch save_future = training_client.save_state(f'twinkle-lora-{epoch}') @@ -85,7 +86,7 @@ def eval(): weight_path = 'twinkle://20260212_174205-Qwen_Qwen2_5-7B-Instruct-51edc9ed/weights/twinkle-lora-2' # Connect to the server and create a sampling client with the trained weights - service_client = init_tinker_compat_client(base_url='http://localhost:8000') + service_client = init_tinker_compat_client(base_url='http://localhost:9000') sampling_client = service_client.create_sampling_client(model_path=weight_path, base_model=base_model) # Step 2: Prepare the chat prompt diff --git a/cookbook/client/tinker/transformer/server_config.yaml b/cookbook/client/tinker/transformer/server_config.yaml index 00e57387..20d25f52 100644 --- a/cookbook/client/tinker/transformer/server_config.yaml +++ b/cookbook/client/tinker/transformer/server_config.yaml @@ -65,7 +65,6 @@ applications: runtime_env: env_vars: TWINKLE_TRUST_REMOTE_CODE: "0" - DEVICE_COUNT_PER_PHYSICAL_NODE: "8" # 3. Sampler Service - Runs inference / sampling using vLLM engine # Used for generating text from the model (e.g., evaluating LoRA results). @@ -102,4 +101,3 @@ applications: runtime_env: env_vars: TWINKLE_TRUST_REMOTE_CODE: "0" - DEVICE_COUNT_PER_PHYSICAL_NODE: "8" diff --git a/cookbook/client/twinkle/transformer/server_config.yaml b/cookbook/client/twinkle/transformer/server_config.yaml index 93fe8592..787f0a0b 100644 --- a/cookbook/client/twinkle/transformer/server_config.yaml +++ b/cookbook/client/twinkle/transformer/server_config.yaml @@ -61,7 +61,6 @@ applications: runtime_env: env_vars: TWINKLE_TRUST_REMOTE_CODE: "0" - DEVICE_COUNT_PER_PHYSICAL_NODE: "8" # 3. Processor Service - Handles data preprocessing on CPU # Runs tokenization, template application, and other CPU-bound tasks. @@ -90,7 +89,6 @@ applications: runtime_env: env_vars: TWINKLE_TRUST_REMOTE_CODE: "0" - DEVICE_COUNT_PER_PHYSICAL_NODE: "8" # 4. Sampler Service - Handles text generation inference # Uses vLLM for efficient batched generation with optional LoRA adapters. @@ -125,4 +123,3 @@ applications: runtime_env: env_vars: TWINKLE_TRUST_REMOTE_CODE: "0" - DEVICE_COUNT_PER_PHYSICAL_NODE: "8" diff --git a/docs/source_en/Usage Guide/Server and Client/Server.md b/docs/source_en/Usage Guide/Server and Client/Server.md index ec7b4b42..302a5875 100644 --- a/docs/source_en/Usage Guide/Server and Client/Server.md +++ b/docs/source_en/Usage Guide/Server and Client/Server.md @@ -55,12 +55,9 @@ This configuration starts 3 nodes: Before starting the Server, you need to set the following environment variables: ```bash -export DEVICE_COUNT_PER_PHYSICAL_NODE=8 # Specify the total number of GPUs on each physical machine export TWINKLE_TRUST_REMOTE_CODE=0 # Whether to trust remote code (security consideration) ``` -> **Important Note**: `DEVICE_COUNT_PER_PHYSICAL_NODE` must be set to the actual number of physical GPUs on the machine, which is crucial for correctly parsing the `ranks` configuration. - ### Node Rank in YAML Configuration In the YAML configuration file, **each component needs to occupy a separate Node**. @@ -117,7 +114,6 @@ applications: **Important notes:** - The `ranks` configuration uses **physical GPU card numbers**, directly corresponding to the actual GPU devices on the machine - The `device_mesh` configuration uses parameters like `dp_size`, `tp_size`, `pp_size`, `ep_size` instead of the original `mesh` and `mesh_dim_names` -- The environment variable `DEVICE_COUNT_PER_PHYSICAL_NODE` must be set to inform the system of the total number of physical GPUs on each machine - Different components will be automatically assigned to different Nodes - Ray will automatically schedule to the appropriate Node based on resource requirements (`num_gpus`, `num_cpus` in `ray_actor_options`) @@ -393,7 +389,6 @@ applications: num_cpus: 0.1 runtime_env: env_vars: - DEVICE_COUNT_PER_PHYSICAL_NODE: "8" # Total number of physical GPUs on each machine # 3. Sampler service (optional, for inference sampling) - name: sampler-Qwen2.5-0.5B-Instruct @@ -425,7 +420,6 @@ applications: num_gpus: 1 # Sampler needs independent GPU runtime_env: env_vars: - DEVICE_COUNT_PER_PHYSICAL_NODE: "8" # Total number of physical GPUs on each machine ``` ## Configuration Item Description @@ -471,6 +465,5 @@ device_mesh: **Environment variables:** ```bash -export DEVICE_COUNT_PER_PHYSICAL_NODE=8 # Total number of GPUs on each physical machine (must be set) export TWINKLE_TRUST_REMOTE_CODE=0 # Whether to trust remote code ``` diff --git a/docs/source_en/Usage Guide/Server and Client/Tinker-Compatible-Client.md b/docs/source_en/Usage Guide/Server and Client/Tinker-Compatible-Client.md index 67a6b30f..d3cf4a8f 100644 --- a/docs/source_en/Usage Guide/Server and Client/Tinker-Compatible-Client.md +++ b/docs/source_en/Usage Guide/Server and Client/Tinker-Compatible-Client.md @@ -25,7 +25,7 @@ for item in service_client.get_server_capabilities().supported_models: When calling `init_tinker_compat_client`, the following operations are automatically executed: 1. **Patch Tinker SDK**: Bypass Tinker's `tinker://` prefix validation, allowing it to connect to standard HTTP addresses -2. **Set Request Headers**: Inject necessary authentication headers such as `X-Ray-Serve-Request-Id` and `Authorization` +2. **Set Request Headers**: Inject necessary authentication headers such as `serve_multiplexed_model_id` and `Authorization` 3. **Return `ServiceClient`**: Returns a standard Tinker `ServiceClient` object, subsequent operations are completely identical to native Tinker This means that after initialization, **all existing Tinker training code can be used directly** without any modifications. diff --git "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\347\253\257\345\222\214\345\256\242\346\210\267\347\253\257/Tinker\345\205\274\345\256\271\345\256\242\346\210\267\347\253\257.md" "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\347\253\257\345\222\214\345\256\242\346\210\267\347\253\257/Tinker\345\205\274\345\256\271\345\256\242\346\210\267\347\253\257.md" index ef1c7e26..35b39536 100644 --- "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\347\253\257\345\222\214\345\256\242\346\210\267\347\253\257/Tinker\345\205\274\345\256\271\345\256\242\346\210\267\347\253\257.md" +++ "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\347\253\257\345\222\214\345\256\242\346\210\267\347\253\257/Tinker\345\205\274\345\256\271\345\256\242\346\210\267\347\253\257.md" @@ -25,7 +25,7 @@ for item in service_client.get_server_capabilities().supported_models: 调用 `init_tinker_compat_client` 时,会自动执行以下操作: 1. **Patch Tinker SDK**:绕过 Tinker 的 `tinker://` 前缀校验,使其可以连接到标准 HTTP 地址 -2. **设置请求头**:注入 `X-Ray-Serve-Request-Id` 和 `Authorization` 等必要的认证头 +2. **设置请求头**:注入 `serve_multiplexed_model_id` 和 `Authorization` 等必要的认证头 3. **返回 `ServiceClient`**:返回一个标准的 Tinker `ServiceClient` 对象,后续操作与原生 Tinker 完全一致 这意味着在初始化之后,**所有已有的 Tinker 训练代码都可以直接使用**,无需任何修改。 diff --git "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\347\253\257\345\222\214\345\256\242\346\210\267\347\253\257/\346\234\215\345\212\241\347\253\257.md" "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\347\253\257\345\222\214\345\256\242\346\210\267\347\253\257/\346\234\215\345\212\241\347\253\257.md" index ab7a2436..a09b81e2 100644 --- "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\347\253\257\345\222\214\345\256\242\346\210\267\347\253\257/\346\234\215\345\212\241\347\253\257.md" +++ "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\347\253\257\345\222\214\345\256\242\346\210\267\347\253\257/\346\234\215\345\212\241\347\253\257.md" @@ -55,12 +55,9 @@ ray start --address=10.28.252.9:6379 --num-gpus=0 在启动 Server 之前,需要设置以下环境变量: ```bash -export DEVICE_COUNT_PER_PHYSICAL_NODE=8 # 指定每台物理机上的 GPU 总数 export TWINKLE_TRUST_REMOTE_CODE=0 # 是否信任远程代码(安全考虑) ``` -> **重要提示**:`DEVICE_COUNT_PER_PHYSICAL_NODE` 必须设置为机器上实际的物理 GPU 数量,这对于正确解析 `ranks` 配置至关重要。 - ### YAML 配置中的 Node Rank 在 YAML 配置文件中,**每个组件需要占用一个独立的 Node**。 @@ -117,7 +114,6 @@ applications: **重要提示:** - `ranks` 配置使用**物理 GPU 卡号**,直接对应机器上的实际 GPU 设备 - `device_mesh` 配置使用 `dp_size`、`tp_size`、`pp_size`、`ep_size` 等参数替代原来的 `mesh` 和 `mesh_dim_names` -- 必须设置环境变量 `DEVICE_COUNT_PER_PHYSICAL_NODE` 来告知系统每台机器的物理 GPU 总数 - 不同组件会自动分配到不同的 Node 上 - Ray 会根据资源需求(`ray_actor_options` 中的 `num_gpus`、`num_cpus`)自动调度到合适的 Node @@ -336,7 +332,6 @@ applications: num_cpus: 0.1 runtime_env: env_vars: - DEVICE_COUNT_PER_PHYSICAL_NODE: "8" # 每台机器的物理 GPU 总数 # 3. Sampler 服务(可选,用于推理采样) - name: sampler-Qwen2.5-0.5B-Instruct @@ -368,7 +363,6 @@ applications: num_gpus: 1 # Sampler 需要独立 GPU runtime_env: env_vars: - DEVICE_COUNT_PER_PHYSICAL_NODE: "8" # 每台机器的物理 GPU 总数 ``` ## 配置项说明 @@ -414,6 +408,5 @@ device_mesh: **环境变量:** ```bash -export DEVICE_COUNT_PER_PHYSICAL_NODE=8 # 每台物理机上的 GPU 总数(必须设置) export TWINKLE_TRUST_REMOTE_CODE=0 # 是否信任远程代码 ``` diff --git a/src/twinkle/infra/_ray/ray_helper.py b/src/twinkle/infra/_ray/ray_helper.py index 0a03442c..f0a4011b 100644 --- a/src/twinkle/infra/_ray/ray_helper.py +++ b/src/twinkle/infra/_ray/ray_helper.py @@ -229,19 +229,6 @@ def has_ref(args, kwargs) -> bool: return True return False - @staticmethod - def _noset_env(): - return { - 'RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES': '1', - 'RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES': '1', - 'RAY_EXPERIMENTAL_NOSET_HIP_VISIBLE_DEVICES': '1', - 'RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES': '1', - 'RAY_EXPERIMENTAL_NOSET_HABANA_VISIBLE_MODULES': '1', - 'RAY_EXPERIMENTAL_NOSET_NEURON_RT_VISIBLE_CORES': '1', - 'RAY_EXPERIMENTAL_NOSET_TPU_VISIBLE_CHIPS': '1', - 'RAY_EXPERIMENTAL_NOSET_ONEAPI_DEVICE_SELECTOR': '1', - } - @staticmethod def create_workers(worker_cls: Type[T], group: str, @@ -320,7 +307,7 @@ def create_workers(worker_cls: Type[T], # Prevent Ray from overriding CUDA_VISIBLE_DEVICES set in runtime_env # This is critical for multi-GPU workers (gpus_per_worker > 1) - env_vars.update(RayHelper._noset_env()) + env_vars.update(ResourceManager.noset_env()) runtime_env = RuntimeEnv(env_vars=env_vars) diff --git a/src/twinkle/infra/_ray/resource_manager.py b/src/twinkle/infra/_ray/resource_manager.py index 817cd793..7a45aa8e 100644 --- a/src/twinkle/infra/_ray/resource_manager.py +++ b/src/twinkle/infra/_ray/resource_manager.py @@ -137,6 +137,28 @@ def __init__(self, nproc_per_node: int, ncpu_proc_per_node: int, groups: List[De if self.node_ranks.count(0) > 1: self.node_ranks = list(range(len(self.placement_groups))) + self.visible_devices = [] + + @ray.remote + def get_visible_devices(): + return os.environ.get(Platform.get_platform(group.device_type).visible_device_env()) + + if self.placement_groups: + self.visible_devices = ray.get([ + get_visible_devices.options(placement_group=pg, runtime_env={ + 'env_vars': self.noset_env() + }).remote() for pg in self.placement_groups + ]) + + visible_devices = [] + for visible_device in self.visible_devices: + if visible_device: + visible_device = [int(device) for device in visible_device.split(',')] + else: + visible_device = list(range(nproc_per_node)) + visible_devices.append(visible_device) + self.visible_devices = visible_devices + self.node2pg: Dict[int, PlacementGroup] = {} # Map actual node indices to placement groups # For GPU/NPU groups, node indices start from self.min_node_idx @@ -151,12 +173,8 @@ def __init__(self, nproc_per_node: int, ncpu_proc_per_node: int, groups: List[De self.device_groups = {} ray_address = str(ray.get_runtime_context().gcs_address) - if 'DEVICE_COUNT_PER_PHYSICAL_NODE' in os.environ: - # Sometimes, multiply nodes are in one physical node, there may be error in `gpu_rank` - device_per_node = int(os.environ['DEVICE_COUNT_PER_PHYSICAL_NODE']) - else: - device_per_node = nproc_per_node - for group in groups: + assert len(groups) == len(visible_devices) + for group, visible_device_list in zip(groups, self.visible_devices): if group.device_type != 'CPU': ranks = group.ranks gpus_per_worker = getattr(group, 'gpus_per_worker', 1) @@ -178,7 +196,7 @@ def __init__(self, nproc_per_node: int, ncpu_proc_per_node: int, groups: List[De # All GPUs for a worker should be on the same node node_ranks = [r // nproc_per_node for r in worker_ranks] - gpu_ranks_local = [r % device_per_node for r in worker_ranks] + gpu_ranks_local = [visible_device_list[r % nproc_per_node] for r in worker_ranks] if len(set(node_ranks)) > 1: raise ValueError(f"DeviceGroup '{group.name}': GPUs {worker_ranks} span multiple nodes. " @@ -193,7 +211,7 @@ def __init__(self, nproc_per_node: int, ncpu_proc_per_node: int, groups: List[De else: for alloc_rank in normalized_ranks: node_rank = alloc_rank // nproc_per_node - gpu_rank = alloc_rank % device_per_node + gpu_rank = visible_device_list[alloc_rank % nproc_per_node] local_device_groups.append( dict(gpu_rank=[gpu_rank], placement_group=self.node2pg[node_rank], ray_address=ray_address)) @@ -221,6 +239,19 @@ def __init__(self, nproc_per_node: int, ncpu_proc_per_node: int, groups: List[De logger.info(f'node_ranks: {self.node_ranks}') logger.info(f'node2pg keys: {list(self.node2pg.keys())}') + @staticmethod + def noset_env(): + return { + 'RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES': '1', + 'RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES': '1', + 'RAY_EXPERIMENTAL_NOSET_HIP_VISIBLE_DEVICES': '1', + 'RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES': '1', + 'RAY_EXPERIMENTAL_NOSET_HABANA_VISIBLE_MODULES': '1', + 'RAY_EXPERIMENTAL_NOSET_NEURON_RT_VISIBLE_CORES': '1', + 'RAY_EXPERIMENTAL_NOSET_TPU_VISIBLE_CHIPS': '1', + 'RAY_EXPERIMENTAL_NOSET_ONEAPI_DEVICE_SELECTOR': '1', + } + def get_config(self, group: str): for config in self.group_configs: if config.name == group: diff --git a/src/twinkle/server/tinker/model.py b/src/twinkle/server/tinker/model.py index 2a119162..64bfacab 100644 --- a/src/twinkle/server/tinker/model.py +++ b/src/twinkle/server/tinker/model.py @@ -9,7 +9,6 @@ 3. Checkpoint management (save/load weights) 4. Multi-user support with token-based isolation """ -import os import traceback from fastapi import FastAPI, Request from peft import LoraConfig @@ -100,15 +99,25 @@ def __init__(self, else: self.device_mesh = DeviceMesh.from_sizes(**device_mesh) self.use_megatron = use_megatron + replica_context = serve.get_replica_context() + replica_id = replica_context.replica_id.unique_id # Initialize model immediately - choose backend based on use_megatron if use_megatron: from .common.megatron_model import TwinkleCompatMegatronModel self.model = TwinkleCompatMegatronModel( - model_id=model_id, device_mesh=self.device_mesh, remote_group=self.device_group.name, **kwargs) + model_id=model_id, + device_mesh=self.device_mesh, + remote_group=self.device_group.name, + instance_id=replica_id, + **kwargs) else: from .common.transformers_model import TwinkleCompatTransformersModel self.model = TwinkleCompatTransformersModel( - model_id=model_id, device_mesh=self.device_mesh, remote_group=self.device_group.name, **kwargs) + model_id=model_id, + device_mesh=self.device_mesh, + remote_group=self.device_group.name, + instance_id=replica_id, + **kwargs) self.base_model = model_id self.state: ServerStateProxy = get_server_state() @@ -118,6 +127,19 @@ def __init__(self, self._init_adapter_manager(**adapter_config) self.start_adapter_countdown() + """ + TODO This is a cache system, we must change to sticky routing + Reference docs: + 1. [Now]https://docs.ray.io/en/latest/serve/model-multiplexing.html + 2. https://docs.ray.io/en/latest/serve/llm/architecture/routing-policies.html + 3. https://github.com/ray-project/ray/pull/56855/changes + 4. Direct call actor instead of http or handler in server.py + """ + + # @serve.multiplexed(max_num_models_per_replica=kwargs.get('max_loras', 5)) + # async def get_multiplexed_adapter(self, request_id: str): + # return request_id + def _cleanup_adapter(self, adapter_name: str) -> None: """Common adapter cleanup logic used by both manual unload and automatic expiration. diff --git a/src/twinkle/server/tinker/sampler.py b/src/twinkle/server/tinker/sampler.py index bf4108c9..8ab6fd91 100644 --- a/src/twinkle/server/tinker/sampler.py +++ b/src/twinkle/server/tinker/sampler.py @@ -102,6 +102,8 @@ def __init__(self, else: self.device_mesh = DeviceMesh.from_sizes(**device_mesh) self.sampler_type = sampler_type + replica_context = serve.get_replica_context() + replica_id = replica_context.replica_id.unique_id # Initialize sampler based on type if sampler_type == 'vllm': @@ -112,6 +114,7 @@ def __init__(self, engine_args=sampler_kwargs, device_mesh=self.device_mesh, remote_group=self.device_group.name, + instance_id=replica_id, **{ k: v for k, v in kwargs.items() if k not in ['engine_args'] diff --git a/src/twinkle/server/tinker/server.py b/src/twinkle/server/tinker/server.py index 2e669f56..1a706b45 100644 --- a/src/twinkle/server/tinker/server.py +++ b/src/twinkle/server/tinker/server.py @@ -12,6 +12,7 @@ from __future__ import annotations import asyncio +import dataclasses import httpx import logging import os @@ -82,10 +83,6 @@ def __init__(self, self.client = httpx.AsyncClient(timeout=None, trust_env=False) self.route_prefix = kwargs.get('route_prefix', '/api/v1') self.supported_models = self.normalize_models(supported_models) or [ - types.SupportedModel(model_name='Qwen/Qwen2.5-0.5B-Instruct'), - types.SupportedModel(model_name='Qwen/Qwen2.5-3B-Instruct'), - types.SupportedModel(model_name='Qwen/Qwen2.5-7B-Instruct'), - types.SupportedModel(model_name='Qwen/Qwen2.5-72B-Instruct'), types.SupportedModel(model_name='Qwen/Qwen3-30B-A3B-Instruct-2507'), ] # Lock for ModelScope config file operations (login writes, get_user_info reads) diff --git a/src/twinkle/server/twinkle/model.py b/src/twinkle/server/twinkle/model.py index 1660cd10..1fcf6f8a 100644 --- a/src/twinkle/server/twinkle/model.py +++ b/src/twinkle/server/twinkle/model.py @@ -171,14 +171,24 @@ def __init__(self, nproc_per_node: int, device_group: Dict[str, Any], device_mes self.device_mesh = DeviceMesh(**device_mesh) else: self.device_mesh = DeviceMesh.from_sizes(**device_mesh) + replica_context = serve.get_replica_context() + replica_id = replica_context.replica_id.unique_id if use_megatron: from twinkle.model import MultiLoraMegatronModel self.model = MultiLoraMegatronModel( - model_id=model_id, device_mesh=self.device_mesh, remote_group=self.device_group.name, **kwargs) + model_id=model_id, + device_mesh=self.device_mesh, + remote_group=self.device_group.name, + instance_id=replica_id, + **kwargs) else: from twinkle.model import MultiLoraTransformersModel self.model = MultiLoraTransformersModel( - model_id=model_id, device_mesh=self.device_mesh, remote_group=self.device_group.name, **kwargs) + model_id=model_id, + device_mesh=self.device_mesh, + remote_group=self.device_group.name, + instance_id=replica_id, + **kwargs) # Initialize state before adapter manager (mixin needs self.state) self.state: ServerStateProxy = get_server_state() diff --git a/src/twinkle/server/twinkle/sampler.py b/src/twinkle/server/twinkle/sampler.py index 857c53f6..62cb6a72 100644 --- a/src/twinkle/server/twinkle/sampler.py +++ b/src/twinkle/server/twinkle/sampler.py @@ -152,7 +152,8 @@ def __init__(self, else: self.device_mesh = DeviceMesh.from_sizes(**device_mesh) self.sampler_type = sampler_type - + replica_context = serve.get_replica_context() + replica_id = replica_context.replica_id.unique_id # Initialize sampler based on type if sampler_type == 'vllm': from twinkle.sampler import vLLMSampler @@ -162,6 +163,7 @@ def __init__(self, engine_args=sampler_kwargs, device_mesh=self.device_mesh, remote_group=self.device_group.name, + instance_id=replica_id, **{ k: v for k, v in kwargs.items() if k not in ['engine_args'] @@ -169,7 +171,11 @@ def __init__(self, else: from twinkle.sampler import TorchSampler self.sampler = TorchSampler( - model_id=model_id, device_mesh=self.device_mesh, remote_group=self.device_group.name, **kwargs) + model_id=model_id, + device_mesh=self.device_mesh, + instance_id=replica_id, + remote_group=self.device_group.name, + **kwargs) # Initialize state and adapter manager self.state: ServerStateProxy = get_server_state() diff --git a/src/twinkle/server/twinkle/server.py b/src/twinkle/server/twinkle/server.py index 42d2b4b2..86857647 100644 --- a/src/twinkle/server/twinkle/server.py +++ b/src/twinkle/server/twinkle/server.py @@ -12,10 +12,10 @@ """ from __future__ import annotations -from fastapi import FastAPI, HTTPException, Request, Response +from fastapi import FastAPI, HTTPException, Request from pydantic import BaseModel from ray import serve -from typing import Any, Dict, List, Optional +from typing import Any from twinkle.server.utils.state import ServerStateProxy, get_server_state from twinkle.server.utils.validation import get_token_from_request, verify_request_token diff --git a/src/twinkle/server/utils/adapter_manager.py b/src/twinkle/server/utils/adapter_manager.py index 04e56922..c24ce466 100644 --- a/src/twinkle/server/utils/adapter_manager.py +++ b/src/twinkle/server/utils/adapter_manager.py @@ -13,7 +13,7 @@ import threading import time -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Any if TYPE_CHECKING: from twinkle.server.utils.state import ServerStateProxy diff --git a/src/twinkle/server/utils/validation.py b/src/twinkle/server/utils/validation.py index 23539ed8..d419818a 100644 --- a/src/twinkle/server/utils/validation.py +++ b/src/twinkle/server/utils/validation.py @@ -11,7 +11,7 @@ async def verify_request_token(request: Request, call_next): This middleware: 1. Extracts the Bearer token from Authorization header 2. Validates the token - 3. Extracts X-Ray-Serve-Request-Id for sticky sessions + 3. Extracts serve_multiplexed_model_id for sticky sessions 4. Stores token and request_id in request.state for later use Args: @@ -26,10 +26,11 @@ async def verify_request_token(request: Request, call_next): if not is_token_valid(token): return JSONResponse(status_code=403, content={'detail': 'Invalid token'}) - request_id = request.headers.get('X-Ray-Serve-Request-Id') + request_id = request.headers.get('serve_multiplexed_model_id') if not request_id: return JSONResponse( - status_code=400, content={'detail': 'Missing X-Ray-Serve-Request-Id header, required for sticky session'}) + status_code=400, + content={'detail': 'Missing serve_multiplexed_model_id header, required for sticky session'}) request.state.request_id = request_id request.state.token = token response = await call_next(request) diff --git a/src/twinkle_client/__init__.py b/src/twinkle_client/__init__.py index 5a6928e9..f236f734 100644 --- a/src/twinkle_client/__init__.py +++ b/src/twinkle_client/__init__.py @@ -28,7 +28,7 @@ def init_tinker_compat_client(base_url: str | None = None, api_key: str | None = base_url = f'http://{base_url}' default_headers = { - 'X-Ray-Serve-Request-Id': get_request_id(), + 'serve_multiplexed_model_id': get_request_id(), 'Authorization': 'Bearer ' + api_key, 'Twinkle-Authorization': 'Bearer ' + api_key, # For server compatibility } | kwargs.pop('default_headers', {}) diff --git a/src/twinkle_client/http/http_utils.py b/src/twinkle_client/http/http_utils.py index 522b46af..f9cafa1c 100644 --- a/src/twinkle_client/http/http_utils.py +++ b/src/twinkle_client/http/http_utils.py @@ -16,7 +16,7 @@ def _build_headers(additional_headers: Optional[Dict[str, str]] = None) -> Dict[ Dictionary of headers """ headers = { - 'X-Ray-Serve-Request-Id': get_request_id(), + 'serve_multiplexed_model_id': get_request_id(), 'Authorization': 'Bearer ' + get_api_key(), 'Twinkle-Authorization': 'Bearer ' + get_api_key(), # For server compatibility }