From 261d13effe3022be3e13467b6a15b632eb83bc3a Mon Sep 17 00:00:00 2001 From: Grill cheese Date: Tue, 31 Mar 2026 15:57:14 -0400 Subject: [PATCH 01/17] install update for colab --- README.md | 4 ++-- scripts/install.sh | 21 ++++++++++++++++++--- 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index ec0185a..da9a656 100644 --- a/README.md +++ b/README.md @@ -53,8 +53,8 @@ curl -sSL https://raw.githubusercontent.com/Grillcheese-AI/grilly/main/scripts/i On Colab: ```python -# Recommended: fast mode for Colab (5 min instead of 30) -!wget -qO- https://raw.githubusercontent.com/Grillcheese-AI/grilly/main/scripts/install.sh | bash -s -- --fast +# Recommended: Colab mode (Vulkan 1.3 + fast build + NVIDIA ICD, ~5 min) +!wget -qO- https://raw.githubusercontent.com/Grillcheese-AI/grilly/main/scripts/install.sh | bash -s -- --colab ``` This installs system deps, downloads and builds Vulkan SDK 1.4, compiles the grilly C++ extension, and installs the Python package. The `--fast` flag builds only the components grilly needs (shaderc, loader, headers) and skips validation layers. diff --git a/scripts/install.sh b/scripts/install.sh index e7dbcfa..8f2cb3f 100644 --- a/scripts/install.sh +++ b/scripts/install.sh @@ -13,7 +13,9 @@ set -euo pipefail -VULKAN_SDK_VERSION="1.4.341.1" +VULKAN_SDK_VERSION_14="1.4.341.1" +VULKAN_SDK_VERSION_13="1.3.296.0" +VULKAN_SDK_VERSION="$VULKAN_SDK_VERSION_14" GRILLY_DIR="$(cd "$(dirname "$0")/.." 2>/dev/null && pwd || echo "$(pwd)")" BUILD_DIR="${GRILLY_DIR}/build" JOBS="${JOBS:-$(nproc 2>/dev/null || sysctl -n hw.ncpu 2>/dev/null || echo 4)}" @@ -23,7 +25,15 @@ FAST_MODE=false for arg in "$@"; do case "$arg" in --fast|-f) FAST_MODE=true ;; - --help|-h) echo "Usage: $0 [--fast]"; echo " --fast Only build shaderc + loader (5 min vs 30 min)"; exit 0 ;; + --colab) FAST_MODE=true; VULKAN_SDK_VERSION="$VULKAN_SDK_VERSION_13" ;; + --vulkan-13|--v13) VULKAN_SDK_VERSION="$VULKAN_SDK_VERSION_13" ;; + --help|-h) + echo "Usage: $0 [OPTIONS]" + echo " --fast Only build shaderc + loader (~5 min vs ~30 min)" + echo " --colab Colab mode: Vulkan 1.3 + fast build + NVIDIA ICD" + echo " --v13 Use Vulkan SDK 1.3 (for older drivers)" + exit 0 + ;; esac done @@ -97,8 +107,13 @@ install_vulkan_linux() { info "Downloading Vulkan SDK ${VULKAN_SDK_VERSION}..." local TMP_DIR=$(mktemp -d) local TAR="vulkansdk-linux-x86_64-${VULKAN_SDK_VERSION}.tar.xz" + # Try xz first, fall back to gz for older SDK versions wget -q "https://sdk.lunarg.com/sdk/download/${VULKAN_SDK_VERSION}/linux/${TAR}" \ - -O "${TMP_DIR}/${TAR}" || fail "Failed to download Vulkan SDK" + -O "${TMP_DIR}/${TAR}" 2>/dev/null || { + TAR="vulkansdk-linux-x86_64-${VULKAN_SDK_VERSION}.tar.gz" + wget -q "https://sdk.lunarg.com/sdk/download/${VULKAN_SDK_VERSION}/linux/${TAR}" \ + -O "${TMP_DIR}/${TAR}" || fail "Failed to download Vulkan SDK ${VULKAN_SDK_VERSION}" + } info "Extracting Vulkan SDK..." sudo mkdir -p /opt/vulkan From cfbe2a9f869b72e4ac042c9dfbd0b4dad8dc941b Mon Sep 17 00:00:00 2001 From: Grill cheese Date: Tue, 31 Mar 2026 17:26:11 -0400 Subject: [PATCH 02/17] Update install.sh --- scripts/install.sh | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/scripts/install.sh b/scripts/install.sh index 8f2cb3f..208e9c9 100644 --- a/scripts/install.sh +++ b/scripts/install.sh @@ -72,9 +72,8 @@ if [ "$PLATFORM" = "linux" ]; then # Colab/cloud: install NVIDIA Vulkan ICD if NVIDIA GPU detected if [ -d /proc/driver/nvidia ] || command -v nvidia-smi &>/dev/null; then info "NVIDIA GPU detected — installing Vulkan ICD driver..." - sudo apt-get install -y -qq nvidia-driver-550 2>/dev/null || \ - sudo apt-get install -y -qq nvidia-drivers-550 2>/dev/null || \ - warn "Could not install nvidia-driver-550 — Vulkan may not see GPU" + sudo apt-get install -y -qq nvidia-driver 2>/dev/null || \ + warn "Could not install nvidia-driver — Vulkan may not see GPU" fi ok "System deps installed (apt)" From b02403089519c64c1160a54956db71f72cb0dcaa Mon Sep 17 00:00:00 2001 From: Grill cheese Date: Thu, 2 Apr 2026 14:07:01 -0400 Subject: [PATCH 03/17] feat: grid-cell + addition-linear shaders + JIT snippets MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - grid-cell.glsl: hexagonal 3-wave grid cell firing pattern Three cosine waves at 60° intervals, orientation + phase + spacing params - addition-linear.glsl: multiplication-free L1 distance linear layer y = -||W - x||₁ + bias, no multiplications, only add/sub/abs - snippets/grid_cell.glsl: JIT-fusable op_grid_cell() - snippets/addition_linear.glsl: op_addition_linear(), op_sign_activation(), op_additive_sigmoid() for JIT fusion pipeline Co-Authored-By: Claude Opus 4.6 (1M context) --- shaders/addition-linear.glsl | 65 +++++++++++++++++++++++++ shaders/grid-cell.glsl | 69 +++++++++++++++++++++++++++ shaders/snippets/addition_linear.glsl | 17 +++++++ shaders/snippets/grid_cell.glsl | 10 ++++ 4 files changed, 161 insertions(+) create mode 100644 shaders/addition-linear.glsl create mode 100644 shaders/grid-cell.glsl create mode 100644 shaders/snippets/addition_linear.glsl create mode 100644 shaders/snippets/grid_cell.glsl diff --git a/shaders/addition-linear.glsl b/shaders/addition-linear.glsl new file mode 100644 index 0000000..f46676f --- /dev/null +++ b/shaders/addition-linear.glsl @@ -0,0 +1,65 @@ +#version 450 + +layout (local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +// Input vector (batch_size, in_features) +layout(set = 0, binding = 0) readonly buffer Input { + float input_data[]; +}; + +// Weight patterns / templates (out_features, in_features) +layout(set = 0, binding = 1) readonly buffer Weights { + float weight_patterns[]; +}; + +// Bias (out_features) — optional, can be zeros +layout(set = 0, binding = 2) readonly buffer Bias { + float bias[]; +}; + +// Output (batch_size, out_features) +layout(set = 0, binding = 3) buffer Output { + float output_data[]; +}; + +// Parameters +layout(push_constant) uniform PushConsts { + uint batch_size; + uint in_features; + uint out_features; + uint use_bias; // 0 or 1 +}; + +// Addition-only linear: y[b][o] = -sum_i |weight[o][i] - input[b][i]| + bias[o] +// No multiplications — only additions, subtractions, absolute values. +// This is a radial basis function using Manhattan (L1) distance. +// The closer the input is to weight row o, the higher (less negative) the output. +void main() { + uint gID = gl_GlobalInvocationID.x; + + // Each thread handles one (batch, output) pair + uint total = batch_size * out_features; + if (gID >= total) { + return; + } + + uint b = gID / out_features; // batch index + uint o = gID % out_features; // output neuron index + + // Compute L1 distance: sum_i |w[o][i] - x[b][i]| + float dist = 0.0; + uint w_offset = o * in_features; + uint x_offset = b * in_features; + + for (uint i = 0; i < in_features; i++) { + dist += abs(weight_patterns[w_offset + i] - input_data[x_offset + i]); + } + + // Output = -distance + bias (closer → higher value) + float result = -dist; + if (use_bias == 1) { + result += bias[o]; + } + + output_data[gID] = result; +} diff --git a/shaders/grid-cell.glsl b/shaders/grid-cell.glsl new file mode 100644 index 0000000..61ef4b8 --- /dev/null +++ b/shaders/grid-cell.glsl @@ -0,0 +1,69 @@ +#version 450 + +layout (local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +// Current agent position (2D) +layout(set = 0, binding = 0) readonly buffer AgentPosition { + float agent_pos[]; // [x, y] +}; + +// Grid cell parameters (per neuron) +layout(set = 0, binding = 1) readonly buffer GridSpacings { + float spacings[]; // [s0, s1, ...] +}; + +layout(set = 0, binding = 2) readonly buffer GridOrientations { + float orientations[]; // [theta0, theta1, ...] +}; + +layout(set = 0, binding = 3) readonly buffer GridPhases { + float phases[]; // [px0, py0, px1, py1, ...] +}; + +// Output firing rates +layout(set = 0, binding = 4) buffer FiringRates { + float rates[]; +}; + +// Parameters +layout(push_constant) uniform PushConsts { + uint n_neurons; + float max_rate; // e.g. 25 Hz +}; + +// Hexagonal 3-wave grid cell activation +// Three cosine waves at 60° intervals create hexagonal firing pattern +// r = max_rate * relu((cos(u1) + cos(u2) + cos(u3)) / 3 + 0.5) +void main() { + uint gID = gl_GlobalInvocationID.x; + + if (gID >= n_neurons) { + return; + } + + float x = agent_pos[0]; + float y = agent_pos[1]; + + // Rotation by grid orientation + float theta = orientations[gID]; + float cos_t = cos(theta); + float sin_t = sin(theta); + float rx = cos_t * x - sin_t * y; + float ry = sin_t * x + cos_t * y; + + // Shift by phase + uint phase_off = gID * 2; + float sx = rx - phases[phase_off]; + float sy = ry - phases[phase_off + 1]; + + // Wave vector: k = 4*pi / (sqrt(3) * spacing) + float k = 4.0 * 3.14159265 / (1.7320508 * spacings[gID]); + + // Three waves at 60° intervals (hexagonal pattern) + float u1 = k * sx; + float u2 = k * (-0.5 * sx + 0.866025 * sy); + float u3 = k * (-0.5 * sx - 0.866025 * sy); + + float val = (cos(u1) + cos(u2) + cos(u3)) / 3.0 + 0.5; + rates[gID] = max_rate * max(val, 0.0); +} diff --git a/shaders/snippets/addition_linear.glsl b/shaders/snippets/addition_linear.glsl new file mode 100644 index 0000000..b5dd95a --- /dev/null +++ b/shaders/snippets/addition_linear.glsl @@ -0,0 +1,17 @@ +// Snippet: Addition-only linear (L1 distance) +// Computes -||w - x||_1 for a single (output, input) pair +// No multiplications — only add, sub, abs +float op_addition_linear(float w, float x) { + return -abs(w - x); +} + +// Snippet: Sign activation with threshold (ternary output) +float op_sign_activation(float x, float threshold) { + return sign(x - threshold); +} + +// Snippet: Additive receptance (sigmoid approximation) +// sigmoid(x) ≈ clamp(0.5 + 0.25 * x, 0, 1) +float op_additive_sigmoid(float x) { + return clamp(0.5 + 0.25 * x, 0.0, 1.0); +} diff --git a/shaders/snippets/grid_cell.glsl b/shaders/snippets/grid_cell.glsl new file mode 100644 index 0000000..b5eed74 --- /dev/null +++ b/shaders/snippets/grid_cell.glsl @@ -0,0 +1,10 @@ +// Snippet: Grid cell hexagonal 3-wave activation +// Input: position (rx, ry shifted), spacing, wave vector k +// Output: firing rate [0, max_rate] +float op_grid_cell(float rx, float ry, float k, float max_rate) { + float u1 = k * rx; + float u2 = k * (-0.5 * rx + 0.866025 * ry); + float u3 = k * (-0.5 * rx - 0.866025 * ry); + float val = (cos(u1) + cos(u2) + cos(u3)) / 3.0 + 0.5; + return max_rate * max(val, 0.0); +} From 3c0bb96ec6c512836fd130bd4317f08850deb80e Mon Sep 17 00:00:00 2001 From: Grill cheese Date: Thu, 2 Apr 2026 19:52:21 -0400 Subject: [PATCH 04/17] =?UTF-8?q?feat:=20protobuf=20channels=20=E2=80=94?= =?UTF-8?q?=20C++/Python=20zero-copy=20message=20passing?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Core infrastructure for typed message passing between Vulkan compute (C++) and CubeMind brain modules (Python). Proto definitions (cpp/proto/grilly_channels.proto): - TensorData, SpikeTrain, SpikeEvent - ExpertWeights, ExpertUpdate, RouteRequest/Response - MemoryCapsule, MemoryQuery/Result - TelemetryEvent, NeurochemState - TrainStepRequest/Result - GrillyCompute RPC service C++ channel (cpp/include/grilly/channels/channel.h): - BaseChannel abstract interface - InProcessChannel: thread-safe queue with listener pattern - MessageEnvelope with type, timestamp, sender, payload Python channel (backend/channels.py): - Channel class: C++ backend with pure-Python fallback - Convenience: send_tensor, receive_tensor, send_spikes, receive_spikes, send_telemetry - Subscriber pattern: ch.on(MessageType.TELEMETRY, callback) pybind11 bindings (cpp/python/bindings_channels.cpp): - MessageType enum, MessageEnvelope, InProcessChannel - numpy array convenience methods Co-Authored-By: Claude Opus 4.6 (1M context) --- backend/channels.py | 221 ++++++++++++++++++++++++++ cpp/include/grilly/channels/channel.h | 127 +++++++++++++++ cpp/proto/grilly_channels.proto | 176 ++++++++++++++++++++ cpp/python/bindings_channels.cpp | 108 +++++++++++++ cpp/src/channels/channel.cpp | 72 +++++++++ 5 files changed, 704 insertions(+) create mode 100644 backend/channels.py create mode 100644 cpp/include/grilly/channels/channel.h create mode 100644 cpp/proto/grilly_channels.proto create mode 100644 cpp/python/bindings_channels.cpp create mode 100644 cpp/src/channels/channel.cpp diff --git a/backend/channels.py b/backend/channels.py new file mode 100644 index 0000000..ffe9d3d --- /dev/null +++ b/backend/channels.py @@ -0,0 +1,221 @@ +"""grilly.backend.channels — Python-side channel interface. + +Wraps the C++ InProcessChannel with Pythonic API. +Falls back to a pure-Python implementation if C++ not available. + +Usage: + from grilly.backend.channels import Channel, MessageType + + ch = Channel("brain") + ch.send_tensor(numpy_array, sender="vision") + data = ch.receive_tensor() + + ch.send_spikes(spike_array, n_neurons=64, n_timesteps=10) + + # Subscribe to events + ch.on(MessageType.TELEMETRY_EVENT, lambda msg: print(msg)) +""" + +from __future__ import annotations + +import time +import struct +from collections import defaultdict +from enum import IntEnum +from queue import Queue, Empty +from threading import Lock +from typing import Any, Callable, Dict, List, Optional + +import numpy as np + + +class MessageType(IntEnum): + """Message types matching C++ enum.""" + TENSOR_DATA = 0 + SPIKE_TRAIN = 1 + EXPERT_WEIGHTS = 2 + EXPERT_UPDATE = 3 + ROUTE_REQUEST = 4 + ROUTE_RESPONSE = 5 + MEMORY_CAPSULE = 6 + MEMORY_QUERY = 7 + MEMORY_RESULT = 8 + TELEMETRY_EVENT = 9 + NEUROCHEM_STATE = 10 + TRAIN_STEP_REQUEST = 11 + TRAIN_STEP_RESULT = 12 + + +class Message: + """A channel message with type, payload, and metadata.""" + + __slots__ = ("type", "payload", "sender_id", "timestamp_ns", "metadata") + + def __init__(self, msg_type: MessageType, payload: bytes = b"", + sender_id: str = "", metadata: Dict[str, Any] | None = None): + self.type = msg_type + self.payload = payload + self.sender_id = sender_id + self.timestamp_ns = time.time_ns() + self.metadata = metadata or {} + + @property + def size(self) -> int: + return len(self.payload) + + +class Channel: + """High-level channel with C++ backend fallback. + + Tries to use grilly_core.InProcessChannel (C++, zero-copy). + Falls back to pure-Python thread-safe queue. + + Args: + name: Channel name for debugging. + max_queue_size: Maximum queued messages before dropping oldest. + """ + + def __init__(self, name: str = "default", max_queue_size: int = 10000): + self.name = name + self._cpp_channel = None + self._py_queue: Queue = Queue(maxsize=max_queue_size) + self._listeners: Dict[MessageType, List[Callable]] = defaultdict(list) + self._lock = Lock() + + # Try C++ channel + try: + from grilly import _core + if hasattr(_core, "InProcessChannel"): + self._cpp_channel = _core.InProcessChannel(name, max_queue_size) + except Exception: + pass + + @property + def backend(self) -> str: + return "cpp" if self._cpp_channel else "python" + + def send(self, msg: Message) -> None: + """Send a message. Notifies listeners synchronously.""" + # Notify listeners + for callback in self._listeners.get(msg.type, []): + try: + callback(msg) + except Exception: + pass + + if self._cpp_channel: + # TODO: convert Message → C++ MessageEnvelope + pass + + # Python fallback + if self._py_queue.full(): + try: + self._py_queue.get_nowait() # Drop oldest + except Empty: + pass + self._py_queue.put_nowait(msg) + + def receive(self, timeout: float | None = None) -> Message | None: + """Receive next message. Returns None if empty.""" + if self._cpp_channel: + # TODO: receive from C++ channel + pass + + try: + return self._py_queue.get(timeout=timeout) + except Empty: + return None + + def has_messages(self) -> bool: + if self._cpp_channel: + return self._cpp_channel.has_messages() + return not self._py_queue.empty() + + def on(self, msg_type: MessageType, callback: Callable) -> None: + """Subscribe to a message type.""" + self._listeners[msg_type].append(callback) + + def queue_size(self) -> int: + if self._cpp_channel: + return self._cpp_channel.queue_size() + return self._py_queue.qsize() + + def clear(self) -> None: + if self._cpp_channel: + self._cpp_channel.clear() + while not self._py_queue.empty(): + try: + self._py_queue.get_nowait() + except Empty: + break + + # ── Convenience methods ────────────────────────────────────────────── + + def send_tensor(self, arr: np.ndarray, sender: str = "python") -> None: + """Send a numpy array as a TENSOR_DATA message.""" + msg = Message( + msg_type=MessageType.TENSOR_DATA, + payload=arr.astype(np.float32).tobytes(), + sender_id=sender, + metadata={"shape": list(arr.shape), "dtype": str(arr.dtype)}, + ) + self.send(msg) + + def receive_tensor(self, shape: tuple | None = None) -> np.ndarray | None: + """Receive a TENSOR_DATA message as numpy array.""" + msg = self.receive() + if msg is None or msg.type != MessageType.TENSOR_DATA: + return None + arr = np.frombuffer(msg.payload, dtype=np.float32) + if shape: + arr = arr.reshape(shape) + elif "shape" in msg.metadata: + arr = arr.reshape(msg.metadata["shape"]) + return arr + + def send_spikes(self, spikes: np.ndarray, n_neurons: int, + n_timesteps: int, sender: str = "python") -> None: + """Send spike train: (timesteps, neurons) flattened.""" + header = struct.pack(" tuple[np.ndarray, int, int] | None: + """Receive spike train → (spikes, n_neurons, n_timesteps).""" + msg = self.receive() + if msg is None or msg.type != MessageType.SPIKE_TRAIN: + return None + n_neurons, n_timesteps = struct.unpack(" None: + """Send a telemetry event.""" + import json + payload = json.dumps({ + "component_id": component_id, + "event_type": event_type, + "metrics": metrics or {}, + "step": step, + }).encode() + msg = Message( + msg_type=MessageType.TELEMETRY_EVENT, + payload=payload, + sender_id=component_id, + ) + self.send(msg) + + def stats(self) -> Dict[str, Any]: + return { + "name": self.name, + "backend": self.backend, + "queue_size": self.queue_size(), + "listeners": {t.name: len(cbs) for t, cbs in self._listeners.items()}, + } diff --git a/cpp/include/grilly/channels/channel.h b/cpp/include/grilly/channels/channel.h new file mode 100644 index 0000000..9b1ea1a --- /dev/null +++ b/cpp/include/grilly/channels/channel.h @@ -0,0 +1,127 @@ +/** + * grilly/channels/channel.h — Protobuf-based C++/Python channel interface. + * + * Provides zero-copy message passing between Vulkan compute (C++) + * and CubeMind brain modules (Python) via protobuf serialization. + * + * Two modes: + * 1. In-process: direct pybind11 calls with native proto casters + * 2. IPC: shared memory or Unix domain sockets for multi-process + * + * Usage (C++ side): + * auto channel = grilly::InProcessChannel(); + * channel.send(spike_data); + * auto result = channel.receive(); + * + * Usage (Python side, via pybind11): + * channel = grilly_core.InProcessChannel() + * channel.send_spike_train(spike_data) + * result = channel.receive_tensor() + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +namespace grilly { +namespace channels { + +/** + * MessageType — identifies the protobuf message type without parsing. + */ +enum class MessageType : uint8_t { + TENSOR_DATA = 0, + SPIKE_TRAIN = 1, + EXPERT_WEIGHTS = 2, + EXPERT_UPDATE = 3, + ROUTE_REQUEST = 4, + ROUTE_RESPONSE = 5, + MEMORY_CAPSULE = 6, + MEMORY_QUERY = 7, + MEMORY_RESULT = 8, + TELEMETRY_EVENT = 9, + NEUROCHEM_STATE = 10, + TRAIN_STEP_REQUEST = 11, + TRAIN_STEP_RESULT = 12, +}; + +/** + * MessageEnvelope — header + serialized payload. + */ +struct MessageEnvelope { + MessageType type; + uint64_t timestamp_ns; + std::string sender_id; + std::vector payload; // serialized protobuf bytes + + size_t size() const { return payload.size(); } +}; + +/** + * ChannelListener — callback interface for async message handling. + */ +using ChannelListener = std::function; + +/** + * BaseChannel — abstract channel interface. + */ +class BaseChannel { +public: + virtual ~BaseChannel() = default; + + /// Send a serialized message. + virtual void send(MessageEnvelope envelope) = 0; + + /// Receive next message (blocking). Returns empty if no messages. + virtual MessageEnvelope receive() = 0; + + /// Check if messages are available. + virtual bool has_messages() const = 0; + + /// Register an async listener for a message type. + virtual void subscribe(MessageType type, ChannelListener listener) = 0; + + /// Channel name for debugging. + virtual std::string name() const = 0; +}; + +/** + * InProcessChannel — thread-safe in-process message queue. + * + * For single-process use: C++ compute threads push messages, + * Python pybind11 threads pull them. Zero serialization overhead + * when combined with pybind11_protobuf native casters. + */ +class InProcessChannel : public BaseChannel { +public: + explicit InProcessChannel(const std::string& name = "default", + size_t max_queue_size = 10000); + + void send(MessageEnvelope envelope) override; + MessageEnvelope receive() override; + bool has_messages() const override; + void subscribe(MessageType type, ChannelListener listener) override; + std::string name() const override { return name_; } + + /// Queue size for monitoring. + size_t queue_size() const; + + /// Clear all queued messages. + void clear(); + +private: + std::string name_; + size_t max_queue_size_; + mutable std::mutex mutex_; + std::queue queue_; + std::unordered_map> listeners_; +}; + +} // namespace channels +} // namespace grilly diff --git a/cpp/proto/grilly_channels.proto b/cpp/proto/grilly_channels.proto new file mode 100644 index 0000000..bf03ac4 --- /dev/null +++ b/cpp/proto/grilly_channels.proto @@ -0,0 +1,176 @@ +// grilly_channels.proto — Core message types for grilly C++/Python channels. +// +// These messages define the wire format for zero-copy data passing between +// Vulkan compute shaders (C++) and CubeMind's Python brain modules. + +syntax = "proto3"; +package grilly; + +// ═══════════════════════════════════════════════════════════════════════════════ +// Tensor data +// ═══════════════════════════════════════════════════════════════════════════════ + +message TensorData { + repeated uint32 shape = 1; // e.g., [batch, seq, dim] + string dtype = 2; // "float32", "float16", "int8", "int4" + bytes data = 3; // raw bytes (packed) + string device = 4; // "cpu", "vulkan0", "cuda0" + uint64 buffer_id = 5; // GPU buffer handle (zero-copy) +} + +// ═══════════════════════════════════════════════════════════════════════════════ +// SNN spike trains +// ═══════════════════════════════════════════════════════════════════════════════ + +message SpikeTrain { + uint32 n_neurons = 1; + uint32 n_timesteps = 2; + repeated float spikes = 3 [packed = true]; // flattened (timesteps, neurons) + float dt = 4; // timestep duration +} + +message SpikeEvent { + uint32 neuron_id = 1; + uint32 timestep = 2; + float value = 3; // multi-bit spike level (0..L) +} + +// ═══════════════════════════════════════════════════════════════════════════════ +// Expert / MoE messages +// ═══════════════════════════════════════════════════════════════════════════════ + +message ExpertWeights { + uint32 expert_id = 1; + TensorData weights = 2; // quantized weights (INT4/INT8) + TensorData scales = 3; // per-block scale factors + string quant_type = 4; // "int4", "int8", "float16" +} + +message ExpertUpdate { + uint32 expert_id = 1; + float q_value = 2; // bandit Q-value + float charge = 3; // HE-MoE charge (+1/-1) + repeated float mu = 4 [packed = true]; // expert position in RKHS + float a_trace = 5; // activity trace + repeated float e_trace = 6 [packed = true]; // error trace + uint64 n_uses = 7; +} + +message RouteRequest { + TensorData input = 1; + uint32 top_k = 2; + float temperature = 3; +} + +message RouteResponse { + repeated uint32 expert_ids = 1 [packed = true]; + repeated float weights = 2 [packed = true]; + repeated float scores = 3 [packed = true]; +} + +// ═══════════════════════════════════════════════════════════════════════════════ +// Memory / Hippocampal capsules +// ═══════════════════════════════════════════════════════════════════════════════ + +message MemoryCapsule { + string capsule_id = 1; + TensorData context = 2; // input context at storage time + repeated uint32 expert_ids = 3 [packed = true]; // which experts were active + TensorData error = 4; // prediction error + uint64 timestamp_ns = 5; + float surprise = 6; // error magnitude + repeated string tags = 7; // "novel", "high_error", etc. +} + +message MemoryQuery { + TensorData query = 1; + uint32 top_k = 2; + float min_similarity = 3; + string filter_tag = 4; // optional tag filter +} + +message MemoryResult { + repeated MemoryCapsule capsules = 1; + repeated float similarities = 2 [packed = true]; +} + +// ═══════════════════════════════════════════════════════════════════════════════ +// Telemetry / Observability +// ═══════════════════════════════════════════════════════════════════════════════ + +message TelemetryEvent { + string component_id = 1; + string event_type = 2; // "forward", "spawn", "charge_flip", etc. + map metrics = 3; // key-value metrics + map labels = 4; // key-value labels + uint64 timestamp_ns = 5; + uint32 step = 6; +} + +message TelemetryBatch { + repeated TelemetryEvent events = 1; +} + +// ═══════════════════════════════════════════════════════════════════════════════ +// Neurochemistry +// ═══════════════════════════════════════════════════════════════════════════════ + +message NeurochemState { + float dopamine = 1; + float serotonin = 2; + float cortisol = 3; + float noradrenaline = 4; + float oxytocin = 5; + string emotion = 6; // Lövheim cube classification + float arousal = 7; + float valence = 8; +} + +// ═══════════════════════════════════════════════════════════════════════════════ +// Training +// ═══════════════════════════════════════════════════════════════════════════════ + +message TrainStepRequest { + TensorData input = 1; + TensorData target = 2; + float learning_rate = 3; + bool extract_logits = 4; +} + +message TrainStepResult { + float loss = 1; + float loss_ce = 2; + float loss_kd = 3; + uint32 n_experts = 4; + bool spawned = 5; + float residual_ema = 6; + TelemetryEvent telemetry = 7; + TensorData teacher_logits = 8; // optional, if extract_logits=true +} + +// ═══════════════════════════════════════════════════════════════════════════════ +// RPC Service definitions +// ═══════════════════════════════════════════════════════════════════════════════ + +service GrillyCompute { + // Core tensor ops + rpc Linear(TensorData) returns (TensorData); + rpc Relu(TensorData) returns (TensorData); + rpc Gelu(TensorData) returns (TensorData); + + // SNN ops + rpc GIFNeuronStep(SpikeTrain) returns (SpikeTrain); + rpc SynapsisFwd(TensorData) returns (TensorData); + rpc STDPUpdate(SpikeTrain) returns (ExpertWeights); + + // MoE routing + rpc Route(RouteRequest) returns (RouteResponse); + rpc UpdateExpert(ExpertUpdate) returns (TelemetryEvent); + + // Memory + rpc StoreMemory(MemoryCapsule) returns (TelemetryEvent); + rpc QueryMemory(MemoryQuery) returns (MemoryResult); + + // Training + rpc TrainStep(TrainStepRequest) returns (TrainStepResult); +} diff --git a/cpp/python/bindings_channels.cpp b/cpp/python/bindings_channels.cpp new file mode 100644 index 0000000..5080939 --- /dev/null +++ b/cpp/python/bindings_channels.cpp @@ -0,0 +1,108 @@ +/** + * bindings_channels.cpp — pybind11 bindings for grilly channels. + * + * Exposes InProcessChannel to Python with numpy-friendly message passing. + * When pybind11_protobuf is available, uses native proto casters. + * Otherwise, falls back to bytes-based serialization. + */ + +#include +#include +#include + +#include "grilly/channels/channel.h" + +namespace py = pybind11; + +void bind_channels(py::module_& m) { + using namespace grilly::channels; + + // MessageType enum + py::enum_(m, "MessageType") + .value("TENSOR_DATA", MessageType::TENSOR_DATA) + .value("SPIKE_TRAIN", MessageType::SPIKE_TRAIN) + .value("EXPERT_WEIGHTS", MessageType::EXPERT_WEIGHTS) + .value("EXPERT_UPDATE", MessageType::EXPERT_UPDATE) + .value("ROUTE_REQUEST", MessageType::ROUTE_REQUEST) + .value("ROUTE_RESPONSE", MessageType::ROUTE_RESPONSE) + .value("MEMORY_CAPSULE", MessageType::MEMORY_CAPSULE) + .value("MEMORY_QUERY", MessageType::MEMORY_QUERY) + .value("MEMORY_RESULT", MessageType::MEMORY_RESULT) + .value("TELEMETRY_EVENT", MessageType::TELEMETRY_EVENT) + .value("NEUROCHEM_STATE", MessageType::NEUROCHEM_STATE) + .value("TRAIN_STEP_REQUEST", MessageType::TRAIN_STEP_REQUEST) + .value("TRAIN_STEP_RESULT", MessageType::TRAIN_STEP_RESULT); + + // MessageEnvelope + py::class_(m, "MessageEnvelope") + .def(py::init<>()) + .def_readwrite("type", &MessageEnvelope::type) + .def_readwrite("timestamp_ns", &MessageEnvelope::timestamp_ns) + .def_readwrite("sender_id", &MessageEnvelope::sender_id) + .def_property("payload", + [](const MessageEnvelope& e) { + return py::bytes(reinterpret_cast(e.payload.data()), + e.payload.size()); + }, + [](MessageEnvelope& e, py::bytes data) { + std::string s = data; + e.payload.assign(s.begin(), s.end()); + }) + .def_property_readonly("size", &MessageEnvelope::size); + + // InProcessChannel + py::class_(m, "InProcessChannel") + .def(py::init(), + py::arg("name") = "default", + py::arg("max_queue_size") = 10000) + .def("send", &InProcessChannel::send) + .def("receive", &InProcessChannel::receive) + .def("has_messages", &InProcessChannel::has_messages) + .def("queue_size", &InProcessChannel::queue_size) + .def("clear", &InProcessChannel::clear) + .def("name", &InProcessChannel::name) + + // Convenience: send numpy array as TensorData + .def("send_tensor", [](InProcessChannel& ch, + py::array_t arr, + const std::string& sender_id) { + auto info = arr.request(); + MessageEnvelope env; + env.type = MessageType::TENSOR_DATA; + env.sender_id = sender_id; + env.payload.assign( + reinterpret_cast(info.ptr), + reinterpret_cast(info.ptr) + info.size * sizeof(float)); + ch.send(std::move(env)); + }, py::arg("array"), py::arg("sender_id") = "python") + + // Convenience: receive as numpy array + .def("receive_tensor", [](InProcessChannel& ch) -> py::object { + auto msg = ch.receive(); + if (msg.payload.empty()) return py::none(); + size_t n_floats = msg.payload.size() / sizeof(float); + auto result = py::array_t(n_floats); + auto buf = result.request(); + std::memcpy(buf.ptr, msg.payload.data(), msg.payload.size()); + return result; + }) + + // Convenience: send spike train (timesteps × neurons flattened) + .def("send_spikes", [](InProcessChannel& ch, + py::array_t spikes, + uint32_t n_neurons, + uint32_t n_timesteps, + const std::string& sender_id) { + auto info = spikes.request(); + MessageEnvelope env; + env.type = MessageType::SPIKE_TRAIN; + env.sender_id = sender_id; + // Pack header: n_neurons (4 bytes) + n_timesteps (4 bytes) + data + env.payload.resize(8 + info.size * sizeof(float)); + std::memcpy(env.payload.data(), &n_neurons, 4); + std::memcpy(env.payload.data() + 4, &n_timesteps, 4); + std::memcpy(env.payload.data() + 8, info.ptr, info.size * sizeof(float)); + ch.send(std::move(env)); + }, py::arg("spikes"), py::arg("n_neurons"), py::arg("n_timesteps"), + py::arg("sender_id") = "python"); +} diff --git a/cpp/src/channels/channel.cpp b/cpp/src/channels/channel.cpp new file mode 100644 index 0000000..1da267c --- /dev/null +++ b/cpp/src/channels/channel.cpp @@ -0,0 +1,72 @@ +/** + * channel.cpp — InProcessChannel implementation. + */ + +#include "grilly/channels/channel.h" + +#include +#include + +namespace grilly { +namespace channels { + +InProcessChannel::InProcessChannel(const std::string& name, size_t max_queue_size) + : name_(name), max_queue_size_(max_queue_size) {} + +void InProcessChannel::send(MessageEnvelope envelope) { + std::lock_guard lock(mutex_); + + // Set timestamp if not set + if (envelope.timestamp_ns == 0) { + auto now = std::chrono::high_resolution_clock::now(); + envelope.timestamp_ns = std::chrono::duration_cast( + now.time_since_epoch()).count(); + } + + // Notify listeners first (synchronous) + auto it = listeners_.find(static_cast(envelope.type)); + if (it != listeners_.end()) { + for (auto& listener : it->second) { + listener(envelope); + } + } + + // Queue the message (drop oldest if full) + if (queue_.size() >= max_queue_size_) { + queue_.pop(); + } + queue_.push(std::move(envelope)); +} + +MessageEnvelope InProcessChannel::receive() { + std::lock_guard lock(mutex_); + if (queue_.empty()) { + return MessageEnvelope{MessageType::TENSOR_DATA, 0, "", {}}; + } + auto msg = std::move(queue_.front()); + queue_.pop(); + return msg; +} + +bool InProcessChannel::has_messages() const { + std::lock_guard lock(mutex_); + return !queue_.empty(); +} + +void InProcessChannel::subscribe(MessageType type, ChannelListener listener) { + std::lock_guard lock(mutex_); + listeners_[static_cast(type)].push_back(std::move(listener)); +} + +size_t InProcessChannel::queue_size() const { + std::lock_guard lock(mutex_); + return queue_.size(); +} + +void InProcessChannel::clear() { + std::lock_guard lock(mutex_); + while (!queue_.empty()) queue_.pop(); +} + +} // namespace channels +} // namespace grilly From 64c500c370d7ed974b2b520ec3ac99e0aa4da55b Mon Sep 17 00:00:00 2001 From: Grill cheese Date: Fri, 3 Apr 2026 15:10:47 -0400 Subject: [PATCH 05/17] pre v1.0.0 --- .gitignore | 4 + backend/_bridge.py | 211 +++++++++++++++++++ cpp/include/grilly/vulkan/vk_buffer_pool.h | 9 + cpp/include/grilly/vulkan/vk_command_batch.h | 16 +- cpp/src/buffer_pool.cpp | 181 ++++++++-------- cpp/src/command_batch.cpp | 30 ++- cpp/src/ops/activations.cpp | 12 +- cpp/src/ops/attention.cpp | 3 +- cpp/src/ops/attention_ops.cpp | 15 +- cpp/src/ops/batchnorm.cpp | 6 +- cpp/src/ops/conv.cpp | 21 +- cpp/src/ops/embedding.cpp | 3 +- cpp/src/ops/fused.cpp | 6 +- cpp/src/ops/kv_cache.cpp | 6 +- cpp/src/ops/layernorm.cpp | 6 +- cpp/src/ops/learning.cpp | 21 +- cpp/src/ops/linear.cpp | 13 +- cpp/src/ops/loss.cpp | 6 +- cpp/src/ops/moqe_train.cpp | 6 +- cpp/src/ops/optimizer.cpp | 6 +- cpp/src/ops/perceiver.cpp | 3 +- cpp/src/ops/perceiver_encoder.cpp | 6 +- cpp/src/ops/pooling.cpp | 15 +- cpp/src/ops/rmsnorm.cpp | 3 +- cpp/src/ops/snn.cpp | 21 +- cpp/src/ops/swizzle.cpp | 3 +- functional/activations.py | 28 ++- functional/attention.py | 84 ++++++-- functional/dropout.py | 35 ++- functional/embedding.py | 97 ++------- functional/faiss.py | 131 +++++------- functional/fft.py | 23 +- functional/linear.py | 10 +- functional/loss.py | 35 ++- functional/memory.py | 124 ++++++----- functional/normalization.py | 11 +- nn/gpu_backward.py | 85 ++++---- shaders/fnn-linear.glsl | 71 ++++--- shaders/spv/fnn-linear.spv | Bin 3356 -> 6104 bytes 39 files changed, 843 insertions(+), 523 deletions(-) diff --git a/.gitignore b/.gitignore index fa4811a..3268438 100644 --- a/.gitignore +++ b/.gitignore @@ -201,3 +201,7 @@ docs/symp_former.md /third_party third_party/BLAKE3 third_party/VulkanMemoryAllocator +docs/BRIDGE_MIGRATION_CHECKLIST.md +docs/GPU_OPTIMIZATION_REVIEW.md +docs/PYTORCH_PARITY_STATUS.md +docs/PYTORCH_PARITY_TASKLIST.md diff --git a/backend/_bridge.py b/backend/_bridge.py index 2f86017..131ee49 100644 --- a/backend/_bridge.py +++ b/backend/_bridge.py @@ -1790,6 +1790,217 @@ def moqe_route_and_gemv(activations, choice, expert_weights, expert_scales, bloc ) +# ── Bridge support for functional modules without native kernels ────────── + + +def faiss_distance(query, vectors, distance_type="l2"): + """Compute FAISS-style distances with numpy fallback.""" + q = np.asarray(query, dtype=np.float32) + v = np.asarray(vectors, dtype=np.float32) + if q.ndim == 1: + q = q.reshape(1, -1) + if distance_type == "l2": + diff = q[:, None, :] - v[None, :, :] + return np.sqrt(np.sum(diff * diff, axis=2, dtype=np.float32), dtype=np.float32) + if distance_type == "cosine": + q_norm = q / (np.linalg.norm(q, axis=1, keepdims=True) + 1e-8) + v_norm = v / (np.linalg.norm(v, axis=1, keepdims=True) + 1e-8) + return (1.0 - (q_norm @ v_norm.T)).astype(np.float32) + return (-(q @ v.T)).astype(np.float32) + + +def faiss_topk(distances, k): + """Select top-k nearest neighbors from a distance matrix.""" + d = np.asarray(distances, dtype=np.float32) + if d.ndim == 1: + d = d.reshape(1, -1) + k = max(1, min(int(k), d.shape[1])) + idx = np.argpartition(d, kth=k - 1, axis=1)[:, :k] + vals = np.take_along_axis(d, idx, axis=1) + order = np.argsort(vals, axis=1) + idx = np.take_along_axis(idx, order, axis=1).astype(np.int32) + vals = np.take_along_axis(vals, order, axis=1).astype(np.float32) + return idx, vals + + +def memory_read(queries, memory_keys, memory_values, temperature=None): + """Memory read via attention-style lookup.""" + q = np.asarray(queries, dtype=np.float32) + k = np.asarray(memory_keys, dtype=np.float32) + v = np.asarray(memory_values, dtype=np.float32) + temp = float(temperature) if temperature is not None else float(np.sqrt(k.shape[-1])) + scores = (q @ k.T) / max(temp, 1e-8) + scores = scores - np.max(scores, axis=-1, keepdims=True) + weights = np.exp(scores) + weights = weights / np.sum(weights, axis=-1, keepdims=True) + return (weights @ v).astype(np.float32) + + +def memory_write(new_key, new_value, memory_keys, memory_values, write_index, write_mode=0, blend_factor=0.5): + """Memory write for key/value banks.""" + keys = np.array(memory_keys, dtype=np.float32, copy=True) + vals = np.array(memory_values, dtype=np.float32, copy=True) + idx = int(write_index) + if write_mode == 1: + a = float(blend_factor) + keys[idx] = (1.0 - a) * keys[idx] + a * np.asarray(new_key, dtype=np.float32) + vals[idx] = (1.0 - a) * vals[idx] + a * np.asarray(new_value, dtype=np.float32) + else: + keys[idx] = np.asarray(new_key, dtype=np.float32) + vals[idx] = np.asarray(new_value, dtype=np.float32) + return keys, vals + + +def memory_query_pooling(x, w_query, b_query): + """Pool sequence to query vectors by mean-pool + projection.""" + x_arr = np.asarray(x, dtype=np.float32) + pooled = np.mean(x_arr, axis=1) + return (pooled @ np.asarray(w_query, dtype=np.float32).T + np.asarray(b_query, dtype=np.float32)).astype(np.float32) + + +def memory_inject_gate(x, memory_context, w_gate, b_gate, w_mem_proj): + """Gate between token state and projected memory context.""" + x_arr = np.asarray(x, dtype=np.float32) + mem = np.asarray(memory_context, dtype=np.float32) + batch, seq_len, dim = x_arr.shape + mem_expand = np.broadcast_to(mem[:, None, :], (batch, seq_len, dim)) + concat = np.concatenate([x_arr, mem_expand], axis=-1) + gate = 1.0 / (1.0 + np.exp(-(concat @ np.asarray(w_gate, dtype=np.float32).T + np.asarray(b_gate, dtype=np.float32)))) + mem_proj = mem @ np.asarray(w_mem_proj, dtype=np.float32).T + mem_proj = np.broadcast_to(mem_proj[:, None, :], (batch, seq_len, dim)) + return ((1.0 - gate) * x_arr + gate * mem_proj).astype(np.float32) + + +def fisher_info_update(gradients, fisher, momentum=0.9, use_ema=True, reset=False): + """Update Fisher information estimate.""" + g = np.asarray(gradients, dtype=np.float32) + f = np.zeros_like(g, dtype=np.float32) if reset else np.asarray(fisher, dtype=np.float32) + g2 = g * g + if use_ema: + return (float(momentum) * f + (1.0 - float(momentum)) * g2).astype(np.float32) + return (f + g2).astype(np.float32) + + +def ewc_penalty(current_params, important_params, fisher, lambda_ewc=1.0): + """Compute scalar EWC penalty.""" + cur = np.asarray(current_params, dtype=np.float32) + imp = np.asarray(important_params, dtype=np.float32) + f = np.asarray(fisher, dtype=np.float32) + diff = cur - imp + return float(0.5 * float(lambda_ewc) * np.sum(f * diff * diff, dtype=np.float32)) + + +def natural_gradient(gradients, fisher, eps=1e-8): + """Approximate natural gradient using diagonal Fisher.""" + g = np.asarray(gradients, dtype=np.float32) + f = np.asarray(fisher, dtype=np.float32) + return (g / (f + float(eps))).astype(np.float32) + + +def nlms_predict(x, w, bias=0.0): + """NLMS prediction.""" + return float(np.dot(np.asarray(x, dtype=np.float32), np.asarray(w, dtype=np.float32)) + float(bias)) + + +def nlms_update(x, y_true, w, bias, mu=0.5, eps=1e-6): + """NLMS weight update.""" + x_arr = np.asarray(x, dtype=np.float32) + w_arr = np.asarray(w, dtype=np.float32) + pred = float(np.dot(x_arr, w_arr) + float(bias)) + err = float(y_true) - pred + denom = float(np.dot(x_arr, x_arr) + float(eps)) + step = float(mu) * err / denom + w_new = (w_arr + step * x_arr).astype(np.float32) + b_new = float(bias + step) + return w_new, b_new + + +def whitening_transform(data, mean=None, std=None): + """Compute/apply whitening transform.""" + x = np.asarray(data, dtype=np.float32) + m = np.asarray(mean, dtype=np.float32) if mean is not None else np.mean(x, axis=0) + s = np.asarray(std, dtype=np.float32) if std is not None else np.std(x, axis=0) + out = (x - m) / (s + 1e-8) + return out.astype(np.float32), m.astype(np.float32), s.astype(np.float32) + + +def place_cell(agent_position, field_centers, field_width=1.0, max_rate=20.0, baseline_rate=0.1): + """Gaussian place-cell tuning curves.""" + pos = np.asarray(agent_position, dtype=np.float32) + centers = np.asarray(field_centers, dtype=np.float32) + if pos.ndim == 1: + pos = pos[None, :] + diff = pos[:, None, :] - centers[None, :, :] + d2 = np.sum(diff * diff, axis=-1) + rates = float(baseline_rate) + (float(max_rate) - float(baseline_rate)) * np.exp( + -d2 / (2.0 * float(field_width) * float(field_width) + 1e-8) + ) + return rates.astype(np.float32).squeeze(0) if np.asarray(agent_position).ndim == 1 else rates.astype(np.float32) + + +def time_cell(current_time, preferred_times, time_constant=1.0, max_rate=15.0, baseline_rate=0.1, membrane_state=None): + """Temporal Gaussian tuning with optional EMA membrane state.""" + t = float(current_time) + pref = np.asarray(preferred_times, dtype=np.float32) + rates = float(baseline_rate) + (float(max_rate) - float(baseline_rate)) * np.exp( + -((t - pref) ** 2) / (2.0 * float(time_constant) * float(time_constant) + 1e-8) + ) + if membrane_state is None: + mem = rates.astype(np.float32) + else: + mem = (0.9 * np.asarray(membrane_state, dtype=np.float32) + 0.1 * rates).astype(np.float32) + return rates.astype(np.float32), mem + + +def continuous_to_spikes(continuous, num_timesteps=10, encoding_type=0, projection_weights=None, projection_bias=None): + """Convert continuous vectors to spike trains.""" + x = np.asarray(continuous, dtype=np.float32) + if x.ndim == 1: + x = x[None, :] + if projection_weights is not None: + x = x @ np.asarray(projection_weights, dtype=np.float32).T + if projection_bias is not None: + x = x + np.asarray(projection_bias, dtype=np.float32) + x = np.clip(x, 0.0, 1.0) + t = int(num_timesteps) + if encoding_type == 1: + spikes = np.zeros((x.shape[0], t, x.shape[1]), dtype=np.float32) + idx = np.minimum((1.0 - x) * (t - 1), t - 1).astype(np.int32) + b_idx = np.arange(x.shape[0])[:, None] + f_idx = np.arange(x.shape[1])[None, :] + spikes[b_idx, idx, f_idx] = 1.0 + return spikes + if encoding_type == 2: + phase = np.linspace(0.0, 2.0 * np.pi, t, endpoint=False, dtype=np.float32) + return ((np.sin(phase[None, :, None] + x[:, None, :] * 2.0 * np.pi) > 0).astype(np.float32)) + return (np.random.rand(x.shape[0], t, x.shape[1]).astype(np.float32) < x[:, None, :]).astype(np.float32) + + +def spikes_to_continuous(spikes, encoding_type=0, time_window=5, temporal_weights=None, projection_weights=None, projection_bias=None): + """Convert spike trains to continuous representations.""" + s = np.asarray(spikes, dtype=np.float32) + if s.ndim != 3: + raise ValueError(f"spikes must be 3D (batch, time, features), got shape={s.shape}") + if encoding_type == 1: + if temporal_weights is None: + tw = np.linspace(1.0, 0.0, s.shape[1], dtype=np.float32) + else: + tw = np.asarray(temporal_weights, dtype=np.float32) + tw = tw / (np.sum(tw) + 1e-8) + cont = np.sum(s * tw[None, :, None], axis=1) + elif encoding_type == 2: + phase = np.linspace(0.0, 2.0 * np.pi, s.shape[1], endpoint=False, dtype=np.float32) + cont = np.sum(s * np.sin(phase)[None, :, None], axis=1) / max(s.shape[1], 1) + else: + window = max(1, min(int(time_window), s.shape[1])) + cont = np.mean(s[:, -window:, :], axis=1) + if projection_weights is not None: + cont = cont @ np.asarray(projection_weights, dtype=np.float32).T + if projection_bias is not None: + cont = cont + np.asarray(projection_bias, dtype=np.float32) + return cont.astype(np.float32) + + def q_similarity(queries): """Compute q-similarity (TAPPA metric) for attention queries. diff --git a/cpp/include/grilly/vulkan/vk_buffer_pool.h b/cpp/include/grilly/vulkan/vk_buffer_pool.h index f36050d..c938ee1 100644 --- a/cpp/include/grilly/vulkan/vk_buffer_pool.h +++ b/cpp/include/grilly/vulkan/vk_buffer_pool.h @@ -67,6 +67,15 @@ class BufferPool { std::mutex mutex_; std::unordered_map> buckets_; Stats stats_{}; + + // Persistent transfer context (avoids cmd pool/fence recreation per transfer) + VkCommandPool transferPool_ = VK_NULL_HANDLE; + VkCommandBuffer transferCmd_ = VK_NULL_HANDLE; + VkFence transferFence_ = VK_NULL_HANDLE; + bool transferInitialized_ = false; + + void ensureTransferContext(); + void transferSubmitAndWait(); }; } // namespace grilly diff --git a/cpp/include/grilly/vulkan/vk_command_batch.h b/cpp/include/grilly/vulkan/vk_command_batch.h index e7203dd..783ebb9 100644 --- a/cpp/include/grilly/vulkan/vk_command_batch.h +++ b/cpp/include/grilly/vulkan/vk_command_batch.h @@ -21,6 +21,9 @@ class CommandBatch { void begin(); + /// Start recording if not already. Safe to call multiple times. + void ensureRecording(); + void dispatch(VkPipeline pipeline, VkPipelineLayout layout, VkDescriptorSet descSet, uint32_t gx, uint32_t gy = 1, uint32_t gz = 1, @@ -29,13 +32,22 @@ class CommandBatch { void barrier(); /// GPU-to-GPU buffer copy (no CPU involvement). - /// Must be called between begin() and submit(). void copyBuffer(const GrillyBuffer& src, GrillyBuffer& dst, size_t bytes); void submit(); void submitAsync(VkSemaphore timeline, uint64_t signalValue); + /// Submit without waiting — GPU runs while CPU continues. + void submitDeferred(); + + /// Block until the last submitted work finishes. No-op if nothing pending. + void waitForCompletion(); + + /// Number of dispatches recorded in current command buffer. + uint32_t dispatchCount() const { return dispatchCount_; } + bool isRecording() const { return recording_; } + bool isPending() const { return pending_; } VkCommandBuffer cmdBuffer() const { return cmd_; } private: @@ -44,6 +56,8 @@ class CommandBatch { VkCommandBuffer cmd_ = VK_NULL_HANDLE; VkFence fence_ = VK_NULL_HANDLE; bool recording_ = false; + bool pending_ = false; + uint32_t dispatchCount_ = 0; }; } // namespace grilly diff --git a/cpp/src/buffer_pool.cpp b/cpp/src/buffer_pool.cpp index 37231b4..d7d1cc2 100644 --- a/cpp/src/buffer_pool.cpp +++ b/cpp/src/buffer_pool.cpp @@ -42,7 +42,21 @@ BufferPool::BufferPool(GrillyDevice& device) : device_(device) { } BufferPool::~BufferPool() { - // Destroy all pooled buffers first + VkDevice dev = device_.device(); + + // Clean up persistent transfer context + if (transferInitialized_) { + if (transferFence_ != VK_NULL_HANDLE) { + vkWaitForFences(dev, 1, &transferFence_, VK_TRUE, UINT64_MAX); + vkDestroyFence(dev, transferFence_, nullptr); + } + if (transferCmd_ != VK_NULL_HANDLE) + vkFreeCommandBuffers(dev, transferPool_, 1, &transferCmd_); + if (transferPool_ != VK_NULL_HANDLE) + vkDestroyCommandPool(dev, transferPool_, nullptr); + } + + // Destroy all pooled buffers for (auto& [bucketSize, vec] : buckets_) { for (auto& buf : vec) { if (buf.handle != VK_NULL_HANDLE) @@ -86,14 +100,15 @@ GrillyBuffer BufferPool::allocateBuffer(size_t bucketSize) { bufferInfo.usage = VK_BUFFER_USAGE_STORAGE_BUFFER_BIT; bufferInfo.sharingMode = VK_SHARING_MODE_EXCLUSIVE; - // VMA_MEMORY_USAGE_AUTO lets VMA pick the best heap. - // MAPPED_BIT gives us a persistent CPU pointer — the key win over Python - // which does vkMapMemory/vkUnmapMemory on every upload/download cycle. - // HOST_ACCESS_SEQUENTIAL_WRITE_BIT hints that we write linearly (memcpy). + // Use device-local memory with host-visible fallback. + // On discrete GPUs with Resizable BAR, VMA places this in VRAM with + // CPU-visible mapping — best of both worlds (fast GPU access + memcpy upload). + // Without ReBAR, falls back to host-visible (system RAM). VmaAllocationCreateInfo allocInfo{}; allocInfo.usage = VMA_MEMORY_USAGE_AUTO; allocInfo.flags = VMA_ALLOCATION_CREATE_MAPPED_BIT | VMA_ALLOCATION_CREATE_HOST_ACCESS_SEQUENTIAL_WRITE_BIT; + allocInfo.preferredFlags = VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT; GrillyBuffer buf{}; buf.bucketSize = bucketSize; @@ -247,9 +262,57 @@ GrillyBuffer BufferPool::acquireReadback(size_t size) { return buf; } +// ── Persistent transfer context ──────────────────────────────────────────── + +void BufferPool::ensureTransferContext() { + if (transferInitialized_) return; + + VkCommandPoolCreateInfo poolInfo{}; + poolInfo.sType = VK_STRUCTURE_TYPE_COMMAND_POOL_CREATE_INFO; + poolInfo.flags = VK_COMMAND_POOL_CREATE_RESET_COMMAND_BUFFER_BIT; + poolInfo.queueFamilyIndex = device_.queueFamily(); + vkCheck(vkCreateCommandPool(device_.device(), &poolInfo, nullptr, &transferPool_), + "transfer pool creation failed"); + + VkCommandBufferAllocateInfo cmdAllocInfo{}; + cmdAllocInfo.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_ALLOCATE_INFO; + cmdAllocInfo.commandPool = transferPool_; + cmdAllocInfo.level = VK_COMMAND_BUFFER_LEVEL_PRIMARY; + cmdAllocInfo.commandBufferCount = 1; + vkCheck(vkAllocateCommandBuffers(device_.device(), &cmdAllocInfo, &transferCmd_), + "transfer cmd alloc failed"); + + VkFenceCreateInfo fenceInfo{}; + fenceInfo.sType = VK_STRUCTURE_TYPE_FENCE_CREATE_INFO; + fenceInfo.flags = VK_FENCE_CREATE_SIGNALED_BIT; + vkCheck(vkCreateFence(device_.device(), &fenceInfo, nullptr, &transferFence_), + "transfer fence creation failed"); + + transferInitialized_ = true; +} + +void BufferPool::transferSubmitAndWait() { + vkEndCommandBuffer(transferCmd_); + + vkWaitForFences(device_.device(), 1, &transferFence_, VK_TRUE, UINT64_MAX); + vkResetFences(device_.device(), 1, &transferFence_); + + VkSubmitInfo submitInfo{}; + submitInfo.sType = VK_STRUCTURE_TYPE_SUBMIT_INFO; + submitInfo.commandBufferCount = 1; + submitInfo.pCommandBuffers = &transferCmd_; + + vkCheck(vkQueueSubmit(device_.computeQueue(), 1, &submitInfo, transferFence_), + "transfer submit failed"); + vkCheck(vkWaitForFences(device_.device(), 1, &transferFence_, VK_TRUE, UINT64_MAX), + "transfer wait failed"); +} + void BufferPool::uploadStaged(GrillyBuffer& deviceBuf, const void* data, size_t bytes) { - // 1. Create a temporary host-visible staging buffer + ensureTransferContext(); + + // 1. Create staging buffer (TODO: pool these too) VkBufferCreateInfo stagingInfo{}; stagingInfo.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO; stagingInfo.size = bytes; @@ -268,67 +331,33 @@ void BufferPool::uploadStaged(GrillyBuffer& deviceBuf, const void* data, &stagingBuf, &stagingMem, &stagingMemInfo), "staging buffer alloc failed"); - // 2. Copy data into staging buffer + // 2. Copy data into staging std::memcpy(stagingMemInfo.pMappedData, data, bytes); vmaFlushAllocation(allocator_, stagingMem, 0, bytes); - // 3. Create a one-shot command buffer for the transfer - VkCommandPoolCreateInfo poolInfo{}; - poolInfo.sType = VK_STRUCTURE_TYPE_COMMAND_POOL_CREATE_INFO; - poolInfo.flags = VK_COMMAND_POOL_CREATE_TRANSIENT_BIT; - poolInfo.queueFamilyIndex = device_.queueFamily(); - - VkCommandPool cmdPool; - vkCheck(vkCreateCommandPool(device_.device(), &poolInfo, nullptr, &cmdPool), - "vkCreateCommandPool for staging failed"); - - VkCommandBufferAllocateInfo cmdAllocInfo{}; - cmdAllocInfo.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_ALLOCATE_INFO; - cmdAllocInfo.commandPool = cmdPool; - cmdAllocInfo.level = VK_COMMAND_BUFFER_LEVEL_PRIMARY; - cmdAllocInfo.commandBufferCount = 1; - - VkCommandBuffer cmd; - vkCheck(vkAllocateCommandBuffers(device_.device(), &cmdAllocInfo, &cmd), - "vkAllocateCommandBuffers for staging failed"); - + // 3. Record transfer using persistent context + vkResetCommandBuffer(transferCmd_, 0); VkCommandBufferBeginInfo beginInfo{}; beginInfo.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO; beginInfo.flags = VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT; - vkBeginCommandBuffer(cmd, &beginInfo); + vkBeginCommandBuffer(transferCmd_, &beginInfo); VkBufferCopy copyRegion{}; copyRegion.size = bytes; - vkCmdCopyBuffer(cmd, stagingBuf, deviceBuf.handle, 1, ©Region); + vkCmdCopyBuffer(transferCmd_, stagingBuf, deviceBuf.handle, 1, ©Region); - vkEndCommandBuffer(cmd); + // 4. Submit + wait (reuses persistent fence) + transferSubmitAndWait(); - // 4. Submit and wait - VkSubmitInfo submitInfo{}; - submitInfo.sType = VK_STRUCTURE_TYPE_SUBMIT_INFO; - submitInfo.commandBufferCount = 1; - submitInfo.pCommandBuffers = &cmd; - - VkFenceCreateInfo fenceInfo{}; - fenceInfo.sType = VK_STRUCTURE_TYPE_FENCE_CREATE_INFO; - VkFence fence; - vkCheck(vkCreateFence(device_.device(), &fenceInfo, nullptr, &fence), - "staging fence creation failed"); - - vkCheck(vkQueueSubmit(device_.computeQueue(), 1, &submitInfo, fence), - "staging vkQueueSubmit failed"); - vkCheck(vkWaitForFences(device_.device(), 1, &fence, VK_TRUE, UINT64_MAX), - "staging vkWaitForFences failed"); - - // 5. Cleanup - vkDestroyFence(device_.device(), fence, nullptr); - vkDestroyCommandPool(device_.device(), cmdPool, nullptr); + // 5. Cleanup staging only vmaDestroyBuffer(allocator_, stagingBuf, stagingMem); } void BufferPool::downloadStaged(const GrillyBuffer& deviceBuf, void* out, size_t bytes) { - // 1. Create a temporary host-visible staging buffer for readback + ensureTransferContext(); + + // 1. Create staging readback buffer (TODO: pool these) VkBufferCreateInfo stagingInfo{}; stagingInfo.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO; stagingInfo.size = bytes; @@ -347,61 +376,25 @@ void BufferPool::downloadStaged(const GrillyBuffer& deviceBuf, void* out, &stagingBuf, &stagingMem, &stagingMemInfo), "staging readback buffer alloc failed"); - // 2. Copy from device-local to staging via DMA - VkCommandPoolCreateInfo poolInfo{}; - poolInfo.sType = VK_STRUCTURE_TYPE_COMMAND_POOL_CREATE_INFO; - poolInfo.flags = VK_COMMAND_POOL_CREATE_TRANSIENT_BIT; - poolInfo.queueFamilyIndex = device_.queueFamily(); - - VkCommandPool cmdPool; - vkCheck(vkCreateCommandPool(device_.device(), &poolInfo, nullptr, &cmdPool), - "vkCreateCommandPool for readback failed"); - - VkCommandBufferAllocateInfo cmdAllocInfo{}; - cmdAllocInfo.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_ALLOCATE_INFO; - cmdAllocInfo.commandPool = cmdPool; - cmdAllocInfo.level = VK_COMMAND_BUFFER_LEVEL_PRIMARY; - cmdAllocInfo.commandBufferCount = 1; - - VkCommandBuffer cmd; - vkCheck(vkAllocateCommandBuffers(device_.device(), &cmdAllocInfo, &cmd), - "vkAllocateCommandBuffers for readback failed"); - + // 2. Record copy using persistent transfer context + vkResetCommandBuffer(transferCmd_, 0); VkCommandBufferBeginInfo beginInfo{}; beginInfo.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO; beginInfo.flags = VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT; - vkBeginCommandBuffer(cmd, &beginInfo); + vkBeginCommandBuffer(transferCmd_, &beginInfo); VkBufferCopy copyRegion{}; copyRegion.size = bytes; - vkCmdCopyBuffer(cmd, deviceBuf.handle, stagingBuf, 1, ©Region); - - vkEndCommandBuffer(cmd); - - // 3. Submit and wait - VkSubmitInfo submitInfo{}; - submitInfo.sType = VK_STRUCTURE_TYPE_SUBMIT_INFO; - submitInfo.commandBufferCount = 1; - submitInfo.pCommandBuffers = &cmd; - - VkFenceCreateInfo fenceInfo{}; - fenceInfo.sType = VK_STRUCTURE_TYPE_FENCE_CREATE_INFO; - VkFence fence; - vkCheck(vkCreateFence(device_.device(), &fenceInfo, nullptr, &fence), - "readback fence creation failed"); + vkCmdCopyBuffer(transferCmd_, deviceBuf.handle, stagingBuf, 1, ©Region); - vkCheck(vkQueueSubmit(device_.computeQueue(), 1, &submitInfo, fence), - "readback vkQueueSubmit failed"); - vkCheck(vkWaitForFences(device_.device(), 1, &fence, VK_TRUE, UINT64_MAX), - "readback vkWaitForFences failed"); + // 3. Submit + wait (reuses persistent fence) + transferSubmitAndWait(); // 4. Invalidate + copy to output vmaInvalidateAllocation(allocator_, stagingMem, 0, bytes); std::memcpy(out, stagingMemInfo.pMappedData, bytes); - // 5. Cleanup - vkDestroyFence(device_.device(), fence, nullptr); - vkDestroyCommandPool(device_.device(), cmdPool, nullptr); + // 5. Cleanup staging only vmaDestroyBuffer(allocator_, stagingBuf, stagingMem); } diff --git a/cpp/src/command_batch.cpp b/cpp/src/command_batch.cpp index 9e3ea46..ac5cb33 100644 --- a/cpp/src/command_batch.cpp +++ b/cpp/src/command_batch.cpp @@ -70,6 +70,9 @@ void CommandBatch::begin() { if (recording_) throw std::runtime_error("CommandBatch::begin() called while already recording"); + // Wait for any prior pending work before reusing the command buffer + waitForCompletion(); + vkResetCommandBuffer(cmd_, 0); VkCommandBufferBeginInfo beginInfo{}; @@ -80,6 +83,12 @@ void CommandBatch::begin() { "vkBeginCommandBuffer failed"); recording_ = true; + dispatchCount_ = 0; +} + +void CommandBatch::ensureRecording() { + if (!recording_) + begin(); } void CommandBatch::dispatch(VkPipeline pipeline, VkPipelineLayout layout, @@ -99,6 +108,7 @@ void CommandBatch::dispatch(VkPipeline pipeline, VkPipelineLayout layout, } vkCmdDispatch(cmd_, gx, gy, gz); + dispatchCount_++; } void CommandBatch::barrier() { @@ -135,6 +145,12 @@ void CommandBatch::copyBuffer(const GrillyBuffer& src, GrillyBuffer& dst, size_t // ── Submission ────────────────────────────────────────────────────────────── void CommandBatch::submit() { + // Synchronous: submit + wait. Safe default for backward compat. + submitDeferred(); + waitForCompletion(); +} + +void CommandBatch::submitDeferred() { if (!recording_) return; @@ -142,7 +158,10 @@ void CommandBatch::submit() { recording_ = false; // Wait for any prior submission to complete, then reset fence - vkWaitForFences(device_.device(), 1, &fence_, VK_TRUE, kFenceTimeoutNs); + if (pending_) { + vkWaitForFences(device_.device(), 1, &fence_, VK_TRUE, kFenceTimeoutNs); + pending_ = false; + } vkResetFences(device_.device(), 1, &fence_); VkSubmitInfo submitInfo{}; @@ -153,10 +172,17 @@ void CommandBatch::submit() { vkCheck(vkQueueSubmit(device_.computeQueue(), 1, &submitInfo, fence_), "vkQueueSubmit failed"); - // Wait for completion so callers can read back results + pending_ = true; + // Returns immediately — GPU runs in background +} + +void CommandBatch::waitForCompletion() { + if (!pending_) + return; vkCheck( vkWaitForFences(device_.device(), 1, &fence_, VK_TRUE, kFenceTimeoutNs), "vkWaitForFences timed out"); + pending_ = false; } void CommandBatch::submitAsync(VkSemaphore timeline, uint64_t signalValue) { diff --git a/cpp/src/ops/activations.cpp b/cpp/src/ops/activations.cpp index 59e0781..f5f6704 100644 --- a/cpp/src/ops/activations.cpp +++ b/cpp/src/ops/activations.cpp @@ -41,7 +41,8 @@ static void activationForward( batch.begin(); batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, 1, 1, &push, sizeof(push)); - batch.submit(); + batch.submitDeferred(); + batch.waitForCompletion(); pool.download(bufOut, output, bytes); @@ -83,7 +84,8 @@ static void activationBackward( batch.begin(); batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, 1, 1, &push, sizeof(push)); - batch.submit(); + batch.submitDeferred(); + batch.waitForCompletion(); pool.download(bufGradIn, gradInput, bytes); @@ -207,7 +209,8 @@ void softmax(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx2, 1, 1, &push2, sizeof(push2)); - batch.submit(); + batch.submitDeferred(); + batch.waitForCompletion(); pool.download(bufOutput, output, dataBytes); @@ -250,7 +253,8 @@ void softmaxBackward(CommandBatch& batch, BufferPool& pool, PipelineCache& cache batch.begin(); batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, 1, 1, &push, sizeof(push)); - batch.submit(); + batch.submitDeferred(); + batch.waitForCompletion(); pool.download(bufGradIn, gradInput, bytes); diff --git a/cpp/src/ops/attention.cpp b/cpp/src/ops/attention.cpp index 8706be7..b7015bf 100644 --- a/cpp/src/ops/attention.cpp +++ b/cpp/src/ops/attention.cpp @@ -149,7 +149,8 @@ void flashAttention2(CommandBatch& batch, BufferPool& pool, batch.dispatch(pipe.pipeline, pipe.layout, descSet, finalGroups, 1, 1, &push2, sizeof(push2)); - batch.submit(); + batch.submitDeferred(); + batch.waitForCompletion(); // Download result pool.download(bufOutput, output, outBytes); diff --git a/cpp/src/ops/attention_ops.cpp b/cpp/src/ops/attention_ops.cpp index 7c41e07..b7755b6 100644 --- a/cpp/src/ops/attention_ops.cpp +++ b/cpp/src/ops/attention_ops.cpp @@ -59,7 +59,8 @@ void attentionScores(CommandBatch& batch, BufferPool& pool, PipelineCache& cache batch.begin(); batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, gy, gz, &p, sizeof(p)); - batch.submit(); + batch.submitDeferred(); + batch.waitForCompletion(); pool.download(bufScores, scores, scoreBytes); @@ -105,7 +106,8 @@ void attentionMask(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, batch.begin(); batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, 1, 1, &p, sizeof(p)); - batch.submit(); + batch.submitDeferred(); + batch.waitForCompletion(); pool.download(bufScores, scores, scoreBytes); @@ -148,7 +150,8 @@ void attentionOutput(CommandBatch& batch, BufferPool& pool, PipelineCache& cache batch.begin(); batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, 1, 1, &p, sizeof(p)); - batch.submit(); + batch.submitDeferred(); + batch.waitForCompletion(); pool.download(bufOutput, output, outBytes); @@ -188,7 +191,8 @@ void attentionConcatHeads(CommandBatch& batch, BufferPool& pool, batch.begin(); batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, 1, 1, &p, sizeof(p)); - batch.submit(); + batch.submitDeferred(); + batch.waitForCompletion(); pool.download(bufOut, concatOutput, outBytes); @@ -235,7 +239,8 @@ void applyRoPE(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, batch.begin(); batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, 1, 1, &p, sizeof(p)); - batch.submit(); + batch.submitDeferred(); + batch.waitForCompletion(); pool.download(bufOut, output, dataBytes); diff --git a/cpp/src/ops/batchnorm.cpp b/cpp/src/ops/batchnorm.cpp index a108258..1037f5f 100644 --- a/cpp/src/ops/batchnorm.cpp +++ b/cpp/src/ops/batchnorm.cpp @@ -67,7 +67,8 @@ void batchnorm2dForward(CommandBatch& batch, BufferPool& pool, batch.begin(); batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, 1, 1, &p, sizeof(p)); - batch.submit(); + batch.submitDeferred(); + batch.waitForCompletion(); pool.download(bufOutput, output, dataBytes); pool.download(bufRunMean, runningMean, featBytes); @@ -140,7 +141,8 @@ void batchnorm2dBackward(CommandBatch& batch, BufferPool& pool, batch.begin(); batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, 1, 1, &p, sizeof(p)); - batch.submit(); + batch.submitDeferred(); + batch.waitForCompletion(); pool.download(bufGradIn, gradInput, dataBytes); pool.download(bufGradGamma, gradGamma, featBytes); diff --git a/cpp/src/ops/conv.cpp b/cpp/src/ops/conv.cpp index b07c917..f12449d 100644 --- a/cpp/src/ops/conv.cpp +++ b/cpp/src/ops/conv.cpp @@ -175,7 +175,8 @@ void conv2d(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, &biasPush, sizeof(biasPush)); } - batch.submit(); + batch.submitDeferred(); + batch.waitForCompletion(); // Download from the correct buffer (biased output if shader was // available, raw GEMM output otherwise) @@ -246,7 +247,8 @@ void conv2d(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, batch.begin(); batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, gy, gz, &push, sizeof(push)); - batch.submit(); + batch.submitDeferred(); + batch.waitForCompletion(); pool.download(bufOutput, output, outputBytes); @@ -324,7 +326,8 @@ void conv2dBackwardInput(CommandBatch& batch, BufferPool& pool, batch.begin(); batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, gy, gz, &p, sizeof(p)); - batch.submit(); + batch.submitDeferred(); + batch.waitForCompletion(); pool.download(bufGradIn, gradInput, gradInBytes); @@ -387,7 +390,8 @@ void conv2dBackwardWeight(CommandBatch& batch, BufferPool& pool, batch.begin(); batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, gy, 1, &p, sizeof(p)); - batch.submit(); + batch.submitDeferred(); + batch.waitForCompletion(); pool.download(bufGradW, gradWeight, gradWBytes); if (p.hasBias && gradBias) { @@ -439,7 +443,8 @@ void conv2d3x3Gelu(CommandBatch& batch, BufferPool& pool, batch.dispatch(pipe.pipeline, pipe.layout, desc, (width + 15) / 16, (height + 15) / 16, outChannels, &pc, sizeof(pc)); - batch.submit(); + batch.submitDeferred(); + batch.waitForCompletion(); pool.download(bufOut, output, outBytes); @@ -480,7 +485,8 @@ void maxpool2x2(CommandBatch& batch, BufferPool& pool, batch.dispatch(pipe.pipeline, pipe.layout, desc, (outW + 15) / 16, (outH + 15) / 16, channels, &pc, sizeof(pc)); - batch.submit(); + batch.submitDeferred(); + batch.waitForCompletion(); pool.download(bufOut, output, outBytes); @@ -516,7 +522,8 @@ void adaptiveAvgPool3x3(CommandBatch& batch, BufferPool& pool, batch.dispatch(pipe.pipeline, pipe.layout, desc, 1, 1, (channels + 15) / 16, &pc, sizeof(pc)); - batch.submit(); + batch.submitDeferred(); + batch.waitForCompletion(); pool.download(bufOut, output, outBytes); diff --git a/cpp/src/ops/embedding.cpp b/cpp/src/ops/embedding.cpp index 05b5c9d..410f70d 100644 --- a/cpp/src/ops/embedding.cpp +++ b/cpp/src/ops/embedding.cpp @@ -45,7 +45,8 @@ void embeddingLookup(CommandBatch& batch, BufferPool& pool, batch.begin(); batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, 1, 1, &p, sizeof(p)); - batch.submit(); + batch.submitDeferred(); + batch.waitForCompletion(); pool.download(bufOut, output, outBytes); diff --git a/cpp/src/ops/fused.cpp b/cpp/src/ops/fused.cpp index 870406a..62650f4 100644 --- a/cpp/src/ops/fused.cpp +++ b/cpp/src/ops/fused.cpp @@ -59,7 +59,8 @@ void fusedMlpGelu(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, batch.begin(); batch.dispatch(pipe.pipeline, pipe.layout, descSet, seqLen, 1, 1, nullptr, 0); - batch.submit(); + batch.submitDeferred(); + batch.waitForCompletion(); pool.download(bufOutput, output, outputBytes); @@ -119,7 +120,8 @@ void fusedLayernormLinear(CommandBatch& batch, BufferPool& pool, PipelineCache& batch.begin(); batch.dispatch(pipe.pipeline, pipe.layout, descSet, seqLen, 1, 1, nullptr, 0); - batch.submit(); + batch.submitDeferred(); + batch.waitForCompletion(); pool.download(bufOutput, output, outputBytes); diff --git a/cpp/src/ops/kv_cache.cpp b/cpp/src/ops/kv_cache.cpp index 0651fe9..90e3bac 100644 --- a/cpp/src/ops/kv_cache.cpp +++ b/cpp/src/ops/kv_cache.cpp @@ -231,7 +231,8 @@ void kvCacheAppend(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, batch.begin(); batch.dispatch(pipeGemm.pipeline, pipeGemm.layout, descGemm, gemmGX, gemmGY, 1, gemmPush, sizeof(gemmPush)); - batch.submit(); + batch.submitDeferred(); + batch.waitForCompletion(); // Copy compressed latents into cache at currentLen offset std::vector latentData(numNewTokens * cfg.numHeads * latentDim); @@ -475,7 +476,8 @@ void kvCacheDecode(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, batch.dispatch(pipeGemm.pipeline, pipeGemm.layout, descGemm, (N_proj + 15) / 16, (M_proj + 15) / 16, 1, gemmPush, sizeof(gemmPush)); - batch.submit(); + batch.submitDeferred(); + batch.waitForCompletion(); // Download and split into K and V std::vector decoded(M_proj * N_proj); diff --git a/cpp/src/ops/layernorm.cpp b/cpp/src/ops/layernorm.cpp index 14771c2..b7bd3dd 100644 --- a/cpp/src/ops/layernorm.cpp +++ b/cpp/src/ops/layernorm.cpp @@ -88,7 +88,8 @@ void layernorm(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx2, 1, 1, &push2, sizeof(push2)); - batch.submit(); + batch.submitDeferred(); + batch.waitForCompletion(); // Download result pool.download(bufOutput, output, outputBytes); @@ -179,7 +180,8 @@ void layernormBackward(CommandBatch& batch, BufferPool& pool, (features + 255) / 256, 1, 1, &push2, sizeof(push2)); - batch.submit(); + batch.submitDeferred(); + batch.waitForCompletion(); pool.download(bufGradIn, gradInput, elemBytes); pool.download(bufGradGamma, gradGamma, gammaBytes); diff --git a/cpp/src/ops/learning.cpp b/cpp/src/ops/learning.cpp index 3c45e67..015116e 100644 --- a/cpp/src/ops/learning.cpp +++ b/cpp/src/ops/learning.cpp @@ -44,7 +44,8 @@ void ssmFusedMath(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, batch.begin(); batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, 1, 1, &p, sizeof(p)); - batch.submit(); + batch.submitDeferred(); + batch.waitForCompletion(); pool.download(bufOut, scanOut, dataBytes); @@ -82,7 +83,8 @@ void fisherInfoUpdate(CommandBatch& batch, BufferPool& pool, batch.begin(); batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, 1, 1, &p, sizeof(p)); - batch.submit(); + batch.submitDeferred(); + batch.waitForCompletion(); pool.download(bufFisher, fisher, bytes); @@ -124,7 +126,8 @@ void ewcPenalty(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, batch.begin(); batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, 1, 1, &p, sizeof(p)); - batch.submit(); + batch.submitDeferred(); + batch.waitForCompletion(); pool.download(bufPenalty, penalty, bytes); @@ -170,7 +173,8 @@ void nlmsPredict(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, batch.begin(); batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, 1, 1, &p, sizeof(p)); - batch.submit(); + batch.submitDeferred(); + batch.waitForCompletion(); pool.download(bufY, yPred, yBytes); @@ -224,7 +228,8 @@ void nlmsUpdate(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, batch.begin(); batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, 1, 1, &p, sizeof(p)); - batch.submit(); + batch.submitDeferred(); + batch.waitForCompletion(); pool.download(bufW, w, featureBytes); pool.download(bufBias, bias, scalarBytes); @@ -302,7 +307,8 @@ void continuousToSpikes(CommandBatch& batch, BufferPool& pool, batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, 1, 1, &push1, sizeof(push1)); - batch.submit(); + batch.submitDeferred(); + batch.waitForCompletion(); pool.download(bufSpikes, spikes, spikeBytes); @@ -376,7 +382,8 @@ void spikesToContinuous(CommandBatch& batch, BufferPool& pool, batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, 1, 1, &push1, sizeof(push1)); - batch.submit(); + batch.submitDeferred(); + batch.waitForCompletion(); pool.download(bufFeat, features, featBytes); diff --git a/cpp/src/ops/linear.cpp b/cpp/src/ops/linear.cpp index 7f90d86..85bbada 100644 --- a/cpp/src/ops/linear.cpp +++ b/cpp/src/ops/linear.cpp @@ -58,16 +58,17 @@ void linear(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, LinearParams pushData = p; // ── Dispatch ── - // 2D workgroups at 16×16 (must match fnn-linear.glsl local_size) + // 2D workgroups at 16×16 (matches fnn-linear.glsl tiled GEMM) uint32_t gx = (p.outputDim + 15) / 16; uint32_t gy = (p.batchSeq + 15) / 16; batch.begin(); batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, gy, 1, &pushData, sizeof(pushData)); - batch.submit(); + batch.submitDeferred(); // Submit without waiting — GPU runs async - // ── Download result (persistent mapping — single memcpy, no vkMap) ── + // ── Sync + download (only waits here, not at submit) ── + batch.waitForCompletion(); pool.download(bufOutput, output, outputBytes); // ── Release buffers back to pool ── @@ -192,7 +193,8 @@ void linearBackward(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx2, 1, 1, &bwdParams, sizeof(bwdParams)); - batch.submit(); + batch.submitDeferred(); + batch.waitForCompletion(); pool.download(bufGradIn, gradInput, gradInBytes); pool.download(bufGradW, gradWeight, gradWBytes); @@ -236,7 +238,8 @@ void dropout(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, batch.begin(); batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, 1, 1, &push, sizeof(push)); - batch.submit(); + batch.submitDeferred(); + batch.waitForCompletion(); pool.download(bufOutput, output, bytes); diff --git a/cpp/src/ops/loss.cpp b/cpp/src/ops/loss.cpp index 866431d..f87a219 100644 --- a/cpp/src/ops/loss.cpp +++ b/cpp/src/ops/loss.cpp @@ -70,7 +70,8 @@ void crossEntropyLoss(CommandBatch& batch, BufferPool& pool, batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, 1, 1, &push2, sizeof(push2)); - batch.submit(); + batch.submitDeferred(); + batch.waitForCompletion(); pool.download(bufLoss, losses, lossBytes); @@ -114,7 +115,8 @@ void crossEntropyBackward(CommandBatch& batch, BufferPool& pool, batch.begin(); batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, 1, 1, &p, sizeof(p)); - batch.submit(); + batch.submitDeferred(); + batch.waitForCompletion(); pool.download(bufGrad, gradLogits, logitBytes); diff --git a/cpp/src/ops/moqe_train.cpp b/cpp/src/ops/moqe_train.cpp index f0b214c..88dd7b1 100644 --- a/cpp/src/ops/moqe_train.cpp +++ b/cpp/src/ops/moqe_train.cpp @@ -137,7 +137,8 @@ void moqe_layer_forward_gpu(CommandBatch& batch, BufferPool& pool, batchedTiledLinear(batch, cache, tc.bufA1, lw.w1, nullptr, tc.bufC1, n1, d, d); // No barrier between experts — independent reads/writes - batch.submit(); // Fence wait — safe to reuse buffers after this returns + batch.submitDeferred(); + batch.waitForCompletion(); // Fence wait — safe to reuse buffers after this returns if (n0 > 0) pool.download(tc.bufC0, out0, size_t(n0) * d * sizeof(float)); if (n1 > 0) pool.download(tc.bufC1, out1, size_t(n1) * d * sizeof(float)); @@ -166,7 +167,8 @@ void moqe_layer_backward_dx_gpu(CommandBatch& batch, BufferPool& pool, if (n1 > 0) batchedTiledLinear(batch, cache, tc.bufA1, lw.w1_t, nullptr, tc.bufC1, n1, d, d); - batch.submit(); + batch.submitDeferred(); + batch.waitForCompletion(); if (n0 > 0) pool.download(tc.bufC0, dx0, size_t(n0) * d * sizeof(float)); if (n1 > 0) pool.download(tc.bufC1, dx1, size_t(n1) * d * sizeof(float)); diff --git a/cpp/src/ops/optimizer.cpp b/cpp/src/ops/optimizer.cpp index 328bd91..d780531 100644 --- a/cpp/src/ops/optimizer.cpp +++ b/cpp/src/ops/optimizer.cpp @@ -48,7 +48,8 @@ void adamUpdate(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, batch.begin(); batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, 1, 1, &p, sizeof(p)); - batch.submit(); + batch.submitDeferred(); + batch.waitForCompletion(); pool.download(bufW, weights, bytes); pool.download(bufGrad, grad, bytes); @@ -95,7 +96,8 @@ void adamwUpdate(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, batch.begin(); batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, 1, 1, &p, sizeof(p)); - batch.submit(); + batch.submitDeferred(); + batch.waitForCompletion(); pool.download(bufW, weights, bytes); pool.download(bufGrad, grad, bytes); diff --git a/cpp/src/ops/perceiver.cpp b/cpp/src/ops/perceiver.cpp index 824bd89..a6e2343 100644 --- a/cpp/src/ops/perceiver.cpp +++ b/cpp/src/ops/perceiver.cpp @@ -32,7 +32,8 @@ void perceiverEncode(CommandBatch& batch, BufferPool& pool, PipelineCache& cache batch.begin(); batchedPerceiverEncode(batch, cache, bufQ, bufK, bufV, bufOut, seqN, seqM, headDim); - batch.submit(); + batch.submitDeferred(); + batch.waitForCompletion(); pool.download(bufOut, output, qBytes); diff --git a/cpp/src/ops/perceiver_encoder.cpp b/cpp/src/ops/perceiver_encoder.cpp index bd18c56..159a051 100644 --- a/cpp/src/ops/perceiver_encoder.cpp +++ b/cpp/src/ops/perceiver_encoder.cpp @@ -138,7 +138,8 @@ std::vector perceiver_encode_native( nPatches, D, D); } // ONE barrier after ALL K/V projections complete - batch.submit(); + batch.submitDeferred(); + batch.waitForCompletion(); // ══════════════════════════════════════════════════════════════════ // PHASE 2: Main layer loop — only Q_cross + cached K/V + self-attn @@ -198,7 +199,8 @@ std::vector perceiver_encode_native( std::swap(current, scratch); } - batch.submit(); + batch.submitDeferred(); + batch.waitForCompletion(); // Download and mean-pool std::vector finalLatents(N * D); diff --git a/cpp/src/ops/pooling.cpp b/cpp/src/ops/pooling.cpp index f5e3784..d9f9a00 100644 --- a/cpp/src/ops/pooling.cpp +++ b/cpp/src/ops/pooling.cpp @@ -52,7 +52,8 @@ void maxpool2dForward(CommandBatch& batch, BufferPool& pool, batch.begin(); batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, gy, gz, &p, sizeof(p)); - batch.submit(); + batch.submitDeferred(); + batch.waitForCompletion(); pool.download(bufOut, output, outBytes); pool.download(bufIdx, reinterpret_cast(indices), idxBytes); @@ -104,7 +105,8 @@ void maxpool2dBackward(CommandBatch& batch, BufferPool& pool, batch.begin(); batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, gy, gz, &p, sizeof(p)); - batch.submit(); + batch.submitDeferred(); + batch.waitForCompletion(); pool.download(bufGradIn, gradInput, gradInBytes); @@ -146,7 +148,8 @@ void avgpool2dForward(CommandBatch& batch, BufferPool& pool, batch.begin(); batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, gy, gz, &p, sizeof(p)); - batch.submit(); + batch.submitDeferred(); + batch.waitForCompletion(); pool.download(bufOut, output, outBytes); @@ -187,7 +190,8 @@ void avgpool2dBackward(CommandBatch& batch, BufferPool& pool, batch.begin(); batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, gy, gz, &p, sizeof(p)); - batch.submit(); + batch.submitDeferred(); + batch.waitForCompletion(); pool.download(bufGradIn, gradInput, gradInBytes); @@ -223,7 +227,8 @@ void meanPool(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, batch.begin(); batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, 1, 1, &p, sizeof(p)); - batch.submit(); + batch.submitDeferred(); + batch.waitForCompletion(); pool.download(bufOut, output, outBytes); diff --git a/cpp/src/ops/rmsnorm.cpp b/cpp/src/ops/rmsnorm.cpp index d7000f7..fa815f1 100644 --- a/cpp/src/ops/rmsnorm.cpp +++ b/cpp/src/ops/rmsnorm.cpp @@ -71,7 +71,8 @@ void rmsnorm(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx1, 1, 1, &push1, sizeof(push1)); - batch.submit(); + batch.submitDeferred(); + batch.waitForCompletion(); // Download result pool.download(bufOutput, output, outputBytes); diff --git a/cpp/src/ops/snn.cpp b/cpp/src/ops/snn.cpp index 1dfe280..44aee33 100644 --- a/cpp/src/ops/snn.cpp +++ b/cpp/src/ops/snn.cpp @@ -53,7 +53,8 @@ void lifStep(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, batch.begin(); batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, 1, 1, &p, sizeof(p)); - batch.submit(); + batch.submitDeferred(); + batch.waitForCompletion(); pool.download(bufVMem, vMem, bytes); pool.download(bufRefrac, tRefrac, bytes); @@ -100,7 +101,8 @@ void snnNodeForward(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, batch.begin(); batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, 1, 1, &p, sizeof(p)); - batch.submit(); + batch.submitDeferred(); + batch.waitForCompletion(); pool.download(bufVMem, vMem, bytes); pool.download(bufSpikes, spikes, bytes); @@ -143,7 +145,8 @@ void snnNodeBackward(CommandBatch& batch, BufferPool& pool, PipelineCache& cache batch.begin(); batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, 1, 1, &p, sizeof(p)); - batch.submit(); + batch.submitDeferred(); + batch.waitForCompletion(); pool.download(bufGradX, gradX, bytes); @@ -187,7 +190,8 @@ void hebbianLearning(CommandBatch& batch, BufferPool& pool, PipelineCache& cache batch.begin(); batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, gy, 1, &p, sizeof(p)); - batch.submit(); + batch.submitDeferred(); + batch.waitForCompletion(); pool.download(bufWeights, weights, weightBytes); @@ -252,7 +256,8 @@ void stdpLearning(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, gy, 1, &push1, sizeof(push1)); - batch.submit(); + batch.submitDeferred(); + batch.waitForCompletion(); pool.download(bufWeights, weights, weightBytes); pool.download(bufPreTrace, preTrace, preTraceBytes); @@ -295,7 +300,8 @@ void synapseFilter(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, batch.begin(); batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, 1, 1, &p, sizeof(p)); - batch.submit(); + batch.submitDeferred(); + batch.waitForCompletion(); pool.download(bufYState, yState, bytes); @@ -349,7 +355,8 @@ void gifNeuronStep(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, batch.begin(); batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, 1, 1, &p, sizeof(p)); - batch.submit(); + batch.submitDeferred(); + batch.waitForCompletion(); pool.download(bufVMem, vMem, bytes); pool.download(bufIAdapt, iAdapt, bytes); diff --git a/cpp/src/ops/swizzle.cpp b/cpp/src/ops/swizzle.cpp index ddcfc0a..4c75110 100644 --- a/cpp/src/ops/swizzle.cpp +++ b/cpp/src/ops/swizzle.cpp @@ -209,7 +209,8 @@ void swizzle(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, batch.begin(); batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, 1, 1, &push, sizeof(push)); - batch.submit(); + batch.submitDeferred(); + batch.waitForCompletion(); pool.download(bufOut, output, outputBytes); diff --git a/functional/activations.py b/functional/activations.py index 08c5a5b..e681e02 100644 --- a/functional/activations.py +++ b/functional/activations.py @@ -27,10 +27,8 @@ def relu(x: np.ndarray) -> np.ndarray: return _to_numpy(result) except (ImportError, Exception): pass - # Fallback to legacy Compute() path - from grilly import Compute - - return Compute().activation_relu(x) + x = np.asarray(x, dtype=np.float32) + return np.maximum(0.0, x).astype(np.float32) def gelu(x: np.ndarray) -> np.ndarray: @@ -46,10 +44,10 @@ def gelu(x: np.ndarray) -> np.ndarray: return _to_numpy(result) except (ImportError, Exception): pass - # Fallback to legacy Compute() path - from grilly import Compute - - return Compute().activation_gelu(x) + x = np.asarray(x, dtype=np.float32) + return ( + 0.5 * x * (1.0 + np.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * x**3))) + ).astype(np.float32) def silu(x: np.ndarray) -> np.ndarray: @@ -65,10 +63,8 @@ def silu(x: np.ndarray) -> np.ndarray: return _to_numpy(result) except (ImportError, Exception): pass - # Fallback to legacy Compute() path - from grilly import Compute - - return Compute().activation_silu(x) + x = np.asarray(x, dtype=np.float32) + return (x / (1.0 + np.exp(-x))).astype(np.float32) def softmax(x: np.ndarray, dim: int = -1) -> np.ndarray: @@ -84,10 +80,10 @@ def softmax(x: np.ndarray, dim: int = -1) -> np.ndarray: return _to_numpy(result) except (ImportError, Exception): pass - # Fallback to legacy Compute() path - from grilly import Compute - - return Compute().activation_softmax(x, dim=dim) + x = np.asarray(x, dtype=np.float32) + x_max = np.max(x, axis=dim, keepdims=True) + exp_x = np.exp(x - x_max) + return (exp_x / np.sum(exp_x, axis=dim, keepdims=True)).astype(np.float32) def softplus(x: np.ndarray) -> np.ndarray: diff --git a/functional/attention.py b/functional/attention.py index 5b2527a..2e2fc80 100644 --- a/functional/attention.py +++ b/functional/attention.py @@ -3,11 +3,40 @@ import numpy as np -def _get_backend(): - """Get compute backend""" - from grilly import Compute +def _to_numpy(result): + """Convert bridge result to numpy if it's a C++ Tensor.""" + if result is None: + return None + if isinstance(result, np.ndarray): + return result + if hasattr(result, "numpy"): + return result.numpy() + return np.asarray(result) - return Compute() + +def _numpy_softmax(x: np.ndarray, axis: int = -1) -> np.ndarray: + x_max = np.max(x, axis=axis, keepdims=True) + exp_x = np.exp(x - x_max) + return exp_x / np.sum(exp_x, axis=axis, keepdims=True) + + +def _numpy_attention( + query: np.ndarray, key: np.ndarray, value: np.ndarray, mask: np.ndarray | None = None +) -> tuple[np.ndarray, np.ndarray]: + q = np.asarray(query, dtype=np.float32) + k = np.asarray(key, dtype=np.float32) + v = np.asarray(value, dtype=np.float32) + scale = 1.0 / np.sqrt(float(q.shape[-1])) + scores = (q @ np.swapaxes(k, -1, -2)) * scale + if mask is not None: + mask_arr = np.asarray(mask) + if mask_arr.dtype == np.bool_: + scores = np.where(mask_arr, scores, -1e9) + else: + scores = scores + mask_arr.astype(np.float32) + weights = _numpy_softmax(scores, axis=-1).astype(np.float32) + output = (weights @ v).astype(np.float32) + return output, weights def attention( @@ -26,19 +55,22 @@ def attention( Returns: Tuple of (output, attention_weights) """ - backend = _get_backend() - - # Compute attention scores - scores = backend.attention_scores(query, key) - - # Apply mask if provided - if mask is not None: - scores = backend.attention_mask(scores, mask) + try: + from grilly.backend import _bridge - # Compute attention output - output = backend.attention_output(scores, value) + scores = _bridge.attention_scores(query, key) + if scores is not None: + if mask is not None: + masked_scores = _bridge.attention_mask(scores, mask, False) + if masked_scores is not None: + scores = masked_scores + output = _bridge.attention_output(scores, value) + if output is not None: + return _to_numpy(output), _to_numpy(scores) + except (ImportError, Exception): + pass - return output, scores + return _numpy_attention(query, key, value, mask) def flash_attention2( @@ -57,5 +89,23 @@ def flash_attention2( Returns: Attention output """ - backend = _get_backend() - return backend.flash_attention2(query, key, value, use_rope=use_rope) + try: + from grilly.backend import _bridge + + q = query + k = key + if use_rope: + q_rope = _bridge.rope(q) + k_rope = _bridge.rope(k) + if q_rope is not None and k_rope is not None: + q = q_rope + k = k_rope + + result = _bridge.flash_attention2(q, k, value) + if result is not None: + return _to_numpy(result) + except (ImportError, Exception): + pass + + output, _ = _numpy_attention(query, key, value, mask=None) + return output diff --git a/functional/dropout.py b/functional/dropout.py index 9515ced..e1746e1 100644 --- a/functional/dropout.py +++ b/functional/dropout.py @@ -6,11 +6,15 @@ import numpy as np -def _get_backend(): - """Get compute backend""" - from grilly import Compute - - return Compute() +def _to_numpy(result): + """Convert bridge result to numpy if it's a C++ Tensor.""" + if result is None: + return None + if isinstance(result, np.ndarray): + return result + if hasattr(result, "numpy"): + return result.numpy() + return np.asarray(result) def dropout(input: np.ndarray, p: float = 0.5, training: bool = True) -> np.ndarray: @@ -26,8 +30,21 @@ def dropout(input: np.ndarray, p: float = 0.5, training: bool = True) -> np.ndar Returns: Output tensor with dropout applied (if training) """ + input_arr = np.asarray(input, dtype=np.float32) if not training or p == 0: - return input - - backend = _get_backend() - return backend.fnn.dropout(input, dropout_prob=p, is_training=training) + return input_arr + if p >= 1.0: + return np.zeros_like(input_arr, dtype=np.float32) + + random_mask = (np.random.rand(*input_arr.shape) >= p).astype(np.float32) + try: + from grilly.backend import _bridge + + result = _bridge.dropout(input_arr, random_mask, p=p, training=training) + if result is not None: + return _to_numpy(result) + except (ImportError, Exception): + pass + + scale = 1.0 / (1.0 - p) + return (input_arr * random_mask * scale).astype(np.float32) diff --git a/functional/embedding.py b/functional/embedding.py index fb299e8..2e60ba7 100644 --- a/functional/embedding.py +++ b/functional/embedding.py @@ -8,14 +8,14 @@ import numpy as np -def _get_backend(): - """Get backend instance""" - try: - from ..backend.compute import Compute - - return Compute() - except Exception: +def _to_numpy(result): + if result is None: return None + if isinstance(result, np.ndarray): + return result + if hasattr(result, "numpy"): + return result.numpy() + return np.asarray(result) def embedding_lookup(weight: np.ndarray, indices: np.ndarray) -> np.ndarray: @@ -31,11 +31,15 @@ def embedding_lookup(weight: np.ndarray, indices: np.ndarray) -> np.ndarray: Returns: Embeddings (batch, seq_len, embedding_dim) or (batch, embedding_dim) """ - from grilly import Compute + try: + from grilly.backend import _bridge - backend = Compute() - # Backend expects (token_ids, embedding_table) - note the order! - return backend.embedding_lookup(indices, weight) + result = _bridge.embedding_lookup(indices, weight) + if result is not None: + return _to_numpy(result) + except (ImportError, Exception): + pass + return weight[indices] def embedding_normalize(embeddings: np.ndarray, eps: float = 1e-8) -> np.ndarray: @@ -51,17 +55,6 @@ def embedding_normalize(embeddings: np.ndarray, eps: float = 1e-8) -> np.ndarray Returns: Normalized embeddings (same shape) """ - backend = _get_backend() - - # Try GPU shader if available - if backend and hasattr(backend, "shaders") and "embedding-normalize" in backend.shaders: - try: - # GPU embedding normalization would go here - # For now, use CPU fallback - pass - except Exception: - pass # Fall back to CPU - # CPU fallback norm = np.linalg.norm(embeddings, axis=-1, keepdims=True) return embeddings / (norm + eps) @@ -83,27 +76,15 @@ def embedding_position( Returns: Embeddings with positional encoding (same shape) """ - backend = _get_backend() - - # Try GPU shader if available - if backend and hasattr(backend, "shaders") and "embedding-position" in backend.shaders: - try: - # GPU positional encoding would go here - # For now, use CPU fallback - pass - except Exception: - pass # Fall back to CPU - - # CPU fallback (sinusoidal positional encoding) + # CPU fallback (vectorized sinusoidal positional encoding) batch_size, seq_len, dim = embeddings.shape - # Create positional encoding + positions = np.arange(seq_len, dtype=np.float32)[:, None] + div_term = np.exp(-(np.log(10000.0) * np.arange(0, dim, 2, dtype=np.float32) / dim)) + angles = positions * div_term[None, :] pos_enc = np.zeros((seq_len, dim), dtype=np.float32) - for pos in range(seq_len): - for i in range(0, dim, 2): - pos_enc[pos, i] = np.sin(pos / (10000 ** (i / dim))) - if i + 1 < dim: - pos_enc[pos, i + 1] = np.cos(pos / (10000 ** (i / dim))) + pos_enc[:, 0::2] = np.sin(angles) + pos_enc[:, 1::2] = np.cos(angles[:, : pos_enc[:, 1::2].shape[1]]) return embeddings + pos_enc[None, :, :] @@ -121,17 +102,6 @@ def embedding_pool(embeddings: np.ndarray, pool_type: str = "mean") -> np.ndarra Returns: Pooled embeddings (batch, dim) """ - backend = _get_backend() - - # Try GPU shader if available - if backend and hasattr(backend, "shaders") and "embedding-pool" in backend.shaders: - try: - # GPU embedding pooling would go here - # For now, use CPU fallback - pass - except Exception: - pass # Fall back to CPU - # CPU fallback if pool_type == "mean": return embeddings.mean(axis=1) @@ -167,17 +137,6 @@ def embedding_ffn( Returns: Output embeddings (batch, seq_len, dim) """ - backend = _get_backend() - - # Try GPU shader if available - if backend and hasattr(backend, "shaders") and "embedding-ffn" in backend.shaders: - try: - # GPU embedding FFN would go here - # For now, use CPU fallback - pass - except Exception: - pass # Fall back to CPU - # CPU fallback x = embeddings @ W1.T + b1 @@ -209,18 +168,8 @@ def embedding_attention( Returns: Attended embeddings (batch, seq_len, dim) """ - backend = _get_backend() - - # Try GPU shader if available - if backend and hasattr(backend, "shaders") and "embedding-attention" in backend.shaders: - try: - # GPU embedding attention would go here - # For now, use CPU fallback - pass - except Exception: - pass # Fall back to CPU - # CPU fallback (simplified) from grilly.functional.attention import attention - return attention(embeddings, embeddings, embeddings, num_heads=num_heads) + output, _ = attention(embeddings, embeddings, embeddings) + return output diff --git a/functional/faiss.py b/functional/faiss.py index 7ee0c49..7aad572 100644 --- a/functional/faiss.py +++ b/functional/faiss.py @@ -8,14 +8,14 @@ import numpy as np -def _get_backend(): - """Get backend instance""" - try: - from ..backend.compute import Compute - - return Compute() - except Exception: +def _to_numpy(result): + if result is None: return None + if isinstance(result, np.ndarray): + return result + if hasattr(result, "numpy"): + return result.numpy() + return np.asarray(result) def faiss_distance(query: np.ndarray, vectors: np.ndarray, distance_type: str = "l2") -> np.ndarray: @@ -32,25 +32,25 @@ def faiss_distance(query: np.ndarray, vectors: np.ndarray, distance_type: str = Returns: Distances (batch, num_vectors) or (num_vectors,) """ - from grilly import Compute - - backend = Compute() - if hasattr(backend, "faiss") and hasattr(backend.faiss, "compute_distances"): - return backend.faiss.compute_distances(query, vectors, distance_type=distance_type) - else: - # CPU fallback - if query.ndim == 1: - query = query.reshape(1, -1) - - if distance_type == "l2": - diff = query[:, None, :] - vectors[None, :, :] - return np.sqrt(np.sum(diff**2, axis=2)) - elif distance_type == "cosine": - q_norm = query / (np.linalg.norm(query, axis=1, keepdims=True) + 1e-8) - v_norm = vectors / (np.linalg.norm(vectors, axis=1, keepdims=True) + 1e-8) - return 1 - np.dot(q_norm, v_norm.T) - else: # dot - return -np.dot(query, vectors.T) + try: + from grilly.backend import _bridge + + result = _bridge.faiss_distance(query, vectors, distance_type=distance_type) + if result is not None: + return _to_numpy(result) + except (ImportError, Exception): + pass + + if query.ndim == 1: + query = query.reshape(1, -1) + if distance_type == "l2": + diff = query[:, None, :] - vectors[None, :, :] + return np.sqrt(np.sum(diff**2, axis=2)) + if distance_type == "cosine": + q_norm = query / (np.linalg.norm(query, axis=1, keepdims=True) + 1e-8) + v_norm = vectors / (np.linalg.norm(vectors, axis=1, keepdims=True) + 1e-8) + return 1 - np.dot(q_norm, v_norm.T) + return -np.dot(query, vectors.T) def faiss_topk(distances: np.ndarray, k: int) -> tuple[np.ndarray, np.ndarray]: @@ -66,16 +66,17 @@ def faiss_topk(distances: np.ndarray, k: int) -> tuple[np.ndarray, np.ndarray]: Returns: (indices, topk_distances) - both (batch, k) """ - from grilly import Compute + try: + from grilly.backend import _bridge - backend = Compute() - if hasattr(backend, "faiss") and hasattr(backend.faiss, "topk"): - return backend.faiss.topk(distances, k) - else: - # CPU fallback - indices = np.argsort(distances, axis=1)[:, :k] - topk_distances = np.take_along_axis(distances, indices, axis=1) - return indices, topk_distances + result = _bridge.faiss_topk(distances, k) + if result is not None: + return result + except (ImportError, Exception): + pass + indices = np.argsort(distances, axis=1)[:, :k] + topk_distances = np.take_along_axis(distances, indices, axis=1) + return indices, topk_distances def faiss_ivf_filter(vectors: np.ndarray, centroids: np.ndarray, nlist: int = 100) -> np.ndarray: @@ -92,17 +93,6 @@ def faiss_ivf_filter(vectors: np.ndarray, centroids: np.ndarray, nlist: int = 10 Returns: Cluster assignments (num_vectors,) """ - backend = _get_backend() - - # Try GPU shader if available - if backend and hasattr(backend, "shaders") and "faiss-kmeans-update" in backend.shaders: - try: - # GPU FAISS kmeans update would go here - # For now, use CPU fallback - pass - except Exception: - pass # Fall back to CPU - # CPU fallback - Assign each vector to nearest centroid distances = np.linalg.norm(vectors[:, None, :] - centroids[None, :, :], axis=2) assignments = np.argmin(distances, axis=1) @@ -126,33 +116,17 @@ def faiss_kmeans_update( Returns: Updated centroids (nlist, dim) """ - backend = _get_backend() - - # Try GPU shader if available - if backend and hasattr(backend, "shaders") and "faiss-kmeans-update" in backend.shaders: - try: - # GPU FAISS kmeans update would go here - # For now, use CPU fallback - pass - except Exception: - pass # Fall back to CPU - - # CPU fallback - new_centroids = np.zeros_like(centroids) - counts = np.zeros(nlist, dtype=np.int32) - - for i, vec in enumerate(vectors): - cluster = assignments[i] - new_centroids[cluster] += vec - counts[cluster] += 1 - - # Normalize by counts - for c in range(nlist): - if counts[c] > 0: - new_centroids[c] /= counts[c] - else: - new_centroids[c] = centroids[c] # Keep old centroid if no vectors assigned - + # CPU fallback (vectorized) + vectors = np.asarray(vectors, dtype=np.float32) + centroids = np.asarray(centroids, dtype=np.float32) + assignments = np.asarray(assignments, dtype=np.int32) + dim = vectors.shape[1] + new_centroids = np.zeros((nlist, dim), dtype=np.float32) + np.add.at(new_centroids, assignments, vectors) + counts = np.bincount(assignments, minlength=nlist).astype(np.float32) + nonzero = counts > 0 + new_centroids[nonzero] /= counts[nonzero, None] + new_centroids[~nonzero] = centroids[~nonzero] return new_centroids @@ -172,17 +146,6 @@ def faiss_quantize( Returns: (quantized_vectors, codes) - codes (num_vectors,) """ - backend = _get_backend() - - # Try GPU shader if available - if backend and hasattr(backend, "shaders") and "faiss-quantize" in backend.shaders: - try: - # GPU FAISS quantization would go here - # For now, use CPU fallback - pass - except Exception: - pass # Fall back to CPU - # CPU fallback - Find nearest codebook entry for each vector distances = np.linalg.norm(vectors[:, None, :] - codebook[None, :, :], axis=2) codes = np.argmin(distances, axis=1) diff --git a/functional/fft.py b/functional/fft.py index 6a32a73..c511155 100644 --- a/functional/fft.py +++ b/functional/fft.py @@ -3,13 +3,6 @@ import numpy as np -def _get_backend(): - """Get compute backend""" - from grilly import Compute - - return Compute() - - def fft(input: np.ndarray) -> np.ndarray: """ Fast Fourier Transform @@ -21,9 +14,7 @@ def fft(input: np.ndarray) -> np.ndarray: Returns: FFT output (complex) """ - _get_backend() - # Note: May need to implement FFT in backend if not already exposed - # CPU fallback for now + # CPU fallback return np.fft.fft(input) @@ -38,9 +29,7 @@ def ifft(input: np.ndarray) -> np.ndarray: Returns: Reconstructed signal """ - _get_backend() - # Note: May need to implement IFFT in backend if not already exposed - # CPU fallback for now + # CPU fallback return np.fft.ifft(input) @@ -55,9 +44,7 @@ def fft_magnitude(input: np.ndarray) -> np.ndarray: Returns: Magnitude spectrum """ - _get_backend() - # Note: May need to implement in backend if not already exposed - # CPU fallback for now + # CPU fallback return np.abs(input) @@ -72,7 +59,5 @@ def fft_power_spectrum(input: np.ndarray) -> np.ndarray: Returns: Power spectrum """ - _get_backend() - # Note: May need to implement in backend if not already exposed - # CPU fallback for now + # CPU fallback return np.abs(input) ** 2 diff --git a/functional/linear.py b/functional/linear.py index 258910c..5fd0be5 100644 --- a/functional/linear.py +++ b/functional/linear.py @@ -38,7 +38,9 @@ def linear(input: np.ndarray, weight: np.ndarray, bias: np.ndarray | None = None return _to_numpy(result) except (ImportError, Exception): pass - # Fallback to legacy Compute() path - from grilly import Compute - - return Compute().fnn.linear(input, weight, bias) + input_arr = np.asarray(input, dtype=np.float32) + weight_arr = np.asarray(weight, dtype=np.float32) + output = input_arr @ weight_arr.T + if bias is not None: + output = output + np.asarray(bias, dtype=np.float32) + return np.asarray(output, dtype=np.float32) diff --git a/functional/loss.py b/functional/loss.py index ea7994e..31691b5 100644 --- a/functional/loss.py +++ b/functional/loss.py @@ -6,11 +6,14 @@ import numpy as np -def _get_backend(): - """Get compute backend""" - from grilly import Compute - - return Compute() +def _to_numpy(result): + if result is None: + return None + if isinstance(result, np.ndarray): + return result + if hasattr(result, "numpy"): + return result.numpy() + return np.asarray(result) def cross_entropy( @@ -29,9 +32,21 @@ def cross_entropy( Returns: Loss value(s) """ - _get_backend() - # Note: May need to implement in backend if not already exposed - # CPU fallback for now + try: + from grilly.backend import _bridge + + if target.ndim < input.ndim and target.dtype.kind in {"i", "u"}: + gpu_loss = _bridge.cross_entropy_loss(input, target) + if gpu_loss is not None: + loss = _to_numpy(gpu_loss) + if reduction == "mean": + return np.mean(loss) + if reduction == "sum": + return np.sum(loss) + return loss + except (ImportError, Exception): + pass + # CPU fallback input_softmax = np.exp(input - np.max(input, axis=-1, keepdims=True)) input_softmax = input_softmax / np.sum(input_softmax, axis=-1, keepdims=True) @@ -71,9 +86,7 @@ def binary_cross_entropy( Returns: Loss value(s) """ - _get_backend() - # Note: May need to implement in backend if not already exposed - # CPU fallback for now + # CPU fallback input_clamped = np.clip(input, 1e-8, 1 - 1e-8) loss = -(target * np.log(input_clamped) + (1 - target) * np.log(1 - input_clamped)) diff --git a/functional/memory.py b/functional/memory.py index 13928b0..5eabeb4 100644 --- a/functional/memory.py +++ b/functional/memory.py @@ -9,14 +9,14 @@ import numpy as np -def _get_backend(): - """Get backend instance""" - try: - from ..backend.compute import Compute - - return Compute() - except Exception: +def _to_numpy(result): + if result is None: return None + if isinstance(result, np.ndarray): + return result + if hasattr(result, "numpy"): + return result.numpy() + return np.asarray(result) def memory_read( @@ -39,12 +39,21 @@ def memory_read( Returns: Retrieved values (batch, value_dim) """ - from grilly import Compute - - backend = Compute() if temperature is None: temperature = np.sqrt(memory_keys.shape[1]) - return backend.memory_read(queries, memory_keys, memory_values, temperature) + try: + from grilly.backend import _bridge + + result = _bridge.memory_read(queries, memory_keys, memory_values, temperature) + if result is not None: + return _to_numpy(result) + except (ImportError, Exception): + pass + scores = (queries @ memory_keys.T) / max(float(temperature), 1e-8) + scores = scores - np.max(scores, axis=-1, keepdims=True) + weights = np.exp(scores) + weights = weights / np.sum(weights, axis=-1, keepdims=True) + return (weights @ memory_values).astype(np.float32) def memory_write( @@ -73,12 +82,27 @@ def memory_write( Returns: (updated_memory_keys, updated_memory_values) """ - from grilly import Compute - - backend = Compute() - return backend.memory_write( - new_key, new_value, memory_keys, memory_values, write_index, write_mode, blend_factor - ) + try: + from grilly.backend import _bridge + + result = _bridge.memory_write( + new_key, new_value, memory_keys, memory_values, write_index, write_mode, blend_factor + ) + if result is not None: + return result + except (ImportError, Exception): + pass + mk = np.array(memory_keys, dtype=np.float32, copy=True) + mv = np.array(memory_values, dtype=np.float32, copy=True) + idx = int(write_index) + if int(write_mode) == 1: + alpha = float(blend_factor) + mk[idx] = (1.0 - alpha) * mk[idx] + alpha * np.asarray(new_key, dtype=np.float32) + mv[idx] = (1.0 - alpha) * mv[idx] + alpha * np.asarray(new_value, dtype=np.float32) + else: + mk[idx] = np.asarray(new_key, dtype=np.float32) + mv[idx] = np.asarray(new_value, dtype=np.float32) + return mk, mv def memory_context_aggregate(memory_contexts: np.ndarray) -> np.ndarray: @@ -93,17 +117,6 @@ def memory_context_aggregate(memory_contexts: np.ndarray) -> np.ndarray: Returns: Aggregated context (batch, dim) """ - backend = _get_backend() - - # Try GPU shader if available - if hasattr(backend, "shaders") and "memory-context-aggregate" in backend.shaders: - try: - # GPU memory context aggregation would go here - # For now, use CPU fallback - pass - except Exception: - pass # Fall back to CPU - # CPU fallback (mean pooling) return memory_contexts.mean(axis=1) @@ -122,10 +135,16 @@ def memory_query_pooling(x: np.ndarray, W_query: np.ndarray, b_query: np.ndarray Returns: Query vectors (batch, out_dim) """ - from grilly import Compute + try: + from grilly.backend import _bridge - backend = Compute() - return backend.memory_query_pooling(x, W_query, b_query) + result = _bridge.memory_query_pooling(x, W_query, b_query) + if result is not None: + return _to_numpy(result) + except (ImportError, Exception): + pass + pooled = np.mean(np.asarray(x, dtype=np.float32), axis=1) + return (pooled @ np.asarray(W_query, dtype=np.float32).T + np.asarray(b_query, dtype=np.float32)).astype(np.float32) def memory_inject_concat( @@ -148,17 +167,6 @@ def memory_inject_concat( Returns: Output (batch, seq_len, dim) """ - backend = _get_backend() - - # Try GPU shader if available - if backend and hasattr(backend, "shaders") and "memory-inject-concat" in backend.shaders: - try: - # GPU memory injection with concat would go here - # For now, use CPU fallback - pass - except Exception: - pass # Fall back to CPU - # CPU fallback batch_size, seq_len, dim = x.shape mem_expanded = memory_context[:, None, :] # (batch, 1, dim) @@ -192,10 +200,23 @@ def memory_inject_gate( Returns: Output (batch, seq_len, dim) """ - from grilly import Compute - - backend = Compute() - return backend.memory_inject_gate(x, memory_context, W_gate, b_gate, W_mem_proj) + try: + from grilly.backend import _bridge + + result = _bridge.memory_inject_gate(x, memory_context, W_gate, b_gate, W_mem_proj) + if result is not None: + return _to_numpy(result) + except (ImportError, Exception): + pass + x_arr = np.asarray(x, dtype=np.float32) + mem = np.asarray(memory_context, dtype=np.float32) + batch, seq_len, dim = x_arr.shape + mem_expanded = np.broadcast_to(mem[:, None, :], (batch, seq_len, dim)) + concat = np.concatenate([x_arr, mem_expanded], axis=-1) + gate = 1.0 / (1.0 + np.exp(-(concat @ np.asarray(W_gate, dtype=np.float32).T + np.asarray(b_gate, dtype=np.float32)))) + mem_proj = mem @ np.asarray(W_mem_proj, dtype=np.float32).T + mem_proj = np.broadcast_to(mem_proj[:, None, :], (batch, seq_len, dim)) + return ((1.0 - gate) * x_arr + gate * mem_proj).astype(np.float32) def memory_inject_residual( @@ -218,17 +239,6 @@ def memory_inject_residual( Returns: Output (batch, seq_len, dim) """ - backend = _get_backend() - - # Try GPU shader if available - if backend and hasattr(backend, "shaders") and "memory-inject-residual" in backend.shaders: - try: - # GPU memory injection with residual would go here - # For now, use CPU fallback - pass - except Exception: - pass # Fall back to CPU - # CPU fallback batch_size, seq_len, dim = x.shape mem_proj = memory_context @ mem_proj_weight.T diff --git a/functional/normalization.py b/functional/normalization.py index c262b56..d3c438c 100644 --- a/functional/normalization.py +++ b/functional/normalization.py @@ -51,7 +51,10 @@ def layer_norm( return _to_numpy(result) except (ImportError, Exception): pass - # Fallback to legacy Compute() path - from grilly import Compute - - return Compute().layernorm(input, weight, bias, eps=eps) + input_arr = np.asarray(input, dtype=np.float32) + weight_arr = np.asarray(weight, dtype=np.float32) + bias_arr = np.asarray(bias, dtype=np.float32) + mean = np.mean(input_arr, axis=-1, keepdims=True) + var = np.var(input_arr, axis=-1, keepdims=True) + normalized = (input_arr - mean) / np.sqrt(var + eps) + return (normalized * weight_arr + bias_arr).astype(np.float32) diff --git a/nn/gpu_backward.py b/nn/gpu_backward.py index 597319a..c52ec1b 100644 --- a/nn/gpu_backward.py +++ b/nn/gpu_backward.py @@ -44,10 +44,13 @@ def __init__(self, use_gpu: bool = True): if use_gpu: try: - from grilly import Compute - - self.backend = Compute() - logger.info("GPU backward operations initialized successfully") + from grilly.backend import _bridge + if _bridge.is_available(): + self.backend = _bridge + logger.info("GPU backward operations initialized via _bridge") + else: + logger.warning("_bridge not available. Falling back to CPU.") + self.use_gpu = False except Exception as e: logger.warning(f"Failed to initialize GPU backend: {e}. Falling back to CPU.") self.use_gpu = False @@ -77,15 +80,22 @@ def is_available(self) -> bool: return self.backend is not None and self.use_gpu def _has_shader(self, shader_name: str) -> bool: - """Check if a shader is available.""" + """Check if a backward op is available on _bridge.""" if not self.is_available(): return False - try: - return ( - hasattr(self.backend.core, "shaders") and shader_name in self.backend.core.shaders - ) - except Exception: - return False + # Map shader names to _bridge function names + bridge_map = { + "fnn-linear-backward": "linear_backward", + "activation-relu-backward": "relu_backward", + "activation-gelu-backward": "gelu_backward", + "activation-silu-backward": "silu_backward", + "activation-softmax-backward": "softmax_backward", + "activation-tanh-backward": "tanh_backward", + "cross-entropy-backward": "cross_entropy_backward", + "fnn-layernorm-backward": "layernorm_backward", + } + fn_name = bridge_map.get(shader_name, shader_name.replace("-", "_")) + return hasattr(self.backend, fn_name) # ======================================================================== # Linear / Fully Connected Layer @@ -131,10 +141,11 @@ def linear_backward( ) try: - # Use Grilly's existing linear_backward method! - grad_input, grad_weight, grad_bias = self.backend.fnn.linear_backward( - grad_output, input_data, weights, bias=None - ) + # Use _bridge.linear_backward directly + result = self.backend.linear_backward(grad_output, input_data, weights) + grad_input = np.asarray(result['grad_input']) if 'grad_input' in result else None + grad_weight = np.asarray(result['grad_weight']) if 'grad_weight' in result else None + grad_bias = np.asarray(result['grad_bias']) if 'grad_bias' in result else None # Filter outputs based on what was requested if not compute_input_grad: @@ -207,11 +218,8 @@ def relu_backward(self, grad_output: np.ndarray, input_data: np.ndarray) -> np.n return grad_output * (input_data > 0).astype(np.float32) try: - return self.backend.core.dispatch_shader( - "activation-relu-backward", - inputs={"grad_output": grad_output, "input_data": input_data}, - output_shape=grad_output.shape, - ) + result = self.backend.relu_backward(grad_output, input_data) + return np.asarray(result) if result is not None else grad_output * (input_data > 0).astype(np.float32) except Exception as e: logger.warning(f"GPU ReLU backward failed: {e}. Falling back to CPU.") return grad_output * (input_data > 0).astype(np.float32) @@ -241,22 +249,19 @@ def gelu_backward(self, grad_output: np.ndarray, input_data: np.ndarray) -> np.n return grad_output * (cdf_approx + x * dcdf) try: - return self.backend.core.dispatch_shader( - "activation-gelu-backward", - inputs={"grad_output": grad_output, "input_data": input_data}, - output_shape=grad_output.shape, - ) + result = self.backend.gelu_backward(grad_output, input_data) + if result is not None: + return np.asarray(result) except Exception as e: logger.warning(f"GPU GELU backward failed: {e}. Falling back to CPU.") - # CPU fallback - x = input_data - sqrt_2_pi = np.sqrt(2.0 / np.pi) - cdf_approx = 0.5 * (1.0 + np.tanh(sqrt_2_pi * (x + 0.044715 * x**3))) - inner = sqrt_2_pi * (x + 0.044715 * x**3) - tanh_inner = np.tanh(inner) - sech2 = 1 - tanh_inner**2 - dcdf = 0.5 * sech2 * sqrt_2_pi * (1 + 3 * 0.044715 * x**2) - return grad_output * (cdf_approx + x * dcdf) + x = input_data + sqrt_2_pi = np.sqrt(2.0 / np.pi) + cdf_approx = 0.5 * (1.0 + np.tanh(sqrt_2_pi * (x + 0.044715 * x**3))) + inner = sqrt_2_pi * (x + 0.044715 * x**3) + tanh_inner = np.tanh(inner) + sech2 = 1 - tanh_inner**2 + dcdf = 0.5 * sech2 * sqrt_2_pi * (1 + 3 * 0.044715 * x**2) + return grad_output * (cdf_approx + x * dcdf) def silu_backward(self, grad_output: np.ndarray, input_data: np.ndarray) -> np.ndarray: """ @@ -275,15 +280,13 @@ def silu_backward(self, grad_output: np.ndarray, input_data: np.ndarray) -> np.n return grad_output * sigmoid_x * (1 + input_data * (1 - sigmoid_x)) try: - return self.backend.core.dispatch_shader( - "activation-silu-backward", - inputs={"grad_output": grad_output, "input_data": input_data}, - output_shape=grad_output.shape, - ) + result = self.backend.silu_backward(grad_output, input_data) + if result is not None: + return np.asarray(result) except Exception as e: logger.warning(f"GPU SiLU backward failed: {e}. Falling back to CPU.") - sigmoid_x = 1.0 / (1.0 + np.exp(-input_data)) - return grad_output * sigmoid_x * (1 + input_data * (1 - sigmoid_x)) + sigmoid_x = 1.0 / (1.0 + np.exp(-input_data)) + return grad_output * sigmoid_x * (1 + input_data * (1 - sigmoid_x)) def swiglu_backward( self, grad_output: np.ndarray, input_data: np.ndarray, gate_data: np.ndarray diff --git a/shaders/fnn-linear.glsl b/shaders/fnn-linear.glsl index a18d47d..5f4781c 100644 --- a/shaders/fnn-linear.glsl +++ b/shaders/fnn-linear.glsl @@ -1,58 +1,71 @@ #version 450 +// Tiled GEMM: output = input @ W^T + bias +// 16x16 tiles, shared memory, 4-way unrolled K loop. +// Correct and stable. Performance limited by per-op dispatch overhead. + layout (local_size_x = 16, local_size_y = 16, local_size_z = 1) in; -// Input features (batch * seq, input_dim) layout(set = 0, binding = 0) readonly buffer Input { float input_data[]; }; -// Weight matrix (output_dim, input_dim) layout(set = 0, binding = 1) readonly buffer Weights { float W[]; }; -// Bias vector (output_dim) layout(set = 0, binding = 2) readonly buffer Bias { float b[]; }; -// Output (batch * seq, output_dim) layout(set = 0, binding = 3) buffer Output { float output_data[]; }; -// Parameters layout(push_constant) uniform PushConsts { - uint batch_seq; // batch_size * seq_len + uint batch_seq; uint input_dim; uint output_dim; - uint has_bias; // 1 if bias exists, 0 otherwise + uint has_bias; }; +shared float tileA[16][17]; // +1 padding to avoid bank conflicts +shared float tileB[16][17]; + void main() { - // 2D parallelization: each thread computes one output element - uint row = gl_GlobalInvocationID.y; // Sample index - uint col = gl_GlobalInvocationID.x; // Output feature index - - if (row >= batch_seq || col >= output_dim) { - return; - } - - // Compute: output[row][col] = sum(input[row][k] * W[col][k]) + b[col] - float sum = 0.0; - - for (uint k = 0; k < input_dim; k++) { - uint input_idx = row * input_dim + k; - uint weight_idx = col * input_dim + k; - sum += input_data[input_idx] * W[weight_idx]; + uint row = gl_WorkGroupID.y * 16 + gl_LocalInvocationID.y; + uint col = gl_WorkGroupID.x * 16 + gl_LocalInvocationID.x; + uint ty = gl_LocalInvocationID.y; + uint tx = gl_LocalInvocationID.x; + + float acc = 0.0; + uint numTiles = (input_dim + 15) / 16; + + for (uint t = 0; t < numTiles; t++) { + uint k_base = t * 16; + + uint k_a = k_base + tx; + tileA[ty][tx] = (row < batch_seq && k_a < input_dim) + ? input_data[row * input_dim + k_a] : 0.0; + + uint k_b = k_base + ty; + tileB[ty][tx] = (col < output_dim && k_b < input_dim) + ? W[col * input_dim + k_b] : 0.0; + + barrier(); + + for (uint k = 0; k < 16; k += 4) { + acc += tileA[ty][k] * tileB[k][tx]; + acc += tileA[ty][k + 1] * tileB[k + 1][tx]; + acc += tileA[ty][k + 2] * tileB[k + 2][tx]; + acc += tileA[ty][k + 3] * tileB[k + 3][tx]; + } + + barrier(); } - - // Add bias if present - if (has_bias == 1) { - sum += b[col]; + + if (row < batch_seq && col < output_dim) { + if (has_bias == 1) acc += b[col]; + output_data[row * output_dim + col] = acc; } - - uint out_idx = row * output_dim + col; - output_data[out_idx] = sum; } diff --git a/shaders/spv/fnn-linear.spv b/shaders/spv/fnn-linear.spv index 872db0e5adab137875fe6e11b326c35432017162..7788a2215a03467567563ef2936b7276ae406b69 100644 GIT binary patch literal 6104 zcmZ9O2b5i96@_n@l2IZAfkXx33<&`YiJ%FfKrjJKB0*5lSQ%d?Lo#ME6EbfoHcANg z0@y2d#fH5r#$LeQTWnygh*H(H{J#6{CjWYO)>&tted=HCebcko{FzlXquRR~sQz1Z zp8cx5U@Ex&TCZAr*4iag?F~zoE<4(YeX5>1(3pLznN=UU7unJpo6xa0a$jUN`SSbK zSldn&**$cqs+|~pM3lLG_y+K8o4f#De_dPZJSqZo1EUd=2UYA>iY0&Cr4W2 zYbMS|wZ|qW^y=@a=AcI=$7?(f+`g#84~Z}A;DzAU$Otig)e+!{=`ClEjgL+NYwk$& znbT97PMMsTYQy&PC+D-i)gIY2JT-bAIJ|2co7g(t9^NpvrQlubK-){(rNe!?^oM>_t@6-G2Xq`Z!Pf< z;uHJOj#=B6fy-XovC~WJlU}jscj&s8=aXJyd%13VX_vij!Y+GxX6cplORv~vFV8f+ z#P)JO?WJAz`UiH|%QNrv>Z@}8dusRAXKy`Ytmj=Q>-`7Y^S?i`KV9uRM6~hZ1Ie-9 zAmTHKJrk|S&&9r1gU!)@Z}bdA{)EQ1huk7;^Sl$<#^0yG^6}3GE6MksmicqBedaLf z0`y=lz13l8XOpv*`f$W&9-WUw+ou#~co3ZP>bIx%qK0!1jsg!NnL>W1GJOY<{_h zW3cx{4qzuwL}zV}1p9qIn7l`!ol#DE6 zJh1PSwsoF{_B`dRqwPB+r+>7;=4w9!ZC~y5-P&;J>$lQ;`(B1l-|b*+>s*dDM$S6g zI~p$gab=f%752c4J^Qs2+q;$c8^OJe?fQ2i`kmt@bmq7jtnD0sLOX|?b7=2wIP-k3 zl6MPO+q{1^F>>Z<`wm&_C{Eim_Z{-Qlv&8Jd^`0!qmQ%Q3DbwR#!U18qK@rvZ)wl# zvWE)$Fz~{{Ht&co+uy;m-tsPcMPXa-iG^+6cwy_mxUlv6o26dCs(gRtGd>(K z*YC@N5x+@(ce2i-8$RU0Lx`=@336r zSAykYp8|IOvKOakv0Ly2+XM)X> zzk{BSL;JXPeUC-tT)Q}R&H_7MdYujKKI_NB$*0#7z~;$&KB=p|jknYNUx(P&eKy{n z&h=!pk8>IGBxDofT;j|%3@-QJ>Db+KwczA)ZtKD3$*0Z;*!AUXHh_)m@cs4g&$-y= zAnvbza~GgDBKkZ-zf;<-*EKk=bI0fVQu>TZoiVWaEBG4u`?eYJx72%Z8e$DO<1+V_ zn%`qjV0&k>j!CdF^5&1D<&u9MxV&TAu!jo&6nGAjbDsvgzilKn&Iilq+%Eu|C+}Ha zg!b_)^<9X_xi{j}c_!HT((74Zdo4y1^K7tudOZhhp1k{-y6W7q=Ys9)UK?*u=eiW_ z<6OpEg1i)QE^&HZ0WSC8dDz`^Js(a!ckBgV^W;F_I(K}zWAUj|=Vf5?bH`o|-imk+JbP=%8MlkxSE7}x5Pesp z=OXrg1=`+fzsIjazZ&t;{~EM@b@t%3U~A<3UI&(&jd%{)uSBe6t~j}`2bXvC4cOMo zTHXlueKAg7*8C>0zSMa$*u2=+fQOLOc?;M*G42{9b>0frw}amf&*W{0wdE7@cJM;v z(kA8|U^!#mKYeoP{Z4SXrgvdGTkgoa!E&>Z%=aF!>oiw?dS46Hm)h?Im$lx9EoZI7 zydPYy{R7yp-B^98{Xww4to=H${p1t#A#l0&4`a(2o3+cO_ea3x+CPe2&igTJIqzTA z{&BEtH&=gpe*&y8wLb|iYkdk^&RU82G`L**XRuwnvHDW`vtWH$`{%&+lTXa&!R6Y& zfGuZi)-IRc*MrNoe-XQ!_ez)O+5Q{M;6Ek>-bPtN>hX!DOlC;tcF0VKamKLjsAPDc{+ zBd~nFYd;2?C+}~~PtiVp%@ zfxiOV*L4|hPv`m#+Q+$!`86WvT;j~}Td;HNLNdqi!1C8N|4aWKEa&&d+J8j*SXU#_5QotKEC0?gkQ2|i|DWGj_6|`!-$0V(7Y?}Z{DC!rT(f{J` z!SOfvpS;dE?CCS7^f;J?Ge><6;Zs46AUlnA$Bl91L3p`|3~O5@5jjlulzzvk;8BjNus6ou zt?n*+V_BY!?fU$7x7pZU?A-3Q8vS;+vv?IBea+eGZi6f7DERuIw|S-8>GgYw>Z2J4 zotlk)YqQ>4zX=ZBUfP`-gMNLjy;I`ttvl%NZG^uk-fZ;h&31$AdHsh4EPHt#-5cy= z_yzD*##hp7=)6vQ?M}hXp?BBY8=L)lMtTp6aEtbXwx+hL=&{jnOm_Rbw-%^B-~uVr_V_m*8i?vve<&6#6R z`w!Y~gIZ zcGMEPqE^_ooF-rXrculHwU+m3E$woxTiE4VzHwfwlD$9X-g_DETboaxu*>s1j~6`K!HCi2Iy??sD^asNK_5F&rRu&p6CgRQSKp{@S|1(uKf z39xIDE+)}aL^+T8l*RaCaMp6)dfMKnoc(*YDiZazA4OcoKZcGoSl|5`;`xtqsu^@O zlaX{1?fK>ORnH=>h&>Co)^TE;d3~n9i)d|YzKu4&oIPtV6`cF;6xdkvIM0#CdDYhM zeYAPx%%kmm%0)k)gMAzGG&4T$rv=tF&l=hoIrC^cdveA(cTwM&(>BiUpgwZOX*+Xr z*7vQxMf?vqiRC>zb8?DvaF*Y@{(cu+_WCzW1#M36+wY1xZ1aWfoA_N1TkW^KZ2NsL z+jINut-_w8s`Gdd8At5zLUE2m1%Hta?MFNO9|G&|+=l;Qu)O~Es~vksz~(UC`Sbhh z%sQ9hbG+bx#V7v_IM2REO25#P+!M$Y@)tFx(XM~lldmFj-laJ5o&eiV)P53d?HUq3 zPl4s5_S0bFK{&DzT|h`so3`d1PAdJgTfFMXaxT6 zXMS6e;HhU!zt2PEkGgMyjgyZ&Z-L8o=dtyZx9%KTF8r6l_AK-`pMhJ7FsU+Z-eFH9J^r85p!|}JU~3ZzUGkA&+o=L Date: Fri, 3 Apr 2026 15:11:03 -0400 Subject: [PATCH 06/17] Update learning.py --- functional/learning.py | 160 ++++++++++++++++++++--------------------- 1 file changed, 77 insertions(+), 83 deletions(-) diff --git a/functional/learning.py b/functional/learning.py index 982d761..ae63014 100644 --- a/functional/learning.py +++ b/functional/learning.py @@ -10,14 +10,14 @@ import numpy as np -def _get_backend(): - """Get backend instance""" - try: - from ..backend.compute import Compute - - return Compute() - except Exception: +def _to_numpy(result): + if result is None: return None + if isinstance(result, np.ndarray): + return result + if hasattr(result, "numpy"): + return result.numpy() + return np.asarray(result) def fisher_info( @@ -42,12 +42,22 @@ def fisher_info( Returns: Updated Fisher information """ - from grilly import Compute - - backend = Compute() - return backend.fisher_info_update( - gradients, fisher, momentum=momentum, use_ema=use_ema, reset=reset - ) + try: + from grilly.backend import _bridge + + result = _bridge.fisher_info_update( + gradients, fisher, momentum=momentum, use_ema=use_ema, reset=reset + ) + if result is not None: + return _to_numpy(result) + except (ImportError, Exception): + pass + gradients = np.asarray(gradients, dtype=np.float32) + fisher = np.zeros_like(gradients) if reset else np.asarray(fisher, dtype=np.float32) + g2 = gradients * gradients + if use_ema: + return (momentum * fisher + (1.0 - momentum) * g2).astype(np.float32) + return (fisher + g2).astype(np.float32) def ewc_penalty( @@ -70,10 +80,17 @@ def ewc_penalty( Returns: EWC penalty value """ - from grilly import Compute + try: + from grilly.backend import _bridge - backend = Compute() - return backend.ewc_penalty(current_params, important_params, fisher, lambda_ewc=lambda_ewc) + result = _bridge.ewc_penalty(current_params, important_params, fisher, lambda_ewc=lambda_ewc) + if result is not None: + return float(result) + except (ImportError, Exception): + pass + diff = np.asarray(current_params, dtype=np.float32) - np.asarray(important_params, dtype=np.float32) + f = np.asarray(fisher, dtype=np.float32) + return float(0.5 * lambda_ewc * np.sum(f * diff * diff, dtype=np.float32)) def natural_gradient(gradients: np.ndarray, fisher: np.ndarray, eps: float = 1e-8) -> np.ndarray: @@ -90,10 +107,17 @@ def natural_gradient(gradients: np.ndarray, fisher: np.ndarray, eps: float = 1e- Returns: Natural gradient: F^(-1) * grad """ - from grilly import Compute + try: + from grilly.backend import _bridge - backend = Compute() - return backend.natural_gradient(gradients, fisher, eps=eps) + result = _bridge.natural_gradient(gradients, fisher, eps=eps) + if result is not None: + return _to_numpy(result) + except (ImportError, Exception): + pass + gradients = np.asarray(gradients, dtype=np.float32) + fisher = np.asarray(fisher, dtype=np.float32) + return (gradients / (fisher + eps)).astype(np.float32) def fisher_normalize(fisher: np.ndarray) -> np.ndarray: @@ -108,17 +132,6 @@ def fisher_normalize(fisher: np.ndarray) -> np.ndarray: Returns: Normalized Fisher information """ - backend = _get_backend() - - # Try GPU shader if available - if backend and hasattr(backend, "shaders") and "fisher-normalize" in backend.shaders: - try: - # GPU Fisher normalization would go here - # For now, use CPU fallback - pass - except Exception: - pass # Fall back to CPU - # CPU fallback fisher_sum = np.sum(fisher) if fisher_sum > 0: @@ -140,10 +153,15 @@ def nlms_predict(x: np.ndarray, w: np.ndarray, bias: float = 0.0) -> float: Returns: Predicted value """ - from grilly import Compute + try: + from grilly.backend import _bridge - backend = Compute() - return backend.nlms_predict(x, w, bias) + result = _bridge.nlms_predict(x, w, bias) + if result is not None: + return float(result) + except (ImportError, Exception): + pass + return float(np.dot(np.asarray(x, dtype=np.float32), np.asarray(w, dtype=np.float32)) + bias) def nlms_update( @@ -165,10 +183,20 @@ def nlms_update( Returns: (updated_weights, updated_bias) """ - from grilly import Compute + try: + from grilly.backend import _bridge - backend = Compute() - return backend.nlms_update(x, y_true, w, bias, mu=mu, eps=eps) + result = _bridge.nlms_update(x, y_true, w, bias, mu=mu, eps=eps) + if result is not None: + return result + except (ImportError, Exception): + pass + x_arr = np.asarray(x, dtype=np.float32) + w_arr = np.asarray(w, dtype=np.float32) + pred = float(np.dot(x_arr, w_arr) + bias) + err = float(y_true) - pred + step = float(mu) * err / float(np.dot(x_arr, x_arr) + eps) + return (w_arr + step * x_arr).astype(np.float32), float(bias + step) def nlms_ensemble(x: np.ndarray, weights_list: list, biases_list: list) -> np.ndarray: @@ -185,17 +213,6 @@ def nlms_ensemble(x: np.ndarray, weights_list: list, biases_list: list) -> np.nd Returns: Ensemble predictions (num_experts,) """ - backend = _get_backend() - - # Try GPU shader if available - if backend and hasattr(backend, "shaders") and "nlms-ensemble" in backend.shaders: - try: - # GPU NLMS ensemble would go here - # For now, use CPU fallback - pass - except Exception: - pass # Fall back to CPU - # CPU fallback predictions = [] for w, b in zip(weights_list, biases_list): @@ -217,17 +234,6 @@ def nlms_metrics(errors: np.ndarray, update_count: int) -> dict: Returns: Dictionary with metrics (rmse, mae, etc.) """ - backend = _get_backend() - - # Try GPU shader if available - if backend and hasattr(backend, "shaders") and "nlms-metrics" in backend.shaders: - try: - # GPU NLMS metrics would go here - # For now, use CPU fallback - pass - except Exception: - pass # Fall back to CPU - # CPU fallback rmse = np.sqrt(np.mean(errors**2)) mae = np.mean(np.abs(errors)) @@ -250,10 +256,20 @@ def whitening_transform( Returns: (whitened_data, mean, std) """ - from grilly import Compute + try: + from grilly.backend import _bridge - backend = Compute() - return backend.whitening_transform(data, mean=mean, std=std) + result = _bridge.whitening_transform(data, mean=mean, std=std) + if result is not None: + return result + except (ImportError, Exception): + pass + data = np.asarray(data, dtype=np.float32) + if mean is None: + mean = np.mean(data, axis=0) + if std is None: + std = np.std(data, axis=0) + return ((data - mean) / (std + 1e-8)).astype(np.float32), mean.astype(np.float32), std.astype(np.float32) def whitening_apply(data: np.ndarray, mean: np.ndarray, std: np.ndarray) -> np.ndarray: @@ -270,17 +286,6 @@ def whitening_apply(data: np.ndarray, mean: np.ndarray, std: np.ndarray) -> np.n Returns: Whitened data """ - backend = _get_backend() - - # Try GPU shader if available - if backend and hasattr(backend, "shaders") and "whitening-apply" in backend.shaders: - try: - # GPU whitening apply would go here - # For now, use CPU fallback - pass - except Exception: - pass # Fall back to CPU - # CPU fallback return (data - mean) / (std + 1e-8) @@ -297,17 +302,6 @@ def whitening_batch_stats(data: np.ndarray) -> tuple[np.ndarray, np.ndarray]: Returns: (mean, std) """ - backend = _get_backend() - - # Try GPU shader if available - if backend and hasattr(backend, "shaders") and "whitening-batch-stats" in backend.shaders: - try: - # GPU whitening batch stats would go here - # For now, use CPU fallback - pass - except Exception: - pass # Fall back to CPU - # CPU fallback mean = np.mean(data, axis=0) std = np.std(data, axis=0) From d289a835e4f7e2d6880dfd5e6243164b90b08fdd Mon Sep 17 00:00:00 2001 From: Grill cheese Date: Fri, 3 Apr 2026 15:22:57 -0400 Subject: [PATCH 07/17] optimizations --- backend/jit.py | 10 ++-- functional/bridge.py | 93 ++++++++++++++++++------------- functional/cells.py | 91 ++++++++++++++++++------------ shaders/attention-output.glsl | 26 ++++----- shaders/flash-attention2.glsl | 37 ++++-------- shaders/fnn-linear.glsl | 14 +++-- shaders/spv/attention-output.spv | Bin 5096 -> 5152 bytes shaders/spv/flash-attention2.spv | Bin 13092 -> 12828 bytes shaders/spv/fnn-linear.spv | Bin 6104 -> 6272 bytes 9 files changed, 147 insertions(+), 124 deletions(-) diff --git a/backend/jit.py b/backend/jit.py index 3eef6fc..aa56cf2 100644 --- a/backend/jit.py +++ b/backend/jit.py @@ -24,6 +24,7 @@ def forward(x): import functools import logging +from collections import OrderedDict logger = logging.getLogger("grilly.jit") @@ -234,7 +235,7 @@ class _JitWrapper: def __init__(self, fn, warmup: int = 1): self._fn = fn self._warmup = warmup - self._graphs: dict[tuple, TracedGraph] = {} # cache_key → graph + self._graphs: OrderedDict[tuple, TracedGraph] = OrderedDict() # cache_key → graph (LRU) self._warmup_counts: dict[tuple, int] = {} functools.update_wrapper(self, fn) @@ -252,6 +253,7 @@ def __call__(self, *args, **kwargs): # Check if we have a compiled graph for this shape+kwargs combo if key in self._graphs: + self._graphs.move_to_end(key) # Replay (for now, execute normally — C++ OpGraph replay TBD) return self._fn(*args, **kwargs) @@ -263,14 +265,14 @@ def __call__(self, *args, **kwargs): if count >= self._warmup: # Trace and cache - graph = trace(self._fn, args) + graph = trace(lambda *a: self._fn(*a, **kwargs), args) self._graphs[key] = graph + self._graphs.move_to_end(key) logger.info("JIT compiled: %d ops captured (key=%s)", graph.num_ops, key[:2]) # LRU eviction if len(self._graphs) > self._MAX_CACHED_GRAPHS: - oldest_key = next(iter(self._graphs)) - del self._graphs[oldest_key] + self._graphs.popitem(last=False) logger.info("JIT evicted oldest graph (cache size=%d)", len(self._graphs)) return result diff --git a/functional/bridge.py b/functional/bridge.py index 53e1c49..449a9ef 100644 --- a/functional/bridge.py +++ b/functional/bridge.py @@ -8,14 +8,14 @@ import numpy as np -def _get_backend(): - """Get backend instance""" - try: - from ..backend.compute import Compute - - return Compute() - except Exception: +def _to_numpy(result): + if result is None: return None + if isinstance(result, np.ndarray): + return result + if hasattr(result, "numpy"): + return result.numpy() + return np.asarray(result) def continuous_to_spikes( @@ -40,16 +40,30 @@ def continuous_to_spikes( Returns: Spike trains (batch, num_timesteps, spike_dim) """ - from grilly import Compute - - backend = Compute() - return backend.continuous_to_spikes( - continuous, - num_timesteps=num_timesteps, - encoding_type=encoding_type, - projection_weights=projection_weights, - projection_bias=projection_bias, - ) + try: + from grilly.backend import _bridge + + result = _bridge.continuous_to_spikes( + continuous, + num_timesteps=num_timesteps, + encoding_type=encoding_type, + projection_weights=projection_weights, + projection_bias=projection_bias, + ) + if result is not None: + return _to_numpy(result) + except (ImportError, Exception): + pass + x = np.asarray(continuous, dtype=np.float32) + if x.ndim == 1: + x = x[None, :] + if projection_weights is not None: + x = x @ np.asarray(projection_weights, dtype=np.float32).T + if projection_bias is not None: + x = x + np.asarray(projection_bias, dtype=np.float32) + x = np.clip(x, 0.0, 1.0) + t = int(num_timesteps) + return (np.random.rand(x.shape[0], t, x.shape[1]).astype(np.float32) < x[:, None, :]).astype(np.float32) def spikes_to_continuous( @@ -76,17 +90,29 @@ def spikes_to_continuous( Returns: Continuous values (batch, output_dim) """ - from grilly import Compute - - backend = Compute() - return backend.spikes_to_continuous( - spikes, - encoding_type=encoding_type, - time_window=time_window, - temporal_weights=temporal_weights, - projection_weights=projection_weights, - projection_bias=projection_bias, - ) + try: + from grilly.backend import _bridge + + result = _bridge.spikes_to_continuous( + spikes, + encoding_type=encoding_type, + time_window=time_window, + temporal_weights=temporal_weights, + projection_weights=projection_weights, + projection_bias=projection_bias, + ) + if result is not None: + return _to_numpy(result) + except (ImportError, Exception): + pass + s = np.asarray(spikes, dtype=np.float32) + window = max(1, min(int(time_window), s.shape[1])) + cont = np.mean(s[:, -window:, :], axis=1) + if projection_weights is not None: + cont = cont @ np.asarray(projection_weights, dtype=np.float32).T + if projection_bias is not None: + cont = cont + np.asarray(projection_bias, dtype=np.float32) + return cont.astype(np.float32) def bridge_temporal_weights(weights: np.ndarray, temporal_window: int = 10) -> np.ndarray: @@ -102,17 +128,6 @@ def bridge_temporal_weights(weights: np.ndarray, temporal_window: int = 10) -> n Returns: Temporally weighted weights """ - backend = _get_backend() - - # Try GPU shader if available - if backend and hasattr(backend, "shaders") and "bridge-temporal-weights" in backend.shaders: - try: - # GPU temporal weights would go here - # For now, use CPU fallback - pass - except Exception: - pass # Fall back to CPU - # CPU fallback - Apply temporal decay to weights temporal_decay = np.exp(-np.arange(temporal_window) / temporal_window) # Simple implementation: return weights with temporal scaling diff --git a/functional/cells.py b/functional/cells.py index f97d5b7..34ca1aa 100644 --- a/functional/cells.py +++ b/functional/cells.py @@ -7,14 +7,14 @@ import numpy as np -def _get_backend(): - """Get backend instance""" - try: - from ..backend.compute import Compute - - return Compute() - except Exception: +def _to_numpy(result): + if result is None: return None + if isinstance(result, np.ndarray): + return result + if hasattr(result, "numpy"): + return result.numpy() + return np.asarray(result) def place_cell( @@ -39,16 +39,30 @@ def place_cell( Returns: Firing rates (n_neurons,) or (batch, n_neurons) """ - from grilly import Compute - - backend = Compute() - return backend.place_cell( - agent_position, - field_centers, - field_width=field_width, - max_rate=max_rate, - baseline_rate=baseline_rate, + try: + from grilly.backend import _bridge + + result = _bridge.place_cell( + agent_position, + field_centers, + field_width=field_width, + max_rate=max_rate, + baseline_rate=baseline_rate, + ) + if result is not None: + return _to_numpy(result) + except (ImportError, Exception): + pass + pos = np.asarray(agent_position, dtype=np.float32) + centers = np.asarray(field_centers, dtype=np.float32) + if pos.ndim == 1: + pos = pos[None, :] + diff = pos[:, None, :] - centers[None, :, :] + d2 = np.sum(diff * diff, axis=-1) + rates = baseline_rate + (max_rate - baseline_rate) * np.exp( + -d2 / (2.0 * field_width * field_width + 1e-8) ) + return rates.astype(np.float32).squeeze(0) if np.asarray(agent_position).ndim == 1 else rates.astype(np.float32) def time_cell( @@ -75,17 +89,31 @@ def time_cell( Returns: (firing_rates, updated_membrane_state) """ - from grilly import Compute - - backend = Compute() - return backend.time_cell( - current_time, - preferred_times, - time_constant=temporal_width, - max_rate=max_rate, - baseline_rate=baseline_rate, - membrane_state=membrane_state, + try: + from grilly.backend import _bridge + + result = _bridge.time_cell( + current_time, + preferred_times, + time_constant=temporal_width, + max_rate=max_rate, + baseline_rate=baseline_rate, + membrane_state=membrane_state, + ) + if result is not None: + rates, mem = result + return _to_numpy(rates), _to_numpy(mem) + except (ImportError, Exception): + pass + pref = np.asarray(preferred_times, dtype=np.float32) + rates = baseline_rate + (max_rate - baseline_rate) * np.exp( + -((float(current_time) - pref) ** 2) / (2.0 * temporal_width * temporal_width + 1e-8) ) + if membrane_state is None: + mem = rates.astype(np.float32) + else: + mem = (0.9 * np.asarray(membrane_state, dtype=np.float32) + 0.1 * rates).astype(np.float32) + return rates.astype(np.float32), mem def theta_gamma_encoding( @@ -112,17 +140,6 @@ def theta_gamma_encoding( Returns: Theta-gamma encoding (n_theta * n_gamma,) """ - backend = _get_backend() - - # Try GPU shader if available - if backend and hasattr(backend, "shaders") and "theta-gamma-encoding" in backend.shaders: - try: - # GPU theta-gamma encoding would go here - # For now, use CPU fallback - pass - except Exception: - pass # Fall back to CPU - # CPU fallback theta_phase = 2 * np.pi * theta_freq * time gamma_phase = 2 * np.pi * gamma_freq * time diff --git a/shaders/attention-output.glsl b/shaders/attention-output.glsl index fd286a6..52f2542 100644 --- a/shaders/attention-output.glsl +++ b/shaders/attention-output.glsl @@ -44,21 +44,21 @@ void main() { uint head_idx = remainder / head_dim; uint dim_idx = remainder % head_dim; - // Compute weighted sum: output[batch, seq_q, head, dim] = sum_k(weights[batch, head, seq_q, k] * V[batch, k, head, dim]) + // Compute weighted sum: output[batch, seq_q, head, dim] + // = sum_k(weights[batch, head, seq_q, k] * V[batch, k, head, dim]) + // Hoist loop-invariant index math and use FMA in the inner loop. float sum = 0.0; - + + uint weight_base = batch_idx * num_heads * seq_len * seq_len + + head_idx * seq_len * seq_len + + seq_q * seq_len; + + uint v_base = batch_idx * seq_len * num_heads * head_dim + + head_idx * head_dim + dim_idx; + uint v_stride = num_heads * head_dim; + for (uint k = 0; k < seq_len; k++) { - // Weight index: weights[batch, head, seq_q, k] - uint weight_idx = batch_idx * num_heads * seq_len * seq_len + - head_idx * seq_len * seq_len + - seq_q * seq_len + k; - - // Value index: V[batch, k, head, dim] - uint v_idx = batch_idx * seq_len * num_heads * head_dim + - k * num_heads * head_dim + - head_idx * head_dim + dim_idx; - - sum += weights[weight_idx] * V[v_idx]; + sum = fma(weights[weight_base + k], V[v_base + k * v_stride], sum); } // Output index: output[batch, seq_q, head, dim] diff --git a/shaders/flash-attention2.glsl b/shaders/flash-attention2.glsl index c9ccf74..3c89d09 100644 --- a/shaders/flash-attention2.glsl +++ b/shaders/flash-attention2.glsl @@ -81,6 +81,9 @@ void update_online_softmax( void main() { uint row = gl_GlobalInvocationID.y; uint col = gl_GlobalInvocationID.x; + uint seq_heads = seq_len * num_heads; + uint head_stride = num_heads * head_dim; + uint bsh_stride = seq_len * head_stride; if (pass_type == 0) { // Pass 0: Initialize running max and sum for each query position @@ -130,19 +133,11 @@ void main() { // Compute attention score: Q[q_pos] @ K[k_pos]^T / sqrt(head_dim) float score = 0.0; + uint q_base = batch_idx * bsh_stride + q_pos * head_stride + head_idx * head_dim; + uint k_base = batch_idx * bsh_stride + k_pos * head_stride + head_idx * head_dim; for (uint d = 0; d < head_dim; d++) { - // Q index: [batch, seq, head, head_dim] - uint q_idx = batch_idx * seq_len * num_heads * head_dim + - q_pos * num_heads * head_dim + - head_idx * head_dim + d; - - // K index: [batch, seq, head, head_dim] - uint k_idx = batch_idx * seq_len * num_heads * head_dim + - k_pos * num_heads * head_dim + - head_idx * head_dim + d; - - score += Q[q_idx] * K[k_idx]; + score = fma(Q[q_base + d], K[k_base + d], score); } score *= scale; @@ -154,9 +149,7 @@ void main() { } // Get running max and sum for this query position - uint q_flat_idx = batch_idx * seq_len * num_heads + - q_pos * num_heads + - head_idx; + uint q_flat_idx = batch_idx * seq_heads + q_pos * num_heads + head_idx; float old_max = running_max[q_flat_idx]; float old_sum = running_sum[q_flat_idx]; @@ -176,23 +169,17 @@ void main() { // Rescale existing accumulator if max changed if (new_max > old_max) { float rescale = exp(old_max - new_max); + uint accum_base = q_flat_idx * head_dim; for (uint d = 0; d < head_dim; d++) { - uint accum_idx = q_flat_idx * head_dim + d; - output_accum[accum_idx] *= rescale; + output_accum[accum_base + d] *= rescale; } } // Accumulate weighted value: output += weight * V[k_pos] + uint accum_base = q_flat_idx * head_dim; + uint v_base = batch_idx * bsh_stride + k_pos * head_stride + head_idx * head_dim; for (uint d = 0; d < head_dim; d++) { - uint accum_idx = q_flat_idx * head_dim + d; - - // V index: [batch, seq, head, head_dim] - uint v_idx = batch_idx * seq_len * num_heads * head_dim + - k_pos * num_heads * head_dim + - head_idx * head_dim + d; - - // Accumulate: output[q_pos] += weight * V[k_pos] - output_accum[accum_idx] += weight * V[v_idx]; + output_accum[accum_base + d] = fma(weight, V[v_base + d], output_accum[accum_base + d]); } } else if (pass_type == 2) { diff --git a/shaders/fnn-linear.glsl b/shaders/fnn-linear.glsl index 5f4781c..d016877 100644 --- a/shaders/fnn-linear.glsl +++ b/shaders/fnn-linear.glsl @@ -40,25 +40,27 @@ void main() { float acc = 0.0; uint numTiles = (input_dim + 15) / 16; + uint row_base = row * input_dim; + uint col_base = col * input_dim; for (uint t = 0; t < numTiles; t++) { uint k_base = t * 16; uint k_a = k_base + tx; tileA[ty][tx] = (row < batch_seq && k_a < input_dim) - ? input_data[row * input_dim + k_a] : 0.0; + ? input_data[row_base + k_a] : 0.0; uint k_b = k_base + ty; tileB[ty][tx] = (col < output_dim && k_b < input_dim) - ? W[col * input_dim + k_b] : 0.0; + ? W[col_base + k_b] : 0.0; barrier(); for (uint k = 0; k < 16; k += 4) { - acc += tileA[ty][k] * tileB[k][tx]; - acc += tileA[ty][k + 1] * tileB[k + 1][tx]; - acc += tileA[ty][k + 2] * tileB[k + 2][tx]; - acc += tileA[ty][k + 3] * tileB[k + 3][tx]; + acc = fma(tileA[ty][k], tileB[k][tx], acc); + acc = fma(tileA[ty][k + 1], tileB[k + 1][tx], acc); + acc = fma(tileA[ty][k + 2], tileB[k + 2][tx], acc); + acc = fma(tileA[ty][k + 3], tileB[k + 3][tx], acc); } barrier(); diff --git a/shaders/spv/attention-output.spv b/shaders/spv/attention-output.spv index 7f53be4b489c5ce174f6c0aec32d56663e0820cc..94f9b5841486888b8010794b936f9dd889ab9055 100644 GIT binary patch literal 5152 zcmZ9OiI*Hz6^ARCo=JeP6A*+jAwU3iWCul-7;q98NLWlzFrChHn`!CkNv5kqSd=YB zS&g#qplARU5nOObWKrA@xBo=Jb3BLdtF9Z~smi&zzx#W4d-uJno}sx*=4IKOY>#Zu z?BCho**BXDlfjJ?dVJ#2iLqYZ96R}xbPbtqw7V|sf$Re>x32__ z`YJ>Gn~Q!0qn7QSjZf6aSFIYKIIljrX;ZzIU(;#TJMC_(-fYjdCiC`|mVWaSzjk-B z({H8@!+50ihp9228XH%hOKkjl) z1)i5Zj9!-K434R;Et0v_VlJ-v3~Ka>`i~B%qSm*;YZ{$?tC(*s)Am_|*z3;164$a9 z!8855+(#|@3-+phzNw!N=Bn+34_v+3$QwnDw+3`JP7d#;awsxRe$UM_T7hSr?GkLC z;5=73?^tlg$a&9#8!EX4Xlp%$9zl$=mb1`j1x(E41BE@j%QmNv{?5sDbBPxtK4Bj* zFbn%Ba9YbbMy;4>)CxPT<(#9I*jnz}TH0x?KVzr0oco~GaB(J>#cb5ZJAd;zv#`^A z|AO~S4o3FJOTWcrG^cnWIjlE|_>5xDLo4zJVeeLAV_e@IJqM9LqqMCd_aSWKJSQLh z_e8rUzli_63*uDt+Yip1uA5ie_g2n64?vG1QAc|r;$!@Q=;+lxTzB88|6$nfJJuJ2 z-PZxIOVI~nU5Oqoq?UaS?JVSsQP&Zl!|?SU$}IwW@3igZ5wvyW+>5sNOwRQuOKhxp zykn8aJEX1OchJVjnMZqP$r!E{c7v){W}Qp&U!Dy-*;a= zd=3HY*6uwU2`3Vo%3|t1~odhQz`#u?LoP5Nc0I0H^Tdig}@lkVY6IQi)1EUa#U9QDr+YXL+d64~pTd@pUdBtG^!}}YlaF3jf{l}pd7clp zo@>t2Z-ZR)asfE)Wdgg}%Y|_A(aS}pPuj~WIQi&hHP|@$n9Idr>$w*9Q!aYB6rA>Q z8FqF4pN5l<{jC8TCvQ%_b#jsOa&VgSGuYLfpM{f;KCb{9CvVQRXgTpZbRCKO4X}4K zX0rio{`4#-vE^enO|WtD5!V76cQ6vUKM$6VxQ$?ApuNw;61#{Eg*1G7GjZ*P^E$*!vl|b8t0xA5K1U zZvh)8AGxmvo8L9hY&}{o>Rba(&+J<4YMtxgAdpge@O&H-n9pkG^jKr)RbeyPD_AaPpC7JJ>k+n9o_Z=L!- zSuML8eGlT}`n_n^)v>4hz~+d4?g!g*`i?z-E&qJ^8~h;HIC+0}wZDN_&-~)Z{}_0E zVm}VHUd-?bus-t3_>1I8^f!@5k%h<;h`Fvre+zLfaz71r=3zerHh0+1f}K^&@w;Gs zVF@+6S2PQQU5uxYhnKYybf8wUlaECJTeD45wTbO_3@nabKP_BOq`QC zdi^0doy!Z@a-Kt+vs`-4&O73N1Rg`;9DfY1p5ss8yOkk>|Hy<6cH0?iFyg{_o)A zBkonOvGQ^LuYv8$wdm<}@Fhs(egj<1{d+k1$o&Vfaq^M-O|bc0i`;(%uSLS=PhkDz nJ?9tEa^e3su$;Dc>}|xm6nFLS;I|O(s=nrs)6c(sPe_ZmDnj+!Xlk!ReikESY0=7W|G%vjW@fsG+}^@ zbp8S6n8_T=7adP;GxL6yH!CB}sb$sXqSorE@oJ|r)nYVbe2`wVtTWZAHY@dJeX`!_ zwB-l#-cRecCl*Y#+H!;Zq(65+-u1QXD=W=yv@0;$lKZ6n}w?)~tq>h5T*zN+Bn;d}EoYM19+f3^VcoGXi*Gw^xcUJ8}^ zv!!rMQ;~9>i_~-{b36vV6m!l?@R!va%O^T{DolU27H)MfB%YByLOj3IVUyUwrFpW; z%J;xAzg~?t`F+{C9+lU68oZ?1?AG(W?9Yt-mSE?+iS%5Z^VncofX|q zZx4jiB)!YCB^i)c1@$kA-B$InE{a3(hl?^F9TqM$S7GTwlR$ zig&Ju@q?H;=W;L1SqKyRvO^vZY>2JtXTJO7xV6N?m|w)Z_QE`V3Y^a6KE_0It}%q4cNX_#{vr*r+4IGxLV@69!kp9x{UH+uE%zxCX+h|_xig!fElW808r-Y^xd zDV{|Q=NrQOhKOh2HTms{H!83i$2Z1L!{m=DV&{l*tTm9x)n@k3b5V>}DVS z&o0Ef7jkO!70hoJa=nLgbHLs^V|#fN?>us@#dsYi=lJ6VR%@N7@KNV!u(5f+$E%UE zj`8|}Q|EmNop;4p-IqmE@Vs+53AmYjU5qz~ql%#jelTeDB-#n0M5B82SFzb$d|`w>_^bVFdZ@}b)mtZo#hMNTm-p}> zIQhsqxX4NO>kv5kSofh|b@HJ*3|!u?!{OvZ_hYbH`5THeI|6KJUIF2^L(&6d224k%ZV?-R|-6a-xT{FzriZr-=hBC{9ecLHO$ZP zI^J=8^mZ|L42ykT4z{=1SmaEATC1V0H4Ln*x{nS^*~?x|LwH@}XM=F6-KG@}cX1)yhZTU9jgEbykDRbuNRGk2;rw z)yaqM3UFC>C7gWdehyYEzoz*8{sL@Yjzv#bf&J}7?W@7%+SkCzN9{FWb@EaBTCnxk zU{U)z@E8_3zXY2n@145@FBkdOgVS@pfwN|86+>(ES#y zRzCLicVPQ+EPA>JoZhv2iOaR`gOiWi_k-2RN9_l|)^{vwKL}omMb1ND^W;6(wRpM6 ze+=wd&*y(HkK>=f)?u?S$IQC`|0L#E=>7n2Wn(bH$(^lpDnT(12`IQgjk zC$KvCsQqWK^&N}ae*vd=`wL?8HCWEryY~&|J&Qa0E%+0LQhqi~s-t diff --git a/shaders/spv/flash-attention2.spv b/shaders/spv/flash-attention2.spv index f945782f5f9c478d220fe3028ae7e67c9984b12e..04eb457759c24ad8900e7ed78fefef0ce95a638e 100644 GIT binary patch literal 12828 zcmZvi2Y_8wxrI-dnIu4H(iJeF20;YrB47kekOT;!i0yirnUFD)naSLlP*kKFET{wv zR=@(HQba|ufS_U*%R>bpAOcn_C{+=^@7}YL&3*5f?Oto`|Nr;D>p63^XRYm~Rn?Sg z-D8;kO=FV@-J?f~r^ABtEkBv226Q>OiHim~r1{(uIUZrMutXCZs5^a(Hw5;>lS)za#^kSR-2#?o;ubT9Uj20 zr<$$Z=qcdo?YTI4xG{Hl^rYtS;UgMUbt-QK0{x+T4z&LW6f5pF>%^h8D~2Bl4gr} zZ7l_^ws#$tH`+Z64V=10cP*9{x^uR#jy%rU9P0KMS`*_#1A}dC9@Sffv?u0dbk7+t)$HAo`0Lp#>A z>U!#>|8HAwbwk^a*=^hQRyVb6!~3c`;ENeRdvI-=zUofu&OUB0J9+h0_n@`cv?HPP zR`;W=IitPR1MuaIlbS6ym+cS3m#*eL)zk3I={g*Wg_DEhLxU}5gImW#!n2U#T+OJt zaW)1YF?gCoXxqwtj-Y6}>1x})Y8SBkc+GgGR(qn&nV1+Haf2=BY;w=0x=;4KvGy)v z9Ryz9STj53ayS}yig!?Zj0_&%(?5S?nO&p_kkA=jd1pz8DsZV52AI(vS!~s zgf=odzN|US9fz}X@F-g6{B_Q=+^0TsykMH`TpzFXx%l}0^IfqOWv=T??HxQ1%iP`D z!F{6XuGU*<`lxwl#;&K!-dseCHb5zwOVXo z&0S?&?0%+&I@# z>!W5Ix%Wrn>`2?K1vl=)@TVWVYd7QOZZWR+irl!5Q5!?exN`52GVZ_Op7~wLc4unu z6*Y_Jb01b!d*{l>z5Y0gYi-P(sAp5A6@2$DzE>BY)5Yg?@kL#{zl%?FasO^9`(M?? zFYDs|jZxZP-^Kk~qHOnXiIV%bM9KX-qU8P!5pI2d!FK{b(8c}xA@;WW_e073)-So= z>*3aZtC!qw^^z~_;)}ZYab4VR_0r#O^^*IoUUI+H!*?Q|$u919dfD!`ddXLH@k_h7 z-`}OZ-`pj?x8Ul2bC>OYZ--m^t!+JrUF>^jH;V5j_xpkDY0ujWDC!5Xk?+I4d-u*Av-gn3W6yH0(TVns>T93^^wT^b(4G9lfXhn~hpx zycBFdWsF1M>ajVjw(-4{eqIJvPd{_P`l&mfIn-*2aX7epJTFI6Ps}6Q_G@C!hijv5 zOwYVpVjc-LrtiaW2k+eWX)__NO=dv`$?SG$AOLQyEw7^ zJ5N0}CxE-h^a?cf%*%;j{nX?4O0a(AoWBaL9=}(E^;LK5i>TF-%WJ^yvGCV|wYj<; zcLS`Ay5nwA`#5gfUPn=L+~TLIDQ&Elm;+$@u>XE)dD@1+YIAr8ETtZ%ETK4Gw&|nh zH%NSzgG-+le2k)xZThH<*VsFI56TndI7w};){ioVQ>ag+_}G3Lwe8mF?{u&+GVgBy zJMZN_ej{A{+IrsJ1lCX8xstD-*q`ymiT`G>^ON{zg1h6Ng{Gc$JsYf_dg7b|HqSHb zGkY#rJ9X#e3~II5UjSCi+`kpPlHxqtrjMHQ=-fC@=5C!BZv&Toz8$W%KSh7r)XKSc z4&mjH&5xr7AN+5z{b9qlKkHb zR*%j5z}lF9Y~Bx6kIe_b+FU}3%?H8ivH1{Ko3gJD!_~8f9|7y9es;ZP9|hZ=Z8LaN zc=ujPVK{AUzS=mqtEexfq|c9m?elU<^7=SfJvN^JYg3NzlW_IeTv^*(L7h0C0;|X7 zs@kUP>(glJ8Q*8X`l)App9R~WZ5iLy;Ed0FwQ+oxQQMXrt_6Ee!mk562N%?H_c^dO z>iSr zgSAsPmz${7Vt*Uh{L=4t!1i12vG2mw^IrKLSU>eEIky{9f1jf6nf@WQk7ru_2NX5$ z3vv4X5!igv_m9Edecz6zp1ywq)=%AgC~>VH;mn!KPbuc&{h_^mn%B>%eauUnpHbAz zOPu$gT5|jaxZJnDgs-BUU+>#rfwfV0e(#`Gi~n6aqC)_=bZ25j=+ypFe@MQP1A|GuT+RCC)=&bIW}F1?<=krNrj1VD-fQ8(2T}#QrpQgdh&PzY%JU2_awL+f|{twvrwmHU!snx_!Q}03XjCeu2$F2QtT8p}e;$wSn-EN&6`@qJ? zI!uLk=Qs^bJ>N~!!TPCt#^h5d_Gka%#9t5WnuMS1GqNo`Nr80?C&_oVrRR) zQ^6a7ZA+}pz|KYZ=I}E17I1CU6MIYWrW9k_p4iU@+ZO&D@cr1#;8z~|-U`mY_TO;k zp}jUaSGGIH&bM=5oxHXNmw9aiSKFU>SwppQ4V_EZWG3}2ifiI;Qps=IS`Xh2ZocQ? zW4_zN)wD~FYFXzU!N%NyI(hB{R?j!wY_Qs{*m8 zZeNZ!HoKvzXWzd7te<*pb_d^tPxi_laP8C`lfQeY#eN^KTGo1B@ZJ>H+BWT+oBhDH zom0=%{$RDOXv_Q^09W_-x%j;p>>Qs}`yB{Z(>8Og<~n9xtdrlt;PNhd34AZ&We(=R zwNa1HOTjl_les(u?pz+&E}i$vp=j!fa~N1Z^~~MNz{a;NvFC#CAx>=O!L?I&JpPWP z7W9tfS2I!{Mx3S^E)4Gn=$}7EsI6b`03K zwk@QVr|np<_O|&umOO38gIzz{{7p-qwpW0SW!nkV^0d7YY%JSOq?Yf`HwZUv^(ycT zYUAl=d;DHqYx)^mp0?M3wYAOIa@%$!)@#*k?tC`D`e)8w2hO~=ZHb|#U1BuBWsH;H z_LCU>;LLm5mKbWk=^ks^?H_kM9uJ{#Vv#=k;K1 z)t!R@YPHxe1>0BnFxWM}5&sQ|y$r04x_8qkwU2j`Z6g#l?mYksRh?Y zJ!6^x^RGQ7+s${3THQFV*$Q<^&gm&&&*`C**qn-{o^yH{SU+`RPExChPp8g%{S9Ep zVeNbUP1I*le7tMkT(?`NuQS2MNIqwQyYo34O+D}RbHMtkJCE`=QtZ$E#fg6|*nGp! z1KVHjlJmjZsB7cicnjEh%N|?_UWxAgp?&NwP^aWBek)i_|M3g^YT7y{&W|}; zC&#yg%N*YUSIc|YHZ}b-=Z;;Ty=0wu?*x}?{w{d;9KRb)J!^guSU+`hzK~iiF)s#} z?}+!n=P^I=e=l4c_4vFG?0Yi&{qQ-3&n0ke)U!Sx02|A;_BnV z`FsSdpL$|{6l{Fk9OEi#HSwj?+1HnWT_bDv^~b0`PVsThKT)?^C&y2MjgcI$0C(qj zC7OEn^{2r4sk^V`ms9M|{>6!Z71%Wi|1{YCGPcivwNcl`ef?Rm_g8#A2X@Ti*MrL# zpNDIsp8fm<@O2c=g6)a@MX+s&bpyD}?@RD9_Lt$>s3-P~;G7lP6Z zH1(W`AAt2!kKYf$#xB3xeuSnTzaNA3Rd;T0qgG2Uw}WRFx%~vLO@1@^DOelz{9f}j zu)6c`Kh!?XgZj@YYR-c={oDcG9lPZI3wZY#`z4xs`u-JIKXunGajo-v&7EL#an7~3 zPxJaUwU2pea~DO;yu|svMlCts19mQ)r@O)OwEYHb{L7dgA;EY<{_0{tVVm-F_dWR*U^# zz-r0;ui%_V+mid=z_x||9qbtVE!VOA1I)kn{Mq&}CEsQL1naMjIXpzImK>e{n?qth z3Dzb#{R{j!#WSR>G1RnkU*z0bmuK9wEli_oL+hW8inT&veSZ#GOXHc7fL^ z@Va2*tViwt4?){aDAvh+eX#!F8-R_GZ_JIr+NirvH>UP+pW5~;%5y39BTk%6!Nv*S z3~WF7Zr&WMje5Q@w*WIG?Qcxmoj3DUPn<2mjw}4xU~Tf9StA*-vPD;tNXXD zG4)qVjBUW?W}7+7T?fbNxEx<>XM)T5ng!P<0o{9PS^QSXdhLU{U7pA?{3aOpGpUjt?|OhFby0Miz}PZJxdRqJ0l(|Do=l zY6I+s2Ua%5`bU#BCwZe%5RMr$t48fx?njVx#m9W%UQq`x^nI5K?9k=SVKg!YdZzo*&{{E~^W#Yc<` zkByI2R5Go_L01c!pV}s{Xq%*Fq^Eo!KtT8k&JZbCP4pZCWf#yO&O!BUn9P_#q zI2prv8|!ZlIril10gn%|5Tsl$`aZj$H&)5ftTa(n{ zFP-F?&|uMZRDQn7UV?? zpf$LbO;2?jeS3elmXo@As@vgO`>jpEbXRx6tv_d7)m_M$v-`ly>Y&;_hRlfE9xdN3 z*z6}x3@jfU7-KfND`g+o-MQMTmi2u)_}GDS974-h+2>fAw%bqI_Eh_V-Jk2nGp#xp zZtnQ_z_1(axc27t%*p#?KO5_iBG!@MWsUW-V=eRHxEH)fTH`pimfL%%yE+Zo`zQL- zYdQJ`@@R8;(@kN#W{2G0ArB+392h)n@p#+*T;$~g?GtXj=OMS}yr;SZ+}d!H^W0Tk zQRvtHAANUqE!>Jp`|j#n$fq=iCRmo%Tr2x{rp$M1Eqi|N0cQ_C_5Ai!52AbO+;vqC z!KK!R!HWimIs49xv3sg@a8IqDn~Mu@cx3s~<`8!m&d%8eaP9NgKF`WN^_k-((QMVa zQO)<^<2%gzetTLU7$)@VL0z%Ip0y?^pW#D8P0oH zZp(u64wl=t;Jkn3tTn#olbcm=-n(+U6r6XiobN2hJ)hnj*5kNOrthK|SL~y`kM^$f zY?`%3-U;lZd>VZj&$~47a<7*0yi<**kN(PKyzyF>`n*3AFL!Af&pXt3&XMty%Xm+l z^3?jgSK9H`-!Yl%+&Syxx9=VG#`X>}mrvv}*LEGbyqA(IcT{pkE^~QrC6{*TevlMw|A@4@7pZqq+d(_zc z>YqhFg(iPUEqgb~8DGD>YFYbd*H}LKeZf9E@{PYAy}GhZ>yS&GG51IB-J!2FDEnO| z=a>$ppG8Ye$_LRLm;MLSYinFCIu> z33_GYuAn!DoN<-Cj}q5;IiKcTnz$E(ox8}FfsGycaA&sv2H)^mB(OQYh!v3_zw2W zMBb-E_U}KX-f!PhUeF=?omA@AbjX)>$X9mAYdhqtJ7oWEP{zNdL%yv;_U{C-x4*NH zcLU$mA>Y>_Kh`1p?Vfo0`~4l+etw7Sw|A*OtwZ+PyVU#bUCPZ4*>CSs@3(g;pIgYg zQIFr=(JTA?UCMrYm$KjArR=wNDc{o}`wd;{A1Y*dzoAS0xQQ$n?KL#dk~p#Ff{oovOY9ed!Cm+9~ z!TR;l;&%*KK7KC->nrcrJ-c$TKNjr#MV<%N=K6Zv$APtxcibn?`#5fO$J68-xA=)_ zN(;*+=1E|4nBTiXIl5E9a&x(_d@sI&b_%VRrcNI@-=FbmfJ>hyvhT%SnmT>t&aCkQ zde?0oHTKh6tNkO4VSs)S%}4!N^y=-CZ!y>ynfF(Lo%canY?grKZ>Z;O2&|vHbEUkH zW!q{>90<%oO|a$*)wE3&rQbuTCig;&*JNl<#QI# z1Iz72b5E|O_i>)pt)ksIr5#(G*cX6}eLgMqUkH|u%|&2stUord2g}Fi60kNG(_-@m zuzYOZ2=2`FCOG-5;hVww$*-x`>@8sPshi0yIEVhNG=|f{)~k(kyM|s}a=s00&dX@2 z>+N9q*t`SWIlgzo$;akhwaw-9iE{;5J~r>JZOUBlfs@bp-V4@GKI3~I*nH|TzW0MO zKI_%S@m)%^X^i71%krq@KGEg0;DdrvH`naI;b71Y{t>t=px!8XmY<3ytAo11BJ)+NsSPcAin3taBoZzHdvT~zPeTfy4M zJHNNk%f=w z&-?ZlVEyD>8|Axc<}=`jzYOTV(n8 z{SK_JeEc2&>sOBF_sH_``vX{C`PB4B@YS@$`4hM^|DWOH6X!2r{p92KS8!+kzro4J z@9$uJ<#P}GFW9=&rKShL?t{es2e>o#LvZqm{ZFud@`?R0*!b!a`w_7FAU2PJwUc+x z-A6AM`*q-QT^~obmU3PHg)E<1{teboK4W+StY5jV|3Q|I-;-c{<>Ti|Ouuqnr@+a_ zZz@<{`PAJ7zP^aljog{P2Tne5rh)a7PyXrP&ipgrxN+KQkR-G0+;JL z6S*_?#&Gh9y$M)9`NZB7Yf*Np*jVw~5!v63^ivnVXM)vT$bD>_Szu$y$7VLzJ8pGtvlFtMwz>c0ywkkj zymRbR%g*S^G4Fyb=iby`om_cmdUv_DbKrNQIsd!Tr>5O&KJp&O)|hcT3t3LP)F_wz zvKQEx{x+66pADAJxAxxP7KfehzkQH>ykpfphbHHE#fkG=aCsJ=hitC$Eba#s1IRNb0$bLEySvz^hv3~(rE^B@$*x&GaiJ?wA*Z43o?fiN@4+qO_k1q2! z7g@fS7QYvQ%YASJvYfV=V>#C`^J1U+jslnS@*?EUdFg|b&%7KB)=xfbcMMp+@{W2j zoP7L_1?wxH{PV!&JLEXz&N#=z$tTVUVEyE?4le;4UtMCK2=>m1%}L1G$vZbk(#ysE zWUyS;>ZRaQ@b9Im)6TUz1+2~(^R=mUbI`pET*i1g@)_{GwJtGU0k6&&r_w7&_e!vF z)%jbTa&%{awO4mKy>fJCg53-18erw<7J!YVu1T*P-9oUj)b-OVAIx{>RALQ)XVM!_ zKlSljRCD?nTRFP3z}l)awz4|sXffD1itQj+|IGWVz?pY#jUlI9Vk`ldF@})MlNd|E znRjiCA*WsPj)0x7$fL;Cku$gqtd0CkR^GK=4(3nw44V4*j)Bd8aeaQr!P?53dzfA> z_7h-pMP31R&2PkiGh(j#6k;>3A1*f^1211`t(T4ZhHGp5&p z`O_Mcdh7K!T6yt#^!a{X1vaO>-_L94&!_o#$6QeB?XAJ}z7TAT+!YsrJ>%tf$?M_d z^Zk4=SU-8kth}0LKJ$wc{|#XGLgK#>+!_B(aPqlB-VD}HK5^axw$9uaZv|^7?>t;W zFBkhu!E(y(jmyB!TlV1Fz$@W*#zcXD`{C>k9g;`Mbg9vUknjOaDHakG}7(_4cXdO0Y4qCLaKI zuE|w!@>%l_g7uSk&6VFnGoSgziT@#Rx#k~6?u`EtIQgvk)nNVP6X&B~>&%*81J+L7 zd00y?7yIkL<@@qu$mT7-BR&o%pS+&{>nET2xE`!u`M&%loP7Ll0P8CszfXZX$N6bE z`S^VXtgn3P{w%osrvDsrXa3K_$tTVi!1~E2{};iX`M(4wAHN&H`pRcszYMl6b*bqq z;PRXPtH_*Y##(?c`n8Yw6`;|4s1PTJ|je4|urH z-v-tv=kq(@TWP)6sxyv$-vz77+5R5bSn~1xJ~(IFyz#vqtZsGv4*3CC&KU8%16=w# z7CCJ_i=Io*ihaiRLvT44KSGwvy`WA`|D1KtobqbMWV}0(vj^@1yG{qulHmG>NANz_rS?A|cb<(u!^tQ2U%>jwyLO3dpWhMx3bq#KTzhj`*Wc-VtV^4}(d4X4 zoZk`UQsY0s&V}>zAXquNe}b)F-9upI=pF%UpFRC3INxgOwDT-I23BW`hv}6Q<8g31 zM*IJn{tLNte*O(7pS67gte<@1{0D4(xl5h|YbS5sb@XzvpTbA6Txy?+ob#wIwReHl zMeYVW2LB&}W1EJI*qT3e{(FmjJ5C4duZ=aZ1eIKBmvmMxRMSeP1o80&MY9qh8zQ=a}%lmhEW9l!L7(0TkO`SC>yAF=k yaXG%&&H|V7H5*xsWvsge)ds zeo2nC?Ua$-MTacA38OpPKN}t`4R72yJbG5CGCf^twszDeOSNi!vNTbhnXI&`+b50d zOX8~aN^N!`4d@{vouP*wee{@V%rReYKGWt}X{<4`Wq77BJH7tQ48AX)6MnQ&Dc9E5 zx1(CsMqRJ|u52l~(x~NlIk>f>%@2q#Y~xkna-~8{Pj(czKD+h2YHhL!%(G)_(n(YlpOo@BR>7u^U;d@GVBKx*c|=)qvs*=rxdn5yY&DGw9wug4~oh!KLyA#`d`(A>MzP?Y| z)_En`7&+@`?oTk+;_C_%bZ14E0#5V4z4%^?$G~eIK)b@8WwMRPa4T){P@x(UI-@eG#zdNzz z{Y_I3v1;G%1Ca%Y`*SLL(2HJ#$X`s1_L72gA0zKzu?>TEn z{1M>Jo-5$wjrSgCi@gI!BIf%2c@#2;_+5&19$WA+@A24j$056m{d+W6-nCoTyCfGi zP6WHB(dQ&^XP?Kw$;W=43^q^Rn%+mbsQEasdlfZL0e9AXJe+*gJQZx7yfwYka$@hZ z=X5Xc*=cCsH}zfIwITF6#7F-bX#MJ#=S;9Q!afV^KHR__osAwwjFG>On%bu$*0z>7 za@T{c6>~oUEEoGb0+x&TQLtRtPXv3;ac1X$jgfb)8_;qQe;(L-6!ufV=7;@MaC&}E z!*;L7l9;E10A?NALlZrf;154631Lsuy-)*=YpLx_IL}}82KT-Mm6+S zWD0Sg^qad9T}SkJ7UPBO+|FWO`-ji>GxZr0b*91Q+ix3s2HAvoo}QsK<$2gkkvPBSgFU~UNYr=%SU%40g<$jKJ(FE% zAJ0VJi;!K2dn1lI7lEBGdc7EIuhmGzyaX&Cy)FiuC-1)QLR(iIcjcvE`%WOn+tayT zj`ndbV_t^59x-1WJzoWGe_PtWQ?JH$&e-GEfQ^xlyYgD_9>lw%-&!ltuS4{C7P|}E zxt+zn_77j&mGG&f&Ktny$6a|NcnpeX6G>8m{_y^==m{ly5^5#r*nP+Th99yYyKqIHJhtH=J*s?U)26I zIIZ;=Y&mO1%xA&rnm>m33S9cy+D> z^Gn!rHzTp;FN0mPx%y*{uYmPM?XQB9h|QD8`!SdSbb6Zn_zume+!(h z`P8m{_xZ-+AAE-&=Le@EveE=XbH?Zb8h~CzpOxeE&PYKl<0fTVEaZzXx86 z#Jlr-u-xjLZ~spxXZ~qu^M}xp{{wIz62Bim1P>x(NW}aIEFa&kAA`-4_c!6EXdk~} z`hJ4Q*-IRCeg$;4$r*r)V?c-d= z{2GySE^*B9Td;HNMPiQMf#t6+{&)R7SkCW=wf~6rv9`WHAad3gN1Z=|+iPzBKlB&u zeTe(3&$yN7zasj={u|gkcoF^e{~eJtF7BP2dlG%s(f42A^j!aq-MQCy!pX<~*mr@= vlaD%ggPlL_*F9k4ty8rU!hxj literal 6104 zcmZ9O2b5i96@_n@l2IZAfkXx33<&`YiJ%FfKrjJKB0*5lSQ%d?Lo#ME6EbfoHcANg z0@y2d#fH5r#$LeQTWnygh*H(H{J#6{CjWYO)>&tted=HCebcko{FzlXquRR~sQz1Z zp8cx5U@Ex&TCZAr*4iag?F~zoE<4(YeX5>1(3pLznN=UU7unJpo6xa0a$jUN`SSbK zSldn&**$cqs+|~pM3lLG_y+K8o4f#De_dPZJSqZo1EUd=2UYA>iY0&Cr4W2 zYbMS|wZ|qW^y=@a=AcI=$7?(f+`g#84~Z}A;DzAU$Otig)e+!{=`ClEjgL+NYwk$& znbT97PMMsTYQy&PC+D-i)gIY2JT-bAIJ|2co7g(t9^NpvrQlubK-){(rNe!?^oM>_t@6-G2Xq`Z!Pf< z;uHJOj#=B6fy-XovC~WJlU}jscj&s8=aXJyd%13VX_vij!Y+GxX6cplORv~vFV8f+ z#P)JO?WJAz`UiH|%QNrv>Z@}8dusRAXKy`Ytmj=Q>-`7Y^S?i`KV9uRM6~hZ1Ie-9 zAmTHKJrk|S&&9r1gU!)@Z}bdA{)EQ1huk7;^Sl$<#^0yG^6}3GE6MksmicqBedaLf z0`y=lz13l8XOpv*`f$W&9-WUw+ou#~co3ZP>bIx%qK0!1jsg!NnL>W1GJOY<{_h zW3cx{4qzuwL}zV}1p9qIn7l`!ol#DE6 zJh1PSwsoF{_B`dRqwPB+r+>7;=4w9!ZC~y5-P&;J>$lQ;`(B1l-|b*+>s*dDM$S6g zI~p$gab=f%752c4J^Qs2+q;$c8^OJe?fQ2i`kmt@bmq7jtnD0sLOX|?b7=2wIP-k3 zl6MPO+q{1^F>>Z<`wm&_C{Eim_Z{-Qlv&8Jd^`0!qmQ%Q3DbwR#!U18qK@rvZ)wl# zvWE)$Fz~{{Ht&co+uy;m-tsPcMPXa-iG^+6cwy_mxUlv6o26dCs(gRtGd>(K z*YC@N5x+@(ce2i-8$RU0Lx`=@336r zSAykYp8|IOvKOakv0Ly2+XM)X> zzk{BSL;JXPeUC-tT)Q}R&H_7MdYujKKI_NB$*0#7z~;$&KB=p|jknYNUx(P&eKy{n z&h=!pk8>IGBxDofT;j|%3@-QJ>Db+KwczA)ZtKD3$*0Z;*!AUXHh_)m@cs4g&$-y= zAnvbza~GgDBKkZ-zf;<-*EKk=bI0fVQu>TZoiVWaEBG4u`?eYJx72%Z8e$DO<1+V_ zn%`qjV0&k>j!CdF^5&1D<&u9MxV&TAu!jo&6nGAjbDsvgzilKn&Iilq+%Eu|C+}Ha zg!b_)^<9X_xi{j}c_!HT((74Zdo4y1^K7tudOZhhp1k{-y6W7q=Ys9)UK?*u=eiW_ z<6OpEg1i)QE^&HZ0WSC8dDz`^Js(a!ckBgV^W;F_I(K}zWAUj|=Vf5?bH`o|-imk+JbP=%8MlkxSE7}x5Pesp z=OXrg1=`+fzsIjazZ&t;{~EM@b@t%3U~A<3UI&(&jd%{)uSBe6t~j}`2bXvC4cOMo zTHXlueKAg7*8C>0zSMa$*u2=+fQOLOc?;M*G42{9b>0frw}amf&*W{0wdE7@cJM;v z(kA8|U^!#mKYeoP{Z4SXrgvdGTkgoa!E&>Z%=aF!>oiw?dS46Hm)h?Im$lx9EoZI7 zydPYy{R7yp-B^98{Xww4to=H${p1t#A#l0&4`a(2o3+cO_ea3x+CPe2&igTJIqzTA z{&BEtH&=gpe*&y8wLb|iYkdk^&RU82G`L**XRuwnvHDW`vtWH$`{%&+lTXa&!R6Y& zfGuZi)-IRc*MrNoe-XQ!_ez)O+5Q{M;6Ek>-bPtN>hX!DOlC;tcF0VKamKLjsAPDc{+ zBd~nFYd;2?C+}~~PtiVp%@ zfxiOV*L4|hPv`m#+Q+$!`86WvT;j~}Td;HNLNdqi!1C8N|4aWKEa&&d+J8j*SXU Date: Fri, 3 Apr 2026 16:28:26 -0400 Subject: [PATCH 08/17] pre 1.0 --- .github/workflows/ci.yml | 9 + .gitignore | 2 + CMakeLists.txt | 4 + backend/_bridge.py | 235 ++++++++++--------------- backend/attention.py | 133 +++++++------- backend/core.py | 81 ++++++++- benchmarks/profile_gpu_bottlenecks.py | 69 ++++++++ cpp/include/grilly/ops/attention_ops.h | 25 ++- cpp/python/bindings_attention.cpp | 46 +++++ cpp/src/ops/attention_ops.cpp | 121 +++++++++++++ nn/_perf_policy.py | 77 ++++++++ nn/attention.py | 90 ++++++++-- nn/linear.py | 23 ++- nn/module.py | 4 +- tests/parity/README.md | 25 +++ tests/parity/test_functional_parity.py | 97 ++++++++++ tests/test_attention.py | 16 ++ tests/test_bridge_strict_mode.py | 43 +++++ utils/tensor_conversion.py | 50 ++++-- uv.lock | 2 +- 20 files changed, 894 insertions(+), 258 deletions(-) create mode 100644 benchmarks/profile_gpu_bottlenecks.py create mode 100644 nn/_perf_policy.py create mode 100644 tests/parity/README.md create mode 100644 tests/parity/test_functional_parity.py create mode 100644 tests/test_bridge_strict_mode.py diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b2a32f4..af94d2c 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -23,6 +23,15 @@ jobs: - name: Lint run: ruff check . + - name: Guard deprecated Compute usage in runtime paths + run: | + hits="$(rg -n "from grilly import Compute|from \.\.backend\.compute import Compute|from \.compute import Compute|Compute\(" functional nn/module.py utils/tensor_conversion.py --glob "*.py" || true)" + if [ -n "$hits" ]; then + echo "Deprecated Compute usage detected in runtime paths:" + echo "$hits" + exit 1 + fi + test: runs-on: ubuntu-latest steps: diff --git a/.gitignore b/.gitignore index 3268438..c4e0698 100644 --- a/.gitignore +++ b/.gitignore @@ -205,3 +205,5 @@ docs/BRIDGE_MIGRATION_CHECKLIST.md docs/GPU_OPTIMIZATION_REVIEW.md docs/PYTORCH_PARITY_STATUS.md docs/PYTORCH_PARITY_TASKLIST.md +/build_verify_pybind +/build_verify_pybind2 diff --git a/CMakeLists.txt b/CMakeLists.txt index 92f53fd..3b038fe 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -90,6 +90,10 @@ else() endif() # ── pybind11 ───────────────────────────────────────────────────────────── +# Silence "Using compatibility mode for Python, set PYBIND11_FINDPYTHON to NEW/OLD" +if(NOT DEFINED PYBIND11_FINDPYTHON) + set(PYBIND11_FINDPYTHON NEW) +endif() set(PYBIND11_DIR "${CMAKE_CURRENT_SOURCE_DIR}/third_party/pybind11") if(EXISTS "${PYBIND11_DIR}/CMakeLists.txt") add_subdirectory("${PYBIND11_DIR}" pybind11) diff --git a/backend/_bridge.py b/backend/_bridge.py index 131ee49..24c69df 100644 --- a/backend/_bridge.py +++ b/backend/_bridge.py @@ -19,6 +19,8 @@ import numpy as np logger = logging.getLogger("grilly.bridge") +_BRIDGE_STRICT = os.getenv("GRILLY_BRIDGE_STRICT", "0").strip().lower() in {"1", "true", "yes", "on"} +_FALLBACK_COUNTS = {} def _maybe_trace(op_name, inputs, output, **kwargs): @@ -31,6 +33,25 @@ def _maybe_trace(op_name, inputs, output, **kwargs): except ImportError: pass + +def _record_fallback(op_name: str, reason: Exception | str | None = None): + """Record fallback events and optionally raise in strict mode.""" + _FALLBACK_COUNTS[op_name] = _FALLBACK_COUNTS.get(op_name, 0) + 1 + if reason is not None: + logger.debug("%s fallback triggered: %s", op_name, reason) + if _BRIDGE_STRICT: + raise RuntimeError(f"GRILLY_BRIDGE_STRICT=1: bridge op '{op_name}' failed: {reason}") + + +def get_fallback_stats() -> dict[str, int]: + """Get fallback counts per bridge op.""" + return dict(_FALLBACK_COUNTS) + + +def reset_fallback_stats(): + """Reset bridge fallback counters.""" + _FALLBACK_COUNTS.clear() + try: import grilly_core as _core @@ -133,6 +154,7 @@ def linear(x, weight, bias=None): return result except Exception as e: logger.debug("linear GPU failed (%s), caller will use CPU fallback", e) + _record_fallback("linear", e) return None @@ -147,9 +169,7 @@ def relu(x): try: return _core.relu(dev, _ensure_f32_contiguous(x)) except Exception as e: - - logger.debug("GPU op failed: %s", e) - + _record_fallback("relu", e) return None @@ -193,9 +213,7 @@ def gelu(x): try: return _core.gelu(dev, x) except Exception as e: - - logger.debug("GPU op failed: %s", e) - + _record_fallback("gelu", e) return None @@ -207,9 +225,7 @@ def silu(x): try: return _core.silu(dev, _ensure_f32_contiguous(x)) except Exception as e: - - logger.debug("GPU op failed: %s", e) - + _record_fallback("silu", e) return None @@ -221,9 +237,7 @@ def tanh(x): try: return _core.tanh_act(dev, _ensure_f32_contiguous(x)) except Exception as e: - - logger.debug("GPU op failed: %s", e) - + _record_fallback("tanh", e) return None @@ -240,9 +254,7 @@ def relu_backward(grad_output, input): dev, _ensure_f32_contiguous(grad_output), _ensure_f32_contiguous(input) ) except Exception as e: - - logger.debug("GPU op failed: %s", e) - + _record_fallback("relu_backward", e) return None @@ -256,9 +268,7 @@ def gelu_backward(grad_output, input): dev, _ensure_f32_contiguous(grad_output), _ensure_f32_contiguous(input) ) except Exception as e: - - logger.debug("GPU op failed: %s", e) - + _record_fallback("gelu_backward", e) return None @@ -272,9 +282,7 @@ def silu_backward(grad_output, input): dev, _ensure_f32_contiguous(grad_output), _ensure_f32_contiguous(input) ) except Exception as e: - - logger.debug("GPU op failed: %s", e) - + _record_fallback("silu_backward", e) return None @@ -312,9 +320,7 @@ def lif_step( t_refrac_period, ) except Exception as e: - - logger.debug("GPU op failed: %s", e) - + _record_fallback("lif_step", e) return None @@ -347,9 +353,7 @@ def snn_node_forward( decay_input, ) except Exception as e: - - logger.debug("GPU op failed: %s", e) - + _record_fallback("snn_node_forward", e) return None @@ -368,9 +372,7 @@ def snn_node_backward(grad_spike, h_cache, alpha=2.0, surrogate_type=0, v_thresh v_threshold, ) except Exception as e: - - logger.debug("GPU op failed: %s", e) - + _record_fallback("snn_node_backward", e) return None @@ -393,9 +395,7 @@ def hebbian_learning( weight_decay, ) except Exception as e: - - logger.debug("GPU op failed: %s", e) - + _record_fallback("hebbian_learning", e) return None @@ -430,9 +430,7 @@ def stdp_learning( trace_decay, ) except Exception as e: - - logger.debug("GPU op failed: %s", e) - + _record_fallback("stdp_learning", e) return None @@ -465,9 +463,7 @@ def oja_learning(memories, inputs, num_vectors, dim, eta=0.01): eta, ) except Exception as e: - - logger.debug("GPU op failed: %s", e) - + _record_fallback("oja_learning", e) return None @@ -481,9 +477,7 @@ def synapse_filter(x_in, y_state, decay=0.95): dev, _ensure_f32_contiguous(x_in), _ensure_f32_contiguous(y_state), decay ) except Exception as e: - - logger.debug("GPU op failed: %s", e) - + _record_fallback("synapse_filter", e) return None @@ -538,9 +532,7 @@ def gif_neuron_step( t_refrac_period, ) except Exception as e: - - logger.debug("GPU op failed: %s", e) - + _record_fallback("gif_neuron_step", e) return None @@ -560,9 +552,7 @@ def conv2d(x, weight, bias=None, stride=(1, 1), padding=(0, 0), dilation=(1, 1), dev, x, weight, bias, list(stride), list(padding), list(dilation), groups ) except Exception as e: - - logger.debug("GPU op failed: %s", e) - + _record_fallback("conv2d", e) return None @@ -582,9 +572,7 @@ def conv2d_3x3_gelu(x, weight, bias): _ensure_f32_contiguous(bias), ) except Exception as e: - - logger.debug("GPU op failed: %s", e) - + _record_fallback("conv2d_3x3_gelu", e) return None @@ -596,9 +584,7 @@ def maxpool2x2(x): try: return _core.maxpool2x2(dev, _ensure_f32_contiguous(x)) except Exception as e: - - logger.debug("GPU op failed: %s", e) - + _record_fallback("maxpool2x2", e) return None @@ -610,9 +596,7 @@ def adaptive_avgpool_3x3(x): try: return _core.adaptive_avgpool_3x3(dev, _ensure_f32_contiguous(x)) except Exception as e: - - logger.debug("GPU op failed: %s", e) - + _record_fallback("adaptive_avgpool_3x3", e) return None @@ -630,9 +614,7 @@ def layernorm(x, gamma, beta, eps=1e-5): beta = _ensure_f32_contiguous(beta) return _core.layernorm(dev, x, gamma, beta, eps) except Exception as e: - - logger.debug("GPU op failed: %s", e) - + _record_fallback("layernorm", e) return None @@ -652,9 +634,7 @@ def layernorm_backward(grad_output, input, gamma, mean, var, eps=1e-5): eps, ) except Exception as e: - - logger.debug("GPU op failed: %s", e) - + _record_fallback("layernorm_backward", e) return None @@ -671,9 +651,7 @@ def rmsnorm(x, weight, eps=1e-5): weight = _ensure_f32_contiguous(weight) return _core.rmsnorm(dev, x, weight, eps) except Exception as e: - - logger.debug("GPU op failed: %s", e) - + _record_fallback("rmsnorm", e) return None @@ -690,9 +668,7 @@ def tanh_backward(grad_output, tanh_output): dev, _ensure_f32_contiguous(grad_output), _ensure_f32_contiguous(tanh_output) ) except Exception as e: - - logger.debug("GPU op failed: %s", e) - + _record_fallback("tanh_backward", e) return None @@ -704,9 +680,7 @@ def softmax(x, dim=-1): try: return _core.softmax(dev, _ensure_f32_contiguous(x), dim) except Exception as e: - - logger.debug("GPU op failed: %s", e) - + _record_fallback("softmax", e) return None @@ -720,9 +694,7 @@ def softmax_backward(grad_output, softmax_output): dev, _ensure_f32_contiguous(grad_output), _ensure_f32_contiguous(softmax_output) ) except Exception as e: - - logger.debug("GPU op failed: %s", e) - + _record_fallback("softmax_backward", e) return None @@ -742,9 +714,7 @@ def linear_backward(grad_output, input, weights): _ensure_f32_contiguous(weights), ) except Exception as e: - - logger.debug("GPU op failed: %s", e) - + _record_fallback("linear_backward", e) return None @@ -758,9 +728,7 @@ def dropout(x, random_mask, p=0.5, training=True): dev, _ensure_f32_contiguous(x), _ensure_f32_contiguous(random_mask), p, training ) except Exception as e: - - logger.debug("GPU op failed: %s", e) - + _record_fallback("dropout", e) return None @@ -786,9 +754,7 @@ def conv2d_backward_input( groups, ) except Exception as e: - - logger.debug("GPU op failed: %s", e) - + _record_fallback("conv2d_backward_input", e) return None @@ -819,9 +785,7 @@ def conv2d_backward_weight( has_bias, ) except Exception as e: - - logger.debug("GPU op failed: %s", e) - + _record_fallback("conv2d_backward_weight", e) return None @@ -838,9 +802,7 @@ def attention_scores(Q, K, scale=0.0): dev, _ensure_f32_contiguous(Q), _ensure_f32_contiguous(K), scale ) except Exception as e: - - logger.debug("GPU op failed: %s", e) - + _record_fallback("attention_scores", e) return None @@ -853,9 +815,7 @@ def attention_mask(scores, mask=None, causal=True, mask_value=-1e9): m = _ensure_f32_contiguous(mask) if mask is not None else None return _core.attention_mask(dev, _ensure_f32_contiguous(scores), m, causal, mask_value) except Exception as e: - - logger.debug("GPU op failed: %s", e) - + _record_fallback("attention_mask", e) return None @@ -869,9 +829,32 @@ def attention_output(weights, V): dev, _ensure_f32_contiguous(weights), _ensure_f32_contiguous(V) ) except Exception as e: + _record_fallback("attention_output", e) + return None + - logger.debug("GPU op failed: %s", e) +def attention_scores_softmax_output(Q, K, V, scale=0.0): + """Fused attention: scores + softmax + weighted V in one GPU submit. + Requires Q, K, V shaped (B, H, S, D) with identical S (self-attention / equal length). + Returns ``(output, softmax_weights)`` as core tensors, or None on failure. + """ + dev = _get_device() + if dev is None: + return None + fused = getattr(_core, "attention_scores_softmax_output", None) + if fused is None: + return None + try: + return fused( + dev, + _ensure_f32_contiguous(Q), + _ensure_f32_contiguous(K), + _ensure_f32_contiguous(V), + scale, + ) + except Exception as e: + _record_fallback("attention_scores_softmax_output", e) return None @@ -883,9 +866,7 @@ def attention_concat_heads(mh_output): try: return _core.attention_concat_heads(dev, _ensure_f32_contiguous(mh_output)) except Exception as e: - - logger.debug("GPU op failed: %s", e) - + _record_fallback("attention_concat_heads", e) return None @@ -899,9 +880,7 @@ def rope(x, cos_table=None, sin_table=None, base=10000.0, scaling=1.0): st = _ensure_f32_contiguous(sin_table) if sin_table is not None else None return _core.rope(dev, _ensure_f32_contiguous(x), ct, st, base, scaling) except Exception as e: - - logger.debug("GPU op failed: %s", e) - + _record_fallback("rope", e) return None @@ -918,9 +897,7 @@ def flash_attention2(Q, K, V, mask=None, scale=0.0, tile_size_q=64, tile_size_k= m = _ensure_f32_contiguous(mask) if mask is not None else None return _core.flash_attention2(dev, Q, K, V, m, scale, tile_size_q, tile_size_k) except Exception as e: - - logger.debug("GPU op failed: %s", e) - + _record_fallback("flash_attention2", e) return None @@ -1036,9 +1013,7 @@ def maxpool2d(x, kernel_size, stride=(2, 2), padding=(0, 0), dilation=(1, 1)): list(dilation), ) except Exception as e: - - logger.debug("GPU op failed: %s", e) - + _record_fallback("maxpool2d", e) return None @@ -1057,9 +1032,7 @@ def avgpool2d(x, kernel_size, stride=(2, 2), padding=(0, 0), count_include_pad=T count_include_pad, ) except Exception as e: - - logger.debug("GPU op failed: %s", e) - + _record_fallback("avgpool2d", e) return None @@ -1071,9 +1044,7 @@ def mean_pool(x): try: return _core.mean_pool(dev, _ensure_f32_contiguous(x)) except Exception as e: - - logger.debug("GPU op failed: %s", e) - + _record_fallback("mean_pool", e) return None @@ -1100,9 +1071,7 @@ def batchnorm2d_forward( training, ) except Exception as e: - - logger.debug("GPU op failed: %s", e) - + _record_fallback("batchnorm2d_forward", e) return None @@ -1122,9 +1091,7 @@ def cross_entropy_loss(logits, targets, label_smoothing=0.0): dev, _ensure_f32_contiguous(logits), targets, label_smoothing ) except Exception as e: - - logger.debug("GPU op failed: %s", e) - + _record_fallback("cross_entropy_loss", e) return None @@ -1139,9 +1106,7 @@ def cross_entropy_backward(logits, targets): targets = targets.astype(np.uint32) return _core.cross_entropy_backward(dev, _ensure_f32_contiguous(logits), targets) except Exception as e: - - logger.debug("GPU op failed: %s", e) - + _record_fallback("cross_entropy_backward", e) return None @@ -1181,9 +1146,7 @@ def adam_update( clear_grad, ) except Exception as e: - - logger.debug("GPU op failed: %s", e) - + _record_fallback("adam_update", e) return None @@ -1222,9 +1185,7 @@ def adamw_update( clear_grad, ) except Exception as e: - - logger.debug("GPU op failed: %s", e) - + _record_fallback("adamw_update", e) return None @@ -1242,9 +1203,7 @@ def embedding_lookup(token_ids, embeddings): token_ids = token_ids.astype(np.uint32) return _core.embedding_lookup(dev, token_ids, _ensure_f32_contiguous(embeddings)) except Exception as e: - - logger.debug("GPU op failed: %s", e) - + _record_fallback("embedding_lookup", e) return None @@ -1288,9 +1247,7 @@ def create_kv_cache( eviction_threshold, ) except Exception as e: - - logger.debug("GPU op failed: %s", e) - + _record_fallback("create_kv_cache", e) return None @@ -1305,9 +1262,7 @@ def kv_cache_append(kv_cache, new_keys, new_values): _core.kv_cache_append(dev, kv_cache, new_keys, new_values) return True except Exception as e: - - logger.debug("GPU op failed: %s", e) - + _record_fallback("kv_cache_append", e) return None @@ -1319,9 +1274,7 @@ def kv_cache_decode(kv_cache): try: return _core.kv_cache_decode(dev, kv_cache) except Exception as e: - - logger.debug("GPU op failed: %s", e) - + _record_fallback("kv_cache_decode", e) return None @@ -1335,9 +1288,7 @@ def kv_cache_evict_h2o(kv_cache, attention_scores=None, num_evict=0): _core.kv_cache_evict_h2o(dev, kv_cache, scores, num_evict) return True except Exception as e: - - logger.debug("GPU op failed: %s", e) - + _record_fallback("kv_cache_evict_h2o", e) return None diff --git a/backend/attention.py b/backend/attention.py index a9a7952..7f4cff3 100644 --- a/backend/attention.py +++ b/backend/attention.py @@ -499,7 +499,7 @@ def flash_attention2( ) push_constants_init = struct.pack( - "IIIIfIIIII", + "IIIIfIIIIII", batch_size, seq_len, num_heads, @@ -513,63 +513,12 @@ def flash_attention2( 0, # q_tile_idx, k_tile_idx (not used in init) ) - # Dispatch initialization + # Single submit for init + all tile passes + finalize (one fence wait). workgroups_init_x = 16 workgroups_init_y = (num_q_positions + 15) // 16 - self.core._dispatch_compute( - pipeline, - pipeline_layout, - descriptor_set_init, - workgroups_init_x, - push_constants_init, - workgroups_init_y, - ) - - # Pass 1: Process all tiles - for q_tile in range(num_tiles_q): - for k_tile in range(num_tiles_k): - descriptor_set_tile = self.pipelines.get_cached_descriptor_set( - "flash-attention2", - [ - (self._get_buffer_handle(buf_q), q_flat.nbytes), - (self._get_buffer_handle(buf_k), k_flat.nbytes), - (self._get_buffer_handle(buf_v), v_flat.nbytes), - (mask_handle, mask_size), - (self._get_buffer_handle(buf_output), output_accum_size), - (self._get_buffer_handle(buf_running_max), running_max_size), - (self._get_buffer_handle(buf_running_sum), running_sum_size), - (self._get_buffer_handle(buf_output_accum), output_accum_size), - ], - ) - - push_constants_tile = struct.pack( - "IIIIfIIIII", - batch_size, - seq_len, - num_heads, - head_dim, - scale, - tile_size_q, - tile_size_k, - 1, # pass_type = 1 (process tile) - 1 if mask is not None else 0, # has_mask - q_tile, - k_tile, - ) - - # Dispatch tile processing - workgroups_tile_x = (tile_size_k + 15) // 16 - workgroups_tile_y = (batch_size * num_heads * tile_size_q + 15) // 16 - self.core._dispatch_compute( - pipeline, - pipeline_layout, - descriptor_set_tile, - workgroups_tile_x, - push_constants_tile, - workgroups_tile_y, - ) - - # Pass 2: Finalize output + workgroups_final_x = (head_dim + 15) // 16 + workgroups_final_y = (num_q_positions + 15) // 16 + descriptor_set_final = self.pipelines.get_cached_descriptor_set( "flash-attention2", [ @@ -585,7 +534,7 @@ def flash_attention2( ) push_constants_final = struct.pack( - "IIIIfIIIII", + "IIIIfIIIIII", batch_size, seq_len, num_heads, @@ -599,17 +548,65 @@ def flash_attention2( 0, # q_tile_idx, k_tile_idx (not used in finalize) ) - # Dispatch finalization - workgroups_final_x = (head_dim + 15) // 16 - workgroups_final_y = (num_q_positions + 15) // 16 - self.core._dispatch_compute( - pipeline, - pipeline_layout, - descriptor_set_final, - workgroups_final_x, - push_constants_final, - workgroups_final_y, - ) + with self.core.record_commands() as rec: + rec.dispatch( + pipeline, + pipeline_layout, + descriptor_set_init, + (workgroups_init_x, workgroups_init_y, 1), + push_constants_init, + ) + rec.barrier() + + for q_tile in range(num_tiles_q): + for k_tile in range(num_tiles_k): + descriptor_set_tile = self.pipelines.get_cached_descriptor_set( + "flash-attention2", + [ + (self._get_buffer_handle(buf_q), q_flat.nbytes), + (self._get_buffer_handle(buf_k), k_flat.nbytes), + (self._get_buffer_handle(buf_v), v_flat.nbytes), + (mask_handle, mask_size), + (self._get_buffer_handle(buf_output), output_accum_size), + (self._get_buffer_handle(buf_running_max), running_max_size), + (self._get_buffer_handle(buf_running_sum), running_sum_size), + (self._get_buffer_handle(buf_output_accum), output_accum_size), + ], + ) + + push_constants_tile = struct.pack( + "IIIIfIIIIII", + batch_size, + seq_len, + num_heads, + head_dim, + scale, + tile_size_q, + tile_size_k, + 1, # pass_type = 1 (process tile) + 1 if mask is not None else 0, # has_mask + q_tile, + k_tile, + ) + + workgroups_tile_x = (tile_size_k + 15) // 16 + workgroups_tile_y = (batch_size * num_heads * tile_size_q + 15) // 16 + rec.dispatch( + pipeline, + pipeline_layout, + descriptor_set_tile, + (workgroups_tile_x, workgroups_tile_y, 1), + push_constants_tile, + ) + rec.barrier() + + rec.dispatch( + pipeline, + pipeline_layout, + descriptor_set_final, + (workgroups_final_x, workgroups_final_y, 1), + push_constants_final, + ) # Download results result = self._download_buffer(buf_output, output_accum_size, np.float32) diff --git a/backend/core.py b/backend/core.py index 3421b9d..97e0ea9 100644 --- a/backend/core.py +++ b/backend/core.py @@ -681,8 +681,15 @@ def _dispatch_compute( push_constants: bytes = None, workgroup_y: int = 1, workgroup_z: int = 1, + *, + wait_previous: bool = True, ): - """Dispatch compute shader using pre-allocated command buffer.""" + """Dispatch compute shader using pre-allocated command buffer. + + Args: + wait_previous: If False, skip waiting for previous dispatch. + Use only when you know the queue is idle (e.g., first dispatch). + """ command_buffer = self._cmd_buffer # Reset and re-record the reusable command buffer @@ -726,7 +733,8 @@ def _dispatch_compute( vkEndCommandBuffer(command_buffer) # Wait for previous dispatch to finish, then reset the fence. - vkWaitForFences(self.device, 1, [self._fence], VK_TRUE, 2_000_000_000) + if wait_previous: + vkWaitForFences(self.device, 1, [self._fence], VK_TRUE, 2_000_000_000) vkResetFences(self.device, 1, [self._fence]) # Submit with fence — avoids the heavier vkQueueWaitIdle drain. @@ -746,6 +754,75 @@ def _wait_fence(self, timeout_ns: int = 2_000_000_000): """Wait for the most recent dispatch to complete.""" vkWaitForFences(self.device, 1, [self._fence], VK_TRUE, timeout_ns) + def _dispatch_compute_async( + self, + pipeline, + pipeline_layout, + descriptor_set, + workgroup_x: int, + push_constants: bytes = None, + workgroup_y: int = 1, + workgroup_z: int = 1, + ): + """Async dispatch: record command and return without waiting. + + Returns a handle that can be waited on later via _wait_async(). + Used for batching multiple dispatches before a single fence wait. + """ + command_buffer = self._cmd_buffer + + # Reset and record the reusable command buffer + vkResetCommandBuffer(command_buffer, 0) + + begin_info = VkCommandBufferBeginInfo( + sType=VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO, + flags=VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT, + ) + vkBeginCommandBuffer(command_buffer, begin_info) + + vkCmdBindPipeline(command_buffer, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline) + vkCmdBindDescriptorSets( + command_buffer, + VK_PIPELINE_BIND_POINT_COMPUTE, + pipeline_layout, + 0, + 1, + [descriptor_set], + 0, + None, + ) + + if push_constants: + push_buf = ctypes.create_string_buffer(push_constants) + vkCmdPushConstants( + command_buffer, + pipeline_layout, + VK_SHADER_STAGE_COMPUTE_BIT, + 0, + len(push_constants), + ctypes.addressof(push_buf), + ) + + vkCmdDispatch(command_buffer, workgroup_x, workgroup_y, workgroup_z) + vkEndCommandBuffer(command_buffer) + + # Reset fence and submit (do not wait) + vkResetFences(self.device, 1, [self._fence]) + + submit_info = VkSubmitInfo( + sType=VK_STRUCTURE_TYPE_SUBMIT_INFO, + commandBufferCount=1, + pCommandBuffers=[command_buffer], + ) + vkQueueSubmit(self.queue, 1, [submit_info], self._fence) + + # Return handle for optional wait + return self._fence + + def _wait_async(self, timeout_ns: int = 2_000_000_000): + """Wait for async dispatch to complete.""" + vkWaitForFences(self.device, 1, [self._fence], VK_TRUE, timeout_ns) + def cleanup(self): """Cleanup Vulkan resources""" if hasattr(self, "device") and self.device: diff --git a/benchmarks/profile_gpu_bottlenecks.py b/benchmarks/profile_gpu_bottlenecks.py new file mode 100644 index 0000000..de39990 --- /dev/null +++ b/benchmarks/profile_gpu_bottlenecks.py @@ -0,0 +1,69 @@ +""" +Profile GPU-vs-CPU bottlenecks for parity-critical ops. + +Usage: + python benchmarks/profile_gpu_bottlenecks.py +""" + +from __future__ import annotations + +import cProfile +import io +import pstats +import time + +import numpy as np + +from grilly.nn.attention import MultiheadAttention +from grilly.nn.linear import Linear + + +def _print_stats(pr: cProfile.Profile, title: str, top_n: int = 25): + s = io.StringIO() + pstats.Stats(pr, stream=s).sort_stats("cumtime").print_stats(top_n) + print(f"\n=== {title} ===") + print(s.getvalue()) + + +def profile_linear(): + x = np.random.randn(64, 512).astype(np.float32) + linear = Linear(512, 512) + + pr = cProfile.Profile() + pr.enable() + for _ in range(20): + _ = linear(x) + pr.disable() + _print_stats(pr, "Linear forward profile") + + w = np.asarray(linear.weight, dtype=np.float32) + b = np.asarray(linear.bias, dtype=np.float32) + t0 = time.perf_counter() + for _ in range(20): + _ = x @ w.T + b + t1 = time.perf_counter() + print(f"CPU linear baseline (20 iters): {(t1 - t0) * 1000.0:.3f} ms") + + +def profile_attention(): + attn = MultiheadAttention(embed_dim=512, num_heads=8) + q = np.random.randn(2, 64, 512).astype(np.float32) + k = np.random.randn(2, 64, 512).astype(np.float32) + v = np.random.randn(2, 64, 512).astype(np.float32) + + pr = cProfile.Profile() + pr.enable() + for _ in range(10): + _ = attn(q, k, v) + pr.disable() + _print_stats(pr, "MultiheadAttention forward profile") + + +def main(): + print("Profiling parity-critical GPU paths...") + profile_linear() + profile_attention() + + +if __name__ == "__main__": + main() diff --git a/cpp/include/grilly/ops/attention_ops.h b/cpp/include/grilly/ops/attention_ops.h index bd88223..e329995 100644 --- a/cpp/include/grilly/ops/attention_ops.h +++ b/cpp/include/grilly/ops/attention_ops.h @@ -32,10 +32,28 @@ struct AttentionScoresParams { uint32_t passType; // 0 = compute scores }; +// Used by attentionOutput and attentionScoresSoftmaxOutput (declare before fused API). +struct AttentionOutputParams { + uint32_t batchSize; + uint32_t seqLen; + uint32_t numHeads; + uint32_t headDim; +}; + void attentionScores(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, const float* Q, const float* K, float* scores, const AttentionScoresParams& p); +/// Fused: attention scores → softmax (last dim) → attention output (weights @ V). +/// Single submit/wait; no host round-trip between scores and softmax. +/// Q, K, V: (B, H, S, D). Writes output (B, H, S, D) and softmaxWeights (B, H, S, S). +void attentionScoresSoftmaxOutput(CommandBatch& batch, BufferPool& pool, + PipelineCache& cache, const float* Q, + const float* K, const float* V, float* output, + float* softmaxWeights, + const AttentionScoresParams& scoreParams, + const AttentionOutputParams& outParams); + // ── Attention mask ─────────────────────────────────────────────────────── // Shader: attention-mask.spv // local_size: (256, 1, 1) @@ -58,13 +76,6 @@ void attentionMask(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, // local_size: (256, 1, 1) // Buffers: weights(0) readonly, V(1) readonly, output(2) write -struct AttentionOutputParams { - uint32_t batchSize; - uint32_t seqLen; - uint32_t numHeads; - uint32_t headDim; -}; - void attentionOutput(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, const float* weights, const float* V, float* output, const AttentionOutputParams& p); diff --git a/cpp/python/bindings_attention.cpp b/cpp/python/bindings_attention.cpp index 8457bcf..55766d4 100644 --- a/cpp/python/bindings_attention.cpp +++ b/cpp/python/bindings_attention.cpp @@ -91,6 +91,52 @@ void register_attention_ops(py::module_& m) { py::arg("scale") = 0.0f, "GPU attention scores: Q @ K^T / sqrt(d_h)"); + // ── Fused: scores + softmax + output (single GPU submit) ──────────── + m.def( + "attention_scores_softmax_output", + [](GrillyCoreContext& ctx, + py::array_t Q, py::array_t K, py::array_t V, + float scale) -> py::tuple { + auto qBuf = Q.request(); + auto kBuf = K.request(); + auto vBuf = V.request(); + if (qBuf.ndim != 4 || kBuf.ndim != 4 || vBuf.ndim != 4) + throw std::runtime_error("Q, K, V must be 4D (B, H, S, D)"); + + uint32_t B = static_cast(qBuf.shape[0]); + uint32_t H = static_cast(qBuf.shape[1]); + uint32_t S = static_cast(qBuf.shape[2]); + uint32_t D = static_cast(qBuf.shape[3]); + + if (kBuf.shape[0] != B || kBuf.shape[1] != H || kBuf.shape[2] != S || + kBuf.shape[3] != D || vBuf.shape[0] != B || vBuf.shape[1] != H || + vBuf.shape[2] != S || vBuf.shape[3] != D) + throw std::runtime_error("Q, K, V shapes must match"); + + if (scale == 0.0f) scale = 1.0f / std::sqrt(float(D)); + + py::array_t outArr({ + static_cast(B), static_cast(H), + static_cast(S), static_cast(D)}); + py::array_t wArr({ + static_cast(B), static_cast(H), + static_cast(S), static_cast(S)}); + + grilly::ops::AttentionScoresParams sp{B, S, H, D, scale, 0}; + grilly::ops::AttentionOutputParams outp{B, S, H, D}; + grilly::ops::attentionScoresSoftmaxOutput( + ctx.batch, ctx.pool, ctx.cache, + static_cast(qBuf.ptr), + static_cast(kBuf.ptr), + static_cast(vBuf.ptr), + static_cast(outArr.request().ptr), + static_cast(wArr.request().ptr), sp, outp); + return py::make_tuple(Tensor::from_numpy(outArr), Tensor::from_numpy(wArr)); + }, + py::arg("device"), py::arg("Q"), py::arg("K"), py::arg("V"), + py::arg("scale") = 0.0f, + "GPU fused attention: scores + softmax + output (one submit); returns (output, softmax_weights)"); + // ── Attention mask (causal or custom) ──────────────────────────────── m.def( "attention_mask", diff --git a/cpp/src/ops/attention_ops.cpp b/cpp/src/ops/attention_ops.cpp index b7755b6..a40e751 100644 --- a/cpp/src/ops/attention_ops.cpp +++ b/cpp/src/ops/attention_ops.cpp @@ -1,6 +1,9 @@ #include "grilly/ops/attention_ops.h" +#include "grilly/ops/activations.h" + #include +#include namespace grilly { namespace ops { @@ -70,6 +73,124 @@ void attentionScores(CommandBatch& batch, BufferPool& pool, PipelineCache& cache pool.release(bufScores); } +void attentionScoresSoftmaxOutput(CommandBatch& batch, BufferPool& pool, + PipelineCache& cache, const float* Q, + const float* K, const float* V, float* output, + float* softmaxWeights, const AttentionScoresParams& sp, + const AttentionOutputParams& op) { + const uint32_t B = sp.batchSize; + const uint32_t S = sp.seqLen; + const uint32_t H = sp.numHeads; + const uint32_t D = sp.headDim; + + if (op.batchSize != B || op.seqLen != S || op.numHeads != H || op.headDim != D) { + throw std::runtime_error( + "attentionScoresSoftmaxOutput: score and output params must match"); + } + + const size_t qkvBytes = size_t(B) * H * S * D * sizeof(float); + const size_t scoreBytes = size_t(B) * H * S * S * sizeof(float); + const size_t outBytes = qkvBytes; + + GrillyBuffer bufQ = pool.acquire(qkvBytes); + GrillyBuffer bufK = pool.acquire(qkvBytes); + GrillyBuffer bufVDummy = pool.acquire(sizeof(float)); + GrillyBuffer bufScores = pool.acquire(scoreBytes); + GrillyBuffer bufWeights = pool.acquire(scoreBytes); + GrillyBuffer bufV = pool.acquire(qkvBytes); + GrillyBuffer bufOut = pool.acquire(outBytes); + + const uint32_t totalSoftmaxRows = B * H * S; + const size_t auxBytes = size_t(totalSoftmaxRows) * sizeof(float); + + GrillyBuffer bufMax = pool.acquire(auxBytes); + GrillyBuffer bufSumExp = pool.acquire(auxBytes); + + pool.upload(bufQ, Q, qkvBytes); + pool.upload(bufK, K, qkvBytes); + pool.upload(bufV, V, qkvBytes); + + PipelineEntry pipeScores = cache.getOrCreate("attention-scores", 4, + sizeof(AttentionScoresParams)); + std::vector scoresInfos = { + {bufQ.handle, 0, qkvBytes}, + {bufK.handle, 0, qkvBytes}, + {bufVDummy.handle, 0, sizeof(float)}, + {bufScores.handle, 0, scoreBytes}, + }; + VkDescriptorSet descScores = cache.allocDescriptorSet("attention-scores", scoresInfos); + + PipelineEntry pipeSoftmax = cache.getOrCreate("activation-softmax", 4, + sizeof(SoftmaxParams)); + std::vector softmaxInfos = { + {bufScores.handle, 0, scoreBytes}, + {bufWeights.handle, 0, scoreBytes}, + {bufMax.handle, 0, auxBytes}, + {bufSumExp.handle, 0, auxBytes}, + }; + VkDescriptorSet descSoftmax = cache.allocDescriptorSet("activation-softmax", softmaxInfos); + + PipelineEntry pipeOut = cache.getOrCreate("attention-output", 3, + sizeof(AttentionOutputParams)); + std::vector outInfos = { + {bufWeights.handle, 0, scoreBytes}, + {bufV.handle, 0, qkvBytes}, + {bufOut.handle, 0, outBytes}, + }; + VkDescriptorSet descOut = cache.allocDescriptorSet("attention-output", outInfos); + + const uint32_t gxS = (S + 15) / 16; + const uint32_t gyS = (S + 15) / 16; + const uint32_t gzS = B * H; + + const uint32_t softmaxGx = (totalSoftmaxRows + 255) / 256; + const uint32_t totalElements = totalSoftmaxRows * S; + const uint32_t softmaxGx2 = (totalElements + 255) / 256; + + const uint32_t outTotal = B * H * S * D; + const uint32_t gxOut = (outTotal + 255) / 256; + + batch.begin(); + + batch.dispatch(pipeScores.pipeline, pipeScores.layout, descScores, gxS, gyS, gzS, + &sp, sizeof(sp)); + batch.barrier(); + + SoftmaxParams push0{1, totalSoftmaxRows, S, 0, S}; + batch.dispatch(pipeSoftmax.pipeline, pipeSoftmax.layout, descSoftmax, softmaxGx, 1, 1, + &push0, sizeof(push0)); + batch.barrier(); + + SoftmaxParams push1{1, totalSoftmaxRows, S, 1, S}; + batch.dispatch(pipeSoftmax.pipeline, pipeSoftmax.layout, descSoftmax, softmaxGx, 1, 1, + &push1, sizeof(push1)); + batch.barrier(); + + SoftmaxParams push2{1, totalSoftmaxRows, S, 2, S}; + batch.dispatch(pipeSoftmax.pipeline, pipeSoftmax.layout, descSoftmax, softmaxGx2, 1, 1, + &push2, sizeof(push2)); + batch.barrier(); + + batch.dispatch(pipeOut.pipeline, pipeOut.layout, descOut, gxOut, 1, 1, &op, + sizeof(op)); + + batch.submitDeferred(); + batch.waitForCompletion(); + + pool.download(bufOut, output, outBytes); + pool.download(bufWeights, softmaxWeights, scoreBytes); + + pool.release(bufQ); + pool.release(bufK); + pool.release(bufVDummy); + pool.release(bufScores); + pool.release(bufWeights); + pool.release(bufV); + pool.release(bufOut); + pool.release(bufMax); + pool.release(bufSumExp); +} + // ── Attention mask ─────────────────────────────────────────────────────── void attentionMask(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, diff --git a/nn/_perf_policy.py b/nn/_perf_policy.py new file mode 100644 index 0000000..f755455 --- /dev/null +++ b/nn/_perf_policy.py @@ -0,0 +1,77 @@ +"""Adaptive runtime policy to avoid slower-than-CPU execution.""" + +from __future__ import annotations + +import os +import time + +_AUTO_FASTEST = os.getenv("GRILLY_AUTO_FASTEST", "1").strip().lower() in { + "1", + "true", + "yes", + "on", +} +_DECISIONS: dict[str, str] = {} + + +def _time_ms(fn, warmup: int = 1, repeats: int = 2): + for _ in range(max(0, warmup)): + fn() + t = [] + out = None + for _ in range(max(1, repeats)): + t0 = time.perf_counter() + out = fn() + t1 = time.perf_counter() + t.append((t1 - t0) * 1000.0) + return (sum(t) / len(t)), out + + +def choose_fastest(op_key: str, gpu_fn, cpu_fn): + """Run fastest backend for this op/shape key and cache choice.""" + if not _AUTO_FASTEST: + return gpu_fn() + + decision = _DECISIONS.get(op_key) + if decision == "gpu": + try: + out = gpu_fn() + if out is not None: + return out + except Exception: + pass + _DECISIONS[op_key] = "cpu" + return cpu_fn() + if decision == "cpu": + return cpu_fn() + + # First encounter for this shape/op: benchmark both paths. + gpu_ms = float("inf") + gpu_out = None + try: + gpu_ms, gpu_out = _time_ms(gpu_fn, warmup=1, repeats=2) + except Exception: + gpu_ms = float("inf") + gpu_out = None + + try: + cpu_ms, cpu_out = _time_ms(cpu_fn, warmup=1, repeats=2) + except Exception: + # If CPU path fails, force GPU. + _DECISIONS[op_key] = "gpu" + return gpu_out + + if gpu_out is not None and gpu_ms <= cpu_ms * 0.98: + _DECISIONS[op_key] = "gpu" + return gpu_out + + _DECISIONS[op_key] = "cpu" + return cpu_out + + +def get_perf_decisions() -> dict[str, str]: + return dict(_DECISIONS) + + +def reset_perf_decisions(): + _DECISIONS.clear() diff --git a/nn/attention.py b/nn/attention.py index 07d5a18..9646897 100644 --- a/nn/attention.py +++ b/nn/attention.py @@ -5,7 +5,12 @@ import numpy as np -from ._helpers import _get_param_array +from ._helpers import ( + _USE_CPP_BRIDGE, + _bridge, + _bridge_to_numpy, + _get_param_array, +) from .module import Module @@ -103,51 +108,102 @@ def forward( k_reshaped = k_4d.transpose(0, 2, 1, 3) # (batch, num_heads, seq_len_k, head_dim) v_reshaped = v_4d.transpose(0, 2, 1, 3) # (batch, num_heads, seq_len_k, head_dim) - # Compute attention scores + # Fused C++ path: one submit for scores + softmax + output (no mask; Sq == Sk only). + if _USE_CPP_BRIDGE and seq_len_q == seq_len_k and mask is None: + fused = _bridge.attention_scores_softmax_output( + q_reshaped, k_reshaped, v_reshaped + ) + if fused is not None: + br_out, br_w = fused + attn_output = _bridge_to_numpy(br_out) + scores_softmax = _bridge_to_numpy(br_w) + if attn_output is not None and scores_softmax is not None: + attn_output = attn_output.astype(np.float32) + scores_softmax = scores_softmax.astype(np.float32) + self._cached_scores_pre_softmax = None + self._cached_scores = scores_softmax.copy() + self._cached_attn_output = attn_output.copy() + attn_output_reshaped = attn_output.transpose(0, 2, 1, 3).reshape( + batch_size, seq_len_q, self.embed_dim + ) + attn_weights = scores_softmax + output = self.out_proj(attn_output_reshaped) + return output, attn_weights + + # C++ bridge: scores → optional padding mask (CPU) → softmax → attention output (GPU). + # Kernel uses one sequence length for Q and K; skip when cross-attention has Sq != Sk. + if _USE_CPP_BRIDGE and seq_len_q == seq_len_k: + br_scores = _bridge.attention_scores(q_reshaped, k_reshaped) + if br_scores is not None: + scores = _bridge_to_numpy(br_scores) + if scores.shape != (batch_size, self.num_heads, seq_len_q, seq_len_k): + if scores.size == batch_size * self.num_heads * seq_len_q * seq_len_k: + scores = scores.reshape( + batch_size, self.num_heads, seq_len_q, seq_len_k + ) + else: + scores = None + if scores is not None: + self._cached_scores_pre_softmax = scores.copy() + if mask is not None: + if mask.ndim == 2: + mask_expanded = mask[:, None, :, None] + mask_expanded = np.broadcast_to(mask_expanded, scores.shape) + scores = np.where(mask_expanded > 0, scores, -1e9) + else: + scores = scores + mask.astype(np.float32) + br_soft = _bridge.softmax(scores, dim=-1) + if br_soft is not None: + scores_softmax = _bridge_to_numpy(br_soft).astype(np.float32) + self._cached_scores = scores_softmax.copy() + br_out = _bridge.attention_output(scores_softmax, v_reshaped) + if br_out is not None: + attn_output = _bridge_to_numpy(br_out).astype(np.float32) + self._cached_attn_output = attn_output.copy() + attn_output_reshaped = attn_output.transpose(0, 2, 1, 3).reshape( + batch_size, seq_len_q, self.embed_dim + ) + attn_weights = scores_softmax + output = self.out_proj(attn_output_reshaped) + return output, attn_weights + + # Fallback: legacy Python Vulkan backend + CPU softmax / einsum scores = backend.attention.attention_scores( q_reshaped, k_reshaped, num_heads=self.num_heads, head_dim=self.head_dim ) # Backend may return scores in different shape - normalize to (batch, num_heads, seq_len_q, seq_len_k) if scores.shape == (batch_size, seq_len_q, self.num_heads, seq_len_k): - # Backend returned (batch, seq_len_q, num_heads, seq_len_k) - transpose to (batch, num_heads, seq_len_q, seq_len_k) scores = scores.transpose(0, 2, 1, 3) elif scores.shape != (batch_size, self.num_heads, seq_len_q, seq_len_k): - # Unexpected shape - try to infer if ( scores.ndim == 4 and scores.size == batch_size * self.num_heads * seq_len_q * seq_len_k ): scores = scores.reshape(batch_size, self.num_heads, seq_len_q, seq_len_k) else: - # Fallback: compute manually scores = np.einsum("bhqd,bhkd->bhqk", q_reshaped, k_reshaped) / np.sqrt( - self.head_dim + float(self.head_dim) ) - # Cache pre-softmax scores for backward self._cached_scores_pre_softmax = scores.copy() - # Apply mask if provided if mask is not None: - scores = backend.attention.attention_mask(scores, mask) + if mask.ndim == 2: + mask_expanded = mask[:, None, :, None] + mask_expanded = np.broadcast_to(mask_expanded, scores.shape) + scores = np.where(mask_expanded > 0, scores, -1e9) + else: + scores = scores + mask.astype(np.float32) - # Apply softmax (CPU for now - backend softmax expects 3D) - # scores is (batch, num_heads, seq_len_q, seq_len_k) scores_max = scores.max(axis=-1, keepdims=True) scores_exp = np.exp(scores - scores_max) scores_softmax = scores_exp / scores_exp.sum(axis=-1, keepdims=True) - # Cache softmax scores for backward self._cached_scores = scores_softmax.copy() - # Compute attention output - # scores_softmax: (batch, num_heads, seq_len_q, seq_len_k) - # v_reshaped: (batch, num_heads, seq_len_k, head_dim) - # Output: (batch, num_heads, seq_len_q, head_dim) attn_output = np.einsum("bhqk,bhkd->bhqd", scores_softmax, v_reshaped) - # Cache attention output for backward (in shape: batch, num_heads, seq_len_q, head_dim) self._cached_attn_output = attn_output.copy() # Reshape back: (batch, num_heads, seq_len_q, head_dim) -> (batch, seq_len_q, embed_dim) diff --git a/nn/linear.py b/nn/linear.py index b62fb25..bfcc4ec 100644 --- a/nn/linear.py +++ b/nn/linear.py @@ -14,6 +14,7 @@ _get_param_array, ) from .module import Module +from ._perf_policy import choose_fastest class Linear(Module): @@ -192,9 +193,27 @@ def forward(self, x) -> np.ndarray: weight = _get_param_array(self.weight) bias = _get_param_array(self.bias) if self.bias is not None else None + def cpu_linear(): + x_arr = np.asarray(x, dtype=np.float32) + w_arr = np.asarray(weight, dtype=np.float32) + out = x_arr @ w_arr.T + if bias is not None: + out = out + np.asarray(bias, dtype=np.float32) + return np.asarray(out, dtype=np.float32) + # C++ bridge fast path (handles both numpy and VulkanTensor via __array__) if _USE_CPP_BRIDGE: - result = _bridge_to_numpy(_bridge.linear(x, weight, bias)) + def gpu_linear(): + return _bridge_to_numpy(_bridge.linear(x, weight, bias)) + + # Auto-fastest policy only for numpy-in/numpy-out path. + if isinstance(x, np.ndarray) and not self._return_gpu_tensor: + batch = int(np.prod(x.shape[:-1])) if x.ndim > 1 else 1 + in_features = int(x.shape[-1]) if x.ndim > 0 else self.in_features + op_key = f"linear:{batch}x{in_features}x{self.out_features}" + return choose_fastest(op_key, gpu_linear, cpu_linear) + + result = gpu_linear() if result is not None: return result @@ -208,6 +227,8 @@ def forward(self, x) -> np.ndarray: return_gpu_tensor=self._return_gpu_tensor, ) + return cpu_linear() + def backward(self, grad_output: np.ndarray, x: np.ndarray = None) -> np.ndarray: """ Backward pass using fnn-linear-backward.glsl diff --git a/nn/module.py b/nn/module.py index 51e79f5..75058b2 100644 --- a/nn/module.py +++ b/nn/module.py @@ -68,9 +68,9 @@ def _get_backend(self): """Execute get backend.""" if self._backend is None: - from grilly import Compute + from ..utils.device_manager import get_device_manager - self._backend = Compute() + self._backend = get_device_manager().vulkan return self._backend def _convert_input(self, x: np.ndarray | Any): diff --git a/tests/parity/README.md b/tests/parity/README.md new file mode 100644 index 0000000..ea5d275 --- /dev/null +++ b/tests/parity/README.md @@ -0,0 +1,25 @@ +# Numerical parity tests (PyTorch reference) + +This directory holds tests that compare Grilly outputs to **numpy references** and, when +`torch` is installed, to **`torch.nn.functional`** equivalents. + +## Running + +```bash +# Core parity (numpy reference only; no optional deps) +uv run pytest tests/parity/ -v + +# Include PyTorch cross-checks (requires: pip install "grilly[torch]" or torch) +uv run pytest tests/parity/ -v +``` + +## Conventions + +- **Tolerances**: `rtol=1e-4`, `atol=1e-5` for float32 unless an op documents a looser policy. +- **Weight layout**: `grilly.functional.linear` uses `weight` shaped `(out_features, in_features)`, matching `F.linear` / `nn.Linear.weight`. +- **Markers**: tests are tagged `parity` (see root `tests/conftest.py`). + +## Roadmap + +See `docs/PYTORCH_PARITY_TASKLIST.md` (workstream A1). Planned additions: small CNN/MLP modules, +transformer encoder blocks, optimizer stepping snapshots, and a summarized pass/fail table in CI. diff --git a/tests/parity/test_functional_parity.py b/tests/parity/test_functional_parity.py new file mode 100644 index 0000000..c0ce877 --- /dev/null +++ b/tests/parity/test_functional_parity.py @@ -0,0 +1,97 @@ +""" +Numerical parity: `grilly.functional` vs numpy and (optional) PyTorch references. +""" + +import numpy as np +import pytest + + +def _ref_linear(x: np.ndarray, weight: np.ndarray, bias: np.ndarray | None) -> np.ndarray: + y = x @ weight.T + if bias is not None: + y = y + bias + return y.astype(np.float32, copy=False) + + +@pytest.mark.parity +def test_linear_matches_numpy_reference(): + from grilly.functional import linear + + rng = np.random.default_rng(0) + x = rng.standard_normal((8, 32)).astype(np.float32) + w = rng.standard_normal((16, 32)).astype(np.float32) + b = rng.standard_normal((16,)).astype(np.float32) + expected = _ref_linear(x, w, b) + got = linear(x, w, b) + np.testing.assert_allclose(got, expected, rtol=1e-4, atol=1e-5) + + +@pytest.mark.parity +def test_linear_matches_numpy_no_bias(): + from grilly.functional import linear + + rng = np.random.default_rng(1) + x = rng.standard_normal((4, 64)).astype(np.float32) + w = rng.standard_normal((24, 64)).astype(np.float32) + expected = _ref_linear(x, w, None) + got = linear(x, w, None) + np.testing.assert_allclose(got, expected, rtol=1e-4, atol=1e-5) + + +@pytest.mark.parity +def test_relu_matches_numpy_reference(): + from grilly.functional import relu + + rng = np.random.default_rng(2) + x = rng.standard_normal((5, 17)).astype(np.float32) + expected = np.maximum(0.0, x).astype(np.float32) + got = relu(x) + np.testing.assert_allclose(got, expected, rtol=1e-4, atol=1e-5) + + +@pytest.mark.parity +def test_linear_relu_chain_matches_numpy_reference(): + from grilly.functional import linear, relu + + rng = np.random.default_rng(3) + x = rng.standard_normal((6, 20)).astype(np.float32) + w = rng.standard_normal((12, 20)).astype(np.float32) + b = rng.standard_normal((12,)).astype(np.float32) + expected = np.maximum(0.0, _ref_linear(x, w, b)).astype(np.float32) + got = relu(linear(x, w, b)) + np.testing.assert_allclose(got, expected, rtol=1e-4, atol=1e-5) + + +@pytest.mark.parity +def test_linear_matches_torch_functional(): + torch = pytest.importorskip("torch") + import torch.nn.functional as F + + from grilly.functional import linear + + rng = np.random.default_rng(4) + x = rng.standard_normal((8, 32)).astype(np.float32) + w = rng.standard_normal((16, 32)).astype(np.float32) + b = rng.standard_normal((16,)).astype(np.float32) + + xt = torch.from_numpy(x) + wt = torch.from_numpy(w) + bt = torch.from_numpy(b) + expected = F.linear(xt, wt, bt).detach().numpy().astype(np.float32) + got = linear(x, w, b) + np.testing.assert_allclose(got, expected, rtol=1e-4, atol=1e-5) + + +@pytest.mark.parity +def test_relu_matches_torch(): + torch = pytest.importorskip("torch") + import torch.nn.functional as F + + from grilly.functional import relu + + rng = np.random.default_rng(5) + x = rng.standard_normal((5, 17)).astype(np.float32) + xt = torch.from_numpy(x) + expected = F.relu(xt).detach().numpy().astype(np.float32) + got = relu(x) + np.testing.assert_allclose(got, expected, rtol=1e-4, atol=1e-5) diff --git a/tests/test_attention.py b/tests/test_attention.py index c86372d..289b21f 100644 --- a/tests/test_attention.py +++ b/tests/test_attention.py @@ -60,3 +60,19 @@ def test_attention_output(self, gpu): # Output shape: (batch, seq_len, num_heads, head_dim) assert output.shape == (batch_size, seq_len, num_heads, head_dim) assert np.all(np.isfinite(output)) + + def test_flash_attention2(self, gpu): + """Flash Attention 2 path (batched command buffer).""" + batch_size = 1 + seq_len = 16 + num_heads = 4 + head_dim = 32 + q = np.random.randn(batch_size, seq_len, num_heads * head_dim).astype(np.float32) + k = np.random.randn(batch_size, seq_len, num_heads * head_dim).astype(np.float32) + v = np.random.randn(batch_size, seq_len, num_heads * head_dim).astype(np.float32) + + out = gpu.flash_attention2( + q, k, v, num_heads, head_dim, tile_size_q=16, tile_size_k=16 + ) + assert out.shape == (batch_size, seq_len, num_heads, head_dim) + assert np.all(np.isfinite(out)) diff --git a/tests/test_bridge_strict_mode.py b/tests/test_bridge_strict_mode.py new file mode 100644 index 0000000..f7fac12 --- /dev/null +++ b/tests/test_bridge_strict_mode.py @@ -0,0 +1,43 @@ +import numpy as np + + +def test_bridge_strict_mode_raises(monkeypatch): + from grilly.backend import _bridge + + class DummyCore: + @staticmethod + def linear(dev, x, weight, bias): + raise RuntimeError("forced failure") + + monkeypatch.setattr(_bridge, "_core", DummyCore()) + monkeypatch.setattr(_bridge, "_get_device", lambda: object()) + monkeypatch.setattr(_bridge, "_BRIDGE_STRICT", True) + _bridge.reset_fallback_stats() + + x = np.ones((2, 4), dtype=np.float32) + w = np.ones((3, 4), dtype=np.float32) + try: + _bridge.linear(x, w, None) + assert False, "Expected strict bridge failure to raise" + except RuntimeError as e: + assert "GRILLY_BRIDGE_STRICT=1" in str(e) + + +def test_bridge_fallback_stats_increment(monkeypatch): + from grilly.backend import _bridge + + class DummyCore: + @staticmethod + def relu(dev, x): + raise RuntimeError("forced failure") + + monkeypatch.setattr(_bridge, "_core", DummyCore()) + monkeypatch.setattr(_bridge, "_get_device", lambda: object()) + monkeypatch.setattr(_bridge, "_BRIDGE_STRICT", False) + _bridge.reset_fallback_stats() + + x = np.array([-1.0, 2.0], dtype=np.float32) + out = _bridge.relu(x) + assert out is None + stats = _bridge.get_fallback_stats() + assert stats.get("relu", 0) == 1 diff --git a/utils/tensor_conversion.py b/utils/tensor_conversion.py index 322cead..ebc200a 100644 --- a/utils/tensor_conversion.py +++ b/utils/tensor_conversion.py @@ -31,6 +31,20 @@ from .device_manager import get_device_manager +_vulkan_backend_cache = None + + +def _get_vulkan_backend(): + """Get cached Vulkan backend without using deprecated Compute().""" + global _vulkan_backend_cache + if _vulkan_backend_cache is not None: + return _vulkan_backend_cache + try: + _vulkan_backend_cache = get_device_manager().vulkan + return _vulkan_backend_cache + except Exception: + return None + def to_vulkan( tensor: np.ndarray | Any, keep_on_gpu: bool = False @@ -360,9 +374,9 @@ def _ensure_uploaded(self): except Exception: # C++ Tensor may not have Vulkan backend — fall back to Python path try: - from grilly import Compute - - backend = Compute() + backend = _get_vulkan_backend() + if backend is None: + raise RuntimeError("Vulkan backend is unavailable") cpu_data = self._t.numpy() size = cpu_data.nbytes @@ -425,9 +439,9 @@ def _ensure_downloaded(self): core = self._core if core is None: try: - from grilly import Compute - - backend = Compute() + backend = _get_vulkan_backend() + if backend is None: + raise RuntimeError("Vulkan backend is unavailable") core = backend.core self._core = core except Exception: @@ -453,9 +467,9 @@ def _ensure_downloaded(self): def _download_via_staging(self): core = self._core if core is None: - from grilly import Compute - - backend = Compute() + backend = _get_vulkan_backend() + if backend is None: + raise RuntimeError("Vulkan backend is unavailable") core = backend.core self._core = core @@ -611,9 +625,9 @@ def _ensure_uploaded(self): if not self._cpu_valid: raise RuntimeError("Cannot upload: no valid CPU data") try: - from grilly import Compute - - backend = Compute() + backend = _get_vulkan_backend() + if backend is None: + raise RuntimeError("Vulkan backend is unavailable") size = self._cpu_data.nbytes try: from grilly.backend.buffer_pool import acquire_buffer @@ -663,9 +677,9 @@ def _ensure_downloaded(self): return core = self._core if core is None: - from grilly import Compute - - backend = Compute() + backend = _get_vulkan_backend() + if backend is None: + raise RuntimeError("Vulkan backend is unavailable") core = backend.core self._core = core self._cpu_data = core._download_buffer( @@ -678,9 +692,9 @@ def _ensure_downloaded(self): def _download_via_staging(self): core = self._core if core is None: - from grilly import Compute - - backend = Compute() + backend = _get_vulkan_backend() + if backend is None: + raise RuntimeError("Vulkan backend is unavailable") core = backend.core self._core = core pooled = getattr(self, "_pooled_buffer", None) diff --git a/uv.lock b/uv.lock index 68e9a98..a8db432 100644 --- a/uv.lock +++ b/uv.lock @@ -522,7 +522,7 @@ wheels = [ [[package]] name = "grilly" -version = "0.5.6" +version = "0.6.1" source = { editable = "." } dependencies = [ { name = "numpy" }, From a53c1855c255c00f3934476c9667208e05837c23 Mon Sep 17 00:00:00 2001 From: Grill cheese Date: Fri, 3 Apr 2026 16:37:43 -0400 Subject: [PATCH 09/17] pre 1.0 --- .gitignore | 1 + CHANGELOG.md | 13 +++++++++++++ backend/compute.py | 7 +++++++ cpp/python/bindings_core.h | 14 ++++++++++++++ cpp/python/bindings_linear.cpp | 3 +++ docs/api/functional.md | 10 ++++++++++ tests/conftest.py | 3 +++ 7 files changed, 51 insertions(+) diff --git a/.gitignore b/.gitignore index c4e0698..359f627 100644 --- a/.gitignore +++ b/.gitignore @@ -207,3 +207,4 @@ docs/PYTORCH_PARITY_STATUS.md docs/PYTORCH_PARITY_TASKLIST.md /build_verify_pybind /build_verify_pybind2 +docs/MIGRATION_PYTORCH.md diff --git a/CHANGELOG.md b/CHANGELOG.md index ae07d7c..36cd360 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,19 @@ This changelog follows the spirit of **Keep a Changelog** and uses the terms **A --- +## [Unreleased] + +### Added + +- **`tests/parity/`** — Numerical parity tests for `grilly.functional` (numpy reference; optional PyTorch `F.linear` / `F.relu` when `torch` is installed). See `tests/parity/README.md`. +- **`docs/MIGRATION_PYTORCH.md`** — PyTorch → Grilly migration cookbook (device model, functional layout, module backend lifecycle, debugging). + +### Changed + +- **`docs/api/functional.md`** — PyTorch parity notes and links to migration docs and parity tests. + +--- + ## [0.5.0] — 2026-03-18 — "GPU-First" ### Added diff --git a/backend/compute.py b/backend/compute.py index 1d0eab8..620d3a8 100644 --- a/backend/compute.py +++ b/backend/compute.py @@ -75,6 +75,13 @@ def __init__(self, shader_dir: str = None): self.queue = self.core.queue self.shaders = self.core.shaders + # Public Vulkan dispatch / batching (see docs/PERF_DISPATCH.md) + self.record_commands = self.core.record_commands + self.dispatch_compute = self.core._dispatch_compute + self.dispatch_compute_async = self.core._dispatch_compute_async + self.wait_async = self.core._wait_async + self.wait_fence = self.core._wait_fence + def cleanup(self): """Clean up Vulkan resources""" # Clear weight caches for all backend modules (before device is destroyed) diff --git a/cpp/python/bindings_core.h b/cpp/python/bindings_core.h index 6987de4..ce8f062 100644 --- a/cpp/python/bindings_core.h +++ b/cpp/python/bindings_core.h @@ -98,6 +98,20 @@ struct GrillyCoreContext { // Helper utilities shared across binding files // ═══════════════════════════════════════════════════════════════════════════ +/// Require NumPy C-contiguous float32 (kernels assume dense row-major layout). +inline void require_c_contiguous_float(const py::buffer_info& buf) { + if (buf.itemsize != sizeof(float)) + throw std::runtime_error("expected float32 array"); + if (buf.ndim == 0) + return; + py::ssize_t expected_stride = static_cast(sizeof(float)); + for (int i = static_cast(buf.ndim) - 1; i >= 0; --i) { + if (buf.strides[i] != expected_stride) + throw std::runtime_error("array must be C-contiguous float32"); + expected_stride *= buf.shape[i]; + } +} + /// Extract flat batch*seq and last-dim from a numpy buffer_info. inline std::pair extractBatchAndLastDim( const py::buffer_info& buf) { diff --git a/cpp/python/bindings_linear.cpp b/cpp/python/bindings_linear.cpp index c1ad145..dd93bda 100644 --- a/cpp/python/bindings_linear.cpp +++ b/cpp/python/bindings_linear.cpp @@ -16,6 +16,8 @@ void register_linear_ops(py::module_& m) { std::optional> bias) -> Tensor { auto xBuf = x.request(); auto wBuf = weights.request(); + require_c_contiguous_float(xBuf); + require_c_contiguous_float(wBuf); if (xBuf.ndim < 1 || xBuf.ndim > 3) throw std::runtime_error( @@ -37,6 +39,7 @@ void register_linear_ops(py::module_& m) { uint32_t hasBias = 0; if (bias.has_value()) { auto bBuf = bias->request(); + require_c_contiguous_float(bBuf); if (bBuf.ndim != 1 || static_cast(bBuf.shape[0]) != outputDim) throw std::runtime_error( diff --git a/docs/api/functional.md b/docs/api/functional.md index 0985602..735e7da 100644 --- a/docs/api/functional.md +++ b/docs/api/functional.md @@ -154,3 +154,13 @@ Source: [`functional/`](https://github.com/grillcheese-ai/grilly/tree/main/funct | `seq_to_ann_forward` | `seq_to_ann_forward(module, x)` | Apply ANN module to temporal data. | | `reset_net` | `reset_net(module)` | Reset all neuron membrane potentials. | | `set_step_mode` | `set_step_mode(module, mode)` | Set single-step or multi-step mode. | + +--- + +## PyTorch parity notes + +- **Layout**: `linear(x, weight, bias)` uses `weight` of shape `(out_features, in_features)`, matching `torch.nn.functional.linear` and `nn.Linear.weight`. +- **Dtypes**: GPU paths expect float32 inputs unless documented otherwise. +- **Differences**: Not every `torch.nn.functional` symbol exists; Grilly adds domain-specific ops (SNN, memory, FAISS-like). See `docs/PYTORCH_PARITY_STATUS.md`. +- **Migration**: Step-by-step porting guidance lives in `docs/MIGRATION_PYTORCH.md`. +- **Tests**: Numerical checks vs numpy (and optional PyTorch) live under `tests/parity/`. diff --git a/tests/conftest.py b/tests/conftest.py index 093d482..e70a6ce 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -17,6 +17,9 @@ def pytest_configure(config): config.addinivalue_line( "markers", "cpp: marks tests that require C++ backend (deselect with '-m \"not cpp\"')" ) + config.addinivalue_line( + "markers", "parity: marks numerical parity tests (numpy / optional PyTorch reference)" + ) try: From 6e2beeaa1a62eebd1a96ee1de1a6ee518b792d8d Mon Sep 17 00:00:00 2001 From: Grill cheese Date: Fri, 3 Apr 2026 16:38:19 -0400 Subject: [PATCH 10/17] Update bindings_activations.cpp --- cpp/python/bindings_activations.cpp | 65 ++++++++++++++++++----------- 1 file changed, 40 insertions(+), 25 deletions(-) diff --git a/cpp/python/bindings_activations.cpp b/cpp/python/bindings_activations.cpp index f88a210..cd8cd9f 100644 --- a/cpp/python/bindings_activations.cpp +++ b/cpp/python/bindings_activations.cpp @@ -19,14 +19,18 @@ void register_activations_ops(py::module_& m) { [fn](GrillyCoreContext& ctx, py::array_t input) -> Tensor { auto inBuf = input.request(); + require_c_contiguous_float(inBuf); uint32_t total = static_cast(inBuf.size); py::array_t result(input.request().shape); auto rBuf = result.request(); - fn(ctx.batch, ctx.pool, ctx.cache, - static_cast(inBuf.ptr), - static_cast(rBuf.ptr), total); + { + py::gil_scoped_release release; + fn(ctx.batch, ctx.pool, ctx.cache, + static_cast(inBuf.ptr), + static_cast(rBuf.ptr), total); + } return Tensor::from_numpy(result); }, @@ -50,6 +54,8 @@ void register_activations_ops(py::module_& m) { py::array_t input) -> Tensor { auto gBuf = grad_output.request(); auto iBuf = input.request(); + require_c_contiguous_float(gBuf); + require_c_contiguous_float(iBuf); uint32_t total = 1; for (int i = 0; i < gBuf.ndim; ++i) total *= static_cast(gBuf.shape[i]); @@ -57,10 +63,13 @@ void register_activations_ops(py::module_& m) { py::array_t result(gBuf.shape); auto rBuf = result.request(); - fn(ctx.batch, ctx.pool, ctx.cache, - static_cast(gBuf.ptr), - static_cast(iBuf.ptr), - static_cast(rBuf.ptr), total); + { + py::gil_scoped_release release; + fn(ctx.batch, ctx.pool, ctx.cache, + static_cast(gBuf.ptr), + static_cast(iBuf.ptr), + static_cast(rBuf.ptr), total); + } return Tensor::from_numpy(result); }, @@ -95,15 +104,18 @@ void register_activations_ops(py::module_& m) { py::array_t result(outShape); auto rBuf = result.request(); - grilly::ops::fusedMlpGelu( - ctx.batch, ctx.pool, ctx.cache, - static_cast(inBuf.ptr), - static_cast(w1Buf.ptr), - static_cast(b1Buf.ptr), - static_cast(w2Buf.ptr), - static_cast(b2Buf.ptr), - static_cast(rBuf.ptr), - seqLen, dIn, dHidden, dOut); + { + py::gil_scoped_release release; + grilly::ops::fusedMlpGelu( + ctx.batch, ctx.pool, ctx.cache, + static_cast(inBuf.ptr), + static_cast(w1Buf.ptr), + static_cast(b1Buf.ptr), + static_cast(w2Buf.ptr), + static_cast(b2Buf.ptr), + static_cast(rBuf.ptr), + seqLen, dIn, dHidden, dOut); + } return Tensor::from_numpy(result); }, @@ -131,15 +143,18 @@ void register_activations_ops(py::module_& m) { py::array_t result(outShape); auto rBuf = result.request(); - grilly::ops::fusedLayernormLinear( - ctx.batch, ctx.pool, ctx.cache, - static_cast(inBuf.ptr), - static_cast(lnWBuf.ptr), - static_cast(lnBBuf.ptr), - static_cast(pWBuf.ptr), - static_cast(pBBuf.ptr), - static_cast(rBuf.ptr), - seqLen, dIn, dOut); + { + py::gil_scoped_release release; + grilly::ops::fusedLayernormLinear( + ctx.batch, ctx.pool, ctx.cache, + static_cast(inBuf.ptr), + static_cast(lnWBuf.ptr), + static_cast(lnBBuf.ptr), + static_cast(pWBuf.ptr), + static_cast(pBBuf.ptr), + static_cast(rBuf.ptr), + seqLen, dIn, dOut); + } return Tensor::from_numpy(result); }, From f6793d80136736ac2c55789ea652197e6400b31a Mon Sep 17 00:00:00 2001 From: Grill cheese Date: Fri, 3 Apr 2026 16:38:28 -0400 Subject: [PATCH 11/17] Update bindings_attention.cpp --- cpp/python/bindings_attention.cpp | 35 ++++++++++++++++++------------- 1 file changed, 21 insertions(+), 14 deletions(-) diff --git a/cpp/python/bindings_attention.cpp b/cpp/python/bindings_attention.cpp index 55766d4..57c2a3b 100644 --- a/cpp/python/bindings_attention.cpp +++ b/cpp/python/bindings_attention.cpp @@ -41,15 +41,18 @@ void register_attention_ops(py::module_& m) { }); auto rBuf = result.request(); - grilly::ops::flashAttention2( - ctx.batch, ctx.pool, ctx.cache, - static_cast(qBuf.ptr), - static_cast(K.request().ptr), - static_cast(V.request().ptr), - maskPtr, - static_cast(rBuf.ptr), - batchSize, seqLen, numHeads, headDim, - scale, tileSizeQ, tileSizeK); + { + py::gil_scoped_release release; + grilly::ops::flashAttention2( + ctx.batch, ctx.pool, ctx.cache, + static_cast(qBuf.ptr), + static_cast(K.request().ptr), + static_cast(V.request().ptr), + maskPtr, + static_cast(rBuf.ptr), + batchSize, seqLen, numHeads, headDim, + scale, tileSizeQ, tileSizeK); + } return Tensor::from_numpy(result); }, @@ -78,13 +81,17 @@ void register_attention_ops(py::module_& m) { py::array_t result({ static_cast(B), static_cast(H), static_cast(S), static_cast(S)}); + auto resBuf = result.request(); grilly::ops::AttentionScoresParams p{B, S, H, D, scale, 0}; - grilly::ops::attentionScores( - ctx.batch, ctx.pool, ctx.cache, - static_cast(qBuf.ptr), - static_cast(K.request().ptr), - static_cast(result.request().ptr), p); + { + py::gil_scoped_release release; + grilly::ops::attentionScores( + ctx.batch, ctx.pool, ctx.cache, + static_cast(qBuf.ptr), + static_cast(K.request().ptr), + static_cast(resBuf.ptr), p); + } return Tensor::from_numpy(result); }, py::arg("device"), py::arg("Q"), py::arg("K"), From 33e169d5145d7140eb1870debca9937f1d20a416 Mon Sep 17 00:00:00 2001 From: Grill cheese Date: Fri, 3 Apr 2026 16:39:41 -0400 Subject: [PATCH 12/17] Update bindings_attention.cpp --- cpp/python/bindings_attention.cpp | 68 ++++++++++++++++++++----------- 1 file changed, 44 insertions(+), 24 deletions(-) diff --git a/cpp/python/bindings_attention.cpp b/cpp/python/bindings_attention.cpp index 57c2a3b..b8244cd 100644 --- a/cpp/python/bindings_attention.cpp +++ b/cpp/python/bindings_attention.cpp @@ -128,16 +128,21 @@ void register_attention_ops(py::module_& m) { py::array_t wArr({ static_cast(B), static_cast(H), static_cast(S), static_cast(S)}); + auto outRB = outArr.request(); + auto wRB = wArr.request(); grilly::ops::AttentionScoresParams sp{B, S, H, D, scale, 0}; grilly::ops::AttentionOutputParams outp{B, S, H, D}; - grilly::ops::attentionScoresSoftmaxOutput( - ctx.batch, ctx.pool, ctx.cache, - static_cast(qBuf.ptr), - static_cast(kBuf.ptr), - static_cast(vBuf.ptr), - static_cast(outArr.request().ptr), - static_cast(wArr.request().ptr), sp, outp); + { + py::gil_scoped_release release; + grilly::ops::attentionScoresSoftmaxOutput( + ctx.batch, ctx.pool, ctx.cache, + static_cast(qBuf.ptr), + static_cast(kBuf.ptr), + static_cast(vBuf.ptr), + static_cast(outRB.ptr), + static_cast(wRB.ptr), sp, outp); + } return py::make_tuple(Tensor::from_numpy(outArr), Tensor::from_numpy(wArr)); }, py::arg("device"), py::arg("Q"), py::arg("K"), py::arg("V"), @@ -169,9 +174,12 @@ void register_attention_ops(py::module_& m) { grilly::ops::AttentionMaskParams p{ B, H, S, causal ? 1u : 0u, mask_value}; - grilly::ops::attentionMask( - ctx.batch, ctx.pool, ctx.cache, - static_cast(result.mutable_data()), maskPtr, p); + { + py::gil_scoped_release release; + grilly::ops::attentionMask( + ctx.batch, ctx.pool, ctx.cache, + static_cast(result.mutable_data()), maskPtr, p); + } return Tensor::from_numpy(result); }, py::arg("device"), py::arg("scores"), @@ -198,13 +206,17 @@ void register_attention_ops(py::module_& m) { py::array_t result({ static_cast(B), static_cast(H), static_cast(S), static_cast(D)}); + auto resBuf = result.request(); grilly::ops::AttentionOutputParams p{B, S, H, D}; - grilly::ops::attentionOutput( - ctx.batch, ctx.pool, ctx.cache, - static_cast(wBuf.ptr), - static_cast(vBuf.ptr), - static_cast(result.request().ptr), p); + { + py::gil_scoped_release release; + grilly::ops::attentionOutput( + ctx.batch, ctx.pool, ctx.cache, + static_cast(wBuf.ptr), + static_cast(vBuf.ptr), + static_cast(resBuf.ptr), p); + } return Tensor::from_numpy(result); }, py::arg("device"), py::arg("weights"), py::arg("V"), @@ -228,12 +240,16 @@ void register_attention_ops(py::module_& m) { static_cast(B), static_cast(S), static_cast(H * D)}); + auto resBuf = result.request(); grilly::ops::ConcatHeadsParams p{B, S, H, D}; - grilly::ops::attentionConcatHeads( - ctx.batch, ctx.pool, ctx.cache, - static_cast(inBuf.ptr), - static_cast(result.request().ptr), p); + { + py::gil_scoped_release release; + grilly::ops::attentionConcatHeads( + ctx.batch, ctx.pool, ctx.cache, + static_cast(inBuf.ptr), + static_cast(resBuf.ptr), p); + } return Tensor::from_numpy(result); }, py::arg("device"), py::arg("mh_output"), @@ -266,13 +282,17 @@ void register_attention_ops(py::module_& m) { : nullptr; py::array_t result(inBuf.shape); + auto resBuf = result.request(); grilly::ops::RoPEParams p{ B, S, H, D, base, precomputed ? 1u : 0u, scaling}; - grilly::ops::applyRoPE( - ctx.batch, ctx.pool, ctx.cache, - static_cast(inBuf.ptr), - static_cast(result.request().ptr), - cosPtr, sinPtr, p); + { + py::gil_scoped_release release; + grilly::ops::applyRoPE( + ctx.batch, ctx.pool, ctx.cache, + static_cast(inBuf.ptr), + static_cast(resBuf.ptr), + cosPtr, sinPtr, p); + } return Tensor::from_numpy(result); }, py::arg("device"), py::arg("input"), From e5d443624d8c92feb19cb3708874b60f50da7765 Mon Sep 17 00:00:00 2001 From: Grill cheese Date: Fri, 3 Apr 2026 17:43:26 -0400 Subject: [PATCH 13/17] pre v1.0 --- CHANGELOG.md | 5 + backend/conv.py | 129 +++++++++-- backend/fnn.py | 121 ++++++++++ benchmarks/benchmark_conv_backward_weight.py | 68 ++++++ benchmarks/benchmark_int8_gemm.py | 50 +++++ cpp/python/bindings_conv.cpp | 100 ++++++--- cpp/python/bindings_core.cpp | 11 +- cpp/python/bindings_core.h | 28 +++ cpp/python/bindings_fusion.cpp | 6 +- cpp/python/bindings_loss.cpp | 34 ++- cpp/python/bindings_misc.cpp | 159 +++++++++----- cpp/python/bindings_moqe_train.cpp | 13 +- cpp/python/bindings_normalization.cpp | 140 ++++++++---- cpp/python/bindings_optim.cpp | 72 ++++-- cpp/python/bindings_perceiver.cpp | 5 + cpp/python/bindings_pooling.cpp | 49 +++-- cpp/python/bindings_siglip.cpp | 2 + cpp/python/bindings_snn.cpp | 219 ++++++++++++------- docs/PERF_DISPATCH.md | 48 ++++ nn/containers.py | 2 + shaders/int8-gemm.glsl | 32 +-- shaders/spv/int8-gemm.spv | Bin 0 -> 7620 bytes tests/conftest.py | 3 + tests/parity/README.md | 11 +- tests/parity/test_functional_parity.py | 2 - tests/parity/test_optimizers_parity.py | 94 ++++++++ tests/test_attention_long_sequence.py | 74 +++++++ tests/test_conv_backward_weight_gemm.py | 67 ++++++ 28 files changed, 1248 insertions(+), 296 deletions(-) create mode 100644 benchmarks/benchmark_conv_backward_weight.py create mode 100644 benchmarks/benchmark_int8_gemm.py create mode 100644 docs/PERF_DISPATCH.md create mode 100644 shaders/spv/int8-gemm.spv create mode 100644 tests/parity/test_optimizers_parity.py create mode 100644 tests/test_attention_long_sequence.py create mode 100644 tests/test_conv_backward_weight_gemm.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 36cd360..ef1ce8e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,11 +11,16 @@ This changelog follows the spirit of **Keep a Changelog** and uses the terms **A ### Added - **`tests/parity/`** — Numerical parity tests for `grilly.functional` (numpy reference; optional PyTorch `F.linear` / `F.relu` when `torch` is installed). See `tests/parity/README.md`. +- **`tests/parity/test_optimizers_parity.py`** — SGD and Adam (CPU) stepping vs `torch.optim`. - **`docs/MIGRATION_PYTORCH.md`** — PyTorch → Grilly migration cookbook (device model, functional layout, module backend lifecycle, debugging). +- **`docs/PERF_DISPATCH.md`** — Vulkan dispatch/batching (`record_commands`, async dispatch), Sequential fusion notes, pybind GIL checklist. ### Changed - **`docs/api/functional.md`** — PyTorch parity notes and links to migration docs and parity tests. +- **`backend/compute.py`** — `VulkanCompute` exposes `record_commands`, `dispatch_compute`, `dispatch_compute_async`, `wait_async`, `wait_fence`. +- **`backend/fnn.py`** — `_linear_relu_recorded_chain`: Linear→ReLU in one submit when `fused-linear-relu` shader is absent. +- **`cpp/python/bindings_{activations,attention,linear}.cpp`** — GIL release around heavy ops; `require_c_contiguous_float` on `linear` inputs. --- diff --git a/backend/conv.py b/backend/conv.py index a7f0380..26c1fc7 100644 --- a/backend/conv.py +++ b/backend/conv.py @@ -609,10 +609,118 @@ def _conv2d_backward_weight_gemm( K_dim = in_channels * kernel_h * kernel_w N_cols = batch_size * out_h * out_w + # grad_output: (N, C_out, H_out, W_out) -> (C_out, N*H_out*W_out) + grad_out_reshaped = grad_output.transpose(1, 0, 2, 3).reshape( + out_channels, N_cols + ) # (C_out, N_cols) + + grad_bias = None + if has_bias: + grad_bias = np.sum(grad_output, axis=(0, 2, 3)) # (C_out,) + + use_gpu_gemm = ( + "gemm_mnk" in self.shaders and "tensor-transpose" in self.shaders + ) + # --- Step 1: im2col(input) --- buf_input = self._acquire_buffer(input_data.nbytes) buf_cols = self._acquire_buffer(K_dim * N_cols * 4) + if use_gpu_gemm: + buf_cols_t = self._acquire_buffer(N_cols * K_dim * 4) + buf_grad_out = self._acquire_buffer(grad_out_reshaped.nbytes) + buf_grad_w = self._acquire_buffer(out_channels * K_dim * 4) + try: + self._upload_buffer(buf_input, input_data.flatten()) + + pipeline_im2col, layout_im2col, _ = self.pipelines.get_or_create_pipeline( + "convd_im2col", 2, push_constant_size=56 + ) + + in_handle = self._get_buffer_handle(buf_input) + cols_handle = self._get_buffer_handle(buf_cols) + + desc_im2col = self.pipelines.get_cached_descriptor_set( + "convd_im2col", + [(in_handle, input_data.nbytes), (cols_handle, K_dim * N_cols * 4)], + ) + + push_im2col = struct.pack( + "14I", + batch_size, + in_channels, + in_h, + in_w, + out_h, + out_w, + kernel_h, + kernel_w, + stride_h, + stride_w, + padding_h, + padding_w, + dilation_h, + dilation_w, + ) + + group_x = (K_dim + 15) // 16 + group_y = (N_cols + 15) // 16 + self.core._dispatch_compute( + pipeline_im2col, layout_im2col, desc_im2col, group_x, push_im2col, group_y, 1 + ) + + # cols (K_dim, N_cols) row-major -> cols^T (N_cols, K_dim) for GEMM B + pipeline_tr, layout_tr, _ = self.pipelines.get_or_create_pipeline( + "tensor-transpose", 2, push_constant_size=8 + ) + cols_t_handle = self._get_buffer_handle(buf_cols_t) + desc_tr = self.pipelines.get_cached_descriptor_set( + "tensor-transpose", + [ + (cols_handle, K_dim * N_cols * 4), + (cols_t_handle, N_cols * K_dim * 4), + ], + ) + push_tr = struct.pack("2I", K_dim, N_cols) + wg_tr = (K_dim * N_cols + 255) // 256 + self.core._dispatch_compute(pipeline_tr, layout_tr, desc_tr, wg_tr, push_tr) + + self._upload_buffer(buf_grad_out, grad_out_reshaped.ravel()) + + # C = A @ B: A (C_out, N_cols), B (N_cols, K_dim) -> (C_out, K_dim) + pipeline_gemm, layout_gemm, _ = self.pipelines.get_or_create_pipeline( + "gemm_mnk", 3, push_constant_size=12 + ) + g_handle = self._get_buffer_handle(buf_grad_out) + b_handle = cols_t_handle + c_handle = self._get_buffer_handle(buf_grad_w) + desc_gemm = self.pipelines.get_cached_descriptor_set( + "gemm_mnk", + [ + (g_handle, grad_out_reshaped.nbytes), + (b_handle, N_cols * K_dim * 4), + (c_handle, out_channels * K_dim * 4), + ], + ) + push_gemm = struct.pack("3I", out_channels, N_cols, K_dim) + group_gx = (K_dim + 15) // 16 + group_gy = (out_channels + 15) // 16 + self.core._dispatch_compute( + pipeline_gemm, layout_gemm, desc_gemm, group_gx, push_gemm, group_gy, 1 + ) + + grad_weight_flat = self._download_buffer( + buf_grad_w, out_channels * K_dim * 4, np.float32 + ) + grad_weight = grad_weight_flat.reshape( + out_channels, in_channels, kernel_h, kernel_w + ) + return grad_weight.astype(np.float32), grad_bias + finally: + self._release_buffers( + [buf_input, buf_cols, buf_cols_t, buf_grad_out, buf_grad_w] + ) + try: self._upload_buffer(buf_input, input_data.flatten()) @@ -651,29 +759,12 @@ def _conv2d_backward_weight_gemm( pipeline_im2col, layout_im2col, desc_im2col, group_x, push_im2col, group_y, 1 ) - # Download cols for GEMM (could be done on GPU, but for now download) cols_flat = self._download_buffer(buf_cols, K_dim * N_cols * 4, np.float32) - cols = cols_flat.reshape(K_dim, N_cols) # (K_dim, N_cols) - - # --- Step 2: Prepare grad_output for GEMM --- - # grad_output: (N, C_out, H_out, W_out) -> (C_out, N*H_out*W_out) - grad_out_reshaped = grad_output.transpose(1, 0, 2, 3).reshape( - out_channels, N_cols - ) # (C_out, N_cols) - - # --- Step 3: GEMM: grad_weight = grad_out @ cols.T --- - # (C_out, N_cols) @ (N_cols, K_dim) = (C_out, K_dim) - grad_weight_flat = grad_out_reshaped @ cols.T # (C_out, K_dim) + cols = cols_flat.reshape(K_dim, N_cols) - # Reshape to (C_out, C_in, kH, kW) + grad_weight_flat = grad_out_reshaped @ cols.T grad_weight = grad_weight_flat.reshape(out_channels, in_channels, kernel_h, kernel_w) - # --- Step 4: Compute grad_bias --- - grad_bias = None - if has_bias: - # Sum over all spatial positions and batch - grad_bias = np.sum(grad_output, axis=(0, 2, 3)) # (C_out,) - return grad_weight.astype(np.float32), grad_bias finally: diff --git a/backend/fnn.py b/backend/fnn.py index 735a6c4..52ef3df 100644 --- a/backend/fnn.py +++ b/backend/fnn.py @@ -2698,6 +2698,117 @@ def fused_linear_gelu( self._release_buffer(buf_output) return result.reshape(output_shape) + def _linear_relu_recorded_chain( + self, + x, + weights: np.ndarray, + bias: np.ndarray | None = None, + ) -> np.ndarray: + """Linear then ReLU in one queue submit (fused shader absent).""" + if "fnn-linear" not in self.shaders or "activation-relu" not in self.shaders: + raise RuntimeError("linear/relu shaders unavailable") + + original_shape = x.shape + input_dim = x.shape[-1] + output_dim = weights.shape[0] + + if len(original_shape) > 2: + batch_seq = int(np.prod(original_shape[:-1])) + else: + batch_seq = original_shape[0] if len(original_shape) == 2 else 1 + + input_nbytes = batch_seq * input_dim * 4 + output_size = batch_seq * output_dim * 4 + + buf_input, release_input = self._prepare_input(x, size=input_nbytes) + + w_np = np.ascontiguousarray(weights, dtype=np.float32) + w_nbytes = int(np.prod(weights.shape)) * 4 + buf_weights, release_weights = self._get_or_upload_weight(w_np) + + if bias is not None: + bias_np = np.ascontiguousarray(bias, dtype=np.float32) + buf_bias, release_bias = self._get_or_upload_weight(bias_np) + b_nbytes = bias_np.size * 4 + has_bias = 1 + else: + b_flat = np.zeros(output_dim, dtype=np.float32) + buf_bias = self._acquire_buffer(b_flat.nbytes) + self._upload_buffer(buf_bias, b_flat) + b_nbytes = b_flat.nbytes + release_bias = True + has_bias = 0 + + buf_linear_out = self._acquire_buffer(output_size) + buf_relu_out = self._acquire_buffer(output_size) + + pipeline_l, pipeline_layout_l, _ = self.pipelines.get_or_create_pipeline( + "fnn-linear", 4, push_constant_size=16 + ) + descriptor_set_l = self.pipelines.get_cached_descriptor_set( + "fnn-linear", + [ + (self._get_buffer_handle(buf_input), input_nbytes), + (self._get_buffer_handle(buf_weights), w_nbytes), + (self._get_buffer_handle(buf_bias), b_nbytes), + (self._get_buffer_handle(buf_linear_out), output_size), + ], + ) + + push_l = struct.pack("IIII", batch_seq, input_dim, output_dim, has_bias) + workgroups_x = (output_dim + 15) // 16 + workgroups_y = (batch_seq + 15) // 16 + + total_elements = batch_seq * output_dim + push_r = struct.pack("I", total_elements) + workgroups_r = (total_elements + 255) // 256 + + pipeline_r, pipeline_layout_r, _ = self.pipelines.get_or_create_pipeline( + "activation-relu", 2, push_constant_size=4 + ) + descriptor_set_r = self.pipelines.get_cached_descriptor_set( + "activation-relu", + [ + (self._get_buffer_handle(buf_linear_out), output_size), + (self._get_buffer_handle(buf_relu_out), output_size), + ], + ) + + with self.core.record_commands() as rec: + rec.dispatch( + pipeline_l, + pipeline_layout_l, + descriptor_set_l, + (workgroups_x, workgroups_y), + push_l, + ) + rec.barrier() + rec.dispatch( + pipeline_r, + pipeline_layout_r, + descriptor_set_r, + (workgroups_r,), + push_r, + ) + + if len(original_shape) > 2: + output_shape = original_shape[:-1] + (output_dim,) + else: + output_shape = (batch_seq, output_dim) + + result = self._download_buffer(buf_relu_out, output_size, np.float32) + + if release_input: + self._release_buffer(buf_input) + if release_weights: + self._release_buffer(buf_weights) + if release_bias: + self._release_buffer(buf_bias) + self._release_buffer(buf_linear_out) + self._release_buffer(buf_relu_out) + + return result.reshape(output_shape) + def fused_linear_relu( self, x, @@ -2720,6 +2831,16 @@ def fused_linear_relu( ReLU(Linear(x)) """ if "fused-linear-relu" not in self.shaders: + if ( + not return_gpu_tensor + and hasattr(self.core, "record_commands") + and "fnn-linear" in self.shaders + and "activation-relu" in self.shaders + ): + try: + return self._linear_relu_recorded_chain(x, weights, bias) + except Exception: + pass linear_out = self.linear(x, weights, bias, return_gpu_tensor=return_gpu_tensor) return self.activation_relu(linear_out, return_gpu_tensor=return_gpu_tensor) diff --git a/benchmarks/benchmark_conv_backward_weight.py b/benchmarks/benchmark_conv_backward_weight.py new file mode 100644 index 0000000..5ffc46d --- /dev/null +++ b/benchmarks/benchmark_conv_backward_weight.py @@ -0,0 +1,68 @@ +#!/usr/bin/env python3 +""" +Baseline timings for conv2d backward weight (Workstream C1). + +The Vulkan training path uses `conv2d-backward-weight` (non-atomic, one thread +per weight slot) or the GEMM path (im2col + CPU matmul when `convd_im2col` is +available). This script prints wall time for a representative backward-weight call. + +Usage: + uv run python benchmarks/benchmark_conv_backward_weight.py +""" + +from __future__ import annotations + +import time +import warnings + +import numpy as np + +from grilly.backend.compute import VulkanCompute + + +def main() -> None: + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + backend = VulkanCompute() + try: + rng = np.random.default_rng(0) + batch_size = 4 + in_ch, out_ch = 32, 64 + h, w = 16, 16 + kh, kw = 3, 3 + grad_out = rng.standard_normal((batch_size, out_ch, h, w), dtype=np.float32) + inp = rng.standard_normal((batch_size, in_ch, h, w), dtype=np.float32) + + # Warmup + backend.conv.conv2d_backward_weight( + grad_out, + inp, + (kh, kw), + stride=(1, 1), + padding=(1, 1), + dilation=(1, 1), + groups=1, + has_bias=True, + ) + + n = 5 + t0 = time.perf_counter() + for _ in range(n): + backend.conv.conv2d_backward_weight( + grad_out, + inp, + (kh, kw), + stride=(1, 1), + padding=(1, 1), + dilation=(1, 1), + groups=1, + has_bias=True, + ) + elapsed = (time.perf_counter() - t0) / n + print(f"conv2d_backward_weight: {elapsed * 1000:.3f} ms/iter (mean of {n})") + finally: + backend.cleanup() + + +if __name__ == "__main__": + main() diff --git a/benchmarks/benchmark_int8_gemm.py b/benchmarks/benchmark_int8_gemm.py new file mode 100644 index 0000000..bc38d93 --- /dev/null +++ b/benchmarks/benchmark_int8_gemm.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +""" +INT8 weight GEMM baseline (Workstream C2) — `VulkanFNN.gemm_int8` if shader loaded. + +Usage: + uv run python benchmarks/benchmark_int8_gemm.py +""" + +from __future__ import annotations + +import time +import warnings + +import numpy as np + +from grilly.backend.compute import VulkanCompute + + +def main() -> None: + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + backend = VulkanCompute() + try: + if "int8-gemm" not in backend.fnn.shaders: + print("int8-gemm shader not loaded; skipping benchmark.") + return + + rng = np.random.default_rng(1) + M, K, N = 256, 512, 256 + group_size = 64 + num_groups = (K + group_size - 1) // group_size + + act = rng.standard_normal((M, K), dtype=np.float32) + w_i8 = rng.integers(-128, 127, size=(N, K), dtype=np.int8) + scales = np.abs(rng.standard_normal((N, num_groups), dtype=np.float32)) + 0.01 + + backend.fnn.gemm_int8(act, w_i8, scales, group_size=group_size) + + n = 10 + t0 = time.perf_counter() + for _ in range(n): + backend.fnn.gemm_int8(act, w_i8, scales, group_size=group_size) + elapsed = (time.perf_counter() - t0) / n + print(f"int8_gemm M={M} K={K} N={N}: {elapsed * 1000:.3f} ms/iter (mean of {n})") + finally: + backend.cleanup() + + +if __name__ == "__main__": + main() diff --git a/cpp/python/bindings_conv.cpp b/cpp/python/bindings_conv.cpp index 81387aa..34353a8 100644 --- a/cpp/python/bindings_conv.cpp +++ b/cpp/python/bindings_conv.cpp @@ -19,6 +19,8 @@ void register_conv_ops(py::module_& m) { uint32_t groups) -> Tensor { auto inBuf = input.request(); auto wBuf = weight.request(); + require_c_contiguous_float(inBuf); + require_c_contiguous_float(wBuf); if (inBuf.ndim != 4) throw std::runtime_error( @@ -46,8 +48,11 @@ void register_conv_ops(py::module_& m) { uint32_t outW = grilly::ops::convOutputSize(inW, kW, sW, pW, dW); const float* biasPtr = nullptr; - if (bias.has_value()) - biasPtr = static_cast(bias->request().ptr); + if (bias.has_value()) { + auto bBuf = bias->request(); + require_c_contiguous_float(bBuf); + biasPtr = static_cast(bBuf.ptr); + } py::array_t result({ static_cast(batchSize), @@ -57,15 +62,18 @@ void register_conv_ops(py::module_& m) { }); auto rBuf = result.request(); - grilly::ops::conv2d( - ctx.batch, ctx.pool, ctx.cache, - static_cast(inBuf.ptr), - static_cast(wBuf.ptr), - biasPtr, - static_cast(rBuf.ptr), - batchSize, inChannels, inH, inW, - outChannels, kH, kW, - sH, sW, pH, pW, dH, dW, groups); + { + py::gil_scoped_release release; + grilly::ops::conv2d( + ctx.batch, ctx.pool, ctx.cache, + static_cast(inBuf.ptr), + static_cast(wBuf.ptr), + biasPtr, + static_cast(rBuf.ptr), + batchSize, inChannels, inH, inW, + outChannels, kH, kW, + sH, sW, pH, pW, dH, dW, groups); + } return Tensor::from_numpy(result); }, @@ -87,6 +95,8 @@ void register_conv_ops(py::module_& m) { uint32_t dilation, uint32_t groups) -> Tensor { auto inBuf = input.request(); auto wBuf = weight.request(); + require_c_contiguous_float(inBuf); + require_c_contiguous_float(wBuf); if (inBuf.ndim != 3) throw std::runtime_error( @@ -102,8 +112,11 @@ void register_conv_ops(py::module_& m) { length, kSize, stride, padding, dilation); const float* biasPtr = nullptr; - if (bias.has_value()) - biasPtr = static_cast(bias->request().ptr); + if (bias.has_value()) { + auto bBuf = bias->request(); + require_c_contiguous_float(bBuf); + biasPtr = static_cast(bBuf.ptr); + } py::array_t result({ static_cast(batchSize), @@ -112,15 +125,18 @@ void register_conv_ops(py::module_& m) { }); auto rBuf = result.request(); - grilly::ops::conv1d( - ctx.batch, ctx.pool, ctx.cache, - static_cast(inBuf.ptr), - static_cast(wBuf.ptr), - biasPtr, - static_cast(rBuf.ptr), - batchSize, inChannels, length, - outChannels, kSize, - stride, padding, dilation, groups); + { + py::gil_scoped_release release; + grilly::ops::conv1d( + ctx.batch, ctx.pool, ctx.cache, + static_cast(inBuf.ptr), + static_cast(wBuf.ptr), + biasPtr, + static_cast(rBuf.ptr), + batchSize, inChannels, length, + outChannels, kSize, + stride, padding, dilation, groups); + } return Tensor::from_numpy(result); }, @@ -141,6 +157,8 @@ void register_conv_ops(py::module_& m) { uint32_t groups) -> Tensor { auto gBuf = grad_output.request(); auto wBuf = weight.request(); + require_c_contiguous_float(gBuf); + require_c_contiguous_float(wBuf); uint32_t batchSize = input_shape[0]; uint32_t inChannels = input_shape[1]; @@ -168,12 +186,16 @@ void register_conv_ops(py::module_& m) { static_cast(inChannels), static_cast(inH), static_cast(inW)}); + auto resBuf = result.request(); - grilly::ops::conv2dBackwardInput( - ctx.batch, ctx.pool, ctx.cache, - static_cast(gBuf.ptr), - static_cast(wBuf.ptr), - static_cast(result.request().ptr), p); + { + py::gil_scoped_release release; + grilly::ops::conv2dBackwardInput( + ctx.batch, ctx.pool, ctx.cache, + static_cast(gBuf.ptr), + static_cast(wBuf.ptr), + static_cast(resBuf.ptr), p); + } return Tensor::from_numpy(result); }, @@ -196,6 +218,8 @@ void register_conv_ops(py::module_& m) { uint32_t groups, bool has_bias) -> py::dict { auto gBuf = grad_output.request(); auto iBuf = input.request(); + require_c_contiguous_float(gBuf); + require_c_contiguous_float(iBuf); uint32_t batchSize = static_cast(iBuf.shape[0]); uint32_t inChannels = static_cast(iBuf.shape[1]); @@ -229,15 +253,19 @@ void register_conv_ops(py::module_& m) { has_bias ? std::vector{ static_cast(outChannels)} : std::vector{1}); - - grilly::ops::conv2dBackwardWeight( - ctx.batch, ctx.pool, ctx.cache, - static_cast(gBuf.ptr), - static_cast(iBuf.ptr), - static_cast(gradWeight.request().ptr), - has_bias ? static_cast(gradBias.request().ptr) - : nullptr, - p); + auto gwBuf = gradWeight.request(); + auto gbBuf = gradBias.request(); + + { + py::gil_scoped_release release; + grilly::ops::conv2dBackwardWeight( + ctx.batch, ctx.pool, ctx.cache, + static_cast(gBuf.ptr), + static_cast(iBuf.ptr), + static_cast(gwBuf.ptr), + has_bias ? static_cast(gbBuf.ptr) : nullptr, + p); + } py::dict result; result["grad_weight"] = Tensor::from_numpy(gradWeight); diff --git a/cpp/python/bindings_core.cpp b/cpp/python/bindings_core.cpp index b1f8d64..5c2d3fc 100644 --- a/cpp/python/bindings_core.cpp +++ b/cpp/python/bindings_core.cpp @@ -75,7 +75,11 @@ PYBIND11_MODULE(grilly_core, m) { "Clear all recorded ops for reuse") .def("optimize", [](grilly::OpGraph& graph, GrillyCoreContext& ctx) -> py::dict { - auto stats = graph.optimize(ctx.cache); + grilly::FusionStats stats; + { + py::gil_scoped_release release; + stats = graph.optimize(ctx.cache); + } py::dict d; d["ops_fused"] = stats.opsFused; d["barriers_eliminated"] = stats.barriersEliminated; @@ -87,7 +91,10 @@ PYBIND11_MODULE(grilly_core, m) { "Run fusion optimization pass. Returns fusion statistics.") .def("execute", [](grilly::OpGraph& graph, GrillyCoreContext& ctx) { - graph.execute(ctx.batch, ctx.cache); + { + py::gil_scoped_release release; + graph.execute(ctx.batch, ctx.cache); + } }, py::arg("device"), "Execute all recorded ops in a single GPU submission"); diff --git a/cpp/python/bindings_core.h b/cpp/python/bindings_core.h index ce8f062..437ef72 100644 --- a/cpp/python/bindings_core.h +++ b/cpp/python/bindings_core.h @@ -112,6 +112,34 @@ inline void require_c_contiguous_float(const py::buffer_info& buf) { } } +/// Require NumPy C-contiguous int8 (e.g. VSA / Hamming vectors). +inline void require_c_contiguous_int8(const py::buffer_info& buf) { + if (buf.itemsize != sizeof(int8_t)) + throw std::runtime_error("expected int8 array"); + if (buf.ndim == 0) + return; + py::ssize_t expected_stride = static_cast(sizeof(int8_t)); + for (int i = static_cast(buf.ndim) - 1; i >= 0; --i) { + if (buf.strides[i] != expected_stride) + throw std::runtime_error("array must be C-contiguous int8"); + expected_stride *= buf.shape[i]; + } +} + +/// Require NumPy C-contiguous uint32 (e.g. CE targets). +inline void require_c_contiguous_uint32(const py::buffer_info& buf) { + if (buf.itemsize != sizeof(uint32_t)) + throw std::runtime_error("expected uint32 array"); + if (buf.ndim == 0) + return; + py::ssize_t expected_stride = static_cast(sizeof(uint32_t)); + for (int i = static_cast(buf.ndim) - 1; i >= 0; --i) { + if (buf.strides[i] != expected_stride) + throw std::runtime_error("array must be C-contiguous uint32"); + expected_stride *= buf.shape[i]; + } +} + /// Extract flat batch*seq and last-dim from a numpy buffer_info. inline std::pair extractBatchAndLastDim( const py::buffer_info& buf) { diff --git a/cpp/python/bindings_fusion.cpp b/cpp/python/bindings_fusion.cpp index b179adf..fb5424f 100644 --- a/cpp/python/bindings_fusion.cpp +++ b/cpp/python/bindings_fusion.cpp @@ -39,7 +39,11 @@ void register_fusion_ops(py::module_& m) { graph.ops.push_back(std::move(op)); } - FusionResult result = engine.fuse(graph, ctx.cache); + FusionResult result; + { + py::gil_scoped_release release; + result = engine.fuse(graph, ctx.cache); + } py::dict d; d["shader_name"] = result.shaderName; diff --git a/cpp/python/bindings_loss.cpp b/cpp/python/bindings_loss.cpp index 9a3c59d..7bbb03d 100644 --- a/cpp/python/bindings_loss.cpp +++ b/cpp/python/bindings_loss.cpp @@ -15,6 +15,7 @@ void register_loss_ops(py::module_& m) { py::array_t logits, py::array_t targets, float label_smoothing) -> Tensor { auto lBuf = logits.request(); + require_c_contiguous_float(lBuf); uint32_t batchSize, seqLen, vocabSize; if (lBuf.ndim == 2) { @@ -31,14 +32,20 @@ void register_loss_ops(py::module_& m) { uint32_t totalPos = batchSize * seqLen; py::array_t losses(totalPos); + auto tBuf = targets.request(); + require_c_contiguous_uint32(tBuf); + auto lossBuf = losses.request(); grilly::ops::CrossEntropyParams p{ batchSize, seqLen, vocabSize, 0, label_smoothing}; - grilly::ops::crossEntropyLoss( - ctx.batch, ctx.pool, ctx.cache, - static_cast(lBuf.ptr), - static_cast(targets.request().ptr), - static_cast(losses.request().ptr), p); + { + py::gil_scoped_release release; + grilly::ops::crossEntropyLoss( + ctx.batch, ctx.pool, ctx.cache, + static_cast(lBuf.ptr), + static_cast(tBuf.ptr), + static_cast(lossBuf.ptr), p); + } return Tensor::from_numpy(losses); }, py::arg("device"), py::arg("logits"), py::arg("targets"), @@ -52,18 +59,25 @@ void register_loss_ops(py::module_& m) { py::array_t logits, py::array_t targets) -> Tensor { auto lBuf = logits.request(); + require_c_contiguous_float(lBuf); uint32_t batchSize = static_cast(lBuf.shape[0]); uint32_t numClasses = static_cast(lBuf.shape[1]); py::array_t gradLogits(lBuf.shape); + auto tBuf = targets.request(); + require_c_contiguous_uint32(tBuf); + auto gBuf = gradLogits.request(); grilly::ops::CrossEntropyBackwardParams p{ batchSize, numClasses}; - grilly::ops::crossEntropyBackward( - ctx.batch, ctx.pool, ctx.cache, - static_cast(lBuf.ptr), - static_cast(targets.request().ptr), - static_cast(gradLogits.request().ptr), p); + { + py::gil_scoped_release release; + grilly::ops::crossEntropyBackward( + ctx.batch, ctx.pool, ctx.cache, + static_cast(lBuf.ptr), + static_cast(tBuf.ptr), + static_cast(gBuf.ptr), p); + } return Tensor::from_numpy(gradLogits); }, py::arg("device"), py::arg("logits"), py::arg("targets"), diff --git a/cpp/python/bindings_misc.cpp b/cpp/python/bindings_misc.cpp index d39d298..584334d 100644 --- a/cpp/python/bindings_misc.cpp +++ b/cpp/python/bindings_misc.cpp @@ -90,17 +90,25 @@ void register_misc_ops(py::module_& m) { py::array_t input, py::array_t random_mask, float p, bool training) -> Tensor { auto inBuf = input.request(); - uint32_t total = 1; - for (int i = 0; i < inBuf.ndim; ++i) - total *= static_cast(inBuf.shape[i]); + auto maskBuf = random_mask.request(); + require_c_contiguous_float(inBuf); + require_c_contiguous_float(maskBuf); + uint32_t total = static_cast(inBuf.size); + if (static_cast(maskBuf.size) != total) + throw std::runtime_error("dropout: input and mask size mismatch"); py::array_t result(inBuf.shape); - grilly::ops::dropout( - ctx.batch, ctx.pool, ctx.cache, - static_cast(inBuf.ptr), - static_cast(random_mask.request().ptr), - static_cast(result.request().ptr), - total, p, training); + auto rBuf = result.request(); + require_c_contiguous_float(rBuf); + { + py::gil_scoped_release release; + grilly::ops::dropout( + ctx.batch, ctx.pool, ctx.cache, + static_cast(inBuf.ptr), + static_cast(maskBuf.ptr), + static_cast(rBuf.ptr), + total, p, training); + } return Tensor::from_numpy(result); }, py::arg("device"), py::arg("input"), py::arg("random_mask"), @@ -115,6 +123,8 @@ void register_misc_ops(py::module_& m) { py::array_t embeddings) -> Tensor { auto idBuf = token_ids.request(); auto eBuf = embeddings.request(); + require_c_contiguous_uint32(idBuf); + require_c_contiguous_float(eBuf); uint32_t batchSize = 1, seqLen; if (idBuf.ndim == 1) { @@ -131,13 +141,19 @@ void register_misc_ops(py::module_& m) { static_cast(seqLen), static_cast(embDim)}); + auto rBuf = result.request(); + require_c_contiguous_float(rBuf); + grilly::ops::EmbeddingParams p{ batchSize, seqLen, vocabSize, embDim}; - grilly::ops::embeddingLookup( - ctx.batch, ctx.pool, ctx.cache, - static_cast(idBuf.ptr), - static_cast(eBuf.ptr), - static_cast(result.request().ptr), p); + { + py::gil_scoped_release release; + grilly::ops::embeddingLookup( + ctx.batch, ctx.pool, ctx.cache, + static_cast(idBuf.ptr), + static_cast(eBuf.ptr), + static_cast(rBuf.ptr), p); + } if (idBuf.ndim == 1) result = result.reshape({ @@ -208,13 +224,18 @@ void register_misc_ops(py::module_& m) { [](GrillyCoreContext& ctx, grilly::ops::KVCache& kvCache, py::array_t newKeys, py::array_t newValues) { auto kBuf = newKeys.request(); + auto vBuf = newValues.request(); + require_c_contiguous_float(kBuf); + require_c_contiguous_float(vBuf); uint32_t numNew = static_cast(kBuf.shape[0]); - grilly::ops::kvCacheAppend( - ctx.batch, ctx.pool, ctx.cache, kvCache, - static_cast(kBuf.ptr), - static_cast( - newValues.request().ptr), - numNew); + { + py::gil_scoped_release release; + grilly::ops::kvCacheAppend( + ctx.batch, ctx.pool, ctx.cache, kvCache, + static_cast(kBuf.ptr), + static_cast(vBuf.ptr), + numNew); + } }, py::arg("device"), py::arg("kv_cache"), py::arg("new_keys"), py::arg("new_values"), @@ -234,10 +255,17 @@ void register_misc_ops(py::module_& m) { static_cast(tokens), static_cast(cfg.numHeads), static_cast(cfg.headDim)}); - grilly::ops::kvCacheDecode( - ctx.batch, ctx.pool, ctx.cache, kvCache, - static_cast(keys.request().ptr), - static_cast(values.request().ptr)); + auto keysBuf = keys.request(); + auto valsBuf = values.request(); + require_c_contiguous_float(keysBuf); + require_c_contiguous_float(valsBuf); + { + py::gil_scoped_release release; + grilly::ops::kvCacheDecode( + ctx.batch, ctx.pool, ctx.cache, kvCache, + static_cast(keysBuf.ptr), + static_cast(valsBuf.ptr)); + } py::dict result; result["keys"] = Tensor::from_numpy(keys); result["values"] = Tensor::from_numpy(values); @@ -252,12 +280,18 @@ void register_misc_ops(py::module_& m) { std::optional> attentionScores, uint32_t numEvict) { const float* scoresPtr = nullptr; - if (attentionScores.has_value()) - scoresPtr = static_cast( - attentionScores->request().ptr); - grilly::ops::kvCacheEvictH2O( - ctx.batch, ctx.pool, ctx.cache, kvCache, - scoresPtr, numEvict); + py::buffer_info scoresBuf; + if (attentionScores.has_value()) { + scoresBuf = attentionScores->request(); + require_c_contiguous_float(scoresBuf); + scoresPtr = static_cast(scoresBuf.ptr); + } + { + py::gil_scoped_release release; + grilly::ops::kvCacheEvictH2O( + ctx.batch, ctx.pool, ctx.cache, kvCache, + scoresPtr, numEvict); + } }, py::arg("device"), py::arg("kv_cache"), py::arg("attention_scores") = py::none(), @@ -267,8 +301,11 @@ void register_misc_ops(py::module_& m) { m.def( "kv_cache_compact", [](GrillyCoreContext& ctx, grilly::ops::KVCache& kvCache) { - grilly::ops::kvCacheCompact( - ctx.batch, ctx.pool, ctx.cache, kvCache); + { + py::gil_scoped_release release; + grilly::ops::kvCacheCompact( + ctx.batch, ctx.pool, ctx.cache, kvCache); + } }, py::arg("device"), py::arg("kv_cache"), "Compact KV cache after eviction"); @@ -298,13 +335,18 @@ void register_misc_ops(py::module_& m) { [](GrillyCoreContext& ctx, grilly::ops::KVCache& kvCache, py::array_t tokenFeatures, py::array_t attentionScores, uint32_t seqLen) { - grilly::ops::kvCacheTrainEvictionHead( - ctx.batch, ctx.pool, ctx.cache, kvCache, - static_cast( - tokenFeatures.request().ptr), - static_cast( - attentionScores.request().ptr), - seqLen); + auto tfBuf = tokenFeatures.request(); + auto asBuf = attentionScores.request(); + require_c_contiguous_float(tfBuf); + require_c_contiguous_float(asBuf); + { + py::gil_scoped_release release; + grilly::ops::kvCacheTrainEvictionHead( + ctx.batch, ctx.pool, ctx.cache, kvCache, + static_cast(tfBuf.ptr), + static_cast(asBuf.ptr), + seqLen); + } }, py::arg("device"), py::arg("kv_cache"), py::arg("token_features"), py::arg("attention_scores"), @@ -317,12 +359,18 @@ void register_misc_ops(py::module_& m) { std::optional> hiddenStates, uint32_t hiddenDim) { const float* hsPtr = nullptr; - if (hiddenStates.has_value()) - hsPtr = static_cast( - hiddenStates->request().ptr); - grilly::ops::kvCacheEvictSpeculative( - ctx.batch, ctx.pool, ctx.cache, kvCache, - hsPtr, hiddenDim); + py::buffer_info hsBuf; + if (hiddenStates.has_value()) { + hsBuf = hiddenStates->request(); + require_c_contiguous_float(hsBuf); + hsPtr = static_cast(hsBuf.ptr); + } + { + py::gil_scoped_release release; + grilly::ops::kvCacheEvictSpeculative( + ctx.batch, ctx.pool, ctx.cache, kvCache, + hsPtr, hiddenDim); + } }, py::arg("device"), py::arg("kv_cache"), py::arg("hidden_states") = py::none(), @@ -335,6 +383,7 @@ void register_misc_ops(py::module_& m) { [](GrillyCoreContext& ctx, py::array_t input, uint32_t waveSize, bool reverse) -> py::array_t { auto inBuf = input.request(); + require_c_contiguous_float(inBuf); if (inBuf.ndim != 4) throw std::runtime_error("input must be 4D"); uint32_t batchSize = static_cast(inBuf.shape[0]); @@ -355,12 +404,16 @@ void register_misc_ops(py::module_& m) { result = py::array_t(outSize / sizeof(float)); } auto rBuf = result.request(); - grilly::ops::swizzle( - ctx.batch, ctx.pool, ctx.cache, - static_cast(inBuf.ptr), - static_cast(rBuf.ptr), - batchSize, numHeads, seqLen, headDim, - waveSize, reverse); + require_c_contiguous_float(rBuf); + { + py::gil_scoped_release release; + grilly::ops::swizzle( + ctx.batch, ctx.pool, ctx.cache, + static_cast(inBuf.ptr), + static_cast(rBuf.ptr), + batchSize, numHeads, seqLen, headDim, + waveSize, reverse); + } return result; }, py::arg("device"), py::arg("input"), @@ -517,6 +570,8 @@ void register_misc_ops(py::module_& m) { py::array_t cache_data) -> py::array_t { auto qBuf = query.request(); auto cBuf = cache_data.request(); + require_c_contiguous_int8(qBuf); + require_c_contiguous_int8(cBuf); uint32_t dim = static_cast(qBuf.shape[0]); uint32_t numEntries; if (cBuf.ndim == 2) { @@ -571,6 +626,8 @@ void register_misc_ops(py::module_& m) { py::array_t cache_data) -> py::array_t { auto qBuf = query.request(); auto cBuf = cache_data.request(); + require_c_contiguous_int8(qBuf); + require_c_contiguous_int8(cBuf); uint32_t dim = static_cast(qBuf.shape[0]); uint32_t numEntries = (cBuf.ndim == 2) ? static_cast(cBuf.shape[0]) diff --git a/cpp/python/bindings_moqe_train.cpp b/cpp/python/bindings_moqe_train.cpp index 1d5788d..c42e593 100644 --- a/cpp/python/bindings_moqe_train.cpp +++ b/cpp/python/bindings_moqe_train.cpp @@ -24,7 +24,9 @@ void register_moqe_train_ops(py::module_& m) { for (auto& w : weight_arrays) { auto arr = w.cast>(); arrays.push_back(arr); - ptrs.push_back(static_cast(arr.request().ptr)); + auto wb = arr.request(); + require_c_contiguous_float(wb); + ptrs.push_back(static_cast(wb.ptr)); } int handle = ops::moqe_train_upload( @@ -52,6 +54,7 @@ void register_moqe_train_ops(py::module_& m) { uint32_t layerIdx, int expertIdx, py::array_t w) { auto& cache = ops::moqe_train_get_cache(handle); auto buf = w.request(); + require_c_contiguous_float(buf); { py::gil_scoped_release release; ops::moqe_train_update_expert(ctx.pool, cache, layerIdx, expertIdx, @@ -71,6 +74,10 @@ void register_moqe_train_ops(py::module_& m) { auto& tc = ops::moqe_train_get_cache(handle); auto b0 = x0.request(); auto b1 = x1.request(); + if (b0.size > 0) + require_c_contiguous_float(b0); + if (b1.size > 0) + require_c_contiguous_float(b1); uint32_t n0 = static_cast(b0.shape[0]); uint32_t n1 = static_cast(b1.shape[0]); uint32_t d = tc.dModel; @@ -100,6 +107,10 @@ void register_moqe_train_ops(py::module_& m) { auto& tc = ops::moqe_train_get_cache(handle); auto b0 = d0.request(); auto b1 = d1.request(); + if (b0.size > 0) + require_c_contiguous_float(b0); + if (b1.size > 0) + require_c_contiguous_float(b1); uint32_t n0 = static_cast(b0.shape[0]); uint32_t n1 = static_cast(b1.shape[0]); uint32_t d = tc.dModel; diff --git a/cpp/python/bindings_normalization.cpp b/cpp/python/bindings_normalization.cpp index 15f62bf..bb0e58b 100644 --- a/cpp/python/bindings_normalization.cpp +++ b/cpp/python/bindings_normalization.cpp @@ -19,6 +19,8 @@ void register_normalization_ops(py::module_& m) { float eps) -> Tensor { auto inBuf = input.request(); auto gBuf = gamma.request(); + require_c_contiguous_float(inBuf); + require_c_contiguous_float(gBuf); if (inBuf.ndim < 2) throw std::runtime_error("input must be at least 2D"); @@ -31,14 +33,19 @@ void register_normalization_ops(py::module_& m) { py::array_t result(inBuf.shape); auto rBuf = result.request(); + auto bBuf = beta.request(); + require_c_contiguous_float(bBuf); - grilly::ops::layernorm( - ctx.batch, ctx.pool, ctx.cache, - static_cast(inBuf.ptr), - static_cast(rBuf.ptr), - static_cast(gBuf.ptr), - static_cast(beta.request().ptr), - 1, totalBatch, features, eps); + { + py::gil_scoped_release release; + grilly::ops::layernorm( + ctx.batch, ctx.pool, ctx.cache, + static_cast(inBuf.ptr), + static_cast(rBuf.ptr), + static_cast(gBuf.ptr), + static_cast(bBuf.ptr), + 1, totalBatch, features, eps); + } return Tensor::from_numpy(result); }, @@ -56,6 +63,9 @@ void register_normalization_ops(py::module_& m) { auto goBuf = grad_output.request(); auto inBuf = input.request(); auto gBuf = gamma.request(); + require_c_contiguous_float(goBuf); + require_c_contiguous_float(inBuf); + require_c_contiguous_float(gBuf); if (inBuf.ndim < 2) throw std::runtime_error("input must be at least 2D"); @@ -71,18 +81,28 @@ void register_normalization_ops(py::module_& m) { {static_cast(features)}); py::array_t gradBeta( {static_cast(features)}); + auto meanBuf = mean.request(); + auto varBuf = var.request(); + auto giBuf = gradInput.request(); + auto ggBuf = gradGamma.request(); + auto gbBuf = gradBeta.request(); + require_c_contiguous_float(meanBuf); + require_c_contiguous_float(varBuf); - grilly::ops::layernormBackward( - ctx.batch, ctx.pool, ctx.cache, - static_cast(goBuf.ptr), - static_cast(inBuf.ptr), - static_cast(gBuf.ptr), - static_cast(mean.request().ptr), - static_cast(var.request().ptr), - static_cast(gradInput.request().ptr), - static_cast(gradGamma.request().ptr), - static_cast(gradBeta.request().ptr), - 1, totalBatch, features, eps); + { + py::gil_scoped_release release; + grilly::ops::layernormBackward( + ctx.batch, ctx.pool, ctx.cache, + static_cast(goBuf.ptr), + static_cast(inBuf.ptr), + static_cast(gBuf.ptr), + static_cast(meanBuf.ptr), + static_cast(varBuf.ptr), + static_cast(giBuf.ptr), + static_cast(ggBuf.ptr), + static_cast(gbBuf.ptr), + 1, totalBatch, features, eps); + } py::dict result; result["grad_input"] = Tensor::from_numpy(gradInput); @@ -104,6 +124,7 @@ void register_normalization_ops(py::module_& m) { py::array_t weight, float eps) -> Tensor { auto inBuf = input.request(); + require_c_contiguous_float(inBuf); uint32_t features, batchSize, seqLen; if (inBuf.ndim == 1) { @@ -124,13 +145,18 @@ void register_normalization_ops(py::module_& m) { py::array_t result(inBuf.shape); auto rBuf = result.request(); + auto wBuf = weight.request(); + require_c_contiguous_float(wBuf); - grilly::ops::rmsnorm( - ctx.batch, ctx.pool, ctx.cache, - static_cast(inBuf.ptr), - static_cast(rBuf.ptr), - static_cast(weight.request().ptr), - batchSize, seqLen, features, eps); + { + py::gil_scoped_release release; + grilly::ops::rmsnorm( + ctx.batch, ctx.pool, ctx.cache, + static_cast(inBuf.ptr), + static_cast(rBuf.ptr), + static_cast(wBuf.ptr), + batchSize, seqLen, features, eps); + } return Tensor::from_numpy(result); }, @@ -147,6 +173,7 @@ void register_normalization_ops(py::module_& m) { py::array_t running_mean, py::array_t running_var, float eps, float momentum, bool training) -> py::dict { auto inBuf = input.request(); + require_c_contiguous_float(inBuf); if (inBuf.ndim != 4) throw std::runtime_error("input must be 4D (B, C, H, W)"); @@ -168,16 +195,25 @@ void register_normalization_ops(py::module_& m) { grilly::ops::BatchNorm2dForwardParams p{ B, C, H, W, eps, momentum, training ? 1u : 0u, 1u}; - grilly::ops::batchnorm2dForward( - ctx.batch, ctx.pool, ctx.cache, - static_cast(inBuf.ptr), - static_cast(output.request().ptr), - static_cast(gamma.request().ptr), - static_cast(beta.request().ptr), - static_cast(rmOut.mutable_data()), - static_cast(rvOut.mutable_data()), - static_cast(bMean.mutable_data()), - static_cast(bVar.mutable_data()), p); + auto outBuf = output.request(); + auto gaBuf = gamma.request(); + auto beBuf = beta.request(); + require_c_contiguous_float(gaBuf); + require_c_contiguous_float(beBuf); + + { + py::gil_scoped_release release; + grilly::ops::batchnorm2dForward( + ctx.batch, ctx.pool, ctx.cache, + static_cast(inBuf.ptr), + static_cast(outBuf.ptr), + static_cast(gaBuf.ptr), + static_cast(beBuf.ptr), + static_cast(rmOut.mutable_data()), + static_cast(rvOut.mutable_data()), + static_cast(bMean.mutable_data()), + static_cast(bVar.mutable_data()), p); + } py::dict result; result["output"] = Tensor::from_numpy(output); @@ -200,6 +236,7 @@ void register_normalization_ops(py::module_& m) { [](GrillyCoreContext& ctx, py::array_t input, int dim) -> Tensor { auto inBuf = input.request(); + require_c_contiguous_float(inBuf); if (inBuf.ndim < 1) throw std::runtime_error("input must be at least 1D"); @@ -211,11 +248,15 @@ void register_normalization_ops(py::module_& m) { if (inBuf.ndim == 1) totalBatch = 1; py::array_t result(inBuf.shape); - grilly::ops::softmax( - ctx.batch, ctx.pool, ctx.cache, - static_cast(inBuf.ptr), - static_cast(result.request().ptr), - 1, totalBatch, features); + auto rBuf = result.request(); + { + py::gil_scoped_release release; + grilly::ops::softmax( + ctx.batch, ctx.pool, ctx.cache, + static_cast(inBuf.ptr), + static_cast(rBuf.ptr), + 1, totalBatch, features); + } return Tensor::from_numpy(result); }, py::arg("device"), py::arg("input"), py::arg("dim") = -1, @@ -228,6 +269,7 @@ void register_normalization_ops(py::module_& m) { py::array_t grad_output, py::array_t softmax_output) -> Tensor { auto gBuf = grad_output.request(); + require_c_contiguous_float(gBuf); uint32_t numClasses = static_cast( gBuf.shape[gBuf.ndim - 1]); uint32_t batchSeq = 1; @@ -236,13 +278,19 @@ void register_normalization_ops(py::module_& m) { if (gBuf.ndim == 1) batchSeq = 1; py::array_t result(gBuf.shape); - grilly::ops::softmaxBackward( - ctx.batch, ctx.pool, ctx.cache, - static_cast(gBuf.ptr), - static_cast( - softmax_output.request().ptr), - static_cast(result.request().ptr), - 1, batchSeq, numClasses); + auto sBuf = softmax_output.request(); + auto rBuf = result.request(); + require_c_contiguous_float(sBuf); + + { + py::gil_scoped_release release; + grilly::ops::softmaxBackward( + ctx.batch, ctx.pool, ctx.cache, + static_cast(gBuf.ptr), + static_cast(sBuf.ptr), + static_cast(rBuf.ptr), + 1, batchSeq, numClasses); + } return Tensor::from_numpy(result); }, py::arg("device"), py::arg("grad_output"), diff --git a/cpp/python/bindings_optim.cpp b/cpp/python/bindings_optim.cpp index bb8bf0b..afc61d2 100644 --- a/cpp/python/bindings_optim.cpp +++ b/cpp/python/bindings_optim.cpp @@ -20,9 +20,18 @@ void register_optim_ops(py::module_& m) { float lr, float beta1, float beta2, float eps, float beta1_t, float beta2_t, bool clear_grad) -> py::dict { auto wBuf = weights.request(); - uint32_t total = 1; - for (int i = 0; i < wBuf.ndim; ++i) - total *= static_cast(wBuf.shape[i]); + auto gBuf = grad.request(); + auto mBuf = m_state.request(); + auto vBuf = v_state.request(); + require_c_contiguous_float(wBuf); + require_c_contiguous_float(gBuf); + require_c_contiguous_float(mBuf); + require_c_contiguous_float(vBuf); + if (wBuf.size != gBuf.size || wBuf.size != mBuf.size || + wBuf.size != vBuf.size) + throw std::runtime_error( + "adam_update: weights, grad, m, v must match shape"); + uint32_t total = static_cast(wBuf.size); py::array_t wOut(wBuf.shape); py::array_t gOut(wBuf.shape); @@ -31,22 +40,25 @@ void register_optim_ops(py::module_& m) { std::memcpy(wOut.mutable_data(), wBuf.ptr, total * sizeof(float)); - std::memcpy(gOut.mutable_data(), grad.data(), + std::memcpy(gOut.mutable_data(), gBuf.ptr, total * sizeof(float)); - std::memcpy(mOut.mutable_data(), m_state.data(), + std::memcpy(mOut.mutable_data(), mBuf.ptr, total * sizeof(float)); - std::memcpy(vOut.mutable_data(), v_state.data(), + std::memcpy(vOut.mutable_data(), vBuf.ptr, total * sizeof(float)); grilly::ops::AdamParams p{total, lr, beta1, beta2, eps, beta1_t, beta2_t, clear_grad ? 1u : 0u}; - grilly::ops::adamUpdate( - ctx.batch, ctx.pool, ctx.cache, - static_cast(wOut.mutable_data()), - static_cast(gOut.mutable_data()), - static_cast(mOut.mutable_data()), - static_cast(vOut.mutable_data()), p); + { + py::gil_scoped_release release; + grilly::ops::adamUpdate( + ctx.batch, ctx.pool, ctx.cache, + static_cast(wOut.mutable_data()), + static_cast(gOut.mutable_data()), + static_cast(mOut.mutable_data()), + static_cast(vOut.mutable_data()), p); + } py::dict result; result["weights"] = Tensor::from_numpy(wOut); @@ -73,9 +85,18 @@ void register_optim_ops(py::module_& m) { float weight_decay, float beta1_t, float beta2_t, bool clear_grad) -> py::dict { auto wBuf = weights.request(); - uint32_t total = 1; - for (int i = 0; i < wBuf.ndim; ++i) - total *= static_cast(wBuf.shape[i]); + auto gBuf = grad.request(); + auto mBuf = m_state.request(); + auto vBuf = v_state.request(); + require_c_contiguous_float(wBuf); + require_c_contiguous_float(gBuf); + require_c_contiguous_float(mBuf); + require_c_contiguous_float(vBuf); + if (wBuf.size != gBuf.size || wBuf.size != mBuf.size || + wBuf.size != vBuf.size) + throw std::runtime_error( + "adamw_update: weights, grad, m, v must match shape"); + uint32_t total = static_cast(wBuf.size); py::array_t wOut(wBuf.shape); py::array_t gOut(wBuf.shape); @@ -84,22 +105,25 @@ void register_optim_ops(py::module_& m) { std::memcpy(wOut.mutable_data(), wBuf.ptr, total * sizeof(float)); - std::memcpy(gOut.mutable_data(), grad.data(), + std::memcpy(gOut.mutable_data(), gBuf.ptr, total * sizeof(float)); - std::memcpy(mOut.mutable_data(), m_state.data(), + std::memcpy(mOut.mutable_data(), mBuf.ptr, total * sizeof(float)); - std::memcpy(vOut.mutable_data(), v_state.data(), + std::memcpy(vOut.mutable_data(), vBuf.ptr, total * sizeof(float)); grilly::ops::AdamWParams p{total, lr, beta1, beta2, eps, weight_decay, beta1_t, beta2_t, clear_grad ? 1u : 0u}; - grilly::ops::adamwUpdate( - ctx.batch, ctx.pool, ctx.cache, - static_cast(wOut.mutable_data()), - static_cast(gOut.mutable_data()), - static_cast(mOut.mutable_data()), - static_cast(vOut.mutable_data()), p); + { + py::gil_scoped_release release; + grilly::ops::adamwUpdate( + ctx.batch, ctx.pool, ctx.cache, + static_cast(wOut.mutable_data()), + static_cast(gOut.mutable_data()), + static_cast(mOut.mutable_data()), + static_cast(vOut.mutable_data()), p); + } py::dict result; result["weights"] = Tensor::from_numpy(wOut); diff --git a/cpp/python/bindings_perceiver.cpp b/cpp/python/bindings_perceiver.cpp index e36b055..a69db31 100644 --- a/cpp/python/bindings_perceiver.cpp +++ b/cpp/python/bindings_perceiver.cpp @@ -28,6 +28,9 @@ void register_perceiver_ops(py::module_& m) { auto qBuf = Q_arr.request(); auto kBuf = K_arr.request(); auto vBuf = V_arr.request(); + require_c_contiguous_float(qBuf); + require_c_contiguous_float(kBuf); + require_c_contiguous_float(vBuf); uint32_t seqN = static_cast(qBuf.shape[0]); uint32_t headDim = static_cast(qBuf.shape[qBuf.ndim - 1]); @@ -68,6 +71,7 @@ void register_perceiver_ops(py::module_& m) { for (auto& w : weights) { auto arr = w.cast>(); auto buf = arr.request(); + require_c_contiguous_float(buf); const float* ptr = static_cast(buf.ptr); cpp_weights.emplace_back(ptr, ptr + buf.size); } @@ -92,6 +96,7 @@ void register_perceiver_ops(py::module_& m) { auto& pc = ops::perceiver_get_cache(handle); auto pBuf = patches.request(); + require_c_contiguous_float(pBuf); uint32_t nPatches = static_cast(pBuf.shape[0]); const float* patchPtr = static_cast(pBuf.ptr); diff --git a/cpp/python/bindings_pooling.cpp b/cpp/python/bindings_pooling.cpp index 6676f82..a18880d 100644 --- a/cpp/python/bindings_pooling.cpp +++ b/cpp/python/bindings_pooling.cpp @@ -19,6 +19,7 @@ void register_pooling_ops(py::module_& m) { std::vector padding, std::vector dilation) -> py::dict { auto inBuf = input.request(); + require_c_contiguous_float(inBuf); if (inBuf.ndim != 4) throw std::runtime_error( "input must be 4D (B, C, H, W)"); @@ -53,14 +54,22 @@ void register_pooling_ops(py::module_& m) { static_cast(oH), static_cast(oW)}); + auto resBuf = result.request(); + auto idxBuf = indices.request(); + require_c_contiguous_float(resBuf); + require_c_contiguous_uint32(idxBuf); + grilly::ops::MaxPool2dParams p{B, C, iH, iW, oH, oW, kH, kW, sH, sW, pH, pW, dH, dW}; - grilly::ops::maxpool2dForward( - ctx.batch, ctx.pool, ctx.cache, - static_cast(inBuf.ptr), - static_cast(result.request().ptr), - static_cast(indices.request().ptr), p); + { + py::gil_scoped_release release; + grilly::ops::maxpool2dForward( + ctx.batch, ctx.pool, ctx.cache, + static_cast(inBuf.ptr), + static_cast(resBuf.ptr), + static_cast(idxBuf.ptr), p); + } py::dict out; out["output"] = Tensor::from_numpy(result); @@ -85,6 +94,7 @@ void register_pooling_ops(py::module_& m) { std::vector padding, bool count_include_pad) -> Tensor { auto inBuf = input.request(); + require_c_contiguous_float(inBuf); if (inBuf.ndim != 4) throw std::runtime_error("input must be 4D"); @@ -111,13 +121,19 @@ void register_pooling_ops(py::module_& m) { static_cast(oH), static_cast(oW)}); + auto resBuf = result.request(); + require_c_contiguous_float(resBuf); + grilly::ops::AvgPool2dParams p{ B, C, iH, iW, oH, oW, kH, kW, sH, sW, pH, pW, count_include_pad ? 1u : 0u}; - grilly::ops::avgpool2dForward( - ctx.batch, ctx.pool, ctx.cache, - static_cast(inBuf.ptr), - static_cast(result.request().ptr), p); + { + py::gil_scoped_release release; + grilly::ops::avgpool2dForward( + ctx.batch, ctx.pool, ctx.cache, + static_cast(inBuf.ptr), + static_cast(resBuf.ptr), p); + } return Tensor::from_numpy(result); }, py::arg("device"), py::arg("input"), @@ -133,6 +149,7 @@ void register_pooling_ops(py::module_& m) { [](GrillyCoreContext& ctx, py::array_t input) -> Tensor { auto inBuf = input.request(); + require_c_contiguous_float(inBuf); if (inBuf.ndim != 3) throw std::runtime_error( "input must be 3D (B, S, D)"); @@ -145,11 +162,17 @@ void register_pooling_ops(py::module_& m) { static_cast(B), static_cast(D)}); + auto resBuf = result.request(); + require_c_contiguous_float(resBuf); + grilly::ops::MeanPoolParams p{B, S, D}; - grilly::ops::meanPool( - ctx.batch, ctx.pool, ctx.cache, - static_cast(inBuf.ptr), - static_cast(result.request().ptr), p); + { + py::gil_scoped_release release; + grilly::ops::meanPool( + ctx.batch, ctx.pool, ctx.cache, + static_cast(inBuf.ptr), + static_cast(resBuf.ptr), p); + } return Tensor::from_numpy(result); }, py::arg("device"), py::arg("input"), diff --git a/cpp/python/bindings_siglip.cpp b/cpp/python/bindings_siglip.cpp index e699ed7..c203bc1 100644 --- a/cpp/python/bindings_siglip.cpp +++ b/cpp/python/bindings_siglip.cpp @@ -45,6 +45,7 @@ static int g_nextHandle = 1; static GrillyBuffer uploadPersistent(BufferPool& pool, py::array_t arr) { auto buf = arr.request(); + require_c_contiguous_float(buf); size_t bytes = buf.size * sizeof(float); GrillyBuffer gpuBuf = pool.acquire(bytes); pool.upload(gpuBuf, static_cast(buf.ptr), bytes); @@ -141,6 +142,7 @@ void register_siglip_ops(py::module_& m) { auto& wc = it->second; auto pBuf = patches.request(); + require_c_contiguous_float(pBuf); const uint32_t S = wc.seqLen; const uint32_t H = wc.hidden; const uint32_t H3 = H * 3; diff --git a/cpp/python/bindings_snn.cpp b/cpp/python/bindings_snn.cpp index 96a8372..c7d5338 100644 --- a/cpp/python/bindings_snn.cpp +++ b/cpp/python/bindings_snn.cpp @@ -19,6 +19,7 @@ void register_snn_ops(py::module_& m) { float v_thresh, float r_mem, float t_refrac_period) -> py::dict { auto inBuf = input.request(); + require_c_contiguous_float(inBuf); uint32_t n = 1; for (int i = 0; i < inBuf.ndim; ++i) n *= static_cast(inBuf.shape[i]); @@ -26,21 +27,29 @@ void register_snn_ops(py::module_& m) { py::array_t vMemOut(v_mem.request().shape); py::array_t refracOut(t_refrac.request().shape); py::array_t spikes(inBuf.shape); + auto vmIn = v_mem.request(); + auto trIn = t_refrac.request(); + require_c_contiguous_float(vmIn); + require_c_contiguous_float(trIn); + auto vmOut = vMemOut.request(); + auto rfOut = refracOut.request(); + auto spOut = spikes.request(); - std::memcpy(vMemOut.request().ptr, v_mem.request().ptr, - n * sizeof(float)); - std::memcpy(refracOut.request().ptr, - t_refrac.request().ptr, n * sizeof(float)); + std::memcpy(vmOut.ptr, vmIn.ptr, n * sizeof(float)); + std::memcpy(rfOut.ptr, trIn.ptr, n * sizeof(float)); grilly::ops::LIFParams p{n, dt, tau_mem, v_rest, v_reset, v_thresh, r_mem, t_refrac_period}; - grilly::ops::lifStep( - ctx.batch, ctx.pool, ctx.cache, - static_cast(inBuf.ptr), - static_cast(vMemOut.request().ptr), - static_cast(refracOut.request().ptr), - static_cast(spikes.request().ptr), p); + { + py::gil_scoped_release release; + grilly::ops::lifStep( + ctx.batch, ctx.pool, ctx.cache, + static_cast(inBuf.ptr), + static_cast(vmOut.ptr), + static_cast(rfOut.ptr), + static_cast(spOut.ptr), p); + } py::dict result; result["spikes"] = Tensor::from_numpy(spikes); @@ -66,6 +75,7 @@ void register_snn_ops(py::module_& m) { float v_reset, uint32_t reset_mode, uint32_t decay_input) -> py::dict { auto xBuf = x_in.request(); + require_c_contiguous_float(xBuf); uint32_t n = 1; for (int i = 0; i < xBuf.ndim; ++i) n *= static_cast(xBuf.shape[i]); @@ -73,22 +83,30 @@ void register_snn_ops(py::module_& m) { py::array_t vMemOut(v_mem.request().shape); py::array_t spikes(xBuf.shape); py::array_t hOut(xBuf.shape); + auto vmIn = v_mem.request(); + require_c_contiguous_float(vmIn); + auto tpIn = tau_param.request(); + require_c_contiguous_float(tpIn); + auto vmOut = vMemOut.request(); + auto spOut = spikes.request(); + auto hR = hOut.request(); - std::memcpy(vMemOut.request().ptr, v_mem.request().ptr, - n * sizeof(float)); + std::memcpy(vmOut.ptr, vmIn.ptr, n * sizeof(float)); grilly::ops::SNNNodeForwardParams p{ n, neuron_type, tau, v_threshold, v_reset, reset_mode, decay_input}; - grilly::ops::snnNodeForward( - ctx.batch, ctx.pool, ctx.cache, - static_cast(xBuf.ptr), - static_cast(vMemOut.request().ptr), - static_cast(spikes.request().ptr), - static_cast(hOut.request().ptr), - static_cast( - tau_param.request().ptr), p); + { + py::gil_scoped_release release; + grilly::ops::snnNodeForward( + ctx.batch, ctx.pool, ctx.cache, + static_cast(xBuf.ptr), + static_cast(vmOut.ptr), + static_cast(spOut.ptr), + static_cast(hR.ptr), + static_cast(tpIn.ptr), p); + } py::dict result; result["spikes"] = Tensor::from_numpy(spikes); @@ -112,21 +130,27 @@ void register_snn_ops(py::module_& m) { float alpha, uint32_t surrogate_type, float v_threshold) -> Tensor { auto gBuf = grad_spike.request(); + require_c_contiguous_float(gBuf); uint32_t n = 1; for (int i = 0; i < gBuf.ndim; ++i) n *= static_cast(gBuf.shape[i]); py::array_t gradX(gBuf.shape); + auto hBuf = h_cache.request(); + require_c_contiguous_float(hBuf); + auto gxBuf = gradX.request(); grilly::ops::SNNNodeBackwardParams p{ n, alpha, surrogate_type, v_threshold}; - grilly::ops::snnNodeBackward( - ctx.batch, ctx.pool, ctx.cache, - static_cast(gBuf.ptr), - static_cast( - h_cache.request().ptr), - static_cast(gradX.request().ptr), p); + { + py::gil_scoped_release release; + grilly::ops::snnNodeBackward( + ctx.batch, ctx.pool, ctx.cache, + static_cast(gBuf.ptr), + static_cast(hBuf.ptr), + static_cast(gxBuf.ptr), p); + } return Tensor::from_numpy(gradX); }, @@ -143,10 +167,13 @@ void register_snn_ops(py::module_& m) { py::array_t weights, uint32_t batch_size, uint32_t time_steps, float learning_rate, - float weight_decay) -> Tensor { + float weight_decay) -> Tensor { auto preBuf = pre.request(); auto postBuf = post.request(); auto wBuf = weights.request(); + require_c_contiguous_float(preBuf); + require_c_contiguous_float(postBuf); + require_c_contiguous_float(wBuf); uint32_t pre_dim = static_cast( preBuf.shape[preBuf.ndim - 1]); @@ -154,18 +181,22 @@ void register_snn_ops(py::module_& m) { postBuf.shape[postBuf.ndim - 1]); py::array_t wOut(wBuf.shape); - std::memcpy(wOut.request().ptr, wBuf.ptr, + auto woBuf = wOut.request(); + std::memcpy(woBuf.ptr, wBuf.ptr, size_t(pre_dim) * post_dim * sizeof(float)); grilly::ops::HebbianParams p{batch_size, time_steps, pre_dim, post_dim, learning_rate, weight_decay}; - grilly::ops::hebbianLearning( - ctx.batch, ctx.pool, ctx.cache, - static_cast(preBuf.ptr), - static_cast(postBuf.ptr), - static_cast(wOut.request().ptr), p); + { + py::gil_scoped_release release; + grilly::ops::hebbianLearning( + ctx.batch, ctx.pool, ctx.cache, + static_cast(preBuf.ptr), + static_cast(postBuf.ptr), + static_cast(woBuf.ptr), p); + } return Tensor::from_numpy(wOut); }, @@ -189,6 +220,9 @@ void register_snn_ops(py::module_& m) { auto preBuf = pre.request(); auto postBuf = post.request(); auto wBuf = weights.request(); + require_c_contiguous_float(preBuf); + require_c_contiguous_float(postBuf); + require_c_contiguous_float(wBuf); uint32_t pre_dim = static_cast( preBuf.shape[preBuf.ndim - 1]); @@ -200,27 +234,35 @@ void register_snn_ops(py::module_& m) { pre_trace.request().shape); py::array_t postTraceOut( post_trace.request().shape); - - std::memcpy(wOut.request().ptr, wBuf.ptr, + auto ptIn = pre_trace.request(); + auto pstIn = post_trace.request(); + require_c_contiguous_float(ptIn); + require_c_contiguous_float(pstIn); + auto woBuf = wOut.request(); + auto ptoBuf = preTraceOut.request(); + auto pstoBuf = postTraceOut.request(); + + std::memcpy(woBuf.ptr, wBuf.ptr, size_t(pre_dim) * post_dim * sizeof(float)); - std::memcpy(preTraceOut.request().ptr, - pre_trace.request().ptr, + std::memcpy(ptoBuf.ptr, ptIn.ptr, size_t(batch_size) * pre_dim * sizeof(float)); - std::memcpy(postTraceOut.request().ptr, - post_trace.request().ptr, + std::memcpy(pstoBuf.ptr, pstIn.ptr, size_t(batch_size) * post_dim * sizeof(float)); grilly::ops::STDPParams p{batch_size, time_steps, pre_dim, post_dim, lr_pot, lr_dep, trace_decay, 0}; - grilly::ops::stdpLearning( - ctx.batch, ctx.pool, ctx.cache, - static_cast(preBuf.ptr), - static_cast(postBuf.ptr), - static_cast(wOut.request().ptr), - static_cast(preTraceOut.request().ptr), - static_cast(postTraceOut.request().ptr), p); + { + py::gil_scoped_release release; + grilly::ops::stdpLearning( + ctx.batch, ctx.pool, ctx.cache, + static_cast(preBuf.ptr), + static_cast(postBuf.ptr), + static_cast(woBuf.ptr), + static_cast(ptoBuf.ptr), + static_cast(pstoBuf.ptr), p); + } py::dict result; result["weights"] = Tensor::from_numpy(wOut); @@ -244,21 +286,26 @@ void register_snn_ops(py::module_& m) { py::array_t y_state, float decay) -> Tensor { auto xBuf = x_in.request(); + require_c_contiguous_float(xBuf); uint32_t n = 1; for (int i = 0; i < xBuf.ndim; ++i) n *= static_cast(xBuf.shape[i]); py::array_t yOut(y_state.request().shape); - std::memcpy(yOut.request().ptr, - y_state.request().ptr, - n * sizeof(float)); + auto ysIn = y_state.request(); + require_c_contiguous_float(ysIn); + auto yoBuf = yOut.request(); + std::memcpy(yoBuf.ptr, ysIn.ptr, n * sizeof(float)); grilly::ops::SynapseFilterParams p{n, decay}; - grilly::ops::synapseFilter( - ctx.batch, ctx.pool, ctx.cache, - static_cast(xBuf.ptr), - static_cast(yOut.request().ptr), p); + { + py::gil_scoped_release release; + grilly::ops::synapseFilter( + ctx.batch, ctx.pool, ctx.cache, + static_cast(xBuf.ptr), + static_cast(yoBuf.ptr), p); + } return Tensor::from_numpy(yOut); }, @@ -279,8 +326,9 @@ void register_snn_ops(py::module_& m) { float tau_mem, float v_rest, float v_reset, float v_thresh, float r_mem, float tau_adapt, float delta_adapt, float b_adapt, float tau_gate, float gate_strength, - float t_refrac_period) -> py::dict { + float t_refrac_period) -> py::dict { auto inBuf = input.request(); + require_c_contiguous_float(inBuf); uint32_t n = 1; for (int i = 0; i < inBuf.ndim; ++i) n *= static_cast(inBuf.shape[i]); @@ -294,35 +342,52 @@ void register_snn_ops(py::module_& m) { py::array_t tLastOut( t_last_spike.request().shape); - std::memcpy(vMemOut.request().ptr, - v_mem.request().ptr, n * sizeof(float)); - std::memcpy(iAdaptOut.request().ptr, - i_adapt.request().ptr, n * sizeof(float)); - std::memcpy(gInputOut.request().ptr, - g_input.request().ptr, n * sizeof(float)); - std::memcpy(gForgetOut.request().ptr, - g_forget.request().ptr, n * sizeof(float)); - std::memcpy(refracOut.request().ptr, - t_refrac.request().ptr, n * sizeof(float)); - std::memcpy(tLastOut.request().ptr, - t_last_spike.request().ptr, - n * sizeof(float)); + auto vmIn = v_mem.request(); + auto iaIn = i_adapt.request(); + auto giIn = g_input.request(); + auto gfIn = g_forget.request(); + auto trIn = t_refrac.request(); + auto tlsIn = t_last_spike.request(); + require_c_contiguous_float(vmIn); + require_c_contiguous_float(iaIn); + require_c_contiguous_float(giIn); + require_c_contiguous_float(gfIn); + require_c_contiguous_float(trIn); + require_c_contiguous_float(tlsIn); + + auto vmOut = vMemOut.request(); + auto iaOut = iAdaptOut.request(); + auto giOut = gInputOut.request(); + auto gfOut = gForgetOut.request(); + auto trOut = refracOut.request(); + auto spOut = spikes.request(); + auto tlsOut = tLastOut.request(); + + std::memcpy(vmOut.ptr, vmIn.ptr, n * sizeof(float)); + std::memcpy(iaOut.ptr, iaIn.ptr, n * sizeof(float)); + std::memcpy(giOut.ptr, giIn.ptr, n * sizeof(float)); + std::memcpy(gfOut.ptr, gfIn.ptr, n * sizeof(float)); + std::memcpy(trOut.ptr, trIn.ptr, n * sizeof(float)); + std::memcpy(tlsOut.ptr, tlsIn.ptr, n * sizeof(float)); grilly::ops::GIFParams p{ n, dt, current_time, tau_mem, v_rest, v_reset, v_thresh, r_mem, tau_adapt, delta_adapt, b_adapt, tau_gate, gate_strength, t_refrac_period}; - grilly::ops::gifNeuronStep( - ctx.batch, ctx.pool, ctx.cache, - static_cast(inBuf.ptr), - static_cast(vMemOut.request().ptr), - static_cast(iAdaptOut.request().ptr), - static_cast(gInputOut.request().ptr), - static_cast(gForgetOut.request().ptr), - static_cast(refracOut.request().ptr), - static_cast(spikes.request().ptr), - static_cast(tLastOut.request().ptr), p); + { + py::gil_scoped_release release; + grilly::ops::gifNeuronStep( + ctx.batch, ctx.pool, ctx.cache, + static_cast(inBuf.ptr), + static_cast(vmOut.ptr), + static_cast(iaOut.ptr), + static_cast(giOut.ptr), + static_cast(gfOut.ptr), + static_cast(trOut.ptr), + static_cast(spOut.ptr), + static_cast(tlsOut.ptr), p); + } py::dict result; result["spikes"] = Tensor::from_numpy(spikes); diff --git a/docs/PERF_DISPATCH.md b/docs/PERF_DISPATCH.md new file mode 100644 index 0000000..b9fb49c --- /dev/null +++ b/docs/PERF_DISPATCH.md @@ -0,0 +1,48 @@ +# GPU dispatch performance (Python `VulkanCore` / `VulkanCompute`) + +This document summarizes **non-blocking** and **batched** Vulkan dispatch APIs and how they relate to the PyTorch parity performance workstreams (B1, B2, D4). + +## Synchronous vs async (`backend/core.py`) + +- **`VulkanCore._dispatch_compute(...)`** — Records one command buffer, submits, **waits on a fence** (default `wait_previous=True`). Safe and simple; one fence wait per call. +- **`VulkanCore._dispatch_compute_async(...)`** — Submits work **without** waiting; pair with **`_wait_async()`** before reading GPU results. +- **`VulkanCore.record_commands()`** — Returns a **`CommandRecorder`** context manager: multiple `dispatch` + `barrier` calls, then **one** submit + wait on `__exit__`. Used by FlashAttention2 tiling, RMSNorm two-pass, and **`VulkanFNN._linear_relu_recorded_chain`** (Linear→ReLU fallback when `fused-linear-relu.spv` is missing). + +## Public aliases on `VulkanCompute` (`backend/compute.py`) + +After `Compute()` / `VulkanCompute()` construction: + +| Attribute | Underlying API | +|-----------|----------------| +| `record_commands` | `core.record_commands()` | +| `dispatch_compute` | `core._dispatch_compute` | +| `dispatch_compute_async` | `core._dispatch_compute_async` | +| `wait_async` | `core._wait_async` | +| `wait_fence` | `core._wait_fence` | + +Prefer **`record_commands`** for dependent kernels that share the same queue and do not need host-visible results between dispatches. + +## `nn.Sequential` fusion + +`nn.Sequential` already detects **Linear → ReLU / GELU / SiLU** and calls **`backend.fnn.fused_*`** when those shaders exist. If the fused shader is absent, **`fused_linear_relu`** tries **`_linear_relu_recorded_chain`** (two dispatches, **one** submit) before falling back to separate `linear` + `activation_relu` calls. + +## Pybind11 GIL policy (workstream B3) + +Heavy `grilly_core` entrypoints should release the GIL around GPU work: + +```cpp +{ + py::gil_scoped_release release; + grilly::ops::someOp(...); +} +``` + +Bindings should **request** `py::buffer_info` / output arrays **before** releasing the GIL. New array-based inputs should satisfy **`require_c_contiguous_float`** / **`require_c_contiguous_uint32`** / **`require_c_contiguous_int8`** (`bindings_core.h`) so kernels see dense row-major data. + +## Code review checklist (bindings) + +1. GIL held only for buffer prep and return value wrapping; **not** during `CommandBatch` / pool work. +2. `float32` numpy arrays are **C-contiguous** (or explicitly copied) for GPU kernels. +3. No duplicate `request()` on invalid buffers after GIL re-acquire without re-verifying lifetimes. + +Covered in this policy (non-exhaustive): activations, attention, linear, conv, normalization, loss, SNN, pooling, optim GPU steps, **misc** (dropout, embedding, KV cache GPU paths, `swizzle_kv`), **Hamming** (GPU path), **SigLIP** / **Perceiver** / **MoQE train**, **`ShaderFusionEngine.fuse`**, **`OpGraph.optimize` / `OpGraph.execute`**. diff --git a/nn/containers.py b/nn/containers.py index a4b3599..ffd9fc1 100644 --- a/nn/containers.py +++ b/nn/containers.py @@ -21,6 +21,8 @@ class Sequential(Module): Caches intermediate activations during forward pass for efficient backward pass. Automatically fuses Linear+Activation pairs when fused GPU shaders are available. + When a fused shader is missing, `fused_linear_relu` may use a two-kernel **single-submit** + path (`VulkanFNN._linear_relu_recorded_chain`) before falling back to separate calls. """ def __init__(self, *modules): diff --git a/shaders/int8-gemm.glsl b/shaders/int8-gemm.glsl index 1885164..4e767fe 100644 --- a/shaders/int8-gemm.glsl +++ b/shaders/int8-gemm.glsl @@ -54,23 +54,29 @@ void main() { uint packed_K = (K + 3) / 4; // number of uint32s per row float sum = 0.0; - - for (uint k = 0; k < K; k++) { - // Get activation value - float act = activations[row * K + k]; - - // Get packed weight and unpack - uint pack_idx = k / 4; - uint pack_offset = k % 4; + uint base = row * K; + + // Inner K: process 4 elements per iteration — one packed uint32 load per 4 weights. + uint k = 0; + for (; k + 4u <= K; k += 4u) { + uint pk = weights_packed[col * packed_K + k / 4u]; + uint g0 = k / group_size; + uint g1 = (k + 1u) / group_size; + uint g2 = (k + 2u) / group_size; + uint g3 = (k + 3u) / group_size; + sum += activations[base + k] * unpack_int8(pk, 0) * scales[col * num_groups + g0]; + sum += activations[base + k + 1u] * unpack_int8(pk, 1) * scales[col * num_groups + g1]; + sum += activations[base + k + 2u] * unpack_int8(pk, 2) * scales[col * num_groups + g2]; + sum += activations[base + k + 3u] * unpack_int8(pk, 3) * scales[col * num_groups + g3]; + } + for (; k < K; k++) { + uint pack_idx = k / 4u; + uint pack_offset = k % 4u; uint packed = weights_packed[col * packed_K + pack_idx]; float w = unpack_int8(packed, pack_offset); - - // Get per-group scale uint group_idx = k / group_size; float s = scales[col * num_groups + group_idx]; - - // FP32 accumulate: act * (w_int8 * scale) - sum += act * w * s; + sum += activations[base + k] * w * s; } output_data[row * N + col] = sum; diff --git a/shaders/spv/int8-gemm.spv b/shaders/spv/int8-gemm.spv new file mode 100644 index 0000000000000000000000000000000000000000..16a07cea8fce3c3444327ee2e889cc2cc27ff7c1 GIT binary patch literal 7620 zcmZ9Q37nN>6^B2Vg<%KTk!1k408v0z(kG|Yb59fdW=e*~fce~%c(=c$>pe!4Z4b6sU z%LionGdddxlfgCTx^35r|wzU@&+WYzn<^G-BorUgFPiLW{w5_wfzqGy6xRxZY z)YINw=|}^bh)C->Fm&MV4hi4>bN5J zjR2Q7mp1jIoA~bp^v0e2orUeiZeY$$Mwd%nJ)Ip@)*Q3p+Ize4x8(Eh>Mpc(_iik9 zukP92+g|K1_4cejKOf(aEkJMU-I3>70=}qH-n_E6r`%u8@>P@bxe##`wOzY4iR~Ib zJGYy2zdO>kt+&D&mab&Y>}eUCkMnSBPi1Q%kD-p-2Ds`v6m-KhWCe7&vK8Ez6~P;e zWwKR$o6uWo_;Pd~hIvSrZRjj@ZSF6xX07RYP|iVTJF0vkUGItx9_w9&?&4I5_lW0L z@m?_tCBIL+P~4Q=2wvg2$d|NCt>&DA&WZ`|xc!nm&gB5QuehzawPqizvv0*--(KwQ z%=e%v)AlLXIP=|JrSg1-z-hjF>g@aK?ECBNBX#!CHGXVOEZ$qcGJcm(B>=U)ZPHTA< z^_xpB;?mDtBYP z*XA?EeC~akZyC1xKM5H_RqM_m+IaCe##pZv@oB{#gjVFo6X~9}A~V>x!D#J>nNu6@ z`$RtM@!%&A^E>xeWI(PP^IjY;Z@hXMvKtvoSJTnK=W}J?1mfnv+mpT9=UlWo<oZ66xv<7fgY`b>OT}?Z;heXA>uCE~E*E_+2m2V$Yc2Oh9p|?a zalIpcEjaqrzYcL;(bvUw_9fWfEAzh;?fS)ft1m@-rjjM@;WV)KOxtn3L+wG%d1!mj z&%k2Nw(+j9k9+Y5jNhfq z>3aK}LLD}*!{h8Vb@qnDcF%m5r{jxt_NF@9cX}G{yE(P@)Y-nHQ@`)&)ILyWAFQ(v z)!Dv7)A7DTQ`>iFY9FbykJj18>g)lmeLCKEXV|Wn@6NE*!|QC{pQ+#XXV{MS-Ko8T zUG(nrwI+Wb`k8PJd)kbiU*peZ$K7A;h(8PL=ZWVd;unGCjd%aG!?zeL=XgH@o{soA z;b%j{oKxfPAjZ$Ra}meM-$qYz&p_lKrp5}i&tv%H&qL%qo8rj(OtAe&?eoFbUWr7^ z1z`E8{VcHK{5 zU-aa6X8DM@2yBd>hY|A}u>5YiaL=BLI8OeW8n?d2yH}CxdEok7&xe!0t~TxkV8_b4 zr`};Pl+IW7pU0fRo=- z>$?-|IC*mx(Q@KVXwSHN+J$yc)DN*&CG=&8kNz!a{p#qa8*Gl4-&U}kbBT570oV7~ z3nza!J$iQf5XZ`UPP8{8_FzqM)Z7mC{D!>)?3%@VuK*h(A2B<@_8Dt_C0NcrqxYAA z9jk8zr)L-X-0-5&FK9k{;E>*3_1 z&KtmvmAB4bv|QAABRH+|CTuzHMbvpSxW3L?;N+vuTfvT%x6VGaT-137*z+FtJHct) zcVQbN@1Ea;emAloanJQTwgr7NqAzm22khQk_W)X3d=NbkSw=p`y%!mV7^Cl6&b<8l zkedl$9I5@laD#w3BDb1PWod`cY*aqPag)` zmwd$B4K5&8*ZTelSk74It4}Wa`6xJD^N(StYyNR;W8|afPk`kYBGL0F!Lbkeqn}TK z^+ml;gJT~Kfwk@H9<=#lA3g(iAB@o#bN(#Yocbf@=fLScd>%X9hc93oBOhZ9gZCgY zr!RtylaD!l2^{;NKj!phu)gT&D`5MQkC=PG={|fFTh7?n2f66yYv6PrzK)&l!#A*v zk&m9g36{SBiJrd&j(yM{{d^m&FY0{<9Q$w|SlhntN1HG9;k)2rh%x$N&ff!@Q-9?A zJ~-WnA7H2Z@I!240@k*# zqiFNRKKvf+J{Y4f=KKe+IrT@*KZ4VJ_!D-z4}ZosMn1;;1-uW5IsFxEoP5mbZ{XMm z{V}J%gY`vE{{Y*Ue8l_{obJQFu;q-6eUS6J%VlW4N2sH}e}mI~_z$++C5Yqo$)(?o zN;UgIwE5J2S9&*oAKhxPo4e=#Q~O!TQ__fA49BZv+xX6$(z&PmvRw55uC2w zB<%G5OvW}wKE|8~HczbG6tHpfaethSJb!LM16XPD~k2OY*#GD3B@AlKM<&1Ss z`sBQ8kAS^vYR7w)XX8ufGzVMGyXJU(a*mI??4647XX0Cc9d~s;wr5~v60-nKKF<<&5*U@>0aVLAlSKWuL*w65$f35h)kJh^7a97bZBC+0aCEH?_VH}lCw{1~u3ggq8)%;DPS z%s8+y^7q!b31Ip0h&k;^E^?d*mfOYtPDD>ZX7i$%x!6B Date: Sat, 4 Apr 2026 11:14:33 -0400 Subject: [PATCH 14/17] pre v1.0 --- CHANGELOG.md | 2 + backend/base.py | 16 +++-- backend/conv.py | 86 ++++++++++++++++++++++++ cpp/python/bindings.cpp | 2 + cpp/python/bindings_core.cpp | 2 + shaders/spv/conv1x1-backward-weight.spv | Bin 3616 -> 3688 bytes tests/test_vulkan_tensor_residency.py | 40 +++++++++++ utils/tensor_conversion.py | 57 +++++++++++++++- 8 files changed, 197 insertions(+), 8 deletions(-) create mode 100644 tests/test_vulkan_tensor_residency.py diff --git a/CHANGELOG.md b/CHANGELOG.md index ef1ce8e..0d852ed 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,8 @@ This changelog follows the spirit of **Keep a Changelog** and uses the terms **A ### Changed +- **`utils/tensor_conversion.py`**, **`backend/base.py`** — `VulkanTensor.prepare_for_dispatch()` and C++ `Tensor.gpu_handle_if_valid()` binding so GPU-resident tensors reuse buffers in `BufferMixin._prepare_input` without redundant uploads when possible. +- **`docs/PYTORCH_PARITY_TASKLIST.md`**, **`docs/PYTORCH_PARITY_STATUS.md`**, **`docs/GPU_OPTIMIZATION_REVIEW.md`** — Workstream C (kernel/throughput milestone) marked complete with scoped deliverables; follow-ups consolidated under **Workstream C — future**. - **`docs/api/functional.md`** — PyTorch parity notes and links to migration docs and parity tests. - **`backend/compute.py`** — `VulkanCompute` exposes `record_commands`, `dispatch_compute`, `dispatch_compute_async`, `wait_async`, `wait_fence`. - **`backend/fnn.py`** — `_linear_relu_recorded_chain`: Linear→ReLU in one submit when `fused-linear-relu` shader is absent. diff --git a/backend/base.py b/backend/base.py index 7c931cc..630b8fa 100644 --- a/backend/base.py +++ b/backend/base.py @@ -374,18 +374,20 @@ def _prepare_input(self, data, size=None): When *data* is a VulkanTensor already on GPU, returns its existing buffer without copying. Otherwise allocates a pooled buffer and uploads. + + VulkanTensor residency: ``prepare_for_dispatch()`` binds C++ GPU handles or + pooled buffers so GPU-resident tensors skip redundant CPU→GPU uploads. """ # Avoid hard import at module level - VulkanTensor lives in utils from ..utils.tensor_conversion import VulkanTensor if isinstance(data, VulkanTensor): - if data.on_gpu: - buf = data._pooled_buffer if data._pooled_buffer is not None else None - if buf is not None: - return buf, False - # Has raw GPU buffer but not pooled - wrap for API compat - return _DirectBuffer(data._gpu_buffer, data._gpu_memory, data.nbytes), False - # CPU-backed lazy VulkanTensor + data.prepare_for_dispatch() + if data._pooled_buffer is not None: + return data._pooled_buffer, False + if data._gpu_buffer is not None: + nbytes = int(data.nbytes if size is None else size) + return _DirectBuffer(data._gpu_buffer, data._gpu_memory, nbytes), False arr = np.asarray(data.numpy(), dtype=np.float32).reshape(-1) else: arr = np.asarray(data) diff --git a/backend/conv.py b/backend/conv.py index 26c1fc7..4a3377b 100644 --- a/backend/conv.py +++ b/backend/conv.py @@ -812,6 +812,34 @@ def conv2d_backward_weight( grad_output, input_data, kernel_size, stride, padding, dilation, groups, has_bias ) + # 1×1, stride 1, pad 0, dilation 1, groups 1: optional atomic float shader (no batch dim in shader — loop batches) + kernel_h, kernel_w = kernel_size + stride_h, stride_w = stride + padding_h, padding_w = padding + dilation_h, dilation_w = dilation + batch_size, out_channels, out_height, out_width = grad_output.shape + _, in_channels, in_height, in_width = input_data.shape + atomic_ext = getattr(self.core, "enabled_extensions", None) or set() + can_1x1_atomic = ( + "conv1x1-backward-weight" in self.shaders + and "VK_EXT_shader_atomic_float" in atomic_ext + and kernel_h == 1 + and kernel_w == 1 + and stride_h == 1 + and stride_w == 1 + and padding_h == 0 + and padding_w == 0 + and dilation_h == 1 + and dilation_w == 1 + and groups == 1 + and out_height == in_height + and out_width == in_width + ) + if can_1x1_atomic: + return self._conv1x1_backward_weight_atomic( + grad_output, input_data, out_channels, in_channels, has_bias + ) + # Check if shader is available if "conv2d-backward-weight" not in self.shaders: return self._conv2d_backward_weight_cpu( @@ -924,6 +952,64 @@ def conv2d_backward_weight( finally: self._release_buffers([buf_grad_output, buf_input, buf_grad_weight, buf_grad_bias]) + def _conv1x1_backward_weight_atomic( + self, + grad_output: np.ndarray, + input_data: np.ndarray, + out_channels: int, + in_channels: int, + has_bias: bool, + ) -> tuple[np.ndarray, np.ndarray | None]: + """1×1 conv backward weight: ``conv1x1-backward-weight.glsl`` (buffer float atomics). Shader is batch=1; we dispatch per batch and accumulate.""" + batch_size, _, out_height, out_width = grad_output.shape + flat_in = in_channels * out_height * out_width * 4 + flat_dy = out_channels * out_height * out_width * 4 + num_w = out_channels * in_channels * 4 + + buf_in = self._acquire_buffer(flat_in) + buf_dy = self._acquire_buffer(flat_dy) + buf_gw = self._acquire_buffer(num_w) + + try: + self._upload_buffer(buf_gw, np.zeros(out_channels * in_channels, dtype=np.float32)) + pipeline, pipeline_layout, _ = self.pipelines.get_or_create_pipeline( + "conv1x1-backward-weight", 3, push_constant_size=16 + ) + gx = (out_width + 15) // 16 + gy = (out_height + 15) // 16 + push = struct.pack("4I", out_width, out_height, in_channels, out_channels) + + for b in range(batch_size): + self._upload_buffer(buf_in, input_data[b].ravel()) + self._upload_buffer(buf_dy, grad_output[b].ravel()) + descriptor_set = self.pipelines.get_cached_descriptor_set( + "conv1x1-backward-weight", + [ + (self._get_buffer_handle(buf_in), flat_in), + (self._get_buffer_handle(buf_dy), flat_dy), + (self._get_buffer_handle(buf_gw), num_w), + ], + ) + self.core._dispatch_compute( + pipeline, + pipeline_layout, + descriptor_set, + gx, + push, + gy, + 1, + wait_previous=(b != 0), + ) + + grad_weight_flat = self._download_buffer(buf_gw, num_w, np.float32) + grad_weight = grad_weight_flat.reshape(out_channels, in_channels, 1, 1) + grad_bias = None + if has_bias: + grad_bias = np.sum(grad_output, axis=(0, 2, 3), dtype=np.float32) + return grad_weight, grad_bias + finally: + self._release_buffers([buf_in, buf_dy, buf_gw]) + # CPU fallbacks (using numpy) def _conv2d_cpu(self, input_data, weight, bias, stride, padding, dilation, groups): """CPU fallback for conv2d forward pass""" diff --git a/cpp/python/bindings.cpp b/cpp/python/bindings.cpp index 381dd49..ca2650f 100644 --- a/cpp/python/bindings.cpp +++ b/cpp/python/bindings.cpp @@ -3188,6 +3188,8 @@ PYBIND11_MODULE(grilly_core, m) { "Download to CPU and return as numpy array") .def("gpu_handle", &Tensor::gpu_handle, "Upload to GPU and return buffer handle") + .def("gpu_handle_if_valid", &Tensor::gpu_handle_if_valid, + "GPU buffer handle if already resident (0 otherwise); does not upload") .def("mark_gpu_modified", &Tensor::mark_gpu_modified) .def("mark_cpu_modified", &Tensor::mark_cpu_modified) .def("reshape", &Tensor::reshape, py::arg("shape")) diff --git a/cpp/python/bindings_core.cpp b/cpp/python/bindings_core.cpp index 5c2d3fc..cf593ea 100644 --- a/cpp/python/bindings_core.cpp +++ b/cpp/python/bindings_core.cpp @@ -154,6 +154,8 @@ PYBIND11_MODULE(grilly_core, m) { "Download to CPU and return as numpy array") .def("gpu_handle", &Tensor::gpu_handle, "Upload to GPU and return buffer handle") + .def("gpu_handle_if_valid", &Tensor::gpu_handle_if_valid, + "GPU buffer handle if already resident (0 otherwise); does not upload") .def("mark_gpu_modified", &Tensor::mark_gpu_modified) .def("mark_cpu_modified", &Tensor::mark_cpu_modified) .def("reshape", &Tensor::reshape, py::arg("shape")) diff --git a/shaders/spv/conv1x1-backward-weight.spv b/shaders/spv/conv1x1-backward-weight.spv index d0cef8759dd72ec7e194cc53aafe968ed74815b0..7be9997f3726fe3daa0cd8306e81d987a09de6e9 100644 GIT binary patch delta 205 zcmZ1=^FoH3nMs+Qfq{{Mn}K&C_g;4v1}+A7pLlnFe|I0(_~e3u_~MevoYeT7%)He2 zl+2>k^(+j@2Y DG5IRp delta 134 zcmaDMvp|NMnMs+Qftitkn}K5@_ukC{j3rErOp{xgvw`GC5V<*)WjUiF4+8^(8xUJF zurRm-X%8R bool: + """If C++ Tensor already holds a GPU buffer, mirror its VkBuffer handle for Python dispatch. + + Avoids allocating a second pooled buffer and re-uploading when the tensor is + GPU-resident only (e.g. output of a prior op) but Python slots were not set. + """ + if self._pooled_buffer is not None or self._gpu_buffer is not None: + return True + try: + h = int(self._t.gpu_handle_if_valid()) + except Exception: + return False + if h == 0: + return False + try: + import vulkan as vk + + self._gpu_buffer = vk.ffi.cast("VkBuffer", h) + self._gpu_memory = None + if self._core is None: + backend = _get_vulkan_backend() + if backend is not None: + self._core = backend.core + self._gpu_valid = True + self._uploaded = True + self._cpu_valid = bool(self._t.on_cpu) + return True + except Exception: + return False + + def prepare_for_dispatch(self) -> None: + """Ensure this tensor can supply a Vulkan buffer for kernels (minimal CPU↔GPU traffic).""" + if self._try_bind_cpp_gpu_buffer(): + return + if self._t.on_gpu: + try: + self._t.ensure_gpu() + except Exception: + pass + if self._try_bind_cpp_gpu_buffer(): + return + self._ensure_uploaded() + def _ensure_uploaded(self): - if self._gpu_valid: + if self._try_bind_cpp_gpu_buffer(): return if not self._cpu_valid and not self._t.on_cpu: + if self._t.on_gpu: + try: + self._t.ensure_gpu() + except Exception: + pass + if self._try_bind_cpp_gpu_buffer(): + return raise RuntimeError("Cannot upload: no valid CPU data") try: self._t.ensure_gpu() self._gpu_valid = True self._uploaded = True + if self._try_bind_cpp_gpu_buffer(): + return except Exception: # C++ Tensor may not have Vulkan backend — fall back to Python path try: @@ -726,6 +778,9 @@ def mark_cpu_modified(self): self._cpu_valid = True self._gpu_valid = False + def prepare_for_dispatch(self) -> None: + self._ensure_uploaded() + @property def is_leaf(self) -> bool: """True if this tensor was created directly (not by an operation).""" From 34ee480c68b5fb0d448ae046a780207970ec1a28 Mon Sep 17 00:00:00 2001 From: Grill cheese Date: Sat, 4 Apr 2026 12:26:13 -0400 Subject: [PATCH 15/17] prev1.0 --- CHANGELOG.md | 6 ++++++ backend/compute.py | 7 ++++++- backend/fnn.py | 6 ++++++ docs/PERF_DISPATCH.md | 8 +++++++- docs/index.md | 4 ++++ 5 files changed, 29 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0d852ed..f6f9f8e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,11 @@ This changelog follows the spirit of **Keep a Changelog** and uses the terms **A ### Added +- **`tests/sentencepiece/test_sentencepiece_parity.py`** — Adaptive SentencePiece parity tests vs Hugging Face references (T5/MT5-style tokenizers; auto-skip until `grilly.tokenizers` SentencePiece support is available). +- **`tests/tokenizers/test_gpu_tokenizer_parity.py`** — Adaptive real tokenizer parity tests vs Hugging Face references (auto-skip until `grilly.tokenizers` lands and assets are available). +- **`tests/sentence_transformers/`**, **`tests/transformers_compat/`**, **`tests/converter/`**, **`tests/autograd_chain/`**, **`tests/moe_quant/`** — CI-safe skip-marked scaffold suites to make pre-v1.0 roadmap targets executable and incrementally fillable. +- **`docs/pre-v1.0/OPTIMIZATION_PARITY_TASKLIST.md`** — Post-upgrade pre-v1.0 optimization + feature parity execution plan (P0/P1/P2 workstreams, acceptance criteria, and verification commands). +- **`backend/fnn_chain.py`** — `FnnChainRecorder` / `ChainBufferHandle`: batched `linear` / `relu` / `softmax` recording with `read()` / `read_multiple()` for one submit+wait (then one or many downloads); `VulkanCompute.record_commands(fnn_chain=True)` or `VulkanFNN.chain_record()`. See `docs/PERF_DISPATCH.md`. - **`tests/parity/`** — Numerical parity tests for `grilly.functional` (numpy reference; optional PyTorch `F.linear` / `F.relu` when `torch` is installed). See `tests/parity/README.md`. - **`tests/parity/test_optimizers_parity.py`** — SGD and Adam (CPU) stepping vs `torch.optim`. - **`docs/MIGRATION_PYTORCH.md`** — PyTorch → Grilly migration cookbook (device model, functional layout, module backend lifecycle, debugging). @@ -17,6 +22,7 @@ This changelog follows the spirit of **Keep a Changelog** and uses the terms **A ### Changed +- **`docs/pre-v1.0/OPTIMIZATION_PARITY_TASKLIST.md`** — Re-prioritized roadmap to feature-delivery order: GPU tokenizer -> sentence-transformers -> transformers compatibility -> PyTorch→Grilly converter, with supporting throughput track; expanded tokenizer interoperability to include SentencePiece compatibility targets. - **`utils/tensor_conversion.py`**, **`backend/base.py`** — `VulkanTensor.prepare_for_dispatch()` and C++ `Tensor.gpu_handle_if_valid()` binding so GPU-resident tensors reuse buffers in `BufferMixin._prepare_input` without redundant uploads when possible. - **`docs/PYTORCH_PARITY_TASKLIST.md`**, **`docs/PYTORCH_PARITY_STATUS.md`**, **`docs/GPU_OPTIMIZATION_REVIEW.md`** — Workstream C (kernel/throughput milestone) marked complete with scoped deliverables; follow-ups consolidated under **Workstream C — future**. - **`docs/api/functional.md`** — PyTorch parity notes and links to migration docs and parity tests. diff --git a/backend/compute.py b/backend/compute.py index 620d3a8..3dd4cf4 100644 --- a/backend/compute.py +++ b/backend/compute.py @@ -76,12 +76,17 @@ def __init__(self, shader_dir: str = None): self.shaders = self.core.shaders # Public Vulkan dispatch / batching (see docs/PERF_DISPATCH.md) - self.record_commands = self.core.record_commands self.dispatch_compute = self.core._dispatch_compute self.dispatch_compute_async = self.core._dispatch_compute_async self.wait_async = self.core._wait_async self.wait_fence = self.core._wait_fence + def record_commands(self, fnn_chain: bool = False): + """Batch command recording. Use ``fnn_chain=True`` for :class:`FnnChainRecorder` (linear/relu/softmax + ``read``).""" + if fnn_chain: + return self.fnn.chain_record() + return self.core.record_commands() + def cleanup(self): """Clean up Vulkan resources""" # Clear weight caches for all backend modules (before device is destroyed) diff --git a/backend/fnn.py b/backend/fnn.py index 52ef3df..50591c0 100644 --- a/backend/fnn.py +++ b/backend/fnn.py @@ -73,6 +73,12 @@ def __init__(self, core, pipelines, shaders): self.shaders = shaders self._pool = None # Lazy initialization + def chain_record(self): + """Return a :class:`~grilly.backend.fnn_chain.FnnChainRecorder` for batched linear/relu/softmax + single submit.""" + from .fnn_chain import FnnChainRecorder + + return FnnChainRecorder(self) + def gemm(self, A, B, return_gpu_tensor=False, cache_B=False, force_fp32=False): """ GEMM: C = A @ B diff --git a/docs/PERF_DISPATCH.md b/docs/PERF_DISPATCH.md index b9fb49c..4cbef0d 100644 --- a/docs/PERF_DISPATCH.md +++ b/docs/PERF_DISPATCH.md @@ -8,13 +8,19 @@ This document summarizes **non-blocking** and **batched** Vulkan dispatch APIs a - **`VulkanCore._dispatch_compute_async(...)`** — Submits work **without** waiting; pair with **`_wait_async()`** before reading GPU results. - **`VulkanCore.record_commands()`** — Returns a **`CommandRecorder`** context manager: multiple `dispatch` + `barrier` calls, then **one** submit + wait on `__exit__`. Used by FlashAttention2 tiling, RMSNorm two-pass, and **`VulkanFNN._linear_relu_recorded_chain`** (Linear→ReLU fallback when `fused-linear-relu.spv` is missing). +### FNN chain recorder (`record_commands(fnn_chain=True)`) + +`VulkanCompute.record_commands(fnn_chain=True)` returns **`FnnChainRecorder`** (`backend/fnn_chain.py`): high-level **`linear`**, **`relu`**, **`softmax`** methods that record dispatches only and return **`ChainBufferHandle`**. **`read(handle)`** performs **one** `submit_and_wait` + one download. **`read_multiple([h0, h1, ...])`** does **one** `submit_and_wait` then downloads each handle (MoE fan-out: many expert linears, one fence wait, many CPU reads). Use for deep MLP-style forwards to avoid one fence wait per layer. + +Equivalent: **`VulkanFNN.chain_record()`**. + ## Public aliases on `VulkanCompute` (`backend/compute.py`) After `Compute()` / `VulkanCompute()` construction: | Attribute | Underlying API | |-----------|----------------| -| `record_commands` | `core.record_commands()` | +| `record_commands(fnn_chain=False)` | `core.record_commands()` or `fnn.chain_record()` when `fnn_chain=True` | | `dispatch_compute` | `core._dispatch_compute` | | `dispatch_compute_async` | `core._dispatch_compute_async` | | `wait_async` | `core._wait_async` | diff --git a/docs/index.md b/docs/index.md index 638b17c..7698582 100644 --- a/docs/index.md +++ b/docs/index.md @@ -102,6 +102,10 @@ grilly/ | [optimum-grilly](https://github.com/grillcheese-ai/optimum-grilly) | HuggingFace Optimum backend -- `from_pretrained` to Vulkan inference | | [CubeMind](https://github.com/grillcheese-ai/cubemind) | Neuro-vector-symbolic reasoning powered by grilly | +## Pre-v1.0 Planning + +- [Pre-v1.0 Optimization + Parity Tasklist](pre-v1.0/OPTIMIZATION_PARITY_TASKLIST.md) + ## License MIT -- see [LICENSE](https://github.com/grillcheese-ai/grilly/blob/main/LICENSE). From 13b7ac3bc7fea40080bce75583b43ba38e80af07 Mon Sep 17 00:00:00 2001 From: Grill cheese Date: Mon, 6 Apr 2026 23:13:59 -0400 Subject: [PATCH 16/17] pre v1.0 --- CMakeLists.txt | 6 + __init__.py | 23 +- backend/__init__.py | 8 +- backend/_bridge.py | 336 ++++++++ backend/attention.py | 4 +- backend/base.py | 39 +- backend/buffer_pool.py | 4 +- backend/cells.py | 4 +- backend/contrastive.py | 4 +- backend/conv.py | 4 +- backend/core.py | 15 +- backend/faiss.py | 6 +- backend/fft.py | 4 +- backend/fnn.py | 4 +- backend/fnn_chain.py | 281 +++++++ backend/learning.py | 6 +- backend/lora.py | 8 +- backend/normalization.py | 4 +- backend/pipelines.py | 4 +- backend/pooling.py | 4 +- backend/snn.py | 4 +- backend/snn_compute.py | 4 +- backend/tensor_ops.py | 4 +- cpp/include/grilly/io/grl_checkpoint.h | 25 + cpp/include/grilly/ops/activations.h | 15 + cpp/include/grilly/ops/batched_ops.h | 9 + cpp/include/grilly/ops/moe_forward.h | 104 +++ cpp/include/grilly/ops/vsa_lm_forward.h | 121 +++ cpp/include/grilly/vulkan/vk_buffer_pool.h | 21 + cpp/include/grilly/vulkan/vk_command_batch.h | 6 + cpp/python/bindings_activations.cpp | 49 ++ cpp/python/bindings_core.cpp | 3 + cpp/python/bindings_core.h | 16 + cpp/python/bindings_grl.cpp | 44 + cpp/python/bindings_moe.cpp | 253 ++++++ cpp/python/bindings_normalization.cpp | 32 + cpp/python/bindings_vsa_lm.cpp | 305 +++++++ cpp/src/buffer_pool.cpp | 101 ++- cpp/src/command_batch.cpp | 26 + cpp/src/io/grl_checkpoint.cpp | 110 +++ cpp/src/ops/activations.cpp | 177 ++++- cpp/src/ops/batched_ops.cpp | 28 + cpp/src/ops/embedding.cpp | 40 +- cpp/src/ops/layernorm.cpp | 193 +++-- cpp/src/ops/linear.cpp | 255 ++++-- cpp/src/ops/loss.cpp | 88 +- cpp/src/ops/moe_forward.cpp | 752 ++++++++++++++++++ cpp/src/ops/optimizer.cpp | 128 ++- cpp/src/ops/vsa_lm_forward.cpp | 656 +++++++++++++++ docs/grl_v1_format.md | 24 + docs/pre-v1.0/OPTIMIZATION_PARITY_TASKLIST.md | 210 +++++ examples/experimental_backend_vsa.py | 4 +- functional/__init__.py | 12 + functional/mf_activations.py | 65 ++ nn/__init__.py | 21 + nn/addition_linear.py | 150 ++++ nn/autograd.py | 114 ++- nn/embedding.py | 79 +- nn/functional.py | 22 + nn/init.py | 21 + nn/linear.py | 103 ++- nn/module.py | 216 ++++- nn/module_list.py | 48 ++ nn/normalization_modules.py | 80 +- nn/parameter.py | 41 + nn/utils.py | 61 ++ nn/vsa_lm.py | 386 +++++++++ pyproject.toml | 14 + shaders/mf-sigmoid.glsl | 25 + shaders/mf-softmax.glsl | 86 ++ shaders/mf-softplus.glsl | 26 + shaders/moe-layer-backward-vec4.glsl | 87 ++ shaders/moe-layer-backward.glsl | 110 +++ shaders/moe-layer-fused-vec4.glsl | 106 +++ shaders/moe-layer-fused.glsl | 104 +++ shaders/moe-layer-grad-weight.glsl | 85 ++ shaders/moe-router.glsl | 84 ++ shaders/sign-activation.glsl | 24 + shaders/softmax-fast.glsl | 77 ++ shaders/spv/mf-sigmoid.spv | Bin 0 -> 1580 bytes shaders/spv/mf-softmax.spv | Bin 0 -> 6684 bytes shaders/spv/mf-softplus.spv | Bin 0 -> 1744 bytes shaders/spv/moe-layer-backward-vec4.spv | Bin 0 -> 10392 bytes shaders/spv/moe-layer-backward.spv | Bin 0 -> 10264 bytes shaders/spv/moe-layer-fused-vec4.spv | Bin 0 -> 10864 bytes shaders/spv/moe-layer-fused.spv | Bin 0 -> 9660 bytes shaders/spv/moe-layer-grad-weight.spv | Bin 0 -> 6696 bytes shaders/spv/moe-router.spv | Bin 0 -> 5816 bytes shaders/spv/sign-activation.spv | Bin 0 -> 1520 bytes shaders/spv/softmax-fast.spv | Bin 0 -> 4764 bytes tests/_tokenizer_parity_helpers.py | 68 ++ .../test_autograd_chain_placeholder.py | 13 + tests/conftest.py | 12 +- .../test_pytorch_converter_placeholder.py | 13 + tests/moe_quant/test_moe_quant_placeholder.py | 13 + .../test_sentence_transformers_parity.py | 79 ++ .../test_sentencepiece_parity.py | 50 ++ tests/test_attention.py | 4 +- tests/test_attention_long_sequence.py | 6 +- tests/test_backward.py | 8 +- tests/test_bridge_moe.py | 14 + tests/test_bridge_vsa_lm.py | 14 + tests/test_conv_backward_weight_gemm.py | 4 +- tests/test_core.py | 7 +- tests/test_fnn_chain.py | 75 ++ tests/test_functional.py | 6 +- tests/test_gpu_operations.py | 8 +- tests/test_grl_checkpoint.py | 61 ++ tests/test_inference_ops.py | 10 +- tests/test_integration.py | 4 +- tests/test_learning.py | 4 +- tests/test_memory_operations.py | 4 +- tests/test_mf_activations.py | 50 ++ tests/test_mf_ops_core.py | 69 ++ tests/test_moe_forward.py | 158 ++++ tests/test_snn.py | 4 +- tests/test_torch_api.py | 50 ++ tests/test_vsa_lm_forward.py | 282 +++++++ tests/test_vulkan_tensor_residency.py | 4 +- tests/tokenizers/test_gpu_tokenizer_parity.py | 54 ++ .../test_transformers_compat_placeholder.py | 13 + tokenizer_impl/__init__.py | 32 + tokenizer_impl/auto.py | 20 + tokenizer_impl/base.py | 42 + tokenizer_impl/fast_tokenizer.py | 110 +++ tokenizer_impl/gpu.py | 61 ++ tokenizer_impl/loader.py | 57 ++ torch_api/__init__.py | 101 +++ torch_api/amp_mod.py | 58 ++ torch_api/dtypes_and_device.py | 87 ++ torch_api/functional.py | 47 ++ torch_api/ops.py | 111 +++ torch_api/serialization.py | 49 ++ torch_api/tensor.py | 89 +++ torch_api/vulkan_mod.py | 15 + torch_api_example.py | 311 ++++++++ utils/__init__.py | 3 + utils/grl_checkpoint.py | 252 ++++++ utils/initialization.py | 30 + uv.lock | 71 +- 140 files changed, 9097 insertions(+), 432 deletions(-) create mode 100644 backend/fnn_chain.py create mode 100644 cpp/include/grilly/io/grl_checkpoint.h create mode 100644 cpp/include/grilly/ops/moe_forward.h create mode 100644 cpp/include/grilly/ops/vsa_lm_forward.h create mode 100644 cpp/python/bindings_grl.cpp create mode 100644 cpp/python/bindings_moe.cpp create mode 100644 cpp/python/bindings_vsa_lm.cpp create mode 100644 cpp/src/io/grl_checkpoint.cpp create mode 100644 cpp/src/ops/moe_forward.cpp create mode 100644 cpp/src/ops/vsa_lm_forward.cpp create mode 100644 docs/grl_v1_format.md create mode 100644 docs/pre-v1.0/OPTIMIZATION_PARITY_TASKLIST.md create mode 100644 functional/mf_activations.py create mode 100644 nn/addition_linear.py create mode 100644 nn/functional.py create mode 100644 nn/init.py create mode 100644 nn/module_list.py create mode 100644 nn/utils.py create mode 100644 nn/vsa_lm.py create mode 100644 shaders/mf-sigmoid.glsl create mode 100644 shaders/mf-softmax.glsl create mode 100644 shaders/mf-softplus.glsl create mode 100644 shaders/moe-layer-backward-vec4.glsl create mode 100644 shaders/moe-layer-backward.glsl create mode 100644 shaders/moe-layer-fused-vec4.glsl create mode 100644 shaders/moe-layer-fused.glsl create mode 100644 shaders/moe-layer-grad-weight.glsl create mode 100644 shaders/moe-router.glsl create mode 100644 shaders/sign-activation.glsl create mode 100644 shaders/softmax-fast.glsl create mode 100644 shaders/spv/mf-sigmoid.spv create mode 100644 shaders/spv/mf-softmax.spv create mode 100644 shaders/spv/mf-softplus.spv create mode 100644 shaders/spv/moe-layer-backward-vec4.spv create mode 100644 shaders/spv/moe-layer-backward.spv create mode 100644 shaders/spv/moe-layer-fused-vec4.spv create mode 100644 shaders/spv/moe-layer-fused.spv create mode 100644 shaders/spv/moe-layer-grad-weight.spv create mode 100644 shaders/spv/moe-router.spv create mode 100644 shaders/spv/sign-activation.spv create mode 100644 shaders/spv/softmax-fast.spv create mode 100644 tests/_tokenizer_parity_helpers.py create mode 100644 tests/autograd_chain/test_autograd_chain_placeholder.py create mode 100644 tests/converter/test_pytorch_converter_placeholder.py create mode 100644 tests/moe_quant/test_moe_quant_placeholder.py create mode 100644 tests/sentence_transformers/test_sentence_transformers_parity.py create mode 100644 tests/sentencepiece/test_sentencepiece_parity.py create mode 100644 tests/test_bridge_moe.py create mode 100644 tests/test_bridge_vsa_lm.py create mode 100644 tests/test_fnn_chain.py create mode 100644 tests/test_grl_checkpoint.py create mode 100644 tests/test_mf_activations.py create mode 100644 tests/test_mf_ops_core.py create mode 100644 tests/test_moe_forward.py create mode 100644 tests/test_torch_api.py create mode 100644 tests/test_vsa_lm_forward.py create mode 100644 tests/tokenizers/test_gpu_tokenizer_parity.py create mode 100644 tests/transformers_compat/test_transformers_compat_placeholder.py create mode 100644 tokenizer_impl/__init__.py create mode 100644 tokenizer_impl/auto.py create mode 100644 tokenizer_impl/base.py create mode 100644 tokenizer_impl/fast_tokenizer.py create mode 100644 tokenizer_impl/gpu.py create mode 100644 tokenizer_impl/loader.py create mode 100644 torch_api/__init__.py create mode 100644 torch_api/amp_mod.py create mode 100644 torch_api/dtypes_and_device.py create mode 100644 torch_api/functional.py create mode 100644 torch_api/ops.py create mode 100644 torch_api/serialization.py create mode 100644 torch_api/tensor.py create mode 100644 torch_api/vulkan_mod.py create mode 100644 torch_api_example.py create mode 100644 utils/grl_checkpoint.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 3b038fe..dc2a917 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -218,6 +218,8 @@ add_library(grilly_core_lib STATIC cpp/src/ops/perceiver.cpp cpp/src/ops/perceiver_encoder.cpp cpp/src/ops/moqe_train.cpp + cpp/src/ops/moe_forward.cpp + cpp/src/ops/vsa_lm_forward.cpp cpp/src/shader_fusion.cpp # ── Experimental ── cpp/src/experimental/paged_latent_pool.cpp @@ -247,6 +249,7 @@ add_library(grilly_core_lib STATIC cpp/src/nn/containers.cpp cpp/src/nn/optimizer.cpp cpp/src/nn/dataloader.cpp + cpp/src/io/grl_checkpoint.cpp ) target_include_directories(grilly_core_lib PUBLIC @@ -303,7 +306,10 @@ pybind11_add_module(grilly_core cpp/python/bindings_siglip.cpp cpp/python/bindings_perceiver.cpp cpp/python/bindings_moqe_train.cpp + cpp/python/bindings_moe.cpp cpp/python/bindings_fusion.cpp + cpp/python/bindings_vsa_lm.cpp + cpp/python/bindings_grl.cpp ) target_link_libraries(grilly_core PRIVATE grilly_core_lib) diff --git a/__init__.py b/__init__.py index cad6c95..3e0fa0b 100644 --- a/__init__.py +++ b/__init__.py @@ -19,11 +19,29 @@ - Hippocampal transformer with capsule memory """ +# grilly_core..pyd lives in this directory as a sibling .pyd. Editable +# (PEP 660) installs route `import grilly` through a path hook that does NOT +# add this directory to sys.path, so `import grilly_core` would fail and the +# Vulkan probe in backend/base.py would silently report VULKAN_AVAILABLE=False +# even on machines with a perfectly working Vulkan device. Insert the package +# directory once, before any submodule import triggers the probe. +import os as _os +import sys as _sys +_pkg_dir = _os.path.dirname(_os.path.abspath(__file__)) +if _pkg_dir not in _sys.path: + _sys.path.insert(0, _pkg_dir) + import grilly.functional as functional import grilly.nn as nn import grilly.optim as optim import grilly.utils as utils -from grilly.backend.base import VULKAN_AVAILABLE + +from . import torch_api +from grilly.backend.base import ( + VULKAN_AVAILABLE, + VULKAN_PYTHON_BINDINGS_AVAILABLE, + VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE, +) from grilly.backend.capsule_transformer import ( CapsuleMemory, CapsuleTransformerConfig, @@ -65,6 +83,8 @@ __all__ = [ "VULKAN_AVAILABLE", + "VULKAN_PYTHON_BINDINGS_AVAILABLE", + "VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE", "VulkanCompute", "Compute", "SNNCompute", @@ -79,6 +99,7 @@ "functional", "optim", "utils", + "torch_api", ] # Conditionally add compatibility exports diff --git a/backend/__init__.py b/backend/__init__.py index 814dd4f..b2cb4ab 100644 --- a/backend/__init__.py +++ b/backend/__init__.py @@ -12,7 +12,11 @@ - Bridge operations (continuous ↔ spike) """ -from .base import VULKAN_AVAILABLE +from .base import ( + VULKAN_AVAILABLE, + VULKAN_PYTHON_BINDINGS_AVAILABLE, + VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE, +) from .capsule_transformer import ( CapsuleMemory, CapsuleTransformerConfig, @@ -26,6 +30,8 @@ __all__ = [ "VULKAN_AVAILABLE", + "VULKAN_PYTHON_BINDINGS_AVAILABLE", + "VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE", "VulkanCompute", "SNNCompute", "VulkanLearning", diff --git a/backend/_bridge.py b/backend/_bridge.py index 24c69df..9b74d53 100644 --- a/backend/_bridge.py +++ b/backend/_bridge.py @@ -11,6 +11,10 @@ result = _bridge.linear(x, weight, bias) if result is not None: return result # else fall through to legacy backend + +Fused MoE (:func:`moe_upload`, :func:`moe_forward`, :func:`moe_backward`, etc.) +uses the same lazy :func:`_get_device` and ``shaders/spv`` load as other bridge +ops—prefer these over calling ``grilly_core.moe_*`` with a separate Device. """ import logging @@ -698,6 +702,42 @@ def softmax_backward(grad_output, softmax_output): return None +def mf_softmax(x, dim=-1): + """GPU multiplication-free softmax (ReLU-normalized). Returns None on failure.""" + dev = _get_device() + if dev is None: + return None + try: + return _core.mf_softmax(dev, _ensure_f32_contiguous(x), dim) + except Exception as e: + _record_fallback("mf_softmax", e) + return None + + +def mf_softplus(x, beta=1.0): + """GPU algebraic softplus (sqrt form). Returns None on failure.""" + dev = _get_device() + if dev is None: + return None + try: + return _core.mf_softplus(dev, _ensure_f32_contiguous(x), float(beta)) + except Exception as e: + _record_fallback("mf_softplus", e) + return None + + +def mf_sigmoid(x): + """GPU rational sigmoid x/(1+|x|). Returns None on failure.""" + dev = _get_device() + if dev is None: + return None + try: + return _core.mf_sigmoid(dev, _ensure_f32_contiguous(x)) + except Exception as e: + _record_fallback("mf_sigmoid", e) + return None + + # ── Linear Backward + Dropout ──────────────────────────────────────────── @@ -1741,6 +1781,302 @@ def moqe_route_and_gemv(activations, choice, expert_weights, expert_scales, bloc ) +# ── Fused MoE (grilly_core moe_* — same lazy device + shaders as other bridge ops) ── + + +def moe_upload(embed_w, pos_w, expert_ws, router_ws, router_bs, out_w, n_layers, n_experts): + """Upload MoE weights to GPU; returns an opaque integer handle. + + Uses the shared bridge :func:`_get_device` (lazy init + ``shaders/spv`` load). + + Args: + embed_w: float32 (vocab, d_model), C-contiguous. + pos_w: float32 (max_seq, d_model). + expert_ws: list of length ``n_layers * n_experts``, each (d_model, d_model). + router_ws: list of length ``n_layers``, each (n_experts, d_model). + router_bs: list of length ``n_layers``, each (n_experts,). + out_w: float32 (vocab, d_model) — output projection. + n_layers, n_experts: int. + + Returns: + int handle, or ``None`` if the C++ extension / device is unavailable. + """ + if not _NATIVE or not hasattr(_core, "moe_upload"): + return None + dev = _get_device() + if dev is None: + return None + try: + return _core.moe_upload( + dev, + _ensure_f32_contiguous(embed_w), + _ensure_f32_contiguous(pos_w), + expert_ws, + router_ws, + router_bs, + _ensure_f32_contiguous(out_w), + int(n_layers), + int(n_experts), + ) + except Exception as e: + _record_fallback("moe_upload", e) + return None + + +def moe_release(handle): + """Free GPU resources for a handle from :func:`moe_upload`. + + Returns: + ``True`` if released, ``False`` if bridge/device unavailable. + """ + if not _NATIVE or not hasattr(_core, "moe_release"): + return False + dev = _get_device() + if dev is None: + return False + try: + _core.moe_release(dev, int(handle)) + return True + except Exception as e: + _record_fallback("moe_release", e) + return False + + +def moe_forward(handle, input_ids): + """Run fused MoE forward; returns logits (seq_len, vocab). + + Args: + handle: int from :func:`moe_upload`. + input_ids: int32 1-D (seq_len,), C-contiguous. + + Returns: + float32 ndarray or ``None`` on failure. + """ + if not _NATIVE or not hasattr(_core, "moe_forward"): + return None + dev = _get_device() + if dev is None: + return None + ids = np.asarray(input_ids, dtype=np.int32) + if not ids.flags.c_contiguous: + ids = np.ascontiguousarray(ids, dtype=np.int32) + try: + return _core.moe_forward(dev, int(handle), ids) + except Exception as e: + _record_fallback("moe_forward", e) + return None + + +def moe_update_weights(handle, embed_w, pos_w, expert_ws, router_ws, router_bs, out_w): + """Re-upload weights in place (same shapes as :func:`moe_upload`). + + Returns: + ``True`` on success, ``False`` on failure. + """ + if not _NATIVE or not hasattr(_core, "moe_update_weights"): + return False + dev = _get_device() + if dev is None: + return False + try: + _core.moe_update_weights( + dev, + int(handle), + _ensure_f32_contiguous(embed_w), + _ensure_f32_contiguous(pos_w), + expert_ws, + router_ws, + router_bs, + _ensure_f32_contiguous(out_w), + ) + return True + except Exception as e: + _record_fallback("moe_update_weights", e) + return False + + +def moe_backward(handle, input_ids, grad_logits): + """CPU backward for MoE (gradients dict); uses CPU mirrors from upload. + + Args: + handle: int from :func:`moe_upload`. + input_ids: int32 (seq_len,) — same sequence as forward. + grad_logits: float32 (seq_len, vocab). + + Returns: + dict with keys ``grad_embed``, ``grad_pos``, ``grad_experts``, ``grad_routers_W``, + ``grad_routers_b``, ``grad_out_w``, or ``None`` on failure. + """ + if not _NATIVE or not hasattr(_core, "moe_backward"): + return None + dev = _get_device() + if dev is None: + return None + ids = np.asarray(input_ids, dtype=np.int32) + if not ids.flags.c_contiguous: + ids = np.ascontiguousarray(ids, dtype=np.int32) + g = _ensure_f32_contiguous(grad_logits) + try: + return _core.moe_backward(dev, int(handle), ids, g) + except Exception as e: + _record_fallback("moe_backward", e) + return None + + +# ── Fused VSA-LM (grilly_core vsa_lm_* — AdditionLinear FFN + sign activation) ── + + +def vsa_lm_upload( + embed_w, pos_w, + ffn_up_patterns, ffn_up_biases, ffn_down_patterns, ffn_down_biases, + ln_gammas, ln_betas, out_w, + n_layers, d_model, d_ffn, +): + """Upload VSA-LM weights to GPU; returns an opaque integer handle. + + Uses the shared bridge :func:`_get_device` (lazy init + ``shaders/spv`` load). + + Args: + embed_w: float32 (vocab, d_model). + pos_w: float32 (max_seq, d_model). + ffn_up_patterns: list of (d_ffn, d_model) float32, length n_layers. + ffn_up_biases: list of (d_ffn,) float32, length n_layers. + ffn_down_patterns: list of (d_model, d_ffn) float32, length n_layers. + ffn_down_biases: list of (d_model,) float32, length n_layers. + ln_gammas: list of (d_model,) float32, length n_layers. + ln_betas: list of (d_model,) float32, length n_layers. + out_w: float32 (vocab, d_model). + n_layers, d_model, d_ffn: int. + + Returns: + int handle, or ``None`` if the C++ extension / device is unavailable. + """ + if not _NATIVE or not hasattr(_core, "vsa_lm_upload"): + return None + dev = _get_device() + if dev is None: + return None + try: + return _core.vsa_lm_upload( + dev, + _ensure_f32_contiguous(embed_w), + _ensure_f32_contiguous(pos_w), + ffn_up_patterns, ffn_up_biases, + ffn_down_patterns, ffn_down_biases, + ln_gammas, ln_betas, + _ensure_f32_contiguous(out_w), + int(n_layers), int(d_model), int(d_ffn), + ) + except Exception as e: + _record_fallback("vsa_lm_upload", e) + return None + + +def vsa_lm_release(handle): + """Free GPU resources for a handle from :func:`vsa_lm_upload`. + + Returns: + ``True`` if released, ``False`` if bridge/device unavailable. + """ + if not _NATIVE or not hasattr(_core, "vsa_lm_release"): + return False + dev = _get_device() + if dev is None: + return False + try: + _core.vsa_lm_release(dev, int(handle)) + return True + except Exception as e: + _record_fallback("vsa_lm_release", e) + return False + + +def vsa_lm_forward(handle, input_ids): + """Run fused VSA-LM forward; returns logits (seq_len, vocab). + + Args: + handle: int from :func:`vsa_lm_upload`. + input_ids: int32 1-D (seq_len,), C-contiguous. + + Returns: + float32 ndarray or ``None`` on failure. + """ + if not _NATIVE or not hasattr(_core, "vsa_lm_forward"): + return None + dev = _get_device() + if dev is None: + return None + ids = np.asarray(input_ids, dtype=np.int32) + if not ids.flags.c_contiguous: + ids = np.ascontiguousarray(ids, dtype=np.int32) + try: + return _core.vsa_lm_forward(dev, int(handle), ids) + except Exception as e: + _record_fallback("vsa_lm_forward", e) + return None + + +def vsa_lm_backward(handle, input_ids, grad_logits): + """CPU backward for VSA-LM (gradients dict); uses CPU mirrors from upload. + + Args: + handle: int from :func:`vsa_lm_upload`. + input_ids: int32 (seq_len,) — same sequence as forward. + grad_logits: float32 (seq_len, vocab). + + Returns: + dict with keys ``grad_embed``, ``grad_pos``, ``grad_out_w``, + ``grad_ffn_up_w``, ``grad_ffn_up_b``, ``grad_ffn_down_w``, + ``grad_ffn_down_b``, ``grad_ln_gamma``, ``grad_ln_beta``, + or ``None`` on failure. + """ + if not _NATIVE or not hasattr(_core, "vsa_lm_backward"): + return None + dev = _get_device() + if dev is None: + return None + ids = np.asarray(input_ids, dtype=np.int32) + if not ids.flags.c_contiguous: + ids = np.ascontiguousarray(ids, dtype=np.int32) + g = _ensure_f32_contiguous(grad_logits) + try: + return _core.vsa_lm_backward(dev, int(handle), ids, g) + except Exception as e: + _record_fallback("vsa_lm_backward", e) + return None + + +def vsa_lm_update_weights( + handle, embed_w, pos_w, + ffn_up_patterns, ffn_up_biases, ffn_down_patterns, ffn_down_biases, + ln_gammas, ln_betas, out_w, +): + """Re-upload VSA-LM weights in place (same shapes as :func:`vsa_lm_upload`). + + Returns: + ``True`` on success, ``False`` on failure. + """ + if not _NATIVE or not hasattr(_core, "vsa_lm_update_weights"): + return False + dev = _get_device() + if dev is None: + return False + try: + _core.vsa_lm_update_weights( + dev, int(handle), + _ensure_f32_contiguous(embed_w), + _ensure_f32_contiguous(pos_w), + ffn_up_patterns, ffn_up_biases, + ffn_down_patterns, ffn_down_biases, + ln_gammas, ln_betas, + _ensure_f32_contiguous(out_w), + ) + return True + except Exception as e: + _record_fallback("vsa_lm_update_weights", e) + return False + + # ── Bridge support for functional modules without native kernels ────────── diff --git a/backend/attention.py b/backend/attention.py index 7f4cff3..176dea5 100644 --- a/backend/attention.py +++ b/backend/attention.py @@ -8,7 +8,7 @@ import numpy as np -from .base import VULKAN_AVAILABLE, BufferMixin +from .base import BufferMixin, VULKAN_PYTHON_BINDINGS_AVAILABLE from .shader_registry import get_shader logger = logging.getLogger(__name__) @@ -33,7 +33,7 @@ numba_prosody_modulation = None numba_attention_output = None -if VULKAN_AVAILABLE: +if VULKAN_PYTHON_BINDINGS_AVAILABLE: from vulkan import * diff --git a/backend/base.py b/backend/base.py index 630b8fa..319c746 100644 --- a/backend/base.py +++ b/backend/base.py @@ -1,5 +1,10 @@ """ Base constants and utilities for Vulkan backend. + +``VULKAN_AVAILABLE`` is True when ``grilly_core.Device()`` can initialize (C++ Vulkan; +does not use the PyPI ``vulkan`` package). Optional ctypes bindings set +``VULKAN_PYTHON_BINDINGS_AVAILABLE``; the legacy Python ``VulkanCore`` / ``Compute()`` stack +requires both (see ``VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE``). """ import numpy as np @@ -40,15 +45,30 @@ "VK_PIPELINE_BIND_POINT_COMPUTE": 0, } +def _probe_cpp_vulkan() -> bool: + """True when ``grilly_core`` can initialize a Vulkan device (C++ path; no PyPI ``vulkan``).""" + try: + import grilly_core as gc + + gc.Device() + return True + except Exception: + return False + + +VULKAN_AVAILABLE = _probe_cpp_vulkan() + +# Optional ctypes bindings (``pip install vulkan``) — used by legacy VulkanCore / VulkanCompute. try: from vulkan import * - VULKAN_AVAILABLE = True - # After 'from vulkan import *', all Vulkan constants are in the namespace - # and can be imported by other modules using 'from base import VK_...' + VULKAN_PYTHON_BINDINGS_AVAILABLE = True + # After 'from vulkan import *', constants live in this module for 'from base import VK_...' except ImportError: - VULKAN_AVAILABLE = False - pass + VULKAN_PYTHON_BINDINGS_AVAILABLE = False + +# Full Python VulkanCore stack (ctypes + SPIR-V loaders) — not required for ``grilly_core``-only GPU. +VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE = VULKAN_AVAILABLE and VULKAN_PYTHON_BINDINGS_AVAILABLE # Create dummy constants for type checking when Vulkan is not available # or when a mocked `vulkan` module does not define them (e.g. RTD autodoc). @@ -417,4 +437,11 @@ def _prepare_output(self, size): return self._acquire_buffer(size) -__all__ = ["VULKAN_AVAILABLE", "BufferMixin", "_DirectBuffer", "BUFFER_POOL_AVAILABLE"] +__all__ = [ + "VULKAN_AVAILABLE", + "VULKAN_PYTHON_BINDINGS_AVAILABLE", + "VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE", + "BufferMixin", + "_DirectBuffer", + "BUFFER_POOL_AVAILABLE", +] diff --git a/backend/buffer_pool.py b/backend/buffer_pool.py index 6478d9b..5e11076 100644 --- a/backend/buffer_pool.py +++ b/backend/buffer_pool.py @@ -30,7 +30,7 @@ import numpy as np -from .base import VULKAN_AVAILABLE +from .base import VULKAN_PYTHON_BINDINGS_AVAILABLE if TYPE_CHECKING: from .core import VulkanCore @@ -53,7 +53,7 @@ except ImportError: logger.debug("PyVMA2 not available - using direct Vulkan allocation") -if VULKAN_AVAILABLE: +if VULKAN_PYTHON_BINDINGS_AVAILABLE: from vulkan import ( VK_BUFFER_USAGE_STORAGE_BUFFER_BIT, VK_SHARING_MODE_EXCLUSIVE, diff --git a/backend/cells.py b/backend/cells.py index 0f567f5..7f1a349 100644 --- a/backend/cells.py +++ b/backend/cells.py @@ -7,9 +7,9 @@ import numpy as np -from .base import VULKAN_AVAILABLE, BufferMixin +from .base import BufferMixin, VULKAN_PYTHON_BINDINGS_AVAILABLE -if VULKAN_AVAILABLE: +if VULKAN_PYTHON_BINDINGS_AVAILABLE: from vulkan import * diff --git a/backend/contrastive.py b/backend/contrastive.py index 08e2f97..6854a52 100644 --- a/backend/contrastive.py +++ b/backend/contrastive.py @@ -12,9 +12,9 @@ import numpy as np -from .base import VULKAN_AVAILABLE, BufferMixin +from .base import BufferMixin, VULKAN_PYTHON_BINDINGS_AVAILABLE -if VULKAN_AVAILABLE: +if VULKAN_PYTHON_BINDINGS_AVAILABLE: from vulkan import * diff --git a/backend/conv.py b/backend/conv.py index 4a3377b..3e72446 100644 --- a/backend/conv.py +++ b/backend/conv.py @@ -12,9 +12,9 @@ import numpy as np -from .base import VULKAN_AVAILABLE, BufferMixin +from .base import BufferMixin, VULKAN_PYTHON_BINDINGS_AVAILABLE -if VULKAN_AVAILABLE: +if VULKAN_PYTHON_BINDINGS_AVAILABLE: from vulkan import * diff --git a/backend/core.py b/backend/core.py index 97e0ea9..515a0e3 100644 --- a/backend/core.py +++ b/backend/core.py @@ -10,9 +10,13 @@ import numpy as np -from .base import VK_BUFFER_USAGE_STORAGE_BUFFER_BIT, VULKAN_AVAILABLE +from .base import ( + VK_BUFFER_USAGE_STORAGE_BUFFER_BIT, + VULKAN_AVAILABLE, + VULKAN_PYTHON_BINDINGS_AVAILABLE, +) -if VULKAN_AVAILABLE: +if VULKAN_PYTHON_BINDINGS_AVAILABLE: from vulkan import * logger = logging.getLogger(__name__) @@ -32,7 +36,12 @@ def __init__(self, shader_dir: str = None): """Initialize the instance.""" if not VULKAN_AVAILABLE: - raise RuntimeError("Vulkan not available") + raise RuntimeError("Vulkan not available (C++ grilly_core could not initialize GPU)") + if not VULKAN_PYTHON_BINDINGS_AVAILABLE: + raise RuntimeError( + "VulkanCore requires the Python `vulkan` package (ctypes bindings). " + "Install with: pip install vulkan" + ) import os # Disable Mesa device_select layer which can force CPU llvmpipe diff --git a/backend/faiss.py b/backend/faiss.py index b156f89..3d6e903 100644 --- a/backend/faiss.py +++ b/backend/faiss.py @@ -8,9 +8,9 @@ import numpy as np -from .base import VULKAN_AVAILABLE, BufferMixin +from .base import BufferMixin, VULKAN_PYTHON_BINDINGS_AVAILABLE -if VULKAN_AVAILABLE: +if VULKAN_PYTHON_BINDINGS_AVAILABLE: from vulkan import * logger = logging.getLogger(__name__) @@ -132,7 +132,7 @@ def topk(self, distances, k): k = min(int(k), num_database) # Fallback to CPU when Vulkan is unavailable - if not VULKAN_AVAILABLE or self.core is None: + if not VULKAN_PYTHON_BINDINGS_AVAILABLE or self.core is None: return self._cpu_topk(distances, k) # GPU attempt with validation fallback diff --git a/backend/fft.py b/backend/fft.py index a2d4d11..805eed4 100644 --- a/backend/fft.py +++ b/backend/fft.py @@ -16,9 +16,9 @@ import numpy as np -from .base import VULKAN_AVAILABLE, BufferMixin +from .base import BufferMixin, VULKAN_PYTHON_BINDINGS_AVAILABLE -if VULKAN_AVAILABLE: +if VULKAN_PYTHON_BINDINGS_AVAILABLE: from vulkan import * diff --git a/backend/fnn.py b/backend/fnn.py index 50591c0..d28461a 100644 --- a/backend/fnn.py +++ b/backend/fnn.py @@ -12,9 +12,9 @@ import numpy as np -from .base import VULKAN_AVAILABLE, BufferMixin +from .base import BufferMixin, VULKAN_PYTHON_BINDINGS_AVAILABLE -if VULKAN_AVAILABLE: +if VULKAN_PYTHON_BINDINGS_AVAILABLE: from vulkan import * # Try to import numba-accelerated operations for CPU fallback diff --git a/backend/fnn_chain.py b/backend/fnn_chain.py new file mode 100644 index 0000000..2918f5f --- /dev/null +++ b/backend/fnn_chain.py @@ -0,0 +1,281 @@ +""" +Batched FNN recording: multiple linear/relu/softmax dispatches, one submit + one fence wait. + +Use ``VulkanCompute.record_commands(fnn_chain=True)`` or ``VulkanFNN.chain_record()``. +""" + +from __future__ import annotations + +import struct +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Union + +import numpy as np + +from .core import CommandRecorder + +if TYPE_CHECKING: + from .fnn import VulkanFNN + + +@dataclass(frozen=True) +class ChainBufferHandle: + """GPU buffer produced inside a chain; pass to the next ``rec.*`` call or ``rec.read()``.""" + + buf: Any + nbytes: int + shape: tuple[int, ...] + + @property + def ndim(self) -> int: + return len(self.shape) + + +class FnnChainRecorder(CommandRecorder): + """Records linear / ReLU / softmax into the batch command buffer without per-op submit. + + Buffers acquired for numpy inputs and intermediate outputs are released when + ``read()`` runs or when the context exits. + """ + + def __init__(self, fnn: VulkanFNN): + super().__init__(fnn.core) + self._fnn = fnn + self._owned: list[Any] = [] + self._submitted = False + self._released = False + + def _track(self, buf: Any) -> None: + if buf is not None: + self._owned.append(buf) + + def _release_owned(self) -> None: + if self._released: + return + for buf in self._owned: + try: + self._fnn._release_buffer(buf) + except Exception: + pass + self._owned.clear() + self._released = True + + def linear( + self, + x: Union[np.ndarray, ChainBufferHandle, Any], + weights: np.ndarray, + bias: np.ndarray | None = None, + ) -> ChainBufferHandle: + """Record ``fnn-linear`` dispatch. Returns a handle to the output buffer (no download).""" + if not self._recording: + raise RuntimeError("linear() must be called inside a chain_record context") + fnn = self._fnn + if "fnn-linear" not in fnn.shaders: + raise RuntimeError("fnn-linear shader not available") + + from ..utils.tensor_conversion import VulkanTensor + + w_np = np.ascontiguousarray(weights, dtype=np.float32) + output_dim, input_dim = w_np.shape + + if isinstance(x, ChainBufferHandle): + buf_input = x.buf + original_shape = x.shape + if len(original_shape) > 2: + batch_seq = int(np.prod(original_shape[:-1])) + else: + batch_seq = original_shape[0] + if original_shape[-1] != input_dim: + raise ValueError( + f"ChainBufferHandle last dim {original_shape[-1]} != weight input_dim {input_dim}" + ) + input_nbytes = batch_seq * input_dim * 4 + elif isinstance(x, VulkanTensor): + original_shape = x.shape + if len(original_shape) > 2: + batch_seq = int(np.prod(original_shape[:-1])) + else: + batch_seq = original_shape[0] + input_nbytes = batch_seq * input_dim * 4 + buf_input, release_input = fnn._prepare_input(x, size=input_nbytes) + if release_input: + self._track(buf_input) + else: + x_np = np.asarray(x, dtype=np.float32) + original_shape = x_np.shape + if len(original_shape) > 2: + batch_seq = int(np.prod(original_shape[:-1])) + x_2d = x_np.reshape(batch_seq, input_dim) + else: + batch_seq = original_shape[0] + x_2d = x_np + x_flat = np.ascontiguousarray(x_2d).ravel() + input_nbytes = x_flat.nbytes + buf_input = fnn._acquire_buffer(input_nbytes) + fnn._upload_buffer(buf_input, x_flat) + self._track(buf_input) + release_input = True + + w_nbytes = int(w_np.size) * 4 + output_size = batch_seq * output_dim * 4 + + buf_weights, _ = fnn._get_or_upload_weight(w_np) + has_bias = 1 if bias is not None else 0 + if bias is not None: + bias_np = np.ascontiguousarray(bias, dtype=np.float32) + buf_bias, _ = fnn._get_or_upload_weight(bias_np) + bias_flat = bias_np.ravel() + bias_nbytes = bias_flat.nbytes + else: + buf_bias = fnn._acquire_buffer(4) + fnn._upload_buffer(buf_bias, np.zeros(1, dtype=np.float32)) + self._track(buf_bias) + bias_flat = None + bias_nbytes = 4 + + buf_output = fnn._acquire_buffer(output_size) + self._track(buf_output) + + pipeline, pipeline_layout, _ = fnn.pipelines.get_or_create_pipeline( + "fnn-linear", 4, push_constant_size=16 + ) + descriptor_set = fnn.pipelines.get_cached_descriptor_set( + "fnn-linear", + [ + (fnn._get_buffer_handle(buf_input), input_nbytes), + (fnn._get_buffer_handle(buf_weights), w_nbytes), + ( + fnn._get_buffer_handle(buf_bias), + bias_nbytes, + ), + (fnn._get_buffer_handle(buf_output), output_size), + ], + ) + push = struct.pack("IIII", batch_seq, input_dim, output_dim, has_bias) + gx = (output_dim + 15) // 16 + gy = (batch_seq + 15) // 16 + self.dispatch(pipeline, pipeline_layout, descriptor_set, (gx, gy), push) + self.barrier() + + if len(original_shape) > 2: + out_shape = original_shape[:-1] + (output_dim,) + else: + out_shape = (batch_seq, output_dim) + + return ChainBufferHandle(buf_output, output_size, out_shape) + + def relu(self, h: ChainBufferHandle) -> ChainBufferHandle: + """Record ``activation-relu`` dispatch.""" + if not self._recording: + raise RuntimeError("relu() must be called inside a chain_record context") + fnn = self._fnn + if "activation-relu" not in fnn.shaders: + raise RuntimeError("activation-relu shader not available") + + total_elements = h.nbytes // 4 + buf_out = fnn._acquire_buffer(h.nbytes) + self._track(buf_out) + + pipeline, pipeline_layout, _ = fnn.pipelines.get_or_create_pipeline( + "activation-relu", 2, push_constant_size=4 + ) + descriptor_set = fnn.pipelines.get_cached_descriptor_set( + "activation-relu", + [ + (fnn._get_buffer_handle(h.buf), h.nbytes), + (fnn._get_buffer_handle(buf_out), h.nbytes), + ], + ) + push = struct.pack("I", total_elements) + wg = (total_elements + 255) // 256 + self.dispatch(pipeline, pipeline_layout, descriptor_set, (wg,), push) + self.barrier() + + return ChainBufferHandle(buf_out, h.nbytes, h.shape) + + def softmax(self, h: ChainBufferHandle, dim: int = -1) -> ChainBufferHandle: + """Record ``activation-softmax`` (3 passes). ``dim`` must be the last axis (-1).""" + if not self._recording: + raise RuntimeError("softmax() must be called inside a chain_record context") + if dim not in (-1, len(h.shape) - 1): + raise NotImplementedError( + "FnnChainRecorder.softmax only supports softmax over the last dimension (dim=-1)" + ) + fnn = self._fnn + if "activation-softmax" not in fnn.shaders: + raise RuntimeError("activation-softmax shader not available") + + shape = h.shape + if len(shape) == 1: + batch_size, seq_len, features = 1, 1, shape[0] + elif len(shape) == 2: + batch_size, seq_len, features = shape[0], 1, shape[1] + else: + batch_size, seq_len, features = shape[0], shape[1], int(np.prod(shape[2:])) + + data_nbytes = h.nbytes + buf_max = fnn._acquire_buffer(batch_size * seq_len * 4) + buf_sum = fnn._acquire_buffer(batch_size * seq_len * 4) + self._track(buf_max) + self._track(buf_sum) + + buf_out = fnn._acquire_buffer(data_nbytes) + self._track(buf_out) + + pipeline, pipeline_layout, _ = fnn.pipelines.get_or_create_pipeline( + "activation-softmax", 4, push_constant_size=24 + ) + descriptor_set = fnn.pipelines.get_cached_descriptor_set( + "activation-softmax", + [ + (fnn._get_buffer_handle(h.buf), data_nbytes), + (fnn._get_buffer_handle(buf_out), data_nbytes), + (fnn._get_buffer_handle(buf_max), batch_size * seq_len * 4), + (fnn._get_buffer_handle(buf_sum), batch_size * seq_len * 4), + ], + ) + + workgroups_bs = ((batch_size * seq_len) + 255) // 256 + workgroups_flat = (int(np.prod(shape)) + 255) // 256 + + for pass_id in (0, 1, 2): + if pass_id > 0: + self.barrier() + push = struct.pack("IIIII", batch_size, seq_len, features, pass_id, features) + wg = workgroups_bs if pass_id < 2 else workgroups_flat + self.dispatch(pipeline, pipeline_layout, descriptor_set, (wg,), push) + self.barrier() + + return ChainBufferHandle(buf_out, data_nbytes, shape) + + def read_multiple(self, handles: list[ChainBufferHandle]) -> list[np.ndarray]: + """Submit once, wait, download every handle to numpy (same order as *handles*), then release chain-owned buffers. + + Use for MoE fan-out: several ``linear`` / other ops recorded, one fence wait, multiple CPU reads. + """ + if not self._recording: + raise RuntimeError( + "read_multiple() must be called before exiting the chain_record context" + ) + self.submit_and_wait() + self._submitted = True + out: list[np.ndarray] = [] + for h in handles: + arr = self._fnn._download_buffer(h.buf, h.nbytes, np.float32) + out.append(arr.reshape(h.shape)) + self._release_owned() + return out + + def read(self, h: ChainBufferHandle) -> np.ndarray: + """Submit recorded commands, wait, download *h* to numpy, release chain-owned buffers.""" + return self.read_multiple([h])[0] + + def __exit__(self, exc_type, exc_val, exc_tb): + if not self._submitted and self._recording: + try: + self.submit_and_wait() + except Exception: + pass + if not self._released: + self._release_owned() + return CommandRecorder.__exit__(self, exc_type, exc_val, exc_tb) diff --git a/backend/learning.py b/backend/learning.py index 8bd9f89..56087ee 100644 --- a/backend/learning.py +++ b/backend/learning.py @@ -15,9 +15,9 @@ import numpy as np -from .base import VULKAN_AVAILABLE, BufferMixin +from .base import BufferMixin, VULKAN_PYTHON_BINDINGS_AVAILABLE -if VULKAN_AVAILABLE: +if VULKAN_PYTHON_BINDINGS_AVAILABLE: from vulkan import * @@ -31,7 +31,7 @@ def __init__(self, core, pipelines, shaders): self.shaders = shaders def _free_descriptor_set(self, descriptor_set) -> None: - if descriptor_set is None or not VULKAN_AVAILABLE: + if descriptor_set is None or not VULKAN_PYTHON_BINDINGS_AVAILABLE: return try: vkFreeDescriptorSets( diff --git a/backend/lora.py b/backend/lora.py index 6aceecf..092cb77 100644 --- a/backend/lora.py +++ b/backend/lora.py @@ -14,9 +14,9 @@ import numpy as np -from .base import VULKAN_AVAILABLE, BufferMixin +from .base import BufferMixin, VULKAN_PYTHON_BINDINGS_AVAILABLE -if VULKAN_AVAILABLE: +if VULKAN_PYTHON_BINDINGS_AVAILABLE: from vulkan import * @@ -84,7 +84,7 @@ def forward( rank = A.shape[0] # Check if GPU shader available - if not VULKAN_AVAILABLE or "lora-forward" not in self.shaders: + if not VULKAN_PYTHON_BINDINGS_AVAILABLE or "lora-forward" not in self.shaders: return self._forward_cpu(x, W, A, B, scale, bias) try: @@ -270,7 +270,7 @@ def backward( rank = A.shape[0] # Check if GPU shader available - if not VULKAN_AVAILABLE or "lora-backward" not in self.shaders: + if not VULKAN_PYTHON_BINDINGS_AVAILABLE or "lora-backward" not in self.shaders: return self._backward_cpu(grad_output, x, A, B, h, scale) try: diff --git a/backend/normalization.py b/backend/normalization.py index 4a2658f..09f3ed2 100644 --- a/backend/normalization.py +++ b/backend/normalization.py @@ -7,9 +7,9 @@ import numpy as np -from .base import VULKAN_AVAILABLE, BufferMixin +from .base import BufferMixin, VULKAN_PYTHON_BINDINGS_AVAILABLE -if VULKAN_AVAILABLE: +if VULKAN_PYTHON_BINDINGS_AVAILABLE: from vulkan import * diff --git a/backend/pipelines.py b/backend/pipelines.py index b9aa5b2..ee939b5 100644 --- a/backend/pipelines.py +++ b/backend/pipelines.py @@ -4,9 +4,9 @@ from collections import OrderedDict -from .base import VULKAN_AVAILABLE +from .base import VULKAN_PYTHON_BINDINGS_AVAILABLE -if VULKAN_AVAILABLE: +if VULKAN_PYTHON_BINDINGS_AVAILABLE: from vulkan import * diff --git a/backend/pooling.py b/backend/pooling.py index 7352975..b4d22b3 100644 --- a/backend/pooling.py +++ b/backend/pooling.py @@ -7,9 +7,9 @@ import numpy as np -from .base import VULKAN_AVAILABLE, BufferMixin +from .base import BufferMixin, VULKAN_PYTHON_BINDINGS_AVAILABLE -if VULKAN_AVAILABLE: +if VULKAN_PYTHON_BINDINGS_AVAILABLE: from vulkan import * diff --git a/backend/snn.py b/backend/snn.py index 1229c74..512f96f 100644 --- a/backend/snn.py +++ b/backend/snn.py @@ -6,9 +6,9 @@ import numpy as np -from .base import VULKAN_AVAILABLE, BufferMixin +from .base import BufferMixin, VULKAN_PYTHON_BINDINGS_AVAILABLE -if VULKAN_AVAILABLE: +if VULKAN_PYTHON_BINDINGS_AVAILABLE: from vulkan import * diff --git a/backend/snn_compute.py b/backend/snn_compute.py index e2cbf51..5e2b59c 100644 --- a/backend/snn_compute.py +++ b/backend/snn_compute.py @@ -6,7 +6,7 @@ import numpy as np -from .base import VULKAN_AVAILABLE +from .base import VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE from .compute import VulkanCompute # Import config for SNN parameters (optional - for integration with main project) @@ -70,7 +70,7 @@ def __init__(self, n_neurons: int = None, use_vulkan: bool = True): self.use_vulkan = False self.backend = None - if use_vulkan and VULKAN_AVAILABLE: + if use_vulkan and VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE: try: self.backend = VulkanCompute() self.use_vulkan = True diff --git a/backend/tensor_ops.py b/backend/tensor_ops.py index 6641333..50d9eff 100644 --- a/backend/tensor_ops.py +++ b/backend/tensor_ops.py @@ -15,9 +15,9 @@ import numpy as np -from .base import VULKAN_AVAILABLE, BufferMixin +from .base import BufferMixin, VULKAN_PYTHON_BINDINGS_AVAILABLE -if VULKAN_AVAILABLE: +if VULKAN_PYTHON_BINDINGS_AVAILABLE: from vulkan import * # noqa: F401,F403 diff --git a/cpp/include/grilly/io/grl_checkpoint.h b/cpp/include/grilly/io/grl_checkpoint.h new file mode 100644 index 0000000..5852ef1 --- /dev/null +++ b/cpp/include/grilly/io/grl_checkpoint.h @@ -0,0 +1,25 @@ +#pragma once +/// GRL v1 checkpoint file layout (shared with Python ``utils/grl_checkpoint.py``). + +#include +#include +#include + +namespace grilly::io { + +inline constexpr uint32_t kGrlHeaderSize = 64; +inline constexpr uint16_t kGrlFormatVersion = 1; + +/// Write a GRL v1 file: header | metadata UTF-8 | index JSON UTF-8 | payload bytes. +bool grl_write_file(const std::string& path, + const std::string& metadata_json, + const std::string& index_json, + const std::vector& payload); + +/// Read a GRL v1 file; on success returns true and fills out-arguments. +bool grl_read_file(const std::string& path, + std::string& metadata_json, + std::string& index_json, + std::vector& payload); + +} // namespace grilly::io diff --git a/cpp/include/grilly/ops/activations.h b/cpp/include/grilly/ops/activations.h index 6178bb9..a601d49 100644 --- a/cpp/include/grilly/ops/activations.h +++ b/cpp/include/grilly/ops/activations.h @@ -83,5 +83,20 @@ void softmaxBackward(CommandBatch& batch, BufferPool& pool, PipelineCache& cache float* gradInput, uint32_t batchSize, uint32_t seqLen, uint32_t numClasses); +/// Multiplication-free softmax (ReLU-normalized): shader ``mf-softmax`` (same 3-pass +/// layout as ``activation-softmax``). +void mfSoftmax(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, + const float* input, float* output, uint32_t batchSize, uint32_t seqLen, + uint32_t features); + +/// Algebraic softplus: ``0.5 * (x + sqrt(x*x + c))`` with ``c = 4/beta^2``. Shader +/// ``mf-softplus`` (2 buffers). +void mfSoftplus(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, + const float* input, float* output, uint32_t totalElements, float beta); + +/// Rational sigmoid ``x / (1 + |x|)``. Shader ``mf-sigmoid`` (2 buffers). +void mfSigmoid(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, + const float* input, float* output, uint32_t totalElements); + } // namespace ops } // namespace grilly diff --git a/cpp/include/grilly/ops/batched_ops.h b/cpp/include/grilly/ops/batched_ops.h index ac77527..84bb884 100644 --- a/cpp/include/grilly/ops/batched_ops.h +++ b/cpp/include/grilly/ops/batched_ops.h @@ -15,6 +15,8 @@ #include "grilly/command_batch.h" #include "grilly/pipeline_cache.h" +#include "grilly/ops/embedding.h" + namespace grilly { namespace ops { @@ -60,6 +62,13 @@ void batchedAdd(CommandBatch& batch, PipelineCache& cache, GrillyBuffer& a, const GrillyBuffer& b, uint32_t totalElements); +/// Embedding lookup: output[batch*seq, d] = table[token_ids]. Caller must +/// `batch.begin()` first; buffers must be pre-uploaded (ids, table). +void batchedEmbeddingLookup(CommandBatch& batch, PipelineCache& cache, + const GrillyBuffer& tokenIds, const GrillyBuffer& embedTable, + GrillyBuffer& output, uint32_t batchSize, uint32_t seqLen, + uint32_t vocabSize, uint32_t embeddingDim); + /// Tiled GEMM: C[M,N] = A[M,K] @ B[N,K]^T + bias[N] /// Uses 32x32 shared memory tiles, 2x2 per-thread sub-tiles. void batchedTiledLinear(CommandBatch& batch, PipelineCache& cache, diff --git a/cpp/include/grilly/ops/moe_forward.h b/cpp/include/grilly/ops/moe_forward.h new file mode 100644 index 0000000..a6aeaaa --- /dev/null +++ b/cpp/include/grilly/ops/moe_forward.h @@ -0,0 +1,104 @@ +#pragma once +/// Fused MoE forward / backward — minimal Python crossings for training hot path. +/// +/// `moe_forward_gpu` runs embedding, per-layer expert GEMMs (batched), residual, +/// and output projection on GPU. Router softmax and expert blending use CPU for +/// correctness and simplicity (small tensors); blend can move to a dedicated +/// shader in a follow-up. + +#include +#include +#include + +#include "grilly/buffer_pool.h" +#include "grilly/command_batch.h" +#include "grilly/pipeline_cache.h" + +namespace grilly { +namespace ops { + +struct MoeLayerGPU { + GrillyBuffer routerW; + GrillyBuffer routerB; + std::vector expertW; + std::vector expertWt; + GrillyBuffer expertPacked; // All experts packed contiguously (n_experts * d * d) +}; + +struct MoeGradients { + std::vector grad_embed; + std::vector grad_pos; + std::vector> grad_experts; + std::vector> grad_router_w; + std::vector> grad_router_b; + std::vector grad_out_w; +}; + +struct MoeHandleCache { + uint32_t vocab = 0; + uint32_t d = 0; + uint32_t maxSeq = 0; + uint32_t nLayers = 0; + uint32_t nExperts = 0; + + GrillyBuffer embedW; + GrillyBuffer posW; + GrillyBuffer outW; + GrillyBuffer outWt; + + std::vector layers; + + GrillyBuffer bufIds; + GrillyBuffer bufPosSlice; + GrillyBuffer bufX; + std::vector bufExpertOut; + GrillyBuffer bufBlended; + GrillyBuffer bufLogits; + + // Saved activations for backward (one per layer + final) + std::vector bufActivations; // [0]=after embed, [l]=after layer l-1 + // Router weights per layer (saved from forward for backward) + std::vector> fwd_router_weights; + + std::vector cpu_embed; + std::vector cpu_pos; + std::vector cpu_out_w; + std::vector> cpu_router_w; + std::vector> cpu_router_b; + std::vector> cpu_expert_w; +}; + +int moe_upload(BufferPool& pool, + uint32_t vocab_size, uint32_t d_model, uint32_t max_seq, + const float* embed_w, const float* pos_w, + const std::vector& expert_ws, + const std::vector& router_ws, + const std::vector& router_bs, + const float* out_w, + uint32_t n_layers, uint32_t n_experts); + +MoeHandleCache& moe_get_cache(int handle); + +void moe_release(BufferPool& pool, int handle); + +void moe_update_weights(BufferPool& pool, MoeHandleCache& h, + const float* embed_w, const float* pos_w, + const std::vector& expert_ws, + const std::vector& router_ws, + const std::vector& router_bs, + const float* out_w); + +void moe_forward_gpu(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, + MoeHandleCache& h, const int32_t* input_ids, uint32_t seq_len, + float* logits_out); + +MoeGradients moe_backward_cpu(const MoeHandleCache& h, + const int32_t* input_ids, uint32_t seq_len, + const float* grad_logits); + +MoeGradients moe_backward_gpu(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, + MoeHandleCache& h, const int32_t* input_ids, uint32_t seq_len, + const float* grad_logits); + +} // namespace ops +} // namespace grilly diff --git a/cpp/include/grilly/ops/vsa_lm_forward.h b/cpp/include/grilly/ops/vsa_lm_forward.h new file mode 100644 index 0000000..5a5b20e --- /dev/null +++ b/cpp/include/grilly/ops/vsa_lm_forward.h @@ -0,0 +1,121 @@ +#pragma once +/// VSA Language Model fused forward/backward — AdditionLinear FFN + MindForge LoRA. +/// +/// Uses addition-linear shader (L1 distance, no matmul) for FFN layers and +/// fnn-linear for the output projection only. GPU router + MindForge LoRA +/// on CPU (tiny). Target: 10x faster than Python AdditionLinear loops. + +#include +#include + +#include "grilly/buffer_pool.h" +#include "grilly/command_batch.h" +#include "grilly/pipeline_cache.h" + +namespace grilly { +namespace ops { + +struct VsaLmLayerGPU { + GrillyBuffer ffnUpW; // (d_ffn, d) addition-linear patterns + GrillyBuffer ffnUpB; // (d_ffn,) + GrillyBuffer ffnDownW; // (d, d_ffn) addition-linear patterns + GrillyBuffer ffnDownB; // (d,) + GrillyBuffer lnGamma; // (d,) + GrillyBuffer lnBeta; // (d,) +}; + +struct VsaLmGradients { + std::vector grad_embed; // (vocab, d) + std::vector grad_pos; // (max_seq, d) + std::vector> grad_ffn_up_w; // n_layers of (d_ffn * d) + std::vector> grad_ffn_up_b; // n_layers of (d_ffn) + std::vector> grad_ffn_down_w; // n_layers of (d * d_ffn) + std::vector> grad_ffn_down_b; // n_layers of (d) + std::vector> grad_ln_gamma; // n_layers of (d) + std::vector> grad_ln_beta; // n_layers of (d) + std::vector grad_out_w; // (vocab, d) +}; + +struct VsaLmHandleCache { + uint32_t vocab = 0; + uint32_t d = 0; + uint32_t dFfn = 0; + uint32_t maxSeq = 0; + uint32_t nLayers = 0; + + // GPU buffers + GrillyBuffer embedW; + GrillyBuffer posW; + GrillyBuffer outW; + GrillyBuffer outWt; // transposed for backward + + std::vector layers; + + GrillyBuffer bufIds; + GrillyBuffer bufPosSlice; + GrillyBuffer bufX; // current hidden (seq, d) + GrillyBuffer bufLnOut; // layernorm output (seq, d) + GrillyBuffer bufFfnUp; // after addition-linear up (seq, d_ffn) + GrillyBuffer bufSign; // after sign activation (seq, d_ffn) + GrillyBuffer bufFfnDown; // after addition-linear down (seq, d) + GrillyBuffer bufLogits; // (seq, vocab) + + // LayerNorm scratch (mean + var buffers) + GrillyBuffer bufLnMean; // (seq,) + GrillyBuffer bufLnVar; // (seq,) + + // Saved activations for backward (n_layers + 1) + std::vector bufActivations; + + // CPU mirrors for backward + std::vector cpu_embed; + std::vector cpu_pos; + std::vector cpu_out_w; + std::vector> cpu_ffn_up_w; // n_layers + std::vector> cpu_ffn_up_b; + std::vector> cpu_ffn_down_w; + std::vector> cpu_ffn_down_b; + std::vector> cpu_ln_gamma; + std::vector> cpu_ln_beta; + + // MindForge weights (CPU only — tiny) + std::vector forge_W_proj; // (d_hidden, d_vsa) + std::vector forge_W_h; // (d_hidden, d_hidden*2) + std::vector forge_b_h; // (d_hidden,) + std::vector forge_W_coeff; // (n_basis, d_hidden) + std::vector forge_b_coeff; // (n_basis,) + std::vector forge_A_basis; // (n_basis, rank, d) + std::vector forge_B_basis; // (n_basis, d, rank) + std::vector forge_layer_embs; // (n_layers, d_hidden) + uint32_t forge_d_hidden = 0; + uint32_t forge_d_vsa = 0; + uint32_t forge_n_basis = 0; + uint32_t forge_rank = 0; +}; + +int vsa_lm_upload(BufferPool& pool, + uint32_t vocab, uint32_t d, uint32_t d_ffn, uint32_t max_seq, + const float* embed_w, const float* pos_w, + const std::vector& ffn_up_ws, + const std::vector& ffn_up_bs, + const std::vector& ffn_down_ws, + const std::vector& ffn_down_bs, + const std::vector& ln_gammas, + const std::vector& ln_betas, + const float* out_w, + uint32_t n_layers); + +VsaLmHandleCache& vsa_lm_get_cache(int handle); + +void vsa_lm_release(BufferPool& pool, int handle); + +void vsa_lm_forward_gpu(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, + VsaLmHandleCache& h, const int32_t* input_ids, + uint32_t seq_len, float* logits_out); + +VsaLmGradients vsa_lm_backward_cpu(const VsaLmHandleCache& h, + const int32_t* input_ids, uint32_t seq_len, + const float* grad_logits); + +} // namespace ops +} // namespace grilly diff --git a/cpp/include/grilly/vulkan/vk_buffer_pool.h b/cpp/include/grilly/vulkan/vk_buffer_pool.h index c938ee1..a4873f1 100644 --- a/cpp/include/grilly/vulkan/vk_buffer_pool.h +++ b/cpp/include/grilly/vulkan/vk_buffer_pool.h @@ -21,6 +21,18 @@ struct GrillyBuffer { size_t size = 0; size_t bucketSize = 0; void* mappedPtr = nullptr; + /// True if this buffer was acquired via ``acquireDeviceLocal`` (no host + /// visibility, mapped to GPU-only VRAM). ``release`` uses this to route + /// the buffer back to the device-local bucket pool, so a regular + /// ``acquire`` won't accidentally pick up a DL buffer and crash trying + /// to memcpy into a null mappedPtr. + bool deviceLocal = false; + /// True if this buffer was acquired via ``acquireReadback`` — backed + /// by HOST_CACHED memory for fast CPU read-after-GPU-write. Released + /// buffers go to a separate readback pool so a future ``acquire`` + /// (which expects WC sequential-write memory) doesn't waste a cached + /// readback slot on a write-only workload. + bool readback = false; }; /// VMA-backed buffer pool with power-of-2 bucket reuse. @@ -65,7 +77,16 @@ class BufferPool { GrillyDevice& device_; VmaAllocator allocator_ = VK_NULL_HANDLE; std::mutex mutex_; + /// Pool for host-visible WC buffers (acquire / release default path). + /// Optimized for CPU sequential writes (memcpy CPU → buffer). std::unordered_map> buckets_; + /// Pool for DEVICE_LOCAL-only buffers (acquireDeviceLocal / release). + /// Kept separate so a host-visible acquire never picks up a DL buffer. + std::unordered_map> dlBuckets_; + /// Pool for HOST_CACHED readback buffers (acquireReadback / release). + /// Optimized for CPU random reads (memcpy buffer → CPU output) at + /// cached system RAM speed (~10 GB/s vs ~25 MB/s for WC reads). + std::unordered_map> readbackBuckets_; Stats stats_{}; // Persistent transfer context (avoids cmd pool/fence recreation per transfer) diff --git a/cpp/include/grilly/vulkan/vk_command_batch.h b/cpp/include/grilly/vulkan/vk_command_batch.h index 783ebb9..4e80c4c 100644 --- a/cpp/include/grilly/vulkan/vk_command_batch.h +++ b/cpp/include/grilly/vulkan/vk_command_batch.h @@ -31,6 +31,12 @@ class CommandBatch { void barrier(); + /// Bidirectional TRANSFER ↔ COMPUTE memory barrier — covers both + /// stage-in → compute (TRANSFER_WRITE → SHADER_READ) and + /// compute → stage-out (SHADER_WRITE → TRANSFER_READ). Used by the + /// staging-buffer pattern in cpp/src/ops/linear.cpp. + void transferComputeBarrier(); + /// GPU-to-GPU buffer copy (no CPU involvement). void copyBuffer(const GrillyBuffer& src, GrillyBuffer& dst, size_t bytes); diff --git a/cpp/python/bindings_activations.cpp b/cpp/python/bindings_activations.cpp index cd8cd9f..c4ea83b 100644 --- a/cpp/python/bindings_activations.cpp +++ b/cpp/python/bindings_activations.cpp @@ -42,6 +42,55 @@ void register_activations_ops(py::module_& m) { defActivation("silu", grilly::ops::silu); defActivation("tanh_act", grilly::ops::tanh_act); + m.def( + "mf_sigmoid", + [](GrillyCoreContext& ctx, + py::array_t input) -> Tensor { + auto inBuf = input.request(); + require_c_contiguous_float(inBuf); + uint32_t total = static_cast(inBuf.size); + + py::array_t result(input.request().shape); + auto rBuf = result.request(); + + { + py::gil_scoped_release release; + grilly::ops::mfSigmoid( + ctx.batch, ctx.pool, ctx.cache, + static_cast(inBuf.ptr), + static_cast(rBuf.ptr), total); + } + + return Tensor::from_numpy(result); + }, + py::arg("device"), py::arg("input"), + "GPU rational sigmoid x/(1+|x|) (shader mf-sigmoid)"); + + m.def( + "mf_softplus", + [](GrillyCoreContext& ctx, py::array_t input, + float beta) -> Tensor { + auto inBuf = input.request(); + require_c_contiguous_float(inBuf); + uint32_t total = static_cast(inBuf.size); + + py::array_t result(input.request().shape); + auto rBuf = result.request(); + + { + py::gil_scoped_release release; + grilly::ops::mfSoftplus( + ctx.batch, ctx.pool, ctx.cache, + static_cast(inBuf.ptr), + static_cast(rBuf.ptr), total, beta); + } + + return Tensor::from_numpy(result); + }, + py::arg("device"), py::arg("input"), py::arg("beta") = 1.0f, + "GPU algebraic softplus 0.5*(x+sqrt(x*x+c)), c=4/beta^2 (shader " + "mf-softplus)"); + // ── Activation backward ops ────────────────────────────────────────── // 3 buffers: grad_output, input (original forward input), grad_input. // Same push constant (uint total_elements) and dispatch pattern. diff --git a/cpp/python/bindings_core.cpp b/cpp/python/bindings_core.cpp index cf593ea..0e9f82c 100644 --- a/cpp/python/bindings_core.cpp +++ b/cpp/python/bindings_core.cpp @@ -435,6 +435,9 @@ PYBIND11_MODULE(grilly_core, m) { register_siglip_ops(m); register_perceiver_ops(m); register_moqe_train_ops(m); + register_moe_ops(m); register_fusion_ops(m); + register_vsa_lm_ops(m); + register_grl_ops(m); register_misc_ops(m); } diff --git a/cpp/python/bindings_core.h b/cpp/python/bindings_core.h index 437ef72..6861e11 100644 --- a/cpp/python/bindings_core.h +++ b/cpp/python/bindings_core.h @@ -127,6 +127,19 @@ inline void require_c_contiguous_int8(const py::buffer_info& buf) { } /// Require NumPy C-contiguous uint32 (e.g. CE targets). +inline void require_c_contiguous_int32(const py::buffer_info& buf) { + if (buf.itemsize != sizeof(int32_t)) + throw std::runtime_error("expected int32 array"); + if (buf.ndim == 0) + return; + py::ssize_t expected_stride = static_cast(sizeof(int32_t)); + for (int i = static_cast(buf.ndim) - 1; i >= 0; --i) { + if (buf.strides[i] != expected_stride) + throw std::runtime_error("array must be C-contiguous int32"); + expected_stride *= buf.shape[i]; + } +} + inline void require_c_contiguous_uint32(const py::buffer_info& buf) { if (buf.itemsize != sizeof(uint32_t)) throw std::runtime_error("expected uint32 array"); @@ -174,4 +187,7 @@ void register_pooling_ops(py::module_& m); void register_misc_ops(py::module_& m); void register_perceiver_ops(py::module_& m); void register_moqe_train_ops(py::module_& m); +void register_moe_ops(py::module_& m); void register_fusion_ops(py::module_& m); +void register_vsa_lm_ops(py::module_& m); +void register_grl_ops(py::module_& m); diff --git a/cpp/python/bindings_grl.cpp b/cpp/python/bindings_grl.cpp new file mode 100644 index 0000000..959e55d --- /dev/null +++ b/cpp/python/bindings_grl.cpp @@ -0,0 +1,44 @@ +/// Pybind11 bindings for GRL (.grl) checkpoint I/O — implemented in C++ for +/// performance and a single canonical binary encoder/decoder. + +#include "bindings_core.h" +#include "grilly/io/grl_checkpoint.h" + +#include +#include + +void register_grl_ops(py::module_& m) { + m.def( + "grl_write_file", + [](const std::string& path, const std::string& metadata_json, + const std::string& index_json, py::bytes payload_bytes) { + std::string pb = payload_bytes; + std::vector payload(pb.begin(), pb.end()); + if (!grilly::io::grl_write_file(path, metadata_json, index_json, + payload)) { + throw std::runtime_error("grl_write_file failed: " + path); + } + }, + py::arg("path"), py::arg("metadata_json"), py::arg("index_json"), + py::arg("payload"), + "Write a GRL v1 checkpoint (header + metadata JSON + tensor index JSON " + "+ raw payload bytes). Matches Python utils/grl_checkpoint layout."); + + m.def( + "grl_read_file", + [](const std::string& path) { + std::string metadata_json; + std::string index_json; + std::vector payload; + if (!grilly::io::grl_read_file(path, metadata_json, index_json, + payload)) { + throw std::runtime_error("grl_read_file failed: " + path); + } + return py::make_tuple(metadata_json, index_json, + py::bytes(reinterpret_cast( + payload.data()), + payload.size())); + }, + py::arg("path"), + "Read a GRL v1 file. Returns (metadata_json, index_json, payload_bytes)."); +} diff --git a/cpp/python/bindings_moe.cpp b/cpp/python/bindings_moe.cpp new file mode 100644 index 0000000..e6049a1 --- /dev/null +++ b/cpp/python/bindings_moe.cpp @@ -0,0 +1,253 @@ +/// bindings_moe.cpp — fused MoE upload / forward / backward / weight update. +#include "bindings_core.h" +#include "grilly/ops/moe_forward.h" + +#include +#include +#include + +using namespace grilly; + +void register_moe_ops(py::module_& m) { + + m.def( + "moe_upload", + [](GrillyCoreContext& ctx, py::array_t embed_w, py::array_t pos_w, + py::list expert_ws, py::list router_ws, py::list router_bs, py::array_t out_w, + uint32_t n_layers, uint32_t n_experts) -> int { + auto eb = embed_w.request(); + auto pb = pos_w.request(); + auto ob = out_w.request(); + require_c_contiguous_float(eb); + require_c_contiguous_float(pb); + require_c_contiguous_float(ob); + + if (eb.ndim != 2 || pb.ndim != 2 || ob.ndim != 2) + throw std::runtime_error("moe_upload: embed/pos/out must be 2-D"); + uint32_t vocab = static_cast(eb.shape[0]); + uint32_t d = static_cast(eb.shape[1]); + if (static_cast(ob.shape[1]) != d) + throw std::runtime_error("moe_upload: out_w second dim must match d"); + if (static_cast(ob.shape[0]) != vocab) + throw std::runtime_error("moe_upload: out_w first dim must match vocab"); + uint32_t max_seq = static_cast(pb.shape[0]); + if (static_cast(pb.shape[1]) != d) + throw std::runtime_error("moe_upload: pos_w second dim must match d"); + + std::vector exp_ptrs; + std::vector> exp_keep; + for (auto& item : expert_ws) { + auto arr = item.cast>(); + auto wb = arr.request(); + require_c_contiguous_float(wb); + if (wb.ndim != 2 || static_cast(wb.shape[0]) != d || + static_cast(wb.shape[1]) != d) + throw std::runtime_error("moe_upload: each expert must be (d, d)"); + exp_keep.push_back(arr); + exp_ptrs.push_back(static_cast(wb.ptr)); + } + + std::vector rW_ptrs; + std::vector rb_ptrs; + std::vector> rW_keep; + std::vector> rb_keep; + for (uint32_t l = 0; l < n_layers; ++l) { + auto rw = router_ws[l].cast>(); + auto rb = router_bs[l].cast>(); + auto rwb = rw.request(); + auto rbb = rb.request(); + require_c_contiguous_float(rwb); + require_c_contiguous_float(rbb); + if (rwb.ndim != 2 || static_cast(rwb.shape[0]) != n_experts || + static_cast(rwb.shape[1]) != d) + throw std::runtime_error("moe_upload: router_w must be (n_experts, d)"); + if (rbb.ndim != 1 || static_cast(rbb.shape[0]) != n_experts) + throw std::runtime_error("moe_upload: router_b must be (n_experts,)"); + rW_keep.push_back(rw); + rb_keep.push_back(rb); + rW_ptrs.push_back(static_cast(rwb.ptr)); + rb_ptrs.push_back(static_cast(rbb.ptr)); + } + + int handle = ops::moe_upload( + ctx.pool, vocab, d, max_seq, + static_cast(eb.ptr), + static_cast(pb.ptr), + exp_ptrs, rW_ptrs, rb_ptrs, + static_cast(ob.ptr), + n_layers, n_experts); + + ctx.waitIdle(); + return handle; + }, + py::arg("device"), py::arg("embed_w"), py::arg("pos_w"), py::arg("expert_ws"), + py::arg("router_ws"), py::arg("router_bs"), py::arg("out_w"), + py::arg("n_layers"), py::arg("n_experts"), + "Upload MoE weights to GPU; returns opaque integer handle."); + + m.def( + "moe_release", + [](GrillyCoreContext& ctx, int handle) { ops::moe_release(ctx.pool, handle); }, + py::arg("device"), py::arg("handle"), "Free GPU buffers for a MoE handle."); + + m.def( + "moe_forward", + [](GrillyCoreContext& ctx, int handle, py::array_t input_ids) + -> py::array_t { + auto& cache = ops::moe_get_cache(handle); + auto ib = input_ids.request(); + require_c_contiguous_int32(ib); + if (ib.ndim != 1) + throw std::runtime_error("moe_forward: input_ids must be 1-D"); + uint32_t seq_len = static_cast(ib.shape[0]); + uint32_t V = cache.vocab; + + py::array_t logits({static_cast(seq_len), + static_cast(V)}); + + { + py::gil_scoped_release release; + std::lock_guard lock(ctx.ctx_mutex); + ops::moe_forward_gpu(ctx.batch, ctx.pool, ctx.cache, cache, + static_cast(ib.ptr), seq_len, + logits.mutable_data()); + } + return logits; + }, + py::arg("device"), py::arg("handle"), py::arg("input_ids"), + "Run full MoE forward on GPU (router/blend on CPU). Returns (seq_len, vocab)."); + + m.def( + "moe_update_weights", + [](GrillyCoreContext& ctx, int handle, py::array_t embed_w, + py::array_t pos_w, py::list expert_ws, py::list router_ws, + py::list router_bs, py::array_t out_w) { + auto& h = ops::moe_get_cache(handle); + auto eb = embed_w.request(); + auto pb = pos_w.request(); + auto ob = out_w.request(); + require_c_contiguous_float(eb); + require_c_contiguous_float(pb); + require_c_contiguous_float(ob); + + uint32_t L = h.nLayers; + uint32_t E = h.nExperts; + + std::vector exp_ptrs; + std::vector> exp_keep; + for (auto& item : expert_ws) { + auto arr = item.cast>(); + auto wb = arr.request(); + require_c_contiguous_float(wb); + exp_keep.push_back(arr); + exp_ptrs.push_back(static_cast(wb.ptr)); + } + + std::vector rW_ptrs; + std::vector rb_ptrs; + for (uint32_t l = 0; l < L; ++l) { + auto rw = router_ws[l].cast>(); + auto rb = router_bs[l].cast>(); + auto rwb = rw.request(); + auto rbb = rb.request(); + require_c_contiguous_float(rwb); + require_c_contiguous_float(rbb); + rW_ptrs.push_back(static_cast(rwb.ptr)); + rb_ptrs.push_back(static_cast(rbb.ptr)); + } + + { + py::gil_scoped_release release; + std::lock_guard lock(ctx.ctx_mutex); + ops::moe_update_weights( + ctx.pool, h, + static_cast(eb.ptr), + static_cast(pb.ptr), + exp_ptrs, rW_ptrs, rb_ptrs, + static_cast(ob.ptr)); + ctx.waitIdle(); + } + }, + py::arg("device"), py::arg("handle"), py::arg("embed_w"), py::arg("pos_w"), + py::arg("expert_ws"), py::arg("router_ws"), py::arg("router_bs"), py::arg("out_w"), + "Re-upload weights in place after an optimizer step."); + + m.def( + "moe_backward", + [](GrillyCoreContext& ctx, int handle, py::array_t input_ids, + py::array_t grad_logits) -> py::dict { + auto& h = ops::moe_get_cache(handle); + auto ib = input_ids.request(); + auto gb = grad_logits.request(); + require_c_contiguous_int32(ib); + require_c_contiguous_float(gb); + if (ib.ndim != 1) + throw std::runtime_error("moe_backward: input_ids must be 1-D"); + if (gb.ndim != 2) + throw std::runtime_error("moe_backward: grad_logits must be 2-D"); + uint32_t seq_len = static_cast(ib.shape[0]); + uint32_t V = h.vocab; + if (static_cast(gb.shape[0]) != seq_len || + static_cast(gb.shape[1]) != V) + throw std::runtime_error("moe_backward: grad_logits shape mismatch"); + + ops::MoeGradients grads; + { + py::gil_scoped_release release; + std::lock_guard lock(ctx.ctx_mutex); + grads = ops::moe_backward_gpu( + ctx.batch, ctx.pool, ctx.cache, h, + static_cast(ib.ptr), seq_len, + static_cast(gb.ptr)); + } + + py::dict d; + py::array_t grad_embed({static_cast(h.vocab), + static_cast(h.d)}); + std::memcpy(grad_embed.mutable_data(), grads.grad_embed.data(), + grads.grad_embed.size() * sizeof(float)); + d["grad_embed"] = grad_embed; + + py::array_t grad_pos({static_cast(h.maxSeq), + static_cast(h.d)}); + std::memcpy(grad_pos.mutable_data(), grads.grad_pos.data(), + grads.grad_pos.size() * sizeof(float)); + d["grad_pos"] = grad_pos; + + py::list ge; + for (size_t i = 0; i < grads.grad_experts.size(); ++i) { + py::array_t gex({static_cast(h.d), + static_cast(h.d)}); + std::memcpy(gex.mutable_data(), grads.grad_experts[i].data(), + h.d * h.d * sizeof(float)); + ge.append(gex); + } + d["grad_experts"] = ge; + + py::list grw; + py::list grb; + for (uint32_t l = 0; l < h.nLayers; ++l) { + py::array_t gw({static_cast(h.nExperts), + static_cast(h.d)}); + std::memcpy(gw.mutable_data(), grads.grad_router_w[l].data(), + h.nExperts * h.d * sizeof(float)); + grw.append(gw); + py::array_t gbv({static_cast(h.nExperts)}); + std::memcpy(gbv.mutable_data(), grads.grad_router_b[l].data(), + h.nExperts * sizeof(float)); + grb.append(gbv); + } + d["grad_routers_W"] = grw; + d["grad_routers_b"] = grb; + + py::array_t gow({static_cast(h.vocab), + static_cast(h.d)}); + std::memcpy(gow.mutable_data(), grads.grad_out_w.data(), + grads.grad_out_w.size() * sizeof(float)); + d["grad_out_w"] = gow; + return d; + }, + py::arg("device"), py::arg("handle"), py::arg("input_ids"), + py::arg("grad_logits"), + "Backward pass (GPU path when available, CPU fallback) using uploaded MoE weights."); +} diff --git a/cpp/python/bindings_normalization.cpp b/cpp/python/bindings_normalization.cpp index bb0e58b..4428639 100644 --- a/cpp/python/bindings_normalization.cpp +++ b/cpp/python/bindings_normalization.cpp @@ -262,6 +262,38 @@ void register_normalization_ops(py::module_& m) { py::arg("device"), py::arg("input"), py::arg("dim") = -1, "GPU Softmax (3-pass: max, sum_exp, normalize)"); + m.def( + "mf_softmax", + [](GrillyCoreContext& ctx, py::array_t input, + int dim) -> Tensor { + auto inBuf = input.request(); + require_c_contiguous_float(inBuf); + if (inBuf.ndim < 1) + throw std::runtime_error("input must be at least 1D"); + + uint32_t features = static_cast( + inBuf.shape[inBuf.ndim - 1]); + uint32_t totalBatch = 1; + for (int i = 0; i < inBuf.ndim - 1; ++i) + totalBatch *= static_cast(inBuf.shape[i]); + if (inBuf.ndim == 1) totalBatch = 1; + + py::array_t result(inBuf.shape); + auto rBuf = result.request(); + { + py::gil_scoped_release release; + grilly::ops::mfSoftmax( + ctx.batch, ctx.pool, ctx.cache, + static_cast(inBuf.ptr), + static_cast(rBuf.ptr), + 1, totalBatch, features); + } + return Tensor::from_numpy(result); + }, + py::arg("device"), py::arg("input"), py::arg("dim") = -1, + "GPU multiplication-free softmax (ReLU-normalized; 3-pass, shader " + "mf-softmax)"); + // ── Softmax backward ───────────────────────────────────────────────── m.def( "softmax_backward", diff --git a/cpp/python/bindings_vsa_lm.cpp b/cpp/python/bindings_vsa_lm.cpp new file mode 100644 index 0000000..134ad60 --- /dev/null +++ b/cpp/python/bindings_vsa_lm.cpp @@ -0,0 +1,305 @@ +/// bindings_vsa_lm.cpp — fused VSA-LM upload / forward / backward / release. +#include "bindings_core.h" +#include "grilly/ops/vsa_lm_forward.h" + +#include +#include +#include + +using namespace grilly; + +void register_vsa_lm_ops(py::module_& m) { + + m.def( + "vsa_lm_upload", + [](GrillyCoreContext& ctx, + py::array_t embed_w, py::array_t pos_w, + py::list ffn_up_patterns, py::list ffn_up_biases, + py::list ffn_down_patterns, py::list ffn_down_biases, + py::list ln_gammas, py::list ln_betas, + py::array_t out_w, + uint32_t n_layers, uint32_t d_model, uint32_t d_ffn) -> int { + + auto eb = embed_w.request(); + auto pb = pos_w.request(); + auto ob = out_w.request(); + require_c_contiguous_float(eb); + require_c_contiguous_float(pb); + require_c_contiguous_float(ob); + + if (eb.ndim != 2 || pb.ndim != 2 || ob.ndim != 2) + throw std::runtime_error("vsa_lm_upload: embed/pos/out must be 2-D"); + uint32_t vocab = static_cast(eb.shape[0]); + uint32_t d = static_cast(eb.shape[1]); + if (d != d_model) + throw std::runtime_error("vsa_lm_upload: embed_w dim-1 must match d_model"); + uint32_t max_seq = static_cast(pb.shape[0]); + + auto extract_list = [](py::list& lst, uint32_t n, + std::vector& ptrs, + std::vector>& keep) { + for (uint32_t i = 0; i < n; ++i) { + auto arr = lst[i].cast>(); + auto buf = arr.request(); + require_c_contiguous_float(buf); + keep.push_back(arr); + ptrs.push_back(static_cast(buf.ptr)); + } + }; + + std::vector up_w_ptrs, up_b_ptrs, dn_w_ptrs, dn_b_ptrs; + std::vector gm_ptrs, bt_ptrs; + std::vector> keep1, keep2, keep3, keep4, keep5, keep6; + + extract_list(ffn_up_patterns, n_layers, up_w_ptrs, keep1); + extract_list(ffn_up_biases, n_layers, up_b_ptrs, keep2); + extract_list(ffn_down_patterns, n_layers, dn_w_ptrs, keep3); + extract_list(ffn_down_biases, n_layers, dn_b_ptrs, keep4); + extract_list(ln_gammas, n_layers, gm_ptrs, keep5); + extract_list(ln_betas, n_layers, bt_ptrs, keep6); + + int handle = ops::vsa_lm_upload( + ctx.pool, vocab, d, d_ffn, max_seq, + static_cast(eb.ptr), + static_cast(pb.ptr), + up_w_ptrs, up_b_ptrs, dn_w_ptrs, dn_b_ptrs, + gm_ptrs, bt_ptrs, + static_cast(ob.ptr), + n_layers); + + ctx.waitIdle(); + return handle; + }, + py::arg("device"), py::arg("embed_w"), py::arg("pos_w"), + py::arg("ffn_up_patterns"), py::arg("ffn_up_biases"), + py::arg("ffn_down_patterns"), py::arg("ffn_down_biases"), + py::arg("ln_gammas"), py::arg("ln_betas"), + py::arg("out_w"), + py::arg("n_layers"), py::arg("d_model"), py::arg("d_ffn"), + "Upload VSA-LM weights to GPU; returns opaque integer handle."); + + m.def( + "vsa_lm_release", + [](GrillyCoreContext& ctx, int handle) { + ops::vsa_lm_release(ctx.pool, handle); + }, + py::arg("device"), py::arg("handle"), + "Free GPU buffers for a VSA-LM handle."); + + m.def( + "vsa_lm_forward", + [](GrillyCoreContext& ctx, int handle, py::array_t input_ids) + -> py::array_t { + auto& h = ops::vsa_lm_get_cache(handle); + auto ib = input_ids.request(); + require_c_contiguous_int32(ib); + if (ib.ndim != 1) + throw std::runtime_error("vsa_lm_forward: input_ids must be 1-D"); + uint32_t seq_len = static_cast(ib.shape[0]); + + py::array_t logits({static_cast(seq_len), + static_cast(h.vocab)}); + { + py::gil_scoped_release release; + std::lock_guard lock(ctx.ctx_mutex); + ops::vsa_lm_forward_gpu(ctx.batch, ctx.pool, ctx.cache, h, + static_cast(ib.ptr), + seq_len, logits.mutable_data()); + } + return logits; + }, + py::arg("device"), py::arg("handle"), py::arg("input_ids"), + "Run full VSA-LM forward on GPU. Returns (seq_len, vocab) float32."); + + m.def( + "vsa_lm_backward", + [](GrillyCoreContext& ctx, int handle, py::array_t input_ids, + py::array_t grad_logits) -> py::dict { + auto& h = ops::vsa_lm_get_cache(handle); + auto ib = input_ids.request(); + auto gb = grad_logits.request(); + require_c_contiguous_int32(ib); + require_c_contiguous_float(gb); + if (ib.ndim != 1) + throw std::runtime_error("vsa_lm_backward: input_ids must be 1-D"); + if (gb.ndim != 2) + throw std::runtime_error("vsa_lm_backward: grad_logits must be 2-D"); + uint32_t seq_len = static_cast(ib.shape[0]); + uint32_t V = h.vocab; + uint32_t d = h.d; + uint32_t dF = h.dFfn; + uint32_t L = h.nLayers; + + if (static_cast(gb.shape[0]) != seq_len || + static_cast(gb.shape[1]) != V) + throw std::runtime_error("vsa_lm_backward: grad_logits shape mismatch"); + + ops::VsaLmGradients grads; + { + py::gil_scoped_release release; + grads = ops::vsa_lm_backward_cpu( + h, static_cast(ib.ptr), seq_len, + static_cast(gb.ptr)); + } + + py::dict result; + + // grad_embed (vocab, d) + py::array_t ge({static_cast(V), + static_cast(d)}); + std::memcpy(ge.mutable_data(), grads.grad_embed.data(), + grads.grad_embed.size() * sizeof(float)); + result["grad_embed"] = ge; + + // grad_pos (max_seq, d) + py::array_t gp({static_cast(h.maxSeq), + static_cast(d)}); + std::memcpy(gp.mutable_data(), grads.grad_pos.data(), + grads.grad_pos.size() * sizeof(float)); + result["grad_pos"] = gp; + + // grad_out_w (vocab, d) + py::array_t gow({static_cast(V), + static_cast(d)}); + std::memcpy(gow.mutable_data(), grads.grad_out_w.data(), + grads.grad_out_w.size() * sizeof(float)); + result["grad_out_w"] = gow; + + // Per-layer gradients as lists + py::list g_up_w, g_up_b, g_dn_w, g_dn_b, g_ln_g, g_ln_b; + for (uint32_t l = 0; l < L; ++l) { + py::array_t uw({static_cast(dF), + static_cast(d)}); + std::memcpy(uw.mutable_data(), grads.grad_ffn_up_w[l].data(), + dF * d * sizeof(float)); + g_up_w.append(uw); + + py::array_t ub({static_cast(dF)}); + std::memcpy(ub.mutable_data(), grads.grad_ffn_up_b[l].data(), + dF * sizeof(float)); + g_up_b.append(ub); + + py::array_t dw({static_cast(d), + static_cast(dF)}); + std::memcpy(dw.mutable_data(), grads.grad_ffn_down_w[l].data(), + d * dF * sizeof(float)); + g_dn_w.append(dw); + + py::array_t db({static_cast(d)}); + std::memcpy(db.mutable_data(), grads.grad_ffn_down_b[l].data(), + d * sizeof(float)); + g_dn_b.append(db); + + py::array_t lg({static_cast(d)}); + std::memcpy(lg.mutable_data(), grads.grad_ln_gamma[l].data(), + d * sizeof(float)); + g_ln_g.append(lg); + + py::array_t lb({static_cast(d)}); + std::memcpy(lb.mutable_data(), grads.grad_ln_beta[l].data(), + d * sizeof(float)); + g_ln_b.append(lb); + } + result["grad_ffn_up_w"] = g_up_w; + result["grad_ffn_up_b"] = g_up_b; + result["grad_ffn_down_w"] = g_dn_w; + result["grad_ffn_down_b"] = g_dn_b; + result["grad_ln_gamma"] = g_ln_g; + result["grad_ln_beta"] = g_ln_b; + + return result; + }, + py::arg("device"), py::arg("handle"), py::arg("input_ids"), + py::arg("grad_logits"), + "CPU backward for VSA-LM. Returns dict with all gradient arrays."); + + m.def( + "vsa_lm_update_weights", + [](GrillyCoreContext& ctx, int handle, + py::array_t embed_w, py::array_t pos_w, + py::list ffn_up_patterns, py::list ffn_up_biases, + py::list ffn_down_patterns, py::list ffn_down_biases, + py::list ln_gammas, py::list ln_betas, + py::array_t out_w) { + + auto& h = ops::vsa_lm_get_cache(handle); + auto eb = embed_w.request(); + auto pb = pos_w.request(); + auto ob = out_w.request(); + require_c_contiguous_float(eb); + require_c_contiguous_float(pb); + require_c_contiguous_float(ob); + + uint32_t d = h.d; + uint32_t dF = h.dFfn; + uint32_t V = h.vocab; + uint32_t L = h.nLayers; + + // Re-upload embedding + size_t embed_bytes = size_t(V) * d * sizeof(float); + std::memcpy(h.cpu_embed.data(), static_cast(eb.ptr), embed_bytes); + ctx.pool.upload(h.embedW, static_cast(eb.ptr), embed_bytes); + + // Re-upload pos + size_t pos_bytes = size_t(h.maxSeq) * d * sizeof(float); + std::memcpy(h.cpu_pos.data(), static_cast(pb.ptr), pos_bytes); + ctx.pool.upload(h.posW, static_cast(pb.ptr), pos_bytes); + + // Re-upload out_w + transpose + size_t out_bytes = size_t(V) * d * sizeof(float); + std::memcpy(h.cpu_out_w.data(), static_cast(ob.ptr), out_bytes); + ctx.pool.upload(h.outW, static_cast(ob.ptr), out_bytes); + std::vector out_wt(d * V); + for (uint32_t r = 0; r < V; ++r) + for (uint32_t c = 0; c < d; ++c) + out_wt[c * V + r] = h.cpu_out_w[r * d + c]; + ctx.pool.upload(h.outWt, out_wt.data(), out_wt.size() * sizeof(float)); + + // Per-layer + for (uint32_t l = 0; l < L; ++l) { + auto& lw = h.layers[l]; + auto upw = ffn_up_patterns[l].cast>(); + auto upb = ffn_up_biases[l].cast>(); + auto dnw = ffn_down_patterns[l].cast>(); + auto dnb = ffn_down_biases[l].cast>(); + auto gm = ln_gammas[l].cast>(); + auto bt = ln_betas[l].cast>(); + + auto upwb = upw.request(); + auto upbb = upb.request(); + auto dnwb = dnw.request(); + auto dnbb = dnb.request(); + auto gmb = gm.request(); + auto btb = bt.request(); + + size_t uw_bytes = size_t(dF) * d * sizeof(float); + size_t ub_bytes = size_t(dF) * sizeof(float); + size_t dw_bytes = size_t(d) * dF * sizeof(float); + size_t db_bytes = size_t(d) * sizeof(float); + size_t ln_bytes = size_t(d) * sizeof(float); + + ctx.pool.upload(lw.ffnUpW, static_cast(upwb.ptr), uw_bytes); + ctx.pool.upload(lw.ffnUpB, static_cast(upbb.ptr), ub_bytes); + ctx.pool.upload(lw.ffnDownW, static_cast(dnwb.ptr), dw_bytes); + ctx.pool.upload(lw.ffnDownB, static_cast(dnbb.ptr), db_bytes); + ctx.pool.upload(lw.lnGamma, static_cast(gmb.ptr), ln_bytes); + ctx.pool.upload(lw.lnBeta, static_cast(btb.ptr), ln_bytes); + + std::memcpy(h.cpu_ffn_up_w[l].data(), static_cast(upwb.ptr), uw_bytes); + std::memcpy(h.cpu_ffn_up_b[l].data(), static_cast(upbb.ptr), ub_bytes); + std::memcpy(h.cpu_ffn_down_w[l].data(), static_cast(dnwb.ptr), dw_bytes); + std::memcpy(h.cpu_ffn_down_b[l].data(), static_cast(dnbb.ptr), db_bytes); + std::memcpy(h.cpu_ln_gamma[l].data(), static_cast(gmb.ptr), ln_bytes); + std::memcpy(h.cpu_ln_beta[l].data(), static_cast(btb.ptr), ln_bytes); + } + + ctx.waitIdle(); + }, + py::arg("device"), py::arg("handle"), + py::arg("embed_w"), py::arg("pos_w"), + py::arg("ffn_up_patterns"), py::arg("ffn_up_biases"), + py::arg("ffn_down_patterns"), py::arg("ffn_down_biases"), + py::arg("ln_gammas"), py::arg("ln_betas"), + py::arg("out_w"), + "Re-upload VSA-LM weights after optimizer step."); +} diff --git a/cpp/src/buffer_pool.cpp b/cpp/src/buffer_pool.cpp index d7d1cc2..e56a69d 100644 --- a/cpp/src/buffer_pool.cpp +++ b/cpp/src/buffer_pool.cpp @@ -56,7 +56,7 @@ BufferPool::~BufferPool() { vkDestroyCommandPool(dev, transferPool_, nullptr); } - // Destroy all pooled buffers + // Destroy all pooled buffers (host-visible bucket pool) for (auto& [bucketSize, vec] : buckets_) { for (auto& buf : vec) { if (buf.handle != VK_NULL_HANDLE) @@ -65,6 +65,24 @@ BufferPool::~BufferPool() { } buckets_.clear(); + // Destroy all pooled buffers (device-local bucket pool) + for (auto& [bucketSize, vec] : dlBuckets_) { + for (auto& buf : vec) { + if (buf.handle != VK_NULL_HANDLE) + vmaDestroyBuffer(allocator_, buf.handle, buf.allocation); + } + } + dlBuckets_.clear(); + + // Destroy all pooled buffers (readback bucket pool) + for (auto& [bucketSize, vec] : readbackBuckets_) { + for (auto& buf : vec) { + if (buf.handle != VK_NULL_HANDLE) + vmaDestroyBuffer(allocator_, buf.handle, buf.allocation); + } + } + readbackBuckets_.clear(); + if (allocator_ != VK_NULL_HANDLE) vmaDestroyAllocator(allocator_); } @@ -100,15 +118,33 @@ GrillyBuffer BufferPool::allocateBuffer(size_t bucketSize) { bufferInfo.usage = VK_BUFFER_USAGE_STORAGE_BUFFER_BIT; bufferInfo.sharingMode = VK_SHARING_MODE_EXCLUSIVE; - // Use device-local memory with host-visible fallback. - // On discrete GPUs with Resizable BAR, VMA places this in VRAM with - // CPU-visible mapping — best of both worlds (fast GPU access + memcpy upload). - // Without ReBAR, falls back to host-visible (system RAM). + // REQUIRE device-local memory. + // + // Background: ``preferredFlags`` is a soft hint VMA can ignore. Combined + // with ``HOST_ACCESS_SEQUENTIAL_WRITE_BIT``, on Windows + AMD/RDNA the + // auto-selector lands on memoryType[1] (HOST_VISIBLE | HOST_COHERENT + // *only* — no DEVICE_LOCAL, no HOST_CACHED). The buffer ends up in + // host-uncached BAR memory, and every GPU read becomes a single-byte + // PCIe transaction → measured 0.1 GB/s effective bandwidth, ~100x slower + // than CPU/numpy. See sandbox/vsa_lm/grilly_gpu_path_test.py for the + // smoking-gun profile (gc.relu on 4.7M elements: 757 ms). + // + // Using ``requiredFlags`` forces VMA to *only* consider memory types with + // DEVICE_LOCAL_BIT. On systems with Resizable BAR (the common case for + // modern AMD + Windows + AGESA 2020+), VMA picks the + // DEVICE_LOCAL | HOST_VISIBLE | HOST_COHERENT memory type — full VRAM + // bandwidth (~432 GB/s on RX 6750 XT) + CPU mapping for fast uploads. + // + // On systems without ReBAR, this allocation will FAIL — which is the + // correct behavior, because the silent slow path was a footgun. Users + // without ReBAR should enable it in BIOS, or callers needing the legacy + // host-mapped path should use ``acquirePreferDeviceLocal`` which is + // explicitly soft-preferred for that case. VmaAllocationCreateInfo allocInfo{}; allocInfo.usage = VMA_MEMORY_USAGE_AUTO; allocInfo.flags = VMA_ALLOCATION_CREATE_MAPPED_BIT | VMA_ALLOCATION_CREATE_HOST_ACCESS_SEQUENTIAL_WRITE_BIT; - allocInfo.preferredFlags = VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT; + allocInfo.requiredFlags = VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT; GrillyBuffer buf{}; buf.bucketSize = bucketSize; @@ -156,7 +192,17 @@ void BufferPool::release(GrillyBuffer& buf) { std::lock_guard lock(mutex_); stats_.totalReleased++; - auto& vec = buckets_[buf.bucketSize]; + // Route to the right bucket pool based on memory class. The three pools + // MUST stay separate: + // - dlBuckets_ : DEVICE_LOCAL only (mappedPtr=null, GPU compute) + // - readbackBuckets_ : HOST_CACHED random read (CPU reads from GPU) + // - buckets_ : default WC sequential write (CPU writes to GPU) + // Picking up a DL buffer via ``acquire()`` would crash trying to memcpy + // into a null mappedPtr; picking up a WC buffer via ``acquireReadback`` + // would silently destroy CPU-read perf (the original bug we're fixing). + auto& vec = buf.deviceLocal ? dlBuckets_[buf.bucketSize] + : buf.readback ? readbackBuckets_[buf.bucketSize] + : buckets_[buf.bucketSize]; if (vec.size() < kMaxBuffersPerBucket) { vec.push_back(buf); } else { @@ -179,11 +225,32 @@ void BufferPool::release(GrillyBuffer& buf) { GrillyBuffer BufferPool::acquireDeviceLocal(size_t size) { size_t bucket = sizeToBucket(size); + std::lock_guard lock(mutex_); + stats_.totalAcquired++; + + // Try reuse from the DL bucket pool first (LIFO returns the same handle + // most often, which keeps the descriptor cache hitting on repeat calls). + auto it = dlBuckets_.find(bucket); + if (it != dlBuckets_.end() && !it->second.empty()) { + GrillyBuffer buf = it->second.back(); + it->second.pop_back(); + buf.size = size; + stats_.hits++; + return buf; + } + + stats_.misses++; + stats_.allocations++; + VkBufferCreateInfo bufferInfo{}; bufferInfo.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO; bufferInfo.size = bucket; + // We need both TRANSFER_DST (for stage-in copies) and TRANSFER_SRC (for + // stage-out copies) since the staging pattern in cpp/src/ops/linear.cpp + // copies output back from DL → host-visible staging. bufferInfo.usage = VK_BUFFER_USAGE_STORAGE_BUFFER_BIT | - VK_BUFFER_USAGE_TRANSFER_DST_BIT; + VK_BUFFER_USAGE_TRANSFER_DST_BIT | + VK_BUFFER_USAGE_TRANSFER_SRC_BIT; bufferInfo.sharingMode = VK_SHARING_MODE_EXCLUSIVE; VmaAllocationCreateInfo allocInfo{}; @@ -193,6 +260,7 @@ GrillyBuffer BufferPool::acquireDeviceLocal(size_t size) { GrillyBuffer buf{}; buf.bucketSize = bucket; buf.size = size; + buf.deviceLocal = true; // routes to dlBuckets_ on release vkCheck(vmaCreateBuffer(allocator_, &bufferInfo, &allocInfo, &buf.handle, &buf.allocation, &buf.info), @@ -236,6 +304,22 @@ GrillyBuffer BufferPool::acquirePreferDeviceLocal(size_t size) { GrillyBuffer BufferPool::acquireReadback(size_t size) { size_t bucket = sizeToBucket(size); + std::lock_guard lock(mutex_); + stats_.totalAcquired++; + + // Try reuse from the readback bucket pool first. + auto it = readbackBuckets_.find(bucket); + if (it != readbackBuckets_.end() && !it->second.empty()) { + GrillyBuffer buf = it->second.back(); + it->second.pop_back(); + buf.size = size; + stats_.hits++; + return buf; + } + + stats_.misses++; + stats_.allocations++; + VkBufferCreateInfo bufferInfo{}; bufferInfo.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO; bufferInfo.size = bucket; @@ -253,6 +337,7 @@ GrillyBuffer BufferPool::acquireReadback(size_t size) { GrillyBuffer buf{}; buf.bucketSize = bucket; buf.size = size; + buf.readback = true; // routes to readbackBuckets_ on release vkCheck(vmaCreateBuffer(allocator_, &bufferInfo, &allocInfo, &buf.handle, &buf.allocation, &buf.info), diff --git a/cpp/src/command_batch.cpp b/cpp/src/command_batch.cpp index ac5cb33..0ca6c01 100644 --- a/cpp/src/command_batch.cpp +++ b/cpp/src/command_batch.cpp @@ -131,6 +131,32 @@ void CommandBatch::barrier() { 0, nullptr); // image memory barriers } +void CommandBatch::transferComputeBarrier() { + if (!recording_) + return; + + // Bidirectional TRANSFER ↔ COMPUTE barrier. The src/dst access masks + // cover both edges (stage-in → compute and compute → stage-out) so the + // staging-buffer pattern in linear() can use a single method for both + // transitions without tracking direction. + VkMemoryBarrier memBarrier{}; + memBarrier.sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER; + memBarrier.srcAccessMask = VK_ACCESS_TRANSFER_WRITE_BIT | + VK_ACCESS_SHADER_WRITE_BIT; + memBarrier.dstAccessMask = VK_ACCESS_TRANSFER_READ_BIT | + VK_ACCESS_SHADER_READ_BIT; + + vkCmdPipelineBarrier(cmd_, + VK_PIPELINE_STAGE_TRANSFER_BIT | + VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, + VK_PIPELINE_STAGE_TRANSFER_BIT | + VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, + 0, + 1, &memBarrier, + 0, nullptr, + 0, nullptr); +} + void CommandBatch::copyBuffer(const GrillyBuffer& src, GrillyBuffer& dst, size_t bytes) { if (!recording_) throw std::runtime_error("CommandBatch::copyBuffer() called without begin()"); diff --git a/cpp/src/io/grl_checkpoint.cpp b/cpp/src/io/grl_checkpoint.cpp new file mode 100644 index 0000000..ffdfebc --- /dev/null +++ b/cpp/src/io/grl_checkpoint.cpp @@ -0,0 +1,110 @@ +#include "grilly/io/grl_checkpoint.h" + +#include +#include +#include + +namespace grilly::io { + +static void pack_u64(uint8_t* dst, uint64_t v) { + std::memcpy(dst, &v, 8); +} + +static uint64_t unpack_u64(const uint8_t* src) { + uint64_t v; + std::memcpy(&v, src, 8); + return v; +} + +bool grl_write_file(const std::string& path, + const std::string& metadata_json, + const std::string& index_json, + const std::vector& payload) { + std::ofstream ofs(path, std::ios::binary | std::ios::trunc); + if (!ofs) + return false; + + const uint64_t meta_off = kGrlHeaderSize; + const uint64_t meta_len = metadata_json.size(); + const uint64_t idx_off = meta_off + meta_len; + const uint64_t idx_len = index_json.size(); + const uint64_t pay_off = idx_off + idx_len; + const uint64_t pay_len = payload.size(); + + uint8_t header[kGrlHeaderSize] = {}; + header[0] = 'G'; + header[1] = 'R'; + header[2] = 'L'; + header[3] = 'Y'; + uint16_t ver = kGrlFormatVersion; + uint16_t flags = 0; + uint32_t reserved = 0; + std::memcpy(header + 4, &ver, 2); + std::memcpy(header + 6, &flags, 2); + std::memcpy(header + 8, &reserved, 4); + pack_u64(header + 12, meta_off); + pack_u64(header + 20, meta_len); + pack_u64(header + 28, idx_off); + pack_u64(header + 36, idx_len); + pack_u64(header + 44, pay_off); + pack_u64(header + 52, pay_len); + + ofs.write(reinterpret_cast(header), kGrlHeaderSize); + ofs.write(metadata_json.data(), static_cast(metadata_json.size())); + ofs.write(index_json.data(), static_cast(index_json.size())); + if (!payload.empty()) + ofs.write(reinterpret_cast(payload.data()), + static_cast(payload.size())); + return static_cast(ofs); +} + +bool grl_read_file(const std::string& path, + std::string& metadata_json, + std::string& index_json, + std::vector& payload) { + std::ifstream ifs(path, std::ios::binary | std::ios::ate); + if (!ifs) + return false; + const auto end = ifs.tellg(); + ifs.seekg(0); + if (end < static_cast(kGrlHeaderSize)) + return false; + + uint8_t header[kGrlHeaderSize]; + ifs.read(reinterpret_cast(header), kGrlHeaderSize); + if (ifs.gcount() != static_cast(kGrlHeaderSize)) + return false; + if (header[0] != 'G' || header[1] != 'R' || header[2] != 'L' || header[3] != 'Y') + return false; + uint16_t ver = 0; + std::memcpy(&ver, header + 4, 2); + if (ver != kGrlFormatVersion) + throw std::runtime_error("GRL: unsupported format version"); + + const uint64_t meta_off = unpack_u64(header + 12); + const uint64_t meta_len = unpack_u64(header + 20); + const uint64_t idx_off = unpack_u64(header + 28); + const uint64_t idx_len = unpack_u64(header + 36); + const uint64_t pay_off = unpack_u64(header + 44); + const uint64_t pay_len = unpack_u64(header + 52); + + if (meta_off + meta_len != idx_off || idx_off + idx_len != pay_off) + return false; + if (end < static_cast(pay_off + pay_len)) + return false; + + metadata_json.resize(static_cast(meta_len)); + index_json.resize(static_cast(idx_len)); + payload.resize(static_cast(pay_len)); + + ifs.seekg(static_cast(meta_off)); + ifs.read(metadata_json.data(), static_cast(meta_len)); + ifs.read(index_json.data(), static_cast(idx_len)); + if (pay_len > 0) + ifs.read(reinterpret_cast(payload.data()), + static_cast(pay_len)); + + return static_cast(ifs) || (pay_len == 0 && meta_len == 0 && idx_len == 0); +} + +} // namespace grilly::io diff --git a/cpp/src/ops/activations.cpp b/cpp/src/ops/activations.cpp index f5f6704..1e75a04 100644 --- a/cpp/src/ops/activations.cpp +++ b/cpp/src/ops/activations.cpp @@ -1,6 +1,8 @@ #include "grilly/ops/activations.h" +#include #include +#include namespace grilly { namespace ops { @@ -22,16 +24,23 @@ static void activationForward( const float* input, float* output, uint32_t totalElements) { const size_t bytes = size_t(totalElements) * sizeof(float); - GrillyBuffer bufIn = pool.acquire(bytes); - GrillyBuffer bufOut = pool.acquire(bytes); + // Staging pattern (see linear.cpp for the long-form rationale): + // compute on DEVICE_LOCAL VRAM, stage-in via WC sequential-write + // memory, stage-out via HOST_CACHED random-read memory. Without this + // a 19 MB ReLU readback ran at 25 MB/s (~750 ms); with it the same + // op runs in single-digit ms. + GrillyBuffer bufInDL = pool.acquireDeviceLocal(bytes); + GrillyBuffer bufOutDL = pool.acquireDeviceLocal(bytes); + GrillyBuffer bufInStage = pool.acquire(bytes); + GrillyBuffer bufOutStage = pool.acquireReadback(bytes); - pool.upload(bufIn, input, bytes); + pool.upload(bufInStage, input, bytes); PipelineEntry pipe = cache.getOrCreate(shaderName, 2, sizeof(uint32_t)); std::vector bufInfos = { - {bufIn.handle, 0, bytes}, - {bufOut.handle, 0, bytes}, + {bufInDL.handle, 0, bytes}, + {bufOutDL.handle, 0, bytes}, }; VkDescriptorSet descSet = cache.allocDescriptorSet(shaderName, bufInfos); @@ -39,15 +48,21 @@ static void activationForward( uint32_t gx = (totalElements + 255) / 256; batch.begin(); + batch.copyBuffer(bufInStage, bufInDL, bytes); + batch.transferComputeBarrier(); batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, 1, 1, &push, sizeof(push)); + batch.transferComputeBarrier(); + batch.copyBuffer(bufOutDL, bufOutStage, bytes); batch.submitDeferred(); batch.waitForCompletion(); - pool.download(bufOut, output, bytes); + pool.download(bufOutStage, output, bytes); - pool.release(bufIn); - pool.release(bufOut); + pool.release(bufInDL); + pool.release(bufOutDL); + pool.release(bufInStage); + pool.release(bufOutStage); } // ── Activation backward helper ──────────────────────────────────────────── @@ -62,19 +77,24 @@ static void activationBackward( float* gradInput, uint32_t totalElements) { const size_t bytes = size_t(totalElements) * sizeof(float); - GrillyBuffer bufGradOut = pool.acquire(bytes); - GrillyBuffer bufInput = pool.acquire(bytes); - GrillyBuffer bufGradIn = pool.acquire(bytes); + // Staging pattern: 2 stage-in (gradOutput, input), 1 stage-out (gradInput) + GrillyBuffer bufGradOutDL = pool.acquireDeviceLocal(bytes); + GrillyBuffer bufInputDL = pool.acquireDeviceLocal(bytes); + GrillyBuffer bufGradInDL = pool.acquireDeviceLocal(bytes); - pool.upload(bufGradOut, gradOutput, bytes); - pool.upload(bufInput, input, bytes); + GrillyBuffer bufGradOutStage = pool.acquire(bytes); + GrillyBuffer bufInputStage = pool.acquire(bytes); + GrillyBuffer bufGradInStage = pool.acquireReadback(bytes); + + pool.upload(bufGradOutStage, gradOutput, bytes); + pool.upload(bufInputStage, input, bytes); PipelineEntry pipe = cache.getOrCreate(shaderName, 3, sizeof(uint32_t)); std::vector bufInfos = { - {bufGradOut.handle, 0, bytes}, - {bufInput.handle, 0, bytes}, - {bufGradIn.handle, 0, bytes}, + {bufGradOutDL.handle, 0, bytes}, + {bufInputDL.handle, 0, bytes}, + {bufGradInDL.handle, 0, bytes}, }; VkDescriptorSet descSet = cache.allocDescriptorSet(shaderName, bufInfos); @@ -82,16 +102,24 @@ static void activationBackward( uint32_t gx = (totalElements + 255) / 256; batch.begin(); + batch.copyBuffer(bufGradOutStage, bufGradOutDL, bytes); + batch.copyBuffer(bufInputStage, bufInputDL, bytes); + batch.transferComputeBarrier(); batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, 1, 1, &push, sizeof(push)); + batch.transferComputeBarrier(); + batch.copyBuffer(bufGradInDL, bufGradInStage, bytes); batch.submitDeferred(); batch.waitForCompletion(); - pool.download(bufGradIn, gradInput, bytes); + pool.download(bufGradInStage, gradInput, bytes); - pool.release(bufGradOut); - pool.release(bufInput); - pool.release(bufGradIn); + pool.release(bufGradOutDL); + pool.release(bufInputDL); + pool.release(bufGradInDL); + pool.release(bufGradOutStage); + pool.release(bufInputStage); + pool.release(bufGradInStage); } // ── Forward passes ──────────────────────────────────────────────────────── @@ -263,5 +291,114 @@ void softmaxBackward(CommandBatch& batch, BufferPool& pool, PipelineCache& cache pool.release(bufGradIn); } +// ── Multiplication-free softmax (mf-softmax.glsl) ───────────────────────── +// Same 3-pass buffer layout as softmax; pass_type 0/1/2 with relu sums. + +void mfSoftmax(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, + const float* input, float* output, uint32_t batchSize, uint32_t seqLen, + uint32_t features) { + const uint32_t totalPositions = batchSize * seqLen; + const uint32_t totalElements = totalPositions * features; + const size_t dataBytes = size_t(totalElements) * sizeof(float); + const size_t auxBytes = size_t(totalPositions) * sizeof(float); + + GrillyBuffer bufInput = pool.acquire(dataBytes); + GrillyBuffer bufOutput = pool.acquire(dataBytes); + GrillyBuffer bufMax = pool.acquire(auxBytes); + GrillyBuffer bufSumPos = pool.acquire(auxBytes); + + pool.upload(bufInput, input, dataBytes); + + PipelineEntry pipe = + cache.getOrCreate("mf-softmax", 4, sizeof(SoftmaxParams)); + + std::vector bufInfos = { + {bufInput.handle, 0, dataBytes}, + {bufOutput.handle, 0, dataBytes}, + {bufMax.handle, 0, auxBytes}, + {bufSumPos.handle, 0, auxBytes}, + }; + VkDescriptorSet descSet = cache.allocDescriptorSet("mf-softmax", bufInfos); + + uint32_t gx = (totalPositions + 255) / 256; + + batch.begin(); + + SoftmaxParams push0{batchSize, seqLen, features, 0, features}; + batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, 1, 1, &push0, + sizeof(push0)); + batch.barrier(); + + SoftmaxParams push1{batchSize, seqLen, features, 1, features}; + batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, 1, 1, &push1, + sizeof(push1)); + batch.barrier(); + + SoftmaxParams push2{batchSize, seqLen, features, 2, features}; + uint32_t gx2 = (totalElements + 255) / 256; + batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx2, 1, 1, &push2, + sizeof(push2)); + + batch.submitDeferred(); + batch.waitForCompletion(); + + pool.download(bufOutput, output, dataBytes); + + pool.release(bufInput); + pool.release(bufOutput); + pool.release(bufMax); + pool.release(bufSumPos); +} + +// Push layout must match mf-softplus.glsl: uint total_elements; float c; +struct MfSoftplusParams { + uint32_t totalElements; + float c; +}; + +void mfSoftplus(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, + const float* input, float* output, uint32_t totalElements, + float beta) { + if (beta <= 0.f) { + throw std::invalid_argument("mfSoftplus: beta must be positive"); + } + const float c = 4.f / (beta * beta); + const size_t bytes = size_t(totalElements) * sizeof(float); + + GrillyBuffer bufIn = pool.acquire(bytes); + GrillyBuffer bufOut = pool.acquire(bytes); + + pool.upload(bufIn, input, bytes); + + PipelineEntry pipe = + cache.getOrCreate("mf-softplus", 2, sizeof(MfSoftplusParams)); + + std::vector bufInfos = { + {bufIn.handle, 0, bytes}, + {bufOut.handle, 0, bytes}, + }; + VkDescriptorSet descSet = cache.allocDescriptorSet("mf-softplus", bufInfos); + + MfSoftplusParams push{totalElements, c}; + uint32_t gx = (totalElements + 255) / 256; + + batch.begin(); + batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, 1, 1, &push, + sizeof(push)); + batch.submitDeferred(); + batch.waitForCompletion(); + + pool.download(bufOut, output, bytes); + + pool.release(bufIn); + pool.release(bufOut); +} + +void mfSigmoid(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, + const float* input, float* output, uint32_t totalElements) { + activationForward("mf-sigmoid", batch, pool, cache, input, output, + totalElements); +} + } // namespace ops } // namespace grilly diff --git a/cpp/src/ops/batched_ops.cpp b/cpp/src/ops/batched_ops.cpp index 91701d7..f7e173e 100644 --- a/cpp/src/ops/batched_ops.cpp +++ b/cpp/src/ops/batched_ops.cpp @@ -10,6 +10,7 @@ #include "grilly/ops/batched_ops.h" #include "grilly/ops/linear.h" #include "grilly/ops/activations.h" +#include "grilly/ops/embedding.h" namespace grilly { namespace ops { @@ -188,6 +189,33 @@ void batchedAdd(CommandBatch& batch, PipelineCache& cache, &push, sizeof(push)); } +void batchedEmbeddingLookup(CommandBatch& batch, PipelineCache& cache, + const GrillyBuffer& tokenIds, const GrillyBuffer& embedTable, + GrillyBuffer& output, uint32_t batchSize, uint32_t seqLen, + uint32_t vocabSize, uint32_t embeddingDim) { + + PipelineEntry pipe = cache.getOrCreate("embedding-lookup", 3, + sizeof(EmbeddingParams)); + + const uint32_t totalTokens = batchSize * seqLen; + size_t idBytes = size_t(totalTokens) * sizeof(uint32_t); + size_t embedBytes = size_t(vocabSize) * embeddingDim * sizeof(float); + size_t outBytes = size_t(totalTokens) * embeddingDim * sizeof(float); + + std::vector bufs = { + {tokenIds.handle, 0, idBytes}, + {embedTable.handle, 0, embedBytes}, + {output.handle, 0, outBytes}, + }; + VkDescriptorSet desc = cache.allocDescriptorSet("embedding-lookup", bufs); + + EmbeddingParams push{batchSize, seqLen, vocabSize, embeddingDim}; + uint32_t gx = (totalTokens + 255) / 256; + + batch.dispatch(pipe.pipeline, pipe.layout, desc, gx, 1, 1, + &push, sizeof(push)); +} + void batchedTiledLinear(CommandBatch& batch, PipelineCache& cache, const GrillyBuffer& input, const GrillyBuffer& weight, diff --git a/cpp/src/ops/embedding.cpp b/cpp/src/ops/embedding.cpp index 410f70d..eda4265 100644 --- a/cpp/src/ops/embedding.cpp +++ b/cpp/src/ops/embedding.cpp @@ -22,20 +22,28 @@ void embeddingLookup(CommandBatch& batch, BufferPool& pool, const size_t embedBytes = size_t(p.vocabSize) * p.embeddingDim * sizeof(float); const size_t outBytes = size_t(totalTokens) * p.embeddingDim * sizeof(float); - GrillyBuffer bufIds = pool.acquire(idBytes); - GrillyBuffer bufEmbed = pool.acquire(embedBytes); - GrillyBuffer bufOut = pool.acquire(outBytes); + // Staging pattern: 2 stage-in (token ids, embedding table), 1 stage-out + // (output token vectors). Embedding table is the big one (vocab × dim); + // for VSA-LM with vocab=8192, dim=384 it's 12 MB and re-uploaded every + // call. A weight cache (TODO) would amortize this across training steps. + GrillyBuffer bufIdsDL = pool.acquireDeviceLocal(idBytes); + GrillyBuffer bufEmbedDL = pool.acquireDeviceLocal(embedBytes); + GrillyBuffer bufOutDL = pool.acquireDeviceLocal(outBytes); - pool.upload(bufIds, reinterpret_cast(tokenIds), idBytes); - pool.upload(bufEmbed, embeddings, embedBytes); + GrillyBuffer bufIdsStage = pool.acquire(idBytes); + GrillyBuffer bufEmbedStage = pool.acquire(embedBytes); + GrillyBuffer bufOutStage = pool.acquireReadback(outBytes); + + pool.upload(bufIdsStage, reinterpret_cast(tokenIds), idBytes); + pool.upload(bufEmbedStage, embeddings, embedBytes); PipelineEntry pipe = cache.getOrCreate("embedding-lookup", 3, sizeof(EmbeddingParams)); std::vector bufInfos = { - {bufIds.handle, 0, idBytes}, - {bufEmbed.handle, 0, embedBytes}, - {bufOut.handle, 0, outBytes}, + {bufIdsDL.handle, 0, idBytes}, + {bufEmbedDL.handle, 0, embedBytes}, + {bufOutDL.handle, 0, outBytes}, }; VkDescriptorSet descSet = cache.allocDescriptorSet("embedding-lookup", bufInfos); @@ -43,16 +51,24 @@ void embeddingLookup(CommandBatch& batch, BufferPool& pool, uint32_t gx = (totalTokens + 255) / 256; batch.begin(); + batch.copyBuffer(bufIdsStage, bufIdsDL, idBytes); + batch.copyBuffer(bufEmbedStage, bufEmbedDL, embedBytes); + batch.transferComputeBarrier(); batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, 1, 1, &p, sizeof(p)); + batch.transferComputeBarrier(); + batch.copyBuffer(bufOutDL, bufOutStage, outBytes); batch.submitDeferred(); batch.waitForCompletion(); - pool.download(bufOut, output, outBytes); + pool.download(bufOutStage, output, outBytes); - pool.release(bufIds); - pool.release(bufEmbed); - pool.release(bufOut); + pool.release(bufIdsDL); + pool.release(bufEmbedDL); + pool.release(bufOutDL); + pool.release(bufIdsStage); + pool.release(bufEmbedStage); + pool.release(bufOutStage); } } // namespace ops diff --git a/cpp/src/ops/layernorm.cpp b/cpp/src/ops/layernorm.cpp index b7bd3dd..7c26b70 100644 --- a/cpp/src/ops/layernorm.cpp +++ b/cpp/src/ops/layernorm.cpp @@ -36,38 +36,49 @@ void layernorm(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, const size_t meanBytes = size_t(totalPositions) * sizeof(float); const size_t varBytes = meanBytes; - // Acquire 6 buffers - GrillyBuffer bufInput = pool.acquire(inputBytes); - GrillyBuffer bufOutput = pool.acquire(outputBytes); - GrillyBuffer bufGamma = pool.acquire(gammaBytes); - GrillyBuffer bufBeta = pool.acquire(betaBytes); - GrillyBuffer bufMean = pool.acquire(meanBytes); - GrillyBuffer bufVar = pool.acquire(varBytes); - - // Upload input data - pool.upload(bufInput, input, inputBytes); - pool.upload(bufGamma, gamma, gammaBytes); - pool.upload(bufBeta, beta, betaBytes); + // Staging pattern: 3 stage-in (input, gamma, beta), 1 stage-out (output). + // mean and var are intermediate buffers — pure DEVICE_LOCAL, never touched + // by the CPU, so no staging buffers needed for them. + GrillyBuffer bufInputDL = pool.acquireDeviceLocal(inputBytes); + GrillyBuffer bufOutputDL = pool.acquireDeviceLocal(outputBytes); + GrillyBuffer bufGammaDL = pool.acquireDeviceLocal(gammaBytes); + GrillyBuffer bufBetaDL = pool.acquireDeviceLocal(betaBytes); + GrillyBuffer bufMeanDL = pool.acquireDeviceLocal(meanBytes); + GrillyBuffer bufVarDL = pool.acquireDeviceLocal(varBytes); + + GrillyBuffer bufInputStage = pool.acquire(inputBytes); + GrillyBuffer bufGammaStage = pool.acquire(gammaBytes); + GrillyBuffer bufBetaStage = pool.acquire(betaBytes); + GrillyBuffer bufOutputStage = pool.acquireReadback(outputBytes); + + pool.upload(bufInputStage, input, inputBytes); + pool.upload(bufGammaStage, gamma, gammaBytes); + pool.upload(bufBetaStage, beta, betaBytes); // Get pipeline: 6 buffers, 20 bytes push constants PipelineEntry pipe = cache.getOrCreate("fnn-layernorm", 6, sizeof(LayerNormParams)); - // Descriptor set (same for all 3 passes — same buffers) + // Descriptor set bound to DEVICE_LOCAL buffers std::vector bufInfos = { - {bufInput.handle, 0, inputBytes}, - {bufOutput.handle, 0, outputBytes}, - {bufGamma.handle, 0, gammaBytes}, - {bufBeta.handle, 0, betaBytes}, - {bufMean.handle, 0, meanBytes}, - {bufVar.handle, 0, varBytes}, + {bufInputDL.handle, 0, inputBytes}, + {bufOutputDL.handle, 0, outputBytes}, + {bufGammaDL.handle, 0, gammaBytes}, + {bufBetaDL.handle, 0, betaBytes}, + {bufMeanDL.handle, 0, meanBytes}, + {bufVarDL.handle, 0, varBytes}, }; VkDescriptorSet descSet = cache.allocDescriptorSet("fnn-layernorm", bufInfos); - // 3-pass dispatch batch.begin(); + // Stage-in: copy 3 host-visible staging buffers to DL VRAM + batch.copyBuffer(bufInputStage, bufInputDL, inputBytes); + batch.copyBuffer(bufGammaStage, bufGammaDL, gammaBytes); + batch.copyBuffer(bufBetaStage, bufBetaDL, betaBytes); + batch.transferComputeBarrier(); + // Pass 0: compute mean LayerNormParams push0{batchSize, seqLen, features, eps, 0}; uint32_t gx0 = (totalPositions + 255) / 256; @@ -88,19 +99,25 @@ void layernorm(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx2, 1, 1, &push2, sizeof(push2)); + // Stage-out: DL output → HOST_CACHED readback staging + batch.transferComputeBarrier(); + batch.copyBuffer(bufOutputDL, bufOutputStage, outputBytes); + batch.submitDeferred(); batch.waitForCompletion(); - // Download result - pool.download(bufOutput, output, outputBytes); - - // Release all buffers - pool.release(bufInput); - pool.release(bufOutput); - pool.release(bufGamma); - pool.release(bufBeta); - pool.release(bufMean); - pool.release(bufVar); + pool.download(bufOutputStage, output, outputBytes); + + pool.release(bufInputDL); + pool.release(bufOutputDL); + pool.release(bufGammaDL); + pool.release(bufBetaDL); + pool.release(bufMeanDL); + pool.release(bufVarDL); + pool.release(bufInputStage); + pool.release(bufGammaStage); + pool.release(bufBetaStage); + pool.release(bufOutputStage); } // ── LayerNorm backward ─────────────────────────────────────────────────── @@ -119,47 +136,69 @@ void layernormBackward(CommandBatch& batch, BufferPool& pool, const size_t gammaBytes = size_t(features) * sizeof(float); const size_t posBytes = size_t(totalPositions) * sizeof(float); - // 8 buffers for backward - GrillyBuffer bufGradOut = pool.acquire(elemBytes); - GrillyBuffer bufInput = pool.acquire(elemBytes); - GrillyBuffer bufGamma = pool.acquire(gammaBytes); - GrillyBuffer bufMean = pool.acquire(posBytes); - GrillyBuffer bufVar = pool.acquire(posBytes); - GrillyBuffer bufGradIn = pool.acquire(elemBytes); - GrillyBuffer bufGradGamma = pool.acquire(gammaBytes); - GrillyBuffer bufGradBeta = pool.acquire(gammaBytes); - - pool.upload(bufGradOut, gradOutput, elemBytes); - pool.upload(bufInput, input, elemBytes); - pool.upload(bufGamma, gamma, gammaBytes); - pool.upload(bufMean, mean, posBytes); - pool.upload(bufVar, var, posBytes); - - // Zero grad outputs + // Staging pattern: 5 stage-in (gradOut, input, gamma, mean, var), + // 3 stage-out (gradIn, gradGamma, gradBeta) + GrillyBuffer bufGradOutDL = pool.acquireDeviceLocal(elemBytes); + GrillyBuffer bufInputDL = pool.acquireDeviceLocal(elemBytes); + GrillyBuffer bufGammaDL = pool.acquireDeviceLocal(gammaBytes); + GrillyBuffer bufMeanDL = pool.acquireDeviceLocal(posBytes); + GrillyBuffer bufVarDL = pool.acquireDeviceLocal(posBytes); + GrillyBuffer bufGradInDL = pool.acquireDeviceLocal(elemBytes); + GrillyBuffer bufGradGammaDL = pool.acquireDeviceLocal(gammaBytes); + GrillyBuffer bufGradBetaDL = pool.acquireDeviceLocal(gammaBytes); + + GrillyBuffer bufGradOutStage = pool.acquire(elemBytes); + GrillyBuffer bufInputStage = pool.acquire(elemBytes); + GrillyBuffer bufGammaStage = pool.acquire(gammaBytes); + GrillyBuffer bufMeanStage = pool.acquire(posBytes); + GrillyBuffer bufVarStage = pool.acquire(posBytes); + GrillyBuffer bufGradInStage = pool.acquireReadback(elemBytes); + GrillyBuffer bufGradGammaStage = pool.acquireReadback(gammaBytes); + GrillyBuffer bufGradBetaStage = pool.acquireReadback(gammaBytes); + + pool.upload(bufGradOutStage, gradOutput, elemBytes); + pool.upload(bufInputStage, input, elemBytes); + pool.upload(bufGammaStage, gamma, gammaBytes); + pool.upload(bufMeanStage, mean, posBytes); + pool.upload(bufVarStage, var, posBytes); + + // Zero the grad output staging buffers (atomic accumulation in shader). + // Reuse the readback stage buffers as upload-zeros source. std::vector zeros_elem(totalElements, 0.0f); std::vector zeros_feat(features, 0.0f); - pool.upload(bufGradIn, zeros_elem.data(), elemBytes); - pool.upload(bufGradGamma, zeros_feat.data(), gammaBytes); - pool.upload(bufGradBeta, zeros_feat.data(), gammaBytes); + pool.upload(bufGradInStage, zeros_elem.data(), elemBytes); + pool.upload(bufGradGammaStage, zeros_feat.data(), gammaBytes); + pool.upload(bufGradBetaStage, zeros_feat.data(), gammaBytes); PipelineEntry pipe = cache.getOrCreate("fnn-layernorm-backward", 8, sizeof(LayerNormParams)); std::vector bufInfos = { - {bufGradOut.handle, 0, elemBytes}, - {bufInput.handle, 0, elemBytes}, - {bufGamma.handle, 0, gammaBytes}, - {bufMean.handle, 0, posBytes}, - {bufVar.handle, 0, posBytes}, - {bufGradIn.handle, 0, elemBytes}, - {bufGradGamma.handle, 0, gammaBytes}, - {bufGradBeta.handle, 0, gammaBytes}, + {bufGradOutDL.handle, 0, elemBytes}, + {bufInputDL.handle, 0, elemBytes}, + {bufGammaDL.handle, 0, gammaBytes}, + {bufMeanDL.handle, 0, posBytes}, + {bufVarDL.handle, 0, posBytes}, + {bufGradInDL.handle, 0, elemBytes}, + {bufGradGammaDL.handle, 0, gammaBytes}, + {bufGradBetaDL.handle, 0, gammaBytes}, }; VkDescriptorSet descSet = cache.allocDescriptorSet( "fnn-layernorm-backward", bufInfos); batch.begin(); + // Stage-in: copy all 8 stage buffers (5 inputs + 3 zeroed grads) to DL + batch.copyBuffer(bufGradOutStage, bufGradOutDL, elemBytes); + batch.copyBuffer(bufInputStage, bufInputDL, elemBytes); + batch.copyBuffer(bufGammaStage, bufGammaDL, gammaBytes); + batch.copyBuffer(bufMeanStage, bufMeanDL, posBytes); + batch.copyBuffer(bufVarStage, bufVarDL, posBytes); + batch.copyBuffer(bufGradInStage, bufGradInDL, elemBytes); + batch.copyBuffer(bufGradGammaStage, bufGradGammaDL, gammaBytes); + batch.copyBuffer(bufGradBetaStage, bufGradBetaDL, gammaBytes); + batch.transferComputeBarrier(); + // Pass 0: intermediate sums LayerNormParams push0{batchSize, seqLen, features, eps, 0}; batch.dispatch(pipe.pipeline, pipe.layout, descSet, @@ -180,21 +219,35 @@ void layernormBackward(CommandBatch& batch, BufferPool& pool, (features + 255) / 256, 1, 1, &push2, sizeof(push2)); + // Stage-out: copy 3 grad buffers from DL → HOST_CACHED readback staging + batch.transferComputeBarrier(); + batch.copyBuffer(bufGradInDL, bufGradInStage, elemBytes); + batch.copyBuffer(bufGradGammaDL, bufGradGammaStage, gammaBytes); + batch.copyBuffer(bufGradBetaDL, bufGradBetaStage, gammaBytes); + batch.submitDeferred(); batch.waitForCompletion(); - pool.download(bufGradIn, gradInput, elemBytes); - pool.download(bufGradGamma, gradGamma, gammaBytes); - pool.download(bufGradBeta, gradBeta, gammaBytes); - - pool.release(bufGradOut); - pool.release(bufInput); - pool.release(bufGamma); - pool.release(bufMean); - pool.release(bufVar); - pool.release(bufGradIn); - pool.release(bufGradGamma); - pool.release(bufGradBeta); + pool.download(bufGradInStage, gradInput, elemBytes); + pool.download(bufGradGammaStage, gradGamma, gammaBytes); + pool.download(bufGradBetaStage, gradBeta, gammaBytes); + + pool.release(bufGradOutDL); + pool.release(bufInputDL); + pool.release(bufGammaDL); + pool.release(bufMeanDL); + pool.release(bufVarDL); + pool.release(bufGradInDL); + pool.release(bufGradGammaDL); + pool.release(bufGradBetaDL); + pool.release(bufGradOutStage); + pool.release(bufInputStage); + pool.release(bufGammaStage); + pool.release(bufMeanStage); + pool.release(bufVarStage); + pool.release(bufGradInStage); + pool.release(bufGradGammaStage); + pool.release(bufGradBetaStage); } } // namespace ops diff --git a/cpp/src/ops/linear.cpp b/cpp/src/ops/linear.cpp index 85bbada..3b0bd64 100644 --- a/cpp/src/ops/linear.cpp +++ b/cpp/src/ops/linear.cpp @@ -6,16 +6,24 @@ namespace grilly { namespace ops { -// ── GPU linear (port of fnn.py:1823-1976) ─────────────────────────────────── +// ── GPU linear with explicit DEVICE_LOCAL + staging pattern ──────────────── // -// In the Python backend, linear() makes ~12 ctypes FFI calls: -// acquire buffer × 4, upload × 3, get_or_create_pipeline, get_descriptor_set, -// dispatch_compute (which internally does: reset cmd, begin, bind pipeline, -// bind descriptors, push constants, dispatch, end, submit, wait fence), -// download, release × 4. +// On AMD/Windows even with Resizable BAR enabled, the DEVICE_LOCAL + +// HOST_VISIBLE memory type that VMA selects for ``BufferPool::acquire`` +// lands in WC-mapped memory that bypasses the GPU's L2 cache. Compute +// kernels reading from it run at ~0.05 GB/s — slower than a SATA SSD, +// roughly 0.04% of theoretical VRAM bandwidth (432 GB/s on RX 6750 XT). +// See sandbox/vsa_lm/grilly_gpu_path_test.py for the smoking-gun profile. // -// Here ALL of that is native C++ — zero Python crossings. The CommandBatch -// records everything into a single command buffer submission. +// The fix: compute buffers go through ``acquireDeviceLocal`` (DEVICE_LOCAL +// only, full cached VRAM, ~432 GB/s), and we move data in/out via small +// staging buffers from the regular pool. The staging buffers are slow for +// GPU compute reads but fine for ``vkCmdCopyBuffer`` transfers, which use +// the GPU's dedicated DMA engine and run at PCIe speed (~25 GB/s). +// +// All 3 staging-in copies, the compute dispatch, and the 1 staging-out +// copy are batched into a single command buffer with a single submit/wait, +// so the dispatch overhead is unchanged from the old fast-path. void linear(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, const float* x, const float* weights, const float* bias, @@ -27,55 +35,89 @@ void linear(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, : sizeof(float); // dummy const size_t outputBytes = size_t(p.batchSeq) * p.outputDim * sizeof(float); - // ── Acquire buffers (bucket-rounded, persistent mapping) ── - GrillyBuffer bufInput = pool.acquire(inputBytes); - GrillyBuffer bufWeights = pool.acquire(weightBytes); - GrillyBuffer bufBias = pool.acquire(biasBytes); - GrillyBuffer bufOutput = pool.acquire(outputBytes); - - // ── Upload via persistent mapping (single memcpy each, no vkMap/vkUnmap) ── - pool.upload(bufInput, x, inputBytes); - pool.upload(bufWeights, weights, weightBytes); + // ── Acquire DEVICE_LOCAL compute buffers (cached VRAM, fast GPU access) ── + GrillyBuffer bufInputDL = pool.acquireDeviceLocal(inputBytes); + GrillyBuffer bufWeightsDL = pool.acquireDeviceLocal(weightBytes); + GrillyBuffer bufBiasDL = pool.acquireDeviceLocal(biasBytes); + GrillyBuffer bufOutputDL = pool.acquireDeviceLocal(outputBytes); + + // ── Acquire host-visible staging buffers ── + // Stage-IN buffers (CPU writes only): WC memory is fast for sequential + // memcpy at ~9 GB/s — pool.acquire() is the right choice. + GrillyBuffer bufInputStage = pool.acquire(inputBytes); + GrillyBuffer bufWeightsStage = pool.acquire(weightBytes); + GrillyBuffer bufBiasStage = pool.acquire(biasBytes); + // Stage-OUT buffer (CPU reads from it): MUST be HOST_CACHED random-read + // memory. WC memory is uncached on the CPU side and a 19 MB readback + // memcpy ran at ~25 MB/s (749 ms — slower than the 9 ms GPU compute!). + // HOST_CACHED via acquireReadback gives ~7 GB/s for the same memcpy. + GrillyBuffer bufOutputStage = pool.acquireReadback(outputBytes); + + // ── memcpy CPU → staging (no GPU sync needed, persistent mapping) ── + pool.upload(bufInputStage, x, inputBytes); + pool.upload(bufWeightsStage, weights, weightBytes); if (p.hasBias && bias) { - pool.upload(bufBias, bias, p.outputDim * sizeof(float)); + pool.upload(bufBiasStage, bias, p.outputDim * sizeof(float)); } // ── Get or create pipeline (4 buffers, 16 bytes push constants) ── PipelineEntry pipe = cache.getOrCreate("fnn-linear", 4, 16); - // ── Allocate descriptor set (LRU cached) ── + // ── Allocate descriptor set bound to DEVICE_LOCAL buffers ── + // The descriptor cache keys on (shader_name, [(buffer.handle, range)]), + // so as long as the pool returns stable handles for repeated bucket + // requests (LIFO), this hits across calls. std::vector bufferInfos(4); - bufferInfos[0] = {bufInput.handle, 0, inputBytes}; - bufferInfos[1] = {bufWeights.handle, 0, weightBytes}; - bufferInfos[2] = {bufBias.handle, 0, biasBytes}; - bufferInfos[3] = {bufOutput.handle, 0, outputBytes}; + bufferInfos[0] = {bufInputDL.handle, 0, inputBytes}; + bufferInfos[1] = {bufWeightsDL.handle, 0, weightBytes}; + bufferInfos[2] = {bufBiasDL.handle, 0, biasBytes}; + bufferInfos[3] = {bufOutputDL.handle, 0, outputBytes}; VkDescriptorSet descSet = cache.allocDescriptorSet("fnn-linear", bufferInfos); - // ── Push constants: batch_seq, input_dim, output_dim, has_bias ── - // Matches fnn-linear.glsl layout (4 × uint32 = 16 bytes). - // Python packs these via struct.pack("IIII", ...) — we just memcpy the struct. LinearParams pushData = p; - // ── Dispatch ── - // 2D workgroups at 16×16 (matches fnn-linear.glsl tiled GEMM) uint32_t gx = (p.outputDim + 15) / 16; uint32_t gy = (p.batchSeq + 15) / 16; + // ── Single command buffer: stage-in → barrier → compute → barrier → stage-out ── batch.begin(); + + // Stage-in: DMA copy host-visible staging → DEVICE_LOCAL VRAM + batch.copyBuffer(bufInputStage, bufInputDL, inputBytes); + batch.copyBuffer(bufWeightsStage, bufWeightsDL, weightBytes); + if (p.hasBias && bias) { + batch.copyBuffer(bufBiasStage, bufBiasDL, p.outputDim * sizeof(float)); + } + + // Barrier: TRANSFER_WRITE → SHADER_READ + batch.transferComputeBarrier(); + + // Compute on DEVICE_LOCAL buffers (full ~432 GB/s VRAM bandwidth) batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, gy, 1, &pushData, sizeof(pushData)); - batch.submitDeferred(); // Submit without waiting — GPU runs async - // ── Sync + download (only waits here, not at submit) ── + // Barrier: SHADER_WRITE → TRANSFER_READ + batch.transferComputeBarrier(); + + // Stage-out: DMA copy DEVICE_LOCAL → host-visible HOST_CACHED staging + batch.copyBuffer(bufOutputDL, bufOutputStage, outputBytes); + + batch.submitDeferred(); batch.waitForCompletion(); - pool.download(bufOutput, output, outputBytes); - // ── Release buffers back to pool ── - pool.release(bufInput); - pool.release(bufWeights); - pool.release(bufBias); - pool.release(bufOutput); + // ── memcpy staging → CPU output (HOST_CACHED, ~7 GB/s) ── + pool.download(bufOutputStage, output, outputBytes); + + // ── Release buffers back to their respective pools ── + pool.release(bufInputDL); + pool.release(bufWeightsDL); + pool.release(bufBiasDL); + pool.release(bufOutputDL); + pool.release(bufInputStage); + pool.release(bufWeightsStage); + pool.release(bufBiasStage); + pool.release(bufOutputStage); } // ── CPU reference using Eigen (for correctness verification) ──────────────── @@ -134,24 +176,43 @@ void linearBackward(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, const size_t gradWBytes = weightBytes; const size_t gradBiasBytes = size_t(p.outputDim) * sizeof(float); - GrillyBuffer bufGradOut = pool.acquire(gradOutBytes); - GrillyBuffer bufInput = pool.acquire(inputBytes); - GrillyBuffer bufWeights = pool.acquire(weightBytes); - GrillyBuffer bufGradIn = pool.acquire(gradInBytes); - GrillyBuffer bufGradW = pool.acquire(gradWBytes); - GrillyBuffer bufGradBias = pool.acquire(gradBiasBytes); - - pool.upload(bufGradOut, gradOutput, gradOutBytes); - pool.upload(bufInput, input, inputBytes); - pool.upload(bufWeights, weights, weightBytes); - - // Zero grad outputs + // Staging pattern: 3 stage-in (gradOut, input, weights), + // 3 stage-out (gradIn, gradW, gradBias). All compute on DEVICE_LOCAL. + GrillyBuffer bufGradOutDL = pool.acquireDeviceLocal(gradOutBytes); + GrillyBuffer bufInputDL = pool.acquireDeviceLocal(inputBytes); + GrillyBuffer bufWeightsDL = pool.acquireDeviceLocal(weightBytes); + GrillyBuffer bufGradInDL = pool.acquireDeviceLocal(gradInBytes); + GrillyBuffer bufGradWDL = pool.acquireDeviceLocal(gradWBytes); + GrillyBuffer bufGradBiasDL = pool.acquireDeviceLocal(gradBiasBytes); + + GrillyBuffer bufGradOutStage = pool.acquire(gradOutBytes); + GrillyBuffer bufInputStage = pool.acquire(inputBytes); + GrillyBuffer bufWeightsStage = pool.acquire(weightBytes); + GrillyBuffer bufGradInStage = pool.acquireReadback(gradInBytes); + GrillyBuffer bufGradWStage = pool.acquireReadback(gradWBytes); + GrillyBuffer bufGradBiasStage = pool.acquireReadback(gradBiasBytes); + + pool.upload(bufGradOutStage, gradOutput, gradOutBytes); + pool.upload(bufInputStage, input, inputBytes); + pool.upload(bufWeightsStage, weights, weightBytes); + + // The grad output buffers must start at zero — pass 1 (grad_weight) and + // pass 2 (grad_bias) accumulate via atomic adds in the shader. We zero + // them on the GPU side via vkCmdFillBuffer rather than uploading zeros + // through staging (which was the old code path). + // (Workaround: upload zeros to a small temporary stage and copy. The + // simpler path: keep the upload-zeros-via-stage approach since we need + // to reset every call.) std::vector zerosIn(p.batchSeq * p.inputDim, 0.0f); std::vector zerosW(p.outputDim * p.inputDim, 0.0f); std::vector zerosB(p.outputDim, 0.0f); - pool.upload(bufGradIn, zerosIn.data(), gradInBytes); - pool.upload(bufGradW, zerosW.data(), gradWBytes); - pool.upload(bufGradBias, zerosB.data(), gradBiasBytes); + // Reuse the readback stage buffers as upload-zeros source: they're + // host-visible (HOST_CACHED), CPU-write is fine even though it's not + // optimal for sequential write — total bytes is small relative to GPU + // compute. Upload then DMA copy in the command buffer. + pool.upload(bufGradInStage, zerosIn.data(), gradInBytes); + pool.upload(bufGradWStage, zerosW.data(), gradWBytes); + pool.upload(bufGradBiasStage, zerosB.data(), gradBiasBytes); LinearBackwardParams bwdParams{p.batchSeq, p.inputDim, p.outputDim, 0}; @@ -159,18 +220,28 @@ void linearBackward(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, sizeof(LinearBackwardParams)); std::vector bufInfos = { - {bufGradOut.handle, 0, gradOutBytes}, - {bufInput.handle, 0, inputBytes}, - {bufWeights.handle, 0, weightBytes}, - {bufGradIn.handle, 0, gradInBytes}, - {bufGradW.handle, 0, gradWBytes}, - {bufGradBias.handle, 0, gradBiasBytes}, + {bufGradOutDL.handle, 0, gradOutBytes}, + {bufInputDL.handle, 0, inputBytes}, + {bufWeightsDL.handle, 0, weightBytes}, + {bufGradInDL.handle, 0, gradInBytes}, + {bufGradWDL.handle, 0, gradWBytes}, + {bufGradBiasDL.handle, 0, gradBiasBytes}, }; VkDescriptorSet descSet = cache.allocDescriptorSet("fnn-linear-backward", bufInfos); batch.begin(); + // Stage-in: copy all 6 staging buffers (3 inputs + 3 zeroed grads) to DL + batch.copyBuffer(bufGradOutStage, bufGradOutDL, gradOutBytes); + batch.copyBuffer(bufInputStage, bufInputDL, inputBytes); + batch.copyBuffer(bufWeightsStage, bufWeightsDL, weightBytes); + batch.copyBuffer(bufGradInStage, bufGradInDL, gradInBytes); + batch.copyBuffer(bufGradWStage, bufGradWDL, gradWBytes); + batch.copyBuffer(bufGradBiasStage, bufGradBiasDL, gradBiasBytes); + + batch.transferComputeBarrier(); + // Pass 0: grad_input bwdParams.passType = 0; uint32_t gx0 = (p.inputDim + 15) / 16; @@ -193,19 +264,32 @@ void linearBackward(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx2, 1, 1, &bwdParams, sizeof(bwdParams)); + batch.transferComputeBarrier(); + + // Stage-out: copy 3 grad buffers from DL → HOST_CACHED readback staging + batch.copyBuffer(bufGradInDL, bufGradInStage, gradInBytes); + batch.copyBuffer(bufGradWDL, bufGradWStage, gradWBytes); + batch.copyBuffer(bufGradBiasDL, bufGradBiasStage, gradBiasBytes); + batch.submitDeferred(); batch.waitForCompletion(); - pool.download(bufGradIn, gradInput, gradInBytes); - pool.download(bufGradW, gradWeight, gradWBytes); - pool.download(bufGradBias, gradBias, gradBiasBytes); - - pool.release(bufGradOut); - pool.release(bufInput); - pool.release(bufWeights); - pool.release(bufGradIn); - pool.release(bufGradW); - pool.release(bufGradBias); + pool.download(bufGradInStage, gradInput, gradInBytes); + pool.download(bufGradWStage, gradWeight, gradWBytes); + pool.download(bufGradBiasStage, gradBias, gradBiasBytes); + + pool.release(bufGradOutDL); + pool.release(bufInputDL); + pool.release(bufWeightsDL); + pool.release(bufGradInDL); + pool.release(bufGradWDL); + pool.release(bufGradBiasDL); + pool.release(bufGradOutStage); + pool.release(bufInputStage); + pool.release(bufWeightsStage); + pool.release(bufGradInStage); + pool.release(bufGradWStage); + pool.release(bufGradBiasStage); } // ── GPU dropout ────────────────────────────────────────────────────────── @@ -215,20 +299,25 @@ void dropout(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, uint32_t totalElements, float dropoutProb, bool isTraining) { const size_t bytes = size_t(totalElements) * sizeof(float); - GrillyBuffer bufInput = pool.acquire(bytes); - GrillyBuffer bufRandom = pool.acquire(bytes); - GrillyBuffer bufOutput = pool.acquire(bytes); + // Staging pattern: 2 stage-in (input, randomMask), 1 stage-out (output) + GrillyBuffer bufInputDL = pool.acquireDeviceLocal(bytes); + GrillyBuffer bufRandomDL = pool.acquireDeviceLocal(bytes); + GrillyBuffer bufOutputDL = pool.acquireDeviceLocal(bytes); + + GrillyBuffer bufInputStage = pool.acquire(bytes); + GrillyBuffer bufRandomStage = pool.acquire(bytes); + GrillyBuffer bufOutputStage = pool.acquireReadback(bytes); - pool.upload(bufInput, input, bytes); - pool.upload(bufRandom, randomMask, bytes); + pool.upload(bufInputStage, input, bytes); + pool.upload(bufRandomStage, randomMask, bytes); PipelineEntry pipe = cache.getOrCreate("fnn-dropout", 3, sizeof(DropoutParams)); std::vector bufInfos = { - {bufInput.handle, 0, bytes}, - {bufRandom.handle, 0, bytes}, - {bufOutput.handle, 0, bytes}, + {bufInputDL.handle, 0, bytes}, + {bufRandomDL.handle, 0, bytes}, + {bufOutputDL.handle, 0, bytes}, }; VkDescriptorSet descSet = cache.allocDescriptorSet("fnn-dropout", bufInfos); @@ -236,16 +325,24 @@ void dropout(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, uint32_t gx = (totalElements + 255) / 256; batch.begin(); + batch.copyBuffer(bufInputStage, bufInputDL, bytes); + batch.copyBuffer(bufRandomStage, bufRandomDL, bytes); + batch.transferComputeBarrier(); batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, 1, 1, &push, sizeof(push)); + batch.transferComputeBarrier(); + batch.copyBuffer(bufOutputDL, bufOutputStage, bytes); batch.submitDeferred(); batch.waitForCompletion(); - pool.download(bufOutput, output, bytes); + pool.download(bufOutputStage, output, bytes); - pool.release(bufInput); - pool.release(bufRandom); - pool.release(bufOutput); + pool.release(bufInputDL); + pool.release(bufRandomDL); + pool.release(bufOutputDL); + pool.release(bufInputStage); + pool.release(bufRandomStage); + pool.release(bufOutputStage); } } // namespace ops diff --git a/cpp/src/ops/loss.cpp b/cpp/src/ops/loss.cpp index f87a219..6c9fcbf 100644 --- a/cpp/src/ops/loss.cpp +++ b/cpp/src/ops/loss.cpp @@ -24,24 +24,30 @@ void crossEntropyLoss(CommandBatch& batch, BufferPool& pool, const size_t lossBytes = size_t(totalPositions) * sizeof(float); const size_t auxBytes = size_t(totalPositions) * sizeof(float); - GrillyBuffer bufLogits = pool.acquire(logitBytes); - GrillyBuffer bufTarget = pool.acquire(targetBytes); - GrillyBuffer bufLoss = pool.acquire(lossBytes); - GrillyBuffer bufMax = pool.acquire(auxBytes); - GrillyBuffer bufSumExp = pool.acquire(auxBytes); + // Staging pattern: 2 stage-in (logits, targets), 1 stage-out (losses). + // max and sumExp are intermediate DL-only buffers (CPU never sees them). + GrillyBuffer bufLogitsDL = pool.acquireDeviceLocal(logitBytes); + GrillyBuffer bufTargetDL = pool.acquireDeviceLocal(targetBytes); + GrillyBuffer bufLossDL = pool.acquireDeviceLocal(lossBytes); + GrillyBuffer bufMaxDL = pool.acquireDeviceLocal(auxBytes); + GrillyBuffer bufSumExpDL = pool.acquireDeviceLocal(auxBytes); - pool.upload(bufLogits, logits, logitBytes); - pool.upload(bufTarget, reinterpret_cast(targets), targetBytes); + GrillyBuffer bufLogitsStage = pool.acquire(logitBytes); + GrillyBuffer bufTargetStage = pool.acquire(targetBytes); + GrillyBuffer bufLossStage = pool.acquireReadback(lossBytes); + + pool.upload(bufLogitsStage, logits, logitBytes); + pool.upload(bufTargetStage, reinterpret_cast(targets), targetBytes); PipelineEntry pipe = cache.getOrCreate("loss-cross-entropy", 5, sizeof(CrossEntropyParams)); std::vector bufInfos = { - {bufLogits.handle, 0, logitBytes}, - {bufTarget.handle, 0, targetBytes}, - {bufLoss.handle, 0, lossBytes}, - {bufMax.handle, 0, auxBytes}, - {bufSumExp.handle, 0, auxBytes}, + {bufLogitsDL.handle, 0, logitBytes}, + {bufTargetDL.handle, 0, targetBytes}, + {bufLossDL.handle, 0, lossBytes}, + {bufMaxDL.handle, 0, auxBytes}, + {bufSumExpDL.handle, 0, auxBytes}, }; VkDescriptorSet descSet = cache.allocDescriptorSet("loss-cross-entropy", bufInfos); @@ -49,6 +55,9 @@ void crossEntropyLoss(CommandBatch& batch, BufferPool& pool, uint32_t gx = (totalPositions + 255) / 256; batch.begin(); + batch.copyBuffer(bufLogitsStage, bufLogitsDL, logitBytes); + batch.copyBuffer(bufTargetStage, bufTargetDL, targetBytes); + batch.transferComputeBarrier(); // Pass 0: find max logit per position CrossEntropyParams push0 = p; @@ -70,16 +79,22 @@ void crossEntropyLoss(CommandBatch& batch, BufferPool& pool, batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, 1, 1, &push2, sizeof(push2)); + batch.transferComputeBarrier(); + batch.copyBuffer(bufLossDL, bufLossStage, lossBytes); + batch.submitDeferred(); batch.waitForCompletion(); - pool.download(bufLoss, losses, lossBytes); + pool.download(bufLossStage, losses, lossBytes); - pool.release(bufLogits); - pool.release(bufTarget); - pool.release(bufLoss); - pool.release(bufMax); - pool.release(bufSumExp); + pool.release(bufLogitsDL); + pool.release(bufTargetDL); + pool.release(bufLossDL); + pool.release(bufMaxDL); + pool.release(bufSumExpDL); + pool.release(bufLogitsStage); + pool.release(bufTargetStage); + pool.release(bufLossStage); } // ── Cross-entropy backward ─────────────────────────────────────────────── @@ -92,20 +107,25 @@ void crossEntropyBackward(CommandBatch& batch, BufferPool& pool, const size_t logitBytes = size_t(p.batchSize) * p.numClasses * sizeof(float); const size_t targetBytes = size_t(p.batchSize) * sizeof(uint32_t); - GrillyBuffer bufLogits = pool.acquire(logitBytes); - GrillyBuffer bufTarget = pool.acquire(targetBytes); - GrillyBuffer bufGrad = pool.acquire(logitBytes); + // Staging pattern: 2 stage-in (logits, targets), 1 stage-out (gradLogits) + GrillyBuffer bufLogitsDL = pool.acquireDeviceLocal(logitBytes); + GrillyBuffer bufTargetDL = pool.acquireDeviceLocal(targetBytes); + GrillyBuffer bufGradDL = pool.acquireDeviceLocal(logitBytes); + + GrillyBuffer bufLogitsStage = pool.acquire(logitBytes); + GrillyBuffer bufTargetStage = pool.acquire(targetBytes); + GrillyBuffer bufGradStage = pool.acquireReadback(logitBytes); - pool.upload(bufLogits, logits, logitBytes); - pool.upload(bufTarget, reinterpret_cast(targets), targetBytes); + pool.upload(bufLogitsStage, logits, logitBytes); + pool.upload(bufTargetStage, reinterpret_cast(targets), targetBytes); PipelineEntry pipe = cache.getOrCreate("cross-entropy-backward", 3, sizeof(CrossEntropyBackwardParams)); std::vector bufInfos = { - {bufLogits.handle, 0, logitBytes}, - {bufTarget.handle, 0, targetBytes}, - {bufGrad.handle, 0, logitBytes}, + {bufLogitsDL.handle, 0, logitBytes}, + {bufTargetDL.handle, 0, targetBytes}, + {bufGradDL.handle, 0, logitBytes}, }; VkDescriptorSet descSet = cache.allocDescriptorSet("cross-entropy-backward", bufInfos); @@ -113,16 +133,24 @@ void crossEntropyBackward(CommandBatch& batch, BufferPool& pool, uint32_t gx = (p.batchSize + 255) / 256; batch.begin(); + batch.copyBuffer(bufLogitsStage, bufLogitsDL, logitBytes); + batch.copyBuffer(bufTargetStage, bufTargetDL, targetBytes); + batch.transferComputeBarrier(); batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, 1, 1, &p, sizeof(p)); + batch.transferComputeBarrier(); + batch.copyBuffer(bufGradDL, bufGradStage, logitBytes); batch.submitDeferred(); batch.waitForCompletion(); - pool.download(bufGrad, gradLogits, logitBytes); + pool.download(bufGradStage, gradLogits, logitBytes); - pool.release(bufLogits); - pool.release(bufTarget); - pool.release(bufGrad); + pool.release(bufLogitsDL); + pool.release(bufTargetDL); + pool.release(bufGradDL); + pool.release(bufLogitsStage); + pool.release(bufTargetStage); + pool.release(bufGradStage); } } // namespace ops diff --git a/cpp/src/ops/moe_forward.cpp b/cpp/src/ops/moe_forward.cpp new file mode 100644 index 0000000..6802e8c --- /dev/null +++ b/cpp/src/ops/moe_forward.cpp @@ -0,0 +1,752 @@ +/// moe_forward.cpp — fused MoE forward on GPU (router/blend CPU) + CPU backward. +#include "grilly/ops/moe_forward.h" +#include "grilly/ops/batched_ops.h" + +#include +#include +#include +#include +#include + +namespace grilly { +namespace ops { + +namespace { + +static std::unordered_map g_moe; +static int g_next_moe = 1; + +static void transpose_dd(const float* w, float* wt, uint32_t d) { + for (uint32_t r = 0; r < d; ++r) + for (uint32_t c = 0; c < d; ++c) + wt[c * d + r] = w[r * d + c]; +} + +static void transpose_vd(const float* w, float* wt, uint32_t v, uint32_t d) { + for (uint32_t r = 0; r < v; ++r) + for (uint32_t c = 0; c < d; ++c) + wt[c * v + r] = w[r * d + c]; +} + +static Eigen::VectorXf softmax_vec(const Eigen::VectorXf& x) { + float m = x.maxCoeff(); + Eigen::VectorXf e = (x.array() - m).exp(); + return e / e.sum(); +} + +static Eigen::VectorXf softmax_grad(const Eigen::VectorXf& p, + const Eigen::VectorXf& grad_p) { + float s = p.dot(grad_p); + return p.cwiseProduct(grad_p - Eigen::VectorXf::Constant(p.size(), s)); +} + +} // namespace + +MoeHandleCache& moe_get_cache(int handle) { + auto it = g_moe.find(handle); + if (it == g_moe.end()) + throw std::runtime_error("Invalid moe handle"); + return it->second; +} + +void moe_release(BufferPool& pool, int handle) { + auto it = g_moe.find(handle); + if (it == g_moe.end()) + return; + MoeHandleCache& h = it->second; + + pool.release(h.embedW); + pool.release(h.posW); + pool.release(h.outW); + pool.release(h.outWt); + for (auto& lw : h.layers) { + pool.release(lw.routerW); + pool.release(lw.routerB); + for (uint32_t e = 0; e < lw.expertW.size(); ++e) { + pool.release(lw.expertW[e]); + pool.release(lw.expertWt[e]); + } + pool.release(lw.expertPacked); + } + pool.release(h.bufIds); + pool.release(h.bufPosSlice); + pool.release(h.bufX); + for (auto& b : h.bufExpertOut) + pool.release(b); + pool.release(h.bufBlended); + pool.release(h.bufLogits); + + g_moe.erase(it); +} + +static void upload_dd_pair(BufferPool& pool, GrillyBuffer& wbuf, GrillyBuffer& wtbuf, + const float* w, uint32_t d) { + size_t bytes = size_t(d) * d * sizeof(float); + wbuf = pool.acquire(bytes); + pool.upload(wbuf, w, bytes); + std::vector wt(d * d); + transpose_dd(w, wt.data(), d); + wtbuf = pool.acquire(bytes); + pool.upload(wtbuf, wt.data(), bytes); +} + +int moe_upload(BufferPool& pool, + uint32_t vocab_size, uint32_t d_model, uint32_t max_seq, + const float* embed_w, const float* pos_w, + const std::vector& expert_ws, + const std::vector& router_ws, + const std::vector& router_bs, + const float* out_w, + uint32_t n_layers, uint32_t n_experts) { + + if (n_layers == 0 || n_experts == 0 || d_model == 0 || vocab_size == 0 || max_seq == 0) + throw std::runtime_error("moe_upload: invalid dimensions"); + if (expert_ws.size() != size_t(n_layers) * n_experts) + throw std::runtime_error("moe_upload: expert_ws length mismatch"); + if (router_ws.size() != n_layers || router_bs.size() != n_layers) + throw std::runtime_error("moe_upload: router list length mismatch"); + + MoeHandleCache h; + h.vocab = vocab_size; + h.d = d_model; + h.maxSeq = max_seq; + h.nLayers = n_layers; + h.nExperts = n_experts; + + size_t embed_bytes = size_t(vocab_size) * d_model * sizeof(float); + h.cpu_embed.assign(embed_bytes / sizeof(float), 0.f); + std::memcpy(h.cpu_embed.data(), embed_w, embed_bytes); + h.embedW = pool.acquire(embed_bytes); + pool.upload(h.embedW, embed_w, embed_bytes); + + size_t pos_bytes = size_t(max_seq) * d_model * sizeof(float); + h.cpu_pos.assign(pos_bytes / sizeof(float), 0.f); + std::memcpy(h.cpu_pos.data(), pos_w, pos_bytes); + h.posW = pool.acquire(pos_bytes); + pool.upload(h.posW, pos_w, pos_bytes); + + size_t out_bytes = size_t(vocab_size) * d_model * sizeof(float); + h.cpu_out_w.assign(out_bytes / sizeof(float), 0.f); + std::memcpy(h.cpu_out_w.data(), out_w, out_bytes); + h.outW = pool.acquire(out_bytes); + pool.upload(h.outW, out_w, out_bytes); + std::vector out_wt(d_model * vocab_size); + transpose_vd(out_w, out_wt.data(), vocab_size, d_model); + h.outWt = pool.acquire(out_wt.size() * sizeof(float)); + pool.upload(h.outWt, out_wt.data(), out_wt.size() * sizeof(float)); + + h.cpu_expert_w.resize(size_t(n_layers) * n_experts); + h.layers.resize(n_layers); + for (uint32_t l = 0; l < n_layers; ++l) { + auto& lw = h.layers[l]; + size_t rbytes = size_t(n_experts) * d_model * sizeof(float); + lw.routerW = pool.acquire(rbytes); + pool.upload(lw.routerW, router_ws[l], rbytes); + lw.routerB = pool.acquire(n_experts * sizeof(float)); + pool.upload(lw.routerB, router_bs[l], n_experts * sizeof(float)); + + h.cpu_router_w.emplace_back(n_experts * d_model); + std::memcpy(h.cpu_router_w.back().data(), router_ws[l], rbytes); + h.cpu_router_b.emplace_back(n_experts); + std::memcpy(h.cpu_router_b.back().data(), router_bs[l], + n_experts * sizeof(float)); + + lw.expertW.resize(n_experts); + lw.expertWt.resize(n_experts); + // Pack all experts contiguously for fused shader + size_t packed_size = size_t(n_experts) * d_model * d_model * sizeof(float); + std::vector packed(n_experts * d_model * d_model); + for (uint32_t e = 0; e < n_experts; ++e) { + const float* w = expert_ws[l * n_experts + e]; + upload_dd_pair(pool, lw.expertW[e], lw.expertWt[e], w, d_model); + h.cpu_expert_w[l * n_experts + e].assign(d_model * d_model, 0.f); + std::memcpy(h.cpu_expert_w[l * n_experts + e].data(), w, + d_model * d_model * sizeof(float)); + std::memcpy(packed.data() + e * d_model * d_model, w, + d_model * d_model * sizeof(float)); + } + lw.expertPacked = pool.acquire(packed_size); + pool.upload(lw.expertPacked, packed.data(), packed_size); + } + + size_t seq_d = size_t(max_seq) * d_model * sizeof(float); + h.bufIds = pool.acquire(max_seq * sizeof(uint32_t)); + h.bufPosSlice = pool.acquire(seq_d); + h.bufX = pool.acquire(seq_d); + h.bufBlended = pool.acquire(seq_d); + + // Activation buffers for backward (n_layers + 1: input to each layer + final) + h.bufActivations.resize(n_layers + 1); + for (uint32_t l = 0; l <= n_layers; ++l) + h.bufActivations[l] = pool.acquire(seq_d); + h.fwd_router_weights.resize(n_layers); + + h.bufExpertOut.resize(n_experts); + for (uint32_t e = 0; e < n_experts; ++e) + h.bufExpertOut[e] = pool.acquire(seq_d); + h.bufLogits = pool.acquire(size_t(max_seq) * vocab_size * sizeof(float)); + + int hid = g_next_moe++; + g_moe[hid] = std::move(h); + return hid; +} + +void moe_update_weights(BufferPool& pool, MoeHandleCache& h, + const float* embed_w, const float* pos_w, + const std::vector& expert_ws, + const std::vector& router_ws, + const std::vector& router_bs, + const float* out_w) { + + uint32_t V = h.vocab; + uint32_t d = h.d; + uint32_t max_seq = h.maxSeq; + uint32_t L = h.nLayers; + uint32_t E = h.nExperts; + + size_t embed_bytes = size_t(V) * d * sizeof(float); + std::memcpy(h.cpu_embed.data(), embed_w, embed_bytes); + pool.upload(h.embedW, embed_w, embed_bytes); + + size_t pos_bytes = size_t(max_seq) * d * sizeof(float); + std::memcpy(h.cpu_pos.data(), pos_w, pos_bytes); + pool.upload(h.posW, pos_w, pos_bytes); + + size_t out_bytes = size_t(V) * d * sizeof(float); + std::memcpy(h.cpu_out_w.data(), out_w, out_bytes); + pool.upload(h.outW, out_w, out_bytes); + std::vector out_wt(d * V); + transpose_vd(out_w, out_wt.data(), V, d); + pool.upload(h.outWt, out_wt.data(), out_wt.size() * sizeof(float)); + + for (uint32_t l = 0; l < L; ++l) { + auto& lw = h.layers[l]; + size_t rbytes = size_t(E) * d * sizeof(float); + pool.upload(lw.routerW, router_ws[l], rbytes); + pool.upload(lw.routerB, router_bs[l], E * sizeof(float)); + std::memcpy(h.cpu_router_w[l].data(), router_ws[l], rbytes); + std::memcpy(h.cpu_router_b[l].data(), router_bs[l], E * sizeof(float)); + + size_t packed_size = size_t(E) * d * d * sizeof(float); + std::vector packed(E * d * d); + for (uint32_t e = 0; e < E; ++e) { + const float* w = expert_ws[l * E + e]; + size_t dd = size_t(d) * d * sizeof(float); + pool.upload(lw.expertW[e], w, dd); + std::vector wt(d * d); + transpose_dd(w, wt.data(), d); + pool.upload(lw.expertWt[e], wt.data(), dd); + std::memcpy(h.cpu_expert_w[l * E + e].data(), w, dd); + std::memcpy(packed.data() + e * d * d, w, dd); + } + pool.upload(lw.expertPacked, packed.data(), packed_size); + } +} + +void moe_forward_gpu(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, + MoeHandleCache& h, const int32_t* input_ids, uint32_t seq_len, + float* logits_out) { + + if (seq_len == 0 || seq_len > h.maxSeq) + throw std::runtime_error("moe_forward: invalid seq_len"); + + uint32_t S = seq_len; + uint32_t d = h.d; + uint32_t V = h.vocab; + + pool.upload(h.bufIds, reinterpret_cast(input_ids), + S * sizeof(uint32_t)); + pool.upload(h.bufPosSlice, h.cpu_pos.data(), S * d * sizeof(float)); + + batch.begin(); + batchedEmbeddingLookup(batch, cache, h.bufIds, h.embedW, h.bufX, 1, S, h.vocab, d); + batch.barrier(); + batchedAdd(batch, cache, h.bufX, h.bufPosSlice, S * d); + batch.submitDeferred(); + batch.waitForCompletion(); + + size_t packed_size = size_t(h.nExperts) * d * d * sizeof(float); + bool has_router = cache.hasShader("moe-router"); + // Prefer vec4 path, fall back to scalar fused, then legacy + bool has_vec4 = cache.hasShader("moe-layer-fused-vec4") && (d % 4 == 0); + bool has_fused = has_vec4 || cache.hasShader("moe-layer-fused"); + + if (has_fused && has_router && h.nExperts == 4) { + // ══════════════════════════════════════════════════════════════ + // ALL-GPU path: router + experts in ONE command buffer submission + // ZERO CPU round-trips between layers. ONE fence wait total. + // Vec4 path: 4x memory bandwidth via 128-bit loads. + // ══════════════════════════════════════════════════════════════ + + size_t scratch_size = (d + h.nExperts + h.nExperts) * sizeof(float); + GrillyBuffer bufScratch = pool.acquire(scratch_size); + uint32_t weights_offset = d + h.nExperts; + + struct RouterPush { uint32_t seq_len, d_model, n_experts, pass; }; + struct FusedPush { uint32_t seq_len, d_model, n_experts, weights_offset; }; + + PipelineEntry routerPipe = cache.getOrCreate("moe-router", 4, sizeof(RouterPush)); + + // Choose vec4 or scalar fused shader + const char* fused_name = has_vec4 ? "moe-layer-fused-vec4" : "moe-layer-fused"; + PipelineEntry fusedPipe = cache.getOrCreate(fused_name, 4, sizeof(FusedPush)); + + batch.begin(); + + // Save initial activation + batch.copyBuffer(h.bufX, h.bufActivations[0], S * d * sizeof(float)); + batch.barrier(); + + for (uint32_t l = 0; l < h.nLayers; ++l) { + auto& layer = h.layers[l]; + + // Router: 3 passes (mean, logits, softmax) + for (uint32_t pass = 0; pass < 3; ++pass) { + std::vector rbufs = { + {h.bufX.handle, 0, S * d * sizeof(float)}, + {layer.routerW.handle, 0, h.nExperts * d * sizeof(float)}, + {layer.routerB.handle, 0, h.nExperts * sizeof(float)}, + {bufScratch.handle, 0, scratch_size}, + }; + VkDescriptorSet rdesc = cache.allocDescriptorSet("moe-router", rbufs); + RouterPush rp{S, d, h.nExperts, pass}; + uint32_t wg = (pass == 0) ? (d + 255) / 256 : + (pass == 1) ? (h.nExperts + 255) / 256 : 1; + batch.dispatch(routerPipe.pipeline, routerPipe.layout, rdesc, + wg, 1, 1, &rp, sizeof(rp)); + batch.barrier(); + } + + // Fused expert layer (vec4 or scalar) + { + FusedPush fp{S, d, h.nExperts, weights_offset}; + std::vector fbufs = { + {h.bufX.handle, 0, S * d * sizeof(float)}, + {layer.expertPacked.handle, 0, packed_size}, + {h.bufBlended.handle, 0, S * d * sizeof(float)}, + {bufScratch.handle, 0, scratch_size}, + }; + VkDescriptorSet fdesc = cache.allocDescriptorSet(fused_name, fbufs); + + uint32_t gx, gy; + if (has_vec4) { + gx = ((d / 4) + 15) / 16; // vec4 columns + gy = (S + 15) / 16; + } else { + gx = (d + 15) / 16; + gy = (S + 15) / 16; + } + batch.dispatch(fusedPipe.pipeline, fusedPipe.layout, fdesc, + gx, gy, 1, &fp, sizeof(fp)); + batch.barrier(); + } + + std::swap(h.bufX, h.bufBlended); + + // Save activation for backward + batch.copyBuffer(h.bufX, h.bufActivations[l + 1], S * d * sizeof(float)); + batch.barrier(); + } + + // Output projection (same command buffer!) + batchedLinear(batch, cache, h.bufX, h.outW, nullptr, h.bufLogits, S, d, V); + + // ONE submit, ONE fence wait for entire forward pass + batch.submitDeferred(); + batch.waitForCompletion(); + + pool.release(bufScratch); + } else { + // ══════════════════════════════════════════════════════════════ + // Legacy CPU-router path (fallback) + // ══════════════════════════════════════════════════════════════ + std::vector x_cpu(S * d); + std::vector x_mean(d); + + for (uint32_t l = 0; l < h.nLayers; ++l) { + pool.download(h.bufX, x_cpu.data(), S * d * sizeof(float)); + for (uint32_t j = 0; j < d; ++j) { + float s = 0.f; + for (uint32_t i = 0; i < S; ++i) + s += x_cpu[i * d + j]; + x_mean[j] = s / float(S); + } + + Eigen::Map xmean(x_mean.data(), d); + Eigen::Map> Wr( + h.cpu_router_w[l].data(), h.nExperts, d); + Eigen::Map b(h.cpu_router_b[l].data(), h.nExperts); + Eigen::VectorXf logits_v = Wr * xmean + b; + Eigen::VectorXf p = softmax_vec(logits_v); + + auto& layer = h.layers[l]; + batch.begin(); + for (uint32_t e = 0; e < h.nExperts; ++e) { + batchedLinear(batch, cache, h.bufX, layer.expertW[e], nullptr, + h.bufExpertOut[e], S, d, d); + } + batch.submitDeferred(); + batch.waitForCompletion(); + + std::vector expert_flat(S * d * h.nExperts); + for (uint32_t e = 0; e < h.nExperts; ++e) + pool.download(h.bufExpertOut[e], expert_flat.data() + e * S * d, + S * d * sizeof(float)); + + std::vector blended(S * d, 0.f); + for (uint32_t e = 0; e < h.nExperts; ++e) { + float pe = p[e]; + for (size_t i = 0; i < S * d; ++i) + blended[i] += pe * expert_flat[e * S * d + i]; + } + pool.upload(h.bufBlended, blended.data(), S * d * sizeof(float)); + + batch.begin(); + batchedAdd(batch, cache, h.bufX, h.bufBlended, S * d); + batch.submitDeferred(); + batch.waitForCompletion(); + } + } + + size_t log_bytes = S * V * sizeof(float); + batch.begin(); + batchedLinear(batch, cache, h.bufX, h.outW, nullptr, h.bufLogits, S, d, V); + batch.submitDeferred(); + batch.waitForCompletion(); + pool.download(h.bufLogits, logits_out, log_bytes); +} + +MoeGradients moe_backward_cpu(const MoeHandleCache& h, + const int32_t* input_ids, uint32_t seq_len, + const float* grad_logits) { + + uint32_t S = seq_len; + uint32_t d = h.d; + uint32_t V = h.vocab; + uint32_t L = h.nLayers; + uint32_t E = h.nExperts; + if (S == 0 || S > h.maxSeq) + throw std::runtime_error("moe_backward: invalid seq_len"); + + using RowMajor = Eigen::Matrix; + + Eigen::MatrixXf X = Eigen::MatrixXf::Zero(S, d); + for (uint32_t s = 0; s < S; ++s) { + int32_t tok = input_ids[s]; + if (tok < 0 || static_cast(tok) >= V) + throw std::runtime_error("moe_backward: token id out of range"); + for (uint32_t j = 0; j < d; ++j) { + X(s, j) = h.cpu_embed[tok * d + j] + h.cpu_pos[s * d + j]; + } + } + + std::vector Xs; + Xs.reserve(L + 1); + Xs.push_back(X); + + struct LayerTrace { + Eigen::VectorXf p; + Eigen::VectorXf xmean; + std::vector Y; + }; + std::vector trace(L); + + for (uint32_t l = 0; l < L; ++l) { + Eigen::MatrixXf& cur = Xs.back(); + Eigen::VectorXf xmean = cur.colwise().mean().transpose(); + + Eigen::Map> Wr( + h.cpu_router_w[l].data(), E, d); + Eigen::Map b(h.cpu_router_b[l].data(), E); + Eigen::VectorXf logits = Wr * xmean + b; + Eigen::VectorXf p = softmax_vec(logits); + + Eigen::MatrixXf blend = Eigen::MatrixXf::Zero(S, d); + trace[l].Y.resize(E); + for (uint32_t e = 0; e < E; ++e) { + Eigen::Map We(h.cpu_expert_w[l * E + e].data(), d, d); + trace[l].Y[e] = cur * We.transpose(); + blend += p[e] * trace[l].Y[e]; + } + trace[l].p = p; + trace[l].xmean = xmean; + + Xs.push_back(cur + blend); + } + + Eigen::MatrixXf Xfinal = Xs.back(); + + Eigen::Map Glog(grad_logits, S, V); + Eigen::Map Wo(h.cpu_out_w.data(), V, d); + + MoeGradients out; + out.grad_out_w.resize(V * d); + Eigen::MatrixXf grad_Wo_mat = Glog.transpose() * Xfinal; + Eigen::Map gwo(out.grad_out_w.data(), V, d); + gwo = grad_Wo_mat; + + out.grad_router_w.resize(L); + out.grad_router_b.resize(L); + out.grad_experts.resize(L * E); + for (auto& v : out.grad_router_w) + v.assign(E * d, 0.f); + for (auto& v : out.grad_router_b) + v.assign(E, 0.f); + for (auto& v : out.grad_experts) + v.assign(d * d, 0.f); + + // logits = Xfinal * Wo^T where Wo is (V, d), so grad_X = grad_logits * Wo. + Eigen::MatrixXf g = Glog * Wo; + + for (int li = static_cast(L) - 1; li >= 0; --li) { + uint32_t l = static_cast(li); + const Eigen::MatrixXf& cur = Xs[l]; + const LayerTrace& tr = trace[l]; + + Eigen::Map> Wr( + h.cpu_router_w[l].data(), E, d); + + Eigen::MatrixXf grad_blend = g; + Eigen::MatrixXf grad_Xl = Eigen::MatrixXf::Zero(S, d); + grad_Xl += grad_blend; + + Eigen::VectorXf grad_p(E); + for (uint32_t e = 0; e < E; ++e) + grad_p[e] = (tr.Y[e].array() * grad_blend.array()).sum(); + + Eigen::VectorXf grad_logits_r = softmax_grad(tr.p, grad_p); + Eigen::VectorXf grad_xm = Wr.transpose() * grad_logits_r; + + for (uint32_t s = 0; s < S; ++s) + grad_Xl.row(s) += (1.0f / float(S)) * grad_xm.transpose(); + + Eigen::Map> gWr( + out.grad_router_w[l].data(), E, d); + gWr = grad_logits_r * tr.xmean.transpose(); + + Eigen::Map gb(out.grad_router_b[l].data(), E); + gb = grad_logits_r; + + for (uint32_t e = 0; e < E; ++e) { + Eigen::MatrixXf grad_Y = tr.p[e] * grad_blend; + Eigen::Map We(h.cpu_expert_w[l * E + e].data(), d, d); + grad_Xl += grad_Y * We; + + Eigen::MatrixXf grad_We = grad_Y.transpose() * cur; + std::memcpy(out.grad_experts[l * E + e].data(), grad_We.data(), + d * d * sizeof(float)); + } + + g = grad_Xl; + } + + out.grad_embed.assign(h.vocab * d, 0.f); + out.grad_pos.assign(h.maxSeq * d, 0.f); + + for (uint32_t s = 0; s < S; ++s) { + int32_t tok = input_ids[s]; + for (uint32_t j = 0; j < d; ++j) { + out.grad_embed[tok * d + j] += g(s, j); + out.grad_pos[s * d + j] = g(s, j); + } + } + + return out; +} + +MoeGradients moe_backward_gpu(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, + MoeHandleCache& h, const int32_t* input_ids, uint32_t seq_len, + const float* grad_logits) { + + // GPU backward shaders exist but Eigen CPU is faster for now + // (strided memory access pattern in backward shader kills GPU perf). + // TODO: transpose expert weights for backward-friendly layout. + return moe_backward_cpu(h, input_ids, seq_len, grad_logits); + + bool has_bwd_vec4 = cache.hasShader("moe-layer-backward-vec4") && (h.d % 4 == 0); + bool has_bwd = has_bwd_vec4 || cache.hasShader("moe-layer-backward"); + bool has_gw = cache.hasShader("moe-layer-grad-weight"); + + if (!has_bwd) + return moe_backward_cpu(h, input_ids, seq_len, grad_logits); + + uint32_t S = seq_len; + uint32_t d = h.d; + uint32_t V = h.vocab; + uint32_t L = h.nLayers; + uint32_t E = h.nExperts; + size_t sd = S * d * sizeof(float); + size_t packed_size = size_t(E) * d * d * sizeof(float); + + // Re-compute router weights on CPU (tiny) + std::vector> router_p(L); + for (uint32_t l = 0; l < L; ++l) { + Eigen::Map> Wr( + h.cpu_router_w[l].data(), E, d); + Eigen::Map b(h.cpu_router_b[l].data(), E); + + std::vector act(S * d); + pool.download(h.bufActivations[l], act.data(), sd); + Eigen::VectorXf xmean = Eigen::VectorXf::Zero(d); + for (uint32_t i = 0; i < S; ++i) + for (uint32_t j = 0; j < d; ++j) + xmean[j] += act[i * d + j]; + xmean /= float(S); + + Eigen::VectorXf logits_v = Wr * xmean + b; + Eigen::VectorXf p = softmax_vec(logits_v); + router_p[l].resize(E); + for (uint32_t e = 0; e < E; ++e) + router_p[l][e] = p[e]; + } + + // Upload grad_logits + GrillyBuffer bufGL = pool.acquire(S * V * sizeof(float)); + pool.upload(bufGL, grad_logits, S * V * sizeof(float)); + + // Output projection backward: dx = grad_logits @ out_w (using fnn-linear) + GrillyBuffer bufDx = pool.acquire(sd); + batch.begin(); + batchedLinear(batch, cache, bufGL, h.outWt, nullptr, bufDx, S, V, d); + batch.submitDeferred(); + batch.waitForCompletion(); + + // grad_out_w = grad_logits.T @ x_final (CPU — vocab-sized, not worth GPU) + std::vector gl_cpu(S * V); + pool.download(bufGL, gl_cpu.data(), S * V * sizeof(float)); + std::vector xfinal(S * d); + pool.download(h.bufActivations[L], xfinal.data(), sd); + + MoeGradients out; + out.grad_out_w.resize(V * d, 0.f); + { + using RM = Eigen::Matrix; + Eigen::Map GL(gl_cpu.data(), S, V); + Eigen::Map XF(xfinal.data(), S, d); + RM GOW = GL.transpose() * XF; + std::memcpy(out.grad_out_w.data(), GOW.data(), V * d * sizeof(float)); + } + + // Per-layer backward on GPU + const char* bwd_name = has_bwd_vec4 ? "moe-layer-backward-vec4" : "moe-layer-backward"; + struct BwdPush { uint32_t seq_len, d_model, n_experts; float w0, w1, w2, w3; }; + + PipelineEntry bwdPipe = cache.getOrCreate(bwd_name, 3, sizeof(BwdPush)); + + GrillyBuffer bufGradIn = pool.acquire(sd); + GrillyBuffer bufGradW = pool.acquire(packed_size); + + out.grad_experts.resize(L * E); + out.grad_router_w.resize(L); + out.grad_router_b.resize(L); + + bool has_gw_shader = has_gw; + PipelineEntry gwPipe{}; + if (has_gw_shader) + gwPipe = cache.getOrCreate("moe-layer-grad-weight", 3, sizeof(BwdPush)); + + for (int32_t l = L - 1; l >= 0; --l) { + auto& layer = h.layers[l]; + float w0 = router_p[l][0], w1 = router_p[l][1]; + float w2 = router_p[l][2], w3 = router_p[l][3]; + BwdPush bp{S, d, E, w0, w1, w2, w3}; + + batch.begin(); + + // grad_input via backward shader + { + std::vector bufs = { + {bufDx.handle, 0, sd}, + {layer.expertPacked.handle, 0, packed_size}, + {bufGradIn.handle, 0, sd}, + }; + VkDescriptorSet desc = cache.allocDescriptorSet(bwd_name, bufs); + + uint32_t gx, gy; + if (has_bwd_vec4) { + gx = ((d / 4) + 31) / 32; // vec4 columns, 32-wide workgroup + gy = (S + 7) / 8; // 8-high workgroup + } else { + gx = (d + 15) / 16; + gy = (S + 15) / 16; + } + batch.dispatch(bwdPipe.pipeline, bwdPipe.layout, desc, gx, gy, 1, &bp, sizeof(bp)); + batch.barrier(); + } + + // grad_weight via grad-weight shader (if available) + if (has_gw_shader) { + std::vector bufs = { + {bufDx.handle, 0, sd}, + {h.bufActivations[l].handle, 0, sd}, + {bufGradW.handle, 0, packed_size}, + }; + VkDescriptorSet desc = cache.allocDescriptorSet("moe-layer-grad-weight", bufs); + uint32_t gx = (d + 15) / 16; + uint32_t gy = (d + 15) / 16; + batch.dispatch(gwPipe.pipeline, gwPipe.layout, desc, gx, gy, 1, &bp, sizeof(bp)); + } + + batch.submitDeferred(); + batch.waitForCompletion(); + + // Download grad_W + if (has_gw_shader) { + std::vector gw_packed(E * d * d); + pool.download(bufGradW, gw_packed.data(), packed_size); + for (uint32_t e = 0; e < E; ++e) { + out.grad_experts[l * E + e].assign(d * d, 0.f); + std::memcpy(out.grad_experts[l * E + e].data(), + gw_packed.data() + e * d * d, d * d * sizeof(float)); + } + } else { + // CPU fallback for grad_W + std::vector dx_cpu(S * d); + pool.download(bufDx, dx_cpu.data(), sd); + std::vector act(S * d); + pool.download(h.bufActivations[l], act.data(), sd); + using RM = Eigen::Matrix; + Eigen::Map DX(dx_cpu.data(), S, d); + Eigen::Map ACT(act.data(), S, d); + RM GW = DX.transpose() * ACT; + for (uint32_t e = 0; e < E; ++e) { + out.grad_experts[l * E + e].resize(d * d); + float pe = router_p[l][e]; + for (size_t i = 0; i < d * d; ++i) + out.grad_experts[l * E + e][i] = pe * GW.data()[i]; + } + } + + // Router gradient (CPU, tiny — skip for now, zero placeholder) + out.grad_router_w[l].assign(E * d, 0.f); + out.grad_router_b[l].assign(E, 0.f); + + // Swap for next layer + std::swap(bufDx, bufGradIn); + } + + // Embedding gradient: scatter-add dx + std::vector dx_final(S * d); + pool.download(bufDx, dx_final.data(), sd); + + out.grad_embed.assign(V * d, 0.f); + out.grad_pos.assign(h.maxSeq * d, 0.f); + for (uint32_t s = 0; s < S; ++s) { + int32_t tok = input_ids[s]; + if (tok >= 0 && static_cast(tok) < V) + for (uint32_t j = 0; j < d; ++j) + out.grad_embed[tok * d + j] += dx_final[s * d + j]; + for (uint32_t j = 0; j < d; ++j) + out.grad_pos[s * d + j] = dx_final[s * d + j]; + } + + pool.release(bufGL); + pool.release(bufDx); + pool.release(bufGradIn); + pool.release(bufGradW); + + return out; +} + +} // namespace ops +} // namespace grilly diff --git a/cpp/src/ops/optimizer.cpp b/cpp/src/ops/optimizer.cpp index d780531..38e622f 100644 --- a/cpp/src/ops/optimizer.cpp +++ b/cpp/src/ops/optimizer.cpp @@ -22,44 +22,66 @@ void adamUpdate(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, const AdamParams& p) { const size_t bytes = size_t(p.totalWeights) * sizeof(float); - GrillyBuffer bufW = pool.acquire(bytes); - GrillyBuffer bufGrad = pool.acquire(bytes); - GrillyBuffer bufM = pool.acquire(bytes); - GrillyBuffer bufV = pool.acquire(bytes); - - pool.upload(bufW, weights, bytes); - pool.upload(bufGrad, grad, bytes); - pool.upload(bufM, m, bytes); - pool.upload(bufV, v, bytes); + // Adam updates all 4 buffers in-place: each is both stage-in and + // stage-out. Use HOST_CACHED readback staging for all 4 since we read + // them back to CPU at the end of every step. + GrillyBuffer bufWDL = pool.acquireDeviceLocal(bytes); + GrillyBuffer bufGradDL = pool.acquireDeviceLocal(bytes); + GrillyBuffer bufMDL = pool.acquireDeviceLocal(bytes); + GrillyBuffer bufVDL = pool.acquireDeviceLocal(bytes); + + GrillyBuffer bufWStage = pool.acquireReadback(bytes); + GrillyBuffer bufGradStage = pool.acquireReadback(bytes); + GrillyBuffer bufMStage = pool.acquireReadback(bytes); + GrillyBuffer bufVStage = pool.acquireReadback(bytes); + + pool.upload(bufWStage, weights, bytes); + pool.upload(bufGradStage, grad, bytes); + pool.upload(bufMStage, m, bytes); + pool.upload(bufVStage, v, bytes); PipelineEntry pipe = cache.getOrCreate("adam-update", 4, sizeof(AdamParams)); std::vector bufInfos = { - {bufW.handle, 0, bytes}, - {bufGrad.handle, 0, bytes}, - {bufM.handle, 0, bytes}, - {bufV.handle, 0, bytes}, + {bufWDL.handle, 0, bytes}, + {bufGradDL.handle, 0, bytes}, + {bufMDL.handle, 0, bytes}, + {bufVDL.handle, 0, bytes}, }; VkDescriptorSet descSet = cache.allocDescriptorSet("adam-update", bufInfos); uint32_t gx = (p.totalWeights + 255) / 256; batch.begin(); + batch.copyBuffer(bufWStage, bufWDL, bytes); + batch.copyBuffer(bufGradStage, bufGradDL, bytes); + batch.copyBuffer(bufMStage, bufMDL, bytes); + batch.copyBuffer(bufVStage, bufVDL, bytes); + batch.transferComputeBarrier(); batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, 1, 1, &p, sizeof(p)); + batch.transferComputeBarrier(); + batch.copyBuffer(bufWDL, bufWStage, bytes); + batch.copyBuffer(bufGradDL, bufGradStage, bytes); + batch.copyBuffer(bufMDL, bufMStage, bytes); + batch.copyBuffer(bufVDL, bufVStage, bytes); batch.submitDeferred(); batch.waitForCompletion(); - pool.download(bufW, weights, bytes); - pool.download(bufGrad, grad, bytes); - pool.download(bufM, m, bytes); - pool.download(bufV, v, bytes); - - pool.release(bufW); - pool.release(bufGrad); - pool.release(bufM); - pool.release(bufV); + pool.download(bufWStage, weights, bytes); + pool.download(bufGradStage, grad, bytes); + pool.download(bufMStage, m, bytes); + pool.download(bufVStage, v, bytes); + + pool.release(bufWDL); + pool.release(bufGradDL); + pool.release(bufMDL); + pool.release(bufVDL); + pool.release(bufWStage); + pool.release(bufGradStage); + pool.release(bufMStage); + pool.release(bufVStage); } // ── AdamW ──────────────────────────────────────────────────────────────── @@ -69,24 +91,30 @@ void adamwUpdate(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, const AdamWParams& p) { const size_t bytes = size_t(p.totalWeights) * sizeof(float); - GrillyBuffer bufW = pool.acquire(bytes); - GrillyBuffer bufGrad = pool.acquire(bytes); - GrillyBuffer bufM = pool.acquire(bytes); - GrillyBuffer bufV = pool.acquire(bytes); + // Same staging pattern as adamUpdate — all 4 buffers in-place updated. + GrillyBuffer bufWDL = pool.acquireDeviceLocal(bytes); + GrillyBuffer bufGradDL = pool.acquireDeviceLocal(bytes); + GrillyBuffer bufMDL = pool.acquireDeviceLocal(bytes); + GrillyBuffer bufVDL = pool.acquireDeviceLocal(bytes); - pool.upload(bufW, weights, bytes); - pool.upload(bufGrad, grad, bytes); - pool.upload(bufM, m, bytes); - pool.upload(bufV, v, bytes); + GrillyBuffer bufWStage = pool.acquireReadback(bytes); + GrillyBuffer bufGradStage = pool.acquireReadback(bytes); + GrillyBuffer bufMStage = pool.acquireReadback(bytes); + GrillyBuffer bufVStage = pool.acquireReadback(bytes); + + pool.upload(bufWStage, weights, bytes); + pool.upload(bufGradStage, grad, bytes); + pool.upload(bufMStage, m, bytes); + pool.upload(bufVStage, v, bytes); PipelineEntry pipe = cache.getOrCreate("adamw-update", 4, sizeof(AdamWParams)); std::vector bufInfos = { - {bufW.handle, 0, bytes}, - {bufGrad.handle, 0, bytes}, - {bufM.handle, 0, bytes}, - {bufV.handle, 0, bytes}, + {bufWDL.handle, 0, bytes}, + {bufGradDL.handle, 0, bytes}, + {bufMDL.handle, 0, bytes}, + {bufVDL.handle, 0, bytes}, }; VkDescriptorSet descSet = cache.allocDescriptorSet("adamw-update", bufInfos); @@ -94,20 +122,34 @@ void adamwUpdate(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, uint32_t gx = (p.totalWeights + 255) / 256; batch.begin(); + batch.copyBuffer(bufWStage, bufWDL, bytes); + batch.copyBuffer(bufGradStage, bufGradDL, bytes); + batch.copyBuffer(bufMStage, bufMDL, bytes); + batch.copyBuffer(bufVStage, bufVDL, bytes); + batch.transferComputeBarrier(); batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, 1, 1, &p, sizeof(p)); + batch.transferComputeBarrier(); + batch.copyBuffer(bufWDL, bufWStage, bytes); + batch.copyBuffer(bufGradDL, bufGradStage, bytes); + batch.copyBuffer(bufMDL, bufMStage, bytes); + batch.copyBuffer(bufVDL, bufVStage, bytes); batch.submitDeferred(); batch.waitForCompletion(); - pool.download(bufW, weights, bytes); - pool.download(bufGrad, grad, bytes); - pool.download(bufM, m, bytes); - pool.download(bufV, v, bytes); - - pool.release(bufW); - pool.release(bufGrad); - pool.release(bufM); - pool.release(bufV); + pool.download(bufWStage, weights, bytes); + pool.download(bufGradStage, grad, bytes); + pool.download(bufMStage, m, bytes); + pool.download(bufVStage, v, bytes); + + pool.release(bufWDL); + pool.release(bufGradDL); + pool.release(bufMDL); + pool.release(bufVDL); + pool.release(bufWStage); + pool.release(bufGradStage); + pool.release(bufMStage); + pool.release(bufVStage); } } // namespace ops diff --git a/cpp/src/ops/vsa_lm_forward.cpp b/cpp/src/ops/vsa_lm_forward.cpp new file mode 100644 index 0000000..ee7e04b --- /dev/null +++ b/cpp/src/ops/vsa_lm_forward.cpp @@ -0,0 +1,656 @@ +/// vsa_lm_forward.cpp — fused VSA-LM forward (AdditionLinear FFN + MindForge LoRA) on GPU, +/// Eigen backward on CPU. +/// +/// FFN layers use the addition-linear shader (L1 distance, no matmul). +/// Output projection uses fnn-linear (standard matmul). +/// MindForge LoRA adapters are forged on CPU (tiny). + +#include "grilly/ops/vsa_lm_forward.h" +#include "grilly/ops/batched_ops.h" + +#include +#include +#include +#include +#include + +namespace grilly { +namespace ops { + +namespace { + +static std::unordered_map g_vsa; +static int g_next_vsa = 1; + +static void transpose_vd(const float* w, float* wt, uint32_t v, uint32_t d) { + for (uint32_t r = 0; r < v; ++r) + for (uint32_t c = 0; c < d; ++c) + wt[c * v + r] = w[r * d + c]; +} + +struct AddLinPush { + uint32_t batch_size; + uint32_t in_features; + uint32_t out_features; + uint32_t use_bias; +}; + +struct LayerNormParams { + uint32_t batch_size; + uint32_t seq_len; + uint32_t features; + float eps; + uint32_t pass_type; +}; + +} // namespace + +VsaLmHandleCache& vsa_lm_get_cache(int handle) { + auto it = g_vsa.find(handle); + if (it == g_vsa.end()) + throw std::runtime_error("Invalid vsa_lm handle"); + return it->second; +} + +void vsa_lm_release(BufferPool& pool, int handle) { + auto it = g_vsa.find(handle); + if (it == g_vsa.end()) + return; + VsaLmHandleCache& h = it->second; + + pool.release(h.embedW); + pool.release(h.posW); + pool.release(h.outW); + pool.release(h.outWt); + + for (auto& lw : h.layers) { + pool.release(lw.ffnUpW); + pool.release(lw.ffnUpB); + pool.release(lw.ffnDownW); + pool.release(lw.ffnDownB); + pool.release(lw.lnGamma); + pool.release(lw.lnBeta); + } + + pool.release(h.bufIds); + pool.release(h.bufPosSlice); + pool.release(h.bufX); + pool.release(h.bufLnOut); + pool.release(h.bufFfnUp); + pool.release(h.bufSign); + pool.release(h.bufFfnDown); + pool.release(h.bufLogits); + pool.release(h.bufLnMean); + pool.release(h.bufLnVar); + + for (auto& b : h.bufActivations) + pool.release(b); + + g_vsa.erase(it); +} + +int vsa_lm_upload(BufferPool& pool, + uint32_t vocab, uint32_t d, uint32_t d_ffn, uint32_t max_seq, + const float* embed_w, const float* pos_w, + const std::vector& ffn_up_ws, + const std::vector& ffn_up_bs, + const std::vector& ffn_down_ws, + const std::vector& ffn_down_bs, + const std::vector& ln_gammas, + const std::vector& ln_betas, + const float* out_w, + uint32_t n_layers) { + + if (n_layers == 0 || d == 0 || d_ffn == 0 || vocab == 0 || max_seq == 0) + throw std::runtime_error("vsa_lm_upload: invalid dimensions"); + if (ffn_up_ws.size() != n_layers || ffn_up_bs.size() != n_layers || + ffn_down_ws.size() != n_layers || ffn_down_bs.size() != n_layers || + ln_gammas.size() != n_layers || ln_betas.size() != n_layers) + throw std::runtime_error("vsa_lm_upload: list length mismatch"); + + VsaLmHandleCache h; + h.vocab = vocab; + h.d = d; + h.dFfn = d_ffn; + h.maxSeq = max_seq; + h.nLayers = n_layers; + + // Embedding + size_t embed_bytes = size_t(vocab) * d * sizeof(float); + h.cpu_embed.resize(vocab * d); + std::memcpy(h.cpu_embed.data(), embed_w, embed_bytes); + h.embedW = pool.acquire(embed_bytes); + pool.upload(h.embedW, embed_w, embed_bytes); + + // Positional + size_t pos_bytes = size_t(max_seq) * d * sizeof(float); + h.cpu_pos.resize(max_seq * d); + std::memcpy(h.cpu_pos.data(), pos_w, pos_bytes); + h.posW = pool.acquire(pos_bytes); + pool.upload(h.posW, pos_w, pos_bytes); + + // Output projection + size_t out_bytes = size_t(vocab) * d * sizeof(float); + h.cpu_out_w.resize(vocab * d); + std::memcpy(h.cpu_out_w.data(), out_w, out_bytes); + h.outW = pool.acquire(out_bytes); + pool.upload(h.outW, out_w, out_bytes); + std::vector out_wt(d * vocab); + transpose_vd(out_w, out_wt.data(), vocab, d); + h.outWt = pool.acquire(out_wt.size() * sizeof(float)); + pool.upload(h.outWt, out_wt.data(), out_wt.size() * sizeof(float)); + + // Per-layer weights + h.layers.resize(n_layers); + h.cpu_ffn_up_w.resize(n_layers); + h.cpu_ffn_up_b.resize(n_layers); + h.cpu_ffn_down_w.resize(n_layers); + h.cpu_ffn_down_b.resize(n_layers); + h.cpu_ln_gamma.resize(n_layers); + h.cpu_ln_beta.resize(n_layers); + + for (uint32_t l = 0; l < n_layers; ++l) { + auto& lw = h.layers[l]; + + size_t up_w_bytes = size_t(d_ffn) * d * sizeof(float); + size_t up_b_bytes = size_t(d_ffn) * sizeof(float); + size_t down_w_bytes = size_t(d) * d_ffn * sizeof(float); + size_t down_b_bytes = size_t(d) * sizeof(float); + size_t ln_bytes = size_t(d) * sizeof(float); + + lw.ffnUpW = pool.acquire(up_w_bytes); + pool.upload(lw.ffnUpW, ffn_up_ws[l], up_w_bytes); + lw.ffnUpB = pool.acquire(up_b_bytes); + pool.upload(lw.ffnUpB, ffn_up_bs[l], up_b_bytes); + + lw.ffnDownW = pool.acquire(down_w_bytes); + pool.upload(lw.ffnDownW, ffn_down_ws[l], down_w_bytes); + lw.ffnDownB = pool.acquire(down_b_bytes); + pool.upload(lw.ffnDownB, ffn_down_bs[l], down_b_bytes); + + lw.lnGamma = pool.acquire(ln_bytes); + pool.upload(lw.lnGamma, ln_gammas[l], ln_bytes); + lw.lnBeta = pool.acquire(ln_bytes); + pool.upload(lw.lnBeta, ln_betas[l], ln_bytes); + + h.cpu_ffn_up_w[l].resize(d_ffn * d); + std::memcpy(h.cpu_ffn_up_w[l].data(), ffn_up_ws[l], up_w_bytes); + h.cpu_ffn_up_b[l].resize(d_ffn); + std::memcpy(h.cpu_ffn_up_b[l].data(), ffn_up_bs[l], up_b_bytes); + h.cpu_ffn_down_w[l].resize(d * d_ffn); + std::memcpy(h.cpu_ffn_down_w[l].data(), ffn_down_ws[l], down_w_bytes); + h.cpu_ffn_down_b[l].resize(d); + std::memcpy(h.cpu_ffn_down_b[l].data(), ffn_down_bs[l], down_b_bytes); + h.cpu_ln_gamma[l].resize(d); + std::memcpy(h.cpu_ln_gamma[l].data(), ln_gammas[l], ln_bytes); + h.cpu_ln_beta[l].resize(d); + std::memcpy(h.cpu_ln_beta[l].data(), ln_betas[l], ln_bytes); + } + + // Working buffers + size_t sd = size_t(max_seq) * d * sizeof(float); + size_t s_dffn = size_t(max_seq) * d_ffn * sizeof(float); + size_t logit_sz = size_t(max_seq) * vocab * sizeof(float); + + h.bufIds = pool.acquire(max_seq * sizeof(uint32_t)); + h.bufPosSlice= pool.acquire(sd); + h.bufX = pool.acquire(sd); + h.bufLnOut = pool.acquire(sd); + h.bufFfnUp = pool.acquire(s_dffn); + h.bufSign = pool.acquire(s_dffn); + h.bufFfnDown = pool.acquire(sd); + h.bufLogits = pool.acquire(logit_sz); + h.bufLnMean = pool.acquire(max_seq * sizeof(float)); + h.bufLnVar = pool.acquire(max_seq * sizeof(float)); + + h.bufActivations.resize(n_layers + 1); + for (uint32_t l = 0; l <= n_layers; ++l) + h.bufActivations[l] = pool.acquire(sd); + + int hid = g_next_vsa++; + g_vsa[hid] = std::move(h); + return hid; +} + +// ═══════════════════════════════════════════════════════════════════════════ +// Helper: record addition-linear dispatch into batch (no begin/submit) +// ═══════════════════════════════════════════════════════════════════════════ + +static void batchedAdditionLinear(CommandBatch& batch, PipelineCache& cache, + const GrillyBuffer& input, const GrillyBuffer& weight, + const GrillyBuffer& bias, GrillyBuffer& output, + uint32_t S, uint32_t d_in, uint32_t d_out) { + + PipelineEntry pipe = cache.getOrCreate("addition-linear", 4, sizeof(AddLinPush)); + + size_t inBytes = size_t(S) * d_in * sizeof(float); + size_t wBytes = size_t(d_out) * d_in * sizeof(float); + size_t bBytes = size_t(d_out) * sizeof(float); + size_t outBytes = size_t(S) * d_out * sizeof(float); + + std::vector bufs = { + {input.handle, 0, inBytes}, + {weight.handle, 0, wBytes}, + {bias.handle, 0, bBytes}, + {output.handle, 0, outBytes}, + }; + VkDescriptorSet desc = cache.allocDescriptorSet("addition-linear", bufs); + + AddLinPush push{S, d_in, d_out, 1}; + uint32_t total = S * d_out; + uint32_t gx = (total + 255) / 256; + + batch.dispatch(pipe.pipeline, pipe.layout, desc, gx, 1, 1, + &push, sizeof(push)); +} + +// ═══════════════════════════════════════════════════════════════════════════ +// Helper: record sign activation into batch +// ═══════════════════════════════════════════════════════════════════════════ + +static void batchedSignActivation(CommandBatch& batch, PipelineCache& cache, + const GrillyBuffer& input, GrillyBuffer& output, + uint32_t totalElements) { + + PipelineEntry pipe = cache.getOrCreate("sign-activation", 2, sizeof(uint32_t)); + + size_t bytes = size_t(totalElements) * sizeof(float); + std::vector bufs = { + {input.handle, 0, bytes}, + {output.handle, 0, bytes}, + }; + VkDescriptorSet desc = cache.allocDescriptorSet("sign-activation", bufs); + + uint32_t push = totalElements; + uint32_t gx = (totalElements + 255) / 256; + batch.dispatch(pipe.pipeline, pipe.layout, desc, gx, 1, 1, + &push, sizeof(push)); +} + +// ═══════════════════════════════════════════════════════════════════════════ +// Helper: record 3-pass layernorm into batch (pre-allocated mean/var bufs) +// ═══════════════════════════════════════════════════════════════════════════ + +static void batchedLayerNorm(CommandBatch& batch, PipelineCache& cache, + const GrillyBuffer& input, GrillyBuffer& output, + const GrillyBuffer& gamma, const GrillyBuffer& beta, + GrillyBuffer& meanBuf, GrillyBuffer& varBuf, + uint32_t S, uint32_t features) { + + PipelineEntry pipe = cache.getOrCreate("fnn-layernorm", 6, sizeof(LayerNormParams)); + + size_t elemBytes = size_t(S) * features * sizeof(float); + size_t paramBytes = size_t(features) * sizeof(float); + size_t statBytes = size_t(S) * sizeof(float); + + std::vector bufs = { + {input.handle, 0, elemBytes}, + {output.handle, 0, elemBytes}, + {gamma.handle, 0, paramBytes}, + {beta.handle, 0, paramBytes}, + {meanBuf.handle, 0, statBytes}, + {varBuf.handle, 0, statBytes}, + }; + VkDescriptorSet desc = cache.allocDescriptorSet("fnn-layernorm", bufs); + + // batch_size=1, seq_len=S + LayerNormParams p0{1, S, features, 1e-5f, 0}; + uint32_t gxPos = (S + 255) / 256; + batch.dispatch(pipe.pipeline, pipe.layout, desc, gxPos, 1, 1, &p0, sizeof(p0)); + batch.barrier(); + + LayerNormParams p1{1, S, features, 1e-5f, 1}; + batch.dispatch(pipe.pipeline, pipe.layout, desc, gxPos, 1, 1, &p1, sizeof(p1)); + batch.barrier(); + + LayerNormParams p2{1, S, features, 1e-5f, 2}; + uint32_t gxAll = (S * features + 255) / 256; + batch.dispatch(pipe.pipeline, pipe.layout, desc, gxAll, 1, 1, &p2, sizeof(p2)); +} + +// ═══════════════════════════════════════════════════════════════════════════ +// vsa_lm_forward_gpu +// ═══════════════════════════════════════════════════════════════════════════ + +void vsa_lm_forward_gpu(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, + VsaLmHandleCache& h, const int32_t* input_ids, + uint32_t seq_len, float* logits_out) { + + if (seq_len == 0 || seq_len > h.maxSeq) + throw std::runtime_error("vsa_lm_forward: invalid seq_len"); + + uint32_t S = seq_len; + uint32_t d = h.d; + uint32_t dF = h.dFfn; + uint32_t V = h.vocab; + + // Upload token IDs and position slice + pool.upload(h.bufIds, reinterpret_cast(input_ids), + S * sizeof(uint32_t)); + pool.upload(h.bufPosSlice, h.cpu_pos.data(), S * d * sizeof(float)); + + // Phase 1: embedding lookup + position add + batch.begin(); + batchedEmbeddingLookup(batch, cache, h.bufIds, h.embedW, h.bufX, 1, S, h.vocab, d); + batch.barrier(); + batchedAdd(batch, cache, h.bufX, h.bufPosSlice, S * d); + batch.barrier(); + batch.copyBuffer(h.bufX, h.bufActivations[0], S * d * sizeof(float)); + batch.submitDeferred(); + batch.waitForCompletion(); + + // Phase 2: per-layer forward + for (uint32_t l = 0; l < h.nLayers; ++l) { + auto& lw = h.layers[l]; + + batch.begin(); + + // (a) LayerNorm + batchedLayerNorm(batch, cache, h.bufX, h.bufLnOut, + lw.lnGamma, lw.lnBeta, + h.bufLnMean, h.bufLnVar, S, d); + batch.barrier(); + + // (d) Addition-linear up: (S, d) → (S, d_ffn) + batchedAdditionLinear(batch, cache, h.bufLnOut, lw.ffnUpW, lw.ffnUpB, + h.bufFfnUp, S, d, dF); + batch.barrier(); + + // (e) Sign activation + batchedSignActivation(batch, cache, h.bufFfnUp, h.bufSign, S * dF); + batch.barrier(); + + // (f) Addition-linear down: (S, d_ffn) → (S, d) + batchedAdditionLinear(batch, cache, h.bufSign, lw.ffnDownW, lw.ffnDownB, + h.bufFfnDown, S, dF, d); + batch.barrier(); + + // (h) Residual: x += ffn_down + batchedAdd(batch, cache, h.bufX, h.bufFfnDown, S * d); + batch.barrier(); + + // (i) Save activation + batch.copyBuffer(h.bufX, h.bufActivations[l + 1], S * d * sizeof(float)); + + batch.submitDeferred(); + batch.waitForCompletion(); + } + + // Phase 3: output projection — x @ out_w.T / sqrt(d) + batch.begin(); + batchedLinear(batch, cache, h.bufX, h.outW, nullptr, h.bufLogits, S, d, V); + batch.submitDeferred(); + batch.waitForCompletion(); + + // Download and scale + std::vector raw(S * V); + pool.download(h.bufLogits, raw.data(), S * V * sizeof(float)); + float scale = 1.0f / std::sqrt(static_cast(d)); + for (size_t i = 0; i < size_t(S) * V; ++i) + logits_out[i] = raw[i] * scale; +} + +// ═══════════════════════════════════════════════════════════════════════════ +// vsa_lm_backward_cpu — Eigen-based backward matching moe_backward_cpu pattern. +// +// AdditionLinear backward: +// grad_input[row, k] = -sum_col( grad_out[row, col] * sign(W[col, k] - x[row, k]) ) +// grad_W[col, k] = -sum_row( grad_out[row, col] * sign(W[col, k] - x[row, k]) ) +// grad_b[col] = sum_row( grad_out[row, col] ) +// ═══════════════════════════════════════════════════════════════════════════ + +VsaLmGradients vsa_lm_backward_cpu(const VsaLmHandleCache& h, + const int32_t* input_ids, uint32_t seq_len, + const float* grad_logits) { + + uint32_t S = seq_len; + uint32_t d = h.d; + uint32_t dF = h.dFfn; + uint32_t V = h.vocab; + uint32_t L = h.nLayers; + + if (S == 0 || S > h.maxSeq) + throw std::runtime_error("vsa_lm_backward: invalid seq_len"); + + using RM = Eigen::Matrix; + + // Output projection backward: dx = grad_logits @ out_w * scale + float scale = 1.0f / std::sqrt(static_cast(d)); + Eigen::Map GL(grad_logits, S, V); + Eigen::Map OW(h.cpu_out_w.data(), V, d); + + RM scaledGL = GL * scale; // (S, V) scaled + RM dx = scaledGL * OW; // (S, d) + + VsaLmGradients out; + + // grad_out_w = scaledGL.T @ x_final + // Need final activation — download it. We stored it as bufActivations[L] during forward. + // For CPU backward we'll recompute from saved CPU mirrors. + // Actually let's replay the forward on CPU to get activations. + + // CPU forward replay for activations + std::vector> acts(L + 1); + std::vector> ln_outs(L); + std::vector> ffn_ups(L); // after addition-linear up + std::vector> sign_outs(L); // after sign activation + + // Initial: embed + pos + acts[0].resize(S * d); + for (uint32_t s = 0; s < S; ++s) { + int32_t tok = input_ids[s]; + for (uint32_t j = 0; j < d; ++j) { + float e = (tok >= 0 && static_cast(tok) < V) + ? h.cpu_embed[tok * d + j] : 0.f; + float p = h.cpu_pos[s * d + j]; + acts[0][s * d + j] = e + p; + } + } + + for (uint32_t l = 0; l < L; ++l) { + const auto& uw = h.cpu_ffn_up_w[l]; + const auto& ub = h.cpu_ffn_up_b[l]; + const auto& dw = h.cpu_ffn_down_w[l]; + const auto& db = h.cpu_ffn_down_b[l]; + const auto& gm = h.cpu_ln_gamma[l]; + const auto& bt = h.cpu_ln_beta[l]; + + // LayerNorm + ln_outs[l].resize(S * d); + for (uint32_t s = 0; s < S; ++s) { + float mean = 0.f; + for (uint32_t j = 0; j < d; ++j) + mean += acts[l][s * d + j]; + mean /= float(d); + float var = 0.f; + for (uint32_t j = 0; j < d; ++j) { + float diff = acts[l][s * d + j] - mean; + var += diff * diff; + } + var /= float(d); + float inv_std = 1.0f / std::sqrt(var + 1e-5f); + for (uint32_t j = 0; j < d; ++j) { + float norm = (acts[l][s * d + j] - mean) * inv_std; + ln_outs[l][s * d + j] = gm[j] * norm + bt[j]; + } + } + + // Addition-linear up: (S, d) → (S, d_ffn) + ffn_ups[l].resize(S * dF); + for (uint32_t s = 0; s < S; ++s) { + for (uint32_t o = 0; o < dF; ++o) { + float dist = 0.f; + for (uint32_t k = 0; k < d; ++k) + dist += std::abs(uw[o * d + k] - ln_outs[l][s * d + k]); + ffn_ups[l][s * dF + o] = -dist + ub[o]; + } + } + + // Sign activation + sign_outs[l].resize(S * dF); + for (size_t i = 0; i < S * dF; ++i) + sign_outs[l][i] = (ffn_ups[l][i] > 0.f) ? 1.f : -1.f; + + // Addition-linear down: (S, d_ffn) → (S, d) + acts[l + 1].resize(S * d); + for (uint32_t s = 0; s < S; ++s) { + for (uint32_t o = 0; o < d; ++o) { + float dist = 0.f; + for (uint32_t k = 0; k < dF; ++k) + dist += std::abs(dw[o * dF + k] - sign_outs[l][s * dF + k]); + float ffn_val = -dist + db[o]; + acts[l + 1][s * d + o] = acts[l][s * d + o] + ffn_val; + } + } + } + + // grad_out_w = scaledGL.T @ x_final → (V, d) + Eigen::Map XF(acts[L].data(), S, d); + RM GOW = scaledGL.transpose() * XF; + out.grad_out_w.resize(V * d); + std::memcpy(out.grad_out_w.data(), GOW.data(), V * d * sizeof(float)); + + // Back-propagate through layers in reverse + out.grad_ffn_up_w.resize(L); + out.grad_ffn_up_b.resize(L); + out.grad_ffn_down_w.resize(L); + out.grad_ffn_down_b.resize(L); + out.grad_ln_gamma.resize(L); + out.grad_ln_beta.resize(L); + + // dx is currently (S, d) + for (int32_t l = L - 1; l >= 0; --l) { + const auto& uw = h.cpu_ffn_up_w[l]; + const auto& dw = h.cpu_ffn_down_w[l]; + + // dx passes through residual: grad to addition-linear down is dx + // Addition-linear down backward: + // grad_sign[s, k] = -sum_o( dx[s, o] * sign(dw[o, k] - sign_out[s, k]) ) + // grad_dw[o, k] = -sum_s( dx[s, o] * sign(dw[o, k] - sign_out[s, k]) ) + // grad_db[o] = sum_s( dx[s, o] ) + + out.grad_ffn_down_w[l].assign(d * dF, 0.f); + out.grad_ffn_down_b[l].assign(d, 0.f); + std::vector grad_sign(S * dF, 0.f); + + for (uint32_t s = 0; s < S; ++s) { + for (uint32_t o = 0; o < d; ++o) { + float go = dx(s, o); + out.grad_ffn_down_b[l][o] += go; + for (uint32_t k = 0; k < dF; ++k) { + float sgn = (dw[o * dF + k] > sign_outs[l][s * dF + k]) ? 1.f : + (dw[o * dF + k] < sign_outs[l][s * dF + k]) ? -1.f : 0.f; + out.grad_ffn_down_w[l][o * dF + k] += -go * sgn; + grad_sign[s * dF + k] += -go * sgn; + } + } + } + + // Sign activation backward: grad_ffn_up = grad_sign * 0 (sign is flat) + // Actually sign subgradient: d/dx sign(x) = 0 almost everywhere. + // But for training we use STE (straight-through estimator): + // grad_ffn_up = grad_sign (pass through) + std::vector& grad_ffn_up = grad_sign; // STE + + // Addition-linear up backward: + // grad_ln[s, k] = -sum_o( grad_up[s, o] * sign(uw[o, k] - ln_out[s, k]) ) + // grad_uw[o, k] = -sum_s( grad_up[s, o] * sign(uw[o, k] - ln_out[s, k]) ) + // grad_ub[o] = sum_s( grad_up[s, o] ) + + out.grad_ffn_up_w[l].assign(dF * d, 0.f); + out.grad_ffn_up_b[l].assign(dF, 0.f); + std::vector grad_ln(S * d, 0.f); + + for (uint32_t s = 0; s < S; ++s) { + for (uint32_t o = 0; o < dF; ++o) { + float gu = grad_ffn_up[s * dF + o]; + out.grad_ffn_up_b[l][o] += gu; + for (uint32_t k = 0; k < d; ++k) { + float sgn = (uw[o * d + k] > ln_outs[l][s * d + k]) ? 1.f : + (uw[o * d + k] < ln_outs[l][s * d + k]) ? -1.f : 0.f; + out.grad_ffn_up_w[l][o * d + k] += -gu * sgn; + grad_ln[s * d + k] += -gu * sgn; + } + } + } + + // LayerNorm backward (simplified — gamma * grad_ln_out passed to LN backward) + out.grad_ln_gamma[l].assign(d, 0.f); + out.grad_ln_beta[l].assign(d, 0.f); + + const auto& gm = h.cpu_ln_gamma[l]; + + // Recompute mean/invstd + std::vector means(S), inv_stds(S); + for (uint32_t s = 0; s < S; ++s) { + float m = 0.f; + for (uint32_t j = 0; j < d; ++j) + m += acts[l][s * d + j]; + m /= float(d); + means[s] = m; + float v = 0.f; + for (uint32_t j = 0; j < d; ++j) { + float diff = acts[l][s * d + j] - m; + v += diff * diff; + } + v /= float(d); + inv_stds[s] = 1.0f / std::sqrt(v + 1e-5f); + } + + // grad_beta = sum_s(grad_ln), grad_gamma = sum_s(grad_ln * norm) + for (uint32_t s = 0; s < S; ++s) { + for (uint32_t j = 0; j < d; ++j) { + float norm = (acts[l][s * d + j] - means[s]) * inv_stds[s]; + out.grad_ln_beta[l][j] += grad_ln[s * d + j]; + out.grad_ln_gamma[l][j] += grad_ln[s * d + j] * norm; + } + } + + // Backprop through layernorm to get grad w.r.t. input + // Using full layernorm backward formula + RM grad_x_ln(S, d); + for (uint32_t s = 0; s < S; ++s) { + float is = inv_stds[s]; + float m = means[s]; + + // dl_dxhat = grad_ln * gamma + Eigen::VectorXf dl_dxhat(d); + for (uint32_t j = 0; j < d; ++j) + dl_dxhat[j] = grad_ln[s * d + j] * gm[j]; + + float sum1 = dl_dxhat.sum(); + float sum2 = 0.f; + for (uint32_t j = 0; j < d; ++j) { + float xhat = (acts[l][s * d + j] - m) * is; + sum2 += dl_dxhat[j] * xhat; + } + + for (uint32_t j = 0; j < d; ++j) { + float xhat = (acts[l][s * d + j] - m) * is; + grad_x_ln(s, j) = is * (dl_dxhat[j] - sum1 / d - xhat * sum2 / d); + } + } + + // Residual: dx for next layer = dx (through residual) + grad_x_ln + RM new_dx = dx + grad_x_ln; + dx = new_dx; + } + + // Embedding gradient: scatter-add dx + out.grad_embed.assign(V * d, 0.f); + out.grad_pos.assign(h.maxSeq * d, 0.f); + for (uint32_t s = 0; s < S; ++s) { + int32_t tok = input_ids[s]; + if (tok >= 0 && static_cast(tok) < V) { + for (uint32_t j = 0; j < d; ++j) + out.grad_embed[tok * d + j] += dx(s, j); + } + for (uint32_t j = 0; j < d; ++j) + out.grad_pos[s * d + j] = dx(s, j); + } + + return out; +} + +} // namespace ops +} // namespace grilly diff --git a/docs/grl_v1_format.md b/docs/grl_v1_format.md new file mode 100644 index 0000000..a4adeba --- /dev/null +++ b/docs/grl_v1_format.md @@ -0,0 +1,24 @@ +# GRL v1 checkpoint format (`.grl`) + +Grilly’s native, pickle-free checkpoint container for model weights, optimizer metadata, and training scalars. + +## Layout + +1. **Header** (64 bytes): magic `GRLY`, `uint16` format version (1), `uint16` flags, `uint32` reserved, then `uint64` offsets/lengths for metadata JSON, tensor index JSON, and raw payload. +2. **Metadata JSON** (UTF-8): `schema: "grilly.checkpoint.v1"`, `framework: "grilly"`, optional keys such as `training_step`, `best_ppl`, `step`, `epoch`, and `extra`. +3. **Tensor index JSON**: ordered array of `{name, dtype, shape, offset, length}` with `offset`/`length` relative to the **start of the payload** section. +4. **Payload**: concatenated C-contiguous row-major tensor bytes (little-endian scalars in index). + +## dtypes + +Index encodes dtypes as `f32`, `f16`, `i64`, `i32`, `u8`. + +## API + +- Write: `grilly.utils.grl_checkpoint.save_grl(path, state_dict, metadata=...)`. +- Read: `grilly.utils.grl_checkpoint.load_grl(path, map_location=...)`. +- Torch-style: `grilly.torch_api.save` / `grilly.torch_api.load` (`.grl` only). + +## Versioning + +Format version **1** is fixed in `FORMAT_VERSION` in `utils/grl_checkpoint.py` and the C++ reader/writer. Future versions must bump the header version and preserve a migration path. diff --git a/docs/pre-v1.0/OPTIMIZATION_PARITY_TASKLIST.md b/docs/pre-v1.0/OPTIMIZATION_PARITY_TASKLIST.md new file mode 100644 index 0000000..09b6438 --- /dev/null +++ b/docs/pre-v1.0/OPTIMIZATION_PARITY_TASKLIST.md @@ -0,0 +1,210 @@ +# Pre-v1.0 Optimization + Parity Tasklist + +This tasklist is a post-upgrade pass focused on shipping a stable, fast pre-v1.0 runtime with clearer PyTorch parity guarantees. + +Scope: +- Python Vulkan runtime orchestration (`backend/*`) +- C++ bridge integration (`cpp/python/*`) +- Functional + module parity tests (`tests/parity`, targeted GPU tests) +- Performance/throughput work that directly affects user-visible training/inference latency + +--- + +## Baseline (already landed) + +- `FnnChainRecorder` with `linear` / `relu` / `softmax`, `read`, and `read_multiple` (`backend/fnn_chain.py`) +- `VulkanTensor.prepare_for_dispatch()` residency bind path (`utils/tensor_conversion.py`) +- `_prepare_input()` GPU-resident fast path (`backend/base.py`) +- Conv backward-weight GPU GEMM path and parity test (`backend/conv.py`, `tests/test_conv_backward_weight_gemm.py`) +- MoE backward stability fix at production shape (`cpp/src/ops/moe_forward.cpp`): corrected output-projection grad matmul and added bounds checks; Python binding wired via `moe_backward_gpu(...)` entrypoint with safe CPU fallback (`cpp/python/bindings_moe.cpp`) +- VSA-LM fused C++ forward/backward (`grilly_core.vsa_lm_*`): AdditionLinear FFN (L1 distance shader) + sign activation + LayerNorm + output projection in one C++ call. CPU Eigen backward with full AdditionLinear gradient (STE for sign). New files: `cpp/src/ops/vsa_lm_forward.cpp`, `cpp/include/grilly/ops/vsa_lm_forward.h`, `cpp/python/bindings_vsa_lm.cpp`, `shaders/sign-activation.glsl`. Tests: `tests/test_vsa_lm_forward.py` (shape + parity). + +--- + +## Priority Feature Roadmap (user-directed) + +Order locked by product priority: +1. GPU tokenizer +2. Sentence-transformers +3. Transformers compatibility (target: near 1:1 behavior/signature coverage) +4. PyTorch -> Grilly converter + +--- + +## P0: GPU Tokenizer (highest priority) + +### P0.1 Core GPU tokenizer runtime +- Status: `[~]` (CPU parity path landed; native GPU tokenization still open) +- Goal: + - Tokenization and detokenization run on GPU-backed buffers for high-throughput inference/training pipelines. +- Tasks: + - [x] Add `grilly.tokenizers` module with `Tokenizer` interface (`encode`, `decode`, `batch_encode`). Implementation: `tokenizer_impl/` → `grilly.tokenizers`; default `Tokenizer` is `FastTokenizer` (Rust `tokenizers` library, not `transformers`). + - [ ] Implement BPE/WordPiece fast path with GPU kernels for pretokenized merge/scoring stages. + - [x] Keep exact CPU fallback for unsupported edge cases (identical outputs). Current fallback is the Rust `tokenizers` CPU path; optional `numpy_to_input_ids_buffers` in `tokenizer_impl/gpu.py` for staging. + - [ ] Add `VulkanTensor`-friendly API: accept list[str] and return ids/attention masks in GPU-friendly layout (`wrap_ids_as_vulkan_tensors` exists behind `GRILLY_GPU_TOKENIZER=1` until wired). +- Acceptance: + - Deterministic token IDs vs reference tokenizer on supported models. + - >=2x throughput improvement vs CPU tokenizer on large batch benchmarks. + +### P0.2 HF tokenizer interoperability +- Status: `[~]` (Rust `tokenizer.json` + Hub fallbacks + BERT/DistilBERT/T5/mT5 parity tests; raw `spiece.model`-only repos still open) +- Tasks: + - [x] Load Hugging Face–compatible tokenizer assets: `tokenizer.json` from Hub (`huggingface_hub`), local dir, or file path (`tokenizer_impl/loader.py`). No `transformers` dependency in the Grilly tokenizer package. + - [x] Hub repos without root `tokenizer.json` (e.g. `google/mt5-small`): fall back to `onnx/tokenizer.json` when present — same Rust `tokenizers` pipeline, parity vs HF reference. + - [ ] Add loader for **only** SentencePiece assets (`spiece.model` / no JSON export) with encode/decode parity vs HF. + - [~] Validation suite: `tests/tokenizers/test_gpu_tokenizer_parity.py` (BERT, DistilBERT); `tests/sentencepiece/test_sentencepiece_parity.py` (`t5-small`, `google/mt5-small`, special-token cases). Current run: all 6 tests passing in local env. +- Acceptance: + - Asset compatibility documented and tested for top target checkpoints (including SentencePiece-backed models like T5/LLaMA-family tokenizers). + +--- + +## P1: Sentence-Transformers support + +### P1.1 Inference parity and API surface +- Status: `[ ]` +- Goal: + - `SentenceTransformer`-style embedding API with drop-in ergonomics for common usage. +- Tasks: + - [ ] Add `grilly.sentence_transformers` wrapper with `encode()` behavior-compatible options (`batch_size`, `normalize_embeddings`, device semantics). + - [ ] Implement pooling strategies (`mean`, `cls`, `max`) and normalization parity. + - [ ] Validate cosine similarity/semantic search outputs against reference pipelines. +- Acceptance: + - Embedding outputs within tolerance across target ST models; API docs include known deltas. + +### P1.2 GPU-first embedding pipeline +- Status: `[ ]` +- Tasks: + - [ ] Route tokenizer -> encoder -> pooling through chain recorder where possible. + - [ ] Add `rec.embedding_lookup(ids, table) -> handle` so embed -> first layer stays GPU-resident (no CPU round-trip). + - [ ] Add `read_multiple` fan-out examples for MoE-style encoder blocks. +- Acceptance: + - Single-submit batching demonstrated in benchmarked ST inference path. + +--- + +## P2: Transformers compatibility (1:1 target) + +### P2.1 Signature and config compatibility +- Status: `[ ]` +- Goal: + - Match `transformers` module signatures and config behavior for core models. +- Tasks: + - [ ] Add compatibility matrix by model family (BERT, RoBERTa, MiniLM, GPT2-class decoder). + - [ ] Match key forward signatures (`input_ids`, `attention_mask`, `token_type_ids`, `position_ids`, `past_key_values` where applicable). + - [ ] Align output objects (`last_hidden_state`, `pooler_output`, logits) and shape conventions. +- Acceptance: + - Core families pass reference compatibility tests with documented exceptions. + +### P2.2 Numerical and behavioral parity +- Status: `[ ]` +- Tasks: + - [ ] Golden parity tests versus HF forward outputs (fp32 tolerances per op/family). + - [ ] Attention behavior policy and masks parity (`causal`, `padding`, mixed masks). + - [ ] Tokenizer-model handshake tests (special tokens, truncation/padding behavior). +- Acceptance: + - "1:1" means no user-visible API breakage for covered families in documented scenarios. + +--- + +## P3: PyTorch -> Grilly converter (after compatibility) + +### P3.1 Converter core +- Status: `[ ]` +- Goal: + - Convert PyTorch/HF checkpoints and model graphs into runnable Grilly modules. +- Tasks: + - [ ] Add `grilly.convert.from_pytorch(...)` entrypoint. + - [ ] Implement state_dict key mapping + tensor layout transforms. + - [ ] Generate conversion report (mapped/unmapped params, warnings, unsupported ops). +- Acceptance: + - Supported architectures convert and run inference with parity checks. + +### P3.2 CLI + migration UX +- Status: `[ ]` +- Tasks: + - [ ] Add CLI (`python -m grilly.convert ...`) with dry-run and validation modes. + - [ ] Add migration cookbook examples from PyTorch/HF to Grilly runtime. +- Acceptance: + - Users can convert, validate, and run model with a single documented workflow. + +--- + +## Supporting Throughput Track (parallel, non-blocking) + +### T1 Chain recorder dependency-aware barriers +- Status: `[ ]` +- Tasks: + - [ ] Remove unconditional post-dispatch barriers where no RAW hazard exists. + - [ ] Add `force_barrier=True` debug mode. + - [ ] Add `rec.linear_backward(grad_out, input, weights) -> (grad_input_handle, grad_weight_handle)` for GPU backward matmul chaining. + - [ ] Add MoE fan-out microbench (4/8 experts). + +### T2 VulkanTensor residency hardening +- Status: `[ ]` +- Tasks: + - [ ] Guard for older `grilly_core` binaries lacking `gpu_handle_if_valid`. + - [ ] Add residency counters (fallback uploads/downloads) for profiling. + +### T3 C2+/C4+ infrastructure +- Status: `[ ]` +- Tasks: + - [ ] INT8 GEMM tiling and tuning (C2+) with MoQE-focused benchmarks (4-bit/8-bit expert paths). + - [ ] Add dequant-in-shader path so quantized weights do not require FP32 shadow copies during inference. + - [ ] Transfer queue + staging ring + overlap experiments (C4+). + +### T4 Autograd GPU-resident graph execution +- Status: `[ ]` +- Tasks: + - [ ] Make `Variable.backward()` detect chainable matmul -> relu -> matmul regions. + - [ ] Route detected backward regions through chain recorder (batched GPU dispatches) instead of one bridge call per topo node. + - [ ] Add fallback path to current traversal for unsupported ops while preserving gradient correctness. +- Acceptance: + - Backward pass correctness parity vs current autograd; fewer fence waits in profiled training steps. + +### T5 Fused CE + softmax chain op +- Status: `[ ]` +- Tasks: + - [ ] Add `rec.cross_entropy(logits_handle, targets) -> (loss_handle, grad_logits_handle)` fused forward+backward on GPU. + - [ ] Integrate with chain recorder API so LM training step can stay in a single submit region through loss/grad. + - [ ] Add correctness tests vs reference CE+softmax backward and benchmark on `(seq, vocab)`-sized logits. +- Acceptance: + - Fused CE op matches reference gradients within tolerance and reduces per-step synchronization overhead. + +--- + +## Verification Commands + +```bash +# Existing runtime/perf guardrails +uv run pytest tests/test_fnn_chain.py -v +uv run pytest tests/test_vulkan_tensor_residency.py -v +uv run pytest tests/test_conv_backward_weight_gemm.py -v +uv run pytest tests/parity/ -m parity -v +uv run python benchmarks/benchmark_conv_backward_weight.py +uv run python benchmarks/benchmark_int8_gemm.py + +# Tokenizer parity (Rust CPU path vs HF reference in tests) +uv run pytest tests/tokenizers/ -v +uv run pytest tests/sentencepiece/ -v +# VSA-LM fused forward/backward +uv run pytest tests/test_vsa_lm_forward.py -v +uv run pytest tests/test_moe_forward.py -v + +# Planned (new feature suites to add as implemented) +# uv run pytest tests/sentence_transformers/ -v +# uv run pytest tests/transformers_compat/ -v +# uv run pytest tests/converter/ -v +# uv run pytest tests/autograd_chain/ -v +# uv run pytest tests/moe_quant/ -v +``` + +--- + +## Exit Criteria for pre-v1.0 cut + +- GPU tokenizer shipped and validated against reference tokenizer outputs. *(Current: Rust `tokenizers` CPU path + parity tests for BERT/DistilBERT/T5/mT5; GPU merge/BPE kernels still to ship.)* +- Sentence-transformers pipeline shipped with documented parity bounds. +- Transformers compatibility target achieved for declared core families (documented matrix). +- PyTorch -> Grilly converter ships with dry-run + validation report. +- No known correctness regressions in chain recorder/residency paths. +- Performance regressions detectable with benchmark baselines committed in docs. diff --git a/examples/experimental_backend_vsa.py b/examples/experimental_backend_vsa.py index a93d67c..6aef93f 100644 --- a/examples/experimental_backend_vsa.py +++ b/examples/experimental_backend_vsa.py @@ -7,11 +7,11 @@ import numpy as np try: - from grilly.backend.base import VULKAN_AVAILABLE + from grilly.backend.base import VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE from grilly.backend.core import VulkanCore from grilly.backend.experimental.vsa import VulkanVSA - if not VULKAN_AVAILABLE: + if not VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE: raise ImportError("Vulkan not available") except ImportError: print("Vulkan backend not available. Skipping GPU examples.") diff --git a/functional/__init__.py b/functional/__init__.py index 46d0837..9789243 100644 --- a/functional/__init__.py +++ b/functional/__init__.py @@ -10,6 +10,13 @@ softmax, softplus, ) +from .mf_activations import ( + mf_relu, + mf_sigmoid, + mf_sigmoid_01, + mf_softmax, + mf_softplus, +) from .attention import ( attention, flash_attention2, @@ -96,6 +103,11 @@ "silu", "softmax", "softplus", + "mf_softmax", + "mf_softplus", + "mf_sigmoid", + "mf_sigmoid_01", + "mf_relu", # Linear "linear", # Normalization diff --git a/functional/mf_activations.py b/functional/mf_activations.py new file mode 100644 index 0000000..8175653 --- /dev/null +++ b/functional/mf_activations.py @@ -0,0 +1,65 @@ +""" +Multiplication-free (or low-mul) activations for VSA / addition-style pipelines. + +- **mf_softmax**: ReLU-normalized probabilities — no ``exp``; only ``max``, subtract, + ``relu``, sum, divide (same spirit as ``addition-linear.glsl``: add/sub/abs in the core). +- **mf_softplus**: Algebraic softplus — ``0.5 * (x + sqrt(x^2 + c))`` with ``c = 4/beta²``. + Smooth, no ``log``/``exp``; uses ``sqrt`` (and scalar scaling), not GEMM-style multiply. +- **mf_sigmoid**: Rational sigmoid — ``x / (1 + |x|)``, bounded in ``(-1, 1)``; scale to + ``(0, 1)`` with ``0.5 * (1 + ...)`` if needed via :func:`mf_sigmoid_01`. +""" + +from __future__ import annotations + +import numpy as np + + +def mf_softmax(x: np.ndarray, dim: int = -1, eps: float = 1e-12) -> np.ndarray: + """ReLU-normalized softmax: ``relu(x - max) / sum(relu(x - max))``. + + No exponential; suitable as a sparse, addition-heavy alternative to softmax. + """ + x = np.asarray(x, dtype=np.float32) + axis = dim if dim >= 0 else x.ndim + dim + m = np.max(x, axis=axis, keepdims=True) + z = np.maximum(x - m, 0.0).astype(np.float32) + s = np.sum(z, axis=axis, keepdims=True, dtype=np.float64) + denom = np.maximum(s, eps) + y = (z / denom).astype(np.float32) + # Degenerate: all logits tied → z is all zeros → use uniform distribution + tot = np.sum(y, axis=axis, keepdims=True) + nfeat = float(x.shape[axis]) + unif = np.ones_like(x, dtype=np.float32) / nfeat + return np.where(tot > 1e-8, y, unif).astype(np.float32) + + +def mf_softplus(x: np.ndarray, beta: float = 1.0) -> np.ndarray: + """Algebraic (exp-free) softplus: ``(x + sqrt(x^2 + 4/beta^2)) / 2``. + + Matches ``softplus`` shape qualitatively; uses ``sqrt`` instead of ``log``/``exp``. + """ + x = np.asarray(x, dtype=np.float32) + b = float(beta) + if b <= 0: + raise ValueError("beta must be positive") + c = 4.0 / (b * b) + s = np.sqrt(x * x + c) + return (0.5 * (x + s)).astype(np.float32) + + +def mf_sigmoid(x: np.ndarray) -> np.ndarray: + """Rational sigmoid: ``x / (1 + |x|)``, range ``(-1, 1)``.""" + x = np.asarray(x, dtype=np.float32) + ax = np.abs(x) + 1.0 + return (x / ax).astype(np.float32) + + +def mf_sigmoid_01(x: np.ndarray) -> np.ndarray: + """Map :func:`mf_sigmoid` to ``(0, 1)``: ``0.5 * (1 + x/(1+|x|))``.""" + return (0.5 * (1.0 + mf_sigmoid(x))).astype(np.float32) + + +def mf_relu(x: np.ndarray) -> np.ndarray: + """``max(0, x)`` — multiplication-free nonlinearity.""" + x = np.asarray(x, dtype=np.float32) + return np.maximum(x, 0.0).astype(np.float32) diff --git a/nn/__init__.py b/nn/__init__.py index b4f543c..fa2218b 100644 --- a/nn/__init__.py +++ b/nn/__init__.py @@ -52,6 +52,9 @@ lt, matmul, max, + mf_sigmoid, + mf_softmax, + mf_softplus, mean, min, mse_loss, @@ -72,6 +75,7 @@ # Shapes reshape, sigmoid, + sign, silu, # Trigonometric sin, @@ -141,7 +145,13 @@ MemoryRead, MemoryWrite, ) +from .addition_linear import AdditionLinear +from . import functional +from . import init from .module import Module +from .module_list import ModuleList +from . import utils +from .vsa_lm import VsaLmModel from .modules import ( GCU, GELU, @@ -280,6 +290,10 @@ __all__ = [ # Base class "Module", + "ModuleList", + "functional", + "init", + "utils", # Standard layers "Linear", "LayerNorm", @@ -370,6 +384,9 @@ "DomainPredictor", "DomainClassifier", "ExpertCombiner", + # AdditionLinear (VSA) + "AdditionLinear", + "VsaLmModel", # Affect layers (when implemented) "AffectMLP", # Capsule layers @@ -414,6 +431,9 @@ "neg", "pow", "matmul", + "mf_sigmoid", + "mf_softmax", + "mf_softplus", # Reductions "sum", "mean", @@ -425,6 +445,7 @@ # Activations "relu", "sigmoid", + "sign", "tanh", "exp", "log", diff --git a/nn/addition_linear.py b/nn/addition_linear.py new file mode 100644 index 0000000..928c45d --- /dev/null +++ b/nn/addition_linear.py @@ -0,0 +1,150 @@ +""" +AdditionLinear — multiplication-free linear layer using L1 distance. + +output[b, o] = -sum_k |weight[o, k] - input[b, k]| + bias[o] + +Uses: addition-linear.glsl (compute shader). +Backward uses the subgradient: d/dx |a - b| = -sign(a - b). +""" + +import numpy as np + +from ._helpers import ( + _PARAMETER_AVAILABLE, + _USE_CPP_BRIDGE, + ParameterClass, + _bridge, + _bridge_to_numpy, + _create_param_wrapper, + _get_param_array, +) +from .module import Module + + +class AdditionLinear(Module): + """Multiplication-free linear layer (L1 / Manhattan distance). + + ``y = -||W - x||_1 + b`` per output neuron. Uses the ``addition-linear`` + Vulkan compute shader when available, with a NumPy CPU fallback. + """ + + def __init__(self, in_features: int, out_features: int, bias: bool = True): + super().__init__() + self.in_features = in_features + self.out_features = out_features + + limit = np.sqrt(6.0 / (in_features + out_features)) + weight_data = np.random.uniform( + -limit, limit, (out_features, in_features) + ).astype(np.float32) + + self.weight = _create_param_wrapper(weight_data) + self.register_parameter("weight", self.weight) + + if bias: + self.bias_param = _create_param_wrapper( + np.zeros(out_features, dtype=np.float32) + ) + self.register_parameter("bias", self.bias_param) + else: + self.bias_param = None + + self._last_input = None + + @property + def d_in(self) -> int: + """Alias for ``in_features`` (PyTorch-style naming).""" + return self.in_features + + def forward(self, x) -> np.ndarray: + x_arr = np.ascontiguousarray(x, dtype=np.float32) + self._last_input = x_arr + + w = _get_param_array(self.weight) + b = _get_param_array(self.bias_param) if self.bias_param is not None else None + + if _USE_CPP_BRIDGE and hasattr(_bridge, "addition_linear"): + result = _bridge_to_numpy(_bridge.addition_linear(x_arr, w, b)) + if result is not None: + return result + + if x_arr.ndim == 1: + x_arr = x_arr.reshape(1, -1) + original_shape = x_arr.shape + if x_arr.ndim > 2: + batch_seq = int(np.prod(original_shape[:-1])) + x_2d = x_arr.reshape(batch_seq, original_shape[-1]) + else: + batch_seq = original_shape[0] + x_2d = x_arr + + out = np.empty((batch_seq, self.out_features), dtype=np.float32) + for o in range(self.out_features): + dist = np.sum(np.abs(w[o] - x_2d), axis=1) + out[:, o] = -dist + if b is not None: + out[:, o] += b[o] + + if len(original_shape) > 2: + return out.reshape(original_shape[:-1] + (self.out_features,)) + return out + + def backward(self, grad_output: np.ndarray, x: np.ndarray = None) -> np.ndarray: + """Backward pass — subgradient of L1 distance. + + grad_input[s, k] = -sum_o( grad_out[s, o] * sign(W[o, k] - x[s, k]) ) + grad_W[o, k] = -sum_s( grad_out[s, o] * sign(W[o, k] - x[s, k]) ) + grad_b[o] = sum_s( grad_out[s, o] ) + """ + if x is None: + x = self._last_input + if x is None: + raise RuntimeError("backward requires input; call forward first or pass x=") + + x_arr = np.asarray(x, dtype=np.float32) + w = _get_param_array(self.weight) + go = np.asarray(grad_output, dtype=np.float32) + + if x_arr.ndim > 2: + S = int(np.prod(x_arr.shape[:-1])) + x_2d = x_arr.reshape(S, -1) + go_2d = go.reshape(S, -1) + else: + x_2d = x_arr if x_arr.ndim == 2 else x_arr.reshape(1, -1) + go_2d = go if go.ndim == 2 else go.reshape(1, -1) + + S, d_in = x_2d.shape + d_out = self.out_features + + grad_input = np.zeros_like(x_2d) + grad_w = np.zeros_like(w) + + for o in range(d_out): + sgn = np.sign(w[o][np.newaxis, :] - x_2d) # (S, d_in) + g_col = go_2d[:, o:o+1] # (S, 1) + grad_input -= g_col * sgn + grad_w[o] -= np.sum(g_col * sgn, axis=0) + + if self.weight is not None: + if not hasattr(self.weight, "grad") or self.weight.grad is None: + self.weight.grad = grad_w + else: + self.weight.grad += grad_w + + if self.bias_param is not None: + grad_b = np.sum(go_2d, axis=0) + if not hasattr(self.bias_param, "grad") or self.bias_param.grad is None: + self.bias_param.grad = grad_b + else: + self.bias_param.grad += grad_b + + if x_arr.ndim > 2: + return grad_input.reshape(x_arr.shape) + return grad_input + + def __repr__(self): + return ( + f"AdditionLinear(in_features={self.in_features}, " + f"out_features={self.out_features}, " + f"bias={self.bias_param is not None})" + ) diff --git a/nn/autograd.py b/nn/autograd.py index ec2b1bf..e440905 100644 --- a/nn/autograd.py +++ b/nn/autograd.py @@ -181,6 +181,16 @@ def numpy(self) -> np.ndarray: """Get numpy array (detached from graph).""" return self.data.copy() + def __array__(self, dtype=None) -> np.ndarray: + """NumPy array protocol — allows ``np.asarray(var)``, ``np.matmul(var, w)``, + ``np.dot(var, w)``, etc. to operate on the underlying ``self.data`` ndarray + without an explicit ``.data`` access. Required for grilly's existing + numpy-native layer ``forward`` code to accept ``Tensor`` inputs from the + ``torch_api`` facade transparently.""" + if dtype is not None: + return self.data.astype(dtype, copy=False) + return self.data + def detach(self) -> "Variable": """Return a new Variable detached from the computation graph.""" return Variable(self.data.copy(), requires_grad=False) @@ -199,6 +209,10 @@ def zero_grad(self): """Clear the gradient.""" self.grad = None + def numel(self) -> int: + """Number of elements (PyTorch ``Tensor.numel``).""" + return int(self.data.size) + def backward( self, grad_output: np.ndarray | None = None, @@ -955,11 +969,23 @@ def backward(grad): def matmul(a, b) -> Variable: - """Matrix multiplication: a @ b (with GPU backward support)""" + """Matrix multiplication: a @ b (with GPU forward + backward)""" a = _ensure_variable(a) b = _ensure_variable(b) - result_data = np.matmul(a.data, b.data) + # GPU forward: use _bridge.linear for 2D matrix multiply + result_data = None + if a.data.ndim == 2 and b.data.ndim == 2 and _grad_enabled: + try: + from grilly.backend import _bridge as _ag_bridge + # _bridge.linear(x, w) = x @ w.T, so pass b.T to get a @ b + gpu_result = _ag_bridge.linear(a.data, b.data.T) + if gpu_result is not None: + result_data = np.asarray(gpu_result) if not isinstance(gpu_result, np.ndarray) else gpu_result + except Exception: + pass + if result_data is None: + result_data = np.matmul(a.data, b.data) a_data, b_data = a.data, b.data @@ -1276,6 +1302,20 @@ def backward(grad): return Variable(result_data, requires_grad=a.requires_grad, grad_fn=grad_fn) +def sign(a) -> Variable: + """Elementwise sign; gradient is zero (subgradient).""" + a = _ensure_variable(a) + result_data = np.sign(a.data) + + def backward(grad): + """Run backward.""" + + return (np.zeros_like(a.data),) + + grad_fn = _make_backward("Sign", [a], backward) + return Variable(result_data, requires_grad=a.requires_grad, grad_fn=grad_fn) + + def clamp(a, min_val=None, max_val=None) -> Variable: """Clamp values to [min_val, max_val]""" a = _ensure_variable(a) @@ -1748,6 +1788,76 @@ def backward(grad): return Variable(result_data, requires_grad=a.requires_grad, grad_fn=grad_fn) +def mf_softmax(a, dim: int = -1, eps: float = 1e-12) -> Variable: + """ReLU-normalized softmax (no exp). See :mod:`grilly.functional.mf_activations`.""" + a = _ensure_variable(a) + axis = dim if dim >= 0 else a.data.ndim + dim + m = np.max(a.data, axis=axis, keepdims=True) + z = np.maximum(a.data - m, 0.0).astype(np.float32) + s = np.sum(z, axis=axis, keepdims=True, dtype=np.float64) + denom = np.maximum(s, eps) + y = (z / denom).astype(np.float32) + tot = np.sum(y, axis=axis, keepdims=True) + nfeat = float(a.data.shape[axis]) + unif = np.ones_like(a.data, dtype=np.float32) / nfeat + result_data = np.where(tot > 1e-8, y, unif).astype(np.float32) + + def backward(grad): + """Subgradient with STE on the max (mask active where z > 0).""" + m2 = np.max(a.data, axis=axis, keepdims=True) + z2 = np.maximum(a.data - m2, 0.0).astype(np.float32) + s2 = np.sum(z2, axis=axis, keepdims=True, dtype=np.float64) + s2 = np.maximum(s2, eps) + y2 = (z2 / s2).astype(np.float32) + tot2 = np.sum(y2, axis=axis, keepdims=True) + y2 = np.where(tot2 > 1e-8, y2, np.ones_like(a.data, dtype=np.float32) / nfeat) + gz = (grad - np.sum(grad * y2, axis=axis, keepdims=True)) / s2.astype(np.float32) + mask = (z2 > 0).astype(np.float32) + return (gz * mask,) + + grad_fn = _make_backward("MfSoftmax", [a], backward) + return Variable(result_data, requires_grad=a.requires_grad, grad_fn=grad_fn) + + +def mf_softplus(a, beta: float = 1.0) -> Variable: + """Algebraic softplus (sqrt form, no exp). See :mod:`grilly.functional.mf_activations`.""" + a = _ensure_variable(a) + b = float(beta) + if b <= 0: + raise ValueError("beta must be positive") + x = a.data + c = 4.0 / (b * b) + s = np.sqrt(x * x + c) + result_data = (0.5 * (x + s)).astype(np.float32) + + def backward(grad): + """Run backward.""" + + ds_dx = x / (s + 1e-12) + d = 0.5 * (1.0 + ds_dx) + return (grad * d,) + + grad_fn = _make_backward("MfSoftplus", [a], backward) + return Variable(result_data, requires_grad=a.requires_grad, grad_fn=grad_fn) + + +def mf_sigmoid(a) -> Variable: + """Rational sigmoid x / (1 + |x|).""" + a = _ensure_variable(a) + x = a.data + ax = np.abs(x) + 1.0 + result_data = (x / ax).astype(np.float32) + + def backward(grad): + """d/dx x/(1+|x|) = 1/(1+|x|)^2 on x!=0.""" + + denom = (1.0 + np.abs(x)) ** 2 + return (grad / denom,) + + grad_fn = _make_backward("MfSigmoid", [a], backward) + return Variable(result_data, requires_grad=a.requires_grad, grad_fn=grad_fn) + + # ============================================================================ # Trigonometric Functions # ============================================================================ diff --git a/nn/embedding.py b/nn/embedding.py index 621d1df..1f9d4a5 100644 --- a/nn/embedding.py +++ b/nn/embedding.py @@ -34,34 +34,65 @@ def __init__(self, num_embeddings: int, embedding_dim: int): # Register parameter self.register_parameter("weight", self.weight) - def forward(self, x: np.ndarray) -> np.ndarray: - """Forward pass using embedding-lookup.glsl""" - backend = self._get_backend() + def forward(self, x): + """Forward pass using embedding-lookup.glsl. + + Autograd: when called through ``Module.__call__`` with a LongTensor + input (standard pattern for token ids), the output is wrapped in a + ``Variable`` whose ``GradFn`` calls ``self.backward(grad_output, ids)`` + on loss.backward(). The GradFn has an empty inputs list because token + ids are discrete and don't receive gradients — we only use the + closure to populate ``self.weight.grad`` via the existing + ``self.backward``. + + Mirrors ``nn.Linear.forward``'s autograd wiring. + """ + try: + from grilly.nn.autograd import GradFn as _GradFn + from grilly.nn.autograd import Variable as _Variable + from grilly.nn.autograd import _grad_enabled + except ImportError: + _Variable = None # type: ignore[assignment] + _GradFn = None # type: ignore[assignment] + _grad_enabled = False + weight = _get_param_array(self.weight) - gpu_lookup_enabled = os.getenv("GRILLY_EMBEDDING_GPU_LOOKUP", "1").strip().lower() not in { - "0", - "false", - "no", - } + # Extract the raw token-id ndarray (caller may pass a LongTensor + # subclass or plain ndarray). Keep the original around so the + # backward closure has access to the indices. + ids_data = np.asarray(x) + + # Numpy fancy-index lookup. The legacy ``backend.learning.embedding_lookup`` + # expects a (batch, seq) shape and adds a leading batch dim for 1D + # inputs, which breaks downstream ops that don't expect the extra + # axis. Fancy indexing preserves the exact input shape with a + # trailing embedding dim, matches what PyTorch's Embedding does, + # and is plenty fast for the sizes we care about (LUT bandwidth + # bound — even at 8192 vocab × 384 dim it's under 1 ms). + result = weight[ids_data.astype(np.int32)] + + # ---- Autograd wiring ---- + # Always wrap the output in a Variable with a GradFn when autograd is + # enabled — Embedding's input is discrete (token ids), so there's no + # upstream Variable to chain to, but we still need the GradFn so that + # ``loss.backward()`` can flow back here and update ``weight.grad``. if ( - gpu_lookup_enabled - and hasattr(backend, "learning") - and hasattr(backend.learning, "embedding_lookup") + _GradFn is not None + and _grad_enabled + and isinstance(result, np.ndarray) + and not isinstance(result, _Variable) ): - try: - return backend.learning.embedding_lookup( - x, - weight, - return_gpu_tensor=self._return_gpu_tensor, - ) - except Exception: - pass # Fall back to CPU - - # CPU fallback - if isinstance(x, np.ndarray): - return weight[x.astype(np.int32)] - return weight[x] + def backward_fn(grad_output): + # self.backward populates self.weight.grad in-place. + # Returns None for the input gradient (token ids are discrete). + self.backward(np.asarray(grad_output), ids_data) + return () # no upstream inputs + + grad_fn = _GradFn("Embedding", backward_fn, []) + return _Variable(np.asarray(result), requires_grad=True, grad_fn=grad_fn) + + return result def backward(self, grad_output: np.ndarray, x: np.ndarray = None) -> np.ndarray: """ diff --git a/nn/functional.py b/nn/functional.py new file mode 100644 index 0000000..00189f9 --- /dev/null +++ b/nn/functional.py @@ -0,0 +1,22 @@ +"""``torch.nn.functional``-compatible API (cross-entropy, activations).""" + +from grilly.functional.mf_activations import ( + mf_relu, + mf_sigmoid, + mf_sigmoid_01, + mf_softmax, + mf_softplus, +) +from grilly.nn.autograd import gelu +from grilly.torch_api.functional import cross_entropy, softplus + +__all__ = [ + "cross_entropy", + "softplus", + "gelu", + "mf_softmax", + "mf_softplus", + "mf_sigmoid", + "mf_sigmoid_01", + "mf_relu", +] diff --git a/nn/init.py b/nn/init.py new file mode 100644 index 0000000..9c7dff5 --- /dev/null +++ b/nn/init.py @@ -0,0 +1,21 @@ +"""``torch.nn.init``-style weight initialization (numpy arrays / Parameters).""" + +from __future__ import annotations + +from grilly.utils.initialization import ( + kaiming_normal_, + kaiming_uniform_, + normal_, + uniform_, + xavier_normal_, + xavier_uniform_, +) + +__all__ = [ + "kaiming_normal_", + "kaiming_uniform_", + "normal_", + "uniform_", + "xavier_normal_", + "xavier_uniform_", +] diff --git a/nn/linear.py b/nn/linear.py index bfcc4ec..5bbb60b 100644 --- a/nn/linear.py +++ b/nn/linear.py @@ -188,46 +188,105 @@ def zero_grad(self): if self.bias is not None: self.register_parameter("bias", self.bias) - def forward(self, x) -> np.ndarray: - """Forward pass — GPU-first, always dispatches through GPU backend.""" + def forward(self, x): + """Forward pass — GPU-first, always dispatches through GPU backend. + + Autograd: when ``x`` is an autograd ``Variable`` (or carries + ``requires_grad``), the output is wrapped in a ``Variable`` with a + ``GradFn`` that calls ``self.backward(grad_output, x_data)`` during + ``loss.backward()``. ``self.backward`` already populates + ``self.weight.grad`` and ``self.bias.grad`` from the existing + ``fnn-linear-backward.glsl`` kernel, so the AdamW step picks them up + through ``param.grad`` like any other PyTorch-style optimizer. + + Without this wiring, ``loss.backward()`` produced gradients on the + Variable wrapping the cross-entropy logits but the chain stopped at + ``Linear`` (no ``grad_fn`` -> autograd traversal terminated), so + weights silently never updated. + """ + # Detect autograd input (Variable, or Tensor — Tensor is a Variable subclass) + # and remember its underlying ndarray for the backward closure. + try: + from grilly.nn.autograd import GradFn as _GradFn + from grilly.nn.autograd import Variable as _Variable + from grilly.nn.autograd import _grad_enabled + except ImportError: + _Variable = None # type: ignore[assignment] + _GradFn = None # type: ignore[assignment] + _grad_enabled = False + + x_var: "_Variable | None" = None # type: ignore[name-defined] + if _Variable is not None and isinstance(x, _Variable): + x_var = x + x_data = x.data + else: + x_data = x # raw ndarray (or VulkanTensor) + weight = _get_param_array(self.weight) bias = _get_param_array(self.bias) if self.bias is not None else None def cpu_linear(): - x_arr = np.asarray(x, dtype=np.float32) + x_arr = np.asarray(x_data, dtype=np.float32) w_arr = np.asarray(weight, dtype=np.float32) out = x_arr @ w_arr.T if bias is not None: out = out + np.asarray(bias, dtype=np.float32) return np.asarray(out, dtype=np.float32) - # C++ bridge fast path (handles both numpy and VulkanTensor via __array__) + # ---- Run the existing forward path (unchanged) ---- + result: np.ndarray | None = None if _USE_CPP_BRIDGE: def gpu_linear(): - return _bridge_to_numpy(_bridge.linear(x, weight, bias)) + return _bridge_to_numpy(_bridge.linear(x_data, weight, bias)) # Auto-fastest policy only for numpy-in/numpy-out path. - if isinstance(x, np.ndarray) and not self._return_gpu_tensor: - batch = int(np.prod(x.shape[:-1])) if x.ndim > 1 else 1 - in_features = int(x.shape[-1]) if x.ndim > 0 else self.in_features + if isinstance(x_data, np.ndarray) and not self._return_gpu_tensor: + batch = int(np.prod(x_data.shape[:-1])) if x_data.ndim > 1 else 1 + in_features = int(x_data.shape[-1]) if x_data.ndim > 0 else self.in_features op_key = f"linear:{batch}x{in_features}x{self.out_features}" - return choose_fastest(op_key, gpu_linear, cpu_linear) + result = choose_fastest(op_key, gpu_linear, cpu_linear) + else: + result = gpu_linear() + + if result is None: + # Legacy Python Vulkan backend (fallback) or final CPU path + backend = self._get_backend() + if hasattr(backend, "fnn") and hasattr(backend.fnn, "linear"): + result = backend.fnn.linear( + x_data, + weight, + bias, + return_gpu_tensor=self._return_gpu_tensor, + ) + else: + result = cpu_linear() + + # ---- Autograd wiring ---- + # If the input came from autograd, wrap result so loss.backward() + # can flow back through this layer and populate weight.grad / bias.grad + # via the existing self.backward() implementation. We bypass + # ``_make_backward`` (which short-circuits when no input requires grad) + # because the *weight* always requires grad even if the input doesn't — + # otherwise a frozen-input training loop silently never updates weights. + if ( + x_var is not None + and _GradFn is not None + and _grad_enabled + and isinstance(result, np.ndarray) + and not isinstance(result, _Variable) + ): + x_data_for_backward = x_data # capture for closure - result = gpu_linear() - if result is not None: - return result + def backward_fn(grad_output): + # self.backward populates self.weight.grad and self.bias.grad + # in-place and returns grad_input. + grad_input = self.backward(np.asarray(grad_output), x_data_for_backward) + return (grad_input,) - # Legacy Python Vulkan backend (fallback) - backend = self._get_backend() - if hasattr(backend, "fnn") and hasattr(backend.fnn, "linear"): - return backend.fnn.linear( - x, - weight, - bias, - return_gpu_tensor=self._return_gpu_tensor, - ) + grad_fn = _GradFn("Linear", backward_fn, [x_var]) + return _Variable(np.asarray(result), requires_grad=True, grad_fn=grad_fn) - return cpu_linear() + return result def backward(self, grad_output: np.ndarray, x: np.ndarray = None) -> np.ndarray: """ diff --git a/nn/module.py b/nn/module.py index 75058b2..9a159f0 100644 --- a/nn/module.py +++ b/nn/module.py @@ -65,30 +65,60 @@ def __init__(self): self._use_device_local = False # DEVICE_LOCAL VRAM buffers def _get_backend(self): - """Execute get backend.""" - + """Get the legacy VulkanCompute backend if available, else None. + + The legacy backend (``backend.compute.VulkanCompute``) requires the + PyPI ``vulkan`` Python package (ctypes bindings) and is only used + today by a few slow-path fallbacks (e.g. ``fnn.xavier_init`` at + layer construction time). The real forward/backward path goes + through the C++ ``_bridge`` which only needs ``grilly_core`` and has + no Python ``vulkan`` dependency. + + Returning ``None`` when the legacy backend can't init lets callers + that already guard with ``hasattr(backend, "fnn")`` gracefully fall + through to their CPU reference path (e.g. numpy xavier init), which + is harmless at construction time since the fast path takes over at + forward time via the bridge. + """ if self._backend is None: - from ..utils.device_manager import get_device_manager + try: + from ..utils.device_manager import get_device_manager - self._backend = get_device_manager().vulkan + self._backend = get_device_manager().vulkan + except Exception: + self._backend = None return self._backend def _convert_input(self, x: np.ndarray | Any): """ Convert input to Vulkan-compatible format (GPU-first). - VulkanTensor inputs always pass through without downloading to CPU. - numpy/PyTorch inputs are converted to VulkanTensor when the C++ backend - is available, otherwise to float32 numpy arrays. - Preserves integer dtypes for index arrays (e.g., token IDs). + Pass-through priority: + 1. ``torch_api.Tensor`` / ``LongTensor`` — kept as-is so user + ``forward()`` code keeps torch-style methods (``.unsqueeze``, + ``.reshape``, ``.mean(dim=...)``). Works alongside grilly's + existing numpy-native layers because ``Variable.__array__`` + lets ``np.matmul``/``np.dot`` operate on Tensor inputs directly. + 2. ``VulkanTensor`` — GPU-resident, always pass through. + 3. Everything else (numpy, PyTorch, …) — convert to Vulkan-compatible. - Args: - x: Input (PyTorch tensor, numpy array, VulkanTensor, or other) - - Returns: - VulkanTensor or numpy array + Preserves integer dtypes for index arrays (e.g., token IDs). """ - # GPU-first: always pass VulkanTensor through without CPU round-trip + # 1. torch_api Tensor wrappers + autograd Variable: pass through + # unchanged. Variable matters because Tensor inherits from Variable + # and ops like ``.reshape`` / ``.mean(dim=...)`` / arithmetic return + # raw Variable instances rather than the Tensor subclass. + try: + from grilly.nn.autograd import Variable as _Variable + from grilly.torch_api.tensor import LongTensor as _TorchLong + from grilly.torch_api.tensor import Tensor as _TorchTensor + + if isinstance(x, (_TorchTensor, _TorchLong, _Variable)): + return x + except ImportError: + pass + + # 2. GPU-first: always pass VulkanTensor through without CPU round-trip if TENSOR_CONVERSION_AVAILABLE: from ..utils.tensor_conversion import VulkanTensor @@ -121,15 +151,56 @@ def forward(self, *args, **kwargs): raise NotImplementedError def __call__(self, *args, **kwargs): - # Automatically convert PyTorch tensor inputs to numpy - """Invoke the callable instance.""" + """Invoke the callable instance. + + If any positional arg was torch-style (``torch_api.Tensor``, + ``LongTensor``, or autograd ``Variable``) and the forward returned a + raw ndarray, wrap the output in ``Tensor`` so chained calls + (``x = embed(ids); x = layer(x); x = next(x)``) preserve torch-style + type through user-defined Module subclasses. + + Three classes count as a "torch input": + - ``LongTensor`` — model entry points use it for token ids; the next + layer (Embedding) returns floats and the chain has to start in + torch-style mode from the first step. + - ``Tensor`` — the public torch_api wrapper users construct directly. + - ``Variable`` — what ``Tensor.reshape``/``.mean(dim=...)``/arithmetic + ops actually return today (they delegate to module-level functions + that build raw Variable, not the Tensor subclass). Without this + branch the type is silently lost after the first reshape and the + chain falls back to numpy. + + ``Parameter`` (ndarray subclass) is left untouched on the way out so + re-wrapping a stored weight doesn't break gradient bookkeeping. + """ + try: + from grilly.nn.autograd import Variable as _Variable + from grilly.torch_api.tensor import LongTensor as _TorchLong + from grilly.torch_api.tensor import Tensor as _TorchTensor + + saw_torch_input = any( + isinstance(a, (_TorchTensor, _TorchLong, _Variable)) for a in args + ) + except ImportError: + _TorchTensor = None # type: ignore[assignment] + _TorchLong = None # type: ignore[assignment] + _Variable = None # type: ignore[assignment] + saw_torch_input = False converted_args = tuple(self._convert_input(arg) for arg in args) - # Convert keyword arguments that might be tensors converted_kwargs = { - k: self._convert_input(v) if self._is_tensor_like(v) else v for k, v in kwargs.items() + k: self._convert_input(v) if self._is_tensor_like(v) else v + for k, v in kwargs.items() } - return self.forward(*converted_args, **converted_kwargs) + out = self.forward(*converted_args, **converted_kwargs) + + # Wrap raw ndarray outputs back to Tensor when a torch input was seen. + # Skip ndarray subclasses (Parameter, LongTensor) and Variable outputs + # — those already carry a torch-compatible API. + if saw_torch_input and _TorchTensor is not None: + if isinstance(out, np.ndarray) and type(out) is np.ndarray: + return _TorchTensor(out, requires_grad=False) + return out def _is_tensor_like(self, obj: Any) -> bool: """Check if object is tensor-like and needs conversion""" @@ -226,6 +297,39 @@ def register_parameter(self, name: str, param: np.ndarray | None): self._parameters[name] = param + def register_buffer(self, name: str, tensor: np.ndarray | Any | None, persistent: bool = True): + """ + Register a non-trainable buffer (e.g. running state not in autograd). + + Args: + name: Buffer attribute name + tensor: Numpy array or None to remove + persistent: If False, omitted from state_dict (PyTorch parity; optional) + """ + if tensor is None: + self._buffers.pop(name, None) + if hasattr(self, name): + delattr(self, name) + return + try: + from grilly.torch_api.tensor import Tensor as _TorchTensor + except ImportError: + _TorchTensor = None # type: ignore[assignment] + if _TorchTensor is not None and isinstance(tensor, _TorchTensor): + self._buffers[name] = tensor + setattr(self, name, tensor) + elif hasattr(tensor, "data") and not isinstance(tensor, np.ndarray): + arr = np.asarray(tensor.data) + self._buffers[name] = arr + setattr(self, name, arr) + else: + arr = np.asarray(tensor) + self._buffers[name] = arr + setattr(self, name, arr) + if not persistent: + # Mark non-persistent with a private set on the buffer entry + setattr(self, f"_np_buffer_{name}", True) + def state_dict(self) -> dict[str, Any]: """Execute state dict.""" @@ -243,6 +347,19 @@ def state_dict(self) -> dict[str, Any]: state[name] = param.data.copy() # ParamWrapper else: state[name] = param + try: + from grilly.torch_api.tensor import Tensor as _TorchTensor + except ImportError: + _TorchTensor = None # type: ignore[assignment] + for name, buf in self._buffers.items(): + if getattr(self, f"_np_buffer_{name}", False): + continue + if _TorchTensor is not None and isinstance(buf, _TorchTensor): + state[name] = buf.data.copy() + elif isinstance(buf, np.ndarray): + state[name] = buf.copy() + else: + state[name] = np.asarray(buf) for name, module in self._modules.items(): state[name] = module.state_dict() return state @@ -256,10 +373,41 @@ def load_state_dict(self, state_dict: dict[str, Any]): self._parameters[name] = state_dict[name].copy() else: self._parameters[name] = state_dict[name] + try: + from grilly.torch_api.tensor import Tensor as _TorchTensor + except ImportError: + _TorchTensor = None # type: ignore[assignment] + for name in list(self._buffers.keys()): + if name in state_dict: + prev = self._buffers.get(name) + raw = state_dict[name] + if _TorchTensor is not None and isinstance(prev, _TorchTensor): + self._buffers[name] = _TorchTensor( + np.asarray(raw, dtype=np.float32).copy(), requires_grad=False + ) + else: + val = np.asarray(raw, dtype=np.float32) + if val.dtype != np.float32 and np.issubdtype(val.dtype, np.floating): + val = val.astype(np.float32) + self._buffers[name] = val + setattr(self, name, self._buffers[name]) for name, module in self._modules.items(): if name in state_dict: module.load_state_dict(state_dict[name]) + def named_buffers(self): + """Yield (name, buffer) for all buffers including child modules.""" + for name, buf in self._buffers.items(): + yield name, buf + for prefix, module in self._modules.items(): + for n, b in module.named_buffers(): + yield f"{prefix}.{n}", b + + def buffers(self): + """Iterator over buffers.""" + for _, b in self.named_buffers(): + yield b + def gpu_mode(self, enable=True, device_local=True): """Enable GPU-resident output (returns VulkanTensor instead of numpy). @@ -315,6 +463,36 @@ def llama_cpp(self): return self.to("llama-cpp") + def __setattr__(self, name: str, value: Any) -> None: + """Auto-register ``Parameter`` / child ``Module`` attributes (PyTorch parity). + + Without this hook, ``self.weight = nn.Parameter(...)`` and + ``self.lin = nn.Linear(...)`` only land in ``self.__dict__``, never in + ``self._parameters`` / ``self._modules``. ``parameters()`` then walks + empty dicts and the optimizer sees no params to update — silently + breaking every Module subclass that uses the standard PyTorch + attribute-assignment idiom. + """ + if isinstance(value, Parameter): + # Ensure base __init__ has run; if not, fall back to plain assign. + params = self.__dict__.get("_parameters") + if params is not None: + # Drop any prior child module / buffer with the same name first. + self.__dict__.get("_modules", {}).pop(name, None) + self.__dict__.get("_buffers", {}).pop(name, None) + params[name] = value + object.__setattr__(self, name, value) + return + elif isinstance(value, Module): + modules = self.__dict__.get("_modules") + if modules is not None: + self.__dict__.get("_parameters", {}).pop(name, None) + self.__dict__.get("_buffers", {}).pop(name, None) + modules[name] = value + object.__setattr__(self, name, value) + return + object.__setattr__(self, name, value) + def __repr__(self): """Return a debug representation.""" diff --git a/nn/module_list.py b/nn/module_list.py new file mode 100644 index 0000000..0992069 --- /dev/null +++ b/nn/module_list.py @@ -0,0 +1,48 @@ +"""ModuleList — PyTorch-compatible container of submodules.""" + +from __future__ import annotations + +from typing import Any, Iterator + +from .module import Module + + +class ModuleList(Module): + """Holds submodules in a list. Acts like ``nn.ModuleList`` in PyTorch.""" + + def __init__(self, modules: list[Module] | None = None): + super().__init__() + self._list: list[Module] = [] + if modules: + for i, m in enumerate(modules): + self._modules[str(i)] = m + self._list.append(m) + + def __getitem__(self, idx: int | slice) -> Module | list[Module]: + if isinstance(idx, slice): + return self._list[idx] + return self._list[idx] + + def __setitem__(self, idx: int, module: Module) -> None: + self._modules[str(idx)] = module + self._list[idx] = module + + def __len__(self) -> int: + return len(self._list) + + def __iter__(self) -> Iterator[Module]: + return iter(self._list) + + def append(self, module: Module) -> "ModuleList": + idx = len(self._list) + self._modules[str(idx)] = module + self._list.append(module) + return self + + def extend(self, modules: list[Module]) -> "ModuleList": + for m in modules: + self.append(m) + return self + + def forward(self, *args: Any, **kwargs: Any) -> Any: + raise RuntimeError("ModuleList has no forward; iterate over submodules instead.") diff --git a/nn/normalization_modules.py b/nn/normalization_modules.py index 36b8b69..f75805c 100644 --- a/nn/normalization_modules.py +++ b/nn/normalization_modules.py @@ -37,25 +37,75 @@ def __init__(self, normalized_shape: int, eps: float = 1e-5): self.register_parameter("weight", self.weight) self.register_parameter("bias", self.bias) - def forward(self, x: np.ndarray) -> np.ndarray: - """Forward pass using fnn-layernorm.glsl (GPU-first)""" + def forward(self, x): + """Forward pass using fnn-layernorm.glsl (GPU-first). + + Autograd: when ``x`` is a ``Variable``, the output is wrapped in a + ``Variable`` whose ``GradFn`` calls ``self.backward(grad_output, x)`` + on loss.backward(). ``self.backward`` already populates + ``self.weight.grad`` and ``self.bias.grad`` in-place, so AdamW picks + them up through ``param.grad``. Mirrors ``nn.Linear.forward``'s + autograd wiring — see that file for the long-form rationale. + """ + try: + from grilly.nn.autograd import GradFn as _GradFn + from grilly.nn.autograd import Variable as _Variable + from grilly.nn.autograd import _grad_enabled + except ImportError: + _Variable = None # type: ignore[assignment] + _GradFn = None # type: ignore[assignment] + _grad_enabled = False + + x_var = None + if _Variable is not None and isinstance(x, _Variable): + x_var = x + x_data = x.data + else: + x_data = x + weight = _get_param_array(self.weight) bias = _get_param_array(self.bias) - # C++ bridge fast path (handles both numpy and VulkanTensor via __array__) + # ---- Existing forward path ---- + result = None if _USE_CPP_BRIDGE: - result = _bridge_to_numpy(_bridge.layernorm(x, weight, bias, self.eps)) - if result is not None: - return result - - backend = self._get_backend() - return backend.fnn.layernorm( - x, - weight, - bias, - eps=self.eps, - return_gpu_tensor=self._return_gpu_tensor, - ) + result = _bridge_to_numpy(_bridge.layernorm(x_data, weight, bias, self.eps)) + + if result is None: + backend = self._get_backend() + if backend is not None and hasattr(backend, "fnn"): + result = backend.fnn.layernorm( + x_data, + weight, + bias, + eps=self.eps, + return_gpu_tensor=self._return_gpu_tensor, + ) + else: + # CPU fallback — matches self.backward() math + mean = np.mean(x_data, axis=-1, keepdims=True) + var = np.var(x_data, axis=-1, keepdims=True) + normalized = (x_data - mean) / np.sqrt(var + self.eps) + result = normalized * weight + bias + + # ---- Autograd wiring ---- + if ( + x_var is not None + and _GradFn is not None + and _grad_enabled + and isinstance(result, np.ndarray) + and not isinstance(result, _Variable) + ): + x_data_for_backward = x_data + + def backward_fn(grad_output): + grad_input = self.backward(np.asarray(grad_output), x_data_for_backward) + return (grad_input,) + + grad_fn = _GradFn("LayerNorm", backward_fn, [x_var]) + return _Variable(np.asarray(result), requires_grad=True, grad_fn=grad_fn) + + return result def backward(self, grad_output: np.ndarray, x: np.ndarray = None) -> np.ndarray: """ diff --git a/nn/parameter.py b/nn/parameter.py index fb9494e..cf8d42d 100644 --- a/nn/parameter.py +++ b/nn/parameter.py @@ -23,6 +23,8 @@ def __new__(cls, data: np.ndarray, requires_grad: bool = True): data: Numpy array containing parameter values requires_grad: Whether this parameter requires gradients (default: True) """ + if hasattr(data, "data") and not isinstance(data, np.ndarray): + data = getattr(data, "data", data) obj = np.asarray(data, dtype=np.float32).view(cls) obj.requires_grad = requires_grad obj.grad = None @@ -42,6 +44,45 @@ def zero_grad(self): else: self.grad = np.zeros_like(self, dtype=np.float32) + def uniform_(self, a: float = 0.0, b: float = 1.0) -> "Parameter": + """In-place uniform fill (PyTorch ``Tensor.uniform_``).""" + self[:] = np.random.uniform(a, b, self.shape).astype(np.float32) + return self + + def zero_(self) -> "Parameter": + """Fill with zeros (PyTorch ``Tensor.zero_``).""" + self.fill(0.0) + return self + + # ------------------------------------------------------------------ + # PyTorch-style shape methods so user ``forward`` code can do + # ``self.weight.unsqueeze(0)`` / ``.view(...)`` / ``.mean(dim=...)`` + # without having to know that Parameter is a numpy ndarray subclass. + # ------------------------------------------------------------------ + def unsqueeze(self, dim: int) -> np.ndarray: + """Insert a length-1 axis at *dim* (PyTorch ``Tensor.unsqueeze``).""" + return np.expand_dims(np.asarray(self), dim) + + def view(self, *shape) -> np.ndarray: + """Alias for ``reshape`` matching PyTorch's ``Tensor.view`` signature.""" + if len(shape) == 1 and isinstance(shape[0], (tuple, list)): + shape = tuple(shape[0]) + return np.asarray(self).reshape(shape) + + def mean(self, dim=None, keepdims: bool = False, **kwargs): # type: ignore[override] + """``Tensor.mean(dim=...)`` — accepts both ``dim=`` (torch) and + ``axis=`` (numpy) so it composes cleanly with ``np.mean(p, axis=...)``. + """ + axis = kwargs.pop("axis", dim) + if kwargs: + # Forward any remaining numpy kwargs (dtype, out, where, ...). + return np.asarray(self).mean(axis=axis, keepdims=keepdims, **kwargs) + return np.asarray(self).mean(axis=axis, keepdims=keepdims) + + def detach(self) -> "Parameter": + """Return a Parameter detached from any (future) autograd graph.""" + return Parameter(np.array(self, copy=True), requires_grad=False) + def __repr__(self): """Return a debug representation.""" diff --git a/nn/utils.py b/nn/utils.py new file mode 100644 index 0000000..2df2888 --- /dev/null +++ b/nn/utils.py @@ -0,0 +1,61 @@ +"""``torch.nn.utils`` — gradient clipping and helpers.""" + +from __future__ import annotations + +import math +from collections.abc import Iterable +from typing import Any + +import numpy as np + + +def clip_grad_norm_( + parameters: Iterable[Any], + max_norm: float, + norm_type: float = 2.0, +) -> float: + """ + Clip gradient norm of an iterable of parameters (in-place). + + Mirrors PyTorch: per-tensor norm, then global norm of those norms. + """ + params = [p for p in parameters if p is not None] + grads: list[np.ndarray] = [] + for p in params: + g = getattr(p, "grad", None) + if g is not None: + grads.append(np.asarray(g, dtype=np.float64)) + + if not grads: + return 0.0 + + if norm_type in (math.inf, float("inf")): + total_norm = max(float(np.max(np.abs(g))) for g in grads) + else: + norms = [] + for g in grads: + if norm_type in (2.0, 2): + norms.append(math.sqrt(float(np.sum(g * g)))) + else: + norms.append(float(np.sum(np.abs(g) ** norm_type) ** (1.0 / norm_type))) + stacked = np.array(norms, dtype=np.float64) + if norm_type in (2.0, 2): + total_norm = float(np.sqrt(np.sum(stacked * stacked))) + else: + total_norm = float(np.sum(np.abs(stacked) ** norm_type) ** (1.0 / norm_type)) + + clip_coef = max_norm / (total_norm + 1e-6) if max_norm > 0 else 1.0 + if clip_coef < 1.0: + for p in params: + g = getattr(p, "grad", None) + if g is None: + continue + if isinstance(g, np.ndarray): + g *= clip_coef + else: + try: + p.grad = np.asarray(g, dtype=np.float32) * clip_coef + except Exception: + pass + + return float(total_norm) diff --git a/nn/vsa_lm.py b/nn/vsa_lm.py new file mode 100644 index 0000000..621548b --- /dev/null +++ b/nn/vsa_lm.py @@ -0,0 +1,386 @@ +""" +VsaLmModel — VSA Language Model with AdditionLinear FFN layers. + +Wraps the fused C++ forward/backward (``grilly_core.vsa_lm_*``) for +GPU-accelerated training, with transparent NumPy CPU fallback. + +Usage:: + + model = VsaLmModel(vocab=1000, d_model=256, d_ffn=512, max_seq=512, n_layers=6) + logits = model(input_ids) # (seq, vocab) + grads = model.backward(input_ids, grad_logits) # dict of gradients + model.step(grads, lr=1e-3) # SGD update +""" + +from __future__ import annotations + +import numpy as np + +from ._helpers import _get_param_array, _create_param_wrapper +from .module import Module +from .addition_linear import AdditionLinear + + +def _try_bridge(): + try: + from ..backend import _bridge + if _bridge.is_available(): + return _bridge + except Exception: + pass + return None + + +class VsaLmModel(Module): + """Full VSA-LM: Embedding → [LayerNorm → AdditionLinear FFN → Sign → AdditionLinear → Residual] × L → Output Proj. + + When the C++ fused kernel is available, ``forward``/``backward`` dispatch to + a single ``grilly_core`` call (GPU forward, Eigen backward). Otherwise + falls back to pure NumPy. + """ + + def __init__( + self, + vocab: int, + d_model: int, + d_ffn: int, + max_seq: int, + n_layers: int, + init_std: float = 0.02, + ): + super().__init__() + self.vocab = vocab + self.d_model = d_model + self.d_ffn = d_ffn + self.max_seq = max_seq + self.n_layers = n_layers + + rng = np.random.default_rng() + + self.embed_w = _create_param_wrapper( + rng.standard_normal((vocab, d_model)).astype(np.float32) * init_std + ) + self.pos_w = _create_param_wrapper( + rng.standard_normal((max_seq, d_model)).astype(np.float32) * init_std + ) + self.out_w = _create_param_wrapper( + rng.standard_normal((vocab, d_model)).astype(np.float32) * init_std + ) + + self.register_parameter("embed_w", self.embed_w) + self.register_parameter("pos_w", self.pos_w) + self.register_parameter("out_w", self.out_w) + + self.ffn_up: list[AdditionLinear] = [] + self.ffn_down: list[AdditionLinear] = [] + self.ln_gammas: list = [] + self.ln_betas: list = [] + + for l in range(n_layers): + up = AdditionLinear(d_model, d_ffn, bias=True) + down = AdditionLinear(d_ffn, d_model, bias=True) + self.ffn_up.append(up) + self.ffn_down.append(down) + self._modules[f"ffn_up_{l}"] = up + self._modules[f"ffn_down_{l}"] = down + + gamma = _create_param_wrapper(np.ones(d_model, dtype=np.float32)) + beta = _create_param_wrapper(np.zeros(d_model, dtype=np.float32)) + self.ln_gammas.append(gamma) + self.ln_betas.append(beta) + self.register_parameter(f"ln_gamma_{l}", gamma) + self.register_parameter(f"ln_beta_{l}", beta) + + self._handle: int | None = None + self._bridge = _try_bridge() + + # ── GPU handle lifecycle ──────────────────────────────────────────── + + def _upload(self): + """Upload all weights to GPU via bridge (idempotent).""" + if self._handle is not None: + return + br = self._bridge + if br is None: + return + + h = br.vsa_lm_upload( + _get_param_array(self.embed_w), + _get_param_array(self.pos_w), + [_get_param_array(u.weight) for u in self.ffn_up], + [_get_param_array(u.bias_param) for u in self.ffn_up], + [_get_param_array(d.weight) for d in self.ffn_down], + [_get_param_array(d.bias_param) for d in self.ffn_down], + [_get_param_array(g) for g in self.ln_gammas], + [_get_param_array(b) for b in self.ln_betas], + _get_param_array(self.out_w), + self.n_layers, self.d_model, self.d_ffn, + ) + self._handle = h + + def _sync_weights(self): + """Re-upload updated weights to existing GPU handle.""" + br = self._bridge + if br is None or self._handle is None: + return + br.vsa_lm_update_weights( + self._handle, + _get_param_array(self.embed_w), + _get_param_array(self.pos_w), + [_get_param_array(u.weight) for u in self.ffn_up], + [_get_param_array(u.bias_param) for u in self.ffn_up], + [_get_param_array(d.weight) for d in self.ffn_down], + [_get_param_array(d.bias_param) for d in self.ffn_down], + [_get_param_array(g) for g in self.ln_gammas], + [_get_param_array(b) for b in self.ln_betas], + _get_param_array(self.out_w), + ) + + def release(self): + """Free GPU resources.""" + if self._handle is not None and self._bridge is not None: + self._bridge.vsa_lm_release(self._handle) + self._handle = None + + def __del__(self): + try: + self.release() + except Exception: + pass + + # ── Forward ────────────────────────────────────────────────────────── + + def forward(self, input_ids: np.ndarray) -> np.ndarray: + """Run forward pass; returns logits (seq_len, vocab). + + Tries the fused GPU path first, falls back to NumPy. + """ + ids = np.ascontiguousarray(input_ids, dtype=np.int32) + + # GPU path + self._upload() + if self._handle is not None and self._bridge is not None: + result = self._bridge.vsa_lm_forward(self._handle, ids) + if result is not None: + return result + + # NumPy fallback + return self._forward_numpy(ids) + + def _forward_numpy(self, ids: np.ndarray) -> np.ndarray: + S = ids.shape[0] + d = self.d_model + embed = _get_param_array(self.embed_w) + pos = _get_param_array(self.pos_w) + out_w = _get_param_array(self.out_w) + + x = embed[ids.astype(np.int64)] + pos[:S] + + for l in range(self.n_layers): + gamma = _get_param_array(self.ln_gammas[l]) + beta = _get_param_array(self.ln_betas[l]) + + # LayerNorm + mean = x.mean(axis=-1, keepdims=True) + var = x.var(axis=-1, keepdims=True) + h = ((x - mean) / np.sqrt(var + 1e-5) * gamma + beta).astype(np.float32) + + # AdditionLinear up → sign → AdditionLinear down + h_up = self.ffn_up[l].forward(h) + h_sign = np.where(h_up > 0, 1.0, -1.0).astype(np.float32) + h_ffn = self.ffn_down[l].forward(h_sign) + + x = x + h_ffn + + scale = 1.0 / np.sqrt(np.float32(d)) + return (x @ out_w.T * scale).astype(np.float32) + + # ── Backward ───────────────────────────────────────────────────────── + + def backward(self, input_ids: np.ndarray, grad_logits: np.ndarray) -> dict: + """Run backward pass; returns gradient dict. + + GPU path returns one dict from C++. Fallback stores grads on parameters. + """ + ids = np.ascontiguousarray(input_ids, dtype=np.int32) + gl = np.ascontiguousarray(grad_logits, dtype=np.float32) + + if self._handle is not None and self._bridge is not None: + result = self._bridge.vsa_lm_backward(self._handle, ids, gl) + if result is not None: + self._scatter_grads(result) + return result + + return self._backward_numpy(ids, gl) + + def _scatter_grads(self, grads: dict): + """Store C++ gradient arrays onto parameter .grad attributes.""" + def _acc(param, g): + if not hasattr(param, "grad") or param.grad is None: + param.grad = g.copy() + else: + param.grad += g + + _acc(self.embed_w, grads["grad_embed"]) + _acc(self.pos_w, grads["grad_pos"]) + _acc(self.out_w, grads["grad_out_w"]) + + for l in range(self.n_layers): + _acc(self.ffn_up[l].weight, grads["grad_ffn_up_w"][l]) + if self.ffn_up[l].bias_param is not None: + _acc(self.ffn_up[l].bias_param, grads["grad_ffn_up_b"][l]) + _acc(self.ffn_down[l].weight, grads["grad_ffn_down_w"][l]) + if self.ffn_down[l].bias_param is not None: + _acc(self.ffn_down[l].bias_param, grads["grad_ffn_down_b"][l]) + _acc(self.ln_gammas[l], grads["grad_ln_gamma"][l]) + _acc(self.ln_betas[l], grads["grad_ln_beta"][l]) + + def _backward_numpy(self, ids: np.ndarray, grad_logits: np.ndarray) -> dict: + """Pure NumPy backward (slow but correct reference).""" + S = ids.shape[0] + d = self.d_model + dF = self.d_ffn + V = self.vocab + L = self.n_layers + + embed = _get_param_array(self.embed_w) + pos = _get_param_array(self.pos_w) + out_w = _get_param_array(self.out_w) + + # Forward replay for activations + acts = [None] * (L + 1) + ln_outs = [None] * L + ffn_ups = [None] * L + sign_outs = [None] * L + + acts[0] = embed[ids.astype(np.int64)] + pos[:S] + + for l in range(L): + gamma = _get_param_array(self.ln_gammas[l]) + beta = _get_param_array(self.ln_betas[l]) + uw = _get_param_array(self.ffn_up[l].weight) + ub = _get_param_array(self.ffn_up[l].bias_param) + dw = _get_param_array(self.ffn_down[l].weight) + db = _get_param_array(self.ffn_down[l].bias_param) + + mean = acts[l].mean(axis=-1, keepdims=True) + var = acts[l].var(axis=-1, keepdims=True) + ln_outs[l] = ((acts[l] - mean) / np.sqrt(var + 1e-5) * gamma + beta).astype(np.float32) + + ffn_ups[l] = np.empty((S, dF), dtype=np.float32) + for o in range(dF): + ffn_ups[l][:, o] = -np.sum(np.abs(uw[o] - ln_outs[l]), axis=1) + ub[o] + + sign_outs[l] = np.where(ffn_ups[l] > 0, 1.0, -1.0).astype(np.float32) + + ffn_down_out = np.empty((S, d), dtype=np.float32) + for o in range(d): + ffn_down_out[:, o] = -np.sum(np.abs(dw[o] - sign_outs[l]), axis=1) + db[o] + + acts[l + 1] = acts[l] + ffn_down_out + + scale = 1.0 / np.sqrt(np.float32(d)) + gl_scaled = grad_logits * scale + + dx = gl_scaled @ out_w + grad_out_w = gl_scaled.T @ acts[L] + + grads = { + "grad_out_w": grad_out_w, + "grad_ffn_up_w": [], "grad_ffn_up_b": [], + "grad_ffn_down_w": [], "grad_ffn_down_b": [], + "grad_ln_gamma": [], "grad_ln_beta": [], + } + + for l in range(L - 1, -1, -1): + uw = _get_param_array(self.ffn_up[l].weight) + dw = _get_param_array(self.ffn_down[l].weight) + + # Addition-linear down backward + g_dn_w = np.zeros_like(dw) + g_dn_b = np.sum(dx, axis=0) + grad_sign = np.zeros((S, dF), dtype=np.float32) + for o in range(d): + sgn = np.sign(dw[o][np.newaxis, :] - sign_outs[l]) + g_col = dx[:, o:o+1] + g_dn_w[o] -= np.sum(g_col * sgn, axis=0) + grad_sign -= g_col * sgn + + # STE through sign + grad_ffn_up = grad_sign + + g_up_w = np.zeros_like(uw) + g_up_b = np.sum(grad_ffn_up, axis=0) + grad_ln = np.zeros((S, d), dtype=np.float32) + for o in range(dF): + sgn = np.sign(uw[o][np.newaxis, :] - ln_outs[l]) + g_col = grad_ffn_up[:, o:o+1] + g_up_w[o] -= np.sum(g_col * sgn, axis=0) + grad_ln -= g_col * sgn + + grads["grad_ffn_down_w"].insert(0, g_dn_w) + grads["grad_ffn_down_b"].insert(0, g_dn_b) + grads["grad_ffn_up_w"].insert(0, g_up_w) + grads["grad_ffn_up_b"].insert(0, g_up_b) + + gamma = _get_param_array(self.ln_gammas[l]) + mean = acts[l].mean(axis=-1, keepdims=True) + var = acts[l].var(axis=-1, keepdims=True) + inv_std = 1.0 / np.sqrt(var + 1e-5) + x_hat = (acts[l] - mean) * inv_std + + g_ln_gamma = np.sum(grad_ln * x_hat, axis=0) + g_ln_beta = np.sum(grad_ln, axis=0) + grads["grad_ln_gamma"].insert(0, g_ln_gamma) + grads["grad_ln_beta"].insert(0, g_ln_beta) + + dl_dxhat = grad_ln * gamma + sum1 = dl_dxhat.sum(axis=-1, keepdims=True) + sum2 = (dl_dxhat * x_hat).sum(axis=-1, keepdims=True) + grad_x_ln = inv_std * (dl_dxhat - sum1 / d - x_hat * sum2 / d) + + dx = dx + grad_x_ln + + grad_embed = np.zeros((V, d), dtype=np.float32) + grad_pos = np.zeros((self.max_seq, d), dtype=np.float32) + for s in range(S): + tok = int(ids[s]) + if 0 <= tok < V: + grad_embed[tok] += dx[s] + grad_pos[s] = dx[s] + + grads["grad_embed"] = grad_embed + grads["grad_pos"] = grad_pos + + self._scatter_grads(grads) + return grads + + # ── Optimizer step ─────────────────────────────────────────────────── + + def zero_grad(self): + """Reset all parameter gradients to zero.""" + for p in self.parameters(): + if hasattr(p, "grad"): + p.grad = None + + def step(self, grads: dict = None, lr: float = 1e-3): + """Simple SGD step using stored .grad or explicit grads dict. + + After stepping, re-uploads weights to GPU. + """ + if grads is not None: + self._scatter_grads(grads) + + for p in self.parameters(): + arr = _get_param_array(p) + if hasattr(p, "grad") and p.grad is not None: + arr -= lr * p.grad + + self._sync_weights() + + def __repr__(self): + return ( + f"VsaLmModel(vocab={self.vocab}, d_model={self.d_model}, " + f"d_ffn={self.d_ffn}, max_seq={self.max_seq}, " + f"n_layers={self.n_layers})" + ) diff --git a/pyproject.toml b/pyproject.toml index bc9367a..309bd6a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,10 +60,20 @@ dev = [ "isort>=5.12.0", "mypy>=1.5.0", "mkdocs-material>=9.0", + "tokenizers>=0.15.0", + "huggingface_hub>=0.20.0", + "sentencepiece>=0.2.0", + "transformers>=4.57.6", + "protobuf>=4.0.0", ] accel = [ "numba>=0.59.0", # JIT-compiled CPU fallbacks ] +# Rust tokenizers + Hub file download (no transformers); used by grilly.tokenizers.FastTokenizer +tokenizer = [ + "tokenizers>=0.15.0", + "huggingface_hub>=0.20.0", +] all = [ "grilly[dev,accel]", ] @@ -77,6 +87,7 @@ Documentation = "https://grilly.org/docs" [tool.setuptools] packages = [ "grilly", + "grilly.torch_api", "grilly.backend", "grilly.grilly_datasets", "grilly.examples", @@ -85,6 +96,7 @@ packages = [ "grilly.optim", "grilly.utils", "grilly.functional", + "grilly.tokenizers", "grilly.scripts", "grilly.tutorials", "grilly.tests", @@ -93,6 +105,7 @@ packages = [ [tool.setuptools.package-dir] "grilly" = "." +"grilly.torch_api" = "torch_api" "grilly.backend" = "backend" "grilly.nn" = "nn" "grilly.optim" = "optim" @@ -105,6 +118,7 @@ packages = [ "grilly.experimental_datasets" = "experimental_datasets" "grilly.grilly_datasets" = "grilly_datasets" "grilly.examples" = "examples" +"grilly.tokenizers" = "tokenizer_impl" [tool.setuptools.package-data] grilly = [ diff --git a/shaders/mf-sigmoid.glsl b/shaders/mf-sigmoid.glsl new file mode 100644 index 0000000..240b0e0 --- /dev/null +++ b/shaders/mf-sigmoid.glsl @@ -0,0 +1,25 @@ +#version 450 + +// Rational sigmoid: x / (1 + |x|) + +layout (local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout(set = 0, binding = 0) readonly buffer Input { + float input_data[]; +}; + +layout(set = 0, binding = 1) buffer Output { + float output_data[]; +}; + +layout(push_constant) uniform PushConsts { + uint total_elements; +}; + +void main() { + uint gID = gl_GlobalInvocationID.x; + if (gID >= total_elements) return; + float x = input_data[gID]; + float ax = abs(x) + 1.0; + output_data[gID] = x / ax; +} diff --git a/shaders/mf-softmax.glsl b/shaders/mf-softmax.glsl new file mode 100644 index 0000000..977c166 --- /dev/null +++ b/shaders/mf-softmax.glsl @@ -0,0 +1,86 @@ +#version 450 + +// Multiplication-free softmax: relu(x - max) / sum(relu(x - max)) — no exp(). +// Same 3-pass layout as activation-softmax.glsl; pass 2 sums positive parts only. + +layout (local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout(set = 0, binding = 0) readonly buffer Input { + float input_data[]; +}; + +layout(set = 0, binding = 1) buffer Output { + float output_data[]; +}; + +layout(set = 0, binding = 2) buffer MaxValues { + float max_vals[]; +}; + +layout(set = 0, binding = 3) buffer SumPos { + float sum_pos[]; +}; + +layout(push_constant) uniform PushConsts { + uint batch_size; + uint seq_len; + uint features; + uint pass_type; + uint dim; +}; + +void main() { + uint gID = gl_GlobalInvocationID.x; + + if (pass_type == 0u) { + uint total_positions = batch_size * seq_len; + if (gID >= total_positions) return; + + uint batch_idx = gID / seq_len; + uint seq_idx = gID % seq_len; + + float max_val = -1e10; + for (uint f = 0; f < features; f++) { + uint idx = batch_idx * seq_len * features + seq_idx * features + f; + max_val = max(max_val, input_data[idx]); + } + max_vals[gID] = max_val; + + } else if (pass_type == 1u) { + uint total_positions = batch_size * seq_len; + if (gID >= total_positions) return; + + uint batch_idx = gID / seq_len; + uint seq_idx = gID % seq_len; + float max_val = max_vals[gID]; + + float sum = 0.0; + for (uint f = 0; f < features; f++) { + uint idx = batch_idx * seq_len * features + seq_idx * features + f; + float v = input_data[idx] - max_val; + sum += max(v, 0.0); + } + sum_pos[gID] = sum; + + } else if (pass_type == 2u) { + uint total_elements = batch_size * seq_len * features; + if (gID >= total_elements) return; + + uint batch_idx = gID / (seq_len * features); + uint remainder = gID % (seq_len * features); + uint seq_idx = remainder / features; + + uint pos_idx = batch_idx * seq_len + seq_idx; + float max_val = max_vals[pos_idx]; + float sum = sum_pos[pos_idx]; + + float val = input_data[gID]; + float z = max(val - max_val, 0.0); + if (sum < 1e-5) { + output_data[gID] = 1.0 / float(features); + } else { + float denom = max(sum, 1e-6); + output_data[gID] = z / denom; + } + } +} diff --git a/shaders/mf-softplus.glsl b/shaders/mf-softplus.glsl new file mode 100644 index 0000000..62fe139 --- /dev/null +++ b/shaders/mf-softplus.glsl @@ -0,0 +1,26 @@ +#version 450 + +// Algebraic softplus: 0.5 * (x + sqrt(x*x + c)) with c = 4/beta^2 — no exp/log. + +layout (local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout(set = 0, binding = 0) readonly buffer Input { + float input_data[]; +}; + +layout(set = 0, binding = 1) buffer Output { + float output_data[]; +}; + +layout(push_constant) uniform PushConsts { + uint total_elements; + float c; // 4 / (beta * beta), set from host +}; + +void main() { + uint gID = gl_GlobalInvocationID.x; + if (gID >= total_elements) return; + float x = input_data[gID]; + float s = sqrt(x * x + c); + output_data[gID] = 0.5 * (x + s); +} diff --git a/shaders/moe-layer-backward-vec4.glsl b/shaders/moe-layer-backward-vec4.glsl new file mode 100644 index 0000000..a55c0d8 --- /dev/null +++ b/shaders/moe-layer-backward-vec4.glsl @@ -0,0 +1,87 @@ +#version 450 + +// Fused MoE Backward (vec4): grad_input for all 4 experts. +// RDNA2 optimized: vec4 loads, 32x8 workgroup (matches Wave32), LDS padding. +// +// dx[row, k] = grad_out[row, k] + sum_e(w_e * sum_col(grad_out[row, col] * W_e[col, k])) +// All buffers vec4 packed (d_model must be multiple of 4). + +layout (local_size_x = 32, local_size_y = 8) in; + +layout(set = 0, binding = 0) readonly buffer GradOut { vec4 g_out[]; }; +layout(set = 0, binding = 1) readonly buffer ExpertWeights { vec4 ew[]; }; +layout(set = 0, binding = 2) writeonly buffer GradInput { vec4 g_in[]; }; + +layout(push_constant) uniform PushConsts { + uint seq_len; + uint d_model; + uint n_experts; + float w0, w1, w2, w3; +}; + +shared vec4 tileG[8][33]; // grad_out tile, 8 rows × 32 vec4 cols + padding + +void main() { + uint row = gl_WorkGroupID.y * 8 + gl_LocalInvocationID.y; + uint k_vec = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x; + uint ty = gl_LocalInvocationID.y; + uint tx = gl_LocalInvocationID.x; + + uint d_vec = d_model / 4; + + if (row >= seq_len || k_vec >= d_vec) return; + + vec4 acc0 = vec4(0.0), acc1 = vec4(0.0), acc2 = vec4(0.0), acc3 = vec4(0.0); + + uint numTiles = (d_vec + 31) / 32; + uint ew_stride = d_model * d_vec; // vec4s per expert + + for (uint t = 0; t < numTiles; t++) { + // Load grad_out into LDS + uint col_vec = t * 32 + tx; + tileG[ty][tx] = (col_vec < d_vec) ? g_out[row * d_vec + col_vec] : vec4(0.0); + + memoryBarrierShared(); + barrier(); + + // Accumulate: dx[k] += grad[col] * W[col, k] + for (uint c = 0; c < 32; c++) { + vec4 g = tileG[ty][c]; + uint col_global = t * 32 + c; // vec4 index + if (col_global >= d_vec) break; + + // Each vec4 g has 4 scalar grad components for 4 adjacent columns + // W[col, k] accessed as ew[expert * stride + col_scalar * d_vec + k_vec] + + // Expert 0 + acc0 += g.x * ew[0 * ew_stride + (col_global * 4 + 0) * d_vec + k_vec]; + acc0 += g.y * ew[0 * ew_stride + (col_global * 4 + 1) * d_vec + k_vec]; + acc0 += g.z * ew[0 * ew_stride + (col_global * 4 + 2) * d_vec + k_vec]; + acc0 += g.w * ew[0 * ew_stride + (col_global * 4 + 3) * d_vec + k_vec]; + + // Expert 1 + acc1 += g.x * ew[1 * ew_stride + (col_global * 4 + 0) * d_vec + k_vec]; + acc1 += g.y * ew[1 * ew_stride + (col_global * 4 + 1) * d_vec + k_vec]; + acc1 += g.z * ew[1 * ew_stride + (col_global * 4 + 2) * d_vec + k_vec]; + acc1 += g.w * ew[1 * ew_stride + (col_global * 4 + 3) * d_vec + k_vec]; + + // Expert 2 + acc2 += g.x * ew[2 * ew_stride + (col_global * 4 + 0) * d_vec + k_vec]; + acc2 += g.y * ew[2 * ew_stride + (col_global * 4 + 1) * d_vec + k_vec]; + acc2 += g.z * ew[2 * ew_stride + (col_global * 4 + 2) * d_vec + k_vec]; + acc2 += g.w * ew[2 * ew_stride + (col_global * 4 + 3) * d_vec + k_vec]; + + // Expert 3 + acc3 += g.x * ew[3 * ew_stride + (col_global * 4 + 0) * d_vec + k_vec]; + acc3 += g.y * ew[3 * ew_stride + (col_global * 4 + 1) * d_vec + k_vec]; + acc3 += g.z * ew[3 * ew_stride + (col_global * 4 + 2) * d_vec + k_vec]; + acc3 += g.w * ew[3 * ew_stride + (col_global * 4 + 3) * d_vec + k_vec]; + } + + barrier(); + } + + vec4 dx = w0 * acc0 + w1 * acc1 + w2 * acc2 + w3 * acc3; + uint idx = row * d_vec + k_vec; + g_in[idx] = g_out[idx] + dx; // residual gradient +} diff --git a/shaders/moe-layer-backward.glsl b/shaders/moe-layer-backward.glsl new file mode 100644 index 0000000..e1b4f5d --- /dev/null +++ b/shaders/moe-layer-backward.glsl @@ -0,0 +1,110 @@ +#version 450 + +// Fused MoE Layer Backward: compute dx (gradient w.r.t. input) for all 4 experts in one dispatch. +// +// dx[row, k] = sum_e(router_w[e] * sum_col(grad_out[row, col] * expert_w[e][col, k])) +// = sum_e(router_w[e] * (grad_out[row,:] @ expert_w[e][:, k])) +// +// This is: dx = sum_e(w_e * grad_out @ W_e) +// Each expert weight matrix W_e is (d_model, d_model), stored row-major as W_e[col, k]. +// So grad_out @ W_e means: for output element (row, k), sum over col: grad[row,col] * W[col,k] +// +// Tiled: 16x16, K tiled in blocks of 16. + +layout (local_size_x = 16, local_size_y = 16, local_size_z = 1) in; + +// Gradient w.r.t. layer output (seq_len, d_model) — this is dx from the next layer +layout(set = 0, binding = 0) readonly buffer GradOut { + float grad_out[]; +}; + +// Expert weights (same as forward): 4 experts × (d_model, d_model) +layout(set = 0, binding = 1) readonly buffer ExpertWeights { + float ew[]; +}; + +// Output: gradient w.r.t. layer input (seq_len, d_model) +layout(set = 0, binding = 2) buffer GradInput { + float grad_in[]; +}; + +layout(push_constant) uniform PushConsts { + uint seq_len; + uint d_model; + uint n_experts; + float router_w0; + float router_w1; + float router_w2; + float router_w3; +}; + +shared float tileG[16][17]; // grad_out tile + +void main() { + uint row = gl_WorkGroupID.y * 16 + gl_LocalInvocationID.y; + uint k_out = gl_WorkGroupID.x * 16 + gl_LocalInvocationID.x; // output K index + uint ty = gl_LocalInvocationID.y; + uint tx = gl_LocalInvocationID.x; + + if (row >= seq_len || k_out >= d_model) return; + + float acc0 = 0.0, acc1 = 0.0, acc2 = 0.0, acc3 = 0.0; + + uint numTiles = (d_model + 15) / 16; + uint dd = d_model * d_model; + uint row_base = row * d_model; + + for (uint t = 0; t < numTiles; t++) { + uint col_base = t * 16; + + // Load grad_out tile + uint col = col_base + tx; + tileG[ty][tx] = (col < d_model) ? grad_out[row_base + col] : 0.0; + + barrier(); + + // Accumulate: dx[row, k] += grad[row, col] * W[e][col, k] + // W layout: ew[e * d*d + col * d + k] + for (uint c = 0; c < 16; c += 4) { + uint col0 = col_base + c; + float g0 = tileG[ty][c]; + float g1 = tileG[ty][c + 1]; + float g2 = tileG[ty][c + 2]; + float g3 = tileG[ty][c + 3]; + + // Expert 0: W[0][col, k_out] + acc0 = fma(g0, ew[0 * dd + (col0) * d_model + k_out], acc0); + acc0 = fma(g1, ew[0 * dd + (col0 + 1) * d_model + k_out], acc0); + acc0 = fma(g2, ew[0 * dd + (col0 + 2) * d_model + k_out], acc0); + acc0 = fma(g3, ew[0 * dd + (col0 + 3) * d_model + k_out], acc0); + + // Expert 1 + acc1 = fma(g0, ew[1 * dd + (col0) * d_model + k_out], acc1); + acc1 = fma(g1, ew[1 * dd + (col0 + 1) * d_model + k_out], acc1); + acc1 = fma(g2, ew[1 * dd + (col0 + 2) * d_model + k_out], acc1); + acc1 = fma(g3, ew[1 * dd + (col0 + 3) * d_model + k_out], acc1); + + // Expert 2 + acc2 = fma(g0, ew[2 * dd + (col0) * d_model + k_out], acc2); + acc2 = fma(g1, ew[2 * dd + (col0 + 1) * d_model + k_out], acc2); + acc2 = fma(g2, ew[2 * dd + (col0 + 2) * d_model + k_out], acc2); + acc2 = fma(g3, ew[2 * dd + (col0 + 3) * d_model + k_out], acc2); + + // Expert 3 + acc3 = fma(g0, ew[3 * dd + (col0) * d_model + k_out], acc3); + acc3 = fma(g1, ew[3 * dd + (col0 + 1) * d_model + k_out], acc3); + acc3 = fma(g2, ew[3 * dd + (col0 + 2) * d_model + k_out], acc3); + acc3 = fma(g3, ew[3 * dd + (col0 + 3) * d_model + k_out], acc3); + } + + barrier(); + } + + // Weighted sum + residual gradient (dx passes through residual) + float dx = router_w0 * acc0 + router_w1 * acc1 + + router_w2 * acc2 + router_w3 * acc3; + + // Residual: grad_input = grad_out + dx (grad passes through addition) + uint idx = row * d_model + k_out; + grad_in[idx] = grad_out[idx] + dx; +} diff --git a/shaders/moe-layer-fused-vec4.glsl b/shaders/moe-layer-fused-vec4.glsl new file mode 100644 index 0000000..153f33c --- /dev/null +++ b/shaders/moe-layer-fused-vec4.glsl @@ -0,0 +1,106 @@ +#version 450 +#extension GL_EXT_control_flow_attributes : enable + +// Fused MoE Layer (vec4): 4 expert matmuls + blend + residual. +// RDNA2 optimized: vec4 loads (128-bit), 16x16 workgroup, LDS bank padding. +// +// output[row, col] = input[row, col] + sum_e(w[e] * (input[row,:] @ expert_w[e][col,:])) +// All buffers use vec4 packing (d_model must be multiple of 4). + +layout (local_size_x = 16, local_size_y = 16) in; + +layout(set = 0, binding = 0) readonly buffer Input { vec4 x[]; }; +layout(set = 0, binding = 1) readonly buffer ExpertWeights { vec4 ew[]; }; +layout(set = 0, binding = 2) writeonly buffer Output { vec4 out_data[]; }; +layout(set = 0, binding = 3) readonly buffer RouterScratch { float rscratch[]; }; + +layout(push_constant) uniform PushConsts { + uint seq_len; + uint d_model; // must be multiple of 4 + uint n_experts; + uint weights_offset; +}; + +shared vec4 tileX[16][17]; // input tile with bank-conflict padding + +void main() { + uint row = gl_WorkGroupID.y * 16 + gl_LocalInvocationID.y; + uint col_vec = gl_WorkGroupID.x * 16 + gl_LocalInvocationID.x; + uint tx = gl_LocalInvocationID.x; + uint ty = gl_LocalInvocationID.y; + + uint d_vec = d_model / 4; + + if (row >= seq_len || col_vec >= d_vec) return; + + // Load router weights + float w0 = rscratch[weights_offset]; + float w1 = rscratch[weights_offset + 1]; + float w2 = rscratch[weights_offset + 2]; + float w3 = rscratch[weights_offset + 3]; + + vec4 acc0 = vec4(0.0), acc1 = vec4(0.0), acc2 = vec4(0.0), acc3 = vec4(0.0); + + uint numTiles = (d_vec + 15) / 16; + // Expert weight stride: each expert is d_model rows × d_vec vec4s per row + // ew layout: ew[expert * d_model * d_vec + col * d_vec + k_vec] + // But we need W[col_out, k_in] which is ew[expert * d * d_vec + col_vec*4*d_vec + k_vec] + // Simplified: expert weights are (d, d) row-major packed as vec4 + // ew[expert * d * d_vec + row_w * d_vec + k_vec] + uint ew_stride = d_model * d_vec; // vec4s per expert + + [[unroll]] + for (uint t = 0; t < numTiles; t++) { + uint k_base = t * 16; + + // Load input tile: x[row, k_base*4 .. (k_base+16)*4] + uint k_vec = k_base + tx; + tileX[ty][tx] = (k_vec < d_vec) ? x[row * d_vec + k_vec] : vec4(0.0); + + memoryBarrierShared(); + barrier(); + + // Accumulate for all 4 experts + // Each iteration: load 4 floats from input (via tileX), multiply with expert weight vec4 + [[unroll]] + for (uint k = 0; k < 16; k++) { + vec4 xv = tileX[ty][k]; + uint k_global = k_base + k; // vec4 index in K dimension + + // For each expert: W[col_out_vec4, k_global_vec4] + // col_vec gives us the output vec4 column + // We need 4 loads per expert (one per component of xv) + // ew index: expert * ew_stride + (col_vec*4 + component) * d_vec + k_global + + // Expert 0 + acc0 += xv.x * ew[0 * ew_stride + (col_vec * 4 + 0) * d_vec + k_global]; + acc0 += xv.y * ew[0 * ew_stride + (col_vec * 4 + 1) * d_vec + k_global]; + acc0 += xv.z * ew[0 * ew_stride + (col_vec * 4 + 2) * d_vec + k_global]; + acc0 += xv.w * ew[0 * ew_stride + (col_vec * 4 + 3) * d_vec + k_global]; + + // Expert 1 + acc1 += xv.x * ew[1 * ew_stride + (col_vec * 4 + 0) * d_vec + k_global]; + acc1 += xv.y * ew[1 * ew_stride + (col_vec * 4 + 1) * d_vec + k_global]; + acc1 += xv.z * ew[1 * ew_stride + (col_vec * 4 + 2) * d_vec + k_global]; + acc1 += xv.w * ew[1 * ew_stride + (col_vec * 4 + 3) * d_vec + k_global]; + + // Expert 2 + acc2 += xv.x * ew[2 * ew_stride + (col_vec * 4 + 0) * d_vec + k_global]; + acc2 += xv.y * ew[2 * ew_stride + (col_vec * 4 + 1) * d_vec + k_global]; + acc2 += xv.z * ew[2 * ew_stride + (col_vec * 4 + 2) * d_vec + k_global]; + acc2 += xv.w * ew[2 * ew_stride + (col_vec * 4 + 3) * d_vec + k_global]; + + // Expert 3 + acc3 += xv.x * ew[3 * ew_stride + (col_vec * 4 + 0) * d_vec + k_global]; + acc3 += xv.y * ew[3 * ew_stride + (col_vec * 4 + 1) * d_vec + k_global]; + acc3 += xv.z * ew[3 * ew_stride + (col_vec * 4 + 2) * d_vec + k_global]; + acc3 += xv.w * ew[3 * ew_stride + (col_vec * 4 + 3) * d_vec + k_global]; + } + + barrier(); + } + + vec4 blended = w0 * acc0 + w1 * acc1 + w2 * acc2 + w3 * acc3; + uint out_idx = row * d_vec + col_vec; + out_data[out_idx] = x[out_idx] + blended; +} diff --git a/shaders/moe-layer-fused.glsl b/shaders/moe-layer-fused.glsl new file mode 100644 index 0000000..cd8023c --- /dev/null +++ b/shaders/moe-layer-fused.glsl @@ -0,0 +1,104 @@ +#version 450 + +// Fused MoE Layer: 4 expert matmuls + weighted blend + residual in ONE dispatch. +// +// Router weights read from scratch buffer (computed by moe-router shader). +// No CPU round-trip between layers. +// +// output[row, col] = input[row, col] + sum_e(w[e] * (input[row,:] @ expert_w[e][col,:])) + +layout (local_size_x = 16, local_size_y = 16, local_size_z = 1) in; + +layout(set = 0, binding = 0) readonly buffer Input { + float x[]; +}; + +layout(set = 0, binding = 1) readonly buffer ExpertWeights { + float ew[]; // 4 experts packed: ew[e * d*d + col * d + k] +}; + +layout(set = 0, binding = 2) buffer Output { + float out_data[]; +}; + +// Router scratch buffer: [x_mean (d) | logits (E) | max (1) | sum (1) | weights (E)] +layout(set = 0, binding = 3) readonly buffer RouterScratch { + float rscratch[]; +}; + +layout(push_constant) uniform PushConsts { + uint seq_len; + uint d_model; + uint n_experts; + uint weights_offset; // offset in rscratch where softmax weights start +}; + +shared float tileX[16][17]; + +void main() { + uint row = gl_WorkGroupID.y * 16 + gl_LocalInvocationID.y; + uint col = gl_WorkGroupID.x * 16 + gl_LocalInvocationID.x; + uint ty = gl_LocalInvocationID.y; + uint tx = gl_LocalInvocationID.x; + + if (row >= seq_len || col >= d_model) return; + + // Read router weights from scratch buffer + float w0 = rscratch[weights_offset]; + float w1 = rscratch[weights_offset + 1]; + float w2 = rscratch[weights_offset + 2]; + float w3 = rscratch[weights_offset + 3]; + + float acc0 = 0.0, acc1 = 0.0, acc2 = 0.0, acc3 = 0.0; + + uint numTiles = (d_model + 15) / 16; + uint row_base = row * d_model; + uint dd = d_model * d_model; + + for (uint t = 0; t < numTiles; t++) { + uint k_base = t * 16; + + uint k_a = k_base + tx; + tileX[ty][tx] = (k_a < d_model) ? x[row_base + k_a] : 0.0; + + barrier(); + + uint e0_base = 0 * dd + col * d_model + k_base; + uint e1_base = 1 * dd + col * d_model + k_base; + uint e2_base = 2 * dd + col * d_model + k_base; + uint e3_base = 3 * dd + col * d_model + k_base; + + for (uint k = 0; k < 16; k += 4) { + float x0 = tileX[ty][k]; + float x1 = tileX[ty][k + 1]; + float x2 = tileX[ty][k + 2]; + float x3 = tileX[ty][k + 3]; + + acc0 = fma(x0, ew[e0_base + k], acc0); + acc0 = fma(x1, ew[e0_base + k + 1], acc0); + acc0 = fma(x2, ew[e0_base + k + 2], acc0); + acc0 = fma(x3, ew[e0_base + k + 3], acc0); + + acc1 = fma(x0, ew[e1_base + k], acc1); + acc1 = fma(x1, ew[e1_base + k + 1], acc1); + acc1 = fma(x2, ew[e1_base + k + 2], acc1); + acc1 = fma(x3, ew[e1_base + k + 3], acc1); + + acc2 = fma(x0, ew[e2_base + k], acc2); + acc2 = fma(x1, ew[e2_base + k + 1], acc2); + acc2 = fma(x2, ew[e2_base + k + 2], acc2); + acc2 = fma(x3, ew[e2_base + k + 3], acc2); + + acc3 = fma(x0, ew[e3_base + k], acc3); + acc3 = fma(x1, ew[e3_base + k + 1], acc3); + acc3 = fma(x2, ew[e3_base + k + 2], acc3); + acc3 = fma(x3, ew[e3_base + k + 3], acc3); + } + + barrier(); + } + + float blended = w0 * acc0 + w1 * acc1 + w2 * acc2 + w3 * acc3; + uint idx = row * d_model + col; + out_data[idx] = x[idx] + blended; +} diff --git a/shaders/moe-layer-grad-weight.glsl b/shaders/moe-layer-grad-weight.glsl new file mode 100644 index 0000000..31fddf4 --- /dev/null +++ b/shaders/moe-layer-grad-weight.glsl @@ -0,0 +1,85 @@ +#version 450 + +// Fused MoE grad_W: compute weight gradients for all 4 experts in one dispatch. +// +// grad_W[e][col, k] = router_w[e] * sum_row(grad_out[row, col] * input[row, k]) +// = router_w[e] * (grad_out[:,col].T @ input[:,k]) +// +// Each thread computes one (col, k) element for all 4 experts. +// Tiled: 16x16, reduction over seq_len in blocks of 16. + +layout (local_size_x = 16, local_size_y = 16, local_size_z = 1) in; + +layout(set = 0, binding = 0) readonly buffer GradOut { + float grad_out[]; // (seq_len, d_model) +}; + +layout(set = 0, binding = 1) readonly buffer Input { + float x[]; // (seq_len, d_model) +}; + +// Output: 4 expert weight gradients packed contiguously +// grad_ew[e * d*d + col * d + k] +layout(set = 0, binding = 2) buffer GradWeights { + float grad_ew[]; // (4 * d_model * d_model) +}; + +layout(push_constant) uniform PushConsts { + uint seq_len; + uint d_model; + uint n_experts; + float router_w0; + float router_w1; + float router_w2; + float router_w3; +}; + +shared float tileG[16][17]; // grad_out[:, col] tile +shared float tileX[16][17]; // input[:, k] tile + +void main() { + uint col = gl_WorkGroupID.y * 16 + gl_LocalInvocationID.y; + uint k = gl_WorkGroupID.x * 16 + gl_LocalInvocationID.x; + uint ty = gl_LocalInvocationID.y; + uint tx = gl_LocalInvocationID.x; + + if (col >= d_model || k >= d_model) return; + + float acc = 0.0; // shared across experts (same outer product, different scaling) + + uint numTiles = (seq_len + 15) / 16; + + for (uint t = 0; t < numTiles; t++) { + uint row_base = t * 16; + + // Load grad_out column tile: grad_out[row_base+ty, col] + uint r_g = row_base + ty; + tileG[ty][tx] = (r_g < seq_len && col < d_model) + ? grad_out[r_g * d_model + col] : 0.0; + + // Load input column tile: x[row_base+ty, k] + uint r_x = row_base + ty; + tileX[ty][tx] = (r_x < seq_len && k < d_model) + ? x[r_x * d_model + k] : 0.0; + + barrier(); + + // Accumulate outer product: sum_row(grad[row, col] * x[row, k]) + for (uint r = 0; r < 16; r += 4) { + acc = fma(tileG[r][ty], tileX[r][tx], acc); // Note: ty indexes col, tx indexes k + acc = fma(tileG[r+1][ty], tileX[r+1][tx], acc); + acc = fma(tileG[r+2][ty], tileX[r+2][tx], acc); + acc = fma(tileG[r+3][ty], tileX[r+3][tx], acc); + } + + barrier(); + } + + // Write grad_W for all 4 experts (same outer product, scaled by router weight) + uint dd = d_model * d_model; + uint base = col * d_model + k; + grad_ew[0 * dd + base] = router_w0 * acc; + grad_ew[1 * dd + base] = router_w1 * acc; + grad_ew[2 * dd + base] = router_w2 * acc; + grad_ew[3 * dd + base] = router_w3 * acc; +} diff --git a/shaders/moe-router.glsl b/shaders/moe-router.glsl new file mode 100644 index 0000000..60d66e2 --- /dev/null +++ b/shaders/moe-router.glsl @@ -0,0 +1,84 @@ +#version 450 + +// MoE Router: mean(x) → logits → softmax → router weights +// +// Three passes (controlled by push constant `pass`): +// Pass 0: Compute mean(x) across seq_len → x_mean (d_model,) +// Pass 1: Compute logits = router_W @ x_mean + router_b → (n_experts,) +// Pass 2: Stable softmax over logits → router_weights (n_experts,) +// +// Scratch layout: [x_mean (d) | logits (E) | weights (E)] + +layout (local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout(set = 0, binding = 0) readonly buffer Input { + float x[]; +}; + +layout(set = 0, binding = 1) readonly buffer RouterW { + float rw[]; +}; + +layout(set = 0, binding = 2) readonly buffer RouterB { + float rb[]; +}; + +layout(set = 0, binding = 3) buffer Scratch { + float scratch[]; +}; + +layout(push_constant) uniform PushConsts { + uint seq_len; + uint d_model; + uint n_experts; + uint pass; +}; + +void main() { + uint gid = gl_GlobalInvocationID.x; + + if (pass == 0u) { + // Pass 0: mean reduction + if (gid >= d_model) return; + float sum = 0.0; + for (uint s = 0; s < seq_len; s++) { + sum += x[s * d_model + gid]; + } + scratch[gid] = sum / float(seq_len); + } + else if (pass == 1u) { + // Pass 1: logits = rw @ x_mean + rb + if (gid >= n_experts) return; + float dot = 0.0; + for (uint k = 0; k < d_model; k++) { + dot += rw[gid * d_model + k] * scratch[k]; + } + scratch[d_model + gid] = dot + rb[gid]; + } + else if (pass == 2u) { + // Pass 2: numerically stable softmax (single thread, n_experts is small) + if (gid != 0u) return; + uint logits_off = d_model; + uint weights_off = d_model + n_experts; + + // Find max for numerical stability + float max_val = scratch[logits_off]; + for (uint e = 1; e < n_experts; e++) { + max_val = max(max_val, scratch[logits_off + e]); + } + + // Exp with max subtraction + sum + float sum_exp = 0.0; + for (uint e = 0; e < n_experts; e++) { + float ev = exp(scratch[logits_off + e] - max_val); + scratch[weights_off + e] = ev; + sum_exp += ev; + } + + // Normalize + float inv_sum = 1.0 / (sum_exp + 1e-8); + for (uint e = 0; e < n_experts; e++) { + scratch[weights_off + e] *= inv_sum; + } + } +} diff --git a/shaders/sign-activation.glsl b/shaders/sign-activation.glsl new file mode 100644 index 0000000..ff35efb --- /dev/null +++ b/shaders/sign-activation.glsl @@ -0,0 +1,24 @@ +#version 450 + +// Sign activation: output[i] = (input[i] > 0) ? 1.0 : -1.0 +// Exactly zero maps to -1 (subgradient convention matching backward). + +layout (local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout(set = 0, binding = 0) readonly buffer Input { + float input_data[]; +}; + +layout(set = 0, binding = 1) buffer Output { + float output_data[]; +}; + +layout(push_constant) uniform PushConsts { + uint total_elements; +}; + +void main() { + uint gID = gl_GlobalInvocationID.x; + if (gID >= total_elements) return; + output_data[gID] = (input_data[gID] > 0.0) ? 1.0 : -1.0; +} diff --git a/shaders/softmax-fast.glsl b/shaders/softmax-fast.glsl new file mode 100644 index 0000000..9bb922d --- /dev/null +++ b/shaders/softmax-fast.glsl @@ -0,0 +1,77 @@ +#version 450 +#extension GL_KHR_shader_subgroup_arithmetic : enable +#extension GL_KHR_shader_subgroup_shuffle : enable + +// RDNA2-optimized softmax: subgroup reductions + online max rescaling. +// 256 threads per workgroup (8 subgroups of 32 on Wave32). +// Numerically stable — max subtraction built into the reduction. +// +// Each workgroup processes one row of length <= 256. +// For longer rows, dispatch multiple workgroups per row (requires cross-WG reduction). + +layout(local_size_x = 256) in; + +layout(set = 0, binding = 0) buffer Input { float data[]; } In; +layout(set = 0, binding = 1) buffer Output { float data[]; } Out; + +layout(push_constant) uniform PushConsts { + uint batch_size; // number of rows + uint row_length; // elements per row +}; + +shared float lds_max[8]; +shared float lds_sum[8]; + +void main() { + uint row = gl_WorkGroupID.x; + uint l_idx = gl_LocalInvocationID.x; + uint sg_id = gl_SubgroupInvocationID; + uint sg_idx = gl_SubgroupID; + + if (row >= batch_size) return; + + uint base = row * row_length; + + // Load value (pad with -inf for threads beyond row_length) + float val = (l_idx < row_length) ? In.data[base + l_idx] : -1e38; + + // --- STEP 1: Subgroup-Level Online Reduction --- + float m = subgroupMax(val); + float d = subgroupAdd(exp(val - m)); + + // --- STEP 2: Workgroup-Level Reduction via LDS --- + if (subgroupElect()) { + lds_max[sg_idx] = m; + lds_sum[sg_idx] = d; + } + barrier(); + + // Final reduction by first subgroup + if (sg_idx == 0) { + float final_m = -1e38; + float final_d = 0.0; + + uint n_sgs = (row_length + 31) / 32; + if (n_sgs > 8) n_sgs = 8; + + for (uint i = 0; i < n_sgs; ++i) { + float prev_m = final_m; + final_m = max(final_m, lds_max[i]); + final_d = final_d * exp(prev_m - final_m) + lds_sum[i] * exp(lds_max[i] - final_m); + } + + if (sg_id == 0) { + lds_max[0] = final_m; + lds_sum[0] = final_d; + } + } + barrier(); + + // --- STEP 3: Final Computation --- + float wg_max = lds_max[0]; + float wg_sum = lds_sum[0]; + + if (l_idx < row_length) { + Out.data[base + l_idx] = exp(val - wg_max) / wg_sum; + } +} diff --git a/shaders/spv/mf-sigmoid.spv b/shaders/spv/mf-sigmoid.spv new file mode 100644 index 0000000000000000000000000000000000000000..89a180c492eba201c0a40f64adfc58d6fadae14b GIT binary patch literal 1580 zcmZ9MYflqV5QYyeEhxxEZeDRK-VntbAZpYQO-d6IP5c4PDg;SNYtqH|QNQ_X{1N^s zjfsiRb9NUR51Gu&J2UT`IkTl&SsDppD2#>C@HFISI#fUcSUu|I&byt}ezv>1wtnA? ziBOG+=1h`PLk}Z|=|R`P7%~MeHZ4G+xk{D43i>Bbjf}y%8nIDg`>of+MkCLDr`_!I z-lv^b_jB(&e%23k9Lf6;&w5$fY43IR4)?molTi6CE?|Ux#wPXLJz8bu{ak;eR^OSzc zp}z_~*vonHUbW-dPPm8TT=ESSe9k$@_XOWr&m%KTSwS~olv8_-p7u)+PlC@!8AW^n z{}fu=8sn$Y_7i_r`1TN6!8dmt@tA)GZA^R-7T;Lg*jcnZVsh#Gt%!NIS#*L}FSq_2 z;<0`nUEXh&iJWtu9*Z%a=Ul;dCb7%J8inm$d@Fsqd{^H@%)98jFEQgc3R@oUqP~NK z@b#s1KU_J)QxOzkK!)GuQbyk$?SiSk3?d literal 0 HcmV?d00001 diff --git a/shaders/spv/mf-softmax.spv b/shaders/spv/mf-softmax.spv new file mode 100644 index 0000000000000000000000000000000000000000..55e6998a29899a1c6267b5bc5f6c9737a8baa7b0 GIT binary patch literal 6684 zcmZvg37C~t8OOg_?l=m9Y)TO{NF`4iS+SKVBGy1*QrKkcbYU(qF`L{wJR?OR+JLsx zzDs*%MrCQ5WtP=ro6&9ywozG8St(d%zu$Mi7f;{Q(f4`(@Bh5#f8Miv=K^i5^Nz`~ zmTW@SnN4fS@-`)Fg~{MLa^Jmr?dk>fM&E*ki%vITQr4E|nKPN33cej1to9GtF#$UP zUT%s6iRNl;{MU-Vg`h%4I<7)rCw&`Nol9S5uCsBVw|iiCU3Fm9(AMGg)kgpD5WSAA zFh`cH8Le+xF+5an)H7!6$T{e2UA3`(Q*XV0TTQ;1>-_cFMZE*Hp`*FhvY}ROjBcr+ z@Suh(9I4jpy^V`UkRZQ3*S6li{y`MmvZ-KOh2DqZ8^ew2K<~(Kz27qH@RjV*@XcEK z`^NIz_Uy5E>yU8|784Iv$9lI`2Z-CUrNkTZ+_r25KJqRAuNoQ|ZQxz2v8`XOx3Ai$ zT1#8j)1<+btQUM~qX!LJ=1Tr-@XX(f-L)Npt6HSUw3S?i7XX05e>+F)(S zyWrYyga6-JE7^XyEj4dyUu{c1zCF7GpVohXn0bz#&q3mJJ$IR3tmhCo&h2jSHq15L zk3WC3ksp?OZfqOQx#nTH=MOaLd`u@7w68WaJP6l~hToCq-RdUS`|f>RNw(NFuj!6* zNaURNH1Zr{p5In*-fy{y3FjC&wVi)0z5|nYe(#&(R>H*Ho0-Rzqp{;{=6etAS4%t> zvqgMzGtA=`fz!FXqs_URci7YiwcGQ%s-A08uZmCgK2DtKd8a~8Y)kdLPwJVYo^h)8 zZ{k$X^9;Q>+f>hURj&n8&p6eaN{46r5NsN&a2<0|R8u^IG0xY8*}8~%sB@D)iTGH& z*fI7WhetN&&nRN&kUN#wT+iEP{_%MGnWMJx)Pl3tY2Yp_<}#j+ zMQt;{QSWqC=)JHe^JZhVh`Ye*oXF^fcw=i=gm(|*%+)W!jEG*86MD z9PkBrV|AX5H%HF37<)hE?7y(Uj#bCI89LreWAk2&cZ{4m##KztdfdlNSW6MBu?w@Q z>m3bU@1(K1pTMgjr>?R0G1a{;iPboaU482*OVstQhOT$hSl$2N)sR!y*n1hej%l&j zY{#@-&Ad3ulKT&K2}w@E$U44dbF7{)&%IH%|5JEV`;_brY6 z%}ir|FC+F2EKXwirKNa9DfYK4#_PvQ@%AKEe`hHkPh!XKF2#G2*ztRl*nEGV()s;e zO5^>d*x#eH|3DI}f3Ot$dldQhA1cNEE~Wkc&KNJ}RD1&_VZIZ-4QFyj-rLET{8^l! z=VfgE4E%|h-~Mzg@=q%GtH_xN_Fa=x^IkM&;cfS!Z{I^P|6JwW5Qp|0u-d)}q5W{M zeB{gpo706w&dFf;%b4H#=V6YMcMaY#xyXM6*mXtxNU%BA6!kv}Y>vG3&&S)W-@a2Y zIqMhS&qpDa3(Zr(&JlH=1}@kA7&!UyqV5G?$H`l_--29do(@hmA4go)d_0`|<%Q-G zz>bqw)9*=6?6>3_D)u|_ypFJrC*oa?{vOtR2L8#I&Hgj-_Uogzv%qS^dd~)1+cGS2 zmV?j6_7(el4%l(>?yd2YFy~WW9AnP~yC0$d6mWU%Plc0@d(aJboP3O12`<;Q3Qj)8 zJq_$w`I!IdVD~O`R)fnr=fTN`&NIM{liyLC$N6CO?F;QSVD~O^o(VQj-m~#dlQZ9U z`dOG~7vrA;HfK4v|3dt8u^!Ct$i5i&Jg|Kec>&aUKGuTG$JDex?uC7xwdd|x=wrSY zfYW+kNG#Wd#a)t1?~-R4`=PmP^LEFK5%D)`;_iI{iia=iPa_icL>{?(Yx{@38`*GDa{1*;Kz_BycZOW(HF z6U)cj_6D%y{$7j{~ciWKXl#+F6+DtPCj(r4R)M-yj$-9t8ZUuzZdM@Mb4FA^W;6JJ@_lJ$bTPL zF5ZIogL|;qn0@AX?jOMHi}-_J&;L^1S^GbPnIrE!dkZe|KTIqaZ{bJ4E!cd_KJQNC zeiZC|x}?y)3M}WCxNCCWAMcXqt9QKL@yD=q53Vk_F1+LIlS}WX_aw%D65LtDuI*Ep z^~4^18oUp4kIYq5&O9|YQXs40fD+jQbMUam%n6cMDiP#(f!lGbZm| z+=!Qp{I7sxKRusYF>~Ub{VLcTdDs1Qyv=pn_cct;b&KPjm5VvP0d|h4`Ry@A^-`&S#&s?8nPREk6az#XLU)TT436&xz$@eZK%ZPF~GB@N%JfC)hQI z<}bnJ_5BJ?K5G9p_)bhdbbbR?JKn0_g3Xh6-UE2K$p0N!F22LxgR9tV%s%rL;{Slz zr^dk|4vjw&ry74Ej&Iq%(D*ahJ~i&b8;8bUz|He*{#O2#xCgTi`$FSyVEfcKgg1`9 zzk}r>{s-7v<30T+*c|zIGw%VnVDmBicVTMZjkhnf?gd-(^}L@m@kcPZ!?|zE{)M;s zec5*(b_~LPUt)DS@J`lTXY=o(t;FTK(*`F$QtW*@SZ;Umc2~guJBNInTL<_7Vgb*yq2UsN>(a?Ta-`1iOaNIiA@4ig=R!SkyZi oY>s@?djdG>RnK|k%uVaHZsQy9)~`>04>lL^tfF7u zO#Hs?KI6nQsZ`Zp+o`HEUFx0d3!x_rg#NG-@-rHGK>}Dcs@OS6!kFQ(JcG?b9?8`9Vkhao#xr)yx;4ZW(DfaJJo;aPYZIR!i*Gl9rm?vL@f<62 z$2_mNxz)!C%)5>wCy@E9r9J_9tThRb_Z_3d>jgdow=e5VafjJLoc%@;=OSi( zb$2GF{ca(fC)R(6+*8<>;MR|I+^4d2+@-pCR^a9ovyQsE6Em+dYY=BE&U90aLQ&t! z)v(9}t+@^9BQCA-#?9wD_PD=aAbwY@={)={$km-u@}E9m=sQHY$05jFyWb6B+}$uF zzDWi1sYlxgvY7Gi#{Yr)bZ61Ww<-P|ANO?wGEV#_Ij7*B-_+CQe_704l>aakWwDs^ z4su?%n(gc!x_h*)HvOFOeMp;kPYONS9-xaweWB35aZh6xA$`Q{&)q+Q?9Vq5gvhPaU W4q1KA8<20$*;vEcV){DUHRvy1Q(o@? literal 0 HcmV?d00001 diff --git a/shaders/spv/moe-layer-backward-vec4.spv b/shaders/spv/moe-layer-backward-vec4.spv new file mode 100644 index 0000000000000000000000000000000000000000..bae06ef6715468505cc9efca54e1a9685ea55262 GIT binary patch literal 10392 zcmZ9R1(=pq8is!y3B?XD5l}I}!tOypKv6LfYXoKlbeO?mD1{YN3~XJyyA`{;yK_so zyVu%vcklb3|Iu^Jzt?-6=Xu`uocEmPJKqPgbxW+*K@TrG|isY)aG>5C9ArQ&BDEK&pd03 z@3|h%8vH%)Cf~ASzZ-%5=~X@B_=yv%&p|%4gR@dcUIWOpQU@R2!C6V(wX!>7%Z#Z9 zOlfX%dsS^u#->>f)7#soG&WVWy2=>F_9|YoGLCU#Yjf2ldrmON0wZeqX2ub1;~V#x z(a!n0R^tra*EFkR{mVOOwb$J|n8xb!t*g8Yo&xdgZ-L$CsFrEfeYsaPf0G(I?yh@%r>2eI ztlni~4MOg{makdmb>BVnk*ww2@waZrP>JSpYCiL5-V-(NdbGNlwoIWp$GQ~?&AXzu zN}+jQ)K(+h=P`_G&Ktp*d?v0xl3bpTwe}8ozCY`IURg`r&jdf$miMo+4t@!^?B%mb zuRN#p3NL&4?9xkYFFmuDyzKQRiL#f^FTL_CS7(mr3NL&4Y|Xp2xF6VO5nlHCguKt% z`jx4&t{>g(Ctj5`+)H1=Utf54Mnk+H&?ujZGfgjyNt zITGDJ1Hz90_aoL~tp-NttI3aKRMTgBk7slrM-qAEU6MO*Dx>#A z&3SU)9X0cQ8^z9*&tY^;dG>icxJQ@o`CNGBUkKI{y>K7ChuO!aV7cpF#^@Y1_aXPa zQcIsVz}`{0&tx$@b7$`X`wZ=q{DWxjHT-e7-=%Q>Z0`nlU%s>1m+z$9eZ9@-9@N~I z+;>uK5VP&+JLvuOS@>RTLl^V*_U9hILb11N=!fryW4M18D7p8bN%!Oh=|hpBh;&qT)|xaRj| zIie?_zdPfh&uw`^{bXw7{vA%u=a6|TgZ=iD^ZYYOeSU4;s$l1;@50ma*{Ed=zx(#~ zo0ZP2ewz{Y!9|Z{ifPFJAhrw+?VWaCq{qYa7WDTNT~V#iNCFM zsbV$HQqQ3`c?0MiTLC53suT!9KoUYVl|^xU=tjp{b|O z-eCLWj_m`sM%^`iN7cl>pWd^tIp@BNp4agUeAu7y0K%X70~yUbrq@AW*9adAcF&7z zJ|7IWMqO{^`w_0~TH?$-1ngSb_n}}ly-3Yr;QmBBJREGDy5||ksFwPPVDEZ(BiQ-j zO~a<=JU z>(qzvrD|uKL9`Hq3G>eN&!3ruIX#$M<32O@>)zaBa^A@@XHC|b4R*f$<}e;dOd#}8 z?_5L8y3fIWACD(KV;^%F4Y_sm7<&`;`T?W+aP)4?XIwz|Gk+qZdB^y164*8JEKUZi zc^~v&egZL{aIQFWPXU+D{ZzPXH#Jv-)ixt?=4-&}>D3K)_Ya z^vt^+ZclURcLTT=k(wL9)~V;6xCwkY;X7g8xw|mlOqe?x59GJhG|!!yTfz0jf?CaO zU^Q!fUiFM>@$q)>Lc)9R8v1w#p`Jb82`+uS3$7RHsks|m`nVXbR{D4kTs^(?@m^wm zBL3e8woW}B-4C{>x%7JgT>AJR+&cC6_z*Zgns=_gJxrL3kB`(e&z+h_!SzIZdglDA&k?1M&%>=# zkB={a?P)IkUIdpuz67^UJwCn+j*sS@tBAJ4+&cC6_$fF(ns=@~{)jLaAAelaJa=k-0GwNu z>ErL=)~UzGKY-(-dFSfm9|?2u@lQ4FGse{X89bDTkADHHS*wrc)Z*h`!TP8#uAz^A zBh<6!zk^F3{{gpNJvILXmp=Xru2%Z^1zbJ7^zlog^zkdWb?Wi)Yp^}drQbK;(#L

`(?P)IkdV)(Imw{WS9v_zl z$4B$d)yL()=Hlb>U^UO3niasMk1N8}tc{Oq@o^=vKIWUHk1M08XV0sEOCMK->!W&V zdVxzHd&AXAA6J8`rwwM0 z$G%`S&z+in;L^wba5Zb=qgs4i7p#x@X6fU4XzJPX`ry*X0dVh>dTKTRmp*O?S1Wzo z2(F%9`nWN?^l>2EI`#NC2y9Pt>9+~E^l?+Tb?WhPGjMz~ACEQ%n~RTIfYm&AY6gRc zvA+eizdvpXSF=_h&8fx5t-$)2ZZ#cdT>7{@T&?tR2e^8A z>En*@(#M_P)~UzGox%1rmwvl|OCNVNN2tfgdT@NK=4o;yAc1(!Y! zgR5B^AJyXHaIik+o28Gtqp4@lBfzDPBjMJor)Ctm^l=ZkTIu7SaP{=k$Ia{##X@j$qmwee9cJ{|ElswHEZLeT6~-Y*2jFa^l>tpdiFd8 zT>5x4Trbp9a}2oju^Fyb`ZyJ?o?iOc0xx}RgZzFvE`6K_S1Wz|0bD)3^zj6E>EnF3b?WhP0ob1A((goY>ElUo>(t}p$>8{CJ{~Ou zn~RU9fYm&AYEA{0KAr|wvo=1e#mCdZ`j~H)J}yF2&z{c!mp+~e*GKi#oCPj@JR7c7 z`gjgpJ-zhtTzKi@d2s909MrE`7WaZk>93ya^m1&Bvpg!RF%QEnqdzotj(0rH{A4)vS$= zYW{oV0?y&TPda9Qx09O(&l&E3s|`eRzB#q>-;U>_b^do|^9P{2zGK$EleuHyS^qA$ zTGltG=KOtY^Y>@;x!;3+FR>APF{5kx-0!RD@!)=VJ$ih40Ip_T)_Dl*Iu9~torl5d zS?3Y3e;3a>kHXcgTTuIMJ_dGub6NWduxmfgn6;k-t7q+}z{85QpN6Yhm$jb(yS}-s z{T$e}pJmM2&x6&o_6uO|K=%D2T+Oj+x@AoU<9*p6y!pr;q8r&N7 z)VvNZ@AeyT>(uS>GNW4R-vz6MzgOdV58j8ju!q5fIqUrI(GS4p!aoGN|D5e3ur=yg V>tpacglpN?{is>%|A)98@qY^QojU*k literal 0 HcmV?d00001 diff --git a/shaders/spv/moe-layer-backward.spv b/shaders/spv/moe-layer-backward.spv new file mode 100644 index 0000000000000000000000000000000000000000..0b917e9a1a48ac331fbe3c86fbae6a84946d0927 GIT binary patch literal 10264 zcmZvh1(?-U8pbaUWs71ff`KA-7s6P8VxnS?$c%svGYErV0*Zm%VxePWcXw@fx7*!o z?z+3b-+%5aXYaH3x#v0W`+nbdzH{DxTp62|*{ENw)~mLBt$*#EUbX5optcN34XtlA zj~+LE+*UKYTDRI^$8Ahlsn%5GnX_`OU#$;wZ^pEVQ#%YS&sdeQ3iav(DlE5Cjp`;k z)M{@)`jAo9_951v*o@BE#QIiiO>S?Q&^hDO(K9+{O&`0DHT|pg(Z_X8n%F+J<8-F3 zshu50d)8co`P7!qSzQ%xstsc9I>T5~ZHRb|xKC|U@FBBiPT9M&V`kUP8XbEydgv+Ge!O*0kh%*67<7x*kT~ zuF(4y`u1S=KaAOr-i=_MIB8OKUF0KsxJy^@Q9axxE8g70T?*c(wik29tZ9c&ZEth? zC{497%&o0G@q@rziY2DMJ(|m@c`rrtK2-DB ziPlutRwy)U%v+_6yLcWv>?ql)b!uJ-wD*%cX1guCAr) z>9^~7ZOirEhI@V2W~@e*c|%yyHN~q_W52-+KZD`@m>ueC!k4XMYm6_)%(9ifd!5@u zZ8+RIud&?x73)|%aj&-{_1<@7{hDw;1KH)q%xkdYq0ECBgXp{|v+HFTD~EkILv!!u z+DAUDu6gdw!Gjs;D<97AWBnG)IlKEdz5~O2<9jgclYNal@5%6!bM6aXx!{w)0~qU( z-^y%_n%ozfn!dRH$;{s0>h51Yxvs5`A3g)t4+49S%dPXp>9tj}PVPOeW_(T^TPr_@ z**@~@doI}g@Oj|=y_WKBxc7tmxR*KmxDPCM-TRqcL(P52m(;badp&E~Tdu#iz;mD7 z3fBMZ%XoJ~{(W%Y7pZ?3-WTq^e3r5=pCh^ZdWqS6sktw?&yCs^RNK?%CikfK@=kOy z?q2=4#}85L?Hb+({Ta^T{^pn5-~5sfDY*BNznzK8{p~FIs0Q9#aP#(R;A0xNznxjn z+275Q``cM^e>cP3kN+_h?mWMNFD$tE{uU-*eoX^k)WH2+Ec5-H3wQRn&bb-aeW$F% z@V(>p8NoI1y8EoDFCa(mv!>>?OxFsZ-TG`v*ReP6p>%rC2dwsCA>h|_NQcLd*z>WJKjHYhB_nAET8-g2qZiJ?8 z{t#w)&anyD^IGpaa7)H8hVRDYY*p7EA!lRo)(q>^@8?>mZNpIifF9d1`}vrd`nC); zpCNJ9-5%_D(|bp-y?yT{XD6_FdhZOjPW>a+%ev0*amToqT^R1g=hb|By06`s{kSi4 zc4er!FYza}UKLi;E1yx{@nWCP;86u`uH&3xZ+PPw_CZt6HQx`cwlCwt`aL)rtnU3| zKc8>4tUDImxbFUF>d86a|KyBAQ_p>RAlN!}`}p2ai{A%>%lqLFcw^r~(bTi%VPNaj zUGpGjHSu_6?|-k=;mmsN{65#NRt zr+}@=^*a^Z=yf}qdd@oyY@NFI^mJxFp50g{<1B_}7pK=Ou-7I0bg;d1mf2u))SG#m z&S5@-(Zv|fFmCNo<}(?_Jab2#d;Q#teeIu^&rBIJC+nOIw%&f{FrUkq%+NQzat$@} z-lg|kX2(2+vH8rox6T9G+u3_-0rLe6KlZ8&xCrbT@%Cb{n)ic#%Fk!GmbK#4 zUIH%fjZ5LK6+bTno1^Yt7c#4*{t9q;Z(IqtK3-e})(i8D#j~ry#$5AqW_i}U7F@1* z9lYGr^>DQn80m8Z*gn=8PoEpX#_V$qvpjun0+)SmhL?SAfvb6~(&tvNeXKQ}KDU95 z*=G^6Jbi8lmwoPlmwoPpt93Kd=Ps~)tTmoKcY}@DXEC!pv3tOJnZ6Hz$1w7qco6J; zrkUFZuil2oE4?t^oOoqCweNtvf3vU8fJ?7F3)d_44o{`#4F>e{TA3~F#Gy8xb*5faJ^Db&UeA3SKouHbu;4C_rZ-`{QylpUi}bkt$Mur z5!jx_((lLM(yO1q%~Ow8?}Fo%@tp6cU}NDwtMF26tm`YSlTcy9B}(HHZKr}l5)^4;`zc|L)z3Ky3n@e82>I-l5svnwqyz=iz z)~d&=Wx@6|mVV2DORtuPo2MSHRshE<<2m1oU}NDcflIGehV!j{^Yp@ebK;fp)UEqTMZrTFQ zxAMhvn{ST3m}fk-TY}4X(^l}(m#yLYqMp5P1J;+EZ(F!p`EJ?{-ssi#XzKB52e7s3 zS!+kIJ&mQ`PThE90r%16;nF zM!-w2_Jr$|diFjNTzWMMu4Z2Dt7dqkS9_tU$E&@;)~d&=eZclKmVWz!ORx5Wo2MSH zMuX#(@tkiA*jV^j@F?~ful9%Yt-R6;^UaA@##4I$xO_K_gO^?%2-gGk?EN5c>D9q- zwQlm_)gkakuMS01k5`9*tyPa#315q^lAp&JoR`r z6CAIM=X_mYW8t&F&Fnq*(CKi#l~;OUzB%#Acxq>Z%XiZpc4hD`ppBEUd@M_ryj4)1IH`lIp6tUW8n+H&4pJN z!1-2Q>4o{`#4F>ey%1c!n=XQvUR@0DW-ayXeIdB?>JmAlo4h=Om%N@n)T<^>lZLL{!P^QgnHaL>o1~qJUr{)1XuIlIIeF@ z&HBT^?(axuufeV8w=uj1w=n1Yi|cy$?N!Xbx%R34UT_E89QCYoC)jnmnNxQcSUvo1 z@JMp9&OLB*)aTdNz8CEJ#RI~cXzIE5 zp9Nc|o^_rBm-qSeaP!pd`y{hk@?Qa~g}++oc_vv3**1#I%pCA literal 0 HcmV?d00001 diff --git a/shaders/spv/moe-layer-fused-vec4.spv b/shaders/spv/moe-layer-fused-vec4.spv new file mode 100644 index 0000000000000000000000000000000000000000..4014e715faedb8ccbc34d08a2f55a1e9cc64e475 GIT binary patch literal 10864 zcmZ9R2bi2y8HFd=Bn0WwED#bBL?l3{(h@=dLo^x;7>JZ%TMF6igxyUxG$DW}g2aXe z34#;_6r%{D6af_}Dn+GP00qI`X^Q&3Gymba&OFce+;h(N-TU2p{(tU-rDNp8b!xQ{ zwe@OaYtN0SH9wowMxxZv#x%#7vk#xW^Wad=&U@@R#exlM9nCsxHma>t8_hV1SlYR` z-^6;vYl)4Suf9nW%k5NSb_X45wL2lBYU|c!&aNMF)DiXWM*q-2qp!Z8uQ6Qj92y!} z+_ikDcd*9#qyBIG(aayq{DHWMa zTWih{@7BWe#Jh_})#~8R?(XI@mv^-|vs(G?7H4LQ_p~^(<)dr8jQz`(9-!{1) z*KzIDP1eF|z)LFqO2(Bd)JNB@QLiuTYjkz?0a3=(zRGw=^9dcmy4CnFG7x%1e@?XH`EFWra*5@SmV@*%zP^Z{EKT+v3S<~;5 z@diTf_e%c8ChvIleL9qx+oAcH*cxilTuyBi5p86p`8|r(QEBTl+Uo$uF@*DGvajik z)*Ol=_VYTz`O_tQ(%bqOaPzJuo&oRC7Ow%j8?B#P#FR{J+ zxVF6PwH99X@|mVro^N`Em%V)E=_R(89@|S^_WC)AvX{?1z4H8D&m27rFMIhen0IaQ z1n|g$m%W}P@4I3Brc_xsfo}E_Z_XO-Wjx_$JelLd$-?ljP6U_dUMAUekPMq?<9_I zi>$p6eIp|Lbnqs`R;+afqx03|s~FE9e4h3`lhJ*sdsg{bm8M62ubr<)esAT@yMR$Y z)SM^xd#Yys!U{WAei5U6Io za_9Y$(Hb@9$^HJRP3G;4~E}?z4#8Bcb|TIzAvJTW^@gok@u}*xc>$zx&Hy+sP8;|4J@vjf^)}ws#=F~iPr+TMw~hPzp7k9YZG2@L_xC;Z z?&sWstFI}z`6~-<-rw$WzQ4iYj{XKa9>`PhZrYIW{_=TE<(YWT`0lHpO^w`lU(M&5 zd7Fa0x5{~&!PU>G=4}pkuKF~dzRy}MYxtXCZ|}SG_O4b>Z-3|2%HICg+q;~%6E~4b z> zRjso(*tN{Pz~1&{^z$xzBGUU1u=;_#S>E$Ai5Y}v zGH0Fl^`V40eVJDAtowF&`?~KyQ_u6C4Oa90Os{u>)60B%y$fv4Ufvh-^g0aO-s|0H z>gjbjSZywmUPpk_%Y1sh2W-w>a~S36btJgG*HLKd>GfW)+I%9tjs~Zf`SdylY|dWu z80BK`xXuc971+BidQXL2%R68JF^R}N7FF6p#=Osq!Ri|m_VF%M%lVdoJzx2q=!2`L z&r+~`^8N1zTchrpCorms8;n~LFYyj8WAt5ge4afIFb)zU2=haX<{i^(IoLJASAgB~ z+Uh$n47Nty_g;P?;o7bx&fJs0u9bbC3|5O5E5T~1KNYMNej3<2C;q+m2SJ%v#% z^=E^~XPu9L#}@pf;PQEW z3~r5j=6oFNUUIfifUQ&a9lnt9lf-Ib3Sr*4$1;A3FsBFSR@~?5e%+gUOwM<=%vqCl zJ_B~X{VrnsEO8d0k9y}CYSujk_TIRdc#?h0Wi;g0tzpc2cL~^iIQrgQ%6J)(_wMsx z^N#W3ah|K#I*gEyhyBTaxbLn>r*t4YORO`lz0o zN5Q3!kHOVSA0LOSrf=+y zlSF*{Ri%0E)chJ;`uH?l&D!{=79XDh>*IXFHT3Z}gnIV;EV%UXIk-Nmr{=fd(#PMy z)k+_K4_8kwef$Ga`uIn7|da5T%d*fm^2@AO8!sr@8d=N38U51l&6H_&5?AAI&>gA3MP2;^Qc=n&(cTrbp9vlY1Xacj6*>Ek%KdV1;OHt^EN@o?+Z zEj`A>(t}p+raVB zd_0;7HWwcc1*>`P)XV~RbH-KG{rGmcnzj09PAxvZ1KfT;&PG$up5F;BeS8;OAJtPc z2VDAi7+kIN@!fFs@_w8PFMT{5Zk>93JOXS_bLsaUaOvYbxOM9B@knrdG#`(S0-KAE z^TBGKJ2meGmp&d1SF<)gs>R1+!0q?rv1sbq^KszP$2wde)l+jkxb(3Tu2%Zk1y?Wc z$8LD(V-MUq_4wEewx_xDTL3P7TnM*LJw7f1$4B$=XffDad^`cH=DAa|1YG*q2UoK; zKB~pXrQr7au^&x6dv1VBAD6-PLOnGnf=eF<;A*9hgK+ioejI|AJ}!q_ryd_yfbD56 z{f5D%k1OHUsmI5Y!12+1JUSU{E3L<7dI~(R@6* z7;G*+t^uog?$lfYE`9tQT+Q0}s1_eD1-IXim!YX=&z}dEK3)#jNA=WP0WN*K60TPI zcokf|ydS>+FMa$X+&cC6cs1Cb=F;yPaOvZ3L% z3Aau?K7I>qPji`fGr08e7Pxil@$pu0d^GP||9j>(u(|m7ZLpf>PR;G$(#JdCYSzX_ zHUHLo8Eg7CQpfD?PI7DDIm370YCE7g-<+EBF9AFMQpWbb^_m~7@{U>myUcCCv;Ozs zYFXc$n)8=c^9QQ={x`z+$xW-cp4<)h?+Nkb9=Muy_H@p@;913-`{4c?J9B;jSF;-%8HKvnhjOmGK zCNVL^6w^$@G>z%anwaGO?|Tc+_rBcko$su*_C9;Bd+)jT@rIUlb{JZ#4XUkQ8&-R8 zP^~&PsjY)jLmN`f(`L?^xmRC*$6ovHH{OH|Yb{ltIUCi6)>@edGnTY3>Nc=GV+3Pk z)~jz)VY!`ZtZt!0t@Z?@m5g$2E3sk3dV7`=8&cI;(A9QYPw(Psy*&dsg%R3h>SlHj!)^pA| zeVzUF-s8|c^f#G#IakwjJQRFV&p>}?@2vT~?fvr?!tJ5%(c8CHuGUVe%N3puUcOJg z=5gW)b$p6=qIhs^4!C{({OaDx=hiu^>hZQZXJwtwt8-S%TWjsi-2+QbUewj;X;#_- zX6{kj+3kIub*-h=1?cE7-csuYbK2m+wUx|^8=1x~u6TPDTV)=FXE$1pN_)%|FxJtJ;h z(`UtvYr00vRq>v^0G?Xi`P1rm8m?b;kIv=wx-Zsgb*>k%iOyzIb#!)AJ+Ff=>R3_5 zZUY}b&|jUn_g?P7wrJ~U?{63D^Nxl-jb1)yj;$E-5e>d|#aq_y(dn#ouY2rJpK8%8 zr{*&l&3&yltkB%kYQq`!I)r%$!*$c}=CYT2BE7Od(kr~|<(^3|vAsO3EiZdL4=;OpKhrDsHod~jUf%EY65DGK!(Q^T z*R2G~Uf%P1uhzBq|IONit9x%f?~nD|3uV0r;O>Vl8Jm)2-dHMHQ@lBA>^FwtF$O-A z*`YoPzHS4%#`t>7l&$nh4Q>y$-QlitPsq*Rpn=sB-x%!3diQC$eiYneGhFV-JesOI zGh3q^*53t9kH)PhAJ@=4zg@v&7^yGcjp5Tr>bVDly*6sw z5_8XsM}ytta%&#O?DbXClY9z8Ej5^($Z1&^%9{x16WXliYQ;F?&vGu9N%xs2RVX z*}ATc&j-Q72Gy_M^NpC-?**{jb-!b_hMINdFE+H)8K8!>_2+XFe?BvE>)g(49X0F7 zeMZ#wU?+LseMY;M+rxYGDoQJ}HN5}62kZNw&i%JV$wxQwu?2Sz`CXrU z`Sd34_jVcgySwCbn)uu%-qytD72G=QO?*KU_j^12oO_%2$|ins!9AZF3-0=dn)vD_ z{#X-#yoonz-N5WqeH&_j@;SXTNuyr*M~jUv0?no#uX+%sn5%?6akQ zB{_1REj9N|)@=$N!6?^l23Nndv2Js)Yt;{7ce$6;Qp4{Gd;1PdZ{OeQ>FxKsTG@MR zxV_7D+rZV+dt0z;)$Q%RRZH*fz|H>0psAbhzLqC{dvJ5l9njRxAImJydF%xCd|mH* zcMrxmhVSF#?Ag#CA!kSMUJTc%ujXE;?afesg&zAbd%R9eeLO?WXI7lL6TqHtdhZLi z_e4f=_5-V@_x@nlslP_O)OCJ|9ix{67<%#jV7@)|br7?MzRWq0p{6hK8?`|dR*R=2 zz+;b!UK^>mH4!p1pAlSk1kgUNb9P z?`1r_js+XD*K}ridL0jL?sWp1dU~A%Ry&cAUb8A(?`1r_P6ivZ*Ky49^g0#X-0L(n z_4GO&tTvmGUT0Le-phD;oe4H(uTz-i;yKJ`HL&m6VBfRB^BOqb&W7u)++XwI>habI zR_kD_YP=uL0jqD!uy;GNnjU@cEMyF0q|YL7bDwk3)bqYr40fHmH5V|erN$C)bB%5^ z_2l&YpPZ#=>gjVH*mdgm>0(wB_cD)S_#E{y`y4pG%y|tkFJpKZU(RgYIp@6sY>n{q z!Csq38}Eq=z~-p?%*p#1*0z>7YcB*_E50uRtL2)n1gj?%GMr_b`lk&Nnu=_fIeOwSQthCuPi>)VUAr zdi&ka`~YJW!|Ud?vWA*@&(YiWhT}m-&f~3M_gn7i!{AX2_2j$_Jd%;?{C05jb$$n$ zdg{Cr>^k*aqj!NHW*AHDcZ1y{$$1ahJazp)#H^P5_kv5m?}HC#s3+(B;L`60;LUzN zh^8LDkAPjL9xopPTi;k}KMF4Wei&|^x_(zPt0n&<;L`6$;obxFGv`vrKj-!G!6 z$L|wh*Qv+vm%!FHmfBwimwvwjH&0!^k29-@pJaBg`h0wq*=NVucjHscUuSq2{|2*h z=k)p}*cy2^d<(4R`^@)&{A&zrxmKLD-v*B^_;gqMAO1y^%_rO&Uy_IZwBJbiuxHfEnS%<{y33-(&3 z@9)7675oq2vhN?^=BV3uD)XNhFEVyx7*Fj#gN=p%1w4l_oPSMtKmW=Y#F)U)hxz8{ z(>&u@`!}%nB0m2PE`9z3Ui$ne+#L1nftSGQ*#mlinGx^C)9;l=Ec{h?>HRf$yz9+; zbK>22*1it*{>J;iz@_(p!%OdPz|B#Q_y2&^v#<32Uq-wePd|+r3m*h8y{`k0cfCo> ziFe~!+X8ma#QR`)>Ae+RdLIHeM?Kz$g4MHE_3qze;@x=qtqV35z8<*rzCN5!wP*BZ zzB%!3JZp!8%e}n;e06ay8^TNP8^O&{&)(h`9G}KhdlRs+@Dbp-)Xv`C6wasesSoqb z(WiOFvvxCZxwkimmp(_rOP{0Q=BVcyZvj>>_x6_XcsHJYTY-&*Zw)TJZv*F3dDoly z=ES@4tlbt|?(Na=())Js()$>=IqLB~7OYV02>S65nOuT3C^eTt~c|| ziFe~!yEC}l+q=L^@8jU5_g&%UsK@(mVD)lu?+%Z5+}j7k8CyYcip3T!NV8o2a69nPont~c|| ziFe~!I|ICiT)F$Z3!e?1NA2wG)8Tw7pZYN0 z9DSN+JZsMYmwWq6cx??ZwKos-p_{fsl4mW zd~@R6c-GDbmwUSdUV86@m)_5To1-4@3&85--d+fgcjM`|2y86;TyW`qF`Q53U2o=_ z6Ys{ewhLVD?IrNidpErF-UByBJ>Hjs)yuto9z5QSr(ZAFSa=_}^xhBWQ+d~$`R2sC zaeqhoo66rp&ej^jyo^}sbvayZfZ=*$YUS^=r_h@JCSiOTy7iq?e+6rM;W_W~;c64n ztZz)s^?hL14>0FkR}!0C@C)H{;8*cCF#ipI5!@X0thpFGy;ySz-0!)pxfE`W`elte zmw~NiEOjmiA6L}50`5B{b*_Y)qnw^PsGV#dt#KHUR07Je_-bISSN2R27NweAPs!LXKn^`mC4|Cewex;5usEwfg+ki$}+L7oTwAu?9@8TI)Dtrc@KEc61xEv1edNM<;S9atQhI zQ){ekr;6-WI#kuo80|!qx$XES;#)U7h_9or)j!z1c6ek%*U0eZO{>l_XJTC+e)aIW zp21Z^m!rl8hKKa(Z>?scH#G2U^w<^S{9N%?aeMU$@Ohg@*Pl5&G&(j~(QTZAu15PV z>mKYI+Q&I}Z}-OG-o8P&jyl&h40ZQy-PAW?`aa&>Vb-y}k?t*SSmIr0_~Q!R{R_XO z;5!Qbc(ChQfc6oeitbsr4%}Xy1|HhH@q&TDK1ZqjE6`&#Zm-Tkj|^|=zNBZg4_N!x z$aVLF+p3GeV&)qy$&+lIP^uu+tHnB5o6!3_*p+~y6*7df7t#@JV zYpY&`UNy9-PV`o2`*0L>?zOcwj(34a8vJeO-d_2(YBzZQ{O#2qY-e2CH_*SH6Y%`x zeCj#-w$!!n9H*)-YIK5?ha%eEQ|-fRyLI2Wt|Zg@;WKYsRpGn~a-L^6?}D6nH(YCj zn^bV-80Wi?IL}{hdck@2a)%+#c_~`X{4Vx-1KRp4VZ=VhS~Ir!KE``asU`M3@Ugb` zQMEO0zY<*b@~qM;=aycv%U+&kdWr4jp4m&g>~#lr*=rB>c(46uxxaQ>GwCSv{y9Q{-%}wHHB@x{=znHTVdnY+F8fL!LHZ*Md(K%3y_6KViq^}8;Cg?{Ak2H`Rh3g zxnmLe`{=O*?eic$`Qs2d@31&^j|aQn^gaRH-1|g0`Sd;sY@YlB)Jt9UJ=`E?c?{w# zzAwhx)43jt_Hiy_PDbRMOT4#gsj-}U#&>eVlSM0di8<6Oq9N46l&CC*$!;BpU!v76`G1Sg+!y9{ieeCmvV zU0=>-6l|RQa=xBp=*`GR#QUw^-1+Fs5q+Lvf5UdYuEBYoJ3il+(q~NSYz3Qdzbnws zLM}qQ2cEq(^a!w9eXaeG4kebL(3(92iW-7&jZgb?B|2c$sKzE*!^`6 zl5-_kKIeWF*gSd9@`Y$0&r;vjh@5*PPMsHl%kSNbvF%lUYhQvbpI$Eon}6p44j{(c)45)O_Hiy_UXHv4aV~M@dJVYTgV$m=&-FSu`P{MBgUypqoi~78U+&l& z!N$qwj=c$dE#m&_H`hD%W<;N7cum80y{^G|ojbnVvG~-f^H#9=xnpkwZ%1|@p1n2X zjJuuQJJHH@h`wFuS%|&gj<&bj@9{g)??Qa^zZ%nq zSDf7Ug3XQnKCpK;Yk5D|_r*AUS@Q?L`jYoSa9QUD?72wld~JYuIy<#`eJ_*T(0?B z*e!_X>>S1$lQrwlT;B$lbA1O}?v^6vyWn!o-@|sz#_G!)-v{eU?H_>4+CRjWD{KD< zyIk{+v0byV`cnHRV12QF3NF|DGwiI{IgB?ZYu4{O@B8n2tInK12bXjH0$c7@#C&~n zCP|31VR zdEdJ+^!>=4$TGw}&Maqa=6L{I&io*@-!J*Z{2N@({2y%NQ>d=UQcE6 zoR?(Tq->XLYWDA}7<*=uVKTTWx$axGV%dUrXJElG3y(Bnx2%!}8nZhwRdf$B+&|RR zu?w;Xyxg7=1RASU_%|874Wmj#IJ7Ll`CR*G*6Qm=>Z4BI z5$ozk`rA@esbnt~FvN$T+hfByK2)6ZJ=rnnrOlDC4%(G#JJuu4Ta;Vkp6pEUiq)h2 zoz-jen%a(b$r;~QQ2E@8zyqyLj%(nH5%atay}UKnsgJJAYic`2*A+Dzg_*~00vEBT z<+0k1(N#t4=E5Yln%x9$v<8Pd?OJQi8vWJmX1I0rp~1B!Y5fVf;r{it@%{!8m23-k z9np6WntgJ2bKlk-zDl-DJYK{-0v>9P*J4kfEaUpf;XP9JNBYR`*?MM6$Qb852ip;x z=Pc*F3eGul-mTy&B{vOityiO`AkMRv=V#0km^jOQbGy33HmAdQ&sV>>#Iq4c*asD6 zZhL>zTAqJVtD4^pLwDxxVNQ|H81reT`G&C5eBOg1UvWn;^WI#0+>xF6E1CBsdKcTf z`)DogwAR(wX{}AzQA_NIeF}R{L6fhNdC#Jj*jkejYiXyo{)O%N?}zL~lyS4kVLkEQ z%(0hV#LcS)F@iKgJz^IBy@K=c0Ra>B$a8yIwhC)$^aXwsEgTJ4enu+TK;U z*xTjcNu}-DRwDY%JA{tB-cxP!UW_(I&b->*%b0gHIOc5vYdi0nGDgmM+TJneML#zr zwmBX}vTVaCv7e7&t=m5T=!pNPZWeGjH^B4qj>@?%Wf5=8{5^=n^Zg%86>Sfmm){|^ z_QB{w5?kKimDKO|KWw$X0jcfxI<@^)hi(3e#FpQZ*!uleN4$D#V(Z`5W&6!eta-wx|Vtyy4e`KBV~K(PE>WvyPY^W?2{5ZYm_Ui554&RXIZxDaAF z*XMc8E3s#3&;Q4N7}_}Xf5w#vNU;4MfkezvVEMbt{*MMb zPu`xjUx`@H{Nk8)$W#1;E1=bpKLu1OueoCbDH(aY&zInUhl(I=Ol ziF*_Cmw;VM^s^LfKXC@H19zXn>*3@_$~7$m%UxTp;ViIuny7b0@j631JSbG;X#|!Rh`qvAgrP;N;`2jDVddA9*eT z+h^R*QLu6HaW~rF2IAe&@7%fQ4x%q|je(0b7jJAFJMNUen70nBZ*zGc*MrR^A2DwS zdmbChm`lNO#>QF6x&NMp`=gFtE(53Qx*S_B-km^ynv+>DlswSEF@?RX=%f#WRn#W{Wwtj`=1Xl?Np^m4>=`V`tdQ2TxPH2O1$ zL;q*d`qfeY4zM|5O?QG_cePTyNCB z8*Kg9uP=g)k&iw35;*omf8_o$Sl=}MQk#qZ3Ni^f6tS-U*e`wVf%{>P>geIC;IxOY zVapwW#2(9~d+b{E-GiQ2;;m@!sBvFMTSx6#+>8Du;?Vysw0?E;@olg<;w-)c?mmnA z;N;^hz6*BVeTZkF{SCx=<`+l)?}5|3|30?$V()(dHby@7{(f-mz5dwy2f+Hm{vmh* z*;IaKe*`v0ej0x(%|-th@qA_?`rW^6X!+>zCt%-5*xSKrj}KxSBOkSY3N}yl{4;R$ zY`k^kjEgxx2RkS1Uw}Q!jb;B2f#r;gbC&b$JV*Cd9eq9wPW${Nwp`pHeRApD@QkAV zufS>j$Fb!eL*nnPC%}&)`t0cuw4C@!blknC!1ky1?)?V+Tf|}g-=X!ZqyOK7%@KF+ z58&>*_eVJSxO-27ohR?z)BZJLJ@bnr|DV9=x&9g3dU5yu0yaiI_T;Z%*BIyeH*lP* z@e%iTu)g>^;2E%-^JDI_;B>BgCTFaB;{MpPI(mK%oc8z+Y`Hj7eR9r^J$LQeH=@m_ S*1n*$)!H*lTkY@Famb5l10^H? literal 0 HcmV?d00001 diff --git a/shaders/spv/sign-activation.spv b/shaders/spv/sign-activation.spv new file mode 100644 index 0000000000000000000000000000000000000000..a2334b0cf3e4f45dde2a3eda1af1ba5bb577967e GIT binary patch literal 1520 zcmYk6+iKKM6o&U4JGS;@PabQg^;{iKtq7u0w8Vjd5qy9UXF?66$>3xP-uXJ-_yFDr zzLX*e{@+eE%`S_z|9>6#TA4QLyK~0W%z{}o2c|fyrY^>aYg_FPJ`8%(eB9eVc!bB2 zX;?>mmgUouZc0wVIK{9aSrHGn3L!^)k%s>2(!T_nX5RRNz#k6%!BH@pOoD0tC5eJ0 zPNQHPe~w0Zd>Y}`srbd|D4C6`4lQ|9>$g;+qZ-Hkx9Y89`#DYmKgm9ZNk2W!MqwUj zsjNo!mhJKVZ2IXzEKXl`Xi<#O^@u9-!b}7EIV0}us3$~ zvaGvY@~@Nuf3UNy=nh9bbKjD%EB4`h_%7kU%ZFWDmJkR3TRHj7asH|vwkr}idgcC4 zkLC{Na`Np{*Wo)QKGZCj^*ssmrw?rS6|*mFKgwPn$hIe87vB_%5C4SW*#|ptIUno3 QBOxdCnFk#H^!QxzALrO=&;S4c literal 0 HcmV?d00001 diff --git a/shaders/spv/softmax-fast.spv b/shaders/spv/softmax-fast.spv new file mode 100644 index 0000000000000000000000000000000000000000..5b5c447b2b02ecfa894117d518727b34b70bb34d GIT binary patch literal 4764 zcmaKuX>(Ln5QcAUHe4jSY#?P9huBL z`n5U8TuOnzap<2R#pKAOtEaSX?fOzWFd*d<(w-2%AzR$PhFH_Ih$n${*uQ+d z6Swm>1AWa%V_@}Az0qtWM0TV$%*mE=bL&8=v|LMl7l6C#+eVt`4lHuN-g2|7cCO{#k6>D>aW1-rACG2rg?_Z)f5+#o#<`DeWMIbV_%1F zucWC`exSuomGY0_SC2H)MY$%qUw%yo;lw0e^v5%;oP@~Tce(q}#hN|qe)`1v&PaJ7 zIRWjb&;8VAza8gz@4#AW5c|z-sY5pRPFoRij%DbML%fjgN6Zm>6LT29oY>ijU+_6C zn(~jp`T6Wc%olyh&u3qp&$^sXo}ceWxNAHSIi8uEYc^{wpv6-eW4=zruM<8Vt!SSH zKO)19(SIb`eA<_1+#K5G!yPvX@iYD?v_9<-@1ALs>vx~>+`W;{!tY)x6^ z8F}t+LZ18P=ech~aL<5mah@+3!+me_{_ZincMSJkjd*os4EJ3Pzxj8L;l8bT|ACCJ zq6^-uqY>|>yUT#IIwp!AAdYtyYco<9&`DgdUy5NOL_QwbEaoL z`zjBgZ@0P4+liiyc!xb#5i=*Ve@2XV^i;%g+J7V8>1e+{S(m=k5N)0T@n0M!v9_4= z3~)a0e7JeN_Yt!ItbGr8?blg|1IVlC%_$If{kHtkAILn3!h`*zv}CKjh)N9IUOsts~937TksS>2r*CMth980_>RJ>%hi5 z$QrG?2QfzbB-W_yN~8_(H$ngAEZgp`UeYh6v=RRn^2GM3e z#Mb1QxB>A@M~xf7*6117guW5E9x=|Ga${or&EWj_Tj16e<8KA~P8;WVxiRiR)T91| zdVIf?5)$*@26q0Bv-4I4Ynz0)&aG%a*Qsv{qRn-R&Do1~-^G3Cek9KL0C*;nzc*F5 z_Be}ofVJI@Jd~Z!L9q53;=C2Kwy3)f)^;$fdkCy;J!0K=qW!E}-!?>>b&F$8)IpdqhReZ z<}q-7%;Rv^uRX>*0oES;NpQaIr{KnDk9Xr~aMZ0o>V5`n4f=OyYk3yzdi2>xd#25^ z;<<5u)lu_vU~7*1@H|+XcP7rFw)|Q2oW%GSz*CU8!~4P7Jip$Bm(YHmGkq^2+N?ty zeRvscAL8u10=9Rp--ALTHzwj=1-rJ0e+}&ZM7-P>>yCVCdAuR&v2Rn%_XfVmA35K| zrro~0j+RIKTj0na@o!_(Zu|kXT%YIb9b_`%_)he@$k?~)z04l>_I2}pJ^HL51Ik8u>b%7 literal 0 HcmV?d00001 diff --git a/tests/_tokenizer_parity_helpers.py b/tests/_tokenizer_parity_helpers.py new file mode 100644 index 0000000..c3667a3 --- /dev/null +++ b/tests/_tokenizer_parity_helpers.py @@ -0,0 +1,68 @@ +"""Shared helpers for tokenizer/SentencePiece parity tests.""" + +from __future__ import annotations + +import pytest + + +def load_hf_tokenizer(model_id: str): + """Load Hugging Face tokenizer or skip with context.""" + transformers = pytest.importorskip("transformers") + try: + return transformers.AutoTokenizer.from_pretrained(model_id) + except Exception as exc: # pragma: no cover - depends on environment/cache + pytest.skip(f"Hugging Face tokenizer unavailable for {model_id}: {exc}") + + +def load_grilly_tokenizer(model_id: str): + """Load Grilly tokenizer through likely in-progress API shapes or skip.""" + try: + from grilly import tokenizers as grilly_tokenizers + except Exception as exc: + pytest.skip(f"grilly.tokenizers not available yet: {exc}") + + try: + if hasattr(grilly_tokenizers, "Tokenizer"): + tok_cls = getattr(grilly_tokenizers, "Tokenizer") + if hasattr(tok_cls, "from_pretrained"): + return tok_cls.from_pretrained(model_id) + if hasattr(grilly_tokenizers, "AutoTokenizer"): + auto_cls = getattr(grilly_tokenizers, "AutoTokenizer") + if hasattr(auto_cls, "from_pretrained"): + return auto_cls.from_pretrained(model_id) + if hasattr(grilly_tokenizers, "from_pretrained"): + return grilly_tokenizers.from_pretrained(model_id) + except Exception as exc: + pytest.skip(f"Grilly tokenizer could not load {model_id}: {exc}") + + pytest.skip("No supported tokenizer loading API found in grilly.tokenizers") + + +def extract_input_ids(encoded): + """Normalize encoder output to plain input IDs.""" + if isinstance(encoded, dict): + ids = encoded.get("input_ids") + if ids is None: + raise AssertionError("Encoded dict missing 'input_ids'") + return ids + if hasattr(encoded, "input_ids"): + ids = encoded.input_ids + if hasattr(ids, "tolist"): + return ids.tolist() + return ids + return encoded + + +def encode_ids(tokenizer, text: str, add_special_tokens: bool = True): + """Encode text with a flexible tokenizer call surface.""" + if hasattr(tokenizer, "encode"): + try: + return extract_input_ids( + tokenizer.encode(text, add_special_tokens=add_special_tokens) + ) + except TypeError: + return extract_input_ids(tokenizer.encode(text)) + if hasattr(tokenizer, "__call__"): + return extract_input_ids(tokenizer(text, add_special_tokens=add_special_tokens)) + raise AssertionError("Tokenizer has neither encode() nor __call__()") + diff --git a/tests/autograd_chain/test_autograd_chain_placeholder.py b/tests/autograd_chain/test_autograd_chain_placeholder.py new file mode 100644 index 0000000..8e84a33 --- /dev/null +++ b/tests/autograd_chain/test_autograd_chain_placeholder.py @@ -0,0 +1,13 @@ +"""Pre-v1.0 scaffold: autograd chain-recorder execution test suite placeholder.""" + +import pytest + +pytestmark = pytest.mark.skip( + reason="Scaffold only: implement autograd chain-recorder tests in pre-v1.0 roadmap." +) + + +def test_autograd_chain_scaffold(): + """Placeholder test to reserve CI target.""" + assert True + diff --git a/tests/conftest.py b/tests/conftest.py index 897f42e..3594913 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -27,12 +27,16 @@ def pytest_configure(config): try: import grilly - from grilly.backend.base import VULKAN_AVAILABLE + from grilly.backend.base import ( + VULKAN_AVAILABLE, + VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE, + ) GRILLY_AVAILABLE = True except (ImportError, AttributeError, Exception): GRILLY_AVAILABLE = False VULKAN_AVAILABLE = False + VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE = False # C++ backend (grilly_core with NN framework classes) try: @@ -46,8 +50,10 @@ def pytest_configure(config): @pytest.fixture def gpu_backend(): """Fixture for GPU backend (skips if not available)""" - if not VULKAN_AVAILABLE: - pytest.skip("Vulkan not available") + if not VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE: + pytest.skip( + "Vulkan Compute() not available (needs C++ grilly_core GPU + pip install vulkan)" + ) try: from grilly import Compute diff --git a/tests/converter/test_pytorch_converter_placeholder.py b/tests/converter/test_pytorch_converter_placeholder.py new file mode 100644 index 0000000..94ac60b --- /dev/null +++ b/tests/converter/test_pytorch_converter_placeholder.py @@ -0,0 +1,13 @@ +"""Pre-v1.0 scaffold: PyTorch -> Grilly converter test suite placeholder.""" + +import pytest + +pytestmark = pytest.mark.skip( + reason="Scaffold only: implement converter tests in pre-v1.0 roadmap." +) + + +def test_pytorch_converter_scaffold(): + """Placeholder test to reserve CI target.""" + assert True + diff --git a/tests/moe_quant/test_moe_quant_placeholder.py b/tests/moe_quant/test_moe_quant_placeholder.py new file mode 100644 index 0000000..8cce829 --- /dev/null +++ b/tests/moe_quant/test_moe_quant_placeholder.py @@ -0,0 +1,13 @@ +"""Pre-v1.0 scaffold: MoE quantized expert test suite placeholder.""" + +import pytest + +pytestmark = pytest.mark.skip( + reason="Scaffold only: implement MoE quantization tests in pre-v1.0 roadmap." +) + + +def test_moe_quant_scaffold(): + """Placeholder test to reserve CI target.""" + assert True + diff --git a/tests/sentence_transformers/test_sentence_transformers_parity.py b/tests/sentence_transformers/test_sentence_transformers_parity.py new file mode 100644 index 0000000..a04688b --- /dev/null +++ b/tests/sentence_transformers/test_sentence_transformers_parity.py @@ -0,0 +1,79 @@ +"""Sentence-transformers adaptive parity tests against reference implementation.""" + +from __future__ import annotations + +import numpy as np +import pytest + +sentence_transformers = pytest.importorskip("sentence_transformers") + + +def _load_reference_model(model_name: str): + try: + return sentence_transformers.SentenceTransformer(model_name, device="cpu") + except Exception as exc: # pragma: no cover - env/model cache dependent + pytest.skip(f"Reference sentence-transformers model unavailable: {exc}") + + +def _load_grilly_model(model_name: str): + try: + from grilly.utils.vulkan_sentence_transformer import VulkanSentenceTransformer + except Exception as exc: + pytest.skip(f"Grilly Vulkan sentence transformer unavailable: {exc}") + + try: + return VulkanSentenceTransformer(model_name=model_name) + except Exception as exc: # pragma: no cover - depends on Vulkan/model assets + pytest.skip(f"Could not initialize VulkanSentenceTransformer: {exc}") + + +def _as_float32(x): + arr = np.asarray(x) + if arr.dtype != np.float32: + arr = arr.astype(np.float32) + return arr + + +@pytest.mark.gpu +def test_sentence_transformer_encode_shape_dtype_parity(): + """Embedding shape/dtype should align with reference model outputs.""" + model_name = "all-MiniLM-L6-v2" + texts = [ + "Grilly runs on Vulkan across vendors.", + "Tokenizer and encoder should stay GPU-resident.", + ] + + ref = _load_reference_model(model_name) + gr = _load_grilly_model(model_name) + + ref_emb = _as_float32(ref.encode(texts, normalize_embeddings=True)) + gr_emb = _as_float32(gr.encode(texts, normalize_embeddings=True)) + + assert ref_emb.shape == gr_emb.shape + assert gr_emb.dtype == np.float32 + assert np.all(np.isfinite(gr_emb)) + + +@pytest.mark.gpu +def test_sentence_transformer_similarity_top1_matches_reference(): + """Top-1 nearest sentence should match reference similarity ranking.""" + model_name = "all-MiniLM-L6-v2" + texts = [ + "How do I speed up GPU dispatch batching?", + "Use one submit for many kernels with command recording.", + "Bananas are yellow and sweet.", + "Fused loss kernels reduce synchronization overhead.", + ] + + ref = _load_reference_model(model_name) + gr = _load_grilly_model(model_name) + + ref_emb = _as_float32(ref.encode(texts, normalize_embeddings=True)) + gr_emb = _as_float32(gr.encode(texts, normalize_embeddings=True)) + + # Query = first sentence, compare against candidates [1:] + ref_sims = ref_emb[1:] @ ref_emb[0] + gr_sims = gr_emb[1:] @ gr_emb[0] + + assert int(np.argmax(gr_sims)) == int(np.argmax(ref_sims)) + diff --git a/tests/sentencepiece/test_sentencepiece_parity.py b/tests/sentencepiece/test_sentencepiece_parity.py new file mode 100644 index 0000000..e179213 --- /dev/null +++ b/tests/sentencepiece/test_sentencepiece_parity.py @@ -0,0 +1,50 @@ +"""SentencePiece compatibility tests (adaptive to in-progress tokenizer API).""" + +import pytest + +sentencepiece = pytest.importorskip("sentencepiece") +from tests._tokenizer_parity_helpers import ( + encode_ids, + extract_input_ids, + load_grilly_tokenizer, + load_hf_tokenizer, +) + + +@pytest.mark.parametrize( + "model_id,text", + [ + ("t5-small", "Translate English to German: A tiny test sentence."), + ("google/mt5-small", "A multilingual sentencepiece parity check."), + ], +) +def test_sentencepiece_ids_match_hf_reference(model_id: str, text: str): + """SentencePiece-backed token IDs should match HF reference on supported models.""" + # Keep explicit import usage for environment signaling. + assert sentencepiece.__name__ == "sentencepiece" + + hf_tok = load_hf_tokenizer(model_id) + gr_tok = load_grilly_tokenizer(model_id) + + hf_ids = extract_input_ids(hf_tok(text)) + gr_ids = encode_ids(gr_tok, text) + + assert list(gr_ids) == list(hf_ids) + + +def test_sentencepiece_special_tokens_alignment(): + """Special token insertion should align with HF on canonical T5 tokenizer.""" + model_id = "t5-small" + text = "summarize: Grilly targets Vulkan GPUs." + + hf_tok = load_hf_tokenizer(model_id) + gr_tok = load_grilly_tokenizer(model_id) + + hf_with = extract_input_ids(hf_tok(text, add_special_tokens=True)) + hf_without = extract_input_ids(hf_tok(text, add_special_tokens=False)) + gr_with = encode_ids(gr_tok, text, add_special_tokens=True) + gr_without = encode_ids(gr_tok, text, add_special_tokens=False) + + assert list(gr_with) == list(hf_with) + assert list(gr_without) == list(hf_without) + diff --git a/tests/test_attention.py b/tests/test_attention.py index 289b21f..77ffe93 100644 --- a/tests/test_attention.py +++ b/tests/test_attention.py @@ -7,13 +7,13 @@ try: from grilly import Compute - from grilly.backend.base import VULKAN_AVAILABLE + from grilly.backend.base import VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE except ImportError: pytest.skip("grilly not available", allow_module_level=True) @pytest.mark.gpu -@pytest.mark.skipif(not VULKAN_AVAILABLE, reason="Vulkan not available") +@pytest.mark.skipif(not VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE, reason="Vulkan not available") class TestAttentionOperations: """Test attention operations on GPU""" diff --git a/tests/test_attention_long_sequence.py b/tests/test_attention_long_sequence.py index b9b8813..94207fb 100644 --- a/tests/test_attention_long_sequence.py +++ b/tests/test_attention_long_sequence.py @@ -10,14 +10,14 @@ try: from grilly import Compute - from grilly.backend.base import VULKAN_AVAILABLE + from grilly.backend.base import VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE except ImportError: pytest.skip("grilly not available", allow_module_level=True) @pytest.mark.gpu @pytest.mark.parametrize("seq_len", [128, 256, 512]) -@pytest.mark.skipif(not VULKAN_AVAILABLE, reason="Vulkan not available") +@pytest.mark.skipif(not VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE, reason="Vulkan not available") def test_flash_attention2_long_sequence_finite(seq_len): backend = Compute() try: @@ -47,7 +47,7 @@ def test_flash_attention2_long_sequence_finite(seq_len): @pytest.mark.gpu @pytest.mark.slow @pytest.mark.parametrize("seq_len", [1024]) -@pytest.mark.skipif(not VULKAN_AVAILABLE, reason="Vulkan not available") +@pytest.mark.skipif(not VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE, reason="Vulkan not available") def test_flash_attention2_very_long_sequence_finite(seq_len): backend = Compute() try: diff --git a/tests/test_backward.py b/tests/test_backward.py index 1e1daf3..306f51f 100644 --- a/tests/test_backward.py +++ b/tests/test_backward.py @@ -7,13 +7,13 @@ try: from grilly import Compute - from grilly.backend.base import VULKAN_AVAILABLE + from grilly.backend.base import VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE except ImportError: pytest.skip("grilly not available", allow_module_level=True) @pytest.mark.gpu -@pytest.mark.skipif(not VULKAN_AVAILABLE, reason="Vulkan not available") +@pytest.mark.skipif(not VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE, reason="Vulkan not available") class TestBackwardOperations: """Test backward pass operations on GPU""" @@ -181,7 +181,7 @@ def test_softmax_backward(self, gpu): @pytest.mark.gpu -@pytest.mark.skipif(not VULKAN_AVAILABLE, reason="Vulkan not available") +@pytest.mark.skipif(not VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE, reason="Vulkan not available") class TestAutogradIntegration: """Test autograd integration with nn.Module""" @@ -240,7 +240,7 @@ def test_training_context(self): @pytest.mark.gpu -@pytest.mark.skipif(not VULKAN_AVAILABLE, reason="Vulkan not available") +@pytest.mark.skipif(not VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE, reason="Vulkan not available") class TestEndToEndTraining: """Test end-to-end training loop""" diff --git a/tests/test_bridge_moe.py b/tests/test_bridge_moe.py new file mode 100644 index 0000000..62b1ca3 --- /dev/null +++ b/tests/test_bridge_moe.py @@ -0,0 +1,14 @@ +"""Bridge exposes fused MoE alongside other grilly_core ops (lazy device + shaders).""" + +from grilly.backend import _bridge + + +def test_bridge_moe_functions_exported(): + for name in ( + "moe_upload", + "moe_forward", + "moe_backward", + "moe_update_weights", + "moe_release", + ): + assert hasattr(_bridge, name), f"missing _bridge.{name}" diff --git a/tests/test_bridge_vsa_lm.py b/tests/test_bridge_vsa_lm.py new file mode 100644 index 0000000..8ac6bbc --- /dev/null +++ b/tests/test_bridge_vsa_lm.py @@ -0,0 +1,14 @@ +"""Bridge exposes fused VSA-LM alongside other grilly_core ops (lazy device + shaders).""" + +from grilly.backend import _bridge + + +def test_bridge_vsa_lm_functions_exported(): + for name in ( + "vsa_lm_upload", + "vsa_lm_forward", + "vsa_lm_backward", + "vsa_lm_update_weights", + "vsa_lm_release", + ): + assert hasattr(_bridge, name), f"missing _bridge.{name}" diff --git a/tests/test_conv_backward_weight_gemm.py b/tests/test_conv_backward_weight_gemm.py index 265818f..bca9071 100644 --- a/tests/test_conv_backward_weight_gemm.py +++ b/tests/test_conv_backward_weight_gemm.py @@ -15,13 +15,13 @@ try: from grilly.backend.compute import VulkanCompute - from grilly.backend.base import VULKAN_AVAILABLE + from grilly.backend.base import VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE except ImportError: pytest.skip("grilly not available", allow_module_level=True) @pytest.mark.gpu -@pytest.mark.skipif(not VULKAN_AVAILABLE, reason="Vulkan not available") +@pytest.mark.skipif(not VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE, reason="Vulkan not available") def test_conv2d_backward_weight_matches_torch_gemm_path(): """Small conv: dW from Grilly GEMM path vs torch autograd.""" np.random.seed(123) diff --git a/tests/test_core.py b/tests/test_core.py index 597b5a8..74b08d3 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -7,12 +7,13 @@ try: import grilly - from grilly.backend.base import VULKAN_AVAILABLE + from grilly.backend.base import VULKAN_AVAILABLE, VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE GRILLY_AVAILABLE = True except ImportError: GRILLY_AVAILABLE = False VULKAN_AVAILABLE = False + VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE = False class TestGrillyImports: @@ -49,7 +50,9 @@ def test_vulkan_available_flag(self): @pytest.mark.gpu -@pytest.mark.skipif(not VULKAN_AVAILABLE, reason="Vulkan not available") +@pytest.mark.skipif( + not VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE, reason="Vulkan not available" +) class TestComputeInitialization: """Test Compute backend initialization""" diff --git a/tests/test_fnn_chain.py b/tests/test_fnn_chain.py new file mode 100644 index 0000000..a1cdfc2 --- /dev/null +++ b/tests/test_fnn_chain.py @@ -0,0 +1,75 @@ +"""FnnChainRecorder: many dispatches, one submit (GPU).""" + +import warnings + +import numpy as np +import pytest + +try: + from grilly.backend.compute import VulkanCompute + from grilly.backend.base import VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE + from grilly.backend.fnn_chain import ChainBufferHandle +except ImportError: + pytest.skip("grilly not available", allow_module_level=True) + + +@pytest.mark.gpu +@pytest.mark.skipif(not VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE, reason="Vulkan not available") +def test_fnn_chain_linear_relu_read(): + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + backend = VulkanCompute() + try: + if "fnn-linear" not in backend.fnn.shaders or "activation-relu" not in backend.fnn.shaders: + pytest.skip("fnn-linear or activation-relu not available") + + rng = np.random.default_rng(0) + x = rng.standard_normal((3, 8), dtype=np.float32) + w1 = rng.standard_normal((16, 8), dtype=np.float32) + b1 = rng.standard_normal((16,), dtype=np.float32) + w2 = rng.standard_normal((4, 16), dtype=np.float32) + + with backend.record_commands(fnn_chain=True) as rec: + h = rec.linear(x, w1, b1) + assert isinstance(h, ChainBufferHandle) + h = rec.relu(h) + h = rec.linear(h, w2, None) + out = rec.read(h) + + ref = np.maximum(0, x @ w1.T + b1) @ w2.T + np.testing.assert_allclose(out, ref, rtol=1e-4, atol=1e-4) + finally: + backend.cleanup() + + +@pytest.mark.gpu +@pytest.mark.skipif(not VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE, reason="Vulkan not available") +def test_fnn_chain_read_multiple_moe_fanout(): + """Parallel expert linears: one submit, multiple downloads (MoE pattern).""" + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + backend = VulkanCompute() + try: + if "fnn-linear" not in backend.fnn.shaders: + pytest.skip("fnn-linear not available") + + rng = np.random.default_rng(1) + x = rng.standard_normal((2, 8), dtype=np.float32) + w0 = rng.standard_normal((4, 8), dtype=np.float32) + w1 = rng.standard_normal((4, 8), dtype=np.float32) + w2 = rng.standard_normal((4, 8), dtype=np.float32) + w3 = rng.standard_normal((4, 8), dtype=np.float32) + + with backend.record_commands(fnn_chain=True) as rec: + h0 = rec.linear(x, w0, None) + h1 = rec.linear(x, w1, None) + h2 = rec.linear(x, w2, None) + h3 = rec.linear(x, w3, None) + results = rec.read_multiple([h0, h1, h2, h3]) + + assert len(results) == 4 + refs = [x @ w0.T, x @ w1.T, x @ w2.T, x @ w3.T] + for got, ref in zip(results, refs): + np.testing.assert_allclose(got, ref, rtol=1e-4, atol=1e-4) + finally: + backend.cleanup() diff --git a/tests/test_functional.py b/tests/test_functional.py index a50b7a9..4facaeb 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -5,14 +5,14 @@ try: from grilly import Compute - from grilly.backend.base import VULKAN_AVAILABLE + from grilly.backend.base import VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE from grilly.functional import dropout, linear except ImportError: pytest.skip("grilly not available", allow_module_level=True) @pytest.mark.gpu -@pytest.mark.skipif(not VULKAN_AVAILABLE, reason="Vulkan not available") +@pytest.mark.skipif(not VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE, reason="Vulkan not available") class TestFunctionalLinear: """Tests for functional.linear.""" @@ -40,7 +40,7 @@ def test_linear_no_bias(self, backend): @pytest.mark.gpu -@pytest.mark.skipif(not VULKAN_AVAILABLE, reason="Vulkan not available") +@pytest.mark.skipif(not VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE, reason="Vulkan not available") class TestFunctionalDropout: """Tests for functional.dropout.""" diff --git a/tests/test_gpu_operations.py b/tests/test_gpu_operations.py index 8007a45..cdabefb 100644 --- a/tests/test_gpu_operations.py +++ b/tests/test_gpu_operations.py @@ -7,13 +7,13 @@ try: from grilly import Compute - from grilly.backend.base import VULKAN_AVAILABLE + from grilly.backend.base import VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE except ImportError: pytest.skip("grilly not available", allow_module_level=True) @pytest.mark.gpu -@pytest.mark.skipif(not VULKAN_AVAILABLE, reason="Vulkan not available") +@pytest.mark.skipif(not VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE, reason="Vulkan not available") class TestLIFOperations: """Test LIF neuron operations on GPU""" @@ -82,7 +82,7 @@ def test_lif_refractory_period(self, gpu): @pytest.mark.gpu -@pytest.mark.skipif(not VULKAN_AVAILABLE, reason="Vulkan not available") +@pytest.mark.skipif(not VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE, reason="Vulkan not available") class TestFNNOperations: """Test FNN operations on GPU""" @@ -202,7 +202,7 @@ def test_layernorm(self, gpu): @pytest.mark.gpu -@pytest.mark.skipif(not VULKAN_AVAILABLE, reason="Vulkan not available") +@pytest.mark.skipif(not VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE, reason="Vulkan not available") class TestFAISSOperations: """Test FAISS operations on GPU""" diff --git a/tests/test_grl_checkpoint.py b/tests/test_grl_checkpoint.py new file mode 100644 index 0000000..f1fd8d2 --- /dev/null +++ b/tests/test_grl_checkpoint.py @@ -0,0 +1,61 @@ +"""GRL (.grl) checkpoint format — roundtrip and magic/version checks.""" + +from __future__ import annotations + +import tempfile + +import numpy as np +import pytest + +from grilly.utils.grl_checkpoint import ( + FORMAT_VERSION, + HEADER_SIZE, + MAGIC, + load_grl, + save_grl, +) + + +def test_grl_roundtrip_nested_state(): + state = { + "embed": {"weight": np.random.randn(10, 8).astype(np.float32)}, + "layers": { + "0": {"w": np.ones((4, 4), dtype=np.float32)}, + }, + } + meta = {"training_step": 42, "best_ppl": 3.14} + + with tempfile.NamedTemporaryFile(suffix=".grl", delete=False) as f: + path = f.name + try: + save_grl(path, state, metadata=meta) + out = load_grl(path) + assert "metadata" in out + assert out["metadata"]["schema"] == "grilly.checkpoint.v1" + assert out["training_step"] == 42 + m = out["model"] + assert np.allclose(m["embed"]["weight"], state["embed"]["weight"]) + assert np.allclose(m["layers"]["0"]["w"], state["layers"]["0"]["w"]) + finally: + import os + + os.unlink(path) + + +def test_grl_rejects_bad_magic(): + with tempfile.NamedTemporaryFile(suffix=".grl", delete=False) as f: + f.write(b"XXXX") + path = f.name + try: + with pytest.raises(ValueError, match="GRL"): + load_grl(path) + finally: + import os + + os.unlink(path) + + +def test_header_constants_match_cpp_layout(): + assert len(MAGIC) == 4 + assert HEADER_SIZE == 64 + assert FORMAT_VERSION == 1 diff --git a/tests/test_inference_ops.py b/tests/test_inference_ops.py index c8543e2..9016323 100644 --- a/tests/test_inference_ops.py +++ b/tests/test_inference_ops.py @@ -8,7 +8,7 @@ try: from grilly import Compute - from grilly.backend.base import VULKAN_AVAILABLE + from grilly.backend.base import VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE except ImportError: pytest.skip("grilly not available", allow_module_level=True) @@ -96,7 +96,7 @@ def _ref_gqa_decode_attention( @pytest.mark.gpu -@pytest.mark.skipif(not VULKAN_AVAILABLE, reason="Vulkan not available") +@pytest.mark.skipif(not VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE, reason="Vulkan not available") class TestRMSNormGPU: """Test RMSNorm on GPU""" @@ -174,7 +174,7 @@ def test_rms_norm_eps_affects_output(self, gpu): @pytest.mark.gpu -@pytest.mark.skipif(not VULKAN_AVAILABLE, reason="Vulkan not available") +@pytest.mark.skipif(not VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE, reason="Vulkan not available") class TestSwiGLUFusedGPU: """Test fused SwiGLU on GPU""" @@ -231,7 +231,7 @@ def test_swiglu_fused_output_shape(self, gpu): @pytest.mark.gpu -@pytest.mark.skipif(not VULKAN_AVAILABLE, reason="Vulkan not available") +@pytest.mark.skipif(not VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE, reason="Vulkan not available") class TestGEMMInt8GPU: """Test INT8 weight-only GEMM on GPU""" @@ -295,7 +295,7 @@ def test_gemm_int8_group_sizes(self, gpu, group_size): @pytest.mark.gpu -@pytest.mark.skipif(not VULKAN_AVAILABLE, reason="Vulkan not available") +@pytest.mark.skipif(not VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE, reason="Vulkan not available") class TestGQADecodeAttentionGPU: """Test GQA decode attention on GPU""" diff --git a/tests/test_integration.py b/tests/test_integration.py index 8fc3f7e..963171d 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -7,7 +7,7 @@ try: from grilly import Compute, SNNCompute - from grilly.backend.base import VULKAN_AVAILABLE + from grilly.backend.base import VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE except ImportError: pytest.skip("grilly not available", allow_module_level=True) @@ -53,7 +53,7 @@ def test_snn_temporal_dynamics(self): @pytest.mark.gpu -@pytest.mark.skipif(not VULKAN_AVAILABLE, reason="Vulkan not available") +@pytest.mark.skipif(not VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE, reason="Vulkan not available") class TestGPUIntegration: """Integration tests for GPU operations""" diff --git a/tests/test_learning.py b/tests/test_learning.py index e734ac9..5c97095 100644 --- a/tests/test_learning.py +++ b/tests/test_learning.py @@ -7,13 +7,13 @@ try: from grilly import Compute, Learning - from grilly.backend.base import VULKAN_AVAILABLE + from grilly.backend.base import VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE except ImportError: pytest.skip("grilly not available", allow_module_level=True) @pytest.mark.gpu -@pytest.mark.skipif(not VULKAN_AVAILABLE, reason="Vulkan not available") +@pytest.mark.skipif(not VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE, reason="Vulkan not available") class TestLearningOperations: """Test learning operations on GPU""" diff --git a/tests/test_memory_operations.py b/tests/test_memory_operations.py index 3b0cb33..4b46ab8 100644 --- a/tests/test_memory_operations.py +++ b/tests/test_memory_operations.py @@ -7,13 +7,13 @@ try: from grilly import Compute - from grilly.backend.base import VULKAN_AVAILABLE + from grilly.backend.base import VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE except ImportError: pytest.skip("grilly not available", allow_module_level=True) @pytest.mark.gpu -@pytest.mark.skipif(not VULKAN_AVAILABLE, reason="Vulkan not available") +@pytest.mark.skipif(not VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE, reason="Vulkan not available") class TestMemoryOperations: """Test memory read/write operations""" diff --git a/tests/test_mf_activations.py b/tests/test_mf_activations.py new file mode 100644 index 0000000..713d566 --- /dev/null +++ b/tests/test_mf_activations.py @@ -0,0 +1,50 @@ +"""Multiplication-free activations: numpy + autograd.""" + +import numpy as np +import pytest +from grilly.functional.mf_activations import ( + mf_sigmoid, + mf_sigmoid_01, + mf_softmax, + mf_softplus, +) +from grilly.nn import autograd as ag + + +def test_mf_softmax_rows_sum_to_one(): + x = np.array([[1.0, 2.0, 0.0], [0.0, 0.0, 0.0]], dtype=np.float32) + y = mf_softmax(x, dim=-1) + assert y.shape == x.shape + np.testing.assert_allclose(y.sum(axis=-1), np.ones(2), rtol=1e-5) + + +def test_mf_softplus_matches_algebra(): + x = np.array([-2.0, 0.0, 3.0], dtype=np.float32) + b = 1.0 + c = 4.0 / (b * b) + want = 0.5 * (x + np.sqrt(x * x + c)) + got = mf_softplus(x, beta=b) + np.testing.assert_allclose(got, want, rtol=1e-6) + + +def test_mf_sigmoid_bounds(): + x = np.linspace(-5, 5, 11, dtype=np.float32) + y = mf_sigmoid(x) + assert float(y.min()) >= -1.0 and float(y.max()) <= 1.0 + z = mf_sigmoid_01(x) + assert float(z.min()) >= 0.0 and float(z.max()) <= 1.0 + + +def test_mf_softmax_autograd(): + v = ag.Variable(np.array([[1.0, 2.0]], dtype=np.float32), requires_grad=True) + y = ag.mf_softmax(v, dim=-1) + assert y.data.sum() == pytest.approx(1.0) + y.backward() + assert v.grad is not None + + +def test_mf_softplus_autograd(): + v = ag.Variable(np.array([0.5], dtype=np.float32), requires_grad=True) + y = ag.mf_softplus(v, beta=1.0) + y.backward() + assert v.grad is not None diff --git a/tests/test_mf_ops_core.py b/tests/test_mf_ops_core.py new file mode 100644 index 0000000..a5aa4ab --- /dev/null +++ b/tests/test_mf_ops_core.py @@ -0,0 +1,69 @@ +"""grilly_core mf_softmax / mf_softplus / mf_sigmoid — GPU vs NumPy reference.""" + +from __future__ import annotations + +import pathlib + +import numpy as np +import pytest + +try: + import grilly_core as gc +except ImportError: + pytest.skip("grilly_core not available", allow_module_level=True) + +from grilly.functional.mf_activations import mf_sigmoid, mf_softmax, mf_softplus + + +def _shader_spv_dir() -> pathlib.Path: + return pathlib.Path(__file__).resolve().parent.parent / "shaders" / "spv" + + +def _require_mf_symbols() -> None: + for name in ("mf_softmax", "mf_softplus", "mf_sigmoid"): + if not hasattr(gc, name): + pytest.skip(f"grilly_core.{name} not in this build — rebuild extension") + + +@pytest.mark.gpu +@pytest.mark.cpp +def test_mf_softmax_gpu_matches_numpy() -> None: + _require_mf_symbols() + if not _shader_spv_dir().exists(): + pytest.skip("shaders/spv not present") + + rng = np.random.default_rng(0) + x = rng.standard_normal((4, 7), dtype=np.float32) + + dev = gc.Device() + try: + dev.load_shaders(str(_shader_spv_dir())) + except Exception as e: + pytest.skip(f"load_shaders failed: {e}") + + got = gc.mf_softmax(dev, x, -1) + ref = mf_softmax(x, dim=-1) + np.testing.assert_allclose(got.numpy(), ref, rtol=1e-4, atol=1e-5) + + +@pytest.mark.gpu +@pytest.mark.cpp +def test_mf_softplus_and_sigmoid_gpu_match_numpy() -> None: + _require_mf_symbols() + if not _shader_spv_dir().exists(): + pytest.skip("shaders/spv not present") + + rng = np.random.default_rng(1) + x = rng.standard_normal((3, 5), dtype=np.float32) + + dev = gc.Device() + try: + dev.load_shaders(str(_shader_spv_dir())) + except Exception as e: + pytest.skip(f"load_shaders failed: {e}") + + g_sig = gc.mf_sigmoid(dev, x) + np.testing.assert_allclose(g_sig.numpy(), mf_sigmoid(x), rtol=1e-5, atol=1e-6) + + g_sp = gc.mf_softplus(dev, x, 1.25) + np.testing.assert_allclose(g_sp.numpy(), mf_softplus(x, 1.25), rtol=1e-4, atol=1e-5) diff --git a/tests/test_moe_forward.py b/tests/test_moe_forward.py new file mode 100644 index 0000000..ffd5fb6 --- /dev/null +++ b/tests/test_moe_forward.py @@ -0,0 +1,158 @@ +"""grilly_core.moe_* — fused MoE forward (GPU) vs NumPy reference.""" + +from __future__ import annotations + +import pathlib + +import numpy as np +import pytest + +try: + import grilly_core as gc +except ImportError: + pytest.skip("grilly_core not available", allow_module_level=True) + +try: + from grilly.backend.base import VULKAN_AVAILABLE +except ImportError: + VULKAN_AVAILABLE = False + + +def _shader_spv_dir() -> pathlib.Path: + return pathlib.Path(__file__).resolve().parent.parent / "shaders" / "spv" + + +def _softmax(x: np.ndarray) -> np.ndarray: + x = x - x.max() + e = np.exp(x, dtype=np.float32) + return (e / e.sum()).astype(np.float32) + + +def moe_forward_numpy( + embed_w: np.ndarray, + pos_w: np.ndarray, + expert_ws: list[np.ndarray], + router_ws: list[np.ndarray], + router_bs: list[np.ndarray], + out_w: np.ndarray, + n_layers: int, + n_experts: int, + input_ids: np.ndarray, +) -> np.ndarray: + """Reference forward (CPU).""" + s_len = int(input_ids.shape[0]) + d = embed_w.shape[1] + x = embed_w[input_ids.astype(np.int64)] + pos_w[:s_len] + for layer in range(n_layers): + xm = x.mean(axis=0) + logits = router_ws[layer] @ xm + router_bs[layer] + p = _softmax(logits.astype(np.float32)) + blend = np.zeros_like(x, dtype=np.float32) + for e in range(n_experts): + we = expert_ws[layer * n_experts + e] + y = x @ we.T + blend += p[e] * y + x = x + blend + return x @ out_w.T + + +@pytest.mark.gpu +@pytest.mark.cpp +@pytest.mark.skipif(not VULKAN_AVAILABLE, reason="Vulkan not available") +def test_moe_forward_parity_small(): + if not _shader_spv_dir().exists(): + pytest.skip("shaders/spv not present") + + rng = np.random.default_rng(7) + vocab, d, max_seq = 24, 8, 12 + n_layers, n_experts = 2, 4 + s_len = 6 + + embed_w = rng.standard_normal((vocab, d), dtype=np.float32) + pos_w = rng.standard_normal((max_seq, d), dtype=np.float32) + out_w = rng.standard_normal((vocab, d), dtype=np.float32) + + expert_ws = [rng.standard_normal((d, d), dtype=np.float32) for _ in range(n_layers * n_experts)] + router_ws = [rng.standard_normal((n_experts, d), dtype=np.float32) for _ in range(n_layers)] + router_bs = [rng.standard_normal((n_experts,), dtype=np.float32) for _ in range(n_layers)] + + input_ids = rng.integers(0, vocab, size=(s_len,), dtype=np.int32) + + ref = moe_forward_numpy( + embed_w, + pos_w, + expert_ws, + router_ws, + router_bs, + out_w, + n_layers, + n_experts, + input_ids, + ) + + dev = gc.Device() + try: + dev.load_shaders(str(_shader_spv_dir())) + except Exception as e: + pytest.skip(f"load_shaders failed: {e}") + + h = gc.moe_upload( + dev, + embed_w, + pos_w, + expert_ws, + router_ws, + router_bs, + out_w, + n_layers, + n_experts, + ) + try: + got = gc.moe_forward(dev, h, input_ids) + np.testing.assert_allclose(got, ref, rtol=1e-3, atol=1e-3) + + grad_logits = rng.standard_normal((s_len, vocab), dtype=np.float32) + grads = gc.moe_backward(dev, h, input_ids, grad_logits) + assert grads["grad_embed"].shape == (vocab, d) + assert grads["grad_out_w"].shape == (vocab, d) + assert len(grads["grad_experts"]) == n_layers * n_experts + assert grads["grad_pos"].shape == (max_seq, d) + finally: + gc.moe_release(dev, h) + + +@pytest.mark.cpp +def test_moe_backward_cpu_shapes(): + """Backward is CPU-only; smoke-test shapes without Vulkan.""" + rng = np.random.default_rng(11) + vocab, d, max_seq = 16, 4, 8 + n_layers, n_experts = 1, 2 + s_len = 4 + + embed_w = rng.standard_normal((vocab, d), dtype=np.float32) + pos_w = rng.standard_normal((max_seq, d), dtype=np.float32) + out_w = rng.standard_normal((vocab, d), dtype=np.float32) + expert_ws = [rng.standard_normal((d, d), dtype=np.float32) for _ in range(n_layers * n_experts)] + router_ws = [rng.standard_normal((n_experts, d), dtype=np.float32) for _ in range(n_layers)] + router_bs = [rng.standard_normal((n_experts,), dtype=np.float32) for _ in range(n_layers)] + input_ids = rng.integers(0, vocab, size=(s_len,), dtype=np.int32) + grad_logits = rng.standard_normal((s_len, vocab), dtype=np.float32) + + dev = gc.Device() + h = gc.moe_upload( + dev, + embed_w, + pos_w, + expert_ws, + router_ws, + router_bs, + out_w, + n_layers, + n_experts, + ) + try: + grads = gc.moe_backward(dev, h, input_ids, grad_logits) + assert grads["grad_embed"].shape == (vocab, d) + assert grads["grad_out_w"].shape == (vocab, d) + finally: + gc.moe_release(dev, h) diff --git a/tests/test_snn.py b/tests/test_snn.py index 43d6bb8..7fb36ac 100644 --- a/tests/test_snn.py +++ b/tests/test_snn.py @@ -7,7 +7,7 @@ try: from grilly import SNNCompute - from grilly.backend.base import VULKAN_AVAILABLE + from grilly.backend.base import VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE except ImportError: pytest.skip("grilly not available", allow_module_level=True) @@ -136,7 +136,7 @@ def test_process_handles_large_embedding(self): @pytest.mark.gpu -@pytest.mark.skipif(not VULKAN_AVAILABLE, reason="Vulkan not available") +@pytest.mark.skipif(not VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE, reason="Vulkan not available") class TestSNNGPU: """Test SNN with GPU (if available)""" diff --git a/tests/test_torch_api.py b/tests/test_torch_api.py new file mode 100644 index 0000000..28eea91 --- /dev/null +++ b/tests/test_torch_api.py @@ -0,0 +1,50 @@ +"""Smoke tests for ``grilly.torch_api`` (torch-style facade, no PyTorch).""" + +import numpy as np +import pytest + +import grilly.torch_api as torch +from grilly.nn import autograd as ag + + +def test_device_and_vulkan(): + d = torch.device("vulkan") + assert "vulkan" in str(d) + assert isinstance(torch.vulkan.is_available(), bool) + + +def test_tensor_long_and_ops(): + x = torch.tensor([[1, 2], [3, 4]], dtype=torch.long) + assert x.shape == (2, 2) + z = torch.zeros(3) + z.uniform_(-1, 1) + assert z.shape == (3,) + p = torch.randperm(10) + assert p.shape == (10,) + a = torch.randn(2, 3) + b = torch.randn(2, 3) + d = torch.cdist(a.unsqueeze(0), b.unsqueeze(0), p=1) + assert d.shape == (1, 2, 2) + + +def test_functional_cross_entropy_sum(): + logits = ag.randn(5, 3, requires_grad=True) + target = np.array([0, 1, 2, 1, 0], dtype=np.int64) + loss = torch.nn.functional.cross_entropy(logits, target, reduction="sum") + loss.backward() + assert loss.data.size == 1 + + +def test_amp_namespace(): + assert hasattr(torch.amp, "autocast") + assert hasattr(torch.amp, "GradScaler") + s = torch.amp.GradScaler("vulkan", enabled=False) + assert s.get_scale() == 1.0 + + +def test_grl_save_load_roundtrip(tmp_path): + path = tmp_path / "t.grl" + state = {"model": {"w": np.ones((2, 2), dtype=np.float32)}, "step": 3, "best_ppl": 1.25} + torch.save(state, path) + out = torch.load(path, map_location=torch.device("cpu")) + assert "model" in out or "metadata" in out diff --git a/tests/test_vsa_lm_forward.py b/tests/test_vsa_lm_forward.py new file mode 100644 index 0000000..976a568 --- /dev/null +++ b/tests/test_vsa_lm_forward.py @@ -0,0 +1,282 @@ +"""grilly_core.vsa_lm_* — fused VSA-LM forward (GPU) + backward (CPU) tests.""" + +from __future__ import annotations + +import pathlib + +import numpy as np +import pytest + +try: + import grilly_core as gc +except ImportError: + pytest.skip("grilly_core not available", allow_module_level=True) + +try: + from grilly.backend.base import VULKAN_AVAILABLE +except ImportError: + VULKAN_AVAILABLE = False + + +def _shader_spv_dir() -> pathlib.Path: + return pathlib.Path(__file__).resolve().parent.parent / "shaders" / "spv" + + +def _addition_linear_numpy(x: np.ndarray, w: np.ndarray, b: np.ndarray) -> np.ndarray: + """y[s, o] = -sum_k |w[o, k] - x[s, k]| + b[o]""" + out = np.empty((x.shape[0], w.shape[0]), dtype=np.float32) + for s in range(x.shape[0]): + for o in range(w.shape[0]): + out[s, o] = -np.sum(np.abs(w[o] - x[s])) + b[o] + return out + + +def _sign_activation(x: np.ndarray) -> np.ndarray: + return np.where(x > 0, 1.0, -1.0).astype(np.float32) + + +def _layernorm_numpy(x: np.ndarray, gamma: np.ndarray, beta: np.ndarray, + eps: float = 1e-5) -> np.ndarray: + mean = x.mean(axis=-1, keepdims=True) + var = x.var(axis=-1, keepdims=True) + return ((x - mean) / np.sqrt(var + eps) * gamma + beta).astype(np.float32) + + +def vsa_lm_forward_numpy( + embed_w, pos_w, ffn_up_ws, ffn_up_bs, ffn_down_ws, ffn_down_bs, + ln_gammas, ln_betas, out_w, n_layers, input_ids, +): + """NumPy reference forward.""" + s_len = input_ids.shape[0] + d = embed_w.shape[1] + + x = embed_w[input_ids.astype(np.int64)] + pos_w[:s_len] + + for l in range(n_layers): + h = _layernorm_numpy(x, ln_gammas[l], ln_betas[l]) + h_up = _addition_linear_numpy(h, ffn_up_ws[l], ffn_up_bs[l]) + h_sign = _sign_activation(h_up) + h_ffn = _addition_linear_numpy(h_sign, ffn_down_ws[l], ffn_down_bs[l]) + x = x + h_ffn + + scale = 1.0 / np.sqrt(d).astype(np.float32) + logits = (x @ out_w.T) * scale + return logits.astype(np.float32) + + +def _make_weights(rng, vocab, d, d_ffn, max_seq, n_layers): + embed_w = rng.standard_normal((vocab, d)).astype(np.float32) * 0.02 + pos_w = rng.standard_normal((max_seq, d)).astype(np.float32) * 0.02 + out_w = rng.standard_normal((vocab, d)).astype(np.float32) * 0.02 + + ffn_up_ws = [rng.standard_normal((d_ffn, d)).astype(np.float32) * 0.02 + for _ in range(n_layers)] + ffn_up_bs = [np.zeros(d_ffn, dtype=np.float32) for _ in range(n_layers)] + ffn_down_ws = [rng.standard_normal((d, d_ffn)).astype(np.float32) * 0.02 + for _ in range(n_layers)] + ffn_down_bs = [np.zeros(d, dtype=np.float32) for _ in range(n_layers)] + ln_gammas = [np.ones(d, dtype=np.float32) for _ in range(n_layers)] + ln_betas = [np.zeros(d, dtype=np.float32) for _ in range(n_layers)] + + return (embed_w, pos_w, out_w, ffn_up_ws, ffn_up_bs, + ffn_down_ws, ffn_down_bs, ln_gammas, ln_betas) + + +@pytest.mark.gpu +@pytest.mark.cpp +@pytest.mark.skipif(not VULKAN_AVAILABLE, reason="Vulkan not available") +def test_vsa_lm_forward_parity_small(): + """Forward parity: GPU vs NumPy at small scale.""" + if not _shader_spv_dir().exists(): + pytest.skip("shaders/spv not present") + + rng = np.random.default_rng(42) + vocab, d, d_ffn, max_seq = 24, 8, 16, 12 + n_layers = 2 + s_len = 6 + input_ids = rng.integers(0, vocab, size=(s_len,), dtype=np.int32) + + (embed_w, pos_w, out_w, ffn_up_ws, ffn_up_bs, + ffn_down_ws, ffn_down_bs, ln_gammas, ln_betas) = _make_weights( + rng, vocab, d, d_ffn, max_seq, n_layers) + + ref = vsa_lm_forward_numpy( + embed_w, pos_w, ffn_up_ws, ffn_up_bs, ffn_down_ws, ffn_down_bs, + ln_gammas, ln_betas, out_w, n_layers, input_ids) + + dev = gc.Device() + try: + dev.load_shaders(str(_shader_spv_dir())) + except Exception as e: + pytest.skip(f"load_shaders failed: {e}") + + h = gc.vsa_lm_upload( + dev, embed_w, pos_w, + ffn_up_ws, ffn_up_bs, ffn_down_ws, ffn_down_bs, + ln_gammas, ln_betas, out_w, + n_layers, d, d_ffn) + try: + got = gc.vsa_lm_forward(dev, h, input_ids) + assert got.shape == (s_len, vocab), f"shape: {got.shape}" + np.testing.assert_allclose(got, ref, rtol=1e-2, atol=1e-2) + finally: + gc.vsa_lm_release(dev, h) + + +@pytest.mark.gpu +@pytest.mark.cpp +@pytest.mark.skipif(not VULKAN_AVAILABLE, reason="Vulkan not available") +def test_vsa_lm_forward_parity_medium(): + """Forward parity at target-ish dimensions (d=64, 4 layers).""" + if not _shader_spv_dir().exists(): + pytest.skip("shaders/spv not present") + + rng = np.random.default_rng(99) + vocab, d, d_ffn, max_seq = 100, 64, 128, 32 + n_layers = 4 + s_len = 16 + input_ids = rng.integers(0, vocab, size=(s_len,), dtype=np.int32) + + (embed_w, pos_w, out_w, ffn_up_ws, ffn_up_bs, + ffn_down_ws, ffn_down_bs, ln_gammas, ln_betas) = _make_weights( + rng, vocab, d, d_ffn, max_seq, n_layers) + + ref = vsa_lm_forward_numpy( + embed_w, pos_w, ffn_up_ws, ffn_up_bs, ffn_down_ws, ffn_down_bs, + ln_gammas, ln_betas, out_w, n_layers, input_ids) + + dev = gc.Device() + try: + dev.load_shaders(str(_shader_spv_dir())) + except Exception as e: + pytest.skip(f"load_shaders failed: {e}") + + h = gc.vsa_lm_upload( + dev, embed_w, pos_w, + ffn_up_ws, ffn_up_bs, ffn_down_ws, ffn_down_bs, + ln_gammas, ln_betas, out_w, + n_layers, d, d_ffn) + try: + got = gc.vsa_lm_forward(dev, h, input_ids) + assert got.shape == (s_len, vocab) + np.testing.assert_allclose(got, ref, rtol=5e-2, atol=5e-2) + finally: + gc.vsa_lm_release(dev, h) + + +@pytest.mark.cpp +def test_vsa_lm_backward_cpu_shapes(): + """Backward is CPU-only — smoke-test shapes without Vulkan.""" + rng = np.random.default_rng(7) + vocab, d, d_ffn, max_seq = 16, 4, 8, 8 + n_layers = 1 + s_len = 4 + input_ids = rng.integers(0, vocab, size=(s_len,), dtype=np.int32) + grad_logits = rng.standard_normal((s_len, vocab)).astype(np.float32) + + (embed_w, pos_w, out_w, ffn_up_ws, ffn_up_bs, + ffn_down_ws, ffn_down_bs, ln_gammas, ln_betas) = _make_weights( + rng, vocab, d, d_ffn, max_seq, n_layers) + + dev = gc.Device() + h = gc.vsa_lm_upload( + dev, embed_w, pos_w, + ffn_up_ws, ffn_up_bs, ffn_down_ws, ffn_down_bs, + ln_gammas, ln_betas, out_w, + n_layers, d, d_ffn) + try: + grads = gc.vsa_lm_backward(dev, h, input_ids, grad_logits) + assert grads["grad_embed"].shape == (vocab, d) + assert grads["grad_pos"].shape == (max_seq, d) + assert grads["grad_out_w"].shape == (vocab, d) + assert len(grads["grad_ffn_up_w"]) == n_layers + assert grads["grad_ffn_up_w"][0].shape == (d_ffn, d) + assert grads["grad_ffn_up_b"][0].shape == (d_ffn,) + assert grads["grad_ffn_down_w"][0].shape == (d, d_ffn) + assert grads["grad_ffn_down_b"][0].shape == (d,) + assert grads["grad_ln_gamma"][0].shape == (d,) + assert grads["grad_ln_beta"][0].shape == (d,) + + for key in ["grad_embed", "grad_pos", "grad_out_w"]: + assert np.isfinite(grads[key]).all(), f"{key} has non-finite values" + finally: + gc.vsa_lm_release(dev, h) + + +@pytest.mark.cpp +def test_vsa_lm_update_weights(): + """Verify update_weights re-uploads without crash.""" + rng = np.random.default_rng(13) + vocab, d, d_ffn, max_seq = 16, 4, 8, 8 + n_layers = 1 + + (embed_w, pos_w, out_w, ffn_up_ws, ffn_up_bs, + ffn_down_ws, ffn_down_bs, ln_gammas, ln_betas) = _make_weights( + rng, vocab, d, d_ffn, max_seq, n_layers) + + dev = gc.Device() + h = gc.vsa_lm_upload( + dev, embed_w, pos_w, + ffn_up_ws, ffn_up_bs, ffn_down_ws, ffn_down_bs, + ln_gammas, ln_betas, out_w, + n_layers, d, d_ffn) + try: + # Perturb all weights slightly and re-upload + embed_w2 = embed_w + 0.001 + pos_w2 = pos_w + 0.001 + out_w2 = out_w + 0.001 + ffn_up_ws2 = [w + 0.001 for w in ffn_up_ws] + ffn_up_bs2 = [b + 0.001 for b in ffn_up_bs] + ffn_down_ws2 = [w + 0.001 for w in ffn_down_ws] + ffn_down_bs2 = [b + 0.001 for b in ffn_down_bs] + ln_gammas2 = [g + 0.001 for g in ln_gammas] + ln_betas2 = [b + 0.001 for b in ln_betas] + + gc.vsa_lm_update_weights( + dev, h, embed_w2, pos_w2, + ffn_up_ws2, ffn_up_bs2, ffn_down_ws2, ffn_down_bs2, + ln_gammas2, ln_betas2, out_w2) + finally: + gc.vsa_lm_release(dev, h) + + +@pytest.mark.gpu +@pytest.mark.cpp +@pytest.mark.skipif(not VULKAN_AVAILABLE, reason="Vulkan not available") +def test_vsa_lm_large_scale_no_crash(): + """Ensure no segfault at production-ish shape (d=256, 6 layers).""" + if not _shader_spv_dir().exists(): + pytest.skip("shaders/spv not present") + + rng = np.random.default_rng(55) + vocab, d, d_ffn, max_seq = 1000, 256, 512, 512 + n_layers = 6 + s_len = 64 + input_ids = rng.integers(0, vocab, size=(s_len,), dtype=np.int32) + + (embed_w, pos_w, out_w, ffn_up_ws, ffn_up_bs, + ffn_down_ws, ffn_down_bs, ln_gammas, ln_betas) = _make_weights( + rng, vocab, d, d_ffn, max_seq, n_layers) + + dev = gc.Device() + try: + dev.load_shaders(str(_shader_spv_dir())) + except Exception as e: + pytest.skip(f"load_shaders failed: {e}") + + h = gc.vsa_lm_upload( + dev, embed_w, pos_w, + ffn_up_ws, ffn_up_bs, ffn_down_ws, ffn_down_bs, + ln_gammas, ln_betas, out_w, + n_layers, d, d_ffn) + try: + logits = gc.vsa_lm_forward(dev, h, input_ids) + assert logits.shape == (s_len, vocab) + assert np.isfinite(logits).all(), "logits contain non-finite values" + + grad_logits = rng.standard_normal((s_len, vocab)).astype(np.float32) + grads = gc.vsa_lm_backward(dev, h, input_ids, grad_logits) + assert grads["grad_embed"].shape == (vocab, d) + assert np.isfinite(grads["grad_embed"]).all() + finally: + gc.vsa_lm_release(dev, h) diff --git a/tests/test_vulkan_tensor_residency.py b/tests/test_vulkan_tensor_residency.py index f8b890a..b6000d6 100644 --- a/tests/test_vulkan_tensor_residency.py +++ b/tests/test_vulkan_tensor_residency.py @@ -9,14 +9,14 @@ try: from grilly.backend.compute import VulkanCompute - from grilly.backend.base import VULKAN_AVAILABLE + from grilly.backend.base import VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE from grilly.utils.tensor_conversion import VulkanTensor except ImportError: pytest.skip("grilly not available", allow_module_level=True) @pytest.mark.gpu -@pytest.mark.skipif(not VULKAN_AVAILABLE, reason="Vulkan not available") +@pytest.mark.skipif(not VULKAN_PYTHON_LEGACY_BACKEND_AVAILABLE, reason="Vulkan not available") def test_prepare_for_dispatch_binds_without_double_upload(): """After one GPU op, a VulkanTensor should expose a buffer for the next op without re-uploading.""" with warnings.catch_warnings(): diff --git a/tests/tokenizers/test_gpu_tokenizer_parity.py b/tests/tokenizers/test_gpu_tokenizer_parity.py new file mode 100644 index 0000000..a370083 --- /dev/null +++ b/tests/tokenizers/test_gpu_tokenizer_parity.py @@ -0,0 +1,54 @@ +"""Tokenizer parity tests (first real pass, adaptive to in-progress API).""" + +import pytest + +from tests._tokenizer_parity_helpers import ( + encode_ids, + extract_input_ids, + load_grilly_tokenizer, + load_hf_tokenizer, +) + + +@pytest.mark.parametrize( + "model_id,text", + [ + ("bert-base-uncased", "Hello from Grilly GPU tokenization."), + ("distilbert-base-uncased", "Tokenizer parity check with punctuation: !?.,"), + ], +) +def test_tokenizer_ids_match_hf_reference(model_id: str, text: str): + """Token IDs should match Hugging Face for covered model/tokenizer assets.""" + hf_tok = load_hf_tokenizer(model_id) + gr_tok = load_grilly_tokenizer(model_id) + + hf_ids = extract_input_ids(hf_tok(text)) + gr_ids = encode_ids(gr_tok, text) + + assert list(gr_ids) == list(hf_ids) + + +def test_batch_encode_matches_hf_reference(): + """Batch tokenization should match Hugging Face input_ids for canonical uncased BERT.""" + model_id = "bert-base-uncased" + texts = [ + "Grilly runs on Vulkan.", + "Any GPU, one backend.", + "Parity matters before v1.0.", + ] + hf_tok = load_hf_tokenizer(model_id) + gr_tok = load_grilly_tokenizer(model_id) + + hf_batch = hf_tok(texts, padding=False, truncation=False) + hf_ids = hf_batch["input_ids"] + + if hasattr(gr_tok, "batch_encode"): + gr_batch = gr_tok.batch_encode(texts) + elif hasattr(gr_tok, "__call__"): + gr_batch = gr_tok(texts) + else: + pytest.skip("Tokenizer missing batch_encode/__call__ batch path") + + gr_ids = extract_input_ids(gr_batch) + assert [list(x) for x in gr_ids] == [list(x) for x in hf_ids] + diff --git a/tests/transformers_compat/test_transformers_compat_placeholder.py b/tests/transformers_compat/test_transformers_compat_placeholder.py new file mode 100644 index 0000000..078c848 --- /dev/null +++ b/tests/transformers_compat/test_transformers_compat_placeholder.py @@ -0,0 +1,13 @@ +"""Pre-v1.0 scaffold: transformers compatibility 1:1 test suite placeholder.""" + +import pytest + +pytestmark = pytest.mark.skip( + reason="Scaffold only: implement transformers compatibility tests in pre-v1.0 roadmap." +) + + +def test_transformers_compat_scaffold(): + """Placeholder test to reserve CI target.""" + assert True + diff --git a/tokenizer_impl/__init__.py b/tokenizer_impl/__init__.py new file mode 100644 index 0000000..13ba025 --- /dev/null +++ b/tokenizer_impl/__init__.py @@ -0,0 +1,32 @@ +""" +Grilly tokenizers (P0 roadmap). + +Uses the standalone Rust ``tokenizers`` library (same ``tokenizer.json`` assets as HF Fast +tokenizers) — **no** ``transformers`` / ``AutoTokenizer`` delegation in this package. +Native GPU merge/BPE kernels will layer behind the same :class:`Tokenizer` interface. + +Public API: + from grilly.tokenizers import AutoTokenizer, Tokenizer + tok = AutoTokenizer.from_pretrained("bert-base-uncased") + ids = tok.encode("Hello") +""" + +from __future__ import annotations + +from .auto import AutoTokenizer, from_pretrained +from .base import Tokenizer as TokenizerBase +from .fast_tokenizer import FastTokenizer +from .gpu import GPU_TOKENIZER_AVAILABLE, numpy_to_input_ids_buffers, wrap_ids_as_vulkan_tensors + +Tokenizer = FastTokenizer + +__all__ = [ + "AutoTokenizer", + "FastTokenizer", + "Tokenizer", + "TokenizerBase", + "from_pretrained", + "GPU_TOKENIZER_AVAILABLE", + "numpy_to_input_ids_buffers", + "wrap_ids_as_vulkan_tensors", +] diff --git a/tokenizer_impl/auto.py b/tokenizer_impl/auto.py new file mode 100644 index 0000000..e105f19 --- /dev/null +++ b/tokenizer_impl/auto.py @@ -0,0 +1,20 @@ +"""AutoTokenizer entrypoint (Rust ``tokenizers`` backend; no ``transformers``).""" + +from __future__ import annotations + +from typing import Any + +from .fast_tokenizer import FastTokenizer + + +class AutoTokenizer: + """Load a tokenizer the same way as Hugging Face ``AutoTokenizer`` (same assets).""" + + @staticmethod + def from_pretrained(model_id: str, **kwargs: Any) -> FastTokenizer: + return FastTokenizer.from_pretrained(model_id, **kwargs) + + +def from_pretrained(model_id: str, **kwargs: Any) -> FastTokenizer: + """Alias for :meth:`AutoTokenizer.from_pretrained`.""" + return AutoTokenizer.from_pretrained(model_id, **kwargs) diff --git a/tokenizer_impl/base.py b/tokenizer_impl/base.py new file mode 100644 index 0000000..14c09a3 --- /dev/null +++ b/tokenizer_impl/base.py @@ -0,0 +1,42 @@ +"""Abstract tokenizer interface (P0 GPU tokenizer roadmap).""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any + +import numpy as np + + +class Tokenizer(ABC): + """Minimal encode/decode surface aligned with common HF usage.""" + + @abstractmethod + def encode( + self, + text: str, + add_special_tokens: bool = True, + **kwargs: Any, + ) -> list[int] | np.ndarray: + """Return token ids for a single string.""" + + @abstractmethod + def decode( + self, + token_ids: list[int] | np.ndarray, + skip_special_tokens: bool = True, + **kwargs: Any, + ) -> str: + """Decode ids back to text.""" + + @abstractmethod + def batch_encode( + self, + texts: list[str], + padding: bool | str = False, + truncation: bool | str = False, + max_length: int | None = None, + return_tensors: str | None = None, + **kwargs: Any, + ) -> dict[str, Any]: + """Batch encode; return dict with at least ``input_ids`` (list of lists or ndarray).""" diff --git a/tokenizer_impl/fast_tokenizer.py b/tokenizer_impl/fast_tokenizer.py new file mode 100644 index 0000000..f42fdc1 --- /dev/null +++ b/tokenizer_impl/fast_tokenizer.py @@ -0,0 +1,110 @@ +"""Tokenizer implementation using the standalone ``tokenizers`` (Rust) library only. + +Matches Hugging Face **Fast** tokenizer outputs for the same serialized tokenizer assets +(``tokenizer.json``, including ``onnx/tokenizer.json`` Hub fallbacks) without importing +``transformers``. +""" + +from __future__ import annotations + +from typing import Any + +import numpy as np + +from .base import Tokenizer +from .loader import load_rust_tokenizer + + +class FastTokenizer(Tokenizer): + """Encode/decode via ``tokenizers.Tokenizer`` (Rust backend).""" + + def __init__(self, rust_tokenizer: Any, model_id: str | None = None) -> None: + self._tok = rust_tokenizer + self.model_id = model_id + + @classmethod + def from_pretrained(cls, model_id: str, **kwargs: Any) -> FastTokenizer: + try: + rs = load_rust_tokenizer(model_id, **kwargs) + except ImportError as e: + raise ImportError( + "FastTokenizer.from_pretrained requires `pip install tokenizers huggingface_hub`" + ) from e + return cls(rs, model_id=model_id) + + def encode( + self, + text: str, + add_special_tokens: bool = True, + **kwargs: Any, + ) -> list[int] | np.ndarray: + # Match HF PreTrainedTokenizer.encode defaults (no truncation / padding unless kwargs). + self._tok.no_truncation() + self._tok.no_padding() + enc = self._tok.encode(text, add_special_tokens=add_special_tokens) + return np.asarray(enc.ids, dtype=np.int32) + + def decode( + self, + token_ids: list[int] | np.ndarray, + skip_special_tokens: bool = True, + **kwargs: Any, + ) -> str: + if isinstance(token_ids, np.ndarray): + token_ids = token_ids.astype(np.int32).tolist() + return self._tok.decode(token_ids, skip_special_tokens=skip_special_tokens) + + def batch_encode( + self, + texts: list[str], + padding: bool | str = False, + truncation: bool | str = False, + max_length: int | None = None, + return_tensors: str | None = None, + add_special_tokens: bool = True, + **kwargs: Any, + ) -> dict[str, Any]: + if not texts: + return {"input_ids": [], "attention_mask": []} + + if truncation or max_length is not None: + ml = max_length if max_length is not None else 512 + self._tok.enable_truncation(ml) + else: + self._tok.no_truncation() + + if padding is True or padding == "longest" or padding == "max_length": + length = None + if padding == "max_length" and max_length is not None: + length = max_length + self._tok.enable_padding(direction="right", length=length) + else: + self._tok.no_padding() + + encodings = self._tok.encode_batch(texts, add_special_tokens=add_special_tokens) + + input_ids = [e.ids for e in encodings] + attention_mask = [e.attention_mask for e in encodings] + + out: dict[str, Any] = { + "input_ids": input_ids, + "attention_mask": attention_mask, + } + if return_tensors == "np": + out["input_ids"] = np.asarray(input_ids, dtype=object) + out["attention_mask"] = np.asarray(attention_mask, dtype=object) + return out + + def __call__(self, texts: str | list[str], **kwargs: Any) -> dict[str, Any]: + if isinstance(texts, str): + self._tok.no_truncation() + self._tok.no_padding() + enc = self._tok.encode( + texts, + add_special_tokens=kwargs.get("add_special_tokens", True), + ) + return { + "input_ids": enc.ids, + "attention_mask": enc.attention_mask, + } + return self.batch_encode(texts, **kwargs) diff --git a/tokenizer_impl/gpu.py b/tokenizer_impl/gpu.py new file mode 100644 index 0000000..bdf5d12 --- /dev/null +++ b/tokenizer_impl/gpu.py @@ -0,0 +1,61 @@ +"""GPU tokenizer backend (extension points; kernels land in shaders + C++ bridge).""" + +from __future__ import annotations + +import os +from typing import Any + +import numpy as np + +# Set GRILLY_GPU_TOKENIZER=1 when native GPU tokenization is wired (future). +GPU_TOKENIZER_AVAILABLE = os.environ.get("GRILLY_GPU_TOKENIZER", "").strip() in { + "1", + "true", + "yes", + "on", +} + + +def numpy_to_input_ids_buffers( + encoded: dict[str, Any], +) -> dict[str, np.ndarray]: + """Normalize batch output to numpy int32 arrays (CPU staging before GPU upload).""" + out: dict[str, np.ndarray] = {} + if "input_ids" in encoded: + ids = encoded["input_ids"] + if hasattr(ids, "numpy"): + ids = ids.numpy() + out["input_ids"] = np.asarray(ids, dtype=np.int32) + if "attention_mask" in encoded: + m = encoded["attention_mask"] + if hasattr(m, "numpy"): + m = m.numpy() + out["attention_mask"] = np.asarray(m, dtype=np.int32) + return out + + +def wrap_ids_as_vulkan_tensors( + encoded_numpy: dict[str, np.ndarray], + *, + device: Any = None, +) -> dict[str, Any]: + """Upload token buffers with :class:`grilly.utils.tensor_conversion.VulkanTensor` when available. + + Raises ``RuntimeError`` if Vulkan is unavailable or GPU tokenizer path not enabled. + """ + if not GPU_TOKENIZER_AVAILABLE: + raise RuntimeError( + "GPU-resident tokenizer buffers are not enabled yet " + "(set GRILLY_GPU_TOKENIZER=1 when implemented, or use numpy_to_input_ids_buffers)." + ) + try: + from grilly.utils.tensor_conversion import VulkanTensor, gpu_mode + except ImportError as e: + raise RuntimeError("VulkanTensor not available") from e + + if device is None: + gpu_mode(True) + out: dict[str, Any] = {} + for k, arr in encoded_numpy.items(): + out[k] = VulkanTensor(arr) + return out diff --git a/tokenizer_impl/loader.py b/tokenizer_impl/loader.py new file mode 100644 index 0000000..0ff7b25 --- /dev/null +++ b/tokenizer_impl/loader.py @@ -0,0 +1,57 @@ +"""Load ``tokenizers`` (Rust) tokenizer files from the Hub or local paths — no ``transformers``.""" + +from __future__ import annotations + +import os +from typing import Any + + +def load_rust_tokenizer(model_id: str, **kwargs: Any): + """Return ``tokenizers.Tokenizer`` from a Hub id or a local directory / ``tokenizer.json`` path.""" + from tokenizers import Tokenizer as RsTokenizer + + try: + from huggingface_hub.errors import EntryNotFoundError + except ImportError: + from huggingface_hub.utils import EntryNotFoundError # type: ignore[no-redef] + + if os.path.isfile(model_id) and model_id.endswith(".json"): + return RsTokenizer.from_file(model_id) + if os.path.isdir(model_id): + for rel in ("tokenizer.json", os.path.join("onnx", "tokenizer.json")): + p = os.path.join(model_id, rel) + if os.path.isfile(p): + return RsTokenizer.from_file(p) + raise FileNotFoundError( + f"No tokenizer.json or onnx/tokenizer.json under {model_id!r}", + ) + + token = kwargs.get("token") or kwargs.get("use_auth_token") + revision = kwargs.get("revision") + + from huggingface_hub import hf_hub_download + + # Root tokenizer.json first; some repos (e.g. google/mt5-small) ship only onnx/tokenizer.json. + for rel in ("tokenizer.json", "onnx/tokenizer.json"): + try: + path = hf_hub_download( + repo_id=model_id, + filename=rel, + token=token, + revision=revision, + ) + return RsTokenizer.from_file(path) + except EntryNotFoundError: + continue + + try: + tok = RsTokenizer.from_pretrained(model_id, **kwargs) + if tok is not None: + return tok + except (AttributeError, TypeError, OSError, ValueError, EntryNotFoundError): + pass + + raise FileNotFoundError( + f"Could not load Rust tokenizer for {model_id!r} " + "(tried tokenizer.json, onnx/tokenizer.json, Tokenizer.from_pretrained)", + ) diff --git a/torch_api/__init__.py b/torch_api/__init__.py new file mode 100644 index 0000000..b8497e7 --- /dev/null +++ b/torch_api/__init__.py @@ -0,0 +1,101 @@ +""" +Torch-compatible facade for Grilly: ``import grilly.torch_api as torch`` or use +re-exports on :mod:`grilly` (``device``, ``tensor``, ``save``, ``load``, …). + +No PyTorch runtime; GPU path is Vulkan via ``grilly_core`` + SPIR-V. Checkpoints +use ``.grl`` (see :mod:`grilly.utils.grl_checkpoint`). +""" + +from __future__ import annotations + +import types + +import grilly.nn as nn +import grilly.optim as optim +from grilly.functional.mf_activations import ( + mf_relu, + mf_sigmoid, + mf_sigmoid_01, + mf_softmax, + mf_softplus, +) + +from .amp_mod import GradScaler, autocast +from .dtypes_and_device import ( + cpu, + device, + float16, + float32, + float64, + int32, + int64, + long, +) +from .ops import ( + cdist, + clamp, + empty, + no_grad, + ones, + randn, + randperm, + sign, + tanh, + tensor, + zeros, +) +from .serialization import load, save +from .tensor import LongTensor, Tensor +from .vulkan_mod import is_available as _vulkan_is_available + +amp = types.SimpleNamespace( + autocast=autocast, + GradScaler=GradScaler, + __doc__="Automatic mixed precision (Vulkan-oriented; no CUDA).", +) + + +class _Vulkan: + """``torch.cuda``-free replacement: ``torch.vulkan.is_available()``.""" + + @staticmethod + def is_available() -> bool: + return _vulkan_is_available() + + +vulkan = _Vulkan() + +__all__ = [ + "Tensor", + "LongTensor", + "device", + "cpu", + "float16", + "float32", + "float64", + "int32", + "int64", + "long", + "tensor", + "empty", + "zeros", + "ones", + "randn", + "randperm", + "cdist", + "tanh", + "clamp", + "sign", + "no_grad", + "save", + "load", + "nn", + "optim", + "amp", + "vulkan", + "mf_softmax", + "mf_softplus", + "mf_sigmoid", + "mf_sigmoid_01", + "mf_relu", +] diff --git a/torch_api/amp_mod.py b/torch_api/amp_mod.py new file mode 100644 index 0000000..d0f16ac --- /dev/null +++ b/torch_api/amp_mod.py @@ -0,0 +1,58 @@ +"""``torch.amp``-compatible module (Vulkan / numpy path; no CUDA).""" + +from __future__ import annotations + +from typing import Any + +from grilly.backend import amp as _amp + + +class autocast: + """Autocast context (delegates to :mod:`grilly.backend.amp`). Accepts ``device_type`` / ``dtype`` for API parity.""" + + def __init__( + self, + device_type: str | None = None, + enabled: bool = True, + dtype: Any = None, + cache_enabled: bool = True, + ) -> None: + del device_type, dtype, cache_enabled + self._inner = _amp.autocast(enabled=enabled) + + def __enter__(self) -> autocast: + self._inner.__enter__() + return self + + def __exit__(self, *args: Any) -> None: + self._inner.__exit__(*args) + + +class GradScaler: + """GradScaler shim; first positional arg may be ``'vulkan'`` / ``'cuda'`` (ignored) or ``init_scale`` (float).""" + + def __init__( + self, + device_type: Any = None, + enabled: bool = True, + *, + init_scale: float = 65536.0, + growth_factor: float = 2.0, + backoff_factor: float = 0.5, + growth_interval: int = 2000, + **kwargs: Any, + ) -> None: + del kwargs + scale = init_scale + if isinstance(device_type, (int, float)) and not isinstance(device_type, bool): + scale = float(device_type) + self._inner = _amp.GradScaler( + init_scale=scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + enabled=enabled, + ) + + def __getattr__(self, name: str) -> Any: + return getattr(self._inner, name) diff --git a/torch_api/dtypes_and_device.py b/torch_api/dtypes_and_device.py new file mode 100644 index 0000000..f796964 --- /dev/null +++ b/torch_api/dtypes_and_device.py @@ -0,0 +1,87 @@ +"""Torch-style dtypes and ``device`` (Vulkan-only GPU path; no CUDA).""" + +from __future__ import annotations + +import numpy as np + +# Aliases for ``dtype=torch.float32`` / ``torch.long`` style imports +float16 = np.dtype("float16") +float32 = np.dtype("float32") +float64 = np.dtype("float64") +int32 = np.dtype("int32") +int64 = np.dtype("int64") +long = int64 + + +def _is_long_dtype(dtype: object) -> bool: + if dtype is None: + return False + if dtype in (int64, long, np.int64, "int64", "long"): + return True + if isinstance(dtype, np.dtype) and dtype.kind in "iu": + return dtype.itemsize >= 4 + return isinstance(dtype, str) and dtype.lower() in ("int64", "long") + + +def _dtype_to_numpy(dtype: object) -> np.dtype: + if dtype is None: + return np.dtype("float32") + if isinstance(dtype, np.dtype): + return dtype + if dtype in (float16, float32, float64, int32, int64): + return np.dtype(dtype) + if dtype is np.float16: + return np.dtype("float16") + if dtype is np.float32: + return np.dtype("float32") + if dtype is np.float64: + return np.dtype("float64") + if dtype is np.int64: + return np.dtype("int64") + s = str(dtype).lower() + if s in ("float16", "half"): + return np.dtype("float16") + if s in ("float32",): + return np.dtype("float32") + if s in ("float64", "double"): + return np.dtype("float64") + if s in ("int64", "long"): + return np.dtype("int64") + if s in ("int32",): + return np.dtype("int32") + return np.dtype("float32") + + +class device: + """``torch.device``-like device tag (``cpu`` / ``vulkan``).""" + + __slots__ = ("type", "index") + + def __init__(self, type_: str, index: int | None = None) -> None: + s = str(type_) + if ":" in s: + parts = s.split(":", 1) + self.type = parts[0].strip() + try: + self.index = int(parts[1]) if parts[1] else None + except ValueError: + self.index = None + else: + self.type = s.strip() + self.index = index + + def __eq__(self, other: object) -> bool: + if not isinstance(other, device): + return NotImplemented + return self.type == other.type and self.index == other.index + + def __str__(self) -> str: + if self.index is not None: + return f"{self.type}:{self.index}" + return self.type + + def __repr__(self) -> str: + return f"device(type='{self.type}'{'' if self.index is None else f', index={self.index}'})" + + +cpu = device("cpu") diff --git a/torch_api/functional.py b/torch_api/functional.py new file mode 100644 index 0000000..bd79a63 --- /dev/null +++ b/torch_api/functional.py @@ -0,0 +1,47 @@ +"""``torch.nn.functional``-compatible helpers.""" + +from __future__ import annotations + +import numpy as np + +from grilly.nn import autograd as ag + + +def softplus(input, beta: float = 1.0, threshold: float = 20.0): + return ag.softplus(input, beta, threshold) + + +def cross_entropy( + input, + target, + weight=None, + size_average=None, + ignore_index: int = -100, + reduce=None, + reduction: str = "mean", + label_smoothing: float = 0.0, +): + del weight, size_average, reduce, label_smoothing + logits = ag._ensure_variable(input) + if isinstance(target, ag.Variable): + t = target.data + elif isinstance(target, np.ndarray): + t = target + elif hasattr(target, "data") and not isinstance(target, np.ndarray): + t = target.data + else: + t = np.asarray(target) + + if ignore_index >= 0: + pass + + ce = ag.cross_entropy(logits, t) + + if reduction == "mean": + return ce + if reduction == "sum": + n = float(logits.data.shape[0]) + return ag.mul(ce, n) + if reduction == "none": + raise NotImplementedError("cross_entropy reduction='none' is not implemented") + raise ValueError(f"Unknown reduction: {reduction}") diff --git a/torch_api/ops.py b/torch_api/ops.py new file mode 100644 index 0000000..4a687ea --- /dev/null +++ b/torch_api/ops.py @@ -0,0 +1,111 @@ +"""Tensor factories and math ops (``torch.*`` facade).""" + +from __future__ import annotations + +import numpy as np + +from grilly.nn import autograd as ag + +from .dtypes_and_device import _dtype_to_numpy, _is_long_dtype +from .dtypes_and_device import device as _TorchDevice +from .tensor import LongTensor, Tensor, as_numpy + +_DEFAULT_FLOAT = np.dtype(np.float32) + + +def tensor( + data, + dtype: object | None = None, + device: _TorchDevice | str | None = None, + requires_grad: bool = False, +): + if _is_long_dtype(dtype): + return LongTensor(data, device=device) + np_dt = _dtype_to_numpy(dtype) if dtype is not None else _DEFAULT_FLOAT + arr = np.asarray(data, dtype=np_dt) + if np_dt.kind in "iu": + return LongTensor(arr, device=device) + return Tensor(arr.astype(np.float32), requires_grad=requires_grad) + + +def empty(*size: int, dtype=None, device=None, requires_grad: bool = False) -> Tensor: + if len(size) == 1 and isinstance(size[0], (tuple, list)): + size = tuple(size[0]) # type: ignore[assignment] + if not size: + raise TypeError("empty() requires at least one size dimension") + np_dt = _dtype_to_numpy(dtype) if dtype is not None else _DEFAULT_FLOAT + if np_dt.kind in "iu": + return LongTensor(np.empty(size, dtype=np_dt), device=device) + return Tensor(np.empty(size, dtype=np.float32), requires_grad=requires_grad) + + +def zeros(*size: int, dtype=None, device=None, requires_grad: bool = False) -> Tensor: + if len(size) == 1 and isinstance(size[0], (tuple, list)): + size = tuple(size[0]) # type: ignore[assignment] + if not size: + raise TypeError("zeros() requires size") + np_dt = _dtype_to_numpy(dtype) if dtype is not None else _DEFAULT_FLOAT + if np_dt.kind in "iu": + return LongTensor(np.zeros(size, dtype=np_dt), device=device) + return Tensor(np.zeros(size, dtype=np.float32), requires_grad=requires_grad) + + +def ones(*size: int, dtype=None, device=None, requires_grad: bool = False) -> Tensor: + if len(size) == 1 and isinstance(size[0], (tuple, list)): + size = tuple(size[0]) # type: ignore[assignment] + np_dt = _dtype_to_numpy(dtype) if dtype is not None else _DEFAULT_FLOAT + if np_dt.kind in "iu": + return LongTensor(np.ones(size, dtype=np_dt), device=device) + return Tensor(np.ones(size, dtype=np.float32), requires_grad=requires_grad) + + +def randn(*size: int, requires_grad: bool = False) -> Tensor: + if len(size) == 1 and isinstance(size[0], (tuple, list)): + size = tuple(size[0]) # type: ignore[assignment] + if not size: + return Tensor(np.array(np.random.randn(), dtype=np.float32), requires_grad=requires_grad) + return Tensor(np.random.randn(*size).astype(np.float32), requires_grad=requires_grad) + + +def randperm(n: int, dtype=None, device=None, requires_grad: bool = False) -> LongTensor: + """Random permutation of ``0 .. n-1`` (CPU/numpy).""" + idx = np.random.permutation(n).astype(np.int64) + return LongTensor(idx, device=device) + + +def cdist(x1, x2, p: float = 2.0, compute_mode=None) -> Tensor: + """Pairwise distance (``p=1`` L1 / Manhattan).""" + del compute_mode + a = as_numpy(x1) + b = as_numpy(x2) + if a.ndim == 2: + a = a[np.newaxis, ...] + if b.ndim == 2: + b = b[np.newaxis, ...] + if p == 1: + # (B, P, D), (B, R, D) -> (B, P, R) + diff = np.abs(a[:, :, None, :] - b[:, None, :, :]) + dist = np.sum(diff, axis=-1) + elif p == 2: + diff = a[:, :, None, :] - b[:, None, :, :] + dist = np.sqrt(np.sum(diff * diff, axis=-1)) + else: + diff = np.abs(a[:, :, None, :] - b[:, None, :, :]) ** p + dist = np.sum(diff, axis=-1) ** (1.0 / p) + return Tensor(dist.astype(np.float32), requires_grad=False) + + +def tanh(input) -> Tensor: + return ag.tanh(input) # type: ignore[return-value] + + +def clamp(input, min_val=None, max_val=None) -> Tensor: + return ag.clamp(input, min_val, max_val) # type: ignore[return-value] + + +def sign(input) -> Tensor: + return ag.sign(input) # type: ignore[return-value] + + +# Re-export no_grad from autograd +no_grad = ag.no_grad diff --git a/torch_api/serialization.py b/torch_api/serialization.py new file mode 100644 index 0000000..1c567b2 --- /dev/null +++ b/torch_api/serialization.py @@ -0,0 +1,49 @@ +"""``torch.save`` / ``torch.load`` facades for GRL checkpoints (``.grl``).""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any + +from grilly.utils.grl_checkpoint import load_grl, save_grl + + +def save( + obj: object, + f: str | Path, + *, + pickle_module: object | None = None, + pickle_protocol: int | None = None, + _use_new_zipfile_serialization: bool = True, +) -> None: + del pickle_module, pickle_protocol, _use_new_zipfile_serialization + path = Path(f) + if path.suffix.lower() not in (".grl",): + path = path.with_suffix(".grl") + meta: dict[str, Any] = {} + if isinstance(obj, dict): + for k in ("step", "training_step", "best_ppl", "epoch"): + if k in obj and not isinstance(obj[k], dict): + try: + meta[k] = obj[k] + except TypeError: + pass + save_grl(path, obj, metadata=meta or None) + + +def load( + f: str | Path, + map_location: Any = None, + pickle_module: object | None = None, + *, + weights_only: bool = False, + **kwargs: Any, +) -> Any: + del pickle_module, weights_only, kwargs + path = Path(f) + if path.suffix.lower() != ".grl": + raise ValueError( + f"Grilly torch.load only supports .grl checkpoints (got {path.suffix!r}). " + "Export PyTorch .pt to .grl with a one-time migration script." + ) + return load_grl(path, map_location=map_location) diff --git a/torch_api/tensor.py b/torch_api/tensor.py new file mode 100644 index 0000000..7f73c61 --- /dev/null +++ b/torch_api/tensor.py @@ -0,0 +1,89 @@ +"""``torch.Tensor``-like float tensor (``Variable`` subclass) and ``LongTensor``.""" + +from __future__ import annotations + +import numpy as np + +from grilly.nn.autograd import Variable + +from .dtypes_and_device import _dtype_to_numpy, device + + +class LongTensor(np.ndarray): + """Integer index tensor (``int64``), ndarray subclass.""" + + def __getitem__(self, key): + r = super().__getitem__(key) + if not isinstance(r, np.ndarray): + return r + if r.ndim == 0: + return r + if r.dtype == np.int64: + return r.view(LongTensor) + return r + + def __new__(cls, input_array, device: device | str | None = None): + obj = np.asarray(input_array, dtype=np.int64).view(cls) + obj._torch_device = device # type: ignore[attr-defined] + return obj + + def __array_finalize__(self, obj): + if obj is None: + return + self._torch_device = getattr(obj, "_torch_device", None) # type: ignore[attr-defined] + + def to(self, device: device | str | None = None, dtype=None, non_blocking: bool = False): + del non_blocking + if dtype is not None: + np_dt = _dtype_to_numpy(dtype) + return Tensor(np.asarray(self, dtype=np.float32).astype(np_dt)) + return LongTensor(self.copy(), device=device) + + @property + def dtype(self) -> np.dtype: + return np.dtype(np.int64) + + def numel(self) -> int: + return int(self.size) + + +class Tensor(Variable): + """Float tensor with PyTorch-style helpers (numpy + autograd ``Variable``).""" + + def uniform_(self, a: float = 0.0, b: float = 1.0) -> Tensor: + self.data[:] = np.random.uniform(a, b, self.shape).astype(np.float32) + return self + + def zero_(self) -> Tensor: + self.data.fill(np.float32(0.0)) + return self + + def to( + self, + device: device | str | None = None, + dtype: object | None = None, + non_blocking: bool = False, + ) -> Tensor | LongTensor: + del non_blocking # numpy path: unused + if dtype is not None: + np_dt = _dtype_to_numpy(dtype) + if np_dt.kind in "iu": + return LongTensor(self.data.astype(np_dt)) + new_data = self.data.astype(np_dt) + return Tensor(new_data, requires_grad=self.requires_grad) + if device is not None: + pass # no separate device memory in numpy path + return Tensor(self.data.copy(), requires_grad=self.requires_grad) + + @property + def dtype(self) -> np.dtype: + return self.data.dtype + + +def as_numpy(x: object) -> np.ndarray: + """Unwrap ``Tensor`` / ``Variable`` / ``LongTensor`` to ndarray.""" + if isinstance(x, np.ndarray): + return x + if hasattr(x, "data") and isinstance(getattr(x, "data", None), np.ndarray): + return x.data # type: ignore[union-attr] + return np.asarray(x) diff --git a/torch_api/vulkan_mod.py b/torch_api/vulkan_mod.py new file mode 100644 index 0000000..8886602 --- /dev/null +++ b/torch_api/vulkan_mod.py @@ -0,0 +1,15 @@ +"""``torch.cuda``-free availability helper (Vulkan).""" + +from __future__ import annotations + +try: + from grilly.backend.base import VULKAN_AVAILABLE +except Exception: + VULKAN_AVAILABLE = False + + +def is_available() -> bool: + return bool(VULKAN_AVAILABLE) + + +__all__ = ["is_available"] diff --git a/torch_api_example.py b/torch_api_example.py new file mode 100644 index 0000000..793cefb --- /dev/null +++ b/torch_api_example.py @@ -0,0 +1,311 @@ +"""VSA-LM v3b: Sequence-mean LiquidCell, batched. + +Reverts to the proven v1 architecture (one LiquidCell step per sequence via +x_mean), but fully batched on GPU with B=64. Loads vsa_lm_v1_resume.pt and +continues cosine decay from wherever the checkpoint left off. + +Why: per-token recurrence (v3) hit 0.2 stp/s on A100 due to Python loop +overhead. Sequence-mean is O(1) per layer per sequence — GPU-friendly and +the architecture that trained the checkpoint in the first place. +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +import math +import time +import os + +import numpy as np + +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + +SEQ_LEN = 256 +DATA_DIR = 'vsa_lm_data' + +# ── Data ── +tokens = np.load(f'{DATA_DIR}/tokens.npy') +vocab = int(np.load(f'{DATA_DIR}/vocab.npy')[0]) +n = len(tokens) +tr, va = tokens[:int(.8 * n)], tokens[int(.8 * n):int(.9 * n)] + + +def mkseqs(t, sl): + x, y = [], [] + for i in range(0, len(t) - sl - 1, sl // 2): + x.append(t[i:i + sl]) + y.append(t[i + 1:i + sl + 1]) + return (torch.tensor(np.array(x), dtype=torch.long), + torch.tensor(np.array(y), dtype=torch.long)) + + +train_x, train_y = mkseqs(tr, SEQ_LEN) +val_x, val_y = mkseqs(va, SEQ_LEN) +print(f'Vocab={vocab}, Train={len(train_x)}, Val={len(val_x)}') + + +def compute_ppl(model, x_data, y_data, max_samples=100): + model.eval() + total_loss, n_tok = 0.0, 0 + with torch.no_grad(), torch.amp.autocast('cuda', enabled=True, dtype=torch.float16): + # Batched eval for speed + bs = 32 + for i in range(0, min(len(x_data), max_samples), bs): + xb = x_data[i:i + bs].to(device) + yb = y_data[i:i + bs].to(device) + logits = model(xb) # (B, S, V) + loss = F.cross_entropy( + logits.reshape(-1, logits.shape[-1]), + yb.reshape(-1), + reduction='sum', + ) + total_loss += loss.item() + n_tok += yb.numel() + model.train() + return math.exp(min(total_loss / max(n_tok, 1), 20)) + + +# ── Config ── +D_MODEL = 384 +N_LAYERS = 12 +D_FFN = 1152 +LR = 1e-4 +TRAIN_STEPS = 100000 +VAL_EVERY = 200 +GRAD_CLIP = 1.0 +BATCH_SIZE = 16 +CAPSULE_DIM = 32 +USE_AMP = True # fp16 autocast — ~1.5-2x speedup on cdist + + +# ── Model ── +class AdditionLinearCUDA(nn.Module): + def __init__(self, d_in, d_out): + super().__init__() + self.weight = nn.Parameter(torch.empty(d_out, d_in).uniform_(-0.1, 0.1)) + self.bias = nn.Parameter(torch.zeros(d_out)) + self.d_in = d_in + + def forward(self, x): + # x: (..., d_in) → (..., d_out). Flatten leading dims for cdist. + orig_shape = x.shape + x_flat = x.reshape(-1, orig_shape[-1]) + dist = torch.cdist(x_flat.unsqueeze(0), self.weight.unsqueeze(0), p=1).squeeze(0) + out = -dist + self.bias + return out.reshape(*orig_shape[:-1], -1) + + +class LiquidCellCUDA(nn.Module): + """Sequence-mean LiquidCell — one step per sequence (not per token). + + For batched input (B, d), processes all B sequences in parallel in a + single forward pass. Matches the v1 architecture that trained + vsa_lm_v1_resume.pt. + """ + + def __init__(self, d, dt=0.02, tau_min=0.02, tau_max=2.0): + super().__init__() + s = math.sqrt(2.0 / (d + d)) + self.W = nn.Parameter(torch.randn(d, d) * s) + self.U = nn.Parameter(torch.randn(d, d) * s) + self.b = nn.Parameter(torch.zeros(d)) + self.V = nn.Parameter(torch.randn(d, d) * s) + self.c = nn.Parameter(torch.randn(d) * 0.1) + self.register_buffer('h', torch.zeros(d)) + self.dt = dt + self.tau_min = tau_min + self.tau_max = tau_max + + def step(self, x): + # Single sequence: x is (d,) + tau = self.tau_min + F.softplus(self.V @ x + self.c) + tau = torch.clamp(tau, max=self.tau_max) + a = torch.tanh(self.W @ self.h + self.U @ x + self.b) + dh = -self.h / tau.clamp(min=1e-6) + a + self.h = (self.h + self.dt * dh).detach() + return self.h + + def step_batched(self, x): + # x: (B, d). Broadcasts h across batch, single parallel update. + B = x.shape[0] + tau = self.tau_min + F.softplus(x @ self.V.T + self.c) # (B, d) + tau = torch.clamp(tau, max=self.tau_max) + h_b = self.h.unsqueeze(0).expand(B, -1) # (B, d) — shared state + a = torch.tanh(h_b @ self.W.T + x @ self.U.T + self.b) # (B, d) + dh = -h_b / tau.clamp(min=1e-6) + a + new_h = h_b + self.dt * dh # (B, d) + # Update the shared buffer to the batch mean (detached for next call) + self.h = new_h.mean(dim=0).detach() + return new_h + + def reset(self): + self.h.zero_() + + +class VSALayerCUDA(nn.Module): + def __init__(self, d, d_ffn): + super().__init__() + self.ln = nn.LayerNorm(d) + self.ffn_up = AdditionLinearCUDA(d, d_ffn) + self.ffn_down = AdditionLinearCUDA(d_ffn, d) + self.liquid = LiquidCellCUDA(d) + self.d = d + + def forward(self, x, plasticity=0.5, consolidation=0.5): + if x.ndim == 3: + return self._forward_batch(x, plasticity, consolidation) + return self._forward_single(x, plasticity, consolidation) + + def _forward_single(self, x, plasticity, consolidation): + # x: (S, d) + h = self.ln(x) + x_mean = h.mean(dim=0) # (d,) + self.liquid.dt = 0.01 + 0.03 * plasticity + self.liquid.tau_min = 0.02 + 0.08 * consolidation + temporal = self.liquid.step(x_mean) # (d,) + + h_up = torch.sign(self.ffn_up(h) / math.sqrt(self.ffn_up.d_in)) + h_ffn = self.ffn_down(h_up) / math.sqrt(self.ffn_down.d_in) + + gate = torch.tanh(temporal) # (d,) + y = (0.5 + plasticity) * h_ffn * (1.0 + 0.1 * gate) + return x + y + + def _forward_batch(self, x, plasticity, consolidation): + # x: (B, S, d) + h = self.ln(x) + x_mean = h.mean(dim=1) # (B, d) + self.liquid.dt = 0.01 + 0.03 * plasticity + self.liquid.tau_min = 0.02 + 0.08 * consolidation + temporal = self.liquid.step_batched(x_mean) # (B, d) + + h_up = torch.sign(self.ffn_up(h) / math.sqrt(self.ffn_up.d_in)) + h_ffn = self.ffn_down(h_up) / math.sqrt(self.ffn_down.d_in) + + gate = torch.tanh(temporal).unsqueeze(1) # (B, 1, d) — broadcasts over S + y = (0.5 + plasticity) * h_ffn * (1.0 + 0.1 * gate) + return x + y + + +class VSALMModel(nn.Module): + def __init__(self, vocab, d, d_ffn, n_layers, max_seq): + super().__init__() + self.d = d + self.embed = nn.Embedding(vocab, d) + self.capsule_embed = nn.Embedding(vocab, CAPSULE_DIM) + self.capsule_proj = nn.Linear(CAPSULE_DIM, d, bias=False) + self.pos = nn.Parameter(torch.randn(max_seq + 16, d) * 0.02) + self.out_proj = nn.Linear(d, vocab, bias=False) + self.layers = nn.ModuleList([VSALayerCUDA(d, d_ffn) for _ in range(n_layers)]) + self.scale = math.sqrt(d) + nn.init.normal_(self.embed.weight, 0, 0.02) + nn.init.normal_(self.capsule_embed.weight, 0, 0.01) + + def forward(self, ids): + if ids.ndim == 1: + S = ids.shape[0] + x = self.embed(ids) + self.pos[:S] + caps = self.capsule_embed(ids) + x = x + self.capsule_proj(caps) + with torch.no_grad(): + cm = caps.mean(dim=0) + plasticity = cm[14].clamp(0, 1).item() + consolidation = cm[18].clamp(0, 1).item() + for layer in self.layers: + x = layer(x, plasticity=plasticity, consolidation=consolidation) + return self.out_proj(x / self.scale) + + # Batched: (B, S) + B, S = ids.shape + x = self.embed(ids) + self.pos[:S].unsqueeze(0) + caps = self.capsule_embed(ids) + x = x + self.capsule_proj(caps) + with torch.no_grad(): + cm = caps.mean(dim=(0, 1)) + plasticity = cm[14].clamp(0, 1).item() + consolidation = cm[18].clamp(0, 1).item() + for layer in self.layers: + x = layer(x, plasticity=plasticity, consolidation=consolidation) + return self.out_proj(x / self.scale) + + def reset_liquid(self): + for layer in self.layers: + layer.liquid.reset() + + +# ── Load checkpoint ── +model = VSALMModel(vocab, D_MODEL, D_FFN, N_LAYERS, SEQ_LEN).to(device) + +ckpt_path = 'vsa_lm_v1_resume.pt' if os.path.exists('vsa_lm_v1_resume.pt') else 'vsa_lm_best.pt' +checkpoint = torch.load(ckpt_path, map_location=device) +if 'model' in checkpoint: + model.load_state_dict(checkpoint['model'], strict=True) + ckpt_step = checkpoint.get('step', 0) + ckpt_ppl = checkpoint.get('best_ppl', 0) + print(f'Loaded {ckpt_path} (step={ckpt_step}, PPL={ckpt_ppl:.1f})') +else: + model.load_state_dict(checkpoint, strict=True) + ckpt_step = 0 + print(f'Loaded {ckpt_path}') + +n_params = sum(p.numel() for p in model.parameters()) +print(f'VSA-LM v3b: d={D_MODEL}, L={N_LAYERS}, B={BATCH_SIZE}, params={n_params/1e6:.1f}M') + +with torch.no_grad(): + ppl = compute_ppl(model, val_x, val_y) + print(f'Checkpoint PPL: {ppl:.1f}') + +# ── Train ── +optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=0.01) +# Cosine over remaining steps +remaining = max(TRAIN_STEPS - ckpt_step, 1000) +scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, remaining, eta_min=1e-6) +scaler = torch.amp.GradScaler('cuda', enabled=USE_AMP) + +t0 = time.time() +best_ppl = ppl +step = 0 + +while step < remaining: + perm = torch.randperm(len(train_x)) + for i in range(0, len(perm) - BATCH_SIZE, BATCH_SIZE): + if step >= remaining: + break + + ids_batch = train_x[perm[i:i + BATCH_SIZE]].to(device) + labels_batch = train_y[perm[i:i + BATCH_SIZE]].to(device) + + model.reset_liquid() + optimizer.zero_grad() + + with torch.amp.autocast('cuda', enabled=USE_AMP, dtype=torch.float16): + logits = model(ids_batch) # (B, S, V) + loss = F.cross_entropy( + logits.reshape(-1, logits.shape[-1]), + labels_batch.reshape(-1), + ) + + scaler.scale(loss).backward() + scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP) + scaler.step(optimizer) + scaler.update() + scheduler.step() + + if step % VAL_EVERY == 0: + ppl = compute_ppl(model, val_x, val_y) + el = time.time() - t0 + sps = (step + 1) / el if el > 0 else 0 + lr_now = scheduler.get_last_lr()[0] + print(f'step={step:6d} lr={lr_now:.5f} | loss={loss.item():.3f} | ' + f'PPL={ppl:.1f} | {sps:.1f} stp/s (B={BATCH_SIZE})') + if ppl < best_ppl: + best_ppl = ppl + torch.save( + {'model': model.state_dict(), 'step': ckpt_step + step, 'best_ppl': best_ppl}, + 'vsa_lm_v3b_best.pt', + ) + + step += 1 + +print(f'\nDone: {step} steps, best PPL={best_ppl:.1f}, time={time.time()-t0:.0f}s') diff --git a/utils/__init__.py b/utils/__init__.py index 75d0278..8ec1f57 100644 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -5,6 +5,7 @@ """ from .checkpoint import load_checkpoint, save_checkpoint +from .grl_checkpoint import load_grl, save_grl from .data import ( ArrayDataset, BatchSampler, @@ -137,6 +138,8 @@ # Checkpoint "save_checkpoint", "load_checkpoint", + "save_grl", + "load_grl", # Device "get_device", "set_device", diff --git a/utils/grl_checkpoint.py b/utils/grl_checkpoint.py new file mode 100644 index 0000000..51facf5 --- /dev/null +++ b/utils/grl_checkpoint.py @@ -0,0 +1,252 @@ +""" +Grilly checkpoint format (.grl) — GRL v1. + +Binary layout: + - Magic ``GRLY`` (4 bytes) + - uint16 format version (1) + - uint16 flags (reserved, 0) + - uint32 reserved + - uint64 metadata_json_offset, metadata_json_length + - uint64 tensor_index_offset, tensor_index_length + - uint64 payload_offset, payload_length + - padding to 64-byte header + +Followed by: + - UTF-8 JSON metadata blob + - UTF-8 JSON tensor index (array of tensor descriptors) + - payload (concatenated tensor bytes, C-contiguous row-major) + +Tensor index entry:: + {"name": str, "dtype": "f32"|"f16"|"i64"|"i32"|"u8", "shape": [int,...], + "offset": int, "length": int} + +``offset`` / ``length`` are byte ranges relative to **start of payload section**. +""" + +from __future__ import annotations + +import json +import struct +from pathlib import Path +from typing import Any + +import numpy as np + +try: + import grilly_core as _grl_core + + _HAS_CPP_GRL = hasattr(_grl_core, "grl_write_file") +except ImportError: + _grl_core = None + _HAS_CPP_GRL = False + +MAGIC = b"GRLY" +FORMAT_VERSION = 1 +HEADER_SIZE = 64 +_FLAG_NONE = 0 + +_DTYPE_TO_STR = { + np.dtype("float32"): "f32", + np.dtype("float16"): "f16", + np.dtype("int64"): "i64", + np.dtype("int32"): "i32", + np.dtype("uint8"): "u8", +} +_STR_TO_DTYPE = {v: k for k, v in _DTYPE_TO_STR.items()} + + +def _flatten_state_dict(d: dict[str, Any], prefix: str = "") -> dict[str, np.ndarray]: + """Flatten nested dict to dotted keys -> ndarray.""" + out: dict[str, np.ndarray] = {} + for k, v in d.items(): + key = f"{prefix}.{k}" if prefix else k + if isinstance(v, dict): + out.update(_flatten_state_dict(v, key)) + elif isinstance(v, np.ndarray): + out[key] = np.ascontiguousarray(v) + else: + try: + out[key] = np.asarray(v) + except Exception: + continue + return out + + +def _unflatten_state_dict(flat: dict[str, np.ndarray]) -> dict[str, Any]: + """Rebuild nested dict from dotted keys.""" + root: dict[str, Any] = {} + for key, arr in flat.items(): + parts = key.split(".") + cur = root + for p in parts[:-1]: + cur = cur.setdefault(p, {}) + cur[parts[-1]] = arr + return root + + +def save_grl( + filepath: str | Path, + state_dict: dict[str, Any], + *, + metadata: dict[str, Any] | None = None, +) -> None: + """Write a GRL v1 checkpoint from a (possibly nested) state dict of numpy arrays.""" + path = Path(filepath) + path.parent.mkdir(parents=True, exist_ok=True) + + flat = _flatten_state_dict(state_dict) + meta = { + "schema": "grilly.checkpoint.v1", + "framework": "grilly", + **(metadata or {}), + } + meta_bytes = json.dumps(meta, separators=(",", ":"), sort_keys=True).encode("utf-8") + + payload_parts: list[bytes] = [] + index_entries: list[dict[str, Any]] = [] + offset = 0 + for name in sorted(flat.keys()): + arr = flat[name] + if not isinstance(arr, np.ndarray): + arr = np.asarray(arr) + dt = arr.dtype + if dt not in _DTYPE_TO_STR: + arr = arr.astype(np.float32) + dt = arr.dtype + raw = arr.tobytes(order="C") + dtype_str = _DTYPE_TO_STR.get(dt, "f32") + index_entries.append( + { + "name": name, + "dtype": dtype_str, + "shape": list(arr.shape), + "offset": offset, + "length": len(raw), + } + ) + payload_parts.append(raw) + offset += len(raw) + + payload = b"".join(payload_parts) + index_bytes = json.dumps(index_entries, separators=(",", ":")).encode("utf-8") + + # Layout: header | meta | index | payload + meta_off = HEADER_SIZE + meta_len = len(meta_bytes) + idx_off = meta_off + meta_len + idx_len = len(index_bytes) + pay_off = idx_off + idx_len + pay_len = len(payload) + + if _HAS_CPP_GRL: + _grl_core.grl_write_file( + str(path), + meta_bytes.decode("utf-8"), + index_bytes.decode("utf-8"), + payload, + ) + return + + header = bytearray(HEADER_SIZE) + header[0:4] = MAGIC + struct.pack_into(" dict[str, Any]: + """ + Load GRL v1 checkpoint. Returns a dict with ``metadata``, ``model`` (nested state_dict), + and any extra keys from metadata. + + ``map_location`` is accepted for torch API compatibility; ``\"cpu\"`` keeps arrays + on host; ``\"vulkan\"`` / default leaves numpy arrays (caller uploads to GPU). + """ + _ = map_location + path = Path(filepath) + + if _HAS_CPP_GRL: + meta_json, index_json, pay_bytes = _grl_core.grl_read_file(str(path)) + metadata = json.loads(meta_json) + index_entries = json.loads(index_json) + payload = memoryview(pay_bytes) + flat: dict[str, np.ndarray] = {} + for ent in index_entries: + name = ent["name"] + dtype_s = ent["dtype"] + shape = tuple(ent["shape"]) + off = int(ent["offset"]) + ln = int(ent["length"]) + dt = _STR_TO_DTYPE.get(dtype_s, np.float32) + buf = payload[off : off + ln].tobytes() + arr = np.frombuffer(buf, dtype=dt).reshape(shape) + flat[name] = np.array(arr, copy=True) + # Roundtrip semantics: return what the user saved, not a forced + # ``{'model': ..., 'metadata': ...}`` wrapper. ``torch.save`` / + # ``torch.load`` users expect ``ck == original_payload``. + out: dict[str, Any] = _unflatten_state_dict(flat) + # Add metadata only if it doesn't collide with a user key. + if "metadata" not in out: + out["metadata"] = metadata + # Promote common training scalars from metadata to top level when + # the user didn't include them in the original payload (back-compat + # with checkpoints saved by older grilly that only put step in meta). + for k in ("step", "training_step", "best_ppl", "epoch"): + if k in metadata and k not in out: + out[k] = metadata[k] + return out + + data = path.read_bytes() + if len(data) < HEADER_SIZE or data[0:4] != MAGIC: + raise ValueError(f"Not a GRL file or corrupt magic: {path}") + + version = struct.unpack_from(" dict[str, Any]: + """Map load_grl output to a torch-like checkpoint dict (``model``, ``step``, ...).""" + d = dict(grl_dict) + if "model" in d and isinstance(d["model"], dict): + # torch often uses 'model' key for state_dict + pass + return d diff --git a/utils/initialization.py b/utils/initialization.py index 8c57471..a2ee675 100644 --- a/utils/initialization.py +++ b/utils/initialization.py @@ -116,3 +116,33 @@ def kaiming_normal_( std = gain / np.sqrt(fan) tensor[:] = np.random.randn(*tensor.shape).astype(tensor.dtype) * std return tensor + + +def _writable_array(tensor): + """Return a writable ndarray view for ``tensor``. + + The torch_api ``Tensor`` / autograd ``Variable`` / ``Parameter`` wrappers + don't all support ``__setitem__`` directly — but they expose the backing + ndarray via ``.data``. Detect those and return the underlying buffer so + in-place init works for both raw arrays and wrapped tensors. + """ + if isinstance(tensor, np.ndarray): + return tensor + if hasattr(tensor, "data") and isinstance(getattr(tensor, "data", None), np.ndarray): + return tensor.data + # Last-resort: hope it's array-like and writable. + return np.asarray(tensor) + + +def normal_(tensor, mean: float = 0.0, std: float = 1.0): + """In-place normal distribution (PyTorch ``nn.init.normal_``).""" + arr = _writable_array(tensor) + arr[:] = np.random.normal(mean, std, arr.shape).astype(arr.dtype, copy=False) + return tensor + + +def uniform_(tensor, a: float = 0.0, b: float = 1.0): + """In-place uniform distribution (PyTorch ``nn.init.uniform_``).""" + arr = _writable_array(tensor) + arr[:] = np.random.uniform(a, b, arr.shape).astype(arr.dtype, copy=False) + return tensor diff --git a/uv.lock b/uv.lock index a8db432..302dcdc 100644 --- a/uv.lock +++ b/uv.lock @@ -534,26 +534,36 @@ accel = [ ] all = [ { name = "black" }, + { name = "huggingface-hub" }, { name = "isort" }, { name = "mkdocs-material" }, { name = "mypy" }, { name = "numba" }, + { name = "protobuf" }, { name = "pytest" }, { name = "pytest-asyncio" }, { name = "pytest-benchmark" }, { name = "pytest-cov" }, { name = "ruff" }, + { name = "sentencepiece" }, + { name = "tokenizers" }, + { name = "transformers" }, ] dev = [ { name = "black" }, + { name = "huggingface-hub" }, { name = "isort" }, { name = "mkdocs-material" }, { name = "mypy" }, + { name = "protobuf" }, { name = "pytest" }, { name = "pytest-asyncio" }, { name = "pytest-benchmark" }, { name = "pytest-cov" }, { name = "ruff" }, + { name = "sentencepiece" }, + { name = "tokenizers" }, + { name = "transformers" }, ] full = [ { name = "blake3" }, @@ -572,6 +582,10 @@ huggingface = [ onnx = [ { name = "onnx" }, ] +tokenizer = [ + { name = "huggingface-hub" }, + { name = "tokenizers" }, +] torch = [ { name = "torch" }, ] @@ -581,6 +595,8 @@ requires-dist = [ { name = "black", marker = "extra == 'dev'", specifier = ">=23.7.0" }, { name = "blake3", marker = "extra == 'full'", specifier = ">=1.0.8" }, { name = "grilly", extras = ["dev", "accel"], marker = "extra == 'all'" }, + { name = "huggingface-hub", marker = "extra == 'dev'", specifier = ">=0.20.0" }, + { name = "huggingface-hub", marker = "extra == 'tokenizer'", specifier = ">=0.20.0" }, { name = "isort", marker = "extra == 'dev'", specifier = ">=5.12.0" }, { name = "mkdocs-material", marker = "extra == 'dev'", specifier = ">=9.0" }, { name = "mypy", marker = "extra == 'dev'", specifier = ">=1.5.0" }, @@ -589,6 +605,7 @@ requires-dist = [ { name = "numpy" }, { name = "onnx", marker = "extra == 'full'", specifier = ">=1.15.0" }, { name = "onnx", marker = "extra == 'onnx'", specifier = ">=1.15.0" }, + { name = "protobuf", marker = "extra == 'dev'", specifier = ">=4.0.0" }, { name = "pytest", marker = "extra == 'dev'", specifier = ">=7.4.0" }, { name = "pytest-asyncio", marker = "extra == 'dev'", specifier = ">=0.21.0" }, { name = "pytest-benchmark", marker = "extra == 'dev'", specifier = ">=5.2.3" }, @@ -596,14 +613,18 @@ requires-dist = [ { name = "ruff", marker = "extra == 'dev'", specifier = ">=0.1.0" }, { name = "sentence-transformers", marker = "extra == 'full'", specifier = ">=5.2.0" }, { name = "sentence-transformers", marker = "extra == 'huggingface'", specifier = ">=5.2.0" }, + { name = "sentencepiece", marker = "extra == 'dev'", specifier = ">=0.2.0" }, { name = "spacy", marker = "extra == 'full'", specifier = ">=3.8.11" }, + { name = "tokenizers", marker = "extra == 'dev'", specifier = ">=0.15.0" }, + { name = "tokenizers", marker = "extra == 'tokenizer'", specifier = ">=0.15.0" }, { name = "torch", marker = "extra == 'full'", specifier = ">=2.10.0" }, { name = "torch", marker = "extra == 'torch'", specifier = ">=2.10.0" }, + { name = "transformers", marker = "extra == 'dev'", specifier = ">=4.57.6" }, { name = "transformers", marker = "extra == 'full'", specifier = ">=4.57.6" }, { name = "transformers", marker = "extra == 'huggingface'", specifier = ">=4.57.6" }, { name = "vulkan", marker = "extra == 'full'", specifier = ">=1.3.0" }, ] -provides-extras = ["full", "torch", "huggingface", "onnx", "dev", "accel", "all"] +provides-extras = ["full", "torch", "huggingface", "onnx", "dev", "accel", "tokenizer", "all"] [[package]] name = "hf-xet" @@ -1947,6 +1968,54 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/40/d0/3b2897ef6a0c0c801e9fecca26bcc77081648e38e8c772885ebdd8d7d252/sentence_transformers-5.2.0-py3-none-any.whl", hash = "sha256:aa57180f053687d29b08206766ae7db549be5074f61849def7b17bf0b8025ca2", size = 493748, upload-time = "2025-12-11T14:12:29.516Z" }, ] +[[package]] +name = "sentencepiece" +version = "0.2.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/15/15/2e7a025fc62d764b151ae6d0f2a92f8081755ebe8d4a64099accc6f77ba6/sentencepiece-0.2.1.tar.gz", hash = "sha256:8138cec27c2f2282f4a34d9a016e3374cd40e5c6e9cb335063db66a0a3b71fad", size = 3228515, upload-time = "2025-08-12T07:00:51.718Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/4a/be/32ce495aa1d0e0c323dcb1ba87096037358edee539cac5baf8755a6bd396/sentencepiece-0.2.1-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:57cae326c8727de58c85977b175af132a7138d84c764635d7e71bbee7e774133", size = 1943152, upload-time = "2025-08-12T06:59:40.048Z" }, + { url = "https://files.pythonhosted.org/packages/88/7e/ff23008899a58678e98c6ff592bf4d368eee5a71af96d0df6b38a039dd4f/sentencepiece-0.2.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:56dd39a3c4d6493db3cdca7e8cc68c6b633f0d4195495cbadfcf5af8a22d05a6", size = 1325651, upload-time = "2025-08-12T06:59:41.536Z" }, + { url = "https://files.pythonhosted.org/packages/19/84/42eb3ce4796777a1b5d3699dfd4dca85113e68b637f194a6c8d786f16a04/sentencepiece-0.2.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:d9381351182ff9888cc80e41c632e7e274b106f450de33d67a9e8f6043da6f76", size = 1253645, upload-time = "2025-08-12T06:59:42.903Z" }, + { url = "https://files.pythonhosted.org/packages/89/fa/d3d5ebcba3cb9e6d3775a096251860c41a6bc53a1b9461151df83fe93255/sentencepiece-0.2.1-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:99f955df238021bf11f0fc37cdb54fd5e5b5f7fd30ecc3d93fb48b6815437167", size = 1316273, upload-time = "2025-08-12T06:59:44.476Z" }, + { url = "https://files.pythonhosted.org/packages/04/88/14f2f4a2b922d8b39be45bf63d79e6cd3a9b2f248b2fcb98a69b12af12f5/sentencepiece-0.2.1-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0cdfecef430d985f1c2bcbfff3defd1d95dae876fbd0173376012d2d7d24044b", size = 1387881, upload-time = "2025-08-12T06:59:46.09Z" }, + { url = "https://files.pythonhosted.org/packages/fd/b8/903e5ccb77b4ef140605d5d71b4f9e0ad95d456d6184688073ed11712809/sentencepiece-0.2.1-cp312-cp312-win32.whl", hash = "sha256:a483fd29a34c3e34c39ac5556b0a90942bec253d260235729e50976f5dba1068", size = 999540, upload-time = "2025-08-12T06:59:48.023Z" }, + { url = "https://files.pythonhosted.org/packages/2d/81/92df5673c067148c2545b1bfe49adfd775bcc3a169a047f5a0e6575ddaca/sentencepiece-0.2.1-cp312-cp312-win_amd64.whl", hash = "sha256:4cdc7c36234fda305e85c32949c5211faaf8dd886096c7cea289ddc12a2d02de", size = 1054671, upload-time = "2025-08-12T06:59:49.895Z" }, + { url = "https://files.pythonhosted.org/packages/fe/02/c5e3bc518655d714622bec87d83db9cdba1cd0619a4a04e2109751c4f47f/sentencepiece-0.2.1-cp312-cp312-win_arm64.whl", hash = "sha256:daeb5e9e9fcad012324807856113708614d534f596d5008638eb9b40112cd9e4", size = 1033923, upload-time = "2025-08-12T06:59:51.952Z" }, + { url = "https://files.pythonhosted.org/packages/ba/4a/85fbe1706d4d04a7e826b53f327c4b80f849cf1c7b7c5e31a20a97d8f28b/sentencepiece-0.2.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:dcd8161eee7b41aae57ded06272905dbd680a0a04b91edd0f64790c796b2f706", size = 1943150, upload-time = "2025-08-12T06:59:53.588Z" }, + { url = "https://files.pythonhosted.org/packages/c2/83/4cfb393e287509fc2155480b9d184706ef8d9fa8cbf5505d02a5792bf220/sentencepiece-0.2.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:c6c8f42949f419ff8c7e9960dbadcfbc982d7b5efc2f6748210d3dd53a7de062", size = 1325651, upload-time = "2025-08-12T06:59:55.073Z" }, + { url = "https://files.pythonhosted.org/packages/8d/de/5a007fb53b1ab0aafc69d11a5a3dd72a289d5a3e78dcf2c3a3d9b14ffe93/sentencepiece-0.2.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:097f3394e99456e9e4efba1737c3749d7e23563dd1588ce71a3d007f25475fff", size = 1253641, upload-time = "2025-08-12T06:59:56.562Z" }, + { url = "https://files.pythonhosted.org/packages/2c/d2/f552be5928105588f4f4d66ee37dd4c61460d8097e62d0e2e0eec41bc61d/sentencepiece-0.2.1-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d7b670879c370d350557edabadbad1f6561a9e6968126e6debca4029e5547820", size = 1316271, upload-time = "2025-08-12T06:59:58.109Z" }, + { url = "https://files.pythonhosted.org/packages/96/df/0cfe748ace5485be740fed9476dee7877f109da32ed0d280312c94ec259f/sentencepiece-0.2.1-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c7f0fd2f2693309e6628aeeb2e2faf6edd221134dfccac3308ca0de01f8dab47", size = 1387882, upload-time = "2025-08-12T07:00:00.701Z" }, + { url = "https://files.pythonhosted.org/packages/ac/dd/f7774d42a881ced8e1739f393ab1e82ece39fc9abd4779e28050c2e975b5/sentencepiece-0.2.1-cp313-cp313-win32.whl", hash = "sha256:92b3816aa2339355fda2c8c4e021a5de92180b00aaccaf5e2808972e77a4b22f", size = 999541, upload-time = "2025-08-12T07:00:02.709Z" }, + { url = "https://files.pythonhosted.org/packages/dd/e9/932b9eae6fd7019548321eee1ab8d5e3b3d1294df9d9a0c9ac517c7b636d/sentencepiece-0.2.1-cp313-cp313-win_amd64.whl", hash = "sha256:10ed3dab2044c47f7a2e7b4969b0c430420cdd45735d78c8f853191fa0e3148b", size = 1054669, upload-time = "2025-08-12T07:00:04.915Z" }, + { url = "https://files.pythonhosted.org/packages/c9/3a/76488a00ea7d6931689cda28726a1447d66bf1a4837943489314593d5596/sentencepiece-0.2.1-cp313-cp313-win_arm64.whl", hash = "sha256:ac650534e2251083c5f75dde4ff28896ce7c8904133dc8fef42780f4d5588fcd", size = 1033922, upload-time = "2025-08-12T07:00:06.496Z" }, + { url = "https://files.pythonhosted.org/packages/4a/b6/08fe2ce819e02ccb0296f4843e3f195764ce9829cbda61b7513f29b95718/sentencepiece-0.2.1-cp313-cp313t-macosx_10_13_universal2.whl", hash = "sha256:8dd4b477a7b069648d19363aad0cab9bad2f4e83b2d179be668efa672500dc94", size = 1946052, upload-time = "2025-08-12T07:00:08.136Z" }, + { url = "https://files.pythonhosted.org/packages/ab/d9/1ea0e740591ff4c6fc2b6eb1d7510d02f3fb885093f19b2f3abd1363b402/sentencepiece-0.2.1-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:0c0f672da370cc490e4c59d89e12289778310a0e71d176c541e4834759e1ae07", size = 1327408, upload-time = "2025-08-12T07:00:09.572Z" }, + { url = "https://files.pythonhosted.org/packages/99/7e/1fb26e8a21613f6200e1ab88824d5d203714162cf2883248b517deb500b7/sentencepiece-0.2.1-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:ad8493bea8432dae8d6830365352350f3b4144415a1d09c4c8cb8d30cf3b6c3c", size = 1254857, upload-time = "2025-08-12T07:00:11.021Z" }, + { url = "https://files.pythonhosted.org/packages/bc/85/c72fd1f3c7a6010544d6ae07f8ddb38b5e2a7e33bd4318f87266c0bbafbf/sentencepiece-0.2.1-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b81a24733726e3678d2db63619acc5a8dccd074f7aa7a54ecd5ca33ca6d2d596", size = 1315722, upload-time = "2025-08-12T07:00:12.989Z" }, + { url = "https://files.pythonhosted.org/packages/4a/e8/661e5bd82a8aa641fd6c1020bd0e890ef73230a2b7215ddf9c8cd8e941c2/sentencepiece-0.2.1-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0a81799d0a68d618e89063fb423c3001a034c893069135ffe51fee439ae474d6", size = 1387452, upload-time = "2025-08-12T07:00:15.088Z" }, + { url = "https://files.pythonhosted.org/packages/99/5e/ae66c361023a470afcbc1fbb8da722c72ea678a2fcd9a18f1a12598c7501/sentencepiece-0.2.1-cp313-cp313t-win32.whl", hash = "sha256:89a3ea015517c42c0341d0d962f3e6aaf2cf10d71b1932d475c44ba48d00aa2b", size = 1002501, upload-time = "2025-08-12T07:00:16.966Z" }, + { url = "https://files.pythonhosted.org/packages/c1/03/d332828c4ff764e16c1b56c2c8f9a33488bbe796b53fb6b9c4205ddbf167/sentencepiece-0.2.1-cp313-cp313t-win_amd64.whl", hash = "sha256:33f068c9382dc2e7c228eedfd8163b52baa86bb92f50d0488bf2b7da7032e484", size = 1057555, upload-time = "2025-08-12T07:00:18.573Z" }, + { url = "https://files.pythonhosted.org/packages/88/14/5aee0bf0864df9bd82bd59e7711362908e4935e3f9cdc1f57246b5d5c9b9/sentencepiece-0.2.1-cp313-cp313t-win_arm64.whl", hash = "sha256:b3616ad246f360e52c85781e47682d31abfb6554c779e42b65333d4b5f44ecc0", size = 1036042, upload-time = "2025-08-12T07:00:20.209Z" }, + { url = "https://files.pythonhosted.org/packages/24/9c/89eb8b2052f720a612478baf11c8227dcf1dc28cd4ea4c0c19506b5af2a2/sentencepiece-0.2.1-cp314-cp314-macosx_10_13_universal2.whl", hash = "sha256:5d0350b686c320068702116276cfb26c066dc7e65cfef173980b11bb4d606719", size = 1943147, upload-time = "2025-08-12T07:00:21.809Z" }, + { url = "https://files.pythonhosted.org/packages/82/0b/a1432bc87f97c2ace36386ca23e8bd3b91fb40581b5e6148d24b24186419/sentencepiece-0.2.1-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:c7f54a31cde6fa5cb030370566f68152a742f433f8d2be458463d06c208aef33", size = 1325624, upload-time = "2025-08-12T07:00:23.289Z" }, + { url = "https://files.pythonhosted.org/packages/ea/99/bbe054ebb5a5039457c590e0a4156ed073fb0fe9ce4f7523404dd5b37463/sentencepiece-0.2.1-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:c83b85ab2d6576607f31df77ff86f28182be4a8de6d175d2c33ca609925f5da1", size = 1253670, upload-time = "2025-08-12T07:00:24.69Z" }, + { url = "https://files.pythonhosted.org/packages/19/ad/d5c7075f701bd97971d7c2ac2904f227566f51ef0838dfbdfdccb58cd212/sentencepiece-0.2.1-cp314-cp314-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:1855f57db07b51fb51ed6c9c452f570624d2b169b36f0f79ef71a6e6c618cd8b", size = 1316247, upload-time = "2025-08-12T07:00:26.435Z" }, + { url = "https://files.pythonhosted.org/packages/fb/03/35fbe5f3d9a7435eebd0b473e09584bd3cc354ce118b960445b060d33781/sentencepiece-0.2.1-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:01e6912125cb45d3792f530a4d38f8e21bf884d6b4d4ade1b2de5cf7a8d2a52b", size = 1387894, upload-time = "2025-08-12T07:00:28.339Z" }, + { url = "https://files.pythonhosted.org/packages/dc/aa/956ef729aafb6c8f9c443104c9636489093bb5c61d6b90fc27aa1a865574/sentencepiece-0.2.1-cp314-cp314-win32.whl", hash = "sha256:c415c9de1447e0a74ae3fdb2e52f967cb544113a3a5ce3a194df185cbc1f962f", size = 1096698, upload-time = "2025-08-12T07:00:29.764Z" }, + { url = "https://files.pythonhosted.org/packages/b8/cb/fe400d8836952cc535c81a0ce47dc6875160e5fedb71d2d9ff0e9894c2a6/sentencepiece-0.2.1-cp314-cp314-win_amd64.whl", hash = "sha256:881b2e44b14fc19feade3cbed314be37de639fc415375cefaa5bc81a4be137fd", size = 1155115, upload-time = "2025-08-12T07:00:32.865Z" }, + { url = "https://files.pythonhosted.org/packages/32/89/047921cf70f36c7b6b6390876b2399b3633ab73b8d0cb857e5a964238941/sentencepiece-0.2.1-cp314-cp314-win_arm64.whl", hash = "sha256:2005242a16d2dc3ac5fe18aa7667549134d37854823df4c4db244752453b78a8", size = 1133890, upload-time = "2025-08-12T07:00:34.763Z" }, + { url = "https://files.pythonhosted.org/packages/a1/11/5b414b9fae6255b5fb1e22e2ed3dc3a72d3a694e5703910e640ac78346bb/sentencepiece-0.2.1-cp314-cp314t-macosx_10_13_universal2.whl", hash = "sha256:a19adcec27c524cb7069a1c741060add95f942d1cbf7ad0d104dffa0a7d28a2b", size = 1946081, upload-time = "2025-08-12T07:00:36.97Z" }, + { url = "https://files.pythonhosted.org/packages/77/eb/7a5682bb25824db8545f8e5662e7f3e32d72a508fdce086029d89695106b/sentencepiece-0.2.1-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:e37e4b4c4a11662b5db521def4e44d4d30ae69a1743241412a93ae40fdcab4bb", size = 1327406, upload-time = "2025-08-12T07:00:38.669Z" }, + { url = "https://files.pythonhosted.org/packages/03/b0/811dae8fb9f2784e138785d481469788f2e0d0c109c5737372454415f55f/sentencepiece-0.2.1-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:477c81505db072b3ab627e7eab972ea1025331bd3a92bacbf798df2b75ea86ec", size = 1254846, upload-time = "2025-08-12T07:00:40.611Z" }, + { url = "https://files.pythonhosted.org/packages/ef/23/195b2e7ec85ebb6a547969f60b723c7aca5a75800ece6cc3f41da872d14e/sentencepiece-0.2.1-cp314-cp314t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:010f025a544ef770bb395091d57cb94deb9652d8972e0d09f71d85d5a0816c8c", size = 1315721, upload-time = "2025-08-12T07:00:42.914Z" }, + { url = "https://files.pythonhosted.org/packages/7e/aa/553dbe4178b5f23eb28e59393dddd64186178b56b81d9b8d5c3ff1c28395/sentencepiece-0.2.1-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:733e59ff1794d26db706cd41fc2d7ca5f6c64a820709cb801dc0ea31780d64ab", size = 1387458, upload-time = "2025-08-12T07:00:44.56Z" }, + { url = "https://files.pythonhosted.org/packages/66/7c/08ff0012507297a4dd74a5420fdc0eb9e3e80f4e88cab1538d7f28db303d/sentencepiece-0.2.1-cp314-cp314t-win32.whl", hash = "sha256:d3233770f78e637dc8b1fda2cd7c3b99ec77e7505041934188a4e7fe751de3b0", size = 1099765, upload-time = "2025-08-12T07:00:46.058Z" }, + { url = "https://files.pythonhosted.org/packages/91/d5/2a69e1ce15881beb9ddfc7e3f998322f5cedcd5e4d244cb74dade9441663/sentencepiece-0.2.1-cp314-cp314t-win_amd64.whl", hash = "sha256:5e4366c97b68218fd30ea72d70c525e6e78a6c0a88650f57ac4c43c63b234a9d", size = 1157807, upload-time = "2025-08-12T07:00:47.673Z" }, + { url = "https://files.pythonhosted.org/packages/f3/16/54f611fcfc2d1c46cbe3ec4169780b2cfa7cf63708ef2b71611136db7513/sentencepiece-0.2.1-cp314-cp314t-win_arm64.whl", hash = "sha256:105e36e75cbac1292642045458e8da677b2342dcd33df503e640f0b457cb6751", size = 1136264, upload-time = "2025-08-12T07:00:49.485Z" }, +] + [[package]] name = "setuptools" version = "80.10.1" From e1c0e688b5b7dd1364c6a2c5c6925830598f57d5 Mon Sep 17 00:00:00 2001 From: Grill cheese Date: Tue, 7 Apr 2026 11:55:07 -0400 Subject: [PATCH 17/17] pre v1.0 Massive changes. Memory fixes, many optimization. Pre-release --- CMakeLists.txt | 2 + cpp/include/grilly/ops/linear.h | 55 ++-- cpp/include/grilly/ops/prefix_scan.h | 65 +++++ cpp/include/grilly/vulkan/vk_pipeline_cache.h | 5 + cpp/python/bindings_core.cpp | 1 + cpp/python/bindings_core.h | 1 + cpp/python/bindings_linear.cpp | 112 +++++--- cpp/python/bindings_prefix_scan.cpp | 118 ++++++++ cpp/src/ops/linear.cpp | 252 +++++++++++++----- cpp/src/ops/prefix_scan.cpp | 165 ++++++++++++ nn/prefix_scan.py | 182 +++++++++++++ rebuild.ps1 | 148 ++++++++++ shaders/fused-layernorm-linear.glsl | 1 + shaders/gemm-bias-add.glsl | 27 ++ shaders/gemm-coopmat-shared.glsl | 99 +++++++ shaders/lstm-cell-forward.glsl | 12 +- shaders/prefix-scan-causal-backward.glsl | 102 +++++++ shaders/prefix-scan-causal.glsl | 85 ++++++ shaders/spv/activation-gelu-backward.spv | Bin 2860 -> 3068 bytes shaders/spv/activation-gelu.spv | Bin 2464 -> 2212 bytes shaders/spv/addition-linear.spv | Bin 0 -> 3428 bytes shaders/spv/convd_col2im_noatomic.spv | Bin 0 -> 6992 bytes shaders/spv/dequant-4bit.spv | Bin 0 -> 5000 bytes shaders/spv/fused-layernorm-linear.spv | Bin 7252 -> 6128 bytes shaders/spv/gemm-bias-add.spv | Bin 0 -> 1596 bytes shaders/spv/gemm-coopmat-shared.spv | Bin 0 -> 5012 bytes shaders/spv/gqa-attention.spv | Bin 0 -> 9812 bytes shaders/spv/grid-cell.spv | Bin 0 -> 4236 bytes shaders/spv/hmm-baum-welch.spv | Bin 0 -> 4536 bytes shaders/spv/hmm-forward.spv | Bin 0 -> 4644 bytes shaders/spv/hopfield-surprise.spv | Bin 0 -> 5976 bytes shaders/spv/lstm-cell-forward.spv | Bin 0 -> 8140 bytes shaders/spv/prefix-scan-causal-backward.spv | Bin 0 -> 5868 bytes shaders/spv/prefix-scan-causal.spv | Bin 0 -> 3684 bytes shaders/spv/rms-norm-linear-fused.spv | Bin 0 -> 4820 bytes shaders/spv/sign-activation.spv | Bin 1520 -> 1444 bytes shaders/spv/stdp-learning.spv | Bin 7748 -> 7756 bytes shaders/spv/surprise-momentum.spv | Bin 0 -> 4268 bytes shaders/spv/surprise-recall-blend.spv | Bin 0 -> 3128 bytes shaders/spv/swiglu-fused.spv | Bin 0 -> 3932 bytes shaders/spv/synapsis-forward.spv | Bin 3368 -> 3372 bytes shaders/spv/synapsis-stdp-trace.spv | Bin 3304 -> 3304 bytes shaders/spv/synapsis-stdp-update.spv | Bin 3676 -> 3676 bytes shaders/spv/theta-gamma-encoding.spv | Bin 4908 -> 4916 bytes shaders/spv/time-cell.spv | Bin 4596 -> 4600 bytes shaders/spv/vsa-explore.spv | Bin 0 -> 5208 bytes shaders/spv/vsa-logic-apply.spv | Bin 2640 -> 2568 bytes shaders/spv/whitening-apply.spv | Bin 2768 -> 2772 bytes shaders/spv/whitening-batch-stats.spv | Bin 4712 -> 4720 bytes shaders/spv/whitening-transform.spv | Bin 3924 -> 3928 bytes shaders/vsa-explore.glsl | 14 +- 51 files changed, 1314 insertions(+), 132 deletions(-) create mode 100644 cpp/include/grilly/ops/prefix_scan.h create mode 100644 cpp/python/bindings_prefix_scan.cpp create mode 100644 cpp/src/ops/prefix_scan.cpp create mode 100644 nn/prefix_scan.py create mode 100644 rebuild.ps1 create mode 100644 shaders/gemm-bias-add.glsl create mode 100644 shaders/gemm-coopmat-shared.glsl create mode 100644 shaders/prefix-scan-causal-backward.glsl create mode 100644 shaders/prefix-scan-causal.glsl create mode 100644 shaders/spv/addition-linear.spv create mode 100644 shaders/spv/convd_col2im_noatomic.spv create mode 100644 shaders/spv/dequant-4bit.spv create mode 100644 shaders/spv/gemm-bias-add.spv create mode 100644 shaders/spv/gemm-coopmat-shared.spv create mode 100644 shaders/spv/gqa-attention.spv create mode 100644 shaders/spv/grid-cell.spv create mode 100644 shaders/spv/hmm-baum-welch.spv create mode 100644 shaders/spv/hmm-forward.spv create mode 100644 shaders/spv/hopfield-surprise.spv create mode 100644 shaders/spv/lstm-cell-forward.spv create mode 100644 shaders/spv/prefix-scan-causal-backward.spv create mode 100644 shaders/spv/prefix-scan-causal.spv create mode 100644 shaders/spv/rms-norm-linear-fused.spv create mode 100644 shaders/spv/surprise-momentum.spv create mode 100644 shaders/spv/surprise-recall-blend.spv create mode 100644 shaders/spv/swiglu-fused.spv create mode 100644 shaders/spv/vsa-explore.spv diff --git a/CMakeLists.txt b/CMakeLists.txt index dc2a917..076d0ad 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -220,6 +220,7 @@ add_library(grilly_core_lib STATIC cpp/src/ops/moqe_train.cpp cpp/src/ops/moe_forward.cpp cpp/src/ops/vsa_lm_forward.cpp + cpp/src/ops/prefix_scan.cpp cpp/src/shader_fusion.cpp # ── Experimental ── cpp/src/experimental/paged_latent_pool.cpp @@ -310,6 +311,7 @@ pybind11_add_module(grilly_core cpp/python/bindings_fusion.cpp cpp/python/bindings_vsa_lm.cpp cpp/python/bindings_grl.cpp + cpp/python/bindings_prefix_scan.cpp ) target_link_libraries(grilly_core PRIVATE grilly_core_lib) diff --git a/cpp/include/grilly/ops/linear.h b/cpp/include/grilly/ops/linear.h index a9018c5..d1d9dc8 100644 --- a/cpp/include/grilly/ops/linear.h +++ b/cpp/include/grilly/ops/linear.h @@ -14,42 +14,54 @@ namespace grilly { namespace ops { /// GPU-accelerated linear projection: output = x @ W^T + bias. -/// Ports backend/fnn.py:1823-1976 to native C++. /// -/// Push constants layout (must match fnn-linear.glsl): +/// Dynamic dtype via ``elemSize``: +/// - ``elemSize == 4`` (fp32): dispatches ``fnn-linear.glsl`` (tiled GEMM, +/// universal path, works on any device). x / weights / bias are +/// interpreted as ``float``; output is always ``float``. +/// - ``elemSize == 2`` (fp16): dispatches ``gemm-coopmat-shared.glsl`` if +/// the device supports ``VK_KHR_cooperative_matrix`` AND the shape is +/// aligned (M % 16 == 0, K % 16 == 0, N % 64 == 0). x / weights / bias +/// are ``float16_t`` bytes in the staging buffer; output remains fp32 +/// (the coopmat accumulator runs in fp32 for numerical stability). +/// Falls back to an fp32 conversion + fnn-linear path if the device +/// lacks cooperative matrix support or the shape isn't aligned. +/// +/// Bias handling: on the coopmat path, bias is applied via a small second +/// dispatch (``gemm-bias-add.glsl``) because ``coopMatStore`` can't +/// interleave an elementwise add with the tile-cooperative store. +/// +/// Push constants layout (matches fnn-linear.glsl): /// uint batch_seq; // offset 0 /// uint input_dim; // offset 4 /// uint output_dim; // offset 8 /// uint has_bias; // offset 12 -/// -/// Buffers (binding order matches shader): -/// 0: input (batch_seq * input_dim floats) -/// 1: weights (output_dim * input_dim floats) -/// 2: bias (output_dim floats, or 1 dummy float) -/// 3: output (batch_seq * output_dim floats) +/// uint elem_size; // offset 16 — 4 for fp32, 2 for fp16 struct LinearParams { uint32_t batchSeq; uint32_t inputDim; uint32_t outputDim; uint32_t hasBias; + uint32_t elemSize; ///< 4 for fp32, 2 for fp16 }; /// Execute a linear (dense / fully-connected) layer on the GPU. /// -/// All Vulkan work — buffer upload, pipeline bind, descriptor set, -/// dispatch, download — happens inside C++ with zero Python crossings. +/// Pointer types are ``const void*`` / ``void*`` so the caller can pass +/// either fp32 or fp16 byte buffers transparently — the element size is +/// encoded in ``p.elemSize``. /// /// @param batch CommandBatch to record the dispatch into. /// @param pool BufferPool for acquiring/releasing GPU memory. -/// @param cache PipelineCache for the fnn-linear shader. +/// @param cache PipelineCache for the fnn-linear / gemm-coopmat shaders. /// @param x Input matrix, row-major (batchSeq × inputDim). /// @param weights Weight matrix, row-major (outputDim × inputDim). /// @param bias Optional bias vector (outputDim). nullptr = no bias. -/// @param output Output buffer, pre-allocated (batchSeq × outputDim). -/// @param p Dimension parameters. +/// @param output Output buffer, pre-allocated (batchSeq × outputDim × 4 bytes). +/// @param p Dimension + dtype parameters. void linear(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, - const float* x, const float* weights, const float* bias, - float* output, const LinearParams& p); + const void* x, const void* weights, const void* bias, + void* output, const LinearParams& p); /// CPU reference implementation using Eigen for correctness verification. /// Returns a newly-allocated vector: output = x @ W^T + bias. @@ -66,13 +78,18 @@ struct LinearBackwardParams { }; /// GPU linear backward. Produces grad_input, grad_weight, grad_bias. -/// 3-pass dispatch with barriers. +/// 3-pass dispatch with barriers. Dtype-dynamic via ``p.elemSize``. /// 6 buffers: grad_output(0), input(1), weights(2), /// grad_input(3), grad_weight(4), grad_bias(5). +/// +/// NOTE: the current backward shader (``fnn-linear-backward.glsl``) is +/// fp32-only. Calling with ``p.elemSize == 2`` will throw until a +/// cooperative-matrix backward shader lands. The ``void*`` / dynamic +/// elemSize interface is in place so the upgrade is a local shader change. void linearBackward(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, - const float* gradOutput, const float* input, - const float* weights, - float* gradInput, float* gradWeight, float* gradBias, + const void* gradOutput, const void* input, + const void* weights, + void* gradInput, void* gradWeight, void* gradBias, const LinearParams& p); /// Dropout push constants — matches fnn-dropout.glsl. diff --git a/cpp/include/grilly/ops/prefix_scan.h b/cpp/include/grilly/ops/prefix_scan.h new file mode 100644 index 0000000..929ca8d --- /dev/null +++ b/cpp/include/grilly/ops/prefix_scan.h @@ -0,0 +1,65 @@ +#pragma once + +#include + +#include "grilly/buffer_pool.h" +#include "grilly/command_batch.h" +#include "grilly/pipeline_cache.h" + +namespace grilly { +namespace ops { + +/// Causal Linear-RNN prefix scan push constants. +/// Matches shaders/prefix-scan-causal.glsl + prefix-scan-causal-backward.glsl. +struct PrefixScanParams { + uint32_t seqLen; ///< must be <= subgroup size (32 Wave32 / 64 Wave64) + uint32_t hiddenDim; + uint32_t batchSize; +}; + +/// Causal linear-RNN forward: h_t = a_t * h_{t-1} + x_t. +/// +/// Implemented as two hardware subgroup inclusive scans (log(a) and the +/// rescaled x). The recurrence is strictly causal: h_t depends only on +/// x_{0..t} and a_{0..t}. +/// +/// Inputs / outputs are fp32 (B, S, D) row-major. +/// x: input sequence, shape (B, S, D) +/// a: decay gates in (0, 1], same shape +/// h: output hidden states, same shape (pre-allocated) +/// +/// Buffer bindings (match the shader): +/// 0: x (input) +/// 1: a (decay) +/// 2: h (output) +/// +/// Constraint: seqLen <= subgroup size. Longer sequences need a hierarchical +/// scan — not implemented yet. +void prefixScanCausal(CommandBatch& batch, BufferPool& pool, + PipelineCache& cache, + const float* x, const float* a, float* h, + const PrefixScanParams& p); + +/// Causal prefix scan backward: given dh, compute dx and da. +/// +/// Math: dx_t = sum_{s>=t} (prod_{k=t+1..s} a_k) * dh_s (anti-causal scan) +/// da_t = dx_t * h_{t-1} +/// +/// Needs the forward x and h tensors saved from the forward pass. +/// +/// Buffer bindings: +/// 0: dh (grad into output) +/// 1: a (decay) +/// 2: h (forward output, for da computation) +/// 3: x (forward input, for da computation) +/// 4: dx (grad w.r.t. input, output) +/// 5: da (grad w.r.t. decay, output) +void prefixScanCausalBackward(CommandBatch& batch, BufferPool& pool, + PipelineCache& cache, + const float* dh, const float* a, + const float* h, const float* x, + float* dx, float* da, + const PrefixScanParams& p); + +} // namespace ops +} // namespace grilly diff --git a/cpp/include/grilly/vulkan/vk_pipeline_cache.h b/cpp/include/grilly/vulkan/vk_pipeline_cache.h index 330e796..7a12572 100644 --- a/cpp/include/grilly/vulkan/vk_pipeline_cache.h +++ b/cpp/include/grilly/vulkan/vk_pipeline_cache.h @@ -46,6 +46,11 @@ class PipelineCache { return spirvCode_.count(name) > 0; } + /// Access the underlying device for capability queries + /// (e.g. ``hasCooperativeMatrix()``). + GrillyDevice& getDevice() { return device_; } + const GrillyDevice& getDevice() const { return device_; } + struct CacheStats { uint64_t hits = 0; uint64_t misses = 0; diff --git a/cpp/python/bindings_core.cpp b/cpp/python/bindings_core.cpp index 0e9f82c..0491197 100644 --- a/cpp/python/bindings_core.cpp +++ b/cpp/python/bindings_core.cpp @@ -440,4 +440,5 @@ PYBIND11_MODULE(grilly_core, m) { register_vsa_lm_ops(m); register_grl_ops(m); register_misc_ops(m); + register_prefix_scan_ops(m); } diff --git a/cpp/python/bindings_core.h b/cpp/python/bindings_core.h index 6861e11..3ba823d 100644 --- a/cpp/python/bindings_core.h +++ b/cpp/python/bindings_core.h @@ -191,3 +191,4 @@ void register_moe_ops(py::module_& m); void register_fusion_ops(py::module_& m); void register_vsa_lm_ops(py::module_& m); void register_grl_ops(py::module_& m); +void register_prefix_scan_ops(py::module_& m); diff --git a/cpp/python/bindings_linear.cpp b/cpp/python/bindings_linear.cpp index dd93bda..9b35265 100644 --- a/cpp/python/bindings_linear.cpp +++ b/cpp/python/bindings_linear.cpp @@ -8,16 +8,20 @@ void register_linear_ops(py::module_& m) { using namespace grilly::nn; - // ── GPU linear forward ─────────────────────────────────────────────── + // ── GPU linear forward (fp32 + fp16) ─────────────────────────────── + // + // Accepts generic numpy arrays — inspects ``itemsize`` to detect fp32 + // (itemsize=4) vs fp16 (itemsize=2). Input and weights must share the + // same dtype; bias is always fp32 (stability + broadcast simplicity). + // Output is always fp32 regardless of input dtype because the + // cooperative-matrix accumulator runs in fp32 — callers needing fp16 + // output can ``astype(np.float16)`` on the returned array. m.def( "linear", - [](GrillyCoreContext& ctx, py::array_t x, - py::array_t weights, - std::optional> bias) -> Tensor { + [](GrillyCoreContext& ctx, py::array x, py::array weights, + std::optional bias) -> py::array { auto xBuf = x.request(); auto wBuf = weights.request(); - require_c_contiguous_float(xBuf); - require_c_contiguous_float(wBuf); if (xBuf.ndim < 1 || xBuf.ndim > 3) throw std::runtime_error( @@ -26,6 +30,15 @@ void register_linear_ops(py::module_& m) { throw std::runtime_error( "weights must be 2D (output_dim, input_dim)"); + if (xBuf.itemsize != 2 && xBuf.itemsize != 4) + throw std::runtime_error( + "grilly linear: x must be fp32 (itemsize=4) or fp16 (itemsize=2)"); + if (xBuf.itemsize != wBuf.itemsize) + throw std::runtime_error( + "grilly linear: x and weights must share the same dtype"); + + const uint32_t elemSize = static_cast(xBuf.itemsize); + auto [batchSeq, inputDim] = extractBatchAndLastDim(xBuf); uint32_t outputDim = static_cast(wBuf.shape[0]); @@ -35,20 +48,26 @@ void register_linear_ops(py::module_& m) { std::to_string(wBuf.shape[1]) + " vs " + std::to_string(inputDim)); - const float* biasPtr = nullptr; + // Bias is always fp32. If the caller passed fp16 bias, we + // require them to upcast it on the Python side — simpler and + // matches the C++ contract. + const void* biasPtr = nullptr; uint32_t hasBias = 0; if (bias.has_value()) { auto bBuf = bias->request(); - require_c_contiguous_float(bBuf); + if (bBuf.itemsize != 4) + throw std::runtime_error( + "grilly linear: bias must be fp32 (cast via .astype(np.float32))"); if (bBuf.ndim != 1 || static_cast(bBuf.shape[0]) != outputDim) throw std::runtime_error( "bias must be 1D with size output_dim"); - biasPtr = static_cast(bBuf.ptr); + biasPtr = bBuf.ptr; hasBias = 1; } - grilly::ops::LinearParams p{batchSeq, inputDim, outputDim, hasBias}; + grilly::ops::LinearParams p{ + batchSeq, inputDim, outputDim, hasBias, elemSize}; std::vector outShape; for (int i = 0; i < xBuf.ndim - 1; ++i) @@ -56,16 +75,16 @@ void register_linear_ops(py::module_& m) { if (xBuf.ndim == 1) outShape.push_back(1); outShape.push_back(outputDim); - py::array_t result(outShape); + // Output is always fp32 — format code 'f'. The coopmat shader + // accumulates in fp32, and so does fnn-linear. + py::array result(py::dtype("f"), outShape); auto rBuf = result.request(); - // Extract raw pointers before GIL release - const float* xPtr = static_cast(xBuf.ptr); - const float* wPtr = static_cast(wBuf.ptr); - float* oPtr = static_cast(rBuf.ptr); + const void* xPtr = xBuf.ptr; + const void* wPtr = wBuf.ptr; + void* oPtr = rBuf.ptr; { - // Release GIL during GPU GEMM dispatch py::gil_scoped_release release; grilly::ops::linear( ctx.batch, ctx.pool, ctx.cache, @@ -75,11 +94,11 @@ void register_linear_ops(py::module_& m) { if (xBuf.ndim == 1) result = result.reshape({static_cast(outputDim)}); - return Tensor::from_numpy(result); + return result; }, py::arg("device"), py::arg("x"), py::arg("weights"), py::arg("bias") = py::none(), - "GPU linear projection: output = x @ W^T + bias"); + "GPU linear projection (fp32 or fp16 input; output always fp32)"); // ── GPU linear forward (Tensor I/O — no numpy copies) ────────────── m.def( @@ -120,7 +139,9 @@ void register_linear_ops(py::module_& m) { hasBias = 1; } - grilly::ops::LinearParams p{batchSeq, inputDim, outputDim, hasBias}; + // ``linear_t`` is the native C++ Tensor path — always fp32. + grilly::ops::LinearParams p{batchSeq, inputDim, outputDim, + hasBias, 4u}; // Allocate output tensor (CPU-valid) std::vector outShape; @@ -164,7 +185,9 @@ void register_linear_ops(py::module_& m) { hasBias = 1; } - grilly::ops::LinearParams p{batchSeq, inputDim, outputDim, hasBias}; + // ``linear_cpu`` is the Eigen reference path — always fp32. + grilly::ops::LinearParams p{batchSeq, inputDim, outputDim, + hasBias, 4u}; std::vector out = grilly::ops::linearCPU( static_cast(xBuf.ptr), static_cast(wBuf.ptr), biasPtr, p); @@ -187,43 +210,56 @@ void register_linear_ops(py::module_& m) { py::arg("x"), py::arg("weights"), py::arg("bias") = py::none(), "CPU linear projection using Eigen (for verification)"); - // ── GPU linear backward ────────────────────────────────────────────── + // ── GPU linear backward (fp32 only for now; interface is fp16-ready) ── m.def( "linear_backward", [](GrillyCoreContext& ctx, - py::array_t grad_output, py::array_t input, - py::array_t weights) -> py::dict { + py::array grad_output, py::array input, + py::array weights) -> py::dict { auto gBuf = grad_output.request(); auto iBuf = input.request(); auto wBuf = weights.request(); + if (gBuf.itemsize != 2 && gBuf.itemsize != 4) + throw std::runtime_error( + "grilly linear_backward: grad_output must be fp32 or fp16"); + if (gBuf.itemsize != iBuf.itemsize || + gBuf.itemsize != wBuf.itemsize) + throw std::runtime_error( + "grilly linear_backward: grad_output, input, weights " + "must share the same dtype"); + + const uint32_t elemSize = static_cast(gBuf.itemsize); + const std::string npFormat = (elemSize == 2) ? "e" : "f"; + auto [batchSeq, outputDim] = extractBatchAndLastDim(gBuf); uint32_t inputDim = static_cast( iBuf.shape[iBuf.ndim - 1]); - grilly::ops::LinearParams p{batchSeq, inputDim, outputDim, 1}; + grilly::ops::LinearParams p{ + batchSeq, inputDim, outputDim, 1u, elemSize}; - py::array_t gradInput(iBuf.shape); - py::array_t gradWeight(wBuf.shape); - py::array_t gradBias( - {static_cast(outputDim)}); + py::array gradInput(py::dtype(npFormat), iBuf.shape); + py::array gradWeight(py::dtype(npFormat), wBuf.shape); + // Explicit vector to disambiguate from py::array(handle, bool). + std::vector biasShape = { + static_cast(outputDim)}; + py::array gradBias(py::dtype(npFormat), biasShape); grilly::ops::linearBackward( ctx.batch, ctx.pool, ctx.cache, - static_cast(gBuf.ptr), - static_cast(iBuf.ptr), - static_cast(wBuf.ptr), - static_cast(gradInput.request().ptr), - static_cast(gradWeight.request().ptr), - static_cast(gradBias.request().ptr), p); + gBuf.ptr, iBuf.ptr, wBuf.ptr, + gradInput.request().ptr, + gradWeight.request().ptr, + gradBias.request().ptr, p); py::dict result; - result["grad_input"] = Tensor::from_numpy(gradInput); - result["grad_weight"] = Tensor::from_numpy(gradWeight); - result["grad_bias"] = Tensor::from_numpy(gradBias); + result["grad_input"] = gradInput; + result["grad_weight"] = gradWeight; + result["grad_bias"] = gradBias; return result; }, py::arg("device"), py::arg("grad_output"), py::arg("input"), py::arg("weights"), - "GPU linear backward: grad_input, grad_weight, grad_bias"); + "GPU linear backward (supports fp32; fp16 interface ready for shader upgrade)"); } diff --git a/cpp/python/bindings_prefix_scan.cpp b/cpp/python/bindings_prefix_scan.cpp new file mode 100644 index 0000000..ad55f4c --- /dev/null +++ b/cpp/python/bindings_prefix_scan.cpp @@ -0,0 +1,118 @@ +/// bindings_prefix_scan.cpp — Causal Linear-RNN prefix scan op bindings. +/// +/// Exposes ``prefix_scan_causal`` and ``prefix_scan_causal_backward`` to +/// Python. Used by ``CausalSequenceMixer`` in the VSA-LM model to replace +/// the causally-leaky ``h.mean(dim=1)`` pooling with a strictly causal +/// subgroup-scanned Linear-RNN. + +#include "bindings_core.h" +#include "grilly/ops/prefix_scan.h" + +void register_prefix_scan_ops(py::module_& m) { + using namespace grilly::ops; + + m.def( + "prefix_scan_causal", + [](GrillyCoreContext& ctx, + py::array_t x, py::array_t a) -> py::array_t { + auto xBuf = x.request(); + auto aBuf = a.request(); + + if (xBuf.ndim != 3 || aBuf.ndim != 3) + throw std::runtime_error( + "prefix_scan_causal: x and a must be 3D (batch, seq, dim)"); + if (xBuf.shape != aBuf.shape) + throw std::runtime_error( + "prefix_scan_causal: x and a must have identical shape"); + + const uint32_t batchSize = static_cast(xBuf.shape[0]); + const uint32_t seqLen = static_cast(xBuf.shape[1]); + const uint32_t hiddenDim = static_cast(xBuf.shape[2]); + + if (seqLen > 32) + throw std::runtime_error( + "prefix_scan_causal: seq_len must be <= 32 (subgroup size). " + "Longer sequences need a hierarchical scan (TODO)."); + + PrefixScanParams p{seqLen, hiddenDim, batchSize}; + + py::array_t result(xBuf.shape); + auto rBuf = result.request(); + + { + py::gil_scoped_release release; + prefixScanCausal( + ctx.batch, ctx.pool, ctx.cache, + static_cast(xBuf.ptr), + static_cast(aBuf.ptr), + static_cast(rBuf.ptr), p); + } + + return result; + }, + py::arg("device"), py::arg("x"), py::arg("a"), + "Causal Linear-RNN forward: h_t = a_t * h_{t-1} + x_t via subgroup scan"); + + m.def( + "prefix_scan_causal_backward", + [](GrillyCoreContext& ctx, + py::array_t grad_h, py::array_t a, + py::array_t h, py::array_t x) -> py::dict { + auto dhBuf = grad_h.request(); + auto aBuf = a.request(); + auto hBuf = h.request(); + auto xBuf = x.request(); + + if (dhBuf.ndim != 3) + throw std::runtime_error( + "prefix_scan_causal_backward: grad_h must be 3D"); + + const uint32_t batchSize = static_cast(dhBuf.shape[0]); + const uint32_t seqLen = static_cast(dhBuf.shape[1]); + const uint32_t hiddenDim = static_cast(dhBuf.shape[2]); + + if (seqLen > 32) + throw std::runtime_error( + "prefix_scan_causal_backward: seq_len must be <= 32"); + + PrefixScanParams p{seqLen, hiddenDim, batchSize}; + + py::array_t gradX(dhBuf.shape); + py::array_t gradA(dhBuf.shape); + + // CRITICAL: get the raw pointers BEFORE releasing the GIL. + // ``py::array::request()`` touches Python reference counts and + // needs the GIL held — calling it inside a + // ``gil_scoped_release`` block deadlocks on internal locks. + // The forward binding gets this right; the first version of + // this backward binding didn't and hung inside the dispatcher + // with zero trace output from the C++ side (since the C++ call + // was never reached — the lambda deadlocked on request()). + const void* dhPtr = dhBuf.ptr; + const void* aPtr = aBuf.ptr; + const void* hPtr = hBuf.ptr; + const void* xPtr = xBuf.ptr; + void* gradXPtr = gradX.request().ptr; + void* gradAPtr = gradA.request().ptr; + + { + py::gil_scoped_release release; + prefixScanCausalBackward( + ctx.batch, ctx.pool, ctx.cache, + static_cast(dhPtr), + static_cast(aPtr), + static_cast(hPtr), + static_cast(xPtr), + static_cast(gradXPtr), + static_cast(gradAPtr), p); + } + + py::dict result; + result["grad_x"] = gradX; + result["grad_a"] = gradA; + return result; + }, + py::arg("device"), py::arg("grad_h"), py::arg("a"), + py::arg("h"), py::arg("x"), + "Causal prefix scan backward: returns grad_x and grad_a"); +} diff --git a/cpp/src/ops/linear.cpp b/cpp/src/ops/linear.cpp index 3b0bd64..5dfd7f4 100644 --- a/cpp/src/ops/linear.cpp +++ b/cpp/src/ops/linear.cpp @@ -26,15 +26,50 @@ namespace ops { // so the dispatch overhead is unchanged from the old fast-path. void linear(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, - const float* x, const float* weights, const float* bias, - float* output, const LinearParams& p) { - // ── Buffer sizes ── - const size_t inputBytes = size_t(p.batchSeq) * p.inputDim * sizeof(float); - const size_t weightBytes = size_t(p.outputDim) * p.inputDim * sizeof(float); + const void* x, const void* weights, const void* bias, + void* output, const LinearParams& p) { + // ── Byte sizes (dynamic — fp32 or fp16 determined by p.elemSize) ── + const uint32_t inElem = p.elemSize; // 2 for fp16, 4 for fp32 + const size_t inputBytes = size_t(p.batchSeq) * p.inputDim * inElem; + const size_t weightBytes = size_t(p.outputDim) * p.inputDim * inElem; + // Bias is ALWAYS fp32 regardless of input dtype. The fp32 bias matches + // both fnn-linear's 3rd binding and gemm-bias-add's accumulator, and + // bias is small enough (outputDim floats) that the bandwidth cost of + // fp32 vs fp16 is negligible. const size_t biasBytes = p.hasBias ? size_t(p.outputDim) * sizeof(float) : sizeof(float); // dummy + // The output is ALWAYS fp32 regardless of input dtype — coopmat + // accumulator runs in fp32 for numerical stability, and fnn-linear + // also writes fp32. The Python binding converts back to fp16 if + // requested by the caller's dtype. const size_t outputBytes = size_t(p.batchSeq) * p.outputDim * sizeof(float); + // ── Shader selection ── + // Coopmat requirements: + // - fp16 input (elemSize == 2) + // - device exposes VK_KHR_cooperative_matrix + // - the compiled SPIR-V is loaded in the pipeline cache + // - shape aligned to the shader's tile (M%16, K%16, N%64) + const bool shapeAligned = + (p.batchSeq % 16u == 0u) && + (p.inputDim % 16u == 0u) && + (p.outputDim % 64u == 0u); + const bool useCoopMat = + inElem == 2u && + cache.getDevice().hasCooperativeMatrix() && + cache.hasShader("gemm-coopmat-shared") && + shapeAligned; + + // fp16 input without a coopmat path is not supported in this function — + // the fallback fnn-linear shader is fp32-only. Callers must either use + // fp32 input or run on a device that supports cooperative matrix. + if (inElem == 2u && !useCoopMat) { + throw std::runtime_error( + "linear(): fp16 input requested but cooperative matrix path is " + "unavailable (missing device support, shader, or shape " + "alignment — M%16, K%16, N%64 required)."); + } + // ── Acquire DEVICE_LOCAL compute buffers (cached VRAM, fast GPU access) ── GrillyBuffer bufInputDL = pool.acquireDeviceLocal(inputBytes); GrillyBuffer bufWeightsDL = pool.acquireDeviceLocal(weightBytes); @@ -53,51 +88,119 @@ void linear(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, // HOST_CACHED via acquireReadback gives ~7 GB/s for the same memcpy. GrillyBuffer bufOutputStage = pool.acquireReadback(outputBytes); - // ── memcpy CPU → staging (no GPU sync needed, persistent mapping) ── - pool.upload(bufInputStage, x, inputBytes); - pool.upload(bufWeightsStage, weights, weightBytes); + // ── memcpy CPU → staging (raw bytes, dtype-agnostic for x/weights) ── + pool.upload(bufInputStage, + reinterpret_cast(x), inputBytes); + pool.upload(bufWeightsStage, + reinterpret_cast(weights), weightBytes); if (p.hasBias && bias) { - pool.upload(bufBiasStage, bias, p.outputDim * sizeof(float)); + // Bias is always fp32 (see biasBytes computation above). + pool.upload(bufBiasStage, + reinterpret_cast(bias), + size_t(p.outputDim) * sizeof(float)); } - // ── Get or create pipeline (4 buffers, 16 bytes push constants) ── - PipelineEntry pipe = cache.getOrCreate("fnn-linear", 4, 16); - - // ── Allocate descriptor set bound to DEVICE_LOCAL buffers ── - // The descriptor cache keys on (shader_name, [(buffer.handle, range)]), - // so as long as the pool returns stable handles for repeated bucket - // requests (LIFO), this hits across calls. - std::vector bufferInfos(4); - bufferInfos[0] = {bufInputDL.handle, 0, inputBytes}; - bufferInfos[1] = {bufWeightsDL.handle, 0, weightBytes}; - bufferInfos[2] = {bufBiasDL.handle, 0, biasBytes}; - bufferInfos[3] = {bufOutputDL.handle, 0, outputBytes}; - - VkDescriptorSet descSet = cache.allocDescriptorSet("fnn-linear", bufferInfos); - - LinearParams pushData = p; - - uint32_t gx = (p.outputDim + 15) / 16; - uint32_t gy = (p.batchSeq + 15) / 16; + // ── Get or create pipeline ── + const std::string shaderName = useCoopMat ? "gemm-coopmat-shared" + : "fnn-linear"; + // gemm-coopmat-shared has 3 bindings (A, B, C); push constants = 12 bytes. + // fnn-linear has 4 bindings (input, weights, bias, output); push 16 bytes. + const uint32_t numBindings = useCoopMat ? 3u : 4u; + const uint32_t pushSize = useCoopMat ? 12u : 16u; + PipelineEntry pipe = cache.getOrCreate(shaderName, numBindings, pushSize); + + // ── Allocate descriptor set ── + std::vector bufferInfos; + if (useCoopMat) { + bufferInfos = { + {bufInputDL.handle, 0, inputBytes}, + {bufWeightsDL.handle, 0, weightBytes}, + {bufOutputDL.handle, 0, outputBytes}, + }; + } else { + bufferInfos = { + {bufInputDL.handle, 0, inputBytes}, + {bufWeightsDL.handle, 0, weightBytes}, + {bufBiasDL.handle, 0, biasBytes}, + {bufOutputDL.handle, 0, outputBytes}, + }; + } + VkDescriptorSet descSet = cache.allocDescriptorSet(shaderName, bufferInfos); + + // Dispatch grid depends on the shader's output tile. + uint32_t gx, gy; + if (useCoopMat) { + // gemm-coopmat-shared writes a 16×64 (M×N) tile per workgroup. + gx = (p.outputDim + 63u) / 64u; + gy = (p.batchSeq + 15u) / 16u; + } else { + // fnn-linear writes a 16×16 tile per workgroup. + gx = (p.outputDim + 15u) / 16u; + gy = (p.batchSeq + 15u) / 16u; + } // ── Single command buffer: stage-in → barrier → compute → barrier → stage-out ── batch.begin(); - // Stage-in: DMA copy host-visible staging → DEVICE_LOCAL VRAM + // Stage-in: DMA copy host-visible staging → DEVICE_LOCAL VRAM. + // Bias goes to DL up front for both paths — fnn-linear reads it via + // binding 2, and the coopmat bias-add post-pass reads it via binding 1. batch.copyBuffer(bufInputStage, bufInputDL, inputBytes); batch.copyBuffer(bufWeightsStage, bufWeightsDL, weightBytes); if (p.hasBias && bias) { - batch.copyBuffer(bufBiasStage, bufBiasDL, p.outputDim * sizeof(float)); + batch.copyBuffer(bufBiasStage, bufBiasDL, + size_t(p.outputDim) * sizeof(float)); } - // Barrier: TRANSFER_WRITE → SHADER_READ batch.transferComputeBarrier(); - // Compute on DEVICE_LOCAL buffers (full ~432 GB/s VRAM bandwidth) - batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, gy, 1, - &pushData, sizeof(pushData)); + if (useCoopMat) { + // Coopmat push constants: {M, K, N} (12 bytes) + struct CoopPush { + uint32_t M; + uint32_t K; + uint32_t N; + } coopPush = {p.batchSeq, p.inputDim, p.outputDim}; + batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, gy, 1, + &coopPush, sizeof(coopPush)); + } else { + // fnn-linear push constants: {batch, in, out, hasBias} (16 bytes) + struct FnnPush { + uint32_t batchSeq; + uint32_t inputDim; + uint32_t outputDim; + uint32_t hasBias; + } fnnPush = {p.batchSeq, p.inputDim, p.outputDim, p.hasBias}; + batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, gy, 1, + &fnnPush, sizeof(fnnPush)); + } + + // ── Bias post-pass (coopmat only; fnn-linear applies bias inline) ── + // Bias was already copied to bufBiasDL during the stage-in phase above, + // so the post-pass just needs a GEMM-write → bias-read barrier and a + // dispatch of the gemm-bias-add kernel. + if (useCoopMat && p.hasBias && bias && + cache.hasShader("gemm-bias-add")) { + batch.barrier(); // SHADER_WRITE (GEMM) → SHADER_READ (bias-add) + + // gemm-bias-add: 2 bindings (C, bias), 8 bytes push {totalElements, N} + PipelineEntry biasPipe = + cache.getOrCreate("gemm-bias-add", 2, 2 * sizeof(uint32_t)); + std::vector biasInfos = { + {bufOutputDL.handle, 0, outputBytes}, + {bufBiasDL.handle, 0, size_t(p.outputDim) * sizeof(float)}, + }; + VkDescriptorSet biasSet = + cache.allocDescriptorSet("gemm-bias-add", biasInfos); + struct BiasPush { + uint32_t totalElements; + uint32_t N; + } biasPush = {p.batchSeq * p.outputDim, p.outputDim}; + uint32_t biasGx = (biasPush.totalElements + 255u) / 256u; + batch.dispatch(biasPipe.pipeline, biasPipe.layout, biasSet, + biasGx, 1, 1, &biasPush, sizeof(biasPush)); + } - // Barrier: SHADER_WRITE → TRANSFER_READ batch.transferComputeBarrier(); // Stage-out: DMA copy DEVICE_LOCAL → host-visible HOST_CACHED staging @@ -107,7 +210,8 @@ void linear(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, batch.waitForCompletion(); // ── memcpy staging → CPU output (HOST_CACHED, ~7 GB/s) ── - pool.download(bufOutputStage, output, outputBytes); + // Output is always fp32, regardless of input dtype. + pool.download(bufOutputStage, reinterpret_cast(output), outputBytes); // ── Release buffers back to their respective pools ── pool.release(bufInputDL); @@ -165,16 +269,27 @@ std::vector linearCPU(const float* x, const float* weights, // Workgroups: 2D at (16,16) for passes 0 and 1, 1D for pass 2. void linearBackward(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, - const float* gradOutput, const float* input, - const float* weights, - float* gradInput, float* gradWeight, float* gradBias, + const void* gradOutput, const void* input, + const void* weights, + void* gradInput, void* gradWeight, void* gradBias, const LinearParams& p) { - const size_t gradOutBytes = size_t(p.batchSeq) * p.outputDim * sizeof(float); - const size_t inputBytes = size_t(p.batchSeq) * p.inputDim * sizeof(float); - const size_t weightBytes = size_t(p.outputDim) * p.inputDim * sizeof(float); + // The fnn-linear-backward shader is fp32-only; reject fp16 input until a + // coopmat backward shader lands. The void* interface is in place so the + // switchover is local. + if (p.elemSize != 4u) { + throw std::runtime_error( + "linearBackward(): currently requires fp32 (elemSize=4). fp16 " + "backward needs a cooperative matrix backward shader — TODO."); + } + + // Dynamic byte calculation. With elemSize==4 today these match the old + // sizeof(float) computations, so existing callers see no behavior change. + const size_t gradOutBytes = size_t(p.batchSeq) * p.outputDim * p.elemSize; + const size_t inputBytes = size_t(p.batchSeq) * p.inputDim * p.elemSize; + const size_t weightBytes = size_t(p.outputDim) * p.inputDim * p.elemSize; const size_t gradInBytes = inputBytes; const size_t gradWBytes = weightBytes; - const size_t gradBiasBytes = size_t(p.outputDim) * sizeof(float); + const size_t gradBiasBytes = size_t(p.outputDim) * p.elemSize; // Staging pattern: 3 stage-in (gradOut, input, weights), // 3 stage-out (gradIn, gradW, gradBias). All compute on DEVICE_LOCAL. @@ -192,27 +307,27 @@ void linearBackward(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, GrillyBuffer bufGradWStage = pool.acquireReadback(gradWBytes); GrillyBuffer bufGradBiasStage = pool.acquireReadback(gradBiasBytes); - pool.upload(bufGradOutStage, gradOutput, gradOutBytes); - pool.upload(bufInputStage, input, inputBytes); - pool.upload(bufWeightsStage, weights, weightBytes); - - // The grad output buffers must start at zero — pass 1 (grad_weight) and - // pass 2 (grad_bias) accumulate via atomic adds in the shader. We zero - // them on the GPU side via vkCmdFillBuffer rather than uploading zeros - // through staging (which was the old code path). - // (Workaround: upload zeros to a small temporary stage and copy. The - // simpler path: keep the upload-zeros-via-stage approach since we need - // to reset every call.) - std::vector zerosIn(p.batchSeq * p.inputDim, 0.0f); - std::vector zerosW(p.outputDim * p.inputDim, 0.0f); - std::vector zerosB(p.outputDim, 0.0f); - // Reuse the readback stage buffers as upload-zeros source: they're - // host-visible (HOST_CACHED), CPU-write is fine even though it's not - // optimal for sequential write — total bytes is small relative to GPU - // compute. Upload then DMA copy in the command buffer. - pool.upload(bufGradInStage, zerosIn.data(), gradInBytes); - pool.upload(bufGradWStage, zerosW.data(), gradWBytes); - pool.upload(bufGradBiasStage, zerosB.data(), gradBiasBytes); + pool.upload(bufGradOutStage, + reinterpret_cast(gradOutput), gradOutBytes); + pool.upload(bufInputStage, + reinterpret_cast(input), inputBytes); + pool.upload(bufWeightsStage, + reinterpret_cast(weights), weightBytes); + + // The grad buffers must start at zero — pass 1 (grad_weight) and + // pass 2 (grad_bias) accumulate via atomic adds in the shader. Use + // raw byte vectors so zeroing works identically for fp32 and fp16 + // (whenever the fp16 backward shader lands). Reuse the readback stage + // buffers as upload-zeros source — HOST_CACHED, CPU-write is fine. + std::vector zerosIn(gradInBytes, 0); + std::vector zerosW(gradWBytes, 0); + std::vector zerosB(gradBiasBytes, 0); + pool.upload(bufGradInStage, + reinterpret_cast(zerosIn.data()), gradInBytes); + pool.upload(bufGradWStage, + reinterpret_cast(zerosW.data()), gradWBytes); + pool.upload(bufGradBiasStage, + reinterpret_cast(zerosB.data()), gradBiasBytes); LinearBackwardParams bwdParams{p.batchSeq, p.inputDim, p.outputDim, 0}; @@ -274,9 +389,12 @@ void linearBackward(CommandBatch& batch, BufferPool& pool, PipelineCache& cache, batch.submitDeferred(); batch.waitForCompletion(); - pool.download(bufGradInStage, gradInput, gradInBytes); - pool.download(bufGradWStage, gradWeight, gradWBytes); - pool.download(bufGradBiasStage, gradBias, gradBiasBytes); + pool.download(bufGradInStage, + reinterpret_cast(gradInput), gradInBytes); + pool.download(bufGradWStage, + reinterpret_cast(gradWeight), gradWBytes); + pool.download(bufGradBiasStage, + reinterpret_cast(gradBias), gradBiasBytes); pool.release(bufGradOutDL); pool.release(bufInputDL); diff --git a/cpp/src/ops/prefix_scan.cpp b/cpp/src/ops/prefix_scan.cpp new file mode 100644 index 0000000..318a42b --- /dev/null +++ b/cpp/src/ops/prefix_scan.cpp @@ -0,0 +1,165 @@ +#include "grilly/ops/prefix_scan.h" + +#include +#include +#include + +namespace grilly { +namespace ops { + +// Uses the staging pattern (DEVICE_LOCAL compute + WC stage-in + HOST_CACHED +// stage-out) — same as linear.cpp. See linear.cpp for the rationale. + +void prefixScanCausal(CommandBatch& batch, BufferPool& pool, + PipelineCache& cache, + const float* x, const float* a, float* h, + const PrefixScanParams& p) { + const size_t elemBytes = size_t(p.batchSize) * p.seqLen * p.hiddenDim * + sizeof(float); + + // DEVICE_LOCAL compute buffers + GrillyBuffer bufXDL = pool.acquireDeviceLocal(elemBytes); + GrillyBuffer bufADL = pool.acquireDeviceLocal(elemBytes); + GrillyBuffer bufHDL = pool.acquireDeviceLocal(elemBytes); + + // Host staging + GrillyBuffer bufXStage = pool.acquire(elemBytes); + GrillyBuffer bufAStage = pool.acquire(elemBytes); + GrillyBuffer bufHStage = pool.acquireReadback(elemBytes); + + pool.upload(bufXStage, x, elemBytes); + pool.upload(bufAStage, a, elemBytes); + + PipelineEntry pipe = cache.getOrCreate( + "prefix-scan-causal", 3, 2 * sizeof(uint32_t)); + + std::vector bufInfos = { + {bufXDL.handle, 0, elemBytes}, + {bufADL.handle, 0, elemBytes}, + {bufHDL.handle, 0, elemBytes}, + }; + VkDescriptorSet descSet = + cache.allocDescriptorSet("prefix-scan-causal", bufInfos); + + struct Push { + uint32_t seqLen; + uint32_t hiddenDim; + } push = {p.seqLen, p.hiddenDim}; + + // One workgroup per (hidden_dim, batch) pair. Local size = 32 threads + // (one per time step) — shader enforces seqLen <= 32. + uint32_t gx = p.hiddenDim; + uint32_t gy = p.batchSize; + + batch.begin(); + + batch.copyBuffer(bufXStage, bufXDL, elemBytes); + batch.copyBuffer(bufAStage, bufADL, elemBytes); + batch.transferComputeBarrier(); + + batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, gy, 1, + &push, sizeof(push)); + + batch.transferComputeBarrier(); + batch.copyBuffer(bufHDL, bufHStage, elemBytes); + batch.submitDeferred(); + batch.waitForCompletion(); + + pool.download(bufHStage, h, elemBytes); + + pool.release(bufXDL); + pool.release(bufADL); + pool.release(bufHDL); + pool.release(bufXStage); + pool.release(bufAStage); + pool.release(bufHStage); +} + +void prefixScanCausalBackward(CommandBatch& batch, BufferPool& pool, + PipelineCache& cache, + const float* dh, const float* a, + const float* h, const float* x, + float* dx, float* da, + const PrefixScanParams& p) { + const size_t elemBytes = size_t(p.batchSize) * p.seqLen * p.hiddenDim * + sizeof(float); + + // DEVICE_LOCAL compute buffers (4 inputs, 2 outputs) + GrillyBuffer bufDhDL = pool.acquireDeviceLocal(elemBytes); + GrillyBuffer bufADL = pool.acquireDeviceLocal(elemBytes); + GrillyBuffer bufHDL = pool.acquireDeviceLocal(elemBytes); + GrillyBuffer bufXDL = pool.acquireDeviceLocal(elemBytes); + GrillyBuffer bufDxDL = pool.acquireDeviceLocal(elemBytes); + GrillyBuffer bufDaDL = pool.acquireDeviceLocal(elemBytes); + + // Staging + GrillyBuffer bufDhStage = pool.acquire(elemBytes); + GrillyBuffer bufAStage = pool.acquire(elemBytes); + GrillyBuffer bufHStage = pool.acquire(elemBytes); + GrillyBuffer bufXStage = pool.acquire(elemBytes); + GrillyBuffer bufDxStage = pool.acquireReadback(elemBytes); + GrillyBuffer bufDaStage = pool.acquireReadback(elemBytes); + + pool.upload(bufDhStage, dh, elemBytes); + pool.upload(bufAStage, a, elemBytes); + pool.upload(bufHStage, h, elemBytes); + pool.upload(bufXStage, x, elemBytes); + + PipelineEntry pipe = cache.getOrCreate( + "prefix-scan-causal-backward", 6, 2 * sizeof(uint32_t)); + + std::vector bufInfos = { + {bufDhDL.handle, 0, elemBytes}, + {bufADL.handle, 0, elemBytes}, + {bufHDL.handle, 0, elemBytes}, + {bufXDL.handle, 0, elemBytes}, + {bufDxDL.handle, 0, elemBytes}, + {bufDaDL.handle, 0, elemBytes}, + }; + VkDescriptorSet descSet = cache.allocDescriptorSet( + "prefix-scan-causal-backward", bufInfos); + + struct Push { + uint32_t seqLen; + uint32_t hiddenDim; + } push = {p.seqLen, p.hiddenDim}; + + uint32_t gx = p.hiddenDim; + uint32_t gy = p.batchSize; + + batch.begin(); + + batch.copyBuffer(bufDhStage, bufDhDL, elemBytes); + batch.copyBuffer(bufAStage, bufADL, elemBytes); + batch.copyBuffer(bufHStage, bufHDL, elemBytes); + batch.copyBuffer(bufXStage, bufXDL, elemBytes); + batch.transferComputeBarrier(); + + batch.dispatch(pipe.pipeline, pipe.layout, descSet, gx, gy, 1, + &push, sizeof(push)); + + batch.transferComputeBarrier(); + batch.copyBuffer(bufDxDL, bufDxStage, elemBytes); + batch.copyBuffer(bufDaDL, bufDaStage, elemBytes); + batch.submitDeferred(); + batch.waitForCompletion(); + + pool.download(bufDxStage, dx, elemBytes); + pool.download(bufDaStage, da, elemBytes); + + pool.release(bufDhDL); + pool.release(bufADL); + pool.release(bufHDL); + pool.release(bufXDL); + pool.release(bufDxDL); + pool.release(bufDaDL); + pool.release(bufDhStage); + pool.release(bufAStage); + pool.release(bufHStage); + pool.release(bufXStage); + pool.release(bufDxStage); + pool.release(bufDaStage); +} + +} // namespace ops +} // namespace grilly diff --git a/nn/prefix_scan.py b/nn/prefix_scan.py new file mode 100644 index 0000000..cbf6f1f --- /dev/null +++ b/nn/prefix_scan.py @@ -0,0 +1,182 @@ +"""Causal Linear-RNN prefix scan — autograd-wrapped Python frontend. + +Wraps the C++ / Vulkan ``grilly_core.prefix_scan_causal`` and +``prefix_scan_causal_backward`` kernels into grilly's autograd system so +``loss.backward()`` flows gradients through the recurrence. + +Math: + Forward: h_t = a_t * h_{t-1} + x_t (h_0 = 0) + Backward: dx_t = dh_t + a_{t+1} * dx_{t+1} (anti-causal scan) + da_t = dx_t * h_{t-1} + +The shader runs one subgroup per (batch, hidden_dim) pair, one thread per +time step, and uses ``subgroupInclusiveAdd`` for O(log S) parallel depth. + +Constraint: ``seq_len <= 32``. Longer sequences need a hierarchical scan +(chunk the sequence, carry state between chunks) — not implemented yet. + +Example:: + + from grilly.nn.prefix_scan import prefix_scan_causal + h = prefix_scan_causal(x, a) # x, a: (B, S, D) Variables + loss = h.mean() + loss.backward() # grads flow to both x and a +""" + +from __future__ import annotations + +from typing import Tuple + +import numpy as np + +from grilly.nn.autograd import GradFn, Variable, _ensure_variable, _grad_enabled + + +def _get_bridge_device(): + """Return the grilly bridge device with shaders loaded.""" + from grilly.backend import _bridge + + dev = _bridge._get_device() + if dev is None: + raise RuntimeError( + "grilly bridge device not initialized — " + "import grilly (or grilly.torch_api) first so shaders load" + ) + return dev + + +def prefix_scan_causal(x, a) -> Variable: + """Causal Linear-RNN: ``h_t = a_t * h_{t-1} + x_t``. + + Args: + x: Input sequence, shape ``(B, S, D)``. Any ``Variable`` / ``Tensor`` + / ndarray works. + a: Decay gates in ``(0, 1]``, same shape as ``x``. + + Returns: + ``Variable`` of shape ``(B, S, D)`` with the causal hidden states. + If autograd is enabled and either input requires grad, the result + is wired into the graph via a ``GradFn`` that calls the C++ + ``prefix_scan_causal_backward`` kernel on ``loss.backward()``. + """ + import grilly_core as gc + + x_var = _ensure_variable(x) + a_var = _ensure_variable(a) + + x_data = np.asarray(x_var.data, dtype=np.float32) + a_data = np.asarray(a_var.data, dtype=np.float32) + + if x_data.shape != a_data.shape: + raise ValueError( + f"prefix_scan_causal: x shape {x_data.shape} != a shape {a_data.shape}" + ) + if x_data.ndim != 3: + raise ValueError( + f"prefix_scan_causal: inputs must be 3D (B, S, D), got {x_data.ndim}D" + ) + + S = x_data.shape[1] + if S > 32: + raise ValueError( + f"prefix_scan_causal: seq_len {S} > 32. The current shader runs " + f"one subgroup per (batch, dim) pair with one thread per time " + f"step — hierarchical multi-subgroup scan is a TODO. Either " + f"chunk the sequence on the Python side or truncate to 32 for " + f"the correctness run." + ) + + dev = _get_bridge_device() + h_data = np.asarray(gc.prefix_scan_causal(dev, x_data, a_data), dtype=np.float32) + + # ── Autograd wiring ── + requires_grad = ( + _grad_enabled + and (x_var.requires_grad or a_var.requires_grad) + ) + if not requires_grad: + return Variable(h_data, requires_grad=False) + + # Capture tensors needed by the backward closure. We save the forward + # x / a / h as immutable ndarrays so later in-place mutations on the + # caller's buffers don't corrupt the backward pass. + saved_x = x_data.copy() + saved_a = a_data.copy() + saved_h = h_data.copy() + + def backward_fn(grad_output): + # grad_output is dh coming back from downstream. Dispatch the + # anti-causal kernel and split the returned dict into (grad_x, grad_a). + grad_h = np.asarray(grad_output, dtype=np.float32) + result = gc.prefix_scan_causal_backward( + dev, grad_h, saved_a, saved_h, saved_x + ) + grad_x = np.asarray(result["grad_x"], dtype=np.float32) + grad_a = np.asarray(result["grad_a"], dtype=np.float32) + return (grad_x, grad_a) + + # GradFn inputs list order MUST match the backward return tuple order. + grad_fn = GradFn("PrefixScanCausal", backward_fn, [x_var, a_var]) + return Variable(h_data, requires_grad=True, grad_fn=grad_fn) + + +class CausalSequenceMixer: + """Subgroup-accelerated causal sequence mixer. + + Replaces the ``h.mean(dim=1)`` sequence pooling in the old LiquidCell + path, which destroyed causality by letting any time step see the future. + This module runs a proper Linear-RNN that strictly respects causal + masking. + + Architecture: + x_t = proj_x(x) # input projection + a_t = sigmoid(proj_a(x)) # decay gate in (0, 1) + h_t = a_t * h_{t-1} + x_t # causal prefix scan (GPU) + + Implemented as a regular grilly ``nn.Module``, not a subclass — kept + minimal and self-contained so it's easy to drop into the v3c script. + Use via ``mixer = CausalSequenceMixer(d); h = mixer(x)`` where ``x`` + is shape ``(B, S, D)``. + + NOTE: constructed with explicit ``grilly.nn`` imports at call time to + avoid a circular import at module load. + """ + + def __init__(self, d: int): + from grilly import nn + + self.d = d + self.proj_x = nn.Linear(d, d, bias=False) + self.proj_a = nn.Linear(d, d, bias=True) + # Initialize the gate bias to +1 so sigmoid(1) ≈ 0.73 — the model + # starts out "remembering" most of the hidden state at t=0, which + # matches the LiquidCell behavior the old code defaulted to. + try: + b_arr = np.asarray(self.proj_a.bias.data if hasattr(self.proj_a.bias, "data") + else self.proj_a.bias) + b_arr[:] = 1.0 + except Exception: + pass + + def parameters(self): + yield from self.proj_x.parameters() + yield from self.proj_a.parameters() + + def __call__(self, x): + # x: (B, S, D) — a Variable / Tensor from upstream. + x_t = self.proj_x(x) + + # Sigmoid without importing torch_api: use the identity + # sigmoid(z) = 0.5 * (1 + tanh(z/2)) + # which is numerically stable and uses only tanh, which grilly's + # autograd exposes via Variable.tanh(). + a_logits = self.proj_a(x) + if hasattr(a_logits, "tanh"): + a_t = 0.5 * (1.0 + (a_logits * 0.5).tanh()) + else: + # Fallback for plain ndarray — upstream should be a Variable. + import numpy as _np + a_t = 0.5 * (1.0 + _np.tanh(_np.asarray(a_logits) * 0.5)) + + h = prefix_scan_causal(x_t, a_t) + return h diff --git a/rebuild.ps1 b/rebuild.ps1 new file mode 100644 index 0000000..c849228 --- /dev/null +++ b/rebuild.ps1 @@ -0,0 +1,148 @@ +# rebuild.ps1 — one-shot grilly rebuild: compile shaders + build grilly_core + copy .pyd +# +# Usage (from grilly root): +# ./rebuild.ps1 # full rebuild: shaders + grilly_core + copy +# ./rebuild.ps1 -SkipShaders # skip shader compile (if only C++ changed) +# ./rebuild.ps1 -SkipBuild # just compile shaders + copy existing .pyd +# ./rebuild.ps1 -Clean # wipe build2 and reconfigure from scratch +# ./rebuild.ps1 -Verbose # print each file as it processes +# +# Exits non-zero on shader or build failure so you can chain it in a Makefile +# or git hook if you're feeling fancy. + +[CmdletBinding()] +param( + [switch]$SkipShaders, + [switch]$SkipBuild, + [switch]$Clean +) + +$ErrorActionPreference = 'Stop' +$grillyRoot = $PSScriptRoot +Set-Location $grillyRoot + +function Write-Step { + param([string]$Text) + Write-Host "" + Write-Host $Text -ForegroundColor Cyan +} + +Write-Host "=== grilly rebuild ===" -ForegroundColor Cyan +Write-Host "Root: $grillyRoot" + +# ── 1. Compile GLSL -> SPIR-V ────────────────────────────────────────────── +if (-not $SkipShaders) { + Write-Step "[1/3] Compiling shaders..." + + $shaderDir = Join-Path $grillyRoot "shaders" + $spvDir = Join-Path $shaderDir "spv" + New-Item -ItemType Directory -Force -Path $spvDir | Out-Null + + # Verify glslangValidator is on PATH (ships with Vulkan SDK). + $glslang = "glslangValidator" + $glslangOk = $false + try { + $null = & $glslang --version 2>&1 + if ($LASTEXITCODE -eq 0) { $glslangOk = $true } + } catch {} + if (-not $glslangOk) { + Write-Host "ERROR: glslangValidator not found on PATH." -ForegroundColor Red + Write-Host " Install the Vulkan SDK (https://vulkan.lunarg.com/) or" -ForegroundColor Red + Write-Host " add `$env:VULKAN_SDK/Bin to `$env:Path." -ForegroundColor Red + exit 1 + } + + $total = 0 + $compiled = 0 + $skipped = 0 + $failed = 0 + $failedNames = @() + + Get-ChildItem -Path $shaderDir -Filter "*.glsl" | ForEach-Object { + $total++ + $src = $_.FullName + $dst = Join-Path $spvDir ($_.BaseName + ".spv") + + # Skip if .spv is newer than .glsl + if ((Test-Path $dst) -and ((Get-Item $dst).LastWriteTime -gt $_.LastWriteTime)) { + $skipped++ + if ($VerbosePreference -eq 'Continue') { + Write-Host " SKIP $($_.BaseName) (up-to-date)" -ForegroundColor DarkGray + } + return + } + + # -S comp : treat as compute shader (glslangValidator can't infer + # the stage from the generic .glsl extension) + # --target-env vulkan1.3 : enables cooperative matrix + subgroup extensions + $output = & $glslang -V -S comp $src -o $dst --target-env vulkan1.3 2>&1 + if ($LASTEXITCODE -eq 0) { + $compiled++ + if ($VerbosePreference -eq 'Continue') { + Write-Host " OK $($_.BaseName)" -ForegroundColor DarkGreen + } + } else { + $failed++ + $failedNames += $_.BaseName + Write-Host " FAIL $($_.BaseName)" -ForegroundColor Red + $output | ForEach-Object { Write-Host " $_" -ForegroundColor DarkRed } + } + } + + $summary = " Shaders: $compiled compiled, $skipped up-to-date, $failed failed / $total total" + if ($failed -gt 0) { + Write-Host $summary -ForegroundColor Yellow + Write-Host "ERROR: $failed shader(s) failed to compile: $($failedNames -join ', ')" -ForegroundColor Red + exit 1 + } else { + Write-Host $summary -ForegroundColor Green + } +} + +# ── 2. Build grilly_core (build2/Release) ──────────────────────────────── +if (-not $SkipBuild) { + Write-Step "[2/3] Building grilly_core (build2/Release)..." + + $buildDir = Join-Path $grillyRoot "build2" + + if ($Clean -and (Test-Path $buildDir)) { + Write-Host " Cleaning $buildDir..." -ForegroundColor Yellow + Remove-Item -Recurse -Force $buildDir + } + + if (-not (Test-Path (Join-Path $buildDir "CMakeCache.txt"))) { + Write-Host " Configuring cmake..." + & cmake -B $buildDir -G "Visual Studio 17 2022" -A x64 + if ($LASTEXITCODE -ne 0) { + Write-Host "ERROR: cmake configure failed" -ForegroundColor Red + exit 1 + } + } + + & cmake --build $buildDir --config Release --target grilly_core + if ($LASTEXITCODE -ne 0) { + Write-Host "ERROR: cmake build failed" -ForegroundColor Red + exit 1 + } + Write-Host " Build OK" -ForegroundColor Green +} + +# ── 3. Copy freshly built .pyd to grilly root ──────────────────────────── +Write-Step "[3/3] Copying grilly_core.cp312-win_amd64.pyd..." + +$builtPyd = Join-Path $grillyRoot "build2/Release/grilly_core.cp312-win_amd64.pyd" +$destPyd = Join-Path $grillyRoot "grilly_core.cp312-win_amd64.pyd" + +if (-not (Test-Path $builtPyd)) { + Write-Host "ERROR: Built .pyd not found at $builtPyd" -ForegroundColor Red + Write-Host " (did the build succeed?)" -ForegroundColor Red + exit 1 +} + +Copy-Item -Force $builtPyd $destPyd +$sizeMB = [math]::Round((Get-Item $destPyd).Length / 1MB, 2) +$mtime = (Get-Item $destPyd).LastWriteTime.ToString("HH:mm:ss") +Write-Host " Copied ($sizeMB MB, mtime $mtime)" -ForegroundColor Green + +Write-Host "" +Write-Host "=== Done ===" -ForegroundColor Cyan diff --git a/shaders/fused-layernorm-linear.glsl b/shaders/fused-layernorm-linear.glsl index db62133..698c5a2 100644 --- a/shaders/fused-layernorm-linear.glsl +++ b/shaders/fused-layernorm-linear.glsl @@ -1,5 +1,6 @@ #version 450 #extension GL_KHR_shader_subgroup_arithmetic : require +#extension GL_EXT_shader_atomic_float : require // Fused LayerNorm + Linear projection. // diff --git a/shaders/gemm-bias-add.glsl b/shaders/gemm-bias-add.glsl new file mode 100644 index 0000000..4aace97 --- /dev/null +++ b/shaders/gemm-bias-add.glsl @@ -0,0 +1,27 @@ +#version 450 + +/* + * Row-broadcast bias add: C[r, c] += bias[c] + * + * Used as a second dispatch after gemm-coopmat-shared (which produces C + * without bias, since coopMatStore can't easily interleave an elementwise + * add with the tile-cooperative store). 1D dispatch over M*N elements, + * 256 threads per workgroup. + */ + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout(binding = 0) buffer CBuffer { float C[]; }; +layout(binding = 1) readonly buffer BiasBuffer { float bias[]; }; + +layout(push_constant) uniform PushConstants { + uint totalElements; // M * N + uint N; // stride (columns per row) +} params; + +void main() { + uint idx = gl_GlobalInvocationID.x; + if (idx >= params.totalElements) return; + uint col = idx % params.N; + C[idx] = C[idx] + bias[col]; +} diff --git a/shaders/gemm-coopmat-shared.glsl b/shaders/gemm-coopmat-shared.glsl new file mode 100644 index 0000000..63c71e9 --- /dev/null +++ b/shaders/gemm-coopmat-shared.glsl @@ -0,0 +1,99 @@ +#version 450 +#extension GL_KHR_cooperative_matrix : require +#extension GL_KHR_memory_scope_semantics : require +#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require +#extension GL_KHR_shader_subgroup_basic : require + +/* + * Cooperative Matrix GEMM with Shared Memory Staging ("The Merge") + * + * C = A * B where A is MxK, B is KxN, C is MxN, row-major. + * A and B are float16; C accumulates in float32. + * + * Workgroup: 64x4 (256 threads) -> 4 subgroups of 64 lanes (Wave64 for RDNA). + * Output tile per workgroup: 16 rows x 64 cols of C (4 x 16x16 coopmat tiles). + * + * Alignment requirements on the caller: + * M % 16 == 0, K % 16 == 0, N % 64 == 0 + * + * Dispatch: + * gx = N / 64 + * gy = M / 16 + * gz = 1 + * + * Hardware notes: on RDNA3 / NVIDIA Tensor Cores this hits full WMMA + * throughput via the driver. On RDNA1/RDNA2 it runs through the driver's + * emulation path (standard fp16 vector ops) — still correct, noticeably + * slower than the peak but still competitive with a hand-tuled GEMM. + */ + +layout(local_size_x = 64, local_size_y = 4, local_size_z = 1) in; + +layout(binding = 0) readonly buffer ABuffer { float16_t A[]; }; +layout(binding = 1) readonly buffer BBuffer { float16_t B[]; }; +layout(binding = 2) writeonly buffer CBuffer { float C[]; }; + +layout(push_constant) uniform PushConstants { + uint M; + uint K; + uint N; +} params; + +// Shared memory staging: +// Asub stages a 16x16 tile of A (256 elements, all 4 subgroups share it) +// Bsub stages a 16x64 tile of B (1024 elements, each subgroup takes a slice) +shared float16_t Asub[256]; +shared float16_t Bsub[1024]; + +void main() { + uint tile_row = gl_WorkGroupID.y * 16u; + uint tile_col = gl_WorkGroupID.x * 64u; + + uint sg_id = gl_LocalInvocationID.y; // subgroup id 0..3 + uint lane_id = gl_LocalInvocationID.x; // lane within subgroup 0..63 + uint linear_id = sg_id * 64u + lane_id; // 0..255 + + // Hardware accumulator lives in registers (one per subgroup) + coopmat matC = + coopmat(0.0); + + for (uint k = 0u; k < params.K; k += 16u) { + // ── 1. Stage A tile (16x16 = 256 elements, one per thread) ── + if (linear_id < 256u) { + uint a_r = linear_id / 16u; + uint a_c = linear_id % 16u; + Asub[linear_id] = A[(tile_row + a_r) * params.K + (k + a_c)]; + } + + // ── 2. Stage B tile (16x64 = 1024 elements, 4 per thread) ── + for (uint i = 0u; i < 4u; ++i) { + uint b_idx = linear_id + (i * 256u); + uint b_r = b_idx / 64u; + uint b_c = b_idx % 64u; + Bsub[b_idx] = B[(k + b_r) * params.N + (tile_col + b_c)]; + } + + barrier(); + + // ── 3. Load shared memory → cooperative registers ── + coopmat matA; + coopmat matB; + + // All 4 subgroups load the SAME 16x16 A tile from Asub (stride = 16) + coopMatLoad(matA, Asub, 0, 16, gl_CooperativeMatrixLayoutRowMajor); + + // Each subgroup loads its OWN 16x16 slice of the 16x64 B tile + // (stride = 64 over Bsub, starting offset = sg_id * 16) + coopMatLoad(matB, Bsub, sg_id * 16u, 64, gl_CooperativeMatrixLayoutRowMajor); + + // ── 4. Hardware matrix multiply-accumulate ── + matC = coopMatMulAdd(matA, matB, matC); + + barrier(); + } + + // ── 5. Write back the 16x16 accumulated tile to global C ── + uint out_col = tile_col + (sg_id * 16u); + coopMatStore(matC, C, tile_row * params.N + out_col, params.N, + gl_CooperativeMatrixLayoutRowMajor); +} diff --git a/shaders/lstm-cell-forward.glsl b/shaders/lstm-cell-forward.glsl index 7cd4e83..f11da5d 100644 --- a/shaders/lstm-cell-forward.glsl +++ b/shaders/lstm-cell-forward.glsl @@ -35,7 +35,8 @@ layout(push_constant) uniform PushConstants { } params; layout(set = 0, binding = 0) readonly buffer InputBuffer { - float input[]; + // Renamed from ``input`` — reserved word in recent glslang. + float input_data[]; }; layout(set = 0, binding = 1) readonly buffer HiddenBuffer { @@ -70,7 +71,12 @@ layout(set = 0, binding = 8) writeonly buffer NewCellBuffer { float new_cell[]; }; -layout(set = 0, binding = 9) writeonly buffer GatesBuffer { +// Not ``writeonly``: the shader writes all four gates into this buffer +// (binding 9) in the first half of main(), then reads them back a few +// lines later to compute the cell/hidden updates. Pre-existing bug — +// glslang's writeonly check was only triggered after fixing the +// reserved-word rename on binding 0 forced a recompile. +layout(set = 0, binding = 9) buffer GatesBuffer { float gates[]; }; @@ -94,7 +100,7 @@ void main() { for (uint i = 0; i < params.input_size; i++) { uint weight_idx = (gate * params.hidden_size + hidden_idx) * params.input_size + i; uint input_idx = batch_idx * params.input_size + i; - value += weight_ih[weight_idx] * input[input_idx]; + value += weight_ih[weight_idx] * input_data[input_idx]; } // Add bias_i diff --git a/shaders/prefix-scan-causal-backward.glsl b/shaders/prefix-scan-causal-backward.glsl new file mode 100644 index 0000000..112012c --- /dev/null +++ b/shaders/prefix-scan-causal-backward.glsl @@ -0,0 +1,102 @@ +#version 450 +#extension GL_KHR_shader_subgroup_arithmetic : require +#extension GL_KHR_shader_subgroup_basic : require + +/* + * Causal Linear-RNN Prefix Scan (backward) + * + * Given forward recurrence h_t = a_t * h_{t-1} + x_t, the adjoint rules are: + * + * dx_t = dh_t + a_{t+1} * dx_{t+1} (anti-causal sum) + * da_t = dx_t * h_{t-1} + * + * Using g_t = prod_{k<=t} a_k (forward cumprod): + * + * dx_t = (1 / g_t) * sum_{s>=t} g_s * dh_s + * + * Strategy: + * 1. Compute g_t via ``subgroupInclusiveAdd(log(a))`` + exp. Verified + * correct on partial Wave64 subgroups (see earlier debug dump: + * max abs err 2.98e-08 vs numpy cumprod). + * 2. Compute weighted_dh = g_t * dh_t per thread. + * 3. Store to shared memory. + * 4. Each thread LOOPS sequentially over shared[t..seq_len-1] to compute + * its own right_sum. For seq_len <= 32 this is ~32 iterations which + * is cheap compared to the rest of the pipeline, and it sidesteps + * the ``total`` broadcast trap entirely (neither ``subgroupAdd`` nor + * a ``shared_total`` write from thread 31 produced the correct total + * in earlier attempts — the cause is still unclear but likely some + * AMD partial-Wave64 edge case). + * + * Dispatch: one workgroup per (batch, hidden_dim) pair, one thread per + * time step. Constraint: seq_len <= 32 (matches local_size_x). + */ + +layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in; + +layout(binding = 0) readonly buffer GradHBuffer { float dH[]; }; +layout(binding = 1) readonly buffer DecayBuffer { float A[]; }; +layout(binding = 2) readonly buffer ForwardHBuffer { float H[]; }; +layout(binding = 3) readonly buffer ForwardXBuffer { float X[]; }; +layout(binding = 4) writeonly buffer GradXBuffer { float dX[]; }; +layout(binding = 5) writeonly buffer GradABuffer { float dA[]; }; + +layout(push_constant) uniform PushConstants { + uint seq_len; + uint hidden_dim; +} params; + +shared float scratch[32]; + +void main() { + uint t = gl_LocalInvocationID.x; + uint d = gl_WorkGroupID.x; + uint b = gl_WorkGroupID.y; + + bool in_bounds = (d < params.hidden_dim); + bool in_seq = in_bounds && (t < params.seq_len); + + uint idx = in_seq + ? (b * params.seq_len * params.hidden_dim) + (t * params.hidden_dim) + d + : 0u; + + // Load forward values (neutral for padding threads). + float a_t = in_seq ? max(A[idx], 1e-6) : 1.0; + float dh_t = in_seq ? dH[idx] : 0.0; + float h_t = in_seq ? H[idx] : 0.0; + float x_t = in_seq ? X[idx] : 0.0; + + // ── Forward cumulative product of a: g_t = prod_{k<=t} a_k ── + float log_a = log(a_t); + float cumsum_log_a = subgroupInclusiveAdd(log_a); + float g_t = exp(cumsum_log_a); + + // ── Store weighted_dh to shared memory ── + float weighted_dh = g_t * dh_t; + scratch[t] = weighted_dh; + barrier(); + + // ── Sequential right-sum: right_sum[t] = sum_{s>=t} weighted_dh[s] ── + float right_sum = 0.0; + for (uint s = t; s < params.seq_len; s++) { + right_sum += scratch[s]; + } + + // dx_t = right_sum / g_t + float dx_t = right_sum / g_t; + + // ── da_t = dx_t * h_{t-1}, with h_{t-1} = (h_t - x_t) / a_t ── + // Numerically unstable when a_t is tiny — max(A[idx], 1e-6) clamp + // keeps it away from zero. TODO: save h_{t-1} during forward for + // full stability. + float h_prev = 0.0; + if (in_seq && t > 0u) { + h_prev = (h_t - x_t) / a_t; + } + float da_t = dx_t * h_prev; + + if (in_seq) { + dX[idx] = dx_t; + dA[idx] = da_t; + } +} diff --git a/shaders/prefix-scan-causal.glsl b/shaders/prefix-scan-causal.glsl new file mode 100644 index 0000000..82679ac --- /dev/null +++ b/shaders/prefix-scan-causal.glsl @@ -0,0 +1,85 @@ +#version 450 +#extension GL_KHR_shader_subgroup_arithmetic : require +#extension GL_KHR_shader_subgroup_basic : require + +/* + * Causal Linear-RNN Prefix Scan (forward) + * + * Computes strictly-causal recurrence in parallel: + * h_t = a_t * h_{t-1} + x_t + * where a_t is a per-(batch, seq, dim) decay gate in (0, 1], and h_0 = 0. + * + * Math trick: taking the cumulative product of a_t and scaling x_t by its + * inverse turns the recurrence into two commutative scans — a cumulative + * sum of log(a) for the decay, and a cumulative sum of the scaled input. + * + * cumprod_a[t] = prod_{s<=t} a_s = exp(sum_{s<=t} log(a_s)) + * h_t = cumprod_a[t] * sum_{s<=t} (x_s / cumprod_a[s]) + * + * Both scans are hardware-accelerated by ``subgroupInclusiveAdd``, which + * runs entirely in the subgroup's registers (no shared/global traffic). + * + * Dispatch layout (one subgroup per (batch, hidden_dim) pair): + * local_size_x = subgroup_size (32 on RDNA Wave32, or use 64 for Wave64) + * gl_WorkGroupID.x = hidden_dim index + * gl_WorkGroupID.y = batch index + * local_x = time step t + * + * Buffer layout is (B, S, D) row-major: + * index(b, t, d) = b * S * D + t * D + d + * + * Constraint: sequence length S must be <= subgroup_size. For longer + * sequences the caller should split the sequence into chunks that fit + * in one subgroup, using the tail of each chunk as the carry for the + * next. A multi-subgroup version of this shader (hierarchical scan) is + * a follow-up. + */ + +layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in; + +layout(binding = 0) readonly buffer InputBuffer { float X[]; }; +layout(binding = 1) readonly buffer DecayBuffer { float A[]; }; +layout(binding = 2) writeonly buffer OutputBuffer { float H[]; }; + +layout(push_constant) uniform PushConstants { + uint seq_len; + uint hidden_dim; +} params; + +void main() { + uint t = gl_LocalInvocationID.x; // time step 0..seq_len-1 + uint d = gl_WorkGroupID.x; // hidden dimension index + uint b = gl_WorkGroupID.y; // batch index + + if (d >= params.hidden_dim) return; + bool in_seq = (t < params.seq_len); + + // Flattened (B, S, D) → linear index. Threads beyond seq_len write to + // dummy index 0; we'll guard the store with ``in_seq``. + uint idx = in_seq + ? (b * params.seq_len * params.hidden_dim) + (t * params.hidden_dim) + d + : 0u; + + // Load current time-step values. For padding threads (t >= seq_len), + // fall back to neutral values: a = 1.0 (no-op in cumprod), x = 0. + float a_t = in_seq ? max(A[idx], 1e-6) : 1.0; + float x_t = in_seq ? X[idx] : 0.0; + + // 1. Subgroup inclusive scan of log(a) → cumulative product of a_t + float log_a = log(a_t); + float cumsum_log_a = subgroupInclusiveAdd(log_a); + float cumprod_a = exp(cumsum_log_a); + + // 2. Scale x_t by the inverse cumprod so the next scan is a pure sum + float scaled_x = x_t / cumprod_a; + + // 3. Subgroup inclusive scan of scaled_x + float cumsum_scaled_x = subgroupInclusiveAdd(scaled_x); + + // 4. Undo the scaling to get h_t + float h_t = cumprod_a * cumsum_scaled_x; + + if (in_seq) { + H[idx] = h_t; + } +} diff --git a/shaders/spv/activation-gelu-backward.spv b/shaders/spv/activation-gelu-backward.spv index 7885dde0609ceb68b762c69131e4819ef214d52b..5b636d3b9404523cf8cf51cf86ffc2cf792048c9 100644 GIT binary patch literal 3068 zcmZ9NS#wlX6oqf5CJiEKRkJ}cV(Y&aVNPGq&z*B5Ut zPWK0`>CNB@22MLd6&(kk-?`#d|Yy9N44YfwAIcOGRT%XO_C+OGi zHCv5d8kV$EXmzxutwBm=E;U=k>af21WHZ-U`n7gvt6@;mMzcBk2aVOOm3Aw|Uxhy` zeb?y+&F*>w!9DyC-fyoy;>*GNca#5fc%g67b#W&*E#iwkhh|S;j@M`VfBZ8t?%H~TJBv@tJqs;WgVmGJ($lJ^SP&y zZxU5~iv97Gi4PKE{v5OA@E!D;`eFFa5<5o!F8U!%{X)qnO6?Tfal0{_@sH6PuYQEe zo?lUOGnD^-g56x%-b{byOj+U^N?>9lr`4UIoyXEqpqFhy?M`kmp1RvUnqXpt>L-&9k!Qy zR{1u5VjK4xR*j$A#;{6zl7;??s>^0$27c}<5jqK*BtsHhkw@k%;EXS!}mHk$9^l<_9j@( zJMY@g(A!*_zSG!w%=L)P=^6Vb{zfxipD|J6EpU$gTGqG#R`cCh<08Gy8v5SBW-;@K zqlWLw8pi80CTh%s^9sLI@}H>bK3>L*QI9cKz&RE%?}Cj}x8@wZTEt%i=N0}Qcn*v4 z?}L5Y>OYos7r~Cx7jYke>zF?4_y*LX&JtK{H@DZh-@t}2&sCpeuG6c>n2*5KF*o6k zIe_W=xt!0(V9!dwd3;}Lk>?YzTJTRx9&`T;Zj8F)m+93a{!4IhHk6IgG>*vHt)}m)lDK literal 2860 zcmZ9Nd2t1_WqmV$$SKKs|BVO zgWT|Xtq0=uTAW>3t2|rL|D8iDPQDo#`QY_@nKoJPsW>&z&WaPeN5&w2uV#Gc$CZWq1mioTnLKFZg7Yn*(c59o zfsFO!h98kJSbtQr^>Z(t(i4wG!Fz$z#}AsRiAEprSY~~22LqXFJioWilB(y9Ki8~^LB}rYdt^P| z9iu6GCWYU7YH-f{&N9O}`(TQnp5o`GIKNT#^~+Oyb&7{m{N@zrcZz%~>XNs(PsV%Y zjV`K3&SAd{eMvi_f4KP$2&1u{_jXLi8|58Z4Bt2UpNioPz9eHE`rpdQ?;7JD)x-C) z42`n`|Eo&?qgl>3h0D4f|MpIa^B$>-k2voAEg3#~4=Qds7R2j0-WBI9k^`UR;GZ}? za&SLz^Sv(|)c6M#|6O~f{&^WO=(gs9a3Hgoi^9aAQ-eE4v-l5%gBrggOpLAnSeS1F z{kN*_qA=_5S=^E^-xfaV@Rrc5&Q)QwJ-T&zUy(IroX1reYj|(yw&uEUea#JV*6>?| z?~kgVPlWky@so#lie`D%gwc$zSKNC4RGb)e*7L5>EdDd$KxXl`g!x@Sx0s%AAhVd; z!o;E5x!n<_Pka`4SD16dM{T|#G|T&iFy~@9z7!@0y2X4Y9Mt&N!tB9nd?QQ@x~=(E zm|7O|oiK6eR_l9V>fy7v`@;MN;-g=_aWwMsJ#WZ(_g1Gb%)2(82(!lWrNYFZuPDPw U%|FUw8E=Dp)IlSb@Ab0mKWClMLjV8( diff --git a/shaders/spv/activation-gelu.spv b/shaders/spv/activation-gelu.spv index 52c4b34e0f974107baa535ecf357bfe65c6a7056..73fa61675b256b7930c3eea48122c2e0a62edb0b 100644 GIT binary patch literal 2212 zcmZ9M*;5li5XL9u00cSY@^%pg1w>RXMUDU(qES%1-ckWln5wX`%*NtTAN+TGwtUln z%A?9Ezu)W**>X}}b${JG-958CNw$o%noPoUm^Smqgsa!In1s|0ld&>a*eOiagZ-(Q z*=ZuWOwtAt({0*KS~4XZmMT7u4xvxDB0Lj%lm~yWKtlF`q<$@uKSa`si04Sl*CyXV zenY-CoAaPr%vEc)P}?g7m6|Uvd($@ecBB4wz2?`0x>3c}u%7%uEhtrs{8X7q!zD%gtqrJxk<;cP7Wo{~H$R~wEmNI5F*HFnGU z*1s%X@%{39N#eI8w;F*hLJc^pW^G{+`r__bIql(2!FkJ*_H5kUrZU=ll|CY{h#!{B z2;6B-xj07{AD&UvoEF&c&e+_awBHUtxi~BNSdP0r@TaoJYjK{iR{HpC&D<$_F7hR! ze4Nwe>(fs54y$LT41GW~$nOa(PrOwUU>_21k979nKOsqe?28fS{;*Amv#(QN5r0w= zANGr?%$6{1K z1^I>}onE*bIJvk-`o=~taJ~;V{G*Xh9?ppz7v#-4emmsf9iOwj8NVVvzE)Uq{;YW4 z*!+^?55xFlf8O$sEA-7QP?Nry12}zcOVT4Y`iAq}v5}X%ekqtllfMe7nUt6Ndo2)y z9n5KVnGa^1#kc<@3;##d;6C{eLBhFX??2Aljr)1KalRb!tg^7N-%C;7-I0A+ z9Z(PMd_B^!v42oWX6R>es^uM|^bZjo? zRq0{fjGc)wapr`&_=w{>T@&!pyBBfiyDpCHTh!NtbZpEMecg~`(HFi+VOpRUbeHF* zbWflL_YME881@s(9v6E{x+i>%YTl8KjhUt9tR#z?_-2H8fjsE0#+-C&5RZ=-SL3d9 zFUIdh{IhCuCyN3x*xjD{(mlb&EJ-I0J2jabY%YFPx)M)PkT%DYBPjK;@(s`fQUCblt?CaD!;(U*VgupD|bM*_-dCT}+z9-Vj xiO=0Fw)ow0$MNsUccnIO`nkXyx?XmqGXq(9M)Fc`+cq1Yq(V%Du=%Qv9U0o6rHMg1?cc*ZxhFOZ8rSMzw z!5@<^{s2Gvr%Y0nD&ObKG)!@*uI~5kKHYuJ=~;#b&mVAZz#Vdj-5)Mr$K9Y9CvL>+ zTH{G$u9vmv=C6NE#1S{-1Bp4Rn2PkUWFzXNI1WjUiHAE5p+I7hA^iGjH<@^>OJ)~ z>ibzw!dlXO5hbhXcDEH}oo*_xlfUA7T%~|GhuKz*| z{mlz^(lp)*vv_MGkNMU(zo+7srL!n~8E!|3OrNJJy&flh!#~Mr_A}og{ef9s&*cqr zrl&dWYU+izM9xTR^6{R$UtJaa+sZyKVG%ziJtE;fi02NNmbw_bkB>dB?Bc!sLyo(_ zPcCfou?$-d@DuUnTAU}ZRoOk;UhZLkF7gc&`8a1e-|xbl^%Kd2PB|z&rHtgnKB1b_ z3nZ*SSRaoU_*23MrLozAe@vQs@QVee2HX{4V#g&c;txyXgFmY*+%Y!3kEF>1M=mh4 z2uE*|(t(8iM4xa&fQR2^_tE zc`rEpmkOIa^pa}OfM`MUKE~LY1}rvLb!9g_`ZjK?&7A$dEd+i zHR+qV0Mpmc()0*N-(cndjyl}+uM#(~hkwhVo0XUQV@`;HM{ej_U_Ud@;_d$t!|x(B zxKDnM&|vO3*oQCg!~6!6jxK!TL}2@ST>xaW~li>yokIEayepdEF5| zL({^{6?O3u$Gd(Z!AI{wfz3A~3^%S!#C|CmkX(`ALo<8q<=*kz-mhfC8Rj1z__fYR zuB#Ga;4S8w>~hRE!ouwSw|#Zl?8Qf&sx+MCSd|?}thWcU>5cCT-`ir|9?HhYTQYBtq*=@xzJ>&j8N;^x zk7WlE%d;jMt|_sYA7lrTzl;0%Q8sb##980)|0=)LL6d{uk!KQS*>XOYUGc0<9{{jFv`qj2 diff --git a/shaders/spv/addition-linear.spv b/shaders/spv/addition-linear.spv new file mode 100644 index 0000000000000000000000000000000000000000..d72869fda98a9f0e55c493ad5688892450ec7811 GIT binary patch literal 3428 zcmZ9NX;WNP5QZ-d11K?zWN`})iCauWjY&jfA|z2DViHX>?!y2BjKT~ucSwv@sikF= zf5i`$zri2nR%MmXbLTc(=cJl``#q2#yiFUd?apvp`p4gj|vV)%4$Eh-U0NJQDnr@6BqsUd{7P5hCBm0TxxdVuS^;b&# z>qq~NQRYN3R++tF_EzVvus58=UR|%wthW|x>vPTR)>18Pw3_U8efOi&R$5!n>~fat zMyI`Yxz%i^?Su*jyZPB)tffn9)pp}veV4bMMzgw7uce)>dYjWj*?H@0b<*7p?1<%S zblUamVxvX|>nLUS5BvmRG3QU9TNxfme9YxWJ4NABPEMomR9h=6?fMQk33vB@Tq(H# z*T{SZl3DcadSi8s%UP?o>rTd1H*0BH-)gqAn(T+mH+~mA*WB!w$a`qJ8gkX;T3YMg z!~VJSSwQn9opqeFzrenf-+2&5+f{rAd++kRxvqE8EU7uQUA*f~Qs2&5an0c4EGds= zc4?;v+Y3eFzNIY+_SKRIV9xKhsTL7Qs^J&5QxLmj@QezP#~t`B8)dB=|2 zp7(v-H^3B{rAK*xJQhpJNwu*{8@AJ&Rmn z7jxOe$R#$H{^rsy=JE$n%w>M{eiw|+e_!Y!e z!5%^@@<*`m%dtN0-;eG?u~>J#Csm(@8jsh+r8uQ4D1YRn|luJo#m`u+nJSf|5}dqHIDNdahykO&)r1pBWE1#TR9i?EMeOVZQpPO ziTa$)i0kZW8+RRTTsh-vJ8N>*nCqa=T1)S%sEeg z^WGQHzPtJldNYNdMqKV+Lc6bydM;;px1KB5z4g2SC;wYM4{w6?lXnKRUq_}8g4SgNCia7V~>+5&q2I5}CdJpVdg#CVocVpki_KbYQ z{s8Q^&)Duq><_{2{Xkvj`Uo+f{2uC(`xxm%1`zkd_Y<(Q^L_67DOk?)F;{Z-z<&5f z>Ztp(4DZ%m#g=o@S6Yt literal 0 HcmV?d00001 diff --git a/shaders/spv/convd_col2im_noatomic.spv b/shaders/spv/convd_col2im_noatomic.spv new file mode 100644 index 0000000000000000000000000000000000000000..c767fd6d299273c242e72c11933e3733e358ced0 GIT binary patch literal 6992 zcmbW6d7PDH6~|wg1q5V~%?%jD(t;?9$TCgD8mWmaDO;@ZGBeCvn7Mbjb1#cg4N8(p z*=Cz%o0&zqRJQlgEZX;d&+0?_9U9WCz| z7#?G-GdqE`^BZGBYpa#9dZ|)3sI%aK*~U`6Z>V>ye96FJ-aV8ny?sNaO2s_k@1Ed? z2Fim&b#U-b4t{&Nzdi(<;N3SJO$3C$vz3p4;5*^B4AiLEdqkcohxzTmiTuuStX?bk z4;<-#)M5Sz|7q}}rT+eMWiStjx}I(D1_VC?++QBf55j=pk4||4JY(<0XsK2jVZa@n z|7vV6j%#bswSke-Q2AO9y4=@X+t6HlWOJ=OyU4Xl@6NDGO`BWNjCC+$jgdU>X7Kvn zzA8tQ=WSqb!NL0D;=H$xY#+GN7#XOQ>eU*ly@S^r!R>%)&u)YBTr94Aps+2jeGt8+ zd1Z6^5H{6w#+)5o9J?u8WU?7>YvImmY}zzXbB0Zh(A-yTzrIsV+?gE@?yi;k*HyU8 zMWC~n_ZeiZx9nmw)_2jTn;tz&<;23CTiD0uc3b{Fc#qGcs_%x+vZk8gOeg344L@V# zymP^|6`c1ZIAipiop9cFxp@iaU6-@3i1nV!*?Vx-lj}-2@3-8^3FqCGJ2l~)gWQUQ zvsXDkdQrpsEVl~DXUG8T&%JbWHlAT~X6-qKp6J^(ALD)W_bi+9ZqM=3=2D*9wUZM2 zIPA2RXB@TSypO}*Uc*jndG@Z`w|FTSHDO1sxnwo(Qu=E{i{}$#?JmTp3)?eL;@zOumf8tYeqJ-e{42769nUkmme!@dse z*@e9y?3sjpE7-X#VsFzKhI?2H_I=Q{kE7AnkaPCh-hVmQXBXI5^Y|`A9`CxgekY)f zku#6BcU^7;Q?QRNZ10WY{aVemJQHW=<9XZ*;|$H=tlW<}Z2ugQ+Wv;6_L>&Er^VjT zVqer^mlAsrJe1hh@%JpP=kHl+k0-WgysyRfH!b{K_xCKd4(LB<@Zp_ybBQHsKIu%(fRm5CJsxbF ze8imzZk@?laPkrN1hBF4mlx-MHrPI0b56dqXCtwPbHM2y)?&Br;aoWR=;eupPr8RE z!O2H2-C*P7BknwK>mJs@$w%Ci!N$s8QS9L0pThBFn=|RgyFV6(0y*vxMwU-OwzANiS4XWEg3S^Aya;SR zyO8jCFZp4;cpQn| zUIw|NT0-(vJD5xJ=MDzNpgLL%nXVEKr7 z4Oq_KHgmidF^_!2ybhes`1RQ4y*lxE16V%J{EcAae%g6rS z12#@RYTW>qkNv$DY(07RcMV$3y3YE2hz~+(P z4|f{+W<-7oas##GZb9VzUH%Z-$NMG!K}62`C64|+47TU!?;~LQyA_H4J_?qP{yqkl zi~c?iHjjMt_X)7P{rT=hZS`IJ>Cb*YiP*3AU4Q$s_fMmJ>|LKvA#(OEHs7AYmXE#M z4o+wN8Ep5G&iW2)`Pj>6!E$?xS$__EJ0kCS-iDTo8lMNJ^ZEj|HPU%~5nDcLdi0=uzBOGzXi7U z5k#L(^tTcD=>0ndcT>^Jcd^YQ?>R&rb)4(>!1m_Z=%4oReQfs|d;0-c&b?h<)c7IT z9InOP{t;N-{r&{)<9_9TjL5lPaqRb}U~9#Ge+G8HTanoB&%yGs-(P^`V!wBS%_HxA zqmDZE`%AFBIfFY9d$7-6qkZgCpI;$z_9;&9v3&IO8*mR2eg770-@B0T`5jn3zE!^m z8z&!ee*hb|7m2t(g5@LbPhex^!{^W71`>Jx0&dOoS2+2|^Ea?@^6_`5zk|*1TIBvm zzJA!}pV<1z`z{6NF*<-&g!SZ)o|@ZVYH zApYCS5^}hv-zM-e#ah^Nv4`My^Nzbd4_i(@>$oPDz7u=kqV9aw#<8Q%1=w=S;EZ=o z&iLJ6-?7Wk+B*wd9r+gG>m84L$70JxKG)=o-&Vx$Na7c`6Cp%zUMZEc6iB)~{A6POu-C~YVO zt*E%8)}@Noy45OO?D_-0^e^gqjvhascixMS^PJpszxVsx?b+T*_pIgpSzlK+C!3wU zm9@u$Y*y9&_kZnV-$gdeJ?|WVu>%;#Oo8QbDGX z3&_*RW#oEt>boGv+F4$so1C-2S4i56N1CrU>!(VVwQ5u4%8}B#(Q0#LYwZ^P+P>E^ z=gjz+b~k-zS#|70o@XAoakzT0i8k*I=vsAjbYi^iy9B#ZpI~lw-uI!2((pumv^=q^ zcD!CGH>>s9uKV)X?ra76SbZAb-t6Pxy{*RK9raqH*~s!$$vN&vT&1>qv^%j~>SpcMCUv2rWV{(aP@-+8{_8h$!IS*hu*A(y??xV z=y0>KZhN&!wf>xg&ZgUZDLpsuio;wp=;_kYR=L*BeHweJe5`CQ&Y$@|yEi+B-QLG` zldj=BoIUCPUGzw$JTVT@2czw16rBFvQ-}B>`ecrKvY(?L8b4Mym$BN8lLe>0_uC=9 zgmx{~{RjGiR&%Pw1$4i)9rc_mjg_0_Jmzwn&W4Hc{menUo5SSu9(iwe;u&{rIkqD> z@0nar!WknsJK?$u&i7=k1L!`)IBR(y+O>MF-i^g@)``1fPKW*#wDaZ?4U5Q#@r?tF0QA=#CZp2#JX{{Hq(^}rAcCFs4ck6$-d-89t`P%2(wjfCtHS39Q#?$-*h+_b|AFaqQ#MApUfDCdcu0d;u&mBd)-*sz-y%0PfSp>TjZ7yq? ze;K+9k>66-z5%(29{@Y_A4EH6j+p;Q!QI3vR)Uu?>s!$$731%Kvv23Er|maa&h_7k z9zd*Ne(lwW!}z<<`bK@{HzMv|_-_NJ@jI|xZ}jVLfO{79UT}L~d$KPSbK1_UA4D8W z$#MWKw+viHYunc-+IKDIy0j|==loX-Y^-_WJ(|b&rmf%c!cWdT+P*JyJAbOcuHQVq zwdlvUq^;lg(8kD_N87g}XPoaQ;(Qm{#{B}VkDPJZzJrMK4Md#xU)#8U7Cv&uY5QF@ zF828iuy=D6r{{fp6OmKALmRl$o}+iy;hDV;(~CBzXXbZV9rlV2+u!VTes^Mfwj+tH z-@e3l-rwvrezL=!>ahJ?3V-8g5?lUshkd5QK9|_WpYO0QcG#B+dx%{b1mA}EUI&n^ z?2-3kIU@flcF_IO4*!pW^>@F+|6^cz{oPOPn7bV;XS{E34YCUHO@`0ff`1;LmEcbx z#>qcNPja6`SNdDizP6$r_NC7~h@5?i|DAQ^ST1@R0^8FzBzoEomJgr%!20Y!!spXq z`5C(K%szt{C;vpj4Hvv;6}fhTJ9F)VlRsC)eHLu2yl3j0kc%99z@0tc4<{e}JOGyS zTN!yj50>AHnAf)@C-z-ESm1r=0mS|PBHH~^zstUU3H@cn;rv5r=hd+{4};AS{eA`9 z+3zE8@-MJv&-78mIC;-cdp}}5^NS<@7@bjK4Z_S zVEyEWI90#-N018QdvV^_+2{$xIrq+Qj&|g&fz2EHUI)v?zWYs*3;&~F{hum&JqDHw z|0%Fs_@4kTN$jtIJU_jL4tCD9`n}YS zxhKGK&YebU$J~=(dw1>(T6;6K+`qHPT*MgX^qE1+NBlRz={@=uw&$F_!>6$2W3Rsr zHcsB0=g@MI^ZVd*?|y*o`-;8$Ay^;zi1`uNJh68_2J0su_v9zw3y6E?ys@tHr-*a* za=x%5?=P{_UVeq0_VR0NedHtNH{i6FXR!5?k6wNYej2ftXOQUScVOr2Ki8faR`7?8SJw@V^mk&9H9* r>vM%3J&T(WedOORxJ3oO5N%F-kc%9P!E#f@9^C?#`wwdoy9fUT)xDo# literal 0 HcmV?d00001 diff --git a/shaders/spv/fused-layernorm-linear.spv b/shaders/spv/fused-layernorm-linear.spv index 4856652387d8c666ab5641028229d12ef7469ff7..9f124a0dabb020f1377f03ae1caf2bef5a72d64a 100644 GIT binary patch literal 6128 zcmZ{miFaIe6~|wi>9mvzsgxEGFrh3WSP^KbLc!ER8X4N!7K(}tlT4BhO)_a_(jqQ^ z1;ib!;*Pjg#Rar#-Ow2K9j$fO<6q%+j_3IKym?oC=Gb4(J@jKI*6#VLfKe#%}mSrboJGR|ez2>H!)zeRQT;Yd0qwBh|5q zX02VV)$8Od=Q#!|J1T2i?fTlw)?IA&ld^7EZaX>a$rgi4&}6MKrT#={1+)e_7wU%| zfF6P#h7LpX&@rfo8pi5xLVK3F`C9~^M<{Wx!OHPk)xJ#ob(^=xRa&#d<1@|K>1u7J z(cV2d+HQ=HtHhm;<2PJuoQ*~%d zM=ocl0h+Vz|8ca>({sPAH|&Jlr}O4n&WzTpt=Y*=4d)-ay&ZM#dja|h8~d-WJ9`eg z24wEb;6wHXFY~AD>GWH#wQKry=g(HXFSXOLd{;o8k5?t?rbpWP92`5-v0<#UW$@l2yu$Z0K~Z`2Z7EAFA3)_Om3TFd7j zwPF{dR>)~BcOzm=XsR!Z*N-bTiDWC zbI4Ik?1@@)xzzct3w>Vm@TeuWR(EbI<+ileuh6Hpjv_}bvF90(yAg7FFZWctGZS0O z|2-i`ty9Qq-ab097xAgYSi2YU^dc{TOX^o6pHN_9v@e5qLF%t4WZw;S7a|+C9P;S@ zG`KeP;qUvS?vlcP71*Xi_We;_Ntb5=S0VJldvhw~d*QQFXHM_922y7a>d%2ZVfULK zc6&Jwat6`!`QW8pot~c!b_V7-liHUS_$;vRr?Pij2luY(tfSmt=$!54;9iKdc^&2T zkjMBd;l}qt(Z>MTfRL|6boL=vI_6w9j!#_n#O3#`3;SD<4LF1RZv*P1&U$Z$I|KFh ztL$$Ob>8zJTz~KBw?WyscfsvbopH*33)C6scOc??&y|h)AY31H#wq(=N1X3;#QDxD z8~266N1buXhYOwcd_N;@4y+s1)0ZpuMhIRN8!rm`8C`a zb>>k%R_HEblimG4LjzEypnl#u_sLm$d}jYZFI+$QUx_~tE}ah<=Wgg5y<7;k7j5o|vUx6on@9dP@~nkF5AtZg7_MC& zxt3wf2bg{?U)TPhHIY#{nu=Qh)C&BKhN1HwDfvb-_o&xLR z9#6wP?y>qNq|QFYQD-mMJh8_!;B%qaW9{}BvGU{Xal43(J)T8&S7ML%f%Q=zbG;p` zem}IF$@Rfs1$9Bakaq9df~$}Dz8c(9*zSNU+vjWG-cjzJzYcyUNJ z?D^}#>g?V5DZdsnm$BlAy$77``Mt>Iiaoy%tdIKG^EZI|q4kjVcoz4AwT1jfaJuJj zLe@up_`DhHY+{$+0@hExb2|W67yUj6w!esf2Uwr5zq_zUukS&A2(njg5%*rOwy66) zu({Mnt`C8&6}cV(KMa{mTjcsM_z@^};v-=5sUP4SJsbW}=mE(7y_d3nK2x8s&qW@! zKL)mTJco~i)eS(#Yg3o*jL-M~?RFXcC!pB3LxpVYC&AX1yKj%eKMi@be+I5y9=&}Q zY>wEs&wwSDiC+e%@Oief$EP-szXf>Yjw+nW#&jiSvv2-+)g`$urvu%N3{>k4CyqpB zz{}0dv9?(y28auRdJs9jTAi z^qZLY)f>J2qkXAE2Oequ4l0^Rj`K&Go6Z?*G)I~l)pX_@bkD^Rm)-DDIZ*BF) z(C7$>2NK>|2S-N^PPi*$+kx6p^Wb{(4^^AZ%E;y+N+&#P4%%SJbBs0|nE4=1d(fBWxFfq0 z-54AmsP@+{tFcLQ?S<=aR5%)i{Uhw|>c9Z4W_KN#wqs+-*^j#lDqp`3ysYMzDg3n^ zJte3A{({Q=e-B<)t8d)IVfGAaJJuok-jiEJEnc)`<$23*A-1Cj(YLj<-N#|z)%p1@ z&VaTf&b)`mHIqxe8|?XSMOH8)b}T5=K~8tMb|m~iJj)!haI`Zj>xqyx6-{7XFPK4XWm?`h`F@WTz6xq zxx53BOYDeTVW<1@jzli8xm<59?KIc_u+v;NAado+U_Nq%o#yh6MJ} z$UJVR_tW_h=K^=J_AqQkztc;5Qpp+H++7G|WG9xk_uKuuek$7a#&lloU3bi;LVq*p zHwPWLkH9_xTVMC#y!|zYyme^1XF2DODskkou6c;V`Qy;WSFA(55OGA^i^0Y_nykJt z`pg6Srf3`IRJ47Svkqmi_F1hH_8tkZND>&k8 z18W=SahiEUqYCAKkdLb7bjS<#1^v7+y{fc<{@xgXz~sKfU}+d6)X zwhlS#(Dwb1bDeKTtlJ0HcHOVg`s_nor@bGEb$5efo%dYZb$=**?yf;{e!u zV}9>=#K~H++%`@N*fny-$?cMJo%c1?c~`YvHwo=JIoD}>HzU92a5lDQ{UlDZ=hnL_ zr}%A`P@VqX6^A{40LGe)>7AI1m|y!i^zn)9{0ZaiwTZ3&`f+w&V!OVU*!uh1E%G_< z@2}L}nb`X89%uVIEc~7KH&|->8!T+&`#UUb{r8Ns{XLe>?=9`6bji168sdB8n{x&| zvhULo`D^j1{W{f`H`PTrceAA^|B_~Kam zL~z=l)3C?Ke-fPhcH;XjpNzOpKH{7Xw$3f3&r`tq$uDJa_Kj*H?e*N*q%f^E5X(i z-{5KB@jZGvocvI^r&VCNE6Y8c2R4p;^y3-eZp40=)7s?pwKngBd-EPd?(@Oso`OWZ z&jicI-MRoQw+8Y2UWj&h7Myz)@>;~$Vq<%^)*<8X*0W3A+SY=#Bkyy-a#7oJ!Pe$} z!smHl`M6up2fI$*+FpouSetV%K;*1V9Cu4DVphRL-r`*KU`HRF)6a9$i#TWP>q}eg zd%LjYYv>t>{oa5!uiE>$30+4V&c6ulygF*R2yBe#K|k0Yr1x_GTR!e*1ME6^_pQAV zF`x0p5q}7r_V{A#@$rY@Jy<^8fH#2Uu0p(LSEC)CH|O4n z$QfH~?4!_cLXJe%AQAh`C2wszz}n&e7O?&^knn#iSYH2~XziGL8`v1Jm$!r6i~A3s zcYx*N&3Gr+b@Csk{+Z}^A@c6+J!pq}lYciN=ibB-_ZqPEMDF*3&Akf=pZ9^~Blr8k z*C6tKkBF~#ScWU3~&!cZg z9L|3M?YugAaR=BK@y)*ob|1b2;qxW1e7w(J2D?t)Z>Rk^#C*mV$J(!ei(VJMeZPu5 zKK|F>x}pLn_&IqcnRWJaNf1{_}hqc_HIvU$J}?Y<--0h z*z*!^_xHg1$cNAO!G4c;pML<>Pu?7NqUH4OME?XSdR&~@pJLyISf_Il_h(?|!u~nf zSdr%!V148x_OHQd?B8JTM~v-U#QH7Rxj2iz0~=4?n)jmRBJLl->09t8?D602&v5eb z-Tnf0oqWXkE7Z9Vs-xcf!D+n@V9PB=T<@G*dfzBu}d(_{0 b<@B@OtKsdFTKmG%R%_2LZMA<-PDB0=&5xq= diff --git a/shaders/spv/gemm-bias-add.spv b/shaders/spv/gemm-bias-add.spv new file mode 100644 index 0000000000000000000000000000000000000000..312cd7bb996915bb06bf5a98156b17ca81d00494 GIT binary patch literal 1596 zcmYk6*-leY6oz-{1XdYja)cs^C|D;z)Mx@|kc1c$AHW+vK&^#JCbxIr?>4%mv!kO1#=#0Z*>>3m7|lU_6J?7SkTOnqy=JL!JfN;}D zP5&Ntr8&;Vv(4SkOkB%&G3(9Qydg518~Qx(M4P>9b8gnSg{=JN#bA%IM(j7%Fy5Z$ zfG6@BIn8(j{cJtomDej}Z#480upfP%)B5@Q^jq7vdY3&%9`72bs&(g>NfD_&Nsf2d zfu|0yB8B#8_)uRrNB=pb^|e3mb7!h;32xqb;4%IJQlIu2xN}w)`Y$4_rQLdReO5_rs?LdG5^c-3WijH{!YH^qx@BIVc7-JyI2eE})A)@`J{YqJ-*`_ra> zsjr)B9d{S&xEr}~e!%8vvyR-|Xp1v?f*uN|d_AXOnaTK`r@$EPVvYT~vl8&QkKb7Q zzu+D2%x_s4x!;z2gDCef4BVA{ydcLNi~#K~sc61D`bN>UneUvZfxB`3F~;AX_8-JJ z`zyda?Z2q$x90i7efq8fZT7DImm#QYi#2D_&)z!+_ue{)G5)T#f9GD`;WuHPcHiRL zwZ-^5=+=zHEHY&Tapab<2{$ zzGOu*l>9j&v?JJOvBU_P4WEvq| zavyRUIfuOZ(^9e%5!-J+e@oDRfees$d@9|2$1Q1RHk+xp)B4ecW~0{Vrdhkuot>+9 z8@051Y@yys4>Vg@_o{8g3>owPw!`}JSf8uUwc5whPK{lso%&ohPr6Rt%SS$C-Ol3v zgYDMhLb^ZeG-|l8-vR#qx8J2$hTz>sv!1qFUoU(H^K&@ZOmAc;G<2o#&PNj3VZ7W_-jMx{>37H4Z_2fSGYaVCQ$2~oY=G<&0=O+uUYtB!hTZ>)ZnP$#kf$ys1 zmiu_{=62?N#>u?`t$d>7-n0C|ocHJFzXUu%WZcC%xKG%;4{gq1*o@KU99pNiza)R3 z#_7THE%FZJ7@E(V@p1N1M}N=XnKYMpE#edWl7i;^8{k32{Kk5=Q7g{bT-FL+t>xKA zEwQ!SqqXGKS}((`<=(C3Sx2onchA&X!K<~r3)jspwpP)1!K2oBShcUEMt|vZJ~7rF zL3~Ew%g~DU)o{PJVq;uC1Ko#c-(GU>TiX?IeLq%m{m(41_E=S$ycqbD1 z-Jkha(Z!|c5hQ#s1KV5h^uI*Ls_d}a!yA3^p*t<34*Oc7&Yr%28E9rWB>1T}{h=J~T0%-pi5wiy^ zcikK}p`9J=_8^}uZ5I>kd!zpvubr%PM(x)<)F|J`W)%_UFK)|JQE z&z82}i*TLn`v_29ZN~efV!U?i%a4{e`}Q5w-@bjvpqZ6Df&Y!t|igqUSaef{{_xf5#-(R5h)yIBcMqi59?<;7zey^g9(PqE$ z)1}Szw@Ym7CZ=mS`cKFzWEL@Y8}sLWJTD*T{hwHtqRnX^Lx?)KzsXhZ_o&MKy{&S8 zZ-Z|DZ>(_bn|k>L44%?|4f_j(hg41Gt(A1R)WppJVcK3!P+CoYOpr%#~hxsw#eb{px>)% z9~Z*4>+e0t!~bHi{?(dm;M(=~p5)Qb2-qCa&swlH`w5>*z{VXf_q`6RpLTP4*V@G1 z^A#obj_u=L%+{4?XH5Nf&hyjgEr^fnSD{^3N8QhW&EejzMPG+(MXo`@=latA20q)t zpF@n({%`8YHzU?Fzc^xV0NZ=Sem-AwA4&2L_A-k0F^_9I5pCuXN1k0^^PDKp=SHx$ zn2Ry6w($P~SpSDg|8cOk@OSpLh5r;-|KK-+oweZmz}ATSxdp6`_Bgv+!M$g98#e86 z_qT(M)gHO;0DA|)zXGBYP3o%{PJmD&m@brs%IB zlZboXflaQD-;ni)vl4fpcIGm4%-Mdh_oH^sX3%xS$JhgC*VVD7gJ5&S95le%W)bgR zUPH`ftTwZ0G5M|;?Q05+FvF;5SH??ap?*UdkH{vqO8%+rs+M-lhzJjwOhfUHNH z3FpImRY#vc278au!%x84{Px7mX)``%ZU+8W`Y^BZFcN+I3_OWAr;XvE{Qm?$2Wxi^ zzWZAcTLQ+tl{5+$yu{`o43++xhiJryigCx!-4dxAPsQZS;(>jgC=`@r`kf z*Ntj4pWPdy8^ZC8j;3C;Y~`}qL&H6@kDNP4hg}#x?Vw*WbCQe{gkI z|B``?gWX-jeS-r_PR2)H^KKZtu$gBPczAHQtG{zSc6(z_?9)ev)+`tt7#bdG&`n2^ zgKn(u8tz`xIn;M??@r#{21eF)uJ2sa+to9K_t>Vty~}nI&Yr%tz#Y7E z>h9`Z)7#nK>q2!@{kVQZ-Cg~Zg=uT99kAK23*PFhuV+(}8`GGJZuO&2N8>nb`Z#aH z;K;hFefTkr`7Pd7nt2yu_Y4j<@nSKe-%9lQ&Yk&JV_#6%mtwCN>D|!R>-@(ww0%|} z_Uk_TxV~4Qm-b$?z-h`kM>(IRoz2-dV>}l=#%x0mbq{Xn-BGinaUd~9?2=)3mZF%>q( z@yDV`+vJb0Y|paXENtUEdp`Q_f_6+k=aUHz=B z?ONKSbJ+(y6*0H9wD(1P>}5Z6=9BR$aQaVTMO>@!Gr^vL*z>@ynR!k@yGHiqxOx%d z6aVF4*X#iP>)q043fQ}&ZJy_&T`M`~qwU?1b9_sMjWv(=CV9LU+WK9IHb%}o+E-Pa zao&N%-3Zn;?oDXpZbXdJz6p`bJl+a!NA!Oiy0zA=arf!ni0h)i_Ir?VqjuWw#r8Kh z>-0_V!=oqP5t~mip3Nu4HullGe>~g1y=&>)yQFO&7o+V_&c3z1J95T(ZxZL7(Khat zXno|2)Ao+YrT-hSeNP_3Y@F+BvE>xc?NO|T^Yv`_xW^B{xJKsmT=+g#$2MPS`_3)x zIW>E3&GtK|^q*hYr+^n0wtlX6>F;+=X|Jx?+iJGo5~aW25v9Gou&wWRL^SjTfJ7ylVxYsY>9Sf87TpN>8d z(MSG4;-852`3G}$>>xzW`yx)BCxKmq^m;JZ`S=cw&n&QfY8?VLPX1x?C9nE^R^NV} zj99lF(chZ(btu}$zVw-m$k~@TH4g(j-}H4j*uKg&JOW$(%4!Xd1RE#s8qPt>CFfD# zGUw6Q^_<7R$zN0DJQi%6yg7Y8$fd92!Sx&`z{$tw#L8zLI=!3(mQU|b0UOui+nWEr zl8gUS!Pbtw0PMRd@h5}zk+;wJXgTph^r8wcMtc|BwJ*%;8z0X9xP zac6?-eVqj-pSWj%jg?RRXMfCgK&F@%pp9^;F;?o7z z?_9)vT8Wm6e>b>XiymxyDeq=4wtRYdF4#Ev%;7w+apn12gDs!9KCrR!i8~)$pZf)H z@`>vQ8!Ml_*Mi3($uj`1*B^wFPo8yP!9$c@#0Zu+~Ltta&v#!Hn`*JKjjexxq z$-NO=&wU}Bd~$CB8z-OK7lF<1SaM$sF8B2kZ2jcj*VSmb_-_Wg*YkPvZ$ZBRxeVDC z*^KBn8GSk8SmLe*n=5hGfUiP~b1ZQ$1Uq&GXVE+_Ld+!}pKHOs7q(PB*Ma5q%{xQR z_rNl=_f(x;UJNeh{1R-rDTwio$(8S!rPzsoDY!h_FT-{}jE~RD!SdOUSAdO^&py5q zY}`yFao2<86Za~xvGVbGHFz$PJg))Q^V|R@pFFPx8z-Oh@H(*h9ZT-lgU2E9c>`EK zdDmteS}y)?2Fq!?UvEL&8+&^j`t69ly%mYiJ1RbV{7!7+`!d*A`PBak*q)N-tKfQ`uffSD&)31m$>$t@18jcBoac75ocK2MH!FNQ+Wqy6 z-GO%B)P7@p8~q)`$MNr?9apE{?}5#ce!mZ{_xl4l`JBBUf{l~++-QFbv7Y(G$^Rqp z_QL)#*m{}gPr&-f=PvyeoV(<>xhA83hB)RL+*#Qf`#H8;?7P6O(Kd zs_?z&?D22G)=|61zeoQ8@iEUI(T=Or$DhFF$R7V0T;Jorz{zKi{|Yuv-Zj$x9b!H6 zi{#rF!0zMa?4xJkVMHJK39Pc*Kaf$# zeu(3)-+gHL^!QKk_QL)bxa{!}Y<=WY``=*mq~}M$=~;j4$my3D8!;yKC~Ws~OV$5q zY&re1&vNda`{>%L({~$VWuIfP<#L7`lQTYN!#zsQYtsaBT|C|4&bj-g) zU0wC|Sg?Pm(#PKXdxV_WzeVH>P5|4d+B3K-dN;(!J|da$zusL!D9|x|V!9C#Q za|S1Yjg$8bYVU$r&-&uzpA6nk{n&eATQ6sDFR(uHS)0AV&N2JD4>Fb)fKK86mj|YIu9;aZ-Wp5pmGd^qX+_i6EuI5w!Ki?CNCvIY~ zK2xzLVP~zUVe2FB`*8-^$8~pXIwEIHapv#@@HqUkuLokgugj7690VtyT2BNUC+}R6 zSDn9UJPB-F=cm6l?dxE)kA3Mg6OpqoacUj{cE0KB$=LSg+Qw%#oP54l4h0)0@BYm~ z%fwsD|$Z`PCmUK2R2UL zI{qzCF1?=sws!1!V0{vQB3K`JPH`c0JP5d~m(*r#g=0yJi8{xF*j##`p8d z*v8E)=5h+0eBu^@jg?RRMPU0+{l(yVo+WVd$@4U@aq@W&ECrk2G3V*uMCHWG(D}_h z6`c3oX<*Nn+V|b*=oN^M=lAK=xH|o=1e+uMJ_B6u_nC0=xfiRz#>soWwU;B-v%Wa_ z&jj0h?6bhu%RHY2)<-_?sAq%oj&j^se-oYqcFZ+61FfCB=YY$7@5HXp<6JoTta}&O zIQh(fHQ2cF40prHC$0xC z@~PhsuGe1+C!e?hu(9%4|3R>QIhLN*fxTDBy&hc8y#Y=>xre~U$tU+P*!+$q_Xybc zaC|mm>nHENo{g5%-#xz&>|Q7Sd0>5VRxSl!g6xa0V~Kk{*s=T;UIsRoe0(;8y-QoF z?}{y8Iel}_*9k`z7#c=Y;^AfOe^4Y^}VDmec+%E-} z`}{I&{p4NSt!TOUzX~j;?Y_Sn?A~ObUjx1#ai8@yhn#+{@4;yINBu}+6n~o^)%XvI CacfTi literal 0 HcmV?d00001 diff --git a/shaders/spv/grid-cell.spv b/shaders/spv/grid-cell.spv new file mode 100644 index 0000000000000000000000000000000000000000..35ce3b9e9fe823eed6f837a585622717fb9d0a52 GIT binary patch literal 4236 zcmZ9NX>$}+6oxxY5~4&F+1$WkQ9%&dMG;5<320<>K@pQ7Bw-}X#LNT%H&#UvMMTB* zC-{XQV3ohXZ+@`+9a>gd<@0p+fv()CQ}=z}bI&=q?>#+TnzU$g(%zCxNjj3BlWfdL zCM7MvDM@=e_w*j@xxO6fw^OiMbGHs)6BM5S8OF$J52EyfHufDK{S zvFYS=?F`K1=Fm#~CNcj8X=6n(UmLy-e8ao8t275)V?Fop5SN)xrMNv= z$lN_VSZnUBH>y^Jn@8PI(bqRxH@~r-fkV86`E-g~ljY3KV}s2~$|ol)n0v;mNB4|X zNsvB|Daf4p#Ptb2bEX%9&of)gJWy}+HFN$HbE8_zc`11YKE{I>w^ht$Jg?`$qsJ1-ccQ+%zkdR!C7V+ za^hWE+`RUsqlYuJceRI7UP}9B5_l&o?JLKkY`U%4KWg@@ITkRt-{z&XH?kb9xBKKi z{N`N9?1-E_?8lzbZ!WR<9Kj<-{sellmVFep#Ez&HyjaWriCSW78E{)k#)Yn3o-$&0lH;Kf?@P1F)QqE_%?t)bj+Zn3pmu;9g7U!WIjU4v(} z?)1x@+^uQ9n6EX@Xa7XLnM4}9fG6|W#Ispx{xaq$!zVLq>T}_D<=7bgcQdzO>Kk)z zU#cyG8`p_BTz?OKeI7v=8@Z<#XN`L?#I05s9=9i+tnTX z2>g#*e{HeO066N{t8(l3u`#Ebb>u_27J0_tj9-I~Lw>(?cuVjrsk^Hy_|=prXe|zC>wd{>T{Y{jN?Oklj&vja^Z|1~2A9?tugVl`p9rXXt_ttka*4&@#zp%#l z@;AHSN)mm>c)!bXpFOn-%bqW*yBfX{v!=etu?DQq9QK4feCxn!o$TFu>oM9iJ!^f& z*f;7CvjJR;c^Ynvccjl4drLiHHiFF){8?}ryTrL$cN69s^;okRY@Xm-z@9E9X33rY9xqKhHz~<8D zK75zd%<0{|fOTLOSucMv*Ls;<@4KcJd)Nc^St9-=uxp}!UIzQ0>)q>%xL3gX;(T5O zn@im~zQby<{xz_BjXJM`t#dZt=Nn+PLs-;44366RBiEZ?eNnp)Y<_iX`+uPpwT8i- zN#s2S9>RPkeUY~c))#m77WfFJemOt?{&$4I@mSp#*Z+o z#rn6w{RQ3t7vr07zeVa1a}w+`M~!#DlyxTm-jZQ!#z9 z?h@GNh;<)yjv!Pbq~PrwnYFJeCh8?Qh5;xn-N6)e_$4pzU4`Ap}T z)vV{c`4!fIb<t(>J*u{9CwrI{Bjc|Mnf`9+zPHZenq!--C^JP4xZ` T;Op2N%=+$2&9#1WHemk)trjlz literal 0 HcmV?d00001 diff --git a/shaders/spv/hmm-baum-welch.spv b/shaders/spv/hmm-baum-welch.spv new file mode 100644 index 0000000000000000000000000000000000000000..49c0b5ca7c08112388fb5cbc0174a74f8ce0d4b4 GIT binary patch literal 4536 zcmZ{lX>$}+6oxxwWe3^ZfB-5kh!Fu%WC;NS7%>VWE;tS(BqNg20;c9iDUdMBA z+3!Mhty(LVQ@c4?p`GH^#DhIL=m~Na+5T#DDZ}@mpG|Q~vI%Xi6KuBf>&EV@4t1B0 z3>SfE?{Y5virI$G?VY&nXBXNy?KT)~mobw1dw(wdpG4!oMO%Mum;RFad)F@g?Op%D zS~2ZO+ok^jNtQ9rhRc|f=o6*Fn1`hO3U>FdV9!H`dXE@@VS#9IKLF*%Doc06~{diB8kfbT$f3Gd= z<)@5WidNP#L-y3wz}>ab@u*b5u%6%BS*gYAE8KA-=yVQ&O)YOqf> z*#6(>-%Xdk8#55!AD?>zJ^S1<5&4Zgk^O7y?`-=m@_lgT!~dF$51(uQ!{<6U`LpD) zhW}N@$y>*Bk&B#*z;cnZ6KvhZNcb!PcOs{=oXfy+-c8iG0cv`_9-9vp9Xt&k-q?z^IeGEp8+@aUWAj6-VcC{m5Fk6_E?d#_+O_FjdPkKT`fjg^nw&w=y3 zAH~*BK6)Po`(I@5?nm!6uzS(_F|hs1d#1XkMnsIY@B@foCD|2=QV8o zc0c+)54Nw13)<7Z-vI0Dy-h>hk2-IHz2jdvljZ2Q5IOzA=WVdh6!tq{ zdyANN!TQKw%JRGiHkW&TpWjEj{64#P0g>~YERNhCfa~X1{~mmZ-Glgy?&;@Ue1y0c zaUX;8ai3tv?6?M&1~ literal 0 HcmV?d00001 diff --git a/shaders/spv/hmm-forward.spv b/shaders/spv/hmm-forward.spv new file mode 100644 index 0000000000000000000000000000000000000000..6197d531748ac6142ce76edaf802fe8f454b810d GIT binary patch literal 4644 zcmZ{mXLDOs5QeW~NpVPkP!dXlaR?9~A@mv`7znn(Kp>$PSuWTDTQ-txOb;#e-aCfS zLJPfq;v;{Cfe*A9W_X^fyC5{f&Ft*E@9x>%bIu+)?ekX7PYP|x!lWblK52ZrCG(Ot z;KHPk>Yn~h{cCIW;kCyeccc-!B<*RSF}o6#N9T}TNq*CBb3eU6S7d&VaROXGc&naQD2y*ydzJ1g}U(mpp%*T&AA ztkmkIO1%bO;9uH5xKf0!k4UFGlXpLxQ@au8zJ{i&)safQSO?pKXX7XCL^t?+atL^Q zao2#Wtt?W@o zKVMnb!$ZFmdoKHJiyo5E@?n3;!u;-%{`Mt0gWY{(OcS5%z^6N6&JC$3F z?Yc#XpYc1R^~rZ;_bd)EW_N7!IRoa@_C3nk66`j_I?hSd831dWr-U}=0Ad~OK}0Ujz#l{NM4cJ1 zU)cN%B^{5={S+<|H6mv}%5v&_gSH=e`}h{^`;)g1ZU4T=9nZac2Vy^G z>310aMj7ip_?FdS`x_kgDzLx9&Gz~h+uy?QtNkqudlh(di#^?B8|Uv@bNtp8d#=Ua zmf79x!WrKM@f~>Xr?RI4dRIjLGA=)DkU-not#-@W`t#5mLYBCeDF zjhz0D`u#~geajFz&sO{wmnfEtI$dDv_%5Q(3UC)^a_*?f+7+oWD&`^D=O2&CB8B<7}@0yH-AOUkSGNINPhh z#>vOoUJZ_W*B@tl4Om~~8Ui<;!!UMhPa|;h(bKhH*U3koQE>Bpk6{}pA3c@9gNU=J zKYF?jtS{#Hda!-Ud)NLh%0>J*ST5`du(R=Pc25o=2fnFUY#3#)94u_`nVC?yq}w}7r7n>H=n~3*sVQ12`3*tJq321eB^l=+<`<-&w!1SkDi_d$NSJ9Jv|53 z7w_YFuzkvV*IUtY5&r^MF65~0Tn!y_D%U)x&Wm=;v^x5E1MFVnt-lGDi?{JM*w6XX z_ZIR#VqfCO^A6ZNVZRG@pK%`Vfjy7qh`!kWcCc&p+y6GSoc&tcI@SwcUyIL}sQCfd z^~TPjwe@*!A0obwcz+*(Tj%a$IQf{n9bnhVN1ji>9Z1aGr(omcqmR$P>znM)!S)k- b{sL@_e9Y>X;G>Wwh&kPZoU#7D;zZ&WC*xADW~(q{)79s_xiQ%S6_3b9_J;c)KQ=F(MFhuk@;$)rC~L) z4!IoJfs7-Q$bHB;X-4> zQSVnPQ`KG_i1*6W-`!W6ow;gdzCB-W^%v%sZPweY&QH~i9%lSNy!}Z4OT)$Ey-%-@1eLVv_-E4F! z^VK6o{+TG9YOh!6AMLm?&(-^Kh%aTHP4P%_8S_N9T5GURj)vG72TW$+on7NyJH%Ho zQzE#WT*Z7_dJy%RaZuAe%R90|s@E!cFxT(MOspSCt^+r?zG|yqu~^^4E8u#SPPe|8 z<=BaSFh8GN)6-r@TBp3D#|>unl+w2-XPn)styMpq>$fw{R_Ze|sa{U*V7|RmtM==+ zF7!JKebh9ryrWicruSsclPP1|Qo8r-{7MNW=4UaS+y!2aUrLUld!{42tCP%2_dnh4 zaxT7QY~|zw*!JcR<{-8@i@XPUM=$5hZI3xTyVK0oTD8+}EY>T{F3fOp2Ci0FtoknX zKg%5Z|0eiQ`j&dcA5PRAhY)?wrCL_+Q{brSz*6T}S2@wUZ%u)-ALTA6aQ30xa|)dOCugmZ z!+dgM1d;b{%*!y*-Pefy}8sQ*E(4F&Fsv*iPz&}?lHtM zhVHvoe z?MZe0Ud-%%<*Y~D{*=?cEyoe(O7w_x6aM$qSsyvq zsgEOav9DKwr{kj3&|DTHMcMs_PA5^q=2Xz0Z zDcb$Nrl=nq&`%ENrw8;i1Nzxq-^H%|rkscP-SNzBWY3=e`H1{ZPRR36*WaGM5Lt)# zEeLh@3J1NeqA69seP|w&x)=#IqMioQuKci(JnKn`;viK4W0{NAp~p z!LA!a%(aEtVJ>Z#Aadps|Cfg)mW!BM!CBsN`ki?Ry5Ao2YSYiRwjI&tx%e*BBgWO> zVvK9hcOb^l7BQ{`Ycqy#Nj+@WgN+;d%fR}aVz1_ZIiipJC^sc{12Tm8rnMW_Hz^;r z?E)8TdnNifVr|+Y-_2lc#_%1hdl!3{ClK$#Z(#WE&H2#xp_}{Oyxv!X^^uP_`@w^8 zUIQl|`+P0<7DPT`-wL*-h>b5^-UxO-hmi1j6IedE;ozVUw1M) z>@RH%M9$dai1T)^^+v8H*j&~ZKJ#Gt=;apJb@JxwFgwhpt&PZ;OB}r{7cqO_EN^zF zee}x_^J>#?4f6t`&2#b_tR68AgNreapzlD8p)Fz@1#2_LBC~qLcn7!`C^!pHN)fTmW1gtIekAgk_=>3m@^^uP~ z$oYmNud{pF!~Ai?J>8E)%unQe=$}OQTw=|qz^-`!37=1c8A2{YT(A8^ z{+;tuxAxHzYNw- zK5F|4_(4Ql?Dwl+_xlhMvAzbD5B=+4-%G@P7_5(c#Qp}@_}aqn5wLsH=DyD`%Ne)K z{B5N8F1~|)4)I-Ri~QdOYcs}KX7#B5dtmD?-tqU*<-_L(xzA(F5%Y&&`Ka|rVAsiq z&yT_83!k5W^^*^upMu?2oWsw+a{9)e<--5xVEv=EUx2OuaU^`60L$Nb9s2LV)*SEe4`6-d<2UAy z;J7>EnMY3F$nz&~G0&gTv$M#)$A3YeL^dJ%8beM$d;Csjzb($i->r@4e($eG{s&9t B{|5j7 literal 0 HcmV?d00001 diff --git a/shaders/spv/lstm-cell-forward.spv b/shaders/spv/lstm-cell-forward.spv new file mode 100644 index 0000000000000000000000000000000000000000..8c2dbb088bfc2ae6494080ba15533369868e1a7b GIT binary patch literal 8140 zcmZ9R37l0`9mnr53#%;2qA26cfC_>rq=Aw+fH_Q&SypD-_?Q`(m%|J(Ghnueq?Sed zMmteySeY%hFJc=-`);;KJC&{EmRjHMz4y-?-jV<3|NH;We>>-%d)|9v{g}C9v!=Ri zVm2;YT9@TxS~eyVPRyEe-LdYHb&H3FdlsK{_L(|N&gyeNeWqm7vPN_RvZ+!XaN>BR z4e3U%K(--!kiE!l$X&=(f=@ygB2%ej{4~UZ{L7`DI%B|hKs4f!)@X#8z}!%E0I}$+>ZPZyWo&I&!*mVC!IaWw<&x(0M@_(~MqU z8Sd^Qq{!2hwZL6GGSs(baA0V-GB7-Z-Nb+Lj>3PadR6Zz?;5HDn@5Jbj^&*-`;aF^zOooiQ*Y}nAdrRZPY5jpz6*6?{c%xA4m^Ob?_hn%CR(^DBH zk#U{@MzV7S_T_P&8;3b|po=~lvz_4R&4*;(UCtN#x_L|Q z)_mdeyWxv9u03jv@{Tf3U#xLYLFYMd0Dv$d(*fow;EaK!H z`5LL{x_loTKI8r9{2b(VV|D=TGcS)KK8@L(a2;?%NA|topmSf(rg?qG9L5|(n`fwN zePxIPW35AQ)vn@1Y2SmrA+Zk|lh_Yo4<50}TF3j{+wYs@Ym(=jB-iEk7j-9NufZpF z&*v=XFjj4^-US72$foBx26}z=KEFBG7xivS)4&$IF?e>Rxx7;xu4yi~uWK~df+E*Z zkJMc94s)f?)Vs;iK`!sKiO7k$U7xj}#~^E|BzDX#)Dg~bIq$9TGe&M4V&2QqO^7}n z>_r7#*zh)Q_h=P zY^~yX6n0wcM&{C5-UU%B?uDoocDk;;_{X}$)~Z9St)15T7)QZ#-f?@0&E=F)S+W-S<9vk2@fG4|_C! zaVJFY-VfF?-dfsetrM}+THYB^OYHFd^&JX3t>vBJyt&1#;CMf^(^~Vf(^}pgQ7i6| zs1mLd0Tm#>(SzwAZxee(vZg)E7wZz_AHiLh%KIkWZCxf zXJAhUwP-S0uxFt4m7j$@tHjoG-u6IP^aO?VNmbc`u4X^l8QR z%(Nip)OO8s_UK-FYN z!@gdQ)~6X?0oc4VN zb|%mKbl{m_sYN^q-@PRm`p7Gf;T>|nJi`vxd_RnzgVyl<^*4_1aM=ESN$mwS_Ois* zuRXENx4Opmw?vxnvKo6wjlHwR-j&$q*`L_fJ5XaEOlc%hk|*N+YscJ4U^(Nxt4=}M5bv{uPX6sbo8(Wte@H*9*4dVao9r# z+Ie;Kuoi5N*tf@n-M3C8e4YT7k3QFdjgxn;v@bxcXMS^VZx^CH z59;V;3pic#5VoA>&3NbJjF0o|xr+EJ!RdK=0d^ZQFY&nwEFXLSLa=f2k>^F=^gO*7 zTR-{Or8*zOvMY}(0-`DHV*CP(+ zUxs#G9euwXY>qhdH-Kx`{|Y$ycwb)$HcsC6Rr^}Rdgd2L{#Sw1{e3ldZT{E5$;bY_ z7Hph+n9(*zX4pE`;Bn&k^4nHCX??lUm|69S$ zNYs5BxVG-w;pC(4JHW=tN1k_r)4K1%)=%ELyU=ps|6Z`Y#Cdujcn{)va!x*ft`!@^}}Fu$%oG^;C3X==10JC`o`Im^SpWXJP+#VZ2jb8pFR)1 z4RN2Gk3IMT*txL32yRDWPrd|}(=YZ!E^6Nnb`3Xk=PgCwf!vPlL;T+G6|}>5%egNj zUq{?aan$`PxE-0qQ}?-l4XHzBAkIZ?dH2)2<}g>}_y)Kg(RV*u`<(JF{w8<=A|JWr z(p=_@b$tumjKq8KZEzcsKAZ1g%g3|%E_fv(AKdrA=5j7_e;=Ga_a9)_=Kdj^eB}8N z*f{xkw|@*azjKlMC*brf{1jV1dDndaEf@a30Ly8+@4rOakvKQM0^fyrZuB*WoPO@z z3bfBb9XWpuc74(3-C(&+#CYfKMvPyMHvU|6#Qz4I?)`7EYxn1OaPqM~zX!Yb^1=N9 zY%b>__aDLO{`?7BKY4pQh?WcgKZE6BKmG!a{cuh{_wBEUbLKcy+L7aL*hP-wcf#MX zYuED+IQiJKe}avZkJ|qNM{Vb$_P@c-S^FNecGUh4IIVpzc5UtZ;N)Z9?g#Hf{Wu@}{1@z;{TxPXM?VjP(|-PkUE7Z@zI^mkhy4h)eDpI0Y%b@bpL%fg z<9zhf0CvuP9wJ6N`e_2EXKgHYZ9n7SXsdOYI&H4z=XA5-#Ivp>d7M&x2GQ^2Ip|K{La?DN#z$N$gnchaKfw32h~5$b5$ G)BgcJYT^|D literal 0 HcmV?d00001 diff --git a/shaders/spv/prefix-scan-causal-backward.spv b/shaders/spv/prefix-scan-causal-backward.spv new file mode 100644 index 0000000000000000000000000000000000000000..e6f043b8c3c7115a8bbd42ffe19c5ec0e505d417 GIT binary patch literal 5868 zcmZ{mi*sCM8HZ1r>=J09Rti*FVoIqOz!sGPilqcfl0w8v6)FmwO?JDxlI)hfv_(aW zg<2616h&y&dc_M?#QTMcs1+5&+aKUR;dRCtpXcm(nKL`%mwD&?{oecco$s8@@RD^W zCB>m+X)>DpHR-Pv$&zFUsNS6EE$F35F||+4?3mfm=~gyeyy+q{PD_SWsc(6*JSl)j zki~MNseLK30vSWDM`n>0awqaI@*MIa@-ph1XY4rQ49YNnOVH0kMv_yKshQGs(>Ihl z^>U@!E_HfywRWqwyHsvBy7k3sw=qxbNOJQ3i<>KV(h%kxN&Cf|1^SGV+eHc>=dLZ3 zW?J**g>B6}sBWXx+_n`n{$b`#l1hq;>CD95)N0>0126p(A#^R`qiywr(14zwHL|M?&?%;FD+D?*ovvm^Qbo}m1?t8X)M}fBt6&N<#u_I zWchkx=tF9MhK=&dzI2{RaJke47m^w9*6Muuj;ngRc2(Qy>1%UMAkOdea#?pLx>6@O z<%Q%{^i;cCna<^Fb5*AM`LikO;~LoX0-YE=nQXQ9mfJ6-r~S3NNKgBFPh8H@L+{J- z-w)36cV_L=eC;l0HOD*Va{dG8g;uRpPI)moh@S5)c6y7YFsC{9Vy|UqbU%1+wNa~g ztCdovp4RcscIMmV?!3H^9EP_I;VwGAeUBKE;*sPSy5is&e+peM?QU21^yfOR--nX= z3&{zz=NcRHt6yr)(A?C*&i<|gPUf#Lr756WC1<&`IB|dtI?d7ueEIQw70iN&W^Nn6&SM&;=@8$iCUSfNhZ!dYi z*Zbi4Uf$*C756%N1<&{Lu17Diy@nBc$@9H_4bS&F0gql{*DDa8SMYo<@1K5ai|rM6 zL7wmRJUrjayBWR2uILp!-^;rjy~OtNH^$!b=yf_(t-F?8yKnIs#MpZragD=ILMz(O zf}fmWbM&8r9zwKVo^js~Z5P1JTZXudKNYP{d&K)r<>SxD_)4bo8JXvvoI^ea_g%Aw zqW>JUGibN2d@QrMuUCS{k?51>=C93St>@hO&qrKmazgw>rTNBQ19lc|=iu{w7LS2_ zU**QV746=&Ig{M?Rh$0p88+9R>K(|_mmb^{{6=_^?B8OaTR_prcU5kk655*D>?5B= zv_+jN*kzpsbktb{%dOKyo1@J-a^E{`=K0P=p6`_0ybqv_(Pp0f&dg?i&-4)DoeBOB zcpO=S{}X8I_H84{C&43#cI(Og9nogLM>1?q)O!@{9S;5&cyy>gx9_j9KA#h4xpjVn zc5ZFfrv*l;mASd1W(~q@M(zOE}!FOo=FkC9MOIS zPt4~dkNB5^S0MTL)8X2U_xZ^q{!Fm(-iN4v7FfISK3{qGR)ei!{#x|8h~KnvBx2Tq zwLeLW-^TM0^R!#jJEF~a{jWl_{S)77(60aB)Bb8ioA**2JzfiT?=i;(VCV4r7%>-u zwZ|N<1DmJ)-_(n`>c8_8oMi)I-vVO1J)LJ5eH-#l#9rd4_cpL~COF~k=(i)+Biz8h{LWxM`-bnv0iQ9^<0i2A#!jN;vA1`DV{h*PyWE@hcOwCGe*W9xEyVftMc+19pM4u>x%1n{n$`|q%pX2=)aioFH@1eB$NYQ1G5;N4m-B1i zhunqOOC0lm5S*X?Lva7^aDIK!cRyI4efOf}&Tk)UT049(fB4i<=fhz0jqRc3;kz5` zom$2{--^Bm8A8?|`ph|i)*d+@0c*Q;sGsvu@FB!JedZiQYmc0dfwe`S`@qJpM)a9; zFIs!#JOI`f{NrHn&pmu&ob^G(813f!JFCt3BKlLvXvU5IG-8b3$VbqhK^{hoGhc4Z z-^hOytvrV4JB;>>qxNUP*4B3fE%(`c4(+p1|ARV5(O*DZ`oDN0G`3hKj{7!upY@T-SoBZ?0QN;S<$bB3f_1*W^5cjR!7|-VGi1y&$0K4bAva|js zSetQi9@^ZK`?IGydOiWRXMF3w1=co>#F=T!pV?7(^#3;4{=uIFdlwF6^Lz(vjCSiC zLu(U1g{~szu0nqoS)Xz5#rF`O-4oQ(|NYFi4E=a!`vKTIeg1y`2<`H$_5Bdh=Dx)- z=a0d)9RCTphaAY}{wdfP?UC~{@FWs@dm3z<_Sla$&nkMVBmd`M_aFQjuxAkO$uGdh zJcC&8I9i+WtI*FPo{#(d724%Kwf_>)=03&dc!s}4HfG%W^E<>D;`j3R;OCI@5o7JE m&A6!XM{vHzpWx0GHU12K5wV7`*3f30XAyI$|I63pGUR`t!6ncD literal 0 HcmV?d00001 diff --git a/shaders/spv/prefix-scan-causal.spv b/shaders/spv/prefix-scan-causal.spv new file mode 100644 index 0000000000000000000000000000000000000000..8b96607c11431349248038d57b526934a05e1e65 GIT binary patch literal 3684 zcmZ{ld2^Fh6o=ooX;>AJ%>^(8M^q3|*%Vm{)wWXE1>73aCbW?zDOs?%Rt3gS|Bg{e)+duQ~t&5ZU?(al)S4`yIWqWHQ|aW=h+wQ)oQI-uGSmYE2J}9 zuCyx+lC}2{Q*V|@?e}|4ySkG0cB7SY7d)DL0G_DLSMHzf-M(FGr<0a*jU(>WyKsrG zqgS%r&%o2o!POni8nz+$qsn0LBTh%Of zG-<(gn0KvOUI7wSOq^4e|JY^Ed+>|t9*fC$;ETQPf1AYiHOafC5OY>FSd;j?U&VK* z?Yq-{F11I}JMi62kr{Wi7uyw_cO+LptTBx)KDI~Fx%X=nQj9Bk1~EV8A2lc0pLd|& zTH-Op9%1hRyR<(==X-f4(JSsJdWD_u<()+@vAtZ@*3S3(7CYa|xkfLs%NcpDu=BlE zGrzUP_A)>0=(Pn_evDa+WHH;wv3Ch^m9W>K75N?5>oRPP{`Ke~ME*o(d$)4?vCZ3v zxQu@StxrDUeGmEgZJE8Bc+YCAV!m(2Uh?K^`>y2N|4wuXiF(?*GTZ#!;5frpCNPFA z{}S4BYP)aqUO`;jnb28uDV0KU4sER5E_~A&-VOGxX-y%Ru5U`) zy5FO%D`#EpRYWf8JOG!F15|ac4-q-u=7ZZ!aR}-HZQ8hEgmSHD3mM z*68~x*uMS_BW6EX{@1MU0kC=Uo@F0e&Up8A5Rr3FZ=hZ7N&XOW260c~=bwg!PkuGq z(@C)P^+oMdVDC0!PNzO+*Poxi9l5`liSu$cXVA`8pKmVjGbU<}gN-rP-=wxa@8LZp z?qLe-@*d^h1 zeeT~q*w4Pi1XWt=rhM}pnT-i z!E&=h{hT{s|EHLz&zwcHeB?C1a?z&=Hr}_b&zvQ+eB>;H-EY`;!N&Z|8+Wg5#29(= zTNxMeU2M77XAd0v)MriyEgv~wg6$FZ3fP#3S(ep>Jb;t2uK>viSBEAn}?JH-T?;&cc{k9%Q{sp|1D6jwk literal 0 HcmV?d00001 diff --git a/shaders/spv/rms-norm-linear-fused.spv b/shaders/spv/rms-norm-linear-fused.spv new file mode 100644 index 0000000000000000000000000000000000000000..5e567457179b15ce569223028212c6086c32bd3a GIT binary patch literal 4820 zcmZ9N`Eyi75XWCg0w_^L5fudkl!A)jfuf=)qJ~HW6a~R!S&|KjCcA5P6D~zi@jmfB z@IDanzWvev$+OBTKcDwr!>deHZ-4tc-90@sJ#Skl%$t-~U9;`ZXJaPuI6j_ZNKn@~bBPWql z-)Sk1&S zbq&?FmWR5lW3}FLbFfzJz9IE(OU^|%Mu$rc3N&^Bc7Hn8np}$Bk>a-GGIX~!quqzL zBi~24SqKGh8v}^@(`z$#d#MFYuJn)scld3J>cHj z5V$R=g9oc)C1TULUD(u~;@#l;5##rxqweRy)!N8#r7x@LMeH@$8!Lm`2H1=>$vN~Z zZOseu-Y2-eQ5HD6X+W*76lFfO0> z8nwqAN9|!B%b(qM)b8E8mS=74+Qqfr!YV;#xxudXz}{!reBO?p&nVb) zj&*%2v951H+jU<;yRMw;YTNsA)@$EKoPDir+&gG}Y+fxpE;%^xL`^{7zKfG{ z&a?T8T~9|l{>9(9Q;_wDXCXHCO!S$EZ^O9A?e|}PIe9!Q?Uf(NiCEh_w8Pq*n~TU*gM#l``ieY(>Lx%&NKJ?Jri~8Q<^-g*LieEuil1jmm>D9^Ks_Sft~a0_M^2U@5|t9|JmKXf_(rnhjWqRRj_kmzXmo}d@o-I>mwgN zZ-BkG@OcxgpS-m^kCuyC-US!m!h6``&+2_R`8cZ&z{b6gM4k`9#k2YdTR-_YtB=74 z5%=kQ)cgt9xv)P4TW9RyGq67L;d2OVPexCF4wln5`c=*zv>)veb<}$pT&(vCY`OSf za89n+<9o0p{u^+y$G*k(+y)9h-+|@h+`b1JCm(r!0DJz?V?ToRlaIar1a_}+jz5FH bMm$IT-LIT}*1HMqe%0F3b6f4V<_hG0`eup2 literal 0 HcmV?d00001 diff --git a/shaders/spv/sign-activation.spv b/shaders/spv/sign-activation.spv index a2334b0cf3e4f45dde2a3eda1af1ba5bb577967e..00943054535bd0d1111691601f05e6fe82e97b82 100644 GIT binary patch delta 140 zcmeysy@Z>anMs+QfsK)Yn}K5@x4Aqg11kdq0}Bx6CT8XVDWC)cgD?;)0kQhTQ1y*R zG#D9~CckC$2a>i-!J8K_i7<-t0A&?`*czx;5lDkn^GrU;tgHx*TIjQkEnR%)4DVas7$t9U(sUVFUNa`~4l5tKD~Np(%r<2+-(161z^KRrRBZ*s)B{`kJ VG6swg(+q*M5fJliF66u{0su@~AsGMw delta 231 zcmX?ObHs*?nMs+Qfng$>jR+e9D+2=q3lQffX66AY25yFpd1)++%#*LO_yb8*)@&d- zkrl*#2xj}Tne(~>wOBEN6f-bLZf;~NW>jPXGOd8v8YpEAq-}s0B&Go5+kx1VMLFC> zK_YQL3{n*jq!WOcX>uV)GFWCZRAvg0o(h(E$)OCE*#ec>3Z%DzWfVD`!7>IwXMjvI P1ky%8%(S_Z^Rfs4lN%q> diff --git a/shaders/spv/surprise-momentum.spv b/shaders/spv/surprise-momentum.spv new file mode 100644 index 0000000000000000000000000000000000000000..59e92d4552d202dd16e78304ac82093dc42466c5 GIT binary patch literal 4268 zcmZvd33F6M6oosPB~hY)Y>I+|LayRUT1eds~ z@9RfT9^KZ@8ryd6+D^osuH*xWS?b1IMSWS;ikoc)mdIAh#$}JnCS~u+uFIB5li1}l z#L>V?`VFi9Lrz5zp(8c9J?%`z>5=wAXFATBo%WGKimB*#m-^%L z{kh3byPx%)R;+qXw97hKoJL8SwEDMr)=#oHnu_}gF;zF>Gij&OTbxI(&Lz#+xh!gE-UYe0_E8m{PMh7R z6`v3J5zi-aFPiPe4bE8d`6UKiQ$Oo*#XX?DuBfK=tD{Rh=LB6oJD)QSp~n0z>a|7T zs@tmmgg;QyKu5062Rto5Y|<#U;^%Y6H12JO`q?O%nel$bJ*~c}Ylz!f)KzwgIUued zb$iJ|zUGAdlSRHJrl(1>-#0w1ubavDSMs$4hfdv(wUcw+uej6d<3`-gnhQyk_Wa(a z#Bm7D$)57qvYQh=?eBp)$bC^>gIMa7MozA&Kh&|EJ^Dc39enq!{r$7*d+ zzvRCojH5-pu@<>9KV0je_;9TW`T1HE2c|7*mwip4M+|zC>O(!=R;Wkqm>!rG>QOuB zm1OAQhk94#hkDoKn;w`J>QPhZ<#$o`HD-D%G?Td-v>NMyS1A(xnv7PH&;4TXYvkWn zVCFD>yZRv+{NBQ+cHy?lXCAdiBmNHcjKQzg%6!jjvh`A;w&CIa&lp2vg?iQ_o;C5Q zQ#j7BN_|bnI_v>|wT#C6HR|mgD|Pa<1%JPK&dS>Bw6_ffUN21E@R?7JITIY~<5S0Q zjBhHi=}^O9bf{5$;vP}Y95{6FsnMLv_x*%?%k@rS_HDWc3ykiBFh07^sz(D3UHr*{ zvwR$uKP+*(V=ARL!j>rpjZReqm@maU2o^|0^7oQpq&(oF9x5ae% zBjaAtrM^w~nlL`PpQ}d$4qbd|JJh`)e?(^IxhYJ;NB<`r9mxEjb5}X9`kek!#K2>_ zwM+8DnQ7d?Ut-uhdtj|m89hF=QycKN4EQ?+eE#FY`TW0yetp2_e<~cG8t`Wde?q(D zo!u$pEt02wMIM$G{C=H?{o~sh-+wsf^G5h?@~(J07IRO*|E?I`&Al?_!T%}Eb?RyV zXg$XGUw|Wb;D2=qU^vr!P&i+=>fg>o@_A3J%NTLg#by~}?7dd_rm;e9IS^Nva@aDfL%-WWNqr$|%o6gI^gF45=!P_~H3-cX*h_Um>!mN8%#yD{6v@-c;@pBfg_$=y(bH6CY8@OoSI?GbX&rVp)sQ#-*VL!9?Vro9K6N6 zU&M6Pn`TcK-uBuTW*)r7WWuaxG4sO2!CTCNFz2#7o)d;6*3JxP@#lq!x4nHJ%>FOP zEapRD_>0OD_x6zt&gOh9JUHhQaqwfxD042#xa+Ml)`ZHm>yGz2X14pc_^M!D@&X@A@T;y-#EBV)D+&{5sz!AsU*Qozm z#@TPkO!u3D|4lJ#g};@-jmh|*_+CAY{|Uyvlfm)d05+W;gu{CLQU17${V>LU;BLzP N)?Cw|=f7r;>_2M1Gt2-0 literal 0 HcmV?d00001 diff --git a/shaders/spv/surprise-recall-blend.spv b/shaders/spv/surprise-recall-blend.spv new file mode 100644 index 0000000000000000000000000000000000000000..1640b2c55a5eb8b0357d42b6f9f5f03180ba9f7f GIT binary patch literal 3128 zcmZ9N*>W3I5QfK=C7U>8AsefaV+c!v36KOsHtaaqCWeHVeN$vhPDDnMiR8c?Pz6=+ zG8ETb@)EoWTNPD&-^_F}<>Tt=?!Wu|-RGPhxiUI4p43K?$z&qAljLiEGMbEFPA0WX zFRiYv&JKpH*;A)anz1LTWQpeNC98@aL%PjQ&%h*d0C^cXkGzd+BKzPoZ$Bcj`;`iR zqv+pps$`V8s>CLUbz0vLo5;9trj4bvzurt&dUyL9&0(kCTUjKh%AefVwgJ zI^9m%ywe$W+Elk*HN&;Z%`j_vH!gdgL-)6b8~tv(;HoBP!B+c&!D2fdHnVfoGR!?G zc%zjE~$i#oW^_;qLdXTva;vz82DUuhni{ z-5zdj5Ai*ZxNAL+Wp#aD*AjQ=`zw!UekHqa-*=hJcnj~*6`1cKHipRYBD!|ZuViRX}h9U_~lyOf7H^pmiy(_ zFV}htzg+7|k-JQ;wkv9dU#{i*L@jM=xvZsMuC;+*uI0Ps=eqxf_qoTiZ&bbtBHt!D z_ztksz0*XEYab#<{yO5SA)a zogZz*_(N#9#O2fXdls|5gXlUUhy4175SRN8qoc=zoazz$!<=v`i|08{VtXbrIn>jL zD>$FWw#F0eb#~2}!ge%0cG>yMdp@;ax%>m2I4?;KhV zF?sc!JF)0@0h@JE9%nFkoH2cQR?+SelSki~D)Th(8(Tl0lxG7wc$_VLdD>|Ah{>bB zSzz|#+x3xTB=1juD|eFY7 zeDh{XzWM&1m-o-@@cnHyK2KHWau4F{de-yBx%L+L1v<1J{fOtE^_dkhitF0}vSQnkhGp^ly|Y|kIIXSeTv$bAa!9`&E_ z9z#EaxQsuGHm;67pTm|T&T<^vv-llE&IxSsUyHN6fbBkU``3ROv7Y?evG*11oWGWR z7pL+4R=g`?<~e(35M#S}8U0t05yW1N#d&>OW5Is`yUc$vi{<>rg8vO{W8W6_=CQ@T z=SS84K<=vF$7TC2VtkEV!4k%@KFM^CuShE7)TCHS{}3 zex`iBuHx^UuWMl9KTyLx?;`FK51#AT6UcW(&KkCP;`8kCcj`UF-v-~qc=Yr>wy`+F z2iW#09yuRk*O9oxkFdqejhPjT{ExBYoZk5-h_f8~KP@ox+;2?G*>#ScEp?plGi>iE z&i6UCn6vACV`AlR;~0MQ^98mwe&c@pMsFZu4a9GB9qn?)Z=jn99-e|xd6^`mYV+q%Kg!wwE(a)?bMWe&Og IeVIf42Qw$s*Z=?k literal 0 HcmV?d00001 diff --git a/shaders/spv/swiglu-fused.spv b/shaders/spv/swiglu-fused.spv new file mode 100644 index 0000000000000000000000000000000000000000..92151357c1d639a948fa99dbf1203dd7f8ffc1e3 GIT binary patch literal 3932 zcmZ9N=XO+O6oya8Bm_dQihw}`QHr8ektVo86a}SPhe0K`iLcqJ^w+E9W^HJsQ5#%$cJnS{+mjjiE+OJ&=qiSI4S1kY2Y03O z4AslM^+tcC-Zy%oF;HpM8l!y&^Sq7d@y11BJCcXNM<<%Y2O6WzRx@GZ&Xj{r`YWx0 z;c~Nj9vr-9P#YbaXq5+RBPrjN#=D=|XsbFtQXQ;ST26ks3g>K|jmv($=%Lgr3~@54Tq`r4D@=xo2GUym)Gjkp$3VGcU@1u1)V?_%Pd zsVBA7rQ}|)HRP9Kug$S@=)Vu$hRE;AZNGVP);I5d#AW;gXnpeG-;{GJ^LMK+#GFU4 z{kHk0?V-IT=bU*fxCe=z+S?G9`P-giI zg~&OF_UW9{U(T_)))_!Y9pA3Daef1=BWE4$vpHv8yi@!84zV&HFHqkuy)*_hnw(!8c(4gElbnTJ*PwoZ|g#<8C~!65?t{euwEmThn*42Jx+g zy|G}+_fE0>wio05wuh}gKEGzT95eld~ds%$@|`b$nRld=hKe(2f@a> zzlirwUEX--*ACxiu$=k6sU3*#(6<&bJ9GXAV*HIgikK(=Ju}Jicb)RT(PKB*^(Q|0 z#}GNsHpAII_P{xR^xg-yx9>Q5?+5QgqW9xq^W^`aUes0p!bv;J0mK=-N8|13Tu-1~ z&SlI&M9#Uye6ntnI8k2C-3~)Pa*cRzBqE92N&=31?;KyUxbsNq`q(JCB!`WsPi(| zIj`n1uYiq{-_LK#Z^IenG~#>FZ?4~m3UUVV9FOOA`1-NsqR$X`4HD-u44yh~4NgAh zeHCn;eAGDyF3wxWHcmd~9Rd4Ya$fy0(W?oeO|;a?(-7%)P24NCm;K~3^q?b_W3%vxX(ARjgyakz6tg_>OS?yOmBhp#Xhfq zol`zyu7Z8t=F8r{jY;P=dj<+ zZFBuTYVXYN;3jw`A|HEtFU3is&zwoLT+H_ocq0<;?PIX_Q2g$Ef-N8S@G00l`Ka?5 z*!zs%$LC<<MR6X=L&Jo zvj~x!g{8e1X+!Qp^u?Q60$zaVznSl4Y0l~Un?Bn1vv2gXj{Tzka&Xkw=UKT|`N&xT gmJ54jZpSQlV>^p{%+dokPoJ}BJB#P?Pksmg0+ui~f&c&j literal 0 HcmV?d00001 diff --git a/shaders/spv/synapsis-forward.spv b/shaders/spv/synapsis-forward.spv index 6ff60fa9cfc1a26feaa527d83c40ba4c27ee3504..7ba6c0768865eb539726d4de0f1e1e69158acb9b 100644 GIT binary patch delta 176 zcmZ1>wML4KnMs+Qfo&q2jXF02D+2=q3lQffX66AY25umh0Ag<-4hG^>ATHe)+Rwzu zH2Eo0HjoTt29bNgq$Z2`<^q-iMn;~=4_SqId4P%)fX4X(X`ac9Z0e#Qkw74}28sp& mX^@dTlWp0IfjSD=gn?u$n=)8cF;rCvkS+zO+Ps%7n+*U^#~Ql; delta 198 zcmZ1@wL*%GnMs+Qfng$>jW8PnD+2=q3lQffX66AY2JVfS{Y;F^lYcU01Ibio5P1YBG9^%%QXpLhlG$v@ww?_DOI8=} delta 166 zcmaDM`9hM7nMs+Qfng$>jW8PnD+2=q3lQffX66AY2JVfi^O+c#C$lnV1Ia=#d6(Io z*Bz*;1|$!}lA8rtk{K15fJ|Q?wgyW10cn3A28k&E`GFwz7?~#P zGG_zHnP8HY#eB05O9i7M4^V*(5L*L{um#c}O&~D^Am0_lp8S+mT@)nZ4a6W-K0q2| aEYD;`HfOL*8C0emNLPSlHdnGeW(NSgDHehN delta 166 zcmca3b4P}anMs+Qfng$>jW8PnD+2=q3lQffX66AY2JVfi*O?fZC)+Y-1Id+OQk2D< z*Bz*;3nUN3lAAqQsu>lTfJ_@8wgyVs0%OiE+ZbBR z!pJoFBuh4sG-U;m3&A8W8;D)WX1@6VTLhyb4^Xi$5L*L{^#jr%gFs>mKz<;IJ=u~Y zoRMepL=Isfxs*d$6eODi#2_8HKpJEv&*Zxt%0PKWPLL{5&SJ2tB~VpMf%Gzvs?Bpb G`#Ax9WFW5q delta 188 zcmdm@wnmMOnMs+Qfng$>jW8PnD+2=q3lQffX66AY2JVfy#Vm}>lOM8V14&<25V;df zs4c-QILoh w5Q9`{18I;2Jd+jKo53jW8PnD+2=q3lQffX66AY2JVfShnX3fC+o6g1Id|Sl9kmT z$PQ&S=XD3FT?tYJ#FCpkS?w7WnSe|sAhrfdDFbO0AO?vk0QssQ_T-mr>Woa21=)px iq$+!}C`dL6h(S7{fpiR5)n4{skSY$Ks?Dk#mpK8$)fm_S diff --git a/shaders/spv/vsa-explore.spv b/shaders/spv/vsa-explore.spv new file mode 100644 index 0000000000000000000000000000000000000000..4498227883fe16a5bf79b59cdd831a6a9a2692ea GIT binary patch literal 5208 zcmZ{m>6cXH5r;3#Op8KP6mXZp4Hra>YebC*Ix`W)4R^iXo@tH)_CRUol~cN^{aZT>b`Hix2JFBp}Qot8OiKq zR`O-id*&uHlNrF-Nv+UBt1ex&w40ADJ?Z4*jo3ZuD*}z#gQx*?KQfUuTRLVV^N=-2 zifl)AAbXO}xVealU9a@OyUmF- z>%}gc=6pTbUd=PmzO40!#AW>vaI@fR$))Jk^=_lHvDoA3y^;5!w_?3Bxw;)^BJx1)^!ejTz3n4y7#(VExC1?D|(0R^}7ijy>Ahhy>A8gdLKyM1Ygna zZ0JsAqaKDcAm?*)g|`pyf{(NJIePU}J~_p1yGGlmU2y4GmS@EE6?aE|h&A3V<-o%3 zo1UFvBArK{LwhO-&Y6>QhJth6bMEB!MeOquw4CdQ zcn8nXn!_-0rp8)3Y}fl3?>R><@gf~O=GH!>uzL2j;A$;rAZo=tM6Ix^wVaKpCAOA( zW-aY%tv9f%wLIrut==497W<%fz2|K{W5TZH`wYIC&l!k(VxJ#+?Js7*e0_-dcAz8Q z-mG-(BE0&=^NBJ4AmTHKy$f2A-v@iw61ztKZs-|^{Ib$^=H!mScAYcgWBf1B`s5c- z+3))xaxfX3N%^0WUR-~&H-2jwYdvjeOV0l1qX!XdSYLYq;$!`N(D4rQcAf5ubyo;vJAX80?H| zoA)xb`zPo7YCF?%`mZXnYt7?)M;_-_+qmn|u8}j3w(~0II%n0rvc7Yw?Yg(n#>lx& zdwa?0|FFcajeYq9oXl{4YQ9;v_QIX)#XH&lmQ~|Vtk|BZzg<*50Mv5$~J(b0&M>T<(F$pT!QjAKKyb+brjL@4~)__s2UD zG5eML>%@3J4nSNd|6g(*g!cK4dioAT7UB9`kcQRPc9wO!xa5d+t*v845(^-`hJEwkc z#LlSmAa*{*-*O&jqCFS&H=NzC(Ptw*`j?~itD}c=z~+cES_yVP{gy?{xnTJ>c^Bv9 zJj8YK&WQF|i1o}bwznblaEZMq(dTc##;NUd75W0i$38Da>sLn)tHI_7`y#M=7d>4J zHb&l_w9iM(Z7y-Fy&UX$g?$Cs+;LV{f{l@1&QxB7z8YDBEI{BaJBvzwsG=NzX47W>+6sDYr*>N;%w|^9b!-N5wjlbS=?U6G{JJl z#yQEk_nw3MrH=k?1XugpfGy{YxL%)}>*H+Q_n4`1uroD?M9vAY{A9V0EwEhl+y>k8 z-DS^{U^(NWXSs;K32c1W9k6%iqjF}uU}NO{#!aDpoE3dJBImn{BhN;#d97(aZGAQL z&4_2M&pWB@9z-6s_StCjs$=bM!LE&&`yE&=_Eevo>*HQ~zpQl&`t%atik?&AO(k~C z@4?nod&h1=Z$^Cd-;UO=j-KuSn^gb(Qu_~x^~^7}=Pl^? z7HkFEvwhr+HcoAy_n`lT_~`#LTE9B_yccYauOT%n5$o%Z`cHuM#kb%|usz8~%u`^`BEAJrgXN5kbCPrKJqPzo9sNB6uJ-pV zwp`4uKDp}Iy6-Vl&w-t(K_qfM50;N_!3$uy==nvkJ;%4;C9s@v(X(8{zYI1$>{r0v znfMmG3N}XGZ^1v%KF*51zaw(KyEyW^1~zZpL2Z3C^gj{LTAz1P+dYUpYVFNv^QvR* zzre1InR^{97kjEt&h>GxZ^MrDZ-bX1vB&R#-DB5B%)4Oun7#MFu9J_L_rdlaeS82m kPTtx35d9Id9dY*b|C|4A=c7MH^esW=AlCN(#M6-f1JDfJwEzGB literal 0 HcmV?d00001 diff --git a/shaders/spv/vsa-logic-apply.spv b/shaders/spv/vsa-logic-apply.spv index baeba5384c8053013b01ce003ec3b4df17a067ab..a1e2d20741a7ee14f3f2bce11b13c68a30db48ef 100644 GIT binary patch delta 71 zcmca0(jmgl%%sfDz{beH&A>5{+gzQSft7)Qfdz|oq*VD RVx;=U6D`b}V_3eh0st*o3ZDP~ delta 141 zcmeAWxgf&L%%sfDz|6?N&A>a6+gzNTft7)QfdzIQ#NoH9pNHYhLy3D-foYE8|1snTYnR(rT Ps=}a#OKuiq`OXRe8si~| diff --git a/shaders/spv/whitening-apply.spv b/shaders/spv/whitening-apply.spv index 3cc5be6d2544e38f2d4551a05c05bc246d3a5ba5..18dcb9adea1e76ff4a9161832046ce143a9a3e0f 100644 GIT binary patch delta 176 zcmca0dPS6tnMs+Qfo&q2jXF02D+2=q3lQffX66AY25ul00b+e1wg6%WAP(Oc+Rens zH2EP@HjwmX29Z0#q$-Q~<{TCqMnxW=GF>3H1{$UZq(Rz2VhTXMA&5QsB&#|j&*YD+ n!a$OhO<5EqYX`(29ri%l0f>1f+p;+WjW8PnD+2=q3lQffX66AY2JVfS-As(klfN=$1Ia{Y5P23% zTC$k)x&zh5f>Z&qQd?XRu5tR3;2ahl6A`&t$V^0RRJP7{y?%3OkM<&imc|m?hGsp zTNoG^xPe%5b0}*tqaqWKX#vF6Kq*TgZ3V<2F$Eys2E?9xl1-hFY4S%lVIaxM?kozD rZ3SYGjy53O4ptS*9t@P<$PQ9Jz17bT@)l@ w55yo<4nP`Y0ng+}wqUSK3RET)NQ129nS7AV7%bBTmFWi3Js_FQob28l0P&6)y8r+H delta 170 zcmca1cSVkknMs+Qfng$>jW8PnD+2=q3lQffX66AY2JVfS!pw}!lWUo?f#geOe;{ee z0wOzE%z52`YA1nI0kP!fy)4d*j7*am*@StSfC36YNi!hLG+B{NT@)l@55(3$Q3oLH l2*ga2W7&d%IySNi1IeRo#$Z)lP*vSPx(B3cGcUV02LR}I81Dc8 diff --git a/shaders/vsa-explore.glsl b/shaders/vsa-explore.glsl index 600ea18..bd2c7e6 100644 --- a/shaders/vsa-explore.glsl +++ b/shaders/vsa-explore.glsl @@ -49,9 +49,13 @@ layout(std430, binding = 2) buffer Workspace { int workspace[]; // [n_states * n_states], init to 0 }; -// Output: best transition matrix (float, normalized) -layout(std430, binding = 3) writeonly buffer Output { - float output[]; // [n_states * n_states] +// Output: best transition matrix (float, normalized). +// NOTE: not ``writeonly`` — the shader also reads this buffer as a scratch +// space during exploration (see similarity() below, and the block_bind +// "reusing output buffer temporarily" path). +// ``output`` is a reserved word in recent glslang — renamed to output_data. +layout(std430, binding = 3) buffer Output { + float output_data[]; // [n_states * n_states] }; // ── Per-block circular convolution (bind) ──────────────────── @@ -66,7 +70,7 @@ void block_bind(uint block_offset_a, uint block_offset_b, sum += obs[block_offset_a + m] * obs[block_offset_b + idx_b]; } // Store in shared memory (reusing output buffer temporarily) - output[block_offset_out + i] = sum; + output_data[block_offset_out + i] = sum; } } @@ -74,7 +78,7 @@ void block_bind(uint block_offset_a, uint block_offset_b, float similarity(uint vec_offset, uint cb_entry, uint dim) { float dot = 0.0; for (uint i = 0; i < dim; i++) { - dot += output[vec_offset + i] * codebook[cb_entry * dim + i]; + dot += output_data[vec_offset + i] * codebook[cb_entry * dim + i]; } return dot / float(k); // Normalize by number of blocks }