diff --git a/cookbook/transformers/ep_fsdp2_lora_qwen3_5_moe.py b/cookbook/transformers/ep_fsdp2_lora_qwen3_5_moe.py index 5d5088182..143789974 100644 --- a/cookbook/transformers/ep_fsdp2_lora_qwen3_5_moe.py +++ b/cookbook/transformers/ep_fsdp2_lora_qwen3_5_moe.py @@ -17,7 +17,7 @@ from twinkle.model import TransformersModel from twinkle.preprocessor import SelfCognitionProcessor from twinkle.utils.framework import Torch -from twinkle.kernel import kernelize_model +from twinkle.kernel import kernelize, npu_builtin logger = get_logger() args = CLI.from_args() @@ -95,7 +95,7 @@ def train(): ) # npu patch if Torch.is_npu_available(): - model = kernelize_model(model, mode='train', device='npu') + model = kernelize(model, npu_builtin(model)) lora_cfg = _build_lora_config(ENABLE_EP) model.add_adapter_to_model(args.lora.adapter_name, lora_cfg, gradient_accumulation_steps=args.training.gradient_accumulation_steps) diff --git a/cookbook/transformers/fsdp2.py b/cookbook/transformers/fsdp2.py index a3b4da645..43ff78e08 100644 --- a/cookbook/transformers/fsdp2.py +++ b/cookbook/transformers/fsdp2.py @@ -12,7 +12,7 @@ from twinkle.model import TransformersModel from twinkle.preprocessor import SelfCognitionProcessor from twinkle.utils.framework import Torch -from twinkle.kernel import kernelize_model +from twinkle.kernel import kernelize, npu_builtin logger = get_logger() args = CLI.from_args() @@ -59,7 +59,7 @@ def train(): model.model._no_split_modules = {'Qwen3_5DecoderLayer'} # npu patch if Torch.is_npu_available(): - model = kernelize_model(model, mode='train', device='npu') + model = kernelize(model, npu_builtin(model)) lora_config = LoraConfig(**args.get_lora_args()) model.add_adapter_to_model( diff --git a/cookbook/transformers/sp_fsdp_dense.py b/cookbook/transformers/sp_fsdp_dense.py index 56f22c801..b41a0400f 100644 --- a/cookbook/transformers/sp_fsdp_dense.py +++ b/cookbook/transformers/sp_fsdp_dense.py @@ -9,7 +9,7 @@ from twinkle.model import TransformersModel from twinkle.preprocessor import SelfCognitionProcessor from twinkle.utils.framework import Torch -from twinkle.kernel import kernelize_model +from twinkle.kernel import kernelize, npu_builtin logger = get_logger() args = CLI.from_args() @@ -68,7 +68,7 @@ def train(): ) # npu patch if Torch.is_npu_available(): - model = kernelize_model(model, mode='train', device='npu') + model = kernelize(model, npu_builtin(model)) lora_config = LoraConfig(**args.get_lora_args()) model.add_adapter_to_model(args.lora.adapter_name, lora_config, gradient_accumulation_steps=args.training.gradient_accumulation_steps) diff --git a/docs/source_en/Components/Kernel/Kernel.md b/docs/source_en/Components/Kernel/Kernel.md index d587b5400..fe0a505a3 100644 --- a/docs/source_en/Components/Kernel/Kernel.md +++ b/docs/source_en/Components/Kernel/Kernel.md @@ -1,308 +1,139 @@ -# Twinkle Kernel Module +# Twinkle Kernel -The Twinkle Kernel Module provides two kernel replacement paths for accelerating models during training and inference: +`twinkle.kernel` exposes a mapping-driven kernel replacement API. Replacing one +implementation with another collapses to a single `kernelize(model, mapping)` +call. -* **Layer-level kernelize** - Replace entire `nn.Module` implementations with optimized kernels. -* **Function-level kernelize** - Monkey-patch specific functions inside a Python module. +The public surface is exactly three symbols: -These two approaches can be used independently or together via a unified registration and application entry point. +| Symbol | Purpose | +| --- | --- | +| `kernelize(model, mapping)` | Apply ``mapping`` to ``model`` (in place) and return it | +| `npu_builtin(model=None)` | Return the Ascend NPU built-in mapping (composes with user mappings) | +| `hub(ref, *, revision=None, version=None, backend=None, trust_remote_code=False)` | Build a ``HubRef`` for use as a mapping value; the actual Hub download is deferred to ``kernelize`` | ---- +## Mapping semantics -## Overview: Two Kernelization Paths +`mapping` keys describe the target to replace: -| Path | Granularity | Typical Use Cases | -| -------------- | -------------------- | -------------------------------- | -| Layer-level | Whole `nn.Module` | Linear / Conv / MLP / Attention | -| Function-level | Individual functions | Hot paths, math ops, activations | +- `type[nn.Module]` subclass — replace **every** instance whose exact type matches (`m.__class__ = impl`; subclasses are **not** touched) +- `str` of the form `'pkg.sub.attr'` or `'pkg.sub.ClassName.attr'` — `setattr(target, attr, impl)` ---- +`mapping` values describe the replacement: -## Layer-Level Kernel Replacement +- `type[nn.Module]` subclass — used as the impl class. The class' `__init__` is **never** invoked; its forward must work against the attributes the original instance already has +- `Callable` — assigned with `setattr` +- `dict[str, V]` — device → impl dispatch. Device is inferred from the model; entries without a matching key are **silently skipped** +- `HubRef` — built via `hub(...)`; resolved lazily -### When to Use +Device is inferred from `next(model.parameters()).device.type` (falling back to buffers, then `'cpu'`). -* You have a complete kernel implementation for a layer -* You want model-wide replacement of specific `nn.Module` types -* Suitable for both training and inference +## Examples ---- - -### Example 1: Local Kernel Repo - -Use this when: - -* Kernel implementations live in a local repository -* You want to replace layers in HuggingFace or custom models +### Enable the full NPU built-in bundle ```python -from twinkle.kernel import ( - kernelize_model, - register_layer_kernel, - register_external_layer, -) -from transformers import Qwen2Config, Qwen2ForCausalLM -from transformers.models.qwen2.modeling_qwen2 import Qwen2MLP - -# 1) Register the layer kernel from a local repo -register_layer_kernel( - kernel_name="MyAwesomeMLP", - repo_path="/path/to/local/repo", - package_name="my_kernels", - layer_name="Qwen2MLPTrainingKernel", - device="cuda", - mode="train", -) - -# 2) Bind external layer to kernel name -register_external_layer(Qwen2MLP, "MyAwesomeMLP") - -# 3) Build the model and apply kernelization -config = Qwen2Config( - hidden_size=128, - num_hidden_layers=1, - num_attention_heads=4, - num_key_value_heads=4, - intermediate_size=256, - use_cache=False, -) -model = Qwen2ForCausalLM(config) -model = kernelize_model(model, mode="train", device="cuda", use_fallback=True) -``` - ---- - -### Example 2: Hub Kernel Repo +import torch +from twinkle.kernel import kernelize, npu_builtin -Use this when: +if torch.npu.is_available(): + model = kernelize(model, npu_builtin(model)) +``` -* The kernel is hosted on a Hub +### Custom class replacement ```python -import torch -import torch.nn as nn -from twinkle.kernel import ( - kernelize_model, - register_layer_kernel, - register_external_layer, -) - -# 1) Define the custom layer -class SiluAndMul(nn.Module): - def forward(self, x: torch.Tensor) -> torch.Tensor: - x1, x2 = x.chunk(2, dim=-1) - return nn.functional.silu(x1) * x2 - -# 2) Register the Hub kernel and bind the layer -register_layer_kernel( - kernel_name="SiluAndMulKernel", - repo_id="kernels-community/activation", - layer_name="SiluAndMul", - device="cuda", - mode="train", -) -register_external_layer(SiluAndMul, "SiluAndMulKernel") - -# 3) Apply to a model -class SimpleModel(nn.Module): - def __init__(self): - super().__init__() - self.activation = SiluAndMul() - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.activation(x) - -model = SimpleModel() -model = kernelize_model(model, mode="train", device="cuda", use_fallback=True) +from transformers.models.qwen2.modeling_qwen2 import Qwen2RMSNorm +from twinkle.kernel import kernelize + +model = kernelize(model, {Qwen2RMSNorm: MyRMSNorm}) ``` ---- +### Built-in + custom override -## Local Kernel Repo (Minimal) +```python +from twinkle.kernel import kernelize, npu_builtin -A local kernel repository is a regular Python package. -At minimum, it only needs a `layers.py` file for layer-level kernels. +model = kernelize(model, {**npu_builtin(model), Qwen2RMSNorm: MyRMSNorm}) +``` -```text -# Repo layout: -my_kernels/ # Local kernel repository (Python package) -├── __init__.py # Package entry -└── layers.py # Layer-level kernel implementations +Plain dict merge — later keys override earlier ones. -``` +### Hub kernel (HF Hub format) ```python -# my_kernels/__init__.py -from . import layers -__all__ = ["layers"] +from twinkle.kernel import kernelize, hub +from my_pkg import SiluAndMul -# my_kernels/layers.py -import torch -import torch.nn as nn - -class Qwen2MLPTrainingKernel(nn.Module): - def forward(self, x: torch.Tensor) -> torch.Tensor: - gate = self.gate_proj(x) - up = self.up_proj(x) - return self.down_proj(self.act_fn(gate) * up) +model = kernelize(model, { + SiluAndMul: hub('kernels-community/activation:SiluAndMul', version=1), +}) ``` ---- - -## Function-Level Kernel Replacement +Exactly one of `revision` / `version` must be passed. The `kernels` package is imported lazily; absence raises a clear "install kernels" error. -### When to Use +### Function-level replacement -* You only need to accelerate a small number of hot functions -* Replacing the entire layer is unnecessary or impractical -* Common for math ops, activations, or utility functions +```python +from twinkle.kernel import kernelize +from twinkle.kernel.npu_impls.rotary import npu_apply_rotary_pos_emb ---- +model = kernelize(model, { + 'transformers.models.qwen2.modeling_qwen2.apply_rotary_pos_emb': + npu_apply_rotary_pos_emb, +}) +``` -### Example 1: Batch Registration (Simple Case) +### Cross-device mapping (NPU enabled, CUDA skipped) ```python -from twinkle.kernel import register_kernels, kernelize_model - -# 1) Register function kernels -config = { - "functions": { - "add": { - "target_module": "my_pkg.math_ops", - "func_impl": lambda x, y: x + y + 1, - "device": "cuda", - "mode": "inference", - }, - }, -} -register_kernels(config) +from twinkle.kernel import kernelize -# 2) Apply (model can be None when only functions are used) -kernelize_model(model=None, mode="inference", device="cuda", use_fallback=True) +model = kernelize(model, { + Qwen2RMSNorm: {'npu': NpuRMSNorm, 'cuda': CudaRMSNorm}, +}) ``` ---- +Safe to run on CUDA — entries whose dict misses the current device just skip. -### Example 2: Advanced Function Sources (Full Control) +## NPU built-in coverage -Use this when: +`npu_builtin(model)` returns a dict that (as available transformers modules permit) covers: -* Use when different functions come from different sources (impl / repo / hub) or need compile/backward flags. +- RMSNorm class replacement for Qwen2 / Qwen3 / Qwen3-MoE / Qwen2.5-VL / Qwen3.5 / Qwen3.5-MoE families +- `apply_rotary_pos_emb` function replacement (fused RoPE) for the same families +- SwiGLU fused replacement for the MLP variants +- `Experts.forward` and `SparseMoeBlock.forward` for Qwen3-MoE / Qwen3.5-MoE +- GatedRMSNorm forward for Qwen3.5 / Qwen3.5-MoE +- `apply_multimodal_rotary_pos_emb` for Qwen2.5-VL +- Global SDPA replacement (one-shot side effect on `ALL_ATTENTION_FUNCTIONS['sdpa']`) +- Qwen3.5 Flash Linear Attention enablement (one-shot side effect + per-instance traversal, triggered inside `npu_builtin(model)`) + +**Not included by default:** the NPU replacement for `transformers.integrations.moe._grouped_mm`. Without Expert Parallelism the contiguous-copy overhead is ~8x. Opt in explicitly when EP is enabled: ```python -from twinkle.kernel.function import ( - register_function_kernel, - apply_function_kernel, -) -import torch.nn as nn -from twinkle.kernel import kernelize_model - -TARGET_MODULE = "my_pkg.math_ops" - -# 1) Direct implementation -def fast_add(x, y): - return x + y + 1 - -register_function_kernel( - func_name="add", - target_module=TARGET_MODULE, - func_impl=fast_add, - device="cuda", - mode="inference", -) - -# 2) Repo object (FuncRepositoryProtocol) -class MyFuncRepo: - def load(self): - return MyKernelFunc - -class MyKernelFunc(nn.Module): - def forward(self, x, y): - return x * y - -register_function_kernel( - func_name="mul", - target_module=TARGET_MODULE, - repo=MyFuncRepo(), - device="cuda", - mode="compile", -) - -# 3) Hub repo -register_function_kernel( - func_name="silu_and_mul", - target_module="my_pkg.activations", - repo_id="kernels-community/activation", - revision="main", # or version="0.1.0" - device="cuda", - mode="inference", -) - -# 4) Apply function kernels -applied = apply_function_kernel( - target_module=TARGET_MODULE, - device="cuda", - mode="inference", - strict=False, -) -print("patched:", applied) - -# 5) Optional: unified entry via kernelize_model -model = nn.Sequential(nn.Linear(8, 8), nn.ReLU()) -kernelize_model(model=model, mode="inference", device="cuda", use_fallback=True) +from twinkle.kernel import kernelize, npu_builtin +from twinkle.kernel.npu_impls.moe import npu_grouped_mm + +mapping = { + **npu_builtin(model), + 'transformers.integrations.moe._grouped_mm': {'npu': npu_grouped_mm}, +} +model = kernelize(model, mapping) ``` ---- +## Environment variables -## Unified Layer + Function Batch Registration +Only two remain: -### When to Use +- `TWINKLE_NPU_FLA` — Qwen3.5 FLA switch (default on; `0`/`false` to disable) +- `TWINKLE_NPU_GATED_RMSNorm_FP32` — force FP32 in Gated RMSNorm forward (default off) -* Framework-level integration -* A single configuration entry point is preferred -* Managing both layer and function kernels together +The legacy `TWINKLE_NPU_PATCH` / `TWINKLE_NPU_FUSED_OPS` / `TWINKLE_NPU_GMM_PATCH` / `TWINKLE_USE_KERNELS` are gone — they're now "include the entry in the mapping or don't" decisions. -```python -from twinkle.kernel import register_kernels, kernelize_model -import torch.nn as nn - -# 1) Register layer + function kernels -config = { - "layers": { - "linear": { - "repo_id": "kernels-community/linear", - "layer_name": "Linear", - "version": "0.1.0", - "device": "cuda", - "mode": "train", - }, - "conv2d": { - "repo_path": "/path/to/local/repo", - "package_name": "my_kernels", - "layer_name": "Conv2d", - "device": "cuda", - }, - }, - "functions": { - "add": { - "target_module": "my_pkg.math_ops", - "func_impl": lambda x, y: x + y + 1, - "device": "cuda", - "mode": "inference", - }, - "relu": { - "target_module": "my_pkg.activations", - "repo_id": "kernels-community/activation", - "revision": "main", - "device": "cuda", - }, - }, -} -register_kernels(config) +## Caveats -# 2) Apply via kernelize_model -model = nn.Sequential(nn.Linear(8, 8), nn.ReLU()) -kernelize_model(model=model, mode="train", device="cuda", use_fallback=True) -``` +- `m.__class__ = impl_cls` is Python class-replacement magic. The impl class **must** override only `forward` (and helpers); defining `__init__` is incompatible with the contract +- Exact match: `type(m) is target_cls`. Subclasses of `target_cls` are not replaced — add them to the mapping yourself +- `kernelize` is idempotent under repeated calls +- There is no `unkernelize` — replacement is one-way diff --git "a/docs/source_zh/\347\273\204\344\273\266/\345\206\205\346\240\270/Kernel.md" "b/docs/source_zh/\347\273\204\344\273\266/\345\206\205\346\240\270/Kernel.md" index 89ae37ca5..17e4d4a28 100644 --- "a/docs/source_zh/\347\273\204\344\273\266/\345\206\205\346\240\270/Kernel.md" +++ "b/docs/source_zh/\347\273\204\344\273\266/\345\206\205\346\240\270/Kernel.md" @@ -1,307 +1,137 @@ # Twinkle Kernel 模块 -Twinkle Kernel 模块提供了两条内核替换路径,用于加速训练和推理: +`twinkle.kernel` 提供一个 mapping 驱动的内核替换接口,把“用一种实现替换模型里的另一种实现”压缩为一次 `kernelize(model, mapping)` 调用。 -* **层级 Kernelize(Layer-level kernelize)** - 使用优化内核替换完整的 `nn.Module` 实现。 -* **函数级 Kernelize(Function-level kernelize)** - 对 Python 模块中的特定函数进行 monkey-patch。 +公开符号只有三个: -这两种方式可以独立使用,也可以通过统一入口组合使用。 +| 符号 | 作用 | +| --- | --- | +| `kernelize(model, mapping)` | 在 `model` 上应用 `mapping`,原地修改后返回 | +| `npu_builtin(model=None)` | 返回 Ascend NPU 内置替换的 mapping dict(可与用户 mapping 自由组合) | +| `hub(ref, *, revision=None, version=None, backend=None, trust_remote_code=False)` | 构造一个 `HubRef`,用作 mapping value;真实下载推迟到 `kernelize` 执行 | ---- +## Mapping 语义 -## 概览:两条 Kernelize 路径 +`mapping` 的 **key** 表示要替换的目标: -| 路径 | 粒度 | 典型场景 | -| --- | --- | --- | -| 层级替换 | 整个 `nn.Module` | Linear / Conv / MLP / Attention | -| 函数级替换 | 单个函数 | 热点路径、数学算子、激活函数 | +- `type[nn.Module]` 子类:替换模型里**所有**该精确类型的实例(`m.__class__ = impl_class`,**不包含**子类) +- `str` 形如 `'pkg.sub.attr'` 或 `'pkg.sub.ClassName.attr'`:`setattr(target, attr, impl)` ---- +**value** 表示用什么替换: -## 层级内核替换(Layer-Level) +- `type[nn.Module]` 子类:直接作为 impl 类。该类**不会被 `__init__` 调用**,必须只依赖原 instance 已经有的 attribute(weight / eps / ...)正确工作 +- `Callable`:直接 `setattr` 上去 +- `dict[str, V]`:device → impl 嵌套分派。从 `model` 推断当前 device,未匹配则**静默跳过** +- `HubRef`:通过 `hub(...)` 构造的 Hub 引用,延迟加载 -### 适用场景 +device 从 `next(model.parameters()).device.type` 推断(无参数则用 buffers,再无则为 `'cpu'`)。 -* 你已经有完整的层内核实现 -* 希望在模型中批量替换某类 `nn.Module` -* 同时适用于训练与推理 +## 场景示例 ---- - -### 示例 1:本地 Kernel 仓库 - -适用于: - -* 内核实现位于本地仓库 -* 希望替换 HuggingFace 或自定义模型中的层 +### 启用全部 NPU 内置优化 ```python -from twinkle.kernel import ( - kernelize_model, - register_layer_kernel, - register_external_layer, -) -from transformers import Qwen2Config, Qwen2ForCausalLM -from transformers.models.qwen2.modeling_qwen2 import Qwen2MLP - -# 1) 从本地仓库注册层内核 -register_layer_kernel( - kernel_name="MyAwesomeMLP", - repo_path="/path/to/local/repo", - package_name="my_kernels", - layer_name="Qwen2MLPTrainingKernel", - device="cuda", - mode="train", -) - -# 2) 绑定外部层与内核名 -register_external_layer(Qwen2MLP, "MyAwesomeMLP") - -# 3) 构建模型并应用内核替换 -config = Qwen2Config( - hidden_size=128, - num_hidden_layers=1, - num_attention_heads=4, - num_key_value_heads=4, - intermediate_size=256, - use_cache=False, -) -model = Qwen2ForCausalLM(config) -model = kernelize_model(model, mode="train", device="cuda", use_fallback=True) +import torch +from twinkle.kernel import kernelize, npu_builtin + +if torch.npu.is_available(): + model = kernelize(model, npu_builtin(model)) ``` ---- +### 自定义类替换 -### 示例 2:Hub Kernel 仓库 +```python +from transformers.models.qwen2.modeling_qwen2 import Qwen2RMSNorm +from twinkle.kernel import kernelize -适用于: +model = kernelize(model, {Qwen2RMSNorm: MyRMSNorm}) +``` -* 内核托管在 Hub 上 +### 内置 + 自定义混合 ```python -import torch -import torch.nn as nn -from twinkle.kernel import ( - kernelize_model, - register_layer_kernel, - register_external_layer, -) - -# 1) 定义自定义层 -class SiluAndMul(nn.Module): - def forward(self, x: torch.Tensor) -> torch.Tensor: - x1, x2 = x.chunk(2, dim=-1) - return nn.functional.silu(x1) * x2 - -# 2) 注册 Hub 内核并绑定层 -register_layer_kernel( - kernel_name="SiluAndMulKernel", - repo_id="kernels-community/activation", - layer_name="SiluAndMul", - device="cuda", - mode="train", -) -register_external_layer(SiluAndMul, "SiluAndMulKernel") - -# 3) 应用到模型 -class SimpleModel(nn.Module): - def __init__(self): - super().__init__() - self.activation = SiluAndMul() - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.activation(x) - -model = SimpleModel() -model = kernelize_model(model, mode="train", device="cuda", use_fallback=True) -``` +from twinkle.kernel import kernelize, npu_builtin ---- - -## 本地 Kernel 仓库(最小结构) +model = kernelize(model, {**npu_builtin(model), Qwen2RMSNorm: MyRMSNorm}) +``` -本地 kernel 仓库本质上是一个普通 Python 包。 -最少只需要一个 `layers.py` 来放层级内核实现。 +后写入的 key 会覆盖前面的,普通 dict 合并语义。 -```text -# 仓库结构: -my_kernels/ # 本地 kernel 仓库(Python 包) -├── __init__.py # 包入口 -└── layers.py # 层级 kernel 实现 -``` +### Hub Kernel(HF Hub 格式) ```python -# my_kernels/__init__.py -from . import layers -__all__ = ["layers"] +from twinkle.kernel import kernelize, hub +from my_pkg import SiluAndMul -# my_kernels/layers.py -import torch -import torch.nn as nn - -class Qwen2MLPTrainingKernel(nn.Module): - def forward(self, x: torch.Tensor) -> torch.Tensor: - gate = self.gate_proj(x) - up = self.up_proj(x) - return self.down_proj(self.act_fn(gate) * up) +model = kernelize(model, { + SiluAndMul: hub('kernels-community/activation:SiluAndMul', version=1), +}) ``` ---- - -## 函数级内核替换(Function-Level) +`revision` 与 `version` 二选一必传。`hub(...)` 触发 `kernels` 包的延迟 import,未安装时会提示 `pip install kernels`。 -### 适用场景 +### 函数级替换 -* 只需要加速少量热点函数 -* 不适合或不需要替换整个层 -* 常用于数学算子、激活函数、工具函数 +```python +from twinkle.kernel import kernelize +from twinkle.kernel.npu_impls.rotary import npu_apply_rotary_pos_emb ---- +model = kernelize(model, { + 'transformers.models.qwen2.modeling_qwen2.apply_rotary_pos_emb': + npu_apply_rotary_pos_emb, +}) +``` -### 示例 1:批量注册(简单场景) +### 跨设备 mapping(NPU 启用、CUDA 跳过) ```python -from twinkle.kernel import register_kernels, kernelize_model - -# 1) 注册函数内核 -config = { - "functions": { - "add": { - "target_module": "my_pkg.math_ops", - "func_impl": lambda x, y: x + y + 1, - "device": "cuda", - "mode": "inference", - }, - }, -} -register_kernels(config) +from twinkle.kernel import kernelize -# 2) 应用(仅函数替换时 model 可为 None) -kernelize_model(model=None, mode="inference", device="cuda", use_fallback=True) +model = kernelize(model, { + Qwen2RMSNorm: {'npu': NpuRMSNorm, 'cuda': CudaRMSNorm}, +}) ``` ---- +在 CUDA 模型上跑也安全:未匹配 device 的 entry 不会替换、不会报错。 -### 示例 2:高级函数来源(完整控制) +## 内置 NPU 优化 -适用于: +`npu_builtin(model)` 返回的 dict 至少包含以下覆盖(实际条目随 transformers 已安装的 modeling 模块动态收集): -* 不同函数来自不同来源(impl / repo / hub),或需要 compile/backward 等标志。 +- Qwen2 / Qwen3 / Qwen3-MoE / Qwen2.5-VL / Qwen3.5 / Qwen3.5-MoE 系列的 RMSNorm 类替换 +- 同上系列的 `apply_rotary_pos_emb` 函数替换(融合 RoPE) +- 同上系列 MLP 的 SwiGLU 融合替换 +- Qwen3-MoE / Qwen3.5-MoE 的 `Experts.forward` 与 `SparseMoeBlock.forward` 替换 +- Qwen3.5 / Qwen3.5-MoE 的 GatedRMSNorm forward 替换 +- Qwen2.5-VL 的 `apply_multimodal_rotary_pos_emb` 替换 +- 全局 SDPA 替换(一次性副作用,写入 `ALL_ATTENTION_FUNCTIONS['sdpa']`) +- Qwen3.5 Flash Linear Attention 启用(一次性副作用 + 实例遍历,由 `npu_builtin(model)` 内部触发) + +**未默认包含** `transformers.integrations.moe._grouped_mm` 的 NPU 替换(在没有 Expert Parallelism 时会带来约 8x 开销)。需要时手动加入: ```python -from twinkle.kernel.function import ( - register_function_kernel, - apply_function_kernel, -) -import torch.nn as nn -from twinkle.kernel import kernelize_model - -TARGET_MODULE = "my_pkg.math_ops" - -# 1) 直接传入实现 -def fast_add(x, y): - return x + y + 1 - -register_function_kernel( - func_name="add", - target_module=TARGET_MODULE, - func_impl=fast_add, - device="cuda", - mode="inference", -) - -# 2) Repo 对象(FuncRepositoryProtocol) -class MyFuncRepo: - def load(self): - return MyKernelFunc - -class MyKernelFunc(nn.Module): - def forward(self, x, y): - return x * y - -register_function_kernel( - func_name="mul", - target_module=TARGET_MODULE, - repo=MyFuncRepo(), - device="cuda", - mode="compile", -) - -# 3) Hub 仓库 -register_function_kernel( - func_name="silu_and_mul", - target_module="my_pkg.activations", - repo_id="kernels-community/activation", - revision="main", # 或 version="0.1.0" - device="cuda", - mode="inference", -) - -# 4) 应用函数内核 -applied = apply_function_kernel( - target_module=TARGET_MODULE, - device="cuda", - mode="inference", - strict=False, -) -print("patched:", applied) - -# 5) 可选:通过 kernelize_model 统一应用 -model = nn.Sequential(nn.Linear(8, 8), nn.ReLU()) -kernelize_model(model=model, mode="inference", device="cuda", use_fallback=True) +from twinkle.kernel import kernelize, npu_builtin +from twinkle.kernel.npu_impls.moe import npu_grouped_mm + +mapping = { + **npu_builtin(model), + 'transformers.integrations.moe._grouped_mm': {'npu': npu_grouped_mm}, +} +model = kernelize(model, mapping) ``` ---- +## 环境变量 -## 层级 + 函数级统一批量注册 +只有两个保留: -### 适用场景 +- `TWINKLE_NPU_FLA`:Qwen3.5 FLA 开关(默认开,设为 `0`/`false` 关闭) +- `TWINKLE_NPU_GATED_RMSNorm_FP32`:将 Gated RMSNorm 强制升到 FP32 计算(默认关) -* 需要框架级统一集成 -* 希望通过单一配置入口管理 -* 同时管理层和函数两类内核 +旧的 `TWINKLE_NPU_PATCH` / `TWINKLE_NPU_FUSED_OPS` / `TWINKLE_NPU_GMM_PATCH` / `TWINKLE_USE_KERNELS` 已移除——这些都改成"是否把对应 entry 写进 mapping"的显式选择。 -```python -from twinkle.kernel import register_kernels, kernelize_model -import torch.nn as nn - -# 1) 注册层级 + 函数级内核 -config = { - "layers": { - "linear": { - "repo_id": "kernels-community/linear", - "layer_name": "Linear", - "version": "0.1.0", - "device": "cuda", - "mode": "train", - }, - "conv2d": { - "repo_path": "/path/to/local/repo", - "package_name": "my_kernels", - "layer_name": "Conv2d", - "device": "cuda", - }, - }, - "functions": { - "add": { - "target_module": "my_pkg.math_ops", - "func_impl": lambda x, y: x + y + 1, - "device": "cuda", - "mode": "inference", - }, - "relu": { - "target_module": "my_pkg.activations", - "repo_id": "kernels-community/activation", - "revision": "main", - "device": "cuda", - }, - }, -} -register_kernels(config) +## 注意事项 -# 2) 通过 kernelize_model 应用 -model = nn.Sequential(nn.Linear(8, 8), nn.ReLU()) -kernelize_model(model=model, mode="train", device="cuda", use_fallback=True) -``` +- `m.__class__ = impl_cls` 是 Python class 替换魔法。impl 类**必须**只覆盖 `forward`(以及辅助方法),不能定义 `__init__`,否则原 instance 的 attribute 会与 impl 的预期错位 +- 精确匹配:`type(m) is target_cls`。继承自 `target_cls` 的子类不会被替换;如需替换,把子类也放进 mapping +- 调用 `kernelize` 多次是幂等的(`__class__` 已是 impl 时再设一次无害) +- 没有 `unkernelize`——替换是单向的 diff --git a/scripts/kernelize_demo.py b/scripts/kernelize_demo.py new file mode 100644 index 000000000..aaa194489 --- /dev/null +++ b/scripts/kernelize_demo.py @@ -0,0 +1,260 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""End-to-end demo for ``twinkle.kernel.core.kernelize``. + +Run with:: + + conda run -n twinkle python scripts/kernelize_demo.py + +The script exercises three replacement modes on CPU: + +1. Class replacement - rewrite ``__class__`` of matching ``nn.Module`` instances. +2. Attribute replacement - monkey-patch a module/function attribute via dotted path. +3. Hub replacement - lazy-load a kernel from a mocked ``kernels`` package. +""" + +from __future__ import annotations + +import sys +from pathlib import Path + +# Make the local ``src`` importable when running the script directly. +_PROJECT_ROOT = Path(__file__).resolve().parents[1] +if str(_PROJECT_ROOT / "src") not in sys.path: + sys.path.insert(0, str(_PROJECT_ROOT / "src")) + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import Qwen3Config +from transformers.models.qwen3.modeling_qwen3 import Qwen3MLP + +from twinkle.kernel.core import hub, kernelize + + +def _assert(cond: bool, msg: str) -> None: + if not cond: + raise AssertionError(msg) + + +def _describe(obj) -> str: + """Best-effort name for a callable (plain function or kernels ``Func``).""" + qn = getattr(obj, "__qualname__", None) + if qn: + return qn + return f"<{type(obj).__module__}.{type(obj).__name__}>" + + +class FusedQwen3MLP(Qwen3MLP): + """Pretend fused kernel: same gated MLP + a constant +1.0 bias.""" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + print("[patched] FusedQwen3MLP.forward called") + return super().forward(x) + 1.0 + + +def _build_mlp() -> Qwen3MLP: + config = Qwen3Config( + hidden_size=8, + intermediate_size=16, + num_hidden_layers=1, + num_attention_heads=1, + num_key_value_heads=1, + ) + return Qwen3MLP(config) + + +# --------------------------------------------------------------------------- # +# Demo 1: Class replacement +# --------------------------------------------------------------------------- # +def demo_class_replacement() -> None: + print("=" * 60) + print("Demo 1: class replacement (replace transformers Qwen3MLP)") + print("=" * 60) + + mlp = _build_mlp() + x = torch.randn(1, 4, mlp.config.hidden_size) + + out_before = mlp(x) + print(f"Before kernelize: type = {type(mlp).__name__}") + print(f" output[0,0,:3] = {out_before[0, 0, :3].tolist()}") + + # Pass the MLP itself as the model; ``model.modules()`` yields it. + kernelize(mlp, {Qwen3MLP: FusedQwen3MLP}) + + out_after = mlp(x) + print(f"After kernelize: type = {type(mlp).__name__}") + print(f" output[0,0,:3] = {out_after[0, 0, :3].tolist()}") + + _assert(type(mlp) is FusedQwen3MLP, "mlp should be FusedQwen3MLP after kernelize") + # Params (gate_proj/up_proj/down_proj) are preserved on the instance, so the + # only difference is the +1.0 added by the fused forward. + _assert( + torch.allclose(out_after, out_before + 1.0), + "FusedQwen3MLP should add +1.0 to the original output", + ) + print("✓ Class replacement passed\n") + + +# --------------------------------------------------------------------------- # +# Demo 2: Attribute replacement (patch transformers qwen3 apply_rotary_pos_emb) +# --------------------------------------------------------------------------- # +_QWEN3_MOD_PATH = "transformers.models.qwen3.modeling_qwen3" +_ROPE_ATTR = "apply_rotary_pos_emb" + + +def demo_attr_replacement() -> None: + print("=" * 60) + print("Demo 2: attribute replacement (two forms)") + print("=" * 60) + + import importlib + + mod = importlib.import_module(_QWEN3_MOD_PATH) + + # ---- Form A: module attribute (pkg.mod.attr) -------------------------- # + print("-" * 60) + print("Form A: replace module-level function `apply_rotary_pos_emb`") + print("-" * 60) + + original_rope = getattr(mod, _ROPE_ATTR) + + q = torch.ones(1, 2, 4, 8) + k = torch.ones(1, 2, 4, 8) + cos = torch.ones(1, 1, 4, 8) + sin = torch.ones(1, 1, 4, 8) + + q_out_before, k_out_before = original_rope(q, k, cos, sin) + print(f"Before kernelize: {_describe(mod.apply_rotary_pos_emb)}") + + def fused_apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): # noqa: ANN001 + return original_rope(q, k, cos, sin, unsqueeze_dim=unsqueeze_dim) + + fused_apply_rotary_pos_emb._kernelize_marker = True # type: ignore[attr-defined] + + try: + kernelize(nn.Linear(1, 1), {f"{_QWEN3_MOD_PATH}.{_ROPE_ATTR}": fused_apply_rotary_pos_emb}) + + patched_fn = getattr(mod, _ROPE_ATTR) + print(f"After kernelize: {_describe(patched_fn)}") + _assert(patched_fn is fused_apply_rotary_pos_emb, "module attr should be the fused fn") + q_out_after, k_out_after = patched_fn(q, k, cos, sin) + _assert( + torch.allclose(q_out_after, q_out_before) + and torch.allclose(k_out_after, k_out_before), + "wrapped RoPE should preserve the original output", + ) + print("✓ Form A (module attribute) passed\n") + finally: + setattr(mod, _ROPE_ATTR, original_rope) + + # ---- Form B: class attribute / method (pkg.mod.ClassName.attr) ------- # + print("-" * 60) + print("Form B: replace class method `Qwen3MLP.forward`") + print("-" * 60) + + original_forward = Qwen3MLP.forward + + mlp = _build_mlp() + x = torch.randn(1, 4, mlp.config.hidden_size) + out_before = mlp(x) + print(f"Before kernelize: {_describe(Qwen3MLP.forward)}") + print(f" output[0,0,:3] = {out_before[0, 0, :3].tolist()}") + + def fused_forward(self, x): # noqa: ANN001 + return original_forward(self, x) + 1.0 + + try: + # Dotted path lands on the class, then setattr the method on it. + kernelize(nn.Linear(1, 1), {f"{_QWEN3_MOD_PATH}.Qwen3MLP.forward": fused_forward}) + + patched_forward = Qwen3MLP.forward + print(f"After kernelize: {_describe(patched_forward)}") + _assert(patched_forward is fused_forward, "class method should be the fused fn") + + out_after = mlp(x) + print(f" output[0,0,:3] = {out_after[0, 0, :3].tolist()}") + _assert( + torch.allclose(out_after, out_before + 1.0), + "fused forward should add +1.0 to the original output", + ) + print("✓ Form B (class method) passed\n") + finally: + setattr(Qwen3MLP, "forward", original_forward) + + +# --------------------------------------------------------------------------- # +# Demo 3: Hub replacement (real HuggingFace Hub kernel via ``kernels``) +# --------------------------------------------------------------------------- # +# We use the real ``kernels-community/activation`` repo on the HF Hub, which +# ships a ``SiluAndMul`` layer (the SwiGLU activation used by Qwen3MLP). +# +# Note: the Hub kernel's ``forward`` calls a CUDA op, so it cannot *execute* +# on CPU. This demo verifies the parts that DO work on CPU: the kernel is +# downloaded lazily via ``_load_hub_ref`` and the target module's class is +# swapped to the Hub-loaded class. Running the fused forward requires CUDA. +_HUB_REPO = "kernels-community/activation" +_HUB_LAYER = "SiluAndMul" + + +class LocalSiluAndMul(nn.Module): + """Pure-torch SwiGLU activation, same interface as the Hub ``SiluAndMul``.""" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + d = x.shape[-1] // 2 + return F.silu(x[..., :d]) * x[..., d:] + + +def demo_hub_replacement() -> None: + print("=" * 60) + print("Demo 3: Hub replacement (real HF Hub kernel: kernels-community/activation)") + print("=" * 60) + + try: + from kernels import get_kernel # noqa: F401 + except ImportError: + print("Skipped: `kernels` package not installed (pip install kernels)") + return + + model = nn.Sequential(LocalSiluAndMul()) + x = torch.randn(1, 4, 16, dtype=torch.float32) + + out_before = model(x) + print(f"Before kernelize: type = {type(model[0]).__name__}") + print(f" output[0,0,:3] = {out_before[0, 0, :3].tolist()}") + + ref = hub(f"{_HUB_REPO}:{_HUB_LAYER}", version=1) + print(f"HubRef: repo_id={ref.repo_id!r}, layer_name={ref.layer_name!r}, version={ref.version}") + + try: + kernelize(model, {LocalSiluAndMul: ref}) + except Exception as e: + print(f"Skipped: could not load Hub kernel ({type(e).__name__}: {e})") + return + + hub_cls = type(model[0]) + print(f"After kernelize: type = {hub_cls.__name__}") + print(f" module = {hub_cls.__module__}") + + _assert(hub_cls.__name__ == _HUB_LAYER, "should be the Hub SiluAndMul class") + _assert( + "activation" in hub_cls.__module__, + "loaded class should come from the Hub activation kernel package", + ) + # The Hub forward is CUDA-only, so we do not execute it on CPU. + print("(Hub kernel forward is CUDA-only; verified download + class swap on CPU)") + print("✓ Hub replacement passed\n") + + +# --------------------------------------------------------------------------- # +# Main +# --------------------------------------------------------------------------- # +def main() -> None: + print("Running kernelize end-to-end demos on CPU...\n") + demo_class_replacement() + demo_attr_replacement() + demo_hub_replacement() + print("All demos passed.") + + +if __name__ == "__main__": + main() diff --git a/src/twinkle/kernel/__init__.py b/src/twinkle/kernel/__init__.py index c7262eb07..f1499680d 100644 --- a/src/twinkle/kernel/__init__.py +++ b/src/twinkle/kernel/__init__.py @@ -1,111 +1,13 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -"""Twinkle Kernel Module - Kernel orchestration layer.""" -import torch -from logging import getLogger -from typing import Any, Dict, Optional, Union +"""Mapping-driven kernel replacement. -from twinkle.utils.framework import Torch -from .base import DeviceType, ModeType, is_kernels_enabled -from .function import apply_function_kernel, register_function_kernel -from .layer import apply_layer_kernel, register_layer_batch, register_layer_kernel -from .monkey_patch_npu import apply_npu_patch, register_npu_fused_function_kernels -from .registry import register_external_layer as _register_external_layer +Three public symbols: -logger = getLogger(__name__) +- :func:`kernelize` apply ``mapping`` to a model +- :func:`hub` build a Hub kernel reference +- :func:`npu_builtin` the Ascend NPU built-in bundle +""" +from .builtin import npu_builtin +from .core import hub, kernelize -__all__ = [ - 'kernelize_model', - 'register_layer_kernel', - 'register_function_kernel', - 'register_external_layer', - 'register_kernels', - 'apply_npu_patch', - 'apply_npu_fused_ops', -] - - -def kernelize_model( - model, - mode: ModeType = 'inference', - device: Optional[DeviceType] = None, - use_fallback: bool = True, -) -> Any: - """Apply kernels to model (main entry point). - - For NPU devices, this also applies Ascend fused operators (RMSNorm, RoPE, - SwiGLU, SDPA Attention) unconditionally when running on NPU. - - Args: - model: The PyTorch model to kernelize. - mode: The mode for kernel selection ("inference" or "train"). - device: The device type (auto-detected if None). - use_fallback: Whether to use original forward when no compatible kernel found. - If False, raises ValueError when kernel is unavailable. - - Returns: - The kernelized model. - """ - # Step 0: NPU monkey-patches must be applied BEFORE layer kernel replacement - # so that patched module classes are used when new instances are created. - if device == 'npu' or (device is None and _is_npu_device(model)): - try: - apply_npu_patch(model) - except Exception: - logger.warning('NPU patch failed. Continuing without fused ops.', exc_info=True) - - model = apply_layer_kernel(model, mode=mode, device=device, use_fallback=use_fallback) - - apply_function_kernel(device=device, mode=mode) - - return model - - -def apply_npu_fused_ops(config) -> None: - """Apply NPU fused operators patch manually. - """ - logger.warning('apply_npu_fused_ops(config) is deprecated. ' - 'Use apply_npu_patch() instead, which enables all patches unconditionally.') - apply_npu_patch() - - -def register_external_layer(layer_class: type, kernel_name: str) -> None: - _register_external_layer(layer_class, kernel_name) - - -def register_kernels(config: Dict[str, Dict[str, Any]]) -> None: - """Batch register kernels (framework integration API).""" - if 'layers' in config: - for kernel_name, spec in config['layers'].items(): - device = spec.pop('device', 'cuda') - register_layer_kernel(kernel_name=kernel_name, device=device, **spec) - - if 'functions' in config: - from .function import register_function_batch - - functions = config['functions'] - if isinstance(functions, dict): - function_specs = [] - for func_name, spec in functions.items(): - if not isinstance(spec, dict): - raise TypeError(f'Function spec for {func_name} must be a dict.') - if 'func_name' not in spec: - spec['func_name'] = func_name - function_specs.append(spec) - register_function_batch(function_specs) - else: - register_function_batch(functions) - - -def _is_npu_device(model=None) -> bool: - """Check if the model (or current environment) is on NPU device.""" - # Priority 1: Check model's actual device (kernel-specific inference) - if model is not None: - try: - param_device = next(model.parameters()).device - if param_device.type == 'npu': - return True - except StopIteration: - pass - - # Priority 2: Fallback to global NPU availability - return Torch.is_npu_available() +__all__ = ['kernelize', 'hub', 'npu_builtin'] diff --git a/src/twinkle/kernel/base.py b/src/twinkle/kernel/base.py deleted file mode 100644 index 6da669d5c..000000000 --- a/src/twinkle/kernel/base.py +++ /dev/null @@ -1,81 +0,0 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -"""Kernel module base - Base classes, env vars, device detection.""" -import os -from typing import Any, Literal, Optional - -from twinkle import exists - -ModeType = Literal['train', 'inference', 'compile'] -DeviceType = Literal['cuda', 'npu', 'mps', 'cpu', 'rocm', 'metal'] - - -def _kernels_enabled() -> bool: - """Check if kernels are enabled (default: enabled).""" - env_val = os.getenv('TWINKLE_USE_KERNELS', 'YES').upper() - return env_val in ('YES', 'TRUE', '1', 'ON') - - -def _trust_remote_code() -> bool: - """Check if remote code is trusted (default: not trusted).""" - env_val = os.getenv('TWINKLE_TRUST_REMOTE_CODE', 'NO').upper() - return env_val in ('YES', 'TRUE', '1', 'ON') - - -def detect_backend() -> Optional[str]: - """Detect training framework backend: "transformers" | "megatron" | None.""" - if exists('transformers'): - return 'transformers' - return None - - -def is_kernels_available() -> bool: - """Check if HF kernels package is available.""" - return exists('kernels') - - -def is_kernels_enabled() -> bool: - """Check if kernels are enabled by env var.""" - return _kernels_enabled() and is_kernels_available() - - -def to_kernels_mode(mode: ModeType) -> Any: - """Convert Twinkle mode to HF kernels mode.""" - if not is_kernels_available(): - return None - from kernels import Mode - if isinstance(mode, Mode): - return mode - mode_map = { - 'train': Mode.TRAINING, - 'inference': Mode.INFERENCE, - 'compile': Mode.TORCH_COMPILE, - } - return mode_map.get(mode, Mode.INFERENCE) - - -def validate_mode(mode: str) -> None: - from kernels.layer.mode import Mode - mode = to_kernels_mode(mode) - - if mode == Mode.FALLBACK: - raise ValueError('Mode.FALLBACK can only be used to register kernel mappings.') - if Mode.INFERENCE not in mode and Mode.TRAINING not in mode: # type: ignore[operator] - raise ValueError('kernelize mode must contain Mode.INFERENCE or Mode.TRAINING.') - - -def supports_mode(target: object, mode: str) -> bool: - from kernels.layer.mode import Mode - mode = to_kernels_mode(mode) - if Mode.TORCH_COMPILE in mode and not getattr(target, 'can_torch_compile', False): - return False - if Mode.TRAINING in mode and not getattr(target, 'has_backward', True): - return False - return True - - -def validate_device_type(device_type: str) -> None: - supported_devices = {'cpu', 'cuda', 'mps', 'npu', 'rocm', 'xpu'} - if device_type not in supported_devices: - raise ValueError('Unsupported device type ' - f"'{device_type}'. Supported device types are: " - f"{', '.join(sorted(supported_devices))}") diff --git a/src/twinkle/kernel/builtin.py b/src/twinkle/kernel/builtin.py new file mode 100644 index 000000000..cda8813bb --- /dev/null +++ b/src/twinkle/kernel/builtin.py @@ -0,0 +1,209 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""``npu_builtin()`` returns the bundle of Ascend NPU replacements. + +All values are wrapped in ``{'npu': impl}`` so the bundle composes safely on +CUDA/CPU systems — non-NPU devices silently skip every entry. + +GMM is **not** included by default (without EP it causes ~8x slowdown). Opt +in by merging: + + {**npu_builtin(model), 'transformers.integrations.moe._grouped_mm': + {'npu': npu_grouped_mm}} +""" +from __future__ import annotations + +import importlib +import torch.nn as nn +from typing import Any + +from twinkle import get_logger +from twinkle.utils.device_mesh import Platform + +logger = get_logger() + + +def _import_optional(name: str): + try: + return importlib.import_module(name) + except ImportError: + return None + + +def npu_builtin(model: nn.Module | None = None) -> dict[Any, dict[str, Any]]: + """Return the NPU builtin mapping; optionally apply per-instance FLA.""" + from .npu_impls.attention import npu_sdpa_attention_forward + from .npu_impls.fla import apply_qwen3_5_fla + from .npu_impls.moe import npu_packed_moe_experts_forward, npu_qwen3_5_moe_sparse_block_forward + from .npu_impls.rms_norm import NpuRMSNorm, npu_gated_rms_norm_forward + from .npu_impls.rotary import npu_apply_multimodal_rotary_pos_emb, npu_apply_rotary_pos_emb + from .npu_impls.swiglu import npu_swiglu_forward + + bundle: dict[Any, dict[str, Any]] = {} + + is_npu_platform = Platform.device_prefix() == 'npu' + + # Apply SDPA install eagerly (one-shot module-level mutation) on NPU + # platforms. The NPU impl inverts boolean masks, which is wrong for + # CUDA/CPU execution, so non-NPU platforms must not mutate the global HF + # registry even if ``torch_npu`` is importable in the environment. + if is_npu_platform: + _install_sdpa(npu_sdpa_attention_forward) + + # === per-family class + function entries === + _add_qwen2_entries(bundle, NpuRMSNorm, npu_apply_rotary_pos_emb, npu_swiglu_forward) + _add_qwen3_entries(bundle, NpuRMSNorm, npu_apply_rotary_pos_emb, npu_swiglu_forward) + _add_qwen3_moe_entries( + bundle, + NpuRMSNorm, + npu_apply_rotary_pos_emb, + npu_swiglu_forward, + npu_packed_moe_experts_forward, + npu_qwen3_5_moe_sparse_block_forward, + ) + _add_qwen2_5_vl_entries( + bundle, + NpuRMSNorm, + npu_apply_rotary_pos_emb, + npu_swiglu_forward, + npu_apply_multimodal_rotary_pos_emb, + ) + _add_qwen3_5_entries( + bundle, + NpuRMSNorm, + npu_gated_rms_norm_forward, + npu_apply_rotary_pos_emb, + npu_swiglu_forward, + ) + _add_qwen3_5_moe_entries( + bundle, + NpuRMSNorm, + npu_gated_rms_norm_forward, + npu_apply_rotary_pos_emb, + npu_swiglu_forward, + npu_packed_moe_experts_forward, + npu_qwen3_5_moe_sparse_block_forward, + ) + + # === FLA (side-effect; mapping-incompatible) === + if is_npu_platform: + apply_qwen3_5_fla(model) + + return bundle + + +def _install_sdpa(impl) -> None: + """One-shot install of SDPA attention forward (global modeling_utils dict). + + ``AttentionInterface._global_mapping`` is a private transformers attribute; + guard against its removal so an upstream change can't take down the rest + of ``npu_builtin()``. + """ + try: + from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, AttentionInterface + except ImportError: + return + try: + AttentionInterface._global_mapping['sdpa'] = impl + except AttributeError: + logger.warning('[NPU] [SDPA] AttentionInterface._global_mapping unavailable; skipping') + ALL_ATTENTION_FUNCTIONS['sdpa'] = impl + + +# ---- helpers that conditionally add entries based on module availability ---- + + +def _add_class_if_present(bundle, module_path, class_name, impl_cls): + mod = _import_optional(module_path) + if mod is None: + return + cls = getattr(mod, class_name, None) + if isinstance(cls, type): + bundle[cls] = {'npu': impl_cls} + + +def _add_swiglu_if_present(bundle, module_path, class_name, fn): + mod = _import_optional(module_path) + if mod is None: + return + cls = getattr(mod, class_name, None) + if isinstance(cls, type): + # Function-level: wrap as string-keyed forward replacement. + # We override on the *class object*, not the module attribute, by + # using a class-key with a synthetic impl wrapping the forward. + # The simplest way is to subclass and reassign __class__, but here + # we follow the legacy approach of overwriting the class's forward: + bundle[f'{module_path}.{class_name}.forward'] = {'npu': fn} + + +def _add_attr_if_present(bundle, module_path, attr_name, impl): + mod = _import_optional(module_path) + if mod is None: + return + if '.' in attr_name: + # Dotted attr like 'Qwen3MoeExperts.forward': resolve the class on + # the module, then check the trailing member on the class. + head, _, tail = attr_name.partition('.') + owner = getattr(mod, head, None) + if owner is None or not hasattr(owner, tail): + return + else: + if not hasattr(mod, attr_name): + return + bundle[f'{module_path}.{attr_name}'] = {'npu': impl} + + +def _add_qwen2_entries(bundle, rms_cls, rope_fn, swiglu_fn): + # Qwen2 (used by Qwen2.5-VL etc. via inheritance) + _add_class_if_present(bundle, 'transformers.models.qwen2.modeling_qwen2', 'Qwen2RMSNorm', rms_cls) + _add_attr_if_present(bundle, 'transformers.models.qwen2.modeling_qwen2', 'apply_rotary_pos_emb', rope_fn) + _add_swiglu_if_present(bundle, 'transformers.models.qwen2.modeling_qwen2', 'Qwen2MLP', swiglu_fn) + + +def _add_qwen3_entries(bundle, rms_cls, rope_fn, swiglu_fn): + base = 'transformers.models.qwen3.modeling_qwen3' + _add_class_if_present(bundle, base, 'Qwen3RMSNorm', rms_cls) + _add_attr_if_present(bundle, base, 'apply_rotary_pos_emb', rope_fn) + _add_swiglu_if_present(bundle, base, 'Qwen3MLP', swiglu_fn) + + +def _add_qwen3_moe_entries(bundle, rms_cls, rope_fn, swiglu_fn, experts_fn, sparse_fn): + base = 'transformers.models.qwen3_moe.modeling_qwen3_moe' + _add_class_if_present(bundle, base, 'Qwen3MoeRMSNorm', rms_cls) + _add_attr_if_present(bundle, base, 'apply_rotary_pos_emb', rope_fn) + _add_swiglu_if_present(bundle, base, 'Qwen3MoeMLP', swiglu_fn) + _add_attr_if_present(bundle, base, 'Qwen3MoeExperts.forward', experts_fn) + _add_attr_if_present(bundle, base, 'Qwen3MoeSparseMoeBlock.forward', sparse_fn) + + +def _add_qwen2_5_vl_entries(bundle, rms_cls, rope_fn, swiglu_fn, multimodal_rope_fn): + base = 'transformers.models.qwen2_5_vl.modeling_qwen2_5_vl' + _add_class_if_present(bundle, base, 'Qwen2_5_VLRMSNorm', rms_cls) + _add_attr_if_present(bundle, base, 'apply_rotary_pos_emb', rope_fn) + _add_attr_if_present(bundle, base, 'apply_multimodal_rotary_pos_emb', multimodal_rope_fn) + _add_swiglu_if_present(bundle, base, 'Qwen2MLP', swiglu_fn) + _add_swiglu_if_present(bundle, base, 'Qwen2_5_VLMLP', swiglu_fn) + + +def _add_qwen3_5_entries(bundle, rms_cls, gated_rms_fn, rope_fn, swiglu_fn): + base = 'transformers.models.qwen3_5.modeling_qwen3_5' + if _import_optional(base) is None: + return + _add_class_if_present(bundle, base, 'Qwen3_5RMSNorm', rms_cls) + _add_class_if_present(bundle, base, 'Qwen3_5VisionRMSNorm', rms_cls) + _add_attr_if_present(bundle, base, 'apply_rotary_pos_emb', rope_fn) + _add_swiglu_if_present(bundle, base, 'Qwen3_5MLP', swiglu_fn) + _add_swiglu_if_present(bundle, base, 'Qwen3_5VisionMLP', swiglu_fn) + # Qwen3_5GatedRMSNorm: forward-level replacement + _add_attr_if_present(bundle, base, 'Qwen3_5GatedRMSNorm.forward', gated_rms_fn) + + +def _add_qwen3_5_moe_entries(bundle, rms_cls, gated_rms_fn, rope_fn, swiglu_fn, experts_fn, sparse_fn): + base = 'transformers.models.qwen3_5_moe.modeling_qwen3_5_moe' + if _import_optional(base) is None: + return + _add_class_if_present(bundle, base, 'Qwen3_5MoeRMSNorm', rms_cls) + _add_attr_if_present(bundle, base, 'apply_rotary_pos_emb', rope_fn) + _add_swiglu_if_present(bundle, base, 'Qwen3_5MoeMLP', swiglu_fn) + _add_attr_if_present(bundle, base, 'Qwen3_5MoeExperts.forward', experts_fn) + _add_attr_if_present(bundle, base, 'Qwen3_5MoeSparseMoeBlock.forward', sparse_fn) + _add_attr_if_present(bundle, base, 'Qwen3_5MoeGatedRMSNorm.forward', gated_rms_fn) diff --git a/src/twinkle/kernel/chunk_gated_delta_rule.py b/src/twinkle/kernel/chunk_gated_delta_rule.py index 553fb1226..2d0beee77 100644 --- a/src/twinkle/kernel/chunk_gated_delta_rule.py +++ b/src/twinkle/kernel/chunk_gated_delta_rule.py @@ -1,7 +1,7 @@ '''Ascend NPU implementation of chunk_gated_delta_rule for Flash Linear Attention (FLA). This module provides a drop-in replacement for fla.ops.gated_delta_rule.chunk_gated_delta_rule, redirecting the underlying Triton kernels to MindSpeed's NPU-compatible counterparts. -It is consumed by twinkle.kernel.monkey_patch_npu to enable the fast linear-attention +It is consumed by twinkle.kernel.npu_impls.fla to enable the fast linear-attention path of Qwen3.5 on Ascend hardware.''' import torch diff --git a/src/twinkle/kernel/core.py b/src/twinkle/kernel/core.py new file mode 100644 index 000000000..bdb122709 --- /dev/null +++ b/src/twinkle/kernel/core.py @@ -0,0 +1,164 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Minimal mapping-driven kernel replacement. + +Public API: ``kernelize``, ``hub`` (re-exported from ``twinkle.kernel``). +""" +from __future__ import annotations + +import importlib +import torch.nn as nn +from dataclasses import dataclass +from typing import Any + +from twinkle.utils.device_mesh import Platform + + +@dataclass(frozen=True) +class HubRef: + """Lightweight reference to a HuggingFace Hub kernel layer. + + Resolved lazily by ``kernelize`` via the optional ``kernels`` package. + """ + repo_id: str + layer_name: str + revision: str | None = None + version: int | None = None + backend: str | None = None + + +def hub( + ref: str, + *, + revision: str | None = None, + version: int | None = None, + backend: str | None = None, +) -> HubRef: + """Build a ``HubRef`` for use as a ``kernelize`` mapping value. + + ``ref`` is ``':'`` (e.g. ``'org/repo:SiluAndMul'``). + Exactly one of ``revision`` or ``version`` must be supplied. + """ + if (revision is None) == (version is None): + raise ValueError('Exactly one of `revision` or `version` must be specified.') + if ':' not in ref: + raise ValueError(f"Hub ref must be 'repo_id:LayerName', got: {ref!r}") + repo_id, layer_name = ref.rsplit(':', 1) + return HubRef(repo_id, layer_name, revision, version, backend) + + +def _resolve_value(value: Any, device: str) -> Any | None: + """Resolve a mapping value against the selected device. + + - ``dict``: device-conditional; recurse into ``value[device]`` or return None. + - anything else (including ``HubRef``): pass through. + """ + if isinstance(value, dict): + if device not in value: + return None + return _resolve_value(value[device], device) + return value + + +def _replace_class(model: nn.Module, target_cls: type, impl_cls: type) -> None: + """Rewrite ``__class__`` of every module whose exact type is ``target_cls``. + + Uses ``type(m) is target_cls`` (not ``isinstance``) so user-defined + subclasses of ``target_cls`` are deliberately left alone. + """ + for m in model.modules(): + if type(m) is target_cls: + m.__class__ = impl_cls + + +def _replace_attr(dotted_path: str, impl) -> None: + """``setattr`` ``impl`` onto the attribute identified by the dotted path. + + Supports two forms: + - ``pkg.mod.attr`` (set module attribute) + - ``pkg.mod.ClassName.attr`` (set class attribute / method) + + The split is found by walking the prefix from the longest importable + module backwards until ``importlib.import_module`` succeeds. + """ + parts = dotted_path.split('.') + if len(parts) < 2: + raise ValueError(f"Expected at least 'pkg.attr', got: {dotted_path!r}") + + # Find the longest prefix that imports as a module. + last_err: ImportError | None = None + module = None + module_depth = 0 + for i in range(len(parts) - 1, 0, -1): + candidate = '.'.join(parts[:i]) + try: + module = importlib.import_module(candidate) + module_depth = i + break + except ImportError as e: + last_err = e + continue + if module is None: + raise ImportError(f'Could not import any prefix of {dotted_path!r}') from last_err + + # Walk remaining attributes; the last one is the target. + obj = module + for attr in parts[module_depth:-1]: + obj = getattr(obj, attr) + setattr(obj, parts[-1], impl) + + +def _load_hub_ref(ref: HubRef): + """Lazy-load a Hub kernel layer via the optional ``kernels`` package.""" + try: + from kernels import get_kernel + except ImportError as e: + raise ImportError('Loading a Hub kernel requires the `kernels` package. ' + 'Install it with `pip install kernels`.') from e + + kernel = get_kernel( + ref.repo_id, + revision=ref.revision, + version=ref.version, + backend=ref.backend, + ) + layers = getattr(kernel, 'layers', None) + if layers is None: + raise ValueError(f'Hub repo {ref.repo_id!r} does not define any layers.') + impl = getattr(layers, ref.layer_name, None) + if impl is None: + raise ValueError(f'Layer {ref.layer_name!r} not found in {ref.repo_id!r}.') + return impl + + +def kernelize(model: nn.Module, mapping: dict) -> nn.Module: + """Apply ``mapping`` to ``model`` and return it (modified in place). + + Keys: + - ``type[nn.Module]``: replace ``m.__class__`` for every module of the + exact type (no subclass walking). + - ``str`` (dotted path ``pkg.mod.attr``): ``setattr`` the impl onto the + identified module attribute. + + Values: + - ``dict[str, V]``: device-conditional dispatch using the current + Twinkle platform device prefix; non-matching devices skip. + - ``HubRef``: lazy-resolved via the optional ``kernels`` package. + - anything else: used directly as the impl. + """ + if not mapping: + return model + + device = Platform.device_prefix() + for key, value in mapping.items(): + impl = _resolve_value(value, device) + if impl is None: + continue + if isinstance(impl, HubRef): + impl = _load_hub_ref(impl) + if isinstance(key, type) and issubclass(key, nn.Module): + _replace_class(model, key, impl) + elif isinstance(key, str): + _replace_attr(key, impl) + else: + raise TypeError(f'Unsupported mapping key: {key!r}') + return model diff --git a/src/twinkle/kernel/function.py b/src/twinkle/kernel/function.py deleted file mode 100644 index 94a2d817d..000000000 --- a/src/twinkle/kernel/function.py +++ /dev/null @@ -1,174 +0,0 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -from __future__ import annotations - -import importlib -from typing import TYPE_CHECKING, Callable, Iterable, List, Optional - -from twinkle import get_logger -from .base import ModeType, is_kernels_available, validate_device_type, validate_mode -from .registry import FunctionKernelSpec, get_global_function_registry - -if TYPE_CHECKING: - from kernels.layer.func import FuncRepositoryProtocol - -logger = get_logger() - - -def _load_from_hub( - *, - repo: FuncRepositoryProtocol | None, - repo_id: str | None, - revision: str | None, - version: str | None, - func_name: str, -) -> tuple[Callable, object]: - """Resolve function implementation from a repo or Hub repo_id.""" - if repo is not None: - module_cls = repo.load() - module_instance = module_cls() - - def impl(*args, **kwargs): - return module_instance(*args, **kwargs) - - return impl, module_instance - - from kernels._versions import select_revision_or_version - from kernels.utils import get_kernel - assert repo_id is not None - # kernels API changed across versions; use keyword args for modern API - # and fall back to repo_id-only for older variants. - try: - resolved = select_revision_or_version(repo_id, revision=revision, version=version) - except TypeError: - resolved = select_revision_or_version(repo_id) - try: - kernel = get_kernel(repo_id, revision=resolved) - except TypeError: - kernel = get_kernel(repo_id, resolved) - func = getattr(kernel, func_name, None) - if func is None: - raise AttributeError(f'Kernel repo {repo_id} does not export {func_name}.') - return func, func - - -def register_function_kernel( - *, - func_name: str, - target_module: str, - func_impl: Callable | None = None, - repo: FuncRepositoryProtocol | None = None, - repo_id: str | None = None, - revision: str | None = None, - version: str | None = None, - device: str | None = None, - mode: ModeType | None = None, -) -> None: - """Register a function kernel with the registry.""" - sources = [func_impl is not None, repo is not None, repo_id is not None] - if sum(sources) != 1: - raise ValueError('Provide exactly one of func_impl, repo, or repo_id.') - if revision is not None and version is not None: - raise ValueError('Either revision or version must be specified, not both.') - if mode is not None: - validate_mode(mode) - - get_global_function_registry().register( - FunctionKernelSpec( - func_name=func_name, - target_module=target_module, - func_impl=func_impl, - repo=repo, - repo_id=repo_id, - revision=revision, - version=version, - device=device, - mode=mode, - )) - - -def register_function_batch(function_registry: Iterable[dict]) -> None: - """Batch register function kernels from a list of spec dicts.""" - for spec in function_registry: - register_function_kernel( - func_name=spec['func_name'], - target_module=spec['target_module'], - func_impl=spec.get('func_impl'), - repo=spec.get('repo'), - repo_id=spec.get('repo_id'), - revision=spec.get('revision'), - version=spec.get('version'), - device=spec.get('device'), - mode=spec.get('mode'), - ) - - -def apply_function_kernel( - *, - target_module: str | None = None, - device: str | None = None, - mode: ModeType | None = None, - strict: bool = False, -) -> list[str]: - """Apply registered function kernels by monkey-patching target modules. - target_module: If specified, only apply kernels targeting this module. - device: If specified, only apply kernels matching this device or with no device. - mode: If specified, only apply kernels matching this mode or with no mode. - strict: If True, raise errors on failures; otherwise log warnings. - """ - applied = [] - if device is not None: - validate_device_type(device) - - for spec in get_global_function_registry().list_specs(): - # Filter by target module and device/mode constraints. - if target_module is not None and spec.target_module != target_module: - continue - if device is not None and spec.device is not None and spec.device != device: - continue - if spec.mode is not None and mode is None: - msg = ('Function kernel registered with mode but apply_function_kernel ' - 'was called without mode; skipping.') - if strict: - raise ValueError(msg) - logger.warning(msg) - continue - if spec.mode is not None and mode is not None and spec.mode != mode: - continue - - try: - # Import the module that will be monkey-patched. - module = importlib.import_module(spec.target_module) - except Exception as exc: - if strict: - raise - logger.warning( - 'Failed to import target module %s: %s', - spec.target_module, - exc, - ) - continue - - # Resolve implementation and capability target for mode checks. - if spec.func_impl is not None: - impl = spec.func_impl - else: - if not is_kernels_available(): - msg = ('HF kernels package not available. ' - f'Cannot load function kernel: {spec.func_name}. ' - 'Install it with `pip install kernels`.') - raise RuntimeError(msg) - impl, _ = _load_from_hub( - repo=spec.repo, - repo_id=spec.repo_id, - revision=spec.revision, - version=spec.version, - func_name=spec.func_name, - ) - # Final patch (or reapply when no mode gating is used). - setattr(module, spec.func_name, impl) - applied.append(f'{spec.target_module}.{spec.func_name}') - - if strict and not applied: - raise ValueError('No function kernels applied for the given filters.') - - return applied diff --git a/src/twinkle/kernel/layer.py b/src/twinkle/kernel/layer.py deleted file mode 100644 index e47f73924..000000000 --- a/src/twinkle/kernel/layer.py +++ /dev/null @@ -1,119 +0,0 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -"""Kernel module layer - Layer-level replacement with HF kernels integration.""" -from pathlib import Path -from typing import Any, Optional, Union - -from twinkle import Platform, get_logger -from .base import DeviceType, ModeType, is_kernels_available, is_kernels_enabled, to_kernels_mode -from .registry import get_global_layer_registry, register_layer - -logger = get_logger() - - -def register_layer_kernel( - kernel_name: str, - repo_id: Optional[str] = None, - repo_path: Optional[Union[str, Path]] = None, - package_name: Optional[str] = None, - layer_name: Optional[str] = None, - version: Optional[str] = None, - device: DeviceType = 'cuda', - mode: Optional[ModeType] = None, -) -> None: - """Register a layer kernel with the registry. - - Args: - kernel_name: Unique kernel name (can register multiple modes with same name) - repo_id: Hub repository ID - repo_path: Local repository path - package_name: Package name (required when using repo_path) - layer_name: Layer name (defaults to kernel_name) - version: Version constraint - device: Device type - mode: Mode (train/inference/compile), None means FALLBACK - """ - if not is_kernels_available(): - logger.warning(f'HF kernels package not available. Skipping registration for kernel: {kernel_name}') - return - - from kernels import LayerRepository, LocalLayerRepository - - if repo_path is not None: - if package_name is None: - raise ValueError(f'package_name must be provided when using repo_path for kernel: {kernel_name}') - if isinstance(repo_path, str): - repo_path = Path(repo_path) - repo_spec = LocalLayerRepository( - repo_path=repo_path, - package_name=package_name, - layer_name=layer_name or kernel_name, - ) - else: - if repo_id is None: - raise ValueError(f'Either repo_id or repo_path must be provided for kernel: {kernel_name}') - repo_spec = LayerRepository( - repo_id=repo_id, - layer_name=layer_name or kernel_name, - version=version, - ) - - hf_mode = _to_hf_mode(mode) - register_layer(kernel_name, repo_spec, device, mode=hf_mode) - - mode_str = mode or 'FALLBACK' - logger.info(f'Registered layer kernel: {kernel_name} for device: {device}, mode: {mode_str}') - - -def _to_hf_mode(mode: Optional[ModeType]) -> Any: - """Convert Twinkle mode to HF kernels Mode.""" - if mode is None: - from kernels import Mode - return Mode.FALLBACK - return to_kernels_mode(mode) - - -def apply_layer_kernel( - model, - mode: ModeType = 'inference', - device: Optional[DeviceType] = None, - use_fallback: bool = True, -) -> Any: - """Apply layer kernels to model. - - Args: - model: The PyTorch model to kernelize. - mode: The mode for kernel selection ("inference" or "train"). - device: The device type (auto-detected if None). - use_fallback: Whether to use original forward when no compatible kernel found. - If False, raises ValueError when kernel is unavailable. - - Returns: - The kernelized model. - """ - if not is_kernels_enabled(): - logger.debug('Kernels not enabled, returning original model') - return model - - get_global_layer_registry().sync_to_hf_kernels() - - if device is None: - device = Platform.get_platform().device_prefix() or 'cuda' - - kernel_mode = to_kernels_mode(mode) - - try: - from kernels import kernelize - logger.debug(f'Applying kernels with mode: {mode}, device: {device}, use_fallback: {use_fallback}') - return kernelize(model, mode=kernel_mode, device=device, use_fallback=use_fallback) - except Exception as e: - if use_fallback: - logger.warning(f'Failed to apply kernels: {e}. Returning original model.') - return model - raise - - -def register_layer_batch(mapping: dict, default_device: DeviceType = 'cuda') -> None: - """Batch register layer kernels.""" - for kernel_name, spec in mapping.items(): - device = spec.pop('device', default_device) - register_layer_kernel(kernel_name=kernel_name, device=device, **spec) diff --git a/src/twinkle/kernel/monkey_patch_npu.py b/src/twinkle/kernel/monkey_patch_npu.py deleted file mode 100644 index 01e51b06f..000000000 --- a/src/twinkle/kernel/monkey_patch_npu.py +++ /dev/null @@ -1,1009 +0,0 @@ -"""NPU monkey patches for Ascend hardware acceleration. - -Unified entry point:: - - >>> from twinkle.kernel.monkey_patch_npu import apply_npu_patch - >>> if Torch.is_npu_available(): - ... apply_npu_patch(model) -""" - -import importlib -import os -import torch -import torch.nn.functional as F -from torch import nn -from transformers.utils import is_torch_npu_available - -from twinkle import get_logger -from .causal_conv1d import npu_causal_conv1d_fn - -logger = get_logger() - -_is_torch_npu_available = is_torch_npu_available() -_NPU_PATCH_APPLIED = False - -if _is_torch_npu_available: - import torch_npu - -# --------------------------------------------------------------------------- -# Utils -# --------------------------------------------------------------------------- - - -def import_optional_module(module_name: str): - """Import a module, returning None if unavailable.""" - try: - return importlib.import_module(module_name) - except ImportError as exc: - logger.debug('Failed to import optional module %s: %s', module_name, exc) - return None - - -def _resolve_unsqueeze_dim(position_ids=None, unsqueeze_dim=1): - if isinstance(position_ids, int) and unsqueeze_dim == 1: - return position_ids - return unsqueeze_dim - - -def _is_ep_enabled(model=None) -> bool: - r"""Check whether Expert Parallelism (EP) is enabled. - - EP is detected via ``device_mesh.ep_size > 1``. - When EP is active, each rank holds only a subset of expert weights, - making ``npu_grouped_matmul`` efficient (small contiguous weights). - """ - device_mesh = getattr(model, 'device_mesh', None) - if device_mesh is None: - return False - return (getattr(device_mesh, 'ep_size', None) or 0) > 1 - - -# ============================================================================= -# Section 1: MoE Grouped MatMul (GMM) -# ============================================================================= - - -class GmmFunction(torch.autograd.Function): - r"""Custom autograd function for NPU grouped matrix multiplication.""" - - @staticmethod - def forward(ctx, x: torch.tensor, group_list: torch.tensor, weight_ekn: torch.tensor): - group_list = group_list.to(torch.int64) - ctx.save_for_backward(x, group_list, weight_ekn) - outputs = torch_npu.npu_grouped_matmul( - [x], - [weight_ekn], - group_list=group_list, - group_type=0, - split_item=2, - group_list_type=1, - ) - return outputs[0] - - @staticmethod - def backward(ctx, grad_output: torch.tensor): - x, group_list, weight_ekn = ctx.saved_tensors - grad_input = torch_npu.npu_grouped_matmul( - [grad_output], - [weight_ekn.transpose(-2, -1).contiguous()], - bias=None, - group_list=group_list, - group_type=0, - split_item=2, - group_list_type=1, - )[0] - grad_weight = torch_npu.npu_grouped_matmul( - [x.transpose(0, 1)], - [grad_output], - bias=None, - group_list=group_list, - group_type=2, - split_item=3, - group_list_type=1, - )[0] - return grad_input, None, grad_weight.contiguous() - - -def _grouped_mm_npu(input: torch.tensor, weight_ekn: torch.tensor, offs: torch.tensor) -> torch.tensor: - counts = torch.empty_like(offs) - counts[0] = offs[0] - if offs.numel() > 1: - counts[1:] = offs[1:] - offs[:-1] - counts = counts.to(torch.int64) - return GmmFunction.apply(input, counts, weight_ekn) - - -def _apply_hf_moe_grouped_mm_patch(model=None) -> None: - r"""Patch HuggingFace MoE integration to use NPU grouped matmul. - - When Expert Parallelism (EP) is **not** enabled, each rank holds **all** - expert weights. ``weight.transpose(-2, -1)`` then produces a large - non-contiguous view that ``npu_grouped_matmul`` forces to ``.contiguous()`` - (~12.88 GB per MoE layer), creating a bandwidth bottleneck that makes the - NPU patch **slower** than the native per-expert fallback (~8x overhead). - - Detection logic: - - ``TWINKLE_NPU_GMM_PATCH`` not set → **skip** the patch by default. - - ``TWINKLE_NPU_GMM_PATCH=1`` → EP-aware: apply only if EP is enabled - (each rank has few experts, weights are small and contiguous); - skip if EP is **not** enabled (avoid ~8x overhead). - - ``TWINKLE_NPU_GMM_PATCH=0`` → **disable** the patch regardless. - """ - moe_enabled = _is_env_enabled('TWINKLE_NPU_GMM_PATCH', default=False) - - if not moe_enabled: - has_native_gmm = hasattr(torch.nn.functional, 'grouped_mm') - logger.info( - '[PATCH] TWINKLE_NPU_GMM_PATCH not set: MoE GMM patch skipped by default. ' - 'Set TWINKLE_NPU_GMM_PATCH=1 to enable (EP-aware). ' - 'Native grouped_mm available: %s.', - has_native_gmm, - ) - return - - if not _is_ep_enabled(model): - has_native_gmm = hasattr(torch.nn.functional, 'grouped_mm') - logger.info( - '[PATCH] TWINKLE_NPU_GMM_PATCH=1 but EP not enabled (all experts on each rank) — ' - 'skipping _grouped_mm_npu patch to avoid ~8x overhead from ' - 'contiguous copies on transposed weights. ' - 'Native grouped_mm available: %s.', - has_native_gmm, - ) - return - - import transformers.integrations.moe as hf_moe - hf_moe._grouped_mm = _grouped_mm_npu - logger.info('[PATCH] transformers.integrations.moe._grouped_mm -> _grouped_mm_npu') - - -# ============================================================================= -# Section 1b: MoE Packed Experts -# ============================================================================= - - -def _normalize_packed_expert_weights(module, input_dtype: torch.dtype, hidden_dim: int): - """Normalize packed expert weight shapes for NPU grouped matmul.""" - gate_up_proj = module.gate_up_proj.to(input_dtype) - down_proj = module.down_proj.to(input_dtype) - - if gate_up_proj.shape[1] == hidden_dim: - gate_up_weight = gate_up_proj - elif gate_up_proj.shape[2] == hidden_dim: - gate_up_weight = gate_up_proj.transpose(1, 2) - else: - raise RuntimeError(f'Unsupported gate_up_proj shape for NPU MoE patch: {tuple(gate_up_proj.shape)}.') - - if down_proj.shape[2] == hidden_dim: - down_weight = down_proj - elif down_proj.shape[1] == hidden_dim: - down_weight = down_proj.transpose(1, 2) - else: - raise RuntimeError(f'Unsupported down_proj shape for NPU MoE patch: {tuple(down_proj.shape)}.') - - return gate_up_weight, down_weight - - -def _get_cached_expert_weights(self, target_dtype: torch.dtype, hidden_dim: int): - """Return normalized expert weights with automatic cache invalidation. - - Cache key combines (dtype, gate_version, down_version). This correctly - handles: - - Full-parameter training: optimizer in-place updates bump _version - - LoRA training: frozen weights keep _version stable, cache persists - - Inference: cache is permanent - - AMP autocast: separate cache per dtype - - Safety: when weights require gradients, the cache is bypassed to avoid - breaking the PyTorch autograd graph (non-leaf tensors from .to() cannot - be reused across forward passes). - """ - requires_grad = ( - getattr(self.gate_up_proj, 'requires_grad', False) or getattr(self.down_proj, 'requires_grad', False)) - cache_attr = '_npu_expert_cache' - if not requires_grad and hasattr(self, cache_attr): - cached_dtype, cached_gate_ver, cached_down_ver, cached = getattr(self, cache_attr) - if (cached_dtype == target_dtype and cached_gate_ver == self.gate_up_proj._version - and cached_down_ver == self.down_proj._version): - return cached - - weights = _normalize_packed_expert_weights(self, target_dtype, hidden_dim) - if not requires_grad: - setattr( - self, - cache_attr, - (target_dtype, self.gate_up_proj._version, self.down_proj._version, weights), - ) - return weights - - -def npu_packed_moe_experts_forward( - self, - hidden_states: torch.Tensor, - router_indices_or_routing_weights: torch.Tensor, - routing_weights_or_router_indices: torch.Tensor, -) -> torch.Tensor: - """Packed MoE experts forward using NPU grouped matmul. - - Compatible with Qwen3-MoE, Qwen3.5-MoE, and any model using packed experts - with the standard ``(hidden_states, router_indices, routing_weights)`` call convention. - """ - if router_indices_or_routing_weights.dtype in {torch.int8, torch.int16, torch.int32, torch.int64, torch.uint8}: - router_indices = router_indices_or_routing_weights - routing_weights = routing_weights_or_router_indices - else: - routing_weights = router_indices_or_routing_weights - router_indices = routing_weights_or_router_indices - - output_shape = hidden_states.shape - hidden_dim = output_shape[-1] - hidden_states = hidden_states.reshape(-1, hidden_dim) - - if routing_weights.shape != router_indices.shape: - routing_weights = torch.gather(routing_weights, dim=-1, index=router_indices.to(torch.long)) - routing_weights = routing_weights.to(hidden_states.dtype) - router_indices = router_indices.to(torch.int32) - - permuted_hidden_states, row_ids_map = torch_npu.npu_moe_token_permute(hidden_states, router_indices) - tokens_per_expert = torch.bincount(router_indices.view(-1), minlength=self.num_experts).to(torch.int64) - - # Cached normalized weights: auto-invalidates on weight updates (full-param) - # and persists when frozen (LoRA / inference). - gate_up_weight, down_weight = _get_cached_expert_weights(self, hidden_states.dtype, hidden_dim) - - intermediate_hidden_states = GmmFunction.apply(permuted_hidden_states, tokens_per_expert, gate_up_weight) - intermediate_activations = torch_npu.npu_swiglu(intermediate_hidden_states, dim=-1) - output = GmmFunction.apply(intermediate_activations, tokens_per_expert, down_weight) - next_states = torch_npu.npu_moe_token_unpermute(output, row_ids_map, probs=routing_weights) - return next_states.view(*output_shape) - - -# ============================================================================= -# Section 1c: MoE Sparse Block -# ============================================================================= - - -def _topk_from_router_logits(module, hidden_states: torch.Tensor, router_logits: torch.Tensor): - """Compute top-k routing from router logits (Transformers 4.x style).""" - routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float) - routing_weights, router_indices = torch.topk(routing_weights, module.top_k, dim=-1) - if getattr(module, 'norm_topk_prob', True): - routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True) - routing_weights = routing_weights.to(hidden_states.dtype) - return routing_weights, router_indices - - -def _add_shared_expert(self, hidden_states: torch.Tensor, expert_output: torch.Tensor) -> torch.Tensor: - """Add shared expert output with sigmoid gating. - - Automatically skips if the module lacks shared_expert / shared_expert_gate. - """ - if not (hasattr(self, 'shared_expert') and hasattr(self, 'shared_expert_gate')): - return expert_output - - shared_expert_output = self.shared_expert(hidden_states) - shared_expert_output = (F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_expert_output) - return expert_output + shared_expert_output - - -def _qwen3_5_moe_forward_transformers_5(self, hidden_states: torch.Tensor, routing_weights: torch.Tensor, - selected_experts: torch.Tensor) -> torch.Tensor: - """Transformers 5.x path: gate returns (router_logits, routing_weights, selected_experts).""" - batch_size, sequence_length, hidden_dim = hidden_states.shape - hidden_states = hidden_states.view(-1, hidden_dim) - expert_output = self.experts(hidden_states, selected_experts, routing_weights) - expert_output = _add_shared_expert(self, hidden_states, expert_output) - return expert_output.reshape(batch_size, sequence_length, hidden_dim) - - -def _qwen3_5_moe_forward_linear_gate(self, hidden_states: torch.Tensor, router_logits: torch.Tensor) -> torch.Tensor: - """Transformers 4.x path: gate is nn.Linear and returns router logits.""" - batch_size, sequence_length, hidden_dim = hidden_states.shape - hidden_states = hidden_states.view(-1, hidden_dim) - routing_weights, selected_experts = _topk_from_router_logits(self, hidden_states, router_logits) - expert_output = self.experts(hidden_states, selected_experts, routing_weights) - expert_output = _add_shared_expert(self, hidden_states, expert_output) - return expert_output.reshape(batch_size, sequence_length, hidden_dim) - - -def npu_qwen3_5_moe_sparse_block_forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - """NPU-accelerated SparseMoeBlock forward with dual Transformers version support.""" - hidden_dim = hidden_states.shape[-1] - gate_output = self.gate(hidden_states.view(-1, hidden_dim)) - - if isinstance(gate_output, tuple): - _, routing_weights, selected_experts = gate_output - return _qwen3_5_moe_forward_transformers_5(self, hidden_states, routing_weights, selected_experts) - - return _qwen3_5_moe_forward_linear_gate(self, hidden_states, gate_output) - - -# ============================================================================= -# Section 2: Fused Operators (RMSNorm / RoPE / SwiGLU / SDPA) -# ============================================================================= - - -class NpuRMSNorm(nn.Module): - r"""Fused RMSNorm via ``torch_npu.npu_rms_norm``.""" - - def __init__(self, hidden_size: int, eps: float = 1e-6) -> None: - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - # Detect residual parameterization (e.g. Qwen3.5: scale = 1.0 + weight) - # once at initialization to avoid CPU-synchronizing Tensor.item() calls. - self._residual_param = abs(self.weight.data.mean().item()) < 0.3 - if self._residual_param: - logger.debug('[NPU] NpuRMSNorm using residual parameterization (1.0 + weight)') - - def _get_effective_weight(self, target_dtype: torch.dtype): - if self._residual_param: - return (1.0 + self.weight).to(dtype=target_dtype) - return self.weight.to(dtype=target_dtype) - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - scale = self._get_effective_weight(hidden_states.dtype) - return torch_npu.npu_rms_norm(hidden_states, scale, epsilon=self.variance_epsilon)[0] - - def extra_repr(self) -> str: - return f'{tuple(self.weight.shape)}, eps={self.variance_epsilon}' - - -def npu_gated_rms_norm_forward(self, hidden_states, gate=None): - """NPU forward for Gated RMSNorm. - - The FP32 mode is controlled by ``TWINKLE_NPU_GATED_RMSNorm_FP32``, - resolved once during patching and stored in ``self._twinkle_force_fp32``. - """ - input_dtype = hidden_states.dtype - _eps = getattr(self, 'variance_epsilon', getattr(self, 'eps', 1e-6)) - - # Read the cached flag; no env lookup in the hot path. - force_fp32 = getattr(self, '_twinkle_force_fp32', False) - if force_fp32: - hidden_states = hidden_states.to(torch.float32) - weight = self.weight.float() - gate = gate.to(torch.float32) if gate is not None else None - else: - weight = self.weight - - hidden_states = torch_npu.npu_rms_norm(hidden_states, weight, epsilon=_eps)[0] - - if gate is not None: - hidden_states = hidden_states * F.silu(gate) - - return hidden_states.to(input_dtype) - - -def _make_apply_npu_rotary_emb(): - _cached_partial = {} - - def _apply_npu_rotary_emb(q, k, cos, sin): - rotary_dim = cos.shape[-1] - query_dim = q.shape[-1] - shape_key = (rotary_dim, query_dim) - - use_partial = _cached_partial.get(shape_key) - if use_partial is None: - use_partial = rotary_dim < query_dim - _cached_partial[shape_key] = use_partial - - if use_partial: - q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] - k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] - q_embed = torch_npu.npu_rotary_mul(q_rot, cos, sin).to(q.dtype) - k_embed = torch_npu.npu_rotary_mul(k_rot, cos, sin).to(k.dtype) - q_embed = torch.cat([q_embed, q_pass], dim=-1) - k_embed = torch.cat([k_embed, k_pass], dim=-1) - else: - q_embed = torch_npu.npu_rotary_mul(q, cos, sin).to(q.dtype) - k_embed = torch_npu.npu_rotary_mul(k, cos, sin).to(k.dtype) - - return q_embed, k_embed - - return _apply_npu_rotary_emb - - -_apply_npu_rotary_emb = _make_apply_npu_rotary_emb() - - -def npu_apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): - """Fused RoPE via ``torch_npu.npu_rotary_mul`` with automatic Partial RoPE support.""" - unsqueeze_dim = _resolve_unsqueeze_dim(position_ids, unsqueeze_dim) - cos = cos.unsqueeze(unsqueeze_dim) - sin = sin.unsqueeze(unsqueeze_dim) - return _apply_npu_rotary_emb(q, k, cos, sin) - - -def npu_apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1): - """Multimodal RoPE for Qwen2.5-VL with automatic Partial RoPE support.""" - mrope_section = mrope_section * 2 - cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(unsqueeze_dim) - sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(unsqueeze_dim) - return _apply_npu_rotary_emb(q, k, cos, sin) - - -def npu_swiglu_forward(self, hidden_state): - """Fused SwiGLU (Qwen-style).""" - return self.down_proj( - torch_npu.npu_swiglu( - torch.cat((self.gate_proj(hidden_state), self.up_proj(hidden_state)), dim=-1), - dim=-1, - )) - - -def npu_sdpa_attention_forward(module, - query, - key, - value, - attention_mask, - dropout=0.0, - scaling=None, - is_causal=None, - **kwargs): - r"""SDPA with NPU compatibility fixes.""" - from transformers.integrations.sdpa_attention import repeat_kv - if hasattr(module, 'num_key_value_groups'): - key = repeat_kv(key, module.num_key_value_groups) - value = repeat_kv(value, module.num_key_value_groups) - - causal_mask = attention_mask - if attention_mask is not None and causal_mask.ndim == 4: - causal_mask = causal_mask[:, :, :, :key.shape[-2]] - - query, key, value = query.contiguous(), key.contiguous(), value.contiguous() - - if is_causal is None: - is_causal = query.shape[2] > 1 and causal_mask is None - - if causal_mask is not None and causal_mask.dtype != torch.bool: - causal_mask = torch.logical_not(causal_mask.bool()).to(query.device) - - attn_output = torch.nn.functional.scaled_dot_product_attention( - query, - key, - value, - attn_mask=causal_mask, - dropout_p=dropout, - scale=scaling, - is_causal=is_causal, - ) - return attn_output.transpose(1, 2).contiguous(), None - - -# ============================================================================= -# Section 2c: Flash Linear Attention (FLA) for Qwen3.5 -# ============================================================================= - - -def _patch_qwen3_5_fla(model=None) -> None: - """Enable Flash Linear Attention (FLA) fast path for Qwen3.5 on NPU. - - Controlled by environment variable ``TWINKLE_NPU_FLA`` (default: True). - """ - if not _is_env_enabled('TWINKLE_NPU_FLA', default=True): - logger.info('[NPU] [FLA] Disabled by TWINKLE_NPU_FLA environment variable') - return - - if not _is_torch_npu_available: - logger.info('[NPU] [FLA] Skip: NPU not available') - return - - # 1. Force FLA availability flag - def _is_fla_available() -> bool: - return True - - for utils_mod_name in ('transformers.utils', 'transformers.utils.import_utils'): - try: - utils_mod = importlib.import_module(utils_mod_name) - setattr(utils_mod, 'is_flash_linear_attention_available', _is_fla_available) - logger.info( - '[NPU] [FLA] Patched %s.is_flash_linear_attention_available', - utils_mod_name, - ) - except Exception as exc: - logger.debug('[NPU] [FLA] Failed to patch %s: %s', utils_mod_name, exc) - - # 2. Try MindSpeed Triton FLA backend - mindspeed_fla = None - try: - from .chunk_gated_delta_rule import chunk_gated_delta_rule as _ms_fla - mindspeed_fla = _ms_fla - logger.info('[NPU] [FLA] MindSpeed Triton chunk_gated_delta_rule loaded') - except ImportError as exc: - logger.warning('[NPU] [FLA] MindSpeed not available: %s', exc) - - # 3. Patch Qwen3.5 modeling modules - fla_target_modules = [ - 'transformers.models.qwen3_5.modeling_qwen3_5', - 'transformers.models.qwen3_5_moe.modeling_qwen3_5_moe', - ] - - for module_name in fla_target_modules: - module = import_optional_module(module_name) - if module is None: - logger.info('[NPU] [FLA] %s: module not found, skip', module_name) - continue - - # Only enable FLA flags if we actually have a backend to serve it - if mindspeed_fla is not None: - setattr(module, 'is_flash_linear_attention_available', _is_fla_available) - setattr(module, 'is_fast_path_available', True) - - # Disable CUDA-only fused op - if hasattr(module, 'FusedRMSNormGated'): - setattr(module, 'FusedRMSNormGated', None) - logger.info('[NPU] [FLA] %s: disabled FusedRMSNormGated', module_name) - - # Replace chunk_gated_delta_rule with MindSpeed implementation - setattr(module, 'chunk_gated_delta_rule', mindspeed_fla) - logger.info( - '[NPU] [FLA] Patched %s.chunk_gated_delta_rule -> MindSpeed', - module_name, - ) - else: - logger.warning( - '[NPU] [FLA] %s: MindSpeed unavailable, FLA flags NOT set', - module_name, - ) - - # 4. Traverse instantiated model and replace per-layer chunk_gated_delta_rule - if model is not None and mindspeed_fla is not None: - # Resolve the underlying PyTorch model from TransformersModel wrapper - model = getattr(model, 'model', getattr(model, 'module', model)) - if not hasattr(model, 'named_modules'): - logger.warning('[NPU] [FLA] Model does not support named_modules, skipping instance patch') - return - patched_instances = 0 - patched_causal = 0 - for _name, _module in model.named_modules(): - if hasattr(_module, 'chunk_gated_delta_rule') and callable(getattr(_module, 'chunk_gated_delta_rule')): - if _module.chunk_gated_delta_rule is mindspeed_fla: - continue - - _module.chunk_gated_delta_rule = mindspeed_fla - # Mark as NPU-patched to prevent it from being overwritten by SP - _module._twinkle_npu_patched = True - patched_instances += 1 - logger.debug( - '[NPU] [FLA] Replaced %s(%s).chunk_gated_delta_rule -> MindSpeed', - _name, - type(_module).__name__, - ) - - if hasattr(_module, 'causal_conv1d_fn'): - current = getattr(_module, 'causal_conv1d_fn') - - if current is npu_causal_conv1d_fn: - continue - _module.causal_conv1d_fn = npu_causal_conv1d_fn - patched_causal += 1 - logger.debug( - '[NPU] [FLA] Replaced %s(%s).causal_conv1d_fn (was %s) -> MindSpeed', - _name, - type(_module).__name__, - current, - ) - - if patched_instances > 0: - logger.info( - '[NPU] [FLA] Patched %d linear attention instance(s)', - patched_instances, - ) - if patched_causal > 0: - logger.info( - '[NPU] [FLA] Patched %d causal_conv1d instance(s)', - patched_causal, - ) - else: - logger.info('[NPU] [FLA] No causal_conv1d_fn instances found in model') - - -# ============================================================================= -# Section 3: Patching Helpers -# ============================================================================= - - -def _patch_sdpa_forward() -> None: - from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, AttentionInterface - AttentionInterface._global_mapping['sdpa'] = npu_sdpa_attention_forward - ALL_ATTENTION_FUNCTIONS['sdpa'] = npu_sdpa_attention_forward - logger.debug('[NPU] [SDPA] Patched global SDPA attention forward') - - -def _patch_rmsnorm(module, class_name: str) -> None: - """Patch RMSNorm class with NPU-optimized implementation.""" - if 'Gated' in class_name: - orig_cls = getattr(module, class_name) - setattr(orig_cls, 'forward', npu_gated_rms_norm_forward) - - # Cache the FP32 env flag once at patch time to avoid per-forward overhead. - orig_cls._twinkle_force_fp32 = os.environ.get('TWINKLE_NPU_GATED_RMSNorm_FP32', - '0').lower() in ('1', 'true', 'on', 'yes') - if orig_cls._twinkle_force_fp32: - logger.info( - '[NPU] [RMSNorm] %s.%s forced to FP32 mode', - module.__name__, - class_name, - ) - - logger.info( - '[NPU] [RMSNorm] Patched %s.%s.forward -> npu_gated_rms_norm_forward', - module.__name__, - class_name, - ) - else: - setattr(module, class_name, NpuRMSNorm) - logger.info( - '[NPU] [RMSNorm] Patched %s.%s -> NpuRMSNorm', - module.__name__, - class_name, - ) - - -def _patch_rope(module, func_name: str) -> None: - setattr(module, func_name, npu_apply_rotary_pos_emb) - logger.debug( - '[NPU] [RoPE] Patched %s.%s -> npu_apply_rotary_pos_emb', - module.__name__, - func_name, - ) - - -def _patch_swiglu(module, class_name: str) -> None: - setattr(getattr(module, class_name), 'forward', npu_swiglu_forward) - logger.debug( - '[NPU] [MLP] Patched %s.%s.forward -> npu_swiglu_forward', - module.__name__, - class_name, - ) - - -def _patch_moe_sparse_block(module, class_name: str) -> None: - """Patch SparseMoeBlock forward with NPU-optimized implementation.""" - setattr(getattr(module, class_name), 'forward', npu_qwen3_5_moe_sparse_block_forward) - logger.info( - '[NPU] [MoE] Patched %s.%s.forward -> npu_qwen3_5_moe_sparse_block_forward', - module.__name__, - class_name, - ) - - -def _patch_moe_experts(module, class_name: str) -> None: - """Patch packed Experts forward with NPU grouped matmul.""" - setattr(getattr(module, class_name), 'forward', npu_packed_moe_experts_forward) - logger.debug( - '[NPU] [MoE] Patched %s.%s.forward -> npu_packed_moe_experts_forward', - module.__name__, - class_name, - ) - - -# ============================================================================= -# Section 4: Environment Control -# ============================================================================= - - -def _is_env_enabled(var_name: str, default: bool = True) -> bool: - """Check whether an environment variable is enabled. - - Supports: ``1``/``true``/``on``/``yes`` (force on), - ``0``/``false``/``off``/``no`` (force off), - unset (use ``default``). - """ - env = os.environ.get(var_name, '').lower().strip() - if not env: - return default - if env in ('1', 'true', 'on', 'yes'): - return True - if env in ('0', 'false', 'off', 'no'): - logger.info('[NPU] %s=%s: disabled.', var_name, env) - return False - return default - - -# ============================================================================= -# Section 5: Unified Patching Logic (Fused Ops) -# ============================================================================= - - -def _apply_all_fused_ops(model=None) -> None: - """Apply fused ops to supported model families.""" - logger.info('[NPU] === _apply_all_fused_ops ENTERED ===') - if not _is_torch_npu_available: - return - - if not _is_env_enabled('TWINKLE_NPU_FUSED_OPS', default=True): - return - - target_archs = set() - if model is not None: - config = getattr(model, 'hf_config', getattr(model, 'config', None)) - archs = getattr(config, 'architectures', None) if config else None - if archs: - target_archs = set(archs) - logger.debug('[NPU] Detected architectures for fused ops: %s', archs) - - logger.info('[NPU] Auto-applying fused ops to supported model families') - - _patch_sdpa_forward() - - model_families = [ - ('transformers.models.qwen3.modeling_qwen3', 'Qwen3', 'Qwen3MLP', 'Qwen3ForCausalLM'), - ('transformers.models.qwen3_moe.modeling_qwen3_moe', 'Qwen3Moe', 'Qwen3MoeMLP', 'Qwen3MoeForCausalLM'), - ( - 'transformers.models.qwen2_5_vl.modeling_qwen2_5_vl', - 'Qwen2_5_VL', - 'Qwen2MLP', - 'Qwen2_5_VLForConditionalGeneration', - ), - ( - 'transformers.models.qwen3_5_moe.modeling_qwen3_5_moe', - 'Qwen3_5Moe', - 'Qwen3_5MoeMLP', - 'Qwen3MoeForCausalLM', - ), - ] - - modeling_qwen3_5 = import_optional_module('transformers.models.qwen3_5.modeling_qwen3_5') - if modeling_qwen3_5 is not None: - model_families.append(( - 'transformers.models.qwen3_5.modeling_qwen3_5', - 'Qwen3_5', - 'Qwen3_5MLP', - 'Qwen3_5ForCausalLM', - )) - - modeling_qwen3_5_moe = import_optional_module('transformers.models.qwen3_5_moe.modeling_qwen3_5_moe') - if modeling_qwen3_5_moe is not None: - model_families.append(( - 'transformers.models.qwen3_5_moe.modeling_qwen3_5_moe', - 'Qwen3_5Moe', - 'Qwen3_5MoeMLP', - 'Qwen3_5MoeForCausalLM', - )) - - patched_count = 0 - for module_name, prefix, mlp_name, trigger_arch in model_families: - try: - module = importlib.import_module(module_name) - - # RMSNorm - rmsnorm_cls = f'{prefix}RMSNorm' - if hasattr(module, rmsnorm_cls): - _patch_rmsnorm(module, rmsnorm_cls) - patched_count += 1 - - # RoPE - if hasattr(module, 'apply_rotary_pos_emb'): - _patch_rope(module, 'apply_rotary_pos_emb') - patched_count += 1 - - # SwiGLU / MLP - if hasattr(module, mlp_name): - _patch_swiglu(module, mlp_name) - patched_count += 1 - - experts_cls = f'{prefix}Experts' - if hasattr(module, experts_cls): - _patch_moe_experts(module, experts_cls) - patched_count += 1 - - sparse_cls = f'{prefix}SparseMoeBlock' - if hasattr(module, sparse_cls): - _patch_moe_sparse_block(module, sparse_cls) - patched_count += 1 - - if prefix == 'Qwen2_5_VL': - if hasattr(module, 'Qwen2_5_VLMLP'): - _patch_swiglu(module, 'Qwen2_5_VLMLP') - patched_count += 1 - setattr(module, 'apply_multimodal_rotary_pos_emb', npu_apply_multimodal_rotary_pos_emb) - logger.debug('[NPU] Patched Qwen2_5_VL multimodal RoPE') - - if prefix == 'Qwen3_5': - gated_rmsnorm_cls = f'{prefix}GatedRMSNorm' - if hasattr(module, gated_rmsnorm_cls): - _patch_rmsnorm(module, gated_rmsnorm_cls) - patched_count += 1 - if hasattr(module, 'Qwen3_5VisionMLP'): - _patch_swiglu(module, 'Qwen3_5VisionMLP') - patched_count += 1 - if hasattr(module, 'Qwen3_5VisionRMSNorm'): - _patch_rmsnorm(module, 'Qwen3_5VisionRMSNorm') - patched_count += 1 - - if prefix == 'Qwen3_5Moe': - if hasattr(module, 'Qwen3_5MoeGatedRMSNorm'): - _patch_rmsnorm(module, 'Qwen3_5MoeGatedRMSNorm') - patched_count += 1 - - logger.debug('[NPU] Patched %s fused ops', prefix) - except ImportError: - pass - - if not target_archs: - patched_count += _discover_and_patch_unknown_models() - - _patch_qwen3_5_fla(model) - - logger.info('[NPU] Auto-patched %d components', patched_count) - - -# ============================================================================= -# Section 5b: Dynamic model discovery (no hard-coding) -# ============================================================================= - - -def _discover_and_patch_unknown_models() -> int: - """Dynamically discover and patch additional transformers model families.""" - patched = 0 - already_patched_modules = { - 'transformers.models.qwen3.modeling_qwen3', - 'transformers.models.qwen3_moe.modeling_qwen3_moe', - 'transformers.models.qwen2_5_vl.modeling_qwen2_5_vl', - 'transformers.models.qwen3_5.modeling_qwen3_5', - 'transformers.models.qwen3_5_moe.modeling_qwen3_5_moe', - } - - try: - import transformers.models as models_pkg - except ImportError: - return 0 - - candidate_modules = [] - for model_name in dir(models_pkg): - if model_name.startswith('_'): - continue - modeling_path = f'transformers.models.{model_name}.modeling_{model_name}' - if modeling_path not in already_patched_modules: - candidate_modules.append(modeling_path) - - for module_name in candidate_modules: - module = import_optional_module(module_name) - if module is None: - continue - - has_rmsnorm = any('RMSNorm' in attr_name and isinstance(getattr(module, attr_name, None), type) - for attr_name in dir(module)) - has_rope = hasattr(module, 'apply_rotary_pos_emb') - has_mlp = any( - attr_name.endswith('MLP') and isinstance(getattr(module, attr_name, None), type) - for attr_name in dir(module)) - - if not (has_rmsnorm or has_rope or has_mlp): - continue - - for attr_name in dir(module): - if attr_name.startswith('_'): - continue - obj = getattr(module, attr_name, None) - if not isinstance(obj, type): - continue - - if 'RMSNorm' in attr_name and issubclass(obj, nn.Module): - try: - _patch_rmsnorm(module, attr_name) - patched += 1 - except Exception as exc: - logger.debug('[NPU] Failed to patch %s.%s: %s', module_name, attr_name, exc) - - if attr_name.endswith('MLP') and hasattr(obj, 'forward'): - try: - _patch_swiglu(module, attr_name) - patched += 1 - except Exception as exc: - logger.debug('[NPU] Failed to patch %s.%s: %s', module_name, attr_name, exc) - - if attr_name.endswith('Experts') and hasattr(obj, 'forward'): - try: - _patch_moe_experts(module, attr_name) - patched += 1 - except Exception as exc: - logger.debug('[NPU] Failed to patch %s.%s: %s', module_name, attr_name, exc) - - if attr_name.endswith('SparseMoeBlock') and hasattr(obj, 'forward'): - try: - _patch_moe_sparse_block(module, attr_name) - patched += 1 - except Exception as exc: - logger.debug('[NPU] Failed to patch %s.%s: %s', module_name, attr_name, exc) - - if has_rope: - try: - _patch_rope(module, 'apply_rotary_pos_emb') - patched += 1 - except Exception as exc: - logger.debug('[NPU] Failed to patch %s.apply_rotary_pos_emb: %s', module_name, exc) - - if patched > 0: - logger.debug('[NPU] Dynamically patched %s', module_name) - - return patched - - -# ============================================================================= -# Section 6: Public API -# ============================================================================= - - -def apply_npu_patch(model=None) -> None: - """Apply all NPU patches. - - Ascend NPU optimizations applied: - - MoE grouped_matmul (GMM) - - RMSNorm fused kernel - - RoPE fused kernel - - SwiGLU fused kernel - - SDPA Attention compatibility fixes - - Flash Linear Attention (FLA) for Qwen3.5 - - Causal Conv1D Triton kernel for linear attention - - When ``model`` is **not** provided, the GMM patch is **skipped** by default - (EP cannot be detected without a model instance). - - When ``model`` is provided, the GMM patch is evaluated with EP detection: - - EP enabled → apply GMM patch (efficient on small sharded weights). - - EP not enabled → skip GMM patch (avoid ~8x contiguous-copy overhead). - - Environment variables: - - ``TWINKLE_NPU_PATCH``: overall switch (``1``/``0``) - - ``TWINKLE_NPU_FUSED_OPS``: fused ops switch (``1``/``0``) - - ``TWINKLE_NPU_GMM_PATCH``: MoE GMM switch (``1``/``0``/unset). - When unset: skip the patch by default. - When ``1``: EP-aware — patch is applied **only if EP is enabled**; - without EP the native grouped_mm or per-expert fallback is used - (avoiding ~8x overhead from contiguous copies). - When ``0``: disable the patch regardless. - - ``TWINKLE_NPU_FLA``: FLA switch (``1``/``0``) - - ``TWINKLE_NPU_GATED_RMSNorm_FP32``: force FP32 in Gated RMSNorm (``1``/``0``) - - Args: - model: Optional model instance. If not provided, GMM patch is skipped. - If provided, GMM patch is evaluated with EP detection on the model. - """ - global _NPU_PATCH_APPLIED - - if not _is_env_enabled('TWINKLE_NPU_PATCH', default=True): - return - - if _NPU_PATCH_APPLIED: - logger.debug('[NPU] Patches already applied, skipping.') - return - - try: - import torch_npu - except ImportError: - logger.warning('torch_npu not available. Skipping NPU patches.') - return - - _apply_hf_moe_grouped_mm_patch(model) - - _apply_all_fused_ops(model) - - _NPU_PATCH_APPLIED = True - logger.info('[NPU] All patches applied successfully') - - -def register_npu_fused_function_kernels() -> None: - """Register NPU fused ops as Twinkle function kernels (optional).""" - if not _is_torch_npu_available: - return - - from .function import register_function_kernel - - register_function_kernel( - func_name='apply_rotary_pos_emb', - target_module='transformers.modeling_rope_utils', - func_impl=npu_apply_rotary_pos_emb, - device='npu', - mode='train', - ) - register_function_kernel( - func_name='sdpa_attention_forward', - target_module='transformers.integrations.sdpa_attention', - func_impl=npu_sdpa_attention_forward, - device='npu', - mode='train', - ) - logger.info('[NPU] Registered fused function kernels for training') diff --git a/src/twinkle/kernel/npu_impls/__init__.py b/src/twinkle/kernel/npu_impls/__init__.py new file mode 100644 index 000000000..47d2a0bfa --- /dev/null +++ b/src/twinkle/kernel/npu_impls/__init__.py @@ -0,0 +1,27 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Per-layer NPU implementations consumed by ``npu_builtin()``. + +Each impl is contracted to be applied via ``m.__class__ = ImplCls`` (class +replacement) or ``setattr(module, attr, fn)`` (function replacement). No impl +here is meant to be instantiated directly. +""" +from .attention import npu_sdpa_attention_forward +from .fla import apply_qwen3_5_fla +from .moe import GmmFunction, npu_grouped_mm, npu_packed_moe_experts_forward, npu_qwen3_5_moe_sparse_block_forward +from .rms_norm import NpuRMSNorm, npu_gated_rms_norm_forward +from .rotary import npu_apply_multimodal_rotary_pos_emb, npu_apply_rotary_pos_emb +from .swiglu import npu_swiglu_forward + +__all__ = [ + 'NpuRMSNorm', + 'npu_gated_rms_norm_forward', + 'npu_apply_rotary_pos_emb', + 'npu_apply_multimodal_rotary_pos_emb', + 'npu_swiglu_forward', + 'npu_sdpa_attention_forward', + 'GmmFunction', + 'npu_grouped_mm', + 'npu_packed_moe_experts_forward', + 'npu_qwen3_5_moe_sparse_block_forward', + 'apply_qwen3_5_fla', +] diff --git a/src/twinkle/kernel/npu_impls/attention.py b/src/twinkle/kernel/npu_impls/attention.py new file mode 100644 index 000000000..c63a858f1 --- /dev/null +++ b/src/twinkle/kernel/npu_impls/attention.py @@ -0,0 +1,54 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""SDPA forward with Ascend NPU compatibility fixes.""" +from __future__ import annotations + +import torch + + +def npu_sdpa_attention_forward( + module, + query, + key, + value, + attention_mask, + dropout=0.0, + scaling=None, + is_causal=None, + **kwargs, +): + """Drop-in replacement for ``transformers.integrations.sdpa_attention.sdpa_attention_forward``. + + Fixes: + - Repeats KV heads (NPU SDPA does not auto-broadcast num_kv_groups). + - Truncates causal_mask to key length. + - Forces contiguous tensors (NPU SDPA requirement). + - Inverts boolean masks (NPU treats ``True`` as masked). + """ + from transformers.integrations.sdpa_attention import repeat_kv + + if hasattr(module, 'num_key_value_groups'): + key = repeat_kv(key, module.num_key_value_groups) + value = repeat_kv(value, module.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None and causal_mask.ndim == 4: + causal_mask = causal_mask[:, :, :, :key.shape[-2]] + + query, key, value = query.contiguous(), key.contiguous(), value.contiguous() + + if is_causal is None: + is_causal = query.shape[2] > 1 and causal_mask is None + + if causal_mask is not None and causal_mask.dtype != torch.bool: + causal_mask = torch.logical_not(causal_mask.bool()).to(query.device) + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query, + key, + value, + attn_mask=causal_mask, + dropout_p=dropout, + scale=scaling, + is_causal=is_causal, + ) + return attn_output.transpose(1, 2).contiguous(), None diff --git a/src/twinkle/kernel/npu_impls/fla.py b/src/twinkle/kernel/npu_impls/fla.py new file mode 100644 index 000000000..847832f3a --- /dev/null +++ b/src/twinkle/kernel/npu_impls/fla.py @@ -0,0 +1,102 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Qwen3.5 Flash Linear Attention enablement for Ascend NPU.""" +from __future__ import annotations + +import importlib +import os + +from twinkle import get_logger + +logger = get_logger() + + +def _is_env_enabled(var: str, default: bool = True) -> bool: + env = os.environ.get(var, '').lower().strip() + if not env: + return default + if env in ('1', 'true', 'on', 'yes'): + return True + if env in ('0', 'false', 'off', 'no'): + return False + return default + + +def _import_optional(name: str): + try: + return importlib.import_module(name) + except ImportError: + return None + + +def apply_qwen3_5_fla(model=None) -> int: + """Enable Flash Linear Attention fast path for Qwen3.5 on NPU. + + Returns the count of patched per-layer instances (0 when disabled or when + prerequisites are missing). Safe to call multiple times. + """ + if not _is_env_enabled('TWINKLE_NPU_FLA', default=True): + logger.info('[NPU] [FLA] Disabled by TWINKLE_NPU_FLA') + return 0 + + if _import_optional('torch_npu') is None: + logger.info('[NPU] [FLA] Skip: torch_npu unavailable') + return 0 + + # 1. Confirm the MindSpeed Triton kernel is actually importable BEFORE + # flipping any global availability flags. If we flip the flag and then + # fail to install the kernel, HF transformers would route Qwen3.5 onto + # a FLA fast path whose kernel is missing -> runtime failure on NPU. + try: + from twinkle.kernel.causal_conv1d import npu_causal_conv1d_fn + from twinkle.kernel.chunk_gated_delta_rule import chunk_gated_delta_rule as mindspeed_fla + except ImportError as exc: + logger.warning('[NPU] [FLA] MindSpeed unavailable: %s', exc) + return 0 + + # 2. Only now can we safely claim FLA is available: flip the global flags + # and install the kernel path on Qwen3.5 modeling modules. + def _is_fla_available() -> bool: + return True + + for utils_mod_name in ('transformers.utils', 'transformers.utils.import_utils'): + utils_mod = _import_optional(utils_mod_name) + if utils_mod is not None: + setattr(utils_mod, 'is_flash_linear_attention_available', _is_fla_available) + + # 3. Patch Qwen3.5 modeling modules + fla_target_modules = [ + 'transformers.models.qwen3_5.modeling_qwen3_5', + 'transformers.models.qwen3_5_moe.modeling_qwen3_5_moe', + ] + for module_name in fla_target_modules: + module = _import_optional(module_name) + if module is None: + continue + setattr(module, 'is_flash_linear_attention_available', _is_fla_available) + setattr(module, 'is_fast_path_available', True) + if hasattr(module, 'FusedRMSNormGated'): + setattr(module, 'FusedRMSNormGated', None) + setattr(module, 'chunk_gated_delta_rule', mindspeed_fla) + + # 4. Traverse model and patch per-layer attributes + if model is None: + return 0 + + root = getattr(model, 'model', getattr(model, 'module', model)) + if not hasattr(root, 'named_modules'): + return 0 + + patched_instances = 0 + for _name, _module in root.named_modules(): + if hasattr(_module, 'chunk_gated_delta_rule') and callable(getattr(_module, 'chunk_gated_delta_rule')): + if _module.chunk_gated_delta_rule is not mindspeed_fla: + _module.chunk_gated_delta_rule = mindspeed_fla + _module._twinkle_npu_patched = True + patched_instances += 1 + if hasattr(_module, 'causal_conv1d_fn'): + if getattr(_module, 'causal_conv1d_fn') is not npu_causal_conv1d_fn: + _module.causal_conv1d_fn = npu_causal_conv1d_fn + + if patched_instances: + logger.info('[NPU] [FLA] Patched %d linear attention instance(s)', patched_instances) + return patched_instances diff --git a/src/twinkle/kernel/npu_impls/moe.py b/src/twinkle/kernel/npu_impls/moe.py new file mode 100644 index 000000000..1f847669b --- /dev/null +++ b/src/twinkle/kernel/npu_impls/moe.py @@ -0,0 +1,159 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""MoE GMM + packed-experts + sparse-block impls for Ascend NPU.""" +from __future__ import annotations + +import torch +import torch.nn.functional as F + + +class GmmFunction(torch.autograd.Function): + """Custom autograd function for NPU grouped matrix multiplication.""" + + @staticmethod + def forward(ctx, x: torch.Tensor, group_list: torch.Tensor, weight_ekn: torch.Tensor): + import torch_npu + group_list = group_list.to(torch.int64) + ctx.save_for_backward(x, group_list, weight_ekn) + outputs = torch_npu.npu_grouped_matmul( + [x], + [weight_ekn], + group_list=group_list, + group_type=0, + split_item=2, + group_list_type=1, + ) + return outputs[0] + + @staticmethod + def backward(ctx, grad_output: torch.Tensor): + import torch_npu + x, group_list, weight_ekn = ctx.saved_tensors + grad_input = torch_npu.npu_grouped_matmul( + [grad_output], + [weight_ekn.transpose(-2, -1).contiguous()], + bias=None, + group_list=group_list, + group_type=0, + split_item=2, + group_list_type=1, + )[0] + grad_weight = torch_npu.npu_grouped_matmul( + [x.transpose(0, 1)], + [grad_output], + bias=None, + group_list=group_list, + group_type=2, + split_item=3, + group_list_type=1, + )[0] + return grad_input, None, grad_weight.contiguous() + + +def npu_grouped_mm(input: torch.Tensor, weight_ekn: torch.Tensor, offs: torch.Tensor) -> torch.Tensor: + """Drop-in replacement for ``transformers.integrations.moe._grouped_mm``.""" + counts = torch.empty_like(offs) + counts[0] = offs[0] + if offs.numel() > 1: + counts[1:] = offs[1:] - offs[:-1] + counts = counts.to(torch.int64) + return GmmFunction.apply(input, counts, weight_ekn) + + +def _normalize_packed_expert_weights(module, input_dtype, hidden_dim): + gate_up_proj = module.gate_up_proj.to(input_dtype) + down_proj = module.down_proj.to(input_dtype) + if gate_up_proj.shape[1] == hidden_dim: + gate_up_weight = gate_up_proj + elif gate_up_proj.shape[2] == hidden_dim: + gate_up_weight = gate_up_proj.transpose(1, 2) + else: + raise RuntimeError(f'Unsupported gate_up_proj shape: {tuple(gate_up_proj.shape)}.') + if down_proj.shape[2] == hidden_dim: + down_weight = down_proj + elif down_proj.shape[1] == hidden_dim: + down_weight = down_proj.transpose(1, 2) + else: + raise RuntimeError(f'Unsupported down_proj shape: {tuple(down_proj.shape)}.') + return gate_up_weight, down_weight + + +def _get_cached_expert_weights(self, target_dtype, hidden_dim): + requires_grad = ( + getattr(self.gate_up_proj, 'requires_grad', False) or getattr(self.down_proj, 'requires_grad', False)) + cache_attr = '_npu_expert_cache' + if not requires_grad and hasattr(self, cache_attr): + cached_dtype, cached_gv, cached_dv, cached = getattr(self, cache_attr) + if (cached_dtype == target_dtype and cached_gv == self.gate_up_proj._version + and cached_dv == self.down_proj._version): + return cached + weights = _normalize_packed_expert_weights(self, target_dtype, hidden_dim) + if not requires_grad: + setattr(self, cache_attr, (target_dtype, self.gate_up_proj._version, self.down_proj._version, weights)) + return weights + + +def npu_packed_moe_experts_forward(self, hidden_states, a, b): + """Packed MoE Experts.forward using NPU grouped matmul. + + Accepts both call orderings: ``(hidden_states, routing_weights, router_indices)`` + and ``(hidden_states, router_indices, routing_weights)`` — distinguishes by dtype. + """ + import torch_npu + if a.dtype in {torch.int8, torch.int16, torch.int32, torch.int64, torch.uint8}: + router_indices, routing_weights = a, b + else: + routing_weights, router_indices = a, b + + output_shape = hidden_states.shape + hidden_dim = output_shape[-1] + hidden_states = hidden_states.reshape(-1, hidden_dim) + + if routing_weights.shape != router_indices.shape: + routing_weights = torch.gather(routing_weights, dim=-1, index=router_indices.to(torch.long)) + routing_weights = routing_weights.to(hidden_states.dtype) + router_indices = router_indices.to(torch.int32) + + permuted, row_ids_map = torch_npu.npu_moe_token_permute(hidden_states, router_indices) + tokens_per_expert = torch.bincount(router_indices.view(-1), minlength=self.num_experts).to(torch.int64) + gate_up_weight, down_weight = _get_cached_expert_weights(self, hidden_states.dtype, hidden_dim) + + intermediate = GmmFunction.apply(permuted, tokens_per_expert, gate_up_weight) + activated = torch_npu.npu_swiglu(intermediate, dim=-1) + output = GmmFunction.apply(activated, tokens_per_expert, down_weight) + next_states = torch_npu.npu_moe_token_unpermute(output, row_ids_map, probs=routing_weights) + return next_states.view(*output_shape) + + +def _topk_from_router_logits(module, hidden_states, router_logits): + routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float) + routing_weights, router_indices = torch.topk(routing_weights, module.top_k, dim=-1) + if getattr(module, 'norm_topk_prob', True): + routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True) + routing_weights = routing_weights.to(hidden_states.dtype) + return routing_weights, router_indices + + +def _add_shared_expert(self, hidden_states, expert_output): + if not (hasattr(self, 'shared_expert') and hasattr(self, 'shared_expert_gate')): + return expert_output + shared = self.shared_expert(hidden_states) + shared = F.sigmoid(self.shared_expert_gate(hidden_states)) * shared + return expert_output + shared + + +def npu_qwen3_5_moe_sparse_block_forward(self, hidden_states): + """SparseMoeBlock.forward replacement (Transformers 4.x and 5.x compatible).""" + batch_size, sequence_length, hidden_dim = hidden_states.shape + gate_output = self.gate(hidden_states.view(-1, hidden_dim)) + + if isinstance(gate_output, tuple): + _, routing_weights, selected_experts = gate_output + flat = hidden_states.view(-1, hidden_dim) + expert_output = self.experts(flat, selected_experts, routing_weights) + else: + flat = hidden_states.view(-1, hidden_dim) + routing_weights, selected_experts = _topk_from_router_logits(self, flat, gate_output) + expert_output = self.experts(flat, selected_experts, routing_weights) + + expert_output = _add_shared_expert(self, flat, expert_output) + return expert_output.reshape(batch_size, sequence_length, hidden_dim) diff --git a/src/twinkle/kernel/npu_impls/rms_norm.py b/src/twinkle/kernel/npu_impls/rms_norm.py new file mode 100644 index 000000000..7fd7a5f70 --- /dev/null +++ b/src/twinkle/kernel/npu_impls/rms_norm.py @@ -0,0 +1,72 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Fused RMSNorm impls for Ascend NPU. + +Designed for class-replacement: do not define ``__init__``; rely on the +attributes already present on the original instance. +""" +from __future__ import annotations + +import os +import torch +import torch.nn as nn +import torch.nn.functional as F + +from twinkle import get_logger + +logger = get_logger() + + +class NpuRMSNorm(nn.Module): + """Class-replacement impl for HF RMSNorm variants. + + Required instance attributes (provided by the original class): + - ``weight``: ``nn.Parameter`` + - ``variance_epsilon`` *or* ``eps``: float + """ + + def _twinkle_residual_param(self) -> bool: + """Lazily detect residual parameterization (e.g. Qwen3.5: scale = 1 + weight).""" + cached = getattr(self, '_twinkle_residual_cached', None) + if cached is None: + cached = abs(self.weight.data.mean().item()) < 0.3 + self._twinkle_residual_cached = cached + if cached: + logger.debug('[NPU] NpuRMSNorm using residual parameterization (1.0 + weight)') + return cached + + def _twinkle_eps(self) -> float: + return getattr(self, 'variance_epsilon', getattr(self, 'eps', 1e-6)) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + import torch_npu + target_dtype = hidden_states.dtype + if self._twinkle_residual_param(): + scale = (1.0 + self.weight).to(target_dtype) + else: + scale = self.weight.to(target_dtype) + return torch_npu.npu_rms_norm(hidden_states, scale, epsilon=self._twinkle_eps())[0] + + +# Resolved once at import: matches the legacy "patch-time, process-wide" invariant. +# Mid-process env mutation will not retroactively change behavior. +_FORCE_FP32 = os.environ.get('TWINKLE_NPU_GATED_RMSNorm_FP32', '0').lower() in ('1', 'true', 'on', 'yes') + + +def npu_gated_rms_norm_forward(self, hidden_states, gate=None): + """Forward replacement for Gated RMSNorm variants (e.g. Qwen3.5-MoE).""" + import torch_npu + + input_dtype = hidden_states.dtype + _eps = getattr(self, 'variance_epsilon', getattr(self, 'eps', 1e-6)) + + if _FORCE_FP32: + hidden_states = hidden_states.to(torch.float32) + weight = self.weight.float() + gate = gate.to(torch.float32) if gate is not None else None + else: + weight = self.weight + + hidden_states = torch_npu.npu_rms_norm(hidden_states, weight, epsilon=_eps)[0] + if gate is not None: + hidden_states = hidden_states * F.silu(gate) + return hidden_states.to(input_dtype) diff --git a/src/twinkle/kernel/npu_impls/rotary.py b/src/twinkle/kernel/npu_impls/rotary.py new file mode 100644 index 000000000..aa70b1ffb --- /dev/null +++ b/src/twinkle/kernel/npu_impls/rotary.py @@ -0,0 +1,66 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Fused RoPE impls for Ascend NPU (lazy ``torch_npu`` import).""" +from __future__ import annotations + +import torch + + +def _resolve_unsqueeze_dim(position_ids=None, unsqueeze_dim=1): + if isinstance(position_ids, int) and unsqueeze_dim == 1: + return position_ids + return unsqueeze_dim + + +def _make_apply_npu_rotary_emb(): + """Closure with per-shape Partial-RoPE detection cache.""" + _cached_partial: dict[tuple[int, int], bool] = {} + + def _apply(q, k, cos, sin): + import torch_npu + rotary_dim = cos.shape[-1] + query_dim = q.shape[-1] + shape_key = (rotary_dim, query_dim) + + use_partial = _cached_partial.get(shape_key) + if use_partial is None: + use_partial = rotary_dim < query_dim + _cached_partial[shape_key] = use_partial + + if use_partial: + q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] + k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] + q_embed = torch_npu.npu_rotary_mul(q_rot, cos, sin).to(q.dtype) + k_embed = torch_npu.npu_rotary_mul(k_rot, cos, sin).to(k.dtype) + q_embed = torch.cat([q_embed, q_pass], dim=-1) + k_embed = torch.cat([k_embed, k_pass], dim=-1) + else: + q_embed = torch_npu.npu_rotary_mul(q, cos, sin).to(q.dtype) + k_embed = torch_npu.npu_rotary_mul(k, cos, sin).to(k.dtype) + return q_embed, k_embed + + return _apply + + +_apply_npu_rotary_emb = _make_apply_npu_rotary_emb() + + +def npu_apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Fused RoPE via ``torch_npu.npu_rotary_mul`` with Partial-RoPE support.""" + unsqueeze_dim = _resolve_unsqueeze_dim(position_ids, unsqueeze_dim) + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + return _apply_npu_rotary_emb(q, k, cos, sin) + + +def npu_apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1): + """Multimodal RoPE for Qwen2.5-VL with Partial-RoPE support.""" + mrope_section = mrope_section * 2 + cos = torch.cat( + [m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], + dim=-1, + ).unsqueeze(unsqueeze_dim) + sin = torch.cat( + [m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], + dim=-1, + ).unsqueeze(unsqueeze_dim) + return _apply_npu_rotary_emb(q, k, cos, sin) diff --git a/src/twinkle/kernel/npu_impls/swiglu.py b/src/twinkle/kernel/npu_impls/swiglu.py new file mode 100644 index 000000000..782e16cc8 --- /dev/null +++ b/src/twinkle/kernel/npu_impls/swiglu.py @@ -0,0 +1,19 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Fused SwiGLU forward for Ascend NPU.""" +from __future__ import annotations + +import torch + + +def npu_swiglu_forward(self, hidden_state): + """Fused Qwen-style SwiGLU. + + Used as a class-attribute replacement on HF MLP classes. + Required instance attributes: ``gate_proj``, ``up_proj``, ``down_proj``. + """ + import torch_npu + return self.down_proj( + torch_npu.npu_swiglu( + torch.cat((self.gate_proj(hidden_state), self.up_proj(hidden_state)), dim=-1), + dim=-1, + )) diff --git a/src/twinkle/kernel/registry.py b/src/twinkle/kernel/registry.py deleted file mode 100644 index d03f510f9..000000000 --- a/src/twinkle/kernel/registry.py +++ /dev/null @@ -1,183 +0,0 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Type - -from twinkle import get_logger -from .base import DeviceType, ModeType, is_kernels_available - -if TYPE_CHECKING: - from kernels.layer.func import FuncRepositoryProtocol - -logger = get_logger() - - -class LayerRegistry: - """Manages kernel registrations and syncs to HF kernels.""" - - def __init__(self): - self._registry: Dict[str, Dict[DeviceType, Dict[Any, Any]]] = {} - self._synced = False - - def register(self, kernel_name: str, repo_spec: Any, device: DeviceType = 'cuda', mode: Any = None) -> None: - if kernel_name not in self._registry: - self._registry[kernel_name] = {} - if device not in self._registry[kernel_name]: - self._registry[kernel_name][device] = {} - self._registry[kernel_name][device][mode] = repo_spec - self._synced = False - - def get(self, kernel_name: str, device: Optional[DeviceType] = None, mode: Any = None) -> Optional[Any]: - if kernel_name not in self._registry: - return None - devices = self._registry[kernel_name] - if device is None: - device = next(iter(devices.keys()), None) - if device is None: - return None - modes = devices.get(device) - if modes is None: - return None - if mode is None: - return next(iter(modes.values()), None) - return modes.get(mode) - - def has(self, kernel_name: str, device: Optional[DeviceType] = None, mode: Any = None) -> bool: - if kernel_name not in self._registry: - return False - devices = self._registry[kernel_name] - if device is None: - return True - if device not in devices: - return False - if mode is None: - return True - return mode in devices[device] - - def list_kernel_names(self) -> List[str]: - return list(self._registry.keys()) - - def sync_to_hf_kernels(self) -> None: - if self._synced or not self._registry: - return - - if not is_kernels_available(): - return - - from kernels import register_kernel_mapping as hf_register_kernel_mapping - - hf_register_kernel_mapping({}, inherit_mapping=False) - for kernel_name, device_dict in self._registry.items(): - hf_mapping = {kernel_name: device_dict} - hf_register_kernel_mapping(hf_mapping, inherit_mapping=True) - - self._synced = True - - def _clear(self) -> None: - self._registry.clear() - self._synced = False - - -_global_layer_registry = LayerRegistry() - - -class ExternalLayerRegistry: - """Maps layer classes to kernel names.""" - - def __init__(self): - self._map: Dict[Type, str] = {} - - def register(self, layer_class: Type, kernel_name: str) -> None: - self._map[layer_class] = kernel_name - - def get(self, layer_class: Type) -> Optional[str]: - return self._map.get(layer_class) - - def has(self, layer_class: Type) -> bool: - return layer_class in self._map - - def list_mappings(self) -> List[Tuple[Type, str]]: - return list(self._map.items()) - - def _clear(self) -> None: - self._map.clear() - - -_global_external_layer_registry = ExternalLayerRegistry() - - -@dataclass(frozen=True) -class FunctionKernelSpec: - func_name: str - target_module: str - func_impl: Optional[Callable] - repo: Optional['FuncRepositoryProtocol'] - repo_id: Optional[str] - revision: Optional[str] - version: Optional[str] - device: Optional[str] - mode: Optional[ModeType] - - -class FunctionRegistry: - """Manages function-level kernel registrations.""" - - def __init__(self) -> None: - self._registry: List[FunctionKernelSpec] = [] - - def register(self, spec: FunctionKernelSpec) -> None: - if spec in self._registry: - return - self._registry.append(spec) - - def list_specs(self) -> List[FunctionKernelSpec]: - return list(self._registry) - - def _clear(self) -> None: - self._registry.clear() - - -_global_function_registry = FunctionRegistry() - - -def register_layer(kernel_name: str, repo_spec: Any, device: DeviceType = 'cuda', mode: Any = None) -> None: - _global_layer_registry.register(kernel_name, repo_spec, device, mode) - - -def get_layer_spec(kernel_name: str, device: Optional[DeviceType] = None, mode: Any = None) -> Optional[Any]: - return _global_layer_registry.get(kernel_name, device, mode) - - -def list_kernel_names() -> List[str]: - return _global_layer_registry.list_kernel_names() - - -def has_kernel(kernel_name: str, device: Optional[DeviceType] = None, mode: Any = None) -> bool: - return _global_layer_registry.has(kernel_name, device, mode) - - -def register_external_layer(layer_class: Type, kernel_name: str) -> None: - _global_external_layer_registry.register(layer_class, kernel_name) - - if is_kernels_available(): - from kernels import replace_kernel_forward_from_hub - replace_kernel_forward_from_hub(layer_class, kernel_name) - logger.info(f'Registered {layer_class.__name__} -> kernel: {kernel_name}') - else: - logger.warning(f'HF kernels not available. {layer_class.__name__} mapping registered ' - f'but kernel replacement will not work without kernels package.') - - -def get_external_kernel_name(layer_class: Type) -> Optional[str]: - return _global_external_layer_registry.get(layer_class) - - -def get_global_layer_registry() -> LayerRegistry: - return _global_layer_registry - - -def get_global_external_layer_registry() -> ExternalLayerRegistry: - return _global_external_layer_registry - - -def get_global_function_registry() -> FunctionRegistry: - return _global_function_registry diff --git a/src/twinkle/model/transformers/strategy/sequence_parallel/linear_attention_sp.py b/src/twinkle/model/transformers/strategy/sequence_parallel/linear_attention_sp.py index 4ae580fa3..6608a2b88 100644 --- a/src/twinkle/model/transformers/strategy/sequence_parallel/linear_attention_sp.py +++ b/src/twinkle/model/transformers/strategy/sequence_parallel/linear_attention_sp.py @@ -107,7 +107,7 @@ def _torch_causal_conv1d_fn( return out.transpose(1, 2).contiguous() # NPU: MindSpeed Triton causal_conv1d and chunk_gated_delta_rule - # are both patched by monkey_patch_npu at model initialization. + # are both patched by twinkle.kernel.npu_impls.fla at model initialization. # No need to set them here - they are already bound on the module. if getattr(mod, '_twinkle_npu_patched', False): return False diff --git a/src/twinkle/kernel/csrc/placeholder b/tests/kernel/npu_impls/__init__.py similarity index 100% rename from src/twinkle/kernel/csrc/placeholder rename to tests/kernel/npu_impls/__init__.py diff --git a/tests/kernel/npu_impls/test_attention.py b/tests/kernel/npu_impls/test_attention.py new file mode 100644 index 000000000..ed916dba1 --- /dev/null +++ b/tests/kernel/npu_impls/test_attention.py @@ -0,0 +1,16 @@ +def test_attention_imports(): + from twinkle.kernel.npu_impls.attention import npu_sdpa_attention_forward + assert callable(npu_sdpa_attention_forward) + + +def test_attention_signature(): + import inspect + + from twinkle.kernel.npu_impls.attention import npu_sdpa_attention_forward + + sig = inspect.signature(npu_sdpa_attention_forward) + params = list(sig.parameters) + assert params[:5] == ['module', 'query', 'key', 'value', 'attention_mask'] + assert sig.parameters['dropout'].default == 0.0 + assert sig.parameters['scaling'].default is None + assert sig.parameters['is_causal'].default is None \ No newline at end of file diff --git a/tests/kernel/npu_impls/test_fla.py b/tests/kernel/npu_impls/test_fla.py new file mode 100644 index 000000000..0cfeda1df --- /dev/null +++ b/tests/kernel/npu_impls/test_fla.py @@ -0,0 +1,55 @@ +def test_fla_imports(): + from twinkle.kernel.npu_impls.fla import apply_qwen3_5_fla + assert callable(apply_qwen3_5_fla) + + +def test_fla_disabled_by_env(monkeypatch): + monkeypatch.setenv('TWINKLE_NPU_FLA', '0') + from twinkle.kernel.npu_impls.fla import apply_qwen3_5_fla + # With env=0, function returns 0 (no-op) without raising + assert apply_qwen3_5_fla(None) == 0 + + +def test_fla_skips_when_no_torch_npu(monkeypatch): + import sys + monkeypatch.setenv('TWINKLE_NPU_FLA', '1') + monkeypatch.setitem(sys.modules, 'torch_npu', None) # forces ImportError on import + from twinkle.kernel.npu_impls import fla as fla_mod + # Reload-tolerant: should return 0 when torch_npu is missing. + assert fla_mod.apply_qwen3_5_fla(None) == 0 + + +def test_fla_does_not_flip_flag_when_mindspeed_missing(monkeypatch): + """On an NPU host where the MindSpeed FLA kernel cannot be imported, + ``apply_qwen3_5_fla`` must NOT flip the global ``is_flash_linear_attention_available`` + flag — otherwise HF transformers would route Qwen3.5 onto a FLA fast path + whose kernel is not installed (runtime failure).""" + import sys + import types + + import transformers.utils as tu + + monkeypatch.setenv('TWINKLE_NPU_FLA', '1') + # Fake torch_npu as importable (with a real __spec__ so find_spec doesn't trip) + import importlib.util + spec = importlib.util.spec_from_loader('torch_npu', loader=None) + fake_npu = importlib.util.module_from_spec(spec) + monkeypatch.setitem(sys.modules, 'torch_npu', fake_npu) + # Stub causal_conv1d so the heavy real import chain doesn't run + fake_conv = types.ModuleType('twinkle.kernel.causal_conv1d') + fake_conv.npu_causal_conv1d_fn = object() + monkeypatch.setitem(sys.modules, 'twinkle.kernel.causal_conv1d', fake_conv) + # Force the MindSpeed-backed module import to fail + monkeypatch.setitem(sys.modules, 'twinkle.kernel.chunk_gated_delta_rule', None) + + original_flag = tu.is_flash_linear_attention_available + try: + from twinkle.kernel.npu_impls.fla import apply_qwen3_5_fla + assert apply_qwen3_5_fla(None) == 0 + assert tu.is_flash_linear_attention_available is original_flag, ( + 'is_flash_linear_attention_available was flipped to True while the ' + 'MindSpeed kernel is unavailable — this would break Qwen3.5 at runtime.' + ) + finally: + # Defensive cleanup in case the buggy path ran. + tu.is_flash_linear_attention_available = original_flag \ No newline at end of file diff --git a/tests/kernel/npu_impls/test_moe.py b/tests/kernel/npu_impls/test_moe.py new file mode 100644 index 000000000..34452b61c --- /dev/null +++ b/tests/kernel/npu_impls/test_moe.py @@ -0,0 +1,12 @@ +def test_moe_imports(): + from twinkle.kernel.npu_impls.moe import ( + GmmFunction, + npu_grouped_mm, + npu_packed_moe_experts_forward, + npu_qwen3_5_moe_sparse_block_forward, + ) + import torch + assert issubclass(GmmFunction, torch.autograd.Function) + assert callable(npu_grouped_mm) + assert callable(npu_packed_moe_experts_forward) + assert callable(npu_qwen3_5_moe_sparse_block_forward) \ No newline at end of file diff --git a/tests/kernel/npu_impls/test_rms_norm.py b/tests/kernel/npu_impls/test_rms_norm.py new file mode 100644 index 000000000..184d7ef70 --- /dev/null +++ b/tests/kernel/npu_impls/test_rms_norm.py @@ -0,0 +1,40 @@ +import pytest +import torch +import torch.nn as nn + +try: + import torch_npu # noqa: F401 + _NPU_OK = True +except ImportError: + _NPU_OK = False + + +def test_imports(): + """NpuRMSNorm and npu_gated_rms_norm_forward import without torch_npu.""" + from twinkle.kernel.npu_impls.rms_norm import NpuRMSNorm, npu_gated_rms_norm_forward + assert NpuRMSNorm is not None + assert callable(npu_gated_rms_norm_forward) + + +def test_npu_rmsnorm_has_no_init(): + """Class-replacement contract: NpuRMSNorm must not define its own __init__.""" + from twinkle.kernel.npu_impls.rms_norm import NpuRMSNorm + # If NpuRMSNorm defines __init__, it'd appear in NpuRMSNorm.__dict__ + assert '__init__' not in NpuRMSNorm.__dict__ + + +@pytest.mark.skipif(not _NPU_OK, reason='torch_npu unavailable') +def test_npu_rmsnorm_forward_runs_on_npu(): + from twinkle.kernel.npu_impls.rms_norm import NpuRMSNorm + + class _Orig(nn.Module): + def __init__(self): + super().__init__() + self.weight = nn.Parameter(torch.ones(8)) + self.variance_epsilon = 1e-6 + + m = _Orig().to('npu') + m.__class__ = NpuRMSNorm + x = torch.randn(2, 8, device='npu') + y = m(x) + assert y.shape == (2, 8) \ No newline at end of file diff --git a/tests/kernel/npu_impls/test_rotary.py b/tests/kernel/npu_impls/test_rotary.py new file mode 100644 index 000000000..460d0fc33 --- /dev/null +++ b/tests/kernel/npu_impls/test_rotary.py @@ -0,0 +1,21 @@ +def test_rotary_imports(): + from twinkle.kernel.npu_impls.rotary import ( + npu_apply_multimodal_rotary_pos_emb, + npu_apply_rotary_pos_emb, + ) + assert callable(npu_apply_rotary_pos_emb) + assert callable(npu_apply_multimodal_rotary_pos_emb) + + +def test_rotary_signature_compat(): + """Signature must match HF apply_rotary_pos_emb so setattr swap is safe.""" + import inspect + + from twinkle.kernel.npu_impls.rotary import npu_apply_rotary_pos_emb + + sig = inspect.signature(npu_apply_rotary_pos_emb) + params = list(sig.parameters) + assert params[:4] == ['q', 'k', 'cos', 'sin'] + # position_ids and unsqueeze_dim must be optional + assert sig.parameters['position_ids'].default is None + assert sig.parameters['unsqueeze_dim'].default == 1 \ No newline at end of file diff --git a/tests/kernel/npu_impls/test_swiglu.py b/tests/kernel/npu_impls/test_swiglu.py new file mode 100644 index 000000000..d4ec2da9a --- /dev/null +++ b/tests/kernel/npu_impls/test_swiglu.py @@ -0,0 +1,12 @@ +def test_swiglu_imports(): + from twinkle.kernel.npu_impls.swiglu import npu_swiglu_forward + assert callable(npu_swiglu_forward) + + +def test_swiglu_signature(): + import inspect + + from twinkle.kernel.npu_impls.swiglu import npu_swiglu_forward + + params = list(inspect.signature(npu_swiglu_forward).parameters) + assert params == ['self', 'hidden_state'] \ No newline at end of file diff --git a/tests/kernel/test_builtin.py b/tests/kernel/test_builtin.py new file mode 100644 index 000000000..38d83915d --- /dev/null +++ b/tests/kernel/test_builtin.py @@ -0,0 +1,90 @@ +import importlib.machinery +import sys +import types + +import torch +import torch.nn as nn + +import pytest + + +def _fake_module(name: str): + module = types.ModuleType(name) + module.__spec__ = importlib.machinery.ModuleSpec(name, loader=None) + return module + + +def test_npu_builtin_returns_dict(): + from twinkle.kernel.builtin import npu_builtin + bundle = npu_builtin() + assert isinstance(bundle, dict) + assert len(bundle) > 0 + + +def test_npu_builtin_values_are_npu_gated(): + """Every value in npu_builtin() must be wrapped in {'npu': ...} so it's + safely no-op on CUDA/CPU.""" + from twinkle.kernel.builtin import npu_builtin + for key, value in npu_builtin().items(): + assert isinstance(value, dict), f'value for {key!r} is not a device-dict' + assert 'npu' in value, f'value for {key!r} is missing npu entry' + + +def test_npu_builtin_compose_with_user_override(): + """User-supplied keys override the builtin (via plain dict merge).""" + from twinkle.kernel.builtin import npu_builtin + sentinel = object() + merged = {**npu_builtin(), 'fake.module.path.fn': sentinel} + assert merged['fake.module.path.fn'] is sentinel + + +def test_npu_builtin_safe_on_cpu_model(): + """kernelize(cpu_model, npu_builtin()) must not raise and not modify.""" + from twinkle.kernel import kernelize + from twinkle.kernel.builtin import npu_builtin + + m = nn.Sequential(nn.Linear(2, 2)) + pre_type = type(m[0]) + out = kernelize(m, npu_builtin()) + assert out is m + assert type(m[0]) is pre_type # no replacement happened (cpu device) + + +def test_npu_builtin_skips_missing_modeling_modules(): + """If transformers.models.qwen3_5 is not installed, the bundle must + still produce a dict (with whatever subset is available).""" + from twinkle.kernel.builtin import npu_builtin + bundle = npu_builtin() # must not raise + assert isinstance(bundle, dict) + + +def test_npu_builtin_does_not_overwrite_global_sdpa_on_non_npu_host(monkeypatch): + """Calling npu_builtin() on a CUDA/CPU host must not contaminate the + global HF SDPA registry. The NPU impl inverts boolean masks, which is + wrong for non-NPU execution.""" + from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS + from twinkle.kernel.builtin import npu_builtin + from twinkle.utils.device_mesh import Platform + + monkeypatch.setattr(Platform, 'device_prefix', staticmethod(lambda platform=None: 'cuda')) + original = ALL_ATTENTION_FUNCTIONS.get('sdpa') + npu_builtin() + assert ALL_ATTENTION_FUNCTIONS.get('sdpa') is original + + +def test_npu_builtin_skips_side_effects_on_non_npu_platform(monkeypatch): + from twinkle.kernel import builtin + from twinkle.kernel.npu_impls import fla + from twinkle.utils.device_mesh import Platform + + installs = [] + fla_calls = [] + monkeypatch.setattr(Platform, 'device_prefix', staticmethod(lambda platform=None: 'cuda')) + monkeypatch.setitem(sys.modules, 'torch_npu', _fake_module('torch_npu')) + monkeypatch.setattr(builtin, '_install_sdpa', lambda impl: installs.append(impl)) + monkeypatch.setattr(fla, 'apply_qwen3_5_fla', lambda model: fla_calls.append(model)) + + builtin.npu_builtin(nn.Linear(1, 1)) + + assert installs == [] + assert fla_calls == [] diff --git a/tests/kernel/test_function_kernel.py b/tests/kernel/test_function_kernel.py deleted file mode 100644 index 02b35dd49..000000000 --- a/tests/kernel/test_function_kernel.py +++ /dev/null @@ -1,265 +0,0 @@ -import os -import pytest -import sys -import torch -import torch.nn as nn -import torch.nn.functional as F -import types - -try: - import requests -except ImportError: - requests = None - -from twinkle.kernel.base import is_kernels_available -from twinkle.kernel.function import apply_function_kernel, register_function_kernel -from twinkle.kernel.registry import get_global_function_registry - - -def _ensure_test_packages() -> None: - if 'tests' not in sys.modules: - tests_pkg = types.ModuleType('tests') - tests_pkg.__path__ = [] - sys.modules['tests'] = tests_pkg - if 'tests.kernel' not in sys.modules: - kernel_pkg = types.ModuleType('tests.kernel') - kernel_pkg.__path__ = [] - sys.modules['tests.kernel'] = kernel_pkg - - -def _reference_silu_and_mul(x: torch.Tensor) -> torch.Tensor: - d = x.shape[-1] // 2 - return F.silu(x[..., :d]) * x[..., d:] - - -class TestFunctionKernel: - - def setup_method(self): - if not is_kernels_available(): - pytest.skip('kernels package not available in this environment.') - get_global_function_registry()._clear() - - def teardown_method(self): - get_global_function_registry()._clear() - - def test_flattened_build_replaces_function(self): - if os.environ.get('TWINKLE_SKIP_SLOW_TESTS') == '1': - pytest.skip('TWINKLE_SKIP_SLOW_TESTS=1') - if not torch.cuda.is_available(): - pytest.skip('CUDA not available in this environment.') - try: - import urllib.request - urllib.request.urlopen('https://huggingface.co', timeout=5) - except Exception as e: - pytest.skip(f'HuggingFace unreachable: {e}') - try: - from kernels import has_kernel - from kernels._versions import select_revision_or_version - from kernels.utils import get_kernel - except Exception: - pytest.skip('kernels package missing has_kernel.') - if not has_kernel('kernels-test/flattened-build'): - pytest.skip('kernels-test/flattened-build not available.') - try: - revision = select_revision_or_version( - 'kernels-test/flattened-build', - revision=None, - version=None, - ) - get_kernel('kernels-test/flattened-build', revision=revision) - except Exception as exc: - pytest.skip(f'kernels-test/flattened-build cannot be loaded in this env: {exc}') - - _ensure_test_packages() - module_name = 'tests.kernel._tmp_flattened_build_module' - temp_module = types.ModuleType(module_name) - - def original(x: torch.Tensor) -> torch.Tensor: - return _reference_silu_and_mul(x) - - temp_module.silu_and_mul = original - temp_module.__path__ = [] - sys.modules[module_name] = temp_module - - try: - register_function_kernel( - func_name='silu_and_mul', - target_module=module_name, - repo_id='kernels-test/flattened-build', - device='cuda', - mode='inference', - ) - - try: - applied = apply_function_kernel( - target_module=module_name, - device='cuda', - mode='inference', - ) - except TypeError as e: - if 'select_revision_or_version' in str(e) or 'takes 1 positional argument' in str(e): - pytest.skip(f'kernels API incompatible: {e}') - raise - except Exception as e: - if requests and isinstance(e, (requests.exceptions.SSLError, requests.exceptions.RequestException)): - pytest.skip(f'Network/HuggingFace unreachable: {e}') - if 'SSLError' in type(e).__name__ or 'MaxRetryError' in str(e): - pytest.skip(f'Network/HuggingFace unreachable: {e}') - raise - - assert applied == [f'{module_name}.silu_and_mul'] - assert temp_module.silu_and_mul is not original - - x = torch.randn(4, 16, device='cuda', dtype=torch.float16) - y_kernel = temp_module.silu_and_mul(x) - y_ref = _reference_silu_and_mul(x) - assert torch.allclose(y_kernel, y_ref, atol=1e-3, rtol=1e-3) - except Exception as e: - if requests and isinstance(e, (requests.exceptions.SSLError, requests.exceptions.RequestException)): - pytest.skip(f'Network/HuggingFace unreachable: {e}') - if 'SSLError' in type(e).__name__ or 'MaxRetryError' in str(e): - pytest.skip(f'Network/HuggingFace unreachable: {e}') - raise - finally: - sys.modules.pop(module_name, None) - - def test_flattened_build_device_filter(self): - _ensure_test_packages() - module_name = 'tests.kernel._tmp_flattened_build_device' - temp_module = types.ModuleType(module_name) - - def original(x: torch.Tensor) -> torch.Tensor: - return _reference_silu_and_mul(x) - - temp_module.silu_and_mul = original - temp_module.__path__ = [] - sys.modules[module_name] = temp_module - - try: - register_function_kernel( - func_name='silu_and_mul', - target_module=module_name, - repo_id='kernels-test/flattened-build', - device='cuda', - mode='inference', - ) - - applied = apply_function_kernel( - target_module=module_name, - device='cpu', - mode='inference', - ) - - assert applied == [] - assert temp_module.silu_and_mul is original - finally: - sys.modules.pop(module_name, None) - - def test_flattened_build_mode_filter(self): - _ensure_test_packages() - module_name = 'tests.kernel._tmp_flattened_build_mode' - temp_module = types.ModuleType(module_name) - - def original(x: torch.Tensor) -> torch.Tensor: - return _reference_silu_and_mul(x) - - temp_module.silu_and_mul = original - temp_module.__path__ = [] - sys.modules[module_name] = temp_module - - try: - register_function_kernel( - func_name='silu_and_mul', - target_module=module_name, - repo_id='kernels-test/flattened-build', - device='cuda', - mode='inference', - ) - - applied = apply_function_kernel( - target_module=module_name, - device='cuda', - mode='train', - ) - - assert applied == [] - assert temp_module.silu_and_mul is original - finally: - sys.modules.pop(module_name, None) - - def test_flattened_build_strict_raises_on_no_match(self): - _ensure_test_packages() - module_name = 'tests.kernel._tmp_flattened_build_strict' - temp_module = types.ModuleType(module_name) - - def original(x: torch.Tensor) -> torch.Tensor: - return _reference_silu_and_mul(x) - - temp_module.silu_and_mul = original - temp_module.__path__ = [] - sys.modules[module_name] = temp_module - - try: - register_function_kernel( - func_name='silu_and_mul', - target_module=module_name, - repo_id='kernels-test/flattened-build', - device='cuda', - mode='inference', - ) - - with pytest.raises(ValueError): - apply_function_kernel( - target_module=module_name, - device='cpu', - mode='inference', - strict=True, - ) - finally: - sys.modules.pop(module_name, None) - - def test_repo_object_loads_module_class(self): - _ensure_test_packages() - module_name = 'tests.kernel._tmp_repo_object' - temp_module = types.ModuleType(module_name) - - def original(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: - return x + y - - temp_module.add = original - temp_module.__path__ = [] - sys.modules[module_name] = temp_module - - class MyKernelFunc(nn.Module): - - def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: - return x + y + 2 - - class MyFuncRepo: - func_name = 'add' - - def load(self): - return MyKernelFunc - - try: - register_function_kernel( - func_name='add', - target_module=module_name, - repo=MyFuncRepo(), - device='cuda', - mode='inference', - ) - - applied = apply_function_kernel( - target_module=module_name, - device='cuda', - mode='inference', - ) - - assert applied == [f'{module_name}.add'] - assert temp_module.add is not original - x = torch.tensor([1.0]) - y = torch.tensor([2.0]) - assert torch.allclose(temp_module.add(x, y), x + y + 2) - finally: - sys.modules.pop(module_name, None) diff --git a/tests/kernel/test_hub.py b/tests/kernel/test_hub.py new file mode 100644 index 000000000..022f42e98 --- /dev/null +++ b/tests/kernel/test_hub.py @@ -0,0 +1,52 @@ +import pytest + +from twinkle.kernel.core import HubRef, hub + + +def test_hub_with_version(): + ref = hub('kernels-community/activation:SiluAndMul', version=1) + assert isinstance(ref, HubRef) + assert ref.repo_id == 'kernels-community/activation' + assert ref.layer_name == 'SiluAndMul' + assert ref.version == 1 + assert ref.revision is None + assert ref.backend is None + + +def test_hub_with_revision(): + ref = hub('org/repo:Layer', revision='main') + assert ref.revision == 'main' + assert ref.version is None + + +def test_hub_with_backend(): + ref = hub('org/repo:Layer', version=2, backend='cuda') + assert ref.backend == 'cuda' + + +def test_hub_rejects_both_revision_and_version(): + with pytest.raises(ValueError, match='Exactly one'): + hub('org/repo:Layer', revision='main', version=1) + + +def test_hub_rejects_neither_revision_nor_version(): + with pytest.raises(ValueError, match='Exactly one'): + hub('org/repo:Layer') + + +def test_hub_rejects_missing_colon(): + with pytest.raises(ValueError, match='repo_id:LayerName'): + hub('org/repo', version=1) + + +def test_hub_handles_colon_in_repo_id(): + # rsplit takes only the last colon + ref = hub('org:sub/repo:Layer', version=1) + assert ref.repo_id == 'org:sub/repo' + assert ref.layer_name == 'Layer' + + +def test_hubref_is_frozen(): + ref = hub('org/repo:Layer', version=1) + with pytest.raises(Exception): + ref.repo_id = 'other' \ No newline at end of file diff --git a/tests/kernel/test_kernel.py b/tests/kernel/test_kernel.py deleted file mode 100644 index 5b6a658b4..000000000 --- a/tests/kernel/test_kernel.py +++ /dev/null @@ -1,352 +0,0 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -""" -Kernel module unit tests -""" -import os -import pytest -from unittest.mock import MagicMock, Mock, patch - -from twinkle.kernel import kernelize_model, register_external_layer, register_kernels, register_layer_kernel -from twinkle.kernel.base import is_kernels_available, is_kernels_enabled, to_kernels_mode -from twinkle.kernel.registry import (ExternalLayerRegistry, LayerRegistry, get_global_external_layer_registry, - get_global_function_registry, get_global_layer_registry, get_layer_spec, - register_layer) - - -class TestBase: - """Test base helpers and env vars.""" - - def test_is_kernels_available(self): - """Test kernels availability check.""" - result = is_kernels_available() - assert isinstance(result, bool) - - def test_kernels_enabled_env_var(self): - """Test env var controls kernels enablement.""" - original = os.environ.get('TWINKLE_USE_KERNELS') - try: - os.environ['TWINKLE_USE_KERNELS'] = 'YES' - from twinkle.kernel.base import _kernels_enabled - assert _kernels_enabled() - - os.environ['TWINKLE_USE_KERNELS'] = 'NO' - import importlib - - import twinkle.kernel.base - importlib.reload(twinkle.kernel.base) - from twinkle.kernel.base import _kernels_enabled - assert not _kernels_enabled() - finally: - if original is not None: - os.environ['TWINKLE_USE_KERNELS'] = original - else: - os.environ.pop('TWINKLE_USE_KERNELS', None) - - def test_to_kernels_mode(self): - """Test mode conversion.""" - if not is_kernels_available(): - pytest.skip('kernels package not available') - - assert to_kernels_mode('train').name == 'TRAINING' - assert to_kernels_mode('inference').name == 'INFERENCE' - assert to_kernels_mode('compile').name == 'TORCH_COMPILE' - - -class TestLayerRegistry: - """Test layer registry.""" - - def setup_method(self): - self.registry = LayerRegistry() - - def test_register_and_get(self): - """Test register and lookup.""" - mock_spec = Mock() - self.registry.register('TestLayer', mock_spec, 'cuda') - - result = self.registry.get('TestLayer', 'cuda') - assert result == mock_spec - - result = self.registry.get('NonExistent', 'cuda') - assert result is None - - def test_register_multiple_devices(self): - """Test registration for multiple devices.""" - mock_cuda = Mock() - mock_npu = Mock() - - self.registry.register('TestLayer', mock_cuda, 'cuda') - self.registry.register('TestLayer', mock_npu, 'npu') - - assert self.registry.get('TestLayer', 'cuda') == mock_cuda - assert self.registry.get('TestLayer', 'npu') == mock_npu - - def test_get_without_device(self): - """Test lookup without device.""" - mock_spec = Mock() - self.registry.register('TestLayer', mock_spec, 'cuda') - - result = self.registry.get('TestLayer') - assert result == mock_spec - - def test_has(self): - """Test has checks.""" - mock_spec = Mock() - assert not self.registry.has('TestLayer') - - self.registry.register('TestLayer', mock_spec, 'cuda') - assert self.registry.has('TestLayer') - assert self.registry.has('TestLayer', 'cuda') - assert not self.registry.has('TestLayer', 'npu') - - def test_list_kernel_names(self): - """Test listing kernel names.""" - mock_spec = Mock() - self.registry.register('Layer1', mock_spec, 'cuda') - self.registry.register('Layer2', mock_spec, 'cuda') - - names = self.registry.list_kernel_names() - assert sorted(names) == sorted(['Layer1', 'Layer2']) - - -class TestExternalLayerRegistry: - """Test external layer registry.""" - - def setup_method(self): - self.registry = ExternalLayerRegistry() - - def test_register_and_get(self): - """Test register and lookup.""" - mock_class = Mock - self.registry.register(mock_class, 'LlamaAttention') - - result = self.registry.get(mock_class) - assert result == 'LlamaAttention' - - def test_has(self): - """Test has checks.""" - mock_class = Mock - assert not self.registry.has(mock_class) - - self.registry.register(mock_class, 'LlamaAttention') - assert self.registry.has(mock_class) - - def test_list_mappings(self): - """Test list mappings.""" - - class MockClass1: - pass - - class MockClass2: - pass - - self.registry.register(MockClass1, 'LlamaAttention') - self.registry.register(MockClass2, 'LlamaMLP') - - mappings = self.registry.list_mappings() - assert len(mappings) == 2 - - -class TestRegisterLayer: - """Test global register helpers.""" - - def setup_method(self): - get_global_layer_registry()._clear() - get_global_function_registry()._clear() - - def test_register_and_get_spec(self): - """Test global register and lookup.""" - mock_spec = Mock() - register_layer('TestLayer', mock_spec, 'cuda') - - result = get_layer_spec('TestLayer', 'cuda') - assert result == mock_spec - - -class TestRegisterLayerKernel: - """Test register_layer_kernel.""" - - def setup_method(self): - get_global_layer_registry()._clear() - - def test_register_without_kernels_package(self): - """Test registration when kernels package missing.""" - with patch('twinkle.kernel.layer.is_kernels_available', return_value=False): - register_layer_kernel('TestLayer', repo_id='test/repo') - assert get_layer_spec('TestLayer') is None - - def test_register_with_kernels_package(self): - """Test registration when kernels package available.""" - if not is_kernels_available(): - pytest.skip('kernels package not available') - - register_layer_kernel( - kernel_name='TestLayer', - repo_id='kernels-community/test', - ) - - assert get_layer_spec('TestLayer') is not None - - -class TestKernelizeModel: - """Test kernelize_model.""" - - def test_kernelize_without_kernels_enabled(self): - """Test returns original model when kernels disabled.""" - with patch('twinkle.kernel.layer.is_kernels_enabled', return_value=False): - mock_model = Mock() - result = kernelize_model(mock_model) - assert result == mock_model - - @patch('twinkle.kernel.layer.is_kernels_available', return_value=False) - def test_kernelize_without_kernels_available(self, mock_available): - """Test returns original model when kernels unavailable.""" - mock_model = Mock() - result = kernelize_model(mock_model) - assert result == mock_model - - -class TestRegisterExternalLayer: - """Test register_external_layer.""" - - def setup_method(self): - get_global_external_layer_registry()._clear() - - def test_register_external_layer(self): - """Test registering external layer.""" - mock_class = Mock - - register_external_layer(mock_class, 'LlamaAttention') - - result = get_global_external_layer_registry().get(mock_class) - assert result == 'LlamaAttention' - - def test_register_external_qwen_layer(self): - """Test registering Qwen2 external layer mapping.""" - try: - from transformers.models.qwen2.modeling_qwen2 import Qwen2Attention - except ImportError: - pytest.skip('transformers package not available') - - register_external_layer(Qwen2Attention, 'LlamaAttention') - - registry = get_global_external_layer_registry() - assert registry.has(Qwen2Attention) - assert registry.get(Qwen2Attention) == 'LlamaAttention' - - def test_register_external_layer_adds_kernel_layer_name(self): - """Test register_external_layer sets kernel_layer_name.""" - if not is_kernels_available(): - pytest.skip('kernels package not available') - - class TestLayer: - pass - - register_external_layer(TestLayer, 'TestKernel') - - assert hasattr(TestLayer, 'kernel_layer_name') - assert TestLayer.kernel_layer_name == 'TestKernel' - - -class TestRegisterKernels: - """Test register_kernels batch registration.""" - - def setup_method(self): - get_global_layer_registry()._clear() - - @patch('twinkle.kernel.layer.is_kernels_available', return_value=False) - def test_register_layers_without_kernels(self, mock_available): - """Test layer batch registration when kernels missing.""" - config = { - 'layers': { - 'LlamaAttention': { - 'repo_id': 'kernels-community/llama-attention' - }, - 'LlamaMLP': { - 'repo_id': 'kernels-community/llama-mlp' - }, - } - } - - register_kernels(config) - - assert get_layer_spec('LlamaAttention') is None - assert get_layer_spec('LlamaMLP') is None - - def test_register_functions(self): - """Test function batch registration.""" - config = { - 'functions': { - 'apply_rotary_pos_emb': { - 'func_impl': Mock, - 'target_module': 'test', - 'device': 'cpu', - 'mode': 'inference', - } - } - } - - register_kernels(config) - specs = get_global_function_registry().list_specs() - assert len(specs) == 1 - spec = specs[0] - assert spec.func_name == 'apply_rotary_pos_emb' - assert spec.target_module == 'test' - assert spec.func_impl == Mock - assert spec.device == 'cpu' - assert spec.mode == 'inference' - - -class TestModeSupport: - """Test mode support.""" - - def setup_method(self): - get_global_layer_registry()._clear() - - @patch('twinkle.kernel.layer.is_kernels_available', return_value=False) - def test_register_with_mode_fallback(self, mock_available): - """Test fallback mode mapping when mode is None.""" - from kernels import Mode - - from twinkle.kernel.layer import _to_hf_mode, register_layer_kernel - - result = _to_hf_mode(None) - assert result == Mode.FALLBACK - - def test_to_hf_mode_conversion(self): - """Test Twinkle mode to HF kernels Mode conversion.""" - if not is_kernels_available(): - pytest.skip('kernels package not available') - - from kernels import Mode - - from twinkle.kernel.layer import _to_hf_mode - - assert _to_hf_mode('train') == Mode.TRAINING - assert _to_hf_mode('inference') == Mode.INFERENCE - assert _to_hf_mode('compile') == Mode.TORCH_COMPILE - - @patch('twinkle.kernel.layer.is_kernels_available', return_value=False) - def test_register_multiple_modes(self, mock_available): - """Test registering multiple modes for the same layer.""" - registry = get_global_layer_registry() - - class MockRepo: - pass - - repo_inference = MockRepo() - repo_training = MockRepo() - - from kernels import Mode - - registry.register('TestLayer', repo_inference, 'cuda', Mode.INFERENCE) - registry.register('TestLayer', repo_training, 'cuda', Mode.TRAINING) - - assert registry.has('TestLayer', 'cuda', Mode.INFERENCE) - assert registry.has('TestLayer', 'cuda', Mode.TRAINING) - - result = registry.get('TestLayer', 'cuda', Mode.INFERENCE) - assert result == repo_inference - - -if __name__ == '__main__': - pytest.main([__file__]) diff --git a/tests/kernel/test_kernelize.py b/tests/kernel/test_kernelize.py new file mode 100644 index 000000000..cdb98caee --- /dev/null +++ b/tests/kernel/test_kernelize.py @@ -0,0 +1,98 @@ +import sys +import types + +import pytest +import torch +import torch.nn as nn + +from twinkle.kernel.core import HubRef, kernelize + + +class _SrcLayer(nn.Module): + def __init__(self): + super().__init__() + self.w = nn.Parameter(torch.zeros(1)) + + def forward(self, x): + return x + + +class _DstLayer(nn.Module): + def forward(self, x): + return x + 100 + + +def test_kernelize_class_to_class_replacement(): + parent = nn.Sequential(_SrcLayer(), _SrcLayer()) + out = kernelize(parent, {_SrcLayer: _DstLayer}) + assert out is parent + assert type(parent[0]) is _DstLayer + assert type(parent[1]) is _DstLayer + + +def test_kernelize_empty_mapping_returns_model(): + m = _SrcLayer() + assert kernelize(m, {}) is m + assert type(m) is _SrcLayer + + +def test_kernelize_string_key_calls_setattr(): + mod_name = 'tests.kernel._tmp_kernelize_str' + mod = types.ModuleType(mod_name) + mod.target_fn = lambda x: x + sys.modules[mod_name] = mod + try: + new_fn = lambda x: x * 3 # noqa: E731 + kernelize(nn.Linear(1, 1), {f'{mod_name}.target_fn': new_fn}) + assert mod.target_fn is new_fn + finally: + sys.modules.pop(mod_name, None) + + +def test_kernelize_device_dict_match(monkeypatch): + from twinkle.utils.device_mesh import Platform + + parent = nn.Sequential(_SrcLayer()) + monkeypatch.setattr(Platform, 'device_prefix', staticmethod(lambda platform=None: 'cpu')) + + kernelize(parent, {_SrcLayer: {'cpu': _DstLayer, 'npu': nn.Identity}}) + + assert type(parent[0]) is _DstLayer + + +def test_kernelize_uses_platform_device_prefix(monkeypatch): + from twinkle.utils.device_mesh import Platform + + parent = nn.Sequential(_SrcLayer()) # params may still be CPU before FSDP placement + monkeypatch.setattr(Platform, 'device_prefix', staticmethod(lambda platform=None: 'npu')) + + kernelize(parent, {_SrcLayer: {'npu': _DstLayer}}) + + assert type(parent[0]) is _DstLayer + + +def test_kernelize_device_dict_miss_skips_silently(monkeypatch): + from twinkle.utils.device_mesh import Platform + + parent = nn.Sequential(_SrcLayer()) + monkeypatch.setattr(Platform, 'device_prefix', staticmethod(lambda platform=None: 'cpu')) + + kernelize(parent, {_SrcLayer: {'npu': _DstLayer}}) + + assert type(parent[0]) is _SrcLayer + + +def test_kernelize_rejects_unknown_key_type(): + with pytest.raises(TypeError, match='Unsupported mapping key'): + kernelize(nn.Linear(1, 1), {42: _DstLayer}) + + +def test_kernelize_loads_hub_ref(monkeypatch): + # Stand in for HF kernels: patch _load_hub_ref to return _DstLayer + from twinkle.kernel import core as _core + monkeypatch.setattr(_core, '_load_hub_ref', lambda ref: _DstLayer) + + parent = nn.Sequential(_SrcLayer()) + ref = HubRef('org/repo', 'X', revision='main') + kernelize(parent, {_SrcLayer: ref}) + assert type(parent[0]) is _DstLayer diff --git a/tests/kernel/test_load_hub_ref.py b/tests/kernel/test_load_hub_ref.py new file mode 100644 index 000000000..747e3fdcb --- /dev/null +++ b/tests/kernel/test_load_hub_ref.py @@ -0,0 +1,69 @@ +import sys +import types +from unittest.mock import patch + +import pytest + +from twinkle.kernel.core import HubRef, _load_hub_ref + + +def _install_fake_kernels(layer_obj=None, no_layers=False): + """Install a fake `kernels` module with a controllable `get_kernel`.""" + fake = types.ModuleType('kernels') + + def fake_get_kernel(repo_id, **kwargs): + m = types.ModuleType('fake_kernel') + if not no_layers: + layers_ns = types.SimpleNamespace() + if layer_obj is not None: + layers_ns.MyLayer = layer_obj + m.layers = layers_ns + return m + + fake.get_kernel = fake_get_kernel + sys.modules['kernels'] = fake + + +def _uninstall_fake_kernels(): + sys.modules.pop('kernels', None) + + +def test_load_hub_ref_returns_layer(): + sentinel = object() + _install_fake_kernels(layer_obj=sentinel) + try: + ref = HubRef('org/repo', 'MyLayer', revision='main') + assert _load_hub_ref(ref) is sentinel + finally: + _uninstall_fake_kernels() + + +def test_load_hub_ref_raises_if_layers_missing(): + _install_fake_kernels(no_layers=True) + try: + ref = HubRef('org/repo', 'MyLayer', revision='main') + with pytest.raises(ValueError, match='does not define any layers'): + _load_hub_ref(ref) + finally: + _uninstall_fake_kernels() + + +def test_load_hub_ref_raises_if_layer_name_missing(): + _install_fake_kernels(layer_obj=None) # MyLayer not present + try: + ref = HubRef('org/repo', 'Missing', revision='main') + with pytest.raises(ValueError, match='not found'): + _load_hub_ref(ref) + finally: + _uninstall_fake_kernels() + + +def test_load_hub_ref_install_hint_when_kernels_missing(): + # Force `import kernels` to fail + sys.modules['kernels'] = None # short-circuits import to ImportError + try: + ref = HubRef('org/repo', 'MyLayer', revision='main') + with pytest.raises(ImportError, match='pip install kernels'): + _load_hub_ref(ref) + finally: + sys.modules.pop('kernels', None) \ No newline at end of file diff --git a/tests/kernel/test_public_api.py b/tests/kernel/test_public_api.py new file mode 100644 index 000000000..f9a17a2a5 --- /dev/null +++ b/tests/kernel/test_public_api.py @@ -0,0 +1,22 @@ +def test_public_exports_exactly_three_symbols(): + import twinkle.kernel as k + assert sorted(k.__all__) == ['hub', 'kernelize', 'npu_builtin'] + assert callable(k.kernelize) + assert callable(k.npu_builtin) + assert callable(k.hub) + + +def test_no_legacy_symbols(): + """Legacy registrar / patch helpers must be gone.""" + import twinkle.kernel as k + legacy = [ + 'kernelize_model', 'register_layer_kernel', 'register_function_kernel', + 'register_kernels', 'register_external_layer', 'apply_npu_patch', + 'apply_npu_fused_ops', 'apply_function_kernel', 'apply_layer_kernel', + 'register_layer_batch', 'register_npu_fused_function_kernels', + 'get_global_layer_registry', 'get_global_function_registry', + 'get_global_external_layer_registry', 'LayerRegistry', + 'ExternalLayerRegistry', 'FunctionRegistry', + ] + for name in legacy: + assert not hasattr(k, name), f'unexpected legacy symbol: {name}' \ No newline at end of file diff --git a/tests/kernel/test_replace.py b/tests/kernel/test_replace.py new file mode 100644 index 000000000..e649b2e3c --- /dev/null +++ b/tests/kernel/test_replace.py @@ -0,0 +1,74 @@ +import sys +import types + +import torch.nn as nn + +from twinkle.kernel.core import _replace_attr, _replace_class + + +class _Target(nn.Module): + def forward(self, x): + return x + + +class _Impl(nn.Module): + def forward(self, x): + return x + 1 + + +class _SubTarget(_Target): + pass + + +def test_replace_class_rewrites_exact_match(): + m = _Target() + parent = nn.Sequential(_Target(), nn.Linear(1, 1)) + _replace_class(parent, _Target, _Impl) + assert type(parent[0]) is _Impl + + +def test_replace_class_skips_subclass(): + parent = nn.Sequential(_SubTarget()) + _replace_class(parent, _Target, _Impl) + # exact match only - _SubTarget should NOT be rewritten + assert type(parent[0]) is _SubTarget + + +def test_replace_class_idempotent(): + m = nn.Sequential(_Target()) + _replace_class(m, _Target, _Impl) + _replace_class(m, _Target, _Impl) # second call must be safe + assert type(m[0]) is _Impl + + +def test_replace_attr_sets_module_attribute(): + mod_name = 'tests.kernel._tmp_replace_attr' + mod = types.ModuleType(mod_name) + mod.target_fn = lambda x: x + sys.modules[mod_name] = mod + try: + new_fn = lambda x: x * 2 # noqa: E731 + _replace_attr(f'{mod_name}.target_fn', new_fn) + assert mod.target_fn is new_fn + finally: + sys.modules.pop(mod_name, None) + + +def test_replace_attr_supports_class_attribute(): + import sys + import types + + mod_name = 'tests.kernel._tmp_class_attr' + mod = types.ModuleType(mod_name) + + class Foo: + def forward(self, x): + return x + mod.Foo = Foo + sys.modules[mod_name] = mod + try: + new_forward = lambda self, x: x + 7 # noqa: E731 + _replace_attr(f'{mod_name}.Foo.forward', new_forward) + assert Foo.forward is new_forward + finally: + sys.modules.pop(mod_name, None) \ No newline at end of file diff --git a/tests/kernel/test_resolve_value.py b/tests/kernel/test_resolve_value.py new file mode 100644 index 000000000..652783f5c --- /dev/null +++ b/tests/kernel/test_resolve_value.py @@ -0,0 +1,48 @@ +import torch.nn as nn + +from twinkle.kernel.core import HubRef, _resolve_value + + +class _ImplA(nn.Module): + pass + + +class _ImplB(nn.Module): + pass + + +def test_passthrough_class_value(): + assert _resolve_value(_ImplA, 'cuda') is _ImplA + + +def test_passthrough_callable_value(): + f = lambda x: x # noqa: E731 + assert _resolve_value(f, 'npu') is f + + +def test_passthrough_hubref(): + ref = HubRef('org/repo', 'Layer', revision='main') + assert _resolve_value(ref, 'cuda') is ref + + +def test_device_dict_match(): + val = {'npu': _ImplA, 'cuda': _ImplB} + assert _resolve_value(val, 'npu') is _ImplA + assert _resolve_value(val, 'cuda') is _ImplB + + +def test_device_dict_miss_returns_none(): + val = {'npu': _ImplA} + assert _resolve_value(val, 'cuda') is None + + +def test_device_dict_nested(): + # nested dict -> recursive resolve + val = {'npu': {'npu': _ImplA}} + assert _resolve_value(val, 'npu') is _ImplA + + +def test_device_dict_miss_then_passthrough(): + # nested dict whose inner is also a dict that misses -> None + val = {'npu': {'cuda': _ImplA}} + assert _resolve_value(val, 'npu') is None \ No newline at end of file