diff --git a/neutro/activations/softmax.py b/neutro/activations/softmax.py index 60ec938..73ef494 100644 --- a/neutro/activations/softmax.py +++ b/neutro/activations/softmax.py @@ -7,16 +7,19 @@ def __call__(self, x): self.last_output = exps / np.sum(exps, axis=-1, keepdims=True) return self.last_output def gradient(self, x): - return self.last_output * (1 - self.last_output) + # Softmax does not have a valid element-wise gradient because its + # Jacobian is a full matrix (diag(s) - s*s^T), not a diagonal one. + # s*(1-s) is only the diagonal and produces incorrect gradients. + # Always use gradient_fast(x, grad_output) instead. + raise NotImplementedError( + "Softmax.gradient() is not implemented because the softmax Jacobian " + "is not diagonal. Use gradient_fast(x, grad_output) for correct " + "chain-rule gradients: dL/dx = s * (g - dot(s, g))." + ) def gradient_fast(self, x, grad_output): - orig_shape = grad_output.shape - grad_flat = grad_output.reshape(-1, orig_shape[-1]) - out_flat = self.last_output.reshape(-1, orig_shape[-1]) - - n_samples, units = grad_flat.shape - res = np.zeros_like(grad_flat) - for i in range(n_samples): - s = out_flat[i].reshape(-1, 1) - jacobian = np.diagflat(s) - np.dot(s, s.T) - res[i] = np.dot(grad_flat[i], jacobian) - return res.reshape(orig_shape) + # Vectorized softmax backward: dL/dx = s * (g - dot(s, g)) + # Derived from the full Jacobian J = diag(s) - s*s^T: + # dL/dx_k = s_k * (g_k - sum_i(s_i * g_i)) + s = self.last_output + dot = np.sum(s * grad_output, axis=-1, keepdims=True) + return s * (grad_output - dot) diff --git a/neutro/layers/normalization/rmsnorm.py b/neutro/layers/normalization/rmsnorm.py index 5e472d2..acf1a81 100644 --- a/neutro/layers/normalization/rmsnorm.py +++ b/neutro/layers/normalization/rmsnorm.py @@ -24,19 +24,23 @@ def forward(self, x, training=False): return self.x_norm * self.params['weight'] def backward(self, grad_output): - # Naive but educational implementation of RMSNorm backward - # dW - self.grads['weight'] = np.sum(grad_output * self.x_norm, axis=(0, 1)) - + # dW: sum over all axes except the last (feature) axis so the result + # has the same shape as params['weight'] regardless of input rank + # (works for 2-D (batch, dim), 3-D (batch, seq, dim), etc.) + feature_axes = tuple(range(len(grad_output.shape) - 1)) + self.grads['weight'] = np.sum(grad_output * self.x_norm, axis=feature_axes) + # dX N = self.dim grad_x_norm = grad_output * self.params['weight'] - + # Backward through: x / sqrt(mean(x^2) + eps) - # Detailed derivation for the old folks: - # dx = (grad_x_norm / rms) - (x * sum(grad_x_norm * x) / (N * rms^3)) - + # Derivation: + # rms = sqrt(mean(x^2) + eps), d_rms/dx_j = x_j / (N * rms) + # d_xnorm_i/dx_j = delta_ij/rms - x_i*x_j / (N * rms^3) + # dL/dx_j = grad_x_norm_j/rms - x_j * sum(grad_x_norm*x) / (N * rms^3) + sum_grad_x = np.sum(grad_x_norm * self.x, axis=-1, keepdims=True) dx = (grad_x_norm / self.rms) - (self.x * sum_grad_x / (N * self.rms**3)) - + return dx diff --git a/neutro/layers/recurrent/simple_rnn.py b/neutro/layers/recurrent/simple_rnn.py index 4825025..3e3c6b2 100644 --- a/neutro/layers/recurrent/simple_rnn.py +++ b/neutro/layers/recurrent/simple_rnn.py @@ -45,7 +45,13 @@ def backward(self, grad_output): for t in range(timesteps - 1, -1, -1): dh = (grad_output[:, t, :] if self.return_sequences else (grad_output if t == timesteps - 1 else 0)) + dh_next - dz = dh * (1 - self.h_states[:, t+1, :]**2) + # Apply the derivative of the hidden-state activation. + # tanh: d(tanh(z))/dz = 1 - tanh(z)^2 = 1 - h_t^2 + # linear: d(z)/dz = 1, so dz = dh + if self.activation_name == 'tanh': + dz = dh * (1 - self.h_states[:, t+1, :]**2) + else: + dz = dh d_Wx += np.dot(self.inputs[:, t, :].T, dz) d_Wh += np.dot(self.h_states[:, t, :].T, dz) d_b += np.sum(dz, axis=0) diff --git a/tests/activations/test_softmax.py b/tests/activations/test_softmax.py index 68f1dac..47e1861 100644 --- a/tests/activations/test_softmax.py +++ b/tests/activations/test_softmax.py @@ -10,20 +10,42 @@ def test_softmax_forward(): assert np.allclose(np.sum(out), 1.0) assert out[0, 2] > out[0, 1] > out[0, 0] -def test_softmax_gradient(): +def test_softmax_gradient_raises(): + # Softmax.gradient() is intentionally not implemented because s*(1-s) is + # only the diagonal of the Jacobian and gives incorrect chain-rule results. softmax = Softmax() x = np.array([[1, 2, 3]], dtype=float) softmax(x) - grad = softmax.gradient(x) - assert grad.shape == x.shape + with pytest.raises(NotImplementedError): + softmax.gradient(x) def test_softmax_gradient_fast(): + # gradient_fast implements the correct vectorised backward: + # dL/dx = s * (g - dot(s, g)) + # which is equivalent to the full Jacobian J = diag(s) - s*s^T applied to g. softmax = Softmax() - x = np.array([[1, 2]], dtype=float) - out = softmax(x) - grad_output = np.array([[1, 0]], dtype=float) + x = np.array([[1.0, 2.0, 3.0]], dtype=float) + s = softmax(x) # populates last_output + + grad_output = np.array([[1.0, 0.0, 0.0]], dtype=float) + + # Expected: s * (g - dot(s, g)) for each sample + dot = np.sum(s * grad_output, axis=-1, keepdims=True) + expected = s * (grad_output - dot) + grad = softmax.gradient_fast(x, grad_output) assert grad.shape == x.shape + assert np.allclose(grad, expected) + +def test_softmax_gradient_fast_sum_zero(): + # A key property: the gradient of softmax sums to zero along the last axis + # (because softmax output sums to 1, the Jacobian rows sum to 0). + softmax = Softmax() + x = np.array([[0.5, 1.5, -0.5, 2.0]], dtype=float) + softmax(x) + grad_output = np.random.default_rng(42).standard_normal(x.shape) + grad = softmax.gradient_fast(x, grad_output) + assert np.allclose(grad.sum(axis=-1), 0.0, atol=1e-12) def test_get_activation(): assert isinstance(get('relu'), ReLU)