Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 49 additions & 8 deletions nn/prefix_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,12 +140,39 @@ class CausalSequenceMixer:
avoid a circular import at module load.
"""

def __init__(self, d: int):
def __init__(self, d: int, proj_x_init_scale: float = 1.0):
"""Build the mixer.

Args:
d: hidden dimension.
proj_x_init_scale: multiplicative scale on the default Xavier
init of ``proj_x``. Use values <1 to tame the recurrence's
accumulated gain over a long sequence — for a sequence of
length S the effective amplification of x_t through the
Linear-RNN can be up to ~S, so very deep stacks need
``proj_x_init_scale ≈ 1/sqrt(S)`` to keep the residual
stream stable. Defaults to 1.0 to preserve old behavior.
"""
from grilly import nn

self.d = d
self.proj_x = nn.Linear(d, d, bias=False)
self.proj_a = nn.Linear(d, d, bias=True)

# Optionally shrink proj_x weights so the recurrence's gain stays O(1)
# over the full sequence length. The mixer accumulates x_t through
# ``h_t = a_t * h_{t-1} + x_t``, so even with a_t < 1 the steady-state
# ‖h‖ scales like ‖x_t‖ / (1 - a). At a_max = 0.95 that's a 20x
# amplification — without this rescale the residual stream blows up
# in deep stacks (12+ layers).
if proj_x_init_scale != 1.0:
try:
w_arr = np.asarray(self.proj_x.weight.data if hasattr(self.proj_x.weight, "data")
else self.proj_x.weight)
w_arr *= float(proj_x_init_scale)
except Exception:
pass

# Initialize the gate bias to +1 so sigmoid(1) ≈ 0.73 — the model
# starts out "remembering" most of the hidden state at t=0, which
# matches the LiquidCell behavior the old code defaulted to.
Expand All @@ -164,17 +191,31 @@ def __call__(self, x):
# x: (B, S, D) — a Variable / Tensor from upstream.
x_t = self.proj_x(x)

# Sigmoid without importing torch_api: use the identity
# sigmoid(z) = 0.5 * (1 + tanh(z/2))
# which is numerically stable and uses only tanh, which grilly's
# autograd exposes via Variable.tanh().
# Bounded gate in [0.05, 0.95] using tanh, NOT sigmoid.
#
# Why bounded: ``prefix_scan_causal`` runs the scan in log-space:
# ``subgroupInclusiveAdd(log(a))``. If ``a_t`` ever lands very near
# 0, ``log(a_t) -> -inf``, the partial sum overflows, and the
# forward output of the scan is nan. Sigmoid (which the previous
# version used) has range (0, 1), and as the projection weights grow
# during training some hidden dims drift to a_logits ~ -20, where
# ``sigmoid(-20) ~ 2e-9``. ``log(2e-9) ~ -20``, accumulated 32 times
# over a sequence = -640, and ``exp(-640) = 0`` -> divide-by-zero
# in the scan rescaling. Tanh-bounded gate (range [0.05, 0.95])
# gives ``log(0.05) = -3.0`` even at saturation, which is safe.
a_logits = self.proj_a(x)
if hasattr(a_logits, "tanh"):
a_t = 0.5 * (1.0 + (a_logits * 0.5).tanh())
# Use the same expression structure as before (Variable arithmetic
# exposed by grilly autograd: ``var * scalar``, ``scalar + var``,
# ``var.tanh()``). Build sigmoid via tanh, then squash into
# ``[0.05, 0.95]`` so log(a_t) >= -3.0 and the prefix scan stays
# numerically safe.
sig = 0.5 * (1.0 + (a_logits * 0.5).tanh())
a_t = 0.05 + sig * 0.9
else:
# Fallback for plain ndarray — upstream should be a Variable.
import numpy as _np
a_t = 0.5 * (1.0 + _np.tanh(_np.asarray(a_logits) * 0.5))
sig = 0.5 * (1.0 + _np.tanh(_np.asarray(a_logits) * 0.5))
a_t = 0.05 + sig * 0.9

h = prefix_scan_causal(x_t, a_t)
return h
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "grilly"
version = "0.6.1"
version = "1.0.0-rc1"
description = "GPU-accelerated neural network operations using Vulkan compute shaders"
readme = "README.md"
license = {text = "MIT"}
Expand Down
Loading