diff --git a/pyro-predictor/pyro_predictor/vision.py b/pyro-predictor/pyro_predictor/vision.py index 33166cfb..0e30bcb4 100644 --- a/pyro-predictor/pyro_predictor/vision.py +++ b/pyro-predictor/pyro_predictor/vision.py @@ -6,6 +6,7 @@ import logging import pathlib import platform +import shutil import tarfile from typing import Tuple @@ -21,6 +22,8 @@ MODEL_REPO_ID = "pyronear/yolo11s_plucky-pelican_v7.1.0" MODEL_NAME = "ncnn_cpu.tar.gz" +MODEL_SLUG = MODEL_REPO_ID.split("/", 1)[1] # e.g. "yolo11s_plucky-pelican_v7.1.0" +MODEL_CACHE_SUBDIR = "models" logging.basicConfig(format="%(asctime)s | %(levelname)s: %(message)s", level=logging.INFO, force=True) logger = logging.getLogger(__name__) @@ -70,18 +73,37 @@ def __init__( else: raise ValueError("Unsupported format: should be 'ncnn' or 'onnx'") - model_path = str(pathlib.Path(model_folder) / model) + # Namespace cached weights by model slug so a MODEL_REPO_ID bump lands in a + # fresh path and old weights can be purged. + cache_root = pathlib.Path(model_folder) / MODEL_CACHE_SUBDIR + model_cache = cache_root / MODEL_SLUG + model_path = str(model_cache / model) if not pathlib.Path(model_path).is_file(): + # Drop previous slugs and legacy flat layout to reclaim disk on edge devices. + if cache_root.is_dir(): + for entry in cache_root.iterdir(): + if entry.name != MODEL_SLUG: + shutil.rmtree(entry, ignore_errors=True) + logger.info(f"Removed stale model cache: {entry}") + legacy_archive = pathlib.Path(model_folder) / model + legacy_extract = pathlib.Path(model_folder) / model.replace(".tar.gz", "") + if legacy_archive.is_file(): + legacy_archive.unlink() + logger.info(f"Removed legacy model archive: {legacy_archive}") + if legacy_extract.is_dir(): + shutil.rmtree(legacy_extract, ignore_errors=True) + logger.info(f"Removed legacy model extract dir: {legacy_extract}") + logger.info(f"Downloading model from {MODEL_REPO_ID}/{model} ...") - pathlib.Path(model_folder).mkdir(exist_ok=True, parents=True) - hf_hub_download(repo_id=MODEL_REPO_ID, filename=model, local_dir=model_folder) + model_cache.mkdir(exist_ok=True, parents=True) + hf_hub_download(repo_id=MODEL_REPO_ID, filename=model, local_dir=str(model_cache)) logger.info("Model downloaded!") # Extract archive if model_path.endswith(".tar.gz"): base_name = pathlib.Path(model_path).name.replace(".tar.gz", "") - extract_path = str(pathlib.Path(model_folder) / base_name) + extract_path = str(model_cache / base_name) if not pathlib.Path(extract_path).is_dir(): pathlib.Path(extract_path).mkdir(parents=True, exist_ok=True) with tarfile.open(model_path, "r:gz") as tar: diff --git a/tests/test_vision.py b/tests/test_vision.py index 2b9fc02b..5ac10e76 100644 --- a/tests/test_vision.py +++ b/tests/test_vision.py @@ -6,6 +6,7 @@ # Canonical import — Classifier lives in pyro_predictor from pyro_predictor import Classifier +from pyro_predictor.vision import MODEL_CACHE_SUBDIR, MODEL_SLUG # pyroengine.vision shim must re-export the same class from pyroengine.vision import Classifier as ClassifierShim @@ -29,7 +30,7 @@ def test_classifier(tmpdir_factory, mock_wildfire_image): # Test onnx model model = Classifier(model_folder=folder, format="onnx") - model_path = str(pathlib.Path(folder) / "onnx_cpu" / "best.onnx") + model_path = str(pathlib.Path(folder) / MODEL_CACHE_SUBDIR / MODEL_SLUG / "onnx_cpu" / "best.onnx") assert pathlib.Path(model_path).is_file() # Test occlusion mask @@ -49,12 +50,36 @@ def sha256sum(path): return hashlib.sha256(pathlib.Path(path).read_bytes()).hexdigest() +def test_stale_cache_is_purged(tmpdir_factory): + folder = pathlib.Path(tmpdir_factory.mktemp("engine_cache")) + models_dir = folder / MODEL_CACHE_SUBDIR + + # Seed a previous-version slug alongside the current one. + stale_slug = models_dir / "yolo11s_stale-slug_v0.0.0" + stale_slug.mkdir(parents=True) + (stale_slug / "marker.txt").write_text("stale") + + # Seed the pre-slug flat layout (onnx variant, matching the format used below). + legacy_archive = folder / "onnx_cpu.tar.gz" + legacy_archive.write_bytes(b"legacy") + legacy_extract = folder / "onnx_cpu" + legacy_extract.mkdir() + (legacy_extract / "marker.txt").write_text("legacy") + + _ = Classifier(model_folder=str(folder), format="onnx") + + assert not stale_slug.exists(), "stale slug should have been purged" + assert not legacy_archive.exists(), "legacy archive should have been purged" + assert not legacy_extract.exists(), "legacy extract dir should have been purged" + assert (models_dir / MODEL_SLUG / "onnx_cpu" / "best.onnx").is_file() + + def test_download(tmpdir_factory): folder = str(tmpdir_factory.mktemp("engine_cache")) # First download _ = Classifier(model_folder=folder, format="onnx") - model_path = str(pathlib.Path(folder) / "onnx_cpu" / "best.onnx") + model_path = str(pathlib.Path(folder) / MODEL_CACHE_SUBDIR / MODEL_SLUG / "onnx_cpu" / "best.onnx") assert pathlib.Path(model_path).is_file() hash1 = sha256sum(model_path)