Skip to content

Replace device-specific CUDA/XPU graph APIs with unified torch.accelerator.Graph#9

Draft
Copilot wants to merge 2 commits into
mainfrom
copilot/replace-graph-apis-with-unified-api
Draft

Replace device-specific CUDA/XPU graph APIs with unified torch.accelerator.Graph#9
Copilot wants to merge 2 commits into
mainfrom
copilot/replace-graph-apis-with-unified-api

Conversation

Copilot AI commented Jun 21, 2026

Copy link
Copy Markdown

PyTorch now provides torch.accelerator.Graph as a device-agnostic replacement for torch.cuda.CUDAGraph/torch.xpu.XPUGraph. This PR migrates all graph capture/replay code to the new API, eliminating the XPU monkey-patching shim.

Key API difference

# Old pattern (CUDA-specific)
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, pool=pool, stream=stream):
    output = model(inputs)

# New pattern (device-agnostic)
graph = torch.accelerator.Graph(pool=pool, capture_error_mode="global")
with torch.cuda.stream(stream):
    graph.capture_begin()
    output = model(inputs)
    graph.capture_end()

Notable: pool and capture_error_mode move from capture_begin()/context-manager args to the constructor. capture_error_mode="global" is passed explicitly to preserve CUDA backward compatibility (old default was "global"; new default is "default").

Changes

  • vllm/compilation/cuda_graph.py — replace CUDAGraph() + torch.cuda.graph() CM with accelerator.Graph; update CUDAGraphEntry.cudagraph type annotation to Any
  • vllm/compilation/breakable_cudagraph.py — replace CUDAGraph() + conditional capture_begin(pool=...) with accelerator.Graph(pool=...) + unconditional capture_begin(); update type annotation
  • vllm/v1/worker/gpu/cudagraph_utils.py — same refactoring; update graphs dict type annotation
  • vllm/v1/worker/encoder_cudagraph.py — same refactoring; update BudgetGraphMetadata.graph type annotation
  • vllm/v1/worker/gpu_ubatch_wrapper.py — same refactoring; update CUDAGraphMetaData.cudagraph type annotation
  • vllm/v1/worker/xpu_model_runner.py — remove the torch.cuda.graph/CUDAGraph/graph_pool_handle monkey-patching lines from _torch_cuda_wrapper() (and the now-unused supports_xpu_graph import)
  • tests/v1/cudagraph/test_cudagraph_dispatch.py — update mock targets from torch.cuda.graph to torch.accelerator.Graph

Replace torch.cuda.CUDAGraph/torch.cuda.graph and XPU monkey-patching
with the unified torch.accelerator.Graph API:

- vllm/compilation/cuda_graph.py: Use torch.accelerator.Graph with
  pool+capture_error_mode at construction, manual capture_begin/end
- vllm/compilation/breakable_cudagraph.py: Same, simplify _begin_segment
- vllm/v1/worker/gpu/cudagraph_utils.py: Same refactoring
- vllm/v1/worker/encoder_cudagraph.py: Same refactoring
- vllm/v1/worker/gpu_ubatch_wrapper.py: Same refactoring
- vllm/v1/worker/xpu_model_runner.py: Remove graph monkey-patching
- tests/v1/cudagraph/test_cudagraph_dispatch.py: Update mocks

Co-authored-by: GitHub Copilot
Copilot AI changed the title [WIP] Replace device-specific graph APIs with torch.accelerator.Graph Replace device-specific CUDA/XPU graph APIs with unified torch.accelerator.Graph Jun 21, 2026
Copilot AI requested a review from zhenwei-intel June 21, 2026 12:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants