diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b2a32f4..af94d2c 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -23,6 +23,15 @@ jobs: - name: Lint run: ruff check . + - name: Guard deprecated Compute usage in runtime paths + run: | + hits="$(rg -n "from grilly import Compute|from \.\.backend\.compute import Compute|from \.compute import Compute|Compute\(" functional nn/module.py utils/tensor_conversion.py --glob "*.py" || true)" + if [ -n "$hits" ]; then + echo "Deprecated Compute usage detected in runtime paths:" + echo "$hits" + exit 1 + fi + test: runs-on: ubuntu-latest steps: diff --git a/.gitignore b/.gitignore index fa4811a..359f627 100644 --- a/.gitignore +++ b/.gitignore @@ -201,3 +201,10 @@ docs/symp_former.md /third_party third_party/BLAKE3 third_party/VulkanMemoryAllocator +docs/BRIDGE_MIGRATION_CHECKLIST.md +docs/GPU_OPTIMIZATION_REVIEW.md +docs/PYTORCH_PARITY_STATUS.md +docs/PYTORCH_PARITY_TASKLIST.md +/build_verify_pybind +/build_verify_pybind2 +docs/MIGRATION_PYTORCH.md diff --git a/CHANGELOG.md b/CHANGELOG.md index ae07d7c..f6f9f8e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,32 @@ This changelog follows the spirit of **Keep a Changelog** and uses the terms **A --- +## [Unreleased] + +### Added + +- **`tests/sentencepiece/test_sentencepiece_parity.py`** — Adaptive SentencePiece parity tests vs Hugging Face references (T5/MT5-style tokenizers; auto-skip until `grilly.tokenizers` SentencePiece support is available). +- **`tests/tokenizers/test_gpu_tokenizer_parity.py`** — Adaptive real tokenizer parity tests vs Hugging Face references (auto-skip until `grilly.tokenizers` lands and assets are available). +- **`tests/sentence_transformers/`**, **`tests/transformers_compat/`**, **`tests/converter/`**, **`tests/autograd_chain/`**, **`tests/moe_quant/`** — CI-safe skip-marked scaffold suites to make pre-v1.0 roadmap targets executable and incrementally fillable. +- **`docs/pre-v1.0/OPTIMIZATION_PARITY_TASKLIST.md`** — Post-upgrade pre-v1.0 optimization + feature parity execution plan (P0/P1/P2 workstreams, acceptance criteria, and verification commands). +- **`backend/fnn_chain.py`** — `FnnChainRecorder` / `ChainBufferHandle`: batched `linear` / `relu` / `softmax` recording with `read()` / `read_multiple()` for one submit+wait (then one or many downloads); `VulkanCompute.record_commands(fnn_chain=True)` or `VulkanFNN.chain_record()`. See `docs/PERF_DISPATCH.md`. +- **`tests/parity/`** — Numerical parity tests for `grilly.functional` (numpy reference; optional PyTorch `F.linear` / `F.relu` when `torch` is installed). See `tests/parity/README.md`. +- **`tests/parity/test_optimizers_parity.py`** — SGD and Adam (CPU) stepping vs `torch.optim`. +- **`docs/MIGRATION_PYTORCH.md`** — PyTorch → Grilly migration cookbook (device model, functional layout, module backend lifecycle, debugging). +- **`docs/PERF_DISPATCH.md`** — Vulkan dispatch/batching (`record_commands`, async dispatch), Sequential fusion notes, pybind GIL checklist. + +### Changed + +- **`docs/pre-v1.0/OPTIMIZATION_PARITY_TASKLIST.md`** — Re-prioritized roadmap to feature-delivery order: GPU tokenizer -> sentence-transformers -> transformers compatibility -> PyTorch→Grilly converter, with supporting throughput track; expanded tokenizer interoperability to include SentencePiece compatibility targets. +- **`utils/tensor_conversion.py`**, **`backend/base.py`** — `VulkanTensor.prepare_for_dispatch()` and C++ `Tensor.gpu_handle_if_valid()` binding so GPU-resident tensors reuse buffers in `BufferMixin._prepare_input` without redundant uploads when possible. +- **`docs/PYTORCH_PARITY_TASKLIST.md`**, **`docs/PYTORCH_PARITY_STATUS.md`**, **`docs/GPU_OPTIMIZATION_REVIEW.md`** — Workstream C (kernel/throughput milestone) marked complete with scoped deliverables; follow-ups consolidated under **Workstream C — future**. +- **`docs/api/functional.md`** — PyTorch parity notes and links to migration docs and parity tests. +- **`backend/compute.py`** — `VulkanCompute` exposes `record_commands`, `dispatch_compute`, `dispatch_compute_async`, `wait_async`, `wait_fence`. +- **`backend/fnn.py`** — `_linear_relu_recorded_chain`: Linear→ReLU in one submit when `fused-linear-relu` shader is absent. +- **`cpp/python/bindings_{activations,attention,linear}.cpp`** — GIL release around heavy ops; `require_c_contiguous_float` on `linear` inputs. + +--- + ## [0.5.0] — 2026-03-18 — "GPU-First" ### Added diff --git a/CMakeLists.txt b/CMakeLists.txt index 92f53fd..076d0ad 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -90,6 +90,10 @@ else() endif() # ── pybind11 ───────────────────────────────────────────────────────────── +# Silence "Using compatibility mode for Python, set PYBIND11_FINDPYTHON to NEW/OLD" +if(NOT DEFINED PYBIND11_FINDPYTHON) + set(PYBIND11_FINDPYTHON NEW) +endif() set(PYBIND11_DIR "${CMAKE_CURRENT_SOURCE_DIR}/third_party/pybind11") if(EXISTS "${PYBIND11_DIR}/CMakeLists.txt") add_subdirectory("${PYBIND11_DIR}" pybind11) @@ -214,6 +218,9 @@ add_library(grilly_core_lib STATIC cpp/src/ops/perceiver.cpp cpp/src/ops/perceiver_encoder.cpp cpp/src/ops/moqe_train.cpp + cpp/src/ops/moe_forward.cpp + cpp/src/ops/vsa_lm_forward.cpp + cpp/src/ops/prefix_scan.cpp cpp/src/shader_fusion.cpp # ── Experimental ── cpp/src/experimental/paged_latent_pool.cpp @@ -243,6 +250,7 @@ add_library(grilly_core_lib STATIC cpp/src/nn/containers.cpp cpp/src/nn/optimizer.cpp cpp/src/nn/dataloader.cpp + cpp/src/io/grl_checkpoint.cpp ) target_include_directories(grilly_core_lib PUBLIC @@ -299,7 +307,11 @@ pybind11_add_module(grilly_core cpp/python/bindings_siglip.cpp cpp/python/bindings_perceiver.cpp cpp/python/bindings_moqe_train.cpp + cpp/python/bindings_moe.cpp cpp/python/bindings_fusion.cpp + cpp/python/bindings_vsa_lm.cpp + cpp/python/bindings_grl.cpp + cpp/python/bindings_prefix_scan.cpp ) target_link_libraries(grilly_core PRIVATE grilly_core_lib) diff --git a/README.md b/README.md index ec0185a..da9a656 100644 --- a/README.md +++ b/README.md @@ -53,8 +53,8 @@ curl -sSL https://raw.githubusercontent.com/Grillcheese-AI/grilly/main/scripts/i On Colab: ```python -# Recommended: fast mode for Colab (5 min instead of 30) -!wget -qO- https://raw.githubusercontent.com/Grillcheese-AI/grilly/main/scripts/install.sh | bash -s -- --fast +# Recommended: Colab mode (Vulkan 1.3 + fast build + NVIDIA ICD, ~5 min) +!wget -qO- https://raw.githubusercontent.com/Grillcheese-AI/grilly/main/scripts/install.sh | bash -s -- --colab ``` This installs system deps, downloads and builds Vulkan SDK 1.4, compiles the grilly C++ extension, and installs the Python package. The `--fast` flag builds only the components grilly needs (shaderc, loader, headers) and skips validation layers. diff --git a/__init__.py b/__init__.py index cad6c95..3e0fa0b 100644 --- a/__init__.py +++ b/__init__.py @@ -19,11 +19,29 @@ - Hippocampal transformer with capsule memory """ +# grilly_core..pyd lives in this directory as a sibling .pyd. Editable +# (PEP 660) installs route `import grilly` through a path hook that does NOT +# add this directory to sys.path, so `import grilly_core` would fail and the +# Vulkan probe in backend/base.py would silently report VULKAN_AVAILABLE=False +# even on machines with a perfectly working Vulkan device. Insert the package +# directory once, before any submodule import triggers the probe. +import os as _os +import sys as _sys +_pkg_dir = _os.path.dirname(_os.path.abspath(__file__)) +if _pkg_dir not in _sys.path: + _sys.path.insert(0, _pkg_dir) + import grilly.functional as functional import grilly.nn as nn import grilly.optim as optim import grilly.utils as utils -from grilly.backend.base import VULKAN_AVAILABLE + +from . import torch_api +from grilly.backend.base import ( + VULKAN_AVAILABLE, + VULKAN_PYTHON_BINDINGS_AVAILABLE, + VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE, +) from grilly.backend.capsule_transformer import ( CapsuleMemory, CapsuleTransformerConfig, @@ -65,6 +83,8 @@ __all__ = [ "VULKAN_AVAILABLE", + "VULKAN_PYTHON_BINDINGS_AVAILABLE", + "VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE", "VulkanCompute", "Compute", "SNNCompute", @@ -79,6 +99,7 @@ "functional", "optim", "utils", + "torch_api", ] # Conditionally add compatibility exports diff --git a/backend/__init__.py b/backend/__init__.py index 814dd4f..b2cb4ab 100644 --- a/backend/__init__.py +++ b/backend/__init__.py @@ -12,7 +12,11 @@ - Bridge operations (continuous ↔ spike) """ -from .base import VULKAN_AVAILABLE +from .base import ( + VULKAN_AVAILABLE, + VULKAN_PYTHON_BINDINGS_AVAILABLE, + VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE, +) from .capsule_transformer import ( CapsuleMemory, CapsuleTransformerConfig, @@ -26,6 +30,8 @@ __all__ = [ "VULKAN_AVAILABLE", + "VULKAN_PYTHON_BINDINGS_AVAILABLE", + "VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE", "VulkanCompute", "SNNCompute", "VulkanLearning", diff --git a/backend/_bridge.py b/backend/_bridge.py index 2f86017..9b74d53 100644 --- a/backend/_bridge.py +++ b/backend/_bridge.py @@ -11,6 +11,10 @@ result = _bridge.linear(x, weight, bias) if result is not None: return result # else fall through to legacy backend + +Fused MoE (:func:`moe_upload`, :func:`moe_forward`, :func:`moe_backward`, etc.) +uses the same lazy :func:`_get_device` and ``shaders/spv`` load as other bridge +ops—prefer these over calling ``grilly_core.moe_*`` with a separate Device. """ import logging @@ -19,6 +23,8 @@ import numpy as np logger = logging.getLogger("grilly.bridge") +_BRIDGE_STRICT = os.getenv("GRILLY_BRIDGE_STRICT", "0").strip().lower() in {"1", "true", "yes", "on"} +_FALLBACK_COUNTS = {} def _maybe_trace(op_name, inputs, output, **kwargs): @@ -31,6 +37,25 @@ def _maybe_trace(op_name, inputs, output, **kwargs): except ImportError: pass + +def _record_fallback(op_name: str, reason: Exception | str | None = None): + """Record fallback events and optionally raise in strict mode.""" + _FALLBACK_COUNTS[op_name] = _FALLBACK_COUNTS.get(op_name, 0) + 1 + if reason is not None: + logger.debug("%s fallback triggered: %s", op_name, reason) + if _BRIDGE_STRICT: + raise RuntimeError(f"GRILLY_BRIDGE_STRICT=1: bridge op '{op_name}' failed: {reason}") + + +def get_fallback_stats() -> dict[str, int]: + """Get fallback counts per bridge op.""" + return dict(_FALLBACK_COUNTS) + + +def reset_fallback_stats(): + """Reset bridge fallback counters.""" + _FALLBACK_COUNTS.clear() + try: import grilly_core as _core @@ -133,6 +158,7 @@ def linear(x, weight, bias=None): return result except Exception as e: logger.debug("linear GPU failed (%s), caller will use CPU fallback", e) + _record_fallback("linear", e) return None @@ -147,9 +173,7 @@ def relu(x): try: return _core.relu(dev, _ensure_f32_contiguous(x)) except Exception as e: - - logger.debug("GPU op failed: %s", e) - + _record_fallback("relu", e) return None @@ -193,9 +217,7 @@ def gelu(x): try: return _core.gelu(dev, x) except Exception as e: - - logger.debug("GPU op failed: %s", e) - + _record_fallback("gelu", e) return None @@ -207,9 +229,7 @@ def silu(x): try: return _core.silu(dev, _ensure_f32_contiguous(x)) except Exception as e: - - logger.debug("GPU op failed: %s", e) - + _record_fallback("silu", e) return None @@ -221,9 +241,7 @@ def tanh(x): try: return _core.tanh_act(dev, _ensure_f32_contiguous(x)) except Exception as e: - - logger.debug("GPU op failed: %s", e) - + _record_fallback("tanh", e) return None @@ -240,9 +258,7 @@ def relu_backward(grad_output, input): dev, _ensure_f32_contiguous(grad_output), _ensure_f32_contiguous(input) ) except Exception as e: - - logger.debug("GPU op failed: %s", e) - + _record_fallback("relu_backward", e) return None @@ -256,9 +272,7 @@ def gelu_backward(grad_output, input): dev, _ensure_f32_contiguous(grad_output), _ensure_f32_contiguous(input) ) except Exception as e: - - logger.debug("GPU op failed: %s", e) - + _record_fallback("gelu_backward", e) return None @@ -272,9 +286,7 @@ def silu_backward(grad_output, input): dev, _ensure_f32_contiguous(grad_output), _ensure_f32_contiguous(input) ) except Exception as e: - - logger.debug("GPU op failed: %s", e) - + _record_fallback("silu_backward", e) return None @@ -312,9 +324,7 @@ def lif_step( t_refrac_period, ) except Exception as e: - - logger.debug("GPU op failed: %s", e) - + _record_fallback("lif_step", e) return None @@ -347,9 +357,7 @@ def snn_node_forward( decay_input, ) except Exception as e: - - logger.debug("GPU op failed: %s", e) - + _record_fallback("snn_node_forward", e) return None @@ -368,9 +376,7 @@ def snn_node_backward(grad_spike, h_cache, alpha=2.0, surrogate_type=0, v_thresh v_threshold, ) except Exception as e: - - logger.debug("GPU op failed: %s", e) - + _record_fallback("snn_node_backward", e) return None @@ -393,9 +399,7 @@ def hebbian_learning( weight_decay, ) except Exception as e: - - logger.debug("GPU op failed: %s", e) - + _record_fallback("hebbian_learning", e) return None @@ -430,9 +434,7 @@ def stdp_learning( trace_decay, ) except Exception as e: - - logger.debug("GPU op failed: %s", e) - + _record_fallback("stdp_learning", e) return None @@ -465,9 +467,7 @@ def oja_learning(memories, inputs, num_vectors, dim, eta=0.01): eta, ) except Exception as e: - - logger.debug("GPU op failed: %s", e) - + _record_fallback("oja_learning", e) return None @@ -481,9 +481,7 @@ def synapse_filter(x_in, y_state, decay=0.95): dev, _ensure_f32_contiguous(x_in), _ensure_f32_contiguous(y_state), decay ) except Exception as e: - - logger.debug("GPU op failed: %s", e) - + _record_fallback("synapse_filter", e) return None @@ -538,9 +536,7 @@ def gif_neuron_step( t_refrac_period, ) except Exception as e: - - logger.debug("GPU op failed: %s", e) - + _record_fallback("gif_neuron_step", e) return None @@ -560,9 +556,7 @@ def conv2d(x, weight, bias=None, stride=(1, 1), padding=(0, 0), dilation=(1, 1), dev, x, weight, bias, list(stride), list(padding), list(dilation), groups ) except Exception as e: - - logger.debug("GPU op failed: %s", e) - + _record_fallback("conv2d", e) return None @@ -582,9 +576,7 @@ def conv2d_3x3_gelu(x, weight, bias): _ensure_f32_contiguous(bias), ) except Exception as e: - - logger.debug("GPU op failed: %s", e) - + _record_fallback("conv2d_3x3_gelu", e) return None @@ -596,9 +588,7 @@ def maxpool2x2(x): try: return _core.maxpool2x2(dev, _ensure_f32_contiguous(x)) except Exception as e: - - logger.debug("GPU op failed: %s", e) - + _record_fallback("maxpool2x2", e) return None @@ -610,9 +600,7 @@ def adaptive_avgpool_3x3(x): try: return _core.adaptive_avgpool_3x3(dev, _ensure_f32_contiguous(x)) except Exception as e: - - logger.debug("GPU op failed: %s", e) - + _record_fallback("adaptive_avgpool_3x3", e) return None @@ -630,9 +618,7 @@ def layernorm(x, gamma, beta, eps=1e-5): beta = _ensure_f32_contiguous(beta) return _core.layernorm(dev, x, gamma, beta, eps) except Exception as e: - - logger.debug("GPU op failed: %s", e) - + _record_fallback("layernorm", e) return None @@ -652,9 +638,7 @@ def layernorm_backward(grad_output, input, gamma, mean, var, eps=1e-5): eps, ) except Exception as e: - - logger.debug("GPU op failed: %s", e) - + _record_fallback("layernorm_backward", e) return None @@ -671,9 +655,7 @@ def rmsnorm(x, weight, eps=1e-5): weight = _ensure_f32_contiguous(weight) return _core.rmsnorm(dev, x, weight, eps) except Exception as e: - - logger.debug("GPU op failed: %s", e) - + _record_fallback("rmsnorm", e) return None @@ -690,9 +672,7 @@ def tanh_backward(grad_output, tanh_output): dev, _ensure_f32_contiguous(grad_output), _ensure_f32_contiguous(tanh_output) ) except Exception as e: - - logger.debug("GPU op failed: %s", e) - + _record_fallback("tanh_backward", e) return None @@ -704,9 +684,7 @@ def softmax(x, dim=-1): try: return _core.softmax(dev, _ensure_f32_contiguous(x), dim) except Exception as e: - - logger.debug("GPU op failed: %s", e) - + _record_fallback("softmax", e) return None @@ -720,9 +698,43 @@ def softmax_backward(grad_output, softmax_output): dev, _ensure_f32_contiguous(grad_output), _ensure_f32_contiguous(softmax_output) ) except Exception as e: + _record_fallback("softmax_backward", e) + return None + + +def mf_softmax(x, dim=-1): + """GPU multiplication-free softmax (ReLU-normalized). Returns None on failure.""" + dev = _get_device() + if dev is None: + return None + try: + return _core.mf_softmax(dev, _ensure_f32_contiguous(x), dim) + except Exception as e: + _record_fallback("mf_softmax", e) + return None + + +def mf_softplus(x, beta=1.0): + """GPU algebraic softplus (sqrt form). Returns None on failure.""" + dev = _get_device() + if dev is None: + return None + try: + return _core.mf_softplus(dev, _ensure_f32_contiguous(x), float(beta)) + except Exception as e: + _record_fallback("mf_softplus", e) + return None - logger.debug("GPU op failed: %s", e) +def mf_sigmoid(x): + """GPU rational sigmoid x/(1+|x|). Returns None on failure.""" + dev = _get_device() + if dev is None: + return None + try: + return _core.mf_sigmoid(dev, _ensure_f32_contiguous(x)) + except Exception as e: + _record_fallback("mf_sigmoid", e) return None @@ -742,9 +754,7 @@ def linear_backward(grad_output, input, weights): _ensure_f32_contiguous(weights), ) except Exception as e: - - logger.debug("GPU op failed: %s", e) - + _record_fallback("linear_backward", e) return None @@ -758,9 +768,7 @@ def dropout(x, random_mask, p=0.5, training=True): dev, _ensure_f32_contiguous(x), _ensure_f32_contiguous(random_mask), p, training ) except Exception as e: - - logger.debug("GPU op failed: %s", e) - + _record_fallback("dropout", e) return None @@ -786,9 +794,7 @@ def conv2d_backward_input( groups, ) except Exception as e: - - logger.debug("GPU op failed: %s", e) - + _record_fallback("conv2d_backward_input", e) return None @@ -819,9 +825,7 @@ def conv2d_backward_weight( has_bias, ) except Exception as e: - - logger.debug("GPU op failed: %s", e) - + _record_fallback("conv2d_backward_weight", e) return None @@ -838,9 +842,7 @@ def attention_scores(Q, K, scale=0.0): dev, _ensure_f32_contiguous(Q), _ensure_f32_contiguous(K), scale ) except Exception as e: - - logger.debug("GPU op failed: %s", e) - + _record_fallback("attention_scores", e) return None @@ -853,9 +855,7 @@ def attention_mask(scores, mask=None, causal=True, mask_value=-1e9): m = _ensure_f32_contiguous(mask) if mask is not None else None return _core.attention_mask(dev, _ensure_f32_contiguous(scores), m, causal, mask_value) except Exception as e: - - logger.debug("GPU op failed: %s", e) - + _record_fallback("attention_mask", e) return None @@ -869,9 +869,32 @@ def attention_output(weights, V): dev, _ensure_f32_contiguous(weights), _ensure_f32_contiguous(V) ) except Exception as e: + _record_fallback("attention_output", e) + return None + - logger.debug("GPU op failed: %s", e) +def attention_scores_softmax_output(Q, K, V, scale=0.0): + """Fused attention: scores + softmax + weighted V in one GPU submit. + Requires Q, K, V shaped (B, H, S, D) with identical S (self-attention / equal length). + Returns ``(output, softmax_weights)`` as core tensors, or None on failure. + """ + dev = _get_device() + if dev is None: + return None + fused = getattr(_core, "attention_scores_softmax_output", None) + if fused is None: + return None + try: + return fused( + dev, + _ensure_f32_contiguous(Q), + _ensure_f32_contiguous(K), + _ensure_f32_contiguous(V), + scale, + ) + except Exception as e: + _record_fallback("attention_scores_softmax_output", e) return None @@ -883,9 +906,7 @@ def attention_concat_heads(mh_output): try: return _core.attention_concat_heads(dev, _ensure_f32_contiguous(mh_output)) except Exception as e: - - logger.debug("GPU op failed: %s", e) - + _record_fallback("attention_concat_heads", e) return None @@ -899,9 +920,7 @@ def rope(x, cos_table=None, sin_table=None, base=10000.0, scaling=1.0): st = _ensure_f32_contiguous(sin_table) if sin_table is not None else None return _core.rope(dev, _ensure_f32_contiguous(x), ct, st, base, scaling) except Exception as e: - - logger.debug("GPU op failed: %s", e) - + _record_fallback("rope", e) return None @@ -918,9 +937,7 @@ def flash_attention2(Q, K, V, mask=None, scale=0.0, tile_size_q=64, tile_size_k= m = _ensure_f32_contiguous(mask) if mask is not None else None return _core.flash_attention2(dev, Q, K, V, m, scale, tile_size_q, tile_size_k) except Exception as e: - - logger.debug("GPU op failed: %s", e) - + _record_fallback("flash_attention2", e) return None @@ -1036,9 +1053,7 @@ def maxpool2d(x, kernel_size, stride=(2, 2), padding=(0, 0), dilation=(1, 1)): list(dilation), ) except Exception as e: - - logger.debug("GPU op failed: %s", e) - + _record_fallback("maxpool2d", e) return None @@ -1057,9 +1072,7 @@ def avgpool2d(x, kernel_size, stride=(2, 2), padding=(0, 0), count_include_pad=T count_include_pad, ) except Exception as e: - - logger.debug("GPU op failed: %s", e) - + _record_fallback("avgpool2d", e) return None @@ -1071,9 +1084,7 @@ def mean_pool(x): try: return _core.mean_pool(dev, _ensure_f32_contiguous(x)) except Exception as e: - - logger.debug("GPU op failed: %s", e) - + _record_fallback("mean_pool", e) return None @@ -1100,9 +1111,7 @@ def batchnorm2d_forward( training, ) except Exception as e: - - logger.debug("GPU op failed: %s", e) - + _record_fallback("batchnorm2d_forward", e) return None @@ -1122,9 +1131,7 @@ def cross_entropy_loss(logits, targets, label_smoothing=0.0): dev, _ensure_f32_contiguous(logits), targets, label_smoothing ) except Exception as e: - - logger.debug("GPU op failed: %s", e) - + _record_fallback("cross_entropy_loss", e) return None @@ -1139,9 +1146,7 @@ def cross_entropy_backward(logits, targets): targets = targets.astype(np.uint32) return _core.cross_entropy_backward(dev, _ensure_f32_contiguous(logits), targets) except Exception as e: - - logger.debug("GPU op failed: %s", e) - + _record_fallback("cross_entropy_backward", e) return None @@ -1181,9 +1186,7 @@ def adam_update( clear_grad, ) except Exception as e: - - logger.debug("GPU op failed: %s", e) - + _record_fallback("adam_update", e) return None @@ -1222,9 +1225,7 @@ def adamw_update( clear_grad, ) except Exception as e: - - logger.debug("GPU op failed: %s", e) - + _record_fallback("adamw_update", e) return None @@ -1242,9 +1243,7 @@ def embedding_lookup(token_ids, embeddings): token_ids = token_ids.astype(np.uint32) return _core.embedding_lookup(dev, token_ids, _ensure_f32_contiguous(embeddings)) except Exception as e: - - logger.debug("GPU op failed: %s", e) - + _record_fallback("embedding_lookup", e) return None @@ -1288,9 +1287,7 @@ def create_kv_cache( eviction_threshold, ) except Exception as e: - - logger.debug("GPU op failed: %s", e) - + _record_fallback("create_kv_cache", e) return None @@ -1305,9 +1302,7 @@ def kv_cache_append(kv_cache, new_keys, new_values): _core.kv_cache_append(dev, kv_cache, new_keys, new_values) return True except Exception as e: - - logger.debug("GPU op failed: %s", e) - + _record_fallback("kv_cache_append", e) return None @@ -1319,9 +1314,7 @@ def kv_cache_decode(kv_cache): try: return _core.kv_cache_decode(dev, kv_cache) except Exception as e: - - logger.debug("GPU op failed: %s", e) - + _record_fallback("kv_cache_decode", e) return None @@ -1335,9 +1328,7 @@ def kv_cache_evict_h2o(kv_cache, attention_scores=None, num_evict=0): _core.kv_cache_evict_h2o(dev, kv_cache, scores, num_evict) return True except Exception as e: - - logger.debug("GPU op failed: %s", e) - + _record_fallback("kv_cache_evict_h2o", e) return None @@ -1790,6 +1781,513 @@ def moqe_route_and_gemv(activations, choice, expert_weights, expert_scales, bloc ) +# ── Fused MoE (grilly_core moe_* — same lazy device + shaders as other bridge ops) ── + + +def moe_upload(embed_w, pos_w, expert_ws, router_ws, router_bs, out_w, n_layers, n_experts): + """Upload MoE weights to GPU; returns an opaque integer handle. + + Uses the shared bridge :func:`_get_device` (lazy init + ``shaders/spv`` load). + + Args: + embed_w: float32 (vocab, d_model), C-contiguous. + pos_w: float32 (max_seq, d_model). + expert_ws: list of length ``n_layers * n_experts``, each (d_model, d_model). + router_ws: list of length ``n_layers``, each (n_experts, d_model). + router_bs: list of length ``n_layers``, each (n_experts,). + out_w: float32 (vocab, d_model) — output projection. + n_layers, n_experts: int. + + Returns: + int handle, or ``None`` if the C++ extension / device is unavailable. + """ + if not _NATIVE or not hasattr(_core, "moe_upload"): + return None + dev = _get_device() + if dev is None: + return None + try: + return _core.moe_upload( + dev, + _ensure_f32_contiguous(embed_w), + _ensure_f32_contiguous(pos_w), + expert_ws, + router_ws, + router_bs, + _ensure_f32_contiguous(out_w), + int(n_layers), + int(n_experts), + ) + except Exception as e: + _record_fallback("moe_upload", e) + return None + + +def moe_release(handle): + """Free GPU resources for a handle from :func:`moe_upload`. + + Returns: + ``True`` if released, ``False`` if bridge/device unavailable. + """ + if not _NATIVE or not hasattr(_core, "moe_release"): + return False + dev = _get_device() + if dev is None: + return False + try: + _core.moe_release(dev, int(handle)) + return True + except Exception as e: + _record_fallback("moe_release", e) + return False + + +def moe_forward(handle, input_ids): + """Run fused MoE forward; returns logits (seq_len, vocab). + + Args: + handle: int from :func:`moe_upload`. + input_ids: int32 1-D (seq_len,), C-contiguous. + + Returns: + float32 ndarray or ``None`` on failure. + """ + if not _NATIVE or not hasattr(_core, "moe_forward"): + return None + dev = _get_device() + if dev is None: + return None + ids = np.asarray(input_ids, dtype=np.int32) + if not ids.flags.c_contiguous: + ids = np.ascontiguousarray(ids, dtype=np.int32) + try: + return _core.moe_forward(dev, int(handle), ids) + except Exception as e: + _record_fallback("moe_forward", e) + return None + + +def moe_update_weights(handle, embed_w, pos_w, expert_ws, router_ws, router_bs, out_w): + """Re-upload weights in place (same shapes as :func:`moe_upload`). + + Returns: + ``True`` on success, ``False`` on failure. + """ + if not _NATIVE or not hasattr(_core, "moe_update_weights"): + return False + dev = _get_device() + if dev is None: + return False + try: + _core.moe_update_weights( + dev, + int(handle), + _ensure_f32_contiguous(embed_w), + _ensure_f32_contiguous(pos_w), + expert_ws, + router_ws, + router_bs, + _ensure_f32_contiguous(out_w), + ) + return True + except Exception as e: + _record_fallback("moe_update_weights", e) + return False + + +def moe_backward(handle, input_ids, grad_logits): + """CPU backward for MoE (gradients dict); uses CPU mirrors from upload. + + Args: + handle: int from :func:`moe_upload`. + input_ids: int32 (seq_len,) — same sequence as forward. + grad_logits: float32 (seq_len, vocab). + + Returns: + dict with keys ``grad_embed``, ``grad_pos``, ``grad_experts``, ``grad_routers_W``, + ``grad_routers_b``, ``grad_out_w``, or ``None`` on failure. + """ + if not _NATIVE or not hasattr(_core, "moe_backward"): + return None + dev = _get_device() + if dev is None: + return None + ids = np.asarray(input_ids, dtype=np.int32) + if not ids.flags.c_contiguous: + ids = np.ascontiguousarray(ids, dtype=np.int32) + g = _ensure_f32_contiguous(grad_logits) + try: + return _core.moe_backward(dev, int(handle), ids, g) + except Exception as e: + _record_fallback("moe_backward", e) + return None + + +# ── Fused VSA-LM (grilly_core vsa_lm_* — AdditionLinear FFN + sign activation) ── + + +def vsa_lm_upload( + embed_w, pos_w, + ffn_up_patterns, ffn_up_biases, ffn_down_patterns, ffn_down_biases, + ln_gammas, ln_betas, out_w, + n_layers, d_model, d_ffn, +): + """Upload VSA-LM weights to GPU; returns an opaque integer handle. + + Uses the shared bridge :func:`_get_device` (lazy init + ``shaders/spv`` load). + + Args: + embed_w: float32 (vocab, d_model). + pos_w: float32 (max_seq, d_model). + ffn_up_patterns: list of (d_ffn, d_model) float32, length n_layers. + ffn_up_biases: list of (d_ffn,) float32, length n_layers. + ffn_down_patterns: list of (d_model, d_ffn) float32, length n_layers. + ffn_down_biases: list of (d_model,) float32, length n_layers. + ln_gammas: list of (d_model,) float32, length n_layers. + ln_betas: list of (d_model,) float32, length n_layers. + out_w: float32 (vocab, d_model). + n_layers, d_model, d_ffn: int. + + Returns: + int handle, or ``None`` if the C++ extension / device is unavailable. + """ + if not _NATIVE or not hasattr(_core, "vsa_lm_upload"): + return None + dev = _get_device() + if dev is None: + return None + try: + return _core.vsa_lm_upload( + dev, + _ensure_f32_contiguous(embed_w), + _ensure_f32_contiguous(pos_w), + ffn_up_patterns, ffn_up_biases, + ffn_down_patterns, ffn_down_biases, + ln_gammas, ln_betas, + _ensure_f32_contiguous(out_w), + int(n_layers), int(d_model), int(d_ffn), + ) + except Exception as e: + _record_fallback("vsa_lm_upload", e) + return None + + +def vsa_lm_release(handle): + """Free GPU resources for a handle from :func:`vsa_lm_upload`. + + Returns: + ``True`` if released, ``False`` if bridge/device unavailable. + """ + if not _NATIVE or not hasattr(_core, "vsa_lm_release"): + return False + dev = _get_device() + if dev is None: + return False + try: + _core.vsa_lm_release(dev, int(handle)) + return True + except Exception as e: + _record_fallback("vsa_lm_release", e) + return False + + +def vsa_lm_forward(handle, input_ids): + """Run fused VSA-LM forward; returns logits (seq_len, vocab). + + Args: + handle: int from :func:`vsa_lm_upload`. + input_ids: int32 1-D (seq_len,), C-contiguous. + + Returns: + float32 ndarray or ``None`` on failure. + """ + if not _NATIVE or not hasattr(_core, "vsa_lm_forward"): + return None + dev = _get_device() + if dev is None: + return None + ids = np.asarray(input_ids, dtype=np.int32) + if not ids.flags.c_contiguous: + ids = np.ascontiguousarray(ids, dtype=np.int32) + try: + return _core.vsa_lm_forward(dev, int(handle), ids) + except Exception as e: + _record_fallback("vsa_lm_forward", e) + return None + + +def vsa_lm_backward(handle, input_ids, grad_logits): + """CPU backward for VSA-LM (gradients dict); uses CPU mirrors from upload. + + Args: + handle: int from :func:`vsa_lm_upload`. + input_ids: int32 (seq_len,) — same sequence as forward. + grad_logits: float32 (seq_len, vocab). + + Returns: + dict with keys ``grad_embed``, ``grad_pos``, ``grad_out_w``, + ``grad_ffn_up_w``, ``grad_ffn_up_b``, ``grad_ffn_down_w``, + ``grad_ffn_down_b``, ``grad_ln_gamma``, ``grad_ln_beta``, + or ``None`` on failure. + """ + if not _NATIVE or not hasattr(_core, "vsa_lm_backward"): + return None + dev = _get_device() + if dev is None: + return None + ids = np.asarray(input_ids, dtype=np.int32) + if not ids.flags.c_contiguous: + ids = np.ascontiguousarray(ids, dtype=np.int32) + g = _ensure_f32_contiguous(grad_logits) + try: + return _core.vsa_lm_backward(dev, int(handle), ids, g) + except Exception as e: + _record_fallback("vsa_lm_backward", e) + return None + + +def vsa_lm_update_weights( + handle, embed_w, pos_w, + ffn_up_patterns, ffn_up_biases, ffn_down_patterns, ffn_down_biases, + ln_gammas, ln_betas, out_w, +): + """Re-upload VSA-LM weights in place (same shapes as :func:`vsa_lm_upload`). + + Returns: + ``True`` on success, ``False`` on failure. + """ + if not _NATIVE or not hasattr(_core, "vsa_lm_update_weights"): + return False + dev = _get_device() + if dev is None: + return False + try: + _core.vsa_lm_update_weights( + dev, int(handle), + _ensure_f32_contiguous(embed_w), + _ensure_f32_contiguous(pos_w), + ffn_up_patterns, ffn_up_biases, + ffn_down_patterns, ffn_down_biases, + ln_gammas, ln_betas, + _ensure_f32_contiguous(out_w), + ) + return True + except Exception as e: + _record_fallback("vsa_lm_update_weights", e) + return False + + +# ── Bridge support for functional modules without native kernels ────────── + + +def faiss_distance(query, vectors, distance_type="l2"): + """Compute FAISS-style distances with numpy fallback.""" + q = np.asarray(query, dtype=np.float32) + v = np.asarray(vectors, dtype=np.float32) + if q.ndim == 1: + q = q.reshape(1, -1) + if distance_type == "l2": + diff = q[:, None, :] - v[None, :, :] + return np.sqrt(np.sum(diff * diff, axis=2, dtype=np.float32), dtype=np.float32) + if distance_type == "cosine": + q_norm = q / (np.linalg.norm(q, axis=1, keepdims=True) + 1e-8) + v_norm = v / (np.linalg.norm(v, axis=1, keepdims=True) + 1e-8) + return (1.0 - (q_norm @ v_norm.T)).astype(np.float32) + return (-(q @ v.T)).astype(np.float32) + + +def faiss_topk(distances, k): + """Select top-k nearest neighbors from a distance matrix.""" + d = np.asarray(distances, dtype=np.float32) + if d.ndim == 1: + d = d.reshape(1, -1) + k = max(1, min(int(k), d.shape[1])) + idx = np.argpartition(d, kth=k - 1, axis=1)[:, :k] + vals = np.take_along_axis(d, idx, axis=1) + order = np.argsort(vals, axis=1) + idx = np.take_along_axis(idx, order, axis=1).astype(np.int32) + vals = np.take_along_axis(vals, order, axis=1).astype(np.float32) + return idx, vals + + +def memory_read(queries, memory_keys, memory_values, temperature=None): + """Memory read via attention-style lookup.""" + q = np.asarray(queries, dtype=np.float32) + k = np.asarray(memory_keys, dtype=np.float32) + v = np.asarray(memory_values, dtype=np.float32) + temp = float(temperature) if temperature is not None else float(np.sqrt(k.shape[-1])) + scores = (q @ k.T) / max(temp, 1e-8) + scores = scores - np.max(scores, axis=-1, keepdims=True) + weights = np.exp(scores) + weights = weights / np.sum(weights, axis=-1, keepdims=True) + return (weights @ v).astype(np.float32) + + +def memory_write(new_key, new_value, memory_keys, memory_values, write_index, write_mode=0, blend_factor=0.5): + """Memory write for key/value banks.""" + keys = np.array(memory_keys, dtype=np.float32, copy=True) + vals = np.array(memory_values, dtype=np.float32, copy=True) + idx = int(write_index) + if write_mode == 1: + a = float(blend_factor) + keys[idx] = (1.0 - a) * keys[idx] + a * np.asarray(new_key, dtype=np.float32) + vals[idx] = (1.0 - a) * vals[idx] + a * np.asarray(new_value, dtype=np.float32) + else: + keys[idx] = np.asarray(new_key, dtype=np.float32) + vals[idx] = np.asarray(new_value, dtype=np.float32) + return keys, vals + + +def memory_query_pooling(x, w_query, b_query): + """Pool sequence to query vectors by mean-pool + projection.""" + x_arr = np.asarray(x, dtype=np.float32) + pooled = np.mean(x_arr, axis=1) + return (pooled @ np.asarray(w_query, dtype=np.float32).T + np.asarray(b_query, dtype=np.float32)).astype(np.float32) + + +def memory_inject_gate(x, memory_context, w_gate, b_gate, w_mem_proj): + """Gate between token state and projected memory context.""" + x_arr = np.asarray(x, dtype=np.float32) + mem = np.asarray(memory_context, dtype=np.float32) + batch, seq_len, dim = x_arr.shape + mem_expand = np.broadcast_to(mem[:, None, :], (batch, seq_len, dim)) + concat = np.concatenate([x_arr, mem_expand], axis=-1) + gate = 1.0 / (1.0 + np.exp(-(concat @ np.asarray(w_gate, dtype=np.float32).T + np.asarray(b_gate, dtype=np.float32)))) + mem_proj = mem @ np.asarray(w_mem_proj, dtype=np.float32).T + mem_proj = np.broadcast_to(mem_proj[:, None, :], (batch, seq_len, dim)) + return ((1.0 - gate) * x_arr + gate * mem_proj).astype(np.float32) + + +def fisher_info_update(gradients, fisher, momentum=0.9, use_ema=True, reset=False): + """Update Fisher information estimate.""" + g = np.asarray(gradients, dtype=np.float32) + f = np.zeros_like(g, dtype=np.float32) if reset else np.asarray(fisher, dtype=np.float32) + g2 = g * g + if use_ema: + return (float(momentum) * f + (1.0 - float(momentum)) * g2).astype(np.float32) + return (f + g2).astype(np.float32) + + +def ewc_penalty(current_params, important_params, fisher, lambda_ewc=1.0): + """Compute scalar EWC penalty.""" + cur = np.asarray(current_params, dtype=np.float32) + imp = np.asarray(important_params, dtype=np.float32) + f = np.asarray(fisher, dtype=np.float32) + diff = cur - imp + return float(0.5 * float(lambda_ewc) * np.sum(f * diff * diff, dtype=np.float32)) + + +def natural_gradient(gradients, fisher, eps=1e-8): + """Approximate natural gradient using diagonal Fisher.""" + g = np.asarray(gradients, dtype=np.float32) + f = np.asarray(fisher, dtype=np.float32) + return (g / (f + float(eps))).astype(np.float32) + + +def nlms_predict(x, w, bias=0.0): + """NLMS prediction.""" + return float(np.dot(np.asarray(x, dtype=np.float32), np.asarray(w, dtype=np.float32)) + float(bias)) + + +def nlms_update(x, y_true, w, bias, mu=0.5, eps=1e-6): + """NLMS weight update.""" + x_arr = np.asarray(x, dtype=np.float32) + w_arr = np.asarray(w, dtype=np.float32) + pred = float(np.dot(x_arr, w_arr) + float(bias)) + err = float(y_true) - pred + denom = float(np.dot(x_arr, x_arr) + float(eps)) + step = float(mu) * err / denom + w_new = (w_arr + step * x_arr).astype(np.float32) + b_new = float(bias + step) + return w_new, b_new + + +def whitening_transform(data, mean=None, std=None): + """Compute/apply whitening transform.""" + x = np.asarray(data, dtype=np.float32) + m = np.asarray(mean, dtype=np.float32) if mean is not None else np.mean(x, axis=0) + s = np.asarray(std, dtype=np.float32) if std is not None else np.std(x, axis=0) + out = (x - m) / (s + 1e-8) + return out.astype(np.float32), m.astype(np.float32), s.astype(np.float32) + + +def place_cell(agent_position, field_centers, field_width=1.0, max_rate=20.0, baseline_rate=0.1): + """Gaussian place-cell tuning curves.""" + pos = np.asarray(agent_position, dtype=np.float32) + centers = np.asarray(field_centers, dtype=np.float32) + if pos.ndim == 1: + pos = pos[None, :] + diff = pos[:, None, :] - centers[None, :, :] + d2 = np.sum(diff * diff, axis=-1) + rates = float(baseline_rate) + (float(max_rate) - float(baseline_rate)) * np.exp( + -d2 / (2.0 * float(field_width) * float(field_width) + 1e-8) + ) + return rates.astype(np.float32).squeeze(0) if np.asarray(agent_position).ndim == 1 else rates.astype(np.float32) + + +def time_cell(current_time, preferred_times, time_constant=1.0, max_rate=15.0, baseline_rate=0.1, membrane_state=None): + """Temporal Gaussian tuning with optional EMA membrane state.""" + t = float(current_time) + pref = np.asarray(preferred_times, dtype=np.float32) + rates = float(baseline_rate) + (float(max_rate) - float(baseline_rate)) * np.exp( + -((t - pref) ** 2) / (2.0 * float(time_constant) * float(time_constant) + 1e-8) + ) + if membrane_state is None: + mem = rates.astype(np.float32) + else: + mem = (0.9 * np.asarray(membrane_state, dtype=np.float32) + 0.1 * rates).astype(np.float32) + return rates.astype(np.float32), mem + + +def continuous_to_spikes(continuous, num_timesteps=10, encoding_type=0, projection_weights=None, projection_bias=None): + """Convert continuous vectors to spike trains.""" + x = np.asarray(continuous, dtype=np.float32) + if x.ndim == 1: + x = x[None, :] + if projection_weights is not None: + x = x @ np.asarray(projection_weights, dtype=np.float32).T + if projection_bias is not None: + x = x + np.asarray(projection_bias, dtype=np.float32) + x = np.clip(x, 0.0, 1.0) + t = int(num_timesteps) + if encoding_type == 1: + spikes = np.zeros((x.shape[0], t, x.shape[1]), dtype=np.float32) + idx = np.minimum((1.0 - x) * (t - 1), t - 1).astype(np.int32) + b_idx = np.arange(x.shape[0])[:, None] + f_idx = np.arange(x.shape[1])[None, :] + spikes[b_idx, idx, f_idx] = 1.0 + return spikes + if encoding_type == 2: + phase = np.linspace(0.0, 2.0 * np.pi, t, endpoint=False, dtype=np.float32) + return ((np.sin(phase[None, :, None] + x[:, None, :] * 2.0 * np.pi) > 0).astype(np.float32)) + return (np.random.rand(x.shape[0], t, x.shape[1]).astype(np.float32) < x[:, None, :]).astype(np.float32) + + +def spikes_to_continuous(spikes, encoding_type=0, time_window=5, temporal_weights=None, projection_weights=None, projection_bias=None): + """Convert spike trains to continuous representations.""" + s = np.asarray(spikes, dtype=np.float32) + if s.ndim != 3: + raise ValueError(f"spikes must be 3D (batch, time, features), got shape={s.shape}") + if encoding_type == 1: + if temporal_weights is None: + tw = np.linspace(1.0, 0.0, s.shape[1], dtype=np.float32) + else: + tw = np.asarray(temporal_weights, dtype=np.float32) + tw = tw / (np.sum(tw) + 1e-8) + cont = np.sum(s * tw[None, :, None], axis=1) + elif encoding_type == 2: + phase = np.linspace(0.0, 2.0 * np.pi, s.shape[1], endpoint=False, dtype=np.float32) + cont = np.sum(s * np.sin(phase)[None, :, None], axis=1) / max(s.shape[1], 1) + else: + window = max(1, min(int(time_window), s.shape[1])) + cont = np.mean(s[:, -window:, :], axis=1) + if projection_weights is not None: + cont = cont @ np.asarray(projection_weights, dtype=np.float32).T + if projection_bias is not None: + cont = cont + np.asarray(projection_bias, dtype=np.float32) + return cont.astype(np.float32) + + def q_similarity(queries): """Compute q-similarity (TAPPA metric) for attention queries. diff --git a/backend/attention.py b/backend/attention.py index a9a7952..176dea5 100644 --- a/backend/attention.py +++ b/backend/attention.py @@ -8,7 +8,7 @@ import numpy as np -from .base import VULKAN_AVAILABLE, BufferMixin +from .base import BufferMixin, VULKAN_PYTHON_BINDINGS_AVAILABLE from .shader_registry import get_shader logger = logging.getLogger(__name__) @@ -33,7 +33,7 @@ numba_prosody_modulation = None numba_attention_output = None -if VULKAN_AVAILABLE: +if VULKAN_PYTHON_BINDINGS_AVAILABLE: from vulkan import * @@ -499,7 +499,7 @@ def flash_attention2( ) push_constants_init = struct.pack( - "IIIIfIIIII", + "IIIIfIIIIII", batch_size, seq_len, num_heads, @@ -513,63 +513,12 @@ def flash_attention2( 0, # q_tile_idx, k_tile_idx (not used in init) ) - # Dispatch initialization + # Single submit for init + all tile passes + finalize (one fence wait). workgroups_init_x = 16 workgroups_init_y = (num_q_positions + 15) // 16 - self.core._dispatch_compute( - pipeline, - pipeline_layout, - descriptor_set_init, - workgroups_init_x, - push_constants_init, - workgroups_init_y, - ) - - # Pass 1: Process all tiles - for q_tile in range(num_tiles_q): - for k_tile in range(num_tiles_k): - descriptor_set_tile = self.pipelines.get_cached_descriptor_set( - "flash-attention2", - [ - (self._get_buffer_handle(buf_q), q_flat.nbytes), - (self._get_buffer_handle(buf_k), k_flat.nbytes), - (self._get_buffer_handle(buf_v), v_flat.nbytes), - (mask_handle, mask_size), - (self._get_buffer_handle(buf_output), output_accum_size), - (self._get_buffer_handle(buf_running_max), running_max_size), - (self._get_buffer_handle(buf_running_sum), running_sum_size), - (self._get_buffer_handle(buf_output_accum), output_accum_size), - ], - ) - - push_constants_tile = struct.pack( - "IIIIfIIIII", - batch_size, - seq_len, - num_heads, - head_dim, - scale, - tile_size_q, - tile_size_k, - 1, # pass_type = 1 (process tile) - 1 if mask is not None else 0, # has_mask - q_tile, - k_tile, - ) - - # Dispatch tile processing - workgroups_tile_x = (tile_size_k + 15) // 16 - workgroups_tile_y = (batch_size * num_heads * tile_size_q + 15) // 16 - self.core._dispatch_compute( - pipeline, - pipeline_layout, - descriptor_set_tile, - workgroups_tile_x, - push_constants_tile, - workgroups_tile_y, - ) - - # Pass 2: Finalize output + workgroups_final_x = (head_dim + 15) // 16 + workgroups_final_y = (num_q_positions + 15) // 16 + descriptor_set_final = self.pipelines.get_cached_descriptor_set( "flash-attention2", [ @@ -585,7 +534,7 @@ def flash_attention2( ) push_constants_final = struct.pack( - "IIIIfIIIII", + "IIIIfIIIIII", batch_size, seq_len, num_heads, @@ -599,17 +548,65 @@ def flash_attention2( 0, # q_tile_idx, k_tile_idx (not used in finalize) ) - # Dispatch finalization - workgroups_final_x = (head_dim + 15) // 16 - workgroups_final_y = (num_q_positions + 15) // 16 - self.core._dispatch_compute( - pipeline, - pipeline_layout, - descriptor_set_final, - workgroups_final_x, - push_constants_final, - workgroups_final_y, - ) + with self.core.record_commands() as rec: + rec.dispatch( + pipeline, + pipeline_layout, + descriptor_set_init, + (workgroups_init_x, workgroups_init_y, 1), + push_constants_init, + ) + rec.barrier() + + for q_tile in range(num_tiles_q): + for k_tile in range(num_tiles_k): + descriptor_set_tile = self.pipelines.get_cached_descriptor_set( + "flash-attention2", + [ + (self._get_buffer_handle(buf_q), q_flat.nbytes), + (self._get_buffer_handle(buf_k), k_flat.nbytes), + (self._get_buffer_handle(buf_v), v_flat.nbytes), + (mask_handle, mask_size), + (self._get_buffer_handle(buf_output), output_accum_size), + (self._get_buffer_handle(buf_running_max), running_max_size), + (self._get_buffer_handle(buf_running_sum), running_sum_size), + (self._get_buffer_handle(buf_output_accum), output_accum_size), + ], + ) + + push_constants_tile = struct.pack( + "IIIIfIIIIII", + batch_size, + seq_len, + num_heads, + head_dim, + scale, + tile_size_q, + tile_size_k, + 1, # pass_type = 1 (process tile) + 1 if mask is not None else 0, # has_mask + q_tile, + k_tile, + ) + + workgroups_tile_x = (tile_size_k + 15) // 16 + workgroups_tile_y = (batch_size * num_heads * tile_size_q + 15) // 16 + rec.dispatch( + pipeline, + pipeline_layout, + descriptor_set_tile, + (workgroups_tile_x, workgroups_tile_y, 1), + push_constants_tile, + ) + rec.barrier() + + rec.dispatch( + pipeline, + pipeline_layout, + descriptor_set_final, + (workgroups_final_x, workgroups_final_y, 1), + push_constants_final, + ) # Download results result = self._download_buffer(buf_output, output_accum_size, np.float32) diff --git a/backend/base.py b/backend/base.py index 7c931cc..319c746 100644 --- a/backend/base.py +++ b/backend/base.py @@ -1,5 +1,10 @@ """ Base constants and utilities for Vulkan backend. + +``VULKAN_AVAILABLE`` is True when ``grilly_core.Device()`` can initialize (C++ Vulkan; +does not use the PyPI ``vulkan`` package). Optional ctypes bindings set +``VULKAN_PYTHON_BINDINGS_AVAILABLE``; the legacy Python ``VulkanCore`` / ``Compute()`` stack +requires both (see ``VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE``). """ import numpy as np @@ -40,15 +45,30 @@ "VK_PIPELINE_BIND_POINT_COMPUTE": 0, } +def _probe_cpp_vulkan() -> bool: + """True when ``grilly_core`` can initialize a Vulkan device (C++ path; no PyPI ``vulkan``).""" + try: + import grilly_core as gc + + gc.Device() + return True + except Exception: + return False + + +VULKAN_AVAILABLE = _probe_cpp_vulkan() + +# Optional ctypes bindings (``pip install vulkan``) — used by legacy VulkanCore / VulkanCompute. try: from vulkan import * - VULKAN_AVAILABLE = True - # After 'from vulkan import *', all Vulkan constants are in the namespace - # and can be imported by other modules using 'from base import VK_...' + VULKAN_PYTHON_BINDINGS_AVAILABLE = True + # After 'from vulkan import *', constants live in this module for 'from base import VK_...' except ImportError: - VULKAN_AVAILABLE = False - pass + VULKAN_PYTHON_BINDINGS_AVAILABLE = False + +# Full Python VulkanCore stack (ctypes + SPIR-V loaders) — not required for ``grilly_core``-only GPU. +VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE = VULKAN_AVAILABLE and VULKAN_PYTHON_BINDINGS_AVAILABLE # Create dummy constants for type checking when Vulkan is not available # or when a mocked `vulkan` module does not define them (e.g. RTD autodoc). @@ -374,18 +394,20 @@ def _prepare_input(self, data, size=None): When *data* is a VulkanTensor already on GPU, returns its existing buffer without copying. Otherwise allocates a pooled buffer and uploads. + + VulkanTensor residency: ``prepare_for_dispatch()`` binds C++ GPU handles or + pooled buffers so GPU-resident tensors skip redundant CPU→GPU uploads. """ # Avoid hard import at module level - VulkanTensor lives in utils from ..utils.tensor_conversion import VulkanTensor if isinstance(data, VulkanTensor): - if data.on_gpu: - buf = data._pooled_buffer if data._pooled_buffer is not None else None - if buf is not None: - return buf, False - # Has raw GPU buffer but not pooled - wrap for API compat - return _DirectBuffer(data._gpu_buffer, data._gpu_memory, data.nbytes), False - # CPU-backed lazy VulkanTensor + data.prepare_for_dispatch() + if data._pooled_buffer is not None: + return data._pooled_buffer, False + if data._gpu_buffer is not None: + nbytes = int(data.nbytes if size is None else size) + return _DirectBuffer(data._gpu_buffer, data._gpu_memory, nbytes), False arr = np.asarray(data.numpy(), dtype=np.float32).reshape(-1) else: arr = np.asarray(data) @@ -415,4 +437,11 @@ def _prepare_output(self, size): return self._acquire_buffer(size) -__all__ = ["VULKAN_AVAILABLE", "BufferMixin", "_DirectBuffer", "BUFFER_POOL_AVAILABLE"] +__all__ = [ + "VULKAN_AVAILABLE", + "VULKAN_PYTHON_BINDINGS_AVAILABLE", + "VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE", + "BufferMixin", + "_DirectBuffer", + "BUFFER_POOL_AVAILABLE", +] diff --git a/backend/buffer_pool.py b/backend/buffer_pool.py index 6478d9b..5e11076 100644 --- a/backend/buffer_pool.py +++ b/backend/buffer_pool.py @@ -30,7 +30,7 @@ import numpy as np -from .base import VULKAN_AVAILABLE +from .base import VULKAN_PYTHON_BINDINGS_AVAILABLE if TYPE_CHECKING: from .core import VulkanCore @@ -53,7 +53,7 @@ except ImportError: logger.debug("PyVMA2 not available - using direct Vulkan allocation") -if VULKAN_AVAILABLE: +if VULKAN_PYTHON_BINDINGS_AVAILABLE: from vulkan import ( VK_BUFFER_USAGE_STORAGE_BUFFER_BIT, VK_SHARING_MODE_EXCLUSIVE, diff --git a/backend/cells.py b/backend/cells.py index 0f567f5..7f1a349 100644 --- a/backend/cells.py +++ b/backend/cells.py @@ -7,9 +7,9 @@ import numpy as np -from .base import VULKAN_AVAILABLE, BufferMixin +from .base import BufferMixin, VULKAN_PYTHON_BINDINGS_AVAILABLE -if VULKAN_AVAILABLE: +if VULKAN_PYTHON_BINDINGS_AVAILABLE: from vulkan import * diff --git a/backend/channels.py b/backend/channels.py new file mode 100644 index 0000000..ffe9d3d --- /dev/null +++ b/backend/channels.py @@ -0,0 +1,221 @@ +"""grilly.backend.channels — Python-side channel interface. + +Wraps the C++ InProcessChannel with Pythonic API. +Falls back to a pure-Python implementation if C++ not available. + +Usage: + from grilly.backend.channels import Channel, MessageType + + ch = Channel("brain") + ch.send_tensor(numpy_array, sender="vision") + data = ch.receive_tensor() + + ch.send_spikes(spike_array, n_neurons=64, n_timesteps=10) + + # Subscribe to events + ch.on(MessageType.TELEMETRY_EVENT, lambda msg: print(msg)) +""" + +from __future__ import annotations + +import time +import struct +from collections import defaultdict +from enum import IntEnum +from queue import Queue, Empty +from threading import Lock +from typing import Any, Callable, Dict, List, Optional + +import numpy as np + + +class MessageType(IntEnum): + """Message types matching C++ enum.""" + TENSOR_DATA = 0 + SPIKE_TRAIN = 1 + EXPERT_WEIGHTS = 2 + EXPERT_UPDATE = 3 + ROUTE_REQUEST = 4 + ROUTE_RESPONSE = 5 + MEMORY_CAPSULE = 6 + MEMORY_QUERY = 7 + MEMORY_RESULT = 8 + TELEMETRY_EVENT = 9 + NEUROCHEM_STATE = 10 + TRAIN_STEP_REQUEST = 11 + TRAIN_STEP_RESULT = 12 + + +class Message: + """A channel message with type, payload, and metadata.""" + + __slots__ = ("type", "payload", "sender_id", "timestamp_ns", "metadata") + + def __init__(self, msg_type: MessageType, payload: bytes = b"", + sender_id: str = "", metadata: Dict[str, Any] | None = None): + self.type = msg_type + self.payload = payload + self.sender_id = sender_id + self.timestamp_ns = time.time_ns() + self.metadata = metadata or {} + + @property + def size(self) -> int: + return len(self.payload) + + +class Channel: + """High-level channel with C++ backend fallback. + + Tries to use grilly_core.InProcessChannel (C++, zero-copy). + Falls back to pure-Python thread-safe queue. + + Args: + name: Channel name for debugging. + max_queue_size: Maximum queued messages before dropping oldest. + """ + + def __init__(self, name: str = "default", max_queue_size: int = 10000): + self.name = name + self._cpp_channel = None + self._py_queue: Queue = Queue(maxsize=max_queue_size) + self._listeners: Dict[MessageType, List[Callable]] = defaultdict(list) + self._lock = Lock() + + # Try C++ channel + try: + from grilly import _core + if hasattr(_core, "InProcessChannel"): + self._cpp_channel = _core.InProcessChannel(name, max_queue_size) + except Exception: + pass + + @property + def backend(self) -> str: + return "cpp" if self._cpp_channel else "python" + + def send(self, msg: Message) -> None: + """Send a message. Notifies listeners synchronously.""" + # Notify listeners + for callback in self._listeners.get(msg.type, []): + try: + callback(msg) + except Exception: + pass + + if self._cpp_channel: + # TODO: convert Message → C++ MessageEnvelope + pass + + # Python fallback + if self._py_queue.full(): + try: + self._py_queue.get_nowait() # Drop oldest + except Empty: + pass + self._py_queue.put_nowait(msg) + + def receive(self, timeout: float | None = None) -> Message | None: + """Receive next message. Returns None if empty.""" + if self._cpp_channel: + # TODO: receive from C++ channel + pass + + try: + return self._py_queue.get(timeout=timeout) + except Empty: + return None + + def has_messages(self) -> bool: + if self._cpp_channel: + return self._cpp_channel.has_messages() + return not self._py_queue.empty() + + def on(self, msg_type: MessageType, callback: Callable) -> None: + """Subscribe to a message type.""" + self._listeners[msg_type].append(callback) + + def queue_size(self) -> int: + if self._cpp_channel: + return self._cpp_channel.queue_size() + return self._py_queue.qsize() + + def clear(self) -> None: + if self._cpp_channel: + self._cpp_channel.clear() + while not self._py_queue.empty(): + try: + self._py_queue.get_nowait() + except Empty: + break + + # ── Convenience methods ────────────────────────────────────────────── + + def send_tensor(self, arr: np.ndarray, sender: str = "python") -> None: + """Send a numpy array as a TENSOR_DATA message.""" + msg = Message( + msg_type=MessageType.TENSOR_DATA, + payload=arr.astype(np.float32).tobytes(), + sender_id=sender, + metadata={"shape": list(arr.shape), "dtype": str(arr.dtype)}, + ) + self.send(msg) + + def receive_tensor(self, shape: tuple | None = None) -> np.ndarray | None: + """Receive a TENSOR_DATA message as numpy array.""" + msg = self.receive() + if msg is None or msg.type != MessageType.TENSOR_DATA: + return None + arr = np.frombuffer(msg.payload, dtype=np.float32) + if shape: + arr = arr.reshape(shape) + elif "shape" in msg.metadata: + arr = arr.reshape(msg.metadata["shape"]) + return arr + + def send_spikes(self, spikes: np.ndarray, n_neurons: int, + n_timesteps: int, sender: str = "python") -> None: + """Send spike train: (timesteps, neurons) flattened.""" + header = struct.pack(" tuple[np.ndarray, int, int] | None: + """Receive spike train → (spikes, n_neurons, n_timesteps).""" + msg = self.receive() + if msg is None or msg.type != MessageType.SPIKE_TRAIN: + return None + n_neurons, n_timesteps = struct.unpack(" None: + """Send a telemetry event.""" + import json + payload = json.dumps({ + "component_id": component_id, + "event_type": event_type, + "metrics": metrics or {}, + "step": step, + }).encode() + msg = Message( + msg_type=MessageType.TELEMETRY_EVENT, + payload=payload, + sender_id=component_id, + ) + self.send(msg) + + def stats(self) -> Dict[str, Any]: + return { + "name": self.name, + "backend": self.backend, + "queue_size": self.queue_size(), + "listeners": {t.name: len(cbs) for t, cbs in self._listeners.items()}, + } diff --git a/backend/compute.py b/backend/compute.py index 1d0eab8..3dd4cf4 100644 --- a/backend/compute.py +++ b/backend/compute.py @@ -75,6 +75,18 @@ def __init__(self, shader_dir: str = None): self.queue = self.core.queue self.shaders = self.core.shaders + # Public Vulkan dispatch / batching (see docs/PERF_DISPATCH.md) + self.dispatch_compute = self.core._dispatch_compute + self.dispatch_compute_async = self.core._dispatch_compute_async + self.wait_async = self.core._wait_async + self.wait_fence = self.core._wait_fence + + def record_commands(self, fnn_chain: bool = False): + """Batch command recording. Use ``fnn_chain=True`` for :class:`FnnChainRecorder` (linear/relu/softmax + ``read``).""" + if fnn_chain: + return self.fnn.chain_record() + return self.core.record_commands() + def cleanup(self): """Clean up Vulkan resources""" # Clear weight caches for all backend modules (before device is destroyed) diff --git a/backend/contrastive.py b/backend/contrastive.py index 08e2f97..6854a52 100644 --- a/backend/contrastive.py +++ b/backend/contrastive.py @@ -12,9 +12,9 @@ import numpy as np -from .base import VULKAN_AVAILABLE, BufferMixin +from .base import BufferMixin, VULKAN_PYTHON_BINDINGS_AVAILABLE -if VULKAN_AVAILABLE: +if VULKAN_PYTHON_BINDINGS_AVAILABLE: from vulkan import * diff --git a/backend/conv.py b/backend/conv.py index a7f0380..3e72446 100644 --- a/backend/conv.py +++ b/backend/conv.py @@ -12,9 +12,9 @@ import numpy as np -from .base import VULKAN_AVAILABLE, BufferMixin +from .base import BufferMixin, VULKAN_PYTHON_BINDINGS_AVAILABLE -if VULKAN_AVAILABLE: +if VULKAN_PYTHON_BINDINGS_AVAILABLE: from vulkan import * @@ -609,10 +609,118 @@ def _conv2d_backward_weight_gemm( K_dim = in_channels * kernel_h * kernel_w N_cols = batch_size * out_h * out_w + # grad_output: (N, C_out, H_out, W_out) -> (C_out, N*H_out*W_out) + grad_out_reshaped = grad_output.transpose(1, 0, 2, 3).reshape( + out_channels, N_cols + ) # (C_out, N_cols) + + grad_bias = None + if has_bias: + grad_bias = np.sum(grad_output, axis=(0, 2, 3)) # (C_out,) + + use_gpu_gemm = ( + "gemm_mnk" in self.shaders and "tensor-transpose" in self.shaders + ) + # --- Step 1: im2col(input) --- buf_input = self._acquire_buffer(input_data.nbytes) buf_cols = self._acquire_buffer(K_dim * N_cols * 4) + if use_gpu_gemm: + buf_cols_t = self._acquire_buffer(N_cols * K_dim * 4) + buf_grad_out = self._acquire_buffer(grad_out_reshaped.nbytes) + buf_grad_w = self._acquire_buffer(out_channels * K_dim * 4) + try: + self._upload_buffer(buf_input, input_data.flatten()) + + pipeline_im2col, layout_im2col, _ = self.pipelines.get_or_create_pipeline( + "convd_im2col", 2, push_constant_size=56 + ) + + in_handle = self._get_buffer_handle(buf_input) + cols_handle = self._get_buffer_handle(buf_cols) + + desc_im2col = self.pipelines.get_cached_descriptor_set( + "convd_im2col", + [(in_handle, input_data.nbytes), (cols_handle, K_dim * N_cols * 4)], + ) + + push_im2col = struct.pack( + "14I", + batch_size, + in_channels, + in_h, + in_w, + out_h, + out_w, + kernel_h, + kernel_w, + stride_h, + stride_w, + padding_h, + padding_w, + dilation_h, + dilation_w, + ) + + group_x = (K_dim + 15) // 16 + group_y = (N_cols + 15) // 16 + self.core._dispatch_compute( + pipeline_im2col, layout_im2col, desc_im2col, group_x, push_im2col, group_y, 1 + ) + + # cols (K_dim, N_cols) row-major -> cols^T (N_cols, K_dim) for GEMM B + pipeline_tr, layout_tr, _ = self.pipelines.get_or_create_pipeline( + "tensor-transpose", 2, push_constant_size=8 + ) + cols_t_handle = self._get_buffer_handle(buf_cols_t) + desc_tr = self.pipelines.get_cached_descriptor_set( + "tensor-transpose", + [ + (cols_handle, K_dim * N_cols * 4), + (cols_t_handle, N_cols * K_dim * 4), + ], + ) + push_tr = struct.pack("2I", K_dim, N_cols) + wg_tr = (K_dim * N_cols + 255) // 256 + self.core._dispatch_compute(pipeline_tr, layout_tr, desc_tr, wg_tr, push_tr) + + self._upload_buffer(buf_grad_out, grad_out_reshaped.ravel()) + + # C = A @ B: A (C_out, N_cols), B (N_cols, K_dim) -> (C_out, K_dim) + pipeline_gemm, layout_gemm, _ = self.pipelines.get_or_create_pipeline( + "gemm_mnk", 3, push_constant_size=12 + ) + g_handle = self._get_buffer_handle(buf_grad_out) + b_handle = cols_t_handle + c_handle = self._get_buffer_handle(buf_grad_w) + desc_gemm = self.pipelines.get_cached_descriptor_set( + "gemm_mnk", + [ + (g_handle, grad_out_reshaped.nbytes), + (b_handle, N_cols * K_dim * 4), + (c_handle, out_channels * K_dim * 4), + ], + ) + push_gemm = struct.pack("3I", out_channels, N_cols, K_dim) + group_gx = (K_dim + 15) // 16 + group_gy = (out_channels + 15) // 16 + self.core._dispatch_compute( + pipeline_gemm, layout_gemm, desc_gemm, group_gx, push_gemm, group_gy, 1 + ) + + grad_weight_flat = self._download_buffer( + buf_grad_w, out_channels * K_dim * 4, np.float32 + ) + grad_weight = grad_weight_flat.reshape( + out_channels, in_channels, kernel_h, kernel_w + ) + return grad_weight.astype(np.float32), grad_bias + finally: + self._release_buffers( + [buf_input, buf_cols, buf_cols_t, buf_grad_out, buf_grad_w] + ) + try: self._upload_buffer(buf_input, input_data.flatten()) @@ -651,29 +759,12 @@ def _conv2d_backward_weight_gemm( pipeline_im2col, layout_im2col, desc_im2col, group_x, push_im2col, group_y, 1 ) - # Download cols for GEMM (could be done on GPU, but for now download) cols_flat = self._download_buffer(buf_cols, K_dim * N_cols * 4, np.float32) - cols = cols_flat.reshape(K_dim, N_cols) # (K_dim, N_cols) + cols = cols_flat.reshape(K_dim, N_cols) - # --- Step 2: Prepare grad_output for GEMM --- - # grad_output: (N, C_out, H_out, W_out) -> (C_out, N*H_out*W_out) - grad_out_reshaped = grad_output.transpose(1, 0, 2, 3).reshape( - out_channels, N_cols - ) # (C_out, N_cols) - - # --- Step 3: GEMM: grad_weight = grad_out @ cols.T --- - # (C_out, N_cols) @ (N_cols, K_dim) = (C_out, K_dim) - grad_weight_flat = grad_out_reshaped @ cols.T # (C_out, K_dim) - - # Reshape to (C_out, C_in, kH, kW) + grad_weight_flat = grad_out_reshaped @ cols.T grad_weight = grad_weight_flat.reshape(out_channels, in_channels, kernel_h, kernel_w) - # --- Step 4: Compute grad_bias --- - grad_bias = None - if has_bias: - # Sum over all spatial positions and batch - grad_bias = np.sum(grad_output, axis=(0, 2, 3)) # (C_out,) - return grad_weight.astype(np.float32), grad_bias finally: @@ -721,6 +812,34 @@ def conv2d_backward_weight( grad_output, input_data, kernel_size, stride, padding, dilation, groups, has_bias ) + # 1×1, stride 1, pad 0, dilation 1, groups 1: optional atomic float shader (no batch dim in shader — loop batches) + kernel_h, kernel_w = kernel_size + stride_h, stride_w = stride + padding_h, padding_w = padding + dilation_h, dilation_w = dilation + batch_size, out_channels, out_height, out_width = grad_output.shape + _, in_channels, in_height, in_width = input_data.shape + atomic_ext = getattr(self.core, "enabled_extensions", None) or set() + can_1x1_atomic = ( + "conv1x1-backward-weight" in self.shaders + and "VK_EXT_shader_atomic_float" in atomic_ext + and kernel_h == 1 + and kernel_w == 1 + and stride_h == 1 + and stride_w == 1 + and padding_h == 0 + and padding_w == 0 + and dilation_h == 1 + and dilation_w == 1 + and groups == 1 + and out_height == in_height + and out_width == in_width + ) + if can_1x1_atomic: + return self._conv1x1_backward_weight_atomic( + grad_output, input_data, out_channels, in_channels, has_bias + ) + # Check if shader is available if "conv2d-backward-weight" not in self.shaders: return self._conv2d_backward_weight_cpu( @@ -833,6 +952,64 @@ def conv2d_backward_weight( finally: self._release_buffers([buf_grad_output, buf_input, buf_grad_weight, buf_grad_bias]) + def _conv1x1_backward_weight_atomic( + self, + grad_output: np.ndarray, + input_data: np.ndarray, + out_channels: int, + in_channels: int, + has_bias: bool, + ) -> tuple[np.ndarray, np.ndarray | None]: + """1×1 conv backward weight: ``conv1x1-backward-weight.glsl`` (buffer float atomics). Shader is batch=1; we dispatch per batch and accumulate.""" + batch_size, _, out_height, out_width = grad_output.shape + flat_in = in_channels * out_height * out_width * 4 + flat_dy = out_channels * out_height * out_width * 4 + num_w = out_channels * in_channels * 4 + + buf_in = self._acquire_buffer(flat_in) + buf_dy = self._acquire_buffer(flat_dy) + buf_gw = self._acquire_buffer(num_w) + + try: + self._upload_buffer(buf_gw, np.zeros(out_channels * in_channels, dtype=np.float32)) + pipeline, pipeline_layout, _ = self.pipelines.get_or_create_pipeline( + "conv1x1-backward-weight", 3, push_constant_size=16 + ) + gx = (out_width + 15) // 16 + gy = (out_height + 15) // 16 + push = struct.pack("4I", out_width, out_height, in_channels, out_channels) + + for b in range(batch_size): + self._upload_buffer(buf_in, input_data[b].ravel()) + self._upload_buffer(buf_dy, grad_output[b].ravel()) + descriptor_set = self.pipelines.get_cached_descriptor_set( + "conv1x1-backward-weight", + [ + (self._get_buffer_handle(buf_in), flat_in), + (self._get_buffer_handle(buf_dy), flat_dy), + (self._get_buffer_handle(buf_gw), num_w), + ], + ) + self.core._dispatch_compute( + pipeline, + pipeline_layout, + descriptor_set, + gx, + push, + gy, + 1, + wait_previous=(b != 0), + ) + + grad_weight_flat = self._download_buffer(buf_gw, num_w, np.float32) + grad_weight = grad_weight_flat.reshape(out_channels, in_channels, 1, 1) + grad_bias = None + if has_bias: + grad_bias = np.sum(grad_output, axis=(0, 2, 3), dtype=np.float32) + return grad_weight, grad_bias + finally: + self._release_buffers([buf_in, buf_dy, buf_gw]) + # CPU fallbacks (using numpy) def _conv2d_cpu(self, input_data, weight, bias, stride, padding, dilation, groups): """CPU fallback for conv2d forward pass""" diff --git a/backend/core.py b/backend/core.py index 3421b9d..515a0e3 100644 --- a/backend/core.py +++ b/backend/core.py @@ -10,9 +10,13 @@ import numpy as np -from .base import VK_BUFFER_USAGE_STORAGE_BUFFER_BIT, VULKAN_AVAILABLE +from .base import ( + VK_BUFFER_USAGE_STORAGE_BUFFER_BIT, + VULKAN_AVAILABLE, + VULKAN_PYTHON_BINDINGS_AVAILABLE, +) -if VULKAN_AVAILABLE: +if VULKAN_PYTHON_BINDINGS_AVAILABLE: from vulkan import * logger = logging.getLogger(__name__) @@ -32,7 +36,12 @@ def __init__(self, shader_dir: str = None): """Initialize the instance.""" if not VULKAN_AVAILABLE: - raise RuntimeError("Vulkan not available") + raise RuntimeError("Vulkan not available (C++ grilly_core could not initialize GPU)") + if not VULKAN_PYTHON_BINDINGS_AVAILABLE: + raise RuntimeError( + "VulkanCore requires the Python `vulkan` package (ctypes bindings). " + "Install with: pip install vulkan" + ) import os # Disable Mesa device_select layer which can force CPU llvmpipe @@ -681,8 +690,15 @@ def _dispatch_compute( push_constants: bytes = None, workgroup_y: int = 1, workgroup_z: int = 1, + *, + wait_previous: bool = True, ): - """Dispatch compute shader using pre-allocated command buffer.""" + """Dispatch compute shader using pre-allocated command buffer. + + Args: + wait_previous: If False, skip waiting for previous dispatch. + Use only when you know the queue is idle (e.g., first dispatch). + """ command_buffer = self._cmd_buffer # Reset and re-record the reusable command buffer @@ -726,7 +742,8 @@ def _dispatch_compute( vkEndCommandBuffer(command_buffer) # Wait for previous dispatch to finish, then reset the fence. - vkWaitForFences(self.device, 1, [self._fence], VK_TRUE, 2_000_000_000) + if wait_previous: + vkWaitForFences(self.device, 1, [self._fence], VK_TRUE, 2_000_000_000) vkResetFences(self.device, 1, [self._fence]) # Submit with fence — avoids the heavier vkQueueWaitIdle drain. @@ -746,6 +763,75 @@ def _wait_fence(self, timeout_ns: int = 2_000_000_000): """Wait for the most recent dispatch to complete.""" vkWaitForFences(self.device, 1, [self._fence], VK_TRUE, timeout_ns) + def _dispatch_compute_async( + self, + pipeline, + pipeline_layout, + descriptor_set, + workgroup_x: int, + push_constants: bytes = None, + workgroup_y: int = 1, + workgroup_z: int = 1, + ): + """Async dispatch: record command and return without waiting. + + Returns a handle that can be waited on later via _wait_async(). + Used for batching multiple dispatches before a single fence wait. + """ + command_buffer = self._cmd_buffer + + # Reset and record the reusable command buffer + vkResetCommandBuffer(command_buffer, 0) + + begin_info = VkCommandBufferBeginInfo( + sType=VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO, + flags=VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT, + ) + vkBeginCommandBuffer(command_buffer, begin_info) + + vkCmdBindPipeline(command_buffer, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline) + vkCmdBindDescriptorSets( + command_buffer, + VK_PIPELINE_BIND_POINT_COMPUTE, + pipeline_layout, + 0, + 1, + [descriptor_set], + 0, + None, + ) + + if push_constants: + push_buf = ctypes.create_string_buffer(push_constants) + vkCmdPushConstants( + command_buffer, + pipeline_layout, + VK_SHADER_STAGE_COMPUTE_BIT, + 0, + len(push_constants), + ctypes.addressof(push_buf), + ) + + vkCmdDispatch(command_buffer, workgroup_x, workgroup_y, workgroup_z) + vkEndCommandBuffer(command_buffer) + + # Reset fence and submit (do not wait) + vkResetFences(self.device, 1, [self._fence]) + + submit_info = VkSubmitInfo( + sType=VK_STRUCTURE_TYPE_SUBMIT_INFO, + commandBufferCount=1, + pCommandBuffers=[command_buffer], + ) + vkQueueSubmit(self.queue, 1, [submit_info], self._fence) + + # Return handle for optional wait + return self._fence + + def _wait_async(self, timeout_ns: int = 2_000_000_000): + """Wait for async dispatch to complete.""" + vkWaitForFences(self.device, 1, [self._fence], VK_TRUE, timeout_ns) + def cleanup(self): """Cleanup Vulkan resources""" if hasattr(self, "device") and self.device: diff --git a/backend/faiss.py b/backend/faiss.py index b156f89..3d6e903 100644 --- a/backend/faiss.py +++ b/backend/faiss.py @@ -8,9 +8,9 @@ import numpy as np -from .base import VULKAN_AVAILABLE, BufferMixin +from .base import BufferMixin, VULKAN_PYTHON_BINDINGS_AVAILABLE -if VULKAN_AVAILABLE: +if VULKAN_PYTHON_BINDINGS_AVAILABLE: from vulkan import * logger = logging.getLogger(__name__) @@ -132,7 +132,7 @@ def topk(self, distances, k): k = min(int(k), num_database) # Fallback to CPU when Vulkan is unavailable - if not VULKAN_AVAILABLE or self.core is None: + if not VULKAN_PYTHON_BINDINGS_AVAILABLE or self.core is None: return self._cpu_topk(distances, k) # GPU attempt with validation fallback diff --git a/backend/fft.py b/backend/fft.py index a2d4d11..805eed4 100644 --- a/backend/fft.py +++ b/backend/fft.py @@ -16,9 +16,9 @@ import numpy as np -from .base import VULKAN_AVAILABLE, BufferMixin +from .base import BufferMixin, VULKAN_PYTHON_BINDINGS_AVAILABLE -if VULKAN_AVAILABLE: +if VULKAN_PYTHON_BINDINGS_AVAILABLE: from vulkan import * diff --git a/backend/fnn.py b/backend/fnn.py index 735a6c4..d28461a 100644 --- a/backend/fnn.py +++ b/backend/fnn.py @@ -12,9 +12,9 @@ import numpy as np -from .base import VULKAN_AVAILABLE, BufferMixin +from .base import BufferMixin, VULKAN_PYTHON_BINDINGS_AVAILABLE -if VULKAN_AVAILABLE: +if VULKAN_PYTHON_BINDINGS_AVAILABLE: from vulkan import * # Try to import numba-accelerated operations for CPU fallback @@ -73,6 +73,12 @@ def __init__(self, core, pipelines, shaders): self.shaders = shaders self._pool = None # Lazy initialization + def chain_record(self): + """Return a :class:`~grilly.backend.fnn_chain.FnnChainRecorder` for batched linear/relu/softmax + single submit.""" + from .fnn_chain import FnnChainRecorder + + return FnnChainRecorder(self) + def gemm(self, A, B, return_gpu_tensor=False, cache_B=False, force_fp32=False): """ GEMM: C = A @ B @@ -2698,6 +2704,117 @@ def fused_linear_gelu( self._release_buffer(buf_output) return result.reshape(output_shape) + def _linear_relu_recorded_chain( + self, + x, + weights: np.ndarray, + bias: np.ndarray | None = None, + ) -> np.ndarray: + """Linear then ReLU in one queue submit (fused shader absent).""" + if "fnn-linear" not in self.shaders or "activation-relu" not in self.shaders: + raise RuntimeError("linear/relu shaders unavailable") + + original_shape = x.shape + input_dim = x.shape[-1] + output_dim = weights.shape[0] + + if len(original_shape) > 2: + batch_seq = int(np.prod(original_shape[:-1])) + else: + batch_seq = original_shape[0] if len(original_shape) == 2 else 1 + + input_nbytes = batch_seq * input_dim * 4 + output_size = batch_seq * output_dim * 4 + + buf_input, release_input = self._prepare_input(x, size=input_nbytes) + + w_np = np.ascontiguousarray(weights, dtype=np.float32) + w_nbytes = int(np.prod(weights.shape)) * 4 + buf_weights, release_weights = self._get_or_upload_weight(w_np) + + if bias is not None: + bias_np = np.ascontiguousarray(bias, dtype=np.float32) + buf_bias, release_bias = self._get_or_upload_weight(bias_np) + b_nbytes = bias_np.size * 4 + has_bias = 1 + else: + b_flat = np.zeros(output_dim, dtype=np.float32) + buf_bias = self._acquire_buffer(b_flat.nbytes) + self._upload_buffer(buf_bias, b_flat) + b_nbytes = b_flat.nbytes + release_bias = True + has_bias = 0 + + buf_linear_out = self._acquire_buffer(output_size) + buf_relu_out = self._acquire_buffer(output_size) + + pipeline_l, pipeline_layout_l, _ = self.pipelines.get_or_create_pipeline( + "fnn-linear", 4, push_constant_size=16 + ) + descriptor_set_l = self.pipelines.get_cached_descriptor_set( + "fnn-linear", + [ + (self._get_buffer_handle(buf_input), input_nbytes), + (self._get_buffer_handle(buf_weights), w_nbytes), + (self._get_buffer_handle(buf_bias), b_nbytes), + (self._get_buffer_handle(buf_linear_out), output_size), + ], + ) + + push_l = struct.pack("IIII", batch_seq, input_dim, output_dim, has_bias) + workgroups_x = (output_dim + 15) // 16 + workgroups_y = (batch_seq + 15) // 16 + + total_elements = batch_seq * output_dim + push_r = struct.pack("I", total_elements) + workgroups_r = (total_elements + 255) // 256 + + pipeline_r, pipeline_layout_r, _ = self.pipelines.get_or_create_pipeline( + "activation-relu", 2, push_constant_size=4 + ) + descriptor_set_r = self.pipelines.get_cached_descriptor_set( + "activation-relu", + [ + (self._get_buffer_handle(buf_linear_out), output_size), + (self._get_buffer_handle(buf_relu_out), output_size), + ], + ) + + with self.core.record_commands() as rec: + rec.dispatch( + pipeline_l, + pipeline_layout_l, + descriptor_set_l, + (workgroups_x, workgroups_y), + push_l, + ) + rec.barrier() + rec.dispatch( + pipeline_r, + pipeline_layout_r, + descriptor_set_r, + (workgroups_r,), + push_r, + ) + + if len(original_shape) > 2: + output_shape = original_shape[:-1] + (output_dim,) + else: + output_shape = (batch_seq, output_dim) + + result = self._download_buffer(buf_relu_out, output_size, np.float32) + + if release_input: + self._release_buffer(buf_input) + if release_weights: + self._release_buffer(buf_weights) + if release_bias: + self._release_buffer(buf_bias) + self._release_buffer(buf_linear_out) + self._release_buffer(buf_relu_out) + + return result.reshape(output_shape) + def fused_linear_relu( self, x, @@ -2720,6 +2837,16 @@ def fused_linear_relu( ReLU(Linear(x)) """ if "fused-linear-relu" not in self.shaders: + if ( + not return_gpu_tensor + and hasattr(self.core, "record_commands") + and "fnn-linear" in self.shaders + and "activation-relu" in self.shaders + ): + try: + return self._linear_relu_recorded_chain(x, weights, bias) + except Exception: + pass linear_out = self.linear(x, weights, bias, return_gpu_tensor=return_gpu_tensor) return self.activation_relu(linear_out, return_gpu_tensor=return_gpu_tensor) diff --git a/backend/fnn_chain.py b/backend/fnn_chain.py new file mode 100644 index 0000000..2918f5f --- /dev/null +++ b/backend/fnn_chain.py @@ -0,0 +1,281 @@ +""" +Batched FNN recording: multiple linear/relu/softmax dispatches, one submit + one fence wait. + +Use ``VulkanCompute.record_commands(fnn_chain=True)`` or ``VulkanFNN.chain_record()``. +""" + +from __future__ import annotations + +import struct +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Union + +import numpy as np + +from .core import CommandRecorder + +if TYPE_CHECKING: + from .fnn import VulkanFNN + + +@dataclass(frozen=True) +class ChainBufferHandle: + """GPU buffer produced inside a chain; pass to the next ``rec.*`` call or ``rec.read()``.""" + + buf: Any + nbytes: int + shape: tuple[int, ...] + + @property + def ndim(self) -> int: + return len(self.shape) + + +class FnnChainRecorder(CommandRecorder): + """Records linear / ReLU / softmax into the batch command buffer without per-op submit. + + Buffers acquired for numpy inputs and intermediate outputs are released when + ``read()`` runs or when the context exits. + """ + + def __init__(self, fnn: VulkanFNN): + super().__init__(fnn.core) + self._fnn = fnn + self._owned: list[Any] = [] + self._submitted = False + self._released = False + + def _track(self, buf: Any) -> None: + if buf is not None: + self._owned.append(buf) + + def _release_owned(self) -> None: + if self._released: + return + for buf in self._owned: + try: + self._fnn._release_buffer(buf) + except Exception: + pass + self._owned.clear() + self._released = True + + def linear( + self, + x: Union[np.ndarray, ChainBufferHandle, Any], + weights: np.ndarray, + bias: np.ndarray | None = None, + ) -> ChainBufferHandle: + """Record ``fnn-linear`` dispatch. Returns a handle to the output buffer (no download).""" + if not self._recording: + raise RuntimeError("linear() must be called inside a chain_record context") + fnn = self._fnn + if "fnn-linear" not in fnn.shaders: + raise RuntimeError("fnn-linear shader not available") + + from ..utils.tensor_conversion import VulkanTensor + + w_np = np.ascontiguousarray(weights, dtype=np.float32) + output_dim, input_dim = w_np.shape + + if isinstance(x, ChainBufferHandle): + buf_input = x.buf + original_shape = x.shape + if len(original_shape) > 2: + batch_seq = int(np.prod(original_shape[:-1])) + else: + batch_seq = original_shape[0] + if original_shape[-1] != input_dim: + raise ValueError( + f"ChainBufferHandle last dim {original_shape[-1]} != weight input_dim {input_dim}" + ) + input_nbytes = batch_seq * input_dim * 4 + elif isinstance(x, VulkanTensor): + original_shape = x.shape + if len(original_shape) > 2: + batch_seq = int(np.prod(original_shape[:-1])) + else: + batch_seq = original_shape[0] + input_nbytes = batch_seq * input_dim * 4 + buf_input, release_input = fnn._prepare_input(x, size=input_nbytes) + if release_input: + self._track(buf_input) + else: + x_np = np.asarray(x, dtype=np.float32) + original_shape = x_np.shape + if len(original_shape) > 2: + batch_seq = int(np.prod(original_shape[:-1])) + x_2d = x_np.reshape(batch_seq, input_dim) + else: + batch_seq = original_shape[0] + x_2d = x_np + x_flat = np.ascontiguousarray(x_2d).ravel() + input_nbytes = x_flat.nbytes + buf_input = fnn._acquire_buffer(input_nbytes) + fnn._upload_buffer(buf_input, x_flat) + self._track(buf_input) + release_input = True + + w_nbytes = int(w_np.size) * 4 + output_size = batch_seq * output_dim * 4 + + buf_weights, _ = fnn._get_or_upload_weight(w_np) + has_bias = 1 if bias is not None else 0 + if bias is not None: + bias_np = np.ascontiguousarray(bias, dtype=np.float32) + buf_bias, _ = fnn._get_or_upload_weight(bias_np) + bias_flat = bias_np.ravel() + bias_nbytes = bias_flat.nbytes + else: + buf_bias = fnn._acquire_buffer(4) + fnn._upload_buffer(buf_bias, np.zeros(1, dtype=np.float32)) + self._track(buf_bias) + bias_flat = None + bias_nbytes = 4 + + buf_output = fnn._acquire_buffer(output_size) + self._track(buf_output) + + pipeline, pipeline_layout, _ = fnn.pipelines.get_or_create_pipeline( + "fnn-linear", 4, push_constant_size=16 + ) + descriptor_set = fnn.pipelines.get_cached_descriptor_set( + "fnn-linear", + [ + (fnn._get_buffer_handle(buf_input), input_nbytes), + (fnn._get_buffer_handle(buf_weights), w_nbytes), + ( + fnn._get_buffer_handle(buf_bias), + bias_nbytes, + ), + (fnn._get_buffer_handle(buf_output), output_size), + ], + ) + push = struct.pack("IIII", batch_seq, input_dim, output_dim, has_bias) + gx = (output_dim + 15) // 16 + gy = (batch_seq + 15) // 16 + self.dispatch(pipeline, pipeline_layout, descriptor_set, (gx, gy), push) + self.barrier() + + if len(original_shape) > 2: + out_shape = original_shape[:-1] + (output_dim,) + else: + out_shape = (batch_seq, output_dim) + + return ChainBufferHandle(buf_output, output_size, out_shape) + + def relu(self, h: ChainBufferHandle) -> ChainBufferHandle: + """Record ``activation-relu`` dispatch.""" + if not self._recording: + raise RuntimeError("relu() must be called inside a chain_record context") + fnn = self._fnn + if "activation-relu" not in fnn.shaders: + raise RuntimeError("activation-relu shader not available") + + total_elements = h.nbytes // 4 + buf_out = fnn._acquire_buffer(h.nbytes) + self._track(buf_out) + + pipeline, pipeline_layout, _ = fnn.pipelines.get_or_create_pipeline( + "activation-relu", 2, push_constant_size=4 + ) + descriptor_set = fnn.pipelines.get_cached_descriptor_set( + "activation-relu", + [ + (fnn._get_buffer_handle(h.buf), h.nbytes), + (fnn._get_buffer_handle(buf_out), h.nbytes), + ], + ) + push = struct.pack("I", total_elements) + wg = (total_elements + 255) // 256 + self.dispatch(pipeline, pipeline_layout, descriptor_set, (wg,), push) + self.barrier() + + return ChainBufferHandle(buf_out, h.nbytes, h.shape) + + def softmax(self, h: ChainBufferHandle, dim: int = -1) -> ChainBufferHandle: + """Record ``activation-softmax`` (3 passes). ``dim`` must be the last axis (-1).""" + if not self._recording: + raise RuntimeError("softmax() must be called inside a chain_record context") + if dim not in (-1, len(h.shape) - 1): + raise NotImplementedError( + "FnnChainRecorder.softmax only supports softmax over the last dimension (dim=-1)" + ) + fnn = self._fnn + if "activation-softmax" not in fnn.shaders: + raise RuntimeError("activation-softmax shader not available") + + shape = h.shape + if len(shape) == 1: + batch_size, seq_len, features = 1, 1, shape[0] + elif len(shape) == 2: + batch_size, seq_len, features = shape[0], 1, shape[1] + else: + batch_size, seq_len, features = shape[0], shape[1], int(np.prod(shape[2:])) + + data_nbytes = h.nbytes + buf_max = fnn._acquire_buffer(batch_size * seq_len * 4) + buf_sum = fnn._acquire_buffer(batch_size * seq_len * 4) + self._track(buf_max) + self._track(buf_sum) + + buf_out = fnn._acquire_buffer(data_nbytes) + self._track(buf_out) + + pipeline, pipeline_layout, _ = fnn.pipelines.get_or_create_pipeline( + "activation-softmax", 4, push_constant_size=24 + ) + descriptor_set = fnn.pipelines.get_cached_descriptor_set( + "activation-softmax", + [ + (fnn._get_buffer_handle(h.buf), data_nbytes), + (fnn._get_buffer_handle(buf_out), data_nbytes), + (fnn._get_buffer_handle(buf_max), batch_size * seq_len * 4), + (fnn._get_buffer_handle(buf_sum), batch_size * seq_len * 4), + ], + ) + + workgroups_bs = ((batch_size * seq_len) + 255) // 256 + workgroups_flat = (int(np.prod(shape)) + 255) // 256 + + for pass_id in (0, 1, 2): + if pass_id > 0: + self.barrier() + push = struct.pack("IIIII", batch_size, seq_len, features, pass_id, features) + wg = workgroups_bs if pass_id < 2 else workgroups_flat + self.dispatch(pipeline, pipeline_layout, descriptor_set, (wg,), push) + self.barrier() + + return ChainBufferHandle(buf_out, data_nbytes, shape) + + def read_multiple(self, handles: list[ChainBufferHandle]) -> list[np.ndarray]: + """Submit once, wait, download every handle to numpy (same order as *handles*), then release chain-owned buffers. + + Use for MoE fan-out: several ``linear`` / other ops recorded, one fence wait, multiple CPU reads. + """ + if not self._recording: + raise RuntimeError( + "read_multiple() must be called before exiting the chain_record context" + ) + self.submit_and_wait() + self._submitted = True + out: list[np.ndarray] = [] + for h in handles: + arr = self._fnn._download_buffer(h.buf, h.nbytes, np.float32) + out.append(arr.reshape(h.shape)) + self._release_owned() + return out + + def read(self, h: ChainBufferHandle) -> np.ndarray: + """Submit recorded commands, wait, download *h* to numpy, release chain-owned buffers.""" + return self.read_multiple([h])[0] + + def __exit__(self, exc_type, exc_val, exc_tb): + if not self._submitted and self._recording: + try: + self.submit_and_wait() + except Exception: + pass + if not self._released: + self._release_owned() + return CommandRecorder.__exit__(self, exc_type, exc_val, exc_tb) diff --git a/backend/jit.py b/backend/jit.py index 3eef6fc..aa56cf2 100644 --- a/backend/jit.py +++ b/backend/jit.py @@ -24,6 +24,7 @@ def forward(x): import functools import logging +from collections import OrderedDict logger = logging.getLogger("grilly.jit") @@ -234,7 +235,7 @@ class _JitWrapper: def __init__(self, fn, warmup: int = 1): self._fn = fn self._warmup = warmup - self._graphs: dict[tuple, TracedGraph] = {} # cache_key → graph + self._graphs: OrderedDict[tuple, TracedGraph] = OrderedDict() # cache_key → graph (LRU) self._warmup_counts: dict[tuple, int] = {} functools.update_wrapper(self, fn) @@ -252,6 +253,7 @@ def __call__(self, *args, **kwargs): # Check if we have a compiled graph for this shape+kwargs combo if key in self._graphs: + self._graphs.move_to_end(key) # Replay (for now, execute normally — C++ OpGraph replay TBD) return self._fn(*args, **kwargs) @@ -263,14 +265,14 @@ def __call__(self, *args, **kwargs): if count >= self._warmup: # Trace and cache - graph = trace(self._fn, args) + graph = trace(lambda *a: self._fn(*a, **kwargs), args) self._graphs[key] = graph + self._graphs.move_to_end(key) logger.info("JIT compiled: %d ops captured (key=%s)", graph.num_ops, key[:2]) # LRU eviction if len(self._graphs) > self._MAX_CACHED_GRAPHS: - oldest_key = next(iter(self._graphs)) - del self._graphs[oldest_key] + self._graphs.popitem(last=False) logger.info("JIT evicted oldest graph (cache size=%d)", len(self._graphs)) return result diff --git a/backend/learning.py b/backend/learning.py index 8bd9f89..56087ee 100644 --- a/backend/learning.py +++ b/backend/learning.py @@ -15,9 +15,9 @@ import numpy as np -from .base import VULKAN_AVAILABLE, BufferMixin +from .base import BufferMixin, VULKAN_PYTHON_BINDINGS_AVAILABLE -if VULKAN_AVAILABLE: +if VULKAN_PYTHON_BINDINGS_AVAILABLE: from vulkan import * @@ -31,7 +31,7 @@ def __init__(self, core, pipelines, shaders): self.shaders = shaders def _free_descriptor_set(self, descriptor_set) -> None: - if descriptor_set is None or not VULKAN_AVAILABLE: + if descriptor_set is None or not VULKAN_PYTHON_BINDINGS_AVAILABLE: return try: vkFreeDescriptorSets( diff --git a/backend/lora.py b/backend/lora.py index 6aceecf..092cb77 100644 --- a/backend/lora.py +++ b/backend/lora.py @@ -14,9 +14,9 @@ import numpy as np -from .base import VULKAN_AVAILABLE, BufferMixin +from .base import BufferMixin, VULKAN_PYTHON_BINDINGS_AVAILABLE -if VULKAN_AVAILABLE: +if VULKAN_PYTHON_BINDINGS_AVAILABLE: from vulkan import * @@ -84,7 +84,7 @@ def forward( rank = A.shape[0] # Check if GPU shader available - if not VULKAN_AVAILABLE or "lora-forward" not in self.shaders: + if not VULKAN_PYTHON_BINDINGS_AVAILABLE or "lora-forward" not in self.shaders: return self._forward_cpu(x, W, A, B, scale, bias) try: @@ -270,7 +270,7 @@ def backward( rank = A.shape[0] # Check if GPU shader available - if not VULKAN_AVAILABLE or "lora-backward" not in self.shaders: + if not VULKAN_PYTHON_BINDINGS_AVAILABLE or "lora-backward" not in self.shaders: return self._backward_cpu(grad_output, x, A, B, h, scale) try: diff --git a/backend/normalization.py b/backend/normalization.py index 4a2658f..09f3ed2 100644 --- a/backend/normalization.py +++ b/backend/normalization.py @@ -7,9 +7,9 @@ import numpy as np -from .base import VULKAN_AVAILABLE, BufferMixin +from .base import BufferMixin, VULKAN_PYTHON_BINDINGS_AVAILABLE -if VULKAN_AVAILABLE: +if VULKAN_PYTHON_BINDINGS_AVAILABLE: from vulkan import * diff --git a/backend/pipelines.py b/backend/pipelines.py index b9aa5b2..ee939b5 100644 --- a/backend/pipelines.py +++ b/backend/pipelines.py @@ -4,9 +4,9 @@ from collections import OrderedDict -from .base import VULKAN_AVAILABLE +from .base import VULKAN_PYTHON_BINDINGS_AVAILABLE -if VULKAN_AVAILABLE: +if VULKAN_PYTHON_BINDINGS_AVAILABLE: from vulkan import * diff --git a/backend/pooling.py b/backend/pooling.py index 7352975..b4d22b3 100644 --- a/backend/pooling.py +++ b/backend/pooling.py @@ -7,9 +7,9 @@ import numpy as np -from .base import VULKAN_AVAILABLE, BufferMixin +from .base import BufferMixin, VULKAN_PYTHON_BINDINGS_AVAILABLE -if VULKAN_AVAILABLE: +if VULKAN_PYTHON_BINDINGS_AVAILABLE: from vulkan import * diff --git a/backend/snn.py b/backend/snn.py index 1229c74..512f96f 100644 --- a/backend/snn.py +++ b/backend/snn.py @@ -6,9 +6,9 @@ import numpy as np -from .base import VULKAN_AVAILABLE, BufferMixin +from .base import BufferMixin, VULKAN_PYTHON_BINDINGS_AVAILABLE -if VULKAN_AVAILABLE: +if VULKAN_PYTHON_BINDINGS_AVAILABLE: from vulkan import * diff --git a/backend/snn_compute.py b/backend/snn_compute.py index e2cbf51..5e2b59c 100644 --- a/backend/snn_compute.py +++ b/backend/snn_compute.py @@ -6,7 +6,7 @@ import numpy as np -from .base import VULKAN_AVAILABLE +from .base import VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE from .compute import VulkanCompute # Import config for SNN parameters (optional - for integration with main project) @@ -70,7 +70,7 @@ def __init__(self, n_neurons: int = None, use_vulkan: bool = True): self.use_vulkan = False self.backend = None - if use_vulkan and VULKAN_AVAILABLE: + if use_vulkan and VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE: try: self.backend = VulkanCompute() self.use_vulkan = True diff --git a/backend/tensor_ops.py b/backend/tensor_ops.py index 6641333..50d9eff 100644 --- a/backend/tensor_ops.py +++ b/backend/tensor_ops.py @@ -15,9 +15,9 @@ import numpy as np -from .base import VULKAN_AVAILABLE, BufferMixin +from .base import BufferMixin, VULKAN_PYTHON_BINDINGS_AVAILABLE -if VULKAN_AVAILABLE: +if VULKAN_PYTHON_BINDINGS_AVAILABLE: from vulkan import * # noqa: F401,F403 diff --git a/benchmarks/benchmark_conv_backward_weight.py b/benchmarks/benchmark_conv_backward_weight.py new file mode 100644 index 0000000..5ffc46d --- /dev/null +++ b/benchmarks/benchmark_conv_backward_weight.py @@ -0,0 +1,68 @@ +#!/usr/bin/env python3 +""" +Baseline timings for conv2d backward weight (Workstream C1). + +The Vulkan training path uses `conv2d-backward-weight` (non-atomic, one thread +per weight slot) or the GEMM path (im2col + CPU matmul when `convd_im2col` is +available). This script prints wall time for a representative backward-weight call. + +Usage: + uv run python benchmarks/benchmark_conv_backward_weight.py +""" + +from __future__ import annotations + +import time +import warnings + +import numpy as np + +from grilly.backend.compute import VulkanCompute + + +def main() -> None: + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + backend = VulkanCompute() + try: + rng = np.random.default_rng(0) + batch_size = 4 + in_ch, out_ch = 32, 64 + h, w = 16, 16 + kh, kw = 3, 3 + grad_out = rng.standard_normal((batch_size, out_ch, h, w), dtype=np.float32) + inp = rng.standard_normal((batch_size, in_ch, h, w), dtype=np.float32) + + # Warmup + backend.conv.conv2d_backward_weight( + grad_out, + inp, + (kh, kw), + stride=(1, 1), + padding=(1, 1), + dilation=(1, 1), + groups=1, + has_bias=True, + ) + + n = 5 + t0 = time.perf_counter() + for _ in range(n): + backend.conv.conv2d_backward_weight( + grad_out, + inp, + (kh, kw), + stride=(1, 1), + padding=(1, 1), + dilation=(1, 1), + groups=1, + has_bias=True, + ) + elapsed = (time.perf_counter() - t0) / n + print(f"conv2d_backward_weight: {elapsed * 1000:.3f} ms/iter (mean of {n})") + finally: + backend.cleanup() + + +if __name__ == "__main__": + main() diff --git a/benchmarks/benchmark_int8_gemm.py b/benchmarks/benchmark_int8_gemm.py new file mode 100644 index 0000000..bc38d93 --- /dev/null +++ b/benchmarks/benchmark_int8_gemm.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +""" +INT8 weight GEMM baseline (Workstream C2) — `VulkanFNN.gemm_int8` if shader loaded. + +Usage: + uv run python benchmarks/benchmark_int8_gemm.py +""" + +from __future__ import annotations + +import time +import warnings + +import numpy as np + +from grilly.backend.compute import VulkanCompute + + +def main() -> None: + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + backend = VulkanCompute() + try: + if "int8-gemm" not in backend.fnn.shaders: + print("int8-gemm shader not loaded; skipping benchmark.") + return + + rng = np.random.default_rng(1) + M, K, N = 256, 512, 256 + group_size = 64 + num_groups = (K + group_size - 1) // group_size + + act = rng.standard_normal((M, K), dtype=np.float32) + w_i8 = rng.integers(-128, 127, size=(N, K), dtype=np.int8) + scales = np.abs(rng.standard_normal((N, num_groups), dtype=np.float32)) + 0.01 + + backend.fnn.gemm_int8(act, w_i8, scales, group_size=group_size) + + n = 10 + t0 = time.perf_counter() + for _ in range(n): + backend.fnn.gemm_int8(act, w_i8, scales, group_size=group_size) + elapsed = (time.perf_counter() - t0) / n + print(f"int8_gemm M={M} K={K} N={N}: {elapsed * 1000:.3f} ms/iter (mean of {n})") + finally: + backend.cleanup() + + +if __name__ == "__main__": + main() diff --git a/benchmarks/profile_gpu_bottlenecks.py b/benchmarks/profile_gpu_bottlenecks.py new file mode 100644 index 0000000..de39990 --- /dev/null +++ b/benchmarks/profile_gpu_bottlenecks.py @@ -0,0 +1,69 @@ +""" +Profile GPU-vs-CPU bottlenecks for parity-critical ops. + +Usage: + python benchmarks/profile_gpu_bottlenecks.py +""" + +from __future__ import annotations + +import cProfile +import io +import pstats +import time + +import numpy as np + +from grilly.nn.attention import MultiheadAttention +from grilly.nn.linear import Linear + + +def _print_stats(pr: cProfile.Profile, title: str, top_n: int = 25): + s = io.StringIO() + pstats.Stats(pr, stream=s).sort_stats("cumtime").print_stats(top_n) + print(f"\n=== {title} ===") + print(s.getvalue()) + + +def profile_linear(): + x = np.random.randn(64, 512).astype(np.float32) + linear = Linear(512, 512) + + pr = cProfile.Profile() + pr.enable() + for _ in range(20): + _ = linear(x) + pr.disable() + _print_stats(pr, "Linear forward profile") + + w = np.asarray(linear.weight, dtype=np.float32) + b = np.asarray(linear.bias, dtype=np.float32) + t0 = time.perf_counter() + for _ in range(20): + _ = x @ w.T + b + t1 = time.perf_counter() + print(f"CPU linear baseline (20 iters): {(t1 - t0) * 1000.0:.3f} ms") + + +def profile_attention(): + attn = MultiheadAttention(embed_dim=512, num_heads=8) + q = np.random.randn(2, 64, 512).astype(np.float32) + k = np.random.randn(2, 64, 512).astype(np.float32) + v = np.random.randn(2, 64, 512).astype(np.float32) + + pr = cProfile.Profile() + pr.enable() + for _ in range(10): + _ = attn(q, k, v) + pr.disable() + _print_stats(pr, "MultiheadAttention forward profile") + + +def main(): + print("Profiling parity-critical GPU paths...") + profile_linear() + profile_attention() + + +if __name__ == "__main__": + main() diff --git a/cpp/include/grilly/channels/channel.h b/cpp/include/grilly/channels/channel.h new file mode 100644 index 0000000..9b1ea1a --- /dev/null +++ b/cpp/include/grilly/channels/channel.h @@ -0,0 +1,127 @@ +/** + * grilly/channels/channel.h — Protobuf-based C++/Python channel interface. + * + * Provides zero-copy message passing between Vulkan compute (C++) + * and CubeMind brain modules (Python) via protobuf serialization. + * + * Two modes: + * 1. In-process: direct pybind11 calls with native proto casters + * 2. IPC: shared memory or Unix domain sockets for multi-process + * + * Usage (C++ side): + * auto channel = grilly::InProcessChannel(); + * channel.send(spike_data); + * auto result = channel.receive(); + * + * Usage (Python side, via pybind11): + * channel = grilly_core.InProcessChannel() + * channel.send_spike_train(spike_data) + * result = channel.receive_tensor() + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +namespace grilly { +namespace channels { + +/** + * MessageType — identifies the protobuf message type without parsing. + */ +enum class MessageType : uint8_t { + TENSOR_DATA = 0, + SPIKE_TRAIN = 1, + EXPERT_WEIGHTS = 2, + EXPERT_UPDATE = 3, + ROUTE_REQUEST = 4, + ROUTE_RESPONSE = 5, + MEMORY_CAPSULE = 6, + MEMORY_QUERY = 7, + MEMORY_RESULT = 8, + TELEMETRY_EVENT = 9, + NEUROCHEM_STATE = 10, + TRAIN_STEP_REQUEST = 11, + TRAIN_STEP_RESULT = 12, +}; + +/** + * MessageEnvelope — header + serialized payload. + */ +struct MessageEnvelope { + MessageType type; + uint64_t timestamp_ns; + std::string sender_id; + std::vector payload; // serialized protobuf bytes + + size_t size() const { return payload.size(); } +}; + +/** + * ChannelListener — callback interface for async message handling. + */ +using ChannelListener = std::function; + +/** + * BaseChannel — abstract channel interface. + */ +class BaseChannel { +public: + virtual ~BaseChannel() = default; + + /// Send a serialized message. + virtual void send(MessageEnvelope envelope) = 0; + + /// Receive next message (blocking). Returns empty if no messages. + virtual MessageEnvelope receive() = 0; + + /// Check if messages are available. + virtual bool has_messages() const = 0; + + /// Register an async listener for a message type. + virtual void subscribe(MessageType type, ChannelListener listener) = 0; + + /// Channel name for debugging. + virtual std::string name() const = 0; +}; + +/** + * InProcessChannel — thread-safe in-process message queue. + * + * For single-process use: C++ compute threads push messages, + * Python pybind11 threads pull them. Zero serialization overhead + * when combined with pybind11_protobuf native casters. + */ +class InProcessChannel : public BaseChannel { +public: + explicit InProcessChannel(const std::string& name = "default", + size_t max_queue_size = 10000); + + void send(MessageEnvelope envelope) override; + MessageEnvelope receive() override; + bool has_messages() const override; + void subscribe(MessageType type, ChannelListener listener) override; + std::string name() const override { return name_; } + + /// Queue size for monitoring. + size_t queue_size() const; + + /// Clear all queued messages. + void clear(); + +private: + std::string name_; + size_t max_queue_size_; + mutable std::mutex mutex_; + std::queue queue_; + std::unordered_map> listeners_; +}; + +} // namespace channels +} // namespace grilly diff --git a/cpp/include/grilly/io/grl_checkpoint.h b/cpp/include/grilly/io/grl_checkpoint.h new file mode 100644 index 0000000..5852ef1 --- /dev/null +++ b/cpp/include/grilly/io/grl_checkpoint.h @@ -0,0 +1,25 @@ +#pragma once +/// GRL v1 checkpoint file layout (shared with Python ``utils/grl_checkpoint.py``). + +#include +#include +#include + +namespace grilly::io { + +inline constexpr uint32_t kGrlHeaderSize = 64; +inline constexpr uint16_t kGrlFormatVersion = 1; + +/// Write a GRL v1 file: header | metadata UTF-8 | index JSON UTF-8 | payload bytes. +bool grl_write_file(const std::string& path, + const std::string& metadata_json, + const std::string& index_json, + const std::vector& payload); + +/// Read a GRL v1 file; on success returns true and fills out-arguments. +bool grl_read_file(const std::string& path, + std::string& metadata_json, + std::string& index_json, + std::vector& payload); + +} // namespace grilly::io diff --git a/cpp/include/grilly/ops/activations.h b/cpp/include/grilly/ops/activations.h index 6178bb9..a601d49 100644 --- a/cpp/include/grilly/ops/activations.h +++ b/cpp/include/grilly/ops/activations.h @@ -83,5 +83,20 @@ void softmaxBackward(CommandBatch& batch, BufferPool& pool, PipelineCache& cache float* gradInput, uint32_t batchSize, uint32_t seqLen, uint32_t numClasses); +/// Multiplication-free softmax (ReLU-normalized): shader ``mf-softmax`` (same 3-pass +/// layout as ``activation-softmax``). +void mfSoftmax(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, + const float* input, float* output, uint32_t batchSize, uint32_t seqLen, + uint32_t features); + +/// Algebraic softplus: ``0.5 * (x + sqrt(x*x + c))`` with ``c = 4/beta^2``. Shader +/// ``mf-softplus`` (2 buffers). +void mfSoftplus(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, + const float* input, float* output, uint32_t totalElements, float beta); + +/// Rational sigmoid ``x / (1 + |x|)``. Shader ``mf-sigmoid`` (2 buffers). +void mfSigmoid(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, + const float* input, float* output, uint32_t totalElements); + } // namespace ops } // namespace grilly diff --git a/cpp/include/grilly/ops/attention_ops.h b/cpp/include/grilly/ops/attention_ops.h index bd88223..e329995 100644 --- a/cpp/include/grilly/ops/attention_ops.h +++ b/cpp/include/grilly/ops/attention_ops.h @@ -32,10 +32,28 @@ struct AttentionScoresParams { uint32_t passType; // 0 = compute scores }; +// Used by attentionOutput and attentionScoresSoftmaxOutput (declare before fused API). +struct AttentionOutputParams { + uint32_t batchSize; + uint32_t seqLen; + uint32_t numHeads; + uint32_t headDim; +}; + void attentionScores(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, const float* Q, const float* K, float* scores, const AttentionScoresParams& p); +/// Fused: attention scores → softmax (last dim) → attention output (weights @ V). +/// Single submit/wait; no host round-trip between scores and softmax. +/// Q, K, V: (B, H, S, D). Writes output (B, H, S, D) and softmaxWeights (B, H, S, S). +void attentionScoresSoftmaxOutput(CommandBatch& batch, BufferPool& pool, + PipelineCache& cache, const float* Q, + const float* K, const float* V, float* output, + float* softmaxWeights, + const AttentionScoresParams& scoreParams, + const AttentionOutputParams& outParams); + // ── Attention mask ─────────────────────────────────────────────────────── // Shader: attention-mask.spv // local_size: (256, 1, 1) @@ -58,13 +76,6 @@ void attentionMask(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, // local_size: (256, 1, 1) // Buffers: weights(0) readonly, V(1) readonly, output(2) write -struct AttentionOutputParams { - uint32_t batchSize; - uint32_t seqLen; - uint32_t numHeads; - uint32_t headDim; -}; - void attentionOutput(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, const float* weights, const float* V, float* output, const AttentionOutputParams& p); diff --git a/cpp/include/grilly/ops/batched_ops.h b/cpp/include/grilly/ops/batched_ops.h index ac77527..84bb884 100644 --- a/cpp/include/grilly/ops/batched_ops.h +++ b/cpp/include/grilly/ops/batched_ops.h @@ -15,6 +15,8 @@ #include "grilly/command_batch.h" #include "grilly/pipeline_cache.h" +#include "grilly/ops/embedding.h" + namespace grilly { namespace ops { @@ -60,6 +62,13 @@ void batchedAdd(CommandBatch& batch, PipelineCache& cache, GrillyBuffer& a, const GrillyBuffer& b, uint32_t totalElements); +/// Embedding lookup: output[batch*seq, d] = table[token_ids]. Caller must +/// `batch.begin()` first; buffers must be pre-uploaded (ids, table). +void batchedEmbeddingLookup(CommandBatch& batch, PipelineCache& cache, + const GrillyBuffer& tokenIds, const GrillyBuffer& embedTable, + GrillyBuffer& output, uint32_t batchSize, uint32_t seqLen, + uint32_t vocabSize, uint32_t embeddingDim); + /// Tiled GEMM: C[M,N] = A[M,K] @ B[N,K]^T + bias[N] /// Uses 32x32 shared memory tiles, 2x2 per-thread sub-tiles. void batchedTiledLinear(CommandBatch& batch, PipelineCache& cache, diff --git a/cpp/include/grilly/ops/linear.h b/cpp/include/grilly/ops/linear.h index a9018c5..d1d9dc8 100644 --- a/cpp/include/grilly/ops/linear.h +++ b/cpp/include/grilly/ops/linear.h @@ -14,42 +14,54 @@ namespace grilly { namespace ops { /// GPU-accelerated linear projection: output = x @ W^T + bias. -/// Ports backend/fnn.py:1823-1976 to native C++. /// -/// Push constants layout (must match fnn-linear.glsl): +/// Dynamic dtype via ``elemSize``: +/// - ``elemSize == 4`` (fp32): dispatches ``fnn-linear.glsl`` (tiled GEMM, +/// universal path, works on any device). x / weights / bias are +/// interpreted as ``float``; output is always ``float``. +/// - ``elemSize == 2`` (fp16): dispatches ``gemm-coopmat-shared.glsl`` if +/// the device supports ``VK_KHR_cooperative_matrix`` AND the shape is +/// aligned (M % 16 == 0, K % 16 == 0, N % 64 == 0). x / weights / bias +/// are ``float16_t`` bytes in the staging buffer; output remains fp32 +/// (the coopmat accumulator runs in fp32 for numerical stability). +/// Falls back to an fp32 conversion + fnn-linear path if the device +/// lacks cooperative matrix support or the shape isn't aligned. +/// +/// Bias handling: on the coopmat path, bias is applied via a small second +/// dispatch (``gemm-bias-add.glsl``) because ``coopMatStore`` can't +/// interleave an elementwise add with the tile-cooperative store. +/// +/// Push constants layout (matches fnn-linear.glsl): /// uint batch_seq; // offset 0 /// uint input_dim; // offset 4 /// uint output_dim; // offset 8 /// uint has_bias; // offset 12 -/// -/// Buffers (binding order matches shader): -/// 0: input (batch_seq * input_dim floats) -/// 1: weights (output_dim * input_dim floats) -/// 2: bias (output_dim floats, or 1 dummy float) -/// 3: output (batch_seq * output_dim floats) +/// uint elem_size; // offset 16 — 4 for fp32, 2 for fp16 struct LinearParams { uint32_t batchSeq; uint32_t inputDim; uint32_t outputDim; uint32_t hasBias; + uint32_t elemSize; ///< 4 for fp32, 2 for fp16 }; /// Execute a linear (dense / fully-connected) layer on the GPU. /// -/// All Vulkan work — buffer upload, pipeline bind, descriptor set, -/// dispatch, download — happens inside C++ with zero Python crossings. +/// Pointer types are ``const void*`` / ``void*`` so the caller can pass +/// either fp32 or fp16 byte buffers transparently — the element size is +/// encoded in ``p.elemSize``. /// /// @param batch CommandBatch to record the dispatch into. /// @param pool BufferPool for acquiring/releasing GPU memory. -/// @param cache PipelineCache for the fnn-linear shader. +/// @param cache PipelineCache for the fnn-linear / gemm-coopmat shaders. /// @param x Input matrix, row-major (batchSeq × inputDim). /// @param weights Weight matrix, row-major (outputDim × inputDim). /// @param bias Optional bias vector (outputDim). nullptr = no bias. -/// @param output Output buffer, pre-allocated (batchSeq × outputDim). -/// @param p Dimension parameters. +/// @param output Output buffer, pre-allocated (batchSeq × outputDim × 4 bytes). +/// @param p Dimension + dtype parameters. void linear(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, - const float* x, const float* weights, const float* bias, - float* output, const LinearParams& p); + const void* x, const void* weights, const void* bias, + void* output, const LinearParams& p); /// CPU reference implementation using Eigen for correctness verification. /// Returns a newly-allocated vector: output = x @ W^T + bias. @@ -66,13 +78,18 @@ struct LinearBackwardParams { }; /// GPU linear backward. Produces grad_input, grad_weight, grad_bias. -/// 3-pass dispatch with barriers. +/// 3-pass dispatch with barriers. Dtype-dynamic via ``p.elemSize``. /// 6 buffers: grad_output(0), input(1), weights(2), /// grad_input(3), grad_weight(4), grad_bias(5). +/// +/// NOTE: the current backward shader (``fnn-linear-backward.glsl``) is +/// fp32-only. Calling with ``p.elemSize == 2`` will throw until a +/// cooperative-matrix backward shader lands. The ``void*`` / dynamic +/// elemSize interface is in place so the upgrade is a local shader change. void linearBackward(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, - const float* gradOutput, const float* input, - const float* weights, - float* gradInput, float* gradWeight, float* gradBias, + const void* gradOutput, const void* input, + const void* weights, + void* gradInput, void* gradWeight, void* gradBias, const LinearParams& p); /// Dropout push constants — matches fnn-dropout.glsl. diff --git a/cpp/include/grilly/ops/moe_forward.h b/cpp/include/grilly/ops/moe_forward.h new file mode 100644 index 0000000..a6aeaaa --- /dev/null +++ b/cpp/include/grilly/ops/moe_forward.h @@ -0,0 +1,104 @@ +#pragma once +/// Fused MoE forward / backward — minimal Python crossings for training hot path. +/// +/// `moe_forward_gpu` runs embedding, per-layer expert GEMMs (batched), residual, +/// and output projection on GPU. Router softmax and expert blending use CPU for +/// correctness and simplicity (small tensors); blend can move to a dedicated +/// shader in a follow-up. + +#include +#include +#include + +#include "grilly/buffer_pool.h" +#include "grilly/command_batch.h" +#include "grilly/pipeline_cache.h" + +namespace grilly { +namespace ops { + +struct MoeLayerGPU { + GrillyBuffer routerW; + GrillyBuffer routerB; + std::vector expertW; + std::vector expertWt; + GrillyBuffer expertPacked; // All experts packed contiguously (n_experts * d * d) +}; + +struct MoeGradients { + std::vector grad_embed; + std::vector grad_pos; + std::vector> grad_experts; + std::vector> grad_router_w; + std::vector> grad_router_b; + std::vector grad_out_w; +}; + +struct MoeHandleCache { + uint32_t vocab = 0; + uint32_t d = 0; + uint32_t maxSeq = 0; + uint32_t nLayers = 0; + uint32_t nExperts = 0; + + GrillyBuffer embedW; + GrillyBuffer posW; + GrillyBuffer outW; + GrillyBuffer outWt; + + std::vector layers; + + GrillyBuffer bufIds; + GrillyBuffer bufPosSlice; + GrillyBuffer bufX; + std::vector bufExpertOut; + GrillyBuffer bufBlended; + GrillyBuffer bufLogits; + + // Saved activations for backward (one per layer + final) + std::vector bufActivations; // [0]=after embed, [l]=after layer l-1 + // Router weights per layer (saved from forward for backward) + std::vector> fwd_router_weights; + + std::vector cpu_embed; + std::vector cpu_pos; + std::vector cpu_out_w; + std::vector> cpu_router_w; + std::vector> cpu_router_b; + std::vector> cpu_expert_w; +}; + +int moe_upload(BufferPool& pool, + uint32_t vocab_size, uint32_t d_model, uint32_t max_seq, + const float* embed_w, const float* pos_w, + const std::vector& expert_ws, + const std::vector& router_ws, + const std::vector& router_bs, + const float* out_w, + uint32_t n_layers, uint32_t n_experts); + +MoeHandleCache& moe_get_cache(int handle); + +void moe_release(BufferPool& pool, int handle); + +void moe_update_weights(BufferPool& pool, MoeHandleCache& h, + const float* embed_w, const float* pos_w, + const std::vector& expert_ws, + const std::vector& router_ws, + const std::vector& router_bs, + const float* out_w); + +void moe_forward_gpu(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, + MoeHandleCache& h, const int32_t* input_ids, uint32_t seq_len, + float* logits_out); + +MoeGradients moe_backward_cpu(const MoeHandleCache& h, + const int32_t* input_ids, uint32_t seq_len, + const float* grad_logits); + +MoeGradients moe_backward_gpu(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, + MoeHandleCache& h, const int32_t* input_ids, uint32_t seq_len, + const float* grad_logits); + +} // namespace ops +} // namespace grilly diff --git a/cpp/include/grilly/ops/prefix_scan.h b/cpp/include/grilly/ops/prefix_scan.h new file mode 100644 index 0000000..929ca8d --- /dev/null +++ b/cpp/include/grilly/ops/prefix_scan.h @@ -0,0 +1,65 @@ +#pragma once + +#include + +#include "grilly/buffer_pool.h" +#include "grilly/command_batch.h" +#include "grilly/pipeline_cache.h" + +namespace grilly { +namespace ops { + +/// Causal Linear-RNN prefix scan push constants. +/// Matches shaders/prefix-scan-causal.glsl + prefix-scan-causal-backward.glsl. +struct PrefixScanParams { + uint32_t seqLen; ///< must be <= subgroup size (32 Wave32 / 64 Wave64) + uint32_t hiddenDim; + uint32_t batchSize; +}; + +/// Causal linear-RNN forward: h_t = a_t * h_{t-1} + x_t. +/// +/// Implemented as two hardware subgroup inclusive scans (log(a) and the +/// rescaled x). The recurrence is strictly causal: h_t depends only on +/// x_{0..t} and a_{0..t}. +/// +/// Inputs / outputs are fp32 (B, S, D) row-major. +/// x: input sequence, shape (B, S, D) +/// a: decay gates in (0, 1], same shape +/// h: output hidden states, same shape (pre-allocated) +/// +/// Buffer bindings (match the shader): +/// 0: x (input) +/// 1: a (decay) +/// 2: h (output) +/// +/// Constraint: seqLen <= subgroup size. Longer sequences need a hierarchical +/// scan — not implemented yet. +void prefixScanCausal(CommandBatch& batch, BufferPool& pool, + PipelineCache& cache, + const float* x, const float* a, float* h, + const PrefixScanParams& p); + +/// Causal prefix scan backward: given dh, compute dx and da. +/// +/// Math: dx_t = sum_{s>=t} (prod_{k=t+1..s} a_k) * dh_s (anti-causal scan) +/// da_t = dx_t * h_{t-1} +/// +/// Needs the forward x and h tensors saved from the forward pass. +/// +/// Buffer bindings: +/// 0: dh (grad into output) +/// 1: a (decay) +/// 2: h (forward output, for da computation) +/// 3: x (forward input, for da computation) +/// 4: dx (grad w.r.t. input, output) +/// 5: da (grad w.r.t. decay, output) +void prefixScanCausalBackward(CommandBatch& batch, BufferPool& pool, + PipelineCache& cache, + const float* dh, const float* a, + const float* h, const float* x, + float* dx, float* da, + const PrefixScanParams& p); + +} // namespace ops +} // namespace grilly diff --git a/cpp/include/grilly/ops/vsa_lm_forward.h b/cpp/include/grilly/ops/vsa_lm_forward.h new file mode 100644 index 0000000..5a5b20e --- /dev/null +++ b/cpp/include/grilly/ops/vsa_lm_forward.h @@ -0,0 +1,121 @@ +#pragma once +/// VSA Language Model fused forward/backward — AdditionLinear FFN + MindForge LoRA. +/// +/// Uses addition-linear shader (L1 distance, no matmul) for FFN layers and +/// fnn-linear for the output projection only. GPU router + MindForge LoRA +/// on CPU (tiny). Target: 10x faster than Python AdditionLinear loops. + +#include +#include + +#include "grilly/buffer_pool.h" +#include "grilly/command_batch.h" +#include "grilly/pipeline_cache.h" + +namespace grilly { +namespace ops { + +struct VsaLmLayerGPU { + GrillyBuffer ffnUpW; // (d_ffn, d) addition-linear patterns + GrillyBuffer ffnUpB; // (d_ffn,) + GrillyBuffer ffnDownW; // (d, d_ffn) addition-linear patterns + GrillyBuffer ffnDownB; // (d,) + GrillyBuffer lnGamma; // (d,) + GrillyBuffer lnBeta; // (d,) +}; + +struct VsaLmGradients { + std::vector grad_embed; // (vocab, d) + std::vector grad_pos; // (max_seq, d) + std::vector> grad_ffn_up_w; // n_layers of (d_ffn * d) + std::vector> grad_ffn_up_b; // n_layers of (d_ffn) + std::vector> grad_ffn_down_w; // n_layers of (d * d_ffn) + std::vector> grad_ffn_down_b; // n_layers of (d) + std::vector> grad_ln_gamma; // n_layers of (d) + std::vector> grad_ln_beta; // n_layers of (d) + std::vector grad_out_w; // (vocab, d) +}; + +struct VsaLmHandleCache { + uint32_t vocab = 0; + uint32_t d = 0; + uint32_t dFfn = 0; + uint32_t maxSeq = 0; + uint32_t nLayers = 0; + + // GPU buffers + GrillyBuffer embedW; + GrillyBuffer posW; + GrillyBuffer outW; + GrillyBuffer outWt; // transposed for backward + + std::vector layers; + + GrillyBuffer bufIds; + GrillyBuffer bufPosSlice; + GrillyBuffer bufX; // current hidden (seq, d) + GrillyBuffer bufLnOut; // layernorm output (seq, d) + GrillyBuffer bufFfnUp; // after addition-linear up (seq, d_ffn) + GrillyBuffer bufSign; // after sign activation (seq, d_ffn) + GrillyBuffer bufFfnDown; // after addition-linear down (seq, d) + GrillyBuffer bufLogits; // (seq, vocab) + + // LayerNorm scratch (mean + var buffers) + GrillyBuffer bufLnMean; // (seq,) + GrillyBuffer bufLnVar; // (seq,) + + // Saved activations for backward (n_layers + 1) + std::vector bufActivations; + + // CPU mirrors for backward + std::vector cpu_embed; + std::vector cpu_pos; + std::vector cpu_out_w; + std::vector> cpu_ffn_up_w; // n_layers + std::vector> cpu_ffn_up_b; + std::vector> cpu_ffn_down_w; + std::vector> cpu_ffn_down_b; + std::vector> cpu_ln_gamma; + std::vector> cpu_ln_beta; + + // MindForge weights (CPU only — tiny) + std::vector forge_W_proj; // (d_hidden, d_vsa) + std::vector forge_W_h; // (d_hidden, d_hidden*2) + std::vector forge_b_h; // (d_hidden,) + std::vector forge_W_coeff; // (n_basis, d_hidden) + std::vector forge_b_coeff; // (n_basis,) + std::vector forge_A_basis; // (n_basis, rank, d) + std::vector forge_B_basis; // (n_basis, d, rank) + std::vector forge_layer_embs; // (n_layers, d_hidden) + uint32_t forge_d_hidden = 0; + uint32_t forge_d_vsa = 0; + uint32_t forge_n_basis = 0; + uint32_t forge_rank = 0; +}; + +int vsa_lm_upload(BufferPool& pool, + uint32_t vocab, uint32_t d, uint32_t d_ffn, uint32_t max_seq, + const float* embed_w, const float* pos_w, + const std::vector& ffn_up_ws, + const std::vector& ffn_up_bs, + const std::vector& ffn_down_ws, + const std::vector& ffn_down_bs, + const std::vector& ln_gammas, + const std::vector& ln_betas, + const float* out_w, + uint32_t n_layers); + +VsaLmHandleCache& vsa_lm_get_cache(int handle); + +void vsa_lm_release(BufferPool& pool, int handle); + +void vsa_lm_forward_gpu(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, + VsaLmHandleCache& h, const int32_t* input_ids, + uint32_t seq_len, float* logits_out); + +VsaLmGradients vsa_lm_backward_cpu(const VsaLmHandleCache& h, + const int32_t* input_ids, uint32_t seq_len, + const float* grad_logits); + +} // namespace ops +} // namespace grilly diff --git a/cpp/include/grilly/vulkan/vk_buffer_pool.h b/cpp/include/grilly/vulkan/vk_buffer_pool.h index f36050d..a4873f1 100644 --- a/cpp/include/grilly/vulkan/vk_buffer_pool.h +++ b/cpp/include/grilly/vulkan/vk_buffer_pool.h @@ -21,6 +21,18 @@ struct GrillyBuffer { size_t size = 0; size_t bucketSize = 0; void* mappedPtr = nullptr; + /// True if this buffer was acquired via ``acquireDeviceLocal`` (no host + /// visibility, mapped to GPU-only VRAM). ``release`` uses this to route + /// the buffer back to the device-local bucket pool, so a regular + /// ``acquire`` won't accidentally pick up a DL buffer and crash trying + /// to memcpy into a null mappedPtr. + bool deviceLocal = false; + /// True if this buffer was acquired via ``acquireReadback`` — backed + /// by HOST_CACHED memory for fast CPU read-after-GPU-write. Released + /// buffers go to a separate readback pool so a future ``acquire`` + /// (which expects WC sequential-write memory) doesn't waste a cached + /// readback slot on a write-only workload. + bool readback = false; }; /// VMA-backed buffer pool with power-of-2 bucket reuse. @@ -65,8 +77,26 @@ class BufferPool { GrillyDevice& device_; VmaAllocator allocator_ = VK_NULL_HANDLE; std::mutex mutex_; + /// Pool for host-visible WC buffers (acquire / release default path). + /// Optimized for CPU sequential writes (memcpy CPU → buffer). std::unordered_map> buckets_; + /// Pool for DEVICE_LOCAL-only buffers (acquireDeviceLocal / release). + /// Kept separate so a host-visible acquire never picks up a DL buffer. + std::unordered_map> dlBuckets_; + /// Pool for HOST_CACHED readback buffers (acquireReadback / release). + /// Optimized for CPU random reads (memcpy buffer → CPU output) at + /// cached system RAM speed (~10 GB/s vs ~25 MB/s for WC reads). + std::unordered_map> readbackBuckets_; Stats stats_{}; + + // Persistent transfer context (avoids cmd pool/fence recreation per transfer) + VkCommandPool transferPool_ = VK_NULL_HANDLE; + VkCommandBuffer transferCmd_ = VK_NULL_HANDLE; + VkFence transferFence_ = VK_NULL_HANDLE; + bool transferInitialized_ = false; + + void ensureTransferContext(); + void transferSubmitAndWait(); }; } // namespace grilly diff --git a/cpp/include/grilly/vulkan/vk_command_batch.h b/cpp/include/grilly/vulkan/vk_command_batch.h index e7203dd..4e80c4c 100644 --- a/cpp/include/grilly/vulkan/vk_command_batch.h +++ b/cpp/include/grilly/vulkan/vk_command_batch.h @@ -21,6 +21,9 @@ class CommandBatch { void begin(); + /// Start recording if not already. Safe to call multiple times. + void ensureRecording(); + void dispatch(VkPipeline pipeline, VkPipelineLayout layout, VkDescriptorSet descSet, uint32_t gx, uint32_t gy = 1, uint32_t gz = 1, @@ -28,14 +31,29 @@ class CommandBatch { void barrier(); + /// Bidirectional TRANSFER ↔ COMPUTE memory barrier — covers both + /// stage-in → compute (TRANSFER_WRITE → SHADER_READ) and + /// compute → stage-out (SHADER_WRITE → TRANSFER_READ). Used by the + /// staging-buffer pattern in cpp/src/ops/linear.cpp. + void transferComputeBarrier(); + /// GPU-to-GPU buffer copy (no CPU involvement). - /// Must be called between begin() and submit(). void copyBuffer(const GrillyBuffer& src, GrillyBuffer& dst, size_t bytes); void submit(); void submitAsync(VkSemaphore timeline, uint64_t signalValue); + /// Submit without waiting — GPU runs while CPU continues. + void submitDeferred(); + + /// Block until the last submitted work finishes. No-op if nothing pending. + void waitForCompletion(); + + /// Number of dispatches recorded in current command buffer. + uint32_t dispatchCount() const { return dispatchCount_; } + bool isRecording() const { return recording_; } + bool isPending() const { return pending_; } VkCommandBuffer cmdBuffer() const { return cmd_; } private: @@ -44,6 +62,8 @@ class CommandBatch { VkCommandBuffer cmd_ = VK_NULL_HANDLE; VkFence fence_ = VK_NULL_HANDLE; bool recording_ = false; + bool pending_ = false; + uint32_t dispatchCount_ = 0; }; } // namespace grilly diff --git a/cpp/include/grilly/vulkan/vk_pipeline_cache.h b/cpp/include/grilly/vulkan/vk_pipeline_cache.h index 330e796..7a12572 100644 --- a/cpp/include/grilly/vulkan/vk_pipeline_cache.h +++ b/cpp/include/grilly/vulkan/vk_pipeline_cache.h @@ -46,6 +46,11 @@ class PipelineCache { return spirvCode_.count(name) > 0; } + /// Access the underlying device for capability queries + /// (e.g. ``hasCooperativeMatrix()``). + GrillyDevice& getDevice() { return device_; } + const GrillyDevice& getDevice() const { return device_; } + struct CacheStats { uint64_t hits = 0; uint64_t misses = 0; diff --git a/cpp/proto/grilly_channels.proto b/cpp/proto/grilly_channels.proto new file mode 100644 index 0000000..bf03ac4 --- /dev/null +++ b/cpp/proto/grilly_channels.proto @@ -0,0 +1,176 @@ +// grilly_channels.proto — Core message types for grilly C++/Python channels. +// +// These messages define the wire format for zero-copy data passing between +// Vulkan compute shaders (C++) and CubeMind's Python brain modules. + +syntax = "proto3"; +package grilly; + +// ═══════════════════════════════════════════════════════════════════════════════ +// Tensor data +// ═══════════════════════════════════════════════════════════════════════════════ + +message TensorData { + repeated uint32 shape = 1; // e.g., [batch, seq, dim] + string dtype = 2; // "float32", "float16", "int8", "int4" + bytes data = 3; // raw bytes (packed) + string device = 4; // "cpu", "vulkan0", "cuda0" + uint64 buffer_id = 5; // GPU buffer handle (zero-copy) +} + +// ═══════════════════════════════════════════════════════════════════════════════ +// SNN spike trains +// ═══════════════════════════════════════════════════════════════════════════════ + +message SpikeTrain { + uint32 n_neurons = 1; + uint32 n_timesteps = 2; + repeated float spikes = 3 [packed = true]; // flattened (timesteps, neurons) + float dt = 4; // timestep duration +} + +message SpikeEvent { + uint32 neuron_id = 1; + uint32 timestep = 2; + float value = 3; // multi-bit spike level (0..L) +} + +// ═══════════════════════════════════════════════════════════════════════════════ +// Expert / MoE messages +// ═══════════════════════════════════════════════════════════════════════════════ + +message ExpertWeights { + uint32 expert_id = 1; + TensorData weights = 2; // quantized weights (INT4/INT8) + TensorData scales = 3; // per-block scale factors + string quant_type = 4; // "int4", "int8", "float16" +} + +message ExpertUpdate { + uint32 expert_id = 1; + float q_value = 2; // bandit Q-value + float charge = 3; // HE-MoE charge (+1/-1) + repeated float mu = 4 [packed = true]; // expert position in RKHS + float a_trace = 5; // activity trace + repeated float e_trace = 6 [packed = true]; // error trace + uint64 n_uses = 7; +} + +message RouteRequest { + TensorData input = 1; + uint32 top_k = 2; + float temperature = 3; +} + +message RouteResponse { + repeated uint32 expert_ids = 1 [packed = true]; + repeated float weights = 2 [packed = true]; + repeated float scores = 3 [packed = true]; +} + +// ═══════════════════════════════════════════════════════════════════════════════ +// Memory / Hippocampal capsules +// ═══════════════════════════════════════════════════════════════════════════════ + +message MemoryCapsule { + string capsule_id = 1; + TensorData context = 2; // input context at storage time + repeated uint32 expert_ids = 3 [packed = true]; // which experts were active + TensorData error = 4; // prediction error + uint64 timestamp_ns = 5; + float surprise = 6; // error magnitude + repeated string tags = 7; // "novel", "high_error", etc. +} + +message MemoryQuery { + TensorData query = 1; + uint32 top_k = 2; + float min_similarity = 3; + string filter_tag = 4; // optional tag filter +} + +message MemoryResult { + repeated MemoryCapsule capsules = 1; + repeated float similarities = 2 [packed = true]; +} + +// ═══════════════════════════════════════════════════════════════════════════════ +// Telemetry / Observability +// ═══════════════════════════════════════════════════════════════════════════════ + +message TelemetryEvent { + string component_id = 1; + string event_type = 2; // "forward", "spawn", "charge_flip", etc. + map metrics = 3; // key-value metrics + map labels = 4; // key-value labels + uint64 timestamp_ns = 5; + uint32 step = 6; +} + +message TelemetryBatch { + repeated TelemetryEvent events = 1; +} + +// ═══════════════════════════════════════════════════════════════════════════════ +// Neurochemistry +// ═══════════════════════════════════════════════════════════════════════════════ + +message NeurochemState { + float dopamine = 1; + float serotonin = 2; + float cortisol = 3; + float noradrenaline = 4; + float oxytocin = 5; + string emotion = 6; // Lövheim cube classification + float arousal = 7; + float valence = 8; +} + +// ═══════════════════════════════════════════════════════════════════════════════ +// Training +// ═══════════════════════════════════════════════════════════════════════════════ + +message TrainStepRequest { + TensorData input = 1; + TensorData target = 2; + float learning_rate = 3; + bool extract_logits = 4; +} + +message TrainStepResult { + float loss = 1; + float loss_ce = 2; + float loss_kd = 3; + uint32 n_experts = 4; + bool spawned = 5; + float residual_ema = 6; + TelemetryEvent telemetry = 7; + TensorData teacher_logits = 8; // optional, if extract_logits=true +} + +// ═══════════════════════════════════════════════════════════════════════════════ +// RPC Service definitions +// ═══════════════════════════════════════════════════════════════════════════════ + +service GrillyCompute { + // Core tensor ops + rpc Linear(TensorData) returns (TensorData); + rpc Relu(TensorData) returns (TensorData); + rpc Gelu(TensorData) returns (TensorData); + + // SNN ops + rpc GIFNeuronStep(SpikeTrain) returns (SpikeTrain); + rpc SynapsisFwd(TensorData) returns (TensorData); + rpc STDPUpdate(SpikeTrain) returns (ExpertWeights); + + // MoE routing + rpc Route(RouteRequest) returns (RouteResponse); + rpc UpdateExpert(ExpertUpdate) returns (TelemetryEvent); + + // Memory + rpc StoreMemory(MemoryCapsule) returns (TelemetryEvent); + rpc QueryMemory(MemoryQuery) returns (MemoryResult); + + // Training + rpc TrainStep(TrainStepRequest) returns (TrainStepResult); +} diff --git a/cpp/python/bindings.cpp b/cpp/python/bindings.cpp index 381dd49..ca2650f 100644 --- a/cpp/python/bindings.cpp +++ b/cpp/python/bindings.cpp @@ -3188,6 +3188,8 @@ PYBIND11_MODULE(grilly_core, m) { "Download to CPU and return as numpy array") .def("gpu_handle", &Tensor::gpu_handle, "Upload to GPU and return buffer handle") + .def("gpu_handle_if_valid", &Tensor::gpu_handle_if_valid, + "GPU buffer handle if already resident (0 otherwise); does not upload") .def("mark_gpu_modified", &Tensor::mark_gpu_modified) .def("mark_cpu_modified", &Tensor::mark_cpu_modified) .def("reshape", &Tensor::reshape, py::arg("shape")) diff --git a/cpp/python/bindings_activations.cpp b/cpp/python/bindings_activations.cpp index f88a210..c4ea83b 100644 --- a/cpp/python/bindings_activations.cpp +++ b/cpp/python/bindings_activations.cpp @@ -19,14 +19,18 @@ void register_activations_ops(py::module_& m) { [fn](GrillyCoreContext& ctx, py::array_t input) -> Tensor { auto inBuf = input.request(); + require_c_contiguous_float(inBuf); uint32_t total = static_cast(inBuf.size); py::array_t result(input.request().shape); auto rBuf = result.request(); - fn(ctx.batch, ctx.pool, ctx.cache, - static_cast(inBuf.ptr), - static_cast(rBuf.ptr), total); + { + py::gil_scoped_release release; + fn(ctx.batch, ctx.pool, ctx.cache, + static_cast(inBuf.ptr), + static_cast(rBuf.ptr), total); + } return Tensor::from_numpy(result); }, @@ -38,6 +42,55 @@ void register_activations_ops(py::module_& m) { defActivation("silu", grilly::ops::silu); defActivation("tanh_act", grilly::ops::tanh_act); + m.def( + "mf_sigmoid", + [](GrillyCoreContext& ctx, + py::array_t input) -> Tensor { + auto inBuf = input.request(); + require_c_contiguous_float(inBuf); + uint32_t total = static_cast(inBuf.size); + + py::array_t result(input.request().shape); + auto rBuf = result.request(); + + { + py::gil_scoped_release release; + grilly::ops::mfSigmoid( + ctx.batch, ctx.pool, ctx.cache, + static_cast(inBuf.ptr), + static_cast(rBuf.ptr), total); + } + + return Tensor::from_numpy(result); + }, + py::arg("device"), py::arg("input"), + "GPU rational sigmoid x/(1+|x|) (shader mf-sigmoid)"); + + m.def( + "mf_softplus", + [](GrillyCoreContext& ctx, py::array_t input, + float beta) -> Tensor { + auto inBuf = input.request(); + require_c_contiguous_float(inBuf); + uint32_t total = static_cast(inBuf.size); + + py::array_t result(input.request().shape); + auto rBuf = result.request(); + + { + py::gil_scoped_release release; + grilly::ops::mfSoftplus( + ctx.batch, ctx.pool, ctx.cache, + static_cast(inBuf.ptr), + static_cast(rBuf.ptr), total, beta); + } + + return Tensor::from_numpy(result); + }, + py::arg("device"), py::arg("input"), py::arg("beta") = 1.0f, + "GPU algebraic softplus 0.5*(x+sqrt(x*x+c)), c=4/beta^2 (shader " + "mf-softplus)"); + // ── Activation backward ops ────────────────────────────────────────── // 3 buffers: grad_output, input (original forward input), grad_input. // Same push constant (uint total_elements) and dispatch pattern. @@ -50,6 +103,8 @@ void register_activations_ops(py::module_& m) { py::array_t input) -> Tensor { auto gBuf = grad_output.request(); auto iBuf = input.request(); + require_c_contiguous_float(gBuf); + require_c_contiguous_float(iBuf); uint32_t total = 1; for (int i = 0; i < gBuf.ndim; ++i) total *= static_cast(gBuf.shape[i]); @@ -57,10 +112,13 @@ void register_activations_ops(py::module_& m) { py::array_t result(gBuf.shape); auto rBuf = result.request(); - fn(ctx.batch, ctx.pool, ctx.cache, - static_cast(gBuf.ptr), - static_cast(iBuf.ptr), - static_cast(rBuf.ptr), total); + { + py::gil_scoped_release release; + fn(ctx.batch, ctx.pool, ctx.cache, + static_cast(gBuf.ptr), + static_cast(iBuf.ptr), + static_cast(rBuf.ptr), total); + } return Tensor::from_numpy(result); }, @@ -95,15 +153,18 @@ void register_activations_ops(py::module_& m) { py::array_t result(outShape); auto rBuf = result.request(); - grilly::ops::fusedMlpGelu( - ctx.batch, ctx.pool, ctx.cache, - static_cast(inBuf.ptr), - static_cast(w1Buf.ptr), - static_cast(b1Buf.ptr), - static_cast(w2Buf.ptr), - static_cast(b2Buf.ptr), - static_cast(rBuf.ptr), - seqLen, dIn, dHidden, dOut); + { + py::gil_scoped_release release; + grilly::ops::fusedMlpGelu( + ctx.batch, ctx.pool, ctx.cache, + static_cast(inBuf.ptr), + static_cast(w1Buf.ptr), + static_cast(b1Buf.ptr), + static_cast(w2Buf.ptr), + static_cast(b2Buf.ptr), + static_cast(rBuf.ptr), + seqLen, dIn, dHidden, dOut); + } return Tensor::from_numpy(result); }, @@ -131,15 +192,18 @@ void register_activations_ops(py::module_& m) { py::array_t result(outShape); auto rBuf = result.request(); - grilly::ops::fusedLayernormLinear( - ctx.batch, ctx.pool, ctx.cache, - static_cast(inBuf.ptr), - static_cast(lnWBuf.ptr), - static_cast(lnBBuf.ptr), - static_cast(pWBuf.ptr), - static_cast(pBBuf.ptr), - static_cast(rBuf.ptr), - seqLen, dIn, dOut); + { + py::gil_scoped_release release; + grilly::ops::fusedLayernormLinear( + ctx.batch, ctx.pool, ctx.cache, + static_cast(inBuf.ptr), + static_cast(lnWBuf.ptr), + static_cast(lnBBuf.ptr), + static_cast(pWBuf.ptr), + static_cast(pBBuf.ptr), + static_cast(rBuf.ptr), + seqLen, dIn, dOut); + } return Tensor::from_numpy(result); }, diff --git a/cpp/python/bindings_attention.cpp b/cpp/python/bindings_attention.cpp index 8457bcf..b8244cd 100644 --- a/cpp/python/bindings_attention.cpp +++ b/cpp/python/bindings_attention.cpp @@ -41,15 +41,18 @@ void register_attention_ops(py::module_& m) { }); auto rBuf = result.request(); - grilly::ops::flashAttention2( - ctx.batch, ctx.pool, ctx.cache, - static_cast(qBuf.ptr), - static_cast(K.request().ptr), - static_cast(V.request().ptr), - maskPtr, - static_cast(rBuf.ptr), - batchSize, seqLen, numHeads, headDim, - scale, tileSizeQ, tileSizeK); + { + py::gil_scoped_release release; + grilly::ops::flashAttention2( + ctx.batch, ctx.pool, ctx.cache, + static_cast(qBuf.ptr), + static_cast(K.request().ptr), + static_cast(V.request().ptr), + maskPtr, + static_cast(rBuf.ptr), + batchSize, seqLen, numHeads, headDim, + scale, tileSizeQ, tileSizeK); + } return Tensor::from_numpy(result); }, @@ -78,19 +81,74 @@ void register_attention_ops(py::module_& m) { py::array_t result({ static_cast(B), static_cast(H), static_cast(S), static_cast(S)}); + auto resBuf = result.request(); grilly::ops::AttentionScoresParams p{B, S, H, D, scale, 0}; - grilly::ops::attentionScores( - ctx.batch, ctx.pool, ctx.cache, - static_cast(qBuf.ptr), - static_cast(K.request().ptr), - static_cast(result.request().ptr), p); + { + py::gil_scoped_release release; + grilly::ops::attentionScores( + ctx.batch, ctx.pool, ctx.cache, + static_cast(qBuf.ptr), + static_cast(K.request().ptr), + static_cast(resBuf.ptr), p); + } return Tensor::from_numpy(result); }, py::arg("device"), py::arg("Q"), py::arg("K"), py::arg("scale") = 0.0f, "GPU attention scores: Q @ K^T / sqrt(d_h)"); + // ── Fused: scores + softmax + output (single GPU submit) ──────────── + m.def( + "attention_scores_softmax_output", + [](GrillyCoreContext& ctx, + py::array_t Q, py::array_t K, py::array_t V, + float scale) -> py::tuple { + auto qBuf = Q.request(); + auto kBuf = K.request(); + auto vBuf = V.request(); + if (qBuf.ndim != 4 || kBuf.ndim != 4 || vBuf.ndim != 4) + throw std::runtime_error("Q, K, V must be 4D (B, H, S, D)"); + + uint32_t B = static_cast(qBuf.shape[0]); + uint32_t H = static_cast(qBuf.shape[1]); + uint32_t S = static_cast(qBuf.shape[2]); + uint32_t D = static_cast(qBuf.shape[3]); + + if (kBuf.shape[0] != B || kBuf.shape[1] != H || kBuf.shape[2] != S || + kBuf.shape[3] != D || vBuf.shape[0] != B || vBuf.shape[1] != H || + vBuf.shape[2] != S || vBuf.shape[3] != D) + throw std::runtime_error("Q, K, V shapes must match"); + + if (scale == 0.0f) scale = 1.0f / std::sqrt(float(D)); + + py::array_t outArr({ + static_cast(B), static_cast(H), + static_cast(S), static_cast(D)}); + py::array_t wArr({ + static_cast(B), static_cast(H), + static_cast(S), static_cast(S)}); + auto outRB = outArr.request(); + auto wRB = wArr.request(); + + grilly::ops::AttentionScoresParams sp{B, S, H, D, scale, 0}; + grilly::ops::AttentionOutputParams outp{B, S, H, D}; + { + py::gil_scoped_release release; + grilly::ops::attentionScoresSoftmaxOutput( + ctx.batch, ctx.pool, ctx.cache, + static_cast(qBuf.ptr), + static_cast(kBuf.ptr), + static_cast(vBuf.ptr), + static_cast(outRB.ptr), + static_cast(wRB.ptr), sp, outp); + } + return py::make_tuple(Tensor::from_numpy(outArr), Tensor::from_numpy(wArr)); + }, + py::arg("device"), py::arg("Q"), py::arg("K"), py::arg("V"), + py::arg("scale") = 0.0f, + "GPU fused attention: scores + softmax + output (one submit); returns (output, softmax_weights)"); + // ── Attention mask (causal or custom) ──────────────────────────────── m.def( "attention_mask", @@ -116,9 +174,12 @@ void register_attention_ops(py::module_& m) { grilly::ops::AttentionMaskParams p{ B, H, S, causal ? 1u : 0u, mask_value}; - grilly::ops::attentionMask( - ctx.batch, ctx.pool, ctx.cache, - static_cast(result.mutable_data()), maskPtr, p); + { + py::gil_scoped_release release; + grilly::ops::attentionMask( + ctx.batch, ctx.pool, ctx.cache, + static_cast(result.mutable_data()), maskPtr, p); + } return Tensor::from_numpy(result); }, py::arg("device"), py::arg("scores"), @@ -145,13 +206,17 @@ void register_attention_ops(py::module_& m) { py::array_t result({ static_cast(B), static_cast(H), static_cast(S), static_cast(D)}); + auto resBuf = result.request(); grilly::ops::AttentionOutputParams p{B, S, H, D}; - grilly::ops::attentionOutput( - ctx.batch, ctx.pool, ctx.cache, - static_cast(wBuf.ptr), - static_cast(vBuf.ptr), - static_cast(result.request().ptr), p); + { + py::gil_scoped_release release; + grilly::ops::attentionOutput( + ctx.batch, ctx.pool, ctx.cache, + static_cast(wBuf.ptr), + static_cast(vBuf.ptr), + static_cast(resBuf.ptr), p); + } return Tensor::from_numpy(result); }, py::arg("device"), py::arg("weights"), py::arg("V"), @@ -175,12 +240,16 @@ void register_attention_ops(py::module_& m) { static_cast(B), static_cast(S), static_cast(H * D)}); + auto resBuf = result.request(); grilly::ops::ConcatHeadsParams p{B, S, H, D}; - grilly::ops::attentionConcatHeads( - ctx.batch, ctx.pool, ctx.cache, - static_cast(inBuf.ptr), - static_cast(result.request().ptr), p); + { + py::gil_scoped_release release; + grilly::ops::attentionConcatHeads( + ctx.batch, ctx.pool, ctx.cache, + static_cast(inBuf.ptr), + static_cast(resBuf.ptr), p); + } return Tensor::from_numpy(result); }, py::arg("device"), py::arg("mh_output"), @@ -213,13 +282,17 @@ void register_attention_ops(py::module_& m) { : nullptr; py::array_t result(inBuf.shape); + auto resBuf = result.request(); grilly::ops::RoPEParams p{ B, S, H, D, base, precomputed ? 1u : 0u, scaling}; - grilly::ops::applyRoPE( - ctx.batch, ctx.pool, ctx.cache, - static_cast(inBuf.ptr), - static_cast(result.request().ptr), - cosPtr, sinPtr, p); + { + py::gil_scoped_release release; + grilly::ops::applyRoPE( + ctx.batch, ctx.pool, ctx.cache, + static_cast(inBuf.ptr), + static_cast(resBuf.ptr), + cosPtr, sinPtr, p); + } return Tensor::from_numpy(result); }, py::arg("device"), py::arg("input"), diff --git a/cpp/python/bindings_channels.cpp b/cpp/python/bindings_channels.cpp new file mode 100644 index 0000000..5080939 --- /dev/null +++ b/cpp/python/bindings_channels.cpp @@ -0,0 +1,108 @@ +/** + * bindings_channels.cpp — pybind11 bindings for grilly channels. + * + * Exposes InProcessChannel to Python with numpy-friendly message passing. + * When pybind11_protobuf is available, uses native proto casters. + * Otherwise, falls back to bytes-based serialization. + */ + +#include +#include +#include + +#include "grilly/channels/channel.h" + +namespace py = pybind11; + +void bind_channels(py::module_& m) { + using namespace grilly::channels; + + // MessageType enum + py::enum_(m, "MessageType") + .value("TENSOR_DATA", MessageType::TENSOR_DATA) + .value("SPIKE_TRAIN", MessageType::SPIKE_TRAIN) + .value("EXPERT_WEIGHTS", MessageType::EXPERT_WEIGHTS) + .value("EXPERT_UPDATE", MessageType::EXPERT_UPDATE) + .value("ROUTE_REQUEST", MessageType::ROUTE_REQUEST) + .value("ROUTE_RESPONSE", MessageType::ROUTE_RESPONSE) + .value("MEMORY_CAPSULE", MessageType::MEMORY_CAPSULE) + .value("MEMORY_QUERY", MessageType::MEMORY_QUERY) + .value("MEMORY_RESULT", MessageType::MEMORY_RESULT) + .value("TELEMETRY_EVENT", MessageType::TELEMETRY_EVENT) + .value("NEUROCHEM_STATE", MessageType::NEUROCHEM_STATE) + .value("TRAIN_STEP_REQUEST", MessageType::TRAIN_STEP_REQUEST) + .value("TRAIN_STEP_RESULT", MessageType::TRAIN_STEP_RESULT); + + // MessageEnvelope + py::class_(m, "MessageEnvelope") + .def(py::init<>()) + .def_readwrite("type", &MessageEnvelope::type) + .def_readwrite("timestamp_ns", &MessageEnvelope::timestamp_ns) + .def_readwrite("sender_id", &MessageEnvelope::sender_id) + .def_property("payload", + [](const MessageEnvelope& e) { + return py::bytes(reinterpret_cast(e.payload.data()), + e.payload.size()); + }, + [](MessageEnvelope& e, py::bytes data) { + std::string s = data; + e.payload.assign(s.begin(), s.end()); + }) + .def_property_readonly("size", &MessageEnvelope::size); + + // InProcessChannel + py::class_(m, "InProcessChannel") + .def(py::init(), + py::arg("name") = "default", + py::arg("max_queue_size") = 10000) + .def("send", &InProcessChannel::send) + .def("receive", &InProcessChannel::receive) + .def("has_messages", &InProcessChannel::has_messages) + .def("queue_size", &InProcessChannel::queue_size) + .def("clear", &InProcessChannel::clear) + .def("name", &InProcessChannel::name) + + // Convenience: send numpy array as TensorData + .def("send_tensor", [](InProcessChannel& ch, + py::array_t arr, + const std::string& sender_id) { + auto info = arr.request(); + MessageEnvelope env; + env.type = MessageType::TENSOR_DATA; + env.sender_id = sender_id; + env.payload.assign( + reinterpret_cast(info.ptr), + reinterpret_cast(info.ptr) + info.size * sizeof(float)); + ch.send(std::move(env)); + }, py::arg("array"), py::arg("sender_id") = "python") + + // Convenience: receive as numpy array + .def("receive_tensor", [](InProcessChannel& ch) -> py::object { + auto msg = ch.receive(); + if (msg.payload.empty()) return py::none(); + size_t n_floats = msg.payload.size() / sizeof(float); + auto result = py::array_t(n_floats); + auto buf = result.request(); + std::memcpy(buf.ptr, msg.payload.data(), msg.payload.size()); + return result; + }) + + // Convenience: send spike train (timesteps × neurons flattened) + .def("send_spikes", [](InProcessChannel& ch, + py::array_t spikes, + uint32_t n_neurons, + uint32_t n_timesteps, + const std::string& sender_id) { + auto info = spikes.request(); + MessageEnvelope env; + env.type = MessageType::SPIKE_TRAIN; + env.sender_id = sender_id; + // Pack header: n_neurons (4 bytes) + n_timesteps (4 bytes) + data + env.payload.resize(8 + info.size * sizeof(float)); + std::memcpy(env.payload.data(), &n_neurons, 4); + std::memcpy(env.payload.data() + 4, &n_timesteps, 4); + std::memcpy(env.payload.data() + 8, info.ptr, info.size * sizeof(float)); + ch.send(std::move(env)); + }, py::arg("spikes"), py::arg("n_neurons"), py::arg("n_timesteps"), + py::arg("sender_id") = "python"); +} diff --git a/cpp/python/bindings_conv.cpp b/cpp/python/bindings_conv.cpp index 81387aa..34353a8 100644 --- a/cpp/python/bindings_conv.cpp +++ b/cpp/python/bindings_conv.cpp @@ -19,6 +19,8 @@ void register_conv_ops(py::module_& m) { uint32_t groups) -> Tensor { auto inBuf = input.request(); auto wBuf = weight.request(); + require_c_contiguous_float(inBuf); + require_c_contiguous_float(wBuf); if (inBuf.ndim != 4) throw std::runtime_error( @@ -46,8 +48,11 @@ void register_conv_ops(py::module_& m) { uint32_t outW = grilly::ops::convOutputSize(inW, kW, sW, pW, dW); const float* biasPtr = nullptr; - if (bias.has_value()) - biasPtr = static_cast(bias->request().ptr); + if (bias.has_value()) { + auto bBuf = bias->request(); + require_c_contiguous_float(bBuf); + biasPtr = static_cast(bBuf.ptr); + } py::array_t result({ static_cast(batchSize), @@ -57,15 +62,18 @@ void register_conv_ops(py::module_& m) { }); auto rBuf = result.request(); - grilly::ops::conv2d( - ctx.batch, ctx.pool, ctx.cache, - static_cast(inBuf.ptr), - static_cast(wBuf.ptr), - biasPtr, - static_cast(rBuf.ptr), - batchSize, inChannels, inH, inW, - outChannels, kH, kW, - sH, sW, pH, pW, dH, dW, groups); + { + py::gil_scoped_release release; + grilly::ops::conv2d( + ctx.batch, ctx.pool, ctx.cache, + static_cast(inBuf.ptr), + static_cast(wBuf.ptr), + biasPtr, + static_cast(rBuf.ptr), + batchSize, inChannels, inH, inW, + outChannels, kH, kW, + sH, sW, pH, pW, dH, dW, groups); + } return Tensor::from_numpy(result); }, @@ -87,6 +95,8 @@ void register_conv_ops(py::module_& m) { uint32_t dilation, uint32_t groups) -> Tensor { auto inBuf = input.request(); auto wBuf = weight.request(); + require_c_contiguous_float(inBuf); + require_c_contiguous_float(wBuf); if (inBuf.ndim != 3) throw std::runtime_error( @@ -102,8 +112,11 @@ void register_conv_ops(py::module_& m) { length, kSize, stride, padding, dilation); const float* biasPtr = nullptr; - if (bias.has_value()) - biasPtr = static_cast(bias->request().ptr); + if (bias.has_value()) { + auto bBuf = bias->request(); + require_c_contiguous_float(bBuf); + biasPtr = static_cast(bBuf.ptr); + } py::array_t result({ static_cast(batchSize), @@ -112,15 +125,18 @@ void register_conv_ops(py::module_& m) { }); auto rBuf = result.request(); - grilly::ops::conv1d( - ctx.batch, ctx.pool, ctx.cache, - static_cast(inBuf.ptr), - static_cast(wBuf.ptr), - biasPtr, - static_cast(rBuf.ptr), - batchSize, inChannels, length, - outChannels, kSize, - stride, padding, dilation, groups); + { + py::gil_scoped_release release; + grilly::ops::conv1d( + ctx.batch, ctx.pool, ctx.cache, + static_cast(inBuf.ptr), + static_cast(wBuf.ptr), + biasPtr, + static_cast(rBuf.ptr), + batchSize, inChannels, length, + outChannels, kSize, + stride, padding, dilation, groups); + } return Tensor::from_numpy(result); }, @@ -141,6 +157,8 @@ void register_conv_ops(py::module_& m) { uint32_t groups) -> Tensor { auto gBuf = grad_output.request(); auto wBuf = weight.request(); + require_c_contiguous_float(gBuf); + require_c_contiguous_float(wBuf); uint32_t batchSize = input_shape[0]; uint32_t inChannels = input_shape[1]; @@ -168,12 +186,16 @@ void register_conv_ops(py::module_& m) { static_cast(inChannels), static_cast(inH), static_cast(inW)}); + auto resBuf = result.request(); - grilly::ops::conv2dBackwardInput( - ctx.batch, ctx.pool, ctx.cache, - static_cast(gBuf.ptr), - static_cast(wBuf.ptr), - static_cast(result.request().ptr), p); + { + py::gil_scoped_release release; + grilly::ops::conv2dBackwardInput( + ctx.batch, ctx.pool, ctx.cache, + static_cast(gBuf.ptr), + static_cast(wBuf.ptr), + static_cast(resBuf.ptr), p); + } return Tensor::from_numpy(result); }, @@ -196,6 +218,8 @@ void register_conv_ops(py::module_& m) { uint32_t groups, bool has_bias) -> py::dict { auto gBuf = grad_output.request(); auto iBuf = input.request(); + require_c_contiguous_float(gBuf); + require_c_contiguous_float(iBuf); uint32_t batchSize = static_cast(iBuf.shape[0]); uint32_t inChannels = static_cast(iBuf.shape[1]); @@ -229,15 +253,19 @@ void register_conv_ops(py::module_& m) { has_bias ? std::vector{ static_cast(outChannels)} : std::vector{1}); - - grilly::ops::conv2dBackwardWeight( - ctx.batch, ctx.pool, ctx.cache, - static_cast(gBuf.ptr), - static_cast(iBuf.ptr), - static_cast(gradWeight.request().ptr), - has_bias ? static_cast(gradBias.request().ptr) - : nullptr, - p); + auto gwBuf = gradWeight.request(); + auto gbBuf = gradBias.request(); + + { + py::gil_scoped_release release; + grilly::ops::conv2dBackwardWeight( + ctx.batch, ctx.pool, ctx.cache, + static_cast(gBuf.ptr), + static_cast(iBuf.ptr), + static_cast(gwBuf.ptr), + has_bias ? static_cast(gbBuf.ptr) : nullptr, + p); + } py::dict result; result["grad_weight"] = Tensor::from_numpy(gradWeight); diff --git a/cpp/python/bindings_core.cpp b/cpp/python/bindings_core.cpp index b1f8d64..0491197 100644 --- a/cpp/python/bindings_core.cpp +++ b/cpp/python/bindings_core.cpp @@ -75,7 +75,11 @@ PYBIND11_MODULE(grilly_core, m) { "Clear all recorded ops for reuse") .def("optimize", [](grilly::OpGraph& graph, GrillyCoreContext& ctx) -> py::dict { - auto stats = graph.optimize(ctx.cache); + grilly::FusionStats stats; + { + py::gil_scoped_release release; + stats = graph.optimize(ctx.cache); + } py::dict d; d["ops_fused"] = stats.opsFused; d["barriers_eliminated"] = stats.barriersEliminated; @@ -87,7 +91,10 @@ PYBIND11_MODULE(grilly_core, m) { "Run fusion optimization pass. Returns fusion statistics.") .def("execute", [](grilly::OpGraph& graph, GrillyCoreContext& ctx) { - graph.execute(ctx.batch, ctx.cache); + { + py::gil_scoped_release release; + graph.execute(ctx.batch, ctx.cache); + } }, py::arg("device"), "Execute all recorded ops in a single GPU submission"); @@ -147,6 +154,8 @@ PYBIND11_MODULE(grilly_core, m) { "Download to CPU and return as numpy array") .def("gpu_handle", &Tensor::gpu_handle, "Upload to GPU and return buffer handle") + .def("gpu_handle_if_valid", &Tensor::gpu_handle_if_valid, + "GPU buffer handle if already resident (0 otherwise); does not upload") .def("mark_gpu_modified", &Tensor::mark_gpu_modified) .def("mark_cpu_modified", &Tensor::mark_cpu_modified) .def("reshape", &Tensor::reshape, py::arg("shape")) @@ -426,6 +435,10 @@ PYBIND11_MODULE(grilly_core, m) { register_siglip_ops(m); register_perceiver_ops(m); register_moqe_train_ops(m); + register_moe_ops(m); register_fusion_ops(m); + register_vsa_lm_ops(m); + register_grl_ops(m); register_misc_ops(m); + register_prefix_scan_ops(m); } diff --git a/cpp/python/bindings_core.h b/cpp/python/bindings_core.h index 6987de4..3ba823d 100644 --- a/cpp/python/bindings_core.h +++ b/cpp/python/bindings_core.h @@ -98,6 +98,61 @@ struct GrillyCoreContext { // Helper utilities shared across binding files // ═══════════════════════════════════════════════════════════════════════════ +/// Require NumPy C-contiguous float32 (kernels assume dense row-major layout). +inline void require_c_contiguous_float(const py::buffer_info& buf) { + if (buf.itemsize != sizeof(float)) + throw std::runtime_error("expected float32 array"); + if (buf.ndim == 0) + return; + py::ssize_t expected_stride = static_cast(sizeof(float)); + for (int i = static_cast(buf.ndim) - 1; i >= 0; --i) { + if (buf.strides[i] != expected_stride) + throw std::runtime_error("array must be C-contiguous float32"); + expected_stride *= buf.shape[i]; + } +} + +/// Require NumPy C-contiguous int8 (e.g. VSA / Hamming vectors). +inline void require_c_contiguous_int8(const py::buffer_info& buf) { + if (buf.itemsize != sizeof(int8_t)) + throw std::runtime_error("expected int8 array"); + if (buf.ndim == 0) + return; + py::ssize_t expected_stride = static_cast(sizeof(int8_t)); + for (int i = static_cast(buf.ndim) - 1; i >= 0; --i) { + if (buf.strides[i] != expected_stride) + throw std::runtime_error("array must be C-contiguous int8"); + expected_stride *= buf.shape[i]; + } +} + +/// Require NumPy C-contiguous uint32 (e.g. CE targets). +inline void require_c_contiguous_int32(const py::buffer_info& buf) { + if (buf.itemsize != sizeof(int32_t)) + throw std::runtime_error("expected int32 array"); + if (buf.ndim == 0) + return; + py::ssize_t expected_stride = static_cast(sizeof(int32_t)); + for (int i = static_cast(buf.ndim) - 1; i >= 0; --i) { + if (buf.strides[i] != expected_stride) + throw std::runtime_error("array must be C-contiguous int32"); + expected_stride *= buf.shape[i]; + } +} + +inline void require_c_contiguous_uint32(const py::buffer_info& buf) { + if (buf.itemsize != sizeof(uint32_t)) + throw std::runtime_error("expected uint32 array"); + if (buf.ndim == 0) + return; + py::ssize_t expected_stride = static_cast(sizeof(uint32_t)); + for (int i = static_cast(buf.ndim) - 1; i >= 0; --i) { + if (buf.strides[i] != expected_stride) + throw std::runtime_error("array must be C-contiguous uint32"); + expected_stride *= buf.shape[i]; + } +} + /// Extract flat batch*seq and last-dim from a numpy buffer_info. inline std::pair extractBatchAndLastDim( const py::buffer_info& buf) { @@ -132,4 +187,8 @@ void register_pooling_ops(py::module_& m); void register_misc_ops(py::module_& m); void register_perceiver_ops(py::module_& m); void register_moqe_train_ops(py::module_& m); +void register_moe_ops(py::module_& m); void register_fusion_ops(py::module_& m); +void register_vsa_lm_ops(py::module_& m); +void register_grl_ops(py::module_& m); +void register_prefix_scan_ops(py::module_& m); diff --git a/cpp/python/bindings_fusion.cpp b/cpp/python/bindings_fusion.cpp index b179adf..fb5424f 100644 --- a/cpp/python/bindings_fusion.cpp +++ b/cpp/python/bindings_fusion.cpp @@ -39,7 +39,11 @@ void register_fusion_ops(py::module_& m) { graph.ops.push_back(std::move(op)); } - FusionResult result = engine.fuse(graph, ctx.cache); + FusionResult result; + { + py::gil_scoped_release release; + result = engine.fuse(graph, ctx.cache); + } py::dict d; d["shader_name"] = result.shaderName; diff --git a/cpp/python/bindings_grl.cpp b/cpp/python/bindings_grl.cpp new file mode 100644 index 0000000..959e55d --- /dev/null +++ b/cpp/python/bindings_grl.cpp @@ -0,0 +1,44 @@ +/// Pybind11 bindings for GRL (.grl) checkpoint I/O — implemented in C++ for +/// performance and a single canonical binary encoder/decoder. + +#include "bindings_core.h" +#include "grilly/io/grl_checkpoint.h" + +#include +#include + +void register_grl_ops(py::module_& m) { + m.def( + "grl_write_file", + [](const std::string& path, const std::string& metadata_json, + const std::string& index_json, py::bytes payload_bytes) { + std::string pb = payload_bytes; + std::vector payload(pb.begin(), pb.end()); + if (!grilly::io::grl_write_file(path, metadata_json, index_json, + payload)) { + throw std::runtime_error("grl_write_file failed: " + path); + } + }, + py::arg("path"), py::arg("metadata_json"), py::arg("index_json"), + py::arg("payload"), + "Write a GRL v1 checkpoint (header + metadata JSON + tensor index JSON " + "+ raw payload bytes). Matches Python utils/grl_checkpoint layout."); + + m.def( + "grl_read_file", + [](const std::string& path) { + std::string metadata_json; + std::string index_json; + std::vector payload; + if (!grilly::io::grl_read_file(path, metadata_json, index_json, + payload)) { + throw std::runtime_error("grl_read_file failed: " + path); + } + return py::make_tuple(metadata_json, index_json, + py::bytes(reinterpret_cast( + payload.data()), + payload.size())); + }, + py::arg("path"), + "Read a GRL v1 file. Returns (metadata_json, index_json, payload_bytes)."); +} diff --git a/cpp/python/bindings_linear.cpp b/cpp/python/bindings_linear.cpp index c1ad145..9b35265 100644 --- a/cpp/python/bindings_linear.cpp +++ b/cpp/python/bindings_linear.cpp @@ -8,12 +8,18 @@ void register_linear_ops(py::module_& m) { using namespace grilly::nn; - // ── GPU linear forward ─────────────────────────────────────────────── + // ── GPU linear forward (fp32 + fp16) ─────────────────────────────── + // + // Accepts generic numpy arrays — inspects ``itemsize`` to detect fp32 + // (itemsize=4) vs fp16 (itemsize=2). Input and weights must share the + // same dtype; bias is always fp32 (stability + broadcast simplicity). + // Output is always fp32 regardless of input dtype because the + // cooperative-matrix accumulator runs in fp32 — callers needing fp16 + // output can ``astype(np.float16)`` on the returned array. m.def( "linear", - [](GrillyCoreContext& ctx, py::array_t x, - py::array_t weights, - std::optional> bias) -> Tensor { + [](GrillyCoreContext& ctx, py::array x, py::array weights, + std::optional bias) -> py::array { auto xBuf = x.request(); auto wBuf = weights.request(); @@ -24,6 +30,15 @@ void register_linear_ops(py::module_& m) { throw std::runtime_error( "weights must be 2D (output_dim, input_dim)"); + if (xBuf.itemsize != 2 && xBuf.itemsize != 4) + throw std::runtime_error( + "grilly linear: x must be fp32 (itemsize=4) or fp16 (itemsize=2)"); + if (xBuf.itemsize != wBuf.itemsize) + throw std::runtime_error( + "grilly linear: x and weights must share the same dtype"); + + const uint32_t elemSize = static_cast(xBuf.itemsize); + auto [batchSeq, inputDim] = extractBatchAndLastDim(xBuf); uint32_t outputDim = static_cast(wBuf.shape[0]); @@ -33,19 +48,26 @@ void register_linear_ops(py::module_& m) { std::to_string(wBuf.shape[1]) + " vs " + std::to_string(inputDim)); - const float* biasPtr = nullptr; + // Bias is always fp32. If the caller passed fp16 bias, we + // require them to upcast it on the Python side — simpler and + // matches the C++ contract. + const void* biasPtr = nullptr; uint32_t hasBias = 0; if (bias.has_value()) { auto bBuf = bias->request(); + if (bBuf.itemsize != 4) + throw std::runtime_error( + "grilly linear: bias must be fp32 (cast via .astype(np.float32))"); if (bBuf.ndim != 1 || static_cast(bBuf.shape[0]) != outputDim) throw std::runtime_error( "bias must be 1D with size output_dim"); - biasPtr = static_cast(bBuf.ptr); + biasPtr = bBuf.ptr; hasBias = 1; } - grilly::ops::LinearParams p{batchSeq, inputDim, outputDim, hasBias}; + grilly::ops::LinearParams p{ + batchSeq, inputDim, outputDim, hasBias, elemSize}; std::vector outShape; for (int i = 0; i < xBuf.ndim - 1; ++i) @@ -53,16 +75,16 @@ void register_linear_ops(py::module_& m) { if (xBuf.ndim == 1) outShape.push_back(1); outShape.push_back(outputDim); - py::array_t result(outShape); + // Output is always fp32 — format code 'f'. The coopmat shader + // accumulates in fp32, and so does fnn-linear. + py::array result(py::dtype("f"), outShape); auto rBuf = result.request(); - // Extract raw pointers before GIL release - const float* xPtr = static_cast(xBuf.ptr); - const float* wPtr = static_cast(wBuf.ptr); - float* oPtr = static_cast(rBuf.ptr); + const void* xPtr = xBuf.ptr; + const void* wPtr = wBuf.ptr; + void* oPtr = rBuf.ptr; { - // Release GIL during GPU GEMM dispatch py::gil_scoped_release release; grilly::ops::linear( ctx.batch, ctx.pool, ctx.cache, @@ -72,11 +94,11 @@ void register_linear_ops(py::module_& m) { if (xBuf.ndim == 1) result = result.reshape({static_cast(outputDim)}); - return Tensor::from_numpy(result); + return result; }, py::arg("device"), py::arg("x"), py::arg("weights"), py::arg("bias") = py::none(), - "GPU linear projection: output = x @ W^T + bias"); + "GPU linear projection (fp32 or fp16 input; output always fp32)"); // ── GPU linear forward (Tensor I/O — no numpy copies) ────────────── m.def( @@ -117,7 +139,9 @@ void register_linear_ops(py::module_& m) { hasBias = 1; } - grilly::ops::LinearParams p{batchSeq, inputDim, outputDim, hasBias}; + // ``linear_t`` is the native C++ Tensor path — always fp32. + grilly::ops::LinearParams p{batchSeq, inputDim, outputDim, + hasBias, 4u}; // Allocate output tensor (CPU-valid) std::vector outShape; @@ -161,7 +185,9 @@ void register_linear_ops(py::module_& m) { hasBias = 1; } - grilly::ops::LinearParams p{batchSeq, inputDim, outputDim, hasBias}; + // ``linear_cpu`` is the Eigen reference path — always fp32. + grilly::ops::LinearParams p{batchSeq, inputDim, outputDim, + hasBias, 4u}; std::vector out = grilly::ops::linearCPU( static_cast(xBuf.ptr), static_cast(wBuf.ptr), biasPtr, p); @@ -184,43 +210,56 @@ void register_linear_ops(py::module_& m) { py::arg("x"), py::arg("weights"), py::arg("bias") = py::none(), "CPU linear projection using Eigen (for verification)"); - // ── GPU linear backward ────────────────────────────────────────────── + // ── GPU linear backward (fp32 only for now; interface is fp16-ready) ── m.def( "linear_backward", [](GrillyCoreContext& ctx, - py::array_t grad_output, py::array_t input, - py::array_t weights) -> py::dict { + py::array grad_output, py::array input, + py::array weights) -> py::dict { auto gBuf = grad_output.request(); auto iBuf = input.request(); auto wBuf = weights.request(); + if (gBuf.itemsize != 2 && gBuf.itemsize != 4) + throw std::runtime_error( + "grilly linear_backward: grad_output must be fp32 or fp16"); + if (gBuf.itemsize != iBuf.itemsize || + gBuf.itemsize != wBuf.itemsize) + throw std::runtime_error( + "grilly linear_backward: grad_output, input, weights " + "must share the same dtype"); + + const uint32_t elemSize = static_cast(gBuf.itemsize); + const std::string npFormat = (elemSize == 2) ? "e" : "f"; + auto [batchSeq, outputDim] = extractBatchAndLastDim(gBuf); uint32_t inputDim = static_cast( iBuf.shape[iBuf.ndim - 1]); - grilly::ops::LinearParams p{batchSeq, inputDim, outputDim, 1}; + grilly::ops::LinearParams p{ + batchSeq, inputDim, outputDim, 1u, elemSize}; - py::array_t gradInput(iBuf.shape); - py::array_t gradWeight(wBuf.shape); - py::array_t gradBias( - {static_cast(outputDim)}); + py::array gradInput(py::dtype(npFormat), iBuf.shape); + py::array gradWeight(py::dtype(npFormat), wBuf.shape); + // Explicit vector to disambiguate from py::array(handle, bool). + std::vector biasShape = { + static_cast(outputDim)}; + py::array gradBias(py::dtype(npFormat), biasShape); grilly::ops::linearBackward( ctx.batch, ctx.pool, ctx.cache, - static_cast(gBuf.ptr), - static_cast(iBuf.ptr), - static_cast(wBuf.ptr), - static_cast(gradInput.request().ptr), - static_cast(gradWeight.request().ptr), - static_cast(gradBias.request().ptr), p); + gBuf.ptr, iBuf.ptr, wBuf.ptr, + gradInput.request().ptr, + gradWeight.request().ptr, + gradBias.request().ptr, p); py::dict result; - result["grad_input"] = Tensor::from_numpy(gradInput); - result["grad_weight"] = Tensor::from_numpy(gradWeight); - result["grad_bias"] = Tensor::from_numpy(gradBias); + result["grad_input"] = gradInput; + result["grad_weight"] = gradWeight; + result["grad_bias"] = gradBias; return result; }, py::arg("device"), py::arg("grad_output"), py::arg("input"), py::arg("weights"), - "GPU linear backward: grad_input, grad_weight, grad_bias"); + "GPU linear backward (supports fp32; fp16 interface ready for shader upgrade)"); } diff --git a/cpp/python/bindings_loss.cpp b/cpp/python/bindings_loss.cpp index 9a3c59d..7bbb03d 100644 --- a/cpp/python/bindings_loss.cpp +++ b/cpp/python/bindings_loss.cpp @@ -15,6 +15,7 @@ void register_loss_ops(py::module_& m) { py::array_t logits, py::array_t targets, float label_smoothing) -> Tensor { auto lBuf = logits.request(); + require_c_contiguous_float(lBuf); uint32_t batchSize, seqLen, vocabSize; if (lBuf.ndim == 2) { @@ -31,14 +32,20 @@ void register_loss_ops(py::module_& m) { uint32_t totalPos = batchSize * seqLen; py::array_t losses(totalPos); + auto tBuf = targets.request(); + require_c_contiguous_uint32(tBuf); + auto lossBuf = losses.request(); grilly::ops::CrossEntropyParams p{ batchSize, seqLen, vocabSize, 0, label_smoothing}; - grilly::ops::crossEntropyLoss( - ctx.batch, ctx.pool, ctx.cache, - static_cast(lBuf.ptr), - static_cast(targets.request().ptr), - static_cast(losses.request().ptr), p); + { + py::gil_scoped_release release; + grilly::ops::crossEntropyLoss( + ctx.batch, ctx.pool, ctx.cache, + static_cast(lBuf.ptr), + static_cast(tBuf.ptr), + static_cast(lossBuf.ptr), p); + } return Tensor::from_numpy(losses); }, py::arg("device"), py::arg("logits"), py::arg("targets"), @@ -52,18 +59,25 @@ void register_loss_ops(py::module_& m) { py::array_t logits, py::array_t targets) -> Tensor { auto lBuf = logits.request(); + require_c_contiguous_float(lBuf); uint32_t batchSize = static_cast(lBuf.shape[0]); uint32_t numClasses = static_cast(lBuf.shape[1]); py::array_t gradLogits(lBuf.shape); + auto tBuf = targets.request(); + require_c_contiguous_uint32(tBuf); + auto gBuf = gradLogits.request(); grilly::ops::CrossEntropyBackwardParams p{ batchSize, numClasses}; - grilly::ops::crossEntropyBackward( - ctx.batch, ctx.pool, ctx.cache, - static_cast(lBuf.ptr), - static_cast(targets.request().ptr), - static_cast(gradLogits.request().ptr), p); + { + py::gil_scoped_release release; + grilly::ops::crossEntropyBackward( + ctx.batch, ctx.pool, ctx.cache, + static_cast(lBuf.ptr), + static_cast(tBuf.ptr), + static_cast(gBuf.ptr), p); + } return Tensor::from_numpy(gradLogits); }, py::arg("device"), py::arg("logits"), py::arg("targets"), diff --git a/cpp/python/bindings_misc.cpp b/cpp/python/bindings_misc.cpp index d39d298..584334d 100644 --- a/cpp/python/bindings_misc.cpp +++ b/cpp/python/bindings_misc.cpp @@ -90,17 +90,25 @@ void register_misc_ops(py::module_& m) { py::array_t input, py::array_t random_mask, float p, bool training) -> Tensor { auto inBuf = input.request(); - uint32_t total = 1; - for (int i = 0; i < inBuf.ndim; ++i) - total *= static_cast(inBuf.shape[i]); + auto maskBuf = random_mask.request(); + require_c_contiguous_float(inBuf); + require_c_contiguous_float(maskBuf); + uint32_t total = static_cast(inBuf.size); + if (static_cast(maskBuf.size) != total) + throw std::runtime_error("dropout: input and mask size mismatch"); py::array_t result(inBuf.shape); - grilly::ops::dropout( - ctx.batch, ctx.pool, ctx.cache, - static_cast(inBuf.ptr), - static_cast(random_mask.request().ptr), - static_cast(result.request().ptr), - total, p, training); + auto rBuf = result.request(); + require_c_contiguous_float(rBuf); + { + py::gil_scoped_release release; + grilly::ops::dropout( + ctx.batch, ctx.pool, ctx.cache, + static_cast(inBuf.ptr), + static_cast(maskBuf.ptr), + static_cast(rBuf.ptr), + total, p, training); + } return Tensor::from_numpy(result); }, py::arg("device"), py::arg("input"), py::arg("random_mask"), @@ -115,6 +123,8 @@ void register_misc_ops(py::module_& m) { py::array_t embeddings) -> Tensor { auto idBuf = token_ids.request(); auto eBuf = embeddings.request(); + require_c_contiguous_uint32(idBuf); + require_c_contiguous_float(eBuf); uint32_t batchSize = 1, seqLen; if (idBuf.ndim == 1) { @@ -131,13 +141,19 @@ void register_misc_ops(py::module_& m) { static_cast(seqLen), static_cast(embDim)}); + auto rBuf = result.request(); + require_c_contiguous_float(rBuf); + grilly::ops::EmbeddingParams p{ batchSize, seqLen, vocabSize, embDim}; - grilly::ops::embeddingLookup( - ctx.batch, ctx.pool, ctx.cache, - static_cast(idBuf.ptr), - static_cast(eBuf.ptr), - static_cast(result.request().ptr), p); + { + py::gil_scoped_release release; + grilly::ops::embeddingLookup( + ctx.batch, ctx.pool, ctx.cache, + static_cast(idBuf.ptr), + static_cast(eBuf.ptr), + static_cast(rBuf.ptr), p); + } if (idBuf.ndim == 1) result = result.reshape({ @@ -208,13 +224,18 @@ void register_misc_ops(py::module_& m) { [](GrillyCoreContext& ctx, grilly::ops::KVCache& kvCache, py::array_t newKeys, py::array_t newValues) { auto kBuf = newKeys.request(); + auto vBuf = newValues.request(); + require_c_contiguous_float(kBuf); + require_c_contiguous_float(vBuf); uint32_t numNew = static_cast(kBuf.shape[0]); - grilly::ops::kvCacheAppend( - ctx.batch, ctx.pool, ctx.cache, kvCache, - static_cast(kBuf.ptr), - static_cast( - newValues.request().ptr), - numNew); + { + py::gil_scoped_release release; + grilly::ops::kvCacheAppend( + ctx.batch, ctx.pool, ctx.cache, kvCache, + static_cast(kBuf.ptr), + static_cast(vBuf.ptr), + numNew); + } }, py::arg("device"), py::arg("kv_cache"), py::arg("new_keys"), py::arg("new_values"), @@ -234,10 +255,17 @@ void register_misc_ops(py::module_& m) { static_cast(tokens), static_cast(cfg.numHeads), static_cast(cfg.headDim)}); - grilly::ops::kvCacheDecode( - ctx.batch, ctx.pool, ctx.cache, kvCache, - static_cast(keys.request().ptr), - static_cast(values.request().ptr)); + auto keysBuf = keys.request(); + auto valsBuf = values.request(); + require_c_contiguous_float(keysBuf); + require_c_contiguous_float(valsBuf); + { + py::gil_scoped_release release; + grilly::ops::kvCacheDecode( + ctx.batch, ctx.pool, ctx.cache, kvCache, + static_cast(keysBuf.ptr), + static_cast(valsBuf.ptr)); + } py::dict result; result["keys"] = Tensor::from_numpy(keys); result["values"] = Tensor::from_numpy(values); @@ -252,12 +280,18 @@ void register_misc_ops(py::module_& m) { std::optional> attentionScores, uint32_t numEvict) { const float* scoresPtr = nullptr; - if (attentionScores.has_value()) - scoresPtr = static_cast( - attentionScores->request().ptr); - grilly::ops::kvCacheEvictH2O( - ctx.batch, ctx.pool, ctx.cache, kvCache, - scoresPtr, numEvict); + py::buffer_info scoresBuf; + if (attentionScores.has_value()) { + scoresBuf = attentionScores->request(); + require_c_contiguous_float(scoresBuf); + scoresPtr = static_cast(scoresBuf.ptr); + } + { + py::gil_scoped_release release; + grilly::ops::kvCacheEvictH2O( + ctx.batch, ctx.pool, ctx.cache, kvCache, + scoresPtr, numEvict); + } }, py::arg("device"), py::arg("kv_cache"), py::arg("attention_scores") = py::none(), @@ -267,8 +301,11 @@ void register_misc_ops(py::module_& m) { m.def( "kv_cache_compact", [](GrillyCoreContext& ctx, grilly::ops::KVCache& kvCache) { - grilly::ops::kvCacheCompact( - ctx.batch, ctx.pool, ctx.cache, kvCache); + { + py::gil_scoped_release release; + grilly::ops::kvCacheCompact( + ctx.batch, ctx.pool, ctx.cache, kvCache); + } }, py::arg("device"), py::arg("kv_cache"), "Compact KV cache after eviction"); @@ -298,13 +335,18 @@ void register_misc_ops(py::module_& m) { [](GrillyCoreContext& ctx, grilly::ops::KVCache& kvCache, py::array_t tokenFeatures, py::array_t attentionScores, uint32_t seqLen) { - grilly::ops::kvCacheTrainEvictionHead( - ctx.batch, ctx.pool, ctx.cache, kvCache, - static_cast( - tokenFeatures.request().ptr), - static_cast( - attentionScores.request().ptr), - seqLen); + auto tfBuf = tokenFeatures.request(); + auto asBuf = attentionScores.request(); + require_c_contiguous_float(tfBuf); + require_c_contiguous_float(asBuf); + { + py::gil_scoped_release release; + grilly::ops::kvCacheTrainEvictionHead( + ctx.batch, ctx.pool, ctx.cache, kvCache, + static_cast(tfBuf.ptr), + static_cast(asBuf.ptr), + seqLen); + } }, py::arg("device"), py::arg("kv_cache"), py::arg("token_features"), py::arg("attention_scores"), @@ -317,12 +359,18 @@ void register_misc_ops(py::module_& m) { std::optional> hiddenStates, uint32_t hiddenDim) { const float* hsPtr = nullptr; - if (hiddenStates.has_value()) - hsPtr = static_cast( - hiddenStates->request().ptr); - grilly::ops::kvCacheEvictSpeculative( - ctx.batch, ctx.pool, ctx.cache, kvCache, - hsPtr, hiddenDim); + py::buffer_info hsBuf; + if (hiddenStates.has_value()) { + hsBuf = hiddenStates->request(); + require_c_contiguous_float(hsBuf); + hsPtr = static_cast(hsBuf.ptr); + } + { + py::gil_scoped_release release; + grilly::ops::kvCacheEvictSpeculative( + ctx.batch, ctx.pool, ctx.cache, kvCache, + hsPtr, hiddenDim); + } }, py::arg("device"), py::arg("kv_cache"), py::arg("hidden_states") = py::none(), @@ -335,6 +383,7 @@ void register_misc_ops(py::module_& m) { [](GrillyCoreContext& ctx, py::array_t input, uint32_t waveSize, bool reverse) -> py::array_t { auto inBuf = input.request(); + require_c_contiguous_float(inBuf); if (inBuf.ndim != 4) throw std::runtime_error("input must be 4D"); uint32_t batchSize = static_cast(inBuf.shape[0]); @@ -355,12 +404,16 @@ void register_misc_ops(py::module_& m) { result = py::array_t(outSize / sizeof(float)); } auto rBuf = result.request(); - grilly::ops::swizzle( - ctx.batch, ctx.pool, ctx.cache, - static_cast(inBuf.ptr), - static_cast(rBuf.ptr), - batchSize, numHeads, seqLen, headDim, - waveSize, reverse); + require_c_contiguous_float(rBuf); + { + py::gil_scoped_release release; + grilly::ops::swizzle( + ctx.batch, ctx.pool, ctx.cache, + static_cast(inBuf.ptr), + static_cast(rBuf.ptr), + batchSize, numHeads, seqLen, headDim, + waveSize, reverse); + } return result; }, py::arg("device"), py::arg("input"), @@ -517,6 +570,8 @@ void register_misc_ops(py::module_& m) { py::array_t cache_data) -> py::array_t { auto qBuf = query.request(); auto cBuf = cache_data.request(); + require_c_contiguous_int8(qBuf); + require_c_contiguous_int8(cBuf); uint32_t dim = static_cast(qBuf.shape[0]); uint32_t numEntries; if (cBuf.ndim == 2) { @@ -571,6 +626,8 @@ void register_misc_ops(py::module_& m) { py::array_t cache_data) -> py::array_t { auto qBuf = query.request(); auto cBuf = cache_data.request(); + require_c_contiguous_int8(qBuf); + require_c_contiguous_int8(cBuf); uint32_t dim = static_cast(qBuf.shape[0]); uint32_t numEntries = (cBuf.ndim == 2) ? static_cast(cBuf.shape[0]) diff --git a/cpp/python/bindings_moe.cpp b/cpp/python/bindings_moe.cpp new file mode 100644 index 0000000..e6049a1 --- /dev/null +++ b/cpp/python/bindings_moe.cpp @@ -0,0 +1,253 @@ +/// bindings_moe.cpp — fused MoE upload / forward / backward / weight update. +#include "bindings_core.h" +#include "grilly/ops/moe_forward.h" + +#include +#include +#include + +using namespace grilly; + +void register_moe_ops(py::module_& m) { + + m.def( + "moe_upload", + [](GrillyCoreContext& ctx, py::array_t embed_w, py::array_t pos_w, + py::list expert_ws, py::list router_ws, py::list router_bs, py::array_t out_w, + uint32_t n_layers, uint32_t n_experts) -> int { + auto eb = embed_w.request(); + auto pb = pos_w.request(); + auto ob = out_w.request(); + require_c_contiguous_float(eb); + require_c_contiguous_float(pb); + require_c_contiguous_float(ob); + + if (eb.ndim != 2 || pb.ndim != 2 || ob.ndim != 2) + throw std::runtime_error("moe_upload: embed/pos/out must be 2-D"); + uint32_t vocab = static_cast(eb.shape[0]); + uint32_t d = static_cast(eb.shape[1]); + if (static_cast(ob.shape[1]) != d) + throw std::runtime_error("moe_upload: out_w second dim must match d"); + if (static_cast(ob.shape[0]) != vocab) + throw std::runtime_error("moe_upload: out_w first dim must match vocab"); + uint32_t max_seq = static_cast(pb.shape[0]); + if (static_cast(pb.shape[1]) != d) + throw std::runtime_error("moe_upload: pos_w second dim must match d"); + + std::vector exp_ptrs; + std::vector> exp_keep; + for (auto& item : expert_ws) { + auto arr = item.cast>(); + auto wb = arr.request(); + require_c_contiguous_float(wb); + if (wb.ndim != 2 || static_cast(wb.shape[0]) != d || + static_cast(wb.shape[1]) != d) + throw std::runtime_error("moe_upload: each expert must be (d, d)"); + exp_keep.push_back(arr); + exp_ptrs.push_back(static_cast(wb.ptr)); + } + + std::vector rW_ptrs; + std::vector rb_ptrs; + std::vector> rW_keep; + std::vector> rb_keep; + for (uint32_t l = 0; l < n_layers; ++l) { + auto rw = router_ws[l].cast>(); + auto rb = router_bs[l].cast>(); + auto rwb = rw.request(); + auto rbb = rb.request(); + require_c_contiguous_float(rwb); + require_c_contiguous_float(rbb); + if (rwb.ndim != 2 || static_cast(rwb.shape[0]) != n_experts || + static_cast(rwb.shape[1]) != d) + throw std::runtime_error("moe_upload: router_w must be (n_experts, d)"); + if (rbb.ndim != 1 || static_cast(rbb.shape[0]) != n_experts) + throw std::runtime_error("moe_upload: router_b must be (n_experts,)"); + rW_keep.push_back(rw); + rb_keep.push_back(rb); + rW_ptrs.push_back(static_cast(rwb.ptr)); + rb_ptrs.push_back(static_cast(rbb.ptr)); + } + + int handle = ops::moe_upload( + ctx.pool, vocab, d, max_seq, + static_cast(eb.ptr), + static_cast(pb.ptr), + exp_ptrs, rW_ptrs, rb_ptrs, + static_cast(ob.ptr), + n_layers, n_experts); + + ctx.waitIdle(); + return handle; + }, + py::arg("device"), py::arg("embed_w"), py::arg("pos_w"), py::arg("expert_ws"), + py::arg("router_ws"), py::arg("router_bs"), py::arg("out_w"), + py::arg("n_layers"), py::arg("n_experts"), + "Upload MoE weights to GPU; returns opaque integer handle."); + + m.def( + "moe_release", + [](GrillyCoreContext& ctx, int handle) { ops::moe_release(ctx.pool, handle); }, + py::arg("device"), py::arg("handle"), "Free GPU buffers for a MoE handle."); + + m.def( + "moe_forward", + [](GrillyCoreContext& ctx, int handle, py::array_t input_ids) + -> py::array_t { + auto& cache = ops::moe_get_cache(handle); + auto ib = input_ids.request(); + require_c_contiguous_int32(ib); + if (ib.ndim != 1) + throw std::runtime_error("moe_forward: input_ids must be 1-D"); + uint32_t seq_len = static_cast(ib.shape[0]); + uint32_t V = cache.vocab; + + py::array_t logits({static_cast(seq_len), + static_cast(V)}); + + { + py::gil_scoped_release release; + std::lock_guard lock(ctx.ctx_mutex); + ops::moe_forward_gpu(ctx.batch, ctx.pool, ctx.cache, cache, + static_cast(ib.ptr), seq_len, + logits.mutable_data()); + } + return logits; + }, + py::arg("device"), py::arg("handle"), py::arg("input_ids"), + "Run full MoE forward on GPU (router/blend on CPU). Returns (seq_len, vocab)."); + + m.def( + "moe_update_weights", + [](GrillyCoreContext& ctx, int handle, py::array_t embed_w, + py::array_t pos_w, py::list expert_ws, py::list router_ws, + py::list router_bs, py::array_t out_w) { + auto& h = ops::moe_get_cache(handle); + auto eb = embed_w.request(); + auto pb = pos_w.request(); + auto ob = out_w.request(); + require_c_contiguous_float(eb); + require_c_contiguous_float(pb); + require_c_contiguous_float(ob); + + uint32_t L = h.nLayers; + uint32_t E = h.nExperts; + + std::vector exp_ptrs; + std::vector> exp_keep; + for (auto& item : expert_ws) { + auto arr = item.cast>(); + auto wb = arr.request(); + require_c_contiguous_float(wb); + exp_keep.push_back(arr); + exp_ptrs.push_back(static_cast(wb.ptr)); + } + + std::vector rW_ptrs; + std::vector rb_ptrs; + for (uint32_t l = 0; l < L; ++l) { + auto rw = router_ws[l].cast>(); + auto rb = router_bs[l].cast>(); + auto rwb = rw.request(); + auto rbb = rb.request(); + require_c_contiguous_float(rwb); + require_c_contiguous_float(rbb); + rW_ptrs.push_back(static_cast(rwb.ptr)); + rb_ptrs.push_back(static_cast(rbb.ptr)); + } + + { + py::gil_scoped_release release; + std::lock_guard lock(ctx.ctx_mutex); + ops::moe_update_weights( + ctx.pool, h, + static_cast(eb.ptr), + static_cast(pb.ptr), + exp_ptrs, rW_ptrs, rb_ptrs, + static_cast(ob.ptr)); + ctx.waitIdle(); + } + }, + py::arg("device"), py::arg("handle"), py::arg("embed_w"), py::arg("pos_w"), + py::arg("expert_ws"), py::arg("router_ws"), py::arg("router_bs"), py::arg("out_w"), + "Re-upload weights in place after an optimizer step."); + + m.def( + "moe_backward", + [](GrillyCoreContext& ctx, int handle, py::array_t input_ids, + py::array_t grad_logits) -> py::dict { + auto& h = ops::moe_get_cache(handle); + auto ib = input_ids.request(); + auto gb = grad_logits.request(); + require_c_contiguous_int32(ib); + require_c_contiguous_float(gb); + if (ib.ndim != 1) + throw std::runtime_error("moe_backward: input_ids must be 1-D"); + if (gb.ndim != 2) + throw std::runtime_error("moe_backward: grad_logits must be 2-D"); + uint32_t seq_len = static_cast(ib.shape[0]); + uint32_t V = h.vocab; + if (static_cast(gb.shape[0]) != seq_len || + static_cast(gb.shape[1]) != V) + throw std::runtime_error("moe_backward: grad_logits shape mismatch"); + + ops::MoeGradients grads; + { + py::gil_scoped_release release; + std::lock_guard lock(ctx.ctx_mutex); + grads = ops::moe_backward_gpu( + ctx.batch, ctx.pool, ctx.cache, h, + static_cast(ib.ptr), seq_len, + static_cast(gb.ptr)); + } + + py::dict d; + py::array_t grad_embed({static_cast(h.vocab), + static_cast(h.d)}); + std::memcpy(grad_embed.mutable_data(), grads.grad_embed.data(), + grads.grad_embed.size() * sizeof(float)); + d["grad_embed"] = grad_embed; + + py::array_t grad_pos({static_cast(h.maxSeq), + static_cast(h.d)}); + std::memcpy(grad_pos.mutable_data(), grads.grad_pos.data(), + grads.grad_pos.size() * sizeof(float)); + d["grad_pos"] = grad_pos; + + py::list ge; + for (size_t i = 0; i < grads.grad_experts.size(); ++i) { + py::array_t gex({static_cast(h.d), + static_cast(h.d)}); + std::memcpy(gex.mutable_data(), grads.grad_experts[i].data(), + h.d * h.d * sizeof(float)); + ge.append(gex); + } + d["grad_experts"] = ge; + + py::list grw; + py::list grb; + for (uint32_t l = 0; l < h.nLayers; ++l) { + py::array_t gw({static_cast(h.nExperts), + static_cast(h.d)}); + std::memcpy(gw.mutable_data(), grads.grad_router_w[l].data(), + h.nExperts * h.d * sizeof(float)); + grw.append(gw); + py::array_t gbv({static_cast(h.nExperts)}); + std::memcpy(gbv.mutable_data(), grads.grad_router_b[l].data(), + h.nExperts * sizeof(float)); + grb.append(gbv); + } + d["grad_routers_W"] = grw; + d["grad_routers_b"] = grb; + + py::array_t gow({static_cast(h.vocab), + static_cast(h.d)}); + std::memcpy(gow.mutable_data(), grads.grad_out_w.data(), + grads.grad_out_w.size() * sizeof(float)); + d["grad_out_w"] = gow; + return d; + }, + py::arg("device"), py::arg("handle"), py::arg("input_ids"), + py::arg("grad_logits"), + "Backward pass (GPU path when available, CPU fallback) using uploaded MoE weights."); +} diff --git a/cpp/python/bindings_moqe_train.cpp b/cpp/python/bindings_moqe_train.cpp index 1d5788d..c42e593 100644 --- a/cpp/python/bindings_moqe_train.cpp +++ b/cpp/python/bindings_moqe_train.cpp @@ -24,7 +24,9 @@ void register_moqe_train_ops(py::module_& m) { for (auto& w : weight_arrays) { auto arr = w.cast>(); arrays.push_back(arr); - ptrs.push_back(static_cast(arr.request().ptr)); + auto wb = arr.request(); + require_c_contiguous_float(wb); + ptrs.push_back(static_cast(wb.ptr)); } int handle = ops::moqe_train_upload( @@ -52,6 +54,7 @@ void register_moqe_train_ops(py::module_& m) { uint32_t layerIdx, int expertIdx, py::array_t w) { auto& cache = ops::moqe_train_get_cache(handle); auto buf = w.request(); + require_c_contiguous_float(buf); { py::gil_scoped_release release; ops::moqe_train_update_expert(ctx.pool, cache, layerIdx, expertIdx, @@ -71,6 +74,10 @@ void register_moqe_train_ops(py::module_& m) { auto& tc = ops::moqe_train_get_cache(handle); auto b0 = x0.request(); auto b1 = x1.request(); + if (b0.size > 0) + require_c_contiguous_float(b0); + if (b1.size > 0) + require_c_contiguous_float(b1); uint32_t n0 = static_cast(b0.shape[0]); uint32_t n1 = static_cast(b1.shape[0]); uint32_t d = tc.dModel; @@ -100,6 +107,10 @@ void register_moqe_train_ops(py::module_& m) { auto& tc = ops::moqe_train_get_cache(handle); auto b0 = d0.request(); auto b1 = d1.request(); + if (b0.size > 0) + require_c_contiguous_float(b0); + if (b1.size > 0) + require_c_contiguous_float(b1); uint32_t n0 = static_cast(b0.shape[0]); uint32_t n1 = static_cast(b1.shape[0]); uint32_t d = tc.dModel; diff --git a/cpp/python/bindings_normalization.cpp b/cpp/python/bindings_normalization.cpp index 15f62bf..4428639 100644 --- a/cpp/python/bindings_normalization.cpp +++ b/cpp/python/bindings_normalization.cpp @@ -19,6 +19,8 @@ void register_normalization_ops(py::module_& m) { float eps) -> Tensor { auto inBuf = input.request(); auto gBuf = gamma.request(); + require_c_contiguous_float(inBuf); + require_c_contiguous_float(gBuf); if (inBuf.ndim < 2) throw std::runtime_error("input must be at least 2D"); @@ -31,14 +33,19 @@ void register_normalization_ops(py::module_& m) { py::array_t result(inBuf.shape); auto rBuf = result.request(); + auto bBuf = beta.request(); + require_c_contiguous_float(bBuf); - grilly::ops::layernorm( - ctx.batch, ctx.pool, ctx.cache, - static_cast(inBuf.ptr), - static_cast(rBuf.ptr), - static_cast(gBuf.ptr), - static_cast(beta.request().ptr), - 1, totalBatch, features, eps); + { + py::gil_scoped_release release; + grilly::ops::layernorm( + ctx.batch, ctx.pool, ctx.cache, + static_cast(inBuf.ptr), + static_cast(rBuf.ptr), + static_cast(gBuf.ptr), + static_cast(bBuf.ptr), + 1, totalBatch, features, eps); + } return Tensor::from_numpy(result); }, @@ -56,6 +63,9 @@ void register_normalization_ops(py::module_& m) { auto goBuf = grad_output.request(); auto inBuf = input.request(); auto gBuf = gamma.request(); + require_c_contiguous_float(goBuf); + require_c_contiguous_float(inBuf); + require_c_contiguous_float(gBuf); if (inBuf.ndim < 2) throw std::runtime_error("input must be at least 2D"); @@ -71,18 +81,28 @@ void register_normalization_ops(py::module_& m) { {static_cast(features)}); py::array_t gradBeta( {static_cast(features)}); + auto meanBuf = mean.request(); + auto varBuf = var.request(); + auto giBuf = gradInput.request(); + auto ggBuf = gradGamma.request(); + auto gbBuf = gradBeta.request(); + require_c_contiguous_float(meanBuf); + require_c_contiguous_float(varBuf); - grilly::ops::layernormBackward( - ctx.batch, ctx.pool, ctx.cache, - static_cast(goBuf.ptr), - static_cast(inBuf.ptr), - static_cast(gBuf.ptr), - static_cast(mean.request().ptr), - static_cast(var.request().ptr), - static_cast(gradInput.request().ptr), - static_cast(gradGamma.request().ptr), - static_cast(gradBeta.request().ptr), - 1, totalBatch, features, eps); + { + py::gil_scoped_release release; + grilly::ops::layernormBackward( + ctx.batch, ctx.pool, ctx.cache, + static_cast(goBuf.ptr), + static_cast(inBuf.ptr), + static_cast(gBuf.ptr), + static_cast(meanBuf.ptr), + static_cast(varBuf.ptr), + static_cast(giBuf.ptr), + static_cast(ggBuf.ptr), + static_cast(gbBuf.ptr), + 1, totalBatch, features, eps); + } py::dict result; result["grad_input"] = Tensor::from_numpy(gradInput); @@ -104,6 +124,7 @@ void register_normalization_ops(py::module_& m) { py::array_t weight, float eps) -> Tensor { auto inBuf = input.request(); + require_c_contiguous_float(inBuf); uint32_t features, batchSize, seqLen; if (inBuf.ndim == 1) { @@ -124,13 +145,18 @@ void register_normalization_ops(py::module_& m) { py::array_t result(inBuf.shape); auto rBuf = result.request(); + auto wBuf = weight.request(); + require_c_contiguous_float(wBuf); - grilly::ops::rmsnorm( - ctx.batch, ctx.pool, ctx.cache, - static_cast(inBuf.ptr), - static_cast(rBuf.ptr), - static_cast(weight.request().ptr), - batchSize, seqLen, features, eps); + { + py::gil_scoped_release release; + grilly::ops::rmsnorm( + ctx.batch, ctx.pool, ctx.cache, + static_cast(inBuf.ptr), + static_cast(rBuf.ptr), + static_cast(wBuf.ptr), + batchSize, seqLen, features, eps); + } return Tensor::from_numpy(result); }, @@ -147,6 +173,7 @@ void register_normalization_ops(py::module_& m) { py::array_t running_mean, py::array_t running_var, float eps, float momentum, bool training) -> py::dict { auto inBuf = input.request(); + require_c_contiguous_float(inBuf); if (inBuf.ndim != 4) throw std::runtime_error("input must be 4D (B, C, H, W)"); @@ -168,16 +195,25 @@ void register_normalization_ops(py::module_& m) { grilly::ops::BatchNorm2dForwardParams p{ B, C, H, W, eps, momentum, training ? 1u : 0u, 1u}; - grilly::ops::batchnorm2dForward( - ctx.batch, ctx.pool, ctx.cache, - static_cast(inBuf.ptr), - static_cast(output.request().ptr), - static_cast(gamma.request().ptr), - static_cast(beta.request().ptr), - static_cast(rmOut.mutable_data()), - static_cast(rvOut.mutable_data()), - static_cast(bMean.mutable_data()), - static_cast(bVar.mutable_data()), p); + auto outBuf = output.request(); + auto gaBuf = gamma.request(); + auto beBuf = beta.request(); + require_c_contiguous_float(gaBuf); + require_c_contiguous_float(beBuf); + + { + py::gil_scoped_release release; + grilly::ops::batchnorm2dForward( + ctx.batch, ctx.pool, ctx.cache, + static_cast(inBuf.ptr), + static_cast(outBuf.ptr), + static_cast(gaBuf.ptr), + static_cast(beBuf.ptr), + static_cast(rmOut.mutable_data()), + static_cast(rvOut.mutable_data()), + static_cast(bMean.mutable_data()), + static_cast(bVar.mutable_data()), p); + } py::dict result; result["output"] = Tensor::from_numpy(output); @@ -200,6 +236,7 @@ void register_normalization_ops(py::module_& m) { [](GrillyCoreContext& ctx, py::array_t input, int dim) -> Tensor { auto inBuf = input.request(); + require_c_contiguous_float(inBuf); if (inBuf.ndim < 1) throw std::runtime_error("input must be at least 1D"); @@ -211,16 +248,52 @@ void register_normalization_ops(py::module_& m) { if (inBuf.ndim == 1) totalBatch = 1; py::array_t result(inBuf.shape); - grilly::ops::softmax( - ctx.batch, ctx.pool, ctx.cache, - static_cast(inBuf.ptr), - static_cast(result.request().ptr), - 1, totalBatch, features); + auto rBuf = result.request(); + { + py::gil_scoped_release release; + grilly::ops::softmax( + ctx.batch, ctx.pool, ctx.cache, + static_cast(inBuf.ptr), + static_cast(rBuf.ptr), + 1, totalBatch, features); + } return Tensor::from_numpy(result); }, py::arg("device"), py::arg("input"), py::arg("dim") = -1, "GPU Softmax (3-pass: max, sum_exp, normalize)"); + m.def( + "mf_softmax", + [](GrillyCoreContext& ctx, py::array_t input, + int dim) -> Tensor { + auto inBuf = input.request(); + require_c_contiguous_float(inBuf); + if (inBuf.ndim < 1) + throw std::runtime_error("input must be at least 1D"); + + uint32_t features = static_cast( + inBuf.shape[inBuf.ndim - 1]); + uint32_t totalBatch = 1; + for (int i = 0; i < inBuf.ndim - 1; ++i) + totalBatch *= static_cast(inBuf.shape[i]); + if (inBuf.ndim == 1) totalBatch = 1; + + py::array_t result(inBuf.shape); + auto rBuf = result.request(); + { + py::gil_scoped_release release; + grilly::ops::mfSoftmax( + ctx.batch, ctx.pool, ctx.cache, + static_cast(inBuf.ptr), + static_cast(rBuf.ptr), + 1, totalBatch, features); + } + return Tensor::from_numpy(result); + }, + py::arg("device"), py::arg("input"), py::arg("dim") = -1, + "GPU multiplication-free softmax (ReLU-normalized; 3-pass, shader " + "mf-softmax)"); + // ── Softmax backward ───────────────────────────────────────────────── m.def( "softmax_backward", @@ -228,6 +301,7 @@ void register_normalization_ops(py::module_& m) { py::array_t grad_output, py::array_t softmax_output) -> Tensor { auto gBuf = grad_output.request(); + require_c_contiguous_float(gBuf); uint32_t numClasses = static_cast( gBuf.shape[gBuf.ndim - 1]); uint32_t batchSeq = 1; @@ -236,13 +310,19 @@ void register_normalization_ops(py::module_& m) { if (gBuf.ndim == 1) batchSeq = 1; py::array_t result(gBuf.shape); - grilly::ops::softmaxBackward( - ctx.batch, ctx.pool, ctx.cache, - static_cast(gBuf.ptr), - static_cast( - softmax_output.request().ptr), - static_cast(result.request().ptr), - 1, batchSeq, numClasses); + auto sBuf = softmax_output.request(); + auto rBuf = result.request(); + require_c_contiguous_float(sBuf); + + { + py::gil_scoped_release release; + grilly::ops::softmaxBackward( + ctx.batch, ctx.pool, ctx.cache, + static_cast(gBuf.ptr), + static_cast(sBuf.ptr), + static_cast(rBuf.ptr), + 1, batchSeq, numClasses); + } return Tensor::from_numpy(result); }, py::arg("device"), py::arg("grad_output"), diff --git a/cpp/python/bindings_optim.cpp b/cpp/python/bindings_optim.cpp index bb8bf0b..afc61d2 100644 --- a/cpp/python/bindings_optim.cpp +++ b/cpp/python/bindings_optim.cpp @@ -20,9 +20,18 @@ void register_optim_ops(py::module_& m) { float lr, float beta1, float beta2, float eps, float beta1_t, float beta2_t, bool clear_grad) -> py::dict { auto wBuf = weights.request(); - uint32_t total = 1; - for (int i = 0; i < wBuf.ndim; ++i) - total *= static_cast(wBuf.shape[i]); + auto gBuf = grad.request(); + auto mBuf = m_state.request(); + auto vBuf = v_state.request(); + require_c_contiguous_float(wBuf); + require_c_contiguous_float(gBuf); + require_c_contiguous_float(mBuf); + require_c_contiguous_float(vBuf); + if (wBuf.size != gBuf.size || wBuf.size != mBuf.size || + wBuf.size != vBuf.size) + throw std::runtime_error( + "adam_update: weights, grad, m, v must match shape"); + uint32_t total = static_cast(wBuf.size); py::array_t wOut(wBuf.shape); py::array_t gOut(wBuf.shape); @@ -31,22 +40,25 @@ void register_optim_ops(py::module_& m) { std::memcpy(wOut.mutable_data(), wBuf.ptr, total * sizeof(float)); - std::memcpy(gOut.mutable_data(), grad.data(), + std::memcpy(gOut.mutable_data(), gBuf.ptr, total * sizeof(float)); - std::memcpy(mOut.mutable_data(), m_state.data(), + std::memcpy(mOut.mutable_data(), mBuf.ptr, total * sizeof(float)); - std::memcpy(vOut.mutable_data(), v_state.data(), + std::memcpy(vOut.mutable_data(), vBuf.ptr, total * sizeof(float)); grilly::ops::AdamParams p{total, lr, beta1, beta2, eps, beta1_t, beta2_t, clear_grad ? 1u : 0u}; - grilly::ops::adamUpdate( - ctx.batch, ctx.pool, ctx.cache, - static_cast(wOut.mutable_data()), - static_cast(gOut.mutable_data()), - static_cast(mOut.mutable_data()), - static_cast(vOut.mutable_data()), p); + { + py::gil_scoped_release release; + grilly::ops::adamUpdate( + ctx.batch, ctx.pool, ctx.cache, + static_cast(wOut.mutable_data()), + static_cast(gOut.mutable_data()), + static_cast(mOut.mutable_data()), + static_cast(vOut.mutable_data()), p); + } py::dict result; result["weights"] = Tensor::from_numpy(wOut); @@ -73,9 +85,18 @@ void register_optim_ops(py::module_& m) { float weight_decay, float beta1_t, float beta2_t, bool clear_grad) -> py::dict { auto wBuf = weights.request(); - uint32_t total = 1; - for (int i = 0; i < wBuf.ndim; ++i) - total *= static_cast(wBuf.shape[i]); + auto gBuf = grad.request(); + auto mBuf = m_state.request(); + auto vBuf = v_state.request(); + require_c_contiguous_float(wBuf); + require_c_contiguous_float(gBuf); + require_c_contiguous_float(mBuf); + require_c_contiguous_float(vBuf); + if (wBuf.size != gBuf.size || wBuf.size != mBuf.size || + wBuf.size != vBuf.size) + throw std::runtime_error( + "adamw_update: weights, grad, m, v must match shape"); + uint32_t total = static_cast(wBuf.size); py::array_t wOut(wBuf.shape); py::array_t gOut(wBuf.shape); @@ -84,22 +105,25 @@ void register_optim_ops(py::module_& m) { std::memcpy(wOut.mutable_data(), wBuf.ptr, total * sizeof(float)); - std::memcpy(gOut.mutable_data(), grad.data(), + std::memcpy(gOut.mutable_data(), gBuf.ptr, total * sizeof(float)); - std::memcpy(mOut.mutable_data(), m_state.data(), + std::memcpy(mOut.mutable_data(), mBuf.ptr, total * sizeof(float)); - std::memcpy(vOut.mutable_data(), v_state.data(), + std::memcpy(vOut.mutable_data(), vBuf.ptr, total * sizeof(float)); grilly::ops::AdamWParams p{total, lr, beta1, beta2, eps, weight_decay, beta1_t, beta2_t, clear_grad ? 1u : 0u}; - grilly::ops::adamwUpdate( - ctx.batch, ctx.pool, ctx.cache, - static_cast(wOut.mutable_data()), - static_cast(gOut.mutable_data()), - static_cast(mOut.mutable_data()), - static_cast(vOut.mutable_data()), p); + { + py::gil_scoped_release release; + grilly::ops::adamwUpdate( + ctx.batch, ctx.pool, ctx.cache, + static_cast(wOut.mutable_data()), + static_cast(gOut.mutable_data()), + static_cast(mOut.mutable_data()), + static_cast(vOut.mutable_data()), p); + } py::dict result; result["weights"] = Tensor::from_numpy(wOut); diff --git a/cpp/python/bindings_perceiver.cpp b/cpp/python/bindings_perceiver.cpp index e36b055..a69db31 100644 --- a/cpp/python/bindings_perceiver.cpp +++ b/cpp/python/bindings_perceiver.cpp @@ -28,6 +28,9 @@ void register_perceiver_ops(py::module_& m) { auto qBuf = Q_arr.request(); auto kBuf = K_arr.request(); auto vBuf = V_arr.request(); + require_c_contiguous_float(qBuf); + require_c_contiguous_float(kBuf); + require_c_contiguous_float(vBuf); uint32_t seqN = static_cast(qBuf.shape[0]); uint32_t headDim = static_cast(qBuf.shape[qBuf.ndim - 1]); @@ -68,6 +71,7 @@ void register_perceiver_ops(py::module_& m) { for (auto& w : weights) { auto arr = w.cast>(); auto buf = arr.request(); + require_c_contiguous_float(buf); const float* ptr = static_cast(buf.ptr); cpp_weights.emplace_back(ptr, ptr + buf.size); } @@ -92,6 +96,7 @@ void register_perceiver_ops(py::module_& m) { auto& pc = ops::perceiver_get_cache(handle); auto pBuf = patches.request(); + require_c_contiguous_float(pBuf); uint32_t nPatches = static_cast(pBuf.shape[0]); const float* patchPtr = static_cast(pBuf.ptr); diff --git a/cpp/python/bindings_pooling.cpp b/cpp/python/bindings_pooling.cpp index 6676f82..a18880d 100644 --- a/cpp/python/bindings_pooling.cpp +++ b/cpp/python/bindings_pooling.cpp @@ -19,6 +19,7 @@ void register_pooling_ops(py::module_& m) { std::vector padding, std::vector dilation) -> py::dict { auto inBuf = input.request(); + require_c_contiguous_float(inBuf); if (inBuf.ndim != 4) throw std::runtime_error( "input must be 4D (B, C, H, W)"); @@ -53,14 +54,22 @@ void register_pooling_ops(py::module_& m) { static_cast(oH), static_cast(oW)}); + auto resBuf = result.request(); + auto idxBuf = indices.request(); + require_c_contiguous_float(resBuf); + require_c_contiguous_uint32(idxBuf); + grilly::ops::MaxPool2dParams p{B, C, iH, iW, oH, oW, kH, kW, sH, sW, pH, pW, dH, dW}; - grilly::ops::maxpool2dForward( - ctx.batch, ctx.pool, ctx.cache, - static_cast(inBuf.ptr), - static_cast(result.request().ptr), - static_cast(indices.request().ptr), p); + { + py::gil_scoped_release release; + grilly::ops::maxpool2dForward( + ctx.batch, ctx.pool, ctx.cache, + static_cast(inBuf.ptr), + static_cast(resBuf.ptr), + static_cast(idxBuf.ptr), p); + } py::dict out; out["output"] = Tensor::from_numpy(result); @@ -85,6 +94,7 @@ void register_pooling_ops(py::module_& m) { std::vector padding, bool count_include_pad) -> Tensor { auto inBuf = input.request(); + require_c_contiguous_float(inBuf); if (inBuf.ndim != 4) throw std::runtime_error("input must be 4D"); @@ -111,13 +121,19 @@ void register_pooling_ops(py::module_& m) { static_cast(oH), static_cast(oW)}); + auto resBuf = result.request(); + require_c_contiguous_float(resBuf); + grilly::ops::AvgPool2dParams p{ B, C, iH, iW, oH, oW, kH, kW, sH, sW, pH, pW, count_include_pad ? 1u : 0u}; - grilly::ops::avgpool2dForward( - ctx.batch, ctx.pool, ctx.cache, - static_cast(inBuf.ptr), - static_cast(result.request().ptr), p); + { + py::gil_scoped_release release; + grilly::ops::avgpool2dForward( + ctx.batch, ctx.pool, ctx.cache, + static_cast(inBuf.ptr), + static_cast(resBuf.ptr), p); + } return Tensor::from_numpy(result); }, py::arg("device"), py::arg("input"), @@ -133,6 +149,7 @@ void register_pooling_ops(py::module_& m) { [](GrillyCoreContext& ctx, py::array_t input) -> Tensor { auto inBuf = input.request(); + require_c_contiguous_float(inBuf); if (inBuf.ndim != 3) throw std::runtime_error( "input must be 3D (B, S, D)"); @@ -145,11 +162,17 @@ void register_pooling_ops(py::module_& m) { static_cast(B), static_cast(D)}); + auto resBuf = result.request(); + require_c_contiguous_float(resBuf); + grilly::ops::MeanPoolParams p{B, S, D}; - grilly::ops::meanPool( - ctx.batch, ctx.pool, ctx.cache, - static_cast(inBuf.ptr), - static_cast(result.request().ptr), p); + { + py::gil_scoped_release release; + grilly::ops::meanPool( + ctx.batch, ctx.pool, ctx.cache, + static_cast(inBuf.ptr), + static_cast(resBuf.ptr), p); + } return Tensor::from_numpy(result); }, py::arg("device"), py::arg("input"), diff --git a/cpp/python/bindings_prefix_scan.cpp b/cpp/python/bindings_prefix_scan.cpp new file mode 100644 index 0000000..ad55f4c --- /dev/null +++ b/cpp/python/bindings_prefix_scan.cpp @@ -0,0 +1,118 @@ +/// bindings_prefix_scan.cpp — Causal Linear-RNN prefix scan op bindings. +/// +/// Exposes ``prefix_scan_causal`` and ``prefix_scan_causal_backward`` to +/// Python. Used by ``CausalSequenceMixer`` in the VSA-LM model to replace +/// the causally-leaky ``h.mean(dim=1)`` pooling with a strictly causal +/// subgroup-scanned Linear-RNN. + +#include "bindings_core.h" +#include "grilly/ops/prefix_scan.h" + +void register_prefix_scan_ops(py::module_& m) { + using namespace grilly::ops; + + m.def( + "prefix_scan_causal", + [](GrillyCoreContext& ctx, + py::array_t x, py::array_t a) -> py::array_t { + auto xBuf = x.request(); + auto aBuf = a.request(); + + if (xBuf.ndim != 3 || aBuf.ndim != 3) + throw std::runtime_error( + "prefix_scan_causal: x and a must be 3D (batch, seq, dim)"); + if (xBuf.shape != aBuf.shape) + throw std::runtime_error( + "prefix_scan_causal: x and a must have identical shape"); + + const uint32_t batchSize = static_cast(xBuf.shape[0]); + const uint32_t seqLen = static_cast(xBuf.shape[1]); + const uint32_t hiddenDim = static_cast(xBuf.shape[2]); + + if (seqLen > 32) + throw std::runtime_error( + "prefix_scan_causal: seq_len must be <= 32 (subgroup size). " + "Longer sequences need a hierarchical scan (TODO)."); + + PrefixScanParams p{seqLen, hiddenDim, batchSize}; + + py::array_t result(xBuf.shape); + auto rBuf = result.request(); + + { + py::gil_scoped_release release; + prefixScanCausal( + ctx.batch, ctx.pool, ctx.cache, + static_cast(xBuf.ptr), + static_cast(aBuf.ptr), + static_cast(rBuf.ptr), p); + } + + return result; + }, + py::arg("device"), py::arg("x"), py::arg("a"), + "Causal Linear-RNN forward: h_t = a_t * h_{t-1} + x_t via subgroup scan"); + + m.def( + "prefix_scan_causal_backward", + [](GrillyCoreContext& ctx, + py::array_t grad_h, py::array_t a, + py::array_t h, py::array_t x) -> py::dict { + auto dhBuf = grad_h.request(); + auto aBuf = a.request(); + auto hBuf = h.request(); + auto xBuf = x.request(); + + if (dhBuf.ndim != 3) + throw std::runtime_error( + "prefix_scan_causal_backward: grad_h must be 3D"); + + const uint32_t batchSize = static_cast(dhBuf.shape[0]); + const uint32_t seqLen = static_cast(dhBuf.shape[1]); + const uint32_t hiddenDim = static_cast(dhBuf.shape[2]); + + if (seqLen > 32) + throw std::runtime_error( + "prefix_scan_causal_backward: seq_len must be <= 32"); + + PrefixScanParams p{seqLen, hiddenDim, batchSize}; + + py::array_t gradX(dhBuf.shape); + py::array_t gradA(dhBuf.shape); + + // CRITICAL: get the raw pointers BEFORE releasing the GIL. + // ``py::array::request()`` touches Python reference counts and + // needs the GIL held — calling it inside a + // ``gil_scoped_release`` block deadlocks on internal locks. + // The forward binding gets this right; the first version of + // this backward binding didn't and hung inside the dispatcher + // with zero trace output from the C++ side (since the C++ call + // was never reached — the lambda deadlocked on request()). + const void* dhPtr = dhBuf.ptr; + const void* aPtr = aBuf.ptr; + const void* hPtr = hBuf.ptr; + const void* xPtr = xBuf.ptr; + void* gradXPtr = gradX.request().ptr; + void* gradAPtr = gradA.request().ptr; + + { + py::gil_scoped_release release; + prefixScanCausalBackward( + ctx.batch, ctx.pool, ctx.cache, + static_cast(dhPtr), + static_cast(aPtr), + static_cast(hPtr), + static_cast(xPtr), + static_cast(gradXPtr), + static_cast(gradAPtr), p); + } + + py::dict result; + result["grad_x"] = gradX; + result["grad_a"] = gradA; + return result; + }, + py::arg("device"), py::arg("grad_h"), py::arg("a"), + py::arg("h"), py::arg("x"), + "Causal prefix scan backward: returns grad_x and grad_a"); +} diff --git a/cpp/python/bindings_siglip.cpp b/cpp/python/bindings_siglip.cpp index e699ed7..c203bc1 100644 --- a/cpp/python/bindings_siglip.cpp +++ b/cpp/python/bindings_siglip.cpp @@ -45,6 +45,7 @@ static int g_nextHandle = 1; static GrillyBuffer uploadPersistent(BufferPool& pool, py::array_t arr) { auto buf = arr.request(); + require_c_contiguous_float(buf); size_t bytes = buf.size * sizeof(float); GrillyBuffer gpuBuf = pool.acquire(bytes); pool.upload(gpuBuf, static_cast(buf.ptr), bytes); @@ -141,6 +142,7 @@ void register_siglip_ops(py::module_& m) { auto& wc = it->second; auto pBuf = patches.request(); + require_c_contiguous_float(pBuf); const uint32_t S = wc.seqLen; const uint32_t H = wc.hidden; const uint32_t H3 = H * 3; diff --git a/cpp/python/bindings_snn.cpp b/cpp/python/bindings_snn.cpp index 96a8372..c7d5338 100644 --- a/cpp/python/bindings_snn.cpp +++ b/cpp/python/bindings_snn.cpp @@ -19,6 +19,7 @@ void register_snn_ops(py::module_& m) { float v_thresh, float r_mem, float t_refrac_period) -> py::dict { auto inBuf = input.request(); + require_c_contiguous_float(inBuf); uint32_t n = 1; for (int i = 0; i < inBuf.ndim; ++i) n *= static_cast(inBuf.shape[i]); @@ -26,21 +27,29 @@ void register_snn_ops(py::module_& m) { py::array_t vMemOut(v_mem.request().shape); py::array_t refracOut(t_refrac.request().shape); py::array_t spikes(inBuf.shape); + auto vmIn = v_mem.request(); + auto trIn = t_refrac.request(); + require_c_contiguous_float(vmIn); + require_c_contiguous_float(trIn); + auto vmOut = vMemOut.request(); + auto rfOut = refracOut.request(); + auto spOut = spikes.request(); - std::memcpy(vMemOut.request().ptr, v_mem.request().ptr, - n * sizeof(float)); - std::memcpy(refracOut.request().ptr, - t_refrac.request().ptr, n * sizeof(float)); + std::memcpy(vmOut.ptr, vmIn.ptr, n * sizeof(float)); + std::memcpy(rfOut.ptr, trIn.ptr, n * sizeof(float)); grilly::ops::LIFParams p{n, dt, tau_mem, v_rest, v_reset, v_thresh, r_mem, t_refrac_period}; - grilly::ops::lifStep( - ctx.batch, ctx.pool, ctx.cache, - static_cast(inBuf.ptr), - static_cast(vMemOut.request().ptr), - static_cast(refracOut.request().ptr), - static_cast(spikes.request().ptr), p); + { + py::gil_scoped_release release; + grilly::ops::lifStep( + ctx.batch, ctx.pool, ctx.cache, + static_cast(inBuf.ptr), + static_cast(vmOut.ptr), + static_cast(rfOut.ptr), + static_cast(spOut.ptr), p); + } py::dict result; result["spikes"] = Tensor::from_numpy(spikes); @@ -66,6 +75,7 @@ void register_snn_ops(py::module_& m) { float v_reset, uint32_t reset_mode, uint32_t decay_input) -> py::dict { auto xBuf = x_in.request(); + require_c_contiguous_float(xBuf); uint32_t n = 1; for (int i = 0; i < xBuf.ndim; ++i) n *= static_cast(xBuf.shape[i]); @@ -73,22 +83,30 @@ void register_snn_ops(py::module_& m) { py::array_t vMemOut(v_mem.request().shape); py::array_t spikes(xBuf.shape); py::array_t hOut(xBuf.shape); + auto vmIn = v_mem.request(); + require_c_contiguous_float(vmIn); + auto tpIn = tau_param.request(); + require_c_contiguous_float(tpIn); + auto vmOut = vMemOut.request(); + auto spOut = spikes.request(); + auto hR = hOut.request(); - std::memcpy(vMemOut.request().ptr, v_mem.request().ptr, - n * sizeof(float)); + std::memcpy(vmOut.ptr, vmIn.ptr, n * sizeof(float)); grilly::ops::SNNNodeForwardParams p{ n, neuron_type, tau, v_threshold, v_reset, reset_mode, decay_input}; - grilly::ops::snnNodeForward( - ctx.batch, ctx.pool, ctx.cache, - static_cast(xBuf.ptr), - static_cast(vMemOut.request().ptr), - static_cast(spikes.request().ptr), - static_cast(hOut.request().ptr), - static_cast( - tau_param.request().ptr), p); + { + py::gil_scoped_release release; + grilly::ops::snnNodeForward( + ctx.batch, ctx.pool, ctx.cache, + static_cast(xBuf.ptr), + static_cast(vmOut.ptr), + static_cast(spOut.ptr), + static_cast(hR.ptr), + static_cast(tpIn.ptr), p); + } py::dict result; result["spikes"] = Tensor::from_numpy(spikes); @@ -112,21 +130,27 @@ void register_snn_ops(py::module_& m) { float alpha, uint32_t surrogate_type, float v_threshold) -> Tensor { auto gBuf = grad_spike.request(); + require_c_contiguous_float(gBuf); uint32_t n = 1; for (int i = 0; i < gBuf.ndim; ++i) n *= static_cast(gBuf.shape[i]); py::array_t gradX(gBuf.shape); + auto hBuf = h_cache.request(); + require_c_contiguous_float(hBuf); + auto gxBuf = gradX.request(); grilly::ops::SNNNodeBackwardParams p{ n, alpha, surrogate_type, v_threshold}; - grilly::ops::snnNodeBackward( - ctx.batch, ctx.pool, ctx.cache, - static_cast(gBuf.ptr), - static_cast( - h_cache.request().ptr), - static_cast(gradX.request().ptr), p); + { + py::gil_scoped_release release; + grilly::ops::snnNodeBackward( + ctx.batch, ctx.pool, ctx.cache, + static_cast(gBuf.ptr), + static_cast(hBuf.ptr), + static_cast(gxBuf.ptr), p); + } return Tensor::from_numpy(gradX); }, @@ -143,10 +167,13 @@ void register_snn_ops(py::module_& m) { py::array_t weights, uint32_t batch_size, uint32_t time_steps, float learning_rate, - float weight_decay) -> Tensor { + float weight_decay) -> Tensor { auto preBuf = pre.request(); auto postBuf = post.request(); auto wBuf = weights.request(); + require_c_contiguous_float(preBuf); + require_c_contiguous_float(postBuf); + require_c_contiguous_float(wBuf); uint32_t pre_dim = static_cast( preBuf.shape[preBuf.ndim - 1]); @@ -154,18 +181,22 @@ void register_snn_ops(py::module_& m) { postBuf.shape[postBuf.ndim - 1]); py::array_t wOut(wBuf.shape); - std::memcpy(wOut.request().ptr, wBuf.ptr, + auto woBuf = wOut.request(); + std::memcpy(woBuf.ptr, wBuf.ptr, size_t(pre_dim) * post_dim * sizeof(float)); grilly::ops::HebbianParams p{batch_size, time_steps, pre_dim, post_dim, learning_rate, weight_decay}; - grilly::ops::hebbianLearning( - ctx.batch, ctx.pool, ctx.cache, - static_cast(preBuf.ptr), - static_cast(postBuf.ptr), - static_cast(wOut.request().ptr), p); + { + py::gil_scoped_release release; + grilly::ops::hebbianLearning( + ctx.batch, ctx.pool, ctx.cache, + static_cast(preBuf.ptr), + static_cast(postBuf.ptr), + static_cast(woBuf.ptr), p); + } return Tensor::from_numpy(wOut); }, @@ -189,6 +220,9 @@ void register_snn_ops(py::module_& m) { auto preBuf = pre.request(); auto postBuf = post.request(); auto wBuf = weights.request(); + require_c_contiguous_float(preBuf); + require_c_contiguous_float(postBuf); + require_c_contiguous_float(wBuf); uint32_t pre_dim = static_cast( preBuf.shape[preBuf.ndim - 1]); @@ -200,27 +234,35 @@ void register_snn_ops(py::module_& m) { pre_trace.request().shape); py::array_t postTraceOut( post_trace.request().shape); - - std::memcpy(wOut.request().ptr, wBuf.ptr, + auto ptIn = pre_trace.request(); + auto pstIn = post_trace.request(); + require_c_contiguous_float(ptIn); + require_c_contiguous_float(pstIn); + auto woBuf = wOut.request(); + auto ptoBuf = preTraceOut.request(); + auto pstoBuf = postTraceOut.request(); + + std::memcpy(woBuf.ptr, wBuf.ptr, size_t(pre_dim) * post_dim * sizeof(float)); - std::memcpy(preTraceOut.request().ptr, - pre_trace.request().ptr, + std::memcpy(ptoBuf.ptr, ptIn.ptr, size_t(batch_size) * pre_dim * sizeof(float)); - std::memcpy(postTraceOut.request().ptr, - post_trace.request().ptr, + std::memcpy(pstoBuf.ptr, pstIn.ptr, size_t(batch_size) * post_dim * sizeof(float)); grilly::ops::STDPParams p{batch_size, time_steps, pre_dim, post_dim, lr_pot, lr_dep, trace_decay, 0}; - grilly::ops::stdpLearning( - ctx.batch, ctx.pool, ctx.cache, - static_cast(preBuf.ptr), - static_cast(postBuf.ptr), - static_cast(wOut.request().ptr), - static_cast(preTraceOut.request().ptr), - static_cast(postTraceOut.request().ptr), p); + { + py::gil_scoped_release release; + grilly::ops::stdpLearning( + ctx.batch, ctx.pool, ctx.cache, + static_cast(preBuf.ptr), + static_cast(postBuf.ptr), + static_cast(woBuf.ptr), + static_cast(ptoBuf.ptr), + static_cast(pstoBuf.ptr), p); + } py::dict result; result["weights"] = Tensor::from_numpy(wOut); @@ -244,21 +286,26 @@ void register_snn_ops(py::module_& m) { py::array_t y_state, float decay) -> Tensor { auto xBuf = x_in.request(); + require_c_contiguous_float(xBuf); uint32_t n = 1; for (int i = 0; i < xBuf.ndim; ++i) n *= static_cast(xBuf.shape[i]); py::array_t yOut(y_state.request().shape); - std::memcpy(yOut.request().ptr, - y_state.request().ptr, - n * sizeof(float)); + auto ysIn = y_state.request(); + require_c_contiguous_float(ysIn); + auto yoBuf = yOut.request(); + std::memcpy(yoBuf.ptr, ysIn.ptr, n * sizeof(float)); grilly::ops::SynapseFilterParams p{n, decay}; - grilly::ops::synapseFilter( - ctx.batch, ctx.pool, ctx.cache, - static_cast(xBuf.ptr), - static_cast(yOut.request().ptr), p); + { + py::gil_scoped_release release; + grilly::ops::synapseFilter( + ctx.batch, ctx.pool, ctx.cache, + static_cast(xBuf.ptr), + static_cast(yoBuf.ptr), p); + } return Tensor::from_numpy(yOut); }, @@ -279,8 +326,9 @@ void register_snn_ops(py::module_& m) { float tau_mem, float v_rest, float v_reset, float v_thresh, float r_mem, float tau_adapt, float delta_adapt, float b_adapt, float tau_gate, float gate_strength, - float t_refrac_period) -> py::dict { + float t_refrac_period) -> py::dict { auto inBuf = input.request(); + require_c_contiguous_float(inBuf); uint32_t n = 1; for (int i = 0; i < inBuf.ndim; ++i) n *= static_cast(inBuf.shape[i]); @@ -294,35 +342,52 @@ void register_snn_ops(py::module_& m) { py::array_t tLastOut( t_last_spike.request().shape); - std::memcpy(vMemOut.request().ptr, - v_mem.request().ptr, n * sizeof(float)); - std::memcpy(iAdaptOut.request().ptr, - i_adapt.request().ptr, n * sizeof(float)); - std::memcpy(gInputOut.request().ptr, - g_input.request().ptr, n * sizeof(float)); - std::memcpy(gForgetOut.request().ptr, - g_forget.request().ptr, n * sizeof(float)); - std::memcpy(refracOut.request().ptr, - t_refrac.request().ptr, n * sizeof(float)); - std::memcpy(tLastOut.request().ptr, - t_last_spike.request().ptr, - n * sizeof(float)); + auto vmIn = v_mem.request(); + auto iaIn = i_adapt.request(); + auto giIn = g_input.request(); + auto gfIn = g_forget.request(); + auto trIn = t_refrac.request(); + auto tlsIn = t_last_spike.request(); + require_c_contiguous_float(vmIn); + require_c_contiguous_float(iaIn); + require_c_contiguous_float(giIn); + require_c_contiguous_float(gfIn); + require_c_contiguous_float(trIn); + require_c_contiguous_float(tlsIn); + + auto vmOut = vMemOut.request(); + auto iaOut = iAdaptOut.request(); + auto giOut = gInputOut.request(); + auto gfOut = gForgetOut.request(); + auto trOut = refracOut.request(); + auto spOut = spikes.request(); + auto tlsOut = tLastOut.request(); + + std::memcpy(vmOut.ptr, vmIn.ptr, n * sizeof(float)); + std::memcpy(iaOut.ptr, iaIn.ptr, n * sizeof(float)); + std::memcpy(giOut.ptr, giIn.ptr, n * sizeof(float)); + std::memcpy(gfOut.ptr, gfIn.ptr, n * sizeof(float)); + std::memcpy(trOut.ptr, trIn.ptr, n * sizeof(float)); + std::memcpy(tlsOut.ptr, tlsIn.ptr, n * sizeof(float)); grilly::ops::GIFParams p{ n, dt, current_time, tau_mem, v_rest, v_reset, v_thresh, r_mem, tau_adapt, delta_adapt, b_adapt, tau_gate, gate_strength, t_refrac_period}; - grilly::ops::gifNeuronStep( - ctx.batch, ctx.pool, ctx.cache, - static_cast(inBuf.ptr), - static_cast(vMemOut.request().ptr), - static_cast(iAdaptOut.request().ptr), - static_cast(gInputOut.request().ptr), - static_cast(gForgetOut.request().ptr), - static_cast(refracOut.request().ptr), - static_cast(spikes.request().ptr), - static_cast(tLastOut.request().ptr), p); + { + py::gil_scoped_release release; + grilly::ops::gifNeuronStep( + ctx.batch, ctx.pool, ctx.cache, + static_cast(inBuf.ptr), + static_cast(vmOut.ptr), + static_cast(iaOut.ptr), + static_cast(giOut.ptr), + static_cast(gfOut.ptr), + static_cast(trOut.ptr), + static_cast(spOut.ptr), + static_cast(tlsOut.ptr), p); + } py::dict result; result["spikes"] = Tensor::from_numpy(spikes); diff --git a/cpp/python/bindings_vsa_lm.cpp b/cpp/python/bindings_vsa_lm.cpp new file mode 100644 index 0000000..134ad60 --- /dev/null +++ b/cpp/python/bindings_vsa_lm.cpp @@ -0,0 +1,305 @@ +/// bindings_vsa_lm.cpp — fused VSA-LM upload / forward / backward / release. +#include "bindings_core.h" +#include "grilly/ops/vsa_lm_forward.h" + +#include +#include +#include + +using namespace grilly; + +void register_vsa_lm_ops(py::module_& m) { + + m.def( + "vsa_lm_upload", + [](GrillyCoreContext& ctx, + py::array_t embed_w, py::array_t pos_w, + py::list ffn_up_patterns, py::list ffn_up_biases, + py::list ffn_down_patterns, py::list ffn_down_biases, + py::list ln_gammas, py::list ln_betas, + py::array_t out_w, + uint32_t n_layers, uint32_t d_model, uint32_t d_ffn) -> int { + + auto eb = embed_w.request(); + auto pb = pos_w.request(); + auto ob = out_w.request(); + require_c_contiguous_float(eb); + require_c_contiguous_float(pb); + require_c_contiguous_float(ob); + + if (eb.ndim != 2 || pb.ndim != 2 || ob.ndim != 2) + throw std::runtime_error("vsa_lm_upload: embed/pos/out must be 2-D"); + uint32_t vocab = static_cast(eb.shape[0]); + uint32_t d = static_cast(eb.shape[1]); + if (d != d_model) + throw std::runtime_error("vsa_lm_upload: embed_w dim-1 must match d_model"); + uint32_t max_seq = static_cast(pb.shape[0]); + + auto extract_list = [](py::list& lst, uint32_t n, + std::vector& ptrs, + std::vector>& keep) { + for (uint32_t i = 0; i < n; ++i) { + auto arr = lst[i].cast>(); + auto buf = arr.request(); + require_c_contiguous_float(buf); + keep.push_back(arr); + ptrs.push_back(static_cast(buf.ptr)); + } + }; + + std::vector up_w_ptrs, up_b_ptrs, dn_w_ptrs, dn_b_ptrs; + std::vector gm_ptrs, bt_ptrs; + std::vector> keep1, keep2, keep3, keep4, keep5, keep6; + + extract_list(ffn_up_patterns, n_layers, up_w_ptrs, keep1); + extract_list(ffn_up_biases, n_layers, up_b_ptrs, keep2); + extract_list(ffn_down_patterns, n_layers, dn_w_ptrs, keep3); + extract_list(ffn_down_biases, n_layers, dn_b_ptrs, keep4); + extract_list(ln_gammas, n_layers, gm_ptrs, keep5); + extract_list(ln_betas, n_layers, bt_ptrs, keep6); + + int handle = ops::vsa_lm_upload( + ctx.pool, vocab, d, d_ffn, max_seq, + static_cast(eb.ptr), + static_cast(pb.ptr), + up_w_ptrs, up_b_ptrs, dn_w_ptrs, dn_b_ptrs, + gm_ptrs, bt_ptrs, + static_cast(ob.ptr), + n_layers); + + ctx.waitIdle(); + return handle; + }, + py::arg("device"), py::arg("embed_w"), py::arg("pos_w"), + py::arg("ffn_up_patterns"), py::arg("ffn_up_biases"), + py::arg("ffn_down_patterns"), py::arg("ffn_down_biases"), + py::arg("ln_gammas"), py::arg("ln_betas"), + py::arg("out_w"), + py::arg("n_layers"), py::arg("d_model"), py::arg("d_ffn"), + "Upload VSA-LM weights to GPU; returns opaque integer handle."); + + m.def( + "vsa_lm_release", + [](GrillyCoreContext& ctx, int handle) { + ops::vsa_lm_release(ctx.pool, handle); + }, + py::arg("device"), py::arg("handle"), + "Free GPU buffers for a VSA-LM handle."); + + m.def( + "vsa_lm_forward", + [](GrillyCoreContext& ctx, int handle, py::array_t input_ids) + -> py::array_t { + auto& h = ops::vsa_lm_get_cache(handle); + auto ib = input_ids.request(); + require_c_contiguous_int32(ib); + if (ib.ndim != 1) + throw std::runtime_error("vsa_lm_forward: input_ids must be 1-D"); + uint32_t seq_len = static_cast(ib.shape[0]); + + py::array_t logits({static_cast(seq_len), + static_cast(h.vocab)}); + { + py::gil_scoped_release release; + std::lock_guard lock(ctx.ctx_mutex); + ops::vsa_lm_forward_gpu(ctx.batch, ctx.pool, ctx.cache, h, + static_cast(ib.ptr), + seq_len, logits.mutable_data()); + } + return logits; + }, + py::arg("device"), py::arg("handle"), py::arg("input_ids"), + "Run full VSA-LM forward on GPU. Returns (seq_len, vocab) float32."); + + m.def( + "vsa_lm_backward", + [](GrillyCoreContext& ctx, int handle, py::array_t input_ids, + py::array_t grad_logits) -> py::dict { + auto& h = ops::vsa_lm_get_cache(handle); + auto ib = input_ids.request(); + auto gb = grad_logits.request(); + require_c_contiguous_int32(ib); + require_c_contiguous_float(gb); + if (ib.ndim != 1) + throw std::runtime_error("vsa_lm_backward: input_ids must be 1-D"); + if (gb.ndim != 2) + throw std::runtime_error("vsa_lm_backward: grad_logits must be 2-D"); + uint32_t seq_len = static_cast(ib.shape[0]); + uint32_t V = h.vocab; + uint32_t d = h.d; + uint32_t dF = h.dFfn; + uint32_t L = h.nLayers; + + if (static_cast(gb.shape[0]) != seq_len || + static_cast(gb.shape[1]) != V) + throw std::runtime_error("vsa_lm_backward: grad_logits shape mismatch"); + + ops::VsaLmGradients grads; + { + py::gil_scoped_release release; + grads = ops::vsa_lm_backward_cpu( + h, static_cast(ib.ptr), seq_len, + static_cast(gb.ptr)); + } + + py::dict result; + + // grad_embed (vocab, d) + py::array_t ge({static_cast(V), + static_cast(d)}); + std::memcpy(ge.mutable_data(), grads.grad_embed.data(), + grads.grad_embed.size() * sizeof(float)); + result["grad_embed"] = ge; + + // grad_pos (max_seq, d) + py::array_t gp({static_cast(h.maxSeq), + static_cast(d)}); + std::memcpy(gp.mutable_data(), grads.grad_pos.data(), + grads.grad_pos.size() * sizeof(float)); + result["grad_pos"] = gp; + + // grad_out_w (vocab, d) + py::array_t gow({static_cast(V), + static_cast(d)}); + std::memcpy(gow.mutable_data(), grads.grad_out_w.data(), + grads.grad_out_w.size() * sizeof(float)); + result["grad_out_w"] = gow; + + // Per-layer gradients as lists + py::list g_up_w, g_up_b, g_dn_w, g_dn_b, g_ln_g, g_ln_b; + for (uint32_t l = 0; l < L; ++l) { + py::array_t uw({static_cast(dF), + static_cast(d)}); + std::memcpy(uw.mutable_data(), grads.grad_ffn_up_w[l].data(), + dF * d * sizeof(float)); + g_up_w.append(uw); + + py::array_t ub({static_cast(dF)}); + std::memcpy(ub.mutable_data(), grads.grad_ffn_up_b[l].data(), + dF * sizeof(float)); + g_up_b.append(ub); + + py::array_t dw({static_cast(d), + static_cast(dF)}); + std::memcpy(dw.mutable_data(), grads.grad_ffn_down_w[l].data(), + d * dF * sizeof(float)); + g_dn_w.append(dw); + + py::array_t db({static_cast(d)}); + std::memcpy(db.mutable_data(), grads.grad_ffn_down_b[l].data(), + d * sizeof(float)); + g_dn_b.append(db); + + py::array_t lg({static_cast(d)}); + std::memcpy(lg.mutable_data(), grads.grad_ln_gamma[l].data(), + d * sizeof(float)); + g_ln_g.append(lg); + + py::array_t lb({static_cast(d)}); + std::memcpy(lb.mutable_data(), grads.grad_ln_beta[l].data(), + d * sizeof(float)); + g_ln_b.append(lb); + } + result["grad_ffn_up_w"] = g_up_w; + result["grad_ffn_up_b"] = g_up_b; + result["grad_ffn_down_w"] = g_dn_w; + result["grad_ffn_down_b"] = g_dn_b; + result["grad_ln_gamma"] = g_ln_g; + result["grad_ln_beta"] = g_ln_b; + + return result; + }, + py::arg("device"), py::arg("handle"), py::arg("input_ids"), + py::arg("grad_logits"), + "CPU backward for VSA-LM. Returns dict with all gradient arrays."); + + m.def( + "vsa_lm_update_weights", + [](GrillyCoreContext& ctx, int handle, + py::array_t embed_w, py::array_t pos_w, + py::list ffn_up_patterns, py::list ffn_up_biases, + py::list ffn_down_patterns, py::list ffn_down_biases, + py::list ln_gammas, py::list ln_betas, + py::array_t out_w) { + + auto& h = ops::vsa_lm_get_cache(handle); + auto eb = embed_w.request(); + auto pb = pos_w.request(); + auto ob = out_w.request(); + require_c_contiguous_float(eb); + require_c_contiguous_float(pb); + require_c_contiguous_float(ob); + + uint32_t d = h.d; + uint32_t dF = h.dFfn; + uint32_t V = h.vocab; + uint32_t L = h.nLayers; + + // Re-upload embedding + size_t embed_bytes = size_t(V) * d * sizeof(float); + std::memcpy(h.cpu_embed.data(), static_cast(eb.ptr), embed_bytes); + ctx.pool.upload(h.embedW, static_cast(eb.ptr), embed_bytes); + + // Re-upload pos + size_t pos_bytes = size_t(h.maxSeq) * d * sizeof(float); + std::memcpy(h.cpu_pos.data(), static_cast(pb.ptr), pos_bytes); + ctx.pool.upload(h.posW, static_cast(pb.ptr), pos_bytes); + + // Re-upload out_w + transpose + size_t out_bytes = size_t(V) * d * sizeof(float); + std::memcpy(h.cpu_out_w.data(), static_cast(ob.ptr), out_bytes); + ctx.pool.upload(h.outW, static_cast(ob.ptr), out_bytes); + std::vector out_wt(d * V); + for (uint32_t r = 0; r < V; ++r) + for (uint32_t c = 0; c < d; ++c) + out_wt[c * V + r] = h.cpu_out_w[r * d + c]; + ctx.pool.upload(h.outWt, out_wt.data(), out_wt.size() * sizeof(float)); + + // Per-layer + for (uint32_t l = 0; l < L; ++l) { + auto& lw = h.layers[l]; + auto upw = ffn_up_patterns[l].cast>(); + auto upb = ffn_up_biases[l].cast>(); + auto dnw = ffn_down_patterns[l].cast>(); + auto dnb = ffn_down_biases[l].cast>(); + auto gm = ln_gammas[l].cast>(); + auto bt = ln_betas[l].cast>(); + + auto upwb = upw.request(); + auto upbb = upb.request(); + auto dnwb = dnw.request(); + auto dnbb = dnb.request(); + auto gmb = gm.request(); + auto btb = bt.request(); + + size_t uw_bytes = size_t(dF) * d * sizeof(float); + size_t ub_bytes = size_t(dF) * sizeof(float); + size_t dw_bytes = size_t(d) * dF * sizeof(float); + size_t db_bytes = size_t(d) * sizeof(float); + size_t ln_bytes = size_t(d) * sizeof(float); + + ctx.pool.upload(lw.ffnUpW, static_cast(upwb.ptr), uw_bytes); + ctx.pool.upload(lw.ffnUpB, static_cast(upbb.ptr), ub_bytes); + ctx.pool.upload(lw.ffnDownW, static_cast(dnwb.ptr), dw_bytes); + ctx.pool.upload(lw.ffnDownB, static_cast(dnbb.ptr), db_bytes); + ctx.pool.upload(lw.lnGamma, static_cast(gmb.ptr), ln_bytes); + ctx.pool.upload(lw.lnBeta, static_cast(btb.ptr), ln_bytes); + + std::memcpy(h.cpu_ffn_up_w[l].data(), static_cast(upwb.ptr), uw_bytes); + std::memcpy(h.cpu_ffn_up_b[l].data(), static_cast(upbb.ptr), ub_bytes); + std::memcpy(h.cpu_ffn_down_w[l].data(), static_cast(dnwb.ptr), dw_bytes); + std::memcpy(h.cpu_ffn_down_b[l].data(), static_cast(dnbb.ptr), db_bytes); + std::memcpy(h.cpu_ln_gamma[l].data(), static_cast(gmb.ptr), ln_bytes); + std::memcpy(h.cpu_ln_beta[l].data(), static_cast(btb.ptr), ln_bytes); + } + + ctx.waitIdle(); + }, + py::arg("device"), py::arg("handle"), + py::arg("embed_w"), py::arg("pos_w"), + py::arg("ffn_up_patterns"), py::arg("ffn_up_biases"), + py::arg("ffn_down_patterns"), py::arg("ffn_down_biases"), + py::arg("ln_gammas"), py::arg("ln_betas"), + py::arg("out_w"), + "Re-upload VSA-LM weights after optimizer step."); +} diff --git a/cpp/src/buffer_pool.cpp b/cpp/src/buffer_pool.cpp index 37231b4..e56a69d 100644 --- a/cpp/src/buffer_pool.cpp +++ b/cpp/src/buffer_pool.cpp @@ -42,7 +42,21 @@ BufferPool::BufferPool(GrillyDevice& device) : device_(device) { } BufferPool::~BufferPool() { - // Destroy all pooled buffers first + VkDevice dev = device_.device(); + + // Clean up persistent transfer context + if (transferInitialized_) { + if (transferFence_ != VK_NULL_HANDLE) { + vkWaitForFences(dev, 1, &transferFence_, VK_TRUE, UINT64_MAX); + vkDestroyFence(dev, transferFence_, nullptr); + } + if (transferCmd_ != VK_NULL_HANDLE) + vkFreeCommandBuffers(dev, transferPool_, 1, &transferCmd_); + if (transferPool_ != VK_NULL_HANDLE) + vkDestroyCommandPool(dev, transferPool_, nullptr); + } + + // Destroy all pooled buffers (host-visible bucket pool) for (auto& [bucketSize, vec] : buckets_) { for (auto& buf : vec) { if (buf.handle != VK_NULL_HANDLE) @@ -51,6 +65,24 @@ BufferPool::~BufferPool() { } buckets_.clear(); + // Destroy all pooled buffers (device-local bucket pool) + for (auto& [bucketSize, vec] : dlBuckets_) { + for (auto& buf : vec) { + if (buf.handle != VK_NULL_HANDLE) + vmaDestroyBuffer(allocator_, buf.handle, buf.allocation); + } + } + dlBuckets_.clear(); + + // Destroy all pooled buffers (readback bucket pool) + for (auto& [bucketSize, vec] : readbackBuckets_) { + for (auto& buf : vec) { + if (buf.handle != VK_NULL_HANDLE) + vmaDestroyBuffer(allocator_, buf.handle, buf.allocation); + } + } + readbackBuckets_.clear(); + if (allocator_ != VK_NULL_HANDLE) vmaDestroyAllocator(allocator_); } @@ -86,14 +118,33 @@ GrillyBuffer BufferPool::allocateBuffer(size_t bucketSize) { bufferInfo.usage = VK_BUFFER_USAGE_STORAGE_BUFFER_BIT; bufferInfo.sharingMode = VK_SHARING_MODE_EXCLUSIVE; - // VMA_MEMORY_USAGE_AUTO lets VMA pick the best heap. - // MAPPED_BIT gives us a persistent CPU pointer — the key win over Python - // which does vkMapMemory/vkUnmapMemory on every upload/download cycle. - // HOST_ACCESS_SEQUENTIAL_WRITE_BIT hints that we write linearly (memcpy). + // REQUIRE device-local memory. + // + // Background: ``preferredFlags`` is a soft hint VMA can ignore. Combined + // with ``HOST_ACCESS_SEQUENTIAL_WRITE_BIT``, on Windows + AMD/RDNA the + // auto-selector lands on memoryType[1] (HOST_VISIBLE | HOST_COHERENT + // *only* — no DEVICE_LOCAL, no HOST_CACHED). The buffer ends up in + // host-uncached BAR memory, and every GPU read becomes a single-byte + // PCIe transaction → measured 0.1 GB/s effective bandwidth, ~100x slower + // than CPU/numpy. See sandbox/vsa_lm/grilly_gpu_path_test.py for the + // smoking-gun profile (gc.relu on 4.7M elements: 757 ms). + // + // Using ``requiredFlags`` forces VMA to *only* consider memory types with + // DEVICE_LOCAL_BIT. On systems with Resizable BAR (the common case for + // modern AMD + Windows + AGESA 2020+), VMA picks the + // DEVICE_LOCAL | HOST_VISIBLE | HOST_COHERENT memory type — full VRAM + // bandwidth (~432 GB/s on RX 6750 XT) + CPU mapping for fast uploads. + // + // On systems without ReBAR, this allocation will FAIL — which is the + // correct behavior, because the silent slow path was a footgun. Users + // without ReBAR should enable it in BIOS, or callers needing the legacy + // host-mapped path should use ``acquirePreferDeviceLocal`` which is + // explicitly soft-preferred for that case. VmaAllocationCreateInfo allocInfo{}; allocInfo.usage = VMA_MEMORY_USAGE_AUTO; allocInfo.flags = VMA_ALLOCATION_CREATE_MAPPED_BIT | VMA_ALLOCATION_CREATE_HOST_ACCESS_SEQUENTIAL_WRITE_BIT; + allocInfo.requiredFlags = VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT; GrillyBuffer buf{}; buf.bucketSize = bucketSize; @@ -141,7 +192,17 @@ void BufferPool::release(GrillyBuffer& buf) { std::lock_guard lock(mutex_); stats_.totalReleased++; - auto& vec = buckets_[buf.bucketSize]; + // Route to the right bucket pool based on memory class. The three pools + // MUST stay separate: + // - dlBuckets_ : DEVICE_LOCAL only (mappedPtr=null, GPU compute) + // - readbackBuckets_ : HOST_CACHED random read (CPU reads from GPU) + // - buckets_ : default WC sequential write (CPU writes to GPU) + // Picking up a DL buffer via ``acquire()`` would crash trying to memcpy + // into a null mappedPtr; picking up a WC buffer via ``acquireReadback`` + // would silently destroy CPU-read perf (the original bug we're fixing). + auto& vec = buf.deviceLocal ? dlBuckets_[buf.bucketSize] + : buf.readback ? readbackBuckets_[buf.bucketSize] + : buckets_[buf.bucketSize]; if (vec.size() < kMaxBuffersPerBucket) { vec.push_back(buf); } else { @@ -164,11 +225,32 @@ void BufferPool::release(GrillyBuffer& buf) { GrillyBuffer BufferPool::acquireDeviceLocal(size_t size) { size_t bucket = sizeToBucket(size); + std::lock_guard lock(mutex_); + stats_.totalAcquired++; + + // Try reuse from the DL bucket pool first (LIFO returns the same handle + // most often, which keeps the descriptor cache hitting on repeat calls). + auto it = dlBuckets_.find(bucket); + if (it != dlBuckets_.end() && !it->second.empty()) { + GrillyBuffer buf = it->second.back(); + it->second.pop_back(); + buf.size = size; + stats_.hits++; + return buf; + } + + stats_.misses++; + stats_.allocations++; + VkBufferCreateInfo bufferInfo{}; bufferInfo.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO; bufferInfo.size = bucket; + // We need both TRANSFER_DST (for stage-in copies) and TRANSFER_SRC (for + // stage-out copies) since the staging pattern in cpp/src/ops/linear.cpp + // copies output back from DL → host-visible staging. bufferInfo.usage = VK_BUFFER_USAGE_STORAGE_BUFFER_BIT | - VK_BUFFER_USAGE_TRANSFER_DST_BIT; + VK_BUFFER_USAGE_TRANSFER_DST_BIT | + VK_BUFFER_USAGE_TRANSFER_SRC_BIT; bufferInfo.sharingMode = VK_SHARING_MODE_EXCLUSIVE; VmaAllocationCreateInfo allocInfo{}; @@ -178,6 +260,7 @@ GrillyBuffer BufferPool::acquireDeviceLocal(size_t size) { GrillyBuffer buf{}; buf.bucketSize = bucket; buf.size = size; + buf.deviceLocal = true; // routes to dlBuckets_ on release vkCheck(vmaCreateBuffer(allocator_, &bufferInfo, &allocInfo, &buf.handle, &buf.allocation, &buf.info), @@ -221,6 +304,22 @@ GrillyBuffer BufferPool::acquirePreferDeviceLocal(size_t size) { GrillyBuffer BufferPool::acquireReadback(size_t size) { size_t bucket = sizeToBucket(size); + std::lock_guard lock(mutex_); + stats_.totalAcquired++; + + // Try reuse from the readback bucket pool first. + auto it = readbackBuckets_.find(bucket); + if (it != readbackBuckets_.end() && !it->second.empty()) { + GrillyBuffer buf = it->second.back(); + it->second.pop_back(); + buf.size = size; + stats_.hits++; + return buf; + } + + stats_.misses++; + stats_.allocations++; + VkBufferCreateInfo bufferInfo{}; bufferInfo.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO; bufferInfo.size = bucket; @@ -238,6 +337,7 @@ GrillyBuffer BufferPool::acquireReadback(size_t size) { GrillyBuffer buf{}; buf.bucketSize = bucket; buf.size = size; + buf.readback = true; // routes to readbackBuckets_ on release vkCheck(vmaCreateBuffer(allocator_, &bufferInfo, &allocInfo, &buf.handle, &buf.allocation, &buf.info), @@ -247,9 +347,57 @@ GrillyBuffer BufferPool::acquireReadback(size_t size) { return buf; } +// ── Persistent transfer context ──────────────────────────────────────────── + +void BufferPool::ensureTransferContext() { + if (transferInitialized_) return; + + VkCommandPoolCreateInfo poolInfo{}; + poolInfo.sType = VK_STRUCTURE_TYPE_COMMAND_POOL_CREATE_INFO; + poolInfo.flags = VK_COMMAND_POOL_CREATE_RESET_COMMAND_BUFFER_BIT; + poolInfo.queueFamilyIndex = device_.queueFamily(); + vkCheck(vkCreateCommandPool(device_.device(), &poolInfo, nullptr, &transferPool_), + "transfer pool creation failed"); + + VkCommandBufferAllocateInfo cmdAllocInfo{}; + cmdAllocInfo.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_ALLOCATE_INFO; + cmdAllocInfo.commandPool = transferPool_; + cmdAllocInfo.level = VK_COMMAND_BUFFER_LEVEL_PRIMARY; + cmdAllocInfo.commandBufferCount = 1; + vkCheck(vkAllocateCommandBuffers(device_.device(), &cmdAllocInfo, &transferCmd_), + "transfer cmd alloc failed"); + + VkFenceCreateInfo fenceInfo{}; + fenceInfo.sType = VK_STRUCTURE_TYPE_FENCE_CREATE_INFO; + fenceInfo.flags = VK_FENCE_CREATE_SIGNALED_BIT; + vkCheck(vkCreateFence(device_.device(), &fenceInfo, nullptr, &transferFence_), + "transfer fence creation failed"); + + transferInitialized_ = true; +} + +void BufferPool::transferSubmitAndWait() { + vkEndCommandBuffer(transferCmd_); + + vkWaitForFences(device_.device(), 1, &transferFence_, VK_TRUE, UINT64_MAX); + vkResetFences(device_.device(), 1, &transferFence_); + + VkSubmitInfo submitInfo{}; + submitInfo.sType = VK_STRUCTURE_TYPE_SUBMIT_INFO; + submitInfo.commandBufferCount = 1; + submitInfo.pCommandBuffers = &transferCmd_; + + vkCheck(vkQueueSubmit(device_.computeQueue(), 1, &submitInfo, transferFence_), + "transfer submit failed"); + vkCheck(vkWaitForFences(device_.device(), 1, &transferFence_, VK_TRUE, UINT64_MAX), + "transfer wait failed"); +} + void BufferPool::uploadStaged(GrillyBuffer& deviceBuf, const void* data, size_t bytes) { - // 1. Create a temporary host-visible staging buffer + ensureTransferContext(); + + // 1. Create staging buffer (TODO: pool these too) VkBufferCreateInfo stagingInfo{}; stagingInfo.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO; stagingInfo.size = bytes; @@ -268,67 +416,33 @@ void BufferPool::uploadStaged(GrillyBuffer& deviceBuf, const void* data, &stagingBuf, &stagingMem, &stagingMemInfo), "staging buffer alloc failed"); - // 2. Copy data into staging buffer + // 2. Copy data into staging std::memcpy(stagingMemInfo.pMappedData, data, bytes); vmaFlushAllocation(allocator_, stagingMem, 0, bytes); - // 3. Create a one-shot command buffer for the transfer - VkCommandPoolCreateInfo poolInfo{}; - poolInfo.sType = VK_STRUCTURE_TYPE_COMMAND_POOL_CREATE_INFO; - poolInfo.flags = VK_COMMAND_POOL_CREATE_TRANSIENT_BIT; - poolInfo.queueFamilyIndex = device_.queueFamily(); - - VkCommandPool cmdPool; - vkCheck(vkCreateCommandPool(device_.device(), &poolInfo, nullptr, &cmdPool), - "vkCreateCommandPool for staging failed"); - - VkCommandBufferAllocateInfo cmdAllocInfo{}; - cmdAllocInfo.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_ALLOCATE_INFO; - cmdAllocInfo.commandPool = cmdPool; - cmdAllocInfo.level = VK_COMMAND_BUFFER_LEVEL_PRIMARY; - cmdAllocInfo.commandBufferCount = 1; - - VkCommandBuffer cmd; - vkCheck(vkAllocateCommandBuffers(device_.device(), &cmdAllocInfo, &cmd), - "vkAllocateCommandBuffers for staging failed"); - + // 3. Record transfer using persistent context + vkResetCommandBuffer(transferCmd_, 0); VkCommandBufferBeginInfo beginInfo{}; beginInfo.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO; beginInfo.flags = VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT; - vkBeginCommandBuffer(cmd, &beginInfo); + vkBeginCommandBuffer(transferCmd_, &beginInfo); VkBufferCopy copyRegion{}; copyRegion.size = bytes; - vkCmdCopyBuffer(cmd, stagingBuf, deviceBuf.handle, 1, ©Region); + vkCmdCopyBuffer(transferCmd_, stagingBuf, deviceBuf.handle, 1, ©Region); - vkEndCommandBuffer(cmd); + // 4. Submit + wait (reuses persistent fence) + transferSubmitAndWait(); - // 4. Submit and wait - VkSubmitInfo submitInfo{}; - submitInfo.sType = VK_STRUCTURE_TYPE_SUBMIT_INFO; - submitInfo.commandBufferCount = 1; - submitInfo.pCommandBuffers = &cmd; - - VkFenceCreateInfo fenceInfo{}; - fenceInfo.sType = VK_STRUCTURE_TYPE_FENCE_CREATE_INFO; - VkFence fence; - vkCheck(vkCreateFence(device_.device(), &fenceInfo, nullptr, &fence), - "staging fence creation failed"); - - vkCheck(vkQueueSubmit(device_.computeQueue(), 1, &submitInfo, fence), - "staging vkQueueSubmit failed"); - vkCheck(vkWaitForFences(device_.device(), 1, &fence, VK_TRUE, UINT64_MAX), - "staging vkWaitForFences failed"); - - // 5. Cleanup - vkDestroyFence(device_.device(), fence, nullptr); - vkDestroyCommandPool(device_.device(), cmdPool, nullptr); + // 5. Cleanup staging only vmaDestroyBuffer(allocator_, stagingBuf, stagingMem); } void BufferPool::downloadStaged(const GrillyBuffer& deviceBuf, void* out, size_t bytes) { - // 1. Create a temporary host-visible staging buffer for readback + ensureTransferContext(); + + // 1. Create staging readback buffer (TODO: pool these) VkBufferCreateInfo stagingInfo{}; stagingInfo.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO; stagingInfo.size = bytes; @@ -347,61 +461,25 @@ void BufferPool::downloadStaged(const GrillyBuffer& deviceBuf, void* out, &stagingBuf, &stagingMem, &stagingMemInfo), "staging readback buffer alloc failed"); - // 2. Copy from device-local to staging via DMA - VkCommandPoolCreateInfo poolInfo{}; - poolInfo.sType = VK_STRUCTURE_TYPE_COMMAND_POOL_CREATE_INFO; - poolInfo.flags = VK_COMMAND_POOL_CREATE_TRANSIENT_BIT; - poolInfo.queueFamilyIndex = device_.queueFamily(); - - VkCommandPool cmdPool; - vkCheck(vkCreateCommandPool(device_.device(), &poolInfo, nullptr, &cmdPool), - "vkCreateCommandPool for readback failed"); - - VkCommandBufferAllocateInfo cmdAllocInfo{}; - cmdAllocInfo.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_ALLOCATE_INFO; - cmdAllocInfo.commandPool = cmdPool; - cmdAllocInfo.level = VK_COMMAND_BUFFER_LEVEL_PRIMARY; - cmdAllocInfo.commandBufferCount = 1; - - VkCommandBuffer cmd; - vkCheck(vkAllocateCommandBuffers(device_.device(), &cmdAllocInfo, &cmd), - "vkAllocateCommandBuffers for readback failed"); - + // 2. Record copy using persistent transfer context + vkResetCommandBuffer(transferCmd_, 0); VkCommandBufferBeginInfo beginInfo{}; beginInfo.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO; beginInfo.flags = VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT; - vkBeginCommandBuffer(cmd, &beginInfo); + vkBeginCommandBuffer(transferCmd_, &beginInfo); VkBufferCopy copyRegion{}; copyRegion.size = bytes; - vkCmdCopyBuffer(cmd, deviceBuf.handle, stagingBuf, 1, ©Region); - - vkEndCommandBuffer(cmd); - - // 3. Submit and wait - VkSubmitInfo submitInfo{}; - submitInfo.sType = VK_STRUCTURE_TYPE_SUBMIT_INFO; - submitInfo.commandBufferCount = 1; - submitInfo.pCommandBuffers = &cmd; - - VkFenceCreateInfo fenceInfo{}; - fenceInfo.sType = VK_STRUCTURE_TYPE_FENCE_CREATE_INFO; - VkFence fence; - vkCheck(vkCreateFence(device_.device(), &fenceInfo, nullptr, &fence), - "readback fence creation failed"); + vkCmdCopyBuffer(transferCmd_, deviceBuf.handle, stagingBuf, 1, ©Region); - vkCheck(vkQueueSubmit(device_.computeQueue(), 1, &submitInfo, fence), - "readback vkQueueSubmit failed"); - vkCheck(vkWaitForFences(device_.device(), 1, &fence, VK_TRUE, UINT64_MAX), - "readback vkWaitForFences failed"); + // 3. Submit + wait (reuses persistent fence) + transferSubmitAndWait(); // 4. Invalidate + copy to output vmaInvalidateAllocation(allocator_, stagingMem, 0, bytes); std::memcpy(out, stagingMemInfo.pMappedData, bytes); - // 5. Cleanup - vkDestroyFence(device_.device(), fence, nullptr); - vkDestroyCommandPool(device_.device(), cmdPool, nullptr); + // 5. Cleanup staging only vmaDestroyBuffer(allocator_, stagingBuf, stagingMem); } diff --git a/cpp/src/channels/channel.cpp b/cpp/src/channels/channel.cpp new file mode 100644 index 0000000..1da267c --- /dev/null +++ b/cpp/src/channels/channel.cpp @@ -0,0 +1,72 @@ +/** + * channel.cpp — InProcessChannel implementation. + */ + +#include "grilly/channels/channel.h" + +#include +#include + +namespace grilly { +namespace channels { + +InProcessChannel::InProcessChannel(const std::string& name, size_t max_queue_size) + : name_(name), max_queue_size_(max_queue_size) {} + +void InProcessChannel::send(MessageEnvelope envelope) { + std::lock_guard lock(mutex_); + + // Set timestamp if not set + if (envelope.timestamp_ns == 0) { + auto now = std::chrono::high_resolution_clock::now(); + envelope.timestamp_ns = std::chrono::duration_cast( + now.time_since_epoch()).count(); + } + + // Notify listeners first (synchronous) + auto it = listeners_.find(static_cast(envelope.type)); + if (it != listeners_.end()) { + for (auto& listener : it->second) { + listener(envelope); + } + } + + // Queue the message (drop oldest if full) + if (queue_.size() >= max_queue_size_) { + queue_.pop(); + } + queue_.push(std::move(envelope)); +} + +MessageEnvelope InProcessChannel::receive() { + std::lock_guard lock(mutex_); + if (queue_.empty()) { + return MessageEnvelope{MessageType::TENSOR_DATA, 0, "", {}}; + } + auto msg = std::move(queue_.front()); + queue_.pop(); + return msg; +} + +bool InProcessChannel::has_messages() const { + std::lock_guard lock(mutex_); + return !queue_.empty(); +} + +void InProcessChannel::subscribe(MessageType type, ChannelListener listener) { + std::lock_guard lock(mutex_); + listeners_[static_cast(type)].push_back(std::move(listener)); +} + +size_t InProcessChannel::queue_size() const { + std::lock_guard lock(mutex_); + return queue_.size(); +} + +void InProcessChannel::clear() { + std::lock_guard lock(mutex_); + while (!queue_.empty()) queue_.pop(); +} + +} // namespace channels +} // namespace grilly diff --git a/cpp/src/command_batch.cpp b/cpp/src/command_batch.cpp index 9e3ea46..0ca6c01 100644 --- a/cpp/src/command_batch.cpp +++ b/cpp/src/command_batch.cpp @@ -70,6 +70,9 @@ void CommandBatch::begin() { if (recording_) throw std::runtime_error("CommandBatch::begin() called while already recording"); + // Wait for any prior pending work before reusing the command buffer + waitForCompletion(); + vkResetCommandBuffer(cmd_, 0); VkCommandBufferBeginInfo beginInfo{}; @@ -80,6 +83,12 @@ void CommandBatch::begin() { "vkBeginCommandBuffer failed"); recording_ = true; + dispatchCount_ = 0; +} + +void CommandBatch::ensureRecording() { + if (!recording_) + begin(); } void CommandBatch::dispatch(VkPipeline pipeline, VkPipelineLayout layout, @@ -99,6 +108,7 @@ void CommandBatch::dispatch(VkPipeline pipeline, VkPipelineLayout layout, } vkCmdDispatch(cmd_, gx, gy, gz); + dispatchCount_++; } void CommandBatch::barrier() { @@ -121,6 +131,32 @@ void CommandBatch::barrier() { 0, nullptr); // image memory barriers } +void CommandBatch::transferComputeBarrier() { + if (!recording_) + return; + + // Bidirectional TRANSFER ↔ COMPUTE barrier. The src/dst access masks + // cover both edges (stage-in → compute and compute → stage-out) so the + // staging-buffer pattern in linear() can use a single method for both + // transitions without tracking direction. + VkMemoryBarrier memBarrier{}; + memBarrier.sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER; + memBarrier.srcAccessMask = VK_ACCESS_TRANSFER_WRITE_BIT | + VK_ACCESS_SHADER_WRITE_BIT; + memBarrier.dstAccessMask = VK_ACCESS_TRANSFER_READ_BIT | + VK_ACCESS_SHADER_READ_BIT; + + vkCmdPipelineBarrier(cmd_, + VK_PIPELINE_STAGE_TRANSFER_BIT | + VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, + VK_PIPELINE_STAGE_TRANSFER_BIT | + VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, + 0, + 1, &memBarrier, + 0, nullptr, + 0, nullptr); +} + void CommandBatch::copyBuffer(const GrillyBuffer& src, GrillyBuffer& dst, size_t bytes) { if (!recording_) throw std::runtime_error("CommandBatch::copyBuffer() called without begin()"); @@ -135,6 +171,12 @@ void CommandBatch::copyBuffer(const GrillyBuffer& src, GrillyBuffer& dst, size_t // ── Submission ────────────────────────────────────────────────────────────── void CommandBatch::submit() { + // Synchronous: submit + wait. Safe default for backward compat. + submitDeferred(); + waitForCompletion(); +} + +void CommandBatch::submitDeferred() { if (!recording_) return; @@ -142,7 +184,10 @@ void CommandBatch::submit() { recording_ = false; // Wait for any prior submission to complete, then reset fence - vkWaitForFences(device_.device(), 1, &fence_, VK_TRUE, kFenceTimeoutNs); + if (pending_) { + vkWaitForFences(device_.device(), 1, &fence_, VK_TRUE, kFenceTimeoutNs); + pending_ = false; + } vkResetFences(device_.device(), 1, &fence_); VkSubmitInfo submitInfo{}; @@ -153,10 +198,17 @@ void CommandBatch::submit() { vkCheck(vkQueueSubmit(device_.computeQueue(), 1, &submitInfo, fence_), "vkQueueSubmit failed"); - // Wait for completion so callers can read back results + pending_ = true; + // Returns immediately — GPU runs in background +} + +void CommandBatch::waitForCompletion() { + if (!pending_) + return; vkCheck( vkWaitForFences(device_.device(), 1, &fence_, VK_TRUE, kFenceTimeoutNs), "vkWaitForFences timed out"); + pending_ = false; } void CommandBatch::submitAsync(VkSemaphore timeline, uint64_t signalValue) { diff --git a/cpp/src/io/grl_checkpoint.cpp b/cpp/src/io/grl_checkpoint.cpp new file mode 100644 index 0000000..ffdfebc --- /dev/null +++ b/cpp/src/io/grl_checkpoint.cpp @@ -0,0 +1,110 @@ +#include "grilly/io/grl_checkpoint.h" + +#include +#include +#include + +namespace grilly::io { + +static void pack_u64(uint8_t* dst, uint64_t v) { + std::memcpy(dst, &v, 8); +} + +static uint64_t unpack_u64(const uint8_t* src) { + uint64_t v; + std::memcpy(&v, src, 8); + return v; +} + +bool grl_write_file(const std::string& path, + const std::string& metadata_json, + const std::string& index_json, + const std::vector& payload) { + std::ofstream ofs(path, std::ios::binary | std::ios::trunc); + if (!ofs) + return false; + + const uint64_t meta_off = kGrlHeaderSize; + const uint64_t meta_len = metadata_json.size(); + const uint64_t idx_off = meta_off + meta_len; + const uint64_t idx_len = index_json.size(); + const uint64_t pay_off = idx_off + idx_len; + const uint64_t pay_len = payload.size(); + + uint8_t header[kGrlHeaderSize] = {}; + header[0] = 'G'; + header[1] = 'R'; + header[2] = 'L'; + header[3] = 'Y'; + uint16_t ver = kGrlFormatVersion; + uint16_t flags = 0; + uint32_t reserved = 0; + std::memcpy(header + 4, &ver, 2); + std::memcpy(header + 6, &flags, 2); + std::memcpy(header + 8, &reserved, 4); + pack_u64(header + 12, meta_off); + pack_u64(header + 20, meta_len); + pack_u64(header + 28, idx_off); + pack_u64(header + 36, idx_len); + pack_u64(header + 44, pay_off); + pack_u64(header + 52, pay_len); + + ofs.write(reinterpret_cast(header), kGrlHeaderSize); + ofs.write(metadata_json.data(), static_cast(metadata_json.size())); + ofs.write(index_json.data(), static_cast(index_json.size())); + if (!payload.empty()) + ofs.write(reinterpret_cast(payload.data()), + static_cast(payload.size())); + return static_cast(ofs); +} + +bool grl_read_file(const std::string& path, + std::string& metadata_json, + std::string& index_json, + std::vector& payload) { + std::ifstream ifs(path, std::ios::binary | std::ios::ate); + if (!ifs) + return false; + const auto end = ifs.tellg(); + ifs.seekg(0); + if (end < static_cast(kGrlHeaderSize)) + return false; + + uint8_t header[kGrlHeaderSize]; + ifs.read(reinterpret_cast(header), kGrlHeaderSize); + if (ifs.gcount() != static_cast(kGrlHeaderSize)) + return false; + if (header[0] != 'G' || header[1] != 'R' || header[2] != 'L' || header[3] != 'Y') + return false; + uint16_t ver = 0; + std::memcpy(&ver, header + 4, 2); + if (ver != kGrlFormatVersion) + throw std::runtime_error("GRL: unsupported format version"); + + const uint64_t meta_off = unpack_u64(header + 12); + const uint64_t meta_len = unpack_u64(header + 20); + const uint64_t idx_off = unpack_u64(header + 28); + const uint64_t idx_len = unpack_u64(header + 36); + const uint64_t pay_off = unpack_u64(header + 44); + const uint64_t pay_len = unpack_u64(header + 52); + + if (meta_off + meta_len != idx_off || idx_off + idx_len != pay_off) + return false; + if (end < static_cast(pay_off + pay_len)) + return false; + + metadata_json.resize(static_cast(meta_len)); + index_json.resize(static_cast(idx_len)); + payload.resize(static_cast(pay_len)); + + ifs.seekg(static_cast(meta_off)); + ifs.read(metadata_json.data(), static_cast(meta_len)); + ifs.read(index_json.data(), static_cast(idx_len)); + if (pay_len > 0) + ifs.read(reinterpret_cast(payload.data()), + static_cast(pay_len)); + + return static_cast(ifs) || (pay_len == 0 && meta_len == 0 && idx_len == 0); +} + +} // namespace grilly::io diff --git a/cpp/src/ops/activations.cpp b/cpp/src/ops/activations.cpp index 59e0781..1e75a04 100644 --- a/cpp/src/ops/activations.cpp +++ b/cpp/src/ops/activations.cpp @@ -1,6 +1,8 @@ #include "grilly/ops/activations.h" +#include #include +#include namespace grilly { namespace ops { @@ -22,16 +24,23 @@ static void activationForward( const float* input, float* output, uint32_t totalElements) { const size_t bytes = size_t(totalElements) * sizeof(float); - GrillyBuffer bufIn = pool.acquire(bytes); - GrillyBuffer bufOut = pool.acquire(bytes); + // Staging pattern (see linear.cpp for the long-form rationale): + // compute on DEVICE_LOCAL VRAM, stage-in via WC sequential-write + // memory, stage-out via HOST_CACHED random-read memory. Without this + // a 19 MB ReLU readback ran at 25 MB/s (~750 ms); with it the same + // op runs in single-digit ms. + GrillyBuffer bufInDL = pool.acquireDeviceLocal(bytes); + GrillyBuffer bufOutDL = pool.acquireDeviceLocal(bytes); + GrillyBuffer bufInStage = pool.acquire(bytes); + GrillyBuffer bufOutStage = pool.acquireReadback(bytes); - pool.upload(bufIn, input, bytes); + pool.upload(bufInStage, input, bytes); PipelineEntry pipe = cache.getOrCreate(shaderName, 2, sizeof(uint32_t)); std::vector bufInfos = { - {bufIn.handle, 0, bytes}, - {bufOut.handle, 0, bytes}, + {bufInDL.handle, 0, bytes}, + {bufOutDL.handle, 0, bytes}, }; VkDescriptorSet descSet = cache.allocDescriptorSet(shaderName, bufInfos); @@ -39,14 +48,21 @@ static void activationForward( uint32_t gx = (totalElements + 255) / 256; batch.begin(); + batch.copyBuffer(bufInStage, bufInDL, bytes); + batch.transferComputeBarrier(); batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, 1, 1, &push, sizeof(push)); - batch.submit(); + batch.transferComputeBarrier(); + batch.copyBuffer(bufOutDL, bufOutStage, bytes); + batch.submitDeferred(); + batch.waitForCompletion(); - pool.download(bufOut, output, bytes); + pool.download(bufOutStage, output, bytes); - pool.release(bufIn); - pool.release(bufOut); + pool.release(bufInDL); + pool.release(bufOutDL); + pool.release(bufInStage); + pool.release(bufOutStage); } // ── Activation backward helper ──────────────────────────────────────────── @@ -61,19 +77,24 @@ static void activationBackward( float* gradInput, uint32_t totalElements) { const size_t bytes = size_t(totalElements) * sizeof(float); - GrillyBuffer bufGradOut = pool.acquire(bytes); - GrillyBuffer bufInput = pool.acquire(bytes); - GrillyBuffer bufGradIn = pool.acquire(bytes); + // Staging pattern: 2 stage-in (gradOutput, input), 1 stage-out (gradInput) + GrillyBuffer bufGradOutDL = pool.acquireDeviceLocal(bytes); + GrillyBuffer bufInputDL = pool.acquireDeviceLocal(bytes); + GrillyBuffer bufGradInDL = pool.acquireDeviceLocal(bytes); - pool.upload(bufGradOut, gradOutput, bytes); - pool.upload(bufInput, input, bytes); + GrillyBuffer bufGradOutStage = pool.acquire(bytes); + GrillyBuffer bufInputStage = pool.acquire(bytes); + GrillyBuffer bufGradInStage = pool.acquireReadback(bytes); + + pool.upload(bufGradOutStage, gradOutput, bytes); + pool.upload(bufInputStage, input, bytes); PipelineEntry pipe = cache.getOrCreate(shaderName, 3, sizeof(uint32_t)); std::vector bufInfos = { - {bufGradOut.handle, 0, bytes}, - {bufInput.handle, 0, bytes}, - {bufGradIn.handle, 0, bytes}, + {bufGradOutDL.handle, 0, bytes}, + {bufInputDL.handle, 0, bytes}, + {bufGradInDL.handle, 0, bytes}, }; VkDescriptorSet descSet = cache.allocDescriptorSet(shaderName, bufInfos); @@ -81,15 +102,24 @@ static void activationBackward( uint32_t gx = (totalElements + 255) / 256; batch.begin(); + batch.copyBuffer(bufGradOutStage, bufGradOutDL, bytes); + batch.copyBuffer(bufInputStage, bufInputDL, bytes); + batch.transferComputeBarrier(); batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, 1, 1, &push, sizeof(push)); - batch.submit(); - - pool.download(bufGradIn, gradInput, bytes); - - pool.release(bufGradOut); - pool.release(bufInput); - pool.release(bufGradIn); + batch.transferComputeBarrier(); + batch.copyBuffer(bufGradInDL, bufGradInStage, bytes); + batch.submitDeferred(); + batch.waitForCompletion(); + + pool.download(bufGradInStage, gradInput, bytes); + + pool.release(bufGradOutDL); + pool.release(bufInputDL); + pool.release(bufGradInDL); + pool.release(bufGradOutStage); + pool.release(bufInputStage); + pool.release(bufGradInStage); } // ── Forward passes ──────────────────────────────────────────────────────── @@ -207,7 +237,8 @@ void softmax(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx2, 1, 1, &push2, sizeof(push2)); - batch.submit(); + batch.submitDeferred(); + batch.waitForCompletion(); pool.download(bufOutput, output, dataBytes); @@ -250,7 +281,8 @@ void softmaxBackward(CommandBatch& batch, BufferPool& pool, PipelineCache& cache batch.begin(); batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, 1, 1, &push, sizeof(push)); - batch.submit(); + batch.submitDeferred(); + batch.waitForCompletion(); pool.download(bufGradIn, gradInput, bytes); @@ -259,5 +291,114 @@ void softmaxBackward(CommandBatch& batch, BufferPool& pool, PipelineCache& cache pool.release(bufGradIn); } +// ── Multiplication-free softmax (mf-softmax.glsl) ───────────────────────── +// Same 3-pass buffer layout as softmax; pass_type 0/1/2 with relu sums. + +void mfSoftmax(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, + const float* input, float* output, uint32_t batchSize, uint32_t seqLen, + uint32_t features) { + const uint32_t totalPositions = batchSize * seqLen; + const uint32_t totalElements = totalPositions * features; + const size_t dataBytes = size_t(totalElements) * sizeof(float); + const size_t auxBytes = size_t(totalPositions) * sizeof(float); + + GrillyBuffer bufInput = pool.acquire(dataBytes); + GrillyBuffer bufOutput = pool.acquire(dataBytes); + GrillyBuffer bufMax = pool.acquire(auxBytes); + GrillyBuffer bufSumPos = pool.acquire(auxBytes); + + pool.upload(bufInput, input, dataBytes); + + PipelineEntry pipe = + cache.getOrCreate("mf-softmax", 4, sizeof(SoftmaxParams)); + + std::vector bufInfos = { + {bufInput.handle, 0, dataBytes}, + {bufOutput.handle, 0, dataBytes}, + {bufMax.handle, 0, auxBytes}, + {bufSumPos.handle, 0, auxBytes}, + }; + VkDescriptorSet descSet = cache.allocDescriptorSet("mf-softmax", bufInfos); + + uint32_t gx = (totalPositions + 255) / 256; + + batch.begin(); + + SoftmaxParams push0{batchSize, seqLen, features, 0, features}; + batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, 1, 1, &push0, + sizeof(push0)); + batch.barrier(); + + SoftmaxParams push1{batchSize, seqLen, features, 1, features}; + batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, 1, 1, &push1, + sizeof(push1)); + batch.barrier(); + + SoftmaxParams push2{batchSize, seqLen, features, 2, features}; + uint32_t gx2 = (totalElements + 255) / 256; + batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx2, 1, 1, &push2, + sizeof(push2)); + + batch.submitDeferred(); + batch.waitForCompletion(); + + pool.download(bufOutput, output, dataBytes); + + pool.release(bufInput); + pool.release(bufOutput); + pool.release(bufMax); + pool.release(bufSumPos); +} + +// Push layout must match mf-softplus.glsl: uint total_elements; float c; +struct MfSoftplusParams { + uint32_t totalElements; + float c; +}; + +void mfSoftplus(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, + const float* input, float* output, uint32_t totalElements, + float beta) { + if (beta <= 0.f) { + throw std::invalid_argument("mfSoftplus: beta must be positive"); + } + const float c = 4.f / (beta * beta); + const size_t bytes = size_t(totalElements) * sizeof(float); + + GrillyBuffer bufIn = pool.acquire(bytes); + GrillyBuffer bufOut = pool.acquire(bytes); + + pool.upload(bufIn, input, bytes); + + PipelineEntry pipe = + cache.getOrCreate("mf-softplus", 2, sizeof(MfSoftplusParams)); + + std::vector bufInfos = { + {bufIn.handle, 0, bytes}, + {bufOut.handle, 0, bytes}, + }; + VkDescriptorSet descSet = cache.allocDescriptorSet("mf-softplus", bufInfos); + + MfSoftplusParams push{totalElements, c}; + uint32_t gx = (totalElements + 255) / 256; + + batch.begin(); + batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, 1, 1, &push, + sizeof(push)); + batch.submitDeferred(); + batch.waitForCompletion(); + + pool.download(bufOut, output, bytes); + + pool.release(bufIn); + pool.release(bufOut); +} + +void mfSigmoid(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, + const float* input, float* output, uint32_t totalElements) { + activationForward("mf-sigmoid", batch, pool, cache, input, output, + totalElements); +} + } // namespace ops } // namespace grilly diff --git a/cpp/src/ops/attention.cpp b/cpp/src/ops/attention.cpp index 8706be7..b7015bf 100644 --- a/cpp/src/ops/attention.cpp +++ b/cpp/src/ops/attention.cpp @@ -149,7 +149,8 @@ void flashAttention2(CommandBatch& batch, BufferPool& pool, batch.dispatch(pipe.pipeline, pipe.layout, descSet, finalGroups, 1, 1, &push2, sizeof(push2)); - batch.submit(); + batch.submitDeferred(); + batch.waitForCompletion(); // Download result pool.download(bufOutput, output, outBytes); diff --git a/cpp/src/ops/attention_ops.cpp b/cpp/src/ops/attention_ops.cpp index 7c41e07..a40e751 100644 --- a/cpp/src/ops/attention_ops.cpp +++ b/cpp/src/ops/attention_ops.cpp @@ -1,6 +1,9 @@ #include "grilly/ops/attention_ops.h" +#include "grilly/ops/activations.h" + #include +#include namespace grilly { namespace ops { @@ -59,7 +62,8 @@ void attentionScores(CommandBatch& batch, BufferPool& pool, PipelineCache& cache batch.begin(); batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, gy, gz, &p, sizeof(p)); - batch.submit(); + batch.submitDeferred(); + batch.waitForCompletion(); pool.download(bufScores, scores, scoreBytes); @@ -69,6 +73,124 @@ void attentionScores(CommandBatch& batch, BufferPool& pool, PipelineCache& cache pool.release(bufScores); } +void attentionScoresSoftmaxOutput(CommandBatch& batch, BufferPool& pool, + PipelineCache& cache, const float* Q, + const float* K, const float* V, float* output, + float* softmaxWeights, const AttentionScoresParams& sp, + const AttentionOutputParams& op) { + const uint32_t B = sp.batchSize; + const uint32_t S = sp.seqLen; + const uint32_t H = sp.numHeads; + const uint32_t D = sp.headDim; + + if (op.batchSize != B || op.seqLen != S || op.numHeads != H || op.headDim != D) { + throw std::runtime_error( + "attentionScoresSoftmaxOutput: score and output params must match"); + } + + const size_t qkvBytes = size_t(B) * H * S * D * sizeof(float); + const size_t scoreBytes = size_t(B) * H * S * S * sizeof(float); + const size_t outBytes = qkvBytes; + + GrillyBuffer bufQ = pool.acquire(qkvBytes); + GrillyBuffer bufK = pool.acquire(qkvBytes); + GrillyBuffer bufVDummy = pool.acquire(sizeof(float)); + GrillyBuffer bufScores = pool.acquire(scoreBytes); + GrillyBuffer bufWeights = pool.acquire(scoreBytes); + GrillyBuffer bufV = pool.acquire(qkvBytes); + GrillyBuffer bufOut = pool.acquire(outBytes); + + const uint32_t totalSoftmaxRows = B * H * S; + const size_t auxBytes = size_t(totalSoftmaxRows) * sizeof(float); + + GrillyBuffer bufMax = pool.acquire(auxBytes); + GrillyBuffer bufSumExp = pool.acquire(auxBytes); + + pool.upload(bufQ, Q, qkvBytes); + pool.upload(bufK, K, qkvBytes); + pool.upload(bufV, V, qkvBytes); + + PipelineEntry pipeScores = cache.getOrCreate("attention-scores", 4, + sizeof(AttentionScoresParams)); + std::vector scoresInfos = { + {bufQ.handle, 0, qkvBytes}, + {bufK.handle, 0, qkvBytes}, + {bufVDummy.handle, 0, sizeof(float)}, + {bufScores.handle, 0, scoreBytes}, + }; + VkDescriptorSet descScores = cache.allocDescriptorSet("attention-scores", scoresInfos); + + PipelineEntry pipeSoftmax = cache.getOrCreate("activation-softmax", 4, + sizeof(SoftmaxParams)); + std::vector softmaxInfos = { + {bufScores.handle, 0, scoreBytes}, + {bufWeights.handle, 0, scoreBytes}, + {bufMax.handle, 0, auxBytes}, + {bufSumExp.handle, 0, auxBytes}, + }; + VkDescriptorSet descSoftmax = cache.allocDescriptorSet("activation-softmax", softmaxInfos); + + PipelineEntry pipeOut = cache.getOrCreate("attention-output", 3, + sizeof(AttentionOutputParams)); + std::vector outInfos = { + {bufWeights.handle, 0, scoreBytes}, + {bufV.handle, 0, qkvBytes}, + {bufOut.handle, 0, outBytes}, + }; + VkDescriptorSet descOut = cache.allocDescriptorSet("attention-output", outInfos); + + const uint32_t gxS = (S + 15) / 16; + const uint32_t gyS = (S + 15) / 16; + const uint32_t gzS = B * H; + + const uint32_t softmaxGx = (totalSoftmaxRows + 255) / 256; + const uint32_t totalElements = totalSoftmaxRows * S; + const uint32_t softmaxGx2 = (totalElements + 255) / 256; + + const uint32_t outTotal = B * H * S * D; + const uint32_t gxOut = (outTotal + 255) / 256; + + batch.begin(); + + batch.dispatch(pipeScores.pipeline, pipeScores.layout, descScores, gxS, gyS, gzS, + &sp, sizeof(sp)); + batch.barrier(); + + SoftmaxParams push0{1, totalSoftmaxRows, S, 0, S}; + batch.dispatch(pipeSoftmax.pipeline, pipeSoftmax.layout, descSoftmax, softmaxGx, 1, 1, + &push0, sizeof(push0)); + batch.barrier(); + + SoftmaxParams push1{1, totalSoftmaxRows, S, 1, S}; + batch.dispatch(pipeSoftmax.pipeline, pipeSoftmax.layout, descSoftmax, softmaxGx, 1, 1, + &push1, sizeof(push1)); + batch.barrier(); + + SoftmaxParams push2{1, totalSoftmaxRows, S, 2, S}; + batch.dispatch(pipeSoftmax.pipeline, pipeSoftmax.layout, descSoftmax, softmaxGx2, 1, 1, + &push2, sizeof(push2)); + batch.barrier(); + + batch.dispatch(pipeOut.pipeline, pipeOut.layout, descOut, gxOut, 1, 1, &op, + sizeof(op)); + + batch.submitDeferred(); + batch.waitForCompletion(); + + pool.download(bufOut, output, outBytes); + pool.download(bufWeights, softmaxWeights, scoreBytes); + + pool.release(bufQ); + pool.release(bufK); + pool.release(bufVDummy); + pool.release(bufScores); + pool.release(bufWeights); + pool.release(bufV); + pool.release(bufOut); + pool.release(bufMax); + pool.release(bufSumExp); +} + // ── Attention mask ─────────────────────────────────────────────────────── void attentionMask(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, @@ -105,7 +227,8 @@ void attentionMask(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, batch.begin(); batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, 1, 1, &p, sizeof(p)); - batch.submit(); + batch.submitDeferred(); + batch.waitForCompletion(); pool.download(bufScores, scores, scoreBytes); @@ -148,7 +271,8 @@ void attentionOutput(CommandBatch& batch, BufferPool& pool, PipelineCache& cache batch.begin(); batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, 1, 1, &p, sizeof(p)); - batch.submit(); + batch.submitDeferred(); + batch.waitForCompletion(); pool.download(bufOutput, output, outBytes); @@ -188,7 +312,8 @@ void attentionConcatHeads(CommandBatch& batch, BufferPool& pool, batch.begin(); batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, 1, 1, &p, sizeof(p)); - batch.submit(); + batch.submitDeferred(); + batch.waitForCompletion(); pool.download(bufOut, concatOutput, outBytes); @@ -235,7 +360,8 @@ void applyRoPE(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, batch.begin(); batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, 1, 1, &p, sizeof(p)); - batch.submit(); + batch.submitDeferred(); + batch.waitForCompletion(); pool.download(bufOut, output, dataBytes); diff --git a/cpp/src/ops/batched_ops.cpp b/cpp/src/ops/batched_ops.cpp index 91701d7..f7e173e 100644 --- a/cpp/src/ops/batched_ops.cpp +++ b/cpp/src/ops/batched_ops.cpp @@ -10,6 +10,7 @@ #include "grilly/ops/batched_ops.h" #include "grilly/ops/linear.h" #include "grilly/ops/activations.h" +#include "grilly/ops/embedding.h" namespace grilly { namespace ops { @@ -188,6 +189,33 @@ void batchedAdd(CommandBatch& batch, PipelineCache& cache, &push, sizeof(push)); } +void batchedEmbeddingLookup(CommandBatch& batch, PipelineCache& cache, + const GrillyBuffer& tokenIds, const GrillyBuffer& embedTable, + GrillyBuffer& output, uint32_t batchSize, uint32_t seqLen, + uint32_t vocabSize, uint32_t embeddingDim) { + + PipelineEntry pipe = cache.getOrCreate("embedding-lookup", 3, + sizeof(EmbeddingParams)); + + const uint32_t totalTokens = batchSize * seqLen; + size_t idBytes = size_t(totalTokens) * sizeof(uint32_t); + size_t embedBytes = size_t(vocabSize) * embeddingDim * sizeof(float); + size_t outBytes = size_t(totalTokens) * embeddingDim * sizeof(float); + + std::vector bufs = { + {tokenIds.handle, 0, idBytes}, + {embedTable.handle, 0, embedBytes}, + {output.handle, 0, outBytes}, + }; + VkDescriptorSet desc = cache.allocDescriptorSet("embedding-lookup", bufs); + + EmbeddingParams push{batchSize, seqLen, vocabSize, embeddingDim}; + uint32_t gx = (totalTokens + 255) / 256; + + batch.dispatch(pipe.pipeline, pipe.layout, desc, gx, 1, 1, + &push, sizeof(push)); +} + void batchedTiledLinear(CommandBatch& batch, PipelineCache& cache, const GrillyBuffer& input, const GrillyBuffer& weight, diff --git a/cpp/src/ops/batchnorm.cpp b/cpp/src/ops/batchnorm.cpp index a108258..1037f5f 100644 --- a/cpp/src/ops/batchnorm.cpp +++ b/cpp/src/ops/batchnorm.cpp @@ -67,7 +67,8 @@ void batchnorm2dForward(CommandBatch& batch, BufferPool& pool, batch.begin(); batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, 1, 1, &p, sizeof(p)); - batch.submit(); + batch.submitDeferred(); + batch.waitForCompletion(); pool.download(bufOutput, output, dataBytes); pool.download(bufRunMean, runningMean, featBytes); @@ -140,7 +141,8 @@ void batchnorm2dBackward(CommandBatch& batch, BufferPool& pool, batch.begin(); batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, 1, 1, &p, sizeof(p)); - batch.submit(); + batch.submitDeferred(); + batch.waitForCompletion(); pool.download(bufGradIn, gradInput, dataBytes); pool.download(bufGradGamma, gradGamma, featBytes); diff --git a/cpp/src/ops/conv.cpp b/cpp/src/ops/conv.cpp index b07c917..f12449d 100644 --- a/cpp/src/ops/conv.cpp +++ b/cpp/src/ops/conv.cpp @@ -175,7 +175,8 @@ void conv2d(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, &biasPush, sizeof(biasPush)); } - batch.submit(); + batch.submitDeferred(); + batch.waitForCompletion(); // Download from the correct buffer (biased output if shader was // available, raw GEMM output otherwise) @@ -246,7 +247,8 @@ void conv2d(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, batch.begin(); batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, gy, gz, &push, sizeof(push)); - batch.submit(); + batch.submitDeferred(); + batch.waitForCompletion(); pool.download(bufOutput, output, outputBytes); @@ -324,7 +326,8 @@ void conv2dBackwardInput(CommandBatch& batch, BufferPool& pool, batch.begin(); batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, gy, gz, &p, sizeof(p)); - batch.submit(); + batch.submitDeferred(); + batch.waitForCompletion(); pool.download(bufGradIn, gradInput, gradInBytes); @@ -387,7 +390,8 @@ void conv2dBackwardWeight(CommandBatch& batch, BufferPool& pool, batch.begin(); batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, gy, 1, &p, sizeof(p)); - batch.submit(); + batch.submitDeferred(); + batch.waitForCompletion(); pool.download(bufGradW, gradWeight, gradWBytes); if (p.hasBias && gradBias) { @@ -439,7 +443,8 @@ void conv2d3x3Gelu(CommandBatch& batch, BufferPool& pool, batch.dispatch(pipe.pipeline, pipe.layout, desc, (width + 15) / 16, (height + 15) / 16, outChannels, &pc, sizeof(pc)); - batch.submit(); + batch.submitDeferred(); + batch.waitForCompletion(); pool.download(bufOut, output, outBytes); @@ -480,7 +485,8 @@ void maxpool2x2(CommandBatch& batch, BufferPool& pool, batch.dispatch(pipe.pipeline, pipe.layout, desc, (outW + 15) / 16, (outH + 15) / 16, channels, &pc, sizeof(pc)); - batch.submit(); + batch.submitDeferred(); + batch.waitForCompletion(); pool.download(bufOut, output, outBytes); @@ -516,7 +522,8 @@ void adaptiveAvgPool3x3(CommandBatch& batch, BufferPool& pool, batch.dispatch(pipe.pipeline, pipe.layout, desc, 1, 1, (channels + 15) / 16, &pc, sizeof(pc)); - batch.submit(); + batch.submitDeferred(); + batch.waitForCompletion(); pool.download(bufOut, output, outBytes); diff --git a/cpp/src/ops/embedding.cpp b/cpp/src/ops/embedding.cpp index 05b5c9d..eda4265 100644 --- a/cpp/src/ops/embedding.cpp +++ b/cpp/src/ops/embedding.cpp @@ -22,20 +22,28 @@ void embeddingLookup(CommandBatch& batch, BufferPool& pool, const size_t embedBytes = size_t(p.vocabSize) * p.embeddingDim * sizeof(float); const size_t outBytes = size_t(totalTokens) * p.embeddingDim * sizeof(float); - GrillyBuffer bufIds = pool.acquire(idBytes); - GrillyBuffer bufEmbed = pool.acquire(embedBytes); - GrillyBuffer bufOut = pool.acquire(outBytes); + // Staging pattern: 2 stage-in (token ids, embedding table), 1 stage-out + // (output token vectors). Embedding table is the big one (vocab × dim); + // for VSA-LM with vocab=8192, dim=384 it's 12 MB and re-uploaded every + // call. A weight cache (TODO) would amortize this across training steps. + GrillyBuffer bufIdsDL = pool.acquireDeviceLocal(idBytes); + GrillyBuffer bufEmbedDL = pool.acquireDeviceLocal(embedBytes); + GrillyBuffer bufOutDL = pool.acquireDeviceLocal(outBytes); - pool.upload(bufIds, reinterpret_cast(tokenIds), idBytes); - pool.upload(bufEmbed, embeddings, embedBytes); + GrillyBuffer bufIdsStage = pool.acquire(idBytes); + GrillyBuffer bufEmbedStage = pool.acquire(embedBytes); + GrillyBuffer bufOutStage = pool.acquireReadback(outBytes); + + pool.upload(bufIdsStage, reinterpret_cast(tokenIds), idBytes); + pool.upload(bufEmbedStage, embeddings, embedBytes); PipelineEntry pipe = cache.getOrCreate("embedding-lookup", 3, sizeof(EmbeddingParams)); std::vector bufInfos = { - {bufIds.handle, 0, idBytes}, - {bufEmbed.handle, 0, embedBytes}, - {bufOut.handle, 0, outBytes}, + {bufIdsDL.handle, 0, idBytes}, + {bufEmbedDL.handle, 0, embedBytes}, + {bufOutDL.handle, 0, outBytes}, }; VkDescriptorSet descSet = cache.allocDescriptorSet("embedding-lookup", bufInfos); @@ -43,15 +51,24 @@ void embeddingLookup(CommandBatch& batch, BufferPool& pool, uint32_t gx = (totalTokens + 255) / 256; batch.begin(); + batch.copyBuffer(bufIdsStage, bufIdsDL, idBytes); + batch.copyBuffer(bufEmbedStage, bufEmbedDL, embedBytes); + batch.transferComputeBarrier(); batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, 1, 1, &p, sizeof(p)); - batch.submit(); + batch.transferComputeBarrier(); + batch.copyBuffer(bufOutDL, bufOutStage, outBytes); + batch.submitDeferred(); + batch.waitForCompletion(); - pool.download(bufOut, output, outBytes); + pool.download(bufOutStage, output, outBytes); - pool.release(bufIds); - pool.release(bufEmbed); - pool.release(bufOut); + pool.release(bufIdsDL); + pool.release(bufEmbedDL); + pool.release(bufOutDL); + pool.release(bufIdsStage); + pool.release(bufEmbedStage); + pool.release(bufOutStage); } } // namespace ops diff --git a/cpp/src/ops/fused.cpp b/cpp/src/ops/fused.cpp index 870406a..62650f4 100644 --- a/cpp/src/ops/fused.cpp +++ b/cpp/src/ops/fused.cpp @@ -59,7 +59,8 @@ void fusedMlpGelu(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, batch.begin(); batch.dispatch(pipe.pipeline, pipe.layout, descSet, seqLen, 1, 1, nullptr, 0); - batch.submit(); + batch.submitDeferred(); + batch.waitForCompletion(); pool.download(bufOutput, output, outputBytes); @@ -119,7 +120,8 @@ void fusedLayernormLinear(CommandBatch& batch, BufferPool& pool, PipelineCache& batch.begin(); batch.dispatch(pipe.pipeline, pipe.layout, descSet, seqLen, 1, 1, nullptr, 0); - batch.submit(); + batch.submitDeferred(); + batch.waitForCompletion(); pool.download(bufOutput, output, outputBytes); diff --git a/cpp/src/ops/kv_cache.cpp b/cpp/src/ops/kv_cache.cpp index 0651fe9..90e3bac 100644 --- a/cpp/src/ops/kv_cache.cpp +++ b/cpp/src/ops/kv_cache.cpp @@ -231,7 +231,8 @@ void kvCacheAppend(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, batch.begin(); batch.dispatch(pipeGemm.pipeline, pipeGemm.layout, descGemm, gemmGX, gemmGY, 1, gemmPush, sizeof(gemmPush)); - batch.submit(); + batch.submitDeferred(); + batch.waitForCompletion(); // Copy compressed latents into cache at currentLen offset std::vector latentData(numNewTokens * cfg.numHeads * latentDim); @@ -475,7 +476,8 @@ void kvCacheDecode(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, batch.dispatch(pipeGemm.pipeline, pipeGemm.layout, descGemm, (N_proj + 15) / 16, (M_proj + 15) / 16, 1, gemmPush, sizeof(gemmPush)); - batch.submit(); + batch.submitDeferred(); + batch.waitForCompletion(); // Download and split into K and V std::vector decoded(M_proj * N_proj); diff --git a/cpp/src/ops/layernorm.cpp b/cpp/src/ops/layernorm.cpp index 14771c2..7c26b70 100644 --- a/cpp/src/ops/layernorm.cpp +++ b/cpp/src/ops/layernorm.cpp @@ -36,38 +36,49 @@ void layernorm(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, const size_t meanBytes = size_t(totalPositions) * sizeof(float); const size_t varBytes = meanBytes; - // Acquire 6 buffers - GrillyBuffer bufInput = pool.acquire(inputBytes); - GrillyBuffer bufOutput = pool.acquire(outputBytes); - GrillyBuffer bufGamma = pool.acquire(gammaBytes); - GrillyBuffer bufBeta = pool.acquire(betaBytes); - GrillyBuffer bufMean = pool.acquire(meanBytes); - GrillyBuffer bufVar = pool.acquire(varBytes); - - // Upload input data - pool.upload(bufInput, input, inputBytes); - pool.upload(bufGamma, gamma, gammaBytes); - pool.upload(bufBeta, beta, betaBytes); + // Staging pattern: 3 stage-in (input, gamma, beta), 1 stage-out (output). + // mean and var are intermediate buffers — pure DEVICE_LOCAL, never touched + // by the CPU, so no staging buffers needed for them. + GrillyBuffer bufInputDL = pool.acquireDeviceLocal(inputBytes); + GrillyBuffer bufOutputDL = pool.acquireDeviceLocal(outputBytes); + GrillyBuffer bufGammaDL = pool.acquireDeviceLocal(gammaBytes); + GrillyBuffer bufBetaDL = pool.acquireDeviceLocal(betaBytes); + GrillyBuffer bufMeanDL = pool.acquireDeviceLocal(meanBytes); + GrillyBuffer bufVarDL = pool.acquireDeviceLocal(varBytes); + + GrillyBuffer bufInputStage = pool.acquire(inputBytes); + GrillyBuffer bufGammaStage = pool.acquire(gammaBytes); + GrillyBuffer bufBetaStage = pool.acquire(betaBytes); + GrillyBuffer bufOutputStage = pool.acquireReadback(outputBytes); + + pool.upload(bufInputStage, input, inputBytes); + pool.upload(bufGammaStage, gamma, gammaBytes); + pool.upload(bufBetaStage, beta, betaBytes); // Get pipeline: 6 buffers, 20 bytes push constants PipelineEntry pipe = cache.getOrCreate("fnn-layernorm", 6, sizeof(LayerNormParams)); - // Descriptor set (same for all 3 passes — same buffers) + // Descriptor set bound to DEVICE_LOCAL buffers std::vector bufInfos = { - {bufInput.handle, 0, inputBytes}, - {bufOutput.handle, 0, outputBytes}, - {bufGamma.handle, 0, gammaBytes}, - {bufBeta.handle, 0, betaBytes}, - {bufMean.handle, 0, meanBytes}, - {bufVar.handle, 0, varBytes}, + {bufInputDL.handle, 0, inputBytes}, + {bufOutputDL.handle, 0, outputBytes}, + {bufGammaDL.handle, 0, gammaBytes}, + {bufBetaDL.handle, 0, betaBytes}, + {bufMeanDL.handle, 0, meanBytes}, + {bufVarDL.handle, 0, varBytes}, }; VkDescriptorSet descSet = cache.allocDescriptorSet("fnn-layernorm", bufInfos); - // 3-pass dispatch batch.begin(); + // Stage-in: copy 3 host-visible staging buffers to DL VRAM + batch.copyBuffer(bufInputStage, bufInputDL, inputBytes); + batch.copyBuffer(bufGammaStage, bufGammaDL, gammaBytes); + batch.copyBuffer(bufBetaStage, bufBetaDL, betaBytes); + batch.transferComputeBarrier(); + // Pass 0: compute mean LayerNormParams push0{batchSize, seqLen, features, eps, 0}; uint32_t gx0 = (totalPositions + 255) / 256; @@ -88,18 +99,25 @@ void layernorm(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx2, 1, 1, &push2, sizeof(push2)); - batch.submit(); - - // Download result - pool.download(bufOutput, output, outputBytes); - - // Release all buffers - pool.release(bufInput); - pool.release(bufOutput); - pool.release(bufGamma); - pool.release(bufBeta); - pool.release(bufMean); - pool.release(bufVar); + // Stage-out: DL output → HOST_CACHED readback staging + batch.transferComputeBarrier(); + batch.copyBuffer(bufOutputDL, bufOutputStage, outputBytes); + + batch.submitDeferred(); + batch.waitForCompletion(); + + pool.download(bufOutputStage, output, outputBytes); + + pool.release(bufInputDL); + pool.release(bufOutputDL); + pool.release(bufGammaDL); + pool.release(bufBetaDL); + pool.release(bufMeanDL); + pool.release(bufVarDL); + pool.release(bufInputStage); + pool.release(bufGammaStage); + pool.release(bufBetaStage); + pool.release(bufOutputStage); } // ── LayerNorm backward ─────────────────────────────────────────────────── @@ -118,47 +136,69 @@ void layernormBackward(CommandBatch& batch, BufferPool& pool, const size_t gammaBytes = size_t(features) * sizeof(float); const size_t posBytes = size_t(totalPositions) * sizeof(float); - // 8 buffers for backward - GrillyBuffer bufGradOut = pool.acquire(elemBytes); - GrillyBuffer bufInput = pool.acquire(elemBytes); - GrillyBuffer bufGamma = pool.acquire(gammaBytes); - GrillyBuffer bufMean = pool.acquire(posBytes); - GrillyBuffer bufVar = pool.acquire(posBytes); - GrillyBuffer bufGradIn = pool.acquire(elemBytes); - GrillyBuffer bufGradGamma = pool.acquire(gammaBytes); - GrillyBuffer bufGradBeta = pool.acquire(gammaBytes); - - pool.upload(bufGradOut, gradOutput, elemBytes); - pool.upload(bufInput, input, elemBytes); - pool.upload(bufGamma, gamma, gammaBytes); - pool.upload(bufMean, mean, posBytes); - pool.upload(bufVar, var, posBytes); - - // Zero grad outputs + // Staging pattern: 5 stage-in (gradOut, input, gamma, mean, var), + // 3 stage-out (gradIn, gradGamma, gradBeta) + GrillyBuffer bufGradOutDL = pool.acquireDeviceLocal(elemBytes); + GrillyBuffer bufInputDL = pool.acquireDeviceLocal(elemBytes); + GrillyBuffer bufGammaDL = pool.acquireDeviceLocal(gammaBytes); + GrillyBuffer bufMeanDL = pool.acquireDeviceLocal(posBytes); + GrillyBuffer bufVarDL = pool.acquireDeviceLocal(posBytes); + GrillyBuffer bufGradInDL = pool.acquireDeviceLocal(elemBytes); + GrillyBuffer bufGradGammaDL = pool.acquireDeviceLocal(gammaBytes); + GrillyBuffer bufGradBetaDL = pool.acquireDeviceLocal(gammaBytes); + + GrillyBuffer bufGradOutStage = pool.acquire(elemBytes); + GrillyBuffer bufInputStage = pool.acquire(elemBytes); + GrillyBuffer bufGammaStage = pool.acquire(gammaBytes); + GrillyBuffer bufMeanStage = pool.acquire(posBytes); + GrillyBuffer bufVarStage = pool.acquire(posBytes); + GrillyBuffer bufGradInStage = pool.acquireReadback(elemBytes); + GrillyBuffer bufGradGammaStage = pool.acquireReadback(gammaBytes); + GrillyBuffer bufGradBetaStage = pool.acquireReadback(gammaBytes); + + pool.upload(bufGradOutStage, gradOutput, elemBytes); + pool.upload(bufInputStage, input, elemBytes); + pool.upload(bufGammaStage, gamma, gammaBytes); + pool.upload(bufMeanStage, mean, posBytes); + pool.upload(bufVarStage, var, posBytes); + + // Zero the grad output staging buffers (atomic accumulation in shader). + // Reuse the readback stage buffers as upload-zeros source. std::vector zeros_elem(totalElements, 0.0f); std::vector zeros_feat(features, 0.0f); - pool.upload(bufGradIn, zeros_elem.data(), elemBytes); - pool.upload(bufGradGamma, zeros_feat.data(), gammaBytes); - pool.upload(bufGradBeta, zeros_feat.data(), gammaBytes); + pool.upload(bufGradInStage, zeros_elem.data(), elemBytes); + pool.upload(bufGradGammaStage, zeros_feat.data(), gammaBytes); + pool.upload(bufGradBetaStage, zeros_feat.data(), gammaBytes); PipelineEntry pipe = cache.getOrCreate("fnn-layernorm-backward", 8, sizeof(LayerNormParams)); std::vector bufInfos = { - {bufGradOut.handle, 0, elemBytes}, - {bufInput.handle, 0, elemBytes}, - {bufGamma.handle, 0, gammaBytes}, - {bufMean.handle, 0, posBytes}, - {bufVar.handle, 0, posBytes}, - {bufGradIn.handle, 0, elemBytes}, - {bufGradGamma.handle, 0, gammaBytes}, - {bufGradBeta.handle, 0, gammaBytes}, + {bufGradOutDL.handle, 0, elemBytes}, + {bufInputDL.handle, 0, elemBytes}, + {bufGammaDL.handle, 0, gammaBytes}, + {bufMeanDL.handle, 0, posBytes}, + {bufVarDL.handle, 0, posBytes}, + {bufGradInDL.handle, 0, elemBytes}, + {bufGradGammaDL.handle, 0, gammaBytes}, + {bufGradBetaDL.handle, 0, gammaBytes}, }; VkDescriptorSet descSet = cache.allocDescriptorSet( "fnn-layernorm-backward", bufInfos); batch.begin(); + // Stage-in: copy all 8 stage buffers (5 inputs + 3 zeroed grads) to DL + batch.copyBuffer(bufGradOutStage, bufGradOutDL, elemBytes); + batch.copyBuffer(bufInputStage, bufInputDL, elemBytes); + batch.copyBuffer(bufGammaStage, bufGammaDL, gammaBytes); + batch.copyBuffer(bufMeanStage, bufMeanDL, posBytes); + batch.copyBuffer(bufVarStage, bufVarDL, posBytes); + batch.copyBuffer(bufGradInStage, bufGradInDL, elemBytes); + batch.copyBuffer(bufGradGammaStage, bufGradGammaDL, gammaBytes); + batch.copyBuffer(bufGradBetaStage, bufGradBetaDL, gammaBytes); + batch.transferComputeBarrier(); + // Pass 0: intermediate sums LayerNormParams push0{batchSize, seqLen, features, eps, 0}; batch.dispatch(pipe.pipeline, pipe.layout, descSet, @@ -179,20 +219,35 @@ void layernormBackward(CommandBatch& batch, BufferPool& pool, (features + 255) / 256, 1, 1, &push2, sizeof(push2)); - batch.submit(); - - pool.download(bufGradIn, gradInput, elemBytes); - pool.download(bufGradGamma, gradGamma, gammaBytes); - pool.download(bufGradBeta, gradBeta, gammaBytes); - - pool.release(bufGradOut); - pool.release(bufInput); - pool.release(bufGamma); - pool.release(bufMean); - pool.release(bufVar); - pool.release(bufGradIn); - pool.release(bufGradGamma); - pool.release(bufGradBeta); + // Stage-out: copy 3 grad buffers from DL → HOST_CACHED readback staging + batch.transferComputeBarrier(); + batch.copyBuffer(bufGradInDL, bufGradInStage, elemBytes); + batch.copyBuffer(bufGradGammaDL, bufGradGammaStage, gammaBytes); + batch.copyBuffer(bufGradBetaDL, bufGradBetaStage, gammaBytes); + + batch.submitDeferred(); + batch.waitForCompletion(); + + pool.download(bufGradInStage, gradInput, elemBytes); + pool.download(bufGradGammaStage, gradGamma, gammaBytes); + pool.download(bufGradBetaStage, gradBeta, gammaBytes); + + pool.release(bufGradOutDL); + pool.release(bufInputDL); + pool.release(bufGammaDL); + pool.release(bufMeanDL); + pool.release(bufVarDL); + pool.release(bufGradInDL); + pool.release(bufGradGammaDL); + pool.release(bufGradBetaDL); + pool.release(bufGradOutStage); + pool.release(bufInputStage); + pool.release(bufGammaStage); + pool.release(bufMeanStage); + pool.release(bufVarStage); + pool.release(bufGradInStage); + pool.release(bufGradGammaStage); + pool.release(bufGradBetaStage); } } // namespace ops diff --git a/cpp/src/ops/learning.cpp b/cpp/src/ops/learning.cpp index 3c45e67..015116e 100644 --- a/cpp/src/ops/learning.cpp +++ b/cpp/src/ops/learning.cpp @@ -44,7 +44,8 @@ void ssmFusedMath(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, batch.begin(); batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, 1, 1, &p, sizeof(p)); - batch.submit(); + batch.submitDeferred(); + batch.waitForCompletion(); pool.download(bufOut, scanOut, dataBytes); @@ -82,7 +83,8 @@ void fisherInfoUpdate(CommandBatch& batch, BufferPool& pool, batch.begin(); batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, 1, 1, &p, sizeof(p)); - batch.submit(); + batch.submitDeferred(); + batch.waitForCompletion(); pool.download(bufFisher, fisher, bytes); @@ -124,7 +126,8 @@ void ewcPenalty(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, batch.begin(); batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, 1, 1, &p, sizeof(p)); - batch.submit(); + batch.submitDeferred(); + batch.waitForCompletion(); pool.download(bufPenalty, penalty, bytes); @@ -170,7 +173,8 @@ void nlmsPredict(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, batch.begin(); batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, 1, 1, &p, sizeof(p)); - batch.submit(); + batch.submitDeferred(); + batch.waitForCompletion(); pool.download(bufY, yPred, yBytes); @@ -224,7 +228,8 @@ void nlmsUpdate(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, batch.begin(); batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, 1, 1, &p, sizeof(p)); - batch.submit(); + batch.submitDeferred(); + batch.waitForCompletion(); pool.download(bufW, w, featureBytes); pool.download(bufBias, bias, scalarBytes); @@ -302,7 +307,8 @@ void continuousToSpikes(CommandBatch& batch, BufferPool& pool, batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, 1, 1, &push1, sizeof(push1)); - batch.submit(); + batch.submitDeferred(); + batch.waitForCompletion(); pool.download(bufSpikes, spikes, spikeBytes); @@ -376,7 +382,8 @@ void spikesToContinuous(CommandBatch& batch, BufferPool& pool, batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, 1, 1, &push1, sizeof(push1)); - batch.submit(); + batch.submitDeferred(); + batch.waitForCompletion(); pool.download(bufFeat, features, featBytes); diff --git a/cpp/src/ops/linear.cpp b/cpp/src/ops/linear.cpp index 7f90d86..5dfd7f4 100644 --- a/cpp/src/ops/linear.cpp +++ b/cpp/src/ops/linear.cpp @@ -6,75 +6,222 @@ namespace grilly { namespace ops { -// ── GPU linear (port of fnn.py:1823-1976) ─────────────────────────────────── +// ── GPU linear with explicit DEVICE_LOCAL + staging pattern ──────────────── // -// In the Python backend, linear() makes ~12 ctypes FFI calls: -// acquire buffer × 4, upload × 3, get_or_create_pipeline, get_descriptor_set, -// dispatch_compute (which internally does: reset cmd, begin, bind pipeline, -// bind descriptors, push constants, dispatch, end, submit, wait fence), -// download, release × 4. +// On AMD/Windows even with Resizable BAR enabled, the DEVICE_LOCAL + +// HOST_VISIBLE memory type that VMA selects for ``BufferPool::acquire`` +// lands in WC-mapped memory that bypasses the GPU's L2 cache. Compute +// kernels reading from it run at ~0.05 GB/s — slower than a SATA SSD, +// roughly 0.04% of theoretical VRAM bandwidth (432 GB/s on RX 6750 XT). +// See sandbox/vsa_lm/grilly_gpu_path_test.py for the smoking-gun profile. // -// Here ALL of that is native C++ — zero Python crossings. The CommandBatch -// records everything into a single command buffer submission. +// The fix: compute buffers go through ``acquireDeviceLocal`` (DEVICE_LOCAL +// only, full cached VRAM, ~432 GB/s), and we move data in/out via small +// staging buffers from the regular pool. The staging buffers are slow for +// GPU compute reads but fine for ``vkCmdCopyBuffer`` transfers, which use +// the GPU's dedicated DMA engine and run at PCIe speed (~25 GB/s). +// +// All 3 staging-in copies, the compute dispatch, and the 1 staging-out +// copy are batched into a single command buffer with a single submit/wait, +// so the dispatch overhead is unchanged from the old fast-path. void linear(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, - const float* x, const float* weights, const float* bias, - float* output, const LinearParams& p) { - // ── Buffer sizes ── - const size_t inputBytes = size_t(p.batchSeq) * p.inputDim * sizeof(float); - const size_t weightBytes = size_t(p.outputDim) * p.inputDim * sizeof(float); + const void* x, const void* weights, const void* bias, + void* output, const LinearParams& p) { + // ── Byte sizes (dynamic — fp32 or fp16 determined by p.elemSize) ── + const uint32_t inElem = p.elemSize; // 2 for fp16, 4 for fp32 + const size_t inputBytes = size_t(p.batchSeq) * p.inputDim * inElem; + const size_t weightBytes = size_t(p.outputDim) * p.inputDim * inElem; + // Bias is ALWAYS fp32 regardless of input dtype. The fp32 bias matches + // both fnn-linear's 3rd binding and gemm-bias-add's accumulator, and + // bias is small enough (outputDim floats) that the bandwidth cost of + // fp32 vs fp16 is negligible. const size_t biasBytes = p.hasBias ? size_t(p.outputDim) * sizeof(float) : sizeof(float); // dummy + // The output is ALWAYS fp32 regardless of input dtype — coopmat + // accumulator runs in fp32 for numerical stability, and fnn-linear + // also writes fp32. The Python binding converts back to fp16 if + // requested by the caller's dtype. const size_t outputBytes = size_t(p.batchSeq) * p.outputDim * sizeof(float); - // ── Acquire buffers (bucket-rounded, persistent mapping) ── - GrillyBuffer bufInput = pool.acquire(inputBytes); - GrillyBuffer bufWeights = pool.acquire(weightBytes); - GrillyBuffer bufBias = pool.acquire(biasBytes); - GrillyBuffer bufOutput = pool.acquire(outputBytes); + // ── Shader selection ── + // Coopmat requirements: + // - fp16 input (elemSize == 2) + // - device exposes VK_KHR_cooperative_matrix + // - the compiled SPIR-V is loaded in the pipeline cache + // - shape aligned to the shader's tile (M%16, K%16, N%64) + const bool shapeAligned = + (p.batchSeq % 16u == 0u) && + (p.inputDim % 16u == 0u) && + (p.outputDim % 64u == 0u); + const bool useCoopMat = + inElem == 2u && + cache.getDevice().hasCooperativeMatrix() && + cache.hasShader("gemm-coopmat-shared") && + shapeAligned; + + // fp16 input without a coopmat path is not supported in this function — + // the fallback fnn-linear shader is fp32-only. Callers must either use + // fp32 input or run on a device that supports cooperative matrix. + if (inElem == 2u && !useCoopMat) { + throw std::runtime_error( + "linear(): fp16 input requested but cooperative matrix path is " + "unavailable (missing device support, shader, or shape " + "alignment — M%16, K%16, N%64 required)."); + } - // ── Upload via persistent mapping (single memcpy each, no vkMap/vkUnmap) ── - pool.upload(bufInput, x, inputBytes); - pool.upload(bufWeights, weights, weightBytes); + // ── Acquire DEVICE_LOCAL compute buffers (cached VRAM, fast GPU access) ── + GrillyBuffer bufInputDL = pool.acquireDeviceLocal(inputBytes); + GrillyBuffer bufWeightsDL = pool.acquireDeviceLocal(weightBytes); + GrillyBuffer bufBiasDL = pool.acquireDeviceLocal(biasBytes); + GrillyBuffer bufOutputDL = pool.acquireDeviceLocal(outputBytes); + + // ── Acquire host-visible staging buffers ── + // Stage-IN buffers (CPU writes only): WC memory is fast for sequential + // memcpy at ~9 GB/s — pool.acquire() is the right choice. + GrillyBuffer bufInputStage = pool.acquire(inputBytes); + GrillyBuffer bufWeightsStage = pool.acquire(weightBytes); + GrillyBuffer bufBiasStage = pool.acquire(biasBytes); + // Stage-OUT buffer (CPU reads from it): MUST be HOST_CACHED random-read + // memory. WC memory is uncached on the CPU side and a 19 MB readback + // memcpy ran at ~25 MB/s (749 ms — slower than the 9 ms GPU compute!). + // HOST_CACHED via acquireReadback gives ~7 GB/s for the same memcpy. + GrillyBuffer bufOutputStage = pool.acquireReadback(outputBytes); + + // ── memcpy CPU → staging (raw bytes, dtype-agnostic for x/weights) ── + pool.upload(bufInputStage, + reinterpret_cast(x), inputBytes); + pool.upload(bufWeightsStage, + reinterpret_cast(weights), weightBytes); if (p.hasBias && bias) { - pool.upload(bufBias, bias, p.outputDim * sizeof(float)); + // Bias is always fp32 (see biasBytes computation above). + pool.upload(bufBiasStage, + reinterpret_cast(bias), + size_t(p.outputDim) * sizeof(float)); + } + + // ── Get or create pipeline ── + const std::string shaderName = useCoopMat ? "gemm-coopmat-shared" + : "fnn-linear"; + // gemm-coopmat-shared has 3 bindings (A, B, C); push constants = 12 bytes. + // fnn-linear has 4 bindings (input, weights, bias, output); push 16 bytes. + const uint32_t numBindings = useCoopMat ? 3u : 4u; + const uint32_t pushSize = useCoopMat ? 12u : 16u; + PipelineEntry pipe = cache.getOrCreate(shaderName, numBindings, pushSize); + + // ── Allocate descriptor set ── + std::vector bufferInfos; + if (useCoopMat) { + bufferInfos = { + {bufInputDL.handle, 0, inputBytes}, + {bufWeightsDL.handle, 0, weightBytes}, + {bufOutputDL.handle, 0, outputBytes}, + }; + } else { + bufferInfos = { + {bufInputDL.handle, 0, inputBytes}, + {bufWeightsDL.handle, 0, weightBytes}, + {bufBiasDL.handle, 0, biasBytes}, + {bufOutputDL.handle, 0, outputBytes}, + }; + } + VkDescriptorSet descSet = cache.allocDescriptorSet(shaderName, bufferInfos); + + // Dispatch grid depends on the shader's output tile. + uint32_t gx, gy; + if (useCoopMat) { + // gemm-coopmat-shared writes a 16×64 (M×N) tile per workgroup. + gx = (p.outputDim + 63u) / 64u; + gy = (p.batchSeq + 15u) / 16u; + } else { + // fnn-linear writes a 16×16 tile per workgroup. + gx = (p.outputDim + 15u) / 16u; + gy = (p.batchSeq + 15u) / 16u; } - // ── Get or create pipeline (4 buffers, 16 bytes push constants) ── - PipelineEntry pipe = cache.getOrCreate("fnn-linear", 4, 16); + // ── Single command buffer: stage-in → barrier → compute → barrier → stage-out ── + batch.begin(); + + // Stage-in: DMA copy host-visible staging → DEVICE_LOCAL VRAM. + // Bias goes to DL up front for both paths — fnn-linear reads it via + // binding 2, and the coopmat bias-add post-pass reads it via binding 1. + batch.copyBuffer(bufInputStage, bufInputDL, inputBytes); + batch.copyBuffer(bufWeightsStage, bufWeightsDL, weightBytes); + if (p.hasBias && bias) { + batch.copyBuffer(bufBiasStage, bufBiasDL, + size_t(p.outputDim) * sizeof(float)); + } - // ── Allocate descriptor set (LRU cached) ── - std::vector bufferInfos(4); - bufferInfos[0] = {bufInput.handle, 0, inputBytes}; - bufferInfos[1] = {bufWeights.handle, 0, weightBytes}; - bufferInfos[2] = {bufBias.handle, 0, biasBytes}; - bufferInfos[3] = {bufOutput.handle, 0, outputBytes}; + batch.transferComputeBarrier(); + + if (useCoopMat) { + // Coopmat push constants: {M, K, N} (12 bytes) + struct CoopPush { + uint32_t M; + uint32_t K; + uint32_t N; + } coopPush = {p.batchSeq, p.inputDim, p.outputDim}; + batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, gy, 1, + &coopPush, sizeof(coopPush)); + } else { + // fnn-linear push constants: {batch, in, out, hasBias} (16 bytes) + struct FnnPush { + uint32_t batchSeq; + uint32_t inputDim; + uint32_t outputDim; + uint32_t hasBias; + } fnnPush = {p.batchSeq, p.inputDim, p.outputDim, p.hasBias}; + batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, gy, 1, + &fnnPush, sizeof(fnnPush)); + } - VkDescriptorSet descSet = cache.allocDescriptorSet("fnn-linear", bufferInfos); + // ── Bias post-pass (coopmat only; fnn-linear applies bias inline) ── + // Bias was already copied to bufBiasDL during the stage-in phase above, + // so the post-pass just needs a GEMM-write → bias-read barrier and a + // dispatch of the gemm-bias-add kernel. + if (useCoopMat && p.hasBias && bias && + cache.hasShader("gemm-bias-add")) { + batch.barrier(); // SHADER_WRITE (GEMM) → SHADER_READ (bias-add) + + // gemm-bias-add: 2 bindings (C, bias), 8 bytes push {totalElements, N} + PipelineEntry biasPipe = + cache.getOrCreate("gemm-bias-add", 2, 2 * sizeof(uint32_t)); + std::vector biasInfos = { + {bufOutputDL.handle, 0, outputBytes}, + {bufBiasDL.handle, 0, size_t(p.outputDim) * sizeof(float)}, + }; + VkDescriptorSet biasSet = + cache.allocDescriptorSet("gemm-bias-add", biasInfos); + struct BiasPush { + uint32_t totalElements; + uint32_t N; + } biasPush = {p.batchSeq * p.outputDim, p.outputDim}; + uint32_t biasGx = (biasPush.totalElements + 255u) / 256u; + batch.dispatch(biasPipe.pipeline, biasPipe.layout, biasSet, + biasGx, 1, 1, &biasPush, sizeof(biasPush)); + } - // ── Push constants: batch_seq, input_dim, output_dim, has_bias ── - // Matches fnn-linear.glsl layout (4 × uint32 = 16 bytes). - // Python packs these via struct.pack("IIII", ...) — we just memcpy the struct. - LinearParams pushData = p; + batch.transferComputeBarrier(); - // ── Dispatch ── - // 2D workgroups at 16×16 (must match fnn-linear.glsl local_size) - uint32_t gx = (p.outputDim + 15) / 16; - uint32_t gy = (p.batchSeq + 15) / 16; + // Stage-out: DMA copy DEVICE_LOCAL → host-visible HOST_CACHED staging + batch.copyBuffer(bufOutputDL, bufOutputStage, outputBytes); - batch.begin(); - batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, gy, 1, - &pushData, sizeof(pushData)); - batch.submit(); - - // ── Download result (persistent mapping — single memcpy, no vkMap) ── - pool.download(bufOutput, output, outputBytes); - - // ── Release buffers back to pool ── - pool.release(bufInput); - pool.release(bufWeights); - pool.release(bufBias); - pool.release(bufOutput); + batch.submitDeferred(); + batch.waitForCompletion(); + + // ── memcpy staging → CPU output (HOST_CACHED, ~7 GB/s) ── + // Output is always fp32, regardless of input dtype. + pool.download(bufOutputStage, reinterpret_cast(output), outputBytes); + + // ── Release buffers back to their respective pools ── + pool.release(bufInputDL); + pool.release(bufWeightsDL); + pool.release(bufBiasDL); + pool.release(bufOutputDL); + pool.release(bufInputStage); + pool.release(bufWeightsStage); + pool.release(bufBiasStage); + pool.release(bufOutputStage); } // ── CPU reference using Eigen (for correctness verification) ──────────────── @@ -122,35 +269,65 @@ std::vector linearCPU(const float* x, const float* weights, // Workgroups: 2D at (16,16) for passes 0 and 1, 1D for pass 2. void linearBackward(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, - const float* gradOutput, const float* input, - const float* weights, - float* gradInput, float* gradWeight, float* gradBias, + const void* gradOutput, const void* input, + const void* weights, + void* gradInput, void* gradWeight, void* gradBias, const LinearParams& p) { - const size_t gradOutBytes = size_t(p.batchSeq) * p.outputDim * sizeof(float); - const size_t inputBytes = size_t(p.batchSeq) * p.inputDim * sizeof(float); - const size_t weightBytes = size_t(p.outputDim) * p.inputDim * sizeof(float); + // The fnn-linear-backward shader is fp32-only; reject fp16 input until a + // coopmat backward shader lands. The void* interface is in place so the + // switchover is local. + if (p.elemSize != 4u) { + throw std::runtime_error( + "linearBackward(): currently requires fp32 (elemSize=4). fp16 " + "backward needs a cooperative matrix backward shader — TODO."); + } + + // Dynamic byte calculation. With elemSize==4 today these match the old + // sizeof(float) computations, so existing callers see no behavior change. + const size_t gradOutBytes = size_t(p.batchSeq) * p.outputDim * p.elemSize; + const size_t inputBytes = size_t(p.batchSeq) * p.inputDim * p.elemSize; + const size_t weightBytes = size_t(p.outputDim) * p.inputDim * p.elemSize; const size_t gradInBytes = inputBytes; const size_t gradWBytes = weightBytes; - const size_t gradBiasBytes = size_t(p.outputDim) * sizeof(float); - - GrillyBuffer bufGradOut = pool.acquire(gradOutBytes); - GrillyBuffer bufInput = pool.acquire(inputBytes); - GrillyBuffer bufWeights = pool.acquire(weightBytes); - GrillyBuffer bufGradIn = pool.acquire(gradInBytes); - GrillyBuffer bufGradW = pool.acquire(gradWBytes); - GrillyBuffer bufGradBias = pool.acquire(gradBiasBytes); - - pool.upload(bufGradOut, gradOutput, gradOutBytes); - pool.upload(bufInput, input, inputBytes); - pool.upload(bufWeights, weights, weightBytes); - - // Zero grad outputs - std::vector zerosIn(p.batchSeq * p.inputDim, 0.0f); - std::vector zerosW(p.outputDim * p.inputDim, 0.0f); - std::vector zerosB(p.outputDim, 0.0f); - pool.upload(bufGradIn, zerosIn.data(), gradInBytes); - pool.upload(bufGradW, zerosW.data(), gradWBytes); - pool.upload(bufGradBias, zerosB.data(), gradBiasBytes); + const size_t gradBiasBytes = size_t(p.outputDim) * p.elemSize; + + // Staging pattern: 3 stage-in (gradOut, input, weights), + // 3 stage-out (gradIn, gradW, gradBias). All compute on DEVICE_LOCAL. + GrillyBuffer bufGradOutDL = pool.acquireDeviceLocal(gradOutBytes); + GrillyBuffer bufInputDL = pool.acquireDeviceLocal(inputBytes); + GrillyBuffer bufWeightsDL = pool.acquireDeviceLocal(weightBytes); + GrillyBuffer bufGradInDL = pool.acquireDeviceLocal(gradInBytes); + GrillyBuffer bufGradWDL = pool.acquireDeviceLocal(gradWBytes); + GrillyBuffer bufGradBiasDL = pool.acquireDeviceLocal(gradBiasBytes); + + GrillyBuffer bufGradOutStage = pool.acquire(gradOutBytes); + GrillyBuffer bufInputStage = pool.acquire(inputBytes); + GrillyBuffer bufWeightsStage = pool.acquire(weightBytes); + GrillyBuffer bufGradInStage = pool.acquireReadback(gradInBytes); + GrillyBuffer bufGradWStage = pool.acquireReadback(gradWBytes); + GrillyBuffer bufGradBiasStage = pool.acquireReadback(gradBiasBytes); + + pool.upload(bufGradOutStage, + reinterpret_cast(gradOutput), gradOutBytes); + pool.upload(bufInputStage, + reinterpret_cast(input), inputBytes); + pool.upload(bufWeightsStage, + reinterpret_cast(weights), weightBytes); + + // The grad buffers must start at zero — pass 1 (grad_weight) and + // pass 2 (grad_bias) accumulate via atomic adds in the shader. Use + // raw byte vectors so zeroing works identically for fp32 and fp16 + // (whenever the fp16 backward shader lands). Reuse the readback stage + // buffers as upload-zeros source — HOST_CACHED, CPU-write is fine. + std::vector zerosIn(gradInBytes, 0); + std::vector zerosW(gradWBytes, 0); + std::vector zerosB(gradBiasBytes, 0); + pool.upload(bufGradInStage, + reinterpret_cast(zerosIn.data()), gradInBytes); + pool.upload(bufGradWStage, + reinterpret_cast(zerosW.data()), gradWBytes); + pool.upload(bufGradBiasStage, + reinterpret_cast(zerosB.data()), gradBiasBytes); LinearBackwardParams bwdParams{p.batchSeq, p.inputDim, p.outputDim, 0}; @@ -158,18 +335,28 @@ void linearBackward(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, sizeof(LinearBackwardParams)); std::vector bufInfos = { - {bufGradOut.handle, 0, gradOutBytes}, - {bufInput.handle, 0, inputBytes}, - {bufWeights.handle, 0, weightBytes}, - {bufGradIn.handle, 0, gradInBytes}, - {bufGradW.handle, 0, gradWBytes}, - {bufGradBias.handle, 0, gradBiasBytes}, + {bufGradOutDL.handle, 0, gradOutBytes}, + {bufInputDL.handle, 0, inputBytes}, + {bufWeightsDL.handle, 0, weightBytes}, + {bufGradInDL.handle, 0, gradInBytes}, + {bufGradWDL.handle, 0, gradWBytes}, + {bufGradBiasDL.handle, 0, gradBiasBytes}, }; VkDescriptorSet descSet = cache.allocDescriptorSet("fnn-linear-backward", bufInfos); batch.begin(); + // Stage-in: copy all 6 staging buffers (3 inputs + 3 zeroed grads) to DL + batch.copyBuffer(bufGradOutStage, bufGradOutDL, gradOutBytes); + batch.copyBuffer(bufInputStage, bufInputDL, inputBytes); + batch.copyBuffer(bufWeightsStage, bufWeightsDL, weightBytes); + batch.copyBuffer(bufGradInStage, bufGradInDL, gradInBytes); + batch.copyBuffer(bufGradWStage, bufGradWDL, gradWBytes); + batch.copyBuffer(bufGradBiasStage, bufGradBiasDL, gradBiasBytes); + + batch.transferComputeBarrier(); + // Pass 0: grad_input bwdParams.passType = 0; uint32_t gx0 = (p.inputDim + 15) / 16; @@ -192,18 +379,35 @@ void linearBackward(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx2, 1, 1, &bwdParams, sizeof(bwdParams)); - batch.submit(); - - pool.download(bufGradIn, gradInput, gradInBytes); - pool.download(bufGradW, gradWeight, gradWBytes); - pool.download(bufGradBias, gradBias, gradBiasBytes); - - pool.release(bufGradOut); - pool.release(bufInput); - pool.release(bufWeights); - pool.release(bufGradIn); - pool.release(bufGradW); - pool.release(bufGradBias); + batch.transferComputeBarrier(); + + // Stage-out: copy 3 grad buffers from DL → HOST_CACHED readback staging + batch.copyBuffer(bufGradInDL, bufGradInStage, gradInBytes); + batch.copyBuffer(bufGradWDL, bufGradWStage, gradWBytes); + batch.copyBuffer(bufGradBiasDL, bufGradBiasStage, gradBiasBytes); + + batch.submitDeferred(); + batch.waitForCompletion(); + + pool.download(bufGradInStage, + reinterpret_cast(gradInput), gradInBytes); + pool.download(bufGradWStage, + reinterpret_cast(gradWeight), gradWBytes); + pool.download(bufGradBiasStage, + reinterpret_cast(gradBias), gradBiasBytes); + + pool.release(bufGradOutDL); + pool.release(bufInputDL); + pool.release(bufWeightsDL); + pool.release(bufGradInDL); + pool.release(bufGradWDL); + pool.release(bufGradBiasDL); + pool.release(bufGradOutStage); + pool.release(bufInputStage); + pool.release(bufWeightsStage); + pool.release(bufGradInStage); + pool.release(bufGradWStage); + pool.release(bufGradBiasStage); } // ── GPU dropout ────────────────────────────────────────────────────────── @@ -213,20 +417,25 @@ void dropout(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, uint32_t totalElements, float dropoutProb, bool isTraining) { const size_t bytes = size_t(totalElements) * sizeof(float); - GrillyBuffer bufInput = pool.acquire(bytes); - GrillyBuffer bufRandom = pool.acquire(bytes); - GrillyBuffer bufOutput = pool.acquire(bytes); + // Staging pattern: 2 stage-in (input, randomMask), 1 stage-out (output) + GrillyBuffer bufInputDL = pool.acquireDeviceLocal(bytes); + GrillyBuffer bufRandomDL = pool.acquireDeviceLocal(bytes); + GrillyBuffer bufOutputDL = pool.acquireDeviceLocal(bytes); + + GrillyBuffer bufInputStage = pool.acquire(bytes); + GrillyBuffer bufRandomStage = pool.acquire(bytes); + GrillyBuffer bufOutputStage = pool.acquireReadback(bytes); - pool.upload(bufInput, input, bytes); - pool.upload(bufRandom, randomMask, bytes); + pool.upload(bufInputStage, input, bytes); + pool.upload(bufRandomStage, randomMask, bytes); PipelineEntry pipe = cache.getOrCreate("fnn-dropout", 3, sizeof(DropoutParams)); std::vector bufInfos = { - {bufInput.handle, 0, bytes}, - {bufRandom.handle, 0, bytes}, - {bufOutput.handle, 0, bytes}, + {bufInputDL.handle, 0, bytes}, + {bufRandomDL.handle, 0, bytes}, + {bufOutputDL.handle, 0, bytes}, }; VkDescriptorSet descSet = cache.allocDescriptorSet("fnn-dropout", bufInfos); @@ -234,15 +443,24 @@ void dropout(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, uint32_t gx = (totalElements + 255) / 256; batch.begin(); + batch.copyBuffer(bufInputStage, bufInputDL, bytes); + batch.copyBuffer(bufRandomStage, bufRandomDL, bytes); + batch.transferComputeBarrier(); batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, 1, 1, &push, sizeof(push)); - batch.submit(); - - pool.download(bufOutput, output, bytes); - - pool.release(bufInput); - pool.release(bufRandom); - pool.release(bufOutput); + batch.transferComputeBarrier(); + batch.copyBuffer(bufOutputDL, bufOutputStage, bytes); + batch.submitDeferred(); + batch.waitForCompletion(); + + pool.download(bufOutputStage, output, bytes); + + pool.release(bufInputDL); + pool.release(bufRandomDL); + pool.release(bufOutputDL); + pool.release(bufInputStage); + pool.release(bufRandomStage); + pool.release(bufOutputStage); } } // namespace ops diff --git a/cpp/src/ops/loss.cpp b/cpp/src/ops/loss.cpp index 866431d..6c9fcbf 100644 --- a/cpp/src/ops/loss.cpp +++ b/cpp/src/ops/loss.cpp @@ -24,24 +24,30 @@ void crossEntropyLoss(CommandBatch& batch, BufferPool& pool, const size_t lossBytes = size_t(totalPositions) * sizeof(float); const size_t auxBytes = size_t(totalPositions) * sizeof(float); - GrillyBuffer bufLogits = pool.acquire(logitBytes); - GrillyBuffer bufTarget = pool.acquire(targetBytes); - GrillyBuffer bufLoss = pool.acquire(lossBytes); - GrillyBuffer bufMax = pool.acquire(auxBytes); - GrillyBuffer bufSumExp = pool.acquire(auxBytes); + // Staging pattern: 2 stage-in (logits, targets), 1 stage-out (losses). + // max and sumExp are intermediate DL-only buffers (CPU never sees them). + GrillyBuffer bufLogitsDL = pool.acquireDeviceLocal(logitBytes); + GrillyBuffer bufTargetDL = pool.acquireDeviceLocal(targetBytes); + GrillyBuffer bufLossDL = pool.acquireDeviceLocal(lossBytes); + GrillyBuffer bufMaxDL = pool.acquireDeviceLocal(auxBytes); + GrillyBuffer bufSumExpDL = pool.acquireDeviceLocal(auxBytes); - pool.upload(bufLogits, logits, logitBytes); - pool.upload(bufTarget, reinterpret_cast(targets), targetBytes); + GrillyBuffer bufLogitsStage = pool.acquire(logitBytes); + GrillyBuffer bufTargetStage = pool.acquire(targetBytes); + GrillyBuffer bufLossStage = pool.acquireReadback(lossBytes); + + pool.upload(bufLogitsStage, logits, logitBytes); + pool.upload(bufTargetStage, reinterpret_cast(targets), targetBytes); PipelineEntry pipe = cache.getOrCreate("loss-cross-entropy", 5, sizeof(CrossEntropyParams)); std::vector bufInfos = { - {bufLogits.handle, 0, logitBytes}, - {bufTarget.handle, 0, targetBytes}, - {bufLoss.handle, 0, lossBytes}, - {bufMax.handle, 0, auxBytes}, - {bufSumExp.handle, 0, auxBytes}, + {bufLogitsDL.handle, 0, logitBytes}, + {bufTargetDL.handle, 0, targetBytes}, + {bufLossDL.handle, 0, lossBytes}, + {bufMaxDL.handle, 0, auxBytes}, + {bufSumExpDL.handle, 0, auxBytes}, }; VkDescriptorSet descSet = cache.allocDescriptorSet("loss-cross-entropy", bufInfos); @@ -49,6 +55,9 @@ void crossEntropyLoss(CommandBatch& batch, BufferPool& pool, uint32_t gx = (totalPositions + 255) / 256; batch.begin(); + batch.copyBuffer(bufLogitsStage, bufLogitsDL, logitBytes); + batch.copyBuffer(bufTargetStage, bufTargetDL, targetBytes); + batch.transferComputeBarrier(); // Pass 0: find max logit per position CrossEntropyParams push0 = p; @@ -70,15 +79,22 @@ void crossEntropyLoss(CommandBatch& batch, BufferPool& pool, batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, 1, 1, &push2, sizeof(push2)); - batch.submit(); + batch.transferComputeBarrier(); + batch.copyBuffer(bufLossDL, bufLossStage, lossBytes); + + batch.submitDeferred(); + batch.waitForCompletion(); - pool.download(bufLoss, losses, lossBytes); + pool.download(bufLossStage, losses, lossBytes); - pool.release(bufLogits); - pool.release(bufTarget); - pool.release(bufLoss); - pool.release(bufMax); - pool.release(bufSumExp); + pool.release(bufLogitsDL); + pool.release(bufTargetDL); + pool.release(bufLossDL); + pool.release(bufMaxDL); + pool.release(bufSumExpDL); + pool.release(bufLogitsStage); + pool.release(bufTargetStage); + pool.release(bufLossStage); } // ── Cross-entropy backward ─────────────────────────────────────────────── @@ -91,20 +107,25 @@ void crossEntropyBackward(CommandBatch& batch, BufferPool& pool, const size_t logitBytes = size_t(p.batchSize) * p.numClasses * sizeof(float); const size_t targetBytes = size_t(p.batchSize) * sizeof(uint32_t); - GrillyBuffer bufLogits = pool.acquire(logitBytes); - GrillyBuffer bufTarget = pool.acquire(targetBytes); - GrillyBuffer bufGrad = pool.acquire(logitBytes); + // Staging pattern: 2 stage-in (logits, targets), 1 stage-out (gradLogits) + GrillyBuffer bufLogitsDL = pool.acquireDeviceLocal(logitBytes); + GrillyBuffer bufTargetDL = pool.acquireDeviceLocal(targetBytes); + GrillyBuffer bufGradDL = pool.acquireDeviceLocal(logitBytes); + + GrillyBuffer bufLogitsStage = pool.acquire(logitBytes); + GrillyBuffer bufTargetStage = pool.acquire(targetBytes); + GrillyBuffer bufGradStage = pool.acquireReadback(logitBytes); - pool.upload(bufLogits, logits, logitBytes); - pool.upload(bufTarget, reinterpret_cast(targets), targetBytes); + pool.upload(bufLogitsStage, logits, logitBytes); + pool.upload(bufTargetStage, reinterpret_cast(targets), targetBytes); PipelineEntry pipe = cache.getOrCreate("cross-entropy-backward", 3, sizeof(CrossEntropyBackwardParams)); std::vector bufInfos = { - {bufLogits.handle, 0, logitBytes}, - {bufTarget.handle, 0, targetBytes}, - {bufGrad.handle, 0, logitBytes}, + {bufLogitsDL.handle, 0, logitBytes}, + {bufTargetDL.handle, 0, targetBytes}, + {bufGradDL.handle, 0, logitBytes}, }; VkDescriptorSet descSet = cache.allocDescriptorSet("cross-entropy-backward", bufInfos); @@ -112,15 +133,24 @@ void crossEntropyBackward(CommandBatch& batch, BufferPool& pool, uint32_t gx = (p.batchSize + 255) / 256; batch.begin(); + batch.copyBuffer(bufLogitsStage, bufLogitsDL, logitBytes); + batch.copyBuffer(bufTargetStage, bufTargetDL, targetBytes); + batch.transferComputeBarrier(); batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, 1, 1, &p, sizeof(p)); - batch.submit(); - - pool.download(bufGrad, gradLogits, logitBytes); - - pool.release(bufLogits); - pool.release(bufTarget); - pool.release(bufGrad); + batch.transferComputeBarrier(); + batch.copyBuffer(bufGradDL, bufGradStage, logitBytes); + batch.submitDeferred(); + batch.waitForCompletion(); + + pool.download(bufGradStage, gradLogits, logitBytes); + + pool.release(bufLogitsDL); + pool.release(bufTargetDL); + pool.release(bufGradDL); + pool.release(bufLogitsStage); + pool.release(bufTargetStage); + pool.release(bufGradStage); } } // namespace ops diff --git a/cpp/src/ops/moe_forward.cpp b/cpp/src/ops/moe_forward.cpp new file mode 100644 index 0000000..6802e8c --- /dev/null +++ b/cpp/src/ops/moe_forward.cpp @@ -0,0 +1,752 @@ +/// moe_forward.cpp — fused MoE forward on GPU (router/blend CPU) + CPU backward. +#include "grilly/ops/moe_forward.h" +#include "grilly/ops/batched_ops.h" + +#include +#include +#include +#include +#include + +namespace grilly { +namespace ops { + +namespace { + +static std::unordered_map g_moe; +static int g_next_moe = 1; + +static void transpose_dd(const float* w, float* wt, uint32_t d) { + for (uint32_t r = 0; r < d; ++r) + for (uint32_t c = 0; c < d; ++c) + wt[c * d + r] = w[r * d + c]; +} + +static void transpose_vd(const float* w, float* wt, uint32_t v, uint32_t d) { + for (uint32_t r = 0; r < v; ++r) + for (uint32_t c = 0; c < d; ++c) + wt[c * v + r] = w[r * d + c]; +} + +static Eigen::VectorXf softmax_vec(const Eigen::VectorXf& x) { + float m = x.maxCoeff(); + Eigen::VectorXf e = (x.array() - m).exp(); + return e / e.sum(); +} + +static Eigen::VectorXf softmax_grad(const Eigen::VectorXf& p, + const Eigen::VectorXf& grad_p) { + float s = p.dot(grad_p); + return p.cwiseProduct(grad_p - Eigen::VectorXf::Constant(p.size(), s)); +} + +} // namespace + +MoeHandleCache& moe_get_cache(int handle) { + auto it = g_moe.find(handle); + if (it == g_moe.end()) + throw std::runtime_error("Invalid moe handle"); + return it->second; +} + +void moe_release(BufferPool& pool, int handle) { + auto it = g_moe.find(handle); + if (it == g_moe.end()) + return; + MoeHandleCache& h = it->second; + + pool.release(h.embedW); + pool.release(h.posW); + pool.release(h.outW); + pool.release(h.outWt); + for (auto& lw : h.layers) { + pool.release(lw.routerW); + pool.release(lw.routerB); + for (uint32_t e = 0; e < lw.expertW.size(); ++e) { + pool.release(lw.expertW[e]); + pool.release(lw.expertWt[e]); + } + pool.release(lw.expertPacked); + } + pool.release(h.bufIds); + pool.release(h.bufPosSlice); + pool.release(h.bufX); + for (auto& b : h.bufExpertOut) + pool.release(b); + pool.release(h.bufBlended); + pool.release(h.bufLogits); + + g_moe.erase(it); +} + +static void upload_dd_pair(BufferPool& pool, GrillyBuffer& wbuf, GrillyBuffer& wtbuf, + const float* w, uint32_t d) { + size_t bytes = size_t(d) * d * sizeof(float); + wbuf = pool.acquire(bytes); + pool.upload(wbuf, w, bytes); + std::vector wt(d * d); + transpose_dd(w, wt.data(), d); + wtbuf = pool.acquire(bytes); + pool.upload(wtbuf, wt.data(), bytes); +} + +int moe_upload(BufferPool& pool, + uint32_t vocab_size, uint32_t d_model, uint32_t max_seq, + const float* embed_w, const float* pos_w, + const std::vector& expert_ws, + const std::vector& router_ws, + const std::vector& router_bs, + const float* out_w, + uint32_t n_layers, uint32_t n_experts) { + + if (n_layers == 0 || n_experts == 0 || d_model == 0 || vocab_size == 0 || max_seq == 0) + throw std::runtime_error("moe_upload: invalid dimensions"); + if (expert_ws.size() != size_t(n_layers) * n_experts) + throw std::runtime_error("moe_upload: expert_ws length mismatch"); + if (router_ws.size() != n_layers || router_bs.size() != n_layers) + throw std::runtime_error("moe_upload: router list length mismatch"); + + MoeHandleCache h; + h.vocab = vocab_size; + h.d = d_model; + h.maxSeq = max_seq; + h.nLayers = n_layers; + h.nExperts = n_experts; + + size_t embed_bytes = size_t(vocab_size) * d_model * sizeof(float); + h.cpu_embed.assign(embed_bytes / sizeof(float), 0.f); + std::memcpy(h.cpu_embed.data(), embed_w, embed_bytes); + h.embedW = pool.acquire(embed_bytes); + pool.upload(h.embedW, embed_w, embed_bytes); + + size_t pos_bytes = size_t(max_seq) * d_model * sizeof(float); + h.cpu_pos.assign(pos_bytes / sizeof(float), 0.f); + std::memcpy(h.cpu_pos.data(), pos_w, pos_bytes); + h.posW = pool.acquire(pos_bytes); + pool.upload(h.posW, pos_w, pos_bytes); + + size_t out_bytes = size_t(vocab_size) * d_model * sizeof(float); + h.cpu_out_w.assign(out_bytes / sizeof(float), 0.f); + std::memcpy(h.cpu_out_w.data(), out_w, out_bytes); + h.outW = pool.acquire(out_bytes); + pool.upload(h.outW, out_w, out_bytes); + std::vector out_wt(d_model * vocab_size); + transpose_vd(out_w, out_wt.data(), vocab_size, d_model); + h.outWt = pool.acquire(out_wt.size() * sizeof(float)); + pool.upload(h.outWt, out_wt.data(), out_wt.size() * sizeof(float)); + + h.cpu_expert_w.resize(size_t(n_layers) * n_experts); + h.layers.resize(n_layers); + for (uint32_t l = 0; l < n_layers; ++l) { + auto& lw = h.layers[l]; + size_t rbytes = size_t(n_experts) * d_model * sizeof(float); + lw.routerW = pool.acquire(rbytes); + pool.upload(lw.routerW, router_ws[l], rbytes); + lw.routerB = pool.acquire(n_experts * sizeof(float)); + pool.upload(lw.routerB, router_bs[l], n_experts * sizeof(float)); + + h.cpu_router_w.emplace_back(n_experts * d_model); + std::memcpy(h.cpu_router_w.back().data(), router_ws[l], rbytes); + h.cpu_router_b.emplace_back(n_experts); + std::memcpy(h.cpu_router_b.back().data(), router_bs[l], + n_experts * sizeof(float)); + + lw.expertW.resize(n_experts); + lw.expertWt.resize(n_experts); + // Pack all experts contiguously for fused shader + size_t packed_size = size_t(n_experts) * d_model * d_model * sizeof(float); + std::vector packed(n_experts * d_model * d_model); + for (uint32_t e = 0; e < n_experts; ++e) { + const float* w = expert_ws[l * n_experts + e]; + upload_dd_pair(pool, lw.expertW[e], lw.expertWt[e], w, d_model); + h.cpu_expert_w[l * n_experts + e].assign(d_model * d_model, 0.f); + std::memcpy(h.cpu_expert_w[l * n_experts + e].data(), w, + d_model * d_model * sizeof(float)); + std::memcpy(packed.data() + e * d_model * d_model, w, + d_model * d_model * sizeof(float)); + } + lw.expertPacked = pool.acquire(packed_size); + pool.upload(lw.expertPacked, packed.data(), packed_size); + } + + size_t seq_d = size_t(max_seq) * d_model * sizeof(float); + h.bufIds = pool.acquire(max_seq * sizeof(uint32_t)); + h.bufPosSlice = pool.acquire(seq_d); + h.bufX = pool.acquire(seq_d); + h.bufBlended = pool.acquire(seq_d); + + // Activation buffers for backward (n_layers + 1: input to each layer + final) + h.bufActivations.resize(n_layers + 1); + for (uint32_t l = 0; l <= n_layers; ++l) + h.bufActivations[l] = pool.acquire(seq_d); + h.fwd_router_weights.resize(n_layers); + + h.bufExpertOut.resize(n_experts); + for (uint32_t e = 0; e < n_experts; ++e) + h.bufExpertOut[e] = pool.acquire(seq_d); + h.bufLogits = pool.acquire(size_t(max_seq) * vocab_size * sizeof(float)); + + int hid = g_next_moe++; + g_moe[hid] = std::move(h); + return hid; +} + +void moe_update_weights(BufferPool& pool, MoeHandleCache& h, + const float* embed_w, const float* pos_w, + const std::vector& expert_ws, + const std::vector& router_ws, + const std::vector& router_bs, + const float* out_w) { + + uint32_t V = h.vocab; + uint32_t d = h.d; + uint32_t max_seq = h.maxSeq; + uint32_t L = h.nLayers; + uint32_t E = h.nExperts; + + size_t embed_bytes = size_t(V) * d * sizeof(float); + std::memcpy(h.cpu_embed.data(), embed_w, embed_bytes); + pool.upload(h.embedW, embed_w, embed_bytes); + + size_t pos_bytes = size_t(max_seq) * d * sizeof(float); + std::memcpy(h.cpu_pos.data(), pos_w, pos_bytes); + pool.upload(h.posW, pos_w, pos_bytes); + + size_t out_bytes = size_t(V) * d * sizeof(float); + std::memcpy(h.cpu_out_w.data(), out_w, out_bytes); + pool.upload(h.outW, out_w, out_bytes); + std::vector out_wt(d * V); + transpose_vd(out_w, out_wt.data(), V, d); + pool.upload(h.outWt, out_wt.data(), out_wt.size() * sizeof(float)); + + for (uint32_t l = 0; l < L; ++l) { + auto& lw = h.layers[l]; + size_t rbytes = size_t(E) * d * sizeof(float); + pool.upload(lw.routerW, router_ws[l], rbytes); + pool.upload(lw.routerB, router_bs[l], E * sizeof(float)); + std::memcpy(h.cpu_router_w[l].data(), router_ws[l], rbytes); + std::memcpy(h.cpu_router_b[l].data(), router_bs[l], E * sizeof(float)); + + size_t packed_size = size_t(E) * d * d * sizeof(float); + std::vector packed(E * d * d); + for (uint32_t e = 0; e < E; ++e) { + const float* w = expert_ws[l * E + e]; + size_t dd = size_t(d) * d * sizeof(float); + pool.upload(lw.expertW[e], w, dd); + std::vector wt(d * d); + transpose_dd(w, wt.data(), d); + pool.upload(lw.expertWt[e], wt.data(), dd); + std::memcpy(h.cpu_expert_w[l * E + e].data(), w, dd); + std::memcpy(packed.data() + e * d * d, w, dd); + } + pool.upload(lw.expertPacked, packed.data(), packed_size); + } +} + +void moe_forward_gpu(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, + MoeHandleCache& h, const int32_t* input_ids, uint32_t seq_len, + float* logits_out) { + + if (seq_len == 0 || seq_len > h.maxSeq) + throw std::runtime_error("moe_forward: invalid seq_len"); + + uint32_t S = seq_len; + uint32_t d = h.d; + uint32_t V = h.vocab; + + pool.upload(h.bufIds, reinterpret_cast(input_ids), + S * sizeof(uint32_t)); + pool.upload(h.bufPosSlice, h.cpu_pos.data(), S * d * sizeof(float)); + + batch.begin(); + batchedEmbeddingLookup(batch, cache, h.bufIds, h.embedW, h.bufX, 1, S, h.vocab, d); + batch.barrier(); + batchedAdd(batch, cache, h.bufX, h.bufPosSlice, S * d); + batch.submitDeferred(); + batch.waitForCompletion(); + + size_t packed_size = size_t(h.nExperts) * d * d * sizeof(float); + bool has_router = cache.hasShader("moe-router"); + // Prefer vec4 path, fall back to scalar fused, then legacy + bool has_vec4 = cache.hasShader("moe-layer-fused-vec4") && (d % 4 == 0); + bool has_fused = has_vec4 || cache.hasShader("moe-layer-fused"); + + if (has_fused && has_router && h.nExperts == 4) { + // ══════════════════════════════════════════════════════════════ + // ALL-GPU path: router + experts in ONE command buffer submission + // ZERO CPU round-trips between layers. ONE fence wait total. + // Vec4 path: 4x memory bandwidth via 128-bit loads. + // ══════════════════════════════════════════════════════════════ + + size_t scratch_size = (d + h.nExperts + h.nExperts) * sizeof(float); + GrillyBuffer bufScratch = pool.acquire(scratch_size); + uint32_t weights_offset = d + h.nExperts; + + struct RouterPush { uint32_t seq_len, d_model, n_experts, pass; }; + struct FusedPush { uint32_t seq_len, d_model, n_experts, weights_offset; }; + + PipelineEntry routerPipe = cache.getOrCreate("moe-router", 4, sizeof(RouterPush)); + + // Choose vec4 or scalar fused shader + const char* fused_name = has_vec4 ? "moe-layer-fused-vec4" : "moe-layer-fused"; + PipelineEntry fusedPipe = cache.getOrCreate(fused_name, 4, sizeof(FusedPush)); + + batch.begin(); + + // Save initial activation + batch.copyBuffer(h.bufX, h.bufActivations[0], S * d * sizeof(float)); + batch.barrier(); + + for (uint32_t l = 0; l < h.nLayers; ++l) { + auto& layer = h.layers[l]; + + // Router: 3 passes (mean, logits, softmax) + for (uint32_t pass = 0; pass < 3; ++pass) { + std::vector rbufs = { + {h.bufX.handle, 0, S * d * sizeof(float)}, + {layer.routerW.handle, 0, h.nExperts * d * sizeof(float)}, + {layer.routerB.handle, 0, h.nExperts * sizeof(float)}, + {bufScratch.handle, 0, scratch_size}, + }; + VkDescriptorSet rdesc = cache.allocDescriptorSet("moe-router", rbufs); + RouterPush rp{S, d, h.nExperts, pass}; + uint32_t wg = (pass == 0) ? (d + 255) / 256 : + (pass == 1) ? (h.nExperts + 255) / 256 : 1; + batch.dispatch(routerPipe.pipeline, routerPipe.layout, rdesc, + wg, 1, 1, &rp, sizeof(rp)); + batch.barrier(); + } + + // Fused expert layer (vec4 or scalar) + { + FusedPush fp{S, d, h.nExperts, weights_offset}; + std::vector fbufs = { + {h.bufX.handle, 0, S * d * sizeof(float)}, + {layer.expertPacked.handle, 0, packed_size}, + {h.bufBlended.handle, 0, S * d * sizeof(float)}, + {bufScratch.handle, 0, scratch_size}, + }; + VkDescriptorSet fdesc = cache.allocDescriptorSet(fused_name, fbufs); + + uint32_t gx, gy; + if (has_vec4) { + gx = ((d / 4) + 15) / 16; // vec4 columns + gy = (S + 15) / 16; + } else { + gx = (d + 15) / 16; + gy = (S + 15) / 16; + } + batch.dispatch(fusedPipe.pipeline, fusedPipe.layout, fdesc, + gx, gy, 1, &fp, sizeof(fp)); + batch.barrier(); + } + + std::swap(h.bufX, h.bufBlended); + + // Save activation for backward + batch.copyBuffer(h.bufX, h.bufActivations[l + 1], S * d * sizeof(float)); + batch.barrier(); + } + + // Output projection (same command buffer!) + batchedLinear(batch, cache, h.bufX, h.outW, nullptr, h.bufLogits, S, d, V); + + // ONE submit, ONE fence wait for entire forward pass + batch.submitDeferred(); + batch.waitForCompletion(); + + pool.release(bufScratch); + } else { + // ══════════════════════════════════════════════════════════════ + // Legacy CPU-router path (fallback) + // ══════════════════════════════════════════════════════════════ + std::vector x_cpu(S * d); + std::vector x_mean(d); + + for (uint32_t l = 0; l < h.nLayers; ++l) { + pool.download(h.bufX, x_cpu.data(), S * d * sizeof(float)); + for (uint32_t j = 0; j < d; ++j) { + float s = 0.f; + for (uint32_t i = 0; i < S; ++i) + s += x_cpu[i * d + j]; + x_mean[j] = s / float(S); + } + + Eigen::Map xmean(x_mean.data(), d); + Eigen::Map> Wr( + h.cpu_router_w[l].data(), h.nExperts, d); + Eigen::Map b(h.cpu_router_b[l].data(), h.nExperts); + Eigen::VectorXf logits_v = Wr * xmean + b; + Eigen::VectorXf p = softmax_vec(logits_v); + + auto& layer = h.layers[l]; + batch.begin(); + for (uint32_t e = 0; e < h.nExperts; ++e) { + batchedLinear(batch, cache, h.bufX, layer.expertW[e], nullptr, + h.bufExpertOut[e], S, d, d); + } + batch.submitDeferred(); + batch.waitForCompletion(); + + std::vector expert_flat(S * d * h.nExperts); + for (uint32_t e = 0; e < h.nExperts; ++e) + pool.download(h.bufExpertOut[e], expert_flat.data() + e * S * d, + S * d * sizeof(float)); + + std::vector blended(S * d, 0.f); + for (uint32_t e = 0; e < h.nExperts; ++e) { + float pe = p[e]; + for (size_t i = 0; i < S * d; ++i) + blended[i] += pe * expert_flat[e * S * d + i]; + } + pool.upload(h.bufBlended, blended.data(), S * d * sizeof(float)); + + batch.begin(); + batchedAdd(batch, cache, h.bufX, h.bufBlended, S * d); + batch.submitDeferred(); + batch.waitForCompletion(); + } + } + + size_t log_bytes = S * V * sizeof(float); + batch.begin(); + batchedLinear(batch, cache, h.bufX, h.outW, nullptr, h.bufLogits, S, d, V); + batch.submitDeferred(); + batch.waitForCompletion(); + pool.download(h.bufLogits, logits_out, log_bytes); +} + +MoeGradients moe_backward_cpu(const MoeHandleCache& h, + const int32_t* input_ids, uint32_t seq_len, + const float* grad_logits) { + + uint32_t S = seq_len; + uint32_t d = h.d; + uint32_t V = h.vocab; + uint32_t L = h.nLayers; + uint32_t E = h.nExperts; + if (S == 0 || S > h.maxSeq) + throw std::runtime_error("moe_backward: invalid seq_len"); + + using RowMajor = Eigen::Matrix; + + Eigen::MatrixXf X = Eigen::MatrixXf::Zero(S, d); + for (uint32_t s = 0; s < S; ++s) { + int32_t tok = input_ids[s]; + if (tok < 0 || static_cast(tok) >= V) + throw std::runtime_error("moe_backward: token id out of range"); + for (uint32_t j = 0; j < d; ++j) { + X(s, j) = h.cpu_embed[tok * d + j] + h.cpu_pos[s * d + j]; + } + } + + std::vector Xs; + Xs.reserve(L + 1); + Xs.push_back(X); + + struct LayerTrace { + Eigen::VectorXf p; + Eigen::VectorXf xmean; + std::vector Y; + }; + std::vector trace(L); + + for (uint32_t l = 0; l < L; ++l) { + Eigen::MatrixXf& cur = Xs.back(); + Eigen::VectorXf xmean = cur.colwise().mean().transpose(); + + Eigen::Map> Wr( + h.cpu_router_w[l].data(), E, d); + Eigen::Map b(h.cpu_router_b[l].data(), E); + Eigen::VectorXf logits = Wr * xmean + b; + Eigen::VectorXf p = softmax_vec(logits); + + Eigen::MatrixXf blend = Eigen::MatrixXf::Zero(S, d); + trace[l].Y.resize(E); + for (uint32_t e = 0; e < E; ++e) { + Eigen::Map We(h.cpu_expert_w[l * E + e].data(), d, d); + trace[l].Y[e] = cur * We.transpose(); + blend += p[e] * trace[l].Y[e]; + } + trace[l].p = p; + trace[l].xmean = xmean; + + Xs.push_back(cur + blend); + } + + Eigen::MatrixXf Xfinal = Xs.back(); + + Eigen::Map Glog(grad_logits, S, V); + Eigen::Map Wo(h.cpu_out_w.data(), V, d); + + MoeGradients out; + out.grad_out_w.resize(V * d); + Eigen::MatrixXf grad_Wo_mat = Glog.transpose() * Xfinal; + Eigen::Map gwo(out.grad_out_w.data(), V, d); + gwo = grad_Wo_mat; + + out.grad_router_w.resize(L); + out.grad_router_b.resize(L); + out.grad_experts.resize(L * E); + for (auto& v : out.grad_router_w) + v.assign(E * d, 0.f); + for (auto& v : out.grad_router_b) + v.assign(E, 0.f); + for (auto& v : out.grad_experts) + v.assign(d * d, 0.f); + + // logits = Xfinal * Wo^T where Wo is (V, d), so grad_X = grad_logits * Wo. + Eigen::MatrixXf g = Glog * Wo; + + for (int li = static_cast(L) - 1; li >= 0; --li) { + uint32_t l = static_cast(li); + const Eigen::MatrixXf& cur = Xs[l]; + const LayerTrace& tr = trace[l]; + + Eigen::Map> Wr( + h.cpu_router_w[l].data(), E, d); + + Eigen::MatrixXf grad_blend = g; + Eigen::MatrixXf grad_Xl = Eigen::MatrixXf::Zero(S, d); + grad_Xl += grad_blend; + + Eigen::VectorXf grad_p(E); + for (uint32_t e = 0; e < E; ++e) + grad_p[e] = (tr.Y[e].array() * grad_blend.array()).sum(); + + Eigen::VectorXf grad_logits_r = softmax_grad(tr.p, grad_p); + Eigen::VectorXf grad_xm = Wr.transpose() * grad_logits_r; + + for (uint32_t s = 0; s < S; ++s) + grad_Xl.row(s) += (1.0f / float(S)) * grad_xm.transpose(); + + Eigen::Map> gWr( + out.grad_router_w[l].data(), E, d); + gWr = grad_logits_r * tr.xmean.transpose(); + + Eigen::Map gb(out.grad_router_b[l].data(), E); + gb = grad_logits_r; + + for (uint32_t e = 0; e < E; ++e) { + Eigen::MatrixXf grad_Y = tr.p[e] * grad_blend; + Eigen::Map We(h.cpu_expert_w[l * E + e].data(), d, d); + grad_Xl += grad_Y * We; + + Eigen::MatrixXf grad_We = grad_Y.transpose() * cur; + std::memcpy(out.grad_experts[l * E + e].data(), grad_We.data(), + d * d * sizeof(float)); + } + + g = grad_Xl; + } + + out.grad_embed.assign(h.vocab * d, 0.f); + out.grad_pos.assign(h.maxSeq * d, 0.f); + + for (uint32_t s = 0; s < S; ++s) { + int32_t tok = input_ids[s]; + for (uint32_t j = 0; j < d; ++j) { + out.grad_embed[tok * d + j] += g(s, j); + out.grad_pos[s * d + j] = g(s, j); + } + } + + return out; +} + +MoeGradients moe_backward_gpu(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, + MoeHandleCache& h, const int32_t* input_ids, uint32_t seq_len, + const float* grad_logits) { + + // GPU backward shaders exist but Eigen CPU is faster for now + // (strided memory access pattern in backward shader kills GPU perf). + // TODO: transpose expert weights for backward-friendly layout. + return moe_backward_cpu(h, input_ids, seq_len, grad_logits); + + bool has_bwd_vec4 = cache.hasShader("moe-layer-backward-vec4") && (h.d % 4 == 0); + bool has_bwd = has_bwd_vec4 || cache.hasShader("moe-layer-backward"); + bool has_gw = cache.hasShader("moe-layer-grad-weight"); + + if (!has_bwd) + return moe_backward_cpu(h, input_ids, seq_len, grad_logits); + + uint32_t S = seq_len; + uint32_t d = h.d; + uint32_t V = h.vocab; + uint32_t L = h.nLayers; + uint32_t E = h.nExperts; + size_t sd = S * d * sizeof(float); + size_t packed_size = size_t(E) * d * d * sizeof(float); + + // Re-compute router weights on CPU (tiny) + std::vector> router_p(L); + for (uint32_t l = 0; l < L; ++l) { + Eigen::Map> Wr( + h.cpu_router_w[l].data(), E, d); + Eigen::Map b(h.cpu_router_b[l].data(), E); + + std::vector act(S * d); + pool.download(h.bufActivations[l], act.data(), sd); + Eigen::VectorXf xmean = Eigen::VectorXf::Zero(d); + for (uint32_t i = 0; i < S; ++i) + for (uint32_t j = 0; j < d; ++j) + xmean[j] += act[i * d + j]; + xmean /= float(S); + + Eigen::VectorXf logits_v = Wr * xmean + b; + Eigen::VectorXf p = softmax_vec(logits_v); + router_p[l].resize(E); + for (uint32_t e = 0; e < E; ++e) + router_p[l][e] = p[e]; + } + + // Upload grad_logits + GrillyBuffer bufGL = pool.acquire(S * V * sizeof(float)); + pool.upload(bufGL, grad_logits, S * V * sizeof(float)); + + // Output projection backward: dx = grad_logits @ out_w (using fnn-linear) + GrillyBuffer bufDx = pool.acquire(sd); + batch.begin(); + batchedLinear(batch, cache, bufGL, h.outWt, nullptr, bufDx, S, V, d); + batch.submitDeferred(); + batch.waitForCompletion(); + + // grad_out_w = grad_logits.T @ x_final (CPU — vocab-sized, not worth GPU) + std::vector gl_cpu(S * V); + pool.download(bufGL, gl_cpu.data(), S * V * sizeof(float)); + std::vector xfinal(S * d); + pool.download(h.bufActivations[L], xfinal.data(), sd); + + MoeGradients out; + out.grad_out_w.resize(V * d, 0.f); + { + using RM = Eigen::Matrix; + Eigen::Map GL(gl_cpu.data(), S, V); + Eigen::Map XF(xfinal.data(), S, d); + RM GOW = GL.transpose() * XF; + std::memcpy(out.grad_out_w.data(), GOW.data(), V * d * sizeof(float)); + } + + // Per-layer backward on GPU + const char* bwd_name = has_bwd_vec4 ? "moe-layer-backward-vec4" : "moe-layer-backward"; + struct BwdPush { uint32_t seq_len, d_model, n_experts; float w0, w1, w2, w3; }; + + PipelineEntry bwdPipe = cache.getOrCreate(bwd_name, 3, sizeof(BwdPush)); + + GrillyBuffer bufGradIn = pool.acquire(sd); + GrillyBuffer bufGradW = pool.acquire(packed_size); + + out.grad_experts.resize(L * E); + out.grad_router_w.resize(L); + out.grad_router_b.resize(L); + + bool has_gw_shader = has_gw; + PipelineEntry gwPipe{}; + if (has_gw_shader) + gwPipe = cache.getOrCreate("moe-layer-grad-weight", 3, sizeof(BwdPush)); + + for (int32_t l = L - 1; l >= 0; --l) { + auto& layer = h.layers[l]; + float w0 = router_p[l][0], w1 = router_p[l][1]; + float w2 = router_p[l][2], w3 = router_p[l][3]; + BwdPush bp{S, d, E, w0, w1, w2, w3}; + + batch.begin(); + + // grad_input via backward shader + { + std::vector bufs = { + {bufDx.handle, 0, sd}, + {layer.expertPacked.handle, 0, packed_size}, + {bufGradIn.handle, 0, sd}, + }; + VkDescriptorSet desc = cache.allocDescriptorSet(bwd_name, bufs); + + uint32_t gx, gy; + if (has_bwd_vec4) { + gx = ((d / 4) + 31) / 32; // vec4 columns, 32-wide workgroup + gy = (S + 7) / 8; // 8-high workgroup + } else { + gx = (d + 15) / 16; + gy = (S + 15) / 16; + } + batch.dispatch(bwdPipe.pipeline, bwdPipe.layout, desc, gx, gy, 1, &bp, sizeof(bp)); + batch.barrier(); + } + + // grad_weight via grad-weight shader (if available) + if (has_gw_shader) { + std::vector bufs = { + {bufDx.handle, 0, sd}, + {h.bufActivations[l].handle, 0, sd}, + {bufGradW.handle, 0, packed_size}, + }; + VkDescriptorSet desc = cache.allocDescriptorSet("moe-layer-grad-weight", bufs); + uint32_t gx = (d + 15) / 16; + uint32_t gy = (d + 15) / 16; + batch.dispatch(gwPipe.pipeline, gwPipe.layout, desc, gx, gy, 1, &bp, sizeof(bp)); + } + + batch.submitDeferred(); + batch.waitForCompletion(); + + // Download grad_W + if (has_gw_shader) { + std::vector gw_packed(E * d * d); + pool.download(bufGradW, gw_packed.data(), packed_size); + for (uint32_t e = 0; e < E; ++e) { + out.grad_experts[l * E + e].assign(d * d, 0.f); + std::memcpy(out.grad_experts[l * E + e].data(), + gw_packed.data() + e * d * d, d * d * sizeof(float)); + } + } else { + // CPU fallback for grad_W + std::vector dx_cpu(S * d); + pool.download(bufDx, dx_cpu.data(), sd); + std::vector act(S * d); + pool.download(h.bufActivations[l], act.data(), sd); + using RM = Eigen::Matrix; + Eigen::Map DX(dx_cpu.data(), S, d); + Eigen::Map ACT(act.data(), S, d); + RM GW = DX.transpose() * ACT; + for (uint32_t e = 0; e < E; ++e) { + out.grad_experts[l * E + e].resize(d * d); + float pe = router_p[l][e]; + for (size_t i = 0; i < d * d; ++i) + out.grad_experts[l * E + e][i] = pe * GW.data()[i]; + } + } + + // Router gradient (CPU, tiny — skip for now, zero placeholder) + out.grad_router_w[l].assign(E * d, 0.f); + out.grad_router_b[l].assign(E, 0.f); + + // Swap for next layer + std::swap(bufDx, bufGradIn); + } + + // Embedding gradient: scatter-add dx + std::vector dx_final(S * d); + pool.download(bufDx, dx_final.data(), sd); + + out.grad_embed.assign(V * d, 0.f); + out.grad_pos.assign(h.maxSeq * d, 0.f); + for (uint32_t s = 0; s < S; ++s) { + int32_t tok = input_ids[s]; + if (tok >= 0 && static_cast(tok) < V) + for (uint32_t j = 0; j < d; ++j) + out.grad_embed[tok * d + j] += dx_final[s * d + j]; + for (uint32_t j = 0; j < d; ++j) + out.grad_pos[s * d + j] = dx_final[s * d + j]; + } + + pool.release(bufGL); + pool.release(bufDx); + pool.release(bufGradIn); + pool.release(bufGradW); + + return out; +} + +} // namespace ops +} // namespace grilly diff --git a/cpp/src/ops/moqe_train.cpp b/cpp/src/ops/moqe_train.cpp index f0b214c..88dd7b1 100644 --- a/cpp/src/ops/moqe_train.cpp +++ b/cpp/src/ops/moqe_train.cpp @@ -137,7 +137,8 @@ void moqe_layer_forward_gpu(CommandBatch& batch, BufferPool& pool, batchedTiledLinear(batch, cache, tc.bufA1, lw.w1, nullptr, tc.bufC1, n1, d, d); // No barrier between experts — independent reads/writes - batch.submit(); // Fence wait — safe to reuse buffers after this returns + batch.submitDeferred(); + batch.waitForCompletion(); // Fence wait — safe to reuse buffers after this returns if (n0 > 0) pool.download(tc.bufC0, out0, size_t(n0) * d * sizeof(float)); if (n1 > 0) pool.download(tc.bufC1, out1, size_t(n1) * d * sizeof(float)); @@ -166,7 +167,8 @@ void moqe_layer_backward_dx_gpu(CommandBatch& batch, BufferPool& pool, if (n1 > 0) batchedTiledLinear(batch, cache, tc.bufA1, lw.w1_t, nullptr, tc.bufC1, n1, d, d); - batch.submit(); + batch.submitDeferred(); + batch.waitForCompletion(); if (n0 > 0) pool.download(tc.bufC0, dx0, size_t(n0) * d * sizeof(float)); if (n1 > 0) pool.download(tc.bufC1, dx1, size_t(n1) * d * sizeof(float)); diff --git a/cpp/src/ops/optimizer.cpp b/cpp/src/ops/optimizer.cpp index 328bd91..38e622f 100644 --- a/cpp/src/ops/optimizer.cpp +++ b/cpp/src/ops/optimizer.cpp @@ -22,43 +22,66 @@ void adamUpdate(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, const AdamParams& p) { const size_t bytes = size_t(p.totalWeights) * sizeof(float); - GrillyBuffer bufW = pool.acquire(bytes); - GrillyBuffer bufGrad = pool.acquire(bytes); - GrillyBuffer bufM = pool.acquire(bytes); - GrillyBuffer bufV = pool.acquire(bytes); - - pool.upload(bufW, weights, bytes); - pool.upload(bufGrad, grad, bytes); - pool.upload(bufM, m, bytes); - pool.upload(bufV, v, bytes); + // Adam updates all 4 buffers in-place: each is both stage-in and + // stage-out. Use HOST_CACHED readback staging for all 4 since we read + // them back to CPU at the end of every step. + GrillyBuffer bufWDL = pool.acquireDeviceLocal(bytes); + GrillyBuffer bufGradDL = pool.acquireDeviceLocal(bytes); + GrillyBuffer bufMDL = pool.acquireDeviceLocal(bytes); + GrillyBuffer bufVDL = pool.acquireDeviceLocal(bytes); + + GrillyBuffer bufWStage = pool.acquireReadback(bytes); + GrillyBuffer bufGradStage = pool.acquireReadback(bytes); + GrillyBuffer bufMStage = pool.acquireReadback(bytes); + GrillyBuffer bufVStage = pool.acquireReadback(bytes); + + pool.upload(bufWStage, weights, bytes); + pool.upload(bufGradStage, grad, bytes); + pool.upload(bufMStage, m, bytes); + pool.upload(bufVStage, v, bytes); PipelineEntry pipe = cache.getOrCreate("adam-update", 4, sizeof(AdamParams)); std::vector bufInfos = { - {bufW.handle, 0, bytes}, - {bufGrad.handle, 0, bytes}, - {bufM.handle, 0, bytes}, - {bufV.handle, 0, bytes}, + {bufWDL.handle, 0, bytes}, + {bufGradDL.handle, 0, bytes}, + {bufMDL.handle, 0, bytes}, + {bufVDL.handle, 0, bytes}, }; VkDescriptorSet descSet = cache.allocDescriptorSet("adam-update", bufInfos); uint32_t gx = (p.totalWeights + 255) / 256; batch.begin(); + batch.copyBuffer(bufWStage, bufWDL, bytes); + batch.copyBuffer(bufGradStage, bufGradDL, bytes); + batch.copyBuffer(bufMStage, bufMDL, bytes); + batch.copyBuffer(bufVStage, bufVDL, bytes); + batch.transferComputeBarrier(); batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, 1, 1, &p, sizeof(p)); - batch.submit(); - - pool.download(bufW, weights, bytes); - pool.download(bufGrad, grad, bytes); - pool.download(bufM, m, bytes); - pool.download(bufV, v, bytes); - - pool.release(bufW); - pool.release(bufGrad); - pool.release(bufM); - pool.release(bufV); + batch.transferComputeBarrier(); + batch.copyBuffer(bufWDL, bufWStage, bytes); + batch.copyBuffer(bufGradDL, bufGradStage, bytes); + batch.copyBuffer(bufMDL, bufMStage, bytes); + batch.copyBuffer(bufVDL, bufVStage, bytes); + batch.submitDeferred(); + batch.waitForCompletion(); + + pool.download(bufWStage, weights, bytes); + pool.download(bufGradStage, grad, bytes); + pool.download(bufMStage, m, bytes); + pool.download(bufVStage, v, bytes); + + pool.release(bufWDL); + pool.release(bufGradDL); + pool.release(bufMDL); + pool.release(bufVDL); + pool.release(bufWStage); + pool.release(bufGradStage); + pool.release(bufMStage); + pool.release(bufVStage); } // ── AdamW ──────────────────────────────────────────────────────────────── @@ -68,24 +91,30 @@ void adamwUpdate(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, const AdamWParams& p) { const size_t bytes = size_t(p.totalWeights) * sizeof(float); - GrillyBuffer bufW = pool.acquire(bytes); - GrillyBuffer bufGrad = pool.acquire(bytes); - GrillyBuffer bufM = pool.acquire(bytes); - GrillyBuffer bufV = pool.acquire(bytes); + // Same staging pattern as adamUpdate — all 4 buffers in-place updated. + GrillyBuffer bufWDL = pool.acquireDeviceLocal(bytes); + GrillyBuffer bufGradDL = pool.acquireDeviceLocal(bytes); + GrillyBuffer bufMDL = pool.acquireDeviceLocal(bytes); + GrillyBuffer bufVDL = pool.acquireDeviceLocal(bytes); - pool.upload(bufW, weights, bytes); - pool.upload(bufGrad, grad, bytes); - pool.upload(bufM, m, bytes); - pool.upload(bufV, v, bytes); + GrillyBuffer bufWStage = pool.acquireReadback(bytes); + GrillyBuffer bufGradStage = pool.acquireReadback(bytes); + GrillyBuffer bufMStage = pool.acquireReadback(bytes); + GrillyBuffer bufVStage = pool.acquireReadback(bytes); + + pool.upload(bufWStage, weights, bytes); + pool.upload(bufGradStage, grad, bytes); + pool.upload(bufMStage, m, bytes); + pool.upload(bufVStage, v, bytes); PipelineEntry pipe = cache.getOrCreate("adamw-update", 4, sizeof(AdamWParams)); std::vector bufInfos = { - {bufW.handle, 0, bytes}, - {bufGrad.handle, 0, bytes}, - {bufM.handle, 0, bytes}, - {bufV.handle, 0, bytes}, + {bufWDL.handle, 0, bytes}, + {bufGradDL.handle, 0, bytes}, + {bufMDL.handle, 0, bytes}, + {bufVDL.handle, 0, bytes}, }; VkDescriptorSet descSet = cache.allocDescriptorSet("adamw-update", bufInfos); @@ -93,19 +122,34 @@ void adamwUpdate(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, uint32_t gx = (p.totalWeights + 255) / 256; batch.begin(); + batch.copyBuffer(bufWStage, bufWDL, bytes); + batch.copyBuffer(bufGradStage, bufGradDL, bytes); + batch.copyBuffer(bufMStage, bufMDL, bytes); + batch.copyBuffer(bufVStage, bufVDL, bytes); + batch.transferComputeBarrier(); batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, 1, 1, &p, sizeof(p)); - batch.submit(); - - pool.download(bufW, weights, bytes); - pool.download(bufGrad, grad, bytes); - pool.download(bufM, m, bytes); - pool.download(bufV, v, bytes); - - pool.release(bufW); - pool.release(bufGrad); - pool.release(bufM); - pool.release(bufV); + batch.transferComputeBarrier(); + batch.copyBuffer(bufWDL, bufWStage, bytes); + batch.copyBuffer(bufGradDL, bufGradStage, bytes); + batch.copyBuffer(bufMDL, bufMStage, bytes); + batch.copyBuffer(bufVDL, bufVStage, bytes); + batch.submitDeferred(); + batch.waitForCompletion(); + + pool.download(bufWStage, weights, bytes); + pool.download(bufGradStage, grad, bytes); + pool.download(bufMStage, m, bytes); + pool.download(bufVStage, v, bytes); + + pool.release(bufWDL); + pool.release(bufGradDL); + pool.release(bufMDL); + pool.release(bufVDL); + pool.release(bufWStage); + pool.release(bufGradStage); + pool.release(bufMStage); + pool.release(bufVStage); } } // namespace ops diff --git a/cpp/src/ops/perceiver.cpp b/cpp/src/ops/perceiver.cpp index 824bd89..a6e2343 100644 --- a/cpp/src/ops/perceiver.cpp +++ b/cpp/src/ops/perceiver.cpp @@ -32,7 +32,8 @@ void perceiverEncode(CommandBatch& batch, BufferPool& pool, PipelineCache& cache batch.begin(); batchedPerceiverEncode(batch, cache, bufQ, bufK, bufV, bufOut, seqN, seqM, headDim); - batch.submit(); + batch.submitDeferred(); + batch.waitForCompletion(); pool.download(bufOut, output, qBytes); diff --git a/cpp/src/ops/perceiver_encoder.cpp b/cpp/src/ops/perceiver_encoder.cpp index bd18c56..159a051 100644 --- a/cpp/src/ops/perceiver_encoder.cpp +++ b/cpp/src/ops/perceiver_encoder.cpp @@ -138,7 +138,8 @@ std::vector perceiver_encode_native( nPatches, D, D); } // ONE barrier after ALL K/V projections complete - batch.submit(); + batch.submitDeferred(); + batch.waitForCompletion(); // ══════════════════════════════════════════════════════════════════ // PHASE 2: Main layer loop — only Q_cross + cached K/V + self-attn @@ -198,7 +199,8 @@ std::vector perceiver_encode_native( std::swap(current, scratch); } - batch.submit(); + batch.submitDeferred(); + batch.waitForCompletion(); // Download and mean-pool std::vector finalLatents(N * D); diff --git a/cpp/src/ops/pooling.cpp b/cpp/src/ops/pooling.cpp index f5e3784..d9f9a00 100644 --- a/cpp/src/ops/pooling.cpp +++ b/cpp/src/ops/pooling.cpp @@ -52,7 +52,8 @@ void maxpool2dForward(CommandBatch& batch, BufferPool& pool, batch.begin(); batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, gy, gz, &p, sizeof(p)); - batch.submit(); + batch.submitDeferred(); + batch.waitForCompletion(); pool.download(bufOut, output, outBytes); pool.download(bufIdx, reinterpret_cast(indices), idxBytes); @@ -104,7 +105,8 @@ void maxpool2dBackward(CommandBatch& batch, BufferPool& pool, batch.begin(); batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, gy, gz, &p, sizeof(p)); - batch.submit(); + batch.submitDeferred(); + batch.waitForCompletion(); pool.download(bufGradIn, gradInput, gradInBytes); @@ -146,7 +148,8 @@ void avgpool2dForward(CommandBatch& batch, BufferPool& pool, batch.begin(); batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, gy, gz, &p, sizeof(p)); - batch.submit(); + batch.submitDeferred(); + batch.waitForCompletion(); pool.download(bufOut, output, outBytes); @@ -187,7 +190,8 @@ void avgpool2dBackward(CommandBatch& batch, BufferPool& pool, batch.begin(); batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, gy, gz, &p, sizeof(p)); - batch.submit(); + batch.submitDeferred(); + batch.waitForCompletion(); pool.download(bufGradIn, gradInput, gradInBytes); @@ -223,7 +227,8 @@ void meanPool(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, batch.begin(); batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, 1, 1, &p, sizeof(p)); - batch.submit(); + batch.submitDeferred(); + batch.waitForCompletion(); pool.download(bufOut, output, outBytes); diff --git a/cpp/src/ops/prefix_scan.cpp b/cpp/src/ops/prefix_scan.cpp new file mode 100644 index 0000000..318a42b --- /dev/null +++ b/cpp/src/ops/prefix_scan.cpp @@ -0,0 +1,165 @@ +#include "grilly/ops/prefix_scan.h" + +#include +#include +#include + +namespace grilly { +namespace ops { + +// Uses the staging pattern (DEVICE_LOCAL compute + WC stage-in + HOST_CACHED +// stage-out) — same as linear.cpp. See linear.cpp for the rationale. + +void prefixScanCausal(CommandBatch& batch, BufferPool& pool, + PipelineCache& cache, + const float* x, const float* a, float* h, + const PrefixScanParams& p) { + const size_t elemBytes = size_t(p.batchSize) * p.seqLen * p.hiddenDim * + sizeof(float); + + // DEVICE_LOCAL compute buffers + GrillyBuffer bufXDL = pool.acquireDeviceLocal(elemBytes); + GrillyBuffer bufADL = pool.acquireDeviceLocal(elemBytes); + GrillyBuffer bufHDL = pool.acquireDeviceLocal(elemBytes); + + // Host staging + GrillyBuffer bufXStage = pool.acquire(elemBytes); + GrillyBuffer bufAStage = pool.acquire(elemBytes); + GrillyBuffer bufHStage = pool.acquireReadback(elemBytes); + + pool.upload(bufXStage, x, elemBytes); + pool.upload(bufAStage, a, elemBytes); + + PipelineEntry pipe = cache.getOrCreate( + "prefix-scan-causal", 3, 2 * sizeof(uint32_t)); + + std::vector bufInfos = { + {bufXDL.handle, 0, elemBytes}, + {bufADL.handle, 0, elemBytes}, + {bufHDL.handle, 0, elemBytes}, + }; + VkDescriptorSet descSet = + cache.allocDescriptorSet("prefix-scan-causal", bufInfos); + + struct Push { + uint32_t seqLen; + uint32_t hiddenDim; + } push = {p.seqLen, p.hiddenDim}; + + // One workgroup per (hidden_dim, batch) pair. Local size = 32 threads + // (one per time step) — shader enforces seqLen <= 32. + uint32_t gx = p.hiddenDim; + uint32_t gy = p.batchSize; + + batch.begin(); + + batch.copyBuffer(bufXStage, bufXDL, elemBytes); + batch.copyBuffer(bufAStage, bufADL, elemBytes); + batch.transferComputeBarrier(); + + batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, gy, 1, + &push, sizeof(push)); + + batch.transferComputeBarrier(); + batch.copyBuffer(bufHDL, bufHStage, elemBytes); + batch.submitDeferred(); + batch.waitForCompletion(); + + pool.download(bufHStage, h, elemBytes); + + pool.release(bufXDL); + pool.release(bufADL); + pool.release(bufHDL); + pool.release(bufXStage); + pool.release(bufAStage); + pool.release(bufHStage); +} + +void prefixScanCausalBackward(CommandBatch& batch, BufferPool& pool, + PipelineCache& cache, + const float* dh, const float* a, + const float* h, const float* x, + float* dx, float* da, + const PrefixScanParams& p) { + const size_t elemBytes = size_t(p.batchSize) * p.seqLen * p.hiddenDim * + sizeof(float); + + // DEVICE_LOCAL compute buffers (4 inputs, 2 outputs) + GrillyBuffer bufDhDL = pool.acquireDeviceLocal(elemBytes); + GrillyBuffer bufADL = pool.acquireDeviceLocal(elemBytes); + GrillyBuffer bufHDL = pool.acquireDeviceLocal(elemBytes); + GrillyBuffer bufXDL = pool.acquireDeviceLocal(elemBytes); + GrillyBuffer bufDxDL = pool.acquireDeviceLocal(elemBytes); + GrillyBuffer bufDaDL = pool.acquireDeviceLocal(elemBytes); + + // Staging + GrillyBuffer bufDhStage = pool.acquire(elemBytes); + GrillyBuffer bufAStage = pool.acquire(elemBytes); + GrillyBuffer bufHStage = pool.acquire(elemBytes); + GrillyBuffer bufXStage = pool.acquire(elemBytes); + GrillyBuffer bufDxStage = pool.acquireReadback(elemBytes); + GrillyBuffer bufDaStage = pool.acquireReadback(elemBytes); + + pool.upload(bufDhStage, dh, elemBytes); + pool.upload(bufAStage, a, elemBytes); + pool.upload(bufHStage, h, elemBytes); + pool.upload(bufXStage, x, elemBytes); + + PipelineEntry pipe = cache.getOrCreate( + "prefix-scan-causal-backward", 6, 2 * sizeof(uint32_t)); + + std::vector bufInfos = { + {bufDhDL.handle, 0, elemBytes}, + {bufADL.handle, 0, elemBytes}, + {bufHDL.handle, 0, elemBytes}, + {bufXDL.handle, 0, elemBytes}, + {bufDxDL.handle, 0, elemBytes}, + {bufDaDL.handle, 0, elemBytes}, + }; + VkDescriptorSet descSet = cache.allocDescriptorSet( + "prefix-scan-causal-backward", bufInfos); + + struct Push { + uint32_t seqLen; + uint32_t hiddenDim; + } push = {p.seqLen, p.hiddenDim}; + + uint32_t gx = p.hiddenDim; + uint32_t gy = p.batchSize; + + batch.begin(); + + batch.copyBuffer(bufDhStage, bufDhDL, elemBytes); + batch.copyBuffer(bufAStage, bufADL, elemBytes); + batch.copyBuffer(bufHStage, bufHDL, elemBytes); + batch.copyBuffer(bufXStage, bufXDL, elemBytes); + batch.transferComputeBarrier(); + + batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, gy, 1, + &push, sizeof(push)); + + batch.transferComputeBarrier(); + batch.copyBuffer(bufDxDL, bufDxStage, elemBytes); + batch.copyBuffer(bufDaDL, bufDaStage, elemBytes); + batch.submitDeferred(); + batch.waitForCompletion(); + + pool.download(bufDxStage, dx, elemBytes); + pool.download(bufDaStage, da, elemBytes); + + pool.release(bufDhDL); + pool.release(bufADL); + pool.release(bufHDL); + pool.release(bufXDL); + pool.release(bufDxDL); + pool.release(bufDaDL); + pool.release(bufDhStage); + pool.release(bufAStage); + pool.release(bufHStage); + pool.release(bufXStage); + pool.release(bufDxStage); + pool.release(bufDaStage); +} + +} // namespace ops +} // namespace grilly diff --git a/cpp/src/ops/rmsnorm.cpp b/cpp/src/ops/rmsnorm.cpp index d7000f7..fa815f1 100644 --- a/cpp/src/ops/rmsnorm.cpp +++ b/cpp/src/ops/rmsnorm.cpp @@ -71,7 +71,8 @@ void rmsnorm(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx1, 1, 1, &push1, sizeof(push1)); - batch.submit(); + batch.submitDeferred(); + batch.waitForCompletion(); // Download result pool.download(bufOutput, output, outputBytes); diff --git a/cpp/src/ops/snn.cpp b/cpp/src/ops/snn.cpp index 1dfe280..44aee33 100644 --- a/cpp/src/ops/snn.cpp +++ b/cpp/src/ops/snn.cpp @@ -53,7 +53,8 @@ void lifStep(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, batch.begin(); batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, 1, 1, &p, sizeof(p)); - batch.submit(); + batch.submitDeferred(); + batch.waitForCompletion(); pool.download(bufVMem, vMem, bytes); pool.download(bufRefrac, tRefrac, bytes); @@ -100,7 +101,8 @@ void snnNodeForward(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, batch.begin(); batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, 1, 1, &p, sizeof(p)); - batch.submit(); + batch.submitDeferred(); + batch.waitForCompletion(); pool.download(bufVMem, vMem, bytes); pool.download(bufSpikes, spikes, bytes); @@ -143,7 +145,8 @@ void snnNodeBackward(CommandBatch& batch, BufferPool& pool, PipelineCache& cache batch.begin(); batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, 1, 1, &p, sizeof(p)); - batch.submit(); + batch.submitDeferred(); + batch.waitForCompletion(); pool.download(bufGradX, gradX, bytes); @@ -187,7 +190,8 @@ void hebbianLearning(CommandBatch& batch, BufferPool& pool, PipelineCache& cache batch.begin(); batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, gy, 1, &p, sizeof(p)); - batch.submit(); + batch.submitDeferred(); + batch.waitForCompletion(); pool.download(bufWeights, weights, weightBytes); @@ -252,7 +256,8 @@ void stdpLearning(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, gy, 1, &push1, sizeof(push1)); - batch.submit(); + batch.submitDeferred(); + batch.waitForCompletion(); pool.download(bufWeights, weights, weightBytes); pool.download(bufPreTrace, preTrace, preTraceBytes); @@ -295,7 +300,8 @@ void synapseFilter(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, batch.begin(); batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, 1, 1, &p, sizeof(p)); - batch.submit(); + batch.submitDeferred(); + batch.waitForCompletion(); pool.download(bufYState, yState, bytes); @@ -349,7 +355,8 @@ void gifNeuronStep(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, batch.begin(); batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, 1, 1, &p, sizeof(p)); - batch.submit(); + batch.submitDeferred(); + batch.waitForCompletion(); pool.download(bufVMem, vMem, bytes); pool.download(bufIAdapt, iAdapt, bytes); diff --git a/cpp/src/ops/swizzle.cpp b/cpp/src/ops/swizzle.cpp index ddcfc0a..4c75110 100644 --- a/cpp/src/ops/swizzle.cpp +++ b/cpp/src/ops/swizzle.cpp @@ -209,7 +209,8 @@ void swizzle(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, batch.begin(); batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, 1, 1, &push, sizeof(push)); - batch.submit(); + batch.submitDeferred(); + batch.waitForCompletion(); pool.download(bufOut, output, outputBytes); diff --git a/cpp/src/ops/vsa_lm_forward.cpp b/cpp/src/ops/vsa_lm_forward.cpp new file mode 100644 index 0000000..ee7e04b --- /dev/null +++ b/cpp/src/ops/vsa_lm_forward.cpp @@ -0,0 +1,656 @@ +/// vsa_lm_forward.cpp — fused VSA-LM forward (AdditionLinear FFN + MindForge LoRA) on GPU, +/// Eigen backward on CPU. +/// +/// FFN layers use the addition-linear shader (L1 distance, no matmul). +/// Output projection uses fnn-linear (standard matmul). +/// MindForge LoRA adapters are forged on CPU (tiny). + +#include "grilly/ops/vsa_lm_forward.h" +#include "grilly/ops/batched_ops.h" + +#include +#include +#include +#include +#include + +namespace grilly { +namespace ops { + +namespace { + +static std::unordered_map g_vsa; +static int g_next_vsa = 1; + +static void transpose_vd(const float* w, float* wt, uint32_t v, uint32_t d) { + for (uint32_t r = 0; r < v; ++r) + for (uint32_t c = 0; c < d; ++c) + wt[c * v + r] = w[r * d + c]; +} + +struct AddLinPush { + uint32_t batch_size; + uint32_t in_features; + uint32_t out_features; + uint32_t use_bias; +}; + +struct LayerNormParams { + uint32_t batch_size; + uint32_t seq_len; + uint32_t features; + float eps; + uint32_t pass_type; +}; + +} // namespace + +VsaLmHandleCache& vsa_lm_get_cache(int handle) { + auto it = g_vsa.find(handle); + if (it == g_vsa.end()) + throw std::runtime_error("Invalid vsa_lm handle"); + return it->second; +} + +void vsa_lm_release(BufferPool& pool, int handle) { + auto it = g_vsa.find(handle); + if (it == g_vsa.end()) + return; + VsaLmHandleCache& h = it->second; + + pool.release(h.embedW); + pool.release(h.posW); + pool.release(h.outW); + pool.release(h.outWt); + + for (auto& lw : h.layers) { + pool.release(lw.ffnUpW); + pool.release(lw.ffnUpB); + pool.release(lw.ffnDownW); + pool.release(lw.ffnDownB); + pool.release(lw.lnGamma); + pool.release(lw.lnBeta); + } + + pool.release(h.bufIds); + pool.release(h.bufPosSlice); + pool.release(h.bufX); + pool.release(h.bufLnOut); + pool.release(h.bufFfnUp); + pool.release(h.bufSign); + pool.release(h.bufFfnDown); + pool.release(h.bufLogits); + pool.release(h.bufLnMean); + pool.release(h.bufLnVar); + + for (auto& b : h.bufActivations) + pool.release(b); + + g_vsa.erase(it); +} + +int vsa_lm_upload(BufferPool& pool, + uint32_t vocab, uint32_t d, uint32_t d_ffn, uint32_t max_seq, + const float* embed_w, const float* pos_w, + const std::vector& ffn_up_ws, + const std::vector& ffn_up_bs, + const std::vector& ffn_down_ws, + const std::vector& ffn_down_bs, + const std::vector& ln_gammas, + const std::vector& ln_betas, + const float* out_w, + uint32_t n_layers) { + + if (n_layers == 0 || d == 0 || d_ffn == 0 || vocab == 0 || max_seq == 0) + throw std::runtime_error("vsa_lm_upload: invalid dimensions"); + if (ffn_up_ws.size() != n_layers || ffn_up_bs.size() != n_layers || + ffn_down_ws.size() != n_layers || ffn_down_bs.size() != n_layers || + ln_gammas.size() != n_layers || ln_betas.size() != n_layers) + throw std::runtime_error("vsa_lm_upload: list length mismatch"); + + VsaLmHandleCache h; + h.vocab = vocab; + h.d = d; + h.dFfn = d_ffn; + h.maxSeq = max_seq; + h.nLayers = n_layers; + + // Embedding + size_t embed_bytes = size_t(vocab) * d * sizeof(float); + h.cpu_embed.resize(vocab * d); + std::memcpy(h.cpu_embed.data(), embed_w, embed_bytes); + h.embedW = pool.acquire(embed_bytes); + pool.upload(h.embedW, embed_w, embed_bytes); + + // Positional + size_t pos_bytes = size_t(max_seq) * d * sizeof(float); + h.cpu_pos.resize(max_seq * d); + std::memcpy(h.cpu_pos.data(), pos_w, pos_bytes); + h.posW = pool.acquire(pos_bytes); + pool.upload(h.posW, pos_w, pos_bytes); + + // Output projection + size_t out_bytes = size_t(vocab) * d * sizeof(float); + h.cpu_out_w.resize(vocab * d); + std::memcpy(h.cpu_out_w.data(), out_w, out_bytes); + h.outW = pool.acquire(out_bytes); + pool.upload(h.outW, out_w, out_bytes); + std::vector out_wt(d * vocab); + transpose_vd(out_w, out_wt.data(), vocab, d); + h.outWt = pool.acquire(out_wt.size() * sizeof(float)); + pool.upload(h.outWt, out_wt.data(), out_wt.size() * sizeof(float)); + + // Per-layer weights + h.layers.resize(n_layers); + h.cpu_ffn_up_w.resize(n_layers); + h.cpu_ffn_up_b.resize(n_layers); + h.cpu_ffn_down_w.resize(n_layers); + h.cpu_ffn_down_b.resize(n_layers); + h.cpu_ln_gamma.resize(n_layers); + h.cpu_ln_beta.resize(n_layers); + + for (uint32_t l = 0; l < n_layers; ++l) { + auto& lw = h.layers[l]; + + size_t up_w_bytes = size_t(d_ffn) * d * sizeof(float); + size_t up_b_bytes = size_t(d_ffn) * sizeof(float); + size_t down_w_bytes = size_t(d) * d_ffn * sizeof(float); + size_t down_b_bytes = size_t(d) * sizeof(float); + size_t ln_bytes = size_t(d) * sizeof(float); + + lw.ffnUpW = pool.acquire(up_w_bytes); + pool.upload(lw.ffnUpW, ffn_up_ws[l], up_w_bytes); + lw.ffnUpB = pool.acquire(up_b_bytes); + pool.upload(lw.ffnUpB, ffn_up_bs[l], up_b_bytes); + + lw.ffnDownW = pool.acquire(down_w_bytes); + pool.upload(lw.ffnDownW, ffn_down_ws[l], down_w_bytes); + lw.ffnDownB = pool.acquire(down_b_bytes); + pool.upload(lw.ffnDownB, ffn_down_bs[l], down_b_bytes); + + lw.lnGamma = pool.acquire(ln_bytes); + pool.upload(lw.lnGamma, ln_gammas[l], ln_bytes); + lw.lnBeta = pool.acquire(ln_bytes); + pool.upload(lw.lnBeta, ln_betas[l], ln_bytes); + + h.cpu_ffn_up_w[l].resize(d_ffn * d); + std::memcpy(h.cpu_ffn_up_w[l].data(), ffn_up_ws[l], up_w_bytes); + h.cpu_ffn_up_b[l].resize(d_ffn); + std::memcpy(h.cpu_ffn_up_b[l].data(), ffn_up_bs[l], up_b_bytes); + h.cpu_ffn_down_w[l].resize(d * d_ffn); + std::memcpy(h.cpu_ffn_down_w[l].data(), ffn_down_ws[l], down_w_bytes); + h.cpu_ffn_down_b[l].resize(d); + std::memcpy(h.cpu_ffn_down_b[l].data(), ffn_down_bs[l], down_b_bytes); + h.cpu_ln_gamma[l].resize(d); + std::memcpy(h.cpu_ln_gamma[l].data(), ln_gammas[l], ln_bytes); + h.cpu_ln_beta[l].resize(d); + std::memcpy(h.cpu_ln_beta[l].data(), ln_betas[l], ln_bytes); + } + + // Working buffers + size_t sd = size_t(max_seq) * d * sizeof(float); + size_t s_dffn = size_t(max_seq) * d_ffn * sizeof(float); + size_t logit_sz = size_t(max_seq) * vocab * sizeof(float); + + h.bufIds = pool.acquire(max_seq * sizeof(uint32_t)); + h.bufPosSlice= pool.acquire(sd); + h.bufX = pool.acquire(sd); + h.bufLnOut = pool.acquire(sd); + h.bufFfnUp = pool.acquire(s_dffn); + h.bufSign = pool.acquire(s_dffn); + h.bufFfnDown = pool.acquire(sd); + h.bufLogits = pool.acquire(logit_sz); + h.bufLnMean = pool.acquire(max_seq * sizeof(float)); + h.bufLnVar = pool.acquire(max_seq * sizeof(float)); + + h.bufActivations.resize(n_layers + 1); + for (uint32_t l = 0; l <= n_layers; ++l) + h.bufActivations[l] = pool.acquire(sd); + + int hid = g_next_vsa++; + g_vsa[hid] = std::move(h); + return hid; +} + +// ═══════════════════════════════════════════════════════════════════════════ +// Helper: record addition-linear dispatch into batch (no begin/submit) +// ═══════════════════════════════════════════════════════════════════════════ + +static void batchedAdditionLinear(CommandBatch& batch, PipelineCache& cache, + const GrillyBuffer& input, const GrillyBuffer& weight, + const GrillyBuffer& bias, GrillyBuffer& output, + uint32_t S, uint32_t d_in, uint32_t d_out) { + + PipelineEntry pipe = cache.getOrCreate("addition-linear", 4, sizeof(AddLinPush)); + + size_t inBytes = size_t(S) * d_in * sizeof(float); + size_t wBytes = size_t(d_out) * d_in * sizeof(float); + size_t bBytes = size_t(d_out) * sizeof(float); + size_t outBytes = size_t(S) * d_out * sizeof(float); + + std::vector bufs = { + {input.handle, 0, inBytes}, + {weight.handle, 0, wBytes}, + {bias.handle, 0, bBytes}, + {output.handle, 0, outBytes}, + }; + VkDescriptorSet desc = cache.allocDescriptorSet("addition-linear", bufs); + + AddLinPush push{S, d_in, d_out, 1}; + uint32_t total = S * d_out; + uint32_t gx = (total + 255) / 256; + + batch.dispatch(pipe.pipeline, pipe.layout, desc, gx, 1, 1, + &push, sizeof(push)); +} + +// ═══════════════════════════════════════════════════════════════════════════ +// Helper: record sign activation into batch +// ═══════════════════════════════════════════════════════════════════════════ + +static void batchedSignActivation(CommandBatch& batch, PipelineCache& cache, + const GrillyBuffer& input, GrillyBuffer& output, + uint32_t totalElements) { + + PipelineEntry pipe = cache.getOrCreate("sign-activation", 2, sizeof(uint32_t)); + + size_t bytes = size_t(totalElements) * sizeof(float); + std::vector bufs = { + {input.handle, 0, bytes}, + {output.handle, 0, bytes}, + }; + VkDescriptorSet desc = cache.allocDescriptorSet("sign-activation", bufs); + + uint32_t push = totalElements; + uint32_t gx = (totalElements + 255) / 256; + batch.dispatch(pipe.pipeline, pipe.layout, desc, gx, 1, 1, + &push, sizeof(push)); +} + +// ═══════════════════════════════════════════════════════════════════════════ +// Helper: record 3-pass layernorm into batch (pre-allocated mean/var bufs) +// ═══════════════════════════════════════════════════════════════════════════ + +static void batchedLayerNorm(CommandBatch& batch, PipelineCache& cache, + const GrillyBuffer& input, GrillyBuffer& output, + const GrillyBuffer& gamma, const GrillyBuffer& beta, + GrillyBuffer& meanBuf, GrillyBuffer& varBuf, + uint32_t S, uint32_t features) { + + PipelineEntry pipe = cache.getOrCreate("fnn-layernorm", 6, sizeof(LayerNormParams)); + + size_t elemBytes = size_t(S) * features * sizeof(float); + size_t paramBytes = size_t(features) * sizeof(float); + size_t statBytes = size_t(S) * sizeof(float); + + std::vector bufs = { + {input.handle, 0, elemBytes}, + {output.handle, 0, elemBytes}, + {gamma.handle, 0, paramBytes}, + {beta.handle, 0, paramBytes}, + {meanBuf.handle, 0, statBytes}, + {varBuf.handle, 0, statBytes}, + }; + VkDescriptorSet desc = cache.allocDescriptorSet("fnn-layernorm", bufs); + + // batch_size=1, seq_len=S + LayerNormParams p0{1, S, features, 1e-5f, 0}; + uint32_t gxPos = (S + 255) / 256; + batch.dispatch(pipe.pipeline, pipe.layout, desc, gxPos, 1, 1, &p0, sizeof(p0)); + batch.barrier(); + + LayerNormParams p1{1, S, features, 1e-5f, 1}; + batch.dispatch(pipe.pipeline, pipe.layout, desc, gxPos, 1, 1, &p1, sizeof(p1)); + batch.barrier(); + + LayerNormParams p2{1, S, features, 1e-5f, 2}; + uint32_t gxAll = (S * features + 255) / 256; + batch.dispatch(pipe.pipeline, pipe.layout, desc, gxAll, 1, 1, &p2, sizeof(p2)); +} + +// ═══════════════════════════════════════════════════════════════════════════ +// vsa_lm_forward_gpu +// ═══════════════════════════════════════════════════════════════════════════ + +void vsa_lm_forward_gpu(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, + VsaLmHandleCache& h, const int32_t* input_ids, + uint32_t seq_len, float* logits_out) { + + if (seq_len == 0 || seq_len > h.maxSeq) + throw std::runtime_error("vsa_lm_forward: invalid seq_len"); + + uint32_t S = seq_len; + uint32_t d = h.d; + uint32_t dF = h.dFfn; + uint32_t V = h.vocab; + + // Upload token IDs and position slice + pool.upload(h.bufIds, reinterpret_cast(input_ids), + S * sizeof(uint32_t)); + pool.upload(h.bufPosSlice, h.cpu_pos.data(), S * d * sizeof(float)); + + // Phase 1: embedding lookup + position add + batch.begin(); + batchedEmbeddingLookup(batch, cache, h.bufIds, h.embedW, h.bufX, 1, S, h.vocab, d); + batch.barrier(); + batchedAdd(batch, cache, h.bufX, h.bufPosSlice, S * d); + batch.barrier(); + batch.copyBuffer(h.bufX, h.bufActivations[0], S * d * sizeof(float)); + batch.submitDeferred(); + batch.waitForCompletion(); + + // Phase 2: per-layer forward + for (uint32_t l = 0; l < h.nLayers; ++l) { + auto& lw = h.layers[l]; + + batch.begin(); + + // (a) LayerNorm + batchedLayerNorm(batch, cache, h.bufX, h.bufLnOut, + lw.lnGamma, lw.lnBeta, + h.bufLnMean, h.bufLnVar, S, d); + batch.barrier(); + + // (d) Addition-linear up: (S, d) → (S, d_ffn) + batchedAdditionLinear(batch, cache, h.bufLnOut, lw.ffnUpW, lw.ffnUpB, + h.bufFfnUp, S, d, dF); + batch.barrier(); + + // (e) Sign activation + batchedSignActivation(batch, cache, h.bufFfnUp, h.bufSign, S * dF); + batch.barrier(); + + // (f) Addition-linear down: (S, d_ffn) → (S, d) + batchedAdditionLinear(batch, cache, h.bufSign, lw.ffnDownW, lw.ffnDownB, + h.bufFfnDown, S, dF, d); + batch.barrier(); + + // (h) Residual: x += ffn_down + batchedAdd(batch, cache, h.bufX, h.bufFfnDown, S * d); + batch.barrier(); + + // (i) Save activation + batch.copyBuffer(h.bufX, h.bufActivations[l + 1], S * d * sizeof(float)); + + batch.submitDeferred(); + batch.waitForCompletion(); + } + + // Phase 3: output projection — x @ out_w.T / sqrt(d) + batch.begin(); + batchedLinear(batch, cache, h.bufX, h.outW, nullptr, h.bufLogits, S, d, V); + batch.submitDeferred(); + batch.waitForCompletion(); + + // Download and scale + std::vector raw(S * V); + pool.download(h.bufLogits, raw.data(), S * V * sizeof(float)); + float scale = 1.0f / std::sqrt(static_cast(d)); + for (size_t i = 0; i < size_t(S) * V; ++i) + logits_out[i] = raw[i] * scale; +} + +// ═══════════════════════════════════════════════════════════════════════════ +// vsa_lm_backward_cpu — Eigen-based backward matching moe_backward_cpu pattern. +// +// AdditionLinear backward: +// grad_input[row, k] = -sum_col( grad_out[row, col] * sign(W[col, k] - x[row, k]) ) +// grad_W[col, k] = -sum_row( grad_out[row, col] * sign(W[col, k] - x[row, k]) ) +// grad_b[col] = sum_row( grad_out[row, col] ) +// ═══════════════════════════════════════════════════════════════════════════ + +VsaLmGradients vsa_lm_backward_cpu(const VsaLmHandleCache& h, + const int32_t* input_ids, uint32_t seq_len, + const float* grad_logits) { + + uint32_t S = seq_len; + uint32_t d = h.d; + uint32_t dF = h.dFfn; + uint32_t V = h.vocab; + uint32_t L = h.nLayers; + + if (S == 0 || S > h.maxSeq) + throw std::runtime_error("vsa_lm_backward: invalid seq_len"); + + using RM = Eigen::Matrix; + + // Output projection backward: dx = grad_logits @ out_w * scale + float scale = 1.0f / std::sqrt(static_cast(d)); + Eigen::Map GL(grad_logits, S, V); + Eigen::Map OW(h.cpu_out_w.data(), V, d); + + RM scaledGL = GL * scale; // (S, V) scaled + RM dx = scaledGL * OW; // (S, d) + + VsaLmGradients out; + + // grad_out_w = scaledGL.T @ x_final + // Need final activation — download it. We stored it as bufActivations[L] during forward. + // For CPU backward we'll recompute from saved CPU mirrors. + // Actually let's replay the forward on CPU to get activations. + + // CPU forward replay for activations + std::vector> acts(L + 1); + std::vector> ln_outs(L); + std::vector> ffn_ups(L); // after addition-linear up + std::vector> sign_outs(L); // after sign activation + + // Initial: embed + pos + acts[0].resize(S * d); + for (uint32_t s = 0; s < S; ++s) { + int32_t tok = input_ids[s]; + for (uint32_t j = 0; j < d; ++j) { + float e = (tok >= 0 && static_cast(tok) < V) + ? h.cpu_embed[tok * d + j] : 0.f; + float p = h.cpu_pos[s * d + j]; + acts[0][s * d + j] = e + p; + } + } + + for (uint32_t l = 0; l < L; ++l) { + const auto& uw = h.cpu_ffn_up_w[l]; + const auto& ub = h.cpu_ffn_up_b[l]; + const auto& dw = h.cpu_ffn_down_w[l]; + const auto& db = h.cpu_ffn_down_b[l]; + const auto& gm = h.cpu_ln_gamma[l]; + const auto& bt = h.cpu_ln_beta[l]; + + // LayerNorm + ln_outs[l].resize(S * d); + for (uint32_t s = 0; s < S; ++s) { + float mean = 0.f; + for (uint32_t j = 0; j < d; ++j) + mean += acts[l][s * d + j]; + mean /= float(d); + float var = 0.f; + for (uint32_t j = 0; j < d; ++j) { + float diff = acts[l][s * d + j] - mean; + var += diff * diff; + } + var /= float(d); + float inv_std = 1.0f / std::sqrt(var + 1e-5f); + for (uint32_t j = 0; j < d; ++j) { + float norm = (acts[l][s * d + j] - mean) * inv_std; + ln_outs[l][s * d + j] = gm[j] * norm + bt[j]; + } + } + + // Addition-linear up: (S, d) → (S, d_ffn) + ffn_ups[l].resize(S * dF); + for (uint32_t s = 0; s < S; ++s) { + for (uint32_t o = 0; o < dF; ++o) { + float dist = 0.f; + for (uint32_t k = 0; k < d; ++k) + dist += std::abs(uw[o * d + k] - ln_outs[l][s * d + k]); + ffn_ups[l][s * dF + o] = -dist + ub[o]; + } + } + + // Sign activation + sign_outs[l].resize(S * dF); + for (size_t i = 0; i < S * dF; ++i) + sign_outs[l][i] = (ffn_ups[l][i] > 0.f) ? 1.f : -1.f; + + // Addition-linear down: (S, d_ffn) → (S, d) + acts[l + 1].resize(S * d); + for (uint32_t s = 0; s < S; ++s) { + for (uint32_t o = 0; o < d; ++o) { + float dist = 0.f; + for (uint32_t k = 0; k < dF; ++k) + dist += std::abs(dw[o * dF + k] - sign_outs[l][s * dF + k]); + float ffn_val = -dist + db[o]; + acts[l + 1][s * d + o] = acts[l][s * d + o] + ffn_val; + } + } + } + + // grad_out_w = scaledGL.T @ x_final → (V, d) + Eigen::Map XF(acts[L].data(), S, d); + RM GOW = scaledGL.transpose() * XF; + out.grad_out_w.resize(V * d); + std::memcpy(out.grad_out_w.data(), GOW.data(), V * d * sizeof(float)); + + // Back-propagate through layers in reverse + out.grad_ffn_up_w.resize(L); + out.grad_ffn_up_b.resize(L); + out.grad_ffn_down_w.resize(L); + out.grad_ffn_down_b.resize(L); + out.grad_ln_gamma.resize(L); + out.grad_ln_beta.resize(L); + + // dx is currently (S, d) + for (int32_t l = L - 1; l >= 0; --l) { + const auto& uw = h.cpu_ffn_up_w[l]; + const auto& dw = h.cpu_ffn_down_w[l]; + + // dx passes through residual: grad to addition-linear down is dx + // Addition-linear down backward: + // grad_sign[s, k] = -sum_o( dx[s, o] * sign(dw[o, k] - sign_out[s, k]) ) + // grad_dw[o, k] = -sum_s( dx[s, o] * sign(dw[o, k] - sign_out[s, k]) ) + // grad_db[o] = sum_s( dx[s, o] ) + + out.grad_ffn_down_w[l].assign(d * dF, 0.f); + out.grad_ffn_down_b[l].assign(d, 0.f); + std::vector grad_sign(S * dF, 0.f); + + for (uint32_t s = 0; s < S; ++s) { + for (uint32_t o = 0; o < d; ++o) { + float go = dx(s, o); + out.grad_ffn_down_b[l][o] += go; + for (uint32_t k = 0; k < dF; ++k) { + float sgn = (dw[o * dF + k] > sign_outs[l][s * dF + k]) ? 1.f : + (dw[o * dF + k] < sign_outs[l][s * dF + k]) ? -1.f : 0.f; + out.grad_ffn_down_w[l][o * dF + k] += -go * sgn; + grad_sign[s * dF + k] += -go * sgn; + } + } + } + + // Sign activation backward: grad_ffn_up = grad_sign * 0 (sign is flat) + // Actually sign subgradient: d/dx sign(x) = 0 almost everywhere. + // But for training we use STE (straight-through estimator): + // grad_ffn_up = grad_sign (pass through) + std::vector& grad_ffn_up = grad_sign; // STE + + // Addition-linear up backward: + // grad_ln[s, k] = -sum_o( grad_up[s, o] * sign(uw[o, k] - ln_out[s, k]) ) + // grad_uw[o, k] = -sum_s( grad_up[s, o] * sign(uw[o, k] - ln_out[s, k]) ) + // grad_ub[o] = sum_s( grad_up[s, o] ) + + out.grad_ffn_up_w[l].assign(dF * d, 0.f); + out.grad_ffn_up_b[l].assign(dF, 0.f); + std::vector grad_ln(S * d, 0.f); + + for (uint32_t s = 0; s < S; ++s) { + for (uint32_t o = 0; o < dF; ++o) { + float gu = grad_ffn_up[s * dF + o]; + out.grad_ffn_up_b[l][o] += gu; + for (uint32_t k = 0; k < d; ++k) { + float sgn = (uw[o * d + k] > ln_outs[l][s * d + k]) ? 1.f : + (uw[o * d + k] < ln_outs[l][s * d + k]) ? -1.f : 0.f; + out.grad_ffn_up_w[l][o * d + k] += -gu * sgn; + grad_ln[s * d + k] += -gu * sgn; + } + } + } + + // LayerNorm backward (simplified — gamma * grad_ln_out passed to LN backward) + out.grad_ln_gamma[l].assign(d, 0.f); + out.grad_ln_beta[l].assign(d, 0.f); + + const auto& gm = h.cpu_ln_gamma[l]; + + // Recompute mean/invstd + std::vector means(S), inv_stds(S); + for (uint32_t s = 0; s < S; ++s) { + float m = 0.f; + for (uint32_t j = 0; j < d; ++j) + m += acts[l][s * d + j]; + m /= float(d); + means[s] = m; + float v = 0.f; + for (uint32_t j = 0; j < d; ++j) { + float diff = acts[l][s * d + j] - m; + v += diff * diff; + } + v /= float(d); + inv_stds[s] = 1.0f / std::sqrt(v + 1e-5f); + } + + // grad_beta = sum_s(grad_ln), grad_gamma = sum_s(grad_ln * norm) + for (uint32_t s = 0; s < S; ++s) { + for (uint32_t j = 0; j < d; ++j) { + float norm = (acts[l][s * d + j] - means[s]) * inv_stds[s]; + out.grad_ln_beta[l][j] += grad_ln[s * d + j]; + out.grad_ln_gamma[l][j] += grad_ln[s * d + j] * norm; + } + } + + // Backprop through layernorm to get grad w.r.t. input + // Using full layernorm backward formula + RM grad_x_ln(S, d); + for (uint32_t s = 0; s < S; ++s) { + float is = inv_stds[s]; + float m = means[s]; + + // dl_dxhat = grad_ln * gamma + Eigen::VectorXf dl_dxhat(d); + for (uint32_t j = 0; j < d; ++j) + dl_dxhat[j] = grad_ln[s * d + j] * gm[j]; + + float sum1 = dl_dxhat.sum(); + float sum2 = 0.f; + for (uint32_t j = 0; j < d; ++j) { + float xhat = (acts[l][s * d + j] - m) * is; + sum2 += dl_dxhat[j] * xhat; + } + + for (uint32_t j = 0; j < d; ++j) { + float xhat = (acts[l][s * d + j] - m) * is; + grad_x_ln(s, j) = is * (dl_dxhat[j] - sum1 / d - xhat * sum2 / d); + } + } + + // Residual: dx for next layer = dx (through residual) + grad_x_ln + RM new_dx = dx + grad_x_ln; + dx = new_dx; + } + + // Embedding gradient: scatter-add dx + out.grad_embed.assign(V * d, 0.f); + out.grad_pos.assign(h.maxSeq * d, 0.f); + for (uint32_t s = 0; s < S; ++s) { + int32_t tok = input_ids[s]; + if (tok >= 0 && static_cast(tok) < V) { + for (uint32_t j = 0; j < d; ++j) + out.grad_embed[tok * d + j] += dx(s, j); + } + for (uint32_t j = 0; j < d; ++j) + out.grad_pos[s * d + j] = dx(s, j); + } + + return out; +} + +} // namespace ops +} // namespace grilly diff --git a/docs/PERF_DISPATCH.md b/docs/PERF_DISPATCH.md new file mode 100644 index 0000000..4cbef0d --- /dev/null +++ b/docs/PERF_DISPATCH.md @@ -0,0 +1,54 @@ +# GPU dispatch performance (Python `VulkanCore` / `VulkanCompute`) + +This document summarizes **non-blocking** and **batched** Vulkan dispatch APIs and how they relate to the PyTorch parity performance workstreams (B1, B2, D4). + +## Synchronous vs async (`backend/core.py`) + +- **`VulkanCore._dispatch_compute(...)`** — Records one command buffer, submits, **waits on a fence** (default `wait_previous=True`). Safe and simple; one fence wait per call. +- **`VulkanCore._dispatch_compute_async(...)`** — Submits work **without** waiting; pair with **`_wait_async()`** before reading GPU results. +- **`VulkanCore.record_commands()`** — Returns a **`CommandRecorder`** context manager: multiple `dispatch` + `barrier` calls, then **one** submit + wait on `__exit__`. Used by FlashAttention2 tiling, RMSNorm two-pass, and **`VulkanFNN._linear_relu_recorded_chain`** (Linear→ReLU fallback when `fused-linear-relu.spv` is missing). + +### FNN chain recorder (`record_commands(fnn_chain=True)`) + +`VulkanCompute.record_commands(fnn_chain=True)` returns **`FnnChainRecorder`** (`backend/fnn_chain.py`): high-level **`linear`**, **`relu`**, **`softmax`** methods that record dispatches only and return **`ChainBufferHandle`**. **`read(handle)`** performs **one** `submit_and_wait` + one download. **`read_multiple([h0, h1, ...])`** does **one** `submit_and_wait` then downloads each handle (MoE fan-out: many expert linears, one fence wait, many CPU reads). Use for deep MLP-style forwards to avoid one fence wait per layer. + +Equivalent: **`VulkanFNN.chain_record()`**. + +## Public aliases on `VulkanCompute` (`backend/compute.py`) + +After `Compute()` / `VulkanCompute()` construction: + +| Attribute | Underlying API | +|-----------|----------------| +| `record_commands(fnn_chain=False)` | `core.record_commands()` or `fnn.chain_record()` when `fnn_chain=True` | +| `dispatch_compute` | `core._dispatch_compute` | +| `dispatch_compute_async` | `core._dispatch_compute_async` | +| `wait_async` | `core._wait_async` | +| `wait_fence` | `core._wait_fence` | + +Prefer **`record_commands`** for dependent kernels that share the same queue and do not need host-visible results between dispatches. + +## `nn.Sequential` fusion + +`nn.Sequential` already detects **Linear → ReLU / GELU / SiLU** and calls **`backend.fnn.fused_*`** when those shaders exist. If the fused shader is absent, **`fused_linear_relu`** tries **`_linear_relu_recorded_chain`** (two dispatches, **one** submit) before falling back to separate `linear` + `activation_relu` calls. + +## Pybind11 GIL policy (workstream B3) + +Heavy `grilly_core` entrypoints should release the GIL around GPU work: + +```cpp +{ + py::gil_scoped_release release; + grilly::ops::someOp(...); +} +``` + +Bindings should **request** `py::buffer_info` / output arrays **before** releasing the GIL. New array-based inputs should satisfy **`require_c_contiguous_float`** / **`require_c_contiguous_uint32`** / **`require_c_contiguous_int8`** (`bindings_core.h`) so kernels see dense row-major data. + +## Code review checklist (bindings) + +1. GIL held only for buffer prep and return value wrapping; **not** during `CommandBatch` / pool work. +2. `float32` numpy arrays are **C-contiguous** (or explicitly copied) for GPU kernels. +3. No duplicate `request()` on invalid buffers after GIL re-acquire without re-verifying lifetimes. + +Covered in this policy (non-exhaustive): activations, attention, linear, conv, normalization, loss, SNN, pooling, optim GPU steps, **misc** (dropout, embedding, KV cache GPU paths, `swizzle_kv`), **Hamming** (GPU path), **SigLIP** / **Perceiver** / **MoQE train**, **`ShaderFusionEngine.fuse`**, **`OpGraph.optimize` / `OpGraph.execute`**. diff --git a/docs/api/functional.md b/docs/api/functional.md index 0985602..735e7da 100644 --- a/docs/api/functional.md +++ b/docs/api/functional.md @@ -154,3 +154,13 @@ Source: [`functional/`](https://github.com/grillcheese-ai/grilly/tree/main/funct | `seq_to_ann_forward` | `seq_to_ann_forward(module, x)` | Apply ANN module to temporal data. | | `reset_net` | `reset_net(module)` | Reset all neuron membrane potentials. | | `set_step_mode` | `set_step_mode(module, mode)` | Set single-step or multi-step mode. | + +--- + +## PyTorch parity notes + +- **Layout**: `linear(x, weight, bias)` uses `weight` of shape `(out_features, in_features)`, matching `torch.nn.functional.linear` and `nn.Linear.weight`. +- **Dtypes**: GPU paths expect float32 inputs unless documented otherwise. +- **Differences**: Not every `torch.nn.functional` symbol exists; Grilly adds domain-specific ops (SNN, memory, FAISS-like). See `docs/PYTORCH_PARITY_STATUS.md`. +- **Migration**: Step-by-step porting guidance lives in `docs/MIGRATION_PYTORCH.md`. +- **Tests**: Numerical checks vs numpy (and optional PyTorch) live under `tests/parity/`. diff --git a/docs/grl_v1_format.md b/docs/grl_v1_format.md new file mode 100644 index 0000000..a4adeba --- /dev/null +++ b/docs/grl_v1_format.md @@ -0,0 +1,24 @@ +# GRL v1 checkpoint format (`.grl`) + +Grilly’s native, pickle-free checkpoint container for model weights, optimizer metadata, and training scalars. + +## Layout + +1. **Header** (64 bytes): magic `GRLY`, `uint16` format version (1), `uint16` flags, `uint32` reserved, then `uint64` offsets/lengths for metadata JSON, tensor index JSON, and raw payload. +2. **Metadata JSON** (UTF-8): `schema: "grilly.checkpoint.v1"`, `framework: "grilly"`, optional keys such as `training_step`, `best_ppl`, `step`, `epoch`, and `extra`. +3. **Tensor index JSON**: ordered array of `{name, dtype, shape, offset, length}` with `offset`/`length` relative to the **start of the payload** section. +4. **Payload**: concatenated C-contiguous row-major tensor bytes (little-endian scalars in index). + +## dtypes + +Index encodes dtypes as `f32`, `f16`, `i64`, `i32`, `u8`. + +## API + +- Write: `grilly.utils.grl_checkpoint.save_grl(path, state_dict, metadata=...)`. +- Read: `grilly.utils.grl_checkpoint.load_grl(path, map_location=...)`. +- Torch-style: `grilly.torch_api.save` / `grilly.torch_api.load` (`.grl` only). + +## Versioning + +Format version **1** is fixed in `FORMAT_VERSION` in `utils/grl_checkpoint.py` and the C++ reader/writer. Future versions must bump the header version and preserve a migration path. diff --git a/docs/index.md b/docs/index.md index 638b17c..7698582 100644 --- a/docs/index.md +++ b/docs/index.md @@ -102,6 +102,10 @@ grilly/ | [optimum-grilly](https://github.com/grillcheese-ai/optimum-grilly) | HuggingFace Optimum backend -- `from_pretrained` to Vulkan inference | | [CubeMind](https://github.com/grillcheese-ai/cubemind) | Neuro-vector-symbolic reasoning powered by grilly | +## Pre-v1.0 Planning + +- [Pre-v1.0 Optimization + Parity Tasklist](pre-v1.0/OPTIMIZATION_PARITY_TASKLIST.md) + ## License MIT -- see [LICENSE](https://github.com/grillcheese-ai/grilly/blob/main/LICENSE). diff --git a/docs/pre-v1.0/OPTIMIZATION_PARITY_TASKLIST.md b/docs/pre-v1.0/OPTIMIZATION_PARITY_TASKLIST.md new file mode 100644 index 0000000..09b6438 --- /dev/null +++ b/docs/pre-v1.0/OPTIMIZATION_PARITY_TASKLIST.md @@ -0,0 +1,210 @@ +# Pre-v1.0 Optimization + Parity Tasklist + +This tasklist is a post-upgrade pass focused on shipping a stable, fast pre-v1.0 runtime with clearer PyTorch parity guarantees. + +Scope: +- Python Vulkan runtime orchestration (`backend/*`) +- C++ bridge integration (`cpp/python/*`) +- Functional + module parity tests (`tests/parity`, targeted GPU tests) +- Performance/throughput work that directly affects user-visible training/inference latency + +--- + +## Baseline (already landed) + +- `FnnChainRecorder` with `linear` / `relu` / `softmax`, `read`, and `read_multiple` (`backend/fnn_chain.py`) +- `VulkanTensor.prepare_for_dispatch()` residency bind path (`utils/tensor_conversion.py`) +- `_prepare_input()` GPU-resident fast path (`backend/base.py`) +- Conv backward-weight GPU GEMM path and parity test (`backend/conv.py`, `tests/test_conv_backward_weight_gemm.py`) +- MoE backward stability fix at production shape (`cpp/src/ops/moe_forward.cpp`): corrected output-projection grad matmul and added bounds checks; Python binding wired via `moe_backward_gpu(...)` entrypoint with safe CPU fallback (`cpp/python/bindings_moe.cpp`) +- VSA-LM fused C++ forward/backward (`grilly_core.vsa_lm_*`): AdditionLinear FFN (L1 distance shader) + sign activation + LayerNorm + output projection in one C++ call. CPU Eigen backward with full AdditionLinear gradient (STE for sign). New files: `cpp/src/ops/vsa_lm_forward.cpp`, `cpp/include/grilly/ops/vsa_lm_forward.h`, `cpp/python/bindings_vsa_lm.cpp`, `shaders/sign-activation.glsl`. Tests: `tests/test_vsa_lm_forward.py` (shape + parity). + +--- + +## Priority Feature Roadmap (user-directed) + +Order locked by product priority: +1. GPU tokenizer +2. Sentence-transformers +3. Transformers compatibility (target: near 1:1 behavior/signature coverage) +4. PyTorch -> Grilly converter + +--- + +## P0: GPU Tokenizer (highest priority) + +### P0.1 Core GPU tokenizer runtime +- Status: `[~]` (CPU parity path landed; native GPU tokenization still open) +- Goal: + - Tokenization and detokenization run on GPU-backed buffers for high-throughput inference/training pipelines. +- Tasks: + - [x] Add `grilly.tokenizers` module with `Tokenizer` interface (`encode`, `decode`, `batch_encode`). Implementation: `tokenizer_impl/` → `grilly.tokenizers`; default `Tokenizer` is `FastTokenizer` (Rust `tokenizers` library, not `transformers`). + - [ ] Implement BPE/WordPiece fast path with GPU kernels for pretokenized merge/scoring stages. + - [x] Keep exact CPU fallback for unsupported edge cases (identical outputs). Current fallback is the Rust `tokenizers` CPU path; optional `numpy_to_input_ids_buffers` in `tokenizer_impl/gpu.py` for staging. + - [ ] Add `VulkanTensor`-friendly API: accept list[str] and return ids/attention masks in GPU-friendly layout (`wrap_ids_as_vulkan_tensors` exists behind `GRILLY_GPU_TOKENIZER=1` until wired). +- Acceptance: + - Deterministic token IDs vs reference tokenizer on supported models. + - >=2x throughput improvement vs CPU tokenizer on large batch benchmarks. + +### P0.2 HF tokenizer interoperability +- Status: `[~]` (Rust `tokenizer.json` + Hub fallbacks + BERT/DistilBERT/T5/mT5 parity tests; raw `spiece.model`-only repos still open) +- Tasks: + - [x] Load Hugging Face–compatible tokenizer assets: `tokenizer.json` from Hub (`huggingface_hub`), local dir, or file path (`tokenizer_impl/loader.py`). No `transformers` dependency in the Grilly tokenizer package. + - [x] Hub repos without root `tokenizer.json` (e.g. `google/mt5-small`): fall back to `onnx/tokenizer.json` when present — same Rust `tokenizers` pipeline, parity vs HF reference. + - [ ] Add loader for **only** SentencePiece assets (`spiece.model` / no JSON export) with encode/decode parity vs HF. + - [~] Validation suite: `tests/tokenizers/test_gpu_tokenizer_parity.py` (BERT, DistilBERT); `tests/sentencepiece/test_sentencepiece_parity.py` (`t5-small`, `google/mt5-small`, special-token cases). Current run: all 6 tests passing in local env. +- Acceptance: + - Asset compatibility documented and tested for top target checkpoints (including SentencePiece-backed models like T5/LLaMA-family tokenizers). + +--- + +## P1: Sentence-Transformers support + +### P1.1 Inference parity and API surface +- Status: `[ ]` +- Goal: + - `SentenceTransformer`-style embedding API with drop-in ergonomics for common usage. +- Tasks: + - [ ] Add `grilly.sentence_transformers` wrapper with `encode()` behavior-compatible options (`batch_size`, `normalize_embeddings`, device semantics). + - [ ] Implement pooling strategies (`mean`, `cls`, `max`) and normalization parity. + - [ ] Validate cosine similarity/semantic search outputs against reference pipelines. +- Acceptance: + - Embedding outputs within tolerance across target ST models; API docs include known deltas. + +### P1.2 GPU-first embedding pipeline +- Status: `[ ]` +- Tasks: + - [ ] Route tokenizer -> encoder -> pooling through chain recorder where possible. + - [ ] Add `rec.embedding_lookup(ids, table) -> handle` so embed -> first layer stays GPU-resident (no CPU round-trip). + - [ ] Add `read_multiple` fan-out examples for MoE-style encoder blocks. +- Acceptance: + - Single-submit batching demonstrated in benchmarked ST inference path. + +--- + +## P2: Transformers compatibility (1:1 target) + +### P2.1 Signature and config compatibility +- Status: `[ ]` +- Goal: + - Match `transformers` module signatures and config behavior for core models. +- Tasks: + - [ ] Add compatibility matrix by model family (BERT, RoBERTa, MiniLM, GPT2-class decoder). + - [ ] Match key forward signatures (`input_ids`, `attention_mask`, `token_type_ids`, `position_ids`, `past_key_values` where applicable). + - [ ] Align output objects (`last_hidden_state`, `pooler_output`, logits) and shape conventions. +- Acceptance: + - Core families pass reference compatibility tests with documented exceptions. + +### P2.2 Numerical and behavioral parity +- Status: `[ ]` +- Tasks: + - [ ] Golden parity tests versus HF forward outputs (fp32 tolerances per op/family). + - [ ] Attention behavior policy and masks parity (`causal`, `padding`, mixed masks). + - [ ] Tokenizer-model handshake tests (special tokens, truncation/padding behavior). +- Acceptance: + - "1:1" means no user-visible API breakage for covered families in documented scenarios. + +--- + +## P3: PyTorch -> Grilly converter (after compatibility) + +### P3.1 Converter core +- Status: `[ ]` +- Goal: + - Convert PyTorch/HF checkpoints and model graphs into runnable Grilly modules. +- Tasks: + - [ ] Add `grilly.convert.from_pytorch(...)` entrypoint. + - [ ] Implement state_dict key mapping + tensor layout transforms. + - [ ] Generate conversion report (mapped/unmapped params, warnings, unsupported ops). +- Acceptance: + - Supported architectures convert and run inference with parity checks. + +### P3.2 CLI + migration UX +- Status: `[ ]` +- Tasks: + - [ ] Add CLI (`python -m grilly.convert ...`) with dry-run and validation modes. + - [ ] Add migration cookbook examples from PyTorch/HF to Grilly runtime. +- Acceptance: + - Users can convert, validate, and run model with a single documented workflow. + +--- + +## Supporting Throughput Track (parallel, non-blocking) + +### T1 Chain recorder dependency-aware barriers +- Status: `[ ]` +- Tasks: + - [ ] Remove unconditional post-dispatch barriers where no RAW hazard exists. + - [ ] Add `force_barrier=True` debug mode. + - [ ] Add `rec.linear_backward(grad_out, input, weights) -> (grad_input_handle, grad_weight_handle)` for GPU backward matmul chaining. + - [ ] Add MoE fan-out microbench (4/8 experts). + +### T2 VulkanTensor residency hardening +- Status: `[ ]` +- Tasks: + - [ ] Guard for older `grilly_core` binaries lacking `gpu_handle_if_valid`. + - [ ] Add residency counters (fallback uploads/downloads) for profiling. + +### T3 C2+/C4+ infrastructure +- Status: `[ ]` +- Tasks: + - [ ] INT8 GEMM tiling and tuning (C2+) with MoQE-focused benchmarks (4-bit/8-bit expert paths). + - [ ] Add dequant-in-shader path so quantized weights do not require FP32 shadow copies during inference. + - [ ] Transfer queue + staging ring + overlap experiments (C4+). + +### T4 Autograd GPU-resident graph execution +- Status: `[ ]` +- Tasks: + - [ ] Make `Variable.backward()` detect chainable matmul -> relu -> matmul regions. + - [ ] Route detected backward regions through chain recorder (batched GPU dispatches) instead of one bridge call per topo node. + - [ ] Add fallback path to current traversal for unsupported ops while preserving gradient correctness. +- Acceptance: + - Backward pass correctness parity vs current autograd; fewer fence waits in profiled training steps. + +### T5 Fused CE + softmax chain op +- Status: `[ ]` +- Tasks: + - [ ] Add `rec.cross_entropy(logits_handle, targets) -> (loss_handle, grad_logits_handle)` fused forward+backward on GPU. + - [ ] Integrate with chain recorder API so LM training step can stay in a single submit region through loss/grad. + - [ ] Add correctness tests vs reference CE+softmax backward and benchmark on `(seq, vocab)`-sized logits. +- Acceptance: + - Fused CE op matches reference gradients within tolerance and reduces per-step synchronization overhead. + +--- + +## Verification Commands + +```bash +# Existing runtime/perf guardrails +uv run pytest tests/test_fnn_chain.py -v +uv run pytest tests/test_vulkan_tensor_residency.py -v +uv run pytest tests/test_conv_backward_weight_gemm.py -v +uv run pytest tests/parity/ -m parity -v +uv run python benchmarks/benchmark_conv_backward_weight.py +uv run python benchmarks/benchmark_int8_gemm.py + +# Tokenizer parity (Rust CPU path vs HF reference in tests) +uv run pytest tests/tokenizers/ -v +uv run pytest tests/sentencepiece/ -v +# VSA-LM fused forward/backward +uv run pytest tests/test_vsa_lm_forward.py -v +uv run pytest tests/test_moe_forward.py -v + +# Planned (new feature suites to add as implemented) +# uv run pytest tests/sentence_transformers/ -v +# uv run pytest tests/transformers_compat/ -v +# uv run pytest tests/converter/ -v +# uv run pytest tests/autograd_chain/ -v +# uv run pytest tests/moe_quant/ -v +``` + +--- + +## Exit Criteria for pre-v1.0 cut + +- GPU tokenizer shipped and validated against reference tokenizer outputs. *(Current: Rust `tokenizers` CPU path + parity tests for BERT/DistilBERT/T5/mT5; GPU merge/BPE kernels still to ship.)* +- Sentence-transformers pipeline shipped with documented parity bounds. +- Transformers compatibility target achieved for declared core families (documented matrix). +- PyTorch -> Grilly converter ships with dry-run + validation report. +- No known correctness regressions in chain recorder/residency paths. +- Performance regressions detectable with benchmark baselines committed in docs. diff --git a/examples/experimental_backend_vsa.py b/examples/experimental_backend_vsa.py index a93d67c..6aef93f 100644 --- a/examples/experimental_backend_vsa.py +++ b/examples/experimental_backend_vsa.py @@ -7,11 +7,11 @@ import numpy as np try: - from grilly.backend.base import VULKAN_AVAILABLE + from grilly.backend.base import VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE from grilly.backend.core import VulkanCore from grilly.backend.experimental.vsa import VulkanVSA - if not VULKAN_AVAILABLE: + if not VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE: raise ImportError("Vulkan not available") except ImportError: print("Vulkan backend not available. Skipping GPU examples.") diff --git a/functional/__init__.py b/functional/__init__.py index 46d0837..9789243 100644 --- a/functional/__init__.py +++ b/functional/__init__.py @@ -10,6 +10,13 @@ softmax, softplus, ) +from .mf_activations import ( + mf_relu, + mf_sigmoid, + mf_sigmoid_01, + mf_softmax, + mf_softplus, +) from .attention import ( attention, flash_attention2, @@ -96,6 +103,11 @@ "silu", "softmax", "softplus", + "mf_softmax", + "mf_softplus", + "mf_sigmoid", + "mf_sigmoid_01", + "mf_relu", # Linear "linear", # Normalization diff --git a/functional/activations.py b/functional/activations.py index 08c5a5b..e681e02 100644 --- a/functional/activations.py +++ b/functional/activations.py @@ -27,10 +27,8 @@ def relu(x: np.ndarray) -> np.ndarray: return _to_numpy(result) except (ImportError, Exception): pass - # Fallback to legacy Compute() path - from grilly import Compute - - return Compute().activation_relu(x) + x = np.asarray(x, dtype=np.float32) + return np.maximum(0.0, x).astype(np.float32) def gelu(x: np.ndarray) -> np.ndarray: @@ -46,10 +44,10 @@ def gelu(x: np.ndarray) -> np.ndarray: return _to_numpy(result) except (ImportError, Exception): pass - # Fallback to legacy Compute() path - from grilly import Compute - - return Compute().activation_gelu(x) + x = np.asarray(x, dtype=np.float32) + return ( + 0.5 * x * (1.0 + np.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * x**3))) + ).astype(np.float32) def silu(x: np.ndarray) -> np.ndarray: @@ -65,10 +63,8 @@ def silu(x: np.ndarray) -> np.ndarray: return _to_numpy(result) except (ImportError, Exception): pass - # Fallback to legacy Compute() path - from grilly import Compute - - return Compute().activation_silu(x) + x = np.asarray(x, dtype=np.float32) + return (x / (1.0 + np.exp(-x))).astype(np.float32) def softmax(x: np.ndarray, dim: int = -1) -> np.ndarray: @@ -84,10 +80,10 @@ def softmax(x: np.ndarray, dim: int = -1) -> np.ndarray: return _to_numpy(result) except (ImportError, Exception): pass - # Fallback to legacy Compute() path - from grilly import Compute - - return Compute().activation_softmax(x, dim=dim) + x = np.asarray(x, dtype=np.float32) + x_max = np.max(x, axis=dim, keepdims=True) + exp_x = np.exp(x - x_max) + return (exp_x / np.sum(exp_x, axis=dim, keepdims=True)).astype(np.float32) def softplus(x: np.ndarray) -> np.ndarray: diff --git a/functional/attention.py b/functional/attention.py index 5b2527a..2e2fc80 100644 --- a/functional/attention.py +++ b/functional/attention.py @@ -3,11 +3,40 @@ import numpy as np -def _get_backend(): - """Get compute backend""" - from grilly import Compute +def _to_numpy(result): + """Convert bridge result to numpy if it's a C++ Tensor.""" + if result is None: + return None + if isinstance(result, np.ndarray): + return result + if hasattr(result, "numpy"): + return result.numpy() + return np.asarray(result) - return Compute() + +def _numpy_softmax(x: np.ndarray, axis: int = -1) -> np.ndarray: + x_max = np.max(x, axis=axis, keepdims=True) + exp_x = np.exp(x - x_max) + return exp_x / np.sum(exp_x, axis=axis, keepdims=True) + + +def _numpy_attention( + query: np.ndarray, key: np.ndarray, value: np.ndarray, mask: np.ndarray | None = None +) -> tuple[np.ndarray, np.ndarray]: + q = np.asarray(query, dtype=np.float32) + k = np.asarray(key, dtype=np.float32) + v = np.asarray(value, dtype=np.float32) + scale = 1.0 / np.sqrt(float(q.shape[-1])) + scores = (q @ np.swapaxes(k, -1, -2)) * scale + if mask is not None: + mask_arr = np.asarray(mask) + if mask_arr.dtype == np.bool_: + scores = np.where(mask_arr, scores, -1e9) + else: + scores = scores + mask_arr.astype(np.float32) + weights = _numpy_softmax(scores, axis=-1).astype(np.float32) + output = (weights @ v).astype(np.float32) + return output, weights def attention( @@ -26,19 +55,22 @@ def attention( Returns: Tuple of (output, attention_weights) """ - backend = _get_backend() - - # Compute attention scores - scores = backend.attention_scores(query, key) - - # Apply mask if provided - if mask is not None: - scores = backend.attention_mask(scores, mask) + try: + from grilly.backend import _bridge - # Compute attention output - output = backend.attention_output(scores, value) + scores = _bridge.attention_scores(query, key) + if scores is not None: + if mask is not None: + masked_scores = _bridge.attention_mask(scores, mask, False) + if masked_scores is not None: + scores = masked_scores + output = _bridge.attention_output(scores, value) + if output is not None: + return _to_numpy(output), _to_numpy(scores) + except (ImportError, Exception): + pass - return output, scores + return _numpy_attention(query, key, value, mask) def flash_attention2( @@ -57,5 +89,23 @@ def flash_attention2( Returns: Attention output """ - backend = _get_backend() - return backend.flash_attention2(query, key, value, use_rope=use_rope) + try: + from grilly.backend import _bridge + + q = query + k = key + if use_rope: + q_rope = _bridge.rope(q) + k_rope = _bridge.rope(k) + if q_rope is not None and k_rope is not None: + q = q_rope + k = k_rope + + result = _bridge.flash_attention2(q, k, value) + if result is not None: + return _to_numpy(result) + except (ImportError, Exception): + pass + + output, _ = _numpy_attention(query, key, value, mask=None) + return output diff --git a/functional/bridge.py b/functional/bridge.py index 53e1c49..449a9ef 100644 --- a/functional/bridge.py +++ b/functional/bridge.py @@ -8,14 +8,14 @@ import numpy as np -def _get_backend(): - """Get backend instance""" - try: - from ..backend.compute import Compute - - return Compute() - except Exception: +def _to_numpy(result): + if result is None: return None + if isinstance(result, np.ndarray): + return result + if hasattr(result, "numpy"): + return result.numpy() + return np.asarray(result) def continuous_to_spikes( @@ -40,16 +40,30 @@ def continuous_to_spikes( Returns: Spike trains (batch, num_timesteps, spike_dim) """ - from grilly import Compute - - backend = Compute() - return backend.continuous_to_spikes( - continuous, - num_timesteps=num_timesteps, - encoding_type=encoding_type, - projection_weights=projection_weights, - projection_bias=projection_bias, - ) + try: + from grilly.backend import _bridge + + result = _bridge.continuous_to_spikes( + continuous, + num_timesteps=num_timesteps, + encoding_type=encoding_type, + projection_weights=projection_weights, + projection_bias=projection_bias, + ) + if result is not None: + return _to_numpy(result) + except (ImportError, Exception): + pass + x = np.asarray(continuous, dtype=np.float32) + if x.ndim == 1: + x = x[None, :] + if projection_weights is not None: + x = x @ np.asarray(projection_weights, dtype=np.float32).T + if projection_bias is not None: + x = x + np.asarray(projection_bias, dtype=np.float32) + x = np.clip(x, 0.0, 1.0) + t = int(num_timesteps) + return (np.random.rand(x.shape[0], t, x.shape[1]).astype(np.float32) < x[:, None, :]).astype(np.float32) def spikes_to_continuous( @@ -76,17 +90,29 @@ def spikes_to_continuous( Returns: Continuous values (batch, output_dim) """ - from grilly import Compute - - backend = Compute() - return backend.spikes_to_continuous( - spikes, - encoding_type=encoding_type, - time_window=time_window, - temporal_weights=temporal_weights, - projection_weights=projection_weights, - projection_bias=projection_bias, - ) + try: + from grilly.backend import _bridge + + result = _bridge.spikes_to_continuous( + spikes, + encoding_type=encoding_type, + time_window=time_window, + temporal_weights=temporal_weights, + projection_weights=projection_weights, + projection_bias=projection_bias, + ) + if result is not None: + return _to_numpy(result) + except (ImportError, Exception): + pass + s = np.asarray(spikes, dtype=np.float32) + window = max(1, min(int(time_window), s.shape[1])) + cont = np.mean(s[:, -window:, :], axis=1) + if projection_weights is not None: + cont = cont @ np.asarray(projection_weights, dtype=np.float32).T + if projection_bias is not None: + cont = cont + np.asarray(projection_bias, dtype=np.float32) + return cont.astype(np.float32) def bridge_temporal_weights(weights: np.ndarray, temporal_window: int = 10) -> np.ndarray: @@ -102,17 +128,6 @@ def bridge_temporal_weights(weights: np.ndarray, temporal_window: int = 10) -> n Returns: Temporally weighted weights """ - backend = _get_backend() - - # Try GPU shader if available - if backend and hasattr(backend, "shaders") and "bridge-temporal-weights" in backend.shaders: - try: - # GPU temporal weights would go here - # For now, use CPU fallback - pass - except Exception: - pass # Fall back to CPU - # CPU fallback - Apply temporal decay to weights temporal_decay = np.exp(-np.arange(temporal_window) / temporal_window) # Simple implementation: return weights with temporal scaling diff --git a/functional/cells.py b/functional/cells.py index f97d5b7..34ca1aa 100644 --- a/functional/cells.py +++ b/functional/cells.py @@ -7,14 +7,14 @@ import numpy as np -def _get_backend(): - """Get backend instance""" - try: - from ..backend.compute import Compute - - return Compute() - except Exception: +def _to_numpy(result): + if result is None: return None + if isinstance(result, np.ndarray): + return result + if hasattr(result, "numpy"): + return result.numpy() + return np.asarray(result) def place_cell( @@ -39,16 +39,30 @@ def place_cell( Returns: Firing rates (n_neurons,) or (batch, n_neurons) """ - from grilly import Compute - - backend = Compute() - return backend.place_cell( - agent_position, - field_centers, - field_width=field_width, - max_rate=max_rate, - baseline_rate=baseline_rate, + try: + from grilly.backend import _bridge + + result = _bridge.place_cell( + agent_position, + field_centers, + field_width=field_width, + max_rate=max_rate, + baseline_rate=baseline_rate, + ) + if result is not None: + return _to_numpy(result) + except (ImportError, Exception): + pass + pos = np.asarray(agent_position, dtype=np.float32) + centers = np.asarray(field_centers, dtype=np.float32) + if pos.ndim == 1: + pos = pos[None, :] + diff = pos[:, None, :] - centers[None, :, :] + d2 = np.sum(diff * diff, axis=-1) + rates = baseline_rate + (max_rate - baseline_rate) * np.exp( + -d2 / (2.0 * field_width * field_width + 1e-8) ) + return rates.astype(np.float32).squeeze(0) if np.asarray(agent_position).ndim == 1 else rates.astype(np.float32) def time_cell( @@ -75,17 +89,31 @@ def time_cell( Returns: (firing_rates, updated_membrane_state) """ - from grilly import Compute - - backend = Compute() - return backend.time_cell( - current_time, - preferred_times, - time_constant=temporal_width, - max_rate=max_rate, - baseline_rate=baseline_rate, - membrane_state=membrane_state, + try: + from grilly.backend import _bridge + + result = _bridge.time_cell( + current_time, + preferred_times, + time_constant=temporal_width, + max_rate=max_rate, + baseline_rate=baseline_rate, + membrane_state=membrane_state, + ) + if result is not None: + rates, mem = result + return _to_numpy(rates), _to_numpy(mem) + except (ImportError, Exception): + pass + pref = np.asarray(preferred_times, dtype=np.float32) + rates = baseline_rate + (max_rate - baseline_rate) * np.exp( + -((float(current_time) - pref) ** 2) / (2.0 * temporal_width * temporal_width + 1e-8) ) + if membrane_state is None: + mem = rates.astype(np.float32) + else: + mem = (0.9 * np.asarray(membrane_state, dtype=np.float32) + 0.1 * rates).astype(np.float32) + return rates.astype(np.float32), mem def theta_gamma_encoding( @@ -112,17 +140,6 @@ def theta_gamma_encoding( Returns: Theta-gamma encoding (n_theta * n_gamma,) """ - backend = _get_backend() - - # Try GPU shader if available - if backend and hasattr(backend, "shaders") and "theta-gamma-encoding" in backend.shaders: - try: - # GPU theta-gamma encoding would go here - # For now, use CPU fallback - pass - except Exception: - pass # Fall back to CPU - # CPU fallback theta_phase = 2 * np.pi * theta_freq * time gamma_phase = 2 * np.pi * gamma_freq * time diff --git a/functional/dropout.py b/functional/dropout.py index 9515ced..e1746e1 100644 --- a/functional/dropout.py +++ b/functional/dropout.py @@ -6,11 +6,15 @@ import numpy as np -def _get_backend(): - """Get compute backend""" - from grilly import Compute - - return Compute() +def _to_numpy(result): + """Convert bridge result to numpy if it's a C++ Tensor.""" + if result is None: + return None + if isinstance(result, np.ndarray): + return result + if hasattr(result, "numpy"): + return result.numpy() + return np.asarray(result) def dropout(input: np.ndarray, p: float = 0.5, training: bool = True) -> np.ndarray: @@ -26,8 +30,21 @@ def dropout(input: np.ndarray, p: float = 0.5, training: bool = True) -> np.ndar Returns: Output tensor with dropout applied (if training) """ + input_arr = np.asarray(input, dtype=np.float32) if not training or p == 0: - return input - - backend = _get_backend() - return backend.fnn.dropout(input, dropout_prob=p, is_training=training) + return input_arr + if p >= 1.0: + return np.zeros_like(input_arr, dtype=np.float32) + + random_mask = (np.random.rand(*input_arr.shape) >= p).astype(np.float32) + try: + from grilly.backend import _bridge + + result = _bridge.dropout(input_arr, random_mask, p=p, training=training) + if result is not None: + return _to_numpy(result) + except (ImportError, Exception): + pass + + scale = 1.0 / (1.0 - p) + return (input_arr * random_mask * scale).astype(np.float32) diff --git a/functional/embedding.py b/functional/embedding.py index fb299e8..2e60ba7 100644 --- a/functional/embedding.py +++ b/functional/embedding.py @@ -8,14 +8,14 @@ import numpy as np -def _get_backend(): - """Get backend instance""" - try: - from ..backend.compute import Compute - - return Compute() - except Exception: +def _to_numpy(result): + if result is None: return None + if isinstance(result, np.ndarray): + return result + if hasattr(result, "numpy"): + return result.numpy() + return np.asarray(result) def embedding_lookup(weight: np.ndarray, indices: np.ndarray) -> np.ndarray: @@ -31,11 +31,15 @@ def embedding_lookup(weight: np.ndarray, indices: np.ndarray) -> np.ndarray: Returns: Embeddings (batch, seq_len, embedding_dim) or (batch, embedding_dim) """ - from grilly import Compute + try: + from grilly.backend import _bridge - backend = Compute() - # Backend expects (token_ids, embedding_table) - note the order! - return backend.embedding_lookup(indices, weight) + result = _bridge.embedding_lookup(indices, weight) + if result is not None: + return _to_numpy(result) + except (ImportError, Exception): + pass + return weight[indices] def embedding_normalize(embeddings: np.ndarray, eps: float = 1e-8) -> np.ndarray: @@ -51,17 +55,6 @@ def embedding_normalize(embeddings: np.ndarray, eps: float = 1e-8) -> np.ndarray Returns: Normalized embeddings (same shape) """ - backend = _get_backend() - - # Try GPU shader if available - if backend and hasattr(backend, "shaders") and "embedding-normalize" in backend.shaders: - try: - # GPU embedding normalization would go here - # For now, use CPU fallback - pass - except Exception: - pass # Fall back to CPU - # CPU fallback norm = np.linalg.norm(embeddings, axis=-1, keepdims=True) return embeddings / (norm + eps) @@ -83,27 +76,15 @@ def embedding_position( Returns: Embeddings with positional encoding (same shape) """ - backend = _get_backend() - - # Try GPU shader if available - if backend and hasattr(backend, "shaders") and "embedding-position" in backend.shaders: - try: - # GPU positional encoding would go here - # For now, use CPU fallback - pass - except Exception: - pass # Fall back to CPU - - # CPU fallback (sinusoidal positional encoding) + # CPU fallback (vectorized sinusoidal positional encoding) batch_size, seq_len, dim = embeddings.shape - # Create positional encoding + positions = np.arange(seq_len, dtype=np.float32)[:, None] + div_term = np.exp(-(np.log(10000.0) * np.arange(0, dim, 2, dtype=np.float32) / dim)) + angles = positions * div_term[None, :] pos_enc = np.zeros((seq_len, dim), dtype=np.float32) - for pos in range(seq_len): - for i in range(0, dim, 2): - pos_enc[pos, i] = np.sin(pos / (10000 ** (i / dim))) - if i + 1 < dim: - pos_enc[pos, i + 1] = np.cos(pos / (10000 ** (i / dim))) + pos_enc[:, 0::2] = np.sin(angles) + pos_enc[:, 1::2] = np.cos(angles[:, : pos_enc[:, 1::2].shape[1]]) return embeddings + pos_enc[None, :, :] @@ -121,17 +102,6 @@ def embedding_pool(embeddings: np.ndarray, pool_type: str = "mean") -> np.ndarra Returns: Pooled embeddings (batch, dim) """ - backend = _get_backend() - - # Try GPU shader if available - if backend and hasattr(backend, "shaders") and "embedding-pool" in backend.shaders: - try: - # GPU embedding pooling would go here - # For now, use CPU fallback - pass - except Exception: - pass # Fall back to CPU - # CPU fallback if pool_type == "mean": return embeddings.mean(axis=1) @@ -167,17 +137,6 @@ def embedding_ffn( Returns: Output embeddings (batch, seq_len, dim) """ - backend = _get_backend() - - # Try GPU shader if available - if backend and hasattr(backend, "shaders") and "embedding-ffn" in backend.shaders: - try: - # GPU embedding FFN would go here - # For now, use CPU fallback - pass - except Exception: - pass # Fall back to CPU - # CPU fallback x = embeddings @ W1.T + b1 @@ -209,18 +168,8 @@ def embedding_attention( Returns: Attended embeddings (batch, seq_len, dim) """ - backend = _get_backend() - - # Try GPU shader if available - if backend and hasattr(backend, "shaders") and "embedding-attention" in backend.shaders: - try: - # GPU embedding attention would go here - # For now, use CPU fallback - pass - except Exception: - pass # Fall back to CPU - # CPU fallback (simplified) from grilly.functional.attention import attention - return attention(embeddings, embeddings, embeddings, num_heads=num_heads) + output, _ = attention(embeddings, embeddings, embeddings) + return output diff --git a/functional/faiss.py b/functional/faiss.py index 7ee0c49..7aad572 100644 --- a/functional/faiss.py +++ b/functional/faiss.py @@ -8,14 +8,14 @@ import numpy as np -def _get_backend(): - """Get backend instance""" - try: - from ..backend.compute import Compute - - return Compute() - except Exception: +def _to_numpy(result): + if result is None: return None + if isinstance(result, np.ndarray): + return result + if hasattr(result, "numpy"): + return result.numpy() + return np.asarray(result) def faiss_distance(query: np.ndarray, vectors: np.ndarray, distance_type: str = "l2") -> np.ndarray: @@ -32,25 +32,25 @@ def faiss_distance(query: np.ndarray, vectors: np.ndarray, distance_type: str = Returns: Distances (batch, num_vectors) or (num_vectors,) """ - from grilly import Compute - - backend = Compute() - if hasattr(backend, "faiss") and hasattr(backend.faiss, "compute_distances"): - return backend.faiss.compute_distances(query, vectors, distance_type=distance_type) - else: - # CPU fallback - if query.ndim == 1: - query = query.reshape(1, -1) - - if distance_type == "l2": - diff = query[:, None, :] - vectors[None, :, :] - return np.sqrt(np.sum(diff**2, axis=2)) - elif distance_type == "cosine": - q_norm = query / (np.linalg.norm(query, axis=1, keepdims=True) + 1e-8) - v_norm = vectors / (np.linalg.norm(vectors, axis=1, keepdims=True) + 1e-8) - return 1 - np.dot(q_norm, v_norm.T) - else: # dot - return -np.dot(query, vectors.T) + try: + from grilly.backend import _bridge + + result = _bridge.faiss_distance(query, vectors, distance_type=distance_type) + if result is not None: + return _to_numpy(result) + except (ImportError, Exception): + pass + + if query.ndim == 1: + query = query.reshape(1, -1) + if distance_type == "l2": + diff = query[:, None, :] - vectors[None, :, :] + return np.sqrt(np.sum(diff**2, axis=2)) + if distance_type == "cosine": + q_norm = query / (np.linalg.norm(query, axis=1, keepdims=True) + 1e-8) + v_norm = vectors / (np.linalg.norm(vectors, axis=1, keepdims=True) + 1e-8) + return 1 - np.dot(q_norm, v_norm.T) + return -np.dot(query, vectors.T) def faiss_topk(distances: np.ndarray, k: int) -> tuple[np.ndarray, np.ndarray]: @@ -66,16 +66,17 @@ def faiss_topk(distances: np.ndarray, k: int) -> tuple[np.ndarray, np.ndarray]: Returns: (indices, topk_distances) - both (batch, k) """ - from grilly import Compute + try: + from grilly.backend import _bridge - backend = Compute() - if hasattr(backend, "faiss") and hasattr(backend.faiss, "topk"): - return backend.faiss.topk(distances, k) - else: - # CPU fallback - indices = np.argsort(distances, axis=1)[:, :k] - topk_distances = np.take_along_axis(distances, indices, axis=1) - return indices, topk_distances + result = _bridge.faiss_topk(distances, k) + if result is not None: + return result + except (ImportError, Exception): + pass + indices = np.argsort(distances, axis=1)[:, :k] + topk_distances = np.take_along_axis(distances, indices, axis=1) + return indices, topk_distances def faiss_ivf_filter(vectors: np.ndarray, centroids: np.ndarray, nlist: int = 100) -> np.ndarray: @@ -92,17 +93,6 @@ def faiss_ivf_filter(vectors: np.ndarray, centroids: np.ndarray, nlist: int = 10 Returns: Cluster assignments (num_vectors,) """ - backend = _get_backend() - - # Try GPU shader if available - if backend and hasattr(backend, "shaders") and "faiss-kmeans-update" in backend.shaders: - try: - # GPU FAISS kmeans update would go here - # For now, use CPU fallback - pass - except Exception: - pass # Fall back to CPU - # CPU fallback - Assign each vector to nearest centroid distances = np.linalg.norm(vectors[:, None, :] - centroids[None, :, :], axis=2) assignments = np.argmin(distances, axis=1) @@ -126,33 +116,17 @@ def faiss_kmeans_update( Returns: Updated centroids (nlist, dim) """ - backend = _get_backend() - - # Try GPU shader if available - if backend and hasattr(backend, "shaders") and "faiss-kmeans-update" in backend.shaders: - try: - # GPU FAISS kmeans update would go here - # For now, use CPU fallback - pass - except Exception: - pass # Fall back to CPU - - # CPU fallback - new_centroids = np.zeros_like(centroids) - counts = np.zeros(nlist, dtype=np.int32) - - for i, vec in enumerate(vectors): - cluster = assignments[i] - new_centroids[cluster] += vec - counts[cluster] += 1 - - # Normalize by counts - for c in range(nlist): - if counts[c] > 0: - new_centroids[c] /= counts[c] - else: - new_centroids[c] = centroids[c] # Keep old centroid if no vectors assigned - + # CPU fallback (vectorized) + vectors = np.asarray(vectors, dtype=np.float32) + centroids = np.asarray(centroids, dtype=np.float32) + assignments = np.asarray(assignments, dtype=np.int32) + dim = vectors.shape[1] + new_centroids = np.zeros((nlist, dim), dtype=np.float32) + np.add.at(new_centroids, assignments, vectors) + counts = np.bincount(assignments, minlength=nlist).astype(np.float32) + nonzero = counts > 0 + new_centroids[nonzero] /= counts[nonzero, None] + new_centroids[~nonzero] = centroids[~nonzero] return new_centroids @@ -172,17 +146,6 @@ def faiss_quantize( Returns: (quantized_vectors, codes) - codes (num_vectors,) """ - backend = _get_backend() - - # Try GPU shader if available - if backend and hasattr(backend, "shaders") and "faiss-quantize" in backend.shaders: - try: - # GPU FAISS quantization would go here - # For now, use CPU fallback - pass - except Exception: - pass # Fall back to CPU - # CPU fallback - Find nearest codebook entry for each vector distances = np.linalg.norm(vectors[:, None, :] - codebook[None, :, :], axis=2) codes = np.argmin(distances, axis=1) diff --git a/functional/fft.py b/functional/fft.py index 6a32a73..c511155 100644 --- a/functional/fft.py +++ b/functional/fft.py @@ -3,13 +3,6 @@ import numpy as np -def _get_backend(): - """Get compute backend""" - from grilly import Compute - - return Compute() - - def fft(input: np.ndarray) -> np.ndarray: """ Fast Fourier Transform @@ -21,9 +14,7 @@ def fft(input: np.ndarray) -> np.ndarray: Returns: FFT output (complex) """ - _get_backend() - # Note: May need to implement FFT in backend if not already exposed - # CPU fallback for now + # CPU fallback return np.fft.fft(input) @@ -38,9 +29,7 @@ def ifft(input: np.ndarray) -> np.ndarray: Returns: Reconstructed signal """ - _get_backend() - # Note: May need to implement IFFT in backend if not already exposed - # CPU fallback for now + # CPU fallback return np.fft.ifft(input) @@ -55,9 +44,7 @@ def fft_magnitude(input: np.ndarray) -> np.ndarray: Returns: Magnitude spectrum """ - _get_backend() - # Note: May need to implement in backend if not already exposed - # CPU fallback for now + # CPU fallback return np.abs(input) @@ -72,7 +59,5 @@ def fft_power_spectrum(input: np.ndarray) -> np.ndarray: Returns: Power spectrum """ - _get_backend() - # Note: May need to implement in backend if not already exposed - # CPU fallback for now + # CPU fallback return np.abs(input) ** 2 diff --git a/functional/learning.py b/functional/learning.py index 982d761..ae63014 100644 --- a/functional/learning.py +++ b/functional/learning.py @@ -10,14 +10,14 @@ import numpy as np -def _get_backend(): - """Get backend instance""" - try: - from ..backend.compute import Compute - - return Compute() - except Exception: +def _to_numpy(result): + if result is None: return None + if isinstance(result, np.ndarray): + return result + if hasattr(result, "numpy"): + return result.numpy() + return np.asarray(result) def fisher_info( @@ -42,12 +42,22 @@ def fisher_info( Returns: Updated Fisher information """ - from grilly import Compute - - backend = Compute() - return backend.fisher_info_update( - gradients, fisher, momentum=momentum, use_ema=use_ema, reset=reset - ) + try: + from grilly.backend import _bridge + + result = _bridge.fisher_info_update( + gradients, fisher, momentum=momentum, use_ema=use_ema, reset=reset + ) + if result is not None: + return _to_numpy(result) + except (ImportError, Exception): + pass + gradients = np.asarray(gradients, dtype=np.float32) + fisher = np.zeros_like(gradients) if reset else np.asarray(fisher, dtype=np.float32) + g2 = gradients * gradients + if use_ema: + return (momentum * fisher + (1.0 - momentum) * g2).astype(np.float32) + return (fisher + g2).astype(np.float32) def ewc_penalty( @@ -70,10 +80,17 @@ def ewc_penalty( Returns: EWC penalty value """ - from grilly import Compute + try: + from grilly.backend import _bridge - backend = Compute() - return backend.ewc_penalty(current_params, important_params, fisher, lambda_ewc=lambda_ewc) + result = _bridge.ewc_penalty(current_params, important_params, fisher, lambda_ewc=lambda_ewc) + if result is not None: + return float(result) + except (ImportError, Exception): + pass + diff = np.asarray(current_params, dtype=np.float32) - np.asarray(important_params, dtype=np.float32) + f = np.asarray(fisher, dtype=np.float32) + return float(0.5 * lambda_ewc * np.sum(f * diff * diff, dtype=np.float32)) def natural_gradient(gradients: np.ndarray, fisher: np.ndarray, eps: float = 1e-8) -> np.ndarray: @@ -90,10 +107,17 @@ def natural_gradient(gradients: np.ndarray, fisher: np.ndarray, eps: float = 1e- Returns: Natural gradient: F^(-1) * grad """ - from grilly import Compute + try: + from grilly.backend import _bridge - backend = Compute() - return backend.natural_gradient(gradients, fisher, eps=eps) + result = _bridge.natural_gradient(gradients, fisher, eps=eps) + if result is not None: + return _to_numpy(result) + except (ImportError, Exception): + pass + gradients = np.asarray(gradients, dtype=np.float32) + fisher = np.asarray(fisher, dtype=np.float32) + return (gradients / (fisher + eps)).astype(np.float32) def fisher_normalize(fisher: np.ndarray) -> np.ndarray: @@ -108,17 +132,6 @@ def fisher_normalize(fisher: np.ndarray) -> np.ndarray: Returns: Normalized Fisher information """ - backend = _get_backend() - - # Try GPU shader if available - if backend and hasattr(backend, "shaders") and "fisher-normalize" in backend.shaders: - try: - # GPU Fisher normalization would go here - # For now, use CPU fallback - pass - except Exception: - pass # Fall back to CPU - # CPU fallback fisher_sum = np.sum(fisher) if fisher_sum > 0: @@ -140,10 +153,15 @@ def nlms_predict(x: np.ndarray, w: np.ndarray, bias: float = 0.0) -> float: Returns: Predicted value """ - from grilly import Compute + try: + from grilly.backend import _bridge - backend = Compute() - return backend.nlms_predict(x, w, bias) + result = _bridge.nlms_predict(x, w, bias) + if result is not None: + return float(result) + except (ImportError, Exception): + pass + return float(np.dot(np.asarray(x, dtype=np.float32), np.asarray(w, dtype=np.float32)) + bias) def nlms_update( @@ -165,10 +183,20 @@ def nlms_update( Returns: (updated_weights, updated_bias) """ - from grilly import Compute + try: + from grilly.backend import _bridge - backend = Compute() - return backend.nlms_update(x, y_true, w, bias, mu=mu, eps=eps) + result = _bridge.nlms_update(x, y_true, w, bias, mu=mu, eps=eps) + if result is not None: + return result + except (ImportError, Exception): + pass + x_arr = np.asarray(x, dtype=np.float32) + w_arr = np.asarray(w, dtype=np.float32) + pred = float(np.dot(x_arr, w_arr) + bias) + err = float(y_true) - pred + step = float(mu) * err / float(np.dot(x_arr, x_arr) + eps) + return (w_arr + step * x_arr).astype(np.float32), float(bias + step) def nlms_ensemble(x: np.ndarray, weights_list: list, biases_list: list) -> np.ndarray: @@ -185,17 +213,6 @@ def nlms_ensemble(x: np.ndarray, weights_list: list, biases_list: list) -> np.nd Returns: Ensemble predictions (num_experts,) """ - backend = _get_backend() - - # Try GPU shader if available - if backend and hasattr(backend, "shaders") and "nlms-ensemble" in backend.shaders: - try: - # GPU NLMS ensemble would go here - # For now, use CPU fallback - pass - except Exception: - pass # Fall back to CPU - # CPU fallback predictions = [] for w, b in zip(weights_list, biases_list): @@ -217,17 +234,6 @@ def nlms_metrics(errors: np.ndarray, update_count: int) -> dict: Returns: Dictionary with metrics (rmse, mae, etc.) """ - backend = _get_backend() - - # Try GPU shader if available - if backend and hasattr(backend, "shaders") and "nlms-metrics" in backend.shaders: - try: - # GPU NLMS metrics would go here - # For now, use CPU fallback - pass - except Exception: - pass # Fall back to CPU - # CPU fallback rmse = np.sqrt(np.mean(errors**2)) mae = np.mean(np.abs(errors)) @@ -250,10 +256,20 @@ def whitening_transform( Returns: (whitened_data, mean, std) """ - from grilly import Compute + try: + from grilly.backend import _bridge - backend = Compute() - return backend.whitening_transform(data, mean=mean, std=std) + result = _bridge.whitening_transform(data, mean=mean, std=std) + if result is not None: + return result + except (ImportError, Exception): + pass + data = np.asarray(data, dtype=np.float32) + if mean is None: + mean = np.mean(data, axis=0) + if std is None: + std = np.std(data, axis=0) + return ((data - mean) / (std + 1e-8)).astype(np.float32), mean.astype(np.float32), std.astype(np.float32) def whitening_apply(data: np.ndarray, mean: np.ndarray, std: np.ndarray) -> np.ndarray: @@ -270,17 +286,6 @@ def whitening_apply(data: np.ndarray, mean: np.ndarray, std: np.ndarray) -> np.n Returns: Whitened data """ - backend = _get_backend() - - # Try GPU shader if available - if backend and hasattr(backend, "shaders") and "whitening-apply" in backend.shaders: - try: - # GPU whitening apply would go here - # For now, use CPU fallback - pass - except Exception: - pass # Fall back to CPU - # CPU fallback return (data - mean) / (std + 1e-8) @@ -297,17 +302,6 @@ def whitening_batch_stats(data: np.ndarray) -> tuple[np.ndarray, np.ndarray]: Returns: (mean, std) """ - backend = _get_backend() - - # Try GPU shader if available - if backend and hasattr(backend, "shaders") and "whitening-batch-stats" in backend.shaders: - try: - # GPU whitening batch stats would go here - # For now, use CPU fallback - pass - except Exception: - pass # Fall back to CPU - # CPU fallback mean = np.mean(data, axis=0) std = np.std(data, axis=0) diff --git a/functional/linear.py b/functional/linear.py index 258910c..5fd0be5 100644 --- a/functional/linear.py +++ b/functional/linear.py @@ -38,7 +38,9 @@ def linear(input: np.ndarray, weight: np.ndarray, bias: np.ndarray | None = None return _to_numpy(result) except (ImportError, Exception): pass - # Fallback to legacy Compute() path - from grilly import Compute - - return Compute().fnn.linear(input, weight, bias) + input_arr = np.asarray(input, dtype=np.float32) + weight_arr = np.asarray(weight, dtype=np.float32) + output = input_arr @ weight_arr.T + if bias is not None: + output = output + np.asarray(bias, dtype=np.float32) + return np.asarray(output, dtype=np.float32) diff --git a/functional/loss.py b/functional/loss.py index ea7994e..31691b5 100644 --- a/functional/loss.py +++ b/functional/loss.py @@ -6,11 +6,14 @@ import numpy as np -def _get_backend(): - """Get compute backend""" - from grilly import Compute - - return Compute() +def _to_numpy(result): + if result is None: + return None + if isinstance(result, np.ndarray): + return result + if hasattr(result, "numpy"): + return result.numpy() + return np.asarray(result) def cross_entropy( @@ -29,9 +32,21 @@ def cross_entropy( Returns: Loss value(s) """ - _get_backend() - # Note: May need to implement in backend if not already exposed - # CPU fallback for now + try: + from grilly.backend import _bridge + + if target.ndim < input.ndim and target.dtype.kind in {"i", "u"}: + gpu_loss = _bridge.cross_entropy_loss(input, target) + if gpu_loss is not None: + loss = _to_numpy(gpu_loss) + if reduction == "mean": + return np.mean(loss) + if reduction == "sum": + return np.sum(loss) + return loss + except (ImportError, Exception): + pass + # CPU fallback input_softmax = np.exp(input - np.max(input, axis=-1, keepdims=True)) input_softmax = input_softmax / np.sum(input_softmax, axis=-1, keepdims=True) @@ -71,9 +86,7 @@ def binary_cross_entropy( Returns: Loss value(s) """ - _get_backend() - # Note: May need to implement in backend if not already exposed - # CPU fallback for now + # CPU fallback input_clamped = np.clip(input, 1e-8, 1 - 1e-8) loss = -(target * np.log(input_clamped) + (1 - target) * np.log(1 - input_clamped)) diff --git a/functional/memory.py b/functional/memory.py index 13928b0..5eabeb4 100644 --- a/functional/memory.py +++ b/functional/memory.py @@ -9,14 +9,14 @@ import numpy as np -def _get_backend(): - """Get backend instance""" - try: - from ..backend.compute import Compute - - return Compute() - except Exception: +def _to_numpy(result): + if result is None: return None + if isinstance(result, np.ndarray): + return result + if hasattr(result, "numpy"): + return result.numpy() + return np.asarray(result) def memory_read( @@ -39,12 +39,21 @@ def memory_read( Returns: Retrieved values (batch, value_dim) """ - from grilly import Compute - - backend = Compute() if temperature is None: temperature = np.sqrt(memory_keys.shape[1]) - return backend.memory_read(queries, memory_keys, memory_values, temperature) + try: + from grilly.backend import _bridge + + result = _bridge.memory_read(queries, memory_keys, memory_values, temperature) + if result is not None: + return _to_numpy(result) + except (ImportError, Exception): + pass + scores = (queries @ memory_keys.T) / max(float(temperature), 1e-8) + scores = scores - np.max(scores, axis=-1, keepdims=True) + weights = np.exp(scores) + weights = weights / np.sum(weights, axis=-1, keepdims=True) + return (weights @ memory_values).astype(np.float32) def memory_write( @@ -73,12 +82,27 @@ def memory_write( Returns: (updated_memory_keys, updated_memory_values) """ - from grilly import Compute - - backend = Compute() - return backend.memory_write( - new_key, new_value, memory_keys, memory_values, write_index, write_mode, blend_factor - ) + try: + from grilly.backend import _bridge + + result = _bridge.memory_write( + new_key, new_value, memory_keys, memory_values, write_index, write_mode, blend_factor + ) + if result is not None: + return result + except (ImportError, Exception): + pass + mk = np.array(memory_keys, dtype=np.float32, copy=True) + mv = np.array(memory_values, dtype=np.float32, copy=True) + idx = int(write_index) + if int(write_mode) == 1: + alpha = float(blend_factor) + mk[idx] = (1.0 - alpha) * mk[idx] + alpha * np.asarray(new_key, dtype=np.float32) + mv[idx] = (1.0 - alpha) * mv[idx] + alpha * np.asarray(new_value, dtype=np.float32) + else: + mk[idx] = np.asarray(new_key, dtype=np.float32) + mv[idx] = np.asarray(new_value, dtype=np.float32) + return mk, mv def memory_context_aggregate(memory_contexts: np.ndarray) -> np.ndarray: @@ -93,17 +117,6 @@ def memory_context_aggregate(memory_contexts: np.ndarray) -> np.ndarray: Returns: Aggregated context (batch, dim) """ - backend = _get_backend() - - # Try GPU shader if available - if hasattr(backend, "shaders") and "memory-context-aggregate" in backend.shaders: - try: - # GPU memory context aggregation would go here - # For now, use CPU fallback - pass - except Exception: - pass # Fall back to CPU - # CPU fallback (mean pooling) return memory_contexts.mean(axis=1) @@ -122,10 +135,16 @@ def memory_query_pooling(x: np.ndarray, W_query: np.ndarray, b_query: np.ndarray Returns: Query vectors (batch, out_dim) """ - from grilly import Compute + try: + from grilly.backend import _bridge - backend = Compute() - return backend.memory_query_pooling(x, W_query, b_query) + result = _bridge.memory_query_pooling(x, W_query, b_query) + if result is not None: + return _to_numpy(result) + except (ImportError, Exception): + pass + pooled = np.mean(np.asarray(x, dtype=np.float32), axis=1) + return (pooled @ np.asarray(W_query, dtype=np.float32).T + np.asarray(b_query, dtype=np.float32)).astype(np.float32) def memory_inject_concat( @@ -148,17 +167,6 @@ def memory_inject_concat( Returns: Output (batch, seq_len, dim) """ - backend = _get_backend() - - # Try GPU shader if available - if backend and hasattr(backend, "shaders") and "memory-inject-concat" in backend.shaders: - try: - # GPU memory injection with concat would go here - # For now, use CPU fallback - pass - except Exception: - pass # Fall back to CPU - # CPU fallback batch_size, seq_len, dim = x.shape mem_expanded = memory_context[:, None, :] # (batch, 1, dim) @@ -192,10 +200,23 @@ def memory_inject_gate( Returns: Output (batch, seq_len, dim) """ - from grilly import Compute - - backend = Compute() - return backend.memory_inject_gate(x, memory_context, W_gate, b_gate, W_mem_proj) + try: + from grilly.backend import _bridge + + result = _bridge.memory_inject_gate(x, memory_context, W_gate, b_gate, W_mem_proj) + if result is not None: + return _to_numpy(result) + except (ImportError, Exception): + pass + x_arr = np.asarray(x, dtype=np.float32) + mem = np.asarray(memory_context, dtype=np.float32) + batch, seq_len, dim = x_arr.shape + mem_expanded = np.broadcast_to(mem[:, None, :], (batch, seq_len, dim)) + concat = np.concatenate([x_arr, mem_expanded], axis=-1) + gate = 1.0 / (1.0 + np.exp(-(concat @ np.asarray(W_gate, dtype=np.float32).T + np.asarray(b_gate, dtype=np.float32)))) + mem_proj = mem @ np.asarray(W_mem_proj, dtype=np.float32).T + mem_proj = np.broadcast_to(mem_proj[:, None, :], (batch, seq_len, dim)) + return ((1.0 - gate) * x_arr + gate * mem_proj).astype(np.float32) def memory_inject_residual( @@ -218,17 +239,6 @@ def memory_inject_residual( Returns: Output (batch, seq_len, dim) """ - backend = _get_backend() - - # Try GPU shader if available - if backend and hasattr(backend, "shaders") and "memory-inject-residual" in backend.shaders: - try: - # GPU memory injection with residual would go here - # For now, use CPU fallback - pass - except Exception: - pass # Fall back to CPU - # CPU fallback batch_size, seq_len, dim = x.shape mem_proj = memory_context @ mem_proj_weight.T diff --git a/functional/mf_activations.py b/functional/mf_activations.py new file mode 100644 index 0000000..8175653 --- /dev/null +++ b/functional/mf_activations.py @@ -0,0 +1,65 @@ +""" +Multiplication-free (or low-mul) activations for VSA / addition-style pipelines. + +- **mf_softmax**: ReLU-normalized probabilities — no ``exp``; only ``max``, subtract, + ``relu``, sum, divide (same spirit as ``addition-linear.glsl``: add/sub/abs in the core). +- **mf_softplus**: Algebraic softplus — ``0.5 * (x + sqrt(x^2 + c))`` with ``c = 4/beta²``. + Smooth, no ``log``/``exp``; uses ``sqrt`` (and scalar scaling), not GEMM-style multiply. +- **mf_sigmoid**: Rational sigmoid — ``x / (1 + |x|)``, bounded in ``(-1, 1)``; scale to + ``(0, 1)`` with ``0.5 * (1 + ...)`` if needed via :func:`mf_sigmoid_01`. +""" + +from __future__ import annotations + +import numpy as np + + +def mf_softmax(x: np.ndarray, dim: int = -1, eps: float = 1e-12) -> np.ndarray: + """ReLU-normalized softmax: ``relu(x - max) / sum(relu(x - max))``. + + No exponential; suitable as a sparse, addition-heavy alternative to softmax. + """ + x = np.asarray(x, dtype=np.float32) + axis = dim if dim >= 0 else x.ndim + dim + m = np.max(x, axis=axis, keepdims=True) + z = np.maximum(x - m, 0.0).astype(np.float32) + s = np.sum(z, axis=axis, keepdims=True, dtype=np.float64) + denom = np.maximum(s, eps) + y = (z / denom).astype(np.float32) + # Degenerate: all logits tied → z is all zeros → use uniform distribution + tot = np.sum(y, axis=axis, keepdims=True) + nfeat = float(x.shape[axis]) + unif = np.ones_like(x, dtype=np.float32) / nfeat + return np.where(tot > 1e-8, y, unif).astype(np.float32) + + +def mf_softplus(x: np.ndarray, beta: float = 1.0) -> np.ndarray: + """Algebraic (exp-free) softplus: ``(x + sqrt(x^2 + 4/beta^2)) / 2``. + + Matches ``softplus`` shape qualitatively; uses ``sqrt`` instead of ``log``/``exp``. + """ + x = np.asarray(x, dtype=np.float32) + b = float(beta) + if b <= 0: + raise ValueError("beta must be positive") + c = 4.0 / (b * b) + s = np.sqrt(x * x + c) + return (0.5 * (x + s)).astype(np.float32) + + +def mf_sigmoid(x: np.ndarray) -> np.ndarray: + """Rational sigmoid: ``x / (1 + |x|)``, range ``(-1, 1)``.""" + x = np.asarray(x, dtype=np.float32) + ax = np.abs(x) + 1.0 + return (x / ax).astype(np.float32) + + +def mf_sigmoid_01(x: np.ndarray) -> np.ndarray: + """Map :func:`mf_sigmoid` to ``(0, 1)``: ``0.5 * (1 + x/(1+|x|))``.""" + return (0.5 * (1.0 + mf_sigmoid(x))).astype(np.float32) + + +def mf_relu(x: np.ndarray) -> np.ndarray: + """``max(0, x)`` — multiplication-free nonlinearity.""" + x = np.asarray(x, dtype=np.float32) + return np.maximum(x, 0.0).astype(np.float32) diff --git a/functional/normalization.py b/functional/normalization.py index c262b56..d3c438c 100644 --- a/functional/normalization.py +++ b/functional/normalization.py @@ -51,7 +51,10 @@ def layer_norm( return _to_numpy(result) except (ImportError, Exception): pass - # Fallback to legacy Compute() path - from grilly import Compute - - return Compute().layernorm(input, weight, bias, eps=eps) + input_arr = np.asarray(input, dtype=np.float32) + weight_arr = np.asarray(weight, dtype=np.float32) + bias_arr = np.asarray(bias, dtype=np.float32) + mean = np.mean(input_arr, axis=-1, keepdims=True) + var = np.var(input_arr, axis=-1, keepdims=True) + normalized = (input_arr - mean) / np.sqrt(var + eps) + return (normalized * weight_arr + bias_arr).astype(np.float32) diff --git a/nn/__init__.py b/nn/__init__.py index b4f543c..fa2218b 100644 --- a/nn/__init__.py +++ b/nn/__init__.py @@ -52,6 +52,9 @@ lt, matmul, max, + mf_sigmoid, + mf_softmax, + mf_softplus, mean, min, mse_loss, @@ -72,6 +75,7 @@ # Shapes reshape, sigmoid, + sign, silu, # Trigonometric sin, @@ -141,7 +145,13 @@ MemoryRead, MemoryWrite, ) +from .addition_linear import AdditionLinear +from . import functional +from . import init from .module import Module +from .module_list import ModuleList +from . import utils +from .vsa_lm import VsaLmModel from .modules import ( GCU, GELU, @@ -280,6 +290,10 @@ __all__ = [ # Base class "Module", + "ModuleList", + "functional", + "init", + "utils", # Standard layers "Linear", "LayerNorm", @@ -370,6 +384,9 @@ "DomainPredictor", "DomainClassifier", "ExpertCombiner", + # AdditionLinear (VSA) + "AdditionLinear", + "VsaLmModel", # Affect layers (when implemented) "AffectMLP", # Capsule layers @@ -414,6 +431,9 @@ "neg", "pow", "matmul", + "mf_sigmoid", + "mf_softmax", + "mf_softplus", # Reductions "sum", "mean", @@ -425,6 +445,7 @@ # Activations "relu", "sigmoid", + "sign", "tanh", "exp", "log", diff --git a/nn/_perf_policy.py b/nn/_perf_policy.py new file mode 100644 index 0000000..f755455 --- /dev/null +++ b/nn/_perf_policy.py @@ -0,0 +1,77 @@ +"""Adaptive runtime policy to avoid slower-than-CPU execution.""" + +from __future__ import annotations + +import os +import time + +_AUTO_FASTEST = os.getenv("GRILLY_AUTO_FASTEST", "1").strip().lower() in { + "1", + "true", + "yes", + "on", +} +_DECISIONS: dict[str, str] = {} + + +def _time_ms(fn, warmup: int = 1, repeats: int = 2): + for _ in range(max(0, warmup)): + fn() + t = [] + out = None + for _ in range(max(1, repeats)): + t0 = time.perf_counter() + out = fn() + t1 = time.perf_counter() + t.append((t1 - t0) * 1000.0) + return (sum(t) / len(t)), out + + +def choose_fastest(op_key: str, gpu_fn, cpu_fn): + """Run fastest backend for this op/shape key and cache choice.""" + if not _AUTO_FASTEST: + return gpu_fn() + + decision = _DECISIONS.get(op_key) + if decision == "gpu": + try: + out = gpu_fn() + if out is not None: + return out + except Exception: + pass + _DECISIONS[op_key] = "cpu" + return cpu_fn() + if decision == "cpu": + return cpu_fn() + + # First encounter for this shape/op: benchmark both paths. + gpu_ms = float("inf") + gpu_out = None + try: + gpu_ms, gpu_out = _time_ms(gpu_fn, warmup=1, repeats=2) + except Exception: + gpu_ms = float("inf") + gpu_out = None + + try: + cpu_ms, cpu_out = _time_ms(cpu_fn, warmup=1, repeats=2) + except Exception: + # If CPU path fails, force GPU. + _DECISIONS[op_key] = "gpu" + return gpu_out + + if gpu_out is not None and gpu_ms <= cpu_ms * 0.98: + _DECISIONS[op_key] = "gpu" + return gpu_out + + _DECISIONS[op_key] = "cpu" + return cpu_out + + +def get_perf_decisions() -> dict[str, str]: + return dict(_DECISIONS) + + +def reset_perf_decisions(): + _DECISIONS.clear() diff --git a/nn/addition_linear.py b/nn/addition_linear.py new file mode 100644 index 0000000..928c45d --- /dev/null +++ b/nn/addition_linear.py @@ -0,0 +1,150 @@ +""" +AdditionLinear — multiplication-free linear layer using L1 distance. + +output[b, o] = -sum_k |weight[o, k] - input[b, k]| + bias[o] + +Uses: addition-linear.glsl (compute shader). +Backward uses the subgradient: d/dx |a - b| = -sign(a - b). +""" + +import numpy as np + +from ._helpers import ( + _PARAMETER_AVAILABLE, + _USE_CPP_BRIDGE, + ParameterClass, + _bridge, + _bridge_to_numpy, + _create_param_wrapper, + _get_param_array, +) +from .module import Module + + +class AdditionLinear(Module): + """Multiplication-free linear layer (L1 / Manhattan distance). + + ``y = -||W - x||_1 + b`` per output neuron. Uses the ``addition-linear`` + Vulkan compute shader when available, with a NumPy CPU fallback. + """ + + def __init__(self, in_features: int, out_features: int, bias: bool = True): + super().__init__() + self.in_features = in_features + self.out_features = out_features + + limit = np.sqrt(6.0 / (in_features + out_features)) + weight_data = np.random.uniform( + -limit, limit, (out_features, in_features) + ).astype(np.float32) + + self.weight = _create_param_wrapper(weight_data) + self.register_parameter("weight", self.weight) + + if bias: + self.bias_param = _create_param_wrapper( + np.zeros(out_features, dtype=np.float32) + ) + self.register_parameter("bias", self.bias_param) + else: + self.bias_param = None + + self._last_input = None + + @property + def d_in(self) -> int: + """Alias for ``in_features`` (PyTorch-style naming).""" + return self.in_features + + def forward(self, x) -> np.ndarray: + x_arr = np.ascontiguousarray(x, dtype=np.float32) + self._last_input = x_arr + + w = _get_param_array(self.weight) + b = _get_param_array(self.bias_param) if self.bias_param is not None else None + + if _USE_CPP_BRIDGE and hasattr(_bridge, "addition_linear"): + result = _bridge_to_numpy(_bridge.addition_linear(x_arr, w, b)) + if result is not None: + return result + + if x_arr.ndim == 1: + x_arr = x_arr.reshape(1, -1) + original_shape = x_arr.shape + if x_arr.ndim > 2: + batch_seq = int(np.prod(original_shape[:-1])) + x_2d = x_arr.reshape(batch_seq, original_shape[-1]) + else: + batch_seq = original_shape[0] + x_2d = x_arr + + out = np.empty((batch_seq, self.out_features), dtype=np.float32) + for o in range(self.out_features): + dist = np.sum(np.abs(w[o] - x_2d), axis=1) + out[:, o] = -dist + if b is not None: + out[:, o] += b[o] + + if len(original_shape) > 2: + return out.reshape(original_shape[:-1] + (self.out_features,)) + return out + + def backward(self, grad_output: np.ndarray, x: np.ndarray = None) -> np.ndarray: + """Backward pass — subgradient of L1 distance. + + grad_input[s, k] = -sum_o( grad_out[s, o] * sign(W[o, k] - x[s, k]) ) + grad_W[o, k] = -sum_s( grad_out[s, o] * sign(W[o, k] - x[s, k]) ) + grad_b[o] = sum_s( grad_out[s, o] ) + """ + if x is None: + x = self._last_input + if x is None: + raise RuntimeError("backward requires input; call forward first or pass x=") + + x_arr = np.asarray(x, dtype=np.float32) + w = _get_param_array(self.weight) + go = np.asarray(grad_output, dtype=np.float32) + + if x_arr.ndim > 2: + S = int(np.prod(x_arr.shape[:-1])) + x_2d = x_arr.reshape(S, -1) + go_2d = go.reshape(S, -1) + else: + x_2d = x_arr if x_arr.ndim == 2 else x_arr.reshape(1, -1) + go_2d = go if go.ndim == 2 else go.reshape(1, -1) + + S, d_in = x_2d.shape + d_out = self.out_features + + grad_input = np.zeros_like(x_2d) + grad_w = np.zeros_like(w) + + for o in range(d_out): + sgn = np.sign(w[o][np.newaxis, :] - x_2d) # (S, d_in) + g_col = go_2d[:, o:o+1] # (S, 1) + grad_input -= g_col * sgn + grad_w[o] -= np.sum(g_col * sgn, axis=0) + + if self.weight is not None: + if not hasattr(self.weight, "grad") or self.weight.grad is None: + self.weight.grad = grad_w + else: + self.weight.grad += grad_w + + if self.bias_param is not None: + grad_b = np.sum(go_2d, axis=0) + if not hasattr(self.bias_param, "grad") or self.bias_param.grad is None: + self.bias_param.grad = grad_b + else: + self.bias_param.grad += grad_b + + if x_arr.ndim > 2: + return grad_input.reshape(x_arr.shape) + return grad_input + + def __repr__(self): + return ( + f"AdditionLinear(in_features={self.in_features}, " + f"out_features={self.out_features}, " + f"bias={self.bias_param is not None})" + ) diff --git a/nn/attention.py b/nn/attention.py index 07d5a18..9646897 100644 --- a/nn/attention.py +++ b/nn/attention.py @@ -5,7 +5,12 @@ import numpy as np -from ._helpers import _get_param_array +from ._helpers import ( + _USE_CPP_BRIDGE, + _bridge, + _bridge_to_numpy, + _get_param_array, +) from .module import Module @@ -103,51 +108,102 @@ def forward( k_reshaped = k_4d.transpose(0, 2, 1, 3) # (batch, num_heads, seq_len_k, head_dim) v_reshaped = v_4d.transpose(0, 2, 1, 3) # (batch, num_heads, seq_len_k, head_dim) - # Compute attention scores + # Fused C++ path: one submit for scores + softmax + output (no mask; Sq == Sk only). + if _USE_CPP_BRIDGE and seq_len_q == seq_len_k and mask is None: + fused = _bridge.attention_scores_softmax_output( + q_reshaped, k_reshaped, v_reshaped + ) + if fused is not None: + br_out, br_w = fused + attn_output = _bridge_to_numpy(br_out) + scores_softmax = _bridge_to_numpy(br_w) + if attn_output is not None and scores_softmax is not None: + attn_output = attn_output.astype(np.float32) + scores_softmax = scores_softmax.astype(np.float32) + self._cached_scores_pre_softmax = None + self._cached_scores = scores_softmax.copy() + self._cached_attn_output = attn_output.copy() + attn_output_reshaped = attn_output.transpose(0, 2, 1, 3).reshape( + batch_size, seq_len_q, self.embed_dim + ) + attn_weights = scores_softmax + output = self.out_proj(attn_output_reshaped) + return output, attn_weights + + # C++ bridge: scores → optional padding mask (CPU) → softmax → attention output (GPU). + # Kernel uses one sequence length for Q and K; skip when cross-attention has Sq != Sk. + if _USE_CPP_BRIDGE and seq_len_q == seq_len_k: + br_scores = _bridge.attention_scores(q_reshaped, k_reshaped) + if br_scores is not None: + scores = _bridge_to_numpy(br_scores) + if scores.shape != (batch_size, self.num_heads, seq_len_q, seq_len_k): + if scores.size == batch_size * self.num_heads * seq_len_q * seq_len_k: + scores = scores.reshape( + batch_size, self.num_heads, seq_len_q, seq_len_k + ) + else: + scores = None + if scores is not None: + self._cached_scores_pre_softmax = scores.copy() + if mask is not None: + if mask.ndim == 2: + mask_expanded = mask[:, None, :, None] + mask_expanded = np.broadcast_to(mask_expanded, scores.shape) + scores = np.where(mask_expanded > 0, scores, -1e9) + else: + scores = scores + mask.astype(np.float32) + br_soft = _bridge.softmax(scores, dim=-1) + if br_soft is not None: + scores_softmax = _bridge_to_numpy(br_soft).astype(np.float32) + self._cached_scores = scores_softmax.copy() + br_out = _bridge.attention_output(scores_softmax, v_reshaped) + if br_out is not None: + attn_output = _bridge_to_numpy(br_out).astype(np.float32) + self._cached_attn_output = attn_output.copy() + attn_output_reshaped = attn_output.transpose(0, 2, 1, 3).reshape( + batch_size, seq_len_q, self.embed_dim + ) + attn_weights = scores_softmax + output = self.out_proj(attn_output_reshaped) + return output, attn_weights + + # Fallback: legacy Python Vulkan backend + CPU softmax / einsum scores = backend.attention.attention_scores( q_reshaped, k_reshaped, num_heads=self.num_heads, head_dim=self.head_dim ) # Backend may return scores in different shape - normalize to (batch, num_heads, seq_len_q, seq_len_k) if scores.shape == (batch_size, seq_len_q, self.num_heads, seq_len_k): - # Backend returned (batch, seq_len_q, num_heads, seq_len_k) - transpose to (batch, num_heads, seq_len_q, seq_len_k) scores = scores.transpose(0, 2, 1, 3) elif scores.shape != (batch_size, self.num_heads, seq_len_q, seq_len_k): - # Unexpected shape - try to infer if ( scores.ndim == 4 and scores.size == batch_size * self.num_heads * seq_len_q * seq_len_k ): scores = scores.reshape(batch_size, self.num_heads, seq_len_q, seq_len_k) else: - # Fallback: compute manually scores = np.einsum("bhqd,bhkd->bhqk", q_reshaped, k_reshaped) / np.sqrt( - self.head_dim + float(self.head_dim) ) - # Cache pre-softmax scores for backward self._cached_scores_pre_softmax = scores.copy() - # Apply mask if provided if mask is not None: - scores = backend.attention.attention_mask(scores, mask) + if mask.ndim == 2: + mask_expanded = mask[:, None, :, None] + mask_expanded = np.broadcast_to(mask_expanded, scores.shape) + scores = np.where(mask_expanded > 0, scores, -1e9) + else: + scores = scores + mask.astype(np.float32) - # Apply softmax (CPU for now - backend softmax expects 3D) - # scores is (batch, num_heads, seq_len_q, seq_len_k) scores_max = scores.max(axis=-1, keepdims=True) scores_exp = np.exp(scores - scores_max) scores_softmax = scores_exp / scores_exp.sum(axis=-1, keepdims=True) - # Cache softmax scores for backward self._cached_scores = scores_softmax.copy() - # Compute attention output - # scores_softmax: (batch, num_heads, seq_len_q, seq_len_k) - # v_reshaped: (batch, num_heads, seq_len_k, head_dim) - # Output: (batch, num_heads, seq_len_q, head_dim) attn_output = np.einsum("bhqk,bhkd->bhqd", scores_softmax, v_reshaped) - # Cache attention output for backward (in shape: batch, num_heads, seq_len_q, head_dim) self._cached_attn_output = attn_output.copy() # Reshape back: (batch, num_heads, seq_len_q, head_dim) -> (batch, seq_len_q, embed_dim) diff --git a/nn/autograd.py b/nn/autograd.py index ec2b1bf..e440905 100644 --- a/nn/autograd.py +++ b/nn/autograd.py @@ -181,6 +181,16 @@ def numpy(self) -> np.ndarray: """Get numpy array (detached from graph).""" return self.data.copy() + def __array__(self, dtype=None) -> np.ndarray: + """NumPy array protocol — allows ``np.asarray(var)``, ``np.matmul(var, w)``, + ``np.dot(var, w)``, etc. to operate on the underlying ``self.data`` ndarray + without an explicit ``.data`` access. Required for grilly's existing + numpy-native layer ``forward`` code to accept ``Tensor`` inputs from the + ``torch_api`` facade transparently.""" + if dtype is not None: + return self.data.astype(dtype, copy=False) + return self.data + def detach(self) -> "Variable": """Return a new Variable detached from the computation graph.""" return Variable(self.data.copy(), requires_grad=False) @@ -199,6 +209,10 @@ def zero_grad(self): """Clear the gradient.""" self.grad = None + def numel(self) -> int: + """Number of elements (PyTorch ``Tensor.numel``).""" + return int(self.data.size) + def backward( self, grad_output: np.ndarray | None = None, @@ -955,11 +969,23 @@ def backward(grad): def matmul(a, b) -> Variable: - """Matrix multiplication: a @ b (with GPU backward support)""" + """Matrix multiplication: a @ b (with GPU forward + backward)""" a = _ensure_variable(a) b = _ensure_variable(b) - result_data = np.matmul(a.data, b.data) + # GPU forward: use _bridge.linear for 2D matrix multiply + result_data = None + if a.data.ndim == 2 and b.data.ndim == 2 and _grad_enabled: + try: + from grilly.backend import _bridge as _ag_bridge + # _bridge.linear(x, w) = x @ w.T, so pass b.T to get a @ b + gpu_result = _ag_bridge.linear(a.data, b.data.T) + if gpu_result is not None: + result_data = np.asarray(gpu_result) if not isinstance(gpu_result, np.ndarray) else gpu_result + except Exception: + pass + if result_data is None: + result_data = np.matmul(a.data, b.data) a_data, b_data = a.data, b.data @@ -1276,6 +1302,20 @@ def backward(grad): return Variable(result_data, requires_grad=a.requires_grad, grad_fn=grad_fn) +def sign(a) -> Variable: + """Elementwise sign; gradient is zero (subgradient).""" + a = _ensure_variable(a) + result_data = np.sign(a.data) + + def backward(grad): + """Run backward.""" + + return (np.zeros_like(a.data),) + + grad_fn = _make_backward("Sign", [a], backward) + return Variable(result_data, requires_grad=a.requires_grad, grad_fn=grad_fn) + + def clamp(a, min_val=None, max_val=None) -> Variable: """Clamp values to [min_val, max_val]""" a = _ensure_variable(a) @@ -1748,6 +1788,76 @@ def backward(grad): return Variable(result_data, requires_grad=a.requires_grad, grad_fn=grad_fn) +def mf_softmax(a, dim: int = -1, eps: float = 1e-12) -> Variable: + """ReLU-normalized softmax (no exp). See :mod:`grilly.functional.mf_activations`.""" + a = _ensure_variable(a) + axis = dim if dim >= 0 else a.data.ndim + dim + m = np.max(a.data, axis=axis, keepdims=True) + z = np.maximum(a.data - m, 0.0).astype(np.float32) + s = np.sum(z, axis=axis, keepdims=True, dtype=np.float64) + denom = np.maximum(s, eps) + y = (z / denom).astype(np.float32) + tot = np.sum(y, axis=axis, keepdims=True) + nfeat = float(a.data.shape[axis]) + unif = np.ones_like(a.data, dtype=np.float32) / nfeat + result_data = np.where(tot > 1e-8, y, unif).astype(np.float32) + + def backward(grad): + """Subgradient with STE on the max (mask active where z > 0).""" + m2 = np.max(a.data, axis=axis, keepdims=True) + z2 = np.maximum(a.data - m2, 0.0).astype(np.float32) + s2 = np.sum(z2, axis=axis, keepdims=True, dtype=np.float64) + s2 = np.maximum(s2, eps) + y2 = (z2 / s2).astype(np.float32) + tot2 = np.sum(y2, axis=axis, keepdims=True) + y2 = np.where(tot2 > 1e-8, y2, np.ones_like(a.data, dtype=np.float32) / nfeat) + gz = (grad - np.sum(grad * y2, axis=axis, keepdims=True)) / s2.astype(np.float32) + mask = (z2 > 0).astype(np.float32) + return (gz * mask,) + + grad_fn = _make_backward("MfSoftmax", [a], backward) + return Variable(result_data, requires_grad=a.requires_grad, grad_fn=grad_fn) + + +def mf_softplus(a, beta: float = 1.0) -> Variable: + """Algebraic softplus (sqrt form, no exp). See :mod:`grilly.functional.mf_activations`.""" + a = _ensure_variable(a) + b = float(beta) + if b <= 0: + raise ValueError("beta must be positive") + x = a.data + c = 4.0 / (b * b) + s = np.sqrt(x * x + c) + result_data = (0.5 * (x + s)).astype(np.float32) + + def backward(grad): + """Run backward.""" + + ds_dx = x / (s + 1e-12) + d = 0.5 * (1.0 + ds_dx) + return (grad * d,) + + grad_fn = _make_backward("MfSoftplus", [a], backward) + return Variable(result_data, requires_grad=a.requires_grad, grad_fn=grad_fn) + + +def mf_sigmoid(a) -> Variable: + """Rational sigmoid x / (1 + |x|).""" + a = _ensure_variable(a) + x = a.data + ax = np.abs(x) + 1.0 + result_data = (x / ax).astype(np.float32) + + def backward(grad): + """d/dx x/(1+|x|) = 1/(1+|x|)^2 on x!=0.""" + + denom = (1.0 + np.abs(x)) ** 2 + return (grad / denom,) + + grad_fn = _make_backward("MfSigmoid", [a], backward) + return Variable(result_data, requires_grad=a.requires_grad, grad_fn=grad_fn) + + # ============================================================================ # Trigonometric Functions # ============================================================================ diff --git a/nn/containers.py b/nn/containers.py index a4b3599..ffd9fc1 100644 --- a/nn/containers.py +++ b/nn/containers.py @@ -21,6 +21,8 @@ class Sequential(Module): Caches intermediate activations during forward pass for efficient backward pass. Automatically fuses Linear+Activation pairs when fused GPU shaders are available. + When a fused shader is missing, `fused_linear_relu` may use a two-kernel **single-submit** + path (`VulkanFNN._linear_relu_recorded_chain`) before falling back to separate calls. """ def __init__(self, *modules): diff --git a/nn/embedding.py b/nn/embedding.py index 621d1df..1f9d4a5 100644 --- a/nn/embedding.py +++ b/nn/embedding.py @@ -34,34 +34,65 @@ def __init__(self, num_embeddings: int, embedding_dim: int): # Register parameter self.register_parameter("weight", self.weight) - def forward(self, x: np.ndarray) -> np.ndarray: - """Forward pass using embedding-lookup.glsl""" - backend = self._get_backend() + def forward(self, x): + """Forward pass using embedding-lookup.glsl. + + Autograd: when called through ``Module.__call__`` with a LongTensor + input (standard pattern for token ids), the output is wrapped in a + ``Variable`` whose ``GradFn`` calls ``self.backward(grad_output, ids)`` + on loss.backward(). The GradFn has an empty inputs list because token + ids are discrete and don't receive gradients — we only use the + closure to populate ``self.weight.grad`` via the existing + ``self.backward``. + + Mirrors ``nn.Linear.forward``'s autograd wiring. + """ + try: + from grilly.nn.autograd import GradFn as _GradFn + from grilly.nn.autograd import Variable as _Variable + from grilly.nn.autograd import _grad_enabled + except ImportError: + _Variable = None # type: ignore[assignment] + _GradFn = None # type: ignore[assignment] + _grad_enabled = False + weight = _get_param_array(self.weight) - gpu_lookup_enabled = os.getenv("GRILLY_EMBEDDING_GPU_LOOKUP", "1").strip().lower() not in { - "0", - "false", - "no", - } + # Extract the raw token-id ndarray (caller may pass a LongTensor + # subclass or plain ndarray). Keep the original around so the + # backward closure has access to the indices. + ids_data = np.asarray(x) + + # Numpy fancy-index lookup. The legacy ``backend.learning.embedding_lookup`` + # expects a (batch, seq) shape and adds a leading batch dim for 1D + # inputs, which breaks downstream ops that don't expect the extra + # axis. Fancy indexing preserves the exact input shape with a + # trailing embedding dim, matches what PyTorch's Embedding does, + # and is plenty fast for the sizes we care about (LUT bandwidth + # bound — even at 8192 vocab × 384 dim it's under 1 ms). + result = weight[ids_data.astype(np.int32)] + + # ---- Autograd wiring ---- + # Always wrap the output in a Variable with a GradFn when autograd is + # enabled — Embedding's input is discrete (token ids), so there's no + # upstream Variable to chain to, but we still need the GradFn so that + # ``loss.backward()`` can flow back here and update ``weight.grad``. if ( - gpu_lookup_enabled - and hasattr(backend, "learning") - and hasattr(backend.learning, "embedding_lookup") + _GradFn is not None + and _grad_enabled + and isinstance(result, np.ndarray) + and not isinstance(result, _Variable) ): - try: - return backend.learning.embedding_lookup( - x, - weight, - return_gpu_tensor=self._return_gpu_tensor, - ) - except Exception: - pass # Fall back to CPU - - # CPU fallback - if isinstance(x, np.ndarray): - return weight[x.astype(np.int32)] - return weight[x] + def backward_fn(grad_output): + # self.backward populates self.weight.grad in-place. + # Returns None for the input gradient (token ids are discrete). + self.backward(np.asarray(grad_output), ids_data) + return () # no upstream inputs + + grad_fn = _GradFn("Embedding", backward_fn, []) + return _Variable(np.asarray(result), requires_grad=True, grad_fn=grad_fn) + + return result def backward(self, grad_output: np.ndarray, x: np.ndarray = None) -> np.ndarray: """ diff --git a/nn/functional.py b/nn/functional.py new file mode 100644 index 0000000..00189f9 --- /dev/null +++ b/nn/functional.py @@ -0,0 +1,22 @@ +"""``torch.nn.functional``-compatible API (cross-entropy, activations).""" + +from grilly.functional.mf_activations import ( + mf_relu, + mf_sigmoid, + mf_sigmoid_01, + mf_softmax, + mf_softplus, +) +from grilly.nn.autograd import gelu +from grilly.torch_api.functional import cross_entropy, softplus + +__all__ = [ + "cross_entropy", + "softplus", + "gelu", + "mf_softmax", + "mf_softplus", + "mf_sigmoid", + "mf_sigmoid_01", + "mf_relu", +] diff --git a/nn/gpu_backward.py b/nn/gpu_backward.py index 597319a..c52ec1b 100644 --- a/nn/gpu_backward.py +++ b/nn/gpu_backward.py @@ -44,10 +44,13 @@ def __init__(self, use_gpu: bool = True): if use_gpu: try: - from grilly import Compute - - self.backend = Compute() - logger.info("GPU backward operations initialized successfully") + from grilly.backend import _bridge + if _bridge.is_available(): + self.backend = _bridge + logger.info("GPU backward operations initialized via _bridge") + else: + logger.warning("_bridge not available. Falling back to CPU.") + self.use_gpu = False except Exception as e: logger.warning(f"Failed to initialize GPU backend: {e}. Falling back to CPU.") self.use_gpu = False @@ -77,15 +80,22 @@ def is_available(self) -> bool: return self.backend is not None and self.use_gpu def _has_shader(self, shader_name: str) -> bool: - """Check if a shader is available.""" + """Check if a backward op is available on _bridge.""" if not self.is_available(): return False - try: - return ( - hasattr(self.backend.core, "shaders") and shader_name in self.backend.core.shaders - ) - except Exception: - return False + # Map shader names to _bridge function names + bridge_map = { + "fnn-linear-backward": "linear_backward", + "activation-relu-backward": "relu_backward", + "activation-gelu-backward": "gelu_backward", + "activation-silu-backward": "silu_backward", + "activation-softmax-backward": "softmax_backward", + "activation-tanh-backward": "tanh_backward", + "cross-entropy-backward": "cross_entropy_backward", + "fnn-layernorm-backward": "layernorm_backward", + } + fn_name = bridge_map.get(shader_name, shader_name.replace("-", "_")) + return hasattr(self.backend, fn_name) # ======================================================================== # Linear / Fully Connected Layer @@ -131,10 +141,11 @@ def linear_backward( ) try: - # Use Grilly's existing linear_backward method! - grad_input, grad_weight, grad_bias = self.backend.fnn.linear_backward( - grad_output, input_data, weights, bias=None - ) + # Use _bridge.linear_backward directly + result = self.backend.linear_backward(grad_output, input_data, weights) + grad_input = np.asarray(result['grad_input']) if 'grad_input' in result else None + grad_weight = np.asarray(result['grad_weight']) if 'grad_weight' in result else None + grad_bias = np.asarray(result['grad_bias']) if 'grad_bias' in result else None # Filter outputs based on what was requested if not compute_input_grad: @@ -207,11 +218,8 @@ def relu_backward(self, grad_output: np.ndarray, input_data: np.ndarray) -> np.n return grad_output * (input_data > 0).astype(np.float32) try: - return self.backend.core.dispatch_shader( - "activation-relu-backward", - inputs={"grad_output": grad_output, "input_data": input_data}, - output_shape=grad_output.shape, - ) + result = self.backend.relu_backward(grad_output, input_data) + return np.asarray(result) if result is not None else grad_output * (input_data > 0).astype(np.float32) except Exception as e: logger.warning(f"GPU ReLU backward failed: {e}. Falling back to CPU.") return grad_output * (input_data > 0).astype(np.float32) @@ -241,22 +249,19 @@ def gelu_backward(self, grad_output: np.ndarray, input_data: np.ndarray) -> np.n return grad_output * (cdf_approx + x * dcdf) try: - return self.backend.core.dispatch_shader( - "activation-gelu-backward", - inputs={"grad_output": grad_output, "input_data": input_data}, - output_shape=grad_output.shape, - ) + result = self.backend.gelu_backward(grad_output, input_data) + if result is not None: + return np.asarray(result) except Exception as e: logger.warning(f"GPU GELU backward failed: {e}. Falling back to CPU.") - # CPU fallback - x = input_data - sqrt_2_pi = np.sqrt(2.0 / np.pi) - cdf_approx = 0.5 * (1.0 + np.tanh(sqrt_2_pi * (x + 0.044715 * x**3))) - inner = sqrt_2_pi * (x + 0.044715 * x**3) - tanh_inner = np.tanh(inner) - sech2 = 1 - tanh_inner**2 - dcdf = 0.5 * sech2 * sqrt_2_pi * (1 + 3 * 0.044715 * x**2) - return grad_output * (cdf_approx + x * dcdf) + x = input_data + sqrt_2_pi = np.sqrt(2.0 / np.pi) + cdf_approx = 0.5 * (1.0 + np.tanh(sqrt_2_pi * (x + 0.044715 * x**3))) + inner = sqrt_2_pi * (x + 0.044715 * x**3) + tanh_inner = np.tanh(inner) + sech2 = 1 - tanh_inner**2 + dcdf = 0.5 * sech2 * sqrt_2_pi * (1 + 3 * 0.044715 * x**2) + return grad_output * (cdf_approx + x * dcdf) def silu_backward(self, grad_output: np.ndarray, input_data: np.ndarray) -> np.ndarray: """ @@ -275,15 +280,13 @@ def silu_backward(self, grad_output: np.ndarray, input_data: np.ndarray) -> np.n return grad_output * sigmoid_x * (1 + input_data * (1 - sigmoid_x)) try: - return self.backend.core.dispatch_shader( - "activation-silu-backward", - inputs={"grad_output": grad_output, "input_data": input_data}, - output_shape=grad_output.shape, - ) + result = self.backend.silu_backward(grad_output, input_data) + if result is not None: + return np.asarray(result) except Exception as e: logger.warning(f"GPU SiLU backward failed: {e}. Falling back to CPU.") - sigmoid_x = 1.0 / (1.0 + np.exp(-input_data)) - return grad_output * sigmoid_x * (1 + input_data * (1 - sigmoid_x)) + sigmoid_x = 1.0 / (1.0 + np.exp(-input_data)) + return grad_output * sigmoid_x * (1 + input_data * (1 - sigmoid_x)) def swiglu_backward( self, grad_output: np.ndarray, input_data: np.ndarray, gate_data: np.ndarray diff --git a/nn/init.py b/nn/init.py new file mode 100644 index 0000000..9c7dff5 --- /dev/null +++ b/nn/init.py @@ -0,0 +1,21 @@ +"""``torch.nn.init``-style weight initialization (numpy arrays / Parameters).""" + +from __future__ import annotations + +from grilly.utils.initialization import ( + kaiming_normal_, + kaiming_uniform_, + normal_, + uniform_, + xavier_normal_, + xavier_uniform_, +) + +__all__ = [ + "kaiming_normal_", + "kaiming_uniform_", + "normal_", + "uniform_", + "xavier_normal_", + "xavier_uniform_", +] diff --git a/nn/linear.py b/nn/linear.py index b62fb25..5bbb60b 100644 --- a/nn/linear.py +++ b/nn/linear.py @@ -14,6 +14,7 @@ _get_param_array, ) from .module import Module +from ._perf_policy import choose_fastest class Linear(Module): @@ -187,26 +188,105 @@ def zero_grad(self): if self.bias is not None: self.register_parameter("bias", self.bias) - def forward(self, x) -> np.ndarray: - """Forward pass — GPU-first, always dispatches through GPU backend.""" + def forward(self, x): + """Forward pass — GPU-first, always dispatches through GPU backend. + + Autograd: when ``x`` is an autograd ``Variable`` (or carries + ``requires_grad``), the output is wrapped in a ``Variable`` with a + ``GradFn`` that calls ``self.backward(grad_output, x_data)`` during + ``loss.backward()``. ``self.backward`` already populates + ``self.weight.grad`` and ``self.bias.grad`` from the existing + ``fnn-linear-backward.glsl`` kernel, so the AdamW step picks them up + through ``param.grad`` like any other PyTorch-style optimizer. + + Without this wiring, ``loss.backward()`` produced gradients on the + Variable wrapping the cross-entropy logits but the chain stopped at + ``Linear`` (no ``grad_fn`` -> autograd traversal terminated), so + weights silently never updated. + """ + # Detect autograd input (Variable, or Tensor — Tensor is a Variable subclass) + # and remember its underlying ndarray for the backward closure. + try: + from grilly.nn.autograd import GradFn as _GradFn + from grilly.nn.autograd import Variable as _Variable + from grilly.nn.autograd import _grad_enabled + except ImportError: + _Variable = None # type: ignore[assignment] + _GradFn = None # type: ignore[assignment] + _grad_enabled = False + + x_var: "_Variable | None" = None # type: ignore[name-defined] + if _Variable is not None and isinstance(x, _Variable): + x_var = x + x_data = x.data + else: + x_data = x # raw ndarray (or VulkanTensor) + weight = _get_param_array(self.weight) bias = _get_param_array(self.bias) if self.bias is not None else None - # C++ bridge fast path (handles both numpy and VulkanTensor via __array__) + def cpu_linear(): + x_arr = np.asarray(x_data, dtype=np.float32) + w_arr = np.asarray(weight, dtype=np.float32) + out = x_arr @ w_arr.T + if bias is not None: + out = out + np.asarray(bias, dtype=np.float32) + return np.asarray(out, dtype=np.float32) + + # ---- Run the existing forward path (unchanged) ---- + result: np.ndarray | None = None if _USE_CPP_BRIDGE: - result = _bridge_to_numpy(_bridge.linear(x, weight, bias)) - if result is not None: - return result + def gpu_linear(): + return _bridge_to_numpy(_bridge.linear(x_data, weight, bias)) + + # Auto-fastest policy only for numpy-in/numpy-out path. + if isinstance(x_data, np.ndarray) and not self._return_gpu_tensor: + batch = int(np.prod(x_data.shape[:-1])) if x_data.ndim > 1 else 1 + in_features = int(x_data.shape[-1]) if x_data.ndim > 0 else self.in_features + op_key = f"linear:{batch}x{in_features}x{self.out_features}" + result = choose_fastest(op_key, gpu_linear, cpu_linear) + else: + result = gpu_linear() + + if result is None: + # Legacy Python Vulkan backend (fallback) or final CPU path + backend = self._get_backend() + if hasattr(backend, "fnn") and hasattr(backend.fnn, "linear"): + result = backend.fnn.linear( + x_data, + weight, + bias, + return_gpu_tensor=self._return_gpu_tensor, + ) + else: + result = cpu_linear() + + # ---- Autograd wiring ---- + # If the input came from autograd, wrap result so loss.backward() + # can flow back through this layer and populate weight.grad / bias.grad + # via the existing self.backward() implementation. We bypass + # ``_make_backward`` (which short-circuits when no input requires grad) + # because the *weight* always requires grad even if the input doesn't — + # otherwise a frozen-input training loop silently never updates weights. + if ( + x_var is not None + and _GradFn is not None + and _grad_enabled + and isinstance(result, np.ndarray) + and not isinstance(result, _Variable) + ): + x_data_for_backward = x_data # capture for closure - # Legacy Python Vulkan backend (fallback) - backend = self._get_backend() - if hasattr(backend, "fnn") and hasattr(backend.fnn, "linear"): - return backend.fnn.linear( - x, - weight, - bias, - return_gpu_tensor=self._return_gpu_tensor, - ) + def backward_fn(grad_output): + # self.backward populates self.weight.grad and self.bias.grad + # in-place and returns grad_input. + grad_input = self.backward(np.asarray(grad_output), x_data_for_backward) + return (grad_input,) + + grad_fn = _GradFn("Linear", backward_fn, [x_var]) + return _Variable(np.asarray(result), requires_grad=True, grad_fn=grad_fn) + + return result def backward(self, grad_output: np.ndarray, x: np.ndarray = None) -> np.ndarray: """ diff --git a/nn/module.py b/nn/module.py index 51e79f5..9a159f0 100644 --- a/nn/module.py +++ b/nn/module.py @@ -65,30 +65,60 @@ def __init__(self): self._use_device_local = False # DEVICE_LOCAL VRAM buffers def _get_backend(self): - """Execute get backend.""" - + """Get the legacy VulkanCompute backend if available, else None. + + The legacy backend (``backend.compute.VulkanCompute``) requires the + PyPI ``vulkan`` Python package (ctypes bindings) and is only used + today by a few slow-path fallbacks (e.g. ``fnn.xavier_init`` at + layer construction time). The real forward/backward path goes + through the C++ ``_bridge`` which only needs ``grilly_core`` and has + no Python ``vulkan`` dependency. + + Returning ``None`` when the legacy backend can't init lets callers + that already guard with ``hasattr(backend, "fnn")`` gracefully fall + through to their CPU reference path (e.g. numpy xavier init), which + is harmless at construction time since the fast path takes over at + forward time via the bridge. + """ if self._backend is None: - from grilly import Compute + try: + from ..utils.device_manager import get_device_manager - self._backend = Compute() + self._backend = get_device_manager().vulkan + except Exception: + self._backend = None return self._backend def _convert_input(self, x: np.ndarray | Any): """ Convert input to Vulkan-compatible format (GPU-first). - VulkanTensor inputs always pass through without downloading to CPU. - numpy/PyTorch inputs are converted to VulkanTensor when the C++ backend - is available, otherwise to float32 numpy arrays. - Preserves integer dtypes for index arrays (e.g., token IDs). + Pass-through priority: + 1. ``torch_api.Tensor`` / ``LongTensor`` — kept as-is so user + ``forward()`` code keeps torch-style methods (``.unsqueeze``, + ``.reshape``, ``.mean(dim=...)``). Works alongside grilly's + existing numpy-native layers because ``Variable.__array__`` + lets ``np.matmul``/``np.dot`` operate on Tensor inputs directly. + 2. ``VulkanTensor`` — GPU-resident, always pass through. + 3. Everything else (numpy, PyTorch, …) — convert to Vulkan-compatible. - Args: - x: Input (PyTorch tensor, numpy array, VulkanTensor, or other) - - Returns: - VulkanTensor or numpy array + Preserves integer dtypes for index arrays (e.g., token IDs). """ - # GPU-first: always pass VulkanTensor through without CPU round-trip + # 1. torch_api Tensor wrappers + autograd Variable: pass through + # unchanged. Variable matters because Tensor inherits from Variable + # and ops like ``.reshape`` / ``.mean(dim=...)`` / arithmetic return + # raw Variable instances rather than the Tensor subclass. + try: + from grilly.nn.autograd import Variable as _Variable + from grilly.torch_api.tensor import LongTensor as _TorchLong + from grilly.torch_api.tensor import Tensor as _TorchTensor + + if isinstance(x, (_TorchTensor, _TorchLong, _Variable)): + return x + except ImportError: + pass + + # 2. GPU-first: always pass VulkanTensor through without CPU round-trip if TENSOR_CONVERSION_AVAILABLE: from ..utils.tensor_conversion import VulkanTensor @@ -121,15 +151,56 @@ def forward(self, *args, **kwargs): raise NotImplementedError def __call__(self, *args, **kwargs): - # Automatically convert PyTorch tensor inputs to numpy - """Invoke the callable instance.""" + """Invoke the callable instance. + + If any positional arg was torch-style (``torch_api.Tensor``, + ``LongTensor``, or autograd ``Variable``) and the forward returned a + raw ndarray, wrap the output in ``Tensor`` so chained calls + (``x = embed(ids); x = layer(x); x = next(x)``) preserve torch-style + type through user-defined Module subclasses. + + Three classes count as a "torch input": + - ``LongTensor`` — model entry points use it for token ids; the next + layer (Embedding) returns floats and the chain has to start in + torch-style mode from the first step. + - ``Tensor`` — the public torch_api wrapper users construct directly. + - ``Variable`` — what ``Tensor.reshape``/``.mean(dim=...)``/arithmetic + ops actually return today (they delegate to module-level functions + that build raw Variable, not the Tensor subclass). Without this + branch the type is silently lost after the first reshape and the + chain falls back to numpy. + + ``Parameter`` (ndarray subclass) is left untouched on the way out so + re-wrapping a stored weight doesn't break gradient bookkeeping. + """ + try: + from grilly.nn.autograd import Variable as _Variable + from grilly.torch_api.tensor import LongTensor as _TorchLong + from grilly.torch_api.tensor import Tensor as _TorchTensor + + saw_torch_input = any( + isinstance(a, (_TorchTensor, _TorchLong, _Variable)) for a in args + ) + except ImportError: + _TorchTensor = None # type: ignore[assignment] + _TorchLong = None # type: ignore[assignment] + _Variable = None # type: ignore[assignment] + saw_torch_input = False converted_args = tuple(self._convert_input(arg) for arg in args) - # Convert keyword arguments that might be tensors converted_kwargs = { - k: self._convert_input(v) if self._is_tensor_like(v) else v for k, v in kwargs.items() + k: self._convert_input(v) if self._is_tensor_like(v) else v + for k, v in kwargs.items() } - return self.forward(*converted_args, **converted_kwargs) + out = self.forward(*converted_args, **converted_kwargs) + + # Wrap raw ndarray outputs back to Tensor when a torch input was seen. + # Skip ndarray subclasses (Parameter, LongTensor) and Variable outputs + # — those already carry a torch-compatible API. + if saw_torch_input and _TorchTensor is not None: + if isinstance(out, np.ndarray) and type(out) is np.ndarray: + return _TorchTensor(out, requires_grad=False) + return out def _is_tensor_like(self, obj: Any) -> bool: """Check if object is tensor-like and needs conversion""" @@ -226,6 +297,39 @@ def register_parameter(self, name: str, param: np.ndarray | None): self._parameters[name] = param + def register_buffer(self, name: str, tensor: np.ndarray | Any | None, persistent: bool = True): + """ + Register a non-trainable buffer (e.g. running state not in autograd). + + Args: + name: Buffer attribute name + tensor: Numpy array or None to remove + persistent: If False, omitted from state_dict (PyTorch parity; optional) + """ + if tensor is None: + self._buffers.pop(name, None) + if hasattr(self, name): + delattr(self, name) + return + try: + from grilly.torch_api.tensor import Tensor as _TorchTensor + except ImportError: + _TorchTensor = None # type: ignore[assignment] + if _TorchTensor is not None and isinstance(tensor, _TorchTensor): + self._buffers[name] = tensor + setattr(self, name, tensor) + elif hasattr(tensor, "data") and not isinstance(tensor, np.ndarray): + arr = np.asarray(tensor.data) + self._buffers[name] = arr + setattr(self, name, arr) + else: + arr = np.asarray(tensor) + self._buffers[name] = arr + setattr(self, name, arr) + if not persistent: + # Mark non-persistent with a private set on the buffer entry + setattr(self, f"_np_buffer_{name}", True) + def state_dict(self) -> dict[str, Any]: """Execute state dict.""" @@ -243,6 +347,19 @@ def state_dict(self) -> dict[str, Any]: state[name] = param.data.copy() # ParamWrapper else: state[name] = param + try: + from grilly.torch_api.tensor import Tensor as _TorchTensor + except ImportError: + _TorchTensor = None # type: ignore[assignment] + for name, buf in self._buffers.items(): + if getattr(self, f"_np_buffer_{name}", False): + continue + if _TorchTensor is not None and isinstance(buf, _TorchTensor): + state[name] = buf.data.copy() + elif isinstance(buf, np.ndarray): + state[name] = buf.copy() + else: + state[name] = np.asarray(buf) for name, module in self._modules.items(): state[name] = module.state_dict() return state @@ -256,10 +373,41 @@ def load_state_dict(self, state_dict: dict[str, Any]): self._parameters[name] = state_dict[name].copy() else: self._parameters[name] = state_dict[name] + try: + from grilly.torch_api.tensor import Tensor as _TorchTensor + except ImportError: + _TorchTensor = None # type: ignore[assignment] + for name in list(self._buffers.keys()): + if name in state_dict: + prev = self._buffers.get(name) + raw = state_dict[name] + if _TorchTensor is not None and isinstance(prev, _TorchTensor): + self._buffers[name] = _TorchTensor( + np.asarray(raw, dtype=np.float32).copy(), requires_grad=False + ) + else: + val = np.asarray(raw, dtype=np.float32) + if val.dtype != np.float32 and np.issubdtype(val.dtype, np.floating): + val = val.astype(np.float32) + self._buffers[name] = val + setattr(self, name, self._buffers[name]) for name, module in self._modules.items(): if name in state_dict: module.load_state_dict(state_dict[name]) + def named_buffers(self): + """Yield (name, buffer) for all buffers including child modules.""" + for name, buf in self._buffers.items(): + yield name, buf + for prefix, module in self._modules.items(): + for n, b in module.named_buffers(): + yield f"{prefix}.{n}", b + + def buffers(self): + """Iterator over buffers.""" + for _, b in self.named_buffers(): + yield b + def gpu_mode(self, enable=True, device_local=True): """Enable GPU-resident output (returns VulkanTensor instead of numpy). @@ -315,6 +463,36 @@ def llama_cpp(self): return self.to("llama-cpp") + def __setattr__(self, name: str, value: Any) -> None: + """Auto-register ``Parameter`` / child ``Module`` attributes (PyTorch parity). + + Without this hook, ``self.weight = nn.Parameter(...)`` and + ``self.lin = nn.Linear(...)`` only land in ``self.__dict__``, never in + ``self._parameters`` / ``self._modules``. ``parameters()`` then walks + empty dicts and the optimizer sees no params to update — silently + breaking every Module subclass that uses the standard PyTorch + attribute-assignment idiom. + """ + if isinstance(value, Parameter): + # Ensure base __init__ has run; if not, fall back to plain assign. + params = self.__dict__.get("_parameters") + if params is not None: + # Drop any prior child module / buffer with the same name first. + self.__dict__.get("_modules", {}).pop(name, None) + self.__dict__.get("_buffers", {}).pop(name, None) + params[name] = value + object.__setattr__(self, name, value) + return + elif isinstance(value, Module): + modules = self.__dict__.get("_modules") + if modules is not None: + self.__dict__.get("_parameters", {}).pop(name, None) + self.__dict__.get("_buffers", {}).pop(name, None) + modules[name] = value + object.__setattr__(self, name, value) + return + object.__setattr__(self, name, value) + def __repr__(self): """Return a debug representation.""" diff --git a/nn/module_list.py b/nn/module_list.py new file mode 100644 index 0000000..0992069 --- /dev/null +++ b/nn/module_list.py @@ -0,0 +1,48 @@ +"""ModuleList — PyTorch-compatible container of submodules.""" + +from __future__ import annotations + +from typing import Any, Iterator + +from .module import Module + + +class ModuleList(Module): + """Holds submodules in a list. Acts like ``nn.ModuleList`` in PyTorch.""" + + def __init__(self, modules: list[Module] | None = None): + super().__init__() + self._list: list[Module] = [] + if modules: + for i, m in enumerate(modules): + self._modules[str(i)] = m + self._list.append(m) + + def __getitem__(self, idx: int | slice) -> Module | list[Module]: + if isinstance(idx, slice): + return self._list[idx] + return self._list[idx] + + def __setitem__(self, idx: int, module: Module) -> None: + self._modules[str(idx)] = module + self._list[idx] = module + + def __len__(self) -> int: + return len(self._list) + + def __iter__(self) -> Iterator[Module]: + return iter(self._list) + + def append(self, module: Module) -> "ModuleList": + idx = len(self._list) + self._modules[str(idx)] = module + self._list.append(module) + return self + + def extend(self, modules: list[Module]) -> "ModuleList": + for m in modules: + self.append(m) + return self + + def forward(self, *args: Any, **kwargs: Any) -> Any: + raise RuntimeError("ModuleList has no forward; iterate over submodules instead.") diff --git a/nn/normalization_modules.py b/nn/normalization_modules.py index 36b8b69..f75805c 100644 --- a/nn/normalization_modules.py +++ b/nn/normalization_modules.py @@ -37,25 +37,75 @@ def __init__(self, normalized_shape: int, eps: float = 1e-5): self.register_parameter("weight", self.weight) self.register_parameter("bias", self.bias) - def forward(self, x: np.ndarray) -> np.ndarray: - """Forward pass using fnn-layernorm.glsl (GPU-first)""" + def forward(self, x): + """Forward pass using fnn-layernorm.glsl (GPU-first). + + Autograd: when ``x`` is a ``Variable``, the output is wrapped in a + ``Variable`` whose ``GradFn`` calls ``self.backward(grad_output, x)`` + on loss.backward(). ``self.backward`` already populates + ``self.weight.grad`` and ``self.bias.grad`` in-place, so AdamW picks + them up through ``param.grad``. Mirrors ``nn.Linear.forward``'s + autograd wiring — see that file for the long-form rationale. + """ + try: + from grilly.nn.autograd import GradFn as _GradFn + from grilly.nn.autograd import Variable as _Variable + from grilly.nn.autograd import _grad_enabled + except ImportError: + _Variable = None # type: ignore[assignment] + _GradFn = None # type: ignore[assignment] + _grad_enabled = False + + x_var = None + if _Variable is not None and isinstance(x, _Variable): + x_var = x + x_data = x.data + else: + x_data = x + weight = _get_param_array(self.weight) bias = _get_param_array(self.bias) - # C++ bridge fast path (handles both numpy and VulkanTensor via __array__) + # ---- Existing forward path ---- + result = None if _USE_CPP_BRIDGE: - result = _bridge_to_numpy(_bridge.layernorm(x, weight, bias, self.eps)) - if result is not None: - return result - - backend = self._get_backend() - return backend.fnn.layernorm( - x, - weight, - bias, - eps=self.eps, - return_gpu_tensor=self._return_gpu_tensor, - ) + result = _bridge_to_numpy(_bridge.layernorm(x_data, weight, bias, self.eps)) + + if result is None: + backend = self._get_backend() + if backend is not None and hasattr(backend, "fnn"): + result = backend.fnn.layernorm( + x_data, + weight, + bias, + eps=self.eps, + return_gpu_tensor=self._return_gpu_tensor, + ) + else: + # CPU fallback — matches self.backward() math + mean = np.mean(x_data, axis=-1, keepdims=True) + var = np.var(x_data, axis=-1, keepdims=True) + normalized = (x_data - mean) / np.sqrt(var + self.eps) + result = normalized * weight + bias + + # ---- Autograd wiring ---- + if ( + x_var is not None + and _GradFn is not None + and _grad_enabled + and isinstance(result, np.ndarray) + and not isinstance(result, _Variable) + ): + x_data_for_backward = x_data + + def backward_fn(grad_output): + grad_input = self.backward(np.asarray(grad_output), x_data_for_backward) + return (grad_input,) + + grad_fn = _GradFn("LayerNorm", backward_fn, [x_var]) + return _Variable(np.asarray(result), requires_grad=True, grad_fn=grad_fn) + + return result def backward(self, grad_output: np.ndarray, x: np.ndarray = None) -> np.ndarray: """ diff --git a/nn/parameter.py b/nn/parameter.py index fb9494e..cf8d42d 100644 --- a/nn/parameter.py +++ b/nn/parameter.py @@ -23,6 +23,8 @@ def __new__(cls, data: np.ndarray, requires_grad: bool = True): data: Numpy array containing parameter values requires_grad: Whether this parameter requires gradients (default: True) """ + if hasattr(data, "data") and not isinstance(data, np.ndarray): + data = getattr(data, "data", data) obj = np.asarray(data, dtype=np.float32).view(cls) obj.requires_grad = requires_grad obj.grad = None @@ -42,6 +44,45 @@ def zero_grad(self): else: self.grad = np.zeros_like(self, dtype=np.float32) + def uniform_(self, a: float = 0.0, b: float = 1.0) -> "Parameter": + """In-place uniform fill (PyTorch ``Tensor.uniform_``).""" + self[:] = np.random.uniform(a, b, self.shape).astype(np.float32) + return self + + def zero_(self) -> "Parameter": + """Fill with zeros (PyTorch ``Tensor.zero_``).""" + self.fill(0.0) + return self + + # ------------------------------------------------------------------ + # PyTorch-style shape methods so user ``forward`` code can do + # ``self.weight.unsqueeze(0)`` / ``.view(...)`` / ``.mean(dim=...)`` + # without having to know that Parameter is a numpy ndarray subclass. + # ------------------------------------------------------------------ + def unsqueeze(self, dim: int) -> np.ndarray: + """Insert a length-1 axis at *dim* (PyTorch ``Tensor.unsqueeze``).""" + return np.expand_dims(np.asarray(self), dim) + + def view(self, *shape) -> np.ndarray: + """Alias for ``reshape`` matching PyTorch's ``Tensor.view`` signature.""" + if len(shape) == 1 and isinstance(shape[0], (tuple, list)): + shape = tuple(shape[0]) + return np.asarray(self).reshape(shape) + + def mean(self, dim=None, keepdims: bool = False, **kwargs): # type: ignore[override] + """``Tensor.mean(dim=...)`` — accepts both ``dim=`` (torch) and + ``axis=`` (numpy) so it composes cleanly with ``np.mean(p, axis=...)``. + """ + axis = kwargs.pop("axis", dim) + if kwargs: + # Forward any remaining numpy kwargs (dtype, out, where, ...). + return np.asarray(self).mean(axis=axis, keepdims=keepdims, **kwargs) + return np.asarray(self).mean(axis=axis, keepdims=keepdims) + + def detach(self) -> "Parameter": + """Return a Parameter detached from any (future) autograd graph.""" + return Parameter(np.array(self, copy=True), requires_grad=False) + def __repr__(self): """Return a debug representation.""" diff --git a/nn/prefix_scan.py b/nn/prefix_scan.py new file mode 100644 index 0000000..cbf6f1f --- /dev/null +++ b/nn/prefix_scan.py @@ -0,0 +1,182 @@ +"""Causal Linear-RNN prefix scan — autograd-wrapped Python frontend. + +Wraps the C++ / Vulkan ``grilly_core.prefix_scan_causal`` and +``prefix_scan_causal_backward`` kernels into grilly's autograd system so +``loss.backward()`` flows gradients through the recurrence. + +Math: + Forward: h_t = a_t * h_{t-1} + x_t (h_0 = 0) + Backward: dx_t = dh_t + a_{t+1} * dx_{t+1} (anti-causal scan) + da_t = dx_t * h_{t-1} + +The shader runs one subgroup per (batch, hidden_dim) pair, one thread per +time step, and uses ``subgroupInclusiveAdd`` for O(log S) parallel depth. + +Constraint: ``seq_len <= 32``. Longer sequences need a hierarchical scan +(chunk the sequence, carry state between chunks) — not implemented yet. + +Example:: + + from grilly.nn.prefix_scan import prefix_scan_causal + h = prefix_scan_causal(x, a) # x, a: (B, S, D) Variables + loss = h.mean() + loss.backward() # grads flow to both x and a +""" + +from __future__ import annotations + +from typing import Tuple + +import numpy as np + +from grilly.nn.autograd import GradFn, Variable, _ensure_variable, _grad_enabled + + +def _get_bridge_device(): + """Return the grilly bridge device with shaders loaded.""" + from grilly.backend import _bridge + + dev = _bridge._get_device() + if dev is None: + raise RuntimeError( + "grilly bridge device not initialized — " + "import grilly (or grilly.torch_api) first so shaders load" + ) + return dev + + +def prefix_scan_causal(x, a) -> Variable: + """Causal Linear-RNN: ``h_t = a_t * h_{t-1} + x_t``. + + Args: + x: Input sequence, shape ``(B, S, D)``. Any ``Variable`` / ``Tensor`` + / ndarray works. + a: Decay gates in ``(0, 1]``, same shape as ``x``. + + Returns: + ``Variable`` of shape ``(B, S, D)`` with the causal hidden states. + If autograd is enabled and either input requires grad, the result + is wired into the graph via a ``GradFn`` that calls the C++ + ``prefix_scan_causal_backward`` kernel on ``loss.backward()``. + """ + import grilly_core as gc + + x_var = _ensure_variable(x) + a_var = _ensure_variable(a) + + x_data = np.asarray(x_var.data, dtype=np.float32) + a_data = np.asarray(a_var.data, dtype=np.float32) + + if x_data.shape != a_data.shape: + raise ValueError( + f"prefix_scan_causal: x shape {x_data.shape} != a shape {a_data.shape}" + ) + if x_data.ndim != 3: + raise ValueError( + f"prefix_scan_causal: inputs must be 3D (B, S, D), got {x_data.ndim}D" + ) + + S = x_data.shape[1] + if S > 32: + raise ValueError( + f"prefix_scan_causal: seq_len {S} > 32. The current shader runs " + f"one subgroup per (batch, dim) pair with one thread per time " + f"step — hierarchical multi-subgroup scan is a TODO. Either " + f"chunk the sequence on the Python side or truncate to 32 for " + f"the correctness run." + ) + + dev = _get_bridge_device() + h_data = np.asarray(gc.prefix_scan_causal(dev, x_data, a_data), dtype=np.float32) + + # ── Autograd wiring ── + requires_grad = ( + _grad_enabled + and (x_var.requires_grad or a_var.requires_grad) + ) + if not requires_grad: + return Variable(h_data, requires_grad=False) + + # Capture tensors needed by the backward closure. We save the forward + # x / a / h as immutable ndarrays so later in-place mutations on the + # caller's buffers don't corrupt the backward pass. + saved_x = x_data.copy() + saved_a = a_data.copy() + saved_h = h_data.copy() + + def backward_fn(grad_output): + # grad_output is dh coming back from downstream. Dispatch the + # anti-causal kernel and split the returned dict into (grad_x, grad_a). + grad_h = np.asarray(grad_output, dtype=np.float32) + result = gc.prefix_scan_causal_backward( + dev, grad_h, saved_a, saved_h, saved_x + ) + grad_x = np.asarray(result["grad_x"], dtype=np.float32) + grad_a = np.asarray(result["grad_a"], dtype=np.float32) + return (grad_x, grad_a) + + # GradFn inputs list order MUST match the backward return tuple order. + grad_fn = GradFn("PrefixScanCausal", backward_fn, [x_var, a_var]) + return Variable(h_data, requires_grad=True, grad_fn=grad_fn) + + +class CausalSequenceMixer: + """Subgroup-accelerated causal sequence mixer. + + Replaces the ``h.mean(dim=1)`` sequence pooling in the old LiquidCell + path, which destroyed causality by letting any time step see the future. + This module runs a proper Linear-RNN that strictly respects causal + masking. + + Architecture: + x_t = proj_x(x) # input projection + a_t = sigmoid(proj_a(x)) # decay gate in (0, 1) + h_t = a_t * h_{t-1} + x_t # causal prefix scan (GPU) + + Implemented as a regular grilly ``nn.Module``, not a subclass — kept + minimal and self-contained so it's easy to drop into the v3c script. + Use via ``mixer = CausalSequenceMixer(d); h = mixer(x)`` where ``x`` + is shape ``(B, S, D)``. + + NOTE: constructed with explicit ``grilly.nn`` imports at call time to + avoid a circular import at module load. + """ + + def __init__(self, d: int): + from grilly import nn + + self.d = d + self.proj_x = nn.Linear(d, d, bias=False) + self.proj_a = nn.Linear(d, d, bias=True) + # Initialize the gate bias to +1 so sigmoid(1) ≈ 0.73 — the model + # starts out "remembering" most of the hidden state at t=0, which + # matches the LiquidCell behavior the old code defaulted to. + try: + b_arr = np.asarray(self.proj_a.bias.data if hasattr(self.proj_a.bias, "data") + else self.proj_a.bias) + b_arr[:] = 1.0 + except Exception: + pass + + def parameters(self): + yield from self.proj_x.parameters() + yield from self.proj_a.parameters() + + def __call__(self, x): + # x: (B, S, D) — a Variable / Tensor from upstream. + x_t = self.proj_x(x) + + # Sigmoid without importing torch_api: use the identity + # sigmoid(z) = 0.5 * (1 + tanh(z/2)) + # which is numerically stable and uses only tanh, which grilly's + # autograd exposes via Variable.tanh(). + a_logits = self.proj_a(x) + if hasattr(a_logits, "tanh"): + a_t = 0.5 * (1.0 + (a_logits * 0.5).tanh()) + else: + # Fallback for plain ndarray — upstream should be a Variable. + import numpy as _np + a_t = 0.5 * (1.0 + _np.tanh(_np.asarray(a_logits) * 0.5)) + + h = prefix_scan_causal(x_t, a_t) + return h diff --git a/nn/utils.py b/nn/utils.py new file mode 100644 index 0000000..2df2888 --- /dev/null +++ b/nn/utils.py @@ -0,0 +1,61 @@ +"""``torch.nn.utils`` — gradient clipping and helpers.""" + +from __future__ import annotations + +import math +from collections.abc import Iterable +from typing import Any + +import numpy as np + + +def clip_grad_norm_( + parameters: Iterable[Any], + max_norm: float, + norm_type: float = 2.0, +) -> float: + """ + Clip gradient norm of an iterable of parameters (in-place). + + Mirrors PyTorch: per-tensor norm, then global norm of those norms. + """ + params = [p for p in parameters if p is not None] + grads: list[np.ndarray] = [] + for p in params: + g = getattr(p, "grad", None) + if g is not None: + grads.append(np.asarray(g, dtype=np.float64)) + + if not grads: + return 0.0 + + if norm_type in (math.inf, float("inf")): + total_norm = max(float(np.max(np.abs(g))) for g in grads) + else: + norms = [] + for g in grads: + if norm_type in (2.0, 2): + norms.append(math.sqrt(float(np.sum(g * g)))) + else: + norms.append(float(np.sum(np.abs(g) ** norm_type) ** (1.0 / norm_type))) + stacked = np.array(norms, dtype=np.float64) + if norm_type in (2.0, 2): + total_norm = float(np.sqrt(np.sum(stacked * stacked))) + else: + total_norm = float(np.sum(np.abs(stacked) ** norm_type) ** (1.0 / norm_type)) + + clip_coef = max_norm / (total_norm + 1e-6) if max_norm > 0 else 1.0 + if clip_coef < 1.0: + for p in params: + g = getattr(p, "grad", None) + if g is None: + continue + if isinstance(g, np.ndarray): + g *= clip_coef + else: + try: + p.grad = np.asarray(g, dtype=np.float32) * clip_coef + except Exception: + pass + + return float(total_norm) diff --git a/nn/vsa_lm.py b/nn/vsa_lm.py new file mode 100644 index 0000000..621548b --- /dev/null +++ b/nn/vsa_lm.py @@ -0,0 +1,386 @@ +""" +VsaLmModel — VSA Language Model with AdditionLinear FFN layers. + +Wraps the fused C++ forward/backward (``grilly_core.vsa_lm_*``) for +GPU-accelerated training, with transparent NumPy CPU fallback. + +Usage:: + + model = VsaLmModel(vocab=1000, d_model=256, d_ffn=512, max_seq=512, n_layers=6) + logits = model(input_ids) # (seq, vocab) + grads = model.backward(input_ids, grad_logits) # dict of gradients + model.step(grads, lr=1e-3) # SGD update +""" + +from __future__ import annotations + +import numpy as np + +from ._helpers import _get_param_array, _create_param_wrapper +from .module import Module +from .addition_linear import AdditionLinear + + +def _try_bridge(): + try: + from ..backend import _bridge + if _bridge.is_available(): + return _bridge + except Exception: + pass + return None + + +class VsaLmModel(Module): + """Full VSA-LM: Embedding → [LayerNorm → AdditionLinear FFN → Sign → AdditionLinear → Residual] × L → Output Proj. + + When the C++ fused kernel is available, ``forward``/``backward`` dispatch to + a single ``grilly_core`` call (GPU forward, Eigen backward). Otherwise + falls back to pure NumPy. + """ + + def __init__( + self, + vocab: int, + d_model: int, + d_ffn: int, + max_seq: int, + n_layers: int, + init_std: float = 0.02, + ): + super().__init__() + self.vocab = vocab + self.d_model = d_model + self.d_ffn = d_ffn + self.max_seq = max_seq + self.n_layers = n_layers + + rng = np.random.default_rng() + + self.embed_w = _create_param_wrapper( + rng.standard_normal((vocab, d_model)).astype(np.float32) * init_std + ) + self.pos_w = _create_param_wrapper( + rng.standard_normal((max_seq, d_model)).astype(np.float32) * init_std + ) + self.out_w = _create_param_wrapper( + rng.standard_normal((vocab, d_model)).astype(np.float32) * init_std + ) + + self.register_parameter("embed_w", self.embed_w) + self.register_parameter("pos_w", self.pos_w) + self.register_parameter("out_w", self.out_w) + + self.ffn_up: list[AdditionLinear] = [] + self.ffn_down: list[AdditionLinear] = [] + self.ln_gammas: list = [] + self.ln_betas: list = [] + + for l in range(n_layers): + up = AdditionLinear(d_model, d_ffn, bias=True) + down = AdditionLinear(d_ffn, d_model, bias=True) + self.ffn_up.append(up) + self.ffn_down.append(down) + self._modules[f"ffn_up_{l}"] = up + self._modules[f"ffn_down_{l}"] = down + + gamma = _create_param_wrapper(np.ones(d_model, dtype=np.float32)) + beta = _create_param_wrapper(np.zeros(d_model, dtype=np.float32)) + self.ln_gammas.append(gamma) + self.ln_betas.append(beta) + self.register_parameter(f"ln_gamma_{l}", gamma) + self.register_parameter(f"ln_beta_{l}", beta) + + self._handle: int | None = None + self._bridge = _try_bridge() + + # ── GPU handle lifecycle ──────────────────────────────────────────── + + def _upload(self): + """Upload all weights to GPU via bridge (idempotent).""" + if self._handle is not None: + return + br = self._bridge + if br is None: + return + + h = br.vsa_lm_upload( + _get_param_array(self.embed_w), + _get_param_array(self.pos_w), + [_get_param_array(u.weight) for u in self.ffn_up], + [_get_param_array(u.bias_param) for u in self.ffn_up], + [_get_param_array(d.weight) for d in self.ffn_down], + [_get_param_array(d.bias_param) for d in self.ffn_down], + [_get_param_array(g) for g in self.ln_gammas], + [_get_param_array(b) for b in self.ln_betas], + _get_param_array(self.out_w), + self.n_layers, self.d_model, self.d_ffn, + ) + self._handle = h + + def _sync_weights(self): + """Re-upload updated weights to existing GPU handle.""" + br = self._bridge + if br is None or self._handle is None: + return + br.vsa_lm_update_weights( + self._handle, + _get_param_array(self.embed_w), + _get_param_array(self.pos_w), + [_get_param_array(u.weight) for u in self.ffn_up], + [_get_param_array(u.bias_param) for u in self.ffn_up], + [_get_param_array(d.weight) for d in self.ffn_down], + [_get_param_array(d.bias_param) for d in self.ffn_down], + [_get_param_array(g) for g in self.ln_gammas], + [_get_param_array(b) for b in self.ln_betas], + _get_param_array(self.out_w), + ) + + def release(self): + """Free GPU resources.""" + if self._handle is not None and self._bridge is not None: + self._bridge.vsa_lm_release(self._handle) + self._handle = None + + def __del__(self): + try: + self.release() + except Exception: + pass + + # ── Forward ────────────────────────────────────────────────────────── + + def forward(self, input_ids: np.ndarray) -> np.ndarray: + """Run forward pass; returns logits (seq_len, vocab). + + Tries the fused GPU path first, falls back to NumPy. + """ + ids = np.ascontiguousarray(input_ids, dtype=np.int32) + + # GPU path + self._upload() + if self._handle is not None and self._bridge is not None: + result = self._bridge.vsa_lm_forward(self._handle, ids) + if result is not None: + return result + + # NumPy fallback + return self._forward_numpy(ids) + + def _forward_numpy(self, ids: np.ndarray) -> np.ndarray: + S = ids.shape[0] + d = self.d_model + embed = _get_param_array(self.embed_w) + pos = _get_param_array(self.pos_w) + out_w = _get_param_array(self.out_w) + + x = embed[ids.astype(np.int64)] + pos[:S] + + for l in range(self.n_layers): + gamma = _get_param_array(self.ln_gammas[l]) + beta = _get_param_array(self.ln_betas[l]) + + # LayerNorm + mean = x.mean(axis=-1, keepdims=True) + var = x.var(axis=-1, keepdims=True) + h = ((x - mean) / np.sqrt(var + 1e-5) * gamma + beta).astype(np.float32) + + # AdditionLinear up → sign → AdditionLinear down + h_up = self.ffn_up[l].forward(h) + h_sign = np.where(h_up > 0, 1.0, -1.0).astype(np.float32) + h_ffn = self.ffn_down[l].forward(h_sign) + + x = x + h_ffn + + scale = 1.0 / np.sqrt(np.float32(d)) + return (x @ out_w.T * scale).astype(np.float32) + + # ── Backward ───────────────────────────────────────────────────────── + + def backward(self, input_ids: np.ndarray, grad_logits: np.ndarray) -> dict: + """Run backward pass; returns gradient dict. + + GPU path returns one dict from C++. Fallback stores grads on parameters. + """ + ids = np.ascontiguousarray(input_ids, dtype=np.int32) + gl = np.ascontiguousarray(grad_logits, dtype=np.float32) + + if self._handle is not None and self._bridge is not None: + result = self._bridge.vsa_lm_backward(self._handle, ids, gl) + if result is not None: + self._scatter_grads(result) + return result + + return self._backward_numpy(ids, gl) + + def _scatter_grads(self, grads: dict): + """Store C++ gradient arrays onto parameter .grad attributes.""" + def _acc(param, g): + if not hasattr(param, "grad") or param.grad is None: + param.grad = g.copy() + else: + param.grad += g + + _acc(self.embed_w, grads["grad_embed"]) + _acc(self.pos_w, grads["grad_pos"]) + _acc(self.out_w, grads["grad_out_w"]) + + for l in range(self.n_layers): + _acc(self.ffn_up[l].weight, grads["grad_ffn_up_w"][l]) + if self.ffn_up[l].bias_param is not None: + _acc(self.ffn_up[l].bias_param, grads["grad_ffn_up_b"][l]) + _acc(self.ffn_down[l].weight, grads["grad_ffn_down_w"][l]) + if self.ffn_down[l].bias_param is not None: + _acc(self.ffn_down[l].bias_param, grads["grad_ffn_down_b"][l]) + _acc(self.ln_gammas[l], grads["grad_ln_gamma"][l]) + _acc(self.ln_betas[l], grads["grad_ln_beta"][l]) + + def _backward_numpy(self, ids: np.ndarray, grad_logits: np.ndarray) -> dict: + """Pure NumPy backward (slow but correct reference).""" + S = ids.shape[0] + d = self.d_model + dF = self.d_ffn + V = self.vocab + L = self.n_layers + + embed = _get_param_array(self.embed_w) + pos = _get_param_array(self.pos_w) + out_w = _get_param_array(self.out_w) + + # Forward replay for activations + acts = [None] * (L + 1) + ln_outs = [None] * L + ffn_ups = [None] * L + sign_outs = [None] * L + + acts[0] = embed[ids.astype(np.int64)] + pos[:S] + + for l in range(L): + gamma = _get_param_array(self.ln_gammas[l]) + beta = _get_param_array(self.ln_betas[l]) + uw = _get_param_array(self.ffn_up[l].weight) + ub = _get_param_array(self.ffn_up[l].bias_param) + dw = _get_param_array(self.ffn_down[l].weight) + db = _get_param_array(self.ffn_down[l].bias_param) + + mean = acts[l].mean(axis=-1, keepdims=True) + var = acts[l].var(axis=-1, keepdims=True) + ln_outs[l] = ((acts[l] - mean) / np.sqrt(var + 1e-5) * gamma + beta).astype(np.float32) + + ffn_ups[l] = np.empty((S, dF), dtype=np.float32) + for o in range(dF): + ffn_ups[l][:, o] = -np.sum(np.abs(uw[o] - ln_outs[l]), axis=1) + ub[o] + + sign_outs[l] = np.where(ffn_ups[l] > 0, 1.0, -1.0).astype(np.float32) + + ffn_down_out = np.empty((S, d), dtype=np.float32) + for o in range(d): + ffn_down_out[:, o] = -np.sum(np.abs(dw[o] - sign_outs[l]), axis=1) + db[o] + + acts[l + 1] = acts[l] + ffn_down_out + + scale = 1.0 / np.sqrt(np.float32(d)) + gl_scaled = grad_logits * scale + + dx = gl_scaled @ out_w + grad_out_w = gl_scaled.T @ acts[L] + + grads = { + "grad_out_w": grad_out_w, + "grad_ffn_up_w": [], "grad_ffn_up_b": [], + "grad_ffn_down_w": [], "grad_ffn_down_b": [], + "grad_ln_gamma": [], "grad_ln_beta": [], + } + + for l in range(L - 1, -1, -1): + uw = _get_param_array(self.ffn_up[l].weight) + dw = _get_param_array(self.ffn_down[l].weight) + + # Addition-linear down backward + g_dn_w = np.zeros_like(dw) + g_dn_b = np.sum(dx, axis=0) + grad_sign = np.zeros((S, dF), dtype=np.float32) + for o in range(d): + sgn = np.sign(dw[o][np.newaxis, :] - sign_outs[l]) + g_col = dx[:, o:o+1] + g_dn_w[o] -= np.sum(g_col * sgn, axis=0) + grad_sign -= g_col * sgn + + # STE through sign + grad_ffn_up = grad_sign + + g_up_w = np.zeros_like(uw) + g_up_b = np.sum(grad_ffn_up, axis=0) + grad_ln = np.zeros((S, d), dtype=np.float32) + for o in range(dF): + sgn = np.sign(uw[o][np.newaxis, :] - ln_outs[l]) + g_col = grad_ffn_up[:, o:o+1] + g_up_w[o] -= np.sum(g_col * sgn, axis=0) + grad_ln -= g_col * sgn + + grads["grad_ffn_down_w"].insert(0, g_dn_w) + grads["grad_ffn_down_b"].insert(0, g_dn_b) + grads["grad_ffn_up_w"].insert(0, g_up_w) + grads["grad_ffn_up_b"].insert(0, g_up_b) + + gamma = _get_param_array(self.ln_gammas[l]) + mean = acts[l].mean(axis=-1, keepdims=True) + var = acts[l].var(axis=-1, keepdims=True) + inv_std = 1.0 / np.sqrt(var + 1e-5) + x_hat = (acts[l] - mean) * inv_std + + g_ln_gamma = np.sum(grad_ln * x_hat, axis=0) + g_ln_beta = np.sum(grad_ln, axis=0) + grads["grad_ln_gamma"].insert(0, g_ln_gamma) + grads["grad_ln_beta"].insert(0, g_ln_beta) + + dl_dxhat = grad_ln * gamma + sum1 = dl_dxhat.sum(axis=-1, keepdims=True) + sum2 = (dl_dxhat * x_hat).sum(axis=-1, keepdims=True) + grad_x_ln = inv_std * (dl_dxhat - sum1 / d - x_hat * sum2 / d) + + dx = dx + grad_x_ln + + grad_embed = np.zeros((V, d), dtype=np.float32) + grad_pos = np.zeros((self.max_seq, d), dtype=np.float32) + for s in range(S): + tok = int(ids[s]) + if 0 <= tok < V: + grad_embed[tok] += dx[s] + grad_pos[s] = dx[s] + + grads["grad_embed"] = grad_embed + grads["grad_pos"] = grad_pos + + self._scatter_grads(grads) + return grads + + # ── Optimizer step ─────────────────────────────────────────────────── + + def zero_grad(self): + """Reset all parameter gradients to zero.""" + for p in self.parameters(): + if hasattr(p, "grad"): + p.grad = None + + def step(self, grads: dict = None, lr: float = 1e-3): + """Simple SGD step using stored .grad or explicit grads dict. + + After stepping, re-uploads weights to GPU. + """ + if grads is not None: + self._scatter_grads(grads) + + for p in self.parameters(): + arr = _get_param_array(p) + if hasattr(p, "grad") and p.grad is not None: + arr -= lr * p.grad + + self._sync_weights() + + def __repr__(self): + return ( + f"VsaLmModel(vocab={self.vocab}, d_model={self.d_model}, " + f"d_ffn={self.d_ffn}, max_seq={self.max_seq}, " + f"n_layers={self.n_layers})" + ) diff --git a/pyproject.toml b/pyproject.toml index bc9367a..309bd6a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,10 +60,20 @@ dev = [ "isort>=5.12.0", "mypy>=1.5.0", "mkdocs-material>=9.0", + "tokenizers>=0.15.0", + "huggingface_hub>=0.20.0", + "sentencepiece>=0.2.0", + "transformers>=4.57.6", + "protobuf>=4.0.0", ] accel = [ "numba>=0.59.0", # JIT-compiled CPU fallbacks ] +# Rust tokenizers + Hub file download (no transformers); used by grilly.tokenizers.FastTokenizer +tokenizer = [ + "tokenizers>=0.15.0", + "huggingface_hub>=0.20.0", +] all = [ "grilly[dev,accel]", ] @@ -77,6 +87,7 @@ Documentation = "https://grilly.org/docs" [tool.setuptools] packages = [ "grilly", + "grilly.torch_api", "grilly.backend", "grilly.grilly_datasets", "grilly.examples", @@ -85,6 +96,7 @@ packages = [ "grilly.optim", "grilly.utils", "grilly.functional", + "grilly.tokenizers", "grilly.scripts", "grilly.tutorials", "grilly.tests", @@ -93,6 +105,7 @@ packages = [ [tool.setuptools.package-dir] "grilly" = "." +"grilly.torch_api" = "torch_api" "grilly.backend" = "backend" "grilly.nn" = "nn" "grilly.optim" = "optim" @@ -105,6 +118,7 @@ packages = [ "grilly.experimental_datasets" = "experimental_datasets" "grilly.grilly_datasets" = "grilly_datasets" "grilly.examples" = "examples" +"grilly.tokenizers" = "tokenizer_impl" [tool.setuptools.package-data] grilly = [ diff --git a/rebuild.ps1 b/rebuild.ps1 new file mode 100644 index 0000000..c849228 --- /dev/null +++ b/rebuild.ps1 @@ -0,0 +1,148 @@ +# rebuild.ps1 — one-shot grilly rebuild: compile shaders + build grilly_core + copy .pyd +# +# Usage (from grilly root): +# ./rebuild.ps1 # full rebuild: shaders + grilly_core + copy +# ./rebuild.ps1 -SkipShaders # skip shader compile (if only C++ changed) +# ./rebuild.ps1 -SkipBuild # just compile shaders + copy existing .pyd +# ./rebuild.ps1 -Clean # wipe build2 and reconfigure from scratch +# ./rebuild.ps1 -Verbose # print each file as it processes +# +# Exits non-zero on shader or build failure so you can chain it in a Makefile +# or git hook if you're feeling fancy. + +[CmdletBinding()] +param( + [switch]$SkipShaders, + [switch]$SkipBuild, + [switch]$Clean +) + +$ErrorActionPreference = 'Stop' +$grillyRoot = $PSScriptRoot +Set-Location $grillyRoot + +function Write-Step { + param([string]$Text) + Write-Host "" + Write-Host $Text -ForegroundColor Cyan +} + +Write-Host "=== grilly rebuild ===" -ForegroundColor Cyan +Write-Host "Root: $grillyRoot" + +# ── 1. Compile GLSL -> SPIR-V ────────────────────────────────────────────── +if (-not $SkipShaders) { + Write-Step "[1/3] Compiling shaders..." + + $shaderDir = Join-Path $grillyRoot "shaders" + $spvDir = Join-Path $shaderDir "spv" + New-Item -ItemType Directory -Force -Path $spvDir | Out-Null + + # Verify glslangValidator is on PATH (ships with Vulkan SDK). + $glslang = "glslangValidator" + $glslangOk = $false + try { + $null = & $glslang --version 2>&1 + if ($LASTEXITCODE -eq 0) { $glslangOk = $true } + } catch {} + if (-not $glslangOk) { + Write-Host "ERROR: glslangValidator not found on PATH." -ForegroundColor Red + Write-Host " Install the Vulkan SDK (https://vulkan.lunarg.com/) or" -ForegroundColor Red + Write-Host " add `$env:VULKAN_SDK/Bin to `$env:Path." -ForegroundColor Red + exit 1 + } + + $total = 0 + $compiled = 0 + $skipped = 0 + $failed = 0 + $failedNames = @() + + Get-ChildItem -Path $shaderDir -Filter "*.glsl" | ForEach-Object { + $total++ + $src = $_.FullName + $dst = Join-Path $spvDir ($_.BaseName + ".spv") + + # Skip if .spv is newer than .glsl + if ((Test-Path $dst) -and ((Get-Item $dst).LastWriteTime -gt $_.LastWriteTime)) { + $skipped++ + if ($VerbosePreference -eq 'Continue') { + Write-Host " SKIP $($_.BaseName) (up-to-date)" -ForegroundColor DarkGray + } + return + } + + # -S comp : treat as compute shader (glslangValidator can't infer + # the stage from the generic .glsl extension) + # --target-env vulkan1.3 : enables cooperative matrix + subgroup extensions + $output = & $glslang -V -S comp $src -o $dst --target-env vulkan1.3 2>&1 + if ($LASTEXITCODE -eq 0) { + $compiled++ + if ($VerbosePreference -eq 'Continue') { + Write-Host " OK $($_.BaseName)" -ForegroundColor DarkGreen + } + } else { + $failed++ + $failedNames += $_.BaseName + Write-Host " FAIL $($_.BaseName)" -ForegroundColor Red + $output | ForEach-Object { Write-Host " $_" -ForegroundColor DarkRed } + } + } + + $summary = " Shaders: $compiled compiled, $skipped up-to-date, $failed failed / $total total" + if ($failed -gt 0) { + Write-Host $summary -ForegroundColor Yellow + Write-Host "ERROR: $failed shader(s) failed to compile: $($failedNames -join ', ')" -ForegroundColor Red + exit 1 + } else { + Write-Host $summary -ForegroundColor Green + } +} + +# ── 2. Build grilly_core (build2/Release) ──────────────────────────────── +if (-not $SkipBuild) { + Write-Step "[2/3] Building grilly_core (build2/Release)..." + + $buildDir = Join-Path $grillyRoot "build2" + + if ($Clean -and (Test-Path $buildDir)) { + Write-Host " Cleaning $buildDir..." -ForegroundColor Yellow + Remove-Item -Recurse -Force $buildDir + } + + if (-not (Test-Path (Join-Path $buildDir "CMakeCache.txt"))) { + Write-Host " Configuring cmake..." + & cmake -B $buildDir -G "Visual Studio 17 2022" -A x64 + if ($LASTEXITCODE -ne 0) { + Write-Host "ERROR: cmake configure failed" -ForegroundColor Red + exit 1 + } + } + + & cmake --build $buildDir --config Release --target grilly_core + if ($LASTEXITCODE -ne 0) { + Write-Host "ERROR: cmake build failed" -ForegroundColor Red + exit 1 + } + Write-Host " Build OK" -ForegroundColor Green +} + +# ── 3. Copy freshly built .pyd to grilly root ──────────────────────────── +Write-Step "[3/3] Copying grilly_core.cp312-win_amd64.pyd..." + +$builtPyd = Join-Path $grillyRoot "build2/Release/grilly_core.cp312-win_amd64.pyd" +$destPyd = Join-Path $grillyRoot "grilly_core.cp312-win_amd64.pyd" + +if (-not (Test-Path $builtPyd)) { + Write-Host "ERROR: Built .pyd not found at $builtPyd" -ForegroundColor Red + Write-Host " (did the build succeed?)" -ForegroundColor Red + exit 1 +} + +Copy-Item -Force $builtPyd $destPyd +$sizeMB = [math]::Round((Get-Item $destPyd).Length / 1MB, 2) +$mtime = (Get-Item $destPyd).LastWriteTime.ToString("HH:mm:ss") +Write-Host " Copied ($sizeMB MB, mtime $mtime)" -ForegroundColor Green + +Write-Host "" +Write-Host "=== Done ===" -ForegroundColor Cyan diff --git a/scripts/install.sh b/scripts/install.sh index e7dbcfa..208e9c9 100644 --- a/scripts/install.sh +++ b/scripts/install.sh @@ -13,7 +13,9 @@ set -euo pipefail -VULKAN_SDK_VERSION="1.4.341.1" +VULKAN_SDK_VERSION_14="1.4.341.1" +VULKAN_SDK_VERSION_13="1.3.296.0" +VULKAN_SDK_VERSION="$VULKAN_SDK_VERSION_14" GRILLY_DIR="$(cd "$(dirname "$0")/.." 2>/dev/null && pwd || echo "$(pwd)")" BUILD_DIR="${GRILLY_DIR}/build" JOBS="${JOBS:-$(nproc 2>/dev/null || sysctl -n hw.ncpu 2>/dev/null || echo 4)}" @@ -23,7 +25,15 @@ FAST_MODE=false for arg in "$@"; do case "$arg" in --fast|-f) FAST_MODE=true ;; - --help|-h) echo "Usage: $0 [--fast]"; echo " --fast Only build shaderc + loader (5 min vs 30 min)"; exit 0 ;; + --colab) FAST_MODE=true; VULKAN_SDK_VERSION="$VULKAN_SDK_VERSION_13" ;; + --vulkan-13|--v13) VULKAN_SDK_VERSION="$VULKAN_SDK_VERSION_13" ;; + --help|-h) + echo "Usage: $0 [OPTIONS]" + echo " --fast Only build shaderc + loader (~5 min vs ~30 min)" + echo " --colab Colab mode: Vulkan 1.3 + fast build + NVIDIA ICD" + echo " --v13 Use Vulkan SDK 1.3 (for older drivers)" + exit 0 + ;; esac done @@ -62,9 +72,8 @@ if [ "$PLATFORM" = "linux" ]; then # Colab/cloud: install NVIDIA Vulkan ICD if NVIDIA GPU detected if [ -d /proc/driver/nvidia ] || command -v nvidia-smi &>/dev/null; then info "NVIDIA GPU detected — installing Vulkan ICD driver..." - sudo apt-get install -y -qq nvidia-driver-550 2>/dev/null || \ - sudo apt-get install -y -qq nvidia-drivers-550 2>/dev/null || \ - warn "Could not install nvidia-driver-550 — Vulkan may not see GPU" + sudo apt-get install -y -qq nvidia-driver 2>/dev/null || \ + warn "Could not install nvidia-driver — Vulkan may not see GPU" fi ok "System deps installed (apt)" @@ -97,8 +106,13 @@ install_vulkan_linux() { info "Downloading Vulkan SDK ${VULKAN_SDK_VERSION}..." local TMP_DIR=$(mktemp -d) local TAR="vulkansdk-linux-x86_64-${VULKAN_SDK_VERSION}.tar.xz" + # Try xz first, fall back to gz for older SDK versions wget -q "https://sdk.lunarg.com/sdk/download/${VULKAN_SDK_VERSION}/linux/${TAR}" \ - -O "${TMP_DIR}/${TAR}" || fail "Failed to download Vulkan SDK" + -O "${TMP_DIR}/${TAR}" 2>/dev/null || { + TAR="vulkansdk-linux-x86_64-${VULKAN_SDK_VERSION}.tar.gz" + wget -q "https://sdk.lunarg.com/sdk/download/${VULKAN_SDK_VERSION}/linux/${TAR}" \ + -O "${TMP_DIR}/${TAR}" || fail "Failed to download Vulkan SDK ${VULKAN_SDK_VERSION}" + } info "Extracting Vulkan SDK..." sudo mkdir -p /opt/vulkan diff --git a/shaders/addition-linear.glsl b/shaders/addition-linear.glsl new file mode 100644 index 0000000..f46676f --- /dev/null +++ b/shaders/addition-linear.glsl @@ -0,0 +1,65 @@ +#version 450 + +layout (local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +// Input vector (batch_size, in_features) +layout(set = 0, binding = 0) readonly buffer Input { + float input_data[]; +}; + +// Weight patterns / templates (out_features, in_features) +layout(set = 0, binding = 1) readonly buffer Weights { + float weight_patterns[]; +}; + +// Bias (out_features) — optional, can be zeros +layout(set = 0, binding = 2) readonly buffer Bias { + float bias[]; +}; + +// Output (batch_size, out_features) +layout(set = 0, binding = 3) buffer Output { + float output_data[]; +}; + +// Parameters +layout(push_constant) uniform PushConsts { + uint batch_size; + uint in_features; + uint out_features; + uint use_bias; // 0 or 1 +}; + +// Addition-only linear: y[b][o] = -sum_i |weight[o][i] - input[b][i]| + bias[o] +// No multiplications — only additions, subtractions, absolute values. +// This is a radial basis function using Manhattan (L1) distance. +// The closer the input is to weight row o, the higher (less negative) the output. +void main() { + uint gID = gl_GlobalInvocationID.x; + + // Each thread handles one (batch, output) pair + uint total = batch_size * out_features; + if (gID >= total) { + return; + } + + uint b = gID / out_features; // batch index + uint o = gID % out_features; // output neuron index + + // Compute L1 distance: sum_i |w[o][i] - x[b][i]| + float dist = 0.0; + uint w_offset = o * in_features; + uint x_offset = b * in_features; + + for (uint i = 0; i < in_features; i++) { + dist += abs(weight_patterns[w_offset + i] - input_data[x_offset + i]); + } + + // Output = -distance + bias (closer → higher value) + float result = -dist; + if (use_bias == 1) { + result += bias[o]; + } + + output_data[gID] = result; +} diff --git a/shaders/attention-output.glsl b/shaders/attention-output.glsl index fd286a6..52f2542 100644 --- a/shaders/attention-output.glsl +++ b/shaders/attention-output.glsl @@ -44,21 +44,21 @@ void main() { uint head_idx = remainder / head_dim; uint dim_idx = remainder % head_dim; - // Compute weighted sum: output[batch, seq_q, head, dim] = sum_k(weights[batch, head, seq_q, k] * V[batch, k, head, dim]) + // Compute weighted sum: output[batch, seq_q, head, dim] + // = sum_k(weights[batch, head, seq_q, k] * V[batch, k, head, dim]) + // Hoist loop-invariant index math and use FMA in the inner loop. float sum = 0.0; - + + uint weight_base = batch_idx * num_heads * seq_len * seq_len + + head_idx * seq_len * seq_len + + seq_q * seq_len; + + uint v_base = batch_idx * seq_len * num_heads * head_dim + + head_idx * head_dim + dim_idx; + uint v_stride = num_heads * head_dim; + for (uint k = 0; k < seq_len; k++) { - // Weight index: weights[batch, head, seq_q, k] - uint weight_idx = batch_idx * num_heads * seq_len * seq_len + - head_idx * seq_len * seq_len + - seq_q * seq_len + k; - - // Value index: V[batch, k, head, dim] - uint v_idx = batch_idx * seq_len * num_heads * head_dim + - k * num_heads * head_dim + - head_idx * head_dim + dim_idx; - - sum += weights[weight_idx] * V[v_idx]; + sum = fma(weights[weight_base + k], V[v_base + k * v_stride], sum); } // Output index: output[batch, seq_q, head, dim] diff --git a/shaders/flash-attention2.glsl b/shaders/flash-attention2.glsl index c9ccf74..3c89d09 100644 --- a/shaders/flash-attention2.glsl +++ b/shaders/flash-attention2.glsl @@ -81,6 +81,9 @@ void update_online_softmax( void main() { uint row = gl_GlobalInvocationID.y; uint col = gl_GlobalInvocationID.x; + uint seq_heads = seq_len * num_heads; + uint head_stride = num_heads * head_dim; + uint bsh_stride = seq_len * head_stride; if (pass_type == 0) { // Pass 0: Initialize running max and sum for each query position @@ -130,19 +133,11 @@ void main() { // Compute attention score: Q[q_pos] @ K[k_pos]^T / sqrt(head_dim) float score = 0.0; + uint q_base = batch_idx * bsh_stride + q_pos * head_stride + head_idx * head_dim; + uint k_base = batch_idx * bsh_stride + k_pos * head_stride + head_idx * head_dim; for (uint d = 0; d < head_dim; d++) { - // Q index: [batch, seq, head, head_dim] - uint q_idx = batch_idx * seq_len * num_heads * head_dim + - q_pos * num_heads * head_dim + - head_idx * head_dim + d; - - // K index: [batch, seq, head, head_dim] - uint k_idx = batch_idx * seq_len * num_heads * head_dim + - k_pos * num_heads * head_dim + - head_idx * head_dim + d; - - score += Q[q_idx] * K[k_idx]; + score = fma(Q[q_base + d], K[k_base + d], score); } score *= scale; @@ -154,9 +149,7 @@ void main() { } // Get running max and sum for this query position - uint q_flat_idx = batch_idx * seq_len * num_heads + - q_pos * num_heads + - head_idx; + uint q_flat_idx = batch_idx * seq_heads + q_pos * num_heads + head_idx; float old_max = running_max[q_flat_idx]; float old_sum = running_sum[q_flat_idx]; @@ -176,23 +169,17 @@ void main() { // Rescale existing accumulator if max changed if (new_max > old_max) { float rescale = exp(old_max - new_max); + uint accum_base = q_flat_idx * head_dim; for (uint d = 0; d < head_dim; d++) { - uint accum_idx = q_flat_idx * head_dim + d; - output_accum[accum_idx] *= rescale; + output_accum[accum_base + d] *= rescale; } } // Accumulate weighted value: output += weight * V[k_pos] + uint accum_base = q_flat_idx * head_dim; + uint v_base = batch_idx * bsh_stride + k_pos * head_stride + head_idx * head_dim; for (uint d = 0; d < head_dim; d++) { - uint accum_idx = q_flat_idx * head_dim + d; - - // V index: [batch, seq, head, head_dim] - uint v_idx = batch_idx * seq_len * num_heads * head_dim + - k_pos * num_heads * head_dim + - head_idx * head_dim + d; - - // Accumulate: output[q_pos] += weight * V[k_pos] - output_accum[accum_idx] += weight * V[v_idx]; + output_accum[accum_base + d] = fma(weight, V[v_base + d], output_accum[accum_base + d]); } } else if (pass_type == 2) { diff --git a/shaders/fnn-linear.glsl b/shaders/fnn-linear.glsl index a18d47d..d016877 100644 --- a/shaders/fnn-linear.glsl +++ b/shaders/fnn-linear.glsl @@ -1,58 +1,73 @@ #version 450 +// Tiled GEMM: output = input @ W^T + bias +// 16x16 tiles, shared memory, 4-way unrolled K loop. +// Correct and stable. Performance limited by per-op dispatch overhead. + layout (local_size_x = 16, local_size_y = 16, local_size_z = 1) in; -// Input features (batch * seq, input_dim) layout(set = 0, binding = 0) readonly buffer Input { float input_data[]; }; -// Weight matrix (output_dim, input_dim) layout(set = 0, binding = 1) readonly buffer Weights { float W[]; }; -// Bias vector (output_dim) layout(set = 0, binding = 2) readonly buffer Bias { float b[]; }; -// Output (batch * seq, output_dim) layout(set = 0, binding = 3) buffer Output { float output_data[]; }; -// Parameters layout(push_constant) uniform PushConsts { - uint batch_seq; // batch_size * seq_len + uint batch_seq; uint input_dim; uint output_dim; - uint has_bias; // 1 if bias exists, 0 otherwise + uint has_bias; }; +shared float tileA[16][17]; // +1 padding to avoid bank conflicts +shared float tileB[16][17]; + void main() { - // 2D parallelization: each thread computes one output element - uint row = gl_GlobalInvocationID.y; // Sample index - uint col = gl_GlobalInvocationID.x; // Output feature index - - if (row >= batch_seq || col >= output_dim) { - return; - } - - // Compute: output[row][col] = sum(input[row][k] * W[col][k]) + b[col] - float sum = 0.0; - - for (uint k = 0; k < input_dim; k++) { - uint input_idx = row * input_dim + k; - uint weight_idx = col * input_dim + k; - sum += input_data[input_idx] * W[weight_idx]; + uint row = gl_WorkGroupID.y * 16 + gl_LocalInvocationID.y; + uint col = gl_WorkGroupID.x * 16 + gl_LocalInvocationID.x; + uint ty = gl_LocalInvocationID.y; + uint tx = gl_LocalInvocationID.x; + + float acc = 0.0; + uint numTiles = (input_dim + 15) / 16; + uint row_base = row * input_dim; + uint col_base = col * input_dim; + + for (uint t = 0; t < numTiles; t++) { + uint k_base = t * 16; + + uint k_a = k_base + tx; + tileA[ty][tx] = (row < batch_seq && k_a < input_dim) + ? input_data[row_base + k_a] : 0.0; + + uint k_b = k_base + ty; + tileB[ty][tx] = (col < output_dim && k_b < input_dim) + ? W[col_base + k_b] : 0.0; + + barrier(); + + for (uint k = 0; k < 16; k += 4) { + acc = fma(tileA[ty][k], tileB[k][tx], acc); + acc = fma(tileA[ty][k + 1], tileB[k + 1][tx], acc); + acc = fma(tileA[ty][k + 2], tileB[k + 2][tx], acc); + acc = fma(tileA[ty][k + 3], tileB[k + 3][tx], acc); + } + + barrier(); } - - // Add bias if present - if (has_bias == 1) { - sum += b[col]; + + if (row < batch_seq && col < output_dim) { + if (has_bias == 1) acc += b[col]; + output_data[row * output_dim + col] = acc; } - - uint out_idx = row * output_dim + col; - output_data[out_idx] = sum; } diff --git a/shaders/fused-layernorm-linear.glsl b/shaders/fused-layernorm-linear.glsl index db62133..698c5a2 100644 --- a/shaders/fused-layernorm-linear.glsl +++ b/shaders/fused-layernorm-linear.glsl @@ -1,5 +1,6 @@ #version 450 #extension GL_KHR_shader_subgroup_arithmetic : require +#extension GL_EXT_shader_atomic_float : require // Fused LayerNorm + Linear projection. // diff --git a/shaders/gemm-bias-add.glsl b/shaders/gemm-bias-add.glsl new file mode 100644 index 0000000..4aace97 --- /dev/null +++ b/shaders/gemm-bias-add.glsl @@ -0,0 +1,27 @@ +#version 450 + +/* + * Row-broadcast bias add: C[r, c] += bias[c] + * + * Used as a second dispatch after gemm-coopmat-shared (which produces C + * without bias, since coopMatStore can't easily interleave an elementwise + * add with the tile-cooperative store). 1D dispatch over M*N elements, + * 256 threads per workgroup. + */ + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout(binding = 0) buffer CBuffer { float C[]; }; +layout(binding = 1) readonly buffer BiasBuffer { float bias[]; }; + +layout(push_constant) uniform PushConstants { + uint totalElements; // M * N + uint N; // stride (columns per row) +} params; + +void main() { + uint idx = gl_GlobalInvocationID.x; + if (idx >= params.totalElements) return; + uint col = idx % params.N; + C[idx] = C[idx] + bias[col]; +} diff --git a/shaders/gemm-coopmat-shared.glsl b/shaders/gemm-coopmat-shared.glsl new file mode 100644 index 0000000..63c71e9 --- /dev/null +++ b/shaders/gemm-coopmat-shared.glsl @@ -0,0 +1,99 @@ +#version 450 +#extension GL_KHR_cooperative_matrix : require +#extension GL_KHR_memory_scope_semantics : require +#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require +#extension GL_KHR_shader_subgroup_basic : require + +/* + * Cooperative Matrix GEMM with Shared Memory Staging ("The Merge") + * + * C = A * B where A is MxK, B is KxN, C is MxN, row-major. + * A and B are float16; C accumulates in float32. + * + * Workgroup: 64x4 (256 threads) -> 4 subgroups of 64 lanes (Wave64 for RDNA). + * Output tile per workgroup: 16 rows x 64 cols of C (4 x 16x16 coopmat tiles). + * + * Alignment requirements on the caller: + * M % 16 == 0, K % 16 == 0, N % 64 == 0 + * + * Dispatch: + * gx = N / 64 + * gy = M / 16 + * gz = 1 + * + * Hardware notes: on RDNA3 / NVIDIA Tensor Cores this hits full WMMA + * throughput via the driver. On RDNA1/RDNA2 it runs through the driver's + * emulation path (standard fp16 vector ops) — still correct, noticeably + * slower than the peak but still competitive with a hand-tuled GEMM. + */ + +layout(local_size_x = 64, local_size_y = 4, local_size_z = 1) in; + +layout(binding = 0) readonly buffer ABuffer { float16_t A[]; }; +layout(binding = 1) readonly buffer BBuffer { float16_t B[]; }; +layout(binding = 2) writeonly buffer CBuffer { float C[]; }; + +layout(push_constant) uniform PushConstants { + uint M; + uint K; + uint N; +} params; + +// Shared memory staging: +// Asub stages a 16x16 tile of A (256 elements, all 4 subgroups share it) +// Bsub stages a 16x64 tile of B (1024 elements, each subgroup takes a slice) +shared float16_t Asub[256]; +shared float16_t Bsub[1024]; + +void main() { + uint tile_row = gl_WorkGroupID.y * 16u; + uint tile_col = gl_WorkGroupID.x * 64u; + + uint sg_id = gl_LocalInvocationID.y; // subgroup id 0..3 + uint lane_id = gl_LocalInvocationID.x; // lane within subgroup 0..63 + uint linear_id = sg_id * 64u + lane_id; // 0..255 + + // Hardware accumulator lives in registers (one per subgroup) + coopmat matC = + coopmat(0.0); + + for (uint k = 0u; k < params.K; k += 16u) { + // ── 1. Stage A tile (16x16 = 256 elements, one per thread) ── + if (linear_id < 256u) { + uint a_r = linear_id / 16u; + uint a_c = linear_id % 16u; + Asub[linear_id] = A[(tile_row + a_r) * params.K + (k + a_c)]; + } + + // ── 2. Stage B tile (16x64 = 1024 elements, 4 per thread) ── + for (uint i = 0u; i < 4u; ++i) { + uint b_idx = linear_id + (i * 256u); + uint b_r = b_idx / 64u; + uint b_c = b_idx % 64u; + Bsub[b_idx] = B[(k + b_r) * params.N + (tile_col + b_c)]; + } + + barrier(); + + // ── 3. Load shared memory → cooperative registers ── + coopmat matA; + coopmat matB; + + // All 4 subgroups load the SAME 16x16 A tile from Asub (stride = 16) + coopMatLoad(matA, Asub, 0, 16, gl_CooperativeMatrixLayoutRowMajor); + + // Each subgroup loads its OWN 16x16 slice of the 16x64 B tile + // (stride = 64 over Bsub, starting offset = sg_id * 16) + coopMatLoad(matB, Bsub, sg_id * 16u, 64, gl_CooperativeMatrixLayoutRowMajor); + + // ── 4. Hardware matrix multiply-accumulate ── + matC = coopMatMulAdd(matA, matB, matC); + + barrier(); + } + + // ── 5. Write back the 16x16 accumulated tile to global C ── + uint out_col = tile_col + (sg_id * 16u); + coopMatStore(matC, C, tile_row * params.N + out_col, params.N, + gl_CooperativeMatrixLayoutRowMajor); +} diff --git a/shaders/grid-cell.glsl b/shaders/grid-cell.glsl new file mode 100644 index 0000000..61ef4b8 --- /dev/null +++ b/shaders/grid-cell.glsl @@ -0,0 +1,69 @@ +#version 450 + +layout (local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +// Current agent position (2D) +layout(set = 0, binding = 0) readonly buffer AgentPosition { + float agent_pos[]; // [x, y] +}; + +// Grid cell parameters (per neuron) +layout(set = 0, binding = 1) readonly buffer GridSpacings { + float spacings[]; // [s0, s1, ...] +}; + +layout(set = 0, binding = 2) readonly buffer GridOrientations { + float orientations[]; // [theta0, theta1, ...] +}; + +layout(set = 0, binding = 3) readonly buffer GridPhases { + float phases[]; // [px0, py0, px1, py1, ...] +}; + +// Output firing rates +layout(set = 0, binding = 4) buffer FiringRates { + float rates[]; +}; + +// Parameters +layout(push_constant) uniform PushConsts { + uint n_neurons; + float max_rate; // e.g. 25 Hz +}; + +// Hexagonal 3-wave grid cell activation +// Three cosine waves at 60° intervals create hexagonal firing pattern +// r = max_rate * relu((cos(u1) + cos(u2) + cos(u3)) / 3 + 0.5) +void main() { + uint gID = gl_GlobalInvocationID.x; + + if (gID >= n_neurons) { + return; + } + + float x = agent_pos[0]; + float y = agent_pos[1]; + + // Rotation by grid orientation + float theta = orientations[gID]; + float cos_t = cos(theta); + float sin_t = sin(theta); + float rx = cos_t * x - sin_t * y; + float ry = sin_t * x + cos_t * y; + + // Shift by phase + uint phase_off = gID * 2; + float sx = rx - phases[phase_off]; + float sy = ry - phases[phase_off + 1]; + + // Wave vector: k = 4*pi / (sqrt(3) * spacing) + float k = 4.0 * 3.14159265 / (1.7320508 * spacings[gID]); + + // Three waves at 60° intervals (hexagonal pattern) + float u1 = k * sx; + float u2 = k * (-0.5 * sx + 0.866025 * sy); + float u3 = k * (-0.5 * sx - 0.866025 * sy); + + float val = (cos(u1) + cos(u2) + cos(u3)) / 3.0 + 0.5; + rates[gID] = max_rate * max(val, 0.0); +} diff --git a/shaders/int8-gemm.glsl b/shaders/int8-gemm.glsl index 1885164..4e767fe 100644 --- a/shaders/int8-gemm.glsl +++ b/shaders/int8-gemm.glsl @@ -54,23 +54,29 @@ void main() { uint packed_K = (K + 3) / 4; // number of uint32s per row float sum = 0.0; - - for (uint k = 0; k < K; k++) { - // Get activation value - float act = activations[row * K + k]; - - // Get packed weight and unpack - uint pack_idx = k / 4; - uint pack_offset = k % 4; + uint base = row * K; + + // Inner K: process 4 elements per iteration — one packed uint32 load per 4 weights. + uint k = 0; + for (; k + 4u <= K; k += 4u) { + uint pk = weights_packed[col * packed_K + k / 4u]; + uint g0 = k / group_size; + uint g1 = (k + 1u) / group_size; + uint g2 = (k + 2u) / group_size; + uint g3 = (k + 3u) / group_size; + sum += activations[base + k] * unpack_int8(pk, 0) * scales[col * num_groups + g0]; + sum += activations[base + k + 1u] * unpack_int8(pk, 1) * scales[col * num_groups + g1]; + sum += activations[base + k + 2u] * unpack_int8(pk, 2) * scales[col * num_groups + g2]; + sum += activations[base + k + 3u] * unpack_int8(pk, 3) * scales[col * num_groups + g3]; + } + for (; k < K; k++) { + uint pack_idx = k / 4u; + uint pack_offset = k % 4u; uint packed = weights_packed[col * packed_K + pack_idx]; float w = unpack_int8(packed, pack_offset); - - // Get per-group scale uint group_idx = k / group_size; float s = scales[col * num_groups + group_idx]; - - // FP32 accumulate: act * (w_int8 * scale) - sum += act * w * s; + sum += activations[base + k] * w * s; } output_data[row * N + col] = sum; diff --git a/shaders/lstm-cell-forward.glsl b/shaders/lstm-cell-forward.glsl index 7cd4e83..f11da5d 100644 --- a/shaders/lstm-cell-forward.glsl +++ b/shaders/lstm-cell-forward.glsl @@ -35,7 +35,8 @@ layout(push_constant) uniform PushConstants { } params; layout(set = 0, binding = 0) readonly buffer InputBuffer { - float input[]; + // Renamed from ``input`` — reserved word in recent glslang. + float input_data[]; }; layout(set = 0, binding = 1) readonly buffer HiddenBuffer { @@ -70,7 +71,12 @@ layout(set = 0, binding = 8) writeonly buffer NewCellBuffer { float new_cell[]; }; -layout(set = 0, binding = 9) writeonly buffer GatesBuffer { +// Not ``writeonly``: the shader writes all four gates into this buffer +// (binding 9) in the first half of main(), then reads them back a few +// lines later to compute the cell/hidden updates. Pre-existing bug — +// glslang's writeonly check was only triggered after fixing the +// reserved-word rename on binding 0 forced a recompile. +layout(set = 0, binding = 9) buffer GatesBuffer { float gates[]; }; @@ -94,7 +100,7 @@ void main() { for (uint i = 0; i < params.input_size; i++) { uint weight_idx = (gate * params.hidden_size + hidden_idx) * params.input_size + i; uint input_idx = batch_idx * params.input_size + i; - value += weight_ih[weight_idx] * input[input_idx]; + value += weight_ih[weight_idx] * input_data[input_idx]; } // Add bias_i diff --git a/shaders/mf-sigmoid.glsl b/shaders/mf-sigmoid.glsl new file mode 100644 index 0000000..240b0e0 --- /dev/null +++ b/shaders/mf-sigmoid.glsl @@ -0,0 +1,25 @@ +#version 450 + +// Rational sigmoid: x / (1 + |x|) + +layout (local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout(set = 0, binding = 0) readonly buffer Input { + float input_data[]; +}; + +layout(set = 0, binding = 1) buffer Output { + float output_data[]; +}; + +layout(push_constant) uniform PushConsts { + uint total_elements; +}; + +void main() { + uint gID = gl_GlobalInvocationID.x; + if (gID >= total_elements) return; + float x = input_data[gID]; + float ax = abs(x) + 1.0; + output_data[gID] = x / ax; +} diff --git a/shaders/mf-softmax.glsl b/shaders/mf-softmax.glsl new file mode 100644 index 0000000..977c166 --- /dev/null +++ b/shaders/mf-softmax.glsl @@ -0,0 +1,86 @@ +#version 450 + +// Multiplication-free softmax: relu(x - max) / sum(relu(x - max)) — no exp(). +// Same 3-pass layout as activation-softmax.glsl; pass 2 sums positive parts only. + +layout (local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout(set = 0, binding = 0) readonly buffer Input { + float input_data[]; +}; + +layout(set = 0, binding = 1) buffer Output { + float output_data[]; +}; + +layout(set = 0, binding = 2) buffer MaxValues { + float max_vals[]; +}; + +layout(set = 0, binding = 3) buffer SumPos { + float sum_pos[]; +}; + +layout(push_constant) uniform PushConsts { + uint batch_size; + uint seq_len; + uint features; + uint pass_type; + uint dim; +}; + +void main() { + uint gID = gl_GlobalInvocationID.x; + + if (pass_type == 0u) { + uint total_positions = batch_size * seq_len; + if (gID >= total_positions) return; + + uint batch_idx = gID / seq_len; + uint seq_idx = gID % seq_len; + + float max_val = -1e10; + for (uint f = 0; f < features; f++) { + uint idx = batch_idx * seq_len * features + seq_idx * features + f; + max_val = max(max_val, input_data[idx]); + } + max_vals[gID] = max_val; + + } else if (pass_type == 1u) { + uint total_positions = batch_size * seq_len; + if (gID >= total_positions) return; + + uint batch_idx = gID / seq_len; + uint seq_idx = gID % seq_len; + float max_val = max_vals[gID]; + + float sum = 0.0; + for (uint f = 0; f < features; f++) { + uint idx = batch_idx * seq_len * features + seq_idx * features + f; + float v = input_data[idx] - max_val; + sum += max(v, 0.0); + } + sum_pos[gID] = sum; + + } else if (pass_type == 2u) { + uint total_elements = batch_size * seq_len * features; + if (gID >= total_elements) return; + + uint batch_idx = gID / (seq_len * features); + uint remainder = gID % (seq_len * features); + uint seq_idx = remainder / features; + + uint pos_idx = batch_idx * seq_len + seq_idx; + float max_val = max_vals[pos_idx]; + float sum = sum_pos[pos_idx]; + + float val = input_data[gID]; + float z = max(val - max_val, 0.0); + if (sum < 1e-5) { + output_data[gID] = 1.0 / float(features); + } else { + float denom = max(sum, 1e-6); + output_data[gID] = z / denom; + } + } +} diff --git a/shaders/mf-softplus.glsl b/shaders/mf-softplus.glsl new file mode 100644 index 0000000..62fe139 --- /dev/null +++ b/shaders/mf-softplus.glsl @@ -0,0 +1,26 @@ +#version 450 + +// Algebraic softplus: 0.5 * (x + sqrt(x*x + c)) with c = 4/beta^2 — no exp/log. + +layout (local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout(set = 0, binding = 0) readonly buffer Input { + float input_data[]; +}; + +layout(set = 0, binding = 1) buffer Output { + float output_data[]; +}; + +layout(push_constant) uniform PushConsts { + uint total_elements; + float c; // 4 / (beta * beta), set from host +}; + +void main() { + uint gID = gl_GlobalInvocationID.x; + if (gID >= total_elements) return; + float x = input_data[gID]; + float s = sqrt(x * x + c); + output_data[gID] = 0.5 * (x + s); +} diff --git a/shaders/moe-layer-backward-vec4.glsl b/shaders/moe-layer-backward-vec4.glsl new file mode 100644 index 0000000..a55c0d8 --- /dev/null +++ b/shaders/moe-layer-backward-vec4.glsl @@ -0,0 +1,87 @@ +#version 450 + +// Fused MoE Backward (vec4): grad_input for all 4 experts. +// RDNA2 optimized: vec4 loads, 32x8 workgroup (matches Wave32), LDS padding. +// +// dx[row, k] = grad_out[row, k] + sum_e(w_e * sum_col(grad_out[row, col] * W_e[col, k])) +// All buffers vec4 packed (d_model must be multiple of 4). + +layout (local_size_x = 32, local_size_y = 8) in; + +layout(set = 0, binding = 0) readonly buffer GradOut { vec4 g_out[]; }; +layout(set = 0, binding = 1) readonly buffer ExpertWeights { vec4 ew[]; }; +layout(set = 0, binding = 2) writeonly buffer GradInput { vec4 g_in[]; }; + +layout(push_constant) uniform PushConsts { + uint seq_len; + uint d_model; + uint n_experts; + float w0, w1, w2, w3; +}; + +shared vec4 tileG[8][33]; // grad_out tile, 8 rows × 32 vec4 cols + padding + +void main() { + uint row = gl_WorkGroupID.y * 8 + gl_LocalInvocationID.y; + uint k_vec = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x; + uint ty = gl_LocalInvocationID.y; + uint tx = gl_LocalInvocationID.x; + + uint d_vec = d_model / 4; + + if (row >= seq_len || k_vec >= d_vec) return; + + vec4 acc0 = vec4(0.0), acc1 = vec4(0.0), acc2 = vec4(0.0), acc3 = vec4(0.0); + + uint numTiles = (d_vec + 31) / 32; + uint ew_stride = d_model * d_vec; // vec4s per expert + + for (uint t = 0; t < numTiles; t++) { + // Load grad_out into LDS + uint col_vec = t * 32 + tx; + tileG[ty][tx] = (col_vec < d_vec) ? g_out[row * d_vec + col_vec] : vec4(0.0); + + memoryBarrierShared(); + barrier(); + + // Accumulate: dx[k] += grad[col] * W[col, k] + for (uint c = 0; c < 32; c++) { + vec4 g = tileG[ty][c]; + uint col_global = t * 32 + c; // vec4 index + if (col_global >= d_vec) break; + + // Each vec4 g has 4 scalar grad components for 4 adjacent columns + // W[col, k] accessed as ew[expert * stride + col_scalar * d_vec + k_vec] + + // Expert 0 + acc0 += g.x * ew[0 * ew_stride + (col_global * 4 + 0) * d_vec + k_vec]; + acc0 += g.y * ew[0 * ew_stride + (col_global * 4 + 1) * d_vec + k_vec]; + acc0 += g.z * ew[0 * ew_stride + (col_global * 4 + 2) * d_vec + k_vec]; + acc0 += g.w * ew[0 * ew_stride + (col_global * 4 + 3) * d_vec + k_vec]; + + // Expert 1 + acc1 += g.x * ew[1 * ew_stride + (col_global * 4 + 0) * d_vec + k_vec]; + acc1 += g.y * ew[1 * ew_stride + (col_global * 4 + 1) * d_vec + k_vec]; + acc1 += g.z * ew[1 * ew_stride + (col_global * 4 + 2) * d_vec + k_vec]; + acc1 += g.w * ew[1 * ew_stride + (col_global * 4 + 3) * d_vec + k_vec]; + + // Expert 2 + acc2 += g.x * ew[2 * ew_stride + (col_global * 4 + 0) * d_vec + k_vec]; + acc2 += g.y * ew[2 * ew_stride + (col_global * 4 + 1) * d_vec + k_vec]; + acc2 += g.z * ew[2 * ew_stride + (col_global * 4 + 2) * d_vec + k_vec]; + acc2 += g.w * ew[2 * ew_stride + (col_global * 4 + 3) * d_vec + k_vec]; + + // Expert 3 + acc3 += g.x * ew[3 * ew_stride + (col_global * 4 + 0) * d_vec + k_vec]; + acc3 += g.y * ew[3 * ew_stride + (col_global * 4 + 1) * d_vec + k_vec]; + acc3 += g.z * ew[3 * ew_stride + (col_global * 4 + 2) * d_vec + k_vec]; + acc3 += g.w * ew[3 * ew_stride + (col_global * 4 + 3) * d_vec + k_vec]; + } + + barrier(); + } + + vec4 dx = w0 * acc0 + w1 * acc1 + w2 * acc2 + w3 * acc3; + uint idx = row * d_vec + k_vec; + g_in[idx] = g_out[idx] + dx; // residual gradient +} diff --git a/shaders/moe-layer-backward.glsl b/shaders/moe-layer-backward.glsl new file mode 100644 index 0000000..e1b4f5d --- /dev/null +++ b/shaders/moe-layer-backward.glsl @@ -0,0 +1,110 @@ +#version 450 + +// Fused MoE Layer Backward: compute dx (gradient w.r.t. input) for all 4 experts in one dispatch. +// +// dx[row, k] = sum_e(router_w[e] * sum_col(grad_out[row, col] * expert_w[e][col, k])) +// = sum_e(router_w[e] * (grad_out[row,:] @ expert_w[e][:, k])) +// +// This is: dx = sum_e(w_e * grad_out @ W_e) +// Each expert weight matrix W_e is (d_model, d_model), stored row-major as W_e[col, k]. +// So grad_out @ W_e means: for output element (row, k), sum over col: grad[row,col] * W[col,k] +// +// Tiled: 16x16, K tiled in blocks of 16. + +layout (local_size_x = 16, local_size_y = 16, local_size_z = 1) in; + +// Gradient w.r.t. layer output (seq_len, d_model) — this is dx from the next layer +layout(set = 0, binding = 0) readonly buffer GradOut { + float grad_out[]; +}; + +// Expert weights (same as forward): 4 experts × (d_model, d_model) +layout(set = 0, binding = 1) readonly buffer ExpertWeights { + float ew[]; +}; + +// Output: gradient w.r.t. layer input (seq_len, d_model) +layout(set = 0, binding = 2) buffer GradInput { + float grad_in[]; +}; + +layout(push_constant) uniform PushConsts { + uint seq_len; + uint d_model; + uint n_experts; + float router_w0; + float router_w1; + float router_w2; + float router_w3; +}; + +shared float tileG[16][17]; // grad_out tile + +void main() { + uint row = gl_WorkGroupID.y * 16 + gl_LocalInvocationID.y; + uint k_out = gl_WorkGroupID.x * 16 + gl_LocalInvocationID.x; // output K index + uint ty = gl_LocalInvocationID.y; + uint tx = gl_LocalInvocationID.x; + + if (row >= seq_len || k_out >= d_model) return; + + float acc0 = 0.0, acc1 = 0.0, acc2 = 0.0, acc3 = 0.0; + + uint numTiles = (d_model + 15) / 16; + uint dd = d_model * d_model; + uint row_base = row * d_model; + + for (uint t = 0; t < numTiles; t++) { + uint col_base = t * 16; + + // Load grad_out tile + uint col = col_base + tx; + tileG[ty][tx] = (col < d_model) ? grad_out[row_base + col] : 0.0; + + barrier(); + + // Accumulate: dx[row, k] += grad[row, col] * W[e][col, k] + // W layout: ew[e * d*d + col * d + k] + for (uint c = 0; c < 16; c += 4) { + uint col0 = col_base + c; + float g0 = tileG[ty][c]; + float g1 = tileG[ty][c + 1]; + float g2 = tileG[ty][c + 2]; + float g3 = tileG[ty][c + 3]; + + // Expert 0: W[0][col, k_out] + acc0 = fma(g0, ew[0 * dd + (col0) * d_model + k_out], acc0); + acc0 = fma(g1, ew[0 * dd + (col0 + 1) * d_model + k_out], acc0); + acc0 = fma(g2, ew[0 * dd + (col0 + 2) * d_model + k_out], acc0); + acc0 = fma(g3, ew[0 * dd + (col0 + 3) * d_model + k_out], acc0); + + // Expert 1 + acc1 = fma(g0, ew[1 * dd + (col0) * d_model + k_out], acc1); + acc1 = fma(g1, ew[1 * dd + (col0 + 1) * d_model + k_out], acc1); + acc1 = fma(g2, ew[1 * dd + (col0 + 2) * d_model + k_out], acc1); + acc1 = fma(g3, ew[1 * dd + (col0 + 3) * d_model + k_out], acc1); + + // Expert 2 + acc2 = fma(g0, ew[2 * dd + (col0) * d_model + k_out], acc2); + acc2 = fma(g1, ew[2 * dd + (col0 + 1) * d_model + k_out], acc2); + acc2 = fma(g2, ew[2 * dd + (col0 + 2) * d_model + k_out], acc2); + acc2 = fma(g3, ew[2 * dd + (col0 + 3) * d_model + k_out], acc2); + + // Expert 3 + acc3 = fma(g0, ew[3 * dd + (col0) * d_model + k_out], acc3); + acc3 = fma(g1, ew[3 * dd + (col0 + 1) * d_model + k_out], acc3); + acc3 = fma(g2, ew[3 * dd + (col0 + 2) * d_model + k_out], acc3); + acc3 = fma(g3, ew[3 * dd + (col0 + 3) * d_model + k_out], acc3); + } + + barrier(); + } + + // Weighted sum + residual gradient (dx passes through residual) + float dx = router_w0 * acc0 + router_w1 * acc1 + + router_w2 * acc2 + router_w3 * acc3; + + // Residual: grad_input = grad_out + dx (grad passes through addition) + uint idx = row * d_model + k_out; + grad_in[idx] = grad_out[idx] + dx; +} diff --git a/shaders/moe-layer-fused-vec4.glsl b/shaders/moe-layer-fused-vec4.glsl new file mode 100644 index 0000000..153f33c --- /dev/null +++ b/shaders/moe-layer-fused-vec4.glsl @@ -0,0 +1,106 @@ +#version 450 +#extension GL_EXT_control_flow_attributes : enable + +// Fused MoE Layer (vec4): 4 expert matmuls + blend + residual. +// RDNA2 optimized: vec4 loads (128-bit), 16x16 workgroup, LDS bank padding. +// +// output[row, col] = input[row, col] + sum_e(w[e] * (input[row,:] @ expert_w[e][col,:])) +// All buffers use vec4 packing (d_model must be multiple of 4). + +layout (local_size_x = 16, local_size_y = 16) in; + +layout(set = 0, binding = 0) readonly buffer Input { vec4 x[]; }; +layout(set = 0, binding = 1) readonly buffer ExpertWeights { vec4 ew[]; }; +layout(set = 0, binding = 2) writeonly buffer Output { vec4 out_data[]; }; +layout(set = 0, binding = 3) readonly buffer RouterScratch { float rscratch[]; }; + +layout(push_constant) uniform PushConsts { + uint seq_len; + uint d_model; // must be multiple of 4 + uint n_experts; + uint weights_offset; +}; + +shared vec4 tileX[16][17]; // input tile with bank-conflict padding + +void main() { + uint row = gl_WorkGroupID.y * 16 + gl_LocalInvocationID.y; + uint col_vec = gl_WorkGroupID.x * 16 + gl_LocalInvocationID.x; + uint tx = gl_LocalInvocationID.x; + uint ty = gl_LocalInvocationID.y; + + uint d_vec = d_model / 4; + + if (row >= seq_len || col_vec >= d_vec) return; + + // Load router weights + float w0 = rscratch[weights_offset]; + float w1 = rscratch[weights_offset + 1]; + float w2 = rscratch[weights_offset + 2]; + float w3 = rscratch[weights_offset + 3]; + + vec4 acc0 = vec4(0.0), acc1 = vec4(0.0), acc2 = vec4(0.0), acc3 = vec4(0.0); + + uint numTiles = (d_vec + 15) / 16; + // Expert weight stride: each expert is d_model rows × d_vec vec4s per row + // ew layout: ew[expert * d_model * d_vec + col * d_vec + k_vec] + // But we need W[col_out, k_in] which is ew[expert * d * d_vec + col_vec*4*d_vec + k_vec] + // Simplified: expert weights are (d, d) row-major packed as vec4 + // ew[expert * d * d_vec + row_w * d_vec + k_vec] + uint ew_stride = d_model * d_vec; // vec4s per expert + + [[unroll]] + for (uint t = 0; t < numTiles; t++) { + uint k_base = t * 16; + + // Load input tile: x[row, k_base*4 .. (k_base+16)*4] + uint k_vec = k_base + tx; + tileX[ty][tx] = (k_vec < d_vec) ? x[row * d_vec + k_vec] : vec4(0.0); + + memoryBarrierShared(); + barrier(); + + // Accumulate for all 4 experts + // Each iteration: load 4 floats from input (via tileX), multiply with expert weight vec4 + [[unroll]] + for (uint k = 0; k < 16; k++) { + vec4 xv = tileX[ty][k]; + uint k_global = k_base + k; // vec4 index in K dimension + + // For each expert: W[col_out_vec4, k_global_vec4] + // col_vec gives us the output vec4 column + // We need 4 loads per expert (one per component of xv) + // ew index: expert * ew_stride + (col_vec*4 + component) * d_vec + k_global + + // Expert 0 + acc0 += xv.x * ew[0 * ew_stride + (col_vec * 4 + 0) * d_vec + k_global]; + acc0 += xv.y * ew[0 * ew_stride + (col_vec * 4 + 1) * d_vec + k_global]; + acc0 += xv.z * ew[0 * ew_stride + (col_vec * 4 + 2) * d_vec + k_global]; + acc0 += xv.w * ew[0 * ew_stride + (col_vec * 4 + 3) * d_vec + k_global]; + + // Expert 1 + acc1 += xv.x * ew[1 * ew_stride + (col_vec * 4 + 0) * d_vec + k_global]; + acc1 += xv.y * ew[1 * ew_stride + (col_vec * 4 + 1) * d_vec + k_global]; + acc1 += xv.z * ew[1 * ew_stride + (col_vec * 4 + 2) * d_vec + k_global]; + acc1 += xv.w * ew[1 * ew_stride + (col_vec * 4 + 3) * d_vec + k_global]; + + // Expert 2 + acc2 += xv.x * ew[2 * ew_stride + (col_vec * 4 + 0) * d_vec + k_global]; + acc2 += xv.y * ew[2 * ew_stride + (col_vec * 4 + 1) * d_vec + k_global]; + acc2 += xv.z * ew[2 * ew_stride + (col_vec * 4 + 2) * d_vec + k_global]; + acc2 += xv.w * ew[2 * ew_stride + (col_vec * 4 + 3) * d_vec + k_global]; + + // Expert 3 + acc3 += xv.x * ew[3 * ew_stride + (col_vec * 4 + 0) * d_vec + k_global]; + acc3 += xv.y * ew[3 * ew_stride + (col_vec * 4 + 1) * d_vec + k_global]; + acc3 += xv.z * ew[3 * ew_stride + (col_vec * 4 + 2) * d_vec + k_global]; + acc3 += xv.w * ew[3 * ew_stride + (col_vec * 4 + 3) * d_vec + k_global]; + } + + barrier(); + } + + vec4 blended = w0 * acc0 + w1 * acc1 + w2 * acc2 + w3 * acc3; + uint out_idx = row * d_vec + col_vec; + out_data[out_idx] = x[out_idx] + blended; +} diff --git a/shaders/moe-layer-fused.glsl b/shaders/moe-layer-fused.glsl new file mode 100644 index 0000000..cd8023c --- /dev/null +++ b/shaders/moe-layer-fused.glsl @@ -0,0 +1,104 @@ +#version 450 + +// Fused MoE Layer: 4 expert matmuls + weighted blend + residual in ONE dispatch. +// +// Router weights read from scratch buffer (computed by moe-router shader). +// No CPU round-trip between layers. +// +// output[row, col] = input[row, col] + sum_e(w[e] * (input[row,:] @ expert_w[e][col,:])) + +layout (local_size_x = 16, local_size_y = 16, local_size_z = 1) in; + +layout(set = 0, binding = 0) readonly buffer Input { + float x[]; +}; + +layout(set = 0, binding = 1) readonly buffer ExpertWeights { + float ew[]; // 4 experts packed: ew[e * d*d + col * d + k] +}; + +layout(set = 0, binding = 2) buffer Output { + float out_data[]; +}; + +// Router scratch buffer: [x_mean (d) | logits (E) | max (1) | sum (1) | weights (E)] +layout(set = 0, binding = 3) readonly buffer RouterScratch { + float rscratch[]; +}; + +layout(push_constant) uniform PushConsts { + uint seq_len; + uint d_model; + uint n_experts; + uint weights_offset; // offset in rscratch where softmax weights start +}; + +shared float tileX[16][17]; + +void main() { + uint row = gl_WorkGroupID.y * 16 + gl_LocalInvocationID.y; + uint col = gl_WorkGroupID.x * 16 + gl_LocalInvocationID.x; + uint ty = gl_LocalInvocationID.y; + uint tx = gl_LocalInvocationID.x; + + if (row >= seq_len || col >= d_model) return; + + // Read router weights from scratch buffer + float w0 = rscratch[weights_offset]; + float w1 = rscratch[weights_offset + 1]; + float w2 = rscratch[weights_offset + 2]; + float w3 = rscratch[weights_offset + 3]; + + float acc0 = 0.0, acc1 = 0.0, acc2 = 0.0, acc3 = 0.0; + + uint numTiles = (d_model + 15) / 16; + uint row_base = row * d_model; + uint dd = d_model * d_model; + + for (uint t = 0; t < numTiles; t++) { + uint k_base = t * 16; + + uint k_a = k_base + tx; + tileX[ty][tx] = (k_a < d_model) ? x[row_base + k_a] : 0.0; + + barrier(); + + uint e0_base = 0 * dd + col * d_model + k_base; + uint e1_base = 1 * dd + col * d_model + k_base; + uint e2_base = 2 * dd + col * d_model + k_base; + uint e3_base = 3 * dd + col * d_model + k_base; + + for (uint k = 0; k < 16; k += 4) { + float x0 = tileX[ty][k]; + float x1 = tileX[ty][k + 1]; + float x2 = tileX[ty][k + 2]; + float x3 = tileX[ty][k + 3]; + + acc0 = fma(x0, ew[e0_base + k], acc0); + acc0 = fma(x1, ew[e0_base + k + 1], acc0); + acc0 = fma(x2, ew[e0_base + k + 2], acc0); + acc0 = fma(x3, ew[e0_base + k + 3], acc0); + + acc1 = fma(x0, ew[e1_base + k], acc1); + acc1 = fma(x1, ew[e1_base + k + 1], acc1); + acc1 = fma(x2, ew[e1_base + k + 2], acc1); + acc1 = fma(x3, ew[e1_base + k + 3], acc1); + + acc2 = fma(x0, ew[e2_base + k], acc2); + acc2 = fma(x1, ew[e2_base + k + 1], acc2); + acc2 = fma(x2, ew[e2_base + k + 2], acc2); + acc2 = fma(x3, ew[e2_base + k + 3], acc2); + + acc3 = fma(x0, ew[e3_base + k], acc3); + acc3 = fma(x1, ew[e3_base + k + 1], acc3); + acc3 = fma(x2, ew[e3_base + k + 2], acc3); + acc3 = fma(x3, ew[e3_base + k + 3], acc3); + } + + barrier(); + } + + float blended = w0 * acc0 + w1 * acc1 + w2 * acc2 + w3 * acc3; + uint idx = row * d_model + col; + out_data[idx] = x[idx] + blended; +} diff --git a/shaders/moe-layer-grad-weight.glsl b/shaders/moe-layer-grad-weight.glsl new file mode 100644 index 0000000..31fddf4 --- /dev/null +++ b/shaders/moe-layer-grad-weight.glsl @@ -0,0 +1,85 @@ +#version 450 + +// Fused MoE grad_W: compute weight gradients for all 4 experts in one dispatch. +// +// grad_W[e][col, k] = router_w[e] * sum_row(grad_out[row, col] * input[row, k]) +// = router_w[e] * (grad_out[:,col].T @ input[:,k]) +// +// Each thread computes one (col, k) element for all 4 experts. +// Tiled: 16x16, reduction over seq_len in blocks of 16. + +layout (local_size_x = 16, local_size_y = 16, local_size_z = 1) in; + +layout(set = 0, binding = 0) readonly buffer GradOut { + float grad_out[]; // (seq_len, d_model) +}; + +layout(set = 0, binding = 1) readonly buffer Input { + float x[]; // (seq_len, d_model) +}; + +// Output: 4 expert weight gradients packed contiguously +// grad_ew[e * d*d + col * d + k] +layout(set = 0, binding = 2) buffer GradWeights { + float grad_ew[]; // (4 * d_model * d_model) +}; + +layout(push_constant) uniform PushConsts { + uint seq_len; + uint d_model; + uint n_experts; + float router_w0; + float router_w1; + float router_w2; + float router_w3; +}; + +shared float tileG[16][17]; // grad_out[:, col] tile +shared float tileX[16][17]; // input[:, k] tile + +void main() { + uint col = gl_WorkGroupID.y * 16 + gl_LocalInvocationID.y; + uint k = gl_WorkGroupID.x * 16 + gl_LocalInvocationID.x; + uint ty = gl_LocalInvocationID.y; + uint tx = gl_LocalInvocationID.x; + + if (col >= d_model || k >= d_model) return; + + float acc = 0.0; // shared across experts (same outer product, different scaling) + + uint numTiles = (seq_len + 15) / 16; + + for (uint t = 0; t < numTiles; t++) { + uint row_base = t * 16; + + // Load grad_out column tile: grad_out[row_base+ty, col] + uint r_g = row_base + ty; + tileG[ty][tx] = (r_g < seq_len && col < d_model) + ? grad_out[r_g * d_model + col] : 0.0; + + // Load input column tile: x[row_base+ty, k] + uint r_x = row_base + ty; + tileX[ty][tx] = (r_x < seq_len && k < d_model) + ? x[r_x * d_model + k] : 0.0; + + barrier(); + + // Accumulate outer product: sum_row(grad[row, col] * x[row, k]) + for (uint r = 0; r < 16; r += 4) { + acc = fma(tileG[r][ty], tileX[r][tx], acc); // Note: ty indexes col, tx indexes k + acc = fma(tileG[r+1][ty], tileX[r+1][tx], acc); + acc = fma(tileG[r+2][ty], tileX[r+2][tx], acc); + acc = fma(tileG[r+3][ty], tileX[r+3][tx], acc); + } + + barrier(); + } + + // Write grad_W for all 4 experts (same outer product, scaled by router weight) + uint dd = d_model * d_model; + uint base = col * d_model + k; + grad_ew[0 * dd + base] = router_w0 * acc; + grad_ew[1 * dd + base] = router_w1 * acc; + grad_ew[2 * dd + base] = router_w2 * acc; + grad_ew[3 * dd + base] = router_w3 * acc; +} diff --git a/shaders/moe-router.glsl b/shaders/moe-router.glsl new file mode 100644 index 0000000..60d66e2 --- /dev/null +++ b/shaders/moe-router.glsl @@ -0,0 +1,84 @@ +#version 450 + +// MoE Router: mean(x) → logits → softmax → router weights +// +// Three passes (controlled by push constant `pass`): +// Pass 0: Compute mean(x) across seq_len → x_mean (d_model,) +// Pass 1: Compute logits = router_W @ x_mean + router_b → (n_experts,) +// Pass 2: Stable softmax over logits → router_weights (n_experts,) +// +// Scratch layout: [x_mean (d) | logits (E) | weights (E)] + +layout (local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout(set = 0, binding = 0) readonly buffer Input { + float x[]; +}; + +layout(set = 0, binding = 1) readonly buffer RouterW { + float rw[]; +}; + +layout(set = 0, binding = 2) readonly buffer RouterB { + float rb[]; +}; + +layout(set = 0, binding = 3) buffer Scratch { + float scratch[]; +}; + +layout(push_constant) uniform PushConsts { + uint seq_len; + uint d_model; + uint n_experts; + uint pass; +}; + +void main() { + uint gid = gl_GlobalInvocationID.x; + + if (pass == 0u) { + // Pass 0: mean reduction + if (gid >= d_model) return; + float sum = 0.0; + for (uint s = 0; s < seq_len; s++) { + sum += x[s * d_model + gid]; + } + scratch[gid] = sum / float(seq_len); + } + else if (pass == 1u) { + // Pass 1: logits = rw @ x_mean + rb + if (gid >= n_experts) return; + float dot = 0.0; + for (uint k = 0; k < d_model; k++) { + dot += rw[gid * d_model + k] * scratch[k]; + } + scratch[d_model + gid] = dot + rb[gid]; + } + else if (pass == 2u) { + // Pass 2: numerically stable softmax (single thread, n_experts is small) + if (gid != 0u) return; + uint logits_off = d_model; + uint weights_off = d_model + n_experts; + + // Find max for numerical stability + float max_val = scratch[logits_off]; + for (uint e = 1; e < n_experts; e++) { + max_val = max(max_val, scratch[logits_off + e]); + } + + // Exp with max subtraction + sum + float sum_exp = 0.0; + for (uint e = 0; e < n_experts; e++) { + float ev = exp(scratch[logits_off + e] - max_val); + scratch[weights_off + e] = ev; + sum_exp += ev; + } + + // Normalize + float inv_sum = 1.0 / (sum_exp + 1e-8); + for (uint e = 0; e < n_experts; e++) { + scratch[weights_off + e] *= inv_sum; + } + } +} diff --git a/shaders/prefix-scan-causal-backward.glsl b/shaders/prefix-scan-causal-backward.glsl new file mode 100644 index 0000000..112012c --- /dev/null +++ b/shaders/prefix-scan-causal-backward.glsl @@ -0,0 +1,102 @@ +#version 450 +#extension GL_KHR_shader_subgroup_arithmetic : require +#extension GL_KHR_shader_subgroup_basic : require + +/* + * Causal Linear-RNN Prefix Scan (backward) + * + * Given forward recurrence h_t = a_t * h_{t-1} + x_t, the adjoint rules are: + * + * dx_t = dh_t + a_{t+1} * dx_{t+1} (anti-causal sum) + * da_t = dx_t * h_{t-1} + * + * Using g_t = prod_{k<=t} a_k (forward cumprod): + * + * dx_t = (1 / g_t) * sum_{s>=t} g_s * dh_s + * + * Strategy: + * 1. Compute g_t via ``subgroupInclusiveAdd(log(a))`` + exp. Verified + * correct on partial Wave64 subgroups (see earlier debug dump: + * max abs err 2.98e-08 vs numpy cumprod). + * 2. Compute weighted_dh = g_t * dh_t per thread. + * 3. Store to shared memory. + * 4. Each thread LOOPS sequentially over shared[t..seq_len-1] to compute + * its own right_sum. For seq_len <= 32 this is ~32 iterations which + * is cheap compared to the rest of the pipeline, and it sidesteps + * the ``total`` broadcast trap entirely (neither ``subgroupAdd`` nor + * a ``shared_total`` write from thread 31 produced the correct total + * in earlier attempts — the cause is still unclear but likely some + * AMD partial-Wave64 edge case). + * + * Dispatch: one workgroup per (batch, hidden_dim) pair, one thread per + * time step. Constraint: seq_len <= 32 (matches local_size_x). + */ + +layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in; + +layout(binding = 0) readonly buffer GradHBuffer { float dH[]; }; +layout(binding = 1) readonly buffer DecayBuffer { float A[]; }; +layout(binding = 2) readonly buffer ForwardHBuffer { float H[]; }; +layout(binding = 3) readonly buffer ForwardXBuffer { float X[]; }; +layout(binding = 4) writeonly buffer GradXBuffer { float dX[]; }; +layout(binding = 5) writeonly buffer GradABuffer { float dA[]; }; + +layout(push_constant) uniform PushConstants { + uint seq_len; + uint hidden_dim; +} params; + +shared float scratch[32]; + +void main() { + uint t = gl_LocalInvocationID.x; + uint d = gl_WorkGroupID.x; + uint b = gl_WorkGroupID.y; + + bool in_bounds = (d < params.hidden_dim); + bool in_seq = in_bounds && (t < params.seq_len); + + uint idx = in_seq + ? (b * params.seq_len * params.hidden_dim) + (t * params.hidden_dim) + d + : 0u; + + // Load forward values (neutral for padding threads). + float a_t = in_seq ? max(A[idx], 1e-6) : 1.0; + float dh_t = in_seq ? dH[idx] : 0.0; + float h_t = in_seq ? H[idx] : 0.0; + float x_t = in_seq ? X[idx] : 0.0; + + // ── Forward cumulative product of a: g_t = prod_{k<=t} a_k ── + float log_a = log(a_t); + float cumsum_log_a = subgroupInclusiveAdd(log_a); + float g_t = exp(cumsum_log_a); + + // ── Store weighted_dh to shared memory ── + float weighted_dh = g_t * dh_t; + scratch[t] = weighted_dh; + barrier(); + + // ── Sequential right-sum: right_sum[t] = sum_{s>=t} weighted_dh[s] ── + float right_sum = 0.0; + for (uint s = t; s < params.seq_len; s++) { + right_sum += scratch[s]; + } + + // dx_t = right_sum / g_t + float dx_t = right_sum / g_t; + + // ── da_t = dx_t * h_{t-1}, with h_{t-1} = (h_t - x_t) / a_t ── + // Numerically unstable when a_t is tiny — max(A[idx], 1e-6) clamp + // keeps it away from zero. TODO: save h_{t-1} during forward for + // full stability. + float h_prev = 0.0; + if (in_seq && t > 0u) { + h_prev = (h_t - x_t) / a_t; + } + float da_t = dx_t * h_prev; + + if (in_seq) { + dX[idx] = dx_t; + dA[idx] = da_t; + } +} diff --git a/shaders/prefix-scan-causal.glsl b/shaders/prefix-scan-causal.glsl new file mode 100644 index 0000000..82679ac --- /dev/null +++ b/shaders/prefix-scan-causal.glsl @@ -0,0 +1,85 @@ +#version 450 +#extension GL_KHR_shader_subgroup_arithmetic : require +#extension GL_KHR_shader_subgroup_basic : require + +/* + * Causal Linear-RNN Prefix Scan (forward) + * + * Computes strictly-causal recurrence in parallel: + * h_t = a_t * h_{t-1} + x_t + * where a_t is a per-(batch, seq, dim) decay gate in (0, 1], and h_0 = 0. + * + * Math trick: taking the cumulative product of a_t and scaling x_t by its + * inverse turns the recurrence into two commutative scans — a cumulative + * sum of log(a) for the decay, and a cumulative sum of the scaled input. + * + * cumprod_a[t] = prod_{s<=t} a_s = exp(sum_{s<=t} log(a_s)) + * h_t = cumprod_a[t] * sum_{s<=t} (x_s / cumprod_a[s]) + * + * Both scans are hardware-accelerated by ``subgroupInclusiveAdd``, which + * runs entirely in the subgroup's registers (no shared/global traffic). + * + * Dispatch layout (one subgroup per (batch, hidden_dim) pair): + * local_size_x = subgroup_size (32 on RDNA Wave32, or use 64 for Wave64) + * gl_WorkGroupID.x = hidden_dim index + * gl_WorkGroupID.y = batch index + * local_x = time step t + * + * Buffer layout is (B, S, D) row-major: + * index(b, t, d) = b * S * D + t * D + d + * + * Constraint: sequence length S must be <= subgroup_size. For longer + * sequences the caller should split the sequence into chunks that fit + * in one subgroup, using the tail of each chunk as the carry for the + * next. A multi-subgroup version of this shader (hierarchical scan) is + * a follow-up. + */ + +layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in; + +layout(binding = 0) readonly buffer InputBuffer { float X[]; }; +layout(binding = 1) readonly buffer DecayBuffer { float A[]; }; +layout(binding = 2) writeonly buffer OutputBuffer { float H[]; }; + +layout(push_constant) uniform PushConstants { + uint seq_len; + uint hidden_dim; +} params; + +void main() { + uint t = gl_LocalInvocationID.x; // time step 0..seq_len-1 + uint d = gl_WorkGroupID.x; // hidden dimension index + uint b = gl_WorkGroupID.y; // batch index + + if (d >= params.hidden_dim) return; + bool in_seq = (t < params.seq_len); + + // Flattened (B, S, D) → linear index. Threads beyond seq_len write to + // dummy index 0; we'll guard the store with ``in_seq``. + uint idx = in_seq + ? (b * params.seq_len * params.hidden_dim) + (t * params.hidden_dim) + d + : 0u; + + // Load current time-step values. For padding threads (t >= seq_len), + // fall back to neutral values: a = 1.0 (no-op in cumprod), x = 0. + float a_t = in_seq ? max(A[idx], 1e-6) : 1.0; + float x_t = in_seq ? X[idx] : 0.0; + + // 1. Subgroup inclusive scan of log(a) → cumulative product of a_t + float log_a = log(a_t); + float cumsum_log_a = subgroupInclusiveAdd(log_a); + float cumprod_a = exp(cumsum_log_a); + + // 2. Scale x_t by the inverse cumprod so the next scan is a pure sum + float scaled_x = x_t / cumprod_a; + + // 3. Subgroup inclusive scan of scaled_x + float cumsum_scaled_x = subgroupInclusiveAdd(scaled_x); + + // 4. Undo the scaling to get h_t + float h_t = cumprod_a * cumsum_scaled_x; + + if (in_seq) { + H[idx] = h_t; + } +} diff --git a/shaders/sign-activation.glsl b/shaders/sign-activation.glsl new file mode 100644 index 0000000..ff35efb --- /dev/null +++ b/shaders/sign-activation.glsl @@ -0,0 +1,24 @@ +#version 450 + +// Sign activation: output[i] = (input[i] > 0) ? 1.0 : -1.0 +// Exactly zero maps to -1 (subgradient convention matching backward). + +layout (local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout(set = 0, binding = 0) readonly buffer Input { + float input_data[]; +}; + +layout(set = 0, binding = 1) buffer Output { + float output_data[]; +}; + +layout(push_constant) uniform PushConsts { + uint total_elements; +}; + +void main() { + uint gID = gl_GlobalInvocationID.x; + if (gID >= total_elements) return; + output_data[gID] = (input_data[gID] > 0.0) ? 1.0 : -1.0; +} diff --git a/shaders/snippets/addition_linear.glsl b/shaders/snippets/addition_linear.glsl new file mode 100644 index 0000000..b5dd95a --- /dev/null +++ b/shaders/snippets/addition_linear.glsl @@ -0,0 +1,17 @@ +// Snippet: Addition-only linear (L1 distance) +// Computes -||w - x||_1 for a single (output, input) pair +// No multiplications — only add, sub, abs +float op_addition_linear(float w, float x) { + return -abs(w - x); +} + +// Snippet: Sign activation with threshold (ternary output) +float op_sign_activation(float x, float threshold) { + return sign(x - threshold); +} + +// Snippet: Additive receptance (sigmoid approximation) +// sigmoid(x) ≈ clamp(0.5 + 0.25 * x, 0, 1) +float op_additive_sigmoid(float x) { + return clamp(0.5 + 0.25 * x, 0.0, 1.0); +} diff --git a/shaders/snippets/grid_cell.glsl b/shaders/snippets/grid_cell.glsl new file mode 100644 index 0000000..b5eed74 --- /dev/null +++ b/shaders/snippets/grid_cell.glsl @@ -0,0 +1,10 @@ +// Snippet: Grid cell hexagonal 3-wave activation +// Input: position (rx, ry shifted), spacing, wave vector k +// Output: firing rate [0, max_rate] +float op_grid_cell(float rx, float ry, float k, float max_rate) { + float u1 = k * rx; + float u2 = k * (-0.5 * rx + 0.866025 * ry); + float u3 = k * (-0.5 * rx - 0.866025 * ry); + float val = (cos(u1) + cos(u2) + cos(u3)) / 3.0 + 0.5; + return max_rate * max(val, 0.0); +} diff --git a/shaders/softmax-fast.glsl b/shaders/softmax-fast.glsl new file mode 100644 index 0000000..9bb922d --- /dev/null +++ b/shaders/softmax-fast.glsl @@ -0,0 +1,77 @@ +#version 450 +#extension GL_KHR_shader_subgroup_arithmetic : enable +#extension GL_KHR_shader_subgroup_shuffle : enable + +// RDNA2-optimized softmax: subgroup reductions + online max rescaling. +// 256 threads per workgroup (8 subgroups of 32 on Wave32). +// Numerically stable — max subtraction built into the reduction. +// +// Each workgroup processes one row of length <= 256. +// For longer rows, dispatch multiple workgroups per row (requires cross-WG reduction). + +layout(local_size_x = 256) in; + +layout(set = 0, binding = 0) buffer Input { float data[]; } In; +layout(set = 0, binding = 1) buffer Output { float data[]; } Out; + +layout(push_constant) uniform PushConsts { + uint batch_size; // number of rows + uint row_length; // elements per row +}; + +shared float lds_max[8]; +shared float lds_sum[8]; + +void main() { + uint row = gl_WorkGroupID.x; + uint l_idx = gl_LocalInvocationID.x; + uint sg_id = gl_SubgroupInvocationID; + uint sg_idx = gl_SubgroupID; + + if (row >= batch_size) return; + + uint base = row * row_length; + + // Load value (pad with -inf for threads beyond row_length) + float val = (l_idx < row_length) ? In.data[base + l_idx] : -1e38; + + // --- STEP 1: Subgroup-Level Online Reduction --- + float m = subgroupMax(val); + float d = subgroupAdd(exp(val - m)); + + // --- STEP 2: Workgroup-Level Reduction via LDS --- + if (subgroupElect()) { + lds_max[sg_idx] = m; + lds_sum[sg_idx] = d; + } + barrier(); + + // Final reduction by first subgroup + if (sg_idx == 0) { + float final_m = -1e38; + float final_d = 0.0; + + uint n_sgs = (row_length + 31) / 32; + if (n_sgs > 8) n_sgs = 8; + + for (uint i = 0; i < n_sgs; ++i) { + float prev_m = final_m; + final_m = max(final_m, lds_max[i]); + final_d = final_d * exp(prev_m - final_m) + lds_sum[i] * exp(lds_max[i] - final_m); + } + + if (sg_id == 0) { + lds_max[0] = final_m; + lds_sum[0] = final_d; + } + } + barrier(); + + // --- STEP 3: Final Computation --- + float wg_max = lds_max[0]; + float wg_sum = lds_sum[0]; + + if (l_idx < row_length) { + Out.data[base + l_idx] = exp(val - wg_max) / wg_sum; + } +} diff --git a/shaders/spv/activation-gelu-backward.spv b/shaders/spv/activation-gelu-backward.spv index 7885dde..5b636d3 100644 Binary files a/shaders/spv/activation-gelu-backward.spv and b/shaders/spv/activation-gelu-backward.spv differ diff --git a/shaders/spv/activation-gelu.spv b/shaders/spv/activation-gelu.spv index 52c4b34..73fa616 100644 Binary files a/shaders/spv/activation-gelu.spv and b/shaders/spv/activation-gelu.spv differ diff --git a/shaders/spv/addition-linear.spv b/shaders/spv/addition-linear.spv new file mode 100644 index 0000000..d72869f Binary files /dev/null and b/shaders/spv/addition-linear.spv differ diff --git a/shaders/spv/attention-output.spv b/shaders/spv/attention-output.spv index 7f53be4..94f9b58 100644 Binary files a/shaders/spv/attention-output.spv and b/shaders/spv/attention-output.spv differ diff --git a/shaders/spv/conv1x1-backward-weight.spv b/shaders/spv/conv1x1-backward-weight.spv index d0cef87..7be9997 100644 Binary files a/shaders/spv/conv1x1-backward-weight.spv and b/shaders/spv/conv1x1-backward-weight.spv differ diff --git a/shaders/spv/convd_col2im_noatomic.spv b/shaders/spv/convd_col2im_noatomic.spv new file mode 100644 index 0000000..c767fd6 Binary files /dev/null and b/shaders/spv/convd_col2im_noatomic.spv differ diff --git a/shaders/spv/dequant-4bit.spv b/shaders/spv/dequant-4bit.spv new file mode 100644 index 0000000..f700446 Binary files /dev/null and b/shaders/spv/dequant-4bit.spv differ diff --git a/shaders/spv/flash-attention2.spv b/shaders/spv/flash-attention2.spv index f945782..04eb457 100644 Binary files a/shaders/spv/flash-attention2.spv and b/shaders/spv/flash-attention2.spv differ diff --git a/shaders/spv/fnn-linear.spv b/shaders/spv/fnn-linear.spv index 872db0e..7281507 100644 Binary files a/shaders/spv/fnn-linear.spv and b/shaders/spv/fnn-linear.spv differ diff --git a/shaders/spv/fused-layernorm-linear.spv b/shaders/spv/fused-layernorm-linear.spv index 4856652..9f124a0 100644 Binary files a/shaders/spv/fused-layernorm-linear.spv and b/shaders/spv/fused-layernorm-linear.spv differ diff --git a/shaders/spv/gemm-bias-add.spv b/shaders/spv/gemm-bias-add.spv new file mode 100644 index 0000000..312cd7b Binary files /dev/null and b/shaders/spv/gemm-bias-add.spv differ diff --git a/shaders/spv/gemm-coopmat-shared.spv b/shaders/spv/gemm-coopmat-shared.spv new file mode 100644 index 0000000..4edd06a Binary files /dev/null and b/shaders/spv/gemm-coopmat-shared.spv differ diff --git a/shaders/spv/gqa-attention.spv b/shaders/spv/gqa-attention.spv new file mode 100644 index 0000000..c736829 Binary files /dev/null and b/shaders/spv/gqa-attention.spv differ diff --git a/shaders/spv/grid-cell.spv b/shaders/spv/grid-cell.spv new file mode 100644 index 0000000..35ce3b9 Binary files /dev/null and b/shaders/spv/grid-cell.spv differ diff --git a/shaders/spv/hmm-baum-welch.spv b/shaders/spv/hmm-baum-welch.spv new file mode 100644 index 0000000..49c0b5c Binary files /dev/null and b/shaders/spv/hmm-baum-welch.spv differ diff --git a/shaders/spv/hmm-forward.spv b/shaders/spv/hmm-forward.spv new file mode 100644 index 0000000..6197d53 Binary files /dev/null and b/shaders/spv/hmm-forward.spv differ diff --git a/shaders/spv/hopfield-surprise.spv b/shaders/spv/hopfield-surprise.spv new file mode 100644 index 0000000..9028b32 Binary files /dev/null and b/shaders/spv/hopfield-surprise.spv differ diff --git a/shaders/spv/int8-gemm.spv b/shaders/spv/int8-gemm.spv new file mode 100644 index 0000000..16a07ce Binary files /dev/null and b/shaders/spv/int8-gemm.spv differ diff --git a/shaders/spv/lstm-cell-forward.spv b/shaders/spv/lstm-cell-forward.spv new file mode 100644 index 0000000..8c2dbb0 Binary files /dev/null and b/shaders/spv/lstm-cell-forward.spv differ diff --git a/shaders/spv/mf-sigmoid.spv b/shaders/spv/mf-sigmoid.spv new file mode 100644 index 0000000..89a180c Binary files /dev/null and b/shaders/spv/mf-sigmoid.spv differ diff --git a/shaders/spv/mf-softmax.spv b/shaders/spv/mf-softmax.spv new file mode 100644 index 0000000..55e6998 Binary files /dev/null and b/shaders/spv/mf-softmax.spv differ diff --git a/shaders/spv/mf-softplus.spv b/shaders/spv/mf-softplus.spv new file mode 100644 index 0000000..70d04b5 Binary files /dev/null and b/shaders/spv/mf-softplus.spv differ diff --git a/shaders/spv/moe-layer-backward-vec4.spv b/shaders/spv/moe-layer-backward-vec4.spv new file mode 100644 index 0000000..bae06ef Binary files /dev/null and b/shaders/spv/moe-layer-backward-vec4.spv differ diff --git a/shaders/spv/moe-layer-backward.spv b/shaders/spv/moe-layer-backward.spv new file mode 100644 index 0000000..0b917e9 Binary files /dev/null and b/shaders/spv/moe-layer-backward.spv differ diff --git a/shaders/spv/moe-layer-fused-vec4.spv b/shaders/spv/moe-layer-fused-vec4.spv new file mode 100644 index 0000000..4014e71 Binary files /dev/null and b/shaders/spv/moe-layer-fused-vec4.spv differ diff --git a/shaders/spv/moe-layer-fused.spv b/shaders/spv/moe-layer-fused.spv new file mode 100644 index 0000000..794ed46 Binary files /dev/null and b/shaders/spv/moe-layer-fused.spv differ diff --git a/shaders/spv/moe-layer-grad-weight.spv b/shaders/spv/moe-layer-grad-weight.spv new file mode 100644 index 0000000..cdca358 Binary files /dev/null and b/shaders/spv/moe-layer-grad-weight.spv differ diff --git a/shaders/spv/moe-router.spv b/shaders/spv/moe-router.spv new file mode 100644 index 0000000..fc8851e Binary files /dev/null and b/shaders/spv/moe-router.spv differ diff --git a/shaders/spv/prefix-scan-causal-backward.spv b/shaders/spv/prefix-scan-causal-backward.spv new file mode 100644 index 0000000..e6f043b Binary files /dev/null and b/shaders/spv/prefix-scan-causal-backward.spv differ diff --git a/shaders/spv/prefix-scan-causal.spv b/shaders/spv/prefix-scan-causal.spv new file mode 100644 index 0000000..8b96607 Binary files /dev/null and b/shaders/spv/prefix-scan-causal.spv differ diff --git a/shaders/spv/rms-norm-linear-fused.spv b/shaders/spv/rms-norm-linear-fused.spv new file mode 100644 index 0000000..5e56745 Binary files /dev/null and b/shaders/spv/rms-norm-linear-fused.spv differ diff --git a/shaders/spv/sign-activation.spv b/shaders/spv/sign-activation.spv new file mode 100644 index 0000000..0094305 Binary files /dev/null and b/shaders/spv/sign-activation.spv differ diff --git a/shaders/spv/softmax-fast.spv b/shaders/spv/softmax-fast.spv new file mode 100644 index 0000000..5b5c447 Binary files /dev/null and b/shaders/spv/softmax-fast.spv differ diff --git a/shaders/spv/stdp-learning.spv b/shaders/spv/stdp-learning.spv index ee48f55..9642944 100644 Binary files a/shaders/spv/stdp-learning.spv and b/shaders/spv/stdp-learning.spv differ diff --git a/shaders/spv/surprise-momentum.spv b/shaders/spv/surprise-momentum.spv new file mode 100644 index 0000000..59e92d4 Binary files /dev/null and b/shaders/spv/surprise-momentum.spv differ diff --git a/shaders/spv/surprise-recall-blend.spv b/shaders/spv/surprise-recall-blend.spv new file mode 100644 index 0000000..1640b2c Binary files /dev/null and b/shaders/spv/surprise-recall-blend.spv differ diff --git a/shaders/spv/swiglu-fused.spv b/shaders/spv/swiglu-fused.spv new file mode 100644 index 0000000..9215135 Binary files /dev/null and b/shaders/spv/swiglu-fused.spv differ diff --git a/shaders/spv/synapsis-forward.spv b/shaders/spv/synapsis-forward.spv index 6ff60fa..7ba6c07 100644 Binary files a/shaders/spv/synapsis-forward.spv and b/shaders/spv/synapsis-forward.spv differ diff --git a/shaders/spv/synapsis-stdp-trace.spv b/shaders/spv/synapsis-stdp-trace.spv index 91db67d..564f590 100644 Binary files a/shaders/spv/synapsis-stdp-trace.spv and b/shaders/spv/synapsis-stdp-trace.spv differ diff --git a/shaders/spv/synapsis-stdp-update.spv b/shaders/spv/synapsis-stdp-update.spv index 4a2e5ac..9af643a 100644 Binary files a/shaders/spv/synapsis-stdp-update.spv and b/shaders/spv/synapsis-stdp-update.spv differ diff --git a/shaders/spv/theta-gamma-encoding.spv b/shaders/spv/theta-gamma-encoding.spv index 07bc1cd..e05134c 100644 Binary files a/shaders/spv/theta-gamma-encoding.spv and b/shaders/spv/theta-gamma-encoding.spv differ diff --git a/shaders/spv/time-cell.spv b/shaders/spv/time-cell.spv index eabc71d..7797874 100644 Binary files a/shaders/spv/time-cell.spv and b/shaders/spv/time-cell.spv differ diff --git a/shaders/spv/vsa-explore.spv b/shaders/spv/vsa-explore.spv new file mode 100644 index 0000000..4498227 Binary files /dev/null and b/shaders/spv/vsa-explore.spv differ diff --git a/shaders/spv/vsa-logic-apply.spv b/shaders/spv/vsa-logic-apply.spv index baeba53..a1e2d20 100644 Binary files a/shaders/spv/vsa-logic-apply.spv and b/shaders/spv/vsa-logic-apply.spv differ diff --git a/shaders/spv/whitening-apply.spv b/shaders/spv/whitening-apply.spv index 3cc5be6..18dcb9a 100644 Binary files a/shaders/spv/whitening-apply.spv and b/shaders/spv/whitening-apply.spv differ diff --git a/shaders/spv/whitening-batch-stats.spv b/shaders/spv/whitening-batch-stats.spv index fac1ae8..3b00997 100644 Binary files a/shaders/spv/whitening-batch-stats.spv and b/shaders/spv/whitening-batch-stats.spv differ diff --git a/shaders/spv/whitening-transform.spv b/shaders/spv/whitening-transform.spv index 991d436..934f67a 100644 Binary files a/shaders/spv/whitening-transform.spv and b/shaders/spv/whitening-transform.spv differ diff --git a/shaders/vsa-explore.glsl b/shaders/vsa-explore.glsl index 600ea18..bd2c7e6 100644 --- a/shaders/vsa-explore.glsl +++ b/shaders/vsa-explore.glsl @@ -49,9 +49,13 @@ layout(std430, binding = 2) buffer Workspace { int workspace[]; // [n_states * n_states], init to 0 }; -// Output: best transition matrix (float, normalized) -layout(std430, binding = 3) writeonly buffer Output { - float output[]; // [n_states * n_states] +// Output: best transition matrix (float, normalized). +// NOTE: not ``writeonly`` — the shader also reads this buffer as a scratch +// space during exploration (see similarity() below, and the block_bind +// "reusing output buffer temporarily" path). +// ``output`` is a reserved word in recent glslang — renamed to output_data. +layout(std430, binding = 3) buffer Output { + float output_data[]; // [n_states * n_states] }; // ── Per-block circular convolution (bind) ──────────────────── @@ -66,7 +70,7 @@ void block_bind(uint block_offset_a, uint block_offset_b, sum += obs[block_offset_a + m] * obs[block_offset_b + idx_b]; } // Store in shared memory (reusing output buffer temporarily) - output[block_offset_out + i] = sum; + output_data[block_offset_out + i] = sum; } } @@ -74,7 +78,7 @@ void block_bind(uint block_offset_a, uint block_offset_b, float similarity(uint vec_offset, uint cb_entry, uint dim) { float dot = 0.0; for (uint i = 0; i < dim; i++) { - dot += output[vec_offset + i] * codebook[cb_entry * dim + i]; + dot += output_data[vec_offset + i] * codebook[cb_entry * dim + i]; } return dot / float(k); // Normalize by number of blocks } diff --git a/tests/_tokenizer_parity_helpers.py b/tests/_tokenizer_parity_helpers.py new file mode 100644 index 0000000..c3667a3 --- /dev/null +++ b/tests/_tokenizer_parity_helpers.py @@ -0,0 +1,68 @@ +"""Shared helpers for tokenizer/SentencePiece parity tests.""" + +from __future__ import annotations + +import pytest + + +def load_hf_tokenizer(model_id: str): + """Load Hugging Face tokenizer or skip with context.""" + transformers = pytest.importorskip("transformers") + try: + return transformers.AutoTokenizer.from_pretrained(model_id) + except Exception as exc: # pragma: no cover - depends on environment/cache + pytest.skip(f"Hugging Face tokenizer unavailable for {model_id}: {exc}") + + +def load_grilly_tokenizer(model_id: str): + """Load Grilly tokenizer through likely in-progress API shapes or skip.""" + try: + from grilly import tokenizers as grilly_tokenizers + except Exception as exc: + pytest.skip(f"grilly.tokenizers not available yet: {exc}") + + try: + if hasattr(grilly_tokenizers, "Tokenizer"): + tok_cls = getattr(grilly_tokenizers, "Tokenizer") + if hasattr(tok_cls, "from_pretrained"): + return tok_cls.from_pretrained(model_id) + if hasattr(grilly_tokenizers, "AutoTokenizer"): + auto_cls = getattr(grilly_tokenizers, "AutoTokenizer") + if hasattr(auto_cls, "from_pretrained"): + return auto_cls.from_pretrained(model_id) + if hasattr(grilly_tokenizers, "from_pretrained"): + return grilly_tokenizers.from_pretrained(model_id) + except Exception as exc: + pytest.skip(f"Grilly tokenizer could not load {model_id}: {exc}") + + pytest.skip("No supported tokenizer loading API found in grilly.tokenizers") + + +def extract_input_ids(encoded): + """Normalize encoder output to plain input IDs.""" + if isinstance(encoded, dict): + ids = encoded.get("input_ids") + if ids is None: + raise AssertionError("Encoded dict missing 'input_ids'") + return ids + if hasattr(encoded, "input_ids"): + ids = encoded.input_ids + if hasattr(ids, "tolist"): + return ids.tolist() + return ids + return encoded + + +def encode_ids(tokenizer, text: str, add_special_tokens: bool = True): + """Encode text with a flexible tokenizer call surface.""" + if hasattr(tokenizer, "encode"): + try: + return extract_input_ids( + tokenizer.encode(text, add_special_tokens=add_special_tokens) + ) + except TypeError: + return extract_input_ids(tokenizer.encode(text)) + if hasattr(tokenizer, "__call__"): + return extract_input_ids(tokenizer(text, add_special_tokens=add_special_tokens)) + raise AssertionError("Tokenizer has neither encode() nor __call__()") + diff --git a/tests/autograd_chain/test_autograd_chain_placeholder.py b/tests/autograd_chain/test_autograd_chain_placeholder.py new file mode 100644 index 0000000..8e84a33 --- /dev/null +++ b/tests/autograd_chain/test_autograd_chain_placeholder.py @@ -0,0 +1,13 @@ +"""Pre-v1.0 scaffold: autograd chain-recorder execution test suite placeholder.""" + +import pytest + +pytestmark = pytest.mark.skip( + reason="Scaffold only: implement autograd chain-recorder tests in pre-v1.0 roadmap." +) + + +def test_autograd_chain_scaffold(): + """Placeholder test to reserve CI target.""" + assert True + diff --git a/tests/conftest.py b/tests/conftest.py index 093d482..3594913 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -17,16 +17,26 @@ def pytest_configure(config): config.addinivalue_line( "markers", "cpp: marks tests that require C++ backend (deselect with '-m \"not cpp\"')" ) + config.addinivalue_line( + "markers", "parity: marks numerical parity tests (numpy / optional PyTorch reference)" + ) + config.addinivalue_line( + "markers", "slow: marks tests that are slow or heavy (deselect with '-m \"not slow\"')" + ) try: import grilly - from grilly.backend.base import VULKAN_AVAILABLE + from grilly.backend.base import ( + VULKAN_AVAILABLE, + VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE, + ) GRILLY_AVAILABLE = True except (ImportError, AttributeError, Exception): GRILLY_AVAILABLE = False VULKAN_AVAILABLE = False + VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE = False # C++ backend (grilly_core with NN framework classes) try: @@ -40,8 +50,10 @@ def pytest_configure(config): @pytest.fixture def gpu_backend(): """Fixture for GPU backend (skips if not available)""" - if not VULKAN_AVAILABLE: - pytest.skip("Vulkan not available") + if not VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE: + pytest.skip( + "Vulkan Compute() not available (needs C++ grilly_core GPU + pip install vulkan)" + ) try: from grilly import Compute diff --git a/tests/converter/test_pytorch_converter_placeholder.py b/tests/converter/test_pytorch_converter_placeholder.py new file mode 100644 index 0000000..94ac60b --- /dev/null +++ b/tests/converter/test_pytorch_converter_placeholder.py @@ -0,0 +1,13 @@ +"""Pre-v1.0 scaffold: PyTorch -> Grilly converter test suite placeholder.""" + +import pytest + +pytestmark = pytest.mark.skip( + reason="Scaffold only: implement converter tests in pre-v1.0 roadmap." +) + + +def test_pytorch_converter_scaffold(): + """Placeholder test to reserve CI target.""" + assert True + diff --git a/tests/moe_quant/test_moe_quant_placeholder.py b/tests/moe_quant/test_moe_quant_placeholder.py new file mode 100644 index 0000000..8cce829 --- /dev/null +++ b/tests/moe_quant/test_moe_quant_placeholder.py @@ -0,0 +1,13 @@ +"""Pre-v1.0 scaffold: MoE quantized expert test suite placeholder.""" + +import pytest + +pytestmark = pytest.mark.skip( + reason="Scaffold only: implement MoE quantization tests in pre-v1.0 roadmap." +) + + +def test_moe_quant_scaffold(): + """Placeholder test to reserve CI target.""" + assert True + diff --git a/tests/parity/README.md b/tests/parity/README.md new file mode 100644 index 0000000..cf4b350 --- /dev/null +++ b/tests/parity/README.md @@ -0,0 +1,34 @@ +# Numerical parity tests (PyTorch reference) + +This directory holds tests that compare Grilly outputs to **numpy references** and, when +`torch` is installed, to **`torch.nn.functional`** equivalents. + +## Running + +```bash +# Core parity (numpy reference only; no optional deps) +uv run pytest tests/parity/ -v + +# Include PyTorch cross-checks (requires: pip install "grilly[torch]" or torch) +uv run pytest tests/parity/ -v +``` + +## Conventions + +- **Tolerances**: `rtol=1e-4`, `atol=1e-5` for float32 unless an op documents a looser policy. +- **Weight layout**: `grilly.functional.linear` uses `weight` shaped `(out_features, in_features)`, matching `F.linear` / `nn.Linear.weight`. +- **Markers**: tests are tagged `parity` (see root `tests/conftest.py`). + +## Optimizers + +`test_optimizers_parity.py` compares **SGD** and **Adam** (CPU path, `use_gpu=False`) against +`torch.optim` on a single tensor when `torch` is installed. + +## Dispatch / batching notes + +See `docs/PERF_DISPATCH.md` for `VulkanCompute.record_commands`, async dispatch, and Sequential fusion. + +## Roadmap + +See `docs/PYTORCH_PARITY_TASKLIST.md` (workstream A1). Planned additions: small CNN/MLP modules, +transformer encoder blocks, and a summarized pass/fail table in CI. diff --git a/tests/parity/test_functional_parity.py b/tests/parity/test_functional_parity.py new file mode 100644 index 0000000..6616040 --- /dev/null +++ b/tests/parity/test_functional_parity.py @@ -0,0 +1,95 @@ +""" +Numerical parity: `grilly.functional` vs numpy and (optional) PyTorch references. +""" + +import numpy as np +import pytest + + +def _ref_linear(x: np.ndarray, weight: np.ndarray, bias: np.ndarray | None) -> np.ndarray: + y = x @ weight.T + if bias is not None: + y = y + bias + return y.astype(np.float32, copy=False) + + +@pytest.mark.parity +def test_linear_matches_numpy_reference(): + from grilly.functional import linear + + rng = np.random.default_rng(0) + x = rng.standard_normal((8, 32)).astype(np.float32) + w = rng.standard_normal((16, 32)).astype(np.float32) + b = rng.standard_normal((16,)).astype(np.float32) + expected = _ref_linear(x, w, b) + got = linear(x, w, b) + np.testing.assert_allclose(got, expected, rtol=1e-4, atol=1e-5) + + +@pytest.mark.parity +def test_linear_matches_numpy_no_bias(): + from grilly.functional import linear + + rng = np.random.default_rng(1) + x = rng.standard_normal((4, 64)).astype(np.float32) + w = rng.standard_normal((24, 64)).astype(np.float32) + expected = _ref_linear(x, w, None) + got = linear(x, w, None) + np.testing.assert_allclose(got, expected, rtol=1e-4, atol=1e-5) + + +@pytest.mark.parity +def test_relu_matches_numpy_reference(): + from grilly.functional import relu + + rng = np.random.default_rng(2) + x = rng.standard_normal((5, 17)).astype(np.float32) + expected = np.maximum(0.0, x).astype(np.float32) + got = relu(x) + np.testing.assert_allclose(got, expected, rtol=1e-4, atol=1e-5) + + +@pytest.mark.parity +def test_linear_relu_chain_matches_numpy_reference(): + from grilly.functional import linear, relu + + rng = np.random.default_rng(3) + x = rng.standard_normal((6, 20)).astype(np.float32) + w = rng.standard_normal((12, 20)).astype(np.float32) + b = rng.standard_normal((12,)).astype(np.float32) + expected = np.maximum(0.0, _ref_linear(x, w, b)).astype(np.float32) + got = relu(linear(x, w, b)) + np.testing.assert_allclose(got, expected, rtol=1e-4, atol=1e-5) + + +@pytest.mark.parity +def test_linear_matches_torch_functional(): + torch = pytest.importorskip("torch") + import torch.nn.functional as F + from grilly.functional import linear + + rng = np.random.default_rng(4) + x = rng.standard_normal((8, 32)).astype(np.float32) + w = rng.standard_normal((16, 32)).astype(np.float32) + b = rng.standard_normal((16,)).astype(np.float32) + + xt = torch.from_numpy(x) + wt = torch.from_numpy(w) + bt = torch.from_numpy(b) + expected = F.linear(xt, wt, bt).detach().numpy().astype(np.float32) + got = linear(x, w, b) + np.testing.assert_allclose(got, expected, rtol=1e-4, atol=1e-5) + + +@pytest.mark.parity +def test_relu_matches_torch(): + torch = pytest.importorskip("torch") + import torch.nn.functional as F + from grilly.functional import relu + + rng = np.random.default_rng(5) + x = rng.standard_normal((5, 17)).astype(np.float32) + xt = torch.from_numpy(x) + expected = F.relu(xt).detach().numpy().astype(np.float32) + got = relu(x) + np.testing.assert_allclose(got, expected, rtol=1e-4, atol=1e-5) diff --git a/tests/parity/test_optimizers_parity.py b/tests/parity/test_optimizers_parity.py new file mode 100644 index 0000000..dfc29a8 --- /dev/null +++ b/tests/parity/test_optimizers_parity.py @@ -0,0 +1,94 @@ +""" +Optimizer stepping parity vs PyTorch (optional `torch`). +""" + +import numpy as np +import pytest + + +@pytest.mark.parity +def test_sgd_no_momentum_matches_torch(): + torch = pytest.importorskip("torch") + import torch.optim as optim + from grilly.nn import Parameter + from grilly.optim import SGD + + np.random.seed(7) + w = np.random.randn(5, 4).astype(np.float32) + g = np.random.randn(5, 4).astype(np.float32) + lr = 0.01 + + wt = torch.from_numpy(w.copy()) + wt.requires_grad_(True) + wt.grad = torch.from_numpy(g.copy()) + topt = optim.SGD([wt], lr=lr) + topt.step() + torch_result = wt.detach().numpy() + + p = Parameter(w.copy(), requires_grad=True) + p.grad = g.copy() + gopt = SGD([p], lr=lr, momentum=0.0, use_gpu=False) + gopt.step() + + np.testing.assert_allclose(np.asarray(p, dtype=np.float32), torch_result, rtol=1e-5, atol=1e-6) + + +@pytest.mark.parity +def test_adam_cpu_matches_torch_single_tensor(): + torch = pytest.importorskip("torch") + import torch.optim as optim + from grilly.nn import Parameter + from grilly.optim import Adam + + np.random.seed(8) + w = np.random.randn(3, 5).astype(np.float32) + g = np.random.randn(3, 5).astype(np.float32) + lr = 1e-3 + betas = (0.9, 0.999) + eps = 1e-8 + + wt = torch.from_numpy(w.copy()) + wt.requires_grad_(True) + wt.grad = torch.from_numpy(g.copy()) + topt = optim.Adam([wt], lr=lr, betas=betas, eps=eps) + topt.step() + torch_result = wt.detach().numpy() + + p = Parameter(w.copy(), requires_grad=True) + p.grad = g.copy() + gopt = Adam([p], lr=lr, betas=betas, eps=eps, use_gpu=False) + gopt.step() + + np.testing.assert_allclose(np.asarray(p, dtype=np.float32), torch_result, rtol=1e-4, atol=1e-5) + + +@pytest.mark.parity +def test_adamw_cpu_matches_torch_single_tensor(): + torch = pytest.importorskip("torch") + import torch.optim as optim + from grilly.nn import Parameter + from grilly.optim import AdamW + + np.random.seed(9) + w = np.random.randn(4, 6).astype(np.float32) + g = np.random.randn(4, 6).astype(np.float32) + lr = 1e-3 + betas = (0.9, 0.999) + eps = 1e-8 + weight_decay = 0.01 + + wt = torch.from_numpy(w.copy()) + wt.requires_grad_(True) + wt.grad = torch.from_numpy(g.copy()) + topt = optim.AdamW([wt], lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + topt.step() + torch_result = wt.detach().numpy() + + p = Parameter(w.copy(), requires_grad=True) + p.grad = g.copy() + gopt = AdamW( + [p], lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, use_gpu=False + ) + gopt.step() + + np.testing.assert_allclose(np.asarray(p, dtype=np.float32), torch_result, rtol=1e-4, atol=1e-5) diff --git a/tests/sentence_transformers/test_sentence_transformers_parity.py b/tests/sentence_transformers/test_sentence_transformers_parity.py new file mode 100644 index 0000000..a04688b --- /dev/null +++ b/tests/sentence_transformers/test_sentence_transformers_parity.py @@ -0,0 +1,79 @@ +"""Sentence-transformers adaptive parity tests against reference implementation.""" + +from __future__ import annotations + +import numpy as np +import pytest + +sentence_transformers = pytest.importorskip("sentence_transformers") + + +def _load_reference_model(model_name: str): + try: + return sentence_transformers.SentenceTransformer(model_name, device="cpu") + except Exception as exc: # pragma: no cover - env/model cache dependent + pytest.skip(f"Reference sentence-transformers model unavailable: {exc}") + + +def _load_grilly_model(model_name: str): + try: + from grilly.utils.vulkan_sentence_transformer import VulkanSentenceTransformer + except Exception as exc: + pytest.skip(f"Grilly Vulkan sentence transformer unavailable: {exc}") + + try: + return VulkanSentenceTransformer(model_name=model_name) + except Exception as exc: # pragma: no cover - depends on Vulkan/model assets + pytest.skip(f"Could not initialize VulkanSentenceTransformer: {exc}") + + +def _as_float32(x): + arr = np.asarray(x) + if arr.dtype != np.float32: + arr = arr.astype(np.float32) + return arr + + +@pytest.mark.gpu +def test_sentence_transformer_encode_shape_dtype_parity(): + """Embedding shape/dtype should align with reference model outputs.""" + model_name = "all-MiniLM-L6-v2" + texts = [ + "Grilly runs on Vulkan across vendors.", + "Tokenizer and encoder should stay GPU-resident.", + ] + + ref = _load_reference_model(model_name) + gr = _load_grilly_model(model_name) + + ref_emb = _as_float32(ref.encode(texts, normalize_embeddings=True)) + gr_emb = _as_float32(gr.encode(texts, normalize_embeddings=True)) + + assert ref_emb.shape == gr_emb.shape + assert gr_emb.dtype == np.float32 + assert np.all(np.isfinite(gr_emb)) + + +@pytest.mark.gpu +def test_sentence_transformer_similarity_top1_matches_reference(): + """Top-1 nearest sentence should match reference similarity ranking.""" + model_name = "all-MiniLM-L6-v2" + texts = [ + "How do I speed up GPU dispatch batching?", + "Use one submit for many kernels with command recording.", + "Bananas are yellow and sweet.", + "Fused loss kernels reduce synchronization overhead.", + ] + + ref = _load_reference_model(model_name) + gr = _load_grilly_model(model_name) + + ref_emb = _as_float32(ref.encode(texts, normalize_embeddings=True)) + gr_emb = _as_float32(gr.encode(texts, normalize_embeddings=True)) + + # Query = first sentence, compare against candidates [1:] + ref_sims = ref_emb[1:] @ ref_emb[0] + gr_sims = gr_emb[1:] @ gr_emb[0] + + assert int(np.argmax(gr_sims)) == int(np.argmax(ref_sims)) + diff --git a/tests/sentencepiece/test_sentencepiece_parity.py b/tests/sentencepiece/test_sentencepiece_parity.py new file mode 100644 index 0000000..e179213 --- /dev/null +++ b/tests/sentencepiece/test_sentencepiece_parity.py @@ -0,0 +1,50 @@ +"""SentencePiece compatibility tests (adaptive to in-progress tokenizer API).""" + +import pytest + +sentencepiece = pytest.importorskip("sentencepiece") +from tests._tokenizer_parity_helpers import ( + encode_ids, + extract_input_ids, + load_grilly_tokenizer, + load_hf_tokenizer, +) + + +@pytest.mark.parametrize( + "model_id,text", + [ + ("t5-small", "Translate English to German: A tiny test sentence."), + ("google/mt5-small", "A multilingual sentencepiece parity check."), + ], +) +def test_sentencepiece_ids_match_hf_reference(model_id: str, text: str): + """SentencePiece-backed token IDs should match HF reference on supported models.""" + # Keep explicit import usage for environment signaling. + assert sentencepiece.__name__ == "sentencepiece" + + hf_tok = load_hf_tokenizer(model_id) + gr_tok = load_grilly_tokenizer(model_id) + + hf_ids = extract_input_ids(hf_tok(text)) + gr_ids = encode_ids(gr_tok, text) + + assert list(gr_ids) == list(hf_ids) + + +def test_sentencepiece_special_tokens_alignment(): + """Special token insertion should align with HF on canonical T5 tokenizer.""" + model_id = "t5-small" + text = "summarize: Grilly targets Vulkan GPUs." + + hf_tok = load_hf_tokenizer(model_id) + gr_tok = load_grilly_tokenizer(model_id) + + hf_with = extract_input_ids(hf_tok(text, add_special_tokens=True)) + hf_without = extract_input_ids(hf_tok(text, add_special_tokens=False)) + gr_with = encode_ids(gr_tok, text, add_special_tokens=True) + gr_without = encode_ids(gr_tok, text, add_special_tokens=False) + + assert list(gr_with) == list(hf_with) + assert list(gr_without) == list(hf_without) + diff --git a/tests/test_attention.py b/tests/test_attention.py index c86372d..77ffe93 100644 --- a/tests/test_attention.py +++ b/tests/test_attention.py @@ -7,13 +7,13 @@ try: from grilly import Compute - from grilly.backend.base import VULKAN_AVAILABLE + from grilly.backend.base import VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE except ImportError: pytest.skip("grilly not available", allow_module_level=True) @pytest.mark.gpu -@pytest.mark.skipif(not VULKAN_AVAILABLE, reason="Vulkan not available") +@pytest.mark.skipif(not VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE, reason="Vulkan not available") class TestAttentionOperations: """Test attention operations on GPU""" @@ -60,3 +60,19 @@ def test_attention_output(self, gpu): # Output shape: (batch, seq_len, num_heads, head_dim) assert output.shape == (batch_size, seq_len, num_heads, head_dim) assert np.all(np.isfinite(output)) + + def test_flash_attention2(self, gpu): + """Flash Attention 2 path (batched command buffer).""" + batch_size = 1 + seq_len = 16 + num_heads = 4 + head_dim = 32 + q = np.random.randn(batch_size, seq_len, num_heads * head_dim).astype(np.float32) + k = np.random.randn(batch_size, seq_len, num_heads * head_dim).astype(np.float32) + v = np.random.randn(batch_size, seq_len, num_heads * head_dim).astype(np.float32) + + out = gpu.flash_attention2( + q, k, v, num_heads, head_dim, tile_size_q=16, tile_size_k=16 + ) + assert out.shape == (batch_size, seq_len, num_heads, head_dim) + assert np.all(np.isfinite(out)) diff --git a/tests/test_attention_long_sequence.py b/tests/test_attention_long_sequence.py new file mode 100644 index 0000000..94207fb --- /dev/null +++ b/tests/test_attention_long_sequence.py @@ -0,0 +1,74 @@ +""" +GPU Flash Attention 2 at medium/long sequence lengths (Workstream C3). + +Smoke tests: output shape, finiteness. Tight parity vs a reference attention +implementation is tracked separately (FA2 uses online softmax; paths may diverge). +""" + +import numpy as np +import pytest + +try: + from grilly import Compute + from grilly.backend.base import VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE +except ImportError: + pytest.skip("grilly not available", allow_module_level=True) + + +@pytest.mark.gpu +@pytest.mark.parametrize("seq_len", [128, 256, 512]) +@pytest.mark.skipif(not VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE, reason="Vulkan not available") +def test_flash_attention2_long_sequence_finite(seq_len): + backend = Compute() + try: + batch_size = 1 + num_heads = 2 + head_dim = 32 + rng = np.random.default_rng(42 + seq_len) + q = rng.standard_normal((batch_size, seq_len, num_heads * head_dim), dtype=np.float32) + k_arr = rng.standard_normal((batch_size, seq_len, num_heads * head_dim), dtype=np.float32) + v_arr = rng.standard_normal((batch_size, seq_len, num_heads * head_dim), dtype=np.float32) + + out = backend.flash_attention2( + q, + k_arr, + v_arr, + num_heads, + head_dim, + tile_size_q=32, + tile_size_k=32, + ) + assert out.shape == (batch_size, seq_len, num_heads, head_dim) + assert np.all(np.isfinite(out)) + finally: + backend.cleanup() + + +@pytest.mark.gpu +@pytest.mark.slow +@pytest.mark.parametrize("seq_len", [1024]) +@pytest.mark.skipif(not VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE, reason="Vulkan not available") +def test_flash_attention2_very_long_sequence_finite(seq_len): + backend = Compute() + try: + batch_size = 1 + num_heads = 2 + head_dim = 32 + rng = np.random.default_rng(7) + q = rng.standard_normal((batch_size, seq_len, num_heads * head_dim), dtype=np.float32) + k_arr = rng.standard_normal((batch_size, seq_len, num_heads * head_dim), dtype=np.float32) + v_arr = rng.standard_normal((batch_size, seq_len, num_heads * head_dim), dtype=np.float32) + + out = backend.flash_attention2( + q, + k_arr, + v_arr, + num_heads, + head_dim, + tile_size_q=64, + tile_size_k=64, + ) + assert out.shape == (batch_size, seq_len, num_heads, head_dim) + assert np.all(np.isfinite(out)) + finally: + backend.cleanup() diff --git a/tests/test_backward.py b/tests/test_backward.py index 1e1daf3..306f51f 100644 --- a/tests/test_backward.py +++ b/tests/test_backward.py @@ -7,13 +7,13 @@ try: from grilly import Compute - from grilly.backend.base import VULKAN_AVAILABLE + from grilly.backend.base import VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE except ImportError: pytest.skip("grilly not available", allow_module_level=True) @pytest.mark.gpu -@pytest.mark.skipif(not VULKAN_AVAILABLE, reason="Vulkan not available") +@pytest.mark.skipif(not VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE, reason="Vulkan not available") class TestBackwardOperations: """Test backward pass operations on GPU""" @@ -181,7 +181,7 @@ def test_softmax_backward(self, gpu): @pytest.mark.gpu -@pytest.mark.skipif(not VULKAN_AVAILABLE, reason="Vulkan not available") +@pytest.mark.skipif(not VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE, reason="Vulkan not available") class TestAutogradIntegration: """Test autograd integration with nn.Module""" @@ -240,7 +240,7 @@ def test_training_context(self): @pytest.mark.gpu -@pytest.mark.skipif(not VULKAN_AVAILABLE, reason="Vulkan not available") +@pytest.mark.skipif(not VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE, reason="Vulkan not available") class TestEndToEndTraining: """Test end-to-end training loop""" diff --git a/tests/test_bridge_moe.py b/tests/test_bridge_moe.py new file mode 100644 index 0000000..62b1ca3 --- /dev/null +++ b/tests/test_bridge_moe.py @@ -0,0 +1,14 @@ +"""Bridge exposes fused MoE alongside other grilly_core ops (lazy device + shaders).""" + +from grilly.backend import _bridge + + +def test_bridge_moe_functions_exported(): + for name in ( + "moe_upload", + "moe_forward", + "moe_backward", + "moe_update_weights", + "moe_release", + ): + assert hasattr(_bridge, name), f"missing _bridge.{name}" diff --git a/tests/test_bridge_strict_mode.py b/tests/test_bridge_strict_mode.py new file mode 100644 index 0000000..f7fac12 --- /dev/null +++ b/tests/test_bridge_strict_mode.py @@ -0,0 +1,43 @@ +import numpy as np + + +def test_bridge_strict_mode_raises(monkeypatch): + from grilly.backend import _bridge + + class DummyCore: + @staticmethod + def linear(dev, x, weight, bias): + raise RuntimeError("forced failure") + + monkeypatch.setattr(_bridge, "_core", DummyCore()) + monkeypatch.setattr(_bridge, "_get_device", lambda: object()) + monkeypatch.setattr(_bridge, "_BRIDGE_STRICT", True) + _bridge.reset_fallback_stats() + + x = np.ones((2, 4), dtype=np.float32) + w = np.ones((3, 4), dtype=np.float32) + try: + _bridge.linear(x, w, None) + assert False, "Expected strict bridge failure to raise" + except RuntimeError as e: + assert "GRILLY_BRIDGE_STRICT=1" in str(e) + + +def test_bridge_fallback_stats_increment(monkeypatch): + from grilly.backend import _bridge + + class DummyCore: + @staticmethod + def relu(dev, x): + raise RuntimeError("forced failure") + + monkeypatch.setattr(_bridge, "_core", DummyCore()) + monkeypatch.setattr(_bridge, "_get_device", lambda: object()) + monkeypatch.setattr(_bridge, "_BRIDGE_STRICT", False) + _bridge.reset_fallback_stats() + + x = np.array([-1.0, 2.0], dtype=np.float32) + out = _bridge.relu(x) + assert out is None + stats = _bridge.get_fallback_stats() + assert stats.get("relu", 0) == 1 diff --git a/tests/test_bridge_vsa_lm.py b/tests/test_bridge_vsa_lm.py new file mode 100644 index 0000000..8ac6bbc --- /dev/null +++ b/tests/test_bridge_vsa_lm.py @@ -0,0 +1,14 @@ +"""Bridge exposes fused VSA-LM alongside other grilly_core ops (lazy device + shaders).""" + +from grilly.backend import _bridge + + +def test_bridge_vsa_lm_functions_exported(): + for name in ( + "vsa_lm_upload", + "vsa_lm_forward", + "vsa_lm_backward", + "vsa_lm_update_weights", + "vsa_lm_release", + ): + assert hasattr(_bridge, name), f"missing _bridge.{name}" diff --git a/tests/test_conv_backward_weight_gemm.py b/tests/test_conv_backward_weight_gemm.py new file mode 100644 index 0000000..bca9071 --- /dev/null +++ b/tests/test_conv_backward_weight_gemm.py @@ -0,0 +1,67 @@ +""" +Reference check: conv2d backward weight (GEMM path) vs PyTorch (Workstream C1). + +When `convd_im2col` + `gemm_mnk` + `tensor-transpose` are loaded, weight gradients +should match PyTorch within tolerance. +""" + +import warnings + +import numpy as np +import pytest + +torch = pytest.importorskip("torch") +import torch.nn.functional as F + +try: + from grilly.backend.compute import VulkanCompute + from grilly.backend.base import VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE +except ImportError: + pytest.skip("grilly not available", allow_module_level=True) + + +@pytest.mark.gpu +@pytest.mark.skipif(not VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE, reason="Vulkan not available") +def test_conv2d_backward_weight_matches_torch_gemm_path(): + """Small conv: dW from Grilly GEMM path vs torch autograd.""" + np.random.seed(123) + torch.manual_seed(123) + + batch_size, in_ch, out_ch = 2, 4, 6 + h, w = 8, 8 + kh, kw = 3, 3 + stride = (1, 1) + padding = (1, 1) + + x_np = np.random.randn(batch_size, in_ch, h, w).astype(np.float32) + w_np = np.random.randn(out_ch, in_ch, kh, kw).astype(np.float32) + b_np = np.random.randn(out_ch).astype(np.float32) + + xt = torch.from_numpy(x_np).requires_grad_(True) + wt = torch.from_numpy(w_np).requires_grad_(True) + bt = torch.from_numpy(b_np).requires_grad_(True) + + y_t = F.conv2d(xt, wt, bt, stride=stride, padding=padding) + go = np.random.randn(*y_t.shape).astype(np.float32) + y_t.backward(torch.from_numpy(go)) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + backend = VulkanCompute() + try: + if "convd_im2col" not in backend.conv.shaders: + pytest.skip("convd_im2col not available") + gw, gb = backend.conv.conv2d_backward_weight( + go, + x_np, + (kh, kw), + stride=stride, + padding=padding, + dilation=(1, 1), + groups=1, + has_bias=True, + ) + np.testing.assert_allclose(gw, wt.grad.numpy(), rtol=5e-3, atol=5e-4) + np.testing.assert_allclose(gb, bt.grad.numpy(), rtol=5e-3, atol=5e-4) + finally: + backend.cleanup() diff --git a/tests/test_core.py b/tests/test_core.py index 597b5a8..74b08d3 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -7,12 +7,13 @@ try: import grilly - from grilly.backend.base import VULKAN_AVAILABLE + from grilly.backend.base import VULKAN_AVAILABLE, VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE GRILLY_AVAILABLE = True except ImportError: GRILLY_AVAILABLE = False VULKAN_AVAILABLE = False + VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE = False class TestGrillyImports: @@ -49,7 +50,9 @@ def test_vulkan_available_flag(self): @pytest.mark.gpu -@pytest.mark.skipif(not VULKAN_AVAILABLE, reason="Vulkan not available") +@pytest.mark.skipif( + not VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE, reason="Vulkan not available" +) class TestComputeInitialization: """Test Compute backend initialization""" diff --git a/tests/test_fnn_chain.py b/tests/test_fnn_chain.py new file mode 100644 index 0000000..a1cdfc2 --- /dev/null +++ b/tests/test_fnn_chain.py @@ -0,0 +1,75 @@ +"""FnnChainRecorder: many dispatches, one submit (GPU).""" + +import warnings + +import numpy as np +import pytest + +try: + from grilly.backend.compute import VulkanCompute + from grilly.backend.base import VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE + from grilly.backend.fnn_chain import ChainBufferHandle +except ImportError: + pytest.skip("grilly not available", allow_module_level=True) + + +@pytest.mark.gpu +@pytest.mark.skipif(not VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE, reason="Vulkan not available") +def test_fnn_chain_linear_relu_read(): + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + backend = VulkanCompute() + try: + if "fnn-linear" not in backend.fnn.shaders or "activation-relu" not in backend.fnn.shaders: + pytest.skip("fnn-linear or activation-relu not available") + + rng = np.random.default_rng(0) + x = rng.standard_normal((3, 8), dtype=np.float32) + w1 = rng.standard_normal((16, 8), dtype=np.float32) + b1 = rng.standard_normal((16,), dtype=np.float32) + w2 = rng.standard_normal((4, 16), dtype=np.float32) + + with backend.record_commands(fnn_chain=True) as rec: + h = rec.linear(x, w1, b1) + assert isinstance(h, ChainBufferHandle) + h = rec.relu(h) + h = rec.linear(h, w2, None) + out = rec.read(h) + + ref = np.maximum(0, x @ w1.T + b1) @ w2.T + np.testing.assert_allclose(out, ref, rtol=1e-4, atol=1e-4) + finally: + backend.cleanup() + + +@pytest.mark.gpu +@pytest.mark.skipif(not VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE, reason="Vulkan not available") +def test_fnn_chain_read_multiple_moe_fanout(): + """Parallel expert linears: one submit, multiple downloads (MoE pattern).""" + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + backend = VulkanCompute() + try: + if "fnn-linear" not in backend.fnn.shaders: + pytest.skip("fnn-linear not available") + + rng = np.random.default_rng(1) + x = rng.standard_normal((2, 8), dtype=np.float32) + w0 = rng.standard_normal((4, 8), dtype=np.float32) + w1 = rng.standard_normal((4, 8), dtype=np.float32) + w2 = rng.standard_normal((4, 8), dtype=np.float32) + w3 = rng.standard_normal((4, 8), dtype=np.float32) + + with backend.record_commands(fnn_chain=True) as rec: + h0 = rec.linear(x, w0, None) + h1 = rec.linear(x, w1, None) + h2 = rec.linear(x, w2, None) + h3 = rec.linear(x, w3, None) + results = rec.read_multiple([h0, h1, h2, h3]) + + assert len(results) == 4 + refs = [x @ w0.T, x @ w1.T, x @ w2.T, x @ w3.T] + for got, ref in zip(results, refs): + np.testing.assert_allclose(got, ref, rtol=1e-4, atol=1e-4) + finally: + backend.cleanup() diff --git a/tests/test_functional.py b/tests/test_functional.py index a50b7a9..4facaeb 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -5,14 +5,14 @@ try: from grilly import Compute - from grilly.backend.base import VULKAN_AVAILABLE + from grilly.backend.base import VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE from grilly.functional import dropout, linear except ImportError: pytest.skip("grilly not available", allow_module_level=True) @pytest.mark.gpu -@pytest.mark.skipif(not VULKAN_AVAILABLE, reason="Vulkan not available") +@pytest.mark.skipif(not VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE, reason="Vulkan not available") class TestFunctionalLinear: """Tests for functional.linear.""" @@ -40,7 +40,7 @@ def test_linear_no_bias(self, backend): @pytest.mark.gpu -@pytest.mark.skipif(not VULKAN_AVAILABLE, reason="Vulkan not available") +@pytest.mark.skipif(not VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE, reason="Vulkan not available") class TestFunctionalDropout: """Tests for functional.dropout.""" diff --git a/tests/test_gpu_operations.py b/tests/test_gpu_operations.py index 8007a45..cdabefb 100644 --- a/tests/test_gpu_operations.py +++ b/tests/test_gpu_operations.py @@ -7,13 +7,13 @@ try: from grilly import Compute - from grilly.backend.base import VULKAN_AVAILABLE + from grilly.backend.base import VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE except ImportError: pytest.skip("grilly not available", allow_module_level=True) @pytest.mark.gpu -@pytest.mark.skipif(not VULKAN_AVAILABLE, reason="Vulkan not available") +@pytest.mark.skipif(not VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE, reason="Vulkan not available") class TestLIFOperations: """Test LIF neuron operations on GPU""" @@ -82,7 +82,7 @@ def test_lif_refractory_period(self, gpu): @pytest.mark.gpu -@pytest.mark.skipif(not VULKAN_AVAILABLE, reason="Vulkan not available") +@pytest.mark.skipif(not VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE, reason="Vulkan not available") class TestFNNOperations: """Test FNN operations on GPU""" @@ -202,7 +202,7 @@ def test_layernorm(self, gpu): @pytest.mark.gpu -@pytest.mark.skipif(not VULKAN_AVAILABLE, reason="Vulkan not available") +@pytest.mark.skipif(not VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE, reason="Vulkan not available") class TestFAISSOperations: """Test FAISS operations on GPU""" diff --git a/tests/test_grl_checkpoint.py b/tests/test_grl_checkpoint.py new file mode 100644 index 0000000..f1fd8d2 --- /dev/null +++ b/tests/test_grl_checkpoint.py @@ -0,0 +1,61 @@ +"""GRL (.grl) checkpoint format — roundtrip and magic/version checks.""" + +from __future__ import annotations + +import tempfile + +import numpy as np +import pytest + +from grilly.utils.grl_checkpoint import ( + FORMAT_VERSION, + HEADER_SIZE, + MAGIC, + load_grl, + save_grl, +) + + +def test_grl_roundtrip_nested_state(): + state = { + "embed": {"weight": np.random.randn(10, 8).astype(np.float32)}, + "layers": { + "0": {"w": np.ones((4, 4), dtype=np.float32)}, + }, + } + meta = {"training_step": 42, "best_ppl": 3.14} + + with tempfile.NamedTemporaryFile(suffix=".grl", delete=False) as f: + path = f.name + try: + save_grl(path, state, metadata=meta) + out = load_grl(path) + assert "metadata" in out + assert out["metadata"]["schema"] == "grilly.checkpoint.v1" + assert out["training_step"] == 42 + m = out["model"] + assert np.allclose(m["embed"]["weight"], state["embed"]["weight"]) + assert np.allclose(m["layers"]["0"]["w"], state["layers"]["0"]["w"]) + finally: + import os + + os.unlink(path) + + +def test_grl_rejects_bad_magic(): + with tempfile.NamedTemporaryFile(suffix=".grl", delete=False) as f: + f.write(b"XXXX") + path = f.name + try: + with pytest.raises(ValueError, match="GRL"): + load_grl(path) + finally: + import os + + os.unlink(path) + + +def test_header_constants_match_cpp_layout(): + assert len(MAGIC) == 4 + assert HEADER_SIZE == 64 + assert FORMAT_VERSION == 1 diff --git a/tests/test_inference_ops.py b/tests/test_inference_ops.py index c8543e2..9016323 100644 --- a/tests/test_inference_ops.py +++ b/tests/test_inference_ops.py @@ -8,7 +8,7 @@ try: from grilly import Compute - from grilly.backend.base import VULKAN_AVAILABLE + from grilly.backend.base import VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE except ImportError: pytest.skip("grilly not available", allow_module_level=True) @@ -96,7 +96,7 @@ def _ref_gqa_decode_attention( @pytest.mark.gpu -@pytest.mark.skipif(not VULKAN_AVAILABLE, reason="Vulkan not available") +@pytest.mark.skipif(not VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE, reason="Vulkan not available") class TestRMSNormGPU: """Test RMSNorm on GPU""" @@ -174,7 +174,7 @@ def test_rms_norm_eps_affects_output(self, gpu): @pytest.mark.gpu -@pytest.mark.skipif(not VULKAN_AVAILABLE, reason="Vulkan not available") +@pytest.mark.skipif(not VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE, reason="Vulkan not available") class TestSwiGLUFusedGPU: """Test fused SwiGLU on GPU""" @@ -231,7 +231,7 @@ def test_swiglu_fused_output_shape(self, gpu): @pytest.mark.gpu -@pytest.mark.skipif(not VULKAN_AVAILABLE, reason="Vulkan not available") +@pytest.mark.skipif(not VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE, reason="Vulkan not available") class TestGEMMInt8GPU: """Test INT8 weight-only GEMM on GPU""" @@ -295,7 +295,7 @@ def test_gemm_int8_group_sizes(self, gpu, group_size): @pytest.mark.gpu -@pytest.mark.skipif(not VULKAN_AVAILABLE, reason="Vulkan not available") +@pytest.mark.skipif(not VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE, reason="Vulkan not available") class TestGQADecodeAttentionGPU: """Test GQA decode attention on GPU""" diff --git a/tests/test_integration.py b/tests/test_integration.py index 8fc3f7e..963171d 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -7,7 +7,7 @@ try: from grilly import Compute, SNNCompute - from grilly.backend.base import VULKAN_AVAILABLE + from grilly.backend.base import VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE except ImportError: pytest.skip("grilly not available", allow_module_level=True) @@ -53,7 +53,7 @@ def test_snn_temporal_dynamics(self): @pytest.mark.gpu -@pytest.mark.skipif(not VULKAN_AVAILABLE, reason="Vulkan not available") +@pytest.mark.skipif(not VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE, reason="Vulkan not available") class TestGPUIntegration: """Integration tests for GPU operations""" diff --git a/tests/test_learning.py b/tests/test_learning.py index e734ac9..5c97095 100644 --- a/tests/test_learning.py +++ b/tests/test_learning.py @@ -7,13 +7,13 @@ try: from grilly import Compute, Learning - from grilly.backend.base import VULKAN_AVAILABLE + from grilly.backend.base import VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE except ImportError: pytest.skip("grilly not available", allow_module_level=True) @pytest.mark.gpu -@pytest.mark.skipif(not VULKAN_AVAILABLE, reason="Vulkan not available") +@pytest.mark.skipif(not VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE, reason="Vulkan not available") class TestLearningOperations: """Test learning operations on GPU""" diff --git a/tests/test_memory_operations.py b/tests/test_memory_operations.py index 3b0cb33..4b46ab8 100644 --- a/tests/test_memory_operations.py +++ b/tests/test_memory_operations.py @@ -7,13 +7,13 @@ try: from grilly import Compute - from grilly.backend.base import VULKAN_AVAILABLE + from grilly.backend.base import VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE except ImportError: pytest.skip("grilly not available", allow_module_level=True) @pytest.mark.gpu -@pytest.mark.skipif(not VULKAN_AVAILABLE, reason="Vulkan not available") +@pytest.mark.skipif(not VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE, reason="Vulkan not available") class TestMemoryOperations: """Test memory read/write operations""" diff --git a/tests/test_mf_activations.py b/tests/test_mf_activations.py new file mode 100644 index 0000000..713d566 --- /dev/null +++ b/tests/test_mf_activations.py @@ -0,0 +1,50 @@ +"""Multiplication-free activations: numpy + autograd.""" + +import numpy as np +import pytest +from grilly.functional.mf_activations import ( + mf_sigmoid, + mf_sigmoid_01, + mf_softmax, + mf_softplus, +) +from grilly.nn import autograd as ag + + +def test_mf_softmax_rows_sum_to_one(): + x = np.array([[1.0, 2.0, 0.0], [0.0, 0.0, 0.0]], dtype=np.float32) + y = mf_softmax(x, dim=-1) + assert y.shape == x.shape + np.testing.assert_allclose(y.sum(axis=-1), np.ones(2), rtol=1e-5) + + +def test_mf_softplus_matches_algebra(): + x = np.array([-2.0, 0.0, 3.0], dtype=np.float32) + b = 1.0 + c = 4.0 / (b * b) + want = 0.5 * (x + np.sqrt(x * x + c)) + got = mf_softplus(x, beta=b) + np.testing.assert_allclose(got, want, rtol=1e-6) + + +def test_mf_sigmoid_bounds(): + x = np.linspace(-5, 5, 11, dtype=np.float32) + y = mf_sigmoid(x) + assert float(y.min()) >= -1.0 and float(y.max()) <= 1.0 + z = mf_sigmoid_01(x) + assert float(z.min()) >= 0.0 and float(z.max()) <= 1.0 + + +def test_mf_softmax_autograd(): + v = ag.Variable(np.array([[1.0, 2.0]], dtype=np.float32), requires_grad=True) + y = ag.mf_softmax(v, dim=-1) + assert y.data.sum() == pytest.approx(1.0) + y.backward() + assert v.grad is not None + + +def test_mf_softplus_autograd(): + v = ag.Variable(np.array([0.5], dtype=np.float32), requires_grad=True) + y = ag.mf_softplus(v, beta=1.0) + y.backward() + assert v.grad is not None diff --git a/tests/test_mf_ops_core.py b/tests/test_mf_ops_core.py new file mode 100644 index 0000000..a5aa4ab --- /dev/null +++ b/tests/test_mf_ops_core.py @@ -0,0 +1,69 @@ +"""grilly_core mf_softmax / mf_softplus / mf_sigmoid — GPU vs NumPy reference.""" + +from __future__ import annotations + +import pathlib + +import numpy as np +import pytest + +try: + import grilly_core as gc +except ImportError: + pytest.skip("grilly_core not available", allow_module_level=True) + +from grilly.functional.mf_activations import mf_sigmoid, mf_softmax, mf_softplus + + +def _shader_spv_dir() -> pathlib.Path: + return pathlib.Path(__file__).resolve().parent.parent / "shaders" / "spv" + + +def _require_mf_symbols() -> None: + for name in ("mf_softmax", "mf_softplus", "mf_sigmoid"): + if not hasattr(gc, name): + pytest.skip(f"grilly_core.{name} not in this build — rebuild extension") + + +@pytest.mark.gpu +@pytest.mark.cpp +def test_mf_softmax_gpu_matches_numpy() -> None: + _require_mf_symbols() + if not _shader_spv_dir().exists(): + pytest.skip("shaders/spv not present") + + rng = np.random.default_rng(0) + x = rng.standard_normal((4, 7), dtype=np.float32) + + dev = gc.Device() + try: + dev.load_shaders(str(_shader_spv_dir())) + except Exception as e: + pytest.skip(f"load_shaders failed: {e}") + + got = gc.mf_softmax(dev, x, -1) + ref = mf_softmax(x, dim=-1) + np.testing.assert_allclose(got.numpy(), ref, rtol=1e-4, atol=1e-5) + + +@pytest.mark.gpu +@pytest.mark.cpp +def test_mf_softplus_and_sigmoid_gpu_match_numpy() -> None: + _require_mf_symbols() + if not _shader_spv_dir().exists(): + pytest.skip("shaders/spv not present") + + rng = np.random.default_rng(1) + x = rng.standard_normal((3, 5), dtype=np.float32) + + dev = gc.Device() + try: + dev.load_shaders(str(_shader_spv_dir())) + except Exception as e: + pytest.skip(f"load_shaders failed: {e}") + + g_sig = gc.mf_sigmoid(dev, x) + np.testing.assert_allclose(g_sig.numpy(), mf_sigmoid(x), rtol=1e-5, atol=1e-6) + + g_sp = gc.mf_softplus(dev, x, 1.25) + np.testing.assert_allclose(g_sp.numpy(), mf_softplus(x, 1.25), rtol=1e-4, atol=1e-5) diff --git a/tests/test_moe_forward.py b/tests/test_moe_forward.py new file mode 100644 index 0000000..ffd5fb6 --- /dev/null +++ b/tests/test_moe_forward.py @@ -0,0 +1,158 @@ +"""grilly_core.moe_* — fused MoE forward (GPU) vs NumPy reference.""" + +from __future__ import annotations + +import pathlib + +import numpy as np +import pytest + +try: + import grilly_core as gc +except ImportError: + pytest.skip("grilly_core not available", allow_module_level=True) + +try: + from grilly.backend.base import VULKAN_AVAILABLE +except ImportError: + VULKAN_AVAILABLE = False + + +def _shader_spv_dir() -> pathlib.Path: + return pathlib.Path(__file__).resolve().parent.parent / "shaders" / "spv" + + +def _softmax(x: np.ndarray) -> np.ndarray: + x = x - x.max() + e = np.exp(x, dtype=np.float32) + return (e / e.sum()).astype(np.float32) + + +def moe_forward_numpy( + embed_w: np.ndarray, + pos_w: np.ndarray, + expert_ws: list[np.ndarray], + router_ws: list[np.ndarray], + router_bs: list[np.ndarray], + out_w: np.ndarray, + n_layers: int, + n_experts: int, + input_ids: np.ndarray, +) -> np.ndarray: + """Reference forward (CPU).""" + s_len = int(input_ids.shape[0]) + d = embed_w.shape[1] + x = embed_w[input_ids.astype(np.int64)] + pos_w[:s_len] + for layer in range(n_layers): + xm = x.mean(axis=0) + logits = router_ws[layer] @ xm + router_bs[layer] + p = _softmax(logits.astype(np.float32)) + blend = np.zeros_like(x, dtype=np.float32) + for e in range(n_experts): + we = expert_ws[layer * n_experts + e] + y = x @ we.T + blend += p[e] * y + x = x + blend + return x @ out_w.T + + +@pytest.mark.gpu +@pytest.mark.cpp +@pytest.mark.skipif(not VULKAN_AVAILABLE, reason="Vulkan not available") +def test_moe_forward_parity_small(): + if not _shader_spv_dir().exists(): + pytest.skip("shaders/spv not present") + + rng = np.random.default_rng(7) + vocab, d, max_seq = 24, 8, 12 + n_layers, n_experts = 2, 4 + s_len = 6 + + embed_w = rng.standard_normal((vocab, d), dtype=np.float32) + pos_w = rng.standard_normal((max_seq, d), dtype=np.float32) + out_w = rng.standard_normal((vocab, d), dtype=np.float32) + + expert_ws = [rng.standard_normal((d, d), dtype=np.float32) for _ in range(n_layers * n_experts)] + router_ws = [rng.standard_normal((n_experts, d), dtype=np.float32) for _ in range(n_layers)] + router_bs = [rng.standard_normal((n_experts,), dtype=np.float32) for _ in range(n_layers)] + + input_ids = rng.integers(0, vocab, size=(s_len,), dtype=np.int32) + + ref = moe_forward_numpy( + embed_w, + pos_w, + expert_ws, + router_ws, + router_bs, + out_w, + n_layers, + n_experts, + input_ids, + ) + + dev = gc.Device() + try: + dev.load_shaders(str(_shader_spv_dir())) + except Exception as e: + pytest.skip(f"load_shaders failed: {e}") + + h = gc.moe_upload( + dev, + embed_w, + pos_w, + expert_ws, + router_ws, + router_bs, + out_w, + n_layers, + n_experts, + ) + try: + got = gc.moe_forward(dev, h, input_ids) + np.testing.assert_allclose(got, ref, rtol=1e-3, atol=1e-3) + + grad_logits = rng.standard_normal((s_len, vocab), dtype=np.float32) + grads = gc.moe_backward(dev, h, input_ids, grad_logits) + assert grads["grad_embed"].shape == (vocab, d) + assert grads["grad_out_w"].shape == (vocab, d) + assert len(grads["grad_experts"]) == n_layers * n_experts + assert grads["grad_pos"].shape == (max_seq, d) + finally: + gc.moe_release(dev, h) + + +@pytest.mark.cpp +def test_moe_backward_cpu_shapes(): + """Backward is CPU-only; smoke-test shapes without Vulkan.""" + rng = np.random.default_rng(11) + vocab, d, max_seq = 16, 4, 8 + n_layers, n_experts = 1, 2 + s_len = 4 + + embed_w = rng.standard_normal((vocab, d), dtype=np.float32) + pos_w = rng.standard_normal((max_seq, d), dtype=np.float32) + out_w = rng.standard_normal((vocab, d), dtype=np.float32) + expert_ws = [rng.standard_normal((d, d), dtype=np.float32) for _ in range(n_layers * n_experts)] + router_ws = [rng.standard_normal((n_experts, d), dtype=np.float32) for _ in range(n_layers)] + router_bs = [rng.standard_normal((n_experts,), dtype=np.float32) for _ in range(n_layers)] + input_ids = rng.integers(0, vocab, size=(s_len,), dtype=np.int32) + grad_logits = rng.standard_normal((s_len, vocab), dtype=np.float32) + + dev = gc.Device() + h = gc.moe_upload( + dev, + embed_w, + pos_w, + expert_ws, + router_ws, + router_bs, + out_w, + n_layers, + n_experts, + ) + try: + grads = gc.moe_backward(dev, h, input_ids, grad_logits) + assert grads["grad_embed"].shape == (vocab, d) + assert grads["grad_out_w"].shape == (vocab, d) + finally: + gc.moe_release(dev, h) diff --git a/tests/test_snn.py b/tests/test_snn.py index 43d6bb8..7fb36ac 100644 --- a/tests/test_snn.py +++ b/tests/test_snn.py @@ -7,7 +7,7 @@ try: from grilly import SNNCompute - from grilly.backend.base import VULKAN_AVAILABLE + from grilly.backend.base import VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE except ImportError: pytest.skip("grilly not available", allow_module_level=True) @@ -136,7 +136,7 @@ def test_process_handles_large_embedding(self): @pytest.mark.gpu -@pytest.mark.skipif(not VULKAN_AVAILABLE, reason="Vulkan not available") +@pytest.mark.skipif(not VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE, reason="Vulkan not available") class TestSNNGPU: """Test SNN with GPU (if available)""" diff --git a/tests/test_torch_api.py b/tests/test_torch_api.py new file mode 100644 index 0000000..28eea91 --- /dev/null +++ b/tests/test_torch_api.py @@ -0,0 +1,50 @@ +"""Smoke tests for ``grilly.torch_api`` (torch-style facade, no PyTorch).""" + +import numpy as np +import pytest + +import grilly.torch_api as torch +from grilly.nn import autograd as ag + + +def test_device_and_vulkan(): + d = torch.device("vulkan") + assert "vulkan" in str(d) + assert isinstance(torch.vulkan.is_available(), bool) + + +def test_tensor_long_and_ops(): + x = torch.tensor([[1, 2], [3, 4]], dtype=torch.long) + assert x.shape == (2, 2) + z = torch.zeros(3) + z.uniform_(-1, 1) + assert z.shape == (3,) + p = torch.randperm(10) + assert p.shape == (10,) + a = torch.randn(2, 3) + b = torch.randn(2, 3) + d = torch.cdist(a.unsqueeze(0), b.unsqueeze(0), p=1) + assert d.shape == (1, 2, 2) + + +def test_functional_cross_entropy_sum(): + logits = ag.randn(5, 3, requires_grad=True) + target = np.array([0, 1, 2, 1, 0], dtype=np.int64) + loss = torch.nn.functional.cross_entropy(logits, target, reduction="sum") + loss.backward() + assert loss.data.size == 1 + + +def test_amp_namespace(): + assert hasattr(torch.amp, "autocast") + assert hasattr(torch.amp, "GradScaler") + s = torch.amp.GradScaler("vulkan", enabled=False) + assert s.get_scale() == 1.0 + + +def test_grl_save_load_roundtrip(tmp_path): + path = tmp_path / "t.grl" + state = {"model": {"w": np.ones((2, 2), dtype=np.float32)}, "step": 3, "best_ppl": 1.25} + torch.save(state, path) + out = torch.load(path, map_location=torch.device("cpu")) + assert "model" in out or "metadata" in out diff --git a/tests/test_vsa_lm_forward.py b/tests/test_vsa_lm_forward.py new file mode 100644 index 0000000..976a568 --- /dev/null +++ b/tests/test_vsa_lm_forward.py @@ -0,0 +1,282 @@ +"""grilly_core.vsa_lm_* — fused VSA-LM forward (GPU) + backward (CPU) tests.""" + +from __future__ import annotations + +import pathlib + +import numpy as np +import pytest + +try: + import grilly_core as gc +except ImportError: + pytest.skip("grilly_core not available", allow_module_level=True) + +try: + from grilly.backend.base import VULKAN_AVAILABLE +except ImportError: + VULKAN_AVAILABLE = False + + +def _shader_spv_dir() -> pathlib.Path: + return pathlib.Path(__file__).resolve().parent.parent / "shaders" / "spv" + + +def _addition_linear_numpy(x: np.ndarray, w: np.ndarray, b: np.ndarray) -> np.ndarray: + """y[s, o] = -sum_k |w[o, k] - x[s, k]| + b[o]""" + out = np.empty((x.shape[0], w.shape[0]), dtype=np.float32) + for s in range(x.shape[0]): + for o in range(w.shape[0]): + out[s, o] = -np.sum(np.abs(w[o] - x[s])) + b[o] + return out + + +def _sign_activation(x: np.ndarray) -> np.ndarray: + return np.where(x > 0, 1.0, -1.0).astype(np.float32) + + +def _layernorm_numpy(x: np.ndarray, gamma: np.ndarray, beta: np.ndarray, + eps: float = 1e-5) -> np.ndarray: + mean = x.mean(axis=-1, keepdims=True) + var = x.var(axis=-1, keepdims=True) + return ((x - mean) / np.sqrt(var + eps) * gamma + beta).astype(np.float32) + + +def vsa_lm_forward_numpy( + embed_w, pos_w, ffn_up_ws, ffn_up_bs, ffn_down_ws, ffn_down_bs, + ln_gammas, ln_betas, out_w, n_layers, input_ids, +): + """NumPy reference forward.""" + s_len = input_ids.shape[0] + d = embed_w.shape[1] + + x = embed_w[input_ids.astype(np.int64)] + pos_w[:s_len] + + for l in range(n_layers): + h = _layernorm_numpy(x, ln_gammas[l], ln_betas[l]) + h_up = _addition_linear_numpy(h, ffn_up_ws[l], ffn_up_bs[l]) + h_sign = _sign_activation(h_up) + h_ffn = _addition_linear_numpy(h_sign, ffn_down_ws[l], ffn_down_bs[l]) + x = x + h_ffn + + scale = 1.0 / np.sqrt(d).astype(np.float32) + logits = (x @ out_w.T) * scale + return logits.astype(np.float32) + + +def _make_weights(rng, vocab, d, d_ffn, max_seq, n_layers): + embed_w = rng.standard_normal((vocab, d)).astype(np.float32) * 0.02 + pos_w = rng.standard_normal((max_seq, d)).astype(np.float32) * 0.02 + out_w = rng.standard_normal((vocab, d)).astype(np.float32) * 0.02 + + ffn_up_ws = [rng.standard_normal((d_ffn, d)).astype(np.float32) * 0.02 + for _ in range(n_layers)] + ffn_up_bs = [np.zeros(d_ffn, dtype=np.float32) for _ in range(n_layers)] + ffn_down_ws = [rng.standard_normal((d, d_ffn)).astype(np.float32) * 0.02 + for _ in range(n_layers)] + ffn_down_bs = [np.zeros(d, dtype=np.float32) for _ in range(n_layers)] + ln_gammas = [np.ones(d, dtype=np.float32) for _ in range(n_layers)] + ln_betas = [np.zeros(d, dtype=np.float32) for _ in range(n_layers)] + + return (embed_w, pos_w, out_w, ffn_up_ws, ffn_up_bs, + ffn_down_ws, ffn_down_bs, ln_gammas, ln_betas) + + +@pytest.mark.gpu +@pytest.mark.cpp +@pytest.mark.skipif(not VULKAN_AVAILABLE, reason="Vulkan not available") +def test_vsa_lm_forward_parity_small(): + """Forward parity: GPU vs NumPy at small scale.""" + if not _shader_spv_dir().exists(): + pytest.skip("shaders/spv not present") + + rng = np.random.default_rng(42) + vocab, d, d_ffn, max_seq = 24, 8, 16, 12 + n_layers = 2 + s_len = 6 + input_ids = rng.integers(0, vocab, size=(s_len,), dtype=np.int32) + + (embed_w, pos_w, out_w, ffn_up_ws, ffn_up_bs, + ffn_down_ws, ffn_down_bs, ln_gammas, ln_betas) = _make_weights( + rng, vocab, d, d_ffn, max_seq, n_layers) + + ref = vsa_lm_forward_numpy( + embed_w, pos_w, ffn_up_ws, ffn_up_bs, ffn_down_ws, ffn_down_bs, + ln_gammas, ln_betas, out_w, n_layers, input_ids) + + dev = gc.Device() + try: + dev.load_shaders(str(_shader_spv_dir())) + except Exception as e: + pytest.skip(f"load_shaders failed: {e}") + + h = gc.vsa_lm_upload( + dev, embed_w, pos_w, + ffn_up_ws, ffn_up_bs, ffn_down_ws, ffn_down_bs, + ln_gammas, ln_betas, out_w, + n_layers, d, d_ffn) + try: + got = gc.vsa_lm_forward(dev, h, input_ids) + assert got.shape == (s_len, vocab), f"shape: {got.shape}" + np.testing.assert_allclose(got, ref, rtol=1e-2, atol=1e-2) + finally: + gc.vsa_lm_release(dev, h) + + +@pytest.mark.gpu +@pytest.mark.cpp +@pytest.mark.skipif(not VULKAN_AVAILABLE, reason="Vulkan not available") +def test_vsa_lm_forward_parity_medium(): + """Forward parity at target-ish dimensions (d=64, 4 layers).""" + if not _shader_spv_dir().exists(): + pytest.skip("shaders/spv not present") + + rng = np.random.default_rng(99) + vocab, d, d_ffn, max_seq = 100, 64, 128, 32 + n_layers = 4 + s_len = 16 + input_ids = rng.integers(0, vocab, size=(s_len,), dtype=np.int32) + + (embed_w, pos_w, out_w, ffn_up_ws, ffn_up_bs, + ffn_down_ws, ffn_down_bs, ln_gammas, ln_betas) = _make_weights( + rng, vocab, d, d_ffn, max_seq, n_layers) + + ref = vsa_lm_forward_numpy( + embed_w, pos_w, ffn_up_ws, ffn_up_bs, ffn_down_ws, ffn_down_bs, + ln_gammas, ln_betas, out_w, n_layers, input_ids) + + dev = gc.Device() + try: + dev.load_shaders(str(_shader_spv_dir())) + except Exception as e: + pytest.skip(f"load_shaders failed: {e}") + + h = gc.vsa_lm_upload( + dev, embed_w, pos_w, + ffn_up_ws, ffn_up_bs, ffn_down_ws, ffn_down_bs, + ln_gammas, ln_betas, out_w, + n_layers, d, d_ffn) + try: + got = gc.vsa_lm_forward(dev, h, input_ids) + assert got.shape == (s_len, vocab) + np.testing.assert_allclose(got, ref, rtol=5e-2, atol=5e-2) + finally: + gc.vsa_lm_release(dev, h) + + +@pytest.mark.cpp +def test_vsa_lm_backward_cpu_shapes(): + """Backward is CPU-only — smoke-test shapes without Vulkan.""" + rng = np.random.default_rng(7) + vocab, d, d_ffn, max_seq = 16, 4, 8, 8 + n_layers = 1 + s_len = 4 + input_ids = rng.integers(0, vocab, size=(s_len,), dtype=np.int32) + grad_logits = rng.standard_normal((s_len, vocab)).astype(np.float32) + + (embed_w, pos_w, out_w, ffn_up_ws, ffn_up_bs, + ffn_down_ws, ffn_down_bs, ln_gammas, ln_betas) = _make_weights( + rng, vocab, d, d_ffn, max_seq, n_layers) + + dev = gc.Device() + h = gc.vsa_lm_upload( + dev, embed_w, pos_w, + ffn_up_ws, ffn_up_bs, ffn_down_ws, ffn_down_bs, + ln_gammas, ln_betas, out_w, + n_layers, d, d_ffn) + try: + grads = gc.vsa_lm_backward(dev, h, input_ids, grad_logits) + assert grads["grad_embed"].shape == (vocab, d) + assert grads["grad_pos"].shape == (max_seq, d) + assert grads["grad_out_w"].shape == (vocab, d) + assert len(grads["grad_ffn_up_w"]) == n_layers + assert grads["grad_ffn_up_w"][0].shape == (d_ffn, d) + assert grads["grad_ffn_up_b"][0].shape == (d_ffn,) + assert grads["grad_ffn_down_w"][0].shape == (d, d_ffn) + assert grads["grad_ffn_down_b"][0].shape == (d,) + assert grads["grad_ln_gamma"][0].shape == (d,) + assert grads["grad_ln_beta"][0].shape == (d,) + + for key in ["grad_embed", "grad_pos", "grad_out_w"]: + assert np.isfinite(grads[key]).all(), f"{key} has non-finite values" + finally: + gc.vsa_lm_release(dev, h) + + +@pytest.mark.cpp +def test_vsa_lm_update_weights(): + """Verify update_weights re-uploads without crash.""" + rng = np.random.default_rng(13) + vocab, d, d_ffn, max_seq = 16, 4, 8, 8 + n_layers = 1 + + (embed_w, pos_w, out_w, ffn_up_ws, ffn_up_bs, + ffn_down_ws, ffn_down_bs, ln_gammas, ln_betas) = _make_weights( + rng, vocab, d, d_ffn, max_seq, n_layers) + + dev = gc.Device() + h = gc.vsa_lm_upload( + dev, embed_w, pos_w, + ffn_up_ws, ffn_up_bs, ffn_down_ws, ffn_down_bs, + ln_gammas, ln_betas, out_w, + n_layers, d, d_ffn) + try: + # Perturb all weights slightly and re-upload + embed_w2 = embed_w + 0.001 + pos_w2 = pos_w + 0.001 + out_w2 = out_w + 0.001 + ffn_up_ws2 = [w + 0.001 for w in ffn_up_ws] + ffn_up_bs2 = [b + 0.001 for b in ffn_up_bs] + ffn_down_ws2 = [w + 0.001 for w in ffn_down_ws] + ffn_down_bs2 = [b + 0.001 for b in ffn_down_bs] + ln_gammas2 = [g + 0.001 for g in ln_gammas] + ln_betas2 = [b + 0.001 for b in ln_betas] + + gc.vsa_lm_update_weights( + dev, h, embed_w2, pos_w2, + ffn_up_ws2, ffn_up_bs2, ffn_down_ws2, ffn_down_bs2, + ln_gammas2, ln_betas2, out_w2) + finally: + gc.vsa_lm_release(dev, h) + + +@pytest.mark.gpu +@pytest.mark.cpp +@pytest.mark.skipif(not VULKAN_AVAILABLE, reason="Vulkan not available") +def test_vsa_lm_large_scale_no_crash(): + """Ensure no segfault at production-ish shape (d=256, 6 layers).""" + if not _shader_spv_dir().exists(): + pytest.skip("shaders/spv not present") + + rng = np.random.default_rng(55) + vocab, d, d_ffn, max_seq = 1000, 256, 512, 512 + n_layers = 6 + s_len = 64 + input_ids = rng.integers(0, vocab, size=(s_len,), dtype=np.int32) + + (embed_w, pos_w, out_w, ffn_up_ws, ffn_up_bs, + ffn_down_ws, ffn_down_bs, ln_gammas, ln_betas) = _make_weights( + rng, vocab, d, d_ffn, max_seq, n_layers) + + dev = gc.Device() + try: + dev.load_shaders(str(_shader_spv_dir())) + except Exception as e: + pytest.skip(f"load_shaders failed: {e}") + + h = gc.vsa_lm_upload( + dev, embed_w, pos_w, + ffn_up_ws, ffn_up_bs, ffn_down_ws, ffn_down_bs, + ln_gammas, ln_betas, out_w, + n_layers, d, d_ffn) + try: + logits = gc.vsa_lm_forward(dev, h, input_ids) + assert logits.shape == (s_len, vocab) + assert np.isfinite(logits).all(), "logits contain non-finite values" + + grad_logits = rng.standard_normal((s_len, vocab)).astype(np.float32) + grads = gc.vsa_lm_backward(dev, h, input_ids, grad_logits) + assert grads["grad_embed"].shape == (vocab, d) + assert np.isfinite(grads["grad_embed"]).all() + finally: + gc.vsa_lm_release(dev, h) diff --git a/tests/test_vulkan_tensor_residency.py b/tests/test_vulkan_tensor_residency.py new file mode 100644 index 0000000..b6000d6 --- /dev/null +++ b/tests/test_vulkan_tensor_residency.py @@ -0,0 +1,40 @@ +""" +VulkanTensor GPU residency: prepare_for_dispatch binds buffers without redundant upload. +""" + +import warnings + +import numpy as np +import pytest + +try: + from grilly.backend.compute import VulkanCompute + from grilly.backend.base import VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE + from grilly.utils.tensor_conversion import VulkanTensor +except ImportError: + pytest.skip("grilly not available", allow_module_level=True) + + +@pytest.mark.gpu +@pytest.mark.skipif(not VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE, reason="Vulkan not available") +def test_prepare_for_dispatch_binds_without_double_upload(): + """After one GPU op, a VulkanTensor should expose a buffer for the next op without re-uploading.""" + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + backend = VulkanCompute() + try: + if "activation-relu" not in backend.fnn.shaders: + pytest.skip("activation-relu not available") + + x = np.random.randn(4, 8).astype(np.float32) + out = backend.fnn.activation_relu(x, return_gpu_tensor=True) + assert isinstance(out, VulkanTensor) + + out.prepare_for_dispatch() + assert out._pooled_buffer is not None or out._gpu_buffer is not None + + out2 = backend.fnn.activation_relu(out, return_gpu_tensor=False) + assert out2.shape == (4, 8) + assert np.all(np.isfinite(out2)) + finally: + backend.cleanup() diff --git a/tests/tokenizers/test_gpu_tokenizer_parity.py b/tests/tokenizers/test_gpu_tokenizer_parity.py new file mode 100644 index 0000000..a370083 --- /dev/null +++ b/tests/tokenizers/test_gpu_tokenizer_parity.py @@ -0,0 +1,54 @@ +"""Tokenizer parity tests (first real pass, adaptive to in-progress API).""" + +import pytest + +from tests._tokenizer_parity_helpers import ( + encode_ids, + extract_input_ids, + load_grilly_tokenizer, + load_hf_tokenizer, +) + + +@pytest.mark.parametrize( + "model_id,text", + [ + ("bert-base-uncased", "Hello from Grilly GPU tokenization."), + ("distilbert-base-uncased", "Tokenizer parity check with punctuation: !?.,"), + ], +) +def test_tokenizer_ids_match_hf_reference(model_id: str, text: str): + """Token IDs should match Hugging Face for covered model/tokenizer assets.""" + hf_tok = load_hf_tokenizer(model_id) + gr_tok = load_grilly_tokenizer(model_id) + + hf_ids = extract_input_ids(hf_tok(text)) + gr_ids = encode_ids(gr_tok, text) + + assert list(gr_ids) == list(hf_ids) + + +def test_batch_encode_matches_hf_reference(): + """Batch tokenization should match Hugging Face input_ids for canonical uncased BERT.""" + model_id = "bert-base-uncased" + texts = [ + "Grilly runs on Vulkan.", + "Any GPU, one backend.", + "Parity matters before v1.0.", + ] + hf_tok = load_hf_tokenizer(model_id) + gr_tok = load_grilly_tokenizer(model_id) + + hf_batch = hf_tok(texts, padding=False, truncation=False) + hf_ids = hf_batch["input_ids"] + + if hasattr(gr_tok, "batch_encode"): + gr_batch = gr_tok.batch_encode(texts) + elif hasattr(gr_tok, "__call__"): + gr_batch = gr_tok(texts) + else: + pytest.skip("Tokenizer missing batch_encode/__call__ batch path") + + gr_ids = extract_input_ids(gr_batch) + assert [list(x) for x in gr_ids] == [list(x) for x in hf_ids] + diff --git a/tests/transformers_compat/test_transformers_compat_placeholder.py b/tests/transformers_compat/test_transformers_compat_placeholder.py new file mode 100644 index 0000000..078c848 --- /dev/null +++ b/tests/transformers_compat/test_transformers_compat_placeholder.py @@ -0,0 +1,13 @@ +"""Pre-v1.0 scaffold: transformers compatibility 1:1 test suite placeholder.""" + +import pytest + +pytestmark = pytest.mark.skip( + reason="Scaffold only: implement transformers compatibility tests in pre-v1.0 roadmap." +) + + +def test_transformers_compat_scaffold(): + """Placeholder test to reserve CI target.""" + assert True + diff --git a/tokenizer_impl/__init__.py b/tokenizer_impl/__init__.py new file mode 100644 index 0000000..13ba025 --- /dev/null +++ b/tokenizer_impl/__init__.py @@ -0,0 +1,32 @@ +""" +Grilly tokenizers (P0 roadmap). + +Uses the standalone Rust ``tokenizers`` library (same ``tokenizer.json`` assets as HF Fast +tokenizers) — **no** ``transformers`` / ``AutoTokenizer`` delegation in this package. +Native GPU merge/BPE kernels will layer behind the same :class:`Tokenizer` interface. + +Public API: + from grilly.tokenizers import AutoTokenizer, Tokenizer + tok = AutoTokenizer.from_pretrained("bert-base-uncased") + ids = tok.encode("Hello") +""" + +from __future__ import annotations + +from .auto import AutoTokenizer, from_pretrained +from .base import Tokenizer as TokenizerBase +from .fast_tokenizer import FastTokenizer +from .gpu import GPU_TOKENIZER_AVAILABLE, numpy_to_input_ids_buffers, wrap_ids_as_vulkan_tensors + +Tokenizer = FastTokenizer + +__all__ = [ + "AutoTokenizer", + "FastTokenizer", + "Tokenizer", + "TokenizerBase", + "from_pretrained", + "GPU_TOKENIZER_AVAILABLE", + "numpy_to_input_ids_buffers", + "wrap_ids_as_vulkan_tensors", +] diff --git a/tokenizer_impl/auto.py b/tokenizer_impl/auto.py new file mode 100644 index 0000000..e105f19 --- /dev/null +++ b/tokenizer_impl/auto.py @@ -0,0 +1,20 @@ +"""AutoTokenizer entrypoint (Rust ``tokenizers`` backend; no ``transformers``).""" + +from __future__ import annotations + +from typing import Any + +from .fast_tokenizer import FastTokenizer + + +class AutoTokenizer: + """Load a tokenizer the same way as Hugging Face ``AutoTokenizer`` (same assets).""" + + @staticmethod + def from_pretrained(model_id: str, **kwargs: Any) -> FastTokenizer: + return FastTokenizer.from_pretrained(model_id, **kwargs) + + +def from_pretrained(model_id: str, **kwargs: Any) -> FastTokenizer: + """Alias for :meth:`AutoTokenizer.from_pretrained`.""" + return AutoTokenizer.from_pretrained(model_id, **kwargs) diff --git a/tokenizer_impl/base.py b/tokenizer_impl/base.py new file mode 100644 index 0000000..14c09a3 --- /dev/null +++ b/tokenizer_impl/base.py @@ -0,0 +1,42 @@ +"""Abstract tokenizer interface (P0 GPU tokenizer roadmap).""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any + +import numpy as np + + +class Tokenizer(ABC): + """Minimal encode/decode surface aligned with common HF usage.""" + + @abstractmethod + def encode( + self, + text: str, + add_special_tokens: bool = True, + **kwargs: Any, + ) -> list[int] | np.ndarray: + """Return token ids for a single string.""" + + @abstractmethod + def decode( + self, + token_ids: list[int] | np.ndarray, + skip_special_tokens: bool = True, + **kwargs: Any, + ) -> str: + """Decode ids back to text.""" + + @abstractmethod + def batch_encode( + self, + texts: list[str], + padding: bool | str = False, + truncation: bool | str = False, + max_length: int | None = None, + return_tensors: str | None = None, + **kwargs: Any, + ) -> dict[str, Any]: + """Batch encode; return dict with at least ``input_ids`` (list of lists or ndarray).""" diff --git a/tokenizer_impl/fast_tokenizer.py b/tokenizer_impl/fast_tokenizer.py new file mode 100644 index 0000000..f42fdc1 --- /dev/null +++ b/tokenizer_impl/fast_tokenizer.py @@ -0,0 +1,110 @@ +"""Tokenizer implementation using the standalone ``tokenizers`` (Rust) library only. + +Matches Hugging Face **Fast** tokenizer outputs for the same serialized tokenizer assets +(``tokenizer.json``, including ``onnx/tokenizer.json`` Hub fallbacks) without importing +``transformers``. +""" + +from __future__ import annotations + +from typing import Any + +import numpy as np + +from .base import Tokenizer +from .loader import load_rust_tokenizer + + +class FastTokenizer(Tokenizer): + """Encode/decode via ``tokenizers.Tokenizer`` (Rust backend).""" + + def __init__(self, rust_tokenizer: Any, model_id: str | None = None) -> None: + self._tok = rust_tokenizer + self.model_id = model_id + + @classmethod + def from_pretrained(cls, model_id: str, **kwargs: Any) -> FastTokenizer: + try: + rs = load_rust_tokenizer(model_id, **kwargs) + except ImportError as e: + raise ImportError( + "FastTokenizer.from_pretrained requires `pip install tokenizers huggingface_hub`" + ) from e + return cls(rs, model_id=model_id) + + def encode( + self, + text: str, + add_special_tokens: bool = True, + **kwargs: Any, + ) -> list[int] | np.ndarray: + # Match HF PreTrainedTokenizer.encode defaults (no truncation / padding unless kwargs). + self._tok.no_truncation() + self._tok.no_padding() + enc = self._tok.encode(text, add_special_tokens=add_special_tokens) + return np.asarray(enc.ids, dtype=np.int32) + + def decode( + self, + token_ids: list[int] | np.ndarray, + skip_special_tokens: bool = True, + **kwargs: Any, + ) -> str: + if isinstance(token_ids, np.ndarray): + token_ids = token_ids.astype(np.int32).tolist() + return self._tok.decode(token_ids, skip_special_tokens=skip_special_tokens) + + def batch_encode( + self, + texts: list[str], + padding: bool | str = False, + truncation: bool | str = False, + max_length: int | None = None, + return_tensors: str | None = None, + add_special_tokens: bool = True, + **kwargs: Any, + ) -> dict[str, Any]: + if not texts: + return {"input_ids": [], "attention_mask": []} + + if truncation or max_length is not None: + ml = max_length if max_length is not None else 512 + self._tok.enable_truncation(ml) + else: + self._tok.no_truncation() + + if padding is True or padding == "longest" or padding == "max_length": + length = None + if padding == "max_length" and max_length is not None: + length = max_length + self._tok.enable_padding(direction="right", length=length) + else: + self._tok.no_padding() + + encodings = self._tok.encode_batch(texts, add_special_tokens=add_special_tokens) + + input_ids = [e.ids for e in encodings] + attention_mask = [e.attention_mask for e in encodings] + + out: dict[str, Any] = { + "input_ids": input_ids, + "attention_mask": attention_mask, + } + if return_tensors == "np": + out["input_ids"] = np.asarray(input_ids, dtype=object) + out["attention_mask"] = np.asarray(attention_mask, dtype=object) + return out + + def __call__(self, texts: str | list[str], **kwargs: Any) -> dict[str, Any]: + if isinstance(texts, str): + self._tok.no_truncation() + self._tok.no_padding() + enc = self._tok.encode( + texts, + add_special_tokens=kwargs.get("add_special_tokens", True), + ) + return { + "input_ids": enc.ids, + "attention_mask": enc.attention_mask, + } + return self.batch_encode(texts, **kwargs) diff --git a/tokenizer_impl/gpu.py b/tokenizer_impl/gpu.py new file mode 100644 index 0000000..bdf5d12 --- /dev/null +++ b/tokenizer_impl/gpu.py @@ -0,0 +1,61 @@ +"""GPU tokenizer backend (extension points; kernels land in shaders + C++ bridge).""" + +from __future__ import annotations + +import os +from typing import Any + +import numpy as np + +# Set GRILLY_GPU_TOKENIZER=1 when native GPU tokenization is wired (future). +GPU_TOKENIZER_AVAILABLE = os.environ.get("GRILLY_GPU_TOKENIZER", "").strip() in { + "1", + "true", + "yes", + "on", +} + + +def numpy_to_input_ids_buffers( + encoded: dict[str, Any], +) -> dict[str, np.ndarray]: + """Normalize batch output to numpy int32 arrays (CPU staging before GPU upload).""" + out: dict[str, np.ndarray] = {} + if "input_ids" in encoded: + ids = encoded["input_ids"] + if hasattr(ids, "numpy"): + ids = ids.numpy() + out["input_ids"] = np.asarray(ids, dtype=np.int32) + if "attention_mask" in encoded: + m = encoded["attention_mask"] + if hasattr(m, "numpy"): + m = m.numpy() + out["attention_mask"] = np.asarray(m, dtype=np.int32) + return out + + +def wrap_ids_as_vulkan_tensors( + encoded_numpy: dict[str, np.ndarray], + *, + device: Any = None, +) -> dict[str, Any]: + """Upload token buffers with :class:`grilly.utils.tensor_conversion.VulkanTensor` when available. + + Raises ``RuntimeError`` if Vulkan is unavailable or GPU tokenizer path not enabled. + """ + if not GPU_TOKENIZER_AVAILABLE: + raise RuntimeError( + "GPU-resident tokenizer buffers are not enabled yet " + "(set GRILLY_GPU_TOKENIZER=1 when implemented, or use numpy_to_input_ids_buffers)." + ) + try: + from grilly.utils.tensor_conversion import VulkanTensor, gpu_mode + except ImportError as e: + raise RuntimeError("VulkanTensor not available") from e + + if device is None: + gpu_mode(True) + out: dict[str, Any] = {} + for k, arr in encoded_numpy.items(): + out[k] = VulkanTensor(arr) + return out diff --git a/tokenizer_impl/loader.py b/tokenizer_impl/loader.py new file mode 100644 index 0000000..0ff7b25 --- /dev/null +++ b/tokenizer_impl/loader.py @@ -0,0 +1,57 @@ +"""Load ``tokenizers`` (Rust) tokenizer files from the Hub or local paths — no ``transformers``.""" + +from __future__ import annotations + +import os +from typing import Any + + +def load_rust_tokenizer(model_id: str, **kwargs: Any): + """Return ``tokenizers.Tokenizer`` from a Hub id or a local directory / ``tokenizer.json`` path.""" + from tokenizers import Tokenizer as RsTokenizer + + try: + from huggingface_hub.errors import EntryNotFoundError + except ImportError: + from huggingface_hub.utils import EntryNotFoundError # type: ignore[no-redef] + + if os.path.isfile(model_id) and model_id.endswith(".json"): + return RsTokenizer.from_file(model_id) + if os.path.isdir(model_id): + for rel in ("tokenizer.json", os.path.join("onnx", "tokenizer.json")): + p = os.path.join(model_id, rel) + if os.path.isfile(p): + return RsTokenizer.from_file(p) + raise FileNotFoundError( + f"No tokenizer.json or onnx/tokenizer.json under {model_id!r}", + ) + + token = kwargs.get("token") or kwargs.get("use_auth_token") + revision = kwargs.get("revision") + + from huggingface_hub import hf_hub_download + + # Root tokenizer.json first; some repos (e.g. google/mt5-small) ship only onnx/tokenizer.json. + for rel in ("tokenizer.json", "onnx/tokenizer.json"): + try: + path = hf_hub_download( + repo_id=model_id, + filename=rel, + token=token, + revision=revision, + ) + return RsTokenizer.from_file(path) + except EntryNotFoundError: + continue + + try: + tok = RsTokenizer.from_pretrained(model_id, **kwargs) + if tok is not None: + return tok + except (AttributeError, TypeError, OSError, ValueError, EntryNotFoundError): + pass + + raise FileNotFoundError( + f"Could not load Rust tokenizer for {model_id!r} " + "(tried tokenizer.json, onnx/tokenizer.json, Tokenizer.from_pretrained)", + ) diff --git a/torch_api/__init__.py b/torch_api/__init__.py new file mode 100644 index 0000000..b8497e7 --- /dev/null +++ b/torch_api/__init__.py @@ -0,0 +1,101 @@ +""" +Torch-compatible facade for Grilly: ``import grilly.torch_api as torch`` or use +re-exports on :mod:`grilly` (``device``, ``tensor``, ``save``, ``load``, …). + +No PyTorch runtime; GPU path is Vulkan via ``grilly_core`` + SPIR-V. Checkpoints +use ``.grl`` (see :mod:`grilly.utils.grl_checkpoint`). +""" + +from __future__ import annotations + +import types + +import grilly.nn as nn +import grilly.optim as optim +from grilly.functional.mf_activations import ( + mf_relu, + mf_sigmoid, + mf_sigmoid_01, + mf_softmax, + mf_softplus, +) + +from .amp_mod import GradScaler, autocast +from .dtypes_and_device import ( + cpu, + device, + float16, + float32, + float64, + int32, + int64, + long, +) +from .ops import ( + cdist, + clamp, + empty, + no_grad, + ones, + randn, + randperm, + sign, + tanh, + tensor, + zeros, +) +from .serialization import load, save +from .tensor import LongTensor, Tensor +from .vulkan_mod import is_available as _vulkan_is_available + +amp = types.SimpleNamespace( + autocast=autocast, + GradScaler=GradScaler, + __doc__="Automatic mixed precision (Vulkan-oriented; no CUDA).", +) + + +class _Vulkan: + """``torch.cuda``-free replacement: ``torch.vulkan.is_available()``.""" + + @staticmethod + def is_available() -> bool: + return _vulkan_is_available() + + +vulkan = _Vulkan() + +__all__ = [ + "Tensor", + "LongTensor", + "device", + "cpu", + "float16", + "float32", + "float64", + "int32", + "int64", + "long", + "tensor", + "empty", + "zeros", + "ones", + "randn", + "randperm", + "cdist", + "tanh", + "clamp", + "sign", + "no_grad", + "save", + "load", + "nn", + "optim", + "amp", + "vulkan", + "mf_softmax", + "mf_softplus", + "mf_sigmoid", + "mf_sigmoid_01", + "mf_relu", +] diff --git a/torch_api/amp_mod.py b/torch_api/amp_mod.py new file mode 100644 index 0000000..d0f16ac --- /dev/null +++ b/torch_api/amp_mod.py @@ -0,0 +1,58 @@ +"""``torch.amp``-compatible module (Vulkan / numpy path; no CUDA).""" + +from __future__ import annotations + +from typing import Any + +from grilly.backend import amp as _amp + + +class autocast: + """Autocast context (delegates to :mod:`grilly.backend.amp`). Accepts ``device_type`` / ``dtype`` for API parity.""" + + def __init__( + self, + device_type: str | None = None, + enabled: bool = True, + dtype: Any = None, + cache_enabled: bool = True, + ) -> None: + del device_type, dtype, cache_enabled + self._inner = _amp.autocast(enabled=enabled) + + def __enter__(self) -> autocast: + self._inner.__enter__() + return self + + def __exit__(self, *args: Any) -> None: + self._inner.__exit__(*args) + + +class GradScaler: + """GradScaler shim; first positional arg may be ``'vulkan'`` / ``'cuda'`` (ignored) or ``init_scale`` (float).""" + + def __init__( + self, + device_type: Any = None, + enabled: bool = True, + *, + init_scale: float = 65536.0, + growth_factor: float = 2.0, + backoff_factor: float = 0.5, + growth_interval: int = 2000, + **kwargs: Any, + ) -> None: + del kwargs + scale = init_scale + if isinstance(device_type, (int, float)) and not isinstance(device_type, bool): + scale = float(device_type) + self._inner = _amp.GradScaler( + init_scale=scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + enabled=enabled, + ) + + def __getattr__(self, name: str) -> Any: + return getattr(self._inner, name) diff --git a/torch_api/dtypes_and_device.py b/torch_api/dtypes_and_device.py new file mode 100644 index 0000000..f796964 --- /dev/null +++ b/torch_api/dtypes_and_device.py @@ -0,0 +1,87 @@ +"""Torch-style dtypes and ``device`` (Vulkan-only GPU path; no CUDA).""" + +from __future__ import annotations + +import numpy as np + +# Aliases for ``dtype=torch.float32`` / ``torch.long`` style imports +float16 = np.dtype("float16") +float32 = np.dtype("float32") +float64 = np.dtype("float64") +int32 = np.dtype("int32") +int64 = np.dtype("int64") +long = int64 + + +def _is_long_dtype(dtype: object) -> bool: + if dtype is None: + return False + if dtype in (int64, long, np.int64, "int64", "long"): + return True + if isinstance(dtype, np.dtype) and dtype.kind in "iu": + return dtype.itemsize >= 4 + return isinstance(dtype, str) and dtype.lower() in ("int64", "long") + + +def _dtype_to_numpy(dtype: object) -> np.dtype: + if dtype is None: + return np.dtype("float32") + if isinstance(dtype, np.dtype): + return dtype + if dtype in (float16, float32, float64, int32, int64): + return np.dtype(dtype) + if dtype is np.float16: + return np.dtype("float16") + if dtype is np.float32: + return np.dtype("float32") + if dtype is np.float64: + return np.dtype("float64") + if dtype is np.int64: + return np.dtype("int64") + s = str(dtype).lower() + if s in ("float16", "half"): + return np.dtype("float16") + if s in ("float32",): + return np.dtype("float32") + if s in ("float64", "double"): + return np.dtype("float64") + if s in ("int64", "long"): + return np.dtype("int64") + if s in ("int32",): + return np.dtype("int32") + return np.dtype("float32") + + +class device: + """``torch.device``-like device tag (``cpu`` / ``vulkan``).""" + + __slots__ = ("type", "index") + + def __init__(self, type_: str, index: int | None = None) -> None: + s = str(type_) + if ":" in s: + parts = s.split(":", 1) + self.type = parts[0].strip() + try: + self.index = int(parts[1]) if parts[1] else None + except ValueError: + self.index = None + else: + self.type = s.strip() + self.index = index + + def __eq__(self, other: object) -> bool: + if not isinstance(other, device): + return NotImplemented + return self.type == other.type and self.index == other.index + + def __str__(self) -> str: + if self.index is not None: + return f"{self.type}:{self.index}" + return self.type + + def __repr__(self) -> str: + return f"device(type='{self.type}'{'' if self.index is None else f', index={self.index}'})" + + +cpu = device("cpu") diff --git a/torch_api/functional.py b/torch_api/functional.py new file mode 100644 index 0000000..bd79a63 --- /dev/null +++ b/torch_api/functional.py @@ -0,0 +1,47 @@ +"""``torch.nn.functional``-compatible helpers.""" + +from __future__ import annotations + +import numpy as np + +from grilly.nn import autograd as ag + + +def softplus(input, beta: float = 1.0, threshold: float = 20.0): + return ag.softplus(input, beta, threshold) + + +def cross_entropy( + input, + target, + weight=None, + size_average=None, + ignore_index: int = -100, + reduce=None, + reduction: str = "mean", + label_smoothing: float = 0.0, +): + del weight, size_average, reduce, label_smoothing + logits = ag._ensure_variable(input) + if isinstance(target, ag.Variable): + t = target.data + elif isinstance(target, np.ndarray): + t = target + elif hasattr(target, "data") and not isinstance(target, np.ndarray): + t = target.data + else: + t = np.asarray(target) + + if ignore_index >= 0: + pass + + ce = ag.cross_entropy(logits, t) + + if reduction == "mean": + return ce + if reduction == "sum": + n = float(logits.data.shape[0]) + return ag.mul(ce, n) + if reduction == "none": + raise NotImplementedError("cross_entropy reduction='none' is not implemented") + raise ValueError(f"Unknown reduction: {reduction}") diff --git a/torch_api/ops.py b/torch_api/ops.py new file mode 100644 index 0000000..4a687ea --- /dev/null +++ b/torch_api/ops.py @@ -0,0 +1,111 @@ +"""Tensor factories and math ops (``torch.*`` facade).""" + +from __future__ import annotations + +import numpy as np + +from grilly.nn import autograd as ag + +from .dtypes_and_device import _dtype_to_numpy, _is_long_dtype +from .dtypes_and_device import device as _TorchDevice +from .tensor import LongTensor, Tensor, as_numpy + +_DEFAULT_FLOAT = np.dtype(np.float32) + + +def tensor( + data, + dtype: object | None = None, + device: _TorchDevice | str | None = None, + requires_grad: bool = False, +): + if _is_long_dtype(dtype): + return LongTensor(data, device=device) + np_dt = _dtype_to_numpy(dtype) if dtype is not None else _DEFAULT_FLOAT + arr = np.asarray(data, dtype=np_dt) + if np_dt.kind in "iu": + return LongTensor(arr, device=device) + return Tensor(arr.astype(np.float32), requires_grad=requires_grad) + + +def empty(*size: int, dtype=None, device=None, requires_grad: bool = False) -> Tensor: + if len(size) == 1 and isinstance(size[0], (tuple, list)): + size = tuple(size[0]) # type: ignore[assignment] + if not size: + raise TypeError("empty() requires at least one size dimension") + np_dt = _dtype_to_numpy(dtype) if dtype is not None else _DEFAULT_FLOAT + if np_dt.kind in "iu": + return LongTensor(np.empty(size, dtype=np_dt), device=device) + return Tensor(np.empty(size, dtype=np.float32), requires_grad=requires_grad) + + +def zeros(*size: int, dtype=None, device=None, requires_grad: bool = False) -> Tensor: + if len(size) == 1 and isinstance(size[0], (tuple, list)): + size = tuple(size[0]) # type: ignore[assignment] + if not size: + raise TypeError("zeros() requires size") + np_dt = _dtype_to_numpy(dtype) if dtype is not None else _DEFAULT_FLOAT + if np_dt.kind in "iu": + return LongTensor(np.zeros(size, dtype=np_dt), device=device) + return Tensor(np.zeros(size, dtype=np.float32), requires_grad=requires_grad) + + +def ones(*size: int, dtype=None, device=None, requires_grad: bool = False) -> Tensor: + if len(size) == 1 and isinstance(size[0], (tuple, list)): + size = tuple(size[0]) # type: ignore[assignment] + np_dt = _dtype_to_numpy(dtype) if dtype is not None else _DEFAULT_FLOAT + if np_dt.kind in "iu": + return LongTensor(np.ones(size, dtype=np_dt), device=device) + return Tensor(np.ones(size, dtype=np.float32), requires_grad=requires_grad) + + +def randn(*size: int, requires_grad: bool = False) -> Tensor: + if len(size) == 1 and isinstance(size[0], (tuple, list)): + size = tuple(size[0]) # type: ignore[assignment] + if not size: + return Tensor(np.array(np.random.randn(), dtype=np.float32), requires_grad=requires_grad) + return Tensor(np.random.randn(*size).astype(np.float32), requires_grad=requires_grad) + + +def randperm(n: int, dtype=None, device=None, requires_grad: bool = False) -> LongTensor: + """Random permutation of ``0 .. n-1`` (CPU/numpy).""" + idx = np.random.permutation(n).astype(np.int64) + return LongTensor(idx, device=device) + + +def cdist(x1, x2, p: float = 2.0, compute_mode=None) -> Tensor: + """Pairwise distance (``p=1`` L1 / Manhattan).""" + del compute_mode + a = as_numpy(x1) + b = as_numpy(x2) + if a.ndim == 2: + a = a[np.newaxis, ...] + if b.ndim == 2: + b = b[np.newaxis, ...] + if p == 1: + # (B, P, D), (B, R, D) -> (B, P, R) + diff = np.abs(a[:, :, None, :] - b[:, None, :, :]) + dist = np.sum(diff, axis=-1) + elif p == 2: + diff = a[:, :, None, :] - b[:, None, :, :] + dist = np.sqrt(np.sum(diff * diff, axis=-1)) + else: + diff = np.abs(a[:, :, None, :] - b[:, None, :, :]) ** p + dist = np.sum(diff, axis=-1) ** (1.0 / p) + return Tensor(dist.astype(np.float32), requires_grad=False) + + +def tanh(input) -> Tensor: + return ag.tanh(input) # type: ignore[return-value] + + +def clamp(input, min_val=None, max_val=None) -> Tensor: + return ag.clamp(input, min_val, max_val) # type: ignore[return-value] + + +def sign(input) -> Tensor: + return ag.sign(input) # type: ignore[return-value] + + +# Re-export no_grad from autograd +no_grad = ag.no_grad diff --git a/torch_api/serialization.py b/torch_api/serialization.py new file mode 100644 index 0000000..1c567b2 --- /dev/null +++ b/torch_api/serialization.py @@ -0,0 +1,49 @@ +"""``torch.save`` / ``torch.load`` facades for GRL checkpoints (``.grl``).""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any + +from grilly.utils.grl_checkpoint import load_grl, save_grl + + +def save( + obj: object, + f: str | Path, + *, + pickle_module: object | None = None, + pickle_protocol: int | None = None, + _use_new_zipfile_serialization: bool = True, +) -> None: + del pickle_module, pickle_protocol, _use_new_zipfile_serialization + path = Path(f) + if path.suffix.lower() not in (".grl",): + path = path.with_suffix(".grl") + meta: dict[str, Any] = {} + if isinstance(obj, dict): + for k in ("step", "training_step", "best_ppl", "epoch"): + if k in obj and not isinstance(obj[k], dict): + try: + meta[k] = obj[k] + except TypeError: + pass + save_grl(path, obj, metadata=meta or None) + + +def load( + f: str | Path, + map_location: Any = None, + pickle_module: object | None = None, + *, + weights_only: bool = False, + **kwargs: Any, +) -> Any: + del pickle_module, weights_only, kwargs + path = Path(f) + if path.suffix.lower() != ".grl": + raise ValueError( + f"Grilly torch.load only supports .grl checkpoints (got {path.suffix!r}). " + "Export PyTorch .pt to .grl with a one-time migration script." + ) + return load_grl(path, map_location=map_location) diff --git a/torch_api/tensor.py b/torch_api/tensor.py new file mode 100644 index 0000000..7f73c61 --- /dev/null +++ b/torch_api/tensor.py @@ -0,0 +1,89 @@ +"""``torch.Tensor``-like float tensor (``Variable`` subclass) and ``LongTensor``.""" + +from __future__ import annotations + +import numpy as np + +from grilly.nn.autograd import Variable + +from .dtypes_and_device import _dtype_to_numpy, device + + +class LongTensor(np.ndarray): + """Integer index tensor (``int64``), ndarray subclass.""" + + def __getitem__(self, key): + r = super().__getitem__(key) + if not isinstance(r, np.ndarray): + return r + if r.ndim == 0: + return r + if r.dtype == np.int64: + return r.view(LongTensor) + return r + + def __new__(cls, input_array, device: device | str | None = None): + obj = np.asarray(input_array, dtype=np.int64).view(cls) + obj._torch_device = device # type: ignore[attr-defined] + return obj + + def __array_finalize__(self, obj): + if obj is None: + return + self._torch_device = getattr(obj, "_torch_device", None) # type: ignore[attr-defined] + + def to(self, device: device | str | None = None, dtype=None, non_blocking: bool = False): + del non_blocking + if dtype is not None: + np_dt = _dtype_to_numpy(dtype) + return Tensor(np.asarray(self, dtype=np.float32).astype(np_dt)) + return LongTensor(self.copy(), device=device) + + @property + def dtype(self) -> np.dtype: + return np.dtype(np.int64) + + def numel(self) -> int: + return int(self.size) + + +class Tensor(Variable): + """Float tensor with PyTorch-style helpers (numpy + autograd ``Variable``).""" + + def uniform_(self, a: float = 0.0, b: float = 1.0) -> Tensor: + self.data[:] = np.random.uniform(a, b, self.shape).astype(np.float32) + return self + + def zero_(self) -> Tensor: + self.data.fill(np.float32(0.0)) + return self + + def to( + self, + device: device | str | None = None, + dtype: object | None = None, + non_blocking: bool = False, + ) -> Tensor | LongTensor: + del non_blocking # numpy path: unused + if dtype is not None: + np_dt = _dtype_to_numpy(dtype) + if np_dt.kind in "iu": + return LongTensor(self.data.astype(np_dt)) + new_data = self.data.astype(np_dt) + return Tensor(new_data, requires_grad=self.requires_grad) + if device is not None: + pass # no separate device memory in numpy path + return Tensor(self.data.copy(), requires_grad=self.requires_grad) + + @property + def dtype(self) -> np.dtype: + return self.data.dtype + + +def as_numpy(x: object) -> np.ndarray: + """Unwrap ``Tensor`` / ``Variable`` / ``LongTensor`` to ndarray.""" + if isinstance(x, np.ndarray): + return x + if hasattr(x, "data") and isinstance(getattr(x, "data", None), np.ndarray): + return x.data # type: ignore[union-attr] + return np.asarray(x) diff --git a/torch_api/vulkan_mod.py b/torch_api/vulkan_mod.py new file mode 100644 index 0000000..8886602 --- /dev/null +++ b/torch_api/vulkan_mod.py @@ -0,0 +1,15 @@ +"""``torch.cuda``-free availability helper (Vulkan).""" + +from __future__ import annotations + +try: + from grilly.backend.base import VULKAN_AVAILABLE +except Exception: + VULKAN_AVAILABLE = False + + +def is_available() -> bool: + return bool(VULKAN_AVAILABLE) + + +__all__ = ["is_available"] diff --git a/torch_api_example.py b/torch_api_example.py new file mode 100644 index 0000000..793cefb --- /dev/null +++ b/torch_api_example.py @@ -0,0 +1,311 @@ +"""VSA-LM v3b: Sequence-mean LiquidCell, batched. + +Reverts to the proven v1 architecture (one LiquidCell step per sequence via +x_mean), but fully batched on GPU with B=64. Loads vsa_lm_v1_resume.pt and +continues cosine decay from wherever the checkpoint left off. + +Why: per-token recurrence (v3) hit 0.2 stp/s on A100 due to Python loop +overhead. Sequence-mean is O(1) per layer per sequence — GPU-friendly and +the architecture that trained the checkpoint in the first place. +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +import math +import time +import os + +import numpy as np + +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + +SEQ_LEN = 256 +DATA_DIR = 'vsa_lm_data' + +# ── Data ── +tokens = np.load(f'{DATA_DIR}/tokens.npy') +vocab = int(np.load(f'{DATA_DIR}/vocab.npy')[0]) +n = len(tokens) +tr, va = tokens[:int(.8 * n)], tokens[int(.8 * n):int(.9 * n)] + + +def mkseqs(t, sl): + x, y = [], [] + for i in range(0, len(t) - sl - 1, sl // 2): + x.append(t[i:i + sl]) + y.append(t[i + 1:i + sl + 1]) + return (torch.tensor(np.array(x), dtype=torch.long), + torch.tensor(np.array(y), dtype=torch.long)) + + +train_x, train_y = mkseqs(tr, SEQ_LEN) +val_x, val_y = mkseqs(va, SEQ_LEN) +print(f'Vocab={vocab}, Train={len(train_x)}, Val={len(val_x)}') + + +def compute_ppl(model, x_data, y_data, max_samples=100): + model.eval() + total_loss, n_tok = 0.0, 0 + with torch.no_grad(), torch.amp.autocast('cuda', enabled=True, dtype=torch.float16): + # Batched eval for speed + bs = 32 + for i in range(0, min(len(x_data), max_samples), bs): + xb = x_data[i:i + bs].to(device) + yb = y_data[i:i + bs].to(device) + logits = model(xb) # (B, S, V) + loss = F.cross_entropy( + logits.reshape(-1, logits.shape[-1]), + yb.reshape(-1), + reduction='sum', + ) + total_loss += loss.item() + n_tok += yb.numel() + model.train() + return math.exp(min(total_loss / max(n_tok, 1), 20)) + + +# ── Config ── +D_MODEL = 384 +N_LAYERS = 12 +D_FFN = 1152 +LR = 1e-4 +TRAIN_STEPS = 100000 +VAL_EVERY = 200 +GRAD_CLIP = 1.0 +BATCH_SIZE = 16 +CAPSULE_DIM = 32 +USE_AMP = True # fp16 autocast — ~1.5-2x speedup on cdist + + +# ── Model ── +class AdditionLinearCUDA(nn.Module): + def __init__(self, d_in, d_out): + super().__init__() + self.weight = nn.Parameter(torch.empty(d_out, d_in).uniform_(-0.1, 0.1)) + self.bias = nn.Parameter(torch.zeros(d_out)) + self.d_in = d_in + + def forward(self, x): + # x: (..., d_in) → (..., d_out). Flatten leading dims for cdist. + orig_shape = x.shape + x_flat = x.reshape(-1, orig_shape[-1]) + dist = torch.cdist(x_flat.unsqueeze(0), self.weight.unsqueeze(0), p=1).squeeze(0) + out = -dist + self.bias + return out.reshape(*orig_shape[:-1], -1) + + +class LiquidCellCUDA(nn.Module): + """Sequence-mean LiquidCell — one step per sequence (not per token). + + For batched input (B, d), processes all B sequences in parallel in a + single forward pass. Matches the v1 architecture that trained + vsa_lm_v1_resume.pt. + """ + + def __init__(self, d, dt=0.02, tau_min=0.02, tau_max=2.0): + super().__init__() + s = math.sqrt(2.0 / (d + d)) + self.W = nn.Parameter(torch.randn(d, d) * s) + self.U = nn.Parameter(torch.randn(d, d) * s) + self.b = nn.Parameter(torch.zeros(d)) + self.V = nn.Parameter(torch.randn(d, d) * s) + self.c = nn.Parameter(torch.randn(d) * 0.1) + self.register_buffer('h', torch.zeros(d)) + self.dt = dt + self.tau_min = tau_min + self.tau_max = tau_max + + def step(self, x): + # Single sequence: x is (d,) + tau = self.tau_min + F.softplus(self.V @ x + self.c) + tau = torch.clamp(tau, max=self.tau_max) + a = torch.tanh(self.W @ self.h + self.U @ x + self.b) + dh = -self.h / tau.clamp(min=1e-6) + a + self.h = (self.h + self.dt * dh).detach() + return self.h + + def step_batched(self, x): + # x: (B, d). Broadcasts h across batch, single parallel update. + B = x.shape[0] + tau = self.tau_min + F.softplus(x @ self.V.T + self.c) # (B, d) + tau = torch.clamp(tau, max=self.tau_max) + h_b = self.h.unsqueeze(0).expand(B, -1) # (B, d) — shared state + a = torch.tanh(h_b @ self.W.T + x @ self.U.T + self.b) # (B, d) + dh = -h_b / tau.clamp(min=1e-6) + a + new_h = h_b + self.dt * dh # (B, d) + # Update the shared buffer to the batch mean (detached for next call) + self.h = new_h.mean(dim=0).detach() + return new_h + + def reset(self): + self.h.zero_() + + +class VSALayerCUDA(nn.Module): + def __init__(self, d, d_ffn): + super().__init__() + self.ln = nn.LayerNorm(d) + self.ffn_up = AdditionLinearCUDA(d, d_ffn) + self.ffn_down = AdditionLinearCUDA(d_ffn, d) + self.liquid = LiquidCellCUDA(d) + self.d = d + + def forward(self, x, plasticity=0.5, consolidation=0.5): + if x.ndim == 3: + return self._forward_batch(x, plasticity, consolidation) + return self._forward_single(x, plasticity, consolidation) + + def _forward_single(self, x, plasticity, consolidation): + # x: (S, d) + h = self.ln(x) + x_mean = h.mean(dim=0) # (d,) + self.liquid.dt = 0.01 + 0.03 * plasticity + self.liquid.tau_min = 0.02 + 0.08 * consolidation + temporal = self.liquid.step(x_mean) # (d,) + + h_up = torch.sign(self.ffn_up(h) / math.sqrt(self.ffn_up.d_in)) + h_ffn = self.ffn_down(h_up) / math.sqrt(self.ffn_down.d_in) + + gate = torch.tanh(temporal) # (d,) + y = (0.5 + plasticity) * h_ffn * (1.0 + 0.1 * gate) + return x + y + + def _forward_batch(self, x, plasticity, consolidation): + # x: (B, S, d) + h = self.ln(x) + x_mean = h.mean(dim=1) # (B, d) + self.liquid.dt = 0.01 + 0.03 * plasticity + self.liquid.tau_min = 0.02 + 0.08 * consolidation + temporal = self.liquid.step_batched(x_mean) # (B, d) + + h_up = torch.sign(self.ffn_up(h) / math.sqrt(self.ffn_up.d_in)) + h_ffn = self.ffn_down(h_up) / math.sqrt(self.ffn_down.d_in) + + gate = torch.tanh(temporal).unsqueeze(1) # (B, 1, d) — broadcasts over S + y = (0.5 + plasticity) * h_ffn * (1.0 + 0.1 * gate) + return x + y + + +class VSALMModel(nn.Module): + def __init__(self, vocab, d, d_ffn, n_layers, max_seq): + super().__init__() + self.d = d + self.embed = nn.Embedding(vocab, d) + self.capsule_embed = nn.Embedding(vocab, CAPSULE_DIM) + self.capsule_proj = nn.Linear(CAPSULE_DIM, d, bias=False) + self.pos = nn.Parameter(torch.randn(max_seq + 16, d) * 0.02) + self.out_proj = nn.Linear(d, vocab, bias=False) + self.layers = nn.ModuleList([VSALayerCUDA(d, d_ffn) for _ in range(n_layers)]) + self.scale = math.sqrt(d) + nn.init.normal_(self.embed.weight, 0, 0.02) + nn.init.normal_(self.capsule_embed.weight, 0, 0.01) + + def forward(self, ids): + if ids.ndim == 1: + S = ids.shape[0] + x = self.embed(ids) + self.pos[:S] + caps = self.capsule_embed(ids) + x = x + self.capsule_proj(caps) + with torch.no_grad(): + cm = caps.mean(dim=0) + plasticity = cm[14].clamp(0, 1).item() + consolidation = cm[18].clamp(0, 1).item() + for layer in self.layers: + x = layer(x, plasticity=plasticity, consolidation=consolidation) + return self.out_proj(x / self.scale) + + # Batched: (B, S) + B, S = ids.shape + x = self.embed(ids) + self.pos[:S].unsqueeze(0) + caps = self.capsule_embed(ids) + x = x + self.capsule_proj(caps) + with torch.no_grad(): + cm = caps.mean(dim=(0, 1)) + plasticity = cm[14].clamp(0, 1).item() + consolidation = cm[18].clamp(0, 1).item() + for layer in self.layers: + x = layer(x, plasticity=plasticity, consolidation=consolidation) + return self.out_proj(x / self.scale) + + def reset_liquid(self): + for layer in self.layers: + layer.liquid.reset() + + +# ── Load checkpoint ── +model = VSALMModel(vocab, D_MODEL, D_FFN, N_LAYERS, SEQ_LEN).to(device) + +ckpt_path = 'vsa_lm_v1_resume.pt' if os.path.exists('vsa_lm_v1_resume.pt') else 'vsa_lm_best.pt' +checkpoint = torch.load(ckpt_path, map_location=device) +if 'model' in checkpoint: + model.load_state_dict(checkpoint['model'], strict=True) + ckpt_step = checkpoint.get('step', 0) + ckpt_ppl = checkpoint.get('best_ppl', 0) + print(f'Loaded {ckpt_path} (step={ckpt_step}, PPL={ckpt_ppl:.1f})') +else: + model.load_state_dict(checkpoint, strict=True) + ckpt_step = 0 + print(f'Loaded {ckpt_path}') + +n_params = sum(p.numel() for p in model.parameters()) +print(f'VSA-LM v3b: d={D_MODEL}, L={N_LAYERS}, B={BATCH_SIZE}, params={n_params/1e6:.1f}M') + +with torch.no_grad(): + ppl = compute_ppl(model, val_x, val_y) + print(f'Checkpoint PPL: {ppl:.1f}') + +# ── Train ── +optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=0.01) +# Cosine over remaining steps +remaining = max(TRAIN_STEPS - ckpt_step, 1000) +scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, remaining, eta_min=1e-6) +scaler = torch.amp.GradScaler('cuda', enabled=USE_AMP) + +t0 = time.time() +best_ppl = ppl +step = 0 + +while step < remaining: + perm = torch.randperm(len(train_x)) + for i in range(0, len(perm) - BATCH_SIZE, BATCH_SIZE): + if step >= remaining: + break + + ids_batch = train_x[perm[i:i + BATCH_SIZE]].to(device) + labels_batch = train_y[perm[i:i + BATCH_SIZE]].to(device) + + model.reset_liquid() + optimizer.zero_grad() + + with torch.amp.autocast('cuda', enabled=USE_AMP, dtype=torch.float16): + logits = model(ids_batch) # (B, S, V) + loss = F.cross_entropy( + logits.reshape(-1, logits.shape[-1]), + labels_batch.reshape(-1), + ) + + scaler.scale(loss).backward() + scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP) + scaler.step(optimizer) + scaler.update() + scheduler.step() + + if step % VAL_EVERY == 0: + ppl = compute_ppl(model, val_x, val_y) + el = time.time() - t0 + sps = (step + 1) / el if el > 0 else 0 + lr_now = scheduler.get_last_lr()[0] + print(f'step={step:6d} lr={lr_now:.5f} | loss={loss.item():.3f} | ' + f'PPL={ppl:.1f} | {sps:.1f} stp/s (B={BATCH_SIZE})') + if ppl < best_ppl: + best_ppl = ppl + torch.save( + {'model': model.state_dict(), 'step': ckpt_step + step, 'best_ppl': best_ppl}, + 'vsa_lm_v3b_best.pt', + ) + + step += 1 + +print(f'\nDone: {step} steps, best PPL={best_ppl:.1f}, time={time.time()-t0:.0f}s') diff --git a/utils/__init__.py b/utils/__init__.py index 75d0278..8ec1f57 100644 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -5,6 +5,7 @@ """ from .checkpoint import load_checkpoint, save_checkpoint +from .grl_checkpoint import load_grl, save_grl from .data import ( ArrayDataset, BatchSampler, @@ -137,6 +138,8 @@ # Checkpoint "save_checkpoint", "load_checkpoint", + "save_grl", + "load_grl", # Device "get_device", "set_device", diff --git a/utils/grl_checkpoint.py b/utils/grl_checkpoint.py new file mode 100644 index 0000000..51facf5 --- /dev/null +++ b/utils/grl_checkpoint.py @@ -0,0 +1,252 @@ +""" +Grilly checkpoint format (.grl) — GRL v1. + +Binary layout: + - Magic ``GRLY`` (4 bytes) + - uint16 format version (1) + - uint16 flags (reserved, 0) + - uint32 reserved + - uint64 metadata_json_offset, metadata_json_length + - uint64 tensor_index_offset, tensor_index_length + - uint64 payload_offset, payload_length + - padding to 64-byte header + +Followed by: + - UTF-8 JSON metadata blob + - UTF-8 JSON tensor index (array of tensor descriptors) + - payload (concatenated tensor bytes, C-contiguous row-major) + +Tensor index entry:: + {"name": str, "dtype": "f32"|"f16"|"i64"|"i32"|"u8", "shape": [int,...], + "offset": int, "length": int} + +``offset`` / ``length`` are byte ranges relative to **start of payload section**. +""" + +from __future__ import annotations + +import json +import struct +from pathlib import Path +from typing import Any + +import numpy as np + +try: + import grilly_core as _grl_core + + _HAS_CPP_GRL = hasattr(_grl_core, "grl_write_file") +except ImportError: + _grl_core = None + _HAS_CPP_GRL = False + +MAGIC = b"GRLY" +FORMAT_VERSION = 1 +HEADER_SIZE = 64 +_FLAG_NONE = 0 + +_DTYPE_TO_STR = { + np.dtype("float32"): "f32", + np.dtype("float16"): "f16", + np.dtype("int64"): "i64", + np.dtype("int32"): "i32", + np.dtype("uint8"): "u8", +} +_STR_TO_DTYPE = {v: k for k, v in _DTYPE_TO_STR.items()} + + +def _flatten_state_dict(d: dict[str, Any], prefix: str = "") -> dict[str, np.ndarray]: + """Flatten nested dict to dotted keys -> ndarray.""" + out: dict[str, np.ndarray] = {} + for k, v in d.items(): + key = f"{prefix}.{k}" if prefix else k + if isinstance(v, dict): + out.update(_flatten_state_dict(v, key)) + elif isinstance(v, np.ndarray): + out[key] = np.ascontiguousarray(v) + else: + try: + out[key] = np.asarray(v) + except Exception: + continue + return out + + +def _unflatten_state_dict(flat: dict[str, np.ndarray]) -> dict[str, Any]: + """Rebuild nested dict from dotted keys.""" + root: dict[str, Any] = {} + for key, arr in flat.items(): + parts = key.split(".") + cur = root + for p in parts[:-1]: + cur = cur.setdefault(p, {}) + cur[parts[-1]] = arr + return root + + +def save_grl( + filepath: str | Path, + state_dict: dict[str, Any], + *, + metadata: dict[str, Any] | None = None, +) -> None: + """Write a GRL v1 checkpoint from a (possibly nested) state dict of numpy arrays.""" + path = Path(filepath) + path.parent.mkdir(parents=True, exist_ok=True) + + flat = _flatten_state_dict(state_dict) + meta = { + "schema": "grilly.checkpoint.v1", + "framework": "grilly", + **(metadata or {}), + } + meta_bytes = json.dumps(meta, separators=(",", ":"), sort_keys=True).encode("utf-8") + + payload_parts: list[bytes] = [] + index_entries: list[dict[str, Any]] = [] + offset = 0 + for name in sorted(flat.keys()): + arr = flat[name] + if not isinstance(arr, np.ndarray): + arr = np.asarray(arr) + dt = arr.dtype + if dt not in _DTYPE_TO_STR: + arr = arr.astype(np.float32) + dt = arr.dtype + raw = arr.tobytes(order="C") + dtype_str = _DTYPE_TO_STR.get(dt, "f32") + index_entries.append( + { + "name": name, + "dtype": dtype_str, + "shape": list(arr.shape), + "offset": offset, + "length": len(raw), + } + ) + payload_parts.append(raw) + offset += len(raw) + + payload = b"".join(payload_parts) + index_bytes = json.dumps(index_entries, separators=(",", ":")).encode("utf-8") + + # Layout: header | meta | index | payload + meta_off = HEADER_SIZE + meta_len = len(meta_bytes) + idx_off = meta_off + meta_len + idx_len = len(index_bytes) + pay_off = idx_off + idx_len + pay_len = len(payload) + + if _HAS_CPP_GRL: + _grl_core.grl_write_file( + str(path), + meta_bytes.decode("utf-8"), + index_bytes.decode("utf-8"), + payload, + ) + return + + header = bytearray(HEADER_SIZE) + header[0:4] = MAGIC + struct.pack_into(" dict[str, Any]: + """ + Load GRL v1 checkpoint. Returns a dict with ``metadata``, ``model`` (nested state_dict), + and any extra keys from metadata. + + ``map_location`` is accepted for torch API compatibility; ``\"cpu\"`` keeps arrays + on host; ``\"vulkan\"`` / default leaves numpy arrays (caller uploads to GPU). + """ + _ = map_location + path = Path(filepath) + + if _HAS_CPP_GRL: + meta_json, index_json, pay_bytes = _grl_core.grl_read_file(str(path)) + metadata = json.loads(meta_json) + index_entries = json.loads(index_json) + payload = memoryview(pay_bytes) + flat: dict[str, np.ndarray] = {} + for ent in index_entries: + name = ent["name"] + dtype_s = ent["dtype"] + shape = tuple(ent["shape"]) + off = int(ent["offset"]) + ln = int(ent["length"]) + dt = _STR_TO_DTYPE.get(dtype_s, np.float32) + buf = payload[off : off + ln].tobytes() + arr = np.frombuffer(buf, dtype=dt).reshape(shape) + flat[name] = np.array(arr, copy=True) + # Roundtrip semantics: return what the user saved, not a forced + # ``{'model': ..., 'metadata': ...}`` wrapper. ``torch.save`` / + # ``torch.load`` users expect ``ck == original_payload``. + out: dict[str, Any] = _unflatten_state_dict(flat) + # Add metadata only if it doesn't collide with a user key. + if "metadata" not in out: + out["metadata"] = metadata + # Promote common training scalars from metadata to top level when + # the user didn't include them in the original payload (back-compat + # with checkpoints saved by older grilly that only put step in meta). + for k in ("step", "training_step", "best_ppl", "epoch"): + if k in metadata and k not in out: + out[k] = metadata[k] + return out + + data = path.read_bytes() + if len(data) < HEADER_SIZE or data[0:4] != MAGIC: + raise ValueError(f"Not a GRL file or corrupt magic: {path}") + + version = struct.unpack_from(" dict[str, Any]: + """Map load_grl output to a torch-like checkpoint dict (``model``, ``step``, ...).""" + d = dict(grl_dict) + if "model" in d and isinstance(d["model"], dict): + # torch often uses 'model' key for state_dict + pass + return d diff --git a/utils/initialization.py b/utils/initialization.py index 8c57471..a2ee675 100644 --- a/utils/initialization.py +++ b/utils/initialization.py @@ -116,3 +116,33 @@ def kaiming_normal_( std = gain / np.sqrt(fan) tensor[:] = np.random.randn(*tensor.shape).astype(tensor.dtype) * std return tensor + + +def _writable_array(tensor): + """Return a writable ndarray view for ``tensor``. + + The torch_api ``Tensor`` / autograd ``Variable`` / ``Parameter`` wrappers + don't all support ``__setitem__`` directly — but they expose the backing + ndarray via ``.data``. Detect those and return the underlying buffer so + in-place init works for both raw arrays and wrapped tensors. + """ + if isinstance(tensor, np.ndarray): + return tensor + if hasattr(tensor, "data") and isinstance(getattr(tensor, "data", None), np.ndarray): + return tensor.data + # Last-resort: hope it's array-like and writable. + return np.asarray(tensor) + + +def normal_(tensor, mean: float = 0.0, std: float = 1.0): + """In-place normal distribution (PyTorch ``nn.init.normal_``).""" + arr = _writable_array(tensor) + arr[:] = np.random.normal(mean, std, arr.shape).astype(arr.dtype, copy=False) + return tensor + + +def uniform_(tensor, a: float = 0.0, b: float = 1.0): + """In-place uniform distribution (PyTorch ``nn.init.uniform_``).""" + arr = _writable_array(tensor) + arr[:] = np.random.uniform(a, b, arr.shape).astype(arr.dtype, copy=False) + return tensor diff --git a/utils/tensor_conversion.py b/utils/tensor_conversion.py index 322cead..9cdf442 100644 --- a/utils/tensor_conversion.py +++ b/utils/tensor_conversion.py @@ -31,6 +31,20 @@ from .device_manager import get_device_manager +_vulkan_backend_cache = None + + +def _get_vulkan_backend(): + """Get cached Vulkan backend without using deprecated Compute().""" + global _vulkan_backend_cache + if _vulkan_backend_cache is not None: + return _vulkan_backend_cache + try: + _vulkan_backend_cache = get_device_manager().vulkan + return _vulkan_backend_cache + except Exception: + return None + def to_vulkan( tensor: np.ndarray | Any, keep_on_gpu: bool = False @@ -346,23 +360,75 @@ def mark_cpu_modified(self): except Exception: pass + def _try_bind_cpp_gpu_buffer(self) -> bool: + """If C++ Tensor already holds a GPU buffer, mirror its VkBuffer handle for Python dispatch. + + Avoids allocating a second pooled buffer and re-uploading when the tensor is + GPU-resident only (e.g. output of a prior op) but Python slots were not set. + """ + if self._pooled_buffer is not None or self._gpu_buffer is not None: + return True + try: + h = int(self._t.gpu_handle_if_valid()) + except Exception: + return False + if h == 0: + return False + try: + import vulkan as vk + + self._gpu_buffer = vk.ffi.cast("VkBuffer", h) + self._gpu_memory = None + if self._core is None: + backend = _get_vulkan_backend() + if backend is not None: + self._core = backend.core + self._gpu_valid = True + self._uploaded = True + self._cpu_valid = bool(self._t.on_cpu) + return True + except Exception: + return False + + def prepare_for_dispatch(self) -> None: + """Ensure this tensor can supply a Vulkan buffer for kernels (minimal CPU↔GPU traffic).""" + if self._try_bind_cpp_gpu_buffer(): + return + if self._t.on_gpu: + try: + self._t.ensure_gpu() + except Exception: + pass + if self._try_bind_cpp_gpu_buffer(): + return + self._ensure_uploaded() + def _ensure_uploaded(self): - if self._gpu_valid: + if self._try_bind_cpp_gpu_buffer(): return if not self._cpu_valid and not self._t.on_cpu: + if self._t.on_gpu: + try: + self._t.ensure_gpu() + except Exception: + pass + if self._try_bind_cpp_gpu_buffer(): + return raise RuntimeError("Cannot upload: no valid CPU data") try: self._t.ensure_gpu() self._gpu_valid = True self._uploaded = True + if self._try_bind_cpp_gpu_buffer(): + return except Exception: # C++ Tensor may not have Vulkan backend — fall back to Python path try: - from grilly import Compute - - backend = Compute() + backend = _get_vulkan_backend() + if backend is None: + raise RuntimeError("Vulkan backend is unavailable") cpu_data = self._t.numpy() size = cpu_data.nbytes @@ -425,9 +491,9 @@ def _ensure_downloaded(self): core = self._core if core is None: try: - from grilly import Compute - - backend = Compute() + backend = _get_vulkan_backend() + if backend is None: + raise RuntimeError("Vulkan backend is unavailable") core = backend.core self._core = core except Exception: @@ -453,9 +519,9 @@ def _ensure_downloaded(self): def _download_via_staging(self): core = self._core if core is None: - from grilly import Compute - - backend = Compute() + backend = _get_vulkan_backend() + if backend is None: + raise RuntimeError("Vulkan backend is unavailable") core = backend.core self._core = core @@ -611,9 +677,9 @@ def _ensure_uploaded(self): if not self._cpu_valid: raise RuntimeError("Cannot upload: no valid CPU data") try: - from grilly import Compute - - backend = Compute() + backend = _get_vulkan_backend() + if backend is None: + raise RuntimeError("Vulkan backend is unavailable") size = self._cpu_data.nbytes try: from grilly.backend.buffer_pool import acquire_buffer @@ -663,9 +729,9 @@ def _ensure_downloaded(self): return core = self._core if core is None: - from grilly import Compute - - backend = Compute() + backend = _get_vulkan_backend() + if backend is None: + raise RuntimeError("Vulkan backend is unavailable") core = backend.core self._core = core self._cpu_data = core._download_buffer( @@ -678,9 +744,9 @@ def _ensure_downloaded(self): def _download_via_staging(self): core = self._core if core is None: - from grilly import Compute - - backend = Compute() + backend = _get_vulkan_backend() + if backend is None: + raise RuntimeError("Vulkan backend is unavailable") core = backend.core self._core = core pooled = getattr(self, "_pooled_buffer", None) @@ -712,6 +778,9 @@ def mark_cpu_modified(self): self._cpu_valid = True self._gpu_valid = False + def prepare_for_dispatch(self) -> None: + self._ensure_uploaded() + @property def is_leaf(self) -> bool: """True if this tensor was created directly (not by an operation).""" diff --git a/uv.lock b/uv.lock index 68e9a98..302dcdc 100644 --- a/uv.lock +++ b/uv.lock @@ -522,7 +522,7 @@ wheels = [ [[package]] name = "grilly" -version = "0.5.6" +version = "0.6.1" source = { editable = "." } dependencies = [ { name = "numpy" }, @@ -534,26 +534,36 @@ accel = [ ] all = [ { name = "black" }, + { name = "huggingface-hub" }, { name = "isort" }, { name = "mkdocs-material" }, { name = "mypy" }, { name = "numba" }, + { name = "protobuf" }, { name = "pytest" }, { name = "pytest-asyncio" }, { name = "pytest-benchmark" }, { name = "pytest-cov" }, { name = "ruff" }, + { name = "sentencepiece" }, + { name = "tokenizers" }, + { name = "transformers" }, ] dev = [ { name = "black" }, + { name = "huggingface-hub" }, { name = "isort" }, { name = "mkdocs-material" }, { name = "mypy" }, + { name = "protobuf" }, { name = "pytest" }, { name = "pytest-asyncio" }, { name = "pytest-benchmark" }, { name = "pytest-cov" }, { name = "ruff" }, + { name = "sentencepiece" }, + { name = "tokenizers" }, + { name = "transformers" }, ] full = [ { name = "blake3" }, @@ -572,6 +582,10 @@ huggingface = [ onnx = [ { name = "onnx" }, ] +tokenizer = [ + { name = "huggingface-hub" }, + { name = "tokenizers" }, +] torch = [ { name = "torch" }, ] @@ -581,6 +595,8 @@ requires-dist = [ { name = "black", marker = "extra == 'dev'", specifier = ">=23.7.0" }, { name = "blake3", marker = "extra == 'full'", specifier = ">=1.0.8" }, { name = "grilly", extras = ["dev", "accel"], marker = "extra == 'all'" }, + { name = "huggingface-hub", marker = "extra == 'dev'", specifier = ">=0.20.0" }, + { name = "huggingface-hub", marker = "extra == 'tokenizer'", specifier = ">=0.20.0" }, { name = "isort", marker = "extra == 'dev'", specifier = ">=5.12.0" }, { name = "mkdocs-material", marker = "extra == 'dev'", specifier = ">=9.0" }, { name = "mypy", marker = "extra == 'dev'", specifier = ">=1.5.0" }, @@ -589,6 +605,7 @@ requires-dist = [ { name = "numpy" }, { name = "onnx", marker = "extra == 'full'", specifier = ">=1.15.0" }, { name = "onnx", marker = "extra == 'onnx'", specifier = ">=1.15.0" }, + { name = "protobuf", marker = "extra == 'dev'", specifier = ">=4.0.0" }, { name = "pytest", marker = "extra == 'dev'", specifier = ">=7.4.0" }, { name = "pytest-asyncio", marker = "extra == 'dev'", specifier = ">=0.21.0" }, { name = "pytest-benchmark", marker = "extra == 'dev'", specifier = ">=5.2.3" }, @@ -596,14 +613,18 @@ requires-dist = [ { name = "ruff", marker = "extra == 'dev'", specifier = ">=0.1.0" }, { name = "sentence-transformers", marker = "extra == 'full'", specifier = ">=5.2.0" }, { name = "sentence-transformers", marker = "extra == 'huggingface'", specifier = ">=5.2.0" }, + { name = "sentencepiece", marker = "extra == 'dev'", specifier = ">=0.2.0" }, { name = "spacy", marker = "extra == 'full'", specifier = ">=3.8.11" }, + { name = "tokenizers", marker = "extra == 'dev'", specifier = ">=0.15.0" }, + { name = "tokenizers", marker = "extra == 'tokenizer'", specifier = ">=0.15.0" }, { name = "torch", marker = "extra == 'full'", specifier = ">=2.10.0" }, { name = "torch", marker = "extra == 'torch'", specifier = ">=2.10.0" }, + { name = "transformers", marker = "extra == 'dev'", specifier = ">=4.57.6" }, { name = "transformers", marker = "extra == 'full'", specifier = ">=4.57.6" }, { name = "transformers", marker = "extra == 'huggingface'", specifier = ">=4.57.6" }, { name = "vulkan", marker = "extra == 'full'", specifier = ">=1.3.0" }, ] -provides-extras = ["full", "torch", "huggingface", "onnx", "dev", "accel", "all"] +provides-extras = ["full", "torch", "huggingface", "onnx", "dev", "accel", "tokenizer", "all"] [[package]] name = "hf-xet" @@ -1947,6 +1968,54 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/40/d0/3b2897ef6a0c0c801e9fecca26bcc77081648e38e8c772885ebdd8d7d252/sentence_transformers-5.2.0-py3-none-any.whl", hash = "sha256:aa57180f053687d29b08206766ae7db549be5074f61849def7b17bf0b8025ca2", size = 493748, upload-time = "2025-12-11T14:12:29.516Z" }, ] +[[package]] +name = "sentencepiece" +version = "0.2.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/15/15/2e7a025fc62d764b151ae6d0f2a92f8081755ebe8d4a64099accc6f77ba6/sentencepiece-0.2.1.tar.gz", hash = "sha256:8138cec27c2f2282f4a34d9a016e3374cd40e5c6e9cb335063db66a0a3b71fad", size = 3228515, upload-time = "2025-08-12T07:00:51.718Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/4a/be/32ce495aa1d0e0c323dcb1ba87096037358edee539cac5baf8755a6bd396/sentencepiece-0.2.1-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:57cae326c8727de58c85977b175af132a7138d84c764635d7e71bbee7e774133", size = 1943152, upload-time = "2025-08-12T06:59:40.048Z" }, + { url = "https://files.pythonhosted.org/packages/88/7e/ff23008899a58678e98c6ff592bf4d368eee5a71af96d0df6b38a039dd4f/sentencepiece-0.2.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:56dd39a3c4d6493db3cdca7e8cc68c6b633f0d4195495cbadfcf5af8a22d05a6", size = 1325651, upload-time = "2025-08-12T06:59:41.536Z" }, + { url = "https://files.pythonhosted.org/packages/19/84/42eb3ce4796777a1b5d3699dfd4dca85113e68b637f194a6c8d786f16a04/sentencepiece-0.2.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:d9381351182ff9888cc80e41c632e7e274b106f450de33d67a9e8f6043da6f76", size = 1253645, upload-time = "2025-08-12T06:59:42.903Z" }, + { url = "https://files.pythonhosted.org/packages/89/fa/d3d5ebcba3cb9e6d3775a096251860c41a6bc53a1b9461151df83fe93255/sentencepiece-0.2.1-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:99f955df238021bf11f0fc37cdb54fd5e5b5f7fd30ecc3d93fb48b6815437167", size = 1316273, upload-time = "2025-08-12T06:59:44.476Z" }, + { url = "https://files.pythonhosted.org/packages/04/88/14f2f4a2b922d8b39be45bf63d79e6cd3a9b2f248b2fcb98a69b12af12f5/sentencepiece-0.2.1-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0cdfecef430d985f1c2bcbfff3defd1d95dae876fbd0173376012d2d7d24044b", size = 1387881, upload-time = "2025-08-12T06:59:46.09Z" }, + { url = "https://files.pythonhosted.org/packages/fd/b8/903e5ccb77b4ef140605d5d71b4f9e0ad95d456d6184688073ed11712809/sentencepiece-0.2.1-cp312-cp312-win32.whl", hash = "sha256:a483fd29a34c3e34c39ac5556b0a90942bec253d260235729e50976f5dba1068", size = 999540, upload-time = "2025-08-12T06:59:48.023Z" }, + { url = "https://files.pythonhosted.org/packages/2d/81/92df5673c067148c2545b1bfe49adfd775bcc3a169a047f5a0e6575ddaca/sentencepiece-0.2.1-cp312-cp312-win_amd64.whl", hash = "sha256:4cdc7c36234fda305e85c32949c5211faaf8dd886096c7cea289ddc12a2d02de", size = 1054671, upload-time = "2025-08-12T06:59:49.895Z" }, + { url = "https://files.pythonhosted.org/packages/fe/02/c5e3bc518655d714622bec87d83db9cdba1cd0619a4a04e2109751c4f47f/sentencepiece-0.2.1-cp312-cp312-win_arm64.whl", hash = "sha256:daeb5e9e9fcad012324807856113708614d534f596d5008638eb9b40112cd9e4", size = 1033923, upload-time = "2025-08-12T06:59:51.952Z" }, + { url = "https://files.pythonhosted.org/packages/ba/4a/85fbe1706d4d04a7e826b53f327c4b80f849cf1c7b7c5e31a20a97d8f28b/sentencepiece-0.2.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:dcd8161eee7b41aae57ded06272905dbd680a0a04b91edd0f64790c796b2f706", size = 1943150, upload-time = "2025-08-12T06:59:53.588Z" }, + { url = "https://files.pythonhosted.org/packages/c2/83/4cfb393e287509fc2155480b9d184706ef8d9fa8cbf5505d02a5792bf220/sentencepiece-0.2.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:c6c8f42949f419ff8c7e9960dbadcfbc982d7b5efc2f6748210d3dd53a7de062", size = 1325651, upload-time = "2025-08-12T06:59:55.073Z" }, + { url = "https://files.pythonhosted.org/packages/8d/de/5a007fb53b1ab0aafc69d11a5a3dd72a289d5a3e78dcf2c3a3d9b14ffe93/sentencepiece-0.2.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:097f3394e99456e9e4efba1737c3749d7e23563dd1588ce71a3d007f25475fff", size = 1253641, upload-time = "2025-08-12T06:59:56.562Z" }, + { url = "https://files.pythonhosted.org/packages/2c/d2/f552be5928105588f4f4d66ee37dd4c61460d8097e62d0e2e0eec41bc61d/sentencepiece-0.2.1-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d7b670879c370d350557edabadbad1f6561a9e6968126e6debca4029e5547820", size = 1316271, upload-time = "2025-08-12T06:59:58.109Z" }, + { url = "https://files.pythonhosted.org/packages/96/df/0cfe748ace5485be740fed9476dee7877f109da32ed0d280312c94ec259f/sentencepiece-0.2.1-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c7f0fd2f2693309e6628aeeb2e2faf6edd221134dfccac3308ca0de01f8dab47", size = 1387882, upload-time = "2025-08-12T07:00:00.701Z" }, + { url = "https://files.pythonhosted.org/packages/ac/dd/f7774d42a881ced8e1739f393ab1e82ece39fc9abd4779e28050c2e975b5/sentencepiece-0.2.1-cp313-cp313-win32.whl", hash = "sha256:92b3816aa2339355fda2c8c4e021a5de92180b00aaccaf5e2808972e77a4b22f", size = 999541, upload-time = "2025-08-12T07:00:02.709Z" }, + { url = "https://files.pythonhosted.org/packages/dd/e9/932b9eae6fd7019548321eee1ab8d5e3b3d1294df9d9a0c9ac517c7b636d/sentencepiece-0.2.1-cp313-cp313-win_amd64.whl", hash = "sha256:10ed3dab2044c47f7a2e7b4969b0c430420cdd45735d78c8f853191fa0e3148b", size = 1054669, upload-time = "2025-08-12T07:00:04.915Z" }, + { url = "https://files.pythonhosted.org/packages/c9/3a/76488a00ea7d6931689cda28726a1447d66bf1a4837943489314593d5596/sentencepiece-0.2.1-cp313-cp313-win_arm64.whl", hash = "sha256:ac650534e2251083c5f75dde4ff28896ce7c8904133dc8fef42780f4d5588fcd", size = 1033922, upload-time = "2025-08-12T07:00:06.496Z" }, + { url = "https://files.pythonhosted.org/packages/4a/b6/08fe2ce819e02ccb0296f4843e3f195764ce9829cbda61b7513f29b95718/sentencepiece-0.2.1-cp313-cp313t-macosx_10_13_universal2.whl", hash = "sha256:8dd4b477a7b069648d19363aad0cab9bad2f4e83b2d179be668efa672500dc94", size = 1946052, upload-time = "2025-08-12T07:00:08.136Z" }, + { url = "https://files.pythonhosted.org/packages/ab/d9/1ea0e740591ff4c6fc2b6eb1d7510d02f3fb885093f19b2f3abd1363b402/sentencepiece-0.2.1-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:0c0f672da370cc490e4c59d89e12289778310a0e71d176c541e4834759e1ae07", size = 1327408, upload-time = "2025-08-12T07:00:09.572Z" }, + { url = "https://files.pythonhosted.org/packages/99/7e/1fb26e8a21613f6200e1ab88824d5d203714162cf2883248b517deb500b7/sentencepiece-0.2.1-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:ad8493bea8432dae8d6830365352350f3b4144415a1d09c4c8cb8d30cf3b6c3c", size = 1254857, upload-time = "2025-08-12T07:00:11.021Z" }, + { url = "https://files.pythonhosted.org/packages/bc/85/c72fd1f3c7a6010544d6ae07f8ddb38b5e2a7e33bd4318f87266c0bbafbf/sentencepiece-0.2.1-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b81a24733726e3678d2db63619acc5a8dccd074f7aa7a54ecd5ca33ca6d2d596", size = 1315722, upload-time = "2025-08-12T07:00:12.989Z" }, + { url = "https://files.pythonhosted.org/packages/4a/e8/661e5bd82a8aa641fd6c1020bd0e890ef73230a2b7215ddf9c8cd8e941c2/sentencepiece-0.2.1-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0a81799d0a68d618e89063fb423c3001a034c893069135ffe51fee439ae474d6", size = 1387452, upload-time = "2025-08-12T07:00:15.088Z" }, + { url = "https://files.pythonhosted.org/packages/99/5e/ae66c361023a470afcbc1fbb8da722c72ea678a2fcd9a18f1a12598c7501/sentencepiece-0.2.1-cp313-cp313t-win32.whl", hash = "sha256:89a3ea015517c42c0341d0d962f3e6aaf2cf10d71b1932d475c44ba48d00aa2b", size = 1002501, upload-time = "2025-08-12T07:00:16.966Z" }, + { url = "https://files.pythonhosted.org/packages/c1/03/d332828c4ff764e16c1b56c2c8f9a33488bbe796b53fb6b9c4205ddbf167/sentencepiece-0.2.1-cp313-cp313t-win_amd64.whl", hash = "sha256:33f068c9382dc2e7c228eedfd8163b52baa86bb92f50d0488bf2b7da7032e484", size = 1057555, upload-time = "2025-08-12T07:00:18.573Z" }, + { url = "https://files.pythonhosted.org/packages/88/14/5aee0bf0864df9bd82bd59e7711362908e4935e3f9cdc1f57246b5d5c9b9/sentencepiece-0.2.1-cp313-cp313t-win_arm64.whl", hash = "sha256:b3616ad246f360e52c85781e47682d31abfb6554c779e42b65333d4b5f44ecc0", size = 1036042, upload-time = "2025-08-12T07:00:20.209Z" }, + { url = "https://files.pythonhosted.org/packages/24/9c/89eb8b2052f720a612478baf11c8227dcf1dc28cd4ea4c0c19506b5af2a2/sentencepiece-0.2.1-cp314-cp314-macosx_10_13_universal2.whl", hash = "sha256:5d0350b686c320068702116276cfb26c066dc7e65cfef173980b11bb4d606719", size = 1943147, upload-time = "2025-08-12T07:00:21.809Z" }, + { url = "https://files.pythonhosted.org/packages/82/0b/a1432bc87f97c2ace36386ca23e8bd3b91fb40581b5e6148d24b24186419/sentencepiece-0.2.1-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:c7f54a31cde6fa5cb030370566f68152a742f433f8d2be458463d06c208aef33", size = 1325624, upload-time = "2025-08-12T07:00:23.289Z" }, + { url = "https://files.pythonhosted.org/packages/ea/99/bbe054ebb5a5039457c590e0a4156ed073fb0fe9ce4f7523404dd5b37463/sentencepiece-0.2.1-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:c83b85ab2d6576607f31df77ff86f28182be4a8de6d175d2c33ca609925f5da1", size = 1253670, upload-time = "2025-08-12T07:00:24.69Z" }, + { url = "https://files.pythonhosted.org/packages/19/ad/d5c7075f701bd97971d7c2ac2904f227566f51ef0838dfbdfdccb58cd212/sentencepiece-0.2.1-cp314-cp314-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:1855f57db07b51fb51ed6c9c452f570624d2b169b36f0f79ef71a6e6c618cd8b", size = 1316247, upload-time = "2025-08-12T07:00:26.435Z" }, + { url = "https://files.pythonhosted.org/packages/fb/03/35fbe5f3d9a7435eebd0b473e09584bd3cc354ce118b960445b060d33781/sentencepiece-0.2.1-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:01e6912125cb45d3792f530a4d38f8e21bf884d6b4d4ade1b2de5cf7a8d2a52b", size = 1387894, upload-time = "2025-08-12T07:00:28.339Z" }, + { url = "https://files.pythonhosted.org/packages/dc/aa/956ef729aafb6c8f9c443104c9636489093bb5c61d6b90fc27aa1a865574/sentencepiece-0.2.1-cp314-cp314-win32.whl", hash = "sha256:c415c9de1447e0a74ae3fdb2e52f967cb544113a3a5ce3a194df185cbc1f962f", size = 1096698, upload-time = "2025-08-12T07:00:29.764Z" }, + { url = "https://files.pythonhosted.org/packages/b8/cb/fe400d8836952cc535c81a0ce47dc6875160e5fedb71d2d9ff0e9894c2a6/sentencepiece-0.2.1-cp314-cp314-win_amd64.whl", hash = "sha256:881b2e44b14fc19feade3cbed314be37de639fc415375cefaa5bc81a4be137fd", size = 1155115, upload-time = "2025-08-12T07:00:32.865Z" }, + { url = "https://files.pythonhosted.org/packages/32/89/047921cf70f36c7b6b6390876b2399b3633ab73b8d0cb857e5a964238941/sentencepiece-0.2.1-cp314-cp314-win_arm64.whl", hash = "sha256:2005242a16d2dc3ac5fe18aa7667549134d37854823df4c4db244752453b78a8", size = 1133890, upload-time = "2025-08-12T07:00:34.763Z" }, + { url = "https://files.pythonhosted.org/packages/a1/11/5b414b9fae6255b5fb1e22e2ed3dc3a72d3a694e5703910e640ac78346bb/sentencepiece-0.2.1-cp314-cp314t-macosx_10_13_universal2.whl", hash = "sha256:a19adcec27c524cb7069a1c741060add95f942d1cbf7ad0d104dffa0a7d28a2b", size = 1946081, upload-time = "2025-08-12T07:00:36.97Z" }, + { url = "https://files.pythonhosted.org/packages/77/eb/7a5682bb25824db8545f8e5662e7f3e32d72a508fdce086029d89695106b/sentencepiece-0.2.1-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:e37e4b4c4a11662b5db521def4e44d4d30ae69a1743241412a93ae40fdcab4bb", size = 1327406, upload-time = "2025-08-12T07:00:38.669Z" }, + { url = "https://files.pythonhosted.org/packages/03/b0/811dae8fb9f2784e138785d481469788f2e0d0c109c5737372454415f55f/sentencepiece-0.2.1-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:477c81505db072b3ab627e7eab972ea1025331bd3a92bacbf798df2b75ea86ec", size = 1254846, upload-time = "2025-08-12T07:00:40.611Z" }, + { url = "https://files.pythonhosted.org/packages/ef/23/195b2e7ec85ebb6a547969f60b723c7aca5a75800ece6cc3f41da872d14e/sentencepiece-0.2.1-cp314-cp314t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:010f025a544ef770bb395091d57cb94deb9652d8972e0d09f71d85d5a0816c8c", size = 1315721, upload-time = "2025-08-12T07:00:42.914Z" }, + { url = "https://files.pythonhosted.org/packages/7e/aa/553dbe4178b5f23eb28e59393dddd64186178b56b81d9b8d5c3ff1c28395/sentencepiece-0.2.1-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:733e59ff1794d26db706cd41fc2d7ca5f6c64a820709cb801dc0ea31780d64ab", size = 1387458, upload-time = "2025-08-12T07:00:44.56Z" }, + { url = "https://files.pythonhosted.org/packages/66/7c/08ff0012507297a4dd74a5420fdc0eb9e3e80f4e88cab1538d7f28db303d/sentencepiece-0.2.1-cp314-cp314t-win32.whl", hash = "sha256:d3233770f78e637dc8b1fda2cd7c3b99ec77e7505041934188a4e7fe751de3b0", size = 1099765, upload-time = "2025-08-12T07:00:46.058Z" }, + { url = "https://files.pythonhosted.org/packages/91/d5/2a69e1ce15881beb9ddfc7e3f998322f5cedcd5e4d244cb74dade9441663/sentencepiece-0.2.1-cp314-cp314t-win_amd64.whl", hash = "sha256:5e4366c97b68218fd30ea72d70c525e6e78a6c0a88650f57ac4c43c63b234a9d", size = 1157807, upload-time = "2025-08-12T07:00:47.673Z" }, + { url = "https://files.pythonhosted.org/packages/f3/16/54f611fcfc2d1c46cbe3ec4169780b2cfa7cf63708ef2b71611136db7513/sentencepiece-0.2.1-cp314-cp314t-win_arm64.whl", hash = "sha256:105e36e75cbac1292642045458e8da677b2342dcd33df503e640f0b457cb6751", size = 1136264, upload-time = "2025-08-12T07:00:49.485Z" }, +] + [[package]] name = "setuptools" version = "80.10.1"