feat: auto-download Torch checkpoint (Model) from HuggingFace on first run#148
feat: auto-download Torch checkpoint (Model) from HuggingFace on first run#148blackandredbot wants to merge 6 commits intonikopueringer:mainfrom
Conversation
Add _ensure_torch_checkpoint() to backend.py that lazily downloads the CorridorKey .pth checkpoint via hf_hub_download() when no local file is found. The download is triggered transparently from _discover_checkpoint() so all existing CLI commands (wizard, run-inference, generate-alphas) benefit without modification. MLX checkpoints remain manual. - Add HF_REPO_ID and HF_CHECKPOINT_FILENAME constants - Wrap network errors in RuntimeError with HF URL and connectivity hint - Catch ENOSPC from shutil.copy2 with ~300 MB size hint - Update README to reflect automatic download (no manual step needed) - Add 6 unit tests covering happy path, skip-when-present, MLX exclusion, error wrapping, disk space, and logging - Add 4 Hypothesis property-based tests validating correctness properties
|
|
||
| # Update HF_REPO_ID and HF_CHECKPOINT_FILENAME if a new model version is released. | ||
| HF_REPO_ID = "nikopueringer/CorridorKey_v1.0" | ||
| HF_CHECKPOINT_FILENAME = "CorridorKey.pth" |
There was a problem hiding this comment.
📌 Note for future maintainers: These two constants are the single source of truth for which model gets auto-downloaded. If a new model version is released (e.g. CorridorKey_v2.0), update HF_REPO_ID and HF_CHECKPOINT_FILENAME here — everything else (download logic, error messages, tests) derives from them.
Replace duplicated glob + validation logic in _get_engine() with a call to _discover_checkpoint(TORCH_EXT) from backend.py. This ensures the service layer also benefits from auto-download and keeps checkpoint discovery in a single place. Removes now-unused glob import.
|
There are quite a few new lines, plus new tests. I would love to get a second person to review this before committing. |
|
I am reviewing this PR |
|
Thank you @HYP3R00T -- I'll push a fix for the conflicts soon |
|
Please tag me @blackandredbot once the conflicts are addressed, I will continue my review after that. Thanks. |
- pyproject.toml: adopt upstream dev dependency ordering - uv.lock: take upstream lockfile (macOS drift should not be committed)
|
@HYP3R00T -- conflicts resolved |
Code Review — Community ContributionHey @blackandredbot, this is a really nice quality-of-life improvement — automatically downloading the checkpoint on first run removes a common friction point for new users. The implementation is clean. A few observations: 1.
|
Move 'from huggingface_hub import hf_hub_download' from module-level in backend.py into _ensure_torch_checkpoint(), so importing the backend no longer requires the package when the checkpoint already exists. Update mock patch targets in test_backend.py and test_pbt_auto_download.py to match the new import location. Addresses review feedback from shezmic on PR nikopueringer#148.
|
Thanks for the thorough review @shezmic — really appreciate you taking the time. Great catch on the top-level import. Addressed in 28 tests pass, ruff clean. |
What does this change?
Adds automatic downloading of the CorridorKey Torch checkpoint (~300 MB) from HuggingFace on first run. When no
.pthfile exists inCorridorKeyModule/checkpoints/, the engine fetches it viahf_hub_download()and copies it to the expected location. This is triggered transparently inside_discover_checkpoint(), so all existing CLI commands (wizard, run-inference, generate-alphas) benefit without modification._ensure_torch_checkpoint()function inbackend.pyRuntimeErrorwith HF URL and connectivity hint.safetensorscheckpoints remain manual (Torch-only gating)How was it tested?
tests/test_backend.pycovering happy path, skip-when-present, MLX exclusion, network error wrapping, disk space error, and loggingtests/test_pbt_auto_download.pyvalidating correctness properties across randomized inputs (100 examples each)Checklist
uv run pytestpassesuv run ruff checkpassesuv run ruff format --checkpasses