From 10c7c20b4a9040981a0ca16d9330a0d1f88e8b6c Mon Sep 17 00:00:00 2001 From: Grill cheese Date: Fri, 10 Apr 2026 12:57:16 -0400 Subject: [PATCH] v1.0.0-rc1 --- nn/prefix_scan.py | 57 ++++++++++++++++++++++++++++++++++++++++------- pyproject.toml | 2 +- 2 files changed, 50 insertions(+), 9 deletions(-) diff --git a/nn/prefix_scan.py b/nn/prefix_scan.py index 7c88894..a61a38d 100644 --- a/nn/prefix_scan.py +++ b/nn/prefix_scan.py @@ -140,12 +140,39 @@ class CausalSequenceMixer: avoid a circular import at module load. """ - def __init__(self, d: int): + def __init__(self, d: int, proj_x_init_scale: float = 1.0): + """Build the mixer. + + Args: + d: hidden dimension. + proj_x_init_scale: multiplicative scale on the default Xavier + init of ``proj_x``. Use values <1 to tame the recurrence's + accumulated gain over a long sequence — for a sequence of + length S the effective amplification of x_t through the + Linear-RNN can be up to ~S, so very deep stacks need + ``proj_x_init_scale ≈ 1/sqrt(S)`` to keep the residual + stream stable. Defaults to 1.0 to preserve old behavior. + """ from grilly import nn self.d = d self.proj_x = nn.Linear(d, d, bias=False) self.proj_a = nn.Linear(d, d, bias=True) + + # Optionally shrink proj_x weights so the recurrence's gain stays O(1) + # over the full sequence length. The mixer accumulates x_t through + # ``h_t = a_t * h_{t-1} + x_t``, so even with a_t < 1 the steady-state + # ‖h‖ scales like ‖x_t‖ / (1 - a). At a_max = 0.95 that's a 20x + # amplification — without this rescale the residual stream blows up + # in deep stacks (12+ layers). + if proj_x_init_scale != 1.0: + try: + w_arr = np.asarray(self.proj_x.weight.data if hasattr(self.proj_x.weight, "data") + else self.proj_x.weight) + w_arr *= float(proj_x_init_scale) + except Exception: + pass + # Initialize the gate bias to +1 so sigmoid(1) ≈ 0.73 — the model # starts out "remembering" most of the hidden state at t=0, which # matches the LiquidCell behavior the old code defaulted to. @@ -164,17 +191,31 @@ def __call__(self, x): # x: (B, S, D) — a Variable / Tensor from upstream. x_t = self.proj_x(x) - # Sigmoid without importing torch_api: use the identity - # sigmoid(z) = 0.5 * (1 + tanh(z/2)) - # which is numerically stable and uses only tanh, which grilly's - # autograd exposes via Variable.tanh(). + # Bounded gate in [0.05, 0.95] using tanh, NOT sigmoid. + # + # Why bounded: ``prefix_scan_causal`` runs the scan in log-space: + # ``subgroupInclusiveAdd(log(a))``. If ``a_t`` ever lands very near + # 0, ``log(a_t) -> -inf``, the partial sum overflows, and the + # forward output of the scan is nan. Sigmoid (which the previous + # version used) has range (0, 1), and as the projection weights grow + # during training some hidden dims drift to a_logits ~ -20, where + # ``sigmoid(-20) ~ 2e-9``. ``log(2e-9) ~ -20``, accumulated 32 times + # over a sequence = -640, and ``exp(-640) = 0`` -> divide-by-zero + # in the scan rescaling. Tanh-bounded gate (range [0.05, 0.95]) + # gives ``log(0.05) = -3.0`` even at saturation, which is safe. a_logits = self.proj_a(x) if hasattr(a_logits, "tanh"): - a_t = 0.5 * (1.0 + (a_logits * 0.5).tanh()) + # Use the same expression structure as before (Variable arithmetic + # exposed by grilly autograd: ``var * scalar``, ``scalar + var``, + # ``var.tanh()``). Build sigmoid via tanh, then squash into + # ``[0.05, 0.95]`` so log(a_t) >= -3.0 and the prefix scan stays + # numerically safe. + sig = 0.5 * (1.0 + (a_logits * 0.5).tanh()) + a_t = 0.05 + sig * 0.9 else: - # Fallback for plain ndarray — upstream should be a Variable. import numpy as _np - a_t = 0.5 * (1.0 + _np.tanh(_np.asarray(a_logits) * 0.5)) + sig = 0.5 * (1.0 + _np.tanh(_np.asarray(a_logits) * 0.5)) + a_t = 0.05 + sig * 0.9 h = prefix_scan_causal(x_t, a_t) return h diff --git a/pyproject.toml b/pyproject.toml index 309bd6a..0ece5c1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,7 +12,7 @@ build-backend = "setuptools.build_meta" [project] name = "grilly" -version = "0.6.1" +version = "1.0.0-rc1" description = "GPU-accelerated neural network operations using Vulkan compute shaders" readme = "README.md" license = {text = "MIT"}