@@ -47,7 +47,11 @@ def __init__(self, model, config, ssd_dir: str = "./expert_cache_mlx"):
4747 self .last_stats : dict = {}
4848 self ._last_predicted : set [int ] = set ()
4949 self ._stats_lock = threading .Lock ()
50- self ._runtime_stats = {
50+ self ._runtime_stats = self ._new_runtime_stats ()
51+
52+ @staticmethod
53+ def _new_runtime_stats (all_resident_fast_path : bool = False ) -> dict :
54+ return {
5155 "resident_hits" : 0 ,
5256 "resident_misses" : 0 ,
5357 "resident_vram_hits" : 0 ,
@@ -60,6 +64,9 @@ def __init__(self, model, config, ssd_dir: str = "./expert_cache_mlx"):
6064 "prefetch_wait_time_s" : 0.0 ,
6165 "sync_ssd_loads" : 0 ,
6266 "on_demand_load_time_s" : 0.0 ,
67+ "cold_prediction_hits" : 0 ,
68+ "cold_prediction_total" : 0 ,
69+ "all_resident_fast_path" : bool (all_resident_fast_path ),
6370 }
6471
6572 def setup (self , warm_expert_ids : Optional [List [int ]] = None ):
@@ -90,6 +97,9 @@ def loader(eid: int) -> bool:
9097 if not was_hot and not was_warm :
9198 with self ._stats_lock :
9299 self ._runtime_stats ["sync_ssd_loads" ] += 1
100+ self ._runtime_stats ["cold_prediction_total" ] += 1
101+ if eid in self ._last_predicted :
102+ self ._runtime_stats ["cold_prediction_hits" ] += 1
93103 ok = self .store .promote_to_vram (eid )
94104 elapsed = time .monotonic () - t0
95105 with self ._stats_lock :
@@ -129,6 +139,14 @@ def touch(eid: int) -> None:
129139 moe .runtime_on_demand_loader = loader
130140 moe .runtime_touch_expert = touch
131141
142+ def _clear_runtime_hooks (self ):
143+ for layer in self .model .layers :
144+ moe = getattr (layer , "mlp" , None )
145+ if isinstance (moe , ChronosMLXMOE ):
146+ moe .runtime_miss_policy = "sync_on_demand"
147+ moe .runtime_on_demand_loader = None
148+ moe .runtime_touch_expert = None
149+
132150 # ── Prefetch background thread ────────────────────────────────
133151
134152 def _prefetch_loop (self ):
@@ -149,37 +167,63 @@ def _schedule_prefetch(self, expert_ids: List[int]):
149167 with self ._stats_lock :
150168 self ._runtime_stats ["prefetch_queue_drops" ] += 1
151169
152- def _prefetch_and_promote_window (self , expert_ids : List [int ], timeout_s : float = 0.012 ) -> None :
153- if not expert_ids or self .store .storage_format == "full_dram" :
170+ def _prefetch_and_promote_window (
171+ self ,
172+ prefetch_ids : List [int ],
173+ promote_ids : Optional [List [int ]] = None ,
174+ timeout_s : float = 0.012 ,
175+ ) -> None :
176+ if not prefetch_ids or self .store .storage_format == "full_dram" :
154177 return
178+ promote_ids = list (dict .fromkeys (int (eid ) for eid in (promote_ids or prefetch_ids )))
155179 pending = []
156- for eid in dict .fromkeys (int (eid ) for eid in expert_ids ):
180+ for eid in dict .fromkeys (int (eid ) for eid in prefetch_ids ):
157181 with self .store ._lock :
158182 if eid in self .store ._hot_lru :
159183 continue
160184 if eid in self .store ._warm and self .store ._layer_states_complete (self .store ._warm .get (eid )):
161- pending .append (eid )
162185 continue
163186 pending .append (eid )
164- if not pending :
165- return
166- self ._schedule_prefetch (pending )
187+ if pending :
188+ self ._schedule_prefetch (pending )
167189 deadline = time .monotonic () + max (0.0 , float (timeout_s ))
168190 while time .monotonic () < deadline :
169- self ._promote_ready (pending )
191+ self ._promote_ready (promote_ids )
170192 with self .store ._lock :
171193 ready = all (
172194 eid in self .store ._hot_lru
173195 or (eid in self .store ._warm and self .store ._layer_states_complete (self .store ._warm .get (eid )))
174- for eid in pending
196+ for eid in promote_ids
175197 )
176198 if ready :
177199 break
178200 time .sleep (0.001 )
179- self ._promote_ready (pending )
201+ self ._promote_ready (promote_ids )
180202 with self ._stats_lock :
181203 self ._runtime_stats ["prefetch_wait_time_s" ] += max (0.0 , time .monotonic () - (deadline - max (0.0 , float (timeout_s ))))
182204
205+ def _all_resident_fast_path_enabled (self ) -> bool :
206+ if self .store .storage_format == "full_dram" :
207+ return True
208+ num_experts = int (getattr (self .config , "num_experts" , 0 ) or 0 )
209+ if num_experts <= 0 :
210+ return False
211+ if int (getattr (self .store , "_capacity" , 0 ) or 0 ) < num_experts :
212+ return False
213+ if int (getattr (self .store , "_warm_capacity" , 0 ) or 0 ) < num_experts :
214+ return False
215+ return len (self .store .hot_expert_ids ()) >= num_experts
216+
217+ def _protect_prediction_window (self , expert_ids : List [int ]) -> None :
218+ capacity = max (
219+ 1 ,
220+ int (getattr (self .store , "_warm_capacity" , 1 ) or 1 ),
221+ int (getattr (self .store , "_capacity" , 1 ) or 1 ),
222+ )
223+ hot_ids = sorted (self .store .hot_expert_ids ())
224+ protected = list (dict .fromkeys ([int (eid ) for eid in expert_ids ] + hot_ids ))[:capacity ]
225+ self .store .set_protected_experts (protected )
226+
183227 def stop (self ):
184228 self ._stop .set ()
185229 try :
@@ -210,23 +254,15 @@ def generate(
210254 setup_mem = self ._memory_snapshot ()
211255 self ._last_predicted = set ()
212256 with self ._stats_lock :
213- self ._runtime_stats = {
214- "resident_hits" : 0 ,
215- "resident_misses" : 0 ,
216- "resident_vram_hits" : 0 ,
217- "resident_ram_hits" : 0 ,
218- "selection_hits" : 0 ,
219- "selection_misses" : 0 ,
220- "prediction_hits" : 0 ,
221- "prediction_total" : 0 ,
222- "prefetch_queue_drops" : 0 ,
223- "prefetch_wait_time_s" : 0.0 ,
224- "sync_ssd_loads" : 0 ,
225- "on_demand_load_time_s" : 0.0 ,
226- }
227- self ._install_runtime_hooks ()
257+ all_resident_fast_path = self ._all_resident_fast_path_enabled ()
258+ self ._runtime_stats = self ._new_runtime_stats (all_resident_fast_path )
259+ if all_resident_fast_path :
260+ self ._clear_runtime_hooks ()
261+ self .store .set_protected_experts ([])
262+ else :
263+ self ._install_runtime_hooks ()
228264 # ── Prefill phase: front-load expert IO ───────────────────────────
229- if scheduler is not None :
265+ if scheduler is not None and not all_resident_fast_path :
230266 import torch , numpy as np
231267 ids_np = np .array (input_ids .tolist (), dtype = np .int64 )
232268 ids_pt = torch .from_numpy (ids_np )
@@ -235,7 +271,7 @@ def generate(
235271
236272 # Prefill forward
237273 prefill_t0 = time .monotonic ()
238- if self . store . storage_format == "full_dram" :
274+ if all_resident_fast_path :
239275 prefill_masks = None
240276 else :
241277 prefill_masks = self ._build_avail_masks (None )
@@ -248,21 +284,41 @@ def generate(
248284 next_token = self ._sample (logits [:, - 1 , :], temperature , top_p )
249285 activated_ids : List [int ] = []
250286 tokens = 1
251- if scheduler is None and lookahead_probs is not None :
252- future_ids = self ._predict_future_experts (lookahead_probs )
287+ if not all_resident_fast_path and scheduler is None and lookahead_probs is not None :
288+ future_ids = self ._predict_future_experts (
289+ lookahead_probs ,
290+ capacity = int (getattr (self .store , "_warm_capacity" , self .config .num_experts ) or self .config .num_experts ),
291+ )
292+ immediate_ids = self ._predict_future_experts (
293+ lookahead_probs ,
294+ capacity = int (getattr (self .store , "_capacity" , 1 ) or 1 ),
295+ max_steps = 1 ,
296+ )
253297 self ._last_predicted = set (int (eid ) for eid in future_ids )
254- self ._prefetch_and_promote_window (future_ids , timeout_s = 0.025 )
298+ self ._protect_prediction_window (future_ids )
299+ self ._prefetch_and_promote_window (future_ids , promote_ids = immediate_ids , timeout_s = 0.025 )
255300 yield int (next_token .item ())
256301
257302 for _ in range (max_new_tokens - 1 ):
258- if scheduler is not None :
303+ if all_resident_fast_path :
304+ avail_masks = None
305+ elif scheduler is not None :
259306 avail_masks = self ._build_avail_masks (next_token )
260307 else :
261308 # LookaheadRouter-driven prefetch
262309 if lookahead_probs is not None :
263- future_ids = self ._predict_future_experts (lookahead_probs )
310+ future_ids = self ._predict_future_experts (
311+ lookahead_probs ,
312+ capacity = int (getattr (self .store , "_warm_capacity" , self .config .num_experts ) or self .config .num_experts ),
313+ )
314+ immediate_ids = self ._predict_future_experts (
315+ lookahead_probs ,
316+ capacity = int (getattr (self .store , "_capacity" , 1 ) or 1 ),
317+ max_steps = 1 ,
318+ )
264319 self ._last_predicted = set (int (eid ) for eid in future_ids )
265- self ._prefetch_and_promote_window (future_ids )
320+ self ._protect_prediction_window (future_ids )
321+ self ._prefetch_and_promote_window (future_ids , promote_ids = immediate_ids )
266322 avail_masks = self ._build_avail_masks (next_token )
267323
268324 token_in = next_token .reshape (1 , 1 )
@@ -298,22 +354,39 @@ def generate(
298354 ** self .store .stats (),
299355 }
300356
301- def _predict_future_experts (self , lookahead_probs : mx .array ) -> List [int ]:
357+ def _predict_future_experts (
358+ self ,
359+ lookahead_probs : mx .array ,
360+ capacity : Optional [int ] = None ,
361+ max_steps : Optional [int ] = None ,
362+ ) -> List [int ]:
302363 """Extract bounded top-k expert IDs for future steps."""
303364 # lookahead_probs: [B, S, K+1, E] — take last token, steps 1..K
304365 future = lookahead_probs [0 , - 1 , 1 :, :] # [K, E]
366+ if max_steps is not None :
367+ future = future [:max (0 , int (max_steps ))]
368+ if future .shape [0 ] <= 0 :
369+ return []
370+ num_experts = int (getattr (self .config , "num_experts" , future .shape [- 1 ]) or future .shape [- 1 ])
305371 top_k = max (1 , min (
306372 int (getattr (self .config , "num_experts_per_tok" , 1 ) or 1 ),
307- int ( getattr ( self . config , " num_experts" , future . shape [ - 1 ]) or future . shape [ - 1 ]) ,
373+ num_experts ,
308374 ))
309- ids : list [int ] = []
310- for k in range (future .shape [0 ]):
311- step = future [k ]
312- if top_k == 1 :
313- ids .append (int (mx .argmax (step ).item ()))
314- else :
315- ids .extend (int (v ) for v in mx .argpartition (- step , kth = top_k - 1 , axis = - 1 )[:top_k ].tolist ())
316- return list (dict .fromkeys (ids ))
375+ budget = num_experts if capacity is None else max (1 , min (num_experts , int (capacity )))
376+ scores = mx .zeros ((num_experts ,), dtype = mx .float32 )
377+ for offset in range (future .shape [0 ]):
378+ step = future [offset ].astype (mx .float32 )
379+ if top_k >= num_experts :
380+ scores = scores + step / float (offset + 1 )
381+ continue
382+ step_ids = mx .argpartition (- step , kth = top_k - 1 , axis = - 1 )[:top_k ]
383+ step_vals = mx .take (step , step_ids ) / float (offset + 1 )
384+ one_hot = (step_ids [:, None ] == mx .arange (num_experts )).astype (mx .float32 )
385+ scores = scores + (one_hot * step_vals [:, None ]).sum (axis = 0 )
386+ mx .eval (scores )
387+ scored = [(idx , float (value )) for idx , value in enumerate (scores .tolist ()) if float (value ) > 0.0 ]
388+ scored .sort (key = lambda item : (- item [1 ], item [0 ]))
389+ return [int (idx ) for idx , _value in scored [:budget ]]
317390
318391 def _build_avail_masks (self , _token ) -> List [set [int ]] | None :
319392 """Promote prefetched experts and return per-layer availability masks."""
@@ -339,16 +412,25 @@ def _runtime_stat_fields(self) -> dict:
339412 stats = dict (self ._runtime_stats )
340413 total = int (stats .get ("selection_hits" , 0 )) + int (stats .get ("selection_misses" , 0 ))
341414 pred_total = int (stats .get ("prediction_total" , 0 ))
415+ cold_pred_total = int (stats .get ("cold_prediction_total" , 0 ))
416+ all_resident_fast_path = bool (stats .get ("all_resident_fast_path" , False ))
417+ resident_hit_rate = 1.0 if all_resident_fast_path else (
418+ float (stats .get ("selection_hits" , 0 )) / max (total , 1 )
419+ )
342420 return {
343- "resident_hit_rate" : round (float ( stats . get ( "selection_hits" , 0 )) / max ( total , 1 ) , 4 ),
344- "cache_hit_rate" : round (float ( stats . get ( "selection_hits" , 0 )) / max ( total , 1 ) , 4 ),
421+ "resident_hit_rate" : round (resident_hit_rate , 4 ),
422+ "cache_hit_rate" : round (resident_hit_rate , 4 ),
345423 "cache_hits" : int (stats .get ("selection_hits" , 0 )),
346424 "cache_misses" : int (stats .get ("selection_misses" , 0 )),
347425 "expert_selection_hits" : int (stats .get ("selection_hits" , 0 )),
348426 "expert_selection_misses" : int (stats .get ("selection_misses" , 0 )),
349427 "prediction_hit_rate" : round (float (stats .get ("prediction_hits" , 0 )) / max (pred_total , 1 ), 4 ),
350428 "prediction_hits" : int (stats .get ("prediction_hits" , 0 )),
351429 "prediction_total" : pred_total ,
430+ "cold_miss_predict_hit_rate" : round (float (stats .get ("cold_prediction_hits" , 0 )) / max (cold_pred_total , 1 ), 4 ),
431+ "cold_prediction_hits" : int (stats .get ("cold_prediction_hits" , 0 )),
432+ "cold_prediction_total" : cold_pred_total ,
433+ "all_resident_fast_path" : all_resident_fast_path ,
352434 "resident_vram_hits" : int (stats .get ("resident_vram_hits" , 0 )),
353435 "resident_ram_hits" : int (stats .get ("resident_ram_hits" , 0 )),
354436 "prefetch_queue_drops" : int (stats .get ("prefetch_queue_drops" , 0 )),
0 commit comments