Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,17 @@ python demo.py --model_path /path/to/checkpoint.pt \
--image_folder /path/to/images/
```

You can also load a checkpoint directly from Hugging Face. The checkpoint is
downloaded once and reused from the local Hugging Face cache on later runs.

```bash
python demo.py -hf robbyant/lingbot-map \
--image_folder /path/to/images/
```

If a repository contains multiple checkpoint files, pass `--hf_file <filename>`
to select the one to load.

### Streaming Inference from Video

```bash
Expand Down
100 changes: 98 additions & 2 deletions demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"""

import argparse
import fnmatch
import glob
import os
import time
Expand Down Expand Up @@ -104,6 +105,81 @@ def load_images(image_folder=None, video_path=None, fps=10, image_ext=".jpg,.png
# Model loading
# =============================================================================

_CHECKPOINT_PATTERNS = (
"*.pt",
"*.pth",
"*.bin",
"*.safetensors",
)


def _is_checkpoint_file(path):
return any(fnmatch.fnmatch(os.path.basename(path), pattern) for pattern in _CHECKPOINT_PATTERNS)


def _choose_hf_checkpoint(repo_id, files, requested_file=None):
candidates = [path for path in files if _is_checkpoint_file(path)]
if requested_file is not None:
if requested_file in files:
return requested_file
matches = [path for path in candidates if os.path.basename(path) == requested_file]
if len(matches) == 1:
return matches[0]
raise ValueError(
f"Could not find checkpoint file '{requested_file}' in {repo_id}. "
f"Available checkpoints: {', '.join(candidates) or '(none)'}"
)

if len(candidates) == 1:
return candidates[0]

repo_name = repo_id.rstrip("/").split("/")[-1]
preferred_names = (
f"{repo_name}.pt",
f"{repo_name}.pth",
"model.pt",
"model.pth",
"checkpoint.pt",
"checkpoint.pth",
"pytorch_model.bin",
"model.safetensors",
)
for name in preferred_names:
matches = [path for path in candidates if os.path.basename(path) == name]
if len(matches) == 1:
return matches[0]

raise ValueError(
f"Found {len(candidates)} checkpoint files in {repo_id}; pass --hf_file "
f"to choose one. Available checkpoints: {', '.join(candidates) or '(none)'}"
)


def resolve_model_path(args):
if args.model_path is not None:
return args.model_path

from huggingface_hub import HfApi, hf_hub_download

print(f"Resolving Hugging Face model: {args.hf}")
api = HfApi()
files = api.list_repo_files(
repo_id=args.hf,
repo_type="model",
revision=args.hf_revision,
)
filename = _choose_hf_checkpoint(args.hf, files, requested_file=args.hf_file)
path = hf_hub_download(
repo_id=args.hf,
filename=filename,
repo_type="model",
revision=args.hf_revision,
cache_dir=args.hf_cache_dir,
)
print(f"Using cached Hugging Face checkpoint: {path}")
return path


def load_model(args, device):
"""Load GCTStream model from checkpoint."""
if getattr(args, "mode", "streaming") == "windowed":
Expand All @@ -127,7 +203,11 @@ def load_model(args, device):

if args.model_path:
print(f"Loading checkpoint: {args.model_path}")
ckpt = torch.load(args.model_path, map_location=device, weights_only=False)
if args.model_path.endswith(".safetensors"):
from safetensors.torch import load_file
ckpt = load_file(args.model_path, device="cpu")
else:
ckpt = torch.load(args.model_path, map_location=device, weights_only=False)
state_dict = ckpt.get("model", ckpt)
missing, unexpected = model.load_state_dict(state_dict, strict=False)
if missing:
Expand Down Expand Up @@ -242,7 +322,16 @@ def main():
parser.add_argument("--stride", type=int, default=1)

# Model
parser.add_argument("--model_path", type=str, required=True)
parser.add_argument("--model_path", type=str, default=None,
help="Local checkpoint path. Use this or -hf/--hf.")
parser.add_argument("-hf", "--hf", type=str, default=None,
help="Hugging Face model repo id, for example robbyant/lingbot-map")
parser.add_argument("--hf_file", type=str, default=None,
help="Checkpoint filename inside the Hugging Face repo if there are multiple")
parser.add_argument("--hf_revision", type=str, default=None,
help="Optional Hugging Face branch, tag, or commit")
parser.add_argument("--hf_cache_dir", type=str, default=None,
help="Optional Hugging Face cache directory")
parser.add_argument("--image_size", type=int, default=518)
parser.add_argument("--patch_size", type=int, default=14)

Expand Down Expand Up @@ -294,8 +383,15 @@ def main():
args = parser.parse_args()
assert args.image_folder or args.video_path, \
"Provide --image_folder or --video_path"
model_source_count = sum([
args.model_path is not None,
args.hf is not None,
])
assert model_source_count == 1, \
"Provide exactly one model source: --model_path or -hf/--hf"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
args.model_path = resolve_model_path(args)

# ── Load images & model ──────────────────────────────────────────────────
t0 = time.time()
Expand Down