Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 26 additions & 4 deletions pyro-predictor/pyro_predictor/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import logging
import pathlib
import platform
import shutil
import tarfile
from typing import Tuple

Expand All @@ -21,6 +22,8 @@

MODEL_REPO_ID = "pyronear/yolo11s_plucky-pelican_v7.1.0"
MODEL_NAME = "ncnn_cpu.tar.gz"
MODEL_SLUG = MODEL_REPO_ID.split("/", 1)[1] # e.g. "yolo11s_plucky-pelican_v7.1.0"
MODEL_CACHE_SUBDIR = "models"

logging.basicConfig(format="%(asctime)s | %(levelname)s: %(message)s", level=logging.INFO, force=True)
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -70,18 +73,37 @@ def __init__(
else:
raise ValueError("Unsupported format: should be 'ncnn' or 'onnx'")

model_path = str(pathlib.Path(model_folder) / model)
# Namespace cached weights by model slug so a MODEL_REPO_ID bump lands in a
# fresh path and old weights can be purged.
cache_root = pathlib.Path(model_folder) / MODEL_CACHE_SUBDIR
model_cache = cache_root / MODEL_SLUG
model_path = str(model_cache / model)

if not pathlib.Path(model_path).is_file():
# Drop previous slugs and legacy flat layout to reclaim disk on edge devices.
if cache_root.is_dir():
for entry in cache_root.iterdir():
if entry.name != MODEL_SLUG:
shutil.rmtree(entry, ignore_errors=True)
logger.info(f"Removed stale model cache: {entry}")
legacy_archive = pathlib.Path(model_folder) / model
legacy_extract = pathlib.Path(model_folder) / model.replace(".tar.gz", "")
if legacy_archive.is_file():
legacy_archive.unlink()
logger.info(f"Removed legacy model archive: {legacy_archive}")
if legacy_extract.is_dir():
shutil.rmtree(legacy_extract, ignore_errors=True)
logger.info(f"Removed legacy model extract dir: {legacy_extract}")

logger.info(f"Downloading model from {MODEL_REPO_ID}/{model} ...")
pathlib.Path(model_folder).mkdir(exist_ok=True, parents=True)
hf_hub_download(repo_id=MODEL_REPO_ID, filename=model, local_dir=model_folder)
model_cache.mkdir(exist_ok=True, parents=True)
hf_hub_download(repo_id=MODEL_REPO_ID, filename=model, local_dir=str(model_cache))
logger.info("Model downloaded!")

# Extract archive
if model_path.endswith(".tar.gz"):
base_name = pathlib.Path(model_path).name.replace(".tar.gz", "")
extract_path = str(pathlib.Path(model_folder) / base_name)
extract_path = str(model_cache / base_name)
if not pathlib.Path(extract_path).is_dir():
pathlib.Path(extract_path).mkdir(parents=True, exist_ok=True)
with tarfile.open(model_path, "r:gz") as tar:
Expand Down
29 changes: 27 additions & 2 deletions tests/test_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

# Canonical import — Classifier lives in pyro_predictor
from pyro_predictor import Classifier
from pyro_predictor.vision import MODEL_CACHE_SUBDIR, MODEL_SLUG

# pyroengine.vision shim must re-export the same class
from pyroengine.vision import Classifier as ClassifierShim
Expand All @@ -29,7 +30,7 @@

# Test onnx model
model = Classifier(model_folder=folder, format="onnx")
model_path = str(pathlib.Path(folder) / "onnx_cpu" / "best.onnx")
model_path = str(pathlib.Path(folder) / MODEL_CACHE_SUBDIR / MODEL_SLUG / "onnx_cpu" / "best.onnx")
assert pathlib.Path(model_path).is_file()

# Test occlusion mask
Expand All @@ -49,12 +50,36 @@
return hashlib.sha256(pathlib.Path(path).read_bytes()).hexdigest()


def test_stale_cache_is_purged(tmpdir_factory):
folder = pathlib.Path(tmpdir_factory.mktemp("engine_cache"))
models_dir = folder / MODEL_CACHE_SUBDIR

# Seed a previous-version slug alongside the current one.
stale_slug = models_dir / "yolo11s_stale-slug_v0.0.0"
stale_slug.mkdir(parents=True)
(stale_slug / "marker.txt").write_text("stale")

# Seed the pre-slug flat layout (onnx variant, matching the format used below).
legacy_archive = folder / "onnx_cpu.tar.gz"
legacy_archive.write_bytes(b"legacy")
legacy_extract = folder / "onnx_cpu"
legacy_extract.mkdir()
(legacy_extract / "marker.txt").write_text("legacy")

_ = Classifier(model_folder=str(folder), format="onnx")

assert not stale_slug.exists(), "stale slug should have been purged"

Check warning on line 71 in tests/test_vision.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

tests/test_vision.py#L71

Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.
assert not legacy_archive.exists(), "legacy archive should have been purged"

Check warning on line 72 in tests/test_vision.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

tests/test_vision.py#L72

Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.
assert not legacy_extract.exists(), "legacy extract dir should have been purged"

Check warning on line 73 in tests/test_vision.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

tests/test_vision.py#L73

Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.
assert (models_dir / MODEL_SLUG / "onnx_cpu" / "best.onnx").is_file()

Check warning on line 74 in tests/test_vision.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

tests/test_vision.py#L74

Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.


def test_download(tmpdir_factory):
folder = str(tmpdir_factory.mktemp("engine_cache"))

# First download
_ = Classifier(model_folder=folder, format="onnx")
model_path = str(pathlib.Path(folder) / "onnx_cpu" / "best.onnx")
model_path = str(pathlib.Path(folder) / MODEL_CACHE_SUBDIR / MODEL_SLUG / "onnx_cpu" / "best.onnx")
assert pathlib.Path(model_path).is_file()

hash1 = sha256sum(model_path)
Expand Down
Loading