Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
94c9904
Add Introspection of Locally Cached Models
yeldarby Mar 26, 2026
835af61
Make Style
yeldarby Mar 26, 2026
8b77ac4
Respond to PR Review Comments
yeldarby Mar 26, 2026
d00a018
Merge branch 'main' into cache-introspection
yeldarby Mar 26, 2026
82e5d4a
Remove csrf
probicheaux Mar 27, 2026
115b2bb
Fix air-gapped model discovery for inference-models cache layout
probicheaux Mar 27, 2026
2a52609
Fall back to inference-models model_config.json for model metadata cache
probicheaux Mar 27, 2026
459f2b0
Resolve inference-models cache path before calling from_pretrained
probicheaux Mar 27, 2026
30f77f8
Merge remote-tracking branch 'origin/main' into offline-support
sberan Mar 31, 2026
7bbba72
Restore CSRF token verification in builder routes
sberan Mar 31, 2026
e3e29ac
Format inference_models_adapters.py and roboflow.py with black
sberan Mar 31, 2026
e8b16fd
Merge remote-tracking branch 'origin/main' into offline-support
sberan Apr 11, 2026
6f63b23
Add missing logging import in workflows handler
sberan Apr 11, 2026
e1919c9
scan_cached_models: read canonical model_id from model_config.json
sberan Apr 11, 2026
1681768
Remove _resolve_cached_model_path from inference_models_adapters
sberan Apr 11, 2026
bb44c87
Add offline fallback in from_pretrained and consolidate cache logic
sberan Apr 11, 2026
a8509d0
Remove unused logging import from workflows handler
sberan Apr 11, 2026
9c41607
Add /build/api/csrf endpoint to expose CSRF token
sberan Apr 11, 2026
1ad0a45
Strip whitespace from CSRF token read from disk
sberan Apr 11, 2026
9c7fcd2
Guard air-gapped block introspection against missing methods
sberan Apr 11, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 74 additions & 48 deletions inference/core/cache/air_gapped.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,9 @@
offline workflow construction.
"""

import hashlib
import json
import logging
import os
import re
from typing import Any, Dict, List, Optional

from inference.core.env import MODEL_CACHE_DIR, USE_INFERENCE_MODELS
Expand All @@ -20,22 +18,6 @@
_SKIP_TOP_LEVEL = {"workflow", "_file_locks"}


def _slugify_model_id(model_id: str) -> str:
"""Reproduce the slug used by inference-models for cache directory names.

Must stay in sync with
``inference_models.models.auto_loaders.core.slugify_model_id_to_os_safe_format``.
"""
slug = re.sub(r"[^A-Za-z0-9_-]+", "-", model_id)
slug = re.sub(r"[_-]{2,}", "-", slug)
if not slug:
slug = "special-char-only-model-id"
if len(slug) > 48:
slug = slug[:48]
digest = hashlib.blake2s(model_id.encode("utf-8"), digest_size=4).hexdigest()
return f"{slug}-{digest}"


def _has_non_hidden_children(path: str) -> bool:
"""Return True if *path* contains at least one non-hidden entry.

Expand Down Expand Up @@ -76,12 +58,14 @@ def is_model_cached(model_id: str) -> bool:
if os.path.isdir(traditional_path) and _has_non_hidden_children(traditional_path):
return True

slug = _slugify_model_id(model_id)
models_cache_path = os.path.join(MODEL_CACHE_DIR, "models-cache", slug)
if os.path.isdir(models_cache_path) and _has_non_hidden_children(models_cache_path):
return True
try:
from inference_models.models.auto_loaders.core import (
find_cached_model_package_dir,
)

return False
return find_cached_model_package_dir(model_id) is not None
except ImportError:
return False


def has_cached_model_variant(model_variants: Optional[List[str]]) -> bool:
Expand All @@ -107,23 +91,29 @@ def _load_blocks() -> list:


def scan_cached_models(cache_dir: str) -> List[Dict[str, Any]]:
"""Walk *cache_dir* looking for ``model_type.json`` marker files.
"""Walk *cache_dir* looking for cached model metadata files.

Each marker is written by the model registry when a model is first
downloaded. The file contains at least ``project_task_type`` and
``model_type`` keys.
Scans two cache layouts:

1. **Traditional** — ``model_type.json`` marker files written by the model
registry. The model ID is derived from the directory path.
2. **inference-models** — ``model_config.json`` files written by
``dump_model_config_for_offline_use``. The canonical ``model_id`` is
read from the file, which ensures alias resolution works correctly
(the directory name is an opaque slug in this layout).

Returns a list of dicts with the following shape::

{
"model_id": "workspace/project/3",
"name": "workspace/project/3",
"model_id": "coco/22",
"name": "coco/22",
"task_type": "object-detection",
"model_architecture": "yolov8n",
"is_foundation": False,
}
"""
results: List[Dict[str, Any]] = []
seen_ids: set = set()
if not os.path.isdir(cache_dir):
return results

Expand All @@ -134,36 +124,72 @@ def scan_cached_models(cache_dir: str) -> List[Dict[str, Any]]:
dirs[:] = [d for d in dirs if d not in _SKIP_TOP_LEVEL]
continue

if "model_type.json" not in files:
continue
has_model_type = "model_type.json" in files
has_model_config = "model_config.json" in files

model_type_path = os.path.join(root, "model_type.json")
try:
with open(model_type_path, "r") as fh:
metadata = json.load(fh)
except (json.JSONDecodeError, OSError) as exc:
logger.warning(
"Skipping unreadable model_type.json at %s: %s",
model_type_path,
exc,
)
if not has_model_type and not has_model_config:
continue

metadata: Optional[dict] = None
use_stored_model_id = False

# Prefer model_config.json when present — it contains the canonical
# model_id that matches REGISTERED_ALIASES.
if has_model_config:
config_path = os.path.join(root, "model_config.json")
try:
with open(config_path, "r") as fh:
cfg = json.load(fh)
if (
isinstance(cfg, dict)
and cfg.get("task_type")
and cfg.get("model_id")
):
metadata = cfg
use_stored_model_id = True
except (json.JSONDecodeError, OSError):
pass

# Fall back to model_type.json for the traditional layout.
if metadata is None and has_model_type:
model_type_path = os.path.join(root, "model_type.json")
try:
with open(model_type_path, "r") as fh:
metadata = json.load(fh)
except (json.JSONDecodeError, OSError) as exc:
logger.warning(
"Skipping unreadable model_type.json at %s: %s",
model_type_path,
exc,
)
continue

if not isinstance(metadata, dict):
continue

# Support both traditional keys and inference-models metadata keys.
task_type = metadata.get(PROJECT_TASK_TYPE_KEY) or metadata.get("taskType", "")
model_architecture = metadata.get(MODEL_TYPE_KEY) or metadata.get(
"modelArchitecture", ""
task_type = (
metadata.get("task_type")
or metadata.get(PROJECT_TASK_TYPE_KEY)
or metadata.get("taskType", "")
)
model_architecture = (
metadata.get("model_architecture")
or metadata.get(MODEL_TYPE_KEY)
or metadata.get("modelArchitecture", "")
)

if not task_type:
continue

model_id = os.path.relpath(root, cache_dir)
# Normalise path separators on Windows.
model_id = model_id.replace(os.sep, "/")
if use_stored_model_id:
model_id = metadata["model_id"]
else:
model_id = os.path.relpath(root, cache_dir)
model_id = model_id.replace(os.sep, "/")

if model_id in seen_ids:
continue
seen_ids.add(model_id)

results.append(
{
Expand Down
8 changes: 7 additions & 1 deletion inference/core/interfaces/http/builder/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
# ----------------------------------------------------------------
csrf_file = workflow_local_dir / ".csrf"
if csrf_file.exists():
csrf = csrf_file.read_text()
csrf = csrf_file.read_text().strip()
else:
csrf = os.urandom(16).hex()
csrf_file.write_text(csrf)
Expand Down Expand Up @@ -112,6 +112,12 @@ async def builder_edit(workflow_id: str):
# ----------------------


@router.get("/api/csrf")
@with_route_exceptions_async
async def get_csrf_token():
return {"csrf": csrf}


@router.get("/api", dependencies=[Depends(verify_csrf_token)])
@with_route_exceptions_async
async def get_all_workflows():
Expand Down
6 changes: 4 additions & 2 deletions inference/core/interfaces/http/handlers/workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,11 @@ def _get_air_gapped_info_for_block(
Compatible task types from ``get_compatible_task_types()`` are always
attached when present.
"""
task_types = manifest_cls.get_compatible_task_types()
task_types = manifest_cls.get_compatible_task_types() if hasattr(manifest_cls, "get_compatible_task_types") else []

# 1. Explicit cloud/internet declaration
if not hasattr(manifest_cls, "get_air_gapped_availability"):
return BlockAirGappedInfo(available=True, compatible_task_types=task_types)
availability = manifest_cls.get_air_gapped_availability()
if not availability.available:
return BlockAirGappedInfo(
Expand All @@ -162,7 +164,7 @@ def _get_air_gapped_info_for_block(
)

# 2. Foundation models with locally-cacheable weights
model_variants = manifest_cls.get_supported_model_variants()
model_variants = manifest_cls.get_supported_model_variants() if hasattr(manifest_cls, "get_supported_model_variants") else None
if model_variants is not None:
cached = has_cached_model_variant(model_variants)
# Use the first variant as a representative identifier for the UI.
Expand Down
56 changes: 47 additions & 9 deletions inference/core/registries/roboflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,21 +299,59 @@ def get_model_metadata_from_cache(
def _get_model_metadata_from_cache(
dataset_id: Union[DatasetID, ModelID], version_id: Optional[VersionID]
) -> Optional[Tuple[TaskType, ModelType]]:
# Layout 1: traditional model_type.json
model_type_cache_path = construct_model_type_cache_path(
dataset_id=dataset_id, version_id=version_id
)
if not os.path.isfile(model_type_cache_path):
return None
if os.path.isfile(model_type_cache_path):
try:
model_metadata = read_json(path=model_type_cache_path)
if not model_metadata_content_is_invalid(content=model_metadata):
return (
model_metadata[PROJECT_TASK_TYPE_KEY],
model_metadata[MODEL_TYPE_KEY],
)
except ValueError as e:
logger.warning(
f"Could not load model description from cache under path: "
f"{model_type_cache_path} - decoding issue: {e}."
)

# Layout 2: inference-models model_config.json
model_id = f"{dataset_id}/{version_id}" if version_id else dataset_id
result = _get_model_metadata_from_inference_models_cache(model_id)
if result is not None:
return result

return None


def _get_model_metadata_from_inference_models_cache(
model_id: str,
) -> Optional[Tuple[TaskType, ModelType]]:
"""Check the inference-models cache layout for model metadata."""
try:
model_metadata = read_json(path=model_type_cache_path)
if model_metadata_content_is_invalid(content=model_metadata):
return None
return model_metadata[PROJECT_TASK_TYPE_KEY], model_metadata[MODEL_TYPE_KEY]
except ValueError as e:
logger.warning(
f"Could not load model description from cache under path: {model_type_cache_path} - decoding issue: {e}."
from inference_models.models.auto_loaders.core import (
find_cached_model_package_dir,
)
except ImportError:
return None

cached_dir = find_cached_model_package_dir(model_id)
if cached_dir is None:
return None
config_path = os.path.join(cached_dir, "model_config.json")
try:
metadata = read_json(path=config_path)
except ValueError:
return None
if not isinstance(metadata, dict):
return None
task_type = metadata.get("task_type", "")
model_arch = metadata.get("model_architecture", "")
if task_type and model_arch:
return task_type, model_arch
return None


def model_metadata_content_is_invalid(content: Optional[Union[list, dict]]) -> bool:
Expand Down
51 changes: 46 additions & 5 deletions inference_models/inference_models/models/auto_loaders/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
MissingModelInitParameterError,
ModelPackageAlternativesExhaustedError,
NoModelPackagesAvailableError,
RetryError,
UnauthorizedModelAccessError,
)
from inference_models.logger import LOGGER, verbose_info
Expand Down Expand Up @@ -827,6 +828,21 @@ def register_file_created_for_model_package(
model_id=model_id_or_path, api_key=api_key
)
raise error
except RetryError:
cached_package_dir = find_cached_model_package_dir(
model_id=model_id_or_path
)
if cached_package_dir is None:
raise
LOGGER.info(
f"Network unavailable for model {model_id_or_path}, "
f"loading from cached package at {cached_package_dir}"
)
return attempt_loading_model_from_local_storage(
model_dir_or_weights_path=cached_package_dir,
allow_local_code_packages=allow_local_code_packages,
model_init_kwargs=model_init_kwargs,
)
# here we verify if de-aliasing or access confirmation from auth master changed something
model_from_access_manager = model_access_manager.retrieve_model_instance(
model_id=model_id_or_path,
Expand Down Expand Up @@ -1292,6 +1308,7 @@ def initialize_model(
task_type=task_type,
backend_type=model_package.backend,
file_lock_acquire_timeout=model_download_file_lock_acquire_timeout,
model_id=model_id,
on_file_created=on_file_created,
)
resolved_files = set(shared_files_mapping.values())
Expand Down Expand Up @@ -1392,6 +1409,7 @@ def dump_model_config_for_offline_use(
task_type: TaskType,
backend_type: Optional[BackendType],
file_lock_acquire_timeout: int,
model_id: Optional[str] = None,
on_file_created: Optional[Callable[[str], None]] = None,
) -> None:
if os.path.exists(config_path):
Expand All @@ -1400,14 +1418,17 @@ def dump_model_config_for_offline_use(
return None
target_file_dir, target_file_name = os.path.split(config_path)
lock_path = os.path.join(target_file_dir, f".{target_file_name}.lock")
content = {
"model_architecture": model_architecture,
"task_type": task_type,
"backend_type": backend_type,
}
if model_id is not None:
content["model_id"] = model_id
with FileLock(lock_path, timeout=file_lock_acquire_timeout):
dump_json(
path=config_path,
content={
"model_architecture": model_architecture,
"task_type": task_type,
"backend_type": backend_type,
},
content=content,
)
if on_file_created:
on_file_created(config_path)
Expand Down Expand Up @@ -1515,6 +1536,26 @@ def generate_model_package_cache_path(model_id: str, package_id: str) -> str:
)


def find_cached_model_package_dir(model_id: str) -> Optional[str]:
"""Return the path to a locally-cached model package for *model_id*, or ``None``.

Scans ``{INFERENCE_HOME}/models-cache/{slug}/`` for any package directory
that contains a valid ``model_config.json``. This is used as a fallback
when the weights-provider API is unreachable (offline / air-gapped).
"""
slug = slugify_model_id_to_os_safe_format(model_id=model_id)
slug_dir = os.path.abspath(os.path.join(INFERENCE_HOME, "models-cache", slug))
if not os.path.isdir(slug_dir):
return None
for entry in os.listdir(slug_dir):
package_dir = os.path.join(slug_dir, entry)
if not os.path.isdir(package_dir):
continue
if os.path.isfile(os.path.join(package_dir, MODEL_CONFIG_FILE_NAME)):
return package_dir
return None


def ensure_package_id_is_os_safe(model_id: str, package_id: str) -> None:
if re.search(r"[^A-Za-z0-9]", package_id):
raise InsecureModelIdentifierError(
Expand Down
Loading
Loading