diff --git a/lmcache/v1/cache_engine.py b/lmcache/v1/cache_engine.py index 61231b2aff..3900eded69 100644 --- a/lmcache/v1/cache_engine.py +++ b/lmcache/v1/cache_engine.py @@ -1794,8 +1794,17 @@ def _Create_memory_allocator( buffer.data_ptr(), config.nixl_buffer_size, 0 ) else: - logger.info(f"Setting cuda device to {corrected_device} ") - torch.cuda.set_device(corrected_device) + logger.info(f"Setting device to {corrected_device}") + # Set device based on device type + if corrected_device.startswith("cuda"): + torch.cuda.set_device(corrected_device) + elif corrected_device.startswith("xpu"): + if not hasattr(torch, "xpu"): + raise RuntimeError( + "XPU device is not available. Please ensure PyTorch " + "is built with XPU support." + ) + torch.xpu.set_device(corrected_device) return PagedTensorMemoryAllocator( buffer, diff --git a/lmcache/v1/storage_backend/nixl_storage_backend.py b/lmcache/v1/storage_backend/nixl_storage_backend.py index 6bbc3abc49..6f7521838e 100644 --- a/lmcache/v1/storage_backend/nixl_storage_backend.py +++ b/lmcache/v1/storage_backend/nixl_storage_backend.py @@ -532,7 +532,16 @@ def initialize_allocator( base_buffer, self.buffer = _allocate_gpu_memory( config.nixl_buffer_size, corrected_device ) - torch.cuda.set_device(corrected_device) + # Set device based on device type + if corrected_device.startswith("cuda"): + torch.cuda.set_device(corrected_device) + elif corrected_device.startswith("xpu"): + if not hasattr(torch, "xpu"): + raise RuntimeError( + "XPU device is not available. Please ensure PyTorch is " + "built with XPU support." + ) + torch.xpu.set_device(corrected_device) self.base_buffer = base_buffer # Prevents early GC of the aligned tensor. self.free_pinned_buffer = False diff --git a/lmcache/v1/storage_backend/p2p_backend.py b/lmcache/v1/storage_backend/p2p_backend.py index 89a1818a8b..3139b5774e 100644 --- a/lmcache/v1/storage_backend/p2p_backend.py +++ b/lmcache/v1/storage_backend/p2p_backend.py @@ -231,6 +231,7 @@ def __init__( peer_lookup_url=self.peer_lookup_url, backends=config.nixl_backends, event_loop=loop, + device="cpu", ) self.running = asyncio.Event() diff --git a/lmcache/v1/storage_backend/pd_backend.py b/lmcache/v1/storage_backend/pd_backend.py index 3fbdec40b8..5c4517c432 100644 --- a/lmcache/v1/storage_backend/pd_backend.py +++ b/lmcache/v1/storage_backend/pd_backend.py @@ -189,6 +189,7 @@ def __init__( tp_rank=self.tp_rank, peer_init_url=peer_init_url, backends=config.nixl_backends, + device=self.pd_config.buffer_device, ) if self.pd_config.role == "sender": @@ -217,8 +218,18 @@ def initialize_allocator( config.pd_buffer_device, metadata.worker_id, ) - logger.info(f"Setting cuda device to {corrected_device} ") - torch.cuda.set_device(corrected_device) + logger.info(f"Setting device to {corrected_device}") + + # Set device based on device type + if corrected_device.startswith("cuda"): + torch.cuda.set_device(corrected_device) + elif corrected_device.startswith("xpu"): + if not hasattr(torch, "xpu"): + raise RuntimeError( + "XPU device is not available. Please ensure PyTorch is built " + "with XPU support." + ) + torch.xpu.set_device(corrected_device) paged_mem_allocator = PagedCpuGpuMemoryAllocator() paged_mem_allocator.init_gpu_memory_allocator( diff --git a/lmcache/v1/transfer_channel/nixl_channel.py b/lmcache/v1/transfer_channel/nixl_channel.py index 85744b4d67..8c0ef70e21 100644 --- a/lmcache/v1/transfer_channel/nixl_channel.py +++ b/lmcache/v1/transfer_channel/nixl_channel.py @@ -81,6 +81,10 @@ def __init__( else: backends = ["UCX"] + # Extract device from kwargs (optional, defaults to "cuda" for + # backwards compatibility) + device = kwargs.get("device", "cuda") + self.role = kwargs["role"] self.nixl_wrapper = NixlAgentWrapper( @@ -89,6 +93,7 @@ def __init__( page_size=kwargs["align_bytes"], tp_rank=kwargs["tp_rank"], backends=backends, + device=device, ) self.nixl_agent = self.nixl_wrapper.agent @@ -579,6 +584,7 @@ def __init__( page_size: int, tp_rank: int, backends: list[str], + device: str = "cuda", ): """ Initialize the NIXL agent. @@ -590,6 +596,8 @@ def __init__( the lmcache memory allocator. tp_rank (int): The tensor parallel rank. backends (list[str]): The list of backends to use. + device (str): The device type string (e.g., "cuda:0", "xpu:0"). + Defaults to "cuda" for backward compatibility. Returns: NixlWrapper: The NIXL agent. @@ -608,6 +616,21 @@ def __init__( if backends is None: backends = ["UCX"] + # Determine memory type based on device string + # device can be "cuda", "cuda:0", "xpu", "xpu:0", "cpu", etc. + if device.startswith("cuda"): + mem_type = "cuda" + elif device.startswith("xpu"): + mem_type = "xpu" + elif device.startswith("cpu"): + mem_type = "cpu" + else: + # Raise error for unsupported device types + raise ValueError( + f"Unsupported device type: {device}. " + "Supported device types are: cuda, xpu, cpu" + ) + # Create a NIXL agent nixl_agent = NixlAgent( str(uuid.uuid4()), @@ -618,8 +641,7 @@ def __init__( # The four fields are (base_addr, length, dev_id, meta_info) # https://github.com/ai-dynamo/nixl/blob/main/src/api/cpp/nixl_descriptors.h#L152 memory_desc = [(buffer_ptr, buffer_size, tp_rank, "")] - # TODO(Jiayi): remove hardcode `mem_type` - reg_descs = nixl_agent.get_reg_descs(memory_desc, mem_type="cuda") + reg_descs = nixl_agent.get_reg_descs(memory_desc, mem_type=mem_type) nixl_agent.register_memory(reg_descs) # Create xfer handlers @@ -627,8 +649,8 @@ def __init__( for base_addr in range(buffer_ptr, buffer_ptr + buffer_size, page_size): xfer_desc.append((base_addr, page_size, tp_rank)) - xfer_descs = nixl_agent.get_xfer_descs(xfer_desc, mem_type="cuda") - xfer_handler = nixl_agent.prep_xfer_dlist("", xfer_descs, mem_type="cuda") + xfer_descs = nixl_agent.get_xfer_descs(xfer_desc, mem_type=mem_type) + xfer_handler = nixl_agent.prep_xfer_dlist("", xfer_descs, mem_type=mem_type) self.agent = nixl_agent self.reg_descs = reg_descs diff --git a/lmcache/v1/transfer_channel/transfer_utils.py b/lmcache/v1/transfer_channel/transfer_utils.py index 0e8e5b9b1f..f13a9c272e 100644 --- a/lmcache/v1/transfer_channel/transfer_utils.py +++ b/lmcache/v1/transfer_channel/transfer_utils.py @@ -11,8 +11,8 @@ def get_correct_device(device: str, worker_id: int) -> str: Get the correct device based on the given device string. Args: - device (str): The device string, could be cpu or cuda. - worker_id (int): The worker id to determine the cuda device. + device (str): The device string, could be cpu, cuda, or xpu. + worker_id (int): The worker id to determine the device. Returns: str: The correct device string with device id. @@ -21,6 +21,8 @@ def get_correct_device(device: str, worker_id: int) -> str: return "cpu" elif device.startswith("cuda"): return f"cuda:{worker_id}" + elif device.startswith("xpu"): + return f"xpu:{worker_id}" else: raise ValueError(f"Invalid device: {device}")