diff --git a/tensorrt_llm/llmapi/llm.py b/tensorrt_llm/llmapi/llm.py index 69c5bdf6ab2c..a5833f26bdaa 100644 --- a/tensorrt_llm/llmapi/llm.py +++ b/tensorrt_llm/llmapi/llm.py @@ -164,7 +164,7 @@ def __init__(self, self.mpi_session = self.args.mpi_session if self.args.parallel_config.is_multi_gpu: - if get_device_count( + if os.getenv("RAY_LOCAL_RANK") is None and get_device_count( ) < self.args.parallel_config.world_size_per_node: raise RuntimeError( f"Only {get_device_count()} GPUs are available, but {self.args.parallel_config.world_size} are required." diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index cf28ecd326d0..6d5b09cc12e8 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -1194,6 +1194,8 @@ def validate_quant_config(cls, v, info): @field_validator("gpus_per_node", mode='before') @classmethod def validate_gpus_per_node(cls, v, info): + if os.getenv("RAY_LOCAL_RANK") is not None: + return info.data.get("tensor_parallel_size") if v is None: logger.warning( f"Using default gpus_per_node: {torch.cuda.device_count()}")