Skip to content

feat: auto-download Torch checkpoint (Model) from HuggingFace on first run#148

Open
blackandredbot wants to merge 6 commits intonikopueringer:mainfrom
blackandredbot:feat/auto-model-download
Open

feat: auto-download Torch checkpoint (Model) from HuggingFace on first run#148
blackandredbot wants to merge 6 commits intonikopueringer:mainfrom
blackandredbot:feat/auto-model-download

Conversation

@blackandredbot
Copy link
Contributor

What does this change?

Adds automatic downloading of the CorridorKey Torch checkpoint (~300 MB) from HuggingFace on first run. When no .pth file exists in CorridorKeyModule/checkpoints/, the engine fetches it via hf_hub_download() and copies it to the expected location. This is triggered transparently inside _discover_checkpoint(), so all existing CLI commands (wizard, run-inference, generate-alphas) benefit without modification.

  • New _ensure_torch_checkpoint() function in backend.py
  • Network errors wrapped in RuntimeError with HF URL and connectivity hint
  • Disk space errors caught with ~300 MB size hint
  • MLX .safetensors checkpoints remain manual (Torch-only gating)
  • README updated to reflect automatic download

How was it tested?

  • 6 new unit tests in tests/test_backend.py covering happy path, skip-when-present, MLX exclusion, network error wrapping, disk space error, and logging
  • 4 Hypothesis property-based tests in tests/test_pbt_auto_download.py validating correctness properties across randomized inputs (100 examples each)
  • All 28 tests pass locally

Checklist

  • uv run pytest passes
  • uv run ruff check passes
  • uv run ruff format --check passes

Michael Foley added 3 commits March 12, 2026 12:01
Add _ensure_torch_checkpoint() to backend.py that lazily downloads the
CorridorKey .pth checkpoint via hf_hub_download() when no local file is
found. The download is triggered transparently from _discover_checkpoint()
so all existing CLI commands (wizard, run-inference, generate-alphas)
benefit without modification. MLX checkpoints remain manual.

- Add HF_REPO_ID and HF_CHECKPOINT_FILENAME constants
- Wrap network errors in RuntimeError with HF URL and connectivity hint
- Catch ENOSPC from shutil.copy2 with ~300 MB size hint
- Update README to reflect automatic download (no manual step needed)
- Add 6 unit tests covering happy path, skip-when-present, MLX exclusion,
  error wrapping, disk space, and logging
- Add 4 Hypothesis property-based tests validating correctness properties
Comment on lines 26 to +29

# Update HF_REPO_ID and HF_CHECKPOINT_FILENAME if a new model version is released.
HF_REPO_ID = "nikopueringer/CorridorKey_v1.0"
HF_CHECKPOINT_FILENAME = "CorridorKey.pth"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

📌 Note for future maintainers: These two constants are the single source of truth for which model gets auto-downloaded. If a new model version is released (e.g. CorridorKey_v2.0), update HF_REPO_ID and HF_CHECKPOINT_FILENAME here — everything else (download logic, error messages, tests) derives from them.

Replace duplicated glob + validation logic in _get_engine() with a call
to _discover_checkpoint(TORCH_EXT) from backend.py. This ensures the
service layer also benefits from auto-download and keeps checkpoint
discovery in a single place.

Removes now-unused glob import.
@blackandredbot blackandredbot changed the title feat: auto-download Torch checkpoint from HuggingFace on first run feat: auto-download Torch checkpoint (Model) from HuggingFace on first run Mar 12, 2026
@nikopueringer
Copy link
Owner

There are quite a few new lines, plus new tests. I would love to get a second person to review this before committing.

@HYP3R00T
Copy link
Contributor

I am reviewing this PR

@blackandredbot
Copy link
Contributor Author

Thank you @HYP3R00T -- I'll push a fix for the conflicts soon

@HYP3R00T
Copy link
Contributor

Please tag me @blackandredbot once the conflicts are addressed, I will continue my review after that. Thanks.

- pyproject.toml: adopt upstream dev dependency ordering
- uv.lock: take upstream lockfile (macOS drift should not be committed)
@blackandredbot
Copy link
Contributor Author

@HYP3R00T -- conflicts resolved

@shezmic
Copy link
Contributor

shezmic commented Mar 14, 2026

Code Review — Community Contribution

Hey @blackandredbot, this is a really nice quality-of-life improvement — automatically downloading the checkpoint on first run removes a common friction point for new users. The implementation is clean. A few observations:


1. huggingface_hub import at module level

from huggingface_hub import hf_hub_download

This is imported at the top of backend.py, which means it becomes a hard requirement at import time. If a user hasn't installed the huggingface_hub package (it's pulled in transitively by diffusers/transformers, but not everyone installs those extras), importing CorridorKeyModule.backend will fail with ModuleNotFoundError — even if they already have the checkpoint file and don't need the download.

Consider a lazy import inside _ensure_torch_checkpoint():

def _ensure_torch_checkpoint() -> Path:
    from huggingface_hub import hf_hub_download
    ...

This way the import only triggers when the download is actually needed.

2. Service.py cleanup is a nice bonus

Replacing the duplicated glob logic in _get_engine() with a call to the centralized _discover_checkpoint(TORCH_EXT) is a good deduplication. Removes the glob_module import too. This is the kind of cleanup that prevents drift between the two code paths.

3. shutil.copy2 from HF cache

The download flow is: hf_hub_download() → cached file → shutil.copy2()CorridorKeyModule/checkpoints/. This means the checkpoint exists in two places (HF cache + checkpoints dir). For a ~300MB file this is acceptable, but worth noting in case disk-constrained users wonder where their space went. A symlink could save space but would be less portable (Windows compat). Current approach is the safer choice.

4. Test coverage is solid

6 unit tests + 4 Hypothesis property-based tests covering happy path, skip-when-present, MLX exclusion, network errors, disk space errors, and logging. Good coverage of failure modes.

5. README update is clear

The installation step is appropriately simplified — makes it clear the download is automatic while still linking to the HuggingFace source.


Overall this is clean and well-tested. The top-level import is the main thing I'd address to avoid breaking users who don't have huggingface_hub installed but already have the checkpoint. Everything else looks good.

Move 'from huggingface_hub import hf_hub_download' from module-level in
backend.py into _ensure_torch_checkpoint(), so importing the backend no
longer requires the package when the checkpoint already exists.

Update mock patch targets in test_backend.py and
test_pbt_auto_download.py to match the new import location.

Addresses review feedback from shezmic on PR nikopueringer#148.
@blackandredbot
Copy link
Contributor Author

Thanks for the thorough review @shezmic — really appreciate you taking the time.

Great catch on the top-level import. Addressed in e09e133: moved from huggingface_hub import hf_hub_download into _ensure_torch_checkpoint() so it only triggers when a download is actually needed. Users with the checkpoint already in place won't hit a ModuleNotFoundError if huggingface_hub isn't installed. Updated the mock patch targets in both test files to match.

28 tests pass, ruff clean.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants