From 6a02a74968b811a2e99db71672bc2066ef9432e3 Mon Sep 17 00:00:00 2001 From: Jingcheng Yu Date: Thu, 5 Feb 2026 14:48:11 +0800 Subject: [PATCH] update softmax_cap --- CHANGELOG.md | 4 ++++ README.md | 6 +++--- pyproject.toml | 2 +- src/floydnet/functional.py | 16 +++++++++++++++- 4 files changed, 23 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 74c6e23..a235c9e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,10 @@ All notable changes to this project will be documented in this file. +## [1.1.0] - 2026-02-05 +- Added `softmax_cap` parameter to `pivotal_attention3` for improved numerical stability. +- Added LRGB example script. + ## [1.0.0] - 2026-01-25 - Full release with training and evaluation scripts for Graph Count, BREC, and TSP. - Added `pivotal_attention3` functional API for 3-Floyd attention. diff --git a/README.md b/README.md index b2d601b..159c021 100644 --- a/README.md +++ b/README.md @@ -128,9 +128,9 @@ uv pip install -e . ## Changelog (latest) -- Full release with training and evaluation scripts for Graph Count, BREC, and TSP. -- Added `pivotal_attention3` functional API for 3-Floyd attention. -- Added additional configuration options in `PivotalAttentionBlock`. +- Added `softmax_cap` parameter to `pivotal_attention3` for improved numerical stability. +- Added LRGB example script. + The full changelog is in [CHANGELOG.md](CHANGELOG.md). diff --git a/pyproject.toml b/pyproject.toml index 41d121c..315457b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "floydnet" -version = "1.0.0" +version = "1.1.0" description = "Floyd Multi-Head Attention: a drop-in variant of PyTorch MHA with module and function APIs" readme = "README.md" requires-python = ">=3.9" diff --git a/src/floydnet/functional.py b/src/floydnet/functional.py index db68b54..6ae761d 100644 --- a/src/floydnet/functional.py +++ b/src/floydnet/functional.py @@ -31,6 +31,7 @@ def pivotal_attention( dropout: float = 0.0, scale: Optional[float] = None, inf: float = 1e9, + softmax_cap: float = -1, ) -> torch.Tensor: """Pivotal attention as described in "FLOYDNET: A LEARNING PARADIGM FOR GLOBAL RELATIONAL REASONING". @@ -47,6 +48,9 @@ def pivotal_attention( dropout: Dropout probability applied to attention weights (only effective if > 0). scale: Optional custom scaling factor. If None, defaults to 1/sqrt(2*D). inf: Value to use for -infinity in masks. + softmax_cap: If > 0, applies a tanh-based logit cap before softmax. + Note: when using a non-boolean (additive) attn_mask, ensure its magnitude/semantics remain compatible + with capping (e.g., very large negative values used to approximate -inf can interact with logit shaping). Returns: Tensor of shape (B, H, L_i, L_k, D) @@ -65,6 +69,9 @@ def pivotal_attention( attn_scores = torch.einsum("bhikd,bhijd->bhikj", q_ik, k_ij) \ + torch.einsum("bhikd,bhjkd->bhikj", q_ik, k_jk) + if softmax_cap > 0: + attn_scores = softmax_cap * torch.tanh(attn_scores / softmax_cap) + if attn_mask is not None: if attn_mask.dtype == torch.bool: attn_scores = attn_scores.masked_fill(attn_mask, -inf) @@ -93,6 +100,7 @@ def pivotal_attention3( dropout: float = 0.0, scale: Optional[float] = None, inf: float = 1e9, + softmax_cap: float = -1, ) -> torch.Tensor: """3-Pivotal attention as described in "FLOYDNET: A LEARNING PARADIGM FOR GLOBAL RELATIONAL REASONING". @@ -111,9 +119,12 @@ def pivotal_attention3( dropout: Dropout probability applied to attention weights (only effective if > 0). scale: Optional custom scaling factor. If None, defaults to 1/sqrt(3*D). inf: Value to use for -infinity in masks. + softmax_cap: If > 0, applies a tanh-based logit cap before softmax. + Note: when using a non-boolean (additive) attn_mask, ensure its magnitude/semantics remain compatible + with capping (e.g., very large negative values used to approximate -inf can interact with logit shaping). Returns: - Tensor of shape (B, H, L_i, l_j, L_k, D) + Tensor of shape (B, H, L_i, L_j, L_k, D) """ assert all([t.dim() == 6 for t in [q_ijk, k_pjk, k_ipk, k_ijp, v_pjk, v_ipk, v_ijp]]), "All inputs must be 6D tensors" B, H, L_i, L_j, L_k, D = q_ijk.shape @@ -130,6 +141,9 @@ def pivotal_attention3( attn_scores = torch.einsum("bhijkd,bhpjkd->bhijkp", q_ijk, k_pjk) \ + torch.einsum("bhijkd,bhipkd->bhijkp", q_ijk, k_ipk) \ + torch.einsum("bhijkd,bhijpd->bhijkp", q_ijk, k_ijp) + + if softmax_cap > 0: + attn_scores = softmax_cap * torch.tanh(attn_scores / softmax_cap) if attn_mask is not None: if attn_mask.dtype == torch.bool: