diff --git a/dashboard/src/lib/stores/app.svelte.ts b/dashboard/src/lib/stores/app.svelte.ts index 5b28b3f020..1c06031311 100644 --- a/dashboard/src/lib/stores/app.svelte.ts +++ b/dashboard/src/lib/stores/app.svelte.ts @@ -219,6 +219,29 @@ export interface TraceListResponse { traces: TraceListItem[]; } +export type ModelSourceKind = + | "exo" + | "huggingface" + | "lmstudio" + | "ollama" + | "llamacpp"; + +export type ModelFileFormat = "safetensors" | "mlx" | "gguf"; + +// Mirrors LocalModelEntry on the backend; kept loose-typed where the dashboard +// just passes values through. +export interface RawLocalModelEntry { + nodeId: string; + source: ModelSourceKind; + externalId: string; + displayName: string; + path: string; + format: ModelFileFormat; + sizeBytes: { inBytes: number }; + loadableWithMlx: boolean; + matchedModelId?: string | null; +} + interface RawStateResponse { topology?: RawTopology; instances?: Record< @@ -231,6 +254,7 @@ interface RawStateResponse { runners?: Record; instanceLinks?: Record; downloads?: Record; + localModels?: Record; // New granular node state fields nodeIdentities?: Record; nodeMemory?: Record; @@ -551,6 +575,7 @@ class AppStore { instanceLinks = $state>({}); featureFlags = $state>({}); downloads = $state>({}); + localModels = $state>({}); nodeDisk = $state< Record< string, @@ -1338,6 +1363,11 @@ class AppStore { if (data.downloads) { this.downloads = data.downloads; } + if (data.localModels) { + this.localModels = data.localModels; + } else { + this.localModels = {}; + } if (data.nodeDisk) { this.nodeDisk = data.nodeDisk; } @@ -3498,6 +3528,7 @@ export const updateInstanceLink = ( export const deleteInstanceLink = (linkId: string) => appStore.deleteInstanceLink(linkId); export const downloads = () => appStore.downloads; +export const localModels = () => appStore.localModels; export const nodeDisk = () => appStore.nodeDisk; export const placementPreviews = () => appStore.placementPreviews; export const selectedPreviewModelId = () => appStore.selectedPreviewModelId; diff --git a/dashboard/src/routes/downloads/+page.svelte b/dashboard/src/routes/downloads/+page.svelte index 4bf69cb6cb..acdf120408 100644 --- a/dashboard/src/routes/downloads/+page.svelte +++ b/dashboard/src/routes/downloads/+page.svelte @@ -5,12 +5,15 @@ import { topologyData, downloads, + localModels, nodeDisk, refreshState, lastUpdate as lastUpdateStore, startDownload, cancelDownload, deleteDownload, + type ModelSourceKind, + type RawLocalModelEntry, } from "$lib/stores/app.svelte"; import { getDownloadTag, @@ -20,7 +23,12 @@ import HeaderNav from "$lib/components/HeaderNav.svelte"; type CellStatus = - | { kind: "completed"; totalBytes: number; modelDirectory?: string } + | { + kind: "completed"; + totalBytes: number; + modelDirectory?: string; + sources: Set; + } | { kind: "downloading"; percentage: number; @@ -29,15 +37,43 @@ speed: number; etaMs: number; modelDirectory?: string; + sources: Set; } | { kind: "pending"; downloaded: number; total: number; modelDirectory?: string; + sources: Set; + } + | { kind: "failed"; modelDirectory?: string; sources: Set } + | { + // External: discovered by a non-exo source (LM Studio, Ollama, …) and not + // currently being managed by exo's downloader. + kind: "external"; + totalBytes: number; + modelDirectory?: string; + sources: Set; + format: "safetensors" | "mlx" | "gguf"; + loadable: boolean; } - | { kind: "failed"; modelDirectory?: string } - | { kind: "not_present" }; + | { kind: "not_present"; sources: Set }; + + const SOURCE_LABELS: Record = { + exo: "exo", + huggingface: "HF", + lmstudio: "LM Studio", + ollama: "Ollama", + llamacpp: "llama.cpp", + }; + + const SOURCE_COLORS: Record = { + exo: "bg-exo-yellow/20 text-exo-yellow", + huggingface: "bg-orange-500/20 text-orange-300", + lmstudio: "bg-purple-500/20 text-purple-300", + ollama: "bg-blue-500/20 text-blue-300", + llamacpp: "bg-pink-500/20 text-pink-300", + }; type ModelCardInfo = { family: string; @@ -124,8 +160,9 @@ } const CELL_PRIORITY: Record = { - completed: 4, - downloading: 3, + completed: 5, + downloading: 4, + external: 3, pending: 2, failed: 1, not_present: 0, @@ -175,19 +212,31 @@ return { prettyName, card }; } + const localModelsData = $derived(localModels()); + let modelRows = $state([]); let nodeColumns = $state([]); let infoRow = $state(null); + let sourceFilter = $state("all"); $effect(() => { try { - if (!downloadsData || Object.keys(downloadsData).length === 0) { + const downloadsEmpty = + !downloadsData || Object.keys(downloadsData).length === 0; + const localModelsEmpty = + !localModelsData || Object.keys(localModelsData).length === 0; + if (downloadsEmpty && localModelsEmpty) { modelRows = []; nodeColumns = []; return; } - const allNodeIds = Object.keys(downloadsData); + const allNodeIds = Array.from( + new Set([ + ...Object.keys(downloadsData ?? {}), + ...Object.keys(localModelsData ?? {}), + ]), + ); const columns: NodeColumn[] = allNodeIds.map((nodeId) => { const diskInfo = nodeDiskData?.[nodeId]; return { @@ -200,7 +249,9 @@ const rowMap = new Map(); - for (const [nodeId, nodeDownloads] of Object.entries(downloadsData)) { + for (const [nodeId, nodeDownloads] of Object.entries( + downloadsData ?? {}, + )) { const entries = Array.isArray(nodeDownloads) ? nodeDownloads : nodeDownloads && typeof nodeDownloads === "object" @@ -234,10 +285,14 @@ const modelDirectory = ((payload.model_directory ?? payload.modelDirectory) as string) || undefined; + // Active exo-managed downloads are always source=exo; the local-model + // merge below extends `sources` with any other sources hosting the same + // model on this node. + const sources = new Set(["exo"]); let cell: CellStatus; if (tag === "DownloadCompleted") { const totalBytes = getBytes(payload.total); - cell = { kind: "completed", totalBytes, modelDirectory }; + cell = { kind: "completed", totalBytes, modelDirectory, sources }; } else if (tag === "DownloadOngoing") { const rawProgress = payload.download_progress ?? payload.downloadProgress ?? {}; @@ -257,9 +312,10 @@ speed, etaMs, modelDirectory, + sources, }; } else if (tag === "DownloadFailed") { - cell = { kind: "failed", modelDirectory }; + cell = { kind: "failed", modelDirectory, sources }; } else { const downloaded = getBytes( payload.downloaded ?? @@ -274,6 +330,7 @@ downloaded, total, modelDirectory, + sources, }; } @@ -284,13 +341,49 @@ } } + // Merge local_models. Each entry either augments an existing exo-managed + // cell (adding to `sources`) or creates an "external" cell when this is the + // only place the model lives on this node. + for (const [nodeId, entries] of Object.entries(localModelsData ?? {})) { + if (!Array.isArray(entries)) continue; + for (const entry of entries as RawLocalModelEntry[]) { + const modelId = entry.matchedModelId ?? entry.externalId; + let row = rowMap.get(modelId); + if (!row) { + row = { + modelId, + prettyName: null, + cells: {}, + shardMetadata: null, + modelCard: null, + }; + rowMap.set(modelId, row); + } + const existing = row.cells[nodeId]; + const totalBytes = entry.sizeBytes?.inBytes ?? 0; + if (existing && existing.kind !== "not_present") { + existing.sources.add(entry.source); + continue; + } + row.cells[nodeId] = { + kind: "external", + totalBytes, + modelDirectory: entry.path, + sources: new Set([entry.source]), + format: entry.format, + loadable: entry.loadableWithMlx, + }; + } + } + function rowSortKey(row: ModelRow): number { - // in progress (4) -> completed (3) -> paused (2) -> not started (1) -> not present (0) + // downloading (5) -> completed (4) -> external (3) -> paused (2) -> not started (1) -> not present (0) let best = 0; for (const cell of Object.values(row.cells)) { let score = 0; - if (cell.kind === "downloading") score = 4; - else if (cell.kind === "completed") score = 3; + if (cell.kind === "downloading") score = 5; + else if (cell.kind === "completed") score = 4; + else if (cell.kind === "external") score = 3; else if (cell.kind === "pending" && cell.downloaded > 0) score = 2; // paused else if (cell.kind === "pending" || cell.kind === "failed") score = 1; // not started @@ -302,35 +395,51 @@ function totalCompletedBytes(row: ModelRow): number { let total = 0; for (const cell of Object.values(row.cells)) { - if (cell.kind === "completed") total += cell.totalBytes; + if (cell.kind === "completed" || cell.kind === "external") + total += cell.totalBytes; } return total; } - const rows = Array.from(rowMap.values()).sort((a, b) => { - const aPriority = rowSortKey(a); - const bPriority = rowSortKey(b); - if (aPriority !== bPriority) return bPriority - aPriority; - // Within completed or paused, sort by biggest size first - if (aPriority === 3 && bPriority === 3) { - const sizeDiff = totalCompletedBytes(b) - totalCompletedBytes(a); - if (sizeDiff !== 0) return sizeDiff; - } - if (aPriority === 2 && bPriority === 2) { - const aSize = Math.max( - ...Object.values(a.cells).map((c) => - c.kind === "pending" ? c.total : 0, - ), - ); - const bSize = Math.max( - ...Object.values(b.cells).map((c) => - c.kind === "pending" ? c.total : 0, - ), - ); - if (aSize !== bSize) return bSize - aSize; + function rowMatchesSourceFilter(row: ModelRow): boolean { + if (sourceFilter === "all") return true; + for (const cell of Object.values(row.cells)) { + if ( + "sources" in cell && + (cell.sources as Set).has(sourceFilter) + ) { + return true; + } } - return a.modelId.localeCompare(b.modelId); - }); + return false; + } + + const rows = Array.from(rowMap.values()) + .filter(rowMatchesSourceFilter) + .sort((a, b) => { + const aPriority = rowSortKey(a); + const bPriority = rowSortKey(b); + if (aPriority !== bPriority) return bPriority - aPriority; + // Within completed or paused, sort by biggest size first + if (aPriority === 3 && bPriority === 3) { + const sizeDiff = totalCompletedBytes(b) - totalCompletedBytes(a); + if (sizeDiff !== 0) return sizeDiff; + } + if (aPriority === 2 && bPriority === 2) { + const aSize = Math.max( + ...Object.values(a.cells).map((c) => + c.kind === "pending" ? c.total : 0, + ), + ); + const bSize = Math.max( + ...Object.values(b.cells).map((c) => + c.kind === "pending" ? c.total : 0, + ), + ); + if (aSize !== bSize) return bSize - aSize; + } + return a.modelId.localeCompare(b.modelId); + }); modelRows = rows; nodeColumns = columns; @@ -403,6 +512,20 @@ {/snippet} +{#snippet sourceBadges(sources: Set)} +
+ {#each Array.from(sources) as src} + + {SOURCE_LABELS[src]} + + {/each} +
+{/snippet} +
@@ -411,10 +534,11 @@

- Downloads + Models

- Overview of models on each node + Models available on each node — including those installed via LM + Studio, Ollama, llama.cpp and HuggingFace.

@@ -434,6 +558,27 @@
+ +
+ Source: + {#each ["all", "exo", "huggingface", "lmstudio", "ollama", "llamacpp"] as filter} + {@const active = sourceFilter === filter} + + {/each} +
+ {#if !hasDownloads}
(), }} + {@const exoManaged = + "sources" in cell && cell.sources.has("exo")} {#if cell.kind === "completed"}
{formatBytes(cell.totalBytes)} - {@render deleteButton(col.nodeId, row.modelId)} + {@render sourceBadges(cell.sources)} + {#if exoManaged} + {@render deleteButton(col.nodeId, row.modelId)} + {/if} +
+ {:else if cell.kind === "external"} +
+ + + + {formatBytes(cell.totalBytes)} + {@render sourceBadges(cell.sources)} + {#if !cell.loadable} + {cell.format} · n/a + {/if}
{:else if cell.kind === "downloading"}
{@render deleteButton(col.nodeId, row.modelId)}
+ {@render sourceBadges(cell.sources)}
{:else if cell.kind === "pending"}
None: self.app.post("/download/start")(self.start_download) self.app.delete("/download/{node_id}/{model_id:path}")(self.delete_download) self.app.post("/download/cancel")(self.cancel_download) + self.app.get("/sources")(self.list_sources) + self.app.post("/sources/rescan")(self.rescan_sources) self.app.get("/v1/traces")(self.list_traces) self.app.post("/v1/traces/delete")(self.delete_traces) self.app.get("/v1/traces/{task_id}")(self.get_trace) @@ -1975,6 +1980,31 @@ async def start_download( async def delete_download( self, node_id: NodeId, model_id: ModelId ) -> DeleteDownloadResponse: + # Guard: only allow deletion of models that exo's own downloader manages. + # Models surfaced from external sources (LM Studio, Ollama, …) live outside + # exo's writable cache and shouldn't be touched here — direct the user at the + # owning tool instead. + node_downloads = self.state.downloads.get(node_id, ()) + is_exo_managed = any( + isinstance(dl, DownloadCompleted) + and dl.shard_metadata.model_card.model_id == model_id + for dl in node_downloads + ) + if not is_exo_managed: + external_sources: set[str] = { + entry.source + for entry in self.state.local_models.get(node_id, ()) + if entry.external_id == str(model_id) and entry.source != "exo" + } + if external_sources: + tools = ", ".join(sorted(external_sources)) + raise HTTPException( + status_code=409, + detail=( + f"'{model_id}' is managed by {tools}. " + "Remove it from the owning tool to free disk space." + ), + ) command = DeleteDownload( target_node_id=node_id, model_id=ModelId(model_id), @@ -1982,6 +2012,36 @@ async def delete_download( await self._send_download(command) return DeleteDownloadResponse(command_id=command.command_id) + async def list_sources(self) -> SourcesResponse: + """Enumerate all model sources known to this node and whether each is configured. + + Reads :func:`exo.sources.default_sources` on the API process. Availability is a + cheap directory existence check; entries themselves are surfaced via ``/state`` + under ``local_models``. + """ + from exo.sources import default_sources + + infos: list[SourceInfo] = [] + for source in default_sources(): + try: + available = source.is_available() + except Exception: # noqa: BLE001 + available = False + infos.append( + SourceInfo( + kind=source.kind, + display_name=source.display_name, + available=available, + ) + ) + return SourcesResponse(sources=infos) + + async def rescan_sources(self) -> RescanSourcesResponse: + """Best-effort: returns immediately. Workers rescan on a 60s timer; this is a + documentation hook for a future broadcast-rescan command. Today it always + returns ``triggered=false`` so callers don't depend on a no-op.""" + return RescanSourcesResponse(triggered=False) + async def cancel_download( self, payload: CancelDownloadParams, diff --git a/src/exo/api/types/__init__.py b/src/exo/api/types/__init__.py index 9cb2f834fa..8c6bc362dc 100644 --- a/src/exo/api/types/__init__.py +++ b/src/exo/api/types/__init__.py @@ -46,6 +46,9 @@ from .api import PlacementPreviewResponse as PlacementPreviewResponse from .api import PowerUsage as PowerUsage from .api import PromptTokensDetails as PromptTokensDetails +from .api import RescanSourcesResponse as RescanSourcesResponse +from .api import SourceInfo as SourceInfo +from .api import SourcesResponse as SourcesResponse from .api import StartDownloadParams as StartDownloadParams from .api import StartDownloadResponse as StartDownloadResponse from .api import StreamingChoiceResponse as StreamingChoiceResponse diff --git a/src/exo/api/types/api.py b/src/exo/api/types/api.py index 8cfa10dd1a..043a062ded 100644 --- a/src/exo/api/types/api.py +++ b/src/exo/api/types/api.py @@ -6,7 +6,7 @@ from pydantic import BaseModel, Field, field_validator from exo.shared.models.model_cards import ModelCard, ModelId -from exo.shared.types.common import CommandId, NodeId +from exo.shared.types.common import CommandId, ModelSourceKind, NodeId from exo.shared.types.memory import Memory from exo.shared.types.text_generation import ReasoningDialect, ReasoningEffort from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta @@ -451,6 +451,20 @@ class CancelDownloadResponse(FrozenModel): command_id: CommandId +class SourceInfo(FrozenModel): + kind: ModelSourceKind + display_name: str + available: bool + + +class SourcesResponse(FrozenModel): + sources: list[SourceInfo] + + +class RescanSourcesResponse(FrozenModel): + triggered: bool + + class TraceEventResponse(FrozenModel): name: str start_us: int diff --git a/src/exo/download/download_utils.py b/src/exo/download/download_utils.py index bab70a1c5d..e81904021d 100644 --- a/src/exo/download/download_utils.py +++ b/src/exo/download/download_utils.py @@ -23,6 +23,7 @@ TypeAdapter, ) +from exo.download.fingerprint import fingerprint_directory from exo.download.huggingface_utils import ( filter_repo_objects, get_allow_patterns, @@ -121,17 +122,76 @@ class InsufficientDiskSpaceError(Exception): def resolve_existing_model( model_id: ModelId, card: ModelCard | None = None ) -> Path | None: - """Search all model directories for a complete, pre-existing model. - - Checks read-only directories first, then writable directories. - A candidate is only returned if ``is_model_directory_complete`` confirms - all weight files are present. + """Find a complete local copy of a model — by folder convention first, then by content. + + 1. Convention pass: look for ``{search_dir}/{model_id.normalize()}``. Read-only dirs + take precedence over writable dirs (matches existing semantics). + 2. Content pass: if the convention pass misses, hash the architecture-defining + fields of the canonical ``config.json`` and scan every other dir under + ``EXO_MODELS_*_DIRS`` for one with a matching fingerprint. Lets exo recognise + a model the user has on disk under a non-canonical folder name without forcing + a redundant re-download. """ normalized = model_id.normalize() for search_dir in (*EXO_MODELS_READ_ONLY_DIRS, *EXO_MODELS_DIRS): candidate = search_dir / normalized if candidate.is_dir() and is_model_directory_complete(candidate, card): return candidate + + target_fp = _target_fingerprint(normalized) + if target_fp is None: + return None + return _find_by_content(target_fp, normalized, card) + + +def _target_fingerprint(normalized_id: str) -> str | None: + """Fingerprint the canonical config.json for ``normalized_id`` if it exists. + + ``ModelCard.fetch_from_hf`` writes ``config.json`` into the canonical dir during + ``/models/add``, so in normal flow this returns a real fingerprint. When the user + has never added the model card, returns ``None`` and the second pass is skipped. + """ + for search_dir in (*EXO_MODELS_DIRS, *EXO_MODELS_READ_ONLY_DIRS): + fp = fingerprint_directory(search_dir / normalized_id) + if fp is not None: + return fp + return None + + +# Process-wide dedup for the "resolved by content" log line. Without this, the same +# (target_fp, found_path) gets logged at INFO every time the coordinator or worker +# checks completeness for an active download — once per file event, ~10× per second. +_logged_content_resolutions: set[tuple[str, str]] = set() + + +def _find_by_content( + target_fp: str, skip_dir_name: str, card: ModelCard | None +) -> Path | None: + """Walk ``EXO_MODELS_*_DIRS`` for a complete model dir whose fingerprint matches.""" + for search_dir in (*EXO_MODELS_READ_ONLY_DIRS, *EXO_MODELS_DIRS): + if not search_dir.exists(): + continue + try: + children = list(search_dir.iterdir()) + except OSError: + continue + for child in children: + if not child.is_dir() or child.name == skip_dir_name: + continue + # Skip exo's own metadata cache. + if child.name == "caches": + continue + if fingerprint_directory(child) != target_fp: + continue + if is_model_directory_complete(child, card): + log_key = (target_fp, str(child)) + if log_key not in _logged_content_resolutions: + _logged_content_resolutions.add(log_key) + logger.info( + f"Resolved {target_fp[:12]}… to {child} via content fingerprint " + f"(folder name '{child.name}' != convention '{skip_dir_name}')" + ) + return child return None @@ -144,9 +204,42 @@ def build_model_path(model_id: ModelId) -> Path: found = resolve_existing_model(model_id) if found is not None: return found + external = _resolve_from_external_sources(str(model_id)) + if external is not None: + return external return EXO_DEFAULT_MODELS_DIR / model_id.normalize() +def _resolve_from_external_sources(external_id: str) -> Path | None: + """Ask non-exo sources whether they have a usable path for this model. + + Used when exo's own cache doesn't have the model — lets MLX inference load + safetensors models that already live in the HF cache or LM Studio's library + without forcing a re-download. GGUF entries from Ollama/llama.cpp are skipped: + they exist in the catalog but no current exo engine can load them. + """ + # Lazy import: ``exo.sources`` pulls in optional cache-walking modules; keeping + # this out of module load-time avoids paying that cost in unrelated code paths. + from exo.sources import default_sources + + for source in default_sources(): + if source.kind == "exo": + continue # already covered by resolve_existing_model above + try: + if not source.is_available(): + continue + resolved = source.resolve_path(external_id) + except Exception as exc: # noqa: BLE001 — best-effort across third-party caches + logger.debug(f"source {source.kind} resolve_path raised: {exc!r}") + continue + if resolved is None: + continue + # Only return the path if the engine can actually load it. + if resolved.is_dir() and (resolved / "config.json").exists(): + return resolved + return None + + def select_download_dir(required_bytes: int) -> Path: """Pick the first writable model directory with enough free space. @@ -191,12 +284,17 @@ async def select_download_dir_for_shard( async def resolve_model_dir(model_id: ModelId) -> Path: - """Return the directory for a model's files, creating it if needed. - - Checks all model directories for an existing complete model first, - then falls back to the default models directory. + """Return the directory exo should *write* a model's files into. + + Prefers an existing exo-managed copy (canonical or mispathed-but-content-matched) + so partial downloads resume in place; otherwise falls back to the canonical + ``EXO_DEFAULT_MODELS_DIR/{normalized_id}``. Deliberately does NOT consult external + sources (HF cache, LM Studio, …) — those are read-only as far as exo is concerned; + redirecting writes there would confuse the owning tool and would scatter exo's + metadata across third-party caches. """ - target = await asyncio.to_thread(build_model_path, model_id) + found = await asyncio.to_thread(resolve_existing_model, model_id) + target = found if found is not None else EXO_DEFAULT_MODELS_DIR / model_id.normalize() await aios.makedirs(target, exist_ok=True) return target @@ -249,11 +347,20 @@ def _scan_model_directory( ) -> list[FileListEntry] | None: """Scan a local model directory and build a file list. - Requires at least one ``*.safetensors.index.json``. Every weight file - referenced by the index that is missing on disk gets ``size=None``. + Two layouts are recognised: + + * **Sharded** — at least one ``*.safetensors.index.json`` exists. Every weight file + named in the index but missing on disk is added with ``size=None`` so completeness + checks can detect partial downloads. + * **Single-file** — no index, but a non-partial ``model.safetensors`` is on disk + (e.g. ``Qwen/Qwen3-0.6B``). We trust local files: a present, fully-downloaded + ``model.safetensors`` means the dir is complete. + + Returns ``None`` only when neither layout is present (truly empty / not a model dir). """ index_files = list(model_dir.glob("**/*.safetensors.index.json")) - if not index_files: + has_single_file_weights = (model_dir / "model.safetensors").is_file() + if not index_files and not has_single_file_weights: return None entries_by_path: dict[str, FileListEntry] = {} @@ -279,6 +386,12 @@ def _scan_model_directory( size=item.stat().st_size, ) + # Single-file layout: nothing more to enumerate. The presence of a non-partial + # ``model.safetensors`` is the completeness signal — partial downloads live under + # ``model.safetensors.partial`` until the rename at end of download. + if not index_files: + return list(entries_by_path.values()) + # Add expected weight files from index that haven't been downloaded yet for index_file in index_files: try: diff --git a/src/exo/download/fingerprint.py b/src/exo/download/fingerprint.py new file mode 100644 index 0000000000..4301f98fdb --- /dev/null +++ b/src/exo/download/fingerprint.py @@ -0,0 +1,69 @@ +"""Content-based model fingerprinting. + +A fingerprint is a stable hash over the architecture-defining subset of a model's +``config.json``. Two model directories with the same fingerprint represent the same +model variant — same architecture, same hyperparameters, same quantization — even if +they live under different folder names on disk. + +Used by :func:`exo.download.download_utils.resolve_existing_model` to find a locally +cached model regardless of how its folder was named, after the convention-based +``{models_dir}/{normalized_id}`` lookup misses. + +Pure functions, no side effects. ``config.json`` is the only input. +""" + +import hashlib +import json +from collections.abc import Mapping +from pathlib import Path +from typing import Final + +# Architecture-defining fields. Quantization is included so a 4-bit MLX quant doesn't +# match its bf16 sibling (they share most other fields). Cosmetic config keys — +# ``_name_or_path``, ``transformers_version``, ``torch_dtype`` overrides at save time, +# ``architectures`` capitalisation drift — are deliberately excluded so two saves of +# the same model from different transformers versions still match. +_FINGERPRINT_KEYS: Final = ( + "architectures", + "model_type", + "hidden_size", + "num_hidden_layers", + "num_attention_heads", + "num_key_value_heads", + "intermediate_size", + "vocab_size", + "max_position_embeddings", + "head_dim", + "quantization", +) + + +def fingerprint_config(config: Mapping[str, object]) -> str: + """Stable SHA-256 over a canonical projection of architecture-defining fields. + + Cosmetic differences (key order, whitespace, unrelated keys) do not affect the + result. Two configs that share every value listed in ``_FINGERPRINT_KEYS`` produce + the same fingerprint. + """ + selected = {k: config[k] for k in _FINGERPRINT_KEYS if k in config} + canonical = json.dumps(selected, sort_keys=True, separators=(",", ":")) + return hashlib.sha256(canonical.encode("utf-8")).hexdigest() + + +def fingerprint_directory(model_dir: Path) -> str | None: + """Return the fingerprint of ``{model_dir}/config.json``, or ``None`` if absent or invalid. + + Tolerates malformed JSON and missing files by returning ``None`` — callers should + treat that as "this isn't a recognisable model directory" rather than as an error. + """ + config_path = model_dir / "config.json" + if not config_path.is_file(): + return None + try: + with config_path.open("r", encoding="utf-8") as f: + raw: object = json.load(f) # pyright: ignore[reportAny] + except (OSError, ValueError): + return None + if not isinstance(raw, dict): + return None + return fingerprint_config(raw) # pyright: ignore[reportUnknownArgumentType] diff --git a/src/exo/download/tests/test_content_resolution.py b/src/exo/download/tests/test_content_resolution.py new file mode 100644 index 0000000000..7661965760 --- /dev/null +++ b/src/exo/download/tests/test_content_resolution.py @@ -0,0 +1,150 @@ +"""Content-based resolution: find a model regardless of folder name.""" + +import json +from pathlib import Path +from unittest.mock import patch + +import pytest + +from exo.download.download_utils import resolve_existing_model +from exo.shared.types.common import ModelId + +MODEL_ID = ModelId("test-org/test-model") +NORMALIZED = MODEL_ID.normalize() + +_BASE_CONFIG: dict[str, object] = { + "model_type": "qwen3", + "architectures": ["Qwen3ForCausalLM"], + "hidden_size": 1024, + "num_hidden_layers": 28, + "num_attention_heads": 16, + "num_key_value_heads": 8, + "intermediate_size": 3072, + "vocab_size": 151936, + "max_position_embeddings": 40960, +} + + +def _write_config(model_dir: Path, config: dict[str, object] | None = None) -> None: + model_dir.mkdir(parents=True, exist_ok=True) + (model_dir / "config.json").write_text(json.dumps(config or _BASE_CONFIG)) + + +def _write_complete_model( + model_dir: Path, config: dict[str, object] | None = None +) -> None: + """Config + safetensors + index — passes ``is_model_directory_complete``.""" + _write_config(model_dir, config) + weight_map = {"layer.weight": "model.safetensors"} + (model_dir / "model.safetensors.index.json").write_text( + json.dumps({"metadata": {"total_size": 1024}, "weight_map": weight_map}) + ) + (model_dir / "model.safetensors").write_bytes(b"weights" * 100) + + +@pytest.fixture +def writable(tmp_path: Path) -> Path: + out = tmp_path / "writable" + out.mkdir() + return out + + +@pytest.fixture +def readonly(tmp_path: Path) -> Path: + out = tmp_path / "readonly" + out.mkdir() + return out + + +def test_finds_model_under_wrong_folder_name(writable: Path) -> None: + """The user's reported scenario: model lives under an unconventional folder name.""" + _write_complete_model(writable / "wrong-typo") + # Canonical dir has the config but no weights — simulates ``/models/add`` having + # written the card but no download yet. + _write_config(writable / NORMALIZED) + with ( + patch("exo.download.download_utils.EXO_MODELS_READ_ONLY_DIRS", ()), + patch("exo.download.download_utils.EXO_MODELS_DIRS", (writable,)), + ): + resolved = resolve_existing_model(MODEL_ID) + assert resolved == writable / "wrong-typo" + + +def test_convention_path_takes_precedence_over_content(writable: Path) -> None: + """When both folders are complete, the canonical one wins (fast path).""" + _write_complete_model(writable / NORMALIZED) + _write_complete_model(writable / "wrong-typo") + with ( + patch("exo.download.download_utils.EXO_MODELS_READ_ONLY_DIRS", ()), + patch("exo.download.download_utils.EXO_MODELS_DIRS", (writable,)), + ): + resolved = resolve_existing_model(MODEL_ID) + assert resolved == writable / NORMALIZED + + +def test_partial_mispath_does_not_match(writable: Path) -> None: + """Mispath has matching fingerprint but only a ``.partial`` weight — must not be selected.""" + target = writable / "wrong-typo" + target.mkdir(parents=True) + (target / "config.json").write_text(json.dumps(_BASE_CONFIG)) + (target / "model.safetensors.partial").write_bytes(b"half") + _write_config(writable / NORMALIZED) + with ( + patch("exo.download.download_utils.EXO_MODELS_READ_ONLY_DIRS", ()), + patch("exo.download.download_utils.EXO_MODELS_DIRS", (writable,)), + ): + resolved = resolve_existing_model(MODEL_ID) + assert resolved is None + + +def test_returns_none_when_canonical_config_missing(writable: Path) -> None: + """If the canonical config.json hasn't been fetched yet, we have nothing to match + against, so the content pass cannot resolve.""" + _write_complete_model(writable / "wrong-typo") + with ( + patch("exo.download.download_utils.EXO_MODELS_READ_ONLY_DIRS", ()), + patch("exo.download.download_utils.EXO_MODELS_DIRS", (writable,)), + ): + resolved = resolve_existing_model(MODEL_ID) + assert resolved is None + + +def test_different_architecture_not_matched(writable: Path) -> None: + _write_complete_model( + writable / "wrong-typo", {**_BASE_CONFIG, "hidden_size": 2048} + ) + _write_config(writable / NORMALIZED) + with ( + patch("exo.download.download_utils.EXO_MODELS_READ_ONLY_DIRS", ()), + patch("exo.download.download_utils.EXO_MODELS_DIRS", (writable,)), + ): + resolved = resolve_existing_model(MODEL_ID) + assert resolved is None + + +def test_read_only_dir_wins_in_content_pass(writable: Path, readonly: Path) -> None: + """Content pass preserves the existing read-only-first precedence.""" + _write_complete_model(readonly / "ro-mispath") + _write_complete_model(writable / "rw-mispath") + _write_config(writable / NORMALIZED) + with ( + patch("exo.download.download_utils.EXO_MODELS_READ_ONLY_DIRS", (readonly,)), + patch("exo.download.download_utils.EXO_MODELS_DIRS", (writable,)), + ): + resolved = resolve_existing_model(MODEL_ID) + assert resolved == readonly / "ro-mispath" + + +def test_quantized_variant_not_matched_by_unquantized_canonical(writable: Path) -> None: + """A 4-bit quant on disk must not be resolved when the user asked for the bf16 model.""" + _write_complete_model( + writable / "wrong-typo", + {**_BASE_CONFIG, "quantization": {"bits": 4, "group_size": 64}}, + ) + _write_config(writable / NORMALIZED) # unquantized canonical config + with ( + patch("exo.download.download_utils.EXO_MODELS_READ_ONLY_DIRS", ()), + patch("exo.download.download_utils.EXO_MODELS_DIRS", (writable,)), + ): + resolved = resolve_existing_model(MODEL_ID) + assert resolved is None diff --git a/src/exo/download/tests/test_fingerprint.py b/src/exo/download/tests/test_fingerprint.py new file mode 100644 index 0000000000..ba422c81d0 --- /dev/null +++ b/src/exo/download/tests/test_fingerprint.py @@ -0,0 +1,80 @@ +"""Tests for content-based fingerprinting of model directories.""" + +import json +from pathlib import Path + +from exo.download.fingerprint import fingerprint_config, fingerprint_directory + + +def test_same_logical_config_yields_same_fingerprint() -> None: + """Reordering keys and whitespace are cosmetic — fingerprints must match.""" + a = { + "model_type": "qwen3", + "hidden_size": 1024, + "num_hidden_layers": 28, + "vocab_size": 151936, + } + b = { + "vocab_size": 151936, + "hidden_size": 1024, + "num_hidden_layers": 28, + "model_type": "qwen3", + } + assert fingerprint_config(a) == fingerprint_config(b) + + +def test_different_quantization_yields_different_fingerprint() -> None: + """A 4-bit quant must not match its bf16 sibling.""" + base = {"model_type": "qwen3", "hidden_size": 1024, "num_hidden_layers": 28} + bf16 = {**base} + quant_4bit = {**base, "quantization": {"bits": 4, "group_size": 64}} + assert fingerprint_config(bf16) != fingerprint_config(quant_4bit) + + +def test_different_architecture_yields_different_fingerprint() -> None: + a = {"model_type": "qwen3", "hidden_size": 1024, "num_hidden_layers": 28} + b = {"model_type": "qwen3", "hidden_size": 2048, "num_hidden_layers": 28} + assert fingerprint_config(a) != fingerprint_config(b) + + +def test_cosmetic_only_fields_are_ignored() -> None: + """Unrelated keys (``_name_or_path``, ``transformers_version``) must not affect the fingerprint.""" + base = {"model_type": "qwen3", "hidden_size": 1024, "num_hidden_layers": 28} + enriched = { + **base, + "_name_or_path": "/some/local/path", + "transformers_version": "4.45.0", + "torch_dtype": "bfloat16", + } + assert fingerprint_config(base) == fingerprint_config(enriched) + + +def test_fingerprint_directory_returns_none_for_empty_dir(tmp_path: Path) -> None: + assert fingerprint_directory(tmp_path) is None + + +def test_fingerprint_directory_returns_none_for_invalid_json(tmp_path: Path) -> None: + (tmp_path / "config.json").write_text("not actually json") + assert fingerprint_directory(tmp_path) is None + + +def test_fingerprint_directory_matches_fingerprint_config(tmp_path: Path) -> None: + """End-to-end: the directory variant must match the in-memory variant for the same config.""" + config = { + "model_type": "qwen3", + "hidden_size": 1024, + "num_hidden_layers": 28, + "num_attention_heads": 16, + "num_key_value_heads": 8, + "intermediate_size": 3072, + "vocab_size": 151936, + "quantization": {"bits": 4, "group_size": 64}, + } + (tmp_path / "config.json").write_text(json.dumps(config)) + assert fingerprint_directory(tmp_path) == fingerprint_config(config) + + +def test_fingerprint_is_deterministic_across_calls() -> None: + """Whatever opaque hash representation we use must be stable run-to-run.""" + config = {"model_type": "x", "hidden_size": 64, "num_hidden_layers": 1} + assert fingerprint_config(config) == fingerprint_config(config) diff --git a/src/exo/download/tests/test_single_file_model.py b/src/exo/download/tests/test_single_file_model.py new file mode 100644 index 0000000000..310b8cfbe0 --- /dev/null +++ b/src/exo/download/tests/test_single_file_model.py @@ -0,0 +1,169 @@ +"""Single-file safetensors model support. + +Models that fit in one ``model.safetensors`` (e.g. ``Qwen/Qwen3-0.6B``, smaller MLX +quants) ship without ``model.safetensors.index.json``. exo's downloader and +"is this complete?" check used to require the index unconditionally — these tests +lock in the relaxation. +""" + +from collections.abc import Awaitable, Callable +from pathlib import Path +from unittest.mock import patch + +import pytest + +from exo.download.download_utils import ( + _scan_model_directory, # pyright: ignore[reportPrivateUsage] — needed for direct unit tests + is_model_directory_complete, +) +from exo.shared.models.model_cards import fetch_safetensors_size +from exo.shared.types.common import ModelId +from exo.shared.types.memory import Memory + + +def _make_single_file_model(model_dir: Path) -> None: + model_dir.mkdir(parents=True, exist_ok=True) + (model_dir / "config.json").write_text('{"model_type": "qwen3"}') + (model_dir / "model.safetensors").write_bytes(b"weights" * 100) + (model_dir / "tokenizer.json").write_text("{}") + + +def test_scan_returns_entries_for_single_file_layout(tmp_path: Path) -> None: + _make_single_file_model(tmp_path) + entries = _scan_model_directory(tmp_path, recursive=True) + assert entries is not None + names = {e.path for e in entries} + assert "model.safetensors" in names + assert "config.json" in names + # All on-disk entries should have a real size — completeness check needs this. + assert all(e.size is not None for e in entries) + + +def test_is_complete_true_for_single_file(tmp_path: Path) -> None: + _make_single_file_model(tmp_path) + assert is_model_directory_complete(tmp_path) is True + + +def test_is_complete_false_for_partial_only(tmp_path: Path) -> None: + """A partial download lives under ``model.safetensors.partial`` until rename — must not be marked complete.""" + tmp_path.mkdir(exist_ok=True) + (tmp_path / "config.json").write_text('{"model_type": "qwen3"}') + (tmp_path / "model.safetensors.partial").write_bytes(b"half") + assert is_model_directory_complete(tmp_path) is False + + +def test_scan_returns_none_for_truly_empty_dir(tmp_path: Path) -> None: + """Regression: an empty dir (no index, no weights) must still return None so callers + know there's nothing to load.""" + assert _scan_model_directory(tmp_path, recursive=True) is None + + +@pytest.mark.asyncio +async def test_fetch_safetensors_size_falls_back_on_missing_index( + tmp_path: Path, +) -> None: + """When the remote has no index, ``fetch_safetensors_size`` consults + ``huggingface_hub.model_info`` for the canonical size.""" + + async def fake_download_raises_not_found( + model_id: ModelId, + revision: str, + path: str, + target_dir: Path, + on_progress: Callable[[int, int, bool], None] = lambda _, __, ___: None, + ) -> Path: + raise FileNotFoundError(f"File not found: fake://{model_id}/{path}") + + async def fake_resolve_model_dir(model_id: ModelId) -> Path: + return tmp_path + + class _FakeSafetensorsInfo: + total: int = 1234 + + class _FakeModelInfo: + safetensors: _FakeSafetensorsInfo = _FakeSafetensorsInfo() + + def fake_model_info(_id: object, *_args: object, **_kw: object) -> _FakeModelInfo: + return _FakeModelInfo() + + with ( + patch( + "exo.download.download_utils.download_file_with_retry", + fake_download_raises_not_found, + ), + patch( + "exo.download.download_utils.resolve_model_dir", + fake_resolve_model_dir, + ), + patch("exo.shared.models.model_cards.model_info", fake_model_info), + ): + size = await fetch_safetensors_size(ModelId("Qwen/Qwen3-0.6B")) + + assert size == Memory.from_bytes(1234) + + +@pytest.mark.asyncio +async def test_fetch_safetensors_size_raises_when_no_safetensors_info( + tmp_path: Path, +) -> None: + """If the index is absent AND HF reports no safetensors metadata, surface a real error.""" + + async def fake_download_raises_not_found( + model_id: ModelId, + revision: str, + path: str, + target_dir: Path, + on_progress: Callable[[int, int, bool], None] = lambda _, __, ___: None, + ) -> Path: + raise FileNotFoundError("missing") + + async def fake_resolve_model_dir(model_id: ModelId) -> Path: + return tmp_path + + class _FakeNoSafetensors: + safetensors = None + + def fake_model_info( + _id: object, *_args: object, **_kw: object + ) -> _FakeNoSafetensors: + return _FakeNoSafetensors() + + with ( + patch( + "exo.download.download_utils.download_file_with_retry", + fake_download_raises_not_found, + ), + patch( + "exo.download.download_utils.resolve_model_dir", + fake_resolve_model_dir, + ), + patch("exo.shared.models.model_cards.model_info", fake_model_info), + pytest.raises(ValueError, match="No safetensors info"), + ): + await fetch_safetensors_size(ModelId("nonexistent/repo")) + + +# Sanity: existing multi-file path keeps working. +def test_scan_still_handles_index_layout(tmp_path: Path) -> None: + import json + + (tmp_path / "config.json").write_text('{"model_type": "x"}') + (tmp_path / "model-00001-of-00002.safetensors").write_bytes(b"a") + (tmp_path / "model-00002-of-00002.safetensors").write_bytes(b"b") + weight_map = { + "layer.0.weight": "model-00001-of-00002.safetensors", + "layer.1.weight": "model-00002-of-00002.safetensors", + } + (tmp_path / "model.safetensors.index.json").write_text( + json.dumps({"metadata": {"total_size": 2}, "weight_map": weight_map}) + ) + entries = _scan_model_directory(tmp_path, recursive=True) + assert entries is not None + paths = {e.path for e in entries} + assert "model-00001-of-00002.safetensors" in paths + assert "model-00002-of-00002.safetensors" in paths + assert is_model_directory_complete(tmp_path) is True + + +# unused, but keeps the type-checker honest about the awaitable signature shape. +_AwaitableFn = Callable[..., Awaitable[Path]] diff --git a/src/exo/routing/router.py b/src/exo/routing/router.py index a9341d10ca..746db495d0 100644 --- a/src/exo/routing/router.py +++ b/src/exo/routing/router.py @@ -22,6 +22,7 @@ ) from filelock import FileLock from loguru import logger +from pydantic import ValidationError from exo.shared.constants import EXO_NODE_ID_KEYPAIR from exo.utils.channels import Receiver, Sender, channel @@ -201,7 +202,18 @@ async def _networking_recv(self): ) continue router = self.topic_routers[topic] - await router.publish_bytes(data) + try: + await router.publish_bytes(data) + except ValidationError as exc: + # A peer running an incompatible schema (newer or older) sent a + # message we can't parse. Drop it and keep the loop alive — one + # bad sender must not take the node offline. + logger.warning( + f"Dropping malformed message on {topic} from {origin}: " + f"{exc.error_count()} validation errors. " + f"(Run with -v for the full payload.)" + ) + logger.debug(f"Malformed payload: {data!r}; errors: {exc}") case PyFromSwarm.Connection(): message = ConnectionMessage.from_update(from_swarm) logger.trace( diff --git a/src/exo/shared/apply.py b/src/exo/shared/apply.py index 959f7765b9..6e4f826c12 100644 --- a/src/exo/shared/apply.py +++ b/src/exo/shared/apply.py @@ -16,6 +16,7 @@ InstanceDeleted, InstanceLinkCreated, InstanceLinkDeleted, + LocalModelsScanned, NodeDownloadProgress, NodeGatheredInfo, NodeTimedOut, @@ -44,6 +45,7 @@ from exo.shared.types.topology import Connection, RDMAConnection from exo.shared.types.worker.downloads import DownloadProgress from exo.shared.types.worker.instances import Instance, InstanceId +from exo.shared.types.worker.local_models import LocalModelEntry from exo.shared.types.worker.runners import ( RunnerId, RunnerReady, @@ -87,6 +89,8 @@ def event_apply(event: Event, state: State) -> State: return apply_node_timed_out(event, state) case NodeDownloadProgress(): return apply_node_download_progress(event, state) + case LocalModelsScanned(): + return apply_local_models_scanned(event, state) case NodeGatheredInfo(): return apply_node_gathered_info(event, state) case RunnerStatusUpdated(): @@ -152,6 +156,21 @@ def apply_node_download_progress(event: NodeDownloadProgress, state: State) -> S return state.model_copy(update={"downloads": new_downloads}) +def apply_local_models_scanned(event: LocalModelsScanned, state: State) -> State: + """Replace all entries for a single (node, source) pair with the freshly-scanned list.""" + others: list[LocalModelEntry] = [ + entry + for entry in state.local_models.get(event.node_id, ()) + if entry.source != event.source + ] + next_for_node: list[LocalModelEntry] = others + list(event.entries) + new_local: Mapping[NodeId, Sequence[LocalModelEntry]] = { + **state.local_models, + event.node_id: next_for_node, + } + return state.model_copy(update={"local_models": new_local}) + + def apply_task_created(event: TaskCreated, state: State) -> State: new_tasks: Mapping[TaskId, Task] = {**state.tasks, event.task_id: event.task} return state.model_copy(update={"tasks": new_tasks}) @@ -278,6 +297,9 @@ def apply_node_timed_out(event: NodeTimedOut, state: State) -> State: downloads = { key: value for key, value in state.downloads.items() if key != event.node_id } + local_models = { + key: value for key, value in state.local_models.items() if key != event.node_id + } # Clean up all granular node mappings node_memory = { key: value for key, value in state.node_memory.items() if key != event.node_id @@ -317,6 +339,7 @@ def apply_node_timed_out(event: NodeTimedOut, state: State) -> State: return state.model_copy( update={ "downloads": downloads, + "local_models": local_models, "topology": topology, "last_seen": last_seen, "node_memory": node_memory, diff --git a/src/exo/shared/models/model_cards.py b/src/exo/shared/models/model_cards.py index 0d9acc1d02..e7640678cd 100644 --- a/src/exo/shared/models/model_cards.py +++ b/src/exo/shared/models/model_cards.py @@ -349,7 +349,12 @@ async def fetch_config_data(model_id: ModelId) -> ConfigData: async def fetch_safetensors_size(model_id: ModelId) -> Memory: - """Gets model size from safetensors index or falls back to HF API.""" + """Gets model size from the safetensors index or, for single-file models, the HF API. + + Single-file safetensors models (e.g. ``Qwen/Qwen3-0.6B``) ship just ``model.safetensors`` + with no sharding index. We treat the missing index as a signal to ask HF directly via + ``model_info().safetensors.total`` rather than as an error. + """ from exo.download.download_utils import ( download_file_with_retry, resolve_model_dir, @@ -357,15 +362,23 @@ async def fetch_safetensors_size(model_id: ModelId) -> Memory: from exo.shared.types.worker.downloads import ModelSafetensorsIndex target_dir = await resolve_model_dir(model_id) - index_path = await download_file_with_retry( - model_id, - "main", - "model.safetensors.index.json", - target_dir, - lambda curr_bytes, total_bytes, is_renamed: logger.debug( - f"Downloading model.safetensors.index.json for {model_id}: {curr_bytes}/{total_bytes} ({is_renamed=})" - ), - ) + try: + index_path = await download_file_with_retry( + model_id, + "main", + "model.safetensors.index.json", + target_dir, + lambda curr_bytes, total_bytes, is_renamed: logger.debug( + f"Downloading model.safetensors.index.json for {model_id}: {curr_bytes}/{total_bytes} ({is_renamed=})" + ), + ) + except FileNotFoundError: + # Single-file model — no index on the remote. + info = model_info(model_id) + if info.safetensors is None: + raise ValueError(f"No safetensors info found for {model_id}") from None + return Memory.from_bytes(info.safetensors.total) + async with aiofiles.open(index_path, "r") as f: index_data = ModelSafetensorsIndex.model_validate_json(await f.read()) diff --git a/src/exo/shared/types/common.py b/src/exo/shared/types/common.py index 097803d307..7a9c16e59e 100644 --- a/src/exo/shared/types/common.py +++ b/src/exo/shared/types/common.py @@ -1,4 +1,4 @@ -from typing import Any, Self +from typing import Any, Literal, Self from uuid import uuid4 from pydantic import GetCoreSchemaHandler, field_validator @@ -6,6 +6,25 @@ from exo.utils.pydantic_ext import FrozenModel +ModelSourceKind = Literal[ + "exo", + "huggingface", + "lmstudio", + "ollama", + "llamacpp", +] +"""Identifier for where a locally-available model came from. + +- ``exo``: model lives in one of ``EXO_MODELS_DIRS`` and is managed by exo's own downloader. +- ``huggingface``: standard HF cache (``~/.cache/huggingface/hub/``), shared with mlx-lm and modern llama.cpp ``-hf``. +- ``lmstudio``: LM Studio's local library (``~/.lmstudio/models/{publisher}/{model}/``). +- ``ollama``: Ollama's content-addressed store (``~/.ollama/models/manifests/`` + ``blobs/``). +- ``llamacpp``: llama.cpp's standalone GGUF cache. +""" + +ModelFileFormat = Literal["safetensors", "mlx", "gguf"] +"""On-disk weight format for a discovered model.""" + class Id(str): def __new__(cls, value: str | None = None) -> Self: diff --git a/src/exo/shared/types/events.py b/src/exo/shared/types/events.py index 01aa0ce5dc..138985e742 100644 --- a/src/exo/shared/types/events.py +++ b/src/exo/shared/types/events.py @@ -6,11 +6,20 @@ from exo.shared.models.model_cards import ModelCard from exo.shared.topology import Connection from exo.shared.types.chunks import Chunk, InputImageChunk -from exo.shared.types.common import CommandId, Id, ModelId, NodeId, SessionId, SystemId +from exo.shared.types.common import ( + CommandId, + Id, + ModelId, + ModelSourceKind, + NodeId, + SessionId, + SystemId, +) from exo.shared.types.instance_link import InstanceLink, InstanceLinkId from exo.shared.types.tasks import Task, TaskId, TaskStatus from exo.shared.types.worker.downloads import DownloadProgress from exo.shared.types.worker.instances import Instance, InstanceId +from exo.shared.types.worker.local_models import LocalModelEntry from exo.shared.types.worker.runners import RunnerId, RunnerStatus from exo.utils.info_gatherer.info_gatherer import GatheredInfo from exo.utils.pydantic_ext import FrozenModel, TaggedModel @@ -90,6 +99,19 @@ class NodeDownloadProgress(BaseEvent): download_progress: DownloadProgress +class LocalModelsScanned(BaseEvent): + """Replaces the catalog of locally-discovered models for a single (node, source) pair. + + The scanner emits one of these per source on every scan; the apply handler atomically + swaps the existing entries for that (node, source). Sources that fail to scan emit + nothing (last-known-good catalog wins) — see ``source_scanner.py``. + """ + + node_id: NodeId + source: ModelSourceKind + entries: list[LocalModelEntry] + + class ChunkGenerated(BaseEvent): command_id: CommandId chunk: Chunk @@ -159,6 +181,7 @@ class InstanceLinkDeleted(BaseEvent): | NodeTimedOut | NodeGatheredInfo | NodeDownloadProgress + | LocalModelsScanned | ChunkGenerated | InputChunkReceived | TopologyEdgeCreated diff --git a/src/exo/shared/types/state.py b/src/exo/shared/types/state.py index 6c976984c8..eb1708e497 100644 --- a/src/exo/shared/types/state.py +++ b/src/exo/shared/types/state.py @@ -21,6 +21,7 @@ from exo.shared.types.tasks import Task, TaskId from exo.shared.types.worker.downloads import DownloadProgress from exo.shared.types.worker.instances import Instance, InstanceId +from exo.shared.types.worker.local_models import LocalModelEntry from exo.shared.types.worker.runners import RunnerId, RunnerStatus from exo.utils.pydantic_ext import FrozenModel @@ -44,6 +45,7 @@ class State(FrozenModel): instances: Mapping[InstanceId, Instance] = {} runners: Mapping[RunnerId, RunnerStatus] = {} downloads: Mapping[NodeId, Sequence[DownloadProgress]] = {} + local_models: Mapping[NodeId, Sequence[LocalModelEntry]] = {} tasks: Mapping[TaskId, Task] = {} last_seen: Mapping[NodeId, datetime] = {} topology: Topology = Field(default_factory=Topology) diff --git a/src/exo/shared/types/worker/local_models.py b/src/exo/shared/types/worker/local_models.py new file mode 100644 index 0000000000..a1d1d3dc31 --- /dev/null +++ b/src/exo/shared/types/worker/local_models.py @@ -0,0 +1,26 @@ +from exo.shared.types.common import ModelFileFormat, ModelId, ModelSourceKind, NodeId +from exo.shared.types.memory import Memory +from exo.utils.pydantic_ext import FrozenModel + + +class LocalModelEntry(FrozenModel): + """A model that exists on a worker's local disk, regardless of which tool put it there.""" + + node_id: NodeId + source: ModelSourceKind + external_id: str + """Natural identifier in the source's namespace (e.g. ``mlx-community/Llama-3.1-8B-Instruct-4bit``, + ``llama3:8b``). Not necessarily unique across sources — pair with ``source`` for a stable key.""" + + display_name: str + path: str + """Absolute path to the model directory (HF/MLX/LMStudio MLX) or to the weight file (GGUF).""" + + format: ModelFileFormat + size_bytes: Memory + + loadable_with_mlx: bool = False + """True if exo's MLX engine can load this entry as-is (directory of safetensors + config.json).""" + + matched_model_id: ModelId | None = None + """If this entry corresponds to an exo-known model card, the canonical exo ``ModelId``.""" diff --git a/src/exo/sources/__init__.py b/src/exo/sources/__init__.py new file mode 100644 index 0000000000..60b1f871b6 --- /dev/null +++ b/src/exo/sources/__init__.py @@ -0,0 +1,44 @@ +"""Pluggable model-source detection. + +Each :class:`~exo.sources.base.ModelSource` knows how to enumerate locally-installed +models in one external tool (HuggingFace cache, LM Studio, Ollama, llama.cpp) plus +exo's own writable cache. The registry below picks the set of sources to enable for +the current process; the per-worker scanner service walks them on a periodic interval +and emits :class:`~exo.shared.types.events.LocalModelsScanned` events for the UI. + +Adding a new source = write one class implementing :class:`ModelSource` and append it +to :func:`default_sources`. +""" + +from exo.sources.base import ModelSource +from exo.sources.exo_native import ExoNativeSource +from exo.sources.huggingface import HuggingFaceSource +from exo.sources.llamacpp import LlamaCppSource +from exo.sources.lmstudio import LMStudioSource +from exo.sources.ollama import OllamaSource + + +def default_sources() -> list[ModelSource]: + """The full list of sources exo enables by default. + + Each source's ``is_available()`` is consulted by the scanner before invoking + ``scan()`` — sources whose cache dir does not exist are silently skipped. + """ + return [ + ExoNativeSource(), + HuggingFaceSource(), + LMStudioSource(), + OllamaSource(), + LlamaCppSource(), + ] + + +__all__ = [ + "ExoNativeSource", + "HuggingFaceSource", + "LMStudioSource", + "LlamaCppSource", + "ModelSource", + "OllamaSource", + "default_sources", +] diff --git a/src/exo/sources/base.py b/src/exo/sources/base.py new file mode 100644 index 0000000000..132a28c5cb --- /dev/null +++ b/src/exo/sources/base.py @@ -0,0 +1,84 @@ +import json +from collections.abc import Iterable +from pathlib import Path +from typing import Protocol, runtime_checkable + +from exo.shared.types.common import ModelFileFormat, ModelSourceKind, NodeId +from exo.shared.types.worker.local_models import LocalModelEntry + + +def classify_directory_format(model_dir: Path) -> ModelFileFormat | None: + """Identify the weight format of an HF-style model directory. + + Returns ``"mlx"`` if the config carries a ``quantization`` block (the canonical + marker for an MLX checkpoint), ``"safetensors"`` for plain HF weights, + ``"gguf"`` for a directory of GGUFs, ``None`` if the directory isn't a usable model. + + This is shared between the exo_native, HuggingFace, and LMStudio scanners so all + three apply identical heuristics. + """ + if not model_dir.is_dir(): + return None + has_config = (model_dir / "config.json").exists() + safetensors_present = any(model_dir.glob("*.safetensors")) + ggufs_present = any(model_dir.glob("*.gguf")) + if has_config and safetensors_present: + return "mlx" if _is_mlx_config(model_dir / "config.json") else "safetensors" + if ggufs_present: + return "gguf" + return None + + +def _is_mlx_config(config_path: Path) -> bool: + try: + with config_path.open() as f: + config_raw: object = json.load(f) # pyright: ignore[reportAny] + except (OSError, ValueError): + return False + return isinstance(config_raw, dict) and "quantization" in config_raw + + +def directory_size_bytes(path: Path) -> int: + """Best-effort recursive byte count; tolerates missing files mid-scan.""" + total = 0 + try: + for child in path.rglob("*"): + if child.is_file(): + try: + total += child.stat().st_size + except OSError: + continue + except OSError: + return 0 + return total + + +@runtime_checkable +class ModelSource(Protocol): + """A scanner for one place where models live on disk. + + Implementations should be cheap to construct and tolerant of a missing cache + directory (``is_available`` returns ``False`` rather than raising). ``scan`` is + called repeatedly by the per-worker scanner service; it must not raise — return + an empty iterable if the layout is malformed. + """ + + kind: ModelSourceKind + display_name: str + + def is_available(self) -> bool: + """Return ``True`` if this source's cache directory is configured and exists.""" + ... + + def scan(self, node_id: NodeId) -> Iterable[LocalModelEntry]: + """Enumerate every local model this source can see, tagged with ``node_id``.""" + ... + + def resolve_path(self, external_id: str) -> Path | None: + """Return the on-disk path for an entry's natural identifier, or ``None`` if absent. + + ``external_id`` matches the value the source set on :class:`LocalModelEntry`. + Used by ``build_model_path`` to fall back to external sources when exo's own + cache doesn't have the model. + """ + ... diff --git a/src/exo/sources/exo_native.py b/src/exo/sources/exo_native.py new file mode 100644 index 0000000000..ff79d50492 --- /dev/null +++ b/src/exo/sources/exo_native.py @@ -0,0 +1,78 @@ +"""Scanner over exo's own writable cache (``EXO_MODELS_DIRS``). + +Each ``models--{org}--{name}`` directory becomes one entry. Models are exposed even +when in-flight downloads are tracked in ``state.downloads``; the dashboard merges both +views so users see the same model in both places only when it really is mid-download. +""" + +from collections.abc import Iterable +from pathlib import Path +from typing import final + +from loguru import logger + +from exo.shared.constants import EXO_MODELS_DIRS, EXO_MODELS_READ_ONLY_DIRS +from exo.shared.types.common import ModelId, ModelSourceKind, NodeId +from exo.shared.types.memory import Memory +from exo.shared.types.worker.local_models import LocalModelEntry +from exo.sources.base import classify_directory_format, directory_size_bytes + + +def _denormalize(dir_name: str) -> str: + """Reverse ``ModelId.normalize`` (``foo--bar`` → ``foo/bar``). + + exo stores models at ``{models_dir}/{org--name}``; we restore the canonical form + so external IDs match what the user typed. + """ + return dir_name.replace("--", "/", 1) + + +@final +class ExoNativeSource: + kind: ModelSourceKind = "exo" + display_name: str = "exo" + + def is_available(self) -> bool: + return any(d.exists() for d in (*EXO_MODELS_DIRS, *EXO_MODELS_READ_ONLY_DIRS)) + + def scan(self, node_id: NodeId) -> Iterable[LocalModelEntry]: + entries: list[LocalModelEntry] = [] + seen: set[str] = set() + for root in (*EXO_MODELS_DIRS, *EXO_MODELS_READ_ONLY_DIRS): + if not root.exists(): + continue + try: + for child in root.iterdir(): + if not child.is_dir() or "--" not in child.name: + continue + if child.name in seen: + continue + seen.add(child.name) + fmt = classify_directory_format(child) + if fmt is None: + continue + external_id = _denormalize(child.name) + entries.append( + LocalModelEntry( + node_id=node_id, + source=self.kind, + external_id=external_id, + display_name=external_id, + path=str(child), + format=fmt, + size_bytes=Memory(in_bytes=directory_size_bytes(child)), + loadable_with_mlx=fmt in ("mlx", "safetensors"), + matched_model_id=ModelId(external_id), + ) + ) + except OSError as exc: + logger.warning(f"Failed to scan exo models dir {root}: {exc!r}") + return entries + + def resolve_path(self, external_id: str) -> Path | None: + normalized = ModelId(external_id).normalize() + for root in (*EXO_MODELS_DIRS, *EXO_MODELS_READ_ONLY_DIRS): + candidate = root / normalized + if candidate.is_dir(): + return candidate + return None diff --git a/src/exo/sources/huggingface.py b/src/exo/sources/huggingface.py new file mode 100644 index 0000000000..20f7a1ff17 --- /dev/null +++ b/src/exo/sources/huggingface.py @@ -0,0 +1,85 @@ +"""HuggingFace cache scanner. + +Covers the standard HF hub cache (``~/.cache/huggingface/hub``) which is also where +``mlx-lm`` puts MLX checkpoints and where modern ``llama.cpp -hf`` lands GGUFs. We +classify each entry's format from the snapshot directory contents, and tag MLX-format +models loadable by exo's MLX engine. +""" + +from collections.abc import Iterable +from pathlib import Path +from typing import final + +from huggingface_hub import scan_cache_dir +from huggingface_hub.errors import CacheNotFound +from loguru import logger + +from exo.shared.types.common import ModelSourceKind, NodeId +from exo.shared.types.memory import Memory +from exo.shared.types.worker.local_models import LocalModelEntry +from exo.sources.base import classify_directory_format + + +def _hf_cache_dir() -> Path: + from huggingface_hub import constants + + return Path(constants.HF_HUB_CACHE) + + +@final +class HuggingFaceSource: + kind: ModelSourceKind = "huggingface" + display_name: str = "HuggingFace" + + def is_available(self) -> bool: + return _hf_cache_dir().exists() + + def scan(self, node_id: NodeId) -> Iterable[LocalModelEntry]: + if not self.is_available(): + return () + try: + info = scan_cache_dir(_hf_cache_dir()) + except CacheNotFound: + return () + except Exception as exc: + logger.warning(f"HF cache scan failed: {exc!r}") + return () + + entries: list[LocalModelEntry] = [] + for repo in info.repos: + if repo.repo_type != "model" or not repo.revisions: + continue + # Pick the largest revision — usually "main"; only one in practice for users. + latest = max(repo.revisions, key=lambda rev: rev.size_on_disk) + snapshot_path = Path(latest.snapshot_path) + fmt = classify_directory_format(snapshot_path) + # Fallback: mlx-community org override even when config.json lacks quantization. + if fmt == "safetensors" and repo.repo_id.startswith("mlx-community/"): + fmt = "mlx" + if fmt is None: + continue + entries.append( + LocalModelEntry( + node_id=node_id, + source=self.kind, + external_id=repo.repo_id, + display_name=repo.repo_id, + path=str(snapshot_path), + format=fmt, + size_bytes=Memory(in_bytes=int(latest.size_on_disk)), + loadable_with_mlx=fmt in ("mlx", "safetensors"), + ) + ) + return entries + + def resolve_path(self, external_id: str) -> Path | None: + try: + info = scan_cache_dir(_hf_cache_dir()) + except (CacheNotFound, OSError): + return None + for repo in info.repos: + if repo.repo_id != external_id or not repo.revisions: + continue + latest = max(repo.revisions, key=lambda rev: rev.size_on_disk) + return Path(latest.snapshot_path) + return None diff --git a/src/exo/sources/llamacpp.py b/src/exo/sources/llamacpp.py new file mode 100644 index 0000000000..358ce88561 --- /dev/null +++ b/src/exo/sources/llamacpp.py @@ -0,0 +1,84 @@ +"""llama.cpp standalone GGUF cache scanner. + +Modern llama.cpp ``-hf`` lands GGUFs in the HF cache (covered by ``HuggingFaceSource``), +but its standalone ``LLAMA_CACHE`` and per-OS cache dirs hold a flat directory of GGUF +files. We enumerate every ``*.gguf`` under that root. + +Override the root with ``EXO_LLAMACPP_DIR`` (used in tests) or ``LLAMA_CACHE``. +""" + +import os +import platform +from collections.abc import Iterable +from pathlib import Path +from typing import final + +from loguru import logger + +from exo.shared.types.common import ModelSourceKind, NodeId +from exo.shared.types.memory import Memory +from exo.shared.types.worker.local_models import LocalModelEntry + + +def _llamacpp_root() -> Path: + override = os.environ.get("EXO_LLAMACPP_DIR") or os.environ.get("LLAMA_CACHE") + if override: + return Path(override).expanduser() + system = platform.system() + if system == "Darwin": + return Path.home() / "Library" / "Caches" / "llama.cpp" + if system == "Windows": + local = os.environ.get("LOCALAPPDATA") + base = Path(local) if local else Path.home() / "AppData" / "Local" + return base / "llama.cpp" / "cache" + return Path.home() / ".cache" / "llama.cpp" + + +@final +class LlamaCppSource: + kind: ModelSourceKind = "llamacpp" + display_name: str = "llama.cpp" + + def is_available(self) -> bool: + return _llamacpp_root().exists() + + def scan(self, node_id: NodeId) -> Iterable[LocalModelEntry]: + root = _llamacpp_root() + if not root.exists(): + return () + entries: list[LocalModelEntry] = [] + try: + ggufs = list(root.rglob("*.gguf")) + except OSError as exc: + logger.warning(f"Failed to walk llama.cpp cache {root}: {exc!r}") + return () + for gguf in ggufs: + if not gguf.is_file(): + continue + try: + size = gguf.stat().st_size + except OSError: + continue + external_id = gguf.stem + entries.append( + LocalModelEntry( + node_id=node_id, + source=self.kind, + external_id=external_id, + display_name=external_id, + path=str(gguf), + format="gguf", + size_bytes=Memory(in_bytes=size), + loadable_with_mlx=False, + ) + ) + return entries + + def resolve_path(self, external_id: str) -> Path | None: + root = _llamacpp_root() + if not root.exists(): + return None + for gguf in root.rglob("*.gguf"): + if gguf.is_file() and gguf.stem == external_id: + return gguf + return None diff --git a/src/exo/sources/lmstudio.py b/src/exo/sources/lmstudio.py new file mode 100644 index 0000000000..bf4a1bc00c --- /dev/null +++ b/src/exo/sources/lmstudio.py @@ -0,0 +1,79 @@ +"""LM Studio model library scanner. + +LM Studio stores models under ``~/.lmstudio/models/{publisher}/{model}/``. The model +slot is either an MLX-style directory (``config.json`` + ``*.safetensors``) or a folder +containing one or more GGUF files. We expose either as a :class:`LocalModelEntry`; only +MLX/safetensors layouts are flagged ``loadable_with_mlx``. + +Override the root with ``EXO_LMSTUDIO_DIR`` (used in tests). +""" + +import os +from collections.abc import Iterable +from pathlib import Path +from typing import final + +from loguru import logger + +from exo.shared.types.common import ModelSourceKind, NodeId +from exo.shared.types.memory import Memory +from exo.shared.types.worker.local_models import LocalModelEntry +from exo.sources.base import classify_directory_format, directory_size_bytes + + +def _lmstudio_root() -> Path: + override = os.environ.get("EXO_LMSTUDIO_DIR") + if override: + return Path(override).expanduser() + return Path.home() / ".lmstudio" / "models" + + +@final +class LMStudioSource: + kind: ModelSourceKind = "lmstudio" + display_name: str = "LM Studio" + + def is_available(self) -> bool: + return _lmstudio_root().exists() + + def scan(self, node_id: NodeId) -> Iterable[LocalModelEntry]: + root = _lmstudio_root() + if not root.exists(): + return () + entries: list[LocalModelEntry] = [] + try: + publishers = [p for p in root.iterdir() if p.is_dir()] + except OSError as exc: + logger.warning(f"Failed to read LM Studio root {root}: {exc!r}") + return () + + for publisher_dir in publishers: + try: + model_dirs = [m for m in publisher_dir.iterdir() if m.is_dir()] + except OSError: + continue + for model_dir in model_dirs: + fmt = classify_directory_format(model_dir) + if fmt is None: + continue + external_id = f"{publisher_dir.name}/{model_dir.name}" + entries.append( + LocalModelEntry( + node_id=node_id, + source=self.kind, + external_id=external_id, + display_name=external_id, + path=str(model_dir), + format=fmt, + size_bytes=Memory(in_bytes=directory_size_bytes(model_dir)), + loadable_with_mlx=fmt in ("mlx", "safetensors"), + ) + ) + return entries + + def resolve_path(self, external_id: str) -> Path | None: + if "/" not in external_id: + return None + publisher, model_name = external_id.split("/", 1) + candidate = _lmstudio_root() / publisher / model_name + return candidate if candidate.is_dir() else None diff --git a/src/exo/sources/ollama.py b/src/exo/sources/ollama.py new file mode 100644 index 0000000000..28b6e632d7 --- /dev/null +++ b/src/exo/sources/ollama.py @@ -0,0 +1,158 @@ +"""Ollama model store scanner. + +Ollama uses an OCI-style content-addressed layout under ``~/.ollama/models/``: + + manifests/{registry_host}/{namespace}/{model}/{tag} # JSON manifest + blobs/sha256-{hex} # content-addressed blobs + +To enumerate offline (without invoking the daemon) we walk every JSON file under +``manifests/``, parse its ``layers[]``, and follow any layer with media type +``application/vnd.ollama.image.model`` to its GGUF blob. + +Override the root with ``OLLAMA_MODELS`` (the same env Ollama itself honors) or +``EXO_OLLAMA_DIR`` (used in tests). +""" + +import json +import os +from collections.abc import Iterable +from pathlib import Path +from typing import Any, cast, final + +from loguru import logger + +from exo.shared.types.common import ModelSourceKind, NodeId +from exo.shared.types.memory import Memory +from exo.shared.types.worker.local_models import LocalModelEntry + +_MODEL_LAYER_MEDIA_TYPE = "application/vnd.ollama.image.model" + + +def _ollama_root() -> Path: + override = os.environ.get("EXO_OLLAMA_DIR") or os.environ.get("OLLAMA_MODELS") + if override: + return Path(override).expanduser() + return Path.home() / ".ollama" / "models" + + +def _digest_to_path(root: Path, digest: str) -> Path | None: + if not digest.startswith("sha256:"): + return None + blob = root / "blobs" / f"sha256-{digest.removeprefix('sha256:')}" + return blob if blob.is_file() else None + + +def _parse_manifest(manifest_path: Path) -> tuple[str, int] | None: + """Return ``(blob_digest, total_bytes)`` for the model layer, or ``None``.""" + try: + with manifest_path.open() as f: + manifest_raw: object = json.load(f) # pyright: ignore[reportAny] + except (OSError, ValueError): + return None + if not isinstance(manifest_raw, dict): + return None + manifest = cast(dict[str, Any], manifest_raw) + layers_obj: object = manifest.get("layers") + if not isinstance(layers_obj, list): + return None + layers: list[Any] = cast(list[Any], layers_obj) + for layer_obj in layers: # pyright: ignore[reportAny] + if not isinstance(layer_obj, dict): + continue + layer = cast(dict[str, Any], layer_obj) + media_type_obj: object = layer.get("mediaType") + digest_obj: object = layer.get("digest") + size_obj: object = layer.get("size") + if ( + media_type_obj == _MODEL_LAYER_MEDIA_TYPE + and isinstance(digest_obj, str) + and isinstance(size_obj, int) + ): + return digest_obj, size_obj + return None + + +@final +class OllamaSource: + kind: ModelSourceKind = "ollama" + display_name: str = "Ollama" + + def is_available(self) -> bool: + return (_ollama_root() / "manifests").exists() + + def scan(self, node_id: NodeId) -> Iterable[LocalModelEntry]: + root = _ollama_root() + manifests_dir = root / "manifests" + if not manifests_dir.is_dir(): + return () + + entries: list[LocalModelEntry] = [] + try: + manifest_files = [ + p + for p in manifests_dir.rglob("*") + if p.is_file() and not p.name.startswith(".") + ] + except OSError as exc: + logger.warning( + f"Failed to walk Ollama manifests at {manifests_dir}: {exc!r}" + ) + return () + + for manifest_path in manifest_files: + parsed = _parse_manifest(manifest_path) + if parsed is None: + continue + digest, size = parsed + blob_path = _digest_to_path(root, digest) + if blob_path is None: + continue + external_id = _manifest_to_id(manifest_path, manifests_dir) + entries.append( + LocalModelEntry( + node_id=node_id, + source=self.kind, + external_id=external_id, + display_name=external_id, + path=str(blob_path), + format="gguf", + size_bytes=Memory(in_bytes=int(size)), + loadable_with_mlx=False, + ) + ) + return entries + + def resolve_path(self, external_id: str) -> Path | None: + # Reverse _manifest_to_id and locate the GGUF blob. + root = _ollama_root() + manifests_dir = root / "manifests" + if not manifests_dir.is_dir(): + return None + for manifest_path in manifests_dir.rglob("*"): + if not manifest_path.is_file(): + continue + if _manifest_to_id(manifest_path, manifests_dir) != external_id: + continue + parsed = _parse_manifest(manifest_path) + if parsed is None: + continue + return _digest_to_path(root, parsed[0]) + return None + + +def _manifest_to_id(manifest_path: Path, manifests_dir: Path) -> str: + """Convert ``manifests/registry.ollama.ai/library/llama3/8b`` → ``llama3:8b``. + + For non-default registries we keep the host prefix. The tag (last path part) + becomes the ``:tag`` suffix; the rest becomes the model name. + """ + rel_parts = manifest_path.relative_to(manifests_dir).parts + if len(rel_parts) < 3: + return manifest_path.name + *prefix_parts, name, tag = rel_parts + host = prefix_parts[0] if prefix_parts else "" + namespace_parts = prefix_parts[1:] if len(prefix_parts) > 1 else [] + if host in ("", "registry.ollama.ai") and namespace_parts == ["library"]: + return f"{name}:{tag}" + namespace = "/".join([*([host] if host else []), *namespace_parts]) + return f"{namespace}/{name}:{tag}" if namespace else f"{name}:{tag}" diff --git a/src/exo/sources/tests/__init__.py b/src/exo/sources/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/exo/sources/tests/test_exo_native.py b/src/exo/sources/tests/test_exo_native.py new file mode 100644 index 0000000000..0010e04841 --- /dev/null +++ b/src/exo/sources/tests/test_exo_native.py @@ -0,0 +1,70 @@ +"""Tests for the exo-native scanner over EXO_MODELS_DIRS.""" + +from pathlib import Path +from unittest.mock import patch + +import pytest + +from exo.shared.types.common import NodeId +from exo.sources.exo_native import ExoNativeSource + + +def _make_model(root: Path, model_id: str) -> Path: + """Create a fake exo-cache model dir matching ``ModelId.normalize`` layout.""" + normalized = model_id.replace("/", "--") + model_dir = root / normalized + model_dir.mkdir(parents=True) + (model_dir / "config.json").write_text('{"model_type": "llama"}') + (model_dir / "model.safetensors").write_bytes(b"x" * 4096) + return model_dir + + +@pytest.fixture +def writable_dir(tmp_path: Path) -> Path: + out = tmp_path / "writable" + out.mkdir() + return out + + +@pytest.fixture +def readonly_dir(tmp_path: Path) -> Path: + out = tmp_path / "ro" + out.mkdir() + return out + + +def test_scan_collects_models(writable_dir: Path, readonly_dir: Path) -> None: + _make_model(writable_dir, "test-org/model-A") + _make_model(readonly_dir, "test-org/model-B") + with ( + patch("exo.sources.exo_native.EXO_MODELS_DIRS", (writable_dir,)), + patch("exo.sources.exo_native.EXO_MODELS_READ_ONLY_DIRS", (readonly_dir,)), + ): + src = ExoNativeSource() + entries = sorted(src.scan(NodeId("node-1")), key=lambda e: e.external_id) + assert [e.external_id for e in entries] == ["test-org/model-A", "test-org/model-B"] + assert all(e.source == "exo" for e in entries) + assert all(e.format == "safetensors" for e in entries) + assert all(e.matched_model_id is not None for e in entries) + + +def test_dedupes_when_same_dir_in_both_lists(writable_dir: Path) -> None: + _make_model(writable_dir, "x/y") + with ( + patch("exo.sources.exo_native.EXO_MODELS_DIRS", (writable_dir,)), + patch("exo.sources.exo_native.EXO_MODELS_READ_ONLY_DIRS", (writable_dir,)), + ): + src = ExoNativeSource() + entries = list(src.scan(NodeId("node-1"))) + assert len(entries) == 1 + + +def test_resolve_path(writable_dir: Path, readonly_dir: Path) -> None: + expected = _make_model(writable_dir, "x/y") + with ( + patch("exo.sources.exo_native.EXO_MODELS_DIRS", (writable_dir,)), + patch("exo.sources.exo_native.EXO_MODELS_READ_ONLY_DIRS", (readonly_dir,)), + ): + src = ExoNativeSource() + assert src.resolve_path("x/y") == expected + assert src.resolve_path("absent/model") is None diff --git a/src/exo/sources/tests/test_huggingface.py b/src/exo/sources/tests/test_huggingface.py new file mode 100644 index 0000000000..5ba258ff23 --- /dev/null +++ b/src/exo/sources/tests/test_huggingface.py @@ -0,0 +1,106 @@ +"""Tests for the HuggingFace cache scanner.""" + +import json +from pathlib import Path + +import pytest + +from exo.shared.types.common import NodeId +from exo.sources.huggingface import HuggingFaceSource + + +def _make_repo( + cache_root: Path, + repo_id: str, + *, + files: dict[str, bytes], + revision: str = "abc123", +) -> Path: + """Build a minimal HF cache layout: ``models--{org}--{name}/snapshots/{rev}/...``.""" + repo_dir = cache_root / f"models--{repo_id.replace('/', '--')}" + snapshot = repo_dir / "snapshots" / revision + blobs = repo_dir / "blobs" + refs = repo_dir / "refs" + snapshot.mkdir(parents=True) + blobs.mkdir(parents=True) + refs.mkdir(parents=True) + (refs / "main").write_text(revision) + for fname, payload in files.items(): + blob_path = blobs / f"blob_{fname}" + blob_path.write_bytes(payload) + (snapshot / fname).symlink_to(blob_path) + return snapshot + + +@pytest.fixture +def hf_cache(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> Path: + cache = tmp_path / "hf_hub" + cache.mkdir() + monkeypatch.setenv("HF_HUB_CACHE", str(cache)) + # huggingface_hub.constants is read at import time; reload to pick up the env. + import importlib + + import huggingface_hub.constants as constants + + importlib.reload(constants) + return cache + + +def test_scan_finds_safetensors_repo(hf_cache: Path) -> None: + _make_repo( + hf_cache, + "meta-llama/Llama-3.2-1B-Instruct", + files={ + "config.json": b'{"model_type": "llama"}', + "model.safetensors": b"x" * 4096, + }, + ) + src = HuggingFaceSource() + entries = list(src.scan(NodeId("node-1"))) + assert len(entries) == 1 + e = entries[0] + assert e.source == "huggingface" + assert e.external_id == "meta-llama/Llama-3.2-1B-Instruct" + assert e.format == "safetensors" + assert e.loadable_with_mlx is True + + +def test_scan_classifies_mlx_community_as_mlx(hf_cache: Path) -> None: + _make_repo( + hf_cache, + "mlx-community/Llama-3.2-1B-Instruct-4bit", + files={ + "config.json": json.dumps( + {"model_type": "llama", "quantization": {"group_size": 64, "bits": 4}} + ).encode(), + "model.safetensors": b"x" * 4096, + }, + ) + src = HuggingFaceSource() + entries = list(src.scan(NodeId("node-1"))) + assert [e.format for e in entries] == ["mlx"] + assert entries[0].loadable_with_mlx is True + + +def test_resolve_path_returns_snapshot(hf_cache: Path) -> None: + snapshot = _make_repo( + hf_cache, + "test-org/test-model", + files={ + "config.json": b'{"model_type": "test"}', + "model.safetensors": b"weights", + }, + ) + src = HuggingFaceSource() + resolved = src.resolve_path("test-org/test-model") + assert resolved == snapshot + + +def test_scan_skips_repos_without_weights(hf_cache: Path) -> None: + _make_repo( + hf_cache, + "empty/repo", + files={"README.md": b"# nothing here"}, + ) + src = HuggingFaceSource() + assert list(src.scan(NodeId("node-1"))) == [] diff --git a/src/exo/sources/tests/test_llamacpp.py b/src/exo/sources/tests/test_llamacpp.py new file mode 100644 index 0000000000..eec69c75a9 --- /dev/null +++ b/src/exo/sources/tests/test_llamacpp.py @@ -0,0 +1,49 @@ +"""Tests for the llama.cpp standalone GGUF cache scanner.""" + +from pathlib import Path + +import pytest + +from exo.shared.types.common import NodeId +from exo.sources.llamacpp import LlamaCppSource + + +@pytest.fixture +def llamacpp_root(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> Path: + root = tmp_path / "llamacpp_cache" + root.mkdir() + monkeypatch.setenv("EXO_LLAMACPP_DIR", str(root)) + return root + + +def test_scan_finds_ggufs(llamacpp_root: Path) -> None: + payload = b"GGUF" + b"\0" * 1024 + (llamacpp_root / "Llama-3.2-1B-Q4_K_M.gguf").write_bytes(payload) + (llamacpp_root / "subdir").mkdir() + (llamacpp_root / "subdir" / "Phi-3-mini-Q5_K_M.gguf").write_bytes(payload) + src = LlamaCppSource() + entries = sorted(src.scan(NodeId("n-1")), key=lambda e: e.external_id) + assert [e.external_id for e in entries] == [ + "Llama-3.2-1B-Q4_K_M", + "Phi-3-mini-Q5_K_M", + ] + assert all(e.format == "gguf" for e in entries) + assert all(not e.loadable_with_mlx for e in entries) + assert all(e.size_bytes.in_bytes == len(payload) for e in entries) + + +def test_resolve_path(llamacpp_root: Path) -> None: + target = llamacpp_root / "model-A.gguf" + target.write_bytes(b"GGUF") + src = LlamaCppSource() + assert src.resolve_path("model-A") == target + assert src.resolve_path("absent") is None + + +def test_unavailable_when_root_missing( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + monkeypatch.setenv("EXO_LLAMACPP_DIR", str(tmp_path / "nope")) + src = LlamaCppSource() + assert src.is_available() is False + assert list(src.scan(NodeId("n-1"))) == [] diff --git a/src/exo/sources/tests/test_lmstudio.py b/src/exo/sources/tests/test_lmstudio.py new file mode 100644 index 0000000000..ea57358519 --- /dev/null +++ b/src/exo/sources/tests/test_lmstudio.py @@ -0,0 +1,71 @@ +"""Tests for the LM Studio scanner.""" + +from pathlib import Path + +import pytest + +from exo.shared.types.common import NodeId +from exo.sources.lmstudio import LMStudioSource + + +def _make_mlx_model(parent: Path, publisher: str, name: str) -> Path: + model_dir = parent / publisher / name + model_dir.mkdir(parents=True) + (model_dir / "config.json").write_text('{"model_type": "llama"}') + (model_dir / "model.safetensors").write_bytes(b"x" * 1024) + return model_dir + + +def _make_gguf_model(parent: Path, publisher: str, name: str) -> Path: + model_dir = parent / publisher / name + model_dir.mkdir(parents=True) + (model_dir / f"{name}.gguf").write_bytes(b"GGUF" + b"\0" * 1020) + return model_dir + + +@pytest.fixture +def lm_studio_root(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> Path: + root = tmp_path / "lmstudio_models" + root.mkdir() + monkeypatch.setenv("EXO_LMSTUDIO_DIR", str(root)) + return root + + +def test_scan_finds_mlx_and_gguf(lm_studio_root: Path) -> None: + _make_mlx_model(lm_studio_root, "MLX", "Llama-3.2-1B-Instruct-4bit-MLX") + _make_gguf_model(lm_studio_root, "Bartowski", "Llama-3.2-1B-Instruct-GGUF") + src = LMStudioSource() + entries = sorted(src.scan(NodeId("node-1")), key=lambda e: e.external_id) + assert len(entries) == 2 + by_id = {e.external_id: e for e in entries} + mlx = by_id["MLX/Llama-3.2-1B-Instruct-4bit-MLX"] + assert mlx.format == "safetensors" + assert mlx.loadable_with_mlx is True + gguf = by_id["Bartowski/Llama-3.2-1B-Instruct-GGUF"] + assert gguf.format == "gguf" + assert gguf.loadable_with_mlx is False + + +def test_resolve_path(lm_studio_root: Path) -> None: + expected = _make_mlx_model(lm_studio_root, "Pub", "Model-X") + src = LMStudioSource() + assert src.resolve_path("Pub/Model-X") == expected + assert src.resolve_path("Other/Model-X") is None + assert src.resolve_path("malformed-id") is None + + +def test_skips_directories_without_models(lm_studio_root: Path) -> None: + publisher = lm_studio_root / "Pub" / "EmptyModel" + publisher.mkdir(parents=True) + (publisher / "README.md").write_text("nothing") + src = LMStudioSource() + assert list(src.scan(NodeId("node-1"))) == [] + + +def test_is_unavailable_when_dir_missing( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + monkeypatch.setenv("EXO_LMSTUDIO_DIR", str(tmp_path / "does-not-exist")) + src = LMStudioSource() + assert src.is_available() is False + assert list(src.scan(NodeId("node-1"))) == [] diff --git a/src/exo/sources/tests/test_ollama.py b/src/exo/sources/tests/test_ollama.py new file mode 100644 index 0000000000..ce90716bfe --- /dev/null +++ b/src/exo/sources/tests/test_ollama.py @@ -0,0 +1,143 @@ +"""Tests for the Ollama scanner — manifest+blob walk without invoking the daemon.""" + +import hashlib +import json +from pathlib import Path + +import pytest + +from exo.shared.types.common import NodeId +from exo.sources.ollama import OllamaSource + + +def _write_blob(blobs_dir: Path, content: bytes) -> str: + digest = hashlib.sha256(content).hexdigest() + blob_path = blobs_dir / f"sha256-{digest}" + blob_path.write_bytes(content) + return f"sha256:{digest}" + + +def _write_manifest( + manifests_dir: Path, + *, + host: str, + namespace: str, + name: str, + tag: str, + model_digest: str, + model_size: int, +) -> Path: + manifest_dir = manifests_dir / host / namespace / name + manifest_dir.mkdir(parents=True) + manifest = { + "schemaVersion": 2, + "config": { + "mediaType": "application/vnd.docker.container.image.v1+json", + "digest": "sha256:0", + "size": 0, + }, + "layers": [ + { + "mediaType": "application/vnd.ollama.image.model", + "digest": model_digest, + "size": model_size, + }, + { + "mediaType": "application/vnd.ollama.image.template", + "digest": "sha256:1", + "size": 1, + }, + ], + } + manifest_path = manifest_dir / tag + manifest_path.write_text(json.dumps(manifest)) + return manifest_path + + +@pytest.fixture +def ollama_root(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> Path: + root = tmp_path / "ollama" + (root / "manifests").mkdir(parents=True) + (root / "blobs").mkdir(parents=True) + monkeypatch.setenv("EXO_OLLAMA_DIR", str(root)) + return root + + +def test_scan_resolves_default_registry_models(ollama_root: Path) -> None: + blob_payload = b"GGUF" + b"\0" * 8000 + digest = _write_blob(ollama_root / "blobs", blob_payload) + _write_manifest( + ollama_root / "manifests", + host="registry.ollama.ai", + namespace="library", + name="llama3", + tag="8b", + model_digest=digest, + model_size=len(blob_payload), + ) + src = OllamaSource() + entries = list(src.scan(NodeId("node-1"))) + assert len(entries) == 1 + e = entries[0] + assert e.external_id == "llama3:8b" + assert e.format == "gguf" + assert e.loadable_with_mlx is False + assert e.size_bytes.in_bytes == len(blob_payload) + assert Path(e.path).read_bytes() == blob_payload + + +def test_scan_keeps_namespace_for_non_default_registry(ollama_root: Path) -> None: + digest = _write_blob(ollama_root / "blobs", b"GGUF" + b"\0" * 200) + _write_manifest( + ollama_root / "manifests", + host="hf.co", + namespace="bartowski", + name="model", + tag="latest", + model_digest=digest, + model_size=204, + ) + src = OllamaSource() + entries = list(src.scan(NodeId("node-1"))) + assert [e.external_id for e in entries] == ["hf.co/bartowski/model:latest"] + + +def test_resolve_path_finds_blob(ollama_root: Path) -> None: + payload = b"GGUF" + b"\0" * 100 + digest = _write_blob(ollama_root / "blobs", payload) + _write_manifest( + ollama_root / "manifests", + host="registry.ollama.ai", + namespace="library", + name="llama3", + tag="8b", + model_digest=digest, + model_size=len(payload), + ) + src = OllamaSource() + resolved = src.resolve_path("llama3:8b") + assert resolved is not None + assert resolved.read_bytes() == payload + + +def test_skips_when_blob_missing(ollama_root: Path) -> None: + _write_manifest( + ollama_root / "manifests", + host="registry.ollama.ai", + namespace="library", + name="llama3", + tag="8b", + model_digest="sha256:" + "ab" * 32, + model_size=10, + ) + # The referenced blob does not exist on disk. + src = OllamaSource() + assert list(src.scan(NodeId("node-1"))) == [] + + +def test_is_unavailable_when_root_missing( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + monkeypatch.setenv("EXO_OLLAMA_DIR", str(tmp_path / "missing")) + src = OllamaSource() + assert src.is_available() is False diff --git a/src/exo/worker/main.py b/src/exo/worker/main.py index b35f946aac..a57a9456ed 100644 --- a/src/exo/worker/main.py +++ b/src/exo/worker/main.py @@ -58,6 +58,7 @@ from exo.utils.task_group import TaskGroup from exo.worker.plan import plan from exo.worker.runner.supervisor import RunnerSupervisor +from exo.worker.source_scanner import SourceScanner class Worker: @@ -103,6 +104,8 @@ async def run(self): info_send, info_recv = channel[GatheredInfo]() info_gatherer: InfoGatherer = InfoGatherer(info_send) + source_scanner = SourceScanner(self.node_id, self.event_sender) + try: async with self._tg as tg: tg.start_soon(info_gatherer.run) @@ -110,6 +113,7 @@ async def run(self): tg.start_soon(self.plan_step) tg.start_soon(self._event_applier) tg.start_soon(self._poll_connection_updates) + tg.start_soon(source_scanner.run) finally: # Actual shutdown code - waits for all tasks to complete before executing. logger.info("Stopping Worker") diff --git a/src/exo/worker/source_scanner.py b/src/exo/worker/source_scanner.py new file mode 100644 index 0000000000..528c18b772 --- /dev/null +++ b/src/exo/worker/source_scanner.py @@ -0,0 +1,79 @@ +"""Per-worker scan of locally-installed models across all configured sources. + +Runs as one of the worker's TaskGroup tasks. On each tick we walk every +:class:`~exo.sources.base.ModelSource`, collect its current entries, and emit a +single :class:`LocalModelsScanned` event per source so the apply handler atomically +swaps that (node, source) view. + +A failed scan emits nothing for that source — last-known-good state wins. Empty +results emit an explicit empty list, so removed entries propagate. +""" + +from collections.abc import Sequence + +import anyio +from anyio import to_thread +from loguru import logger + +from exo.shared.types.common import ModelSourceKind, NodeId +from exo.shared.types.events import Event, LocalModelsScanned +from exo.shared.types.worker.local_models import LocalModelEntry +from exo.sources import default_sources +from exo.sources.base import ModelSource +from exo.utils.channels import Sender + + +class SourceScanner: + def __init__( + self, + node_id: NodeId, + event_sender: Sender[Event], + *, + sources: Sequence[ModelSource] | None = None, + interval_seconds: float = 60.0, + ): + self.node_id: NodeId = node_id + self.event_sender: Sender[Event] = event_sender + self.sources: Sequence[ModelSource] = ( + sources if sources is not None else default_sources() + ) + self.interval_seconds: float = interval_seconds + + async def run(self) -> None: + # Scan once on boot so the dashboard has data before the first interval tick. + await self.scan_once() + while True: + await anyio.sleep(self.interval_seconds) + await self.scan_once() + + async def scan_once(self) -> None: + """Run every available source and emit one event per source.""" + for source in self.sources: + try: + if not source.is_available(): + # Empty entries flush any stale catalog left from a previous availability. + await self._emit(source.kind, []) + continue + entries = await to_thread.run_sync(_scan_sync, source, self.node_id) + except Exception as exc: # noqa: BLE001 — third-party caches can throw anything + logger.warning( + f"Source {source.kind} ({source.display_name}) scan failed: {exc!r}" + ) + continue + await self._emit(source.kind, entries) + + async def _emit( + self, source_kind: ModelSourceKind, entries: list[LocalModelEntry] + ) -> None: + await self.event_sender.send( + LocalModelsScanned( + node_id=self.node_id, + source=source_kind, + entries=entries, + ) + ) + + +def _scan_sync(source: ModelSource, node_id: NodeId) -> list[LocalModelEntry]: + """Module-level helper so ``to_thread.run_sync`` gets a plain function reference.""" + return list(source.scan(node_id)) diff --git a/src/exo/worker/tests/test_source_scanner.py b/src/exo/worker/tests/test_source_scanner.py new file mode 100644 index 0000000000..4dc35779a6 --- /dev/null +++ b/src/exo/worker/tests/test_source_scanner.py @@ -0,0 +1,125 @@ +"""Tests for the worker's SourceScanner — diff/emission against fake sources.""" + +from collections.abc import Iterable +from pathlib import Path + +import anyio +import pytest + +from exo.shared.types.common import ModelSourceKind, NodeId +from exo.shared.types.events import Event, LocalModelsScanned +from exo.shared.types.memory import Memory +from exo.shared.types.worker.local_models import LocalModelEntry +from exo.utils.channels import channel +from exo.worker.source_scanner import SourceScanner + + +class FakeSource: + def __init__( + self, + kind: ModelSourceKind, + entries: list[LocalModelEntry], + *, + available: bool = True, + raise_on_scan: bool = False, + ) -> None: + self.kind: ModelSourceKind = kind + self.display_name: str = kind + self._available: bool = available + self._entries: list[LocalModelEntry] = entries + self._raise: bool = raise_on_scan + self.scan_count: int = 0 + + def is_available(self) -> bool: + return self._available + + def scan(self, node_id: NodeId) -> Iterable[LocalModelEntry]: + self.scan_count += 1 + if self._raise: + raise RuntimeError("boom") + # Return the configured entries with node_id stamped on. + return [ + entry.model_copy(update={"node_id": node_id}) for entry in self._entries + ] + + def resolve_path( + self, external_id: str + ) -> Path | None: # pragma: no cover - unused + return None + + +def _entry(source: ModelSourceKind, external_id: str) -> LocalModelEntry: + return LocalModelEntry( + node_id=NodeId("placeholder"), + source=source, + external_id=external_id, + display_name=external_id, + path=f"/fake/{external_id}", + format="safetensors", + size_bytes=Memory(in_bytes=4096), + loadable_with_mlx=True, + ) + + +@pytest.mark.asyncio +async def test_scan_once_emits_one_event_per_source() -> None: + sender, receiver = channel[Event]() + node = NodeId("worker-A") + sources = [ + FakeSource("huggingface", [_entry("huggingface", "org/model-1")]), + FakeSource("lmstudio", [_entry("lmstudio", "Pub/Model-2")]), + ] + scanner = SourceScanner(node, sender, sources=sources, interval_seconds=999) + await scanner.scan_once() + + events = receiver.collect() + assert len(events) == 2 + by_kind = {e.source: e for e in events if isinstance(e, LocalModelsScanned)} + assert by_kind["huggingface"].entries[0].external_id == "org/model-1" + assert by_kind["lmstudio"].entries[0].external_id == "Pub/Model-2" + + +@pytest.mark.asyncio +async def test_unavailable_source_emits_empty_list() -> None: + """Stale catalog must be flushed when a source disappears.""" + sender, receiver = channel[Event]() + sources = [ + FakeSource("ollama", [_entry("ollama", "llama3:8b")], available=False), + ] + scanner = SourceScanner(NodeId("n"), sender, sources=sources, interval_seconds=999) + await scanner.scan_once() + + events = [e for e in receiver.collect() if isinstance(e, LocalModelsScanned)] + assert len(events) == 1 + assert events[0].source == "ollama" + assert events[0].entries == [] + assert sources[0].scan_count == 0 # is_available() short-circuited the scan + + +@pytest.mark.asyncio +async def test_failing_source_does_not_block_others() -> None: + sender, receiver = channel[Event]() + sources = [ + FakeSource("huggingface", [_entry("huggingface", "ok/1")], raise_on_scan=True), + FakeSource("lmstudio", [_entry("lmstudio", "ok/2")]), + ] + scanner = SourceScanner(NodeId("n"), sender, sources=sources, interval_seconds=999) + await scanner.scan_once() + + events = [e for e in receiver.collect() if isinstance(e, LocalModelsScanned)] + # Only the working source emits; the failing one is silenced. + assert [e.source for e in events] == ["lmstudio"] + + +@pytest.mark.asyncio +async def test_run_emits_immediately_then_loops() -> None: + sender, receiver = channel[Event]() + sources = [FakeSource("exo", [_entry("exo", "x/y")])] + scanner = SourceScanner(NodeId("n"), sender, sources=sources, interval_seconds=0.05) + async with anyio.create_task_group() as tg: + tg.start_soon(scanner.run) + await anyio.sleep(0.12) # 1 immediate + at least 1 interval-driven scan + tg.cancel_scope.cancel() + events = [e for e in receiver.collect() if isinstance(e, LocalModelsScanned)] + assert len(events) >= 2 + assert sources[0].scan_count >= 2