Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions lmcache/v1/cache_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1794,8 +1794,17 @@ def _Create_memory_allocator(
buffer.data_ptr(), config.nixl_buffer_size, 0
)
else:
logger.info(f"Setting cuda device to {corrected_device} ")
torch.cuda.set_device(corrected_device)
logger.info(f"Setting device to {corrected_device}")
# Set device based on device type
if corrected_device.startswith("cuda"):
torch.cuda.set_device(corrected_device)
elif corrected_device.startswith("xpu"):
if not hasattr(torch, "xpu"):
raise RuntimeError(
"XPU device is not available. Please ensure PyTorch "
"is built with XPU support."
)
torch.xpu.set_device(corrected_device)

return PagedTensorMemoryAllocator(
buffer,
Expand Down
11 changes: 10 additions & 1 deletion lmcache/v1/storage_backend/nixl_storage_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,16 @@ def initialize_allocator(
base_buffer, self.buffer = _allocate_gpu_memory(
config.nixl_buffer_size, corrected_device
)
torch.cuda.set_device(corrected_device)
# Set device based on device type
if corrected_device.startswith("cuda"):
torch.cuda.set_device(corrected_device)
elif corrected_device.startswith("xpu"):
if not hasattr(torch, "xpu"):
raise RuntimeError(
"XPU device is not available. Please ensure PyTorch is "
"built with XPU support."
)
torch.xpu.set_device(corrected_device)
self.base_buffer = base_buffer # Prevents early GC of the aligned tensor.
self.free_pinned_buffer = False

Expand Down
1 change: 1 addition & 0 deletions lmcache/v1/storage_backend/p2p_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ def __init__(
peer_lookup_url=self.peer_lookup_url,
backends=config.nixl_backends,
event_loop=loop,
device="cpu",
)

self.running = asyncio.Event()
Expand Down
15 changes: 13 additions & 2 deletions lmcache/v1/storage_backend/pd_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ def __init__(
tp_rank=self.tp_rank,
peer_init_url=peer_init_url,
backends=config.nixl_backends,
device=self.pd_config.buffer_device,
)

if self.pd_config.role == "sender":
Expand Down Expand Up @@ -217,8 +218,18 @@ def initialize_allocator(
config.pd_buffer_device,
metadata.worker_id,
)
logger.info(f"Setting cuda device to {corrected_device} ")
torch.cuda.set_device(corrected_device)
logger.info(f"Setting device to {corrected_device}")

# Set device based on device type
if corrected_device.startswith("cuda"):
torch.cuda.set_device(corrected_device)
elif corrected_device.startswith("xpu"):
if not hasattr(torch, "xpu"):
raise RuntimeError(
"XPU device is not available. Please ensure PyTorch is built "
"with XPU support."
)
torch.xpu.set_device(corrected_device)

paged_mem_allocator = PagedCpuGpuMemoryAllocator()
paged_mem_allocator.init_gpu_memory_allocator(
Expand Down
30 changes: 26 additions & 4 deletions lmcache/v1/transfer_channel/nixl_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,10 @@ def __init__(
else:
backends = ["UCX"]

# Extract device from kwargs (optional, defaults to "cuda" for
# backwards compatibility)
device = kwargs.get("device", "cuda")

self.role = kwargs["role"]

self.nixl_wrapper = NixlAgentWrapper(
Expand All @@ -89,6 +93,7 @@ def __init__(
page_size=kwargs["align_bytes"],
tp_rank=kwargs["tp_rank"],
backends=backends,
device=device,
)
self.nixl_agent = self.nixl_wrapper.agent

Expand Down Expand Up @@ -579,6 +584,7 @@ def __init__(
page_size: int,
tp_rank: int,
backends: list[str],
device: str = "cuda",
):
"""
Initialize the NIXL agent.
Expand All @@ -590,6 +596,8 @@ def __init__(
the lmcache memory allocator.
tp_rank (int): The tensor parallel rank.
backends (list[str]): The list of backends to use.
device (str): The device type string (e.g., "cuda:0", "xpu:0").
Defaults to "cuda" for backward compatibility.

Returns:
NixlWrapper: The NIXL agent.
Expand All @@ -608,6 +616,21 @@ def __init__(
if backends is None:
backends = ["UCX"]

# Determine memory type based on device string
# device can be "cuda", "cuda:0", "xpu", "xpu:0", "cpu", etc.
if device.startswith("cuda"):
mem_type = "cuda"
elif device.startswith("xpu"):
mem_type = "xpu"
elif device.startswith("cpu"):
mem_type = "cpu"
else:
# Raise error for unsupported device types
raise ValueError(
f"Unsupported device type: {device}. "
"Supported device types are: cuda, xpu, cpu"
)

# Create a NIXL agent
nixl_agent = NixlAgent(
str(uuid.uuid4()),
Expand All @@ -618,17 +641,16 @@ def __init__(
# The four fields are (base_addr, length, dev_id, meta_info)
# https://github.com/ai-dynamo/nixl/blob/main/src/api/cpp/nixl_descriptors.h#L152
memory_desc = [(buffer_ptr, buffer_size, tp_rank, "")]
# TODO(Jiayi): remove hardcode `mem_type`
reg_descs = nixl_agent.get_reg_descs(memory_desc, mem_type="cuda")
reg_descs = nixl_agent.get_reg_descs(memory_desc, mem_type=mem_type)
nixl_agent.register_memory(reg_descs)

# Create xfer handlers
xfer_desc = []
for base_addr in range(buffer_ptr, buffer_ptr + buffer_size, page_size):
xfer_desc.append((base_addr, page_size, tp_rank))

xfer_descs = nixl_agent.get_xfer_descs(xfer_desc, mem_type="cuda")
xfer_handler = nixl_agent.prep_xfer_dlist("", xfer_descs, mem_type="cuda")
xfer_descs = nixl_agent.get_xfer_descs(xfer_desc, mem_type=mem_type)
xfer_handler = nixl_agent.prep_xfer_dlist("", xfer_descs, mem_type=mem_type)

self.agent = nixl_agent
self.reg_descs = reg_descs
Expand Down
6 changes: 4 additions & 2 deletions lmcache/v1/transfer_channel/transfer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ def get_correct_device(device: str, worker_id: int) -> str:
Get the correct device based on the given device string.

Args:
device (str): The device string, could be cpu or cuda.
worker_id (int): The worker id to determine the cuda device.
device (str): The device string, could be cpu, cuda, or xpu.
worker_id (int): The worker id to determine the device.

Returns:
str: The correct device string with device id.
Expand All @@ -21,6 +21,8 @@ def get_correct_device(device: str, worker_id: int) -> str:
return "cpu"
elif device.startswith("cuda"):
return f"cuda:{worker_id}"
elif device.startswith("xpu"):
return f"xpu:{worker_id}"
else:
raise ValueError(f"Invalid device: {device}")

Expand Down