From 80550c288e70aade2a6a0458995c3f0412f67a17 Mon Sep 17 00:00:00 2001 From: Mina Huai <121143971+MinaHuai@users.noreply.github.com> Date: Thu, 9 Oct 2025 19:30:39 -0700 Subject: [PATCH] rebase to the updated verl --- tensorrt_llm/llmapi/llm.py | 2 +- tensorrt_llm/llmapi/llm_args.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorrt_llm/llmapi/llm.py b/tensorrt_llm/llmapi/llm.py index 2f16ed1fa646..75cb7e5ddc59 100644 --- a/tensorrt_llm/llmapi/llm.py +++ b/tensorrt_llm/llmapi/llm.py @@ -165,7 +165,7 @@ def __init__(self, self.mpi_session = self.args.mpi_session if self.args.parallel_config.is_multi_gpu: - if os.getenv("RAY_LOCAL_RANK") is None and get_device_count( + if os.getenv("RAY_LOCAL_WORLD_SIZE") is None and get_device_count( ) < self.args.parallel_config.world_size_per_node: raise RuntimeError( f"Only {get_device_count()} GPUs are available, but {self.args.parallel_config.world_size} are required." diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 8fbcbe5075c7..001f88f9724b 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -1381,7 +1381,7 @@ def validate_dtype(cls, v, info): @field_validator("gpus_per_node", mode='before') @classmethod def validate_gpus_per_node(cls, v, info): - if os.getenv("RAY_LOCAL_RANK") is not None: + if os.getenv("RAY_LOCAL_WORLD_SIZE") is not None: return info.data.get("tensor_parallel_size") if v is None: logger.warning(