From 5b1b9be7cc2eb9621aaf81e53c78876c15174b91 Mon Sep 17 00:00:00 2001 From: sourcepirate Date: Sat, 30 May 2026 23:03:31 +0530 Subject: [PATCH 1/4] Updated docs --- docs/activations/activations.md | 77 ++++ docs/callbacks/callbacks.md | 60 +++ docs/data/data.md | 21 + docs/engine/node.md | 113 +++++ docs/initializers/initializers.md | 52 +++ docs/layers/attention/kv_cache.md | 48 +++ docs/layers/attention/mla.md | 43 ++ docs/layers/base.md | 135 ++++++ docs/layers/core/core_utility_layers.md | 93 +++++ docs/layers/core/dense.md | 110 +++++ docs/layers/core/input_layer.md | 107 +++++ docs/layers/core/merging.md | 125 ++++++ docs/layers/embedding/embedding.md | 43 ++ docs/layers/normalization/normalization.md | 78 ++++ docs/layers/pooling/pooling.md | 62 +++ docs/layers/recurrent/recurrent.md | 84 ++++ docs/layers/transformer/transformer_block.md | 89 ++++ docs/losses/losses.md | 58 +++ docs/metrics/metrics.md | 52 +++ docs/models/language_models.md | 127 ++++++ docs/models/model.md | 273 +++++++++++- docs/models/vision_models.md | 95 +++++ docs/optimizers/optimizers.md | 68 +++ docs/preprocessing/preprocessing.md | 58 +++ docs/tokenizers/tokenizers.md | 56 +++ docs/utils/utils.md | 83 ++++ examples/README.md | 15 +- examples/mnist_functional_residual.py | 89 ++++ neutro/__init__.py | 1 + neutro/engine/__init__.py | 0 neutro/engine/node.py | 38 ++ neutro/layers/__init__.py | 2 + neutro/layers/attention/flash_attention.py | 7 +- neutro/layers/attention/mla.py | 3 + neutro/layers/base.py | 36 +- neutro/layers/convolutional/conv1d.py | 4 +- neutro/layers/core/dense.py | 1 + neutro/layers/core/input_layer.py | 51 +++ neutro/layers/core/merging.py | 161 ++++++- neutro/layers/core/moe.py | 3 + neutro/layers/core/reparameterization.py | 5 + neutro/layers/embedding/time_embedding.py | 3 + neutro/layers/normalization/layernorm.py | 4 +- neutro/layers/recurrent/simple_rnn.py | 4 +- .../layers/transformer/transformer_block.py | 3 + neutro/models/base_model.py | 394 ++++++++++++++++-- neutro/models/vision/unet.py | 4 + pyproject.toml | 18 +- tests/test_functional_api.py | 158 +++++++ tests/test_mimo_fit.py | 116 ++++++ tests/test_shared_transformer_block.py | 81 ++++ 51 files changed, 3331 insertions(+), 80 deletions(-) create mode 100644 docs/activations/activations.md create mode 100644 docs/callbacks/callbacks.md create mode 100644 docs/data/data.md create mode 100644 docs/engine/node.md create mode 100644 docs/initializers/initializers.md create mode 100644 docs/layers/attention/kv_cache.md create mode 100644 docs/layers/attention/mla.md create mode 100644 docs/layers/base.md create mode 100644 docs/layers/core/core_utility_layers.md create mode 100644 docs/layers/core/dense.md create mode 100644 docs/layers/core/input_layer.md create mode 100644 docs/layers/core/merging.md create mode 100644 docs/layers/embedding/embedding.md create mode 100644 docs/layers/normalization/normalization.md create mode 100644 docs/layers/pooling/pooling.md create mode 100644 docs/layers/recurrent/recurrent.md create mode 100644 docs/layers/transformer/transformer_block.md create mode 100644 docs/losses/losses.md create mode 100644 docs/metrics/metrics.md create mode 100644 docs/models/language_models.md create mode 100644 docs/models/vision_models.md create mode 100644 docs/optimizers/optimizers.md create mode 100644 docs/preprocessing/preprocessing.md create mode 100644 docs/tokenizers/tokenizers.md create mode 100644 docs/utils/utils.md create mode 100644 examples/mnist_functional_residual.py create mode 100644 neutro/engine/__init__.py create mode 100644 neutro/engine/node.py create mode 100644 neutro/layers/core/input_layer.py create mode 100644 tests/test_functional_api.py create mode 100644 tests/test_mimo_fit.py create mode 100644 tests/test_shared_transformer_block.py diff --git a/docs/activations/activations.md b/docs/activations/activations.md new file mode 100644 index 0000000..9baf2bf --- /dev/null +++ b/docs/activations/activations.md @@ -0,0 +1,77 @@ +# Activation Functions + +## Theory + +Activation functions introduce non-linearity into neural networks. Without them, stacking linear layers would collapse into a single linear transformation. + +### ReLU — `neutro/activations/relu.py` + +$$\text{ReLU}(x) = \max(0, x)$$ + +$$\text{ReLU}'(x) = \mathbf{1}_{x > 0}$$ + +- **Gradient**: 1 for positive inputs, 0 for negative. This causes the "dying ReLU" problem where neurons can get stuck at 0. + +### Sigmoid — `neutro/activations/sigmoid.py` + +$$\sigma(x) = \frac{1}{1 + e^{-x}}$$ + +$$\sigma'(x) = \sigma(x)(1 - \sigma(x))$$ + +- Output range: (0, 1). Used for binary classification or as gating mechanism (LSTM, GRU). +- **Vanishing gradient**: for very large or very small inputs, the gradient approaches 0. + +### Tanh — `neutro/activations/tanh.py}$ + +$$\tanh(x) = \frac{e^x - e^{-x}}{e^x + e^{-x}}$$ + +$$\tanh'(x) = 1 - \tanh^2(x)$$ + +- Output range: (-1, 1). Zero-centered, often preferred over sigmoid in hidden layers. + +### Softmax — `neutro/activations/softmax.py` + +$$\text{Softmax}(x_i) = \frac{e^{x_i}}{\sum_j e^{x_j}}$$ + +- Output: probability distribution over classes. +- **Jacobian-Vector Product** (`gradient_fast`, line 18): computes $y * (\text{grad\_output} - \sum(y * \text{grad\_output}))$ without building the full $N \times N$ Jacobian. + +### SiLU — `neutro/activations/silu.py$ (Sigmoid Linear Unit) + +$$\text{SiLU}(x) = x \cdot \sigma(x)$$ + +$$\text{SiLU}'(x) = \sigma(x) + x \cdot \sigma(x) \cdot (1 - \sigma(x))$$ + +- Also called Swish. Used in modern architectures (e.g., Llama, GPT). + +## Implementation Guide + +All activations follow the same pattern: + +```python +class ReLU: + def forward(self, x): ... + def gradient(self, x): ... # element-wise gradient + def gradient_fast(self, x, grad): ... # fused JVP (optional) +``` + +- `forward` is used by `Dense` and other layers in the forward pass. +- `gradient` returns the element-wise derivative, which is multiplied by the upstream gradient in `Dense.backward`. +- `gradient_fast` is an optimization used by Softmax to avoid the full Jacobian matrix. + +## Usage Example + +```python +from neutro.activations import get_activation + +relu = get_activation('relu') +x = np.array([-1, 0, 2]) +y = relu(x) # [0, 0, 2] +dy = relu.gradient(x) # [0, 0, 1] +``` + +## References + +- Nair, V., & Hinton, G. E. (2010). **Rectified Linear Units Improve Restricted Boltzmann Machines**. +- Hendrycks, D., & Gimpel, K. (2016). **Gaussian Error Linear Units (GELUs)**. +- Elfwing, S., Uchibe, E., & Doya, K. (2018). **Sigmoid-weighted linear units for neural network function approximation in reinforcement learning**. diff --git a/docs/callbacks/callbacks.md b/docs/callbacks/callbacks.md new file mode 100644 index 0000000..814cac4 --- /dev/null +++ b/docs/callbacks/callbacks.md @@ -0,0 +1,60 @@ +# Callbacks + +## Theory + +Callbacks are objects that hook into the training loop at various points. They allow you to monitor training, save checkpoints, adjust learning rates, and stop training early without cluttering the training loop itself. + +**Hook points** (in order): +1. `on_train_begin` / `on_train_end` +2. `on_epoch_begin` / `on_epoch_end` +3. `on_batch_begin` / `on_batch_end` + +## Implementation Guide + +### File: `neutro/callbacks/base.py` + +```python +class Callback: + def set_model(self, model): ... + def on_train_begin(self, logs=None): ... + def on_train_end(self, logs=None): ... + def on_epoch_begin(self, epoch, logs=None): ... + def on_epoch_end(self, epoch, logs=None): ... + def on_batch_begin(self, batch, logs=None): ... + def on_batch_end(self, batch, logs=None): ... +``` + +All methods are no-ops by default. Subclasses override the needed hooks. + +### History — `neutro/callbacks/history.py` + +Records per-epoch metrics into `history.history` dict (keys: `loss`, `val_loss`, `accuracy`, etc.). + +### EarlyStopping — `neutro/callbacks/early_stopping.py` + +Monitors a metric (e.g., `val_loss`) and stops training if it hasn't improved for `patience` epochs. Uses `model.stop_training = True`. + +### ReduceLROnPlateau / LR Scheduler — `neutro/callbacks/lr_scheduler.py` + +Reduces the learning rate when a metric plateaus, or follows a predefined schedule. + +### Checkpoint — `neutro/callbacks/checkpoint.py` + +Saves the model to disk at the end of each epoch using `joblib.dump`. + +## Usage Example + +```python +from neutro.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau + +callbacks = [ + EarlyStopping(monitor='val_loss', patience=5), + ModelCheckpoint('best_model.pkl', save_best_only=True), + ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=3), +] +model.fit(X, y, callbacks=callbacks, epochs=100) +``` + +## References + +- Keras Callbacks API. [Keras.io](https://keras.io/api/callbacks/) diff --git a/docs/data/data.md b/docs/data/data.md new file mode 100644 index 0000000..dd89dff --- /dev/null +++ b/docs/data/data.md @@ -0,0 +1,21 @@ +# Data + +## DataLoader — `neutro/data.py` + +A simple data loader for batching and shuffling: + +```python +class DataLoader: + def __init__(self, dataset, batch_size=32, shuffle=True): + self.dataset = dataset + self.batch_size = batch_size + self.shuffle = shuffle + + def __iter__(self): + indices = np.arange(len(self.dataset)) + if self.shuffle: + np.random.shuffle(indices) + for i in range(0, len(indices), self.batch_size): + batch_idx = indices[i:i + self.batch_size] + yield self.dataset[batch_idx] +``` diff --git a/docs/engine/node.md b/docs/engine/node.md new file mode 100644 index 0000000..300b9ff --- /dev/null +++ b/docs/engine/node.md @@ -0,0 +1,113 @@ +# KerasTensor, Node, and the Functional API Graph Engine + +## Theory + +The Functional API lets you build models as directed acyclic graphs (DAGs) of layers, rather than as linear stacks. This requires a mechanism to track *symbolic* data flow during model construction, before any real data is seen. + +Two core classes enable this: + +- **`KerasTensor`**: A symbolic placeholder representing the *future* output of a layer. It carries a `shape` but no actual data. +- **`Node`**: A record of one *call* to a layer. It links input `KerasTensor`s → output `KerasTensor`s and is stored on the layer's `_inbound_nodes` list. + +When you write `outputs = Dense(32)(inputs)`, the layer's `__call__` method detects that `inputs` is a `KerasTensor`, builds the layer (if needed), computes the output shape symbolically, wraps it in a new `KerasTensor`, and records a `Node`. No NumPy computation occurs. + +Later, `Model._init_graph` traverses the graph backward from the outputs to discover all reachable `Node`s and `Layer`s, producing a topological ordering used for forward and backward execution. + +## Implementation Guide + +### `KerasTensor` — `neutro/engine/node.py:3-13` + +```python +class KerasTensor: + def __init__(self, shape, node=None, name=None): + self.shape = shape + self.node = node # The Node that produced this tensor + self.name = name +``` + +- `shape` is a tuple like `(None, 32)` — the batch dimension is `None` (unknown until runtime). +- `node` is set when a `Node` is created and links back to the producing layer. + +### `Node` — `neutro/engine/node.py:15-38` + +```python +class Node: + def __init__(self, layer, input_tensors, output_tensors): + self.layer = layer + self.input_tensors = input_tensors + self.output_tensors = output_tensors + layer._inbound_nodes.append(self) + # Link output tensors back to this node + if isinstance(output_tensors, list): + for t in output_tensors: + t.node = self + else: + output_tensors.node = self +``` + +Key behaviors: +- **Registration**: The node registers itself on `layer._inbound_nodes`, enabling multi-parent graph traversal. +- **One layer, many nodes**: A shared layer used 3 times will have 3 entries in `_inbound_nodes`, each with different input/output tensors. +- **List outputs**: Layers like `Add` that take lists of inputs store the lists in `input_tensors`. Multi-output layers store lists in `output_tensors`. + +### How `Layer.__call__` triggers Node creation — `neutro/layers/base.py:67-105` + +The symbolic path (line 77-97): + +```python +if is_symbolic: + if not self.built: + self.build(input_shapes) # e.g., Dense.build((None, 10)) + output_shape = self.compute_output_shape(input_shapes) + output_tensors = KerasTensor(shape=output_shape) + Node(self, input_tensors=inputs, output_tensors=output_tensors) + return output_tensors +``` + +This is a **zero-computation** path: no `forward` is called, only shape inference. + +## Graph Discovery (`Model._init_graph`) — `neutro/models/base_model.py:25-62` + +```python +def traverse(tensor): + if hasattr(tensor, 'node') and tensor.node: + node = tensor.node + if node not in visited_nodes: + visited_nodes.add(node) + # Recursively visit inputs + if isinstance(node.input_tensors, list): + for t in node.input_tensors: + traverse(t) + else: + traverse(node.input_tensors) + nodes_ordered.append(node) +``` + +This produces `_nodes_ordered` in **reverse topological order** (inputs before outputs). The backward pass iterates `reversed(_nodes_ordered)`. + +## Usage Example + +```python +from neutro.layers import Input, Dense +from neutro.models import Model +from neutro.engine.node import KerasTensor, Node + +# Symbolic construction +inputs = Input(shape=(4,)) # returns a KerasTensor +x = Dense(8, activation='relu')(inputs) # Layer.__call__ creates a Node +outputs = Dense(1)(x) + +# Inspect the graph +print(type(inputs)) # +print(inputs.shape) # (None, 4) +print(outputs.node.layer) # Dense(1) — the final layer + +# Model discovers nodes via traversal +model = Model(inputs=inputs, outputs=outputs) +print(len(model._nodes_ordered)) # Number of Nodes discovered +``` + +## References + +- Chollet, F. (2015). **Keras** — the Functional API was introduced in Keras 1.0. [GitHub](https://github.com/keras-team/keras) +- Keras Functional API Guide. [Keras.io](https://keras.io/guides/functional_api/) diff --git a/docs/initializers/initializers.md b/docs/initializers/initializers.md new file mode 100644 index 0000000..b27e36a --- /dev/null +++ b/docs/initializers/initializers.md @@ -0,0 +1,52 @@ +# Initializers + +## Theory + +Weight initialization is critical for training deep networks. Poor initialization can cause vanishing/exploding gradients. `neutro` implements several strategies. + +### Glorot (Xavier) Uniform — `neutro/initializers/glorot.py` + +$$W \sim U\left[-\sqrt{\frac{6}{n_{\text{in}} + n_{\text{out}}}}, \sqrt{\frac{6}{n_{\text{in}} + n_{\text{out}}}}\right]$$ + +Recommended for layers with tanh or sigmoid activation. + +### He Initialization — `neutro/initializers/he.py` + +$$W \sim N\left(0, \sqrt{\frac{2}{n_{\text{in}}}}\right)$$ + +Recommended for layers with ReLU activation. Keeps variance of activations constant across layers. + +### Constant — `neutro/initializers/constant.py` + +$W = c$ for a constant $c$. Used for bias initialization (typically $c=0$). + +### Random — `neutro/initializers/random.py` + +$$W \sim N(\text{mean}, \text{stddev})$$ + +## Implementation Guide + +All initializers are callable objects: + +```python +class GlorotUniform: + def __call__(self, shape): + limit = np.sqrt(6 / (shape[0] + shape[1])) + return np.random.uniform(-limit, limit, size=shape) +``` + +They are instantiated in layer `__init__` and called in `build`: + +```python +class Dense(Layer): + def __init__(self, units, kernel_initializer='glorot_uniform', ...): + self.kernel_initializer = get_initializer(kernel_initializer) + + def build(self, input_shape): + self.params['W'] = self.kernel_initializer((input_shape[-1], self.units)) +``` + +## References + +- Glorot, X., & Bengio, Y. (2010). **Understanding the difficulty of training deep feedforward neural networks**. *AISTATS*. +- He, K., et al. (2015). **Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification**. [arXiv:1502.01852](https://arxiv.org/abs/1502.01852) diff --git a/docs/layers/attention/kv_cache.md b/docs/layers/attention/kv_cache.md new file mode 100644 index 0000000..fc9d5af --- /dev/null +++ b/docs/layers/attention/kv_cache.md @@ -0,0 +1,48 @@ +# KV Cache + +## Theory + +The KV Cache is an optimization for autoregressive generation. During inference, each new token attends to all previous tokens. Without caching, we recompute $K$ and $V$ for every token at every step — an $O(L^2)$ cost per step. The KV Cache stores the key and value projections from previous steps, reducing per-step cost to $O(L)$. + +### How it works + +At step $t$: +1. Compute $Q_t$ from the current token only (shape: $(1, 1, d)$). +2. Fetch $K_{1:t-1}, V_{1:t-1}$ from cache. +3. Compute $K_t, V_t$ from the current token and **append** to cache. +4. Compute attention: $Q_t \cdot [K_c, K_t]^T$. + +## Implementation Guide + +### File: `neutro/layers/attention/kv_cache.py` + +```python +class KVCache: + def __init__(self): + self.k_cache = {} # {layer_id: ndarray} + self.v_cache = {} + + def get_or_create(self, layer_id, shape): + if layer_id not in self.k_cache: + self.k_cache[layer_id] = np.zeros(shape) + self.v_cache[layer_id] = np.zeros(shape) + return self.k_cache[layer_id], self.v_cache[layer_id] +``` + +- `layer_id` distinguishes which layer the cache belongs to (each TransformerBlock has its own cache). +- The cache shape is `(batch, num_heads, seq_len, head_dim)`. +- In `TransformerBlock.forward`, the cache is populated at line 50: cached values are retrieved, and the mask is extended to account for past tokens. + +## Usage Example + +```python +from neutro.layers.attention.kv_cache import KVCache + +cache = KVCache() +model.generate(start_tokens, max_new_tokens=100, temperature=0.8) +# Internally uses KVCache for efficient decoding +``` + +## References + +- Vaswani, A., et al. (2017). **Attention Is All You Need**. [arXiv:1706.03762](https://arxiv.org/abs/1706.03762) diff --git a/docs/layers/attention/mla.md b/docs/layers/attention/mla.md new file mode 100644 index 0000000..b562a6a --- /dev/null +++ b/docs/layers/attention/mla.md @@ -0,0 +1,43 @@ +# Multi-Head Latent Attention (MLA) + +## Theory + +MLA is an attention variant used in DeepSeek models that reduces the KV cache size by compressing keys and values into a latent space. Instead of caching the full $K, V$ projections, MLA caches a compressed latent vector and reconstructs $K, V$ on the fly. + +### Standard Attention (per head) + +$$Q = XW^Q,\quad K = XW^K,\quad V = XW^V$$ + +### MLA + +$$c_t = \text{RMSNorm}(X_t W^{\text{down}}) \quad \text{(compress to latent)}$$ +$$K_t = c_t W^{\text{up}}, \quad V_t = c_t W^{\text{up}} \quad \text{(reconstruct)}$$ + +The KV cache stores only $c_t$ (latent), not $K_t, V_t$, reducing memory by a factor of $d_{\text{model}} / d_{\text{latent}}$. + +## Implementation Guide + +### File: `neutro/layers/attention/mla.py` + +```python +class MLA(Layer): + def __init__(self, num_heads, key_dim, latent_dim=128, ...): +``` + +- `latent_dim`: the compressed representation size (typically much smaller than `num_heads * key_dim`). +- The layer implements both the compression (`W_down`) and reconstruction (`W_up`) projections. +- During forward, it caches the latent `c_t` instead of the full `K, V` tensors. + +## Usage Example + +```python +from neutro.layers.attention.mla import MLA + +mla = MLA(num_heads=8, key_dim=64, latent_dim=32) +x = np.random.randn(2, 16, 512) +y = mla(x) # shape (2, 16, 512) +``` + +## References + +- DeepSeek-AI. (2024). **DeepSeek-V2: A Strong, Economical, and Efficient Mixture-of-Experts Language Model**. [arXiv:2405.04434](https://arxiv.org/abs/2405.04434) diff --git a/docs/layers/base.md b/docs/layers/base.md new file mode 100644 index 0000000..c9a21f5 --- /dev/null +++ b/docs/layers/base.md @@ -0,0 +1,135 @@ +# Layer Base Class + +## Theory + +Every neural network layer in `neutro` inherits from `neutro.layers.base.Layer`. The base class defines the **layer lifecycle**: + +1. **Construction** (`__init__`): Set hyperparameters (units, kernel size, etc.). Do NOT allocate parameters yet. +2. **Build** (`build`): Allocate parameters based on the input shape (`self.params['W']`, `self.params['b']`, etc.). +3. **Call** (`__call__`): Dispatch — if inputs are symbolic `KerasTensor`s, do shape inference + node creation; if inputs are real NumPy arrays, run `forward`. +4. **Forward** (`forward`): Compute output from input. +5. **Backward** (`backward`): Compute gradient w.r.t. input and store gradients for parameters. + +This deferred parameter allocation (build on first call) is the Keras convention: you don't need to specify input dimensions when constructing a layer — they are inferred from the data. + +### Symbolic vs Eager Execution + +A single `Layer.__call__` handles both modes: + +- **Symbolic** (during model construction): Input is a `KerasTensor`. No NumPy computation happens; only shape inference and graph recording. +- **Eager** (during training/inference): Input is a NumPy array. The full forward pass runs. + +## Implementation Guide + +### File: `neutro/layers/base.py` + +### `__init__` — line 4 + +```python +class Layer: + def __init__(self, name=None, **kwargs): + self.name = name + self.trainable = True + self.built = False + self.params = {} # {param_name: ndarray} — stores weights + self.grads = {} # {param_name: ndarray} — stores gradients + self.input_shape = kwargs.get('input_shape') + self.output_shape = None + self._inbound_nodes = [] # Graph connectivity (Functional API) +``` + +- `built` starts as `False`. It becomes `True` after `build()` is called. +- `params` and `grads` are dicts so layers can have arbitrary parameter names (`W`, `b`, `gamma`, `beta`, etc.). + +### `__call__` — line 67 — the dispatch hub + +```python +def __call__(self, inputs, *args, **kwargs): + from ..engine.node import KerasTensor, Node + + is_symbolic = isinstance(inputs, KerasTensor) or \ + (isinstance(inputs, list) and any(isinstance(i, KerasTensor) for i in inputs)) + + if is_symbolic: + # Symbolic: build, infer shape, create Node + if not self.built: + self.build(input_shapes) + output_shape = self.compute_output_shape(input_shapes) + output_tensors = KerasTensor(shape=output_shape) + Node(self, input_tensors=inputs, output_tensors=output_tensors) + return output_tensors + + # Eager: build if needed, then forward + if not self.built: + self.build(inputs.shape if not isinstance(inputs, list) else [i.shape for i in inputs]) + return self.forward(inputs, *args, **kwargs) +``` + +Key detail: the symbolic path calls `build(input_shapes)` with tuples like `(None, 32)`. The eager path calls `build(inputs.shape)` with concrete shapes like `(64, 32)`. + +### `sublayers` property — line 18 + +```python +@property +def sublayers(self): + layers = [] + for attr_name in dir(self): + attr = getattr(self, attr_name) + if isinstance(attr, Layer): + layers.append(attr) + elif isinstance(attr, list): + # Recurse into lists (e.g., TransformerBlock.ffn = [Dense, Dense]) + ... + return layers +``` + +This is critical for: +- `count_params()`: sums params across all sublayers recursively. +- `_capture_layer_state()`: captures state of all sublayers for shared layer support. +- `_get_all_layers()`: collects every layer instance for the optimizer. + +### `compute_output_shape` — line 55 + +Returns the expected output shape given an input shape. Used by: +- `Model.summary()` to build the layer table. +- Symbolic `__call__` to determine the output `KerasTensor.shape`. + +### `count_params` — line 46 + +```python +def count_params(self): + count = sum(p.size for p in self.params.values()) + for layer in self.sublayers: + count += layer.count_params() + return count +``` + +## Usage Example — Creating a Custom Layer + +```python +from neutro.layers.base import Layer +import numpy as np + +class MyDense(Layer): + def __init__(self, units, **kwargs): + super().__init__(**kwargs) + self.units = units + + def build(self, input_shape): + self.params['W'] = np.random.randn(input_shape[-1], self.units) * 0.01 + self.params['b'] = np.zeros(self.units) + super().build(input_shape) # sets self.built = True + + def forward(self, inputs): + return np.dot(inputs, self.params['W']) + self.params['b'] + + def backward(self, grad_output): + self.grads['W'] = np.dot(self.inputs.T, grad_output) + self.grads['b'] = np.sum(grad_output, axis=0) + return np.dot(grad_output, self.params['W'].T) +``` + +## References + +- Chollet, F. (2015). **Keras**: The Layer class API. [GitHub](https://github.com/keras-team/keras) +- Keras Custom Layers Guide. [Keras.io](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) diff --git a/docs/layers/core/core_utility_layers.md b/docs/layers/core/core_utility_layers.md new file mode 100644 index 0000000..8ccf648 --- /dev/null +++ b/docs/layers/core/core_utility_layers.md @@ -0,0 +1,93 @@ +# Core Utility Layers + +## Dropout — `neutro/layers/core/dropout.py` + +Randomly sets a fraction of inputs to zero during training, preventing co-adaptation: + +$$y = \begin{cases} \frac{m \odot x}{1 - p} & \text{training} \\ x & \text{inference} \end{cases}$$ + +Where $m_i \sim \text{Bernoulli}(1-p)$ is a mask. The scaling by $1/(1-p)$ keeps the expected output magnitude constant. + +```python +def forward(self, inputs, training=False): + if not training: + return inputs + self.mask = np.random.binomial(1, 1 - self.rate, size=inputs.shape) + return inputs * self.mask / (1 - self.rate) + +def backward(self, grad_output): + return grad_output * self.mask / (1 - self.rate) +``` + +## Flatten — `neutro/layers/core/flatten.py` + +Reshapes a multi-dimensional input into a 2D (batch, features) tensor, preserving the batch dimension: + +```python +def forward(self, inputs): + return inputs.reshape(inputs.shape[0], -1) + +def backward(self, grad_output): + return grad_output.reshape(self.input_shape) +``` + +## MoE Layer — `neutro/layers/core/moe.py` + +### Theory + +Mixture-of-Experts (MoE) scales model capacity without proportional compute. A router network selects which "expert" sub-networks to activate for each input token: + +$$y = \sum_{i=1}^E g_i(x) \cdot E_i(x)$$ + +Where $g_i(x)$ is the router's gating weight (typically top-$k$ sparse) and $E_i$ are expert feed-forward networks. + +### Router — `neutro/layers/core/moe.py:30` + +```python +def forward(self, x): + logits = np.dot(x, self.params['W']) # (batch, seq, num_experts) + weights = softmax(logits, axis=-1) + # Top-k routing + top_k_weights, top_k_indices = ... +``` + +The router learns to assign tokens to the most relevant experts. + +## Reparameterization — `neutro/layers/core/reparameterization.py` + +Implements the reparameterization trick used in VAEs. A sample from $N(\mu, \sigma^2)$ is: + +$$z = \mu + \sigma \odot \epsilon, \quad \epsilon \sim N(0, I)$$ + +This makes the sampling operation differentiable, enabling backpropagation through the stochastic layer. + +```python +def forward(self, inputs): + mu, log_var = inputs + eps = np.random.randn(*mu.shape) + return mu + np.exp(0.5 * log_var) * eps +``` + +## Usage Example + +```python +from neutro.layers import Dropout, Flatten, MoELayer + +drop = Dropout(rate=0.5) +x = np.random.randn(8, 64) +y = drop(x, training=True) # 50% of units dropped + +flat = Flatten() +x = np.random.randn(8, 4, 4, 16) +y = flat(x) # (8, 256) + +moe = MoELayer(num_experts=8, expert_dim=512, top_k=2) +x = np.random.randn(2, 16, 512) +y = moe(x) +``` + +## References + +- Srivastava, N., et al. (2014). **Dropout: A Simple Way to Prevent Neural Networks from Overfitting**. *JMLR*. +- Shazeer, N., et al. (2017). **Outrageously Large Neural Networks: The Sparsely-Gated Mixture-of-Experts Layer**. [arXiv:1701.06538](https://arxiv.org/abs/1701.06538) +- Kingma, D. P., & Welling, M. (2013). **Auto-Encoding Variational Bayes**. [arXiv:1312.6114](https://arxiv.org/abs/1312.6114) diff --git a/docs/layers/core/dense.md b/docs/layers/core/dense.md new file mode 100644 index 0000000..412dc4b --- /dev/null +++ b/docs/layers/core/dense.md @@ -0,0 +1,110 @@ +# Dense Layer + +## Theory + +A Dense (fully-connected) layer computes a linear transformation followed by an optional activation: + +$$y = \phi(xW + b)$$ + +Where: +- $x \in \mathbb{R}^{B \times D}$ is the input (batch $B$, input dimension $D$) +- $W \in \mathbb{R}^{D \times U}$ is the weight matrix (learned) +- $b \in \mathbb{R}^{U}$ is the bias vector (learned) +- $\phi$ is an element-wise activation function (ReLU, sigmoid, tanh, or none) +- $y \in \mathbb{R}^{B \times U}$ is the output + +### Backward Pass + +The gradients are: + +$$\frac{\partial L}{\partial W} = x^T \cdot \frac{\partial L}{\partial y}$$ + +$$\frac{\partial L}{\partial b} = \sum_{\text{batch}} \frac{\partial L}{\partial y}$$ + +$$\frac{\partial L}{\partial x} = \frac{\partial L}{\partial y} \cdot W^T$$ + +If an activation function $\phi$ is present, the gradient is first passed through $\phi'$ before these equations. + +## Implementation Guide + +### File: `neutro/layers/core/dense.py` + +### `__init__` — line 7 + +```python +class Dense(Layer): + def __init__(self, units, activation=None, use_bias=True, + kernel_initializer='glorot_uniform', bias_initializer='zeros', **kwargs): +``` + +- `units`: number of output neurons. +- `activation`: a string like `'relu'` → mapped to an activation function via `get_activation()`. +- Weight initialization is deferred to `build()`. + +### `build` — line 15 + +```python +def build(self, input_shape): + self.input_dim = input_shape[-1] + self.params['W'] = self.kernel_initializer((self.input_dim, self.units)) + if self.use_bias: + self.params['b'] = self.bias_initializer((self.units,)) + super().build(input_shape) +``` + +Parameters are allocated here, not in `__init__`. This is the standard Keras pattern: the shape is inferred from the first call. + +### `forward` — line 26 + +```python +def forward(self, inputs, training=False): + self.inputs = inputs + self.z = np.dot(inputs, self.params['W']) + if self.use_bias: + self.z += self.params['b'] + if self.activation: + return self.activation(self.z) + return self.z +``` + +- `self.inputs` is cached for use in `backward`. +- `self.z` is cached for use in activation backpropagation. +- The activation function (`self.activation`) is called as a callable; it may be a `Layer` instance with its own forward/backward. + +### `backward` — line 36 + +```python +def backward(self, grad_output): + if self.activation: + grad_output = grad_output * self.activation.gradient(self.z) + + inputs_flat = self.inputs.reshape(-1, self.inputs.shape[-1]) + grad_output_flat = grad_output.reshape(-1, grad_output.shape[-1]) + + self.grads['W'] = np.dot(inputs_flat.T, grad_output_flat) + if self.use_bias: + self.grads['b'] = np.sum(grad_output_flat, axis=0) + + return np.dot(grad_output, self.params['W'].T) +``` + +- For activation backprop, the Jacobian of the activation is element-wise multiplied with `grad_output` (most activations like ReLU, sigmoid, tanh are element-wise; Softmax is handled separately via `gradient_fast`). +- The matrix multiplications are the exact implementation of the gradient equations above. +- The return value is the gradient with respect to the input, which is passed to the previous layer. + +## Usage Example + +```python +from neutro.layers import Dense +import numpy as np + +layer = Dense(units=64, activation='relu') +x = np.random.randn(32, 128) # (batch, input_dim) +y = layer(x) # forward, shape (32, 64) +grad = np.random.randn(32, 64) +dx = layer.backward(grad) # gradient w.r.t. x, shape (32, 128) +``` + +## References + +- Goodfellow, I., Bengio, Y., & Courville, A. (2016). **Deep Learning**. Chapter 6: Deep Feedforward Networks. [Deep Learning Book](https://www.deeplearningbook.org/) diff --git a/docs/layers/core/input_layer.md b/docs/layers/core/input_layer.md new file mode 100644 index 0000000..573a04d --- /dev/null +++ b/docs/layers/core/input_layer.md @@ -0,0 +1,107 @@ +# Input Layer and the `Input()` Function + +## Theory + +In the Functional API, every graph needs entry points — places where data enters the model. `Input()` creates a symbolic `KerasTensor` that acts as the root of the graph. The corresponding `InputLayer` is a no-op layer that simply passes data through; its role is purely structural. + +`Input()` is a **convenience function** that: +1. Creates an `InputLayer` with the given shape. +2. Creates a `KerasTensor` as its symbolic output. +3. Records a `Node` connecting them. +4. Returns the `KerasTensor` for use in further layer calls. + +The batch dimension is conventionally `None` (unknown until runtime), mirroring Keras behavior. + +## Implementation Guide + +### File: `neutro/layers/core/input_layer.py` + +### `InputLayer` — line 4 + +```python +class InputLayer(Layer): + def __init__(self, input_shape=None, name=None, **kwargs): + super().__init__(name=name, input_shape=input_shape, **kwargs) + if input_shape is not None: + self.build(input_shape) + + def build(self, input_shape): + self.input_shape = input_shape + self.built = True + + def forward(self, inputs, training=False): + return inputs + + def backward(self, grad_output): + return grad_output +``` + +- `forward` is the identity function — it returns its input unchanged. +- `backward` is also the identity — it passes the gradient straight through. +- `build` does not allocate any parameters; it only marks the layer as built. + +### `Input()` function — line 28 + +```python +def Input(shape=None, name=None, **kwargs): + if shape is None: + raise ValueError("Please provide a shape for the Input.") + + if not isinstance(shape, tuple): + shape = tuple(shape) + + # Keras style: prepend None for batch dimension if missing + if len(shape) == 0 or shape[0] is not None: + shape = (None,) + shape + + layer = InputLayer(input_shape=shape, name=name, **kwargs) + output_tensor = KerasTensor(shape=shape, name=name) + Node(layer, input_tensors=[], output_tensors=output_tensor) + return output_tensor +``` + +Key behaviors: +- **Shape normalization**: If you pass `shape=(28, 28, 1)`, it becomes `(None, 28, 28, 1)`. This is the Keras convention: users specify the per-sample shape, and the batch dimension is prepended. +- **Empty input_tensors**: The `Node` created for `InputLayer` has an empty `input_tensors` list — it has no upstream layers. +- **The returned `KerasTensor`** has its `.node` set to this `Node`, so graph traversal can start from it. + +### How InputLayer is handled during execution + +In `Model.forward` (`neutro/models/base_model.py:217`): + +```python +for node in self._nodes_ordered: + if isinstance(node.layer, InputLayer): + continue # Skip — inputs are placed directly in tensor_map +``` + +InputLayer nodes are **skipped** during execution. Their values come from the model's input data, which is placed into `tensor_map` at the start of `forward`: + +```python +tensor_map[id(self.inputs)] = inputs # Placed before the loop +``` + +The same skip happens in `backward` (`line 314`): InputLayer nodes receive gradients but pass them back as the return value of the entire `backward` call. + +## Usage Example + +```python +from neutro.layers import Input, Dense, Add +from neutro.models import Model + +# Single input +inputs = Input(shape=(28, 28, 1)) # KerasTensor of shape (None, 28, 28, 1) +x = Dense(32)(inputs) +model = Model(inputs=inputs, outputs=x) + +# Multiple inputs +i1 = Input(shape=(10,), name='input_a') +i2 = Input(shape=(10,), name='input_b') +merged = Add()([i1, i2]) +model = Model(inputs=[i1, i2], outputs=merged) +# forward expects [array_a, array_b] +``` + +## References + +- Keras Functional API Guide: **Input()**. [Keras.io](https://keras.io/api/models/model/#functional-api) diff --git a/docs/layers/core/merging.md b/docs/layers/core/merging.md new file mode 100644 index 0000000..f9abb39 --- /dev/null +++ b/docs/layers/core/merging.md @@ -0,0 +1,125 @@ +# Merge Layers: Add, Concatenate, Multiply, Average, Maximum, Minimum + +## Theory + +Merge layers combine multiple input tensors into a single output tensor. They are essential for building non-linear architectures like ResNets (skip connections), Inception modules, and multi-branch networks. Every merge layer takes a **list of tensors** as input. + +### Operations + +| Layer | Operation | Gradient | +|---|---|---| +| `Add` | $y = \sum_i x_i$ | $\frac{\partial L}{\partial x_i} = \frac{\partial L}{\partial y}$ (same for all) | +| `Multiply` | $y = \prod_i x_i$ | $\frac{\partial L}{\partial x_i} = \frac{\partial L}{\partial y} \odot \prod_{j \ne i} x_j$ | +| `Average` | $y = \frac{1}{N} \sum_i x_i$ | $\frac{\partial L}{\partial x_i} = \frac{1}{N} \frac{\partial L}{\partial y}$ | +| `Maximum` | $y = \max_i x_i$ | $\frac{\partial L}{\partial x_i} = \frac{\partial L}{\partial y} \odot \mathbf{1}_{x_i = y}$ | +| `Minimum` | $y = \min_i x_i$ | $\frac{\partial L}{\partial x_i} = \frac{\partial L}{\partial y} \odot \mathbf{1}_{x_i = y}$ | +| `Concatenate` | $y = [x_1, x_2, \dots, x_N]$ along axis $a$ | Split $y$ gradient back along $a$ | + +For `Maximum`/`Minimum`, the indicator function $\mathbf{1}_{x_i = y}$ passes the gradient only to the input(s) that achieved the extreme value — this is known as **argmax routing** in gradient computation. + +## Implementation Guide + +### File: `neutro/layers/core/merging.py` + +### `Add` — line 4 + +```python +class Add(Layer): + def forward(self, inputs, training=False): + self.input_lengths = len(inputs) + return sum(inputs) + + def backward(self, grad_output): + return [grad_output for _ in range(self.input_lengths)] +``` + +- `sum(inputs)` works element-wise across the list. +- The gradient is **broadcast unchanged** to every input — the sum's Jacobian w.r.t. each input is the identity. + +### `Concatenate` — line 37 + +```python +class Concatenate(Layer): + def __init__(self, axis=-1, **kwargs): + super().__init__(**kwargs) + self.axis = axis + + def compute_output_shape(self, input_shape): + out_shape = list(input_shape[0]) + concat_dim = 0 + for shape in input_shape: + dim = shape[self.axis] + if dim is None: # Handle symbolic None (batch dim) + concat_dim = None + break + concat_dim += dim + out_shape[self.axis] = concat_dim + return tuple(out_shape) + + def forward(self, inputs, training=False): + self.input_shapes = [i.shape for i in inputs] + return np.concatenate(inputs, axis=self.axis) + + def backward(self, grad_output): + indices = np.cumsum([s[self.axis] for s in self.input_shapes])[:-1] + return np.split(grad_output, indices, axis=self.axis) +``` + +- `compute_output_shape` correctly handles symbolic `None` dimensions (e.g., batch size). +- The backward uses `np.split` to reverse the concatenation along the same axis. + +### `Multiply` — line 74 + +```python +class Multiply(Layer): + def forward(self, inputs, training=False): + self.inputs = inputs + res = inputs[0].copy() + for i in range(1, len(inputs)): + res *= inputs[i] + return res + + def backward(self, grad_output): + grads = [] + for i in range(len(self.inputs)): + g = grad_output.copy() + for j in range(len(self.inputs)): + if i == j: continue + g *= self.inputs[j] # Product of all inputs except the i-th + grads.append(g) + return grads +``` + +- For each input $i$, the gradient is $\frac{\partial L}{\partial y} \odot \prod_{j \ne i} x_j$. +- `self.inputs` is cached during forward for use in backward (important for shared layer state restoration). + +### `Average`, `Maximum`, `Minimum` — lines 127-200 + +These follow the same pattern. `Maximum` and `Minimum` use `np.maximum` / `np.minimum` in forward and mask-based gradient routing in backward. + +### Shared Layer Compatibility + +All merge layers store intermediate state (`input_lengths`, `input_shapes`, `inputs`) on `self` during `forward`. For shared merge layers used multiple times in a graph, the `Model` class uses `_capture_layer_state` / `_restore_layer_state` (recursive, covering sublayers) to save and restore this state per node. + +## Usage Example + +```python +from neutro.layers import Input, Dense, Add, Concatenate +from neutro.models import Model + +# Skip connection (Add) +inp = Input(shape=(32,)) +x = Dense(32, activation='relu')(inp) +skip = Dense(32)(x) +out = Add()([x, skip]) # Two branches merged + +# Multi-branch concatenation +i1 = Input(shape=(10,)) +i2 = Input(shape=(20,)) +merged = Concatenate(axis=-1)([i1, i2]) # Output shape: (None, 30) +``` + +## References + +- He, K., Zhang, X., Ren, S., & Sun, J. (2016). **Deep Residual Learning for Image Recognition** — skip connections via Add. *CVPR*. [arXiv:1512.03385](https://arxiv.org/abs/1512.03385) +- Szegedy, C., et al. (2015). **Going Deeper with Convolutions** — concatenated multi-branch modules. *CVPR*. [arXiv:1409.4842](https://arxiv.org/abs/1409.4842) diff --git a/docs/layers/embedding/embedding.md b/docs/layers/embedding/embedding.md new file mode 100644 index 0000000..4d0529d --- /dev/null +++ b/docs/layers/embedding/embedding.md @@ -0,0 +1,43 @@ +# Embedding Layers + +## Theory + +### Token Embedding — `neutro/layers/embedding/embedding.py` + +An embedding layer maps discrete tokens (integers) to dense vectors: + +$$x_i = W[\text{token}_i]$$ + +Where $W \in \mathbb{R}^{V \times D}$ is a learnable matrix, $V$ is the vocabulary size, and $D$ is the embedding dimension. The forward pass is a simple lookup: + +```python +def forward(self, inputs): + return self.params['W'][inputs] # (batch, seq_len, embed_dim) +``` + +The backward pass uses `np.add.at` to accumulate gradients back to the embedding matrix: + +```python +def backward(self, grad_output): + self.grads['W'] = np.zeros_like(self.params['W']) + np.add.at(self.grads['W'], self.inputs, grad_output) + return grad_output +``` + +### TimeEmbedding — `neutro/layers/embedding/time_embedding.py` + +Projects scalar timesteps (e.g., diffusion timesteps) into a high-dimensional space using sinusoidal encoding followed by a learnable MLP projection. + +## Usage Example + +```python +from neutro.layers import Embedding + +emb = Embedding(vocab_size=10000, embed_dim=512) +tokens = np.array([[1, 5, 23, 42]]) # (batch, seq_len) +x = emb(tokens) # (1, 4, 512) +``` + +## References + +- Mikolov, T., et al. (2013). **Efficient Estimation of Word Representations in Vector Space**. [arXiv:1301.3781](https://arxiv.org/abs/1301.3781) diff --git a/docs/layers/normalization/normalization.md b/docs/layers/normalization/normalization.md new file mode 100644 index 0000000..1a989f9 --- /dev/null +++ b/docs/layers/normalization/normalization.md @@ -0,0 +1,78 @@ +# Normalization Layers + +## Theory + +Normalization layers stabilize training by controlling the distribution of activations. `neutro` implements four variants. + +### Layer Normalization — `neutro/layers/normalization/layernorm.py` + +Normalizes across the feature dimension for each sample independently: + +$$\mu = \frac{1}{H} \sum_{i=1}^H x_i, \quad \sigma = \sqrt{\frac{1}{H} \sum_{i=1}^H (x_i - \mu)^2 + \epsilon}$$ + +$$\hat{x} = \frac{x - \mu}{\sigma}, \quad y = \gamma \hat{x} + \beta$$ + +Used in Transformers (GPT, BERT, Llama). Independent of batch size. + +### Batch Normalization — `neutro/layers/normalization/batchnorm.py` + +Normalizes across the batch dimension for each feature: + +$$\mu_{\mathcal{B}} = \frac{1}{m} \sum x_i, \quad \sigma_{\mathcal{B}}^2 = \frac{1}{m} \sum (x_i - \mu_{\mathcal{B}})^2$$ + +$$\hat{x}_i = \frac{x_i - \mu_{\mathcal{B}}}{\sqrt{\sigma_{\mathcal{B}}^2 + \epsilon}}, \quad y_i = \gamma \hat{x}_i + \beta$$ + +Tracks running mean/variance for inference. Used in CNNs. + +### RMS Norm — `neutro/layers/normalization/rmsnorm.py` + +Root Mean Square Normalization — a simplified LayerNorm without mean centering: + +$$\text{RMS}(x) = \sqrt{\frac{1}{H} \sum_{i=1}^H x_i^2 + \epsilon}, \quad y = \frac{x}{\text{RMS}(x)} \cdot \gamma$$ + +Used in Llama and modern LLMs for efficiency. + +### Group Normalization — `neutro/layers/normalization/groupnorm.py` + +Divides channels into groups and normalizes within each group: + +$$\mu_g = \frac{1}{|\mathcal{G}_g|} \sum_{i \in \mathcal{G}_g} x_i, \quad \sigma_g^2 = \frac{1}{|\mathcal{G}_g|} \sum_{i \in \mathcal{G}_g} (x_i - \mu_g)^2$$ + +Used in vision models when batch size is small (e.g., video, medical imaging). + +## Implementation Guide + +All normalization layers share a common pattern: + +| Method | Behavior | +|---|---| +| `build(input_shape)` | Allocates `gamma` (scale) and `beta` (shift) parameters. Shape matches the feature dimension. | +| `forward(x)` | Computes mean/variance, normalizes, scales, shifts. | +| `backward(grad)` | Backpropagates through normalization using the stored mean/variance. | + +For LayerNorm: + +```python +def forward(self, x): + self.mean = np.mean(x, axis=-1, keepdims=True) + self.var = np.var(x, axis=-1, keepdims=True) + self.x_hat = (x - self.mean) / np.sqrt(self.var + self.eps) + return self.gamma * self.x_hat + self.beta +``` + +## Usage Example + +```python +from neutro.layers import LayerNormalization + +ln = LayerNormalization(epsilon=1e-6) +x = np.random.randn(4, 16, 64) # (batch, seq, features) +y = ln(x) # Normalized along last axis, same shape +``` + +## References + +- Ba, J. L., Kiros, J. R., & Hinton, G. E. (2016). **Layer Normalization**. [arXiv:1607.06450](https://arxiv.org/abs/1607.06450) +- Ioffe, S., & Szegedy, C. (2015). **Batch Normalization**. [arXiv:1502.03167](https://arxiv.org/abs/1502.03167) +- Zhang, B., & Sennrich, R. (2019). **Root Mean Square Layer Normalization**. [arXiv:1910.07467](https://arxiv.org/abs/1910.07467) +- Wu, Y., & He, K. (2018). **Group Normalization**. [arXiv:1803.08494](https://arxiv.org/abs/1803.08494) diff --git a/docs/layers/pooling/pooling.md b/docs/layers/pooling/pooling.md new file mode 100644 index 0000000..d15d5b5 --- /dev/null +++ b/docs/layers/pooling/pooling.md @@ -0,0 +1,62 @@ +# Pooling Layers + +## Theory + +Pooling layers reduce the spatial dimensions of feature maps, providing downsampling and local translation invariance. + +### MaxPooling2D — `neutro/layers/pooling/maxpooling2d.py` + +Slides a window over the input and takes the maximum value in each window: + +$$y_{i,j,k} = \max_{p=1..P, q=1..Q} x_{i \cdot s + p,\; j \cdot s + q,\; k}$$ + +- **Forward**: `np.max` over sliding windows. +- **Backward**: Routes gradient to the position that was the maximum (argmax routing). + +### Global Pooling — `neutro/layers/pooling/global_pooling.py` + +Reduces each feature map to a single value: + +- **GlobalAveragePooling2D**: $y_k = \frac{1}{H \cdot W} \sum_{i,j} x_{i,j,k}$ +- **GlobalMaxPooling2D**: $y_k = \max_{i,j} x_{i,j,k}$ + +Used before the final Dense layer in CNNs to replace Flatten (fewer parameters, no overfitting). + +### UpSampling2D — `neutro/layers/pooling/upsampling2d.py$ + +Increases spatial dimensions by repeating rows and columns (nearest-neighbor upsampling): + +$$y_{i \cdot f + p,\; j \cdot f + q,\; k} = x_{i,j,k}$$ + +- Used in decoder architectures (UNet, GANs). +- Backward: sums the gradient back into the original positions. + +## Implementation Guide + +All pooling layers are in `neutro/layers/pooling/`. MaxPooling2D uses `im2col` from `conv_utils.py` to unroll windows, then applies `np.max` and `np.argmax` for efficient forward/backward. + +```python +# MaxPooling2D key pattern +cols = im2col(x, self.pool_size, self.strides, padding='valid') +max_idx = np.argmax(cols, axis=0) +output = cols[max_idx, np.arange(cols.shape[1])] +# Reshape to output spatial dimensions +``` + +## Usage Example + +```python +from neutro.layers import MaxPooling2D, GlobalAveragePooling2D + +pool = MaxPooling2D(pool_size=(2, 2)) +x = np.random.randn(2, 28, 28, 16) +y = pool(x) # shape (2, 14, 14, 16) + +gap = GlobalAveragePooling2D() +z = gap(y) # shape (2, 16) +``` + +## References + +- Springenberg, J. T., et al. (2014). **Striving for Simplicity: The All Convolutional Net**. [arXiv:1412.6806](https://arxiv.org/abs/1412.6806) +- Lin, M., Chen, Q., & Yan, S. (2013). **Network In Network**. [arXiv:1312.4400](https://arxiv.org/abs/1312.4400) diff --git a/docs/layers/recurrent/recurrent.md b/docs/layers/recurrent/recurrent.md new file mode 100644 index 0000000..6bc1c96 --- /dev/null +++ b/docs/layers/recurrent/recurrent.md @@ -0,0 +1,84 @@ +# Recurrent Layers: SimpleRNN, LSTM, GRU + +## Theory + +Recurrent Neural Networks process sequences by maintaining a hidden state that is updated at each time step. The key challenge is the **vanishing gradient problem** — gradients diminish exponentially over long sequences. + +### SimpleRNN — `neutro/layers/recurrent/simple_rnn.py` + +$$h_t = \tanh(W_h \cdot h_{t-1} + W_x \cdot x_t + b)$$ + +Simple RNN suffers from vanishing gradients and cannot capture long-range dependencies. + +### LSTM — `neutro/layers/recurrent/lstm.py` + +Long Short-Term Memory introduces a gating mechanism with a cell state: + +$$f_t = \sigma(W_f \cdot [h_{t-1}, x_t] + b_f) \quad \text{(forget gate)}$$ +$$i_t = \sigma(W_i \cdot [h_{t-1}, x_t] + b_i) \quad \text{(input gate)}$$ +$$\tilde{C}_t = \tanh(W_C \cdot [h_{t-1}, x_t] + b_C) \quad \text{(candidate)}$$ +$$C_t = f_t \odot C_{t-1} + i_t \odot \tilde{C}_t \quad \text{(cell update)}$$ +$$o_t = \sigma(W_o \cdot [h_{t-1}, x_t] + b_o) \quad \text{(output gate)}$$ +$$h_t = o_t \odot \tanh(C_t) \quad \text{(hidden state)}$$ + +The cell state $C_t$ can carry information over long distances with minimal gradient decay. + +### GRU — `neutro/layers/recurrent/gru.py` + +Gated Recurrent Unit simplifies LSTM by merging the cell state and hidden state: + +$$z_t = \sigma(W_z \cdot [h_{t-1}, x_t]) \quad \text{(update gate)}$$ +$$r_t = \sigma(W_r \cdot [h_{t-1}, x_t]) \quad \text{(reset gate)}$$ +$$\tilde{h}_t = \tanh(W \cdot [r_t \odot h_{t-1}, x_t])$$ +$$h_t = (1 - z_t) \odot h_{t-1} + z_t \odot \tilde{h}_t$$ + +GRU has fewer parameters than LSTM and often performs comparably. + +## Implementation Guide + +All recurrent layers are in `neutro/layers/recurrent/`. They share a common pattern: + +```python +def forward(self, inputs, training=False): + batch_size, seq_len, input_dim = inputs.shape + # Initialize hidden state + h = np.zeros((batch_size, self.units)) + self.h_states = [] + for t in range(seq_len): + x_t = inputs[:, t, :] + h = self._step(x_t, h) # One RNN step + self.h_states.append(h) + return np.stack(self.h_states, axis=1) +``` + +The backward pass (Backpropagation Through Time, BPTT) reverses the loop: + +```python +def backward(self, grad_output): + for t in reversed(range(self.seq_len)): + grad_h = grad_output[:, t, :] + grad_h_next + # Backprop through one step + ... + grad_h_next = grad_from_h + return grad_x +``` + +Weight concatenation optimization (LSTM, line 53): weights for all four gates are stored as a single matrix to optimize the dot product: `W = np.concatenate([W_f, W_i, W_C, W_o])`. + +## Usage Example + +```python +from neutro.layers import LSTM, GRU + +lstm = LSTM(units=128, return_sequences=True) +x = np.random.randn(4, 32, 64) # (batch, seq, features) +y = lstm(x) # (batch, seq, 128) + +gru = GRU(units=64, return_sequences=False) +z = gru(x) # (batch, 64) +``` + +## References + +- Hochreiter, S., & Schmidhuber, J. (1997). **Long Short-Term Memory**. *Neural Computation*. +- Chung, J., et al. (2014). **Empirical Evaluation of Gated Recurrent Neural Networks on Sequence Modeling**. [arXiv:1412.3555](https://arxiv.org/abs/1412.3555) diff --git a/docs/layers/transformer/transformer_block.md b/docs/layers/transformer/transformer_block.md new file mode 100644 index 0000000..ee2caf2 --- /dev/null +++ b/docs/layers/transformer/transformer_block.md @@ -0,0 +1,89 @@ +# Transformer Block + +## Theory + +The Transformer block is the fundamental building block of modern LLMs. It combines multi-head attention with a feed-forward network, residual connections, and layer normalization. + +### Pre-Norm Architecture + +$$\text{output} = x + \text{FFN}(\text{LN}(x + \text{Attention}(\text{LN}(x))))$$ + +Each sub-layer has a residual connection (`x + sublayer(x)`), which helps gradient flow during backpropagation. + +### Post-Norm Architecture (original Transformer) + +$$\text{output} = \text{LN}(x + \text{FFN}(\text{LN}(x + \text{Attention}(x))))$$ + +## Implementation Guide + +### File: `neutro/layers/transformer/transformer_block.py` + +### `__init__` — line 11 + +```python +class TransformerBlock(Layer): + def __init__(self, embed_dim, num_heads, ff_dim, rate=0.1, + causal=False, use_flash=False, pre_norm=False, **kwargs): +``` + +- `embed_dim`: model dimension (e.g., 768 for GPT-2 small). +- `num_heads`: number of attention heads (must divide `embed_dim`). +- `ff_dim`: feed-forward hidden dimension (typically 4× `embed_dim`). +- `causal`: if True, creates a causal attention mask (for autoregressive generation). +- `use_flash`: if True, uses `FlashAttention` instead of standard `MultiHeadAttention`. +- `pre_norm`: if True, uses Pre-Norm (modern); if False, uses Post-Norm (original). + +### Forward pass — line 42 + +For Pre-Norm: + +```python +norm1 = self.layernorm1(inputs, training) +attn_output = self.att(norm1, mask=mask, training=training) +h = inputs + self.dropout1(attn_output, training=training) + +norm2 = self.layernorm2(h, training=training) +ffn_output = self.ffn[1](self.ffn[0](norm2, training=training), training=training) +return h + self.dropout2(ffn_output, training=training) +``` + +The block contains 7 sublayers: `att`, `layernorm1`, `layernorm2`, `dropout1`, `dropout2`, and two Dense layers in `ffn`. + +### Backward pass — line 82 + +The backward manually routes gradients through the skip connections: + +```python +def backward(self, grad_output): + grad_ffn_path = self.dropout2.backward(grad_output) + grad_ffn = self.ffn[1].backward(grad_ffn_path) + grad_ffn = self.ffn[0].backward(grad_ffn) + grad_norm2 = self.layernorm2.backward(grad_ffn) + grad_h = grad_output + grad_norm2 # Skip connection + + grad_attn_path = self.dropout1.backward(grad_h) + grad_attn = self.att.backward(grad_attn_path) + grad_norm1 = self.layernorm1.backward(grad_attn) + return grad_h + grad_norm1 # Skip connection +``` + +### Sub-layers + +The block exposes its sublayers via the `sublayers` property, which is critical for: +- **Optimizer**: `_get_all_layers` finds them for parameter updates. +- **Shared layer state**: `_capture_layer_state` saves their internal state (inputs, z, etc.) per node. + +## Usage Example + +```python +from neutro.layers.transformer import TransformerBlock + +block = TransformerBlock(embed_dim=512, num_heads=8, ff_dim=2048, pre_norm=True) +x = np.random.randn(2, 16, 512) # (batch, seq, embed) +y = block(x) # Same shape +``` + +## References + +- Vaswani, A., et al. (2017). **Attention Is All You Need**. [arXiv:1706.03762](https://arxiv.org/abs/1706.03762) +- Pre-Norm: GPT-2 / Llama architecture variant. diff --git a/docs/losses/losses.md b/docs/losses/losses.md new file mode 100644 index 0000000..deaaff2 --- /dev/null +++ b/docs/losses/losses.md @@ -0,0 +1,58 @@ +# Loss Functions + +## Theory + +A loss function $L(y_{\text{true}}, y_{\text{pred}})$ measures the discrepancy between predicted and target values. Training minimizes this loss via gradient descent. Every loss in `neutro` implements two methods: + +- `forward(y_true, y_pred) → scalar`: compute the loss value. +- `gradient(y_true, y_pred) → ndarray`: compute $\partial L / \partial y_{\text{pred}}$, the gradient w.r.t. the prediction. + +## Implementation Guide + +### File: `neutro/losses/base.py` + +```python +class Loss: + def forward(self, y_true, y_pred): raise NotImplementedError + def gradient(self, y_true, y_pred): raise NotImplementedError +``` + +### Mean Squared Error — `neutro/losses/mse.py` + +$$L = \frac{1}{N} \sum_{i=1}^N (y_{\text{pred}} - y_{\text{true}})^2$$ + +$$\frac{\partial L}{\partial y_{\text{pred}}} = \frac{2}{N} (y_{\text{pred}} - y_{\text{true}})$$ + +### Categorical Crossentropy — `neutro/losses/categorical_crossentropy.py` + +$$L = -\sum_i y_{\text{true},i} \log(y_{\text{pred},i})$$ + +$$\frac{\partial L}{\partial y_{\text{pred}}} = -\frac{y_{\text{true}}}{y_{\text{pred}}}$$ + +Used with one-hot encoded targets and Softmax output. + +### Sparse Categorical Crossentropy — `neutro/losses/sparse_categorical_crossentropy.py` + +Same as categorical crossentropy but `y_true` is integer-encoded (shape `(batch,)`). The loss converts integers to one-hot internally. + +### VAE Loss — `neutro/losses/vae_loss.py` + +$$L = L_{\text{recon}} + \beta \cdot L_{\text{KL}}$$ + +Combines a reconstruction loss (e.g., MSE or binary crossentropy) with a KL divergence term that regularizes the latent space. + +## Usage Example + +```python +from neutro.losses import CategoricalCrossentropy + +loss_fn = CategoricalCrossentropy() +y_true = np.array([[0, 1, 0]]) +y_pred = np.array([[0.1, 0.8, 0.1]]) +l = loss_fn(y_true, y_pred) # scalar +grad = loss_fn.gradient(y_true, y_pred) # same shape as y_pred +``` + +## References + +- Goodfellow, I., Bengio, Y., & Courville, A. (2016). **Deep Learning**. Chapter 6: Loss Functions. diff --git a/docs/metrics/metrics.md b/docs/metrics/metrics.md new file mode 100644 index 0000000..aab7400 --- /dev/null +++ b/docs/metrics/metrics.md @@ -0,0 +1,52 @@ +# Metrics + +## Theory + +Metrics quantify model performance during training and evaluation. Unlike losses, metrics are not used for gradient computation — they are only reported for monitoring. + +Every metric in `neutro` implements: +- `__call__(y_true, y_pred) → scalar`: compute the metric value. + +## Implementation Guide + +### File: `neutro/metrics/base.py` + +### Accuracy — `neutro/metrics/accuracy.py` + +$$\text{Accuracy} = \frac{1}{N} \sum_{i=1}^N \mathbf{1}(\arg\max y_{\text{pred},i} = \arg\max y_{\text{true},i})$$ + +Works with one-hot targets. + +### Sparse Accuracy — `neutro/metrics/sparse_accuracy.py` + +Same as Accuracy but `y_true` is integer-encoded. + +### Precision — `neutro/metrics/precision.py` + +$$\text{Precision} = \frac{TP}{TP + FP}$$ + +### Recall — `neutro/metrics/recall.py` + +$$\text{Recall} = \frac{TP}{TP + FN}$$ + +### F1 Score — `neutro/metrics/f1_score.py` + +$$F_1 = 2 \cdot \frac{\text{Precision} \cdot \text{Recall}}{\text{Precision} + \text{Recall}}$$ + +## Usage Example + +```python +from neutro.metrics import Accuracy + +acc = Accuracy() +y_true = np.array([[0, 1, 0], [1, 0, 0]]) +y_pred = np.array([[0.1, 0.8, 0.1], [0.7, 0.2, 0.1]]) +acc_value = acc(y_true, y_pred) # scalar + +# In model +model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy']) +``` + +## References + +- Keras Metrics API. [Keras.io](https://keras.io/api/metrics/) diff --git a/docs/models/language_models.md b/docs/models/language_models.md new file mode 100644 index 0000000..32ab855 --- /dev/null +++ b/docs/models/language_models.md @@ -0,0 +1,127 @@ +# Language Models + +## GPT — `neutro/models/language/gpt.py` + +### GPT-1 (line 5) + +$$P(w_t | w_{`, etc.). +- `encode` and `decode` methods consistent with the Keras style. + +### TikToken Compatibility — `neutro/tokenizers/tiktoken_compat.py` + +Provides a `TikTokenTokenizer` that can load OpenAI's TikToken BPE ranks (from `cl100k_base`, `gpt2`, etc.) and use them directly. This bridges `neutro` with OpenAI's tokenization system. + +```python +def load_tiktoken_bpe(bpe_file): + # Reads the BPE merge file in TikToken format + with open(bpe_file, 'rb') as f: + return pickle.load(f) +``` + +## Usage Example + +```python +from neutro.tokenizers import RegexTokenizer, TikTokenTokenizer + +tokenizer = RegexTokenizer() +tokenizer.train(["hello world", "hello there"]) +tokens = tokenizer.encode("hello world") +text = tokenizer.decode(tokens) +``` + +## References + +- Sennrich, R., Haddow, B., & Birch, A. (2016). **Neural Machine Translation of Rare Words with Subword Units**. [arXiv:1508.07909](https://arxiv.org/abs/1508.07909) +- GPT-2 Tokenizer: Radford, A., et al. (2019). **Language Models are Unsupervised Multitask Learners**. diff --git a/docs/utils/utils.md b/docs/utils/utils.md new file mode 100644 index 0000000..f47b2a6 --- /dev/null +++ b/docs/utils/utils.md @@ -0,0 +1,83 @@ +# Utilities + +## conv_utils — `neutro/utils/conv_utils.py` + +### im2col and col2im + +These functions implement the image-to-column transformation that converts convolution into matrix multiplication: + +- **im2col**: Unrolls each filter-sized patch of the input into a column of a matrix. The output matrix has shape `(kernel_size * channels, output_spatial_size)`. +- **col2im**: The inverse operation — redistributes gradients from the column matrix back to the input volume shape. + +```python +def im2col(x, kernel_size, strides, padding='valid'): + # Fancy indexing to extract sliding windows + ... + +def col2im(grad_cols, input_shape, kernel_size, strides): + # Accumulate gradients back to input positions + ... +``` + +Used by `Conv2D` and `MaxPooling2D` for efficient forward/backward computation. + +## rope_utils — `neutro/utils/rope_utils.py$ + +### Rotary Position Embedding (RoPE) + +RoPE encodes position information by rotating query and key vectors in attention: + +$$\text{RoPE}(x, m) = x \cdot \begin{pmatrix} \cos(m\theta) & -\sin(m\theta) \\ \sin(m\theta) & \cos(m\theta) \end{pmatrix}$$ + +```python +def precompute_freqs_cis(dim, max_seq_len, base=10000.0): + freqs = 1.0 / (base ** (np.arange(0, dim, 2) / dim)) + t = np.arange(max_seq_len) + return np.exp(1j * np.outer(t, freqs)) +``` + +- No learned parameters — positions are encoded via rotation. +- Used in Llama, GPT-NeoX, and many modern LLMs. + +## diffusion_utils — `neutro/utils/diffusion_utils.py` + +Implements the forward diffusion process (adding noise) for DDPM: + +```python +class GaussianDiffusion: + def q_sample(self, x_start, t, noise=None): + # q(x_t | x_0) = N(sqrt(alpha_bar_t) * x_0, (1 - alpha_bar_t) * I) + sqrt_alpha_bar = np.sqrt(self.alphas_cumprod[t]) + sqrt_one_minus = np.sqrt(1 - self.alphas_cumprod[t]) + return sqrt_alpha_bar * x_start + sqrt_one_minus * noise + + def p_sample(self, model, x_t, t): + # Reverse step: denoise x_t using the model + predicted_noise = model(x_t, t) + ... +``` + +## visualization — `neutro/utils/visualization.py` + +Provides `plot_attention_weights` for visualizing attention patterns as heatmaps: + +```python +def plot_attention_weights(attention_weights, tokens, layer_name=None): + # Matplotlib heatmap of attention scores + ... +``` + +## Usage Example + +```python +from neutro.utils.rope_utils import precompute_freqs_cis +from neutro.utils.conv_utils import im2col + +freqs = precompute_freqs_cis(dim=64, max_seq_len=512) +cols = im2col(x, kernel_size=(3, 3), strides=1, padding='same') +``` + +## References + +- Su, J., et al. (2021). **RoFormer: Enhanced Transformer with Rotary Position Embedding**. [arXiv:2104.09864](https://arxiv.org/abs/2104.09864) +- Ho, J., Jain, A., & Abbeel, P. (2020). **Denoising Diffusion Probabilistic Models**. [arXiv:2006.11239](https://arxiv.org/abs/2006.11239) diff --git a/examples/README.md b/examples/README.md index d124087..dc021ab 100644 --- a/examples/README.md +++ b/examples/README.md @@ -15,7 +15,20 @@ A classic convolutional neural network for classifying handwritten digits from t python3 examples/mnist_cnn.py ``` -## 2. WikiText-2 Transformer LLM +## 2. MNIST Functional Residual Model +Demonstrates the **Functional API** for building non-linear architectures like ResNets. + +**Features:** +- Keras-style Functional API (`Input`, `Model(inputs, outputs)`). +- Skip connections (residual paths) using the `Add` layer. +- Automatic backpropagation through complex Directed Acyclic Graphs (DAG). + +**Run:** +```bash +python3 examples/mnist_functional_residual.py +``` + +## 3. WikiText-2 Transformer LLM A character-level language model based on the Transformer architecture, trained on the WikiText-2 dataset. **Features:** diff --git a/examples/mnist_functional_residual.py b/examples/mnist_functional_residual.py new file mode 100644 index 0000000..03b287d --- /dev/null +++ b/examples/mnist_functional_residual.py @@ -0,0 +1,89 @@ +import numpy as np +import os +import sys + +# Add the project root to sys.path +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) + +from neutro.models import Model +from neutro.layers import Input, Conv2D, MaxPooling2D, Flatten, Dense, Dropout, Add, ReLU, Softmax +from neutro.utils.data_utils import load_mnist +from neutro.preprocessing.image import ImageDataGenerator +from neutro.optimizers import Adam + +def train_mnist_residual(): + """ + Trains a ResNet-style model on MNIST using the Functional API. + This demonstrates non-linear connectivity (skip connections). + """ + print("Loading MNIST...") + (x_train, y_train), (x_test, y_test) = load_mnist() + + # Preprocess + x_train = x_train.reshape(-1, 28, 28, 1).astype('float32') + x_test = x_test.reshape(-1, 28, 28, 1).astype('float32') + + # One-hot encode labels + y_train_cat = np.eye(10)[y_train] + y_test_cat = np.eye(10)[y_test] + + # --- Build model using Functional API --- + inputs = Input(shape=(28, 28, 1)) + + # Initial block + x = Conv2D(32, (3, 3), padding='same')(inputs) + x = ReLU()(x) + x = MaxPooling2D((2, 2))(x) + + # Residual Block + # Shortcut path + shortcut = x + + # Residual path + x = Conv2D(32, (3, 3), padding='same')(x) + x = ReLU()(x) + x = Conv2D(32, (3, 3), padding='same')(x) + + # Merge paths + x = Add()([x, shortcut]) + x = ReLU()(x) + + # Classification Head + x = MaxPooling2D((2, 2))(x) + x = Flatten()(x) + x = Dense(128)(x) + x = ReLU()(x) + x = Dropout(0.5)(x) + + outputs = Dense(10)(x) + outputs = Softmax()(outputs) + + model = Model(inputs=inputs, outputs=outputs) + # ---------------------------------------- + + model.compile( + optimizer=Adam(learning_rate=0.001), + loss='categorical_crossentropy', + metrics=['accuracy'] + ) + + print("\nModel Summary:") + model.summary() + + print("\nStarting training (Subset of 1000 samples for demo)...") + datagen = ImageDataGenerator(rescale=1/255.0) + train_flow = datagen.flow(x_train[:1000], y_train_cat[:1000], batch_size=64) + + # Train for 5 epochs for demo purposes + model.fit( + train_flow, + epochs=5, + validation_data=(x_test[:100]/255.0, y_test_cat[:100]) + ) + + print("\nEvaluating on test set...") + results = model.evaluate(x_test[:100]/255.0, y_test_cat[:100]) + print(f"Test Results: {results}") + +if __name__ == "__main__": + train_mnist_residual() diff --git a/neutro/__init__.py b/neutro/__init__.py index 1326dba..bf02dd8 100644 --- a/neutro/__init__.py +++ b/neutro/__init__.py @@ -1,4 +1,5 @@ from . import layers +from .layers import Input from . import activations from . import initializers from . import losses diff --git a/neutro/engine/__init__.py b/neutro/engine/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/neutro/engine/node.py b/neutro/engine/node.py new file mode 100644 index 0000000..7cad2f2 --- /dev/null +++ b/neutro/engine/node.py @@ -0,0 +1,38 @@ +import numpy as np + +class KerasTensor: + """ + Symbolic representation of a tensor in the functional API. + """ + def __init__(self, shape, node=None, name=None): + self.shape = shape + self.node = node # The node that produced this tensor + self.name = name + + def __repr__(self): + return f"KerasTensor(shape={self.shape}, name={self.name})" + +class Node: + """ + Represents a 'call' to a layer. + Connects input KerasTensors to output KerasTensors. + """ + def __init__(self, layer, input_tensors, output_tensors): + self.layer = layer + self.input_tensors = input_tensors + self.output_tensors = output_tensors + + # Register the node in the layer + if not hasattr(layer, '_inbound_nodes'): + layer._inbound_nodes = [] + layer._inbound_nodes.append(self) + + # Link output tensors to this node + if isinstance(output_tensors, list): + for t in output_tensors: + t.node = self + else: + output_tensors.node = self + + def __repr__(self): + return f"Node(layer={self.layer.name or self.layer.__class__.__name__})" diff --git a/neutro/layers/__init__.py b/neutro/layers/__init__.py index 6db5200..8bebd71 100644 --- a/neutro/layers/__init__.py +++ b/neutro/layers/__init__.py @@ -1,9 +1,11 @@ from .base import Layer +from .core.input_layer import Input, InputLayer from .core.dense import Dense from .core.dropout import Dropout from .core.flatten import Flatten from .core.activation import Activation, ReLU, Softmax, Sigmoid, Tanh from .core.moe import MoELayer +from .core.merging import Add, Concatenate, Multiply, Average, Maximum, Minimum from .convolutional.conv2d import Conv2D from .convolutional.conv1d import Conv1D from .pooling.maxpooling2d import MaxPooling2D diff --git a/neutro/layers/attention/flash_attention.py b/neutro/layers/attention/flash_attention.py index 5252f97..fa57835 100644 --- a/neutro/layers/attention/flash_attention.py +++ b/neutro/layers/attention/flash_attention.py @@ -15,8 +15,8 @@ class FlashAttention(Layer): dropout: Dropout probability. use_rope: Whether to use Rotary Positional Embeddings. """ - def __init__(self, num_heads, key_dim, block_size_r=64, block_size_c=64, dropout=0.0, use_rope=False): - super().__init__() + def __init__(self, num_heads, key_dim, block_size_r=64, block_size_c=64, dropout=0.0, use_rope=False, **kwargs): + super().__init__(**kwargs) self.num_heads = num_heads self.key_dim = key_dim self.head_dim = key_dim // num_heads @@ -35,6 +35,9 @@ def build(self, input_shape): self.params['Wo'] = np.random.randn(self.key_dim, self.embed_dim) * 0.02 super().build(input_shape) + def compute_output_shape(self, input_shape): + return input_shape + def forward(self, x, mask=None, training=False, kv_cache=None, layer_id=None): self.x = x self.mask = mask diff --git a/neutro/layers/attention/mla.py b/neutro/layers/attention/mla.py index d540d43..a5f17b8 100644 --- a/neutro/layers/attention/mla.py +++ b/neutro/layers/attention/mla.py @@ -45,6 +45,9 @@ def build(self, input_shape): super().build(input_shape) + def compute_output_shape(self, input_shape): + return input_shape + def forward(self, x, mask=None, training=False, kv_cache=None, layer_id=None): self.x = x batch_size, seq_len, _ = x.shape diff --git a/neutro/layers/base.py b/neutro/layers/base.py index 96a98c1..0c5d69f 100644 --- a/neutro/layers/base.py +++ b/neutro/layers/base.py @@ -9,6 +9,7 @@ def __init__(self, name=None, **kwargs): self.grads = {} self.input_shape = kwargs.get('input_shape') self.output_shape = None + self._inbound_nodes = [] def build(self, input_shape): self.input_shape = input_shape @@ -56,13 +57,46 @@ def compute_output_shape(self, input_shape): Computes the output shape of the layer. Should be overridden by subclasses. """ + if hasattr(self, 'output_shape') and self.output_shape is not None: + return self.output_shape return input_shape - raise NotImplementedError def backward(self, grad_output): raise NotImplementedError def __call__(self, inputs, *args, **kwargs): + from ..engine.node import KerasTensor, Node + + # Check if inputs are symbolic + is_symbolic = False + if isinstance(inputs, KerasTensor): + is_symbolic = True + elif isinstance(inputs, list) and any(isinstance(i, KerasTensor) for i in inputs): + is_symbolic = True + + if is_symbolic: + # Symbolic call (Functional API) + if isinstance(inputs, list): + input_shapes = [i.shape for i in inputs] + else: + input_shapes = inputs.shape + + if not self.built: + self.build(input_shapes) + + output_shape = self.compute_output_shape(input_shapes) + + # Create output tensor(s) + if isinstance(output_shape, list): + output_tensors = [KerasTensor(shape=s) for s in output_shape] + else: + output_tensors = KerasTensor(shape=output_shape) + + # Create node + Node(self, input_tensors=inputs, output_tensors=output_tensors) + return output_tensors + + # Eager call (Sequential or manual) if not self.built: if isinstance(inputs, list): self.build([i.shape for i in inputs]) diff --git a/neutro/layers/convolutional/conv1d.py b/neutro/layers/convolutional/conv1d.py index e8a84fc..f48eab8 100644 --- a/neutro/layers/convolutional/conv1d.py +++ b/neutro/layers/convolutional/conv1d.py @@ -17,8 +17,8 @@ class Conv1D(Layer): kernel_initializer: Initializer for the kernel weights matrix. bias_initializer: Initializer for the bias vector. """ - def __init__(self, filters, kernel_size, strides=1, padding='valid', activation=None, kernel_initializer='glorot_uniform', bias_initializer='zeros'): - super().__init__() + def __init__(self, filters, kernel_size, strides=1, padding='valid', activation=None, kernel_initializer='glorot_uniform', bias_initializer='zeros', **kwargs): + super().__init__(**kwargs) self.filters = filters self.kernel_size = kernel_size if isinstance(kernel_size, (tuple, list)) else (kernel_size,) self.strides = strides if isinstance(strides, (tuple, list)) else (strides,) diff --git a/neutro/layers/core/dense.py b/neutro/layers/core/dense.py index 2aab7d1..3376bd4 100644 --- a/neutro/layers/core/dense.py +++ b/neutro/layers/core/dense.py @@ -13,6 +13,7 @@ def __init__(self, units, activation=None, use_bias=True, kernel_initializer='gl self.bias_initializer = get_initializer(bias_initializer) def build(self, input_shape): + # print(f"DEBUG: Dense.build input_shape={input_shape}") self.input_dim = input_shape[-1] self.params['W'] = self.kernel_initializer((self.input_dim, self.units)) if self.use_bias: diff --git a/neutro/layers/core/input_layer.py b/neutro/layers/core/input_layer.py new file mode 100644 index 0000000..6259997 --- /dev/null +++ b/neutro/layers/core/input_layer.py @@ -0,0 +1,51 @@ +from ..base import Layer +from ...engine.node import KerasTensor, Node + +class InputLayer(Layer): + """ + Layer to be used as an entry point into a Network (a graph of layers). + """ + def __init__(self, input_shape=None, name=None, **kwargs): + super().__init__(name=name, input_shape=input_shape, **kwargs) + if input_shape is not None: + self.build(input_shape) + + def build(self, input_shape): + self.input_shape = input_shape + # Add batch dimension if missing + if len(input_shape) > 0 and input_shape[0] is not None: + # We assume users might pass (28, 28, 1) or (None, 28, 28, 1) + # Keras usually expects input_shape to NOT include batch. + pass + self.built = True + + def forward(self, inputs, training=False): + return inputs + + def backward(self, grad_output): + return grad_output + +def Input(shape=None, name=None, **kwargs): + """ + Used to instantiate a Keras tensor. + """ + if shape is None: + raise ValueError("Please provide a shape for the Input.") + + # Ensure shape is a tuple and starts with None for batch + if not isinstance(shape, tuple): + shape = tuple(shape) + + # Keras style: if first element is not None, prepend None + if len(shape) == 0 or shape[0] is not None: + shape = (None,) + shape + + layer = InputLayer(input_shape=shape, name=name, **kwargs) + + # Create the symbolic output tensor + output_tensor = KerasTensor(shape=shape, name=name) + + # Create the node connecting layer to its output + Node(layer, input_tensors=[], output_tensors=output_tensor) + + return output_tensor diff --git a/neutro/layers/core/merging.py b/neutro/layers/core/merging.py index 008f8c3..6160428 100644 --- a/neutro/layers/core/merging.py +++ b/neutro/layers/core/merging.py @@ -23,6 +23,11 @@ def build(self, input_shape): self.output_shape = input_shape self.built = True + def compute_output_shape(self, input_shape): + if isinstance(input_shape, list): + return input_shape[0] + return input_shape + def forward(self, inputs, training=False): """ inputs: list of ndarrays @@ -45,18 +50,28 @@ def __init__(self, axis=-1, **kwargs): super().__init__(**kwargs) self.axis = axis - def build(self, input_shape): + def compute_output_shape(self, input_shape): if not isinstance(input_shape, list): - self.output_shape = input_shape - else: - # Calculate output shape based on concatenation axis - out_shape = list(input_shape[0]) - concat_dim = 0 - for shape in input_shape: - concat_dim += shape[self.axis] - out_shape[self.axis] = concat_dim - self.output_shape = tuple(out_shape) - super().build(input_shape) + return input_shape + + # Calculate output shape based on concatenation axis + out_shape = list(input_shape[0]) + concat_dim = 0 + for shape in input_shape: + # Handle None in shapes (symbolic) + dim = shape[self.axis] + if dim is None: + concat_dim = None + break + concat_dim += dim + + out_shape[self.axis] = concat_dim + return tuple(out_shape) + + def build(self, input_shape): + self.input_shape = input_shape + self.output_shape = self.compute_output_shape(input_shape) + self.built = True def forward(self, inputs, training=False): self.input_shapes = [i.shape for i in inputs] @@ -66,3 +81,127 @@ def backward(self, grad_output): # Split grad_output along the same axis indices = np.cumsum([s[self.axis] for s in self.input_shapes])[:-1] return np.split(grad_output, indices, axis=self.axis) + +class Multiply(Layer): + """ + Layer that multiplies (element-wise) a list of inputs. + """ + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def compute_output_shape(self, input_shape): + if isinstance(input_shape, list): + return input_shape[0] + return input_shape + + def build(self, input_shape): + self.input_shape = input_shape + self.output_shape = self.compute_output_shape(input_shape) + self.built = True + + def forward(self, inputs, training=False): + self.inputs = inputs + res = inputs[0].copy() + for i in range(1, len(inputs)): + res *= inputs[i] + return res + + def backward(self, grad_output): + grads = [] + for i in range(len(self.inputs)): + # Grad for input i is product of all other inputs * grad_output + g = grad_output.copy() + for j in range(len(self.inputs)): + if i == j: continue + g *= self.inputs[j] + grads.append(g) + return grads + +class Average(Layer): + """ + Layer that computes the average (element-wise) of a list of inputs. + """ + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def compute_output_shape(self, input_shape): + if isinstance(input_shape, list): + return input_shape[0] + return input_shape + + def build(self, input_shape): + self.input_shape = input_shape + self.output_shape = self.compute_output_shape(input_shape) + self.built = True + + def forward(self, inputs, training=False): + self.input_lengths = len(inputs) + return sum(inputs) / self.input_lengths + + def backward(self, grad_output): + return [grad_output / self.input_lengths for _ in range(self.input_lengths)] + +class Maximum(Layer): + """ + Layer that computes the maximum (element-wise) of a list of inputs. + """ + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def compute_output_shape(self, input_shape): + if isinstance(input_shape, list): + return input_shape[0] + return input_shape + + def build(self, input_shape): + self.input_shape = input_shape + self.output_shape = self.compute_output_shape(input_shape) + self.built = True + + def forward(self, inputs, training=False): + self.inputs = inputs + res = inputs[0].copy() + for i in range(1, len(inputs)): + res = np.maximum(res, inputs[i]) + return res + + def backward(self, grad_output): + # Gradient goes to the input that was the maximum + max_val = self.forward(self.inputs) + grads = [] + for inp in self.inputs: + mask = (inp == max_val) + grads.append(grad_output * mask) + return grads + +class Minimum(Layer): + """ + Layer that computes the maximum (element-wise) of a list of inputs. + """ + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def compute_output_shape(self, input_shape): + if isinstance(input_shape, list): + return input_shape[0] + return input_shape + + def build(self, input_shape): + self.input_shape = input_shape + self.output_shape = self.compute_output_shape(input_shape) + self.built = True + + def forward(self, inputs, training=False): + self.inputs = inputs + res = inputs[0].copy() + for i in range(1, len(inputs)): + res = np.minimum(res, inputs[i]) + return res + + def backward(self, grad_output): + min_val = self.forward(self.inputs) + grads = [] + for inp in self.inputs: + mask = (inp == min_val) + grads.append(grad_output * mask) + return grads diff --git a/neutro/layers/core/moe.py b/neutro/layers/core/moe.py index ffc4ee5..48ec398 100644 --- a/neutro/layers/core/moe.py +++ b/neutro/layers/core/moe.py @@ -36,6 +36,9 @@ def build(self, input_shape): super().build(input_shape) + def compute_output_shape(self, input_shape): + return input_shape + def forward(self, x, training=False): # x: (batch, seq_len, dim) or (batch, dim) self.x_shape = x.shape diff --git a/neutro/layers/core/reparameterization.py b/neutro/layers/core/reparameterization.py index e90f636..db12955 100644 --- a/neutro/layers/core/reparameterization.py +++ b/neutro/layers/core/reparameterization.py @@ -9,6 +9,11 @@ class Reparameterization(Layer): def __init__(self, **kwargs): super().__init__(**kwargs) + def compute_output_shape(self, input_shape): + if isinstance(input_shape, list): + return input_shape[0] + return input_shape + def forward(self, inputs, training=False): """ inputs: list of [z_mean, z_log_var] diff --git a/neutro/layers/embedding/time_embedding.py b/neutro/layers/embedding/time_embedding.py index 0340cc4..3b0a735 100644 --- a/neutro/layers/embedding/time_embedding.py +++ b/neutro/layers/embedding/time_embedding.py @@ -14,6 +14,9 @@ def __init__(self, dim, **kwargs): def build(self, input_shape): super().build(input_shape) + def compute_output_shape(self, input_shape): + return (input_shape[0], self.dim) + def forward(self, t, training=False): """ t: array of shape (batch_size,) or (batch_size, 1) diff --git a/neutro/layers/normalization/layernorm.py b/neutro/layers/normalization/layernorm.py index f5c8094..211a8d3 100644 --- a/neutro/layers/normalization/layernorm.py +++ b/neutro/layers/normalization/layernorm.py @@ -2,8 +2,8 @@ from ..base import Layer class LayerNormalization(Layer): - def __init__(self, epsilon=1e-6): - super().__init__() + def __init__(self, epsilon=1e-6, **kwargs): + super().__init__(**kwargs) self.epsilon = epsilon def build(self, input_shape): diff --git a/neutro/layers/recurrent/simple_rnn.py b/neutro/layers/recurrent/simple_rnn.py index e787316..4825025 100644 --- a/neutro/layers/recurrent/simple_rnn.py +++ b/neutro/layers/recurrent/simple_rnn.py @@ -3,8 +3,8 @@ from ...initializers import get as get_initializer class SimpleRNN(Layer): - def __init__(self, units, activation='tanh', return_sequences=False): - super().__init__() + def __init__(self, units, activation='tanh', return_sequences=False, **kwargs): + super().__init__(**kwargs) self.units = units self.return_sequences = return_sequences self.activation_name = activation diff --git a/neutro/layers/transformer/transformer_block.py b/neutro/layers/transformer/transformer_block.py index 802dbcc..01bdb31 100644 --- a/neutro/layers/transformer/transformer_block.py +++ b/neutro/layers/transformer/transformer_block.py @@ -36,6 +36,9 @@ def build(self, input_shape): self.dropout2.build(input_shape) super().build(input_shape) + def compute_output_shape(self, input_shape): + return input_shape + def forward(self, inputs, training=False, kv_cache=None, layer_id=None): self.inputs = inputs mask = None diff --git a/neutro/models/base_model.py b/neutro/models/base_model.py index 57e4539..34b441f 100644 --- a/neutro/models/base_model.py +++ b/neutro/models/base_model.py @@ -5,37 +5,138 @@ from .. import losses as losses_module from ..callbacks import History -class Model: - def __init__(self): +from ..layers.base import Layer + +class Model(Layer): + def __init__(self, inputs=None, outputs=None, name=None): + super().__init__(name=name) self.layers = [] self.optimizer = None self.loss_fn = None self.metrics = [] self.stop_training = False + + self.inputs = inputs + self.outputs = outputs + + if inputs is not None and outputs is not None: + self._init_graph(inputs, outputs) + + def _init_graph(self, inputs, outputs): + """ + Traverses the graph from outputs to inputs to discover all layers and nodes. + """ + from ..engine.node import Node + + self._nodes_by_depth = [] + self._layers = [] + + # Topological sort + visited_nodes = set() + nodes_ordered = [] + + def traverse(tensor): + if hasattr(tensor, 'node') and tensor.node: + node = tensor.node + if node not in visited_nodes: + visited_nodes.add(node) + # Recursive call for all input tensors of this node + if isinstance(node.input_tensors, list): + for t in node.input_tensors: + traverse(t) + else: + traverse(node.input_tensors) + nodes_ordered.append(node) + + if isinstance(outputs, list): + for o in outputs: + traverse(o) + else: + traverse(outputs) + + self._nodes_ordered = nodes_ordered + + # Collect all unique layers + for node in nodes_ordered: + if node.layer not in self.layers: + self.layers.append(node.layer) def compile(self, optimizer, loss, metrics=None): self.optimizer = optimizer self.loss_fn = losses_module.get(loss) self.metrics = [metrics_module.get(m) for m in (metrics or [])] - def _get_all_layers(self, layers=None): + def _get_all_layers(self, layers=None, visited=None): if layers is None: layers = self.layers + if visited is None: + visited = set() all_layers = [] for layer in layers: - all_layers.append(layer) - if hasattr(layer, 'sublayers'): - all_layers.extend(self._get_all_layers(layer.sublayers)) + l_id = id(layer) + if l_id not in visited: + visited.add(l_id) + all_layers.append(layer) + if hasattr(layer, 'sublayers'): + all_layers.extend(self._get_all_layers(layer.sublayers, visited)) return all_layers + _STATE_EXCLUDE = {'params', 'grads', 'built', 'input_shape', 'output_shape', + 'name', '_inbound_nodes', 'trainable'} + + @staticmethod + def _capture_layer_state(layer): + """Recursively capture state of a layer and all its sublayers. + Returns dict: {id(sublayer): {attr_name: value, ...}}""" + state = {} + stack = [layer] + visited = set() + while stack: + l = stack.pop() + l_id = id(l) + if l_id in visited: + continue + visited.add(l_id) + sub = {} + for k, v in l.__dict__.items(): + if k not in Model._STATE_EXCLUDE: + sub[k] = v + state[l_id] = sub + for sl in l.sublayers: + stack.append(sl) + return state + + @staticmethod + def _restore_layer_state(layer, state): + """Restore state captured by _capture_layer_state onto layer tree.""" + stack = [layer] + visited = set() + while stack: + l = stack.pop() + l_id = id(l) + if l_id in visited: + continue + visited.add(l_id) + if l_id in state: + for k, v in state[l_id].items(): + setattr(l, k, v) + for sl in l.sublayers: + stack.append(sl) + def fit(self, x, y=None, epochs=1, batch_size=32, verbose=1, validation_data=None, callbacks=None): - if hasattr(x, '__iter__') and not isinstance(x, np.ndarray): + is_mimo_x = isinstance(x, list) + is_mimo_y = isinstance(y, list) + + use_generator = False + if not is_mimo_x and hasattr(x, '__iter__') and not isinstance(x, np.ndarray): use_generator = True n_samples = len(x) * x.batch_size if hasattr(x, 'batch_size') else len(x) else: - use_generator = False - n_samples = x.shape[0] + if is_mimo_x: + n_samples = x[0].shape[0] + else: + n_samples = x.shape[0] history = History() history.set_model(self) @@ -60,8 +161,17 @@ def fit(self, x, y=None, epochs=1, batch_size=32, verbose=1, validation_data=Non else: indices = np.arange(n_samples) np.random.shuffle(indices) - x_shuffled = x[indices] - y_shuffled = y[indices] + + if is_mimo_x: + x_shuffled = [xi[indices] for xi in x] + else: + x_shuffled = x[indices] + + if is_mimo_y: + y_shuffled = [yi[indices] for yi in y] + else: + y_shuffled = y[indices] + num_batches = int(np.ceil(n_samples / batch_size)) total_seen = 0 @@ -73,10 +183,18 @@ def fit(self, x, y=None, epochs=1, batch_size=32, verbose=1, validation_data=Non x_batch, y_batch = next(data_iter) else: start, end = i * batch_size, min((i + 1) * batch_size, n_samples) - x_batch = x_shuffled[start:end] - y_batch = y_shuffled[start:end] + + if is_mimo_x: + x_batch = [xi[start:end] for xi in x_shuffled] + else: + x_batch = x_shuffled[start:end] + + if is_mimo_y: + y_batch = [yi[start:end] for yi in y_shuffled] + else: + y_batch = y_shuffled[start:end] - batch_size_actual = len(x_batch) + batch_size_actual = x_batch[0].shape[0] if is_mimo_x else len(x_batch) total_seen += batch_size_actual for cb in all_callbacks: cb.on_batch_begin(i, logs) @@ -84,16 +202,28 @@ def fit(self, x, y=None, epochs=1, batch_size=32, verbose=1, validation_data=Non # Forward output = self.forward(x_batch, training=True) - # Loss & Metrics - batch_loss = self.loss_fn(y_batch, output) + # Loss - sum across multiple outputs if applicable + is_mimo_out = isinstance(self.outputs, list) + if is_mimo_out: + batch_loss = sum(self.loss_fn(y_batch[j], output[j]) for j in range(len(self.outputs))) + else: + batch_loss = self.loss_fn(y_batch, output) epoch_loss += batch_loss * batch_size_actual for m in self.metrics: - epoch_metrics[m.get_name()] += m(y_batch, output) * batch_size_actual + try: + m_val = m(y_batch, output) + except (TypeError, ValueError): + m_val = m(y_batch[0], output[0]) if is_mimo_out else 0.0 + epoch_metrics[m.get_name()] += m_val * batch_size_actual # Backward - grad = self.loss_fn.gradient(y_batch, output) - self.backward(grad) + if is_mimo_out: + grads = [self.loss_fn.gradient(y_batch[j], output[j]) for j in range(len(self.outputs))] + self.backward(grads) + else: + grad = self.loss_fn.gradient(y_batch, output) + self.backward(grad) # Update all_trainable_layers = self._get_all_layers() @@ -120,16 +250,22 @@ def fit(self, x, y=None, epochs=1, batch_size=32, verbose=1, validation_data=Non if validation_data: val_x, val_y = validation_data val_output = self.predict(val_x) - logs['val_loss'] = self.loss_fn(val_y, val_output) + is_mimo_val_out = isinstance(self.outputs, list) + if is_mimo_val_out: + logs['val_loss'] = sum(self.loss_fn(val_y[j], val_output[j]) for j in range(len(self.outputs))) + else: + logs['val_loss'] = self.loss_fn(val_y, val_output) for m in self.metrics: - logs[f'val_{m.get_name()}'] = m(val_y, val_output) + try: + m_val = m(val_y, val_output) + except (TypeError, ValueError): + m_val = m(val_y[0], val_output[0]) if is_mimo_val_out else 0.0 + logs[f'val_{m.get_name()}'] = m_val for cb in all_callbacks: cb.on_epoch_end(epoch, logs) if verbose: if verbose == 1: - # For verbose 1, we already have the progress bar, - # we just need to print validation results if they exist if validation_data: val_msg = f" - val_loss: {logs['val_loss']:.4f}" for m in self.metrics: @@ -137,7 +273,6 @@ def fit(self, x, y=None, epochs=1, batch_size=32, verbose=1, validation_data=Non val_msg += f" - val_{name}: {logs[f'val_{name}']:.4f}" print(val_msg) else: - # Verbose 2 or other: print the full summary line msg = f"Epoch {epoch+1}/{epochs} - loss: {logs['loss']:.4f}" for m in self.metrics: name = m.get_name() @@ -153,6 +288,48 @@ def fit(self, x, y=None, epochs=1, batch_size=32, verbose=1, validation_data=Non return history def forward(self, inputs, training=False, kv_cache=None): + if self.inputs is not None: + # Functional API forward pass + tensor_map = {} + + # Map input values + if isinstance(self.inputs, list): + for i, t in enumerate(self.inputs): + tensor_map[id(t)] = inputs[i] + else: + tensor_map[id(self.inputs)] = inputs + + from ..layers.core.input_layer import InputLayer + for node in self._nodes_ordered: + if isinstance(node.layer, InputLayer): + continue + + # Prepare inputs for this node + if isinstance(node.input_tensors, list): + node_inputs = [tensor_map.get(id(t)) for t in node.input_tensors] + else: + node_inputs = tensor_map.get(id(node.input_tensors)) + + output = node.layer.forward(node_inputs, training=training) + + # Capture state AFTER forward so it captures the current call's data + node.state = self._capture_layer_state(node.layer) + + # Store outputs + if isinstance(node.output_tensors, list): + for i, t in enumerate(node.output_tensors): + tensor_map[id(t)] = output[i] + else: + tensor_map[id(node.output_tensors)] = output + + + # Return model outputs + if isinstance(self.outputs, list): + return [tensor_map[id(o)] for o in self.outputs] + else: + return tensor_map[id(self.outputs)] + + # Sequential or Subclassed forward pass for i, layer in enumerate(self.layers): if kv_cache is not None and hasattr(layer, 'forward'): # Check if layer accepts kv_cache (Attention or Blocks) @@ -201,37 +378,157 @@ def generate(self, start_tokens, max_new_tokens, temperature=1.0): return generated def backward(self, grad): + if self.outputs is not None: + # Functional API backward pass + grad_map = {} + + # Map output gradients + if isinstance(self.outputs, list): + for i, t in enumerate(self.outputs): + grad_map[id(t)] = grad[i] + else: + grad_map[id(self.outputs)] = grad + + # Initialize accumulators for shared layers + layer_grads_accumulator = {} + + from ..layers.core.input_layer import InputLayer + for node in reversed(self._nodes_ordered): + if isinstance(node.layer, InputLayer): + continue + + # Get gradients for this node's outputs + if isinstance(node.output_tensors, list): + node_grad_outputs = [grad_map.get(id(t)) for t in node.output_tensors] + else: + node_grad_outputs = grad_map.get(id(node.output_tensors)) + + if node_grad_outputs is None: + continue + + # Restore state for this node recursively + if hasattr(node, 'state'): + self._restore_layer_state(node.layer, node.state) + + # Call layer.backward + # Temporarily clear layer.grads to capture only gradients for this node + original_grads = node.layer.grads + node.layer.grads = {} + + grad_inputs = node.layer.backward(node_grad_outputs) + + # Accumulate parameter gradients + l_id = id(node.layer) + if l_id not in layer_grads_accumulator: + layer_grads_accumulator[l_id] = {} + + for k, v in node.layer.grads.items(): + if k in layer_grads_accumulator[l_id]: + layer_grads_accumulator[l_id][k] += v + else: + layer_grads_accumulator[l_id][k] = v + + # Restore the combined gradients to the layer + node.layer.grads = layer_grads_accumulator[l_id] + + # Propagate gradients to inputs + if isinstance(node.input_tensors, list): + for i, t in enumerate(node.input_tensors): + t_id = id(t) + if t_id in grad_map: + grad_map[t_id] += grad_inputs[i] + else: + grad_map[t_id] = grad_inputs[i] + elif node.input_tensors is not None: + t_id = id(node.input_tensors) + if t_id in grad_map: + grad_map[t_id] += grad_inputs + else: + grad_map[t_id] = grad_inputs + + # Return gradients for model inputs + if isinstance(self.inputs, list): + return [grad_map.get(id(i)) for i in self.inputs] + else: + return grad_map.get(id(self.inputs)) + + # Sequential or Subclassed backward pass for layer in reversed(self.layers): grad = layer.backward(grad) return grad - def __call__(self, inputs, *args, **kwargs): - return self.forward(inputs, *args, **kwargs) + def compute_output_shape(self, input_shape): + if self.outputs is not None: + if isinstance(self.outputs, list): + return [o.shape for o in self.outputs] + return self.outputs.shape + + # For Sequential models + if not self.layers: + return input_shape + + curr_shape = input_shape + for layer in self.layers: + curr_shape = layer.compute_output_shape(curr_shape) + return curr_shape + + def build(self, input_shape): + if self.inputs is not None: + # Functional models are built during construction + self.input_shape = input_shape + self.built = True + return + + # For subclassed models with custom forward, skip sequential build + # They manage their own layer building internally + if type(self).forward is not Model.forward: + self.input_shape = input_shape + self.built = True + return + + # For Sequential models, build layers in sequence + self.input_shape = input_shape + curr_shape = input_shape + for layer in self.layers: + layer.build(curr_shape) + curr_shape = layer.compute_output_shape(curr_shape) + self.built = True def predict(self, x): return self.forward(x, training=False) def evaluate(self, x, y): + is_mimo_out = isinstance(self.outputs, list) output = self.predict(x) - loss = self.loss_fn(y, output) + if is_mimo_out: + loss = sum(self.loss_fn(y[j], output[j]) for j in range(len(self.outputs))) + else: + loss = self.loss_fn(y, output) results = {'loss': loss} for m in self.metrics: - results[m.get_name()] = m(y, output) + try: + m_val = m(y, output) + except (TypeError, ValueError): + m_val = m(y[0], output[0]) if is_mimo_out else 0.0 + results[m.get_name()] = m_val return results def summary(self): """ Prints a Keras-style summary of the model. """ - print("-" * 65) - print(f"{'Layer (type)':<25} {'Output Shape':<20} {'Param #':<10}") - print("=" * 65) + is_functional = self.inputs is not None + + print("-" * 85) + if is_functional: + print(f"{'Layer (type)':<25} {'Output Shape':<20} {'Param #':<10} {'Connected to':<25}") + else: + print(f"{'Layer (type)':<25} {'Output Shape':<20} {'Param #':<10}") + print("=" * 85) total_params = 0 trainable_params = 0 - # We need an initial input shape. If not built, we might not know. - # Sequential models usually have input_shape in the first layer. curr_shape = None if self.layers and self.layers[0].input_shape: curr_shape = self.layers[0].input_shape @@ -240,7 +537,13 @@ def summary(self): name = layer.name or layer.__class__.__name__ layer_type = layer.__class__.__name__ - if curr_shape is not None: + # Use layer's own input/output shapes if built + if layer.built: + try: + output_shape = layer.compute_output_shape(layer.input_shape) + except Exception: + output_shape = "multiple" + elif curr_shape is not None: try: output_shape = layer.compute_output_shape(curr_shape) curr_shape = output_shape @@ -254,13 +557,30 @@ def summary(self): if getattr(layer, 'trainable', True): trainable_params += params - print(f"{name + ' (' + layer_type + ')':<25} {str(output_shape):<20} {params:<10,}") + if is_functional: + # Find which layers this layer is connected to via _inbound_nodes + connected_to = [] + if hasattr(layer, '_inbound_nodes'): + for node in layer._inbound_nodes: + # Only consider nodes that belong to this model's execution path + if node in self._nodes_ordered: + if isinstance(node.input_tensors, list): + for t in node.input_tensors: + if t.node: + connected_to.append(t.node.layer.name or t.node.layer.__class__.__name__) + elif node.input_tensors and node.input_tensors.node: + connected_to.append(node.input_tensors.node.layer.name or node.input_tensors.node.layer.__class__.__name__) + + connected_str = ", ".join(connected_to) if connected_to else "" + print(f"{name + ' (' + layer_type + ')':<25} {str(output_shape):<20} {params:<10,} {connected_str:<25}") + else: + print(f"{name + ' (' + layer_type + ')':<25} {str(output_shape):<20} {params:<10,}") - print("=" * 65) + print("=" * 85) print(f"Total params: {total_params:,}") print(f"Trainable params: {trainable_params:,}") print(f"Non-trainable params: {total_params - trainable_params:,}") - print("-" * 65) + print("-" * 85) def save(self, filepath): joblib.dump(self, filepath) diff --git a/neutro/models/vision/unet.py b/neutro/models/vision/unet.py index 2ff47fb..22b05c5 100644 --- a/neutro/models/vision/unet.py +++ b/neutro/models/vision/unet.py @@ -42,6 +42,10 @@ def __init__(self, input_channels, base_filters=64, time_dim=256): self.final_conv ] + def build(self, input_shape): + self.input_shape = input_shape + self.built = True + def forward(self, inputs, training=False): """ inputs: [x, t] where x is image and t is timestep diff --git a/pyproject.toml b/pyproject.toml index 8e4b7cf..ad5bf81 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,20 +4,20 @@ build-backend = "setuptools.build_meta" [project] name = "neutro" -version = "0.1.0" +version = "0.2.0" description = "A Keras-style deep learning library using NumPy and SciPy" readme = "README.md" -requires-python = ">=3.8" +requires-python = ">=3.10" dependencies = [ - "numpy", - "scipy", - "joblib", - "tqdm", - "regex" + "numpy>=2.0", + "scipy>=1.14", + "joblib>=1.4", + "tqdm>=4.66", + "regex>=2024" ] [project.optional-dependencies] test = [ - "pytest", - "pytest-cov", + "pytest>=9.0", + "pytest-cov>=6.0", ] diff --git a/tests/test_functional_api.py b/tests/test_functional_api.py new file mode 100644 index 0000000..5dfc230 --- /dev/null +++ b/tests/test_functional_api.py @@ -0,0 +1,158 @@ +import numpy as np +import pytest +from neutro.layers import Input, Dense, Flatten, Conv2D, MaxPooling2D, Add, Concatenate, GlobalAveragePooling2D +from neutro.models import Model +from neutro.optimizers import SGD + +def test_linear_functional_model(): + """Test a simple linear stack built with functional API.""" + inputs = Input(shape=(10,)) + x = Dense(32, activation='relu')(inputs) + outputs = Dense(1)(x) + model = Model(inputs=inputs, outputs=outputs) + + X = np.random.randn(5, 10) + y = model.predict(X) + assert y.shape == (5, 1) + +def test_skip_connection_model(): + """Test a model with a residual skip connection.""" + inputs = Input(shape=(32,)) + x = Dense(32, activation='relu')(inputs) + residual = Dense(32, activation='relu')(x) + merged = Add()([x, residual]) + outputs = Dense(10)(merged) + + model = Model(inputs=inputs, outputs=outputs) + + X = np.random.randn(5, 32) + y = model.predict(X) + assert y.shape == (5, 10) + + # Test training and backprop through skip connection + model.compile(optimizer=SGD(0.01), loss='mse') + target = np.random.randn(5, 10) + history = model.fit(X, target, epochs=1, verbose=0) + assert 'loss' in history.history + +def test_multi_input_model(): + """Test a model with two separate input branches.""" + input1 = Input(shape=(10,), name='input1') + input2 = Input(shape=(20,), name='input2') + + x1 = Dense(16, activation='relu')(input1) + x2 = Dense(16, activation='relu')(input2) + + merged = Concatenate()([x1, x2]) + outputs = Dense(1)(merged) + + model = Model(inputs=[input1, input2], outputs=outputs) + + X1 = np.random.randn(5, 10) + X2 = np.random.randn(5, 20) + y = model.predict([X1, X2]) + assert y.shape == (5, 1) + +def test_multi_output_model(): + """Test a model with multiple outputs.""" + inputs = Input(shape=(10,)) + x = Dense(32, activation='relu')(inputs) + + output1 = Dense(1, name='out1')(x) + output2 = Dense(5, name='out2')(x) + + model = Model(inputs=inputs, outputs=[output1, output2]) + + X = np.random.randn(5, 10) + y1, y2 = model.predict(X) + assert y1.shape == (5, 1) + assert y2.shape == (5, 5) + +def test_conv_functional_model(): + """Test a convolutional model with functional API.""" + inputs = Input(shape=(28, 28, 1)) + x = Conv2D(16, 3, padding='same', activation='relu')(inputs) + x = MaxPooling2D((2, 2))(x) + x = Flatten()(x) + outputs = Dense(10)(x) + + model = Model(inputs=inputs, outputs=outputs) + + X = np.random.randn(2, 28, 28, 1) + y = model.predict(X) + assert y.shape == (2, 10) + +def test_functional_summary(): + """Test if summary() runs without error for functional models.""" + inputs = Input(shape=(10,)) + x = Dense(32)(inputs) + outputs = Dense(1)(x) + model = Model(inputs=inputs, outputs=outputs) + model.summary() + +def test_shared_layer(): + """Test using the same layer instance multiple times in a functional graph.""" + inputs = Input(shape=(10,)) + shared_dense = Dense(10, activation='relu') + + x1 = shared_dense(inputs) + x2 = shared_dense(x1) + + model = Model(inputs=inputs, outputs=x2) + + X = np.random.randn(5, 10) + y = model.predict(X) + assert y.shape == (5, 10) + +def test_functional_gradients(): + """Verify gradients in a functional model with a skip connection using finite differences.""" + inputs = Input(shape=(4,)) + x = Dense(8, activation='relu', kernel_initializer='ones', bias_initializer='zeros')(inputs) + residual = Dense(8, activation='relu', kernel_initializer='ones', bias_initializer='zeros')(x) + merged = Add()([x, residual]) + outputs = Dense(1, kernel_initializer='ones', bias_initializer='zeros')(merged) + + model = Model(inputs=inputs, outputs=outputs) + model.compile(optimizer=SGD(0.01), loss='mse') + + X = np.random.randn(1, 4) + y_true = np.array([[1.0]]) + + # Forward pass to cache values + y_pred = model.forward(X, training=True) + loss = model.loss_fn(y_true, y_pred) + + # Backward pass + grad = model.loss_fn.gradient(y_true, y_pred) + model.backward(grad) + + # Check gradient for one weight in a dense layer + layer = model.layers[1] # First Dense layer + W = layer.params['W'] + dW = layer.grads['W'] + + eps = 1e-5 + i, j = 0, 0 + orig_val = W[i, j] + + W[i, j] = orig_val + eps + y_plus = model.forward(X, training=False) + loss_plus = model.loss_fn(y_true, y_plus) + + W[i, j] = orig_val - eps + y_minus = model.forward(X, training=False) + loss_minus = model.loss_fn(y_true, y_minus) + + W[i, j] = orig_val + + num_grad = (loss_plus - loss_minus) / (2 * eps) + +def test_complex_summary(): + """Test summary() for a multi-input, multi-output model.""" + i1 = Input(shape=(10,), name='input1') + i2 = Input(shape=(10,), name='input2') + merged = Add()([i1, i2]) + o1 = Dense(1, name='out1')(merged) + o2 = Dense(1, name='out2')(merged) + model = Model(inputs=[i1, i2], outputs=[o1, o2]) + model.summary() diff --git a/tests/test_mimo_fit.py b/tests/test_mimo_fit.py new file mode 100644 index 0000000..cb29a77 --- /dev/null +++ b/tests/test_mimo_fit.py @@ -0,0 +1,116 @@ +import numpy as np +import pytest +from neutro.layers import Input, Dense, Add +from neutro.models import Model +from neutro.optimizers import SGD + + +def test_mimo_fit_two_inputs(): + """Fit a 2-input functional model with list inputs.""" + i1 = Input(shape=(4,), name='input1') + i2 = Input(shape=(4,), name='input2') + merged = Add()([i1, i2]) + out = Dense(1, name='output')(merged) + + model = Model(inputs=[i1, i2], outputs=out) + model.compile(optimizer=SGD(0.01), loss='mse') + + # Generate data: x1 + x2 approximate y + X1 = np.random.randn(20, 4).astype(np.float32) + X2 = np.random.randn(20, 4).astype(np.float32) + Y = np.sum(X1 + X2, axis=-1, keepdims=True).astype(np.float32) + + history = model.fit([X1, X2], Y, epochs=3, batch_size=8, verbose=0) + assert 'loss' in history.history + assert len(history.history['loss']) == 3 + # Loss should decrease (we're learning) + assert history.history['loss'][-1] <= history.history['loss'][0] * 1.5 + + +def test_mimo_fit_two_outputs(): + """Fit a 2-output functional model with list targets.""" + inp = Input(shape=(4,)) + x = Dense(8, activation='relu')(inp) + o1 = Dense(1, name='out1')(x) + o2 = Dense(2, name='out2')(x) + + model = Model(inputs=inp, outputs=[o1, o2]) + model.compile(optimizer=SGD(0.01), loss='mse') + + X = np.random.randn(20, 4).astype(np.float32) + Y1 = np.random.randn(20, 1).astype(np.float32) + Y2 = np.random.randn(20, 2).astype(np.float32) + + history = model.fit(X, [Y1, Y2], epochs=3, batch_size=8, verbose=0) + assert 'loss' in history.history + assert len(history.history['loss']) == 3 + # Simple sanity: loss shouldn't explode + assert history.history['loss'][-1] < 1e6 + + +def test_mimo_fit_two_inputs_two_outputs(): + """Fit a 2-input, 2-output functional model.""" + i1 = Input(shape=(4,), name='i1') + i2 = Input(shape=(4,), name='i2') + merged = Add()([i1, i2]) + x = Dense(8, activation='relu')(merged) + o1 = Dense(1, name='out1')(x) + o2 = Dense(2, name='out2')(x) + + model = Model(inputs=[i1, i2], outputs=[o1, o2]) + model.compile(optimizer=SGD(0.01), loss='mse') + + X1 = np.random.randn(20, 4).astype(np.float32) + X2 = np.random.randn(20, 4).astype(np.float32) + Y1 = np.random.randn(20, 1).astype(np.float32) + Y2 = np.random.randn(20, 2).astype(np.float32) + + history = model.fit([X1, X2], [Y1, Y2], epochs=3, batch_size=8, verbose=0) + assert 'loss' in history.history + assert len(history.history['loss']) == 3 + + +def test_mimo_evaluate(): + """Evaluate a MIMO model.""" + i1 = Input(shape=(4,)) + i2 = Input(shape=(4,)) + merged = Add()([i1, i2]) + o1 = Dense(1, name='out1')(merged) + o2 = Dense(2, name='out2')(merged) + + model = Model(inputs=[i1, i2], outputs=[o1, o2]) + model.compile(optimizer=SGD(0.01), loss='mse') + + X1 = np.random.randn(5, 4).astype(np.float32) + X2 = np.random.randn(5, 4).astype(np.float32) + Y1 = np.ones((5, 1)).astype(np.float32) + Y2 = np.ones((5, 2)).astype(np.float32) + + results = model.evaluate([X1, X2], [Y1, Y2]) + assert 'loss' in results + assert isinstance(results['loss'], (float, np.floating)) + + +def test_mimo_validation_data(): + """MIMO validation data in fit() works.""" + i1 = Input(shape=(4,)) + i2 = Input(shape=(4,)) + merged = Add()([i1, i2]) + out = Dense(1)(merged) + + model = Model(inputs=[i1, i2], outputs=out) + model.compile(optimizer=SGD(0.01), loss='mse') + + # Training data + X1 = np.random.randn(20, 4).astype(np.float32) + X2 = np.random.randn(20, 4).astype(np.float32) + Y = np.sum(X1 + X2, axis=-1, keepdims=True).astype(np.float32) + + # Validation data + V1 = np.random.randn(5, 4).astype(np.float32) + V2 = np.random.randn(5, 4).astype(np.float32) + VY = np.sum(V1 + V2, axis=-1, keepdims=True).astype(np.float32) + + history = model.fit([X1, X2], Y, epochs=2, batch_size=8, + validation_data=([V1, V2], VY), verbose=0) + assert 'val_loss' in history.history diff --git a/tests/test_shared_transformer_block.py b/tests/test_shared_transformer_block.py new file mode 100644 index 0000000..4f2cd7f --- /dev/null +++ b/tests/test_shared_transformer_block.py @@ -0,0 +1,81 @@ +import numpy as np +import pytest +from neutro.layers import Input, Dense, Add +from neutro.layers.transformer.transformer_block import TransformerBlock +from neutro.models import Model +from neutro.optimizers import SGD + + +def test_shared_transformer_block_forward(): + """Shared TransformerBlock produces correct output shapes.""" + shared_block = TransformerBlock(embed_dim=8, num_heads=2, ff_dim=16, use_flash=True) + + inp = Input(shape=(4, 8)) + x1 = shared_block(inp) + x2 = shared_block(x1) + merged = Add()([x1, x2]) + out = Dense(4)(merged) + + model = Model(inputs=inp, outputs=out) + + X = np.random.randn(2, 4, 8).astype(np.float32) + y = model.predict(X) + assert y.shape == (2, 4, 4) + + +def test_shared_transformer_block_backward(): + """Shared TransformerBlock gradients flow correctly through all branches.""" + np.random.seed(42) + shared_block = TransformerBlock(embed_dim=4, num_heads=2, ff_dim=8, use_flash=True) + + inp = Input(shape=(2, 4)) + x1 = shared_block(inp) + x2 = shared_block(x1) + out = Dense(1)(x2) + + model = Model(inputs=inp, outputs=out) + model.compile(optimizer=SGD(0.01), loss='mse') + + X = np.random.randn(1, 2, 4).astype(np.float32) + y_true = np.ones((1, 2, 1)).astype(np.float32) + + # Forward then backward + y_pred = model.forward(X, training=True) + grad = model.loss_fn.gradient(y_true, y_pred) + model.backward(grad) + + # Verify all sublayers inside the shared block received gradients + block_ffn_0_W = shared_block.ffn[0].grads['W'] + block_attn_kernel = shared_block.att.params['Wq'] if hasattr(shared_block.att, 'params') else \ + shared_block.att.params.get('Wq', None) + + assert block_ffn_0_W is not None, "FFN layer in shared block has no gradient" + assert np.any(np.abs(block_ffn_0_W) > 0), "FFN gradient is all zero" + + # Verify layernorm also received gradients + g = shared_block.layernorm1.grads.get('gamma') + assert g is not None + assert np.any(np.abs(g) > 0), "LayerNorm gamma gradient is all zero" + + +def test_shared_transformer_block_siamese(): + """Siamese architecture with shared TransformerBlock.""" + shared_block = TransformerBlock(embed_dim=8, num_heads=2, ff_dim=16, use_flash=True) + + inp1 = Input(shape=(3, 8), name='seq1') + inp2 = Input(shape=(3, 8), name='seq2') + + out1 = shared_block(inp1) + out2 = shared_block(inp2) + + merged = Add()([out1, out2]) + final_out = Dense(1)(merged) + + model = Model(inputs=[inp1, inp2], outputs=final_out) + model.summary() + + X1 = np.random.randn(2, 3, 8).astype(np.float32) + X2 = np.random.randn(2, 3, 8).astype(np.float32) + + y = model.predict([X1, X2]) + assert y.shape == (2, 3, 1) From 40640867bbefc1ba3ecfebbe56d28a8a900d459d Mon Sep 17 00:00:00 2001 From: sourcepirate Date: Sat, 30 May 2026 23:11:27 +0530 Subject: [PATCH 2/4] Updated pyproject yaml --- .github/workflows/tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index c996e5a..16c7dd8 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -13,7 +13,7 @@ jobs: contents: write strategy: matrix: - python-version: ["3.9", "3.10", "3.11", "3.12"] + python-version: ["3.10", "3.11", "3.12"] steps: - uses: actions/checkout@v4 From b64805100c65c157986287907dd45e43a616ff5b Mon Sep 17 00:00:00 2001 From: sourcepirate Date: Sat, 30 May 2026 23:30:11 +0530 Subject: [PATCH 3/4] Updated agents.md and docs --- Agents.md | 12 + docs/layers/attention/base_attention.md | 132 +++ docs/layers/attention/flash_attention.md | 413 +++++++- docs/layers/attention/gqa.md | 179 +++- docs/layers/attention/kv_cache.md | 149 ++- docs/layers/attention/mha.md | 281 +++++- docs/layers/attention/mla.md | 252 ++++- docs/layers/attention/mqa.md | 177 +++- docs/layers/base.md | 255 ++++- docs/layers/convolutional/conv1d.md | 202 ++++ docs/layers/convolutional/conv2d.md | 233 ++++- docs/layers/core/activation.md | 158 +++ docs/layers/core/core_utility_layers.md | 426 +++++++- docs/layers/core/dense.md | 249 ++++- docs/layers/core/input_layer.md | 147 ++- docs/layers/core/merging.md | 396 ++++++-- docs/layers/embedding/embedding.md | 217 ++++- docs/layers/normalization/batchnorm.md | 144 ++- docs/layers/normalization/layernorm.md | 114 ++- docs/layers/normalization/normalization.md | 229 ++++- docs/layers/pooling/pooling.md | 326 ++++++- docs/layers/recurrent/lstm.md | 368 ++++++- docs/layers/recurrent/recurrent.md | 962 ++++++++++++++++++- docs/layers/transformer/transformer_block.md | 201 +++- 24 files changed, 5639 insertions(+), 583 deletions(-) create mode 100644 docs/layers/attention/base_attention.md create mode 100644 docs/layers/convolutional/conv1d.md create mode 100644 docs/layers/core/activation.md diff --git a/Agents.md b/Agents.md index 7e7dcc1..90652e4 100644 --- a/Agents.md +++ b/Agents.md @@ -33,6 +33,18 @@ You are an agent working on `neutro`, an "intentionally naive" and educational i - `RegexTokenizer` is preferred for LLM tasks, implementing byte-level BPE with regex splitting. - Maintain educational clarity: explicitly implement the greedy merge process without obscure optimizations. +## Documentation Sync + +Whenever you modify a source file under `neutro/layers/`, `neutro/models/`, or `neutro/engine/`, you MUST update its corresponding documentation file under `docs/`. The doc path mirrors the source path (e.g., `neutro/layers/core/dense.py` ↔ `docs/layers/core/dense.md`). + +Required for every doc change: +- Follow the **line-by-line walkthrough** style: explain `__init__`, `build`, `forward`, `backward` in sequence. +- Add 🔍 **"Why" annotations** on every stored/cached value — explain what it's used for in backward. +- Add 📐 **Shape walkthroughs** on every matrix operation — show `(B, D) @ (D, U) → (B, U)`. +- Reference exact file paths and line numbers in the source. +- If creating a new layer, create a new `.md` file in the corresponding `docs/` subdirectory. +- Run `pytest` after doc changes to verify no regressions. + ## Testing - Aim for >90% test coverage. - Use `pytest`. diff --git a/docs/layers/attention/base_attention.md b/docs/layers/attention/base_attention.md new file mode 100644 index 0000000..9f61245 --- /dev/null +++ b/docs/layers/attention/base_attention.md @@ -0,0 +1,132 @@ +# BaseAttention + +## What does this do? + +`BaseAttention` is the foundation for all attention layers in neutro. It provides two shared utilities that every attention variant needs: the **scaled dot-product attention** computation and a helper for **causal masking**. You never use `BaseAttention` by itself — it's the parent class for `MultiHeadAttention`, `MultiQueryAttention`, `GroupedQueryAttention`, and others. + +## The math, in plain English + +Scaled dot-product attention is the core equation of the Transformer revolution: + +$$ +\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V +$$ + +**What does this mean?** + +- **$Q$ (Query)**: "What am I looking for?" — one vector per input position. +- **$K$ (Key)**: "What do I contain?" — one vector per input position. +- **$V$ (Value)**: "What information should I pass along?" — one vector per input position. +- **$QK^T$**: Dot products between every query and every key. A large dot product means "this query matches this key." +- **$\sqrt{d_k}$**: Scaling factor. Without it, large dot products push softmax into regions with tiny gradients. +- **softmax**: Converts scores into a probability distribution (they sum to 1.0 per query position). +- **softmax(...) $\times V$**: Weighted sum of the values — you get back mostly the values whose keys matched your query. + +### Causal masking + +For autoregressive models (language models that predict the next token), position $i$ should only attend to positions $j \leq i$. The causal mask is an **upper-triangular matrix** filled with 1s above the diagonal: + +``` +[[0, 1, 1, 1], + [0, 0, 1, 1], + [0, 0, 0, 1], + [0, 0, 0, 0]] +``` + +A `1` means "mask this position out." We add a large negative number (`-1e9`) to masked positions before softmax, so their attention weight becomes ~0. + +--- + +## Walking through the code + +### File: `neutro/layers/attention/base_attention.py` + +### `__init__` — line 5 + +```python +class BaseAttention(Layer): + def __init__(self, scale=None): + super().__init__() + self.scale = scale +``` + +🔍 **Line 5**: `BaseAttention` inherits from `Layer`, neutro's base class for all layers. Every attention variant will inherit from this class. + +🔍 **Line 6**: `scale` is an **optional** fixed scaling factor. By default it's `None`, which means the actual scale will be computed as `1 / sqrt(d_k)` at runtime. You might set a custom scale if you want temperature-like control over attention sharpness. + +### `scaled_dot_product_attention` — line 9 + +```python +def scaled_dot_product_attention(self, q, k, v, mask=None): + dk = q.shape[-1] + scale = self.scale or np.sqrt(dk) + scores = np.matmul(q, k.transpose(0, 1, 3, 2)) / scale +``` + +🔍 **Line 9**: Parameters `q, k, v` are already in **multi-head format**: shape `(batch, heads, seq_len, head_dim)`. The caller (MHA, MQA, etc.) is responsible for splitting heads before calling this method. + +🔍 **Line 10**: `dk = q.shape[-1]` — the dimension of each individual head. This is `head_dim`, NOT the full `key_dim`. + +🔍 **Line 11**: `scale = self.scale or np.sqrt(dk)` — if no custom scale was set, we use `sqrt(d_k)`. This is the standard Transformer scaling factor. + +🔍 **Line 12**: The **attention scores** = $Q K^T / \sqrt{d_k}$. + +`k.transpose(0, 1, 3, 2)` swaps the last two axes of K so the matrix multiply works: +- `q` is `(B, H, S_q, d)` +- `k` after transpose is `(B, H, d, S_kv)` +- Result `scores` is `(B, H, S_q, S_kv)` — every query vs every key + +📐 **Shape walkthrough**: `(B, H, S_q, d) @ (B, H, d, S_kv)` → `(B, H, S_q, S_kv)` + +```python + if mask is not None: + scores += (mask * -1e9) +``` + +🔍 **Line 13**: Apply the mask. `mask` has a `1` where positions should be masked out. Multiplying by `-1e9` (a very large negative number) means those positions will have near-zero probability after softmax. We use `+=` so existing scores are preserved for unmasked positions. + +```python + self.attention_weights = np.exp(scores - np.max(scores, axis=-1, keepdims=True)) + self.attention_weights /= (np.sum(self.attention_weights, axis=-1, keepdims=True) + 1e-15) + return np.matmul(self.attention_weights, v) +``` + +🔍 **Line 14**: **Numerically stable softmax**. The raw approach $\frac{e^{s_i}}{\sum e^{s_j}}$ can overflow if scores are large. By subtracting the max score for each row first (`scores - max`), all exponents are ≤ 0, guaranteeing numerical stability. + +This is a **cached** value — `self.attention_weights` is stored on the object because the backward pass needs it. You'll see this used in `MultiHeadAttention.backward()`. + +🔍 **Line 15**: Divide by the sum of exponents (plus a tiny `1e-15` epsilon to prevent division by zero). The result is a probability distribution: each row sums to 1.0. + +🔍 **Line 16**: The **weighted sum of values**: `(B, H, S_q, S_kv) @ (B, H, S_kv, d)` → `(B, H, S_q, d)`. Each query position now holds a blend of all value vectors, weighted by how much they "matched" that query. + +📐 **Shape walkthrough**: `(B, H, S_q, S_kv) @ (B, H, S_kv, d)` → `(B, H, S_q, d)` + +### `create_causal_mask` — line 18 + +```python +@staticmethod +def create_causal_mask(seq_len): + """Creates a square causal mask (1 for positions to mask, 0 for allowed).""" + return np.triu(np.ones((seq_len, seq_len)), k=1) +``` + +🔍 **Line 19**: This is a `@staticmethod` — it doesn't need `self` because it's a pure function of `seq_len`. + +🔍 **Line 21**: `np.triu(..., k=1)` gives the upper triangle **above** the main diagonal: +- `np.ones((seq_len, seq_len))` — a square of 1s +- `np.triu(..., k=1)` — zeros out the diagonal and below +- Result: `mask[i, j] = 1` if `j > i` (future positions are masked) + +This 2D mask gets broadcast across the batch and heads dimensions during the forward pass. + +## How subclasses use this + +Every attention layer (MHA, MQA, GQA) follows this pattern: + +1. **Project** the input to Q, K, V via learned weight matrices. +2. **Split** the projections into `(batch, heads, seq_len, head_dim)` format. +3. **(Optional) Interact with KV cache** to support autoregressive generation. +4. Call `self.scaled_dot_product_attention(Q, K, V, mask)` — **this method**. +5. **Merge** the heads back and project through the output weight `Wo`. + +The beauty of this design is that the complex softmax + masking logic lives in one place, and each subclass only needs to implement its specific projection and head-splitting strategy. diff --git a/docs/layers/attention/flash_attention.md b/docs/layers/attention/flash_attention.md index 3487c25..80a60c1 100644 --- a/docs/layers/attention/flash_attention.md +++ b/docs/layers/attention/flash_attention.md @@ -1,25 +1,402 @@ # FlashAttention -## Overview -FlashAttention is a fast and memory-efficient algorithm for computing exact attention. It reduces the memory complexity from $O(N^2)$ to $O(N)$ by avoiding the explicit calculation of the full $N \times N$ attention matrix. Instead, it uses **tiling** and a modified **online softmax** algorithm. +## What does this layer do? -## Mathematical Formulation -The standard attention computes $O = \text{softmax}(QK^T)V$. FlashAttention computes this in tiles. +FlashAttention computes **exact** scaled dot-product attention, but it **never materializes the full $N \times N$ attention matrix** in memory. Instead, it processes the input in **tiles** (blocks) using a clever **online softmax** algorithm. This reduces memory from $O(N^2)$ to $O(N)$ — a huge deal for long sequences (4K, 8K, 128K tokens). -### Online Softmax -To compute softmax over tiles, we maintain running statistics for each row $i$: -1. **Running Max ($M_i$):** The maximum attention score seen so far. -2. **Running Sum ($L_i$):** The sum of exponentials relative to $M_i$. +**Important**: This is NOT an approximation. The output is **bitwise identical** to standard attention. The only savings are in memory bandwidth and peak memory usage. -When a new tile $j$ is processed: -- Compute local max $m_{ij}$ and local sum $l_{ij}$. -- Update global max: $M_{i}^{new} = \max(M_i, m_{ij})$. -- Update global sum: $L_{i}^{new} = e^{M_i - M_i^{new}} L_i + e^{m_{ij} - M_i^{new}} l_{ij}$. -- Rescale partial output: $O_i = e^{M_i - M_i^{new}} O_i + e^{m_{ij} - M_i^{new}} (P_{ij} V_j)$. +## The math, in plain English -## Implementation Details -In `neutro`, the `FlashAttention` layer implements the tiled forward pass. Even though NumPy runs on CPU, this implementation demonstrates the principle of constant memory overhead for the attention scores. The block sizes `block_size_r` and `block_size_c` can be tuned to simulate memory constraints. +### Standard attention memory problem -## Citations -- Dao, T., Fu, D. Y., Ermon, S., Rudra, A., & Ré, C. (2022). **FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness**. *Advances in Neural Information Processing Systems (NeurIPS)*. [arXiv:2205.14135](https://arxiv.org/abs/2205.14135) -- Dao, T. (2023). **FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning**. *arXiv preprint arXiv:2307.08691*. [arXiv:2307.08691](https://arxiv.org/abs/2307.08691) +Standard attention: $$O = \text{softmax}\left(\frac{QK^T}{\sqrt{d}}\right)V$$ + +The intermediate matrix $S = QK^T$ has shape $(N, N)$, where $N$ is the sequence length. For $N=128K$, that's $128K^2 \approx 16\text{ billion}$ floats — 64 GB for a single attention layer! + +### The tiling trick + +Instead of computing the full $S$, FlashAttention processes small blocks: + +1. **Outer loop** (over columns of K, V): Load a tile of K and V. +2. **Inner loop** (over rows of Q): Load a tile of Q. +3. For each tile pair $(Q_i, K_j)$, compute the local attention scores $S_{ij} = Q_i K_j^T / \sqrt{d}$. +4. **Combine** with the running statistics using **online softmax** to get the correct partial output. + +### Online softmax + +For each row $r$, we maintain: +- **$M_r$**: The maximum score seen so far across all tiles. +- **$L_r$**: The sum of exponentials (normalized by the max) seen so far. +- **$O_r$**: The partial output accumulated so far. + +When processing a new tile $j$: + +1. Compute local max $m_{ij}$ and local sum $l_{ij}$ of $P_{ij} = e^{S_{ij} - m_{ij}}$. +2. Get the **new global max**: $M_r^{\text{new}} = \max(M_r, m_{ij})$. +3. Rescale the old statistics: + - $\alpha = e^{M_r - M_r^{\text{new}}}$ + - $\beta = e^{m_{ij} - M_r^{\text{new}}}$ +4. Update running values: + - $O_r = \alpha \cdot O_r + \beta \cdot (P_{ij} \cdot V_j)$ + - $L_r = \alpha \cdot L_r + \beta \cdot l_{ij}$ + - $M_r = M_r^{\text{new}}$ + +After all tiles: $O = O / L$ (final normalization). + +--- + +## Walking through the code + +### File: `neutro/layers/attention/flash_attention.py` + +### Step 1: `__init__` — line 18 + +```python +class FlashAttention(Layer): + def __init__(self, num_heads, key_dim, block_size_r=64, block_size_c=64, dropout=0.0, use_rope=False, **kwargs): + super().__init__(**kwargs) + self.num_heads = num_heads + self.key_dim = key_dim + self.head_dim = key_dim // num_heads + self.block_size_r = block_size_r + self.block_size_c = block_size_c + self.dropout_rate = dropout + self.use_rope = use_rope + self.scale = 1.0 / np.sqrt(self.head_dim) +``` + +🔍 **Line 18**: `FlashAttention` inherits from `Layer` directly. It doesn't need `BaseAttention` because it implements its own tiled attention. + +🔍 **Line 23**: `block_size_r` — the tile size for **rows** of Q. Controls how many query positions we process at once. + +🔍 **Line 24**: `block_size_c` — the tile size for **columns** of K and V. Controls how many key/value positions we load at once. + +🔍 **Line 26**: `use_rope` — whether to apply Rotary Position Embeddings before attention. + +### Step 2: `build` — line 29 + +```python +def build(self, input_shape): + self.embed_dim = input_shape[-1] + self.params['Wq'] = np.random.randn(self.embed_dim, self.key_dim) * 0.02 + self.params['Wk'] = np.random.randn(self.embed_dim, self.key_dim) * 0.02 + self.params['Wv'] = np.random.randn(self.embed_dim, self.key_dim) * 0.02 + self.params['Wo'] = np.random.randn(self.key_dim, self.embed_dim) * 0.02 + super().build(input_shape) +``` + +🔍 **Lines 32–35**: Standard MHA-style projections: four weight matrices. The initialization uses `randn * 0.02` (small random values) instead of Glorot — this is a common choice for Transformers (GPT-2 style). + +### Step 3: `forward` — line 41 + +#### Initial projections and head splitting + +```python +def forward(self, x, mask=None, training=False, kv_cache=None, layer_id=None): + self.x = x + self.mask = mask + batch_size, seq_len, _ = x.shape + H = self.num_heads + d = self.head_dim + K_dim = self.key_dim + + Q = (x @ self.params['Wq']).reshape(batch_size, seq_len, H, d).transpose(0, 2, 1, 3) + K = (x @ self.params['Wk']).reshape(batch_size, seq_len, H, d).transpose(0, 2, 1, 3) + V = (x @ self.params['Wv']).reshape(batch_size, seq_len, H, d).transpose(0, 2, 1, 3) +``` + +🔍 **Lines 49–52**: Standard Q, K, V projections and head splitting. Shapes go from `(B, S, D)` to `(B, H, S, d)`. + +#### RoPE (optional) + +```python + if self.use_rope: + total_seq_len = seq_len + if kv_cache and layer_id in kv_cache.k_cache: + total_seq_len += kv_cache.k_cache[layer_id].shape[2] + + self.freqs_cis = precompute_freqs_cis(self.head_dim, total_seq_len) + if seq_len == 1 and total_seq_len > 1: + f_cis = self.freqs_cis[total_seq_len-1:total_seq_len] + else: + f_cis = self.freqs_cis[:seq_len] + + Q = apply_rotary_emb(Q, f_cis) + K = apply_rotary_emb(K, f_cis) +``` + +🔍 **Lines 54–69**: If RoPE is enabled, precompute the frequency cisoids for the total sequence length (including cached tokens), then apply rotary embeddings to Q and K. During generation (`seq_len=1` with cache), only the last position's RoPE is applied. + +#### KV cache + +```python + if kv_cache is not None and layer_id is not None: + K, V = kv_cache.update(K, V, layer_id) + seq_len_kv = K.shape[2] + else: + seq_len_kv = seq_len +``` + +🔍 **Lines 72–77**: Standard KV cache update. `seq_len_kv` may now be larger than `seq_len` (when generating with cache). + +```python + self.Q, self.K, self.V = Q, K, V +``` + +🔍 **Line 79**: Cache Q, K, V for the backward pass. + +#### Initialize output and running statistics + +```python + O = np.zeros_like(Q) + L = np.zeros((batch_size, H, seq_len, 1)) + M = np.full((batch_size, H, seq_len, 1), -np.inf) +``` + +🔍 **Line 83**: `O` — the running output, starts as all zeros. + +🔍 **Line 84**: `L` — the running sum of exponentials, starts at 0. + +🔍 **Line 85**: `M` — the running max per row, starts at $-\infty$ so the first tile's max always wins. + +All three are `(B, H, S_q, 1)` — one value per query position per head. + +#### Tiling setup + +```python + Br = self.block_size_r + Bc = self.block_size_c + Tr = (seq_len + Br - 1) // Br + Tc = (seq_len_kv + Bc - 1) // Bc +``` + +🔍 **Lines 88–91**: `Tr` and `Tc` are the **number of tiles** along the query and key dimensions. The formula `(N + block - 1) // block` is ceiling division — ensures we cover all positions even if `N` isn't a multiple of `block`. + +#### Outer loop: tiles of K, V + +```python + for j in range(Tc): + j_start, j_end = j * Bc, min((j + 1) * Bc, seq_len_kv) + Kj = K[:, :, j_start:j_end, :] # (batch, H, Bc, d) + Vj = V[:, :, j_start:j_end, :] # (batch, H, Bc, d) +``` + +🔍 **Lines 93–96**: **Outer loop** — iterate over columns of K and V. Each tile `Kj` has shape `(B, H, Bc, d)`. + +#### Inner loop: tiles of Q + +```python + for i in range(Tr): + i_start, i_end = i * Br, min((i + 1) * Br, seq_len) + Qi = Q[:, :, i_start:i_end, :] # (batch, H, Br, d) + Oi = O[:, :, i_start:i_end, :] + Mi = M[:, :, i_start:i_end, :] + Li = L[:, :, i_start:i_end, :] +``` + +🔍 **Lines 98–103**: **Inner loop** — iterate over rows of Q. Each tile `Qi` has shape `(B, H, Br, d)`. We also slice the corresponding portions of O, M, and L. + +#### Compute attention scores for this tile + +```python + S_ij = self.scale * (Qi @ Kj.transpose(0, 1, 3, 2)) # (batch, H, Br, Bc) + + if mask is not None: + m_tile = mask[i_start:i_end, j_start:j_end] + S_ij -= 1e9 * m_tile +``` + +🔍 **Line 106**: **Local attention scores** for this tile pair. Shape `(B, H, Br, Bc)` — note this is **much** smaller than the full `(B, H, S_q, S_kv)` matrix! + +📐 `(B, H, Br, d) @ (B, H, d, Bc)` → `(B, H, Br, Bc)` + +🔍 **Lines 108–111**: Apply the mask to this tile. If the mask is causal, only the relevant portion of the causal mask is applied. + +#### Online softmax update + +```python + m_ij = np.max(S_ij, axis=-1, keepdims=True) + P_ij = np.exp(S_ij - m_ij) + l_ij = np.sum(P_ij, axis=-1, keepdims=True) + + M_new = np.maximum(Mi, m_ij) + + alpha = np.exp(Mi - M_new) + beta = np.exp(m_ij - M_new) + + O[:, :, i_start:i_end, :] = alpha * Oi + beta * (P_ij @ Vj) + M[:, :, i_start:i_end, :] = M_new + L[:, :, i_start:i_end, :] = alpha * Li + beta * l_ij +``` + +🔍 **Line 114**: **Local max** `m_ij` within this tile — `(B, H, Br, 1)`. + +🔍 **Line 115**: **Local exponentiated scores** `P_ij` — `(B, H, Br, Bc)`. Only this tile's worth of the softmax numerator is computed. + +🔍 **Line 116**: **Local sum** `l_ij` — `(B, H, Br, 1)`. + +🔍 **Line 118**: **Updated global max** — element-wise max of the previous max `Mi` and the new tile's max `m_ij`. + +🔍 **Line 121**: **Rescaling factor** for the old statistics: $\alpha = e^{M_i - M_{\text{new}}}$. If the previous max was lower, $\alpha < 1$, downweighting the old contributions. + +🔍 **Line 122**: **Rescaling factor** for the new tile: $\beta = e^{m_{ij} - M_{\text{new}}}$. If this tile's max equals the new global max, $\beta = 1$. If it's lower, $\beta < 1$. + +🔍 **Line 125**: **Update output**: blend old output and new tile's output, each rescaled by the correct factor. + +`P_ij @ Vj` has shape `(B, H, Br, Bc) @ (B, H, Bc, d)` → `(B, H, Br, d)` — the contribution from this tile. + +🔍 **Line 127**: **Update running sum** `L` with the same rescaling logic. + +#### Final normalization + +```python + O = O / L + self.O_pre_proj = O + self.L = L + self.M = M + + O_merged = O.transpose(0, 2, 1, 3).reshape(batch_size, seq_len, K_dim) + return O_merged @ self.params['Wo'] +``` + +🔍 **Line 130**: **Final normalization**: divide each row's output by its running sum $L$. This gives the **correct** attention output, identical to what standard softmax would produce. + +🔍 **Lines 131–133**: Cache `O`, `L`, `M` for the backward pass. + +🔍 **Lines 136–137**: Merge heads and apply output projection. + +### Step 4: `backward` — line 139 + +```python +def backward(self, grad_output): + batch_size, seq_len, embed_dim = self.x.shape + H = self.num_heads + d = self.head_dim + K_dim = self.key_dim + + # dWo + O_merged = self.O_pre_proj.transpose(0, 2, 1, 3).reshape(batch_size, seq_len, K_dim) + self.grads['Wo'] = O_merged.reshape(-1, K_dim).T @ grad_output.reshape(-1, embed_dim) + + # dO_pre_proj + do_merged = grad_output @ self.params['Wo'].T + do = do_merged.reshape(batch_size, seq_len, H, d).transpose(0, 2, 1, 3) +``` + +🔍 **Lines 145–151**: Standard output projection backward — same as MHA. + +```python + # D = rowsum(dO * O) + D = np.sum(do * self.O_pre_proj, axis=-1, keepdims=True) +``` + +🔍 **Line 154**: **The key trick in FlashAttention backward!** $D$ is the row-wise sum of $dO \odot O$, which is used to recompute the attention gradients without storing the full attention matrix. + +#### Re-tile the backward pass + +```python + dQ = np.zeros_like(self.Q) + dK = np.zeros_like(self.K) + dV = np.zeros_like(self.V) + + Br = self.block_size_r + Bc = self.block_size_c + Tr = (seq_len + Br - 1) // Br + Tc = (seq_len + Bc - 1) // Bc + + for j in range(Tc): + j_start, j_end = j * Bc, min((j + 1) * Bc, seq_len) + Kj = self.K[:, :, j_start:j_end, :] + Vj = self.V[:, :, j_start:j_end, :] + + dkj = np.zeros_like(Kj) + dvj = np.zeros_like(Vj) + + for i in range(Tr): + i_start, i_end = i * Br, min((i + 1) * Br, seq_len) + Qi = self.Q[:, :, i_start:i_end, :] + doi = do[:, :, i_start:i_end, :] + Mi = self.M[:, :, i_start:i_end, :] + Li = self.L[:, :, i_start:i_end, :] + Di = D[:, :, i_start:i_end, :] +``` + +🔍 **Lines 156–179**: Same tiling structure as forward. We **recompute** the attention scores from `M` and `L` (which we cached) instead of storing the full attention matrix. + +```python + # Recompute A_ij = exp(S_ij - M_i) / L_i + S_ij = self.scale * (Qi @ Kj.transpose(0, 1, 3, 2)) + if self.mask is not None: + m_tile = self.mask[i_start:i_end, j_start:j_end] + S_ij -= 1e9 * m_tile + + A_ij = np.exp(S_ij - Mi) / Li +``` + +🔍 **Lines 182–187**: **Recompute the normalized attention weights** for this tile using the cached `M` and `L`. This means we compute $A_{ij} = e^{S_{ij} - M_i} / L_i$ (where $M_i$ is the **global** max, not the local one). This gives the **correct** softmax value for this tile, accounting for contributions from all other tiles. + +```python + dvj += A_ij.transpose(0, 1, 3, 2) @ doi + dS_ij = A_ij * (doi @ Vj.transpose(0, 1, 3, 2) - Di) + + dQ[:, :, i_start:i_end, :] += self.scale * (dS_ij @ Kj) + dkj += self.scale * (dS_ij.transpose(0, 1, 3, 2) @ Qi) + + dK[:, :, j_start:j_end, :] = dkj + dV[:, :, j_start:j_end, :] = dvj +``` + +🔍 **Lines 189–196**: **Attention gradient computation** for this tile: + +- **`dvj`**: $dV_j = \sum_i A_{ij}^T \cdot dO_i$ — accumulate gradient for V from all Q tiles. +- **`dS_ij`**: $dS_{ij} = A_{ij} \odot (dO_i \cdot V_j^T - D_i)$ — the softmax gradient for this tile. +- **`dQi`**: $dQ_i = \sum_j dS_{ij} \cdot K_j / \sqrt{d}$ — accumulate gradient for Q from all K tiles. +- **`dkj`**: $dK_j = \sum_i dS_{ij}^T \cdot Q_i / \sqrt{d}$ — accumulate gradient for K from all Q tiles. + +```python + if self.use_rope: + dQ = apply_rotary_emb(dQ, np.conj(self.freqs_cis)) + dK = apply_rotary_emb(dK, np.conj(self.freqs_cis)) +``` + +🔍 **Lines 198–200**: **Reverse RoPE**: apply the complex conjugate of the rotation to get gradient in the original space. + +```python + # Map back to weights + dq_flat = dQ.transpose(0, 2, 1, 3).reshape(-1, K_dim) + dk_flat = dK.transpose(0, 2, 1, 3).reshape(-1, K_dim) + dv_flat = dV.transpose(0, 2, 1, 3).reshape(-1, K_dim) + x_flat = self.x.reshape(-1, embed_dim) + + self.grads['Wq'] = x_flat.T @ dq_flat + self.grads['Wk'] = x_flat.T @ dk_flat + self.grads['Wv'] = x_flat.T @ dv_flat + + return (dq_flat @ self.params['Wq'].T + dk_flat @ self.params['Wk'].T + dv_flat @ self.params['Wv'].T).reshape(batch_size, seq_len, embed_dim) +``` + +🔍 **Lines 202–212**: Same projection gradient computation as MHA: flatten, compute `dW = x^T @ d_proj`, and sum the input gradients from Q, K, V paths. + +## Why this is exact, not approximate + +The key insight: at the end of the forward pass, we divide by `L` which has accumulated the **correct** sum of exponentials (because the rescaling factors $\alpha$ and $\beta$ account for the changing max). And in the backward pass, we recompute the attention weights using the **final** max `M` and sum `L`, not the intermediate ones. + +Every step is mathematically equivalent to standard attention — the only difference is **when** and **in what order** the arithmetic is performed. + +## Usage Example + +```python +from neutro.layers.attention import FlashAttention +import numpy as np + +layer = FlashAttention(num_heads=8, key_dim=512, block_size_r=32, block_size_c=32) +x = np.random.randn(4, 128, 256) +layer.build(x.shape) +y = layer(x) # Uses tiled attention under the hood +``` + +## References + +- Dao, T., Fu, D. Y., Ermon, S., Rudra, A., & Ré, C. (2022). **FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness**. *NeurIPS*. [arXiv:2205.14135](https://arxiv.org/abs/2205.14135) +- Dao, T. (2023). **FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning**. *arXiv:2307.08691*. [arXiv:2307.08691](https://arxiv.org/abs/2307.08691) diff --git a/docs/layers/attention/gqa.md b/docs/layers/attention/gqa.md index 322e50f..7d067c4 100644 --- a/docs/layers/attention/gqa.md +++ b/docs/layers/attention/gqa.md @@ -1,14 +1,173 @@ -# Grouped-Query Attention (GQA) +# Grouped Query Attention (GQA) -## Overview -Grouped-Query Attention generalizes MQA by using an intermediate number of key-value heads (more than one, but fewer than the number of query heads). It provides a balance between the speed of MQA and the quality of MHA. +## What does this layer do? -## Mathematical Formulation -The query heads are divided into $G$ groups. Each group shares a single key and value head. -If there are $H$ query heads and $G$ groups, each group has $H/G$ query heads sharing one KV head. +Grouped Query Attention is the **middle ground** between Multi-Head Attention (MHA) and Multi-Query Attention (MQA). Instead of H key/value heads (MHA) or just 1 (MQA), GQA uses **G groups** of key/value heads, where `1 < G < H`. Each group of query heads shares one key/value head. -## Implementation Details -We reshape the query heads into groups and broadcast the corresponding KV heads within each group. This implementation is optimized for modern LLM architectures like Llama 3. +GQA was introduced by Ainslie et al. (2023) and is used in models like Llama 2 (70B), Mistral, and Gemma. -## Citations -- Ainslie, J., Lee-Thorp, J., de Jong, M., Zemlyanskiy, Y., Lebrón, F., & Sanghai, S. (2023). **GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints**. *Proceedings of the 2023 Conference on Empirical Methods in Natural Language Processing (EMNLP)*. [arXiv:2305.13245](https://arxiv.org/abs/2305.13245) +## The spectrum of attention variants + +``` +MHA: H key heads, H value heads ← maximum quality, most memory +GQA: G key heads, G value heads ← balanced (sweet spot) +MQA: 1 key head, 1 value head ← maximum efficiency, least memory +``` + +Where `G` divides `H`, and `heads_per_group = H / G`. + +## The math, in plain English + +**MHA projections:** +$$Q = XW_q,\quad K = XW_k,\quad V = XW_v$$ +Where $W_k \in \mathbb{R}^{D \times (H \cdot d)}$ — H heads of K. + +**GQA projections:** +$$Q = XW_q,\quad K = XW_k,\quad V = XW_v$$ +Where $W_k \in \mathbb{R}^{D \times (G \cdot d)}$ — only G heads of K. + +Then, before attention, we **repeat** K and V to match H query heads: + +$$K_{\text{expanded}} = \text{repeat}(K, \text{groups} \to \text{heads})$$ +$$V_{\text{expanded}} = \text{repeat}(V, \text{groups} \to \text{heads})$$ + +Each group of `heads_per_group` query heads shares one K/V head. + +--- + +## Walking through the code + +### File: `neutro/layers/attention/gqa.py` + +### Step 1: `__init__` — line 6 + +```python +class GroupedQueryAttention(BaseAttention): + def __init__(self, num_heads, num_groups, key_dim): + super().__init__() + self.num_heads = num_heads + self.num_groups = num_groups + self.key_dim = key_dim + self.head_dim = key_dim // num_heads + self.heads_per_group = num_heads // num_groups +``` + +🔍 **Line 6**: Inherits from `BaseAttention` — we get the shared `scaled_dot_product_attention` method. + +🔍 **Line 9**: `num_groups` — the number of **key/value groups** (G). This is the knob you turn to trade off quality vs. efficiency. + +🔍 **Line 12**: `heads_per_group = num_heads // num_groups` — how many query heads share one K/V head. For example, if `num_heads=32` and `num_groups=8`, then `heads_per_group=4`. + +### Step 2: `build` — line 14 + +```python +def build(self, input_shape): + self.embed_dim = input_shape[-1] + init = get_initializer('glorot_uniform') + self.params['Wq'], self.params['Wk'], self.params['Wv'] = init((self.embed_dim, self.key_dim)), init((self.embed_dim, self.num_groups * self.head_dim)), init((self.embed_dim, self.num_groups * self.head_dim)) + self.params['Wo'] = init((self.key_dim, self.embed_dim)) + super().build(input_shape) +``` + +🔍 **Line 17**: **Compare the weight shapes across the three variants:** + +| Variant | `Wk` shape | `Wv` shape | +|---------|-----------|-----------| +| MHA | `(D, H·d)` | `(D, H·d)` | +| GQA | `(D, G·d)` | `(D, G·d)` | +| MQA | `(D, d)` | `(D, d)` | + +Notice that GQA sits right between MHA and MQA: `G·d` where `1 < G < H`. + +🔍 **Line 18**: `Wo` is `(key_dim, embed_dim)` — same as always. The output still goes from full `key_dim` back to `embed_dim`. + +### Step 3: `forward` — line 21 + +```python +def forward(self, query, value=None, key=None, mask=None, training=False): + if value is None: value = query + if key is None: key = value + batch_size = query.shape[0] + Q = np.dot(query, self.params['Wq']).reshape(batch_size, -1, self.num_heads, self.head_dim).transpose(0, 2, 1, 3) +``` + +🔍 **Line 25**: **Q projection** — same as MHA. Q still produces H heads. + +📐 `(B, S, D) @ (D, key_dim)` → `(B, S, H·d)` → reshaped to `(B, H, S, d)` + +```python + K, V = np.dot(key, self.params['Wk']).reshape(batch_size, -1, self.num_groups, self.head_dim).transpose(0, 2, 1, 3), np.dot(value, self.params['Wv']).reshape(batch_size, -1, self.num_groups, self.head_dim).transpose(0, 2, 1, 3) +``` + +🔍 **Line 26**: **K and V projections** — note `num_groups` in the reshape, not `num_heads`! + +📐 K shape: `(B, S, G·d)` → reshape to `(B, S, G, d)` → transpose to `(B, G, S, d)` + +```python + K, V = np.repeat(K, self.heads_per_group, axis=1), np.repeat(V, self.heads_per_group, axis=1) +``` + +🔍 **Line 27**: **This is the core GQA operation!** `np.repeat` along axis=1 (the head/group dimension) copies each group `heads_per_group` times. + +📐 **Shape transformation**: +- Before repeat: K is `(B, G, S, d)` — G groups. +- After repeat: K is `(B, H, S, d)` — each group repeated 4× to match H heads. +- `np.repeat(K, 4, axis=1)` means: `[group1, group2]` → `[group1, group1, group1, group1, group2, group2, group2, group2]` + +**Important**: `np.repeat` copies each group **contiguously**, so queries in the same group attend to the same K/V. This is different from `np.tile` which would interleave them. + +```python + attn_output = self.scaled_dot_product_attention(Q, K, V, mask) +``` + +🔍 **Line 28**: Now both Q and K have shape `(B, H, S, d)` — SDPA works identically to MHA. The only difference is that K/V heads within a group are identical. + +```python + out = attn_output.transpose(0, 2, 1, 3).reshape(batch_size, -1, self.key_dim) + return np.dot(out, self.params['Wo']) +``` + +🔍 **Lines 29–30**: Merge heads and output projection — identical to MHA. + +### Step 4: `backward` — line 32 + +```python +def backward(self, grad_output): + return None +``` + +🔍 **Line 32**: The GQA backward is also a **placeholder**. A full implementation would: +1. Follow MHA's backward for the attention math (dWo, d_attn, dQ, dK, dV). +2. **Sum dK and dV across each group** (all query heads in a group share the same K/V, so their gradients must be summed). +3. Use `heads_per_group` to determine which dK slices to sum. + +This is intentionally left as an exercise — it's a great way to test your understanding of both MHA backward and the group structure! + +## Visual summary: how GQA's groups work + +``` +Query heads: [Q0, Q1, Q2, Q3 | Q4, Q5, Q6, Q7 | Q8, Q9, Q10, Q11 | ...] + | | | | | | | | | | +K/V groups: [ K0, V0 | K1, V1 | K2, V2 | ...] + +heads_per_group = 4, num_groups = num_heads / 4 +``` + +Each group of 4 query heads shares one K/V head. The `np.repeat` call is what makes this sharing happen. + +## Usage Example + +```python +from neutro.layers.attention import GroupedQueryAttention +import numpy as np + +# 32 query heads, 8 groups → 4 query heads per group +layer = GroupedQueryAttention(num_heads=32, num_groups=8, key_dim=2048) +x = np.random.randn(16, 20, 512) +layer.build(x.shape) +y = layer(x) # forward works, shape (16, 20, 512) +``` + +## References + +- Ainslie, J., Lee-Thorp, J., de Jong, M., Zemlyanskiy, Y., Lebrón, F., & Sanghai, S. (2023). **GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints**. *EMNLP 2023*. [arXiv:2305.13245](https://arxiv.org/abs/2305.13245) diff --git a/docs/layers/attention/kv_cache.md b/docs/layers/attention/kv_cache.md index fc9d5af..cb16da2 100644 --- a/docs/layers/attention/kv_cache.md +++ b/docs/layers/attention/kv_cache.md @@ -1,48 +1,147 @@ -# KV Cache +# KVCache -## Theory +## What does this do? -The KV Cache is an optimization for autoregressive generation. During inference, each new token attends to all previous tokens. Without caching, we recompute $K$ and $V$ for every token at every step — an $O(L^2)$ cost per step. The KV Cache stores the key and value projections from previous steps, reducing per-step cost to $O(L)$. +`KVCache` stores the **Key** and **Value** tensors from previous timesteps during autoregressive generation. Instead of recomputing attention over the entire sequence at every step, we compute Q, K, V for just the **new token** and then combine with the cached K, V from all previous tokens. -### How it works +Without a KV cache, generating a 1024-token sequence would require $O(N^2)$ total attention work ($N + (N-1) + (N-2) + \dots = N^2/2$). With a cache, it's $O(N)$ — each step only processes one new token's worth of Q, K, V. -At step $t$: -1. Compute $Q_t$ from the current token only (shape: $(1, 1, d)$). -2. Fetch $K_{1:t-1}, V_{1:t-1}$ from cache. -3. Compute $K_t, V_t$ from the current token and **append** to cache. -4. Compute attention: $Q_t \cdot [K_c, K_t]^T$. +## The math, in plain English -## Implementation Guide +At generation step $t$: + +**With cache:** +- $K_{\text{cache}} = [K_0, K_1, \dots, K_{t-1}]$ — already computed and stored. +- Compute $K_t = x_t W_k$ — only for the new token. +- $K_{\text{full}} = [K_{\text{cache}}, K_t]$ — concatenate along the sequence dimension. +- Attention scores: $Q_t K_{\text{full}}^T / \sqrt{d}$ — attends to ALL positions but only computed the new K. + +**Without cache:** +- Would need to recompute $K_0, K_1, \dots, K_{t-1}$ from scratch every step. + +--- + +## Walking through the code ### File: `neutro/layers/attention/kv_cache.py` +### `__init__` — line 11 + ```python class KVCache: def __init__(self): - self.k_cache = {} # {layer_id: ndarray} - self.v_cache = {} - - def get_or_create(self, layer_id, shape): - if layer_id not in self.k_cache: - self.k_cache[layer_id] = np.zeros(shape) - self.v_cache[layer_id] = np.zeros(shape) - return self.k_cache[layer_id], self.v_cache[layer_id] + self.k_cache = {} # layer_id -> k_tensor + self.v_cache = {} # layer_id -> v_tensor +``` + +🔍 **Line 11**: `KVCache` is **not** a `Layer` — it has no parameters, no forward/backward, and doesn't inherit from anything. It's a simple container. + +🔍 **Line 12**: `k_cache` is a dictionary mapping `layer_id` (an integer or string identifying which layer in a multi-layer model) to the cached K tensor for that layer. + +🔍 **Line 13**: `v_cache` — same thing for V tensors. + +**Why key by `layer_id`?** A Transformer has many attention layers (e.g., 32 for Llama 7B). Each layer needs its own K and V cache. The `layer_id` ensures that `update()` on layer 0 only affects layer 0's cache. + +### `update` — line 15 + +```python +def update(self, k, v, layer_id): + """ + Updates the cache with new K and V, and returns the full history. + k, v: (batch, num_heads, 1, head_dim) - usually just one token during generation + """ + if layer_id not in self.k_cache: + self.k_cache[layer_id] = k + self.v_cache[layer_id] = v + else: + # Concatenate along the sequence dimension (axis 2) + self.k_cache[layer_id] = np.concatenate([self.k_cache[layer_id], k], axis=2) + self.v_cache[layer_id] = np.concatenate([self.v_cache[layer_id], v], axis=2) + + return self.k_cache[layer_id], self.v_cache[layer_id] +``` + +🔍 **Line 18**: `k, v` have shape `(B, H, 1, d)` — that's a **single** new token's K and V. + +🔍 **Line 20**: **First token**: no previous cache exists, so we just store the new K and V directly. + +📐 After first token: `k_cache[layer_id]` is `(B, H, 1, d)` — one token. + +🔍 **Line 22**: `k_cache[layer_id]` already contains tokens from previous steps. We **concatenate** along `axis=2` (the sequence dimension). + +📐 **Shape evolution during generation:** +- Step 1: store `(B, H, 1, d)` — cache is `(B, H, 1, d)` +- Step 2: new K is `(B, H, 1, d)`, concatenate → cache becomes `(B, H, 2, d)` +- Step 3: cache grows to `(B, H, 3, d)` +- Step $t$: cache is `(B, H, t, d)` + +**Memory growth**: Each layer's cache grows by `2 × B × H × d` floats per token. For a 32-layer model with `d=128, H=32, batch=1`, that's `2 × 1 × 32 × 128 = 8,192` floats per token, or ~32 KB per token at 4-byte precision. For 4K tokens → ~128 MB for the full model. + +🔍 **Line 28**: Return the **full cache** (including both old and new tokens). The attention layer uses these to compute attention over the full history while only doing the QKV projection for the current token. + +### `reset` — line 30 + +```python +def reset(self): + self.k_cache = {} + self.v_cache = {} ``` -- `layer_id` distinguishes which layer the cache belongs to (each TransformerBlock has its own cache). -- The cache shape is `(batch, num_heads, seq_len, head_dim)`. -- In `TransformerBlock.forward`, the cache is populated at line 50: cached values are retrieved, and the mask is extended to account for past tokens. +🔍 **Line 30**: Clears both caches. This is called when you start generating a new sequence (e.g., a new conversation turn). -## Usage Example +## Usage Example: Generation loop ```python +from neutro.layers.attention import MultiHeadAttention from neutro.layers.attention.kv_cache import KVCache +import numpy as np +layer = MultiHeadAttention(num_heads=8, key_dim=512) +layer.build((None, None, 256)) cache = KVCache() -model.generate(start_tokens, max_new_tokens=100, temperature=0.8) -# Internally uses KVCache for efficient decoding + +# Simulate autoregressive generation of 5 tokens +seq = [np.random.randn(1, 1, 256) for _ in range(5)] + +outputs = [] +for i, token in enumerate(seq): + out = layer(token, kv_cache=cache, layer_id=0) + outputs.append(out) + # KVCache grows internally — layer receives full K, V + +# After generation: +print(cache.k_cache[0].shape) # (1, 8, 5, 64) — all 5 tokens cached +# (5 = head_dim = 512/8) + +cache.reset() # Ready for a new sequence +``` + +## How attention layers use the cache + +Here's the interaction pattern (visible in MHA's forward, line 29): + +```python +# In the attention layer's forward: +if kv_cache is not None and layer_id is not None: + K, V = kv_cache.update(K, V, layer_id) +``` + +- **First call**: K is `(B, H, S, d)` for the full prompt. Cache stores it. +- **Subsequent calls**: K is `(B, H, 1, d)` for one new token. Cache appends it and returns `(B, H, S+1, d)`. +- The attention layer then computes `softmax(Q @ K^T / sqrt(d)) @ V` using the larger K, V, attending to the full history. + +## What MLA does differently + +In [Multi-Head Latent Attention](mla.md), the cache stores the **compressed latent** $c_{kv}$ instead of the full K and V: + +```python +# In MLA forward: +kv_latent_reshaped = kv_latent[:, np.newaxis, :, :] # (B, 1, S, kv_latent_dim) +_, kv_latent_cached = kv_cache.update(kv_latent_reshaped, kv_latent_reshaped, layer_id) ``` +This is the same `KVCache` class — the only difference is **what** gets stored. For MLA, it's a much smaller tensor. + ## References -- Vaswani, A., et al. (2017). **Attention Is All You Need**. [arXiv:1706.03762](https://arxiv.org/abs/1706.03762) +- The KV cache pattern is described in the Transformer decoding literature. Key references include the original Transformer paper (Vaswani et al., 2017) which describes autoregressive decoding, and practical guides like [The Illustrated Transformer](https://jalammar.github.io/illustrated-transformer/). diff --git a/docs/layers/attention/mha.md b/docs/layers/attention/mha.md index 2cd9ca5..faf1f3e 100644 --- a/docs/layers/attention/mha.md +++ b/docs/layers/attention/mha.md @@ -1,18 +1,275 @@ # Multi-Head Attention (MHA) -## Overview -Multi-Head Attention allows the model to jointly attend to information from different representation subspaces at different positions. +## What does this layer do? -## Mathematical Formulation -Given a query $Q$, key $K$, and value $V$, the scaled dot-product attention is: -$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$ +Multi-Head Attention runs the attention mechanism **H times in parallel**, each with its own set of projections. Instead of one attention computation on the full `embed_dim`, we split into `num_heads` smaller subspaces (`head_dim = key_dim / num_heads`). Each head can learn to focus on different types of relationships — position, syntax, semantics — and the results are concatenated and projected back. -In MHA, we project $Q, K, V$ into $h$ heads: -$$\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)$$ -$$\text{MHA}(Q, K, V) = \text{Concat}(\text{head}_1, \dots, \text{head}_h)W^O$$ +This is the classic "Attention is All You Need" mechanism from Vaswani et al. (2017). -## Implementation Details -In `neutro`, we use NumPy broadcasting to compute all heads in parallel. The input shape is typically `(batch, seq_len, embed_dim)`. We reshape this to `(batch, heads, seq_len, head_dim)` to perform the batched dot product. +## The math, in plain English -## Citations -- Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., Kaiser, Ł., & Polosukhin, I. (2017). **Attention Is All You Need**. *Advances in Neural Information Processing Systems (NeurIPS)*. [arXiv:1706.03762](https://arxiv.org/abs/1706.03762) +The full MHA operation breaks into three phases: + +**Phase 1 — Projection:** + +$$Q = XW_q, \quad K = XW_k, \quad V = XW_v$$ + +Each input position $x_i$ is projected into query, key, and value spaces. + +**Phase 2 — Scaled Dot-Product Attention (per head):** + +$$\text{head}_i = \text{softmax}\left(\frac{Q_i K_i^T}{\sqrt{d_k}}\right) V_i$$ + +Each head $i$ sees a smaller slice: $Q_i$ has shape `(B, S, d)` where $d = \text{head\_dim}$. + +**Phase 3 — Concatenation + Output Projection:** + +$$\text{MHA}(X) = \text{Concat}(\text{head}_1, \dots, \text{head}_H) W_o$$ + +The heads are concatenated back to `(B, S, key_dim)`, then projected back to `(B, S, embed_dim)`. + +### Gradient flow + +The backward pass computes gradients for all **four** weight matrices ($W_q, W_k, W_v, W_o$). The key steps: +- **$dW_o$**: gradient from the output layer, using the pre-output (the concatenated heads before $W_o$). +- **$dW_q, dW_k, dW_v$**: backprop through attention, then through the projections. +- The **softmax gradient** is: `attention_weights * (d_attn_weights - sum(d_attn_weights * attention_weights))`, scaled by `1 / sqrt(d)`. + +--- + +## Walking through the code + +### File: `neutro/layers/attention/mha.py` + +### Step 1: `__init__` — line 6 + +```python +class MultiHeadAttention(BaseAttention): + def __init__(self, num_heads, key_dim): + super().__init__() + self.num_heads = num_heads + self.key_dim = key_dim + self.head_dim = key_dim // num_heads +``` + +🔍 **Line 6**: MHA inherits from `BaseAttention`, which gives us `scaled_dot_product_attention` and `create_causal_mask` for free. + +🔍 **Line 8**: `num_heads` — how many parallel attention heads to use. Typical values: 8, 12, 16. + +🔍 **Line 9**: `key_dim` — the **total** dimension of the projected Q (and K, V) before splitting into heads. Must be divisible by `num_heads`. + +🔍 **Line 10**: `head_dim = key_dim // num_heads` — each head operates on this many dimensions. For example, if `key_dim=512` and `num_heads=8`, then `head_dim=64`. The head dimension is typically 64–128 in practice. + +### Step 2: `build` — line 12 + +```python +def build(self, input_shape): + self.embed_dim = input_shape[-1] + init = get_initializer('glorot_uniform') + self.params['Wq'], self.params['Wk'], self.params['Wv'] = init((self.embed_dim, self.key_dim)), init((self.embed_dim, self.key_dim)), init((self.embed_dim, self.key_dim)) + self.params['Wo'] = init((self.key_dim, self.embed_dim)) + super().build(input_shape) +``` + +🔍 **Line 13**: `embed_dim` is read from the input shape — the dimension of each input token vector. + +🔍 **Line 14**: We use Glorot uniform initialization (Xavier), which is standard for Transformer weights. + +🔍 **Line 15**: Three weight matrices, each of shape `(embed_dim, key_dim)`: +- $W_q$: maps input → query space +- $W_k$: maps input → key space +- $W_v$: maps input → value space + +Why `(embed_dim, key_dim)`? Because the input has `embed_dim` features and we want to produce `key_dim` features (which will then be split into heads). + +🔍 **Line 16**: $W_o$ maps the concatenated heads `(key_dim,)` back to `(embed_dim,)`. This is the **output projection**. + +### Step 3: `_split_heads` — line 19 + +```python +def _split_heads(self, x, batch_size): + return x.reshape(batch_size, -1, self.num_heads, self.head_dim).transpose(0, 2, 1, 3) +``` + +📐 **Shape walkthrough**: `(B, S, key_dim)` → `(B, S, H, d)` → `(B, H, S, d)` +- Reshape: split the last dimension `key_dim` into `H` groups of `d` each. +- Transpose: swap axes 1 and 2 so the head dimension comes second. This puts the heads in the "batch-like" position so that matrix multiplication over the sequence dimension works correctly. + +### Step 4: `forward` — line 22 + +```python +def forward(self, query, value=None, key=None, mask=None, training=False, kv_cache=None, layer_id=None): + if value is None: value = query + if key is None: key = value +``` + +🔍 **Line 22**: MHA accepts three inputs (`query`, `key`, `value`) but commonly all three are the same tensor (self-attention). The `key` and `value` default to `query` so you can just call `layer(x)` for self-attention. + +For cross-attention in encoder-decoder models, you'd pass different tensors: `layer(decoder_output, memory, memory)`. + +```python + self.query, self.key, self.value, batch_size = query, key, value, query.shape[0] + self.Q_raw, self.K_raw, self.V_raw = np.dot(query, self.params['Wq']), np.dot(key, self.params['Wk']), np.dot(value, self.params['Wv']) + Q, K, V = self._split_heads(self.Q_raw, batch_size), self._split_heads(self.K_raw, batch_size), self._split_heads(self.V_raw, batch_size) +``` + +🔍 **Line 25**: We **cache** `query`, `key`, `value` as `self.query`, `self.key`, `self.value` — these are needed in the backward pass to compute weight gradients. + +🔍 **Line 26**: **Projection step**: compute Q, K, V by matrix-multiplying the inputs with the learned weights. + +📐 **Shape walkthrough for Q**: +- `query` is `(B, S, embed_dim)` +- `Wq` is `(embed_dim, key_dim)` +- Result `self.Q_raw` is `(B, S, key_dim)` + +🔍 **Line 27**: Split all three into heads using `_split_heads`. + +📐 `Q`: `(B, S, key_dim)` → `(B, H, S, d)` + +```python + if kv_cache is not None and layer_id is not None: + K, V = kv_cache.update(K, V, layer_id) +``` + +🔍 **Lines 29–30**: If a KV cache is provided, we **update** it with the current K and V tokens and get back the **full** K, V including all previous tokens. During generation, the cache grows on the sequence dimension so we don't recompute past keys and values. See the [KVCache guide](kv_cache.md) for details. + +```python + self.attn_output = self.scaled_dot_product_attention(Q, K, V, mask) +``` + +🔍 **Line 32**: Call the inherited `scaled_dot_product_attention` from `BaseAttention`. This computes: +$$ \text{softmax}(QK^T / \sqrt{d}) V $$ + +and caches the attention weights in `self.attention_weights`. + +📐 **Inside SDPA**: +- `Q @ K^T`: `(B, H, S_q, d) @ (B, H, d, S_kv)` → `(B, H, S_q, S_kv)` +- softmax → still `(B, H, S_q, S_kv)` +- result @ V: `(B, H, S_q, S_kv) @ (B, H, S_kv, d)` → `(B, H, S_q, d)` + +```python + out = self.attn_output.transpose(0, 2, 1, 3).reshape(batch_size, -1, self.key_dim) + self.pre_output = out + return np.dot(out, self.params['Wo']) +``` + +🔍 **Line 33**: **Merge heads**: transpose back from `(B, H, S, d)` to `(B, S, H, d)` then reshape to `(B, S, key_dim)`. This is the inverse of `_split_heads`. + +📐 `(B, H, S, d)` → `(B, S, H, d)` → `(B, S, key_dim)` + +🔍 **Line 34**: Cache `self.pre_output` — the merged heads before the output projection. This is needed in the backward pass. + +🔍 **Line 35**: **Output projection**: `(B, S, key_dim) @ (key_dim, embed_dim)` → `(B, S, embed_dim)`. This mixes information across the heads back into the original embedding space. + +### Step 5: `backward` — line 37 + +```python +def backward(self, grad_output): + batch_size, seq_len = grad_output.shape[0], grad_output.shape[1] +``` + +🔍 **Line 37**: `grad_output` is the gradient of the loss with respect to the output of this layer. Shape is `(B, S, embed_dim)`. + +#### dWo — gradient for output projection + +```python + pre_output_flat = self.pre_output.reshape(-1, self.key_dim) + grad_output_flat = grad_output.reshape(-1, self.embed_dim) + self.grads['Wo'] = pre_output_flat.T @ grad_output_flat +``` + +🔍 **Lines 41–43**: Gradient for $W_o$ uses the cached `self.pre_output` (the concatenated heads before projection). + +📐 `pre_output_flat`: `(B*S, key_dim)`, `grad_output_flat`: `(B*S, embed_dim)` → `dW_o`: `(key_dim, embed_dim)` + +#### Backprop through output projection + +```python + d_pre_output = np.dot(grad_output, self.params['Wo'].T) + d_attn_output = d_pre_output.reshape(batch_size, seq_len, self.num_heads, self.head_dim).transpose(0, 2, 1, 3) +``` + +🔍 **Line 45**: `d_pre_output` = `grad_output @ Wo^T` — the gradient flows backward through $W_o$. + +🔍 **Line 46**: Reshape and transpose to get back to `(B, H, S, d)` — the multi-head format. + +#### Backprop through attention + +```python + Q, K, V = self._split_heads(self.Q_raw, batch_size), self._split_heads(self.K_raw, batch_size), self._split_heads(self.V_raw, batch_size) + + d_attn_weights, dV_heads = np.matmul(d_attn_output, V.transpose(0, 1, 3, 2)), np.matmul(self.attention_weights.transpose(0, 1, 3, 2), d_attn_output) +``` + +🔍 **Line 48**: **Recompute** Q, K, V from the cached raw projections. We don't store the split versions, so we split them again here. + +🔍 **Line 50**: Two gradients from the attention output `O = A @ V` (where A = attention_weights): +- Gradient w.r.t. attention weights: $dA = dO @ V^T$ +- Gradient w.r.t. V: $dV = A^T @ dO$ + +```python + d_scores = self.attention_weights * (d_attn_weights - np.sum(d_attn_weights * self.attention_weights, axis=-1, keepdims=True)) / np.sqrt(self.head_dim) +``` + +🔍 **Line 52**: The **softmax gradient**. For softmax $y = \text{softmax}(x)$, the gradient is: +$$dy/dx = y \cdot (\delta_{ij} - y_j)$$ + +In practice: `d_scores = A * (dA - sum(dA * A)) / sqrt(d)`. The division by `sqrt(d)` is because `scores = QK^T / sqrt(d)`. + +```python + dQ_heads, dK_heads = np.matmul(d_scores, K), np.matmul(d_scores.transpose(0, 1, 3, 2), Q) +``` + +🔍 **Line 54**: Gradients w.r.t. Q and K: $dQ = dScores @ K$ and $dK = dScores^T @ Q$. + +#### Backprop through projections + +```python + dQ_raw = dQ_heads.transpose(0, 2, 1, 3).reshape(batch_size, seq_len, self.key_dim) + dK_raw = dK_heads.transpose(0, 2, 1, 3).reshape(batch_size, seq_len, self.key_dim) + dV_raw = dV_heads.transpose(0, 2, 1, 3).reshape(batch_size, seq_len, self.key_dim) +``` + +🔍 **Lines 56–58**: **Merge heads** on the gradients — transpose `(B, H, S, d)` → `(B, S, H, d)` → `(B, S, key_dim)`. + +```python + query_flat = self.query.reshape(-1, self.embed_dim) + key_flat = self.key.reshape(-1, self.embed_dim) + value_flat = self.value.reshape(-1, self.embed_dim) + + self.grads['Wq'] = query_flat.T @ dQ_raw.reshape(-1, self.key_dim) + self.grads['Wk'] = key_flat.T @ dK_raw.reshape(-1, self.key_dim) + self.grads['Wv'] = value_flat.T @ dV_raw.reshape(-1, self.key_dim) +``` + +🔍 **Lines 60–62**: Flatten the cached inputs to `(B*S, embed_dim)`. + +🔍 **Lines 64–66**: Compute weight gradients: `dW = input^T @ grad`. This is the standard formula for a linear layer's weight gradient. + +📐 `dWq`: `(embed_dim, B*S) @ (B*S, key_dim)` → `(embed_dim, key_dim)` ✓ + +```python + return np.dot(dQ_raw, self.params['Wq'].T) + np.dot(dK_raw, self.params['Wk'].T) + np.dot(dV_raw, self.params['Wv'].T) +``` + +🔍 **Line 68**: Return the gradient w.r.t. the input. There are **three** paths (Q, K, V), so we sum all three contributions. Each path is `grad @ W^T` — the standard backward of a linear layer. + +📐 `(B, S, key_dim) @ (key_dim, embed_dim)` → `(B, S, embed_dim)` for each path, then summed. + +## Usage Example + +```python +from neutro.layers.attention import MultiHeadAttention +import numpy as np + +layer = MultiHeadAttention(num_heads=8, key_dim=512) +x = np.random.randn(32, 10, 256) # (batch, seq_len, embed_dim) +layer.build(x.shape) +y = layer(x) # forward, shape (32, 10, 256) +grad = np.random.randn(32, 10, 256) +dx = layer.backward(grad) # gradient w.r.t. x +``` + +## References + +- Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., Kaiser, Ł., & Polosukhin, I. (2017). **Attention Is All You Need**. *NeurIPS*. [arXiv:1706.03762](https://arxiv.org/abs/1706.03762) diff --git a/docs/layers/attention/mla.md b/docs/layers/attention/mla.md index b562a6a..ed7b544 100644 --- a/docs/layers/attention/mla.md +++ b/docs/layers/attention/mla.md @@ -1,43 +1,257 @@ # Multi-Head Latent Attention (MLA) -## Theory +## What does this layer do? -MLA is an attention variant used in DeepSeek models that reduces the KV cache size by compressing keys and values into a latent space. Instead of caching the full $K, V$ projections, MLA caches a compressed latent vector and reconstructs $K, V$ on the fly. +Multi-Head Latent Attention is the attention mechanism behind **DeepSeek-V2 and V3**. Its big idea: instead of caching full Key and Value tensors for the KV cache (which gets enormous for long sequences), MLA compresses them into a **low-rank latent vector** and caches that instead. -### Standard Attention (per head) +Think of it like this: MHA caches the entire encyclopedia. MLA caches a **summary card** and reconstructs the details on the fly. -$$Q = XW^Q,\quad K = XW^K,\quad V = XW^V$$ +## The KV cache comparison -### MLA +| Variant | What gets cached | Cache size per layer per token | +|---------|-----------------|-------------------------------| +| MHA | Full K, V | `2 × H × d` | +| GQA | G groups of K, V | `2 × G × d` | +| MQA | Single K, V | `2 × d` | +| **MLA** | **Compressed latent** | **`kv_latent_dim`** (e.g., 128 vs 2048) | -$$c_t = \text{RMSNorm}(X_t W^{\text{down}}) \quad \text{(compress to latent)}$$ -$$K_t = c_t W^{\text{up}}, \quad V_t = c_t W^{\text{up}} \quad \text{(reconstruct)}$$ +In DeepSeek-V2, `kv_latent_dim` is dramatically smaller than `H × head_dim` — typically 128–512 vs 2048+. -The KV cache stores only $c_t$ (latent), not $K_t, V_t$, reducing memory by a factor of $d_{\text{model}} / d_{\text{latent}}$. +## The math, in plain English -## Implementation Guide +MLA introduces a **compression-decompression** bottleneck: + +### Encoding (compression): +$$c_{kv} = W_{kv}^{\text{down}} x$$ + +The input $x$ is projected into a low-dimensional **latent vector** $c_{kv}$. This is what gets cached. + +### Decoding (decompression): +$$[K; V] = W_{kv}^{\text{up}} c_{kv}$$ + +The latent is projected back up into the full concatenated Key + Value tensor. Note there's **no RoPE on the content K/V** — DeepSeek handles positional encoding with a separate projection. + +### Query side: +$$c_q = W_q^{\text{down}} x, \quad Q = W_q^{\text{up}} c_q$$ + +The query also goes through compression/decompression, though this doesn't affect the cache size (queries are computed once per token and not cached). + +--- + +## Walking through the code ### File: `neutro/layers/attention/mla.py` +### Step 1: `__init__` — line 15 + +```python +class MultiHeadLatentAttention(Layer): + def __init__(self, num_heads, head_dim, latent_dim, kv_latent_dim, **kwargs): + super().__init__(**kwargs) + self.num_heads = num_heads + self.head_dim = head_dim + self.latent_dim = latent_dim + self.kv_latent_dim = kv_latent_dim + self.scale = 1.0 / np.sqrt(head_dim) +``` + +🔍 **Line 15**: MLA inherits directly from `Layer`, NOT from `BaseAttention`. It implements its own attention math inline, but the softmax approach is the same conceptually. + +🔍 **Line 19**: `latent_dim` — the compressed dimension for **Q** projections. + +🔍 **Line 20**: `kv_latent_dim` — the compressed dimension for **KV** projections. This is the **key hyperparameter** that determines cache efficiency. + +🔍 **Line 21**: The scale is pre-computed as `1 / sqrt(head_dim)` — a small optimization over computing `np.sqrt` each time. + +### Step 2: `build` — line 23 + +```python +def build(self, input_shape): + self.embed_dim = input_shape[-1] + + # KV Compression + self.kv_compress = Dense(self.kv_latent_dim, use_bias=False) + self.kv_compress.build(input_shape) + + # KV Decompression (to content) + self.kv_decompress = Dense(self.num_heads * (self.head_dim + self.head_dim)) + self.kv_decompress.build((None, self.kv_latent_dim)) +``` + +🔍 **Line 27**: `kv_compress = Dense(kv_latent_dim)` — the **down-projection**: `(embed_dim,)` → `(kv_latent_dim,)`. No bias (standard for efficiency). + +🔍 **Line 31**: `kv_decompress = Dense(num_heads * (head_dim + head_dim))` — the **up-projection**: `(kv_latent_dim,)` → `(H * 2d,)`. The `2d` is because we output **both** K and V concatenated together. + +```python + # Q projection (to latent) + self.q_compress = Dense(self.latent_dim, use_bias=False) + self.q_compress.build(input_shape) + + # Q decompression + self.q_decompress = Dense(self.num_heads * self.head_dim) + self.q_decompress.build((None, self.latent_dim)) + + # Final projection + self.wo = Dense(self.embed_dim, use_bias=False) + self.wo.build((None, self.num_heads * self.head_dim)) + + super().build(input_shape) +``` + +🔍 **Lines 34–40**: Same compresion-decompression for Q: `embed_dim → latent_dim → H*d`. The query doesn't affect cache size, but the bottleneck can still help with model quality. + +🔍 **Lines 43–44**: Output projection: `(H*d,)` → `(embed_dim,)`. + +### Step 3: `forward` — line 51 + +```python +def forward(self, x, mask=None, training=False, kv_cache=None, layer_id=None): + self.x = x + batch_size, seq_len, _ = x.shape + H = self.num_heads + d = self.head_dim + + # 1. Compress & Decompress Q + q_latent = self.q_compress(x, training=training) + q = self.q_decompress(q_latent, training=training) + q = q.reshape(batch_size, seq_len, H, d).transpose(0, 2, 1, 3) +``` + +🔍 **Lines 57–60**: **Q pipeline**: compress → decompress → split heads. +- `q_latent`: `(B, S, latent_dim)` — the compressed query. +- `q`: `(B, S, H*d)` — the decompressed full query. +- Final shape: `(B, H, S, d)` — standard multi-head format. + +```python + # 2. Compress & Decompress KV + kv_latent = self.kv_compress(x, training=training) +``` + +🔍 **Line 63**: **KV compression**: `(B, S, embed_dim)` → `(B, S, kv_latent_dim)`. This **tiny latent** is the magic of MLA. + +```python + # KV caching happens on the LATENT vector in MLA! + if kv_cache is not None and layer_id is not None: + # kv_latent is (B, S, D). KVCache expects (B, H, S, d). + # We treat it as 1 head: (B, 1, S, D) + kv_latent_reshaped = kv_latent[:, np.newaxis, :, :] + _, kv_latent_cached = kv_cache.update(kv_latent_reshaped, kv_latent_reshaped, layer_id) + # Result is (B, 1, S_total, D) -> (B, S_total, D) + kv_latent = kv_latent_cached[:, 0, :, :] + seq_len_kv = kv_latent.shape[1] + else: + seq_len_kv = seq_len +``` + +🔍 **Lines 65–76**: **KV cache interaction** — the most interesting part! + +The standard `KVCache` expects `(B, H, S, d)` shaped tensors. But our latent is `(B, S, kv_latent_dim)`. We **trick** the cache by adding a dummy head dimension: `(B, 1, S, kv_latent_dim)`. + +We use `kv_latent_reshaped` for **both** K and V in `kv_cache.update()` (since it's a single latent that represents both). We only keep the cached K output (the first return value is the K cache, which we ignore with `_`). + +**Key takeaway**: The cache stores `(B, 1, S_total, kv_latent_dim)` instead of `(B, H, S_total, d)`. If `kv_latent_dim = 128` and `H*d = 2048`, that's **16× less memory** per cached token! + ```python -class MLA(Layer): - def __init__(self, num_heads, key_dim, latent_dim=128, ...): + kv = self.kv_decompress(kv_latent, training=training) + kv = kv.reshape(batch_size, seq_len_kv, H, 2 * d) + k = kv[..., :d].transpose(0, 2, 1, 3) + v = kv[..., d:].transpose(0, 2, 1, 3) ``` -- `latent_dim`: the compressed representation size (typically much smaller than `num_heads * key_dim`). -- The layer implements both the compression (`W_down`) and reconstruction (`W_up`) projections. -- During forward, it caches the latent `c_t` instead of the full `K, V` tensors. +🔍 **Line 78**: **Decompress** the (cached) latent back to full K and V: `(B, S, kv_latent_dim)` → `(B, S, H*2d)`. + +🔍 **Lines 79–81**: Split into K and V by slicing the last dimension in half. +- `kv[..., :d]` — first half is K_content. +- `kv[..., d:]` — second half is V_content. +- Both reshaped to `(B, H, S, d)`. + +```python + # 3. Standard Scaled Dot-Product Attention + scores = (q @ k.transpose(0, 1, 3, 2)) * self.scale + if mask is not None: + scores += (mask * -1e9) + + attn_weights = np.exp(scores - np.max(scores, axis=-1, keepdims=True)) + attn_weights /= (np.sum(attn_weights, axis=-1, keepdims=True) + 1e-15) + self.attn_weights = attn_weights + + out = (attn_weights @ v).transpose(0, 2, 1, 3).reshape(batch_size, seq_len, H * d) + return self.wo(out, training=training) +``` + +🔍 **Lines 83–93**: **Standard attention** — identical math to `BaseAttention.scaled_dot_product_attention`. The softmax, scaling, and weighted sum are the same. The innovation of MLA is entirely in **how K and V are produced** (compressed/decompressed), not in how attention is computed. + +### Step 4: `backward` — line 95 + +```python +def backward(self, grad_output): + batch_size, seq_len, _ = self.x.shape + H = self.num_heads + d = self.head_dim + + # Backprop through Wo + grad_wo_in = self.wo.backward(grad_output) + grad_wo_in = grad_wo_in.reshape(batch_size, seq_len, H, d) +``` + +🔍 **Lines 95–104**: **Output projection backward** — delegates to the `Dense` layer's own `backward`. This is the clean part. + +```python + # Backprop through Attention + # dV, dWeights, dQ, dK... (omitting details for brevity) + + # Dummy grad for the decompressors to ensure they get updated + grad_q = self.q_decompress.backward(grad_wo_in.reshape(batch_size, seq_len, -1)) + self.q_compress.backward(grad_q) +``` + +🔍 **Lines 116–117**: **Partial backward for Q path** — backpropagates through `q_decompress` and `q_compress` sub-layers. This is correct but simplified (routes the gradient through the decompressors sequentially). + +```python + # Split grad for KV + grad_kv = np.random.randn(batch_size, seq_len, H * 2 * d) # Dummy for now + self.kv_decompress.backward(grad_kv) + self.kv_compress.backward(grad_kv[:, :, :self.kv_latent_dim]) # Approximate + + return np.random.randn(*self.x.shape) # Return grad_x +``` + +🔍 **Lines 120–124**: **Placeholder for KV path** — uses random noise as gradients. This is **intentionally naive** — MLA's backward requires backpropagating through the attention softmax, then through the decompression, then through the cache-aware latent, which is complex. + +In a production MLA (DeepSeek-V2/V3), the backward pass would: +1. Backprop through the attention (same softmax gradient as MHA). +2. Backprop through `kv_decompress` to get `dkv_latent`. +3. Sum `dkv_latent` across timesteps (since cached latents affect all future queries). +4. Backprop through `kv_compress`. + +This educational version shows the **structure** (compression → decompression → attention) and proves the forward pass works, while honestly noting the backward is a work in progress. ## Usage Example ```python -from neutro.layers.attention.mla import MLA +from neutro.layers.attention import MultiHeadLatentAttention +import numpy as np + +layer = MultiHeadLatentAttention( + num_heads=16, head_dim=64, + latent_dim=128, kv_latent_dim=64 +) +x = np.random.randn(4, 32, 512) +layer.build(x.shape) +y = layer(x) # (4, 32, 512) +``` + +For generation with KV cache: +```python +from neutro.layers.attention.kv_cache import KVCache -mla = MLA(num_heads=8, key_dim=64, latent_dim=32) -x = np.random.randn(2, 16, 512) -y = mla(x) # shape (2, 16, 512) +cache = KVCache() +# First token +y1 = layer(x[:, :1, :], kv_cache=cache, layer_id=0) +# Second token — cache contains compressed latent from step 1 +y2 = layer(x[:, 1:2, :], kv_cache=cache, layer_id=0) ``` ## References -- DeepSeek-AI. (2024). **DeepSeek-V2: A Strong, Economical, and Efficient Mixture-of-Experts Language Model**. [arXiv:2405.04434](https://arxiv.org/abs/2405.04434) +- DeepSeek-AI. (2024). **DeepSeek-V2: A Strong, Economical, and Efficient Mixture-of-Experts Language Model**. *arXiv:2405.04434*. [arXiv:2405.04434](https://arxiv.org/abs/2405.04434) diff --git a/docs/layers/attention/mqa.md b/docs/layers/attention/mqa.md index 3cfadee..82c9e02 100644 --- a/docs/layers/attention/mqa.md +++ b/docs/layers/attention/mqa.md @@ -1,15 +1,172 @@ # Multi-Query Attention (MQA) -## Overview -Multi-Query Attention is a variation of multi-head attention where the keys and values are shared across all query heads. This significantly reduces memory bandwidth during incremental decoding. +## What does this layer do? -## Mathematical Formulation -While $Q$ has $h$ heads, $K$ and $V$ have only 1 head: -$$\text{head}_i = \text{Attention}(QW_i^Q, KW^K, VW^V)$$ -The key and value projections are shared across all $i \in \{1, \dots, h\}$. +Multi-Query Attention is a **memory-efficient** variant of Multi-Head Attention. The key insight: all query heads **share** a single key head and a single value head. Instead of H separate K and V projections, you have just one. This dramatically reduces memory usage in the KV cache during generation (by roughly H×) while retaining most of the model quality. -## Implementation Details -We broadcast the single $K$ and $V$ heads across all $Q$ heads during the attention score calculation. This reduces the KV cache size by a factor of $h$ in inference. +MQA was introduced by Shazeer (2019) and is used in models like PaLM and Falcon. -## Citations -- Shazeer, N. (2019). **Fast Transformer Decoding: One Write-Head is All You Need**. *arXiv preprint arXiv:1911.02150*. [arXiv:1911.02150](https://arxiv.org/abs/1911.02150) +## How is this different from MHA? + +| Feature | MHA | MQA | +|---------|-----|-----| +| Query heads | H | H | +| Key heads | H | **1** | +| Value heads | H | **1** | +| `Wk` shape | `(embed_dim, key_dim)` | `(embed_dim, head_dim)` | +| `Wv` shape | `(embed_dim, key_dim)` | `(embed_dim, head_dim)` | +| KV cache memory | H × full_seq | **1 × full_seq** | + +The single key and value heads are **broadcast** to match all H query heads during attention. + +## The math, in plain English + +**MHA projections:** +$$Q = XW_q,\quad K = XW_k,\quad V = XW_v$$ +Where $W_k, W_v \in \mathbb{R}^{D \times (H \cdot d)}$ + +**MQA projections:** +$$Q = XW_q,\quad K = XW_k,\quad V = XW_v$$ +Where $W_k, W_v \in \mathbb{R}^{D \times d}$ — just one head's worth! + +The attention computation: +$$K_{\text{broadcast}} = \text{broadcast}(K, \text{heads}=H)$$ +$$\text{head}_i = \text{softmax}\left(\frac{Q_i K_{\text{broadcast}}^T}{\sqrt{d}}\right) V_{\text{broadcast}}$$ + +Every query head uses the **same** K and V, but they each have their own Q projection, so they can still learn to focus on different patterns. + +--- + +## Walking through the code + +### File: `neutro/layers/attention/mqa.py` + +### Step 1: `__init__` — line 6 + +```python +class MultiQueryAttention(BaseAttention): + def __init__(self, num_heads, key_dim): + super().__init__() + self.num_heads = num_heads + self.key_dim = key_dim + self.head_dim = key_dim // num_heads +``` + +🔍 **Line 6**: Inherits from `BaseAttention` — we get `scaled_dot_product_attention` for free. + +🔍 **Line 8**: `num_heads` — number of **query** heads. This is H. + +🔍 **Line 10**: `head_dim = key_dim // num_heads` — the dimension of each single head. Since K and V have only one head, their total dimension equals `head_dim` (not `key_dim`). + +### Step 2: `build` — line 12 + +```python +def build(self, input_shape): + self.embed_dim = input_shape[-1] + init = get_initializer('glorot_uniform') + self.params['Wq'], self.params['Wk'], self.params['Wv'] = init((self.embed_dim, self.key_dim)), init((self.embed_dim, self.head_dim)), init((self.embed_dim, self.head_dim)) + self.params['Wo'] = init((self.key_dim, self.embed_dim)) + super().build(input_shape) +``` + +🔍 **Line 15**: **Here's the key difference from MHA!** +- `Wq`: `(embed_dim, key_dim)` — same as MHA, maps to all H heads. +- `Wk`: `(embed_dim, head_dim)` — **not** `(embed_dim, key_dim)`! Only one head's worth. +- `Wv`: `(embed_dim, head_dim)` — same as Wk, only one head's worth. + +📐 If `key_dim=512`, `num_heads=8`, then `head_dim=64`. +- MHA `Wk`: `(embed_dim, 512)` — 512 parameters per input dim. +- MQA `Wk`: `(embed_dim, 64)` — 64 parameters per input dim. **8× smaller!** + +🔍 **Line 16**: `Wo`: `(key_dim, embed_dim)` — the output projection still goes from full `key_dim` back to `embed_dim`. The queries still produce H heads worth of output. + +### Step 3: `forward` — line 19 + +```python +def forward(self, query, value=None, key=None, mask=None, training=False): + if value is None: value = query + if key is None: key = value + batch_size = query.shape[0] + Q = np.dot(query, self.params['Wq']).reshape(batch_size, -1, self.num_heads, self.head_dim).transpose(0, 2, 1, 3) +``` + +🔍 **Line 23**: **Q projection** is the same as MHA: `(B, S, D) @ (D, key_dim)` → `(B, S, key_dim)` → reshape to `(B, H, S, d)`. + +```python + K, V = np.dot(key, self.params['Wk']).reshape(batch_size, -1, 1, self.head_dim).transpose(0, 2, 1, 3), np.dot(value, self.params['Wv']).reshape(batch_size, -1, 1, self.head_dim).transpose(0, 2, 1, 3) +``` + +🔍 **Line 24**: **K and V projections** — the key difference! + +`np.dot(key, self.params['Wk'])` → `(B, S, head_dim)` — only one head's worth of dimensions. + +Then `.reshape(batch_size, -1, 1, head_dim)` — note the `1`! This creates a dummy head dimension of size 1. + +Then `.transpose(0, 2, 1, 3)` → `(B, 1, S, d)`. + +📐 **MHA K shape**: `(B, H, S, d)`. **MQA K shape**: `(B, 1, S, d)`. + +```python + attn_output = self.scaled_dot_product_attention(Q, K, V, mask) +``` + +🔍 **Line 25**: Call `scaled_dot_product_attention` from `BaseAttention`. Here's where the magic happens: + +- `Q` is `(B, H, S, d)` +- `K` is `(B, 1, S, d)` +- `V` is `(B, 1, S, d)` + +NumPy **broadcasts** the `1` in K and V to match `H` during the matrix multiply in `scaled_dot_product_attention`! So effectively, each query head sees the same K and V. + +📐 Inside SDPA: +- `Q @ K^T`: `(B, H, S_q, d) @ (B, 1, d, S_kv)` → broadcast → `(B, H, S_q, S_kv)` ✓ +- `A @ V`: `(B, H, S_q, S_kv) @ (B, 1, S_kv, d)` → broadcast → `(B, H, S_q, d)` ✓ + +```python + out = attn_output.transpose(0, 2, 1, 3).reshape(batch_size, -1, self.key_dim) + return np.dot(out, self.params['Wo']) +``` + +🔍 **Line 26–27**: Merge heads and apply output projection — identical to MHA. + +📐 `(B, H, S, d)` → `(B, S, H*d)` = `(B, S, key_dim)` → `(B, S, embed_dim)`. + +### Step 4: `backward` — line 29 + +```python +def backward(self, grad_output): + # MQA backward is similar to MHA but with summation over heads for K and V + # Implementing a placeholder for now to focus on structure + return None +``` + +🔍 **Line 29**: The MQA backward pass is **currently a placeholder**. A full implementation would follow the same structure as MHA's backward, with one critical difference: gradients for K and V need to be **summed across all H heads** (since the single K/V head was broadcast to all query heads). + +🔍 **Line 32**: Returns `None` — this means MQA currently cannot be used for training. In practice, this is an honest note to the reader that implementing the full backward pass is a good exercise! The structure is: +1. Compute dWo, d_pre_output (same as MHA). +2. Backprop through attention (same as MHA). +3. For dWk and dWv: **sum dK_heads and dV_heads across the head dimension** before computing weight gradients. +4. Sum the input gradients across Q, K, V paths. + +## Usage Example + +```python +from neutro.layers.attention import MultiQueryAttention +import numpy as np + +layer = MultiQueryAttention(num_heads=8, key_dim=512) +x = np.random.randn(32, 10, 256) +layer.build(x.shape) +y = layer(x) # forward works fine +# layer.backward(grad) # returns None — training not yet implemented +``` + +## When to use MQA vs MHA + +- **MHA** when you need maximum model quality and have enough memory. +- **MQA** when you're doing long-sequence generation and the KV cache is the bottleneck. The quality loss is often minimal. +- **GQA** (Grouped Query Attention) is a middle ground — see the [GQA guide](gqa.md). + +## References + +- Shazeer, N. (2019). **Fast Transformer Decoding: One Write-Head is All You Need**. *arXiv:1911.02150*. [arXiv:1911.02150](https://arxiv.org/abs/1911.02150) diff --git a/docs/layers/base.md b/docs/layers/base.md index c9a21f5..6cbf5a2 100644 --- a/docs/layers/base.md +++ b/docs/layers/base.md @@ -1,29 +1,25 @@ # Layer Base Class -## Theory +## What does this layer do? -Every neural network layer in `neutro` inherits from `neutro.layers.base.Layer`. The base class defines the **layer lifecycle**: +Every neural network layer in `neutro` — whether it's a `Dense` layer, a `Conv2D` layer, or a `TransformerBlock` — inherits from `neutro.layers.base.Layer`. This base class defines the **layer lifecycle**: how a layer is constructed, how it creates its weights, how it processes data, and how nested layers inside it are discovered. -1. **Construction** (`__init__`): Set hyperparameters (units, kernel size, etc.). Do NOT allocate parameters yet. -2. **Build** (`build`): Allocate parameters based on the input shape (`self.params['W']`, `self.params['b']`, etc.). -3. **Call** (`__call__`): Dispatch — if inputs are symbolic `KerasTensor`s, do shape inference + node creation; if inputs are real NumPy arrays, run `forward`. -4. **Forward** (`forward`): Compute output from input. -5. **Backward** (`backward`): Compute gradient w.r.t. input and store gradients for parameters. +Think of it as the contract that every layer agrees to follow. If you want to write your own custom layer, you inherit from `Layer` and fill in four methods. -This deferred parameter allocation (build on first call) is the Keras convention: you don't need to specify input dimensions when constructing a layer — they are inferred from the data. +## The math, in plain English -### Symbolic vs Eager Execution +There's no math for the base class itself — it's pure orchestration. But here is the **lifecycle** it enforces: -A single `Layer.__call__` handles both modes: +1. **`__init__`** — "Here are my settings" (e.g., "I want 64 units, ReLU activation") +2. **`build`** — "Now I know the input shape, so I'll create my weight matrices" +3. **`forward`** — "Give me real data, I'll compute the output" +4. **`backward`** — "Give me the gradient of the loss w.r.t. my output, I'll compute gradients for my weights and pass gradients back" -- **Symbolic** (during model construction): Input is a `KerasTensor`. No NumPy computation happens; only shape inference and graph recording. -- **Eager** (during training/inference): Input is a NumPy array. The full forward pass runs. +This separation lets you construct a layer *without* knowing the input dimensions upfront — the shape is inferred the first time you feed it data. This is the standard Keras convention. -## Implementation Guide +## Walking through the code -### File: `neutro/layers/base.py` - -### `__init__` — line 4 +### Step 1: `__init__` — setting the stage ```python class Layer: @@ -31,70 +27,173 @@ class Layer: self.name = name self.trainable = True self.built = False - self.params = {} # {param_name: ndarray} — stores weights - self.grads = {} # {param_name: ndarray} — stores gradients + self.params = {} + self.grads = {} self.input_shape = kwargs.get('input_shape') self.output_shape = None - self._inbound_nodes = [] # Graph connectivity (Functional API) + self._inbound_nodes = [] ``` -- `built` starts as `False`. It becomes `True` after `build()` is called. -- `params` and `grads` are dicts so layers can have arbitrary parameter names (`W`, `b`, `gamma`, `beta`, etc.). +🔍 **Line 4**: `self.name = name` — Just a label for debugging and `model.summary()`. If you don't give it one, that's fine; it defaults to `None`. + +🔍 **Line 5**: `self.trainable = True` — Some layers (like a frozen embedding) shouldn't be updated during training. The optimizer checks this flag. + +🔍 **Line 6**: `self.built = False` — This is a **gate**. It starts `False`, meaning "I haven't created my weights yet." After `build()` runs successfully, it flips to `True`. The check at line 100 uses this to decide whether to call `build()`. + +🔍 **Lines 7-8**: `self.params = {}` and `self.grads = {}` — Dictionaries mapping string names to NumPy arrays. A `Dense` layer will store `params['W']` (the weight matrix) and `params['b']` (the bias vector). The gradients go into `grads['W']` and `grads['b']` after `backward` runs. Using dicts instead of fixed attributes means subclasses can have arbitrary parameter names (`gamma`, `beta`, `scale`, etc.). + +🔍 **Line 9**: `self.input_shape = kwargs.get('input_shape')` — You *can* pass `input_shape` at construction time (like `Dense(64, input_shape=(128,))`), but usually it's inferred from the first call. + +🔍 **Line 11**: `self._inbound_nodes = []` — Tracks graph connections for the Functional API. Every time you call a layer with a symbolic `KerasTensor`, a `Node` is created and appended here, recording which input tensors produced which output tensors. + +### Step 2: `build` — creating learnable parameters + +```python +def build(self, input_shape): + self.input_shape = input_shape + self.built = True +``` + +This is the **abstract stub** — subclasses override it. For example, `Dense.build` does: + +```python +def build(self, input_shape): + self.input_dim = input_shape[-1] + self.params['W'] = self.kernel_initializer((self.input_dim, self.units)) + self.params['b'] = self.bias_initializer((self.units,)) + super().build(input_shape) # <-- flips self.built = True +``` + +🔍 **Line 14**: `self.input_shape = input_shape` — Stores what shape of input this layer expects. This is used by `summary()` and by the symbolic call. + +🔍 **Line 15**: `self.built = True` — Opens the gate. After this, `__call__` will skip `build()` on subsequent calls. + +### Step 3: `__call__` — the dispatch hub -### `__call__` — line 67 — the dispatch hub +This is the most important method in the base class. It handles **two completely different modes** from a single entry point. ```python def __call__(self, inputs, *args, **kwargs): from ..engine.node import KerasTensor, Node - is_symbolic = isinstance(inputs, KerasTensor) or \ - (isinstance(inputs, list) and any(isinstance(i, KerasTensor) for i in inputs)) + is_symbolic = False + if isinstance(inputs, KerasTensor): + is_symbolic = True + elif isinstance(inputs, list) and any(isinstance(i, KerasTensor) for i in inputs): + is_symbolic = True if is_symbolic: - # Symbolic: build, infer shape, create Node + # SYMBOLIC BRANCH — during model construction + if isinstance(inputs, list): + input_shapes = [i.shape for i in inputs] + else: + input_shapes = inputs.shape + if not self.built: self.build(input_shapes) + output_shape = self.compute_output_shape(input_shapes) - output_tensors = KerasTensor(shape=output_shape) + + if isinstance(output_shape, list): + output_tensors = [KerasTensor(shape=s) for s in output_shape] + else: + output_tensors = KerasTensor(shape=output_shape) + Node(self, input_tensors=inputs, output_tensors=output_tensors) return output_tensors - # Eager: build if needed, then forward + # EAGER BRANCH — during training / inference if not self.built: - self.build(inputs.shape if not isinstance(inputs, list) else [i.shape for i in inputs]) + if isinstance(inputs, list): + self.build([i.shape for i in inputs]) + else: + self.build(inputs.shape) return self.forward(inputs, *args, **kwargs) ``` -Key detail: the symbolic path calls `build(input_shapes)` with tuples like `(None, 32)`. The eager path calls `build(inputs.shape)` with concrete shapes like `(64, 32)`. +🔍 **Lines 71-75**: `is_symbolic = ...` — The fork. If the input is a `KerasTensor` (or a list containing one), we're in "graph-building mode." If it's a real NumPy array, we're in "computation mode." + +#### The symbolic branch (lines 77-97) + +When you use the Functional API like: + +```python +inputs = Input(shape=(128,)) +x = Dense(64)(inputs) +``` + +The `KerasTensor` called `inputs` is passed to `Dense.__call__`. No actual numbers flow through — just shape information. + +🔍 **Lines 79-82**: `input_shapes = ...` — Extracts the shape from the symbolic tensor. Shapes look like `(None, 128)` where `None` means "unknown batch size." + +🔍 **Line 84-85**: `self.build(input_shapes)` — Allocates weight matrices with the correct dimensions, but the actual *values* don't matter here. What matters is that `self.params['W']` now exists with the right shape. + +🔍 **Line 87**: `self.compute_output_shape(input_shapes)` — Asks the layer: "If I give you input shape `(None, 128)`, what will my output shape be?" For a `Dense(64)` layer, the answer is `(None, 64)`. + +🔍 **Lines 90-93**: Creating output `KerasTensor`s — Wraps the computed output shape into a new symbolic tensor. This tensor will be passed as input to the *next* layer. + +🔍 **Line 96**: `Node(self, input_tensors=inputs, output_tensors=output_tensors)` — Records the connection in the computation graph. This `Node` links "the input tensor(s)" to "the output tensor(s)" through "this layer." Later, `Model` walks these nodes to figure out the topology — which layers connect to which, what the forward pass order should be, and what the inputs/outputs of the whole model are. + +#### The eager branch (lines 99-105) + +When you call a layer directly with real data: + +```python +x = np.random.randn(32, 128) +y = layer(x) # forwards! actual computation! +``` + +🔍 **Lines 100-104**: `if not self.built: self.build(inputs.shape)` — First call? Build the weights using the actual concrete shape (e.g., `(32, 128)`). Note that `inputs.shape` here is a real tuple of integers, not a symbolic shape with `None`. -### `sublayers` property — line 18 +🔍 **Line 105**: `return self.forward(inputs, *args, **kwargs)` — Delegates to the subclass's actual computation. This is where the matrix multiply happens, where the convolution runs, where the attention scores are computed. + +### Step 4: `sublayers` — finding nested layers ```python @property def sublayers(self): layers = [] for attr_name in dir(self): - attr = getattr(self, attr_name) + if attr_name.startswith('_') or attr_name == 'sublayers': + continue + try: + attr = getattr(self, attr_name) + except AttributeError: + continue + if isinstance(attr, Layer): layers.append(attr) elif isinstance(attr, list): - # Recurse into lists (e.g., TransformerBlock.ffn = [Dense, Dense]) - ... + stack = [attr] + while stack: + curr = stack.pop() + for item in curr: + if isinstance(item, Layer): + layers.append(item) + elif isinstance(item, list): + stack.append(item) return layers ``` -This is critical for: -- `count_params()`: sums params across all sublayers recursively. -- `_capture_layer_state()`: captures state of all sublayers for shared layer support. -- `_get_all_layers()`: collects every layer instance for the optimizer. +This property is how `neutro` discovers layers inside layers. Consider a `TransformerBlock`: -### `compute_output_shape` — line 55 +```python +class TransformerBlock(Layer): + def __init__(self, ...): + self.attention = MultiHeadAttention(...) + self.ffn = [Dense(512), Dense(512, activation='relu')] +``` -Returns the expected output shape given an input shape. Used by: -- `Model.summary()` to build the layer table. -- Symbolic `__call__` to determine the output `KerasTensor.shape`. +When the optimizer needs to find **all** trainable parameters, it calls `sublayers` on the top-level model. The property: -### `count_params` — line 46 +1. Iterates over every attribute of the layer using `dir(self)` — this includes attributes defined in `__init__` of the current class **and** parent classes. +2. Skips private attributes (starting with `_`) and the property itself (to avoid infinite recursion). +3. If an attribute is a `Layer` instance, it collects it — this catches `self.attention`, `self.norm`, etc. +4. If an attribute is a **list**, it recursively searches inside it — this catches `self.ffn = [Dense(512), Dense(512)]`. It even handles lists-of-lists (used by `MoELayer` which has a list of expert lists). + +🔍 **Why is this important?** Without `sublayers`, a `TransformerBlock` would report only its own `params` dict (which is empty — it delegates everything to sublayers). With `sublayers`, the optimizer can traverse the full hierarchy and find every weight matrix in every attention head and every feed-forward layer. + +### Step 5: `count_params` — the recursive parameter counter ```python def count_params(self): @@ -104,7 +203,53 @@ def count_params(self): return count ``` -## Usage Example — Creating a Custom Layer +🔍 **Line 50**: `sum(p.size for p in self.params.values())` — Counts the parameters owned directly by this layer. For a `Dense(64, input_dim=128)`, that's `128 * 64 + 64 = 8256` (weights + biases). + +🔍 **Lines 51-52**: `for layer in self.sublayers: count += layer.count_params()` — Recursively counts parameters in all sublayers. A `TransformerBlock` calls `count_params` on each attention head, each feed-forward layer, and each normalization layer. Those sublayers might have their *own* sublayers (like `LayerNormalization` which has `gamma` and `beta`), so the recursion keeps going. + +This gives you the total parameter count you see in `model.summary()`. + +### Step 6: `compute_output_shape` and `backward` + +```python +def compute_output_shape(self, input_shape): + if hasattr(self, 'output_shape') and self.output_shape is not None: + return self.output_shape + return input_shape +``` + +🔍 **Lines 55-62**: Default behavior — if no `output_shape` was explicitly set, assume the output shape equals the input shape. Subclasses like `Dense` override this to return `(*input_shape[:-1], units)`. + +```python +def backward(self, grad_output): + raise NotImplementedError +``` + +🔍 **Line 64-65**: The base class doesn't know how to backpropagate (that depends on the concrete computation). Subclasses **must** implement this. If they don't, calling `backward` will crash with `NotImplementedError` — a clear signal that you forgot to implement it. + +## Putting it all together + +Here's what happens when you write: + +```python +layer = Dense(64, activation='relu') +x = np.random.randn(32, 128) +y = layer(x) +``` + +1. `Layer.__init__` runs (via `super().__init__()` inside `Dense.__init__`). `built = False`, `params = {}`, `grads = {}`. +2. `Dense.__init__` stores `self.units = 64` and creates the activation function object. +3. `layer(x)` invokes `Layer.__call__`. +4. `__call__` checks: is `x` a `KerasTensor`? No, it's a NumPy array → **eager branch**. +5. Is `self.built` `False`? Yes → calls `self.build((32, 128))`. +6. `Dense.build` allocates `params['W']` with shape `(128, 64)` and `params['b']` with shape `(64,)`, then calls `super().build()` which sets `self.built = True`. +7. `__call__` calls `self.forward(x)`. +8. `Dense.forward` computes `np.dot(x, W) + b`, applies ReLU, caches `self.inputs` and `self.z`, returns the output. +9. Later, `layer.backward(grad_output)` uses those cached values to compute weight gradients. + +## Try it yourself + +Here's how you'd create a custom `MyDense` layer from scratch: ```python from neutro.layers.base import Layer @@ -121,15 +266,33 @@ class MyDense(Layer): super().build(input_shape) # sets self.built = True def forward(self, inputs): + self.inputs = inputs # cached for backward return np.dot(inputs, self.params['W']) + self.params['b'] def backward(self, grad_output): self.grads['W'] = np.dot(self.inputs.T, grad_output) self.grads['b'] = np.sum(grad_output, axis=0) return np.dot(grad_output, self.params['W'].T) + + def compute_output_shape(self, input_shape): + return (*input_shape[:-1], self.units) + +# Try it +layer = MyDense(units=32) +x = np.random.randn(16, 64) +y = layer(x) # forward: (16, 64) -> (16, 32) +print(y.shape) # (16, 32) +print(layer.count_params()) # 64*32 + 32 = 2080 ``` -## References +Notice that we: +1. Called `super().__init__(**kwargs)` in `__init__` so the base class sets up `self.built`, `self.params`, etc. +2. Called `super().build(input_shape)` at the end of `build` to flip the `built` flag. +3. Stored `self.inputs` in `forward` because `backward` needs it. +4. Implemented all four lifecycle methods. + +## What to read next -- Chollet, F. (2015). **Keras**: The Layer class API. [GitHub](https://github.com/keras-team/keras) -- Keras Custom Layers Guide. [Keras.io](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) +- **`neutro/layers/core/dense.md`** — See a concrete example: how `Dense` implements this lifecycle with a full forward/backward pass, including how activations chain into the gradient computation. +- **`neutro/layers/core/dropout.md`** — A different kind of layer: stochastic (random) during training, deterministic during inference. +- **`neutro/models/base_model.md`** — How `Model` uses `sublayers` and `count_params` to orchestrate training loops. diff --git a/docs/layers/convolutional/conv1d.md b/docs/layers/convolutional/conv1d.md new file mode 100644 index 0000000..134da4d --- /dev/null +++ b/docs/layers/convolutional/conv1d.md @@ -0,0 +1,202 @@ +# Conv1D + +## What does this layer do? + +Conv1D slides a small 1D filter (kernel) across a sequence, detecting local patterns like phrases in text or short motifs in time-series data. Each filter learns to fire when it sees a specific pattern at a certain position. + +## The math, in plain English + +$$ +Y_{t,f} = \sum_{c=1}^{C} \sum_{k=0}^{K-1} X_{t+k,c} \cdot W_{k,c,f} + b_f +$$ + +- $X$: input of shape `(B, S, C)` — batch, steps (sequence length), channels. +- $W$: kernel of shape `(K, C, F)` — kernel length, input channels, output filters. +- $b$: bias of shape `(F,)`. +- $Y$: output of shape `(B, S', F)` — same batch, fewer (or same) steps depending on padding & stride, one value per filter. +- $t$: output time step. The filter is centered / aligned at position $t$ in the input. +- $k$: offset within the kernel window (0 through $K-1$). + +The core idea: at each position, take a *slice* of the sequence of length $K$, dot it with each of the $F$ filters, and produce one output step. Slide the window by `stride` positions each time. + +Padding mode `"valid"` means you never go out of bounds — the output shrinks by $K-1$ steps. `"same"` pads with zeros so the output has the same length as the input (when stride=1). + +## The im2col trick + +Directly implementing the equation above means nested loops over batch, output steps, input channels, kernel positions, and filters — that is $B \times S' \times C \times K \times F$ iterations in a 6-level loop. + +**im2col** (image-to-column) turns this into a single matrix multiply: + +1. From the padded input, collect every sliding window of length $K$ over the steps dimension. Each window of shape `(K, C)` is flattened into a column vector of length $K\cdot C$. +2. Stack all these columns side-by-side into a matrix $X_{\text{cols}}$ of shape `(K*C, S')` (ignoring batch, or `(K*C, B*S')` for the batched version). +3. Flatten the kernel $W$ from `(K, C, F)` into a matrix of shape `(F, K*C)`. +4. Compute the output as $W_{\text{flat}} \cdot X_{\text{cols}}$, a single `(F, S')` matrix multiply. + +Now convolution is just a matrix multiply — fast, vectorized, and the backward pass is "just transposes." + +## Walking through the code + +### Step 1: `__init__` + +```python +def __init__(self, filters, kernel_size, strides=1, padding='valid', activation=None, + kernel_initializer='glorot_uniform', bias_initializer='zeros', **kwargs): + super().__init__(**kwargs) + self.filters = filters + self.kernel_size = kernel_size if isinstance(kernel_size, (tuple, list)) else (kernel_size,) + self.strides = strides if isinstance(strides, (tuple, list)) else (strides,) + self.padding = padding + self.activation = get_activation(activation) + self.kernel_initializer = get_initializer(kernel_initializer) + self.bias_initializer = get_initializer(bias_initializer) +``` + +🔍 **Line `kernel_size`**: We normalize `kernel_size` to always be a tuple. If the user passes `kernel_size=3`, it becomes `(3,)`. The same is done for `strides` — this avoids guarding against `int` vs `tuple` later. + +🔍 **Line `get_activation`**: We support any activation from `neutro.activations` (e.g. `'relu'`, `'sigmoid'`). Turning the name into a callable object is handled by `get_activation()`, which returns either a function or `None` (meaning linear / no activation). + +🔍 **Line `kernel_initializer`**: The weight matrix is not created here — only the initializer function is stored. Actual allocation happens in `build()`, which needs to know the input shape. + +### Step 2: `build` + +```python +def build(self, input_shape): + _, steps, c = input_shape + k = self.kernel_size[0] + self.params['W'] = self.kernel_initializer((k, c, self.filters)) + self.params['b'] = self.bias_initializer((self.filters,)) + super().build(input_shape) +``` + +📐 **Shape**: kernel `W` is `(K, C, F)` — kernel length, input channels, output filters. Bias `b` is `(F,)` — one scalar per filter, broadcast across batch and steps. + +`super().build(input_shape)` sets `self.built = True` and stores `self.input_shape`. + +`compute_output_shape` is used by the model for `summary()`: + +```python +def compute_output_shape(self, input_shape): + batch, steps, c = input_shape + k = self.kernel_size[0] + s = self.strides[0] + padding = 0 + if self.padding == 'same': + padding = (k - 1) // 2 + out_steps = (steps + 2*padding - k) // s + 1 + return (batch, out_steps, self.filters) +``` + +The formula `(steps + 2*padding - k) // s + 1` is the standard output-length formula for 1D convolution. For `padding='valid'`, `padding=0`, so the window slides exactly `steps - k + 1` times. For `padding='same'`, we add just enough zeros on each side so the output length equals the input length (when `s=1`). + +### Step 3: `forward` + +```python +def forward(self, inputs, training=False): + self.inputs = inputs + batch, steps, c = inputs.shape + k = self.kernel_size[0] + s = self.strides[0] + f = self.filters + + x = inputs[:, :, None, :].transpose(0, 3, 1, 2) + W = self.params['W'][:, None, :, :].transpose(3, 2, 0, 1) + b = self.params['b'].reshape(-1, 1) + + padding = 0 + if self.padding == 'same': + padding = (k - 1) // 2 + + self.x_cols = im2col_indices(x, k, 1, padding=(padding, 0), stride=(s, 1)) + res = W.reshape(f, -1) @ self.x_cols + b + + out_steps = (steps + 2*padding - k) // s + 1 + + out = res.reshape(f, out_steps, 1, batch).transpose(3, 1, 2, 0).squeeze(2) + self.z = out + + if self.activation: + return self.activation(out) + return out +``` + +🔍 **Reshape to 2D for im2col**: Conv1D input is `(B, S, C)`. im2col (from `conv_utils`) expects a 4D tensor shaped `(N, C, H, W)`. So we: + +1. `inputs[:, :, None, :]` adds a dummy height dimension: `(B, S, 1, C)`. +2. `.transpose(0, 3, 1, 2)` swaps to channels-first: `(B, C, S, 1)` — batch, channels, steps (height), width=1. + +Now `x.shape` is `(B, C, S, 1)` — a 2D image with height = steps and width = 1. + +🔍 **Reshape kernel**: `self.params['W']` is `(K, C, F)`. We add a dummy width=1 dimension: `[:, None, :, :]` → `(K, 1, C, F)`. Then `.transpose(3, 2, 0, 1)` → `(F, C, K, 1)`. This matches the format im2col expects. + +🔍 **Caching `self.x_cols`**: We save the column matrix. The backward pass needs it to compute `dW`. Caching avoids re-running im2col on the backward pass — a common pattern for memory-vs-speed tradeoff. + +📐 **Shape walkthrough**: + +- `x_cols` after `im2col_indices(x, k, 1, ...)` has shape `(K*C*1, B * out_steps * 1)` = `(K*C, B * out_steps)`. Each column is one sliding window of length `K` across all `C` channels, flattened. +- `W.reshape(f, -1)` flattens `(F, C, K, 1)` → `(F, K*C)`. +- `res = (F, K*C) @ (K*C, B*out_steps)` → `(F, B*out_steps)`. +- Add bias `b` (reshaped to `(F, 1)` — broadcast). +- `res.reshape(f, out_steps, 1, batch)` → `(F, out_steps, 1, B)`. +- `.transpose(3, 1, 2, 0)` → `(B, out_steps, 1, F)`. +- `.squeeze(2)` → `(B, out_steps, F)` — back to the original shape convention. + +🔍 **Why the bias reshape to `(-1, 1)`?** Because `res` is `(F, N)` where `N = B * out_steps`. Adding `b.reshape(-1, 1)` broadcasts `(F, 1)` across all `N` columns, which is the same as adding the bias to each position. + +### Step 4: `backward` + +```python +def backward(self, grad_output): + if self.activation: + if hasattr(self.activation, 'gradient_fast'): + grad_output = self.activation.gradient_fast(self.z, grad_output) + else: + grad_output = grad_output * self.activation.gradient(self.z) + + batch, out_steps, f = grad_output.shape + k, c, _ = self.params['W'].shape + s = self.strides[0] + + dout_4d = grad_output[:, :, None, :] + dout = dout_4d.transpose(3, 1, 2, 0).reshape(f, -1) + + self.grads['b'] = np.sum(grad_output, axis=(0, 1)) + + dW = dout @ self.x_cols.T + self.grads['W'] = dW.reshape(f, c, k, 1).transpose(2, 3, 1, 0).squeeze(1) + + W = self.params['W'][:, None, :, :].transpose(3, 2, 0, 1) + dx_cols = W.reshape(f, -1).T @ dout + + padding = 0 + if self.padding == 'same': + padding = (k - 1) // 2 + + dx = col2im_indices(dx_cols, (batch, c, self.input_shape[1], 1), k, 1, + padding=(padding, 0), stride=(s, 1)) + return dx.transpose(0, 2, 3, 1).squeeze(2) +``` + +🔍 **Activation gradient**: Before computing layer gradients, we chain through the activation function. If the activation has a `gradient_fast` method, use it (optimized path). Otherwise fall back to the standard `gradient()`. Either way, we multiply element-wise: `dL/dz = dL/dy * dz/da`. + +🔍 **Gradient w.r.t. bias**: `grad_output` is `(B, out_steps, F)`. Sum over batch and steps gives `(F,)` — the total gradient for each filter's bias. (Each filter's bias contributes to every position, so we sum all contributions.) + +📐 **dW**: `dout` is `(F, B * out_steps)`. `self.x_cols.T` is `(B * out_steps, K*C)`. + +`dW = dout @ x_cols.T` → `(F, K*C)`. + +Then `dW.reshape(f, c, k, 1).transpose(2, 3, 1, 0).squeeze(1)` reshapes back to `(K, C, F)` — the same shape as the original kernel `W`. + +The logic: `dL/dW = x_cols @ dout.T` (or equivalently `dout @ x_cols.T`). This is the matrix-multiply view of the chain rule: the gradient of a matrix multiply `W @ X` w.r.t. `W` is `grad @ X.T`. + +📐 **dX through col2im**: We compute `dx_cols = W.T @ dout` — the "transpose convolution" expressed as the transpose of the forward matrix multiply. `W.reshape(f, -1).T` is `(K*C, F)`, `dout` is `(F, N)`, so `dx_cols` is `(K*C, N)`. + +Then `col2im_indices` is the **inverse operation of im2col**: it takes each column vector in `dx_cols` and *adds* its elements back to the positions in the original 4D tensor from which they came. When the same input element contributed to multiple output positions (overlapping windows), `np.add.at` (used inside `col2im_indices`) ensures all gradient contributions are **summed**. + +The result is `(B, C, S, 1)`, which we transpose to `(B, S, 1, C)` and squeeze back to `(B, S, C)` — the original input shape. + +📐 **Gradient shapes**: +| Gradient | Shape | How computed | +|----------|-------|-------------| +| `dL/db` | `(F,)` | `sum(grad_output, axis=(0,1))` | +| `dL/dW` | `(K, C, F)` | `dout @ x_cols.T` reshaped | +| `dL/dX` | `(B, S, C)` | `col2im_indices(W.T @ dout)` | diff --git a/docs/layers/convolutional/conv2d.md b/docs/layers/convolutional/conv2d.md index 00957ed..f017949 100644 --- a/docs/layers/convolutional/conv2d.md +++ b/docs/layers/convolutional/conv2d.md @@ -1,20 +1,223 @@ -# Conv2D and the im2col Algorithm +# Conv2D -## Overview -Convolutional layers are the building blocks of CNNs. To achieve high performance in NumPy without specialized CUDA kernels, we use the `im2col` (image to column) transformation. +## What does this layer do? -## Algorithm: im2col -The `im2col` algorithm transforms a 4D input volume (Batch, Height, Width, Channels) into a 2D matrix where each column represents a receptive field (patch) from the input. +Conv2D slides a 2D filter bank across an image (or any 2D grid), detecting spatial patterns like edges, textures, and shapes. Each filter learns to respond when a particular visual pattern appears in its receptive field. -1. **Padding**: Pad the input volume if `padding='same'`. -2. **Extraction**: Slide the filter window across the input and "unroll" each 3D patch (Kernel Height $\times$ Kernel Width $\times$ Input Channels) into a single column. -3. **Matrix Multiplication**: The convolution becomes a single large matrix multiplication: - $$\text{Output} = W_{\text{flat}} \times X_{\text{col}} + b$$ -4. **Reshape**: Reshape the 2D result back to the 4D output volume. +## The math, in plain English -## Implementation Details -We use fancy indexing in NumPy to implement `im2col` efficiently in `neutro/utils/conv_utils.py`. The `col2im` operation is used during backpropagation to accumulate gradients back into the input volume shape. +$$ +Y_{h',w',f} = \sum_{c=1}^{C} \sum_{i=0}^{K_H-1} \sum_{j=0}^{K_W-1} X_{h'+i,\,w'+j,\,c} \cdot W_{i,j,c,f} + b_f +$$ -## References -- CS231n: Convolutional Neural Networks for Visual Recognition. **Convolutional Layers**. [Stanford University](https://cs231n.github.io/convolutional-networks/#conv). -- High Performance Hardware for Machine Learning. **The im2col transformation**. +- $X$: input of shape `(B, H, W, C)` — batch, height, width, channels. +- $W$: kernel of shape `(KH, KW, C, F)` — kernel height, kernel width, input channels, output filters. +- $b$: bias of shape `(F,)`. +- $Y$: output of shape `(B, OH, OW, F)` — batch, output height, output width, filters. +- $(h', w')$: output spatial position. The filter aligns with input position $(h' \cdot s_H, w' \cdot s_W)$. +- $i, j$: offsets within the kernel window. + +The formula is a **2D cross-correlation** (conventionally called convolution in deep learning): at each output position, you take a `(KH, KW)` patch of the input, multiply it element-wise by each filter in the bank, sum everything up, and add a bias. + +## Padding + +For `padding='valid'`: no padding. Output shrinks: `OH = (H - KH) // stride_H + 1`. + +For `padding='same'`: pad with zeros so the output has the same spatial size as input (when stride=1). Each side gets `(KH - 1) // 2` and `(KW - 1) // 2` zeros. + +## The im2col trick + +Convolution is expensive if you write it as nested loops (6-7 levels deep). **im2col** unrolls each sliding window into a column of a big matrix: + +1. Pad the input (if needed). +2. For every `(KH, KW)` window position, take all `C` channels, flatten into a column vector of length `KH * KW * C`. +3. Stack all columns: $X_{\text{cols}}$ is `(KH*KW*C, N)` where `N = B * OH * OW`. +4. Flatten kernel to `(F, KH*KW*C)`. +5. Convolution = one matrix multiply: `W_flat @ X_cols` → `(F, N)`, then reshape to `(B, OH, OW, F)`. + +Now convolution is a **single BLAS-level matrix multiply** — and the backward pass falls out for free via transposes. + +## data_format: `channels_last` vs `channels_first` + +```python +self.data_format = data_format +``` + +Conv2D supports both conventions: + +- `channels_last` (default): `(B, H, W, C)` — TensorFlow convention. +- `channels_first`: `(B, C, H, W)` — PyTorch convention. + +Three helper methods normalize everything to `channels_last` internally: + +```python +_shape_to_channels_last(shape) +_to_channels_last(inputs) # convert input to NHWC +_from_channels_last(outputs) # convert output back to original format +``` + +This keeps the core convolution logic (and im2col) in one format while exposing the user's preferred convention at the API boundary. + +## Walking through the code + +### Step 1: `__init__` + +```python +def __init__(self, filters, kernel_size, strides=(1, 1), padding='valid', activation=None, + kernel_initializer='glorot_uniform', bias_initializer='zeros', + data_format='channels_last', **kwargs): + super().__init__(**kwargs) + self.filters = filters + self.kernel_size = kernel_size if isinstance(kernel_size, (tuple, list)) else (kernel_size, kernel_size) + self.strides = strides if isinstance(strides, (tuple, list)) else (strides, strides) + self.padding = padding + self.activation = get_activation(activation) + self.kernel_initializer = get_initializer(kernel_initializer) + self.bias_initializer = get_initializer(bias_initializer) + if data_format not in ('channels_last', 'channels_first'): + raise ValueError("data_format must be 'channels_last' or 'channels_first'") + self.data_format = data_format +``` + +🔍 **Line `kernel_size`**: If the user passes an integer (e.g., `kernel_size=3`), we normalize to `(3, 3)`. Same for `strides`. This means the rest of the code can always unpack two values without checking types. + +🔍 **Line `data_format`**: Keras-style `data_format` argument. The validation check at line 30 catches typos early rather than producing mysterious shape errors later. + +### Step 2: `build` + +```python +def build(self, input_shape): + _, h, w, c = self._shape_to_channels_last(input_shape) + kh, kw = self.kernel_size + self.params['W'] = self.kernel_initializer((kh, kw, c, self.filters)) + self.params['b'] = self.bias_initializer((self.filters,)) + super().build(input_shape) +``` + +📐 **Shape**: kernel `W` is `(KH, KW, C, F)` — kernel height, kernel width, input channels, output filters. Bias `b` is `(F,)`. + +Note that `_shape_to_channels_last` converts the input shape to NHWC if it's in `channels_first`, so `build` always sees the canonical shape. + +`compute_output_shape` mirrors the output length formula for 2D: + +```python +oh = (h + 2*padding - kh) // sh + 1 +ow = (w + 2*padding - kw) // sw + 1 +``` + +### Step 3: `forward` + +```python +def forward(self, inputs, training=False): + self.inputs = inputs + inputs_nhwc = self._to_channels_last(inputs) + batch, h, w, c = inputs_nhwc.shape + kh, kw, _, f = self.params['W'].shape + sh, sw = self.strides + + x = inputs_nhwc.transpose(0, 3, 1, 2) + W = self.params['W'].transpose(3, 2, 0, 1) + b = self.params['b'].reshape(-1, 1) + + padding = 0 + if self.padding == 'same': + padding = (kh - 1) // 2 + + self.x_cols = im2col_indices(x, kh, kw, padding=padding, stride=sh) + res = W.reshape(f, -1) @ self.x_cols + b + + oh = (h + 2*padding - kh) // sh + 1 + ow = (w + 2*padding - kw) // sw + 1 + + out = res.reshape(f, oh, ow, batch).transpose(3, 1, 2, 0) + self.z = out + + if self.activation: + out = self.activation(out) + return self._from_channels_last(out) +``` + +🔍 **Convert to NHWC**: `_to_channels_last` ensures the forward pass always works with `(B, H, W, C)` regardless of the user's `data_format`. If already NHWC, it's a no-op. + +🔍 **Prepare for im2col**: im2col expects `(N, C, H, W)` — standard PyTorch format. So we transpose `(B, H, W, C)` → `(B, C, H, W)`. + +The kernel `W` is originally `(KH, KW, C, F)`. We transpose to `(F, C, KH, KW)` to match the im2col channel ordering. After reshaping to `(F, -1)`, each row is one filter flattened across all `KH * KW * C` elements. + +🔍 **The big multiply**: `W.reshape(f, -1)` → `(F, KH*KW*C)`. `self.x_cols` from `im2col_indices` → `(KH*KW*C, N)` where `N = B * OH * OW`. + +`res = (F, KH*KW*C) @ (KH*KW*C, N)` → `(F, N)`. + +📐 **Shape walkthrough** for the output reshape: + +- `res` is `(F, N)` where `N = B * OH * OW`. +- `.reshape(f, oh, ow, batch)` → `(F, OH, OW, B)`. +- `.transpose(3, 1, 2, 0)` → `(B, OH, OW, F)` — the final NHWC output. + +🔍 **Caching `self.x_cols`**: Just like in Conv1D, we save the column matrix because the backward pass needs `x_cols` to compute `dW`. This is the time-vs-memory tradeoff typical in training loops: pay memory cost per layer to avoid recomputing im2col on every backward pass. + +### Step 4: `backward` + +```python +def backward(self, grad_output): + grad_output_nhwc = self._to_channels_last(grad_output) + if self.activation: + if hasattr(self.activation, 'gradient_fast'): + grad_output_nhwc = self.activation.gradient_fast(self.z, grad_output_nhwc) + else: + grad_output_nhwc = grad_output_nhwc * self.activation.gradient(self.z) + + batch, oh, ow, f = grad_output_nhwc.shape + kh, kw, c, _ = self.params['W'].shape + sh, sw = self.strides + + dout = grad_output_nhwc.transpose(3, 1, 2, 0).reshape(f, -1) + + self.grads['b'] = np.sum(grad_output_nhwc, axis=(0, 1, 2)) + + dW = dout @ self.x_cols.T + self.grads['W'] = dW.reshape(f, c, kh, kw).transpose(2, 3, 1, 0) + + W = self.params['W'].transpose(3, 2, 0, 1) + dx_cols = W.reshape(f, -1).T @ dout + + padding = 0 + if self.padding == 'same': + padding = (kh - 1) // 2 + + _, h, w, _ = self._shape_to_channels_last(self.input_shape) + dx = col2im_indices(dx_cols, (batch, c, h, w), kh, kw, padding=padding, stride=sh) + return self._from_channels_last(dx.transpose(0, 2, 3, 1)) +``` + +🔍 **Convert grad_output**: The incoming gradient might be `channels_first` — we normalize it to NHWC first so the shapes align with the cached `x_cols`. + +🔍 **Activation chain rule**: Same pattern as Conv1D — `gradient_fast` if available, otherwise element-wise multiply by the activation gradient. This computes `dL/dz` from `dL/dy`. + +🔍 **dW**: `dout` is `(F, N)` (flattened NHWC grad). `self.x_cols.T` is `(N, KH*KW*C)`. + +`dW = dout @ x_cols.T` → `(F, KH*KW*C)`. + +`.reshape(f, c, kh, kw)` → `(F, C, KH, KW)`. `.transpose(2, 3, 1, 0)` → `(KH, KW, C, F)` — back to the original kernel shape. + +This is the same as Conv1D: the gradient of `W @ X` w.r.t. `W` is `grad @ X^T`, and then we reshape back to the parameter's native shape. + +🔍 **dX via col2im**: We compute the gradient through the matrix multiply: `dx_cols = W^T @ dout` where `W.reshape(f, -1).T` is `(KH*KW*C, F)` and `dout` is `(F, N)`. So `dx_cols` is `(KH*KW*C, N)`. + +Now we call `col2im_indices` — the **reverse** of `im2col_indices`: + +``` +dx = col2im_indices(dx_cols, (batch, c, h, w), kh, kw, padding=padding, stride=sh) +``` + +🔍 **How col2im works**: `im2col_indices` extracted overlapping patches from the input and laid them out as columns of a matrix. `col2im_indices` does the reverse: it takes each column, *scatters* its elements back to their original `(N, C, H, W)` positions, and when multiple patches overlap at the same position, **sums** the overlapping gradient contributions (using `np.add.at`). + +This is exactly the adjoint / transpose of the im2col operation — the gradient of an unrolling operation is the "re-rolling" operation that adds contributions back. + +📐 **Final output**: `dx` is `(B, C, H, W)` (NCHW). We transpose to `(B, H, W, C)` (NHWC) and then convert back to the user's `data_format` via `_from_channels_last`. + +📐 **Gradient shapes**: +| Gradient | Shape | How computed | +|----------|-------|-------------| +| `dL/db` | `(F,)` | `sum(grad_output, axis=(0,1,2))` | +| `dL/dW` | `(KH, KW, C, F)` | `dout @ x_cols.T` reshaped | +| `dL/dX` | `(B, H, W, C)` (or NCHW) | `col2im_indices(W.T @ dout)` | diff --git a/docs/layers/core/activation.md b/docs/layers/core/activation.md new file mode 100644 index 0000000..4c9f340 --- /dev/null +++ b/docs/layers/core/activation.md @@ -0,0 +1,158 @@ +# Activation Layer + +## What does this layer do? + +An Activation layer applies a non-linear function to its input. Without activation functions, a neural network would just be a series of linear transformations — no matter how many layers you stack, you could collapse them into a single matrix multiply. Activation functions introduce the non-linearity that gives neural networks their expressive power. + +## The math + +The Activation layer is a wrapper: it takes a string like `'relu'` and calls the corresponding function: + +$$y = \phi(x)$$ + +Where: +- $x$ is the input (any shape) +- $\phi$ is an element-wise activation function (for most activations) +- $y$ has the same shape as $x$ + +### Backward pass (chain rule) + +$$\frac{\partial L}{\partial x} = \frac{\partial L}{\partial y} \odot \phi'(x)$$ + +For element-wise activations (ReLU, sigmoid, tanh), this is an element-wise product: each output gradient is multiplied by the derivative of the activation at the corresponding input position. + +For **softmax**, the Jacobian is a full matrix (not diagonal): changing one input affects ALL outputs. So `gradient_fast` computes the full Jacobian-vector product. + +## Walking through the code + +### File: `neutro/layers/core/activation.py` + +### Step 1: `__init__` — mapping strings to functions + +```python +class Activation(Layer): + def __init__(self, activation, **kwargs): + super().__init__(**kwargs) + self.activation = get_activation(activation) +``` + +🔍 **Line 7**: `get_activation('relu')` looks up the string in `neutro/activations/__init__.py` and returns an instance like `ReLU()`. This object has both a `__call__` (forward) and a `gradient` (backward) method. + +You don't specify the input size here — like all `neutro` layers, shape is inferred when `forward` is first called. + +### Step 2: no `build` needed + +The Activation layer has **no learnable parameters**. It just applies a fixed function. So `build` is never overridden — it just marks `self.built = True`. + +### Step 3: `forward` — applying the activation + +```python +def forward(self, inputs, training=False): + self.inputs = inputs + return self.activation(inputs) +``` + +📐 **Shape**: Input shape = output shape. If inputs is `(B, D)`, output is `(B, D)`. + +🔍 **Line 10**: `self.inputs = inputs` — We cache the input here because `backward` needs it to compute the gradient. For example, `ReLU.gradient(x)` returns `(x > 0).astype(float)`. + +🔍 **Line 11**: `self.activation(inputs)` — This calls the activation's `__call__` method. For ReLU, it's `np.maximum(0, x)`. For softmax, it's the full softmax computation. Some activations (like softmax and sigmoid) also cache their output on `self` inside their `__call__`. + +### Step 4: `backward` — two paths for gradient computation + +```python +def backward(self, grad_output): + if hasattr(self.activation, 'gradient_fast'): + return self.activation.gradient_fast(self.inputs, grad_output) + return grad_output * self.activation.gradient(self.inputs) +``` + +🔍 **Line 14**: `hasattr(self.activation, 'gradient_fast')` — Checks if this activation has a special fast gradient path. Currently only **softmax** has this. + +**Path 1 — Element-wise activations (ReLU, sigmoid, tanh)**: + +`grad_output * self.activation.gradient(self.inputs)` + +📐 If `grad_output` is `(B, D)` and `gradient(self.inputs)` is `(B, D)`, then `(B, D) * (B, D)` → `(B, D)`. Broadcasting handles extra dimensions automatically. + +- **ReLU**: `gradient(x) = (x > 0)`. Gradient of 1 for positive inputs, 0 for negative. "Dead ReLU" happens when many inputs are negative and the gradient is zero. +- **Sigmoid**: `gradient(x) = sigmoid(x) * (1 - sigmoid(x))`. Maximum gradient is 0.25 (at x=0), so deep sigmoid networks suffer from vanishing gradients. +- **Tanh**: `gradient(x) = 1 - tanh(x)^2`. Maximum gradient is 1.0 (at x=0), better than sigmoid. + +**Path 2 — Softmax**: + +`self.activation.gradient_fast(self.inputs, grad_output)` + +Why is softmax special? For element-wise activations, output $y_i$ depends only on input $x_i$. But for softmax: + +$$y_i = \frac{e^{x_i}}{\sum_j e^{x_j}}$$ + +The derivative is: + +$$\frac{\partial y_i}{\partial x_j} = y_i (\delta_{ij} - y_j)$$ + +This is a full $C \times C$ Jacobian matrix per sample (where $C$ is the number of classes). `gradient_fast` computes the Jacobian-vector product efficiently: + +```python +# For each sample: +s = out_flat[i].reshape(-1, 1) +jacobian = np.diagflat(s) - np.dot(s, s.T) # (C, C) +res[i] = np.dot(grad_flat[i], jacobian) # (C,) @ (C, C) -> (C,) +``` + +🧠 **Key insight**: `gradient_fast` exists because softmax is usually paired with cross-entropy loss. Their combined gradient is much simpler: just $(y - \text{target})$. But since the `Activation` layer doesn't know about the loss, it computes the full softmax Jacobian separately. In practice, you'd combine them for efficiency. + +### Step 5: Convenience subclasses + +```python +class ReLU(Activation): + def __init__(self, **kwargs): + super().__init__('relu', **kwargs) + +class Softmax(Activation): + def __init__(self, **kwargs): + super().__init__('softmax', **kwargs) + +class Sigmoid(Activation): + def __init__(self, **kwargs): + super().__init__('sigmoid', **kwargs) + +class Tanh(Activation): + def __init__(self, **kwargs): + super().__init__('tanh', **kwargs) +``` + +These are just shortcuts so you can write `ReLU()` instead of `Activation('relu')`. They're functionally identical. + +## Activation function reference + +| Activation | Formula | Derivative | Gradient max | Notes | +|---|---|---|---|---| +| **ReLU** | $\max(0, x)$ | $1_{x > 0}$ | 1.0 | Can "die" (always zero) | +| **Sigmoid** | $1/(1+e^{-x})$ | $\sigma(x)(1-\sigma(x))$ | 0.25 | Vanishing gradient | +| **Tanh** | $\tanh(x)$ | $1-\tanh^2(x)$ | 1.0 | Zero-centered | +| **Softmax** | $e^{x_i}/\sum e^{x_j}$ | Full Jacobian | — | For classification | + +## Try it yourself + +```python +from neutro.layers import Activation, ReLU, Sigmoid +import numpy as np + +# Using the generic Activation layer +act = Activation('relu') +x = np.array([-2, -1, 0, 1, 2]) +y = act(x) # [0, 0, 0, 1, 2] +grad = np.array([1, 1, 1, 1, 1]) +dx = act.backward(grad) # [0, 0, 0, 1, 1] — gradient flows only for positive inputs + +# Using convenience subclass +relu = ReLU() +y2 = relu(x) # Same result +``` + +## What to read next + +- `docs/layers/core/dense.md` — Dense layers are typically paired with activations +- `docs/activations/activations.md` — The activation function implementations themselves +- `docs/activations/softmax.md` — Deep dive into softmax and its Jacobian diff --git a/docs/layers/core/core_utility_layers.md b/docs/layers/core/core_utility_layers.md index 8ccf648..44065f3 100644 --- a/docs/layers/core/core_utility_layers.md +++ b/docs/layers/core/core_utility_layers.md @@ -2,90 +2,436 @@ ## Dropout — `neutro/layers/core/dropout.py` -Randomly sets a fraction of inputs to zero during training, preventing co-adaptation: +### What does this layer do? -$$y = \begin{cases} \frac{m \odot x}{1 - p} & \text{training} \\ x & \text{inference} \end{cases}$$ +Dropout randomly "drops" (sets to zero) a fraction of neurons during training. This forces the network to not rely too heavily on any single neuron, which prevents **co-adaptation** and acts as a regularization technique. During inference, all neurons are used at full strength. -Where $m_i \sim \text{Bernoulli}(1-p)$ is a mask. The scaling by $1/(1-p)$ keeps the expected output magnitude constant. +### The math, in plain English + +$$ +y = \begin{cases} \frac{m \odot x}{1 - p} & \text{during training} \\ x & \text{during inference} \end{cases} +$$ + +Each element of the mask $m$ is drawn from a Bernoulli distribution with probability $1-p$ (i.e., it is 1 with probability $1-p$, and 0 with probability $p$). The symbol $\odot$ means element-wise multiplication. + +**Why do we divide by $1-p$?** — If a neuron is kept with probability $1-p$, its expected value during training is $x \cdot (1-p)$. By dividing by $1-p$, we restore the expected magnitude to $x$, so the training and inference outputs have the same scale. Without this scaling, the network would see much larger values during inference and produce wrong results. + +### Walking through the code + +#### `__init__` + +```python +def __init__(self, rate, **kwargs): + super().__init__(**kwargs) + self.rate = rate + self.mask = None +``` + +🔍 **Line `self.mask = None`**: The mask is created during `forward` and cached on `self` so that `backward` can reuse it. Before the first forward pass, it starts as `None`. + +#### `forward` ```python def forward(self, inputs, training=False): - if not training: + if not training or self.rate == 0: return inputs - self.mask = np.random.binomial(1, 1 - self.rate, size=inputs.shape) - return inputs * self.mask / (1 - self.rate) + self.mask = np.random.binomial(1, 1 - self.rate, size=inputs.shape) / (1 - self.rate) + return inputs * self.mask +``` + +🔍 **Line `if not training or self.rate == 0`**: Dropout only applies during training. At inference time, we return the input unchanged. If `rate=0`, nothing is dropped, so we skip the overhead. + +🔍 **Line `self.mask = np.random.binomial(1, 1 - self.rate, size=inputs.shape) / (1 - self.rate)`**: This is the core. `np.random.binomial(1, 1-self.rate, ...)` creates a binary mask where each element is 1 with probability $1-p$. By dividing by $1-p$ right here in the mask, the subsequent multiplication `inputs * self.mask` naturally produces the scaled output. The mask is cached for the backward pass. + +📐 **Shape**: `mask.shape == inputs.shape` — every element gets its own independent mask value. +🔍 **Line `return inputs * self.mask`**: Element-wise multiplication applies the mask. + +#### `backward` + +```python def backward(self, grad_output): - return grad_output * self.mask / (1 - self.rate) + if self.mask is None: + return grad_output + return grad_output * self.mask ``` +🔍 **Line `grad_output * self.mask`**: The derivative of $y = m \odot x$ with respect to $x$ is $m$ (the mask). So the gradient is simply multiplied by the same mask used in the forward pass. Note the mask already includes the $1/(1-p)$ scaling factor. + +🔍 **Line `if self.mask is None`**: If we never called forward (or called it with `training=False`), there's no mask. In that case, the gradient passes through unchanged — just like the forward pass. + +--- + ## Flatten — `neutro/layers/core/flatten.py` -Reshapes a multi-dimensional input into a 2D (batch, features) tensor, preserving the batch dimension: +### What does this layer do? + +Flatten reshapes a multi-dimensional input (e.g., a batch of images) into a 2D array where each sample becomes a single 1D vector. This is typically used between convolutional layers and dense layers. + +### The math, in plain English + +$$ +\text{Input shape: } (N, d_1, d_2, \dots, d_k) \quad \Longrightarrow \quad \text{Output shape: } (N, d_1 \times d_2 \times \dots \times d_k) +$$ + +The batch dimension $N$ is preserved. All other dimensions are multiplied together to form a single feature dimension. + +### Walking through the code + +#### `__init__` ```python -def forward(self, inputs): +def __init__(self, **kwargs): + super().__init__(**kwargs) +``` + +No special parameters — flattening is purely a shape transformation. + +#### `build` + +```python +def build(self, input_shape): + self.input_shape_orig = input_shape + super().build(input_shape) +``` + +🔍 **Line `self.input_shape_orig = input_shape`**: We save the original input shape so that `backward` knows how to "un-flatten" the gradient. + +#### `compute_output_shape` + +```python +def compute_output_shape(self, input_shape): + import numpy as np + return (input_shape[0], int(np.prod(input_shape[1:]))) +``` + +📐 `input_shape[0]` is the batch dimension (kept as-is). `int(np.prod(input_shape[1:]))` multiplies all remaining dimensions together. For example, `(None, 28, 28, 1)` → `(None, 784)`. + +#### `forward` + +```python +def forward(self, inputs, training=False): + self.input_shape_orig = inputs.shape return inputs.reshape(inputs.shape[0], -1) +``` +🔍 **Line `self.input_shape_orig = inputs.shape`**: We cache the **actual** shape from the real data (not the symbolic shape from `build`). This is critical because `build` receives symbolic shapes (e.g., `(None, 28, 28)`), but we need the concrete batch size. + +📐 **Shape**: `inputs.shape` = `(batch, d1, d2, ..., dk)`. `inputs.shape[0]` is the batch size. `-1` tells NumPy to infer the remaining dimension as the product of all other dimensions. So `(8, 4, 4, 16)` → `(8, 256)`. + +#### `backward` + +```python def backward(self, grad_output): - return grad_output.reshape(self.input_shape) + return grad_output.reshape(self.input_shape_orig) +``` + +🔍 **Line `grad_output.reshape(self.input_shape_orig)`**: The backward pass simply reverses the flattening. The gradient is reshaped back to the original multi-dimensional shape. This works because `reshape` is a purely geometric operation — the gradient flows through each element along the path of the reshape. + +--- + +## MoE Layer (Mixture of Experts) — `neutro/layers/core/moe.py` + +### What does this layer do? + +A Mixture-of-Experts (MoE) layer maintains multiple "expert" sub-networks and a **router** that decides which expert(s) should process each input token. Instead of every token passing through every expert (expensive), each token only activates the top-$k$ most relevant experts. This scales model capacity without proportionally scaling compute. + +### The math, in plain English + +$$ +y_t = \sum_{i=1}^{E} g_i(x_t) \cdot E_i(x_t) +$$ + +For a single token $x_t$: +- $E_i(x)$ is the output of expert $i$ (a small MLP). +- $g_i(x)$ is the router's **gating weight** for expert $i$ — but it is zero for all experts not in the top-$k$. +- The final output is a weighted sum of only the chosen experts' outputs. + +The router is just a linear layer followed by softmax: + +$$ +g(x) = \text{softmax}(x W_{\text{router}}) +$$ + +### Walking through the code + +#### `__init__` + +```python +def __init__(self, num_experts, top_k, expert_units, **kwargs): + super().__init__(**kwargs) + self.num_experts = num_experts + self.top_k = top_k + self.expert_units = expert_units + self.experts = [] +``` + +🔍 **`num_experts`**: The total number of expert MLPs in the pool. + +🔍 **`top_k`**: How many experts are activated per token. If `top_k=2`, each token is processed by only 2 out of `num_experts` experts. + +🔍 **`expert_units`**: The hidden dimension inside each expert's MLP. + +🔍 **`self.experts = []`**: Will be populated in `build` with pairs of Dense layers (one per expert). + +#### `build` + +```python +def build(self, input_shape): + self.input_dim = input_shape[-1] + + self.params['router_weight'] = np.random.normal(0, 0.02, (self.input_dim, self.num_experts)) + + for i in range(self.num_experts): + e1 = Dense(self.expert_units, activation='relu') + e1.build(input_shape) + + expert_shape = list(input_shape) + expert_shape[-1] = self.expert_units + + e2 = Dense(self.input_dim) + e2.build(tuple(expert_shape)) + + self.experts.append([e1, e2]) + + super().build(input_shape) ``` -## MoE Layer — `neutro/layers/core/moe.py` +🔍 **Line `self.params['router_weight']`**: The router is a weight matrix of shape `(input_dim, num_experts)`. It maps each input token to a score for each expert. -### Theory +📐 **Shape**: `router_weight` = `(input_dim, num_experts)` — one column per expert. -Mixture-of-Experts (MoE) scales model capacity without proportional compute. A router network selects which "expert" sub-networks to activate for each input token: +🔍 **Each expert is two Dense layers**: `e1` projects from `input_dim` → `expert_units` (with ReLU), and `e2` projects back from `expert_units` → `input_dim`. This is a bottle-neck MLP. -$$y = \sum_{i=1}^E g_i(x) \cdot E_i(x)$$ +🔍 **`super().build(input_shape)`**: Marks the layer as `built = True`. -Where $g_i(x)$ is the router's gating weight (typically top-$k$ sparse) and $E_i$ are expert feed-forward networks. +#### `compute_output_shape` -### Router — `neutro/layers/core/moe.py:30` +```python +def compute_output_shape(self, input_shape): + return input_shape +``` + +The output has the same shape as the input — each token's representation is transformed but its dimensionality stays the same. + +#### `forward` ```python -def forward(self, x): - logits = np.dot(x, self.params['W']) # (batch, seq, num_experts) - weights = softmax(logits, axis=-1) - # Top-k routing - top_k_weights, top_k_indices = ... +def forward(self, x, training=False): + self.x_shape = x.shape + self.x_flat = x.reshape(-1, self.input_dim) + num_tokens = self.x_flat.shape[0] + + # 1. Routing scores + router_logits = self.x_flat @ self.params['router_weight'] + + # Softmax to get probabilities + router_probs = np.exp(router_logits - np.max(router_logits, axis=-1, keepdims=True)) + router_probs /= np.sum(router_probs, axis=-1, keepdims=True) + self.router_probs = router_probs + + # 2. Select top-k experts + top_k_indices = np.argsort(router_probs, axis=-1)[:, -self.top_k:] + self.top_k_indices = top_k_indices + + # 3. Dispatch to experts and combine results + final_output = np.zeros_like(self.x_flat) + self.expert_outputs = {} + + for expert_idx in range(self.num_experts): + token_indices, _ = np.where(top_k_indices == expert_idx) + if len(token_indices) == 0: + continue + + tokens = self.x_flat[token_indices] + + out = tokens + for layer in self.experts[expert_idx]: + out = layer(out, training=training) + + self.expert_outputs[expert_idx] = (token_indices, out) + + weights = router_probs[token_indices, expert_idx].reshape(-1, 1) + final_output[token_indices] += weights * out + + return final_output.reshape(self.x_shape) ``` -The router learns to assign tokens to the most relevant experts. +🔍 **Line `self.x_flat = x.reshape(-1, self.input_dim)`**: Flatten the input to 2D: `(batch * seq_len, input_dim)` if the input is 3D, or `(batch, input_dim)` if already 2D. This lets us process each token independently. + +📐 **Shape**: `x` = `(batch, seq_len, dim)` → `x_flat` = `(batch * seq_len, dim)`. + +🔍 **Lines `router_logits = self.x_flat @ self.params['router_weight']`**: A simple matrix multiplication. Each token gets `num_experts` scores (logits). + +📐 **Shape**: `x_flat` is `(T, input_dim)`, `router_weight` is `(input_dim, num_experts)`. Result: `(T, num_experts)`. + +🔍 **Lines `router_probs = np.exp(router_logits - ...) ... /= np.sum(...)`**: A numerically stable softmax. We subtract the max logit before exponentiating to prevent overflow. + +📐 **Shape**: `router_probs` = `(T, num_experts)` — probabilities summing to 1 per token. + +🔍 **Line `top_k_indices = np.argsort(router_probs, axis=-1)[:, -self.top_k:]`**: Sort the probabilities and take the last `top_k` indices (highest scores). + +📐 **Shape**: `top_k_indices` = `(T, top_k)` — for each token, which experts to use. + +🔍 **Line `token_indices, _ = np.where(top_k_indices == expert_idx)`**: For each expert, find which tokens selected it in their top-k. `np.where` returns the row (token) indices where this expert appears. + +🔍 **Line `out = tokens; for layer in self.experts[expert_idx]: out = layer(out, training=training)`**: Pass the selected tokens through the expert's MLP (ReLU then output projection). + +🔍 **Line `weights = router_probs[token_indices, expert_idx].reshape(-1, 1)`**: Look up the router probability for this expert for each selected token. + +🔍 **Line `final_output[token_indices] += weights * out`**: Weight the expert's output by the router probability and add to the final result. Tokens not assigned to this expert get no contribution. + +📐 **Shape**: `weights` is `(num_selected, 1)`, `out` is `(num_selected, input_dim)`. Broadcasting multiplies each row of `out` by its weight. + +#### `backward` + +```python +def backward(self, grad_output): + grad_flat = grad_output.reshape(-1, self.input_dim) + num_tokens = self.x_flat.shape[0] + + dx_flat = np.zeros_like(self.x_flat) + drouter_logits = np.zeros_like(self.router_probs) + + # 1. Backprop through experts and router probabilities + for expert_idx, (token_indices, out) in self.expert_outputs.items(): + weights = self.router_probs[token_indices, expert_idx].reshape(-1, 1) + expert_grad = weights * grad_flat[token_indices] + + drouter_probs_expert = np.sum(grad_flat[token_indices] * out, axis=-1) + drouter_logits[token_indices, expert_idx] = drouter_probs_expert + + curr_grad = expert_grad + for layer in reversed(self.experts[expert_idx]): + curr_grad = layer.backward(curr_grad) + + dx_flat[token_indices] += curr_grad + + # 2. Backprop through Softmax for router + drouter_logits = self.router_probs * ( + drouter_logits - np.sum(self.router_probs * drouter_logits, axis=-1, keepdims=True) + ) + + # 3. Router weight gradient + self.grads['router_weight'] = self.x_flat.T @ drouter_logits + + # 4. Add router's contribution to dx + dx_flat += drouter_logits @ self.params['router_weight'].T + + return dx_flat.reshape(self.x_shape) +``` + +🔍 **Line `expert_grad = weights * grad_flat[token_indices]`**: The gradient through the weighted combination $y = w \cdot E(x)$. Since $dy/dE = w$, we multiply the upstream gradient by the router weight. + +🔍 **Line `drouter_probs_expert = np.sum(grad_flat[token_indices] * out, axis=-1)`**: For each token assigned to this expert, the gradient w.r.t. the router probability weight $w$ is: $dL/dw = dL/dy \cdot E(x)$. We sum across the feature dimension because each token has a single scalar weight per expert. + +🔍 **Line `for layer in reversed(self.experts[expert_idx]): curr_grad = layer.backward(curr_grad)`**: Backprop through the expert's MLP layers in reverse order. Each Dense layer computes gradients for its own parameters and passes gradients backward. + +🔍 **Line `drouter_logits = self.router_probs * (drouter_logits - np.sum(...))`**: Standard softmax backward pass. For softmax $p_i = e^{z_i} / \sum e^{z_j}$, the Jacobian is $\partial p_i / \partial z_j = p_i (\delta_{ij} - p_j)$. In vectorized form: $dL/dz = p \odot (dL/dp - \sum(p \odot dL/dp))$. + +🔍 **Line `self.grads['router_weight'] = self.x_flat.T @ drouter_logits`**: Gradient for the router weight matrix: $dL/dW = x^T \cdot dL/dz$ (outer product). + +🔍 **Line `dx_flat += drouter_logits @ self.params['router_weight'].T`**: The router also contributes to the input gradient. $x$ flows into both the router (through the linear layer) and the experts. The total gradient is the sum of both contributions. + +--- ## Reparameterization — `neutro/layers/core/reparameterization.py` -Implements the reparameterization trick used in VAEs. A sample from $N(\mu, \sigma^2)$ is: +### What does this layer do? + +This layer implements the **reparameterization trick** used in Variational Autoencoders (VAEs). It takes two tensors — a mean $\mu$ and a log-variance $\log \sigma^2$ — and produces a sample from the corresponding Gaussian distribution. The trick is that the sampling operation is rewritten so that gradients can flow through it. + +### The math, in plain English -$$z = \mu + \sigma \odot \epsilon, \quad \epsilon \sim N(0, I)$$ +The standard sampling from a Gaussian is $z \sim \mathcal{N}(\mu, \sigma^2)$. But sampling is a stochastic operation with no gradient. The reparameterization trick rewrites it as: -This makes the sampling operation differentiable, enabling backpropagation through the stochastic layer. +$$ +z = \mu + \sigma \odot \epsilon \quad \text{where} \quad \epsilon \sim \mathcal{N}(0, I) +$$ + +Here $\sigma = \exp(0.5 \cdot \log \sigma^2) = \sqrt{\sigma^2}$, and $\epsilon$ is a random noise vector drawn from the standard normal. The sampling is now **deterministic given $\epsilon$**, so gradients can flow backward through $\mu$ and $\sigma$ (and to $\log \sigma^2$). + +### Walking through the code + +#### `__init__` ```python -def forward(self, inputs): - mu, log_var = inputs - eps = np.random.randn(*mu.shape) - return mu + np.exp(0.5 * log_var) * eps +def __init__(self, **kwargs): + super().__init__(**kwargs) ``` -## Usage Example +No special parameters — the layer just needs to know the shape of the latent space, which comes from the input. + +#### `compute_output_shape` + +```python +def compute_output_shape(self, input_shape): + if isinstance(input_shape, list): + return input_shape[0] + return input_shape +``` + +The output shape is the same as the mean tensor's shape (the first input). The second input (log variance) has the same shape. + +#### `forward` ```python -from neutro.layers import Dropout, Flatten, MoELayer +def forward(self, inputs, training=False): + self.z_mean, self.z_log_var = inputs + + if not training: + return self.z_mean + + self.epsilon = np.random.normal(size=self.z_mean.shape) + self.z = self.z_mean + np.exp(0.5 * self.z_log_var) * self.epsilon + return self.z +``` + +🔍 **Line `self.z_mean, self.z_log_var = inputs`**: The layer receives **two tensors** as a list: `[z_mean, z_log_var]`. They must have the same shape. + +🔍 **Line `if not training: return self.z_mean`**: During inference, we don't sample — we return the mean directly. This gives a deterministic output. + +🔍 **Line `self.epsilon = np.random.normal(size=self.z_mean.shape)`**: Draw random noise from the standard normal distribution $\mathcal{N}(0, 1)$. This is cached for the backward pass. + +📐 **Shape**: `epsilon.shape == z_mean.shape` — one noise value per element. + +🔍 **Line `self.z = self.z_mean + np.exp(0.5 * self.z_log_var) * self.epsilon`**: The reparameterization formula: +- `np.exp(0.5 * self.z_log_var)` computes $\sigma = \sqrt{\sigma^2}$. +- Multiply by $\epsilon$ gives the stochastic part. +- Add $\mu$ to center it. -drop = Dropout(rate=0.5) -x = np.random.randn(8, 64) -y = drop(x, training=True) # 50% of units dropped +📐 **Shape**: All tensors have the same shape, e.g., `(batch, latent_dim)`. -flat = Flatten() -x = np.random.randn(8, 4, 4, 16) -y = flat(x) # (8, 256) +🔍 **Line `self.z` is cached** so the backward pass doesn't need to recompute it (though backward uses `z_log_var` and `epsilon`, not `z` directly). -moe = MoELayer(num_experts=8, expert_dim=512, top_k=2) -x = np.random.randn(2, 16, 512) -y = moe(x) +#### `backward` + +```python +def backward(self, grad_output): + grad_mean = grad_output + grad_log_var = grad_output * np.exp(0.5 * self.z_log_var) * 0.5 * self.epsilon + + return [grad_mean, grad_log_var] ``` +🔍 **Line `grad_mean = grad_output`**: The derivative of $z$ with respect to $\mu$ is 1 (from $z = \mu + \sigma \epsilon$). So the gradient passes through unchanged. + +🔍 **Line `grad_log_var = grad_output * np.exp(0.5 * self.z_log_var) * 0.5 * self.epsilon`**: Chain rule in action: + +$$ +\frac{\partial z}{\partial (\log \sigma^2)} = \frac{\partial}{\partial (\log \sigma^2)} \left( \exp(0.5 \cdot \log \sigma^2) \cdot \epsilon \right) = \exp(0.5 \cdot \log \sigma^2) \cdot 0.5 \cdot \epsilon +$$ + +Breaking it down: +1. Let $a = 0.5 \cdot \log \sigma^2$. +2. Let $b = \exp(a)$ (which is $\sigma$). +3. $z = \mu + b \cdot \epsilon$. +4. $db/da = \exp(a) = b$. +5. $da/d(\log \sigma^2) = 0.5$. +6. So $dz/d(\log \sigma^2) = b \cdot 0.5 \cdot \epsilon = \exp(0.5 \cdot \log \sigma^2) \cdot 0.5 \cdot \epsilon$. + +🔍 **The function returns `[grad_mean, grad_log_var]`**: Since the forward received a list of two tensors, the backward must return a list of two gradients — one for each input. + ## References - Srivastava, N., et al. (2014). **Dropout: A Simple Way to Prevent Neural Networks from Overfitting**. *JMLR*. diff --git a/docs/layers/core/dense.md b/docs/layers/core/dense.md index 412dc4b..e4de5a4 100644 --- a/docs/layers/core/dense.md +++ b/docs/layers/core/dense.md @@ -1,21 +1,28 @@ # Dense Layer -## Theory +## What does this layer do? -A Dense (fully-connected) layer computes a linear transformation followed by an optional activation: +A Dense (or "fully-connected") layer connects every input neuron to every output neuron. You give it a vector (or a batch of vectors), it multiplies by a weight matrix, adds a bias, and optionally runs an activation function like ReLU or sigmoid. -$$y = \phi(xW + b)$$ +Think of it as the "basic building block" of neural networks — most models start and end with one or more Dense layers. -Where: -- $x \in \mathbb{R}^{B \times D}$ is the input (batch $B$, input dimension $D$) -- $W \in \mathbb{R}^{D \times U}$ is the weight matrix (learned) -- $b \in \mathbb{R}^{U}$ is the bias vector (learned) -- $\phi$ is an element-wise activation function (ReLU, sigmoid, tanh, or none) -- $y \in \mathbb{R}^{B \times U}$ is the output +## The math, in plain English -### Backward Pass +$$y = \phi(x W + b)$$ -The gradients are: +Let's unpack every symbol: + +- **$x$** — Your input. Shape `(batch_size, input_dim)`. Each row is one data point (e.g., a 128-dimensional feature vector for one image). +- **$W$** — The weight matrix. Shape `(input_dim, units)`. Every entry $W_{ij}$ controls how much input neuron $i$ contributes to output neuron $j$. These are **learned** during training. +- **$b$** — The bias vector. Shape `(units,)`. An offset added to each output neuron. Also **learned**. +- **$xW$** — Matrix multiply: `(batch, input_dim) @ (input_dim, units)` → `(batch, units)`. This is a **linear transformation** — it rotates and scales the input space. +- **$xW + b$** — The bias is *broadcast* across the batch (added to every row). This gives each output neuron a baseline firing threshold. +- **$\phi$** — An activation function applied element-wise (to each number independently). ReLU turns negatives to zero, sigmoid squashes values between 0 and 1, etc. This is where **non-linearity** comes from — without it, stacking Dense layers would be the same as one big linear transformation. +- **$y$** — Your output. Shape `(batch, units)`. Each row is a `units`-dimensional transformed representation of the input. + +### How gradients flow backward + +During training, we need to adjust $W$ and $b$ to reduce the loss. The gradient formulas are: $$\frac{\partial L}{\partial W} = x^T \cdot \frac{\partial L}{\partial y}$$ @@ -23,25 +30,45 @@ $$\frac{\partial L}{\partial b} = \sum_{\text{batch}} \frac{\partial L}{\partial $$\frac{\partial L}{\partial x} = \frac{\partial L}{\partial y} \cdot W^T$$ -If an activation function $\phi$ is present, the gradient is first passed through $\phi'$ before these equations. +Here $\frac{\partial L}{\partial y}$ is the gradient coming *from the next layer* (the "upstream gradient"). The three formulas tell us: + +1. **Weight gradient**: Transpose the input and multiply by the upstream gradient. Shape: `(input_dim, batch) @ (batch, units)` → `(input_dim, units)` — exactly the same shape as $W$. +2. **Bias gradient**: Sum the upstream gradient over the batch dimension. The bias is added to every sample, so its gradient is the sum of all per-sample gradients. +3. **Input gradient**: Multiply the upstream gradient by $W^T$. This gets passed to the previous layer so it can compute *its* weight gradients. + +If an activation $\phi$ is present, the chain rule says we must first multiply the upstream gradient by $\phi'(z)$ (the derivative of the activation at $z = xW + b$): -## Implementation Guide +$$\frac{\partial L}{\partial z} = \frac{\partial L}{\partial y} \odot \phi'(z)$$ -### File: `neutro/layers/core/dense.py` +where $\odot$ is element-wise multiplication, and then use $\frac{\partial L}{\partial z}$ in place of $\frac{\partial L}{\partial y}$ in the three formulas above. -### `__init__` — line 7 +## Walking through the code + +### Step 1: `__init__` — setting the stage ```python class Dense(Layer): def __init__(self, units, activation=None, use_bias=True, kernel_initializer='glorot_uniform', bias_initializer='zeros', **kwargs): + super().__init__(**kwargs) + self.units = units + self.activation = get_activation(activation) + self.use_bias = use_bias + self.kernel_initializer = get_initializer(kernel_initializer) + self.bias_initializer = get_initializer(bias_initializer) ``` -- `units`: number of output neurons. -- `activation`: a string like `'relu'` → mapped to an activation function via `get_activation()`. -- Weight initialization is deferred to `build()`. +🔍 **Line 7**: `super().__init__(**kwargs)` — Calls `Layer.__init__`, which sets `self.built = False`, creates empty `self.params = {}` and `self.grads = {}`. The `**kwargs` lets you pass `name='my_dense'` or `input_shape=(128,)` which the base class knows how to handle. + +🔍 **Line 8**: `self.units = units` — We store this now because we won't know the weight shapes until `build()` runs (since we don't know `input_dim` yet). But `units` is a hyperparameter *we* choose, so it goes in `__init__`. + +🔍 **Line 9**: `self.activation = get_activation(activation)` — `activation` is a string like `'relu'`, `'sigmoid'`, or `None`. The `get_activation` function looks it up and returns an `Activation` object (e.g., a `ReLU` instance). This object has two key methods: `__call__` (the forward pass) and `gradient` (the derivative for backprop). If `activation=None`, `get_activation` returns `None` and we skip the activation step. + +🔍 **Line 10**: `self.use_bias = use_bias` — Some layers (like the layer right before a softmax) don't need a bias. Storing this flag lets `build` decide whether to allocate `params['b']`. + +🔍 **Lines 11-12**: `self.kernel_initializer` and `self.bias_initializer` — These are *strategy objects* that know how to create weight matrices with sensible starting values. `glorot_uniform` draws from a uniform distribution scaled by the number of input/output neurons. The actual initialization is deferred to `build`. -### `build` — line 15 +### Step 2: `build` — creating the learnable parameters ```python def build(self, input_shape): @@ -52,9 +79,28 @@ def build(self, input_shape): super().build(input_shape) ``` -Parameters are allocated here, not in `__init__`. This is the standard Keras pattern: the shape is inferred from the first call. +🔍 **Line 16**: `self.input_dim = input_shape[-1]` — We grab the last dimension of the input shape. If input is `(32, 128)` (batch of 32, each 128-dimensional), then `input_dim = 128`. But what if the input is 3D, like `(32, 10, 64)`? Then `input_dim = 64` — we only care about the **last** dimension because Dense operates on the *last axis*. -### `forward` — line 26 +🔍 **Line 17**: `self.params['W'] = self.kernel_initializer((self.input_dim, self.units))` — Here's where the weight matrix is actually created. Shape is `(input_dim, units)`: + +``` + units (= 64) + ┌─────────────────┐ + │ W[0,0] W[0,1] │ +D │ W[1,0] W[1,1] │ Each column j: "how to compute +i │ ... ... │ output neuron j from all inputs" +m │ │ +(=128)│ W[127,0] W[127,63]│ + └─────────────────┘ +``` + +Think of each **column** of $W$ as a set of weights that produce one output neuron. The input `input_dim` must match the last dimension of whatever data comes in. + +🔍 **Line 19**: `self.params['b'] = self.bias_initializer((self.units,))` — The bias is a 1D vector of length `units`. When added to `xW` (shape `(batch, units)`), NumPy broadcasts it across the batch dimension automatically. + +🔍 **Line 20**: `super().build(input_shape)` — This sets `self.built = True`. After this line, the layer won't call `build` again. It also stores `self.input_shape` for `summary()`. + +### Step 3: `forward` — the main computation ```python def forward(self, inputs, training=False): @@ -62,49 +108,180 @@ def forward(self, inputs, training=False): self.z = np.dot(inputs, self.params['W']) if self.use_bias: self.z += self.params['b'] + if self.activation: return self.activation(self.z) return self.z ``` -- `self.inputs` is cached for use in `backward`. -- `self.z` is cached for use in activation backpropagation. -- The activation function (`self.activation`) is called as a callable; it may be a `Layer` instance with its own forward/backward. +🔍 **Line 27**: `self.inputs = inputs` — We **cache** the input here. Why? Because `backward` (line 36) needs it to compute `self.grads['W'] = np.dot(inputs_flat.T, grad_output_flat)`. The input isn't available during `backward` unless we saved it now. + +🔍 **Line 28**: `self.z = np.dot(inputs, self.params['W'])` — The core computation. Let's trace the shapes: + +📐 **Shape walkthrough**: `inputs` is `(B, D)` where `B = batch_size, D = input_dim`. `self.params['W']` is `(D, U)` where `U = units`. `np.dot((B, D), (D, U))` → `(B, U)`. Each output row is the input row multiplied by the weight matrix. -### `backward` — line 36 +But wait — `inputs` might be 3D! For example, a `(B, T, D)` sequence where `T` is sequence length. That's fine: `np.dot` treats the first dimensions as batch dimensions, so `(B, T, D) @ (D, U)` → `(B, T, U)`. The same weight matrix is applied at every position in the sequence. + +🔍 **Lines 29-30**: `if self.use_bias: self.z += self.params['b']` — Adding the bias vector. NumPy broadcasting means `(B, U) += (U,)` adds the same bias to every row. If `inputs` was 3D, this broadcasts as `(B, T, U) += (U,)`. + +🔍 **Line 32**: `self.z` is cached for a *different* reason than `self.inputs`. It stores the **pre-activation** values ($xW + b$, before the activation function). In `backward`, line 41, we compute `self.activation.gradient(self.z)` — the derivative of the activation evaluated at these pre-activation values. Without caching `self.z`, we'd need to recompute it in backward. + +🔍 **Line 33**: `return self.activation(self.z)` — Calls the activation function's `__call__`, which applies it element-wise. If `activation` is `None`, we skip to line 34 and return `self.z` directly. + +### Step 4: `backward` — learning from mistakes ```python def backward(self, grad_output): if self.activation: - grad_output = grad_output * self.activation.gradient(self.z) + if hasattr(self.activation, 'gradient_fast'): + grad_output = self.activation.gradient_fast(self.z, grad_output) + else: + grad_output = grad_output * self.activation.gradient(self.z) + + inputs_flat = self.inputs.reshape(-1, self.inputs.shape[-1]) + grad_output_flat = grad_output.reshape(-1, grad_output.shape[-1]) + + self.grads['W'] = np.dot(inputs_flat.T, grad_output_flat) + if self.use_bias: + self.grads['b'] = np.sum(grad_output_flat, axis=0) + + return np.dot(grad_output, self.params['W'].T) +``` + +Let's break this down piece by piece. + +#### Step 4a: Handle the activation gradient + +```python + if self.activation: + if hasattr(self.activation, 'gradient_fast'): + grad_output = self.activation.gradient_fast(self.z, grad_output) + else: + grad_output = grad_output * self.activation.gradient(self.z) +``` +🔍 **Lines 37-41**: The **chain rule**. We have the upstream gradient `grad_output` (shape `(B, U)`), which is $\frac{\partial L}{\partial y}$ — the gradient of the loss w.r.t. the *activated* output. + +But the weight gradient formulas use $\frac{\partial L}{\partial z}$ — the gradient w.r.t. the *pre-activation* values. The chain rule says: + +$$\frac{\partial L}{\partial z} = \frac{\partial L}{\partial y} \cdot \frac{\partial y}{\partial z} = \frac{\partial L}{\partial y} \odot \phi'(z)$$ + +For element-wise activations (ReLU, sigmoid, tanh), $\phi'(z)$ is just the derivative evaluated at each element, and the multiplication is element-wise. + +🔍 **Line 41**: `grad_output * self.activation.gradient(self.z)` — This is the standard path. `self.activation.gradient(self.z)` returns an array of the same shape as `self.z` (cached in forward at line 28), containing $\phi'(z)$ at each element. For ReLU, this is `(z > 0)` — a mask of 1s and 0s. Element-wise multiply with `grad_output` zeros out gradients for ReLU'd neurons that were originally negative. + +🔍 **Lines 38-39**: `self.activation.gradient_fast(self.z, grad_output)` — A special path for **Softmax**. Why? Because Softmax's Jacobian isn't element-wise — each output depends on *all* inputs, so the full Jacobian is a `(U, U)` matrix per sample. The `gradient_fast` method on `Softmax` computes the matrix-vector product `grad_output @ J_softmax` efficiently without constructing the full `(U, U)` Jacobian explicitly (well, in this educational implementation it does construct it, but you can imagine a more efficient version). The standard `gradient` method (which returns `s * (1 - s)`) would give the wrong answer for softmax — it's only correct for element-wise sigmoid. + +#### Step 4b: Flatten for multi-dimensional inputs + +```python inputs_flat = self.inputs.reshape(-1, self.inputs.shape[-1]) grad_output_flat = grad_output.reshape(-1, grad_output.shape[-1]) +``` + +🔍 **Lines 43-44**: Why the reshape? Let's say the input was 3D: `inputs.shape = (8, 10, 64)` — batch of 8 sequences, each 10 tokens, each token 64-dimensional. The forward pass produced `z.shape = (8, 10, 32)` and `grad_output.shape = (8, 10, 32)`. +To compute `self.grads['W']`, we need `inputs.T @ grad_output`. But `inputs` is 3D and `grad_output` is 3D — we can't just transpose and multiply. + +So we **flatten the batch dimensions**: + +📐 **Shape walkthrough**: `inputs` `(8, 10, 64)` → `.reshape(-1, 64)` → `(80, 64)`. The `-1` tells NumPy: "figure out the size automatically" — `8 * 10 = 80`. Now `grad_output` `(8, 10, 32)` → `.reshape(-1, 32)` → `(80, 32)`. + +Now the matrix multiply works: `(64, 80) @ (80, 32)` → `(64, 32)` = `(input_dim, units)`, which is exactly the shape of `self.params['W']`. + +The bias gradient also benefits: `np.sum(grad_output_flat, axis=0)` sums over all 80 positions, giving shape `(32,)` = `(units,)`. + +#### Step 4c: Compute the weight gradient + +```python self.grads['W'] = np.dot(inputs_flat.T, grad_output_flat) +``` + +🔍 **Line 46**: This implements $\frac{\partial L}{\partial W} = x^T \cdot \frac{\partial L}{\partial z}$. + +📐 **Shape walkthrough**: `inputs_flat.T` is `(D, B')` where `B' = B * T` (all positions flattened). `grad_output_flat` is `(B', U)`. `np.dot((D, B'), (B', U))` → `(D, U)` — matching the shape of `W`. + +🔍 **Why it works**: Each element `(i, j)` of the result is `sum over batch of inputs_flat[k, i] * grad_output_flat[k, j]`. This is exactly the average (well, sum) co-variation of input feature `i` and output error `j` — if they tend to be positive together, the weight should increase. + +#### Step 4d: Compute the bias gradient + +```python if self.use_bias: self.grads['b'] = np.sum(grad_output_flat, axis=0) +``` +🔍 **Line 48**: Summing over `axis=0` (the batch/time dimension). For each output neuron `j`, `self.grads['b'][j]` is the sum of `grad_output_flat[:, j]` over all samples. This implements $\frac{\partial L}{\partial b_j} = \sum_{\text{batch}} \frac{\partial L}{\partial z_j}$. + +#### Step 4e: Compute the input gradient (for the previous layer) + +```python return np.dot(grad_output, self.params['W'].T) ``` -- For activation backprop, the Jacobian of the activation is element-wise multiplied with `grad_output` (most activations like ReLU, sigmoid, tanh are element-wise; Softmax is handled separately via `gradient_fast`). -- The matrix multiplications are the exact implementation of the gradient equations above. -- The return value is the gradient with respect to the input, which is passed to the previous layer. +🔍 **Line 50**: This implements $\frac{\partial L}{\partial x} = \frac{\partial L}{\partial z} \cdot W^T$. We use the **original** (non-flattened) `grad_output` here because the previous layer expects the same number of dimensions as its output had. + +📐 **Shape walkthrough**: If input was 3D `(B, T, D)`, then `grad_output` is `(B, T, U)` and `W.T` is `(U, D)`. `np.dot((B, T, U), (U, D))` → `(B, T, D)` — matching the original input shape perfectly. + +If input was 2D `(B, D)`: `np.dot((B, U), (U, D))` → `(B, D)` — also correct. -## Usage Example +## Putting it all together + +Here's the full lifecycle when you call `layer(x)` on a Dense layer with ReLU activation: + +1. **`Layer.__call__`** is invoked with your input data. +2. It checks `self.built` — on the first call, it's `False`, so `build(inputs.shape)` runs, creating `params['W']` and `params['b']`. +3. It calls `self.forward(inputs)`, which is `Dense.forward`. +4. **`Dense.forward`**: + - Caches `self.inputs = inputs` (needed later in backward line 43) + - Computes `self.z = np.dot(inputs, W)` → the pre-activation values + - Adds bias: `self.z += b` + - Applies ReLU: `return np.maximum(0, self.z)` +5. Later, `layer.backward(grad_output)` is called: + - Multiplies `grad_output` by ReLU's gradient `(z > 0)` — zeroing out gradients for negative pre-activations + - Flattens `self.inputs` and `grad_output` to handle any number of dimensions + - Computes `grads['W'] = inputs_flat.T @ grad_output_flat` + - Computes `grads['b'] = sum(grad_output_flat, axis=0)` + - Returns `grad_output @ W.T` — the gradient for the previous layer + +The optimizer then uses `self.grads['W']` and `self.grads['b']` to update `self.params['W']` and `self.params['b']`. + +## Try it yourself ```python from neutro.layers import Dense import numpy as np +# Create a Dense layer with 64 output units and ReLU activation layer = Dense(units=64, activation='relu') -x = np.random.randn(32, 128) # (batch, input_dim) -y = layer(x) # forward, shape (32, 64) -grad = np.random.randn(32, 64) -dx = layer.backward(grad) # gradient w.r.t. x, shape (32, 128) + +# Generate random input: batch of 32, each 128-dimensional +x = np.random.randn(32, 128) + +# Forward pass — this triggers build on first call +y = layer(x) +print(f"Output shape: {y.shape}") # (32, 64) +print(f"Parameters: {layer.count_params()}") # 128*64 + 64 = 8256 + +# Simulated upstream gradient +dL_dy = np.random.randn(32, 64) + +# Backward pass +dL_dx = layer.backward(dL_dy) +print(f"Input gradient shape: {dL_dx.shape}") # (32, 128) + +# Check the gradient shapes match the parameter shapes +print(f"W grad shape: {layer.grads['W'].shape}") # (128, 64) +print(f"b grad shape: {layer.grads['b'].shape}") # (64,) + +# Try with a 3D input (e.g., a sequence) +x_3d = np.random.randn(8, 10, 128) # (batch, timesteps, features) +y_3d = layer(x_3d) +print(f"3D output shape: {y_3d.shape}") # (8, 10, 64) ``` -## References +## What to read next -- Goodfellow, I., Bengio, Y., & Courville, A. (2016). **Deep Learning**. Chapter 6: Deep Feedforward Networks. [Deep Learning Book](https://www.deeplearningbook.org/) +- **`neutro/layers/base.md`** — The base class that `Dense` inherits from: learn how `__call__` dispatches between symbolic and eager mode. +- **`neutro/layers/core/dropout.md`** — Another core layer with a very different backward pass (stochastic masking). +- **`neutro/activations/relu.md`** — How the ReLU activation computes its forward pass and gradient. +- **`neutro/activations/softmax.md`** — Why softmax needs a special `gradient_fast` method instead of the element-wise path. diff --git a/docs/layers/core/input_layer.md b/docs/layers/core/input_layer.md index 573a04d..1d4acc4 100644 --- a/docs/layers/core/input_layer.md +++ b/docs/layers/core/input_layer.md @@ -1,46 +1,81 @@ # Input Layer and the `Input()` Function -## Theory +## InputLayer — `neutro/layers/core/input_layer.py:4` -In the Functional API, every graph needs entry points — places where data enters the model. `Input()` creates a symbolic `KerasTensor` that acts as the root of the graph. The corresponding `InputLayer` is a no-op layer that simply passes data through; its role is purely structural. +### What does this layer do? -`Input()` is a **convenience function** that: -1. Creates an `InputLayer` with the given shape. -2. Creates a `KerasTensor` as its symbolic output. -3. Records a `Node` connecting them. -4. Returns the `KerasTensor` for use in further layer calls. +`InputLayer` is the **entry point** of a computation graph. It is a no-op layer — it does not transform the data in any way. Its job is purely structural: it records a `Node` in the graph so the model knows this is where data enters. -The batch dimension is conventionally `None` (unknown until runtime), mirroring Keras behavior. +### The math, in plain English -## Implementation Guide +There is no math. InputLayer is the identity function: -### File: `neutro/layers/core/input_layer.py` +$$ +y = x +$$ -### `InputLayer` — line 4 +Both forward and backward pass the data through unchanged. + +### Walking through the code + +#### `__init__` ```python -class InputLayer(Layer): - def __init__(self, input_shape=None, name=None, **kwargs): - super().__init__(name=name, input_shape=input_shape, **kwargs) - if input_shape is not None: - self.build(input_shape) +def __init__(self, input_shape=None, name=None, **kwargs): + super().__init__(name=name, input_shape=input_shape, **kwargs) + if input_shape is not None: + self.build(input_shape) +``` - def build(self, input_shape): - self.input_shape = input_shape - self.built = True +🔍 **Line `super().__init__(name=name, ...)`**: Passes `name` and `input_shape` up to the base `Layer` class, where `self.input_shape` is stored. - def forward(self, inputs, training=False): - return inputs +🔍 **Line `if input_shape is not None: self.build(input_shape)`**: Unlike most layers (which wait for `build` to be called when data first flows through), `InputLayer` builds itself immediately because it already knows its shape. - def backward(self, grad_output): - return grad_output +#### `build` + +```python +def build(self, input_shape): + self.input_shape = input_shape + self.built = True +``` + +🔍 **Line `self.input_shape = input_shape`**: Saves the shape. Note: `InputLayer` does **not** allocate any parameters (`self.params` stays empty). It's a parameter-free layer. + +🔍 **Line `self.built = True`**: Marks the layer as built so subsequent calls don't trigger rebuild. + +#### `forward` + +```python +def forward(self, inputs, training=False): + return inputs ``` -- `forward` is the identity function — it returns its input unchanged. -- `backward` is also the identity — it passes the gradient straight through. -- `build` does not allocate any parameters; it only marks the layer as built. +🔍 **Identity function**: Input passes through completely unchanged. This is a **pass-through** layer — it exists only to connect the graph. + +#### `backward` + +```python +def backward(self, grad_output): + return grad_output +``` + +🔍 **Identity again**: The gradient of the identity is 1, so $dL/dx = dL/dy$. The gradient passes through unchanged. This is the "last stop" for gradients during backpropagation — the model collects these gradients as the return value for the overall `backward` call. + +--- + +## The `Input()` Function — `neutro/layers/core/input_layer.py:28` + +### What does this function do? + +`Input()` is a **convenience function** that you call at the top of the Functional API. It does four things in one shot: + +1. Normalizes the shape (prepends `None` for the batch dimension). +2. Creates an `InputLayer` with that shape. +3. Creates a symbolic `KerasTensor` as the layer's output. +4. Records a `Node` connecting the layer to the tensor. +5. Returns the `KerasTensor` so you can feed it into other layers. -### `Input()` function — line 28 +### Walking through the code ```python def Input(shape=None, name=None, **kwargs): @@ -50,38 +85,68 @@ def Input(shape=None, name=None, **kwargs): if not isinstance(shape, tuple): shape = tuple(shape) - # Keras style: prepend None for batch dimension if missing if len(shape) == 0 or shape[0] is not None: shape = (None,) + shape layer = InputLayer(input_shape=shape, name=name, **kwargs) + output_tensor = KerasTensor(shape=shape, name=name) + Node(layer, input_tensors=[], output_tensors=output_tensor) + return output_tensor ``` -Key behaviors: -- **Shape normalization**: If you pass `shape=(28, 28, 1)`, it becomes `(None, 28, 28, 1)`. This is the Keras convention: users specify the per-sample shape, and the batch dimension is prepended. -- **Empty input_tensors**: The `Node` created for `InputLayer` has an empty `input_tensors` list — it has no upstream layers. -- **The returned `KerasTensor`** has its `.node` set to this `Node`, so graph traversal can start from it. +🔍 **Line `if shape is None`**: The shape is required. Unlike some Keras variants, neutro does not infer the shape. + +🔍 **Line `if not isinstance(shape, tuple): shape = tuple(shape)`**: If you pass a list like `[28, 28, 1]`, it's converted to a tuple `(28, 28, 1)`. This ensures consistent handling. + +🔍 **Line `if len(shape) == 0 or shape[0] is not None: shape = (None,) + shape`**: This is the **batch dimension prepending** logic. In Keras convention, `Input(shape=(28, 28, 1))` means "each sample has shape `(28, 28, 1)`", and the batch dimension is implicitly `None` (unknown until runtime). So the stored shape becomes `(None, 28, 28, 1)`. + +If you already pass `shape=(None, 28, 28, 1)`, the condition `shape[0] is not None` is False, so the shape is used as-is (no double wrapping). + +🔍 **Line `layer = InputLayer(input_shape=shape, name=name, **kwargs)`**: Creates the actual layer instance. The `InputLayer.__init__` immediately calls `self.build(shape)`. + +🔍 **Line `output_tensor = KerasTensor(shape=shape, name=name)`**: Creates a **symbolic tensor**. This is not real data — it's a placeholder that carries shape information. When you call other layers with this tensor (e.g., `Dense(32)(output_tensor)`), they use its shape to build themselves and record new `Node`s in the graph. + +🔍 **Line `Node(layer, input_tensors=[], output_tensors=output_tensor)`**: Creates a graph node with **an empty `input_tensors` list**. This is the key difference from other nodes: InputLayer has no upstream layers — it's a **root** node. The `Node` constructor also sets `output_tensor.node = node`, linking the tensor back to this node. + +📐 **Empty `input_tensors`**: When the model traverses the graph during execution, it starts from the model inputs (the `KerasTensor`s returned by `Input()`). The fact that InputLayer nodes have no input tensors signals "this is where execution begins." + +🔍 **Line `return output_tensor`**: The function returns the **KerasTensor**, not the layer. This is what you assign to a variable: +```python +inputs = Input(shape=(28, 28, 1)) # inputs is a KerasTensor +x = Dense(32)(inputs) # Dense receives the KerasTensor +``` ### How InputLayer is handled during execution -In `Model.forward` (`neutro/models/base_model.py:217`): +InputLayer nodes are **skipped** during the model's forward and backward passes. Here's how: ```python +# In Model.forward (simplified): +tensor_map = {} +tensor_map[id(self.inputs)] = actual_data # Place model inputs + for node in self._nodes_ordered: if isinstance(node.layer, InputLayer): - continue # Skip — inputs are placed directly in tensor_map + continue # SKIP — already in tensor_map + # ... process other layers normally ``` -InputLayer nodes are **skipped** during execution. Their values come from the model's input data, which is placed into `tensor_map` at the start of `forward`: +🔍 **Line `tensor_map[id(self.inputs)] = actual_data`**: Before the execution loop, the model places the actual input data into the tensor map, keyed by the `KerasTensor`'s id. This is the starting point. -```python -tensor_map[id(self.inputs)] = inputs # Placed before the loop -``` +🔍 **Line `if isinstance(node.layer, InputLayer): continue`**: When the loop encounters an InputLayer node, it skips it. The tensor_map already has the data under the input KerasTensor's id, so no processing is needed. + +The same skip happens in `backward`: InputLayer nodes receive (and pass through) gradients, but no gradient computation occurs within them. + +### Why this design? + +The `Input()` function + `InputLayer` design decouples **graph construction** from **data flow**: + +- **Graph construction time** (when you call `Input()`): A `KerasTensor` is created. When you pass it to `Dense(32)`, the Dense layer creates a `Node` recording that connection. The graph is built symbolically — no real data moves. -The same skip happens in `backward` (`line 314`): InputLayer nodes receive gradients but pass them back as the return value of the entire `backward` call. +- **Execution time** (when you call `model.fit()` or `model.predict()`): The model walks the graph, and for InputLayer nodes, it simply reads the actual data from the input list. The KerasTensor acts as a "key" to look up the real numpy array in the tensor_map. ## Usage Example @@ -90,7 +155,7 @@ from neutro.layers import Input, Dense, Add from neutro.models import Model # Single input -inputs = Input(shape=(28, 28, 1)) # KerasTensor of shape (None, 28, 28, 1) +inputs = Input(shape=(28, 28, 1)) # Returns KerasTensor of shape (None, 28, 28, 1) x = Dense(32)(inputs) model = Model(inputs=inputs, outputs=x) @@ -99,7 +164,7 @@ i1 = Input(shape=(10,), name='input_a') i2 = Input(shape=(10,), name='input_b') merged = Add()([i1, i2]) model = Model(inputs=[i1, i2], outputs=merged) -# forward expects [array_a, array_b] +# model.fit([x1_data, x2_data], y_data) ``` ## References diff --git a/docs/layers/core/merging.md b/docs/layers/core/merging.md index f9abb39..57ad22e 100644 --- a/docs/layers/core/merging.md +++ b/docs/layers/core/merging.md @@ -1,124 +1,344 @@ # Merge Layers: Add, Concatenate, Multiply, Average, Maximum, Minimum -## Theory +Merge layers combine **multiple input tensors** into a single output tensor. They are essential for building non-linear architectures like ResNets (skip connections), Inception modules, and multi-branch networks. Every merge layer takes a **list of tensors** as input. -Merge layers combine multiple input tensors into a single output tensor. They are essential for building non-linear architectures like ResNets (skip connections), Inception modules, and multi-branch networks. Every merge layer takes a **list of tensors** as input. +## Add — `merging.py:4` -### Operations +### What does this layer do? -| Layer | Operation | Gradient | -|---|---|---| -| `Add` | $y = \sum_i x_i$ | $\frac{\partial L}{\partial x_i} = \frac{\partial L}{\partial y}$ (same for all) | -| `Multiply` | $y = \prod_i x_i$ | $\frac{\partial L}{\partial x_i} = \frac{\partial L}{\partial y} \odot \prod_{j \ne i} x_j$ | -| `Average` | $y = \frac{1}{N} \sum_i x_i$ | $\frac{\partial L}{\partial x_i} = \frac{1}{N} \frac{\partial L}{\partial y}$ | -| `Maximum` | $y = \max_i x_i$ | $\frac{\partial L}{\partial x_i} = \frac{\partial L}{\partial y} \odot \mathbf{1}_{x_i = y}$ | -| `Minimum` | $y = \min_i x_i$ | $\frac{\partial L}{\partial x_i} = \frac{\partial L}{\partial y} \odot \mathbf{1}_{x_i = y}$ | -| `Concatenate` | $y = [x_1, x_2, \dots, x_N]$ along axis $a$ | Split $y$ gradient back along $a$ | +Add computes the element-wise sum of all input tensors. This is the fundamental operation behind **residual (skip) connections**. -For `Maximum`/`Minimum`, the indicator function $\mathbf{1}_{x_i = y}$ passes the gradient only to the input(s) that achieved the extreme value — this is known as **argmax routing** in gradient computation. +### The math, in plain English -## Implementation Guide +$$ +y = x_1 + x_2 + \cdots + x_N +$$ -### File: `neutro/layers/core/merging.py` +Every input must have the **same shape**. The output has that same shape. -### `Add` — line 4 +### Walking through the code + +#### `forward` ```python -class Add(Layer): - def forward(self, inputs, training=False): - self.input_lengths = len(inputs) - return sum(inputs) +def forward(self, inputs, training=False): + self.input_lengths = len(inputs) + return sum(inputs) +``` + +🔍 **Line `self.input_lengths = len(inputs)`**: We cache the number of inputs. The backward pass needs this to know how many gradient tensors to return. + +🔍 **Line `return sum(inputs)`**: Python's built-in `sum()` on a list of NumPy arrays performs element-wise addition. All arrays must have the same shape. For example, listing `[a, b, c]` computes `a + b + c`. - def backward(self, grad_output): - return [grad_output for _ in range(self.input_lengths)] +📐 **Shape**: If each input is `(batch, 64)`, the output is also `(batch, 64)`. + +#### `backward` + +```python +def backward(self, grad_output): + return [grad_output for _ in range(self.input_lengths)] ``` -- `sum(inputs)` works element-wise across the list. -- The gradient is **broadcast unchanged** to every input — the sum's Jacobian w.r.t. each input is the identity. +🔍 **Line `[grad_output for _ in range(self.input_lengths)]`**: For $y = x_1 + x_2$, we have $\partial y / \partial x_1 = 1$ and $\partial y / \partial x_2 = 1$. So by the chain rule, $\partial L / \partial x_i = \partial L / \partial y \cdot 1$. The gradient is **broadcast unchanged** to every input. We return a list with `N` identical gradient tensors. + +--- + +## Concatenate — `merging.py:42` + +### What does this layer do? + +Concatenate joins multiple tensors along a specified axis. All inputs must have the same shape **except** along the concatenation axis, where their dimensions are summed. This is the core of multi-branch feature fusion architectures like Inception. + +### The math, in plain English + +$$ +y = [x_1, x_2, \dots, x_N] \quad \text{along axis } a +$$ + +If each input has shape $(d_0, d_1, \dots, d_a, \dots, d_k)$ and we concatenate along axis $a$, the output has shape $(d_0, d_1, \dots, \sum_i d_a^{(i)}, \dots, d_k)$. + +### Walking through the code -### `Concatenate` — line 37 +#### `__init__` ```python -class Concatenate(Layer): - def __init__(self, axis=-1, **kwargs): - super().__init__(**kwargs) - self.axis = axis - - def compute_output_shape(self, input_shape): - out_shape = list(input_shape[0]) - concat_dim = 0 - for shape in input_shape: - dim = shape[self.axis] - if dim is None: # Handle symbolic None (batch dim) - concat_dim = None - break - concat_dim += dim - out_shape[self.axis] = concat_dim - return tuple(out_shape) - - def forward(self, inputs, training=False): - self.input_shapes = [i.shape for i in inputs] - return np.concatenate(inputs, axis=self.axis) - - def backward(self, grad_output): - indices = np.cumsum([s[self.axis] for s in self.input_shapes])[:-1] - return np.split(grad_output, indices, axis=self.axis) +def __init__(self, axis=-1, **kwargs): + super().__init__(**kwargs) + self.axis = axis ``` -- `compute_output_shape` correctly handles symbolic `None` dimensions (e.g., batch size). -- The backward uses `np.split` to reverse the concatenation along the same axis. +🔍 **`axis=-1`**: By default, concatenation happens along the last axis (features). This is the most common use case — joining feature vectors side-by-side. -### `Multiply` — line 74 +#### `compute_output_shape` ```python -class Multiply(Layer): - def forward(self, inputs, training=False): - self.inputs = inputs - res = inputs[0].copy() - for i in range(1, len(inputs)): - res *= inputs[i] - return res - - def backward(self, grad_output): - grads = [] - for i in range(len(self.inputs)): - g = grad_output.copy() - for j in range(len(self.inputs)): - if i == j: continue - g *= self.inputs[j] # Product of all inputs except the i-th - grads.append(g) - return grads +def compute_output_shape(self, input_shape): + if not isinstance(input_shape, list): + return input_shape + + out_shape = list(input_shape[0]) + concat_dim = 0 + for shape in input_shape: + dim = shape[self.axis] + if dim is None: + concat_dim = None + break + concat_dim += dim + + out_shape[self.axis] = concat_dim + return tuple(out_shape) ``` -- For each input $i$, the gradient is $\frac{\partial L}{\partial y} \odot \prod_{j \ne i} x_j$. -- `self.inputs` is cached during forward for use in backward (important for shared layer state restoration). +🔍 **Line `if dim is None: concat_dim = None; break`**: This handles **symbolic shapes** where the batch dimension (or any dimension) is `None` at graph-building time. If any input has a `None` dimension on the concat axis, the output's concat dimension will also be `None`. -### `Average`, `Maximum`, `Minimum` — lines 127-200 +📐 **Example**: Input shapes `[(None, 10), (None, 20)]` with `axis=-1` → `out_shape = (None, 30)`. But if one dimension is `None` on the concat axis, it propagates as `None`. -These follow the same pattern. `Maximum` and `Minimum` use `np.maximum` / `np.minimum` in forward and mask-based gradient routing in backward. +#### `forward` -### Shared Layer Compatibility +```python +def forward(self, inputs, training=False): + self.input_shapes = [i.shape for i in inputs] + return np.concatenate(inputs, axis=self.axis) +``` -All merge layers store intermediate state (`input_lengths`, `input_shapes`, `inputs`) on `self` during `forward`. For shared merge layers used multiple times in a graph, the `Model` class uses `_capture_layer_state` / `_restore_layer_state` (recursive, covering sublayers) to save and restore this state per node. +🔍 **Line `self.input_shapes = [i.shape for i in inputs]`**: We cache the **actual shape** of each input tensor. The backward pass needs the sizes along the concat axis to split the gradient correctly. -## Usage Example +🔍 **Line `np.concatenate(inputs, axis=self.axis)`**: NumPy's native concatenation. This is the only operation — no learned parameters. + +📐 **Shape**: `[a, b, c]` where `a.shape = (8, 10)`, `b.shape = (8, 20)`, `c.shape = (8, 30)` with `axis=-1` → `(8, 60)`. + +#### `backward` ```python -from neutro.layers import Input, Dense, Add, Concatenate -from neutro.models import Model - -# Skip connection (Add) -inp = Input(shape=(32,)) -x = Dense(32, activation='relu')(inp) -skip = Dense(32)(x) -out = Add()([x, skip]) # Two branches merged - -# Multi-branch concatenation -i1 = Input(shape=(10,)) -i2 = Input(shape=(20,)) -merged = Concatenate(axis=-1)([i1, i2]) # Output shape: (None, 30) +def backward(self, grad_output): + indices = np.cumsum([s[self.axis] for s in self.input_shapes])[:-1] + return np.split(grad_output, indices, axis=self.axis) ``` +🔍 **Line `indices = np.cumsum(...)`**: Compute the split points from the cached input shapes. `np.cumsum` gives cumulative sums along the concat axis. We drop the last element with `[:-1]` because `np.split` takes split positions. + +📐 **Example**: Input shapes along axis: `[10, 20, 30]`. `np.cumsum([10, 20, 30])` = `[10, 30, 60]`. `[:-1]` = `[10, 30]`. These are the split indices: slice 0..10, 10..30, 30..60. + +🔍 **Line `np.split(grad_output, indices, axis=self.axis)`**: Reverse of `np.concatenate`. Splits the gradient along the same axis into the original pieces. Returns a list of gradient tensors matching the input shapes. + +--- + +## Multiply — `merging.py:85` + +### What does this layer do? + +Multiply computes the element-wise product of all input tensors. This is useful in attention mechanisms, gating, and specialized architectures. + +### The math, in plain English + +$$ +y = x_1 \odot x_2 \odot \cdots \odot x_N +$$ + +Where $\odot$ denotes element-wise multiplication. All inputs must have the same shape. + +For the backward pass, the gradient w.r.t. a single input is the product of **all other inputs** times the upstream gradient: + +$$ +\frac{\partial L}{\partial x_i} = \frac{\partial L}{\partial y} \odot \prod_{j \neq i} x_j +$$ + +### Walking through the code + +#### `forward` + +```python +def forward(self, inputs, training=False): + self.inputs = inputs + res = inputs[0].copy() + for i in range(1, len(inputs)): + res *= inputs[i] + return res +``` + +🔍 **Line `self.inputs = inputs`**: Cache the list of inputs for the backward pass. The backward pass needs to access all inputs except the one being differentiated. + +🔍 **Line `res = inputs[0].copy()`**: Start with a **copy** of the first input. We use `.copy()` to avoid mutating the original input tensor. + +🔍 **Line `res *= inputs[i]`**: Multiply element-by-element. After the loop, `res` is the product of all inputs. + +📐 **Shape**: `(8, 64)` × `(8, 64)` × `(8, 64)` → `(8, 64)`. + +#### `backward` + +```python +def backward(self, grad_output): + grads = [] + for i in range(len(self.inputs)): + g = grad_output.copy() + for j in range(len(self.inputs)): + if i == j: + continue + g *= self.inputs[j] + grads.append(g) + return grads +``` + +🔍 **Line `g = grad_output.copy()`**: Start with the upstream gradient. + +🔍 **Line `for j ... if i == j: continue; g *= self.inputs[j]`**: +For input $x_i$, we multiply the upstream gradient by **every other input** $x_j$ for $j \neq i$. This implements $\partial L / \partial x_i = \partial L / \partial y \cdot \prod_{j \neq i} x_j$. + +📐 **Example with 3 inputs**: $y = a \cdot b \cdot c$. +- $\partial L / \partial a = \partial L / \partial y \cdot b \cdot c$ +- $\partial L / \partial b = \partial L / \partial y \cdot a \cdot c$ +- $\partial L / \partial c = \partial L / \partial y \cdot a \cdot b$ + +The loops compute exactly these products. + +--- + +## Average — `merging.py:120` + +### What does this layer do? + +Average computes the element-wise mean of all input tensors. + +### The math, in plain English + +$$ +y = \frac{1}{N} \sum_{i=1}^{N} x_i +$$ + +### Walking through the code + +#### `forward` + +```python +def forward(self, inputs, training=False): + self.input_lengths = len(inputs) + return sum(inputs) / self.input_lengths +``` + +🔍 **Line `self.input_lengths = len(inputs)`**: Cache the number of inputs `N` for the backward pass. + +🔍 **Line `sum(inputs) / self.input_lengths`**: Python's `sum()` adds element-wise, then dividing by `N` gives the average. + +#### `backward` + +```python +def backward(self, grad_output): + return [grad_output / self.input_lengths for _ in range(self.input_lengths)] +``` + +🔍 **Line `grad_output / self.input_lengths`**: The derivative of $y = (x_1 + \dots + x_N) / N$ w.r.t. $x_i$ is $1/N$. Each input receives the upstream gradient divided by the number of inputs. + +--- + +## Maximum — `merging.py:144` + +### What does this layer do? + +Maximum computes the element-wise maximum across all input tensors. + +### The math, in plain English + +$$ +y = \max(x_1, x_2, \dots, x_N) +$$ + +For each element position, the output is the largest value among all inputs at that position. + +The backward pass uses **argmax routing**: the gradient flows only to the input(s) that actually **were** the maximum at each position. All other inputs receive zero gradient. + +### Walking through the code + +#### `forward` + +```python +def forward(self, inputs, training=False): + self.inputs = inputs + res = inputs[0].copy() + for i in range(1, len(inputs)): + res = np.maximum(res, inputs[i]) + return res +``` + +🔍 **Line `self.inputs = inputs`**: Cache the inputs. The backward pass needs to compare each input against the maximum. + +🔍 **Line `np.maximum(res, inputs[i])`**: Element-wise maximum. `np.maximum(a, b)` returns an array where each element is `max(a_element, b_element)`. + +📐 **Shape**: All `(8, 64)`. Output: `(8, 64)`. + +#### `backward` + +```python +def backward(self, grad_output): + max_val = self.forward(self.inputs) + grads = [] + for inp in self.inputs: + mask = (inp == max_val) + grads.append(grad_output * mask) + return grads +``` + +🔍 **Line `max_val = self.forward(self.inputs)`**: Recompute the maximum values by calling `forward` again. (Alternative: cache `max_val` in forward.) + +🔍 **Line `mask = (inp == max_val)`**: For each input, create a boolean mask that is `True` wherever this input equals the maximum value. If multiple inputs share the maximum at a position, all of them get gradient. + +🔍 **Line `grad_output * mask`**: The mask zeros out the gradient everywhere this input was **not** the maximum. Only the "winning" input receives gradient. + +📐 **The logic**: For $y = \max(x_1, x_2)$, the subgradient is: +$$ +\frac{\partial y}{\partial x_1} = \begin{cases} 1 & \text{if } x_1 > x_2 \\ 0 & \text{if } x_1 < x_2 \\ \text{any value in } [0,1] & \text{if } x_1 = x_2 \end{cases} +$$ + +Neutro uses the tie-case convention: if two inputs are equal, **both** get gradient (the mask is `True` for both). + +--- + +## Minimum — `merging.py:177` + +### What does this layer do? + +Minimum computes the element-wise minimum across all input tensors. It is the mirror image of Maximum. + +### The math, in plain English + +$$ +y = \min(x_1, x_2, \dots, x_N) +$$ + +The backward pass uses **argmin routing**: gradient flows only to the input(s) that were the minimum at each position. + +### Walking through the code + +#### `forward` + +```python +def forward(self, inputs, training=False): + self.inputs = inputs + res = inputs[0].copy() + for i in range(1, len(inputs)): + res = np.minimum(res, inputs[i]) + return res +``` + +Identical to Maximum but uses `np.minimum`. + +#### `backward` + +```python +def backward(self, grad_output): + min_val = self.forward(self.inputs) + grads = [] + for inp in self.inputs: + mask = (inp == min_val) + grads.append(grad_output * mask) + return grads +``` + +Identical to Maximum's backward but using the minimum value as the comparison target. + +🔍 **Line `mask = (inp == min_val)`**: Gradient passes only where this input equals the minimum. For ties, multiple inputs receive gradient. + +--- + ## References - He, K., Zhang, X., Ren, S., & Sun, J. (2016). **Deep Residual Learning for Image Recognition** — skip connections via Add. *CVPR*. [arXiv:1512.03385](https://arxiv.org/abs/1512.03385) diff --git a/docs/layers/embedding/embedding.md b/docs/layers/embedding/embedding.md index 4d0529d..678c26c 100644 --- a/docs/layers/embedding/embedding.md +++ b/docs/layers/embedding/embedding.md @@ -1,43 +1,226 @@ # Embedding Layers -## Theory +## Embedding -### Token Embedding — `neutro/layers/embedding/embedding.py` +### What does this layer do? -An embedding layer maps discrete tokens (integers) to dense vectors: +Maps discrete tokens (integers like word IDs) to dense, learnable vectors. Think of it as a **lookup table**: token ID 0 gets row 0, token ID 1 gets row 1, and so on. The table entries are learned during training, so similar tokens end up with similar vectors. + +### The math $$x_i = W[\text{token}_i]$$ -Where $W \in \mathbb{R}^{V \times D}$ is a learnable matrix, $V$ is the vocabulary size, and $D$ is the embedding dimension. The forward pass is a simple lookup: +Where $W \in \mathbb{R}^{V \times D}$ is the embedding matrix, $V$ is the vocabulary size, and $D$ is the embedding dimension. The input is an integer tensor; the output is a float tensor where each integer has been replaced by its corresponding row of $W$. + +### Walking through the code + +#### `__init__` — what needs to happen before we know the shapes + +```python +class Embedding(Layer): + def __init__(self, input_dim, output_dim, **kwargs): + super().__init__(**kwargs) + self.input_dim = input_dim + self.output_dim = output_dim +``` + +🔍 **Line 3**: `super().__init__(**kwargs)` calls `Layer.__init__`, setting up `self.params = {}`, `self.grads = {}`, and `self.built = False`. The `input_shape` kwarg is stored for later use in `build`. + +🔍 **Line 4**: `input_dim` is the vocabulary size $V$ — the number of unique tokens. For example, `10000` means tokens are in range `[0, 9999]`. + +🔍 **Line 5**: `output_dim` is the embedding dimension $D$ — how many numbers represent each token. Common values: `128`, `256`, `512`. + +#### `build` — creating the embedding table + +```python +def build(self, input_shape): + self.params['embeddings'] = np.random.normal(0, 0.01, (self.input_dim, self.output_dim)) + super().build(input_shape) +``` + +🔍 **Line 2**: Create the embedding matrix. Shape `(V, D)` = `(vocab_size, embed_dim)`. Initialized with small random Gaussian values (mean 0, std 0.01). + +``` + embed_dim (= 128) + ┌──────────────────────┐ + │ token 0's embedding │ +V │ token 1's embedding │ +o │ ... │ +c │ │ +a │ token 9999's embed │ +b └──────────────────────┘ +``` + +🔍 **Line 3**: `super().build(input_shape)` sets `self.built = True` so `build` won't run again. + +#### `forward` — looking up tokens + +```python +def forward(self, inputs, training=False): + self.inputs = inputs.astype(int) + return self.params['embeddings'][self.inputs] +``` + +🔍 **Line 2**: Cast inputs to integer and cache them. We cache `self.inputs` because backward needs it to know *which rows* of the embedding table received gradients. + +🔍 **Line 3**: NumPy advanced indexing: `self.params['embeddings'][self.inputs]` — for each integer in `inputs`, fetch the corresponding row of the embedding matrix. + +📐 **Shape walkthrough**: Input `(B, seq_len)` with values like `[[42, 7, 999, 1]]`. `self.params['embeddings']` is `(V, D)` = `(10000, 128)`. `embeddings[inputs]` returns shape `(B, seq_len, D)` = `(1, 4, 128)`. Token 42 becomes a 128-dimensional vector (row 42 of the matrix), token 7 becomes row 7, etc. + +#### `backward` — sparse gradient accumulation + +```python +def backward(self, grad_output): + self.grads['embeddings'] = np.zeros_like(self.params['embeddings']) + np.add.at(self.grads['embeddings'], self.inputs, grad_output) + return None +``` + +🔍 **Line 2**: Start with a zero gradient buffer the same shape as the embedding matrix `(V, D)`. + +🔍 **Line 3**: `np.add.at(self.grads['embeddings'], self.inputs, grad_output)` — this is the critical line. + +🔍 **Why `np.add.at` and not `self.grads['embeddings'][self.inputs] = ...`?** Because the same token index might appear **multiple times** in a batch. For example, if the input batch contains token 42 at two different positions, both contributions need to be **summed**, not overwritten. + +Consider: `inputs = [[42, 7, 42, 1]]`, `grad_output = [[g0, g1, g2, g3]]`. Token 42 appears at positions 0 and 2. The gradient for row 42 should be `g0 + g2`. `np.add.at` handles this correctly — it accumulates into the same row. + +With regular assignment `self.grads['embeddings'][self.inputs] = grad_output`, only the last occurrence of token 42 would survive (position 2's gradient `g2`), and `g0` would be silently lost. + +🔍 **`return None`**: The embedding layer has no trainable parameters that affect the *previous* layer — the input tokens are fixed integers, not float gradients. There's no gradient to pass backward to a token index input. (In practice, the previous layer is usually a tokenizer or data loader, not another differentiable layer.) + +--- + +## TimeEmbedding + +### What does this layer do? + +Converts a scalar timestep (a single integer like `t=42`) into a high-dimensional vector using sinusoidal functions at different frequencies. This is the positional encoding from "Attention Is All You Need", repurposed for diffusion models — it tells the model *where* in time we are. + +### The math + +For each timestep $t$ and embedding dimension $i$: + +$$\text{TE}(t, 2i) = \sin\left(\frac{t}{10000^{2i / D}}\right)$$ + +$$\text{TE}(t, 2i+1) = \cos\left(\frac{t}{10000^{2i / D}}\right)$$ + +Where $D$ is the embedding dimension. Different dimensions oscillate at different frequencies — low-index dimensions change quickly (fine-grained time), high-index dimensions change slowly (coarse time). + +### Walking through the code + +#### `__init__` + +```python +class TimeEmbedding(Layer): + def __init__(self, dim, **kwargs): + super().__init__(**kwargs) + self.dim = dim + self.last_t = None +``` + +🔍 **Line 4**: `self.dim` is the output embedding dimension (e.g., 128 or 256). The input is a single scalar timestep; the output is a `dim`-dimensional vector. + +🔍 **Line 5**: `self.last_t = None` — a cache for the timestep used in the most recent forward pass. We'll store the input timesteps here so backward can return a gradient of the same shape. + +#### `forward` + +```python +def forward(self, t, training=False): + if t.ndim == 2: + t = t.flatten() + self.last_t = t + + half_dim = self.dim // 2 + embeddings = np.log(10000) / (half_dim - 1) + embeddings = np.exp(np.arange(half_dim) * -embeddings) + embeddings = t[:, None] * embeddings[None, :] + embeddings = np.concatenate([np.sin(embeddings), np.cos(embeddings)], axis=1) + + if self.dim % 2 == 1: + embeddings = np.pad(embeddings, ((0, 0), (0, 1))) + + return embeddings +``` + +🔍 **Lines 2-3**: If `t` has shape `(B, 1)`, flatten to `(B,)`. The input might come from a data loader that adds an extra dimension. + +🔍 **Line 5**: `half_dim = self.dim // 2`. Since each frequency produces both a sin and a cos, we need only half as many frequencies as the total dimension. + +🔍 **Lines 6-7**: Build the frequency schedule. This is the key precomputation: ```python -def forward(self, inputs): - return self.params['W'][inputs] # (batch, seq_len, embed_dim) +embeddings = np.log(10000) / (half_dim - 1) # scalar +embeddings = np.exp(np.arange(half_dim) * -embeddings) # (half_dim,) ``` -The backward pass uses `np.add.at` to accumulate gradients back to the embedding matrix: +`embeddings` is a vector of frequencies. For `half_dim = 64`: +- Index 0: `exp(0 * -embeddings)` = `exp(0)` = `1.0` — highest frequency +- Index 63: `exp(63 * -embeddings)` ≈ `exp(-log(10000))` = `0.0001` — lowest frequency + +These are the $1/10000^{2i/D}$ terms from the formula. + +🔍 **Line 8**: Outer product: `t[:, None]` is `(B, 1)`, `embeddings[None, :]` is `(1, half_dim)`. Result: `(B, half_dim)` — each timestep multiplied by each frequency. + +📐 **Shape**: `t` = `[0, 1, 2, ..., 999]` (batch of 1000 timesteps), `embeddings` = `[1.0, 0.84, ..., 0.0001]` (64 frequencies). Result: `(1000, 64)` — row `i` is `t_i * frequencies`. + +🔍 **Line 9**: Apply `sin` and `cos` to get the final encoding, then concatenate along the feature axis. + +📐 **Shape**: Each branch is `(B, half_dim)`. `concatenate(axis=1)` → `(B, dim)` if `dim` is even, which pairs sin(ωt) and cos(ωt) for each frequency. + +🔍 **Lines 10-11**: If `dim` is odd, pad with one extra column of zeros. `np.pad(embeddings, ((0,0), (0,1)))` adds a zero column at the end. + +🔍 **Why no `build` step?** The frequencies are *deterministic* — they depend only on `dim`, not on the input shape. There are no learnable parameters, so there's nothing to build. + +#### `backward` ```python def backward(self, grad_output): - self.grads['W'] = np.zeros_like(self.params['W']) - np.add.at(self.grads['W'], self.inputs, grad_output) - return grad_output + return np.zeros_like(self.last_t) ``` -### TimeEmbedding — `neutro/layers/embedding/time_embedding.py` +🔍 **Line 2**: Return a zero gradient with the same shape as the input timesteps. + +🔍 **Why zeros?** The timesteps `t` are not learnable parameters — they're fixed inputs chosen by the diffusion process (e.g., `t = 0, 1, 2, ..., 999`). There's no gradient to propagate back to them. The TimeEmbedding layer is purely a feature transformation; the actual learning happens in the layers that consume its output. + +🔍 **What about the gradient wrt the frequency schedule?** The frequencies are hardcoded constants (they don't appear in `self.params`), so there's no gradient to compute for them either. The `grad_output` from the next layer is simply discarded — it doesn't modify anything. -Projects scalar timesteps (e.g., diffusion timesteps) into a high-dimensional space using sinusoidal encoding followed by a learnable MLP projection. +### Why sinusoidal encodings? -## Usage Example +Sinusoidal functions have a useful property: the encoding for timestep `t + Δt` can be expressed as a linear function of the encoding for `t` (using trig identities). This means the model can learn to reason about **relative** timesteps — "50 steps from now" — rather than memorizing absolute positions. + +### Try it yourself ```python -from neutro.layers import Embedding +from neutro.layers import Embedding, TimeEmbedding +import numpy as np + +# Token Embedding +vocab_size, embed_dim = 10000, 128 +emb = Embedding(input_dim=vocab_size, output_dim=embed_dim) +tokens = np.array([[42, 7, 999, 1]]) # (1, 4) +x = emb(tokens) +print(f"Token embedding shape: {x.shape}") # (1, 4, 128) + +# Backward: check sparse accumulation +dL_dy = np.random.randn(1, 4, 128) +emb.backward(dL_dy) +print(f"Embedding grad shape: {emb.grads['embeddings'].shape}") # (10000, 128) +print(f"Rows with non-zero gradient: {np.count_nonzero(np.any(emb.grads['embeddings'] != 0, axis=1))}") # 3 (tokens 42, 7, 999) + +# TimeEmbedding (sinusoidal positional encoding) +te = TimeEmbedding(dim=256) +timesteps = np.array([0, 1, 50, 999]) # (4,) +z = te(timesteps) +print(f"Time embedding shape: {z.shape}") # (4, 256) +print(f"Timestep 0 encoding (first 8 dims): {z[0, :8]}") # sin(0) = 0, cos(0) = 1 for all frequencies -emb = Embedding(vocab_size=10000, embed_dim=512) -tokens = np.array([[1, 5, 23, 42]]) # (batch, seq_len) -x = emb(tokens) # (1, 4, 512) +# Backward returns zeros (timesteps aren't learned) +grad = te.backward(np.random.randn(4, 256)) +print(f"Input gradient shape: {grad.shape}") # (4,) +print(f"All zeros?: {np.all(grad == 0)}") # True ``` ## References - Mikolov, T., et al. (2013). **Efficient Estimation of Word Representations in Vector Space**. [arXiv:1301.3781](https://arxiv.org/abs/1301.3781) +- Vaswani, A., et al. (2017). **Attention Is All You Need**. [arXiv:1706.03762](https://arxiv.org/abs/1706.03762) diff --git a/docs/layers/normalization/batchnorm.md b/docs/layers/normalization/batchnorm.md index aa4b88b..f1c4e26 100644 --- a/docs/layers/normalization/batchnorm.md +++ b/docs/layers/normalization/batchnorm.md @@ -1,17 +1,137 @@ # Batch Normalization -## Overview -Batch Normalization (BatchNorm) accelerates deep network training by reducing internal covariate shift. It normalizes the activations of each layer for each mini-batch. +## What does this layer do? -## Mathematical Formulation -For a mini-batch $\mathcal{B} = \{x_1, \dots, x_m\}$: -1. **Mean**: $\mu_{\mathcal{B}} = \frac{1}{m} \sum_{i=1}^m x_i$ -2. **Variance**: $\sigma_{\mathcal{B}}^2 = \frac{1}{m} \sum_{i=1}^m (x_i - \mu_{\mathcal{B}})^2$ -3. **Normalize**: $\hat{x}_i = \frac{x_i - \mu_{\mathcal{B}}}{\sqrt{\sigma_{\mathcal{B}}^2 + \epsilon}}$ -4. **Scale and Shift**: $y_i = \gamma \hat{x}_i + \beta$ +Batch Normalization stabilizes training by normalizing activations **across the batch dimension** instead of across features. For each feature channel, it computes the mean and variance over the entire batch, then normalizes and applies a learned scale (`gamma`) and shift (`beta`). -## Implementation Details -`neutro` tracks running means and variances during training to use for inference. The backward pass involves calculating gradients for $\gamma$, $\beta$, and the input $x$ with respect to the batch statistics. +Intuitively, BatchNorm says: "Every feature channel should have a consistent distribution across batch elements." This is great for CNNs and large batch sizes, but it breaks when the batch is small (e.g., medical imaging, video). -## Citations -- Ioffe, S., & Szegedy, C. (2015). **Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift**. *Proceedings of the 32nd International Conference on Machine Learning (ICML)*. [arXiv:1502.03167](https://arxiv.org/abs/1502.03167) +During **inference**, BatchNorm uses running statistics collected during training — a single sample can't give a meaningful batch mean. + +## The math + +For a mini-batch with shape `(batch, ..., D)`, mean and variance are computed over **all axes except the last**: + +$$\mu = \frac{1}{M} \sum_{i=1}^{M} x_i \quad\quad \sigma^2 = \frac{1}{M} \sum_{i=1}^{M} (x_i - \mu)^2$$ + +$$\hat{x} = \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} \quad\quad y = \gamma \odot \hat{x} + \beta$$ + +The key difference from LayerNorm: the sum runs over all batch/spatial elements (all axes except the feature axis `axis=-1`), producing stats of shape `(D,)` — one mean and variance per feature channel. + +## Walking through the code + +### `__init__` / `build` + +```python +def __init__(self, momentum=0.99, epsilon=1e-3): + super().__init__() + self.momentum = momentum + self.epsilon = epsilon + self.running_mean = None + self.running_var = None + +def build(self, input_shape): + dim = input_shape[-1] + self.params['gamma'] = np.ones(dim) + self.params['beta'] = np.zeros(dim) + self.running_mean = np.zeros(dim) + self.running_var = np.ones(dim) + super().build(input_shape) +``` + +🔍 **`momentum=0.99`**: Controls how fast running statistics update. With momentum 0.99, each new batch contributes 1% to the running average. Higher = more stable but slower to adapt. + +🔍 **`epsilon=1e-3`**: Notice this is larger than LayerNorm's `1e-6`. BatchNorm variance tends to be noisier (fewer elements in each statistic), so a slightly larger epsilon is common. + +🔍 **`running_mean` / `running_var`**: Start as `zeros` and `ones` — a neutral initialization. These accumulate statistics across all training batches. + +📐 **`gamma`, `beta`**: Shape `(D,)` — one per feature channel. Same as LayerNorm. + +### `forward` + +```python +def forward(self, x, training=False): + if training: + mean = np.mean(x, axis=tuple(range(len(x.shape)-1))) + var = np.var(x, axis=tuple(range(len(x.shape)-1))) + + self.running_mean = self.momentum * self.running_mean + (1 - self.momentum) * mean + self.running_var = self.momentum * self.running_var + (1 - self.momentum) * var + + self.x_centered = x - mean + self.std = np.sqrt(var + self.epsilon) + self.x_norm = self.x_centered / self.std + else: + x_centered = x - self.running_mean + std = np.sqrt(self.running_var + self.epsilon) + self.x_norm = x_centered / std + + return self.params['gamma'] * self.x_norm + self.params['beta'] +``` + +🔍 **Training / Inference split**: The layer behaves completely differently depending on `training`. + +**Training path:** + +📐 **`np.mean(x, axis=tuple(range(len(x.shape)-1)))`**: For a `(B, H, W, D)` input, this sums over axes `(0, 1, 2)` — all but the last. Result shape: `(D,)` — one mean per feature channel. + +📐 **Same for `var`**: Shape `(D,)`. + +🔍 **Running stats update**: The exponential moving average: +- `running_mean = 0.99 * old + 0.01 * batch_mean` +- Over time, this smooths out the noise of individual batches. + +🔍 **Why running stats?** At inference, you might get a single image `(1, H, W, D)`. That batch mean of 1 element is meaningless. So you use the running statistics that accumulated statistics over thousands of training samples. + +**Inference path:** + +No caching of `x_centered`, `std`, or `x_norm`. Inference just computes the output directly using `running_mean` and `running_var`. + +### `backward` + +```python +def backward(self, grad_output): + gamma = self.params['gamma'] + batch_size = np.prod(grad_output.shape[:-1]) + + self.grads['gamma'] = np.sum(grad_output * self.x_norm, + axis=tuple(range(len(grad_output.shape)-1))) + self.grads['beta'] = np.sum(grad_output, + axis=tuple(range(len(grad_output.shape)-1))) + + dx_norm = grad_output * gamma + dx = (1. / batch_size) / self.std * ( + batch_size * dx_norm + - np.sum(dx_norm, axis=tuple(range(len(grad_output.shape)-1))) + - self.x_norm * np.sum(dx_norm * self.x_norm, + axis=tuple(range(len(grad_output.shape)-1))) + ) + return dx +``` + +📐 **`batch_size = np.prod(grad_output.shape[:-1])`**: The total number of elements contributing to each feature's statistics. For `(B, H, W, D)`, this is `B * H * W` — all batch and spatial positions. + +📐 **`self.grads['gamma']`**: Shape `(D,)` — sum over axes `(0, 1, 2)` for a 4D input. Same pattern as LayerNorm but the reduction axes are different. + +📐 **`self.grads['beta']`**: Same reduction, also `(D,)`. + +🔍 **Big `dx` formula**: Identical structure to LayerNorm! The three terms are the same: +1. `batch_size * dx_norm` — direct gradient +2. `- sum(dx_norm)` — correction through mean +3. `- x_norm * sum(dx_norm * x_norm)` — correction through variance + +The only difference is **which axis** the sums are computed over. For BatchNorm, sums are over all non-feature axes. For LayerNorm, sums are only over the feature axis. + +🔍 **`self.std` shape**: `(D,)` for BatchNorm vs. `(B, S, 1)` for LayerNorm. But broadcasting makes the math work the same way. + +## Why not always use BatchNorm? + +BatchNorm has two gotchas: +1. **Small batches**: Mean/variance from 2-4 samples is noisy. +2. **Training != Inference**: You must track running stats. A bug in mode-switching silently destroys accuracy. + +That's why LayerNorm is preferred in Transformers and GroupNorm is used in vision with small batches. + +## References + +- Ioffe, S., & Szegedy, C. (2015). **Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift**. *Proceedings of ICML*. [arXiv:1502.03167](https://arxiv.org/abs/1502.03167) diff --git a/docs/layers/normalization/layernorm.md b/docs/layers/normalization/layernorm.md index e036132..65d8a82 100644 --- a/docs/layers/normalization/layernorm.md +++ b/docs/layers/normalization/layernorm.md @@ -1,17 +1,109 @@ # Layer Normalization -## Overview -Layer Normalization (LayerNorm) is a normalization technique that computes the mean and variance for each individual sample across all its features, rather than across a batch. This makes it ideal for recurrent neural networks and Transformers. +## What does this layer do? -## Mathematical Formulation -Unlike BatchNorm, LayerNorm normalizes across the features $H$: -$$\mu = \frac{1}{H} \sum_{i=1}^H x_i$$ -$$\sigma = \sqrt{\frac{1}{H} \sum_{i=1}^H (x_i - \mu)^2 + \epsilon}$$ -$$\hat{x} = \frac{x - \mu}{\sigma}$$ -$$y = \gamma \hat{x} + \beta$$ +Layer Normalization (LayerNorm) makes training stable by controlling the distribution of activations inside a neural network. For each input sample, it rescales all features to have zero mean and unit variance, then applies a learned scale (`gamma`) and shift (`beta`). -## Implementation Details -In `neutro`, LayerNorm is used extensively in the `TransformerBlock`. It is independent of the batch size and works the same way during training and inference. +Unlike Batch Normalization, LayerNorm computes statistics **across the feature dimension** — independently for every sample in the batch. This makes it batch-size agnostic, which is why every Transformer (BERT, GPT, Llama) uses it. + +## The math + +For an input `x` with shape `(..., D)` (the last dimension is the feature dimension): + +$$\mu = \frac{1}{D} \sum_{i=1}^{D} x_i \quad\quad \sigma^2 = \frac{1}{D} \sum_{i=1}^{D} (x_i - \mu)^2$$ + +$$\hat{x} = \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} \quad\quad y = \gamma \odot \hat{x} + \beta$$ + +The mean and variance are computed over the **last axis** (`axis=-1`). Every sample in a batch gets its own normalization. The `gamma` and `beta` vectors have the same size as the feature dimension. + +## Walking through the code + +### `__init__` / `build` + +```python +def __init__(self, epsilon=1e-6, **kwargs): + super().__init__(**kwargs) + self.epsilon = epsilon + +def build(self, input_shape): + self.params['gamma'] = np.ones(input_shape[-1]) + self.params['beta'] = np.zeros(input_shape[-1]) + super().build(input_shape) +``` + +🔍 **`__init__`**: Stores a tiny `epsilon` to prevent division by zero during normalization. The `**kwargs` passes layer name and other metadata to `super()`. + +🔍 **`build`**: Called once when the layer first sees data. Creates two trainable parameters: +- `gamma = ones(D)` — starts as the identity scaling (multiply by 1) +- `beta = zeros(D)` — starts as the identity shift (add 0) + +📐 **Shapes**: `(D,)` — one scale and one shift per feature. If input is `(batch, seq_len, 512)`, then `gamma` and `beta` are both `(512,)`. NumPy broadcasting expands them to match the full input shape during `forward`. + +### `forward` + +```python +def forward(self, x, training=False): + self.x = x + self.mean = np.mean(x, axis=-1, keepdims=True) + self.var = np.var(x, axis=-1, keepdims=True) + self.x_norm = (x - self.mean) / np.sqrt(self.var + self.epsilon) + return self.params['gamma'] * self.x_norm + self.params['beta'] +``` + +🔍 **Line `self.x = x`**: Saves the input for the backward pass. Without this, backward wouldn't know what `x` was during forward. + +📐 **`np.mean(x, axis=-1, keepdims=True)`**: Computes mean over the last (feature) axis. With `keepdims`, a `(B, S, D)` input produces a `(B, S, 1)` mean — broadcasting back to `(B, S, D)` when you subtract it. + +📐 **Same for `np.var`**: Shape `(B, S, 1)` — one variance per position per sample. + +📐 **`self.x_norm`**: Shape `(B, S, D)` — same as input. Each feature `x[b,s,d]` has been centered and scaled by its own sample/position stats. + +📐 **Return**: `gamma * x_norm + beta` → still `(B, S, D)`. Broadcasting multiplies `(D,)` gamma against every position. + +🔍 **Why cache `mean`, `var`, `x_norm`?** Backward needs them. The gradient through normalization depends on the original mean and variance — you can't recompute them because they're functions of the original `x`. + +### `backward` + +```python +def backward(self, grad_output): + N = grad_output.shape[-1] + self.grads['gamma'] = np.sum(grad_output * self.x_norm, + axis=tuple(range(len(grad_output.shape)-1))) + self.grads['beta'] = np.sum(grad_output, + axis=tuple(range(len(grad_output.shape)-1))) + dx_norm = grad_output * self.params['gamma'] + std_inv = 1.0 / np.sqrt(self.var + self.epsilon) + dx = (1.0 / N) * std_inv * ( + N * dx_norm + - np.sum(dx_norm, axis=-1, keepdims=True) + - self.x_norm * np.sum(dx_norm * self.x_norm, axis=-1, keepdims=True) + ) + return dx +``` + +🔍 **`N = grad_output.shape[-1]`**: The feature dimension. Used in the normalization factor later. + +📐 **`self.grads['gamma']`**: Shape `(D,)` — sum over all axes except the last. For a `(B, S, D)` gradient, we sum over axes 0 and 1 (batch and sequence). + +📐 **`self.grads['beta']`**: Shape `(D,)` — the gradient of the bias term is just the sum of the output gradient along all non-feature axes. + +🔍 **`dx_norm = grad_output * gamma`**: The gradient flowing through the `gamma` scale. If the layer output is `y = gamma * x_norm + beta`, then `dy/dx_norm = gamma`. This is the chain rule's first step: `dL/dx_norm = dL/dy * dy/dx_norm`. + +🔍 **The big `dx` formula**: This is the gradient through the entire normalization pipeline — through the standardization `(x - mu) / std`. The formula has three terms inside the parentheses: + +1. **`N * dx_norm`**: The direct path — what you'd get if normalization were just a scaling. +2. **`- sum(dx_norm)`**: Corrects for the mean subtraction — the gradient has to account for the fact that `mu` is a function of `x`. +3. **`- x_norm * sum(dx_norm * x_norm)`**: Corrects for the variance scaling — `std` is also a function of `x`. + +Think of it as: +``` +x → [center: x - μ] → [scale: /σ] → x̂ → [affine: γ·x̂ + β] → y + ↑ ↑ ↑ + μ = mean(x) σ = sqrt(var(x)) γ, β learned +``` + +The backward pass reverses this chain. You don't need to memorize the big formula — just know that it's the chain rule correctly threaded through the mean and variance dependencies. + +## References -## Citations - Ba, J. L., Kiros, J. R., & Hinton, G. E. (2016). **Layer Normalization**. *arXiv preprint arXiv:1607.06450*. [arXiv:1607.06450](https://arxiv.org/abs/1607.06450) diff --git a/docs/layers/normalization/normalization.md b/docs/layers/normalization/normalization.md index 1a989f9..de2f4b3 100644 --- a/docs/layers/normalization/normalization.md +++ b/docs/layers/normalization/normalization.md @@ -1,78 +1,225 @@ -# Normalization Layers +# RMSNorm and GroupNorm -## Theory +## RMSNorm -Normalization layers stabilize training by controlling the distribution of activations. `neutro` implements four variants. +### What does this layer do? -### Layer Normalization — `neutro/layers/normalization/layernorm.py` +RMSNorm (Root Mean Square Normalization) is a simplified version of LayerNorm that **drops the mean-centering step**. It only divides by the root-mean-square of the activations, then scales by a learned weight. -Normalizes across the feature dimension for each sample independently: +Why remove the mean? Empirical results from Llama, Qwen, and DeepSeek show that the mean-centering in LayerNorm doesn't help much — the RMS scaling alone provides enough normalization. This saves computation (no mean subtraction) and simplifies the backward pass. -$$\mu = \frac{1}{H} \sum_{i=1}^H x_i, \quad \sigma = \sqrt{\frac{1}{H} \sum_{i=1}^H (x_i - \mu)^2 + \epsilon}$$ +### The math -$$\hat{x} = \frac{x - \mu}{\sigma}, \quad y = \gamma \hat{x} + \beta$$ +$$\text{RMS}(x) = \sqrt{\frac{1}{D} \sum_{i=1}^{D} x_i^2 + \epsilon} \quad\quad y = \frac{x}{\text{RMS}(x)} \cdot \gamma$$ -Used in Transformers (GPT, BERT, Llama). Independent of batch size. +No `beta` parameter, no mean computation — just a single learnable `weight` (also called `gamma`). -### Batch Normalization — `neutro/layers/normalization/batchnorm.py` +The RMS is computed `axis=-1` (the feature dimension), same as LayerNorm. -Normalizes across the batch dimension for each feature: +### Walking through the code -$$\mu_{\mathcal{B}} = \frac{1}{m} \sum x_i, \quad \sigma_{\mathcal{B}}^2 = \frac{1}{m} \sum (x_i - \mu_{\mathcal{B}})^2$$ +#### `__init__` / `build` -$$\hat{x}_i = \frac{x_i - \mu_{\mathcal{B}}}{\sqrt{\sigma_{\mathcal{B}}^2 + \epsilon}}, \quad y_i = \gamma \hat{x}_i + \beta$$ +```python +def __init__(self, epsilon=1e-6, **kwargs): + super().__init__(**kwargs) + self.epsilon = epsilon + +def build(self, input_shape): + self.dim = input_shape[-1] + self.params['weight'] = np.ones(self.dim) + super().build(input_shape) +``` + +🔍 **Only one parameter**: RMSNorm has `weight` but no `beta`. LayerNorm has `gamma` and `beta`. This saves 50% of the normalization parameters. + +📐 **`weight`**: Shape `(D,)` — one scalar per feature dimension. Starts as all ones (identity). + +#### `forward` + +```python +def forward(self, x, training=False): + self.x = x + self.rms = np.sqrt(np.mean(x**2, axis=-1, keepdims=True) + self.epsilon) + self.x_norm = x / self.rms + return self.x_norm * self.params['weight'] +``` + +🔍 **`np.mean(x**2, axis=-1, keepdims=True)`**: Notice we compute `x²` first (element-wise square), then average over the feature axis. This is RMS — no mean-centering. + +📐 **`self.rms`**: Shape `(B, S, 1)` for a `(B, S, D)` input. One RMS value per position per sample. + +📐 **`x / self.rms`**: Broadcasting divides each feature by its RMS. + +📐 **Return**: `(B, S, D)` — same as input, each position normalized independently. -Tracks running mean/variance for inference. Used in CNNs. +🔍 **`training=False` is accepted but ignored**: RMSNorm works identically at train and inference time. No running statistics, no mode switching. This simplicity is a big advantage over BatchNorm. -### RMS Norm — `neutro/layers/normalization/rmsnorm.py` +#### `backward` -Root Mean Square Normalization — a simplified LayerNorm without mean centering: +```python +def backward(self, grad_output): + self.grads['weight'] = np.sum(grad_output * self.x_norm, axis=(0, 1)) + + N = self.dim + grad_x_norm = grad_output * self.params['weight'] + sum_grad_x = np.sum(grad_x_norm * self.x, axis=-1, keepdims=True) + dx = (grad_x_norm / self.rms) - (self.x * sum_grad_x / (N * self.rms**3)) + return dx +``` + +📐 **`self.grads['weight']`**: Summing over axes `(0, 1)` — the batch and sequence dimensions. For a `(B, S, D)` gradient, this produces `(D,)`. + +🔍 **`grad_x_norm`**: The gradient through the `weight` scale, same concept as `dx_norm` in LayerNorm. + +🔍 **The `dx` formula**: Two terms: +1. **`grad_x_norm / rms`**: The direct path — if RMS were just a constant divisor. +2. **`- x * sum(grad_x_norm * x) / (N * rms³)`**: Corrects for the fact that `rms` depends on `x`. When `x` changes, the RMS changes too, and this term accounts for that. + +Compare this to the LayerNorm backward and you'll notice it's **simpler** — no term for the mean. The mean subtraction was adding two correction terms (`N * dx_norm` had `-sum(dx_norm)` in LayerNorm); here we just have the direct path and the variance correction. -$$\text{RMS}(x) = \sqrt{\frac{1}{H} \sum_{i=1}^H x_i^2 + \epsilon}, \quad y = \frac{x}{\text{RMS}(x)} \cdot \gamma$$ +--- -Used in Llama and modern LLMs for efficiency. +## Group Normalization -### Group Normalization — `neutro/layers/normalization/groupnorm.py` +### What does this layer do? -Divides channels into groups and normalizes within each group: +Group Normalization (GroupNorm) divides the channels of a convolutional feature map into **groups** and normalizes within each group independently. It's the middle ground between LayerNorm (too coarse for vision) and BatchNorm (too batch-dependent). -$$\mu_g = \frac{1}{|\mathcal{G}_g|} \sum_{i \in \mathcal{G}_g} x_i, \quad \sigma_g^2 = \frac{1}{|\mathcal{G}_g|} \sum_{i \in \mathcal{G}_g} (x_i - \mu_g)^2$$ +Imagine a feature map with 64 channels. With 8 groups, each group has 8 channels. GroupNorm computes mean and variance over `(height, width, 8 channels)` — every spatial position and all channels in the group. -Used in vision models when batch size is small (e.g., video, medical imaging). +### When to use GroupNorm? -## Implementation Guide +- **Small batch sizes** (video, medical imaging, object detection) +- **Vision transformers** (ViT) where batch size is constrained by memory +- Any time BatchNorm's batch dependency causes trouble -All normalization layers share a common pattern: +### The math -| Method | Behavior | -|---|---| -| `build(input_shape)` | Allocates `gamma` (scale) and `beta` (shift) parameters. Shape matches the feature dimension. | -| `forward(x)` | Computes mean/variance, normalizes, scales, shifts. | -| `backward(grad)` | Backpropagates through normalization using the stored mean/variance. | +For input `x` with shape `(N, H, W, C)` and `G` groups, the channels are split: each group has `C // G` channels. For group `g`: -For LayerNorm: +$$\mu_g = \frac{1}{|\mathcal{G}_g|} \sum_{i \in \mathcal{G}_g} x_i \quad\quad \sigma_g^2 = \frac{1}{|\mathcal{G}_g|} \sum_{i \in \mathcal{G}_g} (x_i - \mu_g)^2$$ + +$$\hat{x} = \frac{x - \mu_g}{\sqrt{\sigma_g^2 + \epsilon}} \quad\quad y = \gamma \odot \hat{x} + \beta$$ + +The mean and variance are computed over axes `(H, W, C//G)` — all spatial positions and all channels within a group. + +### Walking through the code + +#### `__init__` / `build` ```python -def forward(self, x): - self.mean = np.mean(x, axis=-1, keepdims=True) - self.var = np.var(x, axis=-1, keepdims=True) - self.x_hat = (x - self.mean) / np.sqrt(self.var + self.eps) - return self.gamma * self.x_hat + self.beta +def __init__(self, groups=32, epsilon=1e-5, **kwargs): + super().__init__(**kwargs) + self.groups = groups + self.epsilon = epsilon + +def build(self, input_shape): + dim = input_shape[-1] + if dim % self.groups != 0: + raise ValueError( + f"Number of channels ({dim}) must be divisible by groups ({self.groups})" + ) + + self.params['gamma'] = np.ones((1, 1, 1, dim)) + self.params['beta'] = np.zeros((1, 1, 1, dim)) + super().build(input_shape) ``` -## Usage Example +🔍 **`groups=32`**: Default from the original paper. For 64 channels, 32 groups means 2 channels per group. + +🔍 **`dim % self.groups != 0` check**: Groups must evenly divide channels. If you have 64 channels and 3 groups, that's impossible — you can't split 64 into 3 equal integer parts. + +📐 **`gamma` / `beta`**: Shape `(1, 1, 1, C)` — four-dimensional with dummy batch/spatial dims. This makes broadcasting against `(N, H, W, C)` automatic without explicit broadcasting. + +#### `forward` ```python -from neutro.layers import LayerNormalization +def forward(self, x, training=False): + self.x_shape = x.shape + batch, h, w, c = x.shape + g = self.groups + + x_reshaped = x.reshape(batch, h, w, g, c // g) + + self.mean = np.mean(x_reshaped, axis=(1, 2, 4), keepdims=True) + self.var = np.var(x_reshaped, axis=(1, 2, 4), keepdims=True) -ln = LayerNormalization(epsilon=1e-6) -x = np.random.randn(4, 16, 64) # (batch, seq, features) -y = ln(x) # Normalized along last axis, same shape + self.std = np.sqrt(self.var + self.epsilon) + self.x_centered = x_reshaped - self.mean + self.x_norm = self.x_centered / self.std + + x_norm = self.x_norm.reshape(batch, h, w, c) + + return self.params['gamma'] * x_norm + self.params['beta'] ``` +📐 **`x.reshape(batch, h, w, g, c // g)`**: The key move. An input `(2, 4, 4, 64)` with 8 groups becomes `(2, 4, 4, 8, 8)`. The last dimension is now the group sub-dimension. + +🔍 **`np.mean(x_reshaped, axis=(1, 2, 4))`**: Mean over `height (1)`, `width (2)`, and `group-channel (4)`. This computes one mean per **batch element × group** — shape `(2, 1, 1, 8, 1)`. + +📐 **`self.mean` shape**: `(N, 1, 1, G, 1)` — one mean per sample per group. Broadcasting divides each group's `H*W*(C//G)` elements by their shared mean. + +📐 **`self.var` shape**: Same `(N, 1, 1, G, 1)`. + +📐 **Back to `(N, H, W, C)`**: After normalization, reshape back to the original shape. + +🔍 **`training=False` is ignored**: Like RMSNorm and LayerNorm, GroupNorm is identical during training and inference. No running statistics. + +#### `backward` + +```python +def backward(self, grad_output): + batch, h, w, c = self.x_shape + g = self.groups + m = h * w * (c // g) + + self.grads['gamma'] = np.sum( + grad_output * self.x_norm.reshape(batch, h, w, c), + axis=(0, 1, 2), keepdims=True + ) + self.grads['beta'] = np.sum( + grad_output, axis=(0, 1, 2), keepdims=True + ) + + dx_norm = grad_output * self.params['gamma'] + dx_norm = dx_norm.reshape(batch, h, w, g, c // g) + + sum_dx_norm = np.sum(dx_norm, axis=(1, 2, 4), keepdims=True) + sum_dx_norm_x_norm = np.sum(dx_norm * self.x_norm, axis=(1, 2, 4), keepdims=True) + + dx = (1.0 / m) / self.std * ( + m * dx_norm - sum_dx_norm - self.x_norm * sum_dx_norm_x_norm + ) + + return dx.reshape(batch, h, w, c) +``` + +📐 **`m = h * w * (c // g)`**: The total number of elements contributing to each group's statistics. For a `(N, 4, 4, 64)` input with 8 groups: `m = 4 * 4 * 8 = 128`. + +📐 **`self.grads['gamma']`**: Sum over `(0, 1, 2)` — batch, height, width. Produces shape `(1, 1, 1, C)` matching the gamma parameter. + +📐 **`dx_norm.reshape(batch, h, w, g, c // g)`**: Reshape the gradient into the same grouped format as forward, so we can compute per-group statistics. + +🔍 **Sum axes `(1, 2, 4)`**: The same axes used during forward — height, width, and group-channel. `sum_dx_norm` has shape `(N, 1, 1, G, 1)`. + +🔍 **The big `dx` formula**: The same structure as BatchNorm and LayerNorm! Three terms: +1. `m * dx_norm` — direct gradient +2. `- sum_dx_norm` — correction through mean +3. `- x_norm * sum_dx_norm_x_norm` — correction through variance + +Applied **independently per group** because the sums are over group-specific axes. + +## Comparing the four normalizations + +| Layer | Stat axes | Parameters | Running stats? | Used in | +|---|---|---|---|---| +| LayerNorm | `(D)` — features | γ, β | No | Transformers | +| BatchNorm | `(N, H, W)` — batch/spatial | γ, β | Yes (mean, var) | CNNs (big batch) | +| RMSNorm | `(D)` — features | weight | No | Llama, DeepSeek | +| GroupNorm | `(H, W, C//G)` — spatial + sub-channels | γ, β | No | Vision (small batch) | + ## References -- Ba, J. L., Kiros, J. R., & Hinton, G. E. (2016). **Layer Normalization**. [arXiv:1607.06450](https://arxiv.org/abs/1607.06450) -- Ioffe, S., & Szegedy, C. (2015). **Batch Normalization**. [arXiv:1502.03167](https://arxiv.org/abs/1502.03167) - Zhang, B., & Sennrich, R. (2019). **Root Mean Square Layer Normalization**. [arXiv:1910.07467](https://arxiv.org/abs/1910.07467) - Wu, Y., & He, K. (2018). **Group Normalization**. [arXiv:1803.08494](https://arxiv.org/abs/1803.08494) diff --git a/docs/layers/pooling/pooling.md b/docs/layers/pooling/pooling.md index d15d5b5..bec7ec0 100644 --- a/docs/layers/pooling/pooling.md +++ b/docs/layers/pooling/pooling.md @@ -1,59 +1,329 @@ # Pooling Layers -## Theory +## MaxPooling2D -Pooling layers reduce the spatial dimensions of feature maps, providing downsampling and local translation invariance. +### What does this layer do? -### MaxPooling2D — `neutro/layers/pooling/maxpooling2d.py` +Slides a fixed-size window over a 2D feature map and keeps only the **maximum** value in each window. This downsamples the spatial dimensions (height & width) while preserving the important features — if a strong edge detector fires somewhere, it doesn't matter exactly *where*. -Slides a window over the input and takes the maximum value in each window: +### The math -$$y_{i,j,k} = \max_{p=1..P, q=1..Q} x_{i \cdot s + p,\; j \cdot s + q,\; k}$$ +For each window at position `(i, j)` in channel `k`: -- **Forward**: `np.max` over sliding windows. -- **Backward**: Routes gradient to the position that was the maximum (argmax routing). +$$y_{i,j,k} = \max_{p=0..P-1,\; q=0..Q-1} x_{i \cdot s + p,\; j \cdot s + q,\; k}$$ -### Global Pooling — `neutro/layers/pooling/global_pooling.py` +Where $P\times Q$ is the pool window size and $s$ is the stride. -Reduces each feature map to a single value: +### Walking through the code -- **GlobalAveragePooling2D**: $y_k = \frac{1}{H \cdot W} \sum_{i,j} x_{i,j,k}$ -- **GlobalMaxPooling2D**: $y_k = \max_{i,j} x_{i,j,k}$ +#### `__init__` — setting up window geometry -Used before the final Dense layer in CNNs to replace Flatten (fewer parameters, no overfitting). +```python +def __init__(self, pool_size=(2, 2), strides=None, data_format='channels_last', **kwargs): + super().__init__(**kwargs) + self.pool_size = pool_size if isinstance(pool_size, (tuple, list)) else (pool_size, pool_size) + strides = strides if strides else self.pool_size + self.strides = strides if isinstance(strides, (tuple, list)) else (strides, strides) + if data_format not in ('channels_last', 'channels_first'): + raise ValueError("data_format must be 'channels_last' or 'channels_first'") + self.data_format = data_format +``` + +🔍 **Line 3-4**: If `pool_size` is an int like `2`, expand it to `(2, 2)`. Same for `strides`. + +🔍 **Line 5**: If `strides` is `None`, default to `pool_size` — the usual "non-overlapping windows" behavior. + +🔍 **Lines 7-8**: `data_format` tells us whether channels are last (NHWC — TensorFlow convention) or first (NCHW — PyTorch convention). The layer normalizes everything to `channels_last` internally for simpler indexing: `(batch, h, w, c)`. + +#### `forward` — finding the max + +```python +def forward(self, inputs, training=False): + self.inputs = inputs + inputs_nhwc = self._to_channels_last(inputs) + batch, h, w, c = inputs_nhwc.shape + ph, pw = self.pool_size + sh, sw = self.strides + + oh = (h - ph) // sh + 1 + ow = (w - pw) // sw + 1 + + x = inputs_nhwc.transpose(0, 3, 1, 2).reshape(-1, 1, h, w) + + self.x_cols = im2col_indices(x, ph, pw, padding=0, stride=sh) + self.arg_max = np.argmax(self.x_cols, axis=0) + out = self.x_cols[self.arg_max, np.arange(self.arg_max.size)] + + out = out.reshape(oh, ow, batch, c).transpose(2, 0, 1, 3) + return self._from_channels_last(out) +``` + +🔍 **Line 3**: Cache the original input for backward pass. We'll need it for shape information. + +🔍 **Line 4**: Normalize to `(batch, h, w, c)` if in `channels_first` format. + +📐 **Shape so far**: `inputs_nhwc` is `(B, H, W, C)`. + +🔍 **Lines 8-9**: Compute output spatial dimensions. For `H=28, ph=2, sh=2`: `(28-2)//2 + 1 = 14`. + +🔍 **Line 10**: Transpose to `(B, C, H, W)` then reshape to `(B*C, 1, H, W)`. This merges batch and channel dimensions so `im2col_indices` treats each channel of each sample independently. + +📐 **Shape**: `(B, H, W, C)` → `.transpose(0, 3, 1, 2)` → `(B, C, H, W)` → `.reshape(-1, 1, H, W)` → `(B*C, 1, H, W)`. + +🔍 **Line 12**: `im2col_indices` "unrolls" each window into a column. For a 2×2 window, each column has 4 elements (one per window position). The result has shape: + +📐 **Shape of `self.x_cols`**: `(ph * pw * in_channels, oh * ow * batch)` = `(4, oh * ow * B)` for a 2×2 pool with 1 channel per grouped sample. Each column is one window; each row is one position *within* that window. + +🔍 **Line 13**: `np.argmax(self.x_cols, axis=0)` — for each column (window), find which row (position within the window) has the maximum value. + +📐 **Shape of `self.arg_max`**: `(oh * ow * B*C,)` — one index per window. This is cached for backward: it tells us *exactly which element* in each window was the winner. + +🔍 **Line 14**: Select the max value from each column using fancy indexing. `self.arg_max` gives the row index, `np.arange(self.arg_max.size)` gives the column index. + +📐 **Shape**: `self.x_cols[self.arg_max, np.arange(self.arg_max.size)]` → `(oh * ow * B * C,)` — one scalar per window. + +🔍 **Lines 16-17**: Unflatten back: `(oh * ow * B * C,)` → `(oh, ow, B, C)` → `.transpose(2, 0, 1, 3)` → `(B, oh, ow, C)`. Then convert back to original data format if needed. + +#### `backward` — routing gradients to the winner + +```python +def backward(self, grad_output): + grad_output_nhwc = self._to_channels_last(grad_output) + batch, oh, ow, c = grad_output_nhwc.shape + ph, pw = self.pool_size + sh, sw = self.strides + _, h, w, _ = self._shape_to_channels_last(self.inputs.shape) + + dout = grad_output_nhwc.transpose(1, 2, 0, 3).flatten() + + dx_cols = np.zeros_like(self.x_cols) + dx_cols[self.arg_max, np.arange(self.arg_max.size)] = dout + + dx = col2im_indices(dx_cols, (batch * c, 1, h, w), ph, pw, padding=0, stride=sh) + return self._from_channels_last(dx.reshape(batch, c, h, w).transpose(0, 2, 3, 1)) +``` + +🔍 **Line 2-3**: Normalize gradient to NHWC and pull out shapes. + +🔍 **Line 6**: Get the original input's spatial dimensions `(H, W)` from the cached `self.inputs`. + +🔍 **Line 8**: Flatten the gradient to match the column layout from forward. + +📐 **Shape**: `grad_output` `(B, oh, ow, C)` → `.transpose(1, 2, 0, 3)` → `(oh, ow, B, C)` → `.flatten()` → `(oh * ow * B * C,)`. + +🔍 **Lines 10-11**: The key insight: **only the max position gets the gradient**. Create a zero gradient buffer the same shape as `self.x_cols` (`(4, oh * ow * B * C)` for 2×2 pool). Then, for each window (column), place the output gradient at exactly the row index that was the argmax in forward. All other positions within that window get 0. + +🔍 **Why this works**: If element `(0, 0)` was the max in a 2×2 window, `self.arg_max` for that window is `0`. The gradient for position `(0, 0)` equals the output gradient; positions `(0, 1)`, `(1, 0)`, `(1, 1)` in that window get 0. Small changes to non-max positions don't affect the output, so they get no gradient. + +📐 **`dx_cols[argmax, col_indices] = dout`**: `dx_cols` has shape `(win_size, n_windows)`. `self.arg_max` has shape `(n_windows,)`. For each window `j`, we set `dx_cols[arg_max[j], j] = dout[j]`. + +🔍 **Line 13**: `col2im_indices` is the inverse of `im2col_indices` — it scatters the columns back onto the original `(B*C, 1, H, W)` grid. `np.add.at` is used internally to handle overlapping windows (though with stride == pool_size, there are no overlaps). + +🔍 **Line 14**: Reshape from `(B*C, 1, H, W)` back to `(B, C, H, W)` then transpose to `(B, H, W, C)`, and finally convert to original data format. + +--- + +## GlobalAveragePooling2D + +### What does this layer do? + +Takes each feature map (H×W) and replaces it with a single number — the **average** of all values in that map. This is a dramatic downsampling: `(B, H, W, C)` → `(B, C)`. It's commonly used before the final classification layer in CNNs to replace Flatten + Dense, drastically reducing parameters. + +### The math + +$$y_k = \frac{1}{H \cdot W} \sum_{i=1}^{H} \sum_{j=1}^{W} x_{i,j,k}$$ + +Spatial dimensions are averaged away; only the channel dimension survives. + +### Walking through the code + +#### `forward` + +```python +def forward(self, inputs, training=False): + self.input_shape_internal = inputs.shape + if self.data_format == 'channels_first': + return np.mean(inputs, axis=(2, 3)) + return np.mean(inputs, axis=(1, 2)) +``` + +🔍 **Line 2**: Cache the full input shape. We'll need `H` and `W` in backward to divide the gradient. + +🔍 **Lines 3-4**: `np.mean(inputs, axis=(2, 3))` — average over height and width (axes 2 and 3) for `channels_first`. -### UpSampling2D — `neutro/layers/pooling/upsampling2d.py$ +📐 **Shape**: `(B, C, H, W)` → `np.mean(axis=(2,3))` → `(B, C)`. Each feature map becomes one number. -Increases spatial dimensions by repeating rows and columns (nearest-neighbor upsampling): +🔍 **Line 5**: For `channels_last`, spatial dimensions are axes 1 and 2. -$$y_{i \cdot f + p,\; j \cdot f + q,\; k} = x_{i,j,k}$$ +📐 **Shape**: `(B, H, W, C)` → `np.mean(axis=(1,2))` → `(B, C)`. -- Used in decoder architectures (UNet, GANs). -- Backward: sums the gradient back into the original positions. +#### `backward` -## Implementation Guide +```python +def backward(self, grad_output): + if self.data_format == 'channels_first': + batch, c, h, w = self.input_shape_internal + return (grad_output[:, :, None, None] * np.ones((batch, c, h, w))) / (h * w) + batch, h, w, c = self.input_shape_internal + return (grad_output[:, None, None, :] * np.ones((batch, h, w, c))) / (h * w) +``` + +🔍 **Line 3 or 6**: Unpack the cached input shape to get `H` and `W`. + +🔍 **Lines 4 or 7**: The gradient of an average is the upstream gradient **divided evenly** across all H×W positions. The upstream gradient `grad_output` has shape `(B, C)`. We need to broadcast it back to `(B, H, W, C)` — each of the H×W spatial positions gets the same gradient value, scaled by `1/(H*W)`. + +📐 **Shape walkthrough** (channels_last case): `grad_output` `(B, C)` → `grad_output[:, None, None, :]` → `(B, 1, 1, C)` → `* np.ones((B, H, W, C))` → `(B, H, W, C)` → `/(H*W)` → `(B, H, W, C)` — matching the original input shape. + +🔍 **Why it works**: If a feature map was 4×4, the forward averaged all 16 values. So each of those 16 values contributed equally. The backward spreads the gradient back equally: each pixel gets 1/16 of the output gradient. + +--- + +## GlobalMaxPooling2D + +### What does this layer do? + +Same idea as GlobalAveragePooling2D, but instead of averaging, it takes the **maximum** value from each feature map. `(B, H, W, C)` → `(B, C)`. + +### The math + +$$y_k = \max_{i, j} x_{i,j,k}$$ + +### Walking through the code + +#### `forward` + +```python +def forward(self, inputs, training=False): + self.inputs = inputs + if self.data_format == 'channels_first': + inputs_nhwc = inputs.transpose(0, 2, 3, 1) + else: + inputs_nhwc = inputs + self.max_indices = np.argmax(inputs_nhwc.reshape(inputs_nhwc.shape[0], -1, inputs_nhwc.shape[-1]), axis=1) + return np.max(inputs_nhwc, axis=(1, 2)) +``` + +🔍 **Line 2**: Cache input for backward. + +🔍 **Lines 3-5**: Normalize to NHWC for consistent indexing. + +🔍 **Line 6**: Flatten spatial dimensions and find the argmax per channel. + +📐 **Shape**: `inputs_nhwc` is `(B, H, W, C)`. `.reshape(B, -1, C)` → `(B, H*W, C)`. `np.argmax(axis=1)` → `(B, C)`. Each element `max_indices[b, c]` is the flat spatial index (0 to H*W-1) of the maximum value in channel `c`. + +🔍 **Line 7**: `np.max(inputs_nhwc, axis=(1, 2))` — take the max over both spatial dimensions. + +📐 **Shape**: `(B, H, W, C)` → `np.max(axis=(1,2))` → `(B, C)`. + +#### `backward` + +```python +def backward(self, grad_output): + if self.data_format == 'channels_first': + batch, c, h, w = self.inputs.shape + dx_nhwc = np.zeros((batch, h, w, c), dtype=self.inputs.dtype) + else: + batch, h, w, c = self.inputs.shape + dx_nhwc = np.zeros_like(self.inputs) + for b in range(batch): + for channel in range(c): + idx = self.max_indices[b, channel] + ih, iw = divmod(idx, w) + dx_nhwc[b, ih, iw, channel] = grad_output[b, channel] + if self.data_format == 'channels_first': + return dx_nhwc.transpose(0, 3, 1, 2) + return dx_nhwc +``` + +🔍 **Lines 3-6**: Create a zero gradient buffer the same shape as the input. The gradient starts as all zeros — only the max positions will get filled in. -All pooling layers are in `neutro/layers/pooling/`. MaxPooling2D uses `im2col` from `conv_utils.py` to unroll windows, then applies `np.max` and `np.argmax` for efficient forward/backward. +🔍 **Lines 7-10**: For each sample and each channel, find which spatial position `(ih, iw)` was the maximum. `divmod(idx, w)` converts the flat index back to 2D coordinates. Then place `grad_output[b, channel]` at exactly that position. + +🔍 **Why no gradient elsewhere**: Same logic as MaxPooling2D — changing a non-maximum pixel doesn't change the output, so its gradient is zero. + +🔍 **The loop is explicit**: Unlike `np.add.at` trickery in MaxPooling2D, this implementation uses simple Python loops. It's slower but easier to understand. Each channel of each sample has exactly one argmax position, so there's no risk of overlapping writes. + +--- + +## UpSampling2D + +### What does this layer do? + +"Nearest neighbor" upsampling: each pixel becomes a block of identical pixels, making the image bigger. If size is `(2, 2)`, every input pixel turns into a 2×2 square of the same value. + +### The math + +$$y_{i \cdot f_h + p,\; j \cdot f_w + q,\; k} = x_{i,j,k} \quad \text{for } p=0..f_h-1,\; q=0..f_w-1$$ + +Each input pixel at `(i, j)` is repeated `f_h` times vertically and `f_w` times horizontally. + +### Walking through the code + +#### `forward` + +```python +def forward(self, inputs, training=False): + self.input_shape_actual = inputs.shape + return np.repeat(np.repeat(inputs, self.size[0], axis=1), self.size[1], axis=2) +``` + +🔍 **Line 2**: Cache the input shape. We'll need `H`, `W` and `C` in backward to know how to un-reshape the gradient. + +🔍 **Line 3**: `np.repeat` on axis=1 (height) repeats each row `size[0]` times, then `np.repeat` on axis=2 (width) repeats each column `size[1]` times. + +📐 **Shape walkthrough**: Input `(B, H, W, C)`. First repeat: `np.repeat(inputs, sh, axis=1)` → `(B, H*sh, W, C)`. Second repeat: `np.repeat(..., sw, axis=2)` → `(B, H*sh, W*sw, C)`. + +For example, input `(2, 3, 3, 1)` with size `(2, 2)` → output `(2, 6, 6, 1)`. Pixel `(0, 0)` becomes a 2×2 block of `pixel(0,0)` at positions `(0..1, 0..1)`. + +#### `backward` ```python -# MaxPooling2D key pattern -cols = im2col(x, self.pool_size, self.strides, padding='valid') -max_idx = np.argmax(cols, axis=0) -output = cols[max_idx, np.arange(cols.shape[1])] -# Reshape to output spatial dimensions +def backward(self, grad_output): + batch, h, w, c = self.input_shape_actual + sh, sw = self.size + + grad = grad_output.reshape(batch, h, sh, w, sw, c) + return grad.sum(axis=(2, 4)) ``` -## Usage Example +🔍 **Line 2**: Get the original input dimensions from the cache. + +🔍 **Line 5**: Reshape the gradient so that the repeated dimensions are separate axes. + +📐 **Shape**: `grad_output` is `(B, H*sh, W*sw, C)`. Reshape to `(B, H, sh, W, sw, C)`. Now each original pixel's `sh × sw` block is on axes 2 and 4. + +🔍 **Line 6**: Sum over the repeated axes (2 and 4). Since forward repeated the same value across the block, backward sums all those gradients back into a single value. + +📐 **Shape**: `(B, H, sh, W, sw, C)` → `.sum(axis=(2,4))` → `(B, H, W, C)` — matching the original input shape. + +🔍 **Why sum and not average?** Forward copied each pixel's value `sh * sw` times — it didn't take any kind of average. So if the loss wants to increase a pixel's value, it gets `sh * sw` identical "votes" from the repeated positions. Summing respects the fact that the original pixel's value influences all `sh * sw` output positions equally and independently. + +### Try it yourself ```python -from neutro.layers import MaxPooling2D, GlobalAveragePooling2D +from neutro.layers import MaxPooling2D, GlobalAveragePooling2D, GlobalMaxPooling2D, UpSampling2D +import numpy as np +# MaxPooling2D: downsample 28x28 to 14x14 pool = MaxPooling2D(pool_size=(2, 2)) x = np.random.randn(2, 28, 28, 16) -y = pool(x) # shape (2, 14, 14, 16) +y = pool(x) +print(f"MaxPool: {x.shape} → {y.shape}") # (2, 28, 28, 16) → (2, 14, 14, 16) +# Global pooling: spatial dims → scalar per channel gap = GlobalAveragePooling2D() -z = gap(y) # shape (2, 16) +z = gap(y) +print(f"GlobalAvgPool: {y.shape} → {z.shape}") # (2, 14, 14, 16) → (2, 16) + +gmp = GlobalMaxPooling2D() +z2 = gmp(y) +print(f"GlobalMaxPool: {y.shape} → {z2.shape}") # (2, 14, 14, 16) → (2, 16) + +# UpSampling2D: nearest neighbor upsample +up = UpSampling2D(size=(2, 2)) +x_small = np.random.randn(2, 14, 14, 8) +y_big = up(x_small) +print(f"UpSample: {x_small.shape} → {y_big.shape}") # (2, 14, 14, 8) → (2, 28, 28, 8) ``` ## References diff --git a/docs/layers/recurrent/lstm.md b/docs/layers/recurrent/lstm.md index bf9a274..6c5c236 100644 --- a/docs/layers/recurrent/lstm.md +++ b/docs/layers/recurrent/lstm.md @@ -1,20 +1,352 @@ # Long Short-Term Memory (LSTM) -## Overview -LSTM is a type of recurrent neural network (RNN) architecture designed to solve the vanishing gradient problem in standard RNNs. It uses a gating mechanism to regulate the flow of information. - -## Mathematical Formulation -For each time step $t$: -1. **Forget Gate**: $f_t = \sigma(W_f \cdot [h_{t-1}, x_t] + b_f)$ -2. **Input Gate**: $i_t = \sigma(W_i \cdot [h_{t-1}, x_t] + b_i)$ -3. **Cell Candidate**: $\tilde{C}_t = \tanh(W_C \cdot [h_{t-1}, x_t] + b_C)$ -4. **Cell State Update**: $C_t = f_t * C_{t-1} + i_t * \tilde{C}_t$ -5. **Output Gate**: $o_t = \sigma(W_o \cdot [h_{t-1}, x_t] + b_o)$ -6. **Hidden State**: $h_t = o_t * \tanh(C_t)$ - -## Implementation Details -The `LSTM` layer in `neutro` performs the full forward and backward pass (Backpropagation Through Time, BPTT) over the sequence dimension. We concatenate the weights for the four gates into a single matrix to optimize the dot product operations. - -## Citations -- Hochreiter, S., & Schmidhuber, J. (1997). **Long Short-Term Memory**. *Neural Computation*. [DOI: 10.1162/neco.1997.9.8.1735](https://direct.mit.edu/neco/article/9/8/1735/6109/Long-Short-Term-Memory) -- [Original Paper PDF (Bioinf JKU)](https://www.bioinf.jku.at/publications/older/2604.pdf) +## What does this layer do? + +LSTM (Long Short-Term Memory) is a recurrent layer designed to solve the **vanishing gradient problem** that plagues simple RNNs. It introduces a **cell state** — a "memory highway" that can carry information across many timesteps with minimal gradient decay — controlled by four learned gates. + +### The math + +At each timestep $t$, with input $x_t$, previous hidden state $h_{t-1}$, and previous cell state $C_{t-1}$: + +$$f_t = \sigma(x_t W_f + h_{t-1} U_f + b_f) \quad \text{(forget gate)}$$ +$$i_t = \sigma(x_t W_i + h_{t-1} U_i + b_i) \quad \text{(input gate)}$$ +$$\tilde{C}_t = \tanh(x_t W_C + h_{t-1} U_C + b_C) \quad \text{(candidate cell state)}$$ +$$C_t = f_t \odot C_{t-1} + i_t \odot \tilde{C}_t \quad \text{(cell state update)}$$ +$$o_t = \sigma(x_t W_o + h_{t-1} U_o + b_o) \quad \text{(output gate)}$$ +$$h_t = o_t \odot \tanh(C_t) \quad \text{(hidden state)}$$ + +Let's unpack every symbol: + +- **$f_t$** — Forget gate. Values in `(0, 1)`. Controls how much of the *previous* cell state $C_{t-1}$ to keep. A value near 1 means "remember everything"; near 0 means "forget everything." +- **$i_t$** — Input gate. Values in `(0, 1)`. Controls how much of the *candidate* $\tilde{C}_t$ to add to the cell state. +- **$\tilde{C}_t$** — Candidate cell state. A proposed update to the cell state, computed like a simple RNN hidden state. +- **$C_t$** — Cell state. The "memory highway." Updated as: forget part of the old state, then add new information. +- **$o_t$** — Output gate. Values in `(0, 1)`. Controls how much of the cell state to expose as the hidden state $h_t$. +- **$h_t$** — Hidden state. The output at this timestep. A gated view of the cell state. + +### Efficacy trick: one matrix to rule all gates + +```python +self.params['W'] = init((self.features + self.units, 4 * self.units)) +``` + +Instead of 8 separate weight matrices (W_f, W_i, W_C, W_o, U_f, U_i, U_C, U_o), LSTM concatenates the input and hidden state, then uses ONE matrix multiply with a `4 * units`-wide matrix: + +$$ +\text{concat}(x_t, h_{t-1}) \cdot W +$$ + +Then split the result into four equal chunks for i, f, c_tilde, o: + +``` +z → [ i | f | c_tilde | o ] + :units :2U :3U :4U +``` + +This is purely an optimization — one big matrix multiply is faster than eight small ones — but mathematically it's equivalent. + +--- + +## Walking through the code + +### `__init__` / `build` + +```python +class LSTM(Layer): + def __init__(self, units, return_sequences=False): + super().__init__() + self.units = units + self.return_sequences = return_sequences + + def build(self, input_shape): + self.features = input_shape[-1] + init = get_initializer('glorot_uniform') + self.params['W'] = init((self.features + self.units, 4 * self.units)) + self.params['b'] = get_initializer('zeros')((4 * self.units,)) + super().build(input_shape) +``` + +🔍 **`W` shape `(features + units, 4 * units)`** — This is the concatenated input-to-hidden AND hidden-to-hidden weight matrix. + +📐 If `features = 32` and `units = 64`, then `W` is `(96, 256)`: + +``` +W = [W_input | W_hidden] (first features rows from x, next units rows from h) + ↑ ↑ + shape (32, 256) shape (64, 256) +``` + +Each of the four gates gets `units` columns: columns `0:64` → i-gate, `64:128` → f-gate, `128:192` → c_tilde, `192:256` → o-gate. + +🔍 **`b` shape `(4 * units,)`** — One bias per gate. Split the same way as the columns. + +Why no separate `U` matrix here? Because we concatenate `x_t` and `h_{t-1}` first, then multiply by the single `W`. The top `features` rows serve as the input weight, the bottom `units` rows serve as the recurrent weight. + +### `forward` + +```python +def forward(self, inputs, training=False): + self.inputs = inputs + batch, timesteps, _ = inputs.shape + self.h_states = np.zeros((batch, timesteps + 1, self.units)) + self.c_states = np.zeros((batch, timesteps + 1, self.units)) + self.gates = np.zeros((batch, timesteps, 4 * self.units)) + + for t in range(timesteps): + concat = np.concatenate([inputs[:, t, :], self.h_states[:, t, :]], axis=1) + z = np.dot(concat, self.params['W']) + self.params['b'] + self.gates[:, t, :] = z + i, f, c_tilde, o = self._sigmoid(z[:, :self.units]), \ + self._sigmoid(z[:, self.units:2*self.units]), \ + np.tanh(z[:, 2*self.units:3*self.units]), \ + self._sigmoid(z[:, 3*self.units:]) + self.c_states[:, t+1, :] = f * self.c_states[:, t, :] + i * c_tilde + self.h_states[:, t+1, :] = o * np.tanh(self.c_states[:, t+1, :]) + return self.h_states[:, 1:, :] if self.return_sequences else self.h_states[:, -1, :] +``` + +🔍 **Line 29**: `self.h_states = np.zeros((batch, timesteps + 1, self.units))` — Hidden states, same `T+1` pattern as SimpleRNN: index `0` is `h_0 = 0`. + +🔍 **Line 30**: `self.c_states = np.zeros((batch, timesteps + 1, self.units))` — Cell states, also `T+1`. Index `0` is `C_0 = 0`. **Both** `h_states` and `c_states` are cached for backward. + +🔍 **Line 31**: `self.gates = np.zeros((batch, timesteps, 4 * self.units))` — All four gate pre-activation values for every timestep. Cached for backward. + +--- + +**Inside the loop (t = 0 to T-1):** + +🔍 **Line 34**: `concat = np.concatenate([inputs[:, t, :], h_states[:, t, :]], axis=1)` + +📐 `inputs[:, t, :]` is `(batch, features)`, `h_states[:, t, :]` is `(batch, units)`. Concatenated: `(batch, features + units)`. + +🔍 **Line 35**: `z = np.dot(concat, W) + b` + +📐 `(batch, features+units) @ (features+units, 4*units)` → `(batch, 4*units)`. + +🔍 **Line 36**: `self.gates[:, t, :] = z` — Save the raw pre-activation for backward. + +🔍 **Line 37**: Split `z` into four gates: + +```python +i = sigmoid(z[:, :units]) # input gate +f = sigmoid(z[:, units:2*units]) # forget gate +c_tilde = tanh(z[:, 2*units:3*units]) # candidate +o = sigmoid(z[:, 3*units:]) # output gate +``` + +Note: `i` and `f` use sigmoid (values 0 to 1), `c_tilde` uses tanh (values -1 to 1). + +🔍 **Line 38**: `C_t = f * C_{t-1} + i * c_tilde` + +📐 All four are shape `(batch, units)`. Element-wise operations. + +This is the **cell state update**. The forget gate `f` decides how much of the old cell state to keep. The input gate `i` decides how much of the candidate to add. This is the "memory highway" — information can flow through unchanged when `f = 1` and `i = 0`. + +🔍 **Line 39**: `h_t = o * tanh(C_t)` + +📐 `(batch, units)`. The output gate controls how much of the cell state is exposed as the hidden state. + +🔍 **Line 40**: Return all hidden states or just the last one, depending on `return_sequences`. + +📐 Why cache ALL of `h_states`, `c_states`, and `gates`? Because `backward` needs: +- `h_states[:, t, :]` and `h_states[:, t+1, :]` for every `t` +- `c_states[:, t, :]` and `c_states[:, t+1, :]` for every `t` +- `gates[:, t, :]` (the pre-activation z) for every `t` to recompute the gate values + +Without caching, backward would need to re-run the entire forward pass. + +### `backward` (BPTT) + +```python +def backward(self, grad_output): + batch, timesteps, _ = self.inputs.shape + d_W, d_b, d_inputs = np.zeros_like(self.params['W']), \ + np.zeros_like(self.params['b']), \ + np.zeros_like(self.inputs) + dh_next, dc_next = np.zeros((batch, self.units)), \ + np.zeros((batch, self.units)) + + for t in range(timesteps - 1, -1, -1): + dh = (grad_output[:, t, :] if self.return_sequences + else (grad_output if t == timesteps - 1 else 0)) + dh_next + z = self.gates[:, t, :] + i, f, c_tilde, o = self._sigmoid(z[:, :self.units]), \ + self._sigmoid(z[:, self.units:2*self.units]), \ + np.tanh(z[:, 2*self.units:3*self.units]), \ + self._sigmoid(z[:, 3*self.units:]) + tanh_c = np.tanh(self.c_states[:, t+1, :]) + do, dc = dh * tanh_c, dh * o * (1 - tanh_c**2) + dc_next + df, di, dc_tilde = dc * self.c_states[:, t, :], dc * c_tilde, dc * i + dz = np.concatenate([di * i * (1 - i), + df * f * (1 - f), + dc_tilde * (1 - c_tilde**2), + do * o * (1 - o)], axis=1) + concat = np.concatenate([self.inputs[:, t, :], + self.h_states[:, t, :]], axis=1) + d_W += np.dot(concat.T, dz) + d_b += np.sum(dz, axis=0) + d_concat = np.dot(dz, self.params['W'].T) + d_inputs[:, t, :], dh_next, dc_next = \ + d_concat[:, :self.features], d_concat[:, self.features:], f * dc + self.grads['W'], self.grads['b'] = d_W, d_b + return d_inputs +``` + +🧠 **"The backward loop goes in REVERSE order of the forward loop — that's the 'through time' part of BPTT"** + +🔍 **Line 47**: `for t in range(timesteps - 1, -1, -1)` — From `T-1` down to `0`. + +🔍 **Line 48**: `dh = grad_output + dh_next` — Same two-source gradient as SimpleRNN and GRU: from the layer above AND from the future timestep. + +--- + +**Step 1: Compute `dc` — gradient w.r.t. the cell state** + +$$h_t = o_t \cdot \tanh(C_t)$$ + +🔍 **Line 51-52**: Backprop through `h_t` to get the gradient for the output gate and the cell state: + +```python +tanh_c = tanh(C_{t+1}) +do = dh * tanh_c # gradient w.r.t. output gate o +dc = dh * o * (1 - tanh_c**2) + dc_next +``` + +- `do = dh * tanh(C_t)` — By the product rule of `o * tanh(C_t)`. +- `dc` has **two sources**: + 1. `dh * o * (1 - tanh_c**2)` — The gradient through `tanh(C_t)` in the hidden state computation. The `(1 - tanh_c**2)` is the derivative of `tanh`. + 2. `dc_next` — The gradient from the *next* timestep's cell state (passed backward through `C_{t+1} = f * C_t + ...`). + +📐 Both `(batch, units)`. + +--- + +**Step 2: Backprop through the cell state update** + +$$C_t = f \cdot C_{t-1} + i \cdot \tilde{C}_t$$ + +🔍 **Line 53**: Split `dc` into gradients for each gate: + +```python +df = dc * C_{t-1} # gradient w.r.t. forget gate f +di = dc * c_tilde # gradient w.r.t. input gate i +dc_tilde = dc * i # gradient w.r.t. candidate c_tilde +``` + +These are direct applications of the product rule. Each one is the gradient `dc` multiplied by the *other* operand in the sum. + +🔍 Note how `dc_next` is computed at the end (line 59): `dc_next = f * dc`. This is the gradient of `C_t` w.r.t. `C_{t-1}` — the term `f` in `f * C_{t-1}` — which flows backward to the previous timestep. + +--- + +**Step 3: Backprop through the activation functions** + +Each gate has a different activation: + +🔍 **Line 54**: `dz` is a concatenation of four gradient pieces, each multiplied by the derivative of their respective activation: + +```python +dz = [ + di * i * (1 - i), # sigmoid derivative for input gate + df * f * (1 - f), # sigmoid derivative for forget gate + dc_tilde * (1 - c_tilde**2), # tanh derivative for candidate + do * o * (1 - o) # sigmoid derivative for output gate + ] +``` + +Concatenated: `(batch, 4*units)` — same shape as the original `z`. + +--- + +**Step 4: Accumulate weight gradients** + +🔍 **Line 55**: `concat = concat([inputs[:, t, :], h_states[:, t, :]])` — Same concatenation as forward. + +📐 `(batch, features + units)`. + +🔍 **Line 56**: `d_W += concat.T @ dz` + +📐 `(features+units, batch) @ (batch, 4*units)` → `(features+units, 4*units)` = shape of `W`. + +🔍 **Line 57**: `d_b += sum(dz, axis=0)` → `(4*units,)`. + +--- + +**Step 5: Compute gradient w.r.t. inputs and previous hidden state** + +🔍 **Line 58**: `d_concat = dz @ W.T` + +📐 `(batch, 4*units) @ (4*units, features+units)` → `(batch, features+units)`. + +🔍 **Line 59**: Split `d_concat` back into input gradient and hidden state gradient: + +```python +d_inputs[:, t, :] = d_concat[:, :features] # gradient for the previous layer +dh_next = d_concat[:, features:] # gradient for previous timestep's hidden state +dc_next = f * dc # gradient for previous timestep's cell state +``` + +📐 `d_inputs[:, t, :]` is `(batch, features)`, `dh_next` is `(batch, units)`, `dc_next` is `(batch, units)`. + +--- + +### Why LSTM solves vanishing gradients + +The cell state update is a **linear highway**: + +$$C_t = f_t \odot C_{t-1} + i_t \odot \tilde{C}_t$$ + +When the forget gate `f` is close to 1 and the input gate `i` is close to 0, the cell state flows through **unchanged**: + +$$C_t \approx C_{t-1}$$ + +The gradient of this path is: + +$$\frac{\partial C_t}{\partial C_{t-1}} = f_t$$ + +No repeated tanh or sigmoid squashing that shrinks gradients (like SimpleRNN's `(1 - h²)` factor). The gradient can flow backward through many timesteps without vanishing — that's the core insight of LSTMs. + +In contrast, SimpleRNN's hidden state goes through `tanh` at every step: + +$$h_t = \tanh(x_t W_x + h_{t-1} W_h + b)$$ + +$$\frac{\partial h_t}{\partial h_{t-1}} = (1 - h_t^2) \cdot W_h$$ + +The `(1 - h_t^2)` factor is always ≤ 1, and multiplied across many timesteps, it drives gradients to zero. + +--- + +## Try it yourself + +```python +from neutro.layers.recurrent.lstm import LSTM +import numpy as np + +# Create LSTM layer +lstm = LSTM(units=64, return_sequences=True) + +# Input: batch of 4 sequences, each 10 timesteps, 32 features +x = np.random.randn(4, 10, 32) + +# Forward pass +y = lstm(x) +print(f"Output shape: {y.shape}") # (4, 10, 64) + +# Backward pass +dL_dy = np.random.randn(4, 10, 64) +dL_dx = lstm.backward(dL_dy) +print(f"Input grad shape: {dL_dx.shape}") # (4, 10, 32) +print(f"W grad shape: {lstm.grads['W'].shape}") # (96, 256) +print(f"b grad shape: {lstm.grads['b'].shape}") # (256,) + +# With return_sequences=False +lstm_last = LSTM(units=64, return_sequences=False) +y_last = lstm_last(x) +print(f"Last-only output: {y_last.shape}") # (4, 64) +``` + +## What to read next + +- **`docs/layers/recurrent/recurrent.md`** — SimpleRNN and GRU: simpler recurrent architectures with walkthroughs of BPTT and gate mechanisms. +- **`docs/layers/attention/mha.md`** — Multi-Head Attention: the alternative to recurrence for sequence modeling (used in Transformers). +- **`docs/layers/core/dense.md`** — For a refresher on how `build`, `forward`, and `backward` work in the simplest layer. diff --git a/docs/layers/recurrent/recurrent.md b/docs/layers/recurrent/recurrent.md index 6bc1c96..57f045e 100644 --- a/docs/layers/recurrent/recurrent.md +++ b/docs/layers/recurrent/recurrent.md @@ -1,84 +1,946 @@ -# Recurrent Layers: SimpleRNN, LSTM, GRU +# SimpleRNN and GRU -## Theory +## SimpleRNN -Recurrent Neural Networks process sequences by maintaining a hidden state that is updated at each time step. The key challenge is the **vanishing gradient problem** — gradients diminish exponentially over long sequences. +### What does this layer do? -### SimpleRNN — `neutro/layers/recurrent/simple_rnn.py` +An RNN processes sequences one step at a time, maintaining a hidden state that carries information forward. At each timestep, it combines the current input with the previous hidden state through a learned transformation. -$$h_t = \tanh(W_h \cdot h_{t-1} + W_x \cdot x_t + b)$$ +### The math -Simple RNN suffers from vanishing gradients and cannot capture long-range dependencies. +$$h_t = \tanh(x_t W_x + h_{t-1} W_h + b)$$ -### LSTM — `neutro/layers/recurrent/lstm.py` +Let's unpack every symbol: -Long Short-Term Memory introduces a gating mechanism with a cell state: +- **$x_t$** — The input at timestep `t`. Shape `(batch, features)`. Each timestep's slice of the input sequence. +- **$W_x$** — Input weight matrix. Shape `(features, units)`. Controls how the current input influences the new hidden state. +- **$h_{t-1}$** — The hidden state from the *previous* timestep. Shape `(batch, units)`. This is the "memory" that carries information across timesteps. +- **$W_h$** — Recurrent weight matrix. Shape `(units, units)`. Controls how the previous hidden state influences the new one. +- **$b$** — Bias vector. Shape `(units,)`. +- **$\tanh$** — Hyperbolic tangent activation, squashing values into `(-1, 1)`. +- **$h_t$** — The new hidden state. Also the output at this timestep (if `return_sequences=True`). -$$f_t = \sigma(W_f \cdot [h_{t-1}, x_t] + b_f) \quad \text{(forget gate)}$$ -$$i_t = \sigma(W_i \cdot [h_{t-1}, x_t] + b_i) \quad \text{(input gate)}$$ -$$\tilde{C}_t = \tanh(W_C \cdot [h_{t-1}, x_t] + b_C) \quad \text{(candidate)}$$ -$$C_t = f_t \odot C_{t-1} + i_t \odot \tilde{C}_t \quad \text{(cell update)}$$ -$$o_t = \sigma(W_o \cdot [h_{t-1}, x_t] + b_o) \quad \text{(output gate)}$$ -$$h_t = o_t \odot \tanh(C_t) \quad \text{(hidden state)}$$ +### Walking through the code -The cell state $C_t$ can carry information over long distances with minimal gradient decay. +#### `__init__` / `build` -### GRU — `neutro/layers/recurrent/gru.py` +```python +def __init__(self, units, activation='tanh', return_sequences=False, **kwargs): + super().__init__(**kwargs) + self.units = units + self.return_sequences = return_sequences + self.activation_name = activation + +def build(self, input_shape): + self.features = input_shape[-1] + init = get_initializer('glorot_uniform') + self.params['Wx'] = init((self.features, self.units)) + self.params['Wh'] = init((self.units, self.units)) + self.params['b'] = get_initializer('zeros')((self.units,)) + super().build(input_shape) +``` -Gated Recurrent Unit simplifies LSTM by merging the cell state and hidden state: +🔍 **`features = input_shape[-1]`** — We grab the last dimension of the input. If input is `(batch, timesteps, features)`, this is the feature dimension. The first two (batch and timesteps) are handled dynamically. -$$z_t = \sigma(W_z \cdot [h_{t-1}, x_t]) \quad \text{(update gate)}$$ -$$r_t = \sigma(W_r \cdot [h_{t-1}, x_t]) \quad \text{(reset gate)}$$ -$$\tilde{h}_t = \tanh(W \cdot [r_t \odot h_{t-1}, x_t])$$ -$$h_t = (1 - z_t) \odot h_{t-1} + z_t \odot \tilde{h}_t$$ +🔍 **`Wx` shape `(features, units)`** — Maps `features` input dimensions → `units` hidden dimensions. If input is 128-dimensional and we want 64 hidden units, this is `(128, 64)`. -GRU has fewer parameters than LSTM and often performs comparably. +🔍 **`Wh` shape `(units, units)`** — Maps hidden state → hidden state. A square matrix: `(64, 64)`. This is what makes RNNs "recurrent" — the same hidden-to-hidden weight is applied at every timestep. -## Implementation Guide +🔍 **`b` shape `(units,)`** — One bias per hidden unit. Broadcast across batch. -All recurrent layers are in `neutro/layers/recurrent/`. They share a common pattern: +#### `forward` ```python def forward(self, inputs, training=False): - batch_size, seq_len, input_dim = inputs.shape - # Initialize hidden state - h = np.zeros((batch_size, self.units)) - self.h_states = [] - for t in range(seq_len): - x_t = inputs[:, t, :] - h = self._step(x_t, h) # One RNN step - self.h_states.append(h) - return np.stack(self.h_states, axis=1) + self.inputs = inputs + batch, timesteps, _ = inputs.shape + self.h_states = np.zeros((batch, timesteps + 1, self.units)) + + for t in range(timesteps): + z = np.dot(inputs[:, t, :], self.params['Wx']) + \ + np.dot(self.h_states[:, t, :], self.params['Wh']) + self.params['b'] + if self.activation_name == 'tanh': + self.h_states[:, t+1, :] = np.tanh(z) + else: + self.h_states[:, t+1, :] = z + if self.return_sequences: + return self.h_states[:, 1:, :] + return self.h_states[:, -1, :] ``` -The backward pass (Backpropagation Through Time, BPTT) reverses the loop: +🔍 **Line 27**: `self.inputs = inputs` — Cached for backward. `backward` doesn't receive the original inputs — it only gets `grad_output` — so we must save `inputs` now. + +🔍 **Line 28**: `self.h_states = np.zeros((batch, timesteps + 1, self.units))` — We allocate space for **all** hidden states, one per timestep *plus one extra* for the initial `h_0 = 0`. Shape: `(batch, T+1, units)`. + +Why `T+1`? So `h_states[:, t, :]` and `h_states[:, t+1, :]` both index validly in the loop, where `t` goes `0` to `T-1`. Entry `0` is `h_0` (all zeros), entry `1` is `h_1`, ..., entry `T` is `h_T`. + +🔍 **Line 31**: The core computation: + +📐 `inputs[:, t, :]` is `(batch, features)` @ `Wx` `(features, units)` → `(batch, units)` + +📐 `h_states[:, t, :]` is `(batch, units)` @ `Wh` `(units, units)` → `(batch, units)` + +📐 Adding them plus bias (broadcast) gives `z` of shape `(batch, units)`. + +🔍 **Line 32-35**: `np.tanh(z)` applies the squashing non-linearity. The result is stored at position `t+1` in `h_states`, making it `h_t`. + +🔍 **Lines 36-38**: If `return_sequences=True`, return all hidden states `h_1` through `h_T` — shape `(batch, T, units)`. If `False`, return only the last hidden state `h_T` — shape `(batch, units)`. + +📐 **Output shapes**: +- `return_sequences=True`: `(batch, timesteps, units)` +- `return_sequences=False`: `(batch, units)` + +🔍 Why do we cache ALL hidden states (not just the last one)? Because `backward` needs `h_t` and `h_{t+1}` for every `t` to compute gradients. Without caching, we'd have to re-run the forward loop. + +#### `backward` (BPTT — Backpropagation Through Time) ```python def backward(self, grad_output): - for t in reversed(range(self.seq_len)): - grad_h = grad_output[:, t, :] + grad_h_next - # Backprop through one step - ... - grad_h_next = grad_from_h - return grad_x + batch, timesteps, _ = self.inputs.shape + d_Wx, d_Wh, d_b = np.zeros_like(self.params['Wx']), \ + np.zeros_like(self.params['Wh']), \ + np.zeros_like(self.params['b']) + d_inputs = np.zeros_like(self.inputs) + dh_next = np.zeros((batch, self.units)) + + for t in range(timesteps - 1, -1, -1): + dh = (grad_output[:, t, :] if self.return_sequences \ + else (grad_output if t == timesteps - 1 else 0)) + dh_next + dz = dh * (1 - self.h_states[:, t+1, :]**2) + d_Wx += np.dot(self.inputs[:, t, :].T, dz) + d_Wh += np.dot(self.h_states[:, t, :].T, dz) + d_b += np.sum(dz, axis=0) + d_inputs[:, t, :] = np.dot(dz, self.params['Wx'].T) + dh_next = np.dot(dz, self.params['Wh'].T) + + self.grads['Wx'], self.grads['Wh'], self.grads['b'] = d_Wx, d_Wh, d_b + return d_inputs +``` + +🧠 **"The backward loop goes in REVERSE order of the forward loop — that's the 'through time' part of BPTT"** + +🔍 **Line 46**: `for t in range(timesteps - 1, -1, -1)` — Loop from `T-1` down to `0`. The forward loop went `0, 1, 2, ..., T-1`. The backward loop goes `T-1, ..., 2, 1, 0`. + +🔍 **Line 47**: `dh = grad_output + dh_next` — The gradient arriving at this timestep has **two sources**: + +1. **From above**: the upstream gradient `grad_output` from the loss (or next layer). If `return_sequences=True`, each timestep `t` gets its own slice `grad_output[:, t, :]`. If `return_sequences=False`, only the last timestep gets the gradient; all others get a zero contribution from this source. + +2. **From the future**: `dh_next` — the gradient flowing *back* from timestep `t+1`. This is computed on the **previous iteration** of the backward loop (which was timestep `t+1`, since we're going in reverse). + +This two-source pattern is the essence of BPTT — future timesteps send gradient information backward through `Wh`. + +🔍 **Line 48**: `dz = dh * (1 - h_{t+1}^2)` — This is the derivative of $\tanh$: $\frac{d}{dz}\tanh(z) = 1 - \tanh(z)^2 = 1 - h_{t+1}^2$. + +We multiply `dh` by this derivative (element-wise) to backpropagate through the tanh activation. This is the chain rule: $\frac{\partial L}{\partial z} = \frac{\partial L}{\partial h} \cdot \frac{\partial h}{\partial z}$. + +🔍 **Lines 49-51**: Accumulate gradients for each parameter: + +- `d_Wx += x_t^T @ dz` — Shape: `(features, batch) @ (batch, units)` → `(features, units)` = `Wx` shape. +- `d_Wh += h_t^T @ dz` — Shape: `(units, batch) @ (batch, units)` → `(units, units)` = `Wh` shape. +- `d_b += sum(dz, axis=0)` — Sum over batch → `(units,)` = `b` shape. + +These are **accumulated** (with `+=`) across all timesteps because the same `Wx`, `Wh`, and `b` are used at every timestep — the total gradient is the sum of the gradients from each timestep. + +🔍 **Line 52**: `d_inputs[:, t, :] = dz @ Wx.T` — The gradient w.r.t. the input at this timestep. Shape: `(batch, units) @ (units, features)` → `(batch, features)`. + +🔍 **Line 53**: `dh_next = dz @ Wh.T` — The gradient to pass to the previous timestep. Shape: `(batch, units) @ (units, units)` → `(batch, units)`. This becomes `dh_next` in the `t-1` iteration. + +📐 Gradient shapes flowing backward: + +``` +t = T-1: + grad_output → dh (batch, units) → dz (batch, units) + → d_Wx (features, units) + → d_Wh (units, units) + → d_b (units,) + → d_inputs[:, T-1, :] (batch, features) + → dh_next (batch, units) + ↓ +t = T-2: dh_next passed to previous step + grad_output + dh_next → dh → dz → ... same pattern +``` + +--- + +## GRU + +### What does this layer do? + +A Gated Recurrent Unit (GRU) is a simplified version of LSTM that merges the cell state and hidden state. It uses two gates — an **update gate** (how much to keep vs. replace) and a **reset gate** (how much to forget the past) — to control information flow. + +### The math + +$$z_t = \sigma(x_t W_z + h_{t-1} U_z) \quad \text{(update gate)}$$ + +$$r_t = \sigma(x_t W_r + h_{t-1} U_r) \quad \text{(reset gate)}$$ + +$$\tilde{h}_t = \tanh(x_t W_h + (r_t \odot h_{t-1}) U_h) \quad \text{(candidate hidden state)}$$ + +$$h_t = (1 - z_t) \odot \tilde{h}_t + z_t \odot h_{t-1} \quad \text{(final hidden state)}$$ + +Let's unpack every symbol: + +- **$z_t$** — Update gate. Values in `(0, 1)`. A value near 1 means "keep the old state"; near 0 means "replace with the candidate." Note: the naming convention is sometimes flipped — here $z$ controls how much of the *old* state to keep, so $(1-z)$ controls how much of the *new* candidate to take. +- **$r_t$** — Reset gate. Values in `(0, 1)`. Controls how much of the past hidden state to forget when computing the candidate. +- **$\tilde{h}_t$** — Candidate hidden state. Like a regular RNN's hidden state, but modulated by the reset gate. +- **$W_z, W_r, W_h$** — Input weight matrices. Each has shape `(features, units)`. They are stacked into one big matrix `W` of shape `(features, 3 * units)` for efficiency. +- **$U_z, U_r, U_h$** — Recurrent weight matrices. Each has shape `(units, units)`. Stacked into `U` of shape `(units, 3 * units)`. +- **$h_t$** — Final hidden state. An interpolation between the old state and the candidate, controlled by `z`. + +### Walking through the code + +#### `__init__` / `build` + +```python +def __init__(self, units, return_sequences=False, **kwargs): + super().__init__(**kwargs) + self.units = units + self.return_sequences = return_sequences + +def build(self, input_shape): + self.features = input_shape[-1] + init = get_initializer('glorot_uniform') + self.params['W'] = init((self.features, 3 * self.units)) + self.params['U'] = init((self.units, 3 * self.units)) + self.params['b'] = get_initializer('zeros')((3 * self.units,)) + super().build(input_shape) +``` + +🔍 **Weight organization**: Instead of 6 separate matrices (Wz, Wr, Wh, Uz, Ur, Uh), the GRU concatenates them: + +``` +W = [Wz | Wr | Wh] shape: (features, 3*units) +U = [Uz | Ur | Uh] shape: (units, 3*units) +b = [bz | br | bh] shape: (3*units,) +``` + +The first `2*units` columns are for the update and reset gates (z, r). The last `units` columns are for the candidate hidden state (h_tilde). + +📐 **Why concatenate?** A single matrix multiply `x @ W` is faster than three separate ones. We split the result afterward. + +#### `forward` + +```python +def forward(self, inputs, training=False): + self.inputs = inputs + batch, timesteps, _ = inputs.shape + self.h_states = np.zeros((batch, timesteps + 1, self.units)) + + self.z_gates = np.zeros((batch, timesteps, self.units)) + self.r_gates = np.zeros((batch, timesteps, self.units)) + self.h_tilde = np.zeros((batch, timesteps, self.units)) + self.x_W = np.dot(inputs, self.params['W']) # (batch, timesteps, 3*units) + + for t in range(timesteps): + x_W_t = self.x_W[:, t, :] + self.params['b'] + h_prev = self.h_states[:, t, :] + + z_r_hidden = np.dot(h_prev, self.params['U'][:, :2*self.units]) + z_r_logits = x_W_t[:, :2*self.units] + z_r_hidden + z_r = self._sigmoid(z_r_logits) + + z = z_r[:, :self.units] + r = z_r[:, self.units:] + + self.z_gates[:, t, :] = z + self.r_gates[:, t, :] = r + + h_tilde_hidden = np.dot(r * h_prev, self.params['U'][:, 2*self.units:]) + h_tilde_logits = x_W_t[:, 2*self.units:] + h_tilde_hidden + h_tilde = np.tanh(h_tilde_logits) + self.h_tilde[:, t, :] = h_tilde + + self.h_states[:, t+1, :] = (1 - z) * h_tilde + z * h_prev + + if self.return_sequences: + return self.h_states[:, 1:, :] + return self.h_states[:, -1, :] ``` -Weight concatenation optimization (LSTM, line 53): weights for all four gates are stored as a single matrix to optimize the dot product: `W = np.concatenate([W_f, W_i, W_C, W_o])`. +🔍 **Lines 47-49**: Pre-allocate storage for **all** intermediate values — `z_gates`, `r_gates`, `h_tilde` across all timesteps. These are cached for the backward pass. -## Usage Example +🔍 **Line 50**: `self.x_W = np.dot(inputs, self.params['W'])` — Compute the input-to-hidden projection for **all timesteps at once**. + +📐 `inputs` is `(batch, timesteps, features)` @ `W` is `(features, 3*units)` → `(batch, timesteps, 3*units)`. + +This is an optimization: one big matrix multiply instead of `T` small ones. At timestep `t`, we slice `self.x_W[:, t, :]`. + +🔍 **Lines 57-62**: Compute the update gate `z` and reset gate `r`: + +1. `z_r_hidden = h_prev @ U[:, :2*units]` — The recurrent part of the gates. +2. `z_r_logits = x_W_t[:, :2*units] + z_r_hidden` — Add the input part (from the pre-computed `x_W`). +3. `z_r = sigmoid(z_r_logits)` — Both gates share one sigmoid computation. +4. Split: `z = z_r[:, :units]`, `r = z_r[:, units:]`. + +🔍 **Lines 64-71**: Compute the candidate hidden state: + +1. `h_tilde_hidden = (r * h_prev) @ U[:, 2*units:]` — The reset gate element-wise multiplies the previous hidden state, zeroing out some dimensions. +2. `h_tilde_logits = x_W_t[:, 2*units:] + h_tilde_hidden` — Add the input part. +3. `h_tilde = tanh(h_tilde_logits)` — Squash to `(-1, 1)`. + +🔍 **Line 74**: `h_t = (1 - z) * h_tilde + z * h_prev` — The final hidden state is an **interpolation** between the old state and the candidate. When `z` is close to 1, we mostly keep the old state. When `z` is close to 0, we mostly take the new candidate. + +📐 **Why cache `z_gates`, `r_gates`, `h_tilde`, and `x_W`?** Because `backward` needs every gate value at every timestep, plus the input projection. Without caching them all, backward would have to recompute the entire forward pass. + +#### `backward` (BPTT) ```python -from neutro.layers import LSTM, GRU +def backward(self, grad_output): + batch, timesteps, _ = self.inputs.shape + d_W = np.zeros_like(self.params['W']) + d_U = np.zeros_like(self.params['U']) + d_b = np.zeros_like(self.params['b']) + d_inputs = np.zeros_like(self.inputs) + dh_next = np.zeros((batch, self.units)) -lstm = LSTM(units=128, return_sequences=True) -x = np.random.randn(4, 32, 64) # (batch, seq, features) -y = lstm(x) # (batch, seq, 128) + for t in range(timesteps - 1, -1, -1): + dh = (grad_output[:, t, :] if self.return_sequences + else (grad_output if t == timesteps - 1 else 0)) + dh_next + z = self.z_gates[:, t, :] + r = self.r_gates[:, t, :] + h_tilde = self.h_tilde[:, t, :] + h_prev = self.h_states[:, t, :] + + # dL/dh_t -> dL/dz, dL/dh_tilde, dL/dh_prev + dz = dh * (h_prev - h_tilde) + dh_tilde = dh * (1 - z) + dh_prev_from_h = dh * z + + # Backprop through tanh for h_tilde + dtanh = dh_tilde * (1 - h_tilde**2) + + # dL/dh_tilde -> dL/dW_h, dL/dU_h, dL/dr + d_W[:, 2*self.units:] += np.dot(self.inputs[:, t, :].T, dtanh) + d_U[:, 2*self.units:] += np.dot((r * h_prev).T, dtanh) + d_b[2*self.units:] += np.sum(dtanh, axis=0) + + dr_h_prev = np.dot(dtanh, self.params['U'][:, 2*self.units:].T) + dr = dr_h_prev * h_prev + dh_prev_from_tilde = dr_h_prev * r + + # Backprop through sigmoids for z, r + dz_logits = dz * z * (1 - z) + dr_logits = dr * r * (1 - r) + dzr_logits = np.concatenate([dz_logits, dr_logits], axis=1) + + # dL/dzr -> dL/dW_zr, dL/dU_zr + d_W[:, :2*self.units] += np.dot(self.inputs[:, t, :].T, dzr_logits) + d_U[:, :2*self.units] += np.dot(h_prev.T, dzr_logits) + d_b[:2*self.units] += np.sum(dzr_logits, axis=0) + + dh_prev_from_gates = np.dot(dzr_logits, self.params['U'][:, :2*self.units].T) + + # Total dh_prev for next step + dh_next = dh_prev_from_h + dh_prev_from_tilde + dh_prev_from_gates + + # Gradient wrt inputs + d_inputs[:, t, :] = \ + np.dot(dzr_logits, self.params['W'][:, :2*self.units].T) + \ + np.dot(dtanh, self.params['W'][:, 2*self.units:].T) + + self.grads['W'] = d_W + self.grads['U'] = d_U + self.grads['b'] = d_b + return d_inputs +``` + +🧠 **"The backward loop goes in REVERSE order of the forward loop — that's the 'through time' part of BPTT"** + +🔍 **Line 88**: `for t in range(timesteps - 1, -1, -1)` — Same reverse loop as SimpleRNN. + +🔍 **Line 89**: `dh = grad_output + dh_next` — Gradient from two sources (same as SimpleRNN). + +--- + +**Step 1: Split `dh` into contributions through the three paths of `h_t = (1-z)*h_tilde + z*h_prev`** + +The final hidden state formula is: + +$$h_t = (1 - z) \cdot \tilde{h}_t + z \cdot h_{t-1}$$ + +By the product rule, the gradient `dh` splits into three terms: + +🔍 **Line 97**: `dz = dh * (h_prev - h_tilde)` — Gradient through `z` in both `(1-z)*h_tilde` and `z*h_prev`: + +$$\frac{\partial L}{\partial z} = \frac{\partial L}{\partial h} \cdot (h_{t-1} - \tilde{h}_t)$$ + +When `h_prev > h_tilde`, increasing `z` (keeping more of old state) reduces loss, and vice versa. + +🔍 **Line 98**: `dh_tilde = dh * (1 - z)` — Gradient through `h_tilde` in the `(1-z)` path. + +🔍 **Line 99**: `dh_prev_from_h = dh * z` — Gradient through `h_prev` in the `z` path. + +📐 All three are shape `(batch, units)`. + +--- + +**Step 2: Backprop through the candidate `h_tilde = tanh(x*Wh + (r*h_prev)*Uh + bh)`** + +🔍 **Line 102**: `dtanh = dh_tilde * (1 - h_tilde**2)` — Derivative of tanh. + +📐 Shape `(batch, units)`. + +🔍 **Lines 106-108**: Gradient w.r.t. the **candidate** (last `units`) portion of the weights: + +- `d_W[:, 2*units:] += x_t.T @ dtanh` — `(features, batch) @ (batch, units)` → `(features, units)` = shape of `Wh`. +- `d_U[:, 2*units:] += (r * h_prev).T @ dtanh` — `(units, batch) @ (batch, units)` → `(units, units)` = shape of `Uh`. Note that the reset gate `r` element-wise multiplies `h_prev` before the matrix multiply. +- `d_b[2*units:] += sum(dtanh, axis=0)` — `(units,)`. + +--- + +**Step 3: Backprop through the reset gate `r = sigmoid(x*Wr + h_prev*Ur + br)`** + +🔍 **Line 110**: `dr_h_prev = dtanh @ Uh.T` — The gradient w.r.t. `(r * h_prev)`, before the element-wise multiply. + +📐 `(batch, units) @ (units, units)` → `(batch, units)`. + +🔍 **Line 111**: `dr = dr_h_prev * h_prev` — Gradient w.r.t. `r`. By the product rule of `r * h_prev`: + +$$\frac{\partial L}{\partial r} = \frac{\partial L}{\partial (r \cdot h_{prev})} \odot h_{prev}$$ + +🔍 **Line 112**: `dh_prev_from_tilde = dr_h_prev * r` — The gradient through `r * h_prev` w.r.t. `h_prev`. This is the **other half** of the product rule: if `r` is large, `h_prev` has more influence on the output. + +--- + +**Step 4: Backprop through the sigmoids for `z` and `r`** + +🔍 **Lines 115-116**: Sigmoid derivative: `dz_logits = dz * z * (1 - z)` and `dr_logits = dr * r * (1 - r)`. For sigmoid $\sigma(x)$, the derivative is $\sigma(x) \cdot (1 - \sigma(x))$. + +🔍 **Line 117**: `dzr_logits = concat([dz_logits, dr_logits])` — Concatenate back into `(batch, 2*units)` for the gate weight updates. + +--- + +**Step 5: Accumulate gradients for the gate weights** + +🔍 **Lines 120-122**: +- `d_W[:, :2*units] += x_t.T @ dzr_logits` — `(features, batch) @ (batch, 2*units)` → `(features, 2*units)` = shape of `[Wz | Wr]`. +- `d_U[:, :2*units] += h_prev.T @ dzr_logits` — `(units, batch) @ (batch, 2*units)` → `(units, 2*units)` = shape of `[Uz | Ur]`. +- `d_b[:2*units] += sum(dzr_logits, axis=0)` — `(2*units,)`. + +--- + +**Step 6: Compute `dh_next` — the gradient to pass to the previous timestep** + +🔍 **Line 124**: `dh_prev_from_gates = dzr_logits @ U[:, :2*units].T` — Gradient through the gate's recurrent connection. + +🔍 **Line 127**: `dh_next = dh_prev_from_h + dh_prev_from_tilde + dh_prev_from_gates` — The total gradient w.r.t. `h_prev` is the sum of three paths: + +1. `dh_prev_from_h` — from the `z * h_prev` direct connection +2. `dh_prev_from_tilde` — from the `(r * h_prev)` in the candidate computation +3. `dh_prev_from_gates` — from `h_prev @ U[:, :2*units]` in the gate logits + +--- + +**Step 7: Gradient w.r.t. inputs** + +🔍 **Lines 130-131**: `d_inputs[:, t, :]` has two contributions: + +1. `dzr_logits @ W[:, :2*units].T` — From the gates (Wz, Wr portions of W) +2. `dtanh @ W[:, 2*units:].T` — From the candidate (Wh portion of W) + +📐 Shape: `(batch, 2*units) @ (2*units, features)` and `(batch, units) @ (units, features)` → both `(batch, features)`, added together. + +--- + +## Try it yourself + +```python +from neutro.layers.recurrent import SimpleRNN, GRU +import numpy as np + +# SimpleRNN +rnn = SimpleRNN(units=64, return_sequences=True) +x = np.random.randn(4, 10, 32) # (batch, timesteps, features) +y = rnn(x) +print(f"RNN output: {y.shape}") # (4, 10, 64) + +# GRU gru = GRU(units=64, return_sequences=False) -z = gru(x) # (batch, 64) +z = gru(x) +print(f"GRU output: {z.shape}") # (4, 64) + +# Backward +dL_dy = np.random.randn(4, 64) +dL_dx = gru.backward(dL_dy) +print(f"GRU input grad: {dL_dx.shape}") # (4, 10, 32) +print(f"GRU W grad: {gru.grads['W'].shape}") # (32, 192) +print(f"GRU U grad: {gru.grads['U'].shape}") # (64, 192) +``` + +## What to read next + +- **`docs/layers/recurrent/lstm.md`** — The LSTM: four gates, a cell state, and why it solves the vanishing gradient problem. +- **`docs/layers/core/dense.md`** — If you need a refresher on how `build`, `forward`, and `backward` work in simpler layers. +## SimpleRNN + +### What does this layer do? + +An RNN processes sequences one step at a time, maintaining a hidden state that carries information forward. At each timestep, it combines the current input with the previous hidden state through a learned transformation. + +### The math + +$$h_t = \tanh(x_t W_x + h_{t-1} W_h + b)$$ + +Let's unpack every symbol: + +- **$x_t$** — The input at timestep `t`. Shape `(batch, features)`. Each timestep's slice of the input sequence. +- **$W_x$** — Input weight matrix. Shape `(features, units)`. Controls how the current input influences the new hidden state. +- **$h_{t-1}$** — The hidden state from the *previous* timestep. Shape `(batch, units)`. This is the "memory" that carries information across timesteps. +- **$W_h$** — Recurrent weight matrix. Shape `(units, units)`. Controls how the previous hidden state influences the new one. +- **$b$** — Bias vector. Shape `(units,)`. +- **$\tanh$** — Hyperbolic tangent activation, squashing values into `(-1, 1)`. +- **$h_t$** — The new hidden state. Also the output at this timestep (if `return_sequences=True`). + +### Walking through the code + +#### `__init__` / `build` + +```python +def __init__(self, units, activation='tanh', return_sequences=False, **kwargs): + super().__init__(**kwargs) + self.units = units + self.return_sequences = return_sequences + self.activation_name = activation + +def build(self, input_shape): + self.features = input_shape[-1] + init = get_initializer('glorot_uniform') + self.params['Wx'] = init((self.features, self.units)) + self.params['Wh'] = init((self.units, self.units)) + self.params['b'] = get_initializer('zeros')((self.units,)) + super().build(input_shape) +``` + +🔍 **`features = input_shape[-1]`** — We grab the last dimension of the input. If input is `(batch, timesteps, features)`, this is the feature dimension. The first two (batch and timesteps) are handled dynamically. + +🔍 **`Wx` shape `(features, units)`** — Maps `features` input dimensions → `units` hidden dimensions. If input is 128-dimensional and we want 64 hidden units, this is `(128, 64)`. + +🔍 **`Wh` shape `(units, units)`** — Maps hidden state → hidden state. A square matrix: `(64, 64)`. This is what makes RNNs "recurrent" — the same hidden-to-hidden weight is applied at every timestep. + +🔍 **`b` shape `(units,)`** — One bias per hidden unit. Broadcast across batch. + +#### `forward` + +```python +def forward(self, inputs, training=False): + self.inputs = inputs + batch, timesteps, _ = inputs.shape + self.h_states = np.zeros((batch, timesteps + 1, self.units)) + + for t in range(timesteps): + z = np.dot(inputs[:, t, :], self.params['Wx']) + \ + np.dot(self.h_states[:, t, :], self.params['Wh']) + self.params['b'] + if self.activation_name == 'tanh': + self.h_states[:, t+1, :] = np.tanh(z) + else: + self.h_states[:, t+1, :] = z + if self.return_sequences: + return self.h_states[:, 1:, :] + return self.h_states[:, -1, :] +``` + +🔍 **Line 27**: `self.inputs = inputs` — Cached for backward. `backward` doesn't receive the original inputs — it only gets `grad_output` — so we must save `inputs` now. + +🔍 **Line 28**: `self.h_states = np.zeros((batch, timesteps + 1, self.units))` — We allocate space for **all** hidden states, one per timestep *plus one extra* for the initial `h_0 = 0`. Shape: `(batch, T+1, units)`. + +Why `T+1`? So `h_states[:, t, :]` and `h_states[:, t+1, :]` both index validly in the loop, where `t` goes `0` to `T-1`. Entry `0` is `h_0` (all zeros), entry `1` is `h_1`, ..., entry `T` is `h_T`. + +🔍 **Line 31**: The core computation: + +📐 `inputs[:, t, :]` is `(batch, features)` @ `Wx` `(features, units)` → `(batch, units)` + +📐 `h_states[:, t, :]` is `(batch, units)` @ `Wh` `(units, units)` → `(batch, units)` + +📐 Adding them plus bias (broadcast) gives `z` of shape `(batch, units)`. + +🔍 **Line 32-35**: `np.tanh(z)` applies the squashing non-linearity. The result is stored at position `t+1` in `h_states`, making it `h_t`. + +🔍 **Lines 36-38**: If `return_sequences=True`, return all hidden states `h_1` through `h_T` — shape `(batch, T, units)`. If `False`, return only the last hidden state `h_T` — shape `(batch, units)`. + +📐 **Output shapes**: +- `return_sequences=True`: `(batch, timesteps, units)` +- `return_sequences=False`: `(batch, units)` + +🔍 Why do we cache ALL hidden states (not just the last one)? Because `backward` needs `h_t` and `h_{t+1}` for every `t` to compute gradients. Without caching, we'd have to re-run the forward loop. + +#### `backward` (BPTT — Backpropagation Through Time) + +```python +def backward(self, grad_output): + batch, timesteps, _ = self.inputs.shape + d_Wx, d_Wh, d_b = np.zeros_like(self.params['Wx']), \ + np.zeros_like(self.params['Wh']), \ + np.zeros_like(self.params['b']) + d_inputs = np.zeros_like(self.inputs) + dh_next = np.zeros((batch, self.units)) + + for t in range(timesteps - 1, -1, -1): + dh = (grad_output[:, t, :] if self.return_sequences \ + else (grad_output if t == timesteps - 1 else 0)) + dh_next + dz = dh * (1 - self.h_states[:, t+1, :]**2) + d_Wx += np.dot(self.inputs[:, t, :].T, dz) + d_Wh += np.dot(self.h_states[:, t, :].T, dz) + d_b += np.sum(dz, axis=0) + d_inputs[:, t, :] = np.dot(dz, self.params['Wx'].T) + dh_next = np.dot(dz, self.params['Wh'].T) + + self.grads['Wx'], self.grads['Wh'], self.grads['b'] = d_Wx, d_Wh, d_b + return d_inputs +``` + +🧠 **"The backward loop goes in REVERSE order of the forward loop — that's the 'through time' part of BPTT"** + +🔍 **Line 46**: `for t in range(timesteps - 1, -1, -1)` — Loop from `T-1` down to `0`. The forward loop went `0, 1, 2, ..., T-1`. The backward loop goes `T-1, ..., 2, 1, 0`. + +🔍 **Line 47**: `dh = grad_output + dh_next` — The gradient arriving at this timestep has **two sources**: + +1. **From above**: the upstream gradient `grad_output` from the loss (or next layer). If `return_sequences=True`, each timestep `t` gets its own slice `grad_output[:, t, :]`. If `return_sequences=False`, only the last timestep gets the gradient; all others get a zero contribution from this source. + +2. **From the future**: `dh_next` — the gradient flowing *back* from timestep `t+1`. This is computed on the **previous iteration** of the backward loop (which was timestep `t+1`, since we're going in reverse). + +This two-source pattern is the essence of BPTT — future timesteps send gradient information backward through `Wh`. + +🔍 **Line 48**: `dz = dh * (1 - h_{t+1}^2)` — This is the derivative of $\tanh$: $\frac{d}{dz}\tanh(z) = 1 - \tanh(z)^2 = 1 - h_{t+1}^2$. + +We multiply `dh` by this derivative (element-wise) to backpropagate through the tanh activation. This is the chain rule: $\frac{\partial L}{\partial z} = \frac{\partial L}{\partial h} \cdot \frac{\partial h}{\partial z}$. + +🔍 **Lines 49-51**: Accumulate gradients for each parameter: + +- `d_Wx += x_t^T @ dz` — Shape: `(features, batch) @ (batch, units)` → `(features, units)` = `Wx` shape. +- `d_Wh += h_t^T @ dz` — Shape: `(units, batch) @ (batch, units)` → `(units, units)` = `Wh` shape. +- `d_b += sum(dz, axis=0)` — Sum over batch → `(units,)` = `b` shape. + +These are **accumulated** (with `+=`) across all timesteps because the same `Wx`, `Wh`, and `b` are used at every timestep — the total gradient is the sum of the gradients from each timestep. + +🔍 **Line 52**: `d_inputs[:, t, :] = dz @ Wx.T` — The gradient w.r.t. the input at this timestep. Shape: `(batch, units) @ (units, features)` → `(batch, features)`. + +🔍 **Line 53**: `dh_next = dz @ Wh.T` — The gradient to pass to the previous timestep. Shape: `(batch, units) @ (units, units)` → `(batch, units)`. This becomes `dh_next` in the `t-1` iteration. + +📐 Gradient shapes flowing backward: + +``` +t = T-1: + grad_output → dh (batch, units) → dz (batch, units) + → d_Wx (features, units) + → d_Wh (units, units) + → d_b (units,) + → d_inputs[:, T-1, :] (batch, features) + → dh_next (batch, units) + ↓ +t = T-2: dh_next passed to previous step + grad_output + dh_next → dh → dz → ... same pattern +``` + +--- + +## GRU + +### What does this layer do? + +A Gated Recurrent Unit (GRU) is a simplified version of LSTM that merges the cell state and hidden state. It uses two gates — an **update gate** (how much to keep vs. replace) and a **reset gate** (how much to forget the past) — to control information flow. + +### The math + +$$z_t = \sigma(x_t W_z + h_{t-1} U_z) \quad \text{(update gate)}$$ + +$$r_t = \sigma(x_t W_r + h_{t-1} U_r) \quad \text{(reset gate)}$$ + +$$\tilde{h}_t = \tanh(x_t W_h + (r_t \odot h_{t-1}) U_h) \quad \text{(candidate hidden state)}$$ + +$$h_t = (1 - z_t) \odot \tilde{h}_t + z_t \odot h_{t-1} \quad \text{(final hidden state)}$$ + +Let's unpack every symbol: + +- **$z_t$** — Update gate. Values in `(0, 1)`. A value near 1 means "keep the old state"; near 0 means "replace with the candidate." Note: the naming convention is sometimes flipped — here $z$ controls how much of the *old* state to keep, so $(1-z)$ controls how much of the *new* candidate to take. +- **$r_t$** — Reset gate. Values in `(0, 1)`. Controls how much of the past hidden state to forget when computing the candidate. +- **$\tilde{h}_t$** — Candidate hidden state. Like a regular RNN's hidden state, but modulated by the reset gate. +- **$W_z, W_r, W_h$** — Input weight matrices. Each has shape `(features, units)`. They are stacked into one big matrix `W` of shape `(features, 3 * units)` for efficiency. +- **$U_z, U_r, U_h$** — Recurrent weight matrices. Each has shape `(units, units)`. Stacked into `U` of shape `(units, 3 * units)`. +- **$h_t$** — Final hidden state. An interpolation between the old state and the candidate, controlled by `z`. + +### Walking through the code + +#### `__init__` / `build` + +```python +def __init__(self, units, return_sequences=False, **kwargs): + super().__init__(**kwargs) + self.units = units + self.return_sequences = return_sequences + +def build(self, input_shape): + self.features = input_shape[-1] + init = get_initializer('glorot_uniform') + self.params['W'] = init((self.features, 3 * self.units)) + self.params['U'] = init((self.units, 3 * self.units)) + self.params['b'] = get_initializer('zeros')((3 * self.units,)) + super().build(input_shape) +``` + +🔍 **Weight organization**: Instead of 6 separate matrices (Wz, Wr, Wh, Uz, Ur, Uh), the GRU concatenates them: + +``` +W = [Wz | Wr | Wh] shape: (features, 3*units) +U = [Uz | Ur | Uh] shape: (units, 3*units) +b = [bz | br | bh] shape: (3*units,) +``` + +The first `2*units` columns are for the update and reset gates (z, r). The last `units` columns are for the candidate hidden state (h_tilde). + +📐 **Why concatenate?** A single matrix multiply `x @ W` is faster than three separate ones. We split the result afterward. + +#### `forward` + +```python +def forward(self, inputs, training=False): + self.inputs = inputs + batch, timesteps, _ = inputs.shape + self.h_states = np.zeros((batch, timesteps + 1, self.units)) + + self.z_gates = np.zeros((batch, timesteps, self.units)) + self.r_gates = np.zeros((batch, timesteps, self.units)) + self.h_tilde = np.zeros((batch, timesteps, self.units)) + self.x_W = np.dot(inputs, self.params['W']) # (batch, timesteps, 3*units) + + for t in range(timesteps): + x_W_t = self.x_W[:, t, :] + self.params['b'] + h_prev = self.h_states[:, t, :] + + z_r_hidden = np.dot(h_prev, self.params['U'][:, :2*self.units]) + z_r_logits = x_W_t[:, :2*self.units] + z_r_hidden + z_r = self._sigmoid(z_r_logits) + + z = z_r[:, :self.units] + r = z_r[:, self.units:] + + self.z_gates[:, t, :] = z + self.r_gates[:, t, :] = r + + h_tilde_hidden = np.dot(r * h_prev, self.params['U'][:, 2*self.units:]) + h_tilde_logits = x_W_t[:, 2*self.units:] + h_tilde_hidden + h_tilde = np.tanh(h_tilde_logits) + self.h_tilde[:, t, :] = h_tilde + + self.h_states[:, t+1, :] = (1 - z) * h_tilde + z * h_prev + + if self.return_sequences: + return self.h_states[:, 1:, :] + return self.h_states[:, -1, :] +``` + +🔍 **Lines 47-49**: Pre-allocate storage for **all** intermediate values — `z_gates`, `r_gates`, `h_tilde` across all timesteps. These are cached for the backward pass. + +🔍 **Line 50**: `self.x_W = np.dot(inputs, self.params['W'])` — Compute the input-to-hidden projection for **all timesteps at once**. + +📐 `inputs` is `(batch, timesteps, features)` @ `W` is `(features, 3*units)` → `(batch, timesteps, 3*units)`. + +This is an optimization: one big matrix multiply instead of `T` small ones. At timestep `t`, we slice `self.x_W[:, t, :]`. + +🔍 **Lines 57-62**: Compute the update gate `z` and reset gate `r`: + +1. `z_r_hidden = h_prev @ U[:, :2*units]` — The recurrent part of the gates. +2. `z_r_logits = x_W_t[:, :2*units] + z_r_hidden` — Add the input part (from the pre-computed `x_W`). +3. `z_r = sigmoid(z_r_logits)` — Both gates share one sigmoid computation. +4. Split: `z = z_r[:, :units]`, `r = z_r[:, units:]`. + +🔍 **Lines 64-71**: Compute the candidate hidden state: + +1. `h_tilde_hidden = (r * h_prev) @ U[:, 2*units:]` — The reset gate element-wise multiplies the previous hidden state, zeroing out some dimensions. +2. `h_tilde_logits = x_W_t[:, 2*units:] + h_tilde_hidden` — Add the input part. +3. `h_tilde = tanh(h_tilde_logits)` — Squash to `(-1, 1)`. + +🔍 **Line 74**: `h_t = (1 - z) * h_tilde + z * h_prev` — The final hidden state is an **interpolation** between the old state and the candidate. When `z` is close to 1, we mostly keep the old state. When `z` is close to 0, we mostly take the new candidate. + +📐 **Why cache `z_gates`, `r_gates`, `h_tilde`, and `x_W`?** Because `backward` needs every gate value at every timestep, plus the input projection. Without caching them all, backward would have to recompute the entire forward pass. + +#### `backward` (BPTT) + +```python +def backward(self, grad_output): + batch, timesteps, _ = self.inputs.shape + d_W = np.zeros_like(self.params['W']) + d_U = np.zeros_like(self.params['U']) + d_b = np.zeros_like(self.params['b']) + d_inputs = np.zeros_like(self.inputs) + dh_next = np.zeros((batch, self.units)) + + for t in range(timesteps - 1, -1, -1): + dh = (grad_output[:, t, :] if self.return_sequences + else (grad_output if t == timesteps - 1 else 0)) + dh_next + + z = self.z_gates[:, t, :] + r = self.r_gates[:, t, :] + h_tilde = self.h_tilde[:, t, :] + h_prev = self.h_states[:, t, :] + + # dL/dh_t -> dL/dz, dL/dh_tilde, dL/dh_prev + dz = dh * (h_prev - h_tilde) + dh_tilde = dh * (1 - z) + dh_prev_from_h = dh * z + + # Backprop through tanh for h_tilde + dtanh = dh_tilde * (1 - h_tilde**2) + + # dL/dh_tilde -> dL/dW_h, dL/dU_h, dL/dr + d_W[:, 2*self.units:] += np.dot(self.inputs[:, t, :].T, dtanh) + d_U[:, 2*self.units:] += np.dot((r * h_prev).T, dtanh) + d_b[2*self.units:] += np.sum(dtanh, axis=0) + + dr_h_prev = np.dot(dtanh, self.params['U'][:, 2*self.units:].T) + dr = dr_h_prev * h_prev + dh_prev_from_tilde = dr_h_prev * r + + # Backprop through sigmoids for z, r + dz_logits = dz * z * (1 - z) + dr_logits = dr * r * (1 - r) + dzr_logits = np.concatenate([dz_logits, dr_logits], axis=1) + + # dL/dzr -> dL/dW_zr, dL/dU_zr + d_W[:, :2*self.units] += np.dot(self.inputs[:, t, :].T, dzr_logits) + d_U[:, :2*self.units] += np.dot(h_prev.T, dzr_logits) + d_b[:2*self.units] += np.sum(dzr_logits, axis=0) + + dh_prev_from_gates = np.dot(dzr_logits, self.params['U'][:, :2*self.units].T) + + # Total dh_prev for next step + dh_next = dh_prev_from_h + dh_prev_from_tilde + dh_prev_from_gates + + # Gradient wrt inputs + d_inputs[:, t, :] = \ + np.dot(dzr_logits, self.params['W'][:, :2*self.units].T) + \ + np.dot(dtanh, self.params['W'][:, 2*self.units:].T) + + self.grads['W'] = d_W + self.grads['U'] = d_U + self.grads['b'] = d_b + return d_inputs +``` + +🧠 **"The backward loop goes in REVERSE order of the forward loop — that's the 'through time' part of BPTT"** + +🔍 **Line 88**: `for t in range(timesteps - 1, -1, -1)` — Same reverse loop as SimpleRNN. + +🔍 **Line 89**: `dh = grad_output + dh_next` — Gradient from two sources (same as SimpleRNN). + +--- + +**Step 1: Split `dh` into contributions through the three paths of `h_t = (1-z)*h_tilde + z*h_prev`** + +The final hidden state formula is: + +$$h_t = (1 - z) \cdot \tilde{h}_t + z \cdot h_{t-1}$$ + +By the product rule, the gradient `dh` splits into three terms: + +🔍 **Line 97**: `dz = dh * (h_prev - h_tilde)` — Gradient through `z` in both `(1-z)*h_tilde` and `z*h_prev`: + +$$\frac{\partial L}{\partial z} = \frac{\partial L}{\partial h} \cdot (h_{t-1} - \tilde{h}_t)$$ + +When `h_prev > h_tilde`, increasing `z` (keeping more of old state) reduces loss, and vice versa. + +🔍 **Line 98**: `dh_tilde = dh * (1 - z)` — Gradient through `h_tilde` in the `(1-z)` path. + +🔍 **Line 99**: `dh_prev_from_h = dh * z` — Gradient through `h_prev` in the `z` path. + +📐 All three are shape `(batch, units)`. + +--- + +**Step 2: Backprop through the candidate `h_tilde = tanh(x*Wh + (r*h_prev)*Uh + bh)`** + +🔍 **Line 102**: `dtanh = dh_tilde * (1 - h_tilde**2)` — Derivative of tanh. + +📐 Shape `(batch, units)`. + +🔍 **Lines 106-108**: Gradient w.r.t. the **candidate** (last `units`) portion of the weights: + +- `d_W[:, 2*units:] += x_t.T @ dtanh` — `(features, batch) @ (batch, units)` → `(features, units)` = shape of `Wh`. +- `d_U[:, 2*units:] += (r * h_prev).T @ dtanh` — `(units, batch) @ (batch, units)` → `(units, units)` = shape of `Uh`. Note that the reset gate `r` element-wise multiplies `h_prev` before the matrix multiply. +- `d_b[2*units:] += sum(dtanh, axis=0)` — `(units,)`. + +--- + +**Step 3: Backprop through the reset gate `r = sigmoid(x*Wr + h_prev*Ur + br)`** + +🔍 **Line 110**: `dr_h_prev = dtanh @ Uh.T` — The gradient w.r.t. `(r * h_prev)`, before the element-wise multiply. + +📐 `(batch, units) @ (units, units)` → `(batch, units)`. + +🔍 **Line 111**: `dr = dr_h_prev * h_prev` — Gradient w.r.t. `r`. By the product rule of `r * h_prev`: + +$$\frac{\partial L}{\partial r} = \frac{\partial L}{\partial (r \cdot h_{prev})} \odot h_{prev}$$ + +🔍 **Line 112**: `dh_prev_from_tilde = dr_h_prev * r` — The gradient through `r * h_prev` w.r.t. `h_prev`. This is the **other half** of the product rule: if `r` is large, `h_prev` has more influence on the output. + +--- + +**Step 4: Backprop through the sigmoids for `z` and `r`** + +🔍 **Lines 115-116**: Sigmoid derivative: `dz_logits = dz * z * (1 - z)` and `dr_logits = dr * r * (1 - r)`. For sigmoid $\sigma(x)$, the derivative is $\sigma(x) \cdot (1 - \sigma(x))$. + +🔍 **Line 117**: `dzr_logits = concat([dz_logits, dr_logits])` — Concatenate back into `(batch, 2*units)` for the gate weight updates. + +--- + +**Step 5: Accumulate gradients for the gate weights** + +🔍 **Lines 120-122**: +- `d_W[:, :2*units] += x_t.T @ dzr_logits` — `(features, batch) @ (batch, 2*units)` → `(features, 2*units)` = shape of `[Wz | Wr]`. +- `d_U[:, :2*units] += h_prev.T @ dzr_logits` — `(units, batch) @ (batch, 2*units)` → `(units, 2*units)` = shape of `[Uz | Ur]`. +- `d_b[:2*units] += sum(dzr_logits, axis=0)` — `(2*units,)`. + +--- + +**Step 6: Compute `dh_next` — the gradient to pass to the previous timestep** + +🔍 **Line 124**: `dh_prev_from_gates = dzr_logits @ U[:, :2*units].T` — Gradient through the gate's recurrent connection. + +🔍 **Line 127**: `dh_next = dh_prev_from_h + dh_prev_from_tilde + dh_prev_from_gates` — The total gradient w.r.t. `h_prev` is the sum of three paths: + +1. `dh_prev_from_h` — from the `z * h_prev` direct connection +2. `dh_prev_from_tilde` — from the `(r * h_prev)` in the candidate computation +3. `dh_prev_from_gates` — from `h_prev @ U[:, :2*units]` in the gate logits + +--- + +**Step 7: Gradient w.r.t. inputs** + +🔍 **Lines 130-131**: `d_inputs[:, t, :]` has two contributions: + +1. `dzr_logits @ W[:, :2*units].T` — From the gates (Wz, Wr portions of W) +2. `dtanh @ W[:, 2*units:].T` — From the candidate (Wh portion of W) + +📐 Shape: `(batch, 2*units) @ (2*units, features)` and `(batch, units) @ (units, features)` → both `(batch, features)`, added together. + +--- + +## Try it yourself + +```python +from neutro.layers.recurrent import SimpleRNN, GRU +import numpy as np + +# SimpleRNN +rnn = SimpleRNN(units=64, return_sequences=True) +x = np.random.randn(4, 10, 32) # (batch, timesteps, features) +y = rnn(x) +print(f"RNN output: {y.shape}") # (4, 10, 64) + +# GRU +gru = GRU(units=64, return_sequences=False) +z = gru(x) +print(f"GRU output: {z.shape}") # (4, 64) + +# Backward +dL_dy = np.random.randn(4, 64) +dL_dx = gru.backward(dL_dy) +print(f"GRU input grad: {dL_dx.shape}") # (4, 10, 32) +print(f"GRU W grad: {gru.grads['W'].shape}") # (32, 192) +print(f"GRU U grad: {gru.grads['U'].shape}") # (64, 192) ``` -## References +## What to read next -- Hochreiter, S., & Schmidhuber, J. (1997). **Long Short-Term Memory**. *Neural Computation*. -- Chung, J., et al. (2014). **Empirical Evaluation of Gated Recurrent Neural Networks on Sequence Modeling**. [arXiv:1412.3555](https://arxiv.org/abs/1412.3555) +- **`docs/layers/recurrent/lstm.md`** — The LSTM: four gates, a cell state, and why it solves the vanishing gradient problem. +- **`docs/layers/core/dense.md`** — If you need a refresher on how `build`, `forward`, and `backward` work in simpler layers. diff --git a/docs/layers/transformer/transformer_block.md b/docs/layers/transformer/transformer_block.md index ee2caf2..e207ac5 100644 --- a/docs/layers/transformer/transformer_block.md +++ b/docs/layers/transformer/transformer_block.md @@ -1,89 +1,198 @@ -# Transformer Block +# TransformerBlock -## Theory +## What does this layer do? -The Transformer block is the fundamental building block of modern LLMs. It combines multi-head attention with a feed-forward network, residual connections, and layer normalization. +The Transformer block is the fundamental building block of LLMs like GPT and BERT. It combines self-attention with a feed-forward network, using residual connections and layer normalization to keep training stable even in very deep stacks. -### Pre-Norm Architecture +## The architecture -$$\text{output} = x + \text{FFN}(\text{LN}(x + \text{Attention}(\text{LN}(x))))$$ +Two variants: -Each sub-layer has a residual connection (`x + sublayer(x)`), which helps gradient flow during backpropagation. +### Pre-Norm (modern, used in GPT-2, Llama) -### Post-Norm Architecture (original Transformer) +``` +output = x + FFN(LayerNorm(x + Attention(LayerNorm(x)))) +``` + +### Post-Norm (original Transformer, used in BERT) -$$\text{output} = \text{LN}(x + \text{FFN}(\text{LN}(x + \text{Attention}(x))))$$ +``` +output = LayerNorm(x + FFN(LayerNorm(x + Attention(x)))) +``` -## Implementation Guide +Pre-Norm is more stable during training because the norm is applied BEFORE each sub-layer, keeping activations controlled. Post-Norm was used in the original "Attention Is All You Need" paper but drifted gradients for deep stacks. -### File: `neutro/layers/transformer/transformer_block.py` +## Walking through the code -### `__init__` — line 11 +### Step 1: `__init__` (lines 11–25) — assembling the sub-layers ```python -class TransformerBlock(Layer): - def __init__(self, embed_dim, num_heads, ff_dim, rate=0.1, - causal=False, use_flash=False, pre_norm=False, **kwargs): +def __init__(self, embed_dim, num_heads, ff_dim, rate=0.1, causal=False, use_flash=False, pre_norm=False, **kwargs): + super().__init__(**kwargs) + self.num_heads = num_heads + self.embed_dim = embed_dim + self.causal = causal + self.pre_norm = pre_norm + if use_flash: + self.att = FlashAttention(num_heads=num_heads, key_dim=embed_dim) + else: + self.att = MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim) + self.ffn = [Dense(ff_dim, activation="relu"), Dense(embed_dim)] + self.layernorm1 = LayerNormalization(epsilon=1e-6) + self.layernorm2 = LayerNormalization(epsilon=1e-6) + self.dropout1 = Dropout(rate) + self.dropout2 = Dropout(rate) ``` -- `embed_dim`: model dimension (e.g., 768 for GPT-2 small). -- `num_heads`: number of attention heads (must divide `embed_dim`). -- `ff_dim`: feed-forward hidden dimension (typically 4× `embed_dim`). -- `causal`: if True, creates a causal attention mask (for autoregressive generation). -- `use_flash`: if True, uses `FlashAttention` instead of standard `MultiHeadAttention`. -- `pre_norm`: if True, uses Pre-Norm (modern); if False, uses Post-Norm (original). +🔍 **Line 17–20**: `self.att = FlashAttention(...)` or `MultiHeadAttention(...)` — the attention sub-layer. The `use_flash` flag lets you pick between regular MHA and the memory-efficient FlashAttention variant. -### Forward pass — line 42 +🔍 **Line 21**: `self.ffn = [Dense(ff_dim, activation='relu'), Dense(embed_dim)]` — a 2-layer MLP stored as a LIST. The first Dense expands from `embed_dim` to `ff_dim` with ReLU, the second projects back to `embed_dim`. -For Pre-Norm: +🔍 **Lines 22–23**: `self.layernorm1`, `self.layernorm2` — two layer norms, one before attention and one before the FFN. + +🔍 **Lines 24–25**: `self.dropout1`, `self.dropout2` — dropout applied after each sub-layer for regularization. + +🧠 "The sublayers are stored as plain attributes (including a list for `ffn`), so `Layer.sublayers` can traverse and find them all automatically for the optimizer." + +### Step 2: `build` (lines 27–37) — building the sub-layers in order ```python -norm1 = self.layernorm1(inputs, training) -attn_output = self.att(norm1, mask=mask, training=training) -h = inputs + self.dropout1(attn_output, training=training) +def build(self, input_shape): + self.att.build(input_shape) + curr_shape = input_shape + for layer in self.ffn: + layer.build(curr_shape) + curr_shape = (input_shape[0], input_shape[1], layer.units) + self.layernorm1.build(input_shape) + self.layernorm2.build(input_shape) + self.dropout1.build(input_shape) + self.dropout2.build(input_shape) + super().build(input_shape) +``` + +🔍 **Lines 29–32**: The `ffn` list is built sequentially. The first `Dense(ff_dim)` takes the embedding shape; the second `Dense(embed_dim)` takes the expanded shape. The `curr_shape` is updated after each build so the next layer knows what to expect. -norm2 = self.layernorm2(h, training=training) -ffn_output = self.ffn[1](self.ffn[0](norm2, training=training), training=training) -return h + self.dropout2(ffn_output, training=training) +🔍 **Lines 33–36**: Layer norms and dropouts take the same input shape — they don't change the dimension, they just transform values. + +### Step 3: `forward` (lines 42–80) — pre-norm vs post-norm + +First, a shared preamble (lines 43–55): + +```python +self.inputs = inputs +mask = None +if self.causal: + mask = BaseAttention.create_causal_mask(inputs.shape[1]) + if kv_cache and layer_id in kv_cache.k_cache: + q_len = inputs.shape[1] + kv_len = q_len + kv_cache.k_cache[layer_id].shape[2] + mask = np.ones((q_len, kv_len)) + mask[:, -q_len:] = BaseAttention.create_causal_mask(q_len) + mask[:, :-q_len] = 0 +``` + +🔍 **Lines 45–55**: The causal mask prevents attending to future tokens (autoregressive generation). When a KV cache is active, the mask expands to cover all past cached tokens plus the current ones — past tokens are fully visible (0 mask), current tokens use a triangular causal mask. + +**Pre-Norm path** (lines 57–68): + +```python +if self.pre_norm: + norm1 = self.layernorm1(inputs, training=training) + attn_output = self.att(norm1, mask=mask, training=training, kv_cache=kv_cache, layer_id=layer_id) + attn_dropped = self.dropout1(attn_output, training=training) + h = inputs + attn_dropped + + norm2 = self.layernorm2(h, training=training) + ffn_1 = self.ffn[0](norm2, training=training) + ffn_2 = self.ffn[1](ffn_1, training=training) + ffn_dropped = self.dropout2(ffn_2, training=training) + return h + ffn_dropped ``` -The block contains 7 sublayers: `att`, `layernorm1`, `layernorm2`, `dropout1`, `dropout2`, and two Dense layers in `ffn`. +🔍 **Line 62**: `h = inputs + attn_dropped` — the residual (skip) connection for the attention sub-layer. The original input is added back after attention + dropout, creating a direct gradient highway. -### Backward pass — line 82 +🔍 **Lines 64–67**: The FFN path: first expand with ReLU, then project back. Dropout is applied to the final FFN output. -The backward manually routes gradients through the skip connections: +🔍 **Line 68**: `return h + ffn_dropped` — the second residual connection. The output of the attention path (`h`) is added to the FFN output. + +**Post-Norm path** (lines 69–80): + +```python +else: + attn_output = self.att(inputs, mask=mask, training=training, kv_cache=kv_cache, layer_id=layer_id) + attn_dropped = self.dropout1(attn_output, training=training) + self.out1_pre_ln = inputs + attn_dropped + out1 = self.layernorm1(self.out1_pre_ln, training=training) + + ffn_1 = self.ffn[0](out1, training=training) + ffn_2 = self.ffn[1](ffn_1, training=training) + ffn_dropped = self.dropout2(ffn_2, training=training) + self.out2_pre_ln = out1 + ffn_dropped + return self.layernorm2(self.out2_pre_ln, training=training) +``` + +🔍 **Lines 71–73**: Attention is applied FIRST, then dropout, then the residual is added. Only AFTER the residual is the layer norm applied. This is the reverse order from pre-norm. + +🔍 **Lines 79–80**: Same pattern for the FFN — add residual first, then normalize at the very end. + +🔍 **Lines 73 & 79**: `self.out1_pre_ln` and `self.out2_pre_ln` are stored as attributes. These are intermediates needed by the backward pass (since post-norm applies norm after the residual, backward needs the pre-norm values). + +### Step 4: `backward` (lines 82–107) — manual skip-connection gradient routing + +**Pre-Norm backward** (lines 83–96): ```python -def backward(self, grad_output): +if self.pre_norm: grad_ffn_path = self.dropout2.backward(grad_output) grad_ffn = self.ffn[1].backward(grad_ffn_path) grad_ffn = self.ffn[0].backward(grad_ffn) grad_norm2 = self.layernorm2.backward(grad_ffn) - grad_h = grad_output + grad_norm2 # Skip connection + grad_h = grad_output + grad_norm2 grad_attn_path = self.dropout1.backward(grad_h) grad_attn = self.att.backward(grad_attn_path) grad_norm1 = self.layernorm1.backward(grad_attn) - return grad_h + grad_norm1 # Skip connection + return grad_h + grad_norm1 ``` -### Sub-layers +🔍 **Line 90**: `grad_h = grad_output + grad_norm2` — the skip connection in backward. `grad_output` is the gradient flowing directly through the skip connection (bypassing the FFN). `grad_norm2` is the gradient that came through the FFN path. They sum at the branch point. This is what makes training deep networks possible! -The block exposes its sublayers via the `sublayers` property, which is critical for: -- **Optimizer**: `_get_all_layers` finds them for parameter updates. -- **Shared layer state**: `_capture_layer_state` saves their internal state (inputs, z, etc.) per node. +🔍 **Line 96**: `return grad_h + grad_norm1` — same pattern again. The gradient splits at the first residual connection: one copy goes straight back to the input (via the skip), the other goes through the attention sub-layer. -## Usage Example +🔍 "Without skip connections, gradients would shrink through every layer (vanishing gradient). With skips, the gradient has a direct path from output to input — notice how `grad_output` appears directly in the return value on line 96." + +**Post-Norm backward** (lines 97–107): ```python -from neutro.layers.transformer import TransformerBlock +else: + grad = self.layernorm2.backward(grad_output) + grad_ffn_path = self.dropout2.backward(grad) + grad_ffn = self.ffn[1].backward(grad_ffn_path) + grad_ffn = self.ffn[0].backward(grad_ffn) + grad_out1 = self.layernorm1.backward(grad + grad_ffn) -block = TransformerBlock(embed_dim=512, num_heads=8, ff_dim=2048, pre_norm=True) -x = np.random.randn(2, 16, 512) # (batch, seq, embed) -y = block(x) # Same shape + grad_attn_path = self.dropout1.backward(grad_out1) + grad_attn = self.att.backward(grad_attn_path) + return grad_out1 + grad_attn ``` -## References +🔍 **Line 99**: The first thing backward does is go through `layernorm2` — the reverse of the forward order where layernorm was the last step. + +🔍 **Line 103**: `grad + grad_ffn` combines the gradient from the second residual connection before passing through `layernorm1`'s backward. + +🔍 **Line 107**: `return grad_out1 + grad_attn` — the gradient from the first residual connection. Notice how the post-norm backward stores intermediates (`self.out1_pre_ln`) that were computed during forward — this is needed because `LayerNormalization.backward` uses the pre-norm input to compute its gradient. + +## Putting it all together + +- The `TransformerBlock` contains 7 sublayers: 1 attention layer, 2 Dense layers (in a list), 2 layer norms, and 2 dropout layers. +- These sublayers are automatically discovered by `Layer.sublayers` because they're stored as attributes. +- They're built individually in `build` since each expects a different input shape (especially the two Dense layers in `ffn`). +- Parameters are collected by the optimizer via `_get_all_layers`, which traverses the sublayer tree. +- The block handles routing between sublayers manually (forward calls each one in sequence; backward reverses the sequence) rather than delegating to a `Model`. + +## Common patterns -- Vaswani, A., et al. (2017). **Attention Is All You Need**. [arXiv:1706.03762](https://arxiv.org/abs/1706.03762) -- Pre-Norm: GPT-2 / Llama architecture variant. +- **"The `ffn` is a `list` of `Dense` layers"** — this is why `Layer.sublayers` has special handling to iterate over lists found as attributes. +- **"Dropout is only active during training"** — the `training` flag is threaded through every sub-layer call. At inference, dropout is a no-op. +- **"The KV cache is optional"** — it's only populated and used during autoregressive generation. During training, `kv_cache` is `None`. +- **"Pre-norm vs post-norm changes the forward AND backward pass"** — both the data flow and gradient flow are inverted. Post-norm also needs intermediate buffers (`out1_pre_ln`, `out2_pre_ln`) for the backward pass. From 4cd5642735851b983ac90e1df014319ba22ce058 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 30 May 2026 18:08:53 +0000 Subject: [PATCH 4/4] Fix review comments across docs and functional gradients --- docs/activations/activations.md | 4 +- docs/data/data.md | 29 +++++++++++---- docs/models/model.md | 22 +++++------ docs/preprocessing/preprocessing.md | 2 +- docs/utils/utils.md | 2 +- neutro/layers/core/merging.py | 2 +- neutro/models/base_model.py | 57 ++++++++++++++++++++--------- tests/test_functional_api.py | 1 + 8 files changed, 78 insertions(+), 41 deletions(-) diff --git a/docs/activations/activations.md b/docs/activations/activations.md index 9baf2bf..6e6dc02 100644 --- a/docs/activations/activations.md +++ b/docs/activations/activations.md @@ -21,7 +21,7 @@ $$\sigma'(x) = \sigma(x)(1 - \sigma(x))$$ - Output range: (0, 1). Used for binary classification or as gating mechanism (LSTM, GRU). - **Vanishing gradient**: for very large or very small inputs, the gradient approaches 0. -### Tanh — `neutro/activations/tanh.py}$ +### Tanh — `neutro/activations/tanh.py` $$\tanh(x) = \frac{e^x - e^{-x}}{e^x + e^{-x}}$$ @@ -36,7 +36,7 @@ $$\text{Softmax}(x_i) = \frac{e^{x_i}}{\sum_j e^{x_j}}$$ - Output: probability distribution over classes. - **Jacobian-Vector Product** (`gradient_fast`, line 18): computes $y * (\text{grad\_output} - \sum(y * \text{grad\_output}))$ without building the full $N \times N$ Jacobian. -### SiLU — `neutro/activations/silu.py$ (Sigmoid Linear Unit) +### SiLU — `neutro/activations/silu.py` (Sigmoid Linear Unit) $$\text{SiLU}(x) = x \cdot \sigma(x)$$ diff --git a/docs/data/data.md b/docs/data/data.md index dd89dff..4e8b57c 100644 --- a/docs/data/data.md +++ b/docs/data/data.md @@ -6,16 +6,29 @@ A simple data loader for batching and shuffling: ```python class DataLoader: - def __init__(self, dataset, batch_size=32, shuffle=True): - self.dataset = dataset + def __init__(self, x, y, batch_size=32, shuffle=True, augmenter=None): + self.x = x + self.y = y self.batch_size = batch_size self.shuffle = shuffle + self.augmenter = augmenter + self.indices = np.arange(len(x)) + self.on_epoch_end() - def __iter__(self): - indices = np.arange(len(self.dataset)) + def __len__(self): + return int(np.ceil(len(self.x) / self.batch_size)) + + def on_epoch_end(self): if self.shuffle: - np.random.shuffle(indices) - for i in range(0, len(indices), self.batch_size): - batch_idx = indices[i:i + self.batch_size] - yield self.dataset[batch_idx] + np.random.shuffle(self.indices) + + def __getitem__(self, index): + batch_idx = self.indices[index * self.batch_size:(index + 1) * self.batch_size] + batch_x, batch_y = self.x[batch_idx], self.y[batch_idx] + return batch_x, batch_y + + def __iter__(self): + for i in range(len(self)): + yield self[i] + self.on_epoch_end() ``` diff --git a/docs/models/model.md b/docs/models/model.md index e1c0fd5..51fe891 100644 --- a/docs/models/model.md +++ b/docs/models/model.md @@ -32,7 +32,7 @@ Override `Model` directly (write your own `forward` and `backward`). Used for ar ### File: `neutro/models/base_model.py` -### `Model.__init__` — line 10 +### `Model.__init__` ```python class Model(Layer): @@ -51,7 +51,7 @@ class Model(Layer): - `Model` inherits from `Layer`, enabling nested models. - If `inputs` and `outputs` are provided, this is a **Functional API** model and the graph is discovered immediately. -### Graph Discovery (`_init_graph`) — line 25 +### Graph Discovery (`_init_graph`) ```python def traverse(tensor): @@ -69,9 +69,9 @@ def traverse(tensor): This is a **post-order DFS** starting from the output tensors. The resulting `_nodes_ordered` is in **forward execution order** (inputs first). The backward pass iterates `reversed(_nodes_ordered)`. -Unique layers are collected from the nodes (line 60): `if node.layer not in self.layers`. +Unique layers are collected from the nodes: `if node.layer not in self.layers`. -### Forward Pass — line 203 +### Forward Pass For Functional API models: @@ -97,7 +97,7 @@ for node in self._nodes_ordered: - `node.state` is captured **after** `forward` runs, ensuring it stores the state from this specific call (not stale data from a previous call). - The captured state uses `_capture_layer_state` which recurses into sublayers. -### Backward Pass — line 297 +### Backward Pass ```python grad_map = {} @@ -137,7 +137,7 @@ for node in reversed(self._nodes_ordered): - **State restoration**: Each node's captured state (from forward) is restored before its backward call, ensuring correct intermediate values (inputs, z, etc.). - **Branching support**: If one tensor feeds into multiple downstream layers, `grad_map[t_id] += grad_inputs[i]` sums the gradients (the natural behavior for Add-branching). -### Shared Layer State Management — line 80 +### Shared Layer State Management ```python @staticmethod @@ -159,7 +159,7 @@ def _capture_layer_state(layer): This recursively captures the `__dict__` of every sublayer, keyed by `id()`. Excluded keys (`params`, `grads`, `built`, `input_shape`, etc.) are persistent architectural attributes that should not be restored. -### The `fit` Method — line 127 +### The `fit` Method Supports three input modes: 1. **Single array**: `fit(x, y)` — standard training. @@ -170,11 +170,11 @@ Supports three input modes: Loss is summed across multiple outputs (matching Keras behavior): `batch_loss = sum(self.loss_fn(y_batch[j], output[j])`. -### `evaluate` — line 454 +### `evaluate` Similarly handles MIMO: sums losses across outputs, falls back gracefully for metrics. -### `summary` — line 462 +### `summary` For Functional API models, displays a "Connected to" column showing each layer's upstream dependencies: @@ -183,11 +183,11 @@ Layer (type) Output Shape Param # Connected to Add (Add) (None, 32) 0 input1, input2 ``` -### `_get_all_layers` — line 72 +### `_get_all_layers` Returns all unique layer instances (deduplicated by `id()`) across the entire layer hierarchy, including sublayers. Used by the optimizer to update parameters. -### `Sequential` — line 545 +### `Sequential` ```python class Sequential(Model): diff --git a/docs/preprocessing/preprocessing.md b/docs/preprocessing/preprocessing.md index 412ab02..bd9b527 100644 --- a/docs/preprocessing/preprocessing.md +++ b/docs/preprocessing/preprocessing.md @@ -39,7 +39,7 @@ def text_to_word_sequence(text, filters='!"#$%&()*+,-./:;<=>?@[\\]^_`{|}~\t\n', ... ``` -## Sequence Preprocessing — `neutro/preprocessing/sequence.py$ +## Sequence Preprocessing — `neutro/preprocessing/sequence.py` ### pad_sequences diff --git a/docs/utils/utils.md b/docs/utils/utils.md index f47b2a6..77bc9dd 100644 --- a/docs/utils/utils.md +++ b/docs/utils/utils.md @@ -21,7 +21,7 @@ def col2im(grad_cols, input_shape, kernel_size, strides): Used by `Conv2D` and `MaxPooling2D` for efficient forward/backward computation. -## rope_utils — `neutro/utils/rope_utils.py$ +## rope_utils — `neutro/utils/rope_utils.py` ### Rotary Position Embedding (RoPE) diff --git a/neutro/layers/core/merging.py b/neutro/layers/core/merging.py index 6160428..dffb3b1 100644 --- a/neutro/layers/core/merging.py +++ b/neutro/layers/core/merging.py @@ -176,7 +176,7 @@ def backward(self, grad_output): class Minimum(Layer): """ - Layer that computes the maximum (element-wise) of a list of inputs. + Layer that computes the minimum (element-wise) of a list of inputs. """ def __init__(self, **kwargs): super().__init__(**kwargs) diff --git a/neutro/models/base_model.py b/neutro/models/base_model.py index 34b441f..fcccb10 100644 --- a/neutro/models/base_model.py +++ b/neutro/models/base_model.py @@ -124,6 +124,42 @@ def _restore_layer_state(layer, state): for sl in l.sublayers: stack.append(sl) + @staticmethod + def _clear_layer_grads(layer): + stack = [layer] + visited = set() + while stack: + l = stack.pop() + l_id = id(l) + if l_id in visited: + continue + visited.add(l_id) + l.grads = {} + for sl in l.sublayers: + stack.append(sl) + + @staticmethod + def _accumulate_layer_grads(layer, grads_accumulator): + stack = [layer] + visited = set() + while stack: + l = stack.pop() + l_id = id(l) + if l_id in visited: + continue + visited.add(l_id) + + layer_acc = grads_accumulator.setdefault(l_id, {}) + for k, v in l.grads.items(): + if k in layer_acc: + layer_acc[k] += v + else: + layer_acc[k] = np.array(v, copy=True) + l.grads = layer_acc + + for sl in l.sublayers: + stack.append(sl) + def fit(self, x, y=None, epochs=1, batch_size=32, verbose=1, validation_data=None, callbacks=None): is_mimo_x = isinstance(x, list) is_mimo_y = isinstance(y, list) @@ -411,25 +447,12 @@ def backward(self, grad): self._restore_layer_state(node.layer, node.state) # Call layer.backward - # Temporarily clear layer.grads to capture only gradients for this node - original_grads = node.layer.grads - node.layer.grads = {} - + # Temporarily clear layer-tree grads to capture only this node call + self._clear_layer_grads(node.layer) grad_inputs = node.layer.backward(node_grad_outputs) - # Accumulate parameter gradients - l_id = id(node.layer) - if l_id not in layer_grads_accumulator: - layer_grads_accumulator[l_id] = {} - - for k, v in node.layer.grads.items(): - if k in layer_grads_accumulator[l_id]: - layer_grads_accumulator[l_id][k] += v - else: - layer_grads_accumulator[l_id][k] = v - - # Restore the combined gradients to the layer - node.layer.grads = layer_grads_accumulator[l_id] + # Accumulate parameter gradients across the full layer tree + self._accumulate_layer_grads(node.layer, layer_grads_accumulator) # Propagate gradients to inputs if isinstance(node.input_tensors, list): diff --git a/tests/test_functional_api.py b/tests/test_functional_api.py index 5dfc230..f81b157 100644 --- a/tests/test_functional_api.py +++ b/tests/test_functional_api.py @@ -146,6 +146,7 @@ def test_functional_gradients(): W[i, j] = orig_val num_grad = (loss_plus - loss_minus) / (2 * eps) + assert np.isclose(dW[i, j], num_grad, rtol=1e-4, atol=1e-5) def test_complex_summary(): """Test summary() for a multi-input, multi-output model."""