From 63a1c2b3d9faa4568d5452917de3cf22d6c73794 Mon Sep 17 00:00:00 2001 From: weikaiwen <34648228+kevssim@users.noreply.github.com> Date: Fri, 26 Jun 2026 08:49:17 +0800 Subject: [PATCH 01/27] feat(kernel): add HubRef dataclass and hub() factory --- src/twinkle/kernel/core.py | 43 ++++++++++++++++++++++++++++++ tests/kernel/test_hub.py | 54 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 97 insertions(+) create mode 100644 src/twinkle/kernel/core.py create mode 100644 tests/kernel/test_hub.py diff --git a/src/twinkle/kernel/core.py b/src/twinkle/kernel/core.py new file mode 100644 index 00000000..4e41f3c4 --- /dev/null +++ b/src/twinkle/kernel/core.py @@ -0,0 +1,43 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Minimal mapping-driven kernel replacement. + +Public API: ``kernelize``, ``hub`` (re-exported from ``twinkle.kernel``). +""" +from __future__ import annotations + +from dataclasses import dataclass + + +@dataclass(frozen=True) +class HubRef: + """Lightweight reference to a HuggingFace Hub kernel layer. + + Resolved lazily by ``kernelize`` via the optional ``kernels`` package. + """ + repo_id: str + layer_name: str + revision: str | None = None + version: int | None = None + backend: str | None = None + trust_remote_code: bool = False + + +def hub( + ref: str, + *, + revision: str | None = None, + version: int | None = None, + backend: str | None = None, + trust_remote_code: bool = False, +) -> HubRef: + """Build a ``HubRef`` for use as a ``kernelize`` mapping value. + + ``ref`` is ``':'`` (e.g. ``'org/repo:SiluAndMul'``). + Exactly one of ``revision`` or ``version`` must be supplied. + """ + if (revision is None) == (version is None): + raise ValueError('Exactly one of `revision` or `version` must be specified.') + if ':' not in ref: + raise ValueError(f"Hub ref must be 'repo_id:LayerName', got: {ref!r}") + repo_id, layer_name = ref.rsplit(':', 1) + return HubRef(repo_id, layer_name, revision, version, backend, trust_remote_code) \ No newline at end of file diff --git a/tests/kernel/test_hub.py b/tests/kernel/test_hub.py new file mode 100644 index 00000000..e1e2644e --- /dev/null +++ b/tests/kernel/test_hub.py @@ -0,0 +1,54 @@ +import pytest + +from twinkle.kernel.core import HubRef, hub + + +def test_hub_with_version(): + ref = hub('kernels-community/activation:SiluAndMul', version=1) + assert isinstance(ref, HubRef) + assert ref.repo_id == 'kernels-community/activation' + assert ref.layer_name == 'SiluAndMul' + assert ref.version == 1 + assert ref.revision is None + assert ref.backend is None + assert ref.trust_remote_code is False + + +def test_hub_with_revision(): + ref = hub('org/repo:Layer', revision='main') + assert ref.revision == 'main' + assert ref.version is None + + +def test_hub_with_backend_and_trust(): + ref = hub('org/repo:Layer', version=2, backend='cuda', trust_remote_code=True) + assert ref.backend == 'cuda' + assert ref.trust_remote_code is True + + +def test_hub_rejects_both_revision_and_version(): + with pytest.raises(ValueError, match='Exactly one'): + hub('org/repo:Layer', revision='main', version=1) + + +def test_hub_rejects_neither_revision_nor_version(): + with pytest.raises(ValueError, match='Exactly one'): + hub('org/repo:Layer') + + +def test_hub_rejects_missing_colon(): + with pytest.raises(ValueError, match='repo_id:LayerName'): + hub('org/repo', version=1) + + +def test_hub_handles_colon_in_repo_id(): + # rsplit takes only the last colon + ref = hub('org:sub/repo:Layer', version=1) + assert ref.repo_id == 'org:sub/repo' + assert ref.layer_name == 'Layer' + + +def test_hubref_is_frozen(): + ref = hub('org/repo:Layer', version=1) + with pytest.raises(Exception): + ref.repo_id = 'other' \ No newline at end of file From 7049f6f66a2a183d792cb912fc18ef5b538192f9 Mon Sep 17 00:00:00 2001 From: weikaiwen <34648228+kevssim@users.noreply.github.com> Date: Fri, 26 Jun 2026 08:50:49 +0800 Subject: [PATCH 02/27] feat(kernel): add _infer_device helper --- src/twinkle/kernel/core.py | 13 ++++++++++++- tests/kernel/test_infer_device.py | 29 +++++++++++++++++++++++++++++ 2 files changed, 41 insertions(+), 1 deletion(-) create mode 100644 tests/kernel/test_infer_device.py diff --git a/src/twinkle/kernel/core.py b/src/twinkle/kernel/core.py index 4e41f3c4..e96fc210 100644 --- a/src/twinkle/kernel/core.py +++ b/src/twinkle/kernel/core.py @@ -7,6 +7,8 @@ from dataclasses import dataclass +import torch.nn as nn + @dataclass(frozen=True) class HubRef: @@ -40,4 +42,13 @@ def hub( if ':' not in ref: raise ValueError(f"Hub ref must be 'repo_id:LayerName', got: {ref!r}") repo_id, layer_name = ref.rsplit(':', 1) - return HubRef(repo_id, layer_name, revision, version, backend, trust_remote_code) \ No newline at end of file + return HubRef(repo_id, layer_name, revision, version, backend, trust_remote_code) + + +def _infer_device(model: nn.Module) -> str: + """Infer the device type from the first parameter, then first buffer, else cpu.""" + for p in model.parameters(): + return p.device.type + for b in model.buffers(): + return b.device.type + return 'cpu' \ No newline at end of file diff --git a/tests/kernel/test_infer_device.py b/tests/kernel/test_infer_device.py new file mode 100644 index 00000000..7f9d5581 --- /dev/null +++ b/tests/kernel/test_infer_device.py @@ -0,0 +1,29 @@ +import torch +import torch.nn as nn + +from twinkle.kernel.core import _infer_device + + +class _NoParamsNoBuffers(nn.Module): + pass + + +class _OnlyBuffer(nn.Module): + def __init__(self): + super().__init__() + self.register_buffer('b', torch.zeros(2)) + + +def test_infer_device_from_parameter(): + m = nn.Linear(2, 3) + assert _infer_device(m) == 'cpu' + + +def test_infer_device_from_buffer_when_no_params(): + m = _OnlyBuffer() + assert _infer_device(m) == 'cpu' + + +def test_infer_device_defaults_to_cpu_when_empty(): + m = _NoParamsNoBuffers() + assert _infer_device(m) == 'cpu' \ No newline at end of file From 1547d8cdee2c87834a3501d3ee8a5fe616a067d0 Mon Sep 17 00:00:00 2001 From: weikaiwen <34648228+kevssim@users.noreply.github.com> Date: Fri, 26 Jun 2026 08:53:01 +0800 Subject: [PATCH 03/27] feat(kernel): add _resolve_value with device-conditional dispatch --- src/twinkle/kernel/core.py | 16 +++++++++- tests/kernel/test_resolve_value.py | 48 ++++++++++++++++++++++++++++++ 2 files changed, 63 insertions(+), 1 deletion(-) create mode 100644 tests/kernel/test_resolve_value.py diff --git a/src/twinkle/kernel/core.py b/src/twinkle/kernel/core.py index e96fc210..7531f9f6 100644 --- a/src/twinkle/kernel/core.py +++ b/src/twinkle/kernel/core.py @@ -6,6 +6,7 @@ from __future__ import annotations from dataclasses import dataclass +from typing import Any import torch.nn as nn @@ -51,4 +52,17 @@ def _infer_device(model: nn.Module) -> str: return p.device.type for b in model.buffers(): return b.device.type - return 'cpu' \ No newline at end of file + return 'cpu' + + +def _resolve_value(value: Any, device: str) -> Any | None: + """Resolve a mapping value against the inferred device. + + - ``dict``: device-conditional; recurse into ``value[device]`` or return None. + - anything else (including ``HubRef``): pass through. + """ + if isinstance(value, dict): + if device not in value: + return None + return _resolve_value(value[device], device) + return value \ No newline at end of file diff --git a/tests/kernel/test_resolve_value.py b/tests/kernel/test_resolve_value.py new file mode 100644 index 00000000..652783f5 --- /dev/null +++ b/tests/kernel/test_resolve_value.py @@ -0,0 +1,48 @@ +import torch.nn as nn + +from twinkle.kernel.core import HubRef, _resolve_value + + +class _ImplA(nn.Module): + pass + + +class _ImplB(nn.Module): + pass + + +def test_passthrough_class_value(): + assert _resolve_value(_ImplA, 'cuda') is _ImplA + + +def test_passthrough_callable_value(): + f = lambda x: x # noqa: E731 + assert _resolve_value(f, 'npu') is f + + +def test_passthrough_hubref(): + ref = HubRef('org/repo', 'Layer', revision='main') + assert _resolve_value(ref, 'cuda') is ref + + +def test_device_dict_match(): + val = {'npu': _ImplA, 'cuda': _ImplB} + assert _resolve_value(val, 'npu') is _ImplA + assert _resolve_value(val, 'cuda') is _ImplB + + +def test_device_dict_miss_returns_none(): + val = {'npu': _ImplA} + assert _resolve_value(val, 'cuda') is None + + +def test_device_dict_nested(): + # nested dict -> recursive resolve + val = {'npu': {'npu': _ImplA}} + assert _resolve_value(val, 'npu') is _ImplA + + +def test_device_dict_miss_then_passthrough(): + # nested dict whose inner is also a dict that misses -> None + val = {'npu': {'cuda': _ImplA}} + assert _resolve_value(val, 'npu') is None \ No newline at end of file From 1be64aa55ac3ea4922ea340159edf8dba3c34fc8 Mon Sep 17 00:00:00 2001 From: weikaiwen <34648228+kevssim@users.noreply.github.com> Date: Fri, 26 Jun 2026 08:54:48 +0800 Subject: [PATCH 04/27] feat(kernel): add _replace_class and _replace_attr helpers --- src/twinkle/kernel/core.py | 23 ++++++++++++++- tests/kernel/test_replace.py | 54 ++++++++++++++++++++++++++++++++++++ 2 files changed, 76 insertions(+), 1 deletion(-) create mode 100644 tests/kernel/test_replace.py diff --git a/src/twinkle/kernel/core.py b/src/twinkle/kernel/core.py index 7531f9f6..48c11de0 100644 --- a/src/twinkle/kernel/core.py +++ b/src/twinkle/kernel/core.py @@ -5,6 +5,7 @@ """ from __future__ import annotations +import importlib from dataclasses import dataclass from typing import Any @@ -65,4 +66,24 @@ def _resolve_value(value: Any, device: str) -> Any | None: if device not in value: return None return _resolve_value(value[device], device) - return value \ No newline at end of file + return value + + +def _replace_class(model: nn.Module, target_cls: type, impl_cls: type) -> None: + """Rewrite ``__class__`` of every module whose exact type is ``target_cls``. + + Uses ``type(m) is target_cls`` (not ``isinstance``) so user-defined + subclasses of ``target_cls`` are deliberately left alone. + """ + for m in model.modules(): + if type(m) is target_cls: + m.__class__ = impl_cls + + +def _replace_attr(dotted_path: str, impl) -> None: + """``setattr`` ``impl`` onto the module identified by the dotted path's prefix.""" + module_path, _, attr = dotted_path.rpartition('.') + if not module_path or not attr: + raise ValueError(f"Expected 'pkg.module.attr', got: {dotted_path!r}") + module = importlib.import_module(module_path) + setattr(module, attr, impl) \ No newline at end of file diff --git a/tests/kernel/test_replace.py b/tests/kernel/test_replace.py new file mode 100644 index 00000000..5a2ba459 --- /dev/null +++ b/tests/kernel/test_replace.py @@ -0,0 +1,54 @@ +import sys +import types + +import torch.nn as nn + +from twinkle.kernel.core import _replace_attr, _replace_class + + +class _Target(nn.Module): + def forward(self, x): + return x + + +class _Impl(nn.Module): + def forward(self, x): + return x + 1 + + +class _SubTarget(_Target): + pass + + +def test_replace_class_rewrites_exact_match(): + m = _Target() + parent = nn.Sequential(_Target(), nn.Linear(1, 1)) + _replace_class(parent, _Target, _Impl) + assert type(parent[0]) is _Impl + + +def test_replace_class_skips_subclass(): + parent = nn.Sequential(_SubTarget()) + _replace_class(parent, _Target, _Impl) + # exact match only - _SubTarget should NOT be rewritten + assert type(parent[0]) is _SubTarget + + +def test_replace_class_idempotent(): + m = nn.Sequential(_Target()) + _replace_class(m, _Target, _Impl) + _replace_class(m, _Target, _Impl) # second call must be safe + assert type(m[0]) is _Impl + + +def test_replace_attr_sets_module_attribute(): + mod_name = 'tests.kernel._tmp_replace_attr' + mod = types.ModuleType(mod_name) + mod.target_fn = lambda x: x + sys.modules[mod_name] = mod + try: + new_fn = lambda x: x * 2 # noqa: E731 + _replace_attr(f'{mod_name}.target_fn', new_fn) + assert mod.target_fn is new_fn + finally: + sys.modules.pop(mod_name, None) \ No newline at end of file From c72883bbaae06b1f0d5d712a50b04ffae51e3d5f Mon Sep 17 00:00:00 2001 From: weikaiwen <34648228+kevssim@users.noreply.github.com> Date: Fri, 26 Jun 2026 08:56:15 +0800 Subject: [PATCH 05/27] feat(kernel): add _load_hub_ref with lazy kernels import --- src/twinkle/kernel/core.py | 28 ++++++++++++- tests/kernel/test_load_hub_ref.py | 69 +++++++++++++++++++++++++++++++ 2 files changed, 96 insertions(+), 1 deletion(-) create mode 100644 tests/kernel/test_load_hub_ref.py diff --git a/src/twinkle/kernel/core.py b/src/twinkle/kernel/core.py index 48c11de0..d3fc3f2f 100644 --- a/src/twinkle/kernel/core.py +++ b/src/twinkle/kernel/core.py @@ -86,4 +86,30 @@ def _replace_attr(dotted_path: str, impl) -> None: if not module_path or not attr: raise ValueError(f"Expected 'pkg.module.attr', got: {dotted_path!r}") module = importlib.import_module(module_path) - setattr(module, attr, impl) \ No newline at end of file + setattr(module, attr, impl) + + +def _load_hub_ref(ref: HubRef): + """Lazy-load a Hub kernel layer via the optional ``kernels`` package.""" + try: + from kernels import get_kernel + except ImportError as e: + raise ImportError( + 'Loading a Hub kernel requires the `kernels` package. ' + 'Install it with `pip install kernels`.' + ) from e + + kernel = get_kernel( + ref.repo_id, + revision=ref.revision, + version=ref.version, + backend=ref.backend, + trust_remote_code=ref.trust_remote_code, + ) + layers = getattr(kernel, 'layers', None) + if layers is None: + raise ValueError(f'Hub repo {ref.repo_id!r} does not define any layers.') + impl = getattr(layers, ref.layer_name, None) + if impl is None: + raise ValueError(f'Layer {ref.layer_name!r} not found in {ref.repo_id!r}.') + return impl \ No newline at end of file diff --git a/tests/kernel/test_load_hub_ref.py b/tests/kernel/test_load_hub_ref.py new file mode 100644 index 00000000..747e3fdc --- /dev/null +++ b/tests/kernel/test_load_hub_ref.py @@ -0,0 +1,69 @@ +import sys +import types +from unittest.mock import patch + +import pytest + +from twinkle.kernel.core import HubRef, _load_hub_ref + + +def _install_fake_kernels(layer_obj=None, no_layers=False): + """Install a fake `kernels` module with a controllable `get_kernel`.""" + fake = types.ModuleType('kernels') + + def fake_get_kernel(repo_id, **kwargs): + m = types.ModuleType('fake_kernel') + if not no_layers: + layers_ns = types.SimpleNamespace() + if layer_obj is not None: + layers_ns.MyLayer = layer_obj + m.layers = layers_ns + return m + + fake.get_kernel = fake_get_kernel + sys.modules['kernels'] = fake + + +def _uninstall_fake_kernels(): + sys.modules.pop('kernels', None) + + +def test_load_hub_ref_returns_layer(): + sentinel = object() + _install_fake_kernels(layer_obj=sentinel) + try: + ref = HubRef('org/repo', 'MyLayer', revision='main') + assert _load_hub_ref(ref) is sentinel + finally: + _uninstall_fake_kernels() + + +def test_load_hub_ref_raises_if_layers_missing(): + _install_fake_kernels(no_layers=True) + try: + ref = HubRef('org/repo', 'MyLayer', revision='main') + with pytest.raises(ValueError, match='does not define any layers'): + _load_hub_ref(ref) + finally: + _uninstall_fake_kernels() + + +def test_load_hub_ref_raises_if_layer_name_missing(): + _install_fake_kernels(layer_obj=None) # MyLayer not present + try: + ref = HubRef('org/repo', 'Missing', revision='main') + with pytest.raises(ValueError, match='not found'): + _load_hub_ref(ref) + finally: + _uninstall_fake_kernels() + + +def test_load_hub_ref_install_hint_when_kernels_missing(): + # Force `import kernels` to fail + sys.modules['kernels'] = None # short-circuits import to ImportError + try: + ref = HubRef('org/repo', 'MyLayer', revision='main') + with pytest.raises(ImportError, match='pip install kernels'): + _load_hub_ref(ref) + finally: + sys.modules.pop('kernels', None) \ No newline at end of file From d4318b1924fd6a8fbab73e25515a553e1b78b5c9 Mon Sep 17 00:00:00 2001 From: weikaiwen <34648228+kevssim@users.noreply.github.com> Date: Fri, 26 Jun 2026 08:58:04 +0800 Subject: [PATCH 06/27] feat(kernel): add kernelize() dispatcher --- src/twinkle/kernel/core.py | 36 +++++++++++++++- tests/kernel/test_kernelize.py | 77 ++++++++++++++++++++++++++++++++++ 2 files changed, 112 insertions(+), 1 deletion(-) create mode 100644 tests/kernel/test_kernelize.py diff --git a/src/twinkle/kernel/core.py b/src/twinkle/kernel/core.py index d3fc3f2f..e83c7a44 100644 --- a/src/twinkle/kernel/core.py +++ b/src/twinkle/kernel/core.py @@ -112,4 +112,38 @@ def _load_hub_ref(ref: HubRef): impl = getattr(layers, ref.layer_name, None) if impl is None: raise ValueError(f'Layer {ref.layer_name!r} not found in {ref.repo_id!r}.') - return impl \ No newline at end of file + return impl + + +def kernelize(model: nn.Module, mapping: dict) -> nn.Module: + """Apply ``mapping`` to ``model`` and return it (modified in place). + + Keys: + - ``type[nn.Module]``: replace ``m.__class__`` for every module of the + exact type (no subclass walking). + - ``str`` (dotted path ``pkg.mod.attr``): ``setattr`` the impl onto the + identified module attribute. + + Values: + - ``dict[str, V]``: device-conditional dispatch using + ``next(model.parameters()).device.type``; non-matching devices skip. + - ``HubRef``: lazy-resolved via the optional ``kernels`` package. + - anything else: used directly as the impl. + """ + if not mapping: + return model + + device = _infer_device(model) + for key, value in mapping.items(): + impl = _resolve_value(value, device) + if impl is None: + continue + if isinstance(impl, HubRef): + impl = _load_hub_ref(impl) + if isinstance(key, type) and issubclass(key, nn.Module): + _replace_class(model, key, impl) + elif isinstance(key, str): + _replace_attr(key, impl) + else: + raise TypeError(f'Unsupported mapping key: {key!r}') + return model \ No newline at end of file diff --git a/tests/kernel/test_kernelize.py b/tests/kernel/test_kernelize.py new file mode 100644 index 00000000..dd159de6 --- /dev/null +++ b/tests/kernel/test_kernelize.py @@ -0,0 +1,77 @@ +import sys +import types + +import pytest +import torch +import torch.nn as nn + +from twinkle.kernel.core import HubRef, kernelize + + +class _SrcLayer(nn.Module): + def __init__(self): + super().__init__() + self.w = nn.Parameter(torch.zeros(1)) + + def forward(self, x): + return x + + +class _DstLayer(nn.Module): + def forward(self, x): + return x + 100 + + +def test_kernelize_class_to_class_replacement(): + parent = nn.Sequential(_SrcLayer(), _SrcLayer()) + out = kernelize(parent, {_SrcLayer: _DstLayer}) + assert out is parent + assert type(parent[0]) is _DstLayer + assert type(parent[1]) is _DstLayer + + +def test_kernelize_empty_mapping_returns_model(): + m = _SrcLayer() + assert kernelize(m, {}) is m + assert type(m) is _SrcLayer + + +def test_kernelize_string_key_calls_setattr(): + mod_name = 'tests.kernel._tmp_kernelize_str' + mod = types.ModuleType(mod_name) + mod.target_fn = lambda x: x + sys.modules[mod_name] = mod + try: + new_fn = lambda x: x * 3 # noqa: E731 + kernelize(nn.Linear(1, 1), {f'{mod_name}.target_fn': new_fn}) + assert mod.target_fn is new_fn + finally: + sys.modules.pop(mod_name, None) + + +def test_kernelize_device_dict_match(): + parent = nn.Sequential(_SrcLayer()) # cpu params + kernelize(parent, {_SrcLayer: {'cpu': _DstLayer, 'npu': nn.Identity}}) + assert type(parent[0]) is _DstLayer + + +def test_kernelize_device_dict_miss_skips_silently(): + parent = nn.Sequential(_SrcLayer()) # cpu params + kernelize(parent, {_SrcLayer: {'npu': _DstLayer}}) + assert type(parent[0]) is _SrcLayer + + +def test_kernelize_rejects_unknown_key_type(): + with pytest.raises(TypeError, match='Unsupported mapping key'): + kernelize(nn.Linear(1, 1), {42: _DstLayer}) + + +def test_kernelize_loads_hub_ref(monkeypatch): + # Stand in for HF kernels: patch _load_hub_ref to return _DstLayer + from twinkle.kernel import core as _core + monkeypatch.setattr(_core, '_load_hub_ref', lambda ref: _DstLayer) + + parent = nn.Sequential(_SrcLayer()) + ref = HubRef('org/repo', 'X', revision='main') + kernelize(parent, {_SrcLayer: ref}) + assert type(parent[0]) is _DstLayer \ No newline at end of file From 77165c9fdd70be04c8a566e2326c5c554aed3a91 Mon Sep 17 00:00:00 2001 From: weikaiwen <34648228+kevssim@users.noreply.github.com> Date: Fri, 26 Jun 2026 09:00:15 +0800 Subject: [PATCH 07/27] feat(kernel): add npu_impls/rms_norm module --- src/twinkle/kernel/npu_impls/__init__.py | 10 +++ src/twinkle/kernel/npu_impls/rms_norm.py | 79 ++++++++++++++++++++++++ tests/kernel/npu_impls/__init__.py | 0 tests/kernel/npu_impls/test_rms_norm.py | 40 ++++++++++++ 4 files changed, 129 insertions(+) create mode 100644 src/twinkle/kernel/npu_impls/__init__.py create mode 100644 src/twinkle/kernel/npu_impls/rms_norm.py create mode 100644 tests/kernel/npu_impls/__init__.py create mode 100644 tests/kernel/npu_impls/test_rms_norm.py diff --git a/src/twinkle/kernel/npu_impls/__init__.py b/src/twinkle/kernel/npu_impls/__init__.py new file mode 100644 index 00000000..dc6a189b --- /dev/null +++ b/src/twinkle/kernel/npu_impls/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Per-layer NPU implementations consumed by ``npu_builtin()``. + +Each impl is contracted to be applied via ``m.__class__ = ImplCls`` (class +replacement) or ``setattr(module, attr, fn)`` (function replacement). No impl +here is meant to be instantiated directly. +""" +from .rms_norm import NpuRMSNorm, npu_gated_rms_norm_forward + +__all__ = ['NpuRMSNorm', 'npu_gated_rms_norm_forward'] \ No newline at end of file diff --git a/src/twinkle/kernel/npu_impls/rms_norm.py b/src/twinkle/kernel/npu_impls/rms_norm.py new file mode 100644 index 00000000..281443bd --- /dev/null +++ b/src/twinkle/kernel/npu_impls/rms_norm.py @@ -0,0 +1,79 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Fused RMSNorm impls for Ascend NPU. + +Designed for class-replacement: do not define ``__init__``; rely on the +attributes already present on the original instance. +""" +from __future__ import annotations + +import os + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from twinkle import get_logger + +logger = get_logger() + + +class NpuRMSNorm(nn.Module): + """Class-replacement impl for HF RMSNorm variants. + + Required instance attributes (provided by the original class): + - ``weight``: ``nn.Parameter`` + - ``variance_epsilon`` *or* ``eps``: float + """ + + def _twinkle_residual_param(self) -> bool: + """Lazily detect residual parameterization (e.g. Qwen3.5: scale = 1 + weight).""" + cached = getattr(self, '_twinkle_residual_cached', None) + if cached is None: + cached = abs(self.weight.data.mean().item()) < 0.3 + self._twinkle_residual_cached = cached + if cached: + logger.debug('[NPU] NpuRMSNorm using residual parameterization (1.0 + weight)') + return cached + + def _twinkle_eps(self) -> float: + return getattr(self, 'variance_epsilon', getattr(self, 'eps', 1e-6)) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + import torch_npu + target_dtype = hidden_states.dtype + if self._twinkle_residual_param(): + scale = (1.0 + self.weight).to(target_dtype) + else: + scale = self.weight.to(target_dtype) + return torch_npu.npu_rms_norm(hidden_states, scale, epsilon=self._twinkle_eps())[0] + + +def npu_gated_rms_norm_forward(self, hidden_states, gate=None): + """Forward replacement for Gated RMSNorm variants (e.g. Qwen3.5-MoE). + + Reads FP32-mode preference from env ``TWINKLE_NPU_GATED_RMSNorm_FP32`` once + and caches it on the instance. + """ + import torch_npu + + input_dtype = hidden_states.dtype + _eps = getattr(self, 'variance_epsilon', getattr(self, 'eps', 1e-6)) + + force_fp32 = getattr(self, '_twinkle_force_fp32', None) + if force_fp32 is None: + force_fp32 = os.environ.get('TWINKLE_NPU_GATED_RMSNorm_FP32', '0').lower() in ( + '1', 'true', 'on', 'yes' + ) + self._twinkle_force_fp32 = force_fp32 + + if force_fp32: + hidden_states = hidden_states.to(torch.float32) + weight = self.weight.float() + gate = gate.to(torch.float32) if gate is not None else None + else: + weight = self.weight + + hidden_states = torch_npu.npu_rms_norm(hidden_states, weight, epsilon=_eps)[0] + if gate is not None: + hidden_states = hidden_states * F.silu(gate) + return hidden_states.to(input_dtype) \ No newline at end of file diff --git a/tests/kernel/npu_impls/__init__.py b/tests/kernel/npu_impls/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/kernel/npu_impls/test_rms_norm.py b/tests/kernel/npu_impls/test_rms_norm.py new file mode 100644 index 00000000..184d7ef7 --- /dev/null +++ b/tests/kernel/npu_impls/test_rms_norm.py @@ -0,0 +1,40 @@ +import pytest +import torch +import torch.nn as nn + +try: + import torch_npu # noqa: F401 + _NPU_OK = True +except ImportError: + _NPU_OK = False + + +def test_imports(): + """NpuRMSNorm and npu_gated_rms_norm_forward import without torch_npu.""" + from twinkle.kernel.npu_impls.rms_norm import NpuRMSNorm, npu_gated_rms_norm_forward + assert NpuRMSNorm is not None + assert callable(npu_gated_rms_norm_forward) + + +def test_npu_rmsnorm_has_no_init(): + """Class-replacement contract: NpuRMSNorm must not define its own __init__.""" + from twinkle.kernel.npu_impls.rms_norm import NpuRMSNorm + # If NpuRMSNorm defines __init__, it'd appear in NpuRMSNorm.__dict__ + assert '__init__' not in NpuRMSNorm.__dict__ + + +@pytest.mark.skipif(not _NPU_OK, reason='torch_npu unavailable') +def test_npu_rmsnorm_forward_runs_on_npu(): + from twinkle.kernel.npu_impls.rms_norm import NpuRMSNorm + + class _Orig(nn.Module): + def __init__(self): + super().__init__() + self.weight = nn.Parameter(torch.ones(8)) + self.variance_epsilon = 1e-6 + + m = _Orig().to('npu') + m.__class__ = NpuRMSNorm + x = torch.randn(2, 8, device='npu') + y = m(x) + assert y.shape == (2, 8) \ No newline at end of file From ac47045a6ab2ec5a13392b25b1629981d69411ae Mon Sep 17 00:00:00 2001 From: weikaiwen <34648228+kevssim@users.noreply.github.com> Date: Fri, 26 Jun 2026 09:01:58 +0800 Subject: [PATCH 08/27] feat(kernel): add npu_impls/rotary module --- src/twinkle/kernel/npu_impls/__init__.py | 8 ++- src/twinkle/kernel/npu_impls/rotary.py | 66 ++++++++++++++++++++++++ tests/kernel/npu_impls/test_rotary.py | 21 ++++++++ 3 files changed, 94 insertions(+), 1 deletion(-) create mode 100644 src/twinkle/kernel/npu_impls/rotary.py create mode 100644 tests/kernel/npu_impls/test_rotary.py diff --git a/src/twinkle/kernel/npu_impls/__init__.py b/src/twinkle/kernel/npu_impls/__init__.py index dc6a189b..bcef54d3 100644 --- a/src/twinkle/kernel/npu_impls/__init__.py +++ b/src/twinkle/kernel/npu_impls/__init__.py @@ -6,5 +6,11 @@ here is meant to be instantiated directly. """ from .rms_norm import NpuRMSNorm, npu_gated_rms_norm_forward +from .rotary import npu_apply_multimodal_rotary_pos_emb, npu_apply_rotary_pos_emb -__all__ = ['NpuRMSNorm', 'npu_gated_rms_norm_forward'] \ No newline at end of file +__all__ = [ + 'NpuRMSNorm', + 'npu_gated_rms_norm_forward', + 'npu_apply_rotary_pos_emb', + 'npu_apply_multimodal_rotary_pos_emb', +] \ No newline at end of file diff --git a/src/twinkle/kernel/npu_impls/rotary.py b/src/twinkle/kernel/npu_impls/rotary.py new file mode 100644 index 00000000..1ed437a3 --- /dev/null +++ b/src/twinkle/kernel/npu_impls/rotary.py @@ -0,0 +1,66 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Fused RoPE impls for Ascend NPU (lazy ``torch_npu`` import).""" +from __future__ import annotations + +import torch + + +def _resolve_unsqueeze_dim(position_ids=None, unsqueeze_dim=1): + if isinstance(position_ids, int) and unsqueeze_dim == 1: + return position_ids + return unsqueeze_dim + + +def _make_apply_npu_rotary_emb(): + """Closure with per-shape Partial-RoPE detection cache.""" + _cached_partial: dict[tuple[int, int], bool] = {} + + def _apply(q, k, cos, sin): + import torch_npu + rotary_dim = cos.shape[-1] + query_dim = q.shape[-1] + shape_key = (rotary_dim, query_dim) + + use_partial = _cached_partial.get(shape_key) + if use_partial is None: + use_partial = rotary_dim < query_dim + _cached_partial[shape_key] = use_partial + + if use_partial: + q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] + k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] + q_embed = torch_npu.npu_rotary_mul(q_rot, cos, sin).to(q.dtype) + k_embed = torch_npu.npu_rotary_mul(k_rot, cos, sin).to(k.dtype) + q_embed = torch.cat([q_embed, q_pass], dim=-1) + k_embed = torch.cat([k_embed, k_pass], dim=-1) + else: + q_embed = torch_npu.npu_rotary_mul(q, cos, sin).to(q.dtype) + k_embed = torch_npu.npu_rotary_mul(k, cos, sin).to(k.dtype) + return q_embed, k_embed + + return _apply + + +_apply_npu_rotary_emb = _make_apply_npu_rotary_emb() + + +def npu_apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Fused RoPE via ``torch_npu.npu_rotary_mul`` with Partial-RoPE support.""" + unsqueeze_dim = _resolve_unsqueeze_dim(position_ids, unsqueeze_dim) + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + return _apply_npu_rotary_emb(q, k, cos, sin) + + +def npu_apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1): + """Multimodal RoPE for Qwen2.5-VL with Partial-RoPE support.""" + mrope_section = mrope_section * 2 + cos = torch.cat( + [m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], + dim=-1, + ).unsqueeze(unsqueeze_dim) + sin = torch.cat( + [m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], + dim=-1, + ).unsqueeze(unsqueeze_dim) + return _apply_npu_rotary_emb(q, k, cos, sin) \ No newline at end of file diff --git a/tests/kernel/npu_impls/test_rotary.py b/tests/kernel/npu_impls/test_rotary.py new file mode 100644 index 00000000..460d0fc3 --- /dev/null +++ b/tests/kernel/npu_impls/test_rotary.py @@ -0,0 +1,21 @@ +def test_rotary_imports(): + from twinkle.kernel.npu_impls.rotary import ( + npu_apply_multimodal_rotary_pos_emb, + npu_apply_rotary_pos_emb, + ) + assert callable(npu_apply_rotary_pos_emb) + assert callable(npu_apply_multimodal_rotary_pos_emb) + + +def test_rotary_signature_compat(): + """Signature must match HF apply_rotary_pos_emb so setattr swap is safe.""" + import inspect + + from twinkle.kernel.npu_impls.rotary import npu_apply_rotary_pos_emb + + sig = inspect.signature(npu_apply_rotary_pos_emb) + params = list(sig.parameters) + assert params[:4] == ['q', 'k', 'cos', 'sin'] + # position_ids and unsqueeze_dim must be optional + assert sig.parameters['position_ids'].default is None + assert sig.parameters['unsqueeze_dim'].default == 1 \ No newline at end of file From f0d0a2374c84335f92c4cf2b82d3895d6562bed6 Mon Sep 17 00:00:00 2001 From: weikaiwen <34648228+kevssim@users.noreply.github.com> Date: Fri, 26 Jun 2026 09:03:55 +0800 Subject: [PATCH 09/27] feat(kernel): add npu_impls/swiglu module --- src/twinkle/kernel/npu_impls/__init__.py | 2 ++ src/twinkle/kernel/npu_impls/swiglu.py | 20 ++++++++++++++++++++ tests/kernel/npu_impls/test_swiglu.py | 12 ++++++++++++ 3 files changed, 34 insertions(+) create mode 100644 src/twinkle/kernel/npu_impls/swiglu.py create mode 100644 tests/kernel/npu_impls/test_swiglu.py diff --git a/src/twinkle/kernel/npu_impls/__init__.py b/src/twinkle/kernel/npu_impls/__init__.py index bcef54d3..26f64c21 100644 --- a/src/twinkle/kernel/npu_impls/__init__.py +++ b/src/twinkle/kernel/npu_impls/__init__.py @@ -7,10 +7,12 @@ """ from .rms_norm import NpuRMSNorm, npu_gated_rms_norm_forward from .rotary import npu_apply_multimodal_rotary_pos_emb, npu_apply_rotary_pos_emb +from .swiglu import npu_swiglu_forward __all__ = [ 'NpuRMSNorm', 'npu_gated_rms_norm_forward', 'npu_apply_rotary_pos_emb', 'npu_apply_multimodal_rotary_pos_emb', + 'npu_swiglu_forward', ] \ No newline at end of file diff --git a/src/twinkle/kernel/npu_impls/swiglu.py b/src/twinkle/kernel/npu_impls/swiglu.py new file mode 100644 index 00000000..c34a7bea --- /dev/null +++ b/src/twinkle/kernel/npu_impls/swiglu.py @@ -0,0 +1,20 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Fused SwiGLU forward for Ascend NPU.""" +from __future__ import annotations + +import torch + + +def npu_swiglu_forward(self, hidden_state): + """Fused Qwen-style SwiGLU. + + Used as a class-attribute replacement on HF MLP classes. + Required instance attributes: ``gate_proj``, ``up_proj``, ``down_proj``. + """ + import torch_npu + return self.down_proj( + torch_npu.npu_swiglu( + torch.cat((self.gate_proj(hidden_state), self.up_proj(hidden_state)), dim=-1), + dim=-1, + ) + ) \ No newline at end of file diff --git a/tests/kernel/npu_impls/test_swiglu.py b/tests/kernel/npu_impls/test_swiglu.py new file mode 100644 index 00000000..d4ec2da9 --- /dev/null +++ b/tests/kernel/npu_impls/test_swiglu.py @@ -0,0 +1,12 @@ +def test_swiglu_imports(): + from twinkle.kernel.npu_impls.swiglu import npu_swiglu_forward + assert callable(npu_swiglu_forward) + + +def test_swiglu_signature(): + import inspect + + from twinkle.kernel.npu_impls.swiglu import npu_swiglu_forward + + params = list(inspect.signature(npu_swiglu_forward).parameters) + assert params == ['self', 'hidden_state'] \ No newline at end of file From 1e9902eb73dec4819c476061b43eff48f0e7a8d7 Mon Sep 17 00:00:00 2001 From: weikaiwen <34648228+kevssim@users.noreply.github.com> Date: Fri, 26 Jun 2026 09:05:38 +0800 Subject: [PATCH 10/27] feat(kernel): add npu_impls/attention module --- src/twinkle/kernel/npu_impls/__init__.py | 2 + src/twinkle/kernel/npu_impls/attention.py | 54 +++++++++++++++++++++++ tests/kernel/npu_impls/test_attention.py | 16 +++++++ 3 files changed, 72 insertions(+) create mode 100644 src/twinkle/kernel/npu_impls/attention.py create mode 100644 tests/kernel/npu_impls/test_attention.py diff --git a/src/twinkle/kernel/npu_impls/__init__.py b/src/twinkle/kernel/npu_impls/__init__.py index 26f64c21..6067e544 100644 --- a/src/twinkle/kernel/npu_impls/__init__.py +++ b/src/twinkle/kernel/npu_impls/__init__.py @@ -8,6 +8,7 @@ from .rms_norm import NpuRMSNorm, npu_gated_rms_norm_forward from .rotary import npu_apply_multimodal_rotary_pos_emb, npu_apply_rotary_pos_emb from .swiglu import npu_swiglu_forward +from .attention import npu_sdpa_attention_forward __all__ = [ 'NpuRMSNorm', @@ -15,4 +16,5 @@ 'npu_apply_rotary_pos_emb', 'npu_apply_multimodal_rotary_pos_emb', 'npu_swiglu_forward', + 'npu_sdpa_attention_forward', ] \ No newline at end of file diff --git a/src/twinkle/kernel/npu_impls/attention.py b/src/twinkle/kernel/npu_impls/attention.py new file mode 100644 index 00000000..f328b2d5 --- /dev/null +++ b/src/twinkle/kernel/npu_impls/attention.py @@ -0,0 +1,54 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""SDPA forward with Ascend NPU compatibility fixes.""" +from __future__ import annotations + +import torch + + +def npu_sdpa_attention_forward( + module, + query, + key, + value, + attention_mask, + dropout=0.0, + scaling=None, + is_causal=None, + **kwargs, +): + """Drop-in replacement for ``transformers.integrations.sdpa_attention.sdpa_attention_forward``. + + Fixes: + - Repeats KV heads (NPU SDPA does not auto-broadcast num_kv_groups). + - Truncates causal_mask to key length. + - Forces contiguous tensors (NPU SDPA requirement). + - Inverts boolean masks (NPU treats ``True`` as masked). + """ + from transformers.integrations.sdpa_attention import repeat_kv + + if hasattr(module, 'num_key_value_groups'): + key = repeat_kv(key, module.num_key_value_groups) + value = repeat_kv(value, module.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None and causal_mask.ndim == 4: + causal_mask = causal_mask[:, :, :, :key.shape[-2]] + + query, key, value = query.contiguous(), key.contiguous(), value.contiguous() + + if is_causal is None: + is_causal = query.shape[2] > 1 and causal_mask is None + + if causal_mask is not None and causal_mask.dtype != torch.bool: + causal_mask = torch.logical_not(causal_mask.bool()).to(query.device) + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query, + key, + value, + attn_mask=causal_mask, + dropout_p=dropout, + scale=scaling, + is_causal=is_causal, + ) + return attn_output.transpose(1, 2).contiguous(), None \ No newline at end of file diff --git a/tests/kernel/npu_impls/test_attention.py b/tests/kernel/npu_impls/test_attention.py new file mode 100644 index 00000000..ed916dba --- /dev/null +++ b/tests/kernel/npu_impls/test_attention.py @@ -0,0 +1,16 @@ +def test_attention_imports(): + from twinkle.kernel.npu_impls.attention import npu_sdpa_attention_forward + assert callable(npu_sdpa_attention_forward) + + +def test_attention_signature(): + import inspect + + from twinkle.kernel.npu_impls.attention import npu_sdpa_attention_forward + + sig = inspect.signature(npu_sdpa_attention_forward) + params = list(sig.parameters) + assert params[:5] == ['module', 'query', 'key', 'value', 'attention_mask'] + assert sig.parameters['dropout'].default == 0.0 + assert sig.parameters['scaling'].default is None + assert sig.parameters['is_causal'].default is None \ No newline at end of file From 4fc02c191e59fcb1b1f0297e4dbca2e377cb606e Mon Sep 17 00:00:00 2001 From: weikaiwen <34648228+kevssim@users.noreply.github.com> Date: Fri, 26 Jun 2026 09:09:13 +0800 Subject: [PATCH 11/27] feat(kernel): add npu_impls/moe module --- src/twinkle/kernel/npu_impls/__init__.py | 10 ++ src/twinkle/kernel/npu_impls/moe.py | 151 +++++++++++++++++++++++ tests/kernel/npu_impls/test_moe.py | 12 ++ 3 files changed, 173 insertions(+) create mode 100644 src/twinkle/kernel/npu_impls/moe.py create mode 100644 tests/kernel/npu_impls/test_moe.py diff --git a/src/twinkle/kernel/npu_impls/__init__.py b/src/twinkle/kernel/npu_impls/__init__.py index 6067e544..9194732c 100644 --- a/src/twinkle/kernel/npu_impls/__init__.py +++ b/src/twinkle/kernel/npu_impls/__init__.py @@ -9,6 +9,12 @@ from .rotary import npu_apply_multimodal_rotary_pos_emb, npu_apply_rotary_pos_emb from .swiglu import npu_swiglu_forward from .attention import npu_sdpa_attention_forward +from .moe import ( + GmmFunction, + npu_grouped_mm, + npu_packed_moe_experts_forward, + npu_qwen3_5_moe_sparse_block_forward, +) __all__ = [ 'NpuRMSNorm', @@ -17,4 +23,8 @@ 'npu_apply_multimodal_rotary_pos_emb', 'npu_swiglu_forward', 'npu_sdpa_attention_forward', + 'GmmFunction', + 'npu_grouped_mm', + 'npu_packed_moe_experts_forward', + 'npu_qwen3_5_moe_sparse_block_forward', ] \ No newline at end of file diff --git a/src/twinkle/kernel/npu_impls/moe.py b/src/twinkle/kernel/npu_impls/moe.py new file mode 100644 index 00000000..efa7f71a --- /dev/null +++ b/src/twinkle/kernel/npu_impls/moe.py @@ -0,0 +1,151 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""MoE GMM + packed-experts + sparse-block impls for Ascend NPU.""" +from __future__ import annotations + +import torch +import torch.nn.functional as F + + +class GmmFunction(torch.autograd.Function): + """Custom autograd function for NPU grouped matrix multiplication.""" + + @staticmethod + def forward(ctx, x: torch.Tensor, group_list: torch.Tensor, weight_ekn: torch.Tensor): + import torch_npu + group_list = group_list.to(torch.int64) + ctx.save_for_backward(x, group_list, weight_ekn) + outputs = torch_npu.npu_grouped_matmul( + [x], [weight_ekn], group_list=group_list, + group_type=0, split_item=2, group_list_type=1, + ) + return outputs[0] + + @staticmethod + def backward(ctx, grad_output: torch.Tensor): + import torch_npu + x, group_list, weight_ekn = ctx.saved_tensors + grad_input = torch_npu.npu_grouped_matmul( + [grad_output], [weight_ekn.transpose(-2, -1).contiguous()], + bias=None, group_list=group_list, + group_type=0, split_item=2, group_list_type=1, + )[0] + grad_weight = torch_npu.npu_grouped_matmul( + [x.transpose(0, 1)], [grad_output], + bias=None, group_list=group_list, + group_type=2, split_item=3, group_list_type=1, + )[0] + return grad_input, None, grad_weight.contiguous() + + +def npu_grouped_mm(input: torch.Tensor, weight_ekn: torch.Tensor, offs: torch.Tensor) -> torch.Tensor: + """Drop-in replacement for ``transformers.integrations.moe._grouped_mm``.""" + counts = torch.empty_like(offs) + counts[0] = offs[0] + if offs.numel() > 1: + counts[1:] = offs[1:] - offs[:-1] + counts = counts.to(torch.int64) + return GmmFunction.apply(input, counts, weight_ekn) + + +def _normalize_packed_expert_weights(module, input_dtype, hidden_dim): + gate_up_proj = module.gate_up_proj.to(input_dtype) + down_proj = module.down_proj.to(input_dtype) + if gate_up_proj.shape[1] == hidden_dim: + gate_up_weight = gate_up_proj + elif gate_up_proj.shape[2] == hidden_dim: + gate_up_weight = gate_up_proj.transpose(1, 2) + else: + raise RuntimeError(f'Unsupported gate_up_proj shape: {tuple(gate_up_proj.shape)}.') + if down_proj.shape[2] == hidden_dim: + down_weight = down_proj + elif down_proj.shape[1] == hidden_dim: + down_weight = down_proj.transpose(1, 2) + else: + raise RuntimeError(f'Unsupported down_proj shape: {tuple(down_proj.shape)}.') + return gate_up_weight, down_weight + + +def _get_cached_expert_weights(self, target_dtype, hidden_dim): + requires_grad = ( + getattr(self.gate_up_proj, 'requires_grad', False) + or getattr(self.down_proj, 'requires_grad', False) + ) + cache_attr = '_npu_expert_cache' + if not requires_grad and hasattr(self, cache_attr): + cached_dtype, cached_gv, cached_dv, cached = getattr(self, cache_attr) + if (cached_dtype == target_dtype + and cached_gv == self.gate_up_proj._version + and cached_dv == self.down_proj._version): + return cached + weights = _normalize_packed_expert_weights(self, target_dtype, hidden_dim) + if not requires_grad: + setattr(self, cache_attr, + (target_dtype, self.gate_up_proj._version, self.down_proj._version, weights)) + return weights + + +def npu_packed_moe_experts_forward(self, hidden_states, a, b): + """Packed MoE Experts.forward using NPU grouped matmul. + + Accepts both call orderings: ``(hidden_states, routing_weights, router_indices)`` + and ``(hidden_states, router_indices, routing_weights)`` — distinguishes by dtype. + """ + import torch_npu + if a.dtype in {torch.int8, torch.int16, torch.int32, torch.int64, torch.uint8}: + router_indices, routing_weights = a, b + else: + routing_weights, router_indices = a, b + + output_shape = hidden_states.shape + hidden_dim = output_shape[-1] + hidden_states = hidden_states.reshape(-1, hidden_dim) + + if routing_weights.shape != router_indices.shape: + routing_weights = torch.gather(routing_weights, dim=-1, index=router_indices.to(torch.long)) + routing_weights = routing_weights.to(hidden_states.dtype) + router_indices = router_indices.to(torch.int32) + + permuted, row_ids_map = torch_npu.npu_moe_token_permute(hidden_states, router_indices) + tokens_per_expert = torch.bincount(router_indices.view(-1), minlength=self.num_experts).to(torch.int64) + gate_up_weight, down_weight = _get_cached_expert_weights(self, hidden_states.dtype, hidden_dim) + + intermediate = GmmFunction.apply(permuted, tokens_per_expert, gate_up_weight) + activated = torch_npu.npu_swiglu(intermediate, dim=-1) + output = GmmFunction.apply(activated, tokens_per_expert, down_weight) + next_states = torch_npu.npu_moe_token_unpermute(output, row_ids_map, probs=routing_weights) + return next_states.view(*output_shape) + + +def _topk_from_router_logits(module, hidden_states, router_logits): + routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float) + routing_weights, router_indices = torch.topk(routing_weights, module.top_k, dim=-1) + if getattr(module, 'norm_topk_prob', True): + routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True) + routing_weights = routing_weights.to(hidden_states.dtype) + return routing_weights, router_indices + + +def _add_shared_expert(self, hidden_states, expert_output): + if not (hasattr(self, 'shared_expert') and hasattr(self, 'shared_expert_gate')): + return expert_output + shared = self.shared_expert(hidden_states) + shared = F.sigmoid(self.shared_expert_gate(hidden_states)) * shared + return expert_output + shared + + +def npu_qwen3_5_moe_sparse_block_forward(self, hidden_states): + """SparseMoeBlock.forward replacement (Transformers 4.x and 5.x compatible).""" + batch_size, sequence_length, hidden_dim = hidden_states.shape + gate_output = self.gate(hidden_states.view(-1, hidden_dim)) + + if isinstance(gate_output, tuple): + _, routing_weights, selected_experts = gate_output + flat = hidden_states.view(-1, hidden_dim) + expert_output = self.experts(flat, selected_experts, routing_weights) + else: + flat = hidden_states.view(-1, hidden_dim) + routing_weights, selected_experts = _topk_from_router_logits(self, flat, gate_output) + expert_output = self.experts(flat, selected_experts, routing_weights) + + expert_output = _add_shared_expert(self, flat, expert_output) + return expert_output.reshape(batch_size, sequence_length, hidden_dim) \ No newline at end of file diff --git a/tests/kernel/npu_impls/test_moe.py b/tests/kernel/npu_impls/test_moe.py new file mode 100644 index 00000000..34452b61 --- /dev/null +++ b/tests/kernel/npu_impls/test_moe.py @@ -0,0 +1,12 @@ +def test_moe_imports(): + from twinkle.kernel.npu_impls.moe import ( + GmmFunction, + npu_grouped_mm, + npu_packed_moe_experts_forward, + npu_qwen3_5_moe_sparse_block_forward, + ) + import torch + assert issubclass(GmmFunction, torch.autograd.Function) + assert callable(npu_grouped_mm) + assert callable(npu_packed_moe_experts_forward) + assert callable(npu_qwen3_5_moe_sparse_block_forward) \ No newline at end of file From a5421f9e44b1967fc4d870599d437cf4712448e5 Mon Sep 17 00:00:00 2001 From: weikaiwen <34648228+kevssim@users.noreply.github.com> Date: Fri, 26 Jun 2026 09:12:10 +0800 Subject: [PATCH 12/27] feat(kernel): add npu_impls/fla module --- src/twinkle/kernel/npu_impls/__init__.py | 2 + src/twinkle/kernel/npu_impls/fla.py | 100 +++++++++++++++++++++++ tests/kernel/npu_impls/test_fla.py | 19 +++++ 3 files changed, 121 insertions(+) create mode 100644 src/twinkle/kernel/npu_impls/fla.py create mode 100644 tests/kernel/npu_impls/test_fla.py diff --git a/src/twinkle/kernel/npu_impls/__init__.py b/src/twinkle/kernel/npu_impls/__init__.py index 9194732c..31d77581 100644 --- a/src/twinkle/kernel/npu_impls/__init__.py +++ b/src/twinkle/kernel/npu_impls/__init__.py @@ -15,6 +15,7 @@ npu_packed_moe_experts_forward, npu_qwen3_5_moe_sparse_block_forward, ) +from .fla import apply_qwen3_5_fla __all__ = [ 'NpuRMSNorm', @@ -27,4 +28,5 @@ 'npu_grouped_mm', 'npu_packed_moe_experts_forward', 'npu_qwen3_5_moe_sparse_block_forward', + 'apply_qwen3_5_fla', ] \ No newline at end of file diff --git a/src/twinkle/kernel/npu_impls/fla.py b/src/twinkle/kernel/npu_impls/fla.py new file mode 100644 index 00000000..4ed91ae1 --- /dev/null +++ b/src/twinkle/kernel/npu_impls/fla.py @@ -0,0 +1,100 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Qwen3.5 Flash Linear Attention enablement for Ascend NPU.""" +from __future__ import annotations + +import importlib +import os + +from twinkle import get_logger + +logger = get_logger() + + +def _is_env_enabled(var: str, default: bool = True) -> bool: + env = os.environ.get(var, '').lower().strip() + if not env: + return default + if env in ('1', 'true', 'on', 'yes'): + return True + if env in ('0', 'false', 'off', 'no'): + return False + return default + + +def _import_optional(name: str): + try: + return importlib.import_module(name) + except ImportError: + return None + + +def apply_qwen3_5_fla(model=None) -> int: + """Enable Flash Linear Attention fast path for Qwen3.5 on NPU. + + Returns the count of patched per-layer instances (0 when disabled or when + prerequisites are missing). Safe to call multiple times. + """ + if not _is_env_enabled('TWINKLE_NPU_FLA', default=True): + logger.info('[NPU] [FLA] Disabled by TWINKLE_NPU_FLA') + return 0 + + if _import_optional('torch_npu') is None: + logger.info('[NPU] [FLA] Skip: torch_npu unavailable') + return 0 + + # 1. Force FLA availability flags on transformers utility modules + def _is_fla_available() -> bool: + return True + + for utils_mod_name in ('transformers.utils', 'transformers.utils.import_utils'): + utils_mod = _import_optional(utils_mod_name) + if utils_mod is not None: + setattr(utils_mod, 'is_flash_linear_attention_available', _is_fla_available) + + # 2. Try to load MindSpeed Triton kernel + try: + from twinkle.kernel.chunk_gated_delta_rule import chunk_gated_delta_rule as mindspeed_fla + except ImportError as exc: + logger.warning('[NPU] [FLA] MindSpeed unavailable: %s', exc) + mindspeed_fla = None + + from twinkle.kernel.causal_conv1d import npu_causal_conv1d_fn + + # 3. Patch Qwen3.5 modeling modules + fla_target_modules = [ + 'transformers.models.qwen3_5.modeling_qwen3_5', + 'transformers.models.qwen3_5_moe.modeling_qwen3_5_moe', + ] + for module_name in fla_target_modules: + module = _import_optional(module_name) + if module is None or mindspeed_fla is None: + continue + setattr(module, 'is_flash_linear_attention_available', _is_fla_available) + setattr(module, 'is_fast_path_available', True) + if hasattr(module, 'FusedRMSNormGated'): + setattr(module, 'FusedRMSNormGated', None) + setattr(module, 'chunk_gated_delta_rule', mindspeed_fla) + + # 4. Traverse model and patch per-layer attributes + if model is None or mindspeed_fla is None: + return 0 + + root = getattr(model, 'model', getattr(model, 'module', model)) + if not hasattr(root, 'named_modules'): + return 0 + + patched_instances = 0 + for _name, _module in root.named_modules(): + if hasattr(_module, 'chunk_gated_delta_rule') and callable( + getattr(_module, 'chunk_gated_delta_rule')): + if _module.chunk_gated_delta_rule is not mindspeed_fla: + _module.chunk_gated_delta_rule = mindspeed_fla + _module._twinkle_npu_patched = True + patched_instances += 1 + if hasattr(_module, 'causal_conv1d_fn'): + if getattr(_module, 'causal_conv1d_fn') is not npu_causal_conv1d_fn: + _module.causal_conv1d_fn = npu_causal_conv1d_fn + + if patched_instances: + logger.info('[NPU] [FLA] Patched %d linear attention instance(s)', patched_instances) + return patched_instances \ No newline at end of file diff --git a/tests/kernel/npu_impls/test_fla.py b/tests/kernel/npu_impls/test_fla.py new file mode 100644 index 00000000..0e8b07bb --- /dev/null +++ b/tests/kernel/npu_impls/test_fla.py @@ -0,0 +1,19 @@ +def test_fla_imports(): + from twinkle.kernel.npu_impls.fla import apply_qwen3_5_fla + assert callable(apply_qwen3_5_fla) + + +def test_fla_disabled_by_env(monkeypatch): + monkeypatch.setenv('TWINKLE_NPU_FLA', '0') + from twinkle.kernel.npu_impls.fla import apply_qwen3_5_fla + # With env=0, function returns 0 (no-op) without raising + assert apply_qwen3_5_fla(None) == 0 + + +def test_fla_skips_when_no_torch_npu(monkeypatch): + import sys + monkeypatch.setenv('TWINKLE_NPU_FLA', '1') + monkeypatch.setitem(sys.modules, 'torch_npu', None) # forces ImportError on import + from twinkle.kernel.npu_impls import fla as fla_mod + # Reload-tolerant: should return 0 when torch_npu is missing. + assert fla_mod.apply_qwen3_5_fla(None) == 0 \ No newline at end of file From 87e8477d39ef19f7464907d818f1d21ef7d902e1 Mon Sep 17 00:00:00 2001 From: weikaiwen <34648228+kevssim@users.noreply.github.com> Date: Fri, 26 Jun 2026 09:27:09 +0800 Subject: [PATCH 13/27] feat(kernel): add npu_builtin() bundle and class-attr replacement --- src/twinkle/kernel/builtin.py | 200 ++++++++++++++++++++++++++++++++++ src/twinkle/kernel/core.py | 40 ++++++- tests/kernel/test_builtin.py | 48 ++++++++ tests/kernel/test_replace.py | 20 ++++ 4 files changed, 302 insertions(+), 6 deletions(-) create mode 100644 src/twinkle/kernel/builtin.py create mode 100644 tests/kernel/test_builtin.py diff --git a/src/twinkle/kernel/builtin.py b/src/twinkle/kernel/builtin.py new file mode 100644 index 00000000..a11ec0e8 --- /dev/null +++ b/src/twinkle/kernel/builtin.py @@ -0,0 +1,200 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""``npu_builtin()`` returns the bundle of Ascend NPU replacements. + +All values are wrapped in ``{'npu': impl}`` so the bundle composes safely on +CUDA/CPU systems — non-NPU devices silently skip every entry. + +GMM is **not** included by default (without EP it causes ~8x slowdown). Opt +in by merging: + + {**npu_builtin(model), 'transformers.integrations.moe._grouped_mm': + {'npu': npu_grouped_mm}} +""" +from __future__ import annotations + +import importlib +from typing import Any + +import torch.nn as nn + +from twinkle import get_logger + +logger = get_logger() + + +def _import_optional(name: str): + try: + return importlib.import_module(name) + except ImportError: + return None + + +def npu_builtin(model: nn.Module | None = None) -> dict[Any, dict[str, Any]]: + """Return the NPU builtin mapping; optionally apply per-instance FLA.""" + from .npu_impls.attention import npu_sdpa_attention_forward + from .npu_impls.fla import apply_qwen3_5_fla + from .npu_impls.moe import ( + npu_packed_moe_experts_forward, + npu_qwen3_5_moe_sparse_block_forward, + ) + from .npu_impls.rms_norm import NpuRMSNorm, npu_gated_rms_norm_forward + from .npu_impls.rotary import ( + npu_apply_multimodal_rotary_pos_emb, + npu_apply_rotary_pos_emb, + ) + from .npu_impls.swiglu import npu_swiglu_forward + + bundle: dict[Any, dict[str, Any]] = {} + + # SDPA attention (global) + bundle['transformers.modeling_utils.ALL_ATTENTION_FUNCTIONS'] = {'npu': _SdpaPatchSentinel()} + # NOTE: ALL_ATTENTION_FUNCTIONS is a dict, not a function. We can't setattr + # it. We instead install the sdpa entry by a small bootstrap below. + # Remove the sentinel approach in favor of explicit module-level entries: + bundle.pop('transformers.modeling_utils.ALL_ATTENTION_FUNCTIONS', None) + + # Apply SDPA install eagerly (one-shot module-level mutation). + _install_sdpa(npu_sdpa_attention_forward) + + # === per-family class + function entries === + _add_qwen2_entries(bundle, NpuRMSNorm, npu_apply_rotary_pos_emb, npu_swiglu_forward) + _add_qwen3_entries(bundle, NpuRMSNorm, npu_apply_rotary_pos_emb, npu_swiglu_forward) + _add_qwen3_moe_entries( + bundle, NpuRMSNorm, npu_apply_rotary_pos_emb, npu_swiglu_forward, + npu_packed_moe_experts_forward, npu_qwen3_5_moe_sparse_block_forward, + ) + _add_qwen2_5_vl_entries( + bundle, NpuRMSNorm, npu_apply_rotary_pos_emb, npu_swiglu_forward, + npu_apply_multimodal_rotary_pos_emb, + ) + _add_qwen3_5_entries( + bundle, NpuRMSNorm, npu_gated_rms_norm_forward, npu_apply_rotary_pos_emb, + npu_swiglu_forward, + ) + _add_qwen3_5_moe_entries( + bundle, NpuRMSNorm, npu_gated_rms_norm_forward, npu_apply_rotary_pos_emb, + npu_swiglu_forward, npu_packed_moe_experts_forward, + npu_qwen3_5_moe_sparse_block_forward, + ) + + # === FLA (side-effect; mapping-incompatible) === + apply_qwen3_5_fla(model) + + return bundle + + +class _SdpaPatchSentinel: + pass # unused; placeholder retained for clarity in diffs + + +def _install_sdpa(impl) -> None: + """One-shot install of SDPA attention forward (global modeling_utils dict).""" + try: + from transformers.modeling_utils import ( + ALL_ATTENTION_FUNCTIONS, + AttentionInterface, + ) + except ImportError: + return + AttentionInterface._global_mapping['sdpa'] = impl + ALL_ATTENTION_FUNCTIONS['sdpa'] = impl + + +# ---- helpers that conditionally add entries based on module availability ---- + +def _add_class_if_present(bundle, module_path, class_name, impl_cls): + mod = _import_optional(module_path) + if mod is None: + return + cls = getattr(mod, class_name, None) + if isinstance(cls, type): + bundle[cls] = {'npu': impl_cls} + + +def _add_swiglu_if_present(bundle, module_path, class_name, fn): + mod = _import_optional(module_path) + if mod is None: + return + cls = getattr(mod, class_name, None) + if isinstance(cls, type): + # Function-level: wrap as string-keyed forward replacement. + # We override on the *class object*, not the module attribute, by + # using a class-key with a synthetic impl wrapping the forward. + # The simplest way is to subclass and reassign __class__, but here + # we follow the legacy approach of overwriting the class's forward: + bundle[f'{module_path}.{class_name}.forward'] = {'npu': fn} + + +def _add_attr_if_present(bundle, module_path, attr_name, impl): + mod = _import_optional(module_path) + if mod is None: + return + if '.' in attr_name: + # Dotted attr like 'Qwen3MoeExperts.forward': resolve the class on + # the module, then check the trailing member on the class. + head, _, tail = attr_name.partition('.') + owner = getattr(mod, head, None) + if owner is None or not hasattr(owner, tail): + return + else: + if not hasattr(mod, attr_name): + return + bundle[f'{module_path}.{attr_name}'] = {'npu': impl} + + +def _add_qwen2_entries(bundle, rms_cls, rope_fn, swiglu_fn): + # Qwen2 (used by Qwen2.5-VL etc. via inheritance) + _add_class_if_present(bundle, 'transformers.models.qwen2.modeling_qwen2', 'Qwen2RMSNorm', rms_cls) + _add_attr_if_present(bundle, 'transformers.models.qwen2.modeling_qwen2', 'apply_rotary_pos_emb', rope_fn) + _add_swiglu_if_present(bundle, 'transformers.models.qwen2.modeling_qwen2', 'Qwen2MLP', swiglu_fn) + + +def _add_qwen3_entries(bundle, rms_cls, rope_fn, swiglu_fn): + base = 'transformers.models.qwen3.modeling_qwen3' + _add_class_if_present(bundle, base, 'Qwen3RMSNorm', rms_cls) + _add_attr_if_present(bundle, base, 'apply_rotary_pos_emb', rope_fn) + _add_swiglu_if_present(bundle, base, 'Qwen3MLP', swiglu_fn) + + +def _add_qwen3_moe_entries(bundle, rms_cls, rope_fn, swiglu_fn, experts_fn, sparse_fn): + base = 'transformers.models.qwen3_moe.modeling_qwen3_moe' + _add_class_if_present(bundle, base, 'Qwen3MoeRMSNorm', rms_cls) + _add_attr_if_present(bundle, base, 'apply_rotary_pos_emb', rope_fn) + _add_swiglu_if_present(bundle, base, 'Qwen3MoeMLP', swiglu_fn) + _add_attr_if_present(bundle, base, 'Qwen3MoeExperts.forward', experts_fn) + _add_attr_if_present(bundle, base, 'Qwen3MoeSparseMoeBlock.forward', sparse_fn) + + +def _add_qwen2_5_vl_entries(bundle, rms_cls, rope_fn, swiglu_fn, multimodal_rope_fn): + base = 'transformers.models.qwen2_5_vl.modeling_qwen2_5_vl' + _add_class_if_present(bundle, base, 'Qwen2_5_VLRMSNorm', rms_cls) + _add_attr_if_present(bundle, base, 'apply_rotary_pos_emb', rope_fn) + _add_attr_if_present(bundle, base, 'apply_multimodal_rotary_pos_emb', multimodal_rope_fn) + _add_swiglu_if_present(bundle, base, 'Qwen2MLP', swiglu_fn) + _add_swiglu_if_present(bundle, base, 'Qwen2_5_VLMLP', swiglu_fn) + + +def _add_qwen3_5_entries(bundle, rms_cls, gated_rms_fn, rope_fn, swiglu_fn): + base = 'transformers.models.qwen3_5.modeling_qwen3_5' + if _import_optional(base) is None: + return + _add_class_if_present(bundle, base, 'Qwen3_5RMSNorm', rms_cls) + _add_class_if_present(bundle, base, 'Qwen3_5VisionRMSNorm', rms_cls) + _add_attr_if_present(bundle, base, 'apply_rotary_pos_emb', rope_fn) + _add_swiglu_if_present(bundle, base, 'Qwen3_5MLP', swiglu_fn) + _add_swiglu_if_present(bundle, base, 'Qwen3_5VisionMLP', swiglu_fn) + # Qwen3_5GatedRMSNorm: forward-level replacement + _add_attr_if_present(bundle, base, 'Qwen3_5GatedRMSNorm.forward', gated_rms_fn) + + +def _add_qwen3_5_moe_entries(bundle, rms_cls, gated_rms_fn, rope_fn, swiglu_fn, + experts_fn, sparse_fn): + base = 'transformers.models.qwen3_5_moe.modeling_qwen3_5_moe' + if _import_optional(base) is None: + return + _add_class_if_present(bundle, base, 'Qwen3_5MoeRMSNorm', rms_cls) + _add_attr_if_present(bundle, base, 'apply_rotary_pos_emb', rope_fn) + _add_swiglu_if_present(bundle, base, 'Qwen3_5MoeMLP', swiglu_fn) + _add_attr_if_present(bundle, base, 'Qwen3_5MoeExperts.forward', experts_fn) + _add_attr_if_present(bundle, base, 'Qwen3_5MoeSparseMoeBlock.forward', sparse_fn) + _add_attr_if_present(bundle, base, 'Qwen3_5MoeGatedRMSNorm.forward', gated_rms_fn) \ No newline at end of file diff --git a/src/twinkle/kernel/core.py b/src/twinkle/kernel/core.py index e83c7a44..362873aa 100644 --- a/src/twinkle/kernel/core.py +++ b/src/twinkle/kernel/core.py @@ -81,12 +81,40 @@ def _replace_class(model: nn.Module, target_cls: type, impl_cls: type) -> None: def _replace_attr(dotted_path: str, impl) -> None: - """``setattr`` ``impl`` onto the module identified by the dotted path's prefix.""" - module_path, _, attr = dotted_path.rpartition('.') - if not module_path or not attr: - raise ValueError(f"Expected 'pkg.module.attr', got: {dotted_path!r}") - module = importlib.import_module(module_path) - setattr(module, attr, impl) + """``setattr`` ``impl`` onto the attribute identified by the dotted path. + + Supports two forms: + - ``pkg.mod.attr`` (set module attribute) + - ``pkg.mod.ClassName.attr`` (set class attribute / method) + + The split is found by walking the prefix from the longest importable + module backwards until ``importlib.import_module`` succeeds. + """ + parts = dotted_path.split('.') + if len(parts) < 2: + raise ValueError(f"Expected at least 'pkg.attr', got: {dotted_path!r}") + + # Find the longest prefix that imports as a module. + last_err: ImportError | None = None + module = None + module_depth = 0 + for i in range(len(parts) - 1, 0, -1): + candidate = '.'.join(parts[:i]) + try: + module = importlib.import_module(candidate) + module_depth = i + break + except ImportError as e: + last_err = e + continue + if module is None: + raise ImportError(f'Could not import any prefix of {dotted_path!r}') from last_err + + # Walk remaining attributes; the last one is the target. + obj = module + for attr in parts[module_depth:-1]: + obj = getattr(obj, attr) + setattr(obj, parts[-1], impl) def _load_hub_ref(ref: HubRef): diff --git a/tests/kernel/test_builtin.py b/tests/kernel/test_builtin.py new file mode 100644 index 00000000..3b4e68d4 --- /dev/null +++ b/tests/kernel/test_builtin.py @@ -0,0 +1,48 @@ +import torch +import torch.nn as nn + +import pytest + + +def test_npu_builtin_returns_dict(): + from twinkle.kernel.builtin import npu_builtin + bundle = npu_builtin() + assert isinstance(bundle, dict) + assert len(bundle) > 0 + + +def test_npu_builtin_values_are_npu_gated(): + """Every value in npu_builtin() must be wrapped in {'npu': ...} so it's + safely no-op on CUDA/CPU.""" + from twinkle.kernel.builtin import npu_builtin + for key, value in npu_builtin().items(): + assert isinstance(value, dict), f'value for {key!r} is not a device-dict' + assert 'npu' in value, f'value for {key!r} is missing npu entry' + + +def test_npu_builtin_compose_with_user_override(): + """User-supplied keys override the builtin (via plain dict merge).""" + from twinkle.kernel.builtin import npu_builtin + sentinel = object() + merged = {**npu_builtin(), 'fake.module.path.fn': sentinel} + assert merged['fake.module.path.fn'] is sentinel + + +def test_npu_builtin_safe_on_cpu_model(): + """kernelize(cpu_model, npu_builtin()) must not raise and not modify.""" + from twinkle.kernel import kernelize + from twinkle.kernel.builtin import npu_builtin + + m = nn.Sequential(nn.Linear(2, 2)) + pre_type = type(m[0]) + out = kernelize(m, npu_builtin()) + assert out is m + assert type(m[0]) is pre_type # no replacement happened (cpu device) + + +def test_npu_builtin_skips_missing_modeling_modules(): + """If transformers.models.qwen3_5 is not installed, the bundle must + still produce a dict (with whatever subset is available).""" + from twinkle.kernel.builtin import npu_builtin + bundle = npu_builtin() # must not raise + assert isinstance(bundle, dict) \ No newline at end of file diff --git a/tests/kernel/test_replace.py b/tests/kernel/test_replace.py index 5a2ba459..e649b2e3 100644 --- a/tests/kernel/test_replace.py +++ b/tests/kernel/test_replace.py @@ -50,5 +50,25 @@ def test_replace_attr_sets_module_attribute(): new_fn = lambda x: x * 2 # noqa: E731 _replace_attr(f'{mod_name}.target_fn', new_fn) assert mod.target_fn is new_fn + finally: + sys.modules.pop(mod_name, None) + + +def test_replace_attr_supports_class_attribute(): + import sys + import types + + mod_name = 'tests.kernel._tmp_class_attr' + mod = types.ModuleType(mod_name) + + class Foo: + def forward(self, x): + return x + mod.Foo = Foo + sys.modules[mod_name] = mod + try: + new_forward = lambda self, x: x + 7 # noqa: E731 + _replace_attr(f'{mod_name}.Foo.forward', new_forward) + assert Foo.forward is new_forward finally: sys.modules.pop(mod_name, None) \ No newline at end of file From 39a9225e2cf162f0d9efe69e93edb75a4e623f76 Mon Sep 17 00:00:00 2001 From: weikaiwen <34648228+kevssim@users.noreply.github.com> Date: Fri, 26 Jun 2026 09:27:11 +0800 Subject: [PATCH 14/27] refactor(kernel): expose only kernelize, hub, npu_builtin --- src/twinkle/kernel/__init__.py | 116 +++----------------------------- tests/kernel/test_public_api.py | 22 ++++++ 2 files changed, 31 insertions(+), 107 deletions(-) create mode 100644 tests/kernel/test_public_api.py diff --git a/src/twinkle/kernel/__init__.py b/src/twinkle/kernel/__init__.py index c7262eb0..5d435c0f 100644 --- a/src/twinkle/kernel/__init__.py +++ b/src/twinkle/kernel/__init__.py @@ -1,111 +1,13 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -"""Twinkle Kernel Module - Kernel orchestration layer.""" -import torch -from logging import getLogger -from typing import Any, Dict, Optional, Union +"""Mapping-driven kernel replacement. -from twinkle.utils.framework import Torch -from .base import DeviceType, ModeType, is_kernels_enabled -from .function import apply_function_kernel, register_function_kernel -from .layer import apply_layer_kernel, register_layer_batch, register_layer_kernel -from .monkey_patch_npu import apply_npu_patch, register_npu_fused_function_kernels -from .registry import register_external_layer as _register_external_layer +Three public symbols: -logger = getLogger(__name__) +- :func:`kernelize` apply ``mapping`` to a model +- :func:`hub` build a Hub kernel reference +- :func:`npu_builtin` the Ascend NPU built-in bundle +""" +from .builtin import npu_builtin +from .core import hub, kernelize -__all__ = [ - 'kernelize_model', - 'register_layer_kernel', - 'register_function_kernel', - 'register_external_layer', - 'register_kernels', - 'apply_npu_patch', - 'apply_npu_fused_ops', -] - - -def kernelize_model( - model, - mode: ModeType = 'inference', - device: Optional[DeviceType] = None, - use_fallback: bool = True, -) -> Any: - """Apply kernels to model (main entry point). - - For NPU devices, this also applies Ascend fused operators (RMSNorm, RoPE, - SwiGLU, SDPA Attention) unconditionally when running on NPU. - - Args: - model: The PyTorch model to kernelize. - mode: The mode for kernel selection ("inference" or "train"). - device: The device type (auto-detected if None). - use_fallback: Whether to use original forward when no compatible kernel found. - If False, raises ValueError when kernel is unavailable. - - Returns: - The kernelized model. - """ - # Step 0: NPU monkey-patches must be applied BEFORE layer kernel replacement - # so that patched module classes are used when new instances are created. - if device == 'npu' or (device is None and _is_npu_device(model)): - try: - apply_npu_patch(model) - except Exception: - logger.warning('NPU patch failed. Continuing without fused ops.', exc_info=True) - - model = apply_layer_kernel(model, mode=mode, device=device, use_fallback=use_fallback) - - apply_function_kernel(device=device, mode=mode) - - return model - - -def apply_npu_fused_ops(config) -> None: - """Apply NPU fused operators patch manually. - """ - logger.warning('apply_npu_fused_ops(config) is deprecated. ' - 'Use apply_npu_patch() instead, which enables all patches unconditionally.') - apply_npu_patch() - - -def register_external_layer(layer_class: type, kernel_name: str) -> None: - _register_external_layer(layer_class, kernel_name) - - -def register_kernels(config: Dict[str, Dict[str, Any]]) -> None: - """Batch register kernels (framework integration API).""" - if 'layers' in config: - for kernel_name, spec in config['layers'].items(): - device = spec.pop('device', 'cuda') - register_layer_kernel(kernel_name=kernel_name, device=device, **spec) - - if 'functions' in config: - from .function import register_function_batch - - functions = config['functions'] - if isinstance(functions, dict): - function_specs = [] - for func_name, spec in functions.items(): - if not isinstance(spec, dict): - raise TypeError(f'Function spec for {func_name} must be a dict.') - if 'func_name' not in spec: - spec['func_name'] = func_name - function_specs.append(spec) - register_function_batch(function_specs) - else: - register_function_batch(functions) - - -def _is_npu_device(model=None) -> bool: - """Check if the model (or current environment) is on NPU device.""" - # Priority 1: Check model's actual device (kernel-specific inference) - if model is not None: - try: - param_device = next(model.parameters()).device - if param_device.type == 'npu': - return True - except StopIteration: - pass - - # Priority 2: Fallback to global NPU availability - return Torch.is_npu_available() +__all__ = ['kernelize', 'hub', 'npu_builtin'] \ No newline at end of file diff --git a/tests/kernel/test_public_api.py b/tests/kernel/test_public_api.py new file mode 100644 index 00000000..f9a17a2a --- /dev/null +++ b/tests/kernel/test_public_api.py @@ -0,0 +1,22 @@ +def test_public_exports_exactly_three_symbols(): + import twinkle.kernel as k + assert sorted(k.__all__) == ['hub', 'kernelize', 'npu_builtin'] + assert callable(k.kernelize) + assert callable(k.npu_builtin) + assert callable(k.hub) + + +def test_no_legacy_symbols(): + """Legacy registrar / patch helpers must be gone.""" + import twinkle.kernel as k + legacy = [ + 'kernelize_model', 'register_layer_kernel', 'register_function_kernel', + 'register_kernels', 'register_external_layer', 'apply_npu_patch', + 'apply_npu_fused_ops', 'apply_function_kernel', 'apply_layer_kernel', + 'register_layer_batch', 'register_npu_fused_function_kernels', + 'get_global_layer_registry', 'get_global_function_registry', + 'get_global_external_layer_registry', 'LayerRegistry', + 'ExternalLayerRegistry', 'FunctionRegistry', + ] + for name in legacy: + assert not hasattr(k, name), f'unexpected legacy symbol: {name}' \ No newline at end of file From f4c491f4d8a481b261dda5d8d1662bc61614af3c Mon Sep 17 00:00:00 2001 From: weikaiwen <34648228+kevssim@users.noreply.github.com> Date: Fri, 26 Jun 2026 09:30:00 +0800 Subject: [PATCH 15/27] refactor(kernel): remove legacy registry/function/layer/base/monkey_patch_npu --- .../\345\206\205\346\240\270/Kernel.md" | 12 +- src/twinkle/kernel/base.py | 81 -- src/twinkle/kernel/chunk_gated_delta_rule.py | 2 +- src/twinkle/kernel/function.py | 174 --- src/twinkle/kernel/layer.py | 119 -- src/twinkle/kernel/monkey_patch_npu.py | 1009 ----------------- src/twinkle/kernel/registry.py | 183 --- .../sequence_parallel/linear_attention_sp.py | 2 +- tests/kernel/test_function_kernel.py | 265 ----- tests/kernel/test_kernel.py | 352 ------ 10 files changed, 8 insertions(+), 2191 deletions(-) delete mode 100644 src/twinkle/kernel/base.py delete mode 100644 src/twinkle/kernel/function.py delete mode 100644 src/twinkle/kernel/layer.py delete mode 100644 src/twinkle/kernel/monkey_patch_npu.py delete mode 100644 src/twinkle/kernel/registry.py delete mode 100644 tests/kernel/test_function_kernel.py delete mode 100644 tests/kernel/test_kernel.py diff --git "a/docs/source_zh/\347\273\204\344\273\266/\345\206\205\346\240\270/Kernel.md" "b/docs/source_zh/\347\273\204\344\273\266/\345\206\205\346\240\270/Kernel.md" index 89ae37ca..7f5c9f3a 100644 --- "a/docs/source_zh/\347\273\204\344\273\266/\345\206\205\346\240\270/Kernel.md" +++ "b/docs/source_zh/\347\273\204\344\273\266/\345\206\205\346\240\270/Kernel.md" @@ -48,12 +48,12 @@ from transformers.models.qwen2.modeling_qwen2 import Qwen2MLP # 1) 从本地仓库注册层内核 register_layer_kernel( - kernel_name="MyAwesomeMLP", - repo_path="/path/to/local/repo", - package_name="my_kernels", - layer_name="Qwen2MLPTrainingKernel", - device="cuda", - mode="train", + kernel_name="MyAwesomeMLP",/取的kernel名字,自定义 + repo_path="/path/to/local/repo",/本地kernel仓库路径 + package_name="my_kernels",/包名 + layer_name="Qwen2MLPTrainingKernel",/对应layer.py里面实现类的名字 + device="cuda",/适用的设备类型 + mode="train",/使用的场景:train or inference ) # 2) 绑定外部层与内核名 diff --git a/src/twinkle/kernel/base.py b/src/twinkle/kernel/base.py deleted file mode 100644 index 6da669d5..00000000 --- a/src/twinkle/kernel/base.py +++ /dev/null @@ -1,81 +0,0 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -"""Kernel module base - Base classes, env vars, device detection.""" -import os -from typing import Any, Literal, Optional - -from twinkle import exists - -ModeType = Literal['train', 'inference', 'compile'] -DeviceType = Literal['cuda', 'npu', 'mps', 'cpu', 'rocm', 'metal'] - - -def _kernels_enabled() -> bool: - """Check if kernels are enabled (default: enabled).""" - env_val = os.getenv('TWINKLE_USE_KERNELS', 'YES').upper() - return env_val in ('YES', 'TRUE', '1', 'ON') - - -def _trust_remote_code() -> bool: - """Check if remote code is trusted (default: not trusted).""" - env_val = os.getenv('TWINKLE_TRUST_REMOTE_CODE', 'NO').upper() - return env_val in ('YES', 'TRUE', '1', 'ON') - - -def detect_backend() -> Optional[str]: - """Detect training framework backend: "transformers" | "megatron" | None.""" - if exists('transformers'): - return 'transformers' - return None - - -def is_kernels_available() -> bool: - """Check if HF kernels package is available.""" - return exists('kernels') - - -def is_kernels_enabled() -> bool: - """Check if kernels are enabled by env var.""" - return _kernels_enabled() and is_kernels_available() - - -def to_kernels_mode(mode: ModeType) -> Any: - """Convert Twinkle mode to HF kernels mode.""" - if not is_kernels_available(): - return None - from kernels import Mode - if isinstance(mode, Mode): - return mode - mode_map = { - 'train': Mode.TRAINING, - 'inference': Mode.INFERENCE, - 'compile': Mode.TORCH_COMPILE, - } - return mode_map.get(mode, Mode.INFERENCE) - - -def validate_mode(mode: str) -> None: - from kernels.layer.mode import Mode - mode = to_kernels_mode(mode) - - if mode == Mode.FALLBACK: - raise ValueError('Mode.FALLBACK can only be used to register kernel mappings.') - if Mode.INFERENCE not in mode and Mode.TRAINING not in mode: # type: ignore[operator] - raise ValueError('kernelize mode must contain Mode.INFERENCE or Mode.TRAINING.') - - -def supports_mode(target: object, mode: str) -> bool: - from kernels.layer.mode import Mode - mode = to_kernels_mode(mode) - if Mode.TORCH_COMPILE in mode and not getattr(target, 'can_torch_compile', False): - return False - if Mode.TRAINING in mode and not getattr(target, 'has_backward', True): - return False - return True - - -def validate_device_type(device_type: str) -> None: - supported_devices = {'cpu', 'cuda', 'mps', 'npu', 'rocm', 'xpu'} - if device_type not in supported_devices: - raise ValueError('Unsupported device type ' - f"'{device_type}'. Supported device types are: " - f"{', '.join(sorted(supported_devices))}") diff --git a/src/twinkle/kernel/chunk_gated_delta_rule.py b/src/twinkle/kernel/chunk_gated_delta_rule.py index 553fb122..2d0beee7 100644 --- a/src/twinkle/kernel/chunk_gated_delta_rule.py +++ b/src/twinkle/kernel/chunk_gated_delta_rule.py @@ -1,7 +1,7 @@ '''Ascend NPU implementation of chunk_gated_delta_rule for Flash Linear Attention (FLA). This module provides a drop-in replacement for fla.ops.gated_delta_rule.chunk_gated_delta_rule, redirecting the underlying Triton kernels to MindSpeed's NPU-compatible counterparts. -It is consumed by twinkle.kernel.monkey_patch_npu to enable the fast linear-attention +It is consumed by twinkle.kernel.npu_impls.fla to enable the fast linear-attention path of Qwen3.5 on Ascend hardware.''' import torch diff --git a/src/twinkle/kernel/function.py b/src/twinkle/kernel/function.py deleted file mode 100644 index 94a2d817..00000000 --- a/src/twinkle/kernel/function.py +++ /dev/null @@ -1,174 +0,0 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -from __future__ import annotations - -import importlib -from typing import TYPE_CHECKING, Callable, Iterable, List, Optional - -from twinkle import get_logger -from .base import ModeType, is_kernels_available, validate_device_type, validate_mode -from .registry import FunctionKernelSpec, get_global_function_registry - -if TYPE_CHECKING: - from kernels.layer.func import FuncRepositoryProtocol - -logger = get_logger() - - -def _load_from_hub( - *, - repo: FuncRepositoryProtocol | None, - repo_id: str | None, - revision: str | None, - version: str | None, - func_name: str, -) -> tuple[Callable, object]: - """Resolve function implementation from a repo or Hub repo_id.""" - if repo is not None: - module_cls = repo.load() - module_instance = module_cls() - - def impl(*args, **kwargs): - return module_instance(*args, **kwargs) - - return impl, module_instance - - from kernels._versions import select_revision_or_version - from kernels.utils import get_kernel - assert repo_id is not None - # kernels API changed across versions; use keyword args for modern API - # and fall back to repo_id-only for older variants. - try: - resolved = select_revision_or_version(repo_id, revision=revision, version=version) - except TypeError: - resolved = select_revision_or_version(repo_id) - try: - kernel = get_kernel(repo_id, revision=resolved) - except TypeError: - kernel = get_kernel(repo_id, resolved) - func = getattr(kernel, func_name, None) - if func is None: - raise AttributeError(f'Kernel repo {repo_id} does not export {func_name}.') - return func, func - - -def register_function_kernel( - *, - func_name: str, - target_module: str, - func_impl: Callable | None = None, - repo: FuncRepositoryProtocol | None = None, - repo_id: str | None = None, - revision: str | None = None, - version: str | None = None, - device: str | None = None, - mode: ModeType | None = None, -) -> None: - """Register a function kernel with the registry.""" - sources = [func_impl is not None, repo is not None, repo_id is not None] - if sum(sources) != 1: - raise ValueError('Provide exactly one of func_impl, repo, or repo_id.') - if revision is not None and version is not None: - raise ValueError('Either revision or version must be specified, not both.') - if mode is not None: - validate_mode(mode) - - get_global_function_registry().register( - FunctionKernelSpec( - func_name=func_name, - target_module=target_module, - func_impl=func_impl, - repo=repo, - repo_id=repo_id, - revision=revision, - version=version, - device=device, - mode=mode, - )) - - -def register_function_batch(function_registry: Iterable[dict]) -> None: - """Batch register function kernels from a list of spec dicts.""" - for spec in function_registry: - register_function_kernel( - func_name=spec['func_name'], - target_module=spec['target_module'], - func_impl=spec.get('func_impl'), - repo=spec.get('repo'), - repo_id=spec.get('repo_id'), - revision=spec.get('revision'), - version=spec.get('version'), - device=spec.get('device'), - mode=spec.get('mode'), - ) - - -def apply_function_kernel( - *, - target_module: str | None = None, - device: str | None = None, - mode: ModeType | None = None, - strict: bool = False, -) -> list[str]: - """Apply registered function kernels by monkey-patching target modules. - target_module: If specified, only apply kernels targeting this module. - device: If specified, only apply kernels matching this device or with no device. - mode: If specified, only apply kernels matching this mode or with no mode. - strict: If True, raise errors on failures; otherwise log warnings. - """ - applied = [] - if device is not None: - validate_device_type(device) - - for spec in get_global_function_registry().list_specs(): - # Filter by target module and device/mode constraints. - if target_module is not None and spec.target_module != target_module: - continue - if device is not None and spec.device is not None and spec.device != device: - continue - if spec.mode is not None and mode is None: - msg = ('Function kernel registered with mode but apply_function_kernel ' - 'was called without mode; skipping.') - if strict: - raise ValueError(msg) - logger.warning(msg) - continue - if spec.mode is not None and mode is not None and spec.mode != mode: - continue - - try: - # Import the module that will be monkey-patched. - module = importlib.import_module(spec.target_module) - except Exception as exc: - if strict: - raise - logger.warning( - 'Failed to import target module %s: %s', - spec.target_module, - exc, - ) - continue - - # Resolve implementation and capability target for mode checks. - if spec.func_impl is not None: - impl = spec.func_impl - else: - if not is_kernels_available(): - msg = ('HF kernels package not available. ' - f'Cannot load function kernel: {spec.func_name}. ' - 'Install it with `pip install kernels`.') - raise RuntimeError(msg) - impl, _ = _load_from_hub( - repo=spec.repo, - repo_id=spec.repo_id, - revision=spec.revision, - version=spec.version, - func_name=spec.func_name, - ) - # Final patch (or reapply when no mode gating is used). - setattr(module, spec.func_name, impl) - applied.append(f'{spec.target_module}.{spec.func_name}') - - if strict and not applied: - raise ValueError('No function kernels applied for the given filters.') - - return applied diff --git a/src/twinkle/kernel/layer.py b/src/twinkle/kernel/layer.py deleted file mode 100644 index e47f7392..00000000 --- a/src/twinkle/kernel/layer.py +++ /dev/null @@ -1,119 +0,0 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -"""Kernel module layer - Layer-level replacement with HF kernels integration.""" -from pathlib import Path -from typing import Any, Optional, Union - -from twinkle import Platform, get_logger -from .base import DeviceType, ModeType, is_kernels_available, is_kernels_enabled, to_kernels_mode -from .registry import get_global_layer_registry, register_layer - -logger = get_logger() - - -def register_layer_kernel( - kernel_name: str, - repo_id: Optional[str] = None, - repo_path: Optional[Union[str, Path]] = None, - package_name: Optional[str] = None, - layer_name: Optional[str] = None, - version: Optional[str] = None, - device: DeviceType = 'cuda', - mode: Optional[ModeType] = None, -) -> None: - """Register a layer kernel with the registry. - - Args: - kernel_name: Unique kernel name (can register multiple modes with same name) - repo_id: Hub repository ID - repo_path: Local repository path - package_name: Package name (required when using repo_path) - layer_name: Layer name (defaults to kernel_name) - version: Version constraint - device: Device type - mode: Mode (train/inference/compile), None means FALLBACK - """ - if not is_kernels_available(): - logger.warning(f'HF kernels package not available. Skipping registration for kernel: {kernel_name}') - return - - from kernels import LayerRepository, LocalLayerRepository - - if repo_path is not None: - if package_name is None: - raise ValueError(f'package_name must be provided when using repo_path for kernel: {kernel_name}') - if isinstance(repo_path, str): - repo_path = Path(repo_path) - repo_spec = LocalLayerRepository( - repo_path=repo_path, - package_name=package_name, - layer_name=layer_name or kernel_name, - ) - else: - if repo_id is None: - raise ValueError(f'Either repo_id or repo_path must be provided for kernel: {kernel_name}') - repo_spec = LayerRepository( - repo_id=repo_id, - layer_name=layer_name or kernel_name, - version=version, - ) - - hf_mode = _to_hf_mode(mode) - register_layer(kernel_name, repo_spec, device, mode=hf_mode) - - mode_str = mode or 'FALLBACK' - logger.info(f'Registered layer kernel: {kernel_name} for device: {device}, mode: {mode_str}') - - -def _to_hf_mode(mode: Optional[ModeType]) -> Any: - """Convert Twinkle mode to HF kernels Mode.""" - if mode is None: - from kernels import Mode - return Mode.FALLBACK - return to_kernels_mode(mode) - - -def apply_layer_kernel( - model, - mode: ModeType = 'inference', - device: Optional[DeviceType] = None, - use_fallback: bool = True, -) -> Any: - """Apply layer kernels to model. - - Args: - model: The PyTorch model to kernelize. - mode: The mode for kernel selection ("inference" or "train"). - device: The device type (auto-detected if None). - use_fallback: Whether to use original forward when no compatible kernel found. - If False, raises ValueError when kernel is unavailable. - - Returns: - The kernelized model. - """ - if not is_kernels_enabled(): - logger.debug('Kernels not enabled, returning original model') - return model - - get_global_layer_registry().sync_to_hf_kernels() - - if device is None: - device = Platform.get_platform().device_prefix() or 'cuda' - - kernel_mode = to_kernels_mode(mode) - - try: - from kernels import kernelize - logger.debug(f'Applying kernels with mode: {mode}, device: {device}, use_fallback: {use_fallback}') - return kernelize(model, mode=kernel_mode, device=device, use_fallback=use_fallback) - except Exception as e: - if use_fallback: - logger.warning(f'Failed to apply kernels: {e}. Returning original model.') - return model - raise - - -def register_layer_batch(mapping: dict, default_device: DeviceType = 'cuda') -> None: - """Batch register layer kernels.""" - for kernel_name, spec in mapping.items(): - device = spec.pop('device', default_device) - register_layer_kernel(kernel_name=kernel_name, device=device, **spec) diff --git a/src/twinkle/kernel/monkey_patch_npu.py b/src/twinkle/kernel/monkey_patch_npu.py deleted file mode 100644 index 01e51b06..00000000 --- a/src/twinkle/kernel/monkey_patch_npu.py +++ /dev/null @@ -1,1009 +0,0 @@ -"""NPU monkey patches for Ascend hardware acceleration. - -Unified entry point:: - - >>> from twinkle.kernel.monkey_patch_npu import apply_npu_patch - >>> if Torch.is_npu_available(): - ... apply_npu_patch(model) -""" - -import importlib -import os -import torch -import torch.nn.functional as F -from torch import nn -from transformers.utils import is_torch_npu_available - -from twinkle import get_logger -from .causal_conv1d import npu_causal_conv1d_fn - -logger = get_logger() - -_is_torch_npu_available = is_torch_npu_available() -_NPU_PATCH_APPLIED = False - -if _is_torch_npu_available: - import torch_npu - -# --------------------------------------------------------------------------- -# Utils -# --------------------------------------------------------------------------- - - -def import_optional_module(module_name: str): - """Import a module, returning None if unavailable.""" - try: - return importlib.import_module(module_name) - except ImportError as exc: - logger.debug('Failed to import optional module %s: %s', module_name, exc) - return None - - -def _resolve_unsqueeze_dim(position_ids=None, unsqueeze_dim=1): - if isinstance(position_ids, int) and unsqueeze_dim == 1: - return position_ids - return unsqueeze_dim - - -def _is_ep_enabled(model=None) -> bool: - r"""Check whether Expert Parallelism (EP) is enabled. - - EP is detected via ``device_mesh.ep_size > 1``. - When EP is active, each rank holds only a subset of expert weights, - making ``npu_grouped_matmul`` efficient (small contiguous weights). - """ - device_mesh = getattr(model, 'device_mesh', None) - if device_mesh is None: - return False - return (getattr(device_mesh, 'ep_size', None) or 0) > 1 - - -# ============================================================================= -# Section 1: MoE Grouped MatMul (GMM) -# ============================================================================= - - -class GmmFunction(torch.autograd.Function): - r"""Custom autograd function for NPU grouped matrix multiplication.""" - - @staticmethod - def forward(ctx, x: torch.tensor, group_list: torch.tensor, weight_ekn: torch.tensor): - group_list = group_list.to(torch.int64) - ctx.save_for_backward(x, group_list, weight_ekn) - outputs = torch_npu.npu_grouped_matmul( - [x], - [weight_ekn], - group_list=group_list, - group_type=0, - split_item=2, - group_list_type=1, - ) - return outputs[0] - - @staticmethod - def backward(ctx, grad_output: torch.tensor): - x, group_list, weight_ekn = ctx.saved_tensors - grad_input = torch_npu.npu_grouped_matmul( - [grad_output], - [weight_ekn.transpose(-2, -1).contiguous()], - bias=None, - group_list=group_list, - group_type=0, - split_item=2, - group_list_type=1, - )[0] - grad_weight = torch_npu.npu_grouped_matmul( - [x.transpose(0, 1)], - [grad_output], - bias=None, - group_list=group_list, - group_type=2, - split_item=3, - group_list_type=1, - )[0] - return grad_input, None, grad_weight.contiguous() - - -def _grouped_mm_npu(input: torch.tensor, weight_ekn: torch.tensor, offs: torch.tensor) -> torch.tensor: - counts = torch.empty_like(offs) - counts[0] = offs[0] - if offs.numel() > 1: - counts[1:] = offs[1:] - offs[:-1] - counts = counts.to(torch.int64) - return GmmFunction.apply(input, counts, weight_ekn) - - -def _apply_hf_moe_grouped_mm_patch(model=None) -> None: - r"""Patch HuggingFace MoE integration to use NPU grouped matmul. - - When Expert Parallelism (EP) is **not** enabled, each rank holds **all** - expert weights. ``weight.transpose(-2, -1)`` then produces a large - non-contiguous view that ``npu_grouped_matmul`` forces to ``.contiguous()`` - (~12.88 GB per MoE layer), creating a bandwidth bottleneck that makes the - NPU patch **slower** than the native per-expert fallback (~8x overhead). - - Detection logic: - - ``TWINKLE_NPU_GMM_PATCH`` not set → **skip** the patch by default. - - ``TWINKLE_NPU_GMM_PATCH=1`` → EP-aware: apply only if EP is enabled - (each rank has few experts, weights are small and contiguous); - skip if EP is **not** enabled (avoid ~8x overhead). - - ``TWINKLE_NPU_GMM_PATCH=0`` → **disable** the patch regardless. - """ - moe_enabled = _is_env_enabled('TWINKLE_NPU_GMM_PATCH', default=False) - - if not moe_enabled: - has_native_gmm = hasattr(torch.nn.functional, 'grouped_mm') - logger.info( - '[PATCH] TWINKLE_NPU_GMM_PATCH not set: MoE GMM patch skipped by default. ' - 'Set TWINKLE_NPU_GMM_PATCH=1 to enable (EP-aware). ' - 'Native grouped_mm available: %s.', - has_native_gmm, - ) - return - - if not _is_ep_enabled(model): - has_native_gmm = hasattr(torch.nn.functional, 'grouped_mm') - logger.info( - '[PATCH] TWINKLE_NPU_GMM_PATCH=1 but EP not enabled (all experts on each rank) — ' - 'skipping _grouped_mm_npu patch to avoid ~8x overhead from ' - 'contiguous copies on transposed weights. ' - 'Native grouped_mm available: %s.', - has_native_gmm, - ) - return - - import transformers.integrations.moe as hf_moe - hf_moe._grouped_mm = _grouped_mm_npu - logger.info('[PATCH] transformers.integrations.moe._grouped_mm -> _grouped_mm_npu') - - -# ============================================================================= -# Section 1b: MoE Packed Experts -# ============================================================================= - - -def _normalize_packed_expert_weights(module, input_dtype: torch.dtype, hidden_dim: int): - """Normalize packed expert weight shapes for NPU grouped matmul.""" - gate_up_proj = module.gate_up_proj.to(input_dtype) - down_proj = module.down_proj.to(input_dtype) - - if gate_up_proj.shape[1] == hidden_dim: - gate_up_weight = gate_up_proj - elif gate_up_proj.shape[2] == hidden_dim: - gate_up_weight = gate_up_proj.transpose(1, 2) - else: - raise RuntimeError(f'Unsupported gate_up_proj shape for NPU MoE patch: {tuple(gate_up_proj.shape)}.') - - if down_proj.shape[2] == hidden_dim: - down_weight = down_proj - elif down_proj.shape[1] == hidden_dim: - down_weight = down_proj.transpose(1, 2) - else: - raise RuntimeError(f'Unsupported down_proj shape for NPU MoE patch: {tuple(down_proj.shape)}.') - - return gate_up_weight, down_weight - - -def _get_cached_expert_weights(self, target_dtype: torch.dtype, hidden_dim: int): - """Return normalized expert weights with automatic cache invalidation. - - Cache key combines (dtype, gate_version, down_version). This correctly - handles: - - Full-parameter training: optimizer in-place updates bump _version - - LoRA training: frozen weights keep _version stable, cache persists - - Inference: cache is permanent - - AMP autocast: separate cache per dtype - - Safety: when weights require gradients, the cache is bypassed to avoid - breaking the PyTorch autograd graph (non-leaf tensors from .to() cannot - be reused across forward passes). - """ - requires_grad = ( - getattr(self.gate_up_proj, 'requires_grad', False) or getattr(self.down_proj, 'requires_grad', False)) - cache_attr = '_npu_expert_cache' - if not requires_grad and hasattr(self, cache_attr): - cached_dtype, cached_gate_ver, cached_down_ver, cached = getattr(self, cache_attr) - if (cached_dtype == target_dtype and cached_gate_ver == self.gate_up_proj._version - and cached_down_ver == self.down_proj._version): - return cached - - weights = _normalize_packed_expert_weights(self, target_dtype, hidden_dim) - if not requires_grad: - setattr( - self, - cache_attr, - (target_dtype, self.gate_up_proj._version, self.down_proj._version, weights), - ) - return weights - - -def npu_packed_moe_experts_forward( - self, - hidden_states: torch.Tensor, - router_indices_or_routing_weights: torch.Tensor, - routing_weights_or_router_indices: torch.Tensor, -) -> torch.Tensor: - """Packed MoE experts forward using NPU grouped matmul. - - Compatible with Qwen3-MoE, Qwen3.5-MoE, and any model using packed experts - with the standard ``(hidden_states, router_indices, routing_weights)`` call convention. - """ - if router_indices_or_routing_weights.dtype in {torch.int8, torch.int16, torch.int32, torch.int64, torch.uint8}: - router_indices = router_indices_or_routing_weights - routing_weights = routing_weights_or_router_indices - else: - routing_weights = router_indices_or_routing_weights - router_indices = routing_weights_or_router_indices - - output_shape = hidden_states.shape - hidden_dim = output_shape[-1] - hidden_states = hidden_states.reshape(-1, hidden_dim) - - if routing_weights.shape != router_indices.shape: - routing_weights = torch.gather(routing_weights, dim=-1, index=router_indices.to(torch.long)) - routing_weights = routing_weights.to(hidden_states.dtype) - router_indices = router_indices.to(torch.int32) - - permuted_hidden_states, row_ids_map = torch_npu.npu_moe_token_permute(hidden_states, router_indices) - tokens_per_expert = torch.bincount(router_indices.view(-1), minlength=self.num_experts).to(torch.int64) - - # Cached normalized weights: auto-invalidates on weight updates (full-param) - # and persists when frozen (LoRA / inference). - gate_up_weight, down_weight = _get_cached_expert_weights(self, hidden_states.dtype, hidden_dim) - - intermediate_hidden_states = GmmFunction.apply(permuted_hidden_states, tokens_per_expert, gate_up_weight) - intermediate_activations = torch_npu.npu_swiglu(intermediate_hidden_states, dim=-1) - output = GmmFunction.apply(intermediate_activations, tokens_per_expert, down_weight) - next_states = torch_npu.npu_moe_token_unpermute(output, row_ids_map, probs=routing_weights) - return next_states.view(*output_shape) - - -# ============================================================================= -# Section 1c: MoE Sparse Block -# ============================================================================= - - -def _topk_from_router_logits(module, hidden_states: torch.Tensor, router_logits: torch.Tensor): - """Compute top-k routing from router logits (Transformers 4.x style).""" - routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float) - routing_weights, router_indices = torch.topk(routing_weights, module.top_k, dim=-1) - if getattr(module, 'norm_topk_prob', True): - routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True) - routing_weights = routing_weights.to(hidden_states.dtype) - return routing_weights, router_indices - - -def _add_shared_expert(self, hidden_states: torch.Tensor, expert_output: torch.Tensor) -> torch.Tensor: - """Add shared expert output with sigmoid gating. - - Automatically skips if the module lacks shared_expert / shared_expert_gate. - """ - if not (hasattr(self, 'shared_expert') and hasattr(self, 'shared_expert_gate')): - return expert_output - - shared_expert_output = self.shared_expert(hidden_states) - shared_expert_output = (F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_expert_output) - return expert_output + shared_expert_output - - -def _qwen3_5_moe_forward_transformers_5(self, hidden_states: torch.Tensor, routing_weights: torch.Tensor, - selected_experts: torch.Tensor) -> torch.Tensor: - """Transformers 5.x path: gate returns (router_logits, routing_weights, selected_experts).""" - batch_size, sequence_length, hidden_dim = hidden_states.shape - hidden_states = hidden_states.view(-1, hidden_dim) - expert_output = self.experts(hidden_states, selected_experts, routing_weights) - expert_output = _add_shared_expert(self, hidden_states, expert_output) - return expert_output.reshape(batch_size, sequence_length, hidden_dim) - - -def _qwen3_5_moe_forward_linear_gate(self, hidden_states: torch.Tensor, router_logits: torch.Tensor) -> torch.Tensor: - """Transformers 4.x path: gate is nn.Linear and returns router logits.""" - batch_size, sequence_length, hidden_dim = hidden_states.shape - hidden_states = hidden_states.view(-1, hidden_dim) - routing_weights, selected_experts = _topk_from_router_logits(self, hidden_states, router_logits) - expert_output = self.experts(hidden_states, selected_experts, routing_weights) - expert_output = _add_shared_expert(self, hidden_states, expert_output) - return expert_output.reshape(batch_size, sequence_length, hidden_dim) - - -def npu_qwen3_5_moe_sparse_block_forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - """NPU-accelerated SparseMoeBlock forward with dual Transformers version support.""" - hidden_dim = hidden_states.shape[-1] - gate_output = self.gate(hidden_states.view(-1, hidden_dim)) - - if isinstance(gate_output, tuple): - _, routing_weights, selected_experts = gate_output - return _qwen3_5_moe_forward_transformers_5(self, hidden_states, routing_weights, selected_experts) - - return _qwen3_5_moe_forward_linear_gate(self, hidden_states, gate_output) - - -# ============================================================================= -# Section 2: Fused Operators (RMSNorm / RoPE / SwiGLU / SDPA) -# ============================================================================= - - -class NpuRMSNorm(nn.Module): - r"""Fused RMSNorm via ``torch_npu.npu_rms_norm``.""" - - def __init__(self, hidden_size: int, eps: float = 1e-6) -> None: - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - # Detect residual parameterization (e.g. Qwen3.5: scale = 1.0 + weight) - # once at initialization to avoid CPU-synchronizing Tensor.item() calls. - self._residual_param = abs(self.weight.data.mean().item()) < 0.3 - if self._residual_param: - logger.debug('[NPU] NpuRMSNorm using residual parameterization (1.0 + weight)') - - def _get_effective_weight(self, target_dtype: torch.dtype): - if self._residual_param: - return (1.0 + self.weight).to(dtype=target_dtype) - return self.weight.to(dtype=target_dtype) - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - scale = self._get_effective_weight(hidden_states.dtype) - return torch_npu.npu_rms_norm(hidden_states, scale, epsilon=self.variance_epsilon)[0] - - def extra_repr(self) -> str: - return f'{tuple(self.weight.shape)}, eps={self.variance_epsilon}' - - -def npu_gated_rms_norm_forward(self, hidden_states, gate=None): - """NPU forward for Gated RMSNorm. - - The FP32 mode is controlled by ``TWINKLE_NPU_GATED_RMSNorm_FP32``, - resolved once during patching and stored in ``self._twinkle_force_fp32``. - """ - input_dtype = hidden_states.dtype - _eps = getattr(self, 'variance_epsilon', getattr(self, 'eps', 1e-6)) - - # Read the cached flag; no env lookup in the hot path. - force_fp32 = getattr(self, '_twinkle_force_fp32', False) - if force_fp32: - hidden_states = hidden_states.to(torch.float32) - weight = self.weight.float() - gate = gate.to(torch.float32) if gate is not None else None - else: - weight = self.weight - - hidden_states = torch_npu.npu_rms_norm(hidden_states, weight, epsilon=_eps)[0] - - if gate is not None: - hidden_states = hidden_states * F.silu(gate) - - return hidden_states.to(input_dtype) - - -def _make_apply_npu_rotary_emb(): - _cached_partial = {} - - def _apply_npu_rotary_emb(q, k, cos, sin): - rotary_dim = cos.shape[-1] - query_dim = q.shape[-1] - shape_key = (rotary_dim, query_dim) - - use_partial = _cached_partial.get(shape_key) - if use_partial is None: - use_partial = rotary_dim < query_dim - _cached_partial[shape_key] = use_partial - - if use_partial: - q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] - k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] - q_embed = torch_npu.npu_rotary_mul(q_rot, cos, sin).to(q.dtype) - k_embed = torch_npu.npu_rotary_mul(k_rot, cos, sin).to(k.dtype) - q_embed = torch.cat([q_embed, q_pass], dim=-1) - k_embed = torch.cat([k_embed, k_pass], dim=-1) - else: - q_embed = torch_npu.npu_rotary_mul(q, cos, sin).to(q.dtype) - k_embed = torch_npu.npu_rotary_mul(k, cos, sin).to(k.dtype) - - return q_embed, k_embed - - return _apply_npu_rotary_emb - - -_apply_npu_rotary_emb = _make_apply_npu_rotary_emb() - - -def npu_apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): - """Fused RoPE via ``torch_npu.npu_rotary_mul`` with automatic Partial RoPE support.""" - unsqueeze_dim = _resolve_unsqueeze_dim(position_ids, unsqueeze_dim) - cos = cos.unsqueeze(unsqueeze_dim) - sin = sin.unsqueeze(unsqueeze_dim) - return _apply_npu_rotary_emb(q, k, cos, sin) - - -def npu_apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1): - """Multimodal RoPE for Qwen2.5-VL with automatic Partial RoPE support.""" - mrope_section = mrope_section * 2 - cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(unsqueeze_dim) - sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(unsqueeze_dim) - return _apply_npu_rotary_emb(q, k, cos, sin) - - -def npu_swiglu_forward(self, hidden_state): - """Fused SwiGLU (Qwen-style).""" - return self.down_proj( - torch_npu.npu_swiglu( - torch.cat((self.gate_proj(hidden_state), self.up_proj(hidden_state)), dim=-1), - dim=-1, - )) - - -def npu_sdpa_attention_forward(module, - query, - key, - value, - attention_mask, - dropout=0.0, - scaling=None, - is_causal=None, - **kwargs): - r"""SDPA with NPU compatibility fixes.""" - from transformers.integrations.sdpa_attention import repeat_kv - if hasattr(module, 'num_key_value_groups'): - key = repeat_kv(key, module.num_key_value_groups) - value = repeat_kv(value, module.num_key_value_groups) - - causal_mask = attention_mask - if attention_mask is not None and causal_mask.ndim == 4: - causal_mask = causal_mask[:, :, :, :key.shape[-2]] - - query, key, value = query.contiguous(), key.contiguous(), value.contiguous() - - if is_causal is None: - is_causal = query.shape[2] > 1 and causal_mask is None - - if causal_mask is not None and causal_mask.dtype != torch.bool: - causal_mask = torch.logical_not(causal_mask.bool()).to(query.device) - - attn_output = torch.nn.functional.scaled_dot_product_attention( - query, - key, - value, - attn_mask=causal_mask, - dropout_p=dropout, - scale=scaling, - is_causal=is_causal, - ) - return attn_output.transpose(1, 2).contiguous(), None - - -# ============================================================================= -# Section 2c: Flash Linear Attention (FLA) for Qwen3.5 -# ============================================================================= - - -def _patch_qwen3_5_fla(model=None) -> None: - """Enable Flash Linear Attention (FLA) fast path for Qwen3.5 on NPU. - - Controlled by environment variable ``TWINKLE_NPU_FLA`` (default: True). - """ - if not _is_env_enabled('TWINKLE_NPU_FLA', default=True): - logger.info('[NPU] [FLA] Disabled by TWINKLE_NPU_FLA environment variable') - return - - if not _is_torch_npu_available: - logger.info('[NPU] [FLA] Skip: NPU not available') - return - - # 1. Force FLA availability flag - def _is_fla_available() -> bool: - return True - - for utils_mod_name in ('transformers.utils', 'transformers.utils.import_utils'): - try: - utils_mod = importlib.import_module(utils_mod_name) - setattr(utils_mod, 'is_flash_linear_attention_available', _is_fla_available) - logger.info( - '[NPU] [FLA] Patched %s.is_flash_linear_attention_available', - utils_mod_name, - ) - except Exception as exc: - logger.debug('[NPU] [FLA] Failed to patch %s: %s', utils_mod_name, exc) - - # 2. Try MindSpeed Triton FLA backend - mindspeed_fla = None - try: - from .chunk_gated_delta_rule import chunk_gated_delta_rule as _ms_fla - mindspeed_fla = _ms_fla - logger.info('[NPU] [FLA] MindSpeed Triton chunk_gated_delta_rule loaded') - except ImportError as exc: - logger.warning('[NPU] [FLA] MindSpeed not available: %s', exc) - - # 3. Patch Qwen3.5 modeling modules - fla_target_modules = [ - 'transformers.models.qwen3_5.modeling_qwen3_5', - 'transformers.models.qwen3_5_moe.modeling_qwen3_5_moe', - ] - - for module_name in fla_target_modules: - module = import_optional_module(module_name) - if module is None: - logger.info('[NPU] [FLA] %s: module not found, skip', module_name) - continue - - # Only enable FLA flags if we actually have a backend to serve it - if mindspeed_fla is not None: - setattr(module, 'is_flash_linear_attention_available', _is_fla_available) - setattr(module, 'is_fast_path_available', True) - - # Disable CUDA-only fused op - if hasattr(module, 'FusedRMSNormGated'): - setattr(module, 'FusedRMSNormGated', None) - logger.info('[NPU] [FLA] %s: disabled FusedRMSNormGated', module_name) - - # Replace chunk_gated_delta_rule with MindSpeed implementation - setattr(module, 'chunk_gated_delta_rule', mindspeed_fla) - logger.info( - '[NPU] [FLA] Patched %s.chunk_gated_delta_rule -> MindSpeed', - module_name, - ) - else: - logger.warning( - '[NPU] [FLA] %s: MindSpeed unavailable, FLA flags NOT set', - module_name, - ) - - # 4. Traverse instantiated model and replace per-layer chunk_gated_delta_rule - if model is not None and mindspeed_fla is not None: - # Resolve the underlying PyTorch model from TransformersModel wrapper - model = getattr(model, 'model', getattr(model, 'module', model)) - if not hasattr(model, 'named_modules'): - logger.warning('[NPU] [FLA] Model does not support named_modules, skipping instance patch') - return - patched_instances = 0 - patched_causal = 0 - for _name, _module in model.named_modules(): - if hasattr(_module, 'chunk_gated_delta_rule') and callable(getattr(_module, 'chunk_gated_delta_rule')): - if _module.chunk_gated_delta_rule is mindspeed_fla: - continue - - _module.chunk_gated_delta_rule = mindspeed_fla - # Mark as NPU-patched to prevent it from being overwritten by SP - _module._twinkle_npu_patched = True - patched_instances += 1 - logger.debug( - '[NPU] [FLA] Replaced %s(%s).chunk_gated_delta_rule -> MindSpeed', - _name, - type(_module).__name__, - ) - - if hasattr(_module, 'causal_conv1d_fn'): - current = getattr(_module, 'causal_conv1d_fn') - - if current is npu_causal_conv1d_fn: - continue - _module.causal_conv1d_fn = npu_causal_conv1d_fn - patched_causal += 1 - logger.debug( - '[NPU] [FLA] Replaced %s(%s).causal_conv1d_fn (was %s) -> MindSpeed', - _name, - type(_module).__name__, - current, - ) - - if patched_instances > 0: - logger.info( - '[NPU] [FLA] Patched %d linear attention instance(s)', - patched_instances, - ) - if patched_causal > 0: - logger.info( - '[NPU] [FLA] Patched %d causal_conv1d instance(s)', - patched_causal, - ) - else: - logger.info('[NPU] [FLA] No causal_conv1d_fn instances found in model') - - -# ============================================================================= -# Section 3: Patching Helpers -# ============================================================================= - - -def _patch_sdpa_forward() -> None: - from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, AttentionInterface - AttentionInterface._global_mapping['sdpa'] = npu_sdpa_attention_forward - ALL_ATTENTION_FUNCTIONS['sdpa'] = npu_sdpa_attention_forward - logger.debug('[NPU] [SDPA] Patched global SDPA attention forward') - - -def _patch_rmsnorm(module, class_name: str) -> None: - """Patch RMSNorm class with NPU-optimized implementation.""" - if 'Gated' in class_name: - orig_cls = getattr(module, class_name) - setattr(orig_cls, 'forward', npu_gated_rms_norm_forward) - - # Cache the FP32 env flag once at patch time to avoid per-forward overhead. - orig_cls._twinkle_force_fp32 = os.environ.get('TWINKLE_NPU_GATED_RMSNorm_FP32', - '0').lower() in ('1', 'true', 'on', 'yes') - if orig_cls._twinkle_force_fp32: - logger.info( - '[NPU] [RMSNorm] %s.%s forced to FP32 mode', - module.__name__, - class_name, - ) - - logger.info( - '[NPU] [RMSNorm] Patched %s.%s.forward -> npu_gated_rms_norm_forward', - module.__name__, - class_name, - ) - else: - setattr(module, class_name, NpuRMSNorm) - logger.info( - '[NPU] [RMSNorm] Patched %s.%s -> NpuRMSNorm', - module.__name__, - class_name, - ) - - -def _patch_rope(module, func_name: str) -> None: - setattr(module, func_name, npu_apply_rotary_pos_emb) - logger.debug( - '[NPU] [RoPE] Patched %s.%s -> npu_apply_rotary_pos_emb', - module.__name__, - func_name, - ) - - -def _patch_swiglu(module, class_name: str) -> None: - setattr(getattr(module, class_name), 'forward', npu_swiglu_forward) - logger.debug( - '[NPU] [MLP] Patched %s.%s.forward -> npu_swiglu_forward', - module.__name__, - class_name, - ) - - -def _patch_moe_sparse_block(module, class_name: str) -> None: - """Patch SparseMoeBlock forward with NPU-optimized implementation.""" - setattr(getattr(module, class_name), 'forward', npu_qwen3_5_moe_sparse_block_forward) - logger.info( - '[NPU] [MoE] Patched %s.%s.forward -> npu_qwen3_5_moe_sparse_block_forward', - module.__name__, - class_name, - ) - - -def _patch_moe_experts(module, class_name: str) -> None: - """Patch packed Experts forward with NPU grouped matmul.""" - setattr(getattr(module, class_name), 'forward', npu_packed_moe_experts_forward) - logger.debug( - '[NPU] [MoE] Patched %s.%s.forward -> npu_packed_moe_experts_forward', - module.__name__, - class_name, - ) - - -# ============================================================================= -# Section 4: Environment Control -# ============================================================================= - - -def _is_env_enabled(var_name: str, default: bool = True) -> bool: - """Check whether an environment variable is enabled. - - Supports: ``1``/``true``/``on``/``yes`` (force on), - ``0``/``false``/``off``/``no`` (force off), - unset (use ``default``). - """ - env = os.environ.get(var_name, '').lower().strip() - if not env: - return default - if env in ('1', 'true', 'on', 'yes'): - return True - if env in ('0', 'false', 'off', 'no'): - logger.info('[NPU] %s=%s: disabled.', var_name, env) - return False - return default - - -# ============================================================================= -# Section 5: Unified Patching Logic (Fused Ops) -# ============================================================================= - - -def _apply_all_fused_ops(model=None) -> None: - """Apply fused ops to supported model families.""" - logger.info('[NPU] === _apply_all_fused_ops ENTERED ===') - if not _is_torch_npu_available: - return - - if not _is_env_enabled('TWINKLE_NPU_FUSED_OPS', default=True): - return - - target_archs = set() - if model is not None: - config = getattr(model, 'hf_config', getattr(model, 'config', None)) - archs = getattr(config, 'architectures', None) if config else None - if archs: - target_archs = set(archs) - logger.debug('[NPU] Detected architectures for fused ops: %s', archs) - - logger.info('[NPU] Auto-applying fused ops to supported model families') - - _patch_sdpa_forward() - - model_families = [ - ('transformers.models.qwen3.modeling_qwen3', 'Qwen3', 'Qwen3MLP', 'Qwen3ForCausalLM'), - ('transformers.models.qwen3_moe.modeling_qwen3_moe', 'Qwen3Moe', 'Qwen3MoeMLP', 'Qwen3MoeForCausalLM'), - ( - 'transformers.models.qwen2_5_vl.modeling_qwen2_5_vl', - 'Qwen2_5_VL', - 'Qwen2MLP', - 'Qwen2_5_VLForConditionalGeneration', - ), - ( - 'transformers.models.qwen3_5_moe.modeling_qwen3_5_moe', - 'Qwen3_5Moe', - 'Qwen3_5MoeMLP', - 'Qwen3MoeForCausalLM', - ), - ] - - modeling_qwen3_5 = import_optional_module('transformers.models.qwen3_5.modeling_qwen3_5') - if modeling_qwen3_5 is not None: - model_families.append(( - 'transformers.models.qwen3_5.modeling_qwen3_5', - 'Qwen3_5', - 'Qwen3_5MLP', - 'Qwen3_5ForCausalLM', - )) - - modeling_qwen3_5_moe = import_optional_module('transformers.models.qwen3_5_moe.modeling_qwen3_5_moe') - if modeling_qwen3_5_moe is not None: - model_families.append(( - 'transformers.models.qwen3_5_moe.modeling_qwen3_5_moe', - 'Qwen3_5Moe', - 'Qwen3_5MoeMLP', - 'Qwen3_5MoeForCausalLM', - )) - - patched_count = 0 - for module_name, prefix, mlp_name, trigger_arch in model_families: - try: - module = importlib.import_module(module_name) - - # RMSNorm - rmsnorm_cls = f'{prefix}RMSNorm' - if hasattr(module, rmsnorm_cls): - _patch_rmsnorm(module, rmsnorm_cls) - patched_count += 1 - - # RoPE - if hasattr(module, 'apply_rotary_pos_emb'): - _patch_rope(module, 'apply_rotary_pos_emb') - patched_count += 1 - - # SwiGLU / MLP - if hasattr(module, mlp_name): - _patch_swiglu(module, mlp_name) - patched_count += 1 - - experts_cls = f'{prefix}Experts' - if hasattr(module, experts_cls): - _patch_moe_experts(module, experts_cls) - patched_count += 1 - - sparse_cls = f'{prefix}SparseMoeBlock' - if hasattr(module, sparse_cls): - _patch_moe_sparse_block(module, sparse_cls) - patched_count += 1 - - if prefix == 'Qwen2_5_VL': - if hasattr(module, 'Qwen2_5_VLMLP'): - _patch_swiglu(module, 'Qwen2_5_VLMLP') - patched_count += 1 - setattr(module, 'apply_multimodal_rotary_pos_emb', npu_apply_multimodal_rotary_pos_emb) - logger.debug('[NPU] Patched Qwen2_5_VL multimodal RoPE') - - if prefix == 'Qwen3_5': - gated_rmsnorm_cls = f'{prefix}GatedRMSNorm' - if hasattr(module, gated_rmsnorm_cls): - _patch_rmsnorm(module, gated_rmsnorm_cls) - patched_count += 1 - if hasattr(module, 'Qwen3_5VisionMLP'): - _patch_swiglu(module, 'Qwen3_5VisionMLP') - patched_count += 1 - if hasattr(module, 'Qwen3_5VisionRMSNorm'): - _patch_rmsnorm(module, 'Qwen3_5VisionRMSNorm') - patched_count += 1 - - if prefix == 'Qwen3_5Moe': - if hasattr(module, 'Qwen3_5MoeGatedRMSNorm'): - _patch_rmsnorm(module, 'Qwen3_5MoeGatedRMSNorm') - patched_count += 1 - - logger.debug('[NPU] Patched %s fused ops', prefix) - except ImportError: - pass - - if not target_archs: - patched_count += _discover_and_patch_unknown_models() - - _patch_qwen3_5_fla(model) - - logger.info('[NPU] Auto-patched %d components', patched_count) - - -# ============================================================================= -# Section 5b: Dynamic model discovery (no hard-coding) -# ============================================================================= - - -def _discover_and_patch_unknown_models() -> int: - """Dynamically discover and patch additional transformers model families.""" - patched = 0 - already_patched_modules = { - 'transformers.models.qwen3.modeling_qwen3', - 'transformers.models.qwen3_moe.modeling_qwen3_moe', - 'transformers.models.qwen2_5_vl.modeling_qwen2_5_vl', - 'transformers.models.qwen3_5.modeling_qwen3_5', - 'transformers.models.qwen3_5_moe.modeling_qwen3_5_moe', - } - - try: - import transformers.models as models_pkg - except ImportError: - return 0 - - candidate_modules = [] - for model_name in dir(models_pkg): - if model_name.startswith('_'): - continue - modeling_path = f'transformers.models.{model_name}.modeling_{model_name}' - if modeling_path not in already_patched_modules: - candidate_modules.append(modeling_path) - - for module_name in candidate_modules: - module = import_optional_module(module_name) - if module is None: - continue - - has_rmsnorm = any('RMSNorm' in attr_name and isinstance(getattr(module, attr_name, None), type) - for attr_name in dir(module)) - has_rope = hasattr(module, 'apply_rotary_pos_emb') - has_mlp = any( - attr_name.endswith('MLP') and isinstance(getattr(module, attr_name, None), type) - for attr_name in dir(module)) - - if not (has_rmsnorm or has_rope or has_mlp): - continue - - for attr_name in dir(module): - if attr_name.startswith('_'): - continue - obj = getattr(module, attr_name, None) - if not isinstance(obj, type): - continue - - if 'RMSNorm' in attr_name and issubclass(obj, nn.Module): - try: - _patch_rmsnorm(module, attr_name) - patched += 1 - except Exception as exc: - logger.debug('[NPU] Failed to patch %s.%s: %s', module_name, attr_name, exc) - - if attr_name.endswith('MLP') and hasattr(obj, 'forward'): - try: - _patch_swiglu(module, attr_name) - patched += 1 - except Exception as exc: - logger.debug('[NPU] Failed to patch %s.%s: %s', module_name, attr_name, exc) - - if attr_name.endswith('Experts') and hasattr(obj, 'forward'): - try: - _patch_moe_experts(module, attr_name) - patched += 1 - except Exception as exc: - logger.debug('[NPU] Failed to patch %s.%s: %s', module_name, attr_name, exc) - - if attr_name.endswith('SparseMoeBlock') and hasattr(obj, 'forward'): - try: - _patch_moe_sparse_block(module, attr_name) - patched += 1 - except Exception as exc: - logger.debug('[NPU] Failed to patch %s.%s: %s', module_name, attr_name, exc) - - if has_rope: - try: - _patch_rope(module, 'apply_rotary_pos_emb') - patched += 1 - except Exception as exc: - logger.debug('[NPU] Failed to patch %s.apply_rotary_pos_emb: %s', module_name, exc) - - if patched > 0: - logger.debug('[NPU] Dynamically patched %s', module_name) - - return patched - - -# ============================================================================= -# Section 6: Public API -# ============================================================================= - - -def apply_npu_patch(model=None) -> None: - """Apply all NPU patches. - - Ascend NPU optimizations applied: - - MoE grouped_matmul (GMM) - - RMSNorm fused kernel - - RoPE fused kernel - - SwiGLU fused kernel - - SDPA Attention compatibility fixes - - Flash Linear Attention (FLA) for Qwen3.5 - - Causal Conv1D Triton kernel for linear attention - - When ``model`` is **not** provided, the GMM patch is **skipped** by default - (EP cannot be detected without a model instance). - - When ``model`` is provided, the GMM patch is evaluated with EP detection: - - EP enabled → apply GMM patch (efficient on small sharded weights). - - EP not enabled → skip GMM patch (avoid ~8x contiguous-copy overhead). - - Environment variables: - - ``TWINKLE_NPU_PATCH``: overall switch (``1``/``0``) - - ``TWINKLE_NPU_FUSED_OPS``: fused ops switch (``1``/``0``) - - ``TWINKLE_NPU_GMM_PATCH``: MoE GMM switch (``1``/``0``/unset). - When unset: skip the patch by default. - When ``1``: EP-aware — patch is applied **only if EP is enabled**; - without EP the native grouped_mm or per-expert fallback is used - (avoiding ~8x overhead from contiguous copies). - When ``0``: disable the patch regardless. - - ``TWINKLE_NPU_FLA``: FLA switch (``1``/``0``) - - ``TWINKLE_NPU_GATED_RMSNorm_FP32``: force FP32 in Gated RMSNorm (``1``/``0``) - - Args: - model: Optional model instance. If not provided, GMM patch is skipped. - If provided, GMM patch is evaluated with EP detection on the model. - """ - global _NPU_PATCH_APPLIED - - if not _is_env_enabled('TWINKLE_NPU_PATCH', default=True): - return - - if _NPU_PATCH_APPLIED: - logger.debug('[NPU] Patches already applied, skipping.') - return - - try: - import torch_npu - except ImportError: - logger.warning('torch_npu not available. Skipping NPU patches.') - return - - _apply_hf_moe_grouped_mm_patch(model) - - _apply_all_fused_ops(model) - - _NPU_PATCH_APPLIED = True - logger.info('[NPU] All patches applied successfully') - - -def register_npu_fused_function_kernels() -> None: - """Register NPU fused ops as Twinkle function kernels (optional).""" - if not _is_torch_npu_available: - return - - from .function import register_function_kernel - - register_function_kernel( - func_name='apply_rotary_pos_emb', - target_module='transformers.modeling_rope_utils', - func_impl=npu_apply_rotary_pos_emb, - device='npu', - mode='train', - ) - register_function_kernel( - func_name='sdpa_attention_forward', - target_module='transformers.integrations.sdpa_attention', - func_impl=npu_sdpa_attention_forward, - device='npu', - mode='train', - ) - logger.info('[NPU] Registered fused function kernels for training') diff --git a/src/twinkle/kernel/registry.py b/src/twinkle/kernel/registry.py deleted file mode 100644 index d03f510f..00000000 --- a/src/twinkle/kernel/registry.py +++ /dev/null @@ -1,183 +0,0 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Type - -from twinkle import get_logger -from .base import DeviceType, ModeType, is_kernels_available - -if TYPE_CHECKING: - from kernels.layer.func import FuncRepositoryProtocol - -logger = get_logger() - - -class LayerRegistry: - """Manages kernel registrations and syncs to HF kernels.""" - - def __init__(self): - self._registry: Dict[str, Dict[DeviceType, Dict[Any, Any]]] = {} - self._synced = False - - def register(self, kernel_name: str, repo_spec: Any, device: DeviceType = 'cuda', mode: Any = None) -> None: - if kernel_name not in self._registry: - self._registry[kernel_name] = {} - if device not in self._registry[kernel_name]: - self._registry[kernel_name][device] = {} - self._registry[kernel_name][device][mode] = repo_spec - self._synced = False - - def get(self, kernel_name: str, device: Optional[DeviceType] = None, mode: Any = None) -> Optional[Any]: - if kernel_name not in self._registry: - return None - devices = self._registry[kernel_name] - if device is None: - device = next(iter(devices.keys()), None) - if device is None: - return None - modes = devices.get(device) - if modes is None: - return None - if mode is None: - return next(iter(modes.values()), None) - return modes.get(mode) - - def has(self, kernel_name: str, device: Optional[DeviceType] = None, mode: Any = None) -> bool: - if kernel_name not in self._registry: - return False - devices = self._registry[kernel_name] - if device is None: - return True - if device not in devices: - return False - if mode is None: - return True - return mode in devices[device] - - def list_kernel_names(self) -> List[str]: - return list(self._registry.keys()) - - def sync_to_hf_kernels(self) -> None: - if self._synced or not self._registry: - return - - if not is_kernels_available(): - return - - from kernels import register_kernel_mapping as hf_register_kernel_mapping - - hf_register_kernel_mapping({}, inherit_mapping=False) - for kernel_name, device_dict in self._registry.items(): - hf_mapping = {kernel_name: device_dict} - hf_register_kernel_mapping(hf_mapping, inherit_mapping=True) - - self._synced = True - - def _clear(self) -> None: - self._registry.clear() - self._synced = False - - -_global_layer_registry = LayerRegistry() - - -class ExternalLayerRegistry: - """Maps layer classes to kernel names.""" - - def __init__(self): - self._map: Dict[Type, str] = {} - - def register(self, layer_class: Type, kernel_name: str) -> None: - self._map[layer_class] = kernel_name - - def get(self, layer_class: Type) -> Optional[str]: - return self._map.get(layer_class) - - def has(self, layer_class: Type) -> bool: - return layer_class in self._map - - def list_mappings(self) -> List[Tuple[Type, str]]: - return list(self._map.items()) - - def _clear(self) -> None: - self._map.clear() - - -_global_external_layer_registry = ExternalLayerRegistry() - - -@dataclass(frozen=True) -class FunctionKernelSpec: - func_name: str - target_module: str - func_impl: Optional[Callable] - repo: Optional['FuncRepositoryProtocol'] - repo_id: Optional[str] - revision: Optional[str] - version: Optional[str] - device: Optional[str] - mode: Optional[ModeType] - - -class FunctionRegistry: - """Manages function-level kernel registrations.""" - - def __init__(self) -> None: - self._registry: List[FunctionKernelSpec] = [] - - def register(self, spec: FunctionKernelSpec) -> None: - if spec in self._registry: - return - self._registry.append(spec) - - def list_specs(self) -> List[FunctionKernelSpec]: - return list(self._registry) - - def _clear(self) -> None: - self._registry.clear() - - -_global_function_registry = FunctionRegistry() - - -def register_layer(kernel_name: str, repo_spec: Any, device: DeviceType = 'cuda', mode: Any = None) -> None: - _global_layer_registry.register(kernel_name, repo_spec, device, mode) - - -def get_layer_spec(kernel_name: str, device: Optional[DeviceType] = None, mode: Any = None) -> Optional[Any]: - return _global_layer_registry.get(kernel_name, device, mode) - - -def list_kernel_names() -> List[str]: - return _global_layer_registry.list_kernel_names() - - -def has_kernel(kernel_name: str, device: Optional[DeviceType] = None, mode: Any = None) -> bool: - return _global_layer_registry.has(kernel_name, device, mode) - - -def register_external_layer(layer_class: Type, kernel_name: str) -> None: - _global_external_layer_registry.register(layer_class, kernel_name) - - if is_kernels_available(): - from kernels import replace_kernel_forward_from_hub - replace_kernel_forward_from_hub(layer_class, kernel_name) - logger.info(f'Registered {layer_class.__name__} -> kernel: {kernel_name}') - else: - logger.warning(f'HF kernels not available. {layer_class.__name__} mapping registered ' - f'but kernel replacement will not work without kernels package.') - - -def get_external_kernel_name(layer_class: Type) -> Optional[str]: - return _global_external_layer_registry.get(layer_class) - - -def get_global_layer_registry() -> LayerRegistry: - return _global_layer_registry - - -def get_global_external_layer_registry() -> ExternalLayerRegistry: - return _global_external_layer_registry - - -def get_global_function_registry() -> FunctionRegistry: - return _global_function_registry diff --git a/src/twinkle/model/transformers/strategy/sequence_parallel/linear_attention_sp.py b/src/twinkle/model/transformers/strategy/sequence_parallel/linear_attention_sp.py index 4ae580fa..6608a2b8 100644 --- a/src/twinkle/model/transformers/strategy/sequence_parallel/linear_attention_sp.py +++ b/src/twinkle/model/transformers/strategy/sequence_parallel/linear_attention_sp.py @@ -107,7 +107,7 @@ def _torch_causal_conv1d_fn( return out.transpose(1, 2).contiguous() # NPU: MindSpeed Triton causal_conv1d and chunk_gated_delta_rule - # are both patched by monkey_patch_npu at model initialization. + # are both patched by twinkle.kernel.npu_impls.fla at model initialization. # No need to set them here - they are already bound on the module. if getattr(mod, '_twinkle_npu_patched', False): return False diff --git a/tests/kernel/test_function_kernel.py b/tests/kernel/test_function_kernel.py deleted file mode 100644 index 02b35dd4..00000000 --- a/tests/kernel/test_function_kernel.py +++ /dev/null @@ -1,265 +0,0 @@ -import os -import pytest -import sys -import torch -import torch.nn as nn -import torch.nn.functional as F -import types - -try: - import requests -except ImportError: - requests = None - -from twinkle.kernel.base import is_kernels_available -from twinkle.kernel.function import apply_function_kernel, register_function_kernel -from twinkle.kernel.registry import get_global_function_registry - - -def _ensure_test_packages() -> None: - if 'tests' not in sys.modules: - tests_pkg = types.ModuleType('tests') - tests_pkg.__path__ = [] - sys.modules['tests'] = tests_pkg - if 'tests.kernel' not in sys.modules: - kernel_pkg = types.ModuleType('tests.kernel') - kernel_pkg.__path__ = [] - sys.modules['tests.kernel'] = kernel_pkg - - -def _reference_silu_and_mul(x: torch.Tensor) -> torch.Tensor: - d = x.shape[-1] // 2 - return F.silu(x[..., :d]) * x[..., d:] - - -class TestFunctionKernel: - - def setup_method(self): - if not is_kernels_available(): - pytest.skip('kernels package not available in this environment.') - get_global_function_registry()._clear() - - def teardown_method(self): - get_global_function_registry()._clear() - - def test_flattened_build_replaces_function(self): - if os.environ.get('TWINKLE_SKIP_SLOW_TESTS') == '1': - pytest.skip('TWINKLE_SKIP_SLOW_TESTS=1') - if not torch.cuda.is_available(): - pytest.skip('CUDA not available in this environment.') - try: - import urllib.request - urllib.request.urlopen('https://huggingface.co', timeout=5) - except Exception as e: - pytest.skip(f'HuggingFace unreachable: {e}') - try: - from kernels import has_kernel - from kernels._versions import select_revision_or_version - from kernels.utils import get_kernel - except Exception: - pytest.skip('kernels package missing has_kernel.') - if not has_kernel('kernels-test/flattened-build'): - pytest.skip('kernels-test/flattened-build not available.') - try: - revision = select_revision_or_version( - 'kernels-test/flattened-build', - revision=None, - version=None, - ) - get_kernel('kernels-test/flattened-build', revision=revision) - except Exception as exc: - pytest.skip(f'kernels-test/flattened-build cannot be loaded in this env: {exc}') - - _ensure_test_packages() - module_name = 'tests.kernel._tmp_flattened_build_module' - temp_module = types.ModuleType(module_name) - - def original(x: torch.Tensor) -> torch.Tensor: - return _reference_silu_and_mul(x) - - temp_module.silu_and_mul = original - temp_module.__path__ = [] - sys.modules[module_name] = temp_module - - try: - register_function_kernel( - func_name='silu_and_mul', - target_module=module_name, - repo_id='kernels-test/flattened-build', - device='cuda', - mode='inference', - ) - - try: - applied = apply_function_kernel( - target_module=module_name, - device='cuda', - mode='inference', - ) - except TypeError as e: - if 'select_revision_or_version' in str(e) or 'takes 1 positional argument' in str(e): - pytest.skip(f'kernels API incompatible: {e}') - raise - except Exception as e: - if requests and isinstance(e, (requests.exceptions.SSLError, requests.exceptions.RequestException)): - pytest.skip(f'Network/HuggingFace unreachable: {e}') - if 'SSLError' in type(e).__name__ or 'MaxRetryError' in str(e): - pytest.skip(f'Network/HuggingFace unreachable: {e}') - raise - - assert applied == [f'{module_name}.silu_and_mul'] - assert temp_module.silu_and_mul is not original - - x = torch.randn(4, 16, device='cuda', dtype=torch.float16) - y_kernel = temp_module.silu_and_mul(x) - y_ref = _reference_silu_and_mul(x) - assert torch.allclose(y_kernel, y_ref, atol=1e-3, rtol=1e-3) - except Exception as e: - if requests and isinstance(e, (requests.exceptions.SSLError, requests.exceptions.RequestException)): - pytest.skip(f'Network/HuggingFace unreachable: {e}') - if 'SSLError' in type(e).__name__ or 'MaxRetryError' in str(e): - pytest.skip(f'Network/HuggingFace unreachable: {e}') - raise - finally: - sys.modules.pop(module_name, None) - - def test_flattened_build_device_filter(self): - _ensure_test_packages() - module_name = 'tests.kernel._tmp_flattened_build_device' - temp_module = types.ModuleType(module_name) - - def original(x: torch.Tensor) -> torch.Tensor: - return _reference_silu_and_mul(x) - - temp_module.silu_and_mul = original - temp_module.__path__ = [] - sys.modules[module_name] = temp_module - - try: - register_function_kernel( - func_name='silu_and_mul', - target_module=module_name, - repo_id='kernels-test/flattened-build', - device='cuda', - mode='inference', - ) - - applied = apply_function_kernel( - target_module=module_name, - device='cpu', - mode='inference', - ) - - assert applied == [] - assert temp_module.silu_and_mul is original - finally: - sys.modules.pop(module_name, None) - - def test_flattened_build_mode_filter(self): - _ensure_test_packages() - module_name = 'tests.kernel._tmp_flattened_build_mode' - temp_module = types.ModuleType(module_name) - - def original(x: torch.Tensor) -> torch.Tensor: - return _reference_silu_and_mul(x) - - temp_module.silu_and_mul = original - temp_module.__path__ = [] - sys.modules[module_name] = temp_module - - try: - register_function_kernel( - func_name='silu_and_mul', - target_module=module_name, - repo_id='kernels-test/flattened-build', - device='cuda', - mode='inference', - ) - - applied = apply_function_kernel( - target_module=module_name, - device='cuda', - mode='train', - ) - - assert applied == [] - assert temp_module.silu_and_mul is original - finally: - sys.modules.pop(module_name, None) - - def test_flattened_build_strict_raises_on_no_match(self): - _ensure_test_packages() - module_name = 'tests.kernel._tmp_flattened_build_strict' - temp_module = types.ModuleType(module_name) - - def original(x: torch.Tensor) -> torch.Tensor: - return _reference_silu_and_mul(x) - - temp_module.silu_and_mul = original - temp_module.__path__ = [] - sys.modules[module_name] = temp_module - - try: - register_function_kernel( - func_name='silu_and_mul', - target_module=module_name, - repo_id='kernels-test/flattened-build', - device='cuda', - mode='inference', - ) - - with pytest.raises(ValueError): - apply_function_kernel( - target_module=module_name, - device='cpu', - mode='inference', - strict=True, - ) - finally: - sys.modules.pop(module_name, None) - - def test_repo_object_loads_module_class(self): - _ensure_test_packages() - module_name = 'tests.kernel._tmp_repo_object' - temp_module = types.ModuleType(module_name) - - def original(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: - return x + y - - temp_module.add = original - temp_module.__path__ = [] - sys.modules[module_name] = temp_module - - class MyKernelFunc(nn.Module): - - def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: - return x + y + 2 - - class MyFuncRepo: - func_name = 'add' - - def load(self): - return MyKernelFunc - - try: - register_function_kernel( - func_name='add', - target_module=module_name, - repo=MyFuncRepo(), - device='cuda', - mode='inference', - ) - - applied = apply_function_kernel( - target_module=module_name, - device='cuda', - mode='inference', - ) - - assert applied == [f'{module_name}.add'] - assert temp_module.add is not original - x = torch.tensor([1.0]) - y = torch.tensor([2.0]) - assert torch.allclose(temp_module.add(x, y), x + y + 2) - finally: - sys.modules.pop(module_name, None) diff --git a/tests/kernel/test_kernel.py b/tests/kernel/test_kernel.py deleted file mode 100644 index 5b6a658b..00000000 --- a/tests/kernel/test_kernel.py +++ /dev/null @@ -1,352 +0,0 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -""" -Kernel module unit tests -""" -import os -import pytest -from unittest.mock import MagicMock, Mock, patch - -from twinkle.kernel import kernelize_model, register_external_layer, register_kernels, register_layer_kernel -from twinkle.kernel.base import is_kernels_available, is_kernels_enabled, to_kernels_mode -from twinkle.kernel.registry import (ExternalLayerRegistry, LayerRegistry, get_global_external_layer_registry, - get_global_function_registry, get_global_layer_registry, get_layer_spec, - register_layer) - - -class TestBase: - """Test base helpers and env vars.""" - - def test_is_kernels_available(self): - """Test kernels availability check.""" - result = is_kernels_available() - assert isinstance(result, bool) - - def test_kernels_enabled_env_var(self): - """Test env var controls kernels enablement.""" - original = os.environ.get('TWINKLE_USE_KERNELS') - try: - os.environ['TWINKLE_USE_KERNELS'] = 'YES' - from twinkle.kernel.base import _kernels_enabled - assert _kernels_enabled() - - os.environ['TWINKLE_USE_KERNELS'] = 'NO' - import importlib - - import twinkle.kernel.base - importlib.reload(twinkle.kernel.base) - from twinkle.kernel.base import _kernels_enabled - assert not _kernels_enabled() - finally: - if original is not None: - os.environ['TWINKLE_USE_KERNELS'] = original - else: - os.environ.pop('TWINKLE_USE_KERNELS', None) - - def test_to_kernels_mode(self): - """Test mode conversion.""" - if not is_kernels_available(): - pytest.skip('kernels package not available') - - assert to_kernels_mode('train').name == 'TRAINING' - assert to_kernels_mode('inference').name == 'INFERENCE' - assert to_kernels_mode('compile').name == 'TORCH_COMPILE' - - -class TestLayerRegistry: - """Test layer registry.""" - - def setup_method(self): - self.registry = LayerRegistry() - - def test_register_and_get(self): - """Test register and lookup.""" - mock_spec = Mock() - self.registry.register('TestLayer', mock_spec, 'cuda') - - result = self.registry.get('TestLayer', 'cuda') - assert result == mock_spec - - result = self.registry.get('NonExistent', 'cuda') - assert result is None - - def test_register_multiple_devices(self): - """Test registration for multiple devices.""" - mock_cuda = Mock() - mock_npu = Mock() - - self.registry.register('TestLayer', mock_cuda, 'cuda') - self.registry.register('TestLayer', mock_npu, 'npu') - - assert self.registry.get('TestLayer', 'cuda') == mock_cuda - assert self.registry.get('TestLayer', 'npu') == mock_npu - - def test_get_without_device(self): - """Test lookup without device.""" - mock_spec = Mock() - self.registry.register('TestLayer', mock_spec, 'cuda') - - result = self.registry.get('TestLayer') - assert result == mock_spec - - def test_has(self): - """Test has checks.""" - mock_spec = Mock() - assert not self.registry.has('TestLayer') - - self.registry.register('TestLayer', mock_spec, 'cuda') - assert self.registry.has('TestLayer') - assert self.registry.has('TestLayer', 'cuda') - assert not self.registry.has('TestLayer', 'npu') - - def test_list_kernel_names(self): - """Test listing kernel names.""" - mock_spec = Mock() - self.registry.register('Layer1', mock_spec, 'cuda') - self.registry.register('Layer2', mock_spec, 'cuda') - - names = self.registry.list_kernel_names() - assert sorted(names) == sorted(['Layer1', 'Layer2']) - - -class TestExternalLayerRegistry: - """Test external layer registry.""" - - def setup_method(self): - self.registry = ExternalLayerRegistry() - - def test_register_and_get(self): - """Test register and lookup.""" - mock_class = Mock - self.registry.register(mock_class, 'LlamaAttention') - - result = self.registry.get(mock_class) - assert result == 'LlamaAttention' - - def test_has(self): - """Test has checks.""" - mock_class = Mock - assert not self.registry.has(mock_class) - - self.registry.register(mock_class, 'LlamaAttention') - assert self.registry.has(mock_class) - - def test_list_mappings(self): - """Test list mappings.""" - - class MockClass1: - pass - - class MockClass2: - pass - - self.registry.register(MockClass1, 'LlamaAttention') - self.registry.register(MockClass2, 'LlamaMLP') - - mappings = self.registry.list_mappings() - assert len(mappings) == 2 - - -class TestRegisterLayer: - """Test global register helpers.""" - - def setup_method(self): - get_global_layer_registry()._clear() - get_global_function_registry()._clear() - - def test_register_and_get_spec(self): - """Test global register and lookup.""" - mock_spec = Mock() - register_layer('TestLayer', mock_spec, 'cuda') - - result = get_layer_spec('TestLayer', 'cuda') - assert result == mock_spec - - -class TestRegisterLayerKernel: - """Test register_layer_kernel.""" - - def setup_method(self): - get_global_layer_registry()._clear() - - def test_register_without_kernels_package(self): - """Test registration when kernels package missing.""" - with patch('twinkle.kernel.layer.is_kernels_available', return_value=False): - register_layer_kernel('TestLayer', repo_id='test/repo') - assert get_layer_spec('TestLayer') is None - - def test_register_with_kernels_package(self): - """Test registration when kernels package available.""" - if not is_kernels_available(): - pytest.skip('kernels package not available') - - register_layer_kernel( - kernel_name='TestLayer', - repo_id='kernels-community/test', - ) - - assert get_layer_spec('TestLayer') is not None - - -class TestKernelizeModel: - """Test kernelize_model.""" - - def test_kernelize_without_kernels_enabled(self): - """Test returns original model when kernels disabled.""" - with patch('twinkle.kernel.layer.is_kernels_enabled', return_value=False): - mock_model = Mock() - result = kernelize_model(mock_model) - assert result == mock_model - - @patch('twinkle.kernel.layer.is_kernels_available', return_value=False) - def test_kernelize_without_kernels_available(self, mock_available): - """Test returns original model when kernels unavailable.""" - mock_model = Mock() - result = kernelize_model(mock_model) - assert result == mock_model - - -class TestRegisterExternalLayer: - """Test register_external_layer.""" - - def setup_method(self): - get_global_external_layer_registry()._clear() - - def test_register_external_layer(self): - """Test registering external layer.""" - mock_class = Mock - - register_external_layer(mock_class, 'LlamaAttention') - - result = get_global_external_layer_registry().get(mock_class) - assert result == 'LlamaAttention' - - def test_register_external_qwen_layer(self): - """Test registering Qwen2 external layer mapping.""" - try: - from transformers.models.qwen2.modeling_qwen2 import Qwen2Attention - except ImportError: - pytest.skip('transformers package not available') - - register_external_layer(Qwen2Attention, 'LlamaAttention') - - registry = get_global_external_layer_registry() - assert registry.has(Qwen2Attention) - assert registry.get(Qwen2Attention) == 'LlamaAttention' - - def test_register_external_layer_adds_kernel_layer_name(self): - """Test register_external_layer sets kernel_layer_name.""" - if not is_kernels_available(): - pytest.skip('kernels package not available') - - class TestLayer: - pass - - register_external_layer(TestLayer, 'TestKernel') - - assert hasattr(TestLayer, 'kernel_layer_name') - assert TestLayer.kernel_layer_name == 'TestKernel' - - -class TestRegisterKernels: - """Test register_kernels batch registration.""" - - def setup_method(self): - get_global_layer_registry()._clear() - - @patch('twinkle.kernel.layer.is_kernels_available', return_value=False) - def test_register_layers_without_kernels(self, mock_available): - """Test layer batch registration when kernels missing.""" - config = { - 'layers': { - 'LlamaAttention': { - 'repo_id': 'kernels-community/llama-attention' - }, - 'LlamaMLP': { - 'repo_id': 'kernels-community/llama-mlp' - }, - } - } - - register_kernels(config) - - assert get_layer_spec('LlamaAttention') is None - assert get_layer_spec('LlamaMLP') is None - - def test_register_functions(self): - """Test function batch registration.""" - config = { - 'functions': { - 'apply_rotary_pos_emb': { - 'func_impl': Mock, - 'target_module': 'test', - 'device': 'cpu', - 'mode': 'inference', - } - } - } - - register_kernels(config) - specs = get_global_function_registry().list_specs() - assert len(specs) == 1 - spec = specs[0] - assert spec.func_name == 'apply_rotary_pos_emb' - assert spec.target_module == 'test' - assert spec.func_impl == Mock - assert spec.device == 'cpu' - assert spec.mode == 'inference' - - -class TestModeSupport: - """Test mode support.""" - - def setup_method(self): - get_global_layer_registry()._clear() - - @patch('twinkle.kernel.layer.is_kernels_available', return_value=False) - def test_register_with_mode_fallback(self, mock_available): - """Test fallback mode mapping when mode is None.""" - from kernels import Mode - - from twinkle.kernel.layer import _to_hf_mode, register_layer_kernel - - result = _to_hf_mode(None) - assert result == Mode.FALLBACK - - def test_to_hf_mode_conversion(self): - """Test Twinkle mode to HF kernels Mode conversion.""" - if not is_kernels_available(): - pytest.skip('kernels package not available') - - from kernels import Mode - - from twinkle.kernel.layer import _to_hf_mode - - assert _to_hf_mode('train') == Mode.TRAINING - assert _to_hf_mode('inference') == Mode.INFERENCE - assert _to_hf_mode('compile') == Mode.TORCH_COMPILE - - @patch('twinkle.kernel.layer.is_kernels_available', return_value=False) - def test_register_multiple_modes(self, mock_available): - """Test registering multiple modes for the same layer.""" - registry = get_global_layer_registry() - - class MockRepo: - pass - - repo_inference = MockRepo() - repo_training = MockRepo() - - from kernels import Mode - - registry.register('TestLayer', repo_inference, 'cuda', Mode.INFERENCE) - registry.register('TestLayer', repo_training, 'cuda', Mode.TRAINING) - - assert registry.has('TestLayer', 'cuda', Mode.INFERENCE) - assert registry.has('TestLayer', 'cuda', Mode.TRAINING) - - result = registry.get('TestLayer', 'cuda', Mode.INFERENCE) - assert result == repo_inference - - -if __name__ == '__main__': - pytest.main([__file__]) From 3fb7071b2da24e54c1e8a0d0c52faa04153f9dce Mon Sep 17 00:00:00 2001 From: weikaiwen <34648228+kevssim@users.noreply.github.com> Date: Fri, 26 Jun 2026 09:33:20 +0800 Subject: [PATCH 16/27] refactor(cookbook): migrate to new twinkle.kernel API --- cookbook/transformers/ep_fsdp2_lora_qwen3_5_moe.py | 4 ++-- cookbook/transformers/fsdp2.py | 4 ++-- cookbook/transformers/sp_fsdp_dense.py | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/cookbook/transformers/ep_fsdp2_lora_qwen3_5_moe.py b/cookbook/transformers/ep_fsdp2_lora_qwen3_5_moe.py index a9f90111..03e962e6 100644 --- a/cookbook/transformers/ep_fsdp2_lora_qwen3_5_moe.py +++ b/cookbook/transformers/ep_fsdp2_lora_qwen3_5_moe.py @@ -17,7 +17,7 @@ from twinkle.model import TransformersModel from twinkle.preprocessor import SelfCognitionProcessor from twinkle.utils.framework import Torch -from twinkle.kernel import kernelize_model +from twinkle.kernel import kernelize, npu_builtin logger = get_logger() @@ -106,7 +106,7 @@ def train(): ) # npu patch if Torch.is_npu_available(): - model = kernelize_model(model, mode='train', device='npu') + model = kernelize(model, npu_builtin(model)) lora_cfg = _build_lora_config(ENABLE_EP) model.add_adapter_to_model(ADAPTER_NAME, lora_cfg, gradient_accumulation_steps=GRAD_ACCUM_STEPS) model.set_optimizer('AdamW', lr=LR, foreach=False) diff --git a/cookbook/transformers/fsdp2.py b/cookbook/transformers/fsdp2.py index ad4c917f..dd7c0cb1 100644 --- a/cookbook/transformers/fsdp2.py +++ b/cookbook/transformers/fsdp2.py @@ -11,7 +11,7 @@ from twinkle.model import TransformersModel from twinkle.preprocessor import SelfCognitionProcessor from twinkle.utils.framework import Torch -from twinkle.kernel import kernelize_model +from twinkle.kernel import kernelize, npu_builtin logger = get_logger() args = CLI.from_args() @@ -58,7 +58,7 @@ def train(): model.model._no_split_modules = {'Qwen3_5DecoderLayer'} # npu patch if Torch.is_npu_available(): - model = kernelize_model(model, mode='train', device='npu') + model = kernelize(model, npu_builtin(model)) lora_config = LoraConfig(**args.get_lora_args()) model.add_adapter_to_model( diff --git a/cookbook/transformers/sp_fsdp_dense.py b/cookbook/transformers/sp_fsdp_dense.py index a6fd0bdc..2fd4ecf1 100644 --- a/cookbook/transformers/sp_fsdp_dense.py +++ b/cookbook/transformers/sp_fsdp_dense.py @@ -9,7 +9,7 @@ from twinkle.model import TransformersModel from twinkle.preprocessor import SelfCognitionProcessor from twinkle.utils.framework import Torch -from twinkle.kernel import kernelize_model +from twinkle.kernel import kernelize, npu_builtin logger = get_logger() MODEL_ID = 'ms://Qwen/Qwen3.5-4B' @@ -72,7 +72,7 @@ def train(): ) # npu patch if Torch.is_npu_available(): - model = kernelize_model(model, mode='train', device='npu') + model = kernelize(model, npu_builtin(model)) lora_config = LoraConfig(target_modules='all-linear') model.add_adapter_to_model('default', lora_config, gradient_accumulation_steps=1) model.set_optimizer('AdamW', lr=1e-4, adapter_name='default') From 109cf247bcf17d075790c8b238e81002d9c46884 Mon Sep 17 00:00:00 2001 From: weikaiwen <34648228+kevssim@users.noreply.github.com> Date: Fri, 26 Jun 2026 09:35:49 +0800 Subject: [PATCH 17/27] docs(kernel): rewrite Chinese doc for new mapping API --- .../\345\206\205\346\240\270/Kernel.md" | 344 +++++------------- 1 file changed, 87 insertions(+), 257 deletions(-) diff --git "a/docs/source_zh/\347\273\204\344\273\266/\345\206\205\346\240\270/Kernel.md" "b/docs/source_zh/\347\273\204\344\273\266/\345\206\205\346\240\270/Kernel.md" index 7f5c9f3a..8ed5ad78 100644 --- "a/docs/source_zh/\347\273\204\344\273\266/\345\206\205\346\240\270/Kernel.md" +++ "b/docs/source_zh/\347\273\204\344\273\266/\345\206\205\346\240\270/Kernel.md" @@ -1,307 +1,137 @@ # Twinkle Kernel 模块 -Twinkle Kernel 模块提供了两条内核替换路径,用于加速训练和推理: +`twinkle.kernel` 提供一个 mapping 驱动的内核替换接口,把“用一种实现替换模型里的另一种实现”压缩为一次 `kernelize(model, mapping)` 调用。 -* **层级 Kernelize(Layer-level kernelize)** - 使用优化内核替换完整的 `nn.Module` 实现。 -* **函数级 Kernelize(Function-level kernelize)** - 对 Python 模块中的特定函数进行 monkey-patch。 +公开符号只有三个: -这两种方式可以独立使用,也可以通过统一入口组合使用。 +| 符号 | 作用 | +| --- | --- | +| `kernelize(model, mapping)` | 在 `model` 上应用 `mapping`,原地修改后返回 | +| `npu_builtin(model=None)` | 返回 Ascend NPU 内置替换的 mapping dict(可与用户 mapping 自由组合) | +| `hub(ref, *, revision=None, version=None, backend=None, trust_remote_code=False)` | 构造一个 `HubRef`,用作 mapping value;真实下载推迟到 `kernelize` 执行 | ---- +## Mapping 语义 -## 概览:两条 Kernelize 路径 +`mapping` 的 **key** 表示要替换的目标: -| 路径 | 粒度 | 典型场景 | -| --- | --- | --- | -| 层级替换 | 整个 `nn.Module` | Linear / Conv / MLP / Attention | -| 函数级替换 | 单个函数 | 热点路径、数学算子、激活函数 | +- `type[nn.Module]` 子类:替换模型里**所有**该精确类型的实例(`m.__class__ = impl_class`,**不包含**子类) +- `str` 形如 `'pkg.sub.attr'` 或 `'pkg.sub.ClassName.attr'`:`setattr(target, attr, impl)` ---- +**value** 表示用什么替换: -## 层级内核替换(Layer-Level) +- `type[nn.Module]` 子类:直接作为 impl 类。该类**不会被 `__init__` 调用**,必须只依赖原 instance 已经有的 attribute(weight / eps / ...)正确工作 +- `Callable`:直接 `setattr` 上去 +- `dict[str, V]`:device → impl 嵌套分派。从 `model` 推断当前 device,未匹配则**静默跳过** +- `HubRef`:通过 `hub(...)` 构造的 Hub 引用,延迟加载 -### 适用场景 +device 从 `next(model.parameters()).device.type` 推断(无参数则用 buffers,再无则为 `'cpu'`)。 -* 你已经有完整的层内核实现 -* 希望在模型中批量替换某类 `nn.Module` -* 同时适用于训练与推理 +## 场景示例 ---- - -### 示例 1:本地 Kernel 仓库 - -适用于: - -* 内核实现位于本地仓库 -* 希望替换 HuggingFace 或自定义模型中的层 +### 启用全部 NPU 内置优化 ```python -from twinkle.kernel import ( - kernelize_model, - register_layer_kernel, - register_external_layer, -) -from transformers import Qwen2Config, Qwen2ForCausalLM -from transformers.models.qwen2.modeling_qwen2 import Qwen2MLP - -# 1) 从本地仓库注册层内核 -register_layer_kernel( - kernel_name="MyAwesomeMLP",/取的kernel名字,自定义 - repo_path="/path/to/local/repo",/本地kernel仓库路径 - package_name="my_kernels",/包名 - layer_name="Qwen2MLPTrainingKernel",/对应layer.py里面实现类的名字 - device="cuda",/适用的设备类型 - mode="train",/使用的场景:train or inference -) - -# 2) 绑定外部层与内核名 -register_external_layer(Qwen2MLP, "MyAwesomeMLP") - -# 3) 构建模型并应用内核替换 -config = Qwen2Config( - hidden_size=128, - num_hidden_layers=1, - num_attention_heads=4, - num_key_value_heads=4, - intermediate_size=256, - use_cache=False, -) -model = Qwen2ForCausalLM(config) -model = kernelize_model(model, mode="train", device="cuda", use_fallback=True) +import torch +from twinkle.kernel import kernelize, npu_builtin + +if torch.npu.is_available(): + model = kernelize(model, npu_builtin(model)) ``` ---- +### 自定义类替换 -### 示例 2:Hub Kernel 仓库 +```python +from transformers.models.qwen2.modeling_qwen2 import Qwen2RMSNorm +from twinkle.kernel import kernelize -适用于: +model = kernelize(model, {Qwen2RMSNorm: MyRMSNorm}) +``` -* 内核托管在 Hub 上 +### 内置 + 自定义混合 ```python -import torch -import torch.nn as nn -from twinkle.kernel import ( - kernelize_model, - register_layer_kernel, - register_external_layer, -) - -# 1) 定义自定义层 -class SiluAndMul(nn.Module): - def forward(self, x: torch.Tensor) -> torch.Tensor: - x1, x2 = x.chunk(2, dim=-1) - return nn.functional.silu(x1) * x2 - -# 2) 注册 Hub 内核并绑定层 -register_layer_kernel( - kernel_name="SiluAndMulKernel", - repo_id="kernels-community/activation", - layer_name="SiluAndMul", - device="cuda", - mode="train", -) -register_external_layer(SiluAndMul, "SiluAndMulKernel") - -# 3) 应用到模型 -class SimpleModel(nn.Module): - def __init__(self): - super().__init__() - self.activation = SiluAndMul() - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.activation(x) - -model = SimpleModel() -model = kernelize_model(model, mode="train", device="cuda", use_fallback=True) -``` +from twinkle.kernel import kernelize, npu_builtin ---- - -## 本地 Kernel 仓库(最小结构) +model = kernelize(model, {**npu_builtin(model), Qwen2RMSNorm: MyRMSNorm}) +``` -本地 kernel 仓库本质上是一个普通 Python 包。 -最少只需要一个 `layers.py` 来放层级内核实现。 +后写入的 key 会覆盖前面的,普通 dict 合并语义。 -```text -# 仓库结构: -my_kernels/ # 本地 kernel 仓库(Python 包) -├── __init__.py # 包入口 -└── layers.py # 层级 kernel 实现 -``` +### Hub Kernel(HF Hub 格式) ```python -# my_kernels/__init__.py -from . import layers -__all__ = ["layers"] +from twinkle.kernel import kernelize, hub +from my_pkg import SiluAndMul -# my_kernels/layers.py -import torch -import torch.nn as nn - -class Qwen2MLPTrainingKernel(nn.Module): - def forward(self, x: torch.Tensor) -> torch.Tensor: - gate = self.gate_proj(x) - up = self.up_proj(x) - return self.down_proj(self.act_fn(gate) * up) +model = kernelize(model, { + SiluAndMul: hub('kernels-community/activation:SiluAndMul', version=1), +}) ``` ---- - -## 函数级内核替换(Function-Level) +`revision` 与 `version` 二选一必传。`hub(...)` 触发 `kernels` 包的延迟 import,未安装时会提示 `pip install kernels`。 -### 适用场景 +### 函数级替换 -* 只需要加速少量热点函数 -* 不适合或不需要替换整个层 -* 常用于数学算子、激活函数、工具函数 +```python +from twinkle.kernel import kernelize +from twinkle.kernel.npu_impls.rotary import npu_apply_rotary_pos_emb ---- +model = kernelize(model, { + 'transformers.models.qwen2.modeling_qwen2.apply_rotary_pos_emb': + npu_apply_rotary_pos_emb, +}) +``` -### 示例 1:批量注册(简单场景) +### 跨设备 mapping(NPU 启用、CUDA 跳过) ```python -from twinkle.kernel import register_kernels, kernelize_model - -# 1) 注册函数内核 -config = { - "functions": { - "add": { - "target_module": "my_pkg.math_ops", - "func_impl": lambda x, y: x + y + 1, - "device": "cuda", - "mode": "inference", - }, - }, -} -register_kernels(config) +from twinkle.kernel import kernelize -# 2) 应用(仅函数替换时 model 可为 None) -kernelize_model(model=None, mode="inference", device="cuda", use_fallback=True) +model = kernelize(model, { + Qwen2RMSNorm: {'npu': NpuRMSNorm, 'cuda': CudaRMSNorm}, +}) ``` ---- +在 CUDA 模型上跑也安全:未匹配 device 的 entry 不会替换、不会报错。 -### 示例 2:高级函数来源(完整控制) +## 内置 NPU 优化 -适用于: +`npu_builtin(model)` 返回的 dict 至少包含以下覆盖(实际条目随 transformers 已安装的 modeling 模块动态收集): -* 不同函数来自不同来源(impl / repo / hub),或需要 compile/backward 等标志。 +- Qwen2 / Qwen3 / Qwen3-MoE / Qwen2.5-VL / Qwen3.5 / Qwen3.5-MoE 系列的 RMSNorm 类替换 +- 同上系列的 `apply_rotary_pos_emb` 函数替换(融合 RoPE) +- 同上系列 MLP 的 SwiGLU 融合替换 +- Qwen3-MoE / Qwen3.5-MoE 的 `Experts.forward` 与 `SparseMoeBlock.forward` 替换 +- Qwen3.5 / Qwen3.5-MoE 的 GatedRMSNorm forward 替换 +- Qwen2.5-VL 的 `apply_multimodal_rotary_pos_emb` 替换 +- 全局 SDPA 替换(一次性副作用,写入 `ALL_ATTENTION_FUNCTIONS['sdpa']`) +- Qwen3.5 Flash Linear Attention 启用(一次性副作用 + 实例遍历,由 `npu_builtin(model)` 内部触发) + +**未默认包含** `transformers.integrations.moe._grouped_mm` 的 NPU 替换(在没有 Expert Parallelism 时会带来约 8x 开销)。需要时手动加入: ```python -from twinkle.kernel.function import ( - register_function_kernel, - apply_function_kernel, -) -import torch.nn as nn -from twinkle.kernel import kernelize_model - -TARGET_MODULE = "my_pkg.math_ops" - -# 1) 直接传入实现 -def fast_add(x, y): - return x + y + 1 - -register_function_kernel( - func_name="add", - target_module=TARGET_MODULE, - func_impl=fast_add, - device="cuda", - mode="inference", -) - -# 2) Repo 对象(FuncRepositoryProtocol) -class MyFuncRepo: - def load(self): - return MyKernelFunc - -class MyKernelFunc(nn.Module): - def forward(self, x, y): - return x * y - -register_function_kernel( - func_name="mul", - target_module=TARGET_MODULE, - repo=MyFuncRepo(), - device="cuda", - mode="compile", -) - -# 3) Hub 仓库 -register_function_kernel( - func_name="silu_and_mul", - target_module="my_pkg.activations", - repo_id="kernels-community/activation", - revision="main", # 或 version="0.1.0" - device="cuda", - mode="inference", -) - -# 4) 应用函数内核 -applied = apply_function_kernel( - target_module=TARGET_MODULE, - device="cuda", - mode="inference", - strict=False, -) -print("patched:", applied) - -# 5) 可选:通过 kernelize_model 统一应用 -model = nn.Sequential(nn.Linear(8, 8), nn.ReLU()) -kernelize_model(model=model, mode="inference", device="cuda", use_fallback=True) +from twinkle.kernel import kernelize, npu_builtin +from twinkle.kernel.npu_impls.moe import npu_grouped_mm + +mapping = { + **npu_builtin(model), + 'transformers.integrations.moe._grouped_mm': {'npu': npu_grouped_mm}, +} +model = kernelize(model, mapping) ``` ---- +## 环境变量 -## 层级 + 函数级统一批量注册 +只有两个保留: -### 适用场景 +- `TWINKLE_NPU_FLA`:Qwen3.5 FLA 开关(默认开,设为 `0`/`false` 关闭) +- `TWINKLE_NPU_GATED_RMSNorm_FP32`:将 Gated RMSNorm 强制升到 FP32 计算(默认关) -* 需要框架级统一集成 -* 希望通过单一配置入口管理 -* 同时管理层和函数两类内核 +旧的 `TWINKLE_NPU_PATCH` / `TWINKLE_NPU_FUSED_OPS` / `TWINKLE_NPU_GMM_PATCH` / `TWINKLE_USE_KERNELS` 已移除——这些都改成"是否把对应 entry 写进 mapping"的显式选择。 -```python -from twinkle.kernel import register_kernels, kernelize_model -import torch.nn as nn - -# 1) 注册层级 + 函数级内核 -config = { - "layers": { - "linear": { - "repo_id": "kernels-community/linear", - "layer_name": "Linear", - "version": "0.1.0", - "device": "cuda", - "mode": "train", - }, - "conv2d": { - "repo_path": "/path/to/local/repo", - "package_name": "my_kernels", - "layer_name": "Conv2d", - "device": "cuda", - }, - }, - "functions": { - "add": { - "target_module": "my_pkg.math_ops", - "func_impl": lambda x, y: x + y + 1, - "device": "cuda", - "mode": "inference", - }, - "relu": { - "target_module": "my_pkg.activations", - "repo_id": "kernels-community/activation", - "revision": "main", - "device": "cuda", - }, - }, -} -register_kernels(config) +## 注意事项 -# 2) 通过 kernelize_model 应用 -model = nn.Sequential(nn.Linear(8, 8), nn.ReLU()) -kernelize_model(model=model, mode="train", device="cuda", use_fallback=True) -``` +- `m.__class__ = impl_cls` 是 Python class 替换魔法。impl 类**必须**只覆盖 `forward`(以及辅助方法),不能定义 `__init__`,否则原 instance 的 attribute 会与 impl 的预期错位 +- 精确匹配:`type(m) is target_cls`。继承自 `target_cls` 的子类不会被替换;如需替换,把子类也放进 mapping +- 调用 `kernelize` 多次是幂等的(`__class__` 已是 impl 时再设一次无害) +- 没有 `unkernelize`——替换是单向的 \ No newline at end of file From c7babac821512a150abeff1f46c9499a1a126677 Mon Sep 17 00:00:00 2001 From: weikaiwen <34648228+kevssim@users.noreply.github.com> Date: Fri, 26 Jun 2026 09:58:01 +0800 Subject: [PATCH 18/27] docs(kernel): rewrite English doc for new mapping API --- docs/source_en/Components/Kernel/Kernel.md | 349 ++++++--------------- 1 file changed, 90 insertions(+), 259 deletions(-) diff --git a/docs/source_en/Components/Kernel/Kernel.md b/docs/source_en/Components/Kernel/Kernel.md index d587b540..f5ab78e9 100644 --- a/docs/source_en/Components/Kernel/Kernel.md +++ b/docs/source_en/Components/Kernel/Kernel.md @@ -1,308 +1,139 @@ -# Twinkle Kernel Module +# Twinkle Kernel -The Twinkle Kernel Module provides two kernel replacement paths for accelerating models during training and inference: +`twinkle.kernel` exposes a mapping-driven kernel replacement API. Replacing one +implementation with another collapses to a single `kernelize(model, mapping)` +call. -* **Layer-level kernelize** - Replace entire `nn.Module` implementations with optimized kernels. -* **Function-level kernelize** - Monkey-patch specific functions inside a Python module. +The public surface is exactly three symbols: -These two approaches can be used independently or together via a unified registration and application entry point. +| Symbol | Purpose | +| --- | --- | +| `kernelize(model, mapping)` | Apply ``mapping`` to ``model`` (in place) and return it | +| `npu_builtin(model=None)` | Return the Ascend NPU built-in mapping (composes with user mappings) | +| `hub(ref, *, revision=None, version=None, backend=None, trust_remote_code=False)` | Build a ``HubRef`` for use as a mapping value; the actual Hub download is deferred to ``kernelize`` | ---- +## Mapping semantics -## Overview: Two Kernelization Paths +`mapping` keys describe the target to replace: -| Path | Granularity | Typical Use Cases | -| -------------- | -------------------- | -------------------------------- | -| Layer-level | Whole `nn.Module` | Linear / Conv / MLP / Attention | -| Function-level | Individual functions | Hot paths, math ops, activations | +- `type[nn.Module]` subclass — replace **every** instance whose exact type matches (`m.__class__ = impl`; subclasses are **not** touched) +- `str` of the form `'pkg.sub.attr'` or `'pkg.sub.ClassName.attr'` — `setattr(target, attr, impl)` ---- +`mapping` values describe the replacement: -## Layer-Level Kernel Replacement +- `type[nn.Module]` subclass — used as the impl class. The class' `__init__` is **never** invoked; its forward must work against the attributes the original instance already has +- `Callable` — assigned with `setattr` +- `dict[str, V]` — device → impl dispatch. Device is inferred from the model; entries without a matching key are **silently skipped** +- `HubRef` — built via `hub(...)`; resolved lazily -### When to Use +Device is inferred from `next(model.parameters()).device.type` (falling back to buffers, then `'cpu'`). -* You have a complete kernel implementation for a layer -* You want model-wide replacement of specific `nn.Module` types -* Suitable for both training and inference +## Examples ---- - -### Example 1: Local Kernel Repo - -Use this when: - -* Kernel implementations live in a local repository -* You want to replace layers in HuggingFace or custom models +### Enable the full NPU built-in bundle ```python -from twinkle.kernel import ( - kernelize_model, - register_layer_kernel, - register_external_layer, -) -from transformers import Qwen2Config, Qwen2ForCausalLM -from transformers.models.qwen2.modeling_qwen2 import Qwen2MLP - -# 1) Register the layer kernel from a local repo -register_layer_kernel( - kernel_name="MyAwesomeMLP", - repo_path="/path/to/local/repo", - package_name="my_kernels", - layer_name="Qwen2MLPTrainingKernel", - device="cuda", - mode="train", -) - -# 2) Bind external layer to kernel name -register_external_layer(Qwen2MLP, "MyAwesomeMLP") - -# 3) Build the model and apply kernelization -config = Qwen2Config( - hidden_size=128, - num_hidden_layers=1, - num_attention_heads=4, - num_key_value_heads=4, - intermediate_size=256, - use_cache=False, -) -model = Qwen2ForCausalLM(config) -model = kernelize_model(model, mode="train", device="cuda", use_fallback=True) -``` - ---- - -### Example 2: Hub Kernel Repo +import torch +from twinkle.kernel import kernelize, npu_builtin -Use this when: +if torch.npu.is_available(): + model = kernelize(model, npu_builtin(model)) +``` -* The kernel is hosted on a Hub +### Custom class replacement ```python -import torch -import torch.nn as nn -from twinkle.kernel import ( - kernelize_model, - register_layer_kernel, - register_external_layer, -) - -# 1) Define the custom layer -class SiluAndMul(nn.Module): - def forward(self, x: torch.Tensor) -> torch.Tensor: - x1, x2 = x.chunk(2, dim=-1) - return nn.functional.silu(x1) * x2 - -# 2) Register the Hub kernel and bind the layer -register_layer_kernel( - kernel_name="SiluAndMulKernel", - repo_id="kernels-community/activation", - layer_name="SiluAndMul", - device="cuda", - mode="train", -) -register_external_layer(SiluAndMul, "SiluAndMulKernel") - -# 3) Apply to a model -class SimpleModel(nn.Module): - def __init__(self): - super().__init__() - self.activation = SiluAndMul() - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.activation(x) - -model = SimpleModel() -model = kernelize_model(model, mode="train", device="cuda", use_fallback=True) +from transformers.models.qwen2.modeling_qwen2 import Qwen2RMSNorm +from twinkle.kernel import kernelize + +model = kernelize(model, {Qwen2RMSNorm: MyRMSNorm}) ``` ---- +### Built-in + custom override -## Local Kernel Repo (Minimal) +```python +from twinkle.kernel import kernelize, npu_builtin -A local kernel repository is a regular Python package. -At minimum, it only needs a `layers.py` file for layer-level kernels. +model = kernelize(model, {**npu_builtin(model), Qwen2RMSNorm: MyRMSNorm}) +``` -```text -# Repo layout: -my_kernels/ # Local kernel repository (Python package) -├── __init__.py # Package entry -└── layers.py # Layer-level kernel implementations +Plain dict merge — later keys override earlier ones. -``` +### Hub kernel (HF Hub format) ```python -# my_kernels/__init__.py -from . import layers -__all__ = ["layers"] +from twinkle.kernel import kernelize, hub +from my_pkg import SiluAndMul -# my_kernels/layers.py -import torch -import torch.nn as nn - -class Qwen2MLPTrainingKernel(nn.Module): - def forward(self, x: torch.Tensor) -> torch.Tensor: - gate = self.gate_proj(x) - up = self.up_proj(x) - return self.down_proj(self.act_fn(gate) * up) +model = kernelize(model, { + SiluAndMul: hub('kernels-community/activation:SiluAndMul', version=1), +}) ``` ---- - -## Function-Level Kernel Replacement +Exactly one of `revision` / `version` must be passed. The `kernels` package is imported lazily; absence raises a clear "install kernels" error. -### When to Use +### Function-level replacement -* You only need to accelerate a small number of hot functions -* Replacing the entire layer is unnecessary or impractical -* Common for math ops, activations, or utility functions +```python +from twinkle.kernel import kernelize +from twinkle.kernel.npu_impls.rotary import npu_apply_rotary_pos_emb ---- +model = kernelize(model, { + 'transformers.models.qwen2.modeling_qwen2.apply_rotary_pos_emb': + npu_apply_rotary_pos_emb, +}) +``` -### Example 1: Batch Registration (Simple Case) +### Cross-device mapping (NPU enabled, CUDA skipped) ```python -from twinkle.kernel import register_kernels, kernelize_model - -# 1) Register function kernels -config = { - "functions": { - "add": { - "target_module": "my_pkg.math_ops", - "func_impl": lambda x, y: x + y + 1, - "device": "cuda", - "mode": "inference", - }, - }, -} -register_kernels(config) +from twinkle.kernel import kernelize -# 2) Apply (model can be None when only functions are used) -kernelize_model(model=None, mode="inference", device="cuda", use_fallback=True) +model = kernelize(model, { + Qwen2RMSNorm: {'npu': NpuRMSNorm, 'cuda': CudaRMSNorm}, +}) ``` ---- +Safe to run on CUDA — entries whose dict misses the current device just skip. -### Example 2: Advanced Function Sources (Full Control) +## NPU built-in coverage -Use this when: +`npu_builtin(model)` returns a dict that (as available transformers modules permit) covers: -* Use when different functions come from different sources (impl / repo / hub) or need compile/backward flags. +- RMSNorm class replacement for Qwen2 / Qwen3 / Qwen3-MoE / Qwen2.5-VL / Qwen3.5 / Qwen3.5-MoE families +- `apply_rotary_pos_emb` function replacement (fused RoPE) for the same families +- SwiGLU fused replacement for the MLP variants +- `Experts.forward` and `SparseMoeBlock.forward` for Qwen3-MoE / Qwen3.5-MoE +- GatedRMSNorm forward for Qwen3.5 / Qwen3.5-MoE +- `apply_multimodal_rotary_pos_emb` for Qwen2.5-VL +- Global SDPA replacement (one-shot side effect on `ALL_ATTENTION_FUNCTIONS['sdpa']`) +- Qwen3.5 Flash Linear Attention enablement (one-shot side effect + per-instance traversal, triggered inside `npu_builtin(model)`) + +**Not included by default:** the NPU replacement for `transformers.integrations.moe._grouped_mm`. Without Expert Parallelism the contiguous-copy overhead is ~8x. Opt in explicitly when EP is enabled: ```python -from twinkle.kernel.function import ( - register_function_kernel, - apply_function_kernel, -) -import torch.nn as nn -from twinkle.kernel import kernelize_model - -TARGET_MODULE = "my_pkg.math_ops" - -# 1) Direct implementation -def fast_add(x, y): - return x + y + 1 - -register_function_kernel( - func_name="add", - target_module=TARGET_MODULE, - func_impl=fast_add, - device="cuda", - mode="inference", -) - -# 2) Repo object (FuncRepositoryProtocol) -class MyFuncRepo: - def load(self): - return MyKernelFunc - -class MyKernelFunc(nn.Module): - def forward(self, x, y): - return x * y - -register_function_kernel( - func_name="mul", - target_module=TARGET_MODULE, - repo=MyFuncRepo(), - device="cuda", - mode="compile", -) - -# 3) Hub repo -register_function_kernel( - func_name="silu_and_mul", - target_module="my_pkg.activations", - repo_id="kernels-community/activation", - revision="main", # or version="0.1.0" - device="cuda", - mode="inference", -) - -# 4) Apply function kernels -applied = apply_function_kernel( - target_module=TARGET_MODULE, - device="cuda", - mode="inference", - strict=False, -) -print("patched:", applied) - -# 5) Optional: unified entry via kernelize_model -model = nn.Sequential(nn.Linear(8, 8), nn.ReLU()) -kernelize_model(model=model, mode="inference", device="cuda", use_fallback=True) +from twinkle.kernel import kernelize, npu_builtin +from twinkle.kernel.npu_impls.moe import npu_grouped_mm + +mapping = { + **npu_builtin(model), + 'transformers.integrations.moe._grouped_mm': {'npu': npu_grouped_mm}, +} +model = kernelize(model, mapping) ``` ---- +## Environment variables -## Unified Layer + Function Batch Registration +Only two remain: -### When to Use +- `TWINKLE_NPU_FLA` — Qwen3.5 FLA switch (default on; `0`/`false` to disable) +- `TWINKLE_NPU_GATED_RMSNorm_FP32` — force FP32 in Gated RMSNorm forward (default off) -* Framework-level integration -* A single configuration entry point is preferred -* Managing both layer and function kernels together +The legacy `TWINKLE_NPU_PATCH` / `TWINKLE_NPU_FUSED_OPS` / `TWINKLE_NPU_GMM_PATCH` / `TWINKLE_USE_KERNELS` are gone — they're now "include the entry in the mapping or don't" decisions. -```python -from twinkle.kernel import register_kernels, kernelize_model -import torch.nn as nn - -# 1) Register layer + function kernels -config = { - "layers": { - "linear": { - "repo_id": "kernels-community/linear", - "layer_name": "Linear", - "version": "0.1.0", - "device": "cuda", - "mode": "train", - }, - "conv2d": { - "repo_path": "/path/to/local/repo", - "package_name": "my_kernels", - "layer_name": "Conv2d", - "device": "cuda", - }, - }, - "functions": { - "add": { - "target_module": "my_pkg.math_ops", - "func_impl": lambda x, y: x + y + 1, - "device": "cuda", - "mode": "inference", - }, - "relu": { - "target_module": "my_pkg.activations", - "repo_id": "kernels-community/activation", - "revision": "main", - "device": "cuda", - }, - }, -} -register_kernels(config) +## Caveats -# 2) Apply via kernelize_model -model = nn.Sequential(nn.Linear(8, 8), nn.ReLU()) -kernelize_model(model=model, mode="train", device="cuda", use_fallback=True) -``` +- `m.__class__ = impl_cls` is Python class-replacement magic. The impl class **must** override only `forward` (and helpers); defining `__init__` is incompatible with the contract +- Exact match: `type(m) is target_cls`. Subclasses of `target_cls` are not replaced — add them to the mapping yourself +- `kernelize` is idempotent under repeated calls +- There is no `unkernelize` — replacement is one-way \ No newline at end of file From 997697fd835e37b0bae7f1ed5fa66b0e4e940ab1 Mon Sep 17 00:00:00 2001 From: weikaiwen <34648228+kevssim@users.noreply.github.com> Date: Fri, 26 Jun 2026 10:22:18 +0800 Subject: [PATCH 19/27] fix(kernel): gate SDPA install on NPU host and FLA flag on MindSpeed load - builtin.py: _install_sdpa() now only runs when torch_npu is importable, preventing the NPU (boolean-mask-inverting) SDPA impl from contaminating the global ALL_ATTENTION_FUNCTIONS['sdpa'] registry on CUDA/CPU hosts. - builtin.py: drop dead _SdpaPatchSentinel + add/pop scaffolding. - fla.py: flip is_flash_linear_attention_available only after the MindSpeed kernel imports successfully; previously a MindSpeed-missing NPU host would be left with FLA flagged available but no kernel installed -> Qwen3.5 runtime failure. --- src/twinkle/kernel/builtin.py | 23 ++++++++--------- src/twinkle/kernel/npu_impls/fla.py | 28 ++++++++++++--------- tests/kernel/npu_impls/test_fla.py | 38 ++++++++++++++++++++++++++++- tests/kernel/test_builtin.py | 14 ++++++++++- 4 files changed, 76 insertions(+), 27 deletions(-) diff --git a/src/twinkle/kernel/builtin.py b/src/twinkle/kernel/builtin.py index a11ec0e8..62efeac2 100644 --- a/src/twinkle/kernel/builtin.py +++ b/src/twinkle/kernel/builtin.py @@ -46,15 +46,16 @@ def npu_builtin(model: nn.Module | None = None) -> dict[Any, dict[str, Any]]: bundle: dict[Any, dict[str, Any]] = {} - # SDPA attention (global) - bundle['transformers.modeling_utils.ALL_ATTENTION_FUNCTIONS'] = {'npu': _SdpaPatchSentinel()} - # NOTE: ALL_ATTENTION_FUNCTIONS is a dict, not a function. We can't setattr - # it. We instead install the sdpa entry by a small bootstrap below. - # Remove the sentinel approach in favor of explicit module-level entries: - bundle.pop('transformers.modeling_utils.ALL_ATTENTION_FUNCTIONS', None) - - # Apply SDPA install eagerly (one-shot module-level mutation). - _install_sdpa(npu_sdpa_attention_forward) + # Apply SDPA install eagerly (one-shot module-level mutation) — only on + # NPU hosts. The NPU impl inverts boolean masks, which is wrong for + # CUDA/CPU execution, so we must not contaminate the global HF registry + # when ``npu_builtin()`` is constructed on a non-NPU machine. + try: + import torch_npu # noqa: F401 + except ImportError: + pass + else: + _install_sdpa(npu_sdpa_attention_forward) # === per-family class + function entries === _add_qwen2_entries(bundle, NpuRMSNorm, npu_apply_rotary_pos_emb, npu_swiglu_forward) @@ -83,10 +84,6 @@ def npu_builtin(model: nn.Module | None = None) -> dict[Any, dict[str, Any]]: return bundle -class _SdpaPatchSentinel: - pass # unused; placeholder retained for clarity in diffs - - def _install_sdpa(impl) -> None: """One-shot install of SDPA attention forward (global modeling_utils dict).""" try: diff --git a/src/twinkle/kernel/npu_impls/fla.py b/src/twinkle/kernel/npu_impls/fla.py index 4ed91ae1..ca5aa2e1 100644 --- a/src/twinkle/kernel/npu_impls/fla.py +++ b/src/twinkle/kernel/npu_impls/fla.py @@ -42,7 +42,20 @@ def apply_qwen3_5_fla(model=None) -> int: logger.info('[NPU] [FLA] Skip: torch_npu unavailable') return 0 - # 1. Force FLA availability flags on transformers utility modules + # 1. Confirm the MindSpeed Triton kernel is actually importable BEFORE + # flipping any global availability flags. If we flip the flag and then + # fail to install the kernel, HF transformers would route Qwen3.5 onto + # a FLA fast path whose kernel is missing -> runtime failure on NPU. + try: + from twinkle.kernel.chunk_gated_delta_rule import chunk_gated_delta_rule as mindspeed_fla + except ImportError as exc: + logger.warning('[NPU] [FLA] MindSpeed unavailable: %s', exc) + return 0 + + from twinkle.kernel.causal_conv1d import npu_causal_conv1d_fn + + # 2. Only now can we safely claim FLA is available: flip the global flags + # and install the kernel path on Qwen3.5 modeling modules. def _is_fla_available() -> bool: return True @@ -51,15 +64,6 @@ def _is_fla_available() -> bool: if utils_mod is not None: setattr(utils_mod, 'is_flash_linear_attention_available', _is_fla_available) - # 2. Try to load MindSpeed Triton kernel - try: - from twinkle.kernel.chunk_gated_delta_rule import chunk_gated_delta_rule as mindspeed_fla - except ImportError as exc: - logger.warning('[NPU] [FLA] MindSpeed unavailable: %s', exc) - mindspeed_fla = None - - from twinkle.kernel.causal_conv1d import npu_causal_conv1d_fn - # 3. Patch Qwen3.5 modeling modules fla_target_modules = [ 'transformers.models.qwen3_5.modeling_qwen3_5', @@ -67,7 +71,7 @@ def _is_fla_available() -> bool: ] for module_name in fla_target_modules: module = _import_optional(module_name) - if module is None or mindspeed_fla is None: + if module is None: continue setattr(module, 'is_flash_linear_attention_available', _is_fla_available) setattr(module, 'is_fast_path_available', True) @@ -76,7 +80,7 @@ def _is_fla_available() -> bool: setattr(module, 'chunk_gated_delta_rule', mindspeed_fla) # 4. Traverse model and patch per-layer attributes - if model is None or mindspeed_fla is None: + if model is None: return 0 root = getattr(model, 'model', getattr(model, 'module', model)) diff --git a/tests/kernel/npu_impls/test_fla.py b/tests/kernel/npu_impls/test_fla.py index 0e8b07bb..0cfeda1d 100644 --- a/tests/kernel/npu_impls/test_fla.py +++ b/tests/kernel/npu_impls/test_fla.py @@ -16,4 +16,40 @@ def test_fla_skips_when_no_torch_npu(monkeypatch): monkeypatch.setitem(sys.modules, 'torch_npu', None) # forces ImportError on import from twinkle.kernel.npu_impls import fla as fla_mod # Reload-tolerant: should return 0 when torch_npu is missing. - assert fla_mod.apply_qwen3_5_fla(None) == 0 \ No newline at end of file + assert fla_mod.apply_qwen3_5_fla(None) == 0 + + +def test_fla_does_not_flip_flag_when_mindspeed_missing(monkeypatch): + """On an NPU host where the MindSpeed FLA kernel cannot be imported, + ``apply_qwen3_5_fla`` must NOT flip the global ``is_flash_linear_attention_available`` + flag — otherwise HF transformers would route Qwen3.5 onto a FLA fast path + whose kernel is not installed (runtime failure).""" + import sys + import types + + import transformers.utils as tu + + monkeypatch.setenv('TWINKLE_NPU_FLA', '1') + # Fake torch_npu as importable (with a real __spec__ so find_spec doesn't trip) + import importlib.util + spec = importlib.util.spec_from_loader('torch_npu', loader=None) + fake_npu = importlib.util.module_from_spec(spec) + monkeypatch.setitem(sys.modules, 'torch_npu', fake_npu) + # Stub causal_conv1d so the heavy real import chain doesn't run + fake_conv = types.ModuleType('twinkle.kernel.causal_conv1d') + fake_conv.npu_causal_conv1d_fn = object() + monkeypatch.setitem(sys.modules, 'twinkle.kernel.causal_conv1d', fake_conv) + # Force the MindSpeed-backed module import to fail + monkeypatch.setitem(sys.modules, 'twinkle.kernel.chunk_gated_delta_rule', None) + + original_flag = tu.is_flash_linear_attention_available + try: + from twinkle.kernel.npu_impls.fla import apply_qwen3_5_fla + assert apply_qwen3_5_fla(None) == 0 + assert tu.is_flash_linear_attention_available is original_flag, ( + 'is_flash_linear_attention_available was flipped to True while the ' + 'MindSpeed kernel is unavailable — this would break Qwen3.5 at runtime.' + ) + finally: + # Defensive cleanup in case the buggy path ran. + tu.is_flash_linear_attention_available = original_flag \ No newline at end of file diff --git a/tests/kernel/test_builtin.py b/tests/kernel/test_builtin.py index 3b4e68d4..44786175 100644 --- a/tests/kernel/test_builtin.py +++ b/tests/kernel/test_builtin.py @@ -45,4 +45,16 @@ def test_npu_builtin_skips_missing_modeling_modules(): still produce a dict (with whatever subset is available).""" from twinkle.kernel.builtin import npu_builtin bundle = npu_builtin() # must not raise - assert isinstance(bundle, dict) \ No newline at end of file + assert isinstance(bundle, dict) + + +def test_npu_builtin_does_not_overwrite_global_sdpa_on_non_npu_host(): + """Calling npu_builtin() on a CUDA/CPU host must not contaminate the + global HF SDPA registry. The NPU impl inverts boolean masks, which is + wrong for non-NPU execution.""" + from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS + from twinkle.kernel.builtin import npu_builtin + + original = ALL_ATTENTION_FUNCTIONS.get('sdpa') + npu_builtin() + assert ALL_ATTENTION_FUNCTIONS.get('sdpa') is original \ No newline at end of file From a742827cf02c8079f4c2dceac1d9ba1a26f1e886 Mon Sep 17 00:00:00 2001 From: weikaiwen <34648228+kevssim@users.noreply.github.com> Date: Fri, 26 Jun 2026 11:03:32 +0800 Subject: [PATCH 20/27] wip --- src/twinkle/kernel/builtin.py | 22 +++++++++++----------- src/twinkle/kernel/core.py | 20 +++++++------------- 2 files changed, 18 insertions(+), 24 deletions(-) diff --git a/src/twinkle/kernel/builtin.py b/src/twinkle/kernel/builtin.py index 62efeac2..8b97d473 100644 --- a/src/twinkle/kernel/builtin.py +++ b/src/twinkle/kernel/builtin.py @@ -18,6 +18,7 @@ import torch.nn as nn from twinkle import get_logger +from twinkle.utils.device_mesh import Platform logger = get_logger() @@ -46,15 +47,13 @@ def npu_builtin(model: nn.Module | None = None) -> dict[Any, dict[str, Any]]: bundle: dict[Any, dict[str, Any]] = {} - # Apply SDPA install eagerly (one-shot module-level mutation) — only on - # NPU hosts. The NPU impl inverts boolean masks, which is wrong for - # CUDA/CPU execution, so we must not contaminate the global HF registry - # when ``npu_builtin()`` is constructed on a non-NPU machine. - try: - import torch_npu # noqa: F401 - except ImportError: - pass - else: + is_npu_platform = Platform.device_prefix() == 'npu' + + # Apply SDPA install eagerly (one-shot module-level mutation) on NPU + # platforms. The NPU impl inverts boolean masks, which is wrong for + # CUDA/CPU execution, so non-NPU platforms must not mutate the global HF + # registry even if ``torch_npu`` is importable in the environment. + if is_npu_platform: _install_sdpa(npu_sdpa_attention_forward) # === per-family class + function entries === @@ -79,7 +78,8 @@ def npu_builtin(model: nn.Module | None = None) -> dict[Any, dict[str, Any]]: ) # === FLA (side-effect; mapping-incompatible) === - apply_qwen3_5_fla(model) + if is_npu_platform: + apply_qwen3_5_fla(model) return bundle @@ -194,4 +194,4 @@ def _add_qwen3_5_moe_entries(bundle, rms_cls, gated_rms_fn, rope_fn, swiglu_fn, _add_swiglu_if_present(bundle, base, 'Qwen3_5MoeMLP', swiglu_fn) _add_attr_if_present(bundle, base, 'Qwen3_5MoeExperts.forward', experts_fn) _add_attr_if_present(bundle, base, 'Qwen3_5MoeSparseMoeBlock.forward', sparse_fn) - _add_attr_if_present(bundle, base, 'Qwen3_5MoeGatedRMSNorm.forward', gated_rms_fn) \ No newline at end of file + _add_attr_if_present(bundle, base, 'Qwen3_5MoeGatedRMSNorm.forward', gated_rms_fn) diff --git a/src/twinkle/kernel/core.py b/src/twinkle/kernel/core.py index 362873aa..a3a12f18 100644 --- a/src/twinkle/kernel/core.py +++ b/src/twinkle/kernel/core.py @@ -11,6 +11,8 @@ import torch.nn as nn +from twinkle.utils.device_mesh import Platform + @dataclass(frozen=True) class HubRef: @@ -47,17 +49,9 @@ def hub( return HubRef(repo_id, layer_name, revision, version, backend, trust_remote_code) -def _infer_device(model: nn.Module) -> str: - """Infer the device type from the first parameter, then first buffer, else cpu.""" - for p in model.parameters(): - return p.device.type - for b in model.buffers(): - return b.device.type - return 'cpu' - def _resolve_value(value: Any, device: str) -> Any | None: - """Resolve a mapping value against the inferred device. + """Resolve a mapping value against the selected device. - ``dict``: device-conditional; recurse into ``value[device]`` or return None. - anything else (including ``HubRef``): pass through. @@ -153,15 +147,15 @@ def kernelize(model: nn.Module, mapping: dict) -> nn.Module: identified module attribute. Values: - - ``dict[str, V]``: device-conditional dispatch using - ``next(model.parameters()).device.type``; non-matching devices skip. + - ``dict[str, V]``: device-conditional dispatch using the current + Twinkle platform device prefix; non-matching devices skip. - ``HubRef``: lazy-resolved via the optional ``kernels`` package. - anything else: used directly as the impl. """ if not mapping: return model - device = _infer_device(model) + device = Platform.device_prefix() for key, value in mapping.items(): impl = _resolve_value(value, device) if impl is None: @@ -174,4 +168,4 @@ def kernelize(model: nn.Module, mapping: dict) -> nn.Module: _replace_attr(key, impl) else: raise TypeError(f'Unsupported mapping key: {key!r}') - return model \ No newline at end of file + return model From dc7cb9387aeb2e21b7e6166c81599822ece7a6a3 Mon Sep 17 00:00:00 2001 From: weikaiwen <34648228+kevssim@users.noreply.github.com> Date: Fri, 26 Jun 2026 11:40:13 +0800 Subject: [PATCH 21/27] wip --- src/twinkle/kernel/builtin.py | 12 ++++++++++-- src/twinkle/kernel/npu_impls/fla.py | 3 +-- src/twinkle/kernel/npu_impls/rms_norm.py | 22 +++++++++------------- 3 files changed, 20 insertions(+), 17 deletions(-) diff --git a/src/twinkle/kernel/builtin.py b/src/twinkle/kernel/builtin.py index 8b97d473..8d7f3b59 100644 --- a/src/twinkle/kernel/builtin.py +++ b/src/twinkle/kernel/builtin.py @@ -85,7 +85,12 @@ def npu_builtin(model: nn.Module | None = None) -> dict[Any, dict[str, Any]]: def _install_sdpa(impl) -> None: - """One-shot install of SDPA attention forward (global modeling_utils dict).""" + """One-shot install of SDPA attention forward (global modeling_utils dict). + + ``AttentionInterface._global_mapping`` is a private transformers attribute; + guard against its removal so an upstream change can't take down the rest + of ``npu_builtin()``. + """ try: from transformers.modeling_utils import ( ALL_ATTENTION_FUNCTIONS, @@ -93,7 +98,10 @@ def _install_sdpa(impl) -> None: ) except ImportError: return - AttentionInterface._global_mapping['sdpa'] = impl + try: + AttentionInterface._global_mapping['sdpa'] = impl + except AttributeError: + logger.warning('[NPU] [SDPA] AttentionInterface._global_mapping unavailable; skipping') ALL_ATTENTION_FUNCTIONS['sdpa'] = impl diff --git a/src/twinkle/kernel/npu_impls/fla.py b/src/twinkle/kernel/npu_impls/fla.py index ca5aa2e1..d2fc43a9 100644 --- a/src/twinkle/kernel/npu_impls/fla.py +++ b/src/twinkle/kernel/npu_impls/fla.py @@ -48,12 +48,11 @@ def apply_qwen3_5_fla(model=None) -> int: # a FLA fast path whose kernel is missing -> runtime failure on NPU. try: from twinkle.kernel.chunk_gated_delta_rule import chunk_gated_delta_rule as mindspeed_fla + from twinkle.kernel.causal_conv1d import npu_causal_conv1d_fn except ImportError as exc: logger.warning('[NPU] [FLA] MindSpeed unavailable: %s', exc) return 0 - from twinkle.kernel.causal_conv1d import npu_causal_conv1d_fn - # 2. Only now can we safely claim FLA is available: flip the global flags # and install the kernel path on Qwen3.5 modeling modules. def _is_fla_available() -> bool: diff --git a/src/twinkle/kernel/npu_impls/rms_norm.py b/src/twinkle/kernel/npu_impls/rms_norm.py index 281443bd..ecebdc23 100644 --- a/src/twinkle/kernel/npu_impls/rms_norm.py +++ b/src/twinkle/kernel/npu_impls/rms_norm.py @@ -48,25 +48,21 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return torch_npu.npu_rms_norm(hidden_states, scale, epsilon=self._twinkle_eps())[0] -def npu_gated_rms_norm_forward(self, hidden_states, gate=None): - """Forward replacement for Gated RMSNorm variants (e.g. Qwen3.5-MoE). +# Resolved once at import: matches the legacy "patch-time, process-wide" invariant. +# Mid-process env mutation will not retroactively change behavior. +_FORCE_FP32 = os.environ.get('TWINKLE_NPU_GATED_RMSNorm_FP32', '0').lower() in ( + '1', 'true', 'on', 'yes' +) - Reads FP32-mode preference from env ``TWINKLE_NPU_GATED_RMSNorm_FP32`` once - and caches it on the instance. - """ + +def npu_gated_rms_norm_forward(self, hidden_states, gate=None): + """Forward replacement for Gated RMSNorm variants (e.g. Qwen3.5-MoE).""" import torch_npu input_dtype = hidden_states.dtype _eps = getattr(self, 'variance_epsilon', getattr(self, 'eps', 1e-6)) - force_fp32 = getattr(self, '_twinkle_force_fp32', None) - if force_fp32 is None: - force_fp32 = os.environ.get('TWINKLE_NPU_GATED_RMSNorm_FP32', '0').lower() in ( - '1', 'true', 'on', 'yes' - ) - self._twinkle_force_fp32 = force_fp32 - - if force_fp32: + if _FORCE_FP32: hidden_states = hidden_states.to(torch.float32) weight = self.weight.float() gate = gate.to(torch.float32) if gate is not None else None From 398dd378ce087d7a0fede8b7c7d4e79927e4f31c Mon Sep 17 00:00:00 2001 From: weikaiwen <34648228+kevssim@users.noreply.github.com> Date: Fri, 26 Jun 2026 15:42:52 +0800 Subject: [PATCH 22/27] wip --- tests/kernel/test_builtin.py | 34 ++++++++++++++++++++++++++++++++-- tests/kernel/test_kernelize.py | 31 ++++++++++++++++++++++++++----- 2 files changed, 58 insertions(+), 7 deletions(-) diff --git a/tests/kernel/test_builtin.py b/tests/kernel/test_builtin.py index 44786175..38d83915 100644 --- a/tests/kernel/test_builtin.py +++ b/tests/kernel/test_builtin.py @@ -1,9 +1,19 @@ +import importlib.machinery +import sys +import types + import torch import torch.nn as nn import pytest +def _fake_module(name: str): + module = types.ModuleType(name) + module.__spec__ = importlib.machinery.ModuleSpec(name, loader=None) + return module + + def test_npu_builtin_returns_dict(): from twinkle.kernel.builtin import npu_builtin bundle = npu_builtin() @@ -48,13 +58,33 @@ def test_npu_builtin_skips_missing_modeling_modules(): assert isinstance(bundle, dict) -def test_npu_builtin_does_not_overwrite_global_sdpa_on_non_npu_host(): +def test_npu_builtin_does_not_overwrite_global_sdpa_on_non_npu_host(monkeypatch): """Calling npu_builtin() on a CUDA/CPU host must not contaminate the global HF SDPA registry. The NPU impl inverts boolean masks, which is wrong for non-NPU execution.""" from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS from twinkle.kernel.builtin import npu_builtin + from twinkle.utils.device_mesh import Platform + monkeypatch.setattr(Platform, 'device_prefix', staticmethod(lambda platform=None: 'cuda')) original = ALL_ATTENTION_FUNCTIONS.get('sdpa') npu_builtin() - assert ALL_ATTENTION_FUNCTIONS.get('sdpa') is original \ No newline at end of file + assert ALL_ATTENTION_FUNCTIONS.get('sdpa') is original + + +def test_npu_builtin_skips_side_effects_on_non_npu_platform(monkeypatch): + from twinkle.kernel import builtin + from twinkle.kernel.npu_impls import fla + from twinkle.utils.device_mesh import Platform + + installs = [] + fla_calls = [] + monkeypatch.setattr(Platform, 'device_prefix', staticmethod(lambda platform=None: 'cuda')) + monkeypatch.setitem(sys.modules, 'torch_npu', _fake_module('torch_npu')) + monkeypatch.setattr(builtin, '_install_sdpa', lambda impl: installs.append(impl)) + monkeypatch.setattr(fla, 'apply_qwen3_5_fla', lambda model: fla_calls.append(model)) + + builtin.npu_builtin(nn.Linear(1, 1)) + + assert installs == [] + assert fla_calls == [] diff --git a/tests/kernel/test_kernelize.py b/tests/kernel/test_kernelize.py index dd159de6..cdb98cae 100644 --- a/tests/kernel/test_kernelize.py +++ b/tests/kernel/test_kernelize.py @@ -49,15 +49,36 @@ def test_kernelize_string_key_calls_setattr(): sys.modules.pop(mod_name, None) -def test_kernelize_device_dict_match(): - parent = nn.Sequential(_SrcLayer()) # cpu params +def test_kernelize_device_dict_match(monkeypatch): + from twinkle.utils.device_mesh import Platform + + parent = nn.Sequential(_SrcLayer()) + monkeypatch.setattr(Platform, 'device_prefix', staticmethod(lambda platform=None: 'cpu')) + kernelize(parent, {_SrcLayer: {'cpu': _DstLayer, 'npu': nn.Identity}}) + + assert type(parent[0]) is _DstLayer + + +def test_kernelize_uses_platform_device_prefix(monkeypatch): + from twinkle.utils.device_mesh import Platform + + parent = nn.Sequential(_SrcLayer()) # params may still be CPU before FSDP placement + monkeypatch.setattr(Platform, 'device_prefix', staticmethod(lambda platform=None: 'npu')) + + kernelize(parent, {_SrcLayer: {'npu': _DstLayer}}) + assert type(parent[0]) is _DstLayer -def test_kernelize_device_dict_miss_skips_silently(): - parent = nn.Sequential(_SrcLayer()) # cpu params +def test_kernelize_device_dict_miss_skips_silently(monkeypatch): + from twinkle.utils.device_mesh import Platform + + parent = nn.Sequential(_SrcLayer()) + monkeypatch.setattr(Platform, 'device_prefix', staticmethod(lambda platform=None: 'cpu')) + kernelize(parent, {_SrcLayer: {'npu': _DstLayer}}) + assert type(parent[0]) is _SrcLayer @@ -74,4 +95,4 @@ def test_kernelize_loads_hub_ref(monkeypatch): parent = nn.Sequential(_SrcLayer()) ref = HubRef('org/repo', 'X', revision='main') kernelize(parent, {_SrcLayer: ref}) - assert type(parent[0]) is _DstLayer \ No newline at end of file + assert type(parent[0]) is _DstLayer From 126efc3da85fbc286f31efbcb0db2cbf7b9107f6 Mon Sep 17 00:00:00 2001 From: weikaiwen <34648228+kevssim@users.noreply.github.com> Date: Mon, 29 Jun 2026 15:16:59 +0800 Subject: [PATCH 23/27] wip --- .../transformers/ep_fsdp2_lora_deepseek_v4.py | 298 ++--- .../transformers/ep_fsdp2_lora_qwen3_5_moe.py | 306 ++--- cookbook/transformers/fsdp2.py | 218 ++-- cookbook/transformers/sp_fsdp_dense.py | 198 +-- docs/source_en/Components/Kernel/Kernel.md | 276 ++--- .../\345\206\205\346\240\270/Kernel.md" | 272 ++--- src/twinkle/kernel/__init__.py | 24 +- src/twinkle/kernel/builtin.py | 410 +++---- src/twinkle/kernel/chunk_gated_delta_rule.py | 724 +++++------ src/twinkle/kernel/core.py | 342 +++--- src/twinkle/kernel/npu_impls/__init__.py | 62 +- src/twinkle/kernel/npu_impls/attention.py | 106 +- src/twinkle/kernel/npu_impls/fla.py | 204 ++-- src/twinkle/kernel/npu_impls/moe.py | 300 ++--- src/twinkle/kernel/npu_impls/rms_norm.py | 148 +-- src/twinkle/kernel/npu_impls/rotary.py | 130 +- src/twinkle/kernel/npu_impls/swiglu.py | 38 +- .../model/transformers/moe/expert_parallel.py | 1066 ++++++++--------- .../sequence_parallel/linear_attention_sp.py | 726 +++++------ tests/kernel/npu_impls/test_attention.py | 30 +- tests/kernel/npu_impls/test_fla.py | 108 +- tests/kernel/npu_impls/test_moe.py | 22 +- tests/kernel/npu_impls/test_rms_norm.py | 78 +- tests/kernel/npu_impls/test_rotary.py | 40 +- tests/kernel/npu_impls/test_swiglu.py | 22 +- tests/kernel/test_builtin.py | 180 +-- tests/kernel/test_hub.py | 106 +- tests/kernel/test_infer_device.py | 56 +- tests/kernel/test_kernelize.py | 196 +-- tests/kernel/test_load_hub_ref.py | 136 +-- tests/kernel/test_public_api.py | 42 +- tests/kernel/test_replace.py | 146 +-- tests/kernel/test_resolve_value.py | 94 +- 33 files changed, 3552 insertions(+), 3552 deletions(-) diff --git a/cookbook/transformers/ep_fsdp2_lora_deepseek_v4.py b/cookbook/transformers/ep_fsdp2_lora_deepseek_v4.py index af72efa1..bb4582e0 100644 --- a/cookbook/transformers/ep_fsdp2_lora_deepseek_v4.py +++ b/cookbook/transformers/ep_fsdp2_lora_deepseek_v4.py @@ -1,149 +1,149 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -"""EP + FSDP2 + LoRA SFT cookbook for DeepSeek-V4. - -Run on 8 GPUs: - torchrun --nproc-per-node=8 cookbook/transformers/ep_fsdp2_lora_deepseek_v4.py -""" -import os -from pathlib import Path - -from peft import LoraConfig -from transformers import AutoConfig - -import twinkle -from twinkle import DeviceMesh, Platform, get_device_placement, get_logger -from twinkle.dataloader import DataLoader -from twinkle.dataset import Dataset, DatasetMeta -from twinkle.model import TransformersModel -from twinkle.preprocessor import SelfCognitionProcessor - -logger = get_logger() - -MODEL_ID = os.environ.get('DSV4_MODEL_ID', 'ms://deepseek-ai/DeepSeek-V4-Flash') -DATASET_ID = os.environ.get('DATASET_ID', 'ms://swift/self-cognition') -TEMPLATE_ID = os.environ.get('TEMPLATE_ID', 'DeepseekV4Template') -BATCH_SIZE = int(os.environ.get('BATCH_SIZE', '4')) -GRAD_ACCUM_STEPS = int(os.environ.get('GRAD_ACCUM_STEPS', '4')) -LOG_INTERVAL = GRAD_ACCUM_STEPS -LR = float(os.environ.get('LR', '1e-4')) -MAX_GRAD_NORM = float(os.environ.get('MAX_GRAD_NORM', '1.0')) -LORA_R = int(os.environ.get('LORA_R', '8')) -LORA_ALPHA = int(os.environ.get('LORA_ALPHA', '32')) -ENABLE_EP = os.environ.get('ENABLE_EP', '1') == '1' -OUTPUT_DIR = os.environ.get('OUTPUT_DIR', './output_dsv4') -RESUME_FROM_CHECKPOINT = os.environ.get('RESUME_FROM_CHECKPOINT') or None -RESUME_ONLY_MODEL = os.environ.get('RESUME_ONLY_MODEL', '0') == '1' -IGNORE_DATA_SKIP = os.environ.get('IGNORE_DATA_SKIP', '0') == '1' -ADAPTER_NAME = os.environ.get('ADAPTER_NAME', 'default') -NUM_GPUS = int(os.environ.get('NUM_GPUS', '8')) - -device_mesh = DeviceMesh.from_sizes( - fsdp_size=NUM_GPUS, - dp_size=1, - ep_size=NUM_GPUS, - device_type=Platform.get_platform().device_prefix(), -) -twinkle.initialize(mode='local', global_device_mesh=device_mesh) - - -def _build_lora_config(enable_ep: bool): - if enable_ep: - return LoraConfig( - r=LORA_R, - lora_alpha=LORA_ALPHA, - target_modules='all-linear', - exclude_modules=['o_a_proj'], - target_parameters=['mlp.experts.gate_up_proj', 'mlp.experts.down_proj'], - ) - # Expert weights are bare nn.Parameters. PEFT trains them through - # target_parameters/ParamWrapper, which dynamically parametrizes weights - # during forward. That is not stable with plain FSDP2, so non-EP mode uses - # regular module LoRA and does not train expert parameters. - return LoraConfig( - r=LORA_R, - lora_alpha=LORA_ALPHA, - exclude_modules=['o_a_proj'], - target_modules='all-linear', - ) - - -def save_checkpoint(model: TransformersModel, checkpoint_name: str, dataloader: DataLoader): - return model.save( - name=checkpoint_name, - output_dir=OUTPUT_DIR, - adapter_name=ADAPTER_NAME, - save_optimizer=True, - consumed_train_samples=dataloader.get_state()['consumed_train_samples'], - ) - - -def train(): - config = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True) - text_config = getattr(config, 'text_config', config) - if hasattr(text_config, 'use_cache'): - text_config.use_cache = False - - dataset = Dataset(dataset_meta=DatasetMeta(DATASET_ID)) - dataset.set_template(TEMPLATE_ID, model_id=MODEL_ID) - dataset.map(SelfCognitionProcessor('twinkle', 'ModelScope')) - dataset.encode(batched=True) - dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, device_mesh=device_mesh) - - model = TransformersModel( - model_id=MODEL_ID, - config=config, - device_mesh=device_mesh, - strategy='native_fsdp', - memory_efficient_init=True, - fsdp_config={ - 'expert_parallel': { - 'enabled': ENABLE_EP, - 'router_dtype': 'fp32', - 'keep_router_logits': False, - } - }, - ) - lora_cfg = _build_lora_config(ENABLE_EP) - model.add_adapter_to_model(ADAPTER_NAME, lora_cfg, gradient_accumulation_steps=GRAD_ACCUM_STEPS) - model.set_optimizer('AdamW', lr=LR, foreach=False) - model.set_lr_scheduler( - scheduler_cls='CosineWarmupScheduler', - num_warmup_steps=5, - num_training_steps=len(dataloader), - ) - - if RESUME_FROM_CHECKPOINT: - checkpoint_path = Path(RESUME_FROM_CHECKPOINT).expanduser().resolve() - kwargs = {} - if ADAPTER_NAME: - kwargs['adapter_name'] = ADAPTER_NAME - progress = model.resume_from_checkpoint( - str(checkpoint_path), resume_only_model=RESUME_ONLY_MODEL, **kwargs) - if not IGNORE_DATA_SKIP: - dataloader.resume_from_checkpoint(progress['consumed_train_samples']) - - logger.info(get_device_placement()) - logger.info(model.get_train_configs()) - logger.info( - f'Total steps: {len(dataloader)}, batch_size={BATCH_SIZE}, grad_accum={GRAD_ACCUM_STEPS}, ' - f'enable_ep={ENABLE_EP}, output_dir={OUTPUT_DIR}') - - optimizer_group = model.optimizer_group[ADAPTER_NAME] - for batch in dataloader: - if callable(batch): - batch = batch() - model.forward_backward(inputs=batch) - model.clip_grad_and_step(max_grad_norm=MAX_GRAD_NORM, gradient_accumulation_steps=GRAD_ACCUM_STEPS) - cur_step = optimizer_group.cur_step - if cur_step > 0 and cur_step % LOG_INTERVAL == 0: - metric = model.calculate_metric(is_training=True) - if callable(metric): - metric = metric() - logger.info(f'Current is step {cur_step} of {len(dataloader)}, metric: {metric}') - - final_checkpoint = save_checkpoint(model, 'checkpoint-final', dataloader) - logger.info(f'Saved final adapter to {final_checkpoint}') - - -if __name__ == '__main__': - train() +# Copyright (c) ModelScope Contributors. All rights reserved. +"""EP + FSDP2 + LoRA SFT cookbook for DeepSeek-V4. + +Run on 8 GPUs: + torchrun --nproc-per-node=8 cookbook/transformers/ep_fsdp2_lora_deepseek_v4.py +""" +import os +from pathlib import Path + +from peft import LoraConfig +from transformers import AutoConfig + +import twinkle +from twinkle import DeviceMesh, Platform, get_device_placement, get_logger +from twinkle.dataloader import DataLoader +from twinkle.dataset import Dataset, DatasetMeta +from twinkle.model import TransformersModel +from twinkle.preprocessor import SelfCognitionProcessor + +logger = get_logger() + +MODEL_ID = os.environ.get('DSV4_MODEL_ID', 'ms://deepseek-ai/DeepSeek-V4-Flash') +DATASET_ID = os.environ.get('DATASET_ID', 'ms://swift/self-cognition') +TEMPLATE_ID = os.environ.get('TEMPLATE_ID', 'DeepseekV4Template') +BATCH_SIZE = int(os.environ.get('BATCH_SIZE', '4')) +GRAD_ACCUM_STEPS = int(os.environ.get('GRAD_ACCUM_STEPS', '4')) +LOG_INTERVAL = GRAD_ACCUM_STEPS +LR = float(os.environ.get('LR', '1e-4')) +MAX_GRAD_NORM = float(os.environ.get('MAX_GRAD_NORM', '1.0')) +LORA_R = int(os.environ.get('LORA_R', '8')) +LORA_ALPHA = int(os.environ.get('LORA_ALPHA', '32')) +ENABLE_EP = os.environ.get('ENABLE_EP', '1') == '1' +OUTPUT_DIR = os.environ.get('OUTPUT_DIR', './output_dsv4') +RESUME_FROM_CHECKPOINT = os.environ.get('RESUME_FROM_CHECKPOINT') or None +RESUME_ONLY_MODEL = os.environ.get('RESUME_ONLY_MODEL', '0') == '1' +IGNORE_DATA_SKIP = os.environ.get('IGNORE_DATA_SKIP', '0') == '1' +ADAPTER_NAME = os.environ.get('ADAPTER_NAME', 'default') +NUM_GPUS = int(os.environ.get('NUM_GPUS', '8')) + +device_mesh = DeviceMesh.from_sizes( + fsdp_size=NUM_GPUS, + dp_size=1, + ep_size=NUM_GPUS, + device_type=Platform.get_platform().device_prefix(), +) +twinkle.initialize(mode='local', global_device_mesh=device_mesh) + + +def _build_lora_config(enable_ep: bool): + if enable_ep: + return LoraConfig( + r=LORA_R, + lora_alpha=LORA_ALPHA, + target_modules='all-linear', + exclude_modules=['o_a_proj'], + target_parameters=['mlp.experts.gate_up_proj', 'mlp.experts.down_proj'], + ) + # Expert weights are bare nn.Parameters. PEFT trains them through + # target_parameters/ParamWrapper, which dynamically parametrizes weights + # during forward. That is not stable with plain FSDP2, so non-EP mode uses + # regular module LoRA and does not train expert parameters. + return LoraConfig( + r=LORA_R, + lora_alpha=LORA_ALPHA, + exclude_modules=['o_a_proj'], + target_modules='all-linear', + ) + + +def save_checkpoint(model: TransformersModel, checkpoint_name: str, dataloader: DataLoader): + return model.save( + name=checkpoint_name, + output_dir=OUTPUT_DIR, + adapter_name=ADAPTER_NAME, + save_optimizer=True, + consumed_train_samples=dataloader.get_state()['consumed_train_samples'], + ) + + +def train(): + config = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True) + text_config = getattr(config, 'text_config', config) + if hasattr(text_config, 'use_cache'): + text_config.use_cache = False + + dataset = Dataset(dataset_meta=DatasetMeta(DATASET_ID)) + dataset.set_template(TEMPLATE_ID, model_id=MODEL_ID) + dataset.map(SelfCognitionProcessor('twinkle', 'ModelScope')) + dataset.encode(batched=True) + dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, device_mesh=device_mesh) + + model = TransformersModel( + model_id=MODEL_ID, + config=config, + device_mesh=device_mesh, + strategy='native_fsdp', + memory_efficient_init=True, + fsdp_config={ + 'expert_parallel': { + 'enabled': ENABLE_EP, + 'router_dtype': 'fp32', + 'keep_router_logits': False, + } + }, + ) + lora_cfg = _build_lora_config(ENABLE_EP) + model.add_adapter_to_model(ADAPTER_NAME, lora_cfg, gradient_accumulation_steps=GRAD_ACCUM_STEPS) + model.set_optimizer('AdamW', lr=LR, foreach=False) + model.set_lr_scheduler( + scheduler_cls='CosineWarmupScheduler', + num_warmup_steps=5, + num_training_steps=len(dataloader), + ) + + if RESUME_FROM_CHECKPOINT: + checkpoint_path = Path(RESUME_FROM_CHECKPOINT).expanduser().resolve() + kwargs = {} + if ADAPTER_NAME: + kwargs['adapter_name'] = ADAPTER_NAME + progress = model.resume_from_checkpoint( + str(checkpoint_path), resume_only_model=RESUME_ONLY_MODEL, **kwargs) + if not IGNORE_DATA_SKIP: + dataloader.resume_from_checkpoint(progress['consumed_train_samples']) + + logger.info(get_device_placement()) + logger.info(model.get_train_configs()) + logger.info( + f'Total steps: {len(dataloader)}, batch_size={BATCH_SIZE}, grad_accum={GRAD_ACCUM_STEPS}, ' + f'enable_ep={ENABLE_EP}, output_dir={OUTPUT_DIR}') + + optimizer_group = model.optimizer_group[ADAPTER_NAME] + for batch in dataloader: + if callable(batch): + batch = batch() + model.forward_backward(inputs=batch) + model.clip_grad_and_step(max_grad_norm=MAX_GRAD_NORM, gradient_accumulation_steps=GRAD_ACCUM_STEPS) + cur_step = optimizer_group.cur_step + if cur_step > 0 and cur_step % LOG_INTERVAL == 0: + metric = model.calculate_metric(is_training=True) + if callable(metric): + metric = metric() + logger.info(f'Current is step {cur_step} of {len(dataloader)}, metric: {metric}') + + final_checkpoint = save_checkpoint(model, 'checkpoint-final', dataloader) + logger.info(f'Saved final adapter to {final_checkpoint}') + + +if __name__ == '__main__': + train() diff --git a/cookbook/transformers/ep_fsdp2_lora_qwen3_5_moe.py b/cookbook/transformers/ep_fsdp2_lora_qwen3_5_moe.py index 03e962e6..eb3efed6 100644 --- a/cookbook/transformers/ep_fsdp2_lora_qwen3_5_moe.py +++ b/cookbook/transformers/ep_fsdp2_lora_qwen3_5_moe.py @@ -1,153 +1,153 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -"""EP + FSDP2 + LoRA SFT cookbook for Qwen3.5-MoE. - -Run on 8 GPUs: - torchrun --nproc-per-node=8 cookbook/transformers/ep_fsdp2_lora_qwen3_5_moe.py -""" -import os -from pathlib import Path - -from peft import LoraConfig -from transformers import AutoConfig - -import twinkle -from twinkle import DeviceMesh, Platform, get_device_placement, get_logger -from twinkle.dataloader import DataLoader -from twinkle.dataset import Dataset, DatasetMeta -from twinkle.model import TransformersModel -from twinkle.preprocessor import SelfCognitionProcessor -from twinkle.utils.framework import Torch -from twinkle.kernel import kernelize, npu_builtin - -logger = get_logger() - -MODEL_ID = os.environ.get('QWEN3_MODEL_ID', 'ms://Qwen/Qwen3.6-35B-A3B') -DATASET_ID = os.environ.get('DATASET_ID', 'ms://swift/self-cognition') -TEMPLATE_ID = os.environ.get('TEMPLATE_ID', 'Qwen3_5Template') -BATCH_SIZE = int(os.environ.get('BATCH_SIZE', '4')) -GRAD_ACCUM_STEPS = int(os.environ.get('GRAD_ACCUM_STEPS', '4')) -LOG_INTERVAL = GRAD_ACCUM_STEPS -LR = float(os.environ.get('LR', '1e-4')) -MAX_GRAD_NORM = float(os.environ.get('MAX_GRAD_NORM', '1.0')) -LORA_R = int(os.environ.get('LORA_R', '8')) -LORA_ALPHA = int(os.environ.get('LORA_ALPHA', '32')) -ENABLE_EP = os.environ.get('ENABLE_EP', '1') == '1' -OUTPUT_DIR = os.environ.get('OUTPUT_DIR', './output') -RESUME_FROM_CHECKPOINT = os.environ.get('RESUME_FROM_CHECKPOINT') or None -RESUME_ONLY_MODEL = os.environ.get('RESUME_ONLY_MODEL', '0') == '1' -IGNORE_DATA_SKIP = os.environ.get('IGNORE_DATA_SKIP', '0') == '1' -ADAPTER_NAME = os.environ.get('ADAPTER_NAME', 'default') - -device_mesh = DeviceMesh.from_sizes( - fsdp_size=8, - dp_size=1, - ep_size=8, - device_type=Platform.get_platform().device_prefix(), -) -twinkle.initialize(mode='local', global_device_mesh=device_mesh) - - -def _build_lora_config(enable_ep: bool): - if enable_ep: - return LoraConfig( - r=LORA_R, - lora_alpha=LORA_ALPHA, - target_modules='all-linear', - target_parameters=['mlp.experts.gate_up_proj', 'mlp.experts.down_proj'], - ) - # Expert weights are bare nn.Parameters. PEFT trains them through - # target_parameters/ParamWrapper, which dynamically parametrizes weights - # during forward. That is not stable with plain FSDP2, so non-EP mode uses - # regular module LoRA and does not train expert parameters. - return LoraConfig( - r=LORA_R, - lora_alpha=LORA_ALPHA, - target_modules='all-linear', - ) - - -def save_checkpoint(model: TransformersModel, checkpoint_name: str, dataloader: DataLoader): - return model.save( - name=checkpoint_name, - output_dir=OUTPUT_DIR, - adapter_name=ADAPTER_NAME, - save_optimizer=True, - consumed_train_samples=dataloader.get_state()['consumed_train_samples'], - ) - - -def train(): - config = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True) - text_config = getattr(config, 'text_config', config) - if hasattr(text_config, 'use_cache'): - text_config.use_cache = False - - dataset = Dataset(dataset_meta=DatasetMeta(DATASET_ID)) - try: - dataset.set_template(TEMPLATE_ID, model_id=MODEL_ID) - except ValueError: - dataset.set_template('Qwen3_5Template', model_id=MODEL_ID) - dataset.map(SelfCognitionProcessor('twinkle', 'ModelScope')) - dataset.encode(batched=True) - dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, device_mesh=device_mesh) - - model = TransformersModel( - model_id=MODEL_ID, - config=config, - device_mesh=device_mesh, - strategy='native_fsdp', - fsdp_config={ - 'expert_parallel': { - 'enabled': ENABLE_EP, - 'router_dtype': 'fp32', - 'keep_router_logits': False, - } - }, - ) - # npu patch - if Torch.is_npu_available(): - model = kernelize(model, npu_builtin(model)) - lora_cfg = _build_lora_config(ENABLE_EP) - model.add_adapter_to_model(ADAPTER_NAME, lora_cfg, gradient_accumulation_steps=GRAD_ACCUM_STEPS) - model.set_optimizer('AdamW', lr=LR, foreach=False) - model.set_lr_scheduler( - scheduler_cls='CosineWarmupScheduler', - num_warmup_steps=5, - num_training_steps=len(dataloader), - ) - - if RESUME_FROM_CHECKPOINT: - checkpoint_path = Path(RESUME_FROM_CHECKPOINT).expanduser().resolve() - kwargs = {} - if ADAPTER_NAME: - kwargs['adapter_name'] = ADAPTER_NAME - progress = model.resume_from_checkpoint( - str(checkpoint_path), resume_only_model=RESUME_ONLY_MODEL, **kwargs) - if not IGNORE_DATA_SKIP: - dataloader.resume_from_checkpoint(progress['consumed_train_samples']) - - logger.info(get_device_placement()) - logger.info(model.get_train_configs()) - logger.info( - f'Total steps: {len(dataloader)}, batch_size={BATCH_SIZE}, grad_accum={GRAD_ACCUM_STEPS}, ' - f'enable_ep={ENABLE_EP}, output_dir={OUTPUT_DIR}') - - optimizer_group = model.optimizer_group[ADAPTER_NAME] - for batch in dataloader: - if callable(batch): - batch = batch() - model.forward_backward(inputs=batch) - model.clip_grad_and_step(max_grad_norm=MAX_GRAD_NORM, gradient_accumulation_steps=GRAD_ACCUM_STEPS) - cur_step = optimizer_group.cur_step - if cur_step > 0 and cur_step % LOG_INTERVAL == 0: - metric = model.calculate_metric(is_training=True) - if callable(metric): - metric = metric() - logger.info(f'Current is step {cur_step} of {len(dataloader)}, metric: {metric}') - - final_checkpoint = save_checkpoint(model, 'checkpoint-final', dataloader) - logger.info(f'Saved final adapter to {final_checkpoint}') - - -if __name__ == '__main__': - train() +# Copyright (c) ModelScope Contributors. All rights reserved. +"""EP + FSDP2 + LoRA SFT cookbook for Qwen3.5-MoE. + +Run on 8 GPUs: + torchrun --nproc-per-node=8 cookbook/transformers/ep_fsdp2_lora_qwen3_5_moe.py +""" +import os +from pathlib import Path + +from peft import LoraConfig +from transformers import AutoConfig + +import twinkle +from twinkle import DeviceMesh, Platform, get_device_placement, get_logger +from twinkle.dataloader import DataLoader +from twinkle.dataset import Dataset, DatasetMeta +from twinkle.model import TransformersModel +from twinkle.preprocessor import SelfCognitionProcessor +from twinkle.utils.framework import Torch +from twinkle.kernel import kernelize, npu_builtin + +logger = get_logger() + +MODEL_ID = os.environ.get('QWEN3_MODEL_ID', 'ms://Qwen/Qwen3.6-35B-A3B') +DATASET_ID = os.environ.get('DATASET_ID', 'ms://swift/self-cognition') +TEMPLATE_ID = os.environ.get('TEMPLATE_ID', 'Qwen3_5Template') +BATCH_SIZE = int(os.environ.get('BATCH_SIZE', '4')) +GRAD_ACCUM_STEPS = int(os.environ.get('GRAD_ACCUM_STEPS', '4')) +LOG_INTERVAL = GRAD_ACCUM_STEPS +LR = float(os.environ.get('LR', '1e-4')) +MAX_GRAD_NORM = float(os.environ.get('MAX_GRAD_NORM', '1.0')) +LORA_R = int(os.environ.get('LORA_R', '8')) +LORA_ALPHA = int(os.environ.get('LORA_ALPHA', '32')) +ENABLE_EP = os.environ.get('ENABLE_EP', '1') == '1' +OUTPUT_DIR = os.environ.get('OUTPUT_DIR', './output') +RESUME_FROM_CHECKPOINT = os.environ.get('RESUME_FROM_CHECKPOINT') or None +RESUME_ONLY_MODEL = os.environ.get('RESUME_ONLY_MODEL', '0') == '1' +IGNORE_DATA_SKIP = os.environ.get('IGNORE_DATA_SKIP', '0') == '1' +ADAPTER_NAME = os.environ.get('ADAPTER_NAME', 'default') + +device_mesh = DeviceMesh.from_sizes( + fsdp_size=8, + dp_size=1, + ep_size=8, + device_type=Platform.get_platform().device_prefix(), +) +twinkle.initialize(mode='local', global_device_mesh=device_mesh) + + +def _build_lora_config(enable_ep: bool): + if enable_ep: + return LoraConfig( + r=LORA_R, + lora_alpha=LORA_ALPHA, + target_modules='all-linear', + target_parameters=['mlp.experts.gate_up_proj', 'mlp.experts.down_proj'], + ) + # Expert weights are bare nn.Parameters. PEFT trains them through + # target_parameters/ParamWrapper, which dynamically parametrizes weights + # during forward. That is not stable with plain FSDP2, so non-EP mode uses + # regular module LoRA and does not train expert parameters. + return LoraConfig( + r=LORA_R, + lora_alpha=LORA_ALPHA, + target_modules='all-linear', + ) + + +def save_checkpoint(model: TransformersModel, checkpoint_name: str, dataloader: DataLoader): + return model.save( + name=checkpoint_name, + output_dir=OUTPUT_DIR, + adapter_name=ADAPTER_NAME, + save_optimizer=True, + consumed_train_samples=dataloader.get_state()['consumed_train_samples'], + ) + + +def train(): + config = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True) + text_config = getattr(config, 'text_config', config) + if hasattr(text_config, 'use_cache'): + text_config.use_cache = False + + dataset = Dataset(dataset_meta=DatasetMeta(DATASET_ID)) + try: + dataset.set_template(TEMPLATE_ID, model_id=MODEL_ID) + except ValueError: + dataset.set_template('Qwen3_5Template', model_id=MODEL_ID) + dataset.map(SelfCognitionProcessor('twinkle', 'ModelScope')) + dataset.encode(batched=True) + dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, device_mesh=device_mesh) + + model = TransformersModel( + model_id=MODEL_ID, + config=config, + device_mesh=device_mesh, + strategy='native_fsdp', + fsdp_config={ + 'expert_parallel': { + 'enabled': ENABLE_EP, + 'router_dtype': 'fp32', + 'keep_router_logits': False, + } + }, + ) + # npu patch + if Torch.is_npu_available(): + model = kernelize(model, npu_builtin(model)) + lora_cfg = _build_lora_config(ENABLE_EP) + model.add_adapter_to_model(ADAPTER_NAME, lora_cfg, gradient_accumulation_steps=GRAD_ACCUM_STEPS) + model.set_optimizer('AdamW', lr=LR, foreach=False) + model.set_lr_scheduler( + scheduler_cls='CosineWarmupScheduler', + num_warmup_steps=5, + num_training_steps=len(dataloader), + ) + + if RESUME_FROM_CHECKPOINT: + checkpoint_path = Path(RESUME_FROM_CHECKPOINT).expanduser().resolve() + kwargs = {} + if ADAPTER_NAME: + kwargs['adapter_name'] = ADAPTER_NAME + progress = model.resume_from_checkpoint( + str(checkpoint_path), resume_only_model=RESUME_ONLY_MODEL, **kwargs) + if not IGNORE_DATA_SKIP: + dataloader.resume_from_checkpoint(progress['consumed_train_samples']) + + logger.info(get_device_placement()) + logger.info(model.get_train_configs()) + logger.info( + f'Total steps: {len(dataloader)}, batch_size={BATCH_SIZE}, grad_accum={GRAD_ACCUM_STEPS}, ' + f'enable_ep={ENABLE_EP}, output_dir={OUTPUT_DIR}') + + optimizer_group = model.optimizer_group[ADAPTER_NAME] + for batch in dataloader: + if callable(batch): + batch = batch() + model.forward_backward(inputs=batch) + model.clip_grad_and_step(max_grad_norm=MAX_GRAD_NORM, gradient_accumulation_steps=GRAD_ACCUM_STEPS) + cur_step = optimizer_group.cur_step + if cur_step > 0 and cur_step % LOG_INTERVAL == 0: + metric = model.calculate_metric(is_training=True) + if callable(metric): + metric = metric() + logger.info(f'Current is step {cur_step} of {len(dataloader)}, metric: {metric}') + + final_checkpoint = save_checkpoint(model, 'checkpoint-final', dataloader) + logger.info(f'Saved final adapter to {final_checkpoint}') + + +if __name__ == '__main__': + train() diff --git a/cookbook/transformers/fsdp2.py b/cookbook/transformers/fsdp2.py index dd7c0cb1..0ccc6a32 100644 --- a/cookbook/transformers/fsdp2.py +++ b/cookbook/transformers/fsdp2.py @@ -1,109 +1,109 @@ -from pathlib import Path - -from peft import LoraConfig -from tqdm import tqdm - -import twinkle -from twinkle import DeviceMesh, get_device_placement, get_logger -from twinkle.cli import CLI -from twinkle.dataloader import DataLoader -from twinkle.dataset import Dataset, DatasetMeta -from twinkle.model import TransformersModel -from twinkle.preprocessor import SelfCognitionProcessor -from twinkle.utils.framework import Torch -from twinkle.kernel import kernelize, npu_builtin - -logger = get_logger() -args = CLI.from_args() - -device_mesh = DeviceMesh.from_sizes(fsdp_size=args.infra.fsdp_size, dp_size=args.infra.dp_size) -twinkle.initialize(mode=args.infra.mode, global_device_mesh=device_mesh) - - -def build_dataset(num_samples: int) -> Dataset: - dataset = Dataset(dataset_meta=DatasetMeta(args.dataset.dataset_id, data_slice=range(num_samples))) - dataset.set_template(args.template.template_cls, model_id=args.model.model_id) - dataset.map(SelfCognitionProcessor( - args.extra.get('model_name', 'twinkle大模型'), - args.extra.get('model_author', 'ModelScope社区'), - )) - dataset.encode() - return dataset - - -def save_checkpoint(model: TransformersModel, checkpoint_name: str, dataloader: DataLoader): - model.save( - checkpoint_name, - output_dir=args.training.output_dir, - adapter_name=args.lora.adapter_name, - save_optimizer=True, - consumed_train_samples=dataloader.get_state()['consumed_train_samples'], - ) - - -def evaluate(model): - eval_samples = args.training.eval_samples or 100 - dataloader = DataLoader(dataset=build_dataset(eval_samples), batch_size=args.training.batch_size) - for batch in tqdm(dataloader): - model.forward_only(inputs=batch) - model.calculate_loss() - return model.calculate_metric(is_training=False) - - -def train(): - train_samples = int(args.extra.get('train_samples', 1000)) - dataset = build_dataset(train_samples) - dataloader = DataLoader(dataset=dataset, batch_size=args.training.batch_size) - model = TransformersModel(model_id=args.model.model_id) - model.model._no_split_modules = {'Qwen3_5DecoderLayer'} - # npu patch - if Torch.is_npu_available(): - model = kernelize(model, npu_builtin(model)) - - lora_config = LoraConfig(**args.get_lora_args()) - model.add_adapter_to_model( - args.lora.adapter_name, lora_config, - gradient_accumulation_steps=args.training.gradient_accumulation_steps) - model.set_optimizer(optimizer_cls=args.optimizer.optimizer_cls, lr=args.optimizer.learning_rate) - - # Add LRScheduler for lora `default` - model.set_lr_scheduler( - scheduler_cls=args.scheduler.scheduler_cls, - num_warmup_steps=args.scheduler.num_warmup_steps, - num_training_steps=len(dataloader)) - - if args.training.resume_from_checkpoint: - checkpoint_path = Path(args.training.resume_from_checkpoint).expanduser().resolve() - progress = model.resume_from_checkpoint( - str(checkpoint_path), - resume_only_model=args.training.resume_only_model, - adapter_name=args.lora.adapter_name) - if not args.training.ignore_data_skip: - dataloader.resume_from_checkpoint(progress['consumed_train_samples']) - - logger.info(get_device_placement()) - logger.info(model.get_train_configs()) - logger.info(f'Total steps: {len(dataloader)}') - optimizer_group = model.optimizer_group[args.lora.adapter_name] - best_loss = float('inf') - eval_interval = args.training.eval_interval or 40 - for batch in dataloader: - model.forward_backward(inputs=batch) - model.clip_grad_and_step() - cur_step = optimizer_group.cur_step - if cur_step % args.training.log_interval == 0: - metric = model.calculate_metric(is_training=True) - logger.info(f'Current is step {cur_step} of {len(dataloader)}, metric: {metric}') - if cur_step > 0 and cur_step % eval_interval == 0: - metrics = evaluate(model) - logger.info(f'Eval metric: {metrics}') - metrics['step'] = cur_step - current_loss = float(metrics['loss']) - if current_loss < best_loss: - save_checkpoint(model, f'checkpoint-{cur_step}', dataloader) - best_loss = current_loss - save_checkpoint(model, 'last-checkpoint', dataloader) - - -if __name__ == '__main__': - train() +from pathlib import Path + +from peft import LoraConfig +from tqdm import tqdm + +import twinkle +from twinkle import DeviceMesh, get_device_placement, get_logger +from twinkle.cli import CLI +from twinkle.dataloader import DataLoader +from twinkle.dataset import Dataset, DatasetMeta +from twinkle.model import TransformersModel +from twinkle.preprocessor import SelfCognitionProcessor +from twinkle.utils.framework import Torch +from twinkle.kernel import kernelize, npu_builtin + +logger = get_logger() +args = CLI.from_args() + +device_mesh = DeviceMesh.from_sizes(fsdp_size=args.infra.fsdp_size, dp_size=args.infra.dp_size) +twinkle.initialize(mode=args.infra.mode, global_device_mesh=device_mesh) + + +def build_dataset(num_samples: int) -> Dataset: + dataset = Dataset(dataset_meta=DatasetMeta(args.dataset.dataset_id, data_slice=range(num_samples))) + dataset.set_template(args.template.template_cls, model_id=args.model.model_id) + dataset.map(SelfCognitionProcessor( + args.extra.get('model_name', 'twinkle大模型'), + args.extra.get('model_author', 'ModelScope社区'), + )) + dataset.encode() + return dataset + + +def save_checkpoint(model: TransformersModel, checkpoint_name: str, dataloader: DataLoader): + model.save( + checkpoint_name, + output_dir=args.training.output_dir, + adapter_name=args.lora.adapter_name, + save_optimizer=True, + consumed_train_samples=dataloader.get_state()['consumed_train_samples'], + ) + + +def evaluate(model): + eval_samples = args.training.eval_samples or 100 + dataloader = DataLoader(dataset=build_dataset(eval_samples), batch_size=args.training.batch_size) + for batch in tqdm(dataloader): + model.forward_only(inputs=batch) + model.calculate_loss() + return model.calculate_metric(is_training=False) + + +def train(): + train_samples = int(args.extra.get('train_samples', 1000)) + dataset = build_dataset(train_samples) + dataloader = DataLoader(dataset=dataset, batch_size=args.training.batch_size) + model = TransformersModel(model_id=args.model.model_id) + model.model._no_split_modules = {'Qwen3_5DecoderLayer'} + # npu patch + if Torch.is_npu_available(): + model = kernelize(model, npu_builtin(model)) + + lora_config = LoraConfig(**args.get_lora_args()) + model.add_adapter_to_model( + args.lora.adapter_name, lora_config, + gradient_accumulation_steps=args.training.gradient_accumulation_steps) + model.set_optimizer(optimizer_cls=args.optimizer.optimizer_cls, lr=args.optimizer.learning_rate) + + # Add LRScheduler for lora `default` + model.set_lr_scheduler( + scheduler_cls=args.scheduler.scheduler_cls, + num_warmup_steps=args.scheduler.num_warmup_steps, + num_training_steps=len(dataloader)) + + if args.training.resume_from_checkpoint: + checkpoint_path = Path(args.training.resume_from_checkpoint).expanduser().resolve() + progress = model.resume_from_checkpoint( + str(checkpoint_path), + resume_only_model=args.training.resume_only_model, + adapter_name=args.lora.adapter_name) + if not args.training.ignore_data_skip: + dataloader.resume_from_checkpoint(progress['consumed_train_samples']) + + logger.info(get_device_placement()) + logger.info(model.get_train_configs()) + logger.info(f'Total steps: {len(dataloader)}') + optimizer_group = model.optimizer_group[args.lora.adapter_name] + best_loss = float('inf') + eval_interval = args.training.eval_interval or 40 + for batch in dataloader: + model.forward_backward(inputs=batch) + model.clip_grad_and_step() + cur_step = optimizer_group.cur_step + if cur_step % args.training.log_interval == 0: + metric = model.calculate_metric(is_training=True) + logger.info(f'Current is step {cur_step} of {len(dataloader)}, metric: {metric}') + if cur_step > 0 and cur_step % eval_interval == 0: + metrics = evaluate(model) + logger.info(f'Eval metric: {metrics}') + metrics['step'] = cur_step + current_loss = float(metrics['loss']) + if current_loss < best_loss: + save_checkpoint(model, f'checkpoint-{cur_step}', dataloader) + best_loss = current_loss + save_checkpoint(model, 'last-checkpoint', dataloader) + + +if __name__ == '__main__': + train() diff --git a/cookbook/transformers/sp_fsdp_dense.py b/cookbook/transformers/sp_fsdp_dense.py index 2fd4ecf1..8a8fb412 100644 --- a/cookbook/transformers/sp_fsdp_dense.py +++ b/cookbook/transformers/sp_fsdp_dense.py @@ -1,99 +1,99 @@ -import numpy as np -from functools import partial -from peft import LoraConfig - -import twinkle -from twinkle import DeviceGroup, DeviceMesh, Platform, get_logger -from twinkle.dataloader import DataLoader -from twinkle.dataset import Dataset, DatasetMeta -from twinkle.model import TransformersModel -from twinkle.preprocessor import SelfCognitionProcessor -from twinkle.utils.framework import Torch -from twinkle.kernel import kernelize, npu_builtin - -logger = get_logger() -MODEL_ID = 'ms://Qwen/Qwen3.5-4B' -DATASETS = 'ms://swift/self-cognition' - -device_group = [DeviceGroup( - name='default', - ranks=[0, 1, 2, 3], - device_type=Platform.get_platform().device_prefix(), -)] - -# FSDP + sequence-parallel validation over 4 GPUs: dp=2, fsdp=2. -# In Transformers route, ulysses_size is the total sequence-parallel degree. -device_mesh = DeviceMesh( - device_type=Platform.get_platform().device_prefix(), - mesh=np.arange(4).reshape(2, 2), - mesh_dim_names=('dp', 'fsdp'), - ulysses_size=2, -) - -twinkle.initialize( - mode='local', - nproc_per_node=4, - global_device_mesh=device_mesh, - lazy_collect=False, -) - - -def eval(model): - dataloader = DataLoader( - dataset=partial(create_dataset, data_slice=range(100)), - batch_size=4, - device_mesh=device_mesh, - ) - for _, batch in enumerate(dataloader): - model.forward_only(inputs=batch, adapter_name='default') - model.calculate_loss(adapter_name='default') - return model.calculate_metric(is_training=False, adapter_name='default') - - -def create_dataset(data_slice=None): - dataset = Dataset(dataset_meta=DatasetMeta(DATASETS, data_slice=range(500))) - dataset.set_template('Qwen3_5Template', model_id=MODEL_ID) - dataset.map(SelfCognitionProcessor('twinkle模型', 'twinkle团队')) - dataset.encode(batched=True) - return dataset - - -def train(): - dataloader = DataLoader( - dataset=partial(create_dataset, data_slice=None), - batch_size=8, - device_mesh=device_mesh, - ) - - model = TransformersModel( - model_id=MODEL_ID, - device_mesh=device_mesh, - strategy='native_fsdp', - ) - # npu patch - if Torch.is_npu_available(): - model = kernelize(model, npu_builtin(model)) - lora_config = LoraConfig(target_modules='all-linear') - model.add_adapter_to_model('default', lora_config, gradient_accumulation_steps=1) - model.set_optimizer('AdamW', lr=1e-4, adapter_name='default') - model.set_lr_scheduler( - scheduler_cls='CosineWarmupScheduler', - num_warmup_steps=5, - num_training_steps=len(dataloader), - adapter_name='default', - ) - - logger.info(model.get_train_configs(adapter_name='default')) - logger.info(f'Total steps: {len(dataloader)}') - - for step, batch in enumerate(dataloader): - model.forward_backward(inputs=batch, adapter_name='default') - model.clip_grad_and_step(adapter_name='default') - if step % 20 == 0: - metric = model.calculate_metric(is_training=True, adapter_name='default') - logger.info(f'Current is step {step} of {len(dataloader)}, metric: {metric}') - model.save('last-checkpoint', interval=1) - - -if __name__ == '__main__': - train() +import numpy as np +from functools import partial +from peft import LoraConfig + +import twinkle +from twinkle import DeviceGroup, DeviceMesh, Platform, get_logger +from twinkle.dataloader import DataLoader +from twinkle.dataset import Dataset, DatasetMeta +from twinkle.model import TransformersModel +from twinkle.preprocessor import SelfCognitionProcessor +from twinkle.utils.framework import Torch +from twinkle.kernel import kernelize, npu_builtin + +logger = get_logger() +MODEL_ID = 'ms://Qwen/Qwen3.5-4B' +DATASETS = 'ms://swift/self-cognition' + +device_group = [DeviceGroup( + name='default', + ranks=[0, 1, 2, 3], + device_type=Platform.get_platform().device_prefix(), +)] + +# FSDP + sequence-parallel validation over 4 GPUs: dp=2, fsdp=2. +# In Transformers route, ulysses_size is the total sequence-parallel degree. +device_mesh = DeviceMesh( + device_type=Platform.get_platform().device_prefix(), + mesh=np.arange(4).reshape(2, 2), + mesh_dim_names=('dp', 'fsdp'), + ulysses_size=2, +) + +twinkle.initialize( + mode='local', + nproc_per_node=4, + global_device_mesh=device_mesh, + lazy_collect=False, +) + + +def eval(model): + dataloader = DataLoader( + dataset=partial(create_dataset, data_slice=range(100)), + batch_size=4, + device_mesh=device_mesh, + ) + for _, batch in enumerate(dataloader): + model.forward_only(inputs=batch, adapter_name='default') + model.calculate_loss(adapter_name='default') + return model.calculate_metric(is_training=False, adapter_name='default') + + +def create_dataset(data_slice=None): + dataset = Dataset(dataset_meta=DatasetMeta(DATASETS, data_slice=range(500))) + dataset.set_template('Qwen3_5Template', model_id=MODEL_ID) + dataset.map(SelfCognitionProcessor('twinkle模型', 'twinkle团队')) + dataset.encode(batched=True) + return dataset + + +def train(): + dataloader = DataLoader( + dataset=partial(create_dataset, data_slice=None), + batch_size=8, + device_mesh=device_mesh, + ) + + model = TransformersModel( + model_id=MODEL_ID, + device_mesh=device_mesh, + strategy='native_fsdp', + ) + # npu patch + if Torch.is_npu_available(): + model = kernelize(model, npu_builtin(model)) + lora_config = LoraConfig(target_modules='all-linear') + model.add_adapter_to_model('default', lora_config, gradient_accumulation_steps=1) + model.set_optimizer('AdamW', lr=1e-4, adapter_name='default') + model.set_lr_scheduler( + scheduler_cls='CosineWarmupScheduler', + num_warmup_steps=5, + num_training_steps=len(dataloader), + adapter_name='default', + ) + + logger.info(model.get_train_configs(adapter_name='default')) + logger.info(f'Total steps: {len(dataloader)}') + + for step, batch in enumerate(dataloader): + model.forward_backward(inputs=batch, adapter_name='default') + model.clip_grad_and_step(adapter_name='default') + if step % 20 == 0: + metric = model.calculate_metric(is_training=True, adapter_name='default') + logger.info(f'Current is step {step} of {len(dataloader)}, metric: {metric}') + model.save('last-checkpoint', interval=1) + + +if __name__ == '__main__': + train() diff --git a/docs/source_en/Components/Kernel/Kernel.md b/docs/source_en/Components/Kernel/Kernel.md index f5ab78e9..f9f168c3 100644 --- a/docs/source_en/Components/Kernel/Kernel.md +++ b/docs/source_en/Components/Kernel/Kernel.md @@ -1,139 +1,139 @@ -# Twinkle Kernel - -`twinkle.kernel` exposes a mapping-driven kernel replacement API. Replacing one -implementation with another collapses to a single `kernelize(model, mapping)` -call. - -The public surface is exactly three symbols: - -| Symbol | Purpose | -| --- | --- | -| `kernelize(model, mapping)` | Apply ``mapping`` to ``model`` (in place) and return it | -| `npu_builtin(model=None)` | Return the Ascend NPU built-in mapping (composes with user mappings) | -| `hub(ref, *, revision=None, version=None, backend=None, trust_remote_code=False)` | Build a ``HubRef`` for use as a mapping value; the actual Hub download is deferred to ``kernelize`` | - -## Mapping semantics - -`mapping` keys describe the target to replace: - -- `type[nn.Module]` subclass — replace **every** instance whose exact type matches (`m.__class__ = impl`; subclasses are **not** touched) -- `str` of the form `'pkg.sub.attr'` or `'pkg.sub.ClassName.attr'` — `setattr(target, attr, impl)` - -`mapping` values describe the replacement: - -- `type[nn.Module]` subclass — used as the impl class. The class' `__init__` is **never** invoked; its forward must work against the attributes the original instance already has -- `Callable` — assigned with `setattr` -- `dict[str, V]` — device → impl dispatch. Device is inferred from the model; entries without a matching key are **silently skipped** -- `HubRef` — built via `hub(...)`; resolved lazily - -Device is inferred from `next(model.parameters()).device.type` (falling back to buffers, then `'cpu'`). - -## Examples - -### Enable the full NPU built-in bundle - -```python -import torch -from twinkle.kernel import kernelize, npu_builtin - -if torch.npu.is_available(): - model = kernelize(model, npu_builtin(model)) -``` - -### Custom class replacement - -```python -from transformers.models.qwen2.modeling_qwen2 import Qwen2RMSNorm -from twinkle.kernel import kernelize - -model = kernelize(model, {Qwen2RMSNorm: MyRMSNorm}) -``` - -### Built-in + custom override - -```python -from twinkle.kernel import kernelize, npu_builtin - -model = kernelize(model, {**npu_builtin(model), Qwen2RMSNorm: MyRMSNorm}) -``` - -Plain dict merge — later keys override earlier ones. - -### Hub kernel (HF Hub format) - -```python -from twinkle.kernel import kernelize, hub -from my_pkg import SiluAndMul - -model = kernelize(model, { - SiluAndMul: hub('kernels-community/activation:SiluAndMul', version=1), -}) -``` - -Exactly one of `revision` / `version` must be passed. The `kernels` package is imported lazily; absence raises a clear "install kernels" error. - -### Function-level replacement - -```python -from twinkle.kernel import kernelize -from twinkle.kernel.npu_impls.rotary import npu_apply_rotary_pos_emb - -model = kernelize(model, { - 'transformers.models.qwen2.modeling_qwen2.apply_rotary_pos_emb': - npu_apply_rotary_pos_emb, -}) -``` - -### Cross-device mapping (NPU enabled, CUDA skipped) - -```python -from twinkle.kernel import kernelize - -model = kernelize(model, { - Qwen2RMSNorm: {'npu': NpuRMSNorm, 'cuda': CudaRMSNorm}, -}) -``` - -Safe to run on CUDA — entries whose dict misses the current device just skip. - -## NPU built-in coverage - -`npu_builtin(model)` returns a dict that (as available transformers modules permit) covers: - -- RMSNorm class replacement for Qwen2 / Qwen3 / Qwen3-MoE / Qwen2.5-VL / Qwen3.5 / Qwen3.5-MoE families -- `apply_rotary_pos_emb` function replacement (fused RoPE) for the same families -- SwiGLU fused replacement for the MLP variants -- `Experts.forward` and `SparseMoeBlock.forward` for Qwen3-MoE / Qwen3.5-MoE -- GatedRMSNorm forward for Qwen3.5 / Qwen3.5-MoE -- `apply_multimodal_rotary_pos_emb` for Qwen2.5-VL -- Global SDPA replacement (one-shot side effect on `ALL_ATTENTION_FUNCTIONS['sdpa']`) -- Qwen3.5 Flash Linear Attention enablement (one-shot side effect + per-instance traversal, triggered inside `npu_builtin(model)`) - -**Not included by default:** the NPU replacement for `transformers.integrations.moe._grouped_mm`. Without Expert Parallelism the contiguous-copy overhead is ~8x. Opt in explicitly when EP is enabled: - -```python -from twinkle.kernel import kernelize, npu_builtin -from twinkle.kernel.npu_impls.moe import npu_grouped_mm - -mapping = { - **npu_builtin(model), - 'transformers.integrations.moe._grouped_mm': {'npu': npu_grouped_mm}, -} -model = kernelize(model, mapping) -``` - -## Environment variables - -Only two remain: - -- `TWINKLE_NPU_FLA` — Qwen3.5 FLA switch (default on; `0`/`false` to disable) -- `TWINKLE_NPU_GATED_RMSNorm_FP32` — force FP32 in Gated RMSNorm forward (default off) - -The legacy `TWINKLE_NPU_PATCH` / `TWINKLE_NPU_FUSED_OPS` / `TWINKLE_NPU_GMM_PATCH` / `TWINKLE_USE_KERNELS` are gone — they're now "include the entry in the mapping or don't" decisions. - -## Caveats - -- `m.__class__ = impl_cls` is Python class-replacement magic. The impl class **must** override only `forward` (and helpers); defining `__init__` is incompatible with the contract -- Exact match: `type(m) is target_cls`. Subclasses of `target_cls` are not replaced — add them to the mapping yourself -- `kernelize` is idempotent under repeated calls +# Twinkle Kernel + +`twinkle.kernel` exposes a mapping-driven kernel replacement API. Replacing one +implementation with another collapses to a single `kernelize(model, mapping)` +call. + +The public surface is exactly three symbols: + +| Symbol | Purpose | +| --- | --- | +| `kernelize(model, mapping)` | Apply ``mapping`` to ``model`` (in place) and return it | +| `npu_builtin(model=None)` | Return the Ascend NPU built-in mapping (composes with user mappings) | +| `hub(ref, *, revision=None, version=None, backend=None, trust_remote_code=False)` | Build a ``HubRef`` for use as a mapping value; the actual Hub download is deferred to ``kernelize`` | + +## Mapping semantics + +`mapping` keys describe the target to replace: + +- `type[nn.Module]` subclass — replace **every** instance whose exact type matches (`m.__class__ = impl`; subclasses are **not** touched) +- `str` of the form `'pkg.sub.attr'` or `'pkg.sub.ClassName.attr'` — `setattr(target, attr, impl)` + +`mapping` values describe the replacement: + +- `type[nn.Module]` subclass — used as the impl class. The class' `__init__` is **never** invoked; its forward must work against the attributes the original instance already has +- `Callable` — assigned with `setattr` +- `dict[str, V]` — device → impl dispatch. Device is inferred from the model; entries without a matching key are **silently skipped** +- `HubRef` — built via `hub(...)`; resolved lazily + +Device is inferred from `next(model.parameters()).device.type` (falling back to buffers, then `'cpu'`). + +## Examples + +### Enable the full NPU built-in bundle + +```python +import torch +from twinkle.kernel import kernelize, npu_builtin + +if torch.npu.is_available(): + model = kernelize(model, npu_builtin(model)) +``` + +### Custom class replacement + +```python +from transformers.models.qwen2.modeling_qwen2 import Qwen2RMSNorm +from twinkle.kernel import kernelize + +model = kernelize(model, {Qwen2RMSNorm: MyRMSNorm}) +``` + +### Built-in + custom override + +```python +from twinkle.kernel import kernelize, npu_builtin + +model = kernelize(model, {**npu_builtin(model), Qwen2RMSNorm: MyRMSNorm}) +``` + +Plain dict merge — later keys override earlier ones. + +### Hub kernel (HF Hub format) + +```python +from twinkle.kernel import kernelize, hub +from my_pkg import SiluAndMul + +model = kernelize(model, { + SiluAndMul: hub('kernels-community/activation:SiluAndMul', version=1), +}) +``` + +Exactly one of `revision` / `version` must be passed. The `kernels` package is imported lazily; absence raises a clear "install kernels" error. + +### Function-level replacement + +```python +from twinkle.kernel import kernelize +from twinkle.kernel.npu_impls.rotary import npu_apply_rotary_pos_emb + +model = kernelize(model, { + 'transformers.models.qwen2.modeling_qwen2.apply_rotary_pos_emb': + npu_apply_rotary_pos_emb, +}) +``` + +### Cross-device mapping (NPU enabled, CUDA skipped) + +```python +from twinkle.kernel import kernelize + +model = kernelize(model, { + Qwen2RMSNorm: {'npu': NpuRMSNorm, 'cuda': CudaRMSNorm}, +}) +``` + +Safe to run on CUDA — entries whose dict misses the current device just skip. + +## NPU built-in coverage + +`npu_builtin(model)` returns a dict that (as available transformers modules permit) covers: + +- RMSNorm class replacement for Qwen2 / Qwen3 / Qwen3-MoE / Qwen2.5-VL / Qwen3.5 / Qwen3.5-MoE families +- `apply_rotary_pos_emb` function replacement (fused RoPE) for the same families +- SwiGLU fused replacement for the MLP variants +- `Experts.forward` and `SparseMoeBlock.forward` for Qwen3-MoE / Qwen3.5-MoE +- GatedRMSNorm forward for Qwen3.5 / Qwen3.5-MoE +- `apply_multimodal_rotary_pos_emb` for Qwen2.5-VL +- Global SDPA replacement (one-shot side effect on `ALL_ATTENTION_FUNCTIONS['sdpa']`) +- Qwen3.5 Flash Linear Attention enablement (one-shot side effect + per-instance traversal, triggered inside `npu_builtin(model)`) + +**Not included by default:** the NPU replacement for `transformers.integrations.moe._grouped_mm`. Without Expert Parallelism the contiguous-copy overhead is ~8x. Opt in explicitly when EP is enabled: + +```python +from twinkle.kernel import kernelize, npu_builtin +from twinkle.kernel.npu_impls.moe import npu_grouped_mm + +mapping = { + **npu_builtin(model), + 'transformers.integrations.moe._grouped_mm': {'npu': npu_grouped_mm}, +} +model = kernelize(model, mapping) +``` + +## Environment variables + +Only two remain: + +- `TWINKLE_NPU_FLA` — Qwen3.5 FLA switch (default on; `0`/`false` to disable) +- `TWINKLE_NPU_GATED_RMSNorm_FP32` — force FP32 in Gated RMSNorm forward (default off) + +The legacy `TWINKLE_NPU_PATCH` / `TWINKLE_NPU_FUSED_OPS` / `TWINKLE_NPU_GMM_PATCH` / `TWINKLE_USE_KERNELS` are gone — they're now "include the entry in the mapping or don't" decisions. + +## Caveats + +- `m.__class__ = impl_cls` is Python class-replacement magic. The impl class **must** override only `forward` (and helpers); defining `__init__` is incompatible with the contract +- Exact match: `type(m) is target_cls`. Subclasses of `target_cls` are not replaced — add them to the mapping yourself +- `kernelize` is idempotent under repeated calls - There is no `unkernelize` — replacement is one-way \ No newline at end of file diff --git "a/docs/source_zh/\347\273\204\344\273\266/\345\206\205\346\240\270/Kernel.md" "b/docs/source_zh/\347\273\204\344\273\266/\345\206\205\346\240\270/Kernel.md" index 8ed5ad78..a7687020 100644 --- "a/docs/source_zh/\347\273\204\344\273\266/\345\206\205\346\240\270/Kernel.md" +++ "b/docs/source_zh/\347\273\204\344\273\266/\345\206\205\346\240\270/Kernel.md" @@ -1,137 +1,137 @@ -# Twinkle Kernel 模块 - -`twinkle.kernel` 提供一个 mapping 驱动的内核替换接口,把“用一种实现替换模型里的另一种实现”压缩为一次 `kernelize(model, mapping)` 调用。 - -公开符号只有三个: - -| 符号 | 作用 | -| --- | --- | -| `kernelize(model, mapping)` | 在 `model` 上应用 `mapping`,原地修改后返回 | -| `npu_builtin(model=None)` | 返回 Ascend NPU 内置替换的 mapping dict(可与用户 mapping 自由组合) | -| `hub(ref, *, revision=None, version=None, backend=None, trust_remote_code=False)` | 构造一个 `HubRef`,用作 mapping value;真实下载推迟到 `kernelize` 执行 | - -## Mapping 语义 - -`mapping` 的 **key** 表示要替换的目标: - -- `type[nn.Module]` 子类:替换模型里**所有**该精确类型的实例(`m.__class__ = impl_class`,**不包含**子类) -- `str` 形如 `'pkg.sub.attr'` 或 `'pkg.sub.ClassName.attr'`:`setattr(target, attr, impl)` - -**value** 表示用什么替换: - -- `type[nn.Module]` 子类:直接作为 impl 类。该类**不会被 `__init__` 调用**,必须只依赖原 instance 已经有的 attribute(weight / eps / ...)正确工作 -- `Callable`:直接 `setattr` 上去 -- `dict[str, V]`:device → impl 嵌套分派。从 `model` 推断当前 device,未匹配则**静默跳过** -- `HubRef`:通过 `hub(...)` 构造的 Hub 引用,延迟加载 - -device 从 `next(model.parameters()).device.type` 推断(无参数则用 buffers,再无则为 `'cpu'`)。 - -## 场景示例 - -### 启用全部 NPU 内置优化 - -```python -import torch -from twinkle.kernel import kernelize, npu_builtin - -if torch.npu.is_available(): - model = kernelize(model, npu_builtin(model)) -``` - -### 自定义类替换 - -```python -from transformers.models.qwen2.modeling_qwen2 import Qwen2RMSNorm -from twinkle.kernel import kernelize - -model = kernelize(model, {Qwen2RMSNorm: MyRMSNorm}) -``` - -### 内置 + 自定义混合 - -```python -from twinkle.kernel import kernelize, npu_builtin - -model = kernelize(model, {**npu_builtin(model), Qwen2RMSNorm: MyRMSNorm}) -``` - -后写入的 key 会覆盖前面的,普通 dict 合并语义。 - -### Hub Kernel(HF Hub 格式) - -```python -from twinkle.kernel import kernelize, hub -from my_pkg import SiluAndMul - -model = kernelize(model, { - SiluAndMul: hub('kernels-community/activation:SiluAndMul', version=1), -}) -``` - -`revision` 与 `version` 二选一必传。`hub(...)` 触发 `kernels` 包的延迟 import,未安装时会提示 `pip install kernels`。 - -### 函数级替换 - -```python -from twinkle.kernel import kernelize -from twinkle.kernel.npu_impls.rotary import npu_apply_rotary_pos_emb - -model = kernelize(model, { - 'transformers.models.qwen2.modeling_qwen2.apply_rotary_pos_emb': - npu_apply_rotary_pos_emb, -}) -``` - -### 跨设备 mapping(NPU 启用、CUDA 跳过) - -```python -from twinkle.kernel import kernelize - -model = kernelize(model, { - Qwen2RMSNorm: {'npu': NpuRMSNorm, 'cuda': CudaRMSNorm}, -}) -``` - -在 CUDA 模型上跑也安全:未匹配 device 的 entry 不会替换、不会报错。 - -## 内置 NPU 优化 - -`npu_builtin(model)` 返回的 dict 至少包含以下覆盖(实际条目随 transformers 已安装的 modeling 模块动态收集): - -- Qwen2 / Qwen3 / Qwen3-MoE / Qwen2.5-VL / Qwen3.5 / Qwen3.5-MoE 系列的 RMSNorm 类替换 -- 同上系列的 `apply_rotary_pos_emb` 函数替换(融合 RoPE) -- 同上系列 MLP 的 SwiGLU 融合替换 -- Qwen3-MoE / Qwen3.5-MoE 的 `Experts.forward` 与 `SparseMoeBlock.forward` 替换 -- Qwen3.5 / Qwen3.5-MoE 的 GatedRMSNorm forward 替换 -- Qwen2.5-VL 的 `apply_multimodal_rotary_pos_emb` 替换 -- 全局 SDPA 替换(一次性副作用,写入 `ALL_ATTENTION_FUNCTIONS['sdpa']`) -- Qwen3.5 Flash Linear Attention 启用(一次性副作用 + 实例遍历,由 `npu_builtin(model)` 内部触发) - -**未默认包含** `transformers.integrations.moe._grouped_mm` 的 NPU 替换(在没有 Expert Parallelism 时会带来约 8x 开销)。需要时手动加入: - -```python -from twinkle.kernel import kernelize, npu_builtin -from twinkle.kernel.npu_impls.moe import npu_grouped_mm - -mapping = { - **npu_builtin(model), - 'transformers.integrations.moe._grouped_mm': {'npu': npu_grouped_mm}, -} -model = kernelize(model, mapping) -``` - -## 环境变量 - -只有两个保留: - -- `TWINKLE_NPU_FLA`:Qwen3.5 FLA 开关(默认开,设为 `0`/`false` 关闭) -- `TWINKLE_NPU_GATED_RMSNorm_FP32`:将 Gated RMSNorm 强制升到 FP32 计算(默认关) - -旧的 `TWINKLE_NPU_PATCH` / `TWINKLE_NPU_FUSED_OPS` / `TWINKLE_NPU_GMM_PATCH` / `TWINKLE_USE_KERNELS` 已移除——这些都改成"是否把对应 entry 写进 mapping"的显式选择。 - -## 注意事项 - -- `m.__class__ = impl_cls` 是 Python class 替换魔法。impl 类**必须**只覆盖 `forward`(以及辅助方法),不能定义 `__init__`,否则原 instance 的 attribute 会与 impl 的预期错位 -- 精确匹配:`type(m) is target_cls`。继承自 `target_cls` 的子类不会被替换;如需替换,把子类也放进 mapping -- 调用 `kernelize` 多次是幂等的(`__class__` 已是 impl 时再设一次无害) +# Twinkle Kernel 模块 + +`twinkle.kernel` 提供一个 mapping 驱动的内核替换接口,把“用一种实现替换模型里的另一种实现”压缩为一次 `kernelize(model, mapping)` 调用。 + +公开符号只有三个: + +| 符号 | 作用 | +| --- | --- | +| `kernelize(model, mapping)` | 在 `model` 上应用 `mapping`,原地修改后返回 | +| `npu_builtin(model=None)` | 返回 Ascend NPU 内置替换的 mapping dict(可与用户 mapping 自由组合) | +| `hub(ref, *, revision=None, version=None, backend=None, trust_remote_code=False)` | 构造一个 `HubRef`,用作 mapping value;真实下载推迟到 `kernelize` 执行 | + +## Mapping 语义 + +`mapping` 的 **key** 表示要替换的目标: + +- `type[nn.Module]` 子类:替换模型里**所有**该精确类型的实例(`m.__class__ = impl_class`,**不包含**子类) +- `str` 形如 `'pkg.sub.attr'` 或 `'pkg.sub.ClassName.attr'`:`setattr(target, attr, impl)` + +**value** 表示用什么替换: + +- `type[nn.Module]` 子类:直接作为 impl 类。该类**不会被 `__init__` 调用**,必须只依赖原 instance 已经有的 attribute(weight / eps / ...)正确工作 +- `Callable`:直接 `setattr` 上去 +- `dict[str, V]`:device → impl 嵌套分派。从 `model` 推断当前 device,未匹配则**静默跳过** +- `HubRef`:通过 `hub(...)` 构造的 Hub 引用,延迟加载 + +device 从 `next(model.parameters()).device.type` 推断(无参数则用 buffers,再无则为 `'cpu'`)。 + +## 场景示例 + +### 启用全部 NPU 内置优化 + +```python +import torch +from twinkle.kernel import kernelize, npu_builtin + +if torch.npu.is_available(): + model = kernelize(model, npu_builtin(model)) +``` + +### 自定义类替换 + +```python +from transformers.models.qwen2.modeling_qwen2 import Qwen2RMSNorm +from twinkle.kernel import kernelize + +model = kernelize(model, {Qwen2RMSNorm: MyRMSNorm}) +``` + +### 内置 + 自定义混合 + +```python +from twinkle.kernel import kernelize, npu_builtin + +model = kernelize(model, {**npu_builtin(model), Qwen2RMSNorm: MyRMSNorm}) +``` + +后写入的 key 会覆盖前面的,普通 dict 合并语义。 + +### Hub Kernel(HF Hub 格式) + +```python +from twinkle.kernel import kernelize, hub +from my_pkg import SiluAndMul + +model = kernelize(model, { + SiluAndMul: hub('kernels-community/activation:SiluAndMul', version=1), +}) +``` + +`revision` 与 `version` 二选一必传。`hub(...)` 触发 `kernels` 包的延迟 import,未安装时会提示 `pip install kernels`。 + +### 函数级替换 + +```python +from twinkle.kernel import kernelize +from twinkle.kernel.npu_impls.rotary import npu_apply_rotary_pos_emb + +model = kernelize(model, { + 'transformers.models.qwen2.modeling_qwen2.apply_rotary_pos_emb': + npu_apply_rotary_pos_emb, +}) +``` + +### 跨设备 mapping(NPU 启用、CUDA 跳过) + +```python +from twinkle.kernel import kernelize + +model = kernelize(model, { + Qwen2RMSNorm: {'npu': NpuRMSNorm, 'cuda': CudaRMSNorm}, +}) +``` + +在 CUDA 模型上跑也安全:未匹配 device 的 entry 不会替换、不会报错。 + +## 内置 NPU 优化 + +`npu_builtin(model)` 返回的 dict 至少包含以下覆盖(实际条目随 transformers 已安装的 modeling 模块动态收集): + +- Qwen2 / Qwen3 / Qwen3-MoE / Qwen2.5-VL / Qwen3.5 / Qwen3.5-MoE 系列的 RMSNorm 类替换 +- 同上系列的 `apply_rotary_pos_emb` 函数替换(融合 RoPE) +- 同上系列 MLP 的 SwiGLU 融合替换 +- Qwen3-MoE / Qwen3.5-MoE 的 `Experts.forward` 与 `SparseMoeBlock.forward` 替换 +- Qwen3.5 / Qwen3.5-MoE 的 GatedRMSNorm forward 替换 +- Qwen2.5-VL 的 `apply_multimodal_rotary_pos_emb` 替换 +- 全局 SDPA 替换(一次性副作用,写入 `ALL_ATTENTION_FUNCTIONS['sdpa']`) +- Qwen3.5 Flash Linear Attention 启用(一次性副作用 + 实例遍历,由 `npu_builtin(model)` 内部触发) + +**未默认包含** `transformers.integrations.moe._grouped_mm` 的 NPU 替换(在没有 Expert Parallelism 时会带来约 8x 开销)。需要时手动加入: + +```python +from twinkle.kernel import kernelize, npu_builtin +from twinkle.kernel.npu_impls.moe import npu_grouped_mm + +mapping = { + **npu_builtin(model), + 'transformers.integrations.moe._grouped_mm': {'npu': npu_grouped_mm}, +} +model = kernelize(model, mapping) +``` + +## 环境变量 + +只有两个保留: + +- `TWINKLE_NPU_FLA`:Qwen3.5 FLA 开关(默认开,设为 `0`/`false` 关闭) +- `TWINKLE_NPU_GATED_RMSNorm_FP32`:将 Gated RMSNorm 强制升到 FP32 计算(默认关) + +旧的 `TWINKLE_NPU_PATCH` / `TWINKLE_NPU_FUSED_OPS` / `TWINKLE_NPU_GMM_PATCH` / `TWINKLE_USE_KERNELS` 已移除——这些都改成"是否把对应 entry 写进 mapping"的显式选择。 + +## 注意事项 + +- `m.__class__ = impl_cls` 是 Python class 替换魔法。impl 类**必须**只覆盖 `forward`(以及辅助方法),不能定义 `__init__`,否则原 instance 的 attribute 会与 impl 的预期错位 +- 精确匹配:`type(m) is target_cls`。继承自 `target_cls` 的子类不会被替换;如需替换,把子类也放进 mapping +- 调用 `kernelize` 多次是幂等的(`__class__` 已是 impl 时再设一次无害) - 没有 `unkernelize`——替换是单向的 \ No newline at end of file diff --git a/src/twinkle/kernel/__init__.py b/src/twinkle/kernel/__init__.py index 5d435c0f..f1de4b75 100644 --- a/src/twinkle/kernel/__init__.py +++ b/src/twinkle/kernel/__init__.py @@ -1,13 +1,13 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -"""Mapping-driven kernel replacement. - -Three public symbols: - -- :func:`kernelize` apply ``mapping`` to a model -- :func:`hub` build a Hub kernel reference -- :func:`npu_builtin` the Ascend NPU built-in bundle -""" -from .builtin import npu_builtin -from .core import hub, kernelize - +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Mapping-driven kernel replacement. + +Three public symbols: + +- :func:`kernelize` apply ``mapping`` to a model +- :func:`hub` build a Hub kernel reference +- :func:`npu_builtin` the Ascend NPU built-in bundle +""" +from .builtin import npu_builtin +from .core import hub, kernelize + __all__ = ['kernelize', 'hub', 'npu_builtin'] \ No newline at end of file diff --git a/src/twinkle/kernel/builtin.py b/src/twinkle/kernel/builtin.py index 8d7f3b59..b2376e77 100644 --- a/src/twinkle/kernel/builtin.py +++ b/src/twinkle/kernel/builtin.py @@ -1,205 +1,205 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -"""``npu_builtin()`` returns the bundle of Ascend NPU replacements. - -All values are wrapped in ``{'npu': impl}`` so the bundle composes safely on -CUDA/CPU systems — non-NPU devices silently skip every entry. - -GMM is **not** included by default (without EP it causes ~8x slowdown). Opt -in by merging: - - {**npu_builtin(model), 'transformers.integrations.moe._grouped_mm': - {'npu': npu_grouped_mm}} -""" -from __future__ import annotations - -import importlib -from typing import Any - -import torch.nn as nn - -from twinkle import get_logger -from twinkle.utils.device_mesh import Platform - -logger = get_logger() - - -def _import_optional(name: str): - try: - return importlib.import_module(name) - except ImportError: - return None - - -def npu_builtin(model: nn.Module | None = None) -> dict[Any, dict[str, Any]]: - """Return the NPU builtin mapping; optionally apply per-instance FLA.""" - from .npu_impls.attention import npu_sdpa_attention_forward - from .npu_impls.fla import apply_qwen3_5_fla - from .npu_impls.moe import ( - npu_packed_moe_experts_forward, - npu_qwen3_5_moe_sparse_block_forward, - ) - from .npu_impls.rms_norm import NpuRMSNorm, npu_gated_rms_norm_forward - from .npu_impls.rotary import ( - npu_apply_multimodal_rotary_pos_emb, - npu_apply_rotary_pos_emb, - ) - from .npu_impls.swiglu import npu_swiglu_forward - - bundle: dict[Any, dict[str, Any]] = {} - - is_npu_platform = Platform.device_prefix() == 'npu' - - # Apply SDPA install eagerly (one-shot module-level mutation) on NPU - # platforms. The NPU impl inverts boolean masks, which is wrong for - # CUDA/CPU execution, so non-NPU platforms must not mutate the global HF - # registry even if ``torch_npu`` is importable in the environment. - if is_npu_platform: - _install_sdpa(npu_sdpa_attention_forward) - - # === per-family class + function entries === - _add_qwen2_entries(bundle, NpuRMSNorm, npu_apply_rotary_pos_emb, npu_swiglu_forward) - _add_qwen3_entries(bundle, NpuRMSNorm, npu_apply_rotary_pos_emb, npu_swiglu_forward) - _add_qwen3_moe_entries( - bundle, NpuRMSNorm, npu_apply_rotary_pos_emb, npu_swiglu_forward, - npu_packed_moe_experts_forward, npu_qwen3_5_moe_sparse_block_forward, - ) - _add_qwen2_5_vl_entries( - bundle, NpuRMSNorm, npu_apply_rotary_pos_emb, npu_swiglu_forward, - npu_apply_multimodal_rotary_pos_emb, - ) - _add_qwen3_5_entries( - bundle, NpuRMSNorm, npu_gated_rms_norm_forward, npu_apply_rotary_pos_emb, - npu_swiglu_forward, - ) - _add_qwen3_5_moe_entries( - bundle, NpuRMSNorm, npu_gated_rms_norm_forward, npu_apply_rotary_pos_emb, - npu_swiglu_forward, npu_packed_moe_experts_forward, - npu_qwen3_5_moe_sparse_block_forward, - ) - - # === FLA (side-effect; mapping-incompatible) === - if is_npu_platform: - apply_qwen3_5_fla(model) - - return bundle - - -def _install_sdpa(impl) -> None: - """One-shot install of SDPA attention forward (global modeling_utils dict). - - ``AttentionInterface._global_mapping`` is a private transformers attribute; - guard against its removal so an upstream change can't take down the rest - of ``npu_builtin()``. - """ - try: - from transformers.modeling_utils import ( - ALL_ATTENTION_FUNCTIONS, - AttentionInterface, - ) - except ImportError: - return - try: - AttentionInterface._global_mapping['sdpa'] = impl - except AttributeError: - logger.warning('[NPU] [SDPA] AttentionInterface._global_mapping unavailable; skipping') - ALL_ATTENTION_FUNCTIONS['sdpa'] = impl - - -# ---- helpers that conditionally add entries based on module availability ---- - -def _add_class_if_present(bundle, module_path, class_name, impl_cls): - mod = _import_optional(module_path) - if mod is None: - return - cls = getattr(mod, class_name, None) - if isinstance(cls, type): - bundle[cls] = {'npu': impl_cls} - - -def _add_swiglu_if_present(bundle, module_path, class_name, fn): - mod = _import_optional(module_path) - if mod is None: - return - cls = getattr(mod, class_name, None) - if isinstance(cls, type): - # Function-level: wrap as string-keyed forward replacement. - # We override on the *class object*, not the module attribute, by - # using a class-key with a synthetic impl wrapping the forward. - # The simplest way is to subclass and reassign __class__, but here - # we follow the legacy approach of overwriting the class's forward: - bundle[f'{module_path}.{class_name}.forward'] = {'npu': fn} - - -def _add_attr_if_present(bundle, module_path, attr_name, impl): - mod = _import_optional(module_path) - if mod is None: - return - if '.' in attr_name: - # Dotted attr like 'Qwen3MoeExperts.forward': resolve the class on - # the module, then check the trailing member on the class. - head, _, tail = attr_name.partition('.') - owner = getattr(mod, head, None) - if owner is None or not hasattr(owner, tail): - return - else: - if not hasattr(mod, attr_name): - return - bundle[f'{module_path}.{attr_name}'] = {'npu': impl} - - -def _add_qwen2_entries(bundle, rms_cls, rope_fn, swiglu_fn): - # Qwen2 (used by Qwen2.5-VL etc. via inheritance) - _add_class_if_present(bundle, 'transformers.models.qwen2.modeling_qwen2', 'Qwen2RMSNorm', rms_cls) - _add_attr_if_present(bundle, 'transformers.models.qwen2.modeling_qwen2', 'apply_rotary_pos_emb', rope_fn) - _add_swiglu_if_present(bundle, 'transformers.models.qwen2.modeling_qwen2', 'Qwen2MLP', swiglu_fn) - - -def _add_qwen3_entries(bundle, rms_cls, rope_fn, swiglu_fn): - base = 'transformers.models.qwen3.modeling_qwen3' - _add_class_if_present(bundle, base, 'Qwen3RMSNorm', rms_cls) - _add_attr_if_present(bundle, base, 'apply_rotary_pos_emb', rope_fn) - _add_swiglu_if_present(bundle, base, 'Qwen3MLP', swiglu_fn) - - -def _add_qwen3_moe_entries(bundle, rms_cls, rope_fn, swiglu_fn, experts_fn, sparse_fn): - base = 'transformers.models.qwen3_moe.modeling_qwen3_moe' - _add_class_if_present(bundle, base, 'Qwen3MoeRMSNorm', rms_cls) - _add_attr_if_present(bundle, base, 'apply_rotary_pos_emb', rope_fn) - _add_swiglu_if_present(bundle, base, 'Qwen3MoeMLP', swiglu_fn) - _add_attr_if_present(bundle, base, 'Qwen3MoeExperts.forward', experts_fn) - _add_attr_if_present(bundle, base, 'Qwen3MoeSparseMoeBlock.forward', sparse_fn) - - -def _add_qwen2_5_vl_entries(bundle, rms_cls, rope_fn, swiglu_fn, multimodal_rope_fn): - base = 'transformers.models.qwen2_5_vl.modeling_qwen2_5_vl' - _add_class_if_present(bundle, base, 'Qwen2_5_VLRMSNorm', rms_cls) - _add_attr_if_present(bundle, base, 'apply_rotary_pos_emb', rope_fn) - _add_attr_if_present(bundle, base, 'apply_multimodal_rotary_pos_emb', multimodal_rope_fn) - _add_swiglu_if_present(bundle, base, 'Qwen2MLP', swiglu_fn) - _add_swiglu_if_present(bundle, base, 'Qwen2_5_VLMLP', swiglu_fn) - - -def _add_qwen3_5_entries(bundle, rms_cls, gated_rms_fn, rope_fn, swiglu_fn): - base = 'transformers.models.qwen3_5.modeling_qwen3_5' - if _import_optional(base) is None: - return - _add_class_if_present(bundle, base, 'Qwen3_5RMSNorm', rms_cls) - _add_class_if_present(bundle, base, 'Qwen3_5VisionRMSNorm', rms_cls) - _add_attr_if_present(bundle, base, 'apply_rotary_pos_emb', rope_fn) - _add_swiglu_if_present(bundle, base, 'Qwen3_5MLP', swiglu_fn) - _add_swiglu_if_present(bundle, base, 'Qwen3_5VisionMLP', swiglu_fn) - # Qwen3_5GatedRMSNorm: forward-level replacement - _add_attr_if_present(bundle, base, 'Qwen3_5GatedRMSNorm.forward', gated_rms_fn) - - -def _add_qwen3_5_moe_entries(bundle, rms_cls, gated_rms_fn, rope_fn, swiglu_fn, - experts_fn, sparse_fn): - base = 'transformers.models.qwen3_5_moe.modeling_qwen3_5_moe' - if _import_optional(base) is None: - return - _add_class_if_present(bundle, base, 'Qwen3_5MoeRMSNorm', rms_cls) - _add_attr_if_present(bundle, base, 'apply_rotary_pos_emb', rope_fn) - _add_swiglu_if_present(bundle, base, 'Qwen3_5MoeMLP', swiglu_fn) - _add_attr_if_present(bundle, base, 'Qwen3_5MoeExperts.forward', experts_fn) - _add_attr_if_present(bundle, base, 'Qwen3_5MoeSparseMoeBlock.forward', sparse_fn) - _add_attr_if_present(bundle, base, 'Qwen3_5MoeGatedRMSNorm.forward', gated_rms_fn) +# Copyright (c) ModelScope Contributors. All rights reserved. +"""``npu_builtin()`` returns the bundle of Ascend NPU replacements. + +All values are wrapped in ``{'npu': impl}`` so the bundle composes safely on +CUDA/CPU systems — non-NPU devices silently skip every entry. + +GMM is **not** included by default (without EP it causes ~8x slowdown). Opt +in by merging: + + {**npu_builtin(model), 'transformers.integrations.moe._grouped_mm': + {'npu': npu_grouped_mm}} +""" +from __future__ import annotations + +import importlib +from typing import Any + +import torch.nn as nn + +from twinkle import get_logger +from twinkle.utils.device_mesh import Platform + +logger = get_logger() + + +def _import_optional(name: str): + try: + return importlib.import_module(name) + except ImportError: + return None + + +def npu_builtin(model: nn.Module | None = None) -> dict[Any, dict[str, Any]]: + """Return the NPU builtin mapping; optionally apply per-instance FLA.""" + from .npu_impls.attention import npu_sdpa_attention_forward + from .npu_impls.fla import apply_qwen3_5_fla + from .npu_impls.moe import ( + npu_packed_moe_experts_forward, + npu_qwen3_5_moe_sparse_block_forward, + ) + from .npu_impls.rms_norm import NpuRMSNorm, npu_gated_rms_norm_forward + from .npu_impls.rotary import ( + npu_apply_multimodal_rotary_pos_emb, + npu_apply_rotary_pos_emb, + ) + from .npu_impls.swiglu import npu_swiglu_forward + + bundle: dict[Any, dict[str, Any]] = {} + + is_npu_platform = Platform.device_prefix() == 'npu' + + # Apply SDPA install eagerly (one-shot module-level mutation) on NPU + # platforms. The NPU impl inverts boolean masks, which is wrong for + # CUDA/CPU execution, so non-NPU platforms must not mutate the global HF + # registry even if ``torch_npu`` is importable in the environment. + if is_npu_platform: + _install_sdpa(npu_sdpa_attention_forward) + + # === per-family class + function entries === + _add_qwen2_entries(bundle, NpuRMSNorm, npu_apply_rotary_pos_emb, npu_swiglu_forward) + _add_qwen3_entries(bundle, NpuRMSNorm, npu_apply_rotary_pos_emb, npu_swiglu_forward) + _add_qwen3_moe_entries( + bundle, NpuRMSNorm, npu_apply_rotary_pos_emb, npu_swiglu_forward, + npu_packed_moe_experts_forward, npu_qwen3_5_moe_sparse_block_forward, + ) + _add_qwen2_5_vl_entries( + bundle, NpuRMSNorm, npu_apply_rotary_pos_emb, npu_swiglu_forward, + npu_apply_multimodal_rotary_pos_emb, + ) + _add_qwen3_5_entries( + bundle, NpuRMSNorm, npu_gated_rms_norm_forward, npu_apply_rotary_pos_emb, + npu_swiglu_forward, + ) + _add_qwen3_5_moe_entries( + bundle, NpuRMSNorm, npu_gated_rms_norm_forward, npu_apply_rotary_pos_emb, + npu_swiglu_forward, npu_packed_moe_experts_forward, + npu_qwen3_5_moe_sparse_block_forward, + ) + + # === FLA (side-effect; mapping-incompatible) === + if is_npu_platform: + apply_qwen3_5_fla(model) + + return bundle + + +def _install_sdpa(impl) -> None: + """One-shot install of SDPA attention forward (global modeling_utils dict). + + ``AttentionInterface._global_mapping`` is a private transformers attribute; + guard against its removal so an upstream change can't take down the rest + of ``npu_builtin()``. + """ + try: + from transformers.modeling_utils import ( + ALL_ATTENTION_FUNCTIONS, + AttentionInterface, + ) + except ImportError: + return + try: + AttentionInterface._global_mapping['sdpa'] = impl + except AttributeError: + logger.warning('[NPU] [SDPA] AttentionInterface._global_mapping unavailable; skipping') + ALL_ATTENTION_FUNCTIONS['sdpa'] = impl + + +# ---- helpers that conditionally add entries based on module availability ---- + +def _add_class_if_present(bundle, module_path, class_name, impl_cls): + mod = _import_optional(module_path) + if mod is None: + return + cls = getattr(mod, class_name, None) + if isinstance(cls, type): + bundle[cls] = {'npu': impl_cls} + + +def _add_swiglu_if_present(bundle, module_path, class_name, fn): + mod = _import_optional(module_path) + if mod is None: + return + cls = getattr(mod, class_name, None) + if isinstance(cls, type): + # Function-level: wrap as string-keyed forward replacement. + # We override on the *class object*, not the module attribute, by + # using a class-key with a synthetic impl wrapping the forward. + # The simplest way is to subclass and reassign __class__, but here + # we follow the legacy approach of overwriting the class's forward: + bundle[f'{module_path}.{class_name}.forward'] = {'npu': fn} + + +def _add_attr_if_present(bundle, module_path, attr_name, impl): + mod = _import_optional(module_path) + if mod is None: + return + if '.' in attr_name: + # Dotted attr like 'Qwen3MoeExperts.forward': resolve the class on + # the module, then check the trailing member on the class. + head, _, tail = attr_name.partition('.') + owner = getattr(mod, head, None) + if owner is None or not hasattr(owner, tail): + return + else: + if not hasattr(mod, attr_name): + return + bundle[f'{module_path}.{attr_name}'] = {'npu': impl} + + +def _add_qwen2_entries(bundle, rms_cls, rope_fn, swiglu_fn): + # Qwen2 (used by Qwen2.5-VL etc. via inheritance) + _add_class_if_present(bundle, 'transformers.models.qwen2.modeling_qwen2', 'Qwen2RMSNorm', rms_cls) + _add_attr_if_present(bundle, 'transformers.models.qwen2.modeling_qwen2', 'apply_rotary_pos_emb', rope_fn) + _add_swiglu_if_present(bundle, 'transformers.models.qwen2.modeling_qwen2', 'Qwen2MLP', swiglu_fn) + + +def _add_qwen3_entries(bundle, rms_cls, rope_fn, swiglu_fn): + base = 'transformers.models.qwen3.modeling_qwen3' + _add_class_if_present(bundle, base, 'Qwen3RMSNorm', rms_cls) + _add_attr_if_present(bundle, base, 'apply_rotary_pos_emb', rope_fn) + _add_swiglu_if_present(bundle, base, 'Qwen3MLP', swiglu_fn) + + +def _add_qwen3_moe_entries(bundle, rms_cls, rope_fn, swiglu_fn, experts_fn, sparse_fn): + base = 'transformers.models.qwen3_moe.modeling_qwen3_moe' + _add_class_if_present(bundle, base, 'Qwen3MoeRMSNorm', rms_cls) + _add_attr_if_present(bundle, base, 'apply_rotary_pos_emb', rope_fn) + _add_swiglu_if_present(bundle, base, 'Qwen3MoeMLP', swiglu_fn) + _add_attr_if_present(bundle, base, 'Qwen3MoeExperts.forward', experts_fn) + _add_attr_if_present(bundle, base, 'Qwen3MoeSparseMoeBlock.forward', sparse_fn) + + +def _add_qwen2_5_vl_entries(bundle, rms_cls, rope_fn, swiglu_fn, multimodal_rope_fn): + base = 'transformers.models.qwen2_5_vl.modeling_qwen2_5_vl' + _add_class_if_present(bundle, base, 'Qwen2_5_VLRMSNorm', rms_cls) + _add_attr_if_present(bundle, base, 'apply_rotary_pos_emb', rope_fn) + _add_attr_if_present(bundle, base, 'apply_multimodal_rotary_pos_emb', multimodal_rope_fn) + _add_swiglu_if_present(bundle, base, 'Qwen2MLP', swiglu_fn) + _add_swiglu_if_present(bundle, base, 'Qwen2_5_VLMLP', swiglu_fn) + + +def _add_qwen3_5_entries(bundle, rms_cls, gated_rms_fn, rope_fn, swiglu_fn): + base = 'transformers.models.qwen3_5.modeling_qwen3_5' + if _import_optional(base) is None: + return + _add_class_if_present(bundle, base, 'Qwen3_5RMSNorm', rms_cls) + _add_class_if_present(bundle, base, 'Qwen3_5VisionRMSNorm', rms_cls) + _add_attr_if_present(bundle, base, 'apply_rotary_pos_emb', rope_fn) + _add_swiglu_if_present(bundle, base, 'Qwen3_5MLP', swiglu_fn) + _add_swiglu_if_present(bundle, base, 'Qwen3_5VisionMLP', swiglu_fn) + # Qwen3_5GatedRMSNorm: forward-level replacement + _add_attr_if_present(bundle, base, 'Qwen3_5GatedRMSNorm.forward', gated_rms_fn) + + +def _add_qwen3_5_moe_entries(bundle, rms_cls, gated_rms_fn, rope_fn, swiglu_fn, + experts_fn, sparse_fn): + base = 'transformers.models.qwen3_5_moe.modeling_qwen3_5_moe' + if _import_optional(base) is None: + return + _add_class_if_present(bundle, base, 'Qwen3_5MoeRMSNorm', rms_cls) + _add_attr_if_present(bundle, base, 'apply_rotary_pos_emb', rope_fn) + _add_swiglu_if_present(bundle, base, 'Qwen3_5MoeMLP', swiglu_fn) + _add_attr_if_present(bundle, base, 'Qwen3_5MoeExperts.forward', experts_fn) + _add_attr_if_present(bundle, base, 'Qwen3_5MoeSparseMoeBlock.forward', sparse_fn) + _add_attr_if_present(bundle, base, 'Qwen3_5MoeGatedRMSNorm.forward', gated_rms_fn) diff --git a/src/twinkle/kernel/chunk_gated_delta_rule.py b/src/twinkle/kernel/chunk_gated_delta_rule.py index 2d0beee7..09defec6 100644 --- a/src/twinkle/kernel/chunk_gated_delta_rule.py +++ b/src/twinkle/kernel/chunk_gated_delta_rule.py @@ -1,362 +1,362 @@ -'''Ascend NPU implementation of chunk_gated_delta_rule for Flash Linear Attention (FLA). -This module provides a drop-in replacement for fla.ops.gated_delta_rule.chunk_gated_delta_rule, -redirecting the underlying Triton kernels to MindSpeed's NPU-compatible counterparts. -It is consumed by twinkle.kernel.npu_impls.fla to enable the fast linear-attention -path of Qwen3.5 on Ascend hardware.''' - -import torch -import warnings -from mindspeed.lite.ops.triton.chunk_delta_h import chunk_gated_delta_rule_bwd_dhu, chunk_gated_delta_rule_fwd_h -from mindspeed.lite.ops.triton.chunk_o import chunk_bwd_dqkwg, chunk_bwd_dv_local, chunk_fwd_o -from mindspeed.lite.ops.triton.chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd -from mindspeed.lite.ops.triton.cumsum import chunk_local_cumsum -from mindspeed.lite.ops.triton.solve_tril import solve_tril -from mindspeed.lite.ops.triton.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard -from mindspeed.lite.ops.triton.wy_fast import prepare_wy_repr_bwd, recompute_w_u_fwd -from typing import Optional - - -def _torch_l2norm_fwd( - x: torch.Tensor, - eps: float = 1e-6, - output_dtype: Optional[torch.dtype] = None, -): - x_shape_og = x.shape - x = x.view(-1, x.shape[-1]) - x_float = x.float() - rstd = torch.rsqrt(torch.sum(x_float * x_float, dim=-1) + eps) - y = x_float * rstd.unsqueeze(-1) - y = y.to(output_dtype if output_dtype is not None else x.dtype) - return y.view(x_shape_og), rstd.view(x_shape_og[:-1]) - - -def _torch_l2norm_bwd( - y: torch.Tensor, - rstd: torch.Tensor, - dy: torch.Tensor, - eps: float = 1e-6, -): - y_shape_og = y.shape - y = y.view(-1, y.shape[-1]) - dy = dy.view(-1, dy.shape[-1]) - y_float = y.float() - dy_float = dy.float() - rstd = rstd.view(-1).float() - dx = dy_float * rstd.unsqueeze(-1) - dx = dx - torch.sum(dy_float * y_float, dim=-1, keepdim=True) * y_float * rstd.unsqueeze(-1) - return dx.to(y.dtype).view(y_shape_og) - - -def chunk_gated_delta_rule_fwd( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - g: torch.Tensor, - beta: torch.Tensor, - scale: float, - initial_state: torch.Tensor, - output_final_state: bool, - cu_seqlens: Optional[torch.LongTensor] = None, - chunk_size: int = 64, -): - g = chunk_local_cumsum(g, chunk_size=chunk_size, cu_seqlens=cu_seqlens, head_first=False) - # obtain WY representation. u is actually the new v. - A = chunk_scaled_dot_kkt_fwd( - k=k, g=g, beta=beta, cu_seqlens=cu_seqlens, chunk_size=chunk_size, output_dtype=torch.float32) - A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype) - w, u = recompute_w_u_fwd( - k=k, - v=v, - beta=beta, - A=A, - g=g, - cu_seqlens=cu_seqlens, - ) - h, v_new, final_state = chunk_gated_delta_rule_fwd_h( - k=k, - w=w, - u=u, - g=g, - initial_state=initial_state, - output_final_state=output_final_state, - chunk_size=chunk_size, - cu_seqlens=cu_seqlens, - ) - o = chunk_fwd_o( - q=q, - k=k, - v=v_new, - h=h, - g=g, - scale=scale, - cu_seqlens=cu_seqlens, - chunk_size=chunk_size, - ) - return g, o, A, final_state - - -def chunk_gated_delta_rule_bwd( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - g: torch.Tensor, - beta: torch.Tensor, - A: torch.Tensor, - scale: float, - initial_state: torch.Tensor, - do: torch.Tensor, - dht: torch.Tensor, - cu_seqlens: Optional[torch.LongTensor] = None, - chunk_size: int = 64, -): - w, u = recompute_w_u_fwd( - k=k, - v=v, - beta=beta, - A=A, - g=g, - cu_seqlens=cu_seqlens, - ) - h, v_new, _ = chunk_gated_delta_rule_fwd_h( - k=k, - w=w, - u=u, - g=g, - initial_state=initial_state, - output_final_state=False, - cu_seqlens=cu_seqlens, - chunk_size=chunk_size, - ) - dv = chunk_bwd_dv_local( - q=q, - k=k, - g=g, - do=do, - scale=scale, - cu_seqlens=cu_seqlens, - chunk_size=chunk_size, - ) - dh, dh0, dv = chunk_gated_delta_rule_bwd_dhu( - q=q, - k=k, - w=w, - g=g, - h0=initial_state, - dht=dht, - do=do, - dv=dv, - scale=scale, - cu_seqlens=cu_seqlens, - chunk_size=chunk_size, - ) - dq, dk, dw, dg = chunk_bwd_dqkwg( - q=q, - k=k, - v=v_new, - w=w, - g=g, - h=h, - dv=dv, - do=do, - dh=dh, - chunk_size=chunk_size, - scale=scale, - cu_seqlens=cu_seqlens, - ) - dk2, dv, db, dg2 = prepare_wy_repr_bwd( - k=k, v=v, beta=beta, g=g, A=A, dw=dw, du=dv, cu_seqlens=cu_seqlens, chunk_size=chunk_size) - dk.add_(dk2) - dg.add_(dg2) - if dg.dtype != torch.float32: - raise ValueError(f'dg current type is {dg.dtype} , should be float32') - dg = chunk_local_cumsum(dg, chunk_size=chunk_size, reverse=True, cu_seqlens=cu_seqlens, head_first=False) - return dq, dk, dv, db, dg, dh0 - - -class ChunkGatedDeltaRuleFunction(torch.autograd.Function): - - @staticmethod - @input_guard - @autocast_custom_fwd - def forward( - ctx, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - g: torch.Tensor, - beta: torch.Tensor, - scale: float, - initial_state: torch.Tensor, - output_final_state: bool, - cu_seqlens: Optional[torch.LongTensor] = None, - use_qk_l2norm_in_kernel: bool = False, - chunk_size: int = 64, - ): - if use_qk_l2norm_in_kernel: - q, q_rstd = _torch_l2norm_fwd(q) - k, k_rstd = _torch_l2norm_fwd(k) - else: - q_rstd, k_rstd = None, None - - g, o, A, final_state = chunk_gated_delta_rule_fwd( - q=q, - k=k, - v=v, - g=g, - beta=beta, - scale=scale, - initial_state=initial_state, - output_final_state=output_final_state, - cu_seqlens=cu_seqlens, - chunk_size=chunk_size) - ctx.save_for_backward(q, q_rstd, k, k_rstd, v, g, beta, A, initial_state, cu_seqlens) - ctx.scale = scale - ctx.use_qk_l2norm_in_kernel = use_qk_l2norm_in_kernel - ctx.chunk_size = chunk_size - return o.to(q.dtype), final_state - - @staticmethod - @input_guard - @autocast_custom_bwd - def backward(ctx, do: torch.Tensor, dht: torch.Tensor): - q, q_rstd, k, k_rstd, v, g, beta, A, initial_state, cu_seqlens = ctx.saved_tensors - dq, dk, dv, db, dg, dh0 = chunk_gated_delta_rule_bwd( - q=q, - k=k, - v=v, - g=g, - beta=beta, - A=A, - scale=ctx.scale, - initial_state=initial_state, - do=do, - dht=dht, - cu_seqlens=cu_seqlens, - chunk_size=ctx.chunk_size, - ) - if ctx.use_qk_l2norm_in_kernel: - dq = _torch_l2norm_bwd(q, q_rstd, dq) - dk = _torch_l2norm_bwd(k, k_rstd, dk) - return dq.to(q), dk.to(k), dv.to(v), dg.to(g), db.to(beta), None, dh0, None, None, None, None - - -@torch.compiler.disable -def chunk_gated_delta_rule( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - g: torch.Tensor, - beta: torch.Tensor, - scale: float = None, - initial_state: torch.Tensor = None, - output_final_state: bool = False, - use_qk_l2norm_in_kernel: bool = False, - cu_seqlens: Optional[torch.LongTensor] = None, - chunk_size: int = 64, - head_first: bool = False, -): - r""" - Args: - q (torch.Tensor): - queries of shape `[B, T, H, K]`. - k (torch.Tensor): - keys of shape `[B, T, H, K]`. - v (torch.Tensor): - values of shape `[B, T, H, V]`. - g (torch.Tensor): - (forget) gating tensor (in log space!) of shape `[B, T, H]`. - beta (torch.Tensor): - betas of shape `[B, T, H]`. - scale (Optional[float]): - Scale factor for the RetNet attention scores. - If not provided, it will default to `1 / sqrt(K)`. Default: `None`. - initial_state (Optional[torch.Tensor]): - Initial state of shape `[N, H, K, V]` for `N` input sequences. - For equal-length input sequences, `N` equals the batch size `B`. - Default: `None`. - output_final_state (Optional[bool]): - Whether to output the final state of shape `[N, H, K, V]`. Default: `False`. - use_qk_l2norm_in_kernel (bool): - Whether to apply L2norm to the q/k tensor internally. Default: `False`. - cu_seqlens (torch.LongTensor): - Cumulative sequence lengths of shape `[N+1]` used for variable-length training, - consistent with the FlashAttention API. - head_first (Optional[bool]): - Whether the inputs are in the head-first format. Default: `False`. - This argument has been deprecated. - - Returns: - o (torch.Tensor): - Outputs of shape `[B, T, H, V]`. - final_state (torch.Tensor): - Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`. - - Examples:: - >>> import torch - >>> import torch.nn.functional as F - >>> from einops import rearrange - >>> from fla.ops.gated_delta_rule import chunk_gated_delta_rule - # inputs with equal lengths - >>> B, T, H, K, V = 4, 2048, 4, 512, 512 - >>> q = torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda') - >>> k = F.normalize(torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda'), p=2, dim=-1) - >>> v = torch.randn(B, T, H, V, dtype=torch.bfloat16, device='cuda') - >>> beta = torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda').sigmoid() - >>> g = F.logsigmoid(torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda')) - >>> h0 = torch.randn(B, H, K, V, dtype=torch.bfloat16, device='cuda') - >>> o, ht = chunk_gated_delta_rule( - q, k, v, g, beta, - initial_state=h0, - output_final_state=True - ) - # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required - >>> q, k, v, beta, g = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, beta, g)) - # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected - >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long) - >>> o, ht = chunk_gated_delta_rule( - q, k, v, g, beta, - initial_state=h0, - output_final_state=True, - cu_seqlens=cu_seqlens - ) - """ - if q.dtype != k.dtype or k.dtype != v.dtype: - raise ValueError( - f'q current type is {q.dtype}, k current type is {k.dtype}, v current type is {v.dtype}, should be equal') - if q.dtype == torch.float32: - raise ValueError('ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16.') - if len(beta.shape) != 3: - raise ValueError(f'beta current shape len is {len(beta.shape)}, beta must be of shape [B, T, H] ' - f'if head_first=False, or [B, H, T] otherwise.') - if head_first: - warnings.warn('head_first is deprecated and will be removed in a future version. ' - 'Please use head_first=False for now instead.') - if not head_first and q.shape[1] < q.shape[2]: - warnings.warn( - f'Input tensor shape suggests format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). ' - 'This may indicate the inputs were passed in head-first format [B, H, T, ...] ' - 'when head_first=False was specified. ' - 'Please verify your input tensor format matches the expected shape [B, T, H, ...].') - if cu_seqlens is not None: - if q.shape[0] != 1: - raise ValueError(f'The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`.' - f'Please flatten variable-length inputs before processing.') - if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1: - raise ValueError(f'The number of initial states is expected to be equal to the number of input sequences, ' - f'i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}.') - if scale is None: - scale = k.shape[-1]**-0.5 - o, final_state = ChunkGatedDeltaRuleFunction.apply( - q, - k, - v, - g, - beta, - scale, - initial_state, - output_final_state, - cu_seqlens, - use_qk_l2norm_in_kernel, - chunk_size, - ) - return o, final_state +'''Ascend NPU implementation of chunk_gated_delta_rule for Flash Linear Attention (FLA). +This module provides a drop-in replacement for fla.ops.gated_delta_rule.chunk_gated_delta_rule, +redirecting the underlying Triton kernels to MindSpeed's NPU-compatible counterparts. +It is consumed by twinkle.kernel.npu_impls.fla to enable the fast linear-attention +path of Qwen3.5 on Ascend hardware.''' + +import torch +import warnings +from mindspeed.lite.ops.triton.chunk_delta_h import chunk_gated_delta_rule_bwd_dhu, chunk_gated_delta_rule_fwd_h +from mindspeed.lite.ops.triton.chunk_o import chunk_bwd_dqkwg, chunk_bwd_dv_local, chunk_fwd_o +from mindspeed.lite.ops.triton.chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd +from mindspeed.lite.ops.triton.cumsum import chunk_local_cumsum +from mindspeed.lite.ops.triton.solve_tril import solve_tril +from mindspeed.lite.ops.triton.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard +from mindspeed.lite.ops.triton.wy_fast import prepare_wy_repr_bwd, recompute_w_u_fwd +from typing import Optional + + +def _torch_l2norm_fwd( + x: torch.Tensor, + eps: float = 1e-6, + output_dtype: Optional[torch.dtype] = None, +): + x_shape_og = x.shape + x = x.view(-1, x.shape[-1]) + x_float = x.float() + rstd = torch.rsqrt(torch.sum(x_float * x_float, dim=-1) + eps) + y = x_float * rstd.unsqueeze(-1) + y = y.to(output_dtype if output_dtype is not None else x.dtype) + return y.view(x_shape_og), rstd.view(x_shape_og[:-1]) + + +def _torch_l2norm_bwd( + y: torch.Tensor, + rstd: torch.Tensor, + dy: torch.Tensor, + eps: float = 1e-6, +): + y_shape_og = y.shape + y = y.view(-1, y.shape[-1]) + dy = dy.view(-1, dy.shape[-1]) + y_float = y.float() + dy_float = dy.float() + rstd = rstd.view(-1).float() + dx = dy_float * rstd.unsqueeze(-1) + dx = dx - torch.sum(dy_float * y_float, dim=-1, keepdim=True) * y_float * rstd.unsqueeze(-1) + return dx.to(y.dtype).view(y_shape_og) + + +def chunk_gated_delta_rule_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + output_final_state: bool, + cu_seqlens: Optional[torch.LongTensor] = None, + chunk_size: int = 64, +): + g = chunk_local_cumsum(g, chunk_size=chunk_size, cu_seqlens=cu_seqlens, head_first=False) + # obtain WY representation. u is actually the new v. + A = chunk_scaled_dot_kkt_fwd( + k=k, g=g, beta=beta, cu_seqlens=cu_seqlens, chunk_size=chunk_size, output_dtype=torch.float32) + A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype) + w, u = recompute_w_u_fwd( + k=k, + v=v, + beta=beta, + A=A, + g=g, + cu_seqlens=cu_seqlens, + ) + h, v_new, final_state = chunk_gated_delta_rule_fwd_h( + k=k, + w=w, + u=u, + g=g, + initial_state=initial_state, + output_final_state=output_final_state, + chunk_size=chunk_size, + cu_seqlens=cu_seqlens, + ) + o = chunk_fwd_o( + q=q, + k=k, + v=v_new, + h=h, + g=g, + scale=scale, + cu_seqlens=cu_seqlens, + chunk_size=chunk_size, + ) + return g, o, A, final_state + + +def chunk_gated_delta_rule_bwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + A: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + do: torch.Tensor, + dht: torch.Tensor, + cu_seqlens: Optional[torch.LongTensor] = None, + chunk_size: int = 64, +): + w, u = recompute_w_u_fwd( + k=k, + v=v, + beta=beta, + A=A, + g=g, + cu_seqlens=cu_seqlens, + ) + h, v_new, _ = chunk_gated_delta_rule_fwd_h( + k=k, + w=w, + u=u, + g=g, + initial_state=initial_state, + output_final_state=False, + cu_seqlens=cu_seqlens, + chunk_size=chunk_size, + ) + dv = chunk_bwd_dv_local( + q=q, + k=k, + g=g, + do=do, + scale=scale, + cu_seqlens=cu_seqlens, + chunk_size=chunk_size, + ) + dh, dh0, dv = chunk_gated_delta_rule_bwd_dhu( + q=q, + k=k, + w=w, + g=g, + h0=initial_state, + dht=dht, + do=do, + dv=dv, + scale=scale, + cu_seqlens=cu_seqlens, + chunk_size=chunk_size, + ) + dq, dk, dw, dg = chunk_bwd_dqkwg( + q=q, + k=k, + v=v_new, + w=w, + g=g, + h=h, + dv=dv, + do=do, + dh=dh, + chunk_size=chunk_size, + scale=scale, + cu_seqlens=cu_seqlens, + ) + dk2, dv, db, dg2 = prepare_wy_repr_bwd( + k=k, v=v, beta=beta, g=g, A=A, dw=dw, du=dv, cu_seqlens=cu_seqlens, chunk_size=chunk_size) + dk.add_(dk2) + dg.add_(dg2) + if dg.dtype != torch.float32: + raise ValueError(f'dg current type is {dg.dtype} , should be float32') + dg = chunk_local_cumsum(dg, chunk_size=chunk_size, reverse=True, cu_seqlens=cu_seqlens, head_first=False) + return dq, dk, dv, db, dg, dh0 + + +class ChunkGatedDeltaRuleFunction(torch.autograd.Function): + + @staticmethod + @input_guard + @autocast_custom_fwd + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + output_final_state: bool, + cu_seqlens: Optional[torch.LongTensor] = None, + use_qk_l2norm_in_kernel: bool = False, + chunk_size: int = 64, + ): + if use_qk_l2norm_in_kernel: + q, q_rstd = _torch_l2norm_fwd(q) + k, k_rstd = _torch_l2norm_fwd(k) + else: + q_rstd, k_rstd = None, None + + g, o, A, final_state = chunk_gated_delta_rule_fwd( + q=q, + k=k, + v=v, + g=g, + beta=beta, + scale=scale, + initial_state=initial_state, + output_final_state=output_final_state, + cu_seqlens=cu_seqlens, + chunk_size=chunk_size) + ctx.save_for_backward(q, q_rstd, k, k_rstd, v, g, beta, A, initial_state, cu_seqlens) + ctx.scale = scale + ctx.use_qk_l2norm_in_kernel = use_qk_l2norm_in_kernel + ctx.chunk_size = chunk_size + return o.to(q.dtype), final_state + + @staticmethod + @input_guard + @autocast_custom_bwd + def backward(ctx, do: torch.Tensor, dht: torch.Tensor): + q, q_rstd, k, k_rstd, v, g, beta, A, initial_state, cu_seqlens = ctx.saved_tensors + dq, dk, dv, db, dg, dh0 = chunk_gated_delta_rule_bwd( + q=q, + k=k, + v=v, + g=g, + beta=beta, + A=A, + scale=ctx.scale, + initial_state=initial_state, + do=do, + dht=dht, + cu_seqlens=cu_seqlens, + chunk_size=ctx.chunk_size, + ) + if ctx.use_qk_l2norm_in_kernel: + dq = _torch_l2norm_bwd(q, q_rstd, dq) + dk = _torch_l2norm_bwd(k, k_rstd, dk) + return dq.to(q), dk.to(k), dv.to(v), dg.to(g), db.to(beta), None, dh0, None, None, None, None + + +@torch.compiler.disable +def chunk_gated_delta_rule( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float = None, + initial_state: torch.Tensor = None, + output_final_state: bool = False, + use_qk_l2norm_in_kernel: bool = False, + cu_seqlens: Optional[torch.LongTensor] = None, + chunk_size: int = 64, + head_first: bool = False, +): + r""" + Args: + q (torch.Tensor): + queries of shape `[B, T, H, K]`. + k (torch.Tensor): + keys of shape `[B, T, H, K]`. + v (torch.Tensor): + values of shape `[B, T, H, V]`. + g (torch.Tensor): + (forget) gating tensor (in log space!) of shape `[B, T, H]`. + beta (torch.Tensor): + betas of shape `[B, T, H]`. + scale (Optional[float]): + Scale factor for the RetNet attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + initial_state (Optional[torch.Tensor]): + Initial state of shape `[N, H, K, V]` for `N` input sequences. + For equal-length input sequences, `N` equals the batch size `B`. + Default: `None`. + output_final_state (Optional[bool]): + Whether to output the final state of shape `[N, H, K, V]`. Default: `False`. + use_qk_l2norm_in_kernel (bool): + Whether to apply L2norm to the q/k tensor internally. Default: `False`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + head_first (Optional[bool]): + Whether the inputs are in the head-first format. Default: `False`. + This argument has been deprecated. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, T, H, V]`. + final_state (torch.Tensor): + Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`. + + Examples:: + >>> import torch + >>> import torch.nn.functional as F + >>> from einops import rearrange + >>> from fla.ops.gated_delta_rule import chunk_gated_delta_rule + # inputs with equal lengths + >>> B, T, H, K, V = 4, 2048, 4, 512, 512 + >>> q = torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda') + >>> k = F.normalize(torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda'), p=2, dim=-1) + >>> v = torch.randn(B, T, H, V, dtype=torch.bfloat16, device='cuda') + >>> beta = torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda').sigmoid() + >>> g = F.logsigmoid(torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda')) + >>> h0 = torch.randn(B, H, K, V, dtype=torch.bfloat16, device='cuda') + >>> o, ht = chunk_gated_delta_rule( + q, k, v, g, beta, + initial_state=h0, + output_final_state=True + ) + # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required + >>> q, k, v, beta, g = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, beta, g)) + # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected + >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long) + >>> o, ht = chunk_gated_delta_rule( + q, k, v, g, beta, + initial_state=h0, + output_final_state=True, + cu_seqlens=cu_seqlens + ) + """ + if q.dtype != k.dtype or k.dtype != v.dtype: + raise ValueError( + f'q current type is {q.dtype}, k current type is {k.dtype}, v current type is {v.dtype}, should be equal') + if q.dtype == torch.float32: + raise ValueError('ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16.') + if len(beta.shape) != 3: + raise ValueError(f'beta current shape len is {len(beta.shape)}, beta must be of shape [B, T, H] ' + f'if head_first=False, or [B, H, T] otherwise.') + if head_first: + warnings.warn('head_first is deprecated and will be removed in a future version. ' + 'Please use head_first=False for now instead.') + if not head_first and q.shape[1] < q.shape[2]: + warnings.warn( + f'Input tensor shape suggests format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). ' + 'This may indicate the inputs were passed in head-first format [B, H, T, ...] ' + 'when head_first=False was specified. ' + 'Please verify your input tensor format matches the expected shape [B, T, H, ...].') + if cu_seqlens is not None: + if q.shape[0] != 1: + raise ValueError(f'The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`.' + f'Please flatten variable-length inputs before processing.') + if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1: + raise ValueError(f'The number of initial states is expected to be equal to the number of input sequences, ' + f'i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}.') + if scale is None: + scale = k.shape[-1]**-0.5 + o, final_state = ChunkGatedDeltaRuleFunction.apply( + q, + k, + v, + g, + beta, + scale, + initial_state, + output_final_state, + cu_seqlens, + use_qk_l2norm_in_kernel, + chunk_size, + ) + return o, final_state diff --git a/src/twinkle/kernel/core.py b/src/twinkle/kernel/core.py index a3a12f18..03cace34 100644 --- a/src/twinkle/kernel/core.py +++ b/src/twinkle/kernel/core.py @@ -1,171 +1,171 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -"""Minimal mapping-driven kernel replacement. - -Public API: ``kernelize``, ``hub`` (re-exported from ``twinkle.kernel``). -""" -from __future__ import annotations - -import importlib -from dataclasses import dataclass -from typing import Any - -import torch.nn as nn - -from twinkle.utils.device_mesh import Platform - - -@dataclass(frozen=True) -class HubRef: - """Lightweight reference to a HuggingFace Hub kernel layer. - - Resolved lazily by ``kernelize`` via the optional ``kernels`` package. - """ - repo_id: str - layer_name: str - revision: str | None = None - version: int | None = None - backend: str | None = None - trust_remote_code: bool = False - - -def hub( - ref: str, - *, - revision: str | None = None, - version: int | None = None, - backend: str | None = None, - trust_remote_code: bool = False, -) -> HubRef: - """Build a ``HubRef`` for use as a ``kernelize`` mapping value. - - ``ref`` is ``':'`` (e.g. ``'org/repo:SiluAndMul'``). - Exactly one of ``revision`` or ``version`` must be supplied. - """ - if (revision is None) == (version is None): - raise ValueError('Exactly one of `revision` or `version` must be specified.') - if ':' not in ref: - raise ValueError(f"Hub ref must be 'repo_id:LayerName', got: {ref!r}") - repo_id, layer_name = ref.rsplit(':', 1) - return HubRef(repo_id, layer_name, revision, version, backend, trust_remote_code) - - - -def _resolve_value(value: Any, device: str) -> Any | None: - """Resolve a mapping value against the selected device. - - - ``dict``: device-conditional; recurse into ``value[device]`` or return None. - - anything else (including ``HubRef``): pass through. - """ - if isinstance(value, dict): - if device not in value: - return None - return _resolve_value(value[device], device) - return value - - -def _replace_class(model: nn.Module, target_cls: type, impl_cls: type) -> None: - """Rewrite ``__class__`` of every module whose exact type is ``target_cls``. - - Uses ``type(m) is target_cls`` (not ``isinstance``) so user-defined - subclasses of ``target_cls`` are deliberately left alone. - """ - for m in model.modules(): - if type(m) is target_cls: - m.__class__ = impl_cls - - -def _replace_attr(dotted_path: str, impl) -> None: - """``setattr`` ``impl`` onto the attribute identified by the dotted path. - - Supports two forms: - - ``pkg.mod.attr`` (set module attribute) - - ``pkg.mod.ClassName.attr`` (set class attribute / method) - - The split is found by walking the prefix from the longest importable - module backwards until ``importlib.import_module`` succeeds. - """ - parts = dotted_path.split('.') - if len(parts) < 2: - raise ValueError(f"Expected at least 'pkg.attr', got: {dotted_path!r}") - - # Find the longest prefix that imports as a module. - last_err: ImportError | None = None - module = None - module_depth = 0 - for i in range(len(parts) - 1, 0, -1): - candidate = '.'.join(parts[:i]) - try: - module = importlib.import_module(candidate) - module_depth = i - break - except ImportError as e: - last_err = e - continue - if module is None: - raise ImportError(f'Could not import any prefix of {dotted_path!r}') from last_err - - # Walk remaining attributes; the last one is the target. - obj = module - for attr in parts[module_depth:-1]: - obj = getattr(obj, attr) - setattr(obj, parts[-1], impl) - - -def _load_hub_ref(ref: HubRef): - """Lazy-load a Hub kernel layer via the optional ``kernels`` package.""" - try: - from kernels import get_kernel - except ImportError as e: - raise ImportError( - 'Loading a Hub kernel requires the `kernels` package. ' - 'Install it with `pip install kernels`.' - ) from e - - kernel = get_kernel( - ref.repo_id, - revision=ref.revision, - version=ref.version, - backend=ref.backend, - trust_remote_code=ref.trust_remote_code, - ) - layers = getattr(kernel, 'layers', None) - if layers is None: - raise ValueError(f'Hub repo {ref.repo_id!r} does not define any layers.') - impl = getattr(layers, ref.layer_name, None) - if impl is None: - raise ValueError(f'Layer {ref.layer_name!r} not found in {ref.repo_id!r}.') - return impl - - -def kernelize(model: nn.Module, mapping: dict) -> nn.Module: - """Apply ``mapping`` to ``model`` and return it (modified in place). - - Keys: - - ``type[nn.Module]``: replace ``m.__class__`` for every module of the - exact type (no subclass walking). - - ``str`` (dotted path ``pkg.mod.attr``): ``setattr`` the impl onto the - identified module attribute. - - Values: - - ``dict[str, V]``: device-conditional dispatch using the current - Twinkle platform device prefix; non-matching devices skip. - - ``HubRef``: lazy-resolved via the optional ``kernels`` package. - - anything else: used directly as the impl. - """ - if not mapping: - return model - - device = Platform.device_prefix() - for key, value in mapping.items(): - impl = _resolve_value(value, device) - if impl is None: - continue - if isinstance(impl, HubRef): - impl = _load_hub_ref(impl) - if isinstance(key, type) and issubclass(key, nn.Module): - _replace_class(model, key, impl) - elif isinstance(key, str): - _replace_attr(key, impl) - else: - raise TypeError(f'Unsupported mapping key: {key!r}') - return model +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Minimal mapping-driven kernel replacement. + +Public API: ``kernelize``, ``hub`` (re-exported from ``twinkle.kernel``). +""" +from __future__ import annotations + +import importlib +from dataclasses import dataclass +from typing import Any + +import torch.nn as nn + +from twinkle.utils.device_mesh import Platform + + +@dataclass(frozen=True) +class HubRef: + """Lightweight reference to a HuggingFace Hub kernel layer. + + Resolved lazily by ``kernelize`` via the optional ``kernels`` package. + """ + repo_id: str + layer_name: str + revision: str | None = None + version: int | None = None + backend: str | None = None + trust_remote_code: bool = False + + +def hub( + ref: str, + *, + revision: str | None = None, + version: int | None = None, + backend: str | None = None, + trust_remote_code: bool = False, +) -> HubRef: + """Build a ``HubRef`` for use as a ``kernelize`` mapping value. + + ``ref`` is ``':'`` (e.g. ``'org/repo:SiluAndMul'``). + Exactly one of ``revision`` or ``version`` must be supplied. + """ + if (revision is None) == (version is None): + raise ValueError('Exactly one of `revision` or `version` must be specified.') + if ':' not in ref: + raise ValueError(f"Hub ref must be 'repo_id:LayerName', got: {ref!r}") + repo_id, layer_name = ref.rsplit(':', 1) + return HubRef(repo_id, layer_name, revision, version, backend, trust_remote_code) + + + +def _resolve_value(value: Any, device: str) -> Any | None: + """Resolve a mapping value against the selected device. + + - ``dict``: device-conditional; recurse into ``value[device]`` or return None. + - anything else (including ``HubRef``): pass through. + """ + if isinstance(value, dict): + if device not in value: + return None + return _resolve_value(value[device], device) + return value + + +def _replace_class(model: nn.Module, target_cls: type, impl_cls: type) -> None: + """Rewrite ``__class__`` of every module whose exact type is ``target_cls``. + + Uses ``type(m) is target_cls`` (not ``isinstance``) so user-defined + subclasses of ``target_cls`` are deliberately left alone. + """ + for m in model.modules(): + if type(m) is target_cls: + m.__class__ = impl_cls + + +def _replace_attr(dotted_path: str, impl) -> None: + """``setattr`` ``impl`` onto the attribute identified by the dotted path. + + Supports two forms: + - ``pkg.mod.attr`` (set module attribute) + - ``pkg.mod.ClassName.attr`` (set class attribute / method) + + The split is found by walking the prefix from the longest importable + module backwards until ``importlib.import_module`` succeeds. + """ + parts = dotted_path.split('.') + if len(parts) < 2: + raise ValueError(f"Expected at least 'pkg.attr', got: {dotted_path!r}") + + # Find the longest prefix that imports as a module. + last_err: ImportError | None = None + module = None + module_depth = 0 + for i in range(len(parts) - 1, 0, -1): + candidate = '.'.join(parts[:i]) + try: + module = importlib.import_module(candidate) + module_depth = i + break + except ImportError as e: + last_err = e + continue + if module is None: + raise ImportError(f'Could not import any prefix of {dotted_path!r}') from last_err + + # Walk remaining attributes; the last one is the target. + obj = module + for attr in parts[module_depth:-1]: + obj = getattr(obj, attr) + setattr(obj, parts[-1], impl) + + +def _load_hub_ref(ref: HubRef): + """Lazy-load a Hub kernel layer via the optional ``kernels`` package.""" + try: + from kernels import get_kernel + except ImportError as e: + raise ImportError( + 'Loading a Hub kernel requires the `kernels` package. ' + 'Install it with `pip install kernels`.' + ) from e + + kernel = get_kernel( + ref.repo_id, + revision=ref.revision, + version=ref.version, + backend=ref.backend, + trust_remote_code=ref.trust_remote_code, + ) + layers = getattr(kernel, 'layers', None) + if layers is None: + raise ValueError(f'Hub repo {ref.repo_id!r} does not define any layers.') + impl = getattr(layers, ref.layer_name, None) + if impl is None: + raise ValueError(f'Layer {ref.layer_name!r} not found in {ref.repo_id!r}.') + return impl + + +def kernelize(model: nn.Module, mapping: dict) -> nn.Module: + """Apply ``mapping`` to ``model`` and return it (modified in place). + + Keys: + - ``type[nn.Module]``: replace ``m.__class__`` for every module of the + exact type (no subclass walking). + - ``str`` (dotted path ``pkg.mod.attr``): ``setattr`` the impl onto the + identified module attribute. + + Values: + - ``dict[str, V]``: device-conditional dispatch using the current + Twinkle platform device prefix; non-matching devices skip. + - ``HubRef``: lazy-resolved via the optional ``kernels`` package. + - anything else: used directly as the impl. + """ + if not mapping: + return model + + device = Platform.device_prefix() + for key, value in mapping.items(): + impl = _resolve_value(value, device) + if impl is None: + continue + if isinstance(impl, HubRef): + impl = _load_hub_ref(impl) + if isinstance(key, type) and issubclass(key, nn.Module): + _replace_class(model, key, impl) + elif isinstance(key, str): + _replace_attr(key, impl) + else: + raise TypeError(f'Unsupported mapping key: {key!r}') + return model diff --git a/src/twinkle/kernel/npu_impls/__init__.py b/src/twinkle/kernel/npu_impls/__init__.py index 31d77581..71606aaa 100644 --- a/src/twinkle/kernel/npu_impls/__init__.py +++ b/src/twinkle/kernel/npu_impls/__init__.py @@ -1,32 +1,32 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -"""Per-layer NPU implementations consumed by ``npu_builtin()``. - -Each impl is contracted to be applied via ``m.__class__ = ImplCls`` (class -replacement) or ``setattr(module, attr, fn)`` (function replacement). No impl -here is meant to be instantiated directly. -""" -from .rms_norm import NpuRMSNorm, npu_gated_rms_norm_forward -from .rotary import npu_apply_multimodal_rotary_pos_emb, npu_apply_rotary_pos_emb -from .swiglu import npu_swiglu_forward -from .attention import npu_sdpa_attention_forward -from .moe import ( - GmmFunction, - npu_grouped_mm, - npu_packed_moe_experts_forward, - npu_qwen3_5_moe_sparse_block_forward, -) -from .fla import apply_qwen3_5_fla - -__all__ = [ - 'NpuRMSNorm', - 'npu_gated_rms_norm_forward', - 'npu_apply_rotary_pos_emb', - 'npu_apply_multimodal_rotary_pos_emb', - 'npu_swiglu_forward', - 'npu_sdpa_attention_forward', - 'GmmFunction', - 'npu_grouped_mm', - 'npu_packed_moe_experts_forward', - 'npu_qwen3_5_moe_sparse_block_forward', - 'apply_qwen3_5_fla', +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Per-layer NPU implementations consumed by ``npu_builtin()``. + +Each impl is contracted to be applied via ``m.__class__ = ImplCls`` (class +replacement) or ``setattr(module, attr, fn)`` (function replacement). No impl +here is meant to be instantiated directly. +""" +from .rms_norm import NpuRMSNorm, npu_gated_rms_norm_forward +from .rotary import npu_apply_multimodal_rotary_pos_emb, npu_apply_rotary_pos_emb +from .swiglu import npu_swiglu_forward +from .attention import npu_sdpa_attention_forward +from .moe import ( + GmmFunction, + npu_grouped_mm, + npu_packed_moe_experts_forward, + npu_qwen3_5_moe_sparse_block_forward, +) +from .fla import apply_qwen3_5_fla + +__all__ = [ + 'NpuRMSNorm', + 'npu_gated_rms_norm_forward', + 'npu_apply_rotary_pos_emb', + 'npu_apply_multimodal_rotary_pos_emb', + 'npu_swiglu_forward', + 'npu_sdpa_attention_forward', + 'GmmFunction', + 'npu_grouped_mm', + 'npu_packed_moe_experts_forward', + 'npu_qwen3_5_moe_sparse_block_forward', + 'apply_qwen3_5_fla', ] \ No newline at end of file diff --git a/src/twinkle/kernel/npu_impls/attention.py b/src/twinkle/kernel/npu_impls/attention.py index f328b2d5..2bf4255b 100644 --- a/src/twinkle/kernel/npu_impls/attention.py +++ b/src/twinkle/kernel/npu_impls/attention.py @@ -1,54 +1,54 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -"""SDPA forward with Ascend NPU compatibility fixes.""" -from __future__ import annotations - -import torch - - -def npu_sdpa_attention_forward( - module, - query, - key, - value, - attention_mask, - dropout=0.0, - scaling=None, - is_causal=None, - **kwargs, -): - """Drop-in replacement for ``transformers.integrations.sdpa_attention.sdpa_attention_forward``. - - Fixes: - - Repeats KV heads (NPU SDPA does not auto-broadcast num_kv_groups). - - Truncates causal_mask to key length. - - Forces contiguous tensors (NPU SDPA requirement). - - Inverts boolean masks (NPU treats ``True`` as masked). - """ - from transformers.integrations.sdpa_attention import repeat_kv - - if hasattr(module, 'num_key_value_groups'): - key = repeat_kv(key, module.num_key_value_groups) - value = repeat_kv(value, module.num_key_value_groups) - - causal_mask = attention_mask - if attention_mask is not None and causal_mask.ndim == 4: - causal_mask = causal_mask[:, :, :, :key.shape[-2]] - - query, key, value = query.contiguous(), key.contiguous(), value.contiguous() - - if is_causal is None: - is_causal = query.shape[2] > 1 and causal_mask is None - - if causal_mask is not None and causal_mask.dtype != torch.bool: - causal_mask = torch.logical_not(causal_mask.bool()).to(query.device) - - attn_output = torch.nn.functional.scaled_dot_product_attention( - query, - key, - value, - attn_mask=causal_mask, - dropout_p=dropout, - scale=scaling, - is_causal=is_causal, - ) +# Copyright (c) ModelScope Contributors. All rights reserved. +"""SDPA forward with Ascend NPU compatibility fixes.""" +from __future__ import annotations + +import torch + + +def npu_sdpa_attention_forward( + module, + query, + key, + value, + attention_mask, + dropout=0.0, + scaling=None, + is_causal=None, + **kwargs, +): + """Drop-in replacement for ``transformers.integrations.sdpa_attention.sdpa_attention_forward``. + + Fixes: + - Repeats KV heads (NPU SDPA does not auto-broadcast num_kv_groups). + - Truncates causal_mask to key length. + - Forces contiguous tensors (NPU SDPA requirement). + - Inverts boolean masks (NPU treats ``True`` as masked). + """ + from transformers.integrations.sdpa_attention import repeat_kv + + if hasattr(module, 'num_key_value_groups'): + key = repeat_kv(key, module.num_key_value_groups) + value = repeat_kv(value, module.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None and causal_mask.ndim == 4: + causal_mask = causal_mask[:, :, :, :key.shape[-2]] + + query, key, value = query.contiguous(), key.contiguous(), value.contiguous() + + if is_causal is None: + is_causal = query.shape[2] > 1 and causal_mask is None + + if causal_mask is not None and causal_mask.dtype != torch.bool: + causal_mask = torch.logical_not(causal_mask.bool()).to(query.device) + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query, + key, + value, + attn_mask=causal_mask, + dropout_p=dropout, + scale=scaling, + is_causal=is_causal, + ) return attn_output.transpose(1, 2).contiguous(), None \ No newline at end of file diff --git a/src/twinkle/kernel/npu_impls/fla.py b/src/twinkle/kernel/npu_impls/fla.py index d2fc43a9..9a387f77 100644 --- a/src/twinkle/kernel/npu_impls/fla.py +++ b/src/twinkle/kernel/npu_impls/fla.py @@ -1,103 +1,103 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -"""Qwen3.5 Flash Linear Attention enablement for Ascend NPU.""" -from __future__ import annotations - -import importlib -import os - -from twinkle import get_logger - -logger = get_logger() - - -def _is_env_enabled(var: str, default: bool = True) -> bool: - env = os.environ.get(var, '').lower().strip() - if not env: - return default - if env in ('1', 'true', 'on', 'yes'): - return True - if env in ('0', 'false', 'off', 'no'): - return False - return default - - -def _import_optional(name: str): - try: - return importlib.import_module(name) - except ImportError: - return None - - -def apply_qwen3_5_fla(model=None) -> int: - """Enable Flash Linear Attention fast path for Qwen3.5 on NPU. - - Returns the count of patched per-layer instances (0 when disabled or when - prerequisites are missing). Safe to call multiple times. - """ - if not _is_env_enabled('TWINKLE_NPU_FLA', default=True): - logger.info('[NPU] [FLA] Disabled by TWINKLE_NPU_FLA') - return 0 - - if _import_optional('torch_npu') is None: - logger.info('[NPU] [FLA] Skip: torch_npu unavailable') - return 0 - - # 1. Confirm the MindSpeed Triton kernel is actually importable BEFORE - # flipping any global availability flags. If we flip the flag and then - # fail to install the kernel, HF transformers would route Qwen3.5 onto - # a FLA fast path whose kernel is missing -> runtime failure on NPU. - try: - from twinkle.kernel.chunk_gated_delta_rule import chunk_gated_delta_rule as mindspeed_fla - from twinkle.kernel.causal_conv1d import npu_causal_conv1d_fn - except ImportError as exc: - logger.warning('[NPU] [FLA] MindSpeed unavailable: %s', exc) - return 0 - - # 2. Only now can we safely claim FLA is available: flip the global flags - # and install the kernel path on Qwen3.5 modeling modules. - def _is_fla_available() -> bool: - return True - - for utils_mod_name in ('transformers.utils', 'transformers.utils.import_utils'): - utils_mod = _import_optional(utils_mod_name) - if utils_mod is not None: - setattr(utils_mod, 'is_flash_linear_attention_available', _is_fla_available) - - # 3. Patch Qwen3.5 modeling modules - fla_target_modules = [ - 'transformers.models.qwen3_5.modeling_qwen3_5', - 'transformers.models.qwen3_5_moe.modeling_qwen3_5_moe', - ] - for module_name in fla_target_modules: - module = _import_optional(module_name) - if module is None: - continue - setattr(module, 'is_flash_linear_attention_available', _is_fla_available) - setattr(module, 'is_fast_path_available', True) - if hasattr(module, 'FusedRMSNormGated'): - setattr(module, 'FusedRMSNormGated', None) - setattr(module, 'chunk_gated_delta_rule', mindspeed_fla) - - # 4. Traverse model and patch per-layer attributes - if model is None: - return 0 - - root = getattr(model, 'model', getattr(model, 'module', model)) - if not hasattr(root, 'named_modules'): - return 0 - - patched_instances = 0 - for _name, _module in root.named_modules(): - if hasattr(_module, 'chunk_gated_delta_rule') and callable( - getattr(_module, 'chunk_gated_delta_rule')): - if _module.chunk_gated_delta_rule is not mindspeed_fla: - _module.chunk_gated_delta_rule = mindspeed_fla - _module._twinkle_npu_patched = True - patched_instances += 1 - if hasattr(_module, 'causal_conv1d_fn'): - if getattr(_module, 'causal_conv1d_fn') is not npu_causal_conv1d_fn: - _module.causal_conv1d_fn = npu_causal_conv1d_fn - - if patched_instances: - logger.info('[NPU] [FLA] Patched %d linear attention instance(s)', patched_instances) +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Qwen3.5 Flash Linear Attention enablement for Ascend NPU.""" +from __future__ import annotations + +import importlib +import os + +from twinkle import get_logger + +logger = get_logger() + + +def _is_env_enabled(var: str, default: bool = True) -> bool: + env = os.environ.get(var, '').lower().strip() + if not env: + return default + if env in ('1', 'true', 'on', 'yes'): + return True + if env in ('0', 'false', 'off', 'no'): + return False + return default + + +def _import_optional(name: str): + try: + return importlib.import_module(name) + except ImportError: + return None + + +def apply_qwen3_5_fla(model=None) -> int: + """Enable Flash Linear Attention fast path for Qwen3.5 on NPU. + + Returns the count of patched per-layer instances (0 when disabled or when + prerequisites are missing). Safe to call multiple times. + """ + if not _is_env_enabled('TWINKLE_NPU_FLA', default=True): + logger.info('[NPU] [FLA] Disabled by TWINKLE_NPU_FLA') + return 0 + + if _import_optional('torch_npu') is None: + logger.info('[NPU] [FLA] Skip: torch_npu unavailable') + return 0 + + # 1. Confirm the MindSpeed Triton kernel is actually importable BEFORE + # flipping any global availability flags. If we flip the flag and then + # fail to install the kernel, HF transformers would route Qwen3.5 onto + # a FLA fast path whose kernel is missing -> runtime failure on NPU. + try: + from twinkle.kernel.chunk_gated_delta_rule import chunk_gated_delta_rule as mindspeed_fla + from twinkle.kernel.causal_conv1d import npu_causal_conv1d_fn + except ImportError as exc: + logger.warning('[NPU] [FLA] MindSpeed unavailable: %s', exc) + return 0 + + # 2. Only now can we safely claim FLA is available: flip the global flags + # and install the kernel path on Qwen3.5 modeling modules. + def _is_fla_available() -> bool: + return True + + for utils_mod_name in ('transformers.utils', 'transformers.utils.import_utils'): + utils_mod = _import_optional(utils_mod_name) + if utils_mod is not None: + setattr(utils_mod, 'is_flash_linear_attention_available', _is_fla_available) + + # 3. Patch Qwen3.5 modeling modules + fla_target_modules = [ + 'transformers.models.qwen3_5.modeling_qwen3_5', + 'transformers.models.qwen3_5_moe.modeling_qwen3_5_moe', + ] + for module_name in fla_target_modules: + module = _import_optional(module_name) + if module is None: + continue + setattr(module, 'is_flash_linear_attention_available', _is_fla_available) + setattr(module, 'is_fast_path_available', True) + if hasattr(module, 'FusedRMSNormGated'): + setattr(module, 'FusedRMSNormGated', None) + setattr(module, 'chunk_gated_delta_rule', mindspeed_fla) + + # 4. Traverse model and patch per-layer attributes + if model is None: + return 0 + + root = getattr(model, 'model', getattr(model, 'module', model)) + if not hasattr(root, 'named_modules'): + return 0 + + patched_instances = 0 + for _name, _module in root.named_modules(): + if hasattr(_module, 'chunk_gated_delta_rule') and callable( + getattr(_module, 'chunk_gated_delta_rule')): + if _module.chunk_gated_delta_rule is not mindspeed_fla: + _module.chunk_gated_delta_rule = mindspeed_fla + _module._twinkle_npu_patched = True + patched_instances += 1 + if hasattr(_module, 'causal_conv1d_fn'): + if getattr(_module, 'causal_conv1d_fn') is not npu_causal_conv1d_fn: + _module.causal_conv1d_fn = npu_causal_conv1d_fn + + if patched_instances: + logger.info('[NPU] [FLA] Patched %d linear attention instance(s)', patched_instances) return patched_instances \ No newline at end of file diff --git a/src/twinkle/kernel/npu_impls/moe.py b/src/twinkle/kernel/npu_impls/moe.py index efa7f71a..c576cd04 100644 --- a/src/twinkle/kernel/npu_impls/moe.py +++ b/src/twinkle/kernel/npu_impls/moe.py @@ -1,151 +1,151 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -"""MoE GMM + packed-experts + sparse-block impls for Ascend NPU.""" -from __future__ import annotations - -import torch -import torch.nn.functional as F - - -class GmmFunction(torch.autograd.Function): - """Custom autograd function for NPU grouped matrix multiplication.""" - - @staticmethod - def forward(ctx, x: torch.Tensor, group_list: torch.Tensor, weight_ekn: torch.Tensor): - import torch_npu - group_list = group_list.to(torch.int64) - ctx.save_for_backward(x, group_list, weight_ekn) - outputs = torch_npu.npu_grouped_matmul( - [x], [weight_ekn], group_list=group_list, - group_type=0, split_item=2, group_list_type=1, - ) - return outputs[0] - - @staticmethod - def backward(ctx, grad_output: torch.Tensor): - import torch_npu - x, group_list, weight_ekn = ctx.saved_tensors - grad_input = torch_npu.npu_grouped_matmul( - [grad_output], [weight_ekn.transpose(-2, -1).contiguous()], - bias=None, group_list=group_list, - group_type=0, split_item=2, group_list_type=1, - )[0] - grad_weight = torch_npu.npu_grouped_matmul( - [x.transpose(0, 1)], [grad_output], - bias=None, group_list=group_list, - group_type=2, split_item=3, group_list_type=1, - )[0] - return grad_input, None, grad_weight.contiguous() - - -def npu_grouped_mm(input: torch.Tensor, weight_ekn: torch.Tensor, offs: torch.Tensor) -> torch.Tensor: - """Drop-in replacement for ``transformers.integrations.moe._grouped_mm``.""" - counts = torch.empty_like(offs) - counts[0] = offs[0] - if offs.numel() > 1: - counts[1:] = offs[1:] - offs[:-1] - counts = counts.to(torch.int64) - return GmmFunction.apply(input, counts, weight_ekn) - - -def _normalize_packed_expert_weights(module, input_dtype, hidden_dim): - gate_up_proj = module.gate_up_proj.to(input_dtype) - down_proj = module.down_proj.to(input_dtype) - if gate_up_proj.shape[1] == hidden_dim: - gate_up_weight = gate_up_proj - elif gate_up_proj.shape[2] == hidden_dim: - gate_up_weight = gate_up_proj.transpose(1, 2) - else: - raise RuntimeError(f'Unsupported gate_up_proj shape: {tuple(gate_up_proj.shape)}.') - if down_proj.shape[2] == hidden_dim: - down_weight = down_proj - elif down_proj.shape[1] == hidden_dim: - down_weight = down_proj.transpose(1, 2) - else: - raise RuntimeError(f'Unsupported down_proj shape: {tuple(down_proj.shape)}.') - return gate_up_weight, down_weight - - -def _get_cached_expert_weights(self, target_dtype, hidden_dim): - requires_grad = ( - getattr(self.gate_up_proj, 'requires_grad', False) - or getattr(self.down_proj, 'requires_grad', False) - ) - cache_attr = '_npu_expert_cache' - if not requires_grad and hasattr(self, cache_attr): - cached_dtype, cached_gv, cached_dv, cached = getattr(self, cache_attr) - if (cached_dtype == target_dtype - and cached_gv == self.gate_up_proj._version - and cached_dv == self.down_proj._version): - return cached - weights = _normalize_packed_expert_weights(self, target_dtype, hidden_dim) - if not requires_grad: - setattr(self, cache_attr, - (target_dtype, self.gate_up_proj._version, self.down_proj._version, weights)) - return weights - - -def npu_packed_moe_experts_forward(self, hidden_states, a, b): - """Packed MoE Experts.forward using NPU grouped matmul. - - Accepts both call orderings: ``(hidden_states, routing_weights, router_indices)`` - and ``(hidden_states, router_indices, routing_weights)`` — distinguishes by dtype. - """ - import torch_npu - if a.dtype in {torch.int8, torch.int16, torch.int32, torch.int64, torch.uint8}: - router_indices, routing_weights = a, b - else: - routing_weights, router_indices = a, b - - output_shape = hidden_states.shape - hidden_dim = output_shape[-1] - hidden_states = hidden_states.reshape(-1, hidden_dim) - - if routing_weights.shape != router_indices.shape: - routing_weights = torch.gather(routing_weights, dim=-1, index=router_indices.to(torch.long)) - routing_weights = routing_weights.to(hidden_states.dtype) - router_indices = router_indices.to(torch.int32) - - permuted, row_ids_map = torch_npu.npu_moe_token_permute(hidden_states, router_indices) - tokens_per_expert = torch.bincount(router_indices.view(-1), minlength=self.num_experts).to(torch.int64) - gate_up_weight, down_weight = _get_cached_expert_weights(self, hidden_states.dtype, hidden_dim) - - intermediate = GmmFunction.apply(permuted, tokens_per_expert, gate_up_weight) - activated = torch_npu.npu_swiglu(intermediate, dim=-1) - output = GmmFunction.apply(activated, tokens_per_expert, down_weight) - next_states = torch_npu.npu_moe_token_unpermute(output, row_ids_map, probs=routing_weights) - return next_states.view(*output_shape) - - -def _topk_from_router_logits(module, hidden_states, router_logits): - routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float) - routing_weights, router_indices = torch.topk(routing_weights, module.top_k, dim=-1) - if getattr(module, 'norm_topk_prob', True): - routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True) - routing_weights = routing_weights.to(hidden_states.dtype) - return routing_weights, router_indices - - -def _add_shared_expert(self, hidden_states, expert_output): - if not (hasattr(self, 'shared_expert') and hasattr(self, 'shared_expert_gate')): - return expert_output - shared = self.shared_expert(hidden_states) - shared = F.sigmoid(self.shared_expert_gate(hidden_states)) * shared - return expert_output + shared - - -def npu_qwen3_5_moe_sparse_block_forward(self, hidden_states): - """SparseMoeBlock.forward replacement (Transformers 4.x and 5.x compatible).""" - batch_size, sequence_length, hidden_dim = hidden_states.shape - gate_output = self.gate(hidden_states.view(-1, hidden_dim)) - - if isinstance(gate_output, tuple): - _, routing_weights, selected_experts = gate_output - flat = hidden_states.view(-1, hidden_dim) - expert_output = self.experts(flat, selected_experts, routing_weights) - else: - flat = hidden_states.view(-1, hidden_dim) - routing_weights, selected_experts = _topk_from_router_logits(self, flat, gate_output) - expert_output = self.experts(flat, selected_experts, routing_weights) - - expert_output = _add_shared_expert(self, flat, expert_output) +# Copyright (c) ModelScope Contributors. All rights reserved. +"""MoE GMM + packed-experts + sparse-block impls for Ascend NPU.""" +from __future__ import annotations + +import torch +import torch.nn.functional as F + + +class GmmFunction(torch.autograd.Function): + """Custom autograd function for NPU grouped matrix multiplication.""" + + @staticmethod + def forward(ctx, x: torch.Tensor, group_list: torch.Tensor, weight_ekn: torch.Tensor): + import torch_npu + group_list = group_list.to(torch.int64) + ctx.save_for_backward(x, group_list, weight_ekn) + outputs = torch_npu.npu_grouped_matmul( + [x], [weight_ekn], group_list=group_list, + group_type=0, split_item=2, group_list_type=1, + ) + return outputs[0] + + @staticmethod + def backward(ctx, grad_output: torch.Tensor): + import torch_npu + x, group_list, weight_ekn = ctx.saved_tensors + grad_input = torch_npu.npu_grouped_matmul( + [grad_output], [weight_ekn.transpose(-2, -1).contiguous()], + bias=None, group_list=group_list, + group_type=0, split_item=2, group_list_type=1, + )[0] + grad_weight = torch_npu.npu_grouped_matmul( + [x.transpose(0, 1)], [grad_output], + bias=None, group_list=group_list, + group_type=2, split_item=3, group_list_type=1, + )[0] + return grad_input, None, grad_weight.contiguous() + + +def npu_grouped_mm(input: torch.Tensor, weight_ekn: torch.Tensor, offs: torch.Tensor) -> torch.Tensor: + """Drop-in replacement for ``transformers.integrations.moe._grouped_mm``.""" + counts = torch.empty_like(offs) + counts[0] = offs[0] + if offs.numel() > 1: + counts[1:] = offs[1:] - offs[:-1] + counts = counts.to(torch.int64) + return GmmFunction.apply(input, counts, weight_ekn) + + +def _normalize_packed_expert_weights(module, input_dtype, hidden_dim): + gate_up_proj = module.gate_up_proj.to(input_dtype) + down_proj = module.down_proj.to(input_dtype) + if gate_up_proj.shape[1] == hidden_dim: + gate_up_weight = gate_up_proj + elif gate_up_proj.shape[2] == hidden_dim: + gate_up_weight = gate_up_proj.transpose(1, 2) + else: + raise RuntimeError(f'Unsupported gate_up_proj shape: {tuple(gate_up_proj.shape)}.') + if down_proj.shape[2] == hidden_dim: + down_weight = down_proj + elif down_proj.shape[1] == hidden_dim: + down_weight = down_proj.transpose(1, 2) + else: + raise RuntimeError(f'Unsupported down_proj shape: {tuple(down_proj.shape)}.') + return gate_up_weight, down_weight + + +def _get_cached_expert_weights(self, target_dtype, hidden_dim): + requires_grad = ( + getattr(self.gate_up_proj, 'requires_grad', False) + or getattr(self.down_proj, 'requires_grad', False) + ) + cache_attr = '_npu_expert_cache' + if not requires_grad and hasattr(self, cache_attr): + cached_dtype, cached_gv, cached_dv, cached = getattr(self, cache_attr) + if (cached_dtype == target_dtype + and cached_gv == self.gate_up_proj._version + and cached_dv == self.down_proj._version): + return cached + weights = _normalize_packed_expert_weights(self, target_dtype, hidden_dim) + if not requires_grad: + setattr(self, cache_attr, + (target_dtype, self.gate_up_proj._version, self.down_proj._version, weights)) + return weights + + +def npu_packed_moe_experts_forward(self, hidden_states, a, b): + """Packed MoE Experts.forward using NPU grouped matmul. + + Accepts both call orderings: ``(hidden_states, routing_weights, router_indices)`` + and ``(hidden_states, router_indices, routing_weights)`` — distinguishes by dtype. + """ + import torch_npu + if a.dtype in {torch.int8, torch.int16, torch.int32, torch.int64, torch.uint8}: + router_indices, routing_weights = a, b + else: + routing_weights, router_indices = a, b + + output_shape = hidden_states.shape + hidden_dim = output_shape[-1] + hidden_states = hidden_states.reshape(-1, hidden_dim) + + if routing_weights.shape != router_indices.shape: + routing_weights = torch.gather(routing_weights, dim=-1, index=router_indices.to(torch.long)) + routing_weights = routing_weights.to(hidden_states.dtype) + router_indices = router_indices.to(torch.int32) + + permuted, row_ids_map = torch_npu.npu_moe_token_permute(hidden_states, router_indices) + tokens_per_expert = torch.bincount(router_indices.view(-1), minlength=self.num_experts).to(torch.int64) + gate_up_weight, down_weight = _get_cached_expert_weights(self, hidden_states.dtype, hidden_dim) + + intermediate = GmmFunction.apply(permuted, tokens_per_expert, gate_up_weight) + activated = torch_npu.npu_swiglu(intermediate, dim=-1) + output = GmmFunction.apply(activated, tokens_per_expert, down_weight) + next_states = torch_npu.npu_moe_token_unpermute(output, row_ids_map, probs=routing_weights) + return next_states.view(*output_shape) + + +def _topk_from_router_logits(module, hidden_states, router_logits): + routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float) + routing_weights, router_indices = torch.topk(routing_weights, module.top_k, dim=-1) + if getattr(module, 'norm_topk_prob', True): + routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True) + routing_weights = routing_weights.to(hidden_states.dtype) + return routing_weights, router_indices + + +def _add_shared_expert(self, hidden_states, expert_output): + if not (hasattr(self, 'shared_expert') and hasattr(self, 'shared_expert_gate')): + return expert_output + shared = self.shared_expert(hidden_states) + shared = F.sigmoid(self.shared_expert_gate(hidden_states)) * shared + return expert_output + shared + + +def npu_qwen3_5_moe_sparse_block_forward(self, hidden_states): + """SparseMoeBlock.forward replacement (Transformers 4.x and 5.x compatible).""" + batch_size, sequence_length, hidden_dim = hidden_states.shape + gate_output = self.gate(hidden_states.view(-1, hidden_dim)) + + if isinstance(gate_output, tuple): + _, routing_weights, selected_experts = gate_output + flat = hidden_states.view(-1, hidden_dim) + expert_output = self.experts(flat, selected_experts, routing_weights) + else: + flat = hidden_states.view(-1, hidden_dim) + routing_weights, selected_experts = _topk_from_router_logits(self, flat, gate_output) + expert_output = self.experts(flat, selected_experts, routing_weights) + + expert_output = _add_shared_expert(self, flat, expert_output) return expert_output.reshape(batch_size, sequence_length, hidden_dim) \ No newline at end of file diff --git a/src/twinkle/kernel/npu_impls/rms_norm.py b/src/twinkle/kernel/npu_impls/rms_norm.py index ecebdc23..98e95699 100644 --- a/src/twinkle/kernel/npu_impls/rms_norm.py +++ b/src/twinkle/kernel/npu_impls/rms_norm.py @@ -1,75 +1,75 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -"""Fused RMSNorm impls for Ascend NPU. - -Designed for class-replacement: do not define ``__init__``; rely on the -attributes already present on the original instance. -""" -from __future__ import annotations - -import os - -import torch -import torch.nn as nn -import torch.nn.functional as F - -from twinkle import get_logger - -logger = get_logger() - - -class NpuRMSNorm(nn.Module): - """Class-replacement impl for HF RMSNorm variants. - - Required instance attributes (provided by the original class): - - ``weight``: ``nn.Parameter`` - - ``variance_epsilon`` *or* ``eps``: float - """ - - def _twinkle_residual_param(self) -> bool: - """Lazily detect residual parameterization (e.g. Qwen3.5: scale = 1 + weight).""" - cached = getattr(self, '_twinkle_residual_cached', None) - if cached is None: - cached = abs(self.weight.data.mean().item()) < 0.3 - self._twinkle_residual_cached = cached - if cached: - logger.debug('[NPU] NpuRMSNorm using residual parameterization (1.0 + weight)') - return cached - - def _twinkle_eps(self) -> float: - return getattr(self, 'variance_epsilon', getattr(self, 'eps', 1e-6)) - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - import torch_npu - target_dtype = hidden_states.dtype - if self._twinkle_residual_param(): - scale = (1.0 + self.weight).to(target_dtype) - else: - scale = self.weight.to(target_dtype) - return torch_npu.npu_rms_norm(hidden_states, scale, epsilon=self._twinkle_eps())[0] - - -# Resolved once at import: matches the legacy "patch-time, process-wide" invariant. -# Mid-process env mutation will not retroactively change behavior. -_FORCE_FP32 = os.environ.get('TWINKLE_NPU_GATED_RMSNorm_FP32', '0').lower() in ( - '1', 'true', 'on', 'yes' -) - - -def npu_gated_rms_norm_forward(self, hidden_states, gate=None): - """Forward replacement for Gated RMSNorm variants (e.g. Qwen3.5-MoE).""" - import torch_npu - - input_dtype = hidden_states.dtype - _eps = getattr(self, 'variance_epsilon', getattr(self, 'eps', 1e-6)) - - if _FORCE_FP32: - hidden_states = hidden_states.to(torch.float32) - weight = self.weight.float() - gate = gate.to(torch.float32) if gate is not None else None - else: - weight = self.weight - - hidden_states = torch_npu.npu_rms_norm(hidden_states, weight, epsilon=_eps)[0] - if gate is not None: - hidden_states = hidden_states * F.silu(gate) +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Fused RMSNorm impls for Ascend NPU. + +Designed for class-replacement: do not define ``__init__``; rely on the +attributes already present on the original instance. +""" +from __future__ import annotations + +import os + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from twinkle import get_logger + +logger = get_logger() + + +class NpuRMSNorm(nn.Module): + """Class-replacement impl for HF RMSNorm variants. + + Required instance attributes (provided by the original class): + - ``weight``: ``nn.Parameter`` + - ``variance_epsilon`` *or* ``eps``: float + """ + + def _twinkle_residual_param(self) -> bool: + """Lazily detect residual parameterization (e.g. Qwen3.5: scale = 1 + weight).""" + cached = getattr(self, '_twinkle_residual_cached', None) + if cached is None: + cached = abs(self.weight.data.mean().item()) < 0.3 + self._twinkle_residual_cached = cached + if cached: + logger.debug('[NPU] NpuRMSNorm using residual parameterization (1.0 + weight)') + return cached + + def _twinkle_eps(self) -> float: + return getattr(self, 'variance_epsilon', getattr(self, 'eps', 1e-6)) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + import torch_npu + target_dtype = hidden_states.dtype + if self._twinkle_residual_param(): + scale = (1.0 + self.weight).to(target_dtype) + else: + scale = self.weight.to(target_dtype) + return torch_npu.npu_rms_norm(hidden_states, scale, epsilon=self._twinkle_eps())[0] + + +# Resolved once at import: matches the legacy "patch-time, process-wide" invariant. +# Mid-process env mutation will not retroactively change behavior. +_FORCE_FP32 = os.environ.get('TWINKLE_NPU_GATED_RMSNorm_FP32', '0').lower() in ( + '1', 'true', 'on', 'yes' +) + + +def npu_gated_rms_norm_forward(self, hidden_states, gate=None): + """Forward replacement for Gated RMSNorm variants (e.g. Qwen3.5-MoE).""" + import torch_npu + + input_dtype = hidden_states.dtype + _eps = getattr(self, 'variance_epsilon', getattr(self, 'eps', 1e-6)) + + if _FORCE_FP32: + hidden_states = hidden_states.to(torch.float32) + weight = self.weight.float() + gate = gate.to(torch.float32) if gate is not None else None + else: + weight = self.weight + + hidden_states = torch_npu.npu_rms_norm(hidden_states, weight, epsilon=_eps)[0] + if gate is not None: + hidden_states = hidden_states * F.silu(gate) return hidden_states.to(input_dtype) \ No newline at end of file diff --git a/src/twinkle/kernel/npu_impls/rotary.py b/src/twinkle/kernel/npu_impls/rotary.py index 1ed437a3..6493dc8b 100644 --- a/src/twinkle/kernel/npu_impls/rotary.py +++ b/src/twinkle/kernel/npu_impls/rotary.py @@ -1,66 +1,66 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -"""Fused RoPE impls for Ascend NPU (lazy ``torch_npu`` import).""" -from __future__ import annotations - -import torch - - -def _resolve_unsqueeze_dim(position_ids=None, unsqueeze_dim=1): - if isinstance(position_ids, int) and unsqueeze_dim == 1: - return position_ids - return unsqueeze_dim - - -def _make_apply_npu_rotary_emb(): - """Closure with per-shape Partial-RoPE detection cache.""" - _cached_partial: dict[tuple[int, int], bool] = {} - - def _apply(q, k, cos, sin): - import torch_npu - rotary_dim = cos.shape[-1] - query_dim = q.shape[-1] - shape_key = (rotary_dim, query_dim) - - use_partial = _cached_partial.get(shape_key) - if use_partial is None: - use_partial = rotary_dim < query_dim - _cached_partial[shape_key] = use_partial - - if use_partial: - q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] - k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] - q_embed = torch_npu.npu_rotary_mul(q_rot, cos, sin).to(q.dtype) - k_embed = torch_npu.npu_rotary_mul(k_rot, cos, sin).to(k.dtype) - q_embed = torch.cat([q_embed, q_pass], dim=-1) - k_embed = torch.cat([k_embed, k_pass], dim=-1) - else: - q_embed = torch_npu.npu_rotary_mul(q, cos, sin).to(q.dtype) - k_embed = torch_npu.npu_rotary_mul(k, cos, sin).to(k.dtype) - return q_embed, k_embed - - return _apply - - -_apply_npu_rotary_emb = _make_apply_npu_rotary_emb() - - -def npu_apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): - """Fused RoPE via ``torch_npu.npu_rotary_mul`` with Partial-RoPE support.""" - unsqueeze_dim = _resolve_unsqueeze_dim(position_ids, unsqueeze_dim) - cos = cos.unsqueeze(unsqueeze_dim) - sin = sin.unsqueeze(unsqueeze_dim) - return _apply_npu_rotary_emb(q, k, cos, sin) - - -def npu_apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1): - """Multimodal RoPE for Qwen2.5-VL with Partial-RoPE support.""" - mrope_section = mrope_section * 2 - cos = torch.cat( - [m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], - dim=-1, - ).unsqueeze(unsqueeze_dim) - sin = torch.cat( - [m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], - dim=-1, - ).unsqueeze(unsqueeze_dim) +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Fused RoPE impls for Ascend NPU (lazy ``torch_npu`` import).""" +from __future__ import annotations + +import torch + + +def _resolve_unsqueeze_dim(position_ids=None, unsqueeze_dim=1): + if isinstance(position_ids, int) and unsqueeze_dim == 1: + return position_ids + return unsqueeze_dim + + +def _make_apply_npu_rotary_emb(): + """Closure with per-shape Partial-RoPE detection cache.""" + _cached_partial: dict[tuple[int, int], bool] = {} + + def _apply(q, k, cos, sin): + import torch_npu + rotary_dim = cos.shape[-1] + query_dim = q.shape[-1] + shape_key = (rotary_dim, query_dim) + + use_partial = _cached_partial.get(shape_key) + if use_partial is None: + use_partial = rotary_dim < query_dim + _cached_partial[shape_key] = use_partial + + if use_partial: + q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] + k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] + q_embed = torch_npu.npu_rotary_mul(q_rot, cos, sin).to(q.dtype) + k_embed = torch_npu.npu_rotary_mul(k_rot, cos, sin).to(k.dtype) + q_embed = torch.cat([q_embed, q_pass], dim=-1) + k_embed = torch.cat([k_embed, k_pass], dim=-1) + else: + q_embed = torch_npu.npu_rotary_mul(q, cos, sin).to(q.dtype) + k_embed = torch_npu.npu_rotary_mul(k, cos, sin).to(k.dtype) + return q_embed, k_embed + + return _apply + + +_apply_npu_rotary_emb = _make_apply_npu_rotary_emb() + + +def npu_apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Fused RoPE via ``torch_npu.npu_rotary_mul`` with Partial-RoPE support.""" + unsqueeze_dim = _resolve_unsqueeze_dim(position_ids, unsqueeze_dim) + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + return _apply_npu_rotary_emb(q, k, cos, sin) + + +def npu_apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1): + """Multimodal RoPE for Qwen2.5-VL with Partial-RoPE support.""" + mrope_section = mrope_section * 2 + cos = torch.cat( + [m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], + dim=-1, + ).unsqueeze(unsqueeze_dim) + sin = torch.cat( + [m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], + dim=-1, + ).unsqueeze(unsqueeze_dim) return _apply_npu_rotary_emb(q, k, cos, sin) \ No newline at end of file diff --git a/src/twinkle/kernel/npu_impls/swiglu.py b/src/twinkle/kernel/npu_impls/swiglu.py index c34a7bea..4be68184 100644 --- a/src/twinkle/kernel/npu_impls/swiglu.py +++ b/src/twinkle/kernel/npu_impls/swiglu.py @@ -1,20 +1,20 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -"""Fused SwiGLU forward for Ascend NPU.""" -from __future__ import annotations - -import torch - - -def npu_swiglu_forward(self, hidden_state): - """Fused Qwen-style SwiGLU. - - Used as a class-attribute replacement on HF MLP classes. - Required instance attributes: ``gate_proj``, ``up_proj``, ``down_proj``. - """ - import torch_npu - return self.down_proj( - torch_npu.npu_swiglu( - torch.cat((self.gate_proj(hidden_state), self.up_proj(hidden_state)), dim=-1), - dim=-1, - ) +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Fused SwiGLU forward for Ascend NPU.""" +from __future__ import annotations + +import torch + + +def npu_swiglu_forward(self, hidden_state): + """Fused Qwen-style SwiGLU. + + Used as a class-attribute replacement on HF MLP classes. + Required instance attributes: ``gate_proj``, ``up_proj``, ``down_proj``. + """ + import torch_npu + return self.down_proj( + torch_npu.npu_swiglu( + torch.cat((self.gate_proj(hidden_state), self.up_proj(hidden_state)), dim=-1), + dim=-1, + ) ) \ No newline at end of file diff --git a/src/twinkle/model/transformers/moe/expert_parallel.py b/src/twinkle/model/transformers/moe/expert_parallel.py index 2b6e45ea..3f3e5c6a 100644 --- a/src/twinkle/model/transformers/moe/expert_parallel.py +++ b/src/twinkle/model/transformers/moe/expert_parallel.py @@ -1,533 +1,533 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -from __future__ import annotations - -import inspect -import torch -import torch.distributed as dist -import torch.nn.functional as F -from dataclasses import dataclass -from torch import nn -from typing import Any, Dict, Iterable, List, Optional, Tuple - -from twinkle.model.transformers.moe.ep_utils import preprocess, token_pre_all2all, tokens_post_all2all -from twinkle.utils import DeviceMesh - - -@dataclass -class ExpertParallelConfig: - enabled: bool = True - router_dtype: str = 'fp32' - keep_router_logits: bool = True - ignore_shared_experts: bool = False - ep_size: int | None = None # consumed by TransformersModel, not used in expert_parallel logic - - -@dataclass -class ExpertShardingSpec: - """Describes expert sharding info for a single MoE block. Extensible for other models.""" - block: nn.Module - experts_module: nn.Module - num_experts: int - experts_per_rank: int - local_start: int - local_end: int - ep_rank: int - ep_world_size: int - is_tensor_experts: bool - - -def apply_expert_parallel( - model: nn.Module, - device_mesh: DeviceMesh, - config: dict[str, Any] | None = None, - ep_fsdp_device_mesh: torch.distributed.DeviceMesh | None = None, -) -> list[ExpertShardingSpec]: - """Apply expert parallelism to all MoE blocks in the model.""" - cfg = _merge_config(config) - - # EP info comes from the separate ep_fsdp_device_mesh, not from main mesh - if not cfg.enabled or ep_fsdp_device_mesh is None: - return [] - - # Always query EP via the 1D submesh to avoid relying on Tensor named dims. - ep_mesh = ep_fsdp_device_mesh['ep'] - ep_world_size = ep_mesh.size() - if ep_world_size <= 1: - return [] - - if not dist.is_initialized(): - raise RuntimeError('torch.distributed is not initialized, cannot enable expert parallel.') - - # Get process group and local rank from EP submesh. - ep_group = ep_mesh.get_group() - ep_rank = ep_mesh.get_local_rank() - - specs = [] - for _, block in find_moe_blocks_with_names(model): - spec = shard_experts(block, ep_world_size, ep_rank, cfg) - patch_forward(block, ep_group, ep_world_size, cfg) - specs.append(spec) - - return specs - - -def _merge_config(config: dict[str, Any] | None) -> ExpertParallelConfig: - cfg = ExpertParallelConfig() - if not config: - return cfg - for key, value in config.items(): - if not hasattr(cfg, key): - raise ValueError(f'Unknown expert parallel config: {key}') - setattr(cfg, key, value) - return cfg - - -def find_moe_blocks(model: nn.Module) -> Iterable[nn.Module]: - return [block for _, block in find_moe_blocks_with_names(model)] - - -def find_moe_blocks_with_names(model: nn.Module) -> Iterable[tuple[str, nn.Module]]: - blocks = [] - for name, module in model.named_modules(): - experts = getattr(module, 'experts', None) - if experts is None: - continue - if not _is_moe_experts(experts): - continue - if not _get_gate(module): - continue - blocks.append((name, module)) - return blocks - - -def shard_experts( - block: nn.Module, - ep_world_size: int, - ep_rank: int, - cfg: ExpertParallelConfig, -) -> ExpertShardingSpec: - """Shard experts in a MoE block across EP ranks. - - Args: - block: The MoE block containing experts. - ep_world_size: The world size for expert parallelism. - ep_rank: The current rank in the EP group. - cfg: Expert parallel configuration. - - Returns an ExpertShardingSpec describing the sharding. - """ - num_experts = _get_num_experts(block) - - if num_experts % ep_world_size != 0: - raise ValueError(f'num_experts ({num_experts}) must be divisible by ep_world_size ({ep_world_size}).') - - experts_per_rank = num_experts // ep_world_size - local_start = ep_rank * experts_per_rank - local_end = local_start + experts_per_rank - - if isinstance(block.experts, nn.ModuleList): - local_experts = nn.ModuleList(block.experts[local_start:local_end]) - block.experts = local_experts - is_tensor_experts = False - else: - _shard_tensor_experts(block.experts, local_start, local_end) - is_tensor_experts = True - - block._ep_num_experts = num_experts - block._ep_experts_per_rank = experts_per_rank - block._ep_local_start = local_start - block._ep_local_end = local_end - block._ep_rank = ep_rank - block._ep_world_size = ep_world_size - block._ep_tensor_experts = is_tensor_experts - block._ep_ignore_shared_experts = cfg.ignore_shared_experts - - return ExpertShardingSpec( - block=block, - experts_module=block.experts, - num_experts=num_experts, - experts_per_rank=experts_per_rank, - local_start=local_start, - local_end=local_end, - ep_rank=ep_rank, - ep_world_size=ep_world_size, - is_tensor_experts=is_tensor_experts, - ) - - -def patch_forward( - block: nn.Module, - ep_group: dist.ProcessGroup, - ep_world_size: int, - cfg: ExpertParallelConfig, -) -> None: - """Replace the MoE block forward with EP-aware communication flow. - - Communication pattern: - preprocess → token_pre_all2all → expert_compute → tokens_post_all2all - - For tensor experts (gate_up_proj/down_proj), the expert compute is delegated - to block.experts(...) via nn.Module.__call__ so that FSDP2 pre/post-forward - hooks fire correctly (automatic unshard before forward, backward hook - registration, and reshard after forward). No manual unshard/reshard is needed. - - For ModuleList experts, each sub-expert is already called via __call__ inside - _run_local_experts, so the same principle applies. - - Args: - block: The MoE block to patch. - ep_group: The process group for EP communication (from ep_fsdp_device_mesh["ep"]). - ep_world_size: The world size for expert parallelism. - cfg: Expert parallel configuration. - """ - if getattr(block, '_ep_patched', False): - return - - gate = _get_gate(block) - if gate is None: - raise ValueError('MoE block must define gate/router module.') - - top_k = _get_top_k(block) - if top_k is None: - raise ValueError('MoE block must define top_k/num_experts_per_tok.') - - orig_forward = block.forward - return_annotation = inspect.signature(orig_forward).return_annotation - returns_router_logits = return_annotation in ( - tuple, - Tuple[torch.Tensor, torch.Tensor | None], - ) - num_experts = block._ep_num_experts - experts_per_rank = block._ep_experts_per_rank - is_tensor_experts = block._ep_tensor_experts - - # For tensor experts, install an ep_forward on the experts module so we can - # call block.experts(permuted_tokens, counts, experts_per_rank) via __call__, - # letting FSDP2 manage unshard/reshard automatically. - if is_tensor_experts: - _install_ep_forward(block.experts, experts_per_rank) - - def forward(hidden_states: torch.Tensor, *args, **kwargs): - if args: - raise RuntimeError('Expert parallel patch only supports keyword-only extra args for MoE blocks.') - - orig_shape = hidden_states.shape - if hidden_states.ndim == 3: - batch_size, seq_len, hidden_dim = hidden_states.shape - hidden_states_2d = hidden_states.view(-1, hidden_dim) - elif hidden_states.ndim == 2: - batch_size, seq_len = 1, hidden_states.shape[0] - hidden_dim = hidden_states.shape[1] - hidden_states_2d = hidden_states - else: - raise ValueError(f'Unsupported hidden_states ndim: {hidden_states.ndim}') - - router_logits, routing_weights, selected_experts = _run_router( - gate=gate, - hidden_states=hidden_states_2d, - top_k=top_k, - router_dtype=_get_router_dtype(cfg.router_dtype, hidden_states_2d.dtype), - norm_topk_prob=getattr(block, 'norm_topk_prob', False), - **kwargs, - ) - # Keep routing weights in activation dtype before unpermute weighting. - if routing_weights.dtype != hidden_states_2d.dtype: - routing_weights = routing_weights.to(hidden_states_2d.dtype) - # Build expert_mask: [num_experts, top_k, num_tokens] - expert_mask = torch.nn.functional.one_hot( - selected_experts, num_classes=num_experts).permute(2, 1, 0) # [num_experts, top_k, num_tokens] - - # 1. preprocess: compute splits and token counts - ( - input_splits, - output_splits, - num_global_tokens_per_local_expert, - num_global_sum_tokens_per_local_expert, - ) = preprocess(expert_mask, num_experts, ep_group) - - # 2. token_pre_all2all: permute → all_to_all → sort_chunks - ( - global_permuted_hidden_states, - local_input_permutation_mapping, - local_assignment_weights, - org_hidden_states_shape, - ) = token_pre_all2all( - hidden_states_2d, - expert_mask, - routing_weights, - num_experts, - input_splits, - output_splits, - num_global_tokens_per_local_expert, - ep_group, - ) - - # 3. expert_compute: call experts via nn.Module.__call__ so FSDP2 hooks fire. - # For tensor experts: block.experts(permuted_tokens, counts, experts_per_rank) - # → FSDP2 pre-forward unshard → ep_forward → FSDP2 post-forward reshard - # For ModuleList experts: _run_local_experts calls each expert[i](...) via __call__. - if is_tensor_experts: - expert_outputs = block.experts( - global_permuted_hidden_states, - num_global_sum_tokens_per_local_expert, - experts_per_rank, - ) - else: - expert_outputs = _run_local_experts( - block, - global_permuted_hidden_states, - num_global_sum_tokens_per_local_expert, - experts_per_rank, - ) - - # 4. tokens_post_all2all: sort_chunks → all_to_all → unpermute (with routing weight) - final_hidden = tokens_post_all2all( - expert_outputs, - local_assignment_weights, - num_experts, - input_splits, - output_splits, - num_global_tokens_per_local_expert, - local_input_permutation_mapping, - org_hidden_states_shape, - ep_group, - ) - - shared_out = _maybe_run_shared_expert(block, hidden_states_2d, cfg) - if shared_out is not None: - final_hidden = final_hidden + shared_out - - if len(orig_shape) == 3: - final_hidden = final_hidden.view(batch_size, seq_len, hidden_dim) - - if cfg.keep_router_logits and returns_router_logits: - return final_hidden, router_logits - return final_hidden - - block._ep_original_forward = orig_forward - block.forward = forward - block._ep_patched = True - - -def _install_ep_forward(experts_mod: nn.Module, experts_per_rank: int) -> None: - if getattr(experts_mod, '_ep_forward_installed', False): - return - - def ep_forward( - self, - permuted_tokens: torch.Tensor, - num_global_sum_tokens_per_local_expert: torch.Tensor, - experts_per_rank: int, - ) -> torch.Tensor: - if permuted_tokens.numel() == 0: - # Preserve the autograd edge to token_pre_all2all. Returning a new - # empty tensor can make this rank skip the matching backward - # all-to-all, causing EP collective order divergence. - return permuted_tokens - - input_dtype = permuted_tokens.dtype - - cumsum = torch.zeros(experts_per_rank + 1, dtype=torch.long) - for i in range(experts_per_rank): - cumsum[i + 1] = cumsum[i] + int(num_global_sum_tokens_per_local_expert[i].item()) - - output_chunks = [] - for i in range(experts_per_rank): - start = int(cumsum[i].item()) - end = int(cumsum[i + 1].item()) - expert_in = permuted_tokens[start:end] - if expert_in.numel() == 0: - output_chunks.append(expert_in) - continue - - gate_up = self.gate_up_proj[i] - down = self.down_proj[i] - compute_dtype = gate_up.dtype - if expert_in.dtype != compute_dtype: - expert_in = expert_in.to(compute_dtype) - gate_up_out = F.linear(expert_in, gate_up) - if hasattr(self, '_apply_gate'): - out = self._apply_gate(gate_up_out) - else: - gate, up = gate_up_out.chunk(2, dim=-1) - out = self.act_fn(gate) * up - out = F.linear(out, down) - - if out.dtype != input_dtype: - out = out.to(input_dtype) - output_chunks.append(out) - - return torch.cat( - output_chunks, dim=0) if output_chunks else permuted_tokens.new_empty(0, permuted_tokens.size(-1)) - - import types - experts_mod.forward = types.MethodType(ep_forward, experts_mod) - experts_mod._ep_forward_installed = True - - -def _get_gate(block: nn.Module): - gate = getattr(block, 'gate', None) - if gate is None: - gate = getattr(block, 'router', None) - return gate - - -def _get_num_experts(block: nn.Module) -> int: - if hasattr(block, 'num_experts'): - return int(block.num_experts) - experts = getattr(block, 'experts', None) - if experts is None: - raise ValueError('MoE block has no experts.') - if isinstance(experts, nn.ModuleList): - return len(experts) - if hasattr(experts, 'num_experts'): - return int(experts.num_experts) - if hasattr(experts, 'gate_up_proj'): - return int(experts.gate_up_proj.shape[0]) - raise ValueError('Unable to infer num_experts for MoE block.') - - -def _get_top_k(block: nn.Module) -> int | None: - gate = _get_gate(block) - if gate is not None and hasattr(gate, 'top_k'): - value = getattr(gate, 'top_k') - if value is not None: - return int(value) - for name in ('num_experts_per_tok', 'top_k'): - if hasattr(block, name): - value = getattr(block, name) - if value is not None: - return int(value) - return None - - -def _get_router_dtype(router_dtype: str, default_dtype: torch.dtype) -> torch.dtype: - if router_dtype == 'fp32': - return torch.float32 - if router_dtype == 'bf16': - return torch.bfloat16 - if router_dtype == 'fp16': - return torch.float16 - return default_dtype - - -def _maybe_run_shared_expert(block: nn.Module, hidden_states_2d: torch.Tensor, cfg: ExpertParallelConfig): - if cfg.ignore_shared_experts: - return None - shared = getattr(block, 'shared_expert', None) - if shared is None: - shared = getattr(block, 'shared_experts', None) - if shared is None: - return None - return _run_module_with_casting(shared, hidden_states_2d) - - -def _is_moe_experts(experts: Any) -> bool: - if isinstance(experts, nn.ModuleList): - return True - if hasattr(experts, 'gate_up_proj') and hasattr(experts, 'down_proj'): - return True - return False - - -def _shard_tensor_experts(experts: nn.Module, start: int, end: int) -> None: - experts.gate_up_proj = nn.Parameter(experts.gate_up_proj.data[start:end].clone()) - experts.down_proj = nn.Parameter(experts.down_proj.data[start:end].clone()) - if hasattr(experts, 'num_experts'): - experts.num_experts = end - start - - -def _run_local_experts( - block: nn.Module, - permuted_tokens: torch.Tensor, - num_global_sum_tokens_per_local_expert: torch.Tensor, - experts_per_rank: int, -) -> torch.Tensor: - """Run ModuleList experts on permuted tokens via nn.Module.__call__. - Tokens are already grouped by expert (contiguous chunks), sizes given by - num_global_sum_tokens_per_local_expert. No routing weight is applied here; - that happens in unpermute. - """ - if permuted_tokens.numel() == 0: - # Keep the backward path through token_pre_all2all even when this EP - # rank owns no routed tokens for the current block. - return permuted_tokens - - input_dtype = permuted_tokens.dtype - experts = block.experts - - cumsum = torch.zeros(experts_per_rank + 1, dtype=torch.long) - for i in range(experts_per_rank): - cumsum[i + 1] = cumsum[i] + int(num_global_sum_tokens_per_local_expert[i].item()) - - output_chunks = [] - for i in range(experts_per_rank): - start = int(cumsum[i].item()) - end = int(cumsum[i + 1].item()) - expert_in = permuted_tokens[start:end] - if expert_in.numel() == 0: - output_chunks.append(expert_in) - continue - - expert = experts[i] - compute_dtype = _module_compute_dtype(expert, input_dtype) - if expert_in.dtype != compute_dtype: - expert_in = expert_in.to(compute_dtype) - out = expert(expert_in) - - if out.dtype != input_dtype: - out = out.to(input_dtype) - output_chunks.append(out) - - return torch.cat(output_chunks, dim=0) if output_chunks else permuted_tokens.new_empty(0, permuted_tokens.size(-1)) - - -def _module_compute_dtype(module: nn.Module, default: torch.dtype) -> torch.dtype: - for param in module.parameters(): - if param.dtype.is_floating_point: - return param.dtype - return default - - -def _run_module_with_casting(module: nn.Module, module_in: torch.Tensor) -> torch.Tensor: - input_dtype = module_in.dtype - compute_dtype = _module_compute_dtype(module, input_dtype) - if compute_dtype != input_dtype: - module_in = module_in.to(compute_dtype) - out = module(module_in) - if out.dtype != input_dtype: - out = out.to(input_dtype) - return out - - -def _run_router( - *, - gate: nn.Module, - hidden_states: torch.Tensor, - top_k: int, - router_dtype: torch.dtype, - norm_topk_prob: bool, - **kwargs, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - gate_kwargs = {} - if 'input_ids' in kwargs and _module_forward_accepts_kwarg(gate, 'input_ids'): - gate_kwargs['input_ids'] = kwargs['input_ids'] - gate_out = gate(hidden_states, **gate_kwargs) - if isinstance(gate_out, tuple) and len(gate_out) >= 3: - router_logits, routing_weights, selected_experts = gate_out[:3] - return router_logits, routing_weights, selected_experts - - router_logits = gate_out - routing_weights = torch.softmax(router_logits, dim=-1, dtype=router_dtype) - routing_weights, selected_experts = torch.topk(routing_weights, top_k, dim=-1) - if norm_topk_prob: - routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True) - return router_logits, routing_weights, selected_experts - - -def _module_forward_accepts_kwarg(module: nn.Module, kwarg: str) -> bool: - signature = inspect.signature(module.forward) - for param in signature.parameters.values(): - if param.kind == inspect.Parameter.VAR_KEYWORD: - return True - return kwarg in signature.parameters +# Copyright (c) ModelScope Contributors. All rights reserved. +from __future__ import annotations + +import inspect +import torch +import torch.distributed as dist +import torch.nn.functional as F +from dataclasses import dataclass +from torch import nn +from typing import Any, Dict, Iterable, List, Optional, Tuple + +from twinkle.model.transformers.moe.ep_utils import preprocess, token_pre_all2all, tokens_post_all2all +from twinkle.utils import DeviceMesh + + +@dataclass +class ExpertParallelConfig: + enabled: bool = True + router_dtype: str = 'fp32' + keep_router_logits: bool = True + ignore_shared_experts: bool = False + ep_size: int | None = None # consumed by TransformersModel, not used in expert_parallel logic + + +@dataclass +class ExpertShardingSpec: + """Describes expert sharding info for a single MoE block. Extensible for other models.""" + block: nn.Module + experts_module: nn.Module + num_experts: int + experts_per_rank: int + local_start: int + local_end: int + ep_rank: int + ep_world_size: int + is_tensor_experts: bool + + +def apply_expert_parallel( + model: nn.Module, + device_mesh: DeviceMesh, + config: dict[str, Any] | None = None, + ep_fsdp_device_mesh: torch.distributed.DeviceMesh | None = None, +) -> list[ExpertShardingSpec]: + """Apply expert parallelism to all MoE blocks in the model.""" + cfg = _merge_config(config) + + # EP info comes from the separate ep_fsdp_device_mesh, not from main mesh + if not cfg.enabled or ep_fsdp_device_mesh is None: + return [] + + # Always query EP via the 1D submesh to avoid relying on Tensor named dims. + ep_mesh = ep_fsdp_device_mesh['ep'] + ep_world_size = ep_mesh.size() + if ep_world_size <= 1: + return [] + + if not dist.is_initialized(): + raise RuntimeError('torch.distributed is not initialized, cannot enable expert parallel.') + + # Get process group and local rank from EP submesh. + ep_group = ep_mesh.get_group() + ep_rank = ep_mesh.get_local_rank() + + specs = [] + for _, block in find_moe_blocks_with_names(model): + spec = shard_experts(block, ep_world_size, ep_rank, cfg) + patch_forward(block, ep_group, ep_world_size, cfg) + specs.append(spec) + + return specs + + +def _merge_config(config: dict[str, Any] | None) -> ExpertParallelConfig: + cfg = ExpertParallelConfig() + if not config: + return cfg + for key, value in config.items(): + if not hasattr(cfg, key): + raise ValueError(f'Unknown expert parallel config: {key}') + setattr(cfg, key, value) + return cfg + + +def find_moe_blocks(model: nn.Module) -> Iterable[nn.Module]: + return [block for _, block in find_moe_blocks_with_names(model)] + + +def find_moe_blocks_with_names(model: nn.Module) -> Iterable[tuple[str, nn.Module]]: + blocks = [] + for name, module in model.named_modules(): + experts = getattr(module, 'experts', None) + if experts is None: + continue + if not _is_moe_experts(experts): + continue + if not _get_gate(module): + continue + blocks.append((name, module)) + return blocks + + +def shard_experts( + block: nn.Module, + ep_world_size: int, + ep_rank: int, + cfg: ExpertParallelConfig, +) -> ExpertShardingSpec: + """Shard experts in a MoE block across EP ranks. + + Args: + block: The MoE block containing experts. + ep_world_size: The world size for expert parallelism. + ep_rank: The current rank in the EP group. + cfg: Expert parallel configuration. + + Returns an ExpertShardingSpec describing the sharding. + """ + num_experts = _get_num_experts(block) + + if num_experts % ep_world_size != 0: + raise ValueError(f'num_experts ({num_experts}) must be divisible by ep_world_size ({ep_world_size}).') + + experts_per_rank = num_experts // ep_world_size + local_start = ep_rank * experts_per_rank + local_end = local_start + experts_per_rank + + if isinstance(block.experts, nn.ModuleList): + local_experts = nn.ModuleList(block.experts[local_start:local_end]) + block.experts = local_experts + is_tensor_experts = False + else: + _shard_tensor_experts(block.experts, local_start, local_end) + is_tensor_experts = True + + block._ep_num_experts = num_experts + block._ep_experts_per_rank = experts_per_rank + block._ep_local_start = local_start + block._ep_local_end = local_end + block._ep_rank = ep_rank + block._ep_world_size = ep_world_size + block._ep_tensor_experts = is_tensor_experts + block._ep_ignore_shared_experts = cfg.ignore_shared_experts + + return ExpertShardingSpec( + block=block, + experts_module=block.experts, + num_experts=num_experts, + experts_per_rank=experts_per_rank, + local_start=local_start, + local_end=local_end, + ep_rank=ep_rank, + ep_world_size=ep_world_size, + is_tensor_experts=is_tensor_experts, + ) + + +def patch_forward( + block: nn.Module, + ep_group: dist.ProcessGroup, + ep_world_size: int, + cfg: ExpertParallelConfig, +) -> None: + """Replace the MoE block forward with EP-aware communication flow. + + Communication pattern: + preprocess → token_pre_all2all → expert_compute → tokens_post_all2all + + For tensor experts (gate_up_proj/down_proj), the expert compute is delegated + to block.experts(...) via nn.Module.__call__ so that FSDP2 pre/post-forward + hooks fire correctly (automatic unshard before forward, backward hook + registration, and reshard after forward). No manual unshard/reshard is needed. + + For ModuleList experts, each sub-expert is already called via __call__ inside + _run_local_experts, so the same principle applies. + + Args: + block: The MoE block to patch. + ep_group: The process group for EP communication (from ep_fsdp_device_mesh["ep"]). + ep_world_size: The world size for expert parallelism. + cfg: Expert parallel configuration. + """ + if getattr(block, '_ep_patched', False): + return + + gate = _get_gate(block) + if gate is None: + raise ValueError('MoE block must define gate/router module.') + + top_k = _get_top_k(block) + if top_k is None: + raise ValueError('MoE block must define top_k/num_experts_per_tok.') + + orig_forward = block.forward + return_annotation = inspect.signature(orig_forward).return_annotation + returns_router_logits = return_annotation in ( + tuple, + Tuple[torch.Tensor, torch.Tensor | None], + ) + num_experts = block._ep_num_experts + experts_per_rank = block._ep_experts_per_rank + is_tensor_experts = block._ep_tensor_experts + + # For tensor experts, install an ep_forward on the experts module so we can + # call block.experts(permuted_tokens, counts, experts_per_rank) via __call__, + # letting FSDP2 manage unshard/reshard automatically. + if is_tensor_experts: + _install_ep_forward(block.experts, experts_per_rank) + + def forward(hidden_states: torch.Tensor, *args, **kwargs): + if args: + raise RuntimeError('Expert parallel patch only supports keyword-only extra args for MoE blocks.') + + orig_shape = hidden_states.shape + if hidden_states.ndim == 3: + batch_size, seq_len, hidden_dim = hidden_states.shape + hidden_states_2d = hidden_states.view(-1, hidden_dim) + elif hidden_states.ndim == 2: + batch_size, seq_len = 1, hidden_states.shape[0] + hidden_dim = hidden_states.shape[1] + hidden_states_2d = hidden_states + else: + raise ValueError(f'Unsupported hidden_states ndim: {hidden_states.ndim}') + + router_logits, routing_weights, selected_experts = _run_router( + gate=gate, + hidden_states=hidden_states_2d, + top_k=top_k, + router_dtype=_get_router_dtype(cfg.router_dtype, hidden_states_2d.dtype), + norm_topk_prob=getattr(block, 'norm_topk_prob', False), + **kwargs, + ) + # Keep routing weights in activation dtype before unpermute weighting. + if routing_weights.dtype != hidden_states_2d.dtype: + routing_weights = routing_weights.to(hidden_states_2d.dtype) + # Build expert_mask: [num_experts, top_k, num_tokens] + expert_mask = torch.nn.functional.one_hot( + selected_experts, num_classes=num_experts).permute(2, 1, 0) # [num_experts, top_k, num_tokens] + + # 1. preprocess: compute splits and token counts + ( + input_splits, + output_splits, + num_global_tokens_per_local_expert, + num_global_sum_tokens_per_local_expert, + ) = preprocess(expert_mask, num_experts, ep_group) + + # 2. token_pre_all2all: permute → all_to_all → sort_chunks + ( + global_permuted_hidden_states, + local_input_permutation_mapping, + local_assignment_weights, + org_hidden_states_shape, + ) = token_pre_all2all( + hidden_states_2d, + expert_mask, + routing_weights, + num_experts, + input_splits, + output_splits, + num_global_tokens_per_local_expert, + ep_group, + ) + + # 3. expert_compute: call experts via nn.Module.__call__ so FSDP2 hooks fire. + # For tensor experts: block.experts(permuted_tokens, counts, experts_per_rank) + # → FSDP2 pre-forward unshard → ep_forward → FSDP2 post-forward reshard + # For ModuleList experts: _run_local_experts calls each expert[i](...) via __call__. + if is_tensor_experts: + expert_outputs = block.experts( + global_permuted_hidden_states, + num_global_sum_tokens_per_local_expert, + experts_per_rank, + ) + else: + expert_outputs = _run_local_experts( + block, + global_permuted_hidden_states, + num_global_sum_tokens_per_local_expert, + experts_per_rank, + ) + + # 4. tokens_post_all2all: sort_chunks → all_to_all → unpermute (with routing weight) + final_hidden = tokens_post_all2all( + expert_outputs, + local_assignment_weights, + num_experts, + input_splits, + output_splits, + num_global_tokens_per_local_expert, + local_input_permutation_mapping, + org_hidden_states_shape, + ep_group, + ) + + shared_out = _maybe_run_shared_expert(block, hidden_states_2d, cfg) + if shared_out is not None: + final_hidden = final_hidden + shared_out + + if len(orig_shape) == 3: + final_hidden = final_hidden.view(batch_size, seq_len, hidden_dim) + + if cfg.keep_router_logits and returns_router_logits: + return final_hidden, router_logits + return final_hidden + + block._ep_original_forward = orig_forward + block.forward = forward + block._ep_patched = True + + +def _install_ep_forward(experts_mod: nn.Module, experts_per_rank: int) -> None: + if getattr(experts_mod, '_ep_forward_installed', False): + return + + def ep_forward( + self, + permuted_tokens: torch.Tensor, + num_global_sum_tokens_per_local_expert: torch.Tensor, + experts_per_rank: int, + ) -> torch.Tensor: + if permuted_tokens.numel() == 0: + # Preserve the autograd edge to token_pre_all2all. Returning a new + # empty tensor can make this rank skip the matching backward + # all-to-all, causing EP collective order divergence. + return permuted_tokens + + input_dtype = permuted_tokens.dtype + + cumsum = torch.zeros(experts_per_rank + 1, dtype=torch.long) + for i in range(experts_per_rank): + cumsum[i + 1] = cumsum[i] + int(num_global_sum_tokens_per_local_expert[i].item()) + + output_chunks = [] + for i in range(experts_per_rank): + start = int(cumsum[i].item()) + end = int(cumsum[i + 1].item()) + expert_in = permuted_tokens[start:end] + if expert_in.numel() == 0: + output_chunks.append(expert_in) + continue + + gate_up = self.gate_up_proj[i] + down = self.down_proj[i] + compute_dtype = gate_up.dtype + if expert_in.dtype != compute_dtype: + expert_in = expert_in.to(compute_dtype) + gate_up_out = F.linear(expert_in, gate_up) + if hasattr(self, '_apply_gate'): + out = self._apply_gate(gate_up_out) + else: + gate, up = gate_up_out.chunk(2, dim=-1) + out = self.act_fn(gate) * up + out = F.linear(out, down) + + if out.dtype != input_dtype: + out = out.to(input_dtype) + output_chunks.append(out) + + return torch.cat( + output_chunks, dim=0) if output_chunks else permuted_tokens.new_empty(0, permuted_tokens.size(-1)) + + import types + experts_mod.forward = types.MethodType(ep_forward, experts_mod) + experts_mod._ep_forward_installed = True + + +def _get_gate(block: nn.Module): + gate = getattr(block, 'gate', None) + if gate is None: + gate = getattr(block, 'router', None) + return gate + + +def _get_num_experts(block: nn.Module) -> int: + if hasattr(block, 'num_experts'): + return int(block.num_experts) + experts = getattr(block, 'experts', None) + if experts is None: + raise ValueError('MoE block has no experts.') + if isinstance(experts, nn.ModuleList): + return len(experts) + if hasattr(experts, 'num_experts'): + return int(experts.num_experts) + if hasattr(experts, 'gate_up_proj'): + return int(experts.gate_up_proj.shape[0]) + raise ValueError('Unable to infer num_experts for MoE block.') + + +def _get_top_k(block: nn.Module) -> int | None: + gate = _get_gate(block) + if gate is not None and hasattr(gate, 'top_k'): + value = getattr(gate, 'top_k') + if value is not None: + return int(value) + for name in ('num_experts_per_tok', 'top_k'): + if hasattr(block, name): + value = getattr(block, name) + if value is not None: + return int(value) + return None + + +def _get_router_dtype(router_dtype: str, default_dtype: torch.dtype) -> torch.dtype: + if router_dtype == 'fp32': + return torch.float32 + if router_dtype == 'bf16': + return torch.bfloat16 + if router_dtype == 'fp16': + return torch.float16 + return default_dtype + + +def _maybe_run_shared_expert(block: nn.Module, hidden_states_2d: torch.Tensor, cfg: ExpertParallelConfig): + if cfg.ignore_shared_experts: + return None + shared = getattr(block, 'shared_expert', None) + if shared is None: + shared = getattr(block, 'shared_experts', None) + if shared is None: + return None + return _run_module_with_casting(shared, hidden_states_2d) + + +def _is_moe_experts(experts: Any) -> bool: + if isinstance(experts, nn.ModuleList): + return True + if hasattr(experts, 'gate_up_proj') and hasattr(experts, 'down_proj'): + return True + return False + + +def _shard_tensor_experts(experts: nn.Module, start: int, end: int) -> None: + experts.gate_up_proj = nn.Parameter(experts.gate_up_proj.data[start:end].clone()) + experts.down_proj = nn.Parameter(experts.down_proj.data[start:end].clone()) + if hasattr(experts, 'num_experts'): + experts.num_experts = end - start + + +def _run_local_experts( + block: nn.Module, + permuted_tokens: torch.Tensor, + num_global_sum_tokens_per_local_expert: torch.Tensor, + experts_per_rank: int, +) -> torch.Tensor: + """Run ModuleList experts on permuted tokens via nn.Module.__call__. + Tokens are already grouped by expert (contiguous chunks), sizes given by + num_global_sum_tokens_per_local_expert. No routing weight is applied here; + that happens in unpermute. + """ + if permuted_tokens.numel() == 0: + # Keep the backward path through token_pre_all2all even when this EP + # rank owns no routed tokens for the current block. + return permuted_tokens + + input_dtype = permuted_tokens.dtype + experts = block.experts + + cumsum = torch.zeros(experts_per_rank + 1, dtype=torch.long) + for i in range(experts_per_rank): + cumsum[i + 1] = cumsum[i] + int(num_global_sum_tokens_per_local_expert[i].item()) + + output_chunks = [] + for i in range(experts_per_rank): + start = int(cumsum[i].item()) + end = int(cumsum[i + 1].item()) + expert_in = permuted_tokens[start:end] + if expert_in.numel() == 0: + output_chunks.append(expert_in) + continue + + expert = experts[i] + compute_dtype = _module_compute_dtype(expert, input_dtype) + if expert_in.dtype != compute_dtype: + expert_in = expert_in.to(compute_dtype) + out = expert(expert_in) + + if out.dtype != input_dtype: + out = out.to(input_dtype) + output_chunks.append(out) + + return torch.cat(output_chunks, dim=0) if output_chunks else permuted_tokens.new_empty(0, permuted_tokens.size(-1)) + + +def _module_compute_dtype(module: nn.Module, default: torch.dtype) -> torch.dtype: + for param in module.parameters(): + if param.dtype.is_floating_point: + return param.dtype + return default + + +def _run_module_with_casting(module: nn.Module, module_in: torch.Tensor) -> torch.Tensor: + input_dtype = module_in.dtype + compute_dtype = _module_compute_dtype(module, input_dtype) + if compute_dtype != input_dtype: + module_in = module_in.to(compute_dtype) + out = module(module_in) + if out.dtype != input_dtype: + out = out.to(input_dtype) + return out + + +def _run_router( + *, + gate: nn.Module, + hidden_states: torch.Tensor, + top_k: int, + router_dtype: torch.dtype, + norm_topk_prob: bool, + **kwargs, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + gate_kwargs = {} + if 'input_ids' in kwargs and _module_forward_accepts_kwarg(gate, 'input_ids'): + gate_kwargs['input_ids'] = kwargs['input_ids'] + gate_out = gate(hidden_states, **gate_kwargs) + if isinstance(gate_out, tuple) and len(gate_out) >= 3: + router_logits, routing_weights, selected_experts = gate_out[:3] + return router_logits, routing_weights, selected_experts + + router_logits = gate_out + routing_weights = torch.softmax(router_logits, dim=-1, dtype=router_dtype) + routing_weights, selected_experts = torch.topk(routing_weights, top_k, dim=-1) + if norm_topk_prob: + routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True) + return router_logits, routing_weights, selected_experts + + +def _module_forward_accepts_kwarg(module: nn.Module, kwarg: str) -> bool: + signature = inspect.signature(module.forward) + for param in signature.parameters.values(): + if param.kind == inspect.Parameter.VAR_KEYWORD: + return True + return kwarg in signature.parameters diff --git a/src/twinkle/model/transformers/strategy/sequence_parallel/linear_attention_sp.py b/src/twinkle/model/transformers/strategy/sequence_parallel/linear_attention_sp.py index 6608a2b8..9da5794b 100644 --- a/src/twinkle/model/transformers/strategy/sequence_parallel/linear_attention_sp.py +++ b/src/twinkle/model/transformers/strategy/sequence_parallel/linear_attention_sp.py @@ -1,363 +1,363 @@ -import torch -import torch.distributed as dist -import torch.nn.functional as F -import warnings -from importlib import import_module -from transformers.utils.import_utils import is_causal_conv1d_available, is_flash_linear_attention_available -from typing import Any, Optional, Tuple - -from twinkle.model.transformers.strategy.sequence_parallel.utils import ( - get_packed_cu_seqlens_from_sequence_parallel_context, head_to_seq_shard, seq_to_head_shard) -from twinkle.patch import Patch - -if is_flash_linear_attention_available(): - from fla.modules.convolution import causal_conv1d as _FLA_CAUSAL_CONV1D_FN - from fla.ops.gated_delta_rule import chunk_gated_delta_rule as _FLA_CHUNK_GATED_DELTA_RULE -else: - _FLA_CAUSAL_CONV1D_FN = None - _FLA_CHUNK_GATED_DELTA_RULE = None - -if is_causal_conv1d_available(): - from causal_conv1d import causal_conv1d_fn as _CAUSAL_CONV1D_FN -else: - _CAUSAL_CONV1D_FN = None - -_SP_LINEAR_KERNEL_FALLBACK_WARNING = ( - 'flash-linear-attention is not available; falling back to torch implementations for Qwen3.5 linear attention ' - 'sequence parallel. This fallback only supports non-packed sequences.') - - -def _sp_is_enabled(sequence_parallel_context) -> bool: - return bool(sequence_parallel_context is not None and getattr(sequence_parallel_context, 'world_size', 1) > 1) - - -def _get_sp_rank(sequence_parallel_context) -> int: - if not _sp_is_enabled(sequence_parallel_context): - return 0 - if getattr(sequence_parallel_context, '_sp_group', None) is None: - return 0 - return dist.get_rank(group=sequence_parallel_context._sp_group) - - -def _get_local_padding_mask( - attention_mask: torch.Tensor, - local_seq_len: int, - sequence_parallel_context, -) -> torch.Tensor: - if attention_mask.shape[-1] == local_seq_len or not _sp_is_enabled(sequence_parallel_context): - return attention_mask - return sequence_parallel_context.split( - attention_mask, - dim=1, - position_ids=sequence_parallel_context.real_position_ids, - ) - - -def _apply_conv_activation(x: torch.Tensor, activation) -> torch.Tensor: - if activation is None: - return x - if activation in ('silu', 'swish'): - return F.silu(x) - if callable(activation): - return activation(x) - from transformers.activations import ACT2FN - if activation in ACT2FN: - return ACT2FN[activation](x) - raise ValueError(f'Unsupported causal conv activation: {activation!r}') - - -def _ensure_linear_attention_kernels(mod: torch.nn.Module): - """Bind causal_conv1d_fn and chunk_gated_delta_rule for SP forward.""" - - def _torch_causal_conv1d_fn( - *, - x, - weight, - bias=None, - activation=None, - seq_idx=None, - backend=None, - cu_seqlens=None, - ): - # Fallback priority: - # 1. flash-linear-attention kernels handle padding_free/packed cu_seqlens and are selected above. - # 2. causal-conv1d package accelerates non-packed convolution when flash-linear-attention is unavailable. - # 3. plain torch conv1d is the final non-packed fallback. - del backend - if cu_seqlens is not None: - raise NotImplementedError( - 'Qwen3.5 linear attention sequence parallel with padding_free/packed inputs requires ' - 'flash-linear-attention. The torch fallback only supports non-packed sequences. ' - 'Please install flash-linear-attention or disable padding_free/packing.') - if _CAUSAL_CONV1D_FN is not None: - out = _CAUSAL_CONV1D_FN( - x=x.transpose(1, 2).contiguous(), - weight=weight, - bias=bias, - activation=activation, - seq_idx=seq_idx, - ) - if isinstance(out, tuple): - out = out[0] - return out.transpose(1, 2).contiguous() - seq_len = x.shape[1] - x = x.transpose(1, 2).contiguous() - out = F.conv1d(x, weight.unsqueeze(1), bias, padding=weight.shape[-1] - 1, groups=x.shape[1]) - out = _apply_conv_activation(out[:, :, :seq_len], activation) - return out.transpose(1, 2).contiguous() - - # NPU: MindSpeed Triton causal_conv1d and chunk_gated_delta_rule - # are both patched by twinkle.kernel.npu_impls.fla at model initialization. - # No need to set them here - they are already bound on the module. - if getattr(mod, '_twinkle_npu_patched', False): - return False - - if _FLA_CAUSAL_CONV1D_FN is not None and _FLA_CHUNK_GATED_DELTA_RULE is not None: - mod.causal_conv1d_fn = _FLA_CAUSAL_CONV1D_FN - mod.chunk_gated_delta_rule = _FLA_CHUNK_GATED_DELTA_RULE - return False - - modeling_module = import_module(mod.__class__.__module__) - torch_chunk_gated_delta_rule = getattr(modeling_module, 'torch_chunk_gated_delta_rule') - mod.causal_conv1d_fn = _torch_causal_conv1d_fn - mod.chunk_gated_delta_rule = torch_chunk_gated_delta_rule - warnings.warn(_SP_LINEAR_KERNEL_FALLBACK_WARNING, stacklevel=2) - return True - - -def _iter_qwen35_gated_delta_net_classes(): - class_specs = ( - ('transformers.models.qwen3_5.modeling_qwen3_5', 'Qwen3_5GatedDeltaNet'), - ('transformers.models.qwen3_5_moe.modeling_qwen3_5_moe', 'Qwen3_5MoeGatedDeltaNet'), - ) - for module_name, class_name in class_specs: - try: - modeling_module = import_module(module_name) - yield getattr(modeling_module, class_name) - except Exception: - continue - - -def _get_local_conv_weights( - mod: torch.nn.Module, - *, - sp_rank: int, - local_num_k_heads: int, - local_num_v_heads: int, -) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - local_key_dim = local_num_k_heads * mod.head_k_dim - local_value_dim = local_num_v_heads * mod.head_v_dim - conv_weight = mod.conv1d.weight.squeeze(1) - if conv_weight.shape[0] != (2 * mod.key_dim + mod.value_dim): - raise ValueError( - f'Unexpected conv weight dim {conv_weight.shape[0]}, expected {2 * mod.key_dim + mod.value_dim}.') - key_offset = sp_rank * local_key_dim - value_offset = sp_rank * local_value_dim - local_q_weight = conv_weight[key_offset:key_offset + local_key_dim] - local_k_weight = conv_weight[mod.key_dim + key_offset:mod.key_dim + key_offset + local_key_dim] - local_v_weight = conv_weight[2 * mod.key_dim + value_offset:2 * mod.key_dim + value_offset + local_value_dim] - local_conv_weight = torch.cat([local_q_weight, local_k_weight, local_v_weight], dim=0) - - conv_bias = getattr(mod.conv1d, 'bias', None) - if conv_bias is None: - return local_conv_weight, None - local_q_bias = conv_bias[key_offset:key_offset + local_key_dim] - local_k_bias = conv_bias[mod.key_dim + key_offset:mod.key_dim + key_offset + local_key_dim] - local_v_bias = conv_bias[2 * mod.key_dim + value_offset:2 * mod.key_dim + value_offset + local_value_dim] - return local_conv_weight, torch.cat([local_q_bias, local_k_bias, local_v_bias], dim=0) - - -class Qwen3_5GatedDeltaNetUlyssesPatch(Patch): - - @staticmethod - def _run_forward( - mod: torch.nn.Module, - hidden_states: torch.Tensor, - *, - cache_params=None, - cache_position=None, - attention_mask: Optional[torch.Tensor] = None, - sequence_parallel_context=None, - ) -> torch.Tensor: - using_torch_fallback = _ensure_linear_attention_kernels(mod) - modeling_module = import_module(mod.__class__.__module__) - apply_mask_to_padding_states = getattr(modeling_module, 'apply_mask_to_padding_states') - - local_attention_mask = attention_mask - if torch.is_tensor(attention_mask) and attention_mask.dim() == 2: - local_attention_mask = _get_local_padding_mask( - attention_mask, - hidden_states.shape[1], - sequence_parallel_context, - ) - hidden_states = apply_mask_to_padding_states(hidden_states, local_attention_mask) - batch_size, seq_len, _ = hidden_states.shape - - has_previous_state = bool(cache_params is not None and getattr(cache_params, 'has_previous_state', False)) - use_precomputed_states = has_previous_state and seq_len == 1 and cache_position is not None - if use_precomputed_states: - raise NotImplementedError( - 'Qwen3.5 linear attention sequence parallel only supports training/prefill paths; decode with ' - 'cached states is not supported.') - - mixed_qkv = mod.in_proj_qkv(hidden_states) - z = mod.in_proj_z(hidden_states).reshape(batch_size, seq_len, mod.num_v_heads, mod.head_v_dim) - b = mod.in_proj_b(hidden_states) - a = mod.in_proj_a(hidden_states) - - sp_enabled = _sp_is_enabled(sequence_parallel_context) - if sp_enabled: - sp_world_size = int(sequence_parallel_context.sp_world_size) - if mod.num_k_heads % sp_world_size != 0 or mod.num_v_heads % sp_world_size != 0: - raise RuntimeError( - 'Qwen3.5 linear attention sequence parallel requires sp_world_size to divide both ' - f'linear_num_key_heads ({mod.num_k_heads}) and linear_num_value_heads ({mod.num_v_heads}).') - local_num_k_heads = mod.num_k_heads // sp_world_size - local_num_v_heads = mod.num_v_heads // sp_world_size - q_proj, k_proj, v_proj = torch.split(mixed_qkv, [mod.key_dim, mod.key_dim, mod.value_dim], dim=-1) - q_proj = q_proj.reshape(batch_size, seq_len, mod.num_k_heads, mod.head_k_dim) - k_proj = k_proj.reshape(batch_size, seq_len, mod.num_k_heads, mod.head_k_dim) - v_proj = v_proj.reshape(batch_size, seq_len, mod.num_v_heads, mod.head_v_dim) - q_proj = seq_to_head_shard(q_proj, sequence_parallel_context) - k_proj = seq_to_head_shard(k_proj, sequence_parallel_context) - v_proj = seq_to_head_shard(v_proj, sequence_parallel_context) - b = seq_to_head_shard(b.reshape(batch_size, seq_len, mod.num_v_heads, 1), - sequence_parallel_context).squeeze(-1) - a = seq_to_head_shard(a.reshape(batch_size, seq_len, mod.num_v_heads, 1), - sequence_parallel_context).squeeze(-1) - seq_after_shard = q_proj.shape[1] - mixed_qkv = torch.cat( - ( - q_proj.reshape(batch_size, seq_after_shard, local_num_k_heads * mod.head_k_dim), - k_proj.reshape(batch_size, seq_after_shard, local_num_k_heads * mod.head_k_dim), - v_proj.reshape(batch_size, seq_after_shard, local_num_v_heads * mod.head_v_dim), - ), - dim=-1, - ) - sp_rank = _get_sp_rank(sequence_parallel_context) - conv_weight, conv_bias = _get_local_conv_weights( - mod, sp_rank=sp_rank, local_num_k_heads=local_num_k_heads, local_num_v_heads=local_num_v_heads) - else: - local_num_k_heads = mod.num_k_heads - local_num_v_heads = mod.num_v_heads - sp_rank = 0 - b = b.reshape(batch_size, seq_len, mod.num_v_heads) - a = a.reshape(batch_size, seq_len, mod.num_v_heads) - conv_weight = mod.conv1d.weight.squeeze(1) - conv_bias = getattr(mod.conv1d, 'bias', None) - - packed_cu_seqlens = get_packed_cu_seqlens_from_sequence_parallel_context( - sequence_parallel_context, - device=mixed_qkv.device, - ) - extra_kwargs = getattr(sequence_parallel_context, 'extra_kwargs', {}) - if bool(extra_kwargs.get('padding_free', False)) and packed_cu_seqlens is None: - raise ValueError( - 'Qwen3.5 sequence parallel with padding_free/packed inputs requires packed sequence metadata ' - '(for example valid position_ids).') - if using_torch_fallback and packed_cu_seqlens is not None: - raise NotImplementedError( - 'Qwen3.5 linear attention sequence parallel with padding_free/packed inputs requires ' - 'flash-linear-attention. The torch fallback only supports non-packed sequences. ' - 'Please install flash-linear-attention or disable padding_free/packing.') - if cache_params is not None: - cache_params.conv_states[mod.layer_idx] = F.pad( - mixed_qkv.transpose(1, 2).contiguous(), (mod.conv_kernel_size - mixed_qkv.shape[1], 0)) - mixed_qkv = mod.causal_conv1d_fn( - x=mixed_qkv, - weight=conv_weight, - bias=conv_bias, - activation=mod.activation, - seq_idx=None, - backend='triton', - cu_seqlens=packed_cu_seqlens, - ) - if isinstance(mixed_qkv, tuple): - mixed_qkv = mixed_qkv[0] - if mixed_qkv.dim() == 2: - mixed_qkv = mixed_qkv.unsqueeze(0) - if mixed_qkv.dim() != 3: - raise ValueError(f'Unexpected conv output dims: {tuple(mixed_qkv.shape)}') - - local_key_dim = local_num_k_heads * mod.head_k_dim - local_value_dim = local_num_v_heads * mod.head_v_dim - query, key, value = torch.split(mixed_qkv, [local_key_dim, local_key_dim, local_value_dim], dim=-1) - query = query.reshape(batch_size, query.shape[1], local_num_k_heads, mod.head_k_dim) - key = key.reshape(batch_size, key.shape[1], local_num_k_heads, mod.head_k_dim) - value = value.reshape(batch_size, value.shape[1], local_num_v_heads, mod.head_v_dim) - - beta = b.sigmoid() - head_slice = slice(sp_rank * local_num_v_heads, - (sp_rank + 1) * local_num_v_heads) if sp_enabled else slice(None) - g = -mod.A_log[head_slice].float().exp() * F.softplus(a.float() + mod.dt_bias[head_slice]) - - if local_num_v_heads // local_num_k_heads > 1: - repeat = local_num_v_heads // local_num_k_heads - query = query.repeat_interleave(repeat, dim=2) - key = key.repeat_interleave(repeat, dim=2) - - chunk_kwargs = { - 'g': g, - 'beta': beta, - 'initial_state': None, - 'output_final_state': cache_params is not None, - 'use_qk_l2norm_in_kernel': True, - } - if packed_cu_seqlens is not None: - chunk_kwargs['cu_seqlens'] = packed_cu_seqlens - core_attn_out, last_recurrent_state = mod.chunk_gated_delta_rule(query, key, value, **chunk_kwargs) - - if cache_params is not None: - cache_params.recurrent_states[mod.layer_idx] = last_recurrent_state - - if sp_enabled: - core_attn_out = head_to_seq_shard(core_attn_out, sequence_parallel_context) - core_attn_out = mod.norm(core_attn_out.reshape(-1, mod.head_v_dim), z.reshape(-1, mod.head_v_dim)) - core_attn_out = core_attn_out.reshape(batch_size, seq_len, local_value_dim if not sp_enabled else mod.value_dim) - return mod.out_proj(core_attn_out) - - def __call__(self, module, *args, **kwargs): - del module, args - sequence_parallel = kwargs.get('sequence_parallel', None) - if sequence_parallel is None: - return - if int(getattr(sequence_parallel, 'rp_world_size', 1) or 1) > 1: - raise NotImplementedError('Qwen3.5 linear attention sequence parallel does not support rp_world_size > 1 ' - '(derived ring attention).') - - for gated_delta_net_cls in _iter_qwen35_gated_delta_net_classes(): - if getattr(gated_delta_net_cls, '_twinkle_sp_linear_patched', False): - continue - - origin_forward = gated_delta_net_cls.forward - - def sp_linear_forward( - mod, - hidden_states: torch.Tensor, - cache_params=None, - cache_position=None, - attention_mask: Optional[torch.Tensor] = None, - _origin_forward=origin_forward, - **extra_kwargs, - ): - sequence_parallel_context = extra_kwargs.pop('sequence_parallel_context', sequence_parallel) - if not _sp_is_enabled(sequence_parallel_context): - return _origin_forward( - mod, - hidden_states, - cache_params=cache_params, - cache_position=cache_position, - attention_mask=attention_mask, - **extra_kwargs, - ) - return Qwen3_5GatedDeltaNetUlyssesPatch._run_forward( - mod, - hidden_states, - cache_params=cache_params, - cache_position=cache_position, - attention_mask=attention_mask, - sequence_parallel_context=sequence_parallel_context, - ) - - gated_delta_net_cls.forward = sp_linear_forward - gated_delta_net_cls._twinkle_sp_linear_patched = True +import torch +import torch.distributed as dist +import torch.nn.functional as F +import warnings +from importlib import import_module +from transformers.utils.import_utils import is_causal_conv1d_available, is_flash_linear_attention_available +from typing import Any, Optional, Tuple + +from twinkle.model.transformers.strategy.sequence_parallel.utils import ( + get_packed_cu_seqlens_from_sequence_parallel_context, head_to_seq_shard, seq_to_head_shard) +from twinkle.patch import Patch + +if is_flash_linear_attention_available(): + from fla.modules.convolution import causal_conv1d as _FLA_CAUSAL_CONV1D_FN + from fla.ops.gated_delta_rule import chunk_gated_delta_rule as _FLA_CHUNK_GATED_DELTA_RULE +else: + _FLA_CAUSAL_CONV1D_FN = None + _FLA_CHUNK_GATED_DELTA_RULE = None + +if is_causal_conv1d_available(): + from causal_conv1d import causal_conv1d_fn as _CAUSAL_CONV1D_FN +else: + _CAUSAL_CONV1D_FN = None + +_SP_LINEAR_KERNEL_FALLBACK_WARNING = ( + 'flash-linear-attention is not available; falling back to torch implementations for Qwen3.5 linear attention ' + 'sequence parallel. This fallback only supports non-packed sequences.') + + +def _sp_is_enabled(sequence_parallel_context) -> bool: + return bool(sequence_parallel_context is not None and getattr(sequence_parallel_context, 'world_size', 1) > 1) + + +def _get_sp_rank(sequence_parallel_context) -> int: + if not _sp_is_enabled(sequence_parallel_context): + return 0 + if getattr(sequence_parallel_context, '_sp_group', None) is None: + return 0 + return dist.get_rank(group=sequence_parallel_context._sp_group) + + +def _get_local_padding_mask( + attention_mask: torch.Tensor, + local_seq_len: int, + sequence_parallel_context, +) -> torch.Tensor: + if attention_mask.shape[-1] == local_seq_len or not _sp_is_enabled(sequence_parallel_context): + return attention_mask + return sequence_parallel_context.split( + attention_mask, + dim=1, + position_ids=sequence_parallel_context.real_position_ids, + ) + + +def _apply_conv_activation(x: torch.Tensor, activation) -> torch.Tensor: + if activation is None: + return x + if activation in ('silu', 'swish'): + return F.silu(x) + if callable(activation): + return activation(x) + from transformers.activations import ACT2FN + if activation in ACT2FN: + return ACT2FN[activation](x) + raise ValueError(f'Unsupported causal conv activation: {activation!r}') + + +def _ensure_linear_attention_kernels(mod: torch.nn.Module): + """Bind causal_conv1d_fn and chunk_gated_delta_rule for SP forward.""" + + def _torch_causal_conv1d_fn( + *, + x, + weight, + bias=None, + activation=None, + seq_idx=None, + backend=None, + cu_seqlens=None, + ): + # Fallback priority: + # 1. flash-linear-attention kernels handle padding_free/packed cu_seqlens and are selected above. + # 2. causal-conv1d package accelerates non-packed convolution when flash-linear-attention is unavailable. + # 3. plain torch conv1d is the final non-packed fallback. + del backend + if cu_seqlens is not None: + raise NotImplementedError( + 'Qwen3.5 linear attention sequence parallel with padding_free/packed inputs requires ' + 'flash-linear-attention. The torch fallback only supports non-packed sequences. ' + 'Please install flash-linear-attention or disable padding_free/packing.') + if _CAUSAL_CONV1D_FN is not None: + out = _CAUSAL_CONV1D_FN( + x=x.transpose(1, 2).contiguous(), + weight=weight, + bias=bias, + activation=activation, + seq_idx=seq_idx, + ) + if isinstance(out, tuple): + out = out[0] + return out.transpose(1, 2).contiguous() + seq_len = x.shape[1] + x = x.transpose(1, 2).contiguous() + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=weight.shape[-1] - 1, groups=x.shape[1]) + out = _apply_conv_activation(out[:, :, :seq_len], activation) + return out.transpose(1, 2).contiguous() + + # NPU: MindSpeed Triton causal_conv1d and chunk_gated_delta_rule + # are both patched by twinkle.kernel.npu_impls.fla at model initialization. + # No need to set them here - they are already bound on the module. + if getattr(mod, '_twinkle_npu_patched', False): + return False + + if _FLA_CAUSAL_CONV1D_FN is not None and _FLA_CHUNK_GATED_DELTA_RULE is not None: + mod.causal_conv1d_fn = _FLA_CAUSAL_CONV1D_FN + mod.chunk_gated_delta_rule = _FLA_CHUNK_GATED_DELTA_RULE + return False + + modeling_module = import_module(mod.__class__.__module__) + torch_chunk_gated_delta_rule = getattr(modeling_module, 'torch_chunk_gated_delta_rule') + mod.causal_conv1d_fn = _torch_causal_conv1d_fn + mod.chunk_gated_delta_rule = torch_chunk_gated_delta_rule + warnings.warn(_SP_LINEAR_KERNEL_FALLBACK_WARNING, stacklevel=2) + return True + + +def _iter_qwen35_gated_delta_net_classes(): + class_specs = ( + ('transformers.models.qwen3_5.modeling_qwen3_5', 'Qwen3_5GatedDeltaNet'), + ('transformers.models.qwen3_5_moe.modeling_qwen3_5_moe', 'Qwen3_5MoeGatedDeltaNet'), + ) + for module_name, class_name in class_specs: + try: + modeling_module = import_module(module_name) + yield getattr(modeling_module, class_name) + except Exception: + continue + + +def _get_local_conv_weights( + mod: torch.nn.Module, + *, + sp_rank: int, + local_num_k_heads: int, + local_num_v_heads: int, +) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + local_key_dim = local_num_k_heads * mod.head_k_dim + local_value_dim = local_num_v_heads * mod.head_v_dim + conv_weight = mod.conv1d.weight.squeeze(1) + if conv_weight.shape[0] != (2 * mod.key_dim + mod.value_dim): + raise ValueError( + f'Unexpected conv weight dim {conv_weight.shape[0]}, expected {2 * mod.key_dim + mod.value_dim}.') + key_offset = sp_rank * local_key_dim + value_offset = sp_rank * local_value_dim + local_q_weight = conv_weight[key_offset:key_offset + local_key_dim] + local_k_weight = conv_weight[mod.key_dim + key_offset:mod.key_dim + key_offset + local_key_dim] + local_v_weight = conv_weight[2 * mod.key_dim + value_offset:2 * mod.key_dim + value_offset + local_value_dim] + local_conv_weight = torch.cat([local_q_weight, local_k_weight, local_v_weight], dim=0) + + conv_bias = getattr(mod.conv1d, 'bias', None) + if conv_bias is None: + return local_conv_weight, None + local_q_bias = conv_bias[key_offset:key_offset + local_key_dim] + local_k_bias = conv_bias[mod.key_dim + key_offset:mod.key_dim + key_offset + local_key_dim] + local_v_bias = conv_bias[2 * mod.key_dim + value_offset:2 * mod.key_dim + value_offset + local_value_dim] + return local_conv_weight, torch.cat([local_q_bias, local_k_bias, local_v_bias], dim=0) + + +class Qwen3_5GatedDeltaNetUlyssesPatch(Patch): + + @staticmethod + def _run_forward( + mod: torch.nn.Module, + hidden_states: torch.Tensor, + *, + cache_params=None, + cache_position=None, + attention_mask: Optional[torch.Tensor] = None, + sequence_parallel_context=None, + ) -> torch.Tensor: + using_torch_fallback = _ensure_linear_attention_kernels(mod) + modeling_module = import_module(mod.__class__.__module__) + apply_mask_to_padding_states = getattr(modeling_module, 'apply_mask_to_padding_states') + + local_attention_mask = attention_mask + if torch.is_tensor(attention_mask) and attention_mask.dim() == 2: + local_attention_mask = _get_local_padding_mask( + attention_mask, + hidden_states.shape[1], + sequence_parallel_context, + ) + hidden_states = apply_mask_to_padding_states(hidden_states, local_attention_mask) + batch_size, seq_len, _ = hidden_states.shape + + has_previous_state = bool(cache_params is not None and getattr(cache_params, 'has_previous_state', False)) + use_precomputed_states = has_previous_state and seq_len == 1 and cache_position is not None + if use_precomputed_states: + raise NotImplementedError( + 'Qwen3.5 linear attention sequence parallel only supports training/prefill paths; decode with ' + 'cached states is not supported.') + + mixed_qkv = mod.in_proj_qkv(hidden_states) + z = mod.in_proj_z(hidden_states).reshape(batch_size, seq_len, mod.num_v_heads, mod.head_v_dim) + b = mod.in_proj_b(hidden_states) + a = mod.in_proj_a(hidden_states) + + sp_enabled = _sp_is_enabled(sequence_parallel_context) + if sp_enabled: + sp_world_size = int(sequence_parallel_context.sp_world_size) + if mod.num_k_heads % sp_world_size != 0 or mod.num_v_heads % sp_world_size != 0: + raise RuntimeError( + 'Qwen3.5 linear attention sequence parallel requires sp_world_size to divide both ' + f'linear_num_key_heads ({mod.num_k_heads}) and linear_num_value_heads ({mod.num_v_heads}).') + local_num_k_heads = mod.num_k_heads // sp_world_size + local_num_v_heads = mod.num_v_heads // sp_world_size + q_proj, k_proj, v_proj = torch.split(mixed_qkv, [mod.key_dim, mod.key_dim, mod.value_dim], dim=-1) + q_proj = q_proj.reshape(batch_size, seq_len, mod.num_k_heads, mod.head_k_dim) + k_proj = k_proj.reshape(batch_size, seq_len, mod.num_k_heads, mod.head_k_dim) + v_proj = v_proj.reshape(batch_size, seq_len, mod.num_v_heads, mod.head_v_dim) + q_proj = seq_to_head_shard(q_proj, sequence_parallel_context) + k_proj = seq_to_head_shard(k_proj, sequence_parallel_context) + v_proj = seq_to_head_shard(v_proj, sequence_parallel_context) + b = seq_to_head_shard(b.reshape(batch_size, seq_len, mod.num_v_heads, 1), + sequence_parallel_context).squeeze(-1) + a = seq_to_head_shard(a.reshape(batch_size, seq_len, mod.num_v_heads, 1), + sequence_parallel_context).squeeze(-1) + seq_after_shard = q_proj.shape[1] + mixed_qkv = torch.cat( + ( + q_proj.reshape(batch_size, seq_after_shard, local_num_k_heads * mod.head_k_dim), + k_proj.reshape(batch_size, seq_after_shard, local_num_k_heads * mod.head_k_dim), + v_proj.reshape(batch_size, seq_after_shard, local_num_v_heads * mod.head_v_dim), + ), + dim=-1, + ) + sp_rank = _get_sp_rank(sequence_parallel_context) + conv_weight, conv_bias = _get_local_conv_weights( + mod, sp_rank=sp_rank, local_num_k_heads=local_num_k_heads, local_num_v_heads=local_num_v_heads) + else: + local_num_k_heads = mod.num_k_heads + local_num_v_heads = mod.num_v_heads + sp_rank = 0 + b = b.reshape(batch_size, seq_len, mod.num_v_heads) + a = a.reshape(batch_size, seq_len, mod.num_v_heads) + conv_weight = mod.conv1d.weight.squeeze(1) + conv_bias = getattr(mod.conv1d, 'bias', None) + + packed_cu_seqlens = get_packed_cu_seqlens_from_sequence_parallel_context( + sequence_parallel_context, + device=mixed_qkv.device, + ) + extra_kwargs = getattr(sequence_parallel_context, 'extra_kwargs', {}) + if bool(extra_kwargs.get('padding_free', False)) and packed_cu_seqlens is None: + raise ValueError( + 'Qwen3.5 sequence parallel with padding_free/packed inputs requires packed sequence metadata ' + '(for example valid position_ids).') + if using_torch_fallback and packed_cu_seqlens is not None: + raise NotImplementedError( + 'Qwen3.5 linear attention sequence parallel with padding_free/packed inputs requires ' + 'flash-linear-attention. The torch fallback only supports non-packed sequences. ' + 'Please install flash-linear-attention or disable padding_free/packing.') + if cache_params is not None: + cache_params.conv_states[mod.layer_idx] = F.pad( + mixed_qkv.transpose(1, 2).contiguous(), (mod.conv_kernel_size - mixed_qkv.shape[1], 0)) + mixed_qkv = mod.causal_conv1d_fn( + x=mixed_qkv, + weight=conv_weight, + bias=conv_bias, + activation=mod.activation, + seq_idx=None, + backend='triton', + cu_seqlens=packed_cu_seqlens, + ) + if isinstance(mixed_qkv, tuple): + mixed_qkv = mixed_qkv[0] + if mixed_qkv.dim() == 2: + mixed_qkv = mixed_qkv.unsqueeze(0) + if mixed_qkv.dim() != 3: + raise ValueError(f'Unexpected conv output dims: {tuple(mixed_qkv.shape)}') + + local_key_dim = local_num_k_heads * mod.head_k_dim + local_value_dim = local_num_v_heads * mod.head_v_dim + query, key, value = torch.split(mixed_qkv, [local_key_dim, local_key_dim, local_value_dim], dim=-1) + query = query.reshape(batch_size, query.shape[1], local_num_k_heads, mod.head_k_dim) + key = key.reshape(batch_size, key.shape[1], local_num_k_heads, mod.head_k_dim) + value = value.reshape(batch_size, value.shape[1], local_num_v_heads, mod.head_v_dim) + + beta = b.sigmoid() + head_slice = slice(sp_rank * local_num_v_heads, + (sp_rank + 1) * local_num_v_heads) if sp_enabled else slice(None) + g = -mod.A_log[head_slice].float().exp() * F.softplus(a.float() + mod.dt_bias[head_slice]) + + if local_num_v_heads // local_num_k_heads > 1: + repeat = local_num_v_heads // local_num_k_heads + query = query.repeat_interleave(repeat, dim=2) + key = key.repeat_interleave(repeat, dim=2) + + chunk_kwargs = { + 'g': g, + 'beta': beta, + 'initial_state': None, + 'output_final_state': cache_params is not None, + 'use_qk_l2norm_in_kernel': True, + } + if packed_cu_seqlens is not None: + chunk_kwargs['cu_seqlens'] = packed_cu_seqlens + core_attn_out, last_recurrent_state = mod.chunk_gated_delta_rule(query, key, value, **chunk_kwargs) + + if cache_params is not None: + cache_params.recurrent_states[mod.layer_idx] = last_recurrent_state + + if sp_enabled: + core_attn_out = head_to_seq_shard(core_attn_out, sequence_parallel_context) + core_attn_out = mod.norm(core_attn_out.reshape(-1, mod.head_v_dim), z.reshape(-1, mod.head_v_dim)) + core_attn_out = core_attn_out.reshape(batch_size, seq_len, local_value_dim if not sp_enabled else mod.value_dim) + return mod.out_proj(core_attn_out) + + def __call__(self, module, *args, **kwargs): + del module, args + sequence_parallel = kwargs.get('sequence_parallel', None) + if sequence_parallel is None: + return + if int(getattr(sequence_parallel, 'rp_world_size', 1) or 1) > 1: + raise NotImplementedError('Qwen3.5 linear attention sequence parallel does not support rp_world_size > 1 ' + '(derived ring attention).') + + for gated_delta_net_cls in _iter_qwen35_gated_delta_net_classes(): + if getattr(gated_delta_net_cls, '_twinkle_sp_linear_patched', False): + continue + + origin_forward = gated_delta_net_cls.forward + + def sp_linear_forward( + mod, + hidden_states: torch.Tensor, + cache_params=None, + cache_position=None, + attention_mask: Optional[torch.Tensor] = None, + _origin_forward=origin_forward, + **extra_kwargs, + ): + sequence_parallel_context = extra_kwargs.pop('sequence_parallel_context', sequence_parallel) + if not _sp_is_enabled(sequence_parallel_context): + return _origin_forward( + mod, + hidden_states, + cache_params=cache_params, + cache_position=cache_position, + attention_mask=attention_mask, + **extra_kwargs, + ) + return Qwen3_5GatedDeltaNetUlyssesPatch._run_forward( + mod, + hidden_states, + cache_params=cache_params, + cache_position=cache_position, + attention_mask=attention_mask, + sequence_parallel_context=sequence_parallel_context, + ) + + gated_delta_net_cls.forward = sp_linear_forward + gated_delta_net_cls._twinkle_sp_linear_patched = True diff --git a/tests/kernel/npu_impls/test_attention.py b/tests/kernel/npu_impls/test_attention.py index ed916dba..06fa41ac 100644 --- a/tests/kernel/npu_impls/test_attention.py +++ b/tests/kernel/npu_impls/test_attention.py @@ -1,16 +1,16 @@ -def test_attention_imports(): - from twinkle.kernel.npu_impls.attention import npu_sdpa_attention_forward - assert callable(npu_sdpa_attention_forward) - - -def test_attention_signature(): - import inspect - - from twinkle.kernel.npu_impls.attention import npu_sdpa_attention_forward - - sig = inspect.signature(npu_sdpa_attention_forward) - params = list(sig.parameters) - assert params[:5] == ['module', 'query', 'key', 'value', 'attention_mask'] - assert sig.parameters['dropout'].default == 0.0 - assert sig.parameters['scaling'].default is None +def test_attention_imports(): + from twinkle.kernel.npu_impls.attention import npu_sdpa_attention_forward + assert callable(npu_sdpa_attention_forward) + + +def test_attention_signature(): + import inspect + + from twinkle.kernel.npu_impls.attention import npu_sdpa_attention_forward + + sig = inspect.signature(npu_sdpa_attention_forward) + params = list(sig.parameters) + assert params[:5] == ['module', 'query', 'key', 'value', 'attention_mask'] + assert sig.parameters['dropout'].default == 0.0 + assert sig.parameters['scaling'].default is None assert sig.parameters['is_causal'].default is None \ No newline at end of file diff --git a/tests/kernel/npu_impls/test_fla.py b/tests/kernel/npu_impls/test_fla.py index 0cfeda1d..86de8c6b 100644 --- a/tests/kernel/npu_impls/test_fla.py +++ b/tests/kernel/npu_impls/test_fla.py @@ -1,55 +1,55 @@ -def test_fla_imports(): - from twinkle.kernel.npu_impls.fla import apply_qwen3_5_fla - assert callable(apply_qwen3_5_fla) - - -def test_fla_disabled_by_env(monkeypatch): - monkeypatch.setenv('TWINKLE_NPU_FLA', '0') - from twinkle.kernel.npu_impls.fla import apply_qwen3_5_fla - # With env=0, function returns 0 (no-op) without raising - assert apply_qwen3_5_fla(None) == 0 - - -def test_fla_skips_when_no_torch_npu(monkeypatch): - import sys - monkeypatch.setenv('TWINKLE_NPU_FLA', '1') - monkeypatch.setitem(sys.modules, 'torch_npu', None) # forces ImportError on import - from twinkle.kernel.npu_impls import fla as fla_mod - # Reload-tolerant: should return 0 when torch_npu is missing. - assert fla_mod.apply_qwen3_5_fla(None) == 0 - - -def test_fla_does_not_flip_flag_when_mindspeed_missing(monkeypatch): - """On an NPU host where the MindSpeed FLA kernel cannot be imported, - ``apply_qwen3_5_fla`` must NOT flip the global ``is_flash_linear_attention_available`` - flag — otherwise HF transformers would route Qwen3.5 onto a FLA fast path - whose kernel is not installed (runtime failure).""" - import sys - import types - - import transformers.utils as tu - - monkeypatch.setenv('TWINKLE_NPU_FLA', '1') - # Fake torch_npu as importable (with a real __spec__ so find_spec doesn't trip) - import importlib.util - spec = importlib.util.spec_from_loader('torch_npu', loader=None) - fake_npu = importlib.util.module_from_spec(spec) - monkeypatch.setitem(sys.modules, 'torch_npu', fake_npu) - # Stub causal_conv1d so the heavy real import chain doesn't run - fake_conv = types.ModuleType('twinkle.kernel.causal_conv1d') - fake_conv.npu_causal_conv1d_fn = object() - monkeypatch.setitem(sys.modules, 'twinkle.kernel.causal_conv1d', fake_conv) - # Force the MindSpeed-backed module import to fail - monkeypatch.setitem(sys.modules, 'twinkle.kernel.chunk_gated_delta_rule', None) - - original_flag = tu.is_flash_linear_attention_available - try: - from twinkle.kernel.npu_impls.fla import apply_qwen3_5_fla - assert apply_qwen3_5_fla(None) == 0 - assert tu.is_flash_linear_attention_available is original_flag, ( - 'is_flash_linear_attention_available was flipped to True while the ' - 'MindSpeed kernel is unavailable — this would break Qwen3.5 at runtime.' - ) - finally: - # Defensive cleanup in case the buggy path ran. +def test_fla_imports(): + from twinkle.kernel.npu_impls.fla import apply_qwen3_5_fla + assert callable(apply_qwen3_5_fla) + + +def test_fla_disabled_by_env(monkeypatch): + monkeypatch.setenv('TWINKLE_NPU_FLA', '0') + from twinkle.kernel.npu_impls.fla import apply_qwen3_5_fla + # With env=0, function returns 0 (no-op) without raising + assert apply_qwen3_5_fla(None) == 0 + + +def test_fla_skips_when_no_torch_npu(monkeypatch): + import sys + monkeypatch.setenv('TWINKLE_NPU_FLA', '1') + monkeypatch.setitem(sys.modules, 'torch_npu', None) # forces ImportError on import + from twinkle.kernel.npu_impls import fla as fla_mod + # Reload-tolerant: should return 0 when torch_npu is missing. + assert fla_mod.apply_qwen3_5_fla(None) == 0 + + +def test_fla_does_not_flip_flag_when_mindspeed_missing(monkeypatch): + """On an NPU host where the MindSpeed FLA kernel cannot be imported, + ``apply_qwen3_5_fla`` must NOT flip the global ``is_flash_linear_attention_available`` + flag — otherwise HF transformers would route Qwen3.5 onto a FLA fast path + whose kernel is not installed (runtime failure).""" + import sys + import types + + import transformers.utils as tu + + monkeypatch.setenv('TWINKLE_NPU_FLA', '1') + # Fake torch_npu as importable (with a real __spec__ so find_spec doesn't trip) + import importlib.util + spec = importlib.util.spec_from_loader('torch_npu', loader=None) + fake_npu = importlib.util.module_from_spec(spec) + monkeypatch.setitem(sys.modules, 'torch_npu', fake_npu) + # Stub causal_conv1d so the heavy real import chain doesn't run + fake_conv = types.ModuleType('twinkle.kernel.causal_conv1d') + fake_conv.npu_causal_conv1d_fn = object() + monkeypatch.setitem(sys.modules, 'twinkle.kernel.causal_conv1d', fake_conv) + # Force the MindSpeed-backed module import to fail + monkeypatch.setitem(sys.modules, 'twinkle.kernel.chunk_gated_delta_rule', None) + + original_flag = tu.is_flash_linear_attention_available + try: + from twinkle.kernel.npu_impls.fla import apply_qwen3_5_fla + assert apply_qwen3_5_fla(None) == 0 + assert tu.is_flash_linear_attention_available is original_flag, ( + 'is_flash_linear_attention_available was flipped to True while the ' + 'MindSpeed kernel is unavailable — this would break Qwen3.5 at runtime.' + ) + finally: + # Defensive cleanup in case the buggy path ran. tu.is_flash_linear_attention_available = original_flag \ No newline at end of file diff --git a/tests/kernel/npu_impls/test_moe.py b/tests/kernel/npu_impls/test_moe.py index 34452b61..e26a5856 100644 --- a/tests/kernel/npu_impls/test_moe.py +++ b/tests/kernel/npu_impls/test_moe.py @@ -1,12 +1,12 @@ -def test_moe_imports(): - from twinkle.kernel.npu_impls.moe import ( - GmmFunction, - npu_grouped_mm, - npu_packed_moe_experts_forward, - npu_qwen3_5_moe_sparse_block_forward, - ) - import torch - assert issubclass(GmmFunction, torch.autograd.Function) - assert callable(npu_grouped_mm) - assert callable(npu_packed_moe_experts_forward) +def test_moe_imports(): + from twinkle.kernel.npu_impls.moe import ( + GmmFunction, + npu_grouped_mm, + npu_packed_moe_experts_forward, + npu_qwen3_5_moe_sparse_block_forward, + ) + import torch + assert issubclass(GmmFunction, torch.autograd.Function) + assert callable(npu_grouped_mm) + assert callable(npu_packed_moe_experts_forward) assert callable(npu_qwen3_5_moe_sparse_block_forward) \ No newline at end of file diff --git a/tests/kernel/npu_impls/test_rms_norm.py b/tests/kernel/npu_impls/test_rms_norm.py index 184d7ef7..a50c6ad0 100644 --- a/tests/kernel/npu_impls/test_rms_norm.py +++ b/tests/kernel/npu_impls/test_rms_norm.py @@ -1,40 +1,40 @@ -import pytest -import torch -import torch.nn as nn - -try: - import torch_npu # noqa: F401 - _NPU_OK = True -except ImportError: - _NPU_OK = False - - -def test_imports(): - """NpuRMSNorm and npu_gated_rms_norm_forward import without torch_npu.""" - from twinkle.kernel.npu_impls.rms_norm import NpuRMSNorm, npu_gated_rms_norm_forward - assert NpuRMSNorm is not None - assert callable(npu_gated_rms_norm_forward) - - -def test_npu_rmsnorm_has_no_init(): - """Class-replacement contract: NpuRMSNorm must not define its own __init__.""" - from twinkle.kernel.npu_impls.rms_norm import NpuRMSNorm - # If NpuRMSNorm defines __init__, it'd appear in NpuRMSNorm.__dict__ - assert '__init__' not in NpuRMSNorm.__dict__ - - -@pytest.mark.skipif(not _NPU_OK, reason='torch_npu unavailable') -def test_npu_rmsnorm_forward_runs_on_npu(): - from twinkle.kernel.npu_impls.rms_norm import NpuRMSNorm - - class _Orig(nn.Module): - def __init__(self): - super().__init__() - self.weight = nn.Parameter(torch.ones(8)) - self.variance_epsilon = 1e-6 - - m = _Orig().to('npu') - m.__class__ = NpuRMSNorm - x = torch.randn(2, 8, device='npu') - y = m(x) +import pytest +import torch +import torch.nn as nn + +try: + import torch_npu # noqa: F401 + _NPU_OK = True +except ImportError: + _NPU_OK = False + + +def test_imports(): + """NpuRMSNorm and npu_gated_rms_norm_forward import without torch_npu.""" + from twinkle.kernel.npu_impls.rms_norm import NpuRMSNorm, npu_gated_rms_norm_forward + assert NpuRMSNorm is not None + assert callable(npu_gated_rms_norm_forward) + + +def test_npu_rmsnorm_has_no_init(): + """Class-replacement contract: NpuRMSNorm must not define its own __init__.""" + from twinkle.kernel.npu_impls.rms_norm import NpuRMSNorm + # If NpuRMSNorm defines __init__, it'd appear in NpuRMSNorm.__dict__ + assert '__init__' not in NpuRMSNorm.__dict__ + + +@pytest.mark.skipif(not _NPU_OK, reason='torch_npu unavailable') +def test_npu_rmsnorm_forward_runs_on_npu(): + from twinkle.kernel.npu_impls.rms_norm import NpuRMSNorm + + class _Orig(nn.Module): + def __init__(self): + super().__init__() + self.weight = nn.Parameter(torch.ones(8)) + self.variance_epsilon = 1e-6 + + m = _Orig().to('npu') + m.__class__ = NpuRMSNorm + x = torch.randn(2, 8, device='npu') + y = m(x) assert y.shape == (2, 8) \ No newline at end of file diff --git a/tests/kernel/npu_impls/test_rotary.py b/tests/kernel/npu_impls/test_rotary.py index 460d0fc3..fc15fc54 100644 --- a/tests/kernel/npu_impls/test_rotary.py +++ b/tests/kernel/npu_impls/test_rotary.py @@ -1,21 +1,21 @@ -def test_rotary_imports(): - from twinkle.kernel.npu_impls.rotary import ( - npu_apply_multimodal_rotary_pos_emb, - npu_apply_rotary_pos_emb, - ) - assert callable(npu_apply_rotary_pos_emb) - assert callable(npu_apply_multimodal_rotary_pos_emb) - - -def test_rotary_signature_compat(): - """Signature must match HF apply_rotary_pos_emb so setattr swap is safe.""" - import inspect - - from twinkle.kernel.npu_impls.rotary import npu_apply_rotary_pos_emb - - sig = inspect.signature(npu_apply_rotary_pos_emb) - params = list(sig.parameters) - assert params[:4] == ['q', 'k', 'cos', 'sin'] - # position_ids and unsqueeze_dim must be optional - assert sig.parameters['position_ids'].default is None +def test_rotary_imports(): + from twinkle.kernel.npu_impls.rotary import ( + npu_apply_multimodal_rotary_pos_emb, + npu_apply_rotary_pos_emb, + ) + assert callable(npu_apply_rotary_pos_emb) + assert callable(npu_apply_multimodal_rotary_pos_emb) + + +def test_rotary_signature_compat(): + """Signature must match HF apply_rotary_pos_emb so setattr swap is safe.""" + import inspect + + from twinkle.kernel.npu_impls.rotary import npu_apply_rotary_pos_emb + + sig = inspect.signature(npu_apply_rotary_pos_emb) + params = list(sig.parameters) + assert params[:4] == ['q', 'k', 'cos', 'sin'] + # position_ids and unsqueeze_dim must be optional + assert sig.parameters['position_ids'].default is None assert sig.parameters['unsqueeze_dim'].default == 1 \ No newline at end of file diff --git a/tests/kernel/npu_impls/test_swiglu.py b/tests/kernel/npu_impls/test_swiglu.py index d4ec2da9..b3547d7f 100644 --- a/tests/kernel/npu_impls/test_swiglu.py +++ b/tests/kernel/npu_impls/test_swiglu.py @@ -1,12 +1,12 @@ -def test_swiglu_imports(): - from twinkle.kernel.npu_impls.swiglu import npu_swiglu_forward - assert callable(npu_swiglu_forward) - - -def test_swiglu_signature(): - import inspect - - from twinkle.kernel.npu_impls.swiglu import npu_swiglu_forward - - params = list(inspect.signature(npu_swiglu_forward).parameters) +def test_swiglu_imports(): + from twinkle.kernel.npu_impls.swiglu import npu_swiglu_forward + assert callable(npu_swiglu_forward) + + +def test_swiglu_signature(): + import inspect + + from twinkle.kernel.npu_impls.swiglu import npu_swiglu_forward + + params = list(inspect.signature(npu_swiglu_forward).parameters) assert params == ['self', 'hidden_state'] \ No newline at end of file diff --git a/tests/kernel/test_builtin.py b/tests/kernel/test_builtin.py index 38d83915..0bd050ed 100644 --- a/tests/kernel/test_builtin.py +++ b/tests/kernel/test_builtin.py @@ -1,90 +1,90 @@ -import importlib.machinery -import sys -import types - -import torch -import torch.nn as nn - -import pytest - - -def _fake_module(name: str): - module = types.ModuleType(name) - module.__spec__ = importlib.machinery.ModuleSpec(name, loader=None) - return module - - -def test_npu_builtin_returns_dict(): - from twinkle.kernel.builtin import npu_builtin - bundle = npu_builtin() - assert isinstance(bundle, dict) - assert len(bundle) > 0 - - -def test_npu_builtin_values_are_npu_gated(): - """Every value in npu_builtin() must be wrapped in {'npu': ...} so it's - safely no-op on CUDA/CPU.""" - from twinkle.kernel.builtin import npu_builtin - for key, value in npu_builtin().items(): - assert isinstance(value, dict), f'value for {key!r} is not a device-dict' - assert 'npu' in value, f'value for {key!r} is missing npu entry' - - -def test_npu_builtin_compose_with_user_override(): - """User-supplied keys override the builtin (via plain dict merge).""" - from twinkle.kernel.builtin import npu_builtin - sentinel = object() - merged = {**npu_builtin(), 'fake.module.path.fn': sentinel} - assert merged['fake.module.path.fn'] is sentinel - - -def test_npu_builtin_safe_on_cpu_model(): - """kernelize(cpu_model, npu_builtin()) must not raise and not modify.""" - from twinkle.kernel import kernelize - from twinkle.kernel.builtin import npu_builtin - - m = nn.Sequential(nn.Linear(2, 2)) - pre_type = type(m[0]) - out = kernelize(m, npu_builtin()) - assert out is m - assert type(m[0]) is pre_type # no replacement happened (cpu device) - - -def test_npu_builtin_skips_missing_modeling_modules(): - """If transformers.models.qwen3_5 is not installed, the bundle must - still produce a dict (with whatever subset is available).""" - from twinkle.kernel.builtin import npu_builtin - bundle = npu_builtin() # must not raise - assert isinstance(bundle, dict) - - -def test_npu_builtin_does_not_overwrite_global_sdpa_on_non_npu_host(monkeypatch): - """Calling npu_builtin() on a CUDA/CPU host must not contaminate the - global HF SDPA registry. The NPU impl inverts boolean masks, which is - wrong for non-NPU execution.""" - from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS - from twinkle.kernel.builtin import npu_builtin - from twinkle.utils.device_mesh import Platform - - monkeypatch.setattr(Platform, 'device_prefix', staticmethod(lambda platform=None: 'cuda')) - original = ALL_ATTENTION_FUNCTIONS.get('sdpa') - npu_builtin() - assert ALL_ATTENTION_FUNCTIONS.get('sdpa') is original - - -def test_npu_builtin_skips_side_effects_on_non_npu_platform(monkeypatch): - from twinkle.kernel import builtin - from twinkle.kernel.npu_impls import fla - from twinkle.utils.device_mesh import Platform - - installs = [] - fla_calls = [] - monkeypatch.setattr(Platform, 'device_prefix', staticmethod(lambda platform=None: 'cuda')) - monkeypatch.setitem(sys.modules, 'torch_npu', _fake_module('torch_npu')) - monkeypatch.setattr(builtin, '_install_sdpa', lambda impl: installs.append(impl)) - monkeypatch.setattr(fla, 'apply_qwen3_5_fla', lambda model: fla_calls.append(model)) - - builtin.npu_builtin(nn.Linear(1, 1)) - - assert installs == [] - assert fla_calls == [] +import importlib.machinery +import sys +import types + +import torch +import torch.nn as nn + +import pytest + + +def _fake_module(name: str): + module = types.ModuleType(name) + module.__spec__ = importlib.machinery.ModuleSpec(name, loader=None) + return module + + +def test_npu_builtin_returns_dict(): + from twinkle.kernel.builtin import npu_builtin + bundle = npu_builtin() + assert isinstance(bundle, dict) + assert len(bundle) > 0 + + +def test_npu_builtin_values_are_npu_gated(): + """Every value in npu_builtin() must be wrapped in {'npu': ...} so it's + safely no-op on CUDA/CPU.""" + from twinkle.kernel.builtin import npu_builtin + for key, value in npu_builtin().items(): + assert isinstance(value, dict), f'value for {key!r} is not a device-dict' + assert 'npu' in value, f'value for {key!r} is missing npu entry' + + +def test_npu_builtin_compose_with_user_override(): + """User-supplied keys override the builtin (via plain dict merge).""" + from twinkle.kernel.builtin import npu_builtin + sentinel = object() + merged = {**npu_builtin(), 'fake.module.path.fn': sentinel} + assert merged['fake.module.path.fn'] is sentinel + + +def test_npu_builtin_safe_on_cpu_model(): + """kernelize(cpu_model, npu_builtin()) must not raise and not modify.""" + from twinkle.kernel import kernelize + from twinkle.kernel.builtin import npu_builtin + + m = nn.Sequential(nn.Linear(2, 2)) + pre_type = type(m[0]) + out = kernelize(m, npu_builtin()) + assert out is m + assert type(m[0]) is pre_type # no replacement happened (cpu device) + + +def test_npu_builtin_skips_missing_modeling_modules(): + """If transformers.models.qwen3_5 is not installed, the bundle must + still produce a dict (with whatever subset is available).""" + from twinkle.kernel.builtin import npu_builtin + bundle = npu_builtin() # must not raise + assert isinstance(bundle, dict) + + +def test_npu_builtin_does_not_overwrite_global_sdpa_on_non_npu_host(monkeypatch): + """Calling npu_builtin() on a CUDA/CPU host must not contaminate the + global HF SDPA registry. The NPU impl inverts boolean masks, which is + wrong for non-NPU execution.""" + from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS + from twinkle.kernel.builtin import npu_builtin + from twinkle.utils.device_mesh import Platform + + monkeypatch.setattr(Platform, 'device_prefix', staticmethod(lambda platform=None: 'cuda')) + original = ALL_ATTENTION_FUNCTIONS.get('sdpa') + npu_builtin() + assert ALL_ATTENTION_FUNCTIONS.get('sdpa') is original + + +def test_npu_builtin_skips_side_effects_on_non_npu_platform(monkeypatch): + from twinkle.kernel import builtin + from twinkle.kernel.npu_impls import fla + from twinkle.utils.device_mesh import Platform + + installs = [] + fla_calls = [] + monkeypatch.setattr(Platform, 'device_prefix', staticmethod(lambda platform=None: 'cuda')) + monkeypatch.setitem(sys.modules, 'torch_npu', _fake_module('torch_npu')) + monkeypatch.setattr(builtin, '_install_sdpa', lambda impl: installs.append(impl)) + monkeypatch.setattr(fla, 'apply_qwen3_5_fla', lambda model: fla_calls.append(model)) + + builtin.npu_builtin(nn.Linear(1, 1)) + + assert installs == [] + assert fla_calls == [] diff --git a/tests/kernel/test_hub.py b/tests/kernel/test_hub.py index e1e2644e..a0a7cf63 100644 --- a/tests/kernel/test_hub.py +++ b/tests/kernel/test_hub.py @@ -1,54 +1,54 @@ -import pytest - -from twinkle.kernel.core import HubRef, hub - - -def test_hub_with_version(): - ref = hub('kernels-community/activation:SiluAndMul', version=1) - assert isinstance(ref, HubRef) - assert ref.repo_id == 'kernels-community/activation' - assert ref.layer_name == 'SiluAndMul' - assert ref.version == 1 - assert ref.revision is None - assert ref.backend is None - assert ref.trust_remote_code is False - - -def test_hub_with_revision(): - ref = hub('org/repo:Layer', revision='main') - assert ref.revision == 'main' - assert ref.version is None - - -def test_hub_with_backend_and_trust(): - ref = hub('org/repo:Layer', version=2, backend='cuda', trust_remote_code=True) - assert ref.backend == 'cuda' - assert ref.trust_remote_code is True - - -def test_hub_rejects_both_revision_and_version(): - with pytest.raises(ValueError, match='Exactly one'): - hub('org/repo:Layer', revision='main', version=1) - - -def test_hub_rejects_neither_revision_nor_version(): - with pytest.raises(ValueError, match='Exactly one'): - hub('org/repo:Layer') - - -def test_hub_rejects_missing_colon(): - with pytest.raises(ValueError, match='repo_id:LayerName'): - hub('org/repo', version=1) - - -def test_hub_handles_colon_in_repo_id(): - # rsplit takes only the last colon - ref = hub('org:sub/repo:Layer', version=1) - assert ref.repo_id == 'org:sub/repo' - assert ref.layer_name == 'Layer' - - -def test_hubref_is_frozen(): - ref = hub('org/repo:Layer', version=1) - with pytest.raises(Exception): +import pytest + +from twinkle.kernel.core import HubRef, hub + + +def test_hub_with_version(): + ref = hub('kernels-community/activation:SiluAndMul', version=1) + assert isinstance(ref, HubRef) + assert ref.repo_id == 'kernels-community/activation' + assert ref.layer_name == 'SiluAndMul' + assert ref.version == 1 + assert ref.revision is None + assert ref.backend is None + assert ref.trust_remote_code is False + + +def test_hub_with_revision(): + ref = hub('org/repo:Layer', revision='main') + assert ref.revision == 'main' + assert ref.version is None + + +def test_hub_with_backend_and_trust(): + ref = hub('org/repo:Layer', version=2, backend='cuda', trust_remote_code=True) + assert ref.backend == 'cuda' + assert ref.trust_remote_code is True + + +def test_hub_rejects_both_revision_and_version(): + with pytest.raises(ValueError, match='Exactly one'): + hub('org/repo:Layer', revision='main', version=1) + + +def test_hub_rejects_neither_revision_nor_version(): + with pytest.raises(ValueError, match='Exactly one'): + hub('org/repo:Layer') + + +def test_hub_rejects_missing_colon(): + with pytest.raises(ValueError, match='repo_id:LayerName'): + hub('org/repo', version=1) + + +def test_hub_handles_colon_in_repo_id(): + # rsplit takes only the last colon + ref = hub('org:sub/repo:Layer', version=1) + assert ref.repo_id == 'org:sub/repo' + assert ref.layer_name == 'Layer' + + +def test_hubref_is_frozen(): + ref = hub('org/repo:Layer', version=1) + with pytest.raises(Exception): ref.repo_id = 'other' \ No newline at end of file diff --git a/tests/kernel/test_infer_device.py b/tests/kernel/test_infer_device.py index 7f9d5581..a0f6fbb9 100644 --- a/tests/kernel/test_infer_device.py +++ b/tests/kernel/test_infer_device.py @@ -1,29 +1,29 @@ -import torch -import torch.nn as nn - -from twinkle.kernel.core import _infer_device - - -class _NoParamsNoBuffers(nn.Module): - pass - - -class _OnlyBuffer(nn.Module): - def __init__(self): - super().__init__() - self.register_buffer('b', torch.zeros(2)) - - -def test_infer_device_from_parameter(): - m = nn.Linear(2, 3) - assert _infer_device(m) == 'cpu' - - -def test_infer_device_from_buffer_when_no_params(): - m = _OnlyBuffer() - assert _infer_device(m) == 'cpu' - - -def test_infer_device_defaults_to_cpu_when_empty(): - m = _NoParamsNoBuffers() +import torch +import torch.nn as nn + +from twinkle.kernel.core import _infer_device + + +class _NoParamsNoBuffers(nn.Module): + pass + + +class _OnlyBuffer(nn.Module): + def __init__(self): + super().__init__() + self.register_buffer('b', torch.zeros(2)) + + +def test_infer_device_from_parameter(): + m = nn.Linear(2, 3) + assert _infer_device(m) == 'cpu' + + +def test_infer_device_from_buffer_when_no_params(): + m = _OnlyBuffer() + assert _infer_device(m) == 'cpu' + + +def test_infer_device_defaults_to_cpu_when_empty(): + m = _NoParamsNoBuffers() assert _infer_device(m) == 'cpu' \ No newline at end of file diff --git a/tests/kernel/test_kernelize.py b/tests/kernel/test_kernelize.py index cdb98cae..44018fc2 100644 --- a/tests/kernel/test_kernelize.py +++ b/tests/kernel/test_kernelize.py @@ -1,98 +1,98 @@ -import sys -import types - -import pytest -import torch -import torch.nn as nn - -from twinkle.kernel.core import HubRef, kernelize - - -class _SrcLayer(nn.Module): - def __init__(self): - super().__init__() - self.w = nn.Parameter(torch.zeros(1)) - - def forward(self, x): - return x - - -class _DstLayer(nn.Module): - def forward(self, x): - return x + 100 - - -def test_kernelize_class_to_class_replacement(): - parent = nn.Sequential(_SrcLayer(), _SrcLayer()) - out = kernelize(parent, {_SrcLayer: _DstLayer}) - assert out is parent - assert type(parent[0]) is _DstLayer - assert type(parent[1]) is _DstLayer - - -def test_kernelize_empty_mapping_returns_model(): - m = _SrcLayer() - assert kernelize(m, {}) is m - assert type(m) is _SrcLayer - - -def test_kernelize_string_key_calls_setattr(): - mod_name = 'tests.kernel._tmp_kernelize_str' - mod = types.ModuleType(mod_name) - mod.target_fn = lambda x: x - sys.modules[mod_name] = mod - try: - new_fn = lambda x: x * 3 # noqa: E731 - kernelize(nn.Linear(1, 1), {f'{mod_name}.target_fn': new_fn}) - assert mod.target_fn is new_fn - finally: - sys.modules.pop(mod_name, None) - - -def test_kernelize_device_dict_match(monkeypatch): - from twinkle.utils.device_mesh import Platform - - parent = nn.Sequential(_SrcLayer()) - monkeypatch.setattr(Platform, 'device_prefix', staticmethod(lambda platform=None: 'cpu')) - - kernelize(parent, {_SrcLayer: {'cpu': _DstLayer, 'npu': nn.Identity}}) - - assert type(parent[0]) is _DstLayer - - -def test_kernelize_uses_platform_device_prefix(monkeypatch): - from twinkle.utils.device_mesh import Platform - - parent = nn.Sequential(_SrcLayer()) # params may still be CPU before FSDP placement - monkeypatch.setattr(Platform, 'device_prefix', staticmethod(lambda platform=None: 'npu')) - - kernelize(parent, {_SrcLayer: {'npu': _DstLayer}}) - - assert type(parent[0]) is _DstLayer - - -def test_kernelize_device_dict_miss_skips_silently(monkeypatch): - from twinkle.utils.device_mesh import Platform - - parent = nn.Sequential(_SrcLayer()) - monkeypatch.setattr(Platform, 'device_prefix', staticmethod(lambda platform=None: 'cpu')) - - kernelize(parent, {_SrcLayer: {'npu': _DstLayer}}) - - assert type(parent[0]) is _SrcLayer - - -def test_kernelize_rejects_unknown_key_type(): - with pytest.raises(TypeError, match='Unsupported mapping key'): - kernelize(nn.Linear(1, 1), {42: _DstLayer}) - - -def test_kernelize_loads_hub_ref(monkeypatch): - # Stand in for HF kernels: patch _load_hub_ref to return _DstLayer - from twinkle.kernel import core as _core - monkeypatch.setattr(_core, '_load_hub_ref', lambda ref: _DstLayer) - - parent = nn.Sequential(_SrcLayer()) - ref = HubRef('org/repo', 'X', revision='main') - kernelize(parent, {_SrcLayer: ref}) - assert type(parent[0]) is _DstLayer +import sys +import types + +import pytest +import torch +import torch.nn as nn + +from twinkle.kernel.core import HubRef, kernelize + + +class _SrcLayer(nn.Module): + def __init__(self): + super().__init__() + self.w = nn.Parameter(torch.zeros(1)) + + def forward(self, x): + return x + + +class _DstLayer(nn.Module): + def forward(self, x): + return x + 100 + + +def test_kernelize_class_to_class_replacement(): + parent = nn.Sequential(_SrcLayer(), _SrcLayer()) + out = kernelize(parent, {_SrcLayer: _DstLayer}) + assert out is parent + assert type(parent[0]) is _DstLayer + assert type(parent[1]) is _DstLayer + + +def test_kernelize_empty_mapping_returns_model(): + m = _SrcLayer() + assert kernelize(m, {}) is m + assert type(m) is _SrcLayer + + +def test_kernelize_string_key_calls_setattr(): + mod_name = 'tests.kernel._tmp_kernelize_str' + mod = types.ModuleType(mod_name) + mod.target_fn = lambda x: x + sys.modules[mod_name] = mod + try: + new_fn = lambda x: x * 3 # noqa: E731 + kernelize(nn.Linear(1, 1), {f'{mod_name}.target_fn': new_fn}) + assert mod.target_fn is new_fn + finally: + sys.modules.pop(mod_name, None) + + +def test_kernelize_device_dict_match(monkeypatch): + from twinkle.utils.device_mesh import Platform + + parent = nn.Sequential(_SrcLayer()) + monkeypatch.setattr(Platform, 'device_prefix', staticmethod(lambda platform=None: 'cpu')) + + kernelize(parent, {_SrcLayer: {'cpu': _DstLayer, 'npu': nn.Identity}}) + + assert type(parent[0]) is _DstLayer + + +def test_kernelize_uses_platform_device_prefix(monkeypatch): + from twinkle.utils.device_mesh import Platform + + parent = nn.Sequential(_SrcLayer()) # params may still be CPU before FSDP placement + monkeypatch.setattr(Platform, 'device_prefix', staticmethod(lambda platform=None: 'npu')) + + kernelize(parent, {_SrcLayer: {'npu': _DstLayer}}) + + assert type(parent[0]) is _DstLayer + + +def test_kernelize_device_dict_miss_skips_silently(monkeypatch): + from twinkle.utils.device_mesh import Platform + + parent = nn.Sequential(_SrcLayer()) + monkeypatch.setattr(Platform, 'device_prefix', staticmethod(lambda platform=None: 'cpu')) + + kernelize(parent, {_SrcLayer: {'npu': _DstLayer}}) + + assert type(parent[0]) is _SrcLayer + + +def test_kernelize_rejects_unknown_key_type(): + with pytest.raises(TypeError, match='Unsupported mapping key'): + kernelize(nn.Linear(1, 1), {42: _DstLayer}) + + +def test_kernelize_loads_hub_ref(monkeypatch): + # Stand in for HF kernels: patch _load_hub_ref to return _DstLayer + from twinkle.kernel import core as _core + monkeypatch.setattr(_core, '_load_hub_ref', lambda ref: _DstLayer) + + parent = nn.Sequential(_SrcLayer()) + ref = HubRef('org/repo', 'X', revision='main') + kernelize(parent, {_SrcLayer: ref}) + assert type(parent[0]) is _DstLayer diff --git a/tests/kernel/test_load_hub_ref.py b/tests/kernel/test_load_hub_ref.py index 747e3fdc..96d9c78e 100644 --- a/tests/kernel/test_load_hub_ref.py +++ b/tests/kernel/test_load_hub_ref.py @@ -1,69 +1,69 @@ -import sys -import types -from unittest.mock import patch - -import pytest - -from twinkle.kernel.core import HubRef, _load_hub_ref - - -def _install_fake_kernels(layer_obj=None, no_layers=False): - """Install a fake `kernels` module with a controllable `get_kernel`.""" - fake = types.ModuleType('kernels') - - def fake_get_kernel(repo_id, **kwargs): - m = types.ModuleType('fake_kernel') - if not no_layers: - layers_ns = types.SimpleNamespace() - if layer_obj is not None: - layers_ns.MyLayer = layer_obj - m.layers = layers_ns - return m - - fake.get_kernel = fake_get_kernel - sys.modules['kernels'] = fake - - -def _uninstall_fake_kernels(): - sys.modules.pop('kernels', None) - - -def test_load_hub_ref_returns_layer(): - sentinel = object() - _install_fake_kernels(layer_obj=sentinel) - try: - ref = HubRef('org/repo', 'MyLayer', revision='main') - assert _load_hub_ref(ref) is sentinel - finally: - _uninstall_fake_kernels() - - -def test_load_hub_ref_raises_if_layers_missing(): - _install_fake_kernels(no_layers=True) - try: - ref = HubRef('org/repo', 'MyLayer', revision='main') - with pytest.raises(ValueError, match='does not define any layers'): - _load_hub_ref(ref) - finally: - _uninstall_fake_kernels() - - -def test_load_hub_ref_raises_if_layer_name_missing(): - _install_fake_kernels(layer_obj=None) # MyLayer not present - try: - ref = HubRef('org/repo', 'Missing', revision='main') - with pytest.raises(ValueError, match='not found'): - _load_hub_ref(ref) - finally: - _uninstall_fake_kernels() - - -def test_load_hub_ref_install_hint_when_kernels_missing(): - # Force `import kernels` to fail - sys.modules['kernels'] = None # short-circuits import to ImportError - try: - ref = HubRef('org/repo', 'MyLayer', revision='main') - with pytest.raises(ImportError, match='pip install kernels'): - _load_hub_ref(ref) - finally: +import sys +import types +from unittest.mock import patch + +import pytest + +from twinkle.kernel.core import HubRef, _load_hub_ref + + +def _install_fake_kernels(layer_obj=None, no_layers=False): + """Install a fake `kernels` module with a controllable `get_kernel`.""" + fake = types.ModuleType('kernels') + + def fake_get_kernel(repo_id, **kwargs): + m = types.ModuleType('fake_kernel') + if not no_layers: + layers_ns = types.SimpleNamespace() + if layer_obj is not None: + layers_ns.MyLayer = layer_obj + m.layers = layers_ns + return m + + fake.get_kernel = fake_get_kernel + sys.modules['kernels'] = fake + + +def _uninstall_fake_kernels(): + sys.modules.pop('kernels', None) + + +def test_load_hub_ref_returns_layer(): + sentinel = object() + _install_fake_kernels(layer_obj=sentinel) + try: + ref = HubRef('org/repo', 'MyLayer', revision='main') + assert _load_hub_ref(ref) is sentinel + finally: + _uninstall_fake_kernels() + + +def test_load_hub_ref_raises_if_layers_missing(): + _install_fake_kernels(no_layers=True) + try: + ref = HubRef('org/repo', 'MyLayer', revision='main') + with pytest.raises(ValueError, match='does not define any layers'): + _load_hub_ref(ref) + finally: + _uninstall_fake_kernels() + + +def test_load_hub_ref_raises_if_layer_name_missing(): + _install_fake_kernels(layer_obj=None) # MyLayer not present + try: + ref = HubRef('org/repo', 'Missing', revision='main') + with pytest.raises(ValueError, match='not found'): + _load_hub_ref(ref) + finally: + _uninstall_fake_kernels() + + +def test_load_hub_ref_install_hint_when_kernels_missing(): + # Force `import kernels` to fail + sys.modules['kernels'] = None # short-circuits import to ImportError + try: + ref = HubRef('org/repo', 'MyLayer', revision='main') + with pytest.raises(ImportError, match='pip install kernels'): + _load_hub_ref(ref) + finally: sys.modules.pop('kernels', None) \ No newline at end of file diff --git a/tests/kernel/test_public_api.py b/tests/kernel/test_public_api.py index f9a17a2a..9000a758 100644 --- a/tests/kernel/test_public_api.py +++ b/tests/kernel/test_public_api.py @@ -1,22 +1,22 @@ -def test_public_exports_exactly_three_symbols(): - import twinkle.kernel as k - assert sorted(k.__all__) == ['hub', 'kernelize', 'npu_builtin'] - assert callable(k.kernelize) - assert callable(k.npu_builtin) - assert callable(k.hub) - - -def test_no_legacy_symbols(): - """Legacy registrar / patch helpers must be gone.""" - import twinkle.kernel as k - legacy = [ - 'kernelize_model', 'register_layer_kernel', 'register_function_kernel', - 'register_kernels', 'register_external_layer', 'apply_npu_patch', - 'apply_npu_fused_ops', 'apply_function_kernel', 'apply_layer_kernel', - 'register_layer_batch', 'register_npu_fused_function_kernels', - 'get_global_layer_registry', 'get_global_function_registry', - 'get_global_external_layer_registry', 'LayerRegistry', - 'ExternalLayerRegistry', 'FunctionRegistry', - ] - for name in legacy: +def test_public_exports_exactly_three_symbols(): + import twinkle.kernel as k + assert sorted(k.__all__) == ['hub', 'kernelize', 'npu_builtin'] + assert callable(k.kernelize) + assert callable(k.npu_builtin) + assert callable(k.hub) + + +def test_no_legacy_symbols(): + """Legacy registrar / patch helpers must be gone.""" + import twinkle.kernel as k + legacy = [ + 'kernelize_model', 'register_layer_kernel', 'register_function_kernel', + 'register_kernels', 'register_external_layer', 'apply_npu_patch', + 'apply_npu_fused_ops', 'apply_function_kernel', 'apply_layer_kernel', + 'register_layer_batch', 'register_npu_fused_function_kernels', + 'get_global_layer_registry', 'get_global_function_registry', + 'get_global_external_layer_registry', 'LayerRegistry', + 'ExternalLayerRegistry', 'FunctionRegistry', + ] + for name in legacy: assert not hasattr(k, name), f'unexpected legacy symbol: {name}' \ No newline at end of file diff --git a/tests/kernel/test_replace.py b/tests/kernel/test_replace.py index e649b2e3..40013f2c 100644 --- a/tests/kernel/test_replace.py +++ b/tests/kernel/test_replace.py @@ -1,74 +1,74 @@ -import sys -import types - -import torch.nn as nn - -from twinkle.kernel.core import _replace_attr, _replace_class - - -class _Target(nn.Module): - def forward(self, x): - return x - - -class _Impl(nn.Module): - def forward(self, x): - return x + 1 - - -class _SubTarget(_Target): - pass - - -def test_replace_class_rewrites_exact_match(): - m = _Target() - parent = nn.Sequential(_Target(), nn.Linear(1, 1)) - _replace_class(parent, _Target, _Impl) - assert type(parent[0]) is _Impl - - -def test_replace_class_skips_subclass(): - parent = nn.Sequential(_SubTarget()) - _replace_class(parent, _Target, _Impl) - # exact match only - _SubTarget should NOT be rewritten - assert type(parent[0]) is _SubTarget - - -def test_replace_class_idempotent(): - m = nn.Sequential(_Target()) - _replace_class(m, _Target, _Impl) - _replace_class(m, _Target, _Impl) # second call must be safe - assert type(m[0]) is _Impl - - -def test_replace_attr_sets_module_attribute(): - mod_name = 'tests.kernel._tmp_replace_attr' - mod = types.ModuleType(mod_name) - mod.target_fn = lambda x: x - sys.modules[mod_name] = mod - try: - new_fn = lambda x: x * 2 # noqa: E731 - _replace_attr(f'{mod_name}.target_fn', new_fn) - assert mod.target_fn is new_fn - finally: - sys.modules.pop(mod_name, None) - - -def test_replace_attr_supports_class_attribute(): - import sys - import types - - mod_name = 'tests.kernel._tmp_class_attr' - mod = types.ModuleType(mod_name) - - class Foo: - def forward(self, x): - return x - mod.Foo = Foo - sys.modules[mod_name] = mod - try: - new_forward = lambda self, x: x + 7 # noqa: E731 - _replace_attr(f'{mod_name}.Foo.forward', new_forward) - assert Foo.forward is new_forward - finally: +import sys +import types + +import torch.nn as nn + +from twinkle.kernel.core import _replace_attr, _replace_class + + +class _Target(nn.Module): + def forward(self, x): + return x + + +class _Impl(nn.Module): + def forward(self, x): + return x + 1 + + +class _SubTarget(_Target): + pass + + +def test_replace_class_rewrites_exact_match(): + m = _Target() + parent = nn.Sequential(_Target(), nn.Linear(1, 1)) + _replace_class(parent, _Target, _Impl) + assert type(parent[0]) is _Impl + + +def test_replace_class_skips_subclass(): + parent = nn.Sequential(_SubTarget()) + _replace_class(parent, _Target, _Impl) + # exact match only - _SubTarget should NOT be rewritten + assert type(parent[0]) is _SubTarget + + +def test_replace_class_idempotent(): + m = nn.Sequential(_Target()) + _replace_class(m, _Target, _Impl) + _replace_class(m, _Target, _Impl) # second call must be safe + assert type(m[0]) is _Impl + + +def test_replace_attr_sets_module_attribute(): + mod_name = 'tests.kernel._tmp_replace_attr' + mod = types.ModuleType(mod_name) + mod.target_fn = lambda x: x + sys.modules[mod_name] = mod + try: + new_fn = lambda x: x * 2 # noqa: E731 + _replace_attr(f'{mod_name}.target_fn', new_fn) + assert mod.target_fn is new_fn + finally: + sys.modules.pop(mod_name, None) + + +def test_replace_attr_supports_class_attribute(): + import sys + import types + + mod_name = 'tests.kernel._tmp_class_attr' + mod = types.ModuleType(mod_name) + + class Foo: + def forward(self, x): + return x + mod.Foo = Foo + sys.modules[mod_name] = mod + try: + new_forward = lambda self, x: x + 7 # noqa: E731 + _replace_attr(f'{mod_name}.Foo.forward', new_forward) + assert Foo.forward is new_forward + finally: sys.modules.pop(mod_name, None) \ No newline at end of file diff --git a/tests/kernel/test_resolve_value.py b/tests/kernel/test_resolve_value.py index 652783f5..abc419e7 100644 --- a/tests/kernel/test_resolve_value.py +++ b/tests/kernel/test_resolve_value.py @@ -1,48 +1,48 @@ -import torch.nn as nn - -from twinkle.kernel.core import HubRef, _resolve_value - - -class _ImplA(nn.Module): - pass - - -class _ImplB(nn.Module): - pass - - -def test_passthrough_class_value(): - assert _resolve_value(_ImplA, 'cuda') is _ImplA - - -def test_passthrough_callable_value(): - f = lambda x: x # noqa: E731 - assert _resolve_value(f, 'npu') is f - - -def test_passthrough_hubref(): - ref = HubRef('org/repo', 'Layer', revision='main') - assert _resolve_value(ref, 'cuda') is ref - - -def test_device_dict_match(): - val = {'npu': _ImplA, 'cuda': _ImplB} - assert _resolve_value(val, 'npu') is _ImplA - assert _resolve_value(val, 'cuda') is _ImplB - - -def test_device_dict_miss_returns_none(): - val = {'npu': _ImplA} - assert _resolve_value(val, 'cuda') is None - - -def test_device_dict_nested(): - # nested dict -> recursive resolve - val = {'npu': {'npu': _ImplA}} - assert _resolve_value(val, 'npu') is _ImplA - - -def test_device_dict_miss_then_passthrough(): - # nested dict whose inner is also a dict that misses -> None - val = {'npu': {'cuda': _ImplA}} +import torch.nn as nn + +from twinkle.kernel.core import HubRef, _resolve_value + + +class _ImplA(nn.Module): + pass + + +class _ImplB(nn.Module): + pass + + +def test_passthrough_class_value(): + assert _resolve_value(_ImplA, 'cuda') is _ImplA + + +def test_passthrough_callable_value(): + f = lambda x: x # noqa: E731 + assert _resolve_value(f, 'npu') is f + + +def test_passthrough_hubref(): + ref = HubRef('org/repo', 'Layer', revision='main') + assert _resolve_value(ref, 'cuda') is ref + + +def test_device_dict_match(): + val = {'npu': _ImplA, 'cuda': _ImplB} + assert _resolve_value(val, 'npu') is _ImplA + assert _resolve_value(val, 'cuda') is _ImplB + + +def test_device_dict_miss_returns_none(): + val = {'npu': _ImplA} + assert _resolve_value(val, 'cuda') is None + + +def test_device_dict_nested(): + # nested dict -> recursive resolve + val = {'npu': {'npu': _ImplA}} + assert _resolve_value(val, 'npu') is _ImplA + + +def test_device_dict_miss_then_passthrough(): + # nested dict whose inner is also a dict that misses -> None + val = {'npu': {'cuda': _ImplA}} assert _resolve_value(val, 'npu') is None \ No newline at end of file From a20da5af552900b015149c7ebf92cbbe7815a099 Mon Sep 17 00:00:00 2001 From: weikaiwen <34648228+kevssim@users.noreply.github.com> Date: Mon, 29 Jun 2026 15:38:00 +0800 Subject: [PATCH 24/27] Revert "wip" This reverts commit 126efc3da85fbc286f31efbcb0db2cbf7b9107f6. --- .../transformers/ep_fsdp2_lora_deepseek_v4.py | 298 ++--- .../transformers/ep_fsdp2_lora_qwen3_5_moe.py | 306 ++--- cookbook/transformers/fsdp2.py | 218 ++-- cookbook/transformers/sp_fsdp_dense.py | 198 +-- docs/source_en/Components/Kernel/Kernel.md | 276 ++--- .../\345\206\205\346\240\270/Kernel.md" | 272 ++--- src/twinkle/kernel/__init__.py | 24 +- src/twinkle/kernel/builtin.py | 410 +++---- src/twinkle/kernel/chunk_gated_delta_rule.py | 724 +++++------ src/twinkle/kernel/core.py | 342 +++--- src/twinkle/kernel/npu_impls/__init__.py | 62 +- src/twinkle/kernel/npu_impls/attention.py | 106 +- src/twinkle/kernel/npu_impls/fla.py | 204 ++-- src/twinkle/kernel/npu_impls/moe.py | 300 ++--- src/twinkle/kernel/npu_impls/rms_norm.py | 148 +-- src/twinkle/kernel/npu_impls/rotary.py | 130 +- src/twinkle/kernel/npu_impls/swiglu.py | 38 +- .../model/transformers/moe/expert_parallel.py | 1066 ++++++++--------- .../sequence_parallel/linear_attention_sp.py | 726 +++++------ tests/kernel/npu_impls/test_attention.py | 30 +- tests/kernel/npu_impls/test_fla.py | 108 +- tests/kernel/npu_impls/test_moe.py | 22 +- tests/kernel/npu_impls/test_rms_norm.py | 78 +- tests/kernel/npu_impls/test_rotary.py | 40 +- tests/kernel/npu_impls/test_swiglu.py | 22 +- tests/kernel/test_builtin.py | 180 +-- tests/kernel/test_hub.py | 106 +- tests/kernel/test_infer_device.py | 56 +- tests/kernel/test_kernelize.py | 196 +-- tests/kernel/test_load_hub_ref.py | 136 +-- tests/kernel/test_public_api.py | 42 +- tests/kernel/test_replace.py | 146 +-- tests/kernel/test_resolve_value.py | 94 +- 33 files changed, 3552 insertions(+), 3552 deletions(-) diff --git a/cookbook/transformers/ep_fsdp2_lora_deepseek_v4.py b/cookbook/transformers/ep_fsdp2_lora_deepseek_v4.py index bb4582e0..af72efa1 100644 --- a/cookbook/transformers/ep_fsdp2_lora_deepseek_v4.py +++ b/cookbook/transformers/ep_fsdp2_lora_deepseek_v4.py @@ -1,149 +1,149 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -"""EP + FSDP2 + LoRA SFT cookbook for DeepSeek-V4. - -Run on 8 GPUs: - torchrun --nproc-per-node=8 cookbook/transformers/ep_fsdp2_lora_deepseek_v4.py -""" -import os -from pathlib import Path - -from peft import LoraConfig -from transformers import AutoConfig - -import twinkle -from twinkle import DeviceMesh, Platform, get_device_placement, get_logger -from twinkle.dataloader import DataLoader -from twinkle.dataset import Dataset, DatasetMeta -from twinkle.model import TransformersModel -from twinkle.preprocessor import SelfCognitionProcessor - -logger = get_logger() - -MODEL_ID = os.environ.get('DSV4_MODEL_ID', 'ms://deepseek-ai/DeepSeek-V4-Flash') -DATASET_ID = os.environ.get('DATASET_ID', 'ms://swift/self-cognition') -TEMPLATE_ID = os.environ.get('TEMPLATE_ID', 'DeepseekV4Template') -BATCH_SIZE = int(os.environ.get('BATCH_SIZE', '4')) -GRAD_ACCUM_STEPS = int(os.environ.get('GRAD_ACCUM_STEPS', '4')) -LOG_INTERVAL = GRAD_ACCUM_STEPS -LR = float(os.environ.get('LR', '1e-4')) -MAX_GRAD_NORM = float(os.environ.get('MAX_GRAD_NORM', '1.0')) -LORA_R = int(os.environ.get('LORA_R', '8')) -LORA_ALPHA = int(os.environ.get('LORA_ALPHA', '32')) -ENABLE_EP = os.environ.get('ENABLE_EP', '1') == '1' -OUTPUT_DIR = os.environ.get('OUTPUT_DIR', './output_dsv4') -RESUME_FROM_CHECKPOINT = os.environ.get('RESUME_FROM_CHECKPOINT') or None -RESUME_ONLY_MODEL = os.environ.get('RESUME_ONLY_MODEL', '0') == '1' -IGNORE_DATA_SKIP = os.environ.get('IGNORE_DATA_SKIP', '0') == '1' -ADAPTER_NAME = os.environ.get('ADAPTER_NAME', 'default') -NUM_GPUS = int(os.environ.get('NUM_GPUS', '8')) - -device_mesh = DeviceMesh.from_sizes( - fsdp_size=NUM_GPUS, - dp_size=1, - ep_size=NUM_GPUS, - device_type=Platform.get_platform().device_prefix(), -) -twinkle.initialize(mode='local', global_device_mesh=device_mesh) - - -def _build_lora_config(enable_ep: bool): - if enable_ep: - return LoraConfig( - r=LORA_R, - lora_alpha=LORA_ALPHA, - target_modules='all-linear', - exclude_modules=['o_a_proj'], - target_parameters=['mlp.experts.gate_up_proj', 'mlp.experts.down_proj'], - ) - # Expert weights are bare nn.Parameters. PEFT trains them through - # target_parameters/ParamWrapper, which dynamically parametrizes weights - # during forward. That is not stable with plain FSDP2, so non-EP mode uses - # regular module LoRA and does not train expert parameters. - return LoraConfig( - r=LORA_R, - lora_alpha=LORA_ALPHA, - exclude_modules=['o_a_proj'], - target_modules='all-linear', - ) - - -def save_checkpoint(model: TransformersModel, checkpoint_name: str, dataloader: DataLoader): - return model.save( - name=checkpoint_name, - output_dir=OUTPUT_DIR, - adapter_name=ADAPTER_NAME, - save_optimizer=True, - consumed_train_samples=dataloader.get_state()['consumed_train_samples'], - ) - - -def train(): - config = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True) - text_config = getattr(config, 'text_config', config) - if hasattr(text_config, 'use_cache'): - text_config.use_cache = False - - dataset = Dataset(dataset_meta=DatasetMeta(DATASET_ID)) - dataset.set_template(TEMPLATE_ID, model_id=MODEL_ID) - dataset.map(SelfCognitionProcessor('twinkle', 'ModelScope')) - dataset.encode(batched=True) - dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, device_mesh=device_mesh) - - model = TransformersModel( - model_id=MODEL_ID, - config=config, - device_mesh=device_mesh, - strategy='native_fsdp', - memory_efficient_init=True, - fsdp_config={ - 'expert_parallel': { - 'enabled': ENABLE_EP, - 'router_dtype': 'fp32', - 'keep_router_logits': False, - } - }, - ) - lora_cfg = _build_lora_config(ENABLE_EP) - model.add_adapter_to_model(ADAPTER_NAME, lora_cfg, gradient_accumulation_steps=GRAD_ACCUM_STEPS) - model.set_optimizer('AdamW', lr=LR, foreach=False) - model.set_lr_scheduler( - scheduler_cls='CosineWarmupScheduler', - num_warmup_steps=5, - num_training_steps=len(dataloader), - ) - - if RESUME_FROM_CHECKPOINT: - checkpoint_path = Path(RESUME_FROM_CHECKPOINT).expanduser().resolve() - kwargs = {} - if ADAPTER_NAME: - kwargs['adapter_name'] = ADAPTER_NAME - progress = model.resume_from_checkpoint( - str(checkpoint_path), resume_only_model=RESUME_ONLY_MODEL, **kwargs) - if not IGNORE_DATA_SKIP: - dataloader.resume_from_checkpoint(progress['consumed_train_samples']) - - logger.info(get_device_placement()) - logger.info(model.get_train_configs()) - logger.info( - f'Total steps: {len(dataloader)}, batch_size={BATCH_SIZE}, grad_accum={GRAD_ACCUM_STEPS}, ' - f'enable_ep={ENABLE_EP}, output_dir={OUTPUT_DIR}') - - optimizer_group = model.optimizer_group[ADAPTER_NAME] - for batch in dataloader: - if callable(batch): - batch = batch() - model.forward_backward(inputs=batch) - model.clip_grad_and_step(max_grad_norm=MAX_GRAD_NORM, gradient_accumulation_steps=GRAD_ACCUM_STEPS) - cur_step = optimizer_group.cur_step - if cur_step > 0 and cur_step % LOG_INTERVAL == 0: - metric = model.calculate_metric(is_training=True) - if callable(metric): - metric = metric() - logger.info(f'Current is step {cur_step} of {len(dataloader)}, metric: {metric}') - - final_checkpoint = save_checkpoint(model, 'checkpoint-final', dataloader) - logger.info(f'Saved final adapter to {final_checkpoint}') - - -if __name__ == '__main__': - train() +# Copyright (c) ModelScope Contributors. All rights reserved. +"""EP + FSDP2 + LoRA SFT cookbook for DeepSeek-V4. + +Run on 8 GPUs: + torchrun --nproc-per-node=8 cookbook/transformers/ep_fsdp2_lora_deepseek_v4.py +""" +import os +from pathlib import Path + +from peft import LoraConfig +from transformers import AutoConfig + +import twinkle +from twinkle import DeviceMesh, Platform, get_device_placement, get_logger +from twinkle.dataloader import DataLoader +from twinkle.dataset import Dataset, DatasetMeta +from twinkle.model import TransformersModel +from twinkle.preprocessor import SelfCognitionProcessor + +logger = get_logger() + +MODEL_ID = os.environ.get('DSV4_MODEL_ID', 'ms://deepseek-ai/DeepSeek-V4-Flash') +DATASET_ID = os.environ.get('DATASET_ID', 'ms://swift/self-cognition') +TEMPLATE_ID = os.environ.get('TEMPLATE_ID', 'DeepseekV4Template') +BATCH_SIZE = int(os.environ.get('BATCH_SIZE', '4')) +GRAD_ACCUM_STEPS = int(os.environ.get('GRAD_ACCUM_STEPS', '4')) +LOG_INTERVAL = GRAD_ACCUM_STEPS +LR = float(os.environ.get('LR', '1e-4')) +MAX_GRAD_NORM = float(os.environ.get('MAX_GRAD_NORM', '1.0')) +LORA_R = int(os.environ.get('LORA_R', '8')) +LORA_ALPHA = int(os.environ.get('LORA_ALPHA', '32')) +ENABLE_EP = os.environ.get('ENABLE_EP', '1') == '1' +OUTPUT_DIR = os.environ.get('OUTPUT_DIR', './output_dsv4') +RESUME_FROM_CHECKPOINT = os.environ.get('RESUME_FROM_CHECKPOINT') or None +RESUME_ONLY_MODEL = os.environ.get('RESUME_ONLY_MODEL', '0') == '1' +IGNORE_DATA_SKIP = os.environ.get('IGNORE_DATA_SKIP', '0') == '1' +ADAPTER_NAME = os.environ.get('ADAPTER_NAME', 'default') +NUM_GPUS = int(os.environ.get('NUM_GPUS', '8')) + +device_mesh = DeviceMesh.from_sizes( + fsdp_size=NUM_GPUS, + dp_size=1, + ep_size=NUM_GPUS, + device_type=Platform.get_platform().device_prefix(), +) +twinkle.initialize(mode='local', global_device_mesh=device_mesh) + + +def _build_lora_config(enable_ep: bool): + if enable_ep: + return LoraConfig( + r=LORA_R, + lora_alpha=LORA_ALPHA, + target_modules='all-linear', + exclude_modules=['o_a_proj'], + target_parameters=['mlp.experts.gate_up_proj', 'mlp.experts.down_proj'], + ) + # Expert weights are bare nn.Parameters. PEFT trains them through + # target_parameters/ParamWrapper, which dynamically parametrizes weights + # during forward. That is not stable with plain FSDP2, so non-EP mode uses + # regular module LoRA and does not train expert parameters. + return LoraConfig( + r=LORA_R, + lora_alpha=LORA_ALPHA, + exclude_modules=['o_a_proj'], + target_modules='all-linear', + ) + + +def save_checkpoint(model: TransformersModel, checkpoint_name: str, dataloader: DataLoader): + return model.save( + name=checkpoint_name, + output_dir=OUTPUT_DIR, + adapter_name=ADAPTER_NAME, + save_optimizer=True, + consumed_train_samples=dataloader.get_state()['consumed_train_samples'], + ) + + +def train(): + config = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True) + text_config = getattr(config, 'text_config', config) + if hasattr(text_config, 'use_cache'): + text_config.use_cache = False + + dataset = Dataset(dataset_meta=DatasetMeta(DATASET_ID)) + dataset.set_template(TEMPLATE_ID, model_id=MODEL_ID) + dataset.map(SelfCognitionProcessor('twinkle', 'ModelScope')) + dataset.encode(batched=True) + dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, device_mesh=device_mesh) + + model = TransformersModel( + model_id=MODEL_ID, + config=config, + device_mesh=device_mesh, + strategy='native_fsdp', + memory_efficient_init=True, + fsdp_config={ + 'expert_parallel': { + 'enabled': ENABLE_EP, + 'router_dtype': 'fp32', + 'keep_router_logits': False, + } + }, + ) + lora_cfg = _build_lora_config(ENABLE_EP) + model.add_adapter_to_model(ADAPTER_NAME, lora_cfg, gradient_accumulation_steps=GRAD_ACCUM_STEPS) + model.set_optimizer('AdamW', lr=LR, foreach=False) + model.set_lr_scheduler( + scheduler_cls='CosineWarmupScheduler', + num_warmup_steps=5, + num_training_steps=len(dataloader), + ) + + if RESUME_FROM_CHECKPOINT: + checkpoint_path = Path(RESUME_FROM_CHECKPOINT).expanduser().resolve() + kwargs = {} + if ADAPTER_NAME: + kwargs['adapter_name'] = ADAPTER_NAME + progress = model.resume_from_checkpoint( + str(checkpoint_path), resume_only_model=RESUME_ONLY_MODEL, **kwargs) + if not IGNORE_DATA_SKIP: + dataloader.resume_from_checkpoint(progress['consumed_train_samples']) + + logger.info(get_device_placement()) + logger.info(model.get_train_configs()) + logger.info( + f'Total steps: {len(dataloader)}, batch_size={BATCH_SIZE}, grad_accum={GRAD_ACCUM_STEPS}, ' + f'enable_ep={ENABLE_EP}, output_dir={OUTPUT_DIR}') + + optimizer_group = model.optimizer_group[ADAPTER_NAME] + for batch in dataloader: + if callable(batch): + batch = batch() + model.forward_backward(inputs=batch) + model.clip_grad_and_step(max_grad_norm=MAX_GRAD_NORM, gradient_accumulation_steps=GRAD_ACCUM_STEPS) + cur_step = optimizer_group.cur_step + if cur_step > 0 and cur_step % LOG_INTERVAL == 0: + metric = model.calculate_metric(is_training=True) + if callable(metric): + metric = metric() + logger.info(f'Current is step {cur_step} of {len(dataloader)}, metric: {metric}') + + final_checkpoint = save_checkpoint(model, 'checkpoint-final', dataloader) + logger.info(f'Saved final adapter to {final_checkpoint}') + + +if __name__ == '__main__': + train() diff --git a/cookbook/transformers/ep_fsdp2_lora_qwen3_5_moe.py b/cookbook/transformers/ep_fsdp2_lora_qwen3_5_moe.py index eb3efed6..03e962e6 100644 --- a/cookbook/transformers/ep_fsdp2_lora_qwen3_5_moe.py +++ b/cookbook/transformers/ep_fsdp2_lora_qwen3_5_moe.py @@ -1,153 +1,153 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -"""EP + FSDP2 + LoRA SFT cookbook for Qwen3.5-MoE. - -Run on 8 GPUs: - torchrun --nproc-per-node=8 cookbook/transformers/ep_fsdp2_lora_qwen3_5_moe.py -""" -import os -from pathlib import Path - -from peft import LoraConfig -from transformers import AutoConfig - -import twinkle -from twinkle import DeviceMesh, Platform, get_device_placement, get_logger -from twinkle.dataloader import DataLoader -from twinkle.dataset import Dataset, DatasetMeta -from twinkle.model import TransformersModel -from twinkle.preprocessor import SelfCognitionProcessor -from twinkle.utils.framework import Torch -from twinkle.kernel import kernelize, npu_builtin - -logger = get_logger() - -MODEL_ID = os.environ.get('QWEN3_MODEL_ID', 'ms://Qwen/Qwen3.6-35B-A3B') -DATASET_ID = os.environ.get('DATASET_ID', 'ms://swift/self-cognition') -TEMPLATE_ID = os.environ.get('TEMPLATE_ID', 'Qwen3_5Template') -BATCH_SIZE = int(os.environ.get('BATCH_SIZE', '4')) -GRAD_ACCUM_STEPS = int(os.environ.get('GRAD_ACCUM_STEPS', '4')) -LOG_INTERVAL = GRAD_ACCUM_STEPS -LR = float(os.environ.get('LR', '1e-4')) -MAX_GRAD_NORM = float(os.environ.get('MAX_GRAD_NORM', '1.0')) -LORA_R = int(os.environ.get('LORA_R', '8')) -LORA_ALPHA = int(os.environ.get('LORA_ALPHA', '32')) -ENABLE_EP = os.environ.get('ENABLE_EP', '1') == '1' -OUTPUT_DIR = os.environ.get('OUTPUT_DIR', './output') -RESUME_FROM_CHECKPOINT = os.environ.get('RESUME_FROM_CHECKPOINT') or None -RESUME_ONLY_MODEL = os.environ.get('RESUME_ONLY_MODEL', '0') == '1' -IGNORE_DATA_SKIP = os.environ.get('IGNORE_DATA_SKIP', '0') == '1' -ADAPTER_NAME = os.environ.get('ADAPTER_NAME', 'default') - -device_mesh = DeviceMesh.from_sizes( - fsdp_size=8, - dp_size=1, - ep_size=8, - device_type=Platform.get_platform().device_prefix(), -) -twinkle.initialize(mode='local', global_device_mesh=device_mesh) - - -def _build_lora_config(enable_ep: bool): - if enable_ep: - return LoraConfig( - r=LORA_R, - lora_alpha=LORA_ALPHA, - target_modules='all-linear', - target_parameters=['mlp.experts.gate_up_proj', 'mlp.experts.down_proj'], - ) - # Expert weights are bare nn.Parameters. PEFT trains them through - # target_parameters/ParamWrapper, which dynamically parametrizes weights - # during forward. That is not stable with plain FSDP2, so non-EP mode uses - # regular module LoRA and does not train expert parameters. - return LoraConfig( - r=LORA_R, - lora_alpha=LORA_ALPHA, - target_modules='all-linear', - ) - - -def save_checkpoint(model: TransformersModel, checkpoint_name: str, dataloader: DataLoader): - return model.save( - name=checkpoint_name, - output_dir=OUTPUT_DIR, - adapter_name=ADAPTER_NAME, - save_optimizer=True, - consumed_train_samples=dataloader.get_state()['consumed_train_samples'], - ) - - -def train(): - config = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True) - text_config = getattr(config, 'text_config', config) - if hasattr(text_config, 'use_cache'): - text_config.use_cache = False - - dataset = Dataset(dataset_meta=DatasetMeta(DATASET_ID)) - try: - dataset.set_template(TEMPLATE_ID, model_id=MODEL_ID) - except ValueError: - dataset.set_template('Qwen3_5Template', model_id=MODEL_ID) - dataset.map(SelfCognitionProcessor('twinkle', 'ModelScope')) - dataset.encode(batched=True) - dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, device_mesh=device_mesh) - - model = TransformersModel( - model_id=MODEL_ID, - config=config, - device_mesh=device_mesh, - strategy='native_fsdp', - fsdp_config={ - 'expert_parallel': { - 'enabled': ENABLE_EP, - 'router_dtype': 'fp32', - 'keep_router_logits': False, - } - }, - ) - # npu patch - if Torch.is_npu_available(): - model = kernelize(model, npu_builtin(model)) - lora_cfg = _build_lora_config(ENABLE_EP) - model.add_adapter_to_model(ADAPTER_NAME, lora_cfg, gradient_accumulation_steps=GRAD_ACCUM_STEPS) - model.set_optimizer('AdamW', lr=LR, foreach=False) - model.set_lr_scheduler( - scheduler_cls='CosineWarmupScheduler', - num_warmup_steps=5, - num_training_steps=len(dataloader), - ) - - if RESUME_FROM_CHECKPOINT: - checkpoint_path = Path(RESUME_FROM_CHECKPOINT).expanduser().resolve() - kwargs = {} - if ADAPTER_NAME: - kwargs['adapter_name'] = ADAPTER_NAME - progress = model.resume_from_checkpoint( - str(checkpoint_path), resume_only_model=RESUME_ONLY_MODEL, **kwargs) - if not IGNORE_DATA_SKIP: - dataloader.resume_from_checkpoint(progress['consumed_train_samples']) - - logger.info(get_device_placement()) - logger.info(model.get_train_configs()) - logger.info( - f'Total steps: {len(dataloader)}, batch_size={BATCH_SIZE}, grad_accum={GRAD_ACCUM_STEPS}, ' - f'enable_ep={ENABLE_EP}, output_dir={OUTPUT_DIR}') - - optimizer_group = model.optimizer_group[ADAPTER_NAME] - for batch in dataloader: - if callable(batch): - batch = batch() - model.forward_backward(inputs=batch) - model.clip_grad_and_step(max_grad_norm=MAX_GRAD_NORM, gradient_accumulation_steps=GRAD_ACCUM_STEPS) - cur_step = optimizer_group.cur_step - if cur_step > 0 and cur_step % LOG_INTERVAL == 0: - metric = model.calculate_metric(is_training=True) - if callable(metric): - metric = metric() - logger.info(f'Current is step {cur_step} of {len(dataloader)}, metric: {metric}') - - final_checkpoint = save_checkpoint(model, 'checkpoint-final', dataloader) - logger.info(f'Saved final adapter to {final_checkpoint}') - - -if __name__ == '__main__': - train() +# Copyright (c) ModelScope Contributors. All rights reserved. +"""EP + FSDP2 + LoRA SFT cookbook for Qwen3.5-MoE. + +Run on 8 GPUs: + torchrun --nproc-per-node=8 cookbook/transformers/ep_fsdp2_lora_qwen3_5_moe.py +""" +import os +from pathlib import Path + +from peft import LoraConfig +from transformers import AutoConfig + +import twinkle +from twinkle import DeviceMesh, Platform, get_device_placement, get_logger +from twinkle.dataloader import DataLoader +from twinkle.dataset import Dataset, DatasetMeta +from twinkle.model import TransformersModel +from twinkle.preprocessor import SelfCognitionProcessor +from twinkle.utils.framework import Torch +from twinkle.kernel import kernelize, npu_builtin + +logger = get_logger() + +MODEL_ID = os.environ.get('QWEN3_MODEL_ID', 'ms://Qwen/Qwen3.6-35B-A3B') +DATASET_ID = os.environ.get('DATASET_ID', 'ms://swift/self-cognition') +TEMPLATE_ID = os.environ.get('TEMPLATE_ID', 'Qwen3_5Template') +BATCH_SIZE = int(os.environ.get('BATCH_SIZE', '4')) +GRAD_ACCUM_STEPS = int(os.environ.get('GRAD_ACCUM_STEPS', '4')) +LOG_INTERVAL = GRAD_ACCUM_STEPS +LR = float(os.environ.get('LR', '1e-4')) +MAX_GRAD_NORM = float(os.environ.get('MAX_GRAD_NORM', '1.0')) +LORA_R = int(os.environ.get('LORA_R', '8')) +LORA_ALPHA = int(os.environ.get('LORA_ALPHA', '32')) +ENABLE_EP = os.environ.get('ENABLE_EP', '1') == '1' +OUTPUT_DIR = os.environ.get('OUTPUT_DIR', './output') +RESUME_FROM_CHECKPOINT = os.environ.get('RESUME_FROM_CHECKPOINT') or None +RESUME_ONLY_MODEL = os.environ.get('RESUME_ONLY_MODEL', '0') == '1' +IGNORE_DATA_SKIP = os.environ.get('IGNORE_DATA_SKIP', '0') == '1' +ADAPTER_NAME = os.environ.get('ADAPTER_NAME', 'default') + +device_mesh = DeviceMesh.from_sizes( + fsdp_size=8, + dp_size=1, + ep_size=8, + device_type=Platform.get_platform().device_prefix(), +) +twinkle.initialize(mode='local', global_device_mesh=device_mesh) + + +def _build_lora_config(enable_ep: bool): + if enable_ep: + return LoraConfig( + r=LORA_R, + lora_alpha=LORA_ALPHA, + target_modules='all-linear', + target_parameters=['mlp.experts.gate_up_proj', 'mlp.experts.down_proj'], + ) + # Expert weights are bare nn.Parameters. PEFT trains them through + # target_parameters/ParamWrapper, which dynamically parametrizes weights + # during forward. That is not stable with plain FSDP2, so non-EP mode uses + # regular module LoRA and does not train expert parameters. + return LoraConfig( + r=LORA_R, + lora_alpha=LORA_ALPHA, + target_modules='all-linear', + ) + + +def save_checkpoint(model: TransformersModel, checkpoint_name: str, dataloader: DataLoader): + return model.save( + name=checkpoint_name, + output_dir=OUTPUT_DIR, + adapter_name=ADAPTER_NAME, + save_optimizer=True, + consumed_train_samples=dataloader.get_state()['consumed_train_samples'], + ) + + +def train(): + config = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True) + text_config = getattr(config, 'text_config', config) + if hasattr(text_config, 'use_cache'): + text_config.use_cache = False + + dataset = Dataset(dataset_meta=DatasetMeta(DATASET_ID)) + try: + dataset.set_template(TEMPLATE_ID, model_id=MODEL_ID) + except ValueError: + dataset.set_template('Qwen3_5Template', model_id=MODEL_ID) + dataset.map(SelfCognitionProcessor('twinkle', 'ModelScope')) + dataset.encode(batched=True) + dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, device_mesh=device_mesh) + + model = TransformersModel( + model_id=MODEL_ID, + config=config, + device_mesh=device_mesh, + strategy='native_fsdp', + fsdp_config={ + 'expert_parallel': { + 'enabled': ENABLE_EP, + 'router_dtype': 'fp32', + 'keep_router_logits': False, + } + }, + ) + # npu patch + if Torch.is_npu_available(): + model = kernelize(model, npu_builtin(model)) + lora_cfg = _build_lora_config(ENABLE_EP) + model.add_adapter_to_model(ADAPTER_NAME, lora_cfg, gradient_accumulation_steps=GRAD_ACCUM_STEPS) + model.set_optimizer('AdamW', lr=LR, foreach=False) + model.set_lr_scheduler( + scheduler_cls='CosineWarmupScheduler', + num_warmup_steps=5, + num_training_steps=len(dataloader), + ) + + if RESUME_FROM_CHECKPOINT: + checkpoint_path = Path(RESUME_FROM_CHECKPOINT).expanduser().resolve() + kwargs = {} + if ADAPTER_NAME: + kwargs['adapter_name'] = ADAPTER_NAME + progress = model.resume_from_checkpoint( + str(checkpoint_path), resume_only_model=RESUME_ONLY_MODEL, **kwargs) + if not IGNORE_DATA_SKIP: + dataloader.resume_from_checkpoint(progress['consumed_train_samples']) + + logger.info(get_device_placement()) + logger.info(model.get_train_configs()) + logger.info( + f'Total steps: {len(dataloader)}, batch_size={BATCH_SIZE}, grad_accum={GRAD_ACCUM_STEPS}, ' + f'enable_ep={ENABLE_EP}, output_dir={OUTPUT_DIR}') + + optimizer_group = model.optimizer_group[ADAPTER_NAME] + for batch in dataloader: + if callable(batch): + batch = batch() + model.forward_backward(inputs=batch) + model.clip_grad_and_step(max_grad_norm=MAX_GRAD_NORM, gradient_accumulation_steps=GRAD_ACCUM_STEPS) + cur_step = optimizer_group.cur_step + if cur_step > 0 and cur_step % LOG_INTERVAL == 0: + metric = model.calculate_metric(is_training=True) + if callable(metric): + metric = metric() + logger.info(f'Current is step {cur_step} of {len(dataloader)}, metric: {metric}') + + final_checkpoint = save_checkpoint(model, 'checkpoint-final', dataloader) + logger.info(f'Saved final adapter to {final_checkpoint}') + + +if __name__ == '__main__': + train() diff --git a/cookbook/transformers/fsdp2.py b/cookbook/transformers/fsdp2.py index 0ccc6a32..dd7c0cb1 100644 --- a/cookbook/transformers/fsdp2.py +++ b/cookbook/transformers/fsdp2.py @@ -1,109 +1,109 @@ -from pathlib import Path - -from peft import LoraConfig -from tqdm import tqdm - -import twinkle -from twinkle import DeviceMesh, get_device_placement, get_logger -from twinkle.cli import CLI -from twinkle.dataloader import DataLoader -from twinkle.dataset import Dataset, DatasetMeta -from twinkle.model import TransformersModel -from twinkle.preprocessor import SelfCognitionProcessor -from twinkle.utils.framework import Torch -from twinkle.kernel import kernelize, npu_builtin - -logger = get_logger() -args = CLI.from_args() - -device_mesh = DeviceMesh.from_sizes(fsdp_size=args.infra.fsdp_size, dp_size=args.infra.dp_size) -twinkle.initialize(mode=args.infra.mode, global_device_mesh=device_mesh) - - -def build_dataset(num_samples: int) -> Dataset: - dataset = Dataset(dataset_meta=DatasetMeta(args.dataset.dataset_id, data_slice=range(num_samples))) - dataset.set_template(args.template.template_cls, model_id=args.model.model_id) - dataset.map(SelfCognitionProcessor( - args.extra.get('model_name', 'twinkle大模型'), - args.extra.get('model_author', 'ModelScope社区'), - )) - dataset.encode() - return dataset - - -def save_checkpoint(model: TransformersModel, checkpoint_name: str, dataloader: DataLoader): - model.save( - checkpoint_name, - output_dir=args.training.output_dir, - adapter_name=args.lora.adapter_name, - save_optimizer=True, - consumed_train_samples=dataloader.get_state()['consumed_train_samples'], - ) - - -def evaluate(model): - eval_samples = args.training.eval_samples or 100 - dataloader = DataLoader(dataset=build_dataset(eval_samples), batch_size=args.training.batch_size) - for batch in tqdm(dataloader): - model.forward_only(inputs=batch) - model.calculate_loss() - return model.calculate_metric(is_training=False) - - -def train(): - train_samples = int(args.extra.get('train_samples', 1000)) - dataset = build_dataset(train_samples) - dataloader = DataLoader(dataset=dataset, batch_size=args.training.batch_size) - model = TransformersModel(model_id=args.model.model_id) - model.model._no_split_modules = {'Qwen3_5DecoderLayer'} - # npu patch - if Torch.is_npu_available(): - model = kernelize(model, npu_builtin(model)) - - lora_config = LoraConfig(**args.get_lora_args()) - model.add_adapter_to_model( - args.lora.adapter_name, lora_config, - gradient_accumulation_steps=args.training.gradient_accumulation_steps) - model.set_optimizer(optimizer_cls=args.optimizer.optimizer_cls, lr=args.optimizer.learning_rate) - - # Add LRScheduler for lora `default` - model.set_lr_scheduler( - scheduler_cls=args.scheduler.scheduler_cls, - num_warmup_steps=args.scheduler.num_warmup_steps, - num_training_steps=len(dataloader)) - - if args.training.resume_from_checkpoint: - checkpoint_path = Path(args.training.resume_from_checkpoint).expanduser().resolve() - progress = model.resume_from_checkpoint( - str(checkpoint_path), - resume_only_model=args.training.resume_only_model, - adapter_name=args.lora.adapter_name) - if not args.training.ignore_data_skip: - dataloader.resume_from_checkpoint(progress['consumed_train_samples']) - - logger.info(get_device_placement()) - logger.info(model.get_train_configs()) - logger.info(f'Total steps: {len(dataloader)}') - optimizer_group = model.optimizer_group[args.lora.adapter_name] - best_loss = float('inf') - eval_interval = args.training.eval_interval or 40 - for batch in dataloader: - model.forward_backward(inputs=batch) - model.clip_grad_and_step() - cur_step = optimizer_group.cur_step - if cur_step % args.training.log_interval == 0: - metric = model.calculate_metric(is_training=True) - logger.info(f'Current is step {cur_step} of {len(dataloader)}, metric: {metric}') - if cur_step > 0 and cur_step % eval_interval == 0: - metrics = evaluate(model) - logger.info(f'Eval metric: {metrics}') - metrics['step'] = cur_step - current_loss = float(metrics['loss']) - if current_loss < best_loss: - save_checkpoint(model, f'checkpoint-{cur_step}', dataloader) - best_loss = current_loss - save_checkpoint(model, 'last-checkpoint', dataloader) - - -if __name__ == '__main__': - train() +from pathlib import Path + +from peft import LoraConfig +from tqdm import tqdm + +import twinkle +from twinkle import DeviceMesh, get_device_placement, get_logger +from twinkle.cli import CLI +from twinkle.dataloader import DataLoader +from twinkle.dataset import Dataset, DatasetMeta +from twinkle.model import TransformersModel +from twinkle.preprocessor import SelfCognitionProcessor +from twinkle.utils.framework import Torch +from twinkle.kernel import kernelize, npu_builtin + +logger = get_logger() +args = CLI.from_args() + +device_mesh = DeviceMesh.from_sizes(fsdp_size=args.infra.fsdp_size, dp_size=args.infra.dp_size) +twinkle.initialize(mode=args.infra.mode, global_device_mesh=device_mesh) + + +def build_dataset(num_samples: int) -> Dataset: + dataset = Dataset(dataset_meta=DatasetMeta(args.dataset.dataset_id, data_slice=range(num_samples))) + dataset.set_template(args.template.template_cls, model_id=args.model.model_id) + dataset.map(SelfCognitionProcessor( + args.extra.get('model_name', 'twinkle大模型'), + args.extra.get('model_author', 'ModelScope社区'), + )) + dataset.encode() + return dataset + + +def save_checkpoint(model: TransformersModel, checkpoint_name: str, dataloader: DataLoader): + model.save( + checkpoint_name, + output_dir=args.training.output_dir, + adapter_name=args.lora.adapter_name, + save_optimizer=True, + consumed_train_samples=dataloader.get_state()['consumed_train_samples'], + ) + + +def evaluate(model): + eval_samples = args.training.eval_samples or 100 + dataloader = DataLoader(dataset=build_dataset(eval_samples), batch_size=args.training.batch_size) + for batch in tqdm(dataloader): + model.forward_only(inputs=batch) + model.calculate_loss() + return model.calculate_metric(is_training=False) + + +def train(): + train_samples = int(args.extra.get('train_samples', 1000)) + dataset = build_dataset(train_samples) + dataloader = DataLoader(dataset=dataset, batch_size=args.training.batch_size) + model = TransformersModel(model_id=args.model.model_id) + model.model._no_split_modules = {'Qwen3_5DecoderLayer'} + # npu patch + if Torch.is_npu_available(): + model = kernelize(model, npu_builtin(model)) + + lora_config = LoraConfig(**args.get_lora_args()) + model.add_adapter_to_model( + args.lora.adapter_name, lora_config, + gradient_accumulation_steps=args.training.gradient_accumulation_steps) + model.set_optimizer(optimizer_cls=args.optimizer.optimizer_cls, lr=args.optimizer.learning_rate) + + # Add LRScheduler for lora `default` + model.set_lr_scheduler( + scheduler_cls=args.scheduler.scheduler_cls, + num_warmup_steps=args.scheduler.num_warmup_steps, + num_training_steps=len(dataloader)) + + if args.training.resume_from_checkpoint: + checkpoint_path = Path(args.training.resume_from_checkpoint).expanduser().resolve() + progress = model.resume_from_checkpoint( + str(checkpoint_path), + resume_only_model=args.training.resume_only_model, + adapter_name=args.lora.adapter_name) + if not args.training.ignore_data_skip: + dataloader.resume_from_checkpoint(progress['consumed_train_samples']) + + logger.info(get_device_placement()) + logger.info(model.get_train_configs()) + logger.info(f'Total steps: {len(dataloader)}') + optimizer_group = model.optimizer_group[args.lora.adapter_name] + best_loss = float('inf') + eval_interval = args.training.eval_interval or 40 + for batch in dataloader: + model.forward_backward(inputs=batch) + model.clip_grad_and_step() + cur_step = optimizer_group.cur_step + if cur_step % args.training.log_interval == 0: + metric = model.calculate_metric(is_training=True) + logger.info(f'Current is step {cur_step} of {len(dataloader)}, metric: {metric}') + if cur_step > 0 and cur_step % eval_interval == 0: + metrics = evaluate(model) + logger.info(f'Eval metric: {metrics}') + metrics['step'] = cur_step + current_loss = float(metrics['loss']) + if current_loss < best_loss: + save_checkpoint(model, f'checkpoint-{cur_step}', dataloader) + best_loss = current_loss + save_checkpoint(model, 'last-checkpoint', dataloader) + + +if __name__ == '__main__': + train() diff --git a/cookbook/transformers/sp_fsdp_dense.py b/cookbook/transformers/sp_fsdp_dense.py index 8a8fb412..2fd4ecf1 100644 --- a/cookbook/transformers/sp_fsdp_dense.py +++ b/cookbook/transformers/sp_fsdp_dense.py @@ -1,99 +1,99 @@ -import numpy as np -from functools import partial -from peft import LoraConfig - -import twinkle -from twinkle import DeviceGroup, DeviceMesh, Platform, get_logger -from twinkle.dataloader import DataLoader -from twinkle.dataset import Dataset, DatasetMeta -from twinkle.model import TransformersModel -from twinkle.preprocessor import SelfCognitionProcessor -from twinkle.utils.framework import Torch -from twinkle.kernel import kernelize, npu_builtin - -logger = get_logger() -MODEL_ID = 'ms://Qwen/Qwen3.5-4B' -DATASETS = 'ms://swift/self-cognition' - -device_group = [DeviceGroup( - name='default', - ranks=[0, 1, 2, 3], - device_type=Platform.get_platform().device_prefix(), -)] - -# FSDP + sequence-parallel validation over 4 GPUs: dp=2, fsdp=2. -# In Transformers route, ulysses_size is the total sequence-parallel degree. -device_mesh = DeviceMesh( - device_type=Platform.get_platform().device_prefix(), - mesh=np.arange(4).reshape(2, 2), - mesh_dim_names=('dp', 'fsdp'), - ulysses_size=2, -) - -twinkle.initialize( - mode='local', - nproc_per_node=4, - global_device_mesh=device_mesh, - lazy_collect=False, -) - - -def eval(model): - dataloader = DataLoader( - dataset=partial(create_dataset, data_slice=range(100)), - batch_size=4, - device_mesh=device_mesh, - ) - for _, batch in enumerate(dataloader): - model.forward_only(inputs=batch, adapter_name='default') - model.calculate_loss(adapter_name='default') - return model.calculate_metric(is_training=False, adapter_name='default') - - -def create_dataset(data_slice=None): - dataset = Dataset(dataset_meta=DatasetMeta(DATASETS, data_slice=range(500))) - dataset.set_template('Qwen3_5Template', model_id=MODEL_ID) - dataset.map(SelfCognitionProcessor('twinkle模型', 'twinkle团队')) - dataset.encode(batched=True) - return dataset - - -def train(): - dataloader = DataLoader( - dataset=partial(create_dataset, data_slice=None), - batch_size=8, - device_mesh=device_mesh, - ) - - model = TransformersModel( - model_id=MODEL_ID, - device_mesh=device_mesh, - strategy='native_fsdp', - ) - # npu patch - if Torch.is_npu_available(): - model = kernelize(model, npu_builtin(model)) - lora_config = LoraConfig(target_modules='all-linear') - model.add_adapter_to_model('default', lora_config, gradient_accumulation_steps=1) - model.set_optimizer('AdamW', lr=1e-4, adapter_name='default') - model.set_lr_scheduler( - scheduler_cls='CosineWarmupScheduler', - num_warmup_steps=5, - num_training_steps=len(dataloader), - adapter_name='default', - ) - - logger.info(model.get_train_configs(adapter_name='default')) - logger.info(f'Total steps: {len(dataloader)}') - - for step, batch in enumerate(dataloader): - model.forward_backward(inputs=batch, adapter_name='default') - model.clip_grad_and_step(adapter_name='default') - if step % 20 == 0: - metric = model.calculate_metric(is_training=True, adapter_name='default') - logger.info(f'Current is step {step} of {len(dataloader)}, metric: {metric}') - model.save('last-checkpoint', interval=1) - - -if __name__ == '__main__': - train() +import numpy as np +from functools import partial +from peft import LoraConfig + +import twinkle +from twinkle import DeviceGroup, DeviceMesh, Platform, get_logger +from twinkle.dataloader import DataLoader +from twinkle.dataset import Dataset, DatasetMeta +from twinkle.model import TransformersModel +from twinkle.preprocessor import SelfCognitionProcessor +from twinkle.utils.framework import Torch +from twinkle.kernel import kernelize, npu_builtin + +logger = get_logger() +MODEL_ID = 'ms://Qwen/Qwen3.5-4B' +DATASETS = 'ms://swift/self-cognition' + +device_group = [DeviceGroup( + name='default', + ranks=[0, 1, 2, 3], + device_type=Platform.get_platform().device_prefix(), +)] + +# FSDP + sequence-parallel validation over 4 GPUs: dp=2, fsdp=2. +# In Transformers route, ulysses_size is the total sequence-parallel degree. +device_mesh = DeviceMesh( + device_type=Platform.get_platform().device_prefix(), + mesh=np.arange(4).reshape(2, 2), + mesh_dim_names=('dp', 'fsdp'), + ulysses_size=2, +) + +twinkle.initialize( + mode='local', + nproc_per_node=4, + global_device_mesh=device_mesh, + lazy_collect=False, +) + + +def eval(model): + dataloader = DataLoader( + dataset=partial(create_dataset, data_slice=range(100)), + batch_size=4, + device_mesh=device_mesh, + ) + for _, batch in enumerate(dataloader): + model.forward_only(inputs=batch, adapter_name='default') + model.calculate_loss(adapter_name='default') + return model.calculate_metric(is_training=False, adapter_name='default') + + +def create_dataset(data_slice=None): + dataset = Dataset(dataset_meta=DatasetMeta(DATASETS, data_slice=range(500))) + dataset.set_template('Qwen3_5Template', model_id=MODEL_ID) + dataset.map(SelfCognitionProcessor('twinkle模型', 'twinkle团队')) + dataset.encode(batched=True) + return dataset + + +def train(): + dataloader = DataLoader( + dataset=partial(create_dataset, data_slice=None), + batch_size=8, + device_mesh=device_mesh, + ) + + model = TransformersModel( + model_id=MODEL_ID, + device_mesh=device_mesh, + strategy='native_fsdp', + ) + # npu patch + if Torch.is_npu_available(): + model = kernelize(model, npu_builtin(model)) + lora_config = LoraConfig(target_modules='all-linear') + model.add_adapter_to_model('default', lora_config, gradient_accumulation_steps=1) + model.set_optimizer('AdamW', lr=1e-4, adapter_name='default') + model.set_lr_scheduler( + scheduler_cls='CosineWarmupScheduler', + num_warmup_steps=5, + num_training_steps=len(dataloader), + adapter_name='default', + ) + + logger.info(model.get_train_configs(adapter_name='default')) + logger.info(f'Total steps: {len(dataloader)}') + + for step, batch in enumerate(dataloader): + model.forward_backward(inputs=batch, adapter_name='default') + model.clip_grad_and_step(adapter_name='default') + if step % 20 == 0: + metric = model.calculate_metric(is_training=True, adapter_name='default') + logger.info(f'Current is step {step} of {len(dataloader)}, metric: {metric}') + model.save('last-checkpoint', interval=1) + + +if __name__ == '__main__': + train() diff --git a/docs/source_en/Components/Kernel/Kernel.md b/docs/source_en/Components/Kernel/Kernel.md index f9f168c3..f5ab78e9 100644 --- a/docs/source_en/Components/Kernel/Kernel.md +++ b/docs/source_en/Components/Kernel/Kernel.md @@ -1,139 +1,139 @@ -# Twinkle Kernel - -`twinkle.kernel` exposes a mapping-driven kernel replacement API. Replacing one -implementation with another collapses to a single `kernelize(model, mapping)` -call. - -The public surface is exactly three symbols: - -| Symbol | Purpose | -| --- | --- | -| `kernelize(model, mapping)` | Apply ``mapping`` to ``model`` (in place) and return it | -| `npu_builtin(model=None)` | Return the Ascend NPU built-in mapping (composes with user mappings) | -| `hub(ref, *, revision=None, version=None, backend=None, trust_remote_code=False)` | Build a ``HubRef`` for use as a mapping value; the actual Hub download is deferred to ``kernelize`` | - -## Mapping semantics - -`mapping` keys describe the target to replace: - -- `type[nn.Module]` subclass — replace **every** instance whose exact type matches (`m.__class__ = impl`; subclasses are **not** touched) -- `str` of the form `'pkg.sub.attr'` or `'pkg.sub.ClassName.attr'` — `setattr(target, attr, impl)` - -`mapping` values describe the replacement: - -- `type[nn.Module]` subclass — used as the impl class. The class' `__init__` is **never** invoked; its forward must work against the attributes the original instance already has -- `Callable` — assigned with `setattr` -- `dict[str, V]` — device → impl dispatch. Device is inferred from the model; entries without a matching key are **silently skipped** -- `HubRef` — built via `hub(...)`; resolved lazily - -Device is inferred from `next(model.parameters()).device.type` (falling back to buffers, then `'cpu'`). - -## Examples - -### Enable the full NPU built-in bundle - -```python -import torch -from twinkle.kernel import kernelize, npu_builtin - -if torch.npu.is_available(): - model = kernelize(model, npu_builtin(model)) -``` - -### Custom class replacement - -```python -from transformers.models.qwen2.modeling_qwen2 import Qwen2RMSNorm -from twinkle.kernel import kernelize - -model = kernelize(model, {Qwen2RMSNorm: MyRMSNorm}) -``` - -### Built-in + custom override - -```python -from twinkle.kernel import kernelize, npu_builtin - -model = kernelize(model, {**npu_builtin(model), Qwen2RMSNorm: MyRMSNorm}) -``` - -Plain dict merge — later keys override earlier ones. - -### Hub kernel (HF Hub format) - -```python -from twinkle.kernel import kernelize, hub -from my_pkg import SiluAndMul - -model = kernelize(model, { - SiluAndMul: hub('kernels-community/activation:SiluAndMul', version=1), -}) -``` - -Exactly one of `revision` / `version` must be passed. The `kernels` package is imported lazily; absence raises a clear "install kernels" error. - -### Function-level replacement - -```python -from twinkle.kernel import kernelize -from twinkle.kernel.npu_impls.rotary import npu_apply_rotary_pos_emb - -model = kernelize(model, { - 'transformers.models.qwen2.modeling_qwen2.apply_rotary_pos_emb': - npu_apply_rotary_pos_emb, -}) -``` - -### Cross-device mapping (NPU enabled, CUDA skipped) - -```python -from twinkle.kernel import kernelize - -model = kernelize(model, { - Qwen2RMSNorm: {'npu': NpuRMSNorm, 'cuda': CudaRMSNorm}, -}) -``` - -Safe to run on CUDA — entries whose dict misses the current device just skip. - -## NPU built-in coverage - -`npu_builtin(model)` returns a dict that (as available transformers modules permit) covers: - -- RMSNorm class replacement for Qwen2 / Qwen3 / Qwen3-MoE / Qwen2.5-VL / Qwen3.5 / Qwen3.5-MoE families -- `apply_rotary_pos_emb` function replacement (fused RoPE) for the same families -- SwiGLU fused replacement for the MLP variants -- `Experts.forward` and `SparseMoeBlock.forward` for Qwen3-MoE / Qwen3.5-MoE -- GatedRMSNorm forward for Qwen3.5 / Qwen3.5-MoE -- `apply_multimodal_rotary_pos_emb` for Qwen2.5-VL -- Global SDPA replacement (one-shot side effect on `ALL_ATTENTION_FUNCTIONS['sdpa']`) -- Qwen3.5 Flash Linear Attention enablement (one-shot side effect + per-instance traversal, triggered inside `npu_builtin(model)`) - -**Not included by default:** the NPU replacement for `transformers.integrations.moe._grouped_mm`. Without Expert Parallelism the contiguous-copy overhead is ~8x. Opt in explicitly when EP is enabled: - -```python -from twinkle.kernel import kernelize, npu_builtin -from twinkle.kernel.npu_impls.moe import npu_grouped_mm - -mapping = { - **npu_builtin(model), - 'transformers.integrations.moe._grouped_mm': {'npu': npu_grouped_mm}, -} -model = kernelize(model, mapping) -``` - -## Environment variables - -Only two remain: - -- `TWINKLE_NPU_FLA` — Qwen3.5 FLA switch (default on; `0`/`false` to disable) -- `TWINKLE_NPU_GATED_RMSNorm_FP32` — force FP32 in Gated RMSNorm forward (default off) - -The legacy `TWINKLE_NPU_PATCH` / `TWINKLE_NPU_FUSED_OPS` / `TWINKLE_NPU_GMM_PATCH` / `TWINKLE_USE_KERNELS` are gone — they're now "include the entry in the mapping or don't" decisions. - -## Caveats - -- `m.__class__ = impl_cls` is Python class-replacement magic. The impl class **must** override only `forward` (and helpers); defining `__init__` is incompatible with the contract -- Exact match: `type(m) is target_cls`. Subclasses of `target_cls` are not replaced — add them to the mapping yourself -- `kernelize` is idempotent under repeated calls +# Twinkle Kernel + +`twinkle.kernel` exposes a mapping-driven kernel replacement API. Replacing one +implementation with another collapses to a single `kernelize(model, mapping)` +call. + +The public surface is exactly three symbols: + +| Symbol | Purpose | +| --- | --- | +| `kernelize(model, mapping)` | Apply ``mapping`` to ``model`` (in place) and return it | +| `npu_builtin(model=None)` | Return the Ascend NPU built-in mapping (composes with user mappings) | +| `hub(ref, *, revision=None, version=None, backend=None, trust_remote_code=False)` | Build a ``HubRef`` for use as a mapping value; the actual Hub download is deferred to ``kernelize`` | + +## Mapping semantics + +`mapping` keys describe the target to replace: + +- `type[nn.Module]` subclass — replace **every** instance whose exact type matches (`m.__class__ = impl`; subclasses are **not** touched) +- `str` of the form `'pkg.sub.attr'` or `'pkg.sub.ClassName.attr'` — `setattr(target, attr, impl)` + +`mapping` values describe the replacement: + +- `type[nn.Module]` subclass — used as the impl class. The class' `__init__` is **never** invoked; its forward must work against the attributes the original instance already has +- `Callable` — assigned with `setattr` +- `dict[str, V]` — device → impl dispatch. Device is inferred from the model; entries without a matching key are **silently skipped** +- `HubRef` — built via `hub(...)`; resolved lazily + +Device is inferred from `next(model.parameters()).device.type` (falling back to buffers, then `'cpu'`). + +## Examples + +### Enable the full NPU built-in bundle + +```python +import torch +from twinkle.kernel import kernelize, npu_builtin + +if torch.npu.is_available(): + model = kernelize(model, npu_builtin(model)) +``` + +### Custom class replacement + +```python +from transformers.models.qwen2.modeling_qwen2 import Qwen2RMSNorm +from twinkle.kernel import kernelize + +model = kernelize(model, {Qwen2RMSNorm: MyRMSNorm}) +``` + +### Built-in + custom override + +```python +from twinkle.kernel import kernelize, npu_builtin + +model = kernelize(model, {**npu_builtin(model), Qwen2RMSNorm: MyRMSNorm}) +``` + +Plain dict merge — later keys override earlier ones. + +### Hub kernel (HF Hub format) + +```python +from twinkle.kernel import kernelize, hub +from my_pkg import SiluAndMul + +model = kernelize(model, { + SiluAndMul: hub('kernels-community/activation:SiluAndMul', version=1), +}) +``` + +Exactly one of `revision` / `version` must be passed. The `kernels` package is imported lazily; absence raises a clear "install kernels" error. + +### Function-level replacement + +```python +from twinkle.kernel import kernelize +from twinkle.kernel.npu_impls.rotary import npu_apply_rotary_pos_emb + +model = kernelize(model, { + 'transformers.models.qwen2.modeling_qwen2.apply_rotary_pos_emb': + npu_apply_rotary_pos_emb, +}) +``` + +### Cross-device mapping (NPU enabled, CUDA skipped) + +```python +from twinkle.kernel import kernelize + +model = kernelize(model, { + Qwen2RMSNorm: {'npu': NpuRMSNorm, 'cuda': CudaRMSNorm}, +}) +``` + +Safe to run on CUDA — entries whose dict misses the current device just skip. + +## NPU built-in coverage + +`npu_builtin(model)` returns a dict that (as available transformers modules permit) covers: + +- RMSNorm class replacement for Qwen2 / Qwen3 / Qwen3-MoE / Qwen2.5-VL / Qwen3.5 / Qwen3.5-MoE families +- `apply_rotary_pos_emb` function replacement (fused RoPE) for the same families +- SwiGLU fused replacement for the MLP variants +- `Experts.forward` and `SparseMoeBlock.forward` for Qwen3-MoE / Qwen3.5-MoE +- GatedRMSNorm forward for Qwen3.5 / Qwen3.5-MoE +- `apply_multimodal_rotary_pos_emb` for Qwen2.5-VL +- Global SDPA replacement (one-shot side effect on `ALL_ATTENTION_FUNCTIONS['sdpa']`) +- Qwen3.5 Flash Linear Attention enablement (one-shot side effect + per-instance traversal, triggered inside `npu_builtin(model)`) + +**Not included by default:** the NPU replacement for `transformers.integrations.moe._grouped_mm`. Without Expert Parallelism the contiguous-copy overhead is ~8x. Opt in explicitly when EP is enabled: + +```python +from twinkle.kernel import kernelize, npu_builtin +from twinkle.kernel.npu_impls.moe import npu_grouped_mm + +mapping = { + **npu_builtin(model), + 'transformers.integrations.moe._grouped_mm': {'npu': npu_grouped_mm}, +} +model = kernelize(model, mapping) +``` + +## Environment variables + +Only two remain: + +- `TWINKLE_NPU_FLA` — Qwen3.5 FLA switch (default on; `0`/`false` to disable) +- `TWINKLE_NPU_GATED_RMSNorm_FP32` — force FP32 in Gated RMSNorm forward (default off) + +The legacy `TWINKLE_NPU_PATCH` / `TWINKLE_NPU_FUSED_OPS` / `TWINKLE_NPU_GMM_PATCH` / `TWINKLE_USE_KERNELS` are gone — they're now "include the entry in the mapping or don't" decisions. + +## Caveats + +- `m.__class__ = impl_cls` is Python class-replacement magic. The impl class **must** override only `forward` (and helpers); defining `__init__` is incompatible with the contract +- Exact match: `type(m) is target_cls`. Subclasses of `target_cls` are not replaced — add them to the mapping yourself +- `kernelize` is idempotent under repeated calls - There is no `unkernelize` — replacement is one-way \ No newline at end of file diff --git "a/docs/source_zh/\347\273\204\344\273\266/\345\206\205\346\240\270/Kernel.md" "b/docs/source_zh/\347\273\204\344\273\266/\345\206\205\346\240\270/Kernel.md" index a7687020..8ed5ad78 100644 --- "a/docs/source_zh/\347\273\204\344\273\266/\345\206\205\346\240\270/Kernel.md" +++ "b/docs/source_zh/\347\273\204\344\273\266/\345\206\205\346\240\270/Kernel.md" @@ -1,137 +1,137 @@ -# Twinkle Kernel 模块 - -`twinkle.kernel` 提供一个 mapping 驱动的内核替换接口,把“用一种实现替换模型里的另一种实现”压缩为一次 `kernelize(model, mapping)` 调用。 - -公开符号只有三个: - -| 符号 | 作用 | -| --- | --- | -| `kernelize(model, mapping)` | 在 `model` 上应用 `mapping`,原地修改后返回 | -| `npu_builtin(model=None)` | 返回 Ascend NPU 内置替换的 mapping dict(可与用户 mapping 自由组合) | -| `hub(ref, *, revision=None, version=None, backend=None, trust_remote_code=False)` | 构造一个 `HubRef`,用作 mapping value;真实下载推迟到 `kernelize` 执行 | - -## Mapping 语义 - -`mapping` 的 **key** 表示要替换的目标: - -- `type[nn.Module]` 子类:替换模型里**所有**该精确类型的实例(`m.__class__ = impl_class`,**不包含**子类) -- `str` 形如 `'pkg.sub.attr'` 或 `'pkg.sub.ClassName.attr'`:`setattr(target, attr, impl)` - -**value** 表示用什么替换: - -- `type[nn.Module]` 子类:直接作为 impl 类。该类**不会被 `__init__` 调用**,必须只依赖原 instance 已经有的 attribute(weight / eps / ...)正确工作 -- `Callable`:直接 `setattr` 上去 -- `dict[str, V]`:device → impl 嵌套分派。从 `model` 推断当前 device,未匹配则**静默跳过** -- `HubRef`:通过 `hub(...)` 构造的 Hub 引用,延迟加载 - -device 从 `next(model.parameters()).device.type` 推断(无参数则用 buffers,再无则为 `'cpu'`)。 - -## 场景示例 - -### 启用全部 NPU 内置优化 - -```python -import torch -from twinkle.kernel import kernelize, npu_builtin - -if torch.npu.is_available(): - model = kernelize(model, npu_builtin(model)) -``` - -### 自定义类替换 - -```python -from transformers.models.qwen2.modeling_qwen2 import Qwen2RMSNorm -from twinkle.kernel import kernelize - -model = kernelize(model, {Qwen2RMSNorm: MyRMSNorm}) -``` - -### 内置 + 自定义混合 - -```python -from twinkle.kernel import kernelize, npu_builtin - -model = kernelize(model, {**npu_builtin(model), Qwen2RMSNorm: MyRMSNorm}) -``` - -后写入的 key 会覆盖前面的,普通 dict 合并语义。 - -### Hub Kernel(HF Hub 格式) - -```python -from twinkle.kernel import kernelize, hub -from my_pkg import SiluAndMul - -model = kernelize(model, { - SiluAndMul: hub('kernels-community/activation:SiluAndMul', version=1), -}) -``` - -`revision` 与 `version` 二选一必传。`hub(...)` 触发 `kernels` 包的延迟 import,未安装时会提示 `pip install kernels`。 - -### 函数级替换 - -```python -from twinkle.kernel import kernelize -from twinkle.kernel.npu_impls.rotary import npu_apply_rotary_pos_emb - -model = kernelize(model, { - 'transformers.models.qwen2.modeling_qwen2.apply_rotary_pos_emb': - npu_apply_rotary_pos_emb, -}) -``` - -### 跨设备 mapping(NPU 启用、CUDA 跳过) - -```python -from twinkle.kernel import kernelize - -model = kernelize(model, { - Qwen2RMSNorm: {'npu': NpuRMSNorm, 'cuda': CudaRMSNorm}, -}) -``` - -在 CUDA 模型上跑也安全:未匹配 device 的 entry 不会替换、不会报错。 - -## 内置 NPU 优化 - -`npu_builtin(model)` 返回的 dict 至少包含以下覆盖(实际条目随 transformers 已安装的 modeling 模块动态收集): - -- Qwen2 / Qwen3 / Qwen3-MoE / Qwen2.5-VL / Qwen3.5 / Qwen3.5-MoE 系列的 RMSNorm 类替换 -- 同上系列的 `apply_rotary_pos_emb` 函数替换(融合 RoPE) -- 同上系列 MLP 的 SwiGLU 融合替换 -- Qwen3-MoE / Qwen3.5-MoE 的 `Experts.forward` 与 `SparseMoeBlock.forward` 替换 -- Qwen3.5 / Qwen3.5-MoE 的 GatedRMSNorm forward 替换 -- Qwen2.5-VL 的 `apply_multimodal_rotary_pos_emb` 替换 -- 全局 SDPA 替换(一次性副作用,写入 `ALL_ATTENTION_FUNCTIONS['sdpa']`) -- Qwen3.5 Flash Linear Attention 启用(一次性副作用 + 实例遍历,由 `npu_builtin(model)` 内部触发) - -**未默认包含** `transformers.integrations.moe._grouped_mm` 的 NPU 替换(在没有 Expert Parallelism 时会带来约 8x 开销)。需要时手动加入: - -```python -from twinkle.kernel import kernelize, npu_builtin -from twinkle.kernel.npu_impls.moe import npu_grouped_mm - -mapping = { - **npu_builtin(model), - 'transformers.integrations.moe._grouped_mm': {'npu': npu_grouped_mm}, -} -model = kernelize(model, mapping) -``` - -## 环境变量 - -只有两个保留: - -- `TWINKLE_NPU_FLA`:Qwen3.5 FLA 开关(默认开,设为 `0`/`false` 关闭) -- `TWINKLE_NPU_GATED_RMSNorm_FP32`:将 Gated RMSNorm 强制升到 FP32 计算(默认关) - -旧的 `TWINKLE_NPU_PATCH` / `TWINKLE_NPU_FUSED_OPS` / `TWINKLE_NPU_GMM_PATCH` / `TWINKLE_USE_KERNELS` 已移除——这些都改成"是否把对应 entry 写进 mapping"的显式选择。 - -## 注意事项 - -- `m.__class__ = impl_cls` 是 Python class 替换魔法。impl 类**必须**只覆盖 `forward`(以及辅助方法),不能定义 `__init__`,否则原 instance 的 attribute 会与 impl 的预期错位 -- 精确匹配:`type(m) is target_cls`。继承自 `target_cls` 的子类不会被替换;如需替换,把子类也放进 mapping -- 调用 `kernelize` 多次是幂等的(`__class__` 已是 impl 时再设一次无害) +# Twinkle Kernel 模块 + +`twinkle.kernel` 提供一个 mapping 驱动的内核替换接口,把“用一种实现替换模型里的另一种实现”压缩为一次 `kernelize(model, mapping)` 调用。 + +公开符号只有三个: + +| 符号 | 作用 | +| --- | --- | +| `kernelize(model, mapping)` | 在 `model` 上应用 `mapping`,原地修改后返回 | +| `npu_builtin(model=None)` | 返回 Ascend NPU 内置替换的 mapping dict(可与用户 mapping 自由组合) | +| `hub(ref, *, revision=None, version=None, backend=None, trust_remote_code=False)` | 构造一个 `HubRef`,用作 mapping value;真实下载推迟到 `kernelize` 执行 | + +## Mapping 语义 + +`mapping` 的 **key** 表示要替换的目标: + +- `type[nn.Module]` 子类:替换模型里**所有**该精确类型的实例(`m.__class__ = impl_class`,**不包含**子类) +- `str` 形如 `'pkg.sub.attr'` 或 `'pkg.sub.ClassName.attr'`:`setattr(target, attr, impl)` + +**value** 表示用什么替换: + +- `type[nn.Module]` 子类:直接作为 impl 类。该类**不会被 `__init__` 调用**,必须只依赖原 instance 已经有的 attribute(weight / eps / ...)正确工作 +- `Callable`:直接 `setattr` 上去 +- `dict[str, V]`:device → impl 嵌套分派。从 `model` 推断当前 device,未匹配则**静默跳过** +- `HubRef`:通过 `hub(...)` 构造的 Hub 引用,延迟加载 + +device 从 `next(model.parameters()).device.type` 推断(无参数则用 buffers,再无则为 `'cpu'`)。 + +## 场景示例 + +### 启用全部 NPU 内置优化 + +```python +import torch +from twinkle.kernel import kernelize, npu_builtin + +if torch.npu.is_available(): + model = kernelize(model, npu_builtin(model)) +``` + +### 自定义类替换 + +```python +from transformers.models.qwen2.modeling_qwen2 import Qwen2RMSNorm +from twinkle.kernel import kernelize + +model = kernelize(model, {Qwen2RMSNorm: MyRMSNorm}) +``` + +### 内置 + 自定义混合 + +```python +from twinkle.kernel import kernelize, npu_builtin + +model = kernelize(model, {**npu_builtin(model), Qwen2RMSNorm: MyRMSNorm}) +``` + +后写入的 key 会覆盖前面的,普通 dict 合并语义。 + +### Hub Kernel(HF Hub 格式) + +```python +from twinkle.kernel import kernelize, hub +from my_pkg import SiluAndMul + +model = kernelize(model, { + SiluAndMul: hub('kernels-community/activation:SiluAndMul', version=1), +}) +``` + +`revision` 与 `version` 二选一必传。`hub(...)` 触发 `kernels` 包的延迟 import,未安装时会提示 `pip install kernels`。 + +### 函数级替换 + +```python +from twinkle.kernel import kernelize +from twinkle.kernel.npu_impls.rotary import npu_apply_rotary_pos_emb + +model = kernelize(model, { + 'transformers.models.qwen2.modeling_qwen2.apply_rotary_pos_emb': + npu_apply_rotary_pos_emb, +}) +``` + +### 跨设备 mapping(NPU 启用、CUDA 跳过) + +```python +from twinkle.kernel import kernelize + +model = kernelize(model, { + Qwen2RMSNorm: {'npu': NpuRMSNorm, 'cuda': CudaRMSNorm}, +}) +``` + +在 CUDA 模型上跑也安全:未匹配 device 的 entry 不会替换、不会报错。 + +## 内置 NPU 优化 + +`npu_builtin(model)` 返回的 dict 至少包含以下覆盖(实际条目随 transformers 已安装的 modeling 模块动态收集): + +- Qwen2 / Qwen3 / Qwen3-MoE / Qwen2.5-VL / Qwen3.5 / Qwen3.5-MoE 系列的 RMSNorm 类替换 +- 同上系列的 `apply_rotary_pos_emb` 函数替换(融合 RoPE) +- 同上系列 MLP 的 SwiGLU 融合替换 +- Qwen3-MoE / Qwen3.5-MoE 的 `Experts.forward` 与 `SparseMoeBlock.forward` 替换 +- Qwen3.5 / Qwen3.5-MoE 的 GatedRMSNorm forward 替换 +- Qwen2.5-VL 的 `apply_multimodal_rotary_pos_emb` 替换 +- 全局 SDPA 替换(一次性副作用,写入 `ALL_ATTENTION_FUNCTIONS['sdpa']`) +- Qwen3.5 Flash Linear Attention 启用(一次性副作用 + 实例遍历,由 `npu_builtin(model)` 内部触发) + +**未默认包含** `transformers.integrations.moe._grouped_mm` 的 NPU 替换(在没有 Expert Parallelism 时会带来约 8x 开销)。需要时手动加入: + +```python +from twinkle.kernel import kernelize, npu_builtin +from twinkle.kernel.npu_impls.moe import npu_grouped_mm + +mapping = { + **npu_builtin(model), + 'transformers.integrations.moe._grouped_mm': {'npu': npu_grouped_mm}, +} +model = kernelize(model, mapping) +``` + +## 环境变量 + +只有两个保留: + +- `TWINKLE_NPU_FLA`:Qwen3.5 FLA 开关(默认开,设为 `0`/`false` 关闭) +- `TWINKLE_NPU_GATED_RMSNorm_FP32`:将 Gated RMSNorm 强制升到 FP32 计算(默认关) + +旧的 `TWINKLE_NPU_PATCH` / `TWINKLE_NPU_FUSED_OPS` / `TWINKLE_NPU_GMM_PATCH` / `TWINKLE_USE_KERNELS` 已移除——这些都改成"是否把对应 entry 写进 mapping"的显式选择。 + +## 注意事项 + +- `m.__class__ = impl_cls` 是 Python class 替换魔法。impl 类**必须**只覆盖 `forward`(以及辅助方法),不能定义 `__init__`,否则原 instance 的 attribute 会与 impl 的预期错位 +- 精确匹配:`type(m) is target_cls`。继承自 `target_cls` 的子类不会被替换;如需替换,把子类也放进 mapping +- 调用 `kernelize` 多次是幂等的(`__class__` 已是 impl 时再设一次无害) - 没有 `unkernelize`——替换是单向的 \ No newline at end of file diff --git a/src/twinkle/kernel/__init__.py b/src/twinkle/kernel/__init__.py index f1de4b75..5d435c0f 100644 --- a/src/twinkle/kernel/__init__.py +++ b/src/twinkle/kernel/__init__.py @@ -1,13 +1,13 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -"""Mapping-driven kernel replacement. - -Three public symbols: - -- :func:`kernelize` apply ``mapping`` to a model -- :func:`hub` build a Hub kernel reference -- :func:`npu_builtin` the Ascend NPU built-in bundle -""" -from .builtin import npu_builtin -from .core import hub, kernelize - +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Mapping-driven kernel replacement. + +Three public symbols: + +- :func:`kernelize` apply ``mapping`` to a model +- :func:`hub` build a Hub kernel reference +- :func:`npu_builtin` the Ascend NPU built-in bundle +""" +from .builtin import npu_builtin +from .core import hub, kernelize + __all__ = ['kernelize', 'hub', 'npu_builtin'] \ No newline at end of file diff --git a/src/twinkle/kernel/builtin.py b/src/twinkle/kernel/builtin.py index b2376e77..8d7f3b59 100644 --- a/src/twinkle/kernel/builtin.py +++ b/src/twinkle/kernel/builtin.py @@ -1,205 +1,205 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -"""``npu_builtin()`` returns the bundle of Ascend NPU replacements. - -All values are wrapped in ``{'npu': impl}`` so the bundle composes safely on -CUDA/CPU systems — non-NPU devices silently skip every entry. - -GMM is **not** included by default (without EP it causes ~8x slowdown). Opt -in by merging: - - {**npu_builtin(model), 'transformers.integrations.moe._grouped_mm': - {'npu': npu_grouped_mm}} -""" -from __future__ import annotations - -import importlib -from typing import Any - -import torch.nn as nn - -from twinkle import get_logger -from twinkle.utils.device_mesh import Platform - -logger = get_logger() - - -def _import_optional(name: str): - try: - return importlib.import_module(name) - except ImportError: - return None - - -def npu_builtin(model: nn.Module | None = None) -> dict[Any, dict[str, Any]]: - """Return the NPU builtin mapping; optionally apply per-instance FLA.""" - from .npu_impls.attention import npu_sdpa_attention_forward - from .npu_impls.fla import apply_qwen3_5_fla - from .npu_impls.moe import ( - npu_packed_moe_experts_forward, - npu_qwen3_5_moe_sparse_block_forward, - ) - from .npu_impls.rms_norm import NpuRMSNorm, npu_gated_rms_norm_forward - from .npu_impls.rotary import ( - npu_apply_multimodal_rotary_pos_emb, - npu_apply_rotary_pos_emb, - ) - from .npu_impls.swiglu import npu_swiglu_forward - - bundle: dict[Any, dict[str, Any]] = {} - - is_npu_platform = Platform.device_prefix() == 'npu' - - # Apply SDPA install eagerly (one-shot module-level mutation) on NPU - # platforms. The NPU impl inverts boolean masks, which is wrong for - # CUDA/CPU execution, so non-NPU platforms must not mutate the global HF - # registry even if ``torch_npu`` is importable in the environment. - if is_npu_platform: - _install_sdpa(npu_sdpa_attention_forward) - - # === per-family class + function entries === - _add_qwen2_entries(bundle, NpuRMSNorm, npu_apply_rotary_pos_emb, npu_swiglu_forward) - _add_qwen3_entries(bundle, NpuRMSNorm, npu_apply_rotary_pos_emb, npu_swiglu_forward) - _add_qwen3_moe_entries( - bundle, NpuRMSNorm, npu_apply_rotary_pos_emb, npu_swiglu_forward, - npu_packed_moe_experts_forward, npu_qwen3_5_moe_sparse_block_forward, - ) - _add_qwen2_5_vl_entries( - bundle, NpuRMSNorm, npu_apply_rotary_pos_emb, npu_swiglu_forward, - npu_apply_multimodal_rotary_pos_emb, - ) - _add_qwen3_5_entries( - bundle, NpuRMSNorm, npu_gated_rms_norm_forward, npu_apply_rotary_pos_emb, - npu_swiglu_forward, - ) - _add_qwen3_5_moe_entries( - bundle, NpuRMSNorm, npu_gated_rms_norm_forward, npu_apply_rotary_pos_emb, - npu_swiglu_forward, npu_packed_moe_experts_forward, - npu_qwen3_5_moe_sparse_block_forward, - ) - - # === FLA (side-effect; mapping-incompatible) === - if is_npu_platform: - apply_qwen3_5_fla(model) - - return bundle - - -def _install_sdpa(impl) -> None: - """One-shot install of SDPA attention forward (global modeling_utils dict). - - ``AttentionInterface._global_mapping`` is a private transformers attribute; - guard against its removal so an upstream change can't take down the rest - of ``npu_builtin()``. - """ - try: - from transformers.modeling_utils import ( - ALL_ATTENTION_FUNCTIONS, - AttentionInterface, - ) - except ImportError: - return - try: - AttentionInterface._global_mapping['sdpa'] = impl - except AttributeError: - logger.warning('[NPU] [SDPA] AttentionInterface._global_mapping unavailable; skipping') - ALL_ATTENTION_FUNCTIONS['sdpa'] = impl - - -# ---- helpers that conditionally add entries based on module availability ---- - -def _add_class_if_present(bundle, module_path, class_name, impl_cls): - mod = _import_optional(module_path) - if mod is None: - return - cls = getattr(mod, class_name, None) - if isinstance(cls, type): - bundle[cls] = {'npu': impl_cls} - - -def _add_swiglu_if_present(bundle, module_path, class_name, fn): - mod = _import_optional(module_path) - if mod is None: - return - cls = getattr(mod, class_name, None) - if isinstance(cls, type): - # Function-level: wrap as string-keyed forward replacement. - # We override on the *class object*, not the module attribute, by - # using a class-key with a synthetic impl wrapping the forward. - # The simplest way is to subclass and reassign __class__, but here - # we follow the legacy approach of overwriting the class's forward: - bundle[f'{module_path}.{class_name}.forward'] = {'npu': fn} - - -def _add_attr_if_present(bundle, module_path, attr_name, impl): - mod = _import_optional(module_path) - if mod is None: - return - if '.' in attr_name: - # Dotted attr like 'Qwen3MoeExperts.forward': resolve the class on - # the module, then check the trailing member on the class. - head, _, tail = attr_name.partition('.') - owner = getattr(mod, head, None) - if owner is None or not hasattr(owner, tail): - return - else: - if not hasattr(mod, attr_name): - return - bundle[f'{module_path}.{attr_name}'] = {'npu': impl} - - -def _add_qwen2_entries(bundle, rms_cls, rope_fn, swiglu_fn): - # Qwen2 (used by Qwen2.5-VL etc. via inheritance) - _add_class_if_present(bundle, 'transformers.models.qwen2.modeling_qwen2', 'Qwen2RMSNorm', rms_cls) - _add_attr_if_present(bundle, 'transformers.models.qwen2.modeling_qwen2', 'apply_rotary_pos_emb', rope_fn) - _add_swiglu_if_present(bundle, 'transformers.models.qwen2.modeling_qwen2', 'Qwen2MLP', swiglu_fn) - - -def _add_qwen3_entries(bundle, rms_cls, rope_fn, swiglu_fn): - base = 'transformers.models.qwen3.modeling_qwen3' - _add_class_if_present(bundle, base, 'Qwen3RMSNorm', rms_cls) - _add_attr_if_present(bundle, base, 'apply_rotary_pos_emb', rope_fn) - _add_swiglu_if_present(bundle, base, 'Qwen3MLP', swiglu_fn) - - -def _add_qwen3_moe_entries(bundle, rms_cls, rope_fn, swiglu_fn, experts_fn, sparse_fn): - base = 'transformers.models.qwen3_moe.modeling_qwen3_moe' - _add_class_if_present(bundle, base, 'Qwen3MoeRMSNorm', rms_cls) - _add_attr_if_present(bundle, base, 'apply_rotary_pos_emb', rope_fn) - _add_swiglu_if_present(bundle, base, 'Qwen3MoeMLP', swiglu_fn) - _add_attr_if_present(bundle, base, 'Qwen3MoeExperts.forward', experts_fn) - _add_attr_if_present(bundle, base, 'Qwen3MoeSparseMoeBlock.forward', sparse_fn) - - -def _add_qwen2_5_vl_entries(bundle, rms_cls, rope_fn, swiglu_fn, multimodal_rope_fn): - base = 'transformers.models.qwen2_5_vl.modeling_qwen2_5_vl' - _add_class_if_present(bundle, base, 'Qwen2_5_VLRMSNorm', rms_cls) - _add_attr_if_present(bundle, base, 'apply_rotary_pos_emb', rope_fn) - _add_attr_if_present(bundle, base, 'apply_multimodal_rotary_pos_emb', multimodal_rope_fn) - _add_swiglu_if_present(bundle, base, 'Qwen2MLP', swiglu_fn) - _add_swiglu_if_present(bundle, base, 'Qwen2_5_VLMLP', swiglu_fn) - - -def _add_qwen3_5_entries(bundle, rms_cls, gated_rms_fn, rope_fn, swiglu_fn): - base = 'transformers.models.qwen3_5.modeling_qwen3_5' - if _import_optional(base) is None: - return - _add_class_if_present(bundle, base, 'Qwen3_5RMSNorm', rms_cls) - _add_class_if_present(bundle, base, 'Qwen3_5VisionRMSNorm', rms_cls) - _add_attr_if_present(bundle, base, 'apply_rotary_pos_emb', rope_fn) - _add_swiglu_if_present(bundle, base, 'Qwen3_5MLP', swiglu_fn) - _add_swiglu_if_present(bundle, base, 'Qwen3_5VisionMLP', swiglu_fn) - # Qwen3_5GatedRMSNorm: forward-level replacement - _add_attr_if_present(bundle, base, 'Qwen3_5GatedRMSNorm.forward', gated_rms_fn) - - -def _add_qwen3_5_moe_entries(bundle, rms_cls, gated_rms_fn, rope_fn, swiglu_fn, - experts_fn, sparse_fn): - base = 'transformers.models.qwen3_5_moe.modeling_qwen3_5_moe' - if _import_optional(base) is None: - return - _add_class_if_present(bundle, base, 'Qwen3_5MoeRMSNorm', rms_cls) - _add_attr_if_present(bundle, base, 'apply_rotary_pos_emb', rope_fn) - _add_swiglu_if_present(bundle, base, 'Qwen3_5MoeMLP', swiglu_fn) - _add_attr_if_present(bundle, base, 'Qwen3_5MoeExperts.forward', experts_fn) - _add_attr_if_present(bundle, base, 'Qwen3_5MoeSparseMoeBlock.forward', sparse_fn) - _add_attr_if_present(bundle, base, 'Qwen3_5MoeGatedRMSNorm.forward', gated_rms_fn) +# Copyright (c) ModelScope Contributors. All rights reserved. +"""``npu_builtin()`` returns the bundle of Ascend NPU replacements. + +All values are wrapped in ``{'npu': impl}`` so the bundle composes safely on +CUDA/CPU systems — non-NPU devices silently skip every entry. + +GMM is **not** included by default (without EP it causes ~8x slowdown). Opt +in by merging: + + {**npu_builtin(model), 'transformers.integrations.moe._grouped_mm': + {'npu': npu_grouped_mm}} +""" +from __future__ import annotations + +import importlib +from typing import Any + +import torch.nn as nn + +from twinkle import get_logger +from twinkle.utils.device_mesh import Platform + +logger = get_logger() + + +def _import_optional(name: str): + try: + return importlib.import_module(name) + except ImportError: + return None + + +def npu_builtin(model: nn.Module | None = None) -> dict[Any, dict[str, Any]]: + """Return the NPU builtin mapping; optionally apply per-instance FLA.""" + from .npu_impls.attention import npu_sdpa_attention_forward + from .npu_impls.fla import apply_qwen3_5_fla + from .npu_impls.moe import ( + npu_packed_moe_experts_forward, + npu_qwen3_5_moe_sparse_block_forward, + ) + from .npu_impls.rms_norm import NpuRMSNorm, npu_gated_rms_norm_forward + from .npu_impls.rotary import ( + npu_apply_multimodal_rotary_pos_emb, + npu_apply_rotary_pos_emb, + ) + from .npu_impls.swiglu import npu_swiglu_forward + + bundle: dict[Any, dict[str, Any]] = {} + + is_npu_platform = Platform.device_prefix() == 'npu' + + # Apply SDPA install eagerly (one-shot module-level mutation) on NPU + # platforms. The NPU impl inverts boolean masks, which is wrong for + # CUDA/CPU execution, so non-NPU platforms must not mutate the global HF + # registry even if ``torch_npu`` is importable in the environment. + if is_npu_platform: + _install_sdpa(npu_sdpa_attention_forward) + + # === per-family class + function entries === + _add_qwen2_entries(bundle, NpuRMSNorm, npu_apply_rotary_pos_emb, npu_swiglu_forward) + _add_qwen3_entries(bundle, NpuRMSNorm, npu_apply_rotary_pos_emb, npu_swiglu_forward) + _add_qwen3_moe_entries( + bundle, NpuRMSNorm, npu_apply_rotary_pos_emb, npu_swiglu_forward, + npu_packed_moe_experts_forward, npu_qwen3_5_moe_sparse_block_forward, + ) + _add_qwen2_5_vl_entries( + bundle, NpuRMSNorm, npu_apply_rotary_pos_emb, npu_swiglu_forward, + npu_apply_multimodal_rotary_pos_emb, + ) + _add_qwen3_5_entries( + bundle, NpuRMSNorm, npu_gated_rms_norm_forward, npu_apply_rotary_pos_emb, + npu_swiglu_forward, + ) + _add_qwen3_5_moe_entries( + bundle, NpuRMSNorm, npu_gated_rms_norm_forward, npu_apply_rotary_pos_emb, + npu_swiglu_forward, npu_packed_moe_experts_forward, + npu_qwen3_5_moe_sparse_block_forward, + ) + + # === FLA (side-effect; mapping-incompatible) === + if is_npu_platform: + apply_qwen3_5_fla(model) + + return bundle + + +def _install_sdpa(impl) -> None: + """One-shot install of SDPA attention forward (global modeling_utils dict). + + ``AttentionInterface._global_mapping`` is a private transformers attribute; + guard against its removal so an upstream change can't take down the rest + of ``npu_builtin()``. + """ + try: + from transformers.modeling_utils import ( + ALL_ATTENTION_FUNCTIONS, + AttentionInterface, + ) + except ImportError: + return + try: + AttentionInterface._global_mapping['sdpa'] = impl + except AttributeError: + logger.warning('[NPU] [SDPA] AttentionInterface._global_mapping unavailable; skipping') + ALL_ATTENTION_FUNCTIONS['sdpa'] = impl + + +# ---- helpers that conditionally add entries based on module availability ---- + +def _add_class_if_present(bundle, module_path, class_name, impl_cls): + mod = _import_optional(module_path) + if mod is None: + return + cls = getattr(mod, class_name, None) + if isinstance(cls, type): + bundle[cls] = {'npu': impl_cls} + + +def _add_swiglu_if_present(bundle, module_path, class_name, fn): + mod = _import_optional(module_path) + if mod is None: + return + cls = getattr(mod, class_name, None) + if isinstance(cls, type): + # Function-level: wrap as string-keyed forward replacement. + # We override on the *class object*, not the module attribute, by + # using a class-key with a synthetic impl wrapping the forward. + # The simplest way is to subclass and reassign __class__, but here + # we follow the legacy approach of overwriting the class's forward: + bundle[f'{module_path}.{class_name}.forward'] = {'npu': fn} + + +def _add_attr_if_present(bundle, module_path, attr_name, impl): + mod = _import_optional(module_path) + if mod is None: + return + if '.' in attr_name: + # Dotted attr like 'Qwen3MoeExperts.forward': resolve the class on + # the module, then check the trailing member on the class. + head, _, tail = attr_name.partition('.') + owner = getattr(mod, head, None) + if owner is None or not hasattr(owner, tail): + return + else: + if not hasattr(mod, attr_name): + return + bundle[f'{module_path}.{attr_name}'] = {'npu': impl} + + +def _add_qwen2_entries(bundle, rms_cls, rope_fn, swiglu_fn): + # Qwen2 (used by Qwen2.5-VL etc. via inheritance) + _add_class_if_present(bundle, 'transformers.models.qwen2.modeling_qwen2', 'Qwen2RMSNorm', rms_cls) + _add_attr_if_present(bundle, 'transformers.models.qwen2.modeling_qwen2', 'apply_rotary_pos_emb', rope_fn) + _add_swiglu_if_present(bundle, 'transformers.models.qwen2.modeling_qwen2', 'Qwen2MLP', swiglu_fn) + + +def _add_qwen3_entries(bundle, rms_cls, rope_fn, swiglu_fn): + base = 'transformers.models.qwen3.modeling_qwen3' + _add_class_if_present(bundle, base, 'Qwen3RMSNorm', rms_cls) + _add_attr_if_present(bundle, base, 'apply_rotary_pos_emb', rope_fn) + _add_swiglu_if_present(bundle, base, 'Qwen3MLP', swiglu_fn) + + +def _add_qwen3_moe_entries(bundle, rms_cls, rope_fn, swiglu_fn, experts_fn, sparse_fn): + base = 'transformers.models.qwen3_moe.modeling_qwen3_moe' + _add_class_if_present(bundle, base, 'Qwen3MoeRMSNorm', rms_cls) + _add_attr_if_present(bundle, base, 'apply_rotary_pos_emb', rope_fn) + _add_swiglu_if_present(bundle, base, 'Qwen3MoeMLP', swiglu_fn) + _add_attr_if_present(bundle, base, 'Qwen3MoeExperts.forward', experts_fn) + _add_attr_if_present(bundle, base, 'Qwen3MoeSparseMoeBlock.forward', sparse_fn) + + +def _add_qwen2_5_vl_entries(bundle, rms_cls, rope_fn, swiglu_fn, multimodal_rope_fn): + base = 'transformers.models.qwen2_5_vl.modeling_qwen2_5_vl' + _add_class_if_present(bundle, base, 'Qwen2_5_VLRMSNorm', rms_cls) + _add_attr_if_present(bundle, base, 'apply_rotary_pos_emb', rope_fn) + _add_attr_if_present(bundle, base, 'apply_multimodal_rotary_pos_emb', multimodal_rope_fn) + _add_swiglu_if_present(bundle, base, 'Qwen2MLP', swiglu_fn) + _add_swiglu_if_present(bundle, base, 'Qwen2_5_VLMLP', swiglu_fn) + + +def _add_qwen3_5_entries(bundle, rms_cls, gated_rms_fn, rope_fn, swiglu_fn): + base = 'transformers.models.qwen3_5.modeling_qwen3_5' + if _import_optional(base) is None: + return + _add_class_if_present(bundle, base, 'Qwen3_5RMSNorm', rms_cls) + _add_class_if_present(bundle, base, 'Qwen3_5VisionRMSNorm', rms_cls) + _add_attr_if_present(bundle, base, 'apply_rotary_pos_emb', rope_fn) + _add_swiglu_if_present(bundle, base, 'Qwen3_5MLP', swiglu_fn) + _add_swiglu_if_present(bundle, base, 'Qwen3_5VisionMLP', swiglu_fn) + # Qwen3_5GatedRMSNorm: forward-level replacement + _add_attr_if_present(bundle, base, 'Qwen3_5GatedRMSNorm.forward', gated_rms_fn) + + +def _add_qwen3_5_moe_entries(bundle, rms_cls, gated_rms_fn, rope_fn, swiglu_fn, + experts_fn, sparse_fn): + base = 'transformers.models.qwen3_5_moe.modeling_qwen3_5_moe' + if _import_optional(base) is None: + return + _add_class_if_present(bundle, base, 'Qwen3_5MoeRMSNorm', rms_cls) + _add_attr_if_present(bundle, base, 'apply_rotary_pos_emb', rope_fn) + _add_swiglu_if_present(bundle, base, 'Qwen3_5MoeMLP', swiglu_fn) + _add_attr_if_present(bundle, base, 'Qwen3_5MoeExperts.forward', experts_fn) + _add_attr_if_present(bundle, base, 'Qwen3_5MoeSparseMoeBlock.forward', sparse_fn) + _add_attr_if_present(bundle, base, 'Qwen3_5MoeGatedRMSNorm.forward', gated_rms_fn) diff --git a/src/twinkle/kernel/chunk_gated_delta_rule.py b/src/twinkle/kernel/chunk_gated_delta_rule.py index 09defec6..2d0beee7 100644 --- a/src/twinkle/kernel/chunk_gated_delta_rule.py +++ b/src/twinkle/kernel/chunk_gated_delta_rule.py @@ -1,362 +1,362 @@ -'''Ascend NPU implementation of chunk_gated_delta_rule for Flash Linear Attention (FLA). -This module provides a drop-in replacement for fla.ops.gated_delta_rule.chunk_gated_delta_rule, -redirecting the underlying Triton kernels to MindSpeed's NPU-compatible counterparts. -It is consumed by twinkle.kernel.npu_impls.fla to enable the fast linear-attention -path of Qwen3.5 on Ascend hardware.''' - -import torch -import warnings -from mindspeed.lite.ops.triton.chunk_delta_h import chunk_gated_delta_rule_bwd_dhu, chunk_gated_delta_rule_fwd_h -from mindspeed.lite.ops.triton.chunk_o import chunk_bwd_dqkwg, chunk_bwd_dv_local, chunk_fwd_o -from mindspeed.lite.ops.triton.chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd -from mindspeed.lite.ops.triton.cumsum import chunk_local_cumsum -from mindspeed.lite.ops.triton.solve_tril import solve_tril -from mindspeed.lite.ops.triton.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard -from mindspeed.lite.ops.triton.wy_fast import prepare_wy_repr_bwd, recompute_w_u_fwd -from typing import Optional - - -def _torch_l2norm_fwd( - x: torch.Tensor, - eps: float = 1e-6, - output_dtype: Optional[torch.dtype] = None, -): - x_shape_og = x.shape - x = x.view(-1, x.shape[-1]) - x_float = x.float() - rstd = torch.rsqrt(torch.sum(x_float * x_float, dim=-1) + eps) - y = x_float * rstd.unsqueeze(-1) - y = y.to(output_dtype if output_dtype is not None else x.dtype) - return y.view(x_shape_og), rstd.view(x_shape_og[:-1]) - - -def _torch_l2norm_bwd( - y: torch.Tensor, - rstd: torch.Tensor, - dy: torch.Tensor, - eps: float = 1e-6, -): - y_shape_og = y.shape - y = y.view(-1, y.shape[-1]) - dy = dy.view(-1, dy.shape[-1]) - y_float = y.float() - dy_float = dy.float() - rstd = rstd.view(-1).float() - dx = dy_float * rstd.unsqueeze(-1) - dx = dx - torch.sum(dy_float * y_float, dim=-1, keepdim=True) * y_float * rstd.unsqueeze(-1) - return dx.to(y.dtype).view(y_shape_og) - - -def chunk_gated_delta_rule_fwd( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - g: torch.Tensor, - beta: torch.Tensor, - scale: float, - initial_state: torch.Tensor, - output_final_state: bool, - cu_seqlens: Optional[torch.LongTensor] = None, - chunk_size: int = 64, -): - g = chunk_local_cumsum(g, chunk_size=chunk_size, cu_seqlens=cu_seqlens, head_first=False) - # obtain WY representation. u is actually the new v. - A = chunk_scaled_dot_kkt_fwd( - k=k, g=g, beta=beta, cu_seqlens=cu_seqlens, chunk_size=chunk_size, output_dtype=torch.float32) - A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype) - w, u = recompute_w_u_fwd( - k=k, - v=v, - beta=beta, - A=A, - g=g, - cu_seqlens=cu_seqlens, - ) - h, v_new, final_state = chunk_gated_delta_rule_fwd_h( - k=k, - w=w, - u=u, - g=g, - initial_state=initial_state, - output_final_state=output_final_state, - chunk_size=chunk_size, - cu_seqlens=cu_seqlens, - ) - o = chunk_fwd_o( - q=q, - k=k, - v=v_new, - h=h, - g=g, - scale=scale, - cu_seqlens=cu_seqlens, - chunk_size=chunk_size, - ) - return g, o, A, final_state - - -def chunk_gated_delta_rule_bwd( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - g: torch.Tensor, - beta: torch.Tensor, - A: torch.Tensor, - scale: float, - initial_state: torch.Tensor, - do: torch.Tensor, - dht: torch.Tensor, - cu_seqlens: Optional[torch.LongTensor] = None, - chunk_size: int = 64, -): - w, u = recompute_w_u_fwd( - k=k, - v=v, - beta=beta, - A=A, - g=g, - cu_seqlens=cu_seqlens, - ) - h, v_new, _ = chunk_gated_delta_rule_fwd_h( - k=k, - w=w, - u=u, - g=g, - initial_state=initial_state, - output_final_state=False, - cu_seqlens=cu_seqlens, - chunk_size=chunk_size, - ) - dv = chunk_bwd_dv_local( - q=q, - k=k, - g=g, - do=do, - scale=scale, - cu_seqlens=cu_seqlens, - chunk_size=chunk_size, - ) - dh, dh0, dv = chunk_gated_delta_rule_bwd_dhu( - q=q, - k=k, - w=w, - g=g, - h0=initial_state, - dht=dht, - do=do, - dv=dv, - scale=scale, - cu_seqlens=cu_seqlens, - chunk_size=chunk_size, - ) - dq, dk, dw, dg = chunk_bwd_dqkwg( - q=q, - k=k, - v=v_new, - w=w, - g=g, - h=h, - dv=dv, - do=do, - dh=dh, - chunk_size=chunk_size, - scale=scale, - cu_seqlens=cu_seqlens, - ) - dk2, dv, db, dg2 = prepare_wy_repr_bwd( - k=k, v=v, beta=beta, g=g, A=A, dw=dw, du=dv, cu_seqlens=cu_seqlens, chunk_size=chunk_size) - dk.add_(dk2) - dg.add_(dg2) - if dg.dtype != torch.float32: - raise ValueError(f'dg current type is {dg.dtype} , should be float32') - dg = chunk_local_cumsum(dg, chunk_size=chunk_size, reverse=True, cu_seqlens=cu_seqlens, head_first=False) - return dq, dk, dv, db, dg, dh0 - - -class ChunkGatedDeltaRuleFunction(torch.autograd.Function): - - @staticmethod - @input_guard - @autocast_custom_fwd - def forward( - ctx, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - g: torch.Tensor, - beta: torch.Tensor, - scale: float, - initial_state: torch.Tensor, - output_final_state: bool, - cu_seqlens: Optional[torch.LongTensor] = None, - use_qk_l2norm_in_kernel: bool = False, - chunk_size: int = 64, - ): - if use_qk_l2norm_in_kernel: - q, q_rstd = _torch_l2norm_fwd(q) - k, k_rstd = _torch_l2norm_fwd(k) - else: - q_rstd, k_rstd = None, None - - g, o, A, final_state = chunk_gated_delta_rule_fwd( - q=q, - k=k, - v=v, - g=g, - beta=beta, - scale=scale, - initial_state=initial_state, - output_final_state=output_final_state, - cu_seqlens=cu_seqlens, - chunk_size=chunk_size) - ctx.save_for_backward(q, q_rstd, k, k_rstd, v, g, beta, A, initial_state, cu_seqlens) - ctx.scale = scale - ctx.use_qk_l2norm_in_kernel = use_qk_l2norm_in_kernel - ctx.chunk_size = chunk_size - return o.to(q.dtype), final_state - - @staticmethod - @input_guard - @autocast_custom_bwd - def backward(ctx, do: torch.Tensor, dht: torch.Tensor): - q, q_rstd, k, k_rstd, v, g, beta, A, initial_state, cu_seqlens = ctx.saved_tensors - dq, dk, dv, db, dg, dh0 = chunk_gated_delta_rule_bwd( - q=q, - k=k, - v=v, - g=g, - beta=beta, - A=A, - scale=ctx.scale, - initial_state=initial_state, - do=do, - dht=dht, - cu_seqlens=cu_seqlens, - chunk_size=ctx.chunk_size, - ) - if ctx.use_qk_l2norm_in_kernel: - dq = _torch_l2norm_bwd(q, q_rstd, dq) - dk = _torch_l2norm_bwd(k, k_rstd, dk) - return dq.to(q), dk.to(k), dv.to(v), dg.to(g), db.to(beta), None, dh0, None, None, None, None - - -@torch.compiler.disable -def chunk_gated_delta_rule( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - g: torch.Tensor, - beta: torch.Tensor, - scale: float = None, - initial_state: torch.Tensor = None, - output_final_state: bool = False, - use_qk_l2norm_in_kernel: bool = False, - cu_seqlens: Optional[torch.LongTensor] = None, - chunk_size: int = 64, - head_first: bool = False, -): - r""" - Args: - q (torch.Tensor): - queries of shape `[B, T, H, K]`. - k (torch.Tensor): - keys of shape `[B, T, H, K]`. - v (torch.Tensor): - values of shape `[B, T, H, V]`. - g (torch.Tensor): - (forget) gating tensor (in log space!) of shape `[B, T, H]`. - beta (torch.Tensor): - betas of shape `[B, T, H]`. - scale (Optional[float]): - Scale factor for the RetNet attention scores. - If not provided, it will default to `1 / sqrt(K)`. Default: `None`. - initial_state (Optional[torch.Tensor]): - Initial state of shape `[N, H, K, V]` for `N` input sequences. - For equal-length input sequences, `N` equals the batch size `B`. - Default: `None`. - output_final_state (Optional[bool]): - Whether to output the final state of shape `[N, H, K, V]`. Default: `False`. - use_qk_l2norm_in_kernel (bool): - Whether to apply L2norm to the q/k tensor internally. Default: `False`. - cu_seqlens (torch.LongTensor): - Cumulative sequence lengths of shape `[N+1]` used for variable-length training, - consistent with the FlashAttention API. - head_first (Optional[bool]): - Whether the inputs are in the head-first format. Default: `False`. - This argument has been deprecated. - - Returns: - o (torch.Tensor): - Outputs of shape `[B, T, H, V]`. - final_state (torch.Tensor): - Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`. - - Examples:: - >>> import torch - >>> import torch.nn.functional as F - >>> from einops import rearrange - >>> from fla.ops.gated_delta_rule import chunk_gated_delta_rule - # inputs with equal lengths - >>> B, T, H, K, V = 4, 2048, 4, 512, 512 - >>> q = torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda') - >>> k = F.normalize(torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda'), p=2, dim=-1) - >>> v = torch.randn(B, T, H, V, dtype=torch.bfloat16, device='cuda') - >>> beta = torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda').sigmoid() - >>> g = F.logsigmoid(torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda')) - >>> h0 = torch.randn(B, H, K, V, dtype=torch.bfloat16, device='cuda') - >>> o, ht = chunk_gated_delta_rule( - q, k, v, g, beta, - initial_state=h0, - output_final_state=True - ) - # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required - >>> q, k, v, beta, g = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, beta, g)) - # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected - >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long) - >>> o, ht = chunk_gated_delta_rule( - q, k, v, g, beta, - initial_state=h0, - output_final_state=True, - cu_seqlens=cu_seqlens - ) - """ - if q.dtype != k.dtype or k.dtype != v.dtype: - raise ValueError( - f'q current type is {q.dtype}, k current type is {k.dtype}, v current type is {v.dtype}, should be equal') - if q.dtype == torch.float32: - raise ValueError('ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16.') - if len(beta.shape) != 3: - raise ValueError(f'beta current shape len is {len(beta.shape)}, beta must be of shape [B, T, H] ' - f'if head_first=False, or [B, H, T] otherwise.') - if head_first: - warnings.warn('head_first is deprecated and will be removed in a future version. ' - 'Please use head_first=False for now instead.') - if not head_first and q.shape[1] < q.shape[2]: - warnings.warn( - f'Input tensor shape suggests format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). ' - 'This may indicate the inputs were passed in head-first format [B, H, T, ...] ' - 'when head_first=False was specified. ' - 'Please verify your input tensor format matches the expected shape [B, T, H, ...].') - if cu_seqlens is not None: - if q.shape[0] != 1: - raise ValueError(f'The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`.' - f'Please flatten variable-length inputs before processing.') - if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1: - raise ValueError(f'The number of initial states is expected to be equal to the number of input sequences, ' - f'i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}.') - if scale is None: - scale = k.shape[-1]**-0.5 - o, final_state = ChunkGatedDeltaRuleFunction.apply( - q, - k, - v, - g, - beta, - scale, - initial_state, - output_final_state, - cu_seqlens, - use_qk_l2norm_in_kernel, - chunk_size, - ) - return o, final_state +'''Ascend NPU implementation of chunk_gated_delta_rule for Flash Linear Attention (FLA). +This module provides a drop-in replacement for fla.ops.gated_delta_rule.chunk_gated_delta_rule, +redirecting the underlying Triton kernels to MindSpeed's NPU-compatible counterparts. +It is consumed by twinkle.kernel.npu_impls.fla to enable the fast linear-attention +path of Qwen3.5 on Ascend hardware.''' + +import torch +import warnings +from mindspeed.lite.ops.triton.chunk_delta_h import chunk_gated_delta_rule_bwd_dhu, chunk_gated_delta_rule_fwd_h +from mindspeed.lite.ops.triton.chunk_o import chunk_bwd_dqkwg, chunk_bwd_dv_local, chunk_fwd_o +from mindspeed.lite.ops.triton.chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd +from mindspeed.lite.ops.triton.cumsum import chunk_local_cumsum +from mindspeed.lite.ops.triton.solve_tril import solve_tril +from mindspeed.lite.ops.triton.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard +from mindspeed.lite.ops.triton.wy_fast import prepare_wy_repr_bwd, recompute_w_u_fwd +from typing import Optional + + +def _torch_l2norm_fwd( + x: torch.Tensor, + eps: float = 1e-6, + output_dtype: Optional[torch.dtype] = None, +): + x_shape_og = x.shape + x = x.view(-1, x.shape[-1]) + x_float = x.float() + rstd = torch.rsqrt(torch.sum(x_float * x_float, dim=-1) + eps) + y = x_float * rstd.unsqueeze(-1) + y = y.to(output_dtype if output_dtype is not None else x.dtype) + return y.view(x_shape_og), rstd.view(x_shape_og[:-1]) + + +def _torch_l2norm_bwd( + y: torch.Tensor, + rstd: torch.Tensor, + dy: torch.Tensor, + eps: float = 1e-6, +): + y_shape_og = y.shape + y = y.view(-1, y.shape[-1]) + dy = dy.view(-1, dy.shape[-1]) + y_float = y.float() + dy_float = dy.float() + rstd = rstd.view(-1).float() + dx = dy_float * rstd.unsqueeze(-1) + dx = dx - torch.sum(dy_float * y_float, dim=-1, keepdim=True) * y_float * rstd.unsqueeze(-1) + return dx.to(y.dtype).view(y_shape_og) + + +def chunk_gated_delta_rule_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + output_final_state: bool, + cu_seqlens: Optional[torch.LongTensor] = None, + chunk_size: int = 64, +): + g = chunk_local_cumsum(g, chunk_size=chunk_size, cu_seqlens=cu_seqlens, head_first=False) + # obtain WY representation. u is actually the new v. + A = chunk_scaled_dot_kkt_fwd( + k=k, g=g, beta=beta, cu_seqlens=cu_seqlens, chunk_size=chunk_size, output_dtype=torch.float32) + A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype) + w, u = recompute_w_u_fwd( + k=k, + v=v, + beta=beta, + A=A, + g=g, + cu_seqlens=cu_seqlens, + ) + h, v_new, final_state = chunk_gated_delta_rule_fwd_h( + k=k, + w=w, + u=u, + g=g, + initial_state=initial_state, + output_final_state=output_final_state, + chunk_size=chunk_size, + cu_seqlens=cu_seqlens, + ) + o = chunk_fwd_o( + q=q, + k=k, + v=v_new, + h=h, + g=g, + scale=scale, + cu_seqlens=cu_seqlens, + chunk_size=chunk_size, + ) + return g, o, A, final_state + + +def chunk_gated_delta_rule_bwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + A: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + do: torch.Tensor, + dht: torch.Tensor, + cu_seqlens: Optional[torch.LongTensor] = None, + chunk_size: int = 64, +): + w, u = recompute_w_u_fwd( + k=k, + v=v, + beta=beta, + A=A, + g=g, + cu_seqlens=cu_seqlens, + ) + h, v_new, _ = chunk_gated_delta_rule_fwd_h( + k=k, + w=w, + u=u, + g=g, + initial_state=initial_state, + output_final_state=False, + cu_seqlens=cu_seqlens, + chunk_size=chunk_size, + ) + dv = chunk_bwd_dv_local( + q=q, + k=k, + g=g, + do=do, + scale=scale, + cu_seqlens=cu_seqlens, + chunk_size=chunk_size, + ) + dh, dh0, dv = chunk_gated_delta_rule_bwd_dhu( + q=q, + k=k, + w=w, + g=g, + h0=initial_state, + dht=dht, + do=do, + dv=dv, + scale=scale, + cu_seqlens=cu_seqlens, + chunk_size=chunk_size, + ) + dq, dk, dw, dg = chunk_bwd_dqkwg( + q=q, + k=k, + v=v_new, + w=w, + g=g, + h=h, + dv=dv, + do=do, + dh=dh, + chunk_size=chunk_size, + scale=scale, + cu_seqlens=cu_seqlens, + ) + dk2, dv, db, dg2 = prepare_wy_repr_bwd( + k=k, v=v, beta=beta, g=g, A=A, dw=dw, du=dv, cu_seqlens=cu_seqlens, chunk_size=chunk_size) + dk.add_(dk2) + dg.add_(dg2) + if dg.dtype != torch.float32: + raise ValueError(f'dg current type is {dg.dtype} , should be float32') + dg = chunk_local_cumsum(dg, chunk_size=chunk_size, reverse=True, cu_seqlens=cu_seqlens, head_first=False) + return dq, dk, dv, db, dg, dh0 + + +class ChunkGatedDeltaRuleFunction(torch.autograd.Function): + + @staticmethod + @input_guard + @autocast_custom_fwd + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + output_final_state: bool, + cu_seqlens: Optional[torch.LongTensor] = None, + use_qk_l2norm_in_kernel: bool = False, + chunk_size: int = 64, + ): + if use_qk_l2norm_in_kernel: + q, q_rstd = _torch_l2norm_fwd(q) + k, k_rstd = _torch_l2norm_fwd(k) + else: + q_rstd, k_rstd = None, None + + g, o, A, final_state = chunk_gated_delta_rule_fwd( + q=q, + k=k, + v=v, + g=g, + beta=beta, + scale=scale, + initial_state=initial_state, + output_final_state=output_final_state, + cu_seqlens=cu_seqlens, + chunk_size=chunk_size) + ctx.save_for_backward(q, q_rstd, k, k_rstd, v, g, beta, A, initial_state, cu_seqlens) + ctx.scale = scale + ctx.use_qk_l2norm_in_kernel = use_qk_l2norm_in_kernel + ctx.chunk_size = chunk_size + return o.to(q.dtype), final_state + + @staticmethod + @input_guard + @autocast_custom_bwd + def backward(ctx, do: torch.Tensor, dht: torch.Tensor): + q, q_rstd, k, k_rstd, v, g, beta, A, initial_state, cu_seqlens = ctx.saved_tensors + dq, dk, dv, db, dg, dh0 = chunk_gated_delta_rule_bwd( + q=q, + k=k, + v=v, + g=g, + beta=beta, + A=A, + scale=ctx.scale, + initial_state=initial_state, + do=do, + dht=dht, + cu_seqlens=cu_seqlens, + chunk_size=ctx.chunk_size, + ) + if ctx.use_qk_l2norm_in_kernel: + dq = _torch_l2norm_bwd(q, q_rstd, dq) + dk = _torch_l2norm_bwd(k, k_rstd, dk) + return dq.to(q), dk.to(k), dv.to(v), dg.to(g), db.to(beta), None, dh0, None, None, None, None + + +@torch.compiler.disable +def chunk_gated_delta_rule( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float = None, + initial_state: torch.Tensor = None, + output_final_state: bool = False, + use_qk_l2norm_in_kernel: bool = False, + cu_seqlens: Optional[torch.LongTensor] = None, + chunk_size: int = 64, + head_first: bool = False, +): + r""" + Args: + q (torch.Tensor): + queries of shape `[B, T, H, K]`. + k (torch.Tensor): + keys of shape `[B, T, H, K]`. + v (torch.Tensor): + values of shape `[B, T, H, V]`. + g (torch.Tensor): + (forget) gating tensor (in log space!) of shape `[B, T, H]`. + beta (torch.Tensor): + betas of shape `[B, T, H]`. + scale (Optional[float]): + Scale factor for the RetNet attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + initial_state (Optional[torch.Tensor]): + Initial state of shape `[N, H, K, V]` for `N` input sequences. + For equal-length input sequences, `N` equals the batch size `B`. + Default: `None`. + output_final_state (Optional[bool]): + Whether to output the final state of shape `[N, H, K, V]`. Default: `False`. + use_qk_l2norm_in_kernel (bool): + Whether to apply L2norm to the q/k tensor internally. Default: `False`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + head_first (Optional[bool]): + Whether the inputs are in the head-first format. Default: `False`. + This argument has been deprecated. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, T, H, V]`. + final_state (torch.Tensor): + Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`. + + Examples:: + >>> import torch + >>> import torch.nn.functional as F + >>> from einops import rearrange + >>> from fla.ops.gated_delta_rule import chunk_gated_delta_rule + # inputs with equal lengths + >>> B, T, H, K, V = 4, 2048, 4, 512, 512 + >>> q = torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda') + >>> k = F.normalize(torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda'), p=2, dim=-1) + >>> v = torch.randn(B, T, H, V, dtype=torch.bfloat16, device='cuda') + >>> beta = torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda').sigmoid() + >>> g = F.logsigmoid(torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda')) + >>> h0 = torch.randn(B, H, K, V, dtype=torch.bfloat16, device='cuda') + >>> o, ht = chunk_gated_delta_rule( + q, k, v, g, beta, + initial_state=h0, + output_final_state=True + ) + # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required + >>> q, k, v, beta, g = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, beta, g)) + # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected + >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long) + >>> o, ht = chunk_gated_delta_rule( + q, k, v, g, beta, + initial_state=h0, + output_final_state=True, + cu_seqlens=cu_seqlens + ) + """ + if q.dtype != k.dtype or k.dtype != v.dtype: + raise ValueError( + f'q current type is {q.dtype}, k current type is {k.dtype}, v current type is {v.dtype}, should be equal') + if q.dtype == torch.float32: + raise ValueError('ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16.') + if len(beta.shape) != 3: + raise ValueError(f'beta current shape len is {len(beta.shape)}, beta must be of shape [B, T, H] ' + f'if head_first=False, or [B, H, T] otherwise.') + if head_first: + warnings.warn('head_first is deprecated and will be removed in a future version. ' + 'Please use head_first=False for now instead.') + if not head_first and q.shape[1] < q.shape[2]: + warnings.warn( + f'Input tensor shape suggests format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). ' + 'This may indicate the inputs were passed in head-first format [B, H, T, ...] ' + 'when head_first=False was specified. ' + 'Please verify your input tensor format matches the expected shape [B, T, H, ...].') + if cu_seqlens is not None: + if q.shape[0] != 1: + raise ValueError(f'The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`.' + f'Please flatten variable-length inputs before processing.') + if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1: + raise ValueError(f'The number of initial states is expected to be equal to the number of input sequences, ' + f'i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}.') + if scale is None: + scale = k.shape[-1]**-0.5 + o, final_state = ChunkGatedDeltaRuleFunction.apply( + q, + k, + v, + g, + beta, + scale, + initial_state, + output_final_state, + cu_seqlens, + use_qk_l2norm_in_kernel, + chunk_size, + ) + return o, final_state diff --git a/src/twinkle/kernel/core.py b/src/twinkle/kernel/core.py index 03cace34..a3a12f18 100644 --- a/src/twinkle/kernel/core.py +++ b/src/twinkle/kernel/core.py @@ -1,171 +1,171 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -"""Minimal mapping-driven kernel replacement. - -Public API: ``kernelize``, ``hub`` (re-exported from ``twinkle.kernel``). -""" -from __future__ import annotations - -import importlib -from dataclasses import dataclass -from typing import Any - -import torch.nn as nn - -from twinkle.utils.device_mesh import Platform - - -@dataclass(frozen=True) -class HubRef: - """Lightweight reference to a HuggingFace Hub kernel layer. - - Resolved lazily by ``kernelize`` via the optional ``kernels`` package. - """ - repo_id: str - layer_name: str - revision: str | None = None - version: int | None = None - backend: str | None = None - trust_remote_code: bool = False - - -def hub( - ref: str, - *, - revision: str | None = None, - version: int | None = None, - backend: str | None = None, - trust_remote_code: bool = False, -) -> HubRef: - """Build a ``HubRef`` for use as a ``kernelize`` mapping value. - - ``ref`` is ``':'`` (e.g. ``'org/repo:SiluAndMul'``). - Exactly one of ``revision`` or ``version`` must be supplied. - """ - if (revision is None) == (version is None): - raise ValueError('Exactly one of `revision` or `version` must be specified.') - if ':' not in ref: - raise ValueError(f"Hub ref must be 'repo_id:LayerName', got: {ref!r}") - repo_id, layer_name = ref.rsplit(':', 1) - return HubRef(repo_id, layer_name, revision, version, backend, trust_remote_code) - - - -def _resolve_value(value: Any, device: str) -> Any | None: - """Resolve a mapping value against the selected device. - - - ``dict``: device-conditional; recurse into ``value[device]`` or return None. - - anything else (including ``HubRef``): pass through. - """ - if isinstance(value, dict): - if device not in value: - return None - return _resolve_value(value[device], device) - return value - - -def _replace_class(model: nn.Module, target_cls: type, impl_cls: type) -> None: - """Rewrite ``__class__`` of every module whose exact type is ``target_cls``. - - Uses ``type(m) is target_cls`` (not ``isinstance``) so user-defined - subclasses of ``target_cls`` are deliberately left alone. - """ - for m in model.modules(): - if type(m) is target_cls: - m.__class__ = impl_cls - - -def _replace_attr(dotted_path: str, impl) -> None: - """``setattr`` ``impl`` onto the attribute identified by the dotted path. - - Supports two forms: - - ``pkg.mod.attr`` (set module attribute) - - ``pkg.mod.ClassName.attr`` (set class attribute / method) - - The split is found by walking the prefix from the longest importable - module backwards until ``importlib.import_module`` succeeds. - """ - parts = dotted_path.split('.') - if len(parts) < 2: - raise ValueError(f"Expected at least 'pkg.attr', got: {dotted_path!r}") - - # Find the longest prefix that imports as a module. - last_err: ImportError | None = None - module = None - module_depth = 0 - for i in range(len(parts) - 1, 0, -1): - candidate = '.'.join(parts[:i]) - try: - module = importlib.import_module(candidate) - module_depth = i - break - except ImportError as e: - last_err = e - continue - if module is None: - raise ImportError(f'Could not import any prefix of {dotted_path!r}') from last_err - - # Walk remaining attributes; the last one is the target. - obj = module - for attr in parts[module_depth:-1]: - obj = getattr(obj, attr) - setattr(obj, parts[-1], impl) - - -def _load_hub_ref(ref: HubRef): - """Lazy-load a Hub kernel layer via the optional ``kernels`` package.""" - try: - from kernels import get_kernel - except ImportError as e: - raise ImportError( - 'Loading a Hub kernel requires the `kernels` package. ' - 'Install it with `pip install kernels`.' - ) from e - - kernel = get_kernel( - ref.repo_id, - revision=ref.revision, - version=ref.version, - backend=ref.backend, - trust_remote_code=ref.trust_remote_code, - ) - layers = getattr(kernel, 'layers', None) - if layers is None: - raise ValueError(f'Hub repo {ref.repo_id!r} does not define any layers.') - impl = getattr(layers, ref.layer_name, None) - if impl is None: - raise ValueError(f'Layer {ref.layer_name!r} not found in {ref.repo_id!r}.') - return impl - - -def kernelize(model: nn.Module, mapping: dict) -> nn.Module: - """Apply ``mapping`` to ``model`` and return it (modified in place). - - Keys: - - ``type[nn.Module]``: replace ``m.__class__`` for every module of the - exact type (no subclass walking). - - ``str`` (dotted path ``pkg.mod.attr``): ``setattr`` the impl onto the - identified module attribute. - - Values: - - ``dict[str, V]``: device-conditional dispatch using the current - Twinkle platform device prefix; non-matching devices skip. - - ``HubRef``: lazy-resolved via the optional ``kernels`` package. - - anything else: used directly as the impl. - """ - if not mapping: - return model - - device = Platform.device_prefix() - for key, value in mapping.items(): - impl = _resolve_value(value, device) - if impl is None: - continue - if isinstance(impl, HubRef): - impl = _load_hub_ref(impl) - if isinstance(key, type) and issubclass(key, nn.Module): - _replace_class(model, key, impl) - elif isinstance(key, str): - _replace_attr(key, impl) - else: - raise TypeError(f'Unsupported mapping key: {key!r}') - return model +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Minimal mapping-driven kernel replacement. + +Public API: ``kernelize``, ``hub`` (re-exported from ``twinkle.kernel``). +""" +from __future__ import annotations + +import importlib +from dataclasses import dataclass +from typing import Any + +import torch.nn as nn + +from twinkle.utils.device_mesh import Platform + + +@dataclass(frozen=True) +class HubRef: + """Lightweight reference to a HuggingFace Hub kernel layer. + + Resolved lazily by ``kernelize`` via the optional ``kernels`` package. + """ + repo_id: str + layer_name: str + revision: str | None = None + version: int | None = None + backend: str | None = None + trust_remote_code: bool = False + + +def hub( + ref: str, + *, + revision: str | None = None, + version: int | None = None, + backend: str | None = None, + trust_remote_code: bool = False, +) -> HubRef: + """Build a ``HubRef`` for use as a ``kernelize`` mapping value. + + ``ref`` is ``':'`` (e.g. ``'org/repo:SiluAndMul'``). + Exactly one of ``revision`` or ``version`` must be supplied. + """ + if (revision is None) == (version is None): + raise ValueError('Exactly one of `revision` or `version` must be specified.') + if ':' not in ref: + raise ValueError(f"Hub ref must be 'repo_id:LayerName', got: {ref!r}") + repo_id, layer_name = ref.rsplit(':', 1) + return HubRef(repo_id, layer_name, revision, version, backend, trust_remote_code) + + + +def _resolve_value(value: Any, device: str) -> Any | None: + """Resolve a mapping value against the selected device. + + - ``dict``: device-conditional; recurse into ``value[device]`` or return None. + - anything else (including ``HubRef``): pass through. + """ + if isinstance(value, dict): + if device not in value: + return None + return _resolve_value(value[device], device) + return value + + +def _replace_class(model: nn.Module, target_cls: type, impl_cls: type) -> None: + """Rewrite ``__class__`` of every module whose exact type is ``target_cls``. + + Uses ``type(m) is target_cls`` (not ``isinstance``) so user-defined + subclasses of ``target_cls`` are deliberately left alone. + """ + for m in model.modules(): + if type(m) is target_cls: + m.__class__ = impl_cls + + +def _replace_attr(dotted_path: str, impl) -> None: + """``setattr`` ``impl`` onto the attribute identified by the dotted path. + + Supports two forms: + - ``pkg.mod.attr`` (set module attribute) + - ``pkg.mod.ClassName.attr`` (set class attribute / method) + + The split is found by walking the prefix from the longest importable + module backwards until ``importlib.import_module`` succeeds. + """ + parts = dotted_path.split('.') + if len(parts) < 2: + raise ValueError(f"Expected at least 'pkg.attr', got: {dotted_path!r}") + + # Find the longest prefix that imports as a module. + last_err: ImportError | None = None + module = None + module_depth = 0 + for i in range(len(parts) - 1, 0, -1): + candidate = '.'.join(parts[:i]) + try: + module = importlib.import_module(candidate) + module_depth = i + break + except ImportError as e: + last_err = e + continue + if module is None: + raise ImportError(f'Could not import any prefix of {dotted_path!r}') from last_err + + # Walk remaining attributes; the last one is the target. + obj = module + for attr in parts[module_depth:-1]: + obj = getattr(obj, attr) + setattr(obj, parts[-1], impl) + + +def _load_hub_ref(ref: HubRef): + """Lazy-load a Hub kernel layer via the optional ``kernels`` package.""" + try: + from kernels import get_kernel + except ImportError as e: + raise ImportError( + 'Loading a Hub kernel requires the `kernels` package. ' + 'Install it with `pip install kernels`.' + ) from e + + kernel = get_kernel( + ref.repo_id, + revision=ref.revision, + version=ref.version, + backend=ref.backend, + trust_remote_code=ref.trust_remote_code, + ) + layers = getattr(kernel, 'layers', None) + if layers is None: + raise ValueError(f'Hub repo {ref.repo_id!r} does not define any layers.') + impl = getattr(layers, ref.layer_name, None) + if impl is None: + raise ValueError(f'Layer {ref.layer_name!r} not found in {ref.repo_id!r}.') + return impl + + +def kernelize(model: nn.Module, mapping: dict) -> nn.Module: + """Apply ``mapping`` to ``model`` and return it (modified in place). + + Keys: + - ``type[nn.Module]``: replace ``m.__class__`` for every module of the + exact type (no subclass walking). + - ``str`` (dotted path ``pkg.mod.attr``): ``setattr`` the impl onto the + identified module attribute. + + Values: + - ``dict[str, V]``: device-conditional dispatch using the current + Twinkle platform device prefix; non-matching devices skip. + - ``HubRef``: lazy-resolved via the optional ``kernels`` package. + - anything else: used directly as the impl. + """ + if not mapping: + return model + + device = Platform.device_prefix() + for key, value in mapping.items(): + impl = _resolve_value(value, device) + if impl is None: + continue + if isinstance(impl, HubRef): + impl = _load_hub_ref(impl) + if isinstance(key, type) and issubclass(key, nn.Module): + _replace_class(model, key, impl) + elif isinstance(key, str): + _replace_attr(key, impl) + else: + raise TypeError(f'Unsupported mapping key: {key!r}') + return model diff --git a/src/twinkle/kernel/npu_impls/__init__.py b/src/twinkle/kernel/npu_impls/__init__.py index 71606aaa..31d77581 100644 --- a/src/twinkle/kernel/npu_impls/__init__.py +++ b/src/twinkle/kernel/npu_impls/__init__.py @@ -1,32 +1,32 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -"""Per-layer NPU implementations consumed by ``npu_builtin()``. - -Each impl is contracted to be applied via ``m.__class__ = ImplCls`` (class -replacement) or ``setattr(module, attr, fn)`` (function replacement). No impl -here is meant to be instantiated directly. -""" -from .rms_norm import NpuRMSNorm, npu_gated_rms_norm_forward -from .rotary import npu_apply_multimodal_rotary_pos_emb, npu_apply_rotary_pos_emb -from .swiglu import npu_swiglu_forward -from .attention import npu_sdpa_attention_forward -from .moe import ( - GmmFunction, - npu_grouped_mm, - npu_packed_moe_experts_forward, - npu_qwen3_5_moe_sparse_block_forward, -) -from .fla import apply_qwen3_5_fla - -__all__ = [ - 'NpuRMSNorm', - 'npu_gated_rms_norm_forward', - 'npu_apply_rotary_pos_emb', - 'npu_apply_multimodal_rotary_pos_emb', - 'npu_swiglu_forward', - 'npu_sdpa_attention_forward', - 'GmmFunction', - 'npu_grouped_mm', - 'npu_packed_moe_experts_forward', - 'npu_qwen3_5_moe_sparse_block_forward', - 'apply_qwen3_5_fla', +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Per-layer NPU implementations consumed by ``npu_builtin()``. + +Each impl is contracted to be applied via ``m.__class__ = ImplCls`` (class +replacement) or ``setattr(module, attr, fn)`` (function replacement). No impl +here is meant to be instantiated directly. +""" +from .rms_norm import NpuRMSNorm, npu_gated_rms_norm_forward +from .rotary import npu_apply_multimodal_rotary_pos_emb, npu_apply_rotary_pos_emb +from .swiglu import npu_swiglu_forward +from .attention import npu_sdpa_attention_forward +from .moe import ( + GmmFunction, + npu_grouped_mm, + npu_packed_moe_experts_forward, + npu_qwen3_5_moe_sparse_block_forward, +) +from .fla import apply_qwen3_5_fla + +__all__ = [ + 'NpuRMSNorm', + 'npu_gated_rms_norm_forward', + 'npu_apply_rotary_pos_emb', + 'npu_apply_multimodal_rotary_pos_emb', + 'npu_swiglu_forward', + 'npu_sdpa_attention_forward', + 'GmmFunction', + 'npu_grouped_mm', + 'npu_packed_moe_experts_forward', + 'npu_qwen3_5_moe_sparse_block_forward', + 'apply_qwen3_5_fla', ] \ No newline at end of file diff --git a/src/twinkle/kernel/npu_impls/attention.py b/src/twinkle/kernel/npu_impls/attention.py index 2bf4255b..f328b2d5 100644 --- a/src/twinkle/kernel/npu_impls/attention.py +++ b/src/twinkle/kernel/npu_impls/attention.py @@ -1,54 +1,54 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -"""SDPA forward with Ascend NPU compatibility fixes.""" -from __future__ import annotations - -import torch - - -def npu_sdpa_attention_forward( - module, - query, - key, - value, - attention_mask, - dropout=0.0, - scaling=None, - is_causal=None, - **kwargs, -): - """Drop-in replacement for ``transformers.integrations.sdpa_attention.sdpa_attention_forward``. - - Fixes: - - Repeats KV heads (NPU SDPA does not auto-broadcast num_kv_groups). - - Truncates causal_mask to key length. - - Forces contiguous tensors (NPU SDPA requirement). - - Inverts boolean masks (NPU treats ``True`` as masked). - """ - from transformers.integrations.sdpa_attention import repeat_kv - - if hasattr(module, 'num_key_value_groups'): - key = repeat_kv(key, module.num_key_value_groups) - value = repeat_kv(value, module.num_key_value_groups) - - causal_mask = attention_mask - if attention_mask is not None and causal_mask.ndim == 4: - causal_mask = causal_mask[:, :, :, :key.shape[-2]] - - query, key, value = query.contiguous(), key.contiguous(), value.contiguous() - - if is_causal is None: - is_causal = query.shape[2] > 1 and causal_mask is None - - if causal_mask is not None and causal_mask.dtype != torch.bool: - causal_mask = torch.logical_not(causal_mask.bool()).to(query.device) - - attn_output = torch.nn.functional.scaled_dot_product_attention( - query, - key, - value, - attn_mask=causal_mask, - dropout_p=dropout, - scale=scaling, - is_causal=is_causal, - ) +# Copyright (c) ModelScope Contributors. All rights reserved. +"""SDPA forward with Ascend NPU compatibility fixes.""" +from __future__ import annotations + +import torch + + +def npu_sdpa_attention_forward( + module, + query, + key, + value, + attention_mask, + dropout=0.0, + scaling=None, + is_causal=None, + **kwargs, +): + """Drop-in replacement for ``transformers.integrations.sdpa_attention.sdpa_attention_forward``. + + Fixes: + - Repeats KV heads (NPU SDPA does not auto-broadcast num_kv_groups). + - Truncates causal_mask to key length. + - Forces contiguous tensors (NPU SDPA requirement). + - Inverts boolean masks (NPU treats ``True`` as masked). + """ + from transformers.integrations.sdpa_attention import repeat_kv + + if hasattr(module, 'num_key_value_groups'): + key = repeat_kv(key, module.num_key_value_groups) + value = repeat_kv(value, module.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None and causal_mask.ndim == 4: + causal_mask = causal_mask[:, :, :, :key.shape[-2]] + + query, key, value = query.contiguous(), key.contiguous(), value.contiguous() + + if is_causal is None: + is_causal = query.shape[2] > 1 and causal_mask is None + + if causal_mask is not None and causal_mask.dtype != torch.bool: + causal_mask = torch.logical_not(causal_mask.bool()).to(query.device) + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query, + key, + value, + attn_mask=causal_mask, + dropout_p=dropout, + scale=scaling, + is_causal=is_causal, + ) return attn_output.transpose(1, 2).contiguous(), None \ No newline at end of file diff --git a/src/twinkle/kernel/npu_impls/fla.py b/src/twinkle/kernel/npu_impls/fla.py index 9a387f77..d2fc43a9 100644 --- a/src/twinkle/kernel/npu_impls/fla.py +++ b/src/twinkle/kernel/npu_impls/fla.py @@ -1,103 +1,103 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -"""Qwen3.5 Flash Linear Attention enablement for Ascend NPU.""" -from __future__ import annotations - -import importlib -import os - -from twinkle import get_logger - -logger = get_logger() - - -def _is_env_enabled(var: str, default: bool = True) -> bool: - env = os.environ.get(var, '').lower().strip() - if not env: - return default - if env in ('1', 'true', 'on', 'yes'): - return True - if env in ('0', 'false', 'off', 'no'): - return False - return default - - -def _import_optional(name: str): - try: - return importlib.import_module(name) - except ImportError: - return None - - -def apply_qwen3_5_fla(model=None) -> int: - """Enable Flash Linear Attention fast path for Qwen3.5 on NPU. - - Returns the count of patched per-layer instances (0 when disabled or when - prerequisites are missing). Safe to call multiple times. - """ - if not _is_env_enabled('TWINKLE_NPU_FLA', default=True): - logger.info('[NPU] [FLA] Disabled by TWINKLE_NPU_FLA') - return 0 - - if _import_optional('torch_npu') is None: - logger.info('[NPU] [FLA] Skip: torch_npu unavailable') - return 0 - - # 1. Confirm the MindSpeed Triton kernel is actually importable BEFORE - # flipping any global availability flags. If we flip the flag and then - # fail to install the kernel, HF transformers would route Qwen3.5 onto - # a FLA fast path whose kernel is missing -> runtime failure on NPU. - try: - from twinkle.kernel.chunk_gated_delta_rule import chunk_gated_delta_rule as mindspeed_fla - from twinkle.kernel.causal_conv1d import npu_causal_conv1d_fn - except ImportError as exc: - logger.warning('[NPU] [FLA] MindSpeed unavailable: %s', exc) - return 0 - - # 2. Only now can we safely claim FLA is available: flip the global flags - # and install the kernel path on Qwen3.5 modeling modules. - def _is_fla_available() -> bool: - return True - - for utils_mod_name in ('transformers.utils', 'transformers.utils.import_utils'): - utils_mod = _import_optional(utils_mod_name) - if utils_mod is not None: - setattr(utils_mod, 'is_flash_linear_attention_available', _is_fla_available) - - # 3. Patch Qwen3.5 modeling modules - fla_target_modules = [ - 'transformers.models.qwen3_5.modeling_qwen3_5', - 'transformers.models.qwen3_5_moe.modeling_qwen3_5_moe', - ] - for module_name in fla_target_modules: - module = _import_optional(module_name) - if module is None: - continue - setattr(module, 'is_flash_linear_attention_available', _is_fla_available) - setattr(module, 'is_fast_path_available', True) - if hasattr(module, 'FusedRMSNormGated'): - setattr(module, 'FusedRMSNormGated', None) - setattr(module, 'chunk_gated_delta_rule', mindspeed_fla) - - # 4. Traverse model and patch per-layer attributes - if model is None: - return 0 - - root = getattr(model, 'model', getattr(model, 'module', model)) - if not hasattr(root, 'named_modules'): - return 0 - - patched_instances = 0 - for _name, _module in root.named_modules(): - if hasattr(_module, 'chunk_gated_delta_rule') and callable( - getattr(_module, 'chunk_gated_delta_rule')): - if _module.chunk_gated_delta_rule is not mindspeed_fla: - _module.chunk_gated_delta_rule = mindspeed_fla - _module._twinkle_npu_patched = True - patched_instances += 1 - if hasattr(_module, 'causal_conv1d_fn'): - if getattr(_module, 'causal_conv1d_fn') is not npu_causal_conv1d_fn: - _module.causal_conv1d_fn = npu_causal_conv1d_fn - - if patched_instances: - logger.info('[NPU] [FLA] Patched %d linear attention instance(s)', patched_instances) +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Qwen3.5 Flash Linear Attention enablement for Ascend NPU.""" +from __future__ import annotations + +import importlib +import os + +from twinkle import get_logger + +logger = get_logger() + + +def _is_env_enabled(var: str, default: bool = True) -> bool: + env = os.environ.get(var, '').lower().strip() + if not env: + return default + if env in ('1', 'true', 'on', 'yes'): + return True + if env in ('0', 'false', 'off', 'no'): + return False + return default + + +def _import_optional(name: str): + try: + return importlib.import_module(name) + except ImportError: + return None + + +def apply_qwen3_5_fla(model=None) -> int: + """Enable Flash Linear Attention fast path for Qwen3.5 on NPU. + + Returns the count of patched per-layer instances (0 when disabled or when + prerequisites are missing). Safe to call multiple times. + """ + if not _is_env_enabled('TWINKLE_NPU_FLA', default=True): + logger.info('[NPU] [FLA] Disabled by TWINKLE_NPU_FLA') + return 0 + + if _import_optional('torch_npu') is None: + logger.info('[NPU] [FLA] Skip: torch_npu unavailable') + return 0 + + # 1. Confirm the MindSpeed Triton kernel is actually importable BEFORE + # flipping any global availability flags. If we flip the flag and then + # fail to install the kernel, HF transformers would route Qwen3.5 onto + # a FLA fast path whose kernel is missing -> runtime failure on NPU. + try: + from twinkle.kernel.chunk_gated_delta_rule import chunk_gated_delta_rule as mindspeed_fla + from twinkle.kernel.causal_conv1d import npu_causal_conv1d_fn + except ImportError as exc: + logger.warning('[NPU] [FLA] MindSpeed unavailable: %s', exc) + return 0 + + # 2. Only now can we safely claim FLA is available: flip the global flags + # and install the kernel path on Qwen3.5 modeling modules. + def _is_fla_available() -> bool: + return True + + for utils_mod_name in ('transformers.utils', 'transformers.utils.import_utils'): + utils_mod = _import_optional(utils_mod_name) + if utils_mod is not None: + setattr(utils_mod, 'is_flash_linear_attention_available', _is_fla_available) + + # 3. Patch Qwen3.5 modeling modules + fla_target_modules = [ + 'transformers.models.qwen3_5.modeling_qwen3_5', + 'transformers.models.qwen3_5_moe.modeling_qwen3_5_moe', + ] + for module_name in fla_target_modules: + module = _import_optional(module_name) + if module is None: + continue + setattr(module, 'is_flash_linear_attention_available', _is_fla_available) + setattr(module, 'is_fast_path_available', True) + if hasattr(module, 'FusedRMSNormGated'): + setattr(module, 'FusedRMSNormGated', None) + setattr(module, 'chunk_gated_delta_rule', mindspeed_fla) + + # 4. Traverse model and patch per-layer attributes + if model is None: + return 0 + + root = getattr(model, 'model', getattr(model, 'module', model)) + if not hasattr(root, 'named_modules'): + return 0 + + patched_instances = 0 + for _name, _module in root.named_modules(): + if hasattr(_module, 'chunk_gated_delta_rule') and callable( + getattr(_module, 'chunk_gated_delta_rule')): + if _module.chunk_gated_delta_rule is not mindspeed_fla: + _module.chunk_gated_delta_rule = mindspeed_fla + _module._twinkle_npu_patched = True + patched_instances += 1 + if hasattr(_module, 'causal_conv1d_fn'): + if getattr(_module, 'causal_conv1d_fn') is not npu_causal_conv1d_fn: + _module.causal_conv1d_fn = npu_causal_conv1d_fn + + if patched_instances: + logger.info('[NPU] [FLA] Patched %d linear attention instance(s)', patched_instances) return patched_instances \ No newline at end of file diff --git a/src/twinkle/kernel/npu_impls/moe.py b/src/twinkle/kernel/npu_impls/moe.py index c576cd04..efa7f71a 100644 --- a/src/twinkle/kernel/npu_impls/moe.py +++ b/src/twinkle/kernel/npu_impls/moe.py @@ -1,151 +1,151 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -"""MoE GMM + packed-experts + sparse-block impls for Ascend NPU.""" -from __future__ import annotations - -import torch -import torch.nn.functional as F - - -class GmmFunction(torch.autograd.Function): - """Custom autograd function for NPU grouped matrix multiplication.""" - - @staticmethod - def forward(ctx, x: torch.Tensor, group_list: torch.Tensor, weight_ekn: torch.Tensor): - import torch_npu - group_list = group_list.to(torch.int64) - ctx.save_for_backward(x, group_list, weight_ekn) - outputs = torch_npu.npu_grouped_matmul( - [x], [weight_ekn], group_list=group_list, - group_type=0, split_item=2, group_list_type=1, - ) - return outputs[0] - - @staticmethod - def backward(ctx, grad_output: torch.Tensor): - import torch_npu - x, group_list, weight_ekn = ctx.saved_tensors - grad_input = torch_npu.npu_grouped_matmul( - [grad_output], [weight_ekn.transpose(-2, -1).contiguous()], - bias=None, group_list=group_list, - group_type=0, split_item=2, group_list_type=1, - )[0] - grad_weight = torch_npu.npu_grouped_matmul( - [x.transpose(0, 1)], [grad_output], - bias=None, group_list=group_list, - group_type=2, split_item=3, group_list_type=1, - )[0] - return grad_input, None, grad_weight.contiguous() - - -def npu_grouped_mm(input: torch.Tensor, weight_ekn: torch.Tensor, offs: torch.Tensor) -> torch.Tensor: - """Drop-in replacement for ``transformers.integrations.moe._grouped_mm``.""" - counts = torch.empty_like(offs) - counts[0] = offs[0] - if offs.numel() > 1: - counts[1:] = offs[1:] - offs[:-1] - counts = counts.to(torch.int64) - return GmmFunction.apply(input, counts, weight_ekn) - - -def _normalize_packed_expert_weights(module, input_dtype, hidden_dim): - gate_up_proj = module.gate_up_proj.to(input_dtype) - down_proj = module.down_proj.to(input_dtype) - if gate_up_proj.shape[1] == hidden_dim: - gate_up_weight = gate_up_proj - elif gate_up_proj.shape[2] == hidden_dim: - gate_up_weight = gate_up_proj.transpose(1, 2) - else: - raise RuntimeError(f'Unsupported gate_up_proj shape: {tuple(gate_up_proj.shape)}.') - if down_proj.shape[2] == hidden_dim: - down_weight = down_proj - elif down_proj.shape[1] == hidden_dim: - down_weight = down_proj.transpose(1, 2) - else: - raise RuntimeError(f'Unsupported down_proj shape: {tuple(down_proj.shape)}.') - return gate_up_weight, down_weight - - -def _get_cached_expert_weights(self, target_dtype, hidden_dim): - requires_grad = ( - getattr(self.gate_up_proj, 'requires_grad', False) - or getattr(self.down_proj, 'requires_grad', False) - ) - cache_attr = '_npu_expert_cache' - if not requires_grad and hasattr(self, cache_attr): - cached_dtype, cached_gv, cached_dv, cached = getattr(self, cache_attr) - if (cached_dtype == target_dtype - and cached_gv == self.gate_up_proj._version - and cached_dv == self.down_proj._version): - return cached - weights = _normalize_packed_expert_weights(self, target_dtype, hidden_dim) - if not requires_grad: - setattr(self, cache_attr, - (target_dtype, self.gate_up_proj._version, self.down_proj._version, weights)) - return weights - - -def npu_packed_moe_experts_forward(self, hidden_states, a, b): - """Packed MoE Experts.forward using NPU grouped matmul. - - Accepts both call orderings: ``(hidden_states, routing_weights, router_indices)`` - and ``(hidden_states, router_indices, routing_weights)`` — distinguishes by dtype. - """ - import torch_npu - if a.dtype in {torch.int8, torch.int16, torch.int32, torch.int64, torch.uint8}: - router_indices, routing_weights = a, b - else: - routing_weights, router_indices = a, b - - output_shape = hidden_states.shape - hidden_dim = output_shape[-1] - hidden_states = hidden_states.reshape(-1, hidden_dim) - - if routing_weights.shape != router_indices.shape: - routing_weights = torch.gather(routing_weights, dim=-1, index=router_indices.to(torch.long)) - routing_weights = routing_weights.to(hidden_states.dtype) - router_indices = router_indices.to(torch.int32) - - permuted, row_ids_map = torch_npu.npu_moe_token_permute(hidden_states, router_indices) - tokens_per_expert = torch.bincount(router_indices.view(-1), minlength=self.num_experts).to(torch.int64) - gate_up_weight, down_weight = _get_cached_expert_weights(self, hidden_states.dtype, hidden_dim) - - intermediate = GmmFunction.apply(permuted, tokens_per_expert, gate_up_weight) - activated = torch_npu.npu_swiglu(intermediate, dim=-1) - output = GmmFunction.apply(activated, tokens_per_expert, down_weight) - next_states = torch_npu.npu_moe_token_unpermute(output, row_ids_map, probs=routing_weights) - return next_states.view(*output_shape) - - -def _topk_from_router_logits(module, hidden_states, router_logits): - routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float) - routing_weights, router_indices = torch.topk(routing_weights, module.top_k, dim=-1) - if getattr(module, 'norm_topk_prob', True): - routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True) - routing_weights = routing_weights.to(hidden_states.dtype) - return routing_weights, router_indices - - -def _add_shared_expert(self, hidden_states, expert_output): - if not (hasattr(self, 'shared_expert') and hasattr(self, 'shared_expert_gate')): - return expert_output - shared = self.shared_expert(hidden_states) - shared = F.sigmoid(self.shared_expert_gate(hidden_states)) * shared - return expert_output + shared - - -def npu_qwen3_5_moe_sparse_block_forward(self, hidden_states): - """SparseMoeBlock.forward replacement (Transformers 4.x and 5.x compatible).""" - batch_size, sequence_length, hidden_dim = hidden_states.shape - gate_output = self.gate(hidden_states.view(-1, hidden_dim)) - - if isinstance(gate_output, tuple): - _, routing_weights, selected_experts = gate_output - flat = hidden_states.view(-1, hidden_dim) - expert_output = self.experts(flat, selected_experts, routing_weights) - else: - flat = hidden_states.view(-1, hidden_dim) - routing_weights, selected_experts = _topk_from_router_logits(self, flat, gate_output) - expert_output = self.experts(flat, selected_experts, routing_weights) - - expert_output = _add_shared_expert(self, flat, expert_output) +# Copyright (c) ModelScope Contributors. All rights reserved. +"""MoE GMM + packed-experts + sparse-block impls for Ascend NPU.""" +from __future__ import annotations + +import torch +import torch.nn.functional as F + + +class GmmFunction(torch.autograd.Function): + """Custom autograd function for NPU grouped matrix multiplication.""" + + @staticmethod + def forward(ctx, x: torch.Tensor, group_list: torch.Tensor, weight_ekn: torch.Tensor): + import torch_npu + group_list = group_list.to(torch.int64) + ctx.save_for_backward(x, group_list, weight_ekn) + outputs = torch_npu.npu_grouped_matmul( + [x], [weight_ekn], group_list=group_list, + group_type=0, split_item=2, group_list_type=1, + ) + return outputs[0] + + @staticmethod + def backward(ctx, grad_output: torch.Tensor): + import torch_npu + x, group_list, weight_ekn = ctx.saved_tensors + grad_input = torch_npu.npu_grouped_matmul( + [grad_output], [weight_ekn.transpose(-2, -1).contiguous()], + bias=None, group_list=group_list, + group_type=0, split_item=2, group_list_type=1, + )[0] + grad_weight = torch_npu.npu_grouped_matmul( + [x.transpose(0, 1)], [grad_output], + bias=None, group_list=group_list, + group_type=2, split_item=3, group_list_type=1, + )[0] + return grad_input, None, grad_weight.contiguous() + + +def npu_grouped_mm(input: torch.Tensor, weight_ekn: torch.Tensor, offs: torch.Tensor) -> torch.Tensor: + """Drop-in replacement for ``transformers.integrations.moe._grouped_mm``.""" + counts = torch.empty_like(offs) + counts[0] = offs[0] + if offs.numel() > 1: + counts[1:] = offs[1:] - offs[:-1] + counts = counts.to(torch.int64) + return GmmFunction.apply(input, counts, weight_ekn) + + +def _normalize_packed_expert_weights(module, input_dtype, hidden_dim): + gate_up_proj = module.gate_up_proj.to(input_dtype) + down_proj = module.down_proj.to(input_dtype) + if gate_up_proj.shape[1] == hidden_dim: + gate_up_weight = gate_up_proj + elif gate_up_proj.shape[2] == hidden_dim: + gate_up_weight = gate_up_proj.transpose(1, 2) + else: + raise RuntimeError(f'Unsupported gate_up_proj shape: {tuple(gate_up_proj.shape)}.') + if down_proj.shape[2] == hidden_dim: + down_weight = down_proj + elif down_proj.shape[1] == hidden_dim: + down_weight = down_proj.transpose(1, 2) + else: + raise RuntimeError(f'Unsupported down_proj shape: {tuple(down_proj.shape)}.') + return gate_up_weight, down_weight + + +def _get_cached_expert_weights(self, target_dtype, hidden_dim): + requires_grad = ( + getattr(self.gate_up_proj, 'requires_grad', False) + or getattr(self.down_proj, 'requires_grad', False) + ) + cache_attr = '_npu_expert_cache' + if not requires_grad and hasattr(self, cache_attr): + cached_dtype, cached_gv, cached_dv, cached = getattr(self, cache_attr) + if (cached_dtype == target_dtype + and cached_gv == self.gate_up_proj._version + and cached_dv == self.down_proj._version): + return cached + weights = _normalize_packed_expert_weights(self, target_dtype, hidden_dim) + if not requires_grad: + setattr(self, cache_attr, + (target_dtype, self.gate_up_proj._version, self.down_proj._version, weights)) + return weights + + +def npu_packed_moe_experts_forward(self, hidden_states, a, b): + """Packed MoE Experts.forward using NPU grouped matmul. + + Accepts both call orderings: ``(hidden_states, routing_weights, router_indices)`` + and ``(hidden_states, router_indices, routing_weights)`` — distinguishes by dtype. + """ + import torch_npu + if a.dtype in {torch.int8, torch.int16, torch.int32, torch.int64, torch.uint8}: + router_indices, routing_weights = a, b + else: + routing_weights, router_indices = a, b + + output_shape = hidden_states.shape + hidden_dim = output_shape[-1] + hidden_states = hidden_states.reshape(-1, hidden_dim) + + if routing_weights.shape != router_indices.shape: + routing_weights = torch.gather(routing_weights, dim=-1, index=router_indices.to(torch.long)) + routing_weights = routing_weights.to(hidden_states.dtype) + router_indices = router_indices.to(torch.int32) + + permuted, row_ids_map = torch_npu.npu_moe_token_permute(hidden_states, router_indices) + tokens_per_expert = torch.bincount(router_indices.view(-1), minlength=self.num_experts).to(torch.int64) + gate_up_weight, down_weight = _get_cached_expert_weights(self, hidden_states.dtype, hidden_dim) + + intermediate = GmmFunction.apply(permuted, tokens_per_expert, gate_up_weight) + activated = torch_npu.npu_swiglu(intermediate, dim=-1) + output = GmmFunction.apply(activated, tokens_per_expert, down_weight) + next_states = torch_npu.npu_moe_token_unpermute(output, row_ids_map, probs=routing_weights) + return next_states.view(*output_shape) + + +def _topk_from_router_logits(module, hidden_states, router_logits): + routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float) + routing_weights, router_indices = torch.topk(routing_weights, module.top_k, dim=-1) + if getattr(module, 'norm_topk_prob', True): + routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True) + routing_weights = routing_weights.to(hidden_states.dtype) + return routing_weights, router_indices + + +def _add_shared_expert(self, hidden_states, expert_output): + if not (hasattr(self, 'shared_expert') and hasattr(self, 'shared_expert_gate')): + return expert_output + shared = self.shared_expert(hidden_states) + shared = F.sigmoid(self.shared_expert_gate(hidden_states)) * shared + return expert_output + shared + + +def npu_qwen3_5_moe_sparse_block_forward(self, hidden_states): + """SparseMoeBlock.forward replacement (Transformers 4.x and 5.x compatible).""" + batch_size, sequence_length, hidden_dim = hidden_states.shape + gate_output = self.gate(hidden_states.view(-1, hidden_dim)) + + if isinstance(gate_output, tuple): + _, routing_weights, selected_experts = gate_output + flat = hidden_states.view(-1, hidden_dim) + expert_output = self.experts(flat, selected_experts, routing_weights) + else: + flat = hidden_states.view(-1, hidden_dim) + routing_weights, selected_experts = _topk_from_router_logits(self, flat, gate_output) + expert_output = self.experts(flat, selected_experts, routing_weights) + + expert_output = _add_shared_expert(self, flat, expert_output) return expert_output.reshape(batch_size, sequence_length, hidden_dim) \ No newline at end of file diff --git a/src/twinkle/kernel/npu_impls/rms_norm.py b/src/twinkle/kernel/npu_impls/rms_norm.py index 98e95699..ecebdc23 100644 --- a/src/twinkle/kernel/npu_impls/rms_norm.py +++ b/src/twinkle/kernel/npu_impls/rms_norm.py @@ -1,75 +1,75 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -"""Fused RMSNorm impls for Ascend NPU. - -Designed for class-replacement: do not define ``__init__``; rely on the -attributes already present on the original instance. -""" -from __future__ import annotations - -import os - -import torch -import torch.nn as nn -import torch.nn.functional as F - -from twinkle import get_logger - -logger = get_logger() - - -class NpuRMSNorm(nn.Module): - """Class-replacement impl for HF RMSNorm variants. - - Required instance attributes (provided by the original class): - - ``weight``: ``nn.Parameter`` - - ``variance_epsilon`` *or* ``eps``: float - """ - - def _twinkle_residual_param(self) -> bool: - """Lazily detect residual parameterization (e.g. Qwen3.5: scale = 1 + weight).""" - cached = getattr(self, '_twinkle_residual_cached', None) - if cached is None: - cached = abs(self.weight.data.mean().item()) < 0.3 - self._twinkle_residual_cached = cached - if cached: - logger.debug('[NPU] NpuRMSNorm using residual parameterization (1.0 + weight)') - return cached - - def _twinkle_eps(self) -> float: - return getattr(self, 'variance_epsilon', getattr(self, 'eps', 1e-6)) - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - import torch_npu - target_dtype = hidden_states.dtype - if self._twinkle_residual_param(): - scale = (1.0 + self.weight).to(target_dtype) - else: - scale = self.weight.to(target_dtype) - return torch_npu.npu_rms_norm(hidden_states, scale, epsilon=self._twinkle_eps())[0] - - -# Resolved once at import: matches the legacy "patch-time, process-wide" invariant. -# Mid-process env mutation will not retroactively change behavior. -_FORCE_FP32 = os.environ.get('TWINKLE_NPU_GATED_RMSNorm_FP32', '0').lower() in ( - '1', 'true', 'on', 'yes' -) - - -def npu_gated_rms_norm_forward(self, hidden_states, gate=None): - """Forward replacement for Gated RMSNorm variants (e.g. Qwen3.5-MoE).""" - import torch_npu - - input_dtype = hidden_states.dtype - _eps = getattr(self, 'variance_epsilon', getattr(self, 'eps', 1e-6)) - - if _FORCE_FP32: - hidden_states = hidden_states.to(torch.float32) - weight = self.weight.float() - gate = gate.to(torch.float32) if gate is not None else None - else: - weight = self.weight - - hidden_states = torch_npu.npu_rms_norm(hidden_states, weight, epsilon=_eps)[0] - if gate is not None: - hidden_states = hidden_states * F.silu(gate) +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Fused RMSNorm impls for Ascend NPU. + +Designed for class-replacement: do not define ``__init__``; rely on the +attributes already present on the original instance. +""" +from __future__ import annotations + +import os + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from twinkle import get_logger + +logger = get_logger() + + +class NpuRMSNorm(nn.Module): + """Class-replacement impl for HF RMSNorm variants. + + Required instance attributes (provided by the original class): + - ``weight``: ``nn.Parameter`` + - ``variance_epsilon`` *or* ``eps``: float + """ + + def _twinkle_residual_param(self) -> bool: + """Lazily detect residual parameterization (e.g. Qwen3.5: scale = 1 + weight).""" + cached = getattr(self, '_twinkle_residual_cached', None) + if cached is None: + cached = abs(self.weight.data.mean().item()) < 0.3 + self._twinkle_residual_cached = cached + if cached: + logger.debug('[NPU] NpuRMSNorm using residual parameterization (1.0 + weight)') + return cached + + def _twinkle_eps(self) -> float: + return getattr(self, 'variance_epsilon', getattr(self, 'eps', 1e-6)) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + import torch_npu + target_dtype = hidden_states.dtype + if self._twinkle_residual_param(): + scale = (1.0 + self.weight).to(target_dtype) + else: + scale = self.weight.to(target_dtype) + return torch_npu.npu_rms_norm(hidden_states, scale, epsilon=self._twinkle_eps())[0] + + +# Resolved once at import: matches the legacy "patch-time, process-wide" invariant. +# Mid-process env mutation will not retroactively change behavior. +_FORCE_FP32 = os.environ.get('TWINKLE_NPU_GATED_RMSNorm_FP32', '0').lower() in ( + '1', 'true', 'on', 'yes' +) + + +def npu_gated_rms_norm_forward(self, hidden_states, gate=None): + """Forward replacement for Gated RMSNorm variants (e.g. Qwen3.5-MoE).""" + import torch_npu + + input_dtype = hidden_states.dtype + _eps = getattr(self, 'variance_epsilon', getattr(self, 'eps', 1e-6)) + + if _FORCE_FP32: + hidden_states = hidden_states.to(torch.float32) + weight = self.weight.float() + gate = gate.to(torch.float32) if gate is not None else None + else: + weight = self.weight + + hidden_states = torch_npu.npu_rms_norm(hidden_states, weight, epsilon=_eps)[0] + if gate is not None: + hidden_states = hidden_states * F.silu(gate) return hidden_states.to(input_dtype) \ No newline at end of file diff --git a/src/twinkle/kernel/npu_impls/rotary.py b/src/twinkle/kernel/npu_impls/rotary.py index 6493dc8b..1ed437a3 100644 --- a/src/twinkle/kernel/npu_impls/rotary.py +++ b/src/twinkle/kernel/npu_impls/rotary.py @@ -1,66 +1,66 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -"""Fused RoPE impls for Ascend NPU (lazy ``torch_npu`` import).""" -from __future__ import annotations - -import torch - - -def _resolve_unsqueeze_dim(position_ids=None, unsqueeze_dim=1): - if isinstance(position_ids, int) and unsqueeze_dim == 1: - return position_ids - return unsqueeze_dim - - -def _make_apply_npu_rotary_emb(): - """Closure with per-shape Partial-RoPE detection cache.""" - _cached_partial: dict[tuple[int, int], bool] = {} - - def _apply(q, k, cos, sin): - import torch_npu - rotary_dim = cos.shape[-1] - query_dim = q.shape[-1] - shape_key = (rotary_dim, query_dim) - - use_partial = _cached_partial.get(shape_key) - if use_partial is None: - use_partial = rotary_dim < query_dim - _cached_partial[shape_key] = use_partial - - if use_partial: - q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] - k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] - q_embed = torch_npu.npu_rotary_mul(q_rot, cos, sin).to(q.dtype) - k_embed = torch_npu.npu_rotary_mul(k_rot, cos, sin).to(k.dtype) - q_embed = torch.cat([q_embed, q_pass], dim=-1) - k_embed = torch.cat([k_embed, k_pass], dim=-1) - else: - q_embed = torch_npu.npu_rotary_mul(q, cos, sin).to(q.dtype) - k_embed = torch_npu.npu_rotary_mul(k, cos, sin).to(k.dtype) - return q_embed, k_embed - - return _apply - - -_apply_npu_rotary_emb = _make_apply_npu_rotary_emb() - - -def npu_apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): - """Fused RoPE via ``torch_npu.npu_rotary_mul`` with Partial-RoPE support.""" - unsqueeze_dim = _resolve_unsqueeze_dim(position_ids, unsqueeze_dim) - cos = cos.unsqueeze(unsqueeze_dim) - sin = sin.unsqueeze(unsqueeze_dim) - return _apply_npu_rotary_emb(q, k, cos, sin) - - -def npu_apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1): - """Multimodal RoPE for Qwen2.5-VL with Partial-RoPE support.""" - mrope_section = mrope_section * 2 - cos = torch.cat( - [m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], - dim=-1, - ).unsqueeze(unsqueeze_dim) - sin = torch.cat( - [m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], - dim=-1, - ).unsqueeze(unsqueeze_dim) +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Fused RoPE impls for Ascend NPU (lazy ``torch_npu`` import).""" +from __future__ import annotations + +import torch + + +def _resolve_unsqueeze_dim(position_ids=None, unsqueeze_dim=1): + if isinstance(position_ids, int) and unsqueeze_dim == 1: + return position_ids + return unsqueeze_dim + + +def _make_apply_npu_rotary_emb(): + """Closure with per-shape Partial-RoPE detection cache.""" + _cached_partial: dict[tuple[int, int], bool] = {} + + def _apply(q, k, cos, sin): + import torch_npu + rotary_dim = cos.shape[-1] + query_dim = q.shape[-1] + shape_key = (rotary_dim, query_dim) + + use_partial = _cached_partial.get(shape_key) + if use_partial is None: + use_partial = rotary_dim < query_dim + _cached_partial[shape_key] = use_partial + + if use_partial: + q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] + k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] + q_embed = torch_npu.npu_rotary_mul(q_rot, cos, sin).to(q.dtype) + k_embed = torch_npu.npu_rotary_mul(k_rot, cos, sin).to(k.dtype) + q_embed = torch.cat([q_embed, q_pass], dim=-1) + k_embed = torch.cat([k_embed, k_pass], dim=-1) + else: + q_embed = torch_npu.npu_rotary_mul(q, cos, sin).to(q.dtype) + k_embed = torch_npu.npu_rotary_mul(k, cos, sin).to(k.dtype) + return q_embed, k_embed + + return _apply + + +_apply_npu_rotary_emb = _make_apply_npu_rotary_emb() + + +def npu_apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Fused RoPE via ``torch_npu.npu_rotary_mul`` with Partial-RoPE support.""" + unsqueeze_dim = _resolve_unsqueeze_dim(position_ids, unsqueeze_dim) + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + return _apply_npu_rotary_emb(q, k, cos, sin) + + +def npu_apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1): + """Multimodal RoPE for Qwen2.5-VL with Partial-RoPE support.""" + mrope_section = mrope_section * 2 + cos = torch.cat( + [m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], + dim=-1, + ).unsqueeze(unsqueeze_dim) + sin = torch.cat( + [m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], + dim=-1, + ).unsqueeze(unsqueeze_dim) return _apply_npu_rotary_emb(q, k, cos, sin) \ No newline at end of file diff --git a/src/twinkle/kernel/npu_impls/swiglu.py b/src/twinkle/kernel/npu_impls/swiglu.py index 4be68184..c34a7bea 100644 --- a/src/twinkle/kernel/npu_impls/swiglu.py +++ b/src/twinkle/kernel/npu_impls/swiglu.py @@ -1,20 +1,20 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -"""Fused SwiGLU forward for Ascend NPU.""" -from __future__ import annotations - -import torch - - -def npu_swiglu_forward(self, hidden_state): - """Fused Qwen-style SwiGLU. - - Used as a class-attribute replacement on HF MLP classes. - Required instance attributes: ``gate_proj``, ``up_proj``, ``down_proj``. - """ - import torch_npu - return self.down_proj( - torch_npu.npu_swiglu( - torch.cat((self.gate_proj(hidden_state), self.up_proj(hidden_state)), dim=-1), - dim=-1, - ) +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Fused SwiGLU forward for Ascend NPU.""" +from __future__ import annotations + +import torch + + +def npu_swiglu_forward(self, hidden_state): + """Fused Qwen-style SwiGLU. + + Used as a class-attribute replacement on HF MLP classes. + Required instance attributes: ``gate_proj``, ``up_proj``, ``down_proj``. + """ + import torch_npu + return self.down_proj( + torch_npu.npu_swiglu( + torch.cat((self.gate_proj(hidden_state), self.up_proj(hidden_state)), dim=-1), + dim=-1, + ) ) \ No newline at end of file diff --git a/src/twinkle/model/transformers/moe/expert_parallel.py b/src/twinkle/model/transformers/moe/expert_parallel.py index 3f3e5c6a..2b6e45ea 100644 --- a/src/twinkle/model/transformers/moe/expert_parallel.py +++ b/src/twinkle/model/transformers/moe/expert_parallel.py @@ -1,533 +1,533 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -from __future__ import annotations - -import inspect -import torch -import torch.distributed as dist -import torch.nn.functional as F -from dataclasses import dataclass -from torch import nn -from typing import Any, Dict, Iterable, List, Optional, Tuple - -from twinkle.model.transformers.moe.ep_utils import preprocess, token_pre_all2all, tokens_post_all2all -from twinkle.utils import DeviceMesh - - -@dataclass -class ExpertParallelConfig: - enabled: bool = True - router_dtype: str = 'fp32' - keep_router_logits: bool = True - ignore_shared_experts: bool = False - ep_size: int | None = None # consumed by TransformersModel, not used in expert_parallel logic - - -@dataclass -class ExpertShardingSpec: - """Describes expert sharding info for a single MoE block. Extensible for other models.""" - block: nn.Module - experts_module: nn.Module - num_experts: int - experts_per_rank: int - local_start: int - local_end: int - ep_rank: int - ep_world_size: int - is_tensor_experts: bool - - -def apply_expert_parallel( - model: nn.Module, - device_mesh: DeviceMesh, - config: dict[str, Any] | None = None, - ep_fsdp_device_mesh: torch.distributed.DeviceMesh | None = None, -) -> list[ExpertShardingSpec]: - """Apply expert parallelism to all MoE blocks in the model.""" - cfg = _merge_config(config) - - # EP info comes from the separate ep_fsdp_device_mesh, not from main mesh - if not cfg.enabled or ep_fsdp_device_mesh is None: - return [] - - # Always query EP via the 1D submesh to avoid relying on Tensor named dims. - ep_mesh = ep_fsdp_device_mesh['ep'] - ep_world_size = ep_mesh.size() - if ep_world_size <= 1: - return [] - - if not dist.is_initialized(): - raise RuntimeError('torch.distributed is not initialized, cannot enable expert parallel.') - - # Get process group and local rank from EP submesh. - ep_group = ep_mesh.get_group() - ep_rank = ep_mesh.get_local_rank() - - specs = [] - for _, block in find_moe_blocks_with_names(model): - spec = shard_experts(block, ep_world_size, ep_rank, cfg) - patch_forward(block, ep_group, ep_world_size, cfg) - specs.append(spec) - - return specs - - -def _merge_config(config: dict[str, Any] | None) -> ExpertParallelConfig: - cfg = ExpertParallelConfig() - if not config: - return cfg - for key, value in config.items(): - if not hasattr(cfg, key): - raise ValueError(f'Unknown expert parallel config: {key}') - setattr(cfg, key, value) - return cfg - - -def find_moe_blocks(model: nn.Module) -> Iterable[nn.Module]: - return [block for _, block in find_moe_blocks_with_names(model)] - - -def find_moe_blocks_with_names(model: nn.Module) -> Iterable[tuple[str, nn.Module]]: - blocks = [] - for name, module in model.named_modules(): - experts = getattr(module, 'experts', None) - if experts is None: - continue - if not _is_moe_experts(experts): - continue - if not _get_gate(module): - continue - blocks.append((name, module)) - return blocks - - -def shard_experts( - block: nn.Module, - ep_world_size: int, - ep_rank: int, - cfg: ExpertParallelConfig, -) -> ExpertShardingSpec: - """Shard experts in a MoE block across EP ranks. - - Args: - block: The MoE block containing experts. - ep_world_size: The world size for expert parallelism. - ep_rank: The current rank in the EP group. - cfg: Expert parallel configuration. - - Returns an ExpertShardingSpec describing the sharding. - """ - num_experts = _get_num_experts(block) - - if num_experts % ep_world_size != 0: - raise ValueError(f'num_experts ({num_experts}) must be divisible by ep_world_size ({ep_world_size}).') - - experts_per_rank = num_experts // ep_world_size - local_start = ep_rank * experts_per_rank - local_end = local_start + experts_per_rank - - if isinstance(block.experts, nn.ModuleList): - local_experts = nn.ModuleList(block.experts[local_start:local_end]) - block.experts = local_experts - is_tensor_experts = False - else: - _shard_tensor_experts(block.experts, local_start, local_end) - is_tensor_experts = True - - block._ep_num_experts = num_experts - block._ep_experts_per_rank = experts_per_rank - block._ep_local_start = local_start - block._ep_local_end = local_end - block._ep_rank = ep_rank - block._ep_world_size = ep_world_size - block._ep_tensor_experts = is_tensor_experts - block._ep_ignore_shared_experts = cfg.ignore_shared_experts - - return ExpertShardingSpec( - block=block, - experts_module=block.experts, - num_experts=num_experts, - experts_per_rank=experts_per_rank, - local_start=local_start, - local_end=local_end, - ep_rank=ep_rank, - ep_world_size=ep_world_size, - is_tensor_experts=is_tensor_experts, - ) - - -def patch_forward( - block: nn.Module, - ep_group: dist.ProcessGroup, - ep_world_size: int, - cfg: ExpertParallelConfig, -) -> None: - """Replace the MoE block forward with EP-aware communication flow. - - Communication pattern: - preprocess → token_pre_all2all → expert_compute → tokens_post_all2all - - For tensor experts (gate_up_proj/down_proj), the expert compute is delegated - to block.experts(...) via nn.Module.__call__ so that FSDP2 pre/post-forward - hooks fire correctly (automatic unshard before forward, backward hook - registration, and reshard after forward). No manual unshard/reshard is needed. - - For ModuleList experts, each sub-expert is already called via __call__ inside - _run_local_experts, so the same principle applies. - - Args: - block: The MoE block to patch. - ep_group: The process group for EP communication (from ep_fsdp_device_mesh["ep"]). - ep_world_size: The world size for expert parallelism. - cfg: Expert parallel configuration. - """ - if getattr(block, '_ep_patched', False): - return - - gate = _get_gate(block) - if gate is None: - raise ValueError('MoE block must define gate/router module.') - - top_k = _get_top_k(block) - if top_k is None: - raise ValueError('MoE block must define top_k/num_experts_per_tok.') - - orig_forward = block.forward - return_annotation = inspect.signature(orig_forward).return_annotation - returns_router_logits = return_annotation in ( - tuple, - Tuple[torch.Tensor, torch.Tensor | None], - ) - num_experts = block._ep_num_experts - experts_per_rank = block._ep_experts_per_rank - is_tensor_experts = block._ep_tensor_experts - - # For tensor experts, install an ep_forward on the experts module so we can - # call block.experts(permuted_tokens, counts, experts_per_rank) via __call__, - # letting FSDP2 manage unshard/reshard automatically. - if is_tensor_experts: - _install_ep_forward(block.experts, experts_per_rank) - - def forward(hidden_states: torch.Tensor, *args, **kwargs): - if args: - raise RuntimeError('Expert parallel patch only supports keyword-only extra args for MoE blocks.') - - orig_shape = hidden_states.shape - if hidden_states.ndim == 3: - batch_size, seq_len, hidden_dim = hidden_states.shape - hidden_states_2d = hidden_states.view(-1, hidden_dim) - elif hidden_states.ndim == 2: - batch_size, seq_len = 1, hidden_states.shape[0] - hidden_dim = hidden_states.shape[1] - hidden_states_2d = hidden_states - else: - raise ValueError(f'Unsupported hidden_states ndim: {hidden_states.ndim}') - - router_logits, routing_weights, selected_experts = _run_router( - gate=gate, - hidden_states=hidden_states_2d, - top_k=top_k, - router_dtype=_get_router_dtype(cfg.router_dtype, hidden_states_2d.dtype), - norm_topk_prob=getattr(block, 'norm_topk_prob', False), - **kwargs, - ) - # Keep routing weights in activation dtype before unpermute weighting. - if routing_weights.dtype != hidden_states_2d.dtype: - routing_weights = routing_weights.to(hidden_states_2d.dtype) - # Build expert_mask: [num_experts, top_k, num_tokens] - expert_mask = torch.nn.functional.one_hot( - selected_experts, num_classes=num_experts).permute(2, 1, 0) # [num_experts, top_k, num_tokens] - - # 1. preprocess: compute splits and token counts - ( - input_splits, - output_splits, - num_global_tokens_per_local_expert, - num_global_sum_tokens_per_local_expert, - ) = preprocess(expert_mask, num_experts, ep_group) - - # 2. token_pre_all2all: permute → all_to_all → sort_chunks - ( - global_permuted_hidden_states, - local_input_permutation_mapping, - local_assignment_weights, - org_hidden_states_shape, - ) = token_pre_all2all( - hidden_states_2d, - expert_mask, - routing_weights, - num_experts, - input_splits, - output_splits, - num_global_tokens_per_local_expert, - ep_group, - ) - - # 3. expert_compute: call experts via nn.Module.__call__ so FSDP2 hooks fire. - # For tensor experts: block.experts(permuted_tokens, counts, experts_per_rank) - # → FSDP2 pre-forward unshard → ep_forward → FSDP2 post-forward reshard - # For ModuleList experts: _run_local_experts calls each expert[i](...) via __call__. - if is_tensor_experts: - expert_outputs = block.experts( - global_permuted_hidden_states, - num_global_sum_tokens_per_local_expert, - experts_per_rank, - ) - else: - expert_outputs = _run_local_experts( - block, - global_permuted_hidden_states, - num_global_sum_tokens_per_local_expert, - experts_per_rank, - ) - - # 4. tokens_post_all2all: sort_chunks → all_to_all → unpermute (with routing weight) - final_hidden = tokens_post_all2all( - expert_outputs, - local_assignment_weights, - num_experts, - input_splits, - output_splits, - num_global_tokens_per_local_expert, - local_input_permutation_mapping, - org_hidden_states_shape, - ep_group, - ) - - shared_out = _maybe_run_shared_expert(block, hidden_states_2d, cfg) - if shared_out is not None: - final_hidden = final_hidden + shared_out - - if len(orig_shape) == 3: - final_hidden = final_hidden.view(batch_size, seq_len, hidden_dim) - - if cfg.keep_router_logits and returns_router_logits: - return final_hidden, router_logits - return final_hidden - - block._ep_original_forward = orig_forward - block.forward = forward - block._ep_patched = True - - -def _install_ep_forward(experts_mod: nn.Module, experts_per_rank: int) -> None: - if getattr(experts_mod, '_ep_forward_installed', False): - return - - def ep_forward( - self, - permuted_tokens: torch.Tensor, - num_global_sum_tokens_per_local_expert: torch.Tensor, - experts_per_rank: int, - ) -> torch.Tensor: - if permuted_tokens.numel() == 0: - # Preserve the autograd edge to token_pre_all2all. Returning a new - # empty tensor can make this rank skip the matching backward - # all-to-all, causing EP collective order divergence. - return permuted_tokens - - input_dtype = permuted_tokens.dtype - - cumsum = torch.zeros(experts_per_rank + 1, dtype=torch.long) - for i in range(experts_per_rank): - cumsum[i + 1] = cumsum[i] + int(num_global_sum_tokens_per_local_expert[i].item()) - - output_chunks = [] - for i in range(experts_per_rank): - start = int(cumsum[i].item()) - end = int(cumsum[i + 1].item()) - expert_in = permuted_tokens[start:end] - if expert_in.numel() == 0: - output_chunks.append(expert_in) - continue - - gate_up = self.gate_up_proj[i] - down = self.down_proj[i] - compute_dtype = gate_up.dtype - if expert_in.dtype != compute_dtype: - expert_in = expert_in.to(compute_dtype) - gate_up_out = F.linear(expert_in, gate_up) - if hasattr(self, '_apply_gate'): - out = self._apply_gate(gate_up_out) - else: - gate, up = gate_up_out.chunk(2, dim=-1) - out = self.act_fn(gate) * up - out = F.linear(out, down) - - if out.dtype != input_dtype: - out = out.to(input_dtype) - output_chunks.append(out) - - return torch.cat( - output_chunks, dim=0) if output_chunks else permuted_tokens.new_empty(0, permuted_tokens.size(-1)) - - import types - experts_mod.forward = types.MethodType(ep_forward, experts_mod) - experts_mod._ep_forward_installed = True - - -def _get_gate(block: nn.Module): - gate = getattr(block, 'gate', None) - if gate is None: - gate = getattr(block, 'router', None) - return gate - - -def _get_num_experts(block: nn.Module) -> int: - if hasattr(block, 'num_experts'): - return int(block.num_experts) - experts = getattr(block, 'experts', None) - if experts is None: - raise ValueError('MoE block has no experts.') - if isinstance(experts, nn.ModuleList): - return len(experts) - if hasattr(experts, 'num_experts'): - return int(experts.num_experts) - if hasattr(experts, 'gate_up_proj'): - return int(experts.gate_up_proj.shape[0]) - raise ValueError('Unable to infer num_experts for MoE block.') - - -def _get_top_k(block: nn.Module) -> int | None: - gate = _get_gate(block) - if gate is not None and hasattr(gate, 'top_k'): - value = getattr(gate, 'top_k') - if value is not None: - return int(value) - for name in ('num_experts_per_tok', 'top_k'): - if hasattr(block, name): - value = getattr(block, name) - if value is not None: - return int(value) - return None - - -def _get_router_dtype(router_dtype: str, default_dtype: torch.dtype) -> torch.dtype: - if router_dtype == 'fp32': - return torch.float32 - if router_dtype == 'bf16': - return torch.bfloat16 - if router_dtype == 'fp16': - return torch.float16 - return default_dtype - - -def _maybe_run_shared_expert(block: nn.Module, hidden_states_2d: torch.Tensor, cfg: ExpertParallelConfig): - if cfg.ignore_shared_experts: - return None - shared = getattr(block, 'shared_expert', None) - if shared is None: - shared = getattr(block, 'shared_experts', None) - if shared is None: - return None - return _run_module_with_casting(shared, hidden_states_2d) - - -def _is_moe_experts(experts: Any) -> bool: - if isinstance(experts, nn.ModuleList): - return True - if hasattr(experts, 'gate_up_proj') and hasattr(experts, 'down_proj'): - return True - return False - - -def _shard_tensor_experts(experts: nn.Module, start: int, end: int) -> None: - experts.gate_up_proj = nn.Parameter(experts.gate_up_proj.data[start:end].clone()) - experts.down_proj = nn.Parameter(experts.down_proj.data[start:end].clone()) - if hasattr(experts, 'num_experts'): - experts.num_experts = end - start - - -def _run_local_experts( - block: nn.Module, - permuted_tokens: torch.Tensor, - num_global_sum_tokens_per_local_expert: torch.Tensor, - experts_per_rank: int, -) -> torch.Tensor: - """Run ModuleList experts on permuted tokens via nn.Module.__call__. - Tokens are already grouped by expert (contiguous chunks), sizes given by - num_global_sum_tokens_per_local_expert. No routing weight is applied here; - that happens in unpermute. - """ - if permuted_tokens.numel() == 0: - # Keep the backward path through token_pre_all2all even when this EP - # rank owns no routed tokens for the current block. - return permuted_tokens - - input_dtype = permuted_tokens.dtype - experts = block.experts - - cumsum = torch.zeros(experts_per_rank + 1, dtype=torch.long) - for i in range(experts_per_rank): - cumsum[i + 1] = cumsum[i] + int(num_global_sum_tokens_per_local_expert[i].item()) - - output_chunks = [] - for i in range(experts_per_rank): - start = int(cumsum[i].item()) - end = int(cumsum[i + 1].item()) - expert_in = permuted_tokens[start:end] - if expert_in.numel() == 0: - output_chunks.append(expert_in) - continue - - expert = experts[i] - compute_dtype = _module_compute_dtype(expert, input_dtype) - if expert_in.dtype != compute_dtype: - expert_in = expert_in.to(compute_dtype) - out = expert(expert_in) - - if out.dtype != input_dtype: - out = out.to(input_dtype) - output_chunks.append(out) - - return torch.cat(output_chunks, dim=0) if output_chunks else permuted_tokens.new_empty(0, permuted_tokens.size(-1)) - - -def _module_compute_dtype(module: nn.Module, default: torch.dtype) -> torch.dtype: - for param in module.parameters(): - if param.dtype.is_floating_point: - return param.dtype - return default - - -def _run_module_with_casting(module: nn.Module, module_in: torch.Tensor) -> torch.Tensor: - input_dtype = module_in.dtype - compute_dtype = _module_compute_dtype(module, input_dtype) - if compute_dtype != input_dtype: - module_in = module_in.to(compute_dtype) - out = module(module_in) - if out.dtype != input_dtype: - out = out.to(input_dtype) - return out - - -def _run_router( - *, - gate: nn.Module, - hidden_states: torch.Tensor, - top_k: int, - router_dtype: torch.dtype, - norm_topk_prob: bool, - **kwargs, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - gate_kwargs = {} - if 'input_ids' in kwargs and _module_forward_accepts_kwarg(gate, 'input_ids'): - gate_kwargs['input_ids'] = kwargs['input_ids'] - gate_out = gate(hidden_states, **gate_kwargs) - if isinstance(gate_out, tuple) and len(gate_out) >= 3: - router_logits, routing_weights, selected_experts = gate_out[:3] - return router_logits, routing_weights, selected_experts - - router_logits = gate_out - routing_weights = torch.softmax(router_logits, dim=-1, dtype=router_dtype) - routing_weights, selected_experts = torch.topk(routing_weights, top_k, dim=-1) - if norm_topk_prob: - routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True) - return router_logits, routing_weights, selected_experts - - -def _module_forward_accepts_kwarg(module: nn.Module, kwarg: str) -> bool: - signature = inspect.signature(module.forward) - for param in signature.parameters.values(): - if param.kind == inspect.Parameter.VAR_KEYWORD: - return True - return kwarg in signature.parameters +# Copyright (c) ModelScope Contributors. All rights reserved. +from __future__ import annotations + +import inspect +import torch +import torch.distributed as dist +import torch.nn.functional as F +from dataclasses import dataclass +from torch import nn +from typing import Any, Dict, Iterable, List, Optional, Tuple + +from twinkle.model.transformers.moe.ep_utils import preprocess, token_pre_all2all, tokens_post_all2all +from twinkle.utils import DeviceMesh + + +@dataclass +class ExpertParallelConfig: + enabled: bool = True + router_dtype: str = 'fp32' + keep_router_logits: bool = True + ignore_shared_experts: bool = False + ep_size: int | None = None # consumed by TransformersModel, not used in expert_parallel logic + + +@dataclass +class ExpertShardingSpec: + """Describes expert sharding info for a single MoE block. Extensible for other models.""" + block: nn.Module + experts_module: nn.Module + num_experts: int + experts_per_rank: int + local_start: int + local_end: int + ep_rank: int + ep_world_size: int + is_tensor_experts: bool + + +def apply_expert_parallel( + model: nn.Module, + device_mesh: DeviceMesh, + config: dict[str, Any] | None = None, + ep_fsdp_device_mesh: torch.distributed.DeviceMesh | None = None, +) -> list[ExpertShardingSpec]: + """Apply expert parallelism to all MoE blocks in the model.""" + cfg = _merge_config(config) + + # EP info comes from the separate ep_fsdp_device_mesh, not from main mesh + if not cfg.enabled or ep_fsdp_device_mesh is None: + return [] + + # Always query EP via the 1D submesh to avoid relying on Tensor named dims. + ep_mesh = ep_fsdp_device_mesh['ep'] + ep_world_size = ep_mesh.size() + if ep_world_size <= 1: + return [] + + if not dist.is_initialized(): + raise RuntimeError('torch.distributed is not initialized, cannot enable expert parallel.') + + # Get process group and local rank from EP submesh. + ep_group = ep_mesh.get_group() + ep_rank = ep_mesh.get_local_rank() + + specs = [] + for _, block in find_moe_blocks_with_names(model): + spec = shard_experts(block, ep_world_size, ep_rank, cfg) + patch_forward(block, ep_group, ep_world_size, cfg) + specs.append(spec) + + return specs + + +def _merge_config(config: dict[str, Any] | None) -> ExpertParallelConfig: + cfg = ExpertParallelConfig() + if not config: + return cfg + for key, value in config.items(): + if not hasattr(cfg, key): + raise ValueError(f'Unknown expert parallel config: {key}') + setattr(cfg, key, value) + return cfg + + +def find_moe_blocks(model: nn.Module) -> Iterable[nn.Module]: + return [block for _, block in find_moe_blocks_with_names(model)] + + +def find_moe_blocks_with_names(model: nn.Module) -> Iterable[tuple[str, nn.Module]]: + blocks = [] + for name, module in model.named_modules(): + experts = getattr(module, 'experts', None) + if experts is None: + continue + if not _is_moe_experts(experts): + continue + if not _get_gate(module): + continue + blocks.append((name, module)) + return blocks + + +def shard_experts( + block: nn.Module, + ep_world_size: int, + ep_rank: int, + cfg: ExpertParallelConfig, +) -> ExpertShardingSpec: + """Shard experts in a MoE block across EP ranks. + + Args: + block: The MoE block containing experts. + ep_world_size: The world size for expert parallelism. + ep_rank: The current rank in the EP group. + cfg: Expert parallel configuration. + + Returns an ExpertShardingSpec describing the sharding. + """ + num_experts = _get_num_experts(block) + + if num_experts % ep_world_size != 0: + raise ValueError(f'num_experts ({num_experts}) must be divisible by ep_world_size ({ep_world_size}).') + + experts_per_rank = num_experts // ep_world_size + local_start = ep_rank * experts_per_rank + local_end = local_start + experts_per_rank + + if isinstance(block.experts, nn.ModuleList): + local_experts = nn.ModuleList(block.experts[local_start:local_end]) + block.experts = local_experts + is_tensor_experts = False + else: + _shard_tensor_experts(block.experts, local_start, local_end) + is_tensor_experts = True + + block._ep_num_experts = num_experts + block._ep_experts_per_rank = experts_per_rank + block._ep_local_start = local_start + block._ep_local_end = local_end + block._ep_rank = ep_rank + block._ep_world_size = ep_world_size + block._ep_tensor_experts = is_tensor_experts + block._ep_ignore_shared_experts = cfg.ignore_shared_experts + + return ExpertShardingSpec( + block=block, + experts_module=block.experts, + num_experts=num_experts, + experts_per_rank=experts_per_rank, + local_start=local_start, + local_end=local_end, + ep_rank=ep_rank, + ep_world_size=ep_world_size, + is_tensor_experts=is_tensor_experts, + ) + + +def patch_forward( + block: nn.Module, + ep_group: dist.ProcessGroup, + ep_world_size: int, + cfg: ExpertParallelConfig, +) -> None: + """Replace the MoE block forward with EP-aware communication flow. + + Communication pattern: + preprocess → token_pre_all2all → expert_compute → tokens_post_all2all + + For tensor experts (gate_up_proj/down_proj), the expert compute is delegated + to block.experts(...) via nn.Module.__call__ so that FSDP2 pre/post-forward + hooks fire correctly (automatic unshard before forward, backward hook + registration, and reshard after forward). No manual unshard/reshard is needed. + + For ModuleList experts, each sub-expert is already called via __call__ inside + _run_local_experts, so the same principle applies. + + Args: + block: The MoE block to patch. + ep_group: The process group for EP communication (from ep_fsdp_device_mesh["ep"]). + ep_world_size: The world size for expert parallelism. + cfg: Expert parallel configuration. + """ + if getattr(block, '_ep_patched', False): + return + + gate = _get_gate(block) + if gate is None: + raise ValueError('MoE block must define gate/router module.') + + top_k = _get_top_k(block) + if top_k is None: + raise ValueError('MoE block must define top_k/num_experts_per_tok.') + + orig_forward = block.forward + return_annotation = inspect.signature(orig_forward).return_annotation + returns_router_logits = return_annotation in ( + tuple, + Tuple[torch.Tensor, torch.Tensor | None], + ) + num_experts = block._ep_num_experts + experts_per_rank = block._ep_experts_per_rank + is_tensor_experts = block._ep_tensor_experts + + # For tensor experts, install an ep_forward on the experts module so we can + # call block.experts(permuted_tokens, counts, experts_per_rank) via __call__, + # letting FSDP2 manage unshard/reshard automatically. + if is_tensor_experts: + _install_ep_forward(block.experts, experts_per_rank) + + def forward(hidden_states: torch.Tensor, *args, **kwargs): + if args: + raise RuntimeError('Expert parallel patch only supports keyword-only extra args for MoE blocks.') + + orig_shape = hidden_states.shape + if hidden_states.ndim == 3: + batch_size, seq_len, hidden_dim = hidden_states.shape + hidden_states_2d = hidden_states.view(-1, hidden_dim) + elif hidden_states.ndim == 2: + batch_size, seq_len = 1, hidden_states.shape[0] + hidden_dim = hidden_states.shape[1] + hidden_states_2d = hidden_states + else: + raise ValueError(f'Unsupported hidden_states ndim: {hidden_states.ndim}') + + router_logits, routing_weights, selected_experts = _run_router( + gate=gate, + hidden_states=hidden_states_2d, + top_k=top_k, + router_dtype=_get_router_dtype(cfg.router_dtype, hidden_states_2d.dtype), + norm_topk_prob=getattr(block, 'norm_topk_prob', False), + **kwargs, + ) + # Keep routing weights in activation dtype before unpermute weighting. + if routing_weights.dtype != hidden_states_2d.dtype: + routing_weights = routing_weights.to(hidden_states_2d.dtype) + # Build expert_mask: [num_experts, top_k, num_tokens] + expert_mask = torch.nn.functional.one_hot( + selected_experts, num_classes=num_experts).permute(2, 1, 0) # [num_experts, top_k, num_tokens] + + # 1. preprocess: compute splits and token counts + ( + input_splits, + output_splits, + num_global_tokens_per_local_expert, + num_global_sum_tokens_per_local_expert, + ) = preprocess(expert_mask, num_experts, ep_group) + + # 2. token_pre_all2all: permute → all_to_all → sort_chunks + ( + global_permuted_hidden_states, + local_input_permutation_mapping, + local_assignment_weights, + org_hidden_states_shape, + ) = token_pre_all2all( + hidden_states_2d, + expert_mask, + routing_weights, + num_experts, + input_splits, + output_splits, + num_global_tokens_per_local_expert, + ep_group, + ) + + # 3. expert_compute: call experts via nn.Module.__call__ so FSDP2 hooks fire. + # For tensor experts: block.experts(permuted_tokens, counts, experts_per_rank) + # → FSDP2 pre-forward unshard → ep_forward → FSDP2 post-forward reshard + # For ModuleList experts: _run_local_experts calls each expert[i](...) via __call__. + if is_tensor_experts: + expert_outputs = block.experts( + global_permuted_hidden_states, + num_global_sum_tokens_per_local_expert, + experts_per_rank, + ) + else: + expert_outputs = _run_local_experts( + block, + global_permuted_hidden_states, + num_global_sum_tokens_per_local_expert, + experts_per_rank, + ) + + # 4. tokens_post_all2all: sort_chunks → all_to_all → unpermute (with routing weight) + final_hidden = tokens_post_all2all( + expert_outputs, + local_assignment_weights, + num_experts, + input_splits, + output_splits, + num_global_tokens_per_local_expert, + local_input_permutation_mapping, + org_hidden_states_shape, + ep_group, + ) + + shared_out = _maybe_run_shared_expert(block, hidden_states_2d, cfg) + if shared_out is not None: + final_hidden = final_hidden + shared_out + + if len(orig_shape) == 3: + final_hidden = final_hidden.view(batch_size, seq_len, hidden_dim) + + if cfg.keep_router_logits and returns_router_logits: + return final_hidden, router_logits + return final_hidden + + block._ep_original_forward = orig_forward + block.forward = forward + block._ep_patched = True + + +def _install_ep_forward(experts_mod: nn.Module, experts_per_rank: int) -> None: + if getattr(experts_mod, '_ep_forward_installed', False): + return + + def ep_forward( + self, + permuted_tokens: torch.Tensor, + num_global_sum_tokens_per_local_expert: torch.Tensor, + experts_per_rank: int, + ) -> torch.Tensor: + if permuted_tokens.numel() == 0: + # Preserve the autograd edge to token_pre_all2all. Returning a new + # empty tensor can make this rank skip the matching backward + # all-to-all, causing EP collective order divergence. + return permuted_tokens + + input_dtype = permuted_tokens.dtype + + cumsum = torch.zeros(experts_per_rank + 1, dtype=torch.long) + for i in range(experts_per_rank): + cumsum[i + 1] = cumsum[i] + int(num_global_sum_tokens_per_local_expert[i].item()) + + output_chunks = [] + for i in range(experts_per_rank): + start = int(cumsum[i].item()) + end = int(cumsum[i + 1].item()) + expert_in = permuted_tokens[start:end] + if expert_in.numel() == 0: + output_chunks.append(expert_in) + continue + + gate_up = self.gate_up_proj[i] + down = self.down_proj[i] + compute_dtype = gate_up.dtype + if expert_in.dtype != compute_dtype: + expert_in = expert_in.to(compute_dtype) + gate_up_out = F.linear(expert_in, gate_up) + if hasattr(self, '_apply_gate'): + out = self._apply_gate(gate_up_out) + else: + gate, up = gate_up_out.chunk(2, dim=-1) + out = self.act_fn(gate) * up + out = F.linear(out, down) + + if out.dtype != input_dtype: + out = out.to(input_dtype) + output_chunks.append(out) + + return torch.cat( + output_chunks, dim=0) if output_chunks else permuted_tokens.new_empty(0, permuted_tokens.size(-1)) + + import types + experts_mod.forward = types.MethodType(ep_forward, experts_mod) + experts_mod._ep_forward_installed = True + + +def _get_gate(block: nn.Module): + gate = getattr(block, 'gate', None) + if gate is None: + gate = getattr(block, 'router', None) + return gate + + +def _get_num_experts(block: nn.Module) -> int: + if hasattr(block, 'num_experts'): + return int(block.num_experts) + experts = getattr(block, 'experts', None) + if experts is None: + raise ValueError('MoE block has no experts.') + if isinstance(experts, nn.ModuleList): + return len(experts) + if hasattr(experts, 'num_experts'): + return int(experts.num_experts) + if hasattr(experts, 'gate_up_proj'): + return int(experts.gate_up_proj.shape[0]) + raise ValueError('Unable to infer num_experts for MoE block.') + + +def _get_top_k(block: nn.Module) -> int | None: + gate = _get_gate(block) + if gate is not None and hasattr(gate, 'top_k'): + value = getattr(gate, 'top_k') + if value is not None: + return int(value) + for name in ('num_experts_per_tok', 'top_k'): + if hasattr(block, name): + value = getattr(block, name) + if value is not None: + return int(value) + return None + + +def _get_router_dtype(router_dtype: str, default_dtype: torch.dtype) -> torch.dtype: + if router_dtype == 'fp32': + return torch.float32 + if router_dtype == 'bf16': + return torch.bfloat16 + if router_dtype == 'fp16': + return torch.float16 + return default_dtype + + +def _maybe_run_shared_expert(block: nn.Module, hidden_states_2d: torch.Tensor, cfg: ExpertParallelConfig): + if cfg.ignore_shared_experts: + return None + shared = getattr(block, 'shared_expert', None) + if shared is None: + shared = getattr(block, 'shared_experts', None) + if shared is None: + return None + return _run_module_with_casting(shared, hidden_states_2d) + + +def _is_moe_experts(experts: Any) -> bool: + if isinstance(experts, nn.ModuleList): + return True + if hasattr(experts, 'gate_up_proj') and hasattr(experts, 'down_proj'): + return True + return False + + +def _shard_tensor_experts(experts: nn.Module, start: int, end: int) -> None: + experts.gate_up_proj = nn.Parameter(experts.gate_up_proj.data[start:end].clone()) + experts.down_proj = nn.Parameter(experts.down_proj.data[start:end].clone()) + if hasattr(experts, 'num_experts'): + experts.num_experts = end - start + + +def _run_local_experts( + block: nn.Module, + permuted_tokens: torch.Tensor, + num_global_sum_tokens_per_local_expert: torch.Tensor, + experts_per_rank: int, +) -> torch.Tensor: + """Run ModuleList experts on permuted tokens via nn.Module.__call__. + Tokens are already grouped by expert (contiguous chunks), sizes given by + num_global_sum_tokens_per_local_expert. No routing weight is applied here; + that happens in unpermute. + """ + if permuted_tokens.numel() == 0: + # Keep the backward path through token_pre_all2all even when this EP + # rank owns no routed tokens for the current block. + return permuted_tokens + + input_dtype = permuted_tokens.dtype + experts = block.experts + + cumsum = torch.zeros(experts_per_rank + 1, dtype=torch.long) + for i in range(experts_per_rank): + cumsum[i + 1] = cumsum[i] + int(num_global_sum_tokens_per_local_expert[i].item()) + + output_chunks = [] + for i in range(experts_per_rank): + start = int(cumsum[i].item()) + end = int(cumsum[i + 1].item()) + expert_in = permuted_tokens[start:end] + if expert_in.numel() == 0: + output_chunks.append(expert_in) + continue + + expert = experts[i] + compute_dtype = _module_compute_dtype(expert, input_dtype) + if expert_in.dtype != compute_dtype: + expert_in = expert_in.to(compute_dtype) + out = expert(expert_in) + + if out.dtype != input_dtype: + out = out.to(input_dtype) + output_chunks.append(out) + + return torch.cat(output_chunks, dim=0) if output_chunks else permuted_tokens.new_empty(0, permuted_tokens.size(-1)) + + +def _module_compute_dtype(module: nn.Module, default: torch.dtype) -> torch.dtype: + for param in module.parameters(): + if param.dtype.is_floating_point: + return param.dtype + return default + + +def _run_module_with_casting(module: nn.Module, module_in: torch.Tensor) -> torch.Tensor: + input_dtype = module_in.dtype + compute_dtype = _module_compute_dtype(module, input_dtype) + if compute_dtype != input_dtype: + module_in = module_in.to(compute_dtype) + out = module(module_in) + if out.dtype != input_dtype: + out = out.to(input_dtype) + return out + + +def _run_router( + *, + gate: nn.Module, + hidden_states: torch.Tensor, + top_k: int, + router_dtype: torch.dtype, + norm_topk_prob: bool, + **kwargs, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + gate_kwargs = {} + if 'input_ids' in kwargs and _module_forward_accepts_kwarg(gate, 'input_ids'): + gate_kwargs['input_ids'] = kwargs['input_ids'] + gate_out = gate(hidden_states, **gate_kwargs) + if isinstance(gate_out, tuple) and len(gate_out) >= 3: + router_logits, routing_weights, selected_experts = gate_out[:3] + return router_logits, routing_weights, selected_experts + + router_logits = gate_out + routing_weights = torch.softmax(router_logits, dim=-1, dtype=router_dtype) + routing_weights, selected_experts = torch.topk(routing_weights, top_k, dim=-1) + if norm_topk_prob: + routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True) + return router_logits, routing_weights, selected_experts + + +def _module_forward_accepts_kwarg(module: nn.Module, kwarg: str) -> bool: + signature = inspect.signature(module.forward) + for param in signature.parameters.values(): + if param.kind == inspect.Parameter.VAR_KEYWORD: + return True + return kwarg in signature.parameters diff --git a/src/twinkle/model/transformers/strategy/sequence_parallel/linear_attention_sp.py b/src/twinkle/model/transformers/strategy/sequence_parallel/linear_attention_sp.py index 9da5794b..6608a2b8 100644 --- a/src/twinkle/model/transformers/strategy/sequence_parallel/linear_attention_sp.py +++ b/src/twinkle/model/transformers/strategy/sequence_parallel/linear_attention_sp.py @@ -1,363 +1,363 @@ -import torch -import torch.distributed as dist -import torch.nn.functional as F -import warnings -from importlib import import_module -from transformers.utils.import_utils import is_causal_conv1d_available, is_flash_linear_attention_available -from typing import Any, Optional, Tuple - -from twinkle.model.transformers.strategy.sequence_parallel.utils import ( - get_packed_cu_seqlens_from_sequence_parallel_context, head_to_seq_shard, seq_to_head_shard) -from twinkle.patch import Patch - -if is_flash_linear_attention_available(): - from fla.modules.convolution import causal_conv1d as _FLA_CAUSAL_CONV1D_FN - from fla.ops.gated_delta_rule import chunk_gated_delta_rule as _FLA_CHUNK_GATED_DELTA_RULE -else: - _FLA_CAUSAL_CONV1D_FN = None - _FLA_CHUNK_GATED_DELTA_RULE = None - -if is_causal_conv1d_available(): - from causal_conv1d import causal_conv1d_fn as _CAUSAL_CONV1D_FN -else: - _CAUSAL_CONV1D_FN = None - -_SP_LINEAR_KERNEL_FALLBACK_WARNING = ( - 'flash-linear-attention is not available; falling back to torch implementations for Qwen3.5 linear attention ' - 'sequence parallel. This fallback only supports non-packed sequences.') - - -def _sp_is_enabled(sequence_parallel_context) -> bool: - return bool(sequence_parallel_context is not None and getattr(sequence_parallel_context, 'world_size', 1) > 1) - - -def _get_sp_rank(sequence_parallel_context) -> int: - if not _sp_is_enabled(sequence_parallel_context): - return 0 - if getattr(sequence_parallel_context, '_sp_group', None) is None: - return 0 - return dist.get_rank(group=sequence_parallel_context._sp_group) - - -def _get_local_padding_mask( - attention_mask: torch.Tensor, - local_seq_len: int, - sequence_parallel_context, -) -> torch.Tensor: - if attention_mask.shape[-1] == local_seq_len or not _sp_is_enabled(sequence_parallel_context): - return attention_mask - return sequence_parallel_context.split( - attention_mask, - dim=1, - position_ids=sequence_parallel_context.real_position_ids, - ) - - -def _apply_conv_activation(x: torch.Tensor, activation) -> torch.Tensor: - if activation is None: - return x - if activation in ('silu', 'swish'): - return F.silu(x) - if callable(activation): - return activation(x) - from transformers.activations import ACT2FN - if activation in ACT2FN: - return ACT2FN[activation](x) - raise ValueError(f'Unsupported causal conv activation: {activation!r}') - - -def _ensure_linear_attention_kernels(mod: torch.nn.Module): - """Bind causal_conv1d_fn and chunk_gated_delta_rule for SP forward.""" - - def _torch_causal_conv1d_fn( - *, - x, - weight, - bias=None, - activation=None, - seq_idx=None, - backend=None, - cu_seqlens=None, - ): - # Fallback priority: - # 1. flash-linear-attention kernels handle padding_free/packed cu_seqlens and are selected above. - # 2. causal-conv1d package accelerates non-packed convolution when flash-linear-attention is unavailable. - # 3. plain torch conv1d is the final non-packed fallback. - del backend - if cu_seqlens is not None: - raise NotImplementedError( - 'Qwen3.5 linear attention sequence parallel with padding_free/packed inputs requires ' - 'flash-linear-attention. The torch fallback only supports non-packed sequences. ' - 'Please install flash-linear-attention or disable padding_free/packing.') - if _CAUSAL_CONV1D_FN is not None: - out = _CAUSAL_CONV1D_FN( - x=x.transpose(1, 2).contiguous(), - weight=weight, - bias=bias, - activation=activation, - seq_idx=seq_idx, - ) - if isinstance(out, tuple): - out = out[0] - return out.transpose(1, 2).contiguous() - seq_len = x.shape[1] - x = x.transpose(1, 2).contiguous() - out = F.conv1d(x, weight.unsqueeze(1), bias, padding=weight.shape[-1] - 1, groups=x.shape[1]) - out = _apply_conv_activation(out[:, :, :seq_len], activation) - return out.transpose(1, 2).contiguous() - - # NPU: MindSpeed Triton causal_conv1d and chunk_gated_delta_rule - # are both patched by twinkle.kernel.npu_impls.fla at model initialization. - # No need to set them here - they are already bound on the module. - if getattr(mod, '_twinkle_npu_patched', False): - return False - - if _FLA_CAUSAL_CONV1D_FN is not None and _FLA_CHUNK_GATED_DELTA_RULE is not None: - mod.causal_conv1d_fn = _FLA_CAUSAL_CONV1D_FN - mod.chunk_gated_delta_rule = _FLA_CHUNK_GATED_DELTA_RULE - return False - - modeling_module = import_module(mod.__class__.__module__) - torch_chunk_gated_delta_rule = getattr(modeling_module, 'torch_chunk_gated_delta_rule') - mod.causal_conv1d_fn = _torch_causal_conv1d_fn - mod.chunk_gated_delta_rule = torch_chunk_gated_delta_rule - warnings.warn(_SP_LINEAR_KERNEL_FALLBACK_WARNING, stacklevel=2) - return True - - -def _iter_qwen35_gated_delta_net_classes(): - class_specs = ( - ('transformers.models.qwen3_5.modeling_qwen3_5', 'Qwen3_5GatedDeltaNet'), - ('transformers.models.qwen3_5_moe.modeling_qwen3_5_moe', 'Qwen3_5MoeGatedDeltaNet'), - ) - for module_name, class_name in class_specs: - try: - modeling_module = import_module(module_name) - yield getattr(modeling_module, class_name) - except Exception: - continue - - -def _get_local_conv_weights( - mod: torch.nn.Module, - *, - sp_rank: int, - local_num_k_heads: int, - local_num_v_heads: int, -) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - local_key_dim = local_num_k_heads * mod.head_k_dim - local_value_dim = local_num_v_heads * mod.head_v_dim - conv_weight = mod.conv1d.weight.squeeze(1) - if conv_weight.shape[0] != (2 * mod.key_dim + mod.value_dim): - raise ValueError( - f'Unexpected conv weight dim {conv_weight.shape[0]}, expected {2 * mod.key_dim + mod.value_dim}.') - key_offset = sp_rank * local_key_dim - value_offset = sp_rank * local_value_dim - local_q_weight = conv_weight[key_offset:key_offset + local_key_dim] - local_k_weight = conv_weight[mod.key_dim + key_offset:mod.key_dim + key_offset + local_key_dim] - local_v_weight = conv_weight[2 * mod.key_dim + value_offset:2 * mod.key_dim + value_offset + local_value_dim] - local_conv_weight = torch.cat([local_q_weight, local_k_weight, local_v_weight], dim=0) - - conv_bias = getattr(mod.conv1d, 'bias', None) - if conv_bias is None: - return local_conv_weight, None - local_q_bias = conv_bias[key_offset:key_offset + local_key_dim] - local_k_bias = conv_bias[mod.key_dim + key_offset:mod.key_dim + key_offset + local_key_dim] - local_v_bias = conv_bias[2 * mod.key_dim + value_offset:2 * mod.key_dim + value_offset + local_value_dim] - return local_conv_weight, torch.cat([local_q_bias, local_k_bias, local_v_bias], dim=0) - - -class Qwen3_5GatedDeltaNetUlyssesPatch(Patch): - - @staticmethod - def _run_forward( - mod: torch.nn.Module, - hidden_states: torch.Tensor, - *, - cache_params=None, - cache_position=None, - attention_mask: Optional[torch.Tensor] = None, - sequence_parallel_context=None, - ) -> torch.Tensor: - using_torch_fallback = _ensure_linear_attention_kernels(mod) - modeling_module = import_module(mod.__class__.__module__) - apply_mask_to_padding_states = getattr(modeling_module, 'apply_mask_to_padding_states') - - local_attention_mask = attention_mask - if torch.is_tensor(attention_mask) and attention_mask.dim() == 2: - local_attention_mask = _get_local_padding_mask( - attention_mask, - hidden_states.shape[1], - sequence_parallel_context, - ) - hidden_states = apply_mask_to_padding_states(hidden_states, local_attention_mask) - batch_size, seq_len, _ = hidden_states.shape - - has_previous_state = bool(cache_params is not None and getattr(cache_params, 'has_previous_state', False)) - use_precomputed_states = has_previous_state and seq_len == 1 and cache_position is not None - if use_precomputed_states: - raise NotImplementedError( - 'Qwen3.5 linear attention sequence parallel only supports training/prefill paths; decode with ' - 'cached states is not supported.') - - mixed_qkv = mod.in_proj_qkv(hidden_states) - z = mod.in_proj_z(hidden_states).reshape(batch_size, seq_len, mod.num_v_heads, mod.head_v_dim) - b = mod.in_proj_b(hidden_states) - a = mod.in_proj_a(hidden_states) - - sp_enabled = _sp_is_enabled(sequence_parallel_context) - if sp_enabled: - sp_world_size = int(sequence_parallel_context.sp_world_size) - if mod.num_k_heads % sp_world_size != 0 or mod.num_v_heads % sp_world_size != 0: - raise RuntimeError( - 'Qwen3.5 linear attention sequence parallel requires sp_world_size to divide both ' - f'linear_num_key_heads ({mod.num_k_heads}) and linear_num_value_heads ({mod.num_v_heads}).') - local_num_k_heads = mod.num_k_heads // sp_world_size - local_num_v_heads = mod.num_v_heads // sp_world_size - q_proj, k_proj, v_proj = torch.split(mixed_qkv, [mod.key_dim, mod.key_dim, mod.value_dim], dim=-1) - q_proj = q_proj.reshape(batch_size, seq_len, mod.num_k_heads, mod.head_k_dim) - k_proj = k_proj.reshape(batch_size, seq_len, mod.num_k_heads, mod.head_k_dim) - v_proj = v_proj.reshape(batch_size, seq_len, mod.num_v_heads, mod.head_v_dim) - q_proj = seq_to_head_shard(q_proj, sequence_parallel_context) - k_proj = seq_to_head_shard(k_proj, sequence_parallel_context) - v_proj = seq_to_head_shard(v_proj, sequence_parallel_context) - b = seq_to_head_shard(b.reshape(batch_size, seq_len, mod.num_v_heads, 1), - sequence_parallel_context).squeeze(-1) - a = seq_to_head_shard(a.reshape(batch_size, seq_len, mod.num_v_heads, 1), - sequence_parallel_context).squeeze(-1) - seq_after_shard = q_proj.shape[1] - mixed_qkv = torch.cat( - ( - q_proj.reshape(batch_size, seq_after_shard, local_num_k_heads * mod.head_k_dim), - k_proj.reshape(batch_size, seq_after_shard, local_num_k_heads * mod.head_k_dim), - v_proj.reshape(batch_size, seq_after_shard, local_num_v_heads * mod.head_v_dim), - ), - dim=-1, - ) - sp_rank = _get_sp_rank(sequence_parallel_context) - conv_weight, conv_bias = _get_local_conv_weights( - mod, sp_rank=sp_rank, local_num_k_heads=local_num_k_heads, local_num_v_heads=local_num_v_heads) - else: - local_num_k_heads = mod.num_k_heads - local_num_v_heads = mod.num_v_heads - sp_rank = 0 - b = b.reshape(batch_size, seq_len, mod.num_v_heads) - a = a.reshape(batch_size, seq_len, mod.num_v_heads) - conv_weight = mod.conv1d.weight.squeeze(1) - conv_bias = getattr(mod.conv1d, 'bias', None) - - packed_cu_seqlens = get_packed_cu_seqlens_from_sequence_parallel_context( - sequence_parallel_context, - device=mixed_qkv.device, - ) - extra_kwargs = getattr(sequence_parallel_context, 'extra_kwargs', {}) - if bool(extra_kwargs.get('padding_free', False)) and packed_cu_seqlens is None: - raise ValueError( - 'Qwen3.5 sequence parallel with padding_free/packed inputs requires packed sequence metadata ' - '(for example valid position_ids).') - if using_torch_fallback and packed_cu_seqlens is not None: - raise NotImplementedError( - 'Qwen3.5 linear attention sequence parallel with padding_free/packed inputs requires ' - 'flash-linear-attention. The torch fallback only supports non-packed sequences. ' - 'Please install flash-linear-attention or disable padding_free/packing.') - if cache_params is not None: - cache_params.conv_states[mod.layer_idx] = F.pad( - mixed_qkv.transpose(1, 2).contiguous(), (mod.conv_kernel_size - mixed_qkv.shape[1], 0)) - mixed_qkv = mod.causal_conv1d_fn( - x=mixed_qkv, - weight=conv_weight, - bias=conv_bias, - activation=mod.activation, - seq_idx=None, - backend='triton', - cu_seqlens=packed_cu_seqlens, - ) - if isinstance(mixed_qkv, tuple): - mixed_qkv = mixed_qkv[0] - if mixed_qkv.dim() == 2: - mixed_qkv = mixed_qkv.unsqueeze(0) - if mixed_qkv.dim() != 3: - raise ValueError(f'Unexpected conv output dims: {tuple(mixed_qkv.shape)}') - - local_key_dim = local_num_k_heads * mod.head_k_dim - local_value_dim = local_num_v_heads * mod.head_v_dim - query, key, value = torch.split(mixed_qkv, [local_key_dim, local_key_dim, local_value_dim], dim=-1) - query = query.reshape(batch_size, query.shape[1], local_num_k_heads, mod.head_k_dim) - key = key.reshape(batch_size, key.shape[1], local_num_k_heads, mod.head_k_dim) - value = value.reshape(batch_size, value.shape[1], local_num_v_heads, mod.head_v_dim) - - beta = b.sigmoid() - head_slice = slice(sp_rank * local_num_v_heads, - (sp_rank + 1) * local_num_v_heads) if sp_enabled else slice(None) - g = -mod.A_log[head_slice].float().exp() * F.softplus(a.float() + mod.dt_bias[head_slice]) - - if local_num_v_heads // local_num_k_heads > 1: - repeat = local_num_v_heads // local_num_k_heads - query = query.repeat_interleave(repeat, dim=2) - key = key.repeat_interleave(repeat, dim=2) - - chunk_kwargs = { - 'g': g, - 'beta': beta, - 'initial_state': None, - 'output_final_state': cache_params is not None, - 'use_qk_l2norm_in_kernel': True, - } - if packed_cu_seqlens is not None: - chunk_kwargs['cu_seqlens'] = packed_cu_seqlens - core_attn_out, last_recurrent_state = mod.chunk_gated_delta_rule(query, key, value, **chunk_kwargs) - - if cache_params is not None: - cache_params.recurrent_states[mod.layer_idx] = last_recurrent_state - - if sp_enabled: - core_attn_out = head_to_seq_shard(core_attn_out, sequence_parallel_context) - core_attn_out = mod.norm(core_attn_out.reshape(-1, mod.head_v_dim), z.reshape(-1, mod.head_v_dim)) - core_attn_out = core_attn_out.reshape(batch_size, seq_len, local_value_dim if not sp_enabled else mod.value_dim) - return mod.out_proj(core_attn_out) - - def __call__(self, module, *args, **kwargs): - del module, args - sequence_parallel = kwargs.get('sequence_parallel', None) - if sequence_parallel is None: - return - if int(getattr(sequence_parallel, 'rp_world_size', 1) or 1) > 1: - raise NotImplementedError('Qwen3.5 linear attention sequence parallel does not support rp_world_size > 1 ' - '(derived ring attention).') - - for gated_delta_net_cls in _iter_qwen35_gated_delta_net_classes(): - if getattr(gated_delta_net_cls, '_twinkle_sp_linear_patched', False): - continue - - origin_forward = gated_delta_net_cls.forward - - def sp_linear_forward( - mod, - hidden_states: torch.Tensor, - cache_params=None, - cache_position=None, - attention_mask: Optional[torch.Tensor] = None, - _origin_forward=origin_forward, - **extra_kwargs, - ): - sequence_parallel_context = extra_kwargs.pop('sequence_parallel_context', sequence_parallel) - if not _sp_is_enabled(sequence_parallel_context): - return _origin_forward( - mod, - hidden_states, - cache_params=cache_params, - cache_position=cache_position, - attention_mask=attention_mask, - **extra_kwargs, - ) - return Qwen3_5GatedDeltaNetUlyssesPatch._run_forward( - mod, - hidden_states, - cache_params=cache_params, - cache_position=cache_position, - attention_mask=attention_mask, - sequence_parallel_context=sequence_parallel_context, - ) - - gated_delta_net_cls.forward = sp_linear_forward - gated_delta_net_cls._twinkle_sp_linear_patched = True +import torch +import torch.distributed as dist +import torch.nn.functional as F +import warnings +from importlib import import_module +from transformers.utils.import_utils import is_causal_conv1d_available, is_flash_linear_attention_available +from typing import Any, Optional, Tuple + +from twinkle.model.transformers.strategy.sequence_parallel.utils import ( + get_packed_cu_seqlens_from_sequence_parallel_context, head_to_seq_shard, seq_to_head_shard) +from twinkle.patch import Patch + +if is_flash_linear_attention_available(): + from fla.modules.convolution import causal_conv1d as _FLA_CAUSAL_CONV1D_FN + from fla.ops.gated_delta_rule import chunk_gated_delta_rule as _FLA_CHUNK_GATED_DELTA_RULE +else: + _FLA_CAUSAL_CONV1D_FN = None + _FLA_CHUNK_GATED_DELTA_RULE = None + +if is_causal_conv1d_available(): + from causal_conv1d import causal_conv1d_fn as _CAUSAL_CONV1D_FN +else: + _CAUSAL_CONV1D_FN = None + +_SP_LINEAR_KERNEL_FALLBACK_WARNING = ( + 'flash-linear-attention is not available; falling back to torch implementations for Qwen3.5 linear attention ' + 'sequence parallel. This fallback only supports non-packed sequences.') + + +def _sp_is_enabled(sequence_parallel_context) -> bool: + return bool(sequence_parallel_context is not None and getattr(sequence_parallel_context, 'world_size', 1) > 1) + + +def _get_sp_rank(sequence_parallel_context) -> int: + if not _sp_is_enabled(sequence_parallel_context): + return 0 + if getattr(sequence_parallel_context, '_sp_group', None) is None: + return 0 + return dist.get_rank(group=sequence_parallel_context._sp_group) + + +def _get_local_padding_mask( + attention_mask: torch.Tensor, + local_seq_len: int, + sequence_parallel_context, +) -> torch.Tensor: + if attention_mask.shape[-1] == local_seq_len or not _sp_is_enabled(sequence_parallel_context): + return attention_mask + return sequence_parallel_context.split( + attention_mask, + dim=1, + position_ids=sequence_parallel_context.real_position_ids, + ) + + +def _apply_conv_activation(x: torch.Tensor, activation) -> torch.Tensor: + if activation is None: + return x + if activation in ('silu', 'swish'): + return F.silu(x) + if callable(activation): + return activation(x) + from transformers.activations import ACT2FN + if activation in ACT2FN: + return ACT2FN[activation](x) + raise ValueError(f'Unsupported causal conv activation: {activation!r}') + + +def _ensure_linear_attention_kernels(mod: torch.nn.Module): + """Bind causal_conv1d_fn and chunk_gated_delta_rule for SP forward.""" + + def _torch_causal_conv1d_fn( + *, + x, + weight, + bias=None, + activation=None, + seq_idx=None, + backend=None, + cu_seqlens=None, + ): + # Fallback priority: + # 1. flash-linear-attention kernels handle padding_free/packed cu_seqlens and are selected above. + # 2. causal-conv1d package accelerates non-packed convolution when flash-linear-attention is unavailable. + # 3. plain torch conv1d is the final non-packed fallback. + del backend + if cu_seqlens is not None: + raise NotImplementedError( + 'Qwen3.5 linear attention sequence parallel with padding_free/packed inputs requires ' + 'flash-linear-attention. The torch fallback only supports non-packed sequences. ' + 'Please install flash-linear-attention or disable padding_free/packing.') + if _CAUSAL_CONV1D_FN is not None: + out = _CAUSAL_CONV1D_FN( + x=x.transpose(1, 2).contiguous(), + weight=weight, + bias=bias, + activation=activation, + seq_idx=seq_idx, + ) + if isinstance(out, tuple): + out = out[0] + return out.transpose(1, 2).contiguous() + seq_len = x.shape[1] + x = x.transpose(1, 2).contiguous() + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=weight.shape[-1] - 1, groups=x.shape[1]) + out = _apply_conv_activation(out[:, :, :seq_len], activation) + return out.transpose(1, 2).contiguous() + + # NPU: MindSpeed Triton causal_conv1d and chunk_gated_delta_rule + # are both patched by twinkle.kernel.npu_impls.fla at model initialization. + # No need to set them here - they are already bound on the module. + if getattr(mod, '_twinkle_npu_patched', False): + return False + + if _FLA_CAUSAL_CONV1D_FN is not None and _FLA_CHUNK_GATED_DELTA_RULE is not None: + mod.causal_conv1d_fn = _FLA_CAUSAL_CONV1D_FN + mod.chunk_gated_delta_rule = _FLA_CHUNK_GATED_DELTA_RULE + return False + + modeling_module = import_module(mod.__class__.__module__) + torch_chunk_gated_delta_rule = getattr(modeling_module, 'torch_chunk_gated_delta_rule') + mod.causal_conv1d_fn = _torch_causal_conv1d_fn + mod.chunk_gated_delta_rule = torch_chunk_gated_delta_rule + warnings.warn(_SP_LINEAR_KERNEL_FALLBACK_WARNING, stacklevel=2) + return True + + +def _iter_qwen35_gated_delta_net_classes(): + class_specs = ( + ('transformers.models.qwen3_5.modeling_qwen3_5', 'Qwen3_5GatedDeltaNet'), + ('transformers.models.qwen3_5_moe.modeling_qwen3_5_moe', 'Qwen3_5MoeGatedDeltaNet'), + ) + for module_name, class_name in class_specs: + try: + modeling_module = import_module(module_name) + yield getattr(modeling_module, class_name) + except Exception: + continue + + +def _get_local_conv_weights( + mod: torch.nn.Module, + *, + sp_rank: int, + local_num_k_heads: int, + local_num_v_heads: int, +) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + local_key_dim = local_num_k_heads * mod.head_k_dim + local_value_dim = local_num_v_heads * mod.head_v_dim + conv_weight = mod.conv1d.weight.squeeze(1) + if conv_weight.shape[0] != (2 * mod.key_dim + mod.value_dim): + raise ValueError( + f'Unexpected conv weight dim {conv_weight.shape[0]}, expected {2 * mod.key_dim + mod.value_dim}.') + key_offset = sp_rank * local_key_dim + value_offset = sp_rank * local_value_dim + local_q_weight = conv_weight[key_offset:key_offset + local_key_dim] + local_k_weight = conv_weight[mod.key_dim + key_offset:mod.key_dim + key_offset + local_key_dim] + local_v_weight = conv_weight[2 * mod.key_dim + value_offset:2 * mod.key_dim + value_offset + local_value_dim] + local_conv_weight = torch.cat([local_q_weight, local_k_weight, local_v_weight], dim=0) + + conv_bias = getattr(mod.conv1d, 'bias', None) + if conv_bias is None: + return local_conv_weight, None + local_q_bias = conv_bias[key_offset:key_offset + local_key_dim] + local_k_bias = conv_bias[mod.key_dim + key_offset:mod.key_dim + key_offset + local_key_dim] + local_v_bias = conv_bias[2 * mod.key_dim + value_offset:2 * mod.key_dim + value_offset + local_value_dim] + return local_conv_weight, torch.cat([local_q_bias, local_k_bias, local_v_bias], dim=0) + + +class Qwen3_5GatedDeltaNetUlyssesPatch(Patch): + + @staticmethod + def _run_forward( + mod: torch.nn.Module, + hidden_states: torch.Tensor, + *, + cache_params=None, + cache_position=None, + attention_mask: Optional[torch.Tensor] = None, + sequence_parallel_context=None, + ) -> torch.Tensor: + using_torch_fallback = _ensure_linear_attention_kernels(mod) + modeling_module = import_module(mod.__class__.__module__) + apply_mask_to_padding_states = getattr(modeling_module, 'apply_mask_to_padding_states') + + local_attention_mask = attention_mask + if torch.is_tensor(attention_mask) and attention_mask.dim() == 2: + local_attention_mask = _get_local_padding_mask( + attention_mask, + hidden_states.shape[1], + sequence_parallel_context, + ) + hidden_states = apply_mask_to_padding_states(hidden_states, local_attention_mask) + batch_size, seq_len, _ = hidden_states.shape + + has_previous_state = bool(cache_params is not None and getattr(cache_params, 'has_previous_state', False)) + use_precomputed_states = has_previous_state and seq_len == 1 and cache_position is not None + if use_precomputed_states: + raise NotImplementedError( + 'Qwen3.5 linear attention sequence parallel only supports training/prefill paths; decode with ' + 'cached states is not supported.') + + mixed_qkv = mod.in_proj_qkv(hidden_states) + z = mod.in_proj_z(hidden_states).reshape(batch_size, seq_len, mod.num_v_heads, mod.head_v_dim) + b = mod.in_proj_b(hidden_states) + a = mod.in_proj_a(hidden_states) + + sp_enabled = _sp_is_enabled(sequence_parallel_context) + if sp_enabled: + sp_world_size = int(sequence_parallel_context.sp_world_size) + if mod.num_k_heads % sp_world_size != 0 or mod.num_v_heads % sp_world_size != 0: + raise RuntimeError( + 'Qwen3.5 linear attention sequence parallel requires sp_world_size to divide both ' + f'linear_num_key_heads ({mod.num_k_heads}) and linear_num_value_heads ({mod.num_v_heads}).') + local_num_k_heads = mod.num_k_heads // sp_world_size + local_num_v_heads = mod.num_v_heads // sp_world_size + q_proj, k_proj, v_proj = torch.split(mixed_qkv, [mod.key_dim, mod.key_dim, mod.value_dim], dim=-1) + q_proj = q_proj.reshape(batch_size, seq_len, mod.num_k_heads, mod.head_k_dim) + k_proj = k_proj.reshape(batch_size, seq_len, mod.num_k_heads, mod.head_k_dim) + v_proj = v_proj.reshape(batch_size, seq_len, mod.num_v_heads, mod.head_v_dim) + q_proj = seq_to_head_shard(q_proj, sequence_parallel_context) + k_proj = seq_to_head_shard(k_proj, sequence_parallel_context) + v_proj = seq_to_head_shard(v_proj, sequence_parallel_context) + b = seq_to_head_shard(b.reshape(batch_size, seq_len, mod.num_v_heads, 1), + sequence_parallel_context).squeeze(-1) + a = seq_to_head_shard(a.reshape(batch_size, seq_len, mod.num_v_heads, 1), + sequence_parallel_context).squeeze(-1) + seq_after_shard = q_proj.shape[1] + mixed_qkv = torch.cat( + ( + q_proj.reshape(batch_size, seq_after_shard, local_num_k_heads * mod.head_k_dim), + k_proj.reshape(batch_size, seq_after_shard, local_num_k_heads * mod.head_k_dim), + v_proj.reshape(batch_size, seq_after_shard, local_num_v_heads * mod.head_v_dim), + ), + dim=-1, + ) + sp_rank = _get_sp_rank(sequence_parallel_context) + conv_weight, conv_bias = _get_local_conv_weights( + mod, sp_rank=sp_rank, local_num_k_heads=local_num_k_heads, local_num_v_heads=local_num_v_heads) + else: + local_num_k_heads = mod.num_k_heads + local_num_v_heads = mod.num_v_heads + sp_rank = 0 + b = b.reshape(batch_size, seq_len, mod.num_v_heads) + a = a.reshape(batch_size, seq_len, mod.num_v_heads) + conv_weight = mod.conv1d.weight.squeeze(1) + conv_bias = getattr(mod.conv1d, 'bias', None) + + packed_cu_seqlens = get_packed_cu_seqlens_from_sequence_parallel_context( + sequence_parallel_context, + device=mixed_qkv.device, + ) + extra_kwargs = getattr(sequence_parallel_context, 'extra_kwargs', {}) + if bool(extra_kwargs.get('padding_free', False)) and packed_cu_seqlens is None: + raise ValueError( + 'Qwen3.5 sequence parallel with padding_free/packed inputs requires packed sequence metadata ' + '(for example valid position_ids).') + if using_torch_fallback and packed_cu_seqlens is not None: + raise NotImplementedError( + 'Qwen3.5 linear attention sequence parallel with padding_free/packed inputs requires ' + 'flash-linear-attention. The torch fallback only supports non-packed sequences. ' + 'Please install flash-linear-attention or disable padding_free/packing.') + if cache_params is not None: + cache_params.conv_states[mod.layer_idx] = F.pad( + mixed_qkv.transpose(1, 2).contiguous(), (mod.conv_kernel_size - mixed_qkv.shape[1], 0)) + mixed_qkv = mod.causal_conv1d_fn( + x=mixed_qkv, + weight=conv_weight, + bias=conv_bias, + activation=mod.activation, + seq_idx=None, + backend='triton', + cu_seqlens=packed_cu_seqlens, + ) + if isinstance(mixed_qkv, tuple): + mixed_qkv = mixed_qkv[0] + if mixed_qkv.dim() == 2: + mixed_qkv = mixed_qkv.unsqueeze(0) + if mixed_qkv.dim() != 3: + raise ValueError(f'Unexpected conv output dims: {tuple(mixed_qkv.shape)}') + + local_key_dim = local_num_k_heads * mod.head_k_dim + local_value_dim = local_num_v_heads * mod.head_v_dim + query, key, value = torch.split(mixed_qkv, [local_key_dim, local_key_dim, local_value_dim], dim=-1) + query = query.reshape(batch_size, query.shape[1], local_num_k_heads, mod.head_k_dim) + key = key.reshape(batch_size, key.shape[1], local_num_k_heads, mod.head_k_dim) + value = value.reshape(batch_size, value.shape[1], local_num_v_heads, mod.head_v_dim) + + beta = b.sigmoid() + head_slice = slice(sp_rank * local_num_v_heads, + (sp_rank + 1) * local_num_v_heads) if sp_enabled else slice(None) + g = -mod.A_log[head_slice].float().exp() * F.softplus(a.float() + mod.dt_bias[head_slice]) + + if local_num_v_heads // local_num_k_heads > 1: + repeat = local_num_v_heads // local_num_k_heads + query = query.repeat_interleave(repeat, dim=2) + key = key.repeat_interleave(repeat, dim=2) + + chunk_kwargs = { + 'g': g, + 'beta': beta, + 'initial_state': None, + 'output_final_state': cache_params is not None, + 'use_qk_l2norm_in_kernel': True, + } + if packed_cu_seqlens is not None: + chunk_kwargs['cu_seqlens'] = packed_cu_seqlens + core_attn_out, last_recurrent_state = mod.chunk_gated_delta_rule(query, key, value, **chunk_kwargs) + + if cache_params is not None: + cache_params.recurrent_states[mod.layer_idx] = last_recurrent_state + + if sp_enabled: + core_attn_out = head_to_seq_shard(core_attn_out, sequence_parallel_context) + core_attn_out = mod.norm(core_attn_out.reshape(-1, mod.head_v_dim), z.reshape(-1, mod.head_v_dim)) + core_attn_out = core_attn_out.reshape(batch_size, seq_len, local_value_dim if not sp_enabled else mod.value_dim) + return mod.out_proj(core_attn_out) + + def __call__(self, module, *args, **kwargs): + del module, args + sequence_parallel = kwargs.get('sequence_parallel', None) + if sequence_parallel is None: + return + if int(getattr(sequence_parallel, 'rp_world_size', 1) or 1) > 1: + raise NotImplementedError('Qwen3.5 linear attention sequence parallel does not support rp_world_size > 1 ' + '(derived ring attention).') + + for gated_delta_net_cls in _iter_qwen35_gated_delta_net_classes(): + if getattr(gated_delta_net_cls, '_twinkle_sp_linear_patched', False): + continue + + origin_forward = gated_delta_net_cls.forward + + def sp_linear_forward( + mod, + hidden_states: torch.Tensor, + cache_params=None, + cache_position=None, + attention_mask: Optional[torch.Tensor] = None, + _origin_forward=origin_forward, + **extra_kwargs, + ): + sequence_parallel_context = extra_kwargs.pop('sequence_parallel_context', sequence_parallel) + if not _sp_is_enabled(sequence_parallel_context): + return _origin_forward( + mod, + hidden_states, + cache_params=cache_params, + cache_position=cache_position, + attention_mask=attention_mask, + **extra_kwargs, + ) + return Qwen3_5GatedDeltaNetUlyssesPatch._run_forward( + mod, + hidden_states, + cache_params=cache_params, + cache_position=cache_position, + attention_mask=attention_mask, + sequence_parallel_context=sequence_parallel_context, + ) + + gated_delta_net_cls.forward = sp_linear_forward + gated_delta_net_cls._twinkle_sp_linear_patched = True diff --git a/tests/kernel/npu_impls/test_attention.py b/tests/kernel/npu_impls/test_attention.py index 06fa41ac..ed916dba 100644 --- a/tests/kernel/npu_impls/test_attention.py +++ b/tests/kernel/npu_impls/test_attention.py @@ -1,16 +1,16 @@ -def test_attention_imports(): - from twinkle.kernel.npu_impls.attention import npu_sdpa_attention_forward - assert callable(npu_sdpa_attention_forward) - - -def test_attention_signature(): - import inspect - - from twinkle.kernel.npu_impls.attention import npu_sdpa_attention_forward - - sig = inspect.signature(npu_sdpa_attention_forward) - params = list(sig.parameters) - assert params[:5] == ['module', 'query', 'key', 'value', 'attention_mask'] - assert sig.parameters['dropout'].default == 0.0 - assert sig.parameters['scaling'].default is None +def test_attention_imports(): + from twinkle.kernel.npu_impls.attention import npu_sdpa_attention_forward + assert callable(npu_sdpa_attention_forward) + + +def test_attention_signature(): + import inspect + + from twinkle.kernel.npu_impls.attention import npu_sdpa_attention_forward + + sig = inspect.signature(npu_sdpa_attention_forward) + params = list(sig.parameters) + assert params[:5] == ['module', 'query', 'key', 'value', 'attention_mask'] + assert sig.parameters['dropout'].default == 0.0 + assert sig.parameters['scaling'].default is None assert sig.parameters['is_causal'].default is None \ No newline at end of file diff --git a/tests/kernel/npu_impls/test_fla.py b/tests/kernel/npu_impls/test_fla.py index 86de8c6b..0cfeda1d 100644 --- a/tests/kernel/npu_impls/test_fla.py +++ b/tests/kernel/npu_impls/test_fla.py @@ -1,55 +1,55 @@ -def test_fla_imports(): - from twinkle.kernel.npu_impls.fla import apply_qwen3_5_fla - assert callable(apply_qwen3_5_fla) - - -def test_fla_disabled_by_env(monkeypatch): - monkeypatch.setenv('TWINKLE_NPU_FLA', '0') - from twinkle.kernel.npu_impls.fla import apply_qwen3_5_fla - # With env=0, function returns 0 (no-op) without raising - assert apply_qwen3_5_fla(None) == 0 - - -def test_fla_skips_when_no_torch_npu(monkeypatch): - import sys - monkeypatch.setenv('TWINKLE_NPU_FLA', '1') - monkeypatch.setitem(sys.modules, 'torch_npu', None) # forces ImportError on import - from twinkle.kernel.npu_impls import fla as fla_mod - # Reload-tolerant: should return 0 when torch_npu is missing. - assert fla_mod.apply_qwen3_5_fla(None) == 0 - - -def test_fla_does_not_flip_flag_when_mindspeed_missing(monkeypatch): - """On an NPU host where the MindSpeed FLA kernel cannot be imported, - ``apply_qwen3_5_fla`` must NOT flip the global ``is_flash_linear_attention_available`` - flag — otherwise HF transformers would route Qwen3.5 onto a FLA fast path - whose kernel is not installed (runtime failure).""" - import sys - import types - - import transformers.utils as tu - - monkeypatch.setenv('TWINKLE_NPU_FLA', '1') - # Fake torch_npu as importable (with a real __spec__ so find_spec doesn't trip) - import importlib.util - spec = importlib.util.spec_from_loader('torch_npu', loader=None) - fake_npu = importlib.util.module_from_spec(spec) - monkeypatch.setitem(sys.modules, 'torch_npu', fake_npu) - # Stub causal_conv1d so the heavy real import chain doesn't run - fake_conv = types.ModuleType('twinkle.kernel.causal_conv1d') - fake_conv.npu_causal_conv1d_fn = object() - monkeypatch.setitem(sys.modules, 'twinkle.kernel.causal_conv1d', fake_conv) - # Force the MindSpeed-backed module import to fail - monkeypatch.setitem(sys.modules, 'twinkle.kernel.chunk_gated_delta_rule', None) - - original_flag = tu.is_flash_linear_attention_available - try: - from twinkle.kernel.npu_impls.fla import apply_qwen3_5_fla - assert apply_qwen3_5_fla(None) == 0 - assert tu.is_flash_linear_attention_available is original_flag, ( - 'is_flash_linear_attention_available was flipped to True while the ' - 'MindSpeed kernel is unavailable — this would break Qwen3.5 at runtime.' - ) - finally: - # Defensive cleanup in case the buggy path ran. +def test_fla_imports(): + from twinkle.kernel.npu_impls.fla import apply_qwen3_5_fla + assert callable(apply_qwen3_5_fla) + + +def test_fla_disabled_by_env(monkeypatch): + monkeypatch.setenv('TWINKLE_NPU_FLA', '0') + from twinkle.kernel.npu_impls.fla import apply_qwen3_5_fla + # With env=0, function returns 0 (no-op) without raising + assert apply_qwen3_5_fla(None) == 0 + + +def test_fla_skips_when_no_torch_npu(monkeypatch): + import sys + monkeypatch.setenv('TWINKLE_NPU_FLA', '1') + monkeypatch.setitem(sys.modules, 'torch_npu', None) # forces ImportError on import + from twinkle.kernel.npu_impls import fla as fla_mod + # Reload-tolerant: should return 0 when torch_npu is missing. + assert fla_mod.apply_qwen3_5_fla(None) == 0 + + +def test_fla_does_not_flip_flag_when_mindspeed_missing(monkeypatch): + """On an NPU host where the MindSpeed FLA kernel cannot be imported, + ``apply_qwen3_5_fla`` must NOT flip the global ``is_flash_linear_attention_available`` + flag — otherwise HF transformers would route Qwen3.5 onto a FLA fast path + whose kernel is not installed (runtime failure).""" + import sys + import types + + import transformers.utils as tu + + monkeypatch.setenv('TWINKLE_NPU_FLA', '1') + # Fake torch_npu as importable (with a real __spec__ so find_spec doesn't trip) + import importlib.util + spec = importlib.util.spec_from_loader('torch_npu', loader=None) + fake_npu = importlib.util.module_from_spec(spec) + monkeypatch.setitem(sys.modules, 'torch_npu', fake_npu) + # Stub causal_conv1d so the heavy real import chain doesn't run + fake_conv = types.ModuleType('twinkle.kernel.causal_conv1d') + fake_conv.npu_causal_conv1d_fn = object() + monkeypatch.setitem(sys.modules, 'twinkle.kernel.causal_conv1d', fake_conv) + # Force the MindSpeed-backed module import to fail + monkeypatch.setitem(sys.modules, 'twinkle.kernel.chunk_gated_delta_rule', None) + + original_flag = tu.is_flash_linear_attention_available + try: + from twinkle.kernel.npu_impls.fla import apply_qwen3_5_fla + assert apply_qwen3_5_fla(None) == 0 + assert tu.is_flash_linear_attention_available is original_flag, ( + 'is_flash_linear_attention_available was flipped to True while the ' + 'MindSpeed kernel is unavailable — this would break Qwen3.5 at runtime.' + ) + finally: + # Defensive cleanup in case the buggy path ran. tu.is_flash_linear_attention_available = original_flag \ No newline at end of file diff --git a/tests/kernel/npu_impls/test_moe.py b/tests/kernel/npu_impls/test_moe.py index e26a5856..34452b61 100644 --- a/tests/kernel/npu_impls/test_moe.py +++ b/tests/kernel/npu_impls/test_moe.py @@ -1,12 +1,12 @@ -def test_moe_imports(): - from twinkle.kernel.npu_impls.moe import ( - GmmFunction, - npu_grouped_mm, - npu_packed_moe_experts_forward, - npu_qwen3_5_moe_sparse_block_forward, - ) - import torch - assert issubclass(GmmFunction, torch.autograd.Function) - assert callable(npu_grouped_mm) - assert callable(npu_packed_moe_experts_forward) +def test_moe_imports(): + from twinkle.kernel.npu_impls.moe import ( + GmmFunction, + npu_grouped_mm, + npu_packed_moe_experts_forward, + npu_qwen3_5_moe_sparse_block_forward, + ) + import torch + assert issubclass(GmmFunction, torch.autograd.Function) + assert callable(npu_grouped_mm) + assert callable(npu_packed_moe_experts_forward) assert callable(npu_qwen3_5_moe_sparse_block_forward) \ No newline at end of file diff --git a/tests/kernel/npu_impls/test_rms_norm.py b/tests/kernel/npu_impls/test_rms_norm.py index a50c6ad0..184d7ef7 100644 --- a/tests/kernel/npu_impls/test_rms_norm.py +++ b/tests/kernel/npu_impls/test_rms_norm.py @@ -1,40 +1,40 @@ -import pytest -import torch -import torch.nn as nn - -try: - import torch_npu # noqa: F401 - _NPU_OK = True -except ImportError: - _NPU_OK = False - - -def test_imports(): - """NpuRMSNorm and npu_gated_rms_norm_forward import without torch_npu.""" - from twinkle.kernel.npu_impls.rms_norm import NpuRMSNorm, npu_gated_rms_norm_forward - assert NpuRMSNorm is not None - assert callable(npu_gated_rms_norm_forward) - - -def test_npu_rmsnorm_has_no_init(): - """Class-replacement contract: NpuRMSNorm must not define its own __init__.""" - from twinkle.kernel.npu_impls.rms_norm import NpuRMSNorm - # If NpuRMSNorm defines __init__, it'd appear in NpuRMSNorm.__dict__ - assert '__init__' not in NpuRMSNorm.__dict__ - - -@pytest.mark.skipif(not _NPU_OK, reason='torch_npu unavailable') -def test_npu_rmsnorm_forward_runs_on_npu(): - from twinkle.kernel.npu_impls.rms_norm import NpuRMSNorm - - class _Orig(nn.Module): - def __init__(self): - super().__init__() - self.weight = nn.Parameter(torch.ones(8)) - self.variance_epsilon = 1e-6 - - m = _Orig().to('npu') - m.__class__ = NpuRMSNorm - x = torch.randn(2, 8, device='npu') - y = m(x) +import pytest +import torch +import torch.nn as nn + +try: + import torch_npu # noqa: F401 + _NPU_OK = True +except ImportError: + _NPU_OK = False + + +def test_imports(): + """NpuRMSNorm and npu_gated_rms_norm_forward import without torch_npu.""" + from twinkle.kernel.npu_impls.rms_norm import NpuRMSNorm, npu_gated_rms_norm_forward + assert NpuRMSNorm is not None + assert callable(npu_gated_rms_norm_forward) + + +def test_npu_rmsnorm_has_no_init(): + """Class-replacement contract: NpuRMSNorm must not define its own __init__.""" + from twinkle.kernel.npu_impls.rms_norm import NpuRMSNorm + # If NpuRMSNorm defines __init__, it'd appear in NpuRMSNorm.__dict__ + assert '__init__' not in NpuRMSNorm.__dict__ + + +@pytest.mark.skipif(not _NPU_OK, reason='torch_npu unavailable') +def test_npu_rmsnorm_forward_runs_on_npu(): + from twinkle.kernel.npu_impls.rms_norm import NpuRMSNorm + + class _Orig(nn.Module): + def __init__(self): + super().__init__() + self.weight = nn.Parameter(torch.ones(8)) + self.variance_epsilon = 1e-6 + + m = _Orig().to('npu') + m.__class__ = NpuRMSNorm + x = torch.randn(2, 8, device='npu') + y = m(x) assert y.shape == (2, 8) \ No newline at end of file diff --git a/tests/kernel/npu_impls/test_rotary.py b/tests/kernel/npu_impls/test_rotary.py index fc15fc54..460d0fc3 100644 --- a/tests/kernel/npu_impls/test_rotary.py +++ b/tests/kernel/npu_impls/test_rotary.py @@ -1,21 +1,21 @@ -def test_rotary_imports(): - from twinkle.kernel.npu_impls.rotary import ( - npu_apply_multimodal_rotary_pos_emb, - npu_apply_rotary_pos_emb, - ) - assert callable(npu_apply_rotary_pos_emb) - assert callable(npu_apply_multimodal_rotary_pos_emb) - - -def test_rotary_signature_compat(): - """Signature must match HF apply_rotary_pos_emb so setattr swap is safe.""" - import inspect - - from twinkle.kernel.npu_impls.rotary import npu_apply_rotary_pos_emb - - sig = inspect.signature(npu_apply_rotary_pos_emb) - params = list(sig.parameters) - assert params[:4] == ['q', 'k', 'cos', 'sin'] - # position_ids and unsqueeze_dim must be optional - assert sig.parameters['position_ids'].default is None +def test_rotary_imports(): + from twinkle.kernel.npu_impls.rotary import ( + npu_apply_multimodal_rotary_pos_emb, + npu_apply_rotary_pos_emb, + ) + assert callable(npu_apply_rotary_pos_emb) + assert callable(npu_apply_multimodal_rotary_pos_emb) + + +def test_rotary_signature_compat(): + """Signature must match HF apply_rotary_pos_emb so setattr swap is safe.""" + import inspect + + from twinkle.kernel.npu_impls.rotary import npu_apply_rotary_pos_emb + + sig = inspect.signature(npu_apply_rotary_pos_emb) + params = list(sig.parameters) + assert params[:4] == ['q', 'k', 'cos', 'sin'] + # position_ids and unsqueeze_dim must be optional + assert sig.parameters['position_ids'].default is None assert sig.parameters['unsqueeze_dim'].default == 1 \ No newline at end of file diff --git a/tests/kernel/npu_impls/test_swiglu.py b/tests/kernel/npu_impls/test_swiglu.py index b3547d7f..d4ec2da9 100644 --- a/tests/kernel/npu_impls/test_swiglu.py +++ b/tests/kernel/npu_impls/test_swiglu.py @@ -1,12 +1,12 @@ -def test_swiglu_imports(): - from twinkle.kernel.npu_impls.swiglu import npu_swiglu_forward - assert callable(npu_swiglu_forward) - - -def test_swiglu_signature(): - import inspect - - from twinkle.kernel.npu_impls.swiglu import npu_swiglu_forward - - params = list(inspect.signature(npu_swiglu_forward).parameters) +def test_swiglu_imports(): + from twinkle.kernel.npu_impls.swiglu import npu_swiglu_forward + assert callable(npu_swiglu_forward) + + +def test_swiglu_signature(): + import inspect + + from twinkle.kernel.npu_impls.swiglu import npu_swiglu_forward + + params = list(inspect.signature(npu_swiglu_forward).parameters) assert params == ['self', 'hidden_state'] \ No newline at end of file diff --git a/tests/kernel/test_builtin.py b/tests/kernel/test_builtin.py index 0bd050ed..38d83915 100644 --- a/tests/kernel/test_builtin.py +++ b/tests/kernel/test_builtin.py @@ -1,90 +1,90 @@ -import importlib.machinery -import sys -import types - -import torch -import torch.nn as nn - -import pytest - - -def _fake_module(name: str): - module = types.ModuleType(name) - module.__spec__ = importlib.machinery.ModuleSpec(name, loader=None) - return module - - -def test_npu_builtin_returns_dict(): - from twinkle.kernel.builtin import npu_builtin - bundle = npu_builtin() - assert isinstance(bundle, dict) - assert len(bundle) > 0 - - -def test_npu_builtin_values_are_npu_gated(): - """Every value in npu_builtin() must be wrapped in {'npu': ...} so it's - safely no-op on CUDA/CPU.""" - from twinkle.kernel.builtin import npu_builtin - for key, value in npu_builtin().items(): - assert isinstance(value, dict), f'value for {key!r} is not a device-dict' - assert 'npu' in value, f'value for {key!r} is missing npu entry' - - -def test_npu_builtin_compose_with_user_override(): - """User-supplied keys override the builtin (via plain dict merge).""" - from twinkle.kernel.builtin import npu_builtin - sentinel = object() - merged = {**npu_builtin(), 'fake.module.path.fn': sentinel} - assert merged['fake.module.path.fn'] is sentinel - - -def test_npu_builtin_safe_on_cpu_model(): - """kernelize(cpu_model, npu_builtin()) must not raise and not modify.""" - from twinkle.kernel import kernelize - from twinkle.kernel.builtin import npu_builtin - - m = nn.Sequential(nn.Linear(2, 2)) - pre_type = type(m[0]) - out = kernelize(m, npu_builtin()) - assert out is m - assert type(m[0]) is pre_type # no replacement happened (cpu device) - - -def test_npu_builtin_skips_missing_modeling_modules(): - """If transformers.models.qwen3_5 is not installed, the bundle must - still produce a dict (with whatever subset is available).""" - from twinkle.kernel.builtin import npu_builtin - bundle = npu_builtin() # must not raise - assert isinstance(bundle, dict) - - -def test_npu_builtin_does_not_overwrite_global_sdpa_on_non_npu_host(monkeypatch): - """Calling npu_builtin() on a CUDA/CPU host must not contaminate the - global HF SDPA registry. The NPU impl inverts boolean masks, which is - wrong for non-NPU execution.""" - from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS - from twinkle.kernel.builtin import npu_builtin - from twinkle.utils.device_mesh import Platform - - monkeypatch.setattr(Platform, 'device_prefix', staticmethod(lambda platform=None: 'cuda')) - original = ALL_ATTENTION_FUNCTIONS.get('sdpa') - npu_builtin() - assert ALL_ATTENTION_FUNCTIONS.get('sdpa') is original - - -def test_npu_builtin_skips_side_effects_on_non_npu_platform(monkeypatch): - from twinkle.kernel import builtin - from twinkle.kernel.npu_impls import fla - from twinkle.utils.device_mesh import Platform - - installs = [] - fla_calls = [] - monkeypatch.setattr(Platform, 'device_prefix', staticmethod(lambda platform=None: 'cuda')) - monkeypatch.setitem(sys.modules, 'torch_npu', _fake_module('torch_npu')) - monkeypatch.setattr(builtin, '_install_sdpa', lambda impl: installs.append(impl)) - monkeypatch.setattr(fla, 'apply_qwen3_5_fla', lambda model: fla_calls.append(model)) - - builtin.npu_builtin(nn.Linear(1, 1)) - - assert installs == [] - assert fla_calls == [] +import importlib.machinery +import sys +import types + +import torch +import torch.nn as nn + +import pytest + + +def _fake_module(name: str): + module = types.ModuleType(name) + module.__spec__ = importlib.machinery.ModuleSpec(name, loader=None) + return module + + +def test_npu_builtin_returns_dict(): + from twinkle.kernel.builtin import npu_builtin + bundle = npu_builtin() + assert isinstance(bundle, dict) + assert len(bundle) > 0 + + +def test_npu_builtin_values_are_npu_gated(): + """Every value in npu_builtin() must be wrapped in {'npu': ...} so it's + safely no-op on CUDA/CPU.""" + from twinkle.kernel.builtin import npu_builtin + for key, value in npu_builtin().items(): + assert isinstance(value, dict), f'value for {key!r} is not a device-dict' + assert 'npu' in value, f'value for {key!r} is missing npu entry' + + +def test_npu_builtin_compose_with_user_override(): + """User-supplied keys override the builtin (via plain dict merge).""" + from twinkle.kernel.builtin import npu_builtin + sentinel = object() + merged = {**npu_builtin(), 'fake.module.path.fn': sentinel} + assert merged['fake.module.path.fn'] is sentinel + + +def test_npu_builtin_safe_on_cpu_model(): + """kernelize(cpu_model, npu_builtin()) must not raise and not modify.""" + from twinkle.kernel import kernelize + from twinkle.kernel.builtin import npu_builtin + + m = nn.Sequential(nn.Linear(2, 2)) + pre_type = type(m[0]) + out = kernelize(m, npu_builtin()) + assert out is m + assert type(m[0]) is pre_type # no replacement happened (cpu device) + + +def test_npu_builtin_skips_missing_modeling_modules(): + """If transformers.models.qwen3_5 is not installed, the bundle must + still produce a dict (with whatever subset is available).""" + from twinkle.kernel.builtin import npu_builtin + bundle = npu_builtin() # must not raise + assert isinstance(bundle, dict) + + +def test_npu_builtin_does_not_overwrite_global_sdpa_on_non_npu_host(monkeypatch): + """Calling npu_builtin() on a CUDA/CPU host must not contaminate the + global HF SDPA registry. The NPU impl inverts boolean masks, which is + wrong for non-NPU execution.""" + from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS + from twinkle.kernel.builtin import npu_builtin + from twinkle.utils.device_mesh import Platform + + monkeypatch.setattr(Platform, 'device_prefix', staticmethod(lambda platform=None: 'cuda')) + original = ALL_ATTENTION_FUNCTIONS.get('sdpa') + npu_builtin() + assert ALL_ATTENTION_FUNCTIONS.get('sdpa') is original + + +def test_npu_builtin_skips_side_effects_on_non_npu_platform(monkeypatch): + from twinkle.kernel import builtin + from twinkle.kernel.npu_impls import fla + from twinkle.utils.device_mesh import Platform + + installs = [] + fla_calls = [] + monkeypatch.setattr(Platform, 'device_prefix', staticmethod(lambda platform=None: 'cuda')) + monkeypatch.setitem(sys.modules, 'torch_npu', _fake_module('torch_npu')) + monkeypatch.setattr(builtin, '_install_sdpa', lambda impl: installs.append(impl)) + monkeypatch.setattr(fla, 'apply_qwen3_5_fla', lambda model: fla_calls.append(model)) + + builtin.npu_builtin(nn.Linear(1, 1)) + + assert installs == [] + assert fla_calls == [] diff --git a/tests/kernel/test_hub.py b/tests/kernel/test_hub.py index a0a7cf63..e1e2644e 100644 --- a/tests/kernel/test_hub.py +++ b/tests/kernel/test_hub.py @@ -1,54 +1,54 @@ -import pytest - -from twinkle.kernel.core import HubRef, hub - - -def test_hub_with_version(): - ref = hub('kernels-community/activation:SiluAndMul', version=1) - assert isinstance(ref, HubRef) - assert ref.repo_id == 'kernels-community/activation' - assert ref.layer_name == 'SiluAndMul' - assert ref.version == 1 - assert ref.revision is None - assert ref.backend is None - assert ref.trust_remote_code is False - - -def test_hub_with_revision(): - ref = hub('org/repo:Layer', revision='main') - assert ref.revision == 'main' - assert ref.version is None - - -def test_hub_with_backend_and_trust(): - ref = hub('org/repo:Layer', version=2, backend='cuda', trust_remote_code=True) - assert ref.backend == 'cuda' - assert ref.trust_remote_code is True - - -def test_hub_rejects_both_revision_and_version(): - with pytest.raises(ValueError, match='Exactly one'): - hub('org/repo:Layer', revision='main', version=1) - - -def test_hub_rejects_neither_revision_nor_version(): - with pytest.raises(ValueError, match='Exactly one'): - hub('org/repo:Layer') - - -def test_hub_rejects_missing_colon(): - with pytest.raises(ValueError, match='repo_id:LayerName'): - hub('org/repo', version=1) - - -def test_hub_handles_colon_in_repo_id(): - # rsplit takes only the last colon - ref = hub('org:sub/repo:Layer', version=1) - assert ref.repo_id == 'org:sub/repo' - assert ref.layer_name == 'Layer' - - -def test_hubref_is_frozen(): - ref = hub('org/repo:Layer', version=1) - with pytest.raises(Exception): +import pytest + +from twinkle.kernel.core import HubRef, hub + + +def test_hub_with_version(): + ref = hub('kernels-community/activation:SiluAndMul', version=1) + assert isinstance(ref, HubRef) + assert ref.repo_id == 'kernels-community/activation' + assert ref.layer_name == 'SiluAndMul' + assert ref.version == 1 + assert ref.revision is None + assert ref.backend is None + assert ref.trust_remote_code is False + + +def test_hub_with_revision(): + ref = hub('org/repo:Layer', revision='main') + assert ref.revision == 'main' + assert ref.version is None + + +def test_hub_with_backend_and_trust(): + ref = hub('org/repo:Layer', version=2, backend='cuda', trust_remote_code=True) + assert ref.backend == 'cuda' + assert ref.trust_remote_code is True + + +def test_hub_rejects_both_revision_and_version(): + with pytest.raises(ValueError, match='Exactly one'): + hub('org/repo:Layer', revision='main', version=1) + + +def test_hub_rejects_neither_revision_nor_version(): + with pytest.raises(ValueError, match='Exactly one'): + hub('org/repo:Layer') + + +def test_hub_rejects_missing_colon(): + with pytest.raises(ValueError, match='repo_id:LayerName'): + hub('org/repo', version=1) + + +def test_hub_handles_colon_in_repo_id(): + # rsplit takes only the last colon + ref = hub('org:sub/repo:Layer', version=1) + assert ref.repo_id == 'org:sub/repo' + assert ref.layer_name == 'Layer' + + +def test_hubref_is_frozen(): + ref = hub('org/repo:Layer', version=1) + with pytest.raises(Exception): ref.repo_id = 'other' \ No newline at end of file diff --git a/tests/kernel/test_infer_device.py b/tests/kernel/test_infer_device.py index a0f6fbb9..7f9d5581 100644 --- a/tests/kernel/test_infer_device.py +++ b/tests/kernel/test_infer_device.py @@ -1,29 +1,29 @@ -import torch -import torch.nn as nn - -from twinkle.kernel.core import _infer_device - - -class _NoParamsNoBuffers(nn.Module): - pass - - -class _OnlyBuffer(nn.Module): - def __init__(self): - super().__init__() - self.register_buffer('b', torch.zeros(2)) - - -def test_infer_device_from_parameter(): - m = nn.Linear(2, 3) - assert _infer_device(m) == 'cpu' - - -def test_infer_device_from_buffer_when_no_params(): - m = _OnlyBuffer() - assert _infer_device(m) == 'cpu' - - -def test_infer_device_defaults_to_cpu_when_empty(): - m = _NoParamsNoBuffers() +import torch +import torch.nn as nn + +from twinkle.kernel.core import _infer_device + + +class _NoParamsNoBuffers(nn.Module): + pass + + +class _OnlyBuffer(nn.Module): + def __init__(self): + super().__init__() + self.register_buffer('b', torch.zeros(2)) + + +def test_infer_device_from_parameter(): + m = nn.Linear(2, 3) + assert _infer_device(m) == 'cpu' + + +def test_infer_device_from_buffer_when_no_params(): + m = _OnlyBuffer() + assert _infer_device(m) == 'cpu' + + +def test_infer_device_defaults_to_cpu_when_empty(): + m = _NoParamsNoBuffers() assert _infer_device(m) == 'cpu' \ No newline at end of file diff --git a/tests/kernel/test_kernelize.py b/tests/kernel/test_kernelize.py index 44018fc2..cdb98cae 100644 --- a/tests/kernel/test_kernelize.py +++ b/tests/kernel/test_kernelize.py @@ -1,98 +1,98 @@ -import sys -import types - -import pytest -import torch -import torch.nn as nn - -from twinkle.kernel.core import HubRef, kernelize - - -class _SrcLayer(nn.Module): - def __init__(self): - super().__init__() - self.w = nn.Parameter(torch.zeros(1)) - - def forward(self, x): - return x - - -class _DstLayer(nn.Module): - def forward(self, x): - return x + 100 - - -def test_kernelize_class_to_class_replacement(): - parent = nn.Sequential(_SrcLayer(), _SrcLayer()) - out = kernelize(parent, {_SrcLayer: _DstLayer}) - assert out is parent - assert type(parent[0]) is _DstLayer - assert type(parent[1]) is _DstLayer - - -def test_kernelize_empty_mapping_returns_model(): - m = _SrcLayer() - assert kernelize(m, {}) is m - assert type(m) is _SrcLayer - - -def test_kernelize_string_key_calls_setattr(): - mod_name = 'tests.kernel._tmp_kernelize_str' - mod = types.ModuleType(mod_name) - mod.target_fn = lambda x: x - sys.modules[mod_name] = mod - try: - new_fn = lambda x: x * 3 # noqa: E731 - kernelize(nn.Linear(1, 1), {f'{mod_name}.target_fn': new_fn}) - assert mod.target_fn is new_fn - finally: - sys.modules.pop(mod_name, None) - - -def test_kernelize_device_dict_match(monkeypatch): - from twinkle.utils.device_mesh import Platform - - parent = nn.Sequential(_SrcLayer()) - monkeypatch.setattr(Platform, 'device_prefix', staticmethod(lambda platform=None: 'cpu')) - - kernelize(parent, {_SrcLayer: {'cpu': _DstLayer, 'npu': nn.Identity}}) - - assert type(parent[0]) is _DstLayer - - -def test_kernelize_uses_platform_device_prefix(monkeypatch): - from twinkle.utils.device_mesh import Platform - - parent = nn.Sequential(_SrcLayer()) # params may still be CPU before FSDP placement - monkeypatch.setattr(Platform, 'device_prefix', staticmethod(lambda platform=None: 'npu')) - - kernelize(parent, {_SrcLayer: {'npu': _DstLayer}}) - - assert type(parent[0]) is _DstLayer - - -def test_kernelize_device_dict_miss_skips_silently(monkeypatch): - from twinkle.utils.device_mesh import Platform - - parent = nn.Sequential(_SrcLayer()) - monkeypatch.setattr(Platform, 'device_prefix', staticmethod(lambda platform=None: 'cpu')) - - kernelize(parent, {_SrcLayer: {'npu': _DstLayer}}) - - assert type(parent[0]) is _SrcLayer - - -def test_kernelize_rejects_unknown_key_type(): - with pytest.raises(TypeError, match='Unsupported mapping key'): - kernelize(nn.Linear(1, 1), {42: _DstLayer}) - - -def test_kernelize_loads_hub_ref(monkeypatch): - # Stand in for HF kernels: patch _load_hub_ref to return _DstLayer - from twinkle.kernel import core as _core - monkeypatch.setattr(_core, '_load_hub_ref', lambda ref: _DstLayer) - - parent = nn.Sequential(_SrcLayer()) - ref = HubRef('org/repo', 'X', revision='main') - kernelize(parent, {_SrcLayer: ref}) - assert type(parent[0]) is _DstLayer +import sys +import types + +import pytest +import torch +import torch.nn as nn + +from twinkle.kernel.core import HubRef, kernelize + + +class _SrcLayer(nn.Module): + def __init__(self): + super().__init__() + self.w = nn.Parameter(torch.zeros(1)) + + def forward(self, x): + return x + + +class _DstLayer(nn.Module): + def forward(self, x): + return x + 100 + + +def test_kernelize_class_to_class_replacement(): + parent = nn.Sequential(_SrcLayer(), _SrcLayer()) + out = kernelize(parent, {_SrcLayer: _DstLayer}) + assert out is parent + assert type(parent[0]) is _DstLayer + assert type(parent[1]) is _DstLayer + + +def test_kernelize_empty_mapping_returns_model(): + m = _SrcLayer() + assert kernelize(m, {}) is m + assert type(m) is _SrcLayer + + +def test_kernelize_string_key_calls_setattr(): + mod_name = 'tests.kernel._tmp_kernelize_str' + mod = types.ModuleType(mod_name) + mod.target_fn = lambda x: x + sys.modules[mod_name] = mod + try: + new_fn = lambda x: x * 3 # noqa: E731 + kernelize(nn.Linear(1, 1), {f'{mod_name}.target_fn': new_fn}) + assert mod.target_fn is new_fn + finally: + sys.modules.pop(mod_name, None) + + +def test_kernelize_device_dict_match(monkeypatch): + from twinkle.utils.device_mesh import Platform + + parent = nn.Sequential(_SrcLayer()) + monkeypatch.setattr(Platform, 'device_prefix', staticmethod(lambda platform=None: 'cpu')) + + kernelize(parent, {_SrcLayer: {'cpu': _DstLayer, 'npu': nn.Identity}}) + + assert type(parent[0]) is _DstLayer + + +def test_kernelize_uses_platform_device_prefix(monkeypatch): + from twinkle.utils.device_mesh import Platform + + parent = nn.Sequential(_SrcLayer()) # params may still be CPU before FSDP placement + monkeypatch.setattr(Platform, 'device_prefix', staticmethod(lambda platform=None: 'npu')) + + kernelize(parent, {_SrcLayer: {'npu': _DstLayer}}) + + assert type(parent[0]) is _DstLayer + + +def test_kernelize_device_dict_miss_skips_silently(monkeypatch): + from twinkle.utils.device_mesh import Platform + + parent = nn.Sequential(_SrcLayer()) + monkeypatch.setattr(Platform, 'device_prefix', staticmethod(lambda platform=None: 'cpu')) + + kernelize(parent, {_SrcLayer: {'npu': _DstLayer}}) + + assert type(parent[0]) is _SrcLayer + + +def test_kernelize_rejects_unknown_key_type(): + with pytest.raises(TypeError, match='Unsupported mapping key'): + kernelize(nn.Linear(1, 1), {42: _DstLayer}) + + +def test_kernelize_loads_hub_ref(monkeypatch): + # Stand in for HF kernels: patch _load_hub_ref to return _DstLayer + from twinkle.kernel import core as _core + monkeypatch.setattr(_core, '_load_hub_ref', lambda ref: _DstLayer) + + parent = nn.Sequential(_SrcLayer()) + ref = HubRef('org/repo', 'X', revision='main') + kernelize(parent, {_SrcLayer: ref}) + assert type(parent[0]) is _DstLayer diff --git a/tests/kernel/test_load_hub_ref.py b/tests/kernel/test_load_hub_ref.py index 96d9c78e..747e3fdc 100644 --- a/tests/kernel/test_load_hub_ref.py +++ b/tests/kernel/test_load_hub_ref.py @@ -1,69 +1,69 @@ -import sys -import types -from unittest.mock import patch - -import pytest - -from twinkle.kernel.core import HubRef, _load_hub_ref - - -def _install_fake_kernels(layer_obj=None, no_layers=False): - """Install a fake `kernels` module with a controllable `get_kernel`.""" - fake = types.ModuleType('kernels') - - def fake_get_kernel(repo_id, **kwargs): - m = types.ModuleType('fake_kernel') - if not no_layers: - layers_ns = types.SimpleNamespace() - if layer_obj is not None: - layers_ns.MyLayer = layer_obj - m.layers = layers_ns - return m - - fake.get_kernel = fake_get_kernel - sys.modules['kernels'] = fake - - -def _uninstall_fake_kernels(): - sys.modules.pop('kernels', None) - - -def test_load_hub_ref_returns_layer(): - sentinel = object() - _install_fake_kernels(layer_obj=sentinel) - try: - ref = HubRef('org/repo', 'MyLayer', revision='main') - assert _load_hub_ref(ref) is sentinel - finally: - _uninstall_fake_kernels() - - -def test_load_hub_ref_raises_if_layers_missing(): - _install_fake_kernels(no_layers=True) - try: - ref = HubRef('org/repo', 'MyLayer', revision='main') - with pytest.raises(ValueError, match='does not define any layers'): - _load_hub_ref(ref) - finally: - _uninstall_fake_kernels() - - -def test_load_hub_ref_raises_if_layer_name_missing(): - _install_fake_kernels(layer_obj=None) # MyLayer not present - try: - ref = HubRef('org/repo', 'Missing', revision='main') - with pytest.raises(ValueError, match='not found'): - _load_hub_ref(ref) - finally: - _uninstall_fake_kernels() - - -def test_load_hub_ref_install_hint_when_kernels_missing(): - # Force `import kernels` to fail - sys.modules['kernels'] = None # short-circuits import to ImportError - try: - ref = HubRef('org/repo', 'MyLayer', revision='main') - with pytest.raises(ImportError, match='pip install kernels'): - _load_hub_ref(ref) - finally: +import sys +import types +from unittest.mock import patch + +import pytest + +from twinkle.kernel.core import HubRef, _load_hub_ref + + +def _install_fake_kernels(layer_obj=None, no_layers=False): + """Install a fake `kernels` module with a controllable `get_kernel`.""" + fake = types.ModuleType('kernels') + + def fake_get_kernel(repo_id, **kwargs): + m = types.ModuleType('fake_kernel') + if not no_layers: + layers_ns = types.SimpleNamespace() + if layer_obj is not None: + layers_ns.MyLayer = layer_obj + m.layers = layers_ns + return m + + fake.get_kernel = fake_get_kernel + sys.modules['kernels'] = fake + + +def _uninstall_fake_kernels(): + sys.modules.pop('kernels', None) + + +def test_load_hub_ref_returns_layer(): + sentinel = object() + _install_fake_kernels(layer_obj=sentinel) + try: + ref = HubRef('org/repo', 'MyLayer', revision='main') + assert _load_hub_ref(ref) is sentinel + finally: + _uninstall_fake_kernels() + + +def test_load_hub_ref_raises_if_layers_missing(): + _install_fake_kernels(no_layers=True) + try: + ref = HubRef('org/repo', 'MyLayer', revision='main') + with pytest.raises(ValueError, match='does not define any layers'): + _load_hub_ref(ref) + finally: + _uninstall_fake_kernels() + + +def test_load_hub_ref_raises_if_layer_name_missing(): + _install_fake_kernels(layer_obj=None) # MyLayer not present + try: + ref = HubRef('org/repo', 'Missing', revision='main') + with pytest.raises(ValueError, match='not found'): + _load_hub_ref(ref) + finally: + _uninstall_fake_kernels() + + +def test_load_hub_ref_install_hint_when_kernels_missing(): + # Force `import kernels` to fail + sys.modules['kernels'] = None # short-circuits import to ImportError + try: + ref = HubRef('org/repo', 'MyLayer', revision='main') + with pytest.raises(ImportError, match='pip install kernels'): + _load_hub_ref(ref) + finally: sys.modules.pop('kernels', None) \ No newline at end of file diff --git a/tests/kernel/test_public_api.py b/tests/kernel/test_public_api.py index 9000a758..f9a17a2a 100644 --- a/tests/kernel/test_public_api.py +++ b/tests/kernel/test_public_api.py @@ -1,22 +1,22 @@ -def test_public_exports_exactly_three_symbols(): - import twinkle.kernel as k - assert sorted(k.__all__) == ['hub', 'kernelize', 'npu_builtin'] - assert callable(k.kernelize) - assert callable(k.npu_builtin) - assert callable(k.hub) - - -def test_no_legacy_symbols(): - """Legacy registrar / patch helpers must be gone.""" - import twinkle.kernel as k - legacy = [ - 'kernelize_model', 'register_layer_kernel', 'register_function_kernel', - 'register_kernels', 'register_external_layer', 'apply_npu_patch', - 'apply_npu_fused_ops', 'apply_function_kernel', 'apply_layer_kernel', - 'register_layer_batch', 'register_npu_fused_function_kernels', - 'get_global_layer_registry', 'get_global_function_registry', - 'get_global_external_layer_registry', 'LayerRegistry', - 'ExternalLayerRegistry', 'FunctionRegistry', - ] - for name in legacy: +def test_public_exports_exactly_three_symbols(): + import twinkle.kernel as k + assert sorted(k.__all__) == ['hub', 'kernelize', 'npu_builtin'] + assert callable(k.kernelize) + assert callable(k.npu_builtin) + assert callable(k.hub) + + +def test_no_legacy_symbols(): + """Legacy registrar / patch helpers must be gone.""" + import twinkle.kernel as k + legacy = [ + 'kernelize_model', 'register_layer_kernel', 'register_function_kernel', + 'register_kernels', 'register_external_layer', 'apply_npu_patch', + 'apply_npu_fused_ops', 'apply_function_kernel', 'apply_layer_kernel', + 'register_layer_batch', 'register_npu_fused_function_kernels', + 'get_global_layer_registry', 'get_global_function_registry', + 'get_global_external_layer_registry', 'LayerRegistry', + 'ExternalLayerRegistry', 'FunctionRegistry', + ] + for name in legacy: assert not hasattr(k, name), f'unexpected legacy symbol: {name}' \ No newline at end of file diff --git a/tests/kernel/test_replace.py b/tests/kernel/test_replace.py index 40013f2c..e649b2e3 100644 --- a/tests/kernel/test_replace.py +++ b/tests/kernel/test_replace.py @@ -1,74 +1,74 @@ -import sys -import types - -import torch.nn as nn - -from twinkle.kernel.core import _replace_attr, _replace_class - - -class _Target(nn.Module): - def forward(self, x): - return x - - -class _Impl(nn.Module): - def forward(self, x): - return x + 1 - - -class _SubTarget(_Target): - pass - - -def test_replace_class_rewrites_exact_match(): - m = _Target() - parent = nn.Sequential(_Target(), nn.Linear(1, 1)) - _replace_class(parent, _Target, _Impl) - assert type(parent[0]) is _Impl - - -def test_replace_class_skips_subclass(): - parent = nn.Sequential(_SubTarget()) - _replace_class(parent, _Target, _Impl) - # exact match only - _SubTarget should NOT be rewritten - assert type(parent[0]) is _SubTarget - - -def test_replace_class_idempotent(): - m = nn.Sequential(_Target()) - _replace_class(m, _Target, _Impl) - _replace_class(m, _Target, _Impl) # second call must be safe - assert type(m[0]) is _Impl - - -def test_replace_attr_sets_module_attribute(): - mod_name = 'tests.kernel._tmp_replace_attr' - mod = types.ModuleType(mod_name) - mod.target_fn = lambda x: x - sys.modules[mod_name] = mod - try: - new_fn = lambda x: x * 2 # noqa: E731 - _replace_attr(f'{mod_name}.target_fn', new_fn) - assert mod.target_fn is new_fn - finally: - sys.modules.pop(mod_name, None) - - -def test_replace_attr_supports_class_attribute(): - import sys - import types - - mod_name = 'tests.kernel._tmp_class_attr' - mod = types.ModuleType(mod_name) - - class Foo: - def forward(self, x): - return x - mod.Foo = Foo - sys.modules[mod_name] = mod - try: - new_forward = lambda self, x: x + 7 # noqa: E731 - _replace_attr(f'{mod_name}.Foo.forward', new_forward) - assert Foo.forward is new_forward - finally: +import sys +import types + +import torch.nn as nn + +from twinkle.kernel.core import _replace_attr, _replace_class + + +class _Target(nn.Module): + def forward(self, x): + return x + + +class _Impl(nn.Module): + def forward(self, x): + return x + 1 + + +class _SubTarget(_Target): + pass + + +def test_replace_class_rewrites_exact_match(): + m = _Target() + parent = nn.Sequential(_Target(), nn.Linear(1, 1)) + _replace_class(parent, _Target, _Impl) + assert type(parent[0]) is _Impl + + +def test_replace_class_skips_subclass(): + parent = nn.Sequential(_SubTarget()) + _replace_class(parent, _Target, _Impl) + # exact match only - _SubTarget should NOT be rewritten + assert type(parent[0]) is _SubTarget + + +def test_replace_class_idempotent(): + m = nn.Sequential(_Target()) + _replace_class(m, _Target, _Impl) + _replace_class(m, _Target, _Impl) # second call must be safe + assert type(m[0]) is _Impl + + +def test_replace_attr_sets_module_attribute(): + mod_name = 'tests.kernel._tmp_replace_attr' + mod = types.ModuleType(mod_name) + mod.target_fn = lambda x: x + sys.modules[mod_name] = mod + try: + new_fn = lambda x: x * 2 # noqa: E731 + _replace_attr(f'{mod_name}.target_fn', new_fn) + assert mod.target_fn is new_fn + finally: + sys.modules.pop(mod_name, None) + + +def test_replace_attr_supports_class_attribute(): + import sys + import types + + mod_name = 'tests.kernel._tmp_class_attr' + mod = types.ModuleType(mod_name) + + class Foo: + def forward(self, x): + return x + mod.Foo = Foo + sys.modules[mod_name] = mod + try: + new_forward = lambda self, x: x + 7 # noqa: E731 + _replace_attr(f'{mod_name}.Foo.forward', new_forward) + assert Foo.forward is new_forward + finally: sys.modules.pop(mod_name, None) \ No newline at end of file diff --git a/tests/kernel/test_resolve_value.py b/tests/kernel/test_resolve_value.py index abc419e7..652783f5 100644 --- a/tests/kernel/test_resolve_value.py +++ b/tests/kernel/test_resolve_value.py @@ -1,48 +1,48 @@ -import torch.nn as nn - -from twinkle.kernel.core import HubRef, _resolve_value - - -class _ImplA(nn.Module): - pass - - -class _ImplB(nn.Module): - pass - - -def test_passthrough_class_value(): - assert _resolve_value(_ImplA, 'cuda') is _ImplA - - -def test_passthrough_callable_value(): - f = lambda x: x # noqa: E731 - assert _resolve_value(f, 'npu') is f - - -def test_passthrough_hubref(): - ref = HubRef('org/repo', 'Layer', revision='main') - assert _resolve_value(ref, 'cuda') is ref - - -def test_device_dict_match(): - val = {'npu': _ImplA, 'cuda': _ImplB} - assert _resolve_value(val, 'npu') is _ImplA - assert _resolve_value(val, 'cuda') is _ImplB - - -def test_device_dict_miss_returns_none(): - val = {'npu': _ImplA} - assert _resolve_value(val, 'cuda') is None - - -def test_device_dict_nested(): - # nested dict -> recursive resolve - val = {'npu': {'npu': _ImplA}} - assert _resolve_value(val, 'npu') is _ImplA - - -def test_device_dict_miss_then_passthrough(): - # nested dict whose inner is also a dict that misses -> None - val = {'npu': {'cuda': _ImplA}} +import torch.nn as nn + +from twinkle.kernel.core import HubRef, _resolve_value + + +class _ImplA(nn.Module): + pass + + +class _ImplB(nn.Module): + pass + + +def test_passthrough_class_value(): + assert _resolve_value(_ImplA, 'cuda') is _ImplA + + +def test_passthrough_callable_value(): + f = lambda x: x # noqa: E731 + assert _resolve_value(f, 'npu') is f + + +def test_passthrough_hubref(): + ref = HubRef('org/repo', 'Layer', revision='main') + assert _resolve_value(ref, 'cuda') is ref + + +def test_device_dict_match(): + val = {'npu': _ImplA, 'cuda': _ImplB} + assert _resolve_value(val, 'npu') is _ImplA + assert _resolve_value(val, 'cuda') is _ImplB + + +def test_device_dict_miss_returns_none(): + val = {'npu': _ImplA} + assert _resolve_value(val, 'cuda') is None + + +def test_device_dict_nested(): + # nested dict -> recursive resolve + val = {'npu': {'npu': _ImplA}} + assert _resolve_value(val, 'npu') is _ImplA + + +def test_device_dict_miss_then_passthrough(): + # nested dict whose inner is also a dict that misses -> None + val = {'npu': {'cuda': _ImplA}} assert _resolve_value(val, 'npu') is None \ No newline at end of file From ff67fc1fbee4be9343deec884cf81bc87d3d55fe Mon Sep 17 00:00:00 2001 From: weikaiwen <34648228+kevssim@users.noreply.github.com> Date: Mon, 29 Jun 2026 15:44:05 +0800 Subject: [PATCH 25/27] lint --- docs/source_en/Components/Kernel/Kernel.md | 2 +- .../\345\206\205\346\240\270/Kernel.md" | 2 +- src/twinkle/kernel/__init__.py | 2 +- src/twinkle/kernel/builtin.py | 48 ++++++++++--------- src/twinkle/kernel/core.py | 10 ++-- src/twinkle/kernel/npu_impls/__init__.py | 13 ++--- src/twinkle/kernel/npu_impls/attention.py | 2 +- src/twinkle/kernel/npu_impls/fla.py | 7 ++- src/twinkle/kernel/npu_impls/moe.py | 40 +++++++++------- src/twinkle/kernel/npu_impls/rms_norm.py | 7 +-- src/twinkle/kernel/npu_impls/rotary.py | 2 +- src/twinkle/kernel/npu_impls/swiglu.py | 3 +- 12 files changed, 68 insertions(+), 70 deletions(-) diff --git a/docs/source_en/Components/Kernel/Kernel.md b/docs/source_en/Components/Kernel/Kernel.md index f5ab78e9..fe0a505a 100644 --- a/docs/source_en/Components/Kernel/Kernel.md +++ b/docs/source_en/Components/Kernel/Kernel.md @@ -136,4 +136,4 @@ The legacy `TWINKLE_NPU_PATCH` / `TWINKLE_NPU_FUSED_OPS` / `TWINKLE_NPU_GMM_PATC - `m.__class__ = impl_cls` is Python class-replacement magic. The impl class **must** override only `forward` (and helpers); defining `__init__` is incompatible with the contract - Exact match: `type(m) is target_cls`. Subclasses of `target_cls` are not replaced — add them to the mapping yourself - `kernelize` is idempotent under repeated calls -- There is no `unkernelize` — replacement is one-way \ No newline at end of file +- There is no `unkernelize` — replacement is one-way diff --git "a/docs/source_zh/\347\273\204\344\273\266/\345\206\205\346\240\270/Kernel.md" "b/docs/source_zh/\347\273\204\344\273\266/\345\206\205\346\240\270/Kernel.md" index 8ed5ad78..17e4d4a2 100644 --- "a/docs/source_zh/\347\273\204\344\273\266/\345\206\205\346\240\270/Kernel.md" +++ "b/docs/source_zh/\347\273\204\344\273\266/\345\206\205\346\240\270/Kernel.md" @@ -134,4 +134,4 @@ model = kernelize(model, mapping) - `m.__class__ = impl_cls` 是 Python class 替换魔法。impl 类**必须**只覆盖 `forward`(以及辅助方法),不能定义 `__init__`,否则原 instance 的 attribute 会与 impl 的预期错位 - 精确匹配:`type(m) is target_cls`。继承自 `target_cls` 的子类不会被替换;如需替换,把子类也放进 mapping - 调用 `kernelize` 多次是幂等的(`__class__` 已是 impl 时再设一次无害) -- 没有 `unkernelize`——替换是单向的 \ No newline at end of file +- 没有 `unkernelize`——替换是单向的 diff --git a/src/twinkle/kernel/__init__.py b/src/twinkle/kernel/__init__.py index 5d435c0f..f1499680 100644 --- a/src/twinkle/kernel/__init__.py +++ b/src/twinkle/kernel/__init__.py @@ -10,4 +10,4 @@ from .builtin import npu_builtin from .core import hub, kernelize -__all__ = ['kernelize', 'hub', 'npu_builtin'] \ No newline at end of file +__all__ = ['kernelize', 'hub', 'npu_builtin'] diff --git a/src/twinkle/kernel/builtin.py b/src/twinkle/kernel/builtin.py index 8d7f3b59..cda8813b 100644 --- a/src/twinkle/kernel/builtin.py +++ b/src/twinkle/kernel/builtin.py @@ -13,9 +13,8 @@ from __future__ import annotations import importlib -from typing import Any - import torch.nn as nn +from typing import Any from twinkle import get_logger from twinkle.utils.device_mesh import Platform @@ -34,15 +33,9 @@ def npu_builtin(model: nn.Module | None = None) -> dict[Any, dict[str, Any]]: """Return the NPU builtin mapping; optionally apply per-instance FLA.""" from .npu_impls.attention import npu_sdpa_attention_forward from .npu_impls.fla import apply_qwen3_5_fla - from .npu_impls.moe import ( - npu_packed_moe_experts_forward, - npu_qwen3_5_moe_sparse_block_forward, - ) + from .npu_impls.moe import npu_packed_moe_experts_forward, npu_qwen3_5_moe_sparse_block_forward from .npu_impls.rms_norm import NpuRMSNorm, npu_gated_rms_norm_forward - from .npu_impls.rotary import ( - npu_apply_multimodal_rotary_pos_emb, - npu_apply_rotary_pos_emb, - ) + from .npu_impls.rotary import npu_apply_multimodal_rotary_pos_emb, npu_apply_rotary_pos_emb from .npu_impls.swiglu import npu_swiglu_forward bundle: dict[Any, dict[str, Any]] = {} @@ -60,20 +53,34 @@ def npu_builtin(model: nn.Module | None = None) -> dict[Any, dict[str, Any]]: _add_qwen2_entries(bundle, NpuRMSNorm, npu_apply_rotary_pos_emb, npu_swiglu_forward) _add_qwen3_entries(bundle, NpuRMSNorm, npu_apply_rotary_pos_emb, npu_swiglu_forward) _add_qwen3_moe_entries( - bundle, NpuRMSNorm, npu_apply_rotary_pos_emb, npu_swiglu_forward, - npu_packed_moe_experts_forward, npu_qwen3_5_moe_sparse_block_forward, + bundle, + NpuRMSNorm, + npu_apply_rotary_pos_emb, + npu_swiglu_forward, + npu_packed_moe_experts_forward, + npu_qwen3_5_moe_sparse_block_forward, ) _add_qwen2_5_vl_entries( - bundle, NpuRMSNorm, npu_apply_rotary_pos_emb, npu_swiglu_forward, + bundle, + NpuRMSNorm, + npu_apply_rotary_pos_emb, + npu_swiglu_forward, npu_apply_multimodal_rotary_pos_emb, ) _add_qwen3_5_entries( - bundle, NpuRMSNorm, npu_gated_rms_norm_forward, npu_apply_rotary_pos_emb, + bundle, + NpuRMSNorm, + npu_gated_rms_norm_forward, + npu_apply_rotary_pos_emb, npu_swiglu_forward, ) _add_qwen3_5_moe_entries( - bundle, NpuRMSNorm, npu_gated_rms_norm_forward, npu_apply_rotary_pos_emb, - npu_swiglu_forward, npu_packed_moe_experts_forward, + bundle, + NpuRMSNorm, + npu_gated_rms_norm_forward, + npu_apply_rotary_pos_emb, + npu_swiglu_forward, + npu_packed_moe_experts_forward, npu_qwen3_5_moe_sparse_block_forward, ) @@ -92,10 +99,7 @@ def _install_sdpa(impl) -> None: of ``npu_builtin()``. """ try: - from transformers.modeling_utils import ( - ALL_ATTENTION_FUNCTIONS, - AttentionInterface, - ) + from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, AttentionInterface except ImportError: return try: @@ -107,6 +111,7 @@ def _install_sdpa(impl) -> None: # ---- helpers that conditionally add entries based on module availability ---- + def _add_class_if_present(bundle, module_path, class_name, impl_cls): mod = _import_optional(module_path) if mod is None: @@ -192,8 +197,7 @@ def _add_qwen3_5_entries(bundle, rms_cls, gated_rms_fn, rope_fn, swiglu_fn): _add_attr_if_present(bundle, base, 'Qwen3_5GatedRMSNorm.forward', gated_rms_fn) -def _add_qwen3_5_moe_entries(bundle, rms_cls, gated_rms_fn, rope_fn, swiglu_fn, - experts_fn, sparse_fn): +def _add_qwen3_5_moe_entries(bundle, rms_cls, gated_rms_fn, rope_fn, swiglu_fn, experts_fn, sparse_fn): base = 'transformers.models.qwen3_5_moe.modeling_qwen3_5_moe' if _import_optional(base) is None: return diff --git a/src/twinkle/kernel/core.py b/src/twinkle/kernel/core.py index a3a12f18..e6b33ee1 100644 --- a/src/twinkle/kernel/core.py +++ b/src/twinkle/kernel/core.py @@ -6,11 +6,10 @@ from __future__ import annotations import importlib +import torch.nn as nn from dataclasses import dataclass from typing import Any -import torch.nn as nn - from twinkle.utils.device_mesh import Platform @@ -49,7 +48,6 @@ def hub( return HubRef(repo_id, layer_name, revision, version, backend, trust_remote_code) - def _resolve_value(value: Any, device: str) -> Any | None: """Resolve a mapping value against the selected device. @@ -116,10 +114,8 @@ def _load_hub_ref(ref: HubRef): try: from kernels import get_kernel except ImportError as e: - raise ImportError( - 'Loading a Hub kernel requires the `kernels` package. ' - 'Install it with `pip install kernels`.' - ) from e + raise ImportError('Loading a Hub kernel requires the `kernels` package. ' + 'Install it with `pip install kernels`.') from e kernel = get_kernel( ref.repo_id, diff --git a/src/twinkle/kernel/npu_impls/__init__.py b/src/twinkle/kernel/npu_impls/__init__.py index 31d77581..47d2a0bf 100644 --- a/src/twinkle/kernel/npu_impls/__init__.py +++ b/src/twinkle/kernel/npu_impls/__init__.py @@ -5,17 +5,12 @@ replacement) or ``setattr(module, attr, fn)`` (function replacement). No impl here is meant to be instantiated directly. """ +from .attention import npu_sdpa_attention_forward +from .fla import apply_qwen3_5_fla +from .moe import GmmFunction, npu_grouped_mm, npu_packed_moe_experts_forward, npu_qwen3_5_moe_sparse_block_forward from .rms_norm import NpuRMSNorm, npu_gated_rms_norm_forward from .rotary import npu_apply_multimodal_rotary_pos_emb, npu_apply_rotary_pos_emb from .swiglu import npu_swiglu_forward -from .attention import npu_sdpa_attention_forward -from .moe import ( - GmmFunction, - npu_grouped_mm, - npu_packed_moe_experts_forward, - npu_qwen3_5_moe_sparse_block_forward, -) -from .fla import apply_qwen3_5_fla __all__ = [ 'NpuRMSNorm', @@ -29,4 +24,4 @@ 'npu_packed_moe_experts_forward', 'npu_qwen3_5_moe_sparse_block_forward', 'apply_qwen3_5_fla', -] \ No newline at end of file +] diff --git a/src/twinkle/kernel/npu_impls/attention.py b/src/twinkle/kernel/npu_impls/attention.py index f328b2d5..c63a858f 100644 --- a/src/twinkle/kernel/npu_impls/attention.py +++ b/src/twinkle/kernel/npu_impls/attention.py @@ -51,4 +51,4 @@ def npu_sdpa_attention_forward( scale=scaling, is_causal=is_causal, ) - return attn_output.transpose(1, 2).contiguous(), None \ No newline at end of file + return attn_output.transpose(1, 2).contiguous(), None diff --git a/src/twinkle/kernel/npu_impls/fla.py b/src/twinkle/kernel/npu_impls/fla.py index d2fc43a9..847832f3 100644 --- a/src/twinkle/kernel/npu_impls/fla.py +++ b/src/twinkle/kernel/npu_impls/fla.py @@ -47,8 +47,8 @@ def apply_qwen3_5_fla(model=None) -> int: # fail to install the kernel, HF transformers would route Qwen3.5 onto # a FLA fast path whose kernel is missing -> runtime failure on NPU. try: - from twinkle.kernel.chunk_gated_delta_rule import chunk_gated_delta_rule as mindspeed_fla from twinkle.kernel.causal_conv1d import npu_causal_conv1d_fn + from twinkle.kernel.chunk_gated_delta_rule import chunk_gated_delta_rule as mindspeed_fla except ImportError as exc: logger.warning('[NPU] [FLA] MindSpeed unavailable: %s', exc) return 0 @@ -88,8 +88,7 @@ def _is_fla_available() -> bool: patched_instances = 0 for _name, _module in root.named_modules(): - if hasattr(_module, 'chunk_gated_delta_rule') and callable( - getattr(_module, 'chunk_gated_delta_rule')): + if hasattr(_module, 'chunk_gated_delta_rule') and callable(getattr(_module, 'chunk_gated_delta_rule')): if _module.chunk_gated_delta_rule is not mindspeed_fla: _module.chunk_gated_delta_rule = mindspeed_fla _module._twinkle_npu_patched = True @@ -100,4 +99,4 @@ def _is_fla_available() -> bool: if patched_instances: logger.info('[NPU] [FLA] Patched %d linear attention instance(s)', patched_instances) - return patched_instances \ No newline at end of file + return patched_instances diff --git a/src/twinkle/kernel/npu_impls/moe.py b/src/twinkle/kernel/npu_impls/moe.py index efa7f71a..1f847669 100644 --- a/src/twinkle/kernel/npu_impls/moe.py +++ b/src/twinkle/kernel/npu_impls/moe.py @@ -15,8 +15,12 @@ def forward(ctx, x: torch.Tensor, group_list: torch.Tensor, weight_ekn: torch.Te group_list = group_list.to(torch.int64) ctx.save_for_backward(x, group_list, weight_ekn) outputs = torch_npu.npu_grouped_matmul( - [x], [weight_ekn], group_list=group_list, - group_type=0, split_item=2, group_list_type=1, + [x], + [weight_ekn], + group_list=group_list, + group_type=0, + split_item=2, + group_list_type=1, ) return outputs[0] @@ -25,14 +29,22 @@ def backward(ctx, grad_output: torch.Tensor): import torch_npu x, group_list, weight_ekn = ctx.saved_tensors grad_input = torch_npu.npu_grouped_matmul( - [grad_output], [weight_ekn.transpose(-2, -1).contiguous()], - bias=None, group_list=group_list, - group_type=0, split_item=2, group_list_type=1, + [grad_output], + [weight_ekn.transpose(-2, -1).contiguous()], + bias=None, + group_list=group_list, + group_type=0, + split_item=2, + group_list_type=1, )[0] grad_weight = torch_npu.npu_grouped_matmul( - [x.transpose(0, 1)], [grad_output], - bias=None, group_list=group_list, - group_type=2, split_item=3, group_list_type=1, + [x.transpose(0, 1)], + [grad_output], + bias=None, + group_list=group_list, + group_type=2, + split_item=3, + group_list_type=1, )[0] return grad_input, None, grad_weight.contiguous() @@ -67,20 +79,16 @@ def _normalize_packed_expert_weights(module, input_dtype, hidden_dim): def _get_cached_expert_weights(self, target_dtype, hidden_dim): requires_grad = ( - getattr(self.gate_up_proj, 'requires_grad', False) - or getattr(self.down_proj, 'requires_grad', False) - ) + getattr(self.gate_up_proj, 'requires_grad', False) or getattr(self.down_proj, 'requires_grad', False)) cache_attr = '_npu_expert_cache' if not requires_grad and hasattr(self, cache_attr): cached_dtype, cached_gv, cached_dv, cached = getattr(self, cache_attr) - if (cached_dtype == target_dtype - and cached_gv == self.gate_up_proj._version + if (cached_dtype == target_dtype and cached_gv == self.gate_up_proj._version and cached_dv == self.down_proj._version): return cached weights = _normalize_packed_expert_weights(self, target_dtype, hidden_dim) if not requires_grad: - setattr(self, cache_attr, - (target_dtype, self.gate_up_proj._version, self.down_proj._version, weights)) + setattr(self, cache_attr, (target_dtype, self.gate_up_proj._version, self.down_proj._version, weights)) return weights @@ -148,4 +156,4 @@ def npu_qwen3_5_moe_sparse_block_forward(self, hidden_states): expert_output = self.experts(flat, selected_experts, routing_weights) expert_output = _add_shared_expert(self, flat, expert_output) - return expert_output.reshape(batch_size, sequence_length, hidden_dim) \ No newline at end of file + return expert_output.reshape(batch_size, sequence_length, hidden_dim) diff --git a/src/twinkle/kernel/npu_impls/rms_norm.py b/src/twinkle/kernel/npu_impls/rms_norm.py index ecebdc23..7fd7a5f7 100644 --- a/src/twinkle/kernel/npu_impls/rms_norm.py +++ b/src/twinkle/kernel/npu_impls/rms_norm.py @@ -7,7 +7,6 @@ from __future__ import annotations import os - import torch import torch.nn as nn import torch.nn.functional as F @@ -50,9 +49,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # Resolved once at import: matches the legacy "patch-time, process-wide" invariant. # Mid-process env mutation will not retroactively change behavior. -_FORCE_FP32 = os.environ.get('TWINKLE_NPU_GATED_RMSNorm_FP32', '0').lower() in ( - '1', 'true', 'on', 'yes' -) +_FORCE_FP32 = os.environ.get('TWINKLE_NPU_GATED_RMSNorm_FP32', '0').lower() in ('1', 'true', 'on', 'yes') def npu_gated_rms_norm_forward(self, hidden_states, gate=None): @@ -72,4 +69,4 @@ def npu_gated_rms_norm_forward(self, hidden_states, gate=None): hidden_states = torch_npu.npu_rms_norm(hidden_states, weight, epsilon=_eps)[0] if gate is not None: hidden_states = hidden_states * F.silu(gate) - return hidden_states.to(input_dtype) \ No newline at end of file + return hidden_states.to(input_dtype) diff --git a/src/twinkle/kernel/npu_impls/rotary.py b/src/twinkle/kernel/npu_impls/rotary.py index 1ed437a3..aa70b1ff 100644 --- a/src/twinkle/kernel/npu_impls/rotary.py +++ b/src/twinkle/kernel/npu_impls/rotary.py @@ -63,4 +63,4 @@ def npu_apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze [m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1, ).unsqueeze(unsqueeze_dim) - return _apply_npu_rotary_emb(q, k, cos, sin) \ No newline at end of file + return _apply_npu_rotary_emb(q, k, cos, sin) diff --git a/src/twinkle/kernel/npu_impls/swiglu.py b/src/twinkle/kernel/npu_impls/swiglu.py index c34a7bea..782e16cc 100644 --- a/src/twinkle/kernel/npu_impls/swiglu.py +++ b/src/twinkle/kernel/npu_impls/swiglu.py @@ -16,5 +16,4 @@ def npu_swiglu_forward(self, hidden_state): torch_npu.npu_swiglu( torch.cat((self.gate_proj(hidden_state), self.up_proj(hidden_state)), dim=-1), dim=-1, - ) - ) \ No newline at end of file + )) From b06cfe507e07c2f74695a28809c01c633b90d871 Mon Sep 17 00:00:00 2001 From: weikaiwen <34648228+kevssim@users.noreply.github.com> Date: Tue, 30 Jun 2026 10:06:10 +0800 Subject: [PATCH 26/27] delete --- tests/kernel/test_infer_device.py | 29 ----------------------------- 1 file changed, 29 deletions(-) delete mode 100644 tests/kernel/test_infer_device.py diff --git a/tests/kernel/test_infer_device.py b/tests/kernel/test_infer_device.py deleted file mode 100644 index 7f9d5581..00000000 --- a/tests/kernel/test_infer_device.py +++ /dev/null @@ -1,29 +0,0 @@ -import torch -import torch.nn as nn - -from twinkle.kernel.core import _infer_device - - -class _NoParamsNoBuffers(nn.Module): - pass - - -class _OnlyBuffer(nn.Module): - def __init__(self): - super().__init__() - self.register_buffer('b', torch.zeros(2)) - - -def test_infer_device_from_parameter(): - m = nn.Linear(2, 3) - assert _infer_device(m) == 'cpu' - - -def test_infer_device_from_buffer_when_no_params(): - m = _OnlyBuffer() - assert _infer_device(m) == 'cpu' - - -def test_infer_device_defaults_to_cpu_when_empty(): - m = _NoParamsNoBuffers() - assert _infer_device(m) == 'cpu' \ No newline at end of file From 1b5f8c9d4caf54595a9ad31fde454e1b96cc572f Mon Sep 17 00:00:00 2001 From: weikaiwen <34648228+kevssim@users.noreply.github.com> Date: Tue, 30 Jun 2026 15:28:04 +0800 Subject: [PATCH 27/27] wip --- scripts/kernelize_demo.py | 260 ++++++++++++++++++++++++++++ src/twinkle/kernel/core.py | 5 +- src/twinkle/kernel/csrc/placeholder | 0 tests/kernel/test_hub.py | 6 +- 4 files changed, 263 insertions(+), 8 deletions(-) create mode 100644 scripts/kernelize_demo.py delete mode 100644 src/twinkle/kernel/csrc/placeholder diff --git a/scripts/kernelize_demo.py b/scripts/kernelize_demo.py new file mode 100644 index 00000000..aaa19448 --- /dev/null +++ b/scripts/kernelize_demo.py @@ -0,0 +1,260 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""End-to-end demo for ``twinkle.kernel.core.kernelize``. + +Run with:: + + conda run -n twinkle python scripts/kernelize_demo.py + +The script exercises three replacement modes on CPU: + +1. Class replacement - rewrite ``__class__`` of matching ``nn.Module`` instances. +2. Attribute replacement - monkey-patch a module/function attribute via dotted path. +3. Hub replacement - lazy-load a kernel from a mocked ``kernels`` package. +""" + +from __future__ import annotations + +import sys +from pathlib import Path + +# Make the local ``src`` importable when running the script directly. +_PROJECT_ROOT = Path(__file__).resolve().parents[1] +if str(_PROJECT_ROOT / "src") not in sys.path: + sys.path.insert(0, str(_PROJECT_ROOT / "src")) + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import Qwen3Config +from transformers.models.qwen3.modeling_qwen3 import Qwen3MLP + +from twinkle.kernel.core import hub, kernelize + + +def _assert(cond: bool, msg: str) -> None: + if not cond: + raise AssertionError(msg) + + +def _describe(obj) -> str: + """Best-effort name for a callable (plain function or kernels ``Func``).""" + qn = getattr(obj, "__qualname__", None) + if qn: + return qn + return f"<{type(obj).__module__}.{type(obj).__name__}>" + + +class FusedQwen3MLP(Qwen3MLP): + """Pretend fused kernel: same gated MLP + a constant +1.0 bias.""" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + print("[patched] FusedQwen3MLP.forward called") + return super().forward(x) + 1.0 + + +def _build_mlp() -> Qwen3MLP: + config = Qwen3Config( + hidden_size=8, + intermediate_size=16, + num_hidden_layers=1, + num_attention_heads=1, + num_key_value_heads=1, + ) + return Qwen3MLP(config) + + +# --------------------------------------------------------------------------- # +# Demo 1: Class replacement +# --------------------------------------------------------------------------- # +def demo_class_replacement() -> None: + print("=" * 60) + print("Demo 1: class replacement (replace transformers Qwen3MLP)") + print("=" * 60) + + mlp = _build_mlp() + x = torch.randn(1, 4, mlp.config.hidden_size) + + out_before = mlp(x) + print(f"Before kernelize: type = {type(mlp).__name__}") + print(f" output[0,0,:3] = {out_before[0, 0, :3].tolist()}") + + # Pass the MLP itself as the model; ``model.modules()`` yields it. + kernelize(mlp, {Qwen3MLP: FusedQwen3MLP}) + + out_after = mlp(x) + print(f"After kernelize: type = {type(mlp).__name__}") + print(f" output[0,0,:3] = {out_after[0, 0, :3].tolist()}") + + _assert(type(mlp) is FusedQwen3MLP, "mlp should be FusedQwen3MLP after kernelize") + # Params (gate_proj/up_proj/down_proj) are preserved on the instance, so the + # only difference is the +1.0 added by the fused forward. + _assert( + torch.allclose(out_after, out_before + 1.0), + "FusedQwen3MLP should add +1.0 to the original output", + ) + print("✓ Class replacement passed\n") + + +# --------------------------------------------------------------------------- # +# Demo 2: Attribute replacement (patch transformers qwen3 apply_rotary_pos_emb) +# --------------------------------------------------------------------------- # +_QWEN3_MOD_PATH = "transformers.models.qwen3.modeling_qwen3" +_ROPE_ATTR = "apply_rotary_pos_emb" + + +def demo_attr_replacement() -> None: + print("=" * 60) + print("Demo 2: attribute replacement (two forms)") + print("=" * 60) + + import importlib + + mod = importlib.import_module(_QWEN3_MOD_PATH) + + # ---- Form A: module attribute (pkg.mod.attr) -------------------------- # + print("-" * 60) + print("Form A: replace module-level function `apply_rotary_pos_emb`") + print("-" * 60) + + original_rope = getattr(mod, _ROPE_ATTR) + + q = torch.ones(1, 2, 4, 8) + k = torch.ones(1, 2, 4, 8) + cos = torch.ones(1, 1, 4, 8) + sin = torch.ones(1, 1, 4, 8) + + q_out_before, k_out_before = original_rope(q, k, cos, sin) + print(f"Before kernelize: {_describe(mod.apply_rotary_pos_emb)}") + + def fused_apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): # noqa: ANN001 + return original_rope(q, k, cos, sin, unsqueeze_dim=unsqueeze_dim) + + fused_apply_rotary_pos_emb._kernelize_marker = True # type: ignore[attr-defined] + + try: + kernelize(nn.Linear(1, 1), {f"{_QWEN3_MOD_PATH}.{_ROPE_ATTR}": fused_apply_rotary_pos_emb}) + + patched_fn = getattr(mod, _ROPE_ATTR) + print(f"After kernelize: {_describe(patched_fn)}") + _assert(patched_fn is fused_apply_rotary_pos_emb, "module attr should be the fused fn") + q_out_after, k_out_after = patched_fn(q, k, cos, sin) + _assert( + torch.allclose(q_out_after, q_out_before) + and torch.allclose(k_out_after, k_out_before), + "wrapped RoPE should preserve the original output", + ) + print("✓ Form A (module attribute) passed\n") + finally: + setattr(mod, _ROPE_ATTR, original_rope) + + # ---- Form B: class attribute / method (pkg.mod.ClassName.attr) ------- # + print("-" * 60) + print("Form B: replace class method `Qwen3MLP.forward`") + print("-" * 60) + + original_forward = Qwen3MLP.forward + + mlp = _build_mlp() + x = torch.randn(1, 4, mlp.config.hidden_size) + out_before = mlp(x) + print(f"Before kernelize: {_describe(Qwen3MLP.forward)}") + print(f" output[0,0,:3] = {out_before[0, 0, :3].tolist()}") + + def fused_forward(self, x): # noqa: ANN001 + return original_forward(self, x) + 1.0 + + try: + # Dotted path lands on the class, then setattr the method on it. + kernelize(nn.Linear(1, 1), {f"{_QWEN3_MOD_PATH}.Qwen3MLP.forward": fused_forward}) + + patched_forward = Qwen3MLP.forward + print(f"After kernelize: {_describe(patched_forward)}") + _assert(patched_forward is fused_forward, "class method should be the fused fn") + + out_after = mlp(x) + print(f" output[0,0,:3] = {out_after[0, 0, :3].tolist()}") + _assert( + torch.allclose(out_after, out_before + 1.0), + "fused forward should add +1.0 to the original output", + ) + print("✓ Form B (class method) passed\n") + finally: + setattr(Qwen3MLP, "forward", original_forward) + + +# --------------------------------------------------------------------------- # +# Demo 3: Hub replacement (real HuggingFace Hub kernel via ``kernels``) +# --------------------------------------------------------------------------- # +# We use the real ``kernels-community/activation`` repo on the HF Hub, which +# ships a ``SiluAndMul`` layer (the SwiGLU activation used by Qwen3MLP). +# +# Note: the Hub kernel's ``forward`` calls a CUDA op, so it cannot *execute* +# on CPU. This demo verifies the parts that DO work on CPU: the kernel is +# downloaded lazily via ``_load_hub_ref`` and the target module's class is +# swapped to the Hub-loaded class. Running the fused forward requires CUDA. +_HUB_REPO = "kernels-community/activation" +_HUB_LAYER = "SiluAndMul" + + +class LocalSiluAndMul(nn.Module): + """Pure-torch SwiGLU activation, same interface as the Hub ``SiluAndMul``.""" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + d = x.shape[-1] // 2 + return F.silu(x[..., :d]) * x[..., d:] + + +def demo_hub_replacement() -> None: + print("=" * 60) + print("Demo 3: Hub replacement (real HF Hub kernel: kernels-community/activation)") + print("=" * 60) + + try: + from kernels import get_kernel # noqa: F401 + except ImportError: + print("Skipped: `kernels` package not installed (pip install kernels)") + return + + model = nn.Sequential(LocalSiluAndMul()) + x = torch.randn(1, 4, 16, dtype=torch.float32) + + out_before = model(x) + print(f"Before kernelize: type = {type(model[0]).__name__}") + print(f" output[0,0,:3] = {out_before[0, 0, :3].tolist()}") + + ref = hub(f"{_HUB_REPO}:{_HUB_LAYER}", version=1) + print(f"HubRef: repo_id={ref.repo_id!r}, layer_name={ref.layer_name!r}, version={ref.version}") + + try: + kernelize(model, {LocalSiluAndMul: ref}) + except Exception as e: + print(f"Skipped: could not load Hub kernel ({type(e).__name__}: {e})") + return + + hub_cls = type(model[0]) + print(f"After kernelize: type = {hub_cls.__name__}") + print(f" module = {hub_cls.__module__}") + + _assert(hub_cls.__name__ == _HUB_LAYER, "should be the Hub SiluAndMul class") + _assert( + "activation" in hub_cls.__module__, + "loaded class should come from the Hub activation kernel package", + ) + # The Hub forward is CUDA-only, so we do not execute it on CPU. + print("(Hub kernel forward is CUDA-only; verified download + class swap on CPU)") + print("✓ Hub replacement passed\n") + + +# --------------------------------------------------------------------------- # +# Main +# --------------------------------------------------------------------------- # +def main() -> None: + print("Running kernelize end-to-end demos on CPU...\n") + demo_class_replacement() + demo_attr_replacement() + demo_hub_replacement() + print("All demos passed.") + + +if __name__ == "__main__": + main() diff --git a/src/twinkle/kernel/core.py b/src/twinkle/kernel/core.py index e6b33ee1..bdb12270 100644 --- a/src/twinkle/kernel/core.py +++ b/src/twinkle/kernel/core.py @@ -24,7 +24,6 @@ class HubRef: revision: str | None = None version: int | None = None backend: str | None = None - trust_remote_code: bool = False def hub( @@ -33,7 +32,6 @@ def hub( revision: str | None = None, version: int | None = None, backend: str | None = None, - trust_remote_code: bool = False, ) -> HubRef: """Build a ``HubRef`` for use as a ``kernelize`` mapping value. @@ -45,7 +43,7 @@ def hub( if ':' not in ref: raise ValueError(f"Hub ref must be 'repo_id:LayerName', got: {ref!r}") repo_id, layer_name = ref.rsplit(':', 1) - return HubRef(repo_id, layer_name, revision, version, backend, trust_remote_code) + return HubRef(repo_id, layer_name, revision, version, backend) def _resolve_value(value: Any, device: str) -> Any | None: @@ -122,7 +120,6 @@ def _load_hub_ref(ref: HubRef): revision=ref.revision, version=ref.version, backend=ref.backend, - trust_remote_code=ref.trust_remote_code, ) layers = getattr(kernel, 'layers', None) if layers is None: diff --git a/src/twinkle/kernel/csrc/placeholder b/src/twinkle/kernel/csrc/placeholder deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/kernel/test_hub.py b/tests/kernel/test_hub.py index e1e2644e..022f42e9 100644 --- a/tests/kernel/test_hub.py +++ b/tests/kernel/test_hub.py @@ -11,7 +11,6 @@ def test_hub_with_version(): assert ref.version == 1 assert ref.revision is None assert ref.backend is None - assert ref.trust_remote_code is False def test_hub_with_revision(): @@ -20,10 +19,9 @@ def test_hub_with_revision(): assert ref.version is None -def test_hub_with_backend_and_trust(): - ref = hub('org/repo:Layer', version=2, backend='cuda', trust_remote_code=True) +def test_hub_with_backend(): + ref = hub('org/repo:Layer', version=2, backend='cuda') assert ref.backend == 'cuda' - assert ref.trust_remote_code is True def test_hub_rejects_both_revision_and_version():