From 239adc5279c319a4f0c70e1427978a0689db46a7 Mon Sep 17 00:00:00 2001 From: Mina Huai <121143971+MinaHuai@users.noreply.github.com> Date: Mon, 21 Jul 2025 19:33:01 -0700 Subject: [PATCH] resolve ray conflit --- tensorrt_llm/llmapi/llm.py | 2 +- tensorrt_llm/llmapi/llm_args.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/tensorrt_llm/llmapi/llm.py b/tensorrt_llm/llmapi/llm.py index 69c5bdf6ab2c..a5833f26bdaa 100644 --- a/tensorrt_llm/llmapi/llm.py +++ b/tensorrt_llm/llmapi/llm.py @@ -164,7 +164,7 @@ def __init__(self, self.mpi_session = self.args.mpi_session if self.args.parallel_config.is_multi_gpu: - if get_device_count( + if os.getenv("RAY_LOCAL_RANK") is None and get_device_count( ) < self.args.parallel_config.world_size_per_node: raise RuntimeError( f"Only {get_device_count()} GPUs are available, but {self.args.parallel_config.world_size} are required." diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index cf28ecd326d0..6d5b09cc12e8 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -1194,6 +1194,8 @@ def validate_quant_config(cls, v, info): @field_validator("gpus_per_node", mode='before') @classmethod def validate_gpus_per_node(cls, v, info): + if os.getenv("RAY_LOCAL_RANK") is not None: + return info.data.get("tensor_parallel_size") if v is None: logger.warning( f"Using default gpus_per_node: {torch.cuda.device_count()}")