Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,17 @@ python demo.py --model_path /path/to/checkpoint.pt \
--image_folder /path/to/images/ --use_sdpa
```

### Apple Silicon / MPS

On Apple Silicon Macs, `--device auto` uses MPS when CUDA is unavailable.
FlashInfer is CUDA-only, so the demo automatically switches to PyTorch SDPA
when running on MPS or CPU.

```bash
python demo.py --model_path /path/to/checkpoint.pt \
--image_folder /path/to/images/ --device mps
```

### Running on Limited GPU Memory

If you run into out-of-memory issues, try one (or both) of the following:
Expand Down
45 changes: 42 additions & 3 deletions demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,18 @@
import argparse
import glob
import os
import platform
import time
from contextlib import nullcontext

# Must be set before `import torch` / any CUDA init. Reduces the reserved-vs-allocated
# memory gap by letting the caching allocator grow segments on demand instead of
# pre-reserving fixed-size blocks.
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
if platform.system() == "Darwin" and platform.machine() in {"arm64", "aarch64"}:
# Apple Silicon runs most inference on MPS. This lets unsupported kernels
# fall back to CPU instead of failing mid-run.
os.environ.setdefault("PYTORCH_ENABLE_MPS_FALLBACK", "1")

import cv2
import numpy as np
Expand All @@ -38,6 +44,10 @@
from lingbot_map.utils.load_fn import load_and_preprocess_images


def mps_is_available():
return hasattr(torch.backends, "mps") and torch.backends.mps.is_available()


# =============================================================================
# Image loading
# =============================================================================
Expand Down Expand Up @@ -104,8 +114,29 @@ def load_images(image_folder=None, video_path=None, fps=10, image_ext=".jpg,.png
# Model loading
# =============================================================================

def select_device(device_name="auto"):
"""Select an inference device, preferring CUDA, then Apple Silicon MPS."""
if device_name == "auto":
if torch.cuda.is_available():
return torch.device("cuda")
if mps_is_available():
return torch.device("mps")
return torch.device("cpu")

device = torch.device(device_name)
if device.type == "cuda" and not torch.cuda.is_available():
raise RuntimeError("CUDA was requested, but torch.cuda.is_available() is false.")
if device.type == "mps" and not mps_is_available():
raise RuntimeError("MPS was requested, but torch.backends.mps.is_available() is false.")
return device


def load_model(args, device):
"""Load GCTStream model from checkpoint."""
if device.type != "cuda" and not args.use_sdpa:
print(f"Device {device.type} does not support FlashInfer; enabling SDPA attention.")
args.use_sdpa = True

if getattr(args, "mode", "streaming") == "windowed":
from lingbot_map.models.gct_stream_window import GCTStream
else:
Expand Down Expand Up @@ -243,6 +274,8 @@ def main():

# Model
parser.add_argument("--model_path", type=str, required=True)
parser.add_argument("--device", type=str, default="auto", choices=["auto", "cuda", "mps", "cpu"],
help="Inference device. auto prefers CUDA, then Apple Silicon MPS, then CPU.")
parser.add_argument("--image_size", type=int, default=518)
parser.add_argument("--patch_size", type=int, default=14)

Expand Down Expand Up @@ -295,7 +328,8 @@ def main():
assert args.image_folder or args.video_path, \
"Provide --image_folder or --video_path"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = select_device(args.device)
print(f"Using device: {device}")

# ── Load images & model ──────────────────────────────────────────────────
t0 = time.time()
Expand All @@ -321,7 +355,7 @@ def main():
print(f"Total load time: {time.time() - t0:.1f}s")

# Pick inference dtype; autocast still runs for the ops that need fp32 (e.g. LayerNorm).
if torch.cuda.is_available():
if device.type == "cuda":
dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16
else:
dtype = torch.float32
Expand Down Expand Up @@ -361,8 +395,13 @@ def main():
t0 = time.time()

output_device = torch.device("cpu") if args.offload_to_cpu else None
autocast_context = (
torch.amp.autocast("cuda", dtype=dtype)
if device.type == "cuda"
else nullcontext()
)

with torch.no_grad(), torch.amp.autocast("cuda", dtype=dtype):
with torch.no_grad(), autocast_context:
if args.mode == "streaming":
predictions = model.inference_streaming(
images,
Expand Down
10 changes: 7 additions & 3 deletions lingbot_map/layers/rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,10 @@ def forward(self, ppf, pph, ppw, patch_start_idx, device: torch.device, f_start:
"""

# 步骤1:将预计算的频率移到目标设备,并分割成三个维度
self.freqs = self.freqs.to(device)
if device.type == "mps":
freqs_tensor = self.freqs.to(dtype=torch.complex64).to(device)
else:
freqs_tensor = self.freqs.to(device)
# 获取实际的维度分配
if hasattr(self, 'fhw_dim') and self.fhw_dim is not None:
t_dim, h_dim, w_dim = self.fhw_dim
Expand All @@ -364,7 +367,7 @@ def forward(self, ppf, pph, ppw, patch_start_idx, device: torch.device, f_start:
t_dim = self.attention_head_dim - h_dim - w_dim

# 使用正确的split sizes(每个维度的一半)
freqs = self.freqs.split_with_sizes(
freqs = freqs_tensor.split_with_sizes(
[
t_dim // 2, # 时间维度
h_dim // 2, # 高度维度
Expand Down Expand Up @@ -453,7 +456,8 @@ def apply_rotary_emb(x, freqs):
"""
# 步骤1:reshape成 [..., head_dim//2, 2] 形式,最后一维表示(real, imag)
# 例如:[b, h, seq, 64] -> [b, h, seq, 32, 2]
x_reshaped = x.to(torch.float64).reshape(x.shape[0], x.shape[1], x.shape[2], -1, 2)
rotary_dtype = torch.float32 if x.device.type == "mps" else torch.float64
x_reshaped = x.to(rotary_dtype).reshape(x.shape[0], x.shape[1], x.shape[2], -1, 2)

# 步骤2:转换为复数表示 [b, h, seq, 32]
# 每个元素是 real + imag*i
Expand Down