Skip to content

Commit 3b5a404

Browse files
authored
Add files via upload
1 parent d7ee765 commit 3b5a404

25 files changed

Lines changed: 1569 additions & 214 deletions

chronos/mlx/expert_store.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def __init__(self, model, config, ssd_dir: str = "./expert_cache_mlx"):
7272
self._warm_lru: collections.OrderedDict = collections.OrderedDict()
7373
self._cluster_manifest: ClusterManifest | None = None
7474
self._loaded_clusters: set[int] = set()
75+
self._protected_experts: set[int] = set()
7576
self.storage_format = "npy"
7677
self.attach_cluster_manifest(ssd_dir)
7778
self._stats = collections.Counter()
@@ -296,18 +297,25 @@ def _put_warm_locked(
296297
self._warm_lru.move_to_end(expert_id)
297298
self._evict_warm_locked(protected_expert_id=protected_expert_id)
298299

300+
def set_protected_experts(self, expert_ids: List[int] | set[int] | tuple[int, ...]):
301+
with self._lock:
302+
self._protected_experts = {
303+
int(eid) for eid in expert_ids
304+
if 0 <= int(eid) < int(self.num_experts)
305+
}
306+
299307
def _evict_warm_locked(self, protected_expert_id: int | None = None):
300308
while len(self._warm_lru) > self._warm_capacity:
301309
evict_id = None
302310
for candidate in self._warm_lru.keys():
303-
if candidate == protected_expert_id:
311+
if candidate == protected_expert_id or candidate in self._protected_experts:
304312
continue
305313
if candidate not in self._hot_lru:
306314
evict_id = candidate
307315
break
308316
if evict_id is None:
309317
for candidate in self._warm_lru.keys():
310-
if candidate != protected_expert_id:
318+
if candidate != protected_expert_id and candidate not in self._protected_experts:
311319
evict_id = candidate
312320
break
313321
if evict_id is None:
@@ -455,9 +463,17 @@ def promote_to_vram(self, expert_id: int) -> bool:
455463
self._hot_lru.move_to_end(expert_id)
456464
else:
457465
while len(self._hot_lru) >= self._capacity:
458-
evicted, _ = self._hot_lru.popitem(last=False)
459-
if evicted == expert_id:
460-
continue
466+
evicted = None
467+
for candidate in self._hot_lru.keys():
468+
if candidate != expert_id and candidate not in self._protected_experts:
469+
evicted = candidate
470+
break
471+
if evicted is None:
472+
evicted, _ = self._hot_lru.popitem(last=False)
473+
if evicted == expert_id:
474+
continue
475+
else:
476+
self._hot_lru.pop(evicted, None)
461477
# Keep warm weights in unified memory, but drop the
462478
# executable live module so masked MoE cannot call an
463479
# expert outside the hot execution budget.

chronos/mlx/inference.py

Lines changed: 128 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,11 @@ def __init__(self, model, config, ssd_dir: str = "./expert_cache_mlx"):
4747
self.last_stats: dict = {}
4848
self._last_predicted: set[int] = set()
4949
self._stats_lock = threading.Lock()
50-
self._runtime_stats = {
50+
self._runtime_stats = self._new_runtime_stats()
51+
52+
@staticmethod
53+
def _new_runtime_stats(all_resident_fast_path: bool = False) -> dict:
54+
return {
5155
"resident_hits": 0,
5256
"resident_misses": 0,
5357
"resident_vram_hits": 0,
@@ -60,6 +64,9 @@ def __init__(self, model, config, ssd_dir: str = "./expert_cache_mlx"):
6064
"prefetch_wait_time_s": 0.0,
6165
"sync_ssd_loads": 0,
6266
"on_demand_load_time_s": 0.0,
67+
"cold_prediction_hits": 0,
68+
"cold_prediction_total": 0,
69+
"all_resident_fast_path": bool(all_resident_fast_path),
6370
}
6471

6572
def setup(self, warm_expert_ids: Optional[List[int]] = None):
@@ -90,6 +97,9 @@ def loader(eid: int) -> bool:
9097
if not was_hot and not was_warm:
9198
with self._stats_lock:
9299
self._runtime_stats["sync_ssd_loads"] += 1
100+
self._runtime_stats["cold_prediction_total"] += 1
101+
if eid in self._last_predicted:
102+
self._runtime_stats["cold_prediction_hits"] += 1
93103
ok = self.store.promote_to_vram(eid)
94104
elapsed = time.monotonic() - t0
95105
with self._stats_lock:
@@ -129,6 +139,14 @@ def touch(eid: int) -> None:
129139
moe.runtime_on_demand_loader = loader
130140
moe.runtime_touch_expert = touch
131141

142+
def _clear_runtime_hooks(self):
143+
for layer in self.model.layers:
144+
moe = getattr(layer, "mlp", None)
145+
if isinstance(moe, ChronosMLXMOE):
146+
moe.runtime_miss_policy = "sync_on_demand"
147+
moe.runtime_on_demand_loader = None
148+
moe.runtime_touch_expert = None
149+
132150
# ── Prefetch background thread ────────────────────────────────
133151

134152
def _prefetch_loop(self):
@@ -149,37 +167,63 @@ def _schedule_prefetch(self, expert_ids: List[int]):
149167
with self._stats_lock:
150168
self._runtime_stats["prefetch_queue_drops"] += 1
151169

152-
def _prefetch_and_promote_window(self, expert_ids: List[int], timeout_s: float = 0.012) -> None:
153-
if not expert_ids or self.store.storage_format == "full_dram":
170+
def _prefetch_and_promote_window(
171+
self,
172+
prefetch_ids: List[int],
173+
promote_ids: Optional[List[int]] = None,
174+
timeout_s: float = 0.012,
175+
) -> None:
176+
if not prefetch_ids or self.store.storage_format == "full_dram":
154177
return
178+
promote_ids = list(dict.fromkeys(int(eid) for eid in (promote_ids or prefetch_ids)))
155179
pending = []
156-
for eid in dict.fromkeys(int(eid) for eid in expert_ids):
180+
for eid in dict.fromkeys(int(eid) for eid in prefetch_ids):
157181
with self.store._lock:
158182
if eid in self.store._hot_lru:
159183
continue
160184
if eid in self.store._warm and self.store._layer_states_complete(self.store._warm.get(eid)):
161-
pending.append(eid)
162185
continue
163186
pending.append(eid)
164-
if not pending:
165-
return
166-
self._schedule_prefetch(pending)
187+
if pending:
188+
self._schedule_prefetch(pending)
167189
deadline = time.monotonic() + max(0.0, float(timeout_s))
168190
while time.monotonic() < deadline:
169-
self._promote_ready(pending)
191+
self._promote_ready(promote_ids)
170192
with self.store._lock:
171193
ready = all(
172194
eid in self.store._hot_lru
173195
or (eid in self.store._warm and self.store._layer_states_complete(self.store._warm.get(eid)))
174-
for eid in pending
196+
for eid in promote_ids
175197
)
176198
if ready:
177199
break
178200
time.sleep(0.001)
179-
self._promote_ready(pending)
201+
self._promote_ready(promote_ids)
180202
with self._stats_lock:
181203
self._runtime_stats["prefetch_wait_time_s"] += max(0.0, time.monotonic() - (deadline - max(0.0, float(timeout_s))))
182204

205+
def _all_resident_fast_path_enabled(self) -> bool:
206+
if self.store.storage_format == "full_dram":
207+
return True
208+
num_experts = int(getattr(self.config, "num_experts", 0) or 0)
209+
if num_experts <= 0:
210+
return False
211+
if int(getattr(self.store, "_capacity", 0) or 0) < num_experts:
212+
return False
213+
if int(getattr(self.store, "_warm_capacity", 0) or 0) < num_experts:
214+
return False
215+
return len(self.store.hot_expert_ids()) >= num_experts
216+
217+
def _protect_prediction_window(self, expert_ids: List[int]) -> None:
218+
capacity = max(
219+
1,
220+
int(getattr(self.store, "_warm_capacity", 1) or 1),
221+
int(getattr(self.store, "_capacity", 1) or 1),
222+
)
223+
hot_ids = sorted(self.store.hot_expert_ids())
224+
protected = list(dict.fromkeys([int(eid) for eid in expert_ids] + hot_ids))[:capacity]
225+
self.store.set_protected_experts(protected)
226+
183227
def stop(self):
184228
self._stop.set()
185229
try:
@@ -210,23 +254,15 @@ def generate(
210254
setup_mem = self._memory_snapshot()
211255
self._last_predicted = set()
212256
with self._stats_lock:
213-
self._runtime_stats = {
214-
"resident_hits": 0,
215-
"resident_misses": 0,
216-
"resident_vram_hits": 0,
217-
"resident_ram_hits": 0,
218-
"selection_hits": 0,
219-
"selection_misses": 0,
220-
"prediction_hits": 0,
221-
"prediction_total": 0,
222-
"prefetch_queue_drops": 0,
223-
"prefetch_wait_time_s": 0.0,
224-
"sync_ssd_loads": 0,
225-
"on_demand_load_time_s": 0.0,
226-
}
227-
self._install_runtime_hooks()
257+
all_resident_fast_path = self._all_resident_fast_path_enabled()
258+
self._runtime_stats = self._new_runtime_stats(all_resident_fast_path)
259+
if all_resident_fast_path:
260+
self._clear_runtime_hooks()
261+
self.store.set_protected_experts([])
262+
else:
263+
self._install_runtime_hooks()
228264
# ── Prefill phase: front-load expert IO ───────────────────────────
229-
if scheduler is not None:
265+
if scheduler is not None and not all_resident_fast_path:
230266
import torch, numpy as np
231267
ids_np = np.array(input_ids.tolist(), dtype=np.int64)
232268
ids_pt = torch.from_numpy(ids_np)
@@ -235,7 +271,7 @@ def generate(
235271

236272
# Prefill forward
237273
prefill_t0 = time.monotonic()
238-
if self.store.storage_format == "full_dram":
274+
if all_resident_fast_path:
239275
prefill_masks = None
240276
else:
241277
prefill_masks = self._build_avail_masks(None)
@@ -248,21 +284,41 @@ def generate(
248284
next_token = self._sample(logits[:, -1, :], temperature, top_p)
249285
activated_ids: List[int] = []
250286
tokens = 1
251-
if scheduler is None and lookahead_probs is not None:
252-
future_ids = self._predict_future_experts(lookahead_probs)
287+
if not all_resident_fast_path and scheduler is None and lookahead_probs is not None:
288+
future_ids = self._predict_future_experts(
289+
lookahead_probs,
290+
capacity=int(getattr(self.store, "_warm_capacity", self.config.num_experts) or self.config.num_experts),
291+
)
292+
immediate_ids = self._predict_future_experts(
293+
lookahead_probs,
294+
capacity=int(getattr(self.store, "_capacity", 1) or 1),
295+
max_steps=1,
296+
)
253297
self._last_predicted = set(int(eid) for eid in future_ids)
254-
self._prefetch_and_promote_window(future_ids, timeout_s=0.025)
298+
self._protect_prediction_window(future_ids)
299+
self._prefetch_and_promote_window(future_ids, promote_ids=immediate_ids, timeout_s=0.025)
255300
yield int(next_token.item())
256301

257302
for _ in range(max_new_tokens - 1):
258-
if scheduler is not None:
303+
if all_resident_fast_path:
304+
avail_masks = None
305+
elif scheduler is not None:
259306
avail_masks = self._build_avail_masks(next_token)
260307
else:
261308
# LookaheadRouter-driven prefetch
262309
if lookahead_probs is not None:
263-
future_ids = self._predict_future_experts(lookahead_probs)
310+
future_ids = self._predict_future_experts(
311+
lookahead_probs,
312+
capacity=int(getattr(self.store, "_warm_capacity", self.config.num_experts) or self.config.num_experts),
313+
)
314+
immediate_ids = self._predict_future_experts(
315+
lookahead_probs,
316+
capacity=int(getattr(self.store, "_capacity", 1) or 1),
317+
max_steps=1,
318+
)
264319
self._last_predicted = set(int(eid) for eid in future_ids)
265-
self._prefetch_and_promote_window(future_ids)
320+
self._protect_prediction_window(future_ids)
321+
self._prefetch_and_promote_window(future_ids, promote_ids=immediate_ids)
266322
avail_masks = self._build_avail_masks(next_token)
267323

268324
token_in = next_token.reshape(1, 1)
@@ -298,22 +354,39 @@ def generate(
298354
**self.store.stats(),
299355
}
300356

301-
def _predict_future_experts(self, lookahead_probs: mx.array) -> List[int]:
357+
def _predict_future_experts(
358+
self,
359+
lookahead_probs: mx.array,
360+
capacity: Optional[int] = None,
361+
max_steps: Optional[int] = None,
362+
) -> List[int]:
302363
"""Extract bounded top-k expert IDs for future steps."""
303364
# lookahead_probs: [B, S, K+1, E] — take last token, steps 1..K
304365
future = lookahead_probs[0, -1, 1:, :] # [K, E]
366+
if max_steps is not None:
367+
future = future[:max(0, int(max_steps))]
368+
if future.shape[0] <= 0:
369+
return []
370+
num_experts = int(getattr(self.config, "num_experts", future.shape[-1]) or future.shape[-1])
305371
top_k = max(1, min(
306372
int(getattr(self.config, "num_experts_per_tok", 1) or 1),
307-
int(getattr(self.config, "num_experts", future.shape[-1]) or future.shape[-1]),
373+
num_experts,
308374
))
309-
ids: list[int] = []
310-
for k in range(future.shape[0]):
311-
step = future[k]
312-
if top_k == 1:
313-
ids.append(int(mx.argmax(step).item()))
314-
else:
315-
ids.extend(int(v) for v in mx.argpartition(-step, kth=top_k - 1, axis=-1)[:top_k].tolist())
316-
return list(dict.fromkeys(ids))
375+
budget = num_experts if capacity is None else max(1, min(num_experts, int(capacity)))
376+
scores = mx.zeros((num_experts,), dtype=mx.float32)
377+
for offset in range(future.shape[0]):
378+
step = future[offset].astype(mx.float32)
379+
if top_k >= num_experts:
380+
scores = scores + step / float(offset + 1)
381+
continue
382+
step_ids = mx.argpartition(-step, kth=top_k - 1, axis=-1)[:top_k]
383+
step_vals = mx.take(step, step_ids) / float(offset + 1)
384+
one_hot = (step_ids[:, None] == mx.arange(num_experts)).astype(mx.float32)
385+
scores = scores + (one_hot * step_vals[:, None]).sum(axis=0)
386+
mx.eval(scores)
387+
scored = [(idx, float(value)) for idx, value in enumerate(scores.tolist()) if float(value) > 0.0]
388+
scored.sort(key=lambda item: (-item[1], item[0]))
389+
return [int(idx) for idx, _value in scored[:budget]]
317390

318391
def _build_avail_masks(self, _token) -> List[set[int]] | None:
319392
"""Promote prefetched experts and return per-layer availability masks."""
@@ -339,16 +412,25 @@ def _runtime_stat_fields(self) -> dict:
339412
stats = dict(self._runtime_stats)
340413
total = int(stats.get("selection_hits", 0)) + int(stats.get("selection_misses", 0))
341414
pred_total = int(stats.get("prediction_total", 0))
415+
cold_pred_total = int(stats.get("cold_prediction_total", 0))
416+
all_resident_fast_path = bool(stats.get("all_resident_fast_path", False))
417+
resident_hit_rate = 1.0 if all_resident_fast_path else (
418+
float(stats.get("selection_hits", 0)) / max(total, 1)
419+
)
342420
return {
343-
"resident_hit_rate": round(float(stats.get("selection_hits", 0)) / max(total, 1), 4),
344-
"cache_hit_rate": round(float(stats.get("selection_hits", 0)) / max(total, 1), 4),
421+
"resident_hit_rate": round(resident_hit_rate, 4),
422+
"cache_hit_rate": round(resident_hit_rate, 4),
345423
"cache_hits": int(stats.get("selection_hits", 0)),
346424
"cache_misses": int(stats.get("selection_misses", 0)),
347425
"expert_selection_hits": int(stats.get("selection_hits", 0)),
348426
"expert_selection_misses": int(stats.get("selection_misses", 0)),
349427
"prediction_hit_rate": round(float(stats.get("prediction_hits", 0)) / max(pred_total, 1), 4),
350428
"prediction_hits": int(stats.get("prediction_hits", 0)),
351429
"prediction_total": pred_total,
430+
"cold_miss_predict_hit_rate": round(float(stats.get("cold_prediction_hits", 0)) / max(cold_pred_total, 1), 4),
431+
"cold_prediction_hits": int(stats.get("cold_prediction_hits", 0)),
432+
"cold_prediction_total": cold_pred_total,
433+
"all_resident_fast_path": all_resident_fast_path,
352434
"resident_vram_hits": int(stats.get("resident_vram_hits", 0)),
353435
"resident_ram_hits": int(stats.get("resident_ram_hits", 0)),
354436
"prefetch_queue_drops": int(stats.get("prefetch_queue_drops", 0)),

0 commit comments

Comments
 (0)