diff --git a/inference/core/cache/air_gapped.py b/inference/core/cache/air_gapped.py index 4ddb280996..3c167b79f3 100644 --- a/inference/core/cache/air_gapped.py +++ b/inference/core/cache/air_gapped.py @@ -4,11 +4,9 @@ offline workflow construction. """ -import hashlib import json import logging import os -import re from typing import Any, Dict, List, Optional from inference.core.env import MODEL_CACHE_DIR, USE_INFERENCE_MODELS @@ -20,22 +18,6 @@ _SKIP_TOP_LEVEL = {"workflow", "_file_locks"} -def _slugify_model_id(model_id: str) -> str: - """Reproduce the slug used by inference-models for cache directory names. - - Must stay in sync with - ``inference_models.models.auto_loaders.core.slugify_model_id_to_os_safe_format``. - """ - slug = re.sub(r"[^A-Za-z0-9_-]+", "-", model_id) - slug = re.sub(r"[_-]{2,}", "-", slug) - if not slug: - slug = "special-char-only-model-id" - if len(slug) > 48: - slug = slug[:48] - digest = hashlib.blake2s(model_id.encode("utf-8"), digest_size=4).hexdigest() - return f"{slug}-{digest}" - - def _has_non_hidden_children(path: str) -> bool: """Return True if *path* contains at least one non-hidden entry. @@ -76,12 +58,14 @@ def is_model_cached(model_id: str) -> bool: if os.path.isdir(traditional_path) and _has_non_hidden_children(traditional_path): return True - slug = _slugify_model_id(model_id) - models_cache_path = os.path.join(MODEL_CACHE_DIR, "models-cache", slug) - if os.path.isdir(models_cache_path) and _has_non_hidden_children(models_cache_path): - return True + try: + from inference_models.models.auto_loaders.core import ( + find_cached_model_package_dir, + ) - return False + return find_cached_model_package_dir(model_id) is not None + except ImportError: + return False def has_cached_model_variant(model_variants: Optional[List[str]]) -> bool: @@ -107,23 +91,29 @@ def _load_blocks() -> list: def scan_cached_models(cache_dir: str) -> List[Dict[str, Any]]: - """Walk *cache_dir* looking for ``model_type.json`` marker files. + """Walk *cache_dir* looking for cached model metadata files. - Each marker is written by the model registry when a model is first - downloaded. The file contains at least ``project_task_type`` and - ``model_type`` keys. + Scans two cache layouts: + + 1. **Traditional** — ``model_type.json`` marker files written by the model + registry. The model ID is derived from the directory path. + 2. **inference-models** — ``model_config.json`` files written by + ``dump_model_config_for_offline_use``. The canonical ``model_id`` is + read from the file, which ensures alias resolution works correctly + (the directory name is an opaque slug in this layout). Returns a list of dicts with the following shape:: { - "model_id": "workspace/project/3", - "name": "workspace/project/3", + "model_id": "coco/22", + "name": "coco/22", "task_type": "object-detection", "model_architecture": "yolov8n", "is_foundation": False, } """ results: List[Dict[str, Any]] = [] + seen_ids: set = set() if not os.path.isdir(cache_dir): return results @@ -134,36 +124,72 @@ def scan_cached_models(cache_dir: str) -> List[Dict[str, Any]]: dirs[:] = [d for d in dirs if d not in _SKIP_TOP_LEVEL] continue - if "model_type.json" not in files: - continue + has_model_type = "model_type.json" in files + has_model_config = "model_config.json" in files - model_type_path = os.path.join(root, "model_type.json") - try: - with open(model_type_path, "r") as fh: - metadata = json.load(fh) - except (json.JSONDecodeError, OSError) as exc: - logger.warning( - "Skipping unreadable model_type.json at %s: %s", - model_type_path, - exc, - ) + if not has_model_type and not has_model_config: continue + metadata: Optional[dict] = None + use_stored_model_id = False + + # Prefer model_config.json when present — it contains the canonical + # model_id that matches REGISTERED_ALIASES. + if has_model_config: + config_path = os.path.join(root, "model_config.json") + try: + with open(config_path, "r") as fh: + cfg = json.load(fh) + if ( + isinstance(cfg, dict) + and cfg.get("task_type") + and cfg.get("model_id") + ): + metadata = cfg + use_stored_model_id = True + except (json.JSONDecodeError, OSError): + pass + + # Fall back to model_type.json for the traditional layout. + if metadata is None and has_model_type: + model_type_path = os.path.join(root, "model_type.json") + try: + with open(model_type_path, "r") as fh: + metadata = json.load(fh) + except (json.JSONDecodeError, OSError) as exc: + logger.warning( + "Skipping unreadable model_type.json at %s: %s", + model_type_path, + exc, + ) + continue + if not isinstance(metadata, dict): continue - # Support both traditional keys and inference-models metadata keys. - task_type = metadata.get(PROJECT_TASK_TYPE_KEY) or metadata.get("taskType", "") - model_architecture = metadata.get(MODEL_TYPE_KEY) or metadata.get( - "modelArchitecture", "" + task_type = ( + metadata.get("task_type") + or metadata.get(PROJECT_TASK_TYPE_KEY) + or metadata.get("taskType", "") + ) + model_architecture = ( + metadata.get("model_architecture") + or metadata.get(MODEL_TYPE_KEY) + or metadata.get("modelArchitecture", "") ) if not task_type: continue - model_id = os.path.relpath(root, cache_dir) - # Normalise path separators on Windows. - model_id = model_id.replace(os.sep, "/") + if use_stored_model_id: + model_id = metadata["model_id"] + else: + model_id = os.path.relpath(root, cache_dir) + model_id = model_id.replace(os.sep, "/") + + if model_id in seen_ids: + continue + seen_ids.add(model_id) results.append( { diff --git a/inference/core/interfaces/http/builder/routes.py b/inference/core/interfaces/http/builder/routes.py index 175688e32b..865795036b 100644 --- a/inference/core/interfaces/http/builder/routes.py +++ b/inference/core/interfaces/http/builder/routes.py @@ -35,7 +35,7 @@ # ---------------------------------------------------------------- csrf_file = workflow_local_dir / ".csrf" if csrf_file.exists(): - csrf = csrf_file.read_text() + csrf = csrf_file.read_text().strip() else: csrf = os.urandom(16).hex() csrf_file.write_text(csrf) @@ -112,6 +112,12 @@ async def builder_edit(workflow_id: str): # ---------------------- +@router.get("/api/csrf") +@with_route_exceptions_async +async def get_csrf_token(): + return {"csrf": csrf} + + @router.get("/api", dependencies=[Depends(verify_csrf_token)]) @with_route_exceptions_async async def get_all_workflows(): diff --git a/inference/core/interfaces/http/handlers/workflows.py b/inference/core/interfaces/http/handlers/workflows.py index d174bbf835..066eda0fea 100644 --- a/inference/core/interfaces/http/handlers/workflows.py +++ b/inference/core/interfaces/http/handlers/workflows.py @@ -150,9 +150,11 @@ def _get_air_gapped_info_for_block( Compatible task types from ``get_compatible_task_types()`` are always attached when present. """ - task_types = manifest_cls.get_compatible_task_types() + task_types = manifest_cls.get_compatible_task_types() if hasattr(manifest_cls, "get_compatible_task_types") else [] # 1. Explicit cloud/internet declaration + if not hasattr(manifest_cls, "get_air_gapped_availability"): + return BlockAirGappedInfo(available=True, compatible_task_types=task_types) availability = manifest_cls.get_air_gapped_availability() if not availability.available: return BlockAirGappedInfo( @@ -162,7 +164,7 @@ def _get_air_gapped_info_for_block( ) # 2. Foundation models with locally-cacheable weights - model_variants = manifest_cls.get_supported_model_variants() + model_variants = manifest_cls.get_supported_model_variants() if hasattr(manifest_cls, "get_supported_model_variants") else None if model_variants is not None: cached = has_cached_model_variant(model_variants) # Use the first variant as a representative identifier for the UI. diff --git a/inference/core/registries/roboflow.py b/inference/core/registries/roboflow.py index 7c03ca6bc0..a34f2e5940 100644 --- a/inference/core/registries/roboflow.py +++ b/inference/core/registries/roboflow.py @@ -299,21 +299,59 @@ def get_model_metadata_from_cache( def _get_model_metadata_from_cache( dataset_id: Union[DatasetID, ModelID], version_id: Optional[VersionID] ) -> Optional[Tuple[TaskType, ModelType]]: + # Layout 1: traditional model_type.json model_type_cache_path = construct_model_type_cache_path( dataset_id=dataset_id, version_id=version_id ) - if not os.path.isfile(model_type_cache_path): - return None + if os.path.isfile(model_type_cache_path): + try: + model_metadata = read_json(path=model_type_cache_path) + if not model_metadata_content_is_invalid(content=model_metadata): + return ( + model_metadata[PROJECT_TASK_TYPE_KEY], + model_metadata[MODEL_TYPE_KEY], + ) + except ValueError as e: + logger.warning( + f"Could not load model description from cache under path: " + f"{model_type_cache_path} - decoding issue: {e}." + ) + + # Layout 2: inference-models model_config.json + model_id = f"{dataset_id}/{version_id}" if version_id else dataset_id + result = _get_model_metadata_from_inference_models_cache(model_id) + if result is not None: + return result + + return None + + +def _get_model_metadata_from_inference_models_cache( + model_id: str, +) -> Optional[Tuple[TaskType, ModelType]]: + """Check the inference-models cache layout for model metadata.""" try: - model_metadata = read_json(path=model_type_cache_path) - if model_metadata_content_is_invalid(content=model_metadata): - return None - return model_metadata[PROJECT_TASK_TYPE_KEY], model_metadata[MODEL_TYPE_KEY] - except ValueError as e: - logger.warning( - f"Could not load model description from cache under path: {model_type_cache_path} - decoding issue: {e}." + from inference_models.models.auto_loaders.core import ( + find_cached_model_package_dir, ) + except ImportError: + return None + + cached_dir = find_cached_model_package_dir(model_id) + if cached_dir is None: + return None + config_path = os.path.join(cached_dir, "model_config.json") + try: + metadata = read_json(path=config_path) + except ValueError: + return None + if not isinstance(metadata, dict): return None + task_type = metadata.get("task_type", "") + model_arch = metadata.get("model_architecture", "") + if task_type and model_arch: + return task_type, model_arch + return None def model_metadata_content_is_invalid(content: Optional[Union[list, dict]]) -> bool: diff --git a/inference_models/inference_models/models/auto_loaders/core.py b/inference_models/inference_models/models/auto_loaders/core.py index 394d16543d..b21e3d1514 100644 --- a/inference_models/inference_models/models/auto_loaders/core.py +++ b/inference_models/inference_models/models/auto_loaders/core.py @@ -27,6 +27,7 @@ MissingModelInitParameterError, ModelPackageAlternativesExhaustedError, NoModelPackagesAvailableError, + RetryError, UnauthorizedModelAccessError, ) from inference_models.logger import LOGGER, verbose_info @@ -827,6 +828,21 @@ def register_file_created_for_model_package( model_id=model_id_or_path, api_key=api_key ) raise error + except RetryError: + cached_package_dir = find_cached_model_package_dir( + model_id=model_id_or_path + ) + if cached_package_dir is None: + raise + LOGGER.info( + f"Network unavailable for model {model_id_or_path}, " + f"loading from cached package at {cached_package_dir}" + ) + return attempt_loading_model_from_local_storage( + model_dir_or_weights_path=cached_package_dir, + allow_local_code_packages=allow_local_code_packages, + model_init_kwargs=model_init_kwargs, + ) # here we verify if de-aliasing or access confirmation from auth master changed something model_from_access_manager = model_access_manager.retrieve_model_instance( model_id=model_id_or_path, @@ -1292,6 +1308,7 @@ def initialize_model( task_type=task_type, backend_type=model_package.backend, file_lock_acquire_timeout=model_download_file_lock_acquire_timeout, + model_id=model_id, on_file_created=on_file_created, ) resolved_files = set(shared_files_mapping.values()) @@ -1392,6 +1409,7 @@ def dump_model_config_for_offline_use( task_type: TaskType, backend_type: Optional[BackendType], file_lock_acquire_timeout: int, + model_id: Optional[str] = None, on_file_created: Optional[Callable[[str], None]] = None, ) -> None: if os.path.exists(config_path): @@ -1400,14 +1418,17 @@ def dump_model_config_for_offline_use( return None target_file_dir, target_file_name = os.path.split(config_path) lock_path = os.path.join(target_file_dir, f".{target_file_name}.lock") + content = { + "model_architecture": model_architecture, + "task_type": task_type, + "backend_type": backend_type, + } + if model_id is not None: + content["model_id"] = model_id with FileLock(lock_path, timeout=file_lock_acquire_timeout): dump_json( path=config_path, - content={ - "model_architecture": model_architecture, - "task_type": task_type, - "backend_type": backend_type, - }, + content=content, ) if on_file_created: on_file_created(config_path) @@ -1515,6 +1536,26 @@ def generate_model_package_cache_path(model_id: str, package_id: str) -> str: ) +def find_cached_model_package_dir(model_id: str) -> Optional[str]: + """Return the path to a locally-cached model package for *model_id*, or ``None``. + + Scans ``{INFERENCE_HOME}/models-cache/{slug}/`` for any package directory + that contains a valid ``model_config.json``. This is used as a fallback + when the weights-provider API is unreachable (offline / air-gapped). + """ + slug = slugify_model_id_to_os_safe_format(model_id=model_id) + slug_dir = os.path.abspath(os.path.join(INFERENCE_HOME, "models-cache", slug)) + if not os.path.isdir(slug_dir): + return None + for entry in os.listdir(slug_dir): + package_dir = os.path.join(slug_dir, entry) + if not os.path.isdir(package_dir): + continue + if os.path.isfile(os.path.join(package_dir, MODEL_CONFIG_FILE_NAME)): + return package_dir + return None + + def ensure_package_id_is_os_safe(model_id: str, package_id: str) -> None: if re.search(r"[^A-Za-z0-9]", package_id): raise InsecureModelIdentifierError( diff --git a/inference_models/tests/unit_tests/models/auto_loaders/test_core.py b/inference_models/tests/unit_tests/models/auto_loaders/test_core.py index 609ccbbae7..5cca7735c1 100644 --- a/inference_models/tests/unit_tests/models/auto_loaders/test_core.py +++ b/inference_models/tests/unit_tests/models/auto_loaders/test_core.py @@ -431,3 +431,132 @@ def _create_file(path: str, content: str) -> None: def _read_file(path: str) -> str: with open(path, "r") as f: return f.read() + + +# --------------------------------------------------------------------------- +# find_cached_model_package_dir +# --------------------------------------------------------------------------- + + +class TestFindCachedModelPackageDir: + + def test_returns_package_dir_when_model_config_exists(self, tmp_path): + from inference_models.models.auto_loaders.core import ( + find_cached_model_package_dir, + slugify_model_id_to_os_safe_format, + ) + + model_id = "coco/22" + slug = slugify_model_id_to_os_safe_format(model_id=model_id) + pkg_dir = tmp_path / "models-cache" / slug / "pkg001" + pkg_dir.mkdir(parents=True) + (pkg_dir / "model_config.json").write_text( + json.dumps({"task_type": "object-detection"}) + ) + + with mock.patch( + "inference_models.models.auto_loaders.core.INFERENCE_HOME", str(tmp_path) + ): + result = find_cached_model_package_dir(model_id) + + assert result == str(pkg_dir) + + def test_returns_none_when_no_cache(self, tmp_path): + from inference_models.models.auto_loaders.core import ( + find_cached_model_package_dir, + ) + + with mock.patch( + "inference_models.models.auto_loaders.core.INFERENCE_HOME", str(tmp_path) + ): + result = find_cached_model_package_dir("nonexistent/model") + + assert result is None + + def test_returns_none_when_no_model_config(self, tmp_path): + from inference_models.models.auto_loaders.core import ( + find_cached_model_package_dir, + slugify_model_id_to_os_safe_format, + ) + + model_id = "my/model" + slug = slugify_model_id_to_os_safe_format(model_id=model_id) + pkg_dir = tmp_path / "models-cache" / slug / "pkg001" + pkg_dir.mkdir(parents=True) + (pkg_dir / "weights.onnx").write_text("fake") + + with mock.patch( + "inference_models.models.auto_loaders.core.INFERENCE_HOME", str(tmp_path) + ): + result = find_cached_model_package_dir(model_id) + + assert result is None + + +# --------------------------------------------------------------------------- +# RetryError offline fallback in from_pretrained +# --------------------------------------------------------------------------- + + +class TestFromPretrainedOfflineFallback: + + def test_falls_back_to_cached_package_on_retry_error(self, tmp_path): + from inference_models.errors import RetryError + from inference_models.models.auto_loaders.core import ( + find_cached_model_package_dir, + slugify_model_id_to_os_safe_format, + ) + + model_id = "test/model" + slug = slugify_model_id_to_os_safe_format(model_id=model_id) + pkg_dir = tmp_path / "models-cache" / slug / "pkg001" + pkg_dir.mkdir(parents=True) + (pkg_dir / "model_config.json").write_text( + json.dumps( + { + "model_id": model_id, + "task_type": "object-detection", + "model_architecture": "yolov8n", + "backend_type": "onnxruntime", + } + ) + ) + + fake_model = MagicMock() + + with mock.patch( + "inference_models.models.auto_loaders.core.INFERENCE_HOME", str(tmp_path) + ), mock.patch( + "inference_models.models.auto_loaders.core.get_model_from_provider", + side_effect=RetryError("network down"), + ), mock.patch( + "inference_models.models.auto_loaders.core.attempt_loading_model_from_local_storage", + return_value=fake_model, + ) as mock_load: + from inference_models.models.auto_loaders.core import AutoModel + + result = AutoModel.from_pretrained( + model_id_or_path=model_id, + api_key="test-key", + ) + + assert result is fake_model + mock_load.assert_called_once() + assert mock_load.call_args[1]["model_dir_or_weights_path"] == str(pkg_dir) + + def test_reraises_retry_error_when_no_cache(self, tmp_path): + from inference_models.errors import RetryError + + with mock.patch( + "inference_models.models.auto_loaders.core.INFERENCE_HOME", str(tmp_path) + ), mock.patch( + "inference_models.models.auto_loaders.core.get_model_from_provider", + side_effect=RetryError("network down"), + ): + from inference_models.models.auto_loaders.core import AutoModel + + with pytest.raises(RetryError): + AutoModel.from_pretrained( + model_id_or_path="nonexistent/model", + api_key="test-key", + ) diff --git a/tests/unit/core/cache/test_air_gapped.py b/tests/unit/core/cache/test_air_gapped.py index 0c0e136f1c..a1de44e3d6 100644 --- a/tests/unit/core/cache/test_air_gapped.py +++ b/tests/unit/core/cache/test_air_gapped.py @@ -8,7 +8,6 @@ import pytest - # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- @@ -300,32 +299,156 @@ def model_json_schema(cls) -> dict: result = get_cached_foundation_models(blocks=[block]) -# ── Cross-validation: _slugify_model_id must match inference_models ────────── - -_SLUGIFY_TEST_IDS = [ - "clip/ViT-B-16", - "coco/40", - "rfdetr-medium", - "sam3/sam3_final", - "florence-pretrains/3", - "depth-anything-v3/small", - "smolvlm2/smolvlm-2.2b-instruct", - "qwen-pretrains/1", - "a" * 100, # long model id - "special!!!chars###here", -] - - -@pytest.mark.parametrize("model_id", _SLUGIFY_TEST_IDS) -def test_slugify_matches_inference_models(model_id: str): - """Ensure _slugify_model_id stays in sync with the canonical implementation.""" - try: - from inference_models.models.auto_loaders.core import ( - slugify_model_id_to_os_safe_format, +# --------------------------------------------------------------------------- +# scan_cached_models — model_config.json (inference-models cache layout) +# --------------------------------------------------------------------------- + + +def _write_model_config_json( + cache_dir: str, + slug_dir: str, + package_id: str, + config: dict, +) -> None: + """Write a ``model_config.json`` inside the inference-models cache layout.""" + package_dir = os.path.join(cache_dir, "models-cache", slug_dir, package_id) + os.makedirs(package_dir, exist_ok=True) + with open(os.path.join(package_dir, "model_config.json"), "w") as fh: + json.dump(config, fh) + + +class TestScanModelConfigJson: + """model_config.json written by dump_model_config_for_offline_use.""" + + def test_uses_canonical_model_id_from_config(self, tmp_path): + """When model_config.json has model_id, use it instead of directory path.""" + from inference.core.cache.air_gapped import scan_cached_models + + cache = str(tmp_path) + _write_model_config_json( + cache, + slug_dir="coco-22-abcd1234", + package_id="pkg-001", + config={ + "model_id": "coco/22", + "task_type": "object-detection", + "model_architecture": "yolov10b", + "backend_type": "onnxruntime", + }, + ) + + result = scan_cached_models(cache) + + assert len(result) == 1 + m = result[0] + assert m["model_id"] == "coco/22" + assert m["task_type"] == "object-detection" + assert m["model_architecture"] == "yolov10b" + assert m["is_foundation"] is False + + def test_deduplicates_by_model_id(self, tmp_path): + """Two cache entries with the same canonical model_id produce one result.""" + from inference.core.cache.air_gapped import scan_cached_models + + cache = str(tmp_path) + # Same model in inference-models layout + _write_model_config_json( + cache, + slug_dir="coco-22-abcd1234", + package_id="pkg-001", + config={ + "model_id": "coco/22", + "task_type": "object-detection", + "model_architecture": "yolov10b", + "backend_type": "onnxruntime", + }, + ) + # Same model also present in traditional layout + _write_model_type_json( + cache, + "coco/22", + {"project_task_type": "object-detection", "model_type": "yolov10b"}, + ) + + result = scan_cached_models(cache) + + assert len(result) == 1 + assert result[0]["model_id"] == "coco/22" + + def test_skips_config_without_model_id(self, tmp_path): + """model_config.json missing model_id falls back to model_type.json.""" + from inference.core.cache.air_gapped import scan_cached_models + + cache = str(tmp_path) + # model_config.json without model_id — should not be picked up + _write_model_config_json( + cache, + slug_dir="some-slug-abcd1234", + package_id="pkg-001", + config={ + "task_type": "object-detection", + "model_architecture": "yolov8n", + "backend_type": "onnxruntime", + }, ) - except ImportError: - pytest.skip("inference_models not installed") - from inference.core.cache.air_gapped import _slugify_model_id + result = scan_cached_models(cache) + + assert len(result) == 0 + + +# --------------------------------------------------------------------------- +# is_model_cached — inference-models layout delegation +# --------------------------------------------------------------------------- + - assert _slugify_model_id(model_id) == slugify_model_id_to_os_safe_format(model_id) +class TestIsModelCachedInferenceModels: + """is_model_cached delegates to find_cached_model_package_dir for the + inference-models cache layout.""" + + def test_returns_true_when_package_dir_found(self, tmp_path): + from inference.core.cache.air_gapped import is_model_cached + + cache = str(tmp_path) + fake_find = MagicMock(return_value="/some/cached/dir") + fake_module = MagicMock() + fake_module.find_cached_model_package_dir = fake_find + + with patch("inference.core.cache.air_gapped.MODEL_CACHE_DIR", cache), patch( + "inference.core.cache.air_gapped.USE_INFERENCE_MODELS", True + ), patch.dict( + "sys.modules", + {"inference_models.models.auto_loaders.core": fake_module}, + ): + assert is_model_cached("my-model") is True + fake_find.assert_called_once_with("my-model") + + def test_returns_false_when_no_cache_hit(self): + from inference.core.cache.air_gapped import is_model_cached + + fake_find = MagicMock(return_value=None) + fake_module = MagicMock() + fake_module.find_cached_model_package_dir = fake_find + + with patch( + "inference.core.cache.air_gapped.MODEL_CACHE_DIR", "/nonexistent" + ), patch( + "inference.core.cache.air_gapped.USE_INFERENCE_MODELS", True + ), patch.dict( + "sys.modules", + {"inference_models.models.auto_loaders.core": fake_module}, + ): + assert is_model_cached("no-such-model") is False + + def test_returns_false_when_inference_models_not_installed(self): + from inference.core.cache.air_gapped import is_model_cached + + with patch( + "inference.core.cache.air_gapped.MODEL_CACHE_DIR", "/nonexistent" + ), patch( + "inference.core.cache.air_gapped.USE_INFERENCE_MODELS", True + ), patch.dict( + "sys.modules", + {"inference_models.models.auto_loaders.core": None}, + ): + assert is_model_cached("some-model") is False