[draft] Multi lora#60
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a Multi-LoRA registry and routing mechanism to support per-sequence and per-token adapter dispatching in Megatron-Core models, including MoE layers. My review highlights potential issues with registry initialization, the safety of direct dictionary manipulation for cleanup, and performance considerations regarding per-token linear operations and buffer allocation in the routing path.
| num_slots: Number of concurrent adapters to support. | ||
| lora_config: ``peft.LoraConfig`` describing rank, target modules, etc. | ||
| """ | ||
| self._adapter_registry: dict = {} # logical_name → slot_name |
There was a problem hiding this comment.
| if root is not None: | ||
| root.__dict__.pop('_expert_lora_indices', None) | ||
| if post is not None: | ||
| self.__dict__.pop('_lora_post_dispatch', None) | ||
| return result |
There was a problem hiding this comment.
The cleanup logic in routed_experts_compute uses root.__dict__.pop and self.__dict__.pop. While this works, it is safer to use getattr(root, '_expert_lora_indices', None) and delattr or setattr(root, '_expert_lora_indices', None) to avoid potential issues with direct dictionary manipulation of objects.
| for idx, slot in _lpl._idx_to_name.items(): | ||
| if slot not in _lpl.lora_A: | ||
| continue | ||
| mask = token_idx == idx | ||
| if not mask.any(): | ||
| continue | ||
| rows = mask.nonzero(as_tuple=True)[0] | ||
| x_sub = x_flat.index_select(0, rows).to(result.dtype) | ||
| delta = F.linear(_lpl.lora_dropout[slot](x_sub), | ||
| _lpl.lora_A[slot].weight.to(result.dtype)) | ||
| delta = F.linear(delta, _lpl.lora_B[slot].weight.to(result.dtype)) | ||
| result_flat.index_add_(0, rows, (delta * _lpl.scaling[slot]).to(result.dtype)) |
There was a problem hiding this comment.
| delta = torch.zeros(total_tokens, out_feat, dtype=dtype, device=result.device) | ||
| apply_routed_lora(delta, x_flat, token_indices, | ||
| lora_A_by_idx, lora_B_by_idx, scaling_by_idx, dropout_by_idx) | ||
| result = result + delta.view_as(result).to(result.dtype) |
No description provided.