diff --git a/README.md b/README.md index d73b1c6..a2fd875 100644 --- a/README.md +++ b/README.md @@ -118,6 +118,17 @@ python demo.py --model_path /path/to/checkpoint.pt \ --image_folder /path/to/images/ ``` +You can also load a checkpoint directly from Hugging Face. The checkpoint is +downloaded once and reused from the local Hugging Face cache on later runs. + +```bash +python demo.py -hf robbyant/lingbot-map \ + --image_folder /path/to/images/ +``` + +If a repository contains multiple checkpoint files, pass `--hf_file ` +to select the one to load. + ### Streaming Inference from Video ```bash diff --git a/demo.py b/demo.py index a8c8636..5552b9e 100644 --- a/demo.py +++ b/demo.py @@ -19,6 +19,7 @@ """ import argparse +import fnmatch import glob import os import time @@ -104,6 +105,81 @@ def load_images(image_folder=None, video_path=None, fps=10, image_ext=".jpg,.png # Model loading # ============================================================================= +_CHECKPOINT_PATTERNS = ( + "*.pt", + "*.pth", + "*.bin", + "*.safetensors", +) + + +def _is_checkpoint_file(path): + return any(fnmatch.fnmatch(os.path.basename(path), pattern) for pattern in _CHECKPOINT_PATTERNS) + + +def _choose_hf_checkpoint(repo_id, files, requested_file=None): + candidates = [path for path in files if _is_checkpoint_file(path)] + if requested_file is not None: + if requested_file in files: + return requested_file + matches = [path for path in candidates if os.path.basename(path) == requested_file] + if len(matches) == 1: + return matches[0] + raise ValueError( + f"Could not find checkpoint file '{requested_file}' in {repo_id}. " + f"Available checkpoints: {', '.join(candidates) or '(none)'}" + ) + + if len(candidates) == 1: + return candidates[0] + + repo_name = repo_id.rstrip("/").split("/")[-1] + preferred_names = ( + f"{repo_name}.pt", + f"{repo_name}.pth", + "model.pt", + "model.pth", + "checkpoint.pt", + "checkpoint.pth", + "pytorch_model.bin", + "model.safetensors", + ) + for name in preferred_names: + matches = [path for path in candidates if os.path.basename(path) == name] + if len(matches) == 1: + return matches[0] + + raise ValueError( + f"Found {len(candidates)} checkpoint files in {repo_id}; pass --hf_file " + f"to choose one. Available checkpoints: {', '.join(candidates) or '(none)'}" + ) + + +def resolve_model_path(args): + if args.model_path is not None: + return args.model_path + + from huggingface_hub import HfApi, hf_hub_download + + print(f"Resolving Hugging Face model: {args.hf}") + api = HfApi() + files = api.list_repo_files( + repo_id=args.hf, + repo_type="model", + revision=args.hf_revision, + ) + filename = _choose_hf_checkpoint(args.hf, files, requested_file=args.hf_file) + path = hf_hub_download( + repo_id=args.hf, + filename=filename, + repo_type="model", + revision=args.hf_revision, + cache_dir=args.hf_cache_dir, + ) + print(f"Using cached Hugging Face checkpoint: {path}") + return path + + def load_model(args, device): """Load GCTStream model from checkpoint.""" if getattr(args, "mode", "streaming") == "windowed": @@ -127,7 +203,11 @@ def load_model(args, device): if args.model_path: print(f"Loading checkpoint: {args.model_path}") - ckpt = torch.load(args.model_path, map_location=device, weights_only=False) + if args.model_path.endswith(".safetensors"): + from safetensors.torch import load_file + ckpt = load_file(args.model_path, device="cpu") + else: + ckpt = torch.load(args.model_path, map_location=device, weights_only=False) state_dict = ckpt.get("model", ckpt) missing, unexpected = model.load_state_dict(state_dict, strict=False) if missing: @@ -242,7 +322,16 @@ def main(): parser.add_argument("--stride", type=int, default=1) # Model - parser.add_argument("--model_path", type=str, required=True) + parser.add_argument("--model_path", type=str, default=None, + help="Local checkpoint path. Use this or -hf/--hf.") + parser.add_argument("-hf", "--hf", type=str, default=None, + help="Hugging Face model repo id, for example robbyant/lingbot-map") + parser.add_argument("--hf_file", type=str, default=None, + help="Checkpoint filename inside the Hugging Face repo if there are multiple") + parser.add_argument("--hf_revision", type=str, default=None, + help="Optional Hugging Face branch, tag, or commit") + parser.add_argument("--hf_cache_dir", type=str, default=None, + help="Optional Hugging Face cache directory") parser.add_argument("--image_size", type=int, default=518) parser.add_argument("--patch_size", type=int, default=14) @@ -294,8 +383,15 @@ def main(): args = parser.parse_args() assert args.image_folder or args.video_path, \ "Provide --image_folder or --video_path" + model_source_count = sum([ + args.model_path is not None, + args.hf is not None, + ]) + assert model_source_count == 1, \ + "Provide exactly one model source: --model_path or -hf/--hf" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + args.model_path = resolve_model_path(args) # ── Load images & model ────────────────────────────────────────────────── t0 = time.time()