diff --git a/README.md b/README.md index d73b1c6..369728d 100644 --- a/README.md +++ b/README.md @@ -190,6 +190,17 @@ python demo.py --model_path /path/to/checkpoint.pt \ --image_folder /path/to/images/ --use_sdpa ``` +### Apple Silicon / MPS + +On Apple Silicon Macs, `--device auto` uses MPS when CUDA is unavailable. +FlashInfer is CUDA-only, so the demo automatically switches to PyTorch SDPA +when running on MPS or CPU. + +```bash +python demo.py --model_path /path/to/checkpoint.pt \ + --image_folder /path/to/images/ --device mps +``` + ### Running on Limited GPU Memory If you run into out-of-memory issues, try one (or both) of the following: diff --git a/demo.py b/demo.py index a8c8636..0653ca6 100644 --- a/demo.py +++ b/demo.py @@ -21,12 +21,18 @@ import argparse import glob import os +import platform import time +from contextlib import nullcontext # Must be set before `import torch` / any CUDA init. Reduces the reserved-vs-allocated # memory gap by letting the caching allocator grow segments on demand instead of # pre-reserving fixed-size blocks. os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True") +if platform.system() == "Darwin" and platform.machine() in {"arm64", "aarch64"}: + # Apple Silicon runs most inference on MPS. This lets unsupported kernels + # fall back to CPU instead of failing mid-run. + os.environ.setdefault("PYTORCH_ENABLE_MPS_FALLBACK", "1") import cv2 import numpy as np @@ -38,6 +44,10 @@ from lingbot_map.utils.load_fn import load_and_preprocess_images +def mps_is_available(): + return hasattr(torch.backends, "mps") and torch.backends.mps.is_available() + + # ============================================================================= # Image loading # ============================================================================= @@ -104,8 +114,29 @@ def load_images(image_folder=None, video_path=None, fps=10, image_ext=".jpg,.png # Model loading # ============================================================================= +def select_device(device_name="auto"): + """Select an inference device, preferring CUDA, then Apple Silicon MPS.""" + if device_name == "auto": + if torch.cuda.is_available(): + return torch.device("cuda") + if mps_is_available(): + return torch.device("mps") + return torch.device("cpu") + + device = torch.device(device_name) + if device.type == "cuda" and not torch.cuda.is_available(): + raise RuntimeError("CUDA was requested, but torch.cuda.is_available() is false.") + if device.type == "mps" and not mps_is_available(): + raise RuntimeError("MPS was requested, but torch.backends.mps.is_available() is false.") + return device + + def load_model(args, device): """Load GCTStream model from checkpoint.""" + if device.type != "cuda" and not args.use_sdpa: + print(f"Device {device.type} does not support FlashInfer; enabling SDPA attention.") + args.use_sdpa = True + if getattr(args, "mode", "streaming") == "windowed": from lingbot_map.models.gct_stream_window import GCTStream else: @@ -243,6 +274,8 @@ def main(): # Model parser.add_argument("--model_path", type=str, required=True) + parser.add_argument("--device", type=str, default="auto", choices=["auto", "cuda", "mps", "cpu"], + help="Inference device. auto prefers CUDA, then Apple Silicon MPS, then CPU.") parser.add_argument("--image_size", type=int, default=518) parser.add_argument("--patch_size", type=int, default=14) @@ -295,7 +328,8 @@ def main(): assert args.image_folder or args.video_path, \ "Provide --image_folder or --video_path" - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + device = select_device(args.device) + print(f"Using device: {device}") # ── Load images & model ────────────────────────────────────────────────── t0 = time.time() @@ -321,7 +355,7 @@ def main(): print(f"Total load time: {time.time() - t0:.1f}s") # Pick inference dtype; autocast still runs for the ops that need fp32 (e.g. LayerNorm). - if torch.cuda.is_available(): + if device.type == "cuda": dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16 else: dtype = torch.float32 @@ -361,8 +395,13 @@ def main(): t0 = time.time() output_device = torch.device("cpu") if args.offload_to_cpu else None + autocast_context = ( + torch.amp.autocast("cuda", dtype=dtype) + if device.type == "cuda" + else nullcontext() + ) - with torch.no_grad(), torch.amp.autocast("cuda", dtype=dtype): + with torch.no_grad(), autocast_context: if args.mode == "streaming": predictions = model.inference_streaming( images, diff --git a/lingbot_map/layers/rope.py b/lingbot_map/layers/rope.py index 7f44e31..48c2a8c 100644 --- a/lingbot_map/layers/rope.py +++ b/lingbot_map/layers/rope.py @@ -354,7 +354,10 @@ def forward(self, ppf, pph, ppw, patch_start_idx, device: torch.device, f_start: """ # 步骤1:将预计算的频率移到目标设备,并分割成三个维度 - self.freqs = self.freqs.to(device) + if device.type == "mps": + freqs_tensor = self.freqs.to(dtype=torch.complex64).to(device) + else: + freqs_tensor = self.freqs.to(device) # 获取实际的维度分配 if hasattr(self, 'fhw_dim') and self.fhw_dim is not None: t_dim, h_dim, w_dim = self.fhw_dim @@ -364,7 +367,7 @@ def forward(self, ppf, pph, ppw, patch_start_idx, device: torch.device, f_start: t_dim = self.attention_head_dim - h_dim - w_dim # 使用正确的split sizes(每个维度的一半) - freqs = self.freqs.split_with_sizes( + freqs = freqs_tensor.split_with_sizes( [ t_dim // 2, # 时间维度 h_dim // 2, # 高度维度 @@ -453,7 +456,8 @@ def apply_rotary_emb(x, freqs): """ # 步骤1:reshape成 [..., head_dim//2, 2] 形式,最后一维表示(real, imag) # 例如:[b, h, seq, 64] -> [b, h, seq, 32, 2] - x_reshaped = x.to(torch.float64).reshape(x.shape[0], x.shape[1], x.shape[2], -1, 2) + rotary_dtype = torch.float32 if x.device.type == "mps" else torch.float64 + x_reshaped = x.to(rotary_dtype).reshape(x.shape[0], x.shape[1], x.shape[2], -1, 2) # 步骤2:转换为复数表示 [b, h, seq, 32] # 每个元素是 real + imag*i